Skip to content

Commit 450e1c1

Browse files
committed
[refactor] merge Yi into llama model
1 parent 9b69c9c commit 450e1c1

File tree

5 files changed

+41
-435
lines changed

5 files changed

+41
-435
lines changed

src/model_loader/model_loader.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,6 @@ bool HFModelLoader::load_model_args(const std::string& model_weights_path) {
203203
return false;
204204
}
205205

206-
// override model type from gflag if exists
207-
if (!FLAGS_model_type.empty()) {
208-
model_type = FLAGS_model_type;
209-
}
210206
auto args_loader = ModelRegistry::get_model_args_loader(model_type);
211207
if (args_loader == nullptr) {
212208
LOG(ERROR) << "Failed to find model args loader for model type "

src/models/huggingface/llama.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,12 +363,45 @@ class LlamaForCausalLMImpl : public torch::nn::Module {
363363
};
364364
TORCH_MODULE(LlamaForCausalLM);
365365

366+
class YiChatTemplate final : public CodedChatTemplate {
367+
public:
368+
// generate prompt from dialogs
369+
// https://huggingface.co/01-ai/Yi-34B-Chat/blob/main/tokenizer_config.json#L60
370+
// Prompt template:
371+
// <|im_start|>user\n {message} <|im_end|>\n
372+
// <|im_start|>assistant\n
373+
std::optional<std::string> get_prompt(
374+
const std::string_view& system_message,
375+
const std::vector<std::string_view>& messages) const override {
376+
// at least one user message
377+
if (messages.size() % 2 == 0) {
378+
return std::nullopt;
379+
}
380+
381+
std::stringstream ss;
382+
if (!system_message.empty()) {
383+
ss << "<|im_start|>system\n" << system_message << "<|im_end|>\n";
384+
}
385+
386+
// then user and assistant message pairs (u/a/u/a/u...)
387+
for (size_t i = 0; i < messages.size(); ++i) {
388+
const char* role = (i % 2) == 0 ? "user" : "assistant";
389+
ss << "<|im_start|>" << role << "\n" << messages[i] << "<|im_end|>\n";
390+
}
391+
// end with assistant message
392+
ss << "<|im_start|>assistant\n";
393+
return ss.str();
394+
}
395+
};
396+
366397
// register the causal model
367398
REGISTER_CAUSAL_MODEL(llama, LlamaForCausalLM);
368399
REGISTER_CAUSAL_MODEL(llama3, LlamaForCausalLM);
400+
REGISTER_CAUSAL_MODEL(Yi, LlamaForCausalLM);
369401

370402
REGISTER_DEFAULT_CHAT_TEMPLATE(llama, Llama2ChatTemplate);
371403
REGISTER_DEFAULT_CHAT_TEMPLATE(llama3, Llama3ChatTemplate);
404+
REGISTER_DEFAULT_CHAT_TEMPLATE(Yi, YiChatTemplate);
372405
// register the model args
373406
// example config:
374407
// https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct/blob/main/config.json
@@ -399,6 +432,12 @@ REGISTER_MODEL_ARGS(llama, [&] {
399432
SET_ARG(model_type, "llama3");
400433
// stop token ids: "<|end_of_text|>", "<|eot_id|>"
401434
SET_ARG(stop_token_ids, std::unordered_set<int32_t>({128001, 128009}));
435+
} else if (args->vocab_size() == 64000) {
436+
// choose the right chat template
437+
SET_ARG(model_type, "Yi");
438+
// stop token ids: "<|endoftext|>", "<|im_start|>", "<|im_end|>",
439+
// "<|im_sep|>"
440+
SET_ARG(stop_token_ids, std::unordered_set<int32_t>({2, 6, 7, 8}));
402441
}
403442
});
404443

0 commit comments

Comments
 (0)