File tree Expand file tree Collapse file tree 1 file changed +1
-8
lines changed Expand file tree Collapse file tree 1 file changed +1
-8
lines changed Original file line number Diff line number Diff line change @@ -306,19 +306,12 @@ defmodule Bumblebee.Text.Generation do
306
306
307
307
output_policy = model_output_policy ( model )
308
308
309
- # TODO: fix Axon.MixedPrecision.cast/2 to not cast integers, to
310
- # match Axon compiler
311
-
312
309
# Cast all float cache tensors to match the model output. This way
313
310
# we make sure the cache we pass as input has the same types as
314
311
# the updated cache returned from the model
315
312
cache =
316
313
Bumblebee.Utils.Nx . map ( cache , fn tensor ->
317
- if Nx.Type . integer? ( Nx . type ( tensor ) ) do
318
- tensor
319
- else
320
- Axon.MixedPrecision . cast ( output_policy , tensor , :output )
321
- end
314
+ Axon.MixedPrecision . cast ( output_policy , tensor , :output )
322
315
end )
323
316
324
317
Map . put ( inputs , "cache" , cache )
You can’t perform that action at this time.
0 commit comments