Skip to content

Commit 7ebebca

Browse files
committed
fix maximum tracking in triton
1 parent e0c3bd3 commit 7ebebca

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def forward_kernel_causal_and_sparse(
278278

279279
qk += tl.where(causal_mask, 0, float("-inf"))
280280

281-
m_ij = tl.maximum(tl.max(qk, 2) * softmax_scale, lse_i)
281+
m_ij = tl.maximum(tl.max(qk, 2) * softmax_scale, m_i)
282282
p = tl.exp(qk * softmax_scale - m_ij[:, :, None])
283283

284284
l_ij = tl.sum(p, 2)
@@ -408,7 +408,7 @@ def forward_kernel_causal_and_sparse(
408408

409409
# attention
410410

411-
m_ij = tl.maximum(tl.max(sel_qk, 2) * softmax_scale, lse_i)
411+
m_ij = tl.maximum(tl.max(sel_qk, 2) * softmax_scale, m_i)
412412
block_p = tl.exp(sel_qk * softmax_scale - m_ij[:, :, None])
413413

414414
l_ij = tl.sum(block_p, 2)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.2.0"
3+
version = "0.2.1"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)