2929import pytensor .scalar .basic as ps
3030from pytensor import compile , config
3131from pytensor .compile .ops import ViewOp
32- from pytensor .graph import FunctionGraph
32+ from pytensor .graph import FunctionGraph , Op
3333from pytensor .graph .basic import Constant
3434from pytensor .graph .rewriting .basic import (
3535 NodeProcessingGraphRewriter ,
4040 node_rewriter ,
4141)
4242from pytensor .graph .rewriting .db import RewriteDatabase
43+ from pytensor .graph .rewriting .unify import OpPattern
4344from pytensor .npy_2_compat import normalize_axis_index
4445from pytensor .raise_op import Assert , CheckAndRaise , assert_op
45- from pytensor .scalar .basic import Second
4646from pytensor .tensor .basic import (
4747 Alloc ,
4848 AllocEmpty ,
@@ -225,6 +225,12 @@ def register(inner_rewriter: RewriteDatabase | Rewriter):
225225 return node_rewriter
226226
227227
228+ def elemwise_of (scalar_op ) -> OpPattern :
229+ if not isinstance (scalar_op , Op | OpPattern ):
230+ scalar_op = OpPattern (scalar_op )
231+ return OpPattern (Elemwise , scalar_op = scalar_op )
232+
233+
228234@register_canonicalize
229235@register_specialize
230236@node_rewriter ([TensorFromScalar ])
@@ -324,15 +330,12 @@ def dimshuffled_alloc(i):
324330 return new_outs
325331
326332
327- @node_rewriter ([Elemwise ])
333+ @node_rewriter ([fill ])
328334def local_fill_sink (fgraph , node ):
329335 """
330336 f(fill(a, b), fill(c, d), e) -> fill(c, fill(a, f(b, d, e)))
331337 f need to be an elemwise that isn't a fill.
332338 """
333- if isinstance (node .op .scalar_op , Second ):
334- return False
335-
336339 models = []
337340 inputs = []
338341 for inp in node .inputs :
@@ -653,7 +656,7 @@ def local_alloc_unary(fgraph, node):
653656
654657@register_canonicalize
655658@register_specialize
656- @node_rewriter ([Elemwise ])
659+ @node_rewriter ([elemwise_of ( ps . Cast ) ])
657660def local_cast_cast (fgraph , node ):
658661 """cast(cast(x, dtype1), dtype2)
659662
@@ -663,8 +666,6 @@ def local_cast_cast(fgraph, node):
663666 and the first cast cause an upcast.
664667
665668 """
666- if not (isinstance (node .op , Elemwise ) and isinstance (node .op .scalar_op , ps .Cast )):
667- return
668669 x = node .inputs [0 ]
669670 if not (
670671 x .owner
@@ -1031,19 +1032,13 @@ def local_useless_switch(fgraph, node):
10311032
10321033
10331034@register_canonicalize
1034- @node_rewriter ([Elemwise ])
1035+ @node_rewriter ([elemwise_of ( ps . BinaryScalarOp | ps . Add | ps . Mul ) ])
10351036def local_merge_switch_same_cond (fgraph , node ):
10361037 """
10371038 Merge add/sub/mul/div/minimum/maximum/... of switches sharing the same
10381039 condition, to enable further simplification of their branches
10391040 Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y)
10401041 """
1041- # node must be binary elemwise or add or mul
1042- if not (
1043- isinstance (node .op , Elemwise )
1044- and isinstance (node .op .scalar_op , ps .BinaryScalarOp | ps .Add | ps .Mul )
1045- ):
1046- return
10471042 # all inputs must be switch
10481043 if not all (
10491044 s .owner
@@ -1174,10 +1169,9 @@ def constant_folding(fgraph, node):
11741169@register_infer_shape
11751170@register_canonicalize ("fast_compile" )
11761171@register_useless ("fast_compile" )
1177- @node_rewriter (None )
1172+ @node_rewriter ([ ViewOp ] )
11781173def local_view_op (fgraph , node ):
1179- if isinstance (node .op , ViewOp ):
1180- return node .inputs
1174+ return node .inputs
11811175
11821176
11831177@register_infer_shape
0 commit comments