@@ -129,8 +129,8 @@ def forward_kernel(
129
129
q_ptrs = (
130
130
Q +
131
131
off_b * stride_qb +
132
- offs_qh [:, None , None ] * stride_qh +
133
- offs_m [None , : , None ] * stride_qm +
132
+ offs_qh [None , : , None ] * stride_qh +
133
+ offs_m [:, None , None ] * stride_qm +
134
134
offs_d [None , None , :]
135
135
)
136
136
@@ -152,48 +152,56 @@ def forward_kernel(
152
152
153
153
# maximum
154
154
155
- m_i = tl .zeros ([BLOCK * QUERY_HEAD_GROUPS ], dtype = tl .float32 ) - float ("inf" )
155
+ m_i = tl .zeros ([BLOCK , QUERY_HEAD_GROUPS ], dtype = tl .float32 ) - float ("inf" )
156
156
157
157
# lse
158
158
159
159
lse_ptrs = (
160
160
Lse +
161
161
off_b * stride_lse_b +
162
- offs_qh [:, None ] * seqlen_q_rounded +
163
- offs_m [None , : ]
162
+ offs_qh [None , : ] * seqlen_q_rounded +
163
+ offs_m [:, None ]
164
164
)
165
165
166
- lse_i = tl .zeros ([BLOCK * QUERY_HEAD_GROUPS ], dtype = tl .float32 ) - float ("inf" )
166
+ lse_i = tl .zeros ([BLOCK , QUERY_HEAD_GROUPS ], dtype = tl .float32 ) - float ("inf" )
167
167
168
168
# output
169
169
170
170
out_ptrs = (
171
171
Out +
172
172
off_b * stride_ob +
173
- offs_qh [:, None , None ] * stride_oh +
174
- offs_m [None , : , None ] * stride_om +
173
+ offs_qh [None , : , None ] * stride_oh +
174
+ offs_m [:, None , None ] * stride_om +
175
175
offs_d [None , None , :]
176
176
)
177
177
178
- acc_o = tl .zeros ([QUERY_HEAD_GROUPS * BLOCK , BLOCK_HEADDIM ], dtype = tl .float32 )
178
+ acc_o = tl .zeros ([BLOCK , QUERY_HEAD_GROUPS , BLOCK_HEADDIM ], dtype = tl .float32 )
179
179
180
180
# load queries, keys, values
181
181
182
182
if EVEN_M & EVEN_N :
183
183
if EVEN_HEADDIM :
184
184
q = tl .load (q_ptrs )
185
185
else :
186
- q = tl .load (q_ptrs , mask = offs_d [None , None , :] < headdim , other = 0.0 )
186
+ q = tl .load (
187
+ q_ptrs ,
188
+ mask = offs_d [None , None , :] < headdim ,
189
+ other = 0.0
190
+ )
187
191
else :
188
192
if EVEN_HEADDIM :
189
- q = tl .load (q_ptrs , mask = offs_m [None , :, None ] < seqlen_q , other = 0.0 )
193
+ q = tl .load (
194
+ q_ptrs ,
195
+ mask = offs_m [:, None , None ] < seqlen_q ,
196
+ other = 0.0
197
+ )
190
198
else :
191
199
q = tl .load (
192
- q_ptrs , mask = (offs_m [None , :, None ] < seqlen_q ) & (offs_d [None , None , :] < headdim ), other = 0.0
200
+ q_ptrs ,
201
+ mask = (offs_m [:, None , None ] < seqlen_q ) & (offs_d [None , None , :] < headdim ),
202
+ other = 0.0
193
203
)
194
204
195
- q = q .reshape ([QUERY_HEAD_GROUPS * BLOCK , BLOCK_HEADDIM ])
196
-
197
205
if EVEN_N & EVEN_M :
198
206
if EVEN_HEADDIM :
199
207
k = tl .load (k_ptrs )
@@ -203,65 +211,75 @@ def forward_kernel(
203
211
if EVEN_HEADDIM :
204
212
k = tl .load (
205
213
k_ptrs ,
206
- mask = offs_n [:, None ] < seqlen_k ,
207
- other = 0.0 ,
214
+ mask = offs_n [:, None ] < seqlen_k ,
215
+ other = 0.0 ,
208
216
)
209
217
else :
210
218
k = tl .load (
211
219
k_ptrs ,
212
- mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ),
213
- other = 0.0 ,
220
+ mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ),
221
+ other = 0.0 ,
214
222
)
215
223
216
- qk = tl .zeros ([QUERY_HEAD_GROUPS * BLOCK , BLOCK ], dtype = tl .float32 )
224
+ qk = tl .zeros ([BLOCK * QUERY_HEAD_GROUPS , BLOCK ], dtype = tl .float32 )
225
+
226
+ q = q .reshape (BLOCK * QUERY_HEAD_GROUPS , BLOCK_HEADDIM )
227
+
217
228
qk += tl .dot (q , tl .trans (k ))
218
229
230
+ qk = qk .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK )
231
+
219
232
if not EVEN_N :
220
233
qk += tl .where (offs_n [None , :] < seqlen_k , 0 , float ("-inf" ))
221
234
222
- qk = qk .reshape ([QUERY_HEAD_GROUPS , BLOCK , BLOCK ])
223
-
224
- qk += tl .where (offs_m [:, None ] >= offs_n [None , :], 0 , float ("-inf" ))
235
+ qk = qk .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK )
225
236
226
- qk = qk . reshape ([ QUERY_HEAD_GROUPS * BLOCK , BLOCK ] )
237
+ qk += tl . where ( offs_m [:, None , None ] >= offs_n [ None , None , :], 0 , float ( "-inf" ) )
227
238
228
- m_ij = tl .maximum (tl .max (qk , 1 ) * softmax_scale , lse_i )
229
- p = tl .exp (qk * softmax_scale - m_ij [:, None ])
239
+ m_ij = tl .maximum (tl .max (qk , 2 ) * softmax_scale , lse_i )
240
+ p = tl .exp (qk * softmax_scale - m_ij [:, :, None ])
230
241
231
- l_ij = tl .sum (p , 1 )
242
+ l_ij = tl .sum (p , 2 )
232
243
233
244
acc_o_scale = tl .exp (m_i - m_ij )
234
- acc_o *= acc_o_scale [:, None ]
245
+ acc_o *= acc_o_scale [:, :, None ]
235
246
236
247
if EVEN_N & EVEN_M :
237
248
if EVEN_HEADDIM :
238
249
v = tl .load (v_ptrs )
239
250
else :
240
- v = tl .load (v_ptrs , mask = offs_d [None , :] < headdim , other = 0.0 )
251
+ v = tl .load (
252
+ v_ptrs ,
253
+ mask = offs_d [None , :] < headdim ,
254
+ other = 0.0
255
+ )
241
256
else :
242
257
if EVEN_HEADDIM :
243
258
v = tl .load (
244
259
v_ptrs ,
245
- mask = offs_n [:, None ] < seqlen_k ,
246
- other = 0.0 ,
260
+ mask = offs_n [:, None ] < seqlen_k ,
261
+ other = 0.0 ,
247
262
)
248
263
else :
249
264
v = tl .load (
250
265
v_ptrs ,
251
- mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ),
252
- other = 0.0 ,
266
+ mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ),
267
+ other = 0.0 ,
253
268
)
254
269
255
- p = p .to (v .dtype )
256
- acc_o += tl .dot (p , v )
270
+ p = p .reshape (BLOCK * QUERY_HEAD_GROUPS , BLOCK ).to (v .dtype )
271
+
272
+ causal_o = tl .dot (p , v )
273
+
274
+ acc_o += causal_o .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK_HEADDIM )
257
275
258
276
# -- update statistics
259
277
260
278
m_i = m_ij
261
279
l_i_new = tl .exp (lse_i - m_ij ) + l_ij
262
280
lse_i = m_ij + tl .log (l_i_new )
263
281
264
- # take care of the selected kv blocks
282
+ # # take care of the selected kv blocks
265
283
266
284
kv_block_indices_ptrs = (
267
285
kv_block_indices +
@@ -277,8 +295,7 @@ def forward_kernel(
277
295
offs_m * stride_kvbl_m
278
296
)
279
297
280
- q = q .reshape (QUERY_HEAD_GROUPS , BLOCK , BLOCK_HEADDIM )
281
- q = q .permute ((1 , 0 , 2 ))
298
+ q = q .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK_HEADDIM )
282
299
q = tl .expand_dims (q , 2 )
283
300
q = tl .broadcast_to (q , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK_HEADDIM ))
284
301
q = q .reshape (BLOCK , 16 , BLOCK_HEADDIM )
@@ -290,11 +307,19 @@ def forward_kernel(
290
307
blocks_offs_n = block_indices [:, None ] * BLOCK + tl .arange (0 , BLOCK )[None , :]
291
308
292
309
block_k_ptrs = (
293
- K + off_b * stride_kb + off_h * stride_kh + (blocks_offs_n [:, :, None ] * stride_kn + offs_d [None , None , :])
310
+ K +
311
+ off_b * stride_kb +
312
+ off_h * stride_kh +
313
+ blocks_offs_n [:, :, None ] * stride_kn +
314
+ offs_d [None , None , :]
294
315
)
295
316
296
317
block_v_ptrs = (
297
- V + off_b * stride_vb + off_h * stride_vh + (blocks_offs_n [:, :, None ] * stride_vn + offs_d [None , None , :])
318
+ V +
319
+ off_b * stride_vb +
320
+ off_h * stride_vh +
321
+ blocks_offs_n [:, :, None ] * stride_vn +
322
+ offs_d [None , None , :]
298
323
)
299
324
300
325
# load k of shape (m, n, d), sparsely selected by each query
@@ -304,50 +329,44 @@ def forward_kernel(
304
329
# similarities
305
330
306
331
block_qk = tl .zeros ([BLOCK , 16 , BLOCK ], dtype = tl .float32 )
307
- qk = tl .zeros ([QUERY_HEAD_GROUPS , BLOCK , BLOCK ], dtype = tl .float32 )
332
+ qk = tl .zeros ([BLOCK , QUERY_HEAD_GROUPS , BLOCK ], dtype = tl .float32 )
308
333
309
334
k_block = k_block .reshape (BLOCK , BLOCK , BLOCK_HEADDIM )
310
335
k_block = k_block .permute (0 , 2 , 1 )
311
336
312
- block_qk = tl .dot (q , k_block )
337
+ block_qk + = tl .dot (q , k_block )
313
338
block_qk = block_qk .reshape (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK )
314
339
block_qk = tl .sum (block_qk , 2 ) / QUERY_EXPAND_DIM
315
- block_qk = block_qk .permute (1 , 0 , 2 )
316
340
317
341
qk += block_qk
318
- qk += tl .where (block_masks [:, None ], 0 , float ("-inf" ))
319
-
320
- qk = qk .reshape (QUERY_HEAD_GROUPS * BLOCK , BLOCK )
342
+ qk += tl .where (block_masks [:, None , None ], 0 , float ("-inf" ))
321
343
322
344
# attention
323
345
324
- m_ij = tl .maximum (tl .max (qk , 1 ) * softmax_scale , lse_i )
325
- p = tl .exp (qk * softmax_scale - m_ij [:, None ])
346
+ m_ij = tl .maximum (tl .max (qk , 2 ) * softmax_scale , lse_i )
347
+ block_p = tl .exp (qk * softmax_scale - m_ij [:, :, None ])
326
348
327
- l_ij = tl .sum (p , 1 )
349
+ l_ij = tl .sum (block_p , 2 )
328
350
329
351
# renormalize the running output
330
352
331
353
acc_o_scale = tl .exp (m_i - m_ij )
332
- acc_o = acc_o * acc_o_scale [:, None ]
354
+ acc_o = acc_o * acc_o_scale [:, :, None ]
333
355
334
356
# aggregate values
335
357
336
358
v_block = tl .load (block_v_ptrs )
337
359
v_block = tl .reshape (v_block , (BLOCK , BLOCK , BLOCK_HEADDIM ))
338
360
339
- p = p .to (v_block .dtype )
340
- p_expanded = p .reshape (QUERY_HEAD_GROUPS , BLOCK , BLOCK )
341
- p_expanded = p_expanded .permute (1 , 0 , 2 )
361
+ block_p = block_p .to (v_block .dtype )
362
+ p_expanded = block_p .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK )
342
363
p_expanded = tl .expand_dims (p_expanded , 2 )
343
364
p_expanded = tl .broadcast_to (p_expanded , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK ))
344
365
p_expanded = p_expanded .reshape (BLOCK , 16 , BLOCK )
345
366
346
367
block_acc_o = tl .dot (p_expanded , v_block )
347
368
block_acc_o = block_acc_o .reshape (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK_HEADDIM )
348
369
block_acc_o = tl .sum (block_acc_o , 2 ) / QUERY_EXPAND_DIM
349
- block_acc_o = block_acc_o .permute (1 , 0 , 2 )
350
- block_acc_o = block_acc_o .reshape (QUERY_HEAD_GROUPS * BLOCK , BLOCK_HEADDIM )
351
370
352
371
acc_o += block_acc_o
353
372
@@ -360,28 +379,38 @@ def forward_kernel(
360
379
# normalize accumulated out
361
380
362
381
acc_o_scale = tl .exp (m_i - lse_i )
363
- acc_o *= acc_o_scale [:, None ]
382
+ acc_o *= acc_o_scale [:, :, None ]
364
383
365
384
# write back lse
366
385
367
- lse_i = lse_i .reshape ([ QUERY_HEAD_GROUPS , BLOCK ] )
368
- tl .store (lse_ptrs , lse_i , mask = offs_m [None , : ] < seqlen_q )
386
+ lse_i = lse_i .reshape (BLOCK , QUERY_HEAD_GROUPS )
387
+ tl .store (lse_ptrs , lse_i , mask = offs_m [:, None ] < seqlen_q )
369
388
370
389
# write to output
371
390
372
- acc_o = acc_o .reshape ([ QUERY_HEAD_GROUPS , BLOCK , BLOCK_HEADDIM ] )
391
+ acc_o = acc_o .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK_HEADDIM )
373
392
374
393
if EVEN_M :
375
394
if EVEN_HEADDIM :
376
395
tl .store (out_ptrs , acc_o )
377
396
else :
378
- tl .store (out_ptrs , acc_o , mask = offs_d [None , None , :] < headdim )
397
+ tl .store (
398
+ out_ptrs ,
399
+ acc_o ,
400
+ mask = offs_d [None , None , :] < headdim
401
+ )
379
402
else :
380
403
if EVEN_HEADDIM :
381
- tl .store (out_ptrs , acc_o , mask = offs_m [None , :, None ] < seqlen_q )
404
+ tl .store (
405
+ out_ptrs ,
406
+ acc_o ,
407
+ mask = offs_m [:, None , None ] < seqlen_q
408
+ )
382
409
else :
383
410
tl .store (
384
- out_ptrs , acc_o , mask = (offs_m [None , :, None ] < seqlen_q ) & (offs_d [None , None , :] < headdim )
411
+ out_ptrs ,
412
+ acc_o ,
413
+ mask = (offs_m [:, None , None ] < seqlen_q ) & (offs_d [None , None , :] < headdim )
385
414
)
386
415
387
416
def native_sparse_attn_forward (
0 commit comments