Skip to content

Commit 10ef34f

Browse files
committed
refactor
1 parent 758c8e5 commit 10ef34f

File tree

2 files changed

+64
-146
lines changed

2 files changed

+64
-146
lines changed

convert_hf_to_gguf.py

Lines changed: 64 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,14 @@ class ModelBase:
8989
block_count: int
9090
tensor_map: gguf.TensorNameMap
9191

92+
is_mistral_format: bool = False
93+
9294
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
9395
use_temp_file: bool = False, eager: bool = False,
9496
metadata_override: Path | None = None, model_name: str | None = None,
9597
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
9698
small_first_shard: bool = False, hparams: dict[str, Any] | None = None,
97-
remote_hf_model_id: str | None = None, n_ctx: int = 0, is_mistral_format: bool = False):
99+
remote_hf_model_id: str | None = None):
98100
if type(self) is ModelBase or \
99101
type(self) is TextModel or \
100102
type(self) is MmprojModel:
@@ -109,11 +111,6 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
109111
self.use_temp_file = use_temp_file
110112
self.lazy = not eager or (remote_hf_model_id is not None)
111113
self.remote_hf_model_id = remote_hf_model_id
112-
self.n_ctx = n_ctx
113-
self.is_mistral_format = is_mistral_format
114-
115-
if is_mistral_format and not n_ctx:
116-
raise ValueError("Please pass the context length using --ctx when using mistral formats.")
117114

118115
if remote_hf_model_id is not None:
119116
self.is_safetensors = True
@@ -127,12 +124,12 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
127124

128125
self.get_tensors = get_remote_tensors
129126
else:
130-
prefix = "model" if not is_mistral_format else "consolidated"
127+
prefix = "model" if not self.is_mistral_format else "consolidated"
131128
self.part_names = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
132129
self.is_safetensors = len(self.part_names) > 0
133130
if not self.is_safetensors:
134131
self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
135-
self.hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format) if hparams is None else hparams
132+
self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams
136133
self.tensor_names = None
137134
self.metadata_override = metadata_override
138135
self.model_name = model_name
@@ -296,14 +293,6 @@ def prepare_tensors(self):
296293
break
297294

298295
for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
299-
# hard coded for pixtral
300-
if name == "vision_language_adapter.w_in.weight":
301-
assert new_name == "mm.23.weight", new_name
302-
new_name = "mm.1.weight"
303-
elif name == "vision_language_adapter.w_out.weight":
304-
assert new_name == "mm.23.weight", new_name
305-
new_name = "mm.2.weight"
306-
307296
# TODO: why do we squeeze here?
308297
# data = data_torch.squeeze().numpy()
309298
data = data_torch.numpy()
@@ -566,12 +555,7 @@ def prepare_metadata(self, vocab_only: bool):
566555
def set_gguf_parameters(self):
567556
self.gguf_writer.add_block_count(self.block_count)
568557

569-
if self.is_mistral_format:
570-
n_ctx = self.n_ctx
571-
else:
572-
n_ctx = self.find_hparam(["max_position_embeddings", "n_ctx", "n_positions", "max_length"], optional=True)
573-
574-
if n_ctx is not None:
558+
if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx", "n_positions", "max_length"], optional=True)) is not None:
575559
self.gguf_writer.add_context_length(n_ctx)
576560
logger.info(f"gguf: context length = {n_ctx}")
577561

@@ -2013,6 +1997,9 @@ def __init__(self, *args, **kwargs):
20131997
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
20141998

20151999
def set_vocab(self):
2000+
if self.is_mistral_format:
2001+
self._set_vocab_mistral()
2002+
return
20162003
try:
20172004
self._set_vocab_sentencepiece()
20182005
except FileNotFoundError:
@@ -2048,7 +2035,9 @@ def set_vocab(self):
20482035
def set_gguf_parameters(self):
20492036
super().set_gguf_parameters()
20502037
hparams = self.hparams
2051-
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
2038+
2039+
if not self.is_mistral_format:
2040+
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
20522041

20532042
if (rope_dim := hparams.get("head_dim")) is None:
20542043
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
@@ -2070,12 +2059,24 @@ def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
20702059
_experts: list[dict[str, Tensor]] | None = None
20712060

20722061
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2073-
n_head = self.hparams["num_attention_heads"]
2074-
n_kv_head = self.hparams.get("num_key_value_heads")
2062+
n_head = self.find_hparam(["n_heads", "num_attention_heads"])
2063+
n_kv_head = self.find_hparam(["n_kv_heads", "num_key_value_heads"])
2064+
2065+
vision_prefixes = [
2066+
"vision_encoder.",
2067+
"vision_language_adapter.",
2068+
"patch_merger.",
2069+
"pre_mm_projector_norm",
2070+
]
2071+
20752072
is_vision_tensor = "vision_tower" in name \
20762073
or "vision_model" in name \
20772074
or "model.connector" in name \
2078-
or "multi_modal_projector" in name
2075+
or "multi_modal_projector" in name \
2076+
or any(
2077+
name.startswith(prefix)
2078+
for prefix in vision_prefixes
2079+
)
20792080

