Skip to content

Commit 8c9c0f3

Browse files
larryshamalamabrandonwillard
authored andcommitted
Allow vector-valued indices for switch/ifelse mixture sub-graphs
1 parent 21787e9 commit 8c9c0f3

File tree

2 files changed

+129
-65
lines changed

2 files changed

+129
-65
lines changed

aeppl/mixture.py

Lines changed: 20 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -313,48 +313,21 @@ def switch_mixture_replace(fgraph, node):
313313
return None # pragma: no cover
314314

315315
old_mixture_rv = node.default_output()
316-
# idx, component_1, component_2 = node.inputs
317-
318-
mixture_rvs = []
319-
320-
for component_rv in node.inputs[1:]:
321-
if not (
322-
component_rv.owner
323-
and isinstance(component_rv.owner.op, MeasurableVariable)
324-
and component_rv not in rv_map_feature.rv_values
325-
):
326-
return None
327-
new_node = assign_custom_measurable_outputs(component_rv.owner)
328-
out_idx = component_rv.owner.outputs.index(component_rv)
329-
new_comp_rv = new_node.outputs[out_idx]
330-
mixture_rvs.append(new_comp_rv)
331316

332-
"""
333-
Unlike mixtures generated via at.stack, there is only one condition, i.e. index
334-
for switch/ifelse-defined mixture sub-graphs. However, this condition can be
335-
non-scalar for Switch Ops.
336-
"""
337-
mix_op = MixtureRV(
338-
2,
339-
old_mixture_rv.dtype,
340-
old_mixture_rv.broadcastable,
341-
)
342-
new_node = mix_op.make_node(
343-
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs)
317+
# Add an extra dimension to the indices so that the `MixtureRV` we
318+
# construct represents a valid
319+
# `at.stack(node.inputs[1:])[f(node.inputs[0])]`, for some function `f`,
320+
# that's equivalent to `at.switch(*node.inputs)`.
321+
out_shape = at.broadcast_shape(
322+
*(tuple(v.shape) for v in node.inputs[1:]), arrays_are_shapes=True
344323
)
324+
switch_indices = (node.inputs[0],) + tuple(at.arange(s) for s in out_shape)
345325

346-
new_mixture_rv = new_node.default_output()
347-
348-
if aesara.config.compute_test_value != "off":
349-
if not hasattr(old_mixture_rv.tag, "test_value"):
350-
compute_test_value(node)
351-
352-
new_mixture_rv.tag.test_value = old_mixture_rv.tag.test_value
353-
354-
if old_mixture_rv.name:
355-
new_mixture_rv.name = f"{old_mixture_rv.name}-mixture"
326+
# Construct the proxy/intermediate mixture representation
327+
switch_stack = at.stack(node.inputs[::-1])[switch_indices]
328+
switch_stack.name = old_mixture_rv.name
356329

357-
return [new_mixture_rv]
330+
return mixture_replace.transform(fgraph, switch_stack.owner)
358331

359332

360333
@node_rewriter((IfElse,))
@@ -394,9 +367,15 @@ def ifelse_mixture_replace(fgraph, node):
394367
old_mixture_rv.dtype,
395368
old_mixture_rv.broadcastable,
396369
)
397-
new_node = mix_op.make_node(
398-
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs)
399-
)
370+
371+
if node.inputs[0].ndim == 0:
372+
# as_nontensor_scalar to allow graphs to be identical to mixture sub-graphs
373+
# created using at.stack and Subtensor indexing
374+
new_node = mix_op.make_node(
375+
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs)
376+
)
377+
else:
378+
new_node = mix_op.make_node(*([at.constant(0), node.inputs[0]] + mixture_rvs))
400379

401380
new_mixture_rv = new_node.default_output()
402381

tests/test_mixture.py

