Skip to content

Commit 946b1f6

Browse files
CUDA: fix pointer incrementation in FA (#14916)
1 parent 6c6e397 commit 946b1f6

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

ggml/src/ggml-cuda/fattn-vec-f16.cuh

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,10 @@ static __global__ void flash_attn_vec_ext_f16(
174174
K += blockIdx.y*D * nb11;
175175
V += blockIdx.y*D * nb21;
176176
maskh += blockIdx.y*D;
177-
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
177+
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
178+
// Increment pointers after each loop:
179+
K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
180+
178181
// Calculate KQ tile and keep track of new maximum KQ values:
179182

180183
if (mask) {
@@ -291,10 +294,6 @@ static __global__ void flash_attn_vec_ext_f16(
291294
}
292295
}
293296

294-
K += gridDim.y*D * nb11;
295-
V += gridDim.y*D * nb21;
296-
maskh += gridDim.y*D;
297-
298297
__syncthreads();
299298
}
300299

ggml/src/ggml-cuda/fattn-vec-f32.cuh

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,10 @@ static __global__ void flash_attn_vec_ext_f32(
180180
K += blockIdx.y*D * nb11;
181181
V += blockIdx.y*D * nb21;
182182
maskh += blockIdx.y*D;
183-
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
183+
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
184+
// Increment pointers after each loop:
185+
K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
186+
184187
// Calculate KQ tile and keep track of new maximum KQ values:
185188

186189
if (mask) {
@@ -286,10 +289,6 @@ static __global__ void flash_attn_vec_ext_f32(
286289
}
287290
}
288291

289-
K += gridDim.y*D * nb11;
290-
V += gridDim.y*D * nb21;
291-
maskh += gridDim.y*D;
292-
293292
__syncthreads();
294293
}
295294

0 commit comments

Comments
 (0)