Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CODEOWNERS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ Vagner Santana - vsantana@ibm.com

Cássia Sampaio - csamp@ibm.com

Ashwath Vaithinathan Aravindan - ashwath.vaithina@ibm.com
Ashwath Vaithinathan Aravindan - ashwath.vaithina@ibm.com

9 changes: 5 additions & 4 deletions SECURITY.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ currently being supported with security updates.

| Version | Supported |
| ------- | ------------------ |
| 5.1.x | :white_check_mark: |
| 5.0.x | :x: |
| v0.1.0 | :white_check_mark: |
<!--| 5.0.x | :x: |
| 4.0.x | :white_check_mark: |
| < 4.0 | :x: |
| < 4.0 | :x: |-->

## Reporting a Vulnerability

To report a security issue, please email $VMTalias with a description of the issue, the steps you took to create the issue, affected versions, and if known, mitigations for the issue. Our vulnerability management team will acknowledge receiving your email within 3 working days. This project follows a 90 day disclosure timeline.
To report a security issue, please email us with a description of the issue, the steps you took to create the issue, affected versions, and if known, mitigations for the issue. Our vulnerability management team will acknowledge receiving your email within 3 working days. This project follows a 90 day disclosure timeline.

75 changes: 74 additions & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
import json
import os
import pickle
import numpy as np
from functools import lru_cache

app = Flask(__name__)

Expand All @@ -57,6 +59,41 @@
app.register_blueprint(cfg.SWAGGER_BLUEPRINT, url_prefix = cfg.SWAGGER_URL)
FRONT_LOG_FILE = 'front_log.json'

@lru_cache(maxsize=1)
def get_values_embedding_function():
"""
Getting the embedding function for the /values endpoint.
Cached to avoid reloading the model multiple times.

Returns:
Embedding function callable
"""
model_id, model_path = save_model.save_model()
return recommendation_handler.get_embedding_func(inference='local', model_id=model_path)

@lru_cache(maxsize=1)
def get_values_centroids():
"""
Getting the positive and negative value centroids for the /values endpoint.
Cached to avoid reloading the data multiple times.

Returns:
Dictionary with 'positive' and 'negative' centroid embeddings
"""
prompt_json = recommendation_handler.populate_json()
positive_category_centroid = {}
negative_category_centroid = {}

for category in prompt_json['positive_values']:
positive_category_centroid[category['label']] = np.array(category['centroid'])

for category in prompt_json['negative_values']:
negative_category_centroid[category['label']] = np.array(category['centroid'])

return {
'positive': positive_category_centroid,
'negative': negative_category_centroid
}

@app.route("/")
def index():
Expand Down Expand Up @@ -109,7 +146,7 @@ def get_thresholds():
@cross_origin()
def recommend_local():
model_id, _ = save_model.save_model()
prompt_json, _ = recommendation_handler.populate_json()
prompt_json = recommendation_handler.populate_json()
args = request.args
print("args list = ", args)
prompt = args.get("prompt")
Expand Down Expand Up @@ -166,6 +203,42 @@ def demo_inference():
return response
except:
return "Model Inference failed.", 500

@app.route("/values", methods=['GET'])
@cross_origin()
def get_values():
"""
Getting positive and negative values for a given prompt using cached embedding function and centroids for performance.
"""
args = request.args
prompt = args.get("prompt")

# validating input
if not prompt:
return jsonify({"error": "Missing required parameter: prompt"}), 400

if not isinstance(prompt, str):
return jsonify({"error": "Parameter 'prompt' must be a string"}), 400

if len(prompt.strip()) == 0:
return jsonify({"error": "Parameter 'prompt' cannot be empty"}), 400

try:
embedding_fn = get_values_embedding_function()
centroids = get_values_centroids()

values = recommendation_handler.get_values(
prompt,
centroids['positive'],
centroids['negative'],
embedding_fn
)

return jsonify(values)

except Exception as e:
logger.error(f'Error in /values endpoint: {str(e)}')
return jsonify({"error": "Internal server error processing prompt"}), 500

if __name__=='__main__':
debug_mode = os.getenv('FLASK_DEBUG', 'False').lower() in ['true', '1', 't']
Expand Down
79 changes: 79 additions & 0 deletions control/recommendation_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,82 @@ def get_thresholds(
thresholds['remove_higher_threshold'] = round(remove_similarities_df.describe([.9]).loc['90%', 'similarity'], 1)

return thresholds

def get_values(
prompt,
positive_embeddings,
negative_embeddings,
embedding_fn = None
):
"""
Compute positive and negative value associations for each sentence in the input prompt.

Args:
prompt: Input prompt text.
positive_embeddings: Dictionary mapping positive value labels to centroid embeddings.
negative_embeddings: Dictionary mapping negative value labels to centroid embeddings.
embedding_fn: Function to generate embeddings from text.

Returns:
Dictionary containing sentences with their associated positive and negative values and similarity scores.
"""

if embedding_fn is None:
# using all-MiniLM-L6-v2 locally by default
embedding_fn = get_embedding_func('local', model_id='sentence-transformers/all-MiniLM-L6-v2')

sentences = split_into_sentences(prompt)

# bifurcating and filtering out empty sentences
sentences = [s for s in sentences if s.strip()]

values = {}
values["prompts"] = []

# returning if no valid sentences
if not sentences:
return values

# generating all sentence embeddings in a single call by batching all
sentence_embeddings = embedding_fn(sentences)
sentence_embeddings = np.array(sentence_embeddings)

# ensuring embeddings have correct shape - expanding embeddings of all sentences
if len(sentence_embeddings.shape) == 1:
sentence_embeddings = np.expand_dims(sentence_embeddings, axis=0)

# processing each sentence with its corresponding embedding
for idx, sentence in enumerate(sentences):

sentence_embedding = sentence_embeddings[idx]

max_similarity_positive = -1
positive_label = None
for label, centroid in positive_embeddings.items():
similarity = cosine_similarity(
np.expand_dims(sentence_embedding, axis=0),
np.array([centroid])
)[0, 0]
if similarity > max_similarity_positive:
max_similarity_positive = similarity
positive_label = label

max_similarity_negative = -1
negative_label = None
for label, centroid in negative_embeddings.items():
similarity = cosine_similarity(
np.expand_dims(sentence_embedding, axis=0),
np.array([centroid])
)[0, 0]
if similarity > max_similarity_negative:
max_similarity_negative = similarity
negative_label = label

values["prompts"].append({
"sentence": sentence,
"positive_value": {"label": positive_label, "similarity": float(max_similarity_positive)},
"negative_value": {"label": negative_label, "similarity": float(max_similarity_negative)}
})

return values

Loading
Loading