Skip to content

Commit 56cce5d

Browse files
committed
make it right for now, optimizer later
1 parent 1ba6ec6 commit 56cce5d

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -395,8 +395,10 @@ def forward_inference(
395395

396396
# block causal diagonal
397397

398+
rotated_q, rotated_k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
399+
398400
fine_sliding_window = (seq_len % self.selection_block_size) + 1
399-
fk = k[..., -fine_sliding_window:, :]
401+
fk = rotated_k[..., -fine_sliding_window:, :]
400402
fv = v[..., -fine_sliding_window:, :]
401403

402404
# select out the sparse kv segments as defined by compressed attention map as importance score
@@ -410,7 +412,7 @@ def forward_inference(
410412
fine_divisible_seq_len = round_up_mult(seq_len, self.selection_block_size)
411413
remainder = fine_divisible_seq_len - k.shape[-2]
412414

413-
sel_fk = pad_at_dim(k, (0, remainder), dim = -2)
415+
sel_fk = pad_at_dim(rotated_k, (0, remainder), dim = -2)
414416
sel_fv = pad_at_dim(v, (0, remainder), dim = -2)
415417

416418
sel_fk = rearrange(sel_fk, 'b h (w j) d -> b h w j d', j = self.selection_block_size)
@@ -430,7 +432,7 @@ def forward_inference(
430432

431433
# remove later
432434

433-
fq = rearrange(q, 'b (h gh) ... -> b h gh ...', gh = self.num_grouped_queries)
435+
fq = rearrange(rotated_q, 'b (h gh) ... -> b h gh ...', gh = self.num_grouped_queries)
434436

435437
fsim = einsum(fq, fk, 'b h gh i d, b h j d -> b h gh i j') * scale
436438

@@ -457,7 +459,7 @@ def forward_inference(
457459

458460
strategy_weighted_combine = self.to_strategy_combine(inp)
459461

460-
out = einsum(strategy_weighted_combine, stack([compressed_attn_out, compressed_attn_out, sliding_window_attn_out]), 'b h n s, s b h n d -> b h n d')
462+
out = einsum(strategy_weighted_combine, stack([compressed_attn_out, fine_attn_out, sliding_window_attn_out]), 'b h n s, s b h n d -> b h n d')
461463

462464
# merge heads and combine them
463465

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.63"
3+
version = "0.0.64"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)