41
41
42
42
# flex attn sliding attention mask
43
43
44
- def create_sliding_mask (seq_len , window_size ):
44
+ def create_sliding_mask (seq_len , window_size , causal = True ):
45
45
def sliding_mask (_ , __ , q_idx , kv_idx ):
46
- causal_mask = q_idx >= kv_idx
47
46
48
- sliding_mask = ( q_idx - kv_idx ) <= window_size
49
- causal_mask = causal_mask & sliding_mask
47
+ distance = q_idx - kv_idx
48
+ mask = distance <= window_size
50
49
51
- return causal_mask
50
+ if causal :
51
+ mask = mask & q_idx >= kv_idx
52
+ else :
53
+ mask = mask & (distance >= - window_size )
54
+
55
+ return mask
52
56
53
57
block_mask = create_block_mask (sliding_mask , B = None , H = None , Q_LEN = seq_len , KV_LEN = seq_len , _compile = True )
54
58
return block_mask
55
59
56
- def create_compress_mask (seq_len , kv_seq_len , compress_block_size , mem_kv_len = 0 ):
60
+ def create_compress_mask (seq_len , kv_seq_len , compress_block_size , mem_kv_len = 0 , causal = True ):
61
+
62
+ if not causal :
63
+ return None
64
+
57
65
# cannot be used as using attention logits for importance score
58
66
# but just to show the immense potential of flex attention
59
67
@@ -69,7 +77,7 @@ def compress_mask(_, __, q_idx, kv_idx):
69
77
block_mask = create_block_mask (compress_mask , B = None , H = None , Q_LEN = seq_len , KV_LEN = kv_seq_len + mem_kv_len , _compile = True )
70
78
return block_mask
71
79
72
- def create_fine_mask (seq_len , fine_block_size ):
80
+ def create_fine_mask (seq_len , fine_block_size , causal = True ):
73
81
74
82
def inner (selected_block_indices : Tensor , num_grouped_queries = 1 ):
75
83
device = selected_block_indices .device
@@ -86,6 +94,9 @@ def fine_mask(b_idx, h_idx, q_idx, kv_idx):
86
94
87
95
is_selected = one_hot_selected_block_indices [b_idx , kv_head_idx , q_idx , compressed_kv_idx ]
88
96
97
+ if not causal :
98
+ return is_selected
99
+
89
100
causal_mask = q_idx >= kv_idx
90
101
block_diagonal = compressed_q_idx == compressed_kv_idx
91
102
@@ -189,6 +200,7 @@ def __init__(
189
200
num_selected_blocks ,
190
201
kv_heads = None ,
191
202
num_compressed_mem_kv = 1 ,
203
+ causal = False ,
192
204
norm = True ,
193
205
use_diff_topk = False ,
194
206
use_triton_kernel = False ,
@@ -219,6 +231,10 @@ def __init__(
219
231
220
232
self .norm = nn .RMSNorm (dim ) if norm else nn .Identity ()
221
233
234
+ # autoregressive or not - will extend this work for long context video / genomics use-cases
235
+
236
+ self .causal = causal
237
+
222
238
# rotary
223
239
224
240
self .rotary_emb = RotaryEmbedding (dim_head )
@@ -236,7 +252,7 @@ def __init__(
236
252
self .sliding_window = LocalAttention (
237
253
dim = dim_head ,
238
254
window_size = sliding_window_size ,
239
- causal = True ,
255
+ causal = causal ,
240
256
exact_windowsize = True ,
241
257
autopad = True ,
242
258
use_rotary_pos_emb = False
@@ -322,6 +338,8 @@ def forward_inference(
322
338
cache ,
323
339
return_cache = True
324
340
):
341
+ assert self .causal , 'inference only relevant for autoregressive'
342
+
325
343
# destruct cache
326
344
327
345
(
@@ -515,6 +533,8 @@ def forward(
515
533
assert inp .shape [1 ] == 1 , 'input must be single tokens if inferencing with cache key values'
516
534
return self .forward_inference (inp , cache , return_cache = return_cache )
517
535
536
+ assert not (self .causal and return_cache )
537
+
518
538
batch , seq_len , scale , heads , device = * inp .shape [:2 ], self .scale , self .heads , inp .device
519
539
520
540
compress_divisible_seq_len = round_down_mult (seq_len , self .compress_block_size )
@@ -560,11 +580,16 @@ def forward(
560
580
ck = cat ((mem_ck , ck ), dim = - 2 )
561
581
cv = cat ((mem_cv , cv ), dim = - 2 )
562
582
563
- cq_seq = arange ( seq_len , device = device )
564
- ck_seq = (( arange ( num_compress_blocks , device = device ) + 1 ) * self . compress_block_size ) - 1
565
- ck_seq = F . pad ( ck_seq , ( num_mem_compress_kv , 0 ), value = - 1 )
583
+ # compressed masking
584
+
585
+ cmask = None
566
586
567
- cmask = einx .less ('j, i -> i j' , ck_seq , cq_seq )
587
+ if self .causal :
588
+ cq_seq = arange (seq_len , device = device )
589
+ ck_seq = ((arange (num_compress_blocks , device = device ) + 1 ) * self .compress_block_size ) - 1
590
+ ck_seq = F .pad (ck_seq , (num_mem_compress_kv , 0 ), value = - 1 )
591
+
592
+ cmask = einx .less ('j, i -> i j' , ck_seq , cq_seq )
568
593
569
594
compressed_attn_out , csim = attend (cq , ck , cv , mask = cmask , return_sim = True )
570
595
@@ -657,7 +682,8 @@ def forward(
657
682
self .selection_block_size ,
658
683
selected_block_indices ,
659
684
fmask ,
660
- sel_scale = gates
685
+ sel_scale = gates ,
686
+ include_block_diagonal = self .causal
661
687
)
662
688
663
689
elif exists (fine_selection_flex_mask ):
@@ -685,19 +711,23 @@ def forward(
685
711
if exists (gates ):
686
712
gates = pad_at_dim (gates , (0 , remainder ), value = 0 , dim = - 2 )
687
713
688
- # handle block causal diagonal in the diagram, but run experiments without to see
714
+ if self .causal :
715
+ # handle block causal diagonal in the diagram, but run experiments without to see
716
+
717
+ fine_window_seq = arange (fine_divisible_seq_len , device = device ) // self .selection_block_size
718
+ fine_window_seq = repeat (fine_window_seq , 'n -> b h n 1' , b = batch , h = selected_block_indices .shape [1 ])
719
+ selected_block_indices = cat ((selected_block_indices , fine_window_seq ), dim = - 1 ) # for the block causal diagonal in fig2
689
720
690
- fine_window_seq = arange (fine_divisible_seq_len , device = device ) // self .selection_block_size
691
- fine_window_seq = repeat (fine_window_seq , 'n -> b h n 1' , b = batch , h = selected_block_indices .shape [1 ])
692
- selected_block_indices = cat ((selected_block_indices , fine_window_seq ), dim = - 1 ) # for the block causal diagonal in fig2
721
+ fmask = repeat (fmask , 'b h i w -> b h i w j' , j = self .selection_block_size )
693
722
694
- fmask = repeat (fmask , 'b h i w -> b h i w j' , j = self .selection_block_size )
723
+ causal_mask = torch .ones ((self .selection_block_size ,) * 2 , device = device , dtype = torch .bool ).tril ()
724
+ causal_mask = repeat (causal_mask , 'i j -> b h (w i) 1 j' , w = num_fine_blocks , b = batch , h = fmask .shape [1 ])
695
725
696
- causal_mask = torch . ones (( self . selection_block_size ,) * 2 , device = device , dtype = torch . bool ). tril ( )
697
- causal_mask = repeat ( causal_mask , 'i j -> b h (w i) 1 j' , w = num_fine_blocks , b = batch , h = fmask . shape [ 1 ] )
726
+ fmask = cat (( fmask , causal_mask ), dim = - 2 )
727
+ fmask = rearrange ( fmask , 'b h i w j -> b h 1 i (w j)' )
698
728
699
- fmask = cat (( fmask , causal_mask ), dim = - 2 )
700
- fmask = rearrange (fmask , 'b h i w j -> b h i (w j)' )
729
+ else :
730
+ fmask = repeat (fmask , 'b h i w -> b h 1 i (w j)' , j = self . selection_block_size )
701
731
702
732
# select out the spatial crops of keys / values for fine attention
703
733
@@ -721,7 +751,9 @@ def forward(
721
751
# differential topk gating
722
752
723
753
if self .use_diff_topk :
724
- gates = F .pad (gates , (0 , 1 ), value = 1. )
754
+ if self .causal :
755
+ gates = F .pad (gates , (0 , 1 ), value = 1. )
756
+
725
757
fk = einx .multiply ('b h i sel, b h i sel j d -> b h i sel j d' , gates , fk )
726
758
727
759
# merge selected key values
@@ -730,8 +762,6 @@ def forward(
730
762
731
763
# fine attention
732
764
733
- fmask = rearrange (fmask , 'b h ... -> b h 1 ...' )
734
-
735
765
fq = rearrange (fq , 'b (h qh) ... -> b h qh ...' , qh = fine_num_grouped_queries )
736
766
737
767
fsim = einsum (fq , fk , 'b h qh i d, b h i j d -> b h qh i j' ) * self .scale
@@ -752,7 +782,10 @@ def forward(
752
782
# if only first block, just do a simple block causal
753
783
754
784
seq_len = fk .shape [- 2 ]
755
- fmask = causal_mask = torch .ones ((seq_len , seq_len ), device = device , dtype = torch .bool ).tril ()
785
+ fmask = None
786
+
787
+ if self .causal :
788
+ fmask = causal_mask = torch .ones ((seq_len , seq_len ), device = device , dtype = torch .bool ).tril ()
756
789
757
790
fine_attn_out = attend (fq , fk , fv , mask = fmask )
758
791
0 commit comments