Skip to content

Commit 26a1321

Browse files
committed
address overlapping compress blocks during inference issue
1 parent 01dacb9 commit 26a1321

File tree

3 files changed

+27
-11
lines changed

3 files changed

+27
-11
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,9 @@ def __init__(
290290
self.split_compress_window = split_compress_window_fn
291291
self.compress_window_size = compress_window_size
292292

293+
assert compress_block_overlap_len < compress_block_size
294+
self.compress_block_overlap_len = compress_block_overlap_len
295+
293296
# compression attention related parameters
294297

295298
self.num_mem_compress_kv = num_compressed_mem_kv
@@ -382,6 +385,7 @@ def forward_inference(
382385

383386
sliding_window = self.sliding_window_size
384387
compress_divisible_seq_len = round_down_mult(seq_len, self.compress_block_size)
388+
compress_overlap_len = self.compress_block_overlap_len
385389

386390
fine_divisible_seq_len = round_up_mult(seq_len, self.selection_block_size)
387391
num_fine_blocks = fine_divisible_seq_len // self.selection_block_size
@@ -439,19 +443,20 @@ def forward_inference(
439443

440444
running_compress_seq_len = run_k.shape[-2]
441445

442-
if divisible_by(running_compress_seq_len, self.compress_block_size):
443-
444-
k_compress_input = self.split_compress_window(run_k)
445-
v_compress_input = self.split_compress_window(run_v)
446+
if divisible_by(running_compress_seq_len, self.compress_block_size + compress_overlap_len):
447+
k_compress_input = rearrange(run_k, 'b h n d -> b h 1 n d')
448+
v_compress_input = rearrange(run_v, 'b h n d -> b h 1 n d')
446449

447450
k_compress_input = einx.add('b h w n d, h n d', k_compress_input, self.k_intrablock_positions)
448451
v_compress_input = einx.add('b h w n d, h n d', v_compress_input, self.v_intrablock_positions)
449452

450453
next_ck = self.k_compress(k_compress_input)
451454
next_cv = self.v_compress(v_compress_input)
452455

453-
run_k = run_k[..., 0:0, :]
454-
run_v = run_v[..., 0:0, :]
456+
run_kv_slice = slice(-compress_overlap_len, None) if compress_overlap_len > 0 else slice(0, 0)
457+
458+
run_k = run_k[..., run_kv_slice, :]
459+
run_v = run_v[..., run_kv_slice, :]
455460

456461
ck = cat((ck, next_ck), dim = -2)
457462
cv = cat((cv, next_cv), dim = -2)
@@ -593,6 +598,8 @@ def forward(
593598
compress_divisible_seq_len = round_down_mult(seq_len, self.compress_block_size)
594599
num_compress_blocks = compress_divisible_seq_len // self.compress_block_size
595600

601+
compress_overlap_len = self.compress_block_overlap_len
602+
596603
fine_divisible_seq_len = round_up_mult(seq_len, self.selection_block_size)
597604
num_fine_blocks = fine_divisible_seq_len // self.selection_block_size
598605

@@ -622,8 +629,14 @@ def forward(
622629
k_compress_input = einx.add('b h w n d, h n d', k_compress_input, self.k_intrablock_positions)
623630
v_compress_input = einx.add('b h w n d, h n d', v_compress_input, self.v_intrablock_positions)
624631

625-
run_k = k[..., compress_divisible_seq_len:, :]
626-
run_v = v[..., compress_divisible_seq_len:, :]
632+
run_k, run_v = k, v
633+
634+
if return_cache and compress_overlap_len > 0:
635+
run_k = F.pad(run_k, (0, 0, compress_overlap_len, 0), value = 0.)
636+
run_v = F.pad(run_v, (0, 0, compress_overlap_len, 0), value = 0.)
637+
638+
run_k = run_k[..., compress_divisible_seq_len:, :]
639+
run_v = run_v[..., compress_divisible_seq_len:, :]
627640

628641
cq = q
629642
ck = self.k_compress(k_compress_input) # Equation (7) of the Native Sparse Attention paper

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.1.22"
3+
version = "0.1.23"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_sparse_attn.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,11 @@ def test_sparse_attn(
5252

5353
@pytest.mark.parametrize('seq_len', (2, 8, 16))
5454
@pytest.mark.parametrize('num_selected_blocks', (0, 2))
55+
@pytest.mark.parametrize('compress_block_overlap_len', (0, 2))
5556
def test_inference(
5657
seq_len,
57-
num_selected_blocks
58+
num_selected_blocks,
59+
compress_block_overlap_len
5860
):
5961

6062
attn = SparseAttention(
@@ -65,7 +67,8 @@ def test_inference(
6567
sliding_window_size = 2,
6668
compress_block_size = 5,
6769
selection_block_size = 10,
68-
num_selected_blocks = num_selected_blocks
70+
num_selected_blocks = num_selected_blocks,
71+
compress_block_overlap_len = compress_block_overlap_len
6972
)
7073

7174
tokens = torch.randn(2, seq_len, 512)

0 commit comments

Comments
 (0)