55from collections import defaultdict , deque
66from collections .abc import Generator , Sequence
77from functools import cache , reduce
8+ from operator import or_
89from typing import Literal
910from warnings import warn
1011
1415from pytensor .compile .mode import get_target_language
1516from pytensor .configdefaults import config
1617from pytensor .graph import FunctionGraph , Op
17- from pytensor .graph .basic import Apply , Variable , ancestors , io_toposort
18+ from pytensor .graph .basic import Apply , Variable , io_toposort
1819from pytensor .graph .destroyhandler import DestroyHandler , inplace_candidates
1920from pytensor .graph .features import ReplaceValidate
2021from pytensor .graph .fg import Output
@@ -660,16 +661,9 @@ def find_fuseable_subgraph(
660661 visited_nodes : set [Apply ],
661662 fuseable_clients : FUSEABLE_MAPPING ,
662663 unfuseable_clients : UNFUSEABLE_MAPPING ,
664+ ancestors_bitset : dict [Apply , int ],
663665 toposort_index : dict [Apply , int ],
664666 ) -> tuple [list [Variable ], list [Variable ]]:
665- def variables_depend_on (
666- variables , depend_on , stop_search_at = None
667- ) -> bool :
668- return any (
669- a in depend_on
670- for a in ancestors (variables , blockers = stop_search_at )
671- )
672-
673667 for starting_node in toposort_index :
674668 if starting_node in visited_nodes :
675669 continue
@@ -681,7 +675,8 @@ def variables_depend_on(
681675
682676 subgraph_inputs : dict [Variable , Literal [None ]] = {} # ordered set
683677 subgraph_outputs : dict [Variable , Literal [None ]] = {} # ordered set
684- unfuseable_clients_subgraph : set [Variable ] = set ()
678+ subgraph_inputs_ancestors_bitset = 0
679+ unfuseable_clients_subgraph_bitset = 0
685680
686681 # If we need to manipulate the maps in place, we'll do a shallow copy later
687682 # For now we query on the original ones
@@ -713,50 +708,32 @@ def variables_depend_on(
713708 if must_become_output :
714709 subgraph_outputs .pop (next_out , None )
715710
716- required_unfuseable_inputs = [
717- inp
718- for inp in next_node .inputs
719- if next_node in unfuseable_clients_clone .get (inp )
720- ]
721- new_required_unfuseable_inputs = [
722- inp
723- for inp in required_unfuseable_inputs
724- if inp not in subgraph_inputs
725- ]
726-
727- must_backtrack = False
728- if new_required_unfuseable_inputs and subgraph_outputs :
729- # We need to check that any new inputs required by this node
730- # do not depend on other outputs of the current subgraph,
731- # via an unfuseable path.
732- if variables_depend_on (
733- [next_out ],
734- depend_on = unfuseable_clients_subgraph ,
735- stop_search_at = subgraph_outputs ,
736- ):
737- must_backtrack = True
711+ # We need to check that any inputs required by this node
712+ # do not depend on other outputs of the current subgraph,
713+ # via an unfuseable path.
714+ must_backtrack = (
715+ ancestors_bitset [next_node ]
716+ & unfuseable_clients_subgraph_bitset
717+ )
738718
739719 if not must_backtrack :
740- implied_unfuseable_clients = {
741- c
742- for client in unfuseable_clients_clone .get (next_out )
743- if not isinstance (client .op , Output )
744- for c in client .outputs
745- }
746-
747- new_implied_unfuseable_clients = (
748- implied_unfuseable_clients - unfuseable_clients_subgraph
720+ implied_unfuseable_clients_bitset = reduce (
721+ or_ ,
722+ (
723+ 1 << toposort_index [client ]
724+ for client in unfuseable_clients_clone .get (next_out )
725+ if not isinstance (client .op , Output )
726+ ),
727+ 0 ,
749728 )
750729
751- if new_implied_unfuseable_clients and subgraph_inputs :
752- # We need to check that any inputs of the current subgraph
753- # do not depend on other clients of this node,
754- # via an unfuseable path.
755- if variables_depend_on (
756- subgraph_inputs ,
757- depend_on = new_implied_unfuseable_clients ,
758- ):
759- must_backtrack = True
730+ # We need to check that any inputs of the current subgraph
731+ # do not depend on other clients of this node,
732+ # via an unfuseable path.
733+ must_backtrack = (
734+ subgraph_inputs_ancestors_bitset
735+ & implied_unfuseable_clients_bitset
736+ )
760737
761738 if must_backtrack :
762739 for inp in next_node .inputs :
@@ -797,29 +774,24 @@ def variables_depend_on(
797774 # immediate dependency problems. Update subgraph
798775 # mappings as if it next_node was part of it.
799776 # Useless inputs will be removed by the useless Composite rewrite
800- for inp in new_required_unfuseable_inputs :
801- subgraph_inputs [inp ] = None
802-
803777 if must_become_output :
804778 subgraph_outputs [next_out ] = None
805- unfuseable_clients_subgraph . update (
806- new_implied_unfuseable_clients
779+ unfuseable_clients_subgraph_bitset |= (
780+ implied_unfuseable_clients_bitset
807781 )
808782
809- # Expand through unvisited fuseable ancestors
810- fuseable_nodes_to_visit .extendleft (
811- sorted (
812- (
813- inp .owner
814- for inp in next_node .inputs
815- if (
816- inp not in required_unfuseable_inputs
817- and inp .owner not in visited_nodes
818- )
819- ),
820- key = toposort_index .get ,
821- )
822- )
783+ for inp in sorted (
784+ next_node .inputs ,
785+ key = lambda x : toposort_index .get (x .owner , - 1 ),
786+ ):
787+ if next_node in unfuseable_clients_clone .get (inp , ()):
788+ # input must become an input of the subgraph since it's unfuseable with new node
789+ subgraph_inputs_ancestors_bitset |= (
790+ ancestors_bitset .get (inp .owner , 0 )
791+ )
792+ subgraph_inputs [inp ] = None
793+ elif inp .owner not in visited_nodes :
794+ fuseable_nodes_to_visit .appendleft (inp .owner )
823795
824796 # Expand through unvisited fuseable clients
825797 fuseable_nodes_to_visit .extend (
@@ -856,6 +828,8 @@ def update_fuseable_mappings_after_fg_replace(
856828 visited_nodes : set [Apply ],
857829 fuseable_clients : FUSEABLE_MAPPING ,
858830 unfuseable_clients : UNFUSEABLE_MAPPING ,
831+ toposort_index : dict [Apply , int ],
832+ ancestors_bitset : dict [Apply , int ],
859833 starting_nodes : set [Apply ],
860834 updated_nodes : set [Apply ],
861835 ) -> None :
@@ -866,11 +840,25 @@ def update_fuseable_mappings_after_fg_replace(
866840 dropped_nodes = starting_nodes - updated_nodes
867841
868842 # Remove intermediate Composite nodes from mappings
843+ # And compute the ancestors bitset of the new composite node
844+ # As well as the new toposort index for the new node
845+ new_node_ancestor_bitset = 0
846+ new_node_toposort_index = len (toposort_index )
869847 for dropped_node in dropped_nodes :
870848 (dropped_out ,) = dropped_node .outputs
871849 fuseable_clients .pop (dropped_out , None )
872850 unfuseable_clients .pop (dropped_out , None )
873851 visited_nodes .remove (dropped_node )
852+ # The new composite ancestor bitset is the union
853+ # of the ancestors of all the dropped nodes
854+ new_node_ancestor_bitset |= ancestors_bitset [dropped_node ]
855+ # The new composite node can have the same order as the latest node that was absorbed into it
856+ new_node_toposort_index = max (
857+ new_node_toposort_index , toposort_index [dropped_node ]
858+ )
859+
860+ ancestors_bitset [new_composite_node ] = new_node_ancestor_bitset
861+ toposort_index [new_composite_node ] = new_node_toposort_index
874862
875863 # Update fuseable information for subgraph inputs
876864 for inp in subgraph_inputs :
@@ -902,12 +890,23 @@ def update_fuseable_mappings_after_fg_replace(
902890 fuseable_clients , unfuseable_clients = initialize_fuseable_mappings (fg = fg )
903891 visited_nodes : set [Apply ] = set ()
904892 toposort_index = {node : i for i , node in enumerate (fgraph .toposort ())}
893+ # Create a bitset for each node of all its ancestors
894+ # This allows to quickly check if a variable depends on a set
895+ ancestors_bitset = {}
896+ for node , index in toposort_index .items ():
897+ node_ancestor_bitset = 1 << index
898+ for inp in node .inputs :
899+ if (inp_node := inp .owner ) is not None :
900+ node_ancestor_bitset |= ancestors_bitset [inp_node ]
901+ ancestors_bitset [node ] = node_ancestor_bitset
902+
905903 while True :
906904 try :
907905 subgraph_inputs , subgraph_outputs = find_fuseable_subgraph (
908906 visited_nodes = visited_nodes ,
909907 fuseable_clients = fuseable_clients ,
910908 unfuseable_clients = unfuseable_clients ,
909+ ancestors_bitset = ancestors_bitset ,
911910 toposort_index = toposort_index ,
912911 )
913912 except ValueError :
@@ -926,6 +925,8 @@ def update_fuseable_mappings_after_fg_replace(
926925 visited_nodes = visited_nodes ,
927926 fuseable_clients = fuseable_clients ,
928927 unfuseable_clients = unfuseable_clients ,
928+ toposort_index = toposort_index ,
929+ ancestors_bitset = ancestors_bitset ,
929930 starting_nodes = starting_nodes ,
930931 updated_nodes = fg .apply_nodes ,
931932 )
0 commit comments