Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 174 additions & 0 deletions tests/unit/vertexai/genai/test_agent_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,77 @@ def test_create_base64_encoded_tarball_outside_project_dir_raises(self):
finally:
os.chdir(origin_dir)

@mock.patch.object(_agent_engines_utils, "_upload_requirements")
@mock.patch.object(_agent_engines_utils, "_upload_extra_packages")
@mock.patch.object(_agent_engines_utils, "_upload_agent_engine")
@mock.patch.object(_agent_engines_utils, "_scan_requirements")
@mock.patch.object(_agent_engines_utils, "_get_gcs_bucket")
def test_prepare_with_creds(
self,
mock_get_gcs_bucket,
mock_scan_requirements,
mock_upload_agent_engine,
mock_upload_extra_packages,
mock_upload_requirements,
):
mock_scan_requirements.return_value = {}
mock_creds = mock.Mock(spec=auth_credentials.AnonymousCredentials())
mock_creds.universe_domain = "googleapis.com"
_agent_engines_utils._prepare(
agent=self.test_agent,
project=_TEST_PROJECT,
location=_TEST_LOCATION,
staging_bucket=_TEST_STAGING_BUCKET,
credentials=mock_creds,
gcs_dir_name=_TEST_GCS_DIR_NAME,
requirements=[],
extra_packages=[],
)
mock_upload_agent_engine.assert_called_once_with(
agent=self.test_agent,
gcs_bucket=mock.ANY,
gcs_dir_name=_TEST_GCS_DIR_NAME,
)

@mock.patch.object(_agent_engines_utils, "_upload_requirements")
@mock.patch.object(_agent_engines_utils, "_upload_extra_packages")
@mock.patch.object(_agent_engines_utils, "_upload_agent_engine")
@mock.patch.object(_agent_engines_utils, "_scan_requirements")
@mock.patch("google.auth.default")
@mock.patch.object(_agent_engines_utils, "_get_gcs_bucket")
def test_prepare_without_creds(
self,
mock_get_gcs_bucket,
mock_auth_default,
mock_scan_requirements,
mock_upload_agent_engine,
mock_upload_extra_packages,
mock_upload_requirements,
):
mock_scan_requirements.return_value = {}
mock_creds = mock.Mock(spec=auth_credentials.AnonymousCredentials())
mock_auth_default.return_value = (mock_creds, _TEST_PROJECT)
_agent_engines_utils._prepare(
agent=self.test_agent,
project=_TEST_PROJECT,
location=_TEST_LOCATION,
staging_bucket=_TEST_STAGING_BUCKET,
gcs_dir_name=_TEST_GCS_DIR_NAME,
requirements=[],
extra_packages=[],
)
mock_get_gcs_bucket.assert_called_once_with(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
staging_bucket=_TEST_STAGING_BUCKET,
credentials=None,
)
mock_upload_agent_engine.assert_called_once_with(
agent=self.test_agent,
gcs_bucket=mock.ANY,
gcs_dir_name=_TEST_GCS_DIR_NAME,
)


