File tree Expand file tree Collapse file tree 1 file changed +37
-0
lines changed Expand file tree Collapse file tree 1 file changed +37
-0
lines changed Original file line number Diff line number Diff line change @@ -413,6 +413,43 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int:
413413# }}} 
414414
415415
416+ # {{{ NodeTypeCountMapper 
417+ 
418+ @optimize_mapper (drop_args = True , drop_kwargs = True , inline_get_cache_key = True ) 
419+ class  NodeTypeCountMapper (CachedWalkMapper ):
420+     """ 
421+     Counts the number of nodes of a given type in a DAG. 
422+ 
423+     .. attribute:: counts 
424+ 
425+        Dictionary mapping node types to number of nodes of that type. 
426+     """ 
427+ 
428+     def  __init__ (self ) ->  None :
429+         super ().__init__ ()
430+         self .counts  =  {}
431+ 
432+     def  get_cache_key (self , expr : ArrayOrNames ) ->  int :
433+         return  id (expr )
434+ 
435+     def  post_visit (self , expr : Any ) ->  None :
436+         self .counts [type (expr )] +=  1 
437+ 
438+ 
439+ def  get_num_node_types (outputs : Union [Array , DictOfNamedArrays ]) ->  int :
440+     """Returns the number of nodes of each given type in DAG *outputs*.""" 
441+ 
442+     from  pytato .codegen  import  normalize_outputs 
443+     outputs  =  normalize_outputs (outputs )
444+ 
445+     ncm  =  NodeTypeCountMapper ()
446+     ncm (outputs )
447+ 
448+     return  ncm .counts 
449+ 
450+ # }}} 
451+ 
452+ 
416453# {{{ CallSiteCountMapper 
417454
418455@optimize_mapper (drop_args = True , drop_kwargs = True , inline_get_cache_key = True ) 
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments