Skip to content

Commit 839d906

Browse files
committed
fix an issue with block diagonal causal in fine attention + bring in coordinate descent based diff topk
1 parent 822dfdc commit 839d906

File tree

3 files changed

+36
-9
lines changed

3 files changed

+36
-9
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import torch.nn.functional as F
88
from torch.nn import Module, ModuleList
99

10+
from colt5_attention import topk as differentiable_topk
11+
1012
from local_attention import LocalAttention
1113

1214
from rotary_embedding_torch import RotaryEmbedding
@@ -84,8 +86,11 @@ def __init__(
8486
num_selected_blocks,
8587
num_compressed_mem_kv = 4,
8688
norm = True,
89+
use_diff_topk = False,
90+
diff_topk_coor_descent_iters = 10.
8791
):
8892
super().__init__()
93+
self.heads = heads
8994
self.scale = dim_head ** -0.5
9095

9196
assert compress_block_size == selection_block_size, 'start off with compressed being equal to selection block sizes'
@@ -136,6 +141,9 @@ def __init__(
136141

137142
# selection related
138143

144+
self.use_diff_topk = use_diff_topk
145+
self.diff_topk_coor_descent_iters = diff_topk_coor_descent_iters
146+
139147
self.selection_block_size = selection_block_size
140148
self.num_selected_blocks = num_selected_blocks
141149

@@ -160,7 +168,7 @@ def forward(
160168
self,
161169
inp
162170
):
163-
batch, seq_len, scale, device = *inp.shape[:2], self.scale, inp.device
171+
batch, seq_len, scale, heads, device = *inp.shape[:2], self.scale, self.heads, inp.device
164172

165173
compress_divisible_seq_len = round_down_mult(seq_len, self.compress_block_size)
166174
num_compress_blocks = compress_divisible_seq_len // self.compress_block_size
@@ -216,7 +224,10 @@ def forward(
216224

217225
importance_scores = csim[..., num_mem_compress_kv:]
218226

219-
selected_importance_values, selected_block_indices = importance_scores.topk(self.num_selected_blocks, dim = -1)
227+
if self.use_diff_topk:
228+
selected_importance_values, selected_block_indices, _, gates = differentiable_topk(importance_scores, self.num_selected_blocks, fused = True)
229+
else:
230+
selected_importance_values, selected_block_indices = importance_scores.topk(self.num_selected_blocks, dim = -1)
220231

221232
fmask = selected_importance_values > mask_value
222233

@@ -239,13 +250,13 @@ def forward(
239250
# handle block causal diagonal in the diagram, but run experiments without to see
240251

241252
fine_window_seq = arange(fine_divisible_seq_len, device = device) // self.selection_block_size
242-
fine_window_seq = rearrange(fine_window_seq, 'n -> n 1').expand_as(selected_block_indices)
253+
fine_window_seq = repeat(fine_window_seq, 'n -> b h n 1', b = batch, h = heads)
243254
selected_block_indices = cat((selected_block_indices, fine_window_seq), dim = -1) # for the block causal diagonal in fig2
244255

245256
fmask = repeat(fmask, 'b h i w -> b h i w j', j = self.selection_block_size)
246257

247258
causal_mask = torch.ones((self.selection_block_size,) * 2, device = device, dtype = torch.bool).tril()
248-
causal_mask = repeat(causal_mask, 'i j -> (w i) 1 j', w = num_fine_blocks).expand_as(fmask)
259+
causal_mask = repeat(causal_mask, 'i j -> b h (w i) 1 j', w = num_fine_blocks, b = batch, h = heads)
249260

250261
fmask = cat((fmask, causal_mask), dim = -2)
251262
fmask = rearrange(fmask, 'b h i w j -> b h i (w j)')
@@ -255,8 +266,19 @@ def forward(
255266
fk = rearrange(fk, 'b h (w n) d -> b h w n d', w = num_fine_blocks)
256267
fv = rearrange(fv, 'b h (w n) d -> b h w n d', w = num_fine_blocks)
257268

258-
fk = einx.get_at('b h [w] j d, b h i selected -> b h i (selected j) d', fk, selected_block_indices)
259-
fv = einx.get_at('b h [w] j d, b h i selected -> b h i (selected j) d', fv, selected_block_indices)
269+
fk = einx.get_at('b h [w] j d, b h i selected -> b h i selected j d', fk, selected_block_indices)
270+
fv = einx.get_at('b h [w] j d, b h i selected -> b h i selected j d', fv, selected_block_indices)
271+
272+
# handle maybe gating
273+
274+
if self.use_diff_topk:
275+
gates = F.pad(gates, (0, 1, 0, remainder), value = 1.)
276+
277+
fk = einx.multiply('b h i w, b h i w j d -> b h i w j d', gates, fk)
278+
fv = einx.multiply('b h i w, b h i w j d -> b h i w j d', gates, fv)
279+
280+
fk = rearrange(fk, 'b h i w j d -> b h i (w j) d')
281+
fv = rearrange(fv, 'b h i w j d -> b h i (w j) d')
260282

261283
# fine attention
262284

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.3"
3+
version = "0.0.4"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -23,6 +23,7 @@ classifiers=[
2323
]
2424

2525
dependencies = [
26+
"CoLT5-attention>=0.11.1",
2627
"einx>=0.3.0",
2728
"einops>=0.8.0",
2829
"local-attention>=1.11.1",

tests/test_sparse_attn.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import pytest
22
import torch
33

4-
def test_sparse_attn():
4+
@pytest.mark.parametrize('use_diff_topk', (False, True))
5+
def test_sparse_attn(
6+
use_diff_topk
7+
):
58
from native_sparse_attention_pytorch import SparseAttention
69

710
attn = SparseAttention(
@@ -11,7 +14,8 @@ def test_sparse_attn():
1114
sliding_window_size = 2,
1215
compress_block_size = 4,
1316
selection_block_size = 4,
14-
num_selected_blocks = 2
17+
num_selected_blocks = 2,
18+
use_diff_topk = use_diff_topk
1519
)
1620

1721
tokens = torch.randn(2, 31, 512)

0 commit comments

Comments
 (0)