Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions tests/unit/vertex_adk/test_agent_engine_templates_adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,34 @@ def trace_provider_mock():
yield tracer_provider_mock


@pytest.fixture
def trace_provider_force_flush_mock():
import opentelemetry.trace
import opentelemetry.sdk.trace

with mock.patch.object(
opentelemetry.trace, "get_tracer_provider"
) as get_tracer_provider_mock:
get_tracer_provider_mock.return_value = mock.Mock(
spec=opentelemetry.sdk.trace.TracerProvider()
)
yield get_tracer_provider_mock.return_value.force_flush


@pytest.fixture
def logger_provider_force_flush_mock():
import opentelemetry._logs
import opentelemetry.sdk._logs

with mock.patch.object(
opentelemetry._logs, "get_logger_provider"
) as get_logger_provider_mock:
get_logger_provider_mock.return_value = mock.Mock(
spec=opentelemetry.sdk._logs.LoggerProvider()
)
yield get_logger_provider_mock.return_value.force_flush


@pytest.fixture
def default_instrumentor_builder_mock():
with mock.patch(
Expand Down Expand Up @@ -351,6 +379,29 @@ async def test_async_stream_query(self):
events.append(event)
assert len(events) == 1

@pytest.mark.asyncio
@mock.patch.dict(
os.environ,
{GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true"},
)
async def test_async_stream_query_force_flush_otel(
self,
trace_provider_force_flush_mock: mock.Mock,
logger_provider_force_flush_mock: mock.Mock,
):
app = agent_engines.AdkApp(agent=_TEST_AGENT)
assert app._tmpl_attrs.get("runner") is None
app.set_up()
app._tmpl_attrs["runner"] = _MockRunner()
async for _ in app.async_stream_query(
user_id=_TEST_USER_ID,
message="test message",
):
pass

trace_provider_force_flush_mock.assert_called_once()
logger_provider_force_flush_mock.assert_called_once()

@pytest.mark.asyncio
async def test_async_stream_query_with_content(self):
app = agent_engines.AdkApp(agent=_TEST_AGENT)
Expand Down Expand Up @@ -403,6 +454,46 @@ async def test_streaming_agent_run_with_events(self):
events.append(event)
assert len(events) == 1

@pytest.mark.asyncio
@mock.patch.dict(
os.environ,
{GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true"},
)
async def test_streaming_agent_run_with_events_force_flush_otel(
self,
trace_provider_force_flush_mock: mock.Mock,
logger_provider_force_flush_mock: mock.Mock,
):
app = agent_engines.AdkApp(agent=_TEST_AGENT)
app.set_up()
app._tmpl_attrs["in_memory_runner"] = _MockRunner()
request_json = json.dumps(
{
"artifacts": [
{
"file_name": "test_file_name",
"versions": [{"version": "v1", "data": "v1data"}],
}
],
"authorizations": {
"test_user_id1": {"access_token": "test_access_token"},
"test_user_id2": {"accessToken": "test-access-token"},
},
"user_id": _TEST_USER_ID,
"message": {
"parts": [{"text": "What is the exchange rate from USD to SEK?"}],
"role": "user",
},
}
)
async for _ in app.streaming_agent_run_with_events(
request_json=request_json,
):
pass

trace_provider_force_flush_mock.assert_called_once()
logger_provider_force_flush_mock.assert_called_once()

@pytest.mark.asyncio
async def test_async_create_session(self):
app = agent_engines.AdkApp(agent=_TEST_AGENT)
Expand Down
93 changes: 93 additions & 0 deletions tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,34 @@ def trace_provider_mock():
yield tracer_provider_mock


@pytest.fixture
def trace_provider_force_flush_mock():
import opentelemetry.trace
import opentelemetry.sdk.trace

with mock.patch.object(
opentelemetry.trace, "get_tracer_provider"
) as get_tracer_provider_mock:
get_tracer_provider_mock.return_value = mock.Mock(
spec=opentelemetry.sdk.trace.TracerProvider()
)
yield get_tracer_provider_mock.return_value.force_flush


@pytest.fixture
def logger_provider_force_flush_mock():
import opentelemetry._logs
import opentelemetry.sdk._logs

with mock.patch.object(
opentelemetry._logs, "get_logger_provider"
) as get_logger_provider_mock:
get_logger_provider_mock.return_value = mock.Mock(
spec=opentelemetry.sdk._logs.LoggerProvider()
)
yield get_logger_provider_mock.return_value.force_flush


@pytest.fixture
def default_instrumentor_builder_mock():
with mock.patch(
Expand Down Expand Up @@ -353,6 +381,31 @@ async def test_async_stream_query(self):
events.append(event)
assert len(events) == 1

@pytest.mark.asyncio
@mock.patch.dict(
os.environ,
{"GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY": "true"},
)
async def test_async_stream_query_force_flush_otel(
self,
trace_provider_force_flush_mock: mock.Mock,
logger_provider_force_flush_mock: mock.Mock,
):
app = reasoning_engines.AdkApp(
agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL), enable_tracing=True
)
assert app._tmpl_attrs.get("runner") is None
app.set_up()
app._tmpl_attrs["runner"] = _MockRunner()
async for _ in app.async_stream_query(
user_id=_TEST_USER_ID,
message="test message",
):
pass

trace_provider_force_flush_mock.assert_called_once()
logger_provider_force_flush_mock.assert_called_once()

@pytest.mark.asyncio
async def test_async_stream_query_with_content(self):
app = reasoning_engines.AdkApp(
Expand Down Expand Up @@ -404,6 +457,46 @@ def test_streaming_agent_run_with_events(self):
events = list(app.streaming_agent_run_with_events(request_json=request_json))
assert len(events) == 1

@pytest.mark.asyncio
@mock.patch.dict(
os.environ,
{"GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY": "true"},
)
async def test_streaming_agent_run_with_events_force_flush_otel(
self,
trace_provider_force_flush_mock: mock.Mock,
logger_provider_force_flush_mock: mock.Mock,
):
app = reasoning_engines.AdkApp(
agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL),
enable_tracing=True,
)
app.set_up()
app._tmpl_attrs["in_memory_runner"] = _MockRunner()
request_json = json.dumps(
{
"artifacts": [
{
"file_name": "test_file_name",
"versions": [{"version": "v1", "data": "v1data"}],
}
],
"authorizations": {
"test_user_id1": {"access_token": "test_access_token"},
"test_user_id2": {"accessToken": "test-access-token"},
},
"user_id": _TEST_USER_ID,
"message": {
"parts": [{"text": "What is the exchange rate from USD to SEK?"}],
"role": "user",
},
}
)
list(app.streaming_agent_run_with_events(request_json=request_json))

trace_provider_force_flush_mock.assert_called_once()
logger_provider_force_flush_mock.assert_called_once()

@pytest.mark.asyncio
async def test_async_bidi_stream_query(self):
app = reasoning_engines.AdkApp(
Expand Down
41 changes: 29 additions & 12 deletions vertexai/agent_engines/templates/adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)

import asyncio
from collections.abc import Awaitable
import queue
import threading
import warnings
Expand Down Expand Up @@ -231,26 +232,38 @@ def _warn(msg: str):
_warn._LOGGER.warning(msg) # pyright: ignore[reportFunctionMemberAccess]


def _force_flush_traces():
async def _force_flush_otel(tracing_enabled: bool, logging_enabled: bool):
try:
import opentelemetry.trace
import opentelemetry._logs
except (ImportError, AttributeError):
_warn(
"Could not force flush traces. opentelemetry-api is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'."
"Could not force flush telemetry data. opentelemetry-api is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'."
)
return None

try:
import opentelemetry.sdk.trace
import opentelemetry.sdk._logs
except (ImportError, AttributeError):
_warn(
"Could not force flush traces. opentelemetry-sdk is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'."
"Could not force flush telemetry data. opentelemetry-sdk is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'."
)
return None

provider = opentelemetry.trace.get_tracer_provider()
if isinstance(provider, opentelemetry.sdk.trace.TracerProvider):
_ = provider.force_flush()
coros: List[Awaitable[bool]] = []

if tracing_enabled:
tracer_provider = opentelemetry.trace.get_tracer_provider()
if isinstance(tracer_provider, opentelemetry.sdk.trace.TracerProvider):
coros.append(asyncio.to_thread(tracer_provider.force_flush))

if logging_enabled:
logger_provider = opentelemetry._logs.get_logger_provider()
if isinstance(logger_provider, opentelemetry.sdk._logs.LoggerProvider):
coros.append(asyncio.to_thread(logger_provider.force_flush))

await asyncio.gather(*coros, return_exceptions=True)


def _default_instrumentor_builder(
Expand Down Expand Up @@ -894,9 +907,11 @@ async def async_stream_query(
# Yield the event data as a dictionary
yield _utils.dump_event_for_json(event)
finally:
# Avoid trace data loss having to do with CPU throttling on instance turndown
if self._tracing_enabled():
_ = await asyncio.to_thread(_force_flush_traces)
# Avoid telemetry data loss having to do with CPU throttling on instance turndown
_ = await _force_flush_otel(
tracing_enabled=self._tracing_enabled(),
logging_enabled=bool(self._telemetry_enabled()),
)

def stream_query(
self,
Expand Down Expand Up @@ -1066,9 +1081,11 @@ async def streaming_agent_run_with_events(self, request_json: str):
user_id=request.user_id,
session_id=session.id,
)
# Avoid trace data loss having to do with CPU throttling on instance turndown
if self._tracing_enabled():
_ = await asyncio.to_thread(_force_flush_traces)
# Avoid telemetry data loss having to do with CPU throttling on instance turndown
_ = await _force_flush_otel(
tracing_enabled=self._tracing_enabled(),
logging_enabled=bool(self._telemetry_enabled()),
)

async def async_get_session(
self,
Expand Down
41 changes: 29 additions & 12 deletions vertexai/preview/reasoning_engines/templates/adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)

import asyncio
from collections.abc import Awaitable
import queue
import threading

Expand Down Expand Up @@ -233,26 +234,38 @@ def _warn(msg: str):
_warn._LOGGER.warning(msg) # pyright: ignore[reportFunctionMemberAccess]


def _force_flush_traces():
async def _force_flush_otel(tracing_enabled: bool, logging_enabled: bool):
try:
import opentelemetry.trace
import opentelemetry._logs
except (ImportError, AttributeError):
_warn(
"Could not force flush traces. opentelemetry-api is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'."
"Could not force flush telemetry data. opentelemetry-api is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'."
)
return None

try:
import opentelemetry.sdk.trace
import opentelemetry.sdk._logs
except (ImportError, AttributeError):
_warn(
"Could not force flush traces. opentelemetry-sdk is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'."
"Could not force flush telemetry data. opentelemetry-sdk is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'."
)
return None

provider = opentelemetry.trace.get_tracer_provider()
if isinstance(provider, opentelemetry.sdk.trace.TracerProvider):
_ = provider.force_flush()
coros: List[Awaitable[bool]] = []

if tracing_enabled:
tracer_provider = opentelemetry.trace.get_tracer_provider()
if isinstance(tracer_provider, opentelemetry.sdk.trace.TracerProvider):
coros.append(asyncio.to_thread(tracer_provider.force_flush))

if logging_enabled:
logger_provider = opentelemetry._logs.get_logger_provider()
if isinstance(logger_provider, opentelemetry.sdk._logs.LoggerProvider):
coros.append(asyncio.to_thread(logger_provider.force_flush))

await asyncio.gather(*coros, return_exceptions=True)


def _default_instrumentor_builder(
Expand Down Expand Up @@ -891,9 +904,11 @@ async def async_stream_query(
# Yield the event data as a dictionary
yield _utils.dump_event_for_json(event)
finally:
# Avoid trace data loss having to do with CPU throttling on instance turndown
if self._tracing_enabled():
_ = await asyncio.to_thread(_force_flush_traces)
# Avoid telemetry data loss having to do with CPU throttling on instance turndown
_ = await _force_flush_otel(
tracing_enabled=self._tracing_enabled(),
logging_enabled=bool(self._telemetry_enabled()),
)

def streaming_agent_run_with_events(self, request_json: str):
import json
Expand Down Expand Up @@ -970,9 +985,11 @@ async def _invoke_agent_async():
user_id=request.user_id,
session_id=session.id,
)
# Avoid trace data loss having to do with CPU throttling on instance turndown
if self._tracing_enabled():
_ = await asyncio.to_thread(_force_flush_traces)
# Avoid telemetry data loss having to do with CPU throttling on instance turndown
_ = await _force_flush_otel(
tracing_enabled=self._tracing_enabled(),
logging_enabled=bool(self._telemetry_enabled()),
)

def _asyncio_thread_main():
try:
Expand Down