Skip to content

Commit ef77474

Browse files
committed
precautionary
1 parent 06f2eae commit ef77474

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -802,10 +802,9 @@ def backward_kernel_one_col_block_sparse(
802802

803803
block_dv = p.to(do.dtype)[:, :, :, None] * do[:, :, None, :]
804804

805-
# block_dv = block_dv.reshape(QUERY_HEAD_GROUPS, BLOCK, BLOCK, BLOCK_HEADDIM)
806805
block_dv = tl.sum(block_dv, 1)
807806

808-
tl.atomic_add(block_dv_ptrs, block_dv, sem = 'relaxed')
807+
tl.atomic_add(block_dv_ptrs, block_dv, mask = block_masks[:, None, None], sem = 'relaxed')
809808

810809
# get dp
811810

@@ -830,7 +829,7 @@ def backward_kernel_one_col_block_sparse(
830829
block_dk = ds[:, :, :, None] * q[:, :, None, :]
831830
block_dk = tl.sum(block_dk, 1)
832831

833-
tl.atomic_add(block_dk_ptrs, block_dk, sem = 'relaxed')
832+
tl.atomic_add(block_dk_ptrs, block_dk, mask = block_masks[:, None, None], sem = 'relaxed')
834833

835834
# block dq
836835

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.59"
3+
version = "0.0.60"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)