Skip to content

Commit f8ee91c

Browse files
committed
Copy on write in FusionOptimizer
1 parent bec1f47 commit f8ee91c

File tree

1 file changed

+61
-22
lines changed

1 file changed

+61
-22
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 61 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
constant,
3535
)
3636
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
37-
from pytensor.tensor.math import add, exp, mul
37+
from pytensor.tensor.math import Any, add, exp, mul
3838
from pytensor.tensor.rewriting.basic import (
3939
alloc_like,
4040
broadcasted_by,
@@ -520,6 +520,43 @@ def elemwise_max_operands_fct(node) -> int:
520520
return 1024
521521

522522

523+
class CopyOnWriteDictOfSets:
524+
__slots__ = ("d", "d_copy")
525+
526+
def __init__(self, d: dict[Any, set[Any]]):
527+
self.d = d
528+
self.d_copy = {}
529+
530+
def __getitem__(self, key):
531+
try:
532+
return self.d_copy[key]
533+
except KeyError:
534+
return self.d[key]
535+
536+
def get(self, key, default=frozenset()):
537+
try:
538+
return self.d_copy[key]
539+
except KeyError:
540+
try:
541+
return self.d[key]
542+
except KeyError:
543+
return default
544+
545+
def remove_from_key(self, key, value):
546+
try:
547+
self.d_copy[key].remove(value)
548+
except KeyError:
549+
self.d_copy[key] = copied_value = self.d[key].copy()
550+
copied_value.remove(value)
551+
552+
def add_to_key(self, key, value):
553+
try:
554+
self.d_copy[key].add(value)
555+
except KeyError:
556+
self.d_copy[key] = copied_value = self.d[key].copy()
557+
copied_value.add(value)
558+
559+
523560
class FusionOptimizer(GraphRewriter):
524561
"""Graph optimizer that fuses consecutive Elemwise operations."""
525562

@@ -646,15 +683,10 @@ def variables_depend_on(
646683
subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set
647684
unfuseable_clients_subgraph: set[Variable] = set()
648685

649-
# Shallow cloning of maps so that they can be manipulated in place
650-
fuseable_clients_clone = defaultdict(set)
651-
fuseable_clients_clone.update(
652-
{k: v.copy() for k, v in fuseable_clients.items()}
653-
)
654-
unfuseable_clients_clone = defaultdict(set)
655-
unfuseable_clients_clone.update(
656-
{k: v.copy() for k, v in unfuseable_clients.items()}
657-
)
686+
# If we need to manipulate the maps in place, we'll do a shallow copy later
687+
# For now we query on the original ones
688+
fuseable_clients_clone = CopyOnWriteDictOfSets(fuseable_clients)
689+
unfuseable_clients_clone = CopyOnWriteDictOfSets(unfuseable_clients)
658690

659691
# We now try to expand as much as possible towards the potentially
660692
# fuseable clients and ancestors to detect the largest possible
@@ -684,7 +716,7 @@ def variables_depend_on(
684716
required_unfuseable_inputs = [
685717
inp
686718
for inp in next_node.inputs
687-
if next_node in unfuseable_clients_clone.get(inp, ())
719+
if next_node in unfuseable_clients_clone.get(inp)
688720
]
689721
new_required_unfuseable_inputs = [
690722
inp
@@ -707,7 +739,7 @@ def variables_depend_on(
707739
if not must_backtrack:
708740
implied_unfuseable_clients = {
709741
c
710-
for client in unfuseable_clients_clone.get(next_out, ())
742+
for client in unfuseable_clients_clone.get(next_out)
711743
if not isinstance(client.op, Output)
712744
for c in client.outputs
713745
}
@@ -728,13 +760,15 @@ def variables_depend_on(
728760

729761
if must_backtrack:
730762
for inp in next_node.inputs:
731-
if (
732-
inp.owner in visited_nodes
733-
# next_node could have the same input repeated
734-
and next_node in fuseable_clients_clone[inp]
735-
):
736-
fuseable_clients_clone[inp].remove(next_node)
737-
unfuseable_clients_clone[inp].add(next_node)
763+
if inp.owner in visited_nodes:
764+
if next_node not in fuseable_clients_clone[inp]:
765+
# This can happen when next node has repeated inputs
766+
continue
767+
fuseable_clients_clone.remove_from_key(
768+
inp, next_node
769+
)
770+
unfuseable_clients_clone.add_to_key(inp, next_node)
771+
738772
# This input must become an output of the subgraph,
739773
# because it can't be merged with next_node.
740774
# We will revisit it to make sure this is safe.
@@ -743,8 +777,13 @@ def variables_depend_on(
743777
# need to convert to tuple not to change set size during iteration
744778
for client in tuple(fuseable_clients_clone[next_out]):
745779
if client in visited_nodes:
746-
fuseable_clients_clone[next_out].remove(client)
747-
unfuseable_clients_clone[next_out].add(client)
780+
fuseable_clients_clone.remove_from_key(
781+
next_out, client
782+
)
783+
unfuseable_clients_clone.add_to_key(
784+
next_out, client
785+
)
786+
748787
# next_out must become an input of the subgraph.
749788
# We will revisit any of its clients currently
750789
# in the subgraph to make sure this is safe.
@@ -787,7 +826,7 @@ def variables_depend_on(
787826
sorted(
788827
(
789828
node
790-
for node in fuseable_clients_clone.get(next_out, ())
829+
for node in fuseable_clients_clone.get(next_out)
791830
if node not in visited_nodes
792831
),
793832
key=toposort_index.get,

0 commit comments

Comments
 (0)