Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 175 additions & 15 deletions tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
import base64
import importlib
import json
import dataclasses
import os
from unittest import mock
from typing import Optional

from google import auth
import vertexai
Expand All @@ -25,6 +28,7 @@
from vertexai.preview import reasoning_engines
from google.genai import types
import pytest
import uuid


try:
Expand All @@ -44,6 +48,7 @@ def __init__(self, name: str, model: str):
_TEST_MODEL = "gemini-2.0-flash"
_TEST_USER_ID = "test_user_id"
_TEST_AGENT_NAME = "test_agent"
_TEST_AGENT = Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL)
_TEST_SESSION = {
"id": "ca18c25a-644b-4e13-9b24-78c150ec3eb9",
"app_name": "default-app-name",
Expand Down Expand Up @@ -92,15 +97,6 @@ def vertexai_init_mock():
yield vertexai_init_mock


@pytest.fixture
def cloud_trace_exporter_mock():
with mock.patch.object(
_utils,
"_import_cloud_trace_exporter_or_warn",
) as cloud_trace_exporter_mock:
yield cloud_trace_exporter_mock


@pytest.fixture
def tracer_provider_mock():
with mock.patch("opentelemetry.sdk.trace.TracerProvider") as tracer_provider_mock:
Expand All @@ -116,12 +112,53 @@ def simple_span_processor_mock():


@pytest.fixture
def mock_adk_version():
def cloud_trace_exporter_mock():
import sys
import opentelemetry

mock_cloud_trace_exporter = mock.Mock()

opentelemetry.exporter = type(sys)("exporter")
opentelemetry.exporter.cloud_trace = type(sys)("cloud_trace")
opentelemetry.exporter.cloud_trace.CloudTraceSpanExporter = (
mock_cloud_trace_exporter
)

sys.modules["opentelemetry.exporter"] = opentelemetry.exporter
sys.modules["opentelemetry.exporter.cloud_trace"] = (
opentelemetry.exporter.cloud_trace
)

yield mock_cloud_trace_exporter

del sys.modules["opentelemetry.exporter.cloud_trace"]
del sys.modules["opentelemetry.exporter"]


@pytest.fixture
def trace_provider_mock():
import opentelemetry.sdk.trace

with mock.patch.object(
opentelemetry.sdk.trace, "TracerProvider"
) as tracer_provider_mock:
yield tracer_provider_mock


@pytest.fixture
def default_instrumentor_builder_mock():
with mock.patch(
"google.cloud.aiplatform.vertexai.preview.reasoning_engines.templates.adk.get_adk_version",
return_value="1.5.0",
):
yield
"google.cloud.aiplatform.vertexai.preview.reasoning_engines.templates.adk._default_instrumentor_builder"
) as default_instrumentor_builder_mock:
yield default_instrumentor_builder_mock


@pytest.fixture
def adk_version_mock():
with mock.patch(
"google.cloud.aiplatform.vertexai.preview.reasoning_engines.templates.adk.get_adk_version"
) as adk_version_mock:
yield adk_version_mock


class _MockRunner:
Expand Down Expand Up @@ -520,6 +557,130 @@ async def test_async_search_memory(self):
)
assert len(response.memories) >= 1

@pytest.mark.parametrize(
"adk_version,enable_tracing,enable_telemetry,want_tracing_setup,want_logging_setup",
[
("1.16.0", False, False, False, False),
("1.16.0", False, True, False, True),
("1.16.0", False, None, False, False),
("1.16.0", True, False, False, False),
("1.16.0", True, True, True, True),
("1.16.0", True, None, True, False),
("1.16.0", None, False, False, False),
("1.16.0", None, True, False, True),
("1.16.0", None, None, False, False),
("1.17.0", False, False, False, False),
("1.17.0", False, True, False, True),
("1.17.0", False, None, False, False),
("1.17.0", True, False, False, False),
("1.17.0", True, True, True, True),
("1.17.0", True, None, True, False),
("1.17.0", None, False, False, False),
("1.17.0", None, True, True, True),
("1.17.0", None, None, False, False),
],
)
@mock.patch.dict(os.environ)
def test_default_instrumentor_enablement(
self,
adk_version: str,
enable_tracing: Optional[bool],
enable_telemetry: Optional[bool],
want_tracing_setup: bool,
want_logging_setup: bool,
default_instrumentor_builder_mock: mock.Mock,
adk_version_mock: mock.Mock,
):
# Arrange
adk_version_mock.return_value = adk_version
if enable_telemetry is not None:
os.environ["GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY"] = str(
enable_telemetry
)

