Skip to content

Commit af75120

Browse files
committed
more guards
1 parent 0c5289e commit af75120

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -947,8 +947,17 @@ def backward_kernel_one_col_block_sparse(
947947
offs_d[None, None, :]
948948
)
949949

950-
block_k = tl.load(block_k_ptrs)
951-
block_v = tl.load(block_v_ptrs)
950+
block_k = tl.load(
951+
block_k_ptrs,
952+
mask = blocks_offs_n[:, :, None] < seqlen_k,
953+
other = 0.
954+
)
955+
956+
block_v = tl.load(
957+
block_v_ptrs,
958+
mask = blocks_offs_n[:, :, None] < seqlen_k,
959+
other = 0.
960+
)
952961

953962
q_expanded = tl.expand_dims(q, 2)
954963
q_expanded = tl.broadcast_to(q_expanded, (BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK_HEADDIM))
@@ -984,7 +993,11 @@ def backward_kernel_one_col_block_sparse(
984993
block_dv = p.to(do.dtype)[:, :, :, None] * do[:, :, None, :]
985994
block_dv = tl.sum(block_dv, 1)
986995

987-
tl.atomic_add(block_dv_ptrs, block_dv, mask = block_masks[:, None, None], sem = 'relaxed')
996+
tl.atomic_add(
997+
block_dv_ptrs, block_dv,
998+
mask = block_masks[:, None, None] & blocks_offs_n[:, :, None] < seqlen_k,
999+
sem = 'relaxed'
1000+
)
9881001

9891002
# get dp
9901003

@@ -1016,6 +1029,7 @@ def backward_kernel_one_col_block_sparse(
10161029
tl.atomic_add(
10171030
kv_block_grads_ptrs + OFF_SEL_KV_BLOCKS,
10181031
sel_grads,
1032+
mask = offs_m < seqlen_q,
10191033
sem = 'relaxed'
10201034
)
10211035

@@ -1037,7 +1051,7 @@ def backward_kernel_one_col_block_sparse(
10371051
tl.atomic_add(
10381052
block_dk_ptrs,
10391053
block_dk,
1040-
mask = block_masks[:, None, None] & (blocks_offs_n[:, :, None] < seqlen_k),
1054+
mask = block_masks[:, None, None] & blocks_offs_n[:, :, None] < seqlen_k,
10411055
sem = 'relaxed'
10421056
)
10431057

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.1.6"
3+
version = "0.1.7"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)