Skip to content

Commit 8c70887

Browse files
sammcjCISCslaren
authored andcommitted
model: support GLM 4.5 family of models (ggml-org#14939)
* model: Add GLM 4.5 (ggml-org#14921) Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Merge in PR suggestions Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * model: Add GLM 4.5 family of models (ggml-org#14921) 1. Updated tensor_mapping.py with NextN tensor mappings - Added proper tensor mappings for all NextN/MTP tensors in /Users/samm/git/llama.cpp/gguf-py/gguf/tensor_mapping.py - Added mappings for: eh_proj, embed_tokens, enorm, hnorm, shared_head.head, shared_head.norm 2. Added num_nextn_predict_layers configuration - Added LLM_KV_NUM_NEXTN_PREDICT_LAYERS constant to llama-arch.h and llama-arch.cpp - Added num_nextn_predict_layers field to llama_hparams struct - Updated GLM4_MOE parameter loading in llama-model.cpp to read this parameter - Modified tensor loading logic to conditionally load NextN tensors based on num_nextn_predict_layers - Added GGUF writer support in gguf_writer.py with add_num_nextn_predict_layers() method - Updated conversion script to extract and write this parameter from HuggingFace config 3. Added FIM tokens for GLM4_MOE - Added GLM-4.5's FIM tokens to llama-vocab.cpp: - <|code_prefix|> for FIM_PRE - <|code_suffix|> for FIM_SUF - <|code_middle|> for FIM_MID 4. Removed manual NextN tensor handling - Removed the special-case handling in convert_hf_to_gguf.py that manually mapped NextN tensors - NextN tensors are now handled automatically through the proper tensor mapping system * glm 4.5 update tensors names * model: glm 4.5 apply suggestions from code review Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * model: glm 4.5 apply suggestions from code review Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * model: glm 4.5 apply suggestions from code review * Apply suggestions from code review * patch broken chat template * typings fix * add TENSOR_SKIP flag Co-authored-by: Diego Devesa <slarengh@gmail.com> * Update src/llama-model-loader.h Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> Co-authored-by: Diego Devesa <slarengh@gmail.com>
1 parent c9f6675 commit 8c70887

14 files changed

+592
-7
lines changed

convert_hf_to_gguf.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
10081008
if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2":
10091009
# ref: https://huggingface.co/THUDM/glm-4-9b-hf
10101010
res = "glm4"
1011+
if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902":
1012+
# ref: https://huggingface.co/zai-org/GLM-4.5-Air
1013+
res = "glm4"
10111014
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
10121015
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
10131016
res = "minerva-7b"
@@ -7026,6 +7029,139 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
70267029
return super().modify_tensors(data_torch, name, bid)
70277030

70287031

7032+
@ModelBase.register("Glm4MoeForCausalLM")
7033+
class Glm4MoeModel(TextModel):
7034+
model_arch = gguf.MODEL_ARCH.GLM4_MOE
7035+
7036+
def __init__(self, *args, **kwargs):
7037+
super().__init__(*args, **kwargs)
7038+
# GLM4_MOE has num_hidden_layers + 1 actual layers (including NextN layer)
7039+
self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0)
7040+
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
7041+
7042+
def set_vocab(self):
7043+
from transformers import AutoTokenizer
7044+
7045+
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
7046+
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
7047+
tokens, toktypes, tokpre = self.get_vocab_base()
7048+
self.gguf_writer.add_tokenizer_model("gpt2")
7049+
self.gguf_writer.add_tokenizer_pre(tokpre)
7050+
self.gguf_writer.add_token_list(tokens)
7051+
self.gguf_writer.add_token_types(toktypes)
7052+
7053+
# Special tokens
7054+
# Note: Using <|endoftext|> (151329) for eot causes endless generation
7055+
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331
7056+
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336
7057+
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
7058+
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
7059+
7060+
# Patch broken chat template
7061+
if isinstance(special_vocab.chat_template, str) and "visible_text(m.content).endswith" in special_vocab.chat_template:
7062+
special_vocab.chat_template = special_vocab.chat_template.replace(
7063+
"""{{ visible_text(m.content) }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not visible_text(m.content).endswith("/nothink")) else '' -}}""",
7064+
"""{% set content = visible_text(m.content) %}{{ content }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not content.endswith("/nothink")) else '' -}}""")
7065+
7066+
special_vocab.add_to_gguf(self.gguf_writer)
7067+
7068+
def set_gguf_parameters(self):
7069+
super().set_gguf_parameters()
7070+
if (rope_dim := self.hparams.get("head_dim")) is None:
7071+
rope_dim = (
7072+
self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
7073+
)
7074+
self.gguf_writer.add_rope_dimension_count(
7075+
int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))
7076+
)
7077+
7078+
# MoE parameters - Use only routed expert count (shared experts handled separately)
7079+
if (n_routed_experts := self.hparams.get("n_routed_experts")) is not None:
7080+
self.gguf_writer.add_expert_count(n_routed_experts)
7081+
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
7082+
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
7083+
if (n_shared_experts := self.hparams.get("n_shared_experts")) is not None:
7084+
self.gguf_writer.add_expert_shared_count(n_shared_experts)
7085+
if (first_k_dense_replace := self.hparams.get("first_k_dense_replace")) is not None:
7086+
self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace)
7087+
7088+
# Expert gating function (sigmoid for GLM4_MOE)
7089+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
7090+
7091+
# Routed scaling factor
7092+
if (routed_scaling_factor := self.hparams.get("routed_scaling_factor")) is not None:
7093+
self.gguf_writer.add_expert_weights_scale(routed_scaling_factor)
7094+
7095+
# Normalise topk probabilities
7096+
if (norm_topk_prob := self.hparams.get("norm_topk_prob")) is not None:
7097+
self.gguf_writer.add_expert_weights_norm(norm_topk_prob)
7098+
7099+
# NextN/MTP prediction layers
7100+
if (num_nextn_predict_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
7101+
self.gguf_writer.add_nextn_predict_layers(num_nextn_predict_layers)
7102+
7103+
_experts: list[dict[str, Tensor]] | None = None
7104+
7105+
def modify_tensors(
7106+
self, data_torch: Tensor, name: str, bid: int | None
7107+
) -> Iterable[tuple[str, Tensor]]:
7108+
if name.startswith("model.visual."): # ignore visual part
7109+
return []
7110+
elif name.startswith("model.language_model."):
7111+
name = name.replace("language_model.", "") # for multimodal variants
7112+
7113+
# Handle main token embedding (but not layer-specific NextN embeddings)
7114+
if name == "model.embed_tokens.weight" and ".layers." not in name:
7115+
return [(self.map_tensor_name("token_embd.weight"), data_torch)]
7116+
7117+
# Handle routed experts
7118+
if name.find("mlp.experts") != -1:
7119+
n_experts = self.hparams["n_routed_experts"]
7120+
assert bid is not None
7121+
7122+
if self._experts is None:
7123+
self._experts = [{} for _ in range(self.block_count)]
7124+
7125+
self._experts[bid][name] = data_torch
7126+
7127+
if len(self._experts[bid]) >= n_experts * 3:
7128+
tensors: list[tuple[str, Tensor]] = []
7129+
7130+
# merge the experts into a single 3d tensor
7131+
for w_name in ["down_proj", "gate_proj", "up_proj"]:
7132+
datas: list[Tensor] = []
7133+
7134+
for xid in range(n_experts):
7135+
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
7136+
datas.append(self._experts[bid][ename])
7137+
del self._experts[bid][ename]
7138+
7139+
data_torch = torch.stack(datas, dim=0)
7140+
7141+
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
7142+
7143+
new_name = self.map_tensor_name(merged_name)
7144+
tensors.append((new_name, data_torch))
7145+
return tensors
7146+
else:
7147+
return []
7148+
7149+
if name.endswith("e_score_correction_bias"):
7150+
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
7151+
7152+
new_name = self.map_tensor_name(name)
7153+
7154+
return [(new_name, data_torch)]
7155+
7156+
def prepare_tensors(self):
7157+
super().prepare_tensors()
7158+
if self._experts is not None:
7159+
# flatten `list[dict[str, Tensor]]` into `list[str]`
7160+
experts = [k for d in self._experts for k in d.keys()]
7161+
if len(experts) > 0:
7162+
raise ValueError(f"Unprocessed experts: {experts}")
7163+
7164+
70297165
@ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration")
70307166
class ChatGLMModel(TextModel):
70317167
model_arch = gguf.MODEL_ARCH.CHATGLM

convert_hf_to_gguf_update.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ class TOKENIZER_TYPE(IntEnum):
139139
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b"},
140140
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
141141
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
142+
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.5-Air", "chkhsh": "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902"},
142143
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
143144
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
144145
{"name": "hunyuan-dense", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-4B-Instruct", "chkhsh": "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6"},

gguf-py/gguf/constants.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class LLM:
105105
EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm"
106106
EXPERT_GATING_FUNC = "{arch}.expert_gating_func"
107107
MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers"
108+
NEXTN_PREDICT_LAYERS = "{arch}.nextn_predict_layers"
108109
POOLING_TYPE = "{arch}.pooling_type"
109110
LOGIT_SCALE = "{arch}.logit_scale"
110111
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
@@ -358,6 +359,7 @@ class MODEL_ARCH(IntEnum):
358359
DEEPSEEK2 = auto()
359360
CHATGLM = auto()
360361
GLM4 = auto()
362+
GLM4_MOE = auto()
361363
BITNET = auto()
362364
BITNET_25 = auto()
363365
T5 = auto()
@@ -616,6 +618,13 @@ class MODEL_TENSOR(IntEnum):
616618
A_MMPROJ_FC = auto()
617619
A_MM_NORM_PRE = auto()
618620
A_MM_NORM_MID = auto()
621+
# nextn/mtp
622+
NEXTN_EH_PROJ = auto()
623+
NEXTN_EMBED_TOKENS = auto()
624+
NEXTN_ENORM = auto()
625+
NEXTN_HNORM = auto()
626+
NEXTN_SHARED_HEAD_HEAD = auto()
627+
NEXTN_SHARED_HEAD_NORM = auto()
619628

620629

621630
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@@ -680,6 +689,7 @@ class MODEL_TENSOR(IntEnum):
680689
MODEL_ARCH.DEEPSEEK2: "deepseek2",
681690
MODEL_ARCH.CHATGLM: "chatglm",
682691
MODEL_ARCH.GLM4: "glm4",
692+
MODEL_ARCH.GLM4_MOE: "glm4moe",
683693
MODEL_ARCH.BITNET: "bitnet",
684694
MODEL_ARCH.BITNET_25: "bitnet-25",
685695
MODEL_ARCH.T5: "t5",
@@ -939,6 +949,13 @@ class MODEL_TENSOR(IntEnum):
939949
MODEL_TENSOR.A_MMPROJ_FC: "mm.a.fc",
940950
MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre",
941951
MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid",
952+
# NextN/MTP
953+
MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.nextn.eh_proj",
954+
MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.nextn.embed_tokens",
955+
MODEL_TENSOR.NEXTN_ENORM: "blk.{bid}.nextn.enorm",
956+
MODEL_TENSOR.NEXTN_HNORM: "blk.{bid}.nextn.hnorm",
957+
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: "blk.{bid}.nextn.shared_head_head",
958+
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: "blk.{bid}.nextn.shared_head_norm",
942959
}
943960

944961
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@@ -2127,6 +2144,37 @@ class MODEL_TENSOR(IntEnum):
21272144
MODEL_TENSOR.ATTN_POST_NORM,
21282145
MODEL_TENSOR.FFN_POST_NORM,
21292146
],
2147+
MODEL_ARCH.GLM4_MOE: [
2148+
MODEL_TENSOR.TOKEN_EMBD,
2149+
MODEL_TENSOR.OUTPUT_NORM,
2150+
MODEL_TENSOR.OUTPUT,
2151+
MODEL_TENSOR.ATTN_NORM,
2152+
MODEL_TENSOR.ATTN_POST_NORM,
2153+
MODEL_TENSOR.ATTN_Q,
2154+
MODEL_TENSOR.ATTN_K,
2155+
MODEL_TENSOR.ATTN_V,
2156+
MODEL_TENSOR.ATTN_OUT,
2157+
MODEL_TENSOR.ATTN_Q_NORM,
2158+
MODEL_TENSOR.ATTN_K_NORM,
2159+
MODEL_TENSOR.FFN_GATE,
2160+
MODEL_TENSOR.FFN_DOWN,
2161+
MODEL_TENSOR.FFN_UP,
2162+
MODEL_TENSOR.FFN_GATE_INP,
2163+
MODEL_TENSOR.FFN_GATE_EXP,
2164+
MODEL_TENSOR.FFN_DOWN_EXP,
2165+
MODEL_TENSOR.FFN_UP_EXP,
2166+
MODEL_TENSOR.FFN_GATE_SHEXP,
2167+
MODEL_TENSOR.FFN_DOWN_SHEXP,
2168+
MODEL_TENSOR.FFN_UP_SHEXP,
2169+
MODEL_TENSOR.FFN_EXP_PROBS_B,
2170+
# NextN/MTP tensors - preserved but unused
2171+
MODEL_TENSOR.NEXTN_EH_PROJ,
2172+
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
2173+
MODEL_TENSOR.NEXTN_ENORM,
2174+
MODEL_TENSOR.NEXTN_HNORM,
2175+
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
2176+
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
2177+
],
21302178
MODEL_ARCH.BITNET: [
21312179
MODEL_TENSOR.ATTN_Q,
21322180
MODEL_TENSOR.ATTN_K,

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -849,6 +849,9 @@ def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None:
849849
def add_moe_every_n_layers(self, value: int) -> None:
850850
self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value)
851851

