Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 101 additions & 47 deletions tests/unit/vertexai/genai/replays/test_create_evaluation_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,53 +19,96 @@
from google.genai import types as genai_types
import pytest


def test_create_eval_run_data_source_evaluation_set(client):
"""Tests that create_evaluation_run() creates a correctly structured EvaluationRun."""
client._api_client._http_options.api_version = "v1beta1"
tool = genai_types.Tool(
function_declarations=[
genai_types.FunctionDeclaration(
name="get_weather",
description="Get weather in a location",
parameters={
"type": "object",
"properties": {"location": {"type": "string"}},
},
GCS_DEST = "gs://lakeyk-test-limited/eval_run_output"
UNIVERSAL_AR_METRIC = types.EvaluationRunMetric(
metric="universal_ar_v1",
metric_config=types.UnifiedMetric(
predefined_metric_spec=types.PredefinedMetricSpec(
metric_spec_name="universal_ar_v1",
)
),
)
FINAL_RESPONSE_QUALITY_METRIC = types.EvaluationRunMetric(
metric="final_response_quality_v1",
metric_config=types.UnifiedMetric(
predefined_metric_spec=types.PredefinedMetricSpec(
metric_spec_name="final_response_quality_v1",
)
),
)
LLM_METRIC = types.EvaluationRunMetric(
metric="llm_metric",
metric_config=types.UnifiedMetric(
llm_based_metric_spec=types.LLMBasedMetricSpec(
metric_prompt_template=(
"\nEvaluate the fluency of the response. Provide a score from 1-5."
)
]
)
evaluation_run = client.evals.create_evaluation_run(
name="test4",
display_name="test4",
dataset=types.EvaluationRunDataSource(
evaluation_set="projects/503583131166/locations/us-central1/evaluationSets/6619939608513740800"
),
agent_info=types.AgentInfo(
name="agent-1",
instruction="agent-1 instruction",
tool_declarations=[tool],
),
dest="gs://lakeyk-limited-bucket/eval_run_output",
)
assert isinstance(evaluation_run, types.EvaluationRun)
assert evaluation_run.display_name == "test4"
assert evaluation_run.state == types.EvaluationRunState.PENDING
assert isinstance(evaluation_run.data_source, types.EvaluationRunDataSource)
assert evaluation_run.data_source.evaluation_set == (
"projects/503583131166/locations/us-central1/evaluationSets/6619939608513740800"
)
assert evaluation_run.inference_configs[
"agent-1"
] == types.EvaluationRunInferenceConfig(
agent_config=types.EvaluationRunAgentConfig(
developer_instruction=genai_types.Content(
parts=[genai_types.Part(text="agent-1 instruction")]
),
tools=[tool],
)
)
assert evaluation_run.error is None
),
)


# TODO(b/431231205): Re-enable once Unified Metrics are in prod.
# def test_create_eval_run_data_source_evaluation_set(client):
# """Tests that create_evaluation_run() creates a correctly structured EvaluationRun."""
# client._api_client._http_options.base_url = (
# "https://us-central1-autopush-aiplatform.sandbox.googleapis.com/"
# )
# client._api_client._http_options.api_version = "v1beta1"
# tool = genai_types.Tool(
# function_declarations=[
# genai_types.FunctionDeclaration(
# name="get_weather",
# description="Get weather in a location",
# parameters={
# "type": "object",
# "properties": {"location": {"type": "string"}},
# },
# )
# ]
# )
# evaluation_run = client.evals.create_evaluation_run(
# name="test4",
# display_name="test4",
# dataset=types.EvaluationRunDataSource(
# evaluation_set="projects/503583131166/locations/us-central1/evaluationSets/6619939608513740800"
# ),
# dest=GCS_DEST,
# metrics=[
# UNIVERSAL_AR_METRIC,
# types.RubricMetric.FINAL_RESPONSE_QUALITY,
# LLM_METRIC
# ],
# agent_info=types.AgentInfo(
# name="agent-1",
# instruction="agent-1 instruction",
# tool_declarations=[tool],
# ),
# )
# assert isinstance(evaluation_run, types.EvaluationRun)
# assert evaluation_run.display_name == "test4"
# assert evaluation_run.state == types.EvaluationRunState.PENDING
# assert isinstance(evaluation_run.data_source, types.EvaluationRunDataSource)
# assert evaluation_run.data_source.evaluation_set == (
# "projects/503583131166/locations/us-central1/evaluationSets/6619939608513740800"
# )
# assert evaluation_run.evaluation_config == types.EvaluationRunConfig(
# output_config=genai_types.OutputConfig(
# gcs_destination=genai_types.GcsDestination(output_uri_prefix=GCS_DEST)
# ),
# metrics=[UNIVERSAL_AR_METRIC, FINAL_RESPONSE_QUALITY_METRIC, LLM_METRIC],
# )
# assert evaluation_run.inference_configs[
# "agent-1"
# ] == types.EvaluationRunInferenceConfig(
# agent_config=types.EvaluationRunAgentConfig(
# developer_instruction=genai_types.Content(
# parts=[genai_types.Part(text="agent-1 instruction")]
# ),
# tools=[tool],
# )
# )
# assert evaluation_run.error is None


def test_create_eval_run_data_source_bigquery_request_set(client):
Expand All @@ -84,7 +127,7 @@ def test_create_eval_run_data_source_bigquery_request_set(client):
},
)
),
dest="gs://lakeyk-limited-bucket/eval_run_output",
dest=GCS_DEST,
)
assert isinstance(evaluation_run, types.EvaluationRun)
assert evaluation_run.display_name == "test5"
Expand All @@ -101,6 +144,11 @@ def test_create_eval_run_data_source_bigquery_request_set(client):
},
)
)
assert evaluation_run.evaluation_config == types.EvaluationRunConfig(
output_config=genai_types.OutputConfig(
gcs_destination=genai_types.GcsDestination(output_uri_prefix=GCS_DEST)
),
)
assert evaluation_run.inference_configs is None
assert evaluation_run.error is None

