Skip to content

Commit 36837b2

Browse files
committed
when doing interpolation of importance score, remask to 0 for illegal positions
1 parent 463963b commit 36837b2

File tree

4 files changed

+15
-8
lines changed

4 files changed

+15
-8
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,9 @@ def forward(
370370

371371
importance_scores = cattn[..., num_mem_compress_kv:]
372372

373+
num_selected = min(self.num_selected_blocks, importance_scores.shape[-1])
374+
has_selected_kv_for_fine_attn = num_selected > 0
375+
373376
# maybe average the compressed attention across each grouped queries (per key / values)
374377

375378
if self.query_heads_share_selected_kv:
@@ -383,13 +386,16 @@ def forward(
383386
# cannot parse their equation, so will just improvise
384387
# first we expand all the compressed scores to the full sequence length, then average within each fine / selection block size - pad on the right to 0s, which should be fine as sliding window convers the local anyways
385388

386-
if self.compress_block_size != self.selection_block_size:
389+
if has_selected_kv_for_fine_attn and self.compress_block_size != self.selection_block_size:
387390

388391
score_len = importance_scores.shape[-1]
389392
compress_seq_len = score_len * self.compress_block_size
390393

391394
if self.interpolated_importance_score:
395+
mask = importance_scores > 1e-10
396+
mask = repeat(mask, '... j -> ... (j block_size)', block_size = self.compress_block_size)
392397
importance_scores = interpolate_1d(importance_scores, compress_seq_len)
398+
importance_scores = importance_scores.masked_fill(~mask, 0.)
393399
else:
394400
importance_scores = repeat(importance_scores, '... j -> ... (j block_size)', block_size = self.compress_block_size)
395401

@@ -400,13 +406,11 @@ def forward(
400406

401407
# handle if number of total blocks is less than number to select for fine attention
402408

403-
num_selected = min(self.num_selected_blocks, importance_scores.shape[-1])
404-
405409
fq = rotated_q
406410
fk = rotated_k
407411
fv = v
408412

409-
if num_selected > 0:
413+
if has_selected_kv_for_fine_attn:
410414
selected_importance_values, selected_block_indices = importance_scores.topk(num_selected, dim = -1)
411415

412416
if self.use_diff_topk:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.43"
3+
version = "0.0.44"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_sparse_attn.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@
1212
@pytest.mark.parametrize('selection_block_size', (8, 4, 2))
1313
@pytest.mark.parametrize('num_selected_block', (0, 2))
1414
@pytest.mark.parametrize('query_heads_share_selected_kv', (False, True))
15+
@pytest.mark.parametrize('interpolated_importance_score', (False, True))
1516
def test_sparse_attn(
1617
use_diff_topk,
1718
seq_len,
1819
kv_heads,
1920
selection_block_size,
2021
num_selected_block,
21-
query_heads_share_selected_kv
22+
query_heads_share_selected_kv,
23+
interpolated_importance_score
2224
):
2325
attn = SparseAttention(
2426
dim = 512,
@@ -30,7 +32,8 @@ def test_sparse_attn(
3032
selection_block_size = selection_block_size,
3133
num_selected_blocks = num_selected_block,
3234
use_diff_topk = use_diff_topk,
33-
query_heads_share_selected_kv = query_heads_share_selected_kv
35+
query_heads_share_selected_kv = query_heads_share_selected_kv,
36+
interpolated_importance_score = interpolated_importance_score
3437
)
3538

3639
tokens = torch.randn(2, seq_len, 512)

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
FINE_BLOCK_SIZE = 32
4444
NUM_FINE_SELECTED = 0
4545

46-
INTERPOLATED_IMPORTANCE_SCORE = True
46+
INTERPOLATED_IMPORTANCE_SCORE = False
4747
USE_DIFF_TOPK = True
4848

4949
# experiment related

0 commit comments

Comments
 (0)