Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions tests/unit/vertexai/genai/replays/test_custom_code_execution_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pylint: disable=protected-access,bad-continuation,missing-function-docstring

from tests.unit.vertexai.genai.replays import pytest_helper
from vertexai._genai import types
import pandas as pd


def test_custom_code_execution(client):
"""Tests that custom code execution metric produces a correctly structured EvaluationResult."""

code_snippet = """
def evaluate(instance):
if instance['response'] == instance['reference']:
return 1.0
return 0.0
"""

custom_metric = types.Metric(
name="my_custom_code_metric",
remote_custom_function=code_snippet,
)

prompts_df = pd.DataFrame(
{
"prompt": ["What is 2+2?", "What is 3+3?"],
"response": ["4", "5"],
"reference": ["4", "6"],
}
)

eval_dataset = types.EvaluationDataset(
eval_dataset_df=prompts_df,
candidate_name="test_model",
)

evaluation_result = client.evals.evaluate(
dataset=eval_dataset,
metrics=[custom_metric],
)

assert isinstance(evaluation_result, types.EvaluationResult)

assert evaluation_result.summary_metrics is not None
assert evaluation_result.summary_metrics
for summary in evaluation_result.summary_metrics:
assert isinstance(summary, types.AggregatedMetricResult)
assert summary.metric_name == "my_custom_code_metric"

assert evaluation_result.eval_case_results is not None
assert evaluation_result.eval_case_results
for case_result in evaluation_result.eval_case_results:
assert isinstance(case_result, types.EvalCaseResult)
assert case_result.eval_case_index is not None
assert case_result.response_candidate_results is not None


def test_custom_code_execution_batch_evaluate(client):
"""Tests that batch_evaluate() works with custom code execution metric."""

code_snippet = """
def evaluate(instance):
if instance['response'] == instance['reference']:
return 1.0
return 0.0
"""

custom_metric = types.Metric(
name="my_custom_code_metric",
remote_custom_function=code_snippet,
)

eval_dataset = types.EvaluationDataset(
gcs_source=types.GcsSource(
uris=["gs://genai-eval-sdk-replay-test/test_data/inference_results.jsonl"]
),
)

evaluation_result = client.evals.batch_evaluate(
dataset=eval_dataset,
metrics=[custom_metric],
dest="gs://genai-eval-sdk-replay-test/test_data/batch_eval_output",
)

assert evaluation_result is not None


