Skip to content

Commit 1ba6ec6

Browse files
committed
complete the overall idea for inference, polish up the edge cases on sunday
1 parent 197112b commit 1ba6ec6

File tree

4 files changed

+68
-13
lines changed

4 files changed

+68
-13
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ def __init__(
252252

253253
self.split_compress_window = Rearrange('b h (w n) d -> b h w n d', n = compress_block_size)
254254

255+
self.num_mem_compress_kv = num_compressed_mem_kv
255256
self.compress_mem_kv = nn.Parameter(torch.zeros(2, kv_heads, num_compressed_mem_kv, dim_head))
256257

257258
self.k_intrablock_positions = nn.Parameter(torch.zeros(kv_heads, compress_block_size, dim_head))
@@ -332,7 +333,6 @@ def forward_inference(
332333

333334
sliding_window = self.sliding_window_size
334335
compress_divisible_seq_len = round_down_mult(seq_len, self.compress_block_size)
335-
num_compress_blocks = compress_divisible_seq_len // self.compress_block_size
336336

337337
fine_divisible_seq_len = round_up_mult(seq_len, self.selection_block_size)
338338
num_fine_blocks = fine_divisible_seq_len // self.selection_block_size
@@ -361,6 +361,14 @@ def forward_inference(
361361
ck = cache_ck
362362
cv = cache_cv
363363

364+
repeated_ck = repeat(ck, 'b h ... -> b (h gh) ...', gh = self.num_grouped_queries)
365+
repeated_cv = repeat(cv, 'b h ... -> b (h gh) ...', gh = self.num_grouped_queries)
366+
367+
csim = einsum(q, repeated_ck, 'b h i d, b h j d -> b h i j') * scale
368+
cattn = csim.softmax(dim = -1)
369+
370+
compressed_attn_out = einsum(cattn, repeated_cv, 'b h i j, b h j d -> b h i d')
371+
364372
if divisible_by(seq_len, self.compress_block_size):
365373
k_compress_input = self.split_compress_window(k[..., -self.compress_block_size:, :] + self.k_intrablock_positions)
366374
v_compress_input = self.split_compress_window(v[..., -self.compress_block_size:, :] + self.v_intrablock_positions)
@@ -374,17 +382,64 @@ def forward_inference(
374382
if return_cache:
375383
cache_compressed_kv = (ck, cv)
376384

377-
ck = repeat(ck, 'b h ... -> b (h gh) ...', gh = self.num_grouped_queries)
378-
cv = repeat(cv, 'b h ... -> b (h gh) ...', gh = self.num_grouped_queries)
385+
# 2. fine attention inference (todo - compress and fine diff block sizes)
379386

380-
csim = einsum(q, ck, 'b h i d, b h j d -> b h i j') * scale
381-
cattn = csim.softmax(dim = -1)
387+
assert self.compress_block_size == self.selection_block_size
388+
389+
importance_scores = csim[..., self.num_mem_compress_kv:]
390+
importance_scores += torch.randn_like(importance_scores) * 100
391+
392+
num_compress_blocks = importance_scores.shape[-1]
393+
num_selected = min(self.num_selected_blocks, num_compress_blocks)
394+
has_selected_kv_for_fine_attn = num_selected > 0
395+
396+
# block causal diagonal
397+
398+
fine_sliding_window = (seq_len % self.selection_block_size) + 1
399+
fk = k[..., -fine_sliding_window:, :]
400+
fv = v[..., -fine_sliding_window:, :]
401+
402+
# select out the sparse kv segments as defined by compressed attention map as importance score
403+
404+
if has_selected_kv_for_fine_attn:
405+
if self.query_heads_share_selected_kv:
406+
importance_scores = reduce(importance_scores, 'b (h grouped_queries) ... -> b h ...', 'mean', grouped_queries = self.num_grouped_queries)
407+
408+
sel_scores, sel_indices = importance_scores.topk(num_selected, dim = -1)
409+
410+
fine_divisible_seq_len = round_up_mult(seq_len, self.selection_block_size)
411+
remainder = fine_divisible_seq_len - k.shape[-2]
412+
413+
sel_fk = pad_at_dim(k, (0, remainder), dim = -2)
414+
sel_fv = pad_at_dim(v, (0, remainder), dim = -2)
415+
416+
sel_fk = rearrange(sel_fk, 'b h (w j) d -> b h w j d', j = self.selection_block_size)
417+
sel_fv = rearrange(sel_fv, 'b h (w j) d -> b h w j d', j = self.selection_block_size)
418+
419+
sel_fk = einx.get_at('b h [w] j d, b h 1 sel -> b h (sel j) d', sel_fk, sel_indices)
420+
sel_fv = einx.get_at('b h [w] j d, b h 1 sel -> b h (sel j) d', sel_fv, sel_indices)
421+
422+
fmask = sel_scores > 1e-10
423+
424+
fmask = repeat(fmask, 'b h i sel -> b h i (sel j)', j = self.selection_block_size)
425+
426+
fk = cat((sel_fk, fk), dim = -2)
427+
fv = cat((sel_fv, fv), dim = -2)
428+
429+
fmask = F.pad(fmask, (0, fk.shape[-2] - fmask.shape[-1]), value = True)
430+
431+
# remove later
432+
433+
fq = rearrange(q, 'b (h gh) ... -> b h gh ...', gh = self.num_grouped_queries)
434+
435+
fsim = einsum(fq, fk, 'b h gh i d, b h j d -> b h gh i j') * scale
382436

383-
compressed_attn_out = einsum(cattn, cv, 'b h i j, b h j d -> b h i d')
437+
fsim = einx.where('b h i j, b h gh i j, -> b h gh i j', fmask, fsim, max_neg_value(fsim))
384438

385-
# 2. fine attention inference (todo)
439+
fattn = fsim.softmax(dim = -1)
386440

387-
# not implemented
441+
fine_attn_out = einsum(fattn, fv, 'b h gh i j, b h j d -> b h gh i d')
442+
fine_attn_out = rearrange(fine_attn_out, 'b h gh ... -> b (h gh) ...')
388443

389444
# 3. sliding window
390445

@@ -402,7 +457,7 @@ def forward_inference(
402457

403458
strategy_weighted_combine = self.to_strategy_combine(inp)
404459

405-
out = einsum(strategy_weighted_combine, stack([compressed_attn_out, sliding_window_attn_out, sliding_window_attn_out]), 'b h n s, s b h n d -> b h n d')
460+
out = einsum(strategy_weighted_combine, stack([compressed_attn_out, compressed_attn_out, sliding_window_attn_out]), 'b h n s, s b h n d -> b h n d')
406461

407462
# merge heads and combine them
408463

native_sparse_attention_pytorch/transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,8 @@ def forward(
255255
is_inferencing = exists(cache)
256256

257257
if is_inferencing:
258-
disable_flex &= False
259-
disable_triton_kernel &= False
258+
disable_flex |= True
259+
disable_triton_kernel |= True
260260

261261
if return_loss:
262262
ids, labels = ids[:, :-1], ids[:, 1:]

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.62"
3+
version = "0.0.63"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
INTERPOLATED_IMPORTANCE_SCORE = False
4848
USE_DIFF_TOPK = True
4949

50-
USE_EFFICIENT_INFERENCE = False # fine attn inference logic still needs implementing
50+
USE_EFFICIENT_INFERENCE = True # needs validation still
5151

5252
# experiment related
5353

0 commit comments

Comments
 (0)