1- #
2- # Copyright 2023 The LLM-on-Ray Authors.
3- #
4- # Licensed under the Apache License, Version 2.0 (the "License");
5- # you may not use this file except in compliance with the License.
6- # You may obtain a copy of the License at
7- #
8- # http://www.apache.org/licenses/LICENSE-2.0
9- #
10- # Unless required by applicable law or agreed to in writing, software
11- # distributed under the License is distributed on an "AS IS" BASIS,
12- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13- # See the License for the specific language governing permissions and
14- # limitations under the License.
15- #
161from typing import List , Union
17-
182from llm_on_ray .inference .api_openai_backend .openai_protocol import ChatMessage
193
204
@@ -23,57 +7,31 @@ def __init__(self, predictor) -> None:
237 self .predictor = predictor
248
259 def get_prompt (self , input : List , is_mllm = False ):
26- """Generate response based on input."""
27- if self .predictor .infer_conf .model_description .chat_template is not None :
28- self .predictor .tokenizer .chat_template = (
29- self .predictor .infer_conf .model_description .chat_template
10+ self .predictor .tokenizer .chat_template = (
11+ self .predictor .infer_conf .model_description .chat_template
12+ or self .predictor .tokenizer .chat_template
13+ or self .predictor .infer_conf .model_description .default_chat_template
14+ )
15+
16+ if isinstance (input , list ) and input and isinstance (input [0 ], (ChatMessage , dict )):
17+ messages = (
18+ [dict (chat_message ) for chat_message in input ]
19+ if isinstance (input [0 ], ChatMessage )
20+ else input
3021 )
31- elif self .predictor .tokenizer .chat_template is None :
32- self .predictor .tokenizer .chat_template = (
33- self .predictor .infer_conf .model_description .default_chat_template
22+ prompt = self .predictor .tokenizer .apply_chat_template (
23+ messages , add_generation_prompt = True , tokenize = False
3424 )
35-
36- if is_mllm :
37- if isinstance (input , List ):
38- if isinstance (input , list ) and input and isinstance (input [0 ], ChatMessage ):
39- messages = []
40- for chat_message in input :
41- message = {
42- "role" : chat_message .role ,
43- "content" : chat_message .content ,
44- }
45- messages .append (message )
46- texts , images = self ._extract_messages (messages )
47- elif isinstance (input , list ) and input and isinstance (input [0 ], dict ):
48- texts , images = self ._extract_messages (input )
49- elif isinstance (input , list ) and input and isinstance (input [0 ], list ):
50- texts , images = [self ._extract_messages (p ) for p in input ]
51-
25+ if is_mllm :
26+ texts , images = self ._extract_messages (messages )
5227 image = self ._prepare_image (images )
53- prompt = self .predictor .tokenizer .apply_chat_template (texts , tokenize = False )
54- return prompt , image
55- else :
56- if isinstance (input , list ) and input and isinstance (input [0 ], dict ):
57- prompt = self .predictor .tokenizer .apply_chat_template (input , tokenize = False )
58- elif isinstance (input , list ) and input and isinstance (input [0 ], list ):
59- prompt = [
60- self .predictor .tokenizer .apply_chat_template (t , tokenize = False ) for t in input
61- ]
62- elif isinstance (input , list ) and input and isinstance (input [0 ], ChatMessage ):
63- messages = []
64- for chat_message in input :
65- message = {"role" : chat_message .role , "content" : chat_message .content }
66- messages .append (message )
67- prompt = self .predictor .tokenizer .apply_chat_template (messages , tokenize = False )
68- elif isinstance (input , list ) and input and isinstance (input [0 ], str ):
69- prompt = input
70- elif isinstance (input , str ):
71- prompt = input
72- else :
73- raise TypeError (
74- f"Unsupported type { type (input )} for text. Expected dict or list of dicts."
28+ prompt = self .predictor .tokenizer .apply_chat_template (
29+ texts , add_generation_prompt = True , tokenize = False
7530 )
76- return prompt
31+ return prompt , image
32+ return prompt
33+
34+ raise TypeError (f"Unsupported type { type (input )} for text. Expected dict or list of dicts." )
7735
7836 def _extract_messages (self , messages ):
7937 texts , images = [], []
@@ -88,39 +46,23 @@ def _extract_messages(self, messages):
8846 return texts , images
8947
9048 def _prepare_image (self , messages : list ):
91- """Prepare image from history messages."""
9249 from PIL import Image
9350 import requests
9451 from io import BytesIO
9552 import base64
9653 import re
9754
98- # prepare images
9955 images : List = []
100- if isinstance (messages [0 ], List ):
101- for i in range (len (messages )):
102- for msg in messages [i ]:
103- msg = dict (msg )
104- content = msg ["content" ]
105- if "url" not in content :
106- continue
107- is_data = len (re .findall ("^data:image/.+;base64," , content ["url" ])) > 0
108- if is_data :
109- encoded_str = re .sub ("^data:image/.+;base64," , "" , content ["url" ])
110- images [i ].append (Image .open (BytesIO (base64 .b64decode (encoded_str ))))
111- else :
112- images [i ].append (Image .open (requests .get (content ["url" ], stream = True ).raw ))
113- elif isinstance (messages [0 ], dict ):
114- for msg in messages :
115- msg = dict (msg )
116- content = msg ["content" ]
117- if "url" not in content :
118- continue
119- is_data = len (re .findall ("^data:image/.+;base64," , content ["url" ])) > 0
120- if is_data :
121- encoded_str = re .sub ("^data:image/.+;base64," , "" , content ["url" ])
122- images .append (Image .open (BytesIO (base64 .b64decode (encoded_str ))))
123- else :
124- images .append (Image .open (requests .get (content ["url" ], stream = True ).raw ))
56+ for msg in messages :
57+ msg = dict (msg )
58+ content = msg ["content" ]
59+ if "url" not in content :
60+ continue
61+ is_data = len (re .findall ("^data:image/.+;base64," , content ["url" ])) > 0
62+ if is_data :
63+ encoded_str = re .sub ("^data:image/.+;base64," , "" , content ["url" ])
64+ images .append (Image .open (BytesIO (base64 .b64decode (encoded_str ))))
65+ else :
66+ images .append (Image .open (requests .get (content ["url" ], stream = True ).raw ))
12567
12668 return images
0 commit comments