20802081
if is_vision_tensor:
20812082
return [] # skip vision tensors
@@ -2191,13 +2192,16 @@ class LlavaVisionModel(MmprojModel):
21912192

21922193
def __init__(self, *args, **kwargs):
21932194
super().__init__(*args, **kwargs)
2194-
if self.hparams["model_type"] == "pixtral":
2195+
if self.hparams.get("model_type") == "pixtral":
21952196
# layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
21962197
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
21972198
self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
2198-
logger.info(f"Image break token id: {self.img_break_tok_id}")
2199+
elif self.is_mistral_format:
2200+
self.hparams["layer_norm_eps"] = self.hparams.get("norm_eps", 1e-5)
2201+
self.img_break_tok_id = self.find_vparam(["image_break_token_id"])
21992202
else:
22002203
raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")
2204+
logger.info(f"Image break token id: {self.img_break_tok_id}")
22012205

22022206
def get_token_id(self, token: str) -> int:
22032207
tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
@@ -2211,7 +2215,7 @@ def get_token_id(self, token: str) -> int:
22112215
def set_gguf_parameters(self):
22122216
super().set_gguf_parameters()
22132217
hparams = self.hparams
2214-
if hparams["model_type"] == "pixtral":
2218+
if hparams.get("model_type") == "pixtral":
22152219
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PIXTRAL)
22162220
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
22172221

@@ -2229,18 +2233,30 @@ def set_gguf_parameters(self):
22292233

22302234
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
22312235
del bid # unused
2232-
n_head = self.hparams["num_attention_heads"]
2236+
n_head = (
2237+
self.hparams["num_attention_heads"] if not self.is_mistral_format else self.find_vparam(["num_attention_heads"])
2238+
)
22332239
n_kv_head = n_head
22342240

2235-
if name.startswith("multi_modal_projector.") or name.startswith("vision_tower."):
2241+
valid_prefixes = (
2242+
"multi_modal_projector.",
2243+
"vision_tower.",
2244+
"vision_encoder.",
2245+
"vision_language_adapter.",
2246+
"patch_merger.",
2247+
"pre_mm_projector_norm",
2248+
)
2249+
2250+
if any(name.startswith(prefix) for prefix in valid_prefixes):
22362251
# process vision tensors
2237-
if name.endswith(("q_proj.weight", "q_proj.bias")):
2252+
if name.endswith(("q_proj.weight", "q_proj.bias")) and not self.is_mistral_format:
22382253
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
2239-
if name.endswith(("k_proj.weight", "k_proj.bias")):
2254+
if name.endswith(("k_proj.weight", "k_proj.bias")) and not self.is_mistral_format:
22402255
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
22412256
return [(self.map_tensor_name(name), data_torch)]
22422257

2243-
if self.img_break_tok_id > 0 and "embed_tokens.weight" in name:
2258+
embed_key = "embed_tokens.weight" if not self.is_mistral_format else "tok_embeddings.weight"
2259+
if self.img_break_tok_id > 0 and embed_key in name:
22442260
logger.info(f"Extracting [IMG_BREAK] token embedding from {name}")
22452261
# for pixtral model, we need to extract the [IMG_BREAK] token embedding
22462262
img_break_embd = data_torch[self.img_break_tok_id]
@@ -7682,81 +7698,18 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
76827698
return [(self.map_tensor_name(name), data_torch)]
76837699

76847700

7685-
class MistralModel(TextModel):
7701+
class MistralModel(LlamaModel):
7702+
model_arch = gguf.MODEL_ARCH.LLAMA
76867703
model_name = "Mistral"
7687-
model_arch = MODEL_ARCH.LLAMA
7704+
hf_arch = ""
7705+
is_mistral_format = True
76887706
undo_permute = True
76897707

7690-
def __init__(self, *args, **kwargs):
7691-
super().__init__(*args, **kwargs)
7692-
7693-
def set_gguf_parameters(self):
7694-
super().set_gguf_parameters()
7695-
hparams = self.hparams
7696-
7697-
if "head_dim" in hparams:
7698-
rope_dim = hparams["head_dim"]
7699-
else:
7700-
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
7701-
self.gguf_writer.add_rope_dimension_count(rope_dim)
7702-
7703-
rope_scaling = self.hparams.get("rope_scaling") or {}
7704-
if (
7705-
rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear"
7706-
and "factor" in rope_scaling
7707-
):
7708-
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
7709-
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
77107708

