Skip to content

Commit 80a0594

Browse files
committed
feat: support GLM 4.5 family of models
1 parent 0a5036b commit 80a0594

File tree

6 files changed

+511
-0
lines changed

6 files changed

+511
-0
lines changed

convert_hf_to_gguf.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6578,6 +6578,181 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
65786578
return super().modify_tensors(data_torch, name, bid)
65796579

65806580

6581+
@ModelBase.register("Glm4MoeForCausalLM")
6582+
class Glm4MoeModel(TextModel):
6583+
model_arch = gguf.MODEL_ARCH.GLM4_MOE
6584+
6585+
def set_vocab(self):
6586+
from transformers import AutoTokenizer
6587+
6588+
tokenizer = AutoTokenizer.from_pretrained(
6589+
self.dir_model, trust_remote_code=True
6590+
)
6591+
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
6592+
tokens, toktypes, tokpre = self.get_vocab_base()
6593+
self.gguf_writer.add_tokenizer_model("gpt2")
6594+
self.gguf_writer.add_tokenizer_pre(tokpre)
6595+
self.gguf_writer.add_token_list(tokens)
6596+
self.gguf_writer.add_token_types(toktypes)
6597+
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
6598+
special_vocab._set_special_token(
6599+
"eos", tokenizer.get_added_vocab()["<|endoftext|>"]
6600+
)
6601+
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])
6602+
special_vocab._set_special_token("eog", tokenizer.get_added_vocab()["<|user|>"])
6603+
special_vocab._set_special_token("eog", tokenizer.get_added_vocab()["<|observation|>"])
6604+
special_vocab._set_special_token(
6605+
"unk", tokenizer.get_added_vocab()["<|endoftext|>"]
6606+
)
6607+
special_vocab._set_special_token(
6608+
"bos", tokenizer.get_added_vocab()["<|endoftext|>"]
6609+
)
6610+
special_vocab.add_to_gguf(self.gguf_writer)
6611+
6612+
def set_gguf_parameters(self):
6613+
super().set_gguf_parameters()
6614+
if (rope_dim := self.hparams.get("head_dim")) is None:
6615+
rope_dim = (
6616+
self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
6617+
)
6618+
self.gguf_writer.add_rope_dimension_count(
6619+
int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))
6620+
)
6621+
6622+
# MoE parameters
6623+
if (n_experts := self.hparams.get("n_routed_experts")) is not None:
6624+
self.gguf_writer.add_expert_count(n_experts)
6625+
# Note: expert_used_count is already set by parent class using num_experts_per_tok
6626+
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
6627+
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
6628+
if (n_shared_experts := self.hparams.get("n_shared_experts")) is not None:
6629+
self.gguf_writer.add_expert_shared_count(n_shared_experts)
6630+
if (first_k_dense_replace := self.hparams.get("first_k_dense_replace")) is not None:
6631+
self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace)
6632+
6633+
# Expert gating function (sigmoid for GLM4_MOE)
6634+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
6635+
6636+
# Routed scaling factor
6637+
if (routed_scaling_factor := self.hparams.get("routed_scaling_factor")) is not None:
6638+
self.gguf_writer.add_expert_weights_scale(routed_scaling_factor)
6639+
6640+
# Normalise topk probabilities
6641+
if (norm_topk_prob := self.hparams.get("norm_topk_prob")) is not None:
6642+
self.gguf_writer.add_expert_weights_norm(norm_topk_prob)
6643+
6644+
_experts: list[dict[str, Tensor]] | None = None
6645+
_shared_experts: list[dict[str, Tensor]] | None = None
6646+
6647+
def modify_tensors(
6648+
self, data_torch: Tensor, name: str, bid: int | None
6649+
) -> Iterable[tuple[str, Tensor]]:
6650+
# Handle layer 46 tensors - preserve all for future MTP support
6651+
if bid is not None and bid == 46:
6652+
# Convert layer 46 tensors to GGUF naming but don't try to map them
6653+
new_name = name.replace("model.layers.", "blk.")
6654+
return [(new_name, data_torch)]
6655+
6656+
if name.startswith("model.visual."): # ignore visual part
6657+
return []
6658+
elif name.startswith("model.language_model."):
6659+
name = name.replace("language_model.", "") # for multimodal variants
6660+
6661+
# Handle main token embedding
6662+
if name == "model.embed_tokens.weight":
6663+
return [(self.map_tensor_name("token_embd.weight"), data_torch)]
6664+
6665+
# Handle routed experts (skip for NextN layer 46)
6666+
if name.find("mlp.experts") != -1 and "shared_experts" not in name and bid != 46:
6667+
n_experts = self.hparams["n_routed_experts"]
6668+
assert bid is not None
6669+
6670+
if self._experts is None:
6671+
self._experts = [{} for _ in range(self.block_count)]
6672+
6673+
self._experts[bid][name] = data_torch
6674+
6675+
if len(self._experts[bid]) >= n_experts * 3:
6676+
tensors: list[tuple[str, Tensor]] = []
6677+
6678+
# merge the experts into a single 3d tensor
6679+
for w_name in ["down_proj", "gate_proj", "up_proj"]:
6680+
datas: list[Tensor] = []
6681+
6682+
for xid in range(n_experts):
6683+
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
6684+
datas.append(self._experts[bid][ename])
6685+
del self._experts[bid][ename]
6686+
6687+
data_torch = torch.stack(datas, dim=0)
6688+
# Generate GGUF tensor names for merged experts
6689+
if w_name == "down_proj":
6690+
new_name = f"blk.{bid}.ffn_down_exps.weight"
6691+
elif w_name == "gate_proj":
6692+
new_name = f"blk.{bid}.ffn_gate_exps.weight"
6693+
elif w_name == "up_proj":
6694+
new_name = f"blk.{bid}.ffn_up_exps.weight"
6695+
else:
6696+
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
6697+
new_name = self.map_tensor_name(merged_name)
6698+
tensors.append((new_name, data_torch))
6699+
return tensors
6700+
else:
6701+
return []
6702+
6703+
# Handle expert gating input (routing gate)
6704+
if ".mlp.gate.e_score_correction_bias" in name:
6705+
new_name = name.replace("model.layers.", "blk.").replace(
6706+
".mlp.gate.e_score_correction_bias", ".ffn_gate_inp.bias"
6707+
)
6708+
return [(self.map_tensor_name(new_name), data_torch)]
6709+
6710+
# Handle shared expert tensors
6711+
if ".mlp.ffn_" in name and "_shexp" in name:
6712+
new_name = name.replace("model.layers.", "blk.")
6713+
return [(new_name, data_torch)]
6714+
6715+
# Handle regular dense FFN layers (for hybrid dense/MoE architecture)
6716+
if ".mlp." in name and "experts" not in name and "_shexp" not in name:
6717+
if "gate_proj" in name:
6718+
new_name = name.replace("model.layers.", "blk.").replace(
6719+
".mlp.gate_proj.weight", ".ffn_gate.weight"
6720+
)
6721+
elif "up_proj" in name:
6722+
new_name = name.replace("model.layers.", "blk.").replace(
6723+
".mlp.up_proj.weight", ".ffn_up.weight"
6724+
)
6725+
elif "down_proj" in name:
6726+
new_name = name.replace("model.layers.", "blk.").replace(
6727+
".mlp.down_proj.weight", ".ffn_down.weight"
6728+
)
6729+
else:
6730+
new_name = name
6731+
return [(self.map_tensor_name(new_name), data_torch)]
6732+
6733+
# Handle special NextN tensors - preserve for future MTP support
6734+
if (
6735+
".embed_tokens." in name
6736+
or ".shared_head." in name
6737+
or ".eh_proj." in name
6738+
or ".enorm." in name
6739+
or ".hnorm." in name
6740+
):
6741+
# For NextN tensors, convert to GGUF naming convention
6742+
new_name = name.replace("model.layers.", "blk.").replace("model.", "")
6743+
return [(new_name, data_torch)]
6744+
6745+
return super().modify_tensors(data_torch, name, bid)
6746+
6747+
def prepare_tensors(self):
6748+
super().prepare_tensors()
6749+
if self._experts is not None:
6750+
# flatten `list[dict[str, Tensor]]` into `list[str]`
6751+
experts = [k for d in self._experts for k in d.keys()]
6752+
if len(experts) > 0:
6753+
raise ValueError(f"Unprocessed experts: {experts}")
6754+
6755+
65816756
@ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration")
65826757
class ChatGLMModel(TextModel):
65836758
model_arch = gguf.MODEL_ARCH.CHATGLM

