@@ -324,12 +324,19 @@ def forward_inference(
324
324
):
325
325
# destruct cache
326
326
327
- (cache_k , cache_v ), (cache_ck , cache_cv ) = cache
327
+ (
328
+ (cache_k , cache_v ),
329
+ (
330
+ (cache_ck , cache_cv ),
331
+ (run_k , run_v )
332
+ )
333
+ ) = cache
328
334
329
335
# variables
330
336
331
337
batch , scale , heads , device = inp .shape [0 ], self .scale , self .heads , inp .device
332
- seq_len = cache_k .shape [- 2 ] + 1
338
+ cache_len = cache_k .shape [- 2 ]
339
+ seq_len = cache_len + 1
333
340
334
341
sliding_window = self .sliding_window_size
335
342
compress_divisible_seq_len = round_down_mult (seq_len , self .compress_block_size )
@@ -347,7 +354,17 @@ def forward_inference(
347
354
348
355
q , k , v = map (self .split_heads , (q , k , v ))
349
356
350
- # handle cache
357
+ # take care of running k and v for compression, which should NOT be rotated https://arxiv.org/abs/2501.18795
358
+
359
+ run_k = cat ((run_k , k ), dim = - 2 )
360
+ run_v = cat ((run_v , v ), dim = - 2 )
361
+
362
+ # rotate after updating the compression running k/v
363
+
364
+ q = self .rotary_emb .rotate_queries_or_keys (q , offset = cache_len )
365
+ k = self .rotary_emb .rotate_queries_or_keys (k , offset = cache_len )
366
+
367
+ # handle cache, which stores the rotated
351
368
352
369
k = cat ((cache_k , k ), dim = - 2 )
353
370
v = cat ((cache_v , v ), dim = - 2 )
@@ -369,18 +386,24 @@ def forward_inference(
369
386
370
387
compressed_attn_out = einsum (cattn , repeated_cv , 'b h i j, b h j d -> b h i d' )
371
388
372
- if divisible_by (seq_len , self .compress_block_size ):
373
- k_compress_input = self .split_compress_window (k [..., - self .compress_block_size :, :] + self .k_intrablock_positions )
374
- v_compress_input = self .split_compress_window (v [..., - self .compress_block_size :, :] + self .v_intrablock_positions )
389
+ running_compress_seq_len = run_k .shape [- 2 ]
390
+
391
+ if divisible_by (running_compress_seq_len , self .compress_block_size ):
392
+
393
+ k_compress_input = self .split_compress_window (run_k + self .k_intrablock_positions )
394
+ v_compress_input = self .split_compress_window (run_v + self .v_intrablock_positions )
375
395
376
396
next_ck = self .k_compress (k_compress_input )
377
397
next_cv = self .v_compress (v_compress_input )
378
398
399
+ run_k = run_k [..., 0 :0 , :]
400
+ run_v = run_v [..., 0 :0 , :]
401
+
379
402
ck = cat ((ck , next_ck ), dim = - 2 )
380
403
cv = cat ((cv , next_cv ), dim = - 2 )
381
404
382
405
if return_cache :
383
- cache_compressed_kv = (ck , cv )
406
+ cache_compressed_kv = (( ck , cv ), ( run_k , run_v ) )
384
407
385
408
# 2. fine attention inference (todo - compress and fine diff block sizes)
386
409
@@ -395,10 +418,8 @@ def forward_inference(
395
418
396
419
# block causal diagonal
397
420
398
- rotated_q , rotated_k = self .rotary_emb .rotate_queries_with_cached_keys (q , k )
399
-
400
421
fine_sliding_window = (seq_len % self .selection_block_size ) + 1
401
- fk = rotated_k [..., - fine_sliding_window :, :]
422
+ fk = k [..., - fine_sliding_window :, :]
402
423
fv = v [..., - fine_sliding_window :, :]
403
424
404
425
# select out the sparse kv segments as defined by compressed attention map as importance score
@@ -412,7 +433,7 @@ def forward_inference(
412
433
fine_divisible_seq_len = round_up_mult (seq_len , self .selection_block_size )
413
434
remainder = fine_divisible_seq_len - k .shape [- 2 ]
414
435
415
- sel_fk = pad_at_dim (rotated_k , (0 , remainder ), dim = - 2 )
436
+ sel_fk = pad_at_dim (k , (0 , remainder ), dim = - 2 )
416
437
sel_fv = pad_at_dim (v , (0 , remainder ), dim = - 2 )
417
438
418
439
sel_fk = rearrange (sel_fk , 'b h (w j) d -> b h w j d' , j = self .selection_block_size )
@@ -438,7 +459,7 @@ def forward_inference(
438
459
439
460
# remove later
440
461
441
- fq = rearrange (rotated_q , 'b (h gh) ... -> b h gh ...' , gh = self .num_grouped_queries )
462
+ fq = rearrange (q , 'b (h gh) ... -> b h gh ...' , gh = self .num_grouped_queries )
442
463
443
464
fsim = einsum (fq , fk , 'b h gh i d, b h j d -> b h gh i j' ) * scale
444
465
@@ -524,12 +545,15 @@ def forward(
524
545
k_compress_input = self .split_compress_window (k [..., :compress_divisible_seq_len , :] + k_pos )
525
546
v_compress_input = self .split_compress_window (v [..., :compress_divisible_seq_len , :] + v_pos )
526
547
548
+ run_k = k [..., compress_divisible_seq_len :, :]
549
+ run_v = v [..., compress_divisible_seq_len :, :]
550
+
527
551
cq = q
528
552
ck = self .k_compress (k_compress_input ) # Equation (7) of the Native Sparse Attention paper
529
553
cv = self .v_compress (v_compress_input )
530
554
531
555
if return_cache :
532
- cache_compressed_kv = (ck , cv )
556
+ cache_compressed_kv = (( ck , cv ), ( run_k , run_v ) )
533
557
534
558
# 1. coarse attention over compressed
535
559
@@ -549,7 +573,6 @@ def forward(
549
573
compressed_attn_out , csim = attend (cq , ck , cv , mask = cmask , return_sim = True )
550
574
551
575
# for 2. and 3., will give them relative positions with rotary - compressed needs to be handled separately (even if they already have intra block absolute positions)
552
-
553
576
rotated_q , rotated_k = self .rotary_emb .rotate_queries_with_cached_keys (q , k )
554
577
555
578
# 2. fine attention over selected based on compressed attention logits - variables prepended with `f` stands for the fine attention pathway
0 commit comments