Skip to content

Commit 232f4eb

Browse files
committed
wire up flex attention for sliding windows
1 parent f77444e commit 232f4eb

File tree

4 files changed

+62
-11
lines changed

4 files changed

+62
-11
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ def __init__(
118118
autopad = True
119119
)
120120

121+
self.sliding_window_size = sliding_window_size
122+
121123
# compress strategy
122124

123125
self.compress_block_size = compress_block_size
@@ -174,7 +176,8 @@ def __init__(
174176

175177
def forward(
176178
self,
177-
inp
179+
inp,
180+
sliding_window_flex_mask = None
178181
):
179182
batch, seq_len, scale, heads, device = *inp.shape[:2], self.scale, self.heads, inp.device
180183

@@ -315,7 +318,10 @@ def forward(
315318

316319
# 3. overlapping sliding window, this is unsurprising and expected
317320

318-
sliding_window_attn_out = self.sliding_window(q, k, v)
321+
if exists(sliding_window_flex_mask):
322+
sliding_window_attn_out = flex_attention(q, k, v, block_mask = sliding_window_flex_mask)
323+
else:
324+
sliding_window_attn_out = self.sliding_window(q, k, v)
319325

320326
# combine strategies
321327

native_sparse_attention_pytorch/transformer.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,19 @@
88

99
from rotary_embedding_torch import RotaryEmbedding
1010

11-
from native_sparse_attention_pytorch.native_sparse_attention import SparseAttention
11+
from native_sparse_attention_pytorch.native_sparse_attention import SparseAttention, create_sliding_mask
12+
13+
# flex attention
14+
# https://pytorch.org/blog/flexattention/
15+
16+
flex_attention = None
17+
18+
try:
19+
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
20+
if torch.cuda.is_available():
21+
flex_attention = torch.compile(flex_attention)
22+
except ImportError:
23+
pass
1224

1325
# functions
1426

@@ -96,6 +108,7 @@ def __init__(
96108
heads = 8,
97109
ff_expansion_factor = 4.,
98110
use_sparse_attn = False,
111+
use_flex_sliding_window = False,
99112
sparse_attn_kwargs: dict = dict(
100113
sliding_window_size = 32,
101114
compress_block_size = 4,
@@ -106,6 +119,12 @@ def __init__(
106119
super().__init__()
107120
self.token_emb = nn.Embedding(num_tokens, dim)
108121

122+
if use_flex_sliding_window:
123+
assert exists(flex_attention), 'flex attention is not available on your current version of pytorch'
124+
125+
self.use_sparse_attn = use_sparse_attn
126+
self.use_flex_sliding_window = use_flex_sliding_window
127+
109128
layers = []
110129
for _ in range(depth):
111130

@@ -123,6 +142,8 @@ def __init__(
123142

124143
layers.append(ModuleList([attn, ff]))
125144

145+
self.attn_sliding_window_size = attn.sliding_window_size
146+
126147
self.layers = ModuleList(layers)
127148

128149
self.norm = RMSNorm(dim)
@@ -131,15 +152,37 @@ def __init__(
131152
def forward(
132153
self,
133154
ids,
134-
return_loss = False
155+
return_loss = False,
156+
disable_flex = False
135157
):
136158
if return_loss:
137159
ids, labels = ids[:, :-1], ids[:, 1:]
138160

161+
seq_len = ids.shape[-1]
162+
163+
# token embedding
164+
139165
tokens = self.token_emb(ids)
140166

167+
# prepare maybe flex attention masks
168+
169+
attn_kwargs = dict()
170+
171+
if not disable_flex and self.use_sparse_attn and self.use_flex_sliding_window:
172+
173+
attn_kwargs.update(
174+
sliding_window_flex_mask = create_sliding_mask(seq_len, self.attn_sliding_window_size)
175+
)
176+
177+
# layers
178+
141179
for attn, ff in self.layers:
142-
tokens = attn(tokens) + tokens
180+
attn_out = attn(
181+
tokens,
182+
**attn_kwargs
183+
)
184+
185+
tokens = attn_out + tokens
143186
tokens = ff(tokens) + tokens
144187

145188
embed = self.norm(tokens)

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.12"
3+
version = "0.0.14"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -27,7 +27,7 @@ dependencies = [
2727
"einops>=0.8.0",
2828
"local-attention>=1.11.1",
2929
"rotary-embedding-torch",
30-
"torch>=2.2",
30+
"torch>=2.5",
3131
]
3232

3333
[project.urls]

train.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
22
import gzip
33
import random
4-
import tqdm
4+
from tqdm import tqdm
55
import numpy as np
66

77
import torch
@@ -76,8 +76,9 @@ def base_decoding(
7676
prompt_seq_len, out = prompt.shape[-1], prompt.clone()
7777
sample_num_times = max(0, seq_len - prompt_seq_len)
7878

79-
for _ in range(sample_num_times):
80-
logits = net(out)
79+
for _ in tqdm(range(sample_num_times)):
80+
logits = net(out, disable_flex = True)
81+
8182
logits = logits[:, -1]
8283
logits = top_k(logits, thres = filter_thres)
8384
sample = gumbel_sample(logits, temperature = temperature, dim = -1)
@@ -93,6 +94,7 @@ def base_decoding(
9394
dim = 512,
9495
depth = 6,
9596
use_sparse_attn = USE_SPARSE_ATTN,
97+
use_flex_sliding_window = True,
9698
sparse_attn_kwargs = dict(
9799
sliding_window_size = 32,
98100
compress_block_size = 32,
@@ -144,7 +146,7 @@ def __getitem__(self, index):
144146

145147
# training
146148

147-
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"):
149+
for i in tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"):
148150
model.train()
149151

150152
for _ in range(GRAD_ACCUM_EVERY):

0 commit comments

Comments
 (0)