File tree Expand file tree Collapse file tree 2 files changed +8
-10
lines changed Expand file tree Collapse file tree 2 files changed +8
-10
lines changed Original file line number Diff line number Diff line change @@ -174,7 +174,10 @@ static __global__ void flash_attn_vec_ext_f16(
174
174
K += blockIdx .y *D * nb11;
175
175
V += blockIdx .y *D * nb21;
176
176
maskh += blockIdx .y *D;
177
- for (int k_VKQ_0 = blockIdx .y *D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim .y *D) {
177
+ for (int k_VKQ_0 = blockIdx .y *D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim .y *D,
178
+ // Increment pointers after each loop:
179
+ K += gridDim .y *D*nb11, V += gridDim .y *D*nb21, maskh += gridDim .y *D) {
180
+
178
181
// Calculate KQ tile and keep track of new maximum KQ values:
179
182
180
183
if (mask) {
@@ -291,10 +294,6 @@ static __global__ void flash_attn_vec_ext_f16(
291
294
}
292
295
}
293
296
294
- K += gridDim .y *D * nb11;
295
- V += gridDim .y *D * nb21;
296
- maskh += gridDim .y *D;
297
-
298
297
__syncthreads ();
299
298
}
300
299
Original file line number Diff line number Diff line change @@ -180,7 +180,10 @@ static __global__ void flash_attn_vec_ext_f32(
180
180
K += blockIdx .y *D * nb11;
181
181
V += blockIdx .y *D * nb21;
182
182
maskh += blockIdx .y *D;
183
- for (int k_VKQ_0 = blockIdx .y *D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim .y *D) {
183
+ for (int k_VKQ_0 = blockIdx .y *D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim .y *D,
184
+ // Increment pointers after each loop:
185
+ K += gridDim .y *D*nb11, V += gridDim .y *D*nb21, maskh += gridDim .y *D) {
186
+
184
187
// Calculate KQ tile and keep track of new maximum KQ values:
185
188
186
189
if (mask) {
@@ -286,10 +289,6 @@ static __global__ void flash_attn_vec_ext_f32(
286
289
}
287
290
}
288
291
289
- K += gridDim .y *D * nb11;
290
- V += gridDim .y *D * nb21;
291
- maskh += gridDim .y *D;
292
-
293
292
__syncthreads ();
294
293
}
295
294
You can’t perform that action at this time.
0 commit comments