@@ -315,13 +315,121 @@ def __init__(
315
315
316
316
self .combine_heads = nn .Linear (dim_inner , dim , bias = False )
317
317
318
+ def forward_inference (
319
+ self ,
320
+ inp ,
321
+ cache ,
322
+ return_cache = True
323
+ ):
324
+ # destruct cache
325
+
326
+ (cache_k , cache_v ), (cache_ck , cache_cv ) = cache
327
+
328
+ # variables
329
+
330
+ batch , scale , heads , device = inp .shape [0 ], self .scale , self .heads , inp .device
331
+ seq_len = cache_k .shape [- 2 ] + 1
332
+
333
+ sliding_window = self .sliding_window_size
334
+ compress_divisible_seq_len = round_down_mult (seq_len , self .compress_block_size )
335
+ num_compress_blocks = compress_divisible_seq_len // self .compress_block_size
336
+
337
+ fine_divisible_seq_len = round_up_mult (seq_len , self .selection_block_size )
338
+ num_fine_blocks = fine_divisible_seq_len // self .selection_block_size
339
+
340
+ # maybe prenorm
341
+
342
+ inp = self .norm (inp )
343
+
344
+ # queries, keys, values
345
+
346
+ q , k , v = self .to_qkv (inp ).split (self .qkv_split , dim = - 1 )
347
+
348
+ q , k , v = map (self .split_heads , (q , k , v ))
349
+
350
+ # handle cache
351
+
352
+ k = cat ((cache_k , k ), dim = - 2 )
353
+ v = cat ((cache_v , v ), dim = - 2 )
354
+
355
+ if return_cache :
356
+ cache_kv = (k , v )
357
+
358
+ # 1. compressed attn inference
359
+
360
+ cq = q
361
+ ck = cache_ck
362
+ cv = cache_cv
363
+
364
+ if divisible_by (seq_len , self .compress_block_size ):
365
+ k_compress_input = self .split_compress_window (k [..., - self .compress_block_size :, :] + self .k_intrablock_positions )
366
+ v_compress_input = self .split_compress_window (v [..., - self .compress_block_size :, :] + self .v_intrablock_positions )
367
+
368
+ next_ck = self .k_compress (k_compress_input )
369
+ next_cv = self .v_compress (v_compress_input )
370
+
371
+ ck = cat ((ck , next_ck ), dim = - 2 )
372
+ cv = cat ((cv , next_cv ), dim = - 2 )
373
+
374
+ if return_cache :
375
+ cache_compressed_kv = (ck , cv )
376
+
377
+ ck = repeat (ck , 'b h ... -> b (h gh) ...' , gh = self .num_grouped_queries )
378
+ cv = repeat (cv , 'b h ... -> b (h gh) ...' , gh = self .num_grouped_queries )
379
+
380
+ csim = einsum (q , ck , 'b h i d, b h j d -> b h i j' ) * scale
381
+ cattn = csim .softmax (dim = - 1 )
382
+
383
+ compressed_attn_out = einsum (cattn , cv , 'b h i j, b h j d -> b h i d' )
384
+
385
+ # 2. fine attention inference (todo)
386
+
387
+ # not implemented
388
+
389
+ # 3. sliding window
390
+
391
+ k = repeat (k , 'b h ... -> b (h gh) ...' , gh = self .num_grouped_queries )
392
+ v = repeat (v , 'b h ... -> b (h gh) ...' , gh = self .num_grouped_queries )
393
+
394
+ sliding_slice = (Ellipsis , slice (- (sliding_window + 1 ), None ), slice (None ))
395
+ rotated_q , rotated_k = self .rotary_emb .rotate_queries_with_cached_keys (q , k [sliding_slice ])
396
+
397
+ sim = einsum (rotated_q , rotated_k , 'b h i d, b h j d -> b h i j' ) * scale
398
+ attn = sim .softmax (dim = - 1 )
399
+ sliding_window_attn_out = einsum (attn , v [sliding_slice ], 'b h i j, b h j d -> b h i d' )
400
+
401
+ # combine strategies
402
+
403
+ strategy_weighted_combine = self .to_strategy_combine (inp )
404
+
405
+ out = einsum (strategy_weighted_combine , stack ([compressed_attn_out , sliding_window_attn_out , sliding_window_attn_out ]), 'b h n s, s b h n d -> b h n d' )
406
+
407
+ # merge heads and combine them
408
+
409
+ out = self .merge_heads (out )
410
+
411
+ out = self .combine_heads (out )
412
+
413
+ if not return_cache :
414
+ return out
415
+
416
+ return out , (cache_kv , cache_compressed_kv )
417
+
318
418
def forward (
319
419
self ,
320
420
inp ,
421
+ cache = None ,
321
422
disable_triton_kernel = False ,
322
423
sliding_window_flex_mask = None ,
323
- fine_selection_flex_mask = None
424
+ fine_selection_flex_mask = None ,
425
+ return_cache = False
324
426
):
427
+ is_inferencing = exists (cache )
428
+
429
+ if is_inferencing :
430
+ assert inp .shape [1 ] == 1 , 'input must be single tokens if inferencing with cache key values'
431
+ return self .forward_inference (inp , cache , return_cache = return_cache )
432
+
325
433
batch , seq_len , scale , heads , device = * inp .shape [:2 ], self .scale , self .heads , inp .device
326
434
327
435
compress_divisible_seq_len = round_down_mult (seq_len , self .compress_block_size )
@@ -340,6 +448,11 @@ def forward(
340
448
341
449
q , k , v = map (self .split_heads , (q , k , v ))
342
450
451
+ # handle cache
452
+
453
+ if return_cache :
454
+ cache_kv = (k , v )
455
+
343
456
# compressed key / values - variables prepended with `c` stands for compressed
344
457
345
458
k_pos = repeat (self .k_intrablock_positions , 'h n d -> h (r n) d' , r = num_compress_blocks )
@@ -352,6 +465,9 @@ def forward(
352
465
ck = self .k_compress (k_compress_input ) # Equation (7) of the Native Sparse Attention paper
353
466
cv = self .v_compress (v_compress_input )
354
467
468
+ if return_cache :
469
+ cache_compressed_kv = (ck , cv )
470
+
355
471
# 1. coarse attention over compressed
356
472
357
473
mem_ck , mem_cv = repeat (self .compress_mem_kv , 'kv ... -> kv b ...' , b = batch )
@@ -570,4 +686,9 @@ def forward(
570
686
571
687
out = self .merge_heads (out )
572
688
573
- return self .combine_heads (out )
689
+ out = self .combine_heads (out )
690
+
691
+ if not return_cache :
692
+ return out
693
+
694
+ return out , (cache_kv , cache_compressed_kv )
0 commit comments