Skip to content

Commit aa359ea

Browse files
committed
allow for different compress to fine block sizes during inference, and also make sure interpolated scores work
1 parent 9a038d1 commit aa359ea

File tree

4 files changed

+21
-7
lines changed

4 files changed

+21
-7
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def attend(
171171
mask_value = max_neg_value(sim)
172172

173173
if exists(mask):
174-
sim = sim.masked_fill(~mask, mask_value)
174+
sim = sim.masked_fill(~mask, mask_value // 10)
175175

176176
attn = sim.softmax(dim = -1)
177177

@@ -425,13 +425,25 @@ def forward_inference(
425425

426426
# 2. fine attention inference (todo - compress and fine diff block sizes)
427427

428-
assert self.compress_block_size == self.selection_block_size
429-
430428
importance_scores = csim[..., self.num_mem_compress_kv:]
431-
importance_scores += torch.randn_like(importance_scores) * 100
432429

433430
num_compress_blocks = importance_scores.shape[-1]
434-
num_selected = min(self.num_selected_blocks, num_compress_blocks)
431+
432+
if self.compress_block_size != self.selection_block_size:
433+
compress_seq_len = num_compress_blocks * self.compress_block_size
434+
435+
if self.interpolated_importance_score:
436+
importance_scores = interpolate_1d(importance_scores, compress_seq_len)
437+
else:
438+
importance_scores = repeat(importance_scores, '... j -> ... (bsz j)', bsz = self.compress_block_size)
439+
440+
fine_seq_len = round_down_mult(compress_seq_len, self.selection_block_size)
441+
442+
importance_scores = importance_scores[..., :fine_seq_len]
443+
importance_scores = reduce(importance_scores, '... (bsz j) -> ... j', 'mean', bsz = self.selection_block_size)
444+
445+
num_fine_blocks = importance_scores.shape[-1]
446+
num_selected = min(self.num_selected_blocks, num_fine_blocks)
435447
has_selected_kv_for_fine_attn = num_selected > 0
436448

437449
# block causal diagonal

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1707,6 +1707,8 @@ def native_sparse_attend(
17071707
assert divisible_by(q_heads, kv_heads)
17081708
assert sel_heads in (q_heads, kv_heads)
17091709

1710+
assert block_size >= 16, 'fine selection block size must be 16 or greater for now'
1711+
17101712
# query heads within each group to attend to different segments
17111713

17121714
if kv_heads != sel_heads:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.1.0"
3+
version = "0.1.1"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
# sparse attention related
4040

4141
SLIDING_WINDOW_SIZE = 64
42-
COMPRESS_BLOCK_SIZE = 16
42+
COMPRESS_BLOCK_SIZE = 8
4343

4444
FINE_BLOCK_SIZE = 16
4545
NUM_FINE_SELECTED = 4

0 commit comments

Comments
 (0)