@@ -218,6 +218,7 @@ def forward(
218
218
fk = k
219
219
fv = v
220
220
221
+
221
222
if seq_len < fine_divisible_seq_len :
222
223
remainder = fine_divisible_seq_len - seq_len
223
224
fk = pad_at_dim (fk , (0 , remainder ), value = 0. , dim = - 2 )
@@ -228,13 +229,30 @@ def forward(
228
229
229
230
selected_block_indices = pad_at_dim (selected_block_indices , (0 , remainder ), value = 0 , dim = - 2 )
230
231
232
+ # handle block causal diagonal in the diagram, but run experiments without to see
233
+
234
+ fine_window_seq = arange (fine_divisible_seq_len , device = device ) // self .selection_block_size
235
+ fine_window_seq = rearrange (fine_window_seq , 'n -> n 1' ).expand_as (selected_block_indices )
236
+ selected_block_indices = cat ((selected_block_indices , fine_window_seq ), dim = - 1 ) # for the block causal diagonal in fig2
237
+
238
+ fmask = repeat (fmask , 'b h i w -> b h i w j' , j = self .selection_block_size )
239
+
240
+ causal_mask = torch .ones ((self .selection_block_size ,) * 2 , device = device , dtype = torch .bool ).tril ()
241
+ causal_mask = repeat (causal_mask , 'i j -> (w i) 1 j' , w = num_fine_blocks ).expand_as (fmask )
242
+
243
+ fmask = cat ((fmask , causal_mask ), dim = - 2 )
244
+ fmask = rearrange (fmask , 'b h i w j -> b h i (w j)' )
245
+
246
+ # select out the spatial crops of keys / values for fine attention
247
+
231
248
fk = rearrange (fk , 'b h (w n) d -> b h w n d' , w = num_fine_blocks )
232
249
fv = rearrange (fv , 'b h (w n) d -> b h w n d' , w = num_fine_blocks )
233
- fmask = repeat (fmask , 'b h i w -> b h i (w j)' , j = self .selection_block_size )
234
250
235
251
fk = einx .get_at ('b h [w] j d, b h i selected -> b h i (selected j) d' , fk , selected_block_indices )
236
252
fv = einx .get_at ('b h [w] j d, b h i selected -> b h i (selected j) d' , fv , selected_block_indices )
237
253
254
+ # fine attention
255
+
238
256
fsim = einsum (fq , fk , 'b h i d, b h i j d -> b h i j' ) * self .scale
239
257
240
258
fsim = fsim .masked_fill (fmask , mask_value )
0 commit comments