@@ -252,6 +252,7 @@ def __init__(
252
252
253
253
self .split_compress_window = Rearrange ('b h (w n) d -> b h w n d' , n = compress_block_size )
254
254
255
+ self .num_mem_compress_kv = num_compressed_mem_kv
255
256
self .compress_mem_kv = nn .Parameter (torch .zeros (2 , kv_heads , num_compressed_mem_kv , dim_head ))
256
257
257
258
self .k_intrablock_positions = nn .Parameter (torch .zeros (kv_heads , compress_block_size , dim_head ))
@@ -332,7 +333,6 @@ def forward_inference(
332
333
333
334
sliding_window = self .sliding_window_size
334
335
compress_divisible_seq_len = round_down_mult (seq_len , self .compress_block_size )
335
- num_compress_blocks = compress_divisible_seq_len // self .compress_block_size
336
336
337
337
fine_divisible_seq_len = round_up_mult (seq_len , self .selection_block_size )
338
338
num_fine_blocks = fine_divisible_seq_len // self .selection_block_size
@@ -361,6 +361,14 @@ def forward_inference(
361
361
ck = cache_ck
362
362
cv = cache_cv
363
363
364
+ repeated_ck = repeat (ck , 'b h ... -> b (h gh) ...' , gh = self .num_grouped_queries )
365
+ repeated_cv = repeat (cv , 'b h ... -> b (h gh) ...' , gh = self .num_grouped_queries )
366
+
367
+ csim = einsum (q , repeated_ck , 'b h i d, b h j d -> b h i j' ) * scale
368
+ cattn = csim .softmax (dim = - 1 )
369
+
370
+ compressed_attn_out = einsum (cattn , repeated_cv , 'b h i j, b h j d -> b h i d' )
371
+
364
372
if divisible_by (seq_len , self .compress_block_size ):
365
373
k_compress_input = self .split_compress_window (k [..., - self .compress_block_size :, :] + self .k_intrablock_positions )
366
374
v_compress_input = self .split_compress_window (v [..., - self .compress_block_size :, :] + self .v_intrablock_positions )
@@ -374,17 +382,64 @@ def forward_inference(
374
382
if return_cache :
375
383
cache_compressed_kv = (ck , cv )
376
384
377
- ck = repeat (ck , 'b h ... -> b (h gh) ...' , gh = self .num_grouped_queries )
378
- cv = repeat (cv , 'b h ... -> b (h gh) ...' , gh = self .num_grouped_queries )
385
+ # 2. fine attention inference (todo - compress and fine diff block sizes)
379
386
380
- csim = einsum (q , ck , 'b h i d, b h j d -> b h i j' ) * scale
381
- cattn = csim .softmax (dim = - 1 )
387
+ assert self .compress_block_size == self .selection_block_size
388
+
389
+ importance_scores = csim [..., self .num_mem_compress_kv :]
390
+ importance_scores += torch .randn_like (importance_scores ) * 100
391
+
392
+ num_compress_blocks = importance_scores .shape [- 1 ]
393
+ num_selected = min (self .num_selected_blocks , num_compress_blocks )
394
+ has_selected_kv_for_fine_attn = num_selected > 0
395
+
396
+ # block causal diagonal
397
+
398
+ fine_sliding_window = (seq_len % self .selection_block_size ) + 1
399
+ fk = k [..., - fine_sliding_window :, :]
400
+ fv = v [..., - fine_sliding_window :, :]
401
+
402
+ # select out the sparse kv segments as defined by compressed attention map as importance score
403
+
404
+ if has_selected_kv_for_fine_attn :
405
+ if self .query_heads_share_selected_kv :
406
+ importance_scores = reduce (importance_scores , 'b (h grouped_queries) ... -> b h ...' , 'mean' , grouped_queries = self .num_grouped_queries )
407
+
408
+ sel_scores , sel_indices = importance_scores .topk (num_selected , dim = - 1 )
409
+
410
+ fine_divisible_seq_len = round_up_mult (seq_len , self .selection_block_size )
411
+ remainder = fine_divisible_seq_len - k .shape [- 2 ]
412
+
413
+ sel_fk = pad_at_dim (k , (0 , remainder ), dim = - 2 )
414
+ sel_fv = pad_at_dim (v , (0 , remainder ), dim = - 2 )
415
+
416
+ sel_fk = rearrange (sel_fk , 'b h (w j) d -> b h w j d' , j = self .selection_block_size )
417
+ sel_fv = rearrange (sel_fv , 'b h (w j) d -> b h w j d' , j = self .selection_block_size )
418
+
419
+ sel_fk = einx .get_at ('b h [w] j d, b h 1 sel -> b h (sel j) d' , sel_fk , sel_indices )
420
+ sel_fv = einx .get_at ('b h [w] j d, b h 1 sel -> b h (sel j) d' , sel_fv , sel_indices )
421
+
422
+ fmask = sel_scores > 1e-10
423
+
424
+ fmask = repeat (fmask , 'b h i sel -> b h i (sel j)' , j = self .selection_block_size )
425
+
426
+ fk = cat ((sel_fk , fk ), dim = - 2 )
427
+ fv = cat ((sel_fv , fv ), dim = - 2 )
428
+
429
+ fmask = F .pad (fmask , (0 , fk .shape [- 2 ] - fmask .shape [- 1 ]), value = True )
430
+
431
+ # remove later
432
+
433
+ fq = rearrange (q , 'b (h gh) ... -> b h gh ...' , gh = self .num_grouped_queries )
434
+
435
+ fsim = einsum (fq , fk , 'b h gh i d, b h j d -> b h gh i j' ) * scale
382
436
383
- compressed_attn_out = einsum ( cattn , cv , 'b h i j, b h j d -> b h i d' )
437
+ fsim = einx . where ( 'b h i j, b h gh i j, -> b h gh i j' , fmask , fsim , max_neg_value ( fsim ) )
384
438
385
- # 2. fine attention inference (todo )
439
+ fattn = fsim . softmax ( dim = - 1 )
386
440
387
- # not implemented
441
+ fine_attn_out = einsum (fattn , fv , 'b h gh i j, b h j d -> b h gh i d' )
442
+ fine_attn_out = rearrange (fine_attn_out , 'b h gh ... -> b (h gh) ...' )
388
443
389
444
# 3. sliding window
390
445
@@ -402,7 +457,7 @@ def forward_inference(
402
457
403
458
strategy_weighted_combine = self .to_strategy_combine (inp )
404
459
405
- out = einsum (strategy_weighted_combine , stack ([compressed_attn_out , sliding_window_attn_out , sliding_window_attn_out ]), 'b h n s, s b h n d -> b h n d' )
460
+ out = einsum (strategy_weighted_combine , stack ([compressed_attn_out , compressed_attn_out , sliding_window_attn_out ]), 'b h n s, s b h n d -> b h n d' )
406
461
407
462
# merge heads and combine them
408
463
0 commit comments