@@ -742,7 +742,9 @@ def backward_kernel_one_col_block_sparse(
742
742
QUERY_HEAD_GROUPS : tl .constexpr ,
743
743
QUERY_EXPAND_DIM : tl .constexpr ,
744
744
RETURN_SEL_GRADS : tl .constexpr ,
745
- OFF_SEL_KV_BLOCKS : tl .constexpr
745
+ OFF_SEL_KV_BLOCKS : tl .constexpr ,
746
+ BLOCK_DV_USE_DOT : tl .constexpr ,
747
+ BLOCK_DK_USE_DOT : tl .constexpr ,
746
748
):
747
749
# We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
748
750
@@ -918,23 +920,33 @@ def backward_kernel_one_col_block_sparse(
918
920
919
921
p = tl .exp (masked_qk * softmax_scale - lse_i [:, :, None ])
920
922
923
+ # prepare do
924
+
925
+ do_expanded = tl .expand_dims (do , 2 )
926
+ do_expanded = tl .broadcast_to (do_expanded , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK_HEADDIM ))
927
+ do_expanded = do_expanded .reshape (BLOCK , 16 , BLOCK_HEADDIM )
928
+
921
929
# take care of block dv
922
930
923
- block_dv = p .to (do .dtype )[:, :, :, None ] * do [:, :, None , :]
931
+ if not BLOCK_DV_USE_DOT :
932
+ p_expanded = p .to (do .dtype )
933
+ p_expanded = tl .expand_dims (p_expanded , 2 )
934
+ p_expanded = tl .broadcast_to (p_expanded , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK ))
935
+ p_expanded = p_expanded .reshape (BLOCK , QUERY_HEAD_GROUPS * QUERY_EXPAND_DIM , BLOCK )
936
+ p_expanded = tl .permute (p_expanded , (0 , 2 , 1 ))
924
937
925
- block_dv = tl .sum (block_dv , 1 )
938
+ block_dv = tl .dot (p_expanded , do_expanded ) / QUERY_EXPAND_DIM
939
+ else :
940
+ block_dv = p .to (do .dtype )[:, :, :, None ] * do [:, :, None , :]
941
+ block_dv = tl .sum (block_dv , 1 )
926
942
927
943
tl .atomic_add (block_dv_ptrs , block_dv , mask = block_masks [:, None , None ], sem = 'relaxed' )
928
944
929
945
# get dp
930
946
931
- do = tl .expand_dims (do , 2 )
932
- do = tl .broadcast_to (do , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK_HEADDIM ))
933
- do = do .reshape (BLOCK , 16 , BLOCK_HEADDIM )
934
-
935
947
block_v = tl .permute (block_v , (0 , 2 , 1 ))
936
948
937
- dp = tl .dot (do , block_v )
949
+ dp = tl .dot (do_expanded , block_v )
938
950
939
951
dp = dp .reshape (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK )
940
952
dp = tl .sum (dp , 2 ) / QUERY_EXPAND_DIM
@@ -963,10 +975,23 @@ def backward_kernel_one_col_block_sparse(
963
975
sem = 'relaxed'
964
976
)
965
977
978
+ # ds
979
+
980
+ ds_expanded = tl .expand_dims (ds , 2 )
981
+ ds_expanded = tl .broadcast_to (ds_expanded , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK ))
982
+ ds_expanded = ds_expanded .reshape (BLOCK , 16 , BLOCK )
983
+
966
984
# block dk
967
985
968
- block_dk = ds [:, :, :, None ] * q [:, :, None , :].to (ds .dtype )
969
- block_dk = tl .sum (block_dk , 1 )
986
+ if BLOCK_DK_USE_DOT :
987
+ q_expanded = tl .expand_dims (q , 2 )
988
+ q_expanded = tl .broadcast_to (q_expanded , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK_HEADDIM ))
989
+ q_expanded = q_expanded .reshape (BLOCK , 16 , BLOCK_HEADDIM )
990
+
991
+ block_dk = tl .dot (tl .permute (ds_expanded , (0 , 2 , 1 )), q_expanded .to (ds .dtype )) / QUERY_EXPAND_DIM
992
+ else :
993
+ block_dk = ds [:, :, :, None ] * q [:, :, None , :].to (ds .dtype )
994
+ block_dk = tl .sum (block_dk , 1 )
970
995
971
996
tl .atomic_add (
972
997
block_dk_ptrs ,
@@ -977,11 +1002,7 @@ def backward_kernel_one_col_block_sparse(
977
1002
978
1003
# block dq
979
1004
980
- ds_expanded = tl .expand_dims (ds , 2 )
981
- ds_expanded = tl .broadcast_to (ds_expanded , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK ))
982
- ds_expanded = ds_expanded .reshape (BLOCK , 16 , BLOCK )
983
1005
ds_expanded = ds_expanded .to (block_k .dtype )
984
-
985
1006
block_dq = tl .dot (ds_expanded , block_k )
986
1007
987
1008
block_dq = block_dq .reshape (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK_HEADDIM )
@@ -1348,6 +1369,8 @@ def backward_kernel(
1348
1369
RETURN_SEL_GRADS : tl .constexpr ,
1349
1370
INCLUDE_BLOCK_CAUSAL : tl .constexpr ,
1350
1371
SLIDING : tl .constexpr ,
1372
+ BLOCK_DV_USE_DOT : tl .constexpr ,
1373
+ BLOCK_DK_USE_DOT : tl .constexpr ,
1351
1374
):
1352
1375
off_hb = tl .program_id (1 )
1353
1376
off_b = off_hb // kv_heads
@@ -1467,7 +1490,9 @@ def backward_kernel(
1467
1490
QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS ,
1468
1491
QUERY_EXPAND_DIM = QUERY_EXPAND_DIM ,
1469
1492
RETURN_SEL_GRADS = RETURN_SEL_GRADS ,
1470
- OFF_SEL_KV_BLOCKS = OFF_SEL_KV_BLOCKS
1493
+ OFF_SEL_KV_BLOCKS = OFF_SEL_KV_BLOCKS ,
1494
+ BLOCK_DV_USE_DOT = BLOCK_DV_USE_DOT ,
1495
+ BLOCK_DK_USE_DOT = BLOCK_DK_USE_DOT ,
1471
1496
)
1472
1497
1473
1498
def native_sparse_attn_backward (
@@ -1482,7 +1507,8 @@ def native_sparse_attn_backward(
1482
1507
block_size = 128 ,
1483
1508
include_block_causal = True ,
1484
1509
return_sel_grads = False ,
1485
- sliding = False
1510
+ sliding = False ,
1511
+ block_dk_dv_use_dot = None
1486
1512
):
1487
1513
device = do .device
1488
1514
@@ -1596,7 +1622,9 @@ def native_sparse_attn_backward(
1596
1622
EVEN_HEADDIM = BLOCK_HEADDIM == dim ,
1597
1623
RETURN_SEL_GRADS = return_sel_grads ,
1598
1624
INCLUDE_BLOCK_CAUSAL = include_block_causal ,
1599
- SLIDING = sliding
1625
+ SLIDING = sliding ,
1626
+ BLOCK_DV_USE_DOT = default (block_dk_dv_use_dot , head_groups > 1 ),
1627
+ BLOCK_DK_USE_DOT = default (block_dk_dv_use_dot , head_groups > 1 )
1600
1628
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
1601
1629
# num_warps=num_warps,
1602
1630
# num_stages=1,
@@ -1619,7 +1647,8 @@ def forward(
1619
1647
selected_block_indices ,
1620
1648
fmask ,
1621
1649
sel_scale ,
1622
- include_block_causal
1650
+ include_block_causal ,
1651
+ block_dk_dv_use_dot
1623
1652
):
1624
1653
dtype = fq .dtype
1625
1654
@@ -1649,6 +1678,7 @@ def forward(
1649
1678
head_groups ,
1650
1679
return_sel_grads ,
1651
1680
include_block_causal ,
1681
+ block_dk_dv_use_dot
1652
1682
)
1653
1683
1654
1684
return out .type (dtype ), lse
@@ -1663,7 +1693,8 @@ def backward(self, ctx, do, _):
1663
1693
block_size ,
1664
1694
head_groups ,
1665
1695
return_sel_grads ,
1666
- include_block_causal
1696
+ include_block_causal ,
1697
+ block_dk_dv_use_dot
1667
1698
) = ctx ._saved_variables
1668
1699
1669
1700
do = do .half ()
@@ -1679,15 +1710,16 @@ def backward(self, ctx, do, _):
1679
1710
out , lse , dq , dk , dv ,
1680
1711
block_size = block_size ,
1681
1712
include_block_causal = include_block_causal ,
1682
- return_sel_grads = return_sel_grads
1713
+ return_sel_grads = return_sel_grads ,
1714
+ block_dk_dv_use_dot = block_dk_dv_use_dot
1683
1715
)
1684
1716
1685
1717
ret_sel_grads = None
1686
1718
1687
1719
if return_sel_grads :
1688
1720
ret_sel_grads = sel_grads
1689
1721
1690
- return dq , dk , dv , None , None , None , ret_sel_grads , None
1722
+ return dq , dk , dv , None , None , None , ret_sel_grads , None , None
1691
1723
1692
1724
_native_sparse_attend = NSA .apply
1693
1725
@@ -1709,7 +1741,8 @@ def native_sparse_attend(
1709
1741
fmask : Bool ['b qh n sel' ] | Bool ['b kh n sel' ],
1710
1742
sel_scale : Float ['b kh n sel' ] | Float ['b qh n sel' ] | None = None ,
1711
1743
include_block_causal = True ,
1712
- return_lse = False
1744
+ return_lse = False ,
1745
+ block_dk_dv_use_dot = False
1713
1746
):
1714
1747
seq_len = fq .shape [- 2 ]
1715
1748
q_heads , kv_heads , sel_heads = fq .shape [1 ], fk .shape [1 ], selected_block_indices .shape [1 ]
@@ -1730,7 +1763,8 @@ def native_sparse_attend(
1730
1763
selected_block_indices ,
1731
1764
fmask ,
1732
1765
sel_scale ,
1733
- include_block_causal
1766
+ include_block_causal ,
1767
+ block_dk_dv_use_dot
1734
1768
)
1735
1769
1736
1770
if not return_lse :
0 commit comments