Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions tests/unit/vertexai/genai/test_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -1515,6 +1515,64 @@ def test_run_inference_with_litellm_parsing(
pd.testing.assert_frame_equal(call_kwargs["prompt_dataset"], mock_df)


@pytest.mark.usefixtures("google_auth_mock")
class TestEvalsMetricHandlers:
"""Unit tests for utility functions in _evals_metric_handlers."""

def test_has_tool_call_with_tool_call(self):
events = [
vertexai_genai_types.evals.Event(
event_id="1",
content=genai_types.Content(
parts=[
genai_types.Part(
function_call=genai_types.FunctionCall(
name="search", args={}
)
)
]
),
)
]
assert _evals_metric_handlers._has_tool_call(events)

def test_has_tool_call_no_tool_call(self):
events = [
vertexai_genai_types.evals.Event(
event_id="1",
content=genai_types.Content(parts=[genai_types.Part(text="hello")]),
)
]
assert not _evals_metric_handlers._has_tool_call(events)

def test_has_tool_call_empty_events(self):
assert not _evals_metric_handlers._has_tool_call([])

def test_has_tool_call_none_events(self):
assert not _evals_metric_handlers._has_tool_call(None)

def test_has_tool_call_mixed_events(self):
events = [
vertexai_genai_types.evals.Event(
event_id="1",
content=genai_types.Content(parts=[genai_types.Part(text="hello")]),
),
vertexai_genai_types.evals.Event(
event_id="2",
content=genai_types.Content(
parts=[
genai_types.Part(
function_call=genai_types.FunctionCall(
name="search", args={}
)
)
]
),
),
]
assert _evals_metric_handlers._has_tool_call(events)


@pytest.mark.usefixtures("google_auth_mock")
class TestRunAgentInternal:
"""Unit tests for the _run_agent_internal function."""
Expand Down Expand Up @@ -3890,6 +3948,39 @@ def test_eval_case_to_agent_data_agent_info_empty(self):

assert agent_data.agent_config is None

@mock.patch.object(_evals_metric_handlers.logger, "warning")
def test_tool_use_quality_metric_no_tool_call_logs_warning(
self, mock_warning, mock_api_client_fixture
):
"""Tests that PredefinedMetricHandler warns for tool_use_quality_v1 if no tool call."""
metric = vertexai_genai_types.Metric(name="tool_use_quality_v1")
handler = _evals_metric_handlers.PredefinedMetricHandler(
module=evals.Evals(api_client_=mock_api_client_fixture), metric=metric
)
eval_case = vertexai_genai_types.EvalCase(
eval_case_id="case-no-tool-call",
prompt=genai_types.Content(parts=[genai_types.Part(text="Hello")]),
responses=[
vertexai_genai_types.ResponseCandidate(
response=genai_types.Content(parts=[genai_types.Part(text="Hi")])
)
],
intermediate_events=[
vertexai_genai_types.evals.Event(
event_id="event1",
content=genai_types.Content(
parts=[genai_types.Part(text="intermediate event")]
),
)
],
)
handler._build_request_payload(eval_case, response_index=0)
mock_warning.assert_called_once_with(
"Metric 'tool_use_quality_v1' requires tool usage in "
"'intermediate_events', but no tool usage was found for case %s.",
"case-no-tool-call",
)


@pytest.mark.usefixtures("google_auth_mock")
class TestLLMMetricHandlerPayload:
Expand Down
20 changes: 20 additions & 0 deletions vertexai/_genai/_evals_metric_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@
_MAX_RETRIES = 3


def _has_tool_call(intermediate_events: Optional[list[types.evals.Event]]) -> bool:
"""Checks if any event in intermediate_events has a function call."""
if not intermediate_events:
return False
for event in intermediate_events:
if event.content and event.content.parts:
for part in event.content.parts:
if hasattr(part, "function_call") and part.function_call:
return True
return False


def _extract_text_from_content(
content: Optional[genai_types.Content], warn_property: str = "text"
) -> Optional[str]:
Expand Down Expand Up @@ -903,6 +915,14 @@ def _build_request_payload(
f"Response content missing for candidate {response_index}."
)

if self.metric.name == "tool_use_quality_v1":
if not _has_tool_call(eval_case.intermediate_events):
logger.warning(
"Metric 'tool_use_quality_v1' requires tool usage in "
"'intermediate_events', but no tool usage was found for case %s.",
eval_case.eval_case_id,
)

reference_instance_data = None
if eval_case.reference:
reference_instance_data = PredefinedMetricHandler._content_to_instance_data(
Expand Down
Loading