File tree Expand file tree Collapse file tree 1 file changed +9
-3
lines changed Expand file tree Collapse file tree 1 file changed +9
-3
lines changed Original file line number Diff line number Diff line change @@ -426,18 +426,24 @@ class NodeTypeCountMapper(CachedWalkMapper):
426426 """
427427
428428 def __init__ (self ) -> None :
429+ from collections import defaultdict
429430 super ().__init__ ()
430- self .counts = {}
431+ self .counts = defaultdict ( int )
431432
432433 def get_cache_key (self , expr : ArrayOrNames ) -> int :
433434 return id (expr )
434435
435436 def post_visit (self , expr : Any ) -> None :
437+ if type (expr ) not in counts :
438+ self .counts [type (expr )] = 0
436439 self .counts [type (expr )] += 1
437440
438441
439- def get_num_node_types (outputs : Union [Array , DictOfNamedArrays ]) -> int :
440- """Returns the number of nodes of each given type in DAG *outputs*."""
442+ def get_num_node_types (outputs : Union [Array , DictOfNamedArrays ]) -> Dict [Type , int ]:
443+ """
444+ Returns a dictionary mapping node types to node count for that type
445+ in DAG *outputs*.
446+ """
441447
442448 from pytato .codegen import normalize_outputs
443449 outputs = normalize_outputs (outputs )
You can’t perform that action at this time.
0 commit comments