Skip to content

Commit 4290e30

Browse files
committed
last commit before a week of jury duty
1 parent 68fd8ee commit 4290e30

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,7 @@ def forward_kernel(
452452
Out,
453453
SlidingOut,
454454
Lse,
455+
SlidingLse,
455456
softmax_scale,
456457
stride_qb,
457458
stride_qh,
@@ -490,9 +491,13 @@ def forward_kernel(
490491
if RETURN_SLIDING_OUT:
491492
sliding = tl.program_id(2) == 0
492493
out_ptr = SlidingOut if sliding else Out
494+
lse_ptr = SlidingLse if sliding else Lse
495+
num_sel_kv_blocks = 0 if sliding else NUM_SEL_KV_BLOCKS
493496
else:
494497
sliding = False
495498
out_ptr = Out
499+
lse_ptr = Lse
500+
num_sel_kv_blocks = NUM_SEL_KV_BLOCKS
496501

497502
forward_kernel_causal_and_sparse(
498503
Q,
@@ -533,7 +538,7 @@ def forward_kernel(
533538
BLOCK,
534539
QUERY_HEAD_GROUPS,
535540
QUERY_EXPAND_DIM,
536-
NUM_SEL_KV_BLOCKS,
541+
num_sel_kv_blocks,
537542
INCLUDE_BLOCK_CAUSAL,
538543
sliding
539544
)
@@ -570,6 +575,7 @@ def native_sparse_attn_forward(
570575
seqlen_q_rounded = round_up_multiple(seqlen_q, TRITON_BLOCK_SIZE)
571576

572577
lse = torch.empty((batch, nheads, seqlen_q_rounded), device = device, dtype = torch.float32)
578+
sliding_lse = torch.empty((batch, nheads, seqlen_q_rounded), device = device, dtype = torch.float32)
573579

574580
o = torch.empty_like(q)
575581
slide_o = torch.empty_like(q)
@@ -592,6 +598,7 @@ def native_sparse_attn_forward(
592598
o,
593599
slide_o,
594600
lse,
601+
sliding_lse,
595602
softmax_scale,
596603
q.stride(0),
597604
q.stride(1),

test_triton_nsa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def regular_attend(
141141
fine_block_size = 16
142142
num_sel = 6
143143
dim_head = 64
144-
fused_sliding_window = False
144+
fused_sliding_window = True
145145
block_dk_dv_use_dot = False # need sufficient shared memory, A100 works
146146

147147
q = torch.randn(batch, q_heads, seq_len, dim_head).cuda()

0 commit comments

Comments
 (0)