Skip to content

Commit c00e698

Browse files
committed
some progress towards non-causal variant
1 parent a6101f5 commit c00e698

File tree

3 files changed

+63
-27
lines changed

3 files changed

+63
-27
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 59 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,27 @@
4141

4242
# flex attn sliding attention mask
4343

44-
def create_sliding_mask(seq_len, window_size):
44+
def create_sliding_mask(seq_len, window_size, causal = True):
4545
def sliding_mask(_, __, q_idx, kv_idx):
46-
causal_mask = q_idx >= kv_idx
4746

48-
sliding_mask = (q_idx - kv_idx) <= window_size
49-
causal_mask = causal_mask & sliding_mask
47+
distance = q_idx - kv_idx
48+
mask = distance <= window_size
5049

51-
return causal_mask
50+
if causal:
51+
mask = mask & q_idx >= kv_idx
52+
else:
53+
mask = mask & (distance >= -window_size)
54+
55+
return mask
5256

5357
block_mask = create_block_mask(sliding_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
5458
return block_mask
5559

56-
def create_compress_mask(seq_len, kv_seq_len, compress_block_size, mem_kv_len = 0):
60+
def create_compress_mask(seq_len, kv_seq_len, compress_block_size, mem_kv_len = 0, causal = True):
61+
62+
if not causal:
63+
return None
64+
5765
# cannot be used as using attention logits for importance score
5866
# but just to show the immense potential of flex attention
5967

@@ -69,7 +77,7 @@ def compress_mask(_, __, q_idx, kv_idx):
6977
block_mask = create_block_mask(compress_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = kv_seq_len + mem_kv_len, _compile = True)
7078
return block_mask
7179

72-
def create_fine_mask(seq_len, fine_block_size):
80+
def create_fine_mask(seq_len, fine_block_size, causal = True):
7381

7482
def inner(selected_block_indices: Tensor, num_grouped_queries = 1):
7583
device = selected_block_indices.device
@@ -86,6 +94,9 @@ def fine_mask(b_idx, h_idx, q_idx, kv_idx):
8694

8795
is_selected = one_hot_selected_block_indices[b_idx, kv_head_idx, q_idx, compressed_kv_idx]
8896

97+
if not causal:
98+
return is_selected
99+
89100
causal_mask = q_idx >= kv_idx
90101
block_diagonal = compressed_q_idx == compressed_kv_idx
91102

@@ -189,6 +200,7 @@ def __init__(
189200
num_selected_blocks,
190201
kv_heads = None,
191202
num_compressed_mem_kv = 1,
203+
causal = False,
192204
norm = True,
193205
use_diff_topk = False,
194206
use_triton_kernel = False,
@@ -219,6 +231,10 @@ def __init__(
219231

220232
self.norm = nn.RMSNorm(dim) if norm else nn.Identity()
221233

234+
# autoregressive or not - will extend this work for long context video / genomics use-cases
235+
236+
self.causal = causal
237+
222238
# rotary
223239

224240
self.rotary_emb = RotaryEmbedding(dim_head)
@@ -236,7 +252,7 @@ def __init__(
236252
self.sliding_window = LocalAttention(
237253
dim = dim_head,
238254
window_size = sliding_window_size,
239-
causal = True,
255+
causal = causal,
240256
exact_windowsize = True,
241257
autopad = True,
242258
use_rotary_pos_emb = False
@@ -322,6 +338,8 @@ def forward_inference(
322338
cache,
323339
return_cache = True
324340
):
341+
assert self.causal, 'inference only relevant for autoregressive'
342+
325343
# destruct cache
326344

327345
(
@@ -515,6 +533,8 @@ def forward(
515533
assert inp.shape[1] == 1, 'input must be single tokens if inferencing with cache key values'
516534
return self.forward_inference(inp, cache, return_cache = return_cache)
517535

536+
assert not (self.causal and return_cache)
537+
518538
batch, seq_len, scale, heads, device = *inp.shape[:2], self.scale, self.heads, inp.device
519539

520540
compress_divisible_seq_len = round_down_mult(seq_len, self.compress_block_size)
@@ -560,11 +580,16 @@ def forward(
560580
ck = cat((mem_ck, ck), dim = -2)
561581
cv = cat((mem_cv, cv), dim = -2)
562582

563-
cq_seq = arange(seq_len, device = device)
564-
ck_seq = ((arange(num_compress_blocks, device = device) + 1) * self.compress_block_size) - 1
565-
ck_seq = F.pad(ck_seq, (num_mem_compress_kv, 0), value = -1)
583+
# compressed masking
584+
585+
cmask = None
566586

567-
cmask = einx.less('j, i -> i j', ck_seq, cq_seq)
587+
if self.causal:
588+
cq_seq = arange(seq_len, device = device)
589+
ck_seq = ((arange(num_compress_blocks, device = device) + 1) * self.compress_block_size) - 1
590+
ck_seq = F.pad(ck_seq, (num_mem_compress_kv, 0), value = -1)
591+
592+
cmask = einx.less('j, i -> i j', ck_seq, cq_seq)
568593

569594
compressed_attn_out, csim = attend(cq, ck, cv, mask = cmask, return_sim = True)
570595

@@ -657,7 +682,8 @@ def forward(
657682
self.selection_block_size,
658683
selected_block_indices,
659684
fmask,
660-
sel_scale = gates
685+
sel_scale = gates,
686+
include_block_diagonal = self.causal
661687
)
662688

663689
elif exists(fine_selection_flex_mask):
@@ -685,19 +711,23 @@ def forward(
685711
if exists(gates):
686712
gates = pad_at_dim(gates, (0, remainder), value = 0, dim = -2)
687713

688-
# handle block causal diagonal in the diagram, but run experiments without to see
714+
if self.causal:
715+
# handle block causal diagonal in the diagram, but run experiments without to see
716+
717+
fine_window_seq = arange(fine_divisible_seq_len, device = device) // self.selection_block_size
718+
fine_window_seq = repeat(fine_window_seq, 'n -> b h n 1', b = batch, h = selected_block_indices.shape[1])
719+
selected_block_indices = cat((selected_block_indices, fine_window_seq), dim = -1) # for the block causal diagonal in fig2
689720

690-
fine_window_seq = arange(fine_divisible_seq_len, device = device) // self.selection_block_size
691-
fine_window_seq = repeat(fine_window_seq, 'n -> b h n 1', b = batch, h = selected_block_indices.shape[1])
692-
selected_block_indices = cat((selected_block_indices, fine_window_seq), dim = -1) # for the block causal diagonal in fig2
721+
fmask = repeat(fmask, 'b h i w -> b h i w j', j = self.selection_block_size)
693722

694-
fmask = repeat(fmask, 'b h i w -> b h i w j', j = self.selection_block_size)
723+
causal_mask = torch.ones((self.selection_block_size,) * 2, device = device, dtype = torch.bool).tril()
724+
causal_mask = repeat(causal_mask, 'i j -> b h (w i) 1 j', w = num_fine_blocks, b = batch, h = fmask.shape[1])
695725

696-
causal_mask = torch.ones((self.selection_block_size,) * 2, device = device, dtype = torch.bool).tril()
697-
causal_mask = repeat(causal_mask, 'i j -> b h (w i) 1 j', w = num_fine_blocks, b = batch, h = fmask.shape[1])
726+
fmask = cat((fmask, causal_mask), dim = -2)
727+
fmask = rearrange(fmask, 'b h i w j -> b h 1 i (w j)')
698728

699-
fmask = cat((fmask, causal_mask), dim = -2)
700-
fmask = rearrange(fmask, 'b h i w j -> b h i (w j)')
729+
else:
730+
fmask = repeat(fmask, 'b h i w -> b h 1 i (w j)', j = self.selection_block_size)
701731

702732
# select out the spatial crops of keys / values for fine attention
703733

@@ -721,7 +751,9 @@ def forward(
721751
# differential topk gating
722752

723753
if self.use_diff_topk:
724-
gates = F.pad(gates, (0, 1), value = 1.)
754+
if self.causal:
755+
gates = F.pad(gates, (0, 1), value = 1.)
756+
725757
fk = einx.multiply('b h i sel, b h i sel j d -> b h i sel j d', gates, fk)
726758

727759
# merge selected key values
@@ -730,8 +762,6 @@ def forward(
730762

731763
# fine attention
732764

733-
fmask = rearrange(fmask, 'b h ... -> b h 1 ...')
734-
735765
fq = rearrange(fq, 'b (h qh) ... -> b h qh ...', qh = fine_num_grouped_queries)
736766

737767
fsim = einsum(fq, fk, 'b h qh i d, b h i j d -> b h qh i j') * self.scale
@@ -752,7 +782,10 @@ def forward(
752782
# if only first block, just do a simple block causal
753783

754784
seq_len = fk.shape[-2]
755-
fmask = causal_mask = torch.ones((seq_len, seq_len), device = device, dtype = torch.bool).tril()
785+
fmask = None
786+
787+
if self.causal:
788+
fmask = causal_mask = torch.ones((seq_len, seq_len), device = device, dtype = torch.bool).tril()
756789

757790
fine_attn_out = attend(fq, fk, fv, mask = fmask)
758791

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.73"
3+
version = "0.0.75"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_sparse_attn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from native_sparse_attention_pytorch import SparseAttention
88

99
@pytest.mark.parametrize('use_diff_topk', (False, True))
10+
@pytest.mark.parametrize('causal', (False, True))
1011
@pytest.mark.parametrize('seq_len', (1, 4, 31, 32, 120))
1112
@pytest.mark.parametrize('kv_heads', (8, 4))
1213
@pytest.mark.parametrize('selection_block_size', (8, 4, 2))
@@ -15,6 +16,7 @@
1516
@pytest.mark.parametrize('interpolated_importance_score', (False, True))
1617
def test_sparse_attn(
1718
use_diff_topk,
19+
causal,
1820
seq_len,
1921
kv_heads,
2022
selection_block_size,
@@ -27,6 +29,7 @@ def test_sparse_attn(
2729
dim_head = 64,
2830
heads = 8,
2931
kv_heads = kv_heads,
32+
causal = causal,
3033
sliding_window_size = 2,
3134
compress_block_size = 4,
3235
selection_block_size = selection_block_size,

0 commit comments

Comments
 (0)