Skip to content
79 changes: 78 additions & 1 deletion keras_hub/src/models/gemma3/gemma3_backbone.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import keras
from keras import ops
from keras import ops, layers

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.modeling.reversible_embedding import (
Expand Down Expand Up @@ -424,3 +424,80 @@ def from_config(cls, config):
)

return super().from_config(config)


class MeanPooling(layers.Layer):
"""
This layer calculates the mean of the token embeddings, ignoring
padded tokens.
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.supports_masking = True

def call(self, inputs):
sequence_output, padding_mask = inputs

mask = ops.expand_dims(
ops.cast(padding_mask, sequence_output.dtype), axis=-1
)
masked_output = sequence_output * mask
sum_embeddings = ops.sum(masked_output, axis=1)

num_tokens = ops.sum(ops.cast(padding_mask, sequence_output.dtype), axis=1)
num_tokens = ops.expand_dims(num_tokens, axis=1)

num_tokens = ops.maximum(num_tokens, 1e-9)

mean_embeddings = sum_embeddings / num_tokens
return mean_embeddings

def compute_output_shape(self, input_shape):
sequence_output_shape, padding_mask_shape = input_shape
return (sequence_output_shape[0], sequence_output_shape[2])


@keras_hub_export("keras_hub.models.Gemma3Embedding")
class Gemma3EmbeddingModel(keras.Model):

def __init__(
self,
backbone: Gemma3Backbone,
embedding_dim: int,
**kwargs,
):
self.backbone = backbone
self.pooling_layer = MeanPooling(
dtype=backbone.dtype, name="mean_pooling"
)
self.projection_layer = layers.Dense(
embedding_dim, dtype=backbone.dtype, name="embedding_projection"
)

inputs = self.backbone.input
sequence_output = self.backbone.outputs[0]
padding_mask = inputs["padding_mask"]

pooled_output = self.pooling_layer([sequence_output, padding_mask])

embedding = self.projection_layer(pooled_output)

super().__init__(inputs=inputs, outputs=embedding, **kwargs)

self.embedding_dim = embedding_dim

def get_config(self):
config = super().get_config()
config.update(
{
"backbone": keras.layers.serialize(self.backbone),
"embedding_dim": self.embedding_dim,
}
)
return config

@classmethod
def from_config(cls, config):
config["backbone"] = keras.layers.deserialize(config["backbone"])
return cls(**config)
93 changes: 93 additions & 0 deletions keras_hub/src/models/gemma3/gemma3_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from keras_hub.src.models.gemma3.gemma3_vision_encoder import (
Gemma3VisionEncoder,
)
from keras_hub.src.models.gemma3.gemma3_backbone import (
Gemma3EmbeddingModel,
)
from keras_hub.src.tests.test_case import TestCase


Expand Down Expand Up @@ -193,3 +196,93 @@ def test_all_presets(self):
if "_text" in preset or "1b" in preset
else self.input_data,
)

class Gemma3EmbeddingModelTest(TestCase, parameterized.TestCase):
def setUp(self):
self.batch_size = 2
self.vocabulary_size = 256
self.text_sequence_length = 64
self.hidden_dim = 8
self.embedding_dim = 16

self.backbone = Gemma3Backbone(
vocabulary_size=self.vocabulary_size,
image_size=16,
num_layers=2,
num_query_heads=2,
num_key_value_heads=1,
hidden_dim=self.hidden_dim,
intermediate_dim=32,
head_dim=4,
vision_encoder=None,
)

self.init_kwargs = {
"backbone": self.backbone,
"embedding_dim": self.embedding_dim,
}

dummy_text_token_ids = np.random.randint(
0,
self.vocabulary_size,
(self.batch_size, self.text_sequence_length),
)
padding_mask = np.ones(
(self.batch_size, self.text_sequence_length), dtype="int32"
)
padding_mask[0, -10:] = 0
padding_mask[1, -5:] = 0

self.input_data = {
"token_ids": dummy_text_token_ids,
"padding_mask": padding_mask,
}

def test_model_basics(self):
"""Test the model's forward pass and output shape."""
model = Gemma3EmbeddingModel(**self.init_kwargs)
output = model(self.input_data)
expected_output_shape = (self.batch_size, self.embedding_dim)
self.assertEqual(output.shape, expected_output_shape)

def test_architecture_characteristics(self):
"""Test parameter and layer counts."""
model = Gemma3EmbeddingModel(**self.init_kwargs)

backbone_params = self.backbone.count_params()
projection_params = (
self.hidden_dim * self.embedding_dim
) + self.embedding_dim
expected_params = backbone_params + projection_params

expected_layers = 8

self.assertEqual(model.count_params(), expected_params)
self.assertEqual(len(model.layers), expected_layers)

def test_saved_model(self):
self.run_model_saving_test(
cls=Gemma3EmbeddingModel,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)

@pytest.mark.kaggle_key_required
@pytest.mark.extra_large
def test_build_from_preset_backbone(self):
backbone = Gemma3Backbone.from_preset("gemma3_instruct_1b_text")
model = Gemma3EmbeddingModel(
backbone=backbone,
embedding_dim=768,
)

input_data = {
"token_ids": ops.array([[651, 4320, 8426, 25341, 235265]]),
"padding_mask": ops.ones((1, 5), dtype="int32"),
}

outputs = model(input_data)

self.assertEqual(outputs.shape, (1, 768))
norm = ops.vector_norm(outputs, axis=1)
self.assertGreater(norm[0], 0)
Loading