Skip to content

Commit f4d28f8

Browse files
committed
resolve mem ck and cv not used during inference
1 parent e901e73 commit f4d28f8

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,8 +419,17 @@ def forward_inference(
419419
ck = cache_ck
420420
cv = cache_cv
421421

422-
repeated_ck = repeat(ck, 'b h ... -> b (h gh) ...', gh = self.num_grouped_queries)
423-
repeated_cv = repeat(cv, 'b h ... -> b (h gh) ...', gh = self.num_grouped_queries)
422+
ck_for_attn = cache_ck
423+
cv_for_attn = cache_cv
424+
425+
if not is_empty(ck):
426+
mem_ck, mem_cv = repeat(self.compress_mem_kv, 'kv ... -> kv b ...', b = batch)
427+
428+
ck_for_attn = cat((mem_ck, ck_for_attn), dim = -2)
429+
cv_for_attn = cat((mem_cv, cv_for_attn), dim = -2)
430+
431+
repeated_ck = repeat(ck_for_attn, 'b h ... -> b (h gh) ...', gh = self.num_grouped_queries)
432+
repeated_cv = repeat(cv_for_attn, 'b h ... -> b (h gh) ...', gh = self.num_grouped_queries)
424433

425434
csim = einsum(q, repeated_ck, 'b h i d, b h j d -> b h i j') * scale
426435
cattn = csim.softmax(dim = -1)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.1.16"
3+
version = "0.1.18"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)