@@ -1208,13 +1208,15 @@ def backward_kernel(
1208
1208
BLOCK : tl .constexpr ,
1209
1209
QUERY_HEAD_GROUPS : tl .constexpr ,
1210
1210
QUERY_EXPAND_DIM : tl .constexpr ,
1211
- NUM_SEL_KV_BLOCKS : tl .constexpr
1212
1211
):
1213
1212
off_hb = tl .program_id (1 )
1214
1213
off_b = off_hb // kv_heads
1215
1214
off_h = off_hb % kv_heads
1216
1215
off_qh = off_h * QUERY_HEAD_GROUPS
1217
1216
1217
+ IS_CAUSAL = tl .program_id (0 ) == 0
1218
+ OFF_SEL_KV_BLOCKS = tl .program_id (0 ) - 1
1219
+
1218
1220
# offset pointers for batch/head
1219
1221
1220
1222
Q += off_b * stride_qb + off_qh * stride_qh
@@ -1244,46 +1246,47 @@ def backward_kernel(
1244
1246
1245
1247
num_block_n = tl .cdiv (seqlen_k , BLOCK )
1246
1248
1247
- for start_n in range (0 , num_block_n ):
1248
- backward_kernel_one_col_block_causal (
1249
- start_n ,
1250
- Q ,
1251
- K ,
1252
- V ,
1253
- kv_block_indices ,
1254
- kv_block_mask ,
1255
- DO ,
1256
- DQ ,
1257
- DK ,
1258
- DV ,
1259
- LSE ,
1260
- D ,
1261
- softmax_scale ,
1262
- stride_qm ,
1263
- stride_kn ,
1264
- stride_vn ,
1265
- stride_dom ,
1266
- stride_dqm ,
1267
- stride_dkn ,
1268
- stride_dvn ,
1269
- stride_kvbl_m ,
1270
- stride_qh ,
1271
- stride_doh ,
1272
- stride_dqh ,
1273
- seqlen_q ,
1274
- seqlen_k ,
1275
- seqlen_q_rounded ,
1276
- headdim ,
1277
- BLOCK_HEADDIM = BLOCK_HEADDIM ,
1278
- EVEN_M = EVEN_M ,
1279
- EVEN_N = EVEN_N ,
1280
- EVEN_HEADDIM = EVEN_HEADDIM ,
1281
- BLOCK = BLOCK ,
1282
- QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS ,
1283
- QUERY_EXPAND_DIM = QUERY_EXPAND_DIM ,
1284
- )
1285
-
1286
- for off_sel_kv_blocks in range (NUM_SEL_KV_BLOCKS ):
1249
+ if IS_CAUSAL :
1250
+ for start_n in range (0 , num_block_n ):
1251
+ backward_kernel_one_col_block_causal (
1252
+ start_n ,
1253
+ Q ,
1254
+ K ,
1255
+ V ,
1256
+ kv_block_indices ,
1257
+ kv_block_mask ,
1258
+ DO ,
1259
+ DQ ,
1260
+ DK ,
1261
+ DV ,
1262
+ LSE ,
1263
+ D ,
1264
+ softmax_scale ,
1265
+ stride_qm ,
1266
+ stride_kn ,
1267
+ stride_vn ,
1268
+ stride_dom ,
1269
+ stride_dqm ,
1270
+ stride_dkn ,
1271
+ stride_dvn ,
1272
+ stride_kvbl_m ,
1273
+ stride_qh ,
1274
+ stride_doh ,
1275
+ stride_dqh ,
1276
+ seqlen_q ,
1277
+ seqlen_k ,
1278
+ seqlen_q_rounded ,
1279
+ headdim ,
1280
+ BLOCK_HEADDIM = BLOCK_HEADDIM ,
1281
+ EVEN_M = EVEN_M ,
1282
+ EVEN_N = EVEN_N ,
1283
+ EVEN_HEADDIM = EVEN_HEADDIM ,
1284
+ BLOCK = BLOCK ,
1285
+ QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS ,
1286
+ QUERY_EXPAND_DIM = QUERY_EXPAND_DIM ,
1287
+ )
1288
+ else :
1289
+ for start_n in range (0 , num_block_n ):
1287
1290
backward_kernel_one_col_block_sparse (
1288
1291
start_n ,
1289
1292
Q ,
@@ -1320,7 +1323,7 @@ def backward_kernel(
1320
1323
BLOCK = BLOCK ,
1321
1324
QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS ,
1322
1325
QUERY_EXPAND_DIM = QUERY_EXPAND_DIM ,
1323
- OFF_SEL_KV_BLOCKS = off_sel_kv_blocks
1326
+ OFF_SEL_KV_BLOCKS = OFF_SEL_KV_BLOCKS
1324
1327
)
1325
1328
1326
1329
def native_sparse_attn_backward (
@@ -1383,7 +1386,7 @@ def native_sparse_attn_backward(
1383
1386
BLOCK_HEADDIM = BLOCK_HEADDIM ,
1384
1387
)
1385
1388
1386
- grid = lambda META : (1 , batch * kv_heads )
1389
+ grid = lambda META : (num_sel_fine_blocks + 1 , batch * kv_heads )
1387
1390
1388
1391
backward_kernel [grid ](
1389
1392
q ,
@@ -1437,7 +1440,6 @@ def native_sparse_attn_backward(
1437
1440
BLOCK = block_size ,
1438
1441
QUERY_HEAD_GROUPS = head_groups ,
1439
1442
QUERY_EXPAND_DIM = 16 // head_groups ,
1440
- NUM_SEL_KV_BLOCKS = num_sel_fine_blocks ,
1441
1443
EVEN_M = divisible_by (seqlen_q , block_size ),
1442
1444
EVEN_N = divisible_by (seqlen_k , block_size ),
1443
1445
EVEN_HEADDIM = BLOCK_HEADDIM == dim
0 commit comments