pytestmark = pytest_helper.setup(
file=__file__,
globals_for_file=globals(),
test_method="evals.evaluate",
)
144 changes: 142 additions & 2 deletions vertexai/_genai/_evals_metric_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,10 +685,9 @@ def get_metric_result(
)
except Exception as e: # pylint: disable=broad-exception-caught
logger.error(
"Error processing metric %s for case %s: %s",
"Error processing metric %s for case %s.",
metric_name,
eval_case.eval_case_id,
e,
exc_info=True,
)
return types.EvalCaseMetricResult(
Expand Down Expand Up @@ -1099,7 +1098,147 @@ def aggregate(
)


class CustomCodeExecutionMetricHandler(MetricHandler):
"""Metric handler for custom code execution metrics."""

def __init__(self, module: "evals.Evals", metric: types.Metric):
super().__init__(module=module, metric=metric)

if not self.metric.remote_custom_function:
raise ValueError(
f"CustomCodeExecutionMetricHandler for '{self.metric.name}' needs "
" Metric.remote_custom_function to be set."
)

def _build_request_payload(
self, eval_case: types.EvalCase, response_index: int
) -> dict[str, Any]:
"""Builds the request parameters for evaluate instances request."""
if not eval_case.responses or response_index >= len(eval_case.responses):
raise IndexError(f"response_index {response_index} is out of bounds.")

response_content = eval_case.responses[response_index].response
if not response_content:
raise ValueError(
f"Response content missing for candidate {response_index}."
)

reference_instance_data = None
if eval_case.reference:
reference_instance_data = PredefinedMetricHandler._content_to_instance_data(
eval_case.reference.response
)

prompt_instance_data = PredefinedMetricHandler._content_to_instance_data(
eval_case.prompt
)

instance_payload = types.EvaluationInstance(
prompt=prompt_instance_data,
response=PredefinedMetricHandler._content_to_instance_data(
response_content
),
reference=reference_instance_data,
)

return {
"instance": instance_payload,
}

@override
def get_metric_result(
self, eval_case: types.EvalCase, response_index: int
) -> types.EvalCaseMetricResult:
"""Processes a single evaluation case for a specific custom code execution metric."""
metric_name = self.metric.name
try:
payload = self._build_request_payload(eval_case, response_index)
for attempt in range(_MAX_RETRIES):
try:
api_response = self.module._evaluate_instances(
metrics=[self.metric],
instance=payload.get("instance"),
)
break
except genai_errors.ClientError as e:
if e.code == 429:
logger.warning(
"Resource Exhausted error on attempt %d/%d: %s. Retrying in %s"
" seconds...",
attempt + 1,
_MAX_RETRIES,
e,
2**attempt,
)
if attempt == _MAX_RETRIES - 1:
return types.EvalCaseMetricResult(
metric_name=metric_name,
error_message=f"Resource exhausted after {_MAX_RETRIES} retries: {e}",
)
time.sleep(2**attempt)
else:
raise e

if (
api_response
and hasattr(api_response, "metric_results")
and api_response.metric_results
):
result_data = api_response.metric_results[0]

error_message = None
if result_data.error and getattr(result_data.error, "code"):
error_message = f"Error in metric result: {result_data.error}"
return types.EvalCaseMetricResult(
metric_name=metric_name,
score=result_data.score,
explanation=result_data.explanation,
error_message=error_message,
)
else:
logger.error(
"Metric results missing in API response for metric '%s'."
" API response: %s",
metric_name,
(
api_response.model_dump_json(exclude_none=True)
if api_response
else "None"
),
)
return types.EvalCaseMetricResult(
metric_name=metric_name,
error_message="Metric results missing in API response.",
)
except Exception as e: # pylint: disable=broad-exception-caught
logger.error(
"Error processing metric %s for case %s",
metric_name,
eval_case.eval_case_id,
exc_info=True,
)
return types.EvalCaseMetricResult(
metric_name=metric_name, error_message=str(e)
)

@override
def aggregate(
self, eval_case_metric_results: list[types.EvalCaseMetricResult]
) -> types.AggregatedMetricResult:
"""Aggregates the metric results for a custom code execution metric."""
logger.debug(
"Aggregating results for custom code execution metric: %s", self.metric.name
)
return _default_aggregate_scores(
self.metric.name, eval_case_metric_results, calculate_pass_rate=True
)


_METRIC_HANDLER_MAPPING = [
(
lambda m: hasattr(m, "remote_custom_function") and m.remote_custom_function,
CustomCodeExecutionMetricHandler,
),
(
lambda m: m.custom_function and isinstance(m.custom_function, Callable),
CustomMetricHandler,
Expand All @@ -1125,6 +1264,7 @@ def aggregate(
TranslationMetricHandler,
LLMMetricHandler,
CustomMetricHandler,
CustomCodeExecutionMetricHandler,
PredefinedMetricHandler,
)

Expand Down
7 changes: 7 additions & 0 deletions vertexai/_genai/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def t_metrics(
"metric_spec_name": metric_name,
"metric_spec_parameters": metric.metric_spec_parameters,
}
# Custom Code Execution Metric
elif (
hasattr(metric, "remote_custom_function") and metric.remote_custom_function
):
metric_payload_item["custom_code_execution_spec"] = {
"evaluation_function": metric.remote_custom_function
}
# Pointwise metrics
elif hasattr(metric, "prompt_template") and metric.prompt_template:
pointwise_spec = {"metric_prompt_template": metric.prompt_template}
Expand Down
Loading
Loading