From 8d74b6968e763a8c31c77d551ace885e6a3fb2a5 Mon Sep 17 00:00:00 2001 From: Danny LI Date: Tue, 4 Nov 2025 03:20:44 +0000 Subject: [PATCH 1/7] Add Diagon installation during cluster creation and modify the workload.py Add wait_for_deployment_ready() Added unit test update goldens.yaml update goldens.yaml update goldens.yaml Fixed parser/cluster.py update goldens.yaml fixed linter fixed linter pyink Test unit test --- src/xpk/commands/cluster.py | 8 + src/xpk/commands/cluster_test.py | 3 + src/xpk/commands/managed_ml_diagnostics.py | 246 ++++++++++++++++++ .../commands/managed_ml_diagnostics_test.py | 240 +++++++++++++++++ src/xpk/parser/cluster.py | 13 + src/xpk/parser/cluster_test.py | 14 + 6 files changed, 524 insertions(+) create mode 100644 src/xpk/commands/managed_ml_diagnostics.py create mode 100644 src/xpk/commands/managed_ml_diagnostics_test.py diff --git a/src/xpk/commands/cluster.py b/src/xpk/commands/cluster.py index e2dd9768c..86dd79757 100644 --- a/src/xpk/commands/cluster.py +++ b/src/xpk/commands/cluster.py @@ -84,6 +84,7 @@ from ..utils.templates import get_templates_absolute_path import shutil import os +from . import managed_ml_diagnostics CLUSTER_PREHEAT_JINJA_FILE = 'cluster_preheat.yaml.j2' @@ -422,6 +423,13 @@ def cluster_create(args) -> None: # pylint: disable=line-too-long f' https://console.cloud.google.com/kubernetes/clusters/details/{get_cluster_location(args.project, args.cluster, args.zone)}/{args.cluster}/details?project={args.project}' ) + + if args.managed_ml_diagnostics: + return_code = managed_ml_diagnostics.install_mldiagnostics_prerequisites() + if return_code != 0: + xpk_print('Installation of MLDiagnostics failed.') + xpk_exit(return_code) + xpk_exit(0) diff --git a/src/xpk/commands/cluster_test.py b/src/xpk/commands/cluster_test.py index b5d199621..31eaf493b 100644 --- a/src/xpk/commands/cluster_test.py +++ b/src/xpk/commands/cluster_test.py @@ -84,6 +84,9 @@ def mocks(mocker) -> _Mocks: run_command_with_updates_path=( 'xpk.commands.cluster.run_command_with_updates' ), + run_command_for_value_path=( + 'xpk.commands.cluster.run_command_for_value' + ), ), ) diff --git a/src/xpk/commands/managed_ml_diagnostics.py b/src/xpk/commands/managed_ml_diagnostics.py new file mode 100644 index 000000000..73035d1f1 --- /dev/null +++ b/src/xpk/commands/managed_ml_diagnostics.py @@ -0,0 +1,246 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import time +from packaging.version import Version +from ..core.commands import run_command_for_value, run_command_with_updates +from ..utils.console import xpk_exit, xpk_print +import os + + +def _install_cert_manager(version: Version = Version('v1.13.0')) -> int: + """ + Apply the cert-manager manifest. + + Returns: + 0 if successful and 1 otherwise. + """ + + command = ( + 'kubectl apply -f' + ' https://github.com/cert-manager/cert-manager/releases/download/' + f'v{version}/cert-manager.yaml' + ) + + return_code = run_command_with_updates( + command, f'Applying cert-manager {version} manifest...' + ) + + if return_code != 0: + xpk_exit(return_code) + + return return_code + + +def _download_mldiagnostics_yaml(package_name: str, version: Version) -> int: + """ + Downloads the mldiagnostics injection webhook YAML from Artifact Registry. + + Returns: + 0 if successful and 1 otherwise. + """ + + version_with_v = f'v{version}' + command = ( + 'gcloud artifacts generic download' + ' --repository=mldiagnostics-webhook-and-operator-yaml --location=us' + f' --package={package_name} --version={version_with_v} --destination=/tmp/' + ' --project=ai-on-gke' + ) + + return_code, return_output = run_command_for_value( + command, + f'Download {package_name} {version}...', + ) + + if return_code != 0: + if 'already exists' in return_output: + xpk_print( + f'Artifact file for {package_name} {version} already exists locally.' + ' Skipping download.' + ) + return 0 + + return return_code + + +def _create_mldiagnostics_namespace() -> int: + """ + Creates the 'gke-mldiagnostics' namespace. + + Returns: + 0 if successful and 1 otherwise. + """ + + command = 'kubectl create namespace gke-mldiagnostics' + + return_code, return_output = run_command_for_value( + command, 'Create gke-mldiagnostics namespace...' + ) + + if return_code != 0: + if 'already exists' in return_output: + xpk_print('Namespace already exists. Skipping creation.') + return 0 + + return return_code + + +def _install_mldiagnostics_yaml(artifact_filename: str) -> int: + """ + Applies the mldiagnostics injection webhook YAML manifest. + + Returns: + 0 if successful and 1 otherwise. + """ + full_artifact_path = os.path.join('/tmp', artifact_filename) + + command = f'kubectl apply -f {full_artifact_path} -n gke-mldiagnostics' + + return_code = run_command_with_updates( + command, + f'Install {full_artifact_path}...', + ) + + return return_code + + +def _label_default_namespace_mldiagnostics() -> int: + """ + Labels the 'default' namespace with 'managed-mldiagnostics-gke=true'. + + Returns: + 0 if successful and 1 otherwise. + """ + + command = 'kubectl label namespace default managed-mldiagnostics-gke=true' + + return_code = run_command_with_updates( + command, + 'Label default namespace with managed-mldiagnostics-gke=true', + ) + + return return_code + + +def install_mldiagnostics_prerequisites() -> int: + """ + Mldiagnostics installation requirements. + + Returns: + 0 if successful and 1 otherwise. + """ + + kueue_deployment_name = 'kueue-controller-manager' + kueue_namespace_name = 'kueue-system' + cert_webhook_deployment_name = 'cert-manager-webhook' + cert_webhook_namespace_name = 'cert-manager' + + if not _wait_for_deployment_ready( + deployment_name=kueue_deployment_name, namespace=kueue_namespace_name + ): + xpk_print( + f'Application {kueue_deployment_name} failed to become ready within the' + ' timeout.' + ) + return 1 + + return_code = _install_cert_manager() + if return_code != 0: + return return_code + + cert_webhook_ready = _wait_for_deployment_ready( + deployment_name=cert_webhook_deployment_name, + namespace=cert_webhook_namespace_name, + ) + if not cert_webhook_ready: + xpk_print('The cert-manager-webhook installation failed.') + return 1 + + webhook_package = 'mldiagnostics-injection-webhook' + webhook_version = 'v0.5.0' + webhook_filename = f'{webhook_package}-{webhook_version}.yaml' + + return_code = _download_mldiagnostics_yaml( + package_name=webhook_package, version=Version(webhook_version) + ) + if return_code != 0: + return return_code + + return_code = _create_mldiagnostics_namespace() + if return_code != 0: + return return_code + + return_code = _install_mldiagnostics_yaml(artifact_filename=webhook_filename) + if return_code != 0: + return return_code + + return_code = _label_default_namespace_mldiagnostics() + if return_code != 0: + return return_code + + operator_package = 'mldiagnostics-connection-operator' + operator_version = 'v0.5.0' + operator_filename = f'{operator_package}-{operator_version}.yaml' + + return_code = _download_mldiagnostics_yaml( + package_name=operator_package, version=Version(webhook_version) + ) + if return_code != 0: + return return_code + + return_code = _install_mldiagnostics_yaml(artifact_filename=operator_filename) + if return_code != 0: + return return_code + + xpk_print( + 'All mldiagnostics installation and setup steps have been' + ' successfully completed!' + ) + return 0 + + +def _wait_for_deployment_ready( + deployment_name: str, namespace: str, timeout_seconds: int = 300 +) -> bool: + """ + Polls the Kubernetes Deployment status using kubectl rollout status + until it successfully rolls out (all replicas are ready) or times out. + + Args: + deployment_name: The name of the Kubernetes Deployment (e.g., 'kueue-controller-manager'). + namespace: The namespace where the Deployment is located (e.g., 'kueue-system'). + timeout_seconds: Timeout duration in seconds (default is 300s / 5 minutes). + + Returns: + bool: True if the Deployment successfully rolled out, False otherwise (timeout or error). + """ + + command = ( + f'kubectl rollout status deployment/{deployment_name} -n {namespace}' + f' --timeout={timeout_seconds}s' + ) + + return_code = run_command_with_updates( + command, f'Checking status of deployment {deployment_name}...' + ) + + if return_code != 0: + return False + + # When the status changes to 'running,' it might need about 10 seconds to fully stabilize. + time.sleep(30) + return True diff --git a/src/xpk/commands/managed_ml_diagnostics_test.py b/src/xpk/commands/managed_ml_diagnostics_test.py new file mode 100644 index 000000000..505bdd623 --- /dev/null +++ b/src/xpk/commands/managed_ml_diagnostics_test.py @@ -0,0 +1,240 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from argparse import Namespace +from dataclasses import dataclass +from typing import Any +from unittest.mock import MagicMock +import pytest +from xpk.commands.managed_ml_diagnostics import install_mldiagnostics_prerequisites +from xpk.core.testing.commands_tester import CommandsTester + + +@dataclass +class _Mocks: + common_print_mock: MagicMock + commands_print_mock: MagicMock + commands_get_reservation_deployment_type: MagicMock + commands_tester: CommandsTester + + +@pytest.fixture +def mocks(mocker) -> _Mocks: + common_print_mock = mocker.patch( + 'xpk.commands.common.xpk_print', + return_value=None, + ) + commands_print_mock = mocker.patch( + 'xpk.commands.cluster.xpk_print', return_value=None + ) + commands_get_reservation_deployment_type = mocker.patch( + 'xpk.commands.cluster.get_reservation_deployment_type', + return_value='DENSE', + ) + return _Mocks( + common_print_mock=common_print_mock, + commands_get_reservation_deployment_type=commands_get_reservation_deployment_type, + commands_print_mock=commands_print_mock, + commands_tester=CommandsTester( + mocker, + run_command_with_updates_path=( + 'xpk.commands.managed_ml_diagnostics.run_command_with_updates' + ), + run_command_for_value_path=( + 'xpk.commands.managed_ml_diagnostics.run_command_for_value' + ), + ), + ) + + +def construct_args(**kwargs: Any) -> Namespace: + args_dict = dict( + project='project', + zone='us-central1-a', + reservation='', + default_pool_cpu_machine_type='test-machine-type', + cluster='test-cluster', + default_pool_cpu_num_nodes='100', + sub_slicing=False, + gke_version='', + private=False, + authorized_networks=None, + enable_pathways=False, + enable_ray_cluster=False, + enable_workload_identity=False, + enable_gcsfuse_csi_driver=False, + enable_gcpfilestore_csi_driver=False, + enable_parallelstore_csi_driver=False, + enable_pd_csi_driver=False, + enable_lustre_csi_driver=False, + custom_cluster_arguments='', + num_slices=1, + num_nodes=1, + flex=False, + memory_limit='100Gi', + cpu_limit=100, + cluster_cpu_machine_type='', + managed_mldiagnostics=False, + ) + args_dict.update(kwargs) + return Namespace(**args_dict) + + +def test_install_mldiagnostics_prerequisites_commands_executed( + mocks: _Mocks, + mocker, +): + + mocks.commands_tester.set_result_for_command( + (0, ''), + 'kubectl', + 'rollout', + 'status', + 'deployment/kueue-controller-manager', + ) + + mocks.commands_tester.set_result_for_command( + (0, ''), + 'kubectl', + 'rollout', + 'status', + 'deployment/cert-manager-webhook', + ) + + mocks.commands_tester.set_result_for_command( + (0, ''), + 'kubectl', + 'apply', + '-f', + 'https://github.com/cert-manager/cert-manager/releases/', + ) + + mocks.commands_tester.set_result_for_command( + (0, ''), + 'gcloud', + 'artifacts', + 'generic', + 'download', + ) + + mocks.commands_tester.set_result_for_command( + (0, ''), + 'kubectl', + 'create', + 'namespace', + 'gke-mldiagnostics', + ) + + mocks.commands_tester.set_result_for_command( + (0, ''), + 'kubectl', + 'apply', + '-f', + '-n', + 'gke-mldiagnostics', + ) + + mocks.commands_tester.set_result_for_command( + (0, ''), + 'kubectl', + 'label', + 'namespace', + 'default', + 'managed-mldiagnostics-gke=true', + ) + + install_mldiagnostics_prerequisites() + + mocks.commands_tester.assert_command_run( + 'kubectl', + 'rollout', + 'status', + 'deployment/kueue-controller-manager', + times=1, + ) + + mocks.commands_tester.assert_command_run( + 'kubectl', + 'apply', + '-f', + 'https://github.com/cert-manager/cert-manager/', + times=1, + ) + + mocks.commands_tester.assert_command_run( + 'kubectl', 'rollout', 'status', 'deployment/cert-manager-webhook', times=1 + ) + + mocks.commands_tester.assert_command_run( + 'gcloud', + 'artifacts', + 'generic', + 'download', + '--package=mldiagnostics-injection-webhook', + '--version=v0.5.0', + times=1, + ) + + mocks.commands_tester.assert_command_run( + 'kubectl', 'create', 'namespace', 'gke-mldiagnostics', times=1 + ) + + mocks.commands_tester.assert_command_run( + 'kubectl', + 'apply', + '-f', + '/tmp/mldiagnostics-injection-webhook-v0.5.0.yaml', + '-n', + 'gke-mldiagnostics', + times=1, + ) + + mocks.commands_tester.assert_command_run( + 'kubectl', + 'label', + 'namespace', + 'default', + 'managed-mldiagnostics-gke=true', + times=1, + ) + + mocks.commands_tester.assert_command_run( + 'gcloud', + 'artifacts', + 'generic', + 'download', + '--package=mldiagnostics-connection-operator', + '--version=v0.5.0', + times=1, + ) + + mocks.commands_tester.assert_command_run( + 'kubectl', + 'apply', + '-f', + '/tmp/mldiagnostics-connection-operator-v0.5.0.yaml', + '-n', + 'gke-mldiagnostics', + times=1, + ) + + mocks.commands_tester.assert_command_run( + 'gcloud', 'artifacts', 'generic', 'download', times=2 + ) + + mocks.commands_tester.assert_command_run( + 'kubectl', 'apply', '-f', '-n', 'gke-mldiagnostics', times=2 + ) diff --git a/src/xpk/parser/cluster.py b/src/xpk/parser/cluster.py index 4f113dbbd..f1ea0b674 100644 --- a/src/xpk/parser/cluster.py +++ b/src/xpk/parser/cluster.py @@ -150,6 +150,13 @@ def set_cluster_create_parser(cluster_create_parser: ArgumentParser): ' enable cluster to accept Pathways workloads.' ), ) + + cluster_create_optional_arguments.add_argument( + '--managed-ml-diagnostics', + action='store_true', + help='Enables the installation of required ML Diagnostics components.', + ) + if FeatureFlags.SUB_SLICING_ENABLED: add_cluster_create_sub_slicing_arguments(cluster_create_optional_arguments) @@ -241,6 +248,12 @@ def set_cluster_create_pathways_parser( ) add_autoprovisioning_arguments(autoprovisioning_arguments) + cluster_create_pathways_optional_arguments.add_argument( + '--managed-ml-diagnostics', + action='store_true', + help='Enables the installation of required ML Diagnostics components.', + ) + ### Capacity arguments specific to "cluster create-pathways" cluster_create_pathways_capacity_arguments = ( cluster_create_pathways_parser.add_argument_group( diff --git a/src/xpk/parser/cluster_test.py b/src/xpk/parser/cluster_test.py index 398c0d10f..f9c34caac 100644 --- a/src/xpk/parser/cluster_test.py +++ b/src/xpk/parser/cluster_test.py @@ -103,3 +103,17 @@ def test_cluster_create_ray_sub_slicing_is_hidden_but_set_to_false(): assert args.sub_slicing is False assert "--sub-slicing" not in help_str + +def test_cluster_create_managed_mldiagnostics(): + parser = argparse.ArgumentParser() + + set_cluster_create_parser(parser) + args = parser.parse_args([ + "--cluster", + "test-cluster", + "--tpu-type", + "v5p-8", + "--managed-ml-diagnostics", + ]) + + assert args.managed_ml_diagnostics is True From 0a6c9e2411aa1773bddf3003401e775cdcc8f24c Mon Sep 17 00:00:00 2001 From: Danny LI Date: Wed, 19 Nov 2025 13:31:18 +0000 Subject: [PATCH 2/7] Resolve conflicts --- src/xpk/commands/cluster_test.py | 1 + src/xpk/commands/managed_ml_diagnostics.py | 56 ++++++++++--------- .../commands/managed_ml_diagnostics_test.py | 32 ----------- src/xpk/parser/cluster.py | 17 ++---- src/xpk/parser/cluster_test.py | 1 + 5 files changed, 38 insertions(+), 69 deletions(-) diff --git a/src/xpk/commands/cluster_test.py b/src/xpk/commands/cluster_test.py index 31eaf493b..b1b124ff6 100644 --- a/src/xpk/commands/cluster_test.py +++ b/src/xpk/commands/cluster_test.py @@ -124,6 +124,7 @@ def construct_args(**kwargs: Any) -> Namespace: cluster_cpu_machine_type='', create_vertex_tensorboard=False, enable_autoprovisioning=False, + managed_ml_diagnostics=False, ) args_dict.update(kwargs) return Namespace(**args_dict) diff --git a/src/xpk/commands/managed_ml_diagnostics.py b/src/xpk/commands/managed_ml_diagnostics.py index 73035d1f1..daebec76e 100644 --- a/src/xpk/commands/managed_ml_diagnostics.py +++ b/src/xpk/commands/managed_ml_diagnostics.py @@ -14,11 +14,16 @@ limitations under the License. """ -import time from packaging.version import Version from ..core.commands import run_command_for_value, run_command_with_updates -from ..utils.console import xpk_exit, xpk_print +from ..utils.console import xpk_print import os +import tempfile + +_KUEUE_DEPLOYMENT_NAME = 'kueue-controller-manager' +_KUEUE_NAMESPACE_NAME = 'kueue-system' +_CERT_WEBHOOK_DEPLOYMENT_NAME = 'cert-manager-webhook' +_CERT_WEBHOOK_NAMESPACE_NAME = 'cert-manager' def _install_cert_manager(version: Version = Version('v1.13.0')) -> int: @@ -39,9 +44,6 @@ def _install_cert_manager(version: Version = Version('v1.13.0')) -> int: command, f'Applying cert-manager {version} manifest...' ) - if return_code != 0: - xpk_exit(return_code) - return return_code @@ -53,11 +55,10 @@ def _download_mldiagnostics_yaml(package_name: str, version: Version) -> int: 0 if successful and 1 otherwise. """ - version_with_v = f'v{version}' command = ( 'gcloud artifacts generic download' ' --repository=mldiagnostics-webhook-and-operator-yaml --location=us' - f' --package={package_name} --version={version_with_v} --destination=/tmp/' + f' --package={package_name} --version=v{version} --destination=/tmp/' ' --project=ai-on-gke' ) @@ -106,7 +107,7 @@ def _install_mldiagnostics_yaml(artifact_filename: str) -> int: Returns: 0 if successful and 1 otherwise. """ - full_artifact_path = os.path.join('/tmp', artifact_filename) + full_artifact_path = os.path.join(tempfile.gettempdir(), artifact_filename) command = f'kubectl apply -f {full_artifact_path} -n gke-mldiagnostics' @@ -144,17 +145,12 @@ def install_mldiagnostics_prerequisites() -> int: 0 if successful and 1 otherwise. """ - kueue_deployment_name = 'kueue-controller-manager' - kueue_namespace_name = 'kueue-system' - cert_webhook_deployment_name = 'cert-manager-webhook' - cert_webhook_namespace_name = 'cert-manager' - if not _wait_for_deployment_ready( - deployment_name=kueue_deployment_name, namespace=kueue_namespace_name + deployment_name=_KUEUE_DEPLOYMENT_NAME, namespace=_KUEUE_NAMESPACE_NAME ): xpk_print( - f'Application {kueue_deployment_name} failed to become ready within the' - ' timeout.' + f'Application {_KUEUE_DEPLOYMENT_NAME} failed to become ready within' + ' the timeout.' ) return 1 @@ -163,19 +159,19 @@ def install_mldiagnostics_prerequisites() -> int: return return_code cert_webhook_ready = _wait_for_deployment_ready( - deployment_name=cert_webhook_deployment_name, - namespace=cert_webhook_namespace_name, + deployment_name=_CERT_WEBHOOK_DEPLOYMENT_NAME, + namespace=_CERT_WEBHOOK_NAMESPACE_NAME, ) if not cert_webhook_ready: xpk_print('The cert-manager-webhook installation failed.') return 1 webhook_package = 'mldiagnostics-injection-webhook' - webhook_version = 'v0.5.0' - webhook_filename = f'{webhook_package}-{webhook_version}.yaml' + webhook_version = Version('v0.5.0') + webhook_filename = f'{webhook_package}-v{webhook_version}.yaml' return_code = _download_mldiagnostics_yaml( - package_name=webhook_package, version=Version(webhook_version) + package_name=webhook_package, version=webhook_version ) if return_code != 0: return return_code @@ -193,11 +189,11 @@ def install_mldiagnostics_prerequisites() -> int: return return_code operator_package = 'mldiagnostics-connection-operator' - operator_version = 'v0.5.0' - operator_filename = f'{operator_package}-{operator_version}.yaml' + operator_version = Version('v0.5.0') + operator_filename = f'{operator_package}-v{operator_version}.yaml' return_code = _download_mldiagnostics_yaml( - package_name=operator_package, version=Version(webhook_version) + package_name=operator_package, version=operator_version ) if return_code != 0: return return_code @@ -242,5 +238,15 @@ def _wait_for_deployment_ready( return False # When the status changes to 'running,' it might need about 10 seconds to fully stabilize. - time.sleep(30) + stabilization_seconds = 30 + stabilization_command = f'sleep {stabilization_seconds}' + stabilization_code = run_command_with_updates( + stabilization_command, + f'Deployment {deployment_name} is ready. Waiting {stabilization_seconds}' + ' seconds for full stabilization', + verbose=True, + ) + if stabilization_code != 0: + return False + return True diff --git a/src/xpk/commands/managed_ml_diagnostics_test.py b/src/xpk/commands/managed_ml_diagnostics_test.py index 505bdd623..637641291 100644 --- a/src/xpk/commands/managed_ml_diagnostics_test.py +++ b/src/xpk/commands/managed_ml_diagnostics_test.py @@ -62,31 +62,6 @@ def mocks(mocker) -> _Mocks: def construct_args(**kwargs: Any) -> Namespace: args_dict = dict( - project='project', - zone='us-central1-a', - reservation='', - default_pool_cpu_machine_type='test-machine-type', - cluster='test-cluster', - default_pool_cpu_num_nodes='100', - sub_slicing=False, - gke_version='', - private=False, - authorized_networks=None, - enable_pathways=False, - enable_ray_cluster=False, - enable_workload_identity=False, - enable_gcsfuse_csi_driver=False, - enable_gcpfilestore_csi_driver=False, - enable_parallelstore_csi_driver=False, - enable_pd_csi_driver=False, - enable_lustre_csi_driver=False, - custom_cluster_arguments='', - num_slices=1, - num_nodes=1, - flex=False, - memory_limit='100Gi', - cpu_limit=100, - cluster_cpu_machine_type='', managed_mldiagnostics=False, ) args_dict.update(kwargs) @@ -99,7 +74,6 @@ def test_install_mldiagnostics_prerequisites_commands_executed( ): mocks.commands_tester.set_result_for_command( - (0, ''), 'kubectl', 'rollout', 'status', @@ -107,7 +81,6 @@ def test_install_mldiagnostics_prerequisites_commands_executed( ) mocks.commands_tester.set_result_for_command( - (0, ''), 'kubectl', 'rollout', 'status', @@ -115,7 +88,6 @@ def test_install_mldiagnostics_prerequisites_commands_executed( ) mocks.commands_tester.set_result_for_command( - (0, ''), 'kubectl', 'apply', '-f', @@ -123,7 +95,6 @@ def test_install_mldiagnostics_prerequisites_commands_executed( ) mocks.commands_tester.set_result_for_command( - (0, ''), 'gcloud', 'artifacts', 'generic', @@ -131,7 +102,6 @@ def test_install_mldiagnostics_prerequisites_commands_executed( ) mocks.commands_tester.set_result_for_command( - (0, ''), 'kubectl', 'create', 'namespace', @@ -139,7 +109,6 @@ def test_install_mldiagnostics_prerequisites_commands_executed( ) mocks.commands_tester.set_result_for_command( - (0, ''), 'kubectl', 'apply', '-f', @@ -148,7 +117,6 @@ def test_install_mldiagnostics_prerequisites_commands_executed( ) mocks.commands_tester.set_result_for_command( - (0, ''), 'kubectl', 'label', 'namespace', diff --git a/src/xpk/parser/cluster.py b/src/xpk/parser/cluster.py index f1ea0b674..3dda4ac5d 100644 --- a/src/xpk/parser/cluster.py +++ b/src/xpk/parser/cluster.py @@ -151,12 +151,6 @@ def set_cluster_create_parser(cluster_create_parser: ArgumentParser): ), ) - cluster_create_optional_arguments.add_argument( - '--managed-ml-diagnostics', - action='store_true', - help='Enables the installation of required ML Diagnostics components.', - ) - if FeatureFlags.SUB_SLICING_ENABLED: add_cluster_create_sub_slicing_arguments(cluster_create_optional_arguments) @@ -248,12 +242,6 @@ def set_cluster_create_pathways_parser( ) add_autoprovisioning_arguments(autoprovisioning_arguments) - cluster_create_pathways_optional_arguments.add_argument( - '--managed-ml-diagnostics', - action='store_true', - help='Enables the installation of required ML Diagnostics components.', - ) - ### Capacity arguments specific to "cluster create-pathways" cluster_create_pathways_capacity_arguments = ( cluster_create_pathways_parser.add_argument_group( @@ -917,6 +905,11 @@ def add_shared_cluster_create_capacity_arguments( ' types.' ), ) + parser_or_group.add_argument( + '--managed-ml-diagnostics', + action='store_true', + help='Enables the installation of required ML Diagnostics components.', + ) def add_shared_cluster_create_mtc_arguments( diff --git a/src/xpk/parser/cluster_test.py b/src/xpk/parser/cluster_test.py index f9c34caac..23a93a5c7 100644 --- a/src/xpk/parser/cluster_test.py +++ b/src/xpk/parser/cluster_test.py @@ -104,6 +104,7 @@ def test_cluster_create_ray_sub_slicing_is_hidden_but_set_to_false(): assert args.sub_slicing is False assert "--sub-slicing" not in help_str + def test_cluster_create_managed_mldiagnostics(): parser = argparse.ArgumentParser() From 4ca23e84d8a517ccdcac0559251692c7c090a909 Mon Sep 17 00:00:00 2001 From: Danny LI Date: Wed, 19 Nov 2025 13:31:18 +0000 Subject: [PATCH 3/7] Resolve conflicts --- src/xpk/commands/cluster.py | 6 ++-- src/xpk/commands/cluster_test.py | 2 +- src/xpk/commands/managed_ml_diagnostics.py | 35 +++++++++---------- .../commands/managed_ml_diagnostics_test.py | 14 +++++--- src/xpk/parser/cluster.py | 5 +++ 5 files changed, 34 insertions(+), 28 deletions(-) diff --git a/src/xpk/commands/cluster.py b/src/xpk/commands/cluster.py index 86dd79757..2b4b3b096 100644 --- a/src/xpk/commands/cluster.py +++ b/src/xpk/commands/cluster.py @@ -84,7 +84,7 @@ from ..utils.templates import get_templates_absolute_path import shutil import os -from . import managed_ml_diagnostics +from .managed_ml_diagnostics import install_mldiagnostics_prerequisites CLUSTER_PREHEAT_JINJA_FILE = 'cluster_preheat.yaml.j2' @@ -424,8 +424,8 @@ def cluster_create(args) -> None: f' https://console.cloud.google.com/kubernetes/clusters/details/{get_cluster_location(args.project, args.cluster, args.zone)}/{args.cluster}/details?project={args.project}' ) - if args.managed_ml_diagnostics: - return_code = managed_ml_diagnostics.install_mldiagnostics_prerequisites() + if args.managed_mldiagnostics: + return_code = install_mldiagnostics_prerequisites() if return_code != 0: xpk_print('Installation of MLDiagnostics failed.') xpk_exit(return_code) diff --git a/src/xpk/commands/cluster_test.py b/src/xpk/commands/cluster_test.py index b1b124ff6..c5f9e2fbe 100644 --- a/src/xpk/commands/cluster_test.py +++ b/src/xpk/commands/cluster_test.py @@ -124,7 +124,7 @@ def construct_args(**kwargs: Any) -> Namespace: cluster_cpu_machine_type='', create_vertex_tensorboard=False, enable_autoprovisioning=False, - managed_ml_diagnostics=False, + managed_mldiagnostics=False, ) args_dict.update(kwargs) return Namespace(**args_dict) diff --git a/src/xpk/commands/managed_ml_diagnostics.py b/src/xpk/commands/managed_ml_diagnostics.py index daebec76e..40f711e45 100644 --- a/src/xpk/commands/managed_ml_diagnostics.py +++ b/src/xpk/commands/managed_ml_diagnostics.py @@ -24,9 +24,16 @@ _KUEUE_NAMESPACE_NAME = 'kueue-system' _CERT_WEBHOOK_DEPLOYMENT_NAME = 'cert-manager-webhook' _CERT_WEBHOOK_NAMESPACE_NAME = 'cert-manager' +_WEBHOOK_PACKAGE = 'mldiagnostics-injection-webhook' +_WEBHOOK_VERSION = Version('v0.5.0') +_WEBHOOK_FILENAME = f'{_WEBHOOK_PACKAGE}-v{_WEBHOOK_VERSION}.yaml' +_OPERATOR_PACKAGE = 'mldiagnostics-connection-operator' +_OPERATOR_VERSION = Version('v0.5.0') +_OPERATOR_FILENAME = f'{_OPERATOR_PACKAGE}-v{_OPERATOR_VERSION}.yaml' +_CERT_MANAGER_VERSION = Version('v1.13.0') -def _install_cert_manager(version: Version = Version('v1.13.0')) -> int: +def _install_cert_manager(version: Version = _CERT_MANAGER_VERSION) -> int: """ Apply the cert-manager manifest. @@ -111,13 +118,11 @@ def _install_mldiagnostics_yaml(artifact_filename: str) -> int: command = f'kubectl apply -f {full_artifact_path} -n gke-mldiagnostics' - return_code = run_command_with_updates( + return run_command_with_updates( command, f'Install {full_artifact_path}...', ) - return return_code - def _label_default_namespace_mldiagnostics() -> int: """ @@ -129,13 +134,11 @@ def _label_default_namespace_mldiagnostics() -> int: command = 'kubectl label namespace default managed-mldiagnostics-gke=true' - return_code = run_command_with_updates( + return run_command_with_updates( command, 'Label default namespace with managed-mldiagnostics-gke=true', ) - return return_code - def install_mldiagnostics_prerequisites() -> int: """ @@ -166,12 +169,8 @@ def install_mldiagnostics_prerequisites() -> int: xpk_print('The cert-manager-webhook installation failed.') return 1 - webhook_package = 'mldiagnostics-injection-webhook' - webhook_version = Version('v0.5.0') - webhook_filename = f'{webhook_package}-v{webhook_version}.yaml' - return_code = _download_mldiagnostics_yaml( - package_name=webhook_package, version=webhook_version + package_name=_WEBHOOK_PACKAGE, version=_WEBHOOK_VERSION ) if return_code != 0: return return_code @@ -180,7 +179,7 @@ def install_mldiagnostics_prerequisites() -> int: if return_code != 0: return return_code - return_code = _install_mldiagnostics_yaml(artifact_filename=webhook_filename) + return_code = _install_mldiagnostics_yaml(artifact_filename=_WEBHOOK_FILENAME) if return_code != 0: return return_code @@ -188,17 +187,15 @@ def install_mldiagnostics_prerequisites() -> int: if return_code != 0: return return_code - operator_package = 'mldiagnostics-connection-operator' - operator_version = Version('v0.5.0') - operator_filename = f'{operator_package}-v{operator_version}.yaml' - return_code = _download_mldiagnostics_yaml( - package_name=operator_package, version=operator_version + package_name=_OPERATOR_PACKAGE, version=_OPERATOR_VERSION ) if return_code != 0: return return_code - return_code = _install_mldiagnostics_yaml(artifact_filename=operator_filename) + return_code = _install_mldiagnostics_yaml( + artifact_filename=_OPERATOR_FILENAME + ) if return_code != 0: return return_code diff --git a/src/xpk/commands/managed_ml_diagnostics_test.py b/src/xpk/commands/managed_ml_diagnostics_test.py index 637641291..0482a376b 100644 --- a/src/xpk/commands/managed_ml_diagnostics_test.py +++ b/src/xpk/commands/managed_ml_diagnostics_test.py @@ -51,15 +51,14 @@ def mocks(mocker) -> _Mocks: commands_tester=CommandsTester( mocker, run_command_with_updates_path=( - 'xpk.commands.managed_ml_diagnostics.run_command_with_updates' + 'xpk.commands.cluster.run_command_with_updates' ), run_command_for_value_path=( - 'xpk.commands.managed_ml_diagnostics.run_command_for_value' + 'xpk.commands.cluster.run_command_for_value' ), ), ) - def construct_args(**kwargs: Any) -> Namespace: args_dict = dict( managed_mldiagnostics=False, @@ -67,13 +66,12 @@ def construct_args(**kwargs: Any) -> Namespace: args_dict.update(kwargs) return Namespace(**args_dict) - def test_install_mldiagnostics_prerequisites_commands_executed( mocks: _Mocks, mocker, ): - mocks.commands_tester.set_result_for_command( + (0, ''), 'kubectl', 'rollout', 'status', @@ -81,6 +79,7 @@ def test_install_mldiagnostics_prerequisites_commands_executed( ) mocks.commands_tester.set_result_for_command( + (0, ''), 'kubectl', 'rollout', 'status', @@ -88,6 +87,7 @@ def test_install_mldiagnostics_prerequisites_commands_executed( ) mocks.commands_tester.set_result_for_command( + (0, ''), 'kubectl', 'apply', '-f', @@ -95,6 +95,7 @@ def test_install_mldiagnostics_prerequisites_commands_executed( ) mocks.commands_tester.set_result_for_command( + (0, ''), 'gcloud', 'artifacts', 'generic', @@ -102,6 +103,7 @@ def test_install_mldiagnostics_prerequisites_commands_executed( ) mocks.commands_tester.set_result_for_command( + (0, ''), 'kubectl', 'create', 'namespace', @@ -109,6 +111,7 @@ def test_install_mldiagnostics_prerequisites_commands_executed( ) mocks.commands_tester.set_result_for_command( + (0, ''), 'kubectl', 'apply', '-f', @@ -117,6 +120,7 @@ def test_install_mldiagnostics_prerequisites_commands_executed( ) mocks.commands_tester.set_result_for_command( + (0, ''), 'kubectl', 'label', 'namespace', diff --git a/src/xpk/parser/cluster.py b/src/xpk/parser/cluster.py index 3dda4ac5d..5bcfc9e95 100644 --- a/src/xpk/parser/cluster.py +++ b/src/xpk/parser/cluster.py @@ -692,6 +692,11 @@ def add_shared_cluster_create_optional_arguments( ' regional clusters, all zones must support the machine type.' ), ) + parser_or_group.add_argument( + '--managed-mldiagnostics', + action='store_true', + help='Enables the installation of required ML Diagnostics components.', + ) parser_or_group.add_argument( '--cluster-cpu-machine-type', type=str, From e01756881394ef44d0f86a15b4b7caf64fef9801 Mon Sep 17 00:00:00 2001 From: Danny LI Date: Fri, 21 Nov 2025 04:16:34 +0000 Subject: [PATCH 4/7] test unit test --- .../commands/managed_ml_diagnostics_test.py | 33 +++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/src/xpk/commands/managed_ml_diagnostics_test.py b/src/xpk/commands/managed_ml_diagnostics_test.py index 0482a376b..8fefeb570 100644 --- a/src/xpk/commands/managed_ml_diagnostics_test.py +++ b/src/xpk/commands/managed_ml_diagnostics_test.py @@ -51,25 +51,53 @@ def mocks(mocker) -> _Mocks: commands_tester=CommandsTester( mocker, run_command_with_updates_path=( - 'xpk.commands.cluster.run_command_with_updates' + 'xpk.commands.managed_ml_diagnostics.run_command_with_updates' ), run_command_for_value_path=( - 'xpk.commands.cluster.run_command_for_value' + 'xpk.commands.managed_ml_diagnostics.run_command_for_value' ), ), ) + def construct_args(**kwargs: Any) -> Namespace: args_dict = dict( + project='project', + zone='us-central1-a', + reservation='', + default_pool_cpu_machine_type='test-machine-type', + cluster='test-cluster', + default_pool_cpu_num_nodes='100', + sub_slicing=False, + gke_version='', + private=False, + authorized_networks=None, + enable_pathways=False, + enable_ray_cluster=False, + enable_workload_identity=False, + enable_gcsfuse_csi_driver=False, + enable_gcpfilestore_csi_driver=False, + enable_parallelstore_csi_driver=False, + enable_pd_csi_driver=False, + enable_lustre_csi_driver=False, + custom_cluster_arguments='', + num_slices=1, + num_nodes=1, + flex=False, + memory_limit='100Gi', + cpu_limit=100, + cluster_cpu_machine_type='', managed_mldiagnostics=False, ) args_dict.update(kwargs) return Namespace(**args_dict) + def test_install_mldiagnostics_prerequisites_commands_executed( mocks: _Mocks, mocker, ): + mocks.commands_tester.set_result_for_command( (0, ''), 'kubectl', @@ -210,3 +238,4 @@ def test_install_mldiagnostics_prerequisites_commands_executed( mocks.commands_tester.assert_command_run( 'kubectl', 'apply', '-f', '-n', 'gke-mldiagnostics', times=2 ) + \ No newline at end of file From 1e680606552f31e167fe71ab1eb2120833427def Mon Sep 17 00:00:00 2001 From: Danny LI Date: Fri, 21 Nov 2025 04:22:48 +0000 Subject: [PATCH 5/7] test unit test --- .../commands/managed_ml_diagnostics_test.py | 63 +++++++++---------- 1 file changed, 31 insertions(+), 32 deletions(-) diff --git a/src/xpk/commands/managed_ml_diagnostics_test.py b/src/xpk/commands/managed_ml_diagnostics_test.py index 8fefeb570..2e3b7715b 100644 --- a/src/xpk/commands/managed_ml_diagnostics_test.py +++ b/src/xpk/commands/managed_ml_diagnostics_test.py @@ -60,37 +60,37 @@ def mocks(mocker) -> _Mocks: ) -def construct_args(**kwargs: Any) -> Namespace: - args_dict = dict( - project='project', - zone='us-central1-a', - reservation='', - default_pool_cpu_machine_type='test-machine-type', - cluster='test-cluster', - default_pool_cpu_num_nodes='100', - sub_slicing=False, - gke_version='', - private=False, - authorized_networks=None, - enable_pathways=False, - enable_ray_cluster=False, - enable_workload_identity=False, - enable_gcsfuse_csi_driver=False, - enable_gcpfilestore_csi_driver=False, - enable_parallelstore_csi_driver=False, - enable_pd_csi_driver=False, - enable_lustre_csi_driver=False, - custom_cluster_arguments='', - num_slices=1, - num_nodes=1, - flex=False, - memory_limit='100Gi', - cpu_limit=100, - cluster_cpu_machine_type='', - managed_mldiagnostics=False, - ) - args_dict.update(kwargs) - return Namespace(**args_dict) +# def construct_args(**kwargs: Any) -> Namespace: +# args_dict = dict( +# project='project', +# zone='us-central1-a', +# reservation='', +# default_pool_cpu_machine_type='test-machine-type', +# cluster='test-cluster', +# default_pool_cpu_num_nodes='100', +# sub_slicing=False, +# gke_version='', +# private=False, +# authorized_networks=None, +# enable_pathways=False, +# enable_ray_cluster=False, +# enable_workload_identity=False, +# enable_gcsfuse_csi_driver=False, +# enable_gcpfilestore_csi_driver=False, +# enable_parallelstore_csi_driver=False, +# enable_pd_csi_driver=False, +# enable_lustre_csi_driver=False, +# custom_cluster_arguments='', +# num_slices=1, +# num_nodes=1, +# flex=False, +# memory_limit='100Gi', +# cpu_limit=100, +# cluster_cpu_machine_type='', +# managed_mldiagnostics=False, +# ) +# args_dict.update(kwargs) +# return Namespace(**args_dict) def test_install_mldiagnostics_prerequisites_commands_executed( @@ -238,4 +238,3 @@ def test_install_mldiagnostics_prerequisites_commands_executed( mocks.commands_tester.assert_command_run( 'kubectl', 'apply', '-f', '-n', 'gke-mldiagnostics', times=2 ) - \ No newline at end of file From 79ccf19c9f2d8b41dbc742d0cf9f034fd59734ce Mon Sep 17 00:00:00 2001 From: Danny LI Date: Fri, 21 Nov 2025 04:33:10 +0000 Subject: [PATCH 6/7] test unit test --- .../commands/managed_ml_diagnostics_test.py | 46 +++---------------- 1 file changed, 6 insertions(+), 40 deletions(-) diff --git a/src/xpk/commands/managed_ml_diagnostics_test.py b/src/xpk/commands/managed_ml_diagnostics_test.py index 2e3b7715b..fe6b88b82 100644 --- a/src/xpk/commands/managed_ml_diagnostics_test.py +++ b/src/xpk/commands/managed_ml_diagnostics_test.py @@ -50,49 +50,15 @@ def mocks(mocker) -> _Mocks: commands_print_mock=commands_print_mock, commands_tester=CommandsTester( mocker, - run_command_with_updates_path=( - 'xpk.commands.managed_ml_diagnostics.run_command_with_updates' - ), - run_command_for_value_path=( - 'xpk.commands.managed_ml_diagnostics.run_command_for_value' - ), + # run_command_with_updates_path=( + # 'xpk.commands.managed_ml_diagnostics.run_command_with_updates' + # ), + # run_command_for_value_path=( + # 'xpk.commands.managed_ml_diagnostics.run_command_for_value' + # ), ), ) - -# def construct_args(**kwargs: Any) -> Namespace: -# args_dict = dict( -# project='project', -# zone='us-central1-a', -# reservation='', -# default_pool_cpu_machine_type='test-machine-type', -# cluster='test-cluster', -# default_pool_cpu_num_nodes='100', -# sub_slicing=False, -# gke_version='', -# private=False, -# authorized_networks=None, -# enable_pathways=False, -# enable_ray_cluster=False, -# enable_workload_identity=False, -# enable_gcsfuse_csi_driver=False, -# enable_gcpfilestore_csi_driver=False, -# enable_parallelstore_csi_driver=False, -# enable_pd_csi_driver=False, -# enable_lustre_csi_driver=False, -# custom_cluster_arguments='', -# num_slices=1, -# num_nodes=1, -# flex=False, -# memory_limit='100Gi', -# cpu_limit=100, -# cluster_cpu_machine_type='', -# managed_mldiagnostics=False, -# ) -# args_dict.update(kwargs) -# return Namespace(**args_dict) - - def test_install_mldiagnostics_prerequisites_commands_executed( mocks: _Mocks, mocker, From df4c4171c08a8eeaca92eafe846fc1c4968c7f7a Mon Sep 17 00:00:00 2001 From: Danny LI Date: Fri, 21 Nov 2025 05:01:37 +0000 Subject: [PATCH 7/7] deleted set_result_for_command --- .../commands/managed_ml_diagnostics_test.py | 74 ++----------------- src/xpk/parser/cluster.py | 5 -- src/xpk/parser/cluster_test.py | 4 +- 3 files changed, 9 insertions(+), 74 deletions(-) diff --git a/src/xpk/commands/managed_ml_diagnostics_test.py b/src/xpk/commands/managed_ml_diagnostics_test.py index fe6b88b82..48cec68e9 100644 --- a/src/xpk/commands/managed_ml_diagnostics_test.py +++ b/src/xpk/commands/managed_ml_diagnostics_test.py @@ -14,9 +14,7 @@ limitations under the License. """ -from argparse import Namespace from dataclasses import dataclass -from typing import Any from unittest.mock import MagicMock import pytest from xpk.commands.managed_ml_diagnostics import install_mldiagnostics_prerequisites @@ -50,78 +48,20 @@ def mocks(mocker) -> _Mocks: commands_print_mock=commands_print_mock, commands_tester=CommandsTester( mocker, - # run_command_with_updates_path=( - # 'xpk.commands.managed_ml_diagnostics.run_command_with_updates' - # ), - # run_command_for_value_path=( - # 'xpk.commands.managed_ml_diagnostics.run_command_for_value' - # ), + run_command_with_updates_path=( + 'xpk.commands.managed_ml_diagnostics.run_command_with_updates' + ), + run_command_for_value_path=( + 'xpk.commands.managed_ml_diagnostics.run_command_for_value' + ), ), ) + def test_install_mldiagnostics_prerequisites_commands_executed( mocks: _Mocks, - mocker, ): - mocks.commands_tester.set_result_for_command( - (0, ''), - 'kubectl', - 'rollout', - 'status', - 'deployment/kueue-controller-manager', - ) - - mocks.commands_tester.set_result_for_command( - (0, ''), - 'kubectl', - 'rollout', - 'status', - 'deployment/cert-manager-webhook', - ) - - mocks.commands_tester.set_result_for_command( - (0, ''), - 'kubectl', - 'apply', - '-f', - 'https://github.com/cert-manager/cert-manager/releases/', - ) - - mocks.commands_tester.set_result_for_command( - (0, ''), - 'gcloud', - 'artifacts', - 'generic', - 'download', - ) - - mocks.commands_tester.set_result_for_command( - (0, ''), - 'kubectl', - 'create', - 'namespace', - 'gke-mldiagnostics', - ) - - mocks.commands_tester.set_result_for_command( - (0, ''), - 'kubectl', - 'apply', - '-f', - '-n', - 'gke-mldiagnostics', - ) - - mocks.commands_tester.set_result_for_command( - (0, ''), - 'kubectl', - 'label', - 'namespace', - 'default', - 'managed-mldiagnostics-gke=true', - ) - install_mldiagnostics_prerequisites() mocks.commands_tester.assert_command_run( diff --git a/src/xpk/parser/cluster.py b/src/xpk/parser/cluster.py index 5bcfc9e95..4475ce232 100644 --- a/src/xpk/parser/cluster.py +++ b/src/xpk/parser/cluster.py @@ -910,11 +910,6 @@ def add_shared_cluster_create_capacity_arguments( ' types.' ), ) - parser_or_group.add_argument( - '--managed-ml-diagnostics', - action='store_true', - help='Enables the installation of required ML Diagnostics components.', - ) def add_shared_cluster_create_mtc_arguments( diff --git a/src/xpk/parser/cluster_test.py b/src/xpk/parser/cluster_test.py index 23a93a5c7..6af84f331 100644 --- a/src/xpk/parser/cluster_test.py +++ b/src/xpk/parser/cluster_test.py @@ -114,7 +114,7 @@ def test_cluster_create_managed_mldiagnostics(): "test-cluster", "--tpu-type", "v5p-8", - "--managed-ml-diagnostics", + "--managed-mldiagnostics", ]) - assert args.managed_ml_diagnostics is True + assert args.managed_mldiagnostics is True