Skip to content

Commit 74a3649

Browse files
larryshamalamaLarry Dong
authored andcommitted
Allow vector-valued indices for switch/ifelse mixture sub-graphs
1 parent 817d2fd commit 74a3649

File tree

2 files changed

+126
-30
lines changed

2 files changed

+126
-30
lines changed

aeppl/mixture.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -339,9 +339,14 @@ def switch_mixture_replace(fgraph, node):
339339
old_mixture_rv.dtype,
340340
old_mixture_rv.broadcastable,
341341
)
342-
new_node = mix_op.make_node(
343-
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs)
344-
)
342+
if node.inputs[0].ndim == 0:
343+
# as_nontensor_scalar to allow graphs to be identical to mixture sub-graphs
344+
# created using at.stack and Subtensor indexing
345+
new_node = mix_op.make_node(
346+
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs)
347+
)
348+
else:
349+
new_node = mix_op.make_node(*([NoneConst, node.inputs[0]] + mixture_rvs))
345350

346351
new_mixture_rv = new_node.default_output()
347352

@@ -394,9 +399,15 @@ def ifelse_mixture_replace(fgraph, node):
394399
old_mixture_rv.dtype,
395400
old_mixture_rv.broadcastable,
396401
)
397-
new_node = mix_op.make_node(
398-
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs)
399-
)
402+
403+
if node.inputs[0].ndim == 0:
404+
# as_nontensor_scalar to allow graphs to be identical to mixture sub-graphs
405+
# created using at.stack and Subtensor indexing
406+
new_node = mix_op.make_node(
407+
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs)
408+
)
409+
else:
410+
new_node = mix_op.make_node(*([at.constant(0), node.inputs[0]] + mixture_rvs))
400411

401412
new_mixture_rv = new_node.default_output()
402413

tests/test_mixture.py

Lines changed: 109 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -232,25 +232,6 @@ def test_hetero_mixture_binomial(p_val, size):
232232
(),
233233
0,
234234
),
235-
(
236-
(
237-
np.array(0, dtype=aesara.config.floatX),
238-
np.array(1, dtype=aesara.config.floatX),
239-
),
240-
(
241-
np.array(0.5, dtype=aesara.config.floatX),
242-
np.array(0.5, dtype=aesara.config.floatX),
243-
),
244-
(
245-
np.array(100, dtype=aesara.config.floatX),
246-
np.array(1, dtype=aesara.config.floatX),
247-
),
248-
np.array([0.1, 0.5, 0.4], dtype=aesara.config.floatX),
249-
(),
250-
(),
251-
(),
252-
0,
253-
),
254235
(
255236
(
256237
np.array(0, dtype=aesara.config.floatX),
@@ -684,14 +665,118 @@ def test_mixture_with_DiracDelta():
684665
assert M_rv in logp_res
685666

686667

687-
@pytest.mark.parametrize("op", [at.switch, ifelse])
688-
def test_switch_ifelse_mixture(op):
668+
@pytest.mark.parametrize(
669+
"op, X_args, Y_args, p_val, comp_size, idx_size",
670+
[
671+
[op] + list(test_args)
672+
for op in [at.switch, ifelse]
673+
for test_args in [
674+
(
675+
(
676+
np.array(-10, dtype=aesara.config.floatX),
677+
np.array(0.1, dtype=aesara.config.floatX),
678+
),
679+
(
680+
np.array(10, dtype=aesara.config.floatX),
681+
np.array(0.1, dtype=aesara.config.floatX),
682+
),
683+
np.array(0.5, dtype=aesara.config.floatX),
684+
(),
685+
(),
686+
),
687+
(
688+
(
689+
np.array(-10, dtype=aesara.config.floatX),
690+
np.array(0.1, dtype=aesara.config.floatX),
691+
),
692+
(
693+
np.array(10, dtype=aesara.config.floatX),
694+
np.array(0.1, dtype=aesara.config.floatX),
695+
),
696+
np.array(0.5, dtype=aesara.config.floatX),
697+
(),
698+
(6,),
699+
),
700+
(
701+
(
702+
np.array([10, 20], dtype=aesara.config.floatX),
703+
np.array(0.1, dtype=aesara.config.floatX),
704+
),
705+
(
706+
np.array([-10, -20], dtype=aesara.config.floatX),
707+
np.array(0.1, dtype=aesara.config.floatX),
708+
),
709+
np.array([0.9, 0.1], dtype=aesara.config.floatX),
710+
(2,),
711+
(2,),
712+
),
713+
(
714+
(
715+
np.array([10, 20], dtype=aesara.config.floatX),
716+
np.array(0.1, dtype=aesara.config.floatX),
717+
),
718+
(
719+
np.array([-10, -20], dtype=aesara.config.floatX),
720+
np.array(0.1, dtype=aesara.config.floatX),
721+
),
722+
np.array([0.9, 0.1], dtype=aesara.config.floatX),
723+
None,
724+
None,
725+
),
726+
(
727+
(
728+
np.array(-10, dtype=aesara.config.floatX),
729+
np.array(0.1, dtype=aesara.config.floatX),
730+
),
731+
(
732+
np.array(10, dtype=aesara.config.floatX),
733+
np.array(0.1, dtype=aesara.config.floatX),
734+
),
735+
np.array(0.5, dtype=aesara.config.floatX),
736+
(2, 3),
737+
(2, 3),
738+
),
739+
(
740+
(
741+
np.array(10, dtype=aesara.config.floatX),
742+
np.array(0.1, dtype=aesara.config.floatX),
743+
),
744+
(
745+
np.array(-10, dtype=aesara.config.floatX),
746+
np.array(0.1, dtype=aesara.config.floatX),
747+
),
748+
np.array(0.5, dtype=aesara.config.floatX),
749+
(2, 3),
750+
(),
751+
),
752+
(
753+
(
754+
np.array(10, dtype=aesara.config.floatX),
755+
np.array(0.1, dtype=aesara.config.floatX),
756+
),
757+
(
758+
np.array(-10, dtype=aesara.config.floatX),
759+
np.array(0.1, dtype=aesara.config.floatX),
760+
),
761+
np.array(0.5, dtype=aesara.config.floatX),
762+
(3,),
763+
(3,),
764+
),
765+
]
766+
if not ((test_args[-1] is None or len(test_args[-1]) > 0) and op == ifelse)
767+
],
768+
)
769+
def test_switch_ifelse_mixture(op, X_args, Y_args, p_val, comp_size, idx_size):
770+
"""
771+
The argument size is both the input to srng.normal and the expected
772+
size of the mixture RV Z1_rv
773+
"""
689774
srng = at.random.RandomStream(29833)
690775

691-
X_rv = srng.normal(-10.0, 0.1, name="X")
692-
Y_rv = srng.normal(10.0, 0.1, name="Y")
776+
X_rv = srng.normal(*X_args, size=comp_size, name="X")
777+
Y_rv = srng.normal(*Y_args, size=comp_size, name="Y")
693778

694-
I_rv = srng.bernoulli(0.5, name="I")
779+
I_rv = srng.bernoulli(p_val, size=idx_size, name="I")
695780
i_vv = I_rv.clone()
696781
i_vv.name = "i"
697782

0 commit comments

Comments
 (0)