Skip to content

Commit e109a4d

Browse files
committed
Don't close dview if passed in; add context manager to help
1 parent 6514bc8 commit e109a4d

File tree

4 files changed

+287
-295
lines changed

4 files changed

+287
-295
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import caiman as cm
2+
from contextlib import contextmanager
3+
from ipyparallel import DirectView
4+
from multiprocessing.pool import Pool
5+
import os
6+
import psutil
7+
from typing import Union, Optional, Generator
8+
9+
Cluster = Union[Pool, DirectView]
10+
11+
def get_n_processes(dview: Optional[Cluster]) -> int:
12+
"""Infer number of processes in a multiprocessing or ipyparallel cluster"""
13+
if isinstance(dview, Pool) and hasattr(dview, '_processes'):
14+
return dview._processes
15+
elif isinstance(dview, DirectView):
16+
return len(dview)
17+
else:
18+
return 1
19+
20+
21+
@contextmanager
22+
def ensure_server(dview: Optional[Cluster]) -> Generator[tuple[Cluster, int], None, None]:
23+
"""
24+
Context manager that passes through an existing 'dview' or
25+
opens up a multiprocessing server if none is passed in.
26+
If a server was opened, closes it upon exit.
27+
Usage: `with ensure_server(dview) as (dview, n_processes):`
28+
"""
29+
if dview is not None:
30+
yield dview, get_n_processes(dview)
31+
else:
32+
# no cluster passed in, so open one
33+
if "MESMERIZE_N_PROCESSES" in os.environ.keys():
34+
try:
35+
n_processes = int(os.environ["MESMERIZE_N_PROCESSES"])
36+
except:
37+
n_processes = psutil.cpu_count() - 1
38+
else:
39+
n_processes = psutil.cpu_count() - 1
40+
41+
# Start cluster for parallel processing
42+
_, dview, n_processes = cm.cluster.setup_cluster(
43+
backend="multiprocessing", n_processes=n_processes, single_thread=False
44+
)
45+
try:
46+
yield dview, n_processes
47+
finally:
48+
cm.stop_server(dview=dview)

mesmerize_core/algorithms/cnmf.py

Lines changed: 79 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
if __name__ in ["__main__", "__mp_main__"]: # when running in subprocess
1717
from mesmerize_core import set_parent_raw_data_path, load_batch
1818
from mesmerize_core.utils import IS_WINDOWS
19+
from mesmerize_core.algorithms._utils import ensure_server
1920
else: # when running with local backend
2021
from ..batch_utils import set_parent_raw_data_path, load_batch
2122
from ..utils import IS_WINDOWS
23+
from ._utils import ensure_server
2224

2325

2426
def run_algo(batch_path, uuid, data_path: str = None, dview=None):
@@ -45,107 +47,84 @@ async def run_algo_async(batch_path, uuid, data_path: str = None, dview=None):
4547
f"Starting CNMF item:\n{item}\nWith params:{params}"
4648
)
4749

