@@ -172,6 +172,12 @@ def _req_keepvars_default(self):
172172 "specification."
173173 ).tag (config = True )
174174
175+ connect_to_job_cmd = Unicode ('' ,
176+ help = "Command to connect to running batch job and forward the port "
177+ "of the running notebook to the Hub. If empty, direct connectivity is assumed. "
178+ "Uses self.job_id as {job_id} and the self.port as {port}."
179+ ).tag (config = True )
180+
175181 # Raw output of job submission command unless overridden
176182 job_id = Unicode ()
177183
@@ -200,6 +206,18 @@ def cmd_formatted_for_batch(self):
200206 """The command which is substituted inside of the batch script"""
201207 return ' ' .join ([self .batchspawner_singleuser_cmd ] + self .cmd + self .get_args ())
202208
209+ async def connect_to_job (self ):
210+ """This command ensures the port of the singleuser server is reachable from the
211+ Batchspawner machine. By default, it does nothing, i.e. direct connectivity
212+ is assumed.
213+ """
214+ subvars = self .get_req_subvars ()
215+ subvars ['job_id' ] = self .job_id
216+ subvars ['port' ] = self .port
217+ cmd = ' ' .join ((format_template (self .exec_prefix , ** subvars ),
218+ format_template (self .connect_to_job_cmd , ** subvars )))
219+ await self .run_background_command (cmd )
220+
203221 async def run_command (self , cmd , input = None , env = None ):
204222 proc = await asyncio .create_subprocess_shell (cmd , env = env ,
205223 stdin = asyncio .subprocess .PIPE ,
@@ -243,6 +261,46 @@ async def run_command(self, cmd, input=None, env=None):
243261 out = out .decode ().strip ()
244262 return out
245263
264+ # List of running background processes, e.g. used by connect_to_job.
265+ background_processes = []
266+
267+ async def _async_wait_process (self , sleep_time ):
268+ """Asynchronously sleeping process for delayed checks"""
269+ await asyncio .sleep (sleep_time )
270+
271+ async def run_background_command (self , cmd , startup_check_delay = 1 , input = None , env = None ):
272+ """Runs the given background command, adds it to background_processes,
273+ and checks if the command is still running after startup_check_delay."""
274+ background_process = self .run_command (cmd , input , env )
275+ success_check_delay = self ._async_wait_process (startup_check_delay )
276+
277+ # Start up both the success check process and the actual process.
278+ done , pending = await asyncio .wait ([background_process , success_check_delay ], return_when = asyncio .FIRST_COMPLETED )
279+
280+ # If the success check process is the one which exited first, all is good, else fail.
281+ if list (done )[0 ]._coro == success_check_delay :
282+ background_task = list (pending )[0 ]
283+ self .background_processes .append (background_task )
284+ return background_task
285+ else :
286+ self .log .error ("Background command exited early: %s" % cmd )
287+ gather_pending = asyncio .gather (* pending )
288+ gather_pending .cancel ()
289+ try :
290+ self .log .debug ("Cancelling pending success check task..." )
291+ await gather_pending
292+ except asyncio .CancelledError :
293+ self .log .debug ("Cancel was successful." )
294+ pass
295+
296+ # Retrieve exception from "done" process.
297+ try :
298+ gather_done = asyncio .gather (* done )
299+ await gather_done
300+ except :
301+ self .log .debug ("Retrieving exception from failed background task..." )
302+ raise RuntimeError ('{} failed!' .format (cmd ))
303+
246304 async def _get_batch_script (self , ** subvars ):
247305 """Format batch script from vars"""
248306 # Could be overridden by subclasses, but mainly useful for testing
@@ -270,6 +328,27 @@ async def submit_batch_script(self):
270328 self .job_id = ''
271329 return self .job_id
272330
331+ def background_tasks_ok (self ):
332+ # Check background processes.
333+ if self .background_processes :
334+ self .log .debug ('Checking background processes...' )
335+ for background_process in self .background_processes :
336+ if background_process .done ():
337+ self .log .debug ('Found a background process in state "done"...' )
338+ try :
339+ background_exception = background_process .exception ()
340+ except asyncio .CancelledError :
341+ self .log .error ('Background process was cancelled!' )
342+ if background_exception :
343+ self .log .error ('Background process exited with an exception:' )
344+ self .log .error (background_exception )
345+ self .log .error ('At least one background process exited!' )
346+ return False
347+ else :
348+ self .log .debug ('Found a not-yet-done background process...' )
349+ self .log .debug ('All background processes still running.' )
350+ return True
351+
273352 # Override if your batch system needs something more elaborate to query the job status
274353 batch_query_cmd = Unicode ('' ,
275354 help = "Command to run to query job status. Formatted using req_xyz traits as {xyz} "
@@ -314,6 +393,29 @@ async def cancel_batch_job(self):
314393 cmd = ' ' .join ((format_template (self .exec_prefix , ** subvars ),
315394 format_template (self .batch_cancel_cmd , ** subvars )))
316395 self .log .info ('Cancelling job ' + self .job_id + ': ' + cmd )
396+
397+ if self .background_processes :
398+ self .log .debug ('Job being cancelled, cancelling background processes...' )
399+ for background_process in self .background_processes :
400+ if not background_process .cancelled ():
401+ try :
402+ background_process .cancel ()
403+ except :
404+ self .log .error ('Encountered an exception cancelling background process...' )
405+ self .log .debug ('Cancelled background process, waiting for it to finish...' )
406+ try :
407+ await asyncio .wait ([background_process ])
408+ except asyncio .CancelledError :
409+ self .log .error ('Successfully cancelled background process.' )
410+ pass
411+ except :
412+ self .log .error ('Background process exited with another exception!' )
413+ raise
414+ else :
415+ self .log .debug ('Background process already cancelled...' )
416+ self .background_processes .clear ()
417+ self .log .debug ('All background processes cancelled.' )
418+
317419 await self .run_command (cmd )
318420
319421 def load_state (self , state ):
@@ -361,6 +463,13 @@ async def poll(self):
361463 """Poll the process"""
362464 status = await self .query_job_status ()
363465 if status in (JobStatus .PENDING , JobStatus .RUNNING , JobStatus .UNKNOWN ):
466+ if not self .background_tasks_ok ():
467+ self .log .debug ('Going to stop job, since background tasks have failed!' )
468+ await self .stop (now = True )
469+ status = await self .query_job_status ()
470+ if status not in (JobStatus .PENDING , JobStatus .RUNNING , JobStatus .UNKNOWN ):
471+ self .clear_state ()
472+ return 1
364473 return None
365474 else :
366475 self .clear_state ()
@@ -420,6 +529,9 @@ async def start(self):
420529 self .job_id , self .ip , self .port )
421530 )
422531
532+ if self .connect_to_job_cmd :
533+ await self .connect_to_job ()
534+
423535 return self .ip , self .port
424536
425537 async def stop (self , now = False ):
0 commit comments