Skip to content

Commit 97aa4ae

Browse files
committed
fix a tricky bug with lse, lse is rounded to 128, but padding needs to remain -inf
1 parent 8fad00a commit 97aa4ae

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def forward_kernel_causal_and_sparse(
401401
# write back lse
402402

403403
lse_i = lse_i.reshape(BLOCK, QUERY_HEAD_GROUPS)
404-
tl.store(lse_ptrs, lse_i, mask = offs_m[:, None] < seqlen_q)
404+
tl.store(lse_ptrs, lse_i)
405405

406406
# write to output
407407

@@ -1362,12 +1362,12 @@ def backward_kernel(
13621362

13631363
D += (
13641364
off_b * stride_D_b +
1365-
off_h * QUERY_HEAD_GROUPS * seqlen_q_rounded
1365+
off_qh * seqlen_q_rounded
13661366
)
13671367

13681368
LSE += (
13691369
off_b * stride_lse_b +
1370-
off_h * QUERY_HEAD_GROUPS * seqlen_q_rounded
1370+
off_qh * seqlen_q_rounded
13711371
)
13721372

13731373
num_block_n = tl.cdiv(seqlen_k, BLOCK)
@@ -1719,5 +1719,5 @@ def native_sparse_attend(
17191719

17201720
if not return_lse:
17211721
return out
1722-
1722+
17231723
return out, lse[..., :seq_len]

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.78"
3+
version = "0.1.0"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)