7
7
import torch .nn .functional as F
8
8
from torch .nn import Module , ModuleList
9
9
10
+ from colt5_attention import topk as differentiable_topk
11
+
10
12
from local_attention import LocalAttention
11
13
12
14
from rotary_embedding_torch import RotaryEmbedding
@@ -84,8 +86,11 @@ def __init__(
84
86
num_selected_blocks ,
85
87
num_compressed_mem_kv = 4 ,
86
88
norm = True ,
89
+ use_diff_topk = False ,
90
+ diff_topk_coor_descent_iters = 10.
87
91
):
88
92
super ().__init__ ()
93
+ self .heads = heads
89
94
self .scale = dim_head ** - 0.5
90
95
91
96
assert compress_block_size == selection_block_size , 'start off with compressed being equal to selection block sizes'
@@ -136,6 +141,9 @@ def __init__(
136
141
137
142
# selection related
138
143
144
+ self .use_diff_topk = use_diff_topk
145
+ self .diff_topk_coor_descent_iters = diff_topk_coor_descent_iters
146
+
139
147
self .selection_block_size = selection_block_size
140
148
self .num_selected_blocks = num_selected_blocks
141
149
@@ -160,7 +168,7 @@ def forward(
160
168
self ,
161
169
inp
162
170
):
163
- batch , seq_len , scale , device = * inp .shape [:2 ], self .scale , inp .device
171
+ batch , seq_len , scale , heads , device = * inp .shape [:2 ], self .scale , self . heads , inp .device
164
172
165
173
compress_divisible_seq_len = round_down_mult (seq_len , self .compress_block_size )
166
174
num_compress_blocks = compress_divisible_seq_len // self .compress_block_size
@@ -216,7 +224,10 @@ def forward(
216
224
217
225
importance_scores = csim [..., num_mem_compress_kv :]
218
226
219
- selected_importance_values , selected_block_indices = importance_scores .topk (self .num_selected_blocks , dim = - 1 )
227
+ if self .use_diff_topk :
228
+ selected_importance_values , selected_block_indices , _ , gates = differentiable_topk (importance_scores , self .num_selected_blocks , fused = True )
229
+ else :
230
+ selected_importance_values , selected_block_indices = importance_scores .topk (self .num_selected_blocks , dim = - 1 )
220
231
221
232
fmask = selected_importance_values > mask_value
222
233
@@ -239,13 +250,13 @@ def forward(
239
250
# handle block causal diagonal in the diagram, but run experiments without to see
240
251
241
252
fine_window_seq = arange (fine_divisible_seq_len , device = device ) // self .selection_block_size
242
- fine_window_seq = rearrange (fine_window_seq , 'n -> n 1' ). expand_as ( selected_block_indices )
253
+ fine_window_seq = repeat (fine_window_seq , 'n -> b h n 1' , b = batch , h = heads )
243
254
selected_block_indices = cat ((selected_block_indices , fine_window_seq ), dim = - 1 ) # for the block causal diagonal in fig2
244
255
245
256
fmask = repeat (fmask , 'b h i w -> b h i w j' , j = self .selection_block_size )
246
257
247
258
causal_mask = torch .ones ((self .selection_block_size ,) * 2 , device = device , dtype = torch .bool ).tril ()
248
- causal_mask = repeat (causal_mask , 'i j -> (w i) 1 j' , w = num_fine_blocks ). expand_as ( fmask )
259
+ causal_mask = repeat (causal_mask , 'i j -> b h (w i) 1 j' , w = num_fine_blocks , b = batch , h = heads )
249
260
250
261
fmask = cat ((fmask , causal_mask ), dim = - 2 )
251
262
fmask = rearrange (fmask , 'b h i w j -> b h i (w j)' )
@@ -255,8 +266,19 @@ def forward(
255
266
fk = rearrange (fk , 'b h (w n) d -> b h w n d' , w = num_fine_blocks )
256
267
fv = rearrange (fv , 'b h (w n) d -> b h w n d' , w = num_fine_blocks )
257
268
258
- fk = einx .get_at ('b h [w] j d, b h i selected -> b h i (selected j) d' , fk , selected_block_indices )
259
- fv = einx .get_at ('b h [w] j d, b h i selected -> b h i (selected j) d' , fv , selected_block_indices )
269
+ fk = einx .get_at ('b h [w] j d, b h i selected -> b h i selected j d' , fk , selected_block_indices )
270
+ fv = einx .get_at ('b h [w] j d, b h i selected -> b h i selected j d' , fv , selected_block_indices )
271
+
272
+ # handle maybe gating
273
+
274
+ if self .use_diff_topk :
275
+ gates = F .pad (gates , (0 , 1 , 0 , remainder ), value = 1. )
276
+
277
+ fk = einx .multiply ('b h i w, b h i w j d -> b h i w j d' , gates , fk )
278
+ fv = einx .multiply ('b h i w, b h i w j d -> b h i w j d' , gates , fv )
279
+
280
+ fk = rearrange (fk , 'b h i w j d -> b h i (w j) d' )
281
+ fv = rearrange (fv , 'b h i w j d -> b h i (w j) d' )
260
282
261
283
# fine attention
262
284
0 commit comments