File tree Expand file tree Collapse file tree 1 file changed +37
-0
lines changed Expand file tree Collapse file tree 1 file changed +37
-0
lines changed Original file line number Diff line number Diff line change @@ -413,6 +413,43 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int:
413413# }}}
414414
415415
416+ # {{{ NodeTypeCountMapper
417+
418+ @optimize_mapper (drop_args = True , drop_kwargs = True , inline_get_cache_key = True )
419+ class NodeTypeCountMapper (CachedWalkMapper ):
420+ """
421+ Counts the number of nodes of a given type in a DAG.
422+
423+ .. attribute:: counts
424+
425+ Dictionary mapping node types to number of nodes of that type.
426+ """
427+
428+ def __init__ (self ) -> None :
429+ super ().__init__ ()
430+ self .counts = {}
431+
432+ def get_cache_key (self , expr : ArrayOrNames ) -> int :
433+ return id (expr )
434+
435+ def post_visit (self , expr : Any ) -> None :
436+ self .counts [type (expr )] += 1
437+
438+
439+ def get_num_node_types (outputs : Union [Array , DictOfNamedArrays ]) -> int :
440+ """Returns the number of nodes of each given type in DAG *outputs*."""
441+
442+ from pytato .codegen import normalize_outputs
443+ outputs = normalize_outputs (outputs )
444+
445+ ncm = NodeTypeCountMapper ()
446+ ncm (outputs )
447+
448+ return ncm .counts
449+
450+ # }}}
451+
452+
416453# {{{ CallSiteCountMapper
417454
418455@optimize_mapper (drop_args = True , drop_kwargs = True , inline_get_cache_key = True )
You can’t perform that action at this time.
0 commit comments