Skip to content

Commit 8380372

Browse files
committed
cache improvements; batch processing; health check endpoint
1 parent 7fad723 commit 8380372

File tree

1 file changed

+86
-60
lines changed

1 file changed

+86
-60
lines changed

server.py

Lines changed: 86 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
import openai
1313

14+
from typing import Dict, Union, Optional
15+
from collections import OrderedDict
1416
from flask import Flask, request, jsonify, abort
1517
from sentence_transformers import SentenceTransformer
1618

@@ -22,7 +24,7 @@
2224

2325

2426
class Config:
25-
def __init__(self, config_file):
27+
def __init__(self, config_file: str):
2628
self.config_file = config_file
2729
self.config = configparser.ConfigParser()
2830
if not os.path.exists(self.config_file):
@@ -32,7 +34,7 @@ def __init__(self, config_file):
3234
logging.info(f'Loading config file: {self.config_file}')
3335
self.config.read(config_file)
3436

35-
def get_val(self, section, key):
37+
def get_val(self, section: str, key: str) -> Optional[str]:
3638
answer = None
3739

3840
try:
@@ -42,7 +44,7 @@ def get_val(self, section, key):
4244

4345
return answer
4446

45-
def get_bool(self, section, key, default=False):
47+
def get_bool(self, section: str, key: str, default: bool = False) -> bool:
4648
try:
4749
return self.config.getboolean(section, key)
4850
except Exception as err:
@@ -51,23 +53,28 @@ def get_bool(self, section, key, default=False):
5153

5254

5355
class EmbeddingCache:
54-
def __init__(self):
55-
logger.info('Created in-memory cache')
56-
self.cache = {}
56+
def __init__(self, max_size: int = 500):
57+
logger.info(f'Created in-memory cache; max size={max_size}')
58+
self.cache = OrderedDict()
59+
self.max_size = max_size
5760

58-
def get_cache_key(self, text, model_type):
61+
def get_cache_key(self, text: str, model_type: str) -> str:
5962
return hashlib.sha256((text + model_type).encode()).hexdigest()
6063

61-
def get(self, text, model_type):
64+
def get(self, text: str, model_type: str):
6265
return self.cache.get(self.get_cache_key(text, model_type))
6366

64-
def set(self, text, model_type, embedding):
65-
self.cache[self.get_cache_key(text, model_type)] = embedding
67+
def set(self, text: str, model_type: str, embedding):
68+
key = self.get_cache_key(text, model_type)
69+
self.cache[key] = embedding
70+
if len(self.cache) > self.max_size:
71+
self.cache.popitem(last=False)
6672

6773

6874
class EmbeddingGenerator:
69-
def __init__(self, sbert_model=None, openai_key=None):
75+
def __init__(self, sbert_model: Optional[str] = None, openai_key: Optional[str] = None):
7076
self.sbert_model = sbert_model
77+
self.openai_key = openai_key
7178
if self.sbert_model is not None:
7279
try:
7380
self.model = SentenceTransformer(self.sbert_model)
@@ -77,48 +84,55 @@ def __init__(self, sbert_model=None, openai_key=None):
7784
sys.exit(1)
7885

7986
if openai_key is not None:
80-
openai.api_key = openai_key
87+
openai.api_key = self.openai_key
8188
logger.info('enabled model: text-embedding-ada-002')
8289

83-
def get_openai_embeddings(self, text):
90+
def generate(self, text: str, model_type: str) -> Dict[str, Union[str, float, list]]:
8491
start_time = time.time()
92+
result = {'status': 'success'}
8593

