@@ -58,7 +58,7 @@ def sliding_mask(_, __, q_idx, kv_idx):
58
58
block_mask = create_block_mask (sliding_mask , B = None , H = None , Q_LEN = seq_len , KV_LEN = seq_len , _compile = True )
59
59
return block_mask
60
60
61
- def create_compress_mask (seq_len , kv_seq_len , compress_block_size , mem_kv_len = 0 , causal = True ):
61
+ def create_compress_mask (seq_len , kv_seq_len , compress_block_sliding_stride , mem_kv_len = 0 , causal = True ):
62
62
63
63
if not causal :
64
64
return None
@@ -70,7 +70,7 @@ def compress_mask(_, __, q_idx, kv_idx):
70
70
is_mem_kv = kv_idx < mem_kv_len
71
71
72
72
kv_without_mem = kv_idx - mem_kv_len
73
- compress_kv_idx = (kv_without_mem * compress_block_size ) + (compress_block_size - 1 )
73
+ compress_kv_idx = (kv_without_mem * compress_block_sliding_stride ) + (compress_block_sliding_stride - 1 )
74
74
75
75
causal_mask = q_idx > compress_kv_idx
76
76
return causal_mask | is_mem_kv
@@ -193,9 +193,9 @@ def __init__(
193
193
heads ,
194
194
sliding_window_size ,
195
195
compress_block_size ,
196
+ compress_block_sliding_stride ,
196
197
selection_block_size ,
197
198
num_selected_blocks ,
198
- compress_block_overlap_len = 0 , # the amount of overlap of a given compression block to the previous block
199
199
kv_heads = None ,
200
200
num_compressed_mem_kv = 1 ,
201
201
causal = False ,
@@ -261,40 +261,28 @@ def __init__(
261
261
# compress strategy
262
262
263
263
self .compress_block_size = compress_block_size
264
+ self .compress_block_sliding_stride = compress_block_sliding_stride
265
+ assert self .compress_block_size >= self .compress_block_sliding_stride , 'compress_block_size must be >= compress_block_sliding_stride'
266
+ assert self .compress_block_sliding_stride > 0 , 'compress_block_sliding_stride must be greater than 0'
267
+ assert divisible_by (selection_block_size , self .compress_block_sliding_stride ), f'selection_block_size { selection_block_size } must be divisible by compress_block_sliding_stride { self .compress_block_sliding_stride } '
268
+
269
+ # Compression window splitting
270
+ self .split_compress_window = nn .Sequential (
271
+ Rearrange ('b h n d -> (b h) d 1 n' ),
272
+ nn .ZeroPad2d (((compress_block_size - compress_block_sliding_stride ), 0 , 0 , 0 )),
273
+ nn .Unfold (kernel_size = (1 , self .compress_block_size ), stride = (1 , self .compress_block_sliding_stride )),
274
+ Rearrange ('(b h) (d n) w -> b h w n d' , d = dim_head , h = kv_heads , n = self .compress_block_size )
275
+ )
264
276
265
277
assert num_compressed_mem_kv > 0
266
-
267
- # the function for splitting out the compression windows for the mlp
268
-
269
- compress_block_has_overlap = compress_block_overlap_len > 0
270
- compress_window_size = compress_block_size + compress_block_overlap_len
271
-
272
- if not compress_block_has_overlap :
273
- split_compress_window_fn = Rearrange ('b h (w n) d -> b h w n d' , n = compress_block_size )
274
- else :
275
- split_compress_window_fn = nn .Sequential (
276
- Rearrange ('b h n d -> (b h) d 1 n' ),
277
- nn .ZeroPad2d ((compress_block_overlap_len , 0 , 0 , 0 )),
278
- nn .Unfold (kernel_size = (1 , compress_window_size ), stride = (1 , compress_block_size )),
279
- Rearrange ('(b h) (d n) w -> b h w n d' , d = dim_head , h = kv_heads )
280
- )
281
-
282
- self .split_compress_window = split_compress_window_fn
283
- self .compress_window_size = compress_window_size
284
-
285
- assert compress_block_overlap_len <= compress_block_size
286
- self .compress_block_overlap_len = compress_block_overlap_len
287
-
288
- # compression attention related parameters
289
-
290
278
self .num_mem_compress_kv = num_compressed_mem_kv
291
279
self .compress_mem_kv = nn .Parameter (torch .zeros (2 , kv_heads , num_compressed_mem_kv , dim_head ))
292
280
293
- self .k_intrablock_positions = nn .Parameter (torch .zeros (kv_heads , compress_window_size , dim_head ))
294
- self .v_intrablock_positions = nn .Parameter (torch .zeros (kv_heads , compress_window_size , dim_head ))
281
+ self .k_intrablock_positions = nn .Parameter (torch .zeros (kv_heads , self . compress_block_size , dim_head ))
282
+ self .v_intrablock_positions = nn .Parameter (torch .zeros (kv_heads , self . compress_block_size , dim_head ))
295
283
296
284
if not exists (compress_mlp ):
297
- compress_dim = compress_window_size * dim_head
285
+ compress_dim = self . compress_block_size * dim_head
298
286
compress_mlp_dim_hidden = int (compress_mlp_expand_factor * compress_dim )
299
287
300
288
compress_mlp = nn .Sequential (
@@ -310,11 +298,7 @@ def __init__(
310
298
# selection related
311
299
312
300
self .use_diff_topk = use_diff_topk
313
-
314
301
self .query_heads_share_selected_kv = query_heads_share_selected_kv
315
-
316
- assert divisible_by (selection_block_size , compress_block_size ), f'selection block size { selection_block_size } must be greater than or equal to compress block size { compress_block_size } , as well as divisible by the compress block size'
317
-
318
302
self .selection_block_size = selection_block_size
319
303
320
304
assert num_selected_blocks >= 0
@@ -376,8 +360,6 @@ def forward_inference(
376
360
seq_len = cache_len + 1
377
361
378
362
sliding_window = self .sliding_window_size
379
- compress_divisible_seq_len = round_down_mult (seq_len , self .compress_block_size )
380
- compress_overlap_len = self .compress_block_overlap_len
381
363
382
364
fine_divisible_seq_len = round_up_mult (seq_len , self .selection_block_size )
383
365
num_fine_blocks = fine_divisible_seq_len // self .selection_block_size
@@ -435,7 +417,7 @@ def forward_inference(
435
417
436
418
running_compress_seq_len = run_k .shape [- 2 ]
437
419
438
- if divisible_by (running_compress_seq_len , self .compress_block_size + compress_overlap_len ):
420
+ if divisible_by (running_compress_seq_len , self .compress_block_size ):
439
421
k_compress_input = rearrange (run_k , 'b h n d -> b h 1 n d' )
440
422
v_compress_input = rearrange (run_v , 'b h n d -> b h 1 n d' )
441
423
@@ -445,6 +427,7 @@ def forward_inference(
445
427
next_ck = self .k_compress (k_compress_input )
446
428
next_cv = self .v_compress (v_compress_input )
447
429
430
+ compress_overlap_len = self .compress_block_size - self .compress_block_sliding_stride
448
431
run_kv_slice = slice (- compress_overlap_len , None ) if compress_overlap_len > 0 else slice (0 , 0 )
449
432
450
433
run_k = run_k [..., run_kv_slice , :]
@@ -461,9 +444,9 @@ def forward_inference(
461
444
importance_scores = csim [..., self .num_mem_compress_kv :]
462
445
463
446
num_compress_blocks = importance_scores .shape [- 1 ]
464
- num_compress_per_fine = self .selection_block_size // self .compress_block_size
447
+ num_compress_per_fine = self .selection_block_size // self .compress_block_sliding_stride
465
448
466
- if self .compress_block_size != self .selection_block_size :
449
+ if self .compress_block_sliding_stride != self .selection_block_size :
467
450
compress_seq_len = round_down_mult (num_compress_blocks , num_compress_per_fine )
468
451
importance_scores = importance_scores [..., :compress_seq_len ]
469
452
importance_scores = reduce (importance_scores , '... (j num_compress_per_fine) -> ... j' , 'mean' , num_compress_per_fine = num_compress_per_fine )
@@ -582,10 +565,10 @@ def forward(
582
565
583
566
batch , seq_len , scale , heads , kv_heads , device = * inp .shape [:2 ], self .scale , self .heads , self .kv_heads , inp .device
584
567
585
- compress_divisible_seq_len = round_down_mult (seq_len , self .compress_block_size )
586
- num_compress_blocks = compress_divisible_seq_len // self .compress_block_size
568
+ compress_divisible_seq_len = round_down_mult (seq_len , self .compress_block_sliding_stride )
569
+ num_compress_blocks = compress_divisible_seq_len // self .compress_block_sliding_stride
587
570
588
- compress_overlap_len = self .compress_block_overlap_len
571
+ compress_overlap_len = self .compress_block_size - self . compress_block_sliding_stride
589
572
has_compress_overlap = compress_overlap_len > 0
590
573
591
574
fine_divisible_seq_len = round_up_mult (seq_len , self .selection_block_size )
@@ -609,7 +592,7 @@ def forward(
609
592
k_compress_input = self .split_compress_window (k_compress_input )
610
593
v_compress_input = self .split_compress_window (v_compress_input )
611
594
else :
612
- k_compress_input , v_compress_input = tuple (t .reshape (batch , kv_heads , 0 , self .compress_window_size , self .dim_head ) for t in (k_compress_input , v_compress_input ))
595
+ k_compress_input , v_compress_input = tuple (t .reshape (batch , kv_heads , 0 , self .compress_block_size , self .dim_head ) for t in (k_compress_input , v_compress_input ))
613
596
614
597
# add the intra block positions
615
598
@@ -645,10 +628,10 @@ def forward(
645
628
# compressed masking
646
629
647
630
cmask = None
648
-
631
+ # TODO
649
632
if self .causal :
650
633
cq_seq = arange (seq_len , device = device )
651
- ck_seq = ((arange (num_compress_blocks , device = device ) + 1 ) * self .compress_block_size ) - 1
634
+ ck_seq = ((arange (num_compress_blocks , device = device ) + 1 ) * self .compress_block_sliding_stride ) - 1
652
635
ck_seq = F .pad (ck_seq , (num_mem_compress_kv , 0 ), value = - 1 )
653
636
654
637
cmask = einx .less ('j, i -> i j' , ck_seq , cq_seq )
@@ -686,9 +669,9 @@ def forward(
686
669
687
670
if has_selected_kv_for_fine_attn :
688
671
689
- if self .compress_block_size != self .selection_block_size :
672
+ if self .compress_block_sliding_stride != self .selection_block_size :
690
673
691
- num_compress_per_fine = self .selection_block_size // self .compress_block_size
674
+ num_compress_per_fine = self .selection_block_size // self .compress_block_sliding_stride
692
675
693
676
round_down_score_len = round_down_mult (importance_scores .shape [- 1 ], num_compress_per_fine )
694
677
importance_scores = importance_scores [..., :round_down_score_len ]
@@ -729,10 +712,7 @@ def forward(
729
712
730
713
selected_importance_values , selected_block_indices = importance_scores .topk (num_selected , dim = - 1 )
731
714
732
- gates = None
733
-
734
- if self .use_diff_topk :
735
- gates = straight_through (selected_importance_values , 1. )
715
+ gates = straight_through (selected_importance_values , 1. ) if self .use_diff_topk else None
736
716
737
717
if self .use_triton_kernel and not disable_triton_kernel :
738
718
@@ -762,9 +742,7 @@ def forward(
762
742
fmask = selected_importance_values > 1e-10
763
743
764
744
if seq_len < fine_divisible_seq_len :
765
- fk = pad_to_multiple (fk )
766
- fv = pad_to_multiple (fv )
767
- fq = pad_to_multiple (fq )
745
+ fk , fv , fq = map (pad_to_multiple , (fk , fv , fq ))
768
746
769
747
fmask = pad_at_dim (fmask , (0 , remainder ), value = False , dim = - 2 )
770
748
@@ -846,9 +824,7 @@ def forward(
846
824
seq_len = fk .shape [- 2 ]
847
825
fmask = None
848
826
849
- fk = pad_to_multiple (fk )
850
- fv = pad_to_multiple (fv )
851
- fq = pad_to_multiple (fq )
827
+ fk , fv , fq = map (pad_to_multiple , (fk , fv , fq ))
852
828
853
829
fq , fk , fv = tuple (rearrange (t , 'b h (w n) d -> (b w) h n d' , n = self .selection_block_size ) for t in (fq , fk , fv ))
854
830
0 commit comments