Skip to content

Commit dd52171

Browse files
committed
Allow gemma to take in flat string inputs
1 parent db7b7a4 commit dd52171

File tree

2 files changed

+38
-15
lines changed

2 files changed

+38
-15
lines changed

keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -511,25 +511,30 @@ def call(
511511
# === Input extraction and validation ===
512512

513513
# Extract text part of the input.
514-
prompts, responses = x["prompts"], x["responses"]
515-
tf.debugging.assert_shapes([(prompts, ("N",)), (responses, ("N",))])
514+
if isinstance(x, dict):
515+
prompts, responses = x["prompts"], x["responses"]
516+
tf.debugging.assert_shapes([(prompts, ("N",)), (responses, ("N",))])
517+
images = x.get("images", None)
518+
has_prompt = True
519+
else:
520+
responses = tf.convert_to_tensor(x)
521+
prompts = None
522+
images = None
523+
has_prompt = False
516524

517525
# Find out if the input is batched/not batched. Uprank if not batched.
518526
# In other preprocessors, we don't have to do this, but here, all
519527
# the following logic (indices, etc.) uses tensors with a batch dim.
520528
# We will squeeze these back at the end.
521529
batched = True
522-
if isinstance(prompts, str):
530+
if isinstance(responses, str):
523531
batched = False
524-
prompts = [prompts]
532+
prompts = [prompts] if has_prompt else None
525533
responses = [responses]
526-
if isinstance(prompts, tf.Tensor) and len(prompts.shape) == 0:
534+
if isinstance(responses, tf.Tensor) and len(responses.shape) == 0:
527535
batched = False
528-
prompts = tf.expand_dims(prompts, axis=0)
529-
responses = tf.expand_dims(responses, axis=0)
530-
531-
# Extract images from the input.
532-
images = x.get("images", None)
536+
prompts = prompts[None] if has_prompt else None
537+
responses = responses[None]
533538

534539
# There are 8 cases, based on values of
535540
# a = `self.text_only_model`, b = `images` is `None`, and whether
@@ -563,18 +568,20 @@ def call(
563568
# === Tokenization, padding, etc. ===
564569

565570
# Tokenise the inputs.
566-
prompts = self.tokenizer(prompts)
567-
responses = self.tokenizer(responses)
571+
if has_prompt:
572+
segments = (self.tokenizer(prompts), self.tokenizer(responses))
573+
else:
574+
segments = (self.tokenizer(responses),)
568575

569576
# Padding.
570577
token_ids, segment_ids = self.packer(
571-
(prompts, responses),
578+
segments,
572579
sequence_length=sequence_length + 1,
573580
add_start_value=self.add_start_token,
574581
add_end_value=self.add_end_token,
575582
)
576-
response_mask = segment_ids == 1
577583
padding_mask = token_ids != self.tokenizer.pad_token_id
584+
response_mask = segment_ids == 1 if has_prompt else padding_mask
578585

579586
# === Text Model ===
580587
if self.text_only_model:
@@ -600,7 +607,7 @@ def call(
600607

601608
# === Vision processing ===
602609

603-
batch_size = tf.shape(prompts)[0]
610+
batch_size = tf.shape(responses)[0]
604611
desired_height = self.image_converter.image_size[0]
605612
desired_width = self.image_converter.image_size[1]
606613
if images is None:

keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,22 @@ def test_text_preprocessor_basics(self):
6565
),
6666
)
6767

68+
def test_text_preprocessor_single_string_input(self):
69+
input_data = ["the quick brown fox"]
70+
self.run_preprocessing_layer_test(
71+
cls=Gemma3CausalLMPreprocessor,
72+
init_kwargs=self.init_text_kwargs,
73+
input_data=input_data,
74+
expected_output=(
75+
{
76+
"token_ids": [[1, 9, 14, 10, 12, 2, 0, 0]],
77+
"padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]],
78+
},
79+
[[9, 14, 10, 12, 2, 0, 0, 0]], # Labels shifted.
80+
[[1, 1, 1, 1, 1, 0, 0, 0]],
81+
),
82+
)
83+
6884
def test_preprocessor_basics(self):
6985
input_data = {
7086
"prompts": ["the quick brown fox <start_of_image>"],

0 commit comments

Comments
 (0)