@@ -112,7 +112,8 @@ def forward_kernel(
112
112
BLOCK : tl .constexpr ,
113
113
QUERY_HEAD_GROUPS : tl .constexpr ,
114
114
QUERY_EXPAND_DIM : tl .constexpr ,
115
- NUM_SEL_KV_BLOCKS : tl .constexpr
115
+ NUM_SEL_KV_BLOCKS : tl .constexpr ,
116
+ INCLUDE_BLOCK_CAUSAL : tl .constexpr
116
117
):
117
118
start_m = tl .program_id (0 )
118
119
off_hb = tl .program_id (1 )
@@ -134,22 +135,6 @@ def forward_kernel(
134
135
offs_d [None , None , :]
135
136
)
136
137
137
- k_ptrs = (
138
- K +
139
- off_b * stride_kb +
140
- off_h * stride_kh +
141
- offs_n [:, None ] * stride_kn +
142
- offs_d [None , :]
143
- )
144
-
145
- v_ptrs = (
146
- V +
147
- off_b * stride_vb +
148
- off_h * stride_vh +
149
- offs_n [:, None ] * stride_vn +
150
- offs_d [None , :]
151
- )
152
-
153
138
# maximum
154
139
155
140
m_i = tl .zeros ([BLOCK , QUERY_HEAD_GROUPS ], dtype = tl .float32 ) - float ("inf" )
@@ -202,82 +187,99 @@ def forward_kernel(
202
187
other = 0.0
203
188
)
204
189
205
- if EVEN_N & EVEN_M :
206
- if EVEN_HEADDIM :
207
- k = tl .load (k_ptrs )
208
- else :
209
- k = tl .load (k_ptrs , mask = offs_d [None , :] < headdim , other = 0.0 )
210
- else :
211
- if EVEN_HEADDIM :
212
- k = tl .load (
213
- k_ptrs ,
214
- mask = offs_n [:, None ] < seqlen_k ,
215
- other = 0.0 ,
216
- )
217
- else :
218
- k = tl .load (
219
- k_ptrs ,
220
- mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ),
221
- other = 0.0 ,
222
- )
190
+ if INCLUDE_BLOCK_CAUSAL :
191
+ k_ptrs = (
192
+ K +
193
+ off_b * stride_kb +
194
+ off_h * stride_kh +
195
+ offs_n [:, None ] * stride_kn +
196
+ offs_d [None , :]
197
+ )
223
198
224
- qk = tl .zeros ([BLOCK * QUERY_HEAD_GROUPS , BLOCK ], dtype = tl .float32 )
199
+ v_ptrs = (
200
+ V +
201
+ off_b * stride_vb +
202
+ off_h * stride_vh +
203
+ offs_n [:, None ] * stride_vn +
204
+ offs_d [None , :]
205
+ )
225
206
226
- q = q .reshape (BLOCK * QUERY_HEAD_GROUPS , BLOCK_HEADDIM )
207
+ if EVEN_N & EVEN_M :
208
+ if EVEN_HEADDIM :
209
+ k = tl .load (k_ptrs )
210
+ else :
211
+ k = tl .load (k_ptrs , mask = offs_d [None , :] < headdim , other = 0.0 )
212
+ else :
213
+ if EVEN_HEADDIM :
214
+ k = tl .load (
215
+ k_ptrs ,
216
+ mask = offs_n [:, None ] < seqlen_k ,
217
+ other = 0.0 ,
218
+ )
219
+ else :
220
+ k = tl .load (
221
+ k_ptrs ,
222
+ mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ),
223
+ other = 0.0 ,
224
+ )
227
225
228
- qk + = tl .dot ( q , tl .trans ( k ) )
226
+ qk = tl .zeros ([ BLOCK * QUERY_HEAD_GROUPS , BLOCK ], dtype = tl .float32 )
229
227
230
- qk = qk .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK )
228
+ q = q .reshape (BLOCK * QUERY_HEAD_GROUPS , BLOCK_HEADDIM )
231
229
232
- if not EVEN_N :
233
- qk += tl .where (offs_n [None , :] < seqlen_k , 0 , float ("-inf" ))
230
+ qk += tl .dot (q , tl .trans (k ))
234
231
235
- qk = qk .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK )
232
+ qk = qk .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK )
236
233
237
- qk += tl .where (offs_m [:, None , None ] >= offs_n [None , None , :], 0 , float ("-inf" ))
234
+ if not EVEN_N :
235
+ qk += tl .where (offs_n [None , :] < seqlen_k , 0 , float ("-inf" ))
238
236
239
- m_ij = tl .maximum (tl .max (qk , 2 ) * softmax_scale , lse_i )
240
- p = tl .exp (qk * softmax_scale - m_ij [:, :, None ])
237
+ qk = qk .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK )
241
238
242
- l_ij = tl .sum ( p , 2 )
239
+ qk + = tl .where ( offs_m [:, None , None ] >= offs_n [ None , None , :], 0 , float ( "-inf" ) )
243
240
244
- acc_o_scale = tl .exp ( m_i - m_ij )
245
- acc_o *= acc_o_scale [:, :, None ]
241
+ m_ij = tl .maximum ( tl . max ( qk , 2 ) * softmax_scale , lse_i )
242
+ p = tl . exp ( qk * softmax_scale - m_ij [:, :, None ])
246
243
247
- if EVEN_N & EVEN_M :
248
- if EVEN_HEADDIM :
249
- v = tl .load (v_ptrs )
250
- else :
251
- v = tl .load (
252
- v_ptrs ,
253
- mask = offs_d [None , :] < headdim ,
254
- other = 0.0
255
- )
256
- else :
257
- if EVEN_HEADDIM :
258
- v = tl .load (
259
- v_ptrs ,
260
- mask = offs_n [:, None ] < seqlen_k ,
261
- other = 0.0 ,
262
- )
244
+ l_ij = tl .sum (p , 2 )
245
+
246
+ acc_o_scale = tl .exp (m_i - m_ij )
247
+ acc_o *= acc_o_scale [:, :, None ]
248
+
249
+ if EVEN_N & EVEN_M :
250
+ if EVEN_HEADDIM :
251
+ v = tl .load (v_ptrs )
252
+ else :
253
+ v = tl .load (
254
+ v_ptrs ,
255
+ mask = offs_d [None , :] < headdim ,
256
+ other = 0.0
257
+ )
263
258
else :
264
- v = tl .load (
265
- v_ptrs ,
266
- mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ),
267
- other = 0.0 ,
268
- )
259
+ if EVEN_HEADDIM :
260
+ v = tl .load (
261
+ v_ptrs ,
262
+ mask = offs_n [:, None ] < seqlen_k ,
263
+ other = 0.0 ,
264
+ )
265
+ else :
266
+ v = tl .load (
267
+ v_ptrs ,
268
+ mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ),
269
+ other = 0.0 ,
270
+ )
269
271
270
- p = p .reshape (BLOCK * QUERY_HEAD_GROUPS , BLOCK ).to (v .dtype )
272
+ p = p .reshape (BLOCK * QUERY_HEAD_GROUPS , BLOCK ).to (v .dtype )
271
273
272
- causal_o = tl .dot (p , v )
274
+ causal_o = tl .dot (p , v )
273
275
274
- acc_o += causal_o .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK_HEADDIM )
276
+ acc_o += causal_o .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK_HEADDIM )
275
277
276
- # -- update statistics
278
+ # -- update statistics
277
279
278
- m_i = m_ij
279
- l_i_new = tl .exp (lse_i - m_ij ) + l_ij
280
- lse_i = m_ij + tl .log (l_i_new )
280
+ m_i = m_ij
281
+ l_i_new = tl .exp (lse_i - m_ij ) + l_ij
282
+ lse_i = m_ij + tl .log (l_i_new )
281
283
282
284
# # take care of the selected kv blocks
283
285
@@ -419,7 +421,8 @@ def native_sparse_attn_forward(
419
421
v ,
420
422
kv_block_indices ,
421
423
kv_block_mask ,
422
- block_size = 128
424
+ block_size = 128 ,
425
+ include_block_causal = True
423
426
):
424
427
q , k , v , kv_block_indices = [x if is_contiguous (x ) else x .contiguous () for x in (q , k , v , kv_block_indices )]
425
428
@@ -488,6 +491,7 @@ def native_sparse_attn_forward(
488
491
QUERY_HEAD_GROUPS = head_groups ,
489
492
QUERY_EXPAND_DIM = 16 // head_groups ,
490
493
NUM_SEL_KV_BLOCKS = num_selected_fine_blocks ,
494
+ INCLUDE_BLOCK_CAUSAL = include_block_causal ,
491
495
num_warps = num_warps ,
492
496
num_stages = 1 ,
493
497
)
@@ -1184,14 +1188,19 @@ def backward_kernel(
1184
1188
BLOCK : tl .constexpr ,
1185
1189
QUERY_HEAD_GROUPS : tl .constexpr ,
1186
1190
QUERY_EXPAND_DIM : tl .constexpr ,
1191
+ INCLUDE_BLOCK_CAUSAL : tl .constexpr
1187
1192
):
1188
1193
off_hb = tl .program_id (1 )
1189
1194
off_b = off_hb // kv_heads
1190
1195
off_h = off_hb % kv_heads
1191
1196
off_qh = off_h * QUERY_HEAD_GROUPS
1192
1197
1193
- IS_CAUSAL = tl .program_id (0 ) == 0
1194
- OFF_SEL_KV_BLOCKS = tl .program_id (0 ) - 1
1198
+ if INCLUDE_BLOCK_CAUSAL :
1199
+ IS_CAUSAL = tl .program_id (0 ) == 0
1200
+ OFF_SEL_KV_BLOCKS = tl .program_id (0 ) - 1
1201
+ else :
1202
+ IS_CAUSAL = False
1203
+ OFF_SEL_KV_BLOCKS = tl .program_id (0 )
1195
1204
1196
1205
# offset pointers for batch/head
1197
1206
@@ -1310,7 +1319,8 @@ def native_sparse_attn_backward(
1310
1319
o ,
1311
1320
lse ,
1312
1321
dq , dk , dv ,
1313
- block_size = 128
1322
+ block_size = 128 ,
1323
+ include_block_causal = True
1314
1324
):
1315
1325
device = do .device
1316
1326
@@ -1362,7 +1372,10 @@ def native_sparse_attn_backward(
1362
1372
BLOCK_HEADDIM = BLOCK_HEADDIM ,
1363
1373
)
1364
1374
1365
- grid = lambda META : (num_sel_fine_blocks + 1 , batch * kv_heads )
1375
+ grid = lambda META : (
1376
+ num_sel_fine_blocks + int (include_block_causal ),
1377
+ batch * kv_heads
1378
+ )
1366
1379
1367
1380
backward_kernel [grid ](
1368
1381
q ,
@@ -1418,7 +1431,8 @@ def native_sparse_attn_backward(
1418
1431
QUERY_EXPAND_DIM = 16 // head_groups ,
1419
1432
EVEN_M = divisible_by (seqlen_q , block_size ),
1420
1433
EVEN_N = divisible_by (seqlen_k , block_size ),
1421
- EVEN_HEADDIM = BLOCK_HEADDIM == dim
1434
+ EVEN_HEADDIM = BLOCK_HEADDIM == dim ,
1435
+ INCLUDE_BLOCK_CAUSAL = include_block_causal ,
1422
1436
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
1423
1437
# num_warps=num_warps,
1424
1438
# num_stages=1,
@@ -1440,6 +1454,7 @@ def forward(
1440
1454
block_size ,
1441
1455
selected_block_indices ,
1442
1456
fmask ,
1457
+ include_block_causal
1443
1458
):
1444
1459
dtype = fq .dtype
1445
1460
@@ -1453,14 +1468,16 @@ def forward(
1453
1468
fq , fk , fv ,
1454
1469
selected_block_indices ,
1455
1470
fmask ,
1456
- block_size = block_size
1471
+ block_size = block_size ,
1472
+ include_block_causal = include_block_causal
1457
1473
)
1458
1474
1459
1475
ctx .save_for_backward (fq , fk , fv , selected_block_indices , fmask , out , lse )
1460
1476
1461
1477
ctx ._saved_variables = (
1462
1478
block_size ,
1463
- head_groups
1479
+ head_groups ,
1480
+ include_block_causal
1464
1481
)
1465
1482
1466
1483
return out .type (dtype ), lse
@@ -1473,7 +1490,8 @@ def backward(self, ctx, do, _):
1473
1490
1474
1491
(
1475
1492
block_size ,
1476
- head_groups
1493
+ head_groups ,
1494
+ include_block_causal
1477
1495
) = ctx ._saved_variables
1478
1496
1479
1497
do = do .half ()
@@ -1485,7 +1503,8 @@ def backward(self, ctx, do, _):
1485
1503
do , q , k , v ,
1486
1504
sel_block_indices , mask ,
1487
1505
out , lse , dq , dk , dv ,
1488
- block_size = block_size
1506
+ block_size = block_size ,
1507
+ include_block_causal = include_block_causal
1489
1508
)
1490
1509
1491
1510
return dq , dk , dv , None , None , None , None
@@ -1508,6 +1527,7 @@ def native_sparse_attend(
1508
1527
block_size : int ,
1509
1528
selected_block_indices : Int ['b qh n sel' ] | Int ['b kh n sel' ],
1510
1529
fmask : Bool ['b qh n sel' ] | Bool ['b kh n sel' ],
1530
+ include_block_causal = True ,
1511
1531
return_lse = False
1512
1532
):
1513
1533
seq_len = fq .shape [- 2 ]
@@ -1526,6 +1546,7 @@ def native_sparse_attend(
1526
1546
block_size ,
1527
1547
selected_block_indices ,
1528
1548
fmask ,
1549
+ include_block_causal
1529
1550
)
1530
1551
1531
1552
if not return_lse :
0 commit comments