Skip to content

Commit 29bdfa4

Browse files
committed
Fix M2M100 with batched input
1 parent 45e7408 commit 29bdfa4

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

lib/bumblebee/text/m2m100.ex

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,14 +437,18 @@ defmodule Bumblebee.Text.M2m100 do
437437
end
438438

439439
defnp sinusoidal_position_embedding_impl(position_ids, opts \\ []) do
440+
position_ids = Nx.vectorize(position_ids, :batch)
441+
440442
size = opts[:size]
441443

442444
half_size = div(size, 2)
443445
base = 10_000
444446
range = Nx.iota({half_size}) / (half_size - 1)
445447
inv_frequency = 1 / Nx.pow(base, range)
446448
angle = Nx.outer(position_ids, inv_frequency)
447-
Nx.concatenate([Nx.sin(angle), Nx.cos(angle)], axis: -1)
449+
sin_cos = Nx.concatenate([Nx.sin(angle), Nx.cos(angle)], axis: -1)
450+
451+
Nx.devectorize(sin_cos)
448452
end
449453

450454
defp decoder(

0 commit comments

Comments
 (0)