Skip to content

Commit 483f7af

Browse files
committed
feat: pod overlay for kubernetes scheduler (#1067,#1068)
1 parent 1d26b39 commit 483f7af

File tree

2 files changed

+307
-1
lines changed

2 files changed

+307
-1
lines changed

torchx/schedulers/kubernetes_scheduler.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,90 @@
2727
See the
2828
`Volcano Quickstart <https://github.com/volcano-sh/volcano>`_
2929
for more information.
30+
31+
Pod Overlay
32+
===========
33+
34+
You can overlay arbitrary Kubernetes Pod fields on generated pods by providing
35+
a callable that receives the generated pod dict and returns the modified pod dict.
36+
37+
.. code:: python
38+
39+
from torchx.specs import AppDef, Role
40+
41+
# Simple merge - replaces lists
42+
role = Role(
43+
name="trainer",
44+
image="my-image:latest",
45+
entrypoint="train.py",
46+
metadata={
47+
"kubernetes": lambda pod: {
48+
**pod,
49+
"spec": {
50+
**pod["spec"],
51+
"nodeSelector": {**pod["spec"].get("nodeSelector", {}), "gpu": "true"},
52+
"tolerations": [{"key": "nvidia.com/gpu", "operator": "Exists"}]
53+
}
54+
}
55+
}
56+
)
57+
58+
# Append to lists
59+
role = Role(
60+
name="trainer",
61+
image="my-image:latest",
62+
entrypoint="train.py",
63+
metadata={
64+
"kubernetes": lambda pod: {
65+
**pod,
66+
"spec": {
67+
**pod["spec"],
68+
"tolerations": pod["spec"].get("tolerations", []) + [
69+
{"key": "nvidia.com/gpu", "operator": "Exists"}
70+
]
71+
}
72+
}
73+
}
74+
)
75+
76+
# Load from YAML file
77+
import yaml
78+
import fsspec
79+
80+
with fsspec.open("file:///path/to/overlay.yaml", "r") as f:
81+
overlay_dict = yaml.safe_load(f)
82+
83+
role = Role(
84+
name="trainer",
85+
image="my-image:latest",
86+
entrypoint="train.py",
87+
metadata={
88+
"kubernetes": lambda pod: {
89+
**pod,
90+
"spec": {
91+
**pod["spec"],
92+
**overlay_dict.get("spec", {}),
93+
"tolerations": pod["spec"].get("tolerations", []) +
94+
overlay_dict.get("spec", {}).get("tolerations", [])
95+
}
96+
}
97+
}
98+
)
99+
100+
Example ``overlay.yaml``:
101+
102+
.. code:: yaml
103+
104+
spec:
105+
nodeSelector:
106+
node.kubernetes.io/instance-type: p4d.24xlarge
107+
tolerations:
108+
- key: nvidia.com/gpu
109+
operator: Exists
110+
effect: NoSchedule
111+
112+
The overlay is deep-merged with the generated pod, preserving existing fields
113+
and adding or overriding specified ones.
30114
"""
31115

32116
import json
@@ -36,6 +120,7 @@
36120
from datetime import datetime
37121
from typing import (
38122
Any,
123+
Callable,
39124
cast,
40125
Dict,
41126
Iterable,
@@ -45,6 +130,7 @@
45130
Tuple,
46131
TYPE_CHECKING,
47132
TypedDict,
133+
Union,
48134
)
49135

50136
import torchx
@@ -97,6 +183,30 @@
97183
RESERVED_MILLICPU = 100
98184
RESERVED_MEMMB = 1024
99185

186+
187+
def apply_pod_overlay(
188+
pod: "V1Pod",
189+
overlay: Callable[[Dict[str, Any]], Dict[str, Any]],
190+
) -> "V1Pod":
191+
"""Apply overlay function to V1Pod object.
192+
193+
Args:
194+
pod: Kubernetes V1Pod object to modify
195+
overlay: Callable that receives pod dict and returns modified pod dict
196+
197+
Returns:
198+
Modified V1Pod object
199+
"""
200+
from kubernetes import client
201+
202+
assert callable(overlay), f"overlay must be callable, got {type(overlay)}"
203+
204+
api = client.ApiClient()
205+
return api._ApiClient__deserialize(
206+
overlay(api.sanitize_for_serialization(pod)), "V1Pod"
207+
)
208+
209+
100210
RETRY_POLICIES: Mapping[str, Iterable[Mapping[str, str]]] = {
101211
RetryPolicy.REPLICA: [],
102212
RetryPolicy.APPLICATION: [
@@ -402,6 +512,8 @@ def app_to_resource(
402512
replica_role.env["TORCHX_IMAGE"] = replica_role.image
403513

404514
pod = role_to_pod(name, replica_role, service_account)
515+
if k8s_overlay := role.metadata.get("kubernetes"):
516+
pod = apply_pod_overlay(pod, k8s_overlay)
405517
pod.metadata.labels.update(
406518
pod_labels(
407519
app=app,
@@ -636,7 +748,7 @@ def schedule(self, dryrun_info: AppDryRunInfo[KubernetesJob]) -> str:
636748
else:
637749
raise
638750

639-
return f'{namespace}:{resp["metadata"]["name"]}'
751+
return f"{namespace}:{resp['metadata']['name']}"
640752

641753
def _submit_dryrun(
642754
self, app: AppDef, cfg: KubernetesOpts

torchx/schedulers/test/kubernetes_scheduler_test.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,200 @@ def test_min_replicas(self) -> None:
929929
]
930930
self.assertEqual(min_available, [1, 1, 0])
931931

932+
def test_apply_pod_overlay_merge(self) -> None:
933+
from kubernetes.client.models import V1Container, V1ObjectMeta, V1Pod, V1PodSpec
934+
from torchx.schedulers.kubernetes_scheduler import ( # pyre-ignore[21]
935+
apply_pod_overlay,
936+
)
937+
938+
pod = V1Pod(
939+
spec=V1PodSpec(
940+
containers=[V1Container(name="test", image="test:latest")],
941+
node_selector={"existing": "label"},
942+
),
943+
metadata=V1ObjectMeta(name="test-pod"),
944+
)
945+
946+
overlay = lambda pod_dict: {
947+
**pod_dict,
948+
"spec": {
949+
**pod_dict["spec"],
950+
"nodeSelector": {
951+
**pod_dict["spec"].get("nodeSelector", {}),
952+
"gpu": "true",
953+
},
954+
"tolerations": [{"key": "nvidia.com/gpu", "operator": "Exists"}],
955+
},
956+
}
957+
958+
pod = apply_pod_overlay(pod, overlay) # pyre-ignore[16]
959+
960+
self.assertEqual(pod.spec.node_selector, {"existing": "label", "gpu": "true"})
961+
self.assertEqual(len(pod.spec.tolerations), 1)
962+
self.assertEqual(pod.spec.tolerations[0].key, "nvidia.com/gpu")
963+
964+
def test_apply_pod_overlay_append_lists(self) -> None:
965+
from kubernetes.client.models import V1Container, V1ObjectMeta, V1Pod, V1PodSpec
966+
from torchx.schedulers.kubernetes_scheduler import apply_pod_overlay
967+
968+
pod = V1Pod(
969+
spec=V1PodSpec(
970+
containers=[V1Container(name="test", image="test:latest")],
971+
tolerations=[{"key": "existing", "operator": "Exists"}],
972+
),
973+
metadata=V1ObjectMeta(name="test-pod"),
974+
)
975+
976+
overlay = lambda pod_dict: {
977+
**pod_dict,
978+
"spec": {
979+
**pod_dict["spec"],
980+
"tolerations": pod_dict["spec"].get("tolerations", [])
981+
+ [{"key": "nvidia.com/gpu", "operator": "Exists"}],
982+
},
983+
}
984+
985+
pod = apply_pod_overlay(pod, overlay) # pyre-ignore[16]
986+
987+
self.assertEqual(len(pod.spec.tolerations), 2)
988+
self.assertEqual(pod.spec.tolerations[0].key, "existing")
989+
self.assertEqual(pod.spec.tolerations[1].key, "nvidia.com/gpu")
990+
991+
def test_submit_dryrun_with_pod_overlay(self) -> None:
992+
scheduler = create_scheduler("test")
993+
994+
trainer_role = specs.Role(
995+
name="trainer",
996+
image="pytorch/torchx:latest",
997+
entrypoint="main",
998+
resource=specs.Resource(cpu=1, memMB=1000, gpu=0),
999+
metadata={
1000+
"kubernetes": lambda pod: {
1001+
**pod,
1002+
"spec": {
1003+
**pod["spec"],
1004+
"nodeSelector": {
1005+
**pod["spec"].get("nodeSelector", {}),
1006+
"gpu": "true",
1007+
},
1008+
},
1009+
}
1010+
},
1011+
)
1012+
app = specs.AppDef("test", roles=[trainer_role])
1013+
cfg = KubernetesOpts({"queue": "testqueue"})
1014+
1015+
info = scheduler.submit_dryrun(app, cfg)
1016+
resource = info.request.resource
1017+
1018+
tasks = resource["spec"]["tasks"] # pyre-ignore[16]
1019+
for task in tasks:
1020+
pod = task["template"]
1021+
self.assertIn("gpu", pod.spec.node_selector)
1022+
self.assertEqual(pod.spec.node_selector["gpu"], "true")
1023+
1024+
def test_submit_dryrun_with_pod_overlay_from_yaml(self) -> None:
1025+
import tempfile
1026+
1027+
import yaml
1028+
1029+
scheduler = create_scheduler("test")
1030+
1031+
overlay_dict = {"spec": {"nodeSelector": {"instance-type": "p4d.24xlarge"}}}
1032+
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
1033+
yaml.dump(overlay_dict, f)
1034+
overlay_path = f.name
1035+
1036+
try:
1037+
import fsspec
1038+
1039+
with fsspec.open(f"file://{overlay_path}", "r") as f:
1040+
loaded_overlay = yaml.safe_load(f)
1041+
1042+
trainer_role = specs.Role(
1043+
name="trainer",
1044+
image="pytorch/torchx:latest",
1045+
entrypoint="main",
1046+
resource=specs.Resource(cpu=1, memMB=1000, gpu=0),
1047+
metadata={
1048+
"kubernetes": lambda pod: {
1049+
**pod,
1050+
"spec": {**pod["spec"], **loaded_overlay.get("spec", {})},
1051+
}
1052+
},
1053+
)
1054+
app = specs.AppDef("test", roles=[trainer_role])
1055+
cfg = KubernetesOpts({"queue": "testqueue"})
1056+
1057+
info = scheduler.submit_dryrun(app, cfg)
1058+
resource = info.request.resource
1059+
1060+
tasks = resource["spec"]["tasks"] # pyre-ignore[16]
1061+
for task in tasks:
1062+
pod = task["template"]
1063+
self.assertIn("instance-type", pod.spec.node_selector)
1064+
self.assertEqual(
1065+
pod.spec.node_selector["instance-type"], "p4d.24xlarge"
1066+
)
1067+
finally:
1068+
import os
1069+
1070+
os.unlink(overlay_path)
1071+
1072+
def test_submit_dryrun_with_pod_overlay_append_from_yaml(self) -> None:
1073+
import tempfile
1074+
1075+
import yaml
1076+
1077+
scheduler = create_scheduler("test")
1078+
1079+
overlay_dict = {
1080+
"spec": {"tolerations": [{"key": "nvidia.com/gpu", "operator": "Exists"}]}
1081+
}
1082+
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
1083+
yaml.dump(overlay_dict, f)
1084+
overlay_path = f.name
1085+
1086+
try:
1087+
import fsspec
1088+
1089+
with fsspec.open(f"file://{overlay_path}", "r") as f:
1090+
loaded_overlay = yaml.safe_load(f)
1091+
1092+
trainer_role = specs.Role(
1093+
name="trainer",
1094+
image="pytorch/torchx:latest",
1095+
entrypoint="main",
1096+
resource=specs.Resource(cpu=1, memMB=1000, gpu=0),
1097+
metadata={
1098+
"kubernetes": lambda pod: {
1099+
**pod,
1100+
"spec": {
1101+
**pod["spec"],
1102+
"tolerations": pod["spec"].get("tolerations", [])
1103+
+ loaded_overlay.get("spec", {}).get("tolerations", []),
1104+
},
1105+
}
1106+
},
1107+
)
1108+
app = specs.AppDef("test", roles=[trainer_role])
1109+
cfg = KubernetesOpts({"queue": "testqueue"})
1110+
1111+
info = scheduler.submit_dryrun(app, cfg)
1112+
resource = info.request.resource
1113+
1114+
tasks = resource["spec"]["tasks"] # pyre-ignore[16]
1115+
for task in tasks:
1116+
pod = task["template"]
1117+
self.assertIsNotNone(pod.spec.tolerations)
1118+
self.assertTrue(
1119+
any(t.key == "nvidia.com/gpu" for t in pod.spec.tolerations)
1120+
)
1121+
finally:
1122+
import os
1123+
1124+
os.unlink(overlay_path)
1125+
9321126

9331127
class KubernetesSchedulerNoImportTest(unittest.TestCase):
9341128
"""

0 commit comments

Comments
 (0)