Skip to content

Commit 8333bc8

Browse files
committed
switch from using asyncio to concurrent.futures and add test
1 parent e109a4d commit 8333bc8

File tree

5 files changed

+92
-33
lines changed

5 files changed

+92
-33
lines changed

mesmerize_core/algorithms/cnmf.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
"""Performs CNMF in a separate process"""
2-
import asyncio
32
import click
43
import caiman as cm
54
from caiman.source_extraction.cnmf import cnmf as cnmf
65
from caiman.source_extraction.cnmf.params import CNMFParams
7-
import psutil
86
import numpy as np
97
import traceback
108
from pathlib import Path, PurePosixPath
119
from shutil import move as move_file
12-
import os
1310
import time
1411

1512
# prevent circular import
@@ -24,9 +21,6 @@
2421

2522

2623
def run_algo(batch_path, uuid, data_path: str = None, dview=None):
27-
asyncio.run(run_algo_async(batch_path, uuid, data_path=data_path, dview=dview))
28-
29-
async def run_algo_async(batch_path, uuid, data_path: str = None, dview=None):
3024
algo_start = time.time()
3125
set_parent_raw_data_path(data_path)
3226

mesmerize_core/algorithms/cnmfe.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
1-
import asyncio
21
import click
32
import numpy as np
43
import caiman as cm
54
from caiman.source_extraction.cnmf import cnmf as cnmf
65
from caiman.source_extraction.cnmf.params import CNMFParams
7-
import psutil
86
import traceback
97
from pathlib import Path, PurePosixPath
108
from shutil import move as move_file
11-
import os
129
import time
1310

1411
if __name__ in ["__main__", "__mp_main__"]: # when running in subprocess
@@ -22,9 +19,6 @@
2219

2320

2421
def run_algo(batch_path, uuid, data_path: str = None, dview=None):
25-
asyncio.run(run_algo_async(batch_path, uuid, data_path=data_path, dview=dview))
26-
27-
async def run_algo_async(batch_path, uuid, data_path: str = None, dview=None):
2822
algo_start = time.time()
2923
set_parent_raw_data_path(data_path)
3024

mesmerize_core/algorithms/mcorr.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import traceback
2-
import asyncio
32
import click
43
import caiman as cm
54
from caiman.source_extraction.cnmf.params import CNMFParams
65
from caiman.motion_correction import MotionCorrect
76
from caiman.summary_images import local_correlations_movie_offline
8-
import psutil
97
import os
108
from pathlib import Path, PurePosixPath
119
import numpy as np
@@ -22,9 +20,6 @@
2220

2321

2422
def run_algo(batch_path, uuid, data_path: str = None, dview=None):
25-
asyncio.run(run_algo_async(batch_path, uuid, data_path=data_path, dview=dview))
26-
27-
async def run_algo_async(batch_path, uuid, data_path: str = None, dview=None):
2823
algo_start = time.time()
2924
set_parent_raw_data_path(data_path)
3025

mesmerize_core/caiman_extensions/common.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import time
1111
from copy import deepcopy
1212
import shlex
13-
import asyncio
13+
from concurrent.futures import ThreadPoolExecutor, Future
1414

1515
import numpy as np
1616
import pandas as pd
@@ -460,12 +460,26 @@ def get_parent(self, index: Union[int, str, UUID]) -> Union[UUID, None]:
460460
return r["uuid"]
461461

462462

463-
class DummyProcess:
463+
class Waitable(Protocol):
464+
"""An object that we can call "wait" on"""
465+
def wait(self) -> None: ...
466+
467+
468+
class DummyProcess(Waitable):
464469
"""Dummy process for local backend"""
465-
def wait(self):
470+
def wait(self) -> None:
466471
pass
467472

468473

