2020import tempfile
2121from dataclasses import dataclass
2222from datetime import datetime
23+ from subprocess import CalledProcessError , PIPE
2324from typing import Any , Dict , Iterable , List , Mapping , Optional , Tuple
2425
2526import torchx
3940 macros ,
4041 NONE ,
4142 ReplicaStatus ,
43+ Resource ,
4244 Role ,
4345 RoleStatus ,
4446 runopts ,
6668 "TIMEOUT" : AppState .FAILED ,
6769}
6870
71+
72+ def appstate_from_slurm_state (slurm_state : str ) -> AppState :
73+ return SLURM_STATES .get (slurm_state , AppState .UNKNOWN )
74+
75+
6976SBATCH_JOB_OPTIONS = {
7077 "comment" ,
7178 "mail-user" ,
@@ -483,15 +490,34 @@ def _cancel_existing(self, app_id: str) -> None:
483490
484491 def describe (self , app_id : str ) -> Optional [DescribeAppResponse ]:
485492 try :
486- return self ._describe_sacct (app_id )
487- except subprocess .CalledProcessError :
488493 return self ._describe_squeue (app_id )
494+ except CalledProcessError as e :
495+ # NOTE: squeue errors out with 'slurm_load_jobs error: Invalid job id specified'
496+ # if the job does not exist or has finished (e.g. not in PENDING or RUNNING states)
497+ # in this case, fall back to the less descriptive but more persistent sacct
498+ # (slurm cluster must have accounting storage enabled for sacct to work)
499+ log .info (
500+ "unable to get job info for `{}` with `squeue` ({}), trying `sacct`" .format (
501+ app_id , e .stderr
502+ )
503+ )
504+ return self ._describe_sacct (app_id )
489505
490506 def _describe_sacct (self , app_id : str ) -> Optional [DescribeAppResponse ]:
491- p = subprocess .run (
492- ["sacct" , "--parsable2" , "-j" , app_id ], stdout = subprocess .PIPE , check = True
493- )
494- output = p .stdout .decode ("utf-8" ).split ("\n " )
507+ try :
508+ output = subprocess .check_output (
509+ ["sacct" , "--parsable2" , "-j" , app_id ],
510+ stderr = PIPE ,
511+ encoding = "utf-8" ,
512+ ).split ("\n " )
513+ except CalledProcessError as e :
514+ log .info (
515+ "unable to get job info for `{}` with `sacct` ({})" .format (
516+ app_id , e .stderr
517+ )
518+ )
519+ return None
520+
495521 if len (output ) <= 1 :
496522 return None
497523
@@ -511,11 +537,7 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
511537
512538 state = row ["State" ]
513539 msg = state
514- state_enum = SLURM_STATES .get (state )
515- assert (
516- state_enum
517- ), f"failed to translate slurm state { state } to torchx state"
518- app_state = state_enum
540+ app_state = appstate_from_slurm_state (state )
519541
520542 role , _ , replica_id = row ["JobName" ].rpartition ("-" )
521543 if not replica_id or not role :
@@ -540,46 +562,107 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
540562 msg = msg ,
541563 )
542564
543- def _describe_squeue (self , app_id : str ) -> Optional [DescribeAppResponse ]:
544- p = subprocess .run (
545- ["squeue" , "--json" , "-j" , app_id ], stdout = subprocess .PIPE , check = True
565+ def _describe_squeue (self , app_id : str ) -> DescribeAppResponse :
566+ # squeue errors out with 'slurm_load_jobs error: Invalid job id specified'
567+ # if the job does not exist or is finished (e.g. not in PENDING or RUNNING state)
568+ output = subprocess .check_output (
569+ ["squeue" , "--json" , "-j" , app_id ], stderr = PIPE , encoding = "utf-8"
546570 )
547- output_json = json .loads (p .stdout .decode ("utf-8" ))
548-
549- roles = {}
550- roles_statuses = {}
551- msg = ""
552- app_state = AppState .UNKNOWN
553- for job in output_json ["jobs" ]:
554- state = job ["job_state" ][0 ]
555- msg = state
556- state_enum = SLURM_STATES .get (state )
557- assert (
558- state_enum
559- ), f"failed to translate slurm state { state } to torchx state"
560- app_state = state_enum
561-
562- role , _ , replica_id = job ["name" ].rpartition ("-" )
563- if not replica_id or not role :
564- # name should always have at least 3 parts but sometimes sacct
565- # is slow to update
566- continue
567- if role not in roles :
568- roles [role ] = Role (name = role , num_replicas = 0 , image = "" )
569- roles_statuses [role ] = RoleStatus (role , [])
570- roles [role ].num_replicas += 1
571- roles_statuses [role ].replicas .append (
572- ReplicaStatus (
573- id = int (replica_id ), role = role , state = app_state , hostname = ""
571+ output_json = json .loads (output )
572+ jobs = output_json ["jobs" ]
573+
574+ roles : dict [str , Role ] = {}
575+ roles_statuses : dict [str , RoleStatus ] = {}
576+ state = AppState .UNKNOWN
577+
578+ for job in jobs :
579+ # job name is of the form "{role_name}-{replica_id}"
580+ role_name , _ , replica_id = job ["name" ].rpartition ("-" )
581+
582+ entrypoint = job ["command" ]
583+ image = job ["current_working_directory" ]
584+ state = appstate_from_slurm_state (job ["job_state" ][0 ])
585+
586+ job_resources = job ["job_resources" ]
587+
588+ role = roles .setdefault (
589+ role_name ,
590+ Role (
591+ name = role_name ,
592+ image = image ,
593+ entrypoint = entrypoint ,
594+ num_replicas = 0 ,
574595 ),
575596 )
597+ role_status = roles_statuses .setdefault (
598+ role_name ,
599+ RoleStatus (role_name , replicas = []),
600+ )
601+
602+ if state == AppState .PENDING :
603+ # NOTE: torchx launched jobs points to exactly one host
604+ # otherwise, scheduled_nodes could be a node list expression (eg. 'slurm-compute-node[0-20,21,45-47]')
605+ hostname = job_resources ["scheduled_nodes" ]
606+ role .num_replicas += 1
607+ role_status .replicas .append (
608+ ReplicaStatus (
609+ id = int (replica_id ),
610+ role = role_name ,
611+ state = state ,
612+ hostname = hostname ,
613+ )
614+ )
615+ else : # state == AppState.RUNNING
616+ # NOTE: torchx schedules on slurm with sbatch + heterogenous job
617+ # where each replica is a "sub-job" so `allocated_nodes` will always be 1
618+ # but we deal with jobs that have not been launched with torchx
619+ # which can have multiple hosts per sub-job (count them as replicas)
620+ node_infos = job_resources .get ("allocated_nodes" , [])
621+
622+ if not isinstance (node_infos , list ):
623+ # NOTE: in some versions of slurm jobs[].job_resources.allocated_nodes
624+ # is not a list of individual nodes, but a map of the nodelist specs
625+ # in this case just use jobs[].job_resources.nodes
626+ hostname = job_resources .get ("nodes" )
627+ role .num_replicas += 1
628+ role_status .replicas .append (
629+ ReplicaStatus (
630+ id = int (replica_id ),
631+ role = role_name ,
632+ state = state ,
633+ hostname = hostname ,
634+ )
635+ )
636+ else :
637+ for node_info in node_infos :
638+ # NOTE: we expect resource specs for all the nodes to be the same
639+ # NOTE: use allocated (not used/requested) memory since
640+ # users may only specify --cpu, in which case slurm
641+ # uses the (system) configured {mem-per-cpu} * {cpus}
642+ # to allocate memory.
643+ # NOTE: getting gpus is tricky because it modeled as a trackable-resource
644+ # or not configured at all (use total-cpu-on-host as proxy for gpus)
645+ cpu = int (node_info ["cpus_used" ])
646+ memMB = int (node_info ["memory_allocated" ])
647+
648+ hostname = node_info ["nodename" ]
649+
650+ role .resource = Resource (cpu = cpu , memMB = memMB , gpu = - 1 )
651+ role .num_replicas += 1
652+ role_status .replicas .append (
653+ ReplicaStatus (
654+ id = int (replica_id ),
655+ role = role_name ,
656+ state = state ,
657+ hostname = hostname ,
658+ )
659+ )
576660
577661 return DescribeAppResponse (
578662 app_id = app_id ,
579663 roles = list (roles .values ()),
580664 roles_statuses = list (roles_statuses .values ()),
581- state = app_state ,
582- msg = msg ,
665+ state = state ,
583666 )
584667
585668 def log_iter (
0 commit comments