Skip to content

Commit 0e1f6fe

Browse files
committed
start with compressed and sliding window for inference
1 parent ef77474 commit 0e1f6fe

File tree

4 files changed

+178
-10
lines changed

4 files changed

+178
-10
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 123 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,13 +315,121 @@ def __init__(
315315

316316
self.combine_heads = nn.Linear(dim_inner, dim, bias = False)
317317

318+
def forward_inference(
319+
self,
320+
inp,
321+
cache,
322+
return_cache = True
323+
):
324+
# destruct cache
325+
326+
(cache_k, cache_v), (cache_ck, cache_cv) = cache
327+
328+
# variables
329+
330+
batch, scale, heads, device = inp.shape[0], self.scale, self.heads, inp.device
331+
seq_len = cache_k.shape[-2] + 1
332+
333+
sliding_window = self.sliding_window_size
334+
compress_divisible_seq_len = round_down_mult(seq_len, self.compress_block_size)
335+
num_compress_blocks = compress_divisible_seq_len // self.compress_block_size
336+
337+
fine_divisible_seq_len = round_up_mult(seq_len, self.selection_block_size)
338+
num_fine_blocks = fine_divisible_seq_len // self.selection_block_size
339+
340+
# maybe prenorm
341+
342+
inp = self.norm(inp)
343+
344+
# queries, keys, values
345+
346+
q, k, v = self.to_qkv(inp).split(self.qkv_split, dim = -1)
347+
348+
q, k, v = map(self.split_heads, (q, k, v))
349+
350+
# handle cache
351+
352+
k = cat((cache_k, k), dim = -2)
353+
v = cat((cache_v, v), dim = -2)
354+
355+
if return_cache:
356+
cache_kv = (k, v)
357+
358+
# 1. compressed attn inference
359+
360+
cq = q
361+
ck = cache_ck
362+
cv = cache_cv
363+
364+
if divisible_by(seq_len, self.compress_block_size):
365+
k_compress_input = self.split_compress_window(k[..., -self.compress_block_size:, :] + self.k_intrablock_positions)
366+
v_compress_input = self.split_compress_window(v[..., -self.compress_block_size:, :] + self.v_intrablock_positions)
367+
368+
next_ck = self.k_compress(k_compress_input)
369+
next_cv = self.v_compress(v_compress_input)
370+
371+
ck = cat((ck, next_ck), dim = -2)
372+
cv = cat((cv, next_cv), dim = -2)
373+
374+
if return_cache:
375+
cache_compressed_kv = (ck, cv)
376+
377+
ck = repeat(ck, 'b h ... -> b (h gh) ...', gh = self.num_grouped_queries)
378+
cv = repeat(cv, 'b h ... -> b (h gh) ...', gh = self.num_grouped_queries)
379+
380+
csim = einsum(q, ck, 'b h i d, b h j d -> b h i j') * scale
381+
cattn = csim.softmax(dim = -1)
382+
383+
compressed_attn_out = einsum(cattn, cv, 'b h i j, b h j d -> b h i d')
384+
385+
# 2. fine attention inference (todo)
386+
387+
# not implemented
388+
389+
# 3. sliding window
390+
391+
k = repeat(k, 'b h ... -> b (h gh) ...', gh = self.num_grouped_queries)
392+
v = repeat(v, 'b h ... -> b (h gh) ...', gh = self.num_grouped_queries)
393+
394+
sliding_slice = (Ellipsis, slice(-(sliding_window + 1), None), slice(None))
395+
rotated_q, rotated_k = self.rotary_emb.rotate_queries_with_cached_keys(q, k[sliding_slice])
396+
397+
sim = einsum(rotated_q, rotated_k, 'b h i d, b h j d -> b h i j') * scale
398+
attn = sim.softmax(dim = -1)
399+
sliding_window_attn_out = einsum(attn, v[sliding_slice], 'b h i j, b h j d -> b h i d')
400+
401+
# combine strategies
402+
403+
strategy_weighted_combine = self.to_strategy_combine(inp)
404+
405+
out = einsum(strategy_weighted_combine, stack([compressed_attn_out, sliding_window_attn_out, sliding_window_attn_out]), 'b h n s, s b h n d -> b h n d')
406+
407+
# merge heads and combine them
408+
409+
out = self.merge_heads(out)
410+
411+
out = self.combine_heads(out)
412+
413+
if not return_cache:
414+
return out
415+
416+
return out, (cache_kv, cache_compressed_kv)
417+
318418
def forward(
319419
self,
320420
inp,
421+
cache = None,
321422
disable_triton_kernel = False,
322423
sliding_window_flex_mask = None,
323-
fine_selection_flex_mask = None
424+
fine_selection_flex_mask = None,
425+
return_cache = False
324426
):
427+
is_inferencing = exists(cache)
428+
429+
if is_inferencing:
430+
assert inp.shape[1] == 1, 'input must be single tokens if inferencing with cache key values'
431+
return self.forward_inference(inp, cache, return_cache = return_cache)
432+
325433
batch, seq_len, scale, heads, device = *inp.shape[:2], self.scale, self.heads, inp.device
326434

327435
compress_divisible_seq_len = round_down_mult(seq_len, self.compress_block_size)
@@ -340,6 +448,11 @@ def forward(
340448

341449
q, k, v = map(self.split_heads, (q, k, v))
342450

451+
# handle cache
452+
453+
if return_cache:
454+
cache_kv = (k, v)
455+
343456
# compressed key / values - variables prepended with `c` stands for compressed
344457

345458
k_pos = repeat(self.k_intrablock_positions, 'h n d -> h (r n) d', r = num_compress_blocks)
@@ -352,6 +465,9 @@ def forward(
352465
ck = self.k_compress(k_compress_input) # Equation (7) of the Native Sparse Attention paper
353466
cv = self.v_compress(v_compress_input)
354467

468+
if return_cache:
469+
cache_compressed_kv = (ck, cv)
470+
355471
# 1. coarse attention over compressed
356472

357473
mem_ck, mem_cv = repeat(self.compress_mem_kv, 'kv ... -> kv b ...', b = batch)
@@ -570,4 +686,9 @@ def forward(
570686

571687
out = self.merge_heads(out)
572688

573-
return self.combine_heads(out)
689+
out = self.combine_heads(out)
690+
691+
if not return_cache:
692+
return out
693+
694+
return out, (cache_kv, cache_compressed_kv)

native_sparse_attention_pytorch/transformer.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -200,23 +200,31 @@ def __init__(
200200
self.norm = RMSNorm(dim)
201201
self.to_logits = Linear(dim, num_tokens, bias = False)
202202

203+
@torch.no_grad()
203204
def sample(
204205
self,
205206
prompt: Tensor,
206207
seq_len: int,
207208
temperature = 1.,
208209
filter_thres = 0.9,
210+
use_cache_kv = False
209211
):
210212
prompt_seq_len, out = prompt.shape[-1], prompt.clone()
211213
sample_num_times = max(0, seq_len - prompt_seq_len)
212214

215+
cache = None
216+
213217
for _ in tqdm(range(sample_num_times)):
214-
logits = self.forward(
218+
219+
logits, next_cache = self.forward(
215220
out,
216-
disable_flex = True,
217-
disable_triton_kernel = True
221+
cache = cache,
222+
return_cache = True
218223
)
219224

225+
if use_cache_kv:
226+
cache = next_cache
227+
220228
logits = logits[:, -1]
221229
logits = top_k(logits, thres = filter_thres)
222230
sample = gumbel_sample(logits, temperature = temperature, dim = -1)
@@ -225,13 +233,28 @@ def sample(
225233

226234
return out[..., prompt_seq_len:]
227235

236+
def forward_inference(
237+
self,
238+
ids,
239+
cache = None
240+
):
241+
return ids
242+
228243
def forward(
229244
self,
230245
ids,
231246
return_loss = False,
232247
disable_flex = False,
233-
disable_triton_kernel = False
248+
disable_triton_kernel = False,
249+
cache = None,
250+
return_cache = False
234251
):
252+
is_inferencing = exists(cache)
253+
254+
if is_inferencing:
255+
disable_flex &= False
256+
disable_triton_kernel &= False
257+
235258
if return_loss:
236259
ids, labels = ids[:, :-1], ids[:, 1:]
237260

@@ -257,14 +280,29 @@ def forward(
257280
fine_selection_flex_mask = create_fine_mask(seq_len, self.attn_fine_block_size)
258281
)
259282

283+
# cache
284+
285+
cache = default(cache, [])
286+
iter_cache = iter(cache)
287+
288+
next_cache = []
289+
290+
if is_inferencing:
291+
tokens = tokens[:, -1:]
292+
260293
# layers
261294

262295
for attn, ff in self.layers:
263-
attn_out = attn(
296+
297+
attn_out, layer_cache = attn(
264298
tokens,
299+
cache = next(iter_cache, None),
300+
return_cache = True,
265301
**attn_kwargs
266302
)
267303

304+
next_cache.append(layer_cache)
305+
268306
tokens = attn_out + tokens
269307
tokens = ff(tokens) + tokens
270308

@@ -273,6 +311,9 @@ def forward(
273311
logits = self.to_logits(embed)
274312

275313
if not return_loss:
276-
return logits
314+
if not return_cache:
315+
return logits
316+
317+
return logits, next_cache
277318

278319
return F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.60"
3+
version = "0.0.61"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

train.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
INTERPOLATED_IMPORTANCE_SCORE = False
4848
USE_DIFF_TOPK = True
4949

50+
USE_EFFICIENT_INFERENCE = False # fine attn inference logic still needs implementing
51+
5052
# experiment related
5153

5254
PROJECT_NAME = 'native-sparse-attention'
@@ -187,7 +189,11 @@ def __getitem__(self, index):
187189

188190
prompt = inp[None, ...]
189191

190-
sampled = model.sample(prompt, GENERATE_LENGTH)
192+
sampled = model.sample(
193+
prompt,
194+
GENERATE_LENGTH,
195+
use_cache_kv = USE_EFFICIENT_INFERENCE
196+
)
191197

192198
base_decode_output = decode_tokens(sampled[0])
193199

0 commit comments

Comments
 (0)