48-
if 'multiprocessing' in str(type(dview)) and hasattr(dview, '_processes'):
49-
n_processes = dview._processes
50-
elif 'ipyparallel' in str(type(dview)):
51-
n_processes = len(dview)
52-
else:
53-
# adapted from current demo notebook
54-
if "MESMERIZE_N_PROCESSES" in os.environ.keys():
55-
try:
56-
n_processes = int(os.environ["MESMERIZE_N_PROCESSES"])
57-
except:
58-
n_processes = psutil.cpu_count() - 1
59-
else:
60-
n_processes = psutil.cpu_count() - 1
61-
# Start cluster for parallel processing
62-
c, dview, n_processes = cm.cluster.setup_cluster(
63-
backend="multiprocessing", n_processes=n_processes, single_thread=False
64-
)
65-
66-
# merge cnmf and eval kwargs into one dict
67-
cnmf_params = CNMFParams(params_dict=params["main"])
68-
# Run CNMF, denote boolean 'success' if CNMF completes w/out error
69-
try:
70-
fname_new = cm.save_memmap(
71-
[input_movie_path], base_name=f"{uuid}_cnmf-memmap_", order="C", dview=dview
72-
)
73-
74-
print("making memmap")
75-
76-
Yr, dims, T = cm.load_memmap(fname_new)
77-
images = np.reshape(Yr.T, [T] + list(dims), order="F")
78-
79-
proj_paths = dict()
80-
for proj_type in ["mean", "std", "max"]:
81-
p_img = getattr(np, f"nan{proj_type}")(images, axis=0)
82-
proj_paths[proj_type] = output_dir.joinpath(
83-
f"{uuid}_{proj_type}_projection.npy"
50+
with ensure_server(dview) as (dview, n_processes):
51+
52+
# merge cnmf and eval kwargs into one dict
53+
cnmf_params = CNMFParams(params_dict=params["main"])
54+
# Run CNMF, denote boolean 'success' if CNMF completes w/out error
55+
try:
56+
fname_new = cm.save_memmap(
57+
[input_movie_path], base_name=f"{uuid}_cnmf-memmap_", order="C", dview=dview
58+
)
59+
60+
print("making memmap")
61+
62+
Yr, dims, T = cm.load_memmap(fname_new)
63+
64+
images = np.reshape(Yr.T, [T] + list(dims), order="F")
65+
66+
proj_paths = dict()
67+
for proj_type in ["mean", "std", "max"]:
68+
p_img = getattr(np, f"nan{proj_type}")(images, axis=0)
69+
proj_paths[proj_type] = output_dir.joinpath(
70+
f"{uuid}_{proj_type}_projection.npy"
71+
)
72+
np.save(str(proj_paths[proj_type]), p_img)
73+
74+
print("performing CNMF")
75+
cnm = cnmf.CNMF(n_processes, params=cnmf_params, dview=dview)
76+
77+
print("fitting images")
78+
cnm = cnm.fit(images)
79+
#
80+
if "refit" in params.keys():
81+
if params["refit"] is True:
82+
print("refitting")
83+
cnm = cnm.refit(images, dview=dview)
84+
85+
print("performing eval")
86+
cnm.estimates.evaluate_components(images, cnm.params, dview=dview)
87+
88+
output_path = output_dir.joinpath(f"{uuid}.hdf5").resolve()
89+
90+
cnm.save(str(output_path))
91+
92+
Cn = cm.local_correlations(images.transpose(1, 2, 0))
93+
Cn[np.isnan(Cn)] = 0
94+
95+
corr_img_path = output_dir.joinpath(f"{uuid}_cn.npy").resolve()
96+
np.save(str(corr_img_path), Cn, allow_pickle=False)
97+
98+
# output dict for dataframe row (pd.Series)
99+
d = dict()
100+
101+
cnmf_memmap_path = output_dir.joinpath(Path(fname_new).name)
102+
if IS_WINDOWS:
103+
Yr._mmap.close() # accessing private attr but windows is annoying otherwise
104+
move_file(fname_new, cnmf_memmap_path)
105+
106+
# save paths as relative path strings with forward slashes
107+
cnmf_hdf5_path = str(PurePosixPath(output_path.relative_to(output_dir.parent)))
108+
cnmf_memmap_path = str(PurePosixPath(cnmf_memmap_path.relative_to(output_dir.parent)))
109+
corr_img_path = str(PurePosixPath(corr_img_path.relative_to(output_dir.parent)))
110+
for proj_type in proj_paths.keys():
111+
d[f"{proj_type}-projection-path"] = str(PurePosixPath(proj_paths[proj_type].relative_to(
112+
output_dir.parent
113+
)))
114+
115+
d.update(
116+
{
117+
"cnmf-hdf5-path": cnmf_hdf5_path,
118+
"cnmf-memmap-path": cnmf_memmap_path,
119+
"corr-img-path": corr_img_path,
120+
"success": True,
121+
"traceback": None,
122+
}
84123
)
85-
np.save(str(proj_paths[proj_type]), p_img)
86-
87-
# in fname new load in memmap order C
88-
cm.stop_server(dview=dview)
89-
c, dview, n_processes = cm.cluster.setup_cluster(
90-
backend="local", n_processes=None, single_thread=False
91-
)
92-
93-
print("performing CNMF")
94-
cnm = cnmf.CNMF(n_processes, params=cnmf_params, dview=dview)
95-
96-
print("fitting images")
97-
cnm = cnm.fit(images)
98-
#
99-
if "refit" in params.keys():
100-
if params["refit"] is True:
101-
print("refitting")
102-
cnm = cnm.refit(images, dview=dview)
103-
104-
print("performing eval")
105-
cnm.estimates.evaluate_components(images, cnm.params, dview=dview)
106-
107-
output_path = output_dir.joinpath(f"{uuid}.hdf5").resolve()
108-
109-
cnm.save(str(output_path))
110-
111-
Cn = cm.local_correlations(images.transpose(1, 2, 0))
112-
Cn[np.isnan(Cn)] = 0
113-
114-
corr_img_path = output_dir.joinpath(f"{uuid}_cn.npy").resolve()
115-
np.save(str(corr_img_path), Cn, allow_pickle=False)
116-
117-
# output dict for dataframe row (pd.Series)
118-
d = dict()
119-
120-
cnmf_memmap_path = output_dir.joinpath(Path(fname_new).name)
121-
if IS_WINDOWS:
122-
Yr._mmap.close() # accessing private attr but windows is annoying otherwise
123-
move_file(fname_new, cnmf_memmap_path)
124-
125-
# save paths as relative path strings with forward slashes
126-
cnmf_hdf5_path = str(PurePosixPath(output_path.relative_to(output_dir.parent)))
127-
cnmf_memmap_path = str(PurePosixPath(cnmf_memmap_path.relative_to(output_dir.parent)))
128-
corr_img_path = str(PurePosixPath(corr_img_path.relative_to(output_dir.parent)))
129-
for proj_type in proj_paths.keys():
130-
d[f"{proj_type}-projection-path"] = str(PurePosixPath(proj_paths[proj_type].relative_to(
131-
output_dir.parent
132-
)))
133-
134-
d.update(
135-
{
136-
"cnmf-hdf5-path": cnmf_hdf5_path,
137-
"cnmf-memmap-path": cnmf_memmap_path,
138-
"corr-img-path": corr_img_path,
139-
"success": True,
140-
"traceback": None,
141-
}
142-
)
143-
144-
except:
145-
d = {"success": False, "traceback": traceback.format_exc()}
146-
147-
cm.stop_server(dview=dview)
148-
124+
125+
except:
126+
d = {"success": False, "traceback": traceback.format_exc()}
127+
149128
runtime = round(time.time() - algo_start, 2)
150129
df.caiman.update_item_with_results(uuid, d, runtime)
151130

0 commit comments

Comments
 (0)