diff --git a/.gitignore b/.gitignore index d5dede4..0aed8f9 100644 --- a/.gitignore +++ b/.gitignore @@ -136,6 +136,3 @@ dmypy.json # test files tests/test data - -# VSCode -.vscode/ diff --git a/mesmerize_core/algorithms/_utils.py b/mesmerize_core/algorithms/_utils.py new file mode 100644 index 0000000..21b50b6 --- /dev/null +++ b/mesmerize_core/algorithms/_utils.py @@ -0,0 +1,387 @@ +from contextlib import contextmanager +import logging +import math +import os +from pathlib import Path +import psutil +from typing import ( + Optional, + Union, + Generator, + Protocol, + Callable, + Generic, + TypeVar, + Sequence, + Iterable, + runtime_checkable, +) + +import caiman as cm +from caiman.base.movies import get_file_size +from caiman.cluster import setup_cluster +from caiman.paths import generate_fname_tot, fn_relocated +from caiman.summary_images import local_correlations +from ipyparallel import DirectView +from multiprocessing.pool import Pool +import numpy as np +import scipy.stats + + +RetVal = TypeVar("RetVal") +@runtime_checkable +class CustomCluster(Protocol): + """ + Protocol for a cluster that is not a multiprocessing pool + (including ipyparallel.DirectView) + """ + + def map_sync( + self, fn: Callable[..., RetVal], args: Iterable + ) -> Sequence[RetVal]: ... + + def __len__(self) -> int: + """return number of workers""" + ... + + +Cluster = Union[Pool, CustomCluster, DirectView] + + +def get_n_processes(dview: Optional[Cluster]) -> int: + """Infer number of processes in a multiprocessing or ipyparallel cluster""" + if isinstance(dview, Pool): + assert hasattr(dview, '_processes'), "Pool not keeping track of # of processes?" + return dview._processes # type: ignore + elif dview is not None: + return len(dview) + else: + return 1 + + +@contextmanager +def ensure_server(dview: Optional[Cluster]) -> Generator[tuple[Cluster, int], None, None]: + """ + Context manager that passes through an existing 'dview' or + opens up a multiprocessing server if none is passed in. + If a server was opened, closes it upon exit. + Usage: `with ensure_server(dview) as (dview, n_processes):` + """ + if dview is not None: + yield dview, get_n_processes(dview) + else: + # no cluster passed in, so open one + procs_available = psutil.cpu_count() + if procs_available is None: + raise RuntimeError('Cannot determine number of processes') + + if "MESMERIZE_N_PROCESSES" in os.environ.keys(): + try: + n_processes = int(os.environ["MESMERIZE_N_PROCESSES"]) + except: + n_processes = procs_available - 1 + else: + n_processes = procs_available - 1 + + # Start cluster for parallel processing + _, dview, n_processes = setup_cluster( + backend="multiprocessing", n_processes=n_processes, single_thread=False + ) + assert isinstance(dview, Pool) and isinstance(n_processes, int), 'setup_cluster with multiprocessing did not return a Pool' + try: + yield dview, n_processes + finally: + cm.stop_server(dview=dview) + + +def avail_bytes_per_process(n_processes: int): + return psutil.virtual_memory()[1] / n_processes + + +def estimate_n_pixels_per_process(n_processes: int, T: int, dims: tuple[int, ...]) -> int: + """ + Estimate a safe number of pixels to allocate to each parallel process at a time + Taken from CNMF.fit (TODO factor this out in caiman and just import it) + """ + avail_memory_per_process = avail_bytes_per_process(n_processes) / 2.0**30 + mem_per_pix = 3.6977678498329843e-09 + npx_per_proc = int(avail_memory_per_process / 8. / mem_per_pix / T) + npx_per_proc = int(np.minimum(npx_per_proc, np.prod(dims) // n_processes)) + return npx_per_proc + + +def fix_multid_subindices(movie_path: str, subindices: Union[list, tuple]) -> Union[list, tuple]: + """ + Make multidimensional subindices that work for the given file type for caiman.load, given that + some file types expect subindices as a list and others don't explicitly support multi-D subindices + and therefore they must be passed as a tuple to work correctly. + """ + _, ext = os.path.splitext(movie_path) + if ext in ['.tif', '.tiff', '.btf', '.avi', '.mkv']: + # formats that expect multi-D subindices as a list + return list(subindices) + else: + return tuple(subindices) + + +R = TypeVar('R') +class ColumnMappingFunction(Generic[R]): + """ + Object to map an operation over columns of a movie to avoid running out of memory + Construct with the kernel function which takes pixels x time matrices and returns anything. + """ + def __init__(self, kernel: Callable[..., R]): + self.kernel = kernel + + def _helper(self, args: tuple) -> R: + movie_path: str = args[0] + var_name_hdf5: str = args[1] + row_slice: slice = args[2] + col_slice: slice = args[3] + page: Optional[int] = args[4] + + subindices = (slice(None), row_slice, col_slice) + if page is not None: + subindices += (page,) + logging.debug(f'In column mapping kernel, page = {page}, cols = {col_slice.start} to {col_slice.stop}') + else: + logging.debug(f'In column mapping kernel, cols = {col_slice.start} to {col_slice.stop}') + + mov: cm.movie = cm.load(movie_path, subindices=fix_multid_subindices(movie_path, subindices), var_name_hdf5=var_name_hdf5) + T, *dims = mov.shape + + # flatten to pixels x time + Yr = mov.reshape((T, -1), order='F').T + + # get slice of flattened pixels + page_offset = page * dims[0] * dims[1] if page else 0 + pixel_slice = slice(page_offset + col_slice.start * dims[0] + row_slice.start, + page_offset + (col_slice.stop-1) * dims[0] + row_slice.stop) + + # apply the kernel + return self.kernel(Yr, pixel_slice, *args[5:]) + + def __call__(self, movie_path: str, dview: Optional[Cluster], var_name_hdf5='mov', *args) -> Iterable[R]: + """Perform an operation on non-overlapping chunks of columns that fit into memory""" + dims, T = get_file_size(movie_path, var_name_hdf5=var_name_hdf5) + assert isinstance(T, int) # non-type-stable interface... + + # use n_pixels_per_process from CNMF to avoid running out of memory + chunk_size = estimate_n_pixels_per_process(get_n_processes(dview), T, dims) + n_chunks = math.ceil(dims[0] * dims[1] / chunk_size) + n_column_chunks = min(math.ceil(dims[0] * dims[1] / chunk_size), dims[1]) + + # divide movie into chunks of columns + chunk_col_edges = np.linspace(0, dims[1], n_column_chunks+1).astype(int) + chunk_col_slices = [slice(start, end) for start, end in zip(chunk_col_edges[:-1], chunk_col_edges[1:])] + + if (n_row_chunks := math.ceil(n_chunks / n_column_chunks)) > 1: + # subdivide rows as well + chunk_row_edges = np.linspace(0, dims[0], n_row_chunks+1).astype(int) + chunk_row_slices = [slice(start, end) for start, end in zip(chunk_row_edges[:-1], chunk_row_edges[1:])] + else: + chunk_row_slices = [slice(0, dims[0])] + + if len(dims) > 2 and dims[2] > 1: + pages = range(dims[2]) + else: + pages = [None] + + map_args = [ + (movie_path, var_name_hdf5, row_slice, col_slice, page) + args + for page in pages for col_slice in chunk_col_slices for row_slice in chunk_row_slices + ] + + if dview is None: + map_fn = map + elif isinstance(dview, Pool): + map_fn = dview.map + else: + map_fn = dview.map_sync + + return map_fn(self._helper, map_args) + + +def _save_c_order_mmap_in_chunks_kernel(Yr_chunk: np.ndarray, pixel_slice: slice, mmap_fname: str, add_to_movie: float): + """ + Alternative to cm.save_memmap that can load from non-mmap files and uses chunks + Based on caiman.mmapping.save_portion + """ + c_order_copy = np.ascontiguousarray(Yr_chunk, dtype=np.float32) + np.float32(add_to_movie) # pixels x time + tot_frames = c_order_copy.shape[1] + + with open(mmap_fname, 'r+b') as f: + # seek to the start of the chunk + idx_start, idx_end = pixel_slice.start, pixel_slice.stop + f.seek(idx_start * c_order_copy.dtype.itemsize * tot_frames) + f.write(c_order_copy.data) + computed_position = np.uint64(idx_end * np.uint64(c_order_copy.dtype.itemsize) * tot_frames) + if f.tell() != computed_position: + logging.critical(f"Error in mmap portion write: at position {f.tell()}") + logging.critical( + f"But should be at position {idx_end} * {c_order_copy.dtype.itemsize} * {tot_frames} = {computed_position}" + ) + f.close() + raise Exception('Internal error in mmapping: Actual position does not match computed position') +save_c_order_mmap_in_chunks = ColumnMappingFunction(_save_c_order_mmap_in_chunks_kernel) + +def save_c_order_mmap_parallel(movie_path: str, base_name: str, dview: Optional[Cluster], + var_name_hdf5='mov', add_to_movie=0.0001) -> str: + """ + Alternative to cm.save_memmap that hopefully does better with memory + add_to_movie=0.0001 emulates default behavior of save_memmap + """ + # get name of mmap file and create it + dims, tot_frames = get_file_size(movie_path, var_name_hdf5=var_name_hdf5) + assert isinstance(tot_frames, int) # non-type-stable interface... + + # use generate_fname_tot to emulate behavior of save_memmap + mmap_fname_start = generate_fname_tot(base_name, list(dims), order='C') + mmap_fname = f'{mmap_fname_start}_frames_{tot_frames}.mmap' + mmap_fname = os.path.join(os.path.split(movie_path)[0], mmap_fname) + mmap_fname = fn_relocated(mmap_fname) + logging.info(f'Creating mmap file: {mmap_fname}') + + d = int(np.prod(dims)) + big_mov = np.memmap(mmap_fname, mode='w+', dtype=np.float32, shape=(d, tot_frames), order='C') + + # parallel load/save call + save_c_order_mmap_in_chunks(movie_path, dview, var_name_hdf5, mmap_fname, add_to_movie) + + # clean up + del big_mov + return mmap_fname + + +def _make_chunk_projections_kernel(Yr_chunk: np.ndarray, _pixel_slice: slice, proj_type: str, ignore_nan=False) -> np.ndarray: + if hasattr(scipy.stats, proj_type): + return getattr(scipy.stats, proj_type)(Yr_chunk, axis=1, nan_policy='omit' if ignore_nan else 'propagate') + + if hasattr(np, proj_type): + if ignore_nan: + if hasattr(np, "nan" + proj_type): + proj_type = "nan" + proj_type + else: + logging.warning(f"NaN-ignoring version of {proj_type} function does not exist; not ignoring NaNs") + return getattr(np, proj_type)(Yr_chunk, axis=1) + + raise NotImplementedError(f"Projection type '{proj_type}' not implemented") +make_chunk_projections = ColumnMappingFunction(_make_chunk_projections_kernel) + + +def make_projection_parallel(movie_path: str, proj_type: str, dview: Optional[Cluster], ignore_nan=False, + var_name_hdf5='mov') -> np.ndarray: + """ + Compute projection in chunks that are small enough to fit in memory + movie_path: path to movie that can be memory-mapped using caiman.load + """ + dims, _ = get_file_size(movie_path, var_name_hdf5=var_name_hdf5) + chunk_projs = make_chunk_projections(movie_path, dview, var_name_hdf5, proj_type, ignore_nan) + p_img_flat = np.concatenate(list(chunk_projs), axis=0) + return np.reshape(p_img_flat, dims, order="F") + + +def save_projections_parallel( + uuid, movie_path: Union[str, Path], output_dir: Path, dview: Optional[Cluster], var_name_hdf5='mov' +) -> dict[str, Path]: + proj_paths = dict() + for proj_type in ["mean", "std", "max"]: + p_img = make_projection_parallel( + str(movie_path), proj_type, dview=dview, ignore_nan=True, var_name_hdf5=var_name_hdf5 + ) + proj_paths[proj_type] = output_dir.joinpath( + f"{uuid}_{proj_type}_projection.npy" + ) + np.save(str(proj_paths[proj_type]), p_img) + return proj_paths + + +ChunkDims = tuple[slice, slice] +ChunkSpec = tuple[ChunkDims, ChunkDims, ChunkDims] # input, output, patch subinds + +def make_correlation_parallel(movie_path: Union[str, Path], dview: Optional[Cluster]) -> np.ndarray: + """ + Compute local correlations in chunks that are small enough to fit in memory + movie_path: path to movie that can be memory-mapped using caiman.load + """ + dims, T = get_file_size(movie_path) + assert isinstance(T, int) # non-type-stable interface... + + # use n_pixels_per_process from CNMF to avoid running out of memory + chunk_size = estimate_n_pixels_per_process(get_n_processes(dview), T, dims) + patches = make_correlation_patches(dims, chunk_size) + + # do correlation calculation in parallel + args = [(str(movie_path), p[0]) for p in patches] + if dview is None: + map_fn = map + elif isinstance(dview, Pool): + map_fn = dview.map + else: + map_fn = dview.map_sync + + patch_corrs = map_fn(chunk_correlation_helper, args) + output_img = np.empty(dims, dtype=np.float32) + for (_, output_coords, subinds), patch_corr in zip(patches, patch_corrs): + output_img[output_coords] = patch_corr[subinds] + + return output_img + + +def save_correlation_parallel(uuid, movie_path: Union[str, Path], output_dir: Path, dview: Optional[Cluster]) -> Path: + """Compute and save local correlations in chunks that are small enough to fit in memory""" + corr_img = make_correlation_parallel(movie_path, dview) + corr_img_path = output_dir.joinpath(f"{uuid}_cn.npy") + np.save(str(corr_img_path), corr_img, allow_pickle=False) + return corr_img_path + + +def chunk_correlation_helper(args: tuple[str, ChunkDims]) -> np.ndarray: + movie_path, dims_input = args + mov = cm.load(movie_path, subindices=fix_multid_subindices(movie_path, (slice(None),) + dims_input)) + return local_correlations(mov, swap_dim=False) + + +def make_correlation_patches(dims: tuple[int, ...], chunk_size: int) -> list[ChunkSpec]: + """ + Compute dimensions for dividing movie (ideally C-order) into patches for correlation calculation. + Overlap = 2 to avoid edge effects except on the edge. + Each entry of the returned list contains 3 (Y, X) tuples of slices: + - input coordinates (for getting sub-movie to compute correlation on) + - output coordinates (for assigning result to full correlation image, excludes inner borders) + - patch sub-indices (to index result for assignment to output) + """ + window_size = math.floor(math.sqrt(chunk_size)) + + # first get patch starts and sizes for each dimension + patch_coords_y = make_correlation_patches_for_dim(dims[0], window_size) + patch_coords_x = make_correlation_patches_for_dim(dims[1], window_size) + return [ + ((input_y, input_x), (output_y, output_x), (subind_y, subind_x)) + for input_y, output_y, subind_y in patch_coords_y + for input_x, output_x, subind_x in patch_coords_x + ] + + +def make_correlation_patches_for_dim(dim: int, window_size: int) -> list[tuple[slice, slice, slice]]: + """ + Like make_correlation_patches but for just one dimension + """ + overlap = 2 # so that edge pixel in one patch is a non-edge pixel in the next + window_size = max(window_size, overlap + 1) + stride = window_size - overlap + + patch_starts = range(0, dim - overlap, stride) # last pixels are covered by last window + patch_ends = [start + window_size for start in patch_starts[:-1]] + [dim] + patch_coords: list[tuple[slice, slice, slice]] = [] + + for start, end in zip(patch_starts, patch_ends): + is_first = start == patch_starts[0] + is_last = start == patch_starts[-1] + patch_coords.append(( + slice(start, end), + slice(start if is_first else start + 1, end if is_last else end-1), + slice(0 if is_first else 1, None if is_last else -1) + )) + + return patch_coords diff --git a/mesmerize_core/algorithms/cnmf.py b/mesmerize_core/algorithms/cnmf.py index dd7381a..fbb5c60 100644 --- a/mesmerize_core/algorithms/cnmf.py +++ b/mesmerize_core/algorithms/cnmf.py @@ -4,24 +4,35 @@ import caiman as cm from caiman.source_extraction.cnmf import cnmf as cnmf from caiman.source_extraction.cnmf.params import CNMFParams -import psutil +from caiman.paths import decode_mmap_filename_dict import numpy as np import traceback from pathlib import Path, PurePosixPath from shutil import move as move_file -import os import time # prevent circular import if __name__ in ["__main__", "__mp_main__"]: # when running in subprocess from mesmerize_core import set_parent_raw_data_path, load_batch from mesmerize_core.utils import IS_WINDOWS + from mesmerize_core.algorithms._utils import ( + ensure_server, + save_projections_parallel, + save_correlation_parallel, + save_c_order_mmap_parallel, + ) else: # when running with local backend from ..batch_utils import set_parent_raw_data_path, load_batch from ..utils import IS_WINDOWS + from ._utils import ( + ensure_server, + save_projections_parallel, + save_correlation_parallel, + save_c_order_mmap_parallel, + ) -def run_algo(batch_path, uuid, data_path: str = None): +def run_algo(batch_path, uuid, data_path: str = None, dview=None): algo_start = time.time() set_parent_raw_data_path(data_path) @@ -42,103 +53,95 @@ def run_algo(batch_path, uuid, data_path: str = None): f"Starting CNMF item:\n{item}\nWith params:{params}" ) - # adapted from current demo notebook - if "MESMERIZE_N_PROCESSES" in os.environ.keys(): - try: - n_processes = int(os.environ["MESMERIZE_N_PROCESSES"]) - except: - n_processes = psutil.cpu_count() - 1 - else: - n_processes = psutil.cpu_count() - 1 - # Start cluster for parallel processing - c, dview, n_processes = cm.cluster.setup_cluster( - backend="local", n_processes=n_processes, single_thread=False - ) - - # merge cnmf and eval kwargs into one dict - cnmf_params = CNMFParams(params_dict=params["main"]) - # Run CNMF, denote boolean 'success' if CNMF completes w/out error - try: - fname_new = cm.save_memmap( - [input_movie_path], base_name=f"{uuid}_cnmf-memmap_", order="C", dview=dview - ) - - print("making memmap") + with ensure_server(dview) as (dview, n_processes): - Yr, dims, T = cm.load_memmap(fname_new) - images = np.reshape(Yr.T, [T] + list(dims), order="F") - - proj_paths = dict() - for proj_type in ["mean", "std", "max"]: - p_img = getattr(np, f"nan{proj_type}")(images, axis=0) - proj_paths[proj_type] = output_dir.joinpath( - f"{uuid}_{proj_type}_projection.npy" - ) - np.save(str(proj_paths[proj_type]), p_img) - - # in fname new load in memmap order C - cm.stop_server(dview=dview) - c, dview, n_processes = cm.cluster.setup_cluster( - backend="local", n_processes=None, single_thread=False - ) - - print("performing CNMF") - cnm = cnmf.CNMF(n_processes, params=cnmf_params, dview=dview) - - print("fitting images") - cnm.fit(images) - # - if "refit" in params.keys(): - if params["refit"] is True: - print("refitting") - cnm = cnm.refit(images, dview=dview) - - print("performing eval") - cnm.estimates.evaluate_components(images, cnm.params, dview=dview) - - output_path = output_dir.joinpath(f"{uuid}.hdf5") - - cnm.save(str(output_path)) - - Cn = cm.local_correlations(images, swap_dim=False) - Cn[np.isnan(Cn)] = 0 - - corr_img_path = output_dir.joinpath(f"{uuid}_cn.npy") - np.save(str(corr_img_path), Cn, allow_pickle=False) - - # output dict for dataframe row (pd.Series) - d = dict() - - cnmf_memmap_path = output_dir.joinpath(Path(fname_new).name) - if IS_WINDOWS: - Yr._mmap.close() # accessing private attr but windows is annoying otherwise - move_file(fname_new, cnmf_memmap_path) - - # save paths as relative path strings with forward slashes - cnmf_hdf5_path = str(PurePosixPath(output_path.relative_to(output_dir.parent))) - cnmf_memmap_path = str( - PurePosixPath(cnmf_memmap_path.relative_to(output_dir.parent)) - ) - corr_img_path = str(PurePosixPath(corr_img_path.relative_to(output_dir.parent))) - for proj_type in proj_paths.keys(): - d[f"{proj_type}-projection-path"] = str( - PurePosixPath(proj_paths[proj_type].relative_to(output_dir.parent)) + # merge cnmf and eval kwargs into one dict + cnmf_params = CNMFParams(params_dict=params["main"]) + # Run CNMF, denote boolean 'success' if CNMF completes w/out error + try: + # only re-save memmap if necessary + save_new_mmap = True + if Path(input_movie_path).suffix == ".mmap": + mmap_info = decode_mmap_filename_dict(input_movie_path) + save_new_mmap = "order" not in mmap_info or mmap_info["order"] != "C" + + if save_new_mmap: + print("making memmap") + fname_new = save_c_order_mmap_parallel( + input_movie_path, + base_name=f"{uuid}_cnmf-memmap_", + dview=dview, + var_name_hdf5=cnmf_params.data['var_name_hdf5'] + ) + cnmf_memmap_path = output_dir.joinpath(Path(fname_new).name) + move_file(fname_new, cnmf_memmap_path) + else: + cnmf_memmap_path = Path(input_movie_path) + + Yr, dims, T = cm.load_memmap(str(cnmf_memmap_path)) + images = np.reshape(Yr.T, [T] + list(dims), order="F") + + print("computing projections") + proj_paths = save_projections_parallel( + uuid=uuid, + movie_path=cnmf_memmap_path, + output_dir=output_dir, + dview=dview, ) - d.update( - { - "cnmf-hdf5-path": cnmf_hdf5_path, - "cnmf-memmap-path": cnmf_memmap_path, - "corr-img-path": corr_img_path, - "success": True, - "traceback": None, - } - ) + print("computing correlation image") + corr_img_path = save_correlation_parallel( + uuid=uuid, + movie_path=cnmf_memmap_path, + output_dir=output_dir, + dview=dview, + ) - except: - d = {"success": False, "traceback": traceback.format_exc()} + print("performing CNMF") + cnm = cnmf.CNMF(n_processes, params=cnmf_params, dview=dview) + + print("fitting images") + cnm.fit(images) + # + if "refit" in params.keys(): + if params["refit"] is True: + print("refitting") + cnm = cnm.refit(images, dview=dview) + + print("performing eval") + cnm.estimates.evaluate_components(images, cnm.params, dview=dview) + + output_path = output_dir.joinpath(f"{uuid}.hdf5") + + cnm.save(str(output_path)) + + # output dict for dataframe row (pd.Series) + d = dict() + + if IS_WINDOWS: + Yr._mmap.close() # accessing private attr but windows is annoying otherwise + + # save paths as relative path strings with forward slashes + cnmf_hdf5_path = str(PurePosixPath(output_path.relative_to(output_dir.parent))) + cnmf_memmap_path = str(PurePosixPath(df.paths.split(cnmf_memmap_path)[1])) # still work if outside output dir + corr_img_path = str(PurePosixPath(corr_img_path.relative_to(output_dir.parent))) + for proj_type in proj_paths.keys(): + d[f"{proj_type}-projection-path"] = str( + PurePosixPath(proj_paths[proj_type].relative_to(output_dir.parent)) + ) + + d.update( + { + "cnmf-hdf5-path": cnmf_hdf5_path, + "cnmf-memmap-path": cnmf_memmap_path, + "corr-img-path": corr_img_path, + "success": True, + "traceback": None, + } + ) - cm.stop_server(dview=dview) + except: + d = {"success": False, "traceback": traceback.format_exc()} runtime = round(time.time() - algo_start, 2) df.caiman.update_item_with_results(uuid, d, runtime) diff --git a/mesmerize_core/algorithms/cnmfe.py b/mesmerize_core/algorithms/cnmfe.py index d4f1858..04da26e 100644 --- a/mesmerize_core/algorithms/cnmfe.py +++ b/mesmerize_core/algorithms/cnmfe.py @@ -3,22 +3,27 @@ import caiman as cm from caiman.source_extraction.cnmf import cnmf as cnmf from caiman.source_extraction.cnmf.params import CNMFParams -import psutil +from caiman.paths import decode_mmap_filename_dict import traceback from pathlib import Path, PurePosixPath from shutil import move as move_file -import os import time if __name__ in ["__main__", "__mp_main__"]: # when running in subprocess from mesmerize_core import set_parent_raw_data_path, load_batch from mesmerize_core.utils import IS_WINDOWS + from mesmerize_core.algorithms._utils import ( + ensure_server, + save_projections_parallel, + save_c_order_mmap_parallel, + ) else: # when running with local backend from ..batch_utils import set_parent_raw_data_path, load_batch from ..utils import IS_WINDOWS + from ._utils import ensure_server, save_projections_parallel, save_c_order_mmap_parallel -def run_algo(batch_path, uuid, data_path: str = None): +def run_algo(batch_path, uuid, data_path: str = None, dview=None): algo_start = time.time() set_parent_raw_data_path(data_path) @@ -35,91 +40,86 @@ def run_algo(batch_path, uuid, data_path: str = None): params = item["params"] print("cnmfe params:", params) - # adapted from current demo notebook - if "MESMERIZE_N_PROCESSES" in os.environ.keys(): + with ensure_server(dview) as (dview, n_processes): try: - n_processes = int(os.environ["MESMERIZE_N_PROCESSES"]) - except: - n_processes = psutil.cpu_count() - 1 - else: - n_processes = psutil.cpu_count() - 1 - # Start cluster for parallel processing - c, dview, n_processes = cm.cluster.setup_cluster( - backend="local", n_processes=n_processes, single_thread=False - ) + # force the CNMFE params + cnmfe_params_dict = { + "method_init": "corr_pnr", + "n_processes": n_processes, + "only_init": True, # for 1p + "center_psf": True, # for 1p + "normalize_init": False, # for 1p + } - try: - fname_new = cm.save_memmap( - [input_movie_path], base_name=f"{uuid}_cnmf-memmap_", order="C", dview=dview - ) - - print("making memmap") - Yr, dims, T = cm.load_memmap(fname_new) - images = np.reshape(Yr.T, [T] + list(dims), order="F") - - # TODO: if projections already exist from mcorr we don't - # need to waste compute time re-computing them here - proj_paths = dict() - for proj_type in ["mean", "std", "max"]: - p_img = getattr(np, f"nan{proj_type}")(images, axis=0) - proj_paths[proj_type] = output_dir.joinpath( - f"{uuid}_{proj_type}_projection.npy" + params_dict = {**cnmfe_params_dict, **params["main"]} + + cnmfe_params_dict = CNMFParams(params_dict=params_dict) + + # only re-save memmap if necessary + save_new_mmap = True + if Path(input_movie_path).suffix == ".mmap": + mmap_info = decode_mmap_filename_dict(input_movie_path) + save_new_mmap = "order" not in mmap_info or mmap_info["order"] != "C" + + if save_new_mmap: + print("making memmap") + fname_new = save_c_order_mmap_parallel( + input_movie_path, + base_name=f"{uuid}_cnmf-memmap_", + dview=dview, + var_name_hdf5=cnmfe_params_dict.data['var_name_hdf5'] + ) + cnmf_memmap_path = output_dir.joinpath(Path(fname_new).name) + move_file(fname_new, cnmf_memmap_path) + else: + cnmf_memmap_path = Path(input_movie_path) + + Yr, dims, T = cm.load_memmap(str(cnmf_memmap_path)) + images = np.reshape(Yr.T, [T] + list(dims), order="F") + + # TODO: if projections already exist from mcorr we don't + # need to waste compute time re-computing them here + proj_paths = save_projections_parallel( + uuid=uuid, movie_path=cnmf_memmap_path, output_dir=output_dir, dview=dview ) - np.save(str(proj_paths[proj_type]), p_img) - d = dict() # for output + d = dict() # for output - # force the CNMFE params - cnmfe_params_dict = { - "method_init": "corr_pnr", - "n_processes": n_processes, - "only_init": True, # for 1p - "center_psf": True, # for 1p - "normalize_init": False, # for 1p - } + cnm = cnmf.CNMF(n_processes=n_processes, dview=dview, params=cnmfe_params_dict) + print("Performing CNMFE") + cnm.fit(images) + print("evaluating components") + cnm.estimates.evaluate_components(images, cnm.params, dview=dview) - params_dict = {**cnmfe_params_dict, **params["main"]} + cnmf_hdf5_path = output_dir.joinpath(f"{uuid}.hdf5") + cnm.save(str(cnmf_hdf5_path)) - cnmfe_params_dict = CNMFParams(params_dict=params_dict) - cnm = cnmf.CNMF(n_processes=n_processes, dview=dview, params=cnmfe_params_dict) - print("Performing CNMFE") - cnm.fit(images) - print("evaluating components") - cnm.estimates.evaluate_components(images, cnm.params, dview=dview) + # save output paths to outputs dict + d["cnmf-hdf5-path"] = cnmf_hdf5_path.relative_to(output_dir.parent) - cnmf_hdf5_path = output_dir.joinpath(f"{uuid}.hdf5") - cnm.save(str(cnmf_hdf5_path)) + for proj_type in proj_paths.keys(): + d[f"{proj_type}-projection-path"] = proj_paths[proj_type].relative_to( + output_dir.parent + ) - # save output paths to outputs dict - d["cnmf-hdf5-path"] = cnmf_hdf5_path.relative_to(output_dir.parent) + if IS_WINDOWS: + Yr._mmap.close() # accessing private attr but windows is annoying otherwise - for proj_type in proj_paths.keys(): - d[f"{proj_type}-projection-path"] = proj_paths[proj_type].relative_to( - output_dir.parent + # save path as relative path strings with forward slashes + cnmfe_memmap_path = str( + PurePosixPath(df.paths.split(cnmf_memmap_path)[1]) ) - cnmf_memmap_path = output_dir.joinpath(Path(fname_new).name) - if IS_WINDOWS: - Yr._mmap.close() # accessing private attr but windows is annoying otherwise - move_file(fname_new, cnmf_memmap_path) - - # save path as relative path strings with forward slashes - cnmfe_memmap_path = str( - PurePosixPath(cnmf_memmap_path.relative_to(output_dir.parent)) - ) - - d.update( - { - "cnmf-memmap-path": cnmfe_memmap_path, - "success": True, - "traceback": None, - } - ) - - except: - d = {"success": False, "traceback": traceback.format_exc()} + d.update( + { + "cnmf-memmap-path": cnmfe_memmap_path, + "success": True, + "traceback": None, + } + ) - cm.stop_server(dview=dview) + except: + d = {"success": False, "traceback": traceback.format_exc()} runtime = round(time.time() - algo_start, 2) df.caiman.update_item_with_results(uuid, d, runtime) diff --git a/mesmerize_core/algorithms/mcorr.py b/mesmerize_core/algorithms/mcorr.py index 484130d..f72435a 100644 --- a/mesmerize_core/algorithms/mcorr.py +++ b/mesmerize_core/algorithms/mcorr.py @@ -3,8 +3,6 @@ import caiman as cm from caiman.source_extraction.cnmf.params import CNMFParams from caiman.motion_correction import MotionCorrect -from caiman.summary_images import local_correlations_movie_offline -import psutil import os from pathlib import Path, PurePosixPath import numpy as np @@ -14,11 +12,17 @@ # prevent circular import if __name__ in ["__main__", "__mp_main__"]: # when running in subprocess from mesmerize_core import set_parent_raw_data_path, load_batch + from mesmerize_core.algorithms._utils import ( + ensure_server, + save_projections_parallel, + save_correlation_parallel, + ) else: # when running with local backend from ..batch_utils import set_parent_raw_data_path, load_batch + from ._utils import ensure_server, save_projections_parallel, save_correlation_parallel -def run_algo(batch_path, uuid, data_path: str = None): +def run_algo(batch_path, uuid, data_path: str = None, dview=None): algo_start = time.time() set_parent_raw_data_path(data_path) @@ -39,115 +43,93 @@ def run_algo(batch_path, uuid, data_path: str = None): params = item["params"] - # adapted from current demo notebook - if "MESMERIZE_N_PROCESSES" in os.environ.keys(): + with ensure_server(dview) as (dview, n_processes): + print("starting mc") + + rel_params = dict(params["main"]) + opts = CNMFParams(params_dict=rel_params) + # Run MC, denote boolean 'success' if MC completes w/out error try: - n_processes = int(os.environ["MESMERIZE_N_PROCESSES"]) - except: - n_processes = psutil.cpu_count() - 1 - else: - n_processes = psutil.cpu_count() - 1 - - print("starting mc") - # Start cluster for parallel processing - c, dview, n_processes = cm.cluster.setup_cluster( - backend="local", n_processes=n_processes, single_thread=False - ) + # Run MC + fnames = [input_movie_path] + mc = MotionCorrect(fnames, dview=dview, **opts.get_group("motion")) + mc.motion_correct(save_movie=True) - rel_params = dict(params["main"]) - opts = CNMFParams(params_dict=rel_params) - # Run MC, denote boolean 'success' if MC completes w/out error - try: - # Run MC - fnames = [input_movie_path] - mc = MotionCorrect(fnames, dview=dview, **opts.get_group("motion")) - mc.motion_correct(save_movie=True) - - # find path to mmap file - memmap_output_path_temp = df.paths.resolve(mc.mmap_file[0]) - - # filename to move the output back to data dir - mcorr_memmap_path = output_dir.joinpath( - f"{uuid}-{memmap_output_path_temp.name}" - ) - - # move the output file - move_file(memmap_output_path_temp, mcorr_memmap_path) - - print("mc finished successfully!") - - print("computing projections") - Yr, dims, T = cm.load_memmap(str(mcorr_memmap_path)) - images = np.reshape(Yr.T, [T] + list(dims), order="F") - - proj_paths = dict() - for proj_type in ["mean", "std", "max"]: - p_img = getattr(np, f"nan{proj_type}")(images, axis=0) - proj_paths[proj_type] = output_dir.joinpath( - f"{uuid}_{proj_type}_projection.npy" + # find path to mmap file + memmap_output_path_temp = df.paths.resolve(mc.mmap_file[0]) + + # filename to move the output back to data dir + mcorr_memmap_path = output_dir.joinpath( + f"{uuid}-{memmap_output_path_temp.name}" ) - np.save(str(proj_paths[proj_type]), p_img) - - print("Computing correlation image") - Cns = local_correlations_movie_offline( - str(mcorr_memmap_path), - remove_baseline=True, - window=1000, - stride=1000, - winSize_baseline=100, - quantil_min_baseline=10, - dview=dview, - ) - Cn = Cns.max(axis=0) - Cn[np.isnan(Cn)] = 0 - cn_path = output_dir.joinpath(f"{uuid}_cn.npy") - np.save(str(cn_path), Cn, allow_pickle=False) - - print("finished computing correlation image") - - # Compute shifts - if opts.motion["pw_rigid"] == True: - x_shifts = mc.x_shifts_els - y_shifts = mc.y_shifts_els - shifts = [x_shifts, y_shifts] - if hasattr(mc, "z_shifts_els"): - shifts.append(mc.z_shifts_els) - shift_path = output_dir.joinpath(f"{uuid}_shifts.npy") - np.save(str(shift_path), shifts) - else: - shifts = mc.shifts_rig - shift_path = output_dir.joinpath(f"{uuid}_shifts.npy") - np.save(str(shift_path), shifts) - - # output dict for pandas series for dataframe row - d = dict() - - # save paths as relative path strings with forward slashes - cn_path = str(PurePosixPath(cn_path.relative_to(output_dir.parent))) - mcorr_memmap_path = str( - PurePosixPath(mcorr_memmap_path.relative_to(output_dir.parent)) - ) - shift_path = str(PurePosixPath(shift_path.relative_to(output_dir.parent))) - for proj_type in proj_paths.keys(): - d[f"{proj_type}-projection-path"] = str( - PurePosixPath(proj_paths[proj_type].relative_to(output_dir.parent)) + + # move the output file + move_file(memmap_output_path_temp, mcorr_memmap_path) + + print("mc finished successfully!") + + print("computing projections") + Yr, dims, T = cm.load_memmap(str(mcorr_memmap_path)) + images = np.reshape(Yr.T, [T] + list(dims), order="F") + + proj_paths = save_projections_parallel( + uuid=uuid, + movie_path=mcorr_memmap_path, + output_dir=output_dir, + dview=dview, + ) + + print("Computing correlation image") + cn_path = save_correlation_parallel( + uuid=uuid, + movie_path=mcorr_memmap_path, + output_dir=output_dir, + dview=dview, ) - d.update( - { - "mcorr-output-path": mcorr_memmap_path, - "corr-img-path": cn_path, - "shifts": shift_path, - "success": True, - "traceback": None, - } - ) - - except: - d = {"success": False, "traceback": traceback.format_exc()} - print("mc failed, stored traceback in output") - - cm.stop_server(dview=dview) + print("finished computing correlation image") + + # Compute shifts + if opts.motion["pw_rigid"] == True: + x_shifts = mc.x_shifts_els + y_shifts = mc.y_shifts_els + shifts = [x_shifts, y_shifts] + if hasattr(mc, "z_shifts_els"): + shifts.append(mc.z_shifts_els) + shift_path = output_dir.joinpath(f"{uuid}_shifts.npy") + np.save(str(shift_path), shifts) + else: + shifts = mc.shifts_rig + shift_path = output_dir.joinpath(f"{uuid}_shifts.npy") + np.save(str(shift_path), shifts) + + # output dict for pandas series for dataframe row + d = dict() + + # save paths as relative path strings with forward slashes + cn_path = str(PurePosixPath(cn_path.relative_to(output_dir.parent))) + mcorr_memmap_path = str( + PurePosixPath(mcorr_memmap_path.relative_to(output_dir.parent)) + ) + shift_path = str(PurePosixPath(shift_path.relative_to(output_dir.parent))) + for proj_type in proj_paths.keys(): + d[f"{proj_type}-projection-path"] = str( + PurePosixPath(proj_paths[proj_type].relative_to(output_dir.parent)) + ) + + d.update( + { + "mcorr-output-path": mcorr_memmap_path, + "corr-img-path": cn_path, + "shifts": shift_path, + "success": True, + "traceback": None, + } + ) + + except: + d = {"success": False, "traceback": traceback.format_exc()} + print("mc failed, stored traceback in output") runtime = round(time.time() - algo_start, 2) df.caiman.update_item_with_results(uuid, d, runtime) diff --git a/mesmerize_core/batch_utils.py b/mesmerize_core/batch_utils.py index 96a2bd7..9766172 100644 --- a/mesmerize_core/batch_utils.py +++ b/mesmerize_core/batch_utils.py @@ -11,11 +11,13 @@ COMPUTE_BACKEND_SUBPROCESS = "subprocess" #: subprocess backend COMPUTE_BACKEND_SLURM = "slurm" #: SLURM backend COMPUTE_BACKEND_LOCAL = "local" +COMPUTE_BACKEND_ASYNC = "local_async" COMPUTE_BACKENDS = [ COMPUTE_BACKEND_SUBPROCESS, COMPUTE_BACKEND_SLURM, COMPUTE_BACKEND_LOCAL, + COMPUTE_BACKEND_ASYNC, ] DATAFRAME_COLUMNS = [ diff --git a/mesmerize_core/caiman_extensions/common.py b/mesmerize_core/caiman_extensions/common.py index 8e2b4cd..72483bc 100644 --- a/mesmerize_core/caiman_extensions/common.py +++ b/mesmerize_core/caiman_extensions/common.py @@ -2,13 +2,14 @@ import shutil from pathlib import Path import psutil -from subprocess import Popen +from subprocess import Popen, CalledProcessError from typing import * from uuid import UUID, uuid4 from shutil import rmtree from datetime import datetime import time from copy import deepcopy +from concurrent.futures import ThreadPoolExecutor, Future import numpy as np import pandas as pd @@ -25,6 +26,7 @@ COMPUTE_BACKENDS, COMPUTE_BACKEND_SUBPROCESS, COMPUTE_BACKEND_LOCAL, + COMPUTE_BACKEND_ASYNC, get_parent_raw_data_path, load_batch, ) @@ -513,13 +515,37 @@ def get_parent(self, index: Union[int, str, UUID]) -> Union[UUID, None]: return r["uuid"] -class DummyProcess: - """Dummy process for local backend""" +class Waitable(Protocol): + """An object that we can call "wait" on""" + def wait(self) -> None: ... + - def wait(self): +class DummyProcess(Waitable): + """Dummy process for local backend""" + def wait(self) -> None: pass +class WaitableFuture(Waitable): + """Adaptor for future returned from Executor.submit""" + def __init__(self, future: Future[None]): + self.future = future + + def wait(self) -> None: + return self.future.result() + + +class CheckedSubprocess(Waitable): + """Adaptor for Popen that just raises an exception if the return code is nonzero""" + def __init__(self, popen: Popen): + self.popen = popen + + def wait(self) -> None: + rc = self.popen.wait() + if rc != 0: + raise CalledProcessError(rc, self.popen.args) + + @pd.api.extensions.register_series_accessor("caiman") class CaimanSeriesExtensions: """ @@ -528,36 +554,54 @@ class CaimanSeriesExtensions: def __init__(self, s: pd.Series): self._series = s - self.process: Popen = None + self.process: Optional[Waitable] = None def _run_local( - self, - algo: str, - batch_path: Path, - uuid: UUID, - data_path: Union[Path, None], - ): + self, + algo: str, + batch_path: Path, + uuid: UUID, + data_path: Union[Path, None], + dview=None + ) -> DummyProcess: algo_module = getattr(algorithms, algo) algo_module.run_algo( - batch_path=str(batch_path), uuid=str(uuid), data_path=str(data_path) + batch_path=str(batch_path), + uuid=str(uuid), + data_path=str(data_path), + dview=dview ) - return DummyProcess() - def _run_subprocess(self, runfile_path: str, wait: bool, **kwargs): + def _run_local_async( + self, + algo: str, + batch_path: Path, + uuid: UUID, + data_path: Union[Path, None], + dview=None + ) -> WaitableFuture: + algo_module = getattr(algorithms, algo) + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit( + algo_module.run_algo, + batch_path=str(batch_path), + uuid=str(uuid), + data_path=str(data_path), + dview=dview + ) + return WaitableFuture(future) + + def _run_subprocess(self, runfile_path: str, **kwargs) -> CheckedSubprocess: # Get the dir that contains the input movie parent_path = self._series.paths.resolve(self._series.input_movie_path).parent - self.process = Popen([runfile_path], cwd=parent_path) + popen = Popen([runfile_path], cwd=parent_path) + return CheckedSubprocess(popen) # so that it throws an exception on failure - if wait: - self.process.wait() - - return self.process def _run_slurm( - self, runfile_path: str, wait: bool, sbatch_opts: str = "", **kwargs - ): + self, runfile_path: str, sbatch_opts: str = "", **kwargs) -> CheckedSubprocess: """ Run on a cluster using SLURM. Configurable options (to pass to run): - sbatch_opts: A single string containing additional options for sbatch. @@ -589,15 +633,12 @@ def _run_slurm( f"--output={output_path}", "--wait", ] + shlex.split(sbatch_opts) + + return CheckedSubprocess(Popen(["sbatch", *submission_opts, runfile_path])) - self.process = Popen(["sbatch", *submission_opts, runfile_path]) - if wait: - self.process.wait() - - return self.process @cnmf_cache.invalidate() - def run(self, backend: Optional[str] = None, wait: bool = True, **kwargs): + def run(self, backend: Optional[str] = None, wait: bool = True, **kwargs) -> Waitable: """ Run a CaImAn algorithm in an external process using the chosen backend @@ -636,42 +677,47 @@ def run(self, backend: Optional[str] = None, wait: bool = True, **kwargs): batch_path = self._series.paths.get_batch_path() - if backend == COMPUTE_BACKEND_LOCAL: - print(f"Running {self._series.uuid} with local backend") - return self._run_local( + if backend in [COMPUTE_BACKEND_LOCAL, COMPUTE_BACKEND_ASYNC]: + print(f"Running {self._series.uuid} with {backend} backend") + self.process = getattr(self, f"_run_{backend}")( algo=self._series["algo"], batch_path=batch_path, uuid=self._series["uuid"], data_path=get_parent_raw_data_path(), + dview=kwargs.get("dview") ) - - # Create the runfile in the batch dir using this Series' UUID as the filename - if IS_WINDOWS: - runfile_ext = ".bat" else: - runfile_ext = ".runfile" - runfile_path = str( - batch_path.parent.joinpath(self._series["uuid"] + runfile_ext) - ) + # Create the runfile in the batch dir using this Series' UUID as the filename + if IS_WINDOWS: + runfile_ext = ".bat" + else: + runfile_ext = ".runfile" + runfile_path = str( + batch_path.parent.joinpath(self._series["uuid"] + runfile_ext) + ) - args_str = ( - f"--batch-path {lex.quote(str(batch_path))} --uuid {self._series.uuid}" - ) - if get_parent_raw_data_path() is not None: - args_str += f" --data-path {lex.quote(str(get_parent_raw_data_path()))}" - - # make the runfile - runfile_path = make_runfile( - module_path=os.path.abspath( - ALGO_MODULES[self._series["algo"]].__file__ - ), # caiman algorithm - filename=runfile_path, # path to create runfile - args_str=args_str, - ) + args_str = ( + f"--batch-path {lex.quote(str(batch_path))} --uuid {self._series.uuid}" + ) + if get_parent_raw_data_path() is not None: + args_str += f" --data-path {lex.quote(str(get_parent_raw_data_path()))}" + + # make the runfile + runfile_path = make_runfile( + module_path=os.path.abspath( + ALGO_MODULES[self._series["algo"]].__file__ + ), # caiman algorithm + filename=runfile_path, # path to create runfile + args_str=args_str, + ) - self.process = getattr(self, f"_run_{backend}")( - runfile_path, wait=wait, **kwargs - ) + self.process = getattr(self, f"_run_{backend}")( + runfile_path, **kwargs + ) + + assert self.process is not None, 'Process should have been created' + if wait: + self.process.wait() return self.process def get_input_movie_path(self) -> Path: diff --git a/tests/test_core.py b/tests/test_core.py index b8a7419..d977eb0 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,5 +1,4 @@ import os - import numpy as np from caiman.utils.utils import load_dict_from_hdf5 from caiman.source_extraction.cnmf.cnmf import CNMF @@ -15,9 +14,12 @@ from mesmerize_core.batch_utils import ( DATAFRAME_COLUMNS, COMPUTE_BACKEND_SUBPROCESS, + COMPUTE_BACKEND_LOCAL, + COMPUTE_BACKEND_ASYNC, get_full_raw_data_path, ) from mesmerize_core.utils import IS_WINDOWS +from mesmerize_core.algorithms._utils import ensure_server from uuid import uuid4 import pytest import requests @@ -33,6 +35,7 @@ import tifffile from copy import deepcopy + # don't call "resolve" on these - want to make sure we can handle non-canonical paths correctly tmp_dir = Path(os.path.dirname(os.path.abspath(__file__)), "test data", "tmp") vid_dir = Path(os.path.dirname(os.path.abspath(__file__)), "test data", "videos") @@ -50,7 +53,7 @@ def _download_ground_truths(): print(f"Downloading ground truths") - url = f"https://zenodo.org/record/14934525/files/ground_truths.zip" + url = f"https://zenodo.org/record/17253218/files/ground_truths.zip" # basically from https://stackoverflow.com/questions/37573483/progress-bar-while-download-file-over-http-with-requests/37573701 response = requests.get(url, stream=True) @@ -1336,3 +1339,48 @@ def test_cache(): assert hex(id(cnmf.cnmf_cache.get_cache().iloc[-1]["return_val"])) == hex( id(output) ) + + +def test_backends(): + """test subprocess, local, and async_local backend""" + set_parent_raw_data_path(vid_dir) + algo = "mcorr" + df, batch_path = _create_tmp_batch() + input_movie_path = get_datafile(algo) + + # make small version of movie for quick testing + movie = tifffile.imread(input_movie_path) + small_movie_path = input_movie_path.parent.joinpath("small_movie.tif") + tifffile.imwrite(small_movie_path, movie[:1001]) + print(input_movie_path) + + # put backends that can run in the background first to save time + backends = [COMPUTE_BACKEND_SUBPROCESS, COMPUTE_BACKEND_ASYNC, COMPUTE_BACKEND_LOCAL] + for backend in backends: + df.caiman.add_item( + algo="mcorr", + item_name=f"test-{backend}", + input_movie_path=small_movie_path, + params=test_params["mcorr"], + ) + + # run using each backend + procs = [] + with ensure_server(None) as (dview, _): + for backend, (_, item) in zip(backends, df.iterrows()): + procs.append(item.caiman.run(backend=backend, dview=dview, wait=False)) + + # wait for all to finish + for proc in procs: + proc.wait() + + # compare results + df = load_batch(batch_path) + for i, item in df.iterrows(): + output = item.mcorr.get_output() + + if i == 0: + # save to compare to other results + first_output = output + else: + numpy.testing.assert_array_equal(output, first_output)