diff --git a/torchx/runner/api.py b/torchx/runner/api.py index 4a03a8aad..3e194a6ca 100644 --- a/torchx/runner/api.py +++ b/torchx/runner/api.py @@ -433,7 +433,7 @@ def dryrun( " Either a patch was built or no changes to workspace was detected." ) - sched._validate(app, scheduler) + sched._validate(app, scheduler, resolved_cfg) dryrun_info = sched.submit_dryrun(app, resolved_cfg) dryrun_info._scheduler = scheduler return dryrun_info diff --git a/torchx/schedulers/api.py b/torchx/schedulers/api.py index 437c4e23a..3ef3c5899 100644 --- a/torchx/schedulers/api.py +++ b/torchx/schedulers/api.py @@ -337,7 +337,7 @@ def log_iter( f"{self.__class__.__qualname__} does not support application log iteration" ) - def _validate(self, app: AppDef, scheduler: str) -> None: + def _validate(self, app: AppDef, scheduler: str, cfg: T) -> None: """ Validates whether application is consistent with the scheduler. diff --git a/torchx/schedulers/docker_scheduler.py b/torchx/schedulers/docker_scheduler.py index 48954a8b9..59e524e65 100644 --- a/torchx/schedulers/docker_scheduler.py +++ b/torchx/schedulers/docker_scheduler.py @@ -327,7 +327,7 @@ def _submit_dryrun(self, app: AppDef, cfg: DockerOpts) -> AppDryRunInfo[DockerJo return AppDryRunInfo(req, repr) - def _validate(self, app: AppDef, scheduler: str) -> None: + def _validate(self, app: AppDef, scheduler: str, cfg: DockerOpts) -> None: # Skip validation step pass diff --git a/torchx/schedulers/gcp_batch_scheduler.py b/torchx/schedulers/gcp_batch_scheduler.py index 36399eac1..f4d0ef09c 100644 --- a/torchx/schedulers/gcp_batch_scheduler.py +++ b/torchx/schedulers/gcp_batch_scheduler.py @@ -464,7 +464,7 @@ def list(self) -> List[ListAppResponse]: for job in all_jobs ] - def _validate(self, app: AppDef, scheduler: str) -> None: + def _validate(self, app: AppDef, scheduler: str, cfg: GCPBatchOpts) -> None: # Skip validation step pass diff --git a/torchx/schedulers/kubernetes_mcad_scheduler.py b/torchx/schedulers/kubernetes_mcad_scheduler.py index 2ac31952e..467c6363e 100644 --- a/torchx/schedulers/kubernetes_mcad_scheduler.py +++ b/torchx/schedulers/kubernetes_mcad_scheduler.py @@ -1033,7 +1033,7 @@ def _submit_dryrun( info._cfg = cfg return info - def _validate(self, app: AppDef, scheduler: str) -> None: + def _validate(self, app: AppDef, scheduler: str, cfg: KubernetesMCADOpts) -> None: # Skip validation step pass diff --git a/torchx/schedulers/kubernetes_scheduler.py b/torchx/schedulers/kubernetes_scheduler.py index c06eebc8a..97d57b8cc 100644 --- a/torchx/schedulers/kubernetes_scheduler.py +++ b/torchx/schedulers/kubernetes_scheduler.py @@ -661,7 +661,7 @@ def _submit_dryrun( ) return AppDryRunInfo(req, repr) - def _validate(self, app: AppDef, scheduler: str) -> None: + def _validate(self, app: AppDef, scheduler: str, cfg: KubernetesOpts) -> None: # Skip validation step pass diff --git a/torchx/schedulers/local_scheduler.py b/torchx/schedulers/local_scheduler.py index 9390902c4..aa899b1d2 100644 --- a/torchx/schedulers/local_scheduler.py +++ b/torchx/schedulers/local_scheduler.py @@ -630,7 +630,7 @@ def _run_opts(self) -> runopts: ) return opts - def _validate(self, app: AppDef, scheduler: str) -> None: + def _validate(self, app: AppDef, scheduler: str, cfg: LocalOpts) -> None: # Skip validation step for local application pass diff --git a/torchx/schedulers/lsf_scheduler.py b/torchx/schedulers/lsf_scheduler.py index 102a8f417..fdc915431 100644 --- a/torchx/schedulers/lsf_scheduler.py +++ b/torchx/schedulers/lsf_scheduler.py @@ -488,7 +488,7 @@ def schedule(self, dryrun_info: AppDryRunInfo[LsfBsub]) -> str: subprocess.run(req.cmd, stdout=subprocess.PIPE, check=True) return req.app_id - def _validate(self, app: AppDef, scheduler: str) -> None: + def _validate(self, app: AppDef, scheduler: str, cfg: LsfOpts) -> None: # Skip validation step for lsf pass diff --git a/torchx/schedulers/ray_scheduler.py b/torchx/schedulers/ray_scheduler.py index f0b98f6da..47767a406 100644 --- a/torchx/schedulers/ray_scheduler.py +++ b/torchx/schedulers/ray_scheduler.py @@ -318,7 +318,7 @@ def _submit_dryrun(self, app: AppDef, cfg: RayOpts) -> AppDryRunInfo[RayJob]: return AppDryRunInfo(job, repr) - def _validate(self, app: AppDef, scheduler: str) -> None: + def _validate(self, app: AppDef, scheduler: str, cfg: RayOpts) -> None: if scheduler != "ray": raise ValueError( f"An unknown scheduler backend '{scheduler}' has been passed to the Ray scheduler." diff --git a/torchx/schedulers/slurm_scheduler.py b/torchx/schedulers/slurm_scheduler.py index 331fd8611..e89b2b063 100644 --- a/torchx/schedulers/slurm_scheduler.py +++ b/torchx/schedulers/slurm_scheduler.py @@ -472,7 +472,7 @@ def _submit_dryrun( return AppDryRunInfo(req, repr) - def _validate(self, app: AppDef, scheduler: str) -> None: + def _validate(self, app: AppDef, scheduler: str, cfg: SlurmOpts) -> None: # Skip validation step for slurm pass diff --git a/torchx/schedulers/test/api_test.py b/torchx/schedulers/test/api_test.py index 26013e364..c45767c56 100644 --- a/torchx/schedulers/test/api_test.py +++ b/torchx/schedulers/test/api_test.py @@ -166,7 +166,7 @@ def test_validate(self) -> None: app_mock.roles[0].resource = NULL_RESOURCE with self.assertRaises(ValueError): - scheduler_mock._validate(app_mock, "local") + scheduler_mock._validate(app_mock, "local", cfg={}) def test_cancel_not_exists(self) -> None: scheduler_mock = SchedulerTest.MockScheduler("test_session") diff --git a/torchx/schedulers/test/kubernetes_mcad_scheduler_test.py b/torchx/schedulers/test/kubernetes_mcad_scheduler_test.py index 25c96a55f..6101642b0 100644 --- a/torchx/schedulers/test/kubernetes_mcad_scheduler_test.py +++ b/torchx/schedulers/test/kubernetes_mcad_scheduler_test.py @@ -167,8 +167,8 @@ def test_create_scheduler(self) -> None: self.assertIsInstance( scheduler, kubernetes_mcad_scheduler.KubernetesMCADScheduler ) - self.assertEquals(client, scheduler._client) - self.assertEquals(docker_client, scheduler._docker_client) + self.assertEqual(client, scheduler._client) + self.assertEqual(docker_client, scheduler._docker_client) def test_app_to_resource_resolved_macros(self) -> None: app = _test_app() @@ -468,7 +468,11 @@ def test_create_mcad_service(self) -> None: def test_validate(self) -> None: scheduler = create_scheduler("test") app = _test_app() - scheduler._validate(app, "kubernetes_mcad") + scheduler._validate( + app, + "kubernetes_mcad", + cfg=KubernetesMCADOpts({"namespace": "test_namespace"}), + ) def test_cleanup_str(self) -> None: self.assertEqual("abcd123", cleanup_str("abcd123")) diff --git a/torchx/schedulers/test/lsf_scheduler_test.py b/torchx/schedulers/test/lsf_scheduler_test.py index 7d4032e54..144d23356 100644 --- a/torchx/schedulers/test/lsf_scheduler_test.py +++ b/torchx/schedulers/test/lsf_scheduler_test.py @@ -515,7 +515,7 @@ def test_submit_dryrun(self) -> None: def test_validate(self) -> None: scheduler = create_scheduler("foo") app = simple_app() - scheduler._validate(app, "lsf") + scheduler._validate(app, "lsf", cfg={}) @patch("subprocess.run") def test_schedule(self, run: MagicMock) -> None: diff --git a/torchx/schedulers/test/ray_scheduler_test.py b/torchx/schedulers/test/ray_scheduler_test.py index 1e7d181b5..f8bde7e7c 100644 --- a/torchx/schedulers/test/ray_scheduler_test.py +++ b/torchx/schedulers/test/ray_scheduler_test.py @@ -139,7 +139,9 @@ def assert_option(expected_opt: Option) -> None: def test_validate_does_not_raise_error_and_does_not_log_warning(self) -> None: with self.assertLogs(_logger, "WARNING") as cm: - self._scheduler._validate(self._app_def, scheduler="ray") + self._scheduler._validate( + self._app_def, scheduler="ray", cfg=self._run_cfg + ) _logger.warning("dummy log") @@ -150,7 +152,9 @@ def test_validate_raises_error_if_backend_name_is_not_ray(self) -> None: ValueError, r"^An unknown scheduler backend 'dummy' has been passed to the Ray scheduler.$", ): - self._scheduler._validate(self._app_def, scheduler="dummy") + self._scheduler._validate( + self._app_def, scheduler="dummy", cfg=self._run_cfg + ) @contextmanager def _assert_log_message(self, level: str, msg: str) -> Iterator[None]: @@ -170,7 +174,9 @@ def test_validate_warns_when_app_def_contains_metadata(self) -> None: with self._assert_log_message( "WARNING", "The Ray scheduler does not use metadata information." ): - self._scheduler._validate(self._app_def, scheduler="ray") + self._scheduler._validate( + self._app_def, scheduler="ray", cfg=self._run_cfg + ) def test_validate_warns_when_role_contains_resource_capability(self) -> None: self._app_def.roles[1].resource.capabilities["dummy_cap1"] = 1 @@ -180,7 +186,9 @@ def test_validate_warns_when_role_contains_resource_capability(self) -> None: "WARNING", "The Ray scheduler does not support custom resource capabilities.", ): - self._scheduler._validate(self._app_def, scheduler="ray") + self._scheduler._validate( + self._app_def, scheduler="ray", cfg=self._run_cfg + ) def test_validate_warns_when_role_contains_port_map(self) -> None: self._app_def.roles[1].port_map["dummy_map1"] = 1