Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions tests/unit/vertex_adk/test_agent_engine_templates_adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(self, name: str, model: str):

_TEST_LOCATION = "us-central1"
_TEST_PROJECT = "test-project"
_TEST_PROJECT_ID = "test-project-id"
_TEST_API_KEY = "test-api-key"
_TEST_MODEL = "gemini-2.0-flash"
_TEST_USER_ID = "test_user_id"
Expand Down Expand Up @@ -224,6 +225,15 @@ def adk_version_mock():
yield adk_version_mock


@pytest.fixture
def get_project_id_mock():
with mock.patch(
"google.cloud.aiplatform.aiplatform.utils.resource_manager_utils.get_project_id"
) as get_project_id_mock:
get_project_id_mock.return_value = _TEST_PROJECT_ID
yield get_project_id_mock


class _MockRunner:
def run(self, *args, **kwargs):
from google.adk.events import event
Expand Down Expand Up @@ -757,26 +767,28 @@ def test_tracing_setup(
monkeypatch,
trace_provider_mock: mock.Mock,
otlp_span_exporter_mock: mock.Mock,
get_project_id_mock: mock.Mock,
):
monkeypatch.setattr(
"uuid.uuid4", lambda: uuid.UUID("12345678123456781234567812345678")
)
monkeypatch.setattr("os.getpid", lambda: 123123123)
app = agent_engines.AdkApp(agent=_TEST_AGENT, enable_tracing=True)
app._warn_if_telemetry_api_disabled = lambda: None
app.set_up()

expected_attributes = {
"cloud.account.id": _TEST_PROJECT_ID,
"cloud.platform": "gcp.agent_engine",
"cloud.region": "us-central1",
"cloud.resource_id": "//aiplatform.googleapis.com/projects/test-project-id/locations/us-central1/reasoningEngines/test_agent_id",
"gcp.project_id": _TEST_PROJECT_ID,
"service.instance.id": "12345678123456781234567812345678-123123123",
"service.name": "test_agent_id",
"some-attribute": "some-value",
"telemetry.sdk.language": "python",
"telemetry.sdk.name": "opentelemetry",
"telemetry.sdk.version": "1.36.0",
"gcp.project_id": "test-project",
"cloud.account.id": "test-project",
"cloud.provider": "gcp",
"cloud.platform": "gcp.agent_engine",
"service.name": "test_agent_id",
"cloud.resource_id": "//aiplatform.googleapis.com/projects/test-project/locations/us-central1/reasoningEngines/test_agent_id",
"service.instance.id": "12345678123456781234567812345678-123123123",
"cloud.region": "us-central1",
"some-attribute": "some-value",
}

Expand All @@ -786,6 +798,8 @@ def test_tracing_setup(
headers=mock.ANY,
)

get_project_id_mock.assert_called_once_with(_TEST_PROJECT)

