Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 8d4c5f4

Browse files
Remove superfluous test value warning in Elemwise fusion rewrite
1 parent 584496d commit 8d4c5f4

File tree

2 files changed

+86
-24
lines changed

2 files changed

+86
-24
lines changed

aesara/tensor/basic_opt.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from aesara.compile.ops import ViewOp
1717
from aesara.configdefaults import config
1818
from aesara.graph.basic import (
19+
Apply,
1920
Constant,
2021
Variable,
2122
ancestors,
@@ -24,7 +25,7 @@
2425
)
2526
from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate
2627
from aesara.graph.fg import FunctionGraph
27-
from aesara.graph.op import get_test_value
28+
from aesara.graph.op import compute_test_value, get_test_value
2829
from aesara.graph.opt import (
2930
GlobalOptimizer,
3031
OpRemove,
@@ -3003,7 +3004,7 @@ def local_fuse(fgraph, node):
30033004
fused = False
30043005

30053006
for i in node.inputs:
3006-
do_fusion = False
3007+
scalar_node: Optional[Apply] = None
30073008
# Will store inputs of the fused node that are not currently inputs
30083009
# of the node we want to create (to avoid duplicating inputs).
30093010
tmp_input = []
@@ -3034,36 +3035,45 @@ def local_fuse(fgraph, node):
30343035
tmp_s_input.append(tmp_scalar[tmp_input.index(ii)])
30353036
else:
30363037
tmp = aes.get_scalar_type(ii.type.dtype).make_variable()
3038+
30373039
try:
30383040
tv = get_test_value(ii)
3039-
if tv.size > 0:
3040-
tmp.tag.test_value = tv.flatten()[0]
3041-
else:
3042-
_logger.warning(
3043-
"Cannot construct a scalar test value"
3044-
" from a test value with no size: {}".format(ii)
3045-
)
3046-
except TestValueError:
3041+
# Sometimes the original inputs have
3042+
# zero-valued shapes in some dimensions, which
3043+
# implies that this whole scalar thing doesn't
3044+
# make sense (i.e. we're asking for the scalar
3045+
# value of an entry in a zero-dimensional
3046+
# array).
3047+
# This will eventually lead to an error in the
3048+
# `compute_test_value` call below when/if
3049+
# `config.compute_test_value_opt` is enabled
3050+
# (for debugging, more or less)
3051+
tmp.tag.test_value = tv.item()
3052+
except (TestValueError, ValueError):
30473053
pass
30483054

30493055
tmp_s_input.append(tmp)
30503056
tmp_input.append(ii)
30513057
tmp_scalar.append(tmp_s_input[-1])
30523058

3053-
s_op = i.owner.op.scalar_op(*tmp_s_input, return_list=True)
3059+
# Use the `Op.make_node` interface in case `Op.__call__`
3060+
# has been customized
3061+
scalar_node = i.owner.op.scalar_op.make_node(*tmp_s_input)
3062+
3063+
if config.compute_test_value_opt != "off":
3064+
# This is required because `Op.make_node` won't do it
3065+
compute_test_value(scalar_node)
30543066

30553067
# If the scalar_op doesn't have a C implementation, we skip
30563068
# its fusion to allow fusion of the other ops
30573069
i.owner.op.scalar_op.c_code(
3058-
s_op[0].owner,
3070+
scalar_node,
30593071
"test_presence_of_c_code",
30603072
["x" for x in i.owner.inputs],
30613073
["z" for z in i.owner.outputs],
30623074
{"fail": "%(fail)s"},
30633075
)
30643076

3065-
do_fusion = True
3066-
30673077
except (NotImplementedError, MethodNotDefined):
30683078
_logger.warning(
30693079
(
@@ -3073,7 +3083,7 @@ def local_fuse(fgraph, node):
30733083
"loop fusion."
30743084
)
30753085
)
3076-
do_fusion = False
3086+
scalar_node = None
30773087

30783088
# Compute the number of inputs in case we fuse this input.
30793089
# We subtract 1 because we replace the existing input with the new
@@ -3089,26 +3099,27 @@ def local_fuse(fgraph, node):
30893099
if x in node.inputs:
30903100
new_nb_input_ -= 1
30913101

3092-
if do_fusion and (new_nb_input_ <= max_nb_input):
3102+
if scalar_node and (new_nb_input_ <= max_nb_input):
30933103
fused = True
30943104
new_nb_input = new_nb_input_
30953105
inputs.extend(tmp_input)
30963106
s_inputs.extend(tmp_scalar)
3097-
s_g.extend(s_op)
3107+
s_g.extend(scalar_node.outputs)
30983108
else:
30993109
# We must support the case where the same variable appears many
31003110
# times within the inputs
31013111
if inputs.count(i) == node.inputs.count(i):
31023112
s = s_inputs[inputs.index(i)]
31033113
else:
31043114
s = aes.get_scalar_type(i.type.dtype).make_variable()
3105-
try:
3106-
if config.compute_test_value != "off":
3115+
if config.compute_test_value_opt != "off":
3116+
try:
31073117
v = get_test_value(i)
3108-
if v.size > 0:
3109-
s.tag.test_value = v.flatten()[0]
3110-
except TestValueError:
3111-
pass
3118+
# See the zero-dimensional test value situation
3119+
# described above.
3120+
s.tag.test_value = v.item()
3121+
except (TestValueError, ValueError):
3122+
pass
31123123

31133124
inputs.append(i)
31143125
s_inputs.append(s)
@@ -3157,7 +3168,8 @@ def local_fuse(fgraph, node):
31573168

31583169
if len(new_node.inputs) > max_nb_input:
31593170
_logger.warning(
3160-
"loop fusion failed because Op would exceed" " kernel argument limit."
3171+
"Loop fusion failed because the resulting node "
3172+
"would exceed the kernel argument limit."
31613173
)
31623174
return False
31633175

tests/tensor/test_basic_opt.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
import copy
23

34
import numpy as np
@@ -22,6 +23,7 @@
2223
from aesara.misc.safe_asarray import _asarray
2324
from aesara.printing import pprint
2425
from aesara.raise_op import Assert, CheckAndRaise
26+
from aesara.scalar.basic import Composite
2527
from aesara.tensor.basic import (
2628
Alloc,
2729
Join,
@@ -1152,6 +1154,54 @@ def impl(self, x):
11521154
for n in f.maker.fgraph.toposort()
11531155
)
11541156

1157+
@pytest.mark.parametrize("test_value", [np.c_[[1.0]], np.c_[[]]])
1158+
def test_test_values(self, test_value):
1159+
"""Make sure that `local_elemwise_fusion_op` uses test values correctly when they have zero dimensions.
1160+
1161+
The test values we're talking about are the ones used when C implementations
1162+
are checked.
1163+
1164+
"""
1165+
1166+
opts = OptimizationQuery(
1167+
include=[
1168+
"local_elemwise_fusion",
1169+
"composite_elemwise_fusion",
1170+
"canonicalize",
1171+
],
1172+
exclude=["cxx_only", "BlasOpt"],
1173+
)
1174+
1175+
mode = Mode(self.mode.linker, opts)
1176+
1177+
x, y, z = dmatrices("xyz")
1178+
1179+
x.tag.test_value = test_value
1180+
y.tag.test_value = test_value
1181+
z.tag.test_value = test_value
1182+
1183+
if test_value.size == 0:
1184+
cm = pytest.raises(ValueError)
1185+
else:
1186+
cm = contextlib.suppress()
1187+
1188+
with config.change_flags(
1189+
compute_test_value="raise", compute_test_value_opt="raise"
1190+
):
1191+
out = x * y + z
1192+
with cm:
1193+
f = function([x, y, z], out, mode=mode)
1194+
1195+
if test_value.size != 0:
1196+
# Confirm that the fusion happened
1197+
assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, Composite)
1198+
assert len(f.maker.fgraph.toposort()) == 1
1199+
1200+
x_c, y_c, z_c = f.maker.fgraph.outputs[0].owner.inputs
1201+
assert np.array_equal(
1202+
f.maker.fgraph.outputs[0].tag.test_value, np.c_[[2.0]]
1203+
)
1204+
11551205

11561206
class TimesN(aes.basic.UnaryScalarOp):
11571207
"""

0 commit comments

Comments
 (0)