Skip to content
This repository was archived by the owner on Sep 23, 2025. It is now read-only.

Commit 8698a17

Browse files
committed
update
1 parent 60d9f15 commit 8698a17

File tree

8 files changed

+58
-177
lines changed

8 files changed

+58
-177
lines changed

examples/inference/api_server_simple/query_single.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,12 @@
5555
)
5656

5757
args = parser.parse_args()
58-
prompt = "Once upon a time,"
58+
# prompt = "Once upon a time,"
59+
prompt = [
60+
{"role": "user", "content": "Which is bigger, the moon or the sun?"},
61+
]
62+
63+
5964
config: Dict[str, Union[int, float]] = {}
6065
if args.max_new_tokens:
6166
config["max_new_tokens"] = int(args.max_new_tokens)

llm_on_ray/finetune/finetune_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class General(BaseModel):
6464
enable_gradient_checkpointing: bool = False
6565
chat_template: Optional[str] = None
6666
default_chat_template: str = (
67-
"{{ bos_token }}"
67+
"Below is an instruction that describes a task. Write a response that appropriately completes the request."
6868
"{% if messages[0]['role'] == 'system' %}"
6969
"{{ raise_exception('System role not supported') }}"
7070
"{% endif %}"
Lines changed: 32 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,4 @@
1-
#
2-
# Copyright 2023 The LLM-on-Ray Authors.
3-
#
4-
# Licensed under the Apache License, Version 2.0 (the "License");
5-
# you may not use this file except in compliance with the License.
6-
# You may obtain a copy of the License at
7-
#
8-
# http://www.apache.org/licenses/LICENSE-2.0
9-
#
10-
# Unless required by applicable law or agreed to in writing, software
11-
# distributed under the License is distributed on an "AS IS" BASIS,
12-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13-
# See the License for the specific language governing permissions and
14-
# limitations under the License.
15-
#
161
from typing import List, Union
17-
182
from llm_on_ray.inference.api_openai_backend.openai_protocol import ChatMessage
193

204

@@ -23,57 +7,31 @@ def __init__(self, predictor) -> None:
237
self.predictor = predictor
248

259
def get_prompt(self, input: List, is_mllm=False):
26-
"""Generate response based on input."""
27-
if self.predictor.infer_conf.model_description.chat_template is not None:
28-
self.predictor.tokenizer.chat_template = (
29-
self.predictor.infer_conf.model_description.chat_template
10+
self.predictor.tokenizer.chat_template = (
11+
self.predictor.infer_conf.model_description.chat_template
12+
or self.predictor.tokenizer.chat_template
13+
or self.predictor.infer_conf.model_description.default_chat_template
14+
)
15+
16+
if isinstance(input, list) and input and isinstance(input[0], (ChatMessage, dict)):
17+
messages = (
18+
[dict(chat_message) for chat_message in input]
19+
if isinstance(input[0], ChatMessage)
20+
else input
3021
)
31-
elif self.predictor.tokenizer.chat_template is None:
32-
self.predictor.tokenizer.chat_template = (
33-
self.predictor.infer_conf.model_description.default_chat_template
22+
prompt = self.predictor.tokenizer.apply_chat_template(
23+
messages, add_generation_prompt=True, tokenize=False
3424
)
35-
36-
if is_mllm:
37-
if isinstance(input, List):
38-
if isinstance(input, list) and input and isinstance(input[0], ChatMessage):
39-
messages = []
40-
for chat_message in input:
41-
message = {
42-
"role": chat_message.role,
43-
"content": chat_message.content,
44-
}
45-
messages.append(message)
46-
texts, images = self._extract_messages(messages)
47-
elif isinstance(input, list) and input and isinstance(input[0], dict):
48-
texts, images = self._extract_messages(input)
49-
elif isinstance(input, list) and input and isinstance(input[0], list):
50-
texts, images = [self._extract_messages(p) for p in input]
51-
25+
if is_mllm:
26+
texts, images = self._extract_messages(messages)
5227
image = self._prepare_image(images)
53-
prompt = self.predictor.tokenizer.apply_chat_template(texts, tokenize=False)
54-
return prompt, image
55-
else:
56-
if isinstance(input, list) and input and isinstance(input[0], dict):
57-
prompt = self.predictor.tokenizer.apply_chat_template(input, tokenize=False)
58-
elif isinstance(input, list) and input and isinstance(input[0], list):
59-
prompt = [
60-
self.predictor.tokenizer.apply_chat_template(t, tokenize=False) for t in input
61-
]
62-
elif isinstance(input, list) and input and isinstance(input[0], ChatMessage):
63-
messages = []
64-
for chat_message in input:
65-
message = {"role": chat_message.role, "content": chat_message.content}
66-
messages.append(message)
67-
prompt = self.predictor.tokenizer.apply_chat_template(messages, tokenize=False)
68-
elif isinstance(input, list) and input and isinstance(input[0], str):
69-
prompt = input
70-
elif isinstance(input, str):
71-
prompt = input
72-
else:
73-
raise TypeError(
74-
f"Unsupported type {type(input)} for text. Expected dict or list of dicts."
28+
prompt = self.predictor.tokenizer.apply_chat_template(
29+
texts, add_generation_prompt=True, tokenize=False
7530
)
76-
return prompt
31+
return prompt, image
32+
return prompt
33+
34+
raise TypeError(f"Unsupported type {type(input)} for text. Expected dict or list of dicts.")
7735

7836
def _extract_messages(self, messages):
7937
texts, images = [], []
@@ -88,39 +46,23 @@ def _extract_messages(self, messages):
8846
return texts, images
8947

9048
def _prepare_image(self, messages: list):
91-
"""Prepare image from history messages."""
9249
from PIL import Image
9350
import requests
9451
from io import BytesIO
9552
import base64
9653
import re
9754

98-
# prepare images
9955
images: List = []
100-
if isinstance(messages[0], List):
101-
for i in range(len(messages)):
102-
for msg in messages[i]:
103-
msg = dict(msg)
104-
content = msg["content"]
105-
if "url" not in content:
106-
continue
107-
is_data = len(re.findall("^data:image/.+;base64,", content["url"])) > 0
108-
if is_data:
109-
encoded_str = re.sub("^data:image/.+;base64,", "", content["url"])
110-
images[i].append(Image.open(BytesIO(base64.b64decode(encoded_str))))
111-
else:
112-
images[i].append(Image.open(requests.get(content["url"], stream=True).raw))
113-
elif isinstance(messages[0], dict):
114-
for msg in messages:
115-
msg = dict(msg)
116-
content = msg["content"]
117-
if "url" not in content:
118-
continue
119-
is_data = len(re.findall("^data:image/.+;base64,", content["url"])) > 0
120-
if is_data:
121-
encoded_str = re.sub("^data:image/.+;base64,", "", content["url"])
122-
images.append(Image.open(BytesIO(base64.b64decode(encoded_str))))
123-
else:
124-
images.append(Image.open(requests.get(content["url"], stream=True).raw))
56+
for msg in messages:
57+
msg = dict(msg)
58+
content = msg["content"]
59+
if "url" not in content:
60+
continue
61+
is_data = len(re.findall("^data:image/.+;base64,", content["url"])) > 0
62+
if is_data:
63+
encoded_str = re.sub("^data:image/.+;base64,", "", content["url"])
64+
images.append(Image.open(BytesIO(base64.b64decode(encoded_str))))
65+
else:
66+
images.append(Image.open(requests.get(content["url"], stream=True).raw))
12567

12668
return images

llm_on_ray/inference/inference_config.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,20 +115,22 @@ class ModelDescription(BaseModel):
115115
chat_model_with_image: bool = False
116116
chat_template: Union[str, None] = None
117117
default_chat_template: str = (
118-
"{{ bos_token }}"
118+
"Below is an instruction that describes a task. Write a response that appropriately completes the request."
119119
"{% if messages[0]['role'] == 'system' %}"
120-
"{{ raise_exception('System role not supported') }}"
121-
"{% endif %}"
122-
"{% for message in messages %}"
120+
"{% set loop_messages = messages[1:] %}"
121+
"{% set system_message = messages[0]['content'] %}"
122+
"{% else %}{% set loop_messages = messages %}"
123+
"{% set system_message = false %}{% endif %}"
124+
"{% for message in loop_messages %}"
123125
"{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
124126
"{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
125127
"{% endif %}"
126128
"{% if message['role'] == 'user' %}"
127-
"{{ '### Instruction: ' + message['content'] + eos_token }}"
129+
"{{ '### Instruction: ' + message['content'].strip() }}"
128130
"{% elif message['role'] == 'assistant' %}"
129-
"{{ '### Response:' + message['content'] + eos_token }}"
131+
"{{ '### Response:' + message['content'].strip() }}"
130132
"{% endif %}{% endfor %}"
131-
"{{'### End \n'}}"
133+
"{% if add_generation_prompt %}{{'### Response:\n'}}{% endif %}"
132134
)
133135

134136
@validator("quantization_type")

llm_on_ray/inference/models/gemma-2b.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ model_description:
1515
tokenizer_name_or_path: google/gemma-2b
1616
config:
1717
use_auth_token: ' '
18+
chat_template: "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"

llm_on_ray/inference/models/gpt2.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ model_description:
1414
model_id_or_path: gpt2
1515
tokenizer_name_or_path: gpt2
1616
gpt_base_model: true
17+
chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + message['content'].strip() }}{% elif message['role'] == 'assistant' %}{{ message['content'].strip() + eos_token }}{% endif %}{% endfor %}"

llm_on_ray/inference/predictor_deployment.py

Lines changed: 6 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def preprocess_prompts(self, input: Union[str, list], tools=None, tool_choice=No
311311
Raises:
312312
HTTPException: If the input prompt format is invalid or not supported.
313313
"""
314-
314+
print("preprocess_prompts")
315315
if isinstance(input, str):
316316
return input
317317
elif isinstance(input, list):
@@ -344,31 +344,6 @@ def preprocess_prompts(self, input: Union[str, list], tools=None, tool_choice=No
344344
else:
345345
prompt = self.process_tool.get_prompt(input)
346346
return prompt
347-
else:
348-
if isinstance(input, list) and input and isinstance(input[0], dict):
349-
prompt = self.predictor.tokenizer.apply_chat_template(input, tokenize=False)
350-
elif isinstance(input, list) and input and isinstance(input[0], list):
351-
prompt = [
352-
self.predictor.tokenizer.apply_chat_template(t, tokenize=False)
353-
for t in input
354-
]
355-
elif isinstance(input, list) and input and isinstance(input[0], ChatMessage):
356-
messages = []
357-
for chat_message in input:
358-
message = {"role": chat_message.role, "content": chat_message.content}
359-
messages.append(message)
360-
prompt = self.predictor.tokenizer.apply_chat_template(
361-
messages, tokenize=False
362-
)
363-
elif isinstance(input, list) and input and isinstance(input[0], str):
364-
prompt = input
365-
elif isinstance(input, str):
366-
prompt = input
367-
else:
368-
raise TypeError(
369-
f"Unsupported type {type(input)} for text. Expected dict or list of dicts."
370-
)
371-
return prompt
372347
elif prompt_format == PromptFormat.PROMPTS_FORMAT:
373348
raise HTTPException(400, "Invalid prompt format.")
374349
return input
@@ -414,63 +389,18 @@ async def openai_call(
414389
tool_choice=None,
415390
):
416391
self.use_openai = True
392+
print("openai_call")
393+
print(input)
394+
print(type(input))
417395

418396
# return prompt or list of prompts preprocessed
419397
prompts = self.preprocess_prompts(input, tools, tool_choice)
398+
print(prompts)
399+
print(type(prompts))
420400

421401
# Handle streaming response
422402
if streaming_response:
423403
async for result in self.handle_streaming(prompts, config):
424404
yield result
425405
else:
426406
yield await self.handle_non_streaming(prompts, config)
427-
428-
def _extract_messages(self, messages):
429-
texts, images = [], []
430-
for message in messages:
431-
if message["role"] == "user" and isinstance(message["content"], list):
432-
texts.append({"role": "user", "content": message["content"][0]["text"]})
433-
images.append(
434-
{"role": "user", "content": message["content"][1]["image_url"]["url"]}
435-
)
436-
else:
437-
texts.append(message)
438-
return texts, images
439-
440-
def _prepare_image(self, messages: list):
441-
"""Prepare image from history messages."""
442-
from PIL import Image
443-
import requests
444-
from io import BytesIO
445-
import base64
446-
import re
447-
448-
# prepare images
449-
images: List = []
450-
if isinstance(messages[0], List):
451-
for i in range(len(messages)):
452-
for msg in messages[i]:
453-
msg = dict(msg)
454-
content = msg["content"]
455-
if "url" not in content:
456-
continue
457-
is_data = len(re.findall("^data:image/.+;base64,", content["url"])) > 0
458-
if is_data:
459-
encoded_str = re.sub("^data:image/.+;base64,", "", content["url"])
460-
images[i].append(Image.open(BytesIO(base64.b64decode(encoded_str))))
461-
else:
462-
images[i].append(Image.open(requests.get(content["url"], stream=True).raw))
463-
elif isinstance(messages[0], dict):
464-
for msg in messages:
465-
msg = dict(msg)
466-
content = msg["content"]
467-
if "url" not in content:
468-
continue
469-
is_data = len(re.findall("^data:image/.+;base64,", content["url"])) > 0
470-
if is_data:
471-
encoded_str = re.sub("^data:image/.+;base64,", "", content["url"])
472-
images.append(Image.open(BytesIO(base64.b64decode(encoded_str))))
473-
else:
474-
images.append(Image.open(requests.get(content["url"], stream=True).raw))
475-
476-
return images

llm_on_ray/inference/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,13 @@ def is_cpu_without_ipex(infer_conf: InferenceConfig) -> bool:
162162
return (not infer_conf.ipex.enabled) and infer_conf.device == DEVICE_CPU
163163

164164

165-
def get_prompt_format(input: Union[List[str], List[dict], List[List[dict]], List[ChatMessage]]):
165+
def get_prompt_format(input: Union[List[str], List[dict], List[ChatMessage]]):
166166
chat_format = True
167167
prompts_format = True
168168
for item in input:
169169
if isinstance(item, str):
170170
chat_format = False
171-
elif isinstance(item, dict) or isinstance(item, ChatMessage) or isinstance(item, list):
171+
elif isinstance(item, dict) or isinstance(item, ChatMessage):
172172
prompts_format = False
173173
else:
174174
chat_format = False

0 commit comments

Comments
 (0)