Skip to content

Commit 2707808

Browse files
committed
Do not recompute toposort in every iteration of FusionOptimizer
It's not really needed as we never expand on the new nodes
1 parent c3cdb83 commit 2707808

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -623,10 +623,10 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool:
623623

624624
def find_fuseable_subgraph(
625625
*,
626-
fg: FunctionGraph,
627626
visited_nodes: set[Apply],
628627
fuseable_clients: FUSEABLE_MAPPING,
629628
unfuseable_clients: UNFUSEABLE_MAPPING,
629+
toposort_index: dict[Apply, int],
630630
) -> tuple[list[Variable], list[Variable]]:
631631
KT = TypeVar("KT")
632632
VT = TypeVar("VT", list, set)
@@ -646,8 +646,7 @@ def variables_depend_on(
646646
for a in ancestors(variables, blockers=stop_search_at)
647647
)
648648

649-
toposort = fg.toposort()
650-
for starting_node in toposort:
649+
for starting_node in toposort_index:
651650
if starting_node in visited_nodes:
652651
continue
653652

@@ -789,7 +788,7 @@ def variables_depend_on(
789788
and inp.owner not in visited_nodes
790789
)
791790
),
792-
key=lambda inp: toposort.index(inp.owner),
791+
key=lambda inp: toposort_index[inp.owner],
793792
reverse=True,
794793
):
795794
fuseable_nodes_to_visit.appendleft(inp.owner)
@@ -801,7 +800,7 @@ def variables_depend_on(
801800
for node in fuseable_clients_temp.get(next_out, ())
802801
if node not in visited_nodes
803802
),
804-
key=lambda node: toposort.index(node),
803+
key=lambda node: toposort_index[node],
805804
):
806805
fuseable_nodes_to_visit.append(next_node)
807806

@@ -875,20 +874,22 @@ def update_fuseable_mappings_after_fg_replace(
875874
# client (those that don't fit into 1))
876875
fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg)
877876
visited_nodes: set[Apply] = set()
877+
toposort_index = {node: i for i, node in enumerate(fgraph.toposort())}
878878
while True:
879-
starting_nodes = fg.apply_nodes.copy()
880879
try:
881880
subgraph_inputs, subgraph_outputs = find_fuseable_subgraph(
882-
fg=fg,
883881
visited_nodes=visited_nodes,
884882
fuseable_clients=fuseable_clients,
885883
unfuseable_clients=unfuseable_clients,
884+
toposort_index=toposort_index,
886885
)
887886
except ValueError:
888887
return
889888
else:
890889
# The caller is now expected to update fg in place,
891890
# by replacing the subgraph with a Composite Op
891+
starting_nodes = fg.apply_nodes.copy()
892+
892893
yield subgraph_inputs, subgraph_outputs
893894

894895
# This is where we avoid repeated work by using a stateful

0 commit comments

Comments
 (0)