Skip to content

Commit 85f5f53

Browse files
committed
Don't close dview if passed in; add context manager to help
1 parent bf4dabd commit 85f5f53

File tree

4 files changed

+281
-289
lines changed

4 files changed

+281
-289
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import caiman as cm
2+
from contextlib import contextmanager
3+
from ipyparallel import DirectView
4+
from multiprocessing.pool import Pool
5+
import os
6+
import psutil
7+
from typing import Union, Optional, Generator
8+
9+
Cluster = Union[Pool, DirectView]
10+
11+
def get_n_processes(dview: Optional[Cluster]) -> int:
12+
"""Infer number of processes in a multiprocessing or ipyparallel cluster"""
13+
if isinstance(dview, Pool) and hasattr(dview, '_processes'):
14+
return dview._processes
15+
elif isinstance(dview, DirectView):
16+
return len(dview)
17+
else:
18+
return 1
19+
20+
21+
@contextmanager
22+
def ensure_server(dview: Optional[Cluster]) -> Generator[tuple[Cluster, int], None, None]:
23+
"""
24+
Context manager that passes through an existing 'dview' or
25+
opens up a multiprocessing server if none is passed in.
26+
If a server was opened, closes it upon exit.
27+
Usage: `with ensure_server(dview) as (dview, n_processes):`
28+
"""
29+
if dview is not None:
30+
yield dview, get_n_processes(dview)
31+
else:
32+
# no cluster passed in, so open one
33+
if "MESMERIZE_N_PROCESSES" in os.environ.keys():
34+
try:
35+
n_processes = int(os.environ["MESMERIZE_N_PROCESSES"])
36+
except:
37+
n_processes = psutil.cpu_count() - 1
38+
else:
39+
n_processes = psutil.cpu_count() - 1
40+
41+
# Start cluster for parallel processing
42+
_, dview, n_processes = cm.cluster.setup_cluster(
43+
backend="multiprocessing", n_processes=n_processes, single_thread=False
44+
)
45+
try:
46+
yield dview, n_processes
47+
finally:
48+
cm.stop_server(dview=dview)

mesmerize_core/algorithms/cnmf.py

Lines changed: 77 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
if __name__ in ["__main__", "__mp_main__"]: # when running in subprocess
1818
from mesmerize_core import set_parent_raw_data_path, load_batch
1919
from mesmerize_core.utils import IS_WINDOWS
20+
from mesmerize_core.algorithms._utils import ensure_server
2021
else: # when running with local backend
2122
from ..batch_utils import set_parent_raw_data_path, load_batch
2223
from ..utils import IS_WINDOWS
24+
from ._utils import ensure_server
2325

2426

2527
def run_algo(batch_path, uuid, data_path: str = None, dview=None):
@@ -46,108 +48,85 @@ async def run_algo_async(batch_path, uuid, data_path: str = None, dview=None):
4648
f"Starting CNMF item:\n{item}\nWith params:{params}"
4749
)
4850

