|
10 | 10 | node_rewriter,
|
11 | 11 | pre_greedy_node_rewriter,
|
12 | 12 | )
|
13 |
| -from aesara.ifelse import ifelse |
| 13 | +from aesara.ifelse import IfElse, ifelse |
14 | 14 | from aesara.scalar.basic import Switch
|
15 | 15 | from aesara.tensor.basic import Join, MakeVector
|
16 | 16 | from aesara.tensor.elemwise import Elemwise
|
@@ -305,14 +305,14 @@ def mixture_replace(fgraph, node):
|
305 | 305 | return [new_mixture_rv]
|
306 | 306 |
|
307 | 307 |
|
308 |
| -@node_rewriter((Elemwise,)) |
309 |
| -def switch_mixture_replace(fgraph, node): |
| 308 | +@node_rewriter((Elemwise, IfElse)) |
| 309 | +def switch_ifelse_mixture_replace(fgraph, node): |
310 | 310 | rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
|
311 | 311 |
|
312 | 312 | if rv_map_feature is None:
|
313 | 313 | return None # pragma: no cover
|
314 | 314 |
|
315 |
| - if not isinstance(node.op.scalar_op, Switch): |
| 315 | + if not isinstance(node.op, IfElse) and not isinstance(node.op.scalar_op, Switch): |
316 | 316 | return None # pragma: no cover
|
317 | 317 |
|
318 | 318 | old_mixture_rv = node.default_output()
|
@@ -420,7 +420,7 @@ def logprob_MixtureRV(
|
420 | 420 | logprob_rewrites_db.register(
|
421 | 421 | "mixture_replace",
|
422 | 422 | EquilibriumGraphRewriter(
|
423 |
| - [mixture_replace, switch_mixture_replace], |
| 423 | + [mixture_replace, switch_ifelse_mixture_replace], |
424 | 424 | max_use_ratio=aesara.config.optdb__max_use_ratio,
|
425 | 425 | ),
|
426 | 426 | 0,
|
|
0 commit comments