@@ -623,10 +623,10 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool:
623623
624624 def find_fuseable_subgraph (
625625 * ,
626- fg : FunctionGraph ,
627626 visited_nodes : set [Apply ],
628627 fuseable_clients : FUSEABLE_MAPPING ,
629628 unfuseable_clients : UNFUSEABLE_MAPPING ,
629+ toposort_index : dict [Apply , int ],
630630 ) -> tuple [list [Variable ], list [Variable ]]:
631631 KT = TypeVar ("KT" )
632632 VT = TypeVar ("VT" , list , set )
@@ -646,8 +646,7 @@ def variables_depend_on(
646646 for a in ancestors (variables , blockers = stop_search_at )
647647 )
648648
649- toposort = fg .toposort ()
650- for starting_node in toposort :
649+ for starting_node in toposort_index :
651650 if starting_node in visited_nodes :
652651 continue
653652
@@ -789,7 +788,7 @@ def variables_depend_on(
789788 and inp .owner not in visited_nodes
790789 )
791790 ),
792- key = lambda inp : toposort . index ( inp .owner ) ,
791+ key = lambda inp : toposort_index [ inp .owner ] ,
793792 reverse = True ,
794793 ):
795794 fuseable_nodes_to_visit .appendleft (inp .owner )
@@ -801,7 +800,7 @@ def variables_depend_on(
801800 for node in fuseable_clients_temp .get (next_out , ())
802801 if node not in visited_nodes
803802 ),
804- key = lambda node : toposort . index ( node ) ,
803+ key = lambda node : toposort_index [ node ] ,
805804 ):
806805 fuseable_nodes_to_visit .append (next_node )
807806
@@ -875,20 +874,22 @@ def update_fuseable_mappings_after_fg_replace(
875874 # client (those that don't fit into 1))
876875 fuseable_clients , unfuseable_clients = initialize_fuseable_mappings (fg = fg )
877876 visited_nodes : set [Apply ] = set ()
877+ toposort_index = {node : i for i , node in enumerate (fgraph .toposort ())}
878878 while True :
879- starting_nodes = fg .apply_nodes .copy ()
880879 try :
881880 subgraph_inputs , subgraph_outputs = find_fuseable_subgraph (
882- fg = fg ,
883881 visited_nodes = visited_nodes ,
884882 fuseable_clients = fuseable_clients ,
885883 unfuseable_clients = unfuseable_clients ,
884+ toposort_index = toposort_index ,
886885 )
887886 except ValueError :
888887 return
889888 else :
890889 # The caller is now expected to update fg in place,
891890 # by replacing the subgraph with a Composite Op
891+ starting_nodes = fg .apply_nodes .copy ()
892+
892893 yield subgraph_inputs , subgraph_outputs
893894
894895 # This is where we avoid repeated work by using a stateful
0 commit comments