user_agent = otlp_span_exporter_mock.call_args.kwargs["headers"]["User-Agent"]
assert (
re.fullmatch(
Expand Down
30 changes: 22 additions & 8 deletions tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(self, name: str, model: str):

_TEST_LOCATION = "us-central1"
_TEST_PROJECT = "test-project"
_TEST_PROJECT_ID = "test-project-id"
_TEST_MODEL = "gemini-2.0-flash"
_TEST_USER_ID = "test_user_id"
_TEST_AGENT_NAME = "test_agent"
Expand Down Expand Up @@ -173,6 +174,15 @@ def adk_version_mock():
yield adk_version_mock


@pytest.fixture
def get_project_id_mock():
with mock.patch(
"google.cloud.aiplatform.aiplatform.utils.resource_manager_utils.get_project_id"
) as get_project_id_mock:
get_project_id_mock.return_value = _TEST_PROJECT_ID
yield get_project_id_mock


class _MockRunner:
def run(self, *args, **kwargs):
from google.adk.events import event
Expand Down Expand Up @@ -699,26 +709,28 @@ def test_tracing_setup(
monkeypatch: pytest.MonkeyPatch,
trace_provider_mock: mock.Mock,
otlp_span_exporter_mock: mock.Mock,
get_project_id_mock: mock.Mock,
):
monkeypatch.setattr(
"uuid.uuid4", lambda: uuid.UUID("12345678123456781234567812345678")
)
monkeypatch.setattr("os.getpid", lambda: 123123123)
app = reasoning_engines.AdkApp(agent=_TEST_AGENT, enable_tracing=True)
app._warn_if_telemetry_api_disabled = lambda: None
app.set_up()

expected_attributes = {
"cloud.account.id": _TEST_PROJECT_ID,
"cloud.platform": "gcp.agent_engine",
"cloud.region": "us-central1",
"cloud.resource_id": "//aiplatform.googleapis.com/projects/test-project-id/locations/us-central1/reasoningEngines/test_agent_id",
"gcp.project_id": _TEST_PROJECT_ID,
"service.instance.id": "12345678123456781234567812345678-123123123",
"service.name": "test_agent_id",
"some-attribute": "some-value",
"telemetry.sdk.language": "python",
"telemetry.sdk.name": "opentelemetry",
"telemetry.sdk.version": "1.36.0",
"gcp.project_id": "test-project",
"cloud.account.id": "test-project",
"cloud.provider": "gcp",
"cloud.platform": "gcp.agent_engine",
"service.name": "test_agent_id",
"cloud.resource_id": "//aiplatform.googleapis.com/projects/test-project/locations/us-central1/reasoningEngines/test_agent_id",
"service.instance.id": "12345678123456781234567812345678-123123123",
"cloud.region": "us-central1",
"some-attribute": "some-value",
}

Expand All @@ -728,6 +740,8 @@ def test_tracing_setup(
headers=mock.ANY,
)

get_project_id_mock.assert_called_once_with(_TEST_PROJECT)

user_agent = otlp_span_exporter_mock.call_args.kwargs["headers"]["User-Agent"]
assert (
re.fullmatch(
Expand Down
20 changes: 17 additions & 3 deletions vertexai/agent_engines/templates/adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,14 +282,20 @@ async def _force_flush_otel(tracing_enabled: bool, logging_enabled: bool):


def _default_instrumentor_builder(
project_id: str,
project_id: Optional[str],
*,
enable_tracing: bool = False,
enable_logging: bool = False,
):
if not enable_tracing and not enable_logging:
return None

if project_id is None:
_warn(
"telemetry is only supported when project is specified, proceeding with no telemetry"
)
return None

import os

def _warn_missing_dependency(
Expand Down Expand Up @@ -791,11 +797,11 @@ def set_up(self):
)

if custom_instrumentor and self._tracing_enabled():
self._tmpl_attrs["instrumentor"] = custom_instrumentor(project)
self._tmpl_attrs["instrumentor"] = custom_instrumentor(self.project_id())

if not custom_instrumentor:
self._tmpl_attrs["instrumentor"] = _default_instrumentor_builder(
project,
self.project_id(),
enable_tracing=self._tracing_enabled(),
enable_logging=enable_logging,
)
Expand Down Expand Up @@ -1607,3 +1613,11 @@ def _warn_if_telemetry_api_disabled(self):
r = session.post("https://telemetry.googleapis.com/v1/traces", data=None)
if "Telemetry API has not been used in project" in r.text:
_warn(_TELEMETRY_API_DISABLED_WARNING % (project, project))

def project_id(self) -> Optional[str]:
if project := self._tmpl_attrs.get("project"):
from google.cloud.aiplatform.utils import resource_manager_utils

return resource_manager_utils.get_project_id(project)

return None
18 changes: 16 additions & 2 deletions vertexai/preview/reasoning_engines/templates/adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,14 +284,20 @@ async def _force_flush_otel(tracing_enabled: bool, logging_enabled: bool):


def _default_instrumentor_builder(
project_id: str,
project_id: Optional[str],
*,
enable_tracing: bool = False,
enable_logging: bool = False,
):
if not enable_tracing and not enable_logging:
return None

if project_id is None:
_warn(
"telemetry is only supported when project is specified, proceeding with no telemetry"
)
return None

import os

def _warn_missing_dependency(
Expand Down Expand Up @@ -712,7 +718,7 @@ def set_up(self):
enable_logging = bool(self._telemetry_enabled())

self._tmpl_attrs["instrumentor"] = _default_instrumentor_builder(
project,
self.project_id(),
enable_tracing=self._tracing_enabled(),
enable_logging=enable_logging,
)
Expand Down Expand Up @@ -1540,3 +1546,11 @@ def _warn_if_telemetry_api_disabled(self):
r = session.post("https://telemetry.googleapis.com/v1/traces", data=None)
if "Telemetry API has not been used in project" in r.text:
_warn(_TELEMETRY_API_DISABLED_WARNING % (project, project))

def project_id(self) -> Optional[str]:
if project := self._tmpl_attrs.get("project"):
from google.cloud.aiplatform.utils import resource_manager_utils

return resource_manager_utils.get_project_id(project)

return None