diff --git a/torchx/schedulers/kubernetes_scheduler.py b/torchx/schedulers/kubernetes_scheduler.py index 648e020ce..c037da2c9 100644 --- a/torchx/schedulers/kubernetes_scheduler.py +++ b/torchx/schedulers/kubernetes_scheduler.py @@ -27,6 +27,76 @@ See the `Volcano Quickstart `_ for more information. + +Pod Overlay +=========== + +You can overlay arbitrary Kubernetes Pod fields on generated pods by setting +the ``kubernetes`` metadata on your role. The value can be: + +- A dict with the overlay structure +- A resource URI pointing to a YAML file (e.g. ``file://``, ``s3://``, ``gs://``) + +Merge semantics: +- **dict**: recursive merge (upsert) +- **list**: append by default, replace if tuple (Python) or ``!!python/tuple`` tag (YAML) +- **primitives**: replace + +.. code:: python + + from torchx.specs import Role + + # Dict overlay - lists append, tuples replace + role = Role( + name="trainer", + image="my-image:latest", + entrypoint="train.py", + metadata={ + "kubernetes": { + "spec": { + "nodeSelector": {"gpu": "true"}, + "tolerations": [{"key": "nvidia.com/gpu", "operator": "Exists"}], # appends + "volumes": ({"name": "my-volume", "emptyDir": {}},) # replaces + } + } + } + ) + + # File URI overlay + role = Role( + name="trainer", + image="my-image:latest", + entrypoint="train.py", + metadata={ + "kubernetes": "file:///path/to/pod_overlay.yaml" + } + ) + +CLI usage with builtin components: + +.. code:: bash + + $ torchx run --scheduler kubernetes dist.ddp \\ + --metadata kubernetes=file:///path/to/pod_overlay.yaml \\ + --script train.py + +Example ``pod_overlay.yaml``: + +.. code:: yaml + + spec: + nodeSelector: + node.kubernetes.io/instance-type: p4d.24xlarge + tolerations: + - key: nvidia.com/gpu + operator: Exists + effect: NoSchedule + volumes: !!python/tuple + - name: my-volume + emptyDir: {} + +The overlay is deep-merged with the generated pod, preserving existing fields +and adding or overriding specified ones. """ import json @@ -45,6 +115,7 @@ Tuple, TYPE_CHECKING, TypedDict, + Union, ) import torchx @@ -97,6 +168,40 @@ RESERVED_MILLICPU = 100 RESERVED_MEMMB = 1024 + +def _apply_pod_overlay(pod: "V1Pod", overlay: Dict[str, Any]) -> None: + """Apply overlay dict to V1Pod object, merging nested fields. + + Merge semantics: + - dict: upsert (recursive merge) + - list: append by default, replace if tuple + - primitives: replace + """ + from kubernetes import client + + api = client.ApiClient() + pod_dict = api.sanitize_for_serialization(pod) + + def deep_merge(base: Dict[str, Any], overlay: Dict[str, Any]) -> None: + for key, value in overlay.items(): + if isinstance(value, dict) and key in base and isinstance(base[key], dict): + deep_merge(base[key], value) + elif isinstance(value, tuple): + base[key] = list(value) + elif ( + isinstance(value, list) and key in base and isinstance(base[key], list) + ): + base[key].extend(value) + else: + base[key] = value + + deep_merge(pod_dict, overlay) + + merged_pod = api._ApiClient__deserialize(pod_dict, "V1Pod") + pod.spec = merged_pod.spec + pod.metadata = merged_pod.metadata + + RETRY_POLICIES: Mapping[str, Iterable[Mapping[str, str]]] = { RetryPolicy.REPLICA: [], RetryPolicy.APPLICATION: [ @@ -402,6 +507,17 @@ def app_to_resource( replica_role.env["TORCHX_IMAGE"] = replica_role.image pod = role_to_pod(name, replica_role, service_account) + if k8s_metadata := role.metadata.get("kubernetes"): + if isinstance(k8s_metadata, str): + import fsspec + + with fsspec.open(k8s_metadata, "r") as f: + k8s_metadata = yaml.unsafe_load(f) + elif not isinstance(k8s_metadata, dict): + raise ValueError( + f"metadata['kubernetes'] must be a dict or resource URI, got {type(k8s_metadata)}" + ) + _apply_pod_overlay(pod, k8s_metadata) pod.metadata.labels.update( pod_labels( app=app, @@ -637,7 +753,7 @@ def schedule(self, dryrun_info: AppDryRunInfo[KubernetesJob]) -> str: else: raise - return f'{namespace}:{resp["metadata"]["name"]}' + return f"{namespace}:{resp['metadata']['name']}" def _submit_dryrun( self, app: AppDef, cfg: KubernetesOpts diff --git a/torchx/schedulers/test/kubernetes_scheduler_test.py b/torchx/schedulers/test/kubernetes_scheduler_test.py index dee7a6f9e..d3f2c2346 100644 --- a/torchx/schedulers/test/kubernetes_scheduler_test.py +++ b/torchx/schedulers/test/kubernetes_scheduler_test.py @@ -94,6 +94,18 @@ def _test_app(num_replicas: int = 1) -> specs.AppDef: class KubernetesSchedulerTest(unittest.TestCase): + def setUp(self) -> None: + # Mock create_namespaced_custom_object for validation calls in submit_dryrun + # This prevents tests from calling real k8s endpoint during validation + self.mock_create_patcher = patch( + "kubernetes.client.CustomObjectsApi.create_namespaced_custom_object" + ) + self.mock_create = self.mock_create_patcher.start() + self.mock_create.return_value = {} + + def tearDown(self) -> None: + self.mock_create_patcher.stop() + def test_create_scheduler(self) -> None: client = MagicMock() docker_client = MagicMock @@ -247,11 +259,7 @@ def test_role_to_pod(self) -> None: want, ) - @patch( - "torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api" - ) - def test_submit_dryrun(self, mock_api: MagicMock) -> None: - mock_api.return_value.create_namespaced_custom_object.return_value = {} + def test_submit_dryrun(self) -> None: scheduler = create_scheduler("test") app = _test_app() cfg = KubernetesOpts({"queue": "testqueue"}) @@ -262,8 +270,8 @@ def test_submit_dryrun(self, mock_api: MagicMock) -> None: info = scheduler.submit_dryrun(app, cfg) resource = str(info.request) - mock_api.return_value.create_namespaced_custom_object.assert_called_once() - call_kwargs = mock_api.return_value.create_namespaced_custom_object.call_args[1] + self.mock_create.assert_called_once() + call_kwargs = self.mock_create.call_args[1] self.assertEqual(call_kwargs["dry_run"], "All") print(resource) @@ -508,11 +516,7 @@ def test_instance_type(self) -> None: }, ) - @patch( - "torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api" - ) - def test_rank0_env(self, mock_api: MagicMock) -> None: - mock_api.return_value.create_namespaced_custom_object.return_value = {} + def test_rank0_env(self) -> None: from kubernetes.client.models import V1EnvVar scheduler = create_scheduler("test") @@ -535,16 +539,12 @@ def test_rank0_env(self, mock_api: MagicMock) -> None: ) container1 = tasks[1]["template"].spec.containers[0] self.assertIn("VC_TRAINERFOO_0_HOSTS", container1.command) - mock_api.return_value.create_namespaced_custom_object.assert_called_once() - call_kwargs = mock_api.return_value.create_namespaced_custom_object.call_args[1] + self.mock_create.assert_called_once() + call_kwargs = self.mock_create.call_args[1] self.assertEqual(call_kwargs["dry_run"], "All") self.assertEqual(call_kwargs["namespace"], "default") - @patch( - "torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api" - ) - def test_submit_dryrun_patch(self, mock_api: MagicMock) -> None: - mock_api.return_value.create_namespaced_custom_object.return_value = {} + def test_submit_dryrun_patch(self) -> None: scheduler = create_scheduler("test") app = _test_app() app.roles[0].image = "sha256:testhash" @@ -570,15 +570,11 @@ def test_submit_dryrun_patch(self, mock_api: MagicMock) -> None: ), }, ) - mock_api.return_value.create_namespaced_custom_object.assert_called_once() - call_kwargs = mock_api.return_value.create_namespaced_custom_object.call_args[1] + self.mock_create.assert_called_once() + call_kwargs = self.mock_create.call_args[1] self.assertEqual(call_kwargs["dry_run"], "All") - @patch( - "torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api" - ) - def test_submit_dryrun_service_account(self, mock_api: MagicMock) -> None: - mock_api.return_value.create_namespaced_custom_object.return_value = {} + def test_submit_dryrun_service_account(self) -> None: scheduler = create_scheduler("test") self.assertIn("service_account", scheduler.run_opts()._opts) app = _test_app() @@ -595,17 +591,11 @@ def test_submit_dryrun_service_account(self, mock_api: MagicMock) -> None: info = scheduler.submit_dryrun(app, cfg) self.assertIn("service_account_name': None", str(info.request.resource)) - self.assertEqual( - mock_api.return_value.create_namespaced_custom_object.call_count, 2 - ) - call_kwargs = mock_api.return_value.create_namespaced_custom_object.call_args[1] + self.assertEqual(self.mock_create.call_count, 2) + call_kwargs = self.mock_create.call_args[1] self.assertEqual(call_kwargs["dry_run"], "All") - @patch( - "torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api" - ) - def test_submit_dryrun_priority_class(self, mock_api: MagicMock) -> None: - mock_api.return_value.create_namespaced_custom_object.return_value = {} + def test_submit_dryrun_priority_class(self) -> None: scheduler = create_scheduler("test") self.assertIn("priority_class", scheduler.run_opts()._opts) app = _test_app() @@ -623,10 +613,8 @@ def test_submit_dryrun_priority_class(self, mock_api: MagicMock) -> None: info = scheduler.submit_dryrun(app, cfg) self.assertNotIn("'priorityClassName'", str(info.request.resource)) - self.assertEqual( - mock_api.return_value.create_namespaced_custom_object.call_count, 2 - ) - call_kwargs = mock_api.return_value.create_namespaced_custom_object.call_args[1] + self.assertEqual(self.mock_create.call_count, 2) + call_kwargs = self.mock_create.call_args[1] self.assertEqual(call_kwargs["dry_run"], "All") @patch("kubernetes.client.CustomObjectsApi.create_namespaced_custom_object") @@ -970,6 +958,9 @@ def test_log_iter(self, read_namespaced_pod_log: MagicMock) -> None: ) def test_push_patches(self) -> None: + # Configure mock to return proper response for schedule() call + self.mock_create.return_value = {"metadata": {"name": "testjob"}} + client = MagicMock() scheduler = KubernetesScheduler( "foo", @@ -999,22 +990,18 @@ def test_min_replicas(self) -> None: min_available = [task["minAvailable"] for task in resource["spec"]["tasks"]] self.assertEqual(min_available, [1, 1, 0]) - @patch( - "torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api" - ) - def test_validate_spec_invalid_name(self, mock_api: MagicMock) -> None: + def test_validate_spec_invalid_name(self) -> None: from kubernetes.client.rest import ApiException scheduler = create_scheduler("test") app = _test_app() app.name = "Invalid_Name" - mock_api_instance = MagicMock() - mock_api_instance.create_namespaced_custom_object.side_effect = ApiException( + # Override the default mock behavior for this test + self.mock_create.side_effect = ApiException( status=422, reason="Invalid", ) - mock_api.return_value = mock_api_instance cfg = cast(KubernetesOpts, {"queue": "testqueue", "validate_spec": True}) @@ -1022,8 +1009,8 @@ def test_validate_spec_invalid_name(self, mock_api: MagicMock) -> None: scheduler.submit_dryrun(app, cfg) self.assertIn("Invalid job spec", str(ctx.exception)) - mock_api_instance.create_namespaced_custom_object.assert_called_once() - call_kwargs = mock_api_instance.create_namespaced_custom_object.call_args[1] + self.mock_create.assert_called_once() + call_kwargs = self.mock_create.call_args[1] self.assertEqual(call_kwargs["dry_run"], "All") def test_validate_spec_disabled(self) -> None: @@ -1032,55 +1019,284 @@ def test_validate_spec_disabled(self) -> None: cfg = KubernetesOpts({"queue": "testqueue", "validate_spec": False}) - with patch( - "torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api" - ) as mock_api: - mock_api_instance = MagicMock() - mock_api_instance.create_namespaced_custom_object.return_value = {} - mock_api.return_value = mock_api_instance - - info = scheduler.submit_dryrun(app, cfg) + info = scheduler.submit_dryrun(app, cfg) self.assertIsNotNone(info) - mock_api_instance.create_namespaced_custom_object.assert_not_called() + self.mock_create.assert_not_called() - @patch( - "torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api" - ) - def test_validate_spec_invalid_task_name(self, mock_api: MagicMock) -> None: + def test_validate_spec_invalid_task_name(self) -> None: from kubernetes.client.rest import ApiException scheduler = create_scheduler("test") app = _test_app() app.roles[0].name = "Invalid-Task-Name" - mock_api_instance = MagicMock() - mock_api_instance.create_namespaced_custom_object.side_effect = ApiException( + # Override the default mock behavior for this test + self.mock_create.side_effect = ApiException( status=422, reason="Invalid", ) - mock_api.return_value = mock_api_instance cfg = cast(KubernetesOpts, {"queue": "testqueue", "validate_spec": True}) - with self.assertRaises(ValueError) as ctx: scheduler.submit_dryrun(app, cfg) + self.assertIn("Invalid job spec", str(ctx.exception)) - self.assertIn("Invalid job spec", str(ctx.exception)) + def test_apply_pod_overlay(self) -> None: + from kubernetes.client.models import V1Container, V1ObjectMeta, V1Pod, V1PodSpec + from torchx.schedulers.kubernetes_scheduler import _apply_pod_overlay - @patch( - "torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api" - ) - def test_validate_spec_long_pod_name(self, mock_api: MagicMock) -> None: + pod = V1Pod( + spec=V1PodSpec( + containers=[V1Container(name="test", image="test:latest")], + node_selector={"existing": "label"}, + ), + metadata=V1ObjectMeta(name="test-pod"), + ) + + overlay = { + "spec": { + "nodeSelector": {"gpu": "true"}, + "tolerations": [{"key": "nvidia.com/gpu", "operator": "Exists"}], + } + } + + _apply_pod_overlay(pod, overlay) + + self.assertEqual(pod.spec.node_selector, {"existing": "label", "gpu": "true"}) + self.assertEqual(len(pod.spec.tolerations), 1) + self.assertEqual(pod.spec.tolerations[0].key, "nvidia.com/gpu") + + def test_apply_pod_overlay_new_fields(self) -> None: + from kubernetes.client.models import V1Container, V1ObjectMeta, V1Pod, V1PodSpec + from torchx.schedulers.kubernetes_scheduler import _apply_pod_overlay + + # Pod without nodeSelector or tolerations + pod = V1Pod( + spec=V1PodSpec(containers=[V1Container(name="test", image="test:latest")]), + metadata=V1ObjectMeta(name="test-pod"), + ) + + # Overlay adds fields not present in original + overlay = { + "spec": { + "nodeSelector": {"gpu": "true"}, + "tolerations": [{"key": "nvidia.com/gpu", "operator": "Exists"}], + "affinity": { + "nodeAffinity": { + "requiredDuringSchedulingIgnoredDuringExecution": { + "nodeSelectorTerms": [ + { + "matchExpressions": [ + { + "key": "gpu", + "operator": "In", + "values": ["true"], + } + ] + } + ] + } + } + }, + } + } + + _apply_pod_overlay(pod, overlay) + + self.assertEqual(pod.spec.node_selector, {"gpu": "true"}) + self.assertEqual(len(pod.spec.tolerations), 1) + self.assertIsNotNone(pod.spec.affinity) + self.assertIsNotNone(pod.spec.affinity.node_affinity) + + def test_submit_dryrun_with_pod_overlay(self) -> None: + scheduler = create_scheduler("test") + + # Create app with metadata + trainer_role = specs.Role( + name="trainer", + image="pytorch/torchx:latest", + entrypoint="main", + resource=specs.Resource(cpu=1, memMB=1000, gpu=0), + metadata={"kubernetes": {"spec": {"nodeSelector": {"gpu": "true"}}}}, + ) + app = specs.AppDef("test", roles=[trainer_role]) + cfg = KubernetesOpts({"queue": "testqueue"}) + + info = scheduler.submit_dryrun(app, cfg) + resource = info.request.resource + + # Check that overlay was applied to all pods + tasks = resource["spec"]["tasks"] + for task in tasks: + pod = task["template"] + self.assertIn("gpu", pod.spec.node_selector) + self.assertEqual(pod.spec.node_selector["gpu"], "true") + + def test_submit_dryrun_with_pod_overlay_file_uri(self) -> None: + import tempfile + + import yaml + + scheduler = create_scheduler("test") + + # Create overlay file + overlay = {"spec": {"nodeSelector": {"instance-type": "p4d.24xlarge"}}} + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(overlay, f) + overlay_path = f.name + + try: + # Create app with file URI + trainer_role = specs.Role( + name="trainer", + image="pytorch/torchx:latest", + entrypoint="main", + resource=specs.Resource(cpu=1, memMB=1000, gpu=0), + metadata={"kubernetes": f"file://{overlay_path}"}, + ) + app = specs.AppDef("test", roles=[trainer_role]) + cfg = KubernetesOpts({"queue": "testqueue"}) + + info = scheduler.submit_dryrun(app, cfg) + resource = info.request.resource + + # Check that overlay was applied + tasks = resource["spec"]["tasks"] + for task in tasks: + pod = task["template"] + self.assertIn("instance-type", pod.spec.node_selector) + self.assertEqual( + pod.spec.node_selector["instance-type"], "p4d.24xlarge" + ) + finally: + import os + + os.unlink(overlay_path) + + def test_apply_pod_overlay_list_append(self) -> None: + from kubernetes.client.models import ( + V1Container, + V1ObjectMeta, + V1Pod, + V1PodSpec, + V1Toleration, + ) + from torchx.schedulers.kubernetes_scheduler import _apply_pod_overlay + + pod = V1Pod( + spec=V1PodSpec( + containers=[V1Container(name="test", image="test:latest")], + tolerations=[V1Toleration(key="existing", operator="Exists")], + ), + metadata=V1ObjectMeta(name="test-pod"), + ) + + overlay = { + "spec": { + "tolerations": [{"key": "nvidia.com/gpu", "operator": "Exists"}], + } + } + + _apply_pod_overlay(pod, overlay) + + self.assertEqual(len(pod.spec.tolerations), 2) + self.assertEqual(pod.spec.tolerations[0].key, "existing") + self.assertEqual(pod.spec.tolerations[1].key, "nvidia.com/gpu") + + def test_apply_pod_overlay_list_replace_tuple(self) -> None: + from kubernetes.client.models import ( + V1Container, + V1ObjectMeta, + V1Pod, + V1PodSpec, + V1Toleration, + ) + from torchx.schedulers.kubernetes_scheduler import _apply_pod_overlay + + pod = V1Pod( + spec=V1PodSpec( + containers=[V1Container(name="test", image="test:latest")], + tolerations=[V1Toleration(key="existing", operator="Exists")], + ), + metadata=V1ObjectMeta(name="test-pod"), + ) + + overlay = { + "spec": { + "tolerations": ({"key": "nvidia.com/gpu", "operator": "Exists"},), + } + } + + _apply_pod_overlay(pod, overlay) + + self.assertEqual(len(pod.spec.tolerations), 1) + self.assertEqual(pod.spec.tolerations[0].key, "nvidia.com/gpu") + + def test_submit_dryrun_with_pod_overlay_yaml_replace_tag(self) -> None: + import tempfile + + import yaml + + scheduler = create_scheduler("test") + + overlay_yaml = """ +spec: + tolerations: !!python/tuple + - key: nvidia.com/gpu + operator: Exists +""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write(overlay_yaml) + overlay_path = f.name + + try: + trainer_role = specs.Role( + name="trainer", + image="pytorch/torchx:latest", + entrypoint="main", + resource=specs.Resource(cpu=1, memMB=1000, gpu=0), + metadata={"kubernetes": f"file://{overlay_path}"}, + ) + app = specs.AppDef("test", roles=[trainer_role]) + cfg = KubernetesOpts({"queue": "testqueue"}) + + info = scheduler.submit_dryrun(app, cfg) + resource = info.request.resource + + tasks = resource["spec"]["tasks"] + for task in tasks: + pod = task["template"] + self.assertEqual(len(pod.spec.tolerations), 1) + self.assertEqual(pod.spec.tolerations[0].key, "nvidia.com/gpu") + finally: + import os + + os.unlink(overlay_path) + + def test_submit_dryrun_with_pod_overlay_invalid_type(self) -> None: + scheduler = create_scheduler("test") + + # Create app with invalid metadata type + trainer_role = specs.Role( + name="trainer", + image="pytorch/torchx:latest", + entrypoint="main", + resource=specs.Resource(cpu=1, memMB=1000, gpu=0), + metadata={"kubernetes": 123}, # Invalid type + ) + app = specs.AppDef("test", roles=[trainer_role]) + cfg = KubernetesOpts({"queue": "testqueue"}) + + with self.assertRaises(ValueError) as ctx: + scheduler.submit_dryrun(app, cfg) + + def test_validate_spec_long_pod_name(self) -> None: scheduler = create_scheduler("test") app = _test_app() app.name = "x" * 50 app.roles[0].name = "y" * 20 - mock_api_instance = MagicMock() - mock_api_instance.create_namespaced_custom_object.return_value = {} - mock_api.return_value = mock_api_instance - cfg = cast(KubernetesOpts, {"queue": "testqueue", "validate_spec": True}) with patch(