86-
try:
87-
response = openai.Embedding.create(input=text, model='text-embedding-ada-002')
88-
elapsed_time = (time.time() - start_time) * 1000
89-
data = {
90-
"embedding": response['data'][0]['embedding'],
91-
"status": "success",
92-
"elapsed": elapsed_time,
93-
"model": "text-embedding-ada-002"
94-
}
95-
return data
96-
except Exception as err:
97-
logger.error(f'Failed to get OpenAI embeddings: {err}')
98-
return {"status": "error", "message": str(err), "model": "text-embedding-ada-002"}
99-
100-
def get_transformers_embeddings(self, text):
101-
start_time = time.time()
102-
103-
try:
104-
embedding = self.model.encode(text).tolist()
105-
elapsed_time = (time.time() - start_time) * 1000
106-
data = {
107-
"embedding": embedding,
108-
"status": "success",
109-
"elapsed": elapsed_time,
110-
"model": self.sbert_model
111-
}
112-
return data
113-
except Exception as err:
114-
logger.error(f'Failed to get sentence-transformers embeddings: {err}')
115-
return {"status": "error", "message": str(err), "model": self.sbert_model}
116-
117-
def generate(self, text, model_type):
11894
if model_type == 'openai':
119-
return self.get_openai_embeddings(text)
95+
try:
96+
response = openai.Embedding.create(input=text, model='text-embedding-ada-002')
97+
result['embedding'] = response['data'][0]['embedding']
98+
result['model'] = 'text-embedding-ada-002'
99+
except Exception as err:
100+
logger.error(f'Failed to get OpenAI embeddings: {err}')
101+
result['status'] = 'error'
102+
result['message'] = str(err)
103+
120104
else:
121-
return self.get_transformers_embeddings(text)
105+
try:
106+
embedding = self.model.encode(text).tolist()
107+
result['embedding'] = embedding
108+
result['model'] = self.sbert_model
109+
except Exception as err:
110+
logger.error(f'Failed to get sentence-transformers embeddings: {err}')
111+
result['status'] = 'error'
112+
result['message'] = str(err)
113+
114+
result['elapsed'] = (time.time() - start_time) * 1000
115+
return result
116+
117+
118+
@app.route('/health', methods=['GET'])
119+
def health_check():
120+
sbert_on = embedding_generator.sbert_model if embedding_generator.sbert_model else 'disabled'
121+
openai_on = True if embedding_generator.openai_key else 'disabled'
122+
123+
health_status = {
124+
"models": {
125+
"openai": openai_on,
126+
'sentence-transformers': sbert_on
127+
},
128+
"cache": {
129+
"enabled": embedding_cache is not None,
130+
"size": len(embedding_cache.cache) if embedding_cache else None,
131+
"max_size": None
132+
}
133+
}
134+
135+
return jsonify(health_status)
122136

123137

124138
@app.route('/submit', methods=['POST'])
@@ -134,22 +148,32 @@ def submit_text():
134148
if model_type not in ['local', 'openai']:
135149
abort(400, 'model field must be one of: local, openai')
136150

137-
if embedding_cache:
138-
result = embedding_cache.get(text_data, model_type)
139-
if result:
140-
logger.info('found embedding in cache!')
141-
result = {'embedding': result, 'cache': True, "status": 'success'}
142-
else:
151+
if isinstance(text_data, str):
152+
text_data = [text_data]
153+
154+
if not all(isinstance(text, str) for text in text_data):
155+
abort(400, 'all data must be text strings')
156+
157+
results = []
158+
for text in text_data:
143159
result = None
144160

145-
if result is None:
146-
result = embedding_generator.generate(text_data, model_type)
161+
if embedding_cache:
162+
result = embedding_cache.get(text, model_type)
163+
if result:
164+
logger.info('found embedding in cache!')
165+
result = {'embedding': result, 'cache': True, "status": 'success'}
166+
167+
if result is None:
168+
result = embedding_generator.generate(text, model_type)
169+
170+
if embedding_cache and result['status'] == 'success':
171+
embedding_cache.set(text, model_type, result['embedding'])
172+
logger.info('added to cache')
147173

148-
if embedding_cache and result['status'] == 'success':
149-
embedding_cache.set(text_data, model_type, result['embedding'])
150-
logger.info('added to cache')
174+
results.append(result)
151175

152-
return jsonify(result)
176+
return jsonify(results)
153177

154178

155179
if __name__ == '__main__':
@@ -168,6 +192,8 @@ def submit_text():
168192
openai_key = conf.get_val('main', 'openai_api_key')
169193
sbert_model = conf.get_val('main', 'sent_transformers_model')
170194
use_cache = conf.get_bool('main', 'use_cache', default=False)
195+
if use_cache:
196+
max_cache_size = int(conf.get_val('main', 'cache_max'))
171197

172198
if openai_key is None:
173199
logger.warn('No OpenAI API key set in configuration file: server.conf')
@@ -179,7 +205,7 @@ def submit_text():
179205
logger.error('No sbert model set *and* no openAI key set; exiting')
180206
sys.exit(1)
181207

182-
embedding_cache = EmbeddingCache() if use_cache else None
208+
embedding_cache = EmbeddingCache(max_cache_size) if use_cache else None
183209
embedding_generator = EmbeddingGenerator(sbert_model, openai_key)
184210

185211
app.run(debug=True)

0 commit comments

Comments
 (0)