Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,83 @@ def __init__(self, name: str, model: str):
"streaming_mode": "sse",
"max_llm_calls": 500,
}
_TEST_SESSION_EVENTS = [
{
"author": "user",
"content": {
"parts": [
{
"text": "What is the exchange rate from US dollars to "
"Swedish krona on 2025-09-25?"
}
],
"role": "user",
},
"id": "8967297909049524224",
"invocationId": "e-308f65d7-a99f-41e3-b80d-40feb5f1b065",
"timestamp": 1765832134.629513,
},
{
"author": "currency_exchange_agent",
"content": {
"parts": [
{
"functionCall": {
"args": {
"currency_date": "2025-09-25",
"currency_from": "USD",
"currency_to": "SEK",
},
"id": "adk-136738ad-9e57-4cfb-8e23-b0f3e50a37d7",
"name": "get_exchange_rate",
}
}
],
"role": "model",
},
"id": "3155402589927899136",
"invocationId": "e-308f65d7-a99f-41e3-b80d-40feb5f1b065",
"timestamp": 1765832134.723713,
},
{
"author": "currency_exchange_agent",
"content": {
"parts": [
{
"functionResponse": {
"id": "adk-136738ad-9e57-4cfb-8e23-b0f3e50a37d7",
"name": "get_exchange_rate",
"response": {
"amount": 1,
"base": "USD",
"date": "2025-09-25",
"rates": {"SEK": 9.4118},
},
}
}
],
"role": "user",
},
"id": "1678221912150376448",
"invocationId": "e-308f65d7-a99f-41e3-b80d-40feb5f1b065",
"timestamp": 1765832135.764961,
},
{
"author": "currency_exchange_agent",
"content": {
"parts": [
{
"text": "The exchange rate from US dollars to Swedish "
"krona on 2025-09-25 is 1 USD to 9.4118 SEK."
}
],
"role": "model",
},
"id": "2470855446567583744",
"invocationId": "e-308f65d7-a99f-41e3-b80d-40feb5f1b065",
"timestamp": 1765832135.853299,
},
]


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -392,6 +469,46 @@ async def test_async_stream_query(self):
events.append(event)
assert len(events) == 1

@pytest.mark.asyncio
async def test_async_stream_query_with_empty_session_events(self):
app = reasoning_engines.AdkApp(
agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL)
)
assert app._tmpl_attrs.get("runner") is None
app.set_up()
app._tmpl_attrs["runner"] = _MockRunner()
events = []
async for event in app.async_stream_query(
user_id=_TEST_USER_ID,
session_events=[],
message="test message",
):
events.append(event)
assert app._tmpl_attrs.get("session_service") is not None
sessions = app.list_sessions(user_id=_TEST_USER_ID)
assert len(sessions.sessions) == 1

@pytest.mark.asyncio
async def test_async_stream_query_with_session_events(
self,
):
app = reasoning_engines.AdkApp(
agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL)
)
assert app._tmpl_attrs.get("runner") is None
app.set_up()
app._tmpl_attrs["runner"] = _MockRunner()
events = []
async for event in app.async_stream_query(
user_id=_TEST_USER_ID,
session_events=_TEST_SESSION_EVENTS,
message="on the day after that?",
):
events.append(event)
assert app._tmpl_attrs.get("session_service") is not None
sessions = app.list_sessions(user_id=_TEST_USER_ID)
assert len(sessions.sessions) == 1

@pytest.mark.asyncio
@mock.patch.dict(
os.environ,
Expand Down
28 changes: 27 additions & 1 deletion vertexai/agent_engines/templates/adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,7 @@ async def async_stream_query(
message: Union[str, Dict[str, Any]],
user_id: str,
session_id: Optional[str] = None,
session_events: Optional[List[Dict[str, Any]]] = None,
run_config: Optional[Dict[str, Any]] = None,
**kwargs,
) -> AsyncIterable[Dict[str, Any]]:
Expand All @@ -944,7 +945,11 @@ async def async_stream_query(
Required. The ID of the user.
session_id (str):
Optional. The ID of the session. If not provided, a new
session will be created for the user.
session will be created for the user. If this is specified, then
`session_events` will be ignored.
session_events (Optional[List[Dict[str, Any]]]):
Optional. The session events to use for the query. This will be
used to initialize the session if `session_id` is not provided.
run_config (Optional[Dict[str, Any]]):
Optional. The run config to use for the query. If you want to
pass in a `run_config` pydantic object, you can pass in a dict
Expand All @@ -955,6 +960,11 @@ async def async_stream_query(

Yields:
Event dictionaries asynchronously.

Raises:
TypeError: If message is not a string or a dictionary representing
a Content object.
ValueError: If both session_id and session_events are specified.
"""
from vertexai.agent_engines import _utils
from google.genai import types
Expand All @@ -971,9 +981,25 @@ async def async_stream_query(

if not self._tmpl_attrs.get("runner"):
self.set_up()
if session_id and session_events:
raise ValueError(
"Only one of session_id and session_events should be specified."
)
if not session_id:
session = await self.async_create_session(user_id=user_id)
session_id = session.id
if session_events is not None:
# We allow for session_events to be an empty list.
from google.adk.events.event import Event

session_service = self._tmpl_attrs.get("session_service")
for event in session_events:
if not isinstance(event, Event):
event = Event.model_validate(event)
await session_service.append_event(
session=session,
event=event,
)

run_config = _validate_run_config(run_config)
if run_config:
Expand Down