Skip to content

Commit 9e884df

Browse files
committed
fix intermittent issue with triton nsa dk
1 parent 860517c commit 9e884df

File tree

3 files changed

+12
-7
lines changed

3 files changed

+12
-7
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -822,20 +822,25 @@ def backward_kernel_one_col_block_sparse(
822822
# ds
823823

824824
ds = (p * (dp - Di[:, :, None]) * softmax_scale)
825-
ds = ds.to(q.dtype)
826825

827826
# block dk
828827

829-
block_dk = ds[:, :, :, None] * q[:, :, None, :]
828+
block_dk = ds[:, :, :, None] * q[:, :, None, :].to(ds.dtype)
830829
block_dk = tl.sum(block_dk, 1)
831830

832-
tl.atomic_add(block_dk_ptrs, block_dk, mask = block_masks[:, None, None], sem = 'relaxed')
831+
tl.atomic_add(
832+
block_dk_ptrs,
833+
block_dk,
834+
mask = block_masks[:, None, None] & (blocks_offs_n[:, :, None] < seqlen_k),
835+
sem = 'relaxed'
836+
)
833837

834838
# block dq
835839

836840
ds_expanded = tl.expand_dims(ds, 2)
837841
ds_expanded = tl.broadcast_to(ds_expanded, (BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK))
838842
ds_expanded = ds_expanded.reshape(BLOCK, 16, BLOCK)
843+
ds_expanded = ds_expanded.to(block_k.dtype)
839844

840845
block_dq = tl.dot(ds_expanded, block_k)
841846

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.65"
3+
version = "0.0.66"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

test_triton_nsa.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ def regular_attend(
100100

101101
# mock inputs
102102

103-
batch = 2
104-
seq_len = 511
103+
batch = 4
104+
seq_len = 507
105105
q_heads = 4
106106
kv_heads = 2
107107
fine_block_size = 16
@@ -135,7 +135,7 @@ def regular_attend(
135135
assert torch.allclose(rlse, nlse, atol = 1e-2)
136136

137137
assert torch.allclose(nv.grad, rv.grad, atol = 1e-2)
138-
assert torch.allclose(nk.grad, rk.grad, atol = 1e-2)
139138
assert torch.allclose(nq.grad, rq.grad, atol = 1e-2)
139+
assert torch.allclose(nk.grad, rk.grad, atol = 1e-2)
140140

141141
print('✅ outputs and gradients are same between pytorch native sparse attn and triton native sparse attn')

0 commit comments

Comments
 (0)