@@ -89,12 +89,14 @@ class ModelBase:
89
89
block_count : int
90
90
tensor_map : gguf .TensorNameMap
91
91
92
+ is_mistral_format : bool = False
93
+
92
94
def __init__ (self , dir_model : Path , ftype : gguf .LlamaFileType , fname_out : Path , * , is_big_endian : bool = False ,
93
95
use_temp_file : bool = False , eager : bool = False ,
94
96
metadata_override : Path | None = None , model_name : str | None = None ,
95
97
split_max_tensors : int = 0 , split_max_size : int = 0 , dry_run : bool = False ,
96
98
small_first_shard : bool = False , hparams : dict [str , Any ] | None = None ,
97
- remote_hf_model_id : str | None = None , n_ctx : int = 0 , is_mistral_format : bool = False ):
99
+ remote_hf_model_id : str | None = None ):
98
100
if type (self ) is ModelBase or \
99
101
type (self ) is TextModel or \
100
102
type (self ) is MmprojModel :
@@ -109,11 +111,6 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
109
111
self .use_temp_file = use_temp_file
110
112
self .lazy = not eager or (remote_hf_model_id is not None )
111
113
self .remote_hf_model_id = remote_hf_model_id
112
- self .n_ctx = n_ctx
113
- self .is_mistral_format = is_mistral_format
114
-
115
- if is_mistral_format and not n_ctx :
116
- raise ValueError ("Please pass the context length using --ctx when using mistral formats." )
117
114
118
115
if remote_hf_model_id is not None :
119
116
self .is_safetensors = True
@@ -127,12 +124,12 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
127
124
128
125
self .get_tensors = get_remote_tensors
129
126
else :
130
- prefix = "model" if not is_mistral_format else "consolidated"
127
+ prefix = "model" if not self . is_mistral_format else "consolidated"
131
128
self .part_names = ModelBase .get_model_part_names (self .dir_model , prefix , ".safetensors" )
132
129
self .is_safetensors = len (self .part_names ) > 0
133
130
if not self .is_safetensors :
134
131
self .part_names = ModelBase .get_model_part_names (self .dir_model , "pytorch_model" , ".bin" )
135
- self .hparams = ModelBase .load_hparams (self .dir_model , is_mistral_format ) if hparams is None else hparams
132
+ self .hparams = ModelBase .load_hparams (self .dir_model , self . is_mistral_format ) if hparams is None else hparams
136
133
self .tensor_names = None
137
134
self .metadata_override = metadata_override
138
135
self .model_name = model_name
@@ -296,14 +293,6 @@ def prepare_tensors(self):
296
293
break
297
294
298
295
for new_name , data_torch in (self .modify_tensors (data_torch , name , bid )):
299
- # hard coded for pixtral
300
- if name == "vision_language_adapter.w_in.weight" :
301
- assert new_name == "mm.23.weight" , new_name
302
- new_name = "mm.1.weight"
303
- elif name == "vision_language_adapter.w_out.weight" :
304
- assert new_name == "mm.23.weight" , new_name
305
- new_name = "mm.2.weight"
306
-
307
296
# TODO: why do we squeeze here?
308
297
# data = data_torch.squeeze().numpy()
309
298
data = data_torch .numpy ()
@@ -566,12 +555,7 @@ def prepare_metadata(self, vocab_only: bool):
566
555
def set_gguf_parameters (self ):
567
556
self .gguf_writer .add_block_count (self .block_count )
568
557
569
- if self .is_mistral_format :
570
- n_ctx = self .n_ctx
571
- else :
572
- n_ctx = self .find_hparam (["max_position_embeddings" , "n_ctx" , "n_positions" , "max_length" ], optional = True )
573
-
574
- if n_ctx is not None :
558
+ if (n_ctx := self .find_hparam (["max_position_embeddings" , "n_ctx" , "n_positions" , "max_length" ], optional = True )) is not None :
575
559
self .gguf_writer .add_context_length (n_ctx )
576
560
logger .info (f"gguf: context length = { n_ctx } " )
577
561
@@ -2013,6 +1997,9 @@ def __init__(self, *args, **kwargs):
2013
1997
self .hparams ["num_attention_heads" ] = self .hparams .get ("num_attention_heads" , 32 )
2014
1998
2015
1999
def set_vocab (self ):
2000
+ if self .is_mistral_format :
2001
+ self ._set_vocab_mistral ()
2002
+ return
2016
2003
try :
2017
2004
self ._set_vocab_sentencepiece ()
2018
2005
except FileNotFoundError :
@@ -2048,7 +2035,9 @@ def set_vocab(self):
2048
2035
def set_gguf_parameters (self ):
2049
2036
super ().set_gguf_parameters ()
2050
2037
hparams = self .hparams
2051
- self .gguf_writer .add_vocab_size (hparams ["vocab_size" ])
2038
+
2039
+ if not self .is_mistral_format :
2040
+ self .gguf_writer .add_vocab_size (hparams ["vocab_size" ])
2052
2041
2053
2042
if (rope_dim := hparams .get ("head_dim" )) is None :
2054
2043
rope_dim = hparams ["hidden_size" ] // hparams ["num_attention_heads" ]
@@ -2070,12 +2059,24 @@ def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
2070
2059
_experts : list [dict [str , Tensor ]] | None = None
2071
2060
2072
2061
def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
2073
- n_head = self .hparams ["num_attention_heads" ]
2074
- n_kv_head = self .hparams .get ("num_key_value_heads" )
2062
+ n_head = self .find_hparam (["n_heads" , "num_attention_heads" ])
2063
+ n_kv_head = self .find_hparam (["n_kv_heads" , "num_key_value_heads" ])
2064
+
2065
+ vision_prefixes = [
2066
+ "vision_encoder." ,
2067
+ "vision_language_adapter." ,
2068
+ "patch_merger." ,
2069
+ "pre_mm_projector_norm" ,
2070
+ ]
2071
+
2075
2072
is_vision_tensor = "vision_tower" in name \
2076
2073
or "vision_model" in name \
2077
2074
or "model.connector" in name \
2078
- or "multi_modal_projector" in name
2075
+ or "multi_modal_projector" in name \
2076
+ or any (
2077
+ name .startswith (prefix )
2078
+ for prefix in vision_prefixes
2079
+ )
2079
2080
2080
2081
if is_vision_tensor :
2081
2082
return [] # skip vision tensors
@@ -2191,13 +2192,16 @@ class LlavaVisionModel(MmprojModel):
2191
2192
2192
2193
def __init__ (self , * args , ** kwargs ):
2193
2194
super ().__init__ (* args , ** kwargs )
2194
- if self .hparams [ "model_type" ] == "pixtral" :
2195
+ if self .hparams . get ( "model_type" ) == "pixtral" :
2195
2196
# layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
2196
2197
self .hparams ["layer_norm_eps" ] = self .hparams .get ("layer_norm_eps" , 1e-5 )
2197
2198
self .img_break_tok_id = self .get_token_id ("[IMG_BREAK]" )
2198
- logger .info (f"Image break token id: { self .img_break_tok_id } " )
2199
+ elif self .is_mistral_format :
2200
+ self .hparams ["layer_norm_eps" ] = self .hparams .get ("norm_eps" , 1e-5 )
2201
+ self .img_break_tok_id = self .find_vparam (["image_break_token_id" ])
2199
2202
else :
2200
2203
raise ValueError (f"Unsupported model type: { self .hparams ['model_type' ]} " )
2204
+ logger .info (f"Image break token id: { self .img_break_tok_id } " )
2201
2205
2202
2206
def get_token_id (self , token : str ) -> int :
2203
2207
tokenizer_config_file = self .dir_model / 'tokenizer_config.json'
@@ -2211,7 +2215,7 @@ def get_token_id(self, token: str) -> int:
2211
2215
def set_gguf_parameters (self ):
2212
2216
super ().set_gguf_parameters ()
2213
2217
hparams = self .hparams
2214
- if hparams [ "model_type" ] == "pixtral" :
2218
+ if hparams . get ( "model_type" ) == "pixtral" :
2215
2219
self .gguf_writer .add_clip_projector_type (gguf .VisionProjectorType .PIXTRAL )
2216
2220
self .gguf_writer .add_vision_attention_layernorm_eps (hparams ["layer_norm_eps" ])
2217
2221
@@ -2229,18 +2233,30 @@ def set_gguf_parameters(self):
2229
2233
2230
2234
def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
2231
2235
del bid # unused
2232
- n_head = self .hparams ["num_attention_heads" ]
2236
+ n_head = (
2237
+ self .hparams ["num_attention_heads" ] if not self .is_mistral_format else self .find_vparam (["num_attention_heads" ])
2238
+ )
2233
2239
n_kv_head = n_head
2234
2240
2235
- if name .startswith ("multi_modal_projector." ) or name .startswith ("vision_tower." ):
2241
+ valid_prefixes = (
2242
+ "multi_modal_projector." ,
2243
+ "vision_tower." ,
2244
+ "vision_encoder." ,
2245
+ "vision_language_adapter." ,
2246
+ "patch_merger." ,
2247
+ "pre_mm_projector_norm" ,
2248
+ )
2249
+
2250
+ if any (name .startswith (prefix ) for prefix in valid_prefixes ):
2236
2251
# process vision tensors
2237
- if name .endswith (("q_proj.weight" , "q_proj.bias" )):
2252
+ if name .endswith (("q_proj.weight" , "q_proj.bias" )) and not self . is_mistral_format :
2238
2253
data_torch = LlamaModel .permute (data_torch , n_head , n_head )
2239
- if name .endswith (("k_proj.weight" , "k_proj.bias" )):
2254
+ if name .endswith (("k_proj.weight" , "k_proj.bias" )) and not self . is_mistral_format :
2240
2255
data_torch = LlamaModel .permute (data_torch , n_head , n_kv_head )
2241
2256
return [(self .map_tensor_name (name ), data_torch )]
2242
2257
2243
- if self .img_break_tok_id > 0 and "embed_tokens.weight" in name :
2258
+ embed_key = "embed_tokens.weight" if not self .is_mistral_format else "tok_embeddings.weight"
2259
+ if self .img_break_tok_id > 0 and embed_key in name :
2244
2260
logger .info (f"Extracting [IMG_BREAK] token embedding from { name } " )
2245
2261
# for pixtral model, we need to extract the [IMG_BREAK] token embedding
2246
2262
img_break_embd = data_torch [self .img_break_tok_id ]
@@ -7682,81 +7698,18 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
7682
7698
return [(self .map_tensor_name (name ), data_torch )]
7683
7699
7684
7700
7685
- class MistralModel (TextModel ):
7701
+ class MistralModel (LlamaModel ):
7702
+ model_arch = gguf .MODEL_ARCH .LLAMA
7686
7703
model_name = "Mistral"
7687
- model_arch = MODEL_ARCH .LLAMA
7704
+ hf_arch = ""
7705
+ is_mistral_format = True
7688
7706
undo_permute = True
7689
7707
7690
- def __init__ (self , * args , ** kwargs ):
7691
- super ().__init__ (* args , ** kwargs )
7692
-
7693
- def set_gguf_parameters (self ):
7694
- super ().set_gguf_parameters ()
7695
- hparams = self .hparams
7696
-
7697
- if "head_dim" in hparams :
7698
- rope_dim = hparams ["head_dim" ]
7699
- else :
7700
- rope_dim = hparams ["hidden_size" ] // hparams ["num_attention_heads" ]
7701
- self .gguf_writer .add_rope_dimension_count (rope_dim )
7702
-
7703
- rope_scaling = self .hparams .get ("rope_scaling" ) or {}
7704
- if (
7705
- rope_scaling .get ("rope_type" , rope_scaling .get ("type" )) == "linear"
7706
- and "factor" in rope_scaling
7707
- ):
7708
- self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .LINEAR )
7709
- self .gguf_writer .add_rope_scaling_factor (rope_scaling ["factor" ])
7710
7708
7711
- @staticmethod
7712
- def permute (weights : Tensor , n_head : int , n_head_kv : int | None ):
7713
- if n_head_kv is not None and n_head != n_head_kv :
7714
- n_head = n_head_kv
7715
- return (
7716
- weights .reshape (
7717
- n_head , 2 , weights .shape [0 ] // n_head // 2 , * weights .shape [1 :]
7718
- )
7719
- .swapaxes (1 , 2 )
7720
- .reshape (weights .shape )
7721
- )
7722
-
7723
- def modify_tensors (
7724
- self , data_torch : Tensor , name : str , bid : int | None
7725
- ) -> Iterable [tuple [str , Tensor ]]:
7726
- n_head = self .hparams ["n_heads" ]
7727
- n_kv_head = self .hparams .get ("n_kv_heads" )
7728
- is_vision_tensor = any (
7729
- name .startswith (prefix )
7730
- for prefix in [
7731
- "vision_encoder." ,
7732
- "vision_language_adapter." ,
7733
- "patch_merger." ,
7734
- "pre_mm_projector_norm" ,
7735
- ]
7736
- )
7737
-
7738
- if is_vision_tensor :
7739
- return [] # skip vision tensors
7740
-
7741
- if self .undo_permute :
7742
- if name .endswith (("q_proj.weight" , "q_proj.bias" )):
7743
- data_torch = self .permute (data_torch , n_head , n_head )
7744
- if name .endswith (("k_proj.weight" , "k_proj.bias" )):
7745
- data_torch = self .permute (data_torch , n_head , n_kv_head )
7746
-
7747
- return [(self .map_tensor_name (name ), data_torch )]
7748
-
7749
-
7750
- class PixtralModel (MmprojModel ):
7709
+ class PixtralModel (LlavaVisionModel ):
7751
7710
model_name = "Pixtral"
7752
- img_break_tok_id = - 1
7753
-
7754
- def __init__ (self , * args , ** kwargs ):
7755
- super ().__init__ (* args , ** kwargs )
7756
- # layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
7757
- self .hparams ["layer_norm_eps" ] = self .hparams .get ("norm_eps" , 1e-5 )
7758
- self .img_break_tok_id = self .find_vparam (["image_break_token_id" ])
7759
- logger .info (f"Image break token id: { self .img_break_tok_id } " )
7711
+ hf_arch = ""
7712
+ is_mistral_format = True
7760
7713
7761
7714
def set_gguf_parameters (self ):
7762
7715
super ().set_gguf_parameters ()
@@ -7774,38 +7727,13 @@ def set_gguf_parameters(self):
7774
7727
self .gguf_writer .add_vision_spatial_merge_size (
7775
7728
self .find_vparam (["spatial_merge_size" ])
7776
7729
)
7777
-
7778
- def modify_tensors (
7779
- self , data_torch : Tensor , name : str , bid : int | None
7780
- ) -> Iterable [tuple [str , Tensor ]]:
7781
- del bid # unused
7782
- n_head = self .find_vparam (["num_attention_heads" ])
7783
- n_kv_head = n_head
7784
-
7785
- if any (
7786
- name .startswith (prefix )
7787
- for prefix in [
7788
- "vision_encoder." ,
7789
- "vision_language_adapter." ,
7790
- "patch_merger." ,
7791
- "pre_mm_projector_norm" ,
7792
- ]
7793
- ):
7794
- # process vision tensors
7795
- if name .endswith (("q_proj.weight" , "q_proj.bias" )):
7796
- data_torch = MistralModel .permute (data_torch , n_head , n_head )
7797
- if name .endswith (("k_proj.weight" , "k_proj.bias" )):
7798
- data_torch = MistralModel .permute (data_torch , n_head , n_kv_head )
7799
- return [(self .map_tensor_name (name ), data_torch )]
7800
-
7801
- if self .img_break_tok_id > 0 and "tok_embeddings.weight" in name :
7802
- logger .info (f"Extracting [IMG_BREAK] token embedding from { name } " )
7803
- # for pixtral model, we need to extract the [IMG_BREAK] token embedding
7804
- img_break_embd = data_torch [self .img_break_tok_id ]
7805
- name = gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .V_TOK_EMBD_IMG_BREAK ]
7806
- return [(self .map_tensor_name (name ), img_break_embd )]
7807
-
7808
- return [] # skip other tensors
7730
+
7731
+ def map_tensor_name (self , name : str , try_suffixes : Sequence [str ] = (".weight" , ".bias" )) -> str :
7732
+ if name == "vision_language_adapter.w_in.weight" :
7733
+ return "mm.1.weight"
7734
+ elif name == "vision_language_adapter.w_out.weight" :
7735
+ return "mm.2.weight"
7736
+ return super ().map_tensor_name (name , try_suffixes )
7809
7737
7810
7738
7811
7739
###### CONVERSION LOGIC ######
@@ -7961,12 +7889,6 @@ def parse_args() -> argparse.Namespace:
7961
7889
"--mistral-format" , action = "store_true" ,
7962
7890
help = "Whether the model is stored following the Mistral format." ,
7963
7891
)
7964
- parser .add_argument (
7965
- "--n-ctx" ,
7966
- type = int ,
7967
- help = "Training context size" ,
7968
- default = 0
7969
- )
7970
7892
7971
7893
args = parser .parse_args ()
7972
7894
if not args .print_supported_models and args .model is None :
@@ -8099,8 +8021,6 @@ def main() -> None:
8099
8021
split_max_size = split_str_to_n_bytes (args .split_max_size ), dry_run = args .dry_run ,
8100
8022
small_first_shard = args .no_tensor_first_split ,
8101
8023
remote_hf_model_id = hf_repo_id ,
8102
- n_ctx = args .n_ctx ,
8103
- is_mistral_format = is_mistral_format
8104
8024
)
8105
8025
8106
8026
if args .vocab_only :
0 commit comments