diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 5bf374746..4c18079be 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -26,7 +26,7 @@ """ from typing import (Mapping, Dict, Union, Set, Tuple, Any, FrozenSet, - TYPE_CHECKING) + Type, TYPE_CHECKING) from pytato.array import (Array, IndexLambda, Stack, Concatenate, Einsum, DictOfNamedArrays, NamedArray, IndexBase, IndexRemappingBase, InputArgumentBase, @@ -381,23 +381,38 @@ def map_named_call_result(self, expr: NamedCallResult) -> FrozenSet[Array]: @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) class NodeCountMapper(CachedWalkMapper): """ - Counts the number of nodes in a DAG. + Counts the number of nodes of a given type in a DAG. - .. attribute:: count + .. attribute:: counts - The number of nodes. + Dictionary mapping node types to number of nodes of that type. """ def __init__(self) -> None: + from collections import defaultdict super().__init__() - self.count = 0 + self.counts = defaultdict(int) def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) def post_visit(self, expr: Any) -> None: - self.count += 1 + self.counts[type(expr)] += 1 + + +def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type, int]: + """ + Returns a dictionary mapping node types to node count for that type + in DAG *outputs*. + """ + + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + ncm = NodeTypeCountMapper() + ncm(outputs) + return ncm.counts def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: """Returns the number of nodes in DAG *outputs*.""" @@ -408,7 +423,7 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: ncm = NodeCountMapper() ncm(outputs) - return ncm.count + return sum(ncm.counts.values()) # }}}