852+
def add_nextn_predict_layers(self, count: int) -> None:
853+
self.add_uint32(Keys.LLM.NEXTN_PREDICT_LAYERS.format(arch=self.arch), count)
854+
852855
def add_swin_norm(self, value: bool) -> None:
853856
self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value)
854857

gguf-py/gguf/tensor_mapping.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1372,6 +1372,31 @@ class TensorNameMap:
13721372
MODEL_TENSOR.A_MM_NORM_MID: (
13731373
"audio.multi_modal_projector.ln_mid", # ultravox
13741374
),
1375+
1376+
# NextN/MTP tensors for GLM4_MOE
1377+
MODEL_TENSOR.NEXTN_EH_PROJ: (
1378+
"model.layers.{bid}.eh_proj",
1379+
),
1380+
1381+
MODEL_TENSOR.NEXTN_EMBED_TOKENS: (
1382+
"model.layers.{bid}.embed_tokens",
1383+
),
1384+
1385+
MODEL_TENSOR.NEXTN_ENORM: (
1386+
"model.layers.{bid}.enorm",
1387+
),
1388+
1389+
MODEL_TENSOR.NEXTN_HNORM: (
1390+
"model.layers.{bid}.hnorm",
1391+
),
1392+
1393+
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: (
1394+
"model.layers.{bid}.shared_head.head",
1395+
),
1396+
1397+
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: (
1398+
"model.layers.{bid}.shared_head.norm",
1399+
),
13751400
}
13761401

