Skip to content

Commit ab3416c

Browse files
committed
make inference line up with selection block size > compress block size
1 parent 1798a2c commit ab3416c

File tree

3 files changed

+22
-24
lines changed

3 files changed

+22
-24
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -461,16 +461,12 @@ def forward_inference(
461461
importance_scores = csim[..., self.num_mem_compress_kv:]
462462

463463
num_compress_blocks = importance_scores.shape[-1]
464+
num_compress_per_fine = self.selection_block_size // self.compress_block_size
464465

465466
if self.compress_block_size != self.selection_block_size:
466-
compress_seq_len = num_compress_blocks * self.compress_block_size
467-
468-
importance_scores = repeat(importance_scores, '... j -> ... (bsz j)', bsz = self.compress_block_size)
469-
470-
fine_seq_len = round_down_mult(compress_seq_len, self.selection_block_size)
471-
472-
importance_scores = importance_scores[..., :fine_seq_len]
473-
importance_scores = reduce(importance_scores, '... (bsz j) -> ... j', 'mean', bsz = self.selection_block_size)
467+
compress_seq_len = round_down_mult(num_compress_blocks, num_compress_per_fine)
468+
importance_scores = importance_scores[..., :compress_seq_len]
469+
importance_scores = reduce(importance_scores, '... (j num_compress_per_fine) -> ... j', 'mean', num_compress_per_fine = num_compress_per_fine)
474470

475471
num_fine_blocks = importance_scores.shape[-1]
476472
num_selected = min(self.num_selected_blocks, num_fine_blocks)
@@ -490,7 +486,9 @@ def forward_inference(
490486
if self.query_heads_share_selected_kv:
491487
importance_scores = reduce(importance_scores, 'b (h grouped_queries) ... -> b h ...', 'mean', grouped_queries = self.num_grouped_queries)
492488

489+
importance_scores = F.pad(importance_scores, (1, 0), value = -1e3)
493490
importance_scores = importance_scores.softmax(dim = -1)
491+
importance_scores = importance_scores[..., 1:]
494492

495493
sel_scores, sel_indices = importance_scores.topk(num_selected, dim = -1)
496494

@@ -689,26 +687,24 @@ def forward(
689687

690688
if self.compress_block_size != self.selection_block_size:
691689

692-
compress_seq_len = num_compress_blocks * self.compress_block_size
693-
694-
importance_scores = repeat(importance_scores, '... j -> ... (j block_size)', block_size = self.compress_block_size)
690+
num_compress_per_fine = self.selection_block_size // self.compress_block_size
695691

696-
padding = fine_divisible_seq_len - compress_seq_len
692+
round_down_score_len = round_down_mult(importance_scores.shape[-1], num_compress_per_fine)
693+
importance_scores = importance_scores[..., :round_down_score_len]
697694

698-
fine_query_seq_len = importance_scores.shape[-2]
699-
fine_query_padding = fine_divisible_seq_len - importance_scores.shape[-2]
695+
if not is_empty(importance_scores):
696+
importance_scores = reduce(importance_scores, '... (j num_compress_per_fine) -> ... j', 'mean', num_compress_per_fine = num_compress_per_fine)
700697

701-
importance_scores = F.pad(importance_scores, (0, padding))
698+
i, j = importance_scores.shape[-2:]
702699

703-
# mask out the diagonal since block causal is included by default for fine attending
700+
# mask out block diagonal
704701

705-
block_causal_mask = torch.ones((num_fine_blocks,) * 2, device = device, dtype = torch.bool).tril(-1)
706-
block_causal_mask = repeat(block_causal_mask, 'i j -> (i n1) (j n2)', n1 = self.selection_block_size, n2 = self.selection_block_size)
707-
block_causal_mask = block_causal_mask[:fine_query_seq_len]
702+
q_seq = arange(i, device = device) // self.selection_block_size
703+
k_seq = arange(j, device = device)
708704

709-
importance_scores = importance_scores.masked_fill(~block_causal_mask, max_neg_value(csim))
705+
block_diagonal_mask = einx.equal('i, j -> i j', q_seq, k_seq)
710706

711-
importance_scores = reduce(importance_scores, '... (j block_size) -> ... j', 'mean', block_size = self.selection_block_size)
707+
importance_scores = importance_scores.masked_fill(block_diagonal_mask, max_neg_value(csim))
712708

713709
importance_scores = F.pad(importance_scores, (1, 0), value = -1e3)
714710
importance_scores = importance_scores.softmax(dim = -1)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.1.23"
3+
version = "0.1.24"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_sparse_attn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,12 @@ def test_sparse_attn(
5050
@pytest.mark.parametrize('seq_len', (2, 8, 16))
5151
@pytest.mark.parametrize('num_selected_blocks', (0, 2))
5252
@pytest.mark.parametrize('compress_block_overlap_len', (0, 2))
53+
@pytest.mark.parametrize('selection_block_size', (5, 10, 15))
5354
def test_inference(
5455
seq_len,
5556
num_selected_blocks,
56-
compress_block_overlap_len
57+
compress_block_overlap_len,
58+
selection_block_size
5759
):
5860

5961
attn = SparseAttention(
@@ -63,7 +65,7 @@ def test_inference(
6365
causal = True,
6466
sliding_window_size = 2,
6567
compress_block_size = 5,
66-
selection_block_size = 10,
68+
selection_block_size = selection_block_size,
6769
num_selected_blocks = num_selected_blocks,
6870
compress_block_overlap_len = compress_block_overlap_len
6971
)

0 commit comments

Comments
 (0)