Skip to content

Commit e901e73

Browse files
committed
allow for single projection "mlp", for @Mr-Grin to experiment around with
1 parent 855c7f8 commit e901e73

File tree

5 files changed

+60
-4
lines changed

5 files changed

+60
-4
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ This will be my last open sourced project under Meta
1212

1313
- [Flex Attention](https://pytorch.org/blog/flexattention/) for allowing for rapid prototyping
1414

15-
- <a href="https://github.com/Mr-Grin">@Mr-Grin</a> for the code review and pointing out a few inaccuracies in the implementation
15+
- <a href="https://github.com/Mr-Grin">@Mr-Grin</a> for the code review and pointing out an inaccuracy with the implementation
1616

1717
## Install
1818

native_sparse_attention_pytorch/compress_networks.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch.nn import Module, ModuleList
44

55
from einops import einsum, rearrange
6-
from einops.layers.torch import EinMix as Mix
6+
from einops.layers.torch import EinMix as Mix, Rearrange
77

88
# helpers
99

@@ -98,3 +98,38 @@ def forward(
9898
compressed = self.net(kv)
9999

100100
return compressed
101+
102+
# single projection "mlp"
103+
104+
class SingleProjection(Module):
105+
def __init__(
106+
self,
107+
dim_head,
108+
compress_window_size,
109+
heads = 1
110+
):
111+
super().__init__()
112+
dim = dim_head * compress_window_size
113+
dim_out = dim_head
114+
115+
is_grouped = heads > 1
116+
117+
if not is_grouped:
118+
self.compress = nn.Sequential(
119+
Rearrange('b h w n d -> b h w (n d)'),
120+
nn.Linear(dim, dim_out, bias = False)
121+
)
122+
else:
123+
self.compress = Mix(
124+
'b h w n i -> b h w o',
125+
weight_shape = 'h i o',
126+
h = heads,
127+
i = dim_head,
128+
o = dim_head
129+
)
130+
131+
def forward(
132+
self,
133+
kv
134+
):
135+
return self.compress(kv)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.1.15"
3+
version = "0.1.16"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_custom_compress_mlp.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,24 @@ def test_group_mlp():
7777
attended = attn(tokens)
7878

7979
assert tokens.shape == attended.shape
80+
81+
@pytest.mark.parametrize('grouped', (False, True))
82+
def test_single_projection_mlp(grouped):
83+
from native_sparse_attention_pytorch.compress_networks import SingleProjection
84+
85+
attn = SparseAttention(
86+
dim = 512,
87+
dim_head = 64,
88+
heads = 8,
89+
sliding_window_size = 2,
90+
compress_block_size = 4,
91+
selection_block_size = 4,
92+
num_selected_blocks = 2,
93+
compress_mlp = SingleProjection(64, 4, 8 if grouped else 1)
94+
)
95+
96+
tokens = torch.randn(2, 31, 512)
97+
98+
attended = attn(tokens)
99+
100+
assert tokens.shape == attended.shape

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
INTERPOLATED_IMPORTANCE_SCORE = False
4949
USE_DIFF_TOPK = True
5050

51-
USE_EFFICIENT_INFERENCE = True # needs validation still
51+
USE_EFFICIENT_INFERENCE = False # needs validation still
5252

5353
# experiment related
5454

0 commit comments

Comments
 (0)