From bbc5b8484c2031522eb0f4d482ad35ada3c501c7 Mon Sep 17 00:00:00 2001 From: Valko Milev Date: Wed, 20 Aug 2025 23:07:32 +0300 Subject: [PATCH 1/4] adjusted the backbone and switched the tokenizer --- keras_hub/src/models/mistral/mistral_attention.py | 12 +++++++----- keras_hub/src/utils/transformers/convert_mistral.py | 6 ++++-- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/keras_hub/src/models/mistral/mistral_attention.py b/keras_hub/src/models/mistral/mistral_attention.py index 6916133b78..d86d8243c0 100644 --- a/keras_hub/src/models/mistral/mistral_attention.py +++ b/keras_hub/src/models/mistral/mistral_attention.py @@ -45,6 +45,7 @@ def __init__( self._rope_scaling_factor = rope_scaling_factor def build(self, inputs_shape): + print("inputs_shape",inputs_shape) # Einsum variables: # b = batch size # q = query length @@ -54,9 +55,10 @@ def build(self, inputs_shape): # v = num key/value heads # h = head dim self._hidden_dim = inputs_shape[-1] + print("self._hidden_dim // self._num_query_heads",self._hidden_dim , self._num_query_heads) self._head_dim = self._hidden_dim // self._num_query_heads self._inv_norm_factor = 1.0 / math.sqrt(self._head_dim) - + print("(None, self._num_query_heads, self._head_dim)",(None, self._num_query_heads, self._head_dim)) self._query_dense = keras.layers.EinsumDense( equation="bqm,muh->bquh", output_shape=(None, self._num_query_heads, self._head_dim), @@ -64,7 +66,7 @@ def build(self, inputs_shape): dtype=self.dtype_policy, name="query", ) - self._query_dense.build(inputs_shape) + self._query_dense.build((None,None,4096))#inputs_shape self._key_dense = keras.layers.EinsumDense( equation="bkm,mvh->bkvh", @@ -77,7 +79,7 @@ def build(self, inputs_shape): dtype=self.dtype_policy, name="key", ) - self._key_dense.build(inputs_shape) + self._key_dense.build((None,None,4096))#input_shape self._value_dense = keras.layers.EinsumDense( equation="bkm,mvh->bkvh", @@ -90,7 +92,7 @@ def build(self, inputs_shape): dtype=self.dtype_policy, name="value", ) - self._value_dense.build(inputs_shape) + self._value_dense.build((None,None,4096)) self._softmax = keras.layers.Softmax( axis=-1, @@ -111,7 +113,7 @@ def build(self, inputs_shape): name="attention_output", ) self._output_dense.build( - (None, None, self._num_query_heads, self._head_dim) + (None, None, self._num_query_heads, 128)#self._head_dim) ) self.rotary_embedding_layer = RotaryEmbedding( diff --git a/keras_hub/src/utils/transformers/convert_mistral.py b/keras_hub/src/utils/transformers/convert_mistral.py index 9c52a708ef..6433fab890 100644 --- a/keras_hub/src/utils/transformers/convert_mistral.py +++ b/keras_hub/src/utils/transformers/convert_mistral.py @@ -50,7 +50,7 @@ def convert_weights(backbone, loader, transformers_config): hf_weight_key=f"model.layers.{index}.post_attention_layernorm.weight", hook_fn=lambda hf_tensor, _: hf_tensor.astype(np.float16), ) - + print("decoder_layer._self_attention_layer._query_dense.kernel",decoder_layer._self_attention_layer._query_dense.kernel,index) # Attention layers loader.port_weight( keras_variable=decoder_layer._self_attention_layer._query_dense.kernel, @@ -59,6 +59,8 @@ def convert_weights(backbone, loader, transformers_config): np.transpose(hf_tensor.astype(np.float16)), keras_shape ), ) + print("decoder_layer._self_attention_layer._key_dense.kernel",decoder_layer._self_attention_layer._key_dense.kernel,index) + loader.port_weight( keras_variable=decoder_layer._self_attention_layer._key_dense.kernel, hf_weight_key=f"model.layers.{index}.self_attn.k_proj.weight", @@ -113,4 +115,4 @@ def convert_weights(backbone, loader, transformers_config): def convert_tokenizer(cls, preset, **kwargs): - return cls(get_file(preset, "tokenizer.model"), **kwargs) + return cls(get_file(preset, "tekken.json"),**kwargs)#)"tokenizer.model"), **kwargs) From 08e0291a27e50bc662d2ad33f6bb7d913597a25b Mon Sep 17 00:00:00 2001 From: Valko Milev Date: Thu, 28 Aug 2025 21:58:54 +0300 Subject: [PATCH 2/4] refactoring --- .../src/models/mistral/mistral_attention.py | 23 +++++++++++-------- .../src/models/mistral/mistral_backbone.py | 2 ++ .../mistral/mistral_transformer_decoder.py | 4 +++- .../src/utils/transformers/preset_loader.py | 7 +++++- 4 files changed, 25 insertions(+), 11 deletions(-) diff --git a/keras_hub/src/models/mistral/mistral_attention.py b/keras_hub/src/models/mistral/mistral_attention.py index d86d8243c0..02322663bd 100644 --- a/keras_hub/src/models/mistral/mistral_attention.py +++ b/keras_hub/src/models/mistral/mistral_attention.py @@ -26,6 +26,7 @@ def __init__( rope_scaling_factor=1.0, kernel_initializer="glorot_uniform", sliding_window=512, + mistral_type='Mistral', dropout=0, **kwargs, ): @@ -34,7 +35,7 @@ def __init__( self._num_key_value_heads = num_key_value_heads self._sliding_window = sliding_window self._dropout = dropout - + self._type = mistral_type self._num_key_value_groups = num_query_heads // num_key_value_heads self._rope_max_wavelength = rope_max_wavelength @@ -45,7 +46,6 @@ def __init__( self._rope_scaling_factor = rope_scaling_factor def build(self, inputs_shape): - print("inputs_shape",inputs_shape) # Einsum variables: # b = batch size # q = query length @@ -55,18 +55,21 @@ def build(self, inputs_shape): # v = num key/value heads # h = head dim self._hidden_dim = inputs_shape[-1] - print("self._hidden_dim // self._num_query_heads",self._hidden_dim , self._num_query_heads) self._head_dim = self._hidden_dim // self._num_query_heads self._inv_norm_factor = 1.0 / math.sqrt(self._head_dim) - print("(None, self._num_query_heads, self._head_dim)",(None, self._num_query_heads, self._head_dim)) + + self._query_dense = keras.layers.EinsumDense( equation="bqm,muh->bquh", output_shape=(None, self._num_query_heads, self._head_dim), kernel_initializer=self._kernel_initializer, dtype=self.dtype_policy, - name="query", + name= + "query", ) - self._query_dense.build((None,None,4096))#inputs_shape + if self._type == 'devstral': + inputs_shape = (None,None,4096) + self._query_dense.build(inputs_shape) self._key_dense = keras.layers.EinsumDense( equation="bkm,mvh->bkvh", @@ -79,7 +82,7 @@ def build(self, inputs_shape): dtype=self.dtype_policy, name="key", ) - self._key_dense.build((None,None,4096))#input_shape + self._key_dense.build(inputs_shape) self._value_dense = keras.layers.EinsumDense( equation="bkm,mvh->bkvh", @@ -92,7 +95,7 @@ def build(self, inputs_shape): dtype=self.dtype_policy, name="value", ) - self._value_dense.build((None,None,4096)) + self._value_dense.build(inputs_shape) self._softmax = keras.layers.Softmax( axis=-1, @@ -112,8 +115,10 @@ def build(self, inputs_shape): dtype=self.dtype_policy, name="attention_output", ) + if self._type == 'devstral': + self._head_dim = 128 self._output_dense.build( - (None, None, self._num_query_heads, 128)#self._head_dim) + (None, None, self._num_query_heads, self._head_dim) ) self.rotary_embedding_layer = RotaryEmbedding( diff --git a/keras_hub/src/models/mistral/mistral_backbone.py b/keras_hub/src/models/mistral/mistral_backbone.py index 09a5d38129..76319729e4 100644 --- a/keras_hub/src/models/mistral/mistral_backbone.py +++ b/keras_hub/src/models/mistral/mistral_backbone.py @@ -101,6 +101,7 @@ def __init__( layer_norm_epsilon=1e-6, sliding_window=512, dropout=0, + mistral_type='Mistral', dtype=None, **kwargs, ): @@ -127,6 +128,7 @@ def __init__( sliding_window=sliding_window, dropout=dropout, dtype=dtype, + mistral_type=mistral_type, name=f"transformer_layer_{i}", ) self.transformer_layers.append(layer) diff --git a/keras_hub/src/models/mistral/mistral_transformer_decoder.py b/keras_hub/src/models/mistral/mistral_transformer_decoder.py index 79d5e93f7a..ba17dc5352 100644 --- a/keras_hub/src/models/mistral/mistral_transformer_decoder.py +++ b/keras_hub/src/models/mistral/mistral_transformer_decoder.py @@ -31,6 +31,7 @@ def __init__( kernel_initializer="glorot_uniform", sliding_window=512, dropout=0, + mistral_type='Mistral', **kwargs, ): super().__init__(**kwargs) @@ -40,7 +41,7 @@ def __init__( self.rope_max_wavelength = rope_max_wavelength self.rope_scaling_factor = rope_scaling_factor - + self.mistral_type = mistral_type self.dropout = dropout self.sliding_window = sliding_window @@ -64,6 +65,7 @@ def build(self, decoder_sequence_shape): kernel_initializer=clone_initializer(self.kernel_initializer), dropout=self.dropout, dtype=self.dtype_policy, + mistral_type=self.mistral_type, name="self_attention", ) self._self_attention_layer.build(decoder_sequence_shape) diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index 4accea67a1..5365c37bd7 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -20,7 +20,7 @@ from keras_hub.src.utils.transformers import convert_qwen_moe from keras_hub.src.utils.transformers import convert_vit from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader - +import re class TransformersPresetLoader(PresetLoader): def __init__(self, preset, config): @@ -70,7 +70,12 @@ def check_backbone_class(self): def load_backbone(self, cls, load_weights, **kwargs): keras_config = self.converter.convert_backbone_config(self.config) + + if re.search(r'devstral', self.preset,re.I): + keras_config["mistral_type"] = "devstral" + backbone = cls(**{**keras_config, **kwargs}) + if load_weights: jax_memory_cleanup(backbone) with SafetensorLoader(self.preset) as loader: From c999be0570dd0c9c1f36e44b943b9f8677905faf Mon Sep 17 00:00:00 2001 From: Valko Milev Date: Thu, 28 Aug 2025 22:03:33 +0300 Subject: [PATCH 3/4] added new tokenizer and removed print statements --- keras_hub/src/utils/transformers/convert_mistral.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/keras_hub/src/utils/transformers/convert_mistral.py b/keras_hub/src/utils/transformers/convert_mistral.py index 6433fab890..bad36b1bd8 100644 --- a/keras_hub/src/utils/transformers/convert_mistral.py +++ b/keras_hub/src/utils/transformers/convert_mistral.py @@ -4,7 +4,7 @@ from keras_hub.src.utils.preset_utils import get_file backbone_cls = MistralBackbone - +import re def convert_backbone_config(transformers_config): return { @@ -50,7 +50,6 @@ def convert_weights(backbone, loader, transformers_config): hf_weight_key=f"model.layers.{index}.post_attention_layernorm.weight", hook_fn=lambda hf_tensor, _: hf_tensor.astype(np.float16), ) - print("decoder_layer._self_attention_layer._query_dense.kernel",decoder_layer._self_attention_layer._query_dense.kernel,index) # Attention layers loader.port_weight( keras_variable=decoder_layer._self_attention_layer._query_dense.kernel, @@ -59,7 +58,6 @@ def convert_weights(backbone, loader, transformers_config): np.transpose(hf_tensor.astype(np.float16)), keras_shape ), ) - print("decoder_layer._self_attention_layer._key_dense.kernel",decoder_layer._self_attention_layer._key_dense.kernel,index) loader.port_weight( keras_variable=decoder_layer._self_attention_layer._key_dense.kernel, @@ -114,5 +112,9 @@ def convert_weights(backbone, loader, transformers_config): ) + def convert_tokenizer(cls, preset, **kwargs): - return cls(get_file(preset, "tekken.json"),**kwargs)#)"tokenizer.model"), **kwargs) + tokenizer_name = "tokenizer.model" + if re.search(r'devstral', preset,re.I): + tokenizer_name = "tekken.json" + return cls(get_file(preset, tokenizer_name), **kwargs) From 7bab8537204df541f6fec4e506fc6607b8c65e45 Mon Sep 17 00:00:00 2001 From: Valko Milev Date: Thu, 28 Aug 2025 23:29:31 +0300 Subject: [PATCH 4/4] refactoring --- keras_hub/src/utils/transformers/convert_mistral.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/keras_hub/src/utils/transformers/convert_mistral.py b/keras_hub/src/utils/transformers/convert_mistral.py index bad36b1bd8..7f7f421de5 100644 --- a/keras_hub/src/utils/transformers/convert_mistral.py +++ b/keras_hub/src/utils/transformers/convert_mistral.py @@ -1,10 +1,11 @@ +import re import numpy as np from keras_hub.src.models.mistral.mistral_backbone import MistralBackbone from keras_hub.src.utils.preset_utils import get_file backbone_cls = MistralBackbone -import re + def convert_backbone_config(transformers_config): return {