diff --git a/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py b/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py index 5bf3bba293..c6e9c02e7c 100644 --- a/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py +++ b/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py @@ -16,7 +16,10 @@ import base64 import importlib import json +import dataclasses +import os from unittest import mock +from typing import Optional from google import auth import vertexai @@ -25,6 +28,7 @@ from vertexai.preview import reasoning_engines from google.genai import types import pytest +import uuid try: @@ -44,6 +48,7 @@ def __init__(self, name: str, model: str): _TEST_MODEL = "gemini-2.0-flash" _TEST_USER_ID = "test_user_id" _TEST_AGENT_NAME = "test_agent" +_TEST_AGENT = Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL) _TEST_SESSION = { "id": "ca18c25a-644b-4e13-9b24-78c150ec3eb9", "app_name": "default-app-name", @@ -92,15 +97,6 @@ def vertexai_init_mock(): yield vertexai_init_mock -@pytest.fixture -def cloud_trace_exporter_mock(): - with mock.patch.object( - _utils, - "_import_cloud_trace_exporter_or_warn", - ) as cloud_trace_exporter_mock: - yield cloud_trace_exporter_mock - - @pytest.fixture def tracer_provider_mock(): with mock.patch("opentelemetry.sdk.trace.TracerProvider") as tracer_provider_mock: @@ -116,12 +112,53 @@ def simple_span_processor_mock(): @pytest.fixture -def mock_adk_version(): +def cloud_trace_exporter_mock(): + import sys + import opentelemetry + + mock_cloud_trace_exporter = mock.Mock() + + opentelemetry.exporter = type(sys)("exporter") + opentelemetry.exporter.cloud_trace = type(sys)("cloud_trace") + opentelemetry.exporter.cloud_trace.CloudTraceSpanExporter = ( + mock_cloud_trace_exporter + ) + + sys.modules["opentelemetry.exporter"] = opentelemetry.exporter + sys.modules["opentelemetry.exporter.cloud_trace"] = ( + opentelemetry.exporter.cloud_trace + ) + + yield mock_cloud_trace_exporter + + del sys.modules["opentelemetry.exporter.cloud_trace"] + del sys.modules["opentelemetry.exporter"] + + +@pytest.fixture +def trace_provider_mock(): + import opentelemetry.sdk.trace + + with mock.patch.object( + opentelemetry.sdk.trace, "TracerProvider" + ) as tracer_provider_mock: + yield tracer_provider_mock + + +@pytest.fixture +def default_instrumentor_builder_mock(): with mock.patch( - "google.cloud.aiplatform.vertexai.preview.reasoning_engines.templates.adk.get_adk_version", - return_value="1.5.0", - ): - yield + "google.cloud.aiplatform.vertexai.preview.reasoning_engines.templates.adk._default_instrumentor_builder" + ) as default_instrumentor_builder_mock: + yield default_instrumentor_builder_mock + + +@pytest.fixture +def adk_version_mock(): + with mock.patch( + "google.cloud.aiplatform.vertexai.preview.reasoning_engines.templates.adk.get_adk_version" + ) as adk_version_mock: + yield adk_version_mock class _MockRunner: @@ -520,6 +557,130 @@ async def test_async_search_memory(self): ) assert len(response.memories) >= 1 + @pytest.mark.parametrize( + "adk_version,enable_tracing,enable_telemetry,want_tracing_setup,want_logging_setup", + [ + ("1.16.0", False, False, False, False), + ("1.16.0", False, True, False, True), + ("1.16.0", False, None, False, False), + ("1.16.0", True, False, False, False), + ("1.16.0", True, True, True, True), + ("1.16.0", True, None, True, False), + ("1.16.0", None, False, False, False), + ("1.16.0", None, True, False, True), + ("1.16.0", None, None, False, False), + ("1.17.0", False, False, False, False), + ("1.17.0", False, True, False, True), + ("1.17.0", False, None, False, False), + ("1.17.0", True, False, False, False), + ("1.17.0", True, True, True, True), + ("1.17.0", True, None, True, False), + ("1.17.0", None, False, False, False), + ("1.17.0", None, True, True, True), + ("1.17.0", None, None, False, False), + ], + ) + @mock.patch.dict(os.environ) + def test_default_instrumentor_enablement( + self, + adk_version: str, + enable_tracing: Optional[bool], + enable_telemetry: Optional[bool], + want_tracing_setup: bool, + want_logging_setup: bool, + default_instrumentor_builder_mock: mock.Mock, + adk_version_mock: mock.Mock, + ): + # Arrange + adk_version_mock.return_value = adk_version + if enable_telemetry is not None: + os.environ["GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY"] = str( + enable_telemetry + ) + + app = reasoning_engines.AdkApp(agent=_TEST_AGENT, enable_tracing=enable_tracing) + + # Act + app.set_up() + + # Assert + default_instrumentor_builder_mock.assert_called_once_with( + _TEST_PROJECT, + enable_tracing=want_tracing_setup, + enable_logging=want_logging_setup, + ) + + @mock.patch.dict( + os.environ, + { + "GOOGLE_CLOUD_AGENT_ENGINE_ID": "test_agent_id", + "OTEL_RESOURCE_ATTRIBUTES": "some-attribute=some-value", + }, + ) + def test_tracing_setup( + self, + trace_provider_mock: mock.Mock, + cloud_trace_exporter_mock: mock.Mock, + monkeypatch: pytest.MonkeyPatch, + ): + monkeypatch.setattr( + "uuid.uuid4", lambda: uuid.UUID("12345678123456781234567812345678") + ) + monkeypatch.setattr("os.getpid", lambda: 123123123) + app = reasoning_engines.AdkApp(agent=_TEST_AGENT, enable_tracing=True) + app.set_up() + + expected_attributes = { + "telemetry.sdk.language": "python", + "telemetry.sdk.name": "opentelemetry", + "telemetry.sdk.version": "1.36.0", + "gcp.project_id": "test-project", + "cloud.account.id": "test-project", + "service.name": "test_agent_id", + "cloud.resource_id": "//aiplatform.googleapis.com/projects/test-project/locations/us-central1/reasoningEngines/test_agent_id", + "service.instance.id": "12345678123456781234567812345678-123123123", + "cloud.region": "us-central1", + "some-attribute": "some-value", + } + + @dataclasses.dataclass + class RegexMatchingAll: + keys: set[str] + + def __eq__(self, regex: object) -> bool: + return isinstance(regex, str) and set(regex.split("|")) == self.keys + + cloud_trace_exporter_mock.assert_called_once_with( + project_id=_TEST_PROJECT, + client=mock.ANY, + resource_regex=RegexMatchingAll(keys=set(expected_attributes.keys())), + ) + + assert ( + trace_provider_mock.call_args.kwargs["resource"].attributes + == expected_attributes + ) + + @mock.patch.dict(os.environ) + def test_span_content_capture_disabled_by_default(self): + app = reasoning_engines.AdkApp(agent=_TEST_AGENT) + app.set_up() + assert os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] == "false" + + @mock.patch.dict( + os.environ, {"OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT": "true"} + ) + def test_span_content_capture_disabled_with_env_var(self): + app = reasoning_engines.AdkApp(agent=_TEST_AGENT) + app.set_up() + assert os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] == "false" + + @mock.patch.dict(os.environ) + def test_span_content_capture_enabled_with_tracing(self): + app = reasoning_engines.AdkApp(agent=_TEST_AGENT, enable_tracing=True) + app.set_up() + assert os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] == "true" + @pytest.mark.usefixtures("caplog") def test_enable_tracing( self, @@ -584,7 +745,6 @@ def test_dump_event_for_json(): assert base64.b64decode(part["thought_signature"]) == raw_signature -@pytest.mark.usefixtures("mock_adk_version") class TestAdkAppErrors: def test_raise_get_session_not_found_error(self): with pytest.raises( diff --git a/vertexai/preview/reasoning_engines/templates/adk.py b/vertexai/preview/reasoning_engines/templates/adk.py index 2d39ca6cdd..9cc9dc876c 100644 --- a/vertexai/preview/reasoning_engines/templates/adk.py +++ b/vertexai/preview/reasoning_engines/templates/adk.py @@ -222,31 +222,121 @@ def dump(self) -> Dict[str, Any]: return result -def _default_instrumentor_builder(project_id: str): - from vertexai.agent_engines import _utils - - cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn() - cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn() - opentelemetry = _utils._import_opentelemetry_or_warn() - opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn() - if all( - ( - cloud_trace_exporter, - cloud_trace_v2, - opentelemetry, - opentelemetry_sdk_trace, +def _warn(msg: str): + if not hasattr(_warn, "_LOGGER"): + from google.cloud.aiplatform import base + + _warn._LOGGER = base.Logger( + __name__ + ) # pyright: ignore[reportFunctionMemberAccess] + + _warn._LOGGER.warning(msg) # pyright: ignore[reportFunctionMemberAccess] + + +def _default_instrumentor_builder( + project_id: str, + *, + enable_tracing: bool = False, + enable_logging: bool = False, +): + if not enable_tracing and not enable_logging: + return None + + import os + + def _warn_missing_dependency( + package: str, + *, + needed_for_logging: bool = False, + needed_for_tracing: bool = False, + ) -> None: + _warn( + f"{package} is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'." ) - ): + MISSING_TRACE_IMPORT_ERROR_MESSAGE = "proceeding with tracing disabled because not all packages (i.e. `google-cloud-trace`, `opentelemetry-sdk`, `opentelemetry-exporter-gcp-trace`) for tracing have been installed" + MISSING_LOGGING_IMPORT_ERROR_MESSAGE = "proceeding with logging disabled because not all packages (i.e. `google-cloud-logging`, `opentelemetry-sdk`, `opentelemetry-exporter-gcp-logging`) for tracing have been installed" + + if needed_for_tracing and enable_tracing: + _warn(MISSING_TRACE_IMPORT_ERROR_MESSAGE) + if needed_for_logging and enable_logging: + _warn(MISSING_LOGGING_IMPORT_ERROR_MESSAGE) + return None + + def _detect_cloud_resource_id(project_id: str) -> Optional[str]: + location = os.getenv("GOOGLE_CLOUD_LOCATION", None) + agent_engine_id = os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_ID", None) + if all(v is not None for v in (location, agent_engine_id)): + return f"//aiplatform.googleapis.com/projects/{project_id}/locations/{location}/reasoningEngines/{agent_engine_id}" + return None + + try: + import opentelemetry + import opentelemetry.trace + import opentelemetry._logs + import opentelemetry._events + except (ImportError, AttributeError): + return _warn_missing_dependency( + "opentelemetry-api", needed_for_tracing=True, needed_for_logging=True + ) + + try: + import opentelemetry.sdk.resources + import opentelemetry.sdk.trace + import opentelemetry.sdk.trace.export + import opentelemetry.sdk._logs + import opentelemetry.sdk._logs.export + import opentelemetry.sdk._events + except (ImportError, AttributeError): + return _warn_missing_dependency( + "opentelemetry-sdk", needed_for_tracing=True, needed_for_logging=True + ) + + import uuid + + # Provide a set of resource attributes but allow to override them with env + # variables like OTEL_RESOURCE_ATTRIBUTES and OTEL_SERVICE_NAME. + cloud_resource_id = _detect_cloud_resource_id(project_id) + resource = opentelemetry.sdk.resources.Resource.create( + attributes={ + "gcp.project_id": project_id, + "cloud.account.id": project_id, + "service.name": os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_ID", ""), + "service.instance.id": f"{uuid.uuid4().hex}-{os.getpid()}", + "cloud.region": os.getenv("GOOGLE_CLOUD_LOCATION", ""), + } + | ( + {"cloud.resource_id": cloud_resource_id} + if cloud_resource_id is not None + else {} + ) + ).merge(opentelemetry.sdk.resources.OTELResourceDetector().detect()) + + if enable_tracing: + try: + import opentelemetry.exporter.cloud_trace + except (ImportError, AttributeError): + return _warn_missing_dependency( + "opentelemetry-exporter-gcp-trace", needed_for_tracing=True + ) + + try: + import google.cloud.trace_v2 + except (ImportError, AttributeError): + return _warn_missing_dependency( + "google-cloud-trace", needed_for_tracing=True + ) + import google.auth credentials, _ = google.auth.default() - span_exporter = cloud_trace_exporter.CloudTraceSpanExporter( + span_exporter = opentelemetry.exporter.cloud_trace.CloudTraceSpanExporter( project_id=project_id, - client=cloud_trace_v2.TraceServiceClient( + client=google.cloud.trace_v2.TraceServiceClient( credentials=credentials.with_quota_project(project_id), ), + resource_regex="|".join(resource.attributes.keys()), ) - span_processor = opentelemetry_sdk_trace.export.BatchSpanProcessor( + span_processor = opentelemetry.sdk.trace.export.BatchSpanProcessor( span_exporter=span_exporter, ) tracer_provider = opentelemetry.trace.get_tracer_provider() @@ -258,40 +348,67 @@ def _default_instrumentor_builder(project_id: str): # If none of the above is set, we log a warning, and # create a tracer provider. if not tracer_provider: - from google.cloud.aiplatform import base - - _LOGGER = base.Logger(__name__) - _LOGGER.warning( + _warn( "No tracer provider. By default, " "we should get one of the following providers: " "OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, " "or _PROXY_TRACER_PROVIDER." ) - tracer_provider = opentelemetry_sdk_trace.TracerProvider() + tracer_provider = opentelemetry.sdk.trace.TracerProvider(resource=resource) opentelemetry.trace.set_tracer_provider(tracer_provider) # Avoids AttributeError: # 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no # attribute 'add_span_processor'. + from vertexai.agent_engines import _utils + if _utils.is_noop_or_proxy_tracer_provider(tracer_provider): - tracer_provider = opentelemetry_sdk_trace.TracerProvider() + tracer_provider = opentelemetry.sdk.trace.TracerProvider(resource=resource) opentelemetry.trace.set_tracer_provider(tracer_provider) # Avoids OpenTelemetry client already exists error. _override_active_span_processor( tracer_provider, - opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(), + opentelemetry.sdk.trace.SynchronousMultiSpanProcessor(), ) tracer_provider.add_span_processor(span_processor) - return None - else: - from google.cloud.aiplatform import base - _LOGGER = base.Logger(__name__) - _LOGGER.warning( - "enable_tracing=True but proceeding with tracing disabled " - "because not all packages (i.e. `google-cloud-trace`, `opentelemetry-sdk`, " - "`opentelemetry-exporter-gcp-trace`) for tracing have been installed" + if enable_logging: + try: + import opentelemetry.exporter.cloud_logging + except (ImportError, AttributeError): + return _warn_missing_dependency( + "opentelemetry-exporter-gcp-logging", needed_for_logging=True + ) + + logger_provider = opentelemetry.sdk._logs.LoggerProvider(resource=resource) + logger_provider.add_log_record_processor( + opentelemetry.sdk._logs.export.BatchLogRecordProcessor( + opentelemetry.exporter.cloud_logging.CloudLoggingExporter( + project_id=project_id, + default_log_name=os.getenv( + "GCP_DEFAULT_LOG_NAME", "adk-on-agent-engine" + ), + ), + ) + ) + event_logger_provider = opentelemetry.sdk._events.EventLoggerProvider( + logger_provider=logger_provider + ) + + opentelemetry._logs.set_logger_provider(logger_provider=logger_provider) + opentelemetry._events.set_event_logger_provider( + event_logger_provider=event_logger_provider ) - return None + + try: + from opentelemetry.instrumentation import google_genai + + google_genai.GoogleGenAiSdkInstrumentor().instrument() + except (ImportError, AttributeError): + _warn( + "telemetry enabled but proceeding without GenAI instrumentation, because not all packages (i.e. opentelemetry-instrumentation-google-genai) have been installed" + ) + + return None def _override_active_span_processor( @@ -506,10 +623,23 @@ def set_up(self): os.environ["GOOGLE_CLOUD_PROJECT"] = project location = self._tmpl_attrs.get("location") os.environ["GOOGLE_CLOUD_LOCATION"] = location + + # Disable content capture in custom ADK spans unless user enabled + # tracing explicitly with the old flag + # (this is to preserve compatibility with old behavior). if self._tmpl_attrs.get("enable_tracing"): - self._tmpl_attrs["instrumentor"] = _default_instrumentor_builder( - project_id=project - ) + os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] = "true" + else: + os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] = "false" + + enable_logging = bool(self._telemetry_enabled()) + + self._tmpl_attrs["instrumentor"] = _default_instrumentor_builder( + project, + enable_tracing=self._tracing_enabled(), + enable_logging=enable_logging, + ) + for key, value in self._tmpl_attrs.get("env_vars").items(): os.environ[key] = value if "GOOGLE_CLOUD_AGENT_ENGINE_ID" in os.environ: @@ -1238,3 +1368,52 @@ def register_operations(self) -> Dict[str, List[str]]: "async_stream": ["async_stream_query"], "bidi_stream": ["bidi_stream_query"], } + + def _telemetry_enabled(self) -> Optional[bool]: + """Return status of telemetry enablement depending on enablement env variable. + + In detail: + - Logging is always enabled when telemetry is enabled. + - Tracing is enabled depending on the truth table seen in `_tracing_enabled` method, in order to not break existing user enablement. + + Returns: + True if telemetry is enabled, False if telemetry is disabled, or None + if telemetry enablement is not set (i.e. old deployments which don't support this env variable). + """ + import os + + GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY = ( + "GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY" + ) + + return ( + os.getenv(GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY, "0").lower() + in ("true", "1") + if GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY in os.environ + else None + ) + + # Tracing enablement follows truth table: + def _tracing_enabled(self) -> bool: + """Tracing enablement follows true table: + + | enable_tracing | enable_telemetry(env) | tracing_actually_enabled | + |----------------|-----------------------|--------------------------| + | false | false | false | + | false | true | false | + | false | None | false | + | true | false | false | + | true | true | true | + | true | None | true | + | None(default) | false | false | + | None(default) | true | adk_version >= 1.17 | + | None(default) | None | false | + """ + enable_tracing: Optional[bool] = self._tmpl_attrs.get("enable_tracing") + enable_telemetry: Optional[bool] = self._telemetry_enabled() + + return (enable_tracing is True and enable_telemetry is not False) or ( + enable_tracing is None + and enable_telemetry is True + and is_version_sufficient("1.17.0") + )