2020import tempfile
2121from dataclasses import dataclass
2222from datetime import datetime
23+ from subprocess import CalledProcessError , PIPE
2324from typing import Any , Dict , Iterable , List , Mapping , Optional , Tuple
2425
2526import torchx
6667 "TIMEOUT" : AppState .FAILED ,
6768}
6869
70+
71+ def appstate_from_slurm_state (slurm_state : str ) -> AppState :
72+ return SLURM_STATES .get (slurm_state , AppState .UNKNOWN )
73+
74+
6975SBATCH_JOB_OPTIONS = {
7076 "comment" ,
7177 "mail-user" ,
@@ -482,10 +488,82 @@ def _cancel_existing(self, app_id: str) -> None:
482488 subprocess .run (["scancel" , app_id ], check = True )
483489
484490 def describe (self , app_id : str ) -> Optional [DescribeAppResponse ]:
485- try :
486- return self ._describe_sacct (app_id )
487- except subprocess .CalledProcessError :
488- return self ._describe_squeue (app_id )
491+ # fallback to using different slurm commands for describing the job
492+ for describe in [
493+ self ._describe_scontrol , # NOTE: only scontrol fills hostnames
494+ self ._describe_sacct ,
495+ self ._describe_squeue ,
496+ ]:
497+ try :
498+ return describe (app_id )
499+ except CalledProcessError :
500+ continue
501+
502+ def _describe_scontrol (self , app_id : str ) -> Optional [DescribeAppResponse ]:
503+ # NOTE: app_id for slurm_scheduler is the job_id (not the heterogenous_job_id).
504+ # For heterogeneous jobs, querying slurm by the base job id returns all the
505+ # "sub-jobs" in it.
506+ # We launch each role's replica on its own srun command where the job_name is set
507+ # to `{role.name}-{replica_id}` (e.g. `worker-0`, `worker-1`, ...).
508+ # So each sub-job maps to a replica in the role.
509+
510+ output = subprocess .check_output (
511+ ["scontrol" , "show" , "--json" , "job" , app_id ], stderr = PIPE , encoding = "utf-8"
512+ )
513+ output_json = json .loads (output )
514+ jobs = output_json ["jobs" ]
515+ if not jobs :
516+ # job either finished or does not exist
517+ return None
518+
519+ roles : dict [str , Role ] = {}
520+ roles_statuses : dict [str , RoleStatus ] = {}
521+ state = AppState .UNKNOWN
522+
523+ for job in jobs :
524+ # job name is of the form "{role_name}-{replica_id}"
525+ role_name , _ , replica_id = job ["name" ].rpartition ("-" )
526+
527+ image = job ["current_working_directory" ]
528+ entrypoint = job ["command" ]
529+ state = appstate_from_slurm_state (job ["job_state" ][0 ])
530+ job_resources = job ["job_resources" ]
531+
532+ # nodes is a a hostlist expression (e.g. slurm-compute-node[200-210,212])
533+ # but we schedule a job per replica so will always be a single host
534+ hostname = job_resources ["nodes" ]
535+
536+ role = roles .setdefault (
537+ role_name ,
538+ Role (
539+ name = role_name ,
540+ image = image ,
541+ entrypoint = entrypoint ,
542+ num_replicas = 0 ,
543+ ),
544+ )
545+ role .num_replicas += 1
546+
547+ role_status = roles_statuses .setdefault (
548+ role_name ,
549+ RoleStatus (role_name , replicas = []),
550+ )
551+
552+ role_status .replicas .append (
553+ ReplicaStatus (
554+ id = int (replica_id ),
555+ role = role_name ,
556+ state = state ,
557+ hostname = hostname ,
558+ )
559+ )
560+
561+ return DescribeAppResponse (
562+ app_id = app_id ,
563+ roles = list (roles .values ()),
564+ roles_statuses = list (roles_statuses .values ()),
565+ state = state ,
566+ )
489567
490568 def _describe_sacct (self , app_id : str ) -> Optional [DescribeAppResponse ]:
491569 p = subprocess .run (
@@ -511,11 +589,7 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
511589
512590 state = row ["State" ]
513591 msg = state
514- state_enum = SLURM_STATES .get (state )
515- assert (
516- state_enum
517- ), f"failed to translate slurm state { state } to torchx state"
518- app_state = state_enum
592+ app_state = appstate_from_slurm_state (state )
519593
520594 role , _ , replica_id = row ["JobName" ].rpartition ("-" )
521595 if not replica_id or not role :
@@ -553,11 +627,7 @@ def _describe_squeue(self, app_id: str) -> Optional[DescribeAppResponse]:
553627 for job in output_json ["jobs" ]:
554628 state = job ["job_state" ][0 ]
555629 msg = state
556- state_enum = SLURM_STATES .get (state )
557- assert (
558- state_enum
559- ), f"failed to translate slurm state { state } to torchx state"
560- app_state = state_enum
630+ app_state = appstate_from_slurm_state (state )
561631
562632 role , _ , replica_id = job ["name" ].rpartition ("-" )
563633 if not replica_id or not role :
0 commit comments