Skip to content

Commit 0ad8c5e

Browse files
committed
allow for ablating fine block selection for negative control
1 parent e9476ec commit 0ad8c5e

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,11 @@ def __init__(
277277

278278
self.selection_block_size = selection_block_size
279279

280-
assert num_selected_blocks > 0
280+
assert num_selected_blocks >= 0
281+
282+
if num_selected_blocks == 0:
283+
print(f'`num_selected_blocks` should be set greater than 0, unless if you are ablating it for experimental purposes')
284+
281285
self.num_selected_blocks = num_selected_blocks
282286

283287
# they combine the three sparse branches through a learned combine with sigmoid activation

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.41"
3+
version = "0.0.42"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_sparse_attn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
@pytest.mark.parametrize('seq_len', (1, 4, 31, 32, 120))
1111
@pytest.mark.parametrize('kv_heads', (8, 4))
1212
@pytest.mark.parametrize('selection_block_size', (8, 4, 2))
13+
@pytest.mark.parametrize('num_selected_block', (0, 2))
1314
@pytest.mark.parametrize('query_heads_share_selected_kv', (False, True))
1415
def test_sparse_attn(
1516
use_diff_topk,
1617
seq_len,
1718
kv_heads,
1819
selection_block_size,
20+
num_selected_block,
1921
query_heads_share_selected_kv
2022
):
2123
attn = SparseAttention(
@@ -26,7 +28,7 @@ def test_sparse_attn(
2628
sliding_window_size = 2,
2729
compress_block_size = 4,
2830
selection_block_size = selection_block_size,
29-
num_selected_blocks = 2,
31+
num_selected_blocks = num_selected_block,
3032
use_diff_topk = use_diff_topk,
3133
query_heads_share_selected_kv = query_heads_share_selected_kv
3234
)

0 commit comments

Comments
 (0)