474+
class WaitableFuture(Waitable):
475+
"""Adaptor for future returned from Executor.submit"""
476+
def __init__(self, future: Future[None]):
477+
self.future = future
478+
479+
def wait(self) -> None:
480+
return self.future.result()
481+
482+
469483
@pd.api.extensions.register_series_accessor("caiman")
470484
class CaimanSeriesExtensions:
471485
"""
@@ -474,7 +488,7 @@ class CaimanSeriesExtensions:
474488

475489
def __init__(self, s: pd.Series):
476490
self._series = s
477-
self.process: Popen = None
491+
self.process: Optional[Waitable] = None
478492

479493
def _run_local(
480494
self,
@@ -484,9 +498,15 @@ def _run_local(
484498
data_path: Union[Path, None],
485499
dview=None
486500
) -> DummyProcess:
487-
coroutine = self._run_local_async(algo, batch_path, uuid, data_path, dview)
488-
asyncio.run(coroutine)
489-
return DummyProcess()
501+
algo_module = getattr(algorithms, algo)
502+
algo_module.run_algo(
503+
batch_path=str(batch_path),
504+
uuid=str(uuid),
505+
data_path=str(data_path),
506+
dview=dview
507+
)
508+
self.process = DummyProcess()
509+
return self.process
490510

491511
def _run_local_async(
492512
self,
@@ -495,14 +515,18 @@ def _run_local_async(
495515
uuid: UUID,
496516
data_path: Union[Path, None],
497517
dview=None
498-
) -> Coroutine:
518+
) -> WaitableFuture:
499519
algo_module = getattr(algorithms, algo)
500-
return algo_module.run_algo_async(
501-
batch_path=str(batch_path),
502-
uuid=str(uuid),
503-
data_path=str(data_path),
504-
dview=dview
505-
)
520+
with ThreadPoolExecutor(max_workers=1) as executor:
521+
future = executor.submit(
522+
algo_module.run_algo,
523+
batch_path=str(batch_path),
524+
uuid=str(uuid),
525+
data_path=str(data_path),
526+
dview=dview
527+
)
528+
self.process = WaitableFuture(future)
529+
return self.process
506530

507531
def _run_subprocess(
508532
self,

tests/test_core.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import os
2-
32
import numpy as np
43
from caiman.utils.utils import load_dict_from_hdf5
54
from caiman.source_extraction.cnmf import cnmf
@@ -12,8 +11,14 @@
1211
CaimanSeriesExtensions,
1312
set_parent_raw_data_path,
1413
)
15-
from mesmerize_core.batch_utils import DATAFRAME_COLUMNS, COMPUTE_BACKEND_SUBPROCESS, get_full_raw_data_path
14+
from mesmerize_core.batch_utils import (
15+
DATAFRAME_COLUMNS,
16+
COMPUTE_BACKEND_SUBPROCESS,
17+
COMPUTE_BACKEND_LOCAL,
18+
COMPUTE_BACKEND_ASYNC,
19+
get_full_raw_data_path)
1620
from mesmerize_core.utils import IS_WINDOWS
21+
from mesmerize_core.algorithms._utils import ensure_server
1722
from uuid import uuid4
1823
from typing import *
1924
import pytest
@@ -30,6 +35,8 @@
3035
import tifffile
3136
from copy import deepcopy
3237

38+
pytest_plugins = ('pytest_asyncio',)
39+
3340
tmp_dir = Path(os.path.dirname(os.path.abspath(__file__)), "tmp")
3441
vid_dir = Path(os.path.dirname(os.path.abspath(__file__)), "videos")
3542
ground_truths_dir = Path(os.path.dirname(os.path.abspath(__file__)), "ground_truths")
@@ -1254,3 +1261,48 @@ def test_cache():
12541261
output2 = df.iloc[1].cnmf.get_output(return_copy=False)
12551262
assert(hex(id(output)) == hex(id(output2)))
12561263
assert(hex(id(cnmf.cnmf_cache.get_cache().iloc[-1]["return_val"])) == hex(id(output)))
1264+
1265+
1266+
def test_backends():
1267+
"""test subprocess, local, and async_local backend"""
1268+
set_parent_raw_data_path(vid_dir)
1269+
algo = "mcorr"
1270+
df, batch_path = _create_tmp_batch()
1271+
input_movie_path = get_datafile(algo)
1272+
1273+
# make small version of movie for quick testing
1274+
movie = tifffile.imread(input_movie_path)
1275+
small_movie_path = input_movie_path.parent.joinpath("small_movie.tif")
1276+
tifffile.imwrite(small_movie_path, movie[:1001])
1277+
print(input_movie_path)
1278+
1279+
# put backends that can run in the background first to save time
1280+
backends = [COMPUTE_BACKEND_SUBPROCESS, COMPUTE_BACKEND_ASYNC, COMPUTE_BACKEND_LOCAL]
1281+
for backend in backends:
1282+
df.caiman.add_item(
1283+
algo="mcorr",
1284+
item_name=f"test-{backend}",
1285+
input_movie_path=small_movie_path,
1286+
params=test_params["mcorr"],
1287+
)
1288+
1289+
# run using each backend
1290+
procs = []
1291+
with ensure_server(None) as (dview, _):
1292+
for backend, (_, item) in zip(backends, df.iterrows()):
1293+
procs.append(item.caiman.run(backend=backend, dview=dview, wait=False))
1294+
1295+
# wait for all to finish
1296+
for proc in procs:
1297+
proc.wait()
1298+
1299+
# compare results
1300+
df = load_batch(batch_path)
1301+
for i, item in df.iterrows():
1302+
output = item.mcorr.get_output()
1303+
1304+
if i == 0:
1305+
# save to compare to other results
1306+
first_output = output
1307+
else:
1308+
numpy.testing.assert_array_equal(output, first_output)

0 commit comments

Comments
 (0)