From 84ddaa5764e22b3f22cdbaefdc44c14348f59b14 Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 29 Jul 2025 17:52:59 +1000 Subject: [PATCH 01/13] model: Add GLM 4.5 (#14921) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- convert_hf_to_gguf.py | 148 ++++++++++++++++- convert_hf_to_gguf_update.py | 1 + gguf-py/gguf/constants.py | 46 ++++++ models/templates/README.md | 3 +- src/llama-arch.cpp | 43 +++++ src/llama-arch.h | 7 + src/llama-graph.cpp | 8 +- src/llama-kv-cache-unified.cpp | 4 + src/llama-model.cpp | 288 +++++++++++++++++++++++++++++++++ src/llama-model.h | 2 + 10 files changed, 544 insertions(+), 6 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 3f5cefe007cca..847a7c4cb6c3e 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -678,6 +678,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2": # ref: https://huggingface.co/THUDM/glm-4-9b-hf res = "glm4" + if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902": + # ref: https://huggingface.co/zai-org/GLM-4.5-Air + res = "glm4" if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35": # ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0 res = "minerva-7b" @@ -6578,6 +6581,149 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return super().modify_tensors(data_torch, name, bid) +@ModelBase.register("Glm4MoeForCausalLM") +class Glm4MoeModel(TextModel): + model_arch = gguf.MODEL_ARCH.GLM4_MOE + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # GLM4_MOE has num_hidden_layers + 1 actual layers (including NextN layer) + self.block_count = self.hparams["num_hidden_layers"] + 1 + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + def set_vocab(self): + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(self.dir_model) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + tokens, toktypes, tokpre = self.get_vocab_base() + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + + # Special tokens + # Note: Using <|endoftext|> (151329) for eos and eot causes endless generation + special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331 + special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|user|>"]) # 151336 - end of + special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336 - same as EOS + special_vocab._set_special_token("eog", tokenizer.get_added_vocab()["<|user|>"]) # 151336 - same as EOS + special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329 + special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338 + + if "" in tokenizer.get_added_vocab(): + special_vocab._set_special_token("sop", tokenizer.get_added_vocab()[""]) # 151333 + if "" in tokenizer.get_added_vocab(): + special_vocab._set_special_token("eop", tokenizer.get_added_vocab()[""]) # 151334 + + special_vocab.add_to_gguf(self.gguf_writer) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + if (rope_dim := self.hparams.get("head_dim")) is None: + rope_dim = ( + self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + ) + self.gguf_writer.add_rope_dimension_count( + int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)) + ) + + # MoE parameters - Use only routed expert count (shared experts handled separately) + if (n_routed_experts := self.hparams.get("n_routed_experts")) is not None: + self.gguf_writer.add_expert_count(n_routed_experts) + if (num_experts_per_tok := self.hparams.get("num_experts_per_tok")) is not None: + self.gguf_writer.add_expert_used_count(num_experts_per_tok) + if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: + self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) + if (n_shared_experts := self.hparams.get("n_shared_experts")) is not None: + self.gguf_writer.add_expert_shared_count(n_shared_experts) + if (first_k_dense_replace := self.hparams.get("first_k_dense_replace")) is not None: + self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace) + + # Expert gating function (sigmoid for GLM4_MOE) + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + + # Routed scaling factor + if (routed_scaling_factor := self.hparams.get("routed_scaling_factor")) is not None: + self.gguf_writer.add_expert_weights_scale(routed_scaling_factor) + + # Normalise topk probabilities + if (norm_topk_prob := self.hparams.get("norm_topk_prob")) is not None: + self.gguf_writer.add_expert_weights_norm(norm_topk_prob) + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: + if name.startswith("model.visual."): # ignore visual part + return [] + elif name.startswith("model.language_model."): + name = name.replace("language_model.", "") # for multimodal variants + + # Handle main token embedding (but not layer-specific NextN embeddings) + if name == "model.embed_tokens.weight" and ".layers." not in name: + return [(self.map_tensor_name("token_embd.weight"), data_torch)] + + # Handle routed experts + if name.find("mlp.experts") != -1: + n_experts = self.hparams["n_routed_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + + new_name = self.map_tensor_name(merged_name) + tensors.append((new_name, data_torch)) + return tensors + else: + return [] + + if name.endswith("e_score_correction_bias"): + name = name.replace("e_score_correction_bias", "e_score_correction.bias") + + # Handle special NextN tensors - preserve for future MTP support + if ( + ".embed_tokens." in name + or ".shared_head." in name + or ".eh_proj." in name + or ".enorm." in name + or ".hnorm." in name + ): + new_name = name.replace("model.layers.", "blk.").replace("model.", "").replace(".weight", "") + return [(new_name, data_torch)] + + new_name = self.map_tensor_name(name) + + return [(new_name, data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + @ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration") class ChatGLMModel(TextModel): model_arch = gguf.MODEL_ARCH.CHATGLM @@ -6594,7 +6740,7 @@ def set_vocab_chatglm3(self): vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab())) assert max(tokenizer.get_vocab().values()) < vocab_size role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"] - special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens + special_tokens = ["[MASK]", "[gMASK]", "sop", "eop"] + role_special_tokens for token_id in range(vocab_size): piece = tokenizer._convert_id_to_token(token_id) if token_id == 0: diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index abaf2ea9a1248..ea221fb1b5c4c 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -138,6 +138,7 @@ class TOKENIZER_TYPE(IntEnum): {"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b"}, {"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"}, {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"}, + {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.5-Air", "chkhsh": "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902"}, {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"}, {"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"}, # falcon-h1 series uses 4 different tokenizers across model sizes (0.5b - 34b), hence we need to define 4 different hashes diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c97b61d09c711..abb18db8ae59b 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -354,6 +354,7 @@ class MODEL_ARCH(IntEnum): DEEPSEEK2 = auto() CHATGLM = auto() GLM4 = auto() + GLM4_MOE = auto() BITNET = auto() T5 = auto() T5ENCODER = auto() @@ -609,6 +610,12 @@ class MODEL_TENSOR(IntEnum): A_MMPROJ_FC = auto() A_MM_NORM_PRE = auto() A_MM_NORM_MID = auto() + NEXTN_EH_PROJ = auto() # nextn tensors (glm4moe) + NEXTN_EMBED_TOKENS = auto() # nextn tensors (glm4moe) + NEXTN_ENORM = auto() # nextn tensors (glm4moe) + NEXTN_HNORM = auto() # nextn tensors (glm4moe) + NEXTN_SHARED_HEAD_HEAD = auto() # nextn tensors (glm4moe) + NEXTN_SHARED_HEAD_NORM = auto() # nextn tensors (glm4moe) MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -673,6 +680,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.DEEPSEEK2: "deepseek2", MODEL_ARCH.CHATGLM: "chatglm", MODEL_ARCH.GLM4: "glm4", + MODEL_ARCH.GLM4_MOE: "glm4moe", MODEL_ARCH.BITNET: "bitnet", MODEL_ARCH.T5: "t5", MODEL_ARCH.T5ENCODER: "t5encoder", @@ -929,6 +937,13 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.A_MMPROJ_FC: "mm.a.fc", MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre", MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid", + # NextN/MTP tensors (GLM4_MOE) + MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.eh_proj", + MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.embed_tokens", + MODEL_TENSOR.NEXTN_ENORM: "blk.{bid}.enorm", + MODEL_TENSOR.NEXTN_HNORM: "blk.{bid}.hnorm", + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: "blk.{bid}.shared_head.head", + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: "blk.{bid}.shared_head.norm", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -2102,6 +2117,37 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_POST_NORM, MODEL_TENSOR.FFN_POST_NORM, ], + MODEL_ARCH.GLM4_MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.FFN_GATE, # dense layers + MODEL_TENSOR.FFN_DOWN, # dense layers + MODEL_TENSOR.FFN_UP, # dense layers + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, + # NextN/MTP tensors - preserved but unused + MODEL_TENSOR.NEXTN_EH_PROJ, + MODEL_TENSOR.NEXTN_EMBED_TOKENS, + MODEL_TENSOR.NEXTN_ENORM, + MODEL_TENSOR.NEXTN_HNORM, + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, + ], MODEL_ARCH.BITNET: [ MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, diff --git a/models/templates/README.md b/models/templates/README.md index 35b6386dd0649..2e8eaa5953b86 100644 --- a/models/templates/README.md +++ b/models/templates/README.md @@ -21,4 +21,5 @@ These templates can be updated with the following commands: ./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja ./scripts/get_chat_template.py Qwen/QwQ-32B > models/templates/Qwen-QwQ-32B.jinja ./scripts/get_chat_template.py Qwen/Qwen3-0.6B > models/templates/Qwen-Qwen3-0.6B.jinja -``` \ No newline at end of file +./scripts/get_chat_template.py zai-org/GLM-4.5 > models/templates/zai-org-GLM-4.5.jinja +``` diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index dbf977443ae85..a6a69839ecb63 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -62,6 +62,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_GLM4, "glm4" }, + { LLM_ARCH_GLM4_MOE, "glm4moe" }, { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, @@ -1389,6 +1390,40 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, }, }, + { + LLM_ARCH_GLM4_MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, // dense layers + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, // dense layers + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, // dense layers + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + // NextN/MTP tensors - preserved but unused (in final layer, dynamic layer number) + { LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.eh_proj" }, + { LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.embed_tokens" }, + { LLM_TENSOR_NEXTN_ENORM, "blk.%d.enorm" }, + { LLM_TENSOR_NEXTN_HNORM, "blk.%d.hnorm" }, + { LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.shared_head.head" }, + { LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.shared_head.norm" }, + }, + }, { LLM_ARCH_BITNET, { @@ -2142,6 +2177,14 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SHORTCONV_CONV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}}, {LLM_TENSOR_SHORTCONV_INPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + // NextN/MTP tensors are loaded but never used (reserved for future MTP support) + // These tensors only exist in the last layer (layer 46 for GLM-4.5-Air) and are treated as output tensors + {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}}, + {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}}, + {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}}, + {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} diff --git a/src/llama-arch.h b/src/llama-arch.h index 8267a8d3aa491..73e546673bc7b 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -66,6 +66,7 @@ enum llm_arch { LLM_ARCH_DEEPSEEK2, LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, + LLM_ARCH_GLM4_MOE, LLM_ARCH_BITNET, LLM_ARCH_T5, LLM_ARCH_T5ENCODER, @@ -407,6 +408,12 @@ enum llm_tensor { LLM_TENSOR_SHORTCONV_CONV, LLM_TENSOR_SHORTCONV_INPROJ, LLM_TENSOR_SHORTCONV_OUTPROJ, + LLM_TENSOR_NEXTN_EH_PROJ, + LLM_TENSOR_NEXTN_EMBED_TOKENS, + LLM_TENSOR_NEXTN_ENORM, + LLM_TENSOR_NEXTN_HNORM, + LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, + LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, }; enum llm_tensor_layer { diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 1b9cc4aec0632..32ee267631e91 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -760,8 +760,8 @@ ggml_tensor * llm_graph_context::build_ffn( if (down) { cur = build_lora_mm(down, cur); - if (arch == LLM_ARCH_GLM4) { - // GLM4 seems to have numerical issues with half-precision accumulators + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { + // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } } @@ -1481,8 +1481,8 @@ ggml_tensor * llm_graph_context::build_attn( if (wo) { cur = build_lora_mm(wo, cur); - if (arch == LLM_ARCH_GLM4) { - // GLM4 seems to have numerical issues with half-precision accumulators + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { + // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } } diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 321dc79fc36ab..7b9987edd03ff 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -39,6 +39,10 @@ llama_kv_cache_unified::llama_kv_cache_unified( if (model.arch == LLM_ARCH_GEMMA3N) { n_layer_cache = 20; } + if (model.arch == LLM_ARCH_GLM4_MOE) { + // GLM4_MOE: Only process first 46 transformer layers, skip NextN layer + n_layer_cache = hparams.n_layer - 1; + } // create a context for each buffer type std::map ctx_map; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e3aa9e6f91af9..2533439c06957 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -111,6 +111,8 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_235B_A22B: return "235B.A22B"; case LLM_TYPE_300B_A47B: return "300B.A47B"; + case LLM_TYPE_106B_A12B: return "106B.A12B"; + case LLM_TYPE_355B_A32B: return "355B.A32B"; case LLM_TYPE_E2B: return "E2B"; case LLM_TYPE_E4B: return "E4B"; default: return "?B"; @@ -1417,6 +1419,31 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GLM4_MOE: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // MoE parameters + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // Expert gating function (GLM4_MOE uses sigmoid) + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + switch (hparams.n_layer) { + case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer) + case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer) + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_BITNET: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -4345,6 +4372,99 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); } } break; + case LLM_ARCH_GLM4_MOE: + { + const int64_t n_expert = hparams.n_expert; + const int64_t n_expert_used = hparams.n_expert_used; + const int64_t n_expert_shared = hparams.n_expert_shared; + + GGML_ASSERT(hparams.n_expert > 0 && "n_expert must be > 0 for GLM4_MOE MoE layers"); + GGML_ASSERT(hparams.n_expert_used > 0 && "n_expert_used must be > 0 for GLM4_MOE MoE layers"); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + // NextN/MTP tensors (preserved but unused) - only in final layer (46 for Air, 92 for GLM-4.5) + const int final_layer = n_layer - 1; // NextN tensors are in the last layer only + create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, final_layer), { 2 * n_embd, n_embd }, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, final_layer), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, final_layer), { n_embd }, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, final_layer), { n_embd }, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, final_layer), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, final_layer), { n_embd }, TENSOR_NOT_REQUIRED); + + // Load ALL tensors including NextN layer to satisfy tensor count (803) + // but only PROCESS first 46 transformer layers in forward pass + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + // GLM-style attention with bias terms + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // K/Q norm tensors (optional for GLM-4.5 355B variant) + layer.attn_q_norm = create_tensor( + tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor( + tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + // Check if this layer uses MoE or dense FFN based on n_layer_dense_lead + // GLM 4.5 uses hybrid architecture: layer 0 is dense, layers 1+ are MoE + const bool use_moe = (static_cast(i) >= hparams.n_layer_dense_lead); + + if (use_moe) { + // MoE layers + layer.ffn_gate_inp = + create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), { n_expert }, 0); + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor( + tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + layer.ffn_down_exps = create_tensor( + tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + layer.ffn_up_exps = create_tensor( + tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + + // Shared expert + if (n_expert_shared > 0) { + const int64_t n_ff_shexp = n_ff_exp * n_expert_shared; + layer.ffn_gate_shexp = create_tensor( + tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor( + tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + layer.ffn_up_shexp = + create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + } + } else { + // Dense layers (first k layers) - GLM uses separate gate/up projections + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + } + + } + } + break; case LLM_ARCH_NEMOTRON: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -13349,6 +13469,169 @@ struct llm_build_glm4 : public llm_graph_context { } }; +struct llm_build_glm4_moe : public llm_graph_context { + llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + // Only process first 46 transformer layers (skip NextN layer 46) + // Layer 46 tensors are loaded but not processed in forward pass + const int n_transformer_layers = n_layer - 1; + for (int il = 0; il < n_transformer_layers; ++il) { + ggml_tensor * inpSA = inpL; + + // Pre-attention norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + } + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply Q/K norm if available (GLM-4.5 355B variant) + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + } + if (model.layers[il].attn_k_norm) { + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + } + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_transformer_layers - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // Post-attention norm + cur = build_norm(ffn_inp, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "post_attn_norm", il); + + // Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense) + if (static_cast(il) < hparams.n_layer_dense_lead) { + // Dense FFN layer + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE layer with shared experts + const int64_t n_expert = hparams.n_expert; + const int64_t n_expert_used = hparams.n_expert_used; + + // Process routed experts using existing MoE infrastructure + ggml_tensor * routed_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(routed_out, "ffn_moe_out", il); + + // Process shared expert on original input + ggml_tensor * shared_out = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(shared_out, "ffn_shexp_out", il); + + // Final output: routed_output + shared_output + cur = ggml_add(ctx0, routed_out, shared_out); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + struct llm_build_nemotron : public llm_graph_context { llm_build_nemotron(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -17509,6 +17792,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_GLM4_MOE: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_BITNET: { llm = std::make_unique(*this, params); @@ -17833,6 +18120,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_HUNYUAN_MOE: case LLM_ARCH_LFM2: case LLM_ARCH_SMALLTHINKER: + case LLM_ARCH_GLM4_MOE: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: diff --git a/src/llama-model.h b/src/llama-model.h index 094e23808a813..5e71247e37cec 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -103,6 +103,8 @@ enum llm_type { LLM_TYPE_30B_A3B, LLM_TYPE_235B_A22B, LLM_TYPE_300B_A47B, // Ernie MoE big + LLM_TYPE_106B_A12B, // GLM-4.5-Air (106B total, 12B active) + LLM_TYPE_355B_A32B, // GLM-4.5 (355B total, 32B active) LLM_TYPE_E2B, LLM_TYPE_E4B, }; From 22619348773b418ed3759bde91c6ad6038176a72 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 4 Aug 2025 06:55:40 +1000 Subject: [PATCH 02/13] Merge in PR suggestions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- convert_hf_to_gguf.py | 15 +++------------ gguf-py/gguf/constants.py | 21 +++++++++++---------- src/llama-arch.cpp | 8 ++++---- src/llama-kv-cache-unified.cpp | 2 +- src/llama-model.cpp | 22 +++++++++++----------- src/llama-model.h | 4 ++-- 6 files changed, 32 insertions(+), 40 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 847a7c4cb6c3e..a20fddc5dbb12 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6603,19 +6603,12 @@ def set_vocab(self): self.gguf_writer.add_token_types(toktypes) # Special tokens - # Note: Using <|endoftext|> (151329) for eos and eot causes endless generation + # Note: Using <|endoftext|> (151329) for eot causes endless generation special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331 - special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|user|>"]) # 151336 - end of - special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336 - same as EOS - special_vocab._set_special_token("eog", tokenizer.get_added_vocab()["<|user|>"]) # 151336 - same as EOS + special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336 special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329 special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338 - if "" in tokenizer.get_added_vocab(): - special_vocab._set_special_token("sop", tokenizer.get_added_vocab()[""]) # 151333 - if "" in tokenizer.get_added_vocab(): - special_vocab._set_special_token("eop", tokenizer.get_added_vocab()[""]) # 151334 - special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): @@ -6631,8 +6624,6 @@ def set_gguf_parameters(self): # MoE parameters - Use only routed expert count (shared experts handled separately) if (n_routed_experts := self.hparams.get("n_routed_experts")) is not None: self.gguf_writer.add_expert_count(n_routed_experts) - if (num_experts_per_tok := self.hparams.get("num_experts_per_tok")) is not None: - self.gguf_writer.add_expert_used_count(num_experts_per_tok) if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) if (n_shared_experts := self.hparams.get("n_shared_experts")) is not None: @@ -6740,7 +6731,7 @@ def set_vocab_chatglm3(self): vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab())) assert max(tokenizer.get_vocab().values()) < vocab_size role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"] - special_tokens = ["[MASK]", "[gMASK]", "sop", "eop"] + role_special_tokens + special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens for token_id in range(vocab_size): piece = tokenizer._convert_id_to_token(token_id) if token_id == 0: diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index abb18db8ae59b..d50f4b5bceb53 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -610,12 +610,13 @@ class MODEL_TENSOR(IntEnum): A_MMPROJ_FC = auto() A_MM_NORM_PRE = auto() A_MM_NORM_MID = auto() - NEXTN_EH_PROJ = auto() # nextn tensors (glm4moe) - NEXTN_EMBED_TOKENS = auto() # nextn tensors (glm4moe) - NEXTN_ENORM = auto() # nextn tensors (glm4moe) - NEXTN_HNORM = auto() # nextn tensors (glm4moe) - NEXTN_SHARED_HEAD_HEAD = auto() # nextn tensors (glm4moe) - NEXTN_SHARED_HEAD_NORM = auto() # nextn tensors (glm4moe) + # nextn/mtp + NEXTN_EH_PROJ = auto() + NEXTN_EMBED_TOKENS = auto() + NEXTN_ENORM = auto() + NEXTN_HNORM = auto() + NEXTN_SHARED_HEAD_HEAD = auto() + NEXTN_SHARED_HEAD_NORM = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -937,7 +938,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.A_MMPROJ_FC: "mm.a.fc", MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre", MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid", - # NextN/MTP tensors (GLM4_MOE) + # NextN/MTP MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.eh_proj", MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.embed_tokens", MODEL_TENSOR.NEXTN_ENORM: "blk.{bid}.enorm", @@ -2129,9 +2130,9 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_OUT, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K_NORM, - MODEL_TENSOR.FFN_GATE, # dense layers - MODEL_TENSOR.FFN_DOWN, # dense layers - MODEL_TENSOR.FFN_UP, # dense layers + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, MODEL_TENSOR.FFN_GATE_INP, MODEL_TENSOR.FFN_GATE_EXP, MODEL_TENSOR.FFN_DOWN_EXP, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index a6a69839ecb63..e7b6e60bdbda5 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1404,9 +1404,9 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, // dense layers - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, // dense layers - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, // dense layers + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, @@ -2178,7 +2178,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SHORTCONV_INPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, // NextN/MTP tensors are loaded but never used (reserved for future MTP support) - // These tensors only exist in the last layer (layer 46 for GLM-4.5-Air) and are treated as output tensors + // These tensors only exist in the last layer and are treated as output tensors {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}}, {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}}, {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}}, diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 7b9987edd03ff..1725f28fcf205 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -40,7 +40,7 @@ llama_kv_cache_unified::llama_kv_cache_unified( n_layer_cache = 20; } if (model.arch == LLM_ARCH_GLM4_MOE) { - // GLM4_MOE: Only process first 46 transformer layers, skip NextN layer + // GLM-4.5: Only process up to last layer, skip final NextN layer n_layer_cache = hparams.n_layer - 1; } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 2533439c06957..6014d0dc87cd7 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -109,9 +109,9 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_A13B: return "A13B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; + case LLM_TYPE_106B_A12B: return "106B.A12B"; case LLM_TYPE_235B_A22B: return "235B.A22B"; case LLM_TYPE_300B_A47B: return "300B.A47B"; - case LLM_TYPE_106B_A12B: return "106B.A12B"; case LLM_TYPE_355B_A32B: return "355B.A32B"; case LLM_TYPE_E2B: return "E2B"; case LLM_TYPE_E4B: return "E4B"; @@ -1432,7 +1432,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - // Expert gating function (GLM4_MOE uses sigmoid) + // Expert gating function (GLM-4.5 uses sigmoid) ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; @@ -4400,8 +4400,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, final_layer), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, final_layer), { n_embd }, TENSOR_NOT_REQUIRED); - // Load ALL tensors including NextN layer to satisfy tensor count (803) - // but only PROCESS first 46 transformer layers in forward pass + // Load ALL tensors including NextN layer to satisfy total tensor count + // but only PROCESS up to last layer (skipping final NextN layer) in forward pass for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -4423,7 +4423,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_k_norm = create_tensor( tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); // Check if this layer uses MoE or dense FFN based on n_layer_dense_lead // GLM 4.5 uses hybrid architecture: layer 0 is dense, layers 1+ are MoE @@ -4448,18 +4448,18 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // Shared expert if (n_expert_shared > 0) { const int64_t n_ff_shexp = n_ff_exp * n_expert_shared; - layer.ffn_gate_shexp = create_tensor( + layer.ffn_gate_shexp = create_tensor( tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); layer.ffn_down_shexp = create_tensor( tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); - layer.ffn_up_shexp = - create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor( + tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); } } else { // Dense layers (first k layers) - GLM uses separate gate/up projections layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); } } @@ -13487,8 +13487,8 @@ struct llm_build_glm4_moe : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); - // Only process first 46 transformer layers (skip NextN layer 46) - // Layer 46 tensors are loaded but not processed in forward pass + // Only process up to last layer (skip final NextN layer) + // Final layer tensors are loaded but not processed in forward pass const int n_transformer_layers = n_layer - 1; for (int il = 0; il < n_transformer_layers; ++il) { ggml_tensor * inpSA = inpL; diff --git a/src/llama-model.h b/src/llama-model.h index 5e71247e37cec..f49a4f968bf65 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -101,10 +101,10 @@ enum llm_type { LLM_TYPE_A13B, LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, + LLM_TYPE_106B_A12B, // GLM-4.5-Air LLM_TYPE_235B_A22B, LLM_TYPE_300B_A47B, // Ernie MoE big - LLM_TYPE_106B_A12B, // GLM-4.5-Air (106B total, 12B active) - LLM_TYPE_355B_A32B, // GLM-4.5 (355B total, 32B active) + LLM_TYPE_355B_A32B, // GLM-4.5 LLM_TYPE_E2B, LLM_TYPE_E4B, }; From 15698b03db71ba848a78b0a99e9e266a6216950f Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 4 Aug 2025 08:31:39 +1000 Subject: [PATCH 03/13] model: Add GLM 4.5 family of models (#14921) 1. Updated tensor_mapping.py with NextN tensor mappings - Added proper tensor mappings for all NextN/MTP tensors in /Users/samm/git/llama.cpp/gguf-py/gguf/tensor_mapping.py - Added mappings for: eh_proj, embed_tokens, enorm, hnorm, shared_head.head, shared_head.norm 2. Added num_nextn_predict_layers configuration - Added LLM_KV_NUM_NEXTN_PREDICT_LAYERS constant to llama-arch.h and llama-arch.cpp - Added num_nextn_predict_layers field to llama_hparams struct - Updated GLM4_MOE parameter loading in llama-model.cpp to read this parameter - Modified tensor loading logic to conditionally load NextN tensors based on num_nextn_predict_layers - Added GGUF writer support in gguf_writer.py with add_num_nextn_predict_layers() method - Updated conversion script to extract and write this parameter from HuggingFace config 3. Added FIM tokens for GLM4_MOE - Added GLM-4.5's FIM tokens to llama-vocab.cpp: - <|code_prefix|> for FIM_PRE - <|code_suffix|> for FIM_SUF - <|code_middle|> for FIM_MID 4. Removed manual NextN tensor handling - Removed the special-case handling in convert_hf_to_gguf.py that manually mapped NextN tensors - NextN tensors are now handled automatically through the proper tensor mapping system --- convert_hf_to_gguf.py | 15 ++++----------- gguf-py/gguf/constants.py | 1 + gguf-py/gguf/gguf_writer.py | 3 +++ gguf-py/gguf/tensor_mapping.py | 25 +++++++++++++++++++++++++ src/llama-arch.cpp | 1 + src/llama-arch.h | 1 + src/llama-hparams.h | 3 ++- src/llama-model.cpp | 21 +++++++++++++-------- src/llama-vocab.cpp | 3 +++ 9 files changed, 53 insertions(+), 20 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index a20fddc5dbb12..bc99dac9e7057 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6642,6 +6642,10 @@ def set_gguf_parameters(self): if (norm_topk_prob := self.hparams.get("norm_topk_prob")) is not None: self.gguf_writer.add_expert_weights_norm(norm_topk_prob) + # NextN/MTP prediction layers + if (num_nextn_predict_layers := self.hparams.get("num_nextn_predict_layers")) is not None: + self.gguf_writer.add_num_nextn_predict_layers(num_nextn_predict_layers) + _experts: list[dict[str, Tensor]] | None = None def modify_tensors( @@ -6691,17 +6695,6 @@ def modify_tensors( if name.endswith("e_score_correction_bias"): name = name.replace("e_score_correction_bias", "e_score_correction.bias") - # Handle special NextN tensors - preserve for future MTP support - if ( - ".embed_tokens." in name - or ".shared_head." in name - or ".eh_proj." in name - or ".enorm." in name - or ".hnorm." in name - ): - new_name = name.replace("model.layers.", "blk.").replace("model.", "").replace(".weight", "") - return [(new_name, data_torch)] - new_name = self.map_tensor_name(name) return [(new_name, data_torch)] diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index d50f4b5bceb53..6bd9c05699c43 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -105,6 +105,7 @@ class LLM: EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm" EXPERT_GATING_FUNC = "{arch}.expert_gating_func" MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers" + NUM_NEXTN_PREDICT_LAYERS = "{arch}.num_nextn_predict_layers" POOLING_TYPE = "{arch}.pooling_type" LOGIT_SCALE = "{arch}.logit_scale" DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 4f23f9b024619..c20d2f9635472 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -753,6 +753,9 @@ def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None: def add_moe_every_n_layers(self, value: int) -> None: self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value) + def add_num_nextn_predict_layers(self, count: int) -> None: + self.add_uint32(Keys.LLM.NUM_NEXTN_PREDICT_LAYERS.format(arch=self.arch), count) + def add_swin_norm(self, value: bool) -> None: self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index bfd4fd37a3f68..964a4d7e1b72f 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1348,6 +1348,31 @@ class TensorNameMap: MODEL_TENSOR.A_MM_NORM_MID: ( "audio.multi_modal_projector.ln_mid", # ultravox ), + + # NextN/MTP tensors for GLM4_MOE + MODEL_TENSOR.NEXTN_EH_PROJ: ( + "model.layers.{bid}.eh_proj", + ), + + MODEL_TENSOR.NEXTN_EMBED_TOKENS: ( + "model.layers.{bid}.embed_tokens", + ), + + MODEL_TENSOR.NEXTN_ENORM: ( + "model.layers.{bid}.enorm", + ), + + MODEL_TENSOR.NEXTN_HNORM: ( + "model.layers.{bid}.hnorm", + ), + + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: ( + "model.layers.{bid}.shared_head.head", + ), + + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: ( + "model.layers.{bid}.shared_head.norm", + ), } # architecture-specific block mappings diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index e7b6e60bdbda5..9989946c08574 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -126,6 +126,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" }, { LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" }, { LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" }, + { LLM_KV_NUM_NEXTN_PREDICT_LAYERS, "%s.num_nextn_predict_layers" }, { LLM_KV_POOLING_TYPE, "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index 73e546673bc7b..42a20cba45124 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -130,6 +130,7 @@ enum llm_kv { LLM_KV_EXPERT_WEIGHTS_NORM, LLM_KV_EXPERT_GATING_FUNC, LLM_KV_MOE_EVERY_N_LAYERS, + LLM_KV_NUM_NEXTN_PREDICT_LAYERS, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, LLM_KV_DECODER_START_TOKEN_ID, diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 8b7e2a1130755..718e347c40182 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -72,7 +72,8 @@ struct llama_hparams { float expert_weights_scale = 0.0; bool expert_weights_norm = false; uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE; - uint32_t moe_every_n_layers = 0; + uint32_t moe_every_n_layers = 0; + uint32_t num_nextn_predict_layers = 0; float f_norm_eps; float f_norm_rms_eps; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 6014d0dc87cd7..8f89a23512788 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1438,6 +1438,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; } + // NextN/MTP parameters + ml.get_key(LLM_KV_NUM_NEXTN_PREDICT_LAYERS, hparams.num_nextn_predict_layers, false); + switch (hparams.n_layer) { case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer) case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer) @@ -4391,14 +4394,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); } - // NextN/MTP tensors (preserved but unused) - only in final layer (46 for Air, 92 for GLM-4.5) - const int final_layer = n_layer - 1; // NextN tensors are in the last layer only - create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, final_layer), { 2 * n_embd, n_embd }, TENSOR_NOT_REQUIRED); - create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, final_layer), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, final_layer), { n_embd }, TENSOR_NOT_REQUIRED); - create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, final_layer), { n_embd }, TENSOR_NOT_REQUIRED); - create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, final_layer), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, final_layer), { n_embd }, TENSOR_NOT_REQUIRED); + // NextN/MTP tensors (preserved but unused) - conditionally load based on num_nextn_predict_layers + const uint32_t n_nextn_layers = hparams.num_nextn_predict_layers > 0 ? hparams.num_nextn_predict_layers : 1; + for (uint32_t i = n_layer - n_nextn_layers; i < n_layer; ++i) { + create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, i), { 2 * n_embd, n_embd }, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, i), { n_embd }, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, i), { n_embd }, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, i), { n_embd }, TENSOR_NOT_REQUIRED); + } // Load ALL tensors including NextN layer to satisfy total tensor count // but only PROCESS up to last layer (skipping final NextN layer) in forward pass diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index e8bae645088dd..12faa78bd66f2 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -2185,6 +2185,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|fim▁begin|>" // DeepSeek || t.first == "
"
                         || t.first == "▁
"          // CodeLlama
+                        || t.first == "<|code_prefix|>" // GLM4_MOE
                         ) {
                     special_fim_pre_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2204,6 +2205,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == "<|fim▁hole|>" // DeepSeek
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
+                        || t.first == "<|code_suffix|>" // GLM4_MOE
                         ) {
                     special_fim_suf_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2223,6 +2225,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == "<|fim▁end|>"  // DeepSeek
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
+                        || t.first == "<|code_middle|>" // GLM4_MOE
                         ) {
                     special_fim_mid_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {

From c429c1ab1787a52f80c60de49e01877de5df001c Mon Sep 17 00:00:00 2001
From: Sam 
Date: Mon, 4 Aug 2025 11:28:51 +1000
Subject: [PATCH 04/13] glm 4.5 update tensors names

---
 gguf-py/gguf/tensor_mapping.py | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py
index 964a4d7e1b72f..7a4c4d7bac636 100644
--- a/gguf-py/gguf/tensor_mapping.py
+++ b/gguf-py/gguf/tensor_mapping.py
@@ -1351,27 +1351,27 @@ class TensorNameMap:
 
         # NextN/MTP tensors for GLM4_MOE
         MODEL_TENSOR.NEXTN_EH_PROJ: (
-            "model.layers.{bid}.eh_proj",
+            "model.layers.{bid}.eh_proj.weight",
         ),
 
         MODEL_TENSOR.NEXTN_EMBED_TOKENS: (
-            "model.layers.{bid}.embed_tokens",
+            "model.layers.{bid}.embed_tokens.weight",
         ),
 
         MODEL_TENSOR.NEXTN_ENORM: (
-            "model.layers.{bid}.enorm",
+            "model.layers.{bid}.enorm.weight",
         ),
 
         MODEL_TENSOR.NEXTN_HNORM: (
-            "model.layers.{bid}.hnorm",
+            "model.layers.{bid}.hnorm.weight",
         ),
 
         MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: (
-            "model.layers.{bid}.shared_head.head",
+            "model.layers.{bid}.shared_head.head.weight",
         ),
 
         MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: (
-            "model.layers.{bid}.shared_head.norm",
+            "model.layers.{bid}.shared_head.norm.weight",
         ),
     }
 

From 7c8fc019229c1a366224ff9a5913066b32b11712 Mon Sep 17 00:00:00 2001
From: Sam 
Date: Mon, 4 Aug 2025 17:37:35 +1000
Subject: [PATCH 05/13] model: glm 4.5 apply suggestions from code review
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Sigbjørn Skjæret 
---
 convert_hf_to_gguf.py          |  2 +-
 gguf-py/gguf/constants.py      | 14 +++++++-------
 gguf-py/gguf/gguf_writer.py    |  4 ++--
 gguf-py/gguf/tensor_mapping.py | 12 ++++++------
 src/llama-arch.cpp             | 14 +++++++-------
 src/llama-arch.h               |  2 +-
 src/llama-hparams.h            |  4 ++--
 src/llama-kv-cache-unified.cpp |  2 +-
 8 files changed, 27 insertions(+), 27 deletions(-)

diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index bc99dac9e7057..c70197575380d 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -6644,7 +6644,7 @@ def set_gguf_parameters(self):
 
         # NextN/MTP prediction layers
         if (num_nextn_predict_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
-            self.gguf_writer.add_num_nextn_predict_layers(num_nextn_predict_layers)
+            self.gguf_writer.add_nextn_predict_layers(num_nextn_predict_layers)
 
     _experts: list[dict[str, Tensor]] | None = None
 
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index 6bd9c05699c43..290f4c4581558 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -105,7 +105,7 @@ class LLM:
         EXPERT_WEIGHTS_NORM               = "{arch}.expert_weights_norm"
         EXPERT_GATING_FUNC                = "{arch}.expert_gating_func"
         MOE_EVERY_N_LAYERS                = "{arch}.moe_every_n_layers"
-        NUM_NEXTN_PREDICT_LAYERS          = "{arch}.num_nextn_predict_layers"
+        NEXTN_PREDICT_LAYERS              = "{arch}.num_nextn_predict_layers"
         POOLING_TYPE                      = "{arch}.pooling_type"
         LOGIT_SCALE                       = "{arch}.logit_scale"
         DECODER_START_TOKEN_ID            = "{arch}.decoder_start_token_id"
@@ -940,12 +940,12 @@ class MODEL_TENSOR(IntEnum):
     MODEL_TENSOR.A_MM_NORM_PRE:             "mm.a.norm_pre",
     MODEL_TENSOR.A_MM_NORM_MID:             "mm.a.norm_mid",
     # NextN/MTP
-    MODEL_TENSOR.NEXTN_EH_PROJ:             "blk.{bid}.eh_proj",
-    MODEL_TENSOR.NEXTN_EMBED_TOKENS:        "blk.{bid}.embed_tokens",
-    MODEL_TENSOR.NEXTN_ENORM:               "blk.{bid}.enorm",
-    MODEL_TENSOR.NEXTN_HNORM:               "blk.{bid}.hnorm",
-    MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD:    "blk.{bid}.shared_head.head",
-    MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM:    "blk.{bid}.shared_head.norm",
+    MODEL_TENSOR.NEXTN_EH_PROJ:             "blk.{bid}.nextn.eh_proj",
+    MODEL_TENSOR.NEXTN_EMBED_TOKENS:        "blk.{bid}.nextn.embed_tokens",
+    MODEL_TENSOR.NEXTN_ENORM:               "blk.{bid}.nextn.enorm",
+    MODEL_TENSOR.NEXTN_HNORM:               "blk.{bid}.nextn.hnorm",
+    MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD:    "blk.{bid}.nextn.shared_head_head",
+    MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM:    "blk.{bid}.nextn.shared_head_norm",
 }
 
 MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py
index c20d2f9635472..49223a2c011aa 100644
--- a/gguf-py/gguf/gguf_writer.py
+++ b/gguf-py/gguf/gguf_writer.py
@@ -753,8 +753,8 @@ def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None:
     def add_moe_every_n_layers(self, value: int) -> None:
         self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value)
 
