Skip to content

Commit 77e55aa

Browse files
committed
move the gating back onto the selected keys for improved differentiable topk, do it for the blocked values later as well
1 parent cee1248 commit 77e55aa

File tree

4 files changed

+87
-34
lines changed

4 files changed

+87
-34
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -616,10 +616,10 @@ def forward(
616616

617617
selected_importance_values, selected_block_indices = importance_scores.topk(num_selected, dim = -1)
618618

619+
gates = None
620+
619621
if self.use_diff_topk:
620622
gates = straight_through(selected_importance_values, 1.)
621-
gates = gates.cumprod(dim = -1)[..., -1]
622-
gates = repeat(gates, 'b h ... -> b (h qh) ...', qh = fine_num_grouped_queries)
623623

624624
if self.use_triton_kernel and not disable_triton_kernel:
625625

@@ -631,10 +631,13 @@ def forward(
631631
fq, fk, fv,
632632
self.selection_block_size,
633633
selected_block_indices,
634-
fmask
634+
fmask,
635+
sel_scale = gates
635636
)
636637

637638
elif exists(fine_selection_flex_mask):
639+
assert not self.use_diff_topk, 'differential topk is not available for flex attention'
640+
638641
# flex attention for the selection for fine attention
639642

640643
fine_block_mask = fine_selection_flex_mask(selected_block_indices, num_grouped_queries = fine_num_grouped_queries)
@@ -654,9 +657,6 @@ def forward(
654657

655658
selected_block_indices = pad_at_dim(selected_block_indices, (0, remainder), value = 0, dim = -2)
656659

657-
if self.use_diff_topk:
658-
gates = pad_at_dim(gates, (0, remainder), value = 1.)
659-
660660
# handle block causal diagonal in the diagram, but run experiments without to see
661661

662662
fine_window_seq = arange(fine_divisible_seq_len, device = device) // self.selection_block_size
@@ -690,6 +690,13 @@ def forward(
690690
fk = fk.gather(3, selected_block_indices)
691691
fv = fv.gather(3, selected_block_indices)
692692

693+
# differential topk gating
694+
695+
if self.use_diff_topk:
696+
fk = einx.multiply('b h i sel, b h i sel j d -> b h i sel j d', gates, fk)
697+
698+
# merge selected key values
699+
693700
fk, fv = tuple(rearrange(t, 'b h i w j d -> b h i (w j) d') for t in (fk, fv))
694701

695702
# fine attention
@@ -712,12 +719,6 @@ def forward(
712719

713720
fine_attn_out = fine_attn_out[..., :seq_len, :]
714721

715-
# handle maybe gating
716-
717-
if self.use_diff_topk:
718-
gates = gates[..., :seq_len]
719-
fine_attn_out = einx.multiply('b h n, b h n d -> b h n d', gates, fine_attn_out)
720-
721722
else:
722723
# if only first block, just do a simple block causal
723724

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,7 @@ def backward_kernel_one_col_block_sparse(
590590
V,
591591
kv_block_indices,
592592
kv_block_mask,
593+
kv_block_grads,
593594
DO,
594595
DQ,
595596
DK,
@@ -619,6 +620,7 @@ def backward_kernel_one_col_block_sparse(
619620
BLOCK: tl.constexpr,
620621
QUERY_HEAD_GROUPS: tl.constexpr,
621622
QUERY_EXPAND_DIM: tl.constexpr,
623+
RETURN_SEL_GRADS: tl.constexpr,
622624
OFF_SEL_KV_BLOCKS: tl.constexpr
623625
):
624626
# We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
@@ -638,9 +640,6 @@ def backward_kernel_one_col_block_sparse(
638640

639641
# initialize pointers to value-like data
640642

641-
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
642-
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
643-
644643
q_ptrs = (
645644
Q +
646645
offs_g[None, :, None] * stride_qh +
@@ -794,9 +793,9 @@ def backward_kernel_one_col_block_sparse(
794793
block_qk = block_qk.reshape(BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK)
795794
qk = tl.sum(block_qk, 2) / QUERY_EXPAND_DIM
796795

797-
qk += tl.where(block_masks[:, None, None], 0, float("-inf"))
796+
masked_qk = qk + tl.where(block_masks[:, None, None], 0, float("-inf"))
798797

799-
p = tl.exp(qk * softmax_scale - lse_i[:, :, None])
798+
p = tl.exp(masked_qk * softmax_scale - lse_i[:, :, None])
800799

801800
# take care of block dv
802801

@@ -823,6 +822,26 @@ def backward_kernel_one_col_block_sparse(
823822

824823
ds = (p * (dp - Di[:, :, None]) * softmax_scale)
825824

825+
# maybe return gradients for better differentiable topk
826+
827+
if RETURN_SEL_GRADS:
828+
829+
kv_block_grads_ptrs = (
830+
kv_block_grads +
831+
offs_m * stride_kvbl_m
832+
)
833+
834+
sel_grads = ds * qk
835+
sel_grads = tl.where(block_masks[:, None, None], sel_grads, 0.)
836+
sel_grads = sel_grads.reshape(BLOCK, QUERY_HEAD_GROUPS * BLOCK)
837+
sel_grads = tl.sum(sel_grads, 1)
838+
839+
tl.atomic_add(
840+
kv_block_grads_ptrs + OFF_SEL_KV_BLOCKS,
841+
sel_grads,
842+
sem = 'relaxed'
843+
)
844+
826845
# block dk
827846

828847
block_dk = ds[:, :, :, None] * q[:, :, None, :].to(ds.dtype)
@@ -1145,6 +1164,7 @@ def backward_kernel(
11451164
V,
11461165
kv_block_indices,
11471166
kv_block_mask,
1167+
kv_block_grads,
11481168
DO,
11491169
DQ,
11501170
DK,
@@ -1192,19 +1212,16 @@ def backward_kernel(
11921212
BLOCK: tl.constexpr,
11931213
QUERY_HEAD_GROUPS: tl.constexpr,
11941214
QUERY_EXPAND_DIM: tl.constexpr,
1215+
RETURN_SEL_GRADS: tl.constexpr,
11951216
INCLUDE_BLOCK_CAUSAL: tl.constexpr
11961217
):
11971218
off_hb = tl.program_id(1)
11981219
off_b = off_hb // kv_heads
11991220
off_h = off_hb % kv_heads
12001221
off_qh = off_h * QUERY_HEAD_GROUPS
12011222

1202-
if INCLUDE_BLOCK_CAUSAL:
1203-
IS_CAUSAL = tl.program_id(0) == 0
1204-
OFF_SEL_KV_BLOCKS = tl.program_id(0) - 1
1205-
else:
1206-
IS_CAUSAL = False
1207-
OFF_SEL_KV_BLOCKS = tl.program_id(0)
1223+
OFF_SEL_KV_BLOCKS = tl.program_id(0) - int(INCLUDE_BLOCK_CAUSAL)
1224+
IS_CAUSAL = INCLUDE_BLOCK_CAUSAL and tl.program_id(0) == 0
12081225

12091226
# offset pointers for batch/head
12101227

@@ -1220,6 +1237,7 @@ def backward_kernel(
12201237

12211238
kv_block_indices += off_b * stride_kvbl_b + off_h * stride_kvbl_h
12221239
kv_block_mask += off_b * stride_kvbl_b + off_h * stride_kvbl_h
1240+
kv_block_grads += off_b * stride_kvbl_b + off_h * stride_kvbl_h
12231241

12241242
# pointer to row-wise quantities in value-like data
12251243

@@ -1283,6 +1301,7 @@ def backward_kernel(
12831301
V,
12841302
kv_block_indices,
12851303
kv_block_mask,
1304+
kv_block_grads,
12861305
DO,
12871306
DQ,
12881307
DK,
@@ -1312,6 +1331,7 @@ def backward_kernel(
13121331
BLOCK = BLOCK,
13131332
QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS,
13141333
QUERY_EXPAND_DIM = QUERY_EXPAND_DIM,
1334+
RETURN_SEL_GRADS = RETURN_SEL_GRADS,
13151335
OFF_SEL_KV_BLOCKS = OFF_SEL_KV_BLOCKS
13161336
)
13171337

@@ -1320,11 +1340,13 @@ def native_sparse_attn_backward(
13201340
q, k, v,
13211341
kv_block_indices,
13221342
kv_block_mask,
1343+
kv_block_grads,
13231344
o,
13241345
lse,
13251346
dq, dk, dv,
13261347
block_size = 128,
1327-
include_block_causal = True
1348+
include_block_causal = True,
1349+
return_sel_grads = False
13281350
):
13291351
device = do.device
13301352

@@ -1387,6 +1409,7 @@ def native_sparse_attn_backward(
13871409
v,
13881410
kv_block_indices,
13891411
kv_block_mask,
1412+
kv_block_grads,
13901413
do,
13911414
dq,
13921415
dk,
@@ -1436,6 +1459,7 @@ def native_sparse_attn_backward(
14361459
EVEN_M = divisible_by(seqlen_q, block_size),
14371460
EVEN_N = divisible_by(seqlen_k, block_size),
14381461
EVEN_HEADDIM = BLOCK_HEADDIM == dim,
1462+
RETURN_SEL_GRADS = return_sel_grads,
14391463
INCLUDE_BLOCK_CAUSAL = include_block_causal,
14401464
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
14411465
# num_warps=num_warps,
@@ -1458,6 +1482,7 @@ def forward(
14581482
block_size,
14591483
selected_block_indices,
14601484
fmask,
1485+
sel_scale,
14611486
include_block_causal
14621487
):
14631488
dtype = fq.dtype
@@ -1478,10 +1503,16 @@ def forward(
14781503

14791504
ctx.save_for_backward(fq, fk, fv, selected_block_indices, fmask, out, lse)
14801505

1506+
return_sel_grads = exists(sel_scale)
1507+
1508+
if return_sel_grads:
1509+
assert (sel_scale == 1.).all(), 'for now, must be straight through as multiplier of 1.'
1510+
14811511
ctx._saved_variables = (
14821512
block_size,
14831513
head_groups,
1484-
include_block_causal
1514+
return_sel_grads,
1515+
include_block_causal,
14851516
)
14861517

14871518
return out.type(dtype), lse
@@ -1495,6 +1526,7 @@ def backward(self, ctx, do, _):
14951526
(
14961527
block_size,
14971528
head_groups,
1529+
return_sel_grads,
14981530
include_block_causal
14991531
) = ctx._saved_variables
15001532

@@ -1503,15 +1535,23 @@ def backward(self, ctx, do, _):
15031535
dk = torch.zeros(k.shape, dtype = torch.float32, device = device)
15041536
dv = torch.zeros(v.shape, dtype = torch.float32, device = device)
15051537

1538+
sel_grads = torch.zeros_like(sel_block_indices).float()
1539+
15061540
native_sparse_attn_backward(
15071541
do, q, k, v,
1508-
sel_block_indices, mask,
1542+
sel_block_indices, mask, sel_grads,
15091543
out, lse, dq, dk, dv,
15101544
block_size = block_size,
1511-
include_block_causal = include_block_causal
1545+
include_block_causal = include_block_causal,
1546+
return_sel_grads = return_sel_grads
15121547
)
15131548

1514-
return dq, dk, dv, None, None, None, None
1549+
ret_sel_grads = None
1550+
1551+
if return_sel_grads:
1552+
ret_sel_grads = sel_grads
1553+
1554+
return dq, dk, dv, None, None, None, ret_sel_grads, None
15151555

15161556
_native_sparse_attend = NSA.apply
15171557

@@ -1531,6 +1571,7 @@ def native_sparse_attend(
15311571
block_size: int,
15321572
selected_block_indices: Int['b qh n sel'] | Int['b kh n sel'],
15331573
fmask: Bool['b qh n sel'] | Bool['b kh n sel'],
1574+
sel_scale: Float['b kh n sel'] | Float['b qh n sel'] | None = None,
15341575
include_block_causal = True,
15351576
return_lse = False
15361577
):
@@ -1550,6 +1591,7 @@ def native_sparse_attend(
15501591
block_size,
15511592
selected_block_indices,
15521593
fmask,
1594+
sel_scale,
15531595
include_block_causal
15541596
)
15551597

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.68"
3+
version = "0.0.69"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

test_triton_nsa.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,24 @@ def regular_attend(
2121
indices,
2222
mask,
2323
block_size,
24+
sel_scale = None,
2425
return_lse = False
2526
):
2627
q_heads, seq_len, kv_heads, device = q.shape[1], q.shape[-2], k.shape[1], q.device
2728
assert divisible_by(q_heads, kv_heads)
2829

2930
q, k, v = tuple(pad_to_multiple(t, block_size, dim = -2) for t in (q, k, v))
3031

32+
if exists(sel_scale):
33+
sel_scale = pad_to_multiple(sel_scale, block_size, dim = -2)
34+
3135
g = q_heads // kv_heads # `g` stands for `g`roups of query heads per kv head
3236

3337
w = ceil(seq_len / block_size)
3438

3539
q, k, v = tuple(rearrange(t, 'b h (w n) d -> b h w n d', n = block_size) for t in (q, k, v))
3640

3741
scale = q.shape[-1] ** -0.5
38-
q = q * scale
3942

4043
q = rearrange(q, 'b (h g) ... -> b h g ...', g = g)
4144

@@ -62,6 +65,10 @@ def regular_attend(
6265

6366
bsim = rearrange(bsim, 'b h g (w i) (sel j) -> b h g w i sel j', sel = num_sel_kv_blocks, i = fine_block_size)
6467

68+
if exists(sel_scale):
69+
sel_scale = rearrange(sel_scale, 'b h (w i) sel -> b h w i sel', i = fine_block_size)
70+
bsim = einx.multiply('b h g w i sel j, b h w i sel -> b h g w i sel j', bsim, sel_scale)
71+
6572
mask = rearrange(mask, 'b h (w i) sel -> b h 1 w i sel', i = fine_block_size)
6673
bsim = torch.where(mask[..., None], bsim, -torch.finfo(bsim.dtype).max)
6774

@@ -78,6 +85,7 @@ def regular_attend(
7885

7986
# attend
8087

88+
sim = sim * scale
8189
attn = sim.softmax(dim = -1)
8290

8391
if has_sel_kv_blocks:
@@ -113,27 +121,29 @@ def regular_attend(
113121

114122
indices = torch.randint(0, 2, (batch, kv_heads, seq_len, num_sel)).cuda()
115123
mask = torch.randint(0, 2, (batch, kv_heads, seq_len, num_sel)).bool().cuda()
124+
sel_scale = torch.ones((batch, kv_heads, seq_len, num_sel)).cuda()
116125

117126
# both regular and nsa pathways `r` and `n`
118127

119-
rq, rk, rv = tuple(t.clone().requires_grad_() for t in (q, k, v))
120-
nq, nk, nv = tuple(t.clone().requires_grad_() for t in (q, k, v))
128+
rq, rk, rv, rsel_scale = tuple(t.clone().requires_grad_() for t in (q, k, v, sel_scale))
129+
nq, nk, nv, nsel_scale = tuple(t.clone().requires_grad_() for t in (q, k, v, sel_scale))
121130

122131
# regular forwards and backwards
123132

124-
out, rlse = regular_attend(rq, rk, rv, indices, mask, block_size = fine_block_size, return_lse = True)
133+
out, rlse = regular_attend(rq, rk, rv, indices, mask, block_size = fine_block_size, sel_scale = rsel_scale, return_lse = True)
125134
out.sum().backward()
126135

127136
# triton nsa forwards and backwards
128137

129-
nsa_out, nlse = native_sparse_attend(nq, nk, nv, fine_block_size, indices, mask, return_lse = True)
138+
nsa_out, nlse = native_sparse_attend(nq, nk, nv, fine_block_size, indices, mask, sel_scale = nsel_scale, return_lse = True)
130139
nsa_out.sum().backward()
131140

132141
# asserts
133142

134143
assert torch.allclose(out, nsa_out, atol = 1e-2)
135144
assert torch.allclose(rlse, nlse, atol = 1e-2)
136145

146+
assert torch.allclose(rsel_scale.grad, nsel_scale.grad, atol = 1e-2)
137147
assert torch.allclose(nv.grad, rv.grad, atol = 1e-2)
138148
assert torch.allclose(nq.grad, rq.grad, atol = 1e-2)
139149
assert torch.allclose(nk.grad, rk.grad, atol = 1e-2)

0 commit comments

Comments
 (0)