Skip to content

Commit 9b69c9c

Browse files
committed
[feat] added llama3 chat template support
1 parent 2aba2d3 commit 9b69c9c

File tree

4 files changed

+69
-20
lines changed

4 files changed

+69
-20
lines changed

README.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ ScaleLLM is a cutting-edge inference system engineered for large language models
5757
| GPT_NeoX | Yes | Yes | No | [EleutherAI/gpt-neox-20b](https://huggingface.co/EleutherAI/gpt-neox-20b) |
5858
| GPT2 | Yes | Yes | No | [gpt2](https://huggingface.co/gpt2)|
5959
| InternLM | Yes | Yes | Yes | [internlm/internlm-7b](https://huggingface.co/internlm/internlm-7b) |
60-
| Llama3/2 | Yes | Yes | Yes | [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B), [meta-llama/Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b), [TheBloke/Llama-2-13B-chat-GPTQ](https://huggingface.co/TheBloke/Llama-2-13B-chat-GPTQ), [TheBloke/Llama-2-70B-AWQ](https://huggingface.co/TheBloke/Llama-2-70B-AWQ) |
60+
| Llama3/2 | Yes | Yes | Yes | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct), [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B), [meta-llama/Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b) |
6161
| Mistral | Yes | Yes | Yes | [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) |
6262
| MPT | Yes | Yes | Yes | [mosaicml/mpt-30b](https://huggingface.co/mosaicml/mpt-30b) |
6363
| Phi2 | Yes | Yes | No | [microsoft/phi-2](https://huggingface.co/microsoft/phi-2) |
@@ -89,7 +89,7 @@ Once you have Docker installed, you can run ScaleLLM Docker container with [late
8989
docker pull docker.io/vectorchai/scalellm:latest
9090
docker run -it --gpus=all --net=host --shm-size=1g \
9191
-v $HOME/.cache/huggingface/hub:/models \
92-
-e HF_MODEL_ID=meta-llama/Meta-Llama-3-8B \
92+
-e HF_MODEL_ID=meta-llama/Meta-Llama-3-8B-Instruct \
9393
-e DEVICE=cuda:0 \
9494
docker.io/vectorchai/scalellm:latest --logtostderr
9595
```
@@ -100,7 +100,7 @@ This command starts the Docker container with GPU support and various configurat
100100
- `HF_MODEL_REVISION` specifies which Hugging Face model revision you want to run. By default, it is set to `"main"`.
101101
- `DEVICE` specifies the device on which this model should run. By default, it is set to `"auto"`, using all available GPUs. You can also specify specific GPUs by using `"cuda:0,cuda:1"`, or use CPU by using `"cpu"`.
102102
- `HF_MODEL_ALLOW_PATTERN` specifies which types of files are allowed to be downloaded. By default, it will be configured automatically based on tensor type. Only use this option if the default configuration is not working for you.
103-
- `HUGGING_FACE_HUB_TOKEN` specifies the token from [huggingface](https://huggingface.co/settings/tokens) for gated models.
103+
- `HUGGING_FACE_HUB_TOKEN` specifies the token from [huggingface](https://huggingface.co/settings/tokens) for gated models. `-e HUGGING_FACE_HUB_TOKEN=$HUGGING_FACE_HUB_TOKEN`
104104

105105
> **Warning**<br />
106106
> * The docker image with tag '[latest](https://hub.docker.com/r/vectorchai/scalellm/tags)' could be changed to a new version upon new release. In order to use latest image, you may need to repull the image with specific tag.
@@ -155,7 +155,7 @@ Using Docker Compose is the easiest way to run ScaleLLM with all the services to
155155

156156
```bash
157157
curl https://raw.githubusercontent.com/vectorch-ai/ScaleLLM/main/scalellm.yml -sSf > scalellm_compose.yml
158-
HF_MODEL_ID=meta-llama/Meta-Llama-3-8B DEVICE=cuda docker compose -f ./scalellm_compose.yml up
158+
HF_MODEL_ID=meta-llama/Meta-Llama-3-8B-Instruct DEVICE=cuda docker compose -f ./scalellm_compose.yml up
159159
```
160160

161161
you will get following running services:
@@ -173,7 +173,7 @@ You can get chat completions with the following example:
173173
curl http://localhost:8080/v1/chat/completions \
174174
-H "Content-Type: application/json" \
175175
-d '{
176-
"model": "meta-llama/Meta-Llama-3-8B",
176+
"model": "meta-llama/Meta-Llama-3-8B-Instruct",
177177
"messages": [
178178
{
179179
"role": "system",
@@ -198,7 +198,7 @@ openai.api_base = "http://localhost:8080/v1"
198198
print("==== Available models ====")
199199
models = openai.Model.list()
200200

201-
model = "meta-llama/Meta-Llama-3-8B"
201+
model = "meta-llama/Meta-Llama-3-8B-Instruct"
202202

203203
completion = openai.ChatCompletion.create(
204204
model=model,
@@ -225,7 +225,7 @@ For regular completions, you can use this example:
225225
curl http://localhost:8080/v1/completions \
226226
-H "Content-Type: application/json" \
227227
-d '{
228-
"model": "meta-llama/Meta-Llama-3-8B",
228+
"model": "meta-llama/Meta-Llama-3-8B-Instruct",
229229
"prompt": "hello",
230230
"max_tokens": 32,
231231
"temperature": 0.7,
@@ -244,7 +244,7 @@ openai.api_base = "http://localhost:8080/v1"
244244
print("==== Available models ====")
245245
models = openai.Model.list()
246246

247-
model = "meta-llama/Meta-Llama-3-8B"
247+
model = "meta-llama/Meta-Llama-3-8B-Instruct"
248248

249249
completion = openai.Completion.create(
250250
model=model,

src/chat_template/common_chat_template.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,36 @@ std::optional<std::string> Llama2ChatTemplate::get_prompt(
4242
return ss.str();
4343
}
4444

45+
// generate prompt from ChatTemplate
46+
std::optional<std::string> Llama3ChatTemplate::get_prompt(
47+
const std::string_view& system_message,
48+
const std::vector<std::string_view>& messages) const {
49+
// at least one user message
50+
if (messages.size() % 2 == 0) {
51+
return std::nullopt;
52+
}
53+
54+
std::stringstream ss;
55+
ss << "<|begin_of_text|>";
56+
auto add_message = [&ss](const std::string_view& role,
57+
const std::string_view& message) {
58+
ss << "<|start_header_id|>" << role << "<|end_header_id|>\n\n";
59+
ss << message << "<|eot_id|>";
60+
};
61+
62+
// start with system message
63+
if (!system_message.empty()) {
64+
add_message("system", system_message);
65+
}
66+
67+
// then user and assistant message pairs (u/a/u/a/u...)
68+
for (size_t i = 0; i < messages.size(); ++i) {
69+
const char* role = i % 2 == 0 ? "user" : "assistant";
70+
add_message(role, messages[i]);
71+
}
72+
// end with assistant message
73+
ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
74+
return ss.str();
75+
}
76+
4577
} // namespace llm

src/chat_template/common_chat_template.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,12 @@ class Llama2ChatTemplate final : public CodedChatTemplate {
1818
const std::vector<std::string_view>& messages) const override;
1919
};
2020

21+
class Llama3ChatTemplate final : public CodedChatTemplate {
22+
public:
23+
// generate prompt from dialogs
24+
std::optional<std::string> get_prompt(
25+
const std::string_view& system_message,
26+
const std::vector<std::string_view>& messages) const override;
27+
};
28+
2129
} // namespace llm

src/models/huggingface/llama.h

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -365,32 +365,41 @@ TORCH_MODULE(LlamaForCausalLM);
365365

366366
// register the causal model
367367
REGISTER_CAUSAL_MODEL(llama, LlamaForCausalLM);
368+
REGISTER_CAUSAL_MODEL(llama3, LlamaForCausalLM);
369+
368370
REGISTER_DEFAULT_CHAT_TEMPLATE(llama, Llama2ChatTemplate);
371+
REGISTER_DEFAULT_CHAT_TEMPLATE(llama3, Llama3ChatTemplate);
369372
// register the model args
370373
// example config:
371-
// https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json set
372-
// default values for args explicitly with values from:
373-
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/configuration_llama.py#L112
374+
// https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct/blob/main/config.json
374375
REGISTER_MODEL_ARGS(llama, [&] {
375376
LOAD_ARG_OR(model_type, "model_type", "llama");
376377
LOAD_ARG_OR(dtype, "torch_dtype", "");
377-
LOAD_ARG_OR(vocab_size, "vocab_size", 32000);
378-
LOAD_ARG_OR(hidden_size, "hidden_size", 4096);
379-
LOAD_ARG_OR(n_layers, "num_hidden_layers", 32);
380-
LOAD_ARG_OR(n_heads, "num_attention_heads", 32);
378+
LOAD_ARG_OR(vocab_size, "vocab_size", 128256);
379+
LOAD_ARG_OR(hidden_size, "hidden_size", 8192);
380+
LOAD_ARG_OR(n_layers, "num_hidden_layers", 80);
381+
LOAD_ARG_OR(n_heads, "num_attention_heads", 64);
381382
LOAD_ARG(n_kv_heads, "num_key_value_heads");
382-
LOAD_ARG_OR(intermediate_size, "intermediate_size", 11008);
383+
LOAD_ARG_OR(intermediate_size, "intermediate_size", 28672);
383384
LOAD_ARG_OR(hidden_act, "hidden_act", "silu");
384-
LOAD_ARG_OR(max_position_embeddings, "max_position_embeddings", 2048);
385+
LOAD_ARG_OR(max_position_embeddings, "max_position_embeddings", 8192);
385386
LOAD_ARG_OR(rms_norm_eps, "rms_norm_eps", 1e-5);
386-
LOAD_ARG_OR(bos_token_id, "bos_token_id", 1);
387-
LOAD_ARG_OR(eos_token_id, "eos_token_id", 2);
388-
LOAD_ARG_OR(rope_theta, "rope_theta", 10000.0f);
387+
LOAD_ARG_OR(bos_token_id, "bos_token_id", 128000);
388+
LOAD_ARG_OR(eos_token_id, "eos_token_id", 128001);
389+
LOAD_ARG_OR(rope_theta, "rope_theta", 500000.0f);
389390
LOAD_ARG_OR(rope_scaling, "rope_scaling", 1.0f);
390391

391392
LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] {
392393
return args->hidden_size() / args->n_heads();
393394
});
395+
396+
// decide model type based on vocab size
397+
if (args->vocab_size() == 128256) {
398+
// choose the right chat template
399+
SET_ARG(model_type, "llama3");
400+
// stop token ids: "<|end_of_text|>", "<|eot_id|>"
401+
SET_ARG(stop_token_ids, std::unordered_set<int32_t>({128001, 128009}));
402+
}
394403
});
395404

396405
} // namespace llm::hf

0 commit comments

Comments
 (0)