Skip to content

Commit 56586dd

Browse files
committed
release new compression block hparams
1 parent 832050a commit 56586dd

File tree

4 files changed

+11
-11
lines changed

4 files changed

+11
-11
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ This will be my last open sourced project under Meta
1616

1717
- <a href="https://github.com/Pasewark">Eric Pasewark</a> for submitting a simple transformer based compression network
1818

19+
- <a href="https://github.com/Mr-Grin">@Mr-Grin</a> for a pull request that fixes compression block hyperparameters
20+
1921
## Install
2022

2123
```bash

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ def forward_inference(
439439
if return_cache:
440440
cache_compressed_kv = ((ck, cv), (run_k, run_v))
441441

442-
# 2. fine attention inference (todo - compress and fine diff block sizes)
442+
# 2. fine attention inference
443443

444444
importance_scores = csim[..., self.num_mem_compress_kv:]
445445

@@ -628,7 +628,7 @@ def forward(
628628
# compressed masking
629629

630630
cmask = None
631-
# TODO
631+
632632
if self.causal:
633633
cq_seq = arange(seq_len, device = device)
634634
ck_seq = ((arange(num_compress_blocks, device = device) + 1) * self.compress_block_sliding_stride) - 1

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.1.27"
3+
version = "0.2.0"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_sparse_attn.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
from native_sparse_attention_pytorch import SparseAttention
88

9-
device = 'cpu'
10-
119
@pytest.mark.parametrize('use_diff_topk', (False, True))
1210
@pytest.mark.parametrize('causal', (False, True))
1311
@pytest.mark.parametrize('seq_len', (1, 4, 31, 32, 120))
@@ -41,9 +39,9 @@ def test_sparse_attn(
4139
num_selected_blocks = num_selected_block,
4240
use_diff_topk = use_diff_topk,
4341
query_heads_share_selected_kv = query_heads_share_selected_kv,
44-
).to(device)
42+
)
4543

46-
tokens = torch.randn(2, seq_len, 512).to(device)
44+
tokens = torch.randn(2, seq_len, 512)
4745

4846
attended = attn(tokens)
4947

@@ -70,9 +68,9 @@ def test_inference(
7068
selection_block_size = selection_block_size,
7169
num_selected_blocks = num_selected_blocks,
7270
compress_block_sliding_stride = compress_block_sliding_stride
73-
).to(device)
71+
)
7472

75-
tokens = torch.randn(2, seq_len, 512).to(device)
73+
tokens = torch.randn(2, seq_len, 512)
7674

7775
parallel_out = attn(tokens)
7876

@@ -106,9 +104,9 @@ def test_transformer_inference(
106104
selection_block_size = selection_block_size,
107105
num_selected_blocks = 2
108106
)
109-
).to(device)
107+
)
110108

111-
prompt = torch.randint(0, 256, (1, 1)).to(device)
109+
prompt = torch.randint(0, 256, (1, 1))
112110

113111
sampled = model.sample(prompt, 128, temperature = 0., use_cache_kv = False)
114112
sampled_cached = model.sample(prompt, 128, temperature = 0., use_cache_kv = True)

0 commit comments

Comments
 (0)