-    def add_num_nextn_predict_layers(self, count: int) -> None:
-        self.add_uint32(Keys.LLM.NUM_NEXTN_PREDICT_LAYERS.format(arch=self.arch), count)
+    def add_nextn_predict_layers(self, count: int) -> None:
+        self.add_uint32(Keys.LLM.NEXTN_PREDICT_LAYERS.format(arch=self.arch), count)
 
     def add_swin_norm(self, value: bool) -> None:
         self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value)
diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py
index 7a4c4d7bac636..964a4d7e1b72f 100644
--- a/gguf-py/gguf/tensor_mapping.py
+++ b/gguf-py/gguf/tensor_mapping.py
@@ -1351,27 +1351,27 @@ class TensorNameMap:
 
         # NextN/MTP tensors for GLM4_MOE
         MODEL_TENSOR.NEXTN_EH_PROJ: (
-            "model.layers.{bid}.eh_proj.weight",
+            "model.layers.{bid}.eh_proj",
         ),
 
         MODEL_TENSOR.NEXTN_EMBED_TOKENS: (
-            "model.layers.{bid}.embed_tokens.weight",
+            "model.layers.{bid}.embed_tokens",
         ),
 
         MODEL_TENSOR.NEXTN_ENORM: (
-            "model.layers.{bid}.enorm.weight",
+            "model.layers.{bid}.enorm",
         ),
 
         MODEL_TENSOR.NEXTN_HNORM: (
-            "model.layers.{bid}.hnorm.weight",
+            "model.layers.{bid}.hnorm",
         ),
 
         MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: (
-            "model.layers.{bid}.shared_head.head.weight",
+            "model.layers.{bid}.shared_head.head",
         ),
 
         MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: (
-            "model.layers.{bid}.shared_head.norm.weight",
+            "model.layers.{bid}.shared_head.norm",
         ),
     }
 
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
index 9989946c08574..d7eee1883c1ff 100644
--- a/src/llama-arch.cpp
+++ b/src/llama-arch.cpp
@@ -126,7 +126,7 @@ static const std::map LLM_KV_NAMES = {
     { LLM_KV_EXPERT_WEIGHTS_NORM,               "%s.expert_weights_norm"               },
     { LLM_KV_EXPERT_GATING_FUNC,                "%s.expert_gating_func"                },
     { LLM_KV_MOE_EVERY_N_LAYERS,                "%s.moe_every_n_layers"                },
-    { LLM_KV_NUM_NEXTN_PREDICT_LAYERS,           "%s.num_nextn_predict_layers"           },
+    { LLM_KV_NEXTN_PREDICT_LAYERS,              "%s.nextn_predict_layers"               },
     { LLM_KV_POOLING_TYPE,                      "%s.pooling_type"                      },
     { LLM_KV_LOGIT_SCALE,                       "%s.logit_scale"                       },
     { LLM_KV_DECODER_START_TOKEN_ID,            "%s.decoder_start_token_id"            },
@@ -1417,12 +1417,12 @@ static const std::map> LLM_TENSOR_N
             { LLM_TENSOR_FFN_UP_SHEXP,       "blk.%d.ffn_up_shexp" },
             { LLM_TENSOR_FFN_EXP_PROBS_B,    "blk.%d.exp_probs_b" },
             // NextN/MTP tensors - preserved but unused (in final layer, dynamic layer number)
-            { LLM_TENSOR_NEXTN_EH_PROJ,      "blk.%d.eh_proj" },
-            { LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.embed_tokens" },
-            { LLM_TENSOR_NEXTN_ENORM,        "blk.%d.enorm" },
-            { LLM_TENSOR_NEXTN_HNORM,        "blk.%d.hnorm" },
-            { LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.shared_head.head" },
-            { LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.shared_head.norm" },
+            { LLM_TENSOR_NEXTN_EH_PROJ,      "blk.%d.nextn.eh_proj" },
+            { LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" },
+            { LLM_TENSOR_NEXTN_ENORM,        "blk.%d.nextn.enorm" },
+            { LLM_TENSOR_NEXTN_HNORM,        "blk.%d.nextn.hnorm" },
+            { LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.nextn.shared_head_head" },
+            { LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" },
         },
     },
     {
diff --git a/src/llama-arch.h b/src/llama-arch.h
index 42a20cba45124..c38969500e0e2 100644
--- a/src/llama-arch.h
+++ b/src/llama-arch.h
@@ -130,7 +130,7 @@ enum llm_kv {
     LLM_KV_EXPERT_WEIGHTS_NORM,
     LLM_KV_EXPERT_GATING_FUNC,
     LLM_KV_MOE_EVERY_N_LAYERS,
-    LLM_KV_NUM_NEXTN_PREDICT_LAYERS,
+    LLM_KV_NEXTN_PREDICT_LAYERS,
     LLM_KV_POOLING_TYPE,
     LLM_KV_LOGIT_SCALE,
     LLM_KV_DECODER_START_TOKEN_ID,
diff --git a/src/llama-hparams.h b/src/llama-hparams.h
index 718e347c40182..d60035726802e 100644
--- a/src/llama-hparams.h
+++ b/src/llama-hparams.h
@@ -72,8 +72,8 @@ struct llama_hparams {
     float    expert_weights_scale = 0.0;
     bool     expert_weights_norm  = false;
     uint32_t expert_gating_func   = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
-    uint32_t moe_every_n_layers       = 0;
-    uint32_t num_nextn_predict_layers = 0;
+    uint32_t moe_every_n_layers   = 0;
+    uint32_t nextn_predict_layers = 0;
 
     float f_norm_eps;
     float f_norm_rms_eps;
diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp
index 1725f28fcf205..449102f85bddf 100644
--- a/src/llama-kv-cache-unified.cpp
+++ b/src/llama-kv-cache-unified.cpp
@@ -41,7 +41,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
     }
     if (model.arch == LLM_ARCH_GLM4_MOE) {
         // GLM-4.5: Only process up to last layer, skip final NextN layer
-        n_layer_cache = hparams.n_layer - 1;
+        n_layer_cache = hparams.n_layer - hparam.nextn_predict_layers;
     }
 
     // create a context for each buffer type

From 07416e0a1de0fe41ee164f1a44d6fa3c7f756979 Mon Sep 17 00:00:00 2001
From: Sam 
Date: Mon, 4 Aug 2025 17:38:24 +1000
Subject: [PATCH 06/13] Update src/llama-model.cpp
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Sigbjørn Skjæret 
---
 src/llama-model.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index 8f89a23512788..35133407f4f43 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -13494,7 +13494,7 @@ struct llm_build_glm4_moe : public llm_graph_context {
 
         // Only process up to last layer (skip final NextN layer)
         // Final layer tensors are loaded but not processed in forward pass
-        const int n_transformer_layers = n_layer - 1;
+        const int n_transformer_layers = n_layer - hparam.nextn_predict_layers;
         for (int il = 0; il < n_transformer_layers; ++il) {
             ggml_tensor * inpSA = inpL;
 

From 21b10415cc503198e02f98dd86c0da8dd17e896c Mon Sep 17 00:00:00 2001
From: Sam 
Date: Mon, 4 Aug 2025 17:44:50 +1000
Subject: [PATCH 07/13] model: glm 4.5 apply suggestions from code review
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Sigbjørn Skjæret 
---
 src/llama-arch.cpp             | 2 +-
 src/llama-kv-cache-unified.cpp | 2 +-
 src/llama-model.cpp            | 2 +-
 3 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
index d7eee1883c1ff..33ec26fd1bf40 100644
--- a/src/llama-arch.cpp
+++ b/src/llama-arch.cpp
@@ -126,7 +126,7 @@ static const std::map LLM_KV_NAMES = {
     { LLM_KV_EXPERT_WEIGHTS_NORM,               "%s.expert_weights_norm"               },
     { LLM_KV_EXPERT_GATING_FUNC,                "%s.expert_gating_func"                },
     { LLM_KV_MOE_EVERY_N_LAYERS,                "%s.moe_every_n_layers"                },
-    { LLM_KV_NEXTN_PREDICT_LAYERS,              "%s.nextn_predict_layers"               },
+    { LLM_KV_NEXTN_PREDICT_LAYERS,              "%s.nextn_predict_layers"              },
     { LLM_KV_POOLING_TYPE,                      "%s.pooling_type"                      },
     { LLM_KV_LOGIT_SCALE,                       "%s.logit_scale"                       },
     { LLM_KV_DECODER_START_TOKEN_ID,            "%s.decoder_start_token_id"            },
diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp
index 449102f85bddf..f092e7a066edd 100644
--- a/src/llama-kv-cache-unified.cpp
+++ b/src/llama-kv-cache-unified.cpp
@@ -41,7 +41,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
     }
     if (model.arch == LLM_ARCH_GLM4_MOE) {
         // GLM-4.5: Only process up to last layer, skip final NextN layer
-        n_layer_cache = hparams.n_layer - hparam.nextn_predict_layers;
+        n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers;
     }
 
     // create a context for each buffer type
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index 35133407f4f43..24800e829debe 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -13494,7 +13494,7 @@ struct llm_build_glm4_moe : public llm_graph_context {
 
         // Only process up to last layer (skip final NextN layer)
         // Final layer tensors are loaded but not processed in forward pass
-        const int n_transformer_layers = n_layer - hparam.nextn_predict_layers;
+        const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
         for (int il = 0; il < n_transformer_layers; ++il) {
             ggml_tensor * inpSA = inpL;
 

From 0bb9f28f4621040e5542ce95259e5fdde78ab4d9 Mon Sep 17 00:00:00 2001
From: Sam 
Date: Mon, 4 Aug 2025 17:47:22 +1000
Subject: [PATCH 08/13] model: glm 4.5 apply suggestions from code review

---
 src/llama-model.cpp | 22 +++++++++++-----------
 src/llama-model.h   | 11 +++++++++++
 2 files changed, 22 insertions(+), 11 deletions(-)

diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index 24800e829debe..8c525cfc577cc 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -1439,7 +1439,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 }
 
                 // NextN/MTP parameters
-                ml.get_key(LLM_KV_NUM_NEXTN_PREDICT_LAYERS,    hparams.num_nextn_predict_layers, false);
+                ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS,    hparams.nextn_predict_layers, false);
 
                 switch (hparams.n_layer) {
                     case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer)
@@ -4394,16 +4394,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
                     }
 
-                    // NextN/MTP tensors (preserved but unused) - conditionally load based on num_nextn_predict_layers
-                    const uint32_t n_nextn_layers = hparams.num_nextn_predict_layers > 0 ? hparams.num_nextn_predict_layers : 1;
-                    for (uint32_t i = n_layer - n_nextn_layers; i < n_layer; ++i) {
-                        create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, i), { 2 * n_embd, n_embd }, TENSOR_NOT_REQUIRED);
-                        create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
-                        create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, i), { n_embd }, TENSOR_NOT_REQUIRED);
-                        create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, i), { n_embd }, TENSOR_NOT_REQUIRED);
-                        create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
-                        create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, i), { n_embd }, TENSOR_NOT_REQUIRED);
-                    }
 
                     // Load ALL tensors including NextN layer to satisfy total tensor count
                     // but only PROCESS up to last layer (skipping final NextN layer) in forward pass
