Skip to content

Commit 39d29d0

Browse files
authored
fixing recommendation_handler.py
1 parent fca24f7 commit 39d29d0

File tree

1 file changed

+14
-151
lines changed

1 file changed

+14
-151
lines changed

control/recommendation_handler.py

Lines changed: 14 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,7 @@
3333
import numpy as np
3434
from sklearn.metrics.pairwise import cosine_similarity
3535
import os
36-
#os.environ['TRANSFORMERS_CACHE'] ="./models/allmini/cache"
37-
import os.path
3836
from sentence_transformers import SentenceTransformer
39-
import pickle
4037

4138
def populate_json(json_file_path = './prompt-sentences-main/prompt_sentences-all-minilm-l6-v2.json',
4239
existing_json_populated_file_path = './prompt-sentences-main/prompt_sentences-all-minilm-l6-v2.json'):
@@ -177,20 +174,9 @@ def recommend_prompt(
177174
Raises:
178175
Nothing.
179176
"""
180-
if(model_id == 'baai/bge-large-en-v1.5' ):
181-
json_file = './prompt-sentences-main/prompt_sentences-bge-large-en-v1.5.json'
182-
umap_model_file = './models/umap/intfloat/multilingual-e5-large/umap.pkl'
183-
elif(model_id == 'intfloat/multilingual-e5-large'):
184-
json_file = './prompt-sentences-main/prompt_sentences-multilingual-e5-large.json'
185-
umap_model_file = './models/umap/intfloat/multilingual-e5-large/umap.pkl'
186-
else: # fall back to all-minilm as default
187-
json_file = './prompt-sentences-main/prompt_sentences-all-minilm-l6-v2.json'
188-
umap_model_file = './models/umap/sentence-transformers/all-MiniLM-L6-v2/umap.pkl'
189-
190-
with open(umap_model_file, 'rb') as f:
191-
umap_model = pickle.load(f)
192-
193-
prompt_json = json.load( open( json_file ) )
177+
if embedding_fn is None:
178+
# Use all-MiniLM-L6-v2 locally by default
179+
embedding_fn = get_embedding_func('local', model_id='sentence-transformers/all-MiniLM-L6-v2')
194180

195181
# Output initialization
196182
out, out['input'], out['add'], out['remove'] = {}, {}, {}, {}
@@ -243,16 +229,17 @@ def recommend_prompt(
243229
)
244230

245231
# Recommendation of values to remove from the current prompt
246-
for sentence in input_sentences:
247-
input_embedding = query(sentence, api_url, headers) # remote
248-
# Obtaining XY coords for input sentences from a UMAP model
249-
if(len(prompt_json['negative_values'][0]['centroid']) == len(input_embedding) and sentence != ''):
250-
embeddings_umap = umap_model.transform(np.expand_dims(pd.DataFrame(input_embedding).squeeze(), axis=0))
251-
input_items.append({
252-
'sentence': sentence,
253-
'x': str(embeddings_umap[0][0]),
254-
'y': str(embeddings_umap[0][1])
255-
})
232+
for sent_idx, sentence in enumerate(input_sentences):
233+
input_embedding = inp_sentence_embeddings[sent_idx]
234+
if umap_model:
235+
# Obtaining XY coords for input sentences from a parametric UMAP model
236+
if(len(prompt_json['negative_values'][0]['centroid']) == len(input_embedding) and sentence != ''):
237+
embeddings_umap = umap_model.transform(np.expand_dims(pd.DataFrame(input_embedding).squeeze(), axis=0))
238+
input_items.append({
239+
'sentence': sentence,
240+
'x': str(embeddings_umap[0][0]),
241+
'y': str(embeddings_umap[0][1])
242+
})
256243

257244
for value_idx, v in enumerate(prompt_json['negative_values']):
258245
# Dealing with values without prompts and making sure they have the same dimensions
@@ -348,127 +335,3 @@ def get_thresholds(
348335
thresholds['remove_higher_threshold'] = round(remove_similarities_df.describe([.9]).loc['90%', 'similarity'], 1)
349336

350337
return thresholds
351-
352-
def recommend_local(prompt, prompt_json, model_id, model_path = './models/all-MiniLM-L6-v2/', add_lower_threshold = 0.3,
353-
add_upper_threshold = 0.5, remove_lower_threshold = 0.1,
354-
remove_upper_threshold = 0.5):
355-
"""
356-
Function that recommends prompts additions or removals
357-
using a local model.
358-
359-
Args:
360-
prompt: The entered prompt text.
361-
prompt_json: Json file populated with embeddings.
362-
model_id: Id of the local model.
363-
model_path: Path to the local model.
364-
365-
Returns:
366-
Prompt values to add or remove.
367-
368-
Raises:
369-
Nothing.
370-
"""
371-
if(model_id == 'baai/bge-large-en-v1.5' ):
372-
json_file = './prompt-sentences-main/prompt_sentences-bge-large-en-v1.5.json'
373-
umap_model_file = './models/umap/intfloat/multilingual-e5-large/umap.pkl'
374-
elif(model_id == 'intfloat/multilingual-e5-large'):
375-
json_file = './prompt-sentences-main/prompt_sentences-multilingual-e5-large.json'
376-
umap_model_file = './models/umap/intfloat/multilingual-e5-large/umap.pkl'
377-
else: # fall back to all-minilm as default
378-
json_file = './prompt-sentences-main/prompt_sentences-all-minilm-l6-v2.json'
379-
umap_model_file = './models/umap/sentence-transformers/all-MiniLM-L6-v2/umap.pkl'
380-
381-
with open(umap_model_file, 'rb') as f:
382-
umap_model = pickle.load(f)
383-
384-
prompt_json = json.load( open( json_file ) )
385-
386-
# Output initialization
387-
out, out['input'], out['add'], out['remove'] = {}, {}, {}, {}
388-
input_items, items_to_add, items_to_remove = [], [], []
389-
390-
# Spliting prompt into sentences
391-
input_sentences = split_into_sentences(prompt)
392-
393-
# Recommendation of values to add to the current prompt
394-
# Using only the last sentence for the add recommendation
395-
model = SentenceTransformer(model_path)
396-
input_embedding = model.encode(input_sentences[-1])
397-
398-
for v in prompt_json['positive_values']:
399-
# Dealing with values without prompts and makinig sure they have the same dimensions
400-
if(len(v['centroid']) == len(input_embedding)):
401-
if(get_similarity(pd.DataFrame(input_embedding), pd.DataFrame(v['centroid'])) > add_lower_threshold):
402-
closer_prompt = -1
403-
for p in v['prompts']:
404-
d_prompt = get_similarity(pd.DataFrame(input_embedding), pd.DataFrame(p['embedding']))
405-
# The sentence_threshold is being used as a ceiling meaning that for high similarities the sentence/value might already be presente in the prompt
406-
# So, we don't want to recommend adding something that is already there
407-
if(d_prompt > closer_prompt and d_prompt > add_lower_threshold and d_prompt < add_upper_threshold):
408-
closer_prompt = d_prompt
409-
items_to_add.append({
410-
'value': v['label'],
411-
'prompt': p['text'],
412-
'similarity': d_prompt,
413-
'x': p['x'],
414-
'y': p['y']})
415-
out['add'] = items_to_add
416-
417-
# Recommendation of values to remove from the current prompt
418-
i = 0
419-
420-
# Recommendation of values to remove from the current prompt
421-
for sentence in input_sentences:
422-
input_embedding = model.encode(sentence) # local
423-
# Obtaining XY coords for input sentences from a UMAP model
424-
if(len(prompt_json['negative_values'][0]['centroid']) == len(input_embedding) and sentence != ''):
425-
embeddings_umap = umap_model.transform(np.expand_dims(pd.DataFrame(input_embedding).squeeze(), axis=0))
426-
input_items.append({
427-
'sentence': sentence,
428-
'x': str(embeddings_umap[0][0]),
429-
'y': str(embeddings_umap[0][1])
430-
})
431-
432-
for v in prompt_json['negative_values']:
433-
# Dealing with values without prompts and makinig sure they have the same dimensions
434-
if(len(v['centroid']) == len(input_embedding)):
435-
if(get_similarity(pd.DataFrame(input_embedding), pd.DataFrame(v['centroid'])) > remove_lower_threshold):
436-
closer_prompt = -1
437-
for p in v['prompts']:
438-
d_prompt = get_similarity(pd.DataFrame(input_embedding), pd.DataFrame(p['embedding']))
439-
# A more restrict threshold is used here to prevent false positives
440-
# The sentence_threhold is being used to indicate that there must be a sentence in the prompt that is similiar to one of our adversarial prompts
441-
# So, yes, we want to recommend the revolval of something adversarial we've found
442-
if(d_prompt > closer_prompt and d_prompt > remove_upper_threshold):
443-
closer_prompt = d_prompt
444-
items_to_remove.append({
445-
'value': v['label'],
446-
'sentence': sentence,
447-
'sentence_index': i,
448-
'closest_harmful_sentence': p['text'],
449-
'similarity': d_prompt,
450-
'x': p['x'],
451-
'y': p['y']})
452-
out['remove'] = items_to_remove
453-
i += 1
454-
455-
out['input'] = input_items
456-
457-
out['add'] = sorted(out['add'], key=sort_by_similarity, reverse=True)
458-
values_map = {}
459-
for item in out['add'][:]:
460-
if(item['value'] in values_map):
461-
out['add'].remove(item)
462-
else:
463-
values_map[item['value']] = item['similarity']
464-
out['add'] = out['add'][0:5]
465-
466-
out['remove'] = sorted(out['remove'], key=sort_by_similarity, reverse=True)
467-
values_map = {}
468-
for item in out['remove'][:]:
469-
if(item['value'] in values_map):
470-
out['remove'].remove(item)
471-
else:
472-
values_map[item['value']] = item['similarity']
473-
out['remove'] = out['remove'][0:5]
474-
return out

0 commit comments

Comments
 (0)