Skip to content

Commit cfa6e08

Browse files
committed
do the differential topk gating in a more suboptimal way, but accommodates fused attention better
1 parent 671534a commit cfa6e08

File tree

2 files changed

+12
-13
lines changed

2 files changed

+12
-13
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -398,8 +398,9 @@ def forward(
398398
selected_importance_values, selected_block_indices = importance_scores.topk(num_selected, dim = -1)
399399

400400
if self.use_diff_topk:
401-
assert not exists(fine_selection_flex_mask)
402401
gates = straight_through(selected_importance_values, 1.)
402+
gates = gates.cumsum(dim = -1)[..., -1]
403+
gates = repeat(gates, 'b h ... -> b (h qh) ...', qh = self.num_grouped_queries)
403404

404405
if exists(fine_selection_flex_mask):
405406
# flex attention for the selection for fine attention
@@ -422,7 +423,7 @@ def forward(
422423
selected_block_indices = pad_at_dim(selected_block_indices, (0, remainder), value = 0, dim = -2)
423424

424425
if self.use_diff_topk:
425-
gates = pad_at_dim(gates, (0, remainder), value = 1., dim = -2)
426+
gates = pad_at_dim(gates, (0, remainder), value = 1.)
426427

427428
# handle block causal diagonal in the diagram, but run experiments without to see
428429

@@ -453,16 +454,7 @@ def forward(
453454
fk = fk.gather(3, selected_block_indices)
454455
fv = fv.gather(3, selected_block_indices)
455456

456-
# handle maybe gating
457-
458-
if self.use_diff_topk:
459-
gates = F.pad(gates, (0, 1), value = 1.)
460-
461-
fk = einx.multiply('b h i w, b h i w j d -> b h i w j d', gates, fk)
462-
fv = einx.multiply('b h i w, b h i w j d -> b h i w j d', gates, fv)
463-
464-
fk = rearrange(fk, 'b h i w j d -> b h i (w j) d')
465-
fv = rearrange(fv, 'b h i w j d -> b h i (w j) d')
457+
fk, fv = tuple(rearrange(t, 'b h i w j d -> b h i (w j) d') for t in (fk, fv))
466458

467459
# fine attention
468460

@@ -483,6 +475,13 @@ def forward(
483475
fine_attn_out = rearrange(fine_attn_out, 'b h qh ... -> b (h qh) ...')
484476

485477
fine_attn_out = fine_attn_out[..., :seq_len, :]
478+
479+
# handle maybe gating
480+
481+
if self.use_diff_topk:
482+
gates = gates[..., :seq_len]
483+
fine_attn_out = einx.multiply('b h n, b h n d -> b h n d', gates, fine_attn_out)
484+
486485
else:
487486
# if only first block, just do a simple block causal
488487

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.38"
3+
version = "0.0.39"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)