Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@
"google-cloud-logging < 4",
"opentelemetry-sdk < 2",
"opentelemetry-exporter-gcp-trace < 2",
"opentelemetry-exporter-otlp-proto-http < 2",
"pydantic >= 2.11.1, < 3",
"typing_extensions",
]
Expand Down
46 changes: 36 additions & 10 deletions tests/unit/vertex_adk/test_agent_engine_templates_adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import sys
from unittest import mock
from typing import Optional
import dataclasses

from google import auth
from google.auth import credentials as auth_credentials
Expand Down Expand Up @@ -153,11 +154,27 @@ def vertexai_init_mock():


@pytest.fixture
def otlp_span_exporter_mock():
with mock.patch(
"opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter"
) as otlp_span_exporter_mock:
yield otlp_span_exporter_mock
def cloud_trace_exporter_mock():
import sys
import opentelemetry

mock_cloud_trace_exporter = mock.Mock()

opentelemetry.exporter = type(sys)("exporter")
opentelemetry.exporter.cloud_trace = type(sys)("cloud_trace")
opentelemetry.exporter.cloud_trace.CloudTraceSpanExporter = (
mock_cloud_trace_exporter
)

sys.modules["opentelemetry.exporter"] = opentelemetry.exporter
sys.modules["opentelemetry.exporter.cloud_trace"] = (
opentelemetry.exporter.cloud_trace
)

yield mock_cloud_trace_exporter

del sys.modules["opentelemetry.exporter.cloud_trace"]
del sys.modules["opentelemetry.exporter"]