app = reasoning_engines.AdkApp(agent=_TEST_AGENT, enable_tracing=enable_tracing)

# Act
app.set_up()

# Assert
default_instrumentor_builder_mock.assert_called_once_with(
_TEST_PROJECT,
enable_tracing=want_tracing_setup,
enable_logging=want_logging_setup,
)

@mock.patch.dict(
os.environ,
{
"GOOGLE_CLOUD_AGENT_ENGINE_ID": "test_agent_id",
"OTEL_RESOURCE_ATTRIBUTES": "some-attribute=some-value",
},
)
def test_tracing_setup(
self,
trace_provider_mock: mock.Mock,
cloud_trace_exporter_mock: mock.Mock,
monkeypatch: pytest.MonkeyPatch,
):
monkeypatch.setattr(
"uuid.uuid4", lambda: uuid.UUID("12345678123456781234567812345678")
)
monkeypatch.setattr("os.getpid", lambda: 123123123)
app = reasoning_engines.AdkApp(agent=_TEST_AGENT, enable_tracing=True)
app.set_up()

expected_attributes = {
"telemetry.sdk.language": "python",
"telemetry.sdk.name": "opentelemetry",
"telemetry.sdk.version": "1.36.0",
"gcp.project_id": "test-project",
"cloud.account.id": "test-project",
"service.name": "test_agent_id",
"cloud.resource_id": "//aiplatform.googleapis.com/projects/test-project/locations/us-central1/reasoningEngines/test_agent_id",
"service.instance.id": "12345678123456781234567812345678-123123123",
"cloud.region": "us-central1",
"some-attribute": "some-value",
}

@dataclasses.dataclass
class RegexMatchingAll:
keys: set[str]

def __eq__(self, regex: object) -> bool:
return isinstance(regex, str) and set(regex.split("|")) == self.keys

cloud_trace_exporter_mock.assert_called_once_with(
project_id=_TEST_PROJECT,
client=mock.ANY,
resource_regex=RegexMatchingAll(keys=set(expected_attributes.keys())),
)

assert (
trace_provider_mock.call_args.kwargs["resource"].attributes
== expected_attributes
)

@mock.patch.dict(os.environ)
def test_span_content_capture_disabled_by_default(self):
app = reasoning_engines.AdkApp(agent=_TEST_AGENT)
app.set_up()
assert os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] == "false"

@mock.patch.dict(
os.environ, {"OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT": "true"}
)
def test_span_content_capture_disabled_with_env_var(self):
app = reasoning_engines.AdkApp(agent=_TEST_AGENT)
app.set_up()
assert os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] == "false"

@mock.patch.dict(os.environ)
def test_span_content_capture_enabled_with_tracing(self):
app = reasoning_engines.AdkApp(agent=_TEST_AGENT, enable_tracing=True)
app.set_up()
assert os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] == "true"

@pytest.mark.usefixtures("caplog")
def test_enable_tracing(
self,
Expand Down Expand Up @@ -584,7 +745,6 @@ def test_dump_event_for_json():
assert base64.b64decode(part["thought_signature"]) == raw_signature


@pytest.mark.usefixtures("mock_adk_version")
class TestAdkAppErrors:
def test_raise_get_session_not_found_error(self):
with pytest.raises(
Expand Down
Loading