Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions torchx/schedulers/slurm_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,15 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState:
return SLURM_STATES.get(slurm_state, AppState.UNKNOWN)


def get_appstate_from_job(job: dict[str, object]) -> AppState:
# Prior to slurm-23.11, job_state was a string and not a list
job_state = job.get("job_state", None)
if isinstance(job_state, list):
return appstate_from_slurm_state(job_state[0])
else:
return appstate_from_slurm_state(str(job_state))


def version() -> Tuple[int, int]:
"""
Uses ``sinfo --version`` to get the slurm version. If the command fails, it
Expand Down Expand Up @@ -666,7 +675,7 @@ def _describe_squeue(self, app_id: str) -> Optional[DescribeAppResponse]:

entrypoint = job["command"]
image = job["current_working_directory"]
state = appstate_from_slurm_state(job["job_state"][0])
state = get_appstate_from_job(job)

job_resources = job["job_resources"]

Expand Down Expand Up @@ -881,7 +890,7 @@ def _list_squeue(self) -> List[ListAppResponse]:
out.append(
ListAppResponse(
app_id=str(job["job_id"]),
state=SLURM_STATES[job["job_state"][0]],
state=get_appstate_from_job(job),
name=job["name"],
)
)
Expand Down
34 changes: 34 additions & 0 deletions torchx/schedulers/test/slurm_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,8 +547,14 @@ def test_list_sacct(self, run: MagicMock) -> None:

@patch("subprocess.run")
def test_list_squeue(self, run: MagicMock) -> None:
# First job is patched with a string-type job state
run.return_value.stdout = b"""{
"jobs": [
{
"job_id": 1233,
"name": "bar",
"job_state": "FAILED"
},
{
"job_id": 1234,
"name": "foo",
Expand Down Expand Up @@ -588,6 +594,7 @@ def test_list_squeue(self, run: MagicMock) -> None:
}"""
scheduler = create_scheduler("foo")
expected_apps = [
ListAppResponse(app_id="1233", state=AppState.FAILED, name="bar"),
ListAppResponse(app_id="1234", state=AppState.FAILED, name="foo"),
ListAppResponse(app_id="1235", state=AppState.FAILED, name="foo"),
ListAppResponse(app_id="1236", state=AppState.RUNNING, name="foo-0"),
Expand Down Expand Up @@ -1128,3 +1135,30 @@ def test_describe_squeue_nodes_as_string(self) -> None:

assert result is not None
assert result.roles_statuses[0].replicas[0].hostname == "compute-node-123"

def test_describe_squeue_handles_string_state(self) -> None:
"""Test that describe handles job state as string (i.e. for SLURM <= 23.02)."""

# Mock legacy slurm response with job_state as a string
mock_job_data = {
"jobs": [
{
"name": "test-job-0",
"job_state": "TIMEOUT",
"job_resources": {"nodes": "compute-node-123"},
"command": "/bin/echo",
"current_working_directory": "/tmp",
}
]
}

with patch("subprocess.check_output") as mock_subprocess:
mock_subprocess.return_value = json.dumps(mock_job_data)

scheduler = SlurmScheduler("test")
result = scheduler._describe_squeue("123")

assert result is not None
assert result.app_id == "123"
# should have a valid parsed state
assert result.state == AppState.FAILED
Loading