diff --git a/vertexai/agent_engines/templates/adk.py b/vertexai/agent_engines/templates/adk.py index f3c186e903..97be3505c6 100644 --- a/vertexai/agent_engines/templates/adk.py +++ b/vertexai/agent_engines/templates/adk.py @@ -585,6 +585,7 @@ async def _init_session( ): """Initializes the session, and returns the session id.""" from google.adk.events.event import Event + import random session_state = None if request.authorizations: @@ -593,9 +594,14 @@ async def _init_session( auth = _Authorization(**auth) session_state[f"temp:{auth_id}"] = auth.access_token + if request.session_id: + session_id = request.session_id + else: + session_id = f"temp_session_{random.randbytes(8).hex()}" session = await session_service.create_session( app_name=self._tmpl_attrs.get("app_name"), user_id=request.user_id, + session_id=session_id, state=session_state, ) if not session: @@ -613,7 +619,7 @@ async def _init_session( saved_version = await artifact_service.save_artifact( app_name=self._tmpl_attrs.get("app_name"), user_id=request.user_id, - session_id=session.id, + session_id=session_id, filename=artifact.file_name, artifact=version_data.data, ) @@ -1052,53 +1058,35 @@ async def streaming_agent_run_with_events(self, request_json: str): import json from google.genai import types - from google.genai.errors import ClientError request = _StreamRunRequest(**json.loads(request_json)) if not self._tmpl_attrs.get("in_memory_runner"): self.set_up() - if not self._tmpl_attrs.get("runner"): - self.set_up() # Prepare the in-memory session. if not self._tmpl_attrs.get("in_memory_artifact_service"): self.set_up() - if not self._tmpl_attrs.get("artifact_service"): - self.set_up() if not self._tmpl_attrs.get("in_memory_session_service"): self.set_up() - if not self._tmpl_attrs.get("session_service"): - self.set_up() + session_service = self._tmpl_attrs.get("in_memory_session_service") + artifact_service = self._tmpl_attrs.get("in_memory_artifact_service") app = self._tmpl_attrs.get("app") - # Try to get the session, if it doesn't exist, create a new one. + session = None if request.session_id: - session_service = self._tmpl_attrs.get("session_service") - artifact_service = self._tmpl_attrs.get("artifact_service") - runner = self._tmpl_attrs.get("runner") try: session = await session_service.get_session( app_name=app.name if app else self._tmpl_attrs.get("app_name"), user_id=request.user_id, session_id=request.session_id, ) - except ClientError: - # Fall back to create session if the session is not found. - # Specifying session_id on creation is not supported, - # so session id will be regenerated. - session = await self._init_session( - session_service=session_service, - artifact_service=artifact_service, - request=request, - ) - else: - # Not providing a session ID will create a new in-memory session. - session_service = self._tmpl_attrs.get("in_memory_session_service") - artifact_service = self._tmpl_attrs.get("in_memory_artifact_service") - runner = self._tmpl_attrs.get("in_memory_runner") - session = await session_service.create_session( - app_name=self._tmpl_attrs.get("app_name"), - user_id=request.user_id, - session_id=request.session_id, + except RuntimeError: + pass + if not session: + # Fall back to create session if the session is not found. + session = await self._init_session( + session_service=session_service, + artifact_service=artifact_service, + request=request, ) if not session: raise RuntimeError("Session initialization failed.") @@ -1106,7 +1094,7 @@ async def streaming_agent_run_with_events(self, request_json: str): # Run the agent message_for_agent = types.Content(**request.message) try: - async for event in runner.run_async( + async for event in self._tmpl_attrs.get("in_memory_runner").run_async( user_id=request.user_id, session_id=session.id, new_message=message_for_agent, diff --git a/vertexai/preview/reasoning_engines/templates/adk.py b/vertexai/preview/reasoning_engines/templates/adk.py index 8b49b992c6..1bf5bb64a0 100644 --- a/vertexai/preview/reasoning_engines/templates/adk.py +++ b/vertexai/preview/reasoning_engines/templates/adk.py @@ -533,6 +533,7 @@ async def _init_session( ): """Initializes the session, and returns the session id.""" from google.adk.events.event import Event + import random session_state = None if request.authorizations: @@ -541,9 +542,14 @@ async def _init_session( auth = _Authorization(**auth) session_state[f"temp:{auth_id}"] = auth.access_token + if request.session_id: + session_id = request.session_id + else: + session_id = f"temp_session_{random.randbytes(8).hex()}" session = await session_service.create_session( app_name=self._tmpl_attrs.get("app_name"), user_id=request.user_id, + session_id=session_id, state=session_state, ) if not session: @@ -561,7 +567,7 @@ async def _init_session( saved_version = await artifact_service.save_artifact( app_name=self._tmpl_attrs.get("app_name"), user_id=request.user_id, - session_id=session.id, + session_id=session_id, filename=artifact.file_name, artifact=version_data.data, ) @@ -939,7 +945,6 @@ async def async_stream_query( def streaming_agent_run_with_events(self, request_json: str): import json from google.genai import types - from google.genai.errors import ClientError event_queue = queue.Queue(maxsize=1) @@ -947,52 +952,37 @@ async def _invoke_agent_async(): request = _StreamRunRequest(**json.loads(request_json)) if not self._tmpl_attrs.get("in_memory_runner"): self.set_up() - if not self._tmpl_attrs.get("runner"): - self.set_up() # Prepare the in-memory session. if not self._tmpl_attrs.get("in_memory_artifact_service"): self.set_up() - if not self._tmpl_attrs.get("artifact_service"): - self.set_up() if not self._tmpl_attrs.get("in_memory_session_service"): self.set_up() - if not self._tmpl_attrs.get("session_service"): - self.set_up() + session_service = self._tmpl_attrs.get("in_memory_session_service") + artifact_service = self._tmpl_attrs.get("in_memory_artifact_service") + # Try to get the session, if it doesn't exist, create a new one. + session = None if request.session_id: - session_service = self._tmpl_attrs.get("session_service") - artifact_service = self._tmpl_attrs.get("artifact_service") - runner = self._tmpl_attrs.get("runner") try: session = await session_service.get_session( app_name=self._tmpl_attrs.get("app_name"), user_id=request.user_id, session_id=request.session_id, ) - except ClientError: - # Fall back to create session if the session is not found. - # Specifying session_id on creation is not supported, - # so session id will be regenerated. - session = await self._init_session( - session_service=session_service, - artifact_service=artifact_service, - request=request, - ) - else: - # Not providing a session ID will create a new in-memory session. - session_service = self._tmpl_attrs.get("in_memory_session_service") - artifact_service = self._tmpl_attrs.get("in_memory_artifact_service") - runner = self._tmpl_attrs.get("in_memory_runner") - session = await session_service.create_session( - app_name=self._tmpl_attrs.get("app_name"), - user_id=request.user_id, - session_id=request.session_id, + except RuntimeError: + pass + if not session: + # Fall back to create session if the session is not found. + session = await self._init_session( + session_service=session_service, + artifact_service=artifact_service, + request=request, ) if not session: raise RuntimeError("Session initialization failed.") # Run the agent. message_for_agent = types.Content(**request.message) try: - for event in runner.run( + for event in self._tmpl_attrs.get("in_memory_runner").run( user_id=request.user_id, session_id=session.id, new_message=message_for_agent,