Skip to content

Commit d1f5d41

Browse files
committed
fix an off by one error in inference for block causal sliding window in fine attention, also fix another edge case where if fine selection turned off, not using block causal in regular forward
1 parent 9868f4c commit d1f5d41

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from copy import deepcopy
44
from math import ceil
5+
from functools import partial
56

67
import torch
78
import torch.nn.functional as F
@@ -483,7 +484,7 @@ def forward_inference(
483484

484485
# block causal diagonal
485486

486-
fine_sliding_window = (seq_len % self.selection_block_size) + 1
487+
fine_sliding_window = ((seq_len - 1) % self.selection_block_size) + 1
487488
fk = k[..., -fine_sliding_window:, :]
488489
fv = v[..., -fine_sliding_window:, :]
489490

@@ -721,6 +722,9 @@ def forward(
721722
num_selected = min(num_selected, importance_scores.shape[-1])
722723
has_selected_kv_for_fine_attn = num_selected > 0
723724

725+
remainder = fine_divisible_seq_len - seq_len
726+
pad_to_multiple = partial(pad_at_dim, pad = (0, remainder), dim = -2)
727+
724728
if has_selected_kv_for_fine_attn:
725729

726730
# get the top-n kv segments for fine attention
@@ -760,10 +764,9 @@ def forward(
760764
fmask = selected_importance_values > 1e-10
761765

762766
if seq_len < fine_divisible_seq_len:
763-
remainder = fine_divisible_seq_len - seq_len
764-
fk = pad_at_dim(fk, (0, remainder), value = 0., dim = -2)
765-
fv = pad_at_dim(fv, (0, remainder), value = 0., dim = -2)
766-
fq = pad_at_dim(fq, (0, remainder), value = 0., dim = -2)
767+
fk = pad_to_multiple(fk)
768+
fv = pad_to_multiple(fv)
769+
fq = pad_to_multiple(fq)
767770

768771
fmask = pad_at_dim(fmask, (0, remainder), value = False, dim = -2)
769772

@@ -845,11 +848,20 @@ def forward(
845848
seq_len = fk.shape[-2]
846849
fmask = None
847850

851+
fk = pad_to_multiple(fk)
852+
fv = pad_to_multiple(fv)
853+
fq = pad_to_multiple(fq)
854+
855+
fq, fk, fv = tuple(rearrange(t, 'b h (w n) d -> (b w) h n d', n = self.selection_block_size) for t in (fq, fk, fv))
856+
848857
if self.causal:
849-
fmask = causal_mask = torch.ones((seq_len, seq_len), device = device, dtype = torch.bool).tril()
858+
fmask = causal_mask = torch.ones((self.selection_block_size, self.selection_block_size), device = device, dtype = torch.bool).tril()
850859

851860
fine_attn_out = attend(fq, fk, fv, mask = fmask)
852861

862+
fine_attn_out = rearrange(fine_attn_out, '(b w) h n d -> b h (w n) d', b = batch)
863+
fine_attn_out = fine_attn_out[..., :seq_len, :]
864+
853865
# 3. overlapping sliding window, this is unsurprising and expected - `s` for sliding
854866

855867
sq = q

tests/test_sparse_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_sparse_attn(
5050

5151
assert tokens.shape == attended.shape
5252

53-
@pytest.mark.parametrize('seq_len', (8,))
53+
@pytest.mark.parametrize('seq_len', (2, 8, 16))
5454
def test_inference(seq_len):
5555

5656
attn = SparseAttention(
@@ -61,7 +61,7 @@ def test_inference(seq_len):
6161
sliding_window_size = 2,
6262
compress_block_size = 5,
6363
selection_block_size = 10,
64-
num_selected_blocks = 2
64+
num_selected_blocks = 0
6565
)
6666

6767
tokens = torch.randn(2, seq_len, 512)

0 commit comments

Comments
 (0)