From c7fcab12dce01250240c13c0ef15a60086a34284 Mon Sep 17 00:00:00 2001 From: Emanuel Ruzak <48877259+emparu@users.noreply.github.com> Date: Wed, 1 Oct 2025 18:25:58 -0300 Subject: [PATCH] [FIX] Prevent TypeError in text-only Gemma3CausalLM and improve generate_step defaults 1. Bug Fix: TypeError during generation for text-only models The Problem: When using a Gemma3CausalLM model configured for text-only processing (i.e., with vision_encoder=None and preprocessor=None), a call to causal_lm.generate() fails with a TypeError. The root cause is that the internal generate_step method returns a dictionary containing an 'images': None key-value pair. This None value is eventually passed to ops.concatenate during the output normalization step, which does not accept None as a valid input. This workflow is common when pretraining a model from scratch. The Fix: The generate_step method has been modified to only include the 'images' key in its returned dictionary if an image tensor is actually present. This ensures that a None value is never passed to downstream functions, resolving the TypeError. Proof of Bug and Fix: The following Colab notebook demonstrates the bug with the original code and shows the successful execution after applying this fix: https://colab.research.google.com/drive/1QVk2idB6fcdYYJb1cBQGaKHe5QSGjCti?usp=sharing 2. Refactoring: Remove Hardcoded Stop Token The Problem: The internal generate_step method has a hardcoded default stop_token_ids=[106], which corresponds to the token. This is conceptually incorrect for a base architectural model, as the model itself should not have opinions about instruction-following or conversational tokens. This hardcoded value can interfere with pretraining or sampling raw text. The Fix: The method signature has been changed from stop_token_ids=[106] to stop_token_ids=None. This is a safe, non-breaking change because the public-facing Gemma3CausalLM.generate() method is already responsible for setting the appropriate stop tokens when a user specifies stop_token_ids="auto". --- keras_hub/src/models/gemma3/gemma3_causal_lm.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/keras_hub/src/models/gemma3/gemma3_causal_lm.py b/keras_hub/src/models/gemma3/gemma3_causal_lm.py index 8fa2811598..c07a4bdbc6 100644 --- a/keras_hub/src/models/gemma3/gemma3_causal_lm.py +++ b/keras_hub/src/models/gemma3/gemma3_causal_lm.py @@ -227,7 +227,7 @@ def _build_cache( ) return hidden_states, cache - def generate_step(self, inputs, stop_token_ids=[106]): + def generate_step(self, inputs, stop_token_ids=None): """A compilable generation function for a single batch of inputs. This function represents the inner, XLA-compilable, generation function @@ -326,11 +326,14 @@ def next(prompt, cache, index): else: # Without early stopping, all locations will have been updated. padding_mask = ops.ones_like(token_ids, dtype="bool") - return { + output_dict = { "token_ids": token_ids, "padding_mask": padding_mask, - "images": images, } + if images is not None: + output_dict["images"] = images + + return output_dict def generate( self,