-
Notifications
You must be signed in to change notification settings - Fork 3.2k
[Groundedness] handle edge cases by copy #43923
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
550bc03
3eec964
7a2a5e2
4388e4a
dc21981
48d55fb
28adb97
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,13 +2,16 @@ | |
| # Copyright (c) Microsoft Corporation. All rights reserved. | ||
| # --------------------------------------------------------- | ||
| import os, logging | ||
| import re | ||
| import math | ||
| from typing import Dict, List, Optional, Union, Any, Tuple | ||
|
|
||
| from typing_extensions import overload, override | ||
| from azure.ai.evaluation._legacy.prompty import AsyncPrompty | ||
|
|
||
| from azure.ai.evaluation._evaluators._common import PromptyEvaluatorBase | ||
| from azure.ai.evaluation._model_configurations import Conversation | ||
| from azure.ai.evaluation._common.constants import PROMPT_BASED_REASON_EVALUATORS | ||
| from ..._common.utils import ( | ||
| ErrorBlame, | ||
| ErrorTarget, | ||
|
|
@@ -17,6 +20,7 @@ | |
| construct_prompty_model_config, | ||
| validate_model_config, | ||
| simplify_messages, | ||
| parse_quality_evaluator_reason_score, | ||
| ) | ||
|
|
||
| try: | ||
|
|
@@ -103,21 +107,25 @@ class GroundednessEvaluator(PromptyEvaluatorBase[Union[str, float]]): | |
| def __init__(self, model_config, *, threshold=3, credential=None, **kwargs): | ||
| current_dir = os.path.dirname(__file__) | ||
| prompty_path = os.path.join(current_dir, self._PROMPTY_FILE_NO_QUERY) # Default to no query | ||
|
|
||
| self._higher_is_better = True | ||
| super().__init__( | ||
| model_config=model_config, | ||
| prompty_file=prompty_path, | ||
| result_key=self._RESULT_KEY, | ||
| threshold=threshold, | ||
| credential=credential, | ||
| _higher_is_better=self._higher_is_better, | ||
| _higher_is_better=True, | ||
| **kwargs, | ||
| ) | ||
| self._model_config = model_config | ||
| self.threshold = threshold | ||
| # Needs to be set because it's used in call method to re-validate prompt if `query` is provided | ||
|
|
||
| # To make sure they're not used directly | ||
| self._flow = None | ||
| self._prompty_file = None | ||
|
|
||
| self._flow_with_query = self._load_flow(self._PROMPTY_FILE_WITH_QUERY, token_credential=credential) | ||
| self._flow_no_query = self._load_flow(self._PROMPTY_FILE_NO_QUERY, token_credential=credential) | ||
|
|
||
| @overload | ||
| def __call__( | ||
| self, | ||
|
|
@@ -201,31 +209,50 @@ def __call__( # pylint: disable=docstring-missing-param | |
| :rtype: Union[Dict[str, Union[str, float]], Dict[str, Union[float, Dict[str, List[Union[str, float]]]]]] | ||
| """ | ||
|
|
||
| if kwargs.get("query", None): | ||
| self._ensure_query_prompty_loaded() | ||
|
|
||
| return super().__call__(*args, **kwargs) | ||
|
|
||
| def _ensure_query_prompty_loaded(self): | ||
| """Switch to the query prompty file if not already loaded.""" | ||
| def _load_flow(self, prompty_filename: str, **kwargs) -> AsyncPrompty: | ||
| """Load the Prompty flow from the specified file. | ||
| :param prompty_filename: The filename of the Prompty flow to load. | ||
| :type prompty_filename: str | ||
| :return: The loaded Prompty flow. | ||
| :rtype: AsyncPrompty | ||
| """ | ||
|
|
||
| current_dir = os.path.dirname(__file__) | ||
| prompty_path = os.path.join(current_dir, self._PROMPTY_FILE_WITH_QUERY) | ||
| prompty_path = os.path.join(current_dir, prompty_filename) | ||
|
|
||
| self._prompty_file = prompty_path | ||
| prompty_model_config = construct_prompty_model_config( | ||
| validate_model_config(self._model_config), | ||
| self._DEFAULT_OPEN_API_VERSION, | ||
| UserAgentSingleton().value, | ||
| ) | ||
| self._flow = AsyncPrompty.load(source=self._prompty_file, model=prompty_model_config) | ||
| flow = AsyncPrompty.load( | ||
| source=prompty_path, | ||
| model=prompty_model_config, | ||
| is_reasoning_model=self._is_reasoning_model, | ||
| **kwargs, | ||
|
Comment on lines
219
to
+234
|
||
| ) | ||
|
|
||
| return flow | ||
|
|
||
| def _has_context(self, eval_input: dict) -> bool: | ||
| """ | ||
ahibrahimm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Return True if eval_input contains a non-empty 'context' field. | ||
| Treats None, empty strings, empty lists, and lists of empty strings as no context. | ||
| """ | ||
| context = eval_input.get("context", None) | ||
| return self._validate_context(context) | ||
|
|
||
| def _validate_context(self, context) -> bool: | ||
| """ | ||
| Validate if the provided context is non-empty and meaningful. | ||
| Treats None, empty strings, empty lists, and lists of empty strings as no context. | ||
| :param context: The context to validate | ||
| :type context: Union[str, List, None] | ||
| :return: True if context is valid and non-empty, False otherwise | ||
| :rtype: bool | ||
| """ | ||
| if not context: | ||
| return False | ||
| if context == "<>": # Special marker for no context | ||
|
|
@@ -239,7 +266,7 @@ def _has_context(self, eval_input: dict) -> bool: | |
| @override | ||
| async def _do_eval(self, eval_input: Dict) -> Dict[str, Union[float, str]]: | ||
| if eval_input.get("query", None) is None: | ||
| return await super()._do_eval(eval_input) | ||
| return await self._do_eval_with_flow(eval_input, self._flow_no_query) | ||
|
|
||
| contains_context = self._has_context(eval_input) | ||
|
|
||
|
|
@@ -254,7 +281,85 @@ async def _do_eval(self, eval_input: Dict) -> Dict[str, Union[float, str]]: | |
| } | ||
|
|
||
| # Replace and call the parent method | ||
| return await super()._do_eval(simplified_eval_input) | ||
| return await self._do_eval_with_flow(simplified_eval_input, self._flow_with_query) | ||
|
|
||
| async def _do_eval_with_flow(self, eval_input: Dict, flow: AsyncPrompty) -> Dict[str, Union[float, str]]: # type: ignore[override] | ||
| """Do an evaluation. | ||
|
|
||
| NOTE: This is copy from parent with addition of flow parameter to allow choosing between two flows. | ||
| :param eval_input: The input to the evaluator. Expected to contain | ||
| whatever inputs are needed for the flow method, including context | ||
| and other fields depending on the child class. | ||
ahibrahimm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| :type eval_input: Dict | ||
|
||
| :param flow: The AsyncPrompty flow to use for evaluation. | ||
| :type flow: AsyncPrompty | ||
| :return: The evaluation result. | ||
|
||
| :rtype: Dict | ||
| """ | ||
| if "query" not in eval_input and "response" not in eval_input: | ||
| raise EvaluationException( | ||
| message="Only text conversation inputs are supported.", | ||
| internal_message="Only text conversation inputs are supported.", | ||
| blame=ErrorBlame.USER_ERROR, | ||
| category=ErrorCategory.INVALID_VALUE, | ||
| target=ErrorTarget.CONVERSATION, | ||
| ) | ||
| # Call the prompty flow to get the evaluation result. | ||
| prompty_output_dict = await flow(timeout=self._LLM_CALL_TIMEOUT, **eval_input) | ||
|
|
||
| score = math.nan | ||
| if prompty_output_dict: | ||
| llm_output = prompty_output_dict.get("llm_output", "") | ||
| input_token_count = prompty_output_dict.get("input_token_count", 0) | ||
| output_token_count = prompty_output_dict.get("output_token_count", 0) | ||
| total_token_count = prompty_output_dict.get("total_token_count", 0) | ||
| finish_reason = prompty_output_dict.get("finish_reason", "") | ||
| model_id = prompty_output_dict.get("model_id", "") | ||
| sample_input = prompty_output_dict.get("sample_input", "") | ||
| sample_output = prompty_output_dict.get("sample_output", "") | ||
| # Parse out score and reason from evaluators known to possess them. | ||
| if self._result_key in PROMPT_BASED_REASON_EVALUATORS: | ||
| score, reason = parse_quality_evaluator_reason_score(llm_output) | ||
| binary_result = self._get_binary_result(score) | ||
| return { | ||
| self._result_key: float(score), | ||
| f"gpt_{self._result_key}": float(score), | ||
| f"{self._result_key}_reason": reason, | ||
| f"{self._result_key}_result": binary_result, | ||
| f"{self._result_key}_threshold": self._threshold, | ||
| f"{self._result_key}_prompt_tokens": input_token_count, | ||
| f"{self._result_key}_completion_tokens": output_token_count, | ||
| f"{self._result_key}_total_tokens": total_token_count, | ||
| f"{self._result_key}_finish_reason": finish_reason, | ||
| f"{self._result_key}_model": model_id, | ||
| f"{self._result_key}_sample_input": sample_input, | ||
| f"{self._result_key}_sample_output": sample_output, | ||
| } | ||
| match = re.search(r"\d", llm_output) | ||
| if match: | ||
| score = float(match.group()) | ||
| binary_result = self._get_binary_result(score) | ||
| return { | ||
| self._result_key: float(score), | ||
| f"gpt_{self._result_key}": float(score), | ||
| f"{self._result_key}_result": binary_result, | ||
| f"{self._result_key}_threshold": self._threshold, | ||
| f"{self._result_key}_prompt_tokens": input_token_count, | ||
| f"{self._result_key}_completion_tokens": output_token_count, | ||
| f"{self._result_key}_total_tokens": total_token_count, | ||
| f"{self._result_key}_finish_reason": finish_reason, | ||
| f"{self._result_key}_model": model_id, | ||
| f"{self._result_key}_sample_input": sample_input, | ||
| f"{self._result_key}_sample_output": sample_output, | ||
| } | ||
|
|
||
| binary_result = self._get_binary_result(score) | ||
| return { | ||
| self._result_key: float(score), | ||
| f"gpt_{self._result_key}": float(score), | ||
| f"{self._result_key}_result": binary_result, | ||
| f"{self._result_key}_threshold": self._threshold, | ||
| } | ||
|
|
||
| async def _real_call(self, **kwargs): | ||
| """The asynchronous call where real end-to-end evaluation logic is performed. | ||
|
|
@@ -272,22 +377,27 @@ async def _real_call(self, **kwargs): | |
| return { | ||
| self._result_key: self._NOT_APPLICABLE_RESULT, | ||
| f"{self._result_key}_result": "pass", | ||
| f"{self._result_key}_threshold": self.threshold, | ||
| f"{self._result_key}_threshold": self._threshold, | ||
| f"{self._result_key}_reason": f"Supported tools were not called. Supported tools for groundedness are {self._SUPPORTED_TOOLS}.", | ||
| } | ||
| else: | ||
| raise ex | ||
|
|
||
| def _is_single_entry(self, value): | ||
| """Determine if the input value represents a single entry, unsure is returned as False.""" | ||
| if isinstance(value, str): | ||
| return True | ||
| if isinstance(value, list) and len(value) == 1: | ||
| return True | ||
ahibrahimm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return False | ||
ahibrahimm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def _convert_kwargs_to_eval_input(self, **kwargs): | ||
| if kwargs.get("context") or kwargs.get("conversation"): | ||
| return super()._convert_kwargs_to_eval_input(**kwargs) | ||
| query = kwargs.get("query") | ||
| response = kwargs.get("response") | ||
| tool_definitions = kwargs.get("tool_definitions") | ||
|
|
||
| if query and self._prompty_file != self._PROMPTY_FILE_WITH_QUERY: | ||
| self._ensure_query_prompty_loaded() | ||
|
|
||
| if (not query) or (not response): # or not tool_definitions: | ||
| msg = f"{type(self).__name__}: Either 'conversation' or individual inputs must be provided. For Agent groundedness 'query' and 'response' are required." | ||
| raise EvaluationException( | ||
|
|
@@ -298,7 +408,16 @@ def _convert_kwargs_to_eval_input(self, **kwargs): | |
| ) | ||
| context = self._get_context_from_agent_response(response, tool_definitions) | ||
|
|
||
| filtered_response = self._filter_file_search_results(response) | ||
| if not self._validate_context(context) and self._is_single_entry(response) and self._is_single_entry(query): | ||
| msg = f"{type(self).__name__}: No valid context provided or could be extracted from the query or response." | ||
| raise EvaluationException( | ||
| message=msg, | ||
| blame=ErrorBlame.USER_ERROR, | ||
| category=ErrorCategory.NOT_APPLICABLE, | ||
| target=ErrorTarget.GROUNDEDNESS_EVALUATOR, | ||
| ) | ||
|
|
||
| filtered_response = self._filter_file_search_results(response) if self._validate_context(context) else response | ||
| return super()._convert_kwargs_to_eval_input(response=filtered_response, context=context, query=query) | ||
|
|
||
| def _filter_file_search_results(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.