File tree Expand file tree Collapse file tree 2 files changed +7
-4
lines changed
native_sparse_attention_pytorch Expand file tree Collapse file tree 2 files changed +7
-4
lines changed Original file line number Diff line number Diff line change @@ -88,7 +88,7 @@ def __init__(
88
88
use_diff_topk = False ,
89
89
compress_mlp : Module | None = None ,
90
90
compress_mlp_expand_factor = 1. ,
91
-
91
+ strategy_combine_mlp : Module | None = None
92
92
):
93
93
super ().__init__ ()
94
94
self .heads = heads
@@ -138,7 +138,7 @@ def __init__(
138
138
compress_mlp = nn .Sequential (
139
139
Rearrange ('b h w n d -> b h w (n d)' ),
140
140
nn .Linear (compress_dim , compress_mlp_dim_hidden ),
141
- nn .SiLU (),
141
+ nn .ReLU (),
142
142
nn .Linear (compress_mlp_dim_hidden , dim_head ),
143
143
)
144
144
@@ -154,8 +154,11 @@ def __init__(
154
154
155
155
# they combine the three sparse branches through a learned combine with sigmoid activation
156
156
157
+ if not exists (strategy_combine_mlp ):
158
+ strategy_combine_mlp = nn .Linear (dim , 3 * heads )
159
+
157
160
self .to_strategy_combine = nn .Sequential (
158
- nn . Linear ( dim , 3 * heads ) ,
161
+ strategy_combine_mlp ,
159
162
nn .Sigmoid (),
160
163
Rearrange ('b n (h s) -> b h n s' , h = heads )
161
164
)
Original file line number Diff line number Diff line change 1
1
[project ]
2
2
name = " native-sparse-attention-pytorch"
3
- version = " 0.0.11 "
3
+ version = " 0.0.12 "
4
4
description = " Native Sparse Attention"
5
5
authors = [
6
6
{ name = " Phil Wang" , email = " lucidrains@gmail.com" }
You can’t perform that action at this time.
0 commit comments