Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 96 additions & 2 deletions tests/unit/vertexai/genai/test_agent_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
import pytest


_TEST_AGENT_FRAMEWORK = "test-agent-framework"
_TEST_AGENT_FRAMEWORK = "google-adk"
GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY = (
"GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY"
)
Expand Down Expand Up @@ -976,9 +976,11 @@ def test_create_agent_engine_config_with_source_packages(
entrypoint_object="app",
requirements_file=requirements_file_path,
class_methods=_TEST_AGENT_ENGINE_CLASS_METHODS,
agent_framework=_TEST_AGENT_FRAMEWORK,
)
assert config["display_name"] == _TEST_AGENT_ENGINE_DISPLAY_NAME
assert config["description"] == _TEST_AGENT_ENGINE_DESCRIPTION
assert config["spec"]["agent_framework"] == _TEST_AGENT_FRAMEWORK
assert config["spec"]["source_code_spec"] == {
"inline_source": {"source_archive": "test_tarball"},
"python_spec": {
Expand Down Expand Up @@ -1500,6 +1502,7 @@ def test_create_agent_engine_with_env_vars_dict(
entrypoint_module=None,
entrypoint_object=None,
requirements_file=None,
agent_framework=None,
)
request_mock.assert_called_with(
"post",
Expand All @@ -1513,7 +1516,9 @@ def test_create_agent_engine_with_env_vars_dict(
"package_spec": {
"pickle_object_gcs_uri": _TEST_AGENT_ENGINE_GCS_URI,
"python_version": _TEST_PYTHON_VERSION,
"requirements_gcs_uri": _TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI,
"requirements_gcs_uri": (
_TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI
),
},
},
},
Expand Down Expand Up @@ -1586,6 +1591,7 @@ def test_create_agent_engine_with_custom_service_account(
entrypoint_module=None,
entrypoint_object=None,
requirements_file=None,
agent_framework=None,
)
request_mock.assert_called_with(
"post",
Expand Down Expand Up @@ -1674,6 +1680,7 @@ def test_create_agent_engine_with_experimental_mode(
entrypoint_module=None,
entrypoint_object=None,
requirements_file=None,
agent_framework=None,
)
request_mock.assert_called_with(
"post",
Expand Down Expand Up @@ -1826,6 +1833,7 @@ def test_create_agent_engine_with_class_methods(
entrypoint_module=None,
entrypoint_object=None,
requirements_file=None,
agent_framework=None,
)
request_mock.assert_called_with(
"post",
Expand All @@ -1845,6 +1853,92 @@ def test_create_agent_engine_with_class_methods(
None,
)

@mock.patch.object(agent_engines.AgentEngines, "_create_config")
@mock.patch.object(_agent_engines_utils, "_await_operation")
def test_create_agent_engine_with_agent_framework(
self,
mock_await_operation,
mock_create_config,
):
mock_create_config.return_value = {
"display_name": _TEST_AGENT_ENGINE_DISPLAY_NAME,
"description": _TEST_AGENT_ENGINE_DESCRIPTION,
"spec": {
"package_spec": {
"python_version": _TEST_PYTHON_VERSION,
"pickle_object_gcs_uri": _TEST_AGENT_ENGINE_GCS_URI,
"requirements_gcs_uri": _TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI,
},
"class_methods": [_TEST_AGENT_ENGINE_CLASS_METHOD_1],
"agent_framework": _TEST_AGENT_FRAMEWORK,
},
}
mock_await_operation.return_value = _genai_types.AgentEngineOperation(
response=_genai_types.ReasoningEngine(
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
spec=_TEST_AGENT_ENGINE_SPEC,
)
)
with mock.patch.object(
self.client.agent_engines._api_client, "request"
) as request_mock:
request_mock.return_value = genai_types.HttpResponse(body="")
self.client.agent_engines.create(
agent=self.test_agent,
config=_genai_types.AgentEngineConfig(
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
requirements=_TEST_AGENT_ENGINE_REQUIREMENTS,
extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH],
staging_bucket=_TEST_STAGING_BUCKET,
agent_framework=_TEST_AGENT_FRAMEWORK,
),
)
mock_create_config.assert_called_with(
mode="create",
agent=self.test_agent,
staging_bucket=_TEST_STAGING_BUCKET,
requirements=_TEST_AGENT_ENGINE_REQUIREMENTS,
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
description=None,
gcs_dir_name=None,
extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH],
env_vars=None,
service_account=None,
context_spec=None,
psc_interface_config=None,
min_instances=None,
max_instances=None,
resource_limits=None,
container_concurrency=None,
encryption_spec=None,
labels=None,
agent_server_mode=None,
class_methods=None,
source_packages=None,
entrypoint_module=None,
entrypoint_object=None,
requirements_file=None,
agent_framework=_TEST_AGENT_FRAMEWORK,
)
request_mock.assert_called_with(
"post",
"reasoningEngines",
{
"displayName": _TEST_AGENT_ENGINE_DISPLAY_NAME,
"description": _TEST_AGENT_ENGINE_DESCRIPTION,
"spec": {
"agent_framework": _TEST_AGENT_FRAMEWORK,
"class_methods": [_TEST_AGENT_ENGINE_CLASS_METHOD_1],
"package_spec": {
"pickle_object_gcs_uri": _TEST_AGENT_ENGINE_GCS_URI,
"python_version": _TEST_PYTHON_VERSION,
"requirements_gcs_uri": _TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI,
},
},
},
None,
)

@pytest.mark.usefixtures("caplog")
@mock.patch.object(_agent_engines_utils, "_prepare")
@mock.patch.object(_agent_engines_utils, "_await_operation")
Expand Down
54 changes: 47 additions & 7 deletions vertexai/_genai/_agent_engines_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,16 @@
_BASE_MODULES = set(_BUILTIN_MODULE_NAMES + tuple(_STDLIB_MODULE_NAMES))
_BLOB_FILENAME = "agent_engine.pkl"
_DEFAULT_AGENT_FRAMEWORK = "custom"
_SUPPORTED_AGENT_FRAMEWORKS = frozenset(
[
"google-adk",
"langchain",
"langgraph",
"ag2",
"llama-index",
"custom",
]
)
_DEFAULT_ASYNC_METHOD_NAME = "async_query"
_DEFAULT_ASYNC_METHOD_RETURN_TYPE = "Coroutine[Any]"
_DEFAULT_ASYNC_STREAM_METHOD_NAME = "async_stream_query"
Expand Down Expand Up @@ -705,13 +715,43 @@ def _generate_schema(
return schema


def _get_agent_framework(*, agent: _AgentEngineInterface) -> str:
if (
hasattr(agent, _AGENT_FRAMEWORK_ATTR)
and getattr(agent, _AGENT_FRAMEWORK_ATTR) is not None
and isinstance(getattr(agent, _AGENT_FRAMEWORK_ATTR), str)
):
return getattr(agent, _AGENT_FRAMEWORK_ATTR)
def _get_agent_framework(
*,
agent_framework: Optional[str],
agent: _AgentEngineInterface,
) -> str:
"""Gets the agent framework to use.

The agent framework is determined in the following order of priority:
1. The `agent_framework` passed to this function.
2. The `agent_framework` attribute on the `agent` object.
3. The default framework, "custom".

Args:
agent_framework (str):
The agent framework provided by the user.
agent (_AgentEngineInterface):
The agent engine instance.

Returns:
str: The name of the agent framework to use.
"""
if agent_framework is not None and agent_framework in _SUPPORTED_AGENT_FRAMEWORKS:
logger.info(f"Using agent framework: {agent_framework}")
return agent_framework
if hasattr(agent, _AGENT_FRAMEWORK_ATTR):
agent_framework_attr = getattr(agent, _AGENT_FRAMEWORK_ATTR)
if (
agent_framework_attr is not None
and isinstance(agent_framework_attr, str)
and agent_framework_attr in _SUPPORTED_AGENT_FRAMEWORKS
):
logger.info(f"Using agent framework: {agent_framework_attr}")
return agent_framework_attr
logger.info(
f"The provided agent framework {agent_framework} is not supported."
f" Defaulting to {_DEFAULT_AGENT_FRAMEWORK}."
)
return _DEFAULT_AGENT_FRAMEWORK


Expand Down
14 changes: 13 additions & 1 deletion vertexai/_genai/agent_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def _CreateAgentEngineConfig_to_vertex(
getv(from_object, ["requirements_file"]),
)

if getv(from_object, ["agent_framework"]) is not None:
setv(parent_object, ["agentFramework"], getv(from_object, ["agent_framework"]))

return to_object


Expand Down Expand Up @@ -285,6 +288,9 @@ def _UpdateAgentEngineConfig_to_vertex(
getv(from_object, ["requirements_file"]),
)

if getv(from_object, ["agent_framework"]) is not None:
setv(parent_object, ["agentFramework"], getv(from_object, ["agent_framework"]))

if getv(from_object, ["update_mask"]) is not None:
setv(
parent_object, ["_query", "updateMask"], getv(from_object, ["update_mask"])
Expand Down Expand Up @@ -923,6 +929,7 @@ def create(
entrypoint_module=config.entrypoint_module,
entrypoint_object=config.entrypoint_object,
requirements_file=config.requirements_file,
agent_framework=config.agent_framework,
)
operation = self._create(config=api_config)
# TODO: Use a more specific link.
Expand Down Expand Up @@ -986,6 +993,7 @@ def _create_config(
entrypoint_module: Optional[str] = None,
entrypoint_object: Optional[str] = None,
requirements_file: Optional[str] = None,
agent_framework: Optional[str] = None,
) -> types.UpdateAgentEngineConfigDict:
import sys

Expand Down Expand Up @@ -1195,7 +1203,10 @@ def _create_config(
] = agent_server_mode

agent_engine_spec["agent_framework"] = (
_agent_engines_utils._get_agent_framework(agent=agent)
_agent_engines_utils._get_agent_framework(
agent_framework=agent_framework,
agent=agent,
)
)
update_masks.append("spec.agent_framework")
config["spec"] = agent_engine_spec
Expand Down Expand Up @@ -1423,6 +1434,7 @@ def update(
entrypoint_module=config.entrypoint_module,
entrypoint_object=config.entrypoint_object,
requirements_file=config.requirements_file,
agent_framework=config.agent_framework,
)
operation = self._update(name=name, config=api_config)
logger.info(
Expand Down
75 changes: 75 additions & 0 deletions vertexai/_genai/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5366,6 +5366,19 @@ class CreateAgentEngineConfig(_common.BaseModel):
the source package.
""",
)
agent_framework: Optional[
Literal["google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom"]
] = Field(
default=None,
description="""The agent framework to be used for the Agent Engine.
The OSS agent framework used to develop the agent.
Currently supported values: "google-adk", "langchain", "langgraph",
"ag2", "llama-index", "custom".
If not specified:
- If `agent` is specified, the agent framework will be auto-detected.
- If `source_packages` is specified, the agent framework will
default to "custom".""",
)


class CreateAgentEngineConfigDict(TypedDict, total=False):
Expand Down Expand Up @@ -5464,6 +5477,18 @@ class CreateAgentEngineConfigDict(TypedDict, total=False):
the source package.
"""

agent_framework: Optional[
Literal["google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom"]
]
"""The agent framework to be used for the Agent Engine.
The OSS agent framework used to develop the agent.
Currently supported values: "google-adk", "langchain", "langgraph",
"ag2", "llama-index", "custom".
If not specified:
- If `agent` is specified, the agent framework will be auto-detected.
- If `source_packages` is specified, the agent framework will
default to "custom"."""


CreateAgentEngineConfigOrDict = Union[
CreateAgentEngineConfig, CreateAgentEngineConfigDict
Expand Down Expand Up @@ -6067,6 +6092,19 @@ class UpdateAgentEngineConfig(_common.BaseModel):
the source package.
""",
)
agent_framework: Optional[
Literal["google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom"]
] = Field(
default=None,
description="""The agent framework to be used for the Agent Engine.
The OSS agent framework used to develop the agent.
Currently supported values: "google-adk", "langchain", "langgraph",
"ag2", "llama-index", "custom".
If not specified:
- If `agent` is specified, the agent framework will be auto-detected.
- If `source_packages` is specified, the agent framework will
default to "custom".""",
)
update_mask: Optional[str] = Field(
default=None,
description="""The update mask to apply. For the `FieldMask` definition, see
Expand Down Expand Up @@ -6170,6 +6208,18 @@ class UpdateAgentEngineConfigDict(TypedDict, total=False):
the source package.
"""

agent_framework: Optional[
Literal["google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom"]
]
"""The agent framework to be used for the Agent Engine.
The OSS agent framework used to develop the agent.
Currently supported values: "google-adk", "langchain", "langgraph",
"ag2", "llama-index", "custom".
If not specified:
- If `agent` is specified, the agent framework will be auto-detected.
- If `source_packages` is specified, the agent framework will
default to "custom"."""

update_mask: Optional[str]
"""The update mask to apply. For the `FieldMask` definition, see
https://protobuf.dev/reference/protobuf/google.protobuf/#field-mask."""
Expand Down Expand Up @@ -12907,6 +12957,19 @@ class AgentEngineConfig(_common.BaseModel):
the source package.
""",
)
agent_framework: Optional[
Literal["google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom"]
] = Field(
default=None,
description="""The agent framework to be used for the Agent Engine.
The OSS agent framework used to develop the agent.
Currently supported values: "google-adk", "langchain", "langgraph",
"ag2", "llama-index", "custom".
If not specified:
- If `agent` is specified, the agent framework will be auto-detected.
- If `source_packages` is specified, the agent framework will
default to "custom".""",
)


class AgentEngineConfigDict(TypedDict, total=False):
Expand Down Expand Up @@ -13034,6 +13097,18 @@ class AgentEngineConfigDict(TypedDict, total=False):
the source package.
"""

agent_framework: Optional[
Literal["google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom"]
]
"""The agent framework to be used for the Agent Engine.
The OSS agent framework used to develop the agent.
Currently supported values: "google-adk", "langchain", "langgraph",
"ag2", "llama-index", "custom".
If not specified:
- If `agent` is specified, the agent framework will be auto-detected.
- If `source_packages` is specified, the agent framework will
default to "custom"."""


AgentEngineConfigOrDict = Union[AgentEngineConfig, AgentEngineConfigDict]

Expand Down
Loading