@@ -4467,6 +4457,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                             layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), { n_embd, n_ff }, 0);
                         }
 
+                        // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers
+                        if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) {
+                            layer.nextn.eh_proj          = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, TENSOR_NOT_REQUIRED);
+                            layer.nextn.embed_tokens     = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
+                            layer.nextn.enorm            = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
+                            layer.nextn.hnorm            = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
+                            layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
+                            layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
+                        }
+
                     }
                 }
                 break;
diff --git a/src/llama-model.h b/src/llama-model.h
index f49a4f968bf65..bdb81cecd89d0 100644
--- a/src/llama-model.h
+++ b/src/llama-model.h
@@ -168,6 +168,15 @@ struct llama_layer_shortconv {
     struct ggml_tensor * out_proj = nullptr;
 };
 
+struct llama_layer_nextn {
+    struct ggml_tensor * eh_proj          = nullptr;
+    struct ggml_tensor * embed_tokens     = nullptr;
+    struct ggml_tensor * enorm            = nullptr;
+    struct ggml_tensor * hnorm            = nullptr;
+    struct ggml_tensor * shared_head_head = nullptr;
+    struct ggml_tensor * shared_head_norm = nullptr;
+};
+
 struct llama_layer {
     // normalization
     struct ggml_tensor * attn_norm       = nullptr;
@@ -356,6 +365,8 @@ struct llama_layer {
     struct llama_layer_convnext convnext;
 
     struct llama_layer_shortconv shortconv;
+
+    struct llama_layer_nextn nextn;
 };
 
 struct llama_model {

From f5df812c90a4c198adac948e474b6f1992fa16ac Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= 
Date: Mon, 4 Aug 2025 11:24:39 +0200
Subject: [PATCH 09/13] Apply suggestions from code review

---
 convert_hf_to_gguf.py     | 2 +-
 gguf-py/gguf/constants.py | 2 +-
 src/llama-model.cpp       | 3 +--
 src/llama-vocab.cpp       | 6 +++---
 4 files changed, 6 insertions(+), 7 deletions(-)

diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index c70197575380d..45ae74fe824f7 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -6588,7 +6588,7 @@ class Glm4MoeModel(TextModel):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
         # GLM4_MOE has num_hidden_layers + 1 actual layers (including NextN layer)
-        self.block_count = self.hparams["num_hidden_layers"] + 1
+        self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0)
         self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
 
     def set_vocab(self):
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index 290f4c4581558..932779059f571 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -105,7 +105,7 @@ class LLM:
         EXPERT_WEIGHTS_NORM               = "{arch}.expert_weights_norm"
         EXPERT_GATING_FUNC                = "{arch}.expert_gating_func"
         MOE_EVERY_N_LAYERS                = "{arch}.moe_every_n_layers"
-        NEXTN_PREDICT_LAYERS              = "{arch}.num_nextn_predict_layers"
+        NEXTN_PREDICT_LAYERS              = "{arch}.nextn_predict_layers"
         POOLING_TYPE                      = "{arch}.pooling_type"
         LOGIT_SCALE                       = "{arch}.logit_scale"
         DECODER_START_TOKEN_ID            = "{arch}.decoder_start_token_id"
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index 8c525cfc577cc..9226231c2f868 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -1439,7 +1439,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 }
 
                 // NextN/MTP parameters
