Skip to content

Commit f77444e

Browse files
committed
allow for the strategy combine "mlp" to be customized as well, but this portion is not that critical imo
1 parent d531216 commit f77444e

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __init__(
8888
use_diff_topk = False,
8989
compress_mlp: Module | None = None,
9090
compress_mlp_expand_factor = 1.,
91-
91+
strategy_combine_mlp: Module | None = None
9292
):
9393
super().__init__()
9494
self.heads = heads
@@ -138,7 +138,7 @@ def __init__(
138138
compress_mlp = nn.Sequential(
139139
Rearrange('b h w n d -> b h w (n d)'),
140140
nn.Linear(compress_dim, compress_mlp_dim_hidden),
141-
nn.SiLU(),
141+
nn.ReLU(),
142142
nn.Linear(compress_mlp_dim_hidden, dim_head),
143143
)
144144

@@ -154,8 +154,11 @@ def __init__(
154154

155155
# they combine the three sparse branches through a learned combine with sigmoid activation
156156

157+
if not exists(strategy_combine_mlp):
158+
strategy_combine_mlp = nn.Linear(dim, 3 * heads)
159+
157160
self.to_strategy_combine = nn.Sequential(
158-
nn.Linear(dim, 3 * heads),
161+
strategy_combine_mlp,
159162
nn.Sigmoid(),
160163
Rearrange('b n (h s) -> b h n s', h = heads)
161164
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.11"
3+
version = "0.0.12"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)