Skip to content

Commit 5a35922

Browse files
clumsyazzhipa
andauthored
feat: pod overlay for kubernetes scheduler (#1067,#1068) (#1148)
Co-authored-by: Alexander Zhipa <azzhipa@amazon.com>
1 parent a01f925 commit 5a35922

File tree

2 files changed

+407
-75
lines changed

2 files changed

+407
-75
lines changed

torchx/schedulers/kubernetes_scheduler.py

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,76 @@
2727
See the
2828
`Volcano Quickstart <https://github.com/volcano-sh/volcano>`_
2929
for more information.
30+
31+
Pod Overlay
32+
===========
33+
34+
You can overlay arbitrary Kubernetes Pod fields on generated pods by setting
35+
the ``kubernetes`` metadata on your role. The value can be:
36+
37+
- A dict with the overlay structure
38+
- A resource URI pointing to a YAML file (e.g. ``file://``, ``s3://``, ``gs://``)
39+
40+
Merge semantics:
41+
- **dict**: recursive merge (upsert)
42+
- **list**: append by default, replace if tuple (Python) or ``!!python/tuple`` tag (YAML)
43+
- **primitives**: replace
44+
45+
.. code:: python
46+
47+
from torchx.specs import Role
48+
49+
# Dict overlay - lists append, tuples replace
50+
role = Role(
51+
name="trainer",
52+
image="my-image:latest",
53+
entrypoint="train.py",
54+
metadata={
55+
"kubernetes": {
56+
"spec": {
57+
"nodeSelector": {"gpu": "true"},
58+
"tolerations": [{"key": "nvidia.com/gpu", "operator": "Exists"}], # appends
59+
"volumes": ({"name": "my-volume", "emptyDir": {}},) # replaces
60+
}
61+
}
62+
}
63+
)
64+
65+
# File URI overlay
66+
role = Role(
67+
name="trainer",
68+
image="my-image:latest",
69+
entrypoint="train.py",
70+
metadata={
71+
"kubernetes": "file:///path/to/pod_overlay.yaml"
72+
}
73+
)
74+
75+
CLI usage with builtin components:
76+
77+
.. code:: bash
78+
79+
$ torchx run --scheduler kubernetes dist.ddp \\
80+
--metadata kubernetes=file:///path/to/pod_overlay.yaml \\
81+
--script train.py
82+
83+
Example ``pod_overlay.yaml``:
84+
85+
.. code:: yaml
86+
87+
spec:
88+
nodeSelector:
89+
node.kubernetes.io/instance-type: p4d.24xlarge
90+
tolerations:
91+
- key: nvidia.com/gpu
92+
operator: Exists
93+
effect: NoSchedule
94+
volumes: !!python/tuple
95+
- name: my-volume
96+
emptyDir: {}
97+
98+
The overlay is deep-merged with the generated pod, preserving existing fields
99+
and adding or overriding specified ones.
30100
"""
31101

32102
import json
@@ -45,6 +115,7 @@
45115
Tuple,
46116
TYPE_CHECKING,
47117
TypedDict,
118+
Union,
48119
)
49120

50121
import torchx
@@ -97,6 +168,40 @@
97168
RESERVED_MILLICPU = 100
98169
RESERVED_MEMMB = 1024
99170

171+
172+
def _apply_pod_overlay(pod: "V1Pod", overlay: Dict[str, Any]) -> None:
173+
"""Apply overlay dict to V1Pod object, merging nested fields.
174+
175+
Merge semantics:
176+
- dict: upsert (recursive merge)
177+
- list: append by default, replace if tuple
178+
- primitives: replace
179+
"""
180+
from kubernetes import client
181+
182+
api = client.ApiClient()
183+
pod_dict = api.sanitize_for_serialization(pod)
184+
185+
def deep_merge(base: Dict[str, Any], overlay: Dict[str, Any]) -> None:
186+
for key, value in overlay.items():
187+
if isinstance(value, dict) and key in base and isinstance(base[key], dict):
188+
deep_merge(base[key], value)
189+
elif isinstance(value, tuple):
190+
base[key] = list(value)
191+
elif (
192+
isinstance(value, list) and key in base and isinstance(base[key], list)
193+
):
194+
base[key].extend(value)
195+
else:
196+
base[key] = value
197+
198+
deep_merge(pod_dict, overlay)
199+
200+
merged_pod = api._ApiClient__deserialize(pod_dict, "V1Pod")
201+
pod.spec = merged_pod.spec
202+
pod.metadata = merged_pod.metadata
203+
204+
100205
RETRY_POLICIES: Mapping[str, Iterable[Mapping[str, str]]] = {
101206
RetryPolicy.REPLICA: [],
102207
RetryPolicy.APPLICATION: [
@@ -402,6 +507,17 @@ def app_to_resource(
402507
replica_role.env["TORCHX_IMAGE"] = replica_role.image
403508

404509
pod = role_to_pod(name, replica_role, service_account)
510+
if k8s_metadata := role.metadata.get("kubernetes"):
511+
if isinstance(k8s_metadata, str):
512+
import fsspec
513+
514+
with fsspec.open(k8s_metadata, "r") as f:
515+
k8s_metadata = yaml.unsafe_load(f)
516+
elif not isinstance(k8s_metadata, dict):
517+
raise ValueError(
518+
f"metadata['kubernetes'] must be a dict or resource URI, got {type(k8s_metadata)}"
519+
)
520+
_apply_pod_overlay(pod, k8s_metadata)
405521
pod.metadata.labels.update(
406522
pod_labels(
407523
app=app,
@@ -637,7 +753,7 @@ def schedule(self, dryrun_info: AppDryRunInfo[KubernetesJob]) -> str:
637753
else:
638754
raise
639755

640-
return f'{namespace}:{resp["metadata"]["name"]}'
756+
return f"{namespace}:{resp['metadata']['name']}"
641757

642758
def _submit_dryrun(
643759
self, app: AppDef, cfg: KubernetesOpts

0 commit comments

Comments
 (0)