Skip to content

Commit a134880

Browse files
authored
fix: multimodal embeddings (text + image) (#159)
1 parent ae2515e commit a134880

File tree

3 files changed

+218
-5
lines changed

3 files changed

+218
-5
lines changed

src/emd/models/embeddings/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from . import bert_embedding
22
from . import jina
33
from . import qwen
4+
from . import bge_vl

src/emd/models/embeddings/bge_vl.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from .. import Model
2+
from ..engines import huggingface_embedding_engine449
3+
from ..services import sagemaker_service, local_service, ecs_service
4+
from ..frameworks import fastapi_framework
5+
from ..instances import (
6+
g5dxlarge_instance,
7+
g5d2xlarge_instance,
8+
g5d4xlarge_instance,
9+
g5d8xlarge_instance,
10+
g5d16xlarge_instance,
11+
local_instance
12+
)
13+
from emd.models.utils.constants import ModelType
14+
from ..model_series import BGE_SERIES
15+
16+
17+
Model.register(
18+
dict(
19+
model_id="bge-vl-base",
20+
supported_engines=[huggingface_embedding_engine449],
21+
supported_instances=[
22+
g5dxlarge_instance,
23+
g5d2xlarge_instance,
24+
g5d4xlarge_instance,
25+
g5d8xlarge_instance,
26+
g5d16xlarge_instance,
27+
local_instance,
28+
],
29+
supported_services=[
30+
sagemaker_service,
31+
ecs_service,
32+
local_service
33+
],
34+
supported_frameworks=[
35+
fastapi_framework
36+
],
37+
allow_china_region=True,
38+
huggingface_model_id="BAAI/BGE-VL-base",
39+
modelscope_model_id="BAAI/BGE-VL-base",
40+
require_huggingface_token=False,
41+
application_scenario="Multimodal RAG, composed image retrieval, visual search",
42+
model_type=ModelType.EMBEDDING,
43+
model_series=BGE_SERIES,
44+
description="BGE-VL-base is a multimodal embedding model that supports text, image, and text-image pair inputs for unified multimodal representation learning and cross-modal retrieval tasks. Lightweight with 149M parameters."
45+
)
46+
)
47+
48+
Model.register(
49+
dict(
50+
model_id="bge-vl-large",
51+
supported_engines=[huggingface_embedding_engine449],
52+
supported_instances=[
53+
g5d2xlarge_instance,
54+
g5d4xlarge_instance,
55+
g5d8xlarge_instance,
56+
g5d16xlarge_instance,
57+
local_instance,
58+
],
59+
supported_services=[
60+
sagemaker_service,
61+
ecs_service,
62+
local_service
63+
],
64+
supported_frameworks=[
65+
fastapi_framework
66+
],
67+
allow_china_region=True,
68+
huggingface_model_id="BAAI/BGE-VL-large",
69+
modelscope_model_id="BAAI/BGE-VL-large",
70+
require_huggingface_token=False,
71+
application_scenario="Multimodal RAG, composed image retrieval, visual search",
72+
model_type=ModelType.EMBEDDING,
73+
model_series=BGE_SERIES,
74+
description="BGE-VL-large is a larger multimodal embedding model that supports text, image, and text-image pair inputs for high-performance multimodal representation learning and cross-modal retrieval tasks."
75+
)
76+
)

src/pipeline/backend/huggingface/embedding/transformers_embedding_backend.py

Lines changed: 141 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from typing import Iterable, List
22
import os
33
import time
4+
import base64
5+
import io
46

57
from emd.models.utils.constants import ModelType,ServiceType
68

@@ -12,6 +14,7 @@
1214
from threading import Thread
1315
import json
1416
from transformers import AutoModel
17+
from PIL import Image
1518

1619

1720
logger = get_logger(__name__)
@@ -32,6 +35,7 @@ def __init__(self,*args,**kwargs):
3235
self.proc = None
3336
self.model = None
3437
self.pretrained_model_init_kwargs = self.execute_model.executable_config.current_engine.pretrained_model_init_kwargs or {}
38+
self.is_bge_vl = "bge-vl" in self.model_id.lower()
3539

3640

3741
def start(self):
@@ -61,6 +65,15 @@ def start(self):
6165
device_map="cuda",
6266
**self.pretrained_model_init_kwargs
6367
)
68+
69+
# BGE-VL specific initialization
70+
if self.is_bge_vl:
71+
try:
72+
self.model.set_processor(model_abs_path)
73+
logger.info(f"BGE-VL processor set successfully for model: {self.model_id}")
74+
except Exception as e:
75+
logger.warning(f"Failed to set BGE-VL processor: {e}")
76+
6477
logger.info(f"model: {self.model}")
6578
# TODO add tokenizer init args from model's definition
6679
# self.tokenizer = AutoTokenizer.from_pretrained(
@@ -87,16 +100,139 @@ def format_openai_response(self,responses:list[list[float]]):
87100
}
88101
}
89102

