|
8 | 8 | from openai import OpenAI, AzureOpenAI
|
9 | 9 | from openai import AuthenticationError, RateLimitError
|
10 | 10 | import time
|
| 11 | +import os |
| 12 | +from transformers import AutoTokenizer |
| 13 | + |
| 14 | + |
| 15 | +PRIVATEMODE_AI_URL = os.getenv("PRIVATEMODE_AI_URL", "http://privatemode-proxy:8080/v1") |
11 | 16 |
|
12 | 17 |
|
13 | 18 | class TransformerSentenceEmbedder(SentenceEmbedder):
|
@@ -199,3 +204,77 @@ def dump(self, project_id: str, embedding_id: str) -> None:
|
199 | 204 | export_file = util.INFERENCE_DIR / project_id / f"embedder-{embedding_id}.json"
|
200 | 205 | export_file.parent.mkdir(parents=True, exist_ok=True)
|
201 | 206 | util.write_json(self.to_json(), export_file, indent=2)
|
| 207 | + |
| 208 | + |
| 209 | +class PrivatemodeAISentenceEmbedder(SentenceEmbedder): |
| 210 | + |
| 211 | + def __init__( |
| 212 | + self, |
| 213 | + batch_size: int = 128, |
| 214 | + model_name: str = "intfloat/multilingual-e5-large-instruct", |
| 215 | + ): |
| 216 | + """ |
| 217 | + Embeds documents using privatemode ai proxy via OpenAI classes. |
| 218 | + Note that the model and api key are currently hardcoded since they aren't configurable. |
| 219 | +
|
| 220 | + Args: |
| 221 | + batch_size (int, optional): Defines the number of conversions after which the embedder yields. Defaults to 128. |
| 222 | + model_name (str, optional): Name of the embedding model from Privatemode AI (e.g. intfloat/multilingual-e5-large-instruct). Defaults to "intfloat/multilingual-e5-large-instruct". |
| 223 | +
|
| 224 | + Raises: |
| 225 | + Exception: If you use Azure, you need to provide api_type, api_version and api_base. |
| 226 | +
|
| 227 | +
|
| 228 | + """ |
| 229 | + super().__init__(batch_size) |
| 230 | + self.model_name = model_name |
| 231 | + self.openai_client = OpenAI( |
| 232 | + api_key="dummy", # Set in proxy |
| 233 | + base_url=PRIVATEMODE_AI_URL, |
| 234 | + ) |
| 235 | + # for trimming the length of the text if > 512 tokens |
| 236 | + self._auto_tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
| 237 | + |
| 238 | + def _encode( |
| 239 | + self, documents: List[Union[str, Doc]], fit_model: bool |
| 240 | + ) -> Generator[List[List[float]], None, None]: |
| 241 | + for documents_batch in util.batch(documents, self.batch_size): |
| 242 | + documents_batch = [self._trim_length(doc.replace("\n", " ")) for doc in documents_batch] |
| 243 | + try: |
| 244 | + response = self.openai_client.embeddings.create( |
| 245 | + input=documents_batch, model=self.model_name |
| 246 | + ) |
| 247 | + embeddings = [entry.embedding for entry in response.data] |
| 248 | + yield embeddings |
| 249 | + except AuthenticationError: |
| 250 | + raise Exception( |
| 251 | + "OpenAI API key is invalid. Please provide a valid API key in the constructor of PrivatemodeAISentenceEmbedder." |
| 252 | + ) |
| 253 | + |
| 254 | + @staticmethod |
| 255 | + def load(embedder: dict) -> "PrivatemodeAISentenceEmbedder": |
| 256 | + return PrivatemodeAISentenceEmbedder( |
| 257 | + model_name=embedder["model_name"], |
| 258 | + batch_size=embedder["batch_size"], |
| 259 | + ) |
| 260 | + |
| 261 | + def to_json(self) -> dict: |
| 262 | + return { |
| 263 | + "cls": "PrivatemodeAISentenceEmbedder", |
| 264 | + "model_name": self.model_name, |
| 265 | + "batch_size": self.batch_size, |
| 266 | + } |
| 267 | + |
| 268 | + def dump(self, project_id: str, embedding_id: str) -> None: |
| 269 | + export_file = util.INFERENCE_DIR / project_id / f"embedder-{embedding_id}.json" |
| 270 | + export_file.parent.mkdir(parents=True, exist_ok=True) |
| 271 | + util.write_json(self.to_json(), export_file, indent=2) |
| 272 | + |
| 273 | + def _trim_length(self, text: str, max_length: int=512) -> str: |
| 274 | + tokens = self._auto_tokenizer( |
| 275 | + text, |
| 276 | + truncation=True, |
| 277 | + max_length=max_length, |
| 278 | + return_tensors=None # No tensors needed for just truncating |
| 279 | + ) |
| 280 | + return self._auto_tokenizer.decode(tokens["input_ids"], skip_special_tokens=True) |
0 commit comments