Skip to content

Commit 79d909b

Browse files
authored
privatemode ai (#161)
* Enables privatemode ai as embedding provider * PR comment * Adds auto trim size for privatemode ai * Submodule merge
1 parent da649ea commit 79d909b

File tree

6 files changed

+89
-5
lines changed

6 files changed

+89
-5
lines changed

controller.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,17 @@
1414
import gc
1515
import os
1616
import pandas as pd
17-
import shutil
1817
from openai import APIConnectionError
1918

2019
from src.embedders import Transformer, util
2120

2221
# Embedder imports are used by eval(Embedder) in __setup_tmp_embedder
23-
from src.embedders.classification.contextual import (
22+
from src.embedders.classification.contextual import ( # noqa: F401
2423
OpenAISentenceEmbedder,
2524
HuggingFaceSentenceEmbedder,
25+
PrivatemodeAISentenceEmbedder,
2626
)
27-
from src.embedders.classification.reduce import PCASentenceReducer
27+
from src.embedders.classification.reduce import PCASentenceReducer # noqa: F401
2828
from src.util import daemon, request_util
2929
from src.util.decorator import param_throttle
3030
from src.util.embedders import get_embedder

src/embedders/classification/contextual.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
from openai import OpenAI, AzureOpenAI
99
from openai import AuthenticationError, RateLimitError
1010
import time
11+
import os
12+
from transformers import AutoTokenizer
13+
14+
15+
PRIVATEMODE_AI_URL = os.getenv("PRIVATEMODE_AI_URL", "http://privatemode-proxy:8080/v1")
1116

1217

1318
class TransformerSentenceEmbedder(SentenceEmbedder):
@@ -199,3 +204,77 @@ def dump(self, project_id: str, embedding_id: str) -> None:
199204
export_file = util.INFERENCE_DIR / project_id / f"embedder-{embedding_id}.json"
200205
export_file.parent.mkdir(parents=True, exist_ok=True)
201206
util.write_json(self.to_json(), export_file, indent=2)
207+
208+
209+
class PrivatemodeAISentenceEmbedder(SentenceEmbedder):
210+
211+
def __init__(
212+
self,
213+
batch_size: int = 128,
214+
model_name: str = "intfloat/multilingual-e5-large-instruct",
215+
):
216+
"""
217+
Embeds documents using privatemode ai proxy via OpenAI classes.
218+
Note that the model and api key are currently hardcoded since they aren't configurable.
219+
220+
Args:
221+
batch_size (int, optional): Defines the number of conversions after which the embedder yields. Defaults to 128.
222+
model_name (str, optional): Name of the embedding model from Privatemode AI (e.g. intfloat/multilingual-e5-large-instruct). Defaults to "intfloat/multilingual-e5-large-instruct".
223+
224+
Raises:
225+
Exception: If you use Azure, you need to provide api_type, api_version and api_base.
226+
227+
228+
"""
229+
super().__init__(batch_size)
230+
self.model_name = model_name
231+
self.openai_client = OpenAI(
232+
api_key="dummy", # Set in proxy
233+
base_url=PRIVATEMODE_AI_URL,
234+
)
235+
# for trimming the length of the text if > 512 tokens
236+
self._auto_tokenizer = AutoTokenizer.from_pretrained(self.model_name)
237+
238+
def _encode(
239+
self, documents: List[Union[str, Doc]], fit_model: bool
240+
) -> Generator[List[List[float]], None, None]:
241+
for documents_batch in util.batch(documents, self.batch_size):
242+
documents_batch = [self._trim_length(doc.replace("\n", " ")) for doc in documents_batch]
243+
try:
244+
response = self.openai_client.embeddings.create(
245+
input=documents_batch, model=self.model_name
246+
)
247+
embeddings = [entry.embedding for entry in response.data]
248+
yield embeddings
249+
except AuthenticationError:
250+
raise Exception(
251+
"OpenAI API key is invalid. Please provide a valid API key in the constructor of PrivatemodeAISentenceEmbedder."
252+
)
253+
254+
@staticmethod
255+
def load(embedder: dict) -> "PrivatemodeAISentenceEmbedder":
256+
return PrivatemodeAISentenceEmbedder(
257+
model_name=embedder["model_name"],
258+
batch_size=embedder["batch_size"],
259+
)
260+
261+
def to_json(self) -> dict:
262+
return {
263+
"cls": "PrivatemodeAISentenceEmbedder",
264+
"model_name": self.model_name,
265+
"batch_size": self.batch_size,
266+
}
267+
268+
def dump(self, project_id: str, embedding_id: str) -> None:
269+
export_file = util.INFERENCE_DIR / project_id / f"embedder-{embedding_id}.json"
270+
export_file.parent.mkdir(parents=True, exist_ok=True)
271+
util.write_json(self.to_json(), export_file, indent=2)
272+
273+
def _trim_length(self, text: str, max_length: int=512) -> str:
274+
tokens = self._auto_tokenizer(
275+
text,
276+
truncation=True,
277+
max_length=max_length,
278+
return_tensors=None # No tensors needed for just truncating
279+
)
280+
return self._auto_tokenizer.decode(tokens["input_ids"], skip_special_tokens=True)

src/embedders/classification/reduce.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
from src.embedders import PCAReducer, util
66

77
# Embedder imports are used by eval(Embedder) in load methods
8-
from src.embedders.classification.contextual import (
8+
from src.embedders.classification.contextual import ( # noqa: F401
99
OpenAISentenceEmbedder,
1010
HuggingFaceSentenceEmbedder,
11+
PrivatemodeAISentenceEmbedder,
1112
)
1213

1314

src/embedders/util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,4 @@ def read_json(file_path: str) -> dict[str, Any]:
3535
def write_json(obj: Any, file_path: str, **kwargs) -> None:
3636
with open(file_path, "w") as f:
3737
json.dump(obj, f, **kwargs)
38+

src/util/embedders.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from src.embedders.classification.contextual import (
33
OpenAISentenceEmbedder,
44
HuggingFaceSentenceEmbedder,
5+
PrivatemodeAISentenceEmbedder,
56
)
67
from src.embedders.extraction.contextual import TransformerTokenEmbedder
78
from src.embedders.classification.reduce import PCASentenceReducer
@@ -42,6 +43,8 @@ def get_embedder(
4243
embedder = HuggingFaceSentenceEmbedder(
4344
config_string=model, batch_size=batch_size
4445
)
46+
elif platform == enums.EmbeddingPlatform.PRIVATEMODE_AI.value:
47+
embedder = PrivatemodeAISentenceEmbedder(batch_size=batch_size)
4548
else:
4649
raise Exception(f"Unknown platform {platform}")
4750

submodules/model

Submodule model updated 1 file

0 commit comments

Comments
 (0)