Skip to content

Commit 399894f

Browse files
authored
fix: jina embedding v4 (#147)
* merge * merge * add Mistral-Small-3.1-24B-Instruct-2503 * modify qwq-32b deploy * add txgemma model; * modify model list command * fix typo * add some ecs parameters * add glm4-z1 models * modify vllm backend * add qwen3 * fix cli bugs * fix * add deeseek r1/Qwen3-235B-A22B * fix local deploy account bug * add qwen 3 awq models * fix serialize_utils bugs * modify qwen3 deployment * modify docs
1 parent 1ff5900 commit 399894f

File tree

4 files changed

+28
-16
lines changed

4 files changed

+28
-16
lines changed

src/emd/models/embeddings/jina.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .. import Model
2-
from ..engines import huggingface_embedding_engine449
2+
from ..engines import huggingface_embedding_engine449,vllm_embedding_engine091
33
from ..services import sagemaker_service,local_service,ecs_service
44
from ..frameworks import fastapi_framework
55
from ..instances import (
@@ -57,12 +57,13 @@
5757
Model.register(
5858
dict(
5959
model_id = "jina-embeddings-v4-vllm-retrieval",
60-
supported_engines=[huggingface_embedding_engine449],
60+
supported_engines=[vllm_embedding_engine091],
6161
supported_instances=[
6262
g5dxlarge_instance,
6363
g5d2xlarge_instance,
6464
g5d4xlarge_instance,
65-
g5d8xlarge_instance
65+
g5d8xlarge_instance,
66+
local_instance,
6667
],
6768
supported_services=[
6869
sagemaker_service,

src/emd/models/engines.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,13 @@ class KtransformersEngine(OpenAICompitableEngine):
176176
"default_cli_args": " --max_model_len 16000 --max_num_seq 30 --disable-log-stats --enable-reasoning --reasoning-parser qwen3 --enable-auto-tool-choice --tool-call-parser hermes --enable-prefix-caching"
177177
})
178178

179+
vllm_embedding_engine091 = VllmEngine(**{
180+
**vllm_engine064.model_dump(),
181+
"engine_dockerfile_config": {"VERSION":"v0.9.1"},
182+
"environment_variables": "export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True",
183+
"default_cli_args": " --max_num_seq 30 --disable-log-stats --trust-remote-code --task embed"
184+
})
185+
179186

180187
vllm_qwen2vl72b_engine064 = VllmEngine(**{
181188
**vllm_engine064.model_dump(),

tests/sdk_tests/client_tests/langchain_client_embedding_and_rerank_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from emd.integrations.langchain_clients import SageMakerVllmRerank
44

55
embedding_model = SageMakerVllmEmbeddings(
6-
model_id="bge-m3",
6+
model_id="jina-embeddings-v4-vllm-retrieval",
77
# model_tag='dev-2'
88
)
99

tests/sdk_tests/client_tests/openai_embedding_local_test.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
# Modify OpenAI's API key and API base to use vLLM's API server.
77
openai_api_key = "EMPTY"
8-
openai_api_base = "http://localhost:8000/v1"
8+
openai_api_base = "http://localhost:8080/v1"
99

1010

1111
def run():
@@ -15,8 +15,8 @@ def run():
1515
base_url=openai_api_base,
1616
)
1717

18-
models = client.models.list()
19-
model = models.data[0].id
18+
# models = client.models.list()
19+
# model = models.data[0].id
2020
t0 = time.time()
2121
responses = client.embeddings.create(
2222
# input=[
@@ -26,9 +26,11 @@ def run():
2626
input=[
2727
'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.'
2828
],
29-
model=model,
29+
model="jina-embeddings-v4-vllm-retrieval",
3030
)
31+
3132
print(f'elapsed time: {time.time()-t0}')
33+
print(responses)
3234

3335
# for data in responses.data:
3436
# print(data.embedding) # list of float of len 4096
@@ -38,12 +40,14 @@ def run():
3840
threads = []
3941
t0 = time.time()
4042

41-
for i in range(2000):
42-
# time.sleep(0.01)
43-
# t = Thread(target=task)
44-
t = Thread(target=run)
45-
threads.append(t)
46-
t.start()
47-
for t in threads:
48-
t.join()
43+
run()
44+
45+
# for i in range(2000):
46+
# # time.sleep(0.01)
47+
# # t = Thread(target=task)
48+
# t = Thread(target=run)
49+
# threads.append(t)
50+
# t.start()
51+
# for t in threads:
52+
# t.join()
4953
print("done, all task elapsed time: ",time.time()-t0)

0 commit comments

Comments
 (0)