Skip to content

Commit f3ba7bc

Browse files
committed
Cleanup FusionOptimizer code
1 parent 2707808 commit f3ba7bc

File tree

1 file changed

+76
-88
lines changed

1 file changed

+76
-88
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 76 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections import defaultdict, deque
66
from collections.abc import Generator, Sequence
77
from functools import cache, reduce
8-
from typing import TypeVar
8+
from typing import Literal
99
from warnings import warn
1010

1111
import pytensor.scalar.basic as ps
@@ -566,8 +566,7 @@ def find_next_fuseable_subgraph(
566566
This generator assumes that such subgraph is replaced by a single
567567
Elemwise Composite before being accessed again in the next iteration.
568568
"""
569-
570-
FUSEABLE_MAPPING = defaultdict[Variable, list[Apply]]
569+
FUSEABLE_MAPPING = defaultdict[Variable, set[Apply]]
571570
UNFUSEABLE_MAPPING = defaultdict[Variable, set[Apply]]
572571

573572
def initialize_fuseable_mappings(
@@ -589,35 +588,33 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool:
589588
# to ensure the rewrite remains deterministic.
590589
# This is not a problem from unfuseable ones, as they can never
591590
# become part of the graph.
592-
fuseable_clients: FUSEABLE_MAPPING = defaultdict(list)
591+
fuseable_clients: FUSEABLE_MAPPING = defaultdict(set)
593592
unfuseable_clients: UNFUSEABLE_MAPPING = defaultdict(set)
594593
for out, clients in fg.clients.items():
595-
# Old FunctionGraph nodes remain in the clients dictionary
596-
# even after they are removed by rewrites
597-
if not clients:
598-
continue
599-
600594
out_maybe_fuseable = (
601-
out.owner
595+
out.owner is not None
602596
and isinstance(out.owner.op, Elemwise)
603597
# and not isinstance(out.owner.op.scalar_op, ps.Composite)
604598
and len(out.owner.outputs) == 1
605599
and elemwise_scalar_op_has_c_code(out.owner)
606600
)
607-
for client, _ in clients:
608-
if (
609-
out_maybe_fuseable
610-
and isinstance(client.op, Elemwise)
611-
# and not isinstance(client.op.scalar_op, ps.Composite)
612-
and len(client.outputs) == 1
613-
and out.type.broadcastable
614-
== client.outputs[0].type.broadcastable
615-
and elemwise_scalar_op_has_c_code(client)
616-
):
617-
if client not in fuseable_clients[out]:
618-
fuseable_clients[out].append(client)
619-
else:
620-
unfuseable_clients[out].add(client)
601+
if out_maybe_fuseable:
602+
out_bcast = (
603+
out.type.broadcastable if out_maybe_fuseable else None
604+
)
605+
for client, _ in clients:
606+
if (
607+
isinstance(client.op, Elemwise)
608+
# and not isinstance(client.op.scalar_op, ps.Composite)
609+
and len(client.outputs) == 1
610+
and out_bcast == client.outputs[0].type.broadcastable
611+
and elemwise_scalar_op_has_c_code(client)
612+
):
613+
fuseable_clients[out].add(client)
614+
else:
615+
unfuseable_clients[out].add(client)
616+
else:
617+
unfuseable_clients[out] = {client for client, _ in clients}
621618

622619
return fuseable_clients, unfuseable_clients
623620

@@ -628,16 +625,6 @@ def find_fuseable_subgraph(
628625
unfuseable_clients: UNFUSEABLE_MAPPING,
629626
toposort_index: dict[Apply, int],
630627
) -> tuple[list[Variable], list[Variable]]:
631-
KT = TypeVar("KT")
632-
VT = TypeVar("VT", list, set)
633-
634-
def shallow_clone_defaultdict(
635-
d: defaultdict[KT, VT],
636-
) -> defaultdict[KT, VT]:
637-
new_dict: defaultdict[KT, VT] = defaultdict(d.default_factory)
638-
new_dict.update({k: v.copy() for k, v in d.items()})
639-
return new_dict
640-
641628
def variables_depend_on(
642629
variables, depend_on, stop_search_at=None
643630
) -> bool:
@@ -655,17 +642,19 @@ def variables_depend_on(
655642
visited_nodes.add(starting_node)
656643
continue
657644

658-
subgraph_inputs: list[Variable] = []
659-
subgraph_outputs: list[Variable] = []
645+
subgraph_inputs: dict[Variable, Literal[None]] = {} # ordered set
646+
subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set
660647
unfuseable_clients_subgraph: set[Variable] = set()
661648

662649
# Shallow cloning of maps so that they can be manipulated in place
663-
fuseable_clients_temp = shallow_clone_defaultdict(fuseable_clients)
664-
unfuseable_clients_clone = shallow_clone_defaultdict(
665-
unfuseable_clients
650+
fuseable_clients_clone: FUSEABLE_MAPPING = defaultdict(set)
651+
fuseable_clients_clone.update(
652+
{k: v.copy() for k, v in fuseable_clients.items()}
653+
)
654+
unfuseable_clients_clone: UNFUSEABLE_MAPPING = defaultdict(set)
655+
unfuseable_clients_clone.update(
656+
{k: v.copy() for k, v in unfuseable_clients.items()}
666657
)
667-
668-
fuseable_nodes_to_visit = deque([starting_node])
669658

670659
# We now try to expand as much as possible towards the potentially
671660
# fuseable clients and ancestors to detect the largest possible
@@ -674,6 +663,7 @@ def variables_depend_on(
674663
# some inputs or clients may depend on other nodes of the same
675664
# subgraph via a path that cannot be included in the Composite
676665
# (unfuseable)
666+
fuseable_nodes_to_visit = deque([starting_node])
677667
while fuseable_nodes_to_visit:
678668
next_node = fuseable_nodes_to_visit.popleft()
679669
visited_nodes.add(next_node)
@@ -682,15 +672,14 @@ def variables_depend_on(
682672
# If the output variable of next_node has no fuseable clients
683673
# or has unfuseable clients, then next_node must become an output
684674
# if it is to be fused.
685-
must_become_output = (
686-
next_out not in fuseable_clients_temp
687-
or next_out in unfuseable_clients_clone
688-
)
675+
must_become_output = not fuseable_clients_clone.get(
676+
next_out
677+
) or unfuseable_clients_clone.get(next_out)
689678

690679
# We have backtracked to this node, and it may no longer be a viable output,
691680
# so we remove it and check again as if we had never seen this node
692-
if must_become_output and next_out in subgraph_outputs:
693-
subgraph_outputs.remove(next_out)
681+
if must_become_output:
682+
subgraph_outputs.pop(next_out, None)
694683

695684
required_unfuseable_inputs = [
696685
inp
@@ -742,18 +731,19 @@ def variables_depend_on(
742731
if (
743732
inp.owner in visited_nodes
744733
# next_node could have the same input repeated
745-
and next_node in fuseable_clients_temp[inp]
734+
and next_node in fuseable_clients_clone[inp]
746735
):
747-
fuseable_clients_temp[inp].remove(next_node)
736+
fuseable_clients_clone[inp].remove(next_node)
748737
unfuseable_clients_clone[inp].add(next_node)
749738
# This input must become an output of the subgraph,
750739
# because it can't be merged with next_node.
751740
# We will revisit it to make sure this is safe.
752741
fuseable_nodes_to_visit.appendleft(inp.owner)
753742

754-
for client in fuseable_clients_temp[next_out]:
743+
# need to convert to tuple not to change set size during iteration
744+
for client in tuple(fuseable_clients_clone[next_out]):
755745
if client in visited_nodes:
756-
fuseable_clients_temp[next_out].remove(client)
746+
fuseable_clients_clone[next_out].remove(client)
757747
unfuseable_clients_clone[next_out].add(client)
758748
# next_out must become an input of the subgraph.
759749
# We will revisit any of its clients currently
@@ -769,74 +759,72 @@ def variables_depend_on(
769759
# mappings as if it next_node was part of it.
770760
# Useless inputs will be removed by the useless Composite rewrite
771761
for inp in new_required_unfuseable_inputs:
772-
if inp not in subgraph_inputs:
773-
subgraph_inputs.append(inp)
762+
subgraph_inputs[inp] = None
774763

775764
if must_become_output:
776-
subgraph_outputs.append(next_out)
765+
subgraph_outputs[next_out] = None
777766
unfuseable_clients_subgraph.update(
778767
new_implied_unfuseable_clients
779768
)
780769

781770
# Expand through unvisited fuseable ancestors
782-
for inp in sorted(
783-
(
784-
inp
785-
for inp in next_node.inputs
786-
if (
787-
inp not in required_unfuseable_inputs
788-
and inp.owner not in visited_nodes
789-
)
790-
),
791-
key=lambda inp: toposort_index[inp.owner],
792-
reverse=True,
793-
):
794-
fuseable_nodes_to_visit.appendleft(inp.owner)
771+
fuseable_nodes_to_visit.extendleft(
772+
sorted(
773+
(
774+
inp.owner
775+
for inp in next_node.inputs
776+
if (
777+
inp not in required_unfuseable_inputs
778+
and inp.owner not in visited_nodes
779+
)
780+
),
781+
key=toposort_index.get, # type: ignore[arg-type]
782+
)
783+
)
795784

796785
# Expand through unvisited fuseable clients
797-
for next_node in sorted(
798-
(
799-
node
800-
for node in fuseable_clients_temp.get(next_out, ())
801-
if node not in visited_nodes
802-
),
803-
key=lambda node: toposort_index[node],
804-
):
805-
fuseable_nodes_to_visit.append(next_node)
786+
fuseable_nodes_to_visit.extend(
787+
sorted(
788+
(
789+
node
790+
for node in fuseable_clients_clone.get(next_out, ())
791+
if node not in visited_nodes
792+
),
793+
key=toposort_index.get, # type: ignore[arg-type]
794+
)
795+
)
806796

807797
# Don't return if final subgraph is just the original Elemwise
808798
if len(subgraph_outputs) == 1 and set(
809-
subgraph_outputs[0].owner.inputs
799+
next(iter(subgraph_outputs)).owner.inputs
810800
) == set(subgraph_inputs):
811801
# Update global fuseable mappings
812802
# No input was actually fuseable
813803
for inp in starting_node.inputs:
814-
if starting_node in fuseable_clients.get(inp, ()):
815-
fuseable_clients[inp].remove(starting_node)
816-
unfuseable_clients[inp].add(starting_node)
804+
fuseable_clients[inp].discard(starting_node)
805+
unfuseable_clients[inp].add(starting_node)
817806
# No client was actually fuseable
818807
unfuseable_clients[starting_out].update(
819808
fuseable_clients.pop(starting_out, ())
820809
)
821810
continue
822811

823-
return subgraph_inputs, subgraph_outputs
812+
return list(subgraph_inputs), list(subgraph_outputs)
824813
raise ValueError
825814

826815
def update_fuseable_mappings_after_fg_replace(
827816
*,
828-
fg: FunctionGraph,
829817
visited_nodes: set[Apply],
830818
fuseable_clients: FUSEABLE_MAPPING,
831819
unfuseable_clients: UNFUSEABLE_MAPPING,
832820
starting_nodes: set[Apply],
821+
updated_nodes: set[Apply],
833822
) -> None:
834823
# Find new composite node and dropped intermediate nodes
835824
# by comparing the current fg.apply nodes with the cached
836825
# original nodes
837-
next_nodes = fg.apply_nodes
838-
(new_composite_node,) = next_nodes - starting_nodes
839-
dropped_nodes = starting_nodes - next_nodes
826+
(new_composite_node,) = updated_nodes - starting_nodes
827+
dropped_nodes = starting_nodes - updated_nodes
840828

841829
# Remove intermediate Composite nodes from mappings
842830
for dropped_node in dropped_nodes:
@@ -848,11 +836,11 @@ def update_fuseable_mappings_after_fg_replace(
848836
# Update fuseable information for subgraph inputs
849837
for inp in subgraph_inputs:
850838
if inp in fuseable_clients:
851-
new_fuseable_clients = [
839+
new_fuseable_clients = {
852840
client
853841
for client in fuseable_clients[inp]
854842
if client not in dropped_nodes
855-
]
843+
}
856844
if new_fuseable_clients:
857845
fuseable_clients[inp] = new_fuseable_clients
858846
else:
@@ -896,11 +884,11 @@ def update_fuseable_mappings_after_fg_replace(
896884
# generator. For large models (as in `TestFusion.test_big_fusion`)
897885
# this can provide huge speedups
898886
update_fuseable_mappings_after_fg_replace(
899-
fg=fg,
900887
visited_nodes=visited_nodes,
901888
fuseable_clients=fuseable_clients,
902889
unfuseable_clients=unfuseable_clients,
903890
starting_nodes=starting_nodes,
891+
updated_nodes=fg.apply_nodes,
904892
)
905893

906894
nb_fused = 0

0 commit comments

Comments
 (0)