@@ -361,7 +361,7 @@ def forward_inference(
361
361
362
362
# rotate after updating the compression running k/v
363
363
364
- q = self .rotary_emb .rotate_queries_or_keys (q , offset = cache_len )
364
+ rotated_q = self .rotary_emb .rotate_queries_or_keys (q , offset = cache_len )
365
365
k = self .rotary_emb .rotate_queries_or_keys (k , offset = cache_len )
366
366
367
367
# handle cache, which stores the rotated
@@ -459,7 +459,7 @@ def forward_inference(
459
459
460
460
# remove later
461
461
462
- fq = rearrange (q , 'b (h gh) ... -> b h gh ...' , gh = self .num_grouped_queries )
462
+ fq = rearrange (rotated_q , 'b (h gh) ... -> b h gh ...' , gh = self .num_grouped_queries )
463
463
464
464
fsim = einsum (fq , fk , 'b h gh i d, b h j d -> b h gh i j' ) * scale
465
465
@@ -476,11 +476,12 @@ def forward_inference(
476
476
v = repeat (v , 'b h ... -> b (h gh) ...' , gh = self .num_grouped_queries )
477
477
478
478
sliding_slice = (Ellipsis , slice (- (sliding_window + 1 ), None ), slice (None ))
479
- rotated_q , rotated_k = self .rotary_emb .rotate_queries_with_cached_keys (q , k [sliding_slice ])
480
479
481
- sim = einsum (rotated_q , rotated_k , 'b h i d, b h j d -> b h i j' ) * scale
480
+ k , v = k [sliding_slice ], v [sliding_slice ]
481
+
482
+ sim = einsum (rotated_q , k , 'b h i d, b h j d -> b h i j' ) * scale
482
483
attn = sim .softmax (dim = - 1 )
483
- sliding_window_attn_out = einsum (attn , v [ sliding_slice ] , 'b h i j, b h j d -> b h i d' )
484
+ sliding_window_attn_out = einsum (attn , v , 'b h i j, b h j d -> b h i d' )
484
485
485
486
# combine strategies
486
487
@@ -630,8 +631,8 @@ def forward(
630
631
631
632
# handle if number of total blocks is less than number to select for fine attention
632
633
633
- fq = rotated_q
634
- fk = rotated_k
634
+ fq = q
635
+ fk = k
635
636
fv = v
636
637
637
638
if has_selected_kv_for_fine_attn :
@@ -757,8 +758,8 @@ def forward(
757
758
758
759
# 3. overlapping sliding window, this is unsurprising and expected - `s` for sliding
759
760
760
- sq = rotated_q
761
- sk = rotated_k
761
+ sq = q
762
+ sk = k
762
763
sv = v
763
764
764
765
if exists (sliding_window_flex_mask ):
0 commit comments