Expand Down Expand Up @@ -220,7 +268,7 @@ async def test_create_eval_run_async(client):
},
)
),
dest="gs://lakeyk-limited-bucket/eval_run_output",
dest=GCS_DEST,
)
assert isinstance(evaluation_run, types.EvaluationRun)
assert evaluation_run.display_name == "test8"
Expand All @@ -233,6 +281,12 @@ async def test_create_eval_run_async(client):
"checkpoint_2": "checkpoint_2",
},
)
assert evaluation_run.evaluation_config == types.EvaluationRunConfig(
output_config=genai_types.OutputConfig(
gcs_destination=genai_types.GcsDestination(output_uri_prefix=GCS_DEST)
),
)
assert evaluation_run.error is None
assert evaluation_run.inference_configs is None
assert evaluation_run.error is None

Expand Down
67 changes: 67 additions & 0 deletions vertexai/_genai/_evals_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,73 @@ def _resolve_dataset_inputs(
return processed_eval_dataset, num_response_candidates


def _resolve_evaluation_run_metrics(
metrics: list[types.EvaluationRunMetric], api_client: Any
) -> list[types.EvaluationRunMetric]:
"""Resolves a list of evaluation run metric instances, loading RubricMetric if necessary."""
if not metrics:
return []
resolved_metrics_list = []
for metric_instance in metrics:
if isinstance(metric_instance, types.EvaluationRunMetric):
resolved_metrics_list.append(metric_instance)
elif isinstance(metric_instance, _evals_utils.LazyLoadedPrebuiltMetric):
try:
resolved_metric = metric_instance.resolve(api_client=api_client)
if resolved_metric.name:
resolved_metrics_list.append(
types.EvaluationRunMetric(
metric=resolved_metric.name,
metric_config=types.UnifiedMetric(
predefined_metric_spec=types.PredefinedMetricSpec(
metric_spec_name=resolved_metric.name,
)
),
)
)
except Exception as e:
logger.error(
"Failed to resolve RubricMetric %s@%s: %s",
metric_instance.name,
metric_instance.version,
e,
)
raise
else:
try:
metric_name_str = str(metric_instance)
lazy_metric_instance = getattr(
_evals_utils.RubricMetric, metric_name_str.upper()
)
if isinstance(
lazy_metric_instance, _evals_utils.LazyLoadedPrebuiltMetric
):
resolved_metric = lazy_metric_instance.resolve(
api_client=api_client
)
if resolved_metric.name:
resolved_metrics_list.append(
types.EvaluationRunMetric(
metric=resolved_metric.name,
metric_config=types.UnifiedMetric(
predefined_metric_spec=types.PredefinedMetricSpec(
metric_spec_name=resolved_metric.name,
)
),
)
)
else:
raise TypeError(
f"RubricMetric.{metric_name_str.upper()} cannot be resolved."
)
except AttributeError as exc:
raise TypeError(
"Unsupported metric type or invalid RubricMetric name:"
f" {metric_instance}"
) from exc
return resolved_metrics_list


def _resolve_metrics(
metrics: list[types.Metric], api_client: Any
) -> list[types.Metric]:
Expand Down
31 changes: 25 additions & 6 deletions vertexai/_genai/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ def _EvaluationRun_from_vertex(
getv(from_object, ["evaluationResults"]),
)

if getv(from_object, ["evaluationConfig"]) is not None:
setv(to_object, ["evaluation_config"], getv(from_object, ["evaluationConfig"]))

if getv(from_object, ["inferenceConfigs"]) is not None:
setv(to_object, ["inference_configs"], getv(from_object, ["inferenceConfigs"]))

Expand Down Expand Up @@ -460,7 +463,7 @@ def _create_evaluation_run(
name: Optional[str] = None,
display_name: Optional[str] = None,
data_source: types.EvaluationRunDataSourceOrDict,
evaluation_config: genai_types.EvaluationConfigOrDict,
evaluation_config: types.EvaluationRunConfigOrDict,
config: Optional[types.CreateEvaluationRunConfigOrDict] = None,
inference_configs: Optional[
dict[str, types.EvaluationRunInferenceConfigOrDict]
Expand Down Expand Up @@ -1306,9 +1309,12 @@ def create_evaluation_run(
self,
*,
name: str,
display_name: Optional[str] = None,
dataset: Union[types.EvaluationRunDataSource, types.EvaluationDataset],
dest: str,
display_name: Optional[str] = None,
metrics: Optional[
list[types.EvaluationRunMetricOrDict]
] = None, # TODO: Make required unified metrics available in prod.
agent_info: Optional[types.AgentInfo] = None,
config: Optional[types.CreateEvaluationRunConfigOrDict] = None,
) -> types.EvaluationRun:
Expand All @@ -1328,7 +1334,12 @@ def create_evaluation_run(
output_config = genai_types.OutputConfig(
gcs_destination=genai_types.GcsDestination(output_uri_prefix=dest)
)
evaluation_config = genai_types.EvaluationConfig(output_config=output_config)
resolved_metrics = _evals_common._resolve_evaluation_run_metrics(
metrics, self._api_client
)
evaluation_config = types.EvaluationRunConfig(
output_config=output_config, metrics=resolved_metrics
)
inference_configs = {}
if agent_info:
logger.warning(
Expand Down Expand Up @@ -1554,7 +1565,7 @@ async def _create_evaluation_run(
name: Optional[str] = None,
display_name: Optional[str] = None,
data_source: types.EvaluationRunDataSourceOrDict,
evaluation_config: genai_types.EvaluationConfigOrDict,
evaluation_config: types.EvaluationRunConfigOrDict,
config: Optional[types.CreateEvaluationRunConfigOrDict] = None,
inference_configs: Optional[
dict[str, types.EvaluationRunInferenceConfigOrDict]
Expand Down Expand Up @@ -2103,9 +2114,12 @@ async def create_evaluation_run(
self,
*,
name: str,
display_name: Optional[str] = None,
dataset: Union[types.EvaluationRunDataSource, types.EvaluationDataset],
dest: str,
display_name: Optional[str] = None,
metrics: Optional[
list[types.EvaluationRunMetricOrDict]
] = None, # TODO: Make required unified metrics available in prod.
agent_info: Optional[types.AgentInfo] = None,
config: Optional[types.CreateEvaluationRunConfigOrDict] = None,
) -> types.EvaluationRun:
Expand All @@ -2125,7 +2139,12 @@ async def create_evaluation_run(
output_config = genai_types.OutputConfig(
gcs_destination=genai_types.GcsDestination(output_uri_prefix=dest)
)
evaluation_config = genai_types.EvaluationConfig(output_config=output_config)
resolved_metrics = _evals_common._resolve_evaluation_run_metrics(
metrics, self._api_client
)
evaluation_config = types.EvaluationRunConfig(
output_config=output_config, metrics=resolved_metrics
)
inference_configs = {}
if agent_info:
logger.warning(
Expand Down
Loading
Loading