1
1
from typing import Iterable , List
2
2
import os
3
3
import time
4
+ import base64
5
+ import io
4
6
5
7
from emd .models .utils .constants import ModelType ,ServiceType
6
8
12
14
from threading import Thread
13
15
import json
14
16
from transformers import AutoModel
17
+ from PIL import Image
15
18
16
19
17
20
logger = get_logger (__name__ )
@@ -32,6 +35,7 @@ def __init__(self,*args,**kwargs):
32
35
self .proc = None
33
36
self .model = None
34
37
self .pretrained_model_init_kwargs = self .execute_model .executable_config .current_engine .pretrained_model_init_kwargs or {}
38
+ self .is_bge_vl = "bge-vl" in self .model_id .lower ()
35
39
36
40
37
41
def start (self ):
@@ -61,6 +65,15 @@ def start(self):
61
65
device_map = "cuda" ,
62
66
** self .pretrained_model_init_kwargs
63
67
)
68
+
69
+ # BGE-VL specific initialization
70
+ if self .is_bge_vl :
71
+ try :
72
+ self .model .set_processor (model_abs_path )
73
+ logger .info (f"BGE-VL processor set successfully for model: { self .model_id } " )
74
+ except Exception as e :
75
+ logger .warning (f"Failed to set BGE-VL processor: { e } " )
76
+
64
77
logger .info (f"model: { self .model } " )
65
78
# TODO add tokenizer init args from model's definition
66
79
# self.tokenizer = AutoTokenizer.from_pretrained(
@@ -87,16 +100,139 @@ def format_openai_response(self,responses:list[list[float]]):
87
100
}
88
101
}
89
102
103
+ def _process_base64_image (self , image_data : str ) -> Image .Image :
104
+ """Convert base64 string to PIL Image"""
105
+ try :
106
+ # Handle data URL format
107
+ if image_data .startswith ('data:image' ):
108
+ image_data = image_data .split (',' )[1 ]
109
+
110
+ # Decode base64
111
+ image_bytes = base64 .b64decode (image_data )
112
+ image = Image .open (io .BytesIO (image_bytes ))
113
+
114
+ # Convert to RGB if needed
115
+ if image .mode != 'RGB' :
116
+ image = image .convert ('RGB' )
117
+
118
+ return image
119
+ except Exception as e :
120
+ logger .error (f"Failed to process base64 image: { e } " )
121
+ raise ValueError (f"Invalid image data: { e } " )
122
+
123
+ def _convert_pil_to_bytesio (self , pil_image : Image .Image ) -> io .BytesIO :
124
+ """Convert PIL Image to BytesIO object for BGE-VL compatibility"""
125
+ try :
126
+ img_buffer = io .BytesIO ()
127
+ # Save as JPEG to ensure compatibility with BGE-VL model
128
+ pil_image .save (img_buffer , format = 'JPEG' , quality = 95 )
129
+ img_buffer .seek (0 ) # Reset pointer to beginning
130
+ return img_buffer
131
+ except Exception as e :
132
+ logger .error (f"Failed to convert PIL image to BytesIO: { e } " )
133
+ raise ValueError (f"Image conversion failed: { e } " )
134
+
135
+ def _parse_multimodal_inputs (self , inputs ):
136
+ """Parse and categorize multimodal inputs for BGE-VL"""
137
+ text_inputs = []
138
+ image_inputs = []
139
+ multimodal_inputs = []
140
+
141
+ for inp in inputs :
142
+ if isinstance (inp , str ):
143
+ # Simple text input
144
+ text_inputs .append (inp )
145
+ elif isinstance (inp , dict ):
146
+ if inp .get ('type' ) == 'text' :
147
+ text_inputs .append (inp .get ('content' , '' ))
148
+ elif inp .get ('type' ) == 'image' :
149
+ # Image-only input
150
+ image_data = inp .get ('image' ) or inp .get ('content' )
151
+ if image_data :
152
+ pil_image = self ._process_base64_image (image_data )
153
+ # Convert PIL Image to BytesIO for BGE-VL compatibility
154
+ bytesio_image = self ._convert_pil_to_bytesio (pil_image )
155
+ image_inputs .append (bytesio_image )
156
+ elif inp .get ('type' ) == 'multimodal' :
157
+ # Text + Image input
158
+ text = inp .get ('text' , '' )
159
+ image_data = inp .get ('image' )
160
+ if image_data :
161
+ pil_image = self ._process_base64_image (image_data )
162
+ # Convert PIL Image to BytesIO for BGE-VL compatibility
163
+ bytesio_image = self ._convert_pil_to_bytesio (pil_image )
164
+ multimodal_inputs .append ((text , bytesio_image ))
165
+
166
+ return text_inputs , image_inputs , multimodal_inputs
167
+
168
+ def _generate_bge_vl_embeddings (self , inputs ):
169
+ """Generate embeddings using BGE-VL model"""
170
+ text_inputs , image_inputs , multimodal_inputs = self ._parse_multimodal_inputs (inputs )
171
+ all_embeddings = []
172
+
173
+ # Process text-only inputs
174
+ if text_inputs :
175
+ try :
176
+ # Use explicit text= parameter for BGE-VL model
177
+ text_embeddings = self .model .encode (text = text_inputs )
178
+ if hasattr (text_embeddings , 'tolist' ):
179
+ all_embeddings .extend (text_embeddings .tolist ())
180
+ else :
181
+ all_embeddings .extend (text_embeddings )
182
+ except Exception as e :
183
+ logger .error (f"Failed to encode text inputs: { e } " )
184
+ raise ValueError (f"BGE-VL text encoding failed: { e } " )
185
+
186
+ # Process image-only inputs
187
+ if image_inputs :
188
+ try :
189
+ # Use explicit images= parameter with list format
190
+ image_embeddings = self .model .encode (images = image_inputs )
191
+ if hasattr (image_embeddings , 'tolist' ):
192
+ all_embeddings .extend (image_embeddings .tolist ())
193
+ else :
194
+ all_embeddings .extend (image_embeddings )
195
+ except Exception as e :
196
+ logger .error (f"Failed to encode image inputs: { e } " )
197
+ raise ValueError (f"BGE-VL image encoding failed: { e } " )
198
+
199
+ # Process multimodal inputs (text + image)
200
+ if multimodal_inputs :
201
+ for text , bytesio_image in multimodal_inputs :
202
+ try :
203
+ # Use explicit parameters with list format for both text and images
204
+ multimodal_embedding = self .model .encode (text = [text ], images = [bytesio_image ])
205
+ if hasattr (multimodal_embedding , 'tolist' ):
206
+ all_embeddings .append (multimodal_embedding .tolist ())
207
+ else :
208
+ all_embeddings .append (multimodal_embedding )
209
+ except Exception as e :
210
+ logger .error (f"Failed to encode multimodal input: { e } " )
211
+ raise ValueError (f"BGE-VL multimodal encoding failed: { e } " )
212
+
213
+ return all_embeddings
214
+
90
215
def invoke (self , request :dict ):
91
216
inputs = request ['input' ]
92
217
if not inputs :
93
218
return []
94
219
95
- task = request .get ('task' , 'text-matching' )
96
- truncate_dim = request .get ('truncate_dim' , None )
97
220
logger .info (f'request: { request } ' )
98
221
t0 = time .time ()
99
- embeddings = self .model .encode (inputs , task = task ,truncate_dim = truncate_dim )
100
- embeddings_list = embeddings .tolist ()
101
- logger .info (f'embeddings res: { embeddings_list } ,\n elapsed time: { time .time ()- t0 } ' )
222
+
223
+ if self .is_bge_vl :
224
+ # Use BGE-VL multimodal processing
225
+ embeddings_list = self ._generate_bge_vl_embeddings (inputs )
226
+ else :
227
+ # Use standard text embedding processing
228
+ task = request .get ('task' , 'text-matching' )
229
+ truncate_dim = request .get ('truncate_dim' , None )
230
+ embeddings = self .model .encode (inputs , task = task , truncate_dim = truncate_dim )
231
+ embeddings_list = embeddings .tolist ()
232
+
233
+ logger .info (f'embeddings generated, count: { len (embeddings_list )} , elapsed time: { time .time ()- t0 } ' )
102
234
return self .format_openai_response (embeddings_list )
235
+
236
+ async def ainvoke (self , request : dict ):
237
+ """Async version of invoke method"""
238
+ return self .invoke (request )
0 commit comments