7711-
@staticmethod
7712-
def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
7713-
if n_head_kv is not None and n_head != n_head_kv:
7714-
n_head = n_head_kv
7715-
return (
7716-
weights.reshape(
7717-
n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]
7718-
)
7719-
.swapaxes(1, 2)
7720-
.reshape(weights.shape)
7721-
)
7722-
7723-
def modify_tensors(
7724-
self, data_torch: Tensor, name: str, bid: int | None
7725-
) -> Iterable[tuple[str, Tensor]]:
7726-
n_head = self.hparams["n_heads"]
7727-
n_kv_head = self.hparams.get("n_kv_heads")
7728-
is_vision_tensor = any(
7729-
name.startswith(prefix)
7730-
for prefix in [
7731-
"vision_encoder.",
7732-
"vision_language_adapter.",
7733-
"patch_merger.",
7734-
"pre_mm_projector_norm",
7735-
]
7736-
)
7737-
7738-
if is_vision_tensor:
7739-
return [] # skip vision tensors
7740-
7741-
if self.undo_permute:
7742-
if name.endswith(("q_proj.weight", "q_proj.bias")):
7743-
data_torch = self.permute(data_torch, n_head, n_head)
7744-
if name.endswith(("k_proj.weight", "k_proj.bias")):
7745-
data_torch = self.permute(data_torch, n_head, n_kv_head)
7746-
7747-
return [(self.map_tensor_name(name), data_torch)]
7748-
7749-
7750-
class PixtralModel(MmprojModel):
7709+
class PixtralModel(LlavaVisionModel):
77517710
model_name = "Pixtral"
7752-
img_break_tok_id = -1
7753-
7754-
def __init__(self, *args, **kwargs):
7755-
super().__init__(*args, **kwargs)
7756-
# layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
7757-
self.hparams["layer_norm_eps"] = self.hparams.get("norm_eps", 1e-5)
7758-
self.img_break_tok_id = self.find_vparam(["image_break_token_id"])
7759-
logger.info(f"Image break token id: {self.img_break_tok_id}")
7711+
hf_arch = ""
7712+
is_mistral_format = True
77607713

77617714
def set_gguf_parameters(self):
77627715
super().set_gguf_parameters()
@@ -7774,38 +7727,13 @@ def set_gguf_parameters(self):
77747727
self.gguf_writer.add_vision_spatial_merge_size(
77757728
self.find_vparam(["spatial_merge_size"])
77767729
)
7777-
7778-
def modify_tensors(
7779-
self, data_torch: Tensor, name: str, bid: int | None
7780-
) -> Iterable[tuple[str, Tensor]]:
7781-
del bid # unused
7782-
n_head = self.find_vparam(["num_attention_heads"])
7783-
n_kv_head = n_head
7784-
7785-
if any(
7786-
name.startswith(prefix)
7787-
for prefix in [
7788-
"vision_encoder.",
7789-
"vision_language_adapter.",
7790-
"patch_merger.",
7791-
"pre_mm_projector_norm",
7792-
]
7793-
):
7794-
# process vision tensors
7795-
if name.endswith(("q_proj.weight", "q_proj.bias")):
7796-
data_torch = MistralModel.permute(data_torch, n_head, n_head)
7797-
if name.endswith(("k_proj.weight", "k_proj.bias")):
7798-
data_torch = MistralModel.permute(data_torch, n_head, n_kv_head)
7799-
return [(self.map_tensor_name(name), data_torch)]
7800-
7801-
if self.img_break_tok_id > 0 and "tok_embeddings.weight" in name:
7802-
logger.info(f"Extracting [IMG_BREAK] token embedding from {name}")
7803-
# for pixtral model, we need to extract the [IMG_BREAK] token embedding
7804-
img_break_embd = data_torch[self.img_break_tok_id]
7805-
name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK]
7806-
return [(self.map_tensor_name(name), img_break_embd)]
7807-
7808-
return [] # skip other tensors
7730+
7731+
def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
7732+
if name == "vision_language_adapter.w_in.weight":
7733+
return "mm.1.weight"
7734+
elif name == "vision_language_adapter.w_out.weight":
7735+
return "mm.2.weight"
7736+
return super().map_tensor_name(name, try_suffixes)
78097737

78107738

78117739
###### CONVERSION LOGIC ######
@@ -7961,12 +7889,6 @@ def parse_args() -> argparse.Namespace:
79617889
"--mistral-format", action="store_true",
79627890
help="Whether the model is stored following the Mistral format.",
79637891
)
7964-
parser.add_argument(
7965-
"--n-ctx",
7966-
type=int,
7967-
help="Training context size",
7968-
default=0
7969-
)
79707892

79717893
args = parser.parse_args()
79727894
if not args.print_supported_models and args.model is None:
@@ -8099,8 +8021,6 @@ def main() -> None:
80998021
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
81008022
small_first_shard=args.no_tensor_first_split,
81018023
remote_hf_model_id=hf_repo_id,
8102-
n_ctx=args.n_ctx,
8103-
is_mistral_format=is_mistral_format
81048024
)
81058025

81068026
if args.vocab_only:

gguf-py/gguf/tensor_mapping.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,8 +1059,6 @@ class TensorNameMap:
10591059
MODEL_TENSOR.V_MMPROJ: (
10601060
"multi_modal_projector.linear_{bid}",
10611061
"visual.merger.mlp.{bid}", # qwen2vl
1062-
"vision_language_adapter.w_in", # pixtral
1063-
"vision_language_adapter.w_out", # pixtral
10641062
),
10651063

10661064
MODEL_TENSOR.V_MMPROJ_FC: (

0 commit comments

Comments
 (0)