88from operator import or_
99from warnings import warn
1010
11- import pytensor .scalar .basic as ps
12- from pytensor import clone_replace , compile
1311from pytensor .compile .function .types import Supervisor
14- from pytensor .compile .mode import get_target_language
12+ from pytensor .compile .mode import get_target_language , optdb
1513from pytensor .configdefaults import config
1614from pytensor .graph .basic import Apply , Variable
1715from pytensor .graph .destroyhandler import DestroyHandler , inplace_candidates
1816from pytensor .graph .features import ReplaceValidate
1917from pytensor .graph .fg import FunctionGraph , Output
2018from pytensor .graph .op import Op
19+ from pytensor .graph .replace import clone_replace
2120from pytensor .graph .rewriting .basic import (
2221 GraphRewriter ,
2322 copy_stack_trace ,
3029from pytensor .graph .rewriting .unify import OpPattern
3130from pytensor .graph .traversal import toposort
3231from pytensor .graph .utils import InconsistencyError , MethodNotDefined
33- from pytensor .scalar .math import Grad2F1Loop , _grad_2f1_loop
34- from pytensor .tensor .basic import (
35- MakeVector ,
36- constant ,
32+ from pytensor .scalar import (
33+ Add ,
34+ Composite ,
35+ Mul ,
36+ ScalarOp ,
37+ get_scalar_type ,
38+ transfer_type ,
39+ upcast_out ,
40+ upgrade_to_float ,
3741)
42+ from pytensor .scalar import cast as scalar_cast
43+ from pytensor .scalar import constant as scalar_constant
44+ from pytensor .scalar .math import Grad2F1Loop , _grad_2f1_loop
45+ from pytensor .tensor .basic import MakeVector
46+ from pytensor .tensor .basic import constant as tensor_constant
3847from pytensor .tensor .elemwise import CAReduce , DimShuffle , Elemwise
3948from pytensor .tensor .math import add , exp , mul
4049from pytensor .tensor .rewriting .basic import (
@@ -280,7 +289,7 @@ def create_inplace_node(self, node, inplace_pattern):
280289 inplace_pattern = {i : o for i , [o ] in inplace_pattern .items ()}
281290 if hasattr (scalar_op , "make_new_inplace" ):
282291 new_scalar_op = scalar_op .make_new_inplace (
283- ps . transfer_type (
292+ transfer_type (
284293 * [
285294 inplace_pattern .get (i , o .dtype )
286295 for i , o in enumerate (node .outputs )
@@ -289,14 +298,14 @@ def create_inplace_node(self, node, inplace_pattern):
289298 )
290299 else :
291300 new_scalar_op = type (scalar_op )(
292- ps . transfer_type (
301+ transfer_type (
293302 * [inplace_pattern .get (i , None ) for i in range (len (node .outputs ))]
294303 )
295304 )
296305 return type (op )(new_scalar_op , inplace_pattern ).make_node (* node .inputs )
297306
298307
299- compile . optdb .register (
308+ optdb .register (
300309 "inplace_elemwise" ,
301310 InplaceElemwiseOptimizer (),
302311 "inplace_elemwise_opt" , # for historic reason
@@ -428,10 +437,8 @@ def local_useless_dimshuffle_makevector(fgraph, node):
428437@register_canonicalize
429438@node_rewriter (
430439 [
431- elemwise_of (
432- OpPattern (ps .ScalarOp , output_types_preference = ps .upgrade_to_float )
433- ),
434- elemwise_of (OpPattern (ps .ScalarOp , output_types_preference = ps .upcast_out )),
440+ elemwise_of (OpPattern (ScalarOp , output_types_preference = upgrade_to_float )),
441+ elemwise_of (OpPattern (ScalarOp , output_types_preference = upcast_out )),
435442 ]
436443)
437444def local_upcast_elemwise_constant_inputs (fgraph , node ):
@@ -452,7 +459,7 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
452459 changed = False
453460 for i , inp in enumerate (node .inputs ):
454461 if inp .type .dtype != output_dtype and isinstance (inp , TensorConstant ):
455- new_inputs [i ] = constant (inp .data .astype (output_dtype ))
462+ new_inputs [i ] = tensor_constant (inp .data .astype (output_dtype ))
456463 changed = True
457464
458465 if not changed :
@@ -531,7 +538,7 @@ def add_requirements(self, fgraph):
531538 @staticmethod
532539 def elemwise_to_scalar (inputs , outputs ):
533540 replacement = {
534- inp : ps . get_scalar_type (inp .type .dtype ).make_variable () for inp in inputs
541+ inp : get_scalar_type (inp .type .dtype ).make_variable () for inp in inputs
535542 }
536543 for node in toposort (outputs , blockers = inputs ):
537544 scalar_inputs = [replacement [inp ] for inp in node .inputs ]
@@ -853,7 +860,7 @@ def elemwise_scalar_op_has_c_code(
853860 scalar_inputs , scalar_outputs = self .elemwise_to_scalar (inputs , outputs )
854861 composite_outputs = Elemwise (
855862 # No need to clone Composite graph, because `self.elemwise_to_scalar` creates fresh variables
856- ps . Composite (scalar_inputs , scalar_outputs , clone_graph = False )
863+ Composite (scalar_inputs , scalar_outputs , clone_graph = False )
857864 )(* inputs , return_list = True )
858865 assert len (outputs ) == len (composite_outputs )
859866 for old_out , composite_out in zip (outputs , composite_outputs ):
@@ -913,7 +920,7 @@ def print_profile(stream, prof, level=0):
913920
914921@register_canonicalize
915922@register_specialize
916- @node_rewriter ([elemwise_of (ps . Composite )])
923+ @node_rewriter ([elemwise_of (Composite )])
917924def local_useless_composite_outputs (fgraph , node ):
918925 """Remove inputs and outputs of Composite Ops that are not used anywhere."""
919926 comp = node .op .scalar_op
@@ -934,7 +941,7 @@ def local_useless_composite_outputs(fgraph, node):
934941 node .outputs
935942 ):
936943 used_inputs = [node .inputs [i ] for i in used_inputs_idxs ]
937- c = ps . Composite (inputs = used_inner_inputs , outputs = used_inner_outputs )
944+ c = Composite (inputs = used_inner_inputs , outputs = used_inner_outputs )
938945 e = Elemwise (scalar_op = c )(* used_inputs , return_list = True )
939946 return dict (zip ([node .outputs [i ] for i in used_outputs_idxs ], e , strict = True ))
940947
@@ -948,7 +955,7 @@ def local_careduce_fusion(fgraph, node):
948955
949956 # FIXME: This check is needed because of the faulty logic in the FIXME below!
950957 # Right now, rewrite only works for `Sum`/`Prod`
951- if not isinstance (car_scalar_op , ps . Add | ps . Mul ):
958+ if not isinstance (car_scalar_op , Add | Mul ):
952959 return None
953960
954961 elm_node = car_input .owner
@@ -992,19 +999,19 @@ def local_careduce_fusion(fgraph, node):
992999 car_acc_dtype = node .op .acc_dtype
9931000
9941001 scalar_elm_inputs = [
995- ps . get_scalar_type (inp .type .dtype ).make_variable () for inp in elm_inputs
1002+ get_scalar_type (inp .type .dtype ).make_variable () for inp in elm_inputs
9961003 ]
9971004
9981005 elm_output = elm_scalar_op (* scalar_elm_inputs )
9991006
10001007 # This input represents the previous value in the `CAReduce` binary reduction
1001- carried_car_input = ps . get_scalar_type (car_acc_dtype ).make_variable ()
1008+ carried_car_input = get_scalar_type (car_acc_dtype ).make_variable ()
10021009
10031010 scalar_fused_output = car_scalar_op (carried_car_input , elm_output )
10041011 if scalar_fused_output .type .dtype != car_acc_dtype :
1005- scalar_fused_output = ps . cast (scalar_fused_output , car_acc_dtype )
1012+ scalar_fused_output = scalar_cast (scalar_fused_output , car_acc_dtype )
10061013
1007- fused_scalar_op = ps . Composite (
1014+ fused_scalar_op = Composite (
10081015 inputs = [carried_car_input , * scalar_elm_inputs ], outputs = [scalar_fused_output ]
10091016 )
10101017
@@ -1025,7 +1032,7 @@ def local_careduce_fusion(fgraph, node):
10251032 return [new_car_op (* elm_inputs )]
10261033
10271034
1028- @node_rewriter ([elemwise_of (ps . Composite )])
1035+ @node_rewriter ([elemwise_of (Composite )])
10291036def local_inline_composite_constants (fgraph , node ):
10301037 """Inline scalar constants in Composite graphs."""
10311038 composite_op = node .op .scalar_op
@@ -1041,7 +1048,7 @@ def local_inline_composite_constants(fgraph, node):
10411048 and "complex" not in outer_inp .type .dtype
10421049 ):
10431050 if outer_inp .unique_value is not None :
1044- inner_replacements [inner_inp ] = ps . constant (
1051+ inner_replacements [inner_inp ] = scalar_constant (
10451052 outer_inp .unique_value , dtype = inner_inp .dtype
10461053 )
10471054 continue
@@ -1054,7 +1061,7 @@ def local_inline_composite_constants(fgraph, node):
10541061 new_inner_outs = clone_replace (
10551062 composite_op .fgraph .outputs , replace = inner_replacements
10561063 )
1057- new_composite_op = ps . Composite (new_inner_inputs , new_inner_outs )
1064+ new_composite_op = Composite (new_inner_inputs , new_inner_outs )
10581065 new_outputs = Elemwise (new_composite_op ).make_node (* new_outer_inputs ).outputs
10591066
10601067 # Some of the inlined constants were broadcasting the output shape
@@ -1095,7 +1102,7 @@ def constant_fold_branches_of_add_mul(fgraph, node):
10951102 if other_inps :
10961103 python_op = operator .mul if node .op == mul else operator .add
10971104 folded_inputs = [reference_inp , * other_inps ]
1098- new_inp = constant (
1105+ new_inp = tensor_constant (
10991106 reduce (python_op , (const .data for const in folded_inputs ))
11001107 )
11011108 new_constants = [
@@ -1119,7 +1126,7 @@ def constant_fold_branches_of_add_mul(fgraph, node):
11191126
11201127
11211128add_mul_fusion_seqopt = SequenceDB ()
1122- compile . optdb .register (
1129+ optdb .register (
11231130 "add_mul_fusion" ,
11241131 add_mul_fusion_seqopt ,
11251132 "fast_run" ,
@@ -1140,7 +1147,7 @@ def constant_fold_branches_of_add_mul(fgraph, node):
11401147
11411148# Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites)
11421149fuse_seqopt = SequenceDB ()
1143- compile . optdb .register (
1150+ optdb .register (
11441151 "elemwise_fusion" ,
11451152 fuse_seqopt ,
11461153 "fast_run" ,
@@ -1271,7 +1278,7 @@ def split_2f1grad_loop(fgraph, node):
12711278 return replacements
12721279
12731280
1274- compile . optdb ["py_only" ].register (
1281+ optdb ["py_only" ].register (
12751282 "split_2f1grad_loop" ,
12761283 split_2f1grad_loop ,
12771284 "fast_compile" ,
0 commit comments