diff --git a/torchx/schedulers/slurm_scheduler.py b/torchx/schedulers/slurm_scheduler.py index e74afb6ac..7fd926d99 100644 --- a/torchx/schedulers/slurm_scheduler.py +++ b/torchx/schedulers/slurm_scheduler.py @@ -73,6 +73,15 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState: return SLURM_STATES.get(slurm_state, AppState.UNKNOWN) +def get_appstate_from_job(job: dict[str, object]) -> AppState: + # Prior to slurm-23.11, job_state was a string and not a list + job_state = job.get("job_state", None) + if isinstance(job_state, list): + return appstate_from_slurm_state(job_state[0]) + else: + return appstate_from_slurm_state(str(job_state)) + + def version() -> Tuple[int, int]: """ Uses ``sinfo --version`` to get the slurm version. If the command fails, it @@ -666,7 +675,7 @@ def _describe_squeue(self, app_id: str) -> Optional[DescribeAppResponse]: entrypoint = job["command"] image = job["current_working_directory"] - state = appstate_from_slurm_state(job["job_state"][0]) + state = get_appstate_from_job(job) job_resources = job["job_resources"] @@ -881,7 +890,7 @@ def _list_squeue(self) -> List[ListAppResponse]: out.append( ListAppResponse( app_id=str(job["job_id"]), - state=SLURM_STATES[job["job_state"][0]], + state=get_appstate_from_job(job), name=job["name"], ) ) diff --git a/torchx/schedulers/test/slurm_scheduler_test.py b/torchx/schedulers/test/slurm_scheduler_test.py index fb1096454..16e457b39 100644 --- a/torchx/schedulers/test/slurm_scheduler_test.py +++ b/torchx/schedulers/test/slurm_scheduler_test.py @@ -547,8 +547,14 @@ def test_list_sacct(self, run: MagicMock) -> None: @patch("subprocess.run") def test_list_squeue(self, run: MagicMock) -> None: + # First job is patched with a string-type job state run.return_value.stdout = b"""{ "jobs": [ + { + "job_id": 1233, + "name": "bar", + "job_state": "FAILED" + }, { "job_id": 1234, "name": "foo", @@ -588,6 +594,7 @@ def test_list_squeue(self, run: MagicMock) -> None: }""" scheduler = create_scheduler("foo") expected_apps = [ + ListAppResponse(app_id="1233", state=AppState.FAILED, name="bar"), ListAppResponse(app_id="1234", state=AppState.FAILED, name="foo"), ListAppResponse(app_id="1235", state=AppState.FAILED, name="foo"), ListAppResponse(app_id="1236", state=AppState.RUNNING, name="foo-0"), @@ -1128,3 +1135,30 @@ def test_describe_squeue_nodes_as_string(self) -> None: assert result is not None assert result.roles_statuses[0].replicas[0].hostname == "compute-node-123" + + def test_describe_squeue_handles_string_state(self) -> None: + """Test that describe handles job state as string (i.e. for SLURM <= 23.02).""" + + # Mock legacy slurm response with job_state as a string + mock_job_data = { + "jobs": [ + { + "name": "test-job-0", + "job_state": "TIMEOUT", + "job_resources": {"nodes": "compute-node-123"}, + "command": "/bin/echo", + "current_working_directory": "/tmp", + } + ] + } + + with patch("subprocess.check_output") as mock_subprocess: + mock_subprocess.return_value = json.dumps(mock_job_data) + + scheduler = SlurmScheduler("test") + result = scheduler._describe_squeue("123") + + assert result is not None + assert result.app_id == "123" + # should have a valid parsed state + assert result.state == AppState.FAILED