Lines changed: 109 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -233,25 +233,6 @@ def test_hetero_mixture_binomial(p_val, size):
233233
(),
234234
0,
235235
),
236-
(
237-
(
238-
np.array(0, dtype=aesara.config.floatX),
239-
np.array(1, dtype=aesara.config.floatX),
240-
),
241-
(
242-
np.array(0.5, dtype=aesara.config.floatX),
243-
np.array(0.5, dtype=aesara.config.floatX),
244-
),
245-
(
246-
np.array(100, dtype=aesara.config.floatX),
247-
np.array(1, dtype=aesara.config.floatX),
248-
),
249-
np.array([0.1, 0.5, 0.4], dtype=aesara.config.floatX),
250-
(),
251-
(),
252-
(),
253-
0,
254-
),
255236
(
256237
(
257238
np.array(0, dtype=aesara.config.floatX),
@@ -683,14 +664,118 @@ def test_mixture_with_DiracDelta():
683664
assert M_rv in logp_res
684665

685666

686-
@pytest.mark.parametrize("op", [at.switch, ifelse])
687-
def test_switch_ifelse_mixture(op):
667+
@pytest.mark.parametrize(
668+
"op, X_args, Y_args, p_val, comp_size, idx_size",
669+
[
670+
[op] + list(test_args)
671+
for op in [at.switch, ifelse]
672+
for test_args in [
673+
(
674+
(
675+
np.array(-10, dtype=aesara.config.floatX),
676+
np.array(0.1, dtype=aesara.config.floatX),
677+
),
678+
(
679+
np.array(10, dtype=aesara.config.floatX),
680+
np.array(0.1, dtype=aesara.config.floatX),
681+
),
682+
np.array(0.5, dtype=aesara.config.floatX),
683+
(),
684+
(),
685+
),
686+
(
687+
(
688+
np.array(-10, dtype=aesara.config.floatX),
689+
np.array(0.1, dtype=aesara.config.floatX),
690+
),
691+
(
692+
np.array(10, dtype=aesara.config.floatX),
693+
np.array(0.1, dtype=aesara.config.floatX),
694+
),
695+
np.array(0.5, dtype=aesara.config.floatX),
696+
(),
697+
(6,),
698+
),
699+
(
700+
(
701+
np.array([10, 20], dtype=aesara.config.floatX),
702+
np.array(0.1, dtype=aesara.config.floatX),
703+
),
704+
(
705+
np.array([-10, -20], dtype=aesara.config.floatX),
706+
np.array(0.1, dtype=aesara.config.floatX),
707+
),
708+
np.array([0.9, 0.1], dtype=aesara.config.floatX),
709+
(2,),
710+
(2,),
711+
),
712+
(
713+
(
714+
np.array([10, 20], dtype=aesara.config.floatX),
715+
np.array(0.1, dtype=aesara.config.floatX),
716+
),
717+
(
718+
np.array([-10, -20], dtype=aesara.config.floatX),
719+
np.array(0.1, dtype=aesara.config.floatX),
720+
),
721+
np.array([0.9, 0.1], dtype=aesara.config.floatX),
722+
None,
723+
None,
724+
),
725+
(
726+
(
727+
np.array(-10, dtype=aesara.config.floatX),
728+
np.array(0.1, dtype=aesara.config.floatX),
729+
),
730+
(
731+
np.array(10, dtype=aesara.config.floatX),
732+
np.array(0.1, dtype=aesara.config.floatX),
733+
),
734+
np.array(0.5, dtype=aesara.config.floatX),
735+
(2, 3),
736+
(2, 3),
737+
),
738+
(
739+
(
740+
np.array(10, dtype=aesara.config.floatX),
741+
np.array(0.1, dtype=aesara.config.floatX),
742+
),
743+
(
744+
np.array(-10, dtype=aesara.config.floatX),
745+
np.array(0.1, dtype=aesara.config.floatX),
746+
),
747+
np.array(0.5, dtype=aesara.config.floatX),
748+
(2, 3),
749+
(),
750+
),
751+
(
752+
(
753+
np.array(10, dtype=aesara.config.floatX),
754+
np.array(0.1, dtype=aesara.config.floatX),
755+
),
756+
(
757+
np.array(-10, dtype=aesara.config.floatX),
758+
np.array(0.1, dtype=aesara.config.floatX),
759+
),
760+
np.array(0.5, dtype=aesara.config.floatX),
761+
(3,),
762+
(3,),
763+
),
764+
]
765+
if not ((test_args[-1] is None or len(test_args[-1]) > 0) and op == ifelse)
766+
],
767+
)
768+
def test_switch_ifelse_mixture(op, X_args, Y_args, p_val, comp_size, idx_size):
769+
"""
770+
The argument size is both the input to srng.normal and the expected
771+
size of the mixture RV Z1_rv
772+
"""
688773
srng = at.random.RandomStream(29833)
689774

690-
X_rv = srng.normal(-10.0, 0.1, name="X")
691-
Y_rv = srng.normal(10.0, 0.1, name="Y")
775+
X_rv = srng.normal(*X_args, size=comp_size, name="X")
776+
Y_rv = srng.normal(*Y_args, size=comp_size, name="Y")
692777

693-
I_rv = srng.bernoulli(0.5, name="I")
778+
I_rv = srng.bernoulli(p_val, size=idx_size, name="I")
694779
i_vv = I_rv.clone()
695780
i_vv.name = "i"
696781

0 commit comments

Comments
 (0)