You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: native_sparse_attention_pytorch/native_sparse_attention.py
+4-18Lines changed: 4 additions & 18 deletions
Original file line number
Diff line number
Diff line change
@@ -145,13 +145,6 @@ def pad_at_dim(t, pad, dim = -1, value = 0.):
145
145
zeros= ((0, 0) *dims_from_right)
146
146
returnF.pad(t, (*zeros, *pad), value=value)
147
147
148
-
definterpolate_1d(x, length, mode='bilinear'):
149
-
x, inverse_pack=pack_one_with_inverse(x, '* n')
150
-
x=rearrange(x, 'b n -> b 1 n 1')
151
-
x=F.interpolate(x, (length, 1), mode=mode)
152
-
x=rearrange(x, 'b 1 n 1 -> b n')
153
-
returninverse_pack(x)
154
-
155
148
defstraight_through(t, target):
156
149
returnt+ (target-t).detach()
157
150
@@ -209,7 +202,6 @@ def __init__(
209
202
norm=True,
210
203
use_diff_topk=False,
211
204
use_triton_kernel=False,
212
-
interpolated_importance_score=False,
213
205
query_heads_share_selected_kv=True, # if set to True, importance score is averaged across query heads to select top-n buckets of kv per kv head - but can be set to False for each query head within a group to look at different sets of kv buckets. will be more memory and compute of course
214
206
compress_mlp: Module|None=None,
215
207
compress_mlp_expand_factor=1.,
@@ -319,10 +311,10 @@ def __init__(
319
311
320
312
self.use_diff_topk=use_diff_topk
321
313
322
-
self.interpolated_importance_score=interpolated_importance_score# in the case fine block size < compressed block size, will weigh space better when selecting
assertdivisible_by(selection_block_size, compress_block_size), f'selection block size {selection_block_size} must be greater than or equal to compress block size {compress_block_size}, as well as divisible by the compress block size'
0 commit comments