File tree Expand file tree Collapse file tree 5 files changed +60
-4
lines changed
native_sparse_attention_pytorch Expand file tree Collapse file tree 5 files changed +60
-4
lines changed Original file line number Diff line number Diff line change @@ -12,7 +12,7 @@ This will be my last open sourced project under Meta
12
12
13
13
- [ Flex Attention] ( https://pytorch.org/blog/flexattention/ ) for allowing for rapid prototyping
14
14
15
- - <a href =" https://github.com/Mr-Grin " >@Mr-Grin </a > for the code review and pointing out a few inaccuracies in the implementation
15
+ - <a href =" https://github.com/Mr-Grin " >@Mr-Grin </a > for the code review and pointing out an inaccuracy with the implementation
16
16
17
17
## Install
18
18
Original file line number Diff line number Diff line change 3
3
from torch .nn import Module , ModuleList
4
4
5
5
from einops import einsum , rearrange
6
- from einops .layers .torch import EinMix as Mix
6
+ from einops .layers .torch import EinMix as Mix , Rearrange
7
7
8
8
# helpers
9
9
@@ -98,3 +98,38 @@ def forward(
98
98
compressed = self .net (kv )
99
99
100
100
return compressed
101
+
102
+ # single projection "mlp"
103
+
104
+ class SingleProjection (Module ):
105
+ def __init__ (
106
+ self ,
107
+ dim_head ,
108
+ compress_window_size ,
109
+ heads = 1
110
+ ):
111
+ super ().__init__ ()
112
+ dim = dim_head * compress_window_size
113
+ dim_out = dim_head
114
+
115
+ is_grouped = heads > 1
116
+
117
+ if not is_grouped :
118
+ self .compress = nn .Sequential (
119
+ Rearrange ('b h w n d -> b h w (n d)' ),
120
+ nn .Linear (dim , dim_out , bias = False )
121
+ )
122
+ else :
123
+ self .compress = Mix (
124
+ 'b h w n i -> b h w o' ,
125
+ weight_shape = 'h i o' ,
126
+ h = heads ,
127
+ i = dim_head ,
128
+ o = dim_head
129
+ )
130
+
131
+ def forward (
132
+ self ,
133
+ kv
134
+ ):
135
+ return self .compress (kv )
Original file line number Diff line number Diff line change 1
1
[project ]
2
2
name = " native-sparse-attention-pytorch"
3
- version = " 0.1.15 "
3
+ version = " 0.1.16 "
4
4
description = " Native Sparse Attention"
5
5
authors = [
6
6
{ name = " Phil Wang" , email = " lucidrains@gmail.com" }
Original file line number Diff line number Diff line change @@ -77,3 +77,24 @@ def test_group_mlp():
77
77
attended = attn (tokens )
78
78
79
79
assert tokens .shape == attended .shape
80
+
81
+ @pytest .mark .parametrize ('grouped' , (False , True ))
82
+ def test_single_projection_mlp (grouped ):
83
+ from native_sparse_attention_pytorch .compress_networks import SingleProjection
84
+
85
+ attn = SparseAttention (
86
+ dim = 512 ,
87
+ dim_head = 64 ,
88
+ heads = 8 ,
89
+ sliding_window_size = 2 ,
90
+ compress_block_size = 4 ,
91
+ selection_block_size = 4 ,
92
+ num_selected_blocks = 2 ,
93
+ compress_mlp = SingleProjection (64 , 4 , 8 if grouped else 1 )
94
+ )
95
+
96
+ tokens = torch .randn (2 , 31 , 512 )
97
+
98
+ attended = attn (tokens )
99
+
100
+ assert tokens .shape == attended .shape
Original file line number Diff line number Diff line change 48
48
INTERPOLATED_IMPORTANCE_SCORE = False
49
49
USE_DIFF_TOPK = True
50
50
51
- USE_EFFICIENT_INFERENCE = True # needs validation still
51
+ USE_EFFICIENT_INFERENCE = False # needs validation still
52
52
53
53
# experiment related
54
54
You can’t perform that action at this time.
0 commit comments