diff --git a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py index de672bc451..4dfcf5acdd 100644 --- a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py +++ b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py @@ -264,6 +264,39 @@ def test_register_operations(self): for operation in operations: assert operation in dir(app) + def test_stream_query(self): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + assert app._tmpl_attrs.get("runner") is None + app.set_up() + app._tmpl_attrs["runner"] = _MockRunner() + events = list( + app.stream_query( + user_id=_TEST_USER_ID, + message="test message", + ) + ) + assert len(events) == 1 + + def test_stream_query_with_content(self): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + assert app._tmpl_attrs.get("runner") is None + app.set_up() + app._tmpl_attrs["runner"] = _MockRunner() + events = list( + app.stream_query( + user_id=_TEST_USER_ID, + message=types.Content( + role="user", + parts=[ + types.Part( + text="test message with content", + ) + ], + ).model_dump(), + ) + ) + assert len(events) == 1 + @pytest.mark.asyncio async def test_async_stream_query(self): app = agent_engines.AdkApp(agent=_TEST_AGENT) @@ -385,6 +418,51 @@ async def test_async_delete_session(self): response0 = await app.async_list_sessions(user_id=_TEST_USER_ID) assert not response0.sessions + def test_create_session(self): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + session1 = app.create_session(user_id=_TEST_USER_ID) + assert session1.user_id == _TEST_USER_ID + session2 = app.create_session( + user_id=_TEST_USER_ID, session_id="test_session_id" + ) + assert session2.user_id == _TEST_USER_ID + assert session2.id == "test_session_id" + + def test_get_session(self): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + session1 = app.create_session(user_id=_TEST_USER_ID) + session2 = app.get_session( + user_id=_TEST_USER_ID, + session_id=session1.id, + ) + assert session2.user_id == _TEST_USER_ID + assert session1.id == session2.id + + def test_list_sessions(self): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + response0 = app.list_sessions(user_id=_TEST_USER_ID) + assert not response0.sessions + session = app.create_session(user_id=_TEST_USER_ID) + response1 = app.list_sessions(user_id=_TEST_USER_ID) + assert len(response1.sessions) == 1 + assert response1.sessions[0].id == session.id + session2 = app.create_session(user_id=_TEST_USER_ID) + response2 = app.list_sessions(user_id=_TEST_USER_ID) + assert len(response2.sessions) == 2 + assert response2.sessions[0].id == session.id + assert response2.sessions[1].id == session2.id + + def test_delete_session(self): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + response = app.delete_session(user_id=_TEST_USER_ID, session_id="") + assert not response + session = app.create_session(user_id=_TEST_USER_ID) + response1 = app.list_sessions(user_id=_TEST_USER_ID) + assert len(response1.sessions) == 1 + app.delete_session(user_id=_TEST_USER_ID, session_id=session.id) + response0 = app.list_sessions(user_id=_TEST_USER_ID) + assert not response0.sessions + @pytest.mark.asyncio async def test_async_add_session_to_memory_dict(self): app = agent_engines.AdkApp(agent=_TEST_AGENT) diff --git a/vertexai/agent_engines/templates/adk.py b/vertexai/agent_engines/templates/adk.py index ad93acd405..1c5bc9d458 100644 --- a/vertexai/agent_engines/templates/adk.py +++ b/vertexai/agent_engines/templates/adk.py @@ -24,6 +24,11 @@ Union, ) +import asyncio +import queue +import threading +import warnings + if TYPE_CHECKING: try: from google.adk.events.event import Event @@ -858,6 +863,85 @@ async def async_stream_query( # Yield the event data as a dictionary yield _utils.dump_event_for_json(event) + def stream_query( + self, + *, + message: Union[str, Dict[str, Any]], + user_id: str, + session_id: Optional[str] = None, + run_config: Optional[Dict[str, Any]] = None, + **kwargs, + ): + """Deprecated. Use async_stream_query instead. + + Streams responses from the ADK application in response to a message. + + Args: + message (Union[str, Dict[str, Any]]): + Required. The message to stream responses for. + user_id (str): + Required. The ID of the user. + session_id (str): + Optional. The ID of the session. If not provided, a new + session will be created for the user. + run_config (Optional[Dict[str, Any]]): + Optional. The run config to use for the query. If you want to + pass in a `run_config` pydantic object, you can pass in a dict + representing it as `run_config.model_dump(mode="json")`. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + runner. + + Yields: + The output of querying the ADK application. + """ + warnings.warn( + ( + "AdkApp.stream_query(...) is deprecated. " + "Use AdkApp.async_stream_query(...) instead. See " + "https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/use/adk#stream-responses " + "for more details." + ), + DeprecationWarning, + stacklevel=2, + ) + from vertexai.agent_engines import _utils + from google.genai import types + + if isinstance(message, Dict): + content = types.Content.model_validate(message) + elif isinstance(message, str): + content = types.Content(role="user", parts=[types.Part(text=message)]) + else: + raise TypeError( + "message must be a string or a dictionary representing" + " a Content object." + ) + + if not self._tmpl_attrs.get("runner"): + self.set_up() + if not session_id: + session = self.create_session(user_id=user_id) + session_id = session.id + run_config = _validate_run_config(run_config) + if run_config: + for event in self._tmpl_attrs.get("runner").run( + user_id=user_id, + session_id=session_id, + new_message=content, + run_config=run_config, + **kwargs, + ): + yield _utils.dump_event_for_json(event) + else: + for event in self._tmpl_attrs.get("runner").run( + user_id=user_id, + session_id=session_id, + new_message=content, + **kwargs, + ): + yield _utils.dump_event_for_json(event) + async def streaming_agent_run_with_events(self, request_json: str): """Streams responses asynchronously from the ADK application. @@ -967,6 +1051,56 @@ async def async_get_session( ) return session + def get_session( + self, + *, + user_id: str, + session_id: str, + **kwargs, + ): + """Deprecated. Use async_get_session instead. + + Get a session for the given user. + """ + warnings.warn( + ( + "AdkApp.get_session(...) is deprecated. " + "Use AdkApp.async_get_session(...) instead. See " + "https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/use/adk#get-session " + "for more details." + ), + DeprecationWarning, + stacklevel=2, + ) + event_queue = queue.Queue(maxsize=1) + + async def _invoke_async_get_session(): + return await self.async_get_session( + user_id=user_id, session_id=session_id, **kwargs + ) + + def _asyncio_thread_main(): + try: + result = asyncio.run(_invoke_async_get_session()) + event_queue.put(result) + except Exception as e: + event_queue.put(e) + + thread = threading.Thread(target=_asyncio_thread_main) + thread.start() + + # Wait for the thread to finish + thread.join() + try: + outcome = event_queue.get(timeout=10) + except queue.Empty: + raise RuntimeError( + "Session not found. Please create it using .create_session()" + ) from None + if isinstance(outcome, RuntimeError): + raise outcome from None + return outcome + async def async_list_sessions(self, *, user_id: str, **kwargs): """List sessions for the given user. @@ -988,6 +1122,45 @@ async def async_list_sessions(self, *, user_id: str, **kwargs): **kwargs, ) + def list_sessions(self, *, user_id: str, **kwargs): + """Deprecated. Use async_list_sessions instead. + + List sessions for the given user. + """ + warnings.warn( + ( + "AdkApp.list_sessions(...) is deprecated. " + "Use AdkApp.async_list_sessions(...) instead. See " + "https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/use/adk#list-sessions " + "for more details." + ), + DeprecationWarning, + stacklevel=2, + ) + event_queue = queue.Queue() + + async def _invoke_async_list_sessions(): + try: + response = await self.async_list_sessions(user_id=user_id, **kwargs) + event_queue.put(response) + except RuntimeError as e: + event_queue.put(e) + + def _asyncio_thread_main(): + try: + asyncio.run(_invoke_async_list_sessions()) + finally: + event_queue.put(None) + + thread = threading.Thread(target=_asyncio_thread_main) + thread.start() + # Wait for the thread to finish + thread.join() + try: + return event_queue.get(timeout=10) + except queue.Empty: + raise RuntimeError("Failed to list sessions.") from None + async def async_create_session( self, *, @@ -1024,6 +1197,58 @@ async def async_create_session( ) return session + def create_session( + self, + *, + user_id: str, + session_id: Optional[str] = None, + state: Optional[Dict[str, Any]] = None, + **kwargs, + ): + """Deprecated. Use async_create_session instead. + + Creates a new session. + """ + warnings.warn( + ( + "AdkApp.create_session(...) is deprecated. " + "Use AdkApp.async_create_session(...) instead. See " + "https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/use/adk#create-session " + "for more details." + ), + DeprecationWarning, + stacklevel=2, + ) + event_queue = queue.Queue(maxsize=1) + + async def _invoke_async_create_session(): + return await self.async_create_session( + user_id=user_id, + session_id=session_id, + state=state, + **kwargs, + ) + + def _asyncio_thread_main(): + try: + result = asyncio.run(_invoke_async_create_session()) + event_queue.put(result) + except RuntimeError as e: + event_queue.put(e) + + thread = threading.Thread(target=_asyncio_thread_main) + thread.start() + # Wait for the thread to finish + thread.join() + + try: + outcome = event_queue.get(timeout=10) + except queue.Empty: + raise RuntimeError("Failed to create session.") from None + if isinstance(outcome, RuntimeError): + raise outcome from None + return outcome + async def async_delete_session( self, *, @@ -1051,6 +1276,50 @@ async def async_delete_session( **kwargs, ) + def delete_session( + self, + *, + user_id: str, + session_id: str, + **kwargs, + ): + """Deprecated. Use async_delete_session instead. + + Deletes a session for the given user. + """ + warnings.warn( + ( + "AdkApp.delete_session(...) is deprecated. " + "Use AdkApp.async_delete_session(...) instead. See " + "https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/use/adk#delete-session " + "for more details." + ), + DeprecationWarning, + stacklevel=2, + ) + event_queue = queue.Queue(maxsize=1) + + async def _invoke_async_delete_session(): + await self.async_delete_session( + user_id=user_id, session_id=session_id, **kwargs + ) + + def _asyncio_thread_main(): + try: + asyncio.run(_invoke_async_delete_session()) + event_queue.put(None) + except RuntimeError as e: + event_queue.put(e) + + thread = threading.Thread(target=_asyncio_thread_main) + thread.start() + # Wait for the thread to finish + thread.join() + + outcome = event_queue.get(timeout=10) + if isinstance(outcome, RuntimeError): + raise outcome from None + async def async_add_session_to_memory(self, *, session: Dict[str, Any]): """Generates memories. @@ -1099,6 +1368,12 @@ async def async_search_memory(self, *, user_id: str, query: str): def register_operations(self) -> Dict[str, List[str]]: """Registers the operations of the ADK application.""" return { + "": [ + "get_session", + "list_sessions", + "create_session", + "delete_session", + ], "async": [ "async_get_session", "async_list_sessions", @@ -1107,6 +1382,7 @@ def register_operations(self) -> Dict[str, List[str]]: "async_add_session_to_memory", "async_search_memory", ], + "stream": ["stream_query"], "async_stream": [ "async_stream_query", "streaming_agent_run_with_events",