Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,20 @@ def test_optimize_prompt(client):
assert response.raw_text_response


# def test_optimize_prompt_w_optimization_target(client):
# """Tests the optimize request parameters method with optimization target."""
# from google.genai import types as genai_types
# test_prompt = "Generate system instructions for analyzing medical articles"
# response = client.prompt_optimizer.optimize_prompt(
# prompt=test_prompt,
# config=types.OptimizeConfig(
# optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_GEMINI_NANO,
# ),
# )
# assert isinstance(response, types.OptimizeResponse)
# assert response.raw_text_response


pytestmark = pytest_helper.setup(
file=__file__,
globals_for_file=globals(),
Expand Down
51 changes: 49 additions & 2 deletions tests/unit/vertexai/genai/test_prompt_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,57 @@ def test_prompt_optimizer_optimize(self, mock_custom_job, mock_client):
def test_prompt_optimizer_optimize_prompt(
self, mock_custom_optimize_prompt, mock_client
):
"""Test that prompt_optimizer.optimize method creates a custom job."""
"""Test that prompt_optimizer.optimize_prompt method calls optimize_prompt API."""
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
test_client.prompt_optimizer.optimize_prompt(prompt="test_prompt")
mock_client.assert_called_once()
mock_custom_optimize_prompt.assert_called_once()

# TODO(b/415060797): add more tests for prompt_optimizer.optimize
@mock.patch.object(prompt_optimizer.PromptOptimizer, "_custom_optimize_prompt")
def test_prompt_optimizer_optimize_prompt_with_optimization_target(
self, mock_custom_optimize_prompt
):
"""Test that prompt_optimizer.optimize_prompt method calls _custom_optimize_prompt with optimization_target."""
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
config = types.OptimizeConfig(
optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_GEMINI_NANO,
)
test_client.prompt_optimizer.optimize_prompt(
prompt="test_prompt",
config=config,
)
mock_custom_optimize_prompt.assert_called_once_with(
content=mock.ANY,
config=config,
)

@pytest.mark.asyncio
@mock.patch.object(prompt_optimizer.AsyncPromptOptimizer, "_custom_optimize_prompt")
async def test_async_prompt_optimizer_optimize_prompt(
self, mock_custom_optimize_prompt
):
"""Test that async prompt_optimizer.optimize_prompt method calls optimize_prompt API."""
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
await test_client.aio.prompt_optimizer.optimize_prompt(prompt="test_prompt")
mock_custom_optimize_prompt.assert_called_once()

@pytest.mark.asyncio
@mock.patch.object(prompt_optimizer.AsyncPromptOptimizer, "_custom_optimize_prompt")
async def test_async_prompt_optimizer_optimize_prompt_with_optimization_target(
self, mock_custom_optimize_prompt
):
"""Test that async prompt_optimizer.optimize_prompt calls optimize_prompt with optimization_target."""
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
config = types.OptimizeConfig(
optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_GEMINI_NANO,
)
await test_client.aio.prompt_optimizer.optimize_prompt(
prompt="test_prompt",
config=config,
)
mock_custom_optimize_prompt.assert_called_once_with(
content=mock.ANY,
config=config,
)

# # TODO(b/415060797): add more tests for prompt_optimizer.optimize
66 changes: 47 additions & 19 deletions vertexai/_genai/prompt_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,22 @@ def _GetCustomJobParameters_to_vertex(
return to_object


def _OptimizeConfig_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}

if getv(from_object, ["optimization_target"]) is not None:
setv(
parent_object,
["optimizationTarget"],
getv(from_object, ["optimization_target"]),
)

return to_object


def _OptimizeRequestParameters_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
Expand All @@ -176,7 +192,11 @@ def _OptimizeRequestParameters_to_vertex(
setv(to_object, ["content"], getv(from_object, ["content"]))

if getv(from_object, ["config"]) is not None:
setv(to_object, ["config"], getv(from_object, ["config"]))
setv(
to_object,
["config"],
_OptimizeConfig_to_vertex(getv(from_object, ["config"]), to_object),
)

return to_object

Expand Down Expand Up @@ -468,7 +488,10 @@ def optimize(
return job

def optimize_prompt(
self, *, prompt: str, config: Optional[types.OptimizeConfig] = None
self,
*,
prompt: str,
config: Optional[types.OptimizeConfig] = None,
) -> types.OptimizeResponse:
"""Makes an API request to _optimize_prompt and returns the parsed response.

Expand All @@ -480,19 +503,21 @@ def optimize_prompt(

Args:
prompt: The prompt to optimize.
config: The configuration for prompt optimization. Currently, config is
not supported for a single prompt optimization.
config: Optional.The configuration for prompt optimization. To optimize
prompts from Android API provide
types.OptimizeConfig(
optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_GEMINI_NANO
)
Returns:
The parsed response from the API request.
"""
if config is not None:
raise ValueError(
"Currently, config is not supported for a single prompt optimization."
)

prompt = genai_types.Content(parts=[genai_types.Part(text=prompt)], role="user")
# TODO: b/435653980 - replace the custom method with a generated method.
return self._custom_optimize_prompt(content=prompt)
return self._custom_optimize_prompt(
content=prompt,
config=config,
)

def _custom_optimize_prompt(
self,
Expand All @@ -511,7 +536,6 @@ def _custom_optimize_prompt(
content=content,
config=config,
)

request_url_dict: Optional[dict[str, str]]
if not self._api_client.vertexai:
raise ValueError("This method is only supported in the Vertex AI client.")
Expand Down Expand Up @@ -850,7 +874,6 @@ async def _custom_optimize_prompt(
content=content,
config=config,
)

request_url_dict: Optional[dict[str, str]]
if not self._api_client.vertexai:
raise ValueError("This method is only supported in the Vertex AI client.")
Expand Down Expand Up @@ -909,7 +932,10 @@ async def _custom_optimize_prompt(
return final_response

async def optimize_prompt(
self, *, prompt: str, config: Optional[types.OptimizeConfig] = None
self,
*,
prompt: str,
config: Optional[types.OptimizeConfig] = None,
) -> types.OptimizeResponse:
"""Makes an async request to _optimize_prompt and returns an optimized prompt.

Expand All @@ -920,16 +946,18 @@ async def optimize_prompt(

Args:
prompt: The prompt to optimize.
config: The configuration for prompt optimization. Currently, config is
not supported for a single prompt optimization.
config: Optional.The configuration for prompt optimization. To optimize
prompts from Android API provide
types.OptimizeConfig(
optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_GEMINI_NANO
)
Returns:
The parsed response from the API request.
"""
if config is not None:
raise ValueError(
"Currently, config is not supported for a single prompt optimization."
)

prompt = genai_types.Content(parts=[genai_types.Part(text=prompt)], role="user")
# TODO: b/435653980 - replace the custom method with a generated method.
return await self._custom_optimize_prompt(content=prompt)
return await self._custom_optimize_prompt(
content=prompt,
config=config,
)
2 changes: 2 additions & 0 deletions vertexai/_genai/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,7 @@
from .common import OptimizeResponseEndpointDict
from .common import OptimizeResponseEndpointOrDict
from .common import OptimizeResponseOrDict
from .common import OptimizeTarget
from .common import PairwiseChoice
from .common import PairwiseMetricInput
from .common import PairwiseMetricInputDict
Expand Down Expand Up @@ -1828,6 +1829,7 @@
"RubricContentType",
"EvaluationRunState",
"Importance",
"OptimizeTarget",
"GenerateMemoriesResponseGeneratedMemoryAction",
"PromptData",
"PromptDataDict",
Expand Down
13 changes: 13 additions & 0 deletions vertexai/_genai/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,13 @@ class Importance(_common.CaseInSensitiveEnum):
"""Low importance."""


class OptimizeTarget(_common.CaseInSensitiveEnum):
"""None"""

OPTIMIZATION_TARGET_GEMINI_NANO = "OPTIMIZATION_TARGET_GEMINI_NANO"
"""The data driven prompt optimizer designer for prompts from Android core API."""


class GenerateMemoriesResponseGeneratedMemoryAction(_common.CaseInSensitiveEnum):
"""The action to take."""

Expand Down Expand Up @@ -3986,6 +3993,9 @@ class OptimizeConfig(_common.BaseModel):
http_options: Optional[genai_types.HttpOptions] = Field(
default=None, description="""Used to override HTTP request options."""
)
optimization_target: Optional[OptimizeTarget] = Field(
default=None, description=""""""
)


class OptimizeConfigDict(TypedDict, total=False):
Expand All @@ -3994,6 +4004,9 @@ class OptimizeConfigDict(TypedDict, total=False):
http_options: Optional[genai_types.HttpOptionsDict]
"""Used to override HTTP request options."""

optimization_target: Optional[OptimizeTarget]
""""""


OptimizeConfigOrDict = Union[OptimizeConfig, OptimizeConfigDict]

Expand Down
Loading