Skip to content

Commit da649ea

Browse files
Merge pull request #160 from code-kern-ai/embedder-rework
Embedder rework
2 parents 787270f + 6389945 commit da649ea

File tree

13 files changed

+261
-375
lines changed

13 files changed

+261
-375
lines changed

controller.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,28 @@
33
from fastapi import status
44
from spacy.tokens import DocBin, Doc
55
from spacy.vocab import Vocab
6+
from functools import lru_cache
67

7-
import pickle
8+
import json
89
import torch
910
import traceback
1011
import logging
1112
import time
1213
import zlib
1314
import gc
1415
import os
15-
import openai
1616
import pandas as pd
17+
import shutil
18+
from openai import APIConnectionError
1719

18-
from src.embedders import Transformer
20+
from src.embedders import Transformer, util
21+
22+
# Embedder imports are used by eval(Embedder) in __setup_tmp_embedder
23+
from src.embedders.classification.contextual import (
24+
OpenAISentenceEmbedder,
25+
HuggingFaceSentenceEmbedder,
26+
)
27+
from src.embedders.classification.reduce import PCASentenceReducer
1928
from src.util import daemon, request_util
2029
from src.util.decorator import param_throttle
2130
from src.util.embedders import get_embedder
@@ -339,7 +348,7 @@ def run_encoding(
339348
enums.EmbeddingState.ENCODING.value,
340349
initial_count,
341350
)
342-
except openai.error.APIConnectionError as e:
351+
except APIConnectionError as e:
343352
embedding.update_embedding_state_failed(
344353
project_id,
345354
embedding_id,
@@ -407,9 +416,6 @@ def run_encoding(
407416
notification_message = "Access denied due to invalid api key."
408417
elif platform == enums.EmbeddingPlatform.AZURE.value:
409418
notification_message = "Access denied due to invalid subscription key or wrong endpoint data."
410-
elif error_message == "invalid api token":
411-
# cohere
412-
notification_message = "Access denied due to invalid api token."
413419
notification.create(
414420
project_id,
415421
user_id,
@@ -453,14 +459,7 @@ def run_encoding(
453459
request_util.post_embedding_to_neural_search(project_id, embedding_id)
454460

455461
# now always since otherwise record edit wouldn't work for embedded columns
456-
pickle_path = os.path.join(
457-
"/inference", project_id, f"embedder-{embedding_id}.pkl"
458-
)
459-
if not os.path.exists(pickle_path):
460-
os.makedirs(os.path.dirname(pickle_path), exist_ok=True)
461-
with open(pickle_path, "wb") as f:
462-
pickle.dump(embedder, f)
463-
462+
embedder.dump(project_id, embedding_id)
464463
upload_embedding_as_file(project_id, embedding_id)
465464
embedding.update_embedding_state_finished(
466465
project_id,
@@ -490,9 +489,8 @@ def delete_embedding(project_id: str, embedding_id: str) -> int:
490489
org_id = organization.get_id_by_project_id(project_id)
491490
s3.delete_object(org_id, f"{project_id}/{object_name}")
492491
request_util.delete_embedding_from_neural_search(embedding_id)
493-
pickle_path = os.path.join("/inference", project_id, f"embedder-{embedding_id}.pkl")
494-
if os.path.exists(pickle_path):
495-
os.remove(pickle_path)
492+
json_path = util.INFERENCE_DIR / project_id / f"embedder-{embedding_id}.json"
493+
json_path.unlink(missing_ok=True)
496494
return status.HTTP_200_OK
497495

498496

@@ -629,15 +627,18 @@ def re_embed_records(project_id: str, changes: Dict[str, List[Dict[str, str]]]):
629627

630628

631629
def __setup_tmp_embedder(project_id: str, embedder_id: str) -> Transformer:
632-
embedder_path = os.path.join(
633-
"/inference", project_id, f"embedder-{embedder_id}.pkl"
634-
)
635-
if not os.path.exists(embedder_path):
630+
embedder_path = util.INFERENCE_DIR / project_id / f"embedder-{embedder_id}.json"
631+
if not embedder_path.exists():
636632
raise Exception(f"Embedder {embedder_id} not found")
637-
with open(embedder_path, "rb") as f:
638-
embedder = pickle.load(f)
633+
return __load_embedder_by_path(embedder_path)
634+
639635

640-
return embedder
636+
@lru_cache(maxsize=32)
637+
def __load_embedder_by_path(embedder_path: str) -> Transformer:
638+
with open(embedder_path, "r") as f:
639+
embedder = json.load(f)
640+
Embedder = eval(embedder["cls"])
641+
return Embedder.load(embedder)
641642

642643

643644
def calc_tensors(project_id: str, embedding_id: str, texts: List[str]) -> List[Any]:

gpu-requirements.txt

Lines changed: 10 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,6 @@
66
#
77
--extra-index-url https://download.pytorch.org/whl/cu113
88

9-
aiohappyeyeballs==2.6.1
10-
# via aiohttp
11-
aiohttp==3.12.14
12-
# via openai
13-
aiosignal==1.4.0
14-
# via aiohttp
159
annotated-types==0.7.0
1610
# via
1711
# -r requirements/torch-cuda-requirements.txt
@@ -20,6 +14,7 @@ anyio==4.9.0
2014
# via
2115
# -r requirements/torch-cuda-requirements.txt
2216
# httpx
17+
# openai
2318
# starlette
2419
argon2-cffi==25.1.0
2520
# via
@@ -29,8 +24,6 @@ argon2-cffi-bindings==21.2.0
2924
# via
3025
# -r requirements/torch-cuda-requirements.txt
3126
# argon2-cffi
32-
attrs==25.3.0
33-
# via aiohttp
3427
blis==0.7.11
3528
# via thinc
3629
boto3==1.39.6
@@ -67,8 +60,6 @@ click==8.2.1
6760
# uvicorn
6861
cloudpathlib==0.21.1
6962
# via weasel
70-
cohere==5.16.1
71-
# via -r requirements/gpu-requirements.in
7263
confection==0.1.5
7364
# via
7465
# thinc
@@ -78,20 +69,16 @@ cymem==2.0.11
7869
# preshed
7970
# spacy
8071
# thinc
72+
distro==1.9.0
73+
# via openai
8174
fastapi==0.116.1
8275
# via -r requirements/torch-cuda-requirements.txt
83-
fastavro==1.11.1
84-
# via cohere
8576
filelock==3.18.0
8677
# via
8778
# -r requirements/torch-cuda-requirements.txt
8879
# huggingface-hub
8980
# torch
9081
# transformers
91-
frozenlist==1.7.0
92-
# via
93-
# aiohttp
94-
# aiosignal
9582
fsspec==2025.7.0
9683
# via
9784
# -r requirements/torch-cuda-requirements.txt
@@ -109,9 +96,7 @@ hf-xet==1.1.5
10996
httpcore==1.0.9
11097
# via httpx
11198
httpx==0.28.1
112-
# via cohere
113-
httpx-sse==0.4.0
114-
# via cohere
99+
# via openai
115100
huggingface-hub==0.33.4
116101
# via
117102
# -r requirements/torch-cuda-requirements.txt
@@ -124,12 +109,13 @@ idna==3.10
124109
# anyio
125110
# httpx
126111
# requests
127-
# yarl
128112
jinja2==3.1.6
129113
# via
130114
# -r requirements/torch-cuda-requirements.txt
131115
# spacy
132116
# torch
117+
jiter==0.10.0
118+
# via openai
133119
jmespath==1.0.1
134120
# via
135121
# -r requirements/torch-cuda-requirements.txt
@@ -160,10 +146,6 @@ mpmath==1.3.0
160146
# via
161147
# -r requirements/torch-cuda-requirements.txt
162148
# sympy
163-
multidict==6.6.3
164-
# via
165-
# aiohttp
166-
# yarl
167149
murmurhash==1.0.13
168150
# via
169151
# preshed
@@ -185,7 +167,7 @@ numpy==1.23.4
185167
# thinc
186168
# torchvision
187169
# transformers
188-
openai==0.28.1
170+
openai==1.97.1
189171
# via -r requirements/gpu-requirements.in
190172
packaging==25.0
191173
# via
@@ -205,10 +187,6 @@ preshed==3.0.10
205187
# via
206188
# spacy
207189
# thinc
208-
propcache==0.3.2
209-
# via
210-
# aiohttp
211-
# yarl
212190
psycopg2-binary==2.9.9
213191
# via -r requirements/torch-cuda-requirements.txt
214192
pyaml==25.7.0
@@ -226,16 +204,15 @@ pycryptodome==3.23.0
226204
pydantic==2.7.4
227205
# via
228206
# -r requirements/torch-cuda-requirements.txt
229-
# cohere
230207
# confection
231208
# fastapi
209+
# openai
232210
# spacy
233211
# thinc
234212
# weasel
235213
pydantic-core==2.18.4
236214
# via
237215
# -r requirements/torch-cuda-requirements.txt
238-
# cohere
239216
# pydantic
240217
pygments==2.19.2
241218
# via rich
@@ -261,9 +238,7 @@ regex==2024.11.6
261238
requests==2.32.4
262239
# via
263240
# -r requirements/torch-cuda-requirements.txt
264-
# cohere
265241
# huggingface-hub
266-
# openai
267242
# spacy
268243
# transformers
269244
# weasel
@@ -304,6 +279,7 @@ sniffio==1.3.1
304279
# via
305280
# -r requirements/torch-cuda-requirements.txt
306281
# anyio
282+
# openai
307283
spacy==3.7.5
308284
# via -r requirements/gpu-requirements.in
309285
spacy-legacy==3.0.12
@@ -335,7 +311,6 @@ threadpoolctl==3.6.0
335311
tokenizers==0.21.2
336312
# via
337313
# -r requirements/torch-cuda-requirements.txt
338-
# cohere
339314
# transformers
340315
torch==2.7.1
341316
# via
@@ -360,17 +335,14 @@ typer==0.16.0
360335
# via
361336
# spacy
362337
# weasel
363-
types-requests==2.32.4.20250611
364-
# via cohere
365338
typing-extensions==4.14.1
366339
# via
367340
# -r requirements/torch-cuda-requirements.txt
368-
# aiosignal
369341
# anyio
370-
# cohere
371342
# fastapi
372343
# huggingface-hub
373344
# minio
345+
# openai
374346
# pydantic
375347
# pydantic-core
376348
# sentence-transformers
@@ -383,7 +355,6 @@ urllib3==2.5.0
383355
# botocore
384356
# minio
385357
# requests
386-
# types-requests
387358
uvicorn==0.35.0
388359
# via -r requirements/torch-cuda-requirements.txt
389360
wasabi==1.1.3
@@ -395,8 +366,6 @@ weasel==0.4.1
395366
# via spacy
396367
wrapt==1.17.2
397368
# via smart-open
398-
yarl==1.20.1
399-
# via aiohttp
400369

401370
# The following packages are considered to be unsafe in a requirements file:
402371
# setuptools

0 commit comments

Comments
 (0)