Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 43 additions & 41 deletions tests/unit/vertexai/genai/replays/test_create_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,19 @@ def test_create(client):
)
assert isinstance(prompt_resource, types.Prompt)
assert isinstance(prompt_resource.dataset, types.Dataset)
assert isinstance(prompt_resource.dataset_version, types.DatasetVersion)
assert prompt_resource.version_id == "1"
assert (
prompt_resource.dataset_version.metadata.prompt_api_schema.multimodal_prompt
== prompt_resource.dataset.metadata.prompt_api_schema.multimodal_prompt
)


def test_create_e2e(client):
prompt_resource = client.prompts.create(
prompt=TEST_PROMPT,
config=TEST_CREATE_PROMPT_CONFIG,
)
assert isinstance(prompt_resource, types.Prompt)
assert isinstance(prompt_resource.dataset, types.Dataset)

# Test local prompt resource is the same after calling get()
retrieved_prompt = client.prompts.get(prompt_id=prompt_resource.prompt_id)
Expand Down Expand Up @@ -186,28 +190,31 @@ def test_create_e2e(client):
# Test calling create_version on the same prompt dataset and change the prompt
new_prompt = TEST_PROMPT.model_copy(deep=True)
new_prompt.prompt_data.contents[0].parts[0].text = "Is this Alice?"
prompt_resource_2 = client.prompts.create_version(
updated_prompt = client.prompts.update(
prompt_id=prompt_resource.prompt_id,
prompt=new_prompt,
config=types.CreatePromptVersionConfig(
version_display_name="my_version",
config=types.UpdatePromptConfig(
prompt_display_name="updated_prompt_display_name",
version_display_name="my_version_2",
),
)
assert prompt_resource_2.dataset.name == prompt_resource.dataset.name
assert prompt_resource_2.prompt_data.contents[0].parts[0].text == "Is this Alice?"
assert updated_prompt.dataset.display_name == "updated_prompt_display_name"
assert updated_prompt.dataset_version.display_name == "my_version_2"
assert updated_prompt.version_id == "2"
assert updated_prompt.prompt_data.contents[0].parts[0].text == "Is this Alice?"

# Update the prompt contents again and verify version history is preserved
prompt_v3 = TEST_PROMPT.model_copy(deep=True)
prompt_v3.prompt_data.contents[0].parts[0].text = "Is this Bob?"
prompt_resource_3 = client.prompts.create_version(
# Tests that assemble_contents() works on a prompt without variables.
assert updated_prompt.assemble_contents()[0].role == "user"

# Calling get_version on version "1" should return the original prompt contents
original_prompt = client.prompts.get_version(
prompt_id=prompt_resource.prompt_id,
prompt=prompt_v3,
config=types.CreatePromptVersionConfig(
version_display_name="my_version_2",
),
version_id="1",
)
assert (
original_prompt.prompt_data.contents[0].parts[0].text
== "Hello, {name}! How are you?"
)
assert prompt_resource_3.dataset.name == prompt_resource.dataset.name
assert prompt_resource_3.prompt_data.contents[0].parts[0].text == "Is this Bob?"


def test_create_version(client):
Expand Down Expand Up @@ -296,6 +303,7 @@ def test_create_with_encryption_spec(client):
config = types.CreatePromptConfig(
prompt_display_name="my_prompt_with_encryption_spec",
encryption_spec=encryption_spec,
version_display_name="my_version_with_encryption_spec",
)
prompt_resource = client.prompts.create(
prompt=TEST_PROMPT,
Expand All @@ -304,19 +312,19 @@ def test_create_with_encryption_spec(client):
assert isinstance(prompt_resource, types.Prompt)
assert isinstance(prompt_resource.dataset, types.Dataset)

# Create a version on a prompt with an encryption spec.
# Update a prompt with an encryption spec.
new_prompt = TEST_PROMPT.model_copy(deep=True)
new_prompt.prompt_data.contents[0].parts[0].text = "Is this Alice?"
prompt_version_resource = client.prompts.create_version(
updated_prompt_resource = client.prompts.update(
prompt_id=prompt_resource.prompt_id,
prompt=new_prompt,
config=types.CreatePromptVersionConfig(
config=types.UpdatePromptConfig(
version_display_name="my_version_existing_dataset",
),
)
assert isinstance(prompt_version_resource, types.Prompt)
assert isinstance(prompt_version_resource.dataset, types.Dataset)
assert isinstance(prompt_version_resource.dataset_version, types.DatasetVersion)
assert isinstance(updated_prompt_resource, types.Prompt)
assert isinstance(updated_prompt_resource.dataset, types.Dataset)
assert isinstance(updated_prompt_resource.dataset_version, types.DatasetVersion)


pytestmark = pytest_helper.setup(
Expand All @@ -329,35 +337,29 @@ def test_create_with_encryption_spec(client):


@pytest.mark.asyncio
async def test_create_async(client):
async def test_create_async_e2e(client):
prompt_resource = await client.aio.prompts.create(
prompt=TEST_PROMPT.model_dump(),
config=TEST_CREATE_PROMPT_CONFIG.model_dump(),
)
assert isinstance(prompt_resource, types.Prompt)
assert isinstance(prompt_resource.dataset, types.Dataset)


@pytest.mark.asyncio
async def test_create_version_async(client):
prompt_resource = await client.aio.prompts.create(
prompt=TEST_PROMPT.model_dump(),
config=TEST_CREATE_PROMPT_CONFIG.model_dump(),
assert isinstance(prompt_resource.dataset_version, types.DatasetVersion)
assert prompt_resource.version_id == "1"
assert (
prompt_resource.dataset.metadata.prompt_api_schema.multimodal_prompt.prompt_message
== prompt_resource.dataset_version.metadata.prompt_api_schema.multimodal_prompt.prompt_message
)

new_prompt = TEST_PROMPT.model_copy(deep=True)
new_prompt.prompt_data.contents[0].parts[0].text = "Is this Alice?"
prompt_version_resource = await client.aio.prompts.create_version(
updated_prompt_resource = await client.aio.prompts.update(
prompt_id=prompt_resource.prompt_id,
prompt=new_prompt,
config=types.CreatePromptVersionConfig(
config=types.UpdatePromptConfig(
version_display_name="my_version_existing_dataset",
),
)
assert isinstance(prompt_version_resource, types.Prompt)
assert isinstance(prompt_version_resource.dataset, types.Dataset)
assert isinstance(prompt_version_resource.dataset_version, types.DatasetVersion)
assert prompt_version_resource.dataset.name.endswith(prompt_resource.prompt_id)
assert (
prompt_version_resource.prompt_data.contents[0].parts[0].text
== "Is this Alice?"
)
assert isinstance(updated_prompt_resource, types.Prompt)
assert isinstance(updated_prompt_resource.dataset, types.Dataset)
assert isinstance(updated_prompt_resource.dataset_version, types.DatasetVersion)
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def test_get_prompt_version(client):
version_id=TEST_PROMPT_VERSION_ID,
)
assert isinstance(prompt, types.Prompt)
assert isinstance(prompt.dataset, types.Dataset)
assert isinstance(prompt.dataset_version, types.DatasetVersion)
assert prompt.dataset.name.endswith(TEST_PROMPT_DATASET_ID)
assert prompt.dataset_version.name.endswith(TEST_PROMPT_VERSION_ID)


Expand Down Expand Up @@ -100,10 +98,8 @@ async def test_get_prompt_version_async(client):
prompt_id=TEST_PROMPT_DATASET_ID, version_id=TEST_PROMPT_VERSION_ID
)
assert isinstance(prompt, types.Prompt)
assert isinstance(prompt.dataset, types.Dataset)
assert prompt.dataset.name.endswith(TEST_PROMPT_DATASET_ID)
assert (
prompt.prompt_data
== prompt.dataset.metadata.prompt_api_schema.multimodal_prompt.prompt_message
== prompt.dataset_version.metadata.prompt_api_schema.multimodal_prompt.prompt_message
)
assert isinstance(prompt.prompt_data, types.SchemaPromptSpecPromptMessage)
Loading
Loading