Skip to content

Commit 2afc727

Browse files
committed
Use bitset to check ancestors more efficiently
1 parent f8ee91c commit 2afc727

File tree

1 file changed

+70
-69
lines changed

1 file changed

+70
-69
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 70 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from collections import defaultdict, deque
66
from collections.abc import Generator, Sequence
77
from functools import cache, reduce
8+
from operator import or_
89
from typing import Literal
910
from warnings import warn
1011

@@ -14,7 +15,7 @@
1415
from pytensor.compile.mode import get_target_language
1516
from pytensor.configdefaults import config
1617
from pytensor.graph import FunctionGraph, Op
17-
from pytensor.graph.basic import Apply, Variable, ancestors, io_toposort
18+
from pytensor.graph.basic import Apply, Variable, io_toposort
1819
from pytensor.graph.destroyhandler import DestroyHandler, inplace_candidates
1920
from pytensor.graph.features import ReplaceValidate
2021
from pytensor.graph.fg import Output
@@ -660,16 +661,9 @@ def find_fuseable_subgraph(
660661
visited_nodes: set[Apply],
661662
fuseable_clients: FUSEABLE_MAPPING,
662663
unfuseable_clients: UNFUSEABLE_MAPPING,
664+
ancestors_bitset: dict[Apply, int],
663665
toposort_index: dict[Apply, int],
664666
) -> tuple[list[Variable], list[Variable]]:
665-
def variables_depend_on(
666-
variables, depend_on, stop_search_at=None
667-
) -> bool:
668-
return any(
669-
a in depend_on
670-
for a in ancestors(variables, blockers=stop_search_at)
671-
)
672-
673667
for starting_node in toposort_index:
674668
if starting_node in visited_nodes:
675669
continue
@@ -681,7 +675,8 @@ def variables_depend_on(
681675

682676
subgraph_inputs: dict[Variable, Literal[None]] = {} # ordered set
683677
subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set
684-
unfuseable_clients_subgraph: set[Variable] = set()
678+
subgraph_inputs_ancestors_bitset = 0
679+
unfuseable_clients_subgraph_bitset = 0
685680

686681
# If we need to manipulate the maps in place, we'll do a shallow copy later
687682
# For now we query on the original ones
@@ -713,50 +708,32 @@ def variables_depend_on(
713708
if must_become_output:
714709
subgraph_outputs.pop(next_out, None)
715710

716-
required_unfuseable_inputs = [
717-
inp
718-
for inp in next_node.inputs
719-
if next_node in unfuseable_clients_clone.get(inp)
720-
]
721-
new_required_unfuseable_inputs = [
722-
inp
723-
for inp in required_unfuseable_inputs
724-
if inp not in subgraph_inputs
725-
]
726-
727-
must_backtrack = False
728-
if new_required_unfuseable_inputs and subgraph_outputs:
729-
# We need to check that any new inputs required by this node
730-
# do not depend on other outputs of the current subgraph,
731-
# via an unfuseable path.
732-
if variables_depend_on(
733-
[next_out],
734-
depend_on=unfuseable_clients_subgraph,
735-
stop_search_at=subgraph_outputs,
736-
):
737-
must_backtrack = True
711+
# We need to check that any inputs required by this node
712+
# do not depend on other outputs of the current subgraph,
713+
# via an unfuseable path.
714+
must_backtrack = (
715+
ancestors_bitset[next_node]
716+
& unfuseable_clients_subgraph_bitset
717+
)
738718

739719
if not must_backtrack:
740-
implied_unfuseable_clients = {
741-
c
742-
for client in unfuseable_clients_clone.get(next_out)
743-
if not isinstance(client.op, Output)
744-
for c in client.outputs
745-
}
746-
747-
new_implied_unfuseable_clients = (
748-
implied_unfuseable_clients - unfuseable_clients_subgraph
720+
implied_unfuseable_clients_bitset = reduce(
721+
or_,
722+
(
723+
1 << toposort_index[client]
724+
for client in unfuseable_clients_clone.get(next_out)
725+
if not isinstance(client.op, Output)
726+
),
727+
0,
749728
)
750729

751-
if new_implied_unfuseable_clients and subgraph_inputs:
752-
# We need to check that any inputs of the current subgraph
753-
# do not depend on other clients of this node,
754-
# via an unfuseable path.
755-
if variables_depend_on(
756-
subgraph_inputs,
757-
depend_on=new_implied_unfuseable_clients,
758-
):
759-
must_backtrack = True
730+
# We need to check that any inputs of the current subgraph
731+
# do not depend on other clients of this node,
732+
# via an unfuseable path.
733+
must_backtrack = (
734+
subgraph_inputs_ancestors_bitset
735+
& implied_unfuseable_clients_bitset
736+
)
760737

761738
if must_backtrack:
762739
for inp in next_node.inputs:
@@ -797,29 +774,24 @@ def variables_depend_on(
797774
# immediate dependency problems. Update subgraph
798775
# mappings as if it next_node was part of it.
799776
# Useless inputs will be removed by the useless Composite rewrite
800-
for inp in new_required_unfuseable_inputs:
801-
subgraph_inputs[inp] = None
802-
803777
if must_become_output:
804778
subgraph_outputs[next_out] = None
805-
unfuseable_clients_subgraph.update(
806-
new_implied_unfuseable_clients
779+
unfuseable_clients_subgraph_bitset |= (
780+
implied_unfuseable_clients_bitset
807781
)
808782

809-
# Expand through unvisited fuseable ancestors
810-
fuseable_nodes_to_visit.extendleft(
811-
sorted(
812-
(
813-
inp.owner
814-
for inp in next_node.inputs
815-
if (
816-
inp not in required_unfuseable_inputs
817-
and inp.owner not in visited_nodes
818-
)
819-
),
820-
key=toposort_index.get,
821-
)
822-
)
783+
for inp in sorted(
784+
next_node.inputs,
785+
key=lambda x: toposort_index.get(x.owner, -1),
786+
):
787+
if next_node in unfuseable_clients_clone.get(inp, ()):
788+
# input must become an input of the subgraph since it's unfuseable with new node
789+
subgraph_inputs_ancestors_bitset |= (
790+
ancestors_bitset.get(inp.owner, 0)
791+
)
792+
subgraph_inputs[inp] = None
793+
elif inp.owner not in visited_nodes:
794+
fuseable_nodes_to_visit.appendleft(inp.owner)
823795

824796
# Expand through unvisited fuseable clients
825797
fuseable_nodes_to_visit.extend(
@@ -856,6 +828,8 @@ def update_fuseable_mappings_after_fg_replace(
856828
visited_nodes: set[Apply],
857829
fuseable_clients: FUSEABLE_MAPPING,
858830
unfuseable_clients: UNFUSEABLE_MAPPING,
831+
toposort_index: dict[Apply, int],
832+
ancestors_bitset: dict[Apply, int],
859833
starting_nodes: set[Apply],
860834
updated_nodes: set[Apply],
861835
) -> None:
@@ -866,11 +840,25 @@ def update_fuseable_mappings_after_fg_replace(
866840
dropped_nodes = starting_nodes - updated_nodes
867841

868842
# Remove intermediate Composite nodes from mappings
843+
# And compute the ancestors bitset of the new composite node
844+
# As well as the new toposort index for the new node
845+
new_node_ancestor_bitset = 0
846+
new_node_toposort_index = len(toposort_index)
869847
for dropped_node in dropped_nodes:
870848
(dropped_out,) = dropped_node.outputs
871849
fuseable_clients.pop(dropped_out, None)
872850
unfuseable_clients.pop(dropped_out, None)
873851
visited_nodes.remove(dropped_node)
852+
# The new composite ancestor bitset is the union
853+
# of the ancestors of all the dropped nodes
854+
new_node_ancestor_bitset |= ancestors_bitset[dropped_node]
855+
# The new composite node can have the same order as the latest node that was absorbed into it
856+
new_node_toposort_index = max(
857+
new_node_toposort_index, toposort_index[dropped_node]
858+
)
859+
860+
ancestors_bitset[new_composite_node] = new_node_ancestor_bitset
861+
toposort_index[new_composite_node] = new_node_toposort_index
874862

875863
# Update fuseable information for subgraph inputs
876864
for inp in subgraph_inputs:
@@ -902,12 +890,23 @@ def update_fuseable_mappings_after_fg_replace(
902890
fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg)
903891
visited_nodes: set[Apply] = set()
904892
toposort_index = {node: i for i, node in enumerate(fgraph.toposort())}
893+
# Create a bitset for each node of all its ancestors
894+
# This allows to quickly check if a variable depends on a set
895+
ancestors_bitset = {}
896+
for node, index in toposort_index.items():
897+
node_ancestor_bitset = 1 << index
898+
for inp in node.inputs:
899+
if (inp_node := inp.owner) is not None:
900+
node_ancestor_bitset |= ancestors_bitset[inp_node]
901+
ancestors_bitset[node] = node_ancestor_bitset
902+
905903
while True:
906904
try:
907905
subgraph_inputs, subgraph_outputs = find_fuseable_subgraph(
908906
visited_nodes=visited_nodes,
909907
fuseable_clients=fuseable_clients,
910908
unfuseable_clients=unfuseable_clients,
909+
ancestors_bitset=ancestors_bitset,
911910
toposort_index=toposort_index,
912911
)
913912
except ValueError:
@@ -926,6 +925,8 @@ def update_fuseable_mappings_after_fg_replace(
926925
visited_nodes=visited_nodes,
927926
fuseable_clients=fuseable_clients,
928927
unfuseable_clients=unfuseable_clients,
928+
toposort_index=toposort_index,
929+
ancestors_bitset=ancestors_bitset,
929930
starting_nodes=starting_nodes,
930931
updated_nodes=fg.apply_nodes,
931932
)

0 commit comments

Comments
 (0)