@@ -68,6 +68,7 @@ def __init__(
68
68
dim ,
69
69
dim_head = 64 ,
70
70
heads = 8 ,
71
+ causal = True ,
71
72
kv_heads = None
72
73
):
73
74
super ().__init__ ()
@@ -78,6 +79,8 @@ def __init__(
78
79
dim_inner = heads * dim_head
79
80
dim_kv_inner = kv_heads * dim_head
80
81
82
+ self .causal = causal
83
+
81
84
self .rotary_embed = RotaryEmbedding (dim_head )
82
85
83
86
self .to_q = nn .Linear (dim , dim_inner , bias = False )
@@ -114,7 +117,7 @@ def forward(
114
117
115
118
out = F .scaled_dot_product_attention (
116
119
q , k , v ,
117
- is_causal = True
120
+ is_causal = self . causal
118
121
)
119
122
120
123
out = self .merge_heads (out )
@@ -146,6 +149,7 @@ def __init__(
146
149
kv_heads = None ,
147
150
ff_expansion_factor = 4. ,
148
151
use_sparse_attn = False ,
152
+ causal = True ,
149
153
use_flex_sliding_window = False ,
150
154
use_flex_fine_selection = False ,
151
155
use_triton_fine_selection = False ,
@@ -164,6 +168,8 @@ def __init__(
164
168
if use_flex_sliding_window or use_flex_fine_selection :
165
169
assert exists (flex_attention ), 'flex attention is not available on your current version of pytorch'
166
170
171
+ self .causal = causal
172
+
167
173
self .use_sparse_attn = use_sparse_attn
168
174
self .use_flex_sliding_window = use_sparse_attn & use_flex_sliding_window
169
175
self .use_flex_fine_selection = use_sparse_attn & use_flex_fine_selection
@@ -177,6 +183,7 @@ def __init__(
177
183
dim_head = dim_head ,
178
184
heads = heads ,
179
185
kv_heads = kv_heads ,
186
+ causal = causal ,
180
187
use_triton_kernel = use_triton_fine_selection ,
181
188
** sparse_attn_kwargs
182
189
)
@@ -185,6 +192,7 @@ def __init__(
185
192
dim = dim ,
186
193
dim_head = dim_head ,
187
194
heads = heads ,
195
+ causal = causal ,
188
196
kv_heads = kv_heads
189
197
)
190
198
@@ -275,12 +283,12 @@ def forward(
275
283
276
284
if not disable_flex and self .use_flex_sliding_window :
277
285
attn_kwargs .update (
278
- sliding_window_flex_mask = create_sliding_mask (seq_len , self .attn_sliding_window_size )
286
+ sliding_window_flex_mask = create_sliding_mask (seq_len , self .attn_sliding_window_size , causal = self . causal )
279
287
)
280
288
281
289
if not disable_flex and self .use_flex_fine_selection :
282
290
attn_kwargs .update (
283
- fine_selection_flex_mask = create_fine_mask (seq_len , self .attn_fine_block_size )
291
+ fine_selection_flex_mask = create_fine_mask (seq_len , self .attn_fine_block_size , causal = self . causal )
284
292
)
285
293
286
294
# cache
0 commit comments