Skip to content

Commit 0c5289e

Browse files
committed
some guards
1 parent 4290e30 commit 0c5289e

File tree

3 files changed

+25
-7
lines changed

3 files changed

+25
-7
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,17 @@ def forward_kernel_causal_and_sparse(
322322
q = q.reshape(BLOCK, 16, BLOCK_HEADDIM)
323323

324324
for off_sel_kv_block in range(NUM_SEL_KV_BLOCKS):
325-
block_indices = tl.load(kv_block_indices_ptrs + off_sel_kv_block)
326-
block_masks = tl.load(kv_block_mask_ptrs + off_sel_kv_block)
325+
block_indices = tl.load(
326+
kv_block_indices_ptrs + off_sel_kv_block,
327+
mask = offs_m < seqlen_q,
328+
other = 0
329+
)
330+
331+
block_masks = tl.load(
332+
kv_block_mask_ptrs + off_sel_kv_block,
333+
mask = offs_m < seqlen_q,
334+
other = False
335+
)
327336

328337
blocks_offs_n = block_indices[:, None] * BLOCK + tl.arange(0, BLOCK)[None, :]
329338

@@ -345,7 +354,11 @@ def forward_kernel_causal_and_sparse(
345354

346355
# load k of shape (m, n, d), sparsely selected by each query
347356

348-
k_block = tl.load(block_k_ptrs)
357+
k_block = tl.load(
358+
block_k_ptrs,
359+
mask = blocks_offs_n[:, :, None] < seqlen_k,
360+
other = 0.
361+
)
349362

350363
# similarities
351364

@@ -376,7 +389,12 @@ def forward_kernel_causal_and_sparse(
376389

377390
# aggregate values
378391

379-
v_block = tl.load(block_v_ptrs)
392+
v_block = tl.load(
393+
block_v_ptrs,
394+
mask = blocks_offs_n[:, :, None] < seqlen_k,
395+
other = 0.
396+
)
397+
380398
v_block = tl.reshape(v_block, (BLOCK, BLOCK, BLOCK_HEADDIM))
381399

382400
block_p = block_p.to(v_block.dtype)
@@ -628,7 +646,7 @@ def native_sparse_attn_forward(
628646
QUERY_HEAD_GROUPS = head_groups,
629647
NUM_SEL_KV_BLOCKS = num_selected_fine_blocks,
630648
INCLUDE_BLOCK_CAUSAL = include_block_causal,
631-
RETURN_SLIDING_OUT = False,
649+
RETURN_SLIDING_OUT = return_sliding_window_out,
632650
num_warps = num_warps,
633651
num_stages = 1,
634652
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.1.5"
3+
version = "0.1.6"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

test_triton_nsa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def regular_attend(
141141
fine_block_size = 16
142142
num_sel = 6
143143
dim_head = 64
144-
fused_sliding_window = True
144+
fused_sliding_window = False
145145
block_dk_dv_use_dot = False # need sufficient shared memory, A100 works
146146

147147
q = torch.randn(batch, q_heads, seq_len, dim_head).cuda()

0 commit comments

Comments
 (0)