diff --git a/tests/unit/vertexai/genai/replays/test_generate_agent_engine_memories.py b/tests/unit/vertexai/genai/replays/test_generate_agent_engine_memories.py index 80298ecd13..a22be6c44c 100644 --- a/tests/unit/vertexai/genai/replays/test_generate_agent_engine_memories.py +++ b/tests/unit/vertexai/genai/replays/test_generate_agent_engine_memories.py @@ -95,23 +95,27 @@ def test_generate_and_rollback_memories(client): # Update the memory again using generation. We use the original source # content to ensure that the original memory is updated. The response should # refer to the previous revision. + pre_extracted_fact = "I am a software engineer focusing in security" response = client.agent_engines.memories.generate( name=agent_engine.api_resource.name, scope={"user_id": "test-user-id"}, - direct_contents_source=types.GenerateMemoriesRequestDirectContentsSource( - events=[ - types.GenerateMemoriesRequestDirectContentsSourceEvent( - content=genai_types.Content( - role="model", - parts=[genai_types.Part(text=memory_revisions[0].fact)], - ) + direct_memories_source=types.GenerateMemoriesRequestDirectMemoriesSource( + direct_memories=[ + types.GenerateMemoriesRequestDirectMemoriesSourceDirectMemory( + fact=pre_extracted_fact ) ] ), ) # The memory was updated, so the previous revision is set. assert response.response.generated_memories[0].previous_revision is not None - + memory_revisions = list( + client.agent_engines.memories.revisions.list(name=memories[0].name) + ) + # Memory Revisions are returned in descending order by revision create time. + # We can't make an assertion on the actual value, since it's + # generated and thus non-deterministic. + assert memory_revisions[0].extracted_memories[0].fact is not None client.agent_engines.delete(name=agent_engine.api_resource.name, force=True) diff --git a/tests/unit/vertexai/genai/replays/test_retrieve_agent_engine_memories.py b/tests/unit/vertexai/genai/replays/test_retrieve_agent_engine_memories.py index f36166cfda..15cb7caca8 100644 --- a/tests/unit/vertexai/genai/replays/test_retrieve_agent_engine_memories.py +++ b/tests/unit/vertexai/genai/replays/test_retrieve_agent_engine_memories.py @@ -92,6 +92,25 @@ def test_retrieve_memories_with_simple_retrieval_params(client): assert isinstance(memories, pagers.Pager) assert isinstance(memories.page[0], types.RetrieveMemoriesResponseRetrievedMemory) assert memories.page_size == 1 + + client.agent_engines.memories.create( + name=agent_engine.api_resource.name, + fact="memory_fact_2", + scope={"user_id": "123"}, + ) + memories = client.agent_engines.memories.retrieve( + name=agent_engine.api_resource.name, scope={"user_id": "123"} + ) + assert memories.page_size == 2 + + memories = client.agent_engines.memories.retrieve( + name=agent_engine.api_resource.name, + scope={"user_id": "123"}, + config={"filter": 'fact="memory_fact_2"'}, + ) + assert memories.page_size == 1 + assert memories.page[0].memory.fact == "memory_fact_2" + # Clean up resources. agent_engine.delete(force=True) @@ -99,7 +118,7 @@ def test_retrieve_memories_with_simple_retrieval_params(client): pytestmark = pytest_helper.setup( file=__file__, globals_for_file=globals(), - test_method="agent_engines.create_memory", + test_method="agent_engines.memories.retrieve", ) diff --git a/vertexai/_genai/memories.py b/vertexai/_genai/memories.py index 4ac889c513..6e4e5f3b8c 100644 --- a/vertexai/_genai/memories.py +++ b/vertexai/_genai/memories.py @@ -287,6 +287,18 @@ def _ListAgentEngineMemoryRequestParameters_to_vertex( return to_object +def _RetrieveAgentEngineMemoriesConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["filter"]) is not None: + setv(parent_object, ["filter"], getv(from_object, ["filter"])) + + return to_object + + def _RetrieveAgentEngineMemoriesRequestParameters_to_vertex( from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, @@ -313,7 +325,13 @@ def _RetrieveAgentEngineMemoriesRequestParameters_to_vertex( ) if getv(from_object, ["config"]) is not None: - setv(to_object, ["config"], getv(from_object, ["config"])) + setv( + to_object, + ["config"], + _RetrieveAgentEngineMemoriesConfig_to_vertex( + getv(from_object, ["config"]), to_object + ), + ) return to_object diff --git a/vertexai/_genai/types/__init__.py b/vertexai/_genai/types/__init__.py index b54d224c12..9110317794 100644 --- a/vertexai/_genai/types/__init__.py +++ b/vertexai/_genai/types/__init__.py @@ -433,6 +433,9 @@ from .common import GetPromptConfigDict from .common import GetPromptConfigOrDict from .common import Importance +from .common import IntermediateExtractedMemory +from .common import IntermediateExtractedMemoryDict +from .common import IntermediateExtractedMemoryOrDict from .common import JobState from .common import Language from .common import ListAgentEngineConfig @@ -1499,6 +1502,9 @@ "GetAgentEngineMemoryRevisionConfig", "GetAgentEngineMemoryRevisionConfigDict", "GetAgentEngineMemoryRevisionConfigOrDict", + "IntermediateExtractedMemory", + "IntermediateExtractedMemoryDict", + "IntermediateExtractedMemoryOrDict", "MemoryRevision", "MemoryRevisionDict", "MemoryRevisionOrDict", diff --git a/vertexai/_genai/types/common.py b/vertexai/_genai/types/common.py index 151adceaf8..d2285eb195 100644 --- a/vertexai/_genai/types/common.py +++ b/vertexai/_genai/types/common.py @@ -7520,6 +7520,17 @@ class RetrieveAgentEngineMemoriesConfig(_common.BaseModel): http_options: Optional[genai_types.HttpOptions] = Field( default=None, description="""Used to override HTTP request options.""" ) + filter: Optional[str] = Field( + default=None, + description="""The standard list filter that will be applied to the retrieved + memories. More detail in [AIP-160](https://google.aip.dev/160). + + Supported fields: + * `fact` + * `create_time` + * `update_time` + """, + ) class RetrieveAgentEngineMemoriesConfigDict(TypedDict, total=False): @@ -7528,6 +7539,16 @@ class RetrieveAgentEngineMemoriesConfigDict(TypedDict, total=False): http_options: Optional[genai_types.HttpOptionsDict] """Used to override HTTP request options.""" + filter: Optional[str] + """The standard list filter that will be applied to the retrieved + memories. More detail in [AIP-160](https://google.aip.dev/160). + + Supported fields: + * `fact` + * `create_time` + * `update_time` + """ + RetrieveAgentEngineMemoriesConfigOrDict = Union[ RetrieveAgentEngineMemoriesConfig, RetrieveAgentEngineMemoriesConfigDict @@ -7946,6 +7967,26 @@ class _GetAgentEngineMemoryRevisionRequestParametersDict(TypedDict, total=False) ] +class IntermediateExtractedMemory(_common.BaseModel): + """An extracted memory that is the intermediate result before consolidation.""" + + fact: Optional[str] = Field( + default=None, description="""Output only. The fact of the extracted memory.""" + ) + + +class IntermediateExtractedMemoryDict(TypedDict, total=False): + """An extracted memory that is the intermediate result before consolidation.""" + + fact: Optional[str] + """Output only. The fact of the extracted memory.""" + + +IntermediateExtractedMemoryOrDict = Union[ + IntermediateExtractedMemory, IntermediateExtractedMemoryDict +] + + class MemoryRevision(_common.BaseModel): """A memory revision.""" @@ -7969,6 +8010,10 @@ class MemoryRevision(_common.BaseModel): default=None, description="""Output only. The labels of the Memory Revision. These labels are applied to the MemoryRevision when it is created based on `GenerateMemoriesRequest.revision_labels`.""", ) + extracted_memories: Optional[list[IntermediateExtractedMemory]] = Field( + default=None, + description="""Output only. The extracted memories from the source content before consolidation when the memory was updated via GenerateMemories. This information was used to modify an existing Memory via Consolidation.""", + ) class MemoryRevisionDict(TypedDict, total=False): @@ -7989,6 +8034,9 @@ class MemoryRevisionDict(TypedDict, total=False): labels: Optional[dict[str, str]] """Output only. The labels of the Memory Revision. These labels are applied to the MemoryRevision when it is created based on `GenerateMemoriesRequest.revision_labels`.""" + extracted_memories: Optional[list[IntermediateExtractedMemoryDict]] + """Output only. The extracted memories from the source content before consolidation when the memory was updated via GenerateMemories. This information was used to modify an existing Memory via Consolidation.""" + MemoryRevisionOrDict = Union[MemoryRevision, MemoryRevisionDict]