diff --git a/src/xpk/commands/cluster.py b/src/xpk/commands/cluster.py index e2dd9768c..2b4b3b096 100644 --- a/src/xpk/commands/cluster.py +++ b/src/xpk/commands/cluster.py @@ -84,6 +84,7 @@ from ..utils.templates import get_templates_absolute_path import shutil import os +from .managed_ml_diagnostics import install_mldiagnostics_prerequisites CLUSTER_PREHEAT_JINJA_FILE = 'cluster_preheat.yaml.j2' @@ -422,6 +423,13 @@ def cluster_create(args) -> None: # pylint: disable=line-too-long f' https://console.cloud.google.com/kubernetes/clusters/details/{get_cluster_location(args.project, args.cluster, args.zone)}/{args.cluster}/details?project={args.project}' ) + + if args.managed_mldiagnostics: + return_code = install_mldiagnostics_prerequisites() + if return_code != 0: + xpk_print('Installation of MLDiagnostics failed.') + xpk_exit(return_code) + xpk_exit(0) diff --git a/src/xpk/commands/cluster_test.py b/src/xpk/commands/cluster_test.py index b5d199621..c5f9e2fbe 100644 --- a/src/xpk/commands/cluster_test.py +++ b/src/xpk/commands/cluster_test.py @@ -84,6 +84,9 @@ def mocks(mocker) -> _Mocks: run_command_with_updates_path=( 'xpk.commands.cluster.run_command_with_updates' ), + run_command_for_value_path=( + 'xpk.commands.cluster.run_command_for_value' + ), ), ) @@ -121,6 +124,7 @@ def construct_args(**kwargs: Any) -> Namespace: cluster_cpu_machine_type='', create_vertex_tensorboard=False, enable_autoprovisioning=False, + managed_mldiagnostics=False, ) args_dict.update(kwargs) return Namespace(**args_dict) diff --git a/src/xpk/commands/managed_ml_diagnostics.py b/src/xpk/commands/managed_ml_diagnostics.py new file mode 100644 index 000000000..40f711e45 --- /dev/null +++ b/src/xpk/commands/managed_ml_diagnostics.py @@ -0,0 +1,249 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from packaging.version import Version +from ..core.commands import run_command_for_value, run_command_with_updates +from ..utils.console import xpk_print +import os +import tempfile + +_KUEUE_DEPLOYMENT_NAME = 'kueue-controller-manager' +_KUEUE_NAMESPACE_NAME = 'kueue-system' +_CERT_WEBHOOK_DEPLOYMENT_NAME = 'cert-manager-webhook' +_CERT_WEBHOOK_NAMESPACE_NAME = 'cert-manager' +_WEBHOOK_PACKAGE = 'mldiagnostics-injection-webhook' +_WEBHOOK_VERSION = Version('v0.5.0') +_WEBHOOK_FILENAME = f'{_WEBHOOK_PACKAGE}-v{_WEBHOOK_VERSION}.yaml' +_OPERATOR_PACKAGE = 'mldiagnostics-connection-operator' +_OPERATOR_VERSION = Version('v0.5.0') +_OPERATOR_FILENAME = f'{_OPERATOR_PACKAGE}-v{_OPERATOR_VERSION}.yaml' +_CERT_MANAGER_VERSION = Version('v1.13.0') + + +def _install_cert_manager(version: Version = _CERT_MANAGER_VERSION) -> int: + """ + Apply the cert-manager manifest. + + Returns: + 0 if successful and 1 otherwise. + """ + + command = ( + 'kubectl apply -f' + ' https://github.com/cert-manager/cert-manager/releases/download/' + f'v{version}/cert-manager.yaml' + ) + + return_code = run_command_with_updates( + command, f'Applying cert-manager {version} manifest...' + ) + + return return_code + + +def _download_mldiagnostics_yaml(package_name: str, version: Version) -> int: + """ + Downloads the mldiagnostics injection webhook YAML from Artifact Registry. + + Returns: + 0 if successful and 1 otherwise. + """ + + command = ( + 'gcloud artifacts generic download' + ' --repository=mldiagnostics-webhook-and-operator-yaml --location=us' + f' --package={package_name} --version=v{version} --destination=/tmp/' + ' --project=ai-on-gke' + ) + + return_code, return_output = run_command_for_value( + command, + f'Download {package_name} {version}...', + ) + + if return_code != 0: + if 'already exists' in return_output: + xpk_print( + f'Artifact file for {package_name} {version} already exists locally.' + ' Skipping download.' + ) + return 0 + + return return_code + + +def _create_mldiagnostics_namespace() -> int: + """ + Creates the 'gke-mldiagnostics' namespace. + + Returns: + 0 if successful and 1 otherwise. + """ + + command = 'kubectl create namespace gke-mldiagnostics' + + return_code, return_output = run_command_for_value( + command, 'Create gke-mldiagnostics namespace...' + ) + + if return_code != 0: + if 'already exists' in return_output: + xpk_print('Namespace already exists. Skipping creation.') + return 0 + + return return_code + + +def _install_mldiagnostics_yaml(artifact_filename: str) -> int: + """ + Applies the mldiagnostics injection webhook YAML manifest. + + Returns: + 0 if successful and 1 otherwise. + """ + full_artifact_path = os.path.join(tempfile.gettempdir(), artifact_filename) + + command = f'kubectl apply -f {full_artifact_path} -n gke-mldiagnostics' + + return run_command_with_updates( + command, + f'Install {full_artifact_path}...', + ) + + +def _label_default_namespace_mldiagnostics() -> int: + """ + Labels the 'default' namespace with 'managed-mldiagnostics-gke=true'. + + Returns: + 0 if successful and 1 otherwise. + """ + + command = 'kubectl label namespace default managed-mldiagnostics-gke=true' + + return run_command_with_updates( + command, + 'Label default namespace with managed-mldiagnostics-gke=true', + ) + + +def install_mldiagnostics_prerequisites() -> int: + """ + Mldiagnostics installation requirements. + + Returns: + 0 if successful and 1 otherwise. + """ + + if not _wait_for_deployment_ready( + deployment_name=_KUEUE_DEPLOYMENT_NAME, namespace=_KUEUE_NAMESPACE_NAME + ): + xpk_print( + f'Application {_KUEUE_DEPLOYMENT_NAME} failed to become ready within' + ' the timeout.' + ) + return 1 + + return_code = _install_cert_manager() + if return_code != 0: + return return_code + + cert_webhook_ready = _wait_for_deployment_ready( + deployment_name=_CERT_WEBHOOK_DEPLOYMENT_NAME, + namespace=_CERT_WEBHOOK_NAMESPACE_NAME, + ) + if not cert_webhook_ready: + xpk_print('The cert-manager-webhook installation failed.') + return 1 + + return_code = _download_mldiagnostics_yaml( + package_name=_WEBHOOK_PACKAGE, version=_WEBHOOK_VERSION + ) + if return_code != 0: + return return_code + + return_code = _create_mldiagnostics_namespace() + if return_code != 0: + return return_code + + return_code = _install_mldiagnostics_yaml(artifact_filename=_WEBHOOK_FILENAME) + if return_code != 0: + return return_code + + return_code = _label_default_namespace_mldiagnostics() + if return_code != 0: + return return_code + + return_code = _download_mldiagnostics_yaml( + package_name=_OPERATOR_PACKAGE, version=_OPERATOR_VERSION + ) + if return_code != 0: + return return_code + + return_code = _install_mldiagnostics_yaml( + artifact_filename=_OPERATOR_FILENAME + ) + if return_code != 0: + return return_code + + xpk_print( + 'All mldiagnostics installation and setup steps have been' + ' successfully completed!' + ) + return 0 + + +def _wait_for_deployment_ready( + deployment_name: str, namespace: str, timeout_seconds: int = 300 +) -> bool: + """ + Polls the Kubernetes Deployment status using kubectl rollout status + until it successfully rolls out (all replicas are ready) or times out. + + Args: + deployment_name: The name of the Kubernetes Deployment (e.g., 'kueue-controller-manager'). + namespace: The namespace where the Deployment is located (e.g., 'kueue-system'). + timeout_seconds: Timeout duration in seconds (default is 300s / 5 minutes). + + Returns: + bool: True if the Deployment successfully rolled out, False otherwise (timeout or error). + """ + + command = ( + f'kubectl rollout status deployment/{deployment_name} -n {namespace}' + f' --timeout={timeout_seconds}s' + ) + + return_code = run_command_with_updates( + command, f'Checking status of deployment {deployment_name}...' + ) + + if return_code != 0: + return False + + # When the status changes to 'running,' it might need about 10 seconds to fully stabilize. + stabilization_seconds = 30 + stabilization_command = f'sleep {stabilization_seconds}' + stabilization_code = run_command_with_updates( + stabilization_command, + f'Deployment {deployment_name} is ready. Waiting {stabilization_seconds}' + ' seconds for full stabilization', + verbose=True, + ) + if stabilization_code != 0: + return False + + return True diff --git a/src/xpk/commands/managed_ml_diagnostics_test.py b/src/xpk/commands/managed_ml_diagnostics_test.py new file mode 100644 index 000000000..48cec68e9 --- /dev/null +++ b/src/xpk/commands/managed_ml_diagnostics_test.py @@ -0,0 +1,146 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from dataclasses import dataclass +from unittest.mock import MagicMock +import pytest +from xpk.commands.managed_ml_diagnostics import install_mldiagnostics_prerequisites +from xpk.core.testing.commands_tester import CommandsTester + + +@dataclass +class _Mocks: + common_print_mock: MagicMock + commands_print_mock: MagicMock + commands_get_reservation_deployment_type: MagicMock + commands_tester: CommandsTester + + +@pytest.fixture +def mocks(mocker) -> _Mocks: + common_print_mock = mocker.patch( + 'xpk.commands.common.xpk_print', + return_value=None, + ) + commands_print_mock = mocker.patch( + 'xpk.commands.cluster.xpk_print', return_value=None + ) + commands_get_reservation_deployment_type = mocker.patch( + 'xpk.commands.cluster.get_reservation_deployment_type', + return_value='DENSE', + ) + return _Mocks( + common_print_mock=common_print_mock, + commands_get_reservation_deployment_type=commands_get_reservation_deployment_type, + commands_print_mock=commands_print_mock, + commands_tester=CommandsTester( + mocker, + run_command_with_updates_path=( + 'xpk.commands.managed_ml_diagnostics.run_command_with_updates' + ), + run_command_for_value_path=( + 'xpk.commands.managed_ml_diagnostics.run_command_for_value' + ), + ), + ) + + +def test_install_mldiagnostics_prerequisites_commands_executed( + mocks: _Mocks, +): + + install_mldiagnostics_prerequisites() + + mocks.commands_tester.assert_command_run( + 'kubectl', + 'rollout', + 'status', + 'deployment/kueue-controller-manager', + times=1, + ) + + mocks.commands_tester.assert_command_run( + 'kubectl', + 'apply', + '-f', + 'https://github.com/cert-manager/cert-manager/', + times=1, + ) + + mocks.commands_tester.assert_command_run( + 'kubectl', 'rollout', 'status', 'deployment/cert-manager-webhook', times=1 + ) + + mocks.commands_tester.assert_command_run( + 'gcloud', + 'artifacts', + 'generic', + 'download', + '--package=mldiagnostics-injection-webhook', + '--version=v0.5.0', + times=1, + ) + + mocks.commands_tester.assert_command_run( + 'kubectl', 'create', 'namespace', 'gke-mldiagnostics', times=1 + ) + + mocks.commands_tester.assert_command_run( + 'kubectl', + 'apply', + '-f', + '/tmp/mldiagnostics-injection-webhook-v0.5.0.yaml', + '-n', + 'gke-mldiagnostics', + times=1, + ) + + mocks.commands_tester.assert_command_run( + 'kubectl', + 'label', + 'namespace', + 'default', + 'managed-mldiagnostics-gke=true', + times=1, + ) + + mocks.commands_tester.assert_command_run( + 'gcloud', + 'artifacts', + 'generic', + 'download', + '--package=mldiagnostics-connection-operator', + '--version=v0.5.0', + times=1, + ) + + mocks.commands_tester.assert_command_run( + 'kubectl', + 'apply', + '-f', + '/tmp/mldiagnostics-connection-operator-v0.5.0.yaml', + '-n', + 'gke-mldiagnostics', + times=1, + ) + + mocks.commands_tester.assert_command_run( + 'gcloud', 'artifacts', 'generic', 'download', times=2 + ) + + mocks.commands_tester.assert_command_run( + 'kubectl', 'apply', '-f', '-n', 'gke-mldiagnostics', times=2 + ) diff --git a/src/xpk/parser/cluster.py b/src/xpk/parser/cluster.py index 4f113dbbd..4475ce232 100644 --- a/src/xpk/parser/cluster.py +++ b/src/xpk/parser/cluster.py @@ -150,6 +150,7 @@ def set_cluster_create_parser(cluster_create_parser: ArgumentParser): ' enable cluster to accept Pathways workloads.' ), ) + if FeatureFlags.SUB_SLICING_ENABLED: add_cluster_create_sub_slicing_arguments(cluster_create_optional_arguments) @@ -691,6 +692,11 @@ def add_shared_cluster_create_optional_arguments( ' regional clusters, all zones must support the machine type.' ), ) + parser_or_group.add_argument( + '--managed-mldiagnostics', + action='store_true', + help='Enables the installation of required ML Diagnostics components.', + ) parser_or_group.add_argument( '--cluster-cpu-machine-type', type=str, diff --git a/src/xpk/parser/cluster_test.py b/src/xpk/parser/cluster_test.py index 398c0d10f..6af84f331 100644 --- a/src/xpk/parser/cluster_test.py +++ b/src/xpk/parser/cluster_test.py @@ -103,3 +103,18 @@ def test_cluster_create_ray_sub_slicing_is_hidden_but_set_to_false(): assert args.sub_slicing is False assert "--sub-slicing" not in help_str + + +def test_cluster_create_managed_mldiagnostics(): + parser = argparse.ArgumentParser() + + set_cluster_create_parser(parser) + args = parser.parse_args([ + "--cluster", + "test-cluster", + "--tpu-type", + "v5p-8", + "--managed-mldiagnostics", + ]) + + assert args.managed_mldiagnostics is True