Skip to content

Commit 8fad00a

Browse files
committed
fix flex sliding window attn
1 parent 4b357af commit 8fad00a

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,18 @@
4141

4242
# flex attn sliding attention mask
4343

44+
4445
def create_sliding_mask(seq_len, window_size, causal = True):
46+
4547
def sliding_mask(_, __, q_idx, kv_idx):
4648

4749
distance = q_idx - kv_idx
48-
mask = distance <= window_size
50+
backward_sliding_mask = distance <= window_size
4951

50-
if causal:
51-
mask = mask & q_idx >= kv_idx
52-
else:
53-
mask = mask & (distance >= -window_size)
52+
forward_distance = 0 if causal else -window_size
53+
forward_sliding_mask = distance >= forward_distance
5454

55-
return mask
55+
return backward_sliding_mask & forward_sliding_mask
5656

5757
block_mask = create_block_mask(sliding_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
5858
return block_mask

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.77"
3+
version = "0.0.78"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)