@@ -929,6 +929,200 @@ def test_min_replicas(self) -> None:
929929 ]
930930 self .assertEqual (min_available , [1 , 1 , 0 ])
931931
932+ def test_apply_pod_overlay_merge (self ) -> None :
933+ from kubernetes .client .models import V1Container , V1ObjectMeta , V1Pod , V1PodSpec
934+ from torchx .schedulers .kubernetes_scheduler import ( # pyre-ignore[21]
935+ apply_pod_overlay ,
936+ )
937+
938+ pod = V1Pod (
939+ spec = V1PodSpec (
940+ containers = [V1Container (name = "test" , image = "test:latest" )],
941+ node_selector = {"existing" : "label" },
942+ ),
943+ metadata = V1ObjectMeta (name = "test-pod" ),
944+ )
945+
946+ overlay = lambda pod_dict : {
947+ ** pod_dict ,
948+ "spec" : {
949+ ** pod_dict ["spec" ],
950+ "nodeSelector" : {
951+ ** pod_dict ["spec" ].get ("nodeSelector" , {}),
952+ "gpu" : "true" ,
953+ },
954+ "tolerations" : [{"key" : "nvidia.com/gpu" , "operator" : "Exists" }],
955+ },
956+ }
957+
958+ pod = apply_pod_overlay (pod , overlay ) # pyre-ignore[16]
959+
960+ self .assertEqual (pod .spec .node_selector , {"existing" : "label" , "gpu" : "true" })
961+ self .assertEqual (len (pod .spec .tolerations ), 1 )
962+ self .assertEqual (pod .spec .tolerations [0 ].key , "nvidia.com/gpu" )
963+
964+ def test_apply_pod_overlay_append_lists (self ) -> None :
965+ from kubernetes .client .models import V1Container , V1ObjectMeta , V1Pod , V1PodSpec
966+ from torchx .schedulers .kubernetes_scheduler import apply_pod_overlay
967+
968+ pod = V1Pod (
969+ spec = V1PodSpec (
970+ containers = [V1Container (name = "test" , image = "test:latest" )],
971+ tolerations = [{"key" : "existing" , "operator" : "Exists" }],
972+ ),
973+ metadata = V1ObjectMeta (name = "test-pod" ),
974+ )
975+
976+ overlay = lambda pod_dict : {
977+ ** pod_dict ,
978+ "spec" : {
979+ ** pod_dict ["spec" ],
980+ "tolerations" : pod_dict ["spec" ].get ("tolerations" , [])
981+ + [{"key" : "nvidia.com/gpu" , "operator" : "Exists" }],
982+ },
983+ }
984+
985+ pod = apply_pod_overlay (pod , overlay ) # pyre-ignore[16]
986+
987+ self .assertEqual (len (pod .spec .tolerations ), 2 )
988+ self .assertEqual (pod .spec .tolerations [0 ].key , "existing" )
989+ self .assertEqual (pod .spec .tolerations [1 ].key , "nvidia.com/gpu" )
990+
991+ def test_submit_dryrun_with_pod_overlay (self ) -> None :
992+ scheduler = create_scheduler ("test" )
993+
994+ trainer_role = specs .Role (
995+ name = "trainer" ,
996+ image = "pytorch/torchx:latest" ,
997+ entrypoint = "main" ,
998+ resource = specs .Resource (cpu = 1 , memMB = 1000 , gpu = 0 ),
999+ metadata = {
1000+ "kubernetes" : lambda pod : {
1001+ ** pod ,
1002+ "spec" : {
1003+ ** pod ["spec" ],
1004+ "nodeSelector" : {
1005+ ** pod ["spec" ].get ("nodeSelector" , {}),
1006+ "gpu" : "true" ,
1007+ },
1008+ },
1009+ }
1010+ },
1011+ )
1012+ app = specs .AppDef ("test" , roles = [trainer_role ])
1013+ cfg = KubernetesOpts ({"queue" : "testqueue" })
1014+
1015+ info = scheduler .submit_dryrun (app , cfg )
1016+ resource = info .request .resource
1017+
1018+ tasks = resource ["spec" ]["tasks" ] # pyre-ignore[16]
1019+ for task in tasks :
1020+ pod = task ["template" ]
1021+ self .assertIn ("gpu" , pod .spec .node_selector )
1022+ self .assertEqual (pod .spec .node_selector ["gpu" ], "true" )
1023+
1024+ def test_submit_dryrun_with_pod_overlay_from_yaml (self ) -> None :
1025+ import tempfile
1026+
1027+ import yaml
1028+
1029+ scheduler = create_scheduler ("test" )
1030+
1031+ overlay_dict = {"spec" : {"nodeSelector" : {"instance-type" : "p4d.24xlarge" }}}
1032+ with tempfile .NamedTemporaryFile (mode = "w" , suffix = ".yaml" , delete = False ) as f :
1033+ yaml .dump (overlay_dict , f )
1034+ overlay_path = f .name
1035+
1036+ try :
1037+ import fsspec
1038+
1039+ with fsspec .open (f"file://{ overlay_path } " , "r" ) as f :
1040+ loaded_overlay = yaml .safe_load (f )
1041+
1042+ trainer_role = specs .Role (
1043+ name = "trainer" ,
1044+ image = "pytorch/torchx:latest" ,
1045+ entrypoint = "main" ,
1046+ resource = specs .Resource (cpu = 1 , memMB = 1000 , gpu = 0 ),
1047+ metadata = {
1048+ "kubernetes" : lambda pod : {
1049+ ** pod ,
1050+ "spec" : {** pod ["spec" ], ** loaded_overlay .get ("spec" , {})},
1051+ }
1052+ },
1053+ )
1054+ app = specs .AppDef ("test" , roles = [trainer_role ])
1055+ cfg = KubernetesOpts ({"queue" : "testqueue" })
1056+
1057+ info = scheduler .submit_dryrun (app , cfg )
1058+ resource = info .request .resource
1059+
1060+ tasks = resource ["spec" ]["tasks" ] # pyre-ignore[16]
1061+ for task in tasks :
1062+ pod = task ["template" ]
1063+ self .assertIn ("instance-type" , pod .spec .node_selector )
1064+ self .assertEqual (
1065+ pod .spec .node_selector ["instance-type" ], "p4d.24xlarge"
1066+ )
1067+ finally :
1068+ import os
1069+
1070+ os .unlink (overlay_path )
1071+
1072+ def test_submit_dryrun_with_pod_overlay_append_from_yaml (self ) -> None :
1073+ import tempfile
1074+
1075+ import yaml
1076+
1077+ scheduler = create_scheduler ("test" )
1078+
1079+ overlay_dict = {
1080+ "spec" : {"tolerations" : [{"key" : "nvidia.com/gpu" , "operator" : "Exists" }]}
1081+ }
1082+ with tempfile .NamedTemporaryFile (mode = "w" , suffix = ".yaml" , delete = False ) as f :
1083+ yaml .dump (overlay_dict , f )
1084+ overlay_path = f .name
1085+
1086+ try :
1087+ import fsspec
1088+
1089+ with fsspec .open (f"file://{ overlay_path } " , "r" ) as f :
1090+ loaded_overlay = yaml .safe_load (f )
1091+
1092+ trainer_role = specs .Role (
1093+ name = "trainer" ,
1094+ image = "pytorch/torchx:latest" ,
1095+ entrypoint = "main" ,
1096+ resource = specs .Resource (cpu = 1 , memMB = 1000 , gpu = 0 ),
1097+ metadata = {
1098+ "kubernetes" : lambda pod : {
1099+ ** pod ,
1100+ "spec" : {
1101+ ** pod ["spec" ],
1102+ "tolerations" : pod ["spec" ].get ("tolerations" , [])
1103+ + loaded_overlay .get ("spec" , {}).get ("tolerations" , []),
1104+ },
1105+ }
1106+ },
1107+ )
1108+ app = specs .AppDef ("test" , roles = [trainer_role ])
1109+ cfg = KubernetesOpts ({"queue" : "testqueue" })
1110+
1111+ info = scheduler .submit_dryrun (app , cfg )
1112+ resource = info .request .resource
1113+
1114+ tasks = resource ["spec" ]["tasks" ] # pyre-ignore[16]
1115+ for task in tasks :
1116+ pod = task ["template" ]
1117+ self .assertIsNotNone (pod .spec .tolerations )
1118+ self .assertTrue (
1119+ any (t .key == "nvidia.com/gpu" for t in pod .spec .tolerations )
1120+ )
1121+ finally :
1122+ import os
1123+
1124+ os .unlink (overlay_path )
1125+
9321126
9331127class KubernetesSchedulerNoImportTest (unittest .TestCase ):
9341128 """
0 commit comments