2020import  tempfile 
2121from  dataclasses  import  dataclass 
2222from  datetime  import  datetime 
23+ from  subprocess  import  CalledProcessError , PIPE 
2324from  typing  import  Any , Dict , Iterable , List , Mapping , Optional , Tuple 
2425
2526import  torchx 
3940    macros ,
4041    NONE ,
4142    ReplicaStatus ,
43+     Resource ,
4244    Role ,
4345    RoleStatus ,
4446    runopts ,
6668    "TIMEOUT" : AppState .FAILED ,
6769}
6870
71+ 
72+ def  appstate_from_slurm_state (slurm_state : str ) ->  AppState :
73+     return  SLURM_STATES .get (slurm_state , AppState .UNKNOWN )
74+ 
75+ 
6976SBATCH_JOB_OPTIONS  =  {
7077    "comment" ,
7178    "mail-user" ,
@@ -483,15 +490,34 @@ def _cancel_existing(self, app_id: str) -> None:
483490
484491    def  describe (self , app_id : str ) ->  Optional [DescribeAppResponse ]:
485492        try :
486-             return  self ._describe_sacct (app_id )
487-         except  subprocess .CalledProcessError :
488493            return  self ._describe_squeue (app_id )
494+         except  CalledProcessError  as  e :
495+             # NOTE: squeue errors out with 'slurm_load_jobs error: Invalid job id specified' 
496+             #   if the job does not exist or has finished (e.g. not in PENDING or RUNNING states) 
497+             #   in this case, fall back to the less descriptive but more persistent sacct 
498+             #   (slurm cluster must have accounting storage enabled for sacct to work) 
499+             log .info (
500+                 "unable to get job info for `{}` with `squeue` ({}), trying `sacct`" .format (
501+                     app_id , e .stderr 
502+                 )
503+             )
504+             return  self ._describe_sacct (app_id )
489505
490506    def  _describe_sacct (self , app_id : str ) ->  Optional [DescribeAppResponse ]:
491-         p  =  subprocess .run (
492-             ["sacct" , "--parsable2" , "-j" , app_id ], stdout = subprocess .PIPE , check = True 
493-         )
494-         output  =  p .stdout .decode ("utf-8" ).split ("\n " )
507+         try :
508+             output  =  subprocess .check_output (
509+                 ["sacct" , "--parsable2" , "-j" , app_id ],
510+                 stderr = PIPE ,
511+                 encoding = "utf-8" ,
512+             ).split ("\n " )
513+         except  CalledProcessError  as  e :
514+             log .info (
515+                 "unable to get job info for `{}` with `sacct` ({})" .format (
516+                     app_id , e .stderr 
517+                 )
518+             )
519+             return  None 
520+ 
495521        if  len (output ) <=  1 :
496522            return  None 
497523
@@ -511,11 +537,7 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
511537
512538            state  =  row ["State" ]
513539            msg  =  state 
514-             state_enum  =  SLURM_STATES .get (state )
515-             assert  (
516-                 state_enum 
517-             ), f"failed to translate slurm state { state }   to torchx state" 
518-             app_state  =  state_enum 
540+             app_state  =  appstate_from_slurm_state (state )
519541
520542            role , _ , replica_id  =  row ["JobName" ].rpartition ("-" )
521543            if  not  replica_id  or  not  role :
@@ -540,46 +562,92 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
540562            msg = msg ,
541563        )
542564
543-     def  _describe_squeue (self , app_id : str ) ->  Optional [DescribeAppResponse ]:
544-         p  =  subprocess .run (
545-             ["squeue" , "--json" , "-j" , app_id ], stdout = subprocess .PIPE , check = True 
565+     def  _describe_squeue (self , app_id : str ) ->  DescribeAppResponse :
566+         # squeue errors out with 'slurm_load_jobs error: Invalid job id specified' 
567+         # if the job does not exist or is finished (e.g. not in PENDING or RUNNING state) 
568+         output  =  subprocess .check_output (
569+             ["squeue" , "--json" , "-j" , app_id ], stderr = PIPE , encoding = "utf-8" 
546570        )
547-         output_json  =  json .loads (p .stdout .decode ("utf-8" ))
548571
549-         roles  =  {}
550-         roles_statuses  =  {}
551-         msg  =  "" 
552-         app_state  =  AppState .UNKNOWN 
553-         for  job  in  output_json ["jobs" ]:
554-             state  =  job ["job_state" ][0 ]
555-             msg  =  state 
556-             state_enum  =  SLURM_STATES .get (state )
557-             assert  (
558-                 state_enum 
559-             ), f"failed to translate slurm state { state }   to torchx state" 
560-             app_state  =  state_enum 
572+         output_json  =  json .loads (output )
573+         jobs  =  output_json ["jobs" ]
561574
562-             role , _ , replica_id  =  job ["name" ].rpartition ("-" )
563-             if  not  replica_id  or  not  role :
564-                 # name should always have at least 3 parts but sometimes sacct 
565-                 # is slow to update 
566-                 continue 
567-             if  role  not  in   roles :
568-                 roles [role ] =  Role (name = role , num_replicas = 0 , image = "" )
569-                 roles_statuses [role ] =  RoleStatus (role , [])
570-             roles [role ].num_replicas  +=  1 
571-             roles_statuses [role ].replicas .append (
572-                 ReplicaStatus (
573-                     id = int (replica_id ), role = role , state = app_state , hostname = "" 
575+         roles : dict [str , Role ] =  {}
576+         roles_statuses : dict [str , RoleStatus ] =  {}
577+         state  =  AppState .UNKNOWN 
578+ 
579+         for  job  in  jobs :
580+             # job name is of the form "{role_name}-{replica_id}" 
581+             role_name , _ , replica_id  =  job ["name" ].rpartition ("-" )
582+ 
583+             entrypoint  =  job ["command" ]
584+             image  =  job ["current_working_directory" ]
585+             state  =  appstate_from_slurm_state (job ["job_state" ][0 ])
586+ 
587+             job_resources  =  job ["job_resources" ]
588+ 
589+             role  =  roles .setdefault (
590+                 role_name ,
591+                 Role (
592+                     name = role_name ,
593+                     image = image ,
594+                     entrypoint = entrypoint ,
595+                     num_replicas = 0 ,
574596                ),
575597            )
598+             role_status  =  roles_statuses .setdefault (
599+                 role_name ,
600+                 RoleStatus (role_name , replicas = []),
601+             )
602+ 
603+             if  state  ==  AppState .PENDING :
604+                 # NOTE: torchx launched jobs points to exactly one host 
605+                 #  otherwise, scheduled_nodes could be a node list expression (eg. 'slurm-compute-node[0-20,21,45-47]') 
606+                 hostname  =  job_resources ["scheduled_nodes" ]
607+                 role .num_replicas  +=  1 
608+                 role_status .replicas .append (
609+                     ReplicaStatus (
610+                         id = int (replica_id ),
611+                         role = role_name ,
612+                         state = state ,
613+                         hostname = hostname ,
614+                     )
615+                 )
616+             else :  # state == AppState.RUNNING 
617+                 # NOTE: torchx schedules on slurm with sbatch + heterogenous job 
618+                 # where each replica is a "sub-job" so `allocated_nodes` will always be 1 
619+                 # but we deal with jobs that have not been launched with torchx 
620+                 # which can have multiple hosts per sub-job (count them as replicas) 
621+                 node_infos  =  job_resources .get ("allocated_nodes" , [])
622+                 for  node_info  in  node_infos :
623+                     # NOTE: we expect resource specs for all the nodes to be the same 
624+                     # NOTE: use allocated (not used/requested) memory since 
625+                     #  users may only specify --cpu, in which case slurm 
626+                     #  uses the (system) configured {mem-per-cpu} * {cpus} 
627+                     #  to allocate memory. 
628+                     # NOTE: getting gpus is tricky because it modeled as a trackable-resource 
629+                     #  or not configured at all (use total-cpu-on-host as proxy for gpus) 
630+                     cpu  =  int (node_info ["cpus_used" ])
631+                     memMB  =  int (node_info ["memory_allocated" ])
632+ 
633+                     hostname  =  node_info ["nodename" ]
634+ 
635+                     role .resource  =  Resource (cpu = cpu , memMB = memMB , gpu = - 1 )
636+                     role .num_replicas  +=  1 
637+                     role_status .replicas .append (
638+                         ReplicaStatus (
639+                             id = int (replica_id ),
640+                             role = role_name ,
641+                             state = state ,
642+                             hostname = hostname ,
643+                         )
644+                     )
576645
577646        return  DescribeAppResponse (
578647            app_id = app_id ,
579648            roles = list (roles .values ()),
580649            roles_statuses = list (roles_statuses .values ()),
581-             state = app_state ,
582-             msg = msg ,
650+             state = state ,
583651        )
584652
585653    def  log_iter (
0 commit comments