|
6 | 6 |
|
7 | 7 | import caiman as cm |
8 | 8 | from caiman.cluster import setup_cluster |
9 | | -from ipyparallel import DirectView |
10 | 9 | from multiprocessing.pool import Pool |
11 | 10 |
|
12 | 11 |
|
| 12 | +RetVal = TypeVar("RetVal") |
13 | 13 | @runtime_checkable |
14 | 14 | class CustomCluster(Protocol): |
15 | | - """Protocol for a cluster that is not a multiprocessing pool""" |
16 | | - RetVal = TypeVar('RetVal') |
17 | | - def map_sync(self, fn: Callable[..., RetVal], args: Iterable) -> Sequence[RetVal]: |
18 | | - ... |
19 | | - |
| 15 | + """ |
| 16 | + Protocol for a cluster that is not a multiprocessing pool |
| 17 | + (including ipyparallel.DirectView) |
| 18 | + """ |
| 19 | + |
| 20 | + def map_sync( |
| 21 | + self, fn: Callable[..., RetVal], args: Iterable |
| 22 | + ) -> Sequence[RetVal]: ... |
| 23 | + |
20 | 24 | def __len__(self) -> int: |
21 | 25 | """return number of workers""" |
22 | 26 | ... |
23 | 27 |
|
24 | | -Cluster = Union[Pool, DirectView, CustomCluster] |
| 28 | + |
| 29 | +Cluster = Union[Pool, CustomCluster] |
| 30 | + |
25 | 31 |
|
26 | 32 | def get_n_processes(dview: Optional[Cluster]) -> int: |
27 | 33 | """Infer number of processes in a multiprocessing or ipyparallel cluster""" |
28 | | - if isinstance(dview, Pool) and hasattr(dview, '_processes'): |
| 34 | + if isinstance(dview, Pool): |
| 35 | + assert hasattr(dview, '_processes'), "Pool not keeping track of # of processes?" |
29 | 36 | return dview._processes # type: ignore |
30 | | - elif isinstance(dview, CustomCluster): |
| 37 | + elif dview is not None: |
31 | 38 | return len(dview) |
32 | 39 | else: |
33 | 40 | return 1 |
|
0 commit comments