Skip to content

Commit 5a61460

Browse files
committed
fix an issue with mask, make sure it converges for enwik8
1 parent 949e716 commit 5a61460

File tree

7 files changed

+372
-5
lines changed

7 files changed

+372
-5
lines changed

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,20 @@ attended = attn(tokens)
3333
assert tokens.shape == attended.shape
3434
```
3535

36+
## Example
37+
38+
Enwik8 language modeling
39+
40+
```bash
41+
$ pip install .[examples]
42+
```
43+
44+
Then
45+
46+
```bash
47+
$ python train.py
48+
```
49+
3650
## Citations
3751

3852
```bibtex

data/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Data source
2+
3+
The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

data/enwik8.gz

34.9 MB
Binary file not shown.

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
from local_attention import LocalAttention
1111

12+
from rotary_embedding_torch import RotaryEmbedding
13+
1214
# einstein notation
1315

1416
import einx
@@ -92,6 +94,10 @@ def __init__(
9294

9395
self.norm = nn.RMSNorm(dim) if norm else nn.Identity()
9496

97+
# rotary
98+
99+
self.rotary_emb = RotaryEmbedding(dim_head)
100+
95101
# qkv
96102

97103
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
@@ -193,14 +199,14 @@ def forward(
193199

194200
cq_seq = arange(seq_len, device = device)
195201

196-
ck_seq = ((arange(num_compress_blocks) + 1) * self.compress_block_size) - 1
202+
ck_seq = ((arange(num_compress_blocks, device = device) + 1) * self.compress_block_size) - 1
197203
ck_seq = F.pad(ck_seq, (num_mem_compress_kv, 0), value = -1)
198204

199205
cmask = einx.less('j, i -> i j', ck_seq, cq_seq)
200206

201207
mask_value = -torch.finfo(csim.dtype).max
202208

203-
csim = csim.masked_fill(cmask, mask_value)
209+
csim = csim.masked_fill(~cmask, mask_value)
204210

205211
cattn = csim.softmax(dim = -1)
206212

@@ -218,6 +224,7 @@ def forward(
218224
fk = k
219225
fv = v
220226

227+
fq, fk = self.rotary_emb.rotate_queries_with_cached_keys(fq, fk)
221228

222229
if seq_len < fine_divisible_seq_len:
223230
remainder = fine_divisible_seq_len - seq_len
@@ -255,7 +262,7 @@ def forward(
255262

256263
fsim = einsum(fq, fk, 'b h i d, b h i j d -> b h i j') * self.scale
257264

258-
fsim = fsim.masked_fill(fmask, mask_value)
265+
fsim = fsim.masked_fill(~fmask, mask_value)
259266

260267
fattn = fsim.softmax(dim = -1)
261268

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import torch
2+
from torch import nn
3+
import torch.nn.functional as F
4+
from torch.nn import Module, ModuleList, Linear, RMSNorm
5+
6+
from einops import rearrange
7+
from einops.layers.torch import Rearrange
8+
9+
from rotary_embedding_torch import RotaryEmbedding
10+
11+
from native_sparse_attention_pytorch.native_sparse_attention import SparseAttention
12+
13+
# functions
14+
15+
def exists(v):
16+
return v is not None
17+
18+
def default(v, d):
19+
return v if exists(v) else d
20+
21+
# attention
22+
23+
class Attention(Module):
24+
def __init__(
25+
self,
26+
dim,
27+
dim_head = 64,
28+
heads = 8
29+
):
30+
super().__init__()
31+
self.norm = RMSNorm(dim)
32+
33+
self.heads = heads
34+
dim_inner = heads * dim_head
35+
36+
self.rotary_embed = RotaryEmbedding(dim_head)
37+
38+
self.to_q = nn.Linear(dim, dim_inner, bias = False)
39+
self.to_k = nn.Linear(dim, dim_inner, bias = False)
40+
self.to_v = nn.Linear(dim, dim_inner, bias = False)
41+
42+
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
43+
self.merge_heads = Rearrange('b h n d -> b n (h d)')
44+
45+
self.to_out = nn.Linear(dim_inner, dim, bias = False)
46+
47+
def forward(
48+
self,
49+
x
50+
):
51+
52+
x = self.norm(x)
53+
54+
q = self.to_q(x)
55+
k = self.to_k(x)
56+
v = self.to_v(x)
57+
58+
q, k, v = map(self.split_heads, (q, k, v))
59+
60+
# relative positions
61+
62+
q, k = self.rotary_embed.rotate_queries_with_cached_keys(q, k)
63+
64+
# attention branch
65+
66+
out = F.scaled_dot_product_attention(
67+
q, k, v,
68+
is_causal = True
69+
)
70+
71+
out = self.merge_heads(out)
72+
73+
return self.to_out(out)
74+
75+
# feedforward
76+
77+
def FeedForward(dim, expansion_factor = 4.):
78+
dim_hidden = int(dim * expansion_factor)
79+
80+
return nn.Sequential(
81+
RMSNorm(dim),
82+
Linear(dim, dim_hidden),
83+
nn.GELU(),
84+
Linear(dim_hidden, dim)
85+
)
86+
87+
# classes
88+
89+
class Transformer(Module):
90+
def __init__(
91+
self,
92+
num_tokens,
93+
dim,
94+
depth,
95+
dim_head = 64,
96+
heads = 8,
97+
ff_expansion_factor = 4.,
98+
use_sparse_attn = False,
99+
sparse_attn_kwargs: dict = dict(
100+
sliding_window_size = 32,
101+
compress_block_size = 4,
102+
selection_block_size = 4,
103+
num_selected_blocks = 4,
104+
)
105+
):
106+
super().__init__()
107+
self.token_emb = nn.Embedding(num_tokens, dim)
108+
109+
layers = []
110+
for _ in range(depth):
111+
112+
if use_sparse_attn:
113+
attn = SparseAttention(
114+
dim = dim,
115+
dim_head = dim_head,
116+
heads = heads,
117+
**sparse_attn_kwargs
118+
)
119+
else:
120+
attn = Attention(dim = dim, dim_head = dim_head, heads = heads)
121+
122+
ff = FeedForward(dim = dim, expansion_factor = ff_expansion_factor)
123+
124+
layers.append(ModuleList([attn, ff]))
125+
126+
self.layers = ModuleList(layers)
127+
128+
self.norm = RMSNorm(dim)
129+
self.to_logits = Linear(dim, num_tokens, bias = False)
130+
131+
def forward(
132+
self,
133+
ids,
134+
return_loss = False
135+
):
136+
if return_loss:
137+
ids, labels = ids[:, :-1], ids[:, 1:]
138+
139+
tokens = self.token_emb(ids)
140+
141+
for attn, ff in self.layers:
142+
tokens = attn(tokens) + tokens
143+
tokens = ff(tokens) + tokens
144+
145+
embed = self.norm(tokens)
146+
147+
logits = self.to_logits(embed)
148+
149+
if not return_loss:
150+
return logits
151+
152+
return F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)

pyproject.toml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.2"
3+
version = "0.0.3"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -26,6 +26,7 @@ dependencies = [
2626
"einx>=0.3.0",
2727
"einops>=0.8.0",
2828
"local-attention>=1.11.1",
29+
"rotary-embedding-torch",
2930
"torch>=2.2",
3031
]
3132

@@ -34,7 +35,11 @@ Homepage = "https://pypi.org/project/native-sparse-attention-pytorch/"
3435
Repository = "https://github.com/lucidrains/native-sparse-attention-pytorch"
3536

3637
[project.optional-dependencies]
37-
examples = []
38+
39+
examples = [
40+
"tqdm",
41+
"wandb"
42+
]
3843
test = [
3944
"pytest"
4045
]

0 commit comments

Comments
 (0)