Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions tests/unit/vertexai/genai/replays/test_create_agent_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os

from tests.unit.vertexai.genai.replays import pytest_helper
from vertexai._genai import types


def test_create_config_lightweight(client):
Expand Down Expand Up @@ -54,6 +55,29 @@ def test_create_with_context_spec(client):
parent = f"projects/{project}/locations/{location}"
generation_model = f"{parent}/publishers/google/models/gemini-2.0-flash-001"
embedding_model = f"{parent}/publishers/google/models/text-embedding-005"
customization_config = {
"memory_topics": [
{"managed_memory_topic": {"managed_topic_enum": "USER_PREFERENCES"}}
],
"generate_memories_examples": [
{
"conversation_source": {
"events": [
{"content": {"role": "user", "parts": [{"text": "Hello"}]}}
]
},
"generatedMemories": [
{
"fact": "I like to say hello.",
"topics": [{"managed_memory_topic": "USER_PREFERENCES"}],
}
],
}
],
}
memory_bank_customization_config = types.MemoryBankCustomizationConfig(
**customization_config
)

agent_engine = client.agent_engines.create(
config={
Expand All @@ -64,6 +88,7 @@ def test_create_with_context_spec(client):
"embedding_model": embedding_model,
},
"ttl_config": {"default_ttl": "120s"},
"customization_configs": [memory_bank_customization_config],
},
},
"http_options": {"api_version": "v1beta1"},
Expand All @@ -76,6 +101,9 @@ def test_create_with_context_spec(client):
memory_bank_config.similarity_search_config.embedding_model == embedding_model
)
assert memory_bank_config.ttl_config.default_ttl == "120s"
assert memory_bank_config.customization_configs == [
memory_bank_customization_config
]
# Clean up resources.
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)

Expand Down
10 changes: 10 additions & 0 deletions vertexai/_genai/memories.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ def _AgentEngineMemoryConfig_to_vertex(
getv(from_object, ["disable_memory_revisions"]),
)

if getv(from_object, ["topics"]) is not None:
setv(
parent_object, ["topics"], [item for item in getv(from_object, ["topics"])]
)

return to_object


Expand Down Expand Up @@ -365,6 +370,11 @@ def _UpdateAgentEngineMemoryConfig_to_vertex(
getv(from_object, ["disable_memory_revisions"]),
)

if getv(from_object, ["topics"]) is not None:
setv(
parent_object, ["topics"], [item for item in getv(from_object, ["topics"])]
)

if getv(from_object, ["update_mask"]) is not None:
setv(
parent_object, ["_query", "updateMask"], getv(from_object, ["update_mask"])
Expand Down
56 changes: 56 additions & 0 deletions vertexai/_genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5330,6 +5330,30 @@ class MemoryBankCustomizationConfigGenerateMemoriesExampleConversationSourceDict
]


class MemoryTopicId(_common.BaseModel):
"""The topic ID for a memory."""

custom_memory_topic_label: Optional[str] = Field(
default=None, description="""Optional. The custom memory topic label."""
)
managed_memory_topic: Optional[ManagedTopicEnum] = Field(
default=None, description="""Optional. The managed memory topic."""
)


class MemoryTopicIdDict(TypedDict, total=False):
"""The topic ID for a memory."""

custom_memory_topic_label: Optional[str]
"""Optional. The custom memory topic label."""

managed_memory_topic: Optional[ManagedTopicEnum]
"""Optional. The managed memory topic."""


MemoryTopicIdOrDict = Union[MemoryTopicId, MemoryTopicIdDict]


