2020import tempfile
2121from dataclasses import dataclass
2222from datetime import datetime
23+ from subprocess import CalledProcessError , PIPE
2324from typing import Any , Dict , Iterable , List , Mapping , Optional , Tuple
2425
2526import torchx
3940 macros ,
4041 NONE ,
4142 ReplicaStatus ,
43+ Resource ,
4244 Role ,
4345 RoleStatus ,
4446 runopts ,
6668 "TIMEOUT" : AppState .FAILED ,
6769}
6870
71+
72+ def appstate_from_slurm_state (slurm_state : str ) -> AppState :
73+ return SLURM_STATES .get (slurm_state , AppState .UNKNOWN )
74+
75+
6976SBATCH_JOB_OPTIONS = {
7077 "comment" ,
7178 "mail-user" ,
@@ -482,16 +489,36 @@ def _cancel_existing(self, app_id: str) -> None:
482489 subprocess .run (["scancel" , app_id ], check = True )
483490
484491 def describe (self , app_id : str ) -> Optional [DescribeAppResponse ]:
492+ # NOTE: depending on the version of slurm, querying for job info
493+ # with `squeue` for finished (or non-existent) jobs either:
494+ # 1. errors out with 'slurm_load_jobs error: Invalid job id specified'
495+ # 2. -- or -- squeue returns an empty jobs list
496+ # in either case, fall back to the less descriptive but more persistent sacct
497+ # (slurm cluster must have accounting storage enabled for sacct to work)
485498 try :
486- return self ._describe_sacct (app_id )
487- except subprocess .CalledProcessError :
488- return self ._describe_squeue (app_id )
499+ if desc := self ._describe_squeue (app_id ):
500+ return desc
501+ except CalledProcessError as e :
502+ log .info (
503+ f"unable to get job info for `{ app_id } ` with `squeue` ({ e .stderr } ), trying `sacct`"
504+ )
505+ return self ._describe_sacct (app_id )
489506
490507 def _describe_sacct (self , app_id : str ) -> Optional [DescribeAppResponse ]:
491- p = subprocess .run (
492- ["sacct" , "--parsable2" , "-j" , app_id ], stdout = subprocess .PIPE , check = True
493- )
494- output = p .stdout .decode ("utf-8" ).split ("\n " )
508+ try :
509+ output = subprocess .check_output (
510+ ["sacct" , "--parsable2" , "-j" , app_id ],
511+ stderr = PIPE ,
512+ encoding = "utf-8" ,
513+ ).split ("\n " )
514+ except CalledProcessError as e :
515+ log .info (
516+ "unable to get job info for `{}` with `sacct` ({})" .format (
517+ app_id , e .stderr
518+ )
519+ )
520+ return None
521+
495522 if len (output ) <= 1 :
496523 return None
497524
@@ -511,11 +538,7 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
511538
512539 state = row ["State" ]
513540 msg = state
514- state_enum = SLURM_STATES .get (state )
515- assert (
516- state_enum
517- ), f"failed to translate slurm state { state } to torchx state"
518- app_state = state_enum
541+ app_state = appstate_from_slurm_state (state )
519542
520543 role , _ , replica_id = row ["JobName" ].rpartition ("-" )
521544 if not replica_id or not role :
@@ -541,45 +564,109 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
541564 )
542565
543566 def _describe_squeue (self , app_id : str ) -> Optional [DescribeAppResponse ]:
544- p = subprocess .run (
545- ["squeue" , "--json" , "-j" , app_id ], stdout = subprocess .PIPE , check = True
567+ # squeue errors out with 'slurm_load_jobs error: Invalid job id specified'
568+ # if the job does not exist or is finished (e.g. not in PENDING or RUNNING state)
569+ output = subprocess .check_output (
570+ ["squeue" , "--json" , "-j" , app_id ], stderr = PIPE , encoding = "utf-8"
546571 )
547- output_json = json .loads (p .stdout .decode ("utf-8" ))
572+ output_json = json .loads (output )
573+ jobs = output_json ["jobs" ]
574+ if not jobs :
575+ return None
548576
549- roles = {}
550- roles_statuses = {}
551- msg = ""
552- app_state = AppState .UNKNOWN
553- for job in output_json ["jobs" ]:
554- state = job ["job_state" ][0 ]
555- msg = state
556- state_enum = SLURM_STATES .get (state )
557- assert (
558- state_enum
559- ), f"failed to translate slurm state { state } to torchx state"
560- app_state = state_enum
577+ roles : dict [str , Role ] = {}
578+ roles_statuses : dict [str , RoleStatus ] = {}
579+ state = AppState .UNKNOWN
561580
562- role , _ , replica_id = job ["name" ].rpartition ("-" )
563- if not replica_id or not role :
564- # name should always have at least 3 parts but sometimes sacct
565- # is slow to update
566- continue
567- if role not in roles :
568- roles [role ] = Role (name = role , num_replicas = 0 , image = "" )
569- roles_statuses [role ] = RoleStatus (role , [])
570- roles [role ].num_replicas += 1
571- roles_statuses [role ].replicas .append (
572- ReplicaStatus (
573- id = int (replica_id ), role = role , state = app_state , hostname = ""
581+ for job in jobs :
582+ # job name is of the form "{role_name}-{replica_id}"
583+ role_name , _ , replica_id = job ["name" ].rpartition ("-" )
584+
585+ entrypoint = job ["command" ]
586+ image = job ["current_working_directory" ]
587+ state = appstate_from_slurm_state (job ["job_state" ][0 ])
588+
589+ job_resources = job ["job_resources" ]
590+
591+ role = roles .setdefault (
592+ role_name ,
593+ Role (
594+ name = role_name ,
595+ image = image ,
596+ entrypoint = entrypoint ,
597+ num_replicas = 0 ,
574598 ),
575599 )
600+ role_status = roles_statuses .setdefault (
601+ role_name ,
602+ RoleStatus (role_name , replicas = []),
603+ )
604+
605+ if state == AppState .PENDING :
606+ # NOTE: torchx launched jobs points to exactly one host
607+ # otherwise, scheduled_nodes could be a node list expression (eg. 'slurm-compute-node[0-20,21,45-47]')
608+ hostname = job_resources .get ("scheduled_nodes" , "" )
609+
610+ role .num_replicas += 1
611+ role_status .replicas .append (
612+ ReplicaStatus (
613+ id = int (replica_id ),
614+ role = role_name ,
615+ state = state ,
616+ hostname = hostname ,
617+ )
618+ )
619+ else : # state == AppState.RUNNING
620+ # NOTE: torchx schedules on slurm with sbatch + heterogenous job
621+ # where each replica is a "sub-job" so `allocated_nodes` will always be 1
622+ # but we deal with jobs that have not been launched with torchx
623+ # which can have multiple hosts per sub-job (count them as replicas)
624+ node_infos = job_resources .get ("allocated_nodes" , [])
625+
626+ if not isinstance (node_infos , list ):
627+ # NOTE: in some versions of slurm jobs[].job_resources.allocated_nodes
628+ # is not a list of individual nodes, but a map of the nodelist specs
629+ # in this case just use jobs[].job_resources.nodes
630+ hostname = job_resources .get ("nodes" )
631+ role .num_replicas += 1
632+ role_status .replicas .append (
633+ ReplicaStatus (
634+ id = int (replica_id ),
635+ role = role_name ,
636+ state = state ,
637+ hostname = hostname ,
638+ )
639+ )
640+ else :
641+ for node_info in node_infos :
642+ # NOTE: we expect resource specs for all the nodes to be the same
643+ # NOTE: use allocated (not used/requested) memory since
644+ # users may only specify --cpu, in which case slurm
645+ # uses the (system) configured {mem-per-cpu} * {cpus}
646+ # to allocate memory.
647+ # NOTE: getting gpus is tricky because it modeled as a trackable-resource
648+ # or not configured at all (use total-cpu-on-host as proxy for gpus)
649+ cpu = int (node_info ["cpus_used" ])
650+ memMB = int (node_info ["memory_allocated" ])
651+
652+ hostname = node_info ["nodename" ]
653+
654+ role .resource = Resource (cpu = cpu , memMB = memMB , gpu = - 1 )
655+ role .num_replicas += 1
656+ role_status .replicas .append (
657+ ReplicaStatus (
658+ id = int (replica_id ),
659+ role = role_name ,
660+ state = state ,
661+ hostname = hostname ,
662+ )
663+ )
576664
577665 return DescribeAppResponse (
578666 app_id = app_id ,
579667 roles = list (roles .values ()),
580668 roles_statuses = list (roles_statuses .values ()),
581- state = app_state ,
582- msg = msg ,
669+ state = state ,
583670 )
584671
585672 def log_iter (
0 commit comments