-
Notifications
You must be signed in to change notification settings - Fork 303
[SmolLM3] Add Backbone, CausalLM + Converter for HuggingFace Weights #2327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
divyashreepathihalli
merged 17 commits into
keras-team:master
from
DavidLandup0:smollm3
Oct 17, 2025
Merged
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
f2dedc4
add first few utils
DavidLandup0 1d90715
add eager attention forward
DavidLandup0 e5a8f33
Add SmolLM3Attention
DavidLandup0 54191ca
Add SmolLM3MLP
DavidLandup0 1369733
Add SmolLM3DecoderLayer
DavidLandup0 2448d80
remove unnecessary comments
DavidLandup0 598fd74
Add SmolLM3RotaryEmbedding
DavidLandup0 b9e458d
add most of smollm3backbone
DavidLandup0 6a53a7d
Fix calls within causal model
DavidLandup0 adb05d9
Move causal mask computation to forward call
DavidLandup0 4d14120
Fix rope and caching indexing
DavidLandup0 a94acc9
Remove unnecessary trimming of cache padding
DavidLandup0 2156a5f
Remove type hints, expad docstrings
DavidLandup0 2838c87
Add basic tests
DavidLandup0 9619703
Merge branch 'master' into smollm3
DavidLandup0 09c1dea
Run linter
DavidLandup0 3fe7f86
Merge branch 'master' into smollm3
DavidLandup0 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
import keras | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.layers.modeling.reversible_embedding import ( | ||
ReversibleEmbedding, | ||
) | ||
from keras_hub.src.models.backbone import Backbone | ||
from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3DecoderLayer | ||
|
||
|
||
@keras_hub_export( | ||
[ | ||
"keras_hub.models.SmolLM3Backbone", | ||
"keras_hub.models.SmolLMBackbone", | ||
] | ||
) | ||
class SmolLM3Backbone(Backbone): | ||
"""SmolLM3 core network with hyperparameters. | ||
|
||
This network implements a Transformer-based decoder network, | ||
SmolLM3, as described in the SmolLM3 model architecture. | ||
It includes the embedding lookups and transformer layers. | ||
|
||
The default constructor gives a fully customizable, randomly initialized | ||
SmolLM3 model with any number of layers, heads, and embedding | ||
dimensions. To load preset architectures and weights, use the `from_preset` | ||
constructor. | ||
|
||
Args: | ||
vocabulary_size: int. The size of the token vocabulary. | ||
hidden_dim: int. The size of the transformer hidden state at the end | ||
of each transformer layer. | ||
intermediate_dim: int. The output dimension of the first Dense layer in | ||
the MLP network of each transformer layer. | ||
num_layers: int. The number of transformer layers. | ||
num_attention_heads: int. The number of attention heads for each | ||
transformer layer. | ||
num_key_value_heads: int. The number of key-value heads for grouped | ||
query attention in each transformer layer. | ||
attention_bias: bool. Whether to use bias in the query, key, value, and | ||
output projection layers in the attention blocks. | ||
attention_dropout: float. Dropout probability for the attention layers. | ||
rope_layer_enabled_list: list of bool. List indicating whether RoPE | ||
(Rotary Position Embedding) is enabled for each layer. Typically, | ||
some layers may disable RoPE for architectural variations. | ||
layer_types: list of str. List of layer types for each transformer | ||
layer (e.g., "attention" or other custom types). | ||
mlp_bias: bool. Whether to use bias in the MLP (feedforward) layers. | ||
layer_norm_epsilon: float. Epsilon value for layer normalization layers | ||
to prevent division by zero. | ||
max_position_embeddings: int. The maximum sequence length that this | ||
model might ever be used with. | ||
rope_theta: float. The base period of the RoPE embeddings. | ||
partial_rotary_factor: float. The percentage of hidden dimensions to | ||
rotate in RoPE. A value of 1.0 rotates all dimensions, while values | ||
less than 1.0 only rotate a subset. | ||
|
||
Examples: | ||
|
||
```python | ||
input_data = { | ||
"token_ids": np.ones(shape=(1, 12), dtype="int32"), | ||
"padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), | ||
} | ||
|
||
# Pretrained SmolLM3 decoder. | ||
model = keras_hub.models.SmolLM3Backbone.from_preset( | ||
"hf://HuggingFaceTB/SmolLM3-3B" | ||
) | ||
model(input_data) | ||
|
||
# Randomly initialized SmolLM3 decoder with custom config. | ||
model = keras_hub.models.SmolLM3Backbone( | ||
vocabulary_size=49152, | ||
hidden_dim=576, | ||
intermediate_dim=1536, | ||
num_layers=30, | ||
num_attention_heads=9, | ||
num_key_value_heads=3, | ||
attention_bias=False, | ||
attention_dropout=0.0, | ||
rope_layer_enabled_list=[True] * 30, | ||
layer_types=["attention"] * 30, | ||
mlp_bias=False, | ||
layer_norm_epsilon=1e-5, | ||
max_position_embeddings=2048, | ||
rope_theta=10000.0, | ||
partial_rotary_factor=1.0, | ||
) | ||
model(input_data) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
vocabulary_size, | ||
hidden_dim, | ||
intermediate_dim, | ||
num_layers, | ||
num_attention_heads, | ||
num_key_value_heads, | ||
attention_bias, | ||
attention_dropout, | ||
rope_layer_enabled_list, | ||
layer_types, | ||
mlp_bias, | ||
layer_norm_epsilon, | ||
max_position_embeddings, | ||
rope_theta, | ||
partial_rotary_factor, | ||
**kwargs, | ||
): | ||
# === Layers === | ||
self.token_embedding = ReversibleEmbedding( | ||
input_dim=vocabulary_size, | ||
output_dim=hidden_dim, | ||
name="token_embedding", | ||
) | ||
self.transformer_layers = [] | ||
for i in range(num_layers): | ||
layer = SmolLM3DecoderLayer( | ||
hidden_size=hidden_dim, | ||
num_attention_heads=num_attention_heads, | ||
num_key_value_heads=num_key_value_heads, | ||
attention_bias=attention_bias, | ||
attention_dropout=attention_dropout, | ||
rope_layer_enabled_list=rope_layer_enabled_list, | ||
layer_types=layer_types, | ||
layer_idx=i, | ||
intermediate_size=intermediate_dim, | ||
mlp_bias=mlp_bias, | ||
layer_norm_epsilon=layer_norm_epsilon, | ||
max_position_embeddings=max_position_embeddings, | ||
rope_theta=rope_theta, | ||
partial_rotary_factor=partial_rotary_factor, | ||
name=f"transformer_layer_{i}", | ||
) | ||
self.transformer_layers.append(layer) | ||
|
||
self.norm = keras.layers.RMSNormalization( | ||
epsilon=layer_norm_epsilon, | ||
name="sequence_output_layernorm", | ||
) | ||
|
||
# === Functional Model === | ||
token_id_input = keras.Input( | ||
shape=(None,), dtype="int32", name="token_ids" | ||
) | ||
|
||
padding_mask_input = keras.Input( | ||
shape=(None,), dtype="int32", name="padding_mask" | ||
) | ||
|
||
x = self.token_embedding(token_id_input) | ||
|
||
for decoder_layer in self.transformer_layers: | ||
x = decoder_layer( | ||
x, | ||
decoder_padding_mask=padding_mask_input, | ||
**kwargs, | ||
) | ||
|
||
sequence_output = self.norm(x) | ||
super().__init__( | ||
inputs={ | ||
"token_ids": token_id_input, | ||
"padding_mask": padding_mask_input, | ||
}, | ||
outputs=sequence_output, | ||
**kwargs, | ||
) | ||
|
||
# === Config === | ||
self.vocabulary_size = vocabulary_size | ||
self.hidden_dim = hidden_dim | ||
self.intermediate_dim = intermediate_dim | ||
self.num_layers = num_layers | ||
self.num_attention_heads = num_attention_heads | ||
self.num_key_value_heads = num_key_value_heads | ||
self.attention_bias = attention_bias | ||
self.attention_dropout = attention_dropout | ||
self.rope_layer_enabled_list = rope_layer_enabled_list | ||
self.layer_types = layer_types | ||
self.mlp_bias = mlp_bias | ||
self.layer_norm_epsilon = layer_norm_epsilon | ||
self.max_position_embeddings = max_position_embeddings | ||
self.rope_theta = rope_theta | ||
self.partial_rotary_factor = partial_rotary_factor | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"vocabulary_size": self.vocabulary_size, | ||
"hidden_dim": self.hidden_dim, | ||
"intermediate_dim": self.intermediate_dim, | ||
"num_layers": self.num_layers, | ||
"num_attention_heads": self.num_attention_heads, | ||
"num_key_value_heads": self.num_key_value_heads, | ||
"attention_bias": self.attention_bias, | ||
"attention_dropout": self.attention_dropout, | ||
"rope_layer_enabled_list": self.rope_layer_enabled_list, | ||
"layer_types": self.layer_types, | ||
"mlp_bias": self.mlp_bias, | ||
"layer_norm_epsilon": self.layer_norm_epsilon, | ||
"max_position_embeddings": self.max_position_embeddings, | ||
"rope_theta": self.rope_theta, | ||
"partial_rotary_factor": self.partial_rotary_factor, | ||
} | ||
) | ||
return config |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import pytest | ||
from keras import ops | ||
|
||
from keras_hub.src.models.smollm3.smollm3_backbone import SmolLM3Backbone | ||
from keras_hub.src.tests.test_case import TestCase | ||
|
||
|
||
class SmolLM3BackboneTest(TestCase): | ||
def setUp(self): | ||
self.init_kwargs = { | ||
"vocabulary_size": 100, | ||
"hidden_dim": 64, | ||
"intermediate_dim": 128, | ||
"num_layers": 2, | ||
"num_attention_heads": 4, | ||
"num_key_value_heads": 2, | ||
"attention_bias": False, | ||
"attention_dropout": 0.0, | ||
"rope_layer_enabled_list": [True, True], | ||
"layer_types": ["attention", "attention"], | ||
"mlp_bias": False, | ||
"layer_norm_epsilon": 1e-5, | ||
"max_position_embeddings": 128, | ||
"rope_theta": 10000.0, | ||
"partial_rotary_factor": 1.0, | ||
} | ||
self.input_data = { | ||
"token_ids": ops.ones((2, 5), dtype="int32"), | ||
"padding_mask": ops.ones((2, 5), dtype="int32"), | ||
} | ||
|
||
def test_backbone_basics(self): | ||
self.run_backbone_test( | ||
cls=SmolLM3Backbone, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
expected_output_shape=(2, 5, 64), | ||
run_mixed_precision_check=False, | ||
run_quantization_check=False, | ||
) | ||
|
||
@pytest.mark.large | ||
def test_saved_model(self): | ||
self.run_model_saving_test( | ||
cls=SmolLM3Backbone, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
) | ||
|
||
def test_num_parameters(self): | ||
model = SmolLM3Backbone(**self.init_kwargs) | ||
# Reference value calculated from the model architecture | ||
self.assertEqual(model.count_params(), 80464) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Usually there's some of these terms (like the epsilon's and rope theta) that have a consistent value across all the presets we care about, and we give them defaults here. Not super important, just for people that wanted an easier time making a custom small version of the arch or something like that.