Skip to content

Commit d0d4c7b

Browse files
Identifying mixture sub-graphs defined with an IfElse Op
1 parent 8ce3c63 commit d0d4c7b

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

aeppl/mixture.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
node_rewriter,
1111
pre_greedy_node_rewriter,
1212
)
13-
from aesara.ifelse import ifelse
13+
from aesara.ifelse import IfElse, ifelse
1414
from aesara.scalar.basic import Switch
1515
from aesara.tensor.basic import Join, MakeVector
1616
from aesara.tensor.elemwise import Elemwise
@@ -305,14 +305,14 @@ def mixture_replace(fgraph, node):
305305
return [new_mixture_rv]
306306

307307

308-
@node_rewriter((Elemwise,))
309-
def switch_mixture_replace(fgraph, node):
308+
@node_rewriter((Elemwise, IfElse))
309+
def switch_ifelse_mixture_replace(fgraph, node):
310310
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
311311

312312
if rv_map_feature is None:
313313
return None # pragma: no cover
314314

315-
if not isinstance(node.op.scalar_op, Switch):
315+
if not isinstance(node.op, IfElse) and not isinstance(node.op.scalar_op, Switch):
316316
return None # pragma: no cover
317317

318318
old_mixture_rv = node.default_output()
@@ -420,7 +420,7 @@ def logprob_MixtureRV(
420420
logprob_rewrites_db.register(
421421
"mixture_replace",
422422
EquilibriumGraphRewriter(
423-
[mixture_replace, switch_mixture_replace],
423+
[mixture_replace, switch_ifelse_mixture_replace],
424424
max_use_ratio=aesara.config.optdb__max_use_ratio,
425425
),
426426
0,

tests/test_mixture.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import pytest
55
import scipy.stats.distributions as sp
6+
from aesara.ifelse import ifelse
67
from aesara.graph.basic import Variable, equal_computations
78
from aesara.tensor.random.basic import CategoricalRV
89
from aesara.tensor.shape import shape_tuple
@@ -715,7 +716,8 @@ def test_mixture_with_DiracDelta():
715716
assert m_vv in logp_res
716717

717718

718-
def test_switch_mixture():
719+
@pytest.mark.parametrize("op", [at.switch, ifelse])
720+
def test_switch_ifelse_mixture(op):
719721
srng = at.random.RandomStream(29833)
720722

721723
X_rv = srng.normal(-10.0, 0.1, name="X")
@@ -725,7 +727,7 @@ def test_switch_mixture():
725727
i_vv = I_rv.clone()
726728
i_vv.name = "i"
727729

728-
Z1_rv = at.switch(I_rv, X_rv, Y_rv)
730+
Z1_rv = op(I_rv, X_rv, Y_rv)
729731
z_vv = Z1_rv.clone()
730732
z_vv.name = "z1"
731733

0 commit comments

Comments
 (0)