11
11
12
12
import openai
13
13
14
+ from typing import Dict , Union , Optional
15
+ from collections import OrderedDict
14
16
from flask import Flask , request , jsonify , abort
15
17
from sentence_transformers import SentenceTransformer
16
18
22
24
23
25
24
26
class Config :
25
- def __init__ (self , config_file ):
27
+ def __init__ (self , config_file : str ):
26
28
self .config_file = config_file
27
29
self .config = configparser .ConfigParser ()
28
30
if not os .path .exists (self .config_file ):
@@ -32,7 +34,7 @@ def __init__(self, config_file):
32
34
logging .info (f'Loading config file: { self .config_file } ' )
33
35
self .config .read (config_file )
34
36
35
- def get_val (self , section , key ) :
37
+ def get_val (self , section : str , key : str ) -> Optional [ str ] :
36
38
answer = None
37
39
38
40
try :
@@ -42,7 +44,7 @@ def get_val(self, section, key):
42
44
43
45
return answer
44
46
45
- def get_bool (self , section , key , default = False ):
47
+ def get_bool (self , section : str , key : str , default : bool = False ) -> bool :
46
48
try :
47
49
return self .config .getboolean (section , key )
48
50
except Exception as err :
@@ -51,23 +53,28 @@ def get_bool(self, section, key, default=False):
51
53
52
54
53
55
class EmbeddingCache :
54
- def __init__ (self ):
55
- logger .info ('Created in-memory cache' )
56
- self .cache = {}
56
+ def __init__ (self , max_size : int = 500 ):
57
+ logger .info (f'Created in-memory cache; max size={ max_size } ' )
58
+ self .cache = OrderedDict ()
59
+ self .max_size = max_size
57
60
58
- def get_cache_key (self , text , model_type ) :
61
+ def get_cache_key (self , text : str , model_type : str ) -> str :
59
62
return hashlib .sha256 ((text + model_type ).encode ()).hexdigest ()
60
63
61
- def get (self , text , model_type ):
64
+ def get (self , text : str , model_type : str ):
62
65
return self .cache .get (self .get_cache_key (text , model_type ))
63
66
64
- def set (self , text , model_type , embedding ):
65
- self .cache [self .get_cache_key (text , model_type )] = embedding
67
+ def set (self , text : str , model_type : str , embedding ):
68
+ key = self .get_cache_key (text , model_type )
69
+ self .cache [key ] = embedding
70
+ if len (self .cache ) > self .max_size :
71
+ self .cache .popitem (last = False )
66
72
67
73
68
74
class EmbeddingGenerator :
69
- def __init__ (self , sbert_model = None , openai_key = None ):
75
+ def __init__ (self , sbert_model : Optional [ str ] = None , openai_key : Optional [ str ] = None ):
70
76
self .sbert_model = sbert_model
77
+ self .openai_key = openai_key
71
78
if self .sbert_model is not None :
72
79
try :
73
80
self .model = SentenceTransformer (self .sbert_model )
@@ -77,48 +84,55 @@ def __init__(self, sbert_model=None, openai_key=None):
77
84
sys .exit (1 )
78
85
79
86
if openai_key is not None :
80
- openai .api_key = openai_key
87
+ openai .api_key = self . openai_key
81
88
logger .info ('enabled model: text-embedding-ada-002' )
82
89
83
- def get_openai_embeddings (self , text ) :
90
+ def generate (self , text : str , model_type : str ) -> Dict [ str , Union [ str , float , list ]] :
84
91
start_time = time .time ()
92
+ result = {'status' : 'success' }
85
93
86
- try :
87
- response = openai .Embedding .create (input = text , model = 'text-embedding-ada-002' )
88
- elapsed_time = (time .time () - start_time ) * 1000
89
- data = {
90
- "embedding" : response ['data' ][0 ]['embedding' ],
91
- "status" : "success" ,
92
- "elapsed" : elapsed_time ,
93
- "model" : "text-embedding-ada-002"
94
- }
95
- return data
96
- except Exception as err :
97
- logger .error (f'Failed to get OpenAI embeddings: { err } ' )
98
- return {"status" : "error" , "message" : str (err ), "model" : "text-embedding-ada-002" }
99
-
100
- def get_transformers_embeddings (self , text ):
101
- start_time = time .time ()
102
-
103
- try :
104
- embedding = self .model .encode (text ).tolist ()
105
- elapsed_time = (time .time () - start_time ) * 1000
106
- data = {
107
- "embedding" : embedding ,
108
- "status" : "success" ,
109
- "elapsed" : elapsed_time ,
110
- "model" : self .sbert_model
111
- }
112
- return data
113
- except Exception as err :
114
- logger .error (f'Failed to get sentence-transformers embeddings: { err } ' )
115
- return {"status" : "error" , "message" : str (err ), "model" : self .sbert_model }
116
-
117
- def generate (self , text , model_type ):
118
94
if model_type == 'openai' :
119
- return self .get_openai_embeddings (text )
95
+ try :
96
+ response = openai .Embedding .create (input = text , model = 'text-embedding-ada-002' )
97
+ result ['embedding' ] = response ['data' ][0 ]['embedding' ]
98
+ result ['model' ] = 'text-embedding-ada-002'
99
+ except Exception as err :
100
+ logger .error (f'Failed to get OpenAI embeddings: { err } ' )
101
+ result ['status' ] = 'error'
102
+ result ['message' ] = str (err )
103
+
120
104
else :
121
- return self .get_transformers_embeddings (text )
105
+ try :
106
+ embedding = self .model .encode (text ).tolist ()
107
+ result ['embedding' ] = embedding
108
+ result ['model' ] = self .sbert_model
109
+ except Exception as err :
110
+ logger .error (f'Failed to get sentence-transformers embeddings: { err } ' )
111
+ result ['status' ] = 'error'
112
+ result ['message' ] = str (err )
113
+
114
+ result ['elapsed' ] = (time .time () - start_time ) * 1000
115
+ return result
116
+
117
+
118
+ @app .route ('/health' , methods = ['GET' ])
119
+ def health_check ():
120
+ sbert_on = embedding_generator .sbert_model if embedding_generator .sbert_model else 'disabled'
121
+ openai_on = True if embedding_generator .openai_key else 'disabled'
122
+
123
+ health_status = {
124
+ "models" : {
125
+ "openai" : openai_on ,
126
+ 'sentence-transformers' : sbert_on
127
+ },
128
+ "cache" : {
129
+ "enabled" : embedding_cache is not None ,
130
+ "size" : len (embedding_cache .cache ) if embedding_cache else None ,
131
+ "max_size" : None
132
+ }
133
+ }
134
+
135
+ return jsonify (health_status )
122
136
123
137
124
138
@app .route ('/submit' , methods = ['POST' ])
@@ -134,22 +148,32 @@ def submit_text():
134
148
if model_type not in ['local' , 'openai' ]:
135
149
abort (400 , 'model field must be one of: local, openai' )
136
150
137
- if embedding_cache :
138
- result = embedding_cache .get (text_data , model_type )
139
- if result :
140
- logger .info ('found embedding in cache!' )
141
- result = {'embedding' : result , 'cache' : True , "status" : 'success' }
142
- else :
151
+ if isinstance (text_data , str ):
152
+ text_data = [text_data ]
153
+
154
+ if not all (isinstance (text , str ) for text in text_data ):
155
+ abort (400 , 'all data must be text strings' )
156
+
157
+ results = []
158
+ for text in text_data :
143
159
result = None
144
160
145
- if result is None :
146
- result = embedding_generator .generate (text_data , model_type )
161
+ if embedding_cache :
162
+ result = embedding_cache .get (text , model_type )
163
+ if result :
164
+ logger .info ('found embedding in cache!' )
165
+ result = {'embedding' : result , 'cache' : True , "status" : 'success' }
166
+
167
+ if result is None :
168
+ result = embedding_generator .generate (text , model_type )
169
+
170
+ if embedding_cache and result ['status' ] == 'success' :
171
+ embedding_cache .set (text , model_type , result ['embedding' ])
172
+ logger .info ('added to cache' )
147
173
148
- if embedding_cache and result ['status' ] == 'success' :
149
- embedding_cache .set (text_data , model_type , result ['embedding' ])
150
- logger .info ('added to cache' )
174
+ results .append (result )
151
175
152
- return jsonify (result )
176
+ return jsonify (results )
153
177
154
178
155
179
if __name__ == '__main__' :
@@ -168,6 +192,8 @@ def submit_text():
168
192
openai_key = conf .get_val ('main' , 'openai_api_key' )
169
193
sbert_model = conf .get_val ('main' , 'sent_transformers_model' )
170
194
use_cache = conf .get_bool ('main' , 'use_cache' , default = False )
195
+ if use_cache :
196
+ max_cache_size = int (conf .get_val ('main' , 'cache_max' ))
171
197
172
198
if openai_key is None :
173
199
logger .warn ('No OpenAI API key set in configuration file: server.conf' )
@@ -179,7 +205,7 @@ def submit_text():
179
205
logger .error ('No sbert model set *and* no openAI key set; exiting' )
180
206
sys .exit (1 )
181
207
182
- embedding_cache = EmbeddingCache () if use_cache else None
208
+ embedding_cache = EmbeddingCache (max_cache_size ) if use_cache else None
183
209
embedding_generator = EmbeddingGenerator (sbert_model , openai_key )
184
210
185
211
app .run (debug = True )
0 commit comments