From 5d2c042513e093b5aa1a4e46ff749c2abc913d4a Mon Sep 17 00:00:00 2001 From: stevenkuang Date: Fri, 25 Jul 2025 19:55:33 +0800 Subject: [PATCH 01/10] support hunyuan_v1_dense Signed-off-by: stevenkuang --- convert_hf_to_gguf.py | 105 +++++++++++++++++++ convert_hf_to_gguf_update.py | 1 + gguf-py/gguf/constants.py | 18 ++++ src/llama-arch.cpp | 21 ++++ src/llama-arch.h | 1 + src/llama-chat.cpp | 26 +++++ src/llama-chat.h | 1 + src/llama-model.cpp | 188 +++++++++++++++++++++++++++++++++++ 8 files changed, 361 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index e12c922bd9ab4..010e209b22a7a 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -684,6 +684,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664": # ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct res = "hunyuan" + if chkhsh == "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6": + # TODO: update ref + res = "hunyuan" if chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6": # ref: https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base res = "falcon-h1" @@ -7531,6 +7534,108 @@ def prepare_tensors(self): raise ValueError(f"Unprocessed experts: {experts}") +@ModelBase.register("HunYuanDenseV1ForCausalLM") +class HunYuanModel(TextModel): + model_arch = gguf.MODEL_ARCH.HUNYUAN_V1_DENSE + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # For handling tied embeddings + self._tok_embd = None + + def set_vocab(self): + if (self.dir_model / "tokenizer.json").is_file(): + self._set_vocab_gpt2() + else: + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) + + # 1. Get the pre-tokenizer identifier hash + tokpre = self.get_vocab_base_pre(tokenizer) + + # 2. Reverse-engineer the merges list from mergeable_ranks + merges = [] + vocab = {} + mergeable_ranks = tokenizer.mergeable_ranks + for token, rank in mergeable_ranks.items(): + vocab[QwenModel.token_bytes_to_string(token)] = rank + if len(token) == 1: + continue + merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank) + if len(merged) == 2: + merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged))) + + # 3. Generate the tokens and toktypes lists + vocab_size = self.hparams["vocab_size"] + assert tokenizer.vocab_size == vocab_size + special_tokens = tokenizer.special_tokens + reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()} + tokens: list[str] = [] + toktypes: list[int] = [] + for i in range(vocab_size): + if i not in reverse_vocab: + tokens.append(f"[PAD{i}]") + toktypes.append(gguf.TokenType.UNUSED) + else: + token = reverse_vocab[i] + tokens.append(token) + if i in special_tokens.values(): + toktypes.append(gguf.TokenType.CONTROL) + else: + toktypes.append(gguf.TokenType.NORMAL) + + # 4. Write all vocab-related fields to the GGUF writer + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + self.gguf_writer.add_token_merges(merges) + + # 5. Add special tokens and chat templates + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False) + special_vocab.add_to_gguf(self.gguf_writer) + # FIX for BOS token: Overwrite incorrect id read from config.json + self.gguf_writer.add_bos_token_id(127958) # <|bos|> + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + + self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"]) + + # Rope + rope_scaling = hparams.get("rope_scaling", {}) + if rope_scaling.get("type") == "dynamic": + # HunYuan uses NTK Aware Alpha based scaling. Original implementation: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + # 1000 corresponds to a usable context length of 256k (https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/report/Hunyuan_A13B_Technical_Report.pdf) + alpha = rope_scaling.get("alpha", 50) + base = hparams.get("rope_theta", 10000.0) + dim = hparams["head_dim"] + scaled_base = base * (alpha ** (dim / (dim - 2))) + self.gguf_writer.add_rope_freq_base(scaled_base) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + self.gguf_writer.add_rope_scaling_factor(1) + # There is no consistent way to calculate ctx from alpha, and the config is incorrectly set to 32k + self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024) # 256k context length + self.gguf_writer.add_context_length(256 * 1024) # 256k context length + + # if any of our assumptions about the values are wrong, something has changed and this may need to be updated + assert alpha == 50 and base == 10000.0 and self.hparams["max_position_embeddings"] in [32 * 1024, 256 * 1024] , \ + "HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually" + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name == "model.embed_tokens.weight": + self._tok_embd = data_torch.clone() + + if name == "lm_head.weight": + if self.hparams.get("tie_word_embeddings", False): + logger.info("Skipping tied output layer 'lm_head.weight'") + return [] + + return [(self.map_tensor_name(name), data_torch)] + @ModelBase.register("SmolLM3ForCausalLM") class SmolLM3Model(LlamaModel): model_arch = gguf.MODEL_ARCH.SMOLLM3 diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index abaf2ea9a1248..6a653a893321f 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -140,6 +140,7 @@ class TOKENIZER_TYPE(IntEnum): {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"}, {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"}, {"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"}, + {"name": "hunyuan-v1-dense", "tokt": TOKENIZER_TYPE.BPE, "repo": "", "chkhsh": ""}, # TODO: update hunyuan-v1-dense repo # falcon-h1 series uses 4 different tokenizers across model sizes (0.5b - 34b), hence we need to define 4 different hashes {"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base", "chkhsh": "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6"}, {"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-1B-Base", "chkhsh": "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86"}, diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 680210db7e9d5..d300c082534c2 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -373,6 +373,7 @@ class MODEL_ARCH(IntEnum): ERNIE4_5 = auto() ERNIE4_5_MOE = auto() HUNYUAN_MOE = auto() + HUNYUAN_V1_DENSE = auto() SMOLLM3 = auto() LFM2 = auto() DREAM = auto() @@ -692,6 +693,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.ERNIE4_5_MOE: "ernie4_5-moe", MODEL_ARCH.FALCON_H1: "falcon-h1", MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe", + MODEL_ARCH.HUNYUAN_V1_DENSE: "hunyuan-v1-dense", MODEL_ARCH.SMOLLM3: "smollm3", MODEL_ARCH.LFM2: "lfm2", MODEL_ARCH.DREAM: "dream", @@ -2449,6 +2451,22 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, ], + MODEL_ARCH.HUNYUAN_V1_DENSE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.SMOLLM3: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 814ac93a6d87e..203d4606b735a 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -85,6 +85,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_ERNIE4_5, "ernie4_5" }, { LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" }, { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" }, + { LLM_ARCH_HUNYUAN_V1_DENSE, "hunyuan-v1-dense" }, { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_LFM2, "lfm2" }, { LLM_ARCH_DREAM, "dream" }, @@ -1895,6 +1896,26 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_HUNYUAN_V1_DENSE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + + }, + }, { LLM_ARCH_SMOLLM3, { diff --git a/src/llama-arch.h b/src/llama-arch.h index d09b7d7810b03..c5c1455c275bf 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -89,6 +89,7 @@ enum llm_arch { LLM_ARCH_ERNIE4_5, LLM_ARCH_ERNIE4_5_MOE, LLM_ARCH_HUNYUAN_MOE, + LLM_ARCH_HUNYUAN_V1_DENSE, LLM_ARCH_SMOLLM3, LLM_ARCH_LFM2, LLM_ARCH_DREAM, diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index 80072ad2713c7..0cda3215edc8f 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -66,6 +66,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM }, { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE }, + { "hunyuan-v1-dense", LLM_CHAT_TEMPLATE_HUNYUAN_V1_DENSE }, { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 }, }; @@ -193,6 +194,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_DOTS1; } else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) { return LLM_CHAT_TEMPLATE_HUNYUAN_MOE; + } else if (tmpl_contains("<|hy_place▁holder▁no▁2|>") && tmpl_contains("<|hy_place▁holder▁no▁3|>")) { + return LLM_CHAT_TEMPLATE_HUNYUAN_V1_DENSE; } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) { return LLM_CHAT_TEMPLATE_KIMI_K2; } @@ -703,6 +706,29 @@ int32_t llm_chat_apply_template( ss << "<|startoftext|>" << message->content << "<|extra_0|>"; } } + } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_V1_DENSE) { + // Todo: add model name + for (size_t i = 0; i < chat.size(); i++) { + std::string role(chat[i]->role); + if (i == 0) { + if (role == "system") { + ss << "<|hy_begin▁of▁sentence|>" << chat[i]->content << "<|hy_place▁holder▁no▁3|>"; + } else { + ss << "<|hy_begin▁of▁sentence|>"; + } + } + + if (role == "assistant") { + ss << "<|hy_Assistant|>" << chat[i]->content << "<|hy_place▁holder▁no▁2|>"; + } else if (role == "user") { + ss << "<|hy_User|>" << chat[i]->content; + } + } + if (add_ass) { + ss << "<|hy_Assistant|>"; + } else { + ss << "<|hy_place▁holder▁no▁8|>"; + } } else if (tmpl == LLM_CHAT_TEMPLATE_KIMI_K2) { // moonshotai/Kimi-K2-Instruct for (auto message : chat) { diff --git a/src/llama-chat.h b/src/llama-chat.h index 6968a19fbe13c..ece18767e9ff5 100644 --- a/src/llama-chat.h +++ b/src/llama-chat.h @@ -46,6 +46,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_SMOLVLM, LLM_CHAT_TEMPLATE_DOTS1, LLM_CHAT_TEMPLATE_HUNYUAN_MOE, + LLM_CHAT_TEMPLATE_HUNYUAN_V1_DENSE, LLM_CHAT_TEMPLATE_KIMI_K2, LLM_CHAT_TEMPLATE_UNKNOWN, }; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a997a1e80f8cf..e3375c2903824 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1744,6 +1744,18 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_HUNYUAN_V1_DENSE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_0_5B; break; + case 2048: type = LLM_TYPE_2B; break; + case 3072: type = LLM_TYPE_4B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_SMOLLM3: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -5103,6 +5115,39 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); } } break; + case LLM_ARCH_HUNYUAN_V1_DENSE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + } + } break; case LLM_ARCH_SMOLLM3: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -16696,6 +16741,144 @@ struct llm_build_hunyuan_moe : public llm_graph_context { } }; +struct llm_build_hunyuan_v1_dense : public llm_graph_context { + llm_build_hunyuan_v1_dense(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = build_norm(Kcur, + model.layers[il].attn_k_norm, nullptr, + LLM_NORM_RMS, il); + cb(Kcur, "Kcur_norm", il); + + Qcur = build_norm(Qcur, + model.layers[il].attn_q_norm, nullptr, + LLM_NORM_RMS, il); + cb(Qcur, "Qcur_norm", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + // feed-forward network (non-MoE) + ggml_tensor * cur_mlp = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur_mlp, "ffn_out", il); + + cur = ggml_add(ctx0, cur_mlp, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + struct llm_build_smollm3 : public llm_graph_context { llm_build_smollm3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -17436,6 +17619,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_HUNYUAN_V1_DENSE: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_SMOLLM3: { llm = std::make_unique(*this, params); @@ -17645,6 +17832,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_MINICPM3: case LLM_ARCH_DOTS1: case LLM_ARCH_HUNYUAN_MOE: + case LLM_ARCH_HUNYUAN_V1_DENSE: case LLM_ARCH_LFM2: return LLAMA_ROPE_TYPE_NEOX; From aa973ca21913aba77f6e81a935270ef7be222e75 Mon Sep 17 00:00:00 2001 From: stevenkuang Date: Fri, 25 Jul 2025 19:56:25 +0800 Subject: [PATCH 02/10] update hunyuan_moe to hunyuan_v1_moe Signed-off-by: stevenkuang --- convert_hf_to_gguf.py | 2 +- convert_hf_to_gguf_update.py | 2 +- gguf-py/gguf/constants.py | 6 +++--- src/llama-arch.cpp | 4 ++-- src/llama-arch.h | 2 +- src/llama-chat.cpp | 8 ++++---- src/llama-chat.h | 2 +- src/llama-model.cpp | 14 +++++++------- 8 files changed, 20 insertions(+), 20 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 010e209b22a7a..26ffea568fa10 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -7387,7 +7387,7 @@ def set_gguf_parameters(self): @ModelBase.register("HunYuanMoEV1ForCausalLM") class HunYuanMoEModel(TextModel): - model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE + model_arch = gguf.MODEL_ARCH.HUNYUAN_V1_MOE def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 6a653a893321f..cac881c5d93ec 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -139,7 +139,7 @@ class TOKENIZER_TYPE(IntEnum): {"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"}, {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"}, {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"}, - {"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"}, + {"name": "hunyuan-v1-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"}, {"name": "hunyuan-v1-dense", "tokt": TOKENIZER_TYPE.BPE, "repo": "", "chkhsh": ""}, # TODO: update hunyuan-v1-dense repo # falcon-h1 series uses 4 different tokenizers across model sizes (0.5b - 34b), hence we need to define 4 different hashes {"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base", "chkhsh": "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6"}, diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index d300c082534c2..70017f297f4cd 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -372,7 +372,7 @@ class MODEL_ARCH(IntEnum): ARCEE = auto() ERNIE4_5 = auto() ERNIE4_5_MOE = auto() - HUNYUAN_MOE = auto() + HUNYUAN_V1_MOE = auto() HUNYUAN_V1_DENSE = auto() SMOLLM3 = auto() LFM2 = auto() @@ -692,7 +692,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.ERNIE4_5: "ernie4_5", MODEL_ARCH.ERNIE4_5_MOE: "ernie4_5-moe", MODEL_ARCH.FALCON_H1: "falcon-h1", - MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe", + MODEL_ARCH.HUNYUAN_V1_MOE: "hunyuan-v1-moe", MODEL_ARCH.HUNYUAN_V1_DENSE: "hunyuan-v1-dense", MODEL_ARCH.SMOLLM3: "smollm3", MODEL_ARCH.LFM2: "lfm2", @@ -2430,7 +2430,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, # Final layer norm MODEL_TENSOR.OUTPUT, # Output projection (lm_head) ], - MODEL_ARCH.HUNYUAN_MOE: [ + MODEL_ARCH.HUNYUAN_V1_MOE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 203d4606b735a..f7062dc0cb459 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -84,7 +84,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_ARCEE, "arcee" }, { LLM_ARCH_ERNIE4_5, "ernie4_5" }, { LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" }, - { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" }, + { LLM_ARCH_HUNYUAN_V1_MOE, "hunyuan-v1-moe" }, { LLM_ARCH_HUNYUAN_V1_DENSE, "hunyuan-v1-dense" }, { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_LFM2, "lfm2" }, @@ -1874,7 +1874,7 @@ static const std::map> LLM_TENSOR_N }, }, { - LLM_ARCH_HUNYUAN_MOE, + LLM_ARCH_HUNYUAN_V1_MOE, { { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index c5c1455c275bf..d29b367878149 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -88,7 +88,7 @@ enum llm_arch { LLM_ARCH_ARCEE, LLM_ARCH_ERNIE4_5, LLM_ARCH_ERNIE4_5_MOE, - LLM_ARCH_HUNYUAN_MOE, + LLM_ARCH_HUNYUAN_V1_MOE, LLM_ARCH_HUNYUAN_V1_DENSE, LLM_ARCH_SMOLLM3, LLM_ARCH_LFM2, diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index 0cda3215edc8f..882f08bb41235 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -65,7 +65,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "bailing", LLM_CHAT_TEMPLATE_BAILING }, { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM }, - { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE }, + { "hunyuan-v1-moe", LLM_CHAT_TEMPLATE_HUNYUAN_V1_MOE }, { "hunyuan-v1-dense", LLM_CHAT_TEMPLATE_HUNYUAN_V1_DENSE }, { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 }, }; @@ -193,7 +193,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { } else if (tmpl_contains("<|endofuserprompt|>")) { return LLM_CHAT_TEMPLATE_DOTS1; } else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) { - return LLM_CHAT_TEMPLATE_HUNYUAN_MOE; + return LLM_CHAT_TEMPLATE_HUNYUAN_V1_MOE; } else if (tmpl_contains("<|hy_place▁holder▁no▁2|>") && tmpl_contains("<|hy_place▁holder▁no▁3|>")) { return LLM_CHAT_TEMPLATE_HUNYUAN_V1_DENSE; } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) { @@ -694,14 +694,14 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|response|>"; } - } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_MOE) { + } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_V1_MOE) { // tencent/Hunyuan-A13B-Instruct for (auto message : chat) { std::string role(message->role); if (role == "system") { ss << "<|startoftext|>" << message->content << "<|extra_4|>"; } else if (role == "assistant") { - ss << "<|startoftext|>" << message->content << "<|eos|>"; + ss << message->content << "<|eos|>"; } else { ss << "<|startoftext|>" << message->content << "<|extra_0|>"; } diff --git a/src/llama-chat.h b/src/llama-chat.h index ece18767e9ff5..3e1f6a5843c37 100644 --- a/src/llama-chat.h +++ b/src/llama-chat.h @@ -45,7 +45,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_LLAMA4, LLM_CHAT_TEMPLATE_SMOLVLM, LLM_CHAT_TEMPLATE_DOTS1, - LLM_CHAT_TEMPLATE_HUNYUAN_MOE, + LLM_CHAT_TEMPLATE_HUNYUAN_V1_MOE, LLM_CHAT_TEMPLATE_HUNYUAN_V1_DENSE, LLM_CHAT_TEMPLATE_KIMI_K2, LLM_CHAT_TEMPLATE_UNKNOWN, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e3375c2903824..62d3441be02d8 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1733,7 +1733,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { type = LLM_TYPE_UNKNOWN; } } break; - case LLM_ARCH_HUNYUAN_MOE: + case LLM_ARCH_HUNYUAN_V1_MOE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); @@ -5078,7 +5078,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {ffn_intermediate_size}, TENSOR_NOT_REQUIRED); } } break; - case LLM_ARCH_HUNYUAN_MOE: + case LLM_ARCH_HUNYUAN_V1_MOE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -16580,8 +16580,8 @@ struct llm_build_arcee : public llm_graph_context { } }; -struct llm_build_hunyuan_moe : public llm_graph_context { - llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +struct llm_build_hunyuan_v1_moe : public llm_graph_context { + llm_build_hunyuan_v1_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -17615,9 +17615,9 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; - case LLM_ARCH_HUNYUAN_MOE: + case LLM_ARCH_HUNYUAN_V1_MOE: { - llm = std::make_unique(*this, params); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_HUNYUAN_V1_DENSE: { @@ -17831,7 +17831,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_EXAONE4: case LLM_ARCH_MINICPM3: case LLM_ARCH_DOTS1: - case LLM_ARCH_HUNYUAN_MOE: + case LLM_ARCH_HUNYUAN_V1_MOE: case LLM_ARCH_HUNYUAN_V1_DENSE: case LLM_ARCH_LFM2: return LLAMA_ROPE_TYPE_NEOX; From 5645497429805ae3d9ecbe4a4a0d1362dfbe03d1 Mon Sep 17 00:00:00 2001 From: stevenkuang Date: Fri, 25 Jul 2025 21:17:06 +0800 Subject: [PATCH 03/10] fix rope alpha assert and bos token Signed-off-by: stevenkuang --- convert_hf_to_gguf.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 26ffea568fa10..75893543ac542 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -7595,7 +7595,8 @@ def set_vocab(self): special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False) special_vocab.add_to_gguf(self.gguf_writer) # FIX for BOS token: Overwrite incorrect id read from config.json - self.gguf_writer.add_bos_token_id(127958) # <|bos|> + if self.hparams['hidden_size'] == 4096: + self.gguf_writer.add_bos_token_id(127958) # only for 7b dense, fix <|bos|> token def set_gguf_parameters(self): super().set_gguf_parameters() @@ -7620,7 +7621,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_context_length(256 * 1024) # 256k context length # if any of our assumptions about the values are wrong, something has changed and this may need to be updated - assert alpha == 50 and base == 10000.0 and self.hparams["max_position_embeddings"] in [32 * 1024, 256 * 1024] , \ + assert base == 10000.0 and self.hparams["max_position_embeddings"] in [32 * 1024, 256 * 1024] , \ "HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually" _experts: list[dict[str, Tensor]] | None = None From 63f32c301fc698104726afc135f52233ea444c05 Mon Sep 17 00:00:00 2001 From: stevenkuang Date: Fri, 25 Jul 2025 21:53:20 +0800 Subject: [PATCH 04/10] add blank line Signed-off-by: stevenkuang --- convert_hf_to_gguf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 75893543ac542..d9317ebb1ddd4 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -7637,6 +7637,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] + @ModelBase.register("SmolLM3ForCausalLM") class SmolLM3Model(LlamaModel): model_arch = gguf.MODEL_ARCH.SMOLLM3 From 78de8db462e9618304690a263218895f78196fb1 Mon Sep 17 00:00:00 2001 From: stevenkuang Date: Sat, 26 Jul 2025 03:26:11 +0800 Subject: [PATCH 05/10] Revert "update hunyuan_moe to hunyuan_v1_moe" This reverts commit aa973ca21913aba77f6e81a935270ef7be222e75. --- convert_hf_to_gguf.py | 2 +- convert_hf_to_gguf_update.py | 2 +- gguf-py/gguf/constants.py | 6 +++--- src/llama-arch.cpp | 4 ++-- src/llama-arch.h | 2 +- src/llama-chat.cpp | 8 ++++---- src/llama-chat.h | 2 +- src/llama-model.cpp | 14 +++++++------- 8 files changed, 20 insertions(+), 20 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index d9317ebb1ddd4..41399f2432abb 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -7387,7 +7387,7 @@ def set_gguf_parameters(self): @ModelBase.register("HunYuanMoEV1ForCausalLM") class HunYuanMoEModel(TextModel): - model_arch = gguf.MODEL_ARCH.HUNYUAN_V1_MOE + model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index cac881c5d93ec..6a653a893321f 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -139,7 +139,7 @@ class TOKENIZER_TYPE(IntEnum): {"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"}, {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"}, {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"}, - {"name": "hunyuan-v1-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"}, + {"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"}, {"name": "hunyuan-v1-dense", "tokt": TOKENIZER_TYPE.BPE, "repo": "", "chkhsh": ""}, # TODO: update hunyuan-v1-dense repo # falcon-h1 series uses 4 different tokenizers across model sizes (0.5b - 34b), hence we need to define 4 different hashes {"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base", "chkhsh": "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6"}, diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 70017f297f4cd..d300c082534c2 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -372,7 +372,7 @@ class MODEL_ARCH(IntEnum): ARCEE = auto() ERNIE4_5 = auto() ERNIE4_5_MOE = auto() - HUNYUAN_V1_MOE = auto() + HUNYUAN_MOE = auto() HUNYUAN_V1_DENSE = auto() SMOLLM3 = auto() LFM2 = auto() @@ -692,7 +692,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.ERNIE4_5: "ernie4_5", MODEL_ARCH.ERNIE4_5_MOE: "ernie4_5-moe", MODEL_ARCH.FALCON_H1: "falcon-h1", - MODEL_ARCH.HUNYUAN_V1_MOE: "hunyuan-v1-moe", + MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe", MODEL_ARCH.HUNYUAN_V1_DENSE: "hunyuan-v1-dense", MODEL_ARCH.SMOLLM3: "smollm3", MODEL_ARCH.LFM2: "lfm2", @@ -2430,7 +2430,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, # Final layer norm MODEL_TENSOR.OUTPUT, # Output projection (lm_head) ], - MODEL_ARCH.HUNYUAN_V1_MOE: [ + MODEL_ARCH.HUNYUAN_MOE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index f7062dc0cb459..203d4606b735a 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -84,7 +84,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_ARCEE, "arcee" }, { LLM_ARCH_ERNIE4_5, "ernie4_5" }, { LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" }, - { LLM_ARCH_HUNYUAN_V1_MOE, "hunyuan-v1-moe" }, + { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" }, { LLM_ARCH_HUNYUAN_V1_DENSE, "hunyuan-v1-dense" }, { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_LFM2, "lfm2" }, @@ -1874,7 +1874,7 @@ static const std::map> LLM_TENSOR_N }, }, { - LLM_ARCH_HUNYUAN_V1_MOE, + LLM_ARCH_HUNYUAN_MOE, { { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index d29b367878149..c5c1455c275bf 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -88,7 +88,7 @@ enum llm_arch { LLM_ARCH_ARCEE, LLM_ARCH_ERNIE4_5, LLM_ARCH_ERNIE4_5_MOE, - LLM_ARCH_HUNYUAN_V1_MOE, + LLM_ARCH_HUNYUAN_MOE, LLM_ARCH_HUNYUAN_V1_DENSE, LLM_ARCH_SMOLLM3, LLM_ARCH_LFM2, diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index 882f08bb41235..0cda3215edc8f 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -65,7 +65,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "bailing", LLM_CHAT_TEMPLATE_BAILING }, { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM }, - { "hunyuan-v1-moe", LLM_CHAT_TEMPLATE_HUNYUAN_V1_MOE }, + { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE }, { "hunyuan-v1-dense", LLM_CHAT_TEMPLATE_HUNYUAN_V1_DENSE }, { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 }, }; @@ -193,7 +193,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { } else if (tmpl_contains("<|endofuserprompt|>")) { return LLM_CHAT_TEMPLATE_DOTS1; } else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) { - return LLM_CHAT_TEMPLATE_HUNYUAN_V1_MOE; + return LLM_CHAT_TEMPLATE_HUNYUAN_MOE; } else if (tmpl_contains("<|hy_place▁holder▁no▁2|>") && tmpl_contains("<|hy_place▁holder▁no▁3|>")) { return LLM_CHAT_TEMPLATE_HUNYUAN_V1_DENSE; } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) { @@ -694,14 +694,14 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|response|>"; } - } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_V1_MOE) { + } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_MOE) { // tencent/Hunyuan-A13B-Instruct for (auto message : chat) { std::string role(message->role); if (role == "system") { ss << "<|startoftext|>" << message->content << "<|extra_4|>"; } else if (role == "assistant") { - ss << message->content << "<|eos|>"; + ss << "<|startoftext|>" << message->content << "<|eos|>"; } else { ss << "<|startoftext|>" << message->content << "<|extra_0|>"; } diff --git a/src/llama-chat.h b/src/llama-chat.h index 3e1f6a5843c37..ece18767e9ff5 100644 --- a/src/llama-chat.h +++ b/src/llama-chat.h @@ -45,7 +45,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_LLAMA4, LLM_CHAT_TEMPLATE_SMOLVLM, LLM_CHAT_TEMPLATE_DOTS1, - LLM_CHAT_TEMPLATE_HUNYUAN_V1_MOE, + LLM_CHAT_TEMPLATE_HUNYUAN_MOE, LLM_CHAT_TEMPLATE_HUNYUAN_V1_DENSE, LLM_CHAT_TEMPLATE_KIMI_K2, LLM_CHAT_TEMPLATE_UNKNOWN, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 62d3441be02d8..e3375c2903824 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1733,7 +1733,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { type = LLM_TYPE_UNKNOWN; } } break; - case LLM_ARCH_HUNYUAN_V1_MOE: + case LLM_ARCH_HUNYUAN_MOE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); @@ -5078,7 +5078,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {ffn_intermediate_size}, TENSOR_NOT_REQUIRED); } } break; - case LLM_ARCH_HUNYUAN_V1_MOE: + case LLM_ARCH_HUNYUAN_MOE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -16580,8 +16580,8 @@ struct llm_build_arcee : public llm_graph_context { } }; -struct llm_build_hunyuan_v1_moe : public llm_graph_context { - llm_build_hunyuan_v1_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +struct llm_build_hunyuan_moe : public llm_graph_context { + llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -17615,9 +17615,9 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; - case LLM_ARCH_HUNYUAN_V1_MOE: + case LLM_ARCH_HUNYUAN_MOE: { - llm = std::make_unique(*this, params); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_HUNYUAN_V1_DENSE: { @@ -17831,7 +17831,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_EXAONE4: case LLM_ARCH_MINICPM3: case LLM_ARCH_DOTS1: - case LLM_ARCH_HUNYUAN_V1_MOE: + case LLM_ARCH_HUNYUAN_MOE: case LLM_ARCH_HUNYUAN_V1_DENSE: case LLM_ARCH_LFM2: return LLAMA_ROPE_TYPE_NEOX; From c7329b4d01c0d51c1c7c829b0477c0714b15a30f Mon Sep 17 00:00:00 2001 From: stevenkuang Date: Sat, 26 Jul 2025 03:26:38 +0800 Subject: [PATCH 06/10] use hunyuan_dense instead of hunyuan_v1_dense Signed-off-by: stevenkuang --- convert_hf_to_gguf.py | 4 ++-- convert_hf_to_gguf_update.py | 2 +- gguf-py/gguf/constants.py | 6 +++--- src/llama-arch.cpp | 4 ++-- src/llama-arch.h | 2 +- src/llama-chat.cpp | 8 ++++---- src/llama-chat.h | 2 +- src/llama-model.cpp | 16 ++++++++-------- 8 files changed, 22 insertions(+), 22 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 41399f2432abb..15cca2a03234f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -685,7 +685,7 @@ def get_vocab_base_pre(self, tokenizer) -> str: # ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct res = "hunyuan" if chkhsh == "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6": - # TODO: update ref + # ref: https://huggingface.co/tencent/Hunyuan-4B res = "hunyuan" if chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6": # ref: https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base @@ -7536,7 +7536,7 @@ def prepare_tensors(self): @ModelBase.register("HunYuanDenseV1ForCausalLM") class HunYuanModel(TextModel): - model_arch = gguf.MODEL_ARCH.HUNYUAN_V1_DENSE + model_arch = gguf.MODEL_ARCH.HUNYUAN_DENSE def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 6a653a893321f..461522b3d2a4c 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -140,7 +140,7 @@ class TOKENIZER_TYPE(IntEnum): {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"}, {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"}, {"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"}, - {"name": "hunyuan-v1-dense", "tokt": TOKENIZER_TYPE.BPE, "repo": "", "chkhsh": ""}, # TODO: update hunyuan-v1-dense repo + {"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-4B", "chkhsh": "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6"}, # falcon-h1 series uses 4 different tokenizers across model sizes (0.5b - 34b), hence we need to define 4 different hashes {"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base", "chkhsh": "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6"}, {"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-1B-Base", "chkhsh": "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86"}, diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index d300c082534c2..7a01775fe8791 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -373,7 +373,7 @@ class MODEL_ARCH(IntEnum): ERNIE4_5 = auto() ERNIE4_5_MOE = auto() HUNYUAN_MOE = auto() - HUNYUAN_V1_DENSE = auto() + HUNYUAN_DENSE = auto() SMOLLM3 = auto() LFM2 = auto() DREAM = auto() @@ -693,7 +693,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.ERNIE4_5_MOE: "ernie4_5-moe", MODEL_ARCH.FALCON_H1: "falcon-h1", MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe", - MODEL_ARCH.HUNYUAN_V1_DENSE: "hunyuan-v1-dense", + MODEL_ARCH.HUNYUAN_DENSE: "hunyuan-dense", MODEL_ARCH.SMOLLM3: "smollm3", MODEL_ARCH.LFM2: "lfm2", MODEL_ARCH.DREAM: "dream", @@ -2451,7 +2451,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, ], - MODEL_ARCH.HUNYUAN_V1_DENSE: [ + MODEL_ARCH.HUNYUAN_DENSE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 203d4606b735a..b6be16ab97225 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -85,7 +85,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_ERNIE4_5, "ernie4_5" }, { LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" }, { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" }, - { LLM_ARCH_HUNYUAN_V1_DENSE, "hunyuan-v1-dense" }, + { LLM_ARCH_HUNYUAN_DENSE, "hunyuan-dense" }, { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_LFM2, "lfm2" }, { LLM_ARCH_DREAM, "dream" }, @@ -1897,7 +1897,7 @@ static const std::map> LLM_TENSOR_N }, }, { - LLM_ARCH_HUNYUAN_V1_DENSE, + LLM_ARCH_HUNYUAN_DENSE, { { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index c5c1455c275bf..fead2b2696839 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -89,7 +89,7 @@ enum llm_arch { LLM_ARCH_ERNIE4_5, LLM_ARCH_ERNIE4_5_MOE, LLM_ARCH_HUNYUAN_MOE, - LLM_ARCH_HUNYUAN_V1_DENSE, + LLM_ARCH_HUNYUAN_DENSE, LLM_ARCH_SMOLLM3, LLM_ARCH_LFM2, LLM_ARCH_DREAM, diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index 0cda3215edc8f..d9c0a6da134db 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -66,7 +66,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM }, { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE }, - { "hunyuan-v1-dense", LLM_CHAT_TEMPLATE_HUNYUAN_V1_DENSE }, + { "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE }, { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 }, }; @@ -195,7 +195,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { } else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) { return LLM_CHAT_TEMPLATE_HUNYUAN_MOE; } else if (tmpl_contains("<|hy_place▁holder▁no▁2|>") && tmpl_contains("<|hy_place▁holder▁no▁3|>")) { - return LLM_CHAT_TEMPLATE_HUNYUAN_V1_DENSE; + return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE; } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) { return LLM_CHAT_TEMPLATE_KIMI_K2; } @@ -706,8 +706,8 @@ int32_t llm_chat_apply_template( ss << "<|startoftext|>" << message->content << "<|extra_0|>"; } } - } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_V1_DENSE) { - // Todo: add model name + } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_DENSE) { + // tencent/Hunyuan-4B for (size_t i = 0; i < chat.size(); i++) { std::string role(chat[i]->role); if (i == 0) { diff --git a/src/llama-chat.h b/src/llama-chat.h index ece18767e9ff5..4cf77fd286733 100644 --- a/src/llama-chat.h +++ b/src/llama-chat.h @@ -46,7 +46,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_SMOLVLM, LLM_CHAT_TEMPLATE_DOTS1, LLM_CHAT_TEMPLATE_HUNYUAN_MOE, - LLM_CHAT_TEMPLATE_HUNYUAN_V1_DENSE, + LLM_CHAT_TEMPLATE_HUNYUAN_DENSE, LLM_CHAT_TEMPLATE_KIMI_K2, LLM_CHAT_TEMPLATE_UNKNOWN, }; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e3375c2903824..75228565cfb6c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1744,13 +1744,13 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; - case LLM_ARCH_HUNYUAN_V1_DENSE: + case LLM_ARCH_HUNYUAN_DENSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_embd) { case 1024: type = LLM_TYPE_0_5B; break; - case 2048: type = LLM_TYPE_2B; break; + case 2048: type = LLM_TYPE_1_8B; break; case 3072: type = LLM_TYPE_4B; break; case 4096: type = LLM_TYPE_7B; break; default: type = LLM_TYPE_UNKNOWN; @@ -5115,7 +5115,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); } } break; - case LLM_ARCH_HUNYUAN_V1_DENSE: + case LLM_ARCH_HUNYUAN_DENSE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -16741,8 +16741,8 @@ struct llm_build_hunyuan_moe : public llm_graph_context { } }; -struct llm_build_hunyuan_v1_dense : public llm_graph_context { - llm_build_hunyuan_v1_dense(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +struct llm_build_hunyuan_dense : public llm_graph_context { + llm_build_hunyuan_dense(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -17619,9 +17619,9 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; - case LLM_ARCH_HUNYUAN_V1_DENSE: + case LLM_ARCH_HUNYUAN_DENSE: { - llm = std::make_unique(*this, params); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_SMOLLM3: { @@ -17832,7 +17832,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_MINICPM3: case LLM_ARCH_DOTS1: case LLM_ARCH_HUNYUAN_MOE: - case LLM_ARCH_HUNYUAN_V1_DENSE: + case LLM_ARCH_HUNYUAN_DENSE: case LLM_ARCH_LFM2: return LLAMA_ROPE_TYPE_NEOX; From 0192c1297ddde77b0d9688179779e4218cf04db4 Mon Sep 17 00:00:00 2001 From: stevenkuang Date: Sat, 26 Jul 2025 03:28:06 +0800 Subject: [PATCH 07/10] fix hunyuan_moe chat template Signed-off-by: stevenkuang --- src/llama-chat.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index d9c0a6da134db..ec0f3dbdd483e 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -701,7 +701,7 @@ int32_t llm_chat_apply_template( if (role == "system") { ss << "<|startoftext|>" << message->content << "<|extra_4|>"; } else if (role == "assistant") { - ss << "<|startoftext|>" << message->content << "<|eos|>"; + ss << message->content << "<|eos|>"; } else { ss << "<|startoftext|>" << message->content << "<|extra_0|>"; } From 3ecc5d3ce5f2f83ffd528c5bc129276c480bd52d Mon Sep 17 00:00:00 2001 From: stevenkuang Date: Sun, 27 Jul 2025 01:08:40 +0800 Subject: [PATCH 08/10] remove leftover code Signed-off-by: stevenkuang --- convert_hf_to_gguf.py | 20 -------------------- src/llama-model.cpp | 2 +- 2 files changed, 1 insertion(+), 21 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 15cca2a03234f..2dad389d70bd9 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -7389,11 +7389,6 @@ def set_gguf_parameters(self): class HunYuanMoEModel(TextModel): model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # For handling tied embeddings - self._tok_embd = None - def set_vocab(self): from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) @@ -7487,9 +7482,6 @@ def set_gguf_parameters(self): _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - if name == "model.embed_tokens.weight": - self._tok_embd = data_torch.clone() - if name == "lm_head.weight": if self.hparams.get("tie_word_embeddings", False): logger.info("Skipping tied output layer 'lm_head.weight'") @@ -7538,11 +7530,6 @@ def prepare_tensors(self): class HunYuanModel(TextModel): model_arch = gguf.MODEL_ARCH.HUNYUAN_DENSE - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # For handling tied embeddings - self._tok_embd = None - def set_vocab(self): if (self.dir_model / "tokenizer.json").is_file(): self._set_vocab_gpt2() @@ -7602,8 +7589,6 @@ def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams - self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"]) - # Rope rope_scaling = hparams.get("rope_scaling", {}) if rope_scaling.get("type") == "dynamic": @@ -7624,12 +7609,7 @@ def set_gguf_parameters(self): assert base == 10000.0 and self.hparams["max_position_embeddings"] in [32 * 1024, 256 * 1024] , \ "HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually" - _experts: list[dict[str, Tensor]] | None = None - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - if name == "model.embed_tokens.weight": - self._tok_embd = data_torch.clone() - if name == "lm_head.weight": if self.hparams.get("tie_word_embeddings", False): logger.info("Skipping tied output layer 'lm_head.weight'") diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 75228565cfb6c..781f0a00bf81e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1746,7 +1746,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_HUNYUAN_DENSE: { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_embd) { case 1024: type = LLM_TYPE_0_5B; break; From 6c17323ad71e999cea1c51c00dc092f17af011f3 Mon Sep 17 00:00:00 2001 From: stevenkuang Date: Sun, 27 Jul 2025 01:09:32 +0800 Subject: [PATCH 09/10] update hunyuan dense chat template Signed-off-by: stevenkuang --- convert_hf_to_gguf.py | 1 + src/llama-chat.cpp | 4 +--- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2dad389d70bd9..c1a192bd89dd4 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -7533,6 +7533,7 @@ class HunYuanModel(TextModel): def set_vocab(self): if (self.dir_model / "tokenizer.json").is_file(): self._set_vocab_gpt2() + self.gguf_writer.add_add_bos_token(True) else: from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index ec0f3dbdd483e..7aac4db56376a 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -712,9 +712,7 @@ int32_t llm_chat_apply_template( std::string role(chat[i]->role); if (i == 0) { if (role == "system") { - ss << "<|hy_begin▁of▁sentence|>" << chat[i]->content << "<|hy_place▁holder▁no▁3|>"; - } else { - ss << "<|hy_begin▁of▁sentence|>"; + ss << chat[i]->content << "<|hy_place▁holder▁no▁3|>"; } } From 675f35d57f9b046db72075996cbfcc766c90b500 Mon Sep 17 00:00:00 2001 From: stevenkuang Date: Fri, 1 Aug 2025 00:00:05 +0800 Subject: [PATCH 10/10] fix hunyuan dense vocab and chat template Signed-off-by: stevenkuang --- convert_hf_to_gguf.py | 5 ++--- convert_hf_to_gguf_update.py | 2 +- src/llama-chat.cpp | 9 ++------- src/llama-vocab.cpp | 5 +++++ src/llama-vocab.h | 1 + 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index c1a192bd89dd4..f6ff543a1c5b1 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -685,8 +685,8 @@ def get_vocab_base_pre(self, tokenizer) -> str: # ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct res = "hunyuan" if chkhsh == "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6": - # ref: https://huggingface.co/tencent/Hunyuan-4B - res = "hunyuan" + # ref: https://huggingface.co/tencent/Hunyuan-4B-Instruct + res = "hunyuan-dense" if chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6": # ref: https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base res = "falcon-h1" @@ -7533,7 +7533,6 @@ class HunYuanModel(TextModel): def set_vocab(self): if (self.dir_model / "tokenizer.json").is_file(): self._set_vocab_gpt2() - self.gguf_writer.add_add_bos_token(True) else: from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 461522b3d2a4c..c4904b53936f5 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -140,7 +140,7 @@ class TOKENIZER_TYPE(IntEnum): {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"}, {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"}, {"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"}, - {"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-4B", "chkhsh": "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6"}, + {"name": "hunyuan-dense", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-4B-Instruct", "chkhsh": "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6"}, # falcon-h1 series uses 4 different tokenizers across model sizes (0.5b - 34b), hence we need to define 4 different hashes {"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base", "chkhsh": "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6"}, {"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-1B-Base", "chkhsh": "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86"}, diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index 7aac4db56376a..f869d3869a2e2 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -707,7 +707,7 @@ int32_t llm_chat_apply_template( } } } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_DENSE) { - // tencent/Hunyuan-4B + // tencent/Hunyuan-4B-Instruct for (size_t i = 0; i < chat.size(); i++) { std::string role(chat[i]->role); if (i == 0) { @@ -719,14 +719,9 @@ int32_t llm_chat_apply_template( if (role == "assistant") { ss << "<|hy_Assistant|>" << chat[i]->content << "<|hy_place▁holder▁no▁2|>"; } else if (role == "user") { - ss << "<|hy_User|>" << chat[i]->content; + ss << "<|hy_User|>" << chat[i]->content << "<|hy_Assistant|>"; } } - if (add_ass) { - ss << "<|hy_Assistant|>"; - } else { - ss << "<|hy_place▁holder▁no▁8|>"; - } } else if (tmpl == LLM_CHAT_TEMPLATE_KIMI_K2) { // moonshotai/Kimi-K2-Instruct for (auto message : chat) { diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index e8bae645088dd..7b7a93566027a 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -307,6 +307,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { }; break; case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM: + case LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE: regex_exprs = { "\\p{N}{1,3}", "[一-龥぀-ゟ゠-ヿ]+", @@ -1964,6 +1965,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "hunyuan") { pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN; clean_spaces = false; + } else if ( + tokenizer_pre == "hunyuan-dense") { + pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE; + clean_spaces = false; } else if ( tokenizer_pre == "kimi-k2") { pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2; diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 842b129e86171..61b8124216847 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -46,6 +46,7 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, + LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, }; struct LLM_KV;