2020import tempfile
2121from dataclasses import dataclass
2222from datetime import datetime
23+ from subprocess import CalledProcessError , PIPE
2324from typing import Any , Dict , Iterable , List , Mapping , Optional , Tuple
2425
2526import torchx
3940 macros ,
4041 NONE ,
4142 ReplicaStatus ,
43+ Resource ,
4244 Role ,
4345 RoleStatus ,
4446 runopts ,
6668 "TIMEOUT" : AppState .FAILED ,
6769}
6870
71+
72+ def appstate_from_slurm_state (slurm_state : str ) -> AppState :
73+ return SLURM_STATES .get (slurm_state , AppState .UNKNOWN )
74+
75+
6976SBATCH_JOB_OPTIONS = {
7077 "comment" ,
7178 "mail-user" ,
@@ -483,15 +490,34 @@ def _cancel_existing(self, app_id: str) -> None:
483490
484491 def describe (self , app_id : str ) -> Optional [DescribeAppResponse ]:
485492 try :
486- return self ._describe_sacct (app_id )
487- except subprocess .CalledProcessError :
488493 return self ._describe_squeue (app_id )
494+ except CalledProcessError as e :
495+ # NOTE: squeue errors out with 'slurm_load_jobs error: Invalid job id specified'
496+ # if the job does not exist or has finished (e.g. not in PENDING or RUNNING states)
497+ # in this case, fall back to the less descriptive but more persistent sacct
498+ # (slurm cluster must have accounting storage enabled for sacct to work)
499+ log .info (
500+ "unable to get job info for `{}` with `squeue` ({}), trying `sacct`" .format (
501+ app_id , e .stderr
502+ )
503+ )
504+ return self ._describe_sacct (app_id )
489505
490506 def _describe_sacct (self , app_id : str ) -> Optional [DescribeAppResponse ]:
491- p = subprocess .run (
492- ["sacct" , "--parsable2" , "-j" , app_id ], stdout = subprocess .PIPE , check = True
493- )
494- output = p .stdout .decode ("utf-8" ).split ("\n " )
507+ try :
508+ output = subprocess .check_output (
509+ ["sacct" , "--parsable2" , "-j" , app_id ],
510+ stderr = PIPE ,
511+ encoding = "utf-8" ,
512+ ).split ("\n " )
513+ except CalledProcessError as e :
514+ log .info (
515+ "unable to get job info for `{}` with `sacct` ({})" .format (
516+ app_id , e .stderr
517+ )
518+ )
519+ return None
520+
495521 if len (output ) <= 1 :
496522 return None
497523
@@ -511,11 +537,7 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
511537
512538 state = row ["State" ]
513539 msg = state
514- state_enum = SLURM_STATES .get (state )
515- assert (
516- state_enum
517- ), f"failed to translate slurm state { state } to torchx state"
518- app_state = state_enum
540+ app_state = appstate_from_slurm_state (state )
519541
520542 role , _ , replica_id = row ["JobName" ].rpartition ("-" )
521543 if not replica_id or not role :
@@ -540,46 +562,93 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
540562 msg = msg ,
541563 )
542564
543- def _describe_squeue (self , app_id : str ) -> Optional [DescribeAppResponse ]:
544- p = subprocess .run (
545- ["squeue" , "--json" , "-j" , app_id ], stdout = subprocess .PIPE , check = True
565+ def _describe_squeue (self , app_id : str ) -> DescribeAppResponse :
566+ # squeue errors out with 'slurm_load_jobs error: Invalid job id specified'
567+ # if the job does not exist or is finished (e.g. not in PENDING or RUNNING state)
568+ output = subprocess .check_output (
569+ ["squeue" , "--json" , "-j" , app_id ], stderr = PIPE , encoding = "utf-8"
546570 )
547- output_json = json .loads (p .stdout .decode ("utf-8" ))
548571
549- roles = {}
550- roles_statuses = {}
551- msg = ""
552- app_state = AppState .UNKNOWN
553- for job in output_json ["jobs" ]:
554- state = job ["job_state" ][0 ]
555- msg = state
556- state_enum = SLURM_STATES .get (state )
557- assert (
558- state_enum
559- ), f"failed to translate slurm state { state } to torchx state"
560- app_state = state_enum
572+ output_json = json .loads (output )
573+ jobs = output_json ["jobs" ]
561574
562- role , _ , replica_id = job ["name" ].rpartition ("-" )
563- if not replica_id or not role :
564- # name should always have at least 3 parts but sometimes sacct
565- # is slow to update
566- continue
567- if role not in roles :
568- roles [role ] = Role (name = role , num_replicas = 0 , image = "" )
569- roles_statuses [role ] = RoleStatus (role , [])
570- roles [role ].num_replicas += 1
571- roles_statuses [role ].replicas .append (
572- ReplicaStatus (
573- id = int (replica_id ), role = role , state = app_state , hostname = ""
575+ roles : dict [str , Role ] = {}
576+ roles_statuses : dict [str , RoleStatus ] = {}
577+ state = AppState .UNKNOWN
578+
579+ for job in jobs :
580+ # job name is of the form "{role_name}-{replica_id}"
581+ role_name , _ , replica_id = job ["name" ].rpartition ("-" )
582+
583+ entrypoint = job ["command" ]
584+ image = job ["current_working_directory" ]
585+ state = appstate_from_slurm_state (job ["job_state" ][0 ])
586+
587+ job_resources = job ["job_resources" ]
588+
589+ role = roles .setdefault (
590+ role_name ,
591+ Role (
592+ name = role_name ,
593+ image = image ,
594+ entrypoint = entrypoint ,
595+ num_replicas = 0 ,
574596 ),
575597 )
598+ role_status = roles_statuses .setdefault (
599+ role_name ,
600+ RoleStatus (role_name , replicas = []),
601+ )
602+
603+ if state == AppState .PENDING :
604+ # NOTE: torchx launched jobs points to exactly one host
605+ # otherwise, scheduled_nodes could be a node list expression (eg. 'slurm-compute-node[0-20,21,45-47]')
606+ hostname = job_resources ["scheduled_nodes" ]
607+ role .num_replicas += 1
608+ role_status .replicas .append (
609+ ReplicaStatus (
610+ id = int (replica_id ),
611+ role = role_name ,
612+ state = state ,
613+ hostname = hostname ,
614+ )
615+ )
616+ else : # state == AppState.RUNNING
617+ # NOTE: torchx schedules on slurm with sbatch + heterogenous job
618+ # where each replica is a "sub-job" so `allocated_nodes` will always be 1
619+ # but we deal with jobs that have not been launched with torchx
620+ # which can have multiple hosts per sub-job (count them as replicas)
621+ node_infos = job_resources .get ("allocated_nodes" , [])
622+
623+ for node_info in node_infos :
624+ # NOTE: we expect resource specs for all the nodes to be the same
625+ # NOTE: use allocated (not used/requested) memory since
626+ # users may only specify --cpu, in which case slurm
627+ # uses the (system) configured {mem-per-cpu} * {cpus}
628+ # to allocate memory.
629+ # NOTE: getting gpus is tricky because it modeled as a trackable-resource
630+ # or not configured at all (use total-cpu-on-host as proxy for gpus)
631+ cpu = int (node_info ["cpus_used" ])
632+ memMB = int (node_info ["memory_allocated" ])
633+
634+ hostname = node_info ["nodename" ]
635+
636+ role .resource = Resource (cpu = cpu , memMB = memMB , gpu = - 1 )
637+ role .num_replicas += 1
638+ role_status .replicas .append (
639+ ReplicaStatus (
640+ id = int (replica_id ),
641+ role = role_name ,
642+ state = state ,
643+ hostname = hostname ,
644+ )
645+ )
576646
577647 return DescribeAppResponse (
578648 app_id = app_id ,
579649 roles = list (roles .values ()),
580650 roles_statuses = list (roles_statuses .values ()),
581- state = app_state ,
582- msg = msg ,
651+ state = state ,
583652 )
584653
585654 def log_iter (
0 commit comments