103+
def _process_base64_image(self, image_data: str) -> Image.Image:
104+
"""Convert base64 string to PIL Image"""
105+
try:
106+
# Handle data URL format
107+
if image_data.startswith('data:image'):
108+
image_data = image_data.split(',')[1]
109+
110+
# Decode base64
111+
image_bytes = base64.b64decode(image_data)
112+
image = Image.open(io.BytesIO(image_bytes))
113+
114+
# Convert to RGB if needed
115+
if image.mode != 'RGB':
116+
image = image.convert('RGB')
117+
118+
return image
119+
except Exception as e:
120+
logger.error(f"Failed to process base64 image: {e}")
121+
raise ValueError(f"Invalid image data: {e}")
122+
123+
def _convert_pil_to_bytesio(self, pil_image: Image.Image) -> io.BytesIO:
124+
"""Convert PIL Image to BytesIO object for BGE-VL compatibility"""
125+
try:
126+
img_buffer = io.BytesIO()
127+
# Save as JPEG to ensure compatibility with BGE-VL model
128+
pil_image.save(img_buffer, format='JPEG', quality=95)
129+
img_buffer.seek(0) # Reset pointer to beginning
130+
return img_buffer
131+
except Exception as e:
132+
logger.error(f"Failed to convert PIL image to BytesIO: {e}")
133+
raise ValueError(f"Image conversion failed: {e}")
134+
135+
def _parse_multimodal_inputs(self, inputs):
136+
"""Parse and categorize multimodal inputs for BGE-VL"""
137+
text_inputs = []
138+
image_inputs = []
139+
multimodal_inputs = []
140+
141+
for inp in inputs:
142+
if isinstance(inp, str):
143+
# Simple text input
144+
text_inputs.append(inp)
145+
elif isinstance(inp, dict):
146+
if inp.get('type') == 'text':
147+
text_inputs.append(inp.get('content', ''))
148+
elif inp.get('type') == 'image':
149+
# Image-only input
150+
image_data = inp.get('image') or inp.get('content')
151+
if image_data:
152+
pil_image = self._process_base64_image(image_data)
153+
# Convert PIL Image to BytesIO for BGE-VL compatibility
154+
bytesio_image = self._convert_pil_to_bytesio(pil_image)
155+
image_inputs.append(bytesio_image)
156+
elif inp.get('type') == 'multimodal':
157+
# Text + Image input
158+
text = inp.get('text', '')
159+
image_data = inp.get('image')
160+
if image_data:
161+
pil_image = self._process_base64_image(image_data)
162+
# Convert PIL Image to BytesIO for BGE-VL compatibility
163+
bytesio_image = self._convert_pil_to_bytesio(pil_image)
164+
multimodal_inputs.append((text, bytesio_image))
165+
166+
return text_inputs, image_inputs, multimodal_inputs
167+
168+
def _generate_bge_vl_embeddings(self, inputs):
169+
"""Generate embeddings using BGE-VL model"""
170+
text_inputs, image_inputs, multimodal_inputs = self._parse_multimodal_inputs(inputs)
171+
all_embeddings = []
172+
173+
# Process text-only inputs
174+
if text_inputs:
175+
try:
176+
# Use explicit text= parameter for BGE-VL model
177+
text_embeddings = self.model.encode(text=text_inputs)
178+
if hasattr(text_embeddings, 'tolist'):
179+
all_embeddings.extend(text_embeddings.tolist())
180+
else:
181+
all_embeddings.extend(text_embeddings)
182+
except Exception as e:
183+
logger.error(f"Failed to encode text inputs: {e}")
184+
raise ValueError(f"BGE-VL text encoding failed: {e}")
185+
186+
# Process image-only inputs
187+
if image_inputs:
188+
try:
189+
# Use explicit images= parameter with list format
190+
image_embeddings = self.model.encode(images=image_inputs)
191+
if hasattr(image_embeddings, 'tolist'):
192+
all_embeddings.extend(image_embeddings.tolist())
193+
else:
194+
all_embeddings.extend(image_embeddings)
195+
except Exception as e:
196+
logger.error(f"Failed to encode image inputs: {e}")
197+
raise ValueError(f"BGE-VL image encoding failed: {e}")
198+
199+
# Process multimodal inputs (text + image)
200+
if multimodal_inputs:
201+
for text, bytesio_image in multimodal_inputs:
202+
try:
203+
# Use explicit parameters with list format for both text and images
204+
multimodal_embedding = self.model.encode(text=[text], images=[bytesio_image])
205+
if hasattr(multimodal_embedding, 'tolist'):
206+
all_embeddings.append(multimodal_embedding.tolist())
207+
else:
208+
all_embeddings.append(multimodal_embedding)
209+
except Exception as e:
210+
logger.error(f"Failed to encode multimodal input: {e}")
211+
raise ValueError(f"BGE-VL multimodal encoding failed: {e}")
212+
213+
return all_embeddings
214+
90215
def invoke(self, request:dict):
91216
inputs = request['input']
92217
if not inputs:
93218
return []
94219

95-
task = request.get('task', 'text-matching')
96-
truncate_dim = request.get('truncate_dim', None)
97220
logger.info(f'request: {request}')
98221
t0 = time.time()
99-
embeddings = self.model.encode(inputs, task=task,truncate_dim=truncate_dim)
100-
embeddings_list = embeddings.tolist()
101-
logger.info(f'embeddings res: {embeddings_list},\nelapsed time: {time.time()-t0}')
222+
223+
if self.is_bge_vl:
224+
# Use BGE-VL multimodal processing
225+
embeddings_list = self._generate_bge_vl_embeddings(inputs)
226+
else:
227+
# Use standard text embedding processing
228+
task = request.get('task', 'text-matching')
229+
truncate_dim = request.get('truncate_dim', None)
230+
embeddings = self.model.encode(inputs, task=task, truncate_dim=truncate_dim)
231+
embeddings_list = embeddings.tolist()
232+
233+
logger.info(f'embeddings generated, count: {len(embeddings_list)}, elapsed time: {time.time()-t0}')
102234
return self.format_openai_response(embeddings_list)
235+
236+
async def ainvoke(self, request: dict):
237+
"""Async version of invoke method"""
238+
return self.invoke(request)

0 commit comments

Comments
 (0)