diff --git a/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py b/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py index 0678eb348c..d646600426 100644 --- a/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py +++ b/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py @@ -78,6 +78,83 @@ def __init__(self, name: str, model: str): "streaming_mode": "sse", "max_llm_calls": 500, } +_TEST_SESSION_EVENTS = [ + { + "author": "user", + "content": { + "parts": [ + { + "text": "What is the exchange rate from US dollars to " + "Swedish krona on 2025-09-25?" + } + ], + "role": "user", + }, + "id": "8967297909049524224", + "invocationId": "e-308f65d7-a99f-41e3-b80d-40feb5f1b065", + "timestamp": 1765832134.629513, + }, + { + "author": "currency_exchange_agent", + "content": { + "parts": [ + { + "functionCall": { + "args": { + "currency_date": "2025-09-25", + "currency_from": "USD", + "currency_to": "SEK", + }, + "id": "adk-136738ad-9e57-4cfb-8e23-b0f3e50a37d7", + "name": "get_exchange_rate", + } + } + ], + "role": "model", + }, + "id": "3155402589927899136", + "invocationId": "e-308f65d7-a99f-41e3-b80d-40feb5f1b065", + "timestamp": 1765832134.723713, + }, + { + "author": "currency_exchange_agent", + "content": { + "parts": [ + { + "functionResponse": { + "id": "adk-136738ad-9e57-4cfb-8e23-b0f3e50a37d7", + "name": "get_exchange_rate", + "response": { + "amount": 1, + "base": "USD", + "date": "2025-09-25", + "rates": {"SEK": 9.4118}, + }, + } + } + ], + "role": "user", + }, + "id": "1678221912150376448", + "invocationId": "e-308f65d7-a99f-41e3-b80d-40feb5f1b065", + "timestamp": 1765832135.764961, + }, + { + "author": "currency_exchange_agent", + "content": { + "parts": [ + { + "text": "The exchange rate from US dollars to Swedish " + "krona on 2025-09-25 is 1 USD to 9.4118 SEK." + } + ], + "role": "model", + }, + "id": "2470855446567583744", + "invocationId": "e-308f65d7-a99f-41e3-b80d-40feb5f1b065", + "timestamp": 1765832135.853299, + }, +] @pytest.fixture(scope="module") @@ -392,6 +469,46 @@ async def test_async_stream_query(self): events.append(event) assert len(events) == 1 + @pytest.mark.asyncio + async def test_async_stream_query_with_empty_session_events(self): + app = reasoning_engines.AdkApp( + agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL) + ) + assert app._tmpl_attrs.get("runner") is None + app.set_up() + app._tmpl_attrs["runner"] = _MockRunner() + events = [] + async for event in app.async_stream_query( + user_id=_TEST_USER_ID, + session_events=[], + message="test message", + ): + events.append(event) + assert app._tmpl_attrs.get("session_service") is not None + sessions = app.list_sessions(user_id=_TEST_USER_ID) + assert len(sessions.sessions) == 1 + + @pytest.mark.asyncio + async def test_async_stream_query_with_session_events( + self, + ): + app = reasoning_engines.AdkApp( + agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL) + ) + assert app._tmpl_attrs.get("runner") is None + app.set_up() + app._tmpl_attrs["runner"] = _MockRunner() + events = [] + async for event in app.async_stream_query( + user_id=_TEST_USER_ID, + session_events=_TEST_SESSION_EVENTS, + message="on the day after that?", + ): + events.append(event) + assert app._tmpl_attrs.get("session_service") is not None + sessions = app.list_sessions(user_id=_TEST_USER_ID) + assert len(sessions.sessions) == 1 + @pytest.mark.asyncio @mock.patch.dict( os.environ, diff --git a/vertexai/agent_engines/templates/adk.py b/vertexai/agent_engines/templates/adk.py index df5ed03da0..2a9620ce57 100644 --- a/vertexai/agent_engines/templates/adk.py +++ b/vertexai/agent_engines/templates/adk.py @@ -932,6 +932,7 @@ async def async_stream_query( message: Union[str, Dict[str, Any]], user_id: str, session_id: Optional[str] = None, + session_events: Optional[List[Dict[str, Any]]] = None, run_config: Optional[Dict[str, Any]] = None, **kwargs, ) -> AsyncIterable[Dict[str, Any]]: @@ -944,7 +945,11 @@ async def async_stream_query( Required. The ID of the user. session_id (str): Optional. The ID of the session. If not provided, a new - session will be created for the user. + session will be created for the user. If this is specified, then + `session_events` will be ignored. + session_events (Optional[List[Dict[str, Any]]]): + Optional. The session events to use for the query. This will be + used to initialize the session if `session_id` is not provided. run_config (Optional[Dict[str, Any]]): Optional. The run config to use for the query. If you want to pass in a `run_config` pydantic object, you can pass in a dict @@ -955,6 +960,11 @@ async def async_stream_query( Yields: Event dictionaries asynchronously. + + Raises: + TypeError: If message is not a string or a dictionary representing + a Content object. + ValueError: If both session_id and session_events are specified. """ from vertexai.agent_engines import _utils from google.genai import types @@ -971,9 +981,25 @@ async def async_stream_query( if not self._tmpl_attrs.get("runner"): self.set_up() + if session_id and session_events: + raise ValueError( + "Only one of session_id and session_events should be specified." + ) if not session_id: session = await self.async_create_session(user_id=user_id) session_id = session.id + if session_events is not None: + # We allow for session_events to be an empty list. + from google.adk.events.event import Event + + session_service = self._tmpl_attrs.get("session_service") + for event in session_events: + if not isinstance(event, Event): + event = Event.model_validate(event) + await session_service.append_event( + session=session, + event=event, + ) run_config = _validate_run_config(run_config) if run_config: