@@ -395,8 +395,10 @@ def forward_inference(
395
395
396
396
# block causal diagonal
397
397
398
+ rotated_q , rotated_k = self .rotary_emb .rotate_queries_with_cached_keys (q , k )
399
+
398
400
fine_sliding_window = (seq_len % self .selection_block_size ) + 1
399
- fk = k [..., - fine_sliding_window :, :]
401
+ fk = rotated_k [..., - fine_sliding_window :, :]
400
402
fv = v [..., - fine_sliding_window :, :]
401
403
402
404
# select out the sparse kv segments as defined by compressed attention map as importance score
@@ -410,7 +412,7 @@ def forward_inference(
410
412
fine_divisible_seq_len = round_up_mult (seq_len , self .selection_block_size )
411
413
remainder = fine_divisible_seq_len - k .shape [- 2 ]
412
414
413
- sel_fk = pad_at_dim (k , (0 , remainder ), dim = - 2 )
415
+ sel_fk = pad_at_dim (rotated_k , (0 , remainder ), dim = - 2 )
414
416
sel_fv = pad_at_dim (v , (0 , remainder ), dim = - 2 )
415
417
416
418
sel_fk = rearrange (sel_fk , 'b h (w j) d -> b h w j d' , j = self .selection_block_size )
@@ -430,7 +432,7 @@ def forward_inference(
430
432
431
433
# remove later
432
434
433
- fq = rearrange (q , 'b (h gh) ... -> b h gh ...' , gh = self .num_grouped_queries )
435
+ fq = rearrange (rotated_q , 'b (h gh) ... -> b h gh ...' , gh = self .num_grouped_queries )
434
436
435
437
fsim = einsum (fq , fk , 'b h gh i d, b h j d -> b h gh i j' ) * scale
436
438
@@ -457,7 +459,7 @@ def forward_inference(
457
459
458
460
strategy_weighted_combine = self .to_strategy_combine (inp )
459
461
460
- out = einsum (strategy_weighted_combine , stack ([compressed_attn_out , compressed_attn_out , sliding_window_attn_out ]), 'b h n s, s b h n d -> b h n d' )
462
+ out = einsum (strategy_weighted_combine , stack ([compressed_attn_out , fine_attn_out , sliding_window_attn_out ]), 'b h n s, s b h n d -> b h n d' )
461
463
462
464
# merge heads and combine them
463
465
0 commit comments