Skip to content

Commit 1b367c6

Browse files
committed
add a test for transformer inference being same with cache and without
1 parent ea86bb6 commit 1b367c6

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

native_sparse_attention_pytorch/transformer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,11 @@ def __init__(
7575
self.norm = RMSNorm(dim)
7676

7777
self.heads = heads
78-
self.kv_heads = default(kv_heads, heads)
78+
kv_heads = default(kv_heads, heads)
7979
dim_inner = heads * dim_head
8080
dim_kv_inner = kv_heads * dim_head
8181

82+
self.kv_heads = kv_heads
8283
self.causal = causal
8384

8485
self.rotary_embed = RotaryEmbedding(dim_head)

tests/test_sparse_attn.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,20 @@ def test_inference(
8282
sequential_out = torch.cat(sequential_out, dim = 1)
8383

8484
assert torch.allclose(parallel_out, sequential_out, atol = 1e-5)
85+
86+
def test_transformer_inference():
87+
from native_sparse_attention_pytorch.transformer import Transformer
88+
89+
model = Transformer(
90+
num_tokens = 256,
91+
dim = 512,
92+
depth = 2,
93+
use_sparse_attn = True
94+
)
95+
96+
prompt = torch.randint(0, 256, (1, 1))
97+
98+
sampled = model.sample(prompt, 25, temperature = 0., use_cache_kv = False)
99+
sampled_cached = model.sample(prompt, 25, temperature = 0., use_cache_kv = True)
100+
101+
assert torch.allclose(sampled, sampled_cached)

0 commit comments

Comments
 (0)