We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 45e7408 commit 29bdfa4Copy full SHA for 29bdfa4
lib/bumblebee/text/m2m100.ex
@@ -437,14 +437,18 @@ defmodule Bumblebee.Text.M2m100 do
437
end
438
439
defnp sinusoidal_position_embedding_impl(position_ids, opts \\ []) do
440
+ position_ids = Nx.vectorize(position_ids, :batch)
441
+
442
size = opts[:size]
443
444
half_size = div(size, 2)
445
base = 10_000
446
range = Nx.iota({half_size}) / (half_size - 1)
447
inv_frequency = 1 / Nx.pow(base, range)
448
angle = Nx.outer(position_ids, inv_frequency)
- Nx.concatenate([Nx.sin(angle), Nx.cos(angle)], axis: -1)
449
+ sin_cos = Nx.concatenate([Nx.sin(angle), Nx.cos(angle)], axis: -1)
450
451
+ Nx.devectorize(sin_cos)
452
453
454
defp decoder(
0 commit comments