11from contextlib import contextmanager
22import os
33import psutil
4- from typing import Optional , Union , Generator
4+ from typing import (Optional , Union , Generator , Protocol ,
5+ Callable , TypeVar , Sequence , Iterable , runtime_checkable )
56
67import caiman as cm
78from caiman .cluster import setup_cluster
89from ipyparallel import DirectView
910from multiprocessing .pool import Pool
1011
1112
12- Cluster = Union [Pool , DirectView ]
13+ @runtime_checkable
14+ class CustomCluster (Protocol ):
15+ """Protocol for a cluster that is not a multiprocessing pool"""
16+ RetVal = TypeVar ('RetVal' )
17+ def map_sync (self , fn : Callable [..., RetVal ], args : Iterable ) -> Sequence [RetVal ]:
18+ ...
19+
20+ def __len__ (self ) -> int :
21+ """return number of workers"""
22+ ...
23+
24+ Cluster = Union [Pool , DirectView , CustomCluster ]
1325
1426def get_n_processes (dview : Optional [Cluster ]) -> int :
1527 """Infer number of processes in a multiprocessing or ipyparallel cluster"""
1628 if isinstance (dview , Pool ) and hasattr (dview , '_processes' ):
1729 return dview ._processes # type: ignore
18- elif isinstance (dview , DirectView ):
30+ elif isinstance (dview , CustomCluster ):
1931 return len (dview )
2032 else :
2133 return 1
@@ -33,13 +45,17 @@ def ensure_server(dview: Optional[Cluster]) -> Generator[tuple[Cluster, int], No
3345 yield dview , get_n_processes (dview )
3446 else :
3547 # no cluster passed in, so open one
48+ procs_available = psutil .cpu_count ()
49+ if procs_available is None :
50+ raise RuntimeError ('Cannot determine number of processes' )
51+
3652 if "MESMERIZE_N_PROCESSES" in os .environ .keys ():
3753 try :
3854 n_processes = int (os .environ ["MESMERIZE_N_PROCESSES" ])
3955 except :
40- n_processes = psutil . cpu_count () - 1
56+ n_processes = procs_available - 1
4157 else :
42- n_processes = psutil . cpu_count () - 1
58+ n_processes = procs_available - 1
4359
4460 # Start cluster for parallel processing
4561 _ , dview , n_processes = setup_cluster (
0 commit comments