Skip to content

Commit 1798a2c

Browse files
committed
remove interpolated importance score, and for now only focus on selection block sizes greater than compress block
1 parent cb34259 commit 1798a2c

File tree

4 files changed

+10
-28
lines changed

4 files changed

+10
-28
lines changed

native_sparse_attention_pytorch/compress_networks.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,9 @@ def forward(
135135
):
136136
return self.compress(kv)
137137

138+
# simple transformer compressor, pull requested by Eric Pasewark
138139

139-
class SimpleMultiheadSelfAttention(nn.Module):
140+
class SimpleMultiheadSelfAttention(Module):
140141
def __init__(self, dim, num_heads, dropout=0.0):
141142
super().__init__()
142143
assert dim % num_heads == 0, "Hidden dimension must be divisible by number of heads"
@@ -167,7 +168,7 @@ def forward(self, x):
167168
attn_out = attn_out.transpose(1, 2).reshape(B, L, D)
168169
return self.out_proj(attn_out)
169170

170-
class SimpleTransformerFeedForward(nn.Module):
171+
class SimpleTransformerFeedForward(Module):
171172
def __init__(self, dim, hidden_dim, dropout=0.0):
172173
"""Two-layer feed-forward network with GELU activation."""
173174
super().__init__()
@@ -183,7 +184,7 @@ def forward(self, x):
183184
out = self.dropout(out)
184185
return out
185186

186-
class SimpleTransformerLayer(nn.Module):
187+
class SimpleTransformerLayer(Module):
187188
def __init__(self, dim, num_heads, ff_hidden_dim=None, dropout=0.0):
188189
"""Single Transformer layer: RMSNorm + Multi-head attention + RMSNorm + FeedForward."""
189190
super().__init__()
@@ -201,7 +202,7 @@ def forward(self, x):
201202
x = x + f
202203
return x
203204

204-
class CompressTransformer(nn.Module):
205+
class CompressTransformer(Module):
205206
def __init__(self, num_layers, dim, num_heads, ff_hidden_dim=None, dropout=0.0):
206207
"""
207208
Stacked Transformer encoder layers.

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,6 @@ def pad_at_dim(t, pad, dim = -1, value = 0.):
145145
zeros = ((0, 0) * dims_from_right)
146146
return F.pad(t, (*zeros, *pad), value = value)
147147

148-
def interpolate_1d(x, length, mode = 'bilinear'):
149-
x, inverse_pack = pack_one_with_inverse(x, '* n')
150-
x = rearrange(x, 'b n -> b 1 n 1')
151-
x = F.interpolate(x, (length, 1), mode = mode)
152-
x = rearrange(x, 'b 1 n 1 -> b n')
153-
return inverse_pack(x)
154-
155148
def straight_through(t, target):
156149
return t + (target - t).detach()
157150

@@ -209,7 +202,6 @@ def __init__(
209202
norm = True,
210203
use_diff_topk = False,
211204
use_triton_kernel = False,
212-
interpolated_importance_score = False,
213205
query_heads_share_selected_kv = True, # if set to True, importance score is averaged across query heads to select top-n buckets of kv per kv head - but can be set to False for each query head within a group to look at different sets of kv buckets. will be more memory and compute of course
214206
compress_mlp: Module | None = None,
215207
compress_mlp_expand_factor = 1.,
@@ -319,10 +311,10 @@ def __init__(
319311

320312
self.use_diff_topk = use_diff_topk
321313

322-
self.interpolated_importance_score = interpolated_importance_score # in the case fine block size < compressed block size, will weigh space better when selecting
323-
324314
self.query_heads_share_selected_kv = query_heads_share_selected_kv
325315

316+
assert divisible_by(selection_block_size, compress_block_size), f'selection block size {selection_block_size} must be greater than or equal to compress block size {compress_block_size}, as well as divisible by the compress block size'
317+
326318
self.selection_block_size = selection_block_size
327319

328320
assert num_selected_blocks >= 0
@@ -473,10 +465,7 @@ def forward_inference(
473465
if self.compress_block_size != self.selection_block_size:
474466
compress_seq_len = num_compress_blocks * self.compress_block_size
475467

476-
if self.interpolated_importance_score:
477-
importance_scores = interpolate_1d(importance_scores, compress_seq_len)
478-
else:
479-
importance_scores = repeat(importance_scores, '... j -> ... (bsz j)', bsz = self.compress_block_size)
468+
importance_scores = repeat(importance_scores, '... j -> ... (bsz j)', bsz = self.compress_block_size)
480469

481470
fine_seq_len = round_down_mult(compress_seq_len, self.selection_block_size)
482471

@@ -702,10 +691,7 @@ def forward(
702691

703692
compress_seq_len = num_compress_blocks * self.compress_block_size
704693

705-
if self.interpolated_importance_score:
706-
importance_scores = interpolate_1d(importance_scores, compress_seq_len)
707-
else:
708-
importance_scores = repeat(importance_scores, '... j -> ... (j block_size)', block_size = self.compress_block_size)
694+
importance_scores = repeat(importance_scores, '... j -> ... (j block_size)', block_size = self.compress_block_size)
709695

710696
padding = fine_divisible_seq_len - compress_seq_len
711697

tests/test_sparse_attn.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,11 @@
1010
@pytest.mark.parametrize('causal', (False, True))
1111
@pytest.mark.parametrize('seq_len', (1, 4, 31, 32, 120))
1212
@pytest.mark.parametrize('kv_heads', (8, 4))
13-
@pytest.mark.parametrize('selection_block_size', (8, 4, 2))
13+
@pytest.mark.parametrize('selection_block_size', (8, 16, 32))
1414
@pytest.mark.parametrize('compress_block_size', (8, 4))
1515
@pytest.mark.parametrize('compress_block_overlap_len', (0, 2))
1616
@pytest.mark.parametrize('num_selected_block', (0, 2))
1717
@pytest.mark.parametrize('query_heads_share_selected_kv', (False, True))
18-
@pytest.mark.parametrize('interpolated_importance_score', (False, True))
1918
def test_sparse_attn(
2019
use_diff_topk,
2120
causal,
@@ -26,7 +25,6 @@ def test_sparse_attn(
2625
compress_block_overlap_len,
2726
num_selected_block,
2827
query_heads_share_selected_kv,
29-
interpolated_importance_score
3028
):
3129
attn = SparseAttention(
3230
dim = 512,
@@ -41,7 +39,6 @@ def test_sparse_attn(
4139
num_selected_blocks = num_selected_block,
4240
use_diff_topk = use_diff_topk,
4341
query_heads_share_selected_kv = query_heads_share_selected_kv,
44-
interpolated_importance_score = interpolated_importance_score
4542
)
4643

4744
tokens = torch.randn(2, seq_len, 512)

train.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
FINE_BLOCK_SIZE = 16
4646
NUM_FINE_SELECTED = 4
4747

48-
INTERPOLATED_IMPORTANCE_SCORE = False
4948
USE_DIFF_TOPK = True
5049

5150
USE_EFFICIENT_INFERENCE = True # needs validation still
@@ -106,7 +105,6 @@ def decode_tokens(tokens):
106105
selection_block_size = FINE_BLOCK_SIZE,
107106
num_selected_blocks = NUM_FINE_SELECTED,
108107
use_diff_topk = USE_DIFF_TOPK,
109-
interpolated_importance_score = INTERPOLATED_IMPORTANCE_SCORE,
110108
query_heads_share_selected_kv = QUERY_HEADS_SHARE_SELECTION
111109
)
112110
).cuda()

0 commit comments

Comments
 (0)