@@ -290,6 +290,9 @@ def __init__(
290
290
self .split_compress_window = split_compress_window_fn
291
291
self .compress_window_size = compress_window_size
292
292
293
+ assert compress_block_overlap_len < compress_block_size
294
+ self .compress_block_overlap_len = compress_block_overlap_len
295
+
293
296
# compression attention related parameters
294
297
295
298
self .num_mem_compress_kv = num_compressed_mem_kv
@@ -382,6 +385,7 @@ def forward_inference(
382
385
383
386
sliding_window = self .sliding_window_size
384
387
compress_divisible_seq_len = round_down_mult (seq_len , self .compress_block_size )
388
+ compress_overlap_len = self .compress_block_overlap_len
385
389
386
390
fine_divisible_seq_len = round_up_mult (seq_len , self .selection_block_size )
387
391
num_fine_blocks = fine_divisible_seq_len // self .selection_block_size
@@ -439,19 +443,20 @@ def forward_inference(
439
443
440
444
running_compress_seq_len = run_k .shape [- 2 ]
441
445
442
- if divisible_by (running_compress_seq_len , self .compress_block_size ):
443
-
444
- k_compress_input = self .split_compress_window (run_k )
445
- v_compress_input = self .split_compress_window (run_v )
446
+ if divisible_by (running_compress_seq_len , self .compress_block_size + compress_overlap_len ):
447
+ k_compress_input = rearrange (run_k , 'b h n d -> b h 1 n d' )
448
+ v_compress_input = rearrange (run_v , 'b h n d -> b h 1 n d' )
446
449
447
450
k_compress_input = einx .add ('b h w n d, h n d' , k_compress_input , self .k_intrablock_positions )
448
451
v_compress_input = einx .add ('b h w n d, h n d' , v_compress_input , self .v_intrablock_positions )
449
452
450
453
next_ck = self .k_compress (k_compress_input )
451
454
next_cv = self .v_compress (v_compress_input )
452
455
453
- run_k = run_k [..., 0 :0 , :]
454
- run_v = run_v [..., 0 :0 , :]
456
+ run_kv_slice = slice (- compress_overlap_len , None ) if compress_overlap_len > 0 else slice (0 , 0 )
457
+
458
+ run_k = run_k [..., run_kv_slice , :]
459
+ run_v = run_v [..., run_kv_slice , :]
455
460
456
461
ck = cat ((ck , next_ck ), dim = - 2 )
457
462
cv = cat ((cv , next_cv ), dim = - 2 )
@@ -593,6 +598,8 @@ def forward(
593
598
compress_divisible_seq_len = round_down_mult (seq_len , self .compress_block_size )
594
599
num_compress_blocks = compress_divisible_seq_len // self .compress_block_size
595
600
601
+ compress_overlap_len = self .compress_block_overlap_len
602
+
596
603
fine_divisible_seq_len = round_up_mult (seq_len , self .selection_block_size )
597
604
num_fine_blocks = fine_divisible_seq_len // self .selection_block_size
598
605
@@ -622,8 +629,14 @@ def forward(
622
629
k_compress_input = einx .add ('b h w n d, h n d' , k_compress_input , self .k_intrablock_positions )
623
630
v_compress_input = einx .add ('b h w n d, h n d' , v_compress_input , self .v_intrablock_positions )
624
631
625
- run_k = k [..., compress_divisible_seq_len :, :]
626
- run_v = v [..., compress_divisible_seq_len :, :]
632
+ run_k , run_v = k , v
633
+
634
+ if return_cache and compress_overlap_len > 0 :
635
+ run_k = F .pad (run_k , (0 , 0 , compress_overlap_len , 0 ), value = 0. )
636
+ run_v = F .pad (run_v , (0 , 0 , compress_overlap_len , 0 ), value = 0. )
637
+
638
+ run_k = run_k [..., compress_divisible_seq_len :, :]
639
+ run_v = run_v [..., compress_divisible_seq_len :, :]
627
640
628
641
cq = q
629
642
ck = self .k_compress (k_compress_input ) # Equation (7) of the Native Sparse Attention paper
0 commit comments