From 8d1076405fa4bb59b80d8ab332ad208f8f6c05b5 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Wed, 10 Sep 2025 12:42:14 +0200 Subject: [PATCH 1/4] Show errorcode in `run_mypy` --- scripts/run_mypy.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index 34cc810647..7b9ac9af59 100644 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -157,7 +157,9 @@ def check_no_unexpected_results(mypy_lines: Iterable[str]): for section, sdf in df.reset_index().groupby(args.groupby): print(f"\n\n[{section}]") for row in sdf.itertuples(): - print(f"{row.file}:{row.line}: {row.type}: {row.message}") + print( + f"{row.file}:{row.line}: {row.type} [{row.errorcode}]: {row.message}" + ) print() else: print( From 08cfb61e20b90aa73387f14819301ba8745a9f6c Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Wed, 27 Aug 2025 15:49:48 +0200 Subject: [PATCH 2/4] Allow unifying with OpPattern --- pytensor/graph/rewriting/basic.py | 151 ++++++++++++++++++------ pytensor/graph/rewriting/kanren.py | 2 +- pytensor/graph/rewriting/unify.py | 177 +++++++++++++++++++++++++++- tests/graph/rewriting/test_basic.py | 37 ++++++ tests/graph/rewriting/test_unify.py | 28 ++++- 5 files changed, 354 insertions(+), 41 deletions(-) diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py index 66d5f844b1..fcfb40ab3e 100644 --- a/pytensor/graph/rewriting/basic.py +++ b/pytensor/graph/rewriting/basic.py @@ -29,7 +29,7 @@ from pytensor.graph.features import AlreadyThere, Feature from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.op import Op -from pytensor.graph.rewriting.unify import Var, convert_strs_to_vars +from pytensor.graph.rewriting.unify import OpPattern, Var, convert_strs_to_vars from pytensor.graph.utils import AssocList, InconsistencyError from pytensor.misc.ordered_set import OrderedSet from pytensor.utils import flatten @@ -1312,6 +1312,7 @@ class PatternNodeRewriter(NodeRewriter): The input and output patterns have the following syntax: input_pattern ::= (op, , , ...) + input_pattern ::= (OpPattern(type(op), {: , ...}), , , ...) input_pattern ::= dict(pattern = , constraint = ) sub_pattern ::= input_pattern @@ -1325,6 +1326,7 @@ class PatternNodeRewriter(NodeRewriter): output_pattern ::= string output_pattern ::= int output_pattern ::= float + output_pattern ::= callable Each string in the input pattern is a variable that will be set to whatever expression is found in its place. If the same string is @@ -1350,20 +1352,73 @@ class PatternNodeRewriter(NodeRewriter): Examples -------- - PatternNodeRewriter((add, 'x', 'y'), (add, 'y', 'x')) - PatternNodeRewriter((multiply, 'x', 'x'), (square, 'x')) - PatternNodeRewriter((subtract, (add, 'x', 'y'), 'y'), 'x') - PatternNodeRewriter((power, 'x', Constant(double, 2.0)), (square, 'x')) - PatternNodeRewriter((boggle, {'pattern': 'x', - 'constraint': lambda expr: expr.type == scrabble}), - (scrabble, 'x')) + .. code-block:: python + from pytensor.graph.rewriting.basic import PatternNodeRewriter + from pytensor.tensor import add, mul, sub, pow, square + + PatternNodeRewriter((add, "x", "y"), (add, "y", "x")) + PatternNodeRewriter((mul, "x", "x"), (square, "x")) + PatternNodeRewriter((sub, (add, "x", "y"), "y"), "x") + PatternNodeRewriter((pow, "x", 2.0), (square, "x")) + PatternNodeRewriter( + (mul, {"pattern": "x", "constraint": lambda expr: expr.ndim == 0}, "y"), + (mul, "y", "x"), + ) + + You can use OpPattern to match a subtype of an Op, with some parameter constraints + You can also specify a callable as the output pattern, which will be called with (fgraph, node, subs_dict) as arguments. + + + .. code-block:: python + + from pytensor.graph.rewriting.basic import PatternNodeRewriter + from pytensor.graph.rewriting.unify import OpPattern + from pytensor.tensor.basic import Join + from pytensor.tensor.elemwise import CAReduce, Elemwise + + + def output_fn(fgraph, node, s): + reduce_op = node.op + reduced_a = reduce_op(s["a"]) + reduced_b = reduce_op(s["b"]) + return Elemwise(s["scalar_op"])(reduced_a, reduced_b) + + + PatternNodeRewriter( + ( + OpPattern(CAReduce, scalar_op="scalar_op", axis=None), + (Join(), "join_axis", "a", "b"), + ), + output_fn, + ) + + + If you want to test a string parameter, you must use LiteralString to avoid it being interpreted as a unification variable. + + .. code-block:: python + + + from pytensor.graph.rewriting.basic import PatternNodeRewriter + from pytensor.graph.rewriting.unify import OpPattern, LiteralString + from pytensor.tensor.blockwise import Blockwise + from pytensor.tensor.slinalg import Solve + + PatternNodeRewriter( + ( + OpPattern( + Blockwise, core_op=OpPattern(Solve, assume_a=LiteralString("gen")) + ), + "A", + "b", + ) + ) """ def __init__( self, - in_pattern, - out_pattern, + in_pattern: tuple, + out_pattern: tuple | Callable | str, allow_multiple_clients: bool = False, name: str | None = None, tracks=(), @@ -1378,7 +1433,8 @@ def __init__( in_pattern The input pattern that we want to replace. out_pattern - The replacement pattern. + The replacement pattern. Or a callable that takes (fgraph, node, subs_dict) as inputs, + and returns the replacement variable (or None/False to reject the rewrite). allow_multiple_clients If ``False``, the pattern matching will fail if one of the subpatterns has more than one client. @@ -1407,26 +1463,40 @@ def __init__( self.out_pattern = convert_strs_to_vars(out_pattern, var_map=var_map) self.values_eq_approx = values_eq_approx self.allow_cast = allow_cast - if isinstance(in_pattern, list | tuple): - self.op = self.in_pattern[0] - elif isinstance(in_pattern, dict): - self.op = self.in_pattern["pattern"][0] - else: - raise TypeError( - "The pattern to search for must start with a specific Op instance." - ) self.allow_multiple_clients = allow_multiple_clients if name: self.__name__ = name - self._tracks = tracks self.get_nodes = get_nodes if tracks != (): - assert get_nodes + if not get_nodes: + raise ValueError("Custom `tracks` requires `get_nodes` to be provided.") + self._tracks = tracks + else: + if isinstance(in_pattern, list | tuple): + op = self.in_pattern[0] + elif isinstance(in_pattern, dict): + op = self.in_pattern["pattern"][0] + else: + raise TypeError( + f"The in_pattern must be a sequence or a dict, but got {in_pattern} of type {type(in_pattern)}" + ) + if isinstance(op, Op): + self._tracks = [op] + elif isinstance(op, type) and issubclass(op, Op): + raise ValueError( + f"The in_pattern starts with an Op class {op}, not an instance.\n" + "You can use pytensor.graph.unify.OpPattern instead if you want to match instances of a class." + ) + elif isinstance(op, OpPattern): + self._tracks = [op.op_type] + else: + raise ValueError( + f"The in_pattern must start with a specific Op or an OpPattern instance. " + f"Got {op}, with type {type(op)}." + ) def tracks(self): - if self._tracks != (): - return self._tracks - return [self.op] + return self._tracks def transform(self, fgraph, node, get_nodes=True): """Check if the graph from node corresponds to ``in_pattern``. @@ -1447,28 +1517,39 @@ def transform(self, fgraph, node, get_nodes=True): # PatternNodeRewriter doesn't support replacing multi-output nodes return False - s = unify(self.in_pattern, node.out) + s = unify(self.in_pattern, node.out, {}) if s is False: return False - ret = reify(self.out_pattern, s) - - if isinstance(ret, ExpressionTuple): - ret = ret.evaled_obj - - if self.values_eq_approx: - ret.tag.values_eq_approx = self.values_eq_approx - if not self.allow_multiple_clients: - input_vars = list(s.values()) + input_vars = set(s.values()) + clients = fgraph.clients if any( - len(fgraph.clients[v]) > 1 + len(clients[v]) > 1 for v in vars_between(input_vars, node.inputs) if v not in input_vars ): return False + if callable(self.out_pattern): + # token is the variable name used in the original pattern + ret = self.out_pattern(fgraph, node, {k.token: v for k, v in s.items()}) + if ret is None or ret is False: + # The output function is still allowed to reject the rewrite + return False + if not isinstance(ret, Variable): + raise ValueError( + f"The output of the PatternNodeRewriter callable must be a variable got {ret} of type {type(ret)}." + ) + else: + ret = reify(self.out_pattern, s) + if isinstance(ret, ExpressionTuple): + ret = ret.evaled_obj + + if self.values_eq_approx: + ret.tag.values_eq_approx = self.values_eq_approx + [old_out] = node.outputs if not old_out.type.is_super(ret.type): from pytensor.tensor.type import TensorType diff --git a/pytensor/graph/rewriting/kanren.py b/pytensor/graph/rewriting/kanren.py index e71303a169..46a6a8d510 100644 --- a/pytensor/graph/rewriting/kanren.py +++ b/pytensor/graph/rewriting/kanren.py @@ -86,7 +86,7 @@ def transform(self, fgraph, node): q = var() kanren_results = run(None, q, self.kanren_relation(input_expr, q)) - chosen_res = self.results_filter(kanren_results) + chosen_res = self.results_filter(kanren_results) # type: ignore[arg-type] if chosen_res: if isinstance(chosen_res, list): diff --git a/pytensor/graph/rewriting/unify.py b/pytensor/graph/rewriting/unify.py index e9361d62c2..aa1b2e26a3 100644 --- a/pytensor/graph/rewriting/unify.py +++ b/pytensor/graph/rewriting/unify.py @@ -10,8 +10,11 @@ """ -from collections.abc import Mapping +from collections.abc import Mapping, Sequence +from dataclasses import dataclass from numbers import Number +from types import UnionType +from typing import Any, TypeAlias import numpy as np from cons.core import ConsError, _car, _cdr @@ -254,6 +257,164 @@ def _unify_ConstrainedVar_object(u, v, s): _unify.add((object, ConstrainedVar, Mapping), _unify_ConstrainedVar_object) +@dataclass(frozen=True) +class LiteralString: + value: str + + +OpPatternOpTypeType: TypeAlias = type[Op] | tuple[type[Op], ...] | UnionType + + +@dataclass(unsafe_hash=True) +class OpPattern: + """Class that can be unified with Op instances of a given type (or instance) and parameters. + + Parameters that are not specified in the OpPattern are ignored during unification. + + This is needed because some Ops can be complex to parametrize fully, + and not all parameters are relevant for a given pattern. + + + Examples + -------- + + OpPattern can be used with `PatternNodeRewriter` to define graph rewrites that match Ops with specific parameters. + The example below matches two nested CAReduce Ops with the same `scalar_op`, + the outer with `axis=None` (full reduction) and fuses them into a single CAReduce. + Note, that because we didn't specify it, the axis of the inner CAReduce can be anything. + The same goes for other properties of the Op that are not specified in the OpPattern. + + .. testcode:: + + from pytensor.graph.rewriting.basic import PatternNodeRewriter + from pytensor.graph.rewriting.unify import OpPattern + from pytensor.tensor.basic import Join + from pytensor.tensor.elemwise import CAReduce, Elemwise + + def output_fn(fgraph, node, s): + reduce_op = node.op + reduced_a = reduce_op(s["a"]) + reduced_b = reduce_op(s["b"]) + return Elemwise(s["scalar_op"])(reduced_a, reduced_b) + + + PatternNodeRewriter( + in_pattern=(OpPattern(CAReduce, scalar_op="scalar_op", axis=None), + (OpPattern(CAReduce, scalar_op="scalar_op",), "x")), + out_pattern=(OpPattern(CAReduce, scalar_op="scalar_op", axis=None), "x"), + ) + + + OpPattern can also be used with `unification.unify` to match Ops with specific parameters. + This is used by PatternNodeRewriter but can also be used directly. + + .. testcode:: + + from unification import var, unify + from etuples import etuple + + import pytensor.tensor as pt + from pytensor.graph.rewriting.unify import OpPattern + from pytensor.tensor.blockwise import Blockwise + from pytensor.tensor.slinalg import Solve + + A = var("A") + b = var("b") + pattern = etuple( + OpPattern(Blockwise, core_op=OpPattern(Solve, assume_a="gen")), + A, + b, + ) + + A_pt = pt.tensor3("A") + b_pt = pt.tensor3("b") + out1 = pt.linalg.solve(A_pt, b_pt) + out2 = pt.linalg.solve(A_pt, b_pt, assume_a="pos") + + assert unify(pattern, out1) == {A: A_pt, b: b_pt} + assert unify(pattern, out2) is False + + assume_a = var("assume_a") + pattern = etuple( + OpPattern(Blockwise, core_op=OpPattern(Solve, assume_a=assume_a)), + A, + b, + ) + assert unify(pattern, out1) == {A: A_pt, b: b_pt, assume_a: "gen"} + assert unify(pattern, out2) == {A: A_pt, b: b_pt, assume_a: "pos"} + + + """ + + op_type: OpPatternOpTypeType + parameters: tuple[tuple[str, Any]] + + def __init__( + self, + op_type: OpPatternOpTypeType, + parameters: dict[str, Any] | Sequence[tuple[str, Any]] | None = None, + **kwargs, + ): + if kwargs: + if parameters is not None: + raise ValueError( + "Cannot provide both parameters dict and keyword arguments" + ) + parameters = kwargs + if isinstance(parameters, dict): + parameters = tuple(sorted(parameters.items())) + elif isinstance(parameters, list | tuple): + parameters = tuple(sorted(parameters)) + elif parameters is None: + parameters = () + self.op_type = op_type + self.parameters = parameters # type: ignore[assignment] + + def match_op(self, op: Op): + if not isinstance(op, self.op_type): + return False + return self.match_parameters(op) + + def match_parameters(self, op): + # This is used by methods that already check the op_type is satisfied + # Some methods may index on the op_type and know in advance the op is matched + # Also recursive calls to OpPattern.match_parameters do the op check outside to exit early (see below) + for key, param in self.parameters: + if isinstance(param, OpPattern): + # Parameters can itself be other OpPatterns + # We check the op_type to avoid a nested call in cases we can reject early + sub_op = getattr(op, key) + if not isinstance(sub_op, param.op_type): + return False + # Match the pattern of the inner Op + # Skip if there are no parameters + if param.parameters and not param.match_parameters(sub_op): + return False + elif getattr(op, key) != param: + return False + return True + + def __str__(self): + return f"OpPattern({self.op_type}, {', '.join(f'{k}={v}' for k, v in self.parameters)})" + + +def _unify_parametrized_op(v: Op, u: OpPattern, s: Mapping): + if not isinstance(v, u.op_type): + yield False + return + for parameter_key, parameter_pattern in u.parameters: + parameter_value = getattr(v, parameter_key) + new_s = yield _unify(parameter_value, parameter_pattern, s) + if new_s is False: + yield False + return + s = new_s + yield s + + +_unify.add((Op, OpPattern, Mapping), _unify_parametrized_op) + + def convert_strs_to_vars( x: tuple | str | dict, var_map: dict[str, Var] | None = None ) -> ExpressionTuple | Var: @@ -266,11 +427,13 @@ def convert_strs_to_vars( if var_map is None: var_map = {} - def _convert(y): + def _convert(y, op_prop=False): if isinstance(y, str): v = var_map.get(y, var(y)) var_map[y] = v return v + if isinstance(y, LiteralString): + return y.value elif isinstance(y, dict): pattern = y["pattern"] if not isinstance(pattern, str): @@ -282,8 +445,14 @@ def _convert(y): var_map[pattern] = v return v elif isinstance(y, tuple): - return etuple(*(_convert(e) for e in y)) - elif isinstance(y, Number | np.ndarray): + return etuple(*(_convert(e, op_prop=op_prop) for e in y)) + elif isinstance(y, OpPattern): + return OpPattern( + y.op_type, + {k: _convert(v, op_prop=True) for k, v in y.parameters}, + ) + elif (not op_prop) and isinstance(y, Number | np.ndarray): + # If we are converting an Op property, we don't want to convert numbers to PyTensor constants from pytensor.tensor import as_tensor_variable return as_tensor_variable(y) diff --git a/tests/graph/rewriting/test_basic.py b/tests/graph/rewriting/test_basic.py index 228c93a8c8..07c518af93 100644 --- a/tests/graph/rewriting/test_basic.py +++ b/tests/graph/rewriting/test_basic.py @@ -18,6 +18,7 @@ pre_constant_merge, pre_greedy_node_rewriter, ) +from pytensor.graph.rewriting.unify import LiteralString, OpPattern from pytensor.raise_op import assert_op from pytensor.tensor.math import Dot, add, dot, exp from pytensor.tensor.rewriting.basic import constant_folding @@ -283,6 +284,42 @@ def test_eq(self): str_g = str(g) assert str_g == "FunctionGraph(Op4(z, y))" + def test_op_pattern(self): + a = MyVariable("a") + e1 = MyOp(name="MyOp(x=1)", x=1)(a) + e2 = MyOp(name="MyOp(x=2)", x=2)(a) + e_hello = MyOp(name="MyOp(x='hello')", x="hello")(a) + op_x3 = MyOp(name="MyOp(x=3)", x=3) + assert not equal_computations([e1], [op_x3(a)]) + assert not equal_computations([e2], [op_x3(a)]) + + rewriter = WalkingPatternNodeRewriter( + (OpPattern(MyOp, x=1), "a"), + "a", + ) + g = FunctionGraph([a], [e1, e2, e1], copy_inputs=False) + rewriter.rewrite(g) + assert equal_computations(g.outputs, [a, e2, a]) + + rewriter = WalkingPatternNodeRewriter( + (OpPattern(MyOp, x="x"), "a"), + lambda fgraph, node, subs: ( + MyOp(name="MyOp(x+=10)", x=subs["x"] + 10)(subs["a"]) + if subs["x"] < 10 + else False + ), + ) + g = FunctionGraph([a], [e1], copy_inputs=False) + rewriter.rewrite(g) + assert equal_computations(g.outputs, [MyOp(name="x=11", x=11)(a)]) + + rewriter = WalkingPatternNodeRewriter( + (OpPattern(MyOp, x=LiteralString("hello")), "a"), "a" + ) + g = FunctionGraph([a], [e1, e_hello], copy_inputs=False) + rewriter.rewrite(g) + assert equal_computations(g.outputs, [e1, a]) + class NoInputOp(Op): __props__ = ("param",) diff --git a/tests/graph/rewriting/test_unify.py b/tests/graph/rewriting/test_unify.py index da430a1587..5ce8d04105 100644 --- a/tests/graph/rewriting/test_unify.py +++ b/tests/graph/rewriting/test_unify.py @@ -11,7 +11,11 @@ import pytensor.tensor as pt from pytensor.graph.basic import Apply, Constant, equal_computations from pytensor.graph.op import Op -from pytensor.graph.rewriting.unify import ConstrainedVar, convert_strs_to_vars +from pytensor.graph.rewriting.unify import ( + ConstrainedVar, + OpPattern, + convert_strs_to_vars, +) from pytensor.tensor.type import TensorType from tests.graph.utils import MyType @@ -348,3 +352,25 @@ def constraint(x): res = convert_strs_to_vars((val,)) assert isinstance(res[0], Constant) assert np.array_equal(res[0].data, val) + + +def test_unify_OpPattern(): + x_pt = MyType()("x_pt") + y_pt = MyType()("y_pt") + out1 = CustomOp(a=1)(x_pt, y_pt) + out2 = CustomOp(a=2)(x_pt, y_pt) + + x = var("x") + y = var("y") + pattern = etuple(OpPattern(CustomOp), x, y) + assert unify(pattern, out1) == {x: x_pt, y: y_pt} + assert unify(pattern, out2) == {x: x_pt, y: y_pt} + + pattern = etuple(OpPattern(CustomOp, a=1), x, y) + assert unify(pattern, out1) == {x: x_pt, y: y_pt} + assert unify(pattern, out2) is False + + a = var("a") + pattern = etuple(OpPattern(CustomOp, a=a), x, y) + assert unify(pattern, out1) == {x: x_pt, y: y_pt, a: 1} + assert unify(pattern, out2) == {x: x_pt, y: y_pt, a: 2} From 7e1cf637e860de9408aceb98136c4422a8e40951 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 28 Aug 2025 15:25:57 +0200 Subject: [PATCH 3/4] Allow OpPattern in tracks Also avoid repeated checks when an outer rewriter enforces tracks before calling individual node rewriters --- doc/gallery/rewrites/graph_rewrites.ipynb | 6 +- pytensor/graph/rewriting/basic.py | 179 +++++++++++++++++----- pytensor/graph/rewriting/kanren.py | 4 +- pytensor/graph/rewriting/unify.py | 36 +++++ pytensor/tensor/rewriting/math.py | 4 +- 5 files changed, 181 insertions(+), 48 deletions(-) diff --git a/doc/gallery/rewrites/graph_rewrites.ipynb b/doc/gallery/rewrites/graph_rewrites.ipynb index 298e13b95e..5aaf7f383b 100644 --- a/doc/gallery/rewrites/graph_rewrites.ipynb +++ b/doc/gallery/rewrites/graph_rewrites.ipynb @@ -583,7 +583,7 @@ " def tracks(self):\n", " return [pt.log]\n", " \n", - " def transform(self, fgraph, node):\n", + " def transform(self, fgraph, node, enforce_tracks=True):\n", " return local_log1p(node) \n", " \n", " def __str__(self):\n", @@ -669,8 +669,8 @@ "@node_rewriter(tracks=[pt.abs])\n", "def local_useless_abs_exp(fgraph, node):\n", " # Because of the tracks we don't need to check \n", - " # that `node` has a `Sign` Op.\n", - " # We still need to check whether it's input is an `Abs` Op\n", + " # that `node` has a `Abs` Op.\n", + " # We still need to check whether it's input is an `Exp` Op\n", " exp_node = node.inputs[0].owner\n", " if exp_node is None or exp_node.op != pt.exp:\n", " return None\n", diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py index fcfb40ab3e..f76c9770a9 100644 --- a/pytensor/graph/rewriting/basic.py +++ b/pytensor/graph/rewriting/basic.py @@ -141,7 +141,12 @@ def tracks(self) -> Sequence[Op] | None: @abc.abstractmethod def transform( - self, fgraph: FunctionGraph, node: Apply, *args, **kwargs + self, + fgraph: FunctionGraph, + node: Apply, + enforce_tracks: bool = True, + *args, + **kwargs, ) -> TransformOutputType: r"""Rewrite the sub-graph given by `node`. @@ -159,7 +164,9 @@ def transform( A `FunctionGraph` containing `node`. node An `Apply` node to be rewritten. - + enforce_tracks: bool + Whether the transform method should enforce tracks, or it can be assumed the caller already enforced them in a pre-filter stage. + See `node_rewriter` tracks argument for more details. """ raise NotImplementedError() @@ -935,15 +942,48 @@ class FromFunctionNodeRewriter(NodeRewriter): def __init__(self, fn, tracks=None, requirements=()): self.fn = fn self._tracks = tracks - self._tracked_types = ( - tuple(t for t in tracks if isinstance(t, type)) if tracks else () - ) + self._tracked_ops = set() + self._tracked_types = type(None) + self._tracked_op_pattern_types = type(None) + self._tracked_op_patterns: list[OpPattern] = [] + if tracks is not None: + if not tracks: + raise ValueError( + "To specify a general rewrite leave tracks as None instead of an empty container" + ) + for t in tracks: + if isinstance(t, Op): + self._tracked_ops.add(t) + elif isinstance(t, type): + self._tracked_types |= t + elif isinstance(t, OpPattern): + if t.parameters: + self._tracked_op_patterns.append(t) + self._tracked_op_pattern_types |= t.op_type + else: + # An OpPattern without parameters behaves like a regular tracked_type + self._tracked_types |= t + else: + raise TypeError( + "`tracks` must consist of `Op` classes, `Op` instances or `OpPattern` instances. " + f"Got {t} of type {type(t)}" + ) self.requirements = requirements - def transform(self, fgraph, node): - if self._tracks: + def transform(self, fgraph, node, enforce_tracks: bool = True): + if enforce_tracks and self._tracks: + node_op = node.op if not ( - node.op in self._tracks or isinstance(node.op, self._tracked_types) + node_op in self._tracked_ops + or isinstance(node_op, self._tracked_types) + or ( + isinstance(node.op, self._tracked_op_pattern_types) + and any( + t.match_parameters(node_op) + for t in self._tracked_op_patterns + if isinstance(node_op, t.op_type) + ) + ) ): return False @@ -967,7 +1007,7 @@ def print_summary(self, stream=sys.stdout, level=0, depth=-1): def node_rewriter( - tracks: Sequence[Op | type] | None, + tracks: Sequence[Op | type, OpPattern] | None, inplace: bool = False, requirements: tuple[type, ...] | None = (), ): @@ -976,7 +1016,7 @@ def node_rewriter( Parameters ---------- tracks - The `Op` types or instances to which this rewrite applies. + The `Op` type, instances or `OpPattern` to which this rewrite applies. Use ``None`` instead of an empty list to have the rewrite apply to all `Op`\s. inplace @@ -995,14 +1035,16 @@ def decorator(f): if tracks is not None: if len(tracks) == 0: raise ValueError( - "Use `None` instead of an empty list to make an rewrite apply to all nodes." + "Use `None` instead of an empty list to make a rewrite apply to all nodes." ) for t in tracks: if not ( - isinstance(t, Op) or (isinstance(t, type) and issubclass(t, Op)) + isinstance(t, Op | OpPattern) + or (isinstance(t, type) and issubclass(t, Op)) ): raise TypeError( - "`tracks` must consist of `Op` classes or instances." + "`tracks` must consist of `Op` classes, `Op` instances or `OpPattern` instances. " + f"Got {t} of type {type(t)}" ) req = requirements if inplace: @@ -1024,47 +1066,93 @@ class OpToRewriterTracker: def __init__(self): self.tracked_instances: dict[Op, list[NodeRewriter]] = defaultdict(list) self.tracked_types: dict[type, list[NodeRewriter]] = defaultdict(list) + self.tracked_pattern_types: dict[type, dict[OpPattern, list[NodeRewriter]]] = ( + defaultdict(lambda: defaultdict(list)) + ) self.untracked_rewrites: list[NodeRewriter] = [] + self._cached_composed_mro = None def add_tracker(self, rw: NodeRewriter): """Add a `NodeRewriter` to be keyed by its `NodeRewriter.tracks` or applied generally.""" + if self._cached_composed_mro is not None: + # We shouldn't actually add_trackers after the first call to get_trackers + # But just to be safe we kill the cache here + self._cached_composed_mro = None + tracks = rw.tracks() if tracks is None: self.untracked_rewrites.append(rw) else: for c in tracks: + if isinstance(c, OpPattern): + if not isinstance(c.op_type, type): + # OpPattern allows anything that you can check with isinstance(op, op_type), + # including tuples or union types. But for OpToRewriterTracker we need a single type. + raise NotImplementedError( + "OpToRewriterTracker requires the outermost `OpPattern.op_type` to be a type. " + f"Got {c.op_type} of type {type(c.op_type)}" + ) + + if c.parameters: + self.tracked_pattern_types[c.op_type][c].append(rw) + else: + # An OpPattern without parameters behaves like a regular tracked_type + self.tracked_types[c.op_type].append(rw) if isinstance(c, type): self.tracked_types[c].append(rw) else: self.tracked_instances[c].append(rw) - def _find_impl(self, cls) -> list[NodeRewriter]: - r"""Returns the `NodeRewriter`\s that apply to `cls` based on inheritance. + @functools.cache + def get_trackers(self, op: Op) -> list[NodeRewriter]: + """Get all the rewrites applicable to an `Op`.""" + + if self._cached_composed_mro is None: + # Cache the mro call on the Op type. We have a small subset of op_types we actually care about + # like Elemwise, Blockwise, and so on, which we don't need to repeatedly investigate + tracked_types = ( + self.tracked_types.keys() | self.tracked_pattern_types.keys() + ) + + @functools.cache + def cached_composed_mro(op_type, tracked_types=tracked_types): + return _compose_mro(op_type, tracked_types) + + self._cached_composed_mro = cached_composed_mro - This based on `functools._find_impl`. - """ - mro = _compose_mro(cls, self.tracked_types.keys()) matches = [] - for t in mro: - match = self.tracked_types.get(t, None) - if match: - matches.extend(match) + if self.tracked_types or self.tracked_pattern_types: + # Find matches for type(op) (and their subclasses) using the same approach that functools.singledispatch uses + mro = self._cached_composed_mro(type(op)) + for t in mro: + if (match := self.tracked_types.get(t, None)) is not None: + matches.extend(match) + if ( + potential_matches := self.tracked_pattern_types.get(t, None) + ) is not None: + # We still need to check if the Op parameters match the constraints + matches.extend( + [ + item + for op_pattern, r_list in potential_matches.items() + if op_pattern.match_parameters(op) + for item in r_list + ] + ) + matches.extend(self.tracked_instances.get(op, [])) + matches.extend(self.untracked_rewrites) return matches - @functools.lru_cache - def get_trackers(self, op: Op) -> list[NodeRewriter]: - """Get all the rewrites applicable to `op`.""" - return ( - self._find_impl(type(op)) - + self.tracked_instances.get(op, []) - + self.untracked_rewrites - ) - - def get_rewriters(self): + def get_rewriters(self) -> Iterable[NodeRewriter]: + """Get all the registered rewriters.""" return chain( + chain.from_iterable(self.tracked_types.values()), + chain.from_iterable(self.tracked_instances.values()), chain.from_iterable( - chain(self.tracked_types.values(), self.tracked_instances.values()) + item + for sub_dict in self.tracked_pattern_types.values() + for item in sub_dict.values() ), self.untracked_rewrites, ) @@ -1138,7 +1226,7 @@ def tracks(self): t.extend(at) return t - def transform(self, fgraph, node): + def transform(self, fgraph, node, enforce_tracks=False): if len(self.rewrites) == 0: return @@ -1150,7 +1238,8 @@ def transform(self, fgraph, node): new_repl = None for rewrite in rewrites: rewrite_start = time.perf_counter() - new_repl = rewrite.transform(fgraph, node) + # Tracks are already enforced by `self.tracker.get_trackers` + new_repl = rewrite.transform(fgraph, node, enforce_tracks=False) rewrite_finish = time.perf_counter() if self.profile: self.time_rewrites[rewrite] += rewrite_start - rewrite_finish @@ -1292,8 +1381,8 @@ def __init__(self, op1, op2, transfer_tags=True): def tracks(self): return [self.op1] - def transform(self, fgraph, node): - if node.op != self.op1: + def transform(self, fgraph, node, enforce_tracks=True): + if enforce_tracks and (node.op != self.op1): return False repl = self.op2.make_node(*node.inputs) if self.transfer_tags: @@ -1498,7 +1587,7 @@ def __init__( def tracks(self): return self._tracks - def transform(self, fgraph, node, get_nodes=True): + def transform(self, fgraph, node, enforce_tracks: bool = False, get_nodes=True): """Check if the graph from node corresponds to ``in_pattern``. If it does, it constructs ``out_pattern`` and performs the replacement. @@ -1788,6 +1877,7 @@ def process_node( fgraph: FunctionGraph, node: Apply, node_rewriter: NodeRewriter | None = None, + enforce_tracks: bool = True, ): r"""Apply `node_rewriter` to `node`. @@ -1805,6 +1895,9 @@ def process_node( node_rewriter A `NodeRewriter` instance that may have a better idea for how to compute node's outputs. + enforce_tracks: bool + Whether the transform method should enforce tracks, + or it can be assumed the caller already enforced them in a pre-filter stage. Returns ------- @@ -1820,7 +1913,9 @@ def process_node( # TODO FIXME: This class's interface is broken assert node_rewriter is not None try: - replacements = node_rewriter.transform(fgraph, node) + replacements = node_rewriter.transform( + fgraph, node, enforce_tracks=enforce_tracks + ) except Exception as e: if self.failure_callback is not None: self.failure_callback( @@ -1938,7 +2033,8 @@ def importer(node): if node not in fgraph.apply_nodes: continue current_node = node - nb += self.process_node(fgraph, node) + # This rewriter does not enforce tracks itself + nb += self.process_node(fgraph, node, enforce_tracks=True) loop_t = time.perf_counter() - t0 finally: self.detach_updater(fgraph, u) @@ -2279,8 +2375,9 @@ def chin_(node, i, r, new_r, reason): for node_rewriter in self.node_tracker.get_trackers(node.op): nb = change_tracker.nb_imported t_rewrite = time.perf_counter() + # Tracks are already enfoced by `self.node_tracker.get_trackers` node_rewriter_change = self.process_node( - fgraph, node, node_rewriter + fgraph, node, node_rewriter, enforce_tracks=False ) time_rewriters[node_rewriter] += time.perf_counter() - t_rewrite if not node_rewriter_change: diff --git a/pytensor/graph/rewriting/kanren.py b/pytensor/graph/rewriting/kanren.py index 46a6a8d510..8b45d85da8 100644 --- a/pytensor/graph/rewriting/kanren.py +++ b/pytensor/graph/rewriting/kanren.py @@ -74,7 +74,7 @@ def results_filter( self.node_filter = node_filter super().__init__() - def transform(self, fgraph, node): + def transform(self, fgraph, node, enforce_tracks: bool = True): if self.node_filter(node) is False: return False @@ -92,7 +92,7 @@ def transform(self, fgraph, node): if isinstance(chosen_res, list): new_outputs = [eval_if_etuple(v) for v in chosen_res] else: - new_outputs = [eval_if_etuple(chosen_res)] + new_outputs = [eval_if_etuple(chosen_res)] # type: ignore[unreachable] return new_outputs else: diff --git a/pytensor/graph/rewriting/unify.py b/pytensor/graph/rewriting/unify.py index aa1b2e26a3..27c916284a 100644 --- a/pytensor/graph/rewriting/unify.py +++ b/pytensor/graph/rewriting/unify.py @@ -278,6 +278,42 @@ class OpPattern: Examples -------- + OpPattern can be used in the `tracks` functionality of `node_rewriter` to more flexible filter out nodes. + For Ops that are parametrized by other Ops, it's possible to use nested OpPatterns. + + .. test-code:: + + from pytensor.graph.rewriting.basic import node_rewriter + from pytensor.graph.rewriting.unify import OpPattern + from pytensor.tensor.elemwise import CAReduce + from pytensor.tensor.blockwise import Blockwise + from pytensor.tensor.slinalg import Solve + + @node_rewriter(tracks=[OpPattern(CAReduce, axis=None)]) + def local_car_reduce_all_rewriter(fgraph, node): + # This will always be true! + assert isinstance(node.op, CAReduce) and node.op.axis is None + ... + + # Any Blockwise whose core_op is a Solve Op (or subclass) instance + @node_rewriter(tracks=[OpPattern(Blockwise, core_op=OpPattern(Solve))]) + def local_blockwise_solve_triangular_rewriter(fgraph, node): + # This will always be true! + assert isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Solve) + ... + + # Any Blockwise whose core_op is a Solve Op (or subclass) instance with b_ndim==1 + @node_rewriter(tracks=[OpPattern(Blockwise, core_op=OpPattern(Solve, b_ndim=1))]) + def local_blockwise_vector_solve_rewriter(fgraph, node): + # This will always be true! + assert ( + isinstance(node.op, Blockwise) + and isinstance(node.op.core_op, Solve) + and node.op.core_op.b_ndim == 1 + ) + ... + + OpPattern can be used with `PatternNodeRewriter` to define graph rewrites that match Ops with specific parameters. The example below matches two nested CAReduce Ops with the same `scalar_op`, the outer with `axis=None` (full reduction) and fuses them into a single CAReduce. diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index df2355ca12..28f1f87fff 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -1338,9 +1338,9 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None): return ct + num, denum - def transform(self, fgraph, node): + def transform(self, fgraph, node, enforce_tracks=True): op = node.op - if op not in [self.main, self.inverse, self.reciprocal]: + if enforce_tracks and (op not in {self.main, self.inverse, self.reciprocal}): return False assert len(node.outputs) == 1 From e12fcd404fe2a7c5f6ab5c1107390d19d2afa782 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 28 Aug 2025 18:02:25 +0200 Subject: [PATCH 4/4] Use OpPattern in tracks --- pytensor/scalar/basic.py | 8 +- pytensor/tensor/_linalg/solve/rewriting.py | 3 +- pytensor/tensor/rewriting/basic.py | 62 +++++---- pytensor/tensor/rewriting/blockwise.py | 46 ++++--- pytensor/tensor/rewriting/elemwise.py | 43 ++---- pytensor/tensor/rewriting/linalg.py | 145 +++++++-------------- pytensor/tensor/rewriting/math.py | 91 +++++-------- pytensor/tensor/rewriting/subtensor.py | 58 ++++----- 8 files changed, 189 insertions(+), 267 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index de92555401..31dab272f6 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -1228,6 +1228,8 @@ def __init__(self, output_types_preference=None, name=None): f"(got: {output_types_preference})" ) self.output_types_preference = output_types_preference + elif not hasattr(self, "output_types_preference"): + self.output_types_preference = None def make_node(self, *inputs): if self.nin >= 0: @@ -1247,7 +1249,7 @@ def make_node(self, *inputs): return Apply(self, inputs, outputs) def output_types(self, types): - if hasattr(self, "output_types_preference"): + if self.output_types_preference is not None: variables = self.output_types_preference(*types) if not isinstance(variables, list | tuple) or any( not isinstance(x, CType) for x in variables @@ -2696,7 +2698,7 @@ class Sign(UnaryScalarOp): nfunc_spec = ("sign", 1, 1) @staticmethod - def output_types_preference(x): + def _output_types_preference(x): if x == bool: raise TypeError(x) return same_out_nocomplex(x) @@ -2737,7 +2739,7 @@ def c_code_cache_version(self): return s -sign = Sign(name="sign") +sign = Sign(name="sign", output_types_preference=Sign._output_types_preference) class Ceil(UnaryScalarOp): diff --git a/pytensor/tensor/_linalg/solve/rewriting.py b/pytensor/tensor/_linalg/solve/rewriting.py index c0a1c5cce8..3e06604ddb 100644 --- a/pytensor/tensor/_linalg/solve/rewriting.py +++ b/pytensor/tensor/_linalg/solve/rewriting.py @@ -14,6 +14,7 @@ from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.rewriting.basic import register_specialize +from pytensor.tensor.rewriting.blockwise import blockwise_of from pytensor.tensor.rewriting.linalg import is_matrix_transpose from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_solve from pytensor.tensor.variable import TensorVariable @@ -227,7 +228,7 @@ def _scan_split_non_sequence_decomposition_and_solve( @register_specialize -@node_rewriter([Blockwise]) +@node_rewriter([blockwise_of(Solve)]) def reuse_decomposition_multiple_solves(fgraph, node): return _split_decomp_and_solve_steps( fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal", "pos"} diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index e9c2c8e47e..a0e873212c 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -26,10 +26,9 @@ import numpy as np -import pytensor.scalar.basic as ps from pytensor import compile, config from pytensor.compile.ops import ViewOp -from pytensor.graph import FunctionGraph +from pytensor.graph import FunctionGraph, Op from pytensor.graph.basic import Constant from pytensor.graph.rewriting.basic import ( NodeProcessingGraphRewriter, @@ -40,9 +39,24 @@ node_rewriter, ) from pytensor.graph.rewriting.db import RewriteDatabase +from pytensor.graph.rewriting.unify import OpPattern, OpPatternOpTypeType from pytensor.npy_2_compat import normalize_axis_index from pytensor.raise_op import Assert, CheckAndRaise, assert_op -from pytensor.scalar.basic import Second +from pytensor.scalar import ( + AND, + EQ, + LE, + NEQ, + OR, + XOR, + Add, + BinaryScalarOp, + Cast, + Identity, + Mul, + Second, + Switch, +) from pytensor.tensor.basic import ( Alloc, AllocEmpty, @@ -225,6 +239,12 @@ def register(inner_rewriter: RewriteDatabase | Rewriter): return node_rewriter +def elemwise_of(scalar_op: OpPatternOpTypeType | OpPattern) -> OpPattern: + if not isinstance(scalar_op, Op | OpPattern): + scalar_op = OpPattern(scalar_op) + return OpPattern(Elemwise, scalar_op=scalar_op) + + @register_canonicalize @register_specialize @node_rewriter([TensorFromScalar]) @@ -551,7 +571,7 @@ def local_useless_elemwise(fgraph, node): dtype = node.outputs[0].type.dtype scalar_op = node.op.scalar_op - if isinstance(scalar_op, ps.EQ) and len(node.inputs) == 2: + if isinstance(scalar_op, EQ) and len(node.inputs) == 2: if node.inputs[0] is node.inputs[1]: # it is the same var in the graph. That will always be true ret = ones_like(node.inputs[0], dtype=dtype, opt=True) @@ -559,7 +579,7 @@ def local_useless_elemwise(fgraph, node): # Copy stack trace from input to constant output copy_stack_trace(node.outputs[0], ret) return [ret] - elif isinstance(scalar_op, ps.NEQ | ps.XOR) and len(node.inputs) == 2: + elif isinstance(scalar_op, NEQ | XOR) and len(node.inputs) == 2: if node.inputs[0] is node.inputs[1]: # it is the same var in the graph. That will always be false ret = zeros_like(node.inputs[0], dtype=dtype, opt=True) @@ -568,14 +588,11 @@ def local_useless_elemwise(fgraph, node): copy_stack_trace(node.outputs[0], ret) return [ret] - elif ( - isinstance(node.op.scalar_op, ps.Mul | ps.Add | ps.Identity) - and len(node.inputs) == 1 - ): + elif isinstance(node.op.scalar_op, Mul | Add | Identity) and len(node.inputs) == 1: # No need to copy over any stack trace return [node.inputs[0]] - elif isinstance(node.op.scalar_op, ps.AND) and len(node.inputs) == 2: + elif isinstance(node.op.scalar_op, AND) and len(node.inputs) == 2: if ( isinstance(node.inputs[0], TensorConstant) and node.inputs[1].type.broadcastable == out_bcast @@ -602,7 +619,7 @@ def local_useless_elemwise(fgraph, node): # and this rewrite would be wrong return [node.inputs[0].astype(node.outputs[0].dtype)] - elif isinstance(node.op.scalar_op, ps.OR) and len(node.inputs) == 2: + elif isinstance(node.op.scalar_op, OR) and len(node.inputs) == 2: if ( isinstance(node.inputs[0], TensorConstant) and node.inputs[1].type.broadcastable == out_bcast @@ -653,7 +670,7 @@ def local_alloc_unary(fgraph, node): @register_canonicalize @register_specialize -@node_rewriter([Elemwise]) +@node_rewriter([elemwise_of(Cast)]) def local_cast_cast(fgraph, node): """cast(cast(x, dtype1), dtype2) @@ -663,13 +680,11 @@ def local_cast_cast(fgraph, node): and the first cast cause an upcast. """ - if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Cast)): - return x = node.inputs[0] if not ( x.owner and isinstance(x.owner.op, Elemwise) - and isinstance(x.owner.op.scalar_op, ps.Cast) + and isinstance(x.owner.op.scalar_op, Cast) ): return @@ -1009,7 +1024,7 @@ def local_useless_switch(fgraph, node): node.outputs[0].type.ndim == 0 and cond_var.owner and isinstance(cond_var.owner.op, Elemwise) - and isinstance(cond_var.owner.op.scalar_op, ps.LE) + and isinstance(cond_var.owner.op.scalar_op, LE) and cond_var.owner.inputs[0].owner and isinstance(cond_var.owner.inputs[0].owner.op, Shape_i) and get_scalar_constant_value( @@ -1031,24 +1046,18 @@ def local_useless_switch(fgraph, node): @register_canonicalize -@node_rewriter([Elemwise]) +@node_rewriter([elemwise_of(BinaryScalarOp | Add | Mul)]) def local_merge_switch_same_cond(fgraph, node): """ Merge add/sub/mul/div/minimum/maximum/... of switches sharing the same condition, to enable further simplification of their branches Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y) """ - # node must be binary elemwise or add or mul - if not ( - isinstance(node.op, Elemwise) - and isinstance(node.op.scalar_op, ps.BinaryScalarOp | ps.Add | ps.Mul) - ): - return # all inputs must be switch if not all( s.owner and isinstance(s.owner.op, Elemwise) - and isinstance(s.owner.op.scalar_op, ps.Switch) + and isinstance(s.owner.op.scalar_op, Switch) for s in node.inputs ): return @@ -1174,10 +1183,9 @@ def constant_folding(fgraph, node): @register_infer_shape @register_canonicalize("fast_compile") @register_useless("fast_compile") -@node_rewriter(None) +@node_rewriter([ViewOp]) def local_view_op(fgraph, node): - if isinstance(node.op, ViewOp): - return node.inputs + return node.inputs @register_infer_shape diff --git a/pytensor/tensor/rewriting/blockwise.py b/pytensor/tensor/rewriting/blockwise.py index 023c8aae51..a5afe2fcd1 100644 --- a/pytensor/tensor/rewriting/blockwise.py +++ b/pytensor/tensor/rewriting/blockwise.py @@ -1,8 +1,9 @@ from pytensor.compile.mode import optdb -from pytensor.graph import Constant, node_rewriter +from pytensor.graph import Constant, Op, node_rewriter from pytensor.graph.destroyhandler import inplace_candidates from pytensor.graph.replace import vectorize_node from pytensor.graph.rewriting.basic import copy_stack_trace, out2in +from pytensor.graph.rewriting.unify import OpPattern, OpPatternOpTypeType from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft from pytensor.tensor.blockwise import Blockwise, _squeeze_left from pytensor.tensor.math import Dot @@ -20,6 +21,12 @@ ) +def blockwise_of(core_op: OpPatternOpTypeType | OpPattern) -> OpPattern: + if not isinstance(core_op, Op | OpPattern): + core_op = OpPattern(core_op) + return OpPattern(Blockwise, core_op=core_op) + + @node_rewriter([Blockwise]) def local_useless_blockwise(fgraph, node): """ @@ -71,22 +78,24 @@ def local_useless_unbatched_blockwise(fgraph, node): @register_canonicalize @register_stabilize @register_specialize -@node_rewriter(tracks=[Blockwise]) +@node_rewriter( + tracks=[ + blockwise_of( + Dot + | Alloc + | ARange + | Subtensor + | AdvancedSubtensor + | AdvancedIncSubtensor + | Reshape + ) + ] +) def local_eager_useless_unbatched_blockwise(fgraph, node): - if isinstance( - node.op.core_op, - Dot - | Alloc - | ARange - | Subtensor - | AdvancedSubtensor - | AdvancedIncSubtensor - | Reshape, - ): - # Many Dot-related rewrites (eg, all of BlasOpt) happen before specialize - # These other Ops can't always be trivially vectorized at runtime, - # since their inputs may imply non-rectangular shapes. - return local_useless_unbatched_blockwise.fn(fgraph, node) + # Many Dot-related rewrites (eg, all of BlasOpt) happen before specialize + # These other Ops can't always be trivially vectorized at runtime, + # since their inputs may imply non-rectangular shapes. + return local_useless_unbatched_blockwise.fn(fgraph, node) @register_specialize("shape_unsafe") @@ -204,7 +213,7 @@ def local_blockwise_alloc(fgraph, node): @register_specialize -@node_rewriter([Blockwise]) +@node_rewriter([blockwise_of(Reshape)]) def local_blockwise_reshape(fgraph, node): """Rewrite away square Blockwise reshapes. @@ -215,9 +224,6 @@ def local_blockwise_reshape(fgraph, node): For the square Reshape case, we must wait for all the intermediate operations to be lifted as Allocs """ - if not isinstance(node.op.core_op, Reshape): - return None - x, output_shape = node.inputs batch_ndim = node.op.batch_ndim(node) if all(output_shape.type.broadcastable[:batch_ndim]): diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index f08f19f06c..8ab49860fb 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -26,6 +26,7 @@ out2in, ) from pytensor.graph.rewriting.db import SequenceDB +from pytensor.graph.rewriting.unify import OpPattern from pytensor.graph.utils import InconsistencyError, MethodNotDefined from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop from pytensor.tensor.basic import ( @@ -37,6 +38,7 @@ from pytensor.tensor.rewriting.basic import ( alloc_like, broadcasted_by, + elemwise_of, register_canonicalize, register_specialize, register_stabilize, @@ -422,7 +424,14 @@ def local_useless_dimshuffle_makevector(fgraph, node): @register_canonicalize -@node_rewriter([Elemwise]) +@node_rewriter( + [ + elemwise_of( + OpPattern(ps.ScalarOp, output_types_preference=ps.upgrade_to_float) + ), + elemwise_of(OpPattern(ps.ScalarOp, output_types_preference=ps.upcast_out)), + ] +) def local_upcast_elemwise_constant_inputs(fgraph, node): """This explicitly upcasts constant inputs to elemwise Ops, when those Ops do implicit upcasting anyway. @@ -433,12 +442,6 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): if len(node.outputs) > 1: return None - if getattr(node.op.scalar_op, "output_types_preference", None) not in ( - ps.upgrade_to_float, - ps.upcast_out, - ): - return None - # this is the kind of op that we can screw with the input # dtypes by upcasting explicitly [old_out] = node.outputs @@ -988,13 +991,9 @@ def print_profile(stream, prof, level=0): @register_canonicalize @register_specialize -@node_rewriter([Elemwise]) +@node_rewriter([elemwise_of(ps.Composite)]) def local_useless_composite_outputs(fgraph, node): """Remove inputs and outputs of Composite Ops that are not used anywhere.""" - if not ( - isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Composite) - ): - return comp = node.op.scalar_op used_outputs_idxs = [ i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern] @@ -1104,14 +1103,10 @@ def local_careduce_fusion(fgraph, node): return [new_car_op(*elm_inputs)] -@node_rewriter([Elemwise]) +@node_rewriter([elemwise_of(ps.Composite)]) def local_inline_composite_constants(fgraph, node): """Inline scalar constants in Composite graphs.""" composite_op = node.op.scalar_op - - if not isinstance(composite_op, ps.Composite): - return None - new_outer_inputs = [] new_inner_inputs = [] inner_replacements = {} @@ -1287,14 +1282,9 @@ def _rebuild_partial_2f1grad_loop(node, wrt): @register_specialize -@node_rewriter([Elemwise]) +@node_rewriter([elemwise_of(Grad2F1Loop)]) def local_useless_2f1grad_loop(fgraph, node): # Remove unused terms from the hyp2f1 grad loop - - loop_op = node.op.scalar_op - if not isinstance(loop_op, Grad2F1Loop): - return - grad_related_vars = node.outputs[:-4] # Rewrite was already applied if len(grad_related_vars) // 3 != 3: @@ -1326,18 +1316,13 @@ def local_useless_2f1grad_loop(fgraph, node): return replacements -@node_rewriter([Elemwise]) +@node_rewriter([elemwise_of(Grad2F1Loop)]) def split_2f1grad_loop(fgraph, node): """ 2f1grad loop has too many operands for Numpy frompyfunc code used by Elemwise nodes on python mode. This rewrite splits it across 3 different operations. It is not needed if `local_useless_2f1grad_loop` was applied """ - loop_op = node.op.scalar_op - - if not isinstance(loop_op, Grad2F1Loop): - return None - grad_related_vars = node.outputs[:-4] # local_useless_2f1grad_loop was used, we should be safe if len(grad_related_vars) // 3 != 3: diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 8367642c4c..a09d994228 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -13,6 +13,7 @@ in2out, node_rewriter, ) +from pytensor.graph.rewriting.unify import OpPattern from pytensor.scalar.basic import Abs, Log, Mul, Sign from pytensor.tensor.basic import ( AllocDiag, @@ -43,6 +44,7 @@ register_specialize, register_stabilize, ) +from pytensor.tensor.rewriting.blockwise import blockwise_of from pytensor.tensor.slinalg import ( BlockDiagonal, Cholesky, @@ -60,7 +62,7 @@ logger = logging.getLogger(__name__) -ALL_INVERSE_OPS = (MatrixInverse, MatrixPinv) +MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv) def is_matrix_transpose(x: TensorVariable) -> bool: @@ -129,69 +131,48 @@ def inv_as_solve(fgraph, node): @register_stabilize @register_canonicalize -@node_rewriter([Blockwise]) +@node_rewriter([blockwise_of(OpPattern(Solve, assume_a="gen"))]) def generic_solve_to_solve_triangular(fgraph, node): """ If any solve() is applied to the output of a cholesky op, then replace it with a triangular solve. """ - if isinstance(node.op.core_op, Solve): - if node.op.core_op.assume_a == "gen": - A, b = node.inputs # result is solution Ax=b - if ( - A.owner - and isinstance(A.owner.op, Blockwise) - and isinstance(A.owner.op.core_op, Cholesky) - ): - if A.owner.op.core_op.lower: - return [ - solve_triangular( - A, b, lower=True, b_ndim=node.op.core_op.b_ndim - ) - ] - else: - return [ - solve_triangular( - A, b, lower=False, b_ndim=node.op.core_op.b_ndim - ) - ] - if is_matrix_transpose(A): - (A_T,) = A.owner.inputs - if ( - A_T.owner - and isinstance(A_T.owner.op, Blockwise) - and isinstance(A_T.owner.op, Cholesky) - ): - if A_T.owner.op.lower: - return [ - solve_triangular( - A, b, lower=False, b_ndim=node.op.core_op.b_ndim - ) - ] - else: - return [ - solve_triangular( - A, b, lower=True, b_ndim=node.op.core_op.b_ndim - ) - ] + A, b = node.inputs # result is the solution to Ax=b + if ( + A.owner + and isinstance(A.owner.op, Blockwise) + and isinstance(A.owner.op.core_op, Cholesky) + ): + if A.owner.op.core_op.lower: + return [solve_triangular(A, b, lower=True, b_ndim=node.op.core_op.b_ndim)] + else: + return [solve_triangular(A, b, lower=False, b_ndim=node.op.core_op.b_ndim)] + if is_matrix_transpose(A): + (A_T,) = A.owner.inputs + if ( + A_T.owner + and isinstance(A_T.owner.op, Blockwise) + and isinstance(A_T.owner.op, Cholesky) + ): + if A_T.owner.op.lower: + return [ + solve_triangular(A, b, lower=False, b_ndim=node.op.core_op.b_ndim) + ] + else: + return [ + solve_triangular(A, b, lower=True, b_ndim=node.op.core_op.b_ndim) + ] @register_specialize -@node_rewriter([Blockwise]) +@node_rewriter([blockwise_of(OpPattern(SolveBase, b_ndim=1))]) def batched_vector_b_solve_to_matrix_b_solve(fgraph, node): """Replace a batched Solve(a, b, b_ndim=1) by Solve(a, b.T, b_ndim=2).T `a` must have no batched dimensions, while `b` can have arbitrary batched dimensions. """ core_op = node.op.core_op - - if not isinstance(core_op, SolveBase): - return None - - if node.op.core_op.b_ndim != 1: - return None - [a, b] = node.inputs # Check `b` is actually batched @@ -242,26 +223,24 @@ def no_transpose_symmetric(fgraph, node): @register_stabilize -@node_rewriter([Blockwise]) +@node_rewriter([blockwise_of(OpPattern(Solve, b_ndim=2))]) def psd_solve_with_chol(fgraph, node): """ This utilizes a boolean `psd` tag on matrices. """ - if isinstance(node.op.core_op, Solve) and node.op.core_op.b_ndim == 2: - A, b = node.inputs # result is solution Ax=b - if getattr(A.tag, "psd", None) is True: - L = cholesky(A) - # N.B. this can be further reduced to a yet-unwritten cho_solve Op - # __if__ no other Op makes use of the L matrix during the - # stabilization - Li_b = solve_triangular(L, b, lower=True, b_ndim=2) - x = solve_triangular((L.mT), Li_b, lower=False, b_ndim=2) - return [x] + A, b = node.inputs # result is the solution to Ax=b + if getattr(A.tag, "psd", None) is True: + L = cholesky(A) + # N.B. this can be further reduced to cho_solve Op + # if no other Op makes use of the L matrix + Li_b = solve_triangular(L, b, lower=True, b_ndim=2) + x = solve_triangular((L.mT), Li_b, lower=False, b_ndim=2) + return [x] @register_canonicalize @register_stabilize -@node_rewriter([Blockwise]) +@node_rewriter([blockwise_of(Cholesky)]) def cholesky_ldotlt(fgraph, node): """ rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular, @@ -271,9 +250,6 @@ def cholesky_ldotlt(fgraph, node): This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices. """ - if not isinstance(node.op.core_op, Cholesky): - return - A = node.inputs[0] if not ( A.owner is not None and (isinstance(A.owner.op, Dot) or (A.owner.op == _matmul)) @@ -342,7 +318,7 @@ def local_log_prod_sqr(fgraph, node): @register_specialize -@node_rewriter([Blockwise]) +@node_rewriter([blockwise_of(MatrixInverse | Cholesky | MatrixPinv)]) def local_lift_through_linalg( fgraph: FunctionGraph, node: Apply ) -> list[Variable] | None: @@ -370,9 +346,6 @@ def local_lift_through_linalg( """ # TODO: Simplify this if we end up Blockwising KroneckerProduct - if not isinstance(node.op.core_op, MatrixInverse | Cholesky | MatrixPinv): - return None - y = node.inputs[0] outer_op = node.op @@ -534,15 +507,12 @@ def rewrite_det_diag_to_prod_diag(fgraph, node): @register_canonicalize @register_stabilize @register_specialize -@node_rewriter([Blockwise]) +@node_rewriter([blockwise_of(SVD)]) def svd_uv_merge(fgraph, node): """If we have more than one `SVD` `Op`s and at least one has keyword argument `compute_uv=True`, then we can change `compute_uv = False` to `True` everywhere and allow `pytensor` to re-use the decomposition outputs instead of recomputing. """ - if not isinstance(node.op.core_op, SVD): - return - (x,) = node.inputs if node.op.core_op.compute_uv: @@ -585,7 +555,7 @@ def svd_uv_merge(fgraph, node): @register_canonicalize @register_stabilize -@node_rewriter([Blockwise]) +@node_rewriter([blockwise_of(MATRIX_INVERSE_OPS)]) def rewrite_inv_inv(fgraph, node): """ This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once. @@ -607,9 +577,6 @@ def rewrite_inv_inv(fgraph, node): # Check if its a valid inverse operation (either inv/pinv) # In case the outer operation is an inverse, it directly goes to the next step of finding inner operation # If the outer operation is not a valid inverse, we do not apply this rewrite - if not isinstance(node.op.core_op, ALL_INVERSE_OPS): - return None - potential_inner_inv = node.inputs[0].owner if potential_inner_inv is None or potential_inner_inv.op is None: return None @@ -618,7 +585,7 @@ def rewrite_inv_inv(fgraph, node): if not ( potential_inner_inv and isinstance(potential_inner_inv.op, Blockwise) - and isinstance(potential_inner_inv.op.core_op, ALL_INVERSE_OPS) + and isinstance(potential_inner_inv.op.core_op, MATRIX_INVERSE_OPS) ): return None return [potential_inner_inv.inputs[0]] @@ -626,7 +593,7 @@ def rewrite_inv_inv(fgraph, node): @register_canonicalize @register_stabilize -@node_rewriter([Blockwise]) +@node_rewriter([blockwise_of(MATRIX_INVERSE_OPS)]) def rewrite_inv_eye_to_eye(fgraph, node): """ This rewrite takes advantage of the fact that the inverse of an identity matrix is the matrix itself @@ -642,10 +609,6 @@ def rewrite_inv_eye_to_eye(fgraph, node): list of Variable, optional List of optimized variables, or None if no optimization was performed """ - core_op = node.op.core_op - if not (isinstance(core_op, ALL_INVERSE_OPS)): - return None - # Check whether input to inverse is Eye and the 1's are on main diagonal potential_eye = node.inputs[0] if not ( @@ -659,7 +622,7 @@ def rewrite_inv_eye_to_eye(fgraph, node): @register_canonicalize @register_stabilize -@node_rewriter([Blockwise]) +@node_rewriter([blockwise_of(MATRIX_INVERSE_OPS)]) def rewrite_inv_diag_to_diag_reciprocal(fgraph, node): """ This rewrite takes advantage of the fact that for a diagonal matrix, the inverse is a diagonal matrix with the new diagonal entries as reciprocals of the original diagonal elements. @@ -677,10 +640,6 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node): list of Variable, optional List of optimized variables, or None if no optimization was performed """ - core_op = node.op.core_op - if not (isinstance(core_op, ALL_INVERSE_OPS)): - return None - inputs = node.inputs[0] # Check for use of pt.diag first if ( @@ -857,7 +816,7 @@ def rewrite_det_kronecker(fgraph, node): @register_canonicalize @register_stabilize -@node_rewriter([Blockwise]) +@node_rewriter([blockwise_of(Cholesky)]) def rewrite_remove_useless_cholesky(fgraph, node): """ This rewrite takes advantage of the fact that the cholesky decomposition of an identity matrix is the matrix itself @@ -877,8 +836,6 @@ def rewrite_remove_useless_cholesky(fgraph, node): List of optimized variables, or None if no optimization was performed """ # Find whether cholesky op is being applied - if not isinstance(node.op.core_op, Cholesky): - return None # Check whether input to Cholesky is Eye and the 1's are on main diagonal potential_eye = node.inputs[0] @@ -894,12 +851,8 @@ def rewrite_remove_useless_cholesky(fgraph, node): @register_canonicalize @register_stabilize -@node_rewriter([Blockwise]) +@node_rewriter([blockwise_of(Cholesky)]) def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node): - # Find whether cholesky op is being applied - if not isinstance(node.op.core_op, Cholesky): - return None - [input] = node.inputs # Check if input is a (1, 1) matrix @@ -1022,7 +975,7 @@ def slogdet_specialization(fgraph, node): @register_stabilize @register_canonicalize -@node_rewriter([Blockwise]) +@node_rewriter([blockwise_of(SolveBase)]) def scalar_solve_to_division(fgraph, node): """ Replace solve(a, b) with b / a if a is a (1, 1) matrix diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 28f1f87fff..e313163b6e 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -37,7 +37,6 @@ zeros, zeros_like, ) -from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.extra_ops import broadcast_arrays, concat_with_broadcast @@ -49,6 +48,11 @@ _dot, _matmul, add, + arccosh, + arcsinh, + arctanh, + cosh, + deg2rad, digamma, dot, erf, @@ -70,13 +74,16 @@ neg, polygamma, prod, + rad2deg, reciprocal, sigmoid, sign, + sinh, softplus, sqr, sqrt, sub, + tanh, tri_gamma, true_div, variadic_add, @@ -96,6 +103,7 @@ register_uncanonicalize, register_useless, ) +from pytensor.tensor.rewriting.blockwise import blockwise_of from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift from pytensor.tensor.rewriting.linalg import is_matrix_transpose from pytensor.tensor.shape import Shape, Shape_i @@ -151,7 +159,7 @@ def local_0_dot_x(fgraph, node): @register_stabilize -@node_rewriter([Blockwise]) +@node_rewriter([blockwise_of(BlockDiagonal)]) def local_block_diag_dot_to_dot_block_diag(fgraph, node): r""" Perform the rewrite ``dot(block_diag(A, B), C) -> concat(dot(A, C), dot(B, C))`` @@ -160,9 +168,6 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node): of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than a single dot on the larger matrix. """ - if not isinstance(node.op.core_op, BlockDiagonal): - return - # Check that the BlockDiagonal is an input to a Dot node: for client in itertools.chain.from_iterable( get_clients_at_depth(fgraph, node, depth=i) for i in [1, 2] @@ -424,60 +429,30 @@ def local_dot_to_mul(fgraph, node): return [new_out] -def is_inverse_pair(node_op, prev_op, inv_pair): - """ - Given two consecutive operations, check if they are the - provided pair of inverse functions. - - """ - node_is_op0 = isinstance(node_op, inv_pair[0]) - node_is_op1 = isinstance(node_op, inv_pair[1]) - prev_is_op0 = isinstance(prev_op, inv_pair[0]) - prev_is_op1 = isinstance(prev_op, inv_pair[1]) - - return (node_is_op0 and prev_is_op1) or (node_is_op1 and prev_is_op0) - - -@register_canonicalize -@register_specialize -@node_rewriter([Elemwise]) -def local_func_inv(fgraph, node): - """ - Check for two consecutive operations that are functional inverses - and remove them from the function graph. - - """ - inv_pairs = ( - (ps.Deg2Rad, ps.Rad2Deg), - (ps.Cosh, ps.ArcCosh), - (ps.Tanh, ps.ArcTanh), - (ps.Sinh, ps.ArcSinh), - (ps.Conj, ps.Conj), - (ps.Neg, ps.Neg), - (ps.Reciprocal, ps.Reciprocal), - ) - x = node.inputs[0] - - if not isinstance(node.op, Elemwise): - return - if not (x.owner and isinstance(x.owner.op, Elemwise)): - return - - prev_op = x.owner.op.scalar_op - node_op = node.op.scalar_op - - for inv_pair in inv_pairs: - if is_inverse_pair(node_op, prev_op, inv_pair): - # We don't need to copy stack trace, because the rewrite - # is trivial and maintains the earlier stack trace - ottype = node.out.dtype - inp = x.owner.inputs[0] - # Functions may have casted integer input to float - if inp.dtype != ottype: - inp = cast(inp, ottype) - return [inp] +for pair in ( + (deg2rad, rad2deg), + (cosh, arccosh), + (tanh, arctanh), + (sinh, arcsinh), + (_conj, _conj), + (neg, neg), + (reciprocal, reciprocal), +): + # Create a simple PatternNodeRewriter for each pair of opposite ops + # instead of a general Op that is called to often for very few hits + for op, inv_op in (pair, reversed(pair)): + rewrite = PatternNodeRewriter( + (op, (inv_op, "x")), + "x", + allow_multiple_clients=True, + allow_cast=True, + name=f"useless_{op}_of_{inv_op}", + ) + register_canonicalize(rewrite) + register_specialize(rewrite) - return + if op is inv_op: + break # Same Op, no need to define two rewrites @register_canonicalize diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 93a94fac09..6206cdcd0a 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -35,7 +35,7 @@ switch, ) from pytensor.tensor.basic import constant as tensor_constant -from pytensor.tensor.blockwise import Blockwise, _squeeze_left +from pytensor.tensor.blockwise import _squeeze_left from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.extra_ops import broadcast_to @@ -58,6 +58,7 @@ register_specialize, register_stabilize, ) +from pytensor.tensor.rewriting.blockwise import blockwise_of from pytensor.tensor.shape import ( shape_padleft, shape_padright, @@ -974,33 +975,30 @@ def movable(i): and not i.owner.op.set_instead_of_inc ) - if node.op == add: - o_type = node.outputs[0].type + o_type = node.outputs[0].type - movable_inputs = [i for i in node.inputs if movable(i)] + movable_inputs = [i for i in node.inputs if movable(i)] - if movable_inputs: - new_inputs = [i for i in node.inputs if not movable(i)] + [ - mi.owner.inputs[0] for mi in movable_inputs - ] - new_add = variadic_add(*new_inputs) - # Copy over stacktrace from original output, as an error - # (e.g. an index error) in this add operation should - # correspond to an error in the original add operation. - copy_stack_trace(node.outputs[0], new_add) - - # stack up the new incsubtensors - tip = new_add - for mi in movable_inputs: - assert o_type.is_super(tip.type) - tip = mi.owner.op(tip, *mi.owner.inputs[1:]) - # Copy over stacktrace from outputs of the original - # "movable" operation to the new operation. - copy_stack_trace(node.outputs + mi.owner.outputs, tip) + if movable_inputs: + new_inputs = [i for i in node.inputs if not movable(i)] + [ + mi.owner.inputs[0] for mi in movable_inputs + ] + new_add = variadic_add(*new_inputs) + # Copy over stacktrace from original output, as an error + # (e.g. an index error) in this add operation should + # correspond to an error in the original add operation. + copy_stack_trace(node.outputs[0], new_add) - return [tip] + # stack up the new incsubtensors + tip = new_add + for mi in movable_inputs: + assert o_type.is_super(tip.type) + tip = mi.owner.op(tip, *mi.owner.inputs[1:]) + # Copy over stacktrace from outputs of the original + # "movable" operation to the new operation. + copy_stack_trace(node.outputs + mi.owner.outputs, tip) - # print incsub_inputs, [id(i.owner.inputs[0]) for i in incsub_inputs] + return [tip] # We register it in a WalkingGraphRewriter inside the canonizer EQ optimizer. @@ -1576,7 +1574,7 @@ def local_uint_constant_indices(fgraph, node): @register_stabilize @register_specialize -@node_rewriter([Blockwise]) +@node_rewriter([blockwise_of(Subtensor)]) def local_blockwise_of_subtensor(fgraph, node): """Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor. @@ -1585,9 +1583,6 @@ def local_blockwise_of_subtensor(fgraph, node): TODO: Handle batched indices like we do with blockwise of inc_subtensor TODO: Extend to AdvanceSubtensor """ - if not isinstance(node.op.core_op, Subtensor): - return - x, *idxs = node.inputs if not all(all(idx.type.broadcastable) for idx in idxs): return @@ -1603,7 +1598,7 @@ def local_blockwise_of_subtensor(fgraph, node): @register_canonicalize("shape_unsafe") @register_stabilize("shape_unsafe") @register_specialize("shape_unsafe") -@node_rewriter([Blockwise]) +@node_rewriter([blockwise_of(IncSubtensor | AdvancedIncSubtensor)]) def local_blockwise_inc_subtensor(fgraph, node): """Rewrite blockwised inc_subtensors. @@ -1614,12 +1609,9 @@ def local_blockwise_inc_subtensor(fgraph, node): and can be safely rewritten without Blockwise. """ core_op = node.op.core_op - if not isinstance(core_op, AdvancedIncSubtensor | IncSubtensor): - return None - x, y, *idxs = node.inputs [out] = node.outputs - if isinstance(node.op.core_op, AdvancedIncSubtensor): + if isinstance(core_op, AdvancedIncSubtensor): if any( ( # Blockwise requires all inputs to be tensors so it is not possible