Skip to content

Commit 8165138

Browse files
committed
give the compress mlp some depth, then allow it to be customizable by researchers. vague in the paper
1 parent 7f567fc commit 8165138

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from copy import deepcopy
34
from math import ceil
45

56
import torch
@@ -85,6 +86,9 @@ def __init__(
8586
num_compressed_mem_kv = 4,
8687
norm = True,
8788
use_diff_topk = False,
89+
compress_mlp: Module | None = None,
90+
compress_mlp_expand_factor = 1.,
91+
8892
):
8993
super().__init__()
9094
self.heads = heads
@@ -120,21 +124,26 @@ def __init__(
120124

121125
assert num_compressed_mem_kv > 0
122126

127+
self.split_compress_window = Rearrange('b h (w n) d -> b h w n d', n = compress_block_size)
128+
123129
self.compress_mem_kv = nn.Parameter(torch.zeros(2, heads, num_compressed_mem_kv, dim_head))
130+
124131
self.k_intrablock_positions = nn.Parameter(torch.zeros(heads, compress_block_size, dim_head))
125132
self.v_intrablock_positions = nn.Parameter(torch.zeros(heads, compress_block_size, dim_head))
126133

127-
self.k_compress = nn.Sequential(
128-
Rearrange('b h n d -> b (h d) n'),
129-
nn.Conv1d(dim_head * heads, dim_head * heads, compress_block_size, stride = compress_block_size, groups = heads),
130-
Rearrange('b (h d) nc -> b h nc d', h = heads)
131-
)
134+
if not exists(compress_mlp):
135+
compress_dim = compress_block_size * dim_head
136+
compress_mlp_dim_hidden = int(compress_mlp_expand_factor * compress_dim)
132137

133-
self.v_compress = nn.Sequential(
134-
Rearrange('b h n d -> b (h d) n'),
135-
nn.Conv1d(dim_head * heads, dim_head * heads, compress_block_size, stride = compress_block_size, groups = heads),
136-
Rearrange('b (h d) nc -> b h nc d', h = heads)
137-
)
138+
mlp = nn.Sequential(
139+
Rearrange('b h w n d -> b h w (n d)'),
140+
nn.Linear(compress_dim, compress_mlp_dim_hidden),
141+
nn.SiLU(),
142+
nn.Linear(compress_mlp_dim_hidden, dim_head),
143+
)
144+
145+
self.k_compress = deepcopy(mlp)
146+
self.v_compress = deepcopy(mlp)
138147

139148
# selection related
140149

@@ -187,8 +196,11 @@ def forward(
187196
k_pos = repeat(self.k_intrablock_positions, 'h n d -> h (r n) d', r = num_compress_blocks)
188197
v_pos = repeat(self.v_intrablock_positions, 'h n d -> h (r n) d', r = num_compress_blocks)
189198

190-
ck = self.k_compress(k[..., :compress_divisible_seq_len, :] + k_pos)
191-
cv = self.v_compress(v[..., :compress_divisible_seq_len, :] + v_pos)
199+
k_compress_input = self.split_compress_window(k[..., :compress_divisible_seq_len, :] + k_pos)
200+
v_compress_input = self.split_compress_window(v[..., :compress_divisible_seq_len, :] + v_pos)
201+
202+
ck = self.k_compress(k_compress_input)
203+
cv = self.v_compress(v_compress_input)
192204

193205
# 1. coarse attention over compressed
194206

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.8"
3+
version = "0.0.9"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)