Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
#
# pylint: disable=protected-access,bad-continuation,missing-function-docstring

import logging
import os

from tests.unit.vertexai.genai.replays import pytest_helper
from vertexai._genai import types
from google.genai import types as genai_types
Expand All @@ -38,7 +38,7 @@ def test_optimize(client):

_raise_for_unset_env_vars()

config = types.PromptOptimizerVAPOConfig(
config = types.PromptOptimizerConfig(
config_path=os.environ.get("VAPO_CONFIG_PATH"),
wait_for_completion=True,
service_account_project_number=os.environ.get(
Expand All @@ -47,7 +47,33 @@ def test_optimize(client):
optimizer_job_display_name="optimizer_job_test",
)
job = client.prompt_optimizer.optimize(
method="vapo",
method=types.PromptOptimizerMethod.VAPO,
config=config,
)
assert isinstance(job, types.CustomJob)
assert job.state == genai_types.JobState.JOB_STATE_SUCCEEDED


def test_optimize_nano(client):
"""Tests the optimize request parameters method."""

_raise_for_unset_env_vars()

config_path = os.environ.get("VAPO_CONFIG_PATH")
root, ext = os.path.splitext(config_path)
nano_path = f"{root}_nano{ext}"

config = types.PromptOptimizerConfig(
config_path=nano_path,
wait_for_completion=True,
service_account_project_number=os.environ.get(
"VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER"
),
optimizer_job_display_name="optimizer_job_test",
)

job = client.prompt_optimizer.optimize(
method=types.PromptOptimizerMethod.OPTIMIZATION_TARGET_GEMINI_NANO,
config=config,
)
assert isinstance(job, types.CustomJob)
Expand All @@ -68,15 +94,37 @@ def test_optimize(client):
async def test_optimize_async(client):
_raise_for_unset_env_vars()

config = types.PromptOptimizerVAPOConfig(
config = types.PromptOptimizerConfig(
config_path=os.environ.get("VAPO_CONFIG_PATH"),
service_account_project_number=os.environ.get(
"VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER"
),
optimizer_job_display_name="optimizer_job_test",
)
job = await client.aio.prompt_optimizer.optimize(
method="vapo",
method=types.PromptOptimizerMethod.VAPO,
config=config,
)
assert isinstance(job, types.CustomJob)
assert job.state == genai_types.JobState.JOB_STATE_PENDING


@pytest.mark.asyncio
async def test_optimize_nano_async(client):
_raise_for_unset_env_vars()
config_path = os.environ.get("VAPO_CONFIG_PATH")
root, ext = os.path.splitext(config_path)
nano_path = f"{root}_nano{ext}"

config = types.PromptOptimizerConfig(
config_path=nano_path,
service_account_project_number=os.environ.get(
"VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER"
),
optimizer_job_display_name="optimizer_job_test",
)
job = await client.aio.prompt_optimizer.optimize(
method=types.PromptOptimizerMethod.OPTIMIZATION_TARGET_GEMINI_NANO,
config=config,
)
assert isinstance(job, types.CustomJob)
Expand All @@ -86,8 +134,9 @@ async def test_optimize_async(client):
@pytest.mark.asyncio
async def test_optimize_async_with_config_wait_for_completion(client, caplog):
_raise_for_unset_env_vars()
caplog.set_level(logging.INFO)

config = types.PromptOptimizerVAPOConfig(
config = types.PromptOptimizerConfig(
config_path=os.environ.get("VAPO_CONFIG_PATH"),
service_account_project_number=os.environ.get(
"VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER"
Expand All @@ -96,7 +145,7 @@ async def test_optimize_async_with_config_wait_for_completion(client, caplog):
wait_for_completion=True,
)
job = await client.aio.prompt_optimizer.optimize(
method="vapo",
method=types.PromptOptimizerMethod.VAPO,
config=config,
)
assert isinstance(job, types.CustomJob)
Expand Down
20 changes: 18 additions & 2 deletions tests/unit/vertexai/genai/test_prompt_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,24 @@ def test_prompt_optimizer_optimize(self, mock_custom_job, mock_client):
"""Test that prompt_optimizer.optimize method creates a custom job."""
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
test_client.prompt_optimizer.optimize(
method="vapo",
config=types.PromptOptimizerVAPOConfig(
method=types.PromptOptimizerMethod.VAPO,
config=types.PromptOptimizerConfig(
config_path="gs://ssusie-vapo-sdk-test/config.json",
wait_for_completion=False,
service_account="test-service-account",
),
)
mock_client.assert_called_once()
mock_custom_job.assert_called_once()

@mock.patch.object(client.Client, "_get_api_client")
@mock.patch.object(prompt_optimizer.PromptOptimizer, "_create_custom_job_resource")
def test_prompt_optimizer_optimize_nano(self, mock_custom_job, mock_client):
"""Test that prompt_optimizer.optimize method creates a custom job."""
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
test_client.prompt_optimizer.optimize(
method=types.PromptOptimizerMethod.OPTIMIZATION_TARGET_GEMINI_NANO,
config=types.PromptOptimizerConfig(
config_path="gs://ssusie-vapo-sdk-test/config.json",
wait_for_completion=False,
service_account="test-service-account",
Expand Down
4 changes: 2 additions & 2 deletions vertexai/_genai/_prompt_optimizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@


def _get_service_account(
config: types.PromptOptimizerVAPOConfigOrDict,
config: types.PromptOptimizerConfigOrDict,
) -> str:
"""Get the service account from the config for the custom job."""
if isinstance(config, dict):
config = types.PromptOptimizerVAPOConfig.model_validate(config)
config = types.PromptOptimizerConfig.model_validate(config)

if config.service_account and config.service_account_project_number:
raise ValueError(
Expand Down
72 changes: 45 additions & 27 deletions vertexai/_genai/prompt_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,37 +407,46 @@ def _wait_for_completion(self, job_name: str) -> types.CustomJob:

def optimize(
self,
method: str,
config: types.PromptOptimizerVAPOConfigOrDict,
method: types.PromptOptimizerMethod,
config: types.PromptOptimizerConfigOrDict,
) -> types.CustomJob:
"""Call PO-Data optimizer.

Args:
method: The method for optimizing multiple prompts.
config: PromptOptimizerVAPOConfig instance containing the
method: The method for optimizing multiple prompts. Supported methods:
VAPO, OPTIMIZATION_TARGET_GEMINI_NANO.
config: PromptOptimizerConfig instance containing the
configuration for prompt optimization.
Returns:
The custom job that was created.
"""

if method != "vapo":
raise ValueError("Only vapo method is currently supported.")

if isinstance(config, dict):
config = types.PromptOptimizerVAPOConfig(**config)
config = types.PromptOptimizerConfig(**config)

if not config.config_path:
raise ValueError("Config path is required.")

_OPTIMIZER_METHOD_TO_CONTAINER_URI = {
types.PromptOptimizerMethod.VAPO: "us-docker.pkg.dev/vertex-ai/cair/vaipo:preview_v1_0",
types.PromptOptimizerMethod.OPTIMIZATION_TARGET_GEMINI_NANO: "us-docker.pkg.dev/vertex-ai/cair/android-apo:preview_v1_0",
}
container_uri = _OPTIMIZER_METHOD_TO_CONTAINER_URI.get(method)
if not container_uri:
raise ValueError(
'Only "VAPO" and "OPTIMIZATION_TARGET_GEMINI_NANO" '
"methods are currently supported."
)

if config.optimizer_job_display_name:
display_name = config.optimizer_job_display_name
else:
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
display_name = f"vapo-optimizer-{timestamp}"
display_name = f"{method.value.lower()}-optimizer-{timestamp}"

wait_for_completion = config.wait_for_completion
if not config.config_path:
raise ValueError("Config path is required.")
bucket = "/".join(config.config_path.split("/")[:-1])

container_uri = "us-docker.pkg.dev/vertex-ai/cair/vaipo:preview_v1_0"

region = self._api_client.location
project = self._api_client.project
container_args = {
Expand Down Expand Up @@ -766,8 +775,8 @@ async def _get_custom_job(
# Todo: b/428953357 - Add example in the README.
async def optimize(
self,
method: str,
config: types.PromptOptimizerVAPOConfigOrDict,
method: types.PromptOptimizerMethod,
config: types.PromptOptimizerConfigOrDict,
) -> types.CustomJob:
"""Call async Vertex AI Prompt Optimizer (VAPO).

Expand All @@ -777,26 +786,37 @@ async def optimize(

Example usage:
client = vertexai.Client(project=PROJECT_NAME, location='us-central1')
vapo_config = vertexai.types.PromptOptimizerVAPOConfig(
vapo_config = vertexai.types.PromptOptimizerConfig(
config_path='gs://you-bucket-name/your-config.json',
service_account=service_account,
)
job = await client.aio.prompt_optimizer.optimize(
method='vapo', config=vapo_config)
method=types.PromptOptimizerMethod.VAPO, config=vapo_config)

Args:
method: The method for optimizing multiple prompts (currently only
vapo is supported).
config: PromptOptimizerVAPOConfig instance containing the
method: The method for optimizing multiple prompts. Supported methods:
VAPO, OPTIMIZATION_TARGET_GEMINI_NANO.
config: PromptOptimizerConfig instance containing the
configuration for prompt optimization.
Returns:
The custom job that was created.
"""
if method != "vapo":
raise ValueError("Only vapo methods is currently supported.")

if isinstance(config, dict):
config = types.PromptOptimizerVAPOConfig(**config)
config = types.PromptOptimizerConfig(**config)

if not config.config_path:
raise ValueError("Config path is required.")

_OPTIMIZER_METHOD_TO_CONTAINER_URI = {
types.PromptOptimizerMethod.VAPO: "us-docker.pkg.dev/vertex-ai/cair/vaipo:preview_v1_0",
types.PromptOptimizerMethod.OPTIMIZATION_TARGET_GEMINI_NANO: "us-docker.pkg.dev/vertex-ai/cair/android-apo:preview_v1_0",
}
container_uri = _OPTIMIZER_METHOD_TO_CONTAINER_URI.get(method)
if not container_uri:
raise ValueError(
'Only "VAPO" and "OPTIMIZATION_TARGET_GEMINI_NANO" '
"methods are currently supported."
)

if config.wait_for_completion:
logger.info(
Expand All @@ -807,14 +827,12 @@ async def optimize(
display_name = config.optimizer_job_display_name
else:
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
display_name = f"vapo-optimizer-{timestamp}"
display_name = f"{method.value.lower()}-optimizer-{timestamp}"

if not config.config_path:
raise ValueError("Config path is required.")
bucket = "/".join(config.config_path.split("/")[:-1])

container_uri = "us-docker.pkg.dev/vertex-ai/cair/vaipo:preview_v1_0"

region = self._api_client.location
project = self._api_client.project
container_args = {
Expand Down
20 changes: 14 additions & 6 deletions vertexai/_genai/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,9 @@
from .common import OptimizeResponseEndpointDict
from .common import OptimizeResponseEndpointOrDict
from .common import OptimizeResponseOrDict
from .common import OptimizerMethodPlaceholder
from .common import OptimizerMethodPlaceholderDict
from .common import OptimizerMethodPlaceholderOrDict
from .common import OptimizeTarget
from .common import PairwiseChoice
from .common import PairwiseMetricInput
Expand Down Expand Up @@ -618,9 +621,10 @@
from .common import PromptDataDict
from .common import PromptDataOrDict
from .common import PromptDict
from .common import PromptOptimizerVAPOConfig
from .common import PromptOptimizerVAPOConfigDict
from .common import PromptOptimizerVAPOConfigOrDict
from .common import PromptOptimizerConfig
from .common import PromptOptimizerConfigDict
from .common import PromptOptimizerConfigOrDict
from .common import PromptOptimizerMethod
from .common import PromptOrDict
from .common import PromptRef
from .common import PromptRefDict
Expand Down Expand Up @@ -1739,9 +1743,12 @@
"UpdateDatasetConfig",
"UpdateDatasetConfigDict",
"UpdateDatasetConfigOrDict",
"PromptOptimizerVAPOConfig",
"PromptOptimizerVAPOConfigDict",
"PromptOptimizerVAPOConfigOrDict",
"PromptOptimizerConfig",
"PromptOptimizerConfigDict",
"PromptOptimizerConfigOrDict",
"OptimizerMethodPlaceholder",
"OptimizerMethodPlaceholderDict",
"OptimizerMethodPlaceholderOrDict",
"ApplicableGuideline",
"ApplicableGuidelineDict",
"ApplicableGuidelineOrDict",
Expand Down Expand Up @@ -1837,6 +1844,7 @@
"Importance",
"OptimizeTarget",
"GenerateMemoriesResponseGeneratedMemoryAction",
"PromptOptimizerMethod",
"PromptData",
"PromptDataDict",
"PromptDataOrDict",
Expand Down
35 changes: 31 additions & 4 deletions vertexai/_genai/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,15 @@ class GenerateMemoriesResponseGeneratedMemoryAction(_common.CaseInSensitiveEnum)
"""The memory was deleted."""


class PromptOptimizerMethod(_common.CaseInSensitiveEnum):
"""The method for data driven prompt optimization."""

VAPO = "VAPO"
"""The default data driven Vertex AI Prompt Optimizer."""
OPTIMIZATION_TARGET_GEMINI_NANO = "OPTIMIZATION_TARGET_GEMINI_NANO"
"""The data driven prompt optimizer designer for prompts from Android core API."""


class CreateEvaluationItemConfig(_common.BaseModel):
"""Config to create an evaluation item."""

Expand Down Expand Up @@ -12025,7 +12034,7 @@ class _UpdateDatasetParametersDict(TypedDict, total=False):
]


class PromptOptimizerVAPOConfig(_common.BaseModel):
class PromptOptimizerConfig(_common.BaseModel):
"""VAPO Prompt Optimizer Config."""

config_path: Optional[str] = Field(
Expand All @@ -12050,7 +12059,7 @@ class PromptOptimizerVAPOConfig(_common.BaseModel):
)


class PromptOptimizerVAPOConfigDict(TypedDict, total=False):
class PromptOptimizerConfigDict(TypedDict, total=False):
"""VAPO Prompt Optimizer Config."""

config_path: Optional[str]
Expand All @@ -12069,8 +12078,26 @@ class PromptOptimizerVAPOConfigDict(TypedDict, total=False):
"""The display name of the optimization job. If not provided, a display name in the format of "vapo-optimizer-{timestamp}" will be used."""


PromptOptimizerVAPOConfigOrDict = Union[
PromptOptimizerVAPOConfig, PromptOptimizerVAPOConfigDict
PromptOptimizerConfigOrDict = Union[PromptOptimizerConfig, PromptOptimizerConfigDict]


class OptimizerMethodPlaceholder(_common.BaseModel):
"""Placeholder class to generate OptimizerMethod enum in common.py."""

method: Optional[PromptOptimizerMethod] = Field(
default=None, description="""The method for optimizing multiple prompts."""
)


class OptimizerMethodPlaceholderDict(TypedDict, total=False):
"""Placeholder class to generate OptimizerMethod enum in common.py."""

method: Optional[PromptOptimizerMethod]
"""The method for optimizing multiple prompts."""


OptimizerMethodPlaceholderOrDict = Union[
OptimizerMethodPlaceholder, OptimizerMethodPlaceholderDict
]


Expand Down
Loading