gguf-py/gguf/constants.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ class MODEL_ARCH(IntEnum):
354354
DEEPSEEK2 = auto()
355355
CHATGLM = auto()
356356
GLM4 = auto()
357+
GLM4_MOE = auto()
357358
BITNET = auto()
358359
T5 = auto()
359360
T5ENCODER = auto()
@@ -422,6 +423,9 @@ class MODEL_TENSOR(IntEnum):
422423
FFN_GATE_EXP = auto()
423424
FFN_DOWN_EXP = auto()
424425
FFN_UP_EXP = auto()
426+
FFN_GATE_EXPS = auto() # merged experts
427+
FFN_DOWN_EXPS = auto() # merged experts
428+
FFN_UP_EXPS = auto() # merged experts
425429
FFN_GATE_SHEXP = auto()
426430
FFN_DOWN_SHEXP = auto()
427431
FFN_UP_SHEXP = auto()
@@ -673,6 +677,7 @@ class MODEL_TENSOR(IntEnum):
673677
MODEL_ARCH.DEEPSEEK2: "deepseek2",
674678
MODEL_ARCH.CHATGLM: "chatglm",
675679
MODEL_ARCH.GLM4: "glm4",
680+
MODEL_ARCH.GLM4_MOE: "glm4moe",
676681
MODEL_ARCH.BITNET: "bitnet",
677682
MODEL_ARCH.T5: "t5",
678683
MODEL_ARCH.T5ENCODER: "t5encoder",
@@ -747,6 +752,9 @@ class MODEL_TENSOR(IntEnum):
747752
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
748753
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps",
749754
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
755+
MODEL_TENSOR.FFN_GATE_EXPS: "blk.{bid}.ffn_gate_exps", # merged experts
756+
MODEL_TENSOR.FFN_DOWN_EXPS: "blk.{bid}.ffn_down_exps", # merged experts
757+
MODEL_TENSOR.FFN_UP_EXPS: "blk.{bid}.ffn_up_exps", # merged experts
750758
MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b",
751759
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
752760
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd", # gemma3n
@@ -2102,6 +2110,29 @@ class MODEL_TENSOR(IntEnum):
21022110
MODEL_TENSOR.ATTN_POST_NORM,
21032111
MODEL_TENSOR.FFN_POST_NORM,
21042112
],
2113+
MODEL_ARCH.GLM4_MOE: [
2114+
MODEL_TENSOR.TOKEN_EMBD,
2115+
MODEL_TENSOR.OUTPUT_NORM,
2116+
MODEL_TENSOR.OUTPUT,
2117+
MODEL_TENSOR.ATTN_NORM,
2118+
MODEL_TENSOR.ATTN_Q,
2119+
MODEL_TENSOR.ATTN_K,
2120+
MODEL_TENSOR.ATTN_V,
2121+
MODEL_TENSOR.ATTN_OUT,
2122+
MODEL_TENSOR.ATTN_Q_NORM,
2123+
MODEL_TENSOR.ATTN_K_NORM,
2124+
MODEL_TENSOR.FFN_NORM,
2125+
MODEL_TENSOR.FFN_GATE, # dense layers
2126+
MODEL_TENSOR.FFN_DOWN, # dense layers
2127+
MODEL_TENSOR.FFN_UP, # dense layers
2128+
MODEL_TENSOR.FFN_GATE_INP,
2129+
MODEL_TENSOR.FFN_GATE_EXPS,
2130+
MODEL_TENSOR.FFN_DOWN_EXPS,
2131+
MODEL_TENSOR.FFN_UP_EXPS,
2132+
MODEL_TENSOR.FFN_GATE_SHEXP,
2133+
MODEL_TENSOR.FFN_DOWN_SHEXP,
2134+
MODEL_TENSOR.FFN_UP_SHEXP,
2135+
],
21052136
MODEL_ARCH.BITNET: [
21062137
MODEL_TENSOR.ATTN_Q,
21072138
MODEL_TENSOR.ATTN_K,

src/llama-arch.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
6262
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
6363
{ LLM_ARCH_CHATGLM, "chatglm" },
6464
{ LLM_ARCH_GLM4, "glm4" },
65+
{ LLM_ARCH_GLM4_MOE, "glm4moe" },
6566
{ LLM_ARCH_BITNET, "bitnet" },
6667
{ LLM_ARCH_T5, "t5" },
6768
{ LLM_ARCH_T5ENCODER, "t5encoder" },
@@ -1389,6 +1390,32 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
13891390
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
13901391
},
13911392
},
1393+
{
1394+
LLM_ARCH_GLM4_MOE,
1395+
{
1396+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1397+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1398+
{ LLM_TENSOR_OUTPUT, "output" },
1399+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1400+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1401+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1402+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1403+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1404+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
1405+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
1406+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1407+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, // dense layers
1408+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, // dense layers
1409+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, // dense layers
1410+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
1411+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
1412+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
1413+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
1414+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
1415+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
1416+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
1417+
},
1418+
},
13921419
{
13931420
LLM_ARCH_BITNET,
13941421
{

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ enum llm_arch {
6666
LLM_ARCH_DEEPSEEK2,
6767
LLM_ARCH_CHATGLM,
6868
LLM_ARCH_GLM4,
69+
LLM_ARCH_GLM4_MOE,
6970
LLM_ARCH_BITNET,
7071
LLM_ARCH_T5,
7172
LLM_ARCH_T5ENCODER,

0 commit comments

Comments
 (0)