Skip to content

Commit b77f9d0

Browse files
committed
update NSA inference so rotated queries and keys are cached
1 parent 1d3712c commit b77f9d0

File tree

3 files changed

+39
-15
lines changed

3 files changed

+39
-15
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -324,12 +324,19 @@ def forward_inference(
324324
):
325325
# destruct cache
326326

327-
(cache_k, cache_v), (cache_ck, cache_cv) = cache
327+
(
328+
(cache_k, cache_v),
329+
(
330+
(cache_ck, cache_cv),
331+
(run_k, run_v)
332+
)
333+
) = cache
328334

329335
# variables
330336

331337
batch, scale, heads, device = inp.shape[0], self.scale, self.heads, inp.device
332-
seq_len = cache_k.shape[-2] + 1
338+
cache_len = cache_k.shape[-2]
339+
seq_len = cache_len + 1
333340

334341
sliding_window = self.sliding_window_size
335342
compress_divisible_seq_len = round_down_mult(seq_len, self.compress_block_size)
@@ -347,7 +354,17 @@ def forward_inference(
347354

348355
q, k, v = map(self.split_heads, (q, k, v))
349356

350-
# handle cache
357+
# take care of running k and v for compression, which should NOT be rotated https://arxiv.org/abs/2501.18795
358+
359+
run_k = cat((run_k, k), dim = -2)
360+
run_v = cat((run_v, v), dim = -2)
361+
362+
# rotate after updating the compression running k/v
363+
364+
q = self.rotary_emb.rotate_queries_or_keys(q, offset = cache_len)
365+
k = self.rotary_emb.rotate_queries_or_keys(k, offset = cache_len)
366+
367+
# handle cache, which stores the rotated
351368

352369
k = cat((cache_k, k), dim = -2)
353370
v = cat((cache_v, v), dim = -2)
@@ -369,18 +386,24 @@ def forward_inference(
369386

370387
compressed_attn_out = einsum(cattn, repeated_cv, 'b h i j, b h j d -> b h i d')
371388

372-
if divisible_by(seq_len, self.compress_block_size):
373-
k_compress_input = self.split_compress_window(k[..., -self.compress_block_size:, :] + self.k_intrablock_positions)
374-
v_compress_input = self.split_compress_window(v[..., -self.compress_block_size:, :] + self.v_intrablock_positions)
389+
running_compress_seq_len = run_k.shape[-2]
390+
391+
if divisible_by(running_compress_seq_len, self.compress_block_size):
392+
393+
k_compress_input = self.split_compress_window(run_k + self.k_intrablock_positions)
394+
v_compress_input = self.split_compress_window(run_v + self.v_intrablock_positions)
375395

376396
next_ck = self.k_compress(k_compress_input)
377397
next_cv = self.v_compress(v_compress_input)
378398

399+
run_k = run_k[..., 0:0, :]
400+
run_v = run_v[..., 0:0, :]
401+
379402
ck = cat((ck, next_ck), dim = -2)
380403
cv = cat((cv, next_cv), dim = -2)
381404

382405
if return_cache:
383-
cache_compressed_kv = (ck, cv)
406+
cache_compressed_kv = ((ck, cv), (run_k, run_v))
384407

385408
# 2. fine attention inference (todo - compress and fine diff block sizes)
386409

@@ -395,10 +418,8 @@ def forward_inference(
395418

396419
# block causal diagonal
397420

398-
rotated_q, rotated_k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
399-
400421
fine_sliding_window = (seq_len % self.selection_block_size) + 1
401-
fk = rotated_k[..., -fine_sliding_window:, :]
422+
fk = k[..., -fine_sliding_window:, :]
402423
fv = v[..., -fine_sliding_window:, :]
403424

404425
# select out the sparse kv segments as defined by compressed attention map as importance score
@@ -412,7 +433,7 @@ def forward_inference(
412433
fine_divisible_seq_len = round_up_mult(seq_len, self.selection_block_size)
413434
remainder = fine_divisible_seq_len - k.shape[-2]
414435

415-
sel_fk = pad_at_dim(rotated_k, (0, remainder), dim = -2)
436+
sel_fk = pad_at_dim(k, (0, remainder), dim = -2)
416437
sel_fv = pad_at_dim(v, (0, remainder), dim = -2)
417438

418439
sel_fk = rearrange(sel_fk, 'b h (w j) d -> b h w j d', j = self.selection_block_size)
@@ -438,7 +459,7 @@ def forward_inference(
438459

439460
# remove later
440461

441-
fq = rearrange(rotated_q, 'b (h gh) ... -> b h gh ...', gh = self.num_grouped_queries)
462+
fq = rearrange(q, 'b (h gh) ... -> b h gh ...', gh = self.num_grouped_queries)
442463

443464
fsim = einsum(fq, fk, 'b h gh i d, b h j d -> b h gh i j') * scale
444465

@@ -524,12 +545,15 @@ def forward(
524545
k_compress_input = self.split_compress_window(k[..., :compress_divisible_seq_len, :] + k_pos)
525546
v_compress_input = self.split_compress_window(v[..., :compress_divisible_seq_len, :] + v_pos)
526547

548+
run_k = k[..., compress_divisible_seq_len:, :]
549+
run_v = v[..., compress_divisible_seq_len:, :]
550+
527551
cq = q
528552
ck = self.k_compress(k_compress_input) # Equation (7) of the Native Sparse Attention paper
529553
cv = self.v_compress(v_compress_input)
530554

531555
if return_cache:
532-
cache_compressed_kv = (ck, cv)
556+
cache_compressed_kv = ((ck, cv), (run_k, run_v))
533557

534558
# 1. coarse attention over compressed
535559

@@ -549,7 +573,6 @@ def forward(
549573
compressed_attn_out, csim = attend(cq, ck, cv, mask = cmask, return_sim = True)
550574

551575
# for 2. and 3., will give them relative positions with rotary - compressed needs to be handled separately (even if they already have intra block absolute positions)
552-
553576
rotated_q, rotated_k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
554577

555578
# 2. fine attention over selected based on compressed attention logits - variables prepended with `f` stands for the fine attention pathway

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,6 +1094,7 @@ def backward_kernel_one_col_block_causal(
10941094
if begin_m >= seqlen_q:
10951095
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
10961096
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
1097+
10971098
backward_store_dk_dv(
10981099
dk_ptrs,
10991100
dv_ptrs,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.69"
3+
version = "0.0.71"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)