@@ -461,16 +461,12 @@ def forward_inference(
461
461
importance_scores = csim [..., self .num_mem_compress_kv :]
462
462
463
463
num_compress_blocks = importance_scores .shape [- 1 ]
464
+ num_compress_per_fine = self .selection_block_size // self .compress_block_size
464
465
465
466
if self .compress_block_size != self .selection_block_size :
466
- compress_seq_len = num_compress_blocks * self .compress_block_size
467
-
468
- importance_scores = repeat (importance_scores , '... j -> ... (bsz j)' , bsz = self .compress_block_size )
469
-
470
- fine_seq_len = round_down_mult (compress_seq_len , self .selection_block_size )
471
-
472
- importance_scores = importance_scores [..., :fine_seq_len ]
473
- importance_scores = reduce (importance_scores , '... (bsz j) -> ... j' , 'mean' , bsz = self .selection_block_size )
467
+ compress_seq_len = round_down_mult (num_compress_blocks , num_compress_per_fine )
468
+ importance_scores = importance_scores [..., :compress_seq_len ]
469
+ importance_scores = reduce (importance_scores , '... (j num_compress_per_fine) -> ... j' , 'mean' , num_compress_per_fine = num_compress_per_fine )
474
470
475
471
num_fine_blocks = importance_scores .shape [- 1 ]
476
472
num_selected = min (self .num_selected_blocks , num_fine_blocks )
@@ -490,7 +486,9 @@ def forward_inference(
490
486
if self .query_heads_share_selected_kv :
491
487
importance_scores = reduce (importance_scores , 'b (h grouped_queries) ... -> b h ...' , 'mean' , grouped_queries = self .num_grouped_queries )
492
488
489
+ importance_scores = F .pad (importance_scores , (1 , 0 ), value = - 1e3 )
493
490
importance_scores = importance_scores .softmax (dim = - 1 )
491
+ importance_scores = importance_scores [..., 1 :]
494
492
495
493
sel_scores , sel_indices = importance_scores .topk (num_selected , dim = - 1 )
496
494
@@ -689,26 +687,24 @@ def forward(
689
687
690
688
if self .compress_block_size != self .selection_block_size :
691
689
692
- compress_seq_len = num_compress_blocks * self .compress_block_size
693
-
694
- importance_scores = repeat (importance_scores , '... j -> ... (j block_size)' , block_size = self .compress_block_size )
690
+ num_compress_per_fine = self .selection_block_size // self .compress_block_size
695
691
696
- padding = fine_divisible_seq_len - compress_seq_len
692
+ round_down_score_len = round_down_mult (importance_scores .shape [- 1 ], num_compress_per_fine )
693
+ importance_scores = importance_scores [..., :round_down_score_len ]
697
694
698
- fine_query_seq_len = importance_scores . shape [ - 2 ]
699
- fine_query_padding = fine_divisible_seq_len - importance_scores . shape [ - 2 ]
695
+ if not is_empty ( importance_scores ):
696
+ importance_scores = reduce ( importance_scores , '... (j num_compress_per_fine) -> ... j' , 'mean' , num_compress_per_fine = num_compress_per_fine )
700
697
701
- importance_scores = F . pad ( importance_scores , ( 0 , padding ))
698
+ i , j = importance_scores . shape [ - 2 :]
702
699
703
- # mask out the diagonal since block causal is included by default for fine attending
700
+ # mask out block diagonal
704
701
705
- block_causal_mask = torch .ones ((num_fine_blocks ,) * 2 , device = device , dtype = torch .bool ).tril (- 1 )
706
- block_causal_mask = repeat (block_causal_mask , 'i j -> (i n1) (j n2)' , n1 = self .selection_block_size , n2 = self .selection_block_size )
707
- block_causal_mask = block_causal_mask [:fine_query_seq_len ]
702
+ q_seq = arange (i , device = device ) // self .selection_block_size
703
+ k_seq = arange (j , device = device )
708
704
709
- importance_scores = importance_scores . masked_fill ( ~ block_causal_mask , max_neg_value ( csim ) )
705
+ block_diagonal_mask = einx . equal ( 'i, j -> i j' , q_seq , k_seq )
710
706
711
- importance_scores = reduce ( importance_scores , '... (j block_size) -> ... j' , 'mean' , block_size = self . selection_block_size )
707
+ importance_scores = importance_scores . masked_fill ( block_diagonal_mask , max_neg_value ( csim ) )
712
708
713
709
importance_scores = F .pad (importance_scores , (1 , 0 ), value = - 1e3 )
714
710
importance_scores = importance_scores .softmax (dim = - 1 )
0 commit comments