Skip to content

Commit 1145e37

Browse files
committed
causal flag for the transformer and setting correct flag for flex attention block mask creation
1 parent c00e698 commit 1145e37

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

native_sparse_attention_pytorch/transformer.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(
6868
dim,
6969
dim_head = 64,
7070
heads = 8,
71+
causal = True,
7172
kv_heads = None
7273
):
7374
super().__init__()
@@ -78,6 +79,8 @@ def __init__(
7879
dim_inner = heads * dim_head
7980
dim_kv_inner = kv_heads * dim_head
8081

82+
self.causal = causal
83+
8184
self.rotary_embed = RotaryEmbedding(dim_head)
8285

8386
self.to_q = nn.Linear(dim, dim_inner, bias = False)
@@ -114,7 +117,7 @@ def forward(
114117

115118
out = F.scaled_dot_product_attention(
116119
q, k, v,
117-
is_causal = True
120+
is_causal = self.causal
118121
)
119122

120123
out = self.merge_heads(out)
@@ -146,6 +149,7 @@ def __init__(
146149
kv_heads = None,
147150
ff_expansion_factor = 4.,
148151
use_sparse_attn = False,
152+
causal = True,
149153
use_flex_sliding_window = False,
150154
use_flex_fine_selection = False,
151155
use_triton_fine_selection = False,
@@ -164,6 +168,8 @@ def __init__(
164168
if use_flex_sliding_window or use_flex_fine_selection:
165169
assert exists(flex_attention), 'flex attention is not available on your current version of pytorch'
166170

171+
self.causal = causal
172+
167173
self.use_sparse_attn = use_sparse_attn
168174
self.use_flex_sliding_window = use_sparse_attn & use_flex_sliding_window
169175
self.use_flex_fine_selection = use_sparse_attn & use_flex_fine_selection
@@ -177,6 +183,7 @@ def __init__(
177183
dim_head = dim_head,
178184
heads = heads,
179185
kv_heads = kv_heads,
186+
causal = causal,
180187
use_triton_kernel = use_triton_fine_selection,
181188
**sparse_attn_kwargs
182189
)
@@ -185,6 +192,7 @@ def __init__(
185192
dim = dim,
186193
dim_head = dim_head,
187194
heads = heads,
195+
causal = causal,
188196
kv_heads = kv_heads
189197
)
190198

@@ -275,12 +283,12 @@ def forward(
275283

276284
if not disable_flex and self.use_flex_sliding_window:
277285
attn_kwargs.update(
278-
sliding_window_flex_mask = create_sliding_mask(seq_len, self.attn_sliding_window_size)
286+
sliding_window_flex_mask = create_sliding_mask(seq_len, self.attn_sliding_window_size, causal = self.causal)
279287
)
280288

281289
if not disable_flex and self.use_flex_fine_selection:
282290
attn_kwargs.update(
283-
fine_selection_flex_mask = create_fine_mask(seq_len, self.attn_fine_block_size)
291+
fine_selection_flex_mask = create_fine_mask(seq_len, self.attn_fine_block_size, causal = self.causal)
284292
)
285293

286294
# cache

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.75"
3+
version = "0.0.76"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)