1616from aesara .compile .ops import ViewOp
1717from aesara .configdefaults import config
1818from aesara .graph .basic import (
19+ Apply ,
1920 Constant ,
2021 Variable ,
2122 ancestors ,
2425)
2526from aesara .graph .features import AlreadyThere , Feature , ReplaceValidate
2627from aesara .graph .fg import FunctionGraph
27- from aesara .graph .op import get_test_value
28+ from aesara .graph .op import compute_test_value , get_test_value
2829from aesara .graph .opt import (
2930 GlobalOptimizer ,
3031 OpRemove ,
@@ -3003,7 +3004,7 @@ def local_fuse(fgraph, node):
30033004 fused = False
30043005
30053006 for i in node .inputs :
3006- do_fusion = False
3007+ scalar_node : Optional [ Apply ] = None
30073008 # Will store inputs of the fused node that are not currently inputs
30083009 # of the node we want to create (to avoid duplicating inputs).
30093010 tmp_input = []
@@ -3034,36 +3035,45 @@ def local_fuse(fgraph, node):
30343035 tmp_s_input .append (tmp_scalar [tmp_input .index (ii )])
30353036 else :
30363037 tmp = aes .get_scalar_type (ii .type .dtype ).make_variable ()
3038+
30373039 try :
30383040 tv = get_test_value (ii )
3039- if tv .size > 0 :
3040- tmp .tag .test_value = tv .flatten ()[0 ]
3041- else :
3042- _logger .warning (
3043- "Cannot construct a scalar test value"
3044- " from a test value with no size: {}" .format (ii )
3045- )
3046- except TestValueError :
3041+ # Sometimes the original inputs have
3042+ # zero-valued shapes in some dimensions, which
3043+ # implies that this whole scalar thing doesn't
3044+ # make sense (i.e. we're asking for the scalar
3045+ # value of an entry in a zero-dimensional
3046+ # array).
3047+ # This will eventually lead to an error in the
3048+ # `compute_test_value` call below when/if
3049+ # `config.compute_test_value_opt` is enabled
3050+ # (for debugging, more or less)
3051+ tmp .tag .test_value = tv .item ()
3052+ except (TestValueError , ValueError ):
30473053 pass
30483054
30493055 tmp_s_input .append (tmp )
30503056 tmp_input .append (ii )
30513057 tmp_scalar .append (tmp_s_input [- 1 ])
30523058
3053- s_op = i .owner .op .scalar_op (* tmp_s_input , return_list = True )
3059+ # Use the `Op.make_node` interface in case `Op.__call__`
3060+ # has been customized
3061+ scalar_node = i .owner .op .scalar_op .make_node (* tmp_s_input )
3062+
3063+ if config .compute_test_value_opt != "off" :
3064+ # This is required because `Op.make_node` won't do it
3065+ compute_test_value (scalar_node )
30543066
30553067 # If the scalar_op doesn't have a C implementation, we skip
30563068 # its fusion to allow fusion of the other ops
30573069 i .owner .op .scalar_op .c_code (
3058- s_op [ 0 ]. owner ,
3070+ scalar_node ,
30593071 "test_presence_of_c_code" ,
30603072 ["x" for x in i .owner .inputs ],
30613073 ["z" for z in i .owner .outputs ],
30623074 {"fail" : "%(fail)s" },
30633075 )
30643076
3065- do_fusion = True
3066-
30673077 except (NotImplementedError , MethodNotDefined ):
30683078 _logger .warning (
30693079 (
@@ -3073,7 +3083,7 @@ def local_fuse(fgraph, node):
30733083 "loop fusion."
30743084 )
30753085 )
3076- do_fusion = False
3086+ scalar_node = None
30773087
30783088 # Compute the number of inputs in case we fuse this input.
30793089 # We subtract 1 because we replace the existing input with the new
@@ -3089,26 +3099,27 @@ def local_fuse(fgraph, node):
30893099 if x in node .inputs :
30903100 new_nb_input_ -= 1
30913101
3092- if do_fusion and (new_nb_input_ <= max_nb_input ):
3102+ if scalar_node and (new_nb_input_ <= max_nb_input ):
30933103 fused = True
30943104 new_nb_input = new_nb_input_
30953105 inputs .extend (tmp_input )
30963106 s_inputs .extend (tmp_scalar )
3097- s_g .extend (s_op )
3107+ s_g .extend (scalar_node . outputs )
30983108 else :
30993109 # We must support the case where the same variable appears many
31003110 # times within the inputs
31013111 if inputs .count (i ) == node .inputs .count (i ):
31023112 s = s_inputs [inputs .index (i )]
31033113 else :
31043114 s = aes .get_scalar_type (i .type .dtype ).make_variable ()
3105- try :
3106- if config . compute_test_value != "off" :
3115+ if config . compute_test_value_opt != "off" :
3116+ try :
31073117 v = get_test_value (i )
3108- if v .size > 0 :
3109- s .tag .test_value = v .flatten ()[0 ]
3110- except TestValueError :
3111- pass
3118+ # See the zero-dimensional test value situation
3119+ # described above.
3120+ s .tag .test_value = v .item ()
3121+ except (TestValueError , ValueError ):
3122+ pass
31123123
31133124 inputs .append (i )
31143125 s_inputs .append (s )
@@ -3157,7 +3168,8 @@ def local_fuse(fgraph, node):
31573168
31583169 if len (new_node .inputs ) > max_nb_input :
31593170 _logger .warning (
3160- "loop fusion failed because Op would exceed" " kernel argument limit."
3171+ "Loop fusion failed because the resulting node "
3172+ "would exceed the kernel argument limit."
31613173 )
31623174 return False
31633175
0 commit comments