Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@

def test_delete_memory(client):
agent_engine = client.agent_engines.create()
operation = client.agent_engines.create_memory(
operation = client.agent_engines.memories.create(
name=agent_engine.api_resource.name,
fact="memory_fact",
scope={"user_id": "123"},
)
memory = operation.response
operation = client.agent_engines.delete_memory(name=memory.name)
operation = client.agent_engines.memories.delete(name=memory.name)
assert isinstance(operation, types.DeleteAgentEngineMemoryOperation)
assert operation.name.startswith(memory.name + "/operations/")
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)


pytestmark = pytest_helper.setup(
Expand All @@ -46,14 +47,16 @@ def test_delete_memory(client):

@pytest.mark.asyncio
async def test_delete_memory_async(client):
# TODO(b/431785750): use async methods for create() and create_memory() when available
agent_engine = client.agent_engines.create()
operation = client.agent_engines.create_memory(
operation = await client.aio.agent_engines.memories.create(
name=agent_engine.api_resource.name,
fact="memory_fact",
scope={"user_id": "123"},
)
memory = operation.response
operation = await client.aio.agent_engines.delete_memory(name=memory.name)
operation = await client.aio.agent_engines.memories.delete(name=memory.name)
assert isinstance(operation, types.DeleteAgentEngineMemoryOperation)
assert operation.name.startswith(memory.name + "/operations/")
await client.aio.agent_engines.delete(
name=agent_engine.api_resource.name, force=True
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@
#
# pylint: disable=protected-access,bad-continuation,missing-function-docstring

import pytest


from tests.unit.vertexai.genai.replays import pytest_helper
from vertexai._genai import types
from google.genai import types as genai_types


def test_generate_and_rollback_memories(client):
# TODO(): Use prod endpoint once experiment is fully rolled out.
client._api_client._http_options.base_url = (
"https://us-central1-autopush-aiplatform.sandbox.googleapis.com/"
)
Expand Down Expand Up @@ -146,3 +150,64 @@ def test_generate_memories_direct_memories_source(client):
globals_for_file=globals(),
test_method="agent_engines.generate_memories",
)


pytest_plugins = ("pytest_asyncio",)


@pytest.mark.asyncio
async def test_generate_and_rollback_memories_async(client):
# TODO(): Use prod endpoint once revisions experiment is fully rolled out.
client._api_client._http_options.base_url = (
"https://us-central1-autopush-aiplatform.sandbox.googleapis.com/"
)
agent_engine = client.agent_engines.create()
await client.aio.agent_engines.memories.generate(
name=agent_engine.api_resource.name,
scope={"user_id": "test-user-id"},
direct_memories_source=types.GenerateMemoriesRequestDirectMemoriesSource(
direct_memories=[
types.GenerateMemoriesRequestDirectMemoriesSourceDirectMemory(
fact="I am a software engineer."
),
types.GenerateMemoriesRequestDirectMemoriesSourceDirectMemory(
fact="I like to write replay tests."
),
]
),
config=types.GenerateAgentEngineMemoriesConfig(wait_for_completion=True),
)
memories_pager = await client.aio.agent_engines.memories.list(
name=agent_engine.api_resource.name
)
memory_list = [item async for item in memories_pager]
assert len(memory_list) >= 1

revisions_pager = await client.aio.agent_engines.memories.revisions.list(
name=memory_list[0].name
)
memory_revisions = [item async for item in revisions_pager]
assert len(memory_revisions) >= 1
revision_name = memory_revisions[0].name

# Update the memory.
client.agent_engines.memories._update(
name=memory_list[0].name,
fact="This is temporary",
scope={"user_id": "test-user-id"},
)
memory = await client.aio.agent_engines.memories.get(name=memory_list[0].name)
assert memory.fact == "This is temporary"

# Rollback to the revision with the original fact that was created by the
# generation request.
await client.aio.agent_engines.memories.rollback(
name=memory_list[0].name,
target_revision_id=revision_name.split("/")[-1],
)
memory = await client.aio.agent_engines.memories.get(name=memory_list[0].name)
assert memory.fact == memory_revisions[0].fact

await client.aio.agent_engines.delete(
name=agent_engine.api_resource.name, force=True
)
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,24 @@

def test_get_memory(client):
agent_engine = client.agent_engines.create()
operation = client.agent_engines.create_memory(
operation = client.agent_engines.memories.create(
name=agent_engine.api_resource.name,
fact="memory_fact",
scope={"user_id": "123"},
)
assert isinstance(operation, types.AgentEngineMemoryOperation)
memory = client.agent_engines.get_memory(
memory = client.agent_engines.memories.get(
name=operation.response.name,
)
assert isinstance(memory, types.Memory)
assert memory.name == operation.response.name
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)


pytestmark = pytest_helper.setup(
file=__file__,
globals_for_file=globals(),
test_method="agent_engines.get_memory",
test_method="agent_engines.memories.get",
)


Expand All @@ -47,16 +48,18 @@ def test_get_memory(client):

@pytest.mark.asyncio
async def test_get_memory_async(client):
# TODO(b/431785750): use async methods for create() and create_memory() when available
agent_engine = client.agent_engines.create()
operation = client.agent_engines.create_memory(
operation = await client.aio.agent_engines.memories.create(
name=agent_engine.api_resource.name,
fact="memory_fact",
scope={"user_id": "123"},
)
assert isinstance(operation, types.AgentEngineMemoryOperation)
memory = await client.aio.agent_engines.get_memory(
memory = await client.aio.agent_engines.memories.get(
name=operation.response.name,
)
assert isinstance(memory, types.Memory)
assert memory.name == operation.response.name
await client.aio.agent_engines.delete(
name=agent_engine.api_resource.name, force=True
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#
# pylint: disable=protected-access,bad-continuation,missing-function-docstring

import pytest

from tests.unit.vertexai.genai.replays import pytest_helper
from vertexai._genai import types

Expand Down Expand Up @@ -59,5 +61,36 @@ def test_list_memories(client):
pytestmark = pytest_helper.setup(
file=__file__,
globals_for_file=globals(),
test_method="agent_engines.list_memories",
test_method="agent_engines.memories.list",
)


pytest_plugins = ("pytest_asyncio",)


@pytest.mark.asyncio
async def test_async_list_memories(client):
agent_engine = client.agent_engines.create()
pager = await client.aio.agent_engines.memories.list(
name=agent_engine.api_resource.name
)
assert not [item async for item in pager]

await client.aio.agent_engines.memories.create(
name=agent_engine.api_resource.name,
fact="memory_fact_2",
scope={"user_id": "456"},
config={
"wait_for_completion": True,
},
)
pager = await client.aio.agent_engines.memories.list(
name=agent_engine.api_resource.name
)
memory_list = [item async for item in pager]
assert len(memory_list) == 1
assert isinstance(memory_list[0], types.Memory)

await client.aio.agent_engines.delete(
name=agent_engine.api_resource.name, force=True
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
#
# pylint: disable=protected-access,bad-continuation,missing-function-docstring

import pytest


from tests.unit.vertexai.genai.replays import pytest_helper
from vertexai._genai import types
from google.genai import pagers
Expand All @@ -22,23 +25,23 @@
def test_retrieve_memories_with_similarity_search_params(client):
agent_engine = client.agent_engines.create()
assert not list(
client.agent_engines.retrieve_memories(
client.agent_engines.memories.retrieve(
name=agent_engine.api_resource.name,
scope={"user_id": "123"},
similarity_search_params=types.RetrieveMemoriesRequestSimilaritySearchParams(
search_query="memory_fact_1",
),
)
)
client.agent_engines.create_memory(
client.agent_engines.memories.create(
name=agent_engine.api_resource.name,
fact="memory_fact_1",
scope={"user_id": "123"},
)
assert (
len(
list(
client.agent_engines.retrieve_memories(
client.agent_engines.memories.retrieve(
name=agent_engine.api_resource.name,
scope={"user_id": "123"},
)
Expand All @@ -47,20 +50,20 @@ def test_retrieve_memories_with_similarity_search_params(client):
== 1
)
assert not list(
client.agent_engines.retrieve_memories(
client.agent_engines.memories.retrieve(
name=agent_engine.api_resource.name,
scope={"user_id": "456"},
)
)
client.agent_engines.create_memory(
client.agent_engines.memories.create(
name=agent_engine.api_resource.name,
fact="memory_fact_2",
scope={"user_id": "123"},
)
assert (
len(
list(
client.agent_engines.retrieve_memories(
client.agent_engines.memories.retrieve(
name=agent_engine.api_resource.name,
scope={"user_id": "123"},
)
Expand All @@ -74,12 +77,12 @@ def test_retrieve_memories_with_similarity_search_params(client):

def test_retrieve_memories_with_simple_retrieval_params(client):
agent_engine = client.agent_engines.create()
client.agent_engines.create_memory(
client.agent_engines.memories.create(
name=agent_engine.api_resource.name,
fact="memory_fact_1",
scope={"user_id": "123"},
)
memories = client.agent_engines.retrieve_memories(
memories = client.agent_engines.memories.retrieve(
name=agent_engine.api_resource.name,
scope={"user_id": "123"},
simple_retrieval_params=types.RetrieveMemoriesRequestSimpleRetrievalParams(
Expand All @@ -98,3 +101,27 @@ def test_retrieve_memories_with_simple_retrieval_params(client):
globals_for_file=globals(),
test_method="agent_engines.create_memory",
)


pytest_plugins = ("pytest_asyncio",)


@pytest.mark.asyncio
async def test_retrieve_memories_async(client):
agent_engine = client.agent_engines.create()
operation = await client.aio.agent_engines.memories.create(
name=agent_engine.api_resource.name,
fact="memory_fact",
scope={"user_id": "123"},
)
assert isinstance(operation, types.AgentEngineMemoryOperation)
pager = await client.aio.agent_engines.memories.retrieve(
name=agent_engine.api_resource.name,
scope={"user_id": "123"},
)
memories = [item async for item in pager]
assert len(memories) == 1
assert isinstance(memories[0], types.RetrieveMemoriesResponseRetrievedMemory)
await client.aio.agent_engines.delete(
name=agent_engine.api_resource.name, force=True
)
Loading
Loading