Skip to content

Commit 65bc636

Browse files
committed
Remove TODO
1 parent 8c489cb commit 65bc636

File tree

1 file changed

+1
-8
lines changed

1 file changed

+1
-8
lines changed

lib/bumblebee/text/generation.ex

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -306,19 +306,12 @@ defmodule Bumblebee.Text.Generation do
306306

307307
output_policy = model_output_policy(model)
308308

309-
# TODO: fix Axon.MixedPrecision.cast/2 to not cast integers, to
310-
# match Axon compiler
311-
312309
# Cast all float cache tensors to match the model output. This way
313310
# we make sure the cache we pass as input has the same types as
314311
# the updated cache returned from the model
315312
cache =
316313
Bumblebee.Utils.Nx.map(cache, fn tensor ->
317-
if Nx.Type.integer?(Nx.type(tensor)) do
318-
tensor
319-
else
320-
Axon.MixedPrecision.cast(output_policy, tensor, :output)
321-
end
314+
Axon.MixedPrecision.cast(output_policy, tensor, :output)
322315
end)
323316

324317
Map.put(inputs, "cache", cache)

0 commit comments

Comments
 (0)