66from collections import defaultdict , deque
77from collections .abc import Generator , Sequence
88from functools import cache , reduce
9+ from operator import or_
910from typing import Literal
1011from warnings import warn
1112
1516from pytensor .compile .mode import get_target_language
1617from pytensor .configdefaults import config
1718from pytensor .graph import FunctionGraph , Op
18- from pytensor .graph .basic import Apply , Variable , ancestors , io_toposort
19+ from pytensor .graph .basic import Apply , Variable , io_toposort
1920from pytensor .graph .destroyhandler import DestroyHandler , inplace_candidates
2021from pytensor .graph .features import ReplaceValidate
2122from pytensor .graph .fg import Output
@@ -661,16 +662,9 @@ def find_fuseable_subgraph(
661662 visited_nodes : set [Apply ],
662663 fuseable_clients : FUSEABLE_MAPPING ,
663664 unfuseable_clients : UNFUSEABLE_MAPPING ,
665+ ancestors_bitset : dict [Apply , int ],
664666 toposort_index : dict [Apply , int ],
665667 ) -> tuple [list [Variable ], list [Variable ]]:
666- def variables_depend_on (
667- variables , depend_on , stop_search_at = None
668- ) -> bool :
669- return any (
670- a in depend_on
671- for a in ancestors (variables , blockers = stop_search_at )
672- )
673-
674668 for starting_node in toposort_index :
675669 if starting_node in visited_nodes :
676670 continue
@@ -682,7 +676,8 @@ def variables_depend_on(
682676
683677 subgraph_inputs : dict [Variable , Literal [None ]] = {} # ordered set
684678 subgraph_outputs : dict [Variable , Literal [None ]] = {} # ordered set
685- unfuseable_clients_subgraph : set [Variable ] = set ()
679+ subgraph_inputs_ancestors_bitset = 0
680+ unfuseable_clients_subgraph_bitset = 0
686681
687682 # If we need to manipulate the maps in place, we'll do a shallow copy later
688683 # For now we query on the original ones
@@ -714,50 +709,32 @@ def variables_depend_on(
714709 if must_become_output :
715710 subgraph_outputs .pop (next_out , None )
716711
717- required_unfuseable_inputs = [
718- inp
719- for inp in next_node .inputs
720- if next_node in unfuseable_clients_clone .get (inp )
721- ]
722- new_required_unfuseable_inputs = [
723- inp
724- for inp in required_unfuseable_inputs
725- if inp not in subgraph_inputs
726- ]
727-
728- must_backtrack = False
729- if new_required_unfuseable_inputs and subgraph_outputs :
730- # We need to check that any new inputs required by this node
731- # do not depend on other outputs of the current subgraph,
732- # via an unfuseable path.
733- if variables_depend_on (
734- [next_out ],
735- depend_on = unfuseable_clients_subgraph ,
736- stop_search_at = subgraph_outputs ,
737- ):
738- must_backtrack = True
712+ # We need to check that any inputs required by this node
713+ # do not depend on other outputs of the current subgraph,
714+ # via an unfuseable path.
715+ must_backtrack = (
716+ ancestors_bitset [next_node ]
717+ & unfuseable_clients_subgraph_bitset
718+ )
739719
740720 if not must_backtrack :
741- implied_unfuseable_clients = {
742- c
743- for client in unfuseable_clients_clone .get (next_out )
744- if not isinstance (client .op , Output )
745- for c in client .outputs
746- }
747-
748- new_implied_unfuseable_clients = (
749- implied_unfuseable_clients - unfuseable_clients_subgraph
721+ implied_unfuseable_clients_bitset = reduce (
722+ or_ ,
723+ (
724+ 1 << toposort_index [client ]
725+ for client in unfuseable_clients_clone .get (next_out )
726+ if not isinstance (client .op , Output )
727+ ),
728+ 0 ,
750729 )
751730
752- if new_implied_unfuseable_clients and subgraph_inputs :
753- # We need to check that any inputs of the current subgraph
754- # do not depend on other clients of this node,
755- # via an unfuseable path.
756- if variables_depend_on (
757- subgraph_inputs ,
758- depend_on = new_implied_unfuseable_clients ,
759- ):
760- must_backtrack = True
731+ # We need to check that any inputs of the current subgraph
732+ # do not depend on other clients of this node,
733+ # via an unfuseable path.
734+ must_backtrack = (
735+ subgraph_inputs_ancestors_bitset
736+ & implied_unfuseable_clients_bitset
737+ )
761738
762739 if must_backtrack :
763740 for inp in next_node .inputs :
@@ -798,29 +775,24 @@ def variables_depend_on(
798775 # immediate dependency problems. Update subgraph
799776 # mappings as if it next_node was part of it.
800777 # Useless inputs will be removed by the useless Composite rewrite
801- for inp in new_required_unfuseable_inputs :
802- subgraph_inputs [inp ] = None
803-
804778 if must_become_output :
805779 subgraph_outputs [next_out ] = None
806- unfuseable_clients_subgraph . update (
807- new_implied_unfuseable_clients
780+ unfuseable_clients_subgraph_bitset |= (
781+ implied_unfuseable_clients_bitset
808782 )
809783
810- # Expand through unvisited fuseable ancestors
811- fuseable_nodes_to_visit .extendleft (
812- sorted (
813- (
814- inp .owner
815- for inp in next_node .inputs
816- if (
817- inp not in required_unfuseable_inputs
818- and inp .owner not in visited_nodes
819- )
820- ),
821- key = toposort_index .get , # type: ignore[arg-type]
822- )
823- )
784+ for inp in sorted (
785+ next_node .inputs ,
786+ key = lambda x : toposort_index .get (x .owner , - 1 ),
787+ ):
788+ if next_node in unfuseable_clients_clone .get (inp , ()):
789+ # input must become an input of the subgraph since it's unfuseable with new node
790+ subgraph_inputs_ancestors_bitset |= (
791+ ancestors_bitset .get (inp .owner , 0 )
792+ )
793+ subgraph_inputs [inp ] = None
794+ elif inp .owner not in visited_nodes :
795+ fuseable_nodes_to_visit .appendleft (inp .owner )
824796
825797 # Expand through unvisited fuseable clients
826798 fuseable_nodes_to_visit .extend (
@@ -857,6 +829,8 @@ def update_fuseable_mappings_after_fg_replace(
857829 visited_nodes : set [Apply ],
858830 fuseable_clients : FUSEABLE_MAPPING ,
859831 unfuseable_clients : UNFUSEABLE_MAPPING ,
832+ toposort_index : dict [Apply , int ],
833+ ancestors_bitset : dict [Apply , int ],
860834 starting_nodes : set [Apply ],
861835 updated_nodes : set [Apply ],
862836 ) -> None :
@@ -867,11 +841,25 @@ def update_fuseable_mappings_after_fg_replace(
867841 dropped_nodes = starting_nodes - updated_nodes
868842
869843 # Remove intermediate Composite nodes from mappings
844+ # And compute the ancestors bitset of the new composite node
845+ # As well as the new toposort index for the new node
846+ new_node_ancestor_bitset = 0
847+ new_node_toposort_index = len (toposort_index )
870848 for dropped_node in dropped_nodes :
871849 (dropped_out ,) = dropped_node .outputs
872850 fuseable_clients .pop (dropped_out , None )
873851 unfuseable_clients .pop (dropped_out , None )
874852 visited_nodes .remove (dropped_node )
853+ # The new composite ancestor bitset is the union
854+ # of the ancestors of all the dropped nodes
855+ new_node_ancestor_bitset |= ancestors_bitset [dropped_node ]
856+ # The new composite node can have the same order as the latest node that was absorbed into it
857+ new_node_toposort_index = max (
858+ new_node_toposort_index , toposort_index [dropped_node ]
859+ )
860+
861+ ancestors_bitset [new_composite_node ] = new_node_ancestor_bitset
862+ toposort_index [new_composite_node ] = new_node_toposort_index
875863
876864 # Update fuseable information for subgraph inputs
877865 for inp in subgraph_inputs :
@@ -903,12 +891,23 @@ def update_fuseable_mappings_after_fg_replace(
903891 fuseable_clients , unfuseable_clients = initialize_fuseable_mappings (fg = fg )
904892 visited_nodes : set [Apply ] = set ()
905893 toposort_index = {node : i for i , node in enumerate (fgraph .toposort ())}
894+ # Create a bitset for each node of all its ancestors
895+ # This allows to quickly check if a variable depends on a set
896+ ancestors_bitset : dict [Apply , int ] = {}
897+ for node , index in toposort_index .items ():
898+ node_ancestor_bitset = 1 << index
899+ for inp in node .inputs :
900+ if (inp_node := inp .owner ) is not None :
901+ node_ancestor_bitset |= ancestors_bitset [inp_node ]
902+ ancestors_bitset [node ] = node_ancestor_bitset
903+
906904 while True :
907905 try :
908906 subgraph_inputs , subgraph_outputs = find_fuseable_subgraph (
909907 visited_nodes = visited_nodes ,
910908 fuseable_clients = fuseable_clients ,
911909 unfuseable_clients = unfuseable_clients ,
910+ ancestors_bitset = ancestors_bitset ,
912911 toposort_index = toposort_index ,
913912 )
914913 except ValueError :
@@ -927,6 +926,8 @@ def update_fuseable_mappings_after_fg_replace(
927926 visited_nodes = visited_nodes ,
928927 fuseable_clients = fuseable_clients ,
929928 unfuseable_clients = unfuseable_clients ,
929+ toposort_index = toposort_index ,
930+ ancestors_bitset = ancestors_bitset ,
930931 starting_nodes = starting_nodes ,
931932 updated_nodes = fg .apply_nodes ,
932933 )
0 commit comments