@@ -398,8 +398,9 @@ def forward(
398
398
selected_importance_values , selected_block_indices = importance_scores .topk (num_selected , dim = - 1 )
399
399
400
400
if self .use_diff_topk :
401
- assert not exists (fine_selection_flex_mask )
402
401
gates = straight_through (selected_importance_values , 1. )
402
+ gates = gates .cumsum (dim = - 1 )[..., - 1 ]
403
+ gates = repeat (gates , 'b h ... -> b (h qh) ...' , qh = self .num_grouped_queries )
403
404
404
405
if exists (fine_selection_flex_mask ):
405
406
# flex attention for the selection for fine attention
@@ -422,7 +423,7 @@ def forward(
422
423
selected_block_indices = pad_at_dim (selected_block_indices , (0 , remainder ), value = 0 , dim = - 2 )
423
424
424
425
if self .use_diff_topk :
425
- gates = pad_at_dim (gates , (0 , remainder ), value = 1. , dim = - 2 )
426
+ gates = pad_at_dim (gates , (0 , remainder ), value = 1. )
426
427
427
428
# handle block causal diagonal in the diagram, but run experiments without to see
428
429
@@ -453,16 +454,7 @@ def forward(
453
454
fk = fk .gather (3 , selected_block_indices )
454
455
fv = fv .gather (3 , selected_block_indices )
455
456
456
- # handle maybe gating
457
-
458
- if self .use_diff_topk :
459
- gates = F .pad (gates , (0 , 1 ), value = 1. )
460
-
461
- fk = einx .multiply ('b h i w, b h i w j d -> b h i w j d' , gates , fk )
462
- fv = einx .multiply ('b h i w, b h i w j d -> b h i w j d' , gates , fv )
463
-
464
- fk = rearrange (fk , 'b h i w j d -> b h i (w j) d' )
465
- fv = rearrange (fv , 'b h i w j d -> b h i (w j) d' )
457
+ fk , fv = tuple (rearrange (t , 'b h i w j d -> b h i (w j) d' ) for t in (fk , fv ))
466
458
467
459
# fine attention
468
460
@@ -483,6 +475,13 @@ def forward(
483
475
fine_attn_out = rearrange (fine_attn_out , 'b h qh ... -> b (h qh) ...' )
484
476
485
477
fine_attn_out = fine_attn_out [..., :seq_len , :]
478
+
479
+ # handle maybe gating
480
+
481
+ if self .use_diff_topk :
482
+ gates = gates [..., :seq_len ]
483
+ fine_attn_out = einx .multiply ('b h n, b h n d -> b h n d' , gates , fine_attn_out )
484
+
486
485
else :
487
486
# if only first block, just do a simple block causal
488
487
0 commit comments