From 2b78c8af272a9a52c98bb85066d289e6c943e8cf Mon Sep 17 00:00:00 2001 From: Suhana Date: Mon, 22 Sep 2025 10:11:54 +0530 Subject: [PATCH 1/7] Added gemma embedding to the backbone --- .../src/models/gemma3/gemma3_backbone.py | 105 +++++++++++++- .../src/models/gemma3/gemma3_backbone_test.py | 128 ++++++++++++++++++ 2 files changed, 232 insertions(+), 1 deletion(-) diff --git a/keras_hub/src/models/gemma3/gemma3_backbone.py b/keras_hub/src/models/gemma3/gemma3_backbone.py index a65dbd726b..f4ac227b58 100644 --- a/keras_hub/src/models/gemma3/gemma3_backbone.py +++ b/keras_hub/src/models/gemma3/gemma3_backbone.py @@ -1,5 +1,5 @@ import keras -from keras import ops +from keras import ops, layers from keras_hub.src.api_export import keras_hub_export from keras_hub.src.layers.modeling.reversible_embedding import ( @@ -424,3 +424,106 @@ def from_config(cls, config): ) return super().from_config(config) + +# --- ADD/REPLACE WITH THIS CODE --- + +class MeanPooling(layers.Layer): + """ + This layer calculates the mean of the token embeddings, ignoring + padded tokens. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.supports_masking = True + + def call(self, inputs): + # --- THIS IS THE CHANGE --- + # Unpack inputs from a single list + sequence_output, padding_mask = inputs + # --- -------------------- --- + + mask = ops.expand_dims( + ops.cast(padding_mask, sequence_output.dtype), axis=-1 + ) + masked_output = sequence_output * mask + sum_embeddings = ops.sum(masked_output, axis=1) + + num_tokens = ops.sum(ops.cast(padding_mask, sequence_output.dtype), axis=1) + num_tokens = ops.expand_dims(num_tokens, axis=1) + + num_tokens = ops.maximum(num_tokens, 1e-9) + + mean_embeddings = sum_embeddings / num_tokens + return mean_embeddings + + def compute_output_shape(self, input_shape): + # Now input_shape will correctly be a list of two shapes + sequence_output_shape, padding_mask_shape = input_shape + return (sequence_output_shape[0], sequence_output_shape[2]) + + +@keras_hub_export("keras_hub.models.Gemma3EmbeddingModel") +class Gemma3EmbeddingModel(keras.Model): + """ + A Gemma3 model for generating sequence embeddings. + (Docstring...) + """ + + def __init__( + self, + backbone: Gemma3Backbone, + embedding_dim: int, + normalize: bool = True, + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.pooling_layer = MeanPooling( + dtype=backbone.dtype, name="mean_pooling" + ) + self.projection_layer = layers.Dense( + embedding_dim, dtype=backbone.dtype, name="embedding_projection" + ) + self.normalize_output = normalize + if self.normalize_output: + self.normalization_layer = layers.UnitNormalization( + axis=-1, dtype=backbone.dtype, name="l2_normalization" # <-- FIX + ) + + # === Functional Model === + inputs = self.backbone.input + sequence_output = self.backbone.outputs[0] + padding_mask = inputs["padding_mask"] + + # --- THIS IS THE CHANGE --- + # Pass the inputs as a single list + pooled_output = self.pooling_layer([sequence_output, padding_mask]) + # --- -------------------- --- + + embedding = self.projection_layer(pooled_output) + + if self.normalize_output: + embedding = self.normalization_layer(embedding) + + super().__init__(inputs=inputs, outputs=embedding, **kwargs) + + # === Config === + self.embedding_dim = embedding_dim + self.normalize = normalize + + def get_config(self): + config = super().get_config() + config.update( + { + "backbone": keras.layers.serialize(self.backbone), + "embedding_dim": self.embedding_dim, + "normalize": self.normalize, + } + ) + return config + + @classmethod + def from_config(cls, config): + config["backbone"] = keras.layers.deserialize(config["backbone"]) + return cls(**config) diff --git a/keras_hub/src/models/gemma3/gemma3_backbone_test.py b/keras_hub/src/models/gemma3/gemma3_backbone_test.py index 7eb31f9ff6..f70a275841 100644 --- a/keras_hub/src/models/gemma3/gemma3_backbone_test.py +++ b/keras_hub/src/models/gemma3/gemma3_backbone_test.py @@ -9,6 +9,9 @@ from keras_hub.src.models.gemma3.gemma3_vision_encoder import ( Gemma3VisionEncoder, ) +from keras_hub.src.models.gemma3.gemma3_backbone import ( + Gemma3EmbeddingModel, +) from keras_hub.src.tests.test_case import TestCase @@ -193,3 +196,128 @@ def test_all_presets(self): if "_text" in preset or "1b" in preset else self.input_data, ) + +class Gemma3EmbeddingModelTest(TestCase, parameterized.TestCase): + def setUp(self): + self.batch_size = 2 + self.vocabulary_size = 256 + self.text_sequence_length = 64 + self.hidden_dim = 8 + self.embedding_dim = 16 + + self.backbone = Gemma3Backbone( + vocabulary_size=self.vocabulary_size, + image_size=16, + num_layers=2, + num_query_heads=2, + num_key_value_heads=1, + hidden_dim=self.hidden_dim, + intermediate_dim=32, + head_dim=4, + vision_encoder=None, + ) + + self.init_kwargs = { + "backbone": self.backbone, + "embedding_dim": self.embedding_dim, + "normalize": True, + } + + self.init_kwargs_no_norm = { + "backbone": self.backbone, + "embedding_dim": self.embedding_dim, + "normalize": False, + } + + dummy_text_token_ids = np.random.randint( + 0, + self.vocabulary_size, + (self.batch_size, self.text_sequence_length), + ) + padding_mask = np.ones( + (self.batch_size, self.text_sequence_length), dtype="int32" + ) + padding_mask[0, -10:] = 0 + padding_mask[1, -5:] = 0 + + self.input_data = { + "token_ids": dummy_text_token_ids, + "padding_mask": padding_mask, + } + + def test_model_basics(self): + """Test the model's forward pass and output shape.""" + model = Gemma3EmbeddingModel(**self.init_kwargs) + output = model(self.input_data) + expected_output_shape = (self.batch_size, self.embedding_dim) + self.assertEqual(output.shape, expected_output_shape) + + @parameterized.named_parameters( + ("normalize", True, 9), + ("no_normalize", False, 8), + ) + def test_architecture_characteristics(self, normalize, num_layers): + """Test parameter and layer counts.""" + init_kwargs = self.init_kwargs if normalize else self.init_kwargs_no_norm + model = Gemma3EmbeddingModel(**init_kwargs) + + backbone_params = self.backbone.count_params() + projection_params = ( + self.hidden_dim * self.embedding_dim + ) + self.embedding_dim + expected_params = backbone_params + projection_params + + self.assertEqual(model.count_params(), expected_params) + self.assertEqual(len(model.layers), num_layers) + + def test_normalization(self): + """Test that the `normalize` flag works correctly.""" + model_norm = Gemma3EmbeddingModel(**self.init_kwargs) + outputs_norm = model_norm(self.input_data) + + norms_squared = ops.sum(ops.square(outputs_norm), axis=1) + norms = ops.sqrt(norms_squared) + + self.assertAllClose(norms, ops.ones(self.batch_size), atol=1e-5) + + model_no_norm = Gemma3EmbeddingModel(**self.init_kwargs_no_norm) + outputs_no_norm = model_no_norm(self.input_data) + + norms_no_norm_squared = ops.sum(ops.square(outputs_no_norm), axis=1) + norms_no_norm = ops.sqrt(norms_no_norm_squared) + + self.assertNotAllClose(norms_no_norm, ops.ones(self.batch_size)) + + @parameterized.named_parameters( + ("normalize", True), + ("no_normalize", False), + ) + def test_saved_model(self, normalize): + init_kwargs = self.init_kwargs if normalize else self.init_kwargs_no_norm + + self.run_model_saving_test( + cls=Gemma3EmbeddingModel, + init_kwargs=init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.kaggle_key_required + @pytest.mark.extra_large + def test_build_from_preset_backbone(self): + backbone = Gemma3Backbone.from_preset("gemma3_instruct_1b_text") + model = Gemma3EmbeddingModel( + backbone=backbone, + embedding_dim=768, + normalize=True, + ) + + input_data = { + "token_ids": ops.array([[651, 4320, 8426, 25341, 235265]]), + "padding_mask": ops.ones((1, 5), dtype="int32"), + } + + outputs = model(input_data) + + self.assertEqual(outputs.shape, (1, 768)) + norm = ops.vector_norm(outputs, axis=1) + self.assertAllClose(norm, [1.0]) \ No newline at end of file From 36d829dad095b58623973b36eb5656dae0ecef88 Mon Sep 17 00:00:00 2001 From: Suhana Date: Wed, 24 Sep 2025 14:12:24 +0530 Subject: [PATCH 2/7] Refactoring the code --- .../src/models/gemma3/gemma3_backbone.py | 32 ++----------------- 1 file changed, 3 insertions(+), 29 deletions(-) diff --git a/keras_hub/src/models/gemma3/gemma3_backbone.py b/keras_hub/src/models/gemma3/gemma3_backbone.py index f4ac227b58..7ab48abf69 100644 --- a/keras_hub/src/models/gemma3/gemma3_backbone.py +++ b/keras_hub/src/models/gemma3/gemma3_backbone.py @@ -425,7 +425,6 @@ def from_config(cls, config): return super().from_config(config) -# --- ADD/REPLACE WITH THIS CODE --- class MeanPooling(layers.Layer): """ @@ -438,10 +437,7 @@ def __init__(self, **kwargs): self.supports_masking = True def call(self, inputs): - # --- THIS IS THE CHANGE --- - # Unpack inputs from a single list sequence_output, padding_mask = inputs - # --- -------------------- --- mask = ops.expand_dims( ops.cast(padding_mask, sequence_output.dtype), axis=-1 @@ -458,26 +454,19 @@ def call(self, inputs): return mean_embeddings def compute_output_shape(self, input_shape): - # Now input_shape will correctly be a list of two shapes sequence_output_shape, padding_mask_shape = input_shape return (sequence_output_shape[0], sequence_output_shape[2]) -@keras_hub_export("keras_hub.models.Gemma3EmbeddingModel") +@keras_hub_export("keras_hub.models.Gemma3Embedding") class Gemma3EmbeddingModel(keras.Model): - """ - A Gemma3 model for generating sequence embeddings. - (Docstring...) - """ def __init__( self, backbone: Gemma3Backbone, embedding_dim: int, - normalize: bool = True, **kwargs, ): - # === Layers === self.backbone = backbone self.pooling_layer = MeanPooling( dtype=backbone.dtype, name="mean_pooling" @@ -485,32 +474,18 @@ def __init__( self.projection_layer = layers.Dense( embedding_dim, dtype=backbone.dtype, name="embedding_projection" ) - self.normalize_output = normalize - if self.normalize_output: - self.normalization_layer = layers.UnitNormalization( - axis=-1, dtype=backbone.dtype, name="l2_normalization" # <-- FIX - ) - # === Functional Model === - inputs = self.backbone.input + inputs = self.backbone.input sequence_output = self.backbone.outputs[0] - padding_mask = inputs["padding_mask"] + padding_mask = inputs["padding_mask"] - # --- THIS IS THE CHANGE --- - # Pass the inputs as a single list pooled_output = self.pooling_layer([sequence_output, padding_mask]) - # --- -------------------- --- embedding = self.projection_layer(pooled_output) - if self.normalize_output: - embedding = self.normalization_layer(embedding) - super().__init__(inputs=inputs, outputs=embedding, **kwargs) - # === Config === self.embedding_dim = embedding_dim - self.normalize = normalize def get_config(self): config = super().get_config() @@ -518,7 +493,6 @@ def get_config(self): { "backbone": keras.layers.serialize(self.backbone), "embedding_dim": self.embedding_dim, - "normalize": self.normalize, } ) return config From d02bc438ef402ef2407275d78bef0114c964bf69 Mon Sep 17 00:00:00 2001 From: Suhana Date: Wed, 24 Sep 2025 14:16:42 +0530 Subject: [PATCH 3/7] Refactoring the code --- .../src/models/gemma3/gemma3_backbone_test.py | 51 +++---------------- 1 file changed, 8 insertions(+), 43 deletions(-) diff --git a/keras_hub/src/models/gemma3/gemma3_backbone_test.py b/keras_hub/src/models/gemma3/gemma3_backbone_test.py index f70a275841..1528344c65 100644 --- a/keras_hub/src/models/gemma3/gemma3_backbone_test.py +++ b/keras_hub/src/models/gemma3/gemma3_backbone_test.py @@ -220,13 +220,6 @@ def setUp(self): self.init_kwargs = { "backbone": self.backbone, "embedding_dim": self.embedding_dim, - "normalize": True, - } - - self.init_kwargs_no_norm = { - "backbone": self.backbone, - "embedding_dim": self.embedding_dim, - "normalize": False, } dummy_text_token_ids = np.random.randint( @@ -252,14 +245,9 @@ def test_model_basics(self): expected_output_shape = (self.batch_size, self.embedding_dim) self.assertEqual(output.shape, expected_output_shape) - @parameterized.named_parameters( - ("normalize", True, 9), - ("no_normalize", False, 8), - ) - def test_architecture_characteristics(self, normalize, num_layers): + def test_architecture_characteristics(self): """Test parameter and layer counts.""" - init_kwargs = self.init_kwargs if normalize else self.init_kwargs_no_norm - model = Gemma3EmbeddingModel(**init_kwargs) + model = Gemma3EmbeddingModel(**self.init_kwargs) backbone_params = self.backbone.count_params() projection_params = ( @@ -267,37 +255,15 @@ def test_architecture_characteristics(self, normalize, num_layers): ) + self.embedding_dim expected_params = backbone_params + projection_params - self.assertEqual(model.count_params(), expected_params) - self.assertEqual(len(model.layers), num_layers) + expected_layers = 8 - def test_normalization(self): - """Test that the `normalize` flag works correctly.""" - model_norm = Gemma3EmbeddingModel(**self.init_kwargs) - outputs_norm = model_norm(self.input_data) - - norms_squared = ops.sum(ops.square(outputs_norm), axis=1) - norms = ops.sqrt(norms_squared) - - self.assertAllClose(norms, ops.ones(self.batch_size), atol=1e-5) - - model_no_norm = Gemma3EmbeddingModel(**self.init_kwargs_no_norm) - outputs_no_norm = model_no_norm(self.input_data) - - norms_no_norm_squared = ops.sum(ops.square(outputs_no_norm), axis=1) - norms_no_norm = ops.sqrt(norms_no_norm_squared) - - self.assertNotAllClose(norms_no_norm, ops.ones(self.batch_size)) - - @parameterized.named_parameters( - ("normalize", True), - ("no_normalize", False), - ) - def test_saved_model(self, normalize): - init_kwargs = self.init_kwargs if normalize else self.init_kwargs_no_norm + self.assertEqual(model.count_params(), expected_params) + self.assertEqual(len(model.layers), expected_layers) + def test_saved_model(self): self.run_model_saving_test( cls=Gemma3EmbeddingModel, - init_kwargs=init_kwargs, + init_kwargs=self.init_kwargs, input_data=self.input_data, ) @@ -308,7 +274,6 @@ def test_build_from_preset_backbone(self): model = Gemma3EmbeddingModel( backbone=backbone, embedding_dim=768, - normalize=True, ) input_data = { @@ -320,4 +285,4 @@ def test_build_from_preset_backbone(self): self.assertEqual(outputs.shape, (1, 768)) norm = ops.vector_norm(outputs, axis=1) - self.assertAllClose(norm, [1.0]) \ No newline at end of file + self.assertGreater(norm[0], 0) \ No newline at end of file From 24f37403b4fb873fc1f3b32b525ec658c32fd45b Mon Sep 17 00:00:00 2001 From: Suhana Date: Mon, 6 Oct 2025 10:48:45 +0530 Subject: [PATCH 4/7] Adding gemma embedding backbone --- keras_hub/api/models/__init__.py | 3 + .../src/models/gemma3/gemma3_backbone.py | 92 ++++++++++++++----- .../src/models/gemma3/gemma3_backbone_test.py | 11 ++- 3 files changed, 80 insertions(+), 26 deletions(-) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index fe220e2d43..3fdc420c18 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -285,6 +285,9 @@ from keras_hub.src.models.gemma3.gemma3_backbone import ( Gemma3Backbone as Gemma3Backbone, ) +from keras_hub.src.models.gemma3.gemma3_backbone import ( + Gemma3EmbeddingModel as Gemma3Embedding, +) from keras_hub.src.models.gemma3.gemma3_causal_lm import ( Gemma3CausalLM as Gemma3CausalLM, ) diff --git a/keras_hub/src/models/gemma3/gemma3_backbone.py b/keras_hub/src/models/gemma3/gemma3_backbone.py index 7ab48abf69..c882901651 100644 --- a/keras_hub/src/models/gemma3/gemma3_backbone.py +++ b/keras_hub/src/models/gemma3/gemma3_backbone.py @@ -1,5 +1,6 @@ import keras -from keras import ops, layers +from keras import layers +from keras import ops from keras_hub.src.api_export import keras_hub_export from keras_hub.src.layers.modeling.reversible_embedding import ( @@ -428,8 +429,29 @@ def from_config(cls, config): class MeanPooling(layers.Layer): """ - This layer calculates the mean of the token embeddings, ignoring - padded tokens. + Mean pooling layer that computes the average of token embeddings, + respecting a padding mask. + + This layer correctly handles variable-length sequences by ignoring + padded tokens in the mean calculation. + + Call arguments: + inputs: A tuple of `(sequence_output, padding_mask)`. + `sequence_output` is a tensor of shape `(batch_size, seq_len, + hidden_dim)`. `padding_mask` is a tensor of shape `(batch_size, + seq_len)` with `1` for valid tokens and `0` for padded tokens. + + Returns: + A tensor of shape `(batch_size, hidden_dim)`. + + Example: + ```python + sequence_output = np.random.rand(2, 4, 8).astype("float32") + padding_mask = np.array([[1, 1, 1, 0], [1, 1, 0, 0]]) + mean_pool_layer = MeanPooling() + pooled = mean_pool_layer((sequence_output, padding_mask)) + # pooled.shape -> (2, 8) + ``` """ def __init__(self, **kwargs): @@ -444,8 +466,9 @@ def call(self, inputs): ) masked_output = sequence_output * mask sum_embeddings = ops.sum(masked_output, axis=1) - - num_tokens = ops.sum(ops.cast(padding_mask, sequence_output.dtype), axis=1) + num_tokens = ops.sum( + ops.cast(padding_mask, sequence_output.dtype), axis=1 + ) num_tokens = ops.expand_dims(num_tokens, axis=1) num_tokens = ops.maximum(num_tokens, 1e-9) @@ -457,16 +480,46 @@ def compute_output_shape(self, input_shape): sequence_output_shape, padding_mask_shape = input_shape return (sequence_output_shape[0], sequence_output_shape[2]) + def get_config(self): + return super().get_config() + @keras_hub_export("keras_hub.models.Gemma3Embedding") class Gemma3EmbeddingModel(keras.Model): + """An end-to-end Gemma3 model for embedding tasks. - def __init__( - self, - backbone: Gemma3Backbone, - embedding_dim: int, - **kwargs, - ): + This model takes token ids as input and returns a fixed-size embedding + vector for the input sequence. It uses a `Gemma3Backbone` to generate + contextualized token embeddings, a `MeanPooling` layer to pool them into a + single vector, and a final `Dense` layer to project to the desired + embedding dimension. + + This model can be loaded with a pre-trained `Gemma3Backbone` and used for + tasks like semantic similarity, retrieval, or as a feature extractor. + + Args: + backbone: A `keras_hub.models.Gemma3Backbone` instance. + embedding_dim (int): The dimension of the output embedding. + + Example: + ```python + # backbone = keras_hub.models.Gemma3Backbone.from_preset( + # "gemma3_instruct_1b" + # ) + # embedding_model = keras_hub.models.Gemma3EmbeddingModel( + # backbone=backbone, + # embedding_dim=768, + # ) + # input_data = { + # "token_ids": np.array([[651, 4320, 8426, 25341, 235265]]), + # "padding_mask": np.ones((1, 5), dtype="int32"), + # } + # embeddings = embedding_model.predict(input_data) + ``` + """ + + def __init__(self, backbone, embedding_dim, **kwargs): + super().__init__(**kwargs) self.backbone = backbone self.pooling_layer = MeanPooling( dtype=backbone.dtype, name="mean_pooling" @@ -474,24 +527,21 @@ def __init__( self.projection_layer = layers.Dense( embedding_dim, dtype=backbone.dtype, name="embedding_projection" ) + self.embedding_dim = embedding_dim - inputs = self.backbone.input - sequence_output = self.backbone.outputs[0] + def call(self, inputs): + sequence_output = self.backbone(inputs) padding_mask = inputs["padding_mask"] - pooled_output = self.pooling_layer([sequence_output, padding_mask]) - + pooled_output = self.pooling_layer((sequence_output, padding_mask)) embedding = self.projection_layer(pooled_output) - - super().__init__(inputs=inputs, outputs=embedding, **kwargs) - - self.embedding_dim = embedding_dim + return embedding def get_config(self): config = super().get_config() config.update( { - "backbone": keras.layers.serialize(self.backbone), + "backbone": layers.serialize(self.backbone), "embedding_dim": self.embedding_dim, } ) @@ -499,5 +549,5 @@ def get_config(self): @classmethod def from_config(cls, config): - config["backbone"] = keras.layers.deserialize(config["backbone"]) + config["backbone"] = layers.deserialize(config["backbone"]) return cls(**config) diff --git a/keras_hub/src/models/gemma3/gemma3_backbone_test.py b/keras_hub/src/models/gemma3/gemma3_backbone_test.py index 1528344c65..8e6f3f31ae 100644 --- a/keras_hub/src/models/gemma3/gemma3_backbone_test.py +++ b/keras_hub/src/models/gemma3/gemma3_backbone_test.py @@ -6,12 +6,10 @@ from keras import ops from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone +from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3EmbeddingModel from keras_hub.src.models.gemma3.gemma3_vision_encoder import ( Gemma3VisionEncoder, ) -from keras_hub.src.models.gemma3.gemma3_backbone import ( - Gemma3EmbeddingModel, -) from keras_hub.src.tests.test_case import TestCase @@ -197,6 +195,7 @@ def test_all_presets(self): else self.input_data, ) + class Gemma3EmbeddingModelTest(TestCase, parameterized.TestCase): def setUp(self): self.batch_size = 2 @@ -249,13 +248,15 @@ def test_architecture_characteristics(self): """Test parameter and layer counts.""" model = Gemma3EmbeddingModel(**self.init_kwargs) + model(self.input_data) + backbone_params = self.backbone.count_params() projection_params = ( self.hidden_dim * self.embedding_dim ) + self.embedding_dim expected_params = backbone_params + projection_params - expected_layers = 8 + expected_layers = 3 self.assertEqual(model.count_params(), expected_params) self.assertEqual(len(model.layers), expected_layers) @@ -285,4 +286,4 @@ def test_build_from_preset_backbone(self): self.assertEqual(outputs.shape, (1, 768)) norm = ops.vector_norm(outputs, axis=1) - self.assertGreater(norm[0], 0) \ No newline at end of file + self.assertGreater(norm[0], 0) From 5f55514a9ca0c372158ea479f37523d7af8ffb64 Mon Sep 17 00:00:00 2001 From: Suhana Date: Mon, 13 Oct 2025 13:01:22 +0530 Subject: [PATCH 5/7] adding gemma3 mean pooling --- .../src/models/gemma3/gemma3_backbone.py | 172 +++++------------- .../src/models/gemma3/gemma3_backbone_test.py | 120 +++--------- .../src/models/gemma3/gemma3_mean_pooling.py | 88 +++++++++ 3 files changed, 157 insertions(+), 223 deletions(-) create mode 100644 keras_hub/src/models/gemma3/gemma3_mean_pooling.py diff --git a/keras_hub/src/models/gemma3/gemma3_backbone.py b/keras_hub/src/models/gemma3/gemma3_backbone.py index c882901651..92fc0c7d80 100644 --- a/keras_hub/src/models/gemma3/gemma3_backbone.py +++ b/keras_hub/src/models/gemma3/gemma3_backbone.py @@ -12,6 +12,7 @@ from keras_hub.src.models.gemma3.gemma3_interleave_embeddings import ( Gemma3InterleaveEmbeddings, ) +from keras_hub.src.models.gemma3.gemma3_mean_pooling import MeanPooling @keras_hub_export("keras_hub.models.Gemma3Backbone") @@ -28,6 +29,9 @@ class Gemma3Backbone(Backbone): For a higher-level object for text-generation, see `keras_hub.models.Gemma3CausalLM`. + This backbone can also function as an end-to-end embedding model by + setting the `pooling_mode` argument. + The default constructor gives a fully customizable, randomly initialized Gemma3 model with any vision encoder, number of heads, embedding dimensions, and equivalent configuration for the decoder layers. To load preset @@ -71,6 +75,12 @@ class Gemma3Backbone(Backbone): in all transformer blocks. Defaults to `1e-6`. dropout: float. Dropout probability for the Transformer decoder blocks. Defaults to `0`. + pooling_mode (str, optional): The pooling mode for the final output. + If set to `"mean"`, the model will add a mean pooling and a dense + projection layer to output a fixed-size embedding. In this case, + `embedding_dim` must also be set. Defaults to `None`. + embedding_dim (int, optional): The dimension of the final projected + embedding. Required if `pooling_mode` is `"mean"`. Defaults to `None`. dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use for the models computations and weights. Note that some computations, such as softmax and layer normalization will always @@ -199,6 +209,8 @@ def __init__( layer_norm_epsilon=1e-6, use_bidirectional_attention=False, dropout=0, + pooling_mode=None, + embedding_dim=None, dtype=None, **kwargs, ): @@ -320,6 +332,32 @@ def __init__( ) sequence_output = self.layer_norm(x) + if pooling_mode is not None: + if pooling_mode != "mean": + raise ValueError(f"Unknown pooling mode: {pooling_mode}") + if embedding_dim is None: + raise ValueError( + "`embedding_dim` must be specified when `pooling_mode` is set." + ) + + pooled_output = MeanPooling(dtype=dtype, name="mean_pooling")( + sequence_output=sequence_output, + padding_mask=padding_mask_input, + ) + + pooled_output = layers.Dense( + embedding_dim, + dtype=dtype, + name="embedding_projection", + )(pooled_output) + + outputs = { + "sequence_output": sequence_output, + "pooled_output": pooled_output, + } + else: + outputs = sequence_output + inputs = { "token_ids": token_id_input, "padding_mask": padding_mask_input, @@ -362,7 +400,9 @@ def __init__( self.use_bidirectional_attention = use_bidirectional_attention self.layer_norm_epsilon = layer_norm_epsilon self.dropout = dropout - + self.pooling_mode = pooling_mode + self.embedding_dim = embedding_dim + # Keep `num_vision_tokens_per_image` as a backbone property for easy # access. if not text_only_model: @@ -402,6 +442,8 @@ def get_config(self): "use_bidirectional_attention": self.use_bidirectional_attention, "layer_norm_epsilon": self.layer_norm_epsilon, "dropout": self.dropout, + "pooling_mode": self.pooling_mode, + "embedding_dim": self.embedding_dim, } ) return config @@ -424,130 +466,4 @@ def from_config(cls, config): } ) - return super().from_config(config) - - -class MeanPooling(layers.Layer): - """ - Mean pooling layer that computes the average of token embeddings, - respecting a padding mask. - - This layer correctly handles variable-length sequences by ignoring - padded tokens in the mean calculation. - - Call arguments: - inputs: A tuple of `(sequence_output, padding_mask)`. - `sequence_output` is a tensor of shape `(batch_size, seq_len, - hidden_dim)`. `padding_mask` is a tensor of shape `(batch_size, - seq_len)` with `1` for valid tokens and `0` for padded tokens. - - Returns: - A tensor of shape `(batch_size, hidden_dim)`. - - Example: - ```python - sequence_output = np.random.rand(2, 4, 8).astype("float32") - padding_mask = np.array([[1, 1, 1, 0], [1, 1, 0, 0]]) - mean_pool_layer = MeanPooling() - pooled = mean_pool_layer((sequence_output, padding_mask)) - # pooled.shape -> (2, 8) - ``` - """ - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.supports_masking = True - - def call(self, inputs): - sequence_output, padding_mask = inputs - - mask = ops.expand_dims( - ops.cast(padding_mask, sequence_output.dtype), axis=-1 - ) - masked_output = sequence_output * mask - sum_embeddings = ops.sum(masked_output, axis=1) - num_tokens = ops.sum( - ops.cast(padding_mask, sequence_output.dtype), axis=1 - ) - num_tokens = ops.expand_dims(num_tokens, axis=1) - - num_tokens = ops.maximum(num_tokens, 1e-9) - - mean_embeddings = sum_embeddings / num_tokens - return mean_embeddings - - def compute_output_shape(self, input_shape): - sequence_output_shape, padding_mask_shape = input_shape - return (sequence_output_shape[0], sequence_output_shape[2]) - - def get_config(self): - return super().get_config() - - -@keras_hub_export("keras_hub.models.Gemma3Embedding") -class Gemma3EmbeddingModel(keras.Model): - """An end-to-end Gemma3 model for embedding tasks. - - This model takes token ids as input and returns a fixed-size embedding - vector for the input sequence. It uses a `Gemma3Backbone` to generate - contextualized token embeddings, a `MeanPooling` layer to pool them into a - single vector, and a final `Dense` layer to project to the desired - embedding dimension. - - This model can be loaded with a pre-trained `Gemma3Backbone` and used for - tasks like semantic similarity, retrieval, or as a feature extractor. - - Args: - backbone: A `keras_hub.models.Gemma3Backbone` instance. - embedding_dim (int): The dimension of the output embedding. - - Example: - ```python - # backbone = keras_hub.models.Gemma3Backbone.from_preset( - # "gemma3_instruct_1b" - # ) - # embedding_model = keras_hub.models.Gemma3EmbeddingModel( - # backbone=backbone, - # embedding_dim=768, - # ) - # input_data = { - # "token_ids": np.array([[651, 4320, 8426, 25341, 235265]]), - # "padding_mask": np.ones((1, 5), dtype="int32"), - # } - # embeddings = embedding_model.predict(input_data) - ``` - """ - - def __init__(self, backbone, embedding_dim, **kwargs): - super().__init__(**kwargs) - self.backbone = backbone - self.pooling_layer = MeanPooling( - dtype=backbone.dtype, name="mean_pooling" - ) - self.projection_layer = layers.Dense( - embedding_dim, dtype=backbone.dtype, name="embedding_projection" - ) - self.embedding_dim = embedding_dim - - def call(self, inputs): - sequence_output = self.backbone(inputs) - padding_mask = inputs["padding_mask"] - - pooled_output = self.pooling_layer((sequence_output, padding_mask)) - embedding = self.projection_layer(pooled_output) - return embedding - - def get_config(self): - config = super().get_config() - config.update( - { - "backbone": layers.serialize(self.backbone), - "embedding_dim": self.embedding_dim, - } - ) - return config - - @classmethod - def from_config(cls, config): - config["backbone"] = layers.deserialize(config["backbone"]) - return cls(**config) + return super().from_config(config) \ No newline at end of file diff --git a/keras_hub/src/models/gemma3/gemma3_backbone_test.py b/keras_hub/src/models/gemma3/gemma3_backbone_test.py index 8e6f3f31ae..38b7e5cbde 100644 --- a/keras_hub/src/models/gemma3/gemma3_backbone_test.py +++ b/keras_hub/src/models/gemma3/gemma3_backbone_test.py @@ -6,7 +6,6 @@ from keras import ops from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone -from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3EmbeddingModel from keras_hub.src.models.gemma3.gemma3_vision_encoder import ( Gemma3VisionEncoder, ) @@ -122,6 +121,30 @@ def test_backbone_basics(self, backbone_type): run_quantization_check=False, ) + def test_embedding_mode(self): + embedding_dim = 16 + init_kwargs = self.text_init_kwargs.copy() + input_data = self.text_backbone_input_data.copy() + + init_kwargs["pooling_mode"] = "mean" + init_kwargs["embedding_dim"] = embedding_dim + + model = Gemma3Backbone(**init_kwargs) + outputs = model(input_data) + + self.assertIsInstance(outputs, dict) + self.assertIn("sequence_output", outputs) + self.assertIn("pooled_output", outputs) + + expected_pooled_shape = (self.batch_size, embedding_dim) + self.assertEqual(outputs["pooled_output"].shape, expected_pooled_shape) + + self.run_model_saving_test( + cls=Gemma3Backbone, + init_kwargs=init_kwargs, + input_data=input_data, + ) + @parameterized.named_parameters( ("text_and_vision", "text_and_vision", 7560, 15), ("text_only", "text_only", 5752, 10), @@ -193,97 +216,4 @@ def test_all_presets(self): input_data=self.text_backbone_input_data if "_text" in preset or "1b" in preset else self.input_data, - ) - - -class Gemma3EmbeddingModelTest(TestCase, parameterized.TestCase): - def setUp(self): - self.batch_size = 2 - self.vocabulary_size = 256 - self.text_sequence_length = 64 - self.hidden_dim = 8 - self.embedding_dim = 16 - - self.backbone = Gemma3Backbone( - vocabulary_size=self.vocabulary_size, - image_size=16, - num_layers=2, - num_query_heads=2, - num_key_value_heads=1, - hidden_dim=self.hidden_dim, - intermediate_dim=32, - head_dim=4, - vision_encoder=None, - ) - - self.init_kwargs = { - "backbone": self.backbone, - "embedding_dim": self.embedding_dim, - } - - dummy_text_token_ids = np.random.randint( - 0, - self.vocabulary_size, - (self.batch_size, self.text_sequence_length), - ) - padding_mask = np.ones( - (self.batch_size, self.text_sequence_length), dtype="int32" - ) - padding_mask[0, -10:] = 0 - padding_mask[1, -5:] = 0 - - self.input_data = { - "token_ids": dummy_text_token_ids, - "padding_mask": padding_mask, - } - - def test_model_basics(self): - """Test the model's forward pass and output shape.""" - model = Gemma3EmbeddingModel(**self.init_kwargs) - output = model(self.input_data) - expected_output_shape = (self.batch_size, self.embedding_dim) - self.assertEqual(output.shape, expected_output_shape) - - def test_architecture_characteristics(self): - """Test parameter and layer counts.""" - model = Gemma3EmbeddingModel(**self.init_kwargs) - - model(self.input_data) - - backbone_params = self.backbone.count_params() - projection_params = ( - self.hidden_dim * self.embedding_dim - ) + self.embedding_dim - expected_params = backbone_params + projection_params - - expected_layers = 3 - - self.assertEqual(model.count_params(), expected_params) - self.assertEqual(len(model.layers), expected_layers) - - def test_saved_model(self): - self.run_model_saving_test( - cls=Gemma3EmbeddingModel, - init_kwargs=self.init_kwargs, - input_data=self.input_data, - ) - - @pytest.mark.kaggle_key_required - @pytest.mark.extra_large - def test_build_from_preset_backbone(self): - backbone = Gemma3Backbone.from_preset("gemma3_instruct_1b_text") - model = Gemma3EmbeddingModel( - backbone=backbone, - embedding_dim=768, - ) - - input_data = { - "token_ids": ops.array([[651, 4320, 8426, 25341, 235265]]), - "padding_mask": ops.ones((1, 5), dtype="int32"), - } - - outputs = model(input_data) - - self.assertEqual(outputs.shape, (1, 768)) - norm = ops.vector_norm(outputs, axis=1) - self.assertGreater(norm[0], 0) + ) \ No newline at end of file diff --git a/keras_hub/src/models/gemma3/gemma3_mean_pooling.py b/keras_hub/src/models/gemma3/gemma3_mean_pooling.py new file mode 100644 index 0000000000..dfd1f7d2e7 --- /dev/null +++ b/keras_hub/src/models/gemma3/gemma3_mean_pooling.py @@ -0,0 +1,88 @@ +# Copyright 2024 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import keras +from keras import ops + + +class MeanPooling(keras.layers.Layer): + """Mean pooling layer that computes the average of token embeddings. + + This layer correctly handles variable-length sequences by ignoring + padded tokens in the mean calculation, using a `padding_mask`. + + Call arguments: + sequence_output: A tensor of shape `(batch_size, seq_len, hidden_dim)`. + padding_mask: A tensor of shape `(batch_size, seq_len)` with `1` + for valid tokens and `0` for padded tokens. + + Returns: + A tensor of shape `(batch_size, hidden_dim)`. + + Example: + ```python + sequence_output = np.random.rand(2, 4, 8).astype("float32") + padding_mask = np.array([[1, 1, 1, 0], [1, 1, 0, 0]]) + mean_pool_layer = MeanPooling() + pooled = mean_pool_layer( + sequence_output=sequence_output, + padding_mask=padding_mask + ) + # pooled.shape -> (2, 8) + ``` + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.supports_masking = True + + def call(self, sequence_output, padding_mask): + """ + Computes the masked mean pooling. + + Args: + sequence_output: The tensor of token embeddings. + padding_mask: The mask indicating which tokens to consider. + """ + # Expand the mask to match the dimensions of the sequence output + mask = ops.expand_dims( + ops.cast(padding_mask, sequence_output.dtype), axis=-1 + ) + + # Apply the mask to zero out the padded tokens + masked_output = sequence_output * mask + + # Sum the embeddings of the valid tokens + sum_embeddings = ops.sum(masked_output, axis=1) + + # Count the number of valid tokens for each sequence + num_tokens = ops.sum( + ops.cast(padding_mask, sequence_output.dtype), axis=1 + ) + num_tokens = ops.expand_dims(num_tokens, axis=1) + + # Add a small epsilon to avoid division by zero for empty sequences + num_tokens = ops.maximum(num_tokens, 1e-9) + + # Compute the mean by dividing the sum by the count of valid tokens + mean_embeddings = sum_embeddings / num_tokens + return mean_embeddings + + def compute_output_shape(self, sequence_output_shape, padding_mask_shape): + """Computes the output shape of the layer.""" + return (sequence_output_shape[0], sequence_output_shape[2]) + + def get_config(self): + """Returns the config of the layer.""" + return super().get_config() \ No newline at end of file From 3e409f5dbc4b3ccc72765f4d119362ed5c12ebc8 Mon Sep 17 00:00:00 2001 From: Suhana Date: Mon, 13 Oct 2025 14:51:47 +0530 Subject: [PATCH 6/7] refactoring the code to add gemma3_mean_pooling.py --- keras_hub/api/models/__init__.py | 3 - .../src/models/gemma3/gemma3_backbone.py | 62 ++++++---- .../src/models/gemma3/gemma3_backbone_test.py | 11 +- .../src/models/gemma3/gemma3_mean_pooling.py | 69 +++++------- .../models/gemma3/gemma3_mean_pooling_test.py | 106 ++++++++++++++++++ 5 files changed, 181 insertions(+), 70 deletions(-) create mode 100644 keras_hub/src/models/gemma3/gemma3_mean_pooling_test.py diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 3e79b5e01d..650487dcb1 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -300,9 +300,6 @@ from keras_hub.src.models.gemma3.gemma3_backbone import ( Gemma3Backbone as Gemma3Backbone, ) -from keras_hub.src.models.gemma3.gemma3_backbone import ( - Gemma3EmbeddingModel as Gemma3Embedding, -) from keras_hub.src.models.gemma3.gemma3_causal_lm import ( Gemma3CausalLM as Gemma3CausalLM, ) diff --git a/keras_hub/src/models/gemma3/gemma3_backbone.py b/keras_hub/src/models/gemma3/gemma3_backbone.py index 92fc0c7d80..5200e7195c 100644 --- a/keras_hub/src/models/gemma3/gemma3_backbone.py +++ b/keras_hub/src/models/gemma3/gemma3_backbone.py @@ -30,7 +30,9 @@ class Gemma3Backbone(Backbone): `keras_hub.models.Gemma3CausalLM`. This backbone can also function as an end-to-end embedding model by - setting the `pooling_mode` argument. + setting the `is_embedding_model` argument to `True`. When configured as an + embedding model with bi-directional attention, it matches the + `EmbeddingGemma` architecture. The default constructor gives a fully customizable, randomly initialized Gemma3 model with any vision encoder, number of heads, embedding dimensions, @@ -75,12 +77,17 @@ class Gemma3Backbone(Backbone): in all transformer blocks. Defaults to `1e-6`. dropout: float. Dropout probability for the Transformer decoder blocks. Defaults to `0`. - pooling_mode (str, optional): The pooling mode for the final output. - If set to `"mean"`, the model will add a mean pooling and a dense - projection layer to output a fixed-size embedding. In this case, - `embedding_dim` must also be set. Defaults to `None`. + is_embedding_model (bool, optional): If `True`, the model will function + as an embedding model. This adds mean pooling layer and a two-layer + dense projection head to the final sequence output. The model output + will be a dictionary containing `'sequence_output'` and + `'pooled_output'`. Defaults to `False`. + pooling_intermediate_dim (int, optional): The intermediate dimension of + the first dense layer in the two-layer pooling projection head. + Required if `is_embedding_model` is `True`. Defaults to `None`. embedding_dim (int, optional): The dimension of the final projected - embedding. Required if `pooling_mode` is `"mean"`. Defaults to `None`. + embedding. Required if `is_embedding_model` is `True`. Defaults to + `None`. dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use for the models computations and weights. Note that some computations, such as softmax and layer normalization will always @@ -209,7 +216,8 @@ def __init__( layer_norm_epsilon=1e-6, use_bidirectional_attention=False, dropout=0, - pooling_mode=None, + is_embedding_model=False, + pooling_intermediate_dim=None, embedding_dim=None, dtype=None, **kwargs, @@ -332,25 +340,33 @@ def __init__( ) sequence_output = self.layer_norm(x) - if pooling_mode is not None: - if pooling_mode != "mean": - raise ValueError(f"Unknown pooling mode: {pooling_mode}") - if embedding_dim is None: + if is_embedding_model: + if embedding_dim is None or pooling_intermediate_dim is None: raise ValueError( - "`embedding_dim` must be specified when `pooling_mode` is set." + "`embedding_dim` and `pooling_intermediate_dim` must be " + "specified when `is_embedding_model` is `True`." ) - + pooled_output = MeanPooling(dtype=dtype, name="mean_pooling")( - sequence_output=sequence_output, - padding_mask=padding_mask_input, + [sequence_output, padding_mask_input] ) - + + pooled_output = layers.Dense( + pooling_intermediate_dim, + dtype=dtype, + name="pooling_dense_1", + activation=None, + use_bias=False, + )(pooled_output) + pooled_output = layers.Dense( embedding_dim, dtype=dtype, name="embedding_projection", + activation=None, + use_bias=False, )(pooled_output) - + outputs = { "sequence_output": sequence_output, "pooled_output": pooled_output, @@ -373,7 +389,7 @@ def __init__( super().__init__( inputs=inputs, - outputs=sequence_output, + outputs=outputs, dtype=dtype, **kwargs, ) @@ -400,9 +416,10 @@ def __init__( self.use_bidirectional_attention = use_bidirectional_attention self.layer_norm_epsilon = layer_norm_epsilon self.dropout = dropout - self.pooling_mode = pooling_mode + self.is_embedding_model = is_embedding_model + self.pooling_intermediate_dim = pooling_intermediate_dim self.embedding_dim = embedding_dim - + # Keep `num_vision_tokens_per_image` as a backbone property for easy # access. if not text_only_model: @@ -442,7 +459,8 @@ def get_config(self): "use_bidirectional_attention": self.use_bidirectional_attention, "layer_norm_epsilon": self.layer_norm_epsilon, "dropout": self.dropout, - "pooling_mode": self.pooling_mode, + "is_embedding_model": self.is_embedding_model, + "pooling_intermediate_dim": self.pooling_intermediate_dim, "embedding_dim": self.embedding_dim, } ) @@ -466,4 +484,4 @@ def from_config(cls, config): } ) - return super().from_config(config) \ No newline at end of file + return super().from_config(config) diff --git a/keras_hub/src/models/gemma3/gemma3_backbone_test.py b/keras_hub/src/models/gemma3/gemma3_backbone_test.py index 38b7e5cbde..ce5bdafe8c 100644 --- a/keras_hub/src/models/gemma3/gemma3_backbone_test.py +++ b/keras_hub/src/models/gemma3/gemma3_backbone_test.py @@ -123,11 +123,13 @@ def test_backbone_basics(self, backbone_type): def test_embedding_mode(self): embedding_dim = 16 + pooling_intermediate_dim = 32 init_kwargs = self.text_init_kwargs.copy() input_data = self.text_backbone_input_data.copy() - init_kwargs["pooling_mode"] = "mean" + init_kwargs["is_embedding_model"] = True init_kwargs["embedding_dim"] = embedding_dim + init_kwargs["pooling_intermediate_dim"] = pooling_intermediate_dim model = Gemma3Backbone(**init_kwargs) outputs = model(input_data) @@ -138,7 +140,10 @@ def test_embedding_mode(self): expected_pooled_shape = (self.batch_size, embedding_dim) self.assertEqual(outputs["pooled_output"].shape, expected_pooled_shape) - + + self.assertEqual(model.count_params(), 6520) + self.assertEqual(len(model.layers), 13) + self.run_model_saving_test( cls=Gemma3Backbone, init_kwargs=init_kwargs, @@ -216,4 +221,4 @@ def test_all_presets(self): input_data=self.text_backbone_input_data if "_text" in preset or "1b" in preset else self.input_data, - ) \ No newline at end of file + ) diff --git a/keras_hub/src/models/gemma3/gemma3_mean_pooling.py b/keras_hub/src/models/gemma3/gemma3_mean_pooling.py index dfd1f7d2e7..d5c0e453ac 100644 --- a/keras_hub/src/models/gemma3/gemma3_mean_pooling.py +++ b/keras_hub/src/models/gemma3/gemma3_mean_pooling.py @@ -1,17 +1,3 @@ -# Copyright 2024 The KerasCV Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import keras from keras import ops @@ -22,23 +8,14 @@ class MeanPooling(keras.layers.Layer): This layer correctly handles variable-length sequences by ignoring padded tokens in the mean calculation, using a `padding_mask`. - Call arguments: - sequence_output: A tensor of shape `(batch_size, seq_len, hidden_dim)`. - padding_mask: A tensor of shape `(batch_size, seq_len)` with `1` - for valid tokens and `0` for padded tokens. - - Returns: - A tensor of shape `(batch_size, hidden_dim)`. - Example: ```python + import numpy as np + sequence_output = np.random.rand(2, 4, 8).astype("float32") - padding_mask = np.array([[1, 1, 1, 0], [1, 1, 0, 0]]) + padding_mask = np.array([[1, 1, 1, 0], [1, 1, 0, 0]], dtype="int32") mean_pool_layer = MeanPooling() - pooled = mean_pool_layer( - sequence_output=sequence_output, - padding_mask=padding_mask - ) + pooled = mean_pool_layer([sequence_output, padding_mask]) # pooled.shape -> (2, 8) ``` """ @@ -47,42 +24,50 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self.supports_masking = True - def call(self, sequence_output, padding_mask): - """ - Computes the masked mean pooling. + def call(self, inputs): + """Performs masked mean pooling on the token embeddings. Args: - sequence_output: The tensor of token embeddings. - padding_mask: The mask indicating which tokens to consider. + inputs: A list or tuple of two tensors: + - sequence_output: The sequence of embeddings to pool, with a + shape of `(batch_size, seq_len, hidden_dim)`. + - padding_mask: The mask indicating valid tokens, with a shape + of `(batch_size, seq_len)`. + + Returns: + A tensor representing the pooled embeddings, with a shape of + `(batch_size, hidden_dim)`. """ - # Expand the mask to match the dimensions of the sequence output + sequence_output, padding_mask = inputs mask = ops.expand_dims( ops.cast(padding_mask, sequence_output.dtype), axis=-1 ) - # Apply the mask to zero out the padded tokens masked_output = sequence_output * mask - # Sum the embeddings of the valid tokens sum_embeddings = ops.sum(masked_output, axis=1) - # Count the number of valid tokens for each sequence num_tokens = ops.sum( ops.cast(padding_mask, sequence_output.dtype), axis=1 ) num_tokens = ops.expand_dims(num_tokens, axis=1) - - # Add a small epsilon to avoid division by zero for empty sequences num_tokens = ops.maximum(num_tokens, 1e-9) - # Compute the mean by dividing the sum by the count of valid tokens mean_embeddings = sum_embeddings / num_tokens return mean_embeddings - def compute_output_shape(self, sequence_output_shape, padding_mask_shape): - """Computes the output shape of the layer.""" + def compute_output_shape(self, input_shape): + """Computes the output shape of the layer. + + Args: + input_shape: A tuple of input shapes. + + Returns: + A tuple representing the output shape. + """ + sequence_output_shape, _ = input_shape return (sequence_output_shape[0], sequence_output_shape[2]) def get_config(self): """Returns the config of the layer.""" - return super().get_config() \ No newline at end of file + return super().get_config() diff --git a/keras_hub/src/models/gemma3/gemma3_mean_pooling_test.py b/keras_hub/src/models/gemma3/gemma3_mean_pooling_test.py new file mode 100644 index 0000000000..60f2dddf15 --- /dev/null +++ b/keras_hub/src/models/gemma3/gemma3_mean_pooling_test.py @@ -0,0 +1,106 @@ +import numpy as np +from absl.testing import parameterized +from keras import ops +from keras.src import testing + +from keras_hub.src.models.gemma3.gemma3_mean_pooling import MeanPooling + + +class MeanPoolingTest(testing.TestCase, parameterized.TestCase): + def test_basic_pooling_with_padding(self): + """Tests if the pooling correctly averages non-padded tokens.""" + sequence_output = np.array( + [ + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [99.0, 99.0, 99.0], + ], + [ + [10.0, 11.0, 12.0], + [13.0, 14.0, 15.0], + [16.0, 17.0, 18.0], + ], + ], + dtype="float32", + ) + padding_mask = np.array( + [ + [1, 1, 0], + [1, 1, 1], + ], + dtype="int32", + ) + + expected_output = np.array( + [ + [2.5, 3.5, 4.5], + [13.0, 14.0, 15.0], + ], + dtype="float32", + ) + + layer = MeanPooling() + inputs = [ + ops.convert_to_tensor(sequence_output), + ops.convert_to_tensor(padding_mask), + ] + actual_output = layer(inputs) + + self.assertAllClose(actual_output, expected_output) + + def test_all_padded_sequence(self): + """Tests that an all-padded sequence results in a zero vector.""" + sequence_output = np.random.rand(2, 3, 4).astype("float32") + padding_mask = np.array( + [ + [1, 1, 0], + [0, 0, 0], + ], + dtype="int32", + ) + expected_output_for_padded_seq = np.zeros(4, dtype="float32") + + layer = MeanPooling() + inputs = [ + ops.convert_to_tensor(sequence_output), + ops.convert_to_tensor(padding_mask), + ] + actual_output = layer(inputs) + + self.assertAllClose(actual_output[1], expected_output_for_padded_seq) + + @parameterized.named_parameters( + ("float32", "float32"), + ("float16", "float16"), + ) + def test_different_dtypes(self, dtype): + """Ensures the layer works with various float dtypes.""" + sequence_output = np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=dtype) + padding_mask = np.array([[1, 0]], dtype="int32") + expected_output = np.array([[1.0, 2.0]], dtype=dtype) + + layer = MeanPooling(dtype=dtype) + inputs = [ + ops.convert_to_tensor(sequence_output), + ops.convert_to_tensor(padding_mask), + ] + actual_output = layer(inputs) + + self.assertAllClose(actual_output, expected_output) + + def test_shape_computation(self): + """Validates the compute_output_shape method.""" + layer = MeanPooling() + input_shape = [(2, 10, 16), (2, 10)] + output_shape = layer.compute_output_shape(input_shape) + expected_shape = (2, 16) + self.assertEqual(output_shape, expected_shape) + + def test_config_serialization(self): + """Tests that the layer can be successfully saved and loaded.""" + layer = MeanPooling(name="mean_pooling_test") + config = layer.get_config() + new_layer = MeanPooling.from_config(config) + self.assertEqual(new_layer.name, layer.name) + self.assertIsInstance(new_layer, MeanPooling) From 3fd0345478750ccb5adc5e7d1407ba07b16117aa Mon Sep 17 00:00:00 2001 From: Suhana Date: Mon, 13 Oct 2025 16:44:50 +0530 Subject: [PATCH 7/7] Adding conversion script --- .../convert_embedding_gemma_checkpoints.py | 131 ++++++++++++++++++ 1 file changed, 131 insertions(+) create mode 100644 tools/checkpoint_conversion/convert_embedding_gemma_checkpoints.py diff --git a/tools/checkpoint_conversion/convert_embedding_gemma_checkpoints.py b/tools/checkpoint_conversion/convert_embedding_gemma_checkpoints.py new file mode 100644 index 0000000000..2e0d5d27e5 --- /dev/null +++ b/tools/checkpoint_conversion/convert_embedding_gemma_checkpoints.py @@ -0,0 +1,131 @@ +""" +Convert a pre-trained causal Gemma3 Keras model to an Embedding Gemma model. + +This script takes a standard Gemma3 model designed for causal language modeling +and adapts it for sentence embedding tasks. It modifies the model architecture +for bi-directional attention, adds a new pooling head for generating fixed-size +embeddings, and transfers the weights from the original model. + +Setup: +```shell +# Make sure to install the necessary libraries, including the specific +# keras_hub package containing the Gemma3 models. +pip install keras-hub +pip install keras + +Usage: +```shell +cd tools/checkpoint_conversion +python3 convert_embedding_gemma_checkpoints.py \ + --source_preset gemma3_instruct_4b_text \ + --output_preset embedding_gemma3_4b_en \ + --pooling_intermediate_dim 4096 \ + --embedding_dim 768 +""" + +import argparse +import os + +os.environ["KERAS_BACKEND"] = "jax" +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +import keras + +from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone +from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer + + +def convert_to_embedding_preset( + source_preset: str, + output_preset: str, + pooling_intermediate_dim: int, + embedding_dim: int, +): + """ + Converts a standard causal Gemma3 preset to an Embedding Gemma preset. + + This function loads a pre-trained causal Gemma3 backbone, reconfigures it + for bi-directional attention and adds a pooling head, transfers the original + weights, and saves the result as a new Keras Hub preset. + + Args: + source_preset (str): The path or name of source causal Gemma3 preset. + output_preset (str): The path to save the new embedding model preset. + pooling_intermediate_dim (int): The intermediate dimension for the + pooling head's dense layer. + embedding_dim (int): The final output dimension of sentence embedding. + """ + source_model = Gemma3Backbone.from_preset(source_preset) + source_tokenizer = Gemma3Tokenizer.from_preset(source_preset) + + config = source_model.get_config() + + config["is_embedding_model"] = True + config["use_bidirectional_attention"] = True + config["pooling_intermediate_dim"] = pooling_intermediate_dim + config["embedding_dim"] = embedding_dim + + if config.get("vision_encoder") is not None: + config["vision_encoder"] = keras.layers.deserialize( + config["vision_encoder"] + ) + + embedding_model = Gemma3Backbone.from_config(config) + + transferred_layers = 0 + source_layer_names = {layer.name for layer in source_model.layers} + + for target_layer in embedding_model.layers: + if target_layer.name in source_layer_names: + source_layer = source_model.get_layer(name=target_layer.name) + if source_layer.get_weights(): + target_layer.set_weights(source_layer.get_weights()) + transferred_layers += 1 + + os.makedirs(output_preset, exist_ok=True) + embedding_model.save_to_preset(output_preset) + source_tokenizer.save_to_preset(output_preset) + print(f"Embedding Gemma preset successfully saved to: '{output_preset}'") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Convert a pre-trained causal Gemma3 model to " + "Embedding Gemma model.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--source_preset", + type=str, + required=True, + help="Path or name of the source causal Gemma3 preset " + "(e.g., 'gemma3_instruct_4b_text').", + ) + parser.add_argument( + "--output_preset", + type=str, + required=True, + help="Path to save the new Embedding Gemma preset " + "(e.g., 'embedding_gemma3_4b_en').", + ) + parser.add_argument( + "--pooling_intermediate_dim", + type=int, + default=4096, + help="Intermediate dimension for the pooling head's first dense layer.", + ) + parser.add_argument( + "--embedding_dim", + type=int, + default=768, + help="The final output dimension of the embedding projection.", + ) + + args = parser.parse_args() + + convert_to_embedding_preset( + source_preset=args.source_preset, + output_preset=args.output_preset, + pooling_intermediate_dim=args.pooling_intermediate_dim, + embedding_dim=args.embedding_dim, + )