class MemoryBankCustomizationConfigGenerateMemoriesExampleGeneratedMemory(
_common.BaseModel
):
Expand All @@ -5338,6 +5362,10 @@ class MemoryBankCustomizationConfigGenerateMemoriesExampleGeneratedMemory(
fact: Optional[str] = Field(
default=None, description="""Required. The fact to generate a memory from."""
)
topics: Optional[list[MemoryTopicId]] = Field(
default=None,
description="""Optional. The list of topics that the memory should be associated with. For example, use `custom_memory_topic_label = "jargon"` if the extracted memory is an example of memory extraction for the custom topic `jargon`.""",
)


class MemoryBankCustomizationConfigGenerateMemoriesExampleGeneratedMemoryDict(
Expand All @@ -5348,6 +5376,9 @@ class MemoryBankCustomizationConfigGenerateMemoriesExampleGeneratedMemoryDict(
fact: Optional[str]
"""Required. The fact to generate a memory from."""

topics: Optional[list[MemoryTopicIdDict]]
"""Optional. The list of topics that the memory should be associated with. For example, use `custom_memory_topic_label = "jargon"` if the extracted memory is an example of memory extraction for the custom topic `jargon`."""


MemoryBankCustomizationConfigGenerateMemoriesExampleGeneratedMemoryOrDict = Union[
MemoryBankCustomizationConfigGenerateMemoriesExampleGeneratedMemory,
Expand Down Expand Up @@ -6432,6 +6463,9 @@ class AgentEngineMemoryConfig(_common.BaseModel):
default=None,
description="""Optional. Input only. If true, no revision will be created for this request.""",
)
topics: Optional[list[MemoryTopicId]] = Field(
default=None, description="""Optional. The topics of the memory."""
)


class AgentEngineMemoryConfigDict(TypedDict, total=False):
Expand Down Expand Up @@ -6466,6 +6500,9 @@ class AgentEngineMemoryConfigDict(TypedDict, total=False):
disable_memory_revisions: Optional[bool]
"""Optional. Input only. If true, no revision will be created for this request."""

topics: Optional[list[MemoryTopicIdDict]]
"""Optional. The topics of the memory."""


AgentEngineMemoryConfigOrDict = Union[
AgentEngineMemoryConfig, AgentEngineMemoryConfigDict
Expand Down Expand Up @@ -6573,6 +6610,9 @@ class Memory(_common.BaseModel):
default=None,
description="""Output only. Timestamp when this Memory was most recently updated.""",
)
topics: Optional[list[MemoryTopicId]] = Field(
default=None, description="""Optional. The Topics of the Memory."""
)


class MemoryDict(TypedDict, total=False):
Expand Down Expand Up @@ -6614,6 +6654,9 @@ class MemoryDict(TypedDict, total=False):
update_time: Optional[datetime.datetime]
"""Output only. Timestamp when this Memory was most recently updated."""

topics: Optional[list[MemoryTopicIdDict]]
"""Optional. The Topics of the Memory."""


MemoryOrDict = Union[Memory, MemoryDict]

Expand Down Expand Up @@ -6840,6 +6883,10 @@ class GenerateMemoriesRequestDirectMemoriesSourceDirectMemory(_common.BaseModel)
default=None,
description="""Required. The fact to consolidate with existing memories.""",
)
topics: Optional[list[MemoryTopicId]] = Field(
default=None,
description="""Optional. The topics that the consolidated memories should be associated with.""",
)


class GenerateMemoriesRequestDirectMemoriesSourceDirectMemoryDict(
Expand All @@ -6850,6 +6897,9 @@ class GenerateMemoriesRequestDirectMemoriesSourceDirectMemoryDict(
fact: Optional[str]
"""Required. The fact to consolidate with existing memories."""

topics: Optional[list[MemoryTopicIdDict]]
"""Optional. The topics that the consolidated memories should be associated with."""


GenerateMemoriesRequestDirectMemoriesSourceDirectMemoryOrDict = Union[
GenerateMemoriesRequestDirectMemoriesSourceDirectMemory,
Expand Down Expand Up @@ -7683,6 +7733,9 @@ class UpdateAgentEngineMemoryConfig(_common.BaseModel):
default=None,
description="""Optional. Input only. If true, no revision will be created for this request.""",
)
topics: Optional[list[MemoryTopicId]] = Field(
default=None, description="""Optional. The topics of the memory."""
)
update_mask: Optional[str] = Field(
default=None,
description="""The update mask to apply. For the `FieldMask` definition, see
Expand Down Expand Up @@ -7722,6 +7775,9 @@ class UpdateAgentEngineMemoryConfigDict(TypedDict, total=False):
disable_memory_revisions: Optional[bool]
"""Optional. Input only. If true, no revision will be created for this request."""

topics: Optional[list[MemoryTopicIdDict]]
"""Optional. The topics of the memory."""

update_mask: Optional[str]
"""The update mask to apply. For the `FieldMask` definition, see
https://protobuf.dev/reference/protobuf/google.protobuf/#field-mask."""
Expand Down
Loading