|
27 | 27 | See the |
28 | 28 | `Volcano Quickstart <https://github.com/volcano-sh/volcano>`_ |
29 | 29 | for more information. |
| 30 | +
|
| 31 | +Pod Overlay |
| 32 | +=========== |
| 33 | +
|
| 34 | +You can overlay arbitrary Kubernetes Pod fields on generated pods by setting |
| 35 | +the ``kubernetes`` metadata on your role. The value can be: |
| 36 | +
|
| 37 | +- A dict with the overlay structure |
| 38 | +- A resource URI pointing to a YAML file (e.g. ``file://``, ``s3://``, ``gs://``) |
| 39 | +
|
| 40 | +Merge semantics: |
| 41 | +- **dict**: recursive merge (upsert) |
| 42 | +- **list**: append by default, replace if tuple (Python) or ``!!python/tuple`` tag (YAML) |
| 43 | +- **primitives**: replace |
| 44 | +
|
| 45 | +.. code:: python |
| 46 | +
|
| 47 | + from torchx.specs import Role |
| 48 | +
|
| 49 | + # Dict overlay - lists append, tuples replace |
| 50 | + role = Role( |
| 51 | + name="trainer", |
| 52 | + image="my-image:latest", |
| 53 | + entrypoint="train.py", |
| 54 | + metadata={ |
| 55 | + "kubernetes": { |
| 56 | + "spec": { |
| 57 | + "nodeSelector": {"gpu": "true"}, |
| 58 | + "tolerations": [{"key": "nvidia.com/gpu", "operator": "Exists"}], # appends |
| 59 | + "volumes": ({"name": "my-volume", "emptyDir": {}},) # replaces |
| 60 | + } |
| 61 | + } |
| 62 | + } |
| 63 | + ) |
| 64 | +
|
| 65 | + # File URI overlay |
| 66 | + role = Role( |
| 67 | + name="trainer", |
| 68 | + image="my-image:latest", |
| 69 | + entrypoint="train.py", |
| 70 | + metadata={ |
| 71 | + "kubernetes": "file:///path/to/pod_overlay.yaml" |
| 72 | + } |
| 73 | + ) |
| 74 | +
|
| 75 | +CLI usage with builtin components: |
| 76 | +
|
| 77 | +.. code:: bash |
| 78 | +
|
| 79 | + $ torchx run --scheduler kubernetes dist.ddp \\ |
| 80 | + --metadata kubernetes=file:///path/to/pod_overlay.yaml \\ |
| 81 | + --script train.py |
| 82 | +
|
| 83 | +Example ``pod_overlay.yaml``: |
| 84 | +
|
| 85 | +.. code:: yaml |
| 86 | +
|
| 87 | + spec: |
| 88 | + nodeSelector: |
| 89 | + node.kubernetes.io/instance-type: p4d.24xlarge |
| 90 | + tolerations: |
| 91 | + - key: nvidia.com/gpu |
| 92 | + operator: Exists |
| 93 | + effect: NoSchedule |
| 94 | + volumes: !!python/tuple |
| 95 | + - name: my-volume |
| 96 | + emptyDir: {} |
| 97 | +
|
| 98 | +The overlay is deep-merged with the generated pod, preserving existing fields |
| 99 | +and adding or overriding specified ones. |
30 | 100 | """ |
31 | 101 |
|
32 | 102 | import json |
|
45 | 115 | Tuple, |
46 | 116 | TYPE_CHECKING, |
47 | 117 | TypedDict, |
| 118 | + Union, |
48 | 119 | ) |
49 | 120 |
|
50 | 121 | import torchx |
|
97 | 168 | RESERVED_MILLICPU = 100 |
98 | 169 | RESERVED_MEMMB = 1024 |
99 | 170 |
|
| 171 | + |
| 172 | +def _apply_pod_overlay(pod: "V1Pod", overlay: Dict[str, Any]) -> None: |
| 173 | + """Apply overlay dict to V1Pod object, merging nested fields. |
| 174 | +
|
| 175 | + Merge semantics: |
| 176 | + - dict: upsert (recursive merge) |
| 177 | + - list: append by default, replace if tuple |
| 178 | + - primitives: replace |
| 179 | + """ |
| 180 | + from kubernetes import client |
| 181 | + |
| 182 | + api = client.ApiClient() |
| 183 | + pod_dict = api.sanitize_for_serialization(pod) |
| 184 | + |
| 185 | + def deep_merge(base: Dict[str, Any], overlay: Dict[str, Any]) -> None: |
| 186 | + for key, value in overlay.items(): |
| 187 | + if isinstance(value, dict) and key in base and isinstance(base[key], dict): |
| 188 | + deep_merge(base[key], value) |
| 189 | + elif isinstance(value, tuple): |
| 190 | + base[key] = list(value) |
| 191 | + elif ( |
| 192 | + isinstance(value, list) and key in base and isinstance(base[key], list) |
| 193 | + ): |
| 194 | + base[key].extend(value) |
| 195 | + else: |
| 196 | + base[key] = value |
| 197 | + |
| 198 | + deep_merge(pod_dict, overlay) |
| 199 | + |
| 200 | + merged_pod = api._ApiClient__deserialize(pod_dict, "V1Pod") |
| 201 | + pod.spec = merged_pod.spec |
| 202 | + pod.metadata = merged_pod.metadata |
| 203 | + |
| 204 | + |
100 | 205 | RETRY_POLICIES: Mapping[str, Iterable[Mapping[str, str]]] = { |
101 | 206 | RetryPolicy.REPLICA: [], |
102 | 207 | RetryPolicy.APPLICATION: [ |
@@ -402,6 +507,17 @@ def app_to_resource( |
402 | 507 | replica_role.env["TORCHX_IMAGE"] = replica_role.image |
403 | 508 |
|
404 | 509 | pod = role_to_pod(name, replica_role, service_account) |
| 510 | + if k8s_metadata := role.metadata.get("kubernetes"): |
| 511 | + if isinstance(k8s_metadata, str): |
| 512 | + import fsspec |
| 513 | + |
| 514 | + with fsspec.open(k8s_metadata, "r") as f: |
| 515 | + k8s_metadata = yaml.unsafe_load(f) |
| 516 | + elif not isinstance(k8s_metadata, dict): |
| 517 | + raise ValueError( |
| 518 | + f"metadata['kubernetes'] must be a dict or resource URI, got {type(k8s_metadata)}" |
| 519 | + ) |
| 520 | + _apply_pod_overlay(pod, k8s_metadata) |
405 | 521 | pod.metadata.labels.update( |
406 | 522 | pod_labels( |
407 | 523 | app=app, |
@@ -637,7 +753,7 @@ def schedule(self, dryrun_info: AppDryRunInfo[KubernetesJob]) -> str: |
637 | 753 | else: |
638 | 754 | raise |
639 | 755 |
|
640 | | - return f'{namespace}:{resp["metadata"]["name"]}' |
| 756 | + return f"{namespace}:{resp['metadata']['name']}" |
641 | 757 |
|
642 | 758 | def _submit_dryrun( |
643 | 759 | self, app: AppDef, cfg: KubernetesOpts |
|
0 commit comments