2020 from ecephys_spike_sorting .scripts .create_input_json import createInputJson
2121 from ecephys_spike_sorting .scripts .helpers import SpikeGLX_utils
2222except Exception as e :
23- print (f'Error in loading "ecephys_spike_sorting" package - { str (e )} ' )
23+ print (f'Warning: Failed loading "ecephys_spike_sorting" package - { str (e )} ' )
2424
2525# import pykilosort package
2626try :
2727 import pykilosort
2828except Exception as e :
29- print (f'Error in loading "pykilosort" package - { str (e )} ' )
29+ print (f'Warning: Failed loading "pykilosort" package - { str (e )} ' )
3030
3131
3232class SGLXKilosortPipeline :
@@ -67,7 +67,6 @@ def __init__(
6767 ni_present = False ,
6868 ni_extract_string = None ,
6969 ):
70-
7170 self ._npx_input_dir = pathlib .Path (npx_input_dir )
7271
7372 self ._ks_output_dir = pathlib .Path (ks_output_dir )
@@ -85,6 +84,13 @@ def __init__(
8584 self ._json_directory = self ._ks_output_dir / "json_configs"
8685 self ._json_directory .mkdir (parents = True , exist_ok = True )
8786
87+ self ._module_input_json = (
88+ self ._json_directory / f"{ self ._npx_input_dir .name } -input.json"
89+ )
90+ self ._module_logfile = (
91+ self ._json_directory / f"{ self ._npx_input_dir .name } -run_modules-log.txt"
92+ )
93+
8894 self ._CatGT_finished = False
8995 self .ks_input_params = None
9096 self ._modules_input_hash = None
@@ -223,20 +229,20 @@ def generate_modules_input_json(self):
223229 ** params ,
224230 )
225231
226- self ._modules_input_hash = dict_to_uuid (self .ks_input_params )
232+ self ._modules_input_hash = dict_to_uuid (dict ( self ._params , KS2ver = self . _KS2ver ) )
227233
228- def run_modules (self ):
234+ def run_modules (self , modules_to_run = None ):
229235 if self ._run_CatGT and not self ._CatGT_finished :
230236 self .run_CatGT ()
231237
232238 print ("---- Running Modules ----" )
233239 self .generate_modules_input_json ()
234240 module_input_json = self ._module_input_json .as_posix ()
235- module_logfile = module_input_json . replace (
236- "-input.json" , "-run_modules-log.txt"
237- )
241+ module_logfile = self . _module_logfile . as_posix ()
242+
243+ modules = modules_to_run or self . _modules
238244
239- for module in self . _modules :
245+ for module in modules :
240246 module_status = self ._get_module_status (module )
241247 if module_status ["completion_time" ] is not None :
242248 continue
@@ -312,13 +318,11 @@ def _update_module_status(self, updated_module_status={}):
312318 else :
313319 # handle cases of processing rerun on different parameters (the hash changes)
314320 # delete outdated files
315- outdated_files = [
316- f
321+ [
322+ f . unlink ()
317323 for f in self ._json_directory .glob ("*" )
318324 if f .is_file () and f .name != self ._module_input_json .name
319325 ]
320- for f in outdated_files :
321- f .unlink ()
322326
323327 modules_status = {
324328 module : {"start_time" : None , "completion_time" : None , "duration" : None }
@@ -371,14 +375,26 @@ def _update_total_duration(self):
371375 for k , v in modules_status .items ()
372376 if k not in ("cumulative_execution_duration" , "total_duration" )
373377 )
378+
379+ for m in self ._modules :
380+ first_start_time = modules_status [m ]["start_time" ]
381+ if first_start_time is not None :
382+ break
383+
384+ for m in self ._modules [::- 1 ]:
385+ last_completion_time = modules_status [m ]["completion_time" ]
386+ if last_completion_time is not None :
387+ break
388+
389+ if first_start_time is None or last_completion_time is None :
390+ return
391+
374392 total_duration = (
375393 datetime .strptime (
376- modules_status [ self . _modules [ - 1 ]][ "completion_time" ] ,
394+ last_completion_time ,
377395 "%Y-%m-%d %H:%M:%S.%f" ,
378396 )
379- - datetime .strptime (
380- modules_status [self ._modules [0 ]]["start_time" ], "%Y-%m-%d %H:%M:%S.%f"
381- )
397+ - datetime .strptime (first_start_time , "%Y-%m-%d %H:%M:%S.%f" )
382398 ).total_seconds ()
383399 self ._update_module_status (
384400 {
@@ -414,7 +430,6 @@ class OpenEphysKilosortPipeline:
414430 def __init__ (
415431 self , npx_input_dir : str , ks_output_dir : str , params : dict , KS2ver : str
416432 ):
417-
418433 self ._npx_input_dir = pathlib .Path (npx_input_dir )
419434
420435 self ._ks_output_dir = pathlib .Path (ks_output_dir )
@@ -426,7 +441,13 @@ def __init__(
426441 self ._json_directory = self ._ks_output_dir / "json_configs"
427442 self ._json_directory .mkdir (parents = True , exist_ok = True )
428443
429- self ._median_subtraction_status = {}
444+ self ._module_input_json = (
445+ self ._json_directory / f"{ self ._npx_input_dir .name } -input.json"
446+ )
447+ self ._module_logfile = (
448+ self ._json_directory / f"{ self ._npx_input_dir .name } -run_modules-log.txt"
449+ )
450+
430451 self .ks_input_params = None
431452 self ._modules_input_hash = None
432453 self ._modules_input_hash_fp = None
@@ -451,9 +472,6 @@ def make_chanmap_file(self):
451472
452473 def generate_modules_input_json (self ):
453474 self .make_chanmap_file ()
454- self ._module_input_json = (
455- self ._json_directory / f"{ self ._npx_input_dir .name } -input.json"
456- )
457475
458476 continuous_file = self ._get_raw_data_filepaths ()
459477
@@ -497,35 +515,37 @@ def generate_modules_input_json(self):
497515 ** params ,
498516 )
499517
500- self ._modules_input_hash = dict_to_uuid (self .ks_input_params )
518+ self ._modules_input_hash = dict_to_uuid (dict ( self ._params , KS2ver = self . _KS2ver ) )
501519
502- def run_modules (self ):
520+ def run_modules (self , modules_to_run = None ):
503521 print ("---- Running Modules ----" )
504522 self .generate_modules_input_json ()
505523 module_input_json = self ._module_input_json .as_posix ()
506- module_logfile = module_input_json .replace (
507- "-input.json" , "-run_modules-log.txt"
508- )
524+ module_logfile = self ._module_logfile .as_posix ()
509525
510- for module in self ._modules :
526+ modules = modules_to_run or self ._modules
527+
528+ for module in modules :
511529 module_status = self ._get_module_status (module )
512530 if module_status ["completion_time" ] is not None :
513531 continue
514532
515- if module == "median_subtraction" and self ._median_subtraction_status :
516- median_subtraction_status = self ._get_module_status (
517- "median_subtraction"
518- )
519- median_subtraction_status ["duration" ] = self ._median_subtraction_status [
520- "duration"
521- ]
522- median_subtraction_status ["completion_time" ] = datetime .strptime (
523- median_subtraction_status ["start_time" ], "%Y-%m-%d %H:%M:%S.%f"
524- ) + timedelta (seconds = median_subtraction_status ["duration" ])
525- self ._update_module_status (
526- {"median_subtraction" : median_subtraction_status }
533+ if module == "median_subtraction" :
534+ median_subtraction_duration = (
535+ self ._get_median_subtraction_duration_from_log ()
527536 )
528- continue
537+ if median_subtraction_duration is not None :
538+ median_subtraction_status = self ._get_module_status (
539+ "median_subtraction"
540+ )
541+ median_subtraction_status ["duration" ] = median_subtraction_duration
542+ median_subtraction_status ["completion_time" ] = datetime .strptime (
543+ median_subtraction_status ["start_time" ], "%Y-%m-%d %H:%M:%S.%f"
544+ ) + timedelta (seconds = median_subtraction_status ["duration" ])
545+ self ._update_module_status (
546+ {"median_subtraction" : median_subtraction_status }
547+ )
548+ continue
529549
530550 module_output_json = self ._get_module_output_json_filename (module )
531551 command = [
@@ -576,26 +596,11 @@ def _get_raw_data_filepaths(self):
576596 assert "depth_estimation" in self ._modules
577597 continuous_file = self ._ks_output_dir / "continuous.dat"
578598 if continuous_file .exists ():
579- if raw_ap_fp .stat ().st_mtime < continuous_file .stat ().st_mtime :
580- # if the copied continuous.dat was actually modified,
581- # median_subtraction may have been completed - let's check
582- module_input_json = self ._module_input_json .as_posix ()
583- module_logfile = module_input_json .replace (
584- "-input.json" , "-run_modules-log.txt"
585- )
586- with open (module_logfile , "r" ) as f :
587- previous_line = ""
588- for line in f .readlines ():
589- if line .startswith (
590- "ecephys spike sorting: median subtraction module"
591- ) and previous_line .startswith ("Total processing time:" ):
592- # regex to search for the processing duration - a float value
593- duration = int (
594- re .search ("\d+\.?\d+" , previous_line ).group ()
595- )
596- self ._median_subtraction_status ["duration" ] = duration
597- return continuous_file
598- previous_line = line
599+ if raw_ap_fp .stat ().st_mtime == continuous_file .stat ().st_mtime :
600+ return continuous_file
601+ else :
602+ if self ._module_logfile .exists ():
603+ return continuous_file
599604
600605 shutil .copy2 (raw_ap_fp , continuous_file )
601606 return continuous_file
@@ -614,13 +619,11 @@ def _update_module_status(self, updated_module_status={}):
614619 else :
615620 # handle cases of processing rerun on different parameters (the hash changes)
616621 # delete outdated files
617- outdated_files = [
618- f
622+ [
623+ f . unlink ()
619624 for f in self ._json_directory .glob ("*" )
620625 if f .is_file () and f .name != self ._module_input_json .name
621626 ]
622- for f in outdated_files :
623- f .unlink ()
624627
625628 modules_status = {
626629 module : {"start_time" : None , "completion_time" : None , "duration" : None }
@@ -673,14 +676,26 @@ def _update_total_duration(self):
673676 for k , v in modules_status .items ()
674677 if k not in ("cumulative_execution_duration" , "total_duration" )
675678 )
679+
680+ for m in self ._modules :
681+ first_start_time = modules_status [m ]["start_time" ]
682+ if first_start_time is not None :
683+ break
684+
685+ for m in self ._modules [::- 1 ]:
686+ last_completion_time = modules_status [m ]["completion_time" ]
687+ if last_completion_time is not None :
688+ break
689+
690+ if first_start_time is None or last_completion_time is None :
691+ return
692+
676693 total_duration = (
677694 datetime .strptime (
678- modules_status [ self . _modules [ - 1 ]][ "completion_time" ] ,
695+ last_completion_time ,
679696 "%Y-%m-%d %H:%M:%S.%f" ,
680697 )
681- - datetime .strptime (
682- modules_status [self ._modules [0 ]]["start_time" ], "%Y-%m-%d %H:%M:%S.%f"
683- )
698+ - datetime .strptime (first_start_time , "%Y-%m-%d %H:%M:%S.%f" )
684699 ).total_seconds ()
685700 self ._update_module_status (
686701 {
@@ -689,6 +704,26 @@ def _update_total_duration(self):
689704 }
690705 )
691706
707+ def _get_median_subtraction_duration_from_log (self ):
708+ raw_ap_fp = self ._npx_input_dir / "continuous.dat"
709+ continuous_file = self ._ks_output_dir / "continuous.dat"
710+ if raw_ap_fp .stat ().st_mtime < continuous_file .stat ().st_mtime :
711+ # if the copied continuous.dat was actually modified,
712+ # median_subtraction may have been completed - let's check
713+ if self ._module_logfile .exists ():
714+ with open (self ._module_logfile , "r" ) as f :
715+ previous_line = ""
716+ for line in f .readlines ():
717+ if line .startswith (
718+ "ecephys spike sorting: median subtraction module"
719+ ) and previous_line .startswith ("Total processing time:" ):
720+ # regex to search for the processing duration - a float value
721+ duration = int (
722+ re .search ("\d+\.?\d+" , previous_line ).group ()
723+ )
724+ return duration
725+ previous_line = line
726+
692727
693728def run_pykilosort (
694729 continuous_file ,
0 commit comments