diff --git a/tests/unit/vertexai/genai/test_agent_engines.py b/tests/unit/vertexai/genai/test_agent_engines.py index 37e678f7f6..508bf0377b 100644 --- a/tests/unit/vertexai/genai/test_agent_engines.py +++ b/tests/unit/vertexai/genai/test_agent_engines.py @@ -40,7 +40,7 @@ import pytest -_TEST_AGENT_FRAMEWORK = "test-agent-framework" +_TEST_AGENT_FRAMEWORK = "google-adk" GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY = ( "GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY" ) @@ -976,9 +976,11 @@ def test_create_agent_engine_config_with_source_packages( entrypoint_object="app", requirements_file=requirements_file_path, class_methods=_TEST_AGENT_ENGINE_CLASS_METHODS, + agent_framework=_TEST_AGENT_FRAMEWORK, ) assert config["display_name"] == _TEST_AGENT_ENGINE_DISPLAY_NAME assert config["description"] == _TEST_AGENT_ENGINE_DESCRIPTION + assert config["spec"]["agent_framework"] == _TEST_AGENT_FRAMEWORK assert config["spec"]["source_code_spec"] == { "inline_source": {"source_archive": "test_tarball"}, "python_spec": { @@ -1500,6 +1502,7 @@ def test_create_agent_engine_with_env_vars_dict( entrypoint_module=None, entrypoint_object=None, requirements_file=None, + agent_framework=None, ) request_mock.assert_called_with( "post", @@ -1513,7 +1516,9 @@ def test_create_agent_engine_with_env_vars_dict( "package_spec": { "pickle_object_gcs_uri": _TEST_AGENT_ENGINE_GCS_URI, "python_version": _TEST_PYTHON_VERSION, - "requirements_gcs_uri": _TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI, + "requirements_gcs_uri": ( + _TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI + ), }, }, }, @@ -1586,6 +1591,7 @@ def test_create_agent_engine_with_custom_service_account( entrypoint_module=None, entrypoint_object=None, requirements_file=None, + agent_framework=None, ) request_mock.assert_called_with( "post", @@ -1674,6 +1680,7 @@ def test_create_agent_engine_with_experimental_mode( entrypoint_module=None, entrypoint_object=None, requirements_file=None, + agent_framework=None, ) request_mock.assert_called_with( "post", @@ -1826,6 +1833,7 @@ def test_create_agent_engine_with_class_methods( entrypoint_module=None, entrypoint_object=None, requirements_file=None, + agent_framework=None, ) request_mock.assert_called_with( "post", @@ -1845,6 +1853,92 @@ def test_create_agent_engine_with_class_methods( None, ) + @mock.patch.object(agent_engines.AgentEngines, "_create_config") + @mock.patch.object(_agent_engines_utils, "_await_operation") + def test_create_agent_engine_with_agent_framework( + self, + mock_await_operation, + mock_create_config, + ): + mock_create_config.return_value = { + "display_name": _TEST_AGENT_ENGINE_DISPLAY_NAME, + "description": _TEST_AGENT_ENGINE_DESCRIPTION, + "spec": { + "package_spec": { + "python_version": _TEST_PYTHON_VERSION, + "pickle_object_gcs_uri": _TEST_AGENT_ENGINE_GCS_URI, + "requirements_gcs_uri": _TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI, + }, + "class_methods": [_TEST_AGENT_ENGINE_CLASS_METHOD_1], + "agent_framework": _TEST_AGENT_FRAMEWORK, + }, + } + mock_await_operation.return_value = _genai_types.AgentEngineOperation( + response=_genai_types.ReasoningEngine( + name=_TEST_AGENT_ENGINE_RESOURCE_NAME, + spec=_TEST_AGENT_ENGINE_SPEC, + ) + ) + with mock.patch.object( + self.client.agent_engines._api_client, "request" + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse(body="") + self.client.agent_engines.create( + agent=self.test_agent, + config=_genai_types.AgentEngineConfig( + display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME, + requirements=_TEST_AGENT_ENGINE_REQUIREMENTS, + extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH], + staging_bucket=_TEST_STAGING_BUCKET, + agent_framework=_TEST_AGENT_FRAMEWORK, + ), + ) + mock_create_config.assert_called_with( + mode="create", + agent=self.test_agent, + staging_bucket=_TEST_STAGING_BUCKET, + requirements=_TEST_AGENT_ENGINE_REQUIREMENTS, + display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME, + description=None, + gcs_dir_name=None, + extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH], + env_vars=None, + service_account=None, + context_spec=None, + psc_interface_config=None, + min_instances=None, + max_instances=None, + resource_limits=None, + container_concurrency=None, + encryption_spec=None, + labels=None, + agent_server_mode=None, + class_methods=None, + source_packages=None, + entrypoint_module=None, + entrypoint_object=None, + requirements_file=None, + agent_framework=_TEST_AGENT_FRAMEWORK, + ) + request_mock.assert_called_with( + "post", + "reasoningEngines", + { + "displayName": _TEST_AGENT_ENGINE_DISPLAY_NAME, + "description": _TEST_AGENT_ENGINE_DESCRIPTION, + "spec": { + "agent_framework": _TEST_AGENT_FRAMEWORK, + "class_methods": [_TEST_AGENT_ENGINE_CLASS_METHOD_1], + "package_spec": { + "pickle_object_gcs_uri": _TEST_AGENT_ENGINE_GCS_URI, + "python_version": _TEST_PYTHON_VERSION, + "requirements_gcs_uri": _TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI, + }, + }, + }, + None, + ) + @pytest.mark.usefixtures("caplog") @mock.patch.object(_agent_engines_utils, "_prepare") @mock.patch.object(_agent_engines_utils, "_await_operation") diff --git a/vertexai/_genai/_agent_engines_utils.py b/vertexai/_genai/_agent_engines_utils.py index 8364212528..0e063a2c1b 100644 --- a/vertexai/_genai/_agent_engines_utils.py +++ b/vertexai/_genai/_agent_engines_utils.py @@ -128,6 +128,16 @@ _BASE_MODULES = set(_BUILTIN_MODULE_NAMES + tuple(_STDLIB_MODULE_NAMES)) _BLOB_FILENAME = "agent_engine.pkl" _DEFAULT_AGENT_FRAMEWORK = "custom" +_SUPPORTED_AGENT_FRAMEWORKS = frozenset( + [ + "google-adk", + "langchain", + "langgraph", + "ag2", + "llama-index", + "custom", + ] +) _DEFAULT_ASYNC_METHOD_NAME = "async_query" _DEFAULT_ASYNC_METHOD_RETURN_TYPE = "Coroutine[Any]" _DEFAULT_ASYNC_STREAM_METHOD_NAME = "async_stream_query" @@ -705,13 +715,43 @@ def _generate_schema( return schema -def _get_agent_framework(*, agent: _AgentEngineInterface) -> str: - if ( - hasattr(agent, _AGENT_FRAMEWORK_ATTR) - and getattr(agent, _AGENT_FRAMEWORK_ATTR) is not None - and isinstance(getattr(agent, _AGENT_FRAMEWORK_ATTR), str) - ): - return getattr(agent, _AGENT_FRAMEWORK_ATTR) +def _get_agent_framework( + *, + agent_framework: Optional[str], + agent: _AgentEngineInterface, +) -> str: + """Gets the agent framework to use. + + The agent framework is determined in the following order of priority: + 1. The `agent_framework` passed to this function. + 2. The `agent_framework` attribute on the `agent` object. + 3. The default framework, "custom". + + Args: + agent_framework (str): + The agent framework provided by the user. + agent (_AgentEngineInterface): + The agent engine instance. + + Returns: + str: The name of the agent framework to use. + """ + if agent_framework is not None and agent_framework in _SUPPORTED_AGENT_FRAMEWORKS: + logger.info(f"Using agent framework: {agent_framework}") + return agent_framework + if hasattr(agent, _AGENT_FRAMEWORK_ATTR): + agent_framework_attr = getattr(agent, _AGENT_FRAMEWORK_ATTR) + if ( + agent_framework_attr is not None + and isinstance(agent_framework_attr, str) + and agent_framework_attr in _SUPPORTED_AGENT_FRAMEWORKS + ): + logger.info(f"Using agent framework: {agent_framework_attr}") + return agent_framework_attr + logger.info( + f"The provided agent framework {agent_framework} is not supported." + f" Defaulting to {_DEFAULT_AGENT_FRAMEWORK}." + ) return _DEFAULT_AGENT_FRAMEWORK diff --git a/vertexai/_genai/agent_engines.py b/vertexai/_genai/agent_engines.py index 2b3ad56d8f..238c5d8d69 100644 --- a/vertexai/_genai/agent_engines.py +++ b/vertexai/_genai/agent_engines.py @@ -94,6 +94,9 @@ def _CreateAgentEngineConfig_to_vertex( getv(from_object, ["requirements_file"]), ) + if getv(from_object, ["agent_framework"]) is not None: + setv(parent_object, ["agentFramework"], getv(from_object, ["agent_framework"])) + return to_object @@ -285,6 +288,9 @@ def _UpdateAgentEngineConfig_to_vertex( getv(from_object, ["requirements_file"]), ) + if getv(from_object, ["agent_framework"]) is not None: + setv(parent_object, ["agentFramework"], getv(from_object, ["agent_framework"])) + if getv(from_object, ["update_mask"]) is not None: setv( parent_object, ["_query", "updateMask"], getv(from_object, ["update_mask"]) @@ -923,6 +929,7 @@ def create( entrypoint_module=config.entrypoint_module, entrypoint_object=config.entrypoint_object, requirements_file=config.requirements_file, + agent_framework=config.agent_framework, ) operation = self._create(config=api_config) # TODO: Use a more specific link. @@ -986,6 +993,7 @@ def _create_config( entrypoint_module: Optional[str] = None, entrypoint_object: Optional[str] = None, requirements_file: Optional[str] = None, + agent_framework: Optional[str] = None, ) -> types.UpdateAgentEngineConfigDict: import sys @@ -1195,7 +1203,10 @@ def _create_config( ] = agent_server_mode agent_engine_spec["agent_framework"] = ( - _agent_engines_utils._get_agent_framework(agent=agent) + _agent_engines_utils._get_agent_framework( + agent_framework=agent_framework, + agent=agent, + ) ) update_masks.append("spec.agent_framework") config["spec"] = agent_engine_spec @@ -1423,6 +1434,7 @@ def update( entrypoint_module=config.entrypoint_module, entrypoint_object=config.entrypoint_object, requirements_file=config.requirements_file, + agent_framework=config.agent_framework, ) operation = self._update(name=name, config=api_config) logger.info( diff --git a/vertexai/_genai/types/common.py b/vertexai/_genai/types/common.py index d193042586..d82a6903e3 100644 --- a/vertexai/_genai/types/common.py +++ b/vertexai/_genai/types/common.py @@ -5366,6 +5366,19 @@ class CreateAgentEngineConfig(_common.BaseModel): the source package. """, ) + agent_framework: Optional[ + Literal["google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom"] + ] = Field( + default=None, + description="""The agent framework to be used for the Agent Engine. + The OSS agent framework used to develop the agent. + Currently supported values: "google-adk", "langchain", "langgraph", + "ag2", "llama-index", "custom". + If not specified: + - If `agent` is specified, the agent framework will be auto-detected. + - If `source_packages` is specified, the agent framework will + default to "custom".""", + ) class CreateAgentEngineConfigDict(TypedDict, total=False): @@ -5464,6 +5477,18 @@ class CreateAgentEngineConfigDict(TypedDict, total=False): the source package. """ + agent_framework: Optional[ + Literal["google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom"] + ] + """The agent framework to be used for the Agent Engine. + The OSS agent framework used to develop the agent. + Currently supported values: "google-adk", "langchain", "langgraph", + "ag2", "llama-index", "custom". + If not specified: + - If `agent` is specified, the agent framework will be auto-detected. + - If `source_packages` is specified, the agent framework will + default to "custom".""" + CreateAgentEngineConfigOrDict = Union[ CreateAgentEngineConfig, CreateAgentEngineConfigDict @@ -6067,6 +6092,19 @@ class UpdateAgentEngineConfig(_common.BaseModel): the source package. """, ) + agent_framework: Optional[ + Literal["google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom"] + ] = Field( + default=None, + description="""The agent framework to be used for the Agent Engine. + The OSS agent framework used to develop the agent. + Currently supported values: "google-adk", "langchain", "langgraph", + "ag2", "llama-index", "custom". + If not specified: + - If `agent` is specified, the agent framework will be auto-detected. + - If `source_packages` is specified, the agent framework will + default to "custom".""", + ) update_mask: Optional[str] = Field( default=None, description="""The update mask to apply. For the `FieldMask` definition, see @@ -6170,6 +6208,18 @@ class UpdateAgentEngineConfigDict(TypedDict, total=False): the source package. """ + agent_framework: Optional[ + Literal["google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom"] + ] + """The agent framework to be used for the Agent Engine. + The OSS agent framework used to develop the agent. + Currently supported values: "google-adk", "langchain", "langgraph", + "ag2", "llama-index", "custom". + If not specified: + - If `agent` is specified, the agent framework will be auto-detected. + - If `source_packages` is specified, the agent framework will + default to "custom".""" + update_mask: Optional[str] """The update mask to apply. For the `FieldMask` definition, see https://protobuf.dev/reference/protobuf/google.protobuf/#field-mask.""" @@ -12907,6 +12957,19 @@ class AgentEngineConfig(_common.BaseModel): the source package. """, ) + agent_framework: Optional[ + Literal["google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom"] + ] = Field( + default=None, + description="""The agent framework to be used for the Agent Engine. + The OSS agent framework used to develop the agent. + Currently supported values: "google-adk", "langchain", "langgraph", + "ag2", "llama-index", "custom". + If not specified: + - If `agent` is specified, the agent framework will be auto-detected. + - If `source_packages` is specified, the agent framework will + default to "custom".""", + ) class AgentEngineConfigDict(TypedDict, total=False): @@ -13034,6 +13097,18 @@ class AgentEngineConfigDict(TypedDict, total=False): the source package. """ + agent_framework: Optional[ + Literal["google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom"] + ] + """The agent framework to be used for the Agent Engine. + The OSS agent framework used to develop the agent. + Currently supported values: "google-adk", "langchain", "langgraph", + "ag2", "llama-index", "custom". + If not specified: + - If `agent` is specified, the agent framework will be auto-detected. + - If `source_packages` is specified, the agent framework will + default to "custom".""" + AgentEngineConfigOrDict = Union[AgentEngineConfig, AgentEngineConfigDict]