Skip to content

Commit c4dbf69

Browse files
committed
feat: support GLM 4.5 family of models
1 parent 0a5036b commit c4dbf69

File tree

5 files changed

+430
-0
lines changed

5 files changed

+430
-0
lines changed

convert_hf_to_gguf.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6578,6 +6578,117 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
65786578
return super().modify_tensors(data_torch, name, bid)
65796579

65806580

6581+
@ModelBase.register("Glm4MoeForCausalLM")
6582+
class Glm4MoeModel(TextModel):
6583+
model_arch = gguf.MODEL_ARCH.GLM4_MOE
6584+
6585+
def set_vocab(self):
6586+
from transformers import AutoTokenizer
6587+
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
6588+
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
6589+
tokens, toktypes, tokpre = self.get_vocab_base()
6590+
self.gguf_writer.add_tokenizer_model("gpt2")
6591+
self.gguf_writer.add_tokenizer_pre(tokpre)
6592+
self.gguf_writer.add_token_list(tokens)
6593+
self.gguf_writer.add_token_types(toktypes)
6594+
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
6595+
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
6596+
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])
6597+
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"])
6598+
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"])
6599+
special_vocab.add_to_gguf(self.gguf_writer)
6600+
6601+
def set_gguf_parameters(self):
6602+
super().set_gguf_parameters()
6603+
if (rope_dim := self.hparams.get("head_dim")) is None:
6604+
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
6605+
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
6606+
6607+
# MoE parameters
6608+
if (n_experts := self.hparams.get("n_routed_experts")) is not None:
6609+
self.gguf_writer.add_expert_count(n_experts)
6610+
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
6611+
self.gguf_writer.add_expert_used_count(n_experts_used)
6612+
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
6613+
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
6614+
if (n_shared_experts := self.hparams.get("n_shared_experts")) is not None:
6615+
self.gguf_writer.add_expert_shared_count(n_shared_experts)
6616+
if (first_k_dense_replace := self.hparams.get("first_k_dense_replace")) is not None:
6617+
self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace)
6618+
6619+
# Expert gating function (sigmoid for GLM4_MOE)
6620+
self.gguf_writer.add_expert_gating_func(2) # LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID
6621+
6622+
# Routed scaling factor
6623+
if (routed_scaling_factor := self.hparams.get("routed_scaling_factor")) is not None:
6624+
self.gguf_writer.add_expert_weights_scale(routed_scaling_factor)
6625+
6626+
# Normalise topk probabilities
6627+
if (norm_topk_prob := self.hparams.get("norm_topk_prob")) is not None:
6628+
self.gguf_writer.add_expert_weights_norm(norm_topk_prob)
6629+
6630+
_experts: list[dict[str, Tensor]] | None = None
6631+
_shared_experts: list[dict[str, Tensor]] | None = None
6632+
6633+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
6634+
if name.startswith("model.visual."): # ignore visual part
6635+
return []
6636+
elif name.startswith("model.language_model."):
6637+
name = name.replace("language_model.", "") # for multimodal variants
6638+
6639+
# Handle routed experts
6640+
if name.find("mlp.experts") != -1 and "shared_experts" not in name:
6641+
n_experts = self.hparams["n_routed_experts"]
6642+
assert bid is not None
6643+
6644+
if self._experts is None:
6645+
self._experts = [{} for _ in range(self.block_count)]
6646+
6647+
self._experts[bid][name] = data_torch
6648+
6649+
if len(self._experts[bid]) >= n_experts * 3:
6650+
tensors: list[tuple[str, Tensor]] = []
6651+
6652+
# merge the experts into a single 3d tensor
6653+
for w_name in ["down_proj", "gate_proj", "up_proj"]:
6654+
datas: list[Tensor] = []
6655+
6656+
for xid in range(n_experts):
6657+
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
6658+
datas.append(self._experts[bid][ename])
6659+
del self._experts[bid][ename]
6660+
6661+
data_torch = torch.stack(datas, dim=0)
6662+
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
6663+
new_name = self.map_tensor_name(merged_name)
6664+
tensors.append((new_name, data_torch))
6665+
return tensors
6666+
else:
6667+
return []
6668+
6669+
# Handle shared experts - map to shared expert tensors
6670+
if "shared_experts" in name:
6671+
if "gate_proj" in name:
6672+
new_name = name.replace("shared_experts.gate_proj.weight", "ffn_gate_shexp.weight")
6673+
elif "up_proj" in name:
6674+
new_name = name.replace("shared_experts.up_proj.weight", "ffn_up_shexp.weight")
6675+
elif "down_proj" in name:
6676+
new_name = name.replace("shared_experts.down_proj.weight", "ffn_down_shexp.weight")
6677+
else:
6678+
new_name = name
6679+
return [(self.map_tensor_name(new_name), data_torch)]
6680+
6681+
return super().modify_tensors(data_torch, name, bid)
6682+
6683+
def prepare_tensors(self):
6684+
super().prepare_tensors()
6685+
if self._experts is not None:
6686+
# flatten `list[dict[str, Tensor]]` into `list[str]`
6687+
experts = [k for d in self._experts for k in d.keys()]
6688+
if len(experts) > 0:
6689+
raise ValueError(f"Unprocessed experts: {experts}")
6690+
6691+
65816692
@ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration")
65826693
class ChatGLMModel(TextModel):
65836694
model_arch = gguf.MODEL_ARCH.CHATGLM

