@@ -363,12 +363,45 @@ class LlamaForCausalLMImpl : public torch::nn::Module {
363363};
364364TORCH_MODULE (LlamaForCausalLM);
365365
366+ class YiChatTemplate final : public CodedChatTemplate {
367+ public:
368+ // generate prompt from dialogs
369+ // https://huggingface.co/01-ai/Yi-34B-Chat/blob/main/tokenizer_config.json#L60
370+ // Prompt template:
371+ // <|im_start|>user\n {message} <|im_end|>\n
372+ // <|im_start|>assistant\n
373+ std::optional<std::string> get_prompt (
374+ const std::string_view& system_message,
375+ const std::vector<std::string_view>& messages) const override {
376+ // at least one user message
377+ if (messages.size () % 2 == 0 ) {
378+ return std::nullopt ;
379+ }
380+
381+ std::stringstream ss;
382+ if (!system_message.empty ()) {
383+ ss << " <|im_start|>system\n " << system_message << " <|im_end|>\n " ;
384+ }
385+
386+ // then user and assistant message pairs (u/a/u/a/u...)
387+ for (size_t i = 0 ; i < messages.size (); ++i) {
388+ const char * role = (i % 2 ) == 0 ? " user" : " assistant" ;
389+ ss << " <|im_start|>" << role << " \n " << messages[i] << " <|im_end|>\n " ;
390+ }
391+ // end with assistant message
392+ ss << " <|im_start|>assistant\n " ;
393+ return ss.str ();
394+ }
395+ };
396+
366397// register the causal model
367398REGISTER_CAUSAL_MODEL (llama, LlamaForCausalLM);
368399REGISTER_CAUSAL_MODEL (llama3, LlamaForCausalLM);
400+ REGISTER_CAUSAL_MODEL (Yi, LlamaForCausalLM);
369401
370402REGISTER_DEFAULT_CHAT_TEMPLATE (llama, Llama2ChatTemplate);
371403REGISTER_DEFAULT_CHAT_TEMPLATE (llama3, Llama3ChatTemplate);
404+ REGISTER_DEFAULT_CHAT_TEMPLATE (Yi, YiChatTemplate);
372405// register the model args
373406// example config:
374407// https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct/blob/main/config.json
@@ -399,6 +432,12 @@ REGISTER_MODEL_ARGS(llama, [&] {
399432 SET_ARG (model_type, " llama3" );
400433 // stop token ids: "<|end_of_text|>", "<|eot_id|>"
401434 SET_ARG (stop_token_ids, std::unordered_set<int32_t >({128001 , 128009 }));
435+ } else if (args->vocab_size () == 64000 ) {
436+ // choose the right chat template
437+ SET_ARG (model_type, " Yi" );
438+ // stop token ids: "<|endoftext|>", "<|im_start|>", "<|im_end|>",
439+ // "<|im_sep|>"
440+ SET_ARG (stop_token_ids, std::unordered_set<int32_t >({2 , 6 , 7 , 8 }));
402441 }
403442});
404443
0 commit comments