Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,23 +95,27 @@ def test_generate_and_rollback_memories(client):
# Update the memory again using generation. We use the original source
# content to ensure that the original memory is updated. The response should
# refer to the previous revision.
pre_extracted_fact = "I am a software engineer focusing in security"
response = client.agent_engines.memories.generate(
name=agent_engine.api_resource.name,
scope={"user_id": "test-user-id"},
direct_contents_source=types.GenerateMemoriesRequestDirectContentsSource(
events=[
types.GenerateMemoriesRequestDirectContentsSourceEvent(
content=genai_types.Content(
role="model",
parts=[genai_types.Part(text=memory_revisions[0].fact)],
)
direct_memories_source=types.GenerateMemoriesRequestDirectMemoriesSource(
direct_memories=[
types.GenerateMemoriesRequestDirectMemoriesSourceDirectMemory(
fact=pre_extracted_fact
)
]
),
)
# The memory was updated, so the previous revision is set.
assert response.response.generated_memories[0].previous_revision is not None

memory_revisions = list(
client.agent_engines.memories.revisions.list(name=memories[0].name)
)
# Memory Revisions are returned in descending order by revision create time.
# We can't make an assertion on the actual value, since it's
# generated and thus non-deterministic.
assert memory_revisions[0].extracted_memories[0].fact is not None
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,33 @@ def test_retrieve_memories_with_simple_retrieval_params(client):
assert isinstance(memories, pagers.Pager)
assert isinstance(memories.page[0], types.RetrieveMemoriesResponseRetrievedMemory)
assert memories.page_size == 1

client.agent_engines.memories.create(
name=agent_engine.api_resource.name,
fact="memory_fact_2",
scope={"user_id": "123"},
)
memories = client.agent_engines.memories.retrieve(
name=agent_engine.api_resource.name, scope={"user_id": "123"}
)
assert memories.page_size == 2

memories = client.agent_engines.memories.retrieve(
name=agent_engine.api_resource.name,
scope={"user_id": "123"},
config={"filter": 'fact="memory_fact_2"'},
)
assert memories.page_size == 1
assert memories.page[0].memory.fact == "memory_fact_2"

# Clean up resources.
agent_engine.delete(force=True)


pytestmark = pytest_helper.setup(
file=__file__,
globals_for_file=globals(),
test_method="agent_engines.create_memory",
test_method="agent_engines.memories.retrieve",
)


Expand Down
20 changes: 19 additions & 1 deletion vertexai/_genai/memories.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,18 @@ def _ListAgentEngineMemoryRequestParameters_to_vertex(
return to_object


def _RetrieveAgentEngineMemoriesConfig_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}

if getv(from_object, ["filter"]) is not None:
setv(parent_object, ["filter"], getv(from_object, ["filter"]))

return to_object


def _RetrieveAgentEngineMemoriesRequestParameters_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
Expand All @@ -313,7 +325,13 @@ def _RetrieveAgentEngineMemoriesRequestParameters_to_vertex(
)

if getv(from_object, ["config"]) is not None:
setv(to_object, ["config"], getv(from_object, ["config"]))
setv(
to_object,
["config"],
_RetrieveAgentEngineMemoriesConfig_to_vertex(
getv(from_object, ["config"]), to_object
),
)

return to_object

Expand Down
6 changes: 6 additions & 0 deletions vertexai/_genai/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,9 @@
from .common import GetPromptConfigDict
from .common import GetPromptConfigOrDict
from .common import Importance
from .common import IntermediateExtractedMemory
from .common import IntermediateExtractedMemoryDict
from .common import IntermediateExtractedMemoryOrDict
from .common import JobState
from .common import Language
from .common import ListAgentEngineConfig
Expand Down Expand Up @@ -1499,6 +1502,9 @@
"GetAgentEngineMemoryRevisionConfig",
"GetAgentEngineMemoryRevisionConfigDict",
"GetAgentEngineMemoryRevisionConfigOrDict",
"IntermediateExtractedMemory",
"IntermediateExtractedMemoryDict",
"IntermediateExtractedMemoryOrDict",
"MemoryRevision",
"MemoryRevisionDict",
"MemoryRevisionOrDict",
Expand Down
48 changes: 48 additions & 0 deletions vertexai/_genai/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7520,6 +7520,17 @@ class RetrieveAgentEngineMemoriesConfig(_common.BaseModel):
http_options: Optional[genai_types.HttpOptions] = Field(
default=None, description="""Used to override HTTP request options."""
)
filter: Optional[str] = Field(
default=None,
description="""The standard list filter that will be applied to the retrieved
memories. More detail in [AIP-160](https://google.aip.dev/160).

Supported fields:
* `fact`
* `create_time`
* `update_time`
""",
)


class RetrieveAgentEngineMemoriesConfigDict(TypedDict, total=False):
Expand All @@ -7528,6 +7539,16 @@ class RetrieveAgentEngineMemoriesConfigDict(TypedDict, total=False):
http_options: Optional[genai_types.HttpOptionsDict]
"""Used to override HTTP request options."""

filter: Optional[str]
"""The standard list filter that will be applied to the retrieved
memories. More detail in [AIP-160](https://google.aip.dev/160).

Supported fields:
* `fact`
* `create_time`
* `update_time`
"""


RetrieveAgentEngineMemoriesConfigOrDict = Union[
RetrieveAgentEngineMemoriesConfig, RetrieveAgentEngineMemoriesConfigDict
Expand Down Expand Up @@ -7946,6 +7967,26 @@ class _GetAgentEngineMemoryRevisionRequestParametersDict(TypedDict, total=False)
]


class IntermediateExtractedMemory(_common.BaseModel):
"""An extracted memory that is the intermediate result before consolidation."""

fact: Optional[str] = Field(
default=None, description="""Output only. The fact of the extracted memory."""
)


class IntermediateExtractedMemoryDict(TypedDict, total=False):
"""An extracted memory that is the intermediate result before consolidation."""

fact: Optional[str]
"""Output only. The fact of the extracted memory."""


IntermediateExtractedMemoryOrDict = Union[
IntermediateExtractedMemory, IntermediateExtractedMemoryDict
]


class MemoryRevision(_common.BaseModel):
"""A memory revision."""

Expand All @@ -7969,6 +8010,10 @@ class MemoryRevision(_common.BaseModel):
default=None,
description="""Output only. The labels of the Memory Revision. These labels are applied to the MemoryRevision when it is created based on `GenerateMemoriesRequest.revision_labels`.""",
)
extracted_memories: Optional[list[IntermediateExtractedMemory]] = Field(
default=None,
description="""Output only. The extracted memories from the source content before consolidation when the memory was updated via GenerateMemories. This information was used to modify an existing Memory via Consolidation.""",
)


class MemoryRevisionDict(TypedDict, total=False):
Expand All @@ -7989,6 +8034,9 @@ class MemoryRevisionDict(TypedDict, total=False):
labels: Optional[dict[str, str]]
"""Output only. The labels of the Memory Revision. These labels are applied to the MemoryRevision when it is created based on `GenerateMemoriesRequest.revision_labels`."""

extracted_memories: Optional[list[IntermediateExtractedMemoryDict]]
"""Output only. The extracted memories from the source content before consolidation when the memory was updated via GenerateMemories. This information was used to modify an existing Memory via Consolidation."""


MemoryRevisionOrDict = Union[MemoryRevision, MemoryRevisionDict]

Expand Down
Loading