gguf-py/gguf/constants.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ class MODEL_ARCH(IntEnum):
354354
DEEPSEEK2 = auto()
355355
CHATGLM = auto()
356356
GLM4 = auto()
357+
GLM4_MOE = auto()
357358
BITNET = auto()
358359
T5 = auto()
359360
T5ENCODER = auto()
@@ -673,6 +674,7 @@ class MODEL_TENSOR(IntEnum):
673674
MODEL_ARCH.DEEPSEEK2: "deepseek2",
674675
MODEL_ARCH.CHATGLM: "chatglm",
675676
MODEL_ARCH.GLM4: "glm4",
677+
MODEL_ARCH.GLM4_MOE: "glm4moe",
676678
MODEL_ARCH.BITNET: "bitnet",
677679
MODEL_ARCH.T5: "t5",
678680
MODEL_ARCH.T5ENCODER: "t5encoder",
@@ -2102,6 +2104,27 @@ class MODEL_TENSOR(IntEnum):
21022104
MODEL_TENSOR.ATTN_POST_NORM,
21032105
MODEL_TENSOR.FFN_POST_NORM,
21042106
],
2107+
MODEL_ARCH.GLM4_MOE: [
2108+
MODEL_TENSOR.TOKEN_EMBD,
2109+
MODEL_TENSOR.OUTPUT_NORM,
2110+
MODEL_TENSOR.OUTPUT,
2111+
MODEL_TENSOR.ATTN_NORM,
2112+
MODEL_TENSOR.ATTN_Q,
2113+
MODEL_TENSOR.ATTN_K,
2114+
MODEL_TENSOR.ATTN_V,
2115+
MODEL_TENSOR.ATTN_OUT,
2116+
MODEL_TENSOR.ATTN_Q_NORM,
2117+
MODEL_TENSOR.ATTN_K_NORM,
2118+
MODEL_TENSOR.FFN_NORM,
2119+
MODEL_TENSOR.FFN_GATE_INP,
2120+
MODEL_TENSOR.FFN_GATE_EXPS,
2121+
MODEL_TENSOR.FFN_DOWN_EXPS,
2122+
MODEL_TENSOR.FFN_UP_EXPS,
2123+
MODEL_TENSOR.FFN_GATE_SHEXP,
2124+
MODEL_TENSOR.FFN_DOWN_SHEXP,
2125+
MODEL_TENSOR.FFN_UP_SHEXP,
2126+
MODEL_TENSOR.ATTN_POST_NORM,
2127+
],
21052128
MODEL_ARCH.BITNET: [
21062129
MODEL_TENSOR.ATTN_Q,
21072130
MODEL_TENSOR.ATTN_K,

src/llama-arch.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
6262
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
6363
{ LLM_ARCH_CHATGLM, "chatglm" },
6464
{ LLM_ARCH_GLM4, "glm4" },
65+
{ LLM_ARCH_GLM4_MOE, "glm4moe" },
6566
{ LLM_ARCH_BITNET, "bitnet" },
6667
{ LLM_ARCH_T5, "t5" },
6768
{ LLM_ARCH_T5ENCODER, "t5encoder" },
@@ -1389,6 +1390,30 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
13891390
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
13901391
},
13911392
},
1393+
{
1394+
LLM_ARCH_GLM4_MOE,
1395+
{
1396+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1397+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1398+
{ LLM_TENSOR_OUTPUT, "output" },
1399+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1400+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1401+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1402+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1403+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1404+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
1405+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
1406+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1407+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
1408+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
1409+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
1410+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
1411+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
1412+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
1413+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
1414+
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
1415+
},
1416+
},
13921417
{
13931418
LLM_ARCH_BITNET,
13941419
{

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ enum llm_arch {
6666
LLM_ARCH_DEEPSEEK2,
6767
LLM_ARCH_CHATGLM,
6868
LLM_ARCH_GLM4,
69+
LLM_ARCH_GLM4_MOE,
6970
LLM_ARCH_BITNET,
7071
LLM_ARCH_T5,
7172
LLM_ARCH_T5ENCODER,

0 commit comments

Comments
 (0)