55from collections import defaultdict , deque
66from collections .abc import Generator , Sequence
77from functools import cache , reduce
8- from typing import TypeVar
8+ from typing import Literal
99from warnings import warn
1010
1111import pytensor .scalar .basic as ps
@@ -566,8 +566,7 @@ def find_next_fuseable_subgraph(
566566 This generator assumes that such subgraph is replaced by a single
567567 Elemwise Composite before being accessed again in the next iteration.
568568 """
569-
570- FUSEABLE_MAPPING = defaultdict [Variable , list [Apply ]]
569+ FUSEABLE_MAPPING = defaultdict [Variable , set [Apply ]]
571570 UNFUSEABLE_MAPPING = defaultdict [Variable , set [Apply ]]
572571
573572 def initialize_fuseable_mappings (
@@ -589,35 +588,33 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool:
589588 # to ensure the rewrite remains deterministic.
590589 # This is not a problem from unfuseable ones, as they can never
591590 # become part of the graph.
592- fuseable_clients : FUSEABLE_MAPPING = defaultdict (list )
591+ fuseable_clients : FUSEABLE_MAPPING = defaultdict (set )
593592 unfuseable_clients : UNFUSEABLE_MAPPING = defaultdict (set )
594593 for out , clients in fg .clients .items ():
595- # Old FunctionGraph nodes remain in the clients dictionary
596- # even after they are removed by rewrites
597- if not clients :
598- continue
599-
600594 out_maybe_fuseable = (
601- out .owner
595+ out .owner is not None
602596 and isinstance (out .owner .op , Elemwise )
603597 # and not isinstance(out.owner.op.scalar_op, ps.Composite)
604598 and len (out .owner .outputs ) == 1
605599 and elemwise_scalar_op_has_c_code (out .owner )
606600 )
607- for client , _ in clients :
608- if (
609- out_maybe_fuseable
610- and isinstance (client .op , Elemwise )
611- # and not isinstance(client.op.scalar_op, ps.Composite)
612- and len (client .outputs ) == 1
613- and out .type .broadcastable
614- == client .outputs [0 ].type .broadcastable
615- and elemwise_scalar_op_has_c_code (client )
616- ):
617- if client not in fuseable_clients [out ]:
618- fuseable_clients [out ].append (client )
619- else :
620- unfuseable_clients [out ].add (client )
601+ if out_maybe_fuseable :
602+ out_bcast = (
603+ out .type .broadcastable if out_maybe_fuseable else None
604+ )
605+ for client , _ in clients :
606+ if (
607+ isinstance (client .op , Elemwise )
608+ # and not isinstance(client.op.scalar_op, ps.Composite)
609+ and len (client .outputs ) == 1
610+ and out_bcast == client .outputs [0 ].type .broadcastable
611+ and elemwise_scalar_op_has_c_code (client )
612+ ):
613+ fuseable_clients [out ].add (client )
614+ else :
615+ unfuseable_clients [out ].add (client )
616+ else :
617+ unfuseable_clients [out ] = {client for client , _ in clients }
621618
622619 return fuseable_clients , unfuseable_clients
623620
@@ -628,16 +625,6 @@ def find_fuseable_subgraph(
628625 unfuseable_clients : UNFUSEABLE_MAPPING ,
629626 toposort_index : dict [Apply , int ],
630627 ) -> tuple [list [Variable ], list [Variable ]]:
631- KT = TypeVar ("KT" )
632- VT = TypeVar ("VT" , list , set )
633-
634- def shallow_clone_defaultdict (
635- d : defaultdict [KT , VT ],
636- ) -> defaultdict [KT , VT ]:
637- new_dict : defaultdict [KT , VT ] = defaultdict (d .default_factory )
638- new_dict .update ({k : v .copy () for k , v in d .items ()})
639- return new_dict
640-
641628 def variables_depend_on (
642629 variables , depend_on , stop_search_at = None
643630 ) -> bool :
@@ -655,17 +642,19 @@ def variables_depend_on(
655642 visited_nodes .add (starting_node )
656643 continue
657644
658- subgraph_inputs : list [Variable ] = []
659- subgraph_outputs : list [Variable ] = []
645+ subgraph_inputs : dict [Variable , Literal [ None ]] = {} # ordered set
646+ subgraph_outputs : dict [Variable , Literal [ None ]] = {} # ordered set
660647 unfuseable_clients_subgraph : set [Variable ] = set ()
661648
662649 # Shallow cloning of maps so that they can be manipulated in place
663- fuseable_clients_temp = shallow_clone_defaultdict (fuseable_clients )
664- unfuseable_clients_clone = shallow_clone_defaultdict (
665- unfuseable_clients
650+ fuseable_clients_clone : FUSEABLE_MAPPING = defaultdict (set )
651+ fuseable_clients_clone .update (
652+ {k : v .copy () for k , v in fuseable_clients .items ()}
653+ )
654+ unfuseable_clients_clone : UNFUSEABLE_MAPPING = defaultdict (set )
655+ unfuseable_clients_clone .update (
656+ {k : v .copy () for k , v in unfuseable_clients .items ()}
666657 )
667-
668- fuseable_nodes_to_visit = deque ([starting_node ])
669658
670659 # We now try to expand as much as possible towards the potentially
671660 # fuseable clients and ancestors to detect the largest possible
@@ -674,6 +663,7 @@ def variables_depend_on(
674663 # some inputs or clients may depend on other nodes of the same
675664 # subgraph via a path that cannot be included in the Composite
676665 # (unfuseable)
666+ fuseable_nodes_to_visit = deque ([starting_node ])
677667 while fuseable_nodes_to_visit :
678668 next_node = fuseable_nodes_to_visit .popleft ()
679669 visited_nodes .add (next_node )
@@ -682,15 +672,14 @@ def variables_depend_on(
682672 # If the output variable of next_node has no fuseable clients
683673 # or has unfuseable clients, then next_node must become an output
684674 # if it is to be fused.
685- must_become_output = (
686- next_out not in fuseable_clients_temp
687- or next_out in unfuseable_clients_clone
688- )
675+ must_become_output = not fuseable_clients_clone .get (
676+ next_out
677+ ) or unfuseable_clients_clone .get (next_out )
689678
690679 # We have backtracked to this node, and it may no longer be a viable output,
691680 # so we remove it and check again as if we had never seen this node
692- if must_become_output and next_out in subgraph_outputs :
693- subgraph_outputs .remove (next_out )
681+ if must_become_output :
682+ subgraph_outputs .pop (next_out , None )
694683
695684 required_unfuseable_inputs = [
696685 inp
@@ -742,18 +731,19 @@ def variables_depend_on(
742731 if (
743732 inp .owner in visited_nodes
744733 # next_node could have the same input repeated
745- and next_node in fuseable_clients_temp [inp ]
734+ and next_node in fuseable_clients_clone [inp ]
746735 ):
747- fuseable_clients_temp [inp ].remove (next_node )
736+ fuseable_clients_clone [inp ].remove (next_node )
748737 unfuseable_clients_clone [inp ].add (next_node )
749738 # This input must become an output of the subgraph,
750739 # because it can't be merged with next_node.
751740 # We will revisit it to make sure this is safe.
752741 fuseable_nodes_to_visit .appendleft (inp .owner )
753742
754- for client in fuseable_clients_temp [next_out ]:
743+ # need to convert to tuple not to change set size during iteration
744+ for client in tuple (fuseable_clients_clone [next_out ]):
755745 if client in visited_nodes :
756- fuseable_clients_temp [next_out ].remove (client )
746+ fuseable_clients_clone [next_out ].remove (client )
757747 unfuseable_clients_clone [next_out ].add (client )
758748 # next_out must become an input of the subgraph.
759749 # We will revisit any of its clients currently
@@ -769,74 +759,72 @@ def variables_depend_on(
769759 # mappings as if it next_node was part of it.
770760 # Useless inputs will be removed by the useless Composite rewrite
771761 for inp in new_required_unfuseable_inputs :
772- if inp not in subgraph_inputs :
773- subgraph_inputs .append (inp )
762+ subgraph_inputs [inp ] = None
774763
775764 if must_become_output :
776- subgraph_outputs . append ( next_out )
765+ subgraph_outputs [ next_out ] = None
777766 unfuseable_clients_subgraph .update (
778767 new_implied_unfuseable_clients
779768 )
780769
781770 # Expand through unvisited fuseable ancestors
782- for inp in sorted (
783- (
784- inp
785- for inp in next_node . inputs
786- if (
787- inp not in required_unfuseable_inputs
788- and inp . owner not in visited_nodes
789- )
790- ),
791- key = lambda inp : toposort_index [ inp . owner ] ,
792- reverse = True ,
793- ):
794- fuseable_nodes_to_visit . appendleft ( inp . owner )
771+ fuseable_nodes_to_visit . extendleft (
772+ sorted (
773+ (
774+ inp . owner
775+ for inp in next_node . inputs
776+ if (
777+ inp not in required_unfuseable_inputs
778+ and inp . owner not in visited_nodes
779+ )
780+ ) ,
781+ key = toposort_index . get , # type: ignore[arg-type]
782+ )
783+ )
795784
796785 # Expand through unvisited fuseable clients
797- for next_node in sorted (
798- (
799- node
800- for node in fuseable_clients_temp .get (next_out , ())
801- if node not in visited_nodes
802- ),
803- key = lambda node : toposort_index [node ],
804- ):
805- fuseable_nodes_to_visit .append (next_node )
786+ fuseable_nodes_to_visit .extend (
787+ sorted (
788+ (
789+ node
790+ for node in fuseable_clients_clone .get (next_out , ())
791+ if node not in visited_nodes
792+ ),
793+ key = toposort_index .get , # type: ignore[arg-type]
794+ )
795+ )
806796
807797 # Don't return if final subgraph is just the original Elemwise
808798 if len (subgraph_outputs ) == 1 and set (
809- subgraph_outputs [ 0 ] .owner .inputs
799+ next ( iter ( subgraph_outputs )) .owner .inputs
810800 ) == set (subgraph_inputs ):
811801 # Update global fuseable mappings
812802 # No input was actually fuseable
813803 for inp in starting_node .inputs :
814- if starting_node in fuseable_clients .get (inp , ()):
815- fuseable_clients [inp ].remove (starting_node )
816- unfuseable_clients [inp ].add (starting_node )
804+ fuseable_clients [inp ].discard (starting_node )
805+ unfuseable_clients [inp ].add (starting_node )
817806 # No client was actually fuseable
818807 unfuseable_clients [starting_out ].update (
819808 fuseable_clients .pop (starting_out , ())
820809 )
821810 continue
822811
823- return subgraph_inputs , subgraph_outputs
812+ return list ( subgraph_inputs ), list ( subgraph_outputs )
824813 raise ValueError
825814
826815 def update_fuseable_mappings_after_fg_replace (
827816 * ,
828- fg : FunctionGraph ,
829817 visited_nodes : set [Apply ],
830818 fuseable_clients : FUSEABLE_MAPPING ,
831819 unfuseable_clients : UNFUSEABLE_MAPPING ,
832820 starting_nodes : set [Apply ],
821+ updated_nodes : set [Apply ],
833822 ) -> None :
834823 # Find new composite node and dropped intermediate nodes
835824 # by comparing the current fg.apply nodes with the cached
836825 # original nodes
837- next_nodes = fg .apply_nodes
838- (new_composite_node ,) = next_nodes - starting_nodes
839- dropped_nodes = starting_nodes - next_nodes
826+ (new_composite_node ,) = updated_nodes - starting_nodes
827+ dropped_nodes = starting_nodes - updated_nodes
840828
841829 # Remove intermediate Composite nodes from mappings
842830 for dropped_node in dropped_nodes :
@@ -848,11 +836,11 @@ def update_fuseable_mappings_after_fg_replace(
848836 # Update fuseable information for subgraph inputs
849837 for inp in subgraph_inputs :
850838 if inp in fuseable_clients :
851- new_fuseable_clients = [
839+ new_fuseable_clients = {
852840 client
853841 for client in fuseable_clients [inp ]
854842 if client not in dropped_nodes
855- ]
843+ }
856844 if new_fuseable_clients :
857845 fuseable_clients [inp ] = new_fuseable_clients
858846 else :
@@ -896,11 +884,11 @@ def update_fuseable_mappings_after_fg_replace(
896884 # generator. For large models (as in `TestFusion.test_big_fusion`)
897885 # this can provide huge speedups
898886 update_fuseable_mappings_after_fg_replace (
899- fg = fg ,
900887 visited_nodes = visited_nodes ,
901888 fuseable_clients = fuseable_clients ,
902889 unfuseable_clients = unfuseable_clients ,
903890 starting_nodes = starting_nodes ,
891+ updated_nodes = fg .apply_nodes ,
904892 )
905893
906894 nb_fused = 0
0 commit comments