Skip to content

Commit e7a867c

Browse files
authored
Merge pull request #24 from Pasewark/token_inference_fix
Small change so token embeddings aren't looked up for past tokens during inference
2 parents 5355ccb + 20bf44a commit e7a867c

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

native_sparse_attention_pytorch/transformer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,10 @@ def forward(
273273

274274
# token embedding
275275

276-
tokens = self.token_emb(ids)
276+
if is_inferencing:
277+
tokens = self.token_emb(ids[:, -1:])
278+
else:
279+
tokens = self.token_emb(ids)
277280

278281
# prepare maybe flex attention masks
279282

@@ -298,9 +301,6 @@ def forward(
298301

299302
next_cache = []
300303

301-
if is_inferencing:
302-
tokens = tokens[:, -1:]
303-
304304
# layers
305305

306306
for attn, ff in self.layers:

0 commit comments

Comments
 (0)