@@ -947,8 +947,17 @@ def backward_kernel_one_col_block_sparse(
947
947
offs_d [None , None , :]
948
948
)
949
949
950
- block_k = tl .load (block_k_ptrs )
951
- block_v = tl .load (block_v_ptrs )
950
+ block_k = tl .load (
951
+ block_k_ptrs ,
952
+ mask = blocks_offs_n [:, :, None ] < seqlen_k ,
953
+ other = 0.
954
+ )
955
+
956
+ block_v = tl .load (
957
+ block_v_ptrs ,
958
+ mask = blocks_offs_n [:, :, None ] < seqlen_k ,
959
+ other = 0.
960
+ )
952
961
953
962
q_expanded = tl .expand_dims (q , 2 )
954
963
q_expanded = tl .broadcast_to (q_expanded , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK_HEADDIM ))
@@ -984,7 +993,11 @@ def backward_kernel_one_col_block_sparse(
984
993
block_dv = p .to (do .dtype )[:, :, :, None ] * do [:, :, None , :]
985
994
block_dv = tl .sum (block_dv , 1 )
986
995
987
- tl .atomic_add (block_dv_ptrs , block_dv , mask = block_masks [:, None , None ], sem = 'relaxed' )
996
+ tl .atomic_add (
997
+ block_dv_ptrs , block_dv ,
998
+ mask = block_masks [:, None , None ] & blocks_offs_n [:, :, None ] < seqlen_k ,
999
+ sem = 'relaxed'
1000
+ )
988
1001
989
1002
# get dp
990
1003
@@ -1016,6 +1029,7 @@ def backward_kernel_one_col_block_sparse(
1016
1029
tl .atomic_add (
1017
1030
kv_block_grads_ptrs + OFF_SEL_KV_BLOCKS ,
1018
1031
sel_grads ,
1032
+ mask = offs_m < seqlen_q ,
1019
1033
sem = 'relaxed'
1020
1034
)
1021
1035
@@ -1037,7 +1051,7 @@ def backward_kernel_one_col_block_sparse(
1037
1051
tl .atomic_add (
1038
1052
block_dk_ptrs ,
1039
1053
block_dk ,
1040
- mask = block_masks [:, None , None ] & ( blocks_offs_n [:, :, None ] < seqlen_k ) ,
1054
+ mask = block_masks [:, None , None ] & blocks_offs_n [:, :, None ] < seqlen_k ,
1041
1055
sem = 'relaxed'
1042
1056
)
1043
1057
0 commit comments