Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchx/runner/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def dryrun(
" Either a patch was built or no changes to workspace was detected."
)

sched._validate(app, scheduler)
sched._validate(app, scheduler, resolved_cfg)
dryrun_info = sched.submit_dryrun(app, resolved_cfg)
dryrun_info._scheduler = scheduler
return dryrun_info
Expand Down
2 changes: 1 addition & 1 deletion torchx/schedulers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def log_iter(
f"{self.__class__.__qualname__} does not support application log iteration"
)

def _validate(self, app: AppDef, scheduler: str) -> None:
def _validate(self, app: AppDef, scheduler: str, cfg: T) -> None:
"""
Validates whether application is consistent with the scheduler.

Expand Down
2 changes: 1 addition & 1 deletion torchx/schedulers/docker_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def _submit_dryrun(self, app: AppDef, cfg: DockerOpts) -> AppDryRunInfo[DockerJo

return AppDryRunInfo(req, repr)

def _validate(self, app: AppDef, scheduler: str) -> None:
def _validate(self, app: AppDef, scheduler: str, cfg: DockerOpts) -> None:
# Skip validation step
pass

Expand Down
2 changes: 1 addition & 1 deletion torchx/schedulers/gcp_batch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def list(self) -> List[ListAppResponse]:
for job in all_jobs
]

def _validate(self, app: AppDef, scheduler: str) -> None:
def _validate(self, app: AppDef, scheduler: str, cfg: GCPBatchOpts) -> None:
# Skip validation step
pass

Expand Down
2 changes: 1 addition & 1 deletion torchx/schedulers/kubernetes_mcad_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,7 +1033,7 @@ def _submit_dryrun(
info._cfg = cfg
return info

def _validate(self, app: AppDef, scheduler: str) -> None:
def _validate(self, app: AppDef, scheduler: str, cfg: KubernetesMCADOpts) -> None:
# Skip validation step
pass

Expand Down
2 changes: 1 addition & 1 deletion torchx/schedulers/kubernetes_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ def _submit_dryrun(
)
return AppDryRunInfo(req, repr)

def _validate(self, app: AppDef, scheduler: str) -> None:
def _validate(self, app: AppDef, scheduler: str, cfg: KubernetesOpts) -> None:
# Skip validation step
pass

Expand Down
2 changes: 1 addition & 1 deletion torchx/schedulers/local_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ def _run_opts(self) -> runopts:
)
return opts

def _validate(self, app: AppDef, scheduler: str) -> None:
def _validate(self, app: AppDef, scheduler: str, cfg: LocalOpts) -> None:
# Skip validation step for local application
pass

Expand Down
2 changes: 1 addition & 1 deletion torchx/schedulers/lsf_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def schedule(self, dryrun_info: AppDryRunInfo[LsfBsub]) -> str:
subprocess.run(req.cmd, stdout=subprocess.PIPE, check=True)
return req.app_id

def _validate(self, app: AppDef, scheduler: str) -> None:
def _validate(self, app: AppDef, scheduler: str, cfg: LsfOpts) -> None:
# Skip validation step for lsf
pass

Expand Down
2 changes: 1 addition & 1 deletion torchx/schedulers/ray_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def _submit_dryrun(self, app: AppDef, cfg: RayOpts) -> AppDryRunInfo[RayJob]:

return AppDryRunInfo(job, repr)

def _validate(self, app: AppDef, scheduler: str) -> None:
def _validate(self, app: AppDef, scheduler: str, cfg: RayOpts) -> None:
if scheduler != "ray":
raise ValueError(
f"An unknown scheduler backend '{scheduler}' has been passed to the Ray scheduler."
Expand Down
2 changes: 1 addition & 1 deletion torchx/schedulers/slurm_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def _submit_dryrun(

return AppDryRunInfo(req, repr)

def _validate(self, app: AppDef, scheduler: str) -> None:
def _validate(self, app: AppDef, scheduler: str, cfg: SlurmOpts) -> None:
# Skip validation step for slurm
pass

Expand Down
2 changes: 1 addition & 1 deletion torchx/schedulers/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def test_validate(self) -> None:
app_mock.roles[0].resource = NULL_RESOURCE

with self.assertRaises(ValueError):
scheduler_mock._validate(app_mock, "local")
scheduler_mock._validate(app_mock, "local", cfg={})

def test_cancel_not_exists(self) -> None:
scheduler_mock = SchedulerTest.MockScheduler("test_session")
Expand Down
10 changes: 7 additions & 3 deletions torchx/schedulers/test/kubernetes_mcad_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ def test_create_scheduler(self) -> None:
self.assertIsInstance(
scheduler, kubernetes_mcad_scheduler.KubernetesMCADScheduler
)
self.assertEquals(client, scheduler._client)
self.assertEquals(docker_client, scheduler._docker_client)
self.assertEqual(client, scheduler._client)
self.assertEqual(docker_client, scheduler._docker_client)

def test_app_to_resource_resolved_macros(self) -> None:
app = _test_app()
Expand Down Expand Up @@ -468,7 +468,11 @@ def test_create_mcad_service(self) -> None:
def test_validate(self) -> None:
scheduler = create_scheduler("test")
app = _test_app()
scheduler._validate(app, "kubernetes_mcad")
scheduler._validate(
app,
"kubernetes_mcad",
cfg=KubernetesMCADOpts({"namespace": "test_namespace"}),
)

def test_cleanup_str(self) -> None:
self.assertEqual("abcd123", cleanup_str("abcd123"))
Expand Down
2 changes: 1 addition & 1 deletion torchx/schedulers/test/lsf_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def test_submit_dryrun(self) -> None:
def test_validate(self) -> None:
scheduler = create_scheduler("foo")
app = simple_app()
scheduler._validate(app, "lsf")
scheduler._validate(app, "lsf", cfg={})

@patch("subprocess.run")
def test_schedule(self, run: MagicMock) -> None:
Expand Down
16 changes: 12 additions & 4 deletions torchx/schedulers/test/ray_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ def assert_option(expected_opt: Option) -> None:

def test_validate_does_not_raise_error_and_does_not_log_warning(self) -> None:
with self.assertLogs(_logger, "WARNING") as cm:
self._scheduler._validate(self._app_def, scheduler="ray")
self._scheduler._validate(
self._app_def, scheduler="ray", cfg=self._run_cfg
)

_logger.warning("dummy log")

Expand All @@ -150,7 +152,9 @@ def test_validate_raises_error_if_backend_name_is_not_ray(self) -> None:
ValueError,
r"^An unknown scheduler backend 'dummy' has been passed to the Ray scheduler.$",
):
self._scheduler._validate(self._app_def, scheduler="dummy")
self._scheduler._validate(
self._app_def, scheduler="dummy", cfg=self._run_cfg
)

@contextmanager
def _assert_log_message(self, level: str, msg: str) -> Iterator[None]:
Expand All @@ -170,7 +174,9 @@ def test_validate_warns_when_app_def_contains_metadata(self) -> None:
with self._assert_log_message(
"WARNING", "The Ray scheduler does not use metadata information."
):
self._scheduler._validate(self._app_def, scheduler="ray")
self._scheduler._validate(
self._app_def, scheduler="ray", cfg=self._run_cfg
)

def test_validate_warns_when_role_contains_resource_capability(self) -> None:
self._app_def.roles[1].resource.capabilities["dummy_cap1"] = 1
Expand All @@ -180,7 +186,9 @@ def test_validate_warns_when_role_contains_resource_capability(self) -> None:
"WARNING",
"The Ray scheduler does not support custom resource capabilities.",
):
self._scheduler._validate(self._app_def, scheduler="ray")
self._scheduler._validate(
self._app_def, scheduler="ray", cfg=self._run_cfg
)

def test_validate_warns_when_role_contains_port_map(self) -> None:
self._app_def.roles[1].port_map["dummy_map1"] = 1
Expand Down
Loading