3434 constant ,
3535)
3636from pytensor .tensor .elemwise import CAReduce , DimShuffle , Elemwise
37- from pytensor .tensor .math import add , exp , mul
37+ from pytensor .tensor .math import Any , add , exp , mul
3838from pytensor .tensor .rewriting .basic import (
3939 alloc_like ,
4040 broadcasted_by ,
@@ -520,6 +520,43 @@ def elemwise_max_operands_fct(node) -> int:
520520 return 1024
521521
522522
523+ class CopyOnWriteDictOfSets :
524+ __slots__ = ("d" , "d_copy" )
525+
526+ def __init__ (self , d : dict [Any , set [Any ]]):
527+ self .d = d
528+ self .d_copy = {}
529+
530+ def __getitem__ (self , key ):
531+ try :
532+ return self .d_copy [key ]
533+ except KeyError :
534+ return self .d [key ]
535+
536+ def get (self , key , default = frozenset ()):
537+ try :
538+ return self .d_copy [key ]
539+ except KeyError :
540+ try :
541+ return self .d [key ]
542+ except KeyError :
543+ return default
544+
545+ def remove_from_key (self , key , value ):
546+ try :
547+ self .d_copy [key ].remove (value )
548+ except KeyError :
549+ self .d_copy [key ] = copied_value = self .d [key ].copy ()
550+ copied_value .remove (value )
551+
552+ def add_to_key (self , key , value ):
553+ try :
554+ self .d_copy [key ].add (value )
555+ except KeyError :
556+ self .d_copy [key ] = copied_value = self .d [key ].copy ()
557+ copied_value .add (value )
558+
559+
523560class FusionOptimizer (GraphRewriter ):
524561 """Graph optimizer that fuses consecutive Elemwise operations."""
525562
@@ -646,15 +683,10 @@ def variables_depend_on(
646683 subgraph_outputs : dict [Variable , Literal [None ]] = {} # ordered set
647684 unfuseable_clients_subgraph : set [Variable ] = set ()
648685
649- # Shallow cloning of maps so that they can be manipulated in place
650- fuseable_clients_clone = defaultdict (set )
651- fuseable_clients_clone .update (
652- {k : v .copy () for k , v in fuseable_clients .items ()}
653- )
654- unfuseable_clients_clone = defaultdict (set )
655- unfuseable_clients_clone .update (
656- {k : v .copy () for k , v in unfuseable_clients .items ()}
657- )
686+ # If we need to manipulate the maps in place, we'll do a shallow copy later
687+ # For now we query on the original ones
688+ fuseable_clients_clone = CopyOnWriteDictOfSets (fuseable_clients )
689+ unfuseable_clients_clone = CopyOnWriteDictOfSets (unfuseable_clients )
658690
659691 # We now try to expand as much as possible towards the potentially
660692 # fuseable clients and ancestors to detect the largest possible
@@ -684,7 +716,7 @@ def variables_depend_on(
684716 required_unfuseable_inputs = [
685717 inp
686718 for inp in next_node .inputs
687- if next_node in unfuseable_clients_clone .get (inp , () )
719+ if next_node in unfuseable_clients_clone .get (inp )
688720 ]
689721 new_required_unfuseable_inputs = [
690722 inp
@@ -707,7 +739,7 @@ def variables_depend_on(
707739 if not must_backtrack :
708740 implied_unfuseable_clients = {
709741 c
710- for client in unfuseable_clients_clone .get (next_out , () )
742+ for client in unfuseable_clients_clone .get (next_out )
711743 if not isinstance (client .op , Output )
712744 for c in client .outputs
713745 }
@@ -728,13 +760,15 @@ def variables_depend_on(
728760
729761 if must_backtrack :
730762 for inp in next_node .inputs :
731- if (
732- inp .owner in visited_nodes
733- # next_node could have the same input repeated
734- and next_node in fuseable_clients_clone [inp ]
735- ):
736- fuseable_clients_clone [inp ].remove (next_node )
737- unfuseable_clients_clone [inp ].add (next_node )
763+ if inp .owner in visited_nodes :
764+ if next_node not in fuseable_clients_clone [inp ]:
765+ # This can happen when next node has repeated inputs
766+ continue
767+ fuseable_clients_clone .remove_from_key (
768+ inp , next_node
769+ )
770+ unfuseable_clients_clone .add_to_key (inp , next_node )
771+
738772 # This input must become an output of the subgraph,
739773 # because it can't be merged with next_node.
740774 # We will revisit it to make sure this is safe.
@@ -743,8 +777,13 @@ def variables_depend_on(
743777 # need to convert to tuple not to change set size during iteration
744778 for client in tuple (fuseable_clients_clone [next_out ]):
745779 if client in visited_nodes :
746- fuseable_clients_clone [next_out ].remove (client )
747- unfuseable_clients_clone [next_out ].add (client )
780+ fuseable_clients_clone .remove_from_key (
781+ next_out , client
782+ )
783+ unfuseable_clients_clone .add_to_key (
784+ next_out , client
785+ )
786+
748787 # next_out must become an input of the subgraph.
749788 # We will revisit any of its clients currently
750789 # in the subgraph to make sure this is safe.
@@ -787,7 +826,7 @@ def variables_depend_on(
787826 sorted (
788827 (
789828 node
790- for node in fuseable_clients_clone .get (next_out , () )
829+ for node in fuseable_clients_clone .get (next_out )
791830 if node not in visited_nodes
792831 ),
793832 key = toposort_index .get ,
0 commit comments