49-
if 'multiprocessing' in str(type(dview)) and hasattr(dview, '_processes'):
50-
n_processes = dview._processes
51-
elif 'ipyparallel' in str(type(dview)):
52-
n_processes = len(dview)
53-
else:
54-
# adapted from current demo notebook
55-
if "MESMERIZE_N_PROCESSES" in os.environ.keys():
56-
try:
57-
n_processes = int(os.environ["MESMERIZE_N_PROCESSES"])
58-
except:
59-
n_processes = psutil.cpu_count() - 1
60-
else:
61-
n_processes = psutil.cpu_count() - 1
62-
# Start cluster for parallel processing
63-
c, dview, n_processes = cm.cluster.setup_cluster(
64-
backend="multiprocessing", n_processes=n_processes, single_thread=False
65-
)
66-
67-
# merge cnmf and eval kwargs into one dict
68-
cnmf_params = CNMFParams(params_dict=params["main"])
69-
# Run CNMF, denote boolean 'success' if CNMF completes w/out error
70-
try:
71-
fname_new = cm.save_memmap(
72-
[input_movie_path], base_name=f"{uuid}_cnmf-memmap_", order="C", dview=dview
73-
)
74-
75-
print("making memmap")
76-
77-
Yr, dims, T = cm.load_memmap(fname_new)
78-
images = np.reshape(Yr.T, [T] + list(dims), order="F")
79-
80-
proj_paths = dict()
81-
for proj_type in ["mean", "std", "max"]:
82-
p_img = getattr(np, f"nan{proj_type}")(images, axis=0)
83-
proj_paths[proj_type] = output_dir.joinpath(
84-
f"{uuid}_{proj_type}_projection.npy"
85-
)
86-
np.save(str(proj_paths[proj_type]), p_img)
87-
88-
# in fname new load in memmap order C
89-
cm.stop_server(dview=dview)
90-
c, dview, n_processes = cm.cluster.setup_cluster(
91-
backend="local", n_processes=None, single_thread=False
92-
)
93-
94-
print("performing CNMF")
95-
cnm = cnmf.CNMF(n_processes, params=cnmf_params, dview=dview)
96-
97-
print("fitting images")
98-
cnm.fit(images)
99-
#
100-
if "refit" in params.keys():
101-
if params["refit"] is True:
102-
print("refitting")
103-
cnm = cnm.refit(images, dview=dview)
104-
105-
print("performing eval")
106-
cnm.estimates.evaluate_components(images, cnm.params, dview=dview)
107-
108-
output_path = output_dir.joinpath(f"{uuid}.hdf5")
109-
110-
cnm.save(str(output_path))
111-
112-
Cn = cm.local_correlations(images, swap_dim=False)
113-
Cn[np.isnan(Cn)] = 0
114-
115-
corr_img_path = output_dir.joinpath(f"{uuid}_cn.npy")
116-
np.save(str(corr_img_path), Cn, allow_pickle=False)
117-
118-
# output dict for dataframe row (pd.Series)
119-
d = dict()
120-
121-
cnmf_memmap_path = output_dir.joinpath(Path(fname_new).name)
122-
if IS_WINDOWS:
123-
Yr._mmap.close() # accessing private attr but windows is annoying otherwise
124-
move_file(fname_new, cnmf_memmap_path)
125-
126-
# save paths as relative path strings with forward slashes
127-
cnmf_hdf5_path = str(PurePosixPath(output_path.relative_to(output_dir.parent)))
128-
cnmf_memmap_path = str(
129-
PurePosixPath(cnmf_memmap_path.relative_to(output_dir.parent))
130-
)
131-
corr_img_path = str(PurePosixPath(corr_img_path.relative_to(output_dir.parent)))
132-
for proj_type in proj_paths.keys():
133-
d[f"{proj_type}-projection-path"] = str(
134-
PurePosixPath(proj_paths[proj_type].relative_to(output_dir.parent))
51+
with ensure_server(dview) as (dview, n_processes):
52+
53+
# merge cnmf and eval kwargs into one dict
54+
cnmf_params = CNMFParams(params_dict=params["main"])
55+
# Run CNMF, denote boolean 'success' if CNMF completes w/out error
56+
try:
57+
fname_new = cm.save_memmap(
58+
[input_movie_path], base_name=f"{uuid}_cnmf-memmap_", order="C", dview=dview
13559
)
13660

137-
d.update(
138-
{
139-
"cnmf-hdf5-path": cnmf_hdf5_path,
140-
"cnmf-memmap-path": cnmf_memmap_path,
141-
"corr-img-path": corr_img_path,
142-
"success": True,
143-
"traceback": None,
144-
}
145-
)
61+
print("making memmap")
62+
63+
Yr, dims, T = cm.load_memmap(fname_new)
64+
65+
images = np.reshape(Yr.T, [T] + list(dims), order="F")
66+
67+
proj_paths = dict()
68+
for proj_type in ["mean", "std", "max"]:
69+
p_img = getattr(np, f"nan{proj_type}")(images, axis=0)
70+
proj_paths[proj_type] = output_dir.joinpath(
71+
f"{uuid}_{proj_type}_projection.npy"
72+
)
73+
np.save(str(proj_paths[proj_type]), p_img)
74+
75+
print("performing CNMF")
76+
cnm = cnmf.CNMF(n_processes, params=cnmf_params, dview=dview)
77+
78+
print("fitting images")
79+
cnm.fit(images)
80+
#
81+
if "refit" in params.keys():
82+
if params["refit"] is True:
83+
print("refitting")
84+
cnm = cnm.refit(images, dview=dview)
85+
86+
print("performing eval")
87+
cnm.estimates.evaluate_components(images, cnm.params, dview=dview)
14688

147-
except:
148-
d = {"success": False, "traceback": traceback.format_exc()}
89+
output_path = output_dir.joinpath(f"{uuid}.hdf5")
90+
91+
cnm.save(str(output_path))
92+
93+
Cn = cm.local_correlations(images, swap_dim=False)
94+
Cn[np.isnan(Cn)] = 0
95+
96+
corr_img_path = output_dir.joinpath(f"{uuid}_cn.npy")
97+
np.save(str(corr_img_path), Cn, allow_pickle=False)
98+
99+
# output dict for dataframe row (pd.Series)
100+
d = dict()
101+
102+
cnmf_memmap_path = output_dir.joinpath(Path(fname_new).name)
103+
if IS_WINDOWS:
104+
Yr._mmap.close() # accessing private attr but windows is annoying otherwise
105+
move_file(fname_new, cnmf_memmap_path)
106+
107+
# save paths as relative path strings with forward slashes
108+
cnmf_hdf5_path = str(PurePosixPath(output_path.relative_to(output_dir.parent)))
109+
cnmf_memmap_path = str(
110+
PurePosixPath(cnmf_memmap_path.relative_to(output_dir.parent))
111+
)
112+
corr_img_path = str(PurePosixPath(corr_img_path.relative_to(output_dir.parent)))
113+
for proj_type in proj_paths.keys():
114+
d[f"{proj_type}-projection-path"] = str(
115+
PurePosixPath(proj_paths[proj_type].relative_to(output_dir.parent))
116+
)
117+
118+
d.update(
119+
{
120+
"cnmf-hdf5-path": cnmf_hdf5_path,
121+
"cnmf-memmap-path": cnmf_memmap_path,
122+
"corr-img-path": corr_img_path,
123+
"success": True,
124+
"traceback": None,
125+
}
126+
)
149127

150-
cm.stop_server(dview=dview)
128+
except:
129+
d = {"success": False, "traceback": traceback.format_exc()}
151130

152131
runtime = round(time.time() - algo_start, 2)
153132
df.caiman.update_item_with_results(uuid, d, runtime)

0 commit comments

Comments
 (0)