We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 5355ccb + 20bf44a commit e7a867cCopy full SHA for e7a867c
native_sparse_attention_pytorch/transformer.py
@@ -273,7 +273,10 @@ def forward(
273
274
# token embedding
275
276
- tokens = self.token_emb(ids)
+ if is_inferencing:
277
+ tokens = self.token_emb(ids[:, -1:])
278
+ else:
279
+ tokens = self.token_emb(ids)
280
281
# prepare maybe flex attention masks
282
@@ -298,9 +301,6 @@ def forward(
298
301
299
302
next_cache = []
300
303
- if is_inferencing:
- tokens = tokens[:, -1:]
-
304
# layers
305
306
for attn, ff in self.layers:
0 commit comments