From d3b12d57d1e4f1a8db5f41f597c6cb0f33e8a369 Mon Sep 17 00:00:00 2001 From: Tongzhou Jiang Date: Fri, 7 Nov 2025 10:59:57 -0800 Subject: [PATCH] feat: Reenable VertexAiSession for streaming_agent_run_with_events PiperOrigin-RevId: 829503268 --- vertexai/agent_engines/templates/adk.py | 42 ++++++++++++------- .../reasoning_engines/templates/adk.py | 40 +++++++++++------- 2 files changed, 52 insertions(+), 30 deletions(-) diff --git a/vertexai/agent_engines/templates/adk.py b/vertexai/agent_engines/templates/adk.py index 817e8475fc..fba12a863e 100644 --- a/vertexai/agent_engines/templates/adk.py +++ b/vertexai/agent_engines/templates/adk.py @@ -602,7 +602,6 @@ async def _init_session( ): """Initializes the session, and returns the session id.""" from google.adk.events.event import Event - import random session_state = None if request.authorizations: @@ -611,14 +610,9 @@ async def _init_session( auth = _Authorization(**auth) session_state[auth_id] = auth.access_token - if request.session_id: - session_id = request.session_id - else: - session_id = f"temp_session_{random.randbytes(8).hex()}" session = await session_service.create_session( app_name=self._tmpl_attrs.get("app_name"), user_id=request.user_id, - session_id=session_id, state=session_state, ) if not session: @@ -636,7 +630,7 @@ async def _init_session( saved_version = await artifact_service.save_artifact( app_name=self._tmpl_attrs.get("app_name"), user_id=request.user_id, - session_id=session_id, + session_id=session.id, filename=artifact.file_name, artifact=version_data.data, ) @@ -1078,31 +1072,49 @@ async def streaming_agent_run_with_events(self, request_json: str): import json from google.genai import types + from google.genai.errors import ClientError request = _StreamRunRequest(**json.loads(request_json)) if not self._tmpl_attrs.get("in_memory_runner"): self.set_up() + if not self._tmpl_attrs.get("runner"): + self.set_up() # Prepare the in-memory session. if not self._tmpl_attrs.get("in_memory_artifact_service"): self.set_up() + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() if not self._tmpl_attrs.get("in_memory_session_service"): self.set_up() - session_service = self._tmpl_attrs.get("in_memory_session_service") - artifact_service = self._tmpl_attrs.get("in_memory_artifact_service") + if not self._tmpl_attrs.get("session_service"): + self.set_up() app = self._tmpl_attrs.get("app") + # Try to get the session, if it doesn't exist, create a new one. - session = None if request.session_id: + session_service = self._tmpl_attrs.get("session_service") + artifact_service = self._tmpl_attrs.get("artifact_service") + runner = self._tmpl_attrs.get("runner") try: session = await session_service.get_session( app_name=app.name if app else self._tmpl_attrs.get("app_name"), user_id=request.user_id, session_id=request.session_id, ) - except RuntimeError: - pass - if not session: - # Fall back to create session if the session is not found. + except ClientError: + # Fall back to create session if the session is not found. + # Specifying session_id on creation is not supported, + # so session id will be regenerated. + session = await self._init_session( + session_service=session_service, + artifact_service=artifact_service, + request=request, + ) + else: + # Not providing a session ID will create a new in-memory session. + session_service = self._tmpl_attrs.get("in_memory_session_service") + artifact_service = self._tmpl_attrs.get("in_memory_artifact_service") + runner = self._tmpl_attrs.get("in_memory_runner") session = await self._init_session( session_service=session_service, artifact_service=artifact_service, @@ -1114,7 +1126,7 @@ async def streaming_agent_run_with_events(self, request_json: str): # Run the agent message_for_agent = types.Content(**request.message) try: - async for event in self._tmpl_attrs.get("in_memory_runner").run_async( + async for event in runner.run_async( user_id=request.user_id, session_id=session.id, new_message=message_for_agent, diff --git a/vertexai/preview/reasoning_engines/templates/adk.py b/vertexai/preview/reasoning_engines/templates/adk.py index a14d314b68..9ba7a276d2 100644 --- a/vertexai/preview/reasoning_engines/templates/adk.py +++ b/vertexai/preview/reasoning_engines/templates/adk.py @@ -550,7 +550,6 @@ async def _init_session( ): """Initializes the session, and returns the session id.""" from google.adk.events.event import Event - import random session_state = None if request.authorizations: @@ -559,14 +558,9 @@ async def _init_session( auth = _Authorization(**auth) session_state[auth_id] = auth.access_token - if request.session_id: - session_id = request.session_id - else: - session_id = f"temp_session_{random.randbytes(8).hex()}" session = await session_service.create_session( app_name=self._tmpl_attrs.get("app_name"), user_id=request.user_id, - session_id=session_id, state=session_state, ) if not session: @@ -584,7 +578,7 @@ async def _init_session( saved_version = await artifact_service.save_artifact( app_name=self._tmpl_attrs.get("app_name"), user_id=request.user_id, - session_id=session_id, + session_id=session.id, filename=artifact.file_name, artifact=version_data.data, ) @@ -965,6 +959,7 @@ async def async_stream_query( def streaming_agent_run_with_events(self, request_json: str): import json from google.genai import types + from google.genai.errors import ClientError event_queue = queue.Queue(maxsize=1) @@ -972,26 +967,41 @@ async def _invoke_agent_async(): request = _StreamRunRequest(**json.loads(request_json)) if not self._tmpl_attrs.get("in_memory_runner"): self.set_up() + if not self._tmpl_attrs.get("runner"): + self.set_up() # Prepare the in-memory session. if not self._tmpl_attrs.get("in_memory_artifact_service"): self.set_up() + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() if not self._tmpl_attrs.get("in_memory_session_service"): self.set_up() - session_service = self._tmpl_attrs.get("in_memory_session_service") - artifact_service = self._tmpl_attrs.get("in_memory_artifact_service") + if not self._tmpl_attrs.get("session_service"): + self.set_up() + # Try to get the session, if it doesn't exist, create a new one. - session = None if request.session_id: + session_service = self._tmpl_attrs.get("session_service") + artifact_service = self._tmpl_attrs.get("artifact_service") + runner = self._tmpl_attrs.get("runner") try: session = await session_service.get_session( app_name=self._tmpl_attrs.get("app_name"), user_id=request.user_id, session_id=request.session_id, ) - except RuntimeError: - pass - if not session: - # Fall back to create session if the session is not found. + except ClientError: + # Fall back to create session if the session is not found. + session = await self._init_session( + session_service=session_service, + artifact_service=artifact_service, + request=request, + ) + else: + # Not providing a session ID will create a new in-memory session. + session_service = self._tmpl_attrs.get("in_memory_session_service") + artifact_service = self._tmpl_attrs.get("in_memory_artifact_service") + runner = self._tmpl_attrs.get("in_memory_runner") session = await self._init_session( session_service=session_service, artifact_service=artifact_service, @@ -1002,7 +1012,7 @@ async def _invoke_agent_async(): # Run the agent. message_for_agent = types.Content(**request.message) try: - for event in self._tmpl_attrs.get("in_memory_runner").run( + for event in runner.run_async( user_id=request.user_id, session_id=session.id, new_message=message_for_agent,