@@ -590,6 +590,7 @@ def backward_kernel_one_col_block_sparse(
590
590
V ,
591
591
kv_block_indices ,
592
592
kv_block_mask ,
593
+ kv_block_grads ,
593
594
DO ,
594
595
DQ ,
595
596
DK ,
@@ -619,6 +620,7 @@ def backward_kernel_one_col_block_sparse(
619
620
BLOCK : tl .constexpr ,
620
621
QUERY_HEAD_GROUPS : tl .constexpr ,
621
622
QUERY_EXPAND_DIM : tl .constexpr ,
623
+ RETURN_SEL_GRADS : tl .constexpr ,
622
624
OFF_SEL_KV_BLOCKS : tl .constexpr
623
625
):
624
626
# We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
@@ -638,9 +640,6 @@ def backward_kernel_one_col_block_sparse(
638
640
639
641
# initialize pointers to value-like data
640
642
641
- k_ptrs = K + (offs_n [:, None ] * stride_kn + offs_d [None , :])
642
- v_ptrs = V + (offs_n [:, None ] * stride_vn + offs_d [None , :])
643
-
644
643
q_ptrs = (
645
644
Q +
646
645
offs_g [None , :, None ] * stride_qh +
@@ -794,9 +793,9 @@ def backward_kernel_one_col_block_sparse(
794
793
block_qk = block_qk .reshape (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK )
795
794
qk = tl .sum (block_qk , 2 ) / QUERY_EXPAND_DIM
796
795
797
- qk += tl .where (block_masks [:, None , None ], 0 , float ("-inf" ))
796
+ masked_qk = qk + tl .where (block_masks [:, None , None ], 0 , float ("-inf" ))
798
797
799
- p = tl .exp (qk * softmax_scale - lse_i [:, :, None ])
798
+ p = tl .exp (masked_qk * softmax_scale - lse_i [:, :, None ])
800
799
801
800
# take care of block dv
802
801
@@ -823,6 +822,26 @@ def backward_kernel_one_col_block_sparse(
823
822
824
823
ds = (p * (dp - Di [:, :, None ]) * softmax_scale )
825
824
825
+ # maybe return gradients for better differentiable topk
826
+
827
+ if RETURN_SEL_GRADS :
828
+
829
+ kv_block_grads_ptrs = (
830
+ kv_block_grads +
831
+ offs_m * stride_kvbl_m
832
+ )
833
+
834
+ sel_grads = ds * qk
835
+ sel_grads = tl .where (block_masks [:, None , None ], sel_grads , 0. )
836
+ sel_grads = sel_grads .reshape (BLOCK , QUERY_HEAD_GROUPS * BLOCK )
837
+ sel_grads = tl .sum (sel_grads , 1 )
838
+
839
+ tl .atomic_add (
840
+ kv_block_grads_ptrs + OFF_SEL_KV_BLOCKS ,
841
+ sel_grads ,
842
+ sem = 'relaxed'
843
+ )
844
+
826
845
# block dk
827
846
828
847
block_dk = ds [:, :, :, None ] * q [:, :, None , :].to (ds .dtype )
@@ -1145,6 +1164,7 @@ def backward_kernel(
1145
1164
V ,
1146
1165
kv_block_indices ,
1147
1166
kv_block_mask ,
1167
+ kv_block_grads ,
1148
1168
DO ,
1149
1169
DQ ,
1150
1170
DK ,
@@ -1192,19 +1212,16 @@ def backward_kernel(
1192
1212
BLOCK : tl .constexpr ,
1193
1213
QUERY_HEAD_GROUPS : tl .constexpr ,
1194
1214
QUERY_EXPAND_DIM : tl .constexpr ,
1215
+ RETURN_SEL_GRADS : tl .constexpr ,
1195
1216
INCLUDE_BLOCK_CAUSAL : tl .constexpr
1196
1217
):
1197
1218
off_hb = tl .program_id (1 )
1198
1219
off_b = off_hb // kv_heads
1199
1220
off_h = off_hb % kv_heads
1200
1221
off_qh = off_h * QUERY_HEAD_GROUPS
1201
1222
1202
- if INCLUDE_BLOCK_CAUSAL :
1203
- IS_CAUSAL = tl .program_id (0 ) == 0
1204
- OFF_SEL_KV_BLOCKS = tl .program_id (0 ) - 1
1205
- else :
1206
- IS_CAUSAL = False
1207
- OFF_SEL_KV_BLOCKS = tl .program_id (0 )
1223
+ OFF_SEL_KV_BLOCKS = tl .program_id (0 ) - int (INCLUDE_BLOCK_CAUSAL )
1224
+ IS_CAUSAL = INCLUDE_BLOCK_CAUSAL and tl .program_id (0 ) == 0
1208
1225
1209
1226
# offset pointers for batch/head
1210
1227
@@ -1220,6 +1237,7 @@ def backward_kernel(
1220
1237
1221
1238
kv_block_indices += off_b * stride_kvbl_b + off_h * stride_kvbl_h
1222
1239
kv_block_mask += off_b * stride_kvbl_b + off_h * stride_kvbl_h
1240
+ kv_block_grads += off_b * stride_kvbl_b + off_h * stride_kvbl_h
1223
1241
1224
1242
# pointer to row-wise quantities in value-like data
1225
1243
@@ -1283,6 +1301,7 @@ def backward_kernel(
1283
1301
V ,
1284
1302
kv_block_indices ,
1285
1303
kv_block_mask ,
1304
+ kv_block_grads ,
1286
1305
DO ,
1287
1306
DQ ,
1288
1307
DK ,
@@ -1312,6 +1331,7 @@ def backward_kernel(
1312
1331
BLOCK = BLOCK ,
1313
1332
QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS ,
1314
1333
QUERY_EXPAND_DIM = QUERY_EXPAND_DIM ,
1334
+ RETURN_SEL_GRADS = RETURN_SEL_GRADS ,
1315
1335
OFF_SEL_KV_BLOCKS = OFF_SEL_KV_BLOCKS
1316
1336
)
1317
1337
@@ -1320,11 +1340,13 @@ def native_sparse_attn_backward(
1320
1340
q , k , v ,
1321
1341
kv_block_indices ,
1322
1342
kv_block_mask ,
1343
+ kv_block_grads ,
1323
1344
o ,
1324
1345
lse ,
1325
1346
dq , dk , dv ,
1326
1347
block_size = 128 ,
1327
- include_block_causal = True
1348
+ include_block_causal = True ,
1349
+ return_sel_grads = False
1328
1350
):
1329
1351
device = do .device
1330
1352
@@ -1387,6 +1409,7 @@ def native_sparse_attn_backward(
1387
1409
v ,
1388
1410
kv_block_indices ,
1389
1411
kv_block_mask ,
1412
+ kv_block_grads ,
1390
1413
do ,
1391
1414
dq ,
1392
1415
dk ,
@@ -1436,6 +1459,7 @@ def native_sparse_attn_backward(
1436
1459
EVEN_M = divisible_by (seqlen_q , block_size ),
1437
1460
EVEN_N = divisible_by (seqlen_k , block_size ),
1438
1461
EVEN_HEADDIM = BLOCK_HEADDIM == dim ,
1462
+ RETURN_SEL_GRADS = return_sel_grads ,
1439
1463
INCLUDE_BLOCK_CAUSAL = include_block_causal ,
1440
1464
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
1441
1465
# num_warps=num_warps,
@@ -1458,6 +1482,7 @@ def forward(
1458
1482
block_size ,
1459
1483
selected_block_indices ,
1460
1484
fmask ,
1485
+ sel_scale ,
1461
1486
include_block_causal
1462
1487
):
1463
1488
dtype = fq .dtype
@@ -1478,10 +1503,16 @@ def forward(
1478
1503
1479
1504
ctx .save_for_backward (fq , fk , fv , selected_block_indices , fmask , out , lse )
1480
1505
1506
+ return_sel_grads = exists (sel_scale )
1507
+
1508
+ if return_sel_grads :
1509
+ assert (sel_scale == 1. ).all (), 'for now, must be straight through as multiplier of 1.'
1510
+
1481
1511
ctx ._saved_variables = (
1482
1512
block_size ,
1483
1513
head_groups ,
1484
- include_block_causal
1514
+ return_sel_grads ,
1515
+ include_block_causal ,
1485
1516
)
1486
1517
1487
1518
return out .type (dtype ), lse
@@ -1495,6 +1526,7 @@ def backward(self, ctx, do, _):
1495
1526
(
1496
1527
block_size ,
1497
1528
head_groups ,
1529
+ return_sel_grads ,
1498
1530
include_block_causal
1499
1531
) = ctx ._saved_variables
1500
1532
@@ -1503,15 +1535,23 @@ def backward(self, ctx, do, _):
1503
1535
dk = torch .zeros (k .shape , dtype = torch .float32 , device = device )
1504
1536
dv = torch .zeros (v .shape , dtype = torch .float32 , device = device )
1505
1537
1538
+ sel_grads = torch .zeros_like (sel_block_indices ).float ()
1539
+
1506
1540
native_sparse_attn_backward (
1507
1541
do , q , k , v ,
1508
- sel_block_indices , mask ,
1542
+ sel_block_indices , mask , sel_grads ,
1509
1543
out , lse , dq , dk , dv ,
1510
1544
block_size = block_size ,
1511
- include_block_causal = include_block_causal
1545
+ include_block_causal = include_block_causal ,
1546
+ return_sel_grads = return_sel_grads
1512
1547
)
1513
1548
1514
- return dq , dk , dv , None , None , None , None
1549
+ ret_sel_grads = None
1550
+
1551
+ if return_sel_grads :
1552
+ ret_sel_grads = sel_grads
1553
+
1554
+ return dq , dk , dv , None , None , None , ret_sel_grads , None
1515
1555
1516
1556
_native_sparse_attend = NSA .apply
1517
1557
@@ -1531,6 +1571,7 @@ def native_sparse_attend(
1531
1571
block_size : int ,
1532
1572
selected_block_indices : Int ['b qh n sel' ] | Int ['b kh n sel' ],
1533
1573
fmask : Bool ['b qh n sel' ] | Bool ['b kh n sel' ],
1574
+ sel_scale : Float ['b kh n sel' ] | Float ['b qh n sel' ] | None = None ,
1534
1575
include_block_causal = True ,
1535
1576
return_lse = False
1536
1577
):
@@ -1550,6 +1591,7 @@ def native_sparse_attend(
1550
1591
block_size ,
1551
1592
selected_block_indices ,
1552
1593
fmask ,
1594
+ sel_scale ,
1553
1595
include_block_causal
1554
1596
)
1555
1597
0 commit comments