22import itertools
33import operator
44import sys
5+ import typing
56from collections import defaultdict , deque
67from collections .abc import Generator , Sequence
78from functools import cache , reduce
@@ -520,6 +521,43 @@ def elemwise_max_operands_fct(node) -> int:
520521 return 1024
521522
522523
524+ class CopyOnWriteDictOfSets :
525+ __slots__ = ("d" , "d_copy" )
526+
527+ def __init__ (self , d : dict [typing .Any , set ]):
528+ self .d = d
529+ self .d_copy : dict [typing .Any , set ] = {}
530+
531+ def __getitem__ (self , key ):
532+ try :
533+ return self .d_copy [key ]
534+ except KeyError :
535+ return self .d [key ]
536+
537+ def get (self , key , default = frozenset ()):
538+ try :
539+ return self .d_copy [key ]
540+ except KeyError :
541+ try :
542+ return self .d [key ]
543+ except KeyError :
544+ return default
545+
546+ def remove_from_key (self , key , value ):
547+ try :
548+ self .d_copy [key ].remove (value )
549+ except KeyError :
550+ self .d_copy [key ] = copied_value = self .d [key ].copy ()
551+ copied_value .remove (value )
552+
553+ def add_to_key (self , key , value ):
554+ try :
555+ self .d_copy [key ].add (value )
556+ except KeyError :
557+ self .d_copy [key ] = copied_value = self .d [key ].copy ()
558+ copied_value .add (value )
559+
560+
523561class FusionOptimizer (GraphRewriter ):
524562 """Graph optimizer that fuses consecutive Elemwise operations."""
525563
@@ -646,15 +684,10 @@ def variables_depend_on(
646684 subgraph_outputs : dict [Variable , Literal [None ]] = {} # ordered set
647685 unfuseable_clients_subgraph : set [Variable ] = set ()
648686
649- # Shallow cloning of maps so that they can be manipulated in place
650- fuseable_clients_clone : FUSEABLE_MAPPING = defaultdict (set )
651- fuseable_clients_clone .update (
652- {k : v .copy () for k , v in fuseable_clients .items ()}
653- )
654- unfuseable_clients_clone : UNFUSEABLE_MAPPING = defaultdict (set )
655- unfuseable_clients_clone .update (
656- {k : v .copy () for k , v in unfuseable_clients .items ()}
657- )
687+ # If we need to manipulate the maps in place, we'll do a shallow copy later
688+ # For now we query on the original ones
689+ fuseable_clients_clone = CopyOnWriteDictOfSets (fuseable_clients )
690+ unfuseable_clients_clone = CopyOnWriteDictOfSets (unfuseable_clients )
658691
659692 # We now try to expand as much as possible towards the potentially
660693 # fuseable clients and ancestors to detect the largest possible
@@ -684,7 +717,7 @@ def variables_depend_on(
684717 required_unfuseable_inputs = [
685718 inp
686719 for inp in next_node .inputs
687- if next_node in unfuseable_clients_clone .get (inp , () )
720+ if next_node in unfuseable_clients_clone .get (inp )
688721 ]
689722 new_required_unfuseable_inputs = [
690723 inp
@@ -707,7 +740,7 @@ def variables_depend_on(
707740 if not must_backtrack :
708741 implied_unfuseable_clients = {
709742 c
710- for client in unfuseable_clients_clone .get (next_out , () )
743+ for client in unfuseable_clients_clone .get (next_out )
711744 if not isinstance (client .op , Output )
712745 for c in client .outputs
713746 }
@@ -728,13 +761,15 @@ def variables_depend_on(
728761
729762 if must_backtrack :
730763 for inp in next_node .inputs :
731- if (
732- inp .owner in visited_nodes
733- # next_node could have the same input repeated
734- and next_node in fuseable_clients_clone [inp ]
735- ):
736- fuseable_clients_clone [inp ].remove (next_node )
737- unfuseable_clients_clone [inp ].add (next_node )
764+ if inp .owner in visited_nodes :
765+ if next_node not in fuseable_clients_clone [inp ]:
766+ # This can happen when next node has repeated inputs
767+ continue
768+ fuseable_clients_clone .remove_from_key (
769+ inp , next_node
770+ )
771+ unfuseable_clients_clone .add_to_key (inp , next_node )
772+
738773 # This input must become an output of the subgraph,
739774 # because it can't be merged with next_node.
740775 # We will revisit it to make sure this is safe.
@@ -743,8 +778,13 @@ def variables_depend_on(
743778 # need to convert to tuple not to change set size during iteration
744779 for client in tuple (fuseable_clients_clone [next_out ]):
745780 if client in visited_nodes :
746- fuseable_clients_clone [next_out ].remove (client )
747- unfuseable_clients_clone [next_out ].add (client )
781+ fuseable_clients_clone .remove_from_key (
782+ next_out , client
783+ )
784+ unfuseable_clients_clone .add_to_key (
785+ next_out , client
786+ )
787+
748788 # next_out must become an input of the subgraph.
749789 # We will revisit any of its clients currently
750790 # in the subgraph to make sure this is safe.
@@ -787,7 +827,7 @@ def variables_depend_on(
787827 sorted (
788828 (
789829 node
790- for node in fuseable_clients_clone .get (next_out , () )
830+ for node in fuseable_clients_clone .get (next_out )
791831 if node not in visited_nodes
792832 ),
793833 key = toposort_index .get , # type: ignore[arg-type]
0 commit comments