Skip to content

Commit fd7ae2a

Browse files
authored
Optimized embeddings generation (#119)
* Add functionality to load existing embeddings from output JSON file thereby optimizing the process of generating embeddings for new data. Signed-off-by: Ayesha Imran <ayesha.i1505@gmail.com> Signed-off-by: GitHub <noreply@github.com> * Add functionality to load existing embeddings from output JSON file thereby optimizing the process of generating embeddings for new data. Signed-off-by: Ayesha Imran <ayesha.i1505@gmail.com> Signed-off-by: GitHub <noreply@github.com> --------- Signed-off-by: Ayesha Imran <ayesha.i1505@gmail.com> Signed-off-by: GitHub <noreply@github.com>
1 parent 2ce0dfb commit fd7ae2a

File tree

2 files changed

+47
-13
lines changed

2 files changed

+47
-13
lines changed

customize/customize_embeddings.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,27 @@
4343
# OUTPUT FILE
4444
json_out_file_name = f'{json_in_file_name}-{model_id}.json'
4545

46+
# check if the output file already exists
47+
if os.path.exists(json_out_file_name):
48+
try:
49+
# Load existing data from the output file
50+
existing_data = customize_helper.load_json(json_out_file_name)
51+
except Exception as e:
52+
existing_data = None
53+
54+
# hashmap
55+
prompts_embeddings = {}
56+
if existing_data:
57+
for d in existing_data["positive_values"]:
58+
for p in d["prompts"]:
59+
prompts_embeddings[p["text"]] = p["embedding"]
60+
for d in existing_data["negative_values"]:
61+
for p in d["prompts"]:
62+
prompts_embeddings[p["text"]] = p["embedding"]
63+
64+
65+
4666
prompt_json = json.load(open(json_in_file))
47-
prompt_json_embeddings = customize_helper.populate_embeddings(prompt_json, model_path)
67+
prompt_json_embeddings = customize_helper.populate_embeddings(prompt_json, model_path, prompts_embeddings)
4868
prompt_json_centroids = customize_helper.populate_centroids(prompt_json_embeddings)
4969
customize_helper.save_json(prompt_json_centroids, json_out_file_name)

customize/customize_helper.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,13 @@ def get_centroid(v, dimension = 384, k = 10):
9191
i += 1
9292
return centroid
9393

94-
def populate_embeddings(prompt_json, model_path):
95-
errors, successess = 0, 0
94+
def populate_embeddings(prompt_json, model_path, prompts_embeddings):
95+
errors, successes = 0, 0
9696
for v in prompt_json['positive_values']:
9797
for p in v['prompts']:
98+
if (p['text'] in prompts_embeddings):
99+
p['embedding'] = prompts_embeddings[p['text']]
100+
else:
98101
if( p['text'] != '' and p['embedding'] == []): # only considering missing embeddings
99102
embedding = query_model(p['text'], model_path)
100103
if( 'error' in embedding ):
@@ -106,14 +109,17 @@ def populate_embeddings(prompt_json, model_path):
106109

107110
for v in prompt_json['negative_values']:
108111
for p in v['prompts']:
109-
if(p['text'] != '' and p['embedding'] == []):
110-
embedding = query_model(p['text'], model_path)
111-
if('error' in embedding):
112-
p['embedding'] = []
113-
errors += 1
114-
else:
115-
p['embedding'] = embedding.tolist()
116-
#successes += 1
112+
if (p['text'] in prompts_embeddings):
113+
p['embedding'] = prompts_embeddings[p['text']]
114+
else:
115+
if(p['text'] != '' and p['embedding'] == []):
116+
embedding = query_model(p['text'], model_path)
117+
if('error' in embedding):
118+
p['embedding'] = []
119+
errors += 1
120+
else:
121+
p['embedding'] = embedding.tolist()
122+
#successes += 1
117123
return prompt_json
118124

119125
def populate_centroids(prompt_json):
@@ -123,7 +129,15 @@ def populate_centroids(prompt_json):
123129
v['centroid'] = get_centroid(v, dimension = 384, k = 10)
124130
return prompt_json
125131

126-
# Saving the embeddings for a specific LLM
132+
# Saving the embeddings for a specific LLM
127133
def save_json(prompt_json, json_out_file_name):
128134
with open(json_out_file_name, 'w') as outfile:
129-
json.dump(prompt_json, outfile)
135+
json.dump(prompt_json, outfile)
136+
137+
# load existing data from a JSON file
138+
def load_json(json_out_file):
139+
if os.path.exists(json_out_file):
140+
with open(json_out_file, 'r') as infile:
141+
return json.load(infile)
142+
else:
143+
return None

0 commit comments

Comments
 (0)