-                ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS,    hparams.nextn_predict_layers, false);
+                ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS,        hparams.nextn_predict_layers, false);
 
                 switch (hparams.n_layer) {
                     case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer)
@@ -4394,7 +4394,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
                     }
 
-
                     // Load ALL tensors including NextN layer to satisfy total tensor count
                     // but only PROCESS up to last layer (skipping final NextN layer) in forward pass
                     for (int i = 0; i < n_layer; ++i) {
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
index 12faa78bd66f2..3bf0cc31ee24d 100644
--- a/src/llama-vocab.cpp
+++ b/src/llama-vocab.cpp
@@ -2185,7 +2185,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == "<|fim▁begin|>" // DeepSeek
                         || t.first == "
"
                         || t.first == "▁
"          // CodeLlama
-                        || t.first == "<|code_prefix|>" // GLM4_MOE
+                        || t.first == "<|code_prefix|>" // GLM-4.5
                         ) {
                     special_fim_pre_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2205,7 +2205,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == "<|fim▁hole|>" // DeepSeek
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
-                        || t.first == "<|code_suffix|>" // GLM4_MOE
+                        || t.first == "<|code_suffix|>" // GLM-4.5
                         ) {
                     special_fim_suf_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2225,7 +2225,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == "<|fim▁end|>"  // DeepSeek
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
-                        || t.first == "<|code_middle|>" // GLM4_MOE
+                        || t.first == "<|code_middle|>" // GLM-4.5
                         ) {
                     special_fim_mid_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {

From 83eee1ce988115cb1bb2922b7bb22d8b832df642 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= 
Date: Mon, 4 Aug 2025 15:55:27 +0200
Subject: [PATCH 10/13] patch broken chat template

---
 convert_hf_to_gguf.py | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 45ae74fe824f7..a2b27fa3e4cbf 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -6609,6 +6609,12 @@ def set_vocab(self):
         special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
         special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"])  # 151338
 
+        # Patch broken chat template
+        if special_vocab.chat_template and "visible_text(m.content).endswith" in special_vocab.chat_template:
+            special_vocab.chat_template = special_vocab.chat_template.replace(
+                """{{ visible_text(m.content) }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not visible_text(m.content).endswith("/nothink")) else '' -}}""",
+                """{% set content = visible_text(m.content) %}{{ content }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not content.endswith("/nothink")) else '' -}}""")
+
         special_vocab.add_to_gguf(self.gguf_writer)
 
     def set_gguf_parameters(self):

From 447f9337890d38e4702b1dc6204ff5009b828275 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= 
Date: Mon, 4 Aug 2025 16:22:41 +0200
Subject: [PATCH 11/13] typings fix

---
 convert_hf_to_gguf.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index a2b27fa3e4cbf..ae25d94aadb4c 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -6610,7 +6610,7 @@ def set_vocab(self):
         special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"])  # 151338
 
         # Patch broken chat template
-        if special_vocab.chat_template and "visible_text(m.content).endswith" in special_vocab.chat_template:
+        if isinstance(special_vocab.chat_template, str) and "visible_text(m.content).endswith" in special_vocab.chat_template:
             special_vocab.chat_template = special_vocab.chat_template.replace(
                 """{{ visible_text(m.content) }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not visible_text(m.content).endswith("/nothink")) else '' -}}""",
                 """{% set content = visible_text(m.content) %}{{ content }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not content.endswith("/nothink")) else '' -}}""")

From f129567dc0232272358ea71c5017486554b2abd3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= 
Date: Mon, 4 Aug 2025 19:11:17 +0200
Subject: [PATCH 12/13] add TENSOR_SKIP flag

Co-authored-by: Diego Devesa 
---
 src/llama-arch.cpp       | 16 +++++-----
 src/llama-model-loader.h |  5 +--
 src/llama-model.cpp      | 66 ++++++++++++++++++++++------------------
 3 files changed, 47 insertions(+), 40 deletions(-)

diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
index 33ec26fd1bf40..2e3e38d08fc9a 100644
--- a/src/llama-arch.cpp
+++ b/src/llama-arch.cpp
@@ -2178,14 +2178,14 @@ static const std::map LLM_TENSOR_INFOS = {
     {LLM_TENSOR_SHORTCONV_CONV,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
     {LLM_TENSOR_SHORTCONV_INPROJ,           {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_SHORTCONV_OUTPROJ,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    // NextN/MTP tensors are loaded but never used (reserved for future MTP support)
-    // These tensors only exist in the last layer and are treated as output tensors
-    {LLM_TENSOR_NEXTN_EH_PROJ,              {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}},
-    {LLM_TENSOR_NEXTN_EMBED_TOKENS,         {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}},
-    {LLM_TENSOR_NEXTN_ENORM,                {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}},
-    {LLM_TENSOR_NEXTN_HNORM,                {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}},
-    {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,     {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}},
-    {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,     {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}},
+    // NextN/MTP tensors are currently ignored (reserved for future MTP support)
+    // These tensors only exist in the last layer(s) and are treated as output tensors
+    {LLM_TENSOR_NEXTN_EH_PROJ,              {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_NEXTN_EMBED_TOKENS,         {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
+    {LLM_TENSOR_NEXTN_ENORM,                {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
+    {LLM_TENSOR_NEXTN_HNORM,                {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
+    {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,     {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,     {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
 };
 
 LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
diff --git a/src/llama-model-loader.h b/src/llama-model-loader.h
index 0f52b011b6986..3b5707d5af6ae 100644
--- a/src/llama-model-loader.h
+++ b/src/llama-model-loader.h
@@ -58,8 +58,9 @@ struct llama_model_loader {
         }
     };
 
-    static const int TENSOR_NOT_REQUIRED = 1;
-    static const int TENSOR_DUPLICATED   = 2;
+    static const int TENSOR_NOT_REQUIRED = 1 << 1;
+    static const int TENSOR_DUPLICATED   = 1 << 2;
+    static const int TENSOR_SKIP         = 1 << 3;
 
     int n_kv      = 0;
     int n_tensors = 0;
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index 9226231c2f868..d869997dca6e5 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -1950,6 +1950,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
 
     const auto TENSOR_DUPLICATED   = llama_model_loader::TENSOR_DUPLICATED;
     const auto TENSOR_NOT_REQUIRED = llama_model_loader::TENSOR_NOT_REQUIRED;
+    const auto TENSOR_SKIP         = llama_model_loader::TENSOR_SKIP;
 
     // create tensors for the weights
     {
@@ -2005,7 +2006,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
             }
 
             // skip unused tensors
-            if (info.op == GGML_OP_NONE) {
+            if (info.op == GGML_OP_NONE || flags & TENSOR_SKIP) {
                 const size_t nbytes = ggml_nbytes(t_meta);
                 LLAMA_LOG_WARN("model has unused tensor %s (size = %zu bytes) -- ignoring\n", tn.str().c_str(), nbytes);
 
@@ -4397,27 +4398,33 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                     // Load ALL tensors including NextN layer to satisfy total tensor count
                     // but only PROCESS up to last layer (skipping final NextN layer) in forward pass
                     for (int i = 0; i < n_layer; ++i) {
+                        int flags = 0;
+                        if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) {
+                            // skip all tensors in the NextN layers
+                            flags |= TENSOR_SKIP;
+                        }
+
                         auto & layer = layers[i];
 
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, flags);
 
                         // GLM-style attention with bias terms
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
-                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, 0);
-                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, 0);
-                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, 0);
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, flags);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, flags);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, flags);
+                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, flags);
+                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, flags);
+                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, flags);
 
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags);
 
                         // K/Q norm tensors (optional for GLM-4.5 355B variant)
                         layer.attn_q_norm = create_tensor(
-                            tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED);
+                            tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED | flags);
                         layer.attn_k_norm = create_tensor(
-                            tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED);
+                            tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED | flags);
 
-                        layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
+                        layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, flags);
 
                         // Check if this layer uses MoE or dense FFN based on n_layer_dense_lead
                         // GLM 4.5 uses hybrid architecture: layer 0 is dense, layers 1+ are MoE
@@ -4426,46 +4433,45 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         if (use_moe) {
                             // MoE layers
                             layer.ffn_gate_inp =
-                                create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0);
-                            layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), { n_expert }, 0);
+                                create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, flags);
+                            layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), { n_expert }, flags);
 
                             // MoE branch
                             const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
 
                             layer.ffn_gate_exps = create_tensor(
-                                tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
+                                tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags);
                             layer.ffn_down_exps = create_tensor(
-                                tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0);
+                                tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, flags);
                             layer.ffn_up_exps = create_tensor(
-                                tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
+                                tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags);
 
                             // Shared expert
                             if (n_expert_shared > 0) {
                                 const int64_t n_ff_shexp = n_ff_exp * n_expert_shared;
                                 layer.ffn_gate_shexp = create_tensor(
-                                    tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0);
+                                    tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags);
                                 layer.ffn_down_shexp = create_tensor(
-                                    tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0);
+                                    tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, flags);
                                 layer.ffn_up_shexp = create_tensor(
-                                    tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0);
+                                    tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags);
                             }
                         } else {
                             // Dense layers (first k layers) - GLM uses separate gate/up projections
-                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
-                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
-                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), { n_embd, n_ff }, 0);
+                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, flags);
+                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, flags);
+                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), { n_embd, n_ff }, flags);
                         }
 
                         // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers
                         if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) {
-                            layer.nextn.eh_proj          = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, TENSOR_NOT_REQUIRED);
-                            layer.nextn.embed_tokens     = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
-                            layer.nextn.enorm            = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
-                            layer.nextn.hnorm            = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
-                            layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
-                            layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
+                            layer.nextn.eh_proj          = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags);
+                            layer.nextn.embed_tokens     = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags);
+                            layer.nextn.enorm            = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags);
+                            layer.nextn.hnorm            = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags);
+                            layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags);
+                            layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags);
                         }
-
                     }
                 }
                 break;

From dcbbd2cb057a6c6e907e0195395a74201ef19e1b Mon Sep 17 00:00:00 2001
From: Diego Devesa 
Date: Mon, 4 Aug 2025 19:21:58 +0200
Subject: [PATCH 13/13] Update src/llama-model-loader.h
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Sigbjørn Skjæret 
---
 src/llama-model-loader.h | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/src/llama-model-loader.h b/src/llama-model-loader.h
index 3b5707d5af6ae..c9189f6cb4466 100644
--- a/src/llama-model-loader.h
+++ b/src/llama-model-loader.h
@@ -58,9 +58,9 @@ struct llama_model_loader {
         }
     };
 
-    static const int TENSOR_NOT_REQUIRED = 1 << 1;
-    static const int TENSOR_DUPLICATED   = 1 << 2;
-    static const int TENSOR_SKIP         = 1 << 3;
+    static const int TENSOR_NOT_REQUIRED = 1 << 0;
+    static const int TENSOR_DUPLICATED   = 1 << 1;
+    static const int TENSOR_SKIP         = 1 << 2;
 
     int n_kv      = 0;
     int n_tensors = 0;