diff --git a/mesmerize_core/algorithms/_utils.py b/mesmerize_core/algorithms/_utils.py new file mode 100644 index 0000000..0bdf1a8 --- /dev/null +++ b/mesmerize_core/algorithms/_utils.py @@ -0,0 +1,76 @@ +from contextlib import contextmanager +import os +import psutil +from typing import (Optional, Union, Generator, Protocol, + Callable, TypeVar, Sequence, Iterable, runtime_checkable) + +import caiman as cm +from caiman.cluster import setup_cluster +from ipyparallel import DirectView +from multiprocessing.pool import Pool + + +RetVal = TypeVar("RetVal") +@runtime_checkable +class CustomCluster(Protocol): + """ + Protocol for a cluster that is not a multiprocessing pool + (including ipyparallel.DirectView) + """ + + def map_sync( + self, fn: Callable[..., RetVal], args: Iterable + ) -> Sequence[RetVal]: ... + + def __len__(self) -> int: + """return number of workers""" + ... + + +Cluster = Union[Pool, CustomCluster, DirectView] + + +def get_n_processes(dview: Optional[Cluster]) -> int: + """Infer number of processes in a multiprocessing or ipyparallel cluster""" + if isinstance(dview, Pool): + assert hasattr(dview, '_processes'), "Pool not keeping track of # of processes?" + return dview._processes # type: ignore + elif dview is not None: + return len(dview) + else: + return 1 + + +@contextmanager +def ensure_server(dview: Optional[Cluster]) -> Generator[tuple[Cluster, int], None, None]: + """ + Context manager that passes through an existing 'dview' or + opens up a multiprocessing server if none is passed in. + If a server was opened, closes it upon exit. + Usage: `with ensure_server(dview) as (dview, n_processes):` + """ + if dview is not None: + yield dview, get_n_processes(dview) + else: + # no cluster passed in, so open one + procs_available = psutil.cpu_count() + if procs_available is None: + raise RuntimeError('Cannot determine number of processes') + + if "MESMERIZE_N_PROCESSES" in os.environ.keys(): + try: + n_processes = int(os.environ["MESMERIZE_N_PROCESSES"]) + except: + n_processes = procs_available - 1 + else: + n_processes = procs_available - 1 + + # Start cluster for parallel processing + _, dview, n_processes = setup_cluster( + backend="multiprocessing", n_processes=n_processes, single_thread=False + ) + assert isinstance(dview, Pool) and isinstance(n_processes, int), 'setup_cluster with multiprocessing did not return a Pool' + try: + yield dview, n_processes + finally: + cm.stop_server(dview=dview) diff --git a/mesmerize_core/algorithms/cnmf.py b/mesmerize_core/algorithms/cnmf.py index dd7381a..9328cce 100644 --- a/mesmerize_core/algorithms/cnmf.py +++ b/mesmerize_core/algorithms/cnmf.py @@ -4,24 +4,24 @@ import caiman as cm from caiman.source_extraction.cnmf import cnmf as cnmf from caiman.source_extraction.cnmf.params import CNMFParams -import psutil import numpy as np import traceback from pathlib import Path, PurePosixPath from shutil import move as move_file -import os import time # prevent circular import if __name__ in ["__main__", "__mp_main__"]: # when running in subprocess from mesmerize_core import set_parent_raw_data_path, load_batch from mesmerize_core.utils import IS_WINDOWS + from mesmerize_core.algorithms._utils import ensure_server else: # when running with local backend from ..batch_utils import set_parent_raw_data_path, load_batch from ..utils import IS_WINDOWS + from ._utils import ensure_server -def run_algo(batch_path, uuid, data_path: str = None): +def run_algo(batch_path, uuid, data_path: str = None, dview=None): algo_start = time.time() set_parent_raw_data_path(data_path) @@ -42,103 +42,85 @@ def run_algo(batch_path, uuid, data_path: str = None): f"Starting CNMF item:\n{item}\nWith params:{params}" ) - # adapted from current demo notebook - if "MESMERIZE_N_PROCESSES" in os.environ.keys(): + with ensure_server(dview) as (dview, n_processes): + + # merge cnmf and eval kwargs into one dict + cnmf_params = CNMFParams(params_dict=params["main"]) + # Run CNMF, denote boolean 'success' if CNMF completes w/out error try: - n_processes = int(os.environ["MESMERIZE_N_PROCESSES"]) - except: - n_processes = psutil.cpu_count() - 1 - else: - n_processes = psutil.cpu_count() - 1 - # Start cluster for parallel processing - c, dview, n_processes = cm.cluster.setup_cluster( - backend="local", n_processes=n_processes, single_thread=False - ) + fname_new = cm.save_memmap( + [input_movie_path], base_name=f"{uuid}_cnmf-memmap_", order="C", dview=dview + ) - # merge cnmf and eval kwargs into one dict - cnmf_params = CNMFParams(params_dict=params["main"]) - # Run CNMF, denote boolean 'success' if CNMF completes w/out error - try: - fname_new = cm.save_memmap( - [input_movie_path], base_name=f"{uuid}_cnmf-memmap_", order="C", dview=dview - ) + print("making memmap") - print("making memmap") + Yr, dims, T = cm.load_memmap(fname_new) - Yr, dims, T = cm.load_memmap(fname_new) - images = np.reshape(Yr.T, [T] + list(dims), order="F") + images = np.reshape(Yr.T, [T] + list(dims), order="F") - proj_paths = dict() - for proj_type in ["mean", "std", "max"]: - p_img = getattr(np, f"nan{proj_type}")(images, axis=0) - proj_paths[proj_type] = output_dir.joinpath( - f"{uuid}_{proj_type}_projection.npy" - ) - np.save(str(proj_paths[proj_type]), p_img) - - # in fname new load in memmap order C - cm.stop_server(dview=dview) - c, dview, n_processes = cm.cluster.setup_cluster( - backend="local", n_processes=None, single_thread=False - ) - - print("performing CNMF") - cnm = cnmf.CNMF(n_processes, params=cnmf_params, dview=dview) - - print("fitting images") - cnm.fit(images) - # - if "refit" in params.keys(): - if params["refit"] is True: - print("refitting") - cnm = cnm.refit(images, dview=dview) - - print("performing eval") - cnm.estimates.evaluate_components(images, cnm.params, dview=dview) - - output_path = output_dir.joinpath(f"{uuid}.hdf5") - - cnm.save(str(output_path)) - - Cn = cm.local_correlations(images, swap_dim=False) - Cn[np.isnan(Cn)] = 0 - - corr_img_path = output_dir.joinpath(f"{uuid}_cn.npy") - np.save(str(corr_img_path), Cn, allow_pickle=False) - - # output dict for dataframe row (pd.Series) - d = dict() - - cnmf_memmap_path = output_dir.joinpath(Path(fname_new).name) - if IS_WINDOWS: - Yr._mmap.close() # accessing private attr but windows is annoying otherwise - move_file(fname_new, cnmf_memmap_path) - - # save paths as relative path strings with forward slashes - cnmf_hdf5_path = str(PurePosixPath(output_path.relative_to(output_dir.parent))) - cnmf_memmap_path = str( - PurePosixPath(cnmf_memmap_path.relative_to(output_dir.parent)) - ) - corr_img_path = str(PurePosixPath(corr_img_path.relative_to(output_dir.parent))) - for proj_type in proj_paths.keys(): - d[f"{proj_type}-projection-path"] = str( - PurePosixPath(proj_paths[proj_type].relative_to(output_dir.parent)) - ) + proj_paths = dict() + for proj_type in ["mean", "std", "max"]: + p_img = getattr(np, f"nan{proj_type}")(images, axis=0) + proj_paths[proj_type] = output_dir.joinpath( + f"{uuid}_{proj_type}_projection.npy" + ) + np.save(str(proj_paths[proj_type]), p_img) + + print("performing CNMF") + cnm = cnmf.CNMF(n_processes, params=cnmf_params, dview=dview) + + print("fitting images") + cnm.fit(images) + # + if "refit" in params.keys(): + if params["refit"] is True: + print("refitting") + cnm = cnm.refit(images, dview=dview) + + print("performing eval") + cnm.estimates.evaluate_components(images, cnm.params, dview=dview) - d.update( - { - "cnmf-hdf5-path": cnmf_hdf5_path, - "cnmf-memmap-path": cnmf_memmap_path, - "corr-img-path": corr_img_path, - "success": True, - "traceback": None, - } - ) + output_path = output_dir.joinpath(f"{uuid}.hdf5") - except: - d = {"success": False, "traceback": traceback.format_exc()} + cnm.save(str(output_path)) - cm.stop_server(dview=dview) + Cn = cm.local_correlations(images, swap_dim=False) + Cn[np.isnan(Cn)] = 0 + + corr_img_path = output_dir.joinpath(f"{uuid}_cn.npy") + np.save(str(corr_img_path), Cn, allow_pickle=False) + + # output dict for dataframe row (pd.Series) + d = dict() + + cnmf_memmap_path = output_dir.joinpath(Path(fname_new).name) + if IS_WINDOWS: + Yr._mmap.close() # accessing private attr but windows is annoying otherwise + move_file(fname_new, cnmf_memmap_path) + + # save paths as relative path strings with forward slashes + cnmf_hdf5_path = str(PurePosixPath(output_path.relative_to(output_dir.parent))) + cnmf_memmap_path = str( + PurePosixPath(cnmf_memmap_path.relative_to(output_dir.parent)) + ) + corr_img_path = str(PurePosixPath(corr_img_path.relative_to(output_dir.parent))) + for proj_type in proj_paths.keys(): + d[f"{proj_type}-projection-path"] = str( + PurePosixPath(proj_paths[proj_type].relative_to(output_dir.parent)) + ) + + d.update( + { + "cnmf-hdf5-path": cnmf_hdf5_path, + "cnmf-memmap-path": cnmf_memmap_path, + "corr-img-path": corr_img_path, + "success": True, + "traceback": None, + } + ) + + except: + d = {"success": False, "traceback": traceback.format_exc()} runtime = round(time.time() - algo_start, 2) df.caiman.update_item_with_results(uuid, d, runtime) diff --git a/mesmerize_core/algorithms/cnmfe.py b/mesmerize_core/algorithms/cnmfe.py index d4f1858..3940727 100644 --- a/mesmerize_core/algorithms/cnmfe.py +++ b/mesmerize_core/algorithms/cnmfe.py @@ -3,22 +3,22 @@ import caiman as cm from caiman.source_extraction.cnmf import cnmf as cnmf from caiman.source_extraction.cnmf.params import CNMFParams -import psutil import traceback from pathlib import Path, PurePosixPath from shutil import move as move_file -import os import time if __name__ in ["__main__", "__mp_main__"]: # when running in subprocess from mesmerize_core import set_parent_raw_data_path, load_batch from mesmerize_core.utils import IS_WINDOWS + from mesmerize_core.algorithms._utils import ensure_server else: # when running with local backend from ..batch_utils import set_parent_raw_data_path, load_batch from ..utils import IS_WINDOWS + from ._utils import ensure_server -def run_algo(batch_path, uuid, data_path: str = None): +def run_algo(batch_path, uuid, data_path: str = None, dview=None): algo_start = time.time() set_parent_raw_data_path(data_path) @@ -35,91 +35,77 @@ def run_algo(batch_path, uuid, data_path: str = None): params = item["params"] print("cnmfe params:", params) - # adapted from current demo notebook - if "MESMERIZE_N_PROCESSES" in os.environ.keys(): + with ensure_server(dview) as (dview, n_processes): try: - n_processes = int(os.environ["MESMERIZE_N_PROCESSES"]) - except: - n_processes = psutil.cpu_count() - 1 - else: - n_processes = psutil.cpu_count() - 1 - # Start cluster for parallel processing - c, dview, n_processes = cm.cluster.setup_cluster( - backend="local", n_processes=n_processes, single_thread=False - ) - - try: - fname_new = cm.save_memmap( - [input_movie_path], base_name=f"{uuid}_cnmf-memmap_", order="C", dview=dview - ) - - print("making memmap") - Yr, dims, T = cm.load_memmap(fname_new) - images = np.reshape(Yr.T, [T] + list(dims), order="F") - - # TODO: if projections already exist from mcorr we don't - # need to waste compute time re-computing them here - proj_paths = dict() - for proj_type in ["mean", "std", "max"]: - p_img = getattr(np, f"nan{proj_type}")(images, axis=0) - proj_paths[proj_type] = output_dir.joinpath( - f"{uuid}_{proj_type}_projection.npy" + fname_new = cm.save_memmap( + [input_movie_path], base_name=f"{uuid}_cnmf-memmap_", order="C", dview=dview ) - np.save(str(proj_paths[proj_type]), p_img) - d = dict() # for output + print("making memmap") + Yr, dims, T = cm.load_memmap(fname_new) + images = np.reshape(Yr.T, [T] + list(dims), order="F") + + # TODO: if projections already exist from mcorr we don't + # need to waste compute time re-computing them here + proj_paths = dict() + for proj_type in ["mean", "std", "max"]: + p_img = getattr(np, f"nan{proj_type}")(images, axis=0) + proj_paths[proj_type] = output_dir.joinpath( + f"{uuid}_{proj_type}_projection.npy" + ) + np.save(str(proj_paths[proj_type]), p_img) + + d = dict() # for output + + # force the CNMFE params + cnmfe_params_dict = { + "method_init": "corr_pnr", + "n_processes": n_processes, + "only_init": True, # for 1p + "center_psf": True, # for 1p + "normalize_init": False, # for 1p + } - # force the CNMFE params - cnmfe_params_dict = { - "method_init": "corr_pnr", - "n_processes": n_processes, - "only_init": True, # for 1p - "center_psf": True, # for 1p - "normalize_init": False, # for 1p - } + params_dict = {**cnmfe_params_dict, **params["main"]} - params_dict = {**cnmfe_params_dict, **params["main"]} + cnmfe_params_dict = CNMFParams(params_dict=params_dict) + cnm = cnmf.CNMF(n_processes=n_processes, dview=dview, params=cnmfe_params_dict) + print("Performing CNMFE") + cnm.fit(images) + print("evaluating components") + cnm.estimates.evaluate_components(images, cnm.params, dview=dview) - cnmfe_params_dict = CNMFParams(params_dict=params_dict) - cnm = cnmf.CNMF(n_processes=n_processes, dview=dview, params=cnmfe_params_dict) - print("Performing CNMFE") - cnm.fit(images) - print("evaluating components") - cnm.estimates.evaluate_components(images, cnm.params, dview=dview) + cnmf_hdf5_path = output_dir.joinpath(f"{uuid}.hdf5") + cnm.save(str(cnmf_hdf5_path)) - cnmf_hdf5_path = output_dir.joinpath(f"{uuid}.hdf5") - cnm.save(str(cnmf_hdf5_path)) + # save output paths to outputs dict + d["cnmf-hdf5-path"] = cnmf_hdf5_path.relative_to(output_dir.parent) - # save output paths to outputs dict - d["cnmf-hdf5-path"] = cnmf_hdf5_path.relative_to(output_dir.parent) + for proj_type in proj_paths.keys(): + d[f"{proj_type}-projection-path"] = proj_paths[proj_type].relative_to( + output_dir.parent + ) - for proj_type in proj_paths.keys(): - d[f"{proj_type}-projection-path"] = proj_paths[proj_type].relative_to( - output_dir.parent - ) + cnmf_memmap_path = output_dir.joinpath(Path(fname_new).name) + if IS_WINDOWS: + Yr._mmap.close() # accessing private attr but windows is annoying otherwise + move_file(fname_new, cnmf_memmap_path) - cnmf_memmap_path = output_dir.joinpath(Path(fname_new).name) - if IS_WINDOWS: - Yr._mmap.close() # accessing private attr but windows is annoying otherwise - move_file(fname_new, cnmf_memmap_path) - - # save path as relative path strings with forward slashes - cnmfe_memmap_path = str( - PurePosixPath(cnmf_memmap_path.relative_to(output_dir.parent)) - ) - - d.update( - { - "cnmf-memmap-path": cnmfe_memmap_path, - "success": True, - "traceback": None, - } - ) + # save path as relative path strings with forward slashes + cnmfe_memmap_path = str( + PurePosixPath(cnmf_memmap_path.relative_to(output_dir.parent)) + ) - except: - d = {"success": False, "traceback": traceback.format_exc()} + d.update( + { + "cnmf-memmap-path": cnmfe_memmap_path, + "success": True, + "traceback": None, + } + ) - cm.stop_server(dview=dview) + except: + d = {"success": False, "traceback": traceback.format_exc()} runtime = round(time.time() - algo_start, 2) df.caiman.update_item_with_results(uuid, d, runtime) diff --git a/mesmerize_core/algorithms/mcorr.py b/mesmerize_core/algorithms/mcorr.py index 484130d..068faa3 100644 --- a/mesmerize_core/algorithms/mcorr.py +++ b/mesmerize_core/algorithms/mcorr.py @@ -4,7 +4,6 @@ from caiman.source_extraction.cnmf.params import CNMFParams from caiman.motion_correction import MotionCorrect from caiman.summary_images import local_correlations_movie_offline -import psutil import os from pathlib import Path, PurePosixPath import numpy as np @@ -14,11 +13,13 @@ # prevent circular import if __name__ in ["__main__", "__mp_main__"]: # when running in subprocess from mesmerize_core import set_parent_raw_data_path, load_batch + from mesmerize_core.algorithms._utils import ensure_server else: # when running with local backend from ..batch_utils import set_parent_raw_data_path, load_batch + from ._utils import ensure_server -def run_algo(batch_path, uuid, data_path: str = None): +def run_algo(batch_path, uuid, data_path: str = None, dview=None): algo_start = time.time() set_parent_raw_data_path(data_path) @@ -39,115 +40,101 @@ def run_algo(batch_path, uuid, data_path: str = None): params = item["params"] - # adapted from current demo notebook - if "MESMERIZE_N_PROCESSES" in os.environ.keys(): + with ensure_server(dview) as (dview, n_processes): + print("starting mc") + + rel_params = dict(params["main"]) + opts = CNMFParams(params_dict=rel_params) + # Run MC, denote boolean 'success' if MC completes w/out error try: - n_processes = int(os.environ["MESMERIZE_N_PROCESSES"]) - except: - n_processes = psutil.cpu_count() - 1 - else: - n_processes = psutil.cpu_count() - 1 - - print("starting mc") - # Start cluster for parallel processing - c, dview, n_processes = cm.cluster.setup_cluster( - backend="local", n_processes=n_processes, single_thread=False - ) - - rel_params = dict(params["main"]) - opts = CNMFParams(params_dict=rel_params) - # Run MC, denote boolean 'success' if MC completes w/out error - try: - # Run MC - fnames = [input_movie_path] - mc = MotionCorrect(fnames, dview=dview, **opts.get_group("motion")) - mc.motion_correct(save_movie=True) - - # find path to mmap file - memmap_output_path_temp = df.paths.resolve(mc.mmap_file[0]) - - # filename to move the output back to data dir - mcorr_memmap_path = output_dir.joinpath( - f"{uuid}-{memmap_output_path_temp.name}" - ) - - # move the output file - move_file(memmap_output_path_temp, mcorr_memmap_path) - - print("mc finished successfully!") - - print("computing projections") - Yr, dims, T = cm.load_memmap(str(mcorr_memmap_path)) - images = np.reshape(Yr.T, [T] + list(dims), order="F") - - proj_paths = dict() - for proj_type in ["mean", "std", "max"]: - p_img = getattr(np, f"nan{proj_type}")(images, axis=0) - proj_paths[proj_type] = output_dir.joinpath( - f"{uuid}_{proj_type}_projection.npy" + # Run MC + fnames = [input_movie_path] + mc = MotionCorrect(fnames, dview=dview, **opts.get_group("motion")) + mc.motion_correct(save_movie=True) + + # find path to mmap file + memmap_output_path_temp = df.paths.resolve(mc.mmap_file[0]) + + # filename to move the output back to data dir + mcorr_memmap_path = output_dir.joinpath( + f"{uuid}-{memmap_output_path_temp.name}" ) - np.save(str(proj_paths[proj_type]), p_img) - - print("Computing correlation image") - Cns = local_correlations_movie_offline( - str(mcorr_memmap_path), - remove_baseline=True, - window=1000, - stride=1000, - winSize_baseline=100, - quantil_min_baseline=10, - dview=dview, - ) - Cn = Cns.max(axis=0) - Cn[np.isnan(Cn)] = 0 - cn_path = output_dir.joinpath(f"{uuid}_cn.npy") - np.save(str(cn_path), Cn, allow_pickle=False) - - print("finished computing correlation image") - - # Compute shifts - if opts.motion["pw_rigid"] == True: - x_shifts = mc.x_shifts_els - y_shifts = mc.y_shifts_els - shifts = [x_shifts, y_shifts] - if hasattr(mc, "z_shifts_els"): - shifts.append(mc.z_shifts_els) - shift_path = output_dir.joinpath(f"{uuid}_shifts.npy") - np.save(str(shift_path), shifts) - else: - shifts = mc.shifts_rig - shift_path = output_dir.joinpath(f"{uuid}_shifts.npy") - np.save(str(shift_path), shifts) - - # output dict for pandas series for dataframe row - d = dict() - - # save paths as relative path strings with forward slashes - cn_path = str(PurePosixPath(cn_path.relative_to(output_dir.parent))) - mcorr_memmap_path = str( - PurePosixPath(mcorr_memmap_path.relative_to(output_dir.parent)) - ) - shift_path = str(PurePosixPath(shift_path.relative_to(output_dir.parent))) - for proj_type in proj_paths.keys(): - d[f"{proj_type}-projection-path"] = str( - PurePosixPath(proj_paths[proj_type].relative_to(output_dir.parent)) + + # move the output file + move_file(memmap_output_path_temp, mcorr_memmap_path) + + print("mc finished successfully!") + + print("computing projections") + Yr, dims, T = cm.load_memmap(str(mcorr_memmap_path)) + images = np.reshape(Yr.T, [T] + list(dims), order="F") + + proj_paths = dict() + for proj_type in ["mean", "std", "max"]: + p_img = getattr(np, f"nan{proj_type}")(images, axis=0) + proj_paths[proj_type] = output_dir.joinpath( + f"{uuid}_{proj_type}_projection.npy" + ) + np.save(str(proj_paths[proj_type]), p_img) + + print("Computing correlation image") + Cns = local_correlations_movie_offline( + str(mcorr_memmap_path), + remove_baseline=True, + window=1000, + stride=1000, + winSize_baseline=100, + quantil_min_baseline=10, + dview=dview, + ) + Cn = Cns.max(axis=0) + Cn[np.isnan(Cn)] = 0 + cn_path = output_dir.joinpath(f"{uuid}_cn.npy") + np.save(str(cn_path), Cn, allow_pickle=False) + + print("finished computing correlation image") + + # Compute shifts + if opts.motion["pw_rigid"] == True: + x_shifts = mc.x_shifts_els + y_shifts = mc.y_shifts_els + shifts = [x_shifts, y_shifts] + if hasattr(mc, "z_shifts_els"): + shifts.append(mc.z_shifts_els) + shift_path = output_dir.joinpath(f"{uuid}_shifts.npy") + np.save(str(shift_path), shifts) + else: + shifts = mc.shifts_rig + shift_path = output_dir.joinpath(f"{uuid}_shifts.npy") + np.save(str(shift_path), shifts) + + # output dict for pandas series for dataframe row + d = dict() + + # save paths as relative path strings with forward slashes + cn_path = str(PurePosixPath(cn_path.relative_to(output_dir.parent))) + mcorr_memmap_path = str( + PurePosixPath(mcorr_memmap_path.relative_to(output_dir.parent)) + ) + shift_path = str(PurePosixPath(shift_path.relative_to(output_dir.parent))) + for proj_type in proj_paths.keys(): + d[f"{proj_type}-projection-path"] = str( + PurePosixPath(proj_paths[proj_type].relative_to(output_dir.parent)) + ) + + d.update( + { + "mcorr-output-path": mcorr_memmap_path, + "corr-img-path": cn_path, + "shifts": shift_path, + "success": True, + "traceback": None, + } ) - d.update( - { - "mcorr-output-path": mcorr_memmap_path, - "corr-img-path": cn_path, - "shifts": shift_path, - "success": True, - "traceback": None, - } - ) - - except: - d = {"success": False, "traceback": traceback.format_exc()} - print("mc failed, stored traceback in output") - - cm.stop_server(dview=dview) + except: + d = {"success": False, "traceback": traceback.format_exc()} + print("mc failed, stored traceback in output") runtime = round(time.time() - algo_start, 2) df.caiman.update_item_with_results(uuid, d, runtime) diff --git a/mesmerize_core/batch_utils.py b/mesmerize_core/batch_utils.py index 96a2bd7..9766172 100644 --- a/mesmerize_core/batch_utils.py +++ b/mesmerize_core/batch_utils.py @@ -11,11 +11,13 @@ COMPUTE_BACKEND_SUBPROCESS = "subprocess" #: subprocess backend COMPUTE_BACKEND_SLURM = "slurm" #: SLURM backend COMPUTE_BACKEND_LOCAL = "local" +COMPUTE_BACKEND_ASYNC = "local_async" COMPUTE_BACKENDS = [ COMPUTE_BACKEND_SUBPROCESS, COMPUTE_BACKEND_SLURM, COMPUTE_BACKEND_LOCAL, + COMPUTE_BACKEND_ASYNC, ] DATAFRAME_COLUMNS = [ diff --git a/mesmerize_core/caiman_extensions/common.py b/mesmerize_core/caiman_extensions/common.py index 8e2b4cd..72483bc 100644 --- a/mesmerize_core/caiman_extensions/common.py +++ b/mesmerize_core/caiman_extensions/common.py @@ -2,13 +2,14 @@ import shutil from pathlib import Path import psutil -from subprocess import Popen +from subprocess import Popen, CalledProcessError from typing import * from uuid import UUID, uuid4 from shutil import rmtree from datetime import datetime import time from copy import deepcopy +from concurrent.futures import ThreadPoolExecutor, Future import numpy as np import pandas as pd @@ -25,6 +26,7 @@ COMPUTE_BACKENDS, COMPUTE_BACKEND_SUBPROCESS, COMPUTE_BACKEND_LOCAL, + COMPUTE_BACKEND_ASYNC, get_parent_raw_data_path, load_batch, ) @@ -513,13 +515,37 @@ def get_parent(self, index: Union[int, str, UUID]) -> Union[UUID, None]: return r["uuid"] -class DummyProcess: - """Dummy process for local backend""" +class Waitable(Protocol): + """An object that we can call "wait" on""" + def wait(self) -> None: ... + - def wait(self): +class DummyProcess(Waitable): + """Dummy process for local backend""" + def wait(self) -> None: pass +class WaitableFuture(Waitable): + """Adaptor for future returned from Executor.submit""" + def __init__(self, future: Future[None]): + self.future = future + + def wait(self) -> None: + return self.future.result() + + +class CheckedSubprocess(Waitable): + """Adaptor for Popen that just raises an exception if the return code is nonzero""" + def __init__(self, popen: Popen): + self.popen = popen + + def wait(self) -> None: + rc = self.popen.wait() + if rc != 0: + raise CalledProcessError(rc, self.popen.args) + + @pd.api.extensions.register_series_accessor("caiman") class CaimanSeriesExtensions: """ @@ -528,36 +554,54 @@ class CaimanSeriesExtensions: def __init__(self, s: pd.Series): self._series = s - self.process: Popen = None + self.process: Optional[Waitable] = None def _run_local( - self, - algo: str, - batch_path: Path, - uuid: UUID, - data_path: Union[Path, None], - ): + self, + algo: str, + batch_path: Path, + uuid: UUID, + data_path: Union[Path, None], + dview=None + ) -> DummyProcess: algo_module = getattr(algorithms, algo) algo_module.run_algo( - batch_path=str(batch_path), uuid=str(uuid), data_path=str(data_path) + batch_path=str(batch_path), + uuid=str(uuid), + data_path=str(data_path), + dview=dview ) - return DummyProcess() - def _run_subprocess(self, runfile_path: str, wait: bool, **kwargs): + def _run_local_async( + self, + algo: str, + batch_path: Path, + uuid: UUID, + data_path: Union[Path, None], + dview=None + ) -> WaitableFuture: + algo_module = getattr(algorithms, algo) + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit( + algo_module.run_algo, + batch_path=str(batch_path), + uuid=str(uuid), + data_path=str(data_path), + dview=dview + ) + return WaitableFuture(future) + + def _run_subprocess(self, runfile_path: str, **kwargs) -> CheckedSubprocess: # Get the dir that contains the input movie parent_path = self._series.paths.resolve(self._series.input_movie_path).parent - self.process = Popen([runfile_path], cwd=parent_path) + popen = Popen([runfile_path], cwd=parent_path) + return CheckedSubprocess(popen) # so that it throws an exception on failure - if wait: - self.process.wait() - - return self.process def _run_slurm( - self, runfile_path: str, wait: bool, sbatch_opts: str = "", **kwargs - ): + self, runfile_path: str, sbatch_opts: str = "", **kwargs) -> CheckedSubprocess: """ Run on a cluster using SLURM. Configurable options (to pass to run): - sbatch_opts: A single string containing additional options for sbatch. @@ -589,15 +633,12 @@ def _run_slurm( f"--output={output_path}", "--wait", ] + shlex.split(sbatch_opts) + + return CheckedSubprocess(Popen(["sbatch", *submission_opts, runfile_path])) - self.process = Popen(["sbatch", *submission_opts, runfile_path]) - if wait: - self.process.wait() - - return self.process @cnmf_cache.invalidate() - def run(self, backend: Optional[str] = None, wait: bool = True, **kwargs): + def run(self, backend: Optional[str] = None, wait: bool = True, **kwargs) -> Waitable: """ Run a CaImAn algorithm in an external process using the chosen backend @@ -636,42 +677,47 @@ def run(self, backend: Optional[str] = None, wait: bool = True, **kwargs): batch_path = self._series.paths.get_batch_path() - if backend == COMPUTE_BACKEND_LOCAL: - print(f"Running {self._series.uuid} with local backend") - return self._run_local( + if backend in [COMPUTE_BACKEND_LOCAL, COMPUTE_BACKEND_ASYNC]: + print(f"Running {self._series.uuid} with {backend} backend") + self.process = getattr(self, f"_run_{backend}")( algo=self._series["algo"], batch_path=batch_path, uuid=self._series["uuid"], data_path=get_parent_raw_data_path(), + dview=kwargs.get("dview") ) - - # Create the runfile in the batch dir using this Series' UUID as the filename - if IS_WINDOWS: - runfile_ext = ".bat" else: - runfile_ext = ".runfile" - runfile_path = str( - batch_path.parent.joinpath(self._series["uuid"] + runfile_ext) - ) + # Create the runfile in the batch dir using this Series' UUID as the filename + if IS_WINDOWS: + runfile_ext = ".bat" + else: + runfile_ext = ".runfile" + runfile_path = str( + batch_path.parent.joinpath(self._series["uuid"] + runfile_ext) + ) - args_str = ( - f"--batch-path {lex.quote(str(batch_path))} --uuid {self._series.uuid}" - ) - if get_parent_raw_data_path() is not None: - args_str += f" --data-path {lex.quote(str(get_parent_raw_data_path()))}" - - # make the runfile - runfile_path = make_runfile( - module_path=os.path.abspath( - ALGO_MODULES[self._series["algo"]].__file__ - ), # caiman algorithm - filename=runfile_path, # path to create runfile - args_str=args_str, - ) + args_str = ( + f"--batch-path {lex.quote(str(batch_path))} --uuid {self._series.uuid}" + ) + if get_parent_raw_data_path() is not None: + args_str += f" --data-path {lex.quote(str(get_parent_raw_data_path()))}" + + # make the runfile + runfile_path = make_runfile( + module_path=os.path.abspath( + ALGO_MODULES[self._series["algo"]].__file__ + ), # caiman algorithm + filename=runfile_path, # path to create runfile + args_str=args_str, + ) - self.process = getattr(self, f"_run_{backend}")( - runfile_path, wait=wait, **kwargs - ) + self.process = getattr(self, f"_run_{backend}")( + runfile_path, **kwargs + ) + + assert self.process is not None, 'Process should have been created' + if wait: + self.process.wait() return self.process def get_input_movie_path(self) -> Path: diff --git a/tests/test_core.py b/tests/test_core.py index b8a7419..8b8a920 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,5 +1,4 @@ import os - import numpy as np from caiman.utils.utils import load_dict_from_hdf5 from caiman.source_extraction.cnmf.cnmf import CNMF @@ -15,9 +14,12 @@ from mesmerize_core.batch_utils import ( DATAFRAME_COLUMNS, COMPUTE_BACKEND_SUBPROCESS, + COMPUTE_BACKEND_LOCAL, + COMPUTE_BACKEND_ASYNC, get_full_raw_data_path, ) from mesmerize_core.utils import IS_WINDOWS +from mesmerize_core.algorithms._utils import ensure_server from uuid import uuid4 import pytest import requests @@ -33,6 +35,7 @@ import tifffile from copy import deepcopy + # don't call "resolve" on these - want to make sure we can handle non-canonical paths correctly tmp_dir = Path(os.path.dirname(os.path.abspath(__file__)), "test data", "tmp") vid_dir = Path(os.path.dirname(os.path.abspath(__file__)), "test data", "videos") @@ -1336,3 +1339,48 @@ def test_cache(): assert hex(id(cnmf.cnmf_cache.get_cache().iloc[-1]["return_val"])) == hex( id(output) ) + + +def test_backends(): + """test subprocess, local, and async_local backend""" + set_parent_raw_data_path(vid_dir) + algo = "mcorr" + df, batch_path = _create_tmp_batch() + input_movie_path = get_datafile(algo) + + # make small version of movie for quick testing + movie = tifffile.imread(input_movie_path) + small_movie_path = input_movie_path.parent.joinpath("small_movie.tif") + tifffile.imwrite(small_movie_path, movie[:1001]) + print(input_movie_path) + + # put backends that can run in the background first to save time + backends = [COMPUTE_BACKEND_SUBPROCESS, COMPUTE_BACKEND_ASYNC, COMPUTE_BACKEND_LOCAL] + for backend in backends: + df.caiman.add_item( + algo="mcorr", + item_name=f"test-{backend}", + input_movie_path=small_movie_path, + params=test_params["mcorr"], + ) + + # run using each backend + procs = [] + with ensure_server(None) as (dview, _): + for backend, (_, item) in zip(backends, df.iterrows()): + procs.append(item.caiman.run(backend=backend, dview=dview, wait=False)) + + # wait for all to finish + for proc in procs: + proc.wait() + + # compare results + df = load_batch(batch_path) + for i, item in df.iterrows(): + output = item.mcorr.get_output() + + if i == 0: + # save to compare to other results + first_output = output + else: + numpy.testing.assert_array_equal(output, first_output)