Skip to content

Commit 68fd8ee

Browse files
committed
continue to setup towards optional fusing sliding window attn
1 parent d72f3a8 commit 68fd8ee

File tree

2 files changed

+67
-23
lines changed

2 files changed

+67
-23
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,7 @@ def forward_kernel(
450450
kv_block_indices,
451451
kv_block_mask,
452452
Out,
453+
SlidingOut,
453454
Lse,
454455
softmax_scale,
455456
stride_qb,
@@ -484,15 +485,22 @@ def forward_kernel(
484485
QUERY_EXPAND_DIM: tl.constexpr,
485486
NUM_SEL_KV_BLOCKS: tl.constexpr,
486487
INCLUDE_BLOCK_CAUSAL: tl.constexpr,
487-
SLIDING: tl.constexpr
488+
RETURN_SLIDING_OUT: tl.constexpr
488489
):
490+
if RETURN_SLIDING_OUT:
491+
sliding = tl.program_id(2) == 0
492+
out_ptr = SlidingOut if sliding else Out
493+
else:
494+
sliding = False
495+
out_ptr = Out
496+
489497
forward_kernel_causal_and_sparse(
490498
Q,
491499
K,
492500
V,
493501
kv_block_indices,
494502
kv_block_mask,
495-
Out,
503+
out_ptr,
496504
Lse,
497505
softmax_scale,
498506
stride_qb,
@@ -527,7 +535,7 @@ def forward_kernel(
527535
QUERY_EXPAND_DIM,
528536
NUM_SEL_KV_BLOCKS,
529537
INCLUDE_BLOCK_CAUSAL,
530-
SLIDING
538+
sliding
531539
)
532540

533541
def native_sparse_attn_forward(
@@ -537,7 +545,8 @@ def native_sparse_attn_forward(
537545
kv_block_indices,
538546
kv_block_mask,
539547
block_size = 128,
540-
include_block_causal = True
548+
include_block_causal = True,
549+
return_sliding_window_out = False
541550
):
542551
q, k, v, kv_block_indices = [x if is_contiguous(x) else x.contiguous() for x in (q, k, v, kv_block_indices)]
543552

@@ -563,11 +572,16 @@ def native_sparse_attn_forward(
563572
lse = torch.empty((batch, nheads, seqlen_q_rounded), device = device, dtype = torch.float32)
564573

565574
o = torch.empty_like(q)
575+
slide_o = torch.empty_like(q)
566576

567577
BLOCK_HEADDIM = max(triton.next_power_of_2(dim), 16)
568578
num_warps = 4 if dim <= 64 else 8
569579

570-
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK"]), batch * kv_heads) # kv heads here, as grouped query heads all loaded, following the paper
580+
grid = lambda META: (
581+
triton.cdiv(seqlen_q, META["BLOCK"]),
582+
batch * kv_heads,
583+
(2 if return_sliding_window_out else 1)
584+
) # kv heads here, as grouped query heads all loaded, following the paper
571585

572586
forward_kernel[grid](
573587
q,
@@ -576,6 +590,7 @@ def native_sparse_attn_forward(
576590
kv_block_indices,
577591
kv_block_mask,
578592
o,
593+
slide_o,
579594
lse,
580595
softmax_scale,
581596
q.stride(0),
@@ -606,12 +621,12 @@ def native_sparse_attn_forward(
606621
QUERY_HEAD_GROUPS = head_groups,
607622
NUM_SEL_KV_BLOCKS = num_selected_fine_blocks,
608623
INCLUDE_BLOCK_CAUSAL = include_block_causal,
609-
SLIDING = False,
624+
RETURN_SLIDING_OUT = False,
610625
num_warps = num_warps,
611626
num_stages = 1,
612627
)
613628

614-
return o, lse
629+
return o, slide_o, lse
615630

616631
@triton.jit
617632
def backward_preprocess_do_o_dot(
@@ -1648,7 +1663,8 @@ def forward(
16481663
fmask,
16491664
sel_scale,
16501665
include_block_causal,
1651-
block_dk_dv_use_dot
1666+
block_dk_dv_use_dot,
1667+
return_sliding_window_out
16521668
):
16531669
dtype = fq.dtype
16541670

@@ -1658,15 +1674,16 @@ def forward(
16581674

16591675
fq, fk, fv = tuple(t.half() for t in (fq, fk, fv))
16601676

1661-
out, lse = native_sparse_attn_forward(
1677+
out, slide_out, lse = native_sparse_attn_forward(
16621678
fq, fk, fv,
16631679
selected_block_indices,
16641680
fmask,
16651681
block_size = block_size,
16661682
include_block_causal = include_block_causal,
1683+
return_sliding_window_out = return_sliding_window_out
16671684
)
16681685

1669-
ctx.save_for_backward(fq, fk, fv, selected_block_indices, fmask, out, lse)
1686+
ctx.save_for_backward(fq, fk, fv, selected_block_indices, fmask, out, slide_out, lse)
16701687

16711688
return_sel_grads = exists(sel_scale)
16721689

@@ -1678,23 +1695,32 @@ def forward(
16781695
head_groups,
16791696
return_sel_grads,
16801697
include_block_causal,
1681-
block_dk_dv_use_dot
1698+
block_dk_dv_use_dot,
1699+
return_sliding_window_out
16821700
)
16831701

1684-
return out.type(dtype), lse
1702+
return out.type(dtype), slide_out.type(dtype), lse
16851703

16861704
@classmethod
1687-
def backward(self, ctx, do, _):
1705+
def backward(self, ctx, do, do_sliding, _):
16881706
device = do.device
16891707

1690-
q, k, v, sel_block_indices, mask, out, lse = ctx.saved_tensors
1708+
(
1709+
q, k, v,
1710+
sel_block_indices,
1711+
mask,
1712+
out,
1713+
slide_out,
1714+
lse
1715+
) = ctx.saved_tensors
16911716

16921717
(
16931718
block_size,
16941719
head_groups,
16951720
return_sel_grads,
16961721
include_block_causal,
1697-
block_dk_dv_use_dot
1722+
block_dk_dv_use_dot,
1723+
return_sliding_window_out
16981724
) = ctx._saved_variables
16991725

17001726
do = do.half()
@@ -1719,7 +1745,7 @@ def backward(self, ctx, do, _):
17191745
if return_sel_grads:
17201746
ret_sel_grads = sel_grads
17211747

1722-
return dq, dk, dv, None, None, None, ret_sel_grads, None, None
1748+
return dq, dk, dv, None, None, None, ret_sel_grads, None, None, None
17231749

17241750
_native_sparse_attend = NSA.apply
17251751

@@ -1742,7 +1768,8 @@ def native_sparse_attend(
17421768
sel_scale: Float['b kh n sel'] | Float['b qh n sel'] | None = None,
17431769
include_block_causal = True,
17441770
return_lse = False,
1745-
block_dk_dv_use_dot = False
1771+
block_dk_dv_use_dot = False,
1772+
return_sliding_window_out = False
17461773
):
17471774
seq_len = fq.shape[-2]
17481775
q_heads, kv_heads, sel_heads = fq.shape[1], fk.shape[1], selected_block_indices.shape[1]
@@ -1757,16 +1784,20 @@ def native_sparse_attend(
17571784
if kv_heads != sel_heads:
17581785
fk, fv = tuple(repeat(t, 'b h ... -> b (h gh) ...', gh = q_heads // kv_heads) for t in (fk, fv))
17591786

1760-
out, lse = _native_sparse_attend(
1787+
out, sliding_out, lse = _native_sparse_attend(
17611788
fq, fk, fv,
17621789
block_size,
17631790
selected_block_indices,
17641791
fmask,
17651792
sel_scale,
17661793
include_block_causal,
1767-
block_dk_dv_use_dot
1794+
block_dk_dv_use_dot,
1795+
return_sliding_window_out
17681796
)
17691797

1798+
if return_sliding_window_out:
1799+
out = (out, out)
1800+
17701801
if not return_lse:
17711802
return out
17721803

test_triton_nsa.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,17 +162,30 @@ def regular_attend(
162162
out, rlse = regular_attend(rq, rk, rv, indices, mask, block_size = fine_block_size, sel_scale = rsel_scale, return_lse = True, return_sliding_window_out = fused_sliding_window)
163163

164164
if fused_sliding_window:
165-
out = sum(out)
165+
loss = sum(out).sum()
166+
else:
167+
loss = out.sum()
166168

167-
out.sum().backward()
169+
loss.backward()
168170

169171
# triton nsa forwards and backwards
170172

171-
nsa_out, nlse = native_sparse_attend(nq, nk, nv, fine_block_size, indices, mask, sel_scale = nsel_scale, return_lse = True, block_dk_dv_use_dot = block_dk_dv_use_dot)
172-
nsa_out.sum().backward()
173+
nsa_out, nlse = native_sparse_attend(nq, nk, nv, fine_block_size, indices, mask, sel_scale = nsel_scale, return_lse = True, block_dk_dv_use_dot = block_dk_dv_use_dot, return_sliding_window_out = fused_sliding_window)
174+
175+
if fused_sliding_window:
176+
nsa_loss = sum(nsa_out).sum()
177+
else:
178+
nsa_loss = nsa_out.sum()
179+
180+
nsa_loss.backward()
173181

174182
# asserts
175183

184+
if fused_sliding_window:
185+
out, sliding_out = out
186+
nsa_out, sliding_nsa_out = nsa_out
187+
assert torch.allclose(sliding_out, sliding_nsa_out, atol = 1e-2)
188+
176189
assert torch.allclose(out, nsa_out, atol = 1e-2)
177190
assert torch.allclose(rlse, nlse, atol = 1e-2)
178191

0 commit comments

Comments
 (0)