13771402
# architecture-specific block mappings

src/llama-arch.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
6262
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
6363
{ LLM_ARCH_CHATGLM, "chatglm" },
6464
{ LLM_ARCH_GLM4, "glm4" },
65+
{ LLM_ARCH_GLM4_MOE, "glm4moe" },
6566
{ LLM_ARCH_BITNET, "bitnet" },
6667
{ LLM_ARCH_T5, "t5" },
6768
{ LLM_ARCH_T5ENCODER, "t5encoder" },
@@ -127,6 +128,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
127128
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
128129
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
129130
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
131+
{ LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" },
130132
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
131133
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
132134
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
@@ -1391,6 +1393,40 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
13911393
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
13921394
},
13931395
},
1396+
{
1397+
LLM_ARCH_GLM4_MOE,
1398+
{
1399+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1400+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1401+
{ LLM_TENSOR_OUTPUT, "output" },
1402+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1403+
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
1404+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1405+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1406+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1407+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1408+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
1409+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
1410+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1411+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1412+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1413+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
1414+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
1415+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
1416+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
1417+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
1418+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
1419+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
1420+
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
1421+
// NextN/MTP tensors - preserved but unused (in final layer, dynamic layer number)
1422+
{ LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" },
1423+
{ LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" },
1424+
{ LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" },
1425+
{ LLM_TENSOR_NEXTN_HNORM, "blk.%d.nextn.hnorm" },
1426+
{ LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.nextn.shared_head_head" },
1427+
{ LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" },
1428+
},
1429+
},
13941430
{
13951431
LLM_ARCH_BITNET,
13961432
{
@@ -2181,6 +2217,14 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
21812217
{LLM_TENSOR_SHORTCONV_CONV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
21822218
{LLM_TENSOR_SHORTCONV_INPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
21832219
{LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
2220+
// NextN/MTP tensors are currently ignored (reserved for future MTP support)
2221+
// These tensors only exist in the last layer(s) and are treated as output tensors
2222+
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
2223+
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
2224+
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
2225+
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
2226+
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
2227+
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
21842228
};
21852229

21862230
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}

src/llama-arch.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ enum llm_arch {
6666
LLM_ARCH_DEEPSEEK2,
6767
LLM_ARCH_CHATGLM,
6868
LLM_ARCH_GLM4,
69+
LLM_ARCH_GLM4_MOE,
6970
LLM_ARCH_BITNET,
7071
LLM_ARCH_T5,
7172
LLM_ARCH_T5ENCODER,
@@ -131,6 +132,7 @@ enum llm_kv {
131132
LLM_KV_EXPERT_WEIGHTS_NORM,
132133
LLM_KV_EXPERT_GATING_FUNC,
133134
LLM_KV_MOE_EVERY_N_LAYERS,
135+
LLM_KV_NEXTN_PREDICT_LAYERS,
134136
LLM_KV_POOLING_TYPE,
135137
LLM_KV_LOGIT_SCALE,
136138
LLM_KV_DECODER_START_TOKEN_ID,
@@ -409,6 +411,12 @@ enum llm_tensor {
409411
LLM_TENSOR_SHORTCONV_CONV,
410412
LLM_TENSOR_SHORTCONV_INPROJ,
411413
LLM_TENSOR_SHORTCONV_OUTPROJ,
414+
LLM_TENSOR_NEXTN_EH_PROJ,
415+
LLM_TENSOR_NEXTN_EMBED_TOKENS,
416+
LLM_TENSOR_NEXTN_ENORM,
417+
LLM_TENSOR_NEXTN_HNORM,
418+
LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
419+
LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
412420
};
413421

414422
enum llm_tensor_layer {

0 commit comments

Comments
 (0)