@pytest.mark.usefixtures("google_auth_mock")
class TestAgentEngine:
Expand Down Expand Up @@ -2622,6 +2693,109 @@ def test_operation_schemas(
want_operation_schemas.append(want_operation_schema)
assert test_agent_engine.operation_schemas() == want_operation_schemas

@mock.patch.object(_agent_engines_utils, "_prepare")
@mock.patch.object(agent_engines.AgentEngines, "_create")
@mock.patch.object(_agent_engines_utils, "_await_operation")
def test_create_agent_engine_with_creds(
self, mock_await_operation, mock_create, mock_prepare
):
mock_operation = mock.Mock()
mock_operation.name = _TEST_AGENT_ENGINE_OPERATION_NAME
mock_create.return_value = mock_operation
mock_await_operation.return_value = _genai_types.AgentEngineOperation(
response=_genai_types.ReasoningEngine(
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
spec=_TEST_AGENT_ENGINE_SPEC,
)
)
self.client.agent_engines.create(
agent=self.test_agent,
config=_genai_types.AgentEngineConfig(
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
staging_bucket=_TEST_STAGING_BUCKET,
),
)
mock_args, mock_kwargs = mock_prepare.call_args
assert mock_kwargs["agent"] == self.test_agent
assert mock_kwargs["extra_packages"] == []
assert mock_kwargs["project"] == _TEST_PROJECT
assert mock_kwargs["location"] == _TEST_LOCATION
assert mock_kwargs["staging_bucket"] == _TEST_STAGING_BUCKET
assert mock_kwargs["credentials"] == _TEST_CREDENTIALS
assert mock_kwargs["gcs_dir_name"] == "agent_engine"

@mock.patch.object(_agent_engines_utils, "_prepare")
@mock.patch.object(agent_engines.AgentEngines, "_create")
@mock.patch("google.auth.default")
@mock.patch.object(_agent_engines_utils, "_await_operation")
def test_create_agent_engine_without_creds(
self, mock_await_operation, mock_auth_default, mock_create, mock_prepare
):
mock_operation = mock.Mock()
mock_operation.name = _TEST_AGENT_ENGINE_OPERATION_NAME
mock_create.return_value = mock_operation
mock_await_operation.return_value = _genai_types.AgentEngineOperation(
response=_genai_types.ReasoningEngine(
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
spec=_TEST_AGENT_ENGINE_SPEC,
)
)
mock_creds = mock.Mock(spec=auth_credentials.AnonymousCredentials())
mock_creds.quota_project_id = _TEST_PROJECT
mock_auth_default.return_value = (mock_creds, _TEST_PROJECT)
client = vertexai.Client(
project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=mock_creds
)
client.agent_engines.create(
agent=self.test_agent,
config=_genai_types.AgentEngineConfig(
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
staging_bucket=_TEST_STAGING_BUCKET,
),
)
mock_args, mock_kwargs = mock_prepare.call_args
assert mock_kwargs["agent"] == self.test_agent
assert mock_kwargs["extra_packages"] == []
assert mock_kwargs["project"] == _TEST_PROJECT
assert mock_kwargs["location"] == _TEST_LOCATION
assert mock_kwargs["staging_bucket"] == _TEST_STAGING_BUCKET
assert mock_kwargs["credentials"] == mock_creds
assert mock_kwargs["gcs_dir_name"] == "agent_engine"

@mock.patch.object(_agent_engines_utils, "_prepare")
@mock.patch.object(agent_engines.AgentEngines, "_create")
@mock.patch.object(_agent_engines_utils, "_await_operation")
def test_create_agent_engine_with_no_creds_in_client(
self, mock_await_operation, mock_create, mock_prepare
):
mock_operation = mock.Mock()
mock_operation.name = _TEST_AGENT_ENGINE_OPERATION_NAME
mock_create.return_value = mock_operation
mock_await_operation.return_value = _genai_types.AgentEngineOperation(
response=_genai_types.ReasoningEngine(
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
spec=_TEST_AGENT_ENGINE_SPEC,
)
)
client = vertexai.Client(
project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=None
)
client.agent_engines.create(
agent=self.test_agent,
config=_genai_types.AgentEngineConfig(
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
staging_bucket=_TEST_STAGING_BUCKET,
),
)
mock_args, mock_kwargs = mock_prepare.call_args
assert mock_kwargs["agent"] == self.test_agent
assert mock_kwargs["extra_packages"] == []
assert mock_kwargs["project"] == _TEST_PROJECT
assert mock_kwargs["location"] == _TEST_LOCATION
assert mock_kwargs["staging_bucket"] == _TEST_STAGING_BUCKET
assert mock_kwargs["credentials"] is None
assert mock_kwargs["gcs_dir_name"] == "agent_engine"


@pytest.mark.usefixtures("google_auth_mock")
class TestAgentEngineErrors:
Expand Down
10 changes: 7 additions & 3 deletions vertexai/_genai/_agent_engines_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,10 +772,11 @@ def _get_gcs_bucket(
project: str,
location: str,
staging_bucket: str,
credentials: Optional[Any] = None,
) -> _StorageBucket:
"""Gets or creates the GCS bucket."""
storage = _import_cloud_storage_or_raise()
storage_client = storage.Client(project=project)
storage_client = storage.Client(project=project, credentials=credentials)
staging_bucket = staging_bucket.replace("gs://", "")
try:
gcs_bucket = storage_client.get_bucket(staging_bucket)
Expand Down Expand Up @@ -910,6 +911,7 @@ def _prepare(
location: str,
staging_bucket: str,
gcs_dir_name: str,
credentials: Optional[Any] = None,
) -> None:
"""Prepares the agent engine for creation or updates in Vertex AI.

Expand All @@ -926,15 +928,17 @@ def _prepare(
project (str): The project for the staging bucket.
location (str): The location for the staging bucket.
staging_bucket (str): The staging bucket name in the form "gs://...".
gcs_dir_name (str): The GCS bucket directory under `staging_bucket` to
use for staging the artifacts needed.
gcs_dir_name (str): The GCS bucket directory under `staging_bucket` to use
for staging the artifacts needed.
credentials: The credentials to use for the storage client.
"""
if agent is None:
return
gcs_bucket = _get_gcs_bucket(
project=project,
location=location,
staging_bucket=staging_bucket,
credentials=credentials,
)
_upload_agent_engine(
agent=agent,
Expand Down
1 change: 1 addition & 0 deletions vertexai/_genai/agent_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,7 @@ def _create_config(
staging_bucket=staging_bucket,
gcs_dir_name=gcs_dir_name,
extra_packages=extra_packages,
credentials=self._api_client._credentials,
)
# Update the package spec.
update_masks.append("spec.package_spec.pickle_object_gcs_uri")
Expand Down
Loading