From 55b7c23950a40731ecbb3918f6e4721c66b1e9d3 Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Mon, 20 Oct 2025 15:12:11 -0700 Subject: [PATCH] feat: GenAI SDK client: Add async Memory and Memory Revisions methods PiperOrigin-RevId: 821824583 --- .../test_delete_agent_engine_memory.py | 13 +- .../test_generate_agent_engine_memories.py | 65 +++++ .../replays/test_get_agent_engine_memory.py | 15 +- .../test_list_agent_engine_memories.py | 35 ++- .../test_retrieve_agent_engine_memories.py | 43 ++- vertexai/_genai/memories.py | 256 +++++++++++++++++- vertexai/_genai/memory_revisions.py | 27 +- 7 files changed, 432 insertions(+), 22 deletions(-) diff --git a/tests/unit/vertexai/genai/replays/test_delete_agent_engine_memory.py b/tests/unit/vertexai/genai/replays/test_delete_agent_engine_memory.py index 120c0384f2..a9ab9b714c 100644 --- a/tests/unit/vertexai/genai/replays/test_delete_agent_engine_memory.py +++ b/tests/unit/vertexai/genai/replays/test_delete_agent_engine_memory.py @@ -23,15 +23,16 @@ def test_delete_memory(client): agent_engine = client.agent_engines.create() - operation = client.agent_engines.create_memory( + operation = client.agent_engines.memories.create( name=agent_engine.api_resource.name, fact="memory_fact", scope={"user_id": "123"}, ) memory = operation.response - operation = client.agent_engines.delete_memory(name=memory.name) + operation = client.agent_engines.memories.delete(name=memory.name) assert isinstance(operation, types.DeleteAgentEngineMemoryOperation) assert operation.name.startswith(memory.name + "/operations/") + client.agent_engines.delete(name=agent_engine.api_resource.name, force=True) pytestmark = pytest_helper.setup( @@ -46,14 +47,16 @@ def test_delete_memory(client): @pytest.mark.asyncio async def test_delete_memory_async(client): - # TODO(b/431785750): use async methods for create() and create_memory() when available agent_engine = client.agent_engines.create() - operation = client.agent_engines.create_memory( + operation = await client.aio.agent_engines.memories.create( name=agent_engine.api_resource.name, fact="memory_fact", scope={"user_id": "123"}, ) memory = operation.response - operation = await client.aio.agent_engines.delete_memory(name=memory.name) + operation = await client.aio.agent_engines.memories.delete(name=memory.name) assert isinstance(operation, types.DeleteAgentEngineMemoryOperation) assert operation.name.startswith(memory.name + "/operations/") + await client.aio.agent_engines.delete( + name=agent_engine.api_resource.name, force=True + ) diff --git a/tests/unit/vertexai/genai/replays/test_generate_agent_engine_memories.py b/tests/unit/vertexai/genai/replays/test_generate_agent_engine_memories.py index 8b2c67cab2..80298ecd13 100644 --- a/tests/unit/vertexai/genai/replays/test_generate_agent_engine_memories.py +++ b/tests/unit/vertexai/genai/replays/test_generate_agent_engine_memories.py @@ -14,12 +14,16 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring +import pytest + + from tests.unit.vertexai.genai.replays import pytest_helper from vertexai._genai import types from google.genai import types as genai_types def test_generate_and_rollback_memories(client): + # TODO(): Use prod endpoint once experiment is fully rolled out. client._api_client._http_options.base_url = ( "https://us-central1-autopush-aiplatform.sandbox.googleapis.com/" ) @@ -146,3 +150,64 @@ def test_generate_memories_direct_memories_source(client): globals_for_file=globals(), test_method="agent_engines.generate_memories", ) + + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_generate_and_rollback_memories_async(client): + # TODO(): Use prod endpoint once revisions experiment is fully rolled out. + client._api_client._http_options.base_url = ( + "https://us-central1-autopush-aiplatform.sandbox.googleapis.com/" + ) + agent_engine = client.agent_engines.create() + await client.aio.agent_engines.memories.generate( + name=agent_engine.api_resource.name, + scope={"user_id": "test-user-id"}, + direct_memories_source=types.GenerateMemoriesRequestDirectMemoriesSource( + direct_memories=[ + types.GenerateMemoriesRequestDirectMemoriesSourceDirectMemory( + fact="I am a software engineer." + ), + types.GenerateMemoriesRequestDirectMemoriesSourceDirectMemory( + fact="I like to write replay tests." + ), + ] + ), + config=types.GenerateAgentEngineMemoriesConfig(wait_for_completion=True), + ) + memories_pager = await client.aio.agent_engines.memories.list( + name=agent_engine.api_resource.name + ) + memory_list = [item async for item in memories_pager] + assert len(memory_list) >= 1 + + revisions_pager = await client.aio.agent_engines.memories.revisions.list( + name=memory_list[0].name + ) + memory_revisions = [item async for item in revisions_pager] + assert len(memory_revisions) >= 1 + revision_name = memory_revisions[0].name + + # Update the memory. + client.agent_engines.memories._update( + name=memory_list[0].name, + fact="This is temporary", + scope={"user_id": "test-user-id"}, + ) + memory = await client.aio.agent_engines.memories.get(name=memory_list[0].name) + assert memory.fact == "This is temporary" + + # Rollback to the revision with the original fact that was created by the + # generation request. + await client.aio.agent_engines.memories.rollback( + name=memory_list[0].name, + target_revision_id=revision_name.split("/")[-1], + ) + memory = await client.aio.agent_engines.memories.get(name=memory_list[0].name) + assert memory.fact == memory_revisions[0].fact + + await client.aio.agent_engines.delete( + name=agent_engine.api_resource.name, force=True + ) diff --git a/tests/unit/vertexai/genai/replays/test_get_agent_engine_memory.py b/tests/unit/vertexai/genai/replays/test_get_agent_engine_memory.py index 4774ab90e2..04ef6b50bd 100644 --- a/tests/unit/vertexai/genai/replays/test_get_agent_engine_memory.py +++ b/tests/unit/vertexai/genai/replays/test_get_agent_engine_memory.py @@ -22,23 +22,24 @@ def test_get_memory(client): agent_engine = client.agent_engines.create() - operation = client.agent_engines.create_memory( + operation = client.agent_engines.memories.create( name=agent_engine.api_resource.name, fact="memory_fact", scope={"user_id": "123"}, ) assert isinstance(operation, types.AgentEngineMemoryOperation) - memory = client.agent_engines.get_memory( + memory = client.agent_engines.memories.get( name=operation.response.name, ) assert isinstance(memory, types.Memory) assert memory.name == operation.response.name + client.agent_engines.delete(name=agent_engine.api_resource.name, force=True) pytestmark = pytest_helper.setup( file=__file__, globals_for_file=globals(), - test_method="agent_engines.get_memory", + test_method="agent_engines.memories.get", ) @@ -47,16 +48,18 @@ def test_get_memory(client): @pytest.mark.asyncio async def test_get_memory_async(client): - # TODO(b/431785750): use async methods for create() and create_memory() when available agent_engine = client.agent_engines.create() - operation = client.agent_engines.create_memory( + operation = await client.aio.agent_engines.memories.create( name=agent_engine.api_resource.name, fact="memory_fact", scope={"user_id": "123"}, ) assert isinstance(operation, types.AgentEngineMemoryOperation) - memory = await client.aio.agent_engines.get_memory( + memory = await client.aio.agent_engines.memories.get( name=operation.response.name, ) assert isinstance(memory, types.Memory) assert memory.name == operation.response.name + await client.aio.agent_engines.delete( + name=agent_engine.api_resource.name, force=True + ) diff --git a/tests/unit/vertexai/genai/replays/test_list_agent_engine_memories.py b/tests/unit/vertexai/genai/replays/test_list_agent_engine_memories.py index 5ce3d77fb9..0e84b5c64e 100644 --- a/tests/unit/vertexai/genai/replays/test_list_agent_engine_memories.py +++ b/tests/unit/vertexai/genai/replays/test_list_agent_engine_memories.py @@ -14,6 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring +import pytest + from tests.unit.vertexai.genai.replays import pytest_helper from vertexai._genai import types @@ -59,5 +61,36 @@ def test_list_memories(client): pytestmark = pytest_helper.setup( file=__file__, globals_for_file=globals(), - test_method="agent_engines.list_memories", + test_method="agent_engines.memories.list", ) + + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_async_list_memories(client): + agent_engine = client.agent_engines.create() + pager = await client.aio.agent_engines.memories.list( + name=agent_engine.api_resource.name + ) + assert not [item async for item in pager] + + await client.aio.agent_engines.memories.create( + name=agent_engine.api_resource.name, + fact="memory_fact_2", + scope={"user_id": "456"}, + config={ + "wait_for_completion": True, + }, + ) + pager = await client.aio.agent_engines.memories.list( + name=agent_engine.api_resource.name + ) + memory_list = [item async for item in pager] + assert len(memory_list) == 1 + assert isinstance(memory_list[0], types.Memory) + + await client.aio.agent_engines.delete( + name=agent_engine.api_resource.name, force=True + ) diff --git a/tests/unit/vertexai/genai/replays/test_retrieve_agent_engine_memories.py b/tests/unit/vertexai/genai/replays/test_retrieve_agent_engine_memories.py index 745e86dcb6..f36166cfda 100644 --- a/tests/unit/vertexai/genai/replays/test_retrieve_agent_engine_memories.py +++ b/tests/unit/vertexai/genai/replays/test_retrieve_agent_engine_memories.py @@ -14,6 +14,9 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring +import pytest + + from tests.unit.vertexai.genai.replays import pytest_helper from vertexai._genai import types from google.genai import pagers @@ -22,7 +25,7 @@ def test_retrieve_memories_with_similarity_search_params(client): agent_engine = client.agent_engines.create() assert not list( - client.agent_engines.retrieve_memories( + client.agent_engines.memories.retrieve( name=agent_engine.api_resource.name, scope={"user_id": "123"}, similarity_search_params=types.RetrieveMemoriesRequestSimilaritySearchParams( @@ -30,7 +33,7 @@ def test_retrieve_memories_with_similarity_search_params(client): ), ) ) - client.agent_engines.create_memory( + client.agent_engines.memories.create( name=agent_engine.api_resource.name, fact="memory_fact_1", scope={"user_id": "123"}, @@ -38,7 +41,7 @@ def test_retrieve_memories_with_similarity_search_params(client): assert ( len( list( - client.agent_engines.retrieve_memories( + client.agent_engines.memories.retrieve( name=agent_engine.api_resource.name, scope={"user_id": "123"}, ) @@ -47,12 +50,12 @@ def test_retrieve_memories_with_similarity_search_params(client): == 1 ) assert not list( - client.agent_engines.retrieve_memories( + client.agent_engines.memories.retrieve( name=agent_engine.api_resource.name, scope={"user_id": "456"}, ) ) - client.agent_engines.create_memory( + client.agent_engines.memories.create( name=agent_engine.api_resource.name, fact="memory_fact_2", scope={"user_id": "123"}, @@ -60,7 +63,7 @@ def test_retrieve_memories_with_similarity_search_params(client): assert ( len( list( - client.agent_engines.retrieve_memories( + client.agent_engines.memories.retrieve( name=agent_engine.api_resource.name, scope={"user_id": "123"}, ) @@ -74,12 +77,12 @@ def test_retrieve_memories_with_similarity_search_params(client): def test_retrieve_memories_with_simple_retrieval_params(client): agent_engine = client.agent_engines.create() - client.agent_engines.create_memory( + client.agent_engines.memories.create( name=agent_engine.api_resource.name, fact="memory_fact_1", scope={"user_id": "123"}, ) - memories = client.agent_engines.retrieve_memories( + memories = client.agent_engines.memories.retrieve( name=agent_engine.api_resource.name, scope={"user_id": "123"}, simple_retrieval_params=types.RetrieveMemoriesRequestSimpleRetrievalParams( @@ -98,3 +101,27 @@ def test_retrieve_memories_with_simple_retrieval_params(client): globals_for_file=globals(), test_method="agent_engines.create_memory", ) + + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_retrieve_memories_async(client): + agent_engine = client.agent_engines.create() + operation = await client.aio.agent_engines.memories.create( + name=agent_engine.api_resource.name, + fact="memory_fact", + scope={"user_id": "123"}, + ) + assert isinstance(operation, types.AgentEngineMemoryOperation) + pager = await client.aio.agent_engines.memories.retrieve( + name=agent_engine.api_resource.name, + scope={"user_id": "123"}, + ) + memories = [item async for item in pager] + assert len(memories) == 1 + assert isinstance(memories[0], types.RetrieveMemoriesResponseRetrievedMemory) + await client.aio.agent_engines.delete( + name=agent_engine.api_resource.name, force=True + ) diff --git a/vertexai/_genai/memories.py b/vertexai/_genai/memories.py index 747d6b5f96..57f76d3c3a 100644 --- a/vertexai/_genai/memories.py +++ b/vertexai/_genai/memories.py @@ -26,7 +26,7 @@ from google.genai import _common from google.genai._common import get_value_by_path as getv from google.genai._common import set_value_by_path as setv -from google.genai.pagers import Pager +from google.genai.pagers import AsyncPager, Pager from . import _agent_engines_utils from . import types @@ -1048,6 +1048,7 @@ def create( operation = _agent_engines_utils._await_operation( operation_name=operation.name, get_operation_fn=self._get_memory_operation, + poll_interval_seconds=0.5, ) # We need to make a call to get the memory because the operation # response might not contain the relevant fields. @@ -1855,3 +1856,256 @@ async def _update( self._api_client._verify_response(return_value) return return_value + + _revisions = None + + @property + def revisions(self): + if self._revisions is None: + try: + # We need to lazy load the revisions module to handle the + # possibility of ImportError when dependencies are not installed. + self._revisions = importlib.import_module( + ".memory_revisions", __package__ + ) + except ImportError as e: + raise ImportError( + "The 'agent_engines.memories.revisions' module requires " + "additional packages. Please install them using pip install " + "google-cloud-aiplatform[agent_engines]" + ) from e + return self._revisions.AsyncMemoryRevisions(self._api_client) + + async def create( + self, + *, + name: str, + fact: str, + scope: dict[str, str], + config: Optional[types.AgentEngineMemoryConfigOrDict] = None, + ) -> types.AgentEngineMemoryOperation: + """Creates a new memory in the Agent Engine. + + Args: + name (str): + Required. The name of the memory to create. + fact (str): + Required. The fact to be stored in the memory. + scope (dict[str, str]): + Required. The scope of the memory. For example, {"user_id": "123"}. + config (AgentEngineMemoryConfigOrDict): + Optional. The configuration for the memory. + + Returns: + AgentEngineMemoryOperation: The operation for creating the memory. + """ + if config is None: + config = types.AgentEngineMemoryConfig() + elif isinstance(config, dict): + config = types.AgentEngineMemoryConfig.model_validate(config) + operation = await self._create( + name=name, + fact=fact, + scope=scope, + config=config, + ) + if config.wait_for_completion: + if not operation.done: + operation = await _agent_engines_utils._await_async_operation( + operation_name=operation.name, + get_operation_fn=self._get_memory_operation, + poll_interval_seconds=0.5, + ) + # We need to make a call to get the memory because the operation + # response might not contain the relevant fields. + if operation.response: + operation.response = await self.get(name=operation.response.name) + elif operation.error: + raise RuntimeError(f"Failed to create memory: {operation.error}") + else: + raise RuntimeError("Error creating memory.") + return operation + + async def generate( + self, + *, + name: str, + vertex_session_source: Optional[ + types.GenerateMemoriesRequestVertexSessionSourceOrDict + ] = None, + direct_contents_source: Optional[ + types.GenerateMemoriesRequestDirectContentsSourceOrDict + ] = None, + direct_memories_source: Optional[ + types.GenerateMemoriesRequestDirectMemoriesSourceOrDict + ] = None, + scope: Optional[dict[str, str]] = None, + config: Optional[types.GenerateAgentEngineMemoriesConfigOrDict] = None, + ) -> types.AgentEngineGenerateMemoriesOperation: + """Generates memories for the agent engine. + + Args: + name (str): + Required. The name of the agent engine to generate memories for. + vertex_session_source (GenerateMemoriesRequestVertexSessionSource): + Optional. The vertex session source to use for generating + memories. Only one of vertex_session_source, + direct_contents_source, or direct_memories_source can be + specified. + direct_contents_source(GenerateMemoriesRequestDirectContentsSource): + Optional. The direct contents source to use for generating + memories. Only one of vertex_session_source, direct_contents_source, + or direct_memories_source can be specified. + direct_memories_source (GenerateMemoriesRequestDirectMemoriesSource): + Optional. The direct memories source to use for generating + memories. Only one of vertex_session_source, direct_contents_source, + or direct_memories_source can be specified. + scope (dict[str, str]): + Optional. The scope of the memories to generate. This is optional + if vertex_session_source is used, otherwise it must be specified. + config (GenerateMemoriesConfig): + Optional. The configuration for the memories to generate. + + Returns: + AgentEngineGenerateMemoriesOperation: + The operation for generating the memories. + """ + if config is None: + config = types.GenerateAgentEngineMemoriesConfig() + elif isinstance(config, dict): + config = types.GenerateAgentEngineMemoriesConfig.model_validate(config) + operation = await self._generate( + name=name, + vertex_session_source=vertex_session_source, + direct_contents_source=direct_contents_source, + direct_memories_source=direct_memories_source, + scope=scope, + config=config, + ) + if config.wait_for_completion and not operation.done: + operation = await _agent_engines_utils._await_async_operation( + operation_name=operation.name, + get_operation_fn=self._get_generate_memories_operation, + poll_interval_seconds=0.5, + ) + if operation.error: + raise RuntimeError(f"Failed to generate memory: {operation.error}") + return operation + + async def list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineMemoryConfigOrDict] = None, + ) -> AsyncPager[types.Memory]: + """Lists Agent Engine memories. + + Args: + name (str): + Required. The name of the agent engine to list memories for. + config (ListAgentEngineMemoryConfig): + Optional. The configuration for the memories to list. + + Returns: + AsyncPager[Memory]: An async pager of memories. + """ + + return AsyncPager( + "memories", + functools.partial(self._list, name=name), + await self._list(name=name, config=config), + config, + ) + + async def retrieve( + self, + *, + name: str, + scope: dict[str, str], + similarity_search_params: Optional[ + types.RetrieveMemoriesRequestSimilaritySearchParamsOrDict + ] = None, + simple_retrieval_params: Optional[ + types.RetrieveMemoriesRequestSimpleRetrievalParamsOrDict + ] = None, + config: Optional[types.RetrieveAgentEngineMemoriesConfigOrDict] = None, + ) -> AsyncPager[types.RetrieveMemoriesResponseRetrievedMemory]: + """Retrieves memories for the agent. + + Args: + name (str): + Required. The name of the agent engine to retrieve memories for. + scope (dict[str, str]): + Required. The scope of the memories to retrieve. For example, + {"user_id": "123"}. + similarity_search_params (RetrieveMemoriesRequestSimilaritySearchParams): + Optional. The similarity search parameters to use for retrieving + memories. + simple_retrieval_params (RetrieveMemoriesRequestSimpleRetrievalParams): + Optional. The simple retrieval parameters to use for retrieving + memories. + config (RetrieveAgentEngineMemoriesConfig): + Optional. The configuration for the memories to retrieve. + + Returns: + AsyncPager[RetrieveMemoriesResponseRetrievedMemory]: An async pager of + retrieved memories. + """ + return AsyncPager( + "retrieved_memories", + lambda config: self._retrieve( + name=name, + similarity_search_params=similarity_search_params, + simple_retrieval_params=simple_retrieval_params, + scope=scope, + config=config, + ), + await self._retrieve( + name=name, + similarity_search_params=similarity_search_params, + simple_retrieval_params=simple_retrieval_params, + scope=scope, + config=config, + ), + config, + ) + + async def rollback( + self, + *, + name: str, + target_revision_id: str, + config: Optional[types.RollbackAgentEngineMemoryConfigOrDict] = None, + ) -> types.AgentEngineRollbackMemoryOperation: + """Rolls back a memory to a previous revision. + + Args: + name (str): + Required. The name of the memory to rollback. + target_revision_id (str): + Required. The revision ID to roll back to + config (RollbackAgentEngineMemoryConfig): + Optional. The configuration for the rollback. + + Returns: + AgentEngineRollbackMemoryOperation: + The operation for rolling back the memory. + """ + if config is None: + config = types.RollbackAgentEngineMemoryConfig() + elif isinstance(config, dict): + config = types.RollbackAgentEngineMemoryConfig.model_validate(config) + operation = await self._rollback( + name=name, + target_revision_id=target_revision_id, + config=config, + ) + if config.wait_for_completion and not operation.done: + operation = await _agent_engines_utils._await_async_operation( + operation_name=operation.name, + get_operation_fn=self._get_rollback_memory_operation, + poll_interval_seconds=0.5, + ) + if operation.error: + raise RuntimeError(f"Failed to rollback memory: {operation.error}") + return operation diff --git a/vertexai/_genai/memory_revisions.py b/vertexai/_genai/memory_revisions.py index 53541c0b4c..765ac32347 100644 --- a/vertexai/_genai/memory_revisions.py +++ b/vertexai/_genai/memory_revisions.py @@ -25,7 +25,7 @@ from google.genai import _common from google.genai._common import get_value_by_path as getv from google.genai._common import set_value_by_path as setv -from google.genai.pagers import Pager +from google.genai.pagers import AsyncPager, Pager from . import types @@ -380,3 +380,28 @@ async def _list( self._api_client._verify_response(return_value) return return_value + + async def list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineMemoryRevisionsConfigOrDict] = None, + ) -> AsyncPager[types.MemoryRevision]: + """Lists Agent Engine memory revisions. + + Args: + name (str): + Required. The name of the Memory to list revisions for. + config (ListAgentEngineMemoryRevisionsConfigOrDict): + Optional. The configuration for the memories to list revisions. + + Returns: + AsyncPager[MemoryRevision]: An async pager of memory revisions. + """ + + return AsyncPager( + "memory_revisions", + functools.partial(self._list, name=name), + await self._list(name=name, config=config), + config, + )