|
26 | 26 | """ |
27 | 27 |
|
28 | 28 | from typing import (Mapping, Dict, Union, Set, Tuple, Any, FrozenSet, |
29 | | - TYPE_CHECKING) |
| 29 | + Type, TYPE_CHECKING) |
30 | 30 | from pytato.array import (Array, IndexLambda, Stack, Concatenate, Einsum, |
31 | 31 | DictOfNamedArrays, NamedArray, |
32 | 32 | IndexBase, IndexRemappingBase, InputArgumentBase, |
@@ -426,18 +426,24 @@ class NodeTypeCountMapper(CachedWalkMapper): |
426 | 426 | """ |
427 | 427 |
|
428 | 428 | def __init__(self) -> None: |
| 429 | + from collections import defaultdict |
429 | 430 | super().__init__() |
430 | | - self.counts = {} |
| 431 | + self.counts = defaultdict(int) |
431 | 432 |
|
432 | 433 | def get_cache_key(self, expr: ArrayOrNames) -> int: |
433 | 434 | return id(expr) |
434 | 435 |
|
435 | 436 | def post_visit(self, expr: Any) -> None: |
| 437 | + if type(expr) not in self.counts: |
| 438 | + self.counts[type(expr)] = 0 |
436 | 439 | self.counts[type(expr)] += 1 |
437 | 440 |
|
438 | 441 |
|
439 | | -def get_num_node_types(outputs: Union[Array, DictOfNamedArrays]) -> int: |
440 | | - """Returns the number of nodes of each given type in DAG *outputs*.""" |
| 442 | +def get_num_node_types(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type, int]: |
| 443 | + """ |
| 444 | + Returns a dictionary mapping node types to node count for that type |
| 445 | + in DAG *outputs*. |
| 446 | + """ |
441 | 447 |
|
442 | 448 | from pytato.codegen import normalize_outputs |
443 | 449 | outputs = normalize_outputs(outputs) |
|
0 commit comments