Skip to content

Commit 4e67a6b

Browse files
committed
Make dview compatible with other cluster types (conforming to protocol)
1 parent a88a873 commit 4e67a6b

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

mesmerize_core/algorithms/_utils.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,33 @@
11
from contextlib import contextmanager
22
import os
33
import psutil
4-
from typing import Optional, Union, Generator
4+
from typing import (Optional, Union, Generator, Protocol,
5+
Callable, TypeVar, Sequence, Iterable, runtime_checkable)
56

67
import caiman as cm
78
from caiman.cluster import setup_cluster
89
from ipyparallel import DirectView
910
from multiprocessing.pool import Pool
1011

1112

12-
Cluster = Union[Pool, DirectView]
13+
@runtime_checkable
14+
class CustomCluster(Protocol):
15+
"""Protocol for a cluster that is not a multiprocessing pool"""
16+
RetVal = TypeVar('RetVal')
17+
def map_sync(self, fn: Callable[..., RetVal], args: Iterable) -> Sequence[RetVal]:
18+
...
19+
20+
def __len__(self) -> int:
21+
"""return number of workers"""
22+
...
23+
24+
Cluster = Union[Pool, DirectView, CustomCluster]
1325

1426
def get_n_processes(dview: Optional[Cluster]) -> int:
1527
"""Infer number of processes in a multiprocessing or ipyparallel cluster"""
1628
if isinstance(dview, Pool) and hasattr(dview, '_processes'):
1729
return dview._processes # type: ignore
18-
elif isinstance(dview, DirectView):
30+
elif isinstance(dview, CustomCluster):
1931
return len(dview)
2032
else:
2133
return 1
@@ -33,13 +45,17 @@ def ensure_server(dview: Optional[Cluster]) -> Generator[tuple[Cluster, int], No
3345
yield dview, get_n_processes(dview)
3446
else:
3547
# no cluster passed in, so open one
48+
procs_available = psutil.cpu_count()
49+
if procs_available is None:
50+
raise RuntimeError('Cannot determine number of processes')
51+
3652
if "MESMERIZE_N_PROCESSES" in os.environ.keys():
3753
try:
3854
n_processes = int(os.environ["MESMERIZE_N_PROCESSES"])
3955
except:
40-
n_processes = psutil.cpu_count() - 1
56+
n_processes = procs_available - 1
4157
else:
42-
n_processes = psutil.cpu_count() - 1
58+
n_processes = procs_available - 1
4359

4460
# Start cluster for parallel processing
4561
_, dview, n_processes = setup_cluster(

0 commit comments

Comments
 (0)