@pytest.fixture
Expand Down Expand Up @@ -649,9 +666,9 @@ def test_custom_instrumentor_enablement(
)
def test_tracing_setup(
self,
monkeypatch,
trace_provider_mock: mock.Mock,
otlp_span_exporter_mock: mock.Mock,
cloud_trace_exporter_mock: mock.Mock,
monkeypatch,
):
monkeypatch.setattr(
"uuid.uuid4", lambda: uuid.UUID("12345678123456781234567812345678")
Expand All @@ -673,9 +690,17 @@ def test_tracing_setup(
"some-attribute": "some-value",
}

otlp_span_exporter_mock.assert_called_once_with(
session=mock.ANY,
endpoint="https://telemetry.googleapis.com/v1/traces",
@dataclasses.dataclass
class RegexMatchingAll:
keys: set[str]

def __eq__(self, regex: object) -> bool:
return isinstance(regex, str) and set(regex.split("|")) == self.keys

cloud_trace_exporter_mock.assert_called_once_with(
project_id=_TEST_PROJECT,
client=mock.ANY,
resource_regex=RegexMatchingAll(keys=set(expected_attributes.keys())),
)

assert (
Expand All @@ -687,6 +712,7 @@ def test_tracing_setup(
def test_enable_tracing(
self,
caplog,
cloud_trace_exporter_mock,
tracer_provider_mock,
simple_span_processor_mock,
):
Expand Down
153 changes: 58 additions & 95 deletions vertexai/agent_engines/templates/adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,28 +231,6 @@ def _warn(msg: str):
_warn._LOGGER.warning(msg) # pyright: ignore[reportFunctionMemberAccess]


def _force_flush_traces():
try:
import opentelemetry.trace
except (ImportError, AttributeError):
_warn(
"Could not force flush traces. opentelemetry-api is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'."
)
return None

try:
import opentelemetry.sdk.trace
except (ImportError, AttributeError):
_warn(
"Could not force flush traces. opentelemetry-sdk is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'."
)
return None

provider = opentelemetry.trace.get_tracer_provider()
if isinstance(provider, opentelemetry.sdk.trace.TracerProvider):
_ = provider.force_flush()


def _default_instrumentor_builder(
project_id: str,
*,
Expand Down Expand Up @@ -333,23 +311,28 @@ def _detect_cloud_resource_id(project_id: str) -> Optional[str]:

if enable_tracing:
try:
import opentelemetry.exporter.otlp.proto.http.trace_exporter
import google.auth.transport.requests
import opentelemetry.exporter.cloud_trace
except (ImportError, AttributeError):
return _warn_missing_dependency(
"opentelemetry-exporter-otlp-proto-http", needed_for_tracing=True
"opentelemetry-exporter-gcp-trace", needed_for_tracing=True
)

try:
import google.cloud.trace_v2
except (ImportError, AttributeError):
return _warn_missing_dependency(
"google-cloud-trace", needed_for_tracing=True
)

import google.auth

credentials, _ = google.auth.default()
span_exporter = (
opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter(
session=google.auth.transport.requests.AuthorizedSession(
credentials=credentials
),
endpoint="https://telemetry.googleapis.com/v1/traces",
)
span_exporter = opentelemetry.exporter.cloud_trace.CloudTraceSpanExporter(
project_id=project_id,
client=google.cloud.trace_v2.TraceServiceClient(
credentials=credentials.with_quota_project(project_id),
),
resource_regex="|".join(resource.attributes.keys()),
)
span_processor = opentelemetry.sdk.trace.export.BatchSpanProcessor(
span_exporter=span_exporter,
Expand Down Expand Up @@ -712,17 +695,54 @@ def set_up(self):
else:
os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] = "false"

enable_logging = bool(self._telemetry_enabled())
GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY = (
"GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY"
)

def telemetry_enabled() -> Optional[bool]:
return (
os.getenv(GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY, "0").lower()
in ("true", "1")
if GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY in os.environ
else None
)

# Tracing enablement follows truth table:
def tracing_enabled() -> bool:
"""Tracing enablement follows true table:

| enable_tracing | enable_telemetry(env) | tracing_actually_enabled |
|----------------|-----------------------|--------------------------|
| false | false | false |
| false | true | false |
| false | None | false |
| true | false | false |
| true | true | true |
| true | None | true |
| None(default) | false | false |
| None(default) | true | adk_version >= 1.17 |
| None(default) | None | false |
"""
enable_tracing: Optional[bool] = self._tmpl_attrs.get("enable_tracing")
enable_telemetry: Optional[bool] = telemetry_enabled()

return (enable_tracing is True and enable_telemetry is not False) or (
enable_tracing is None
and enable_telemetry is True
and is_version_sufficient("1.17.0")
)

enable_logging = bool(telemetry_enabled())

custom_instrumentor = self._tmpl_attrs.get("instrumentor_builder")

if custom_instrumentor and self._tracing_enabled():
if custom_instrumentor and tracing_enabled():
self._tmpl_attrs["instrumentor"] = custom_instrumentor(project)

if not custom_instrumentor:
self._tmpl_attrs["instrumentor"] = _default_instrumentor_builder(
project,
enable_tracing=self._tracing_enabled(),
enable_tracing=tracing_enabled(),
enable_logging=enable_logging,
)

Expand Down Expand Up @@ -894,14 +914,9 @@ async def async_stream_query(
**kwargs,
)

try:
async for event in events_async:
# Yield the event data as a dictionary
yield _utils.dump_event_for_json(event)
finally:
# Avoid trace data loss having to do with CPU throttling on instance turndown
if self._tracing_enabled():
_ = await asyncio.to_thread(_force_flush_traces)
async for event in events_async:
# Yield the event data as a dictionary
yield _utils.dump_event_for_json(event)

def stream_query(
self,
Expand Down Expand Up @@ -1053,9 +1068,6 @@ async def streaming_agent_run_with_events(self, request_json: str):
user_id=request.user_id,
session_id=session.id,
)
# Avoid trace data loss having to do with CPU throttling on instance turndown
if self._tracing_enabled():
_ = await asyncio.to_thread(_force_flush_traces)

async def async_get_session(
self,
Expand Down Expand Up @@ -1438,52 +1450,3 @@ def register_operations(self) -> Dict[str, List[str]]:
"streaming_agent_run_with_events",
],
}

def _telemetry_enabled(self) -> Optional[bool]:
"""Return status of telemetry enablement depending on enablement env variable.

In detail:
- Logging is always enabled when telemetry is enabled.
- Tracing is enabled depending on the truth table seen in `_tracing_enabled` method, in order to not break existing user enablement.

Returns:
True if telemetry is enabled, False if telemetry is disabled, or None
if telemetry enablement is not set (i.e. old deployments which don't support this env variable).
"""
import os

GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY = (
"GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY"
)

return (
os.getenv(GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY, "0").lower()
in ("true", "1")
if GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY in os.environ
else None
)

# Tracing enablement follows truth table:
def _tracing_enabled(self) -> bool:
"""Tracing enablement follows true table:

| enable_tracing | enable_telemetry(env) | tracing_actually_enabled |
|----------------|-----------------------|--------------------------|
| false | false | false |
| false | true | false |
| false | None | false |
| true | false | false |
| true | true | true |
| true | None | true |
| None(default) | false | false |
| None(default) | true | adk_version >= 1.17 |
| None(default) | None | false |
"""
enable_tracing: Optional[bool] = self._tmpl_attrs.get("enable_tracing")
enable_telemetry: Optional[bool] = self._telemetry_enabled()

return (enable_tracing is True and enable_telemetry is not False) or (
enable_tracing is None
and enable_telemetry is True
and is_version_sufficient("1.17.0")
)