Skip to content

Commit 2dfad85

Browse files
committed
Allow string pairs as input to text classification serving
1 parent 65bc636 commit 2dfad85

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

lib/bumblebee/shared.ex

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,14 @@ defmodule Bumblebee.Shared do
210210
end
211211
end
212212

213+
def validate_string_or_pairs(input) do
214+
case input do
215+
input when is_binary(input) -> {:ok, input}
216+
{left, right} when is_binary(left) and is_binary(right) -> {:ok, input}
217+
_other -> {:error, "expected a string or a pair of strings, got: #{inspect(input)}"}
218+
end
219+
end
220+
213221
@doc """
214222
Validates that the input is a single value and not a batch.
215223
"""

lib/bumblebee/text.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ defmodule Bumblebee.Text do
288288
defdelegate translation(model_info, tokenizer, generation_config, opts \\ []),
289289
to: Bumblebee.Text.Translation
290290

291-
@type text_classification_input :: String.t()
291+
@type text_classification_input :: String.t() | {String.t(), String.t()}
292292
@type text_classification_output :: %{predictions: list(text_classification_prediction())}
293293
@type text_classification_prediction :: %{score: number(), label: String.t()}
294294

lib/bumblebee/text/text_classification.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ defmodule Bumblebee.Text.TextClassification do
7474
|> Nx.Serving.batch_size(batch_size)
7575
|> Nx.Serving.process_options(batch_keys: batch_keys)
7676
|> Nx.Serving.client_preprocessing(fn input ->
77-
{texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1)
77+
{texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string_or_pairs/1)
7878

7979
inputs =
8080
Nx.with_default_backend(Nx.BinaryBackend, fn ->

0 commit comments

Comments
 (0)