Skip to content

Commit 599f312

Browse files
committed
prepare for knocking out inference logic over the weekend, last commit for this project for the day
1 parent c395e2a commit 599f312

File tree

3 files changed

+50
-47
lines changed

3 files changed

+50
-47
lines changed

native_sparse_attention_pytorch/transformer.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import torch
2-
from torch import nn
2+
from torch import nn, Tensor
33
import torch.nn.functional as F
44
from torch.nn import Module, ModuleList, Linear, RMSNorm
55

6+
from math import ceil
7+
from tqdm import tqdm
8+
69
from einops import rearrange, repeat
710
from einops.layers.torch import Rearrange
811

@@ -38,6 +41,25 @@ def default(v, d):
3841
def at_most_one_of(*bools):
3942
return sum([*map(int, bools)]) <= 1
4043

44+
# sampling helpers
45+
46+
def log(t, eps = 1e-20):
47+
return torch.log(t.clamp(min = eps))
48+
49+
def gumbel_noise(t):
50+
noise = torch.zeros_like(t).uniform_(0, 1)
51+
return -log(-log(noise))
52+
53+
def gumbel_sample(t, temperature = 1., dim = -1, keepdim = True):
54+
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim, keepdim = keepdim)
55+
56+
def top_k(logits, thres = 0.9):
57+
k = ceil((1 - thres) * logits.shape[-1])
58+
val, ind = torch.topk(logits, k)
59+
probs = torch.full_like(logits, float('-inf'))
60+
probs.scatter_(-1, ind, val)
61+
return probs
62+
4163
# attention
4264

4365
class Attention(Module):
@@ -178,6 +200,31 @@ def __init__(
178200
self.norm = RMSNorm(dim)
179201
self.to_logits = Linear(dim, num_tokens, bias = False)
180202

203+
def sample(
204+
self,
205+
prompt: Tensor,
206+
seq_len: int,
207+
temperature = 1.,
208+
filter_thres = 0.9,
209+
):
210+
prompt_seq_len, out = prompt.shape[-1], prompt.clone()
211+
sample_num_times = max(0, seq_len - prompt_seq_len)
212+
213+
for _ in tqdm(range(sample_num_times)):
214+
logits = self.forward(
215+
out,
216+
disable_flex = True,
217+
disable_triton_kernel = True
218+
)
219+
220+
logits = logits[:, -1]
221+
logits = top_k(logits, thres = filter_thres)
222+
sample = gumbel_sample(logits, temperature = temperature, dim = -1)
223+
224+
out = torch.cat((out, sample), dim = -1)
225+
226+
return out[..., prompt_seq_len:]
227+
181228
def forward(
182229
self,
183230
ids,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.55"
3+
version = "0.0.56"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

train.py

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -69,50 +69,6 @@ def decode_token(token):
6969
def decode_tokens(tokens):
7070
return "".join(list(map(decode_token, tokens)))
7171

72-
# sampling helpers
73-
74-
def log(t, eps = 1e-20):
75-
return torch.log(t.clamp(min = eps))
76-
77-
def gumbel_noise(t):
78-
noise = torch.zeros_like(t).uniform_(0, 1)
79-
return -log(-log(noise))
80-
81-
def gumbel_sample(t, temperature = 1., dim = -1, keepdim = True):
82-
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim, keepdim = keepdim)
83-
84-
def top_k(logits, thres = 0.9):
85-
k = math.ceil((1 - thres) * logits.shape[-1])
86-
val, ind = torch.topk(logits, k)
87-
probs = torch.full_like(logits, float('-inf'))
88-
probs.scatter_(-1, ind, val)
89-
return probs
90-
91-
def base_decoding(
92-
net,
93-
prompt: Tensor,
94-
seq_len: int,
95-
temperature = 1.,
96-
filter_thres = 0.9,
97-
):
98-
prompt_seq_len, out = prompt.shape[-1], prompt.clone()
99-
sample_num_times = max(0, seq_len - prompt_seq_len)
100-
101-
for _ in tqdm(range(sample_num_times)):
102-
logits = net(
103-
out,
104-
disable_flex = True,
105-
disable_triton_kernel = True
106-
)
107-
108-
logits = logits[:, -1]
109-
logits = top_k(logits, thres = filter_thres)
110-
sample = gumbel_sample(logits, temperature = temperature, dim = -1)
111-
112-
out = torch.cat((out, sample), dim = -1)
113-
114-
return out[..., prompt_seq_len:]
115-
11672
# printing
11773

11874
if USE_TRITON_NSA:
@@ -231,7 +187,7 @@ def __getitem__(self, index):
231187

232188
prompt = inp[None, ...]
233189

234-
sampled = base_decoding(model, prompt, GENERATE_LENGTH)
190+
sampled = model.sample(prompt, GENERATE_LENGTH)
235191

236192
base_decode_output = decode_tokens(sampled[0])
237193

0 commit comments

Comments
 (0)