Skip to content

Commit 53931d3

Browse files
committed
Allow etuplization of RandomVariables
1 parent e40c827 commit 53931d3

File tree

3 files changed

+66
-4
lines changed

3 files changed

+66
-4
lines changed

aesara/graph/rewriting/unify.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,21 @@
1212

1313
from collections.abc import Mapping
1414
from numbers import Number
15-
from typing import Dict, Optional, Tuple, Union
15+
from typing import Callable, Dict, Optional, Tuple, Union
1616

1717
import numpy as np
1818
from cons.core import ConsError, _car, _cdr
1919
from etuples import apply, etuple, etuplize
20-
from etuples.core import ExpressionTuple
20+
from etuples.core import ExpressionTuple, etuple
21+
from etuples.dispatch import etuplize_fn
2122
from unification.core import _unify, assoc
2223
from unification.utils import transitive_get as walk
2324
from unification.variable import Var, isvar, var
2425

2526
from aesara.graph.basic import Constant, Variable
2627
from aesara.graph.op import Op
2728
from aesara.graph.type import Type
29+
from aesara.tensor.random.op import RandomVariable
2830

2931

3032
def eval_if_etuple(x):
@@ -72,6 +74,41 @@ def __repr__(self):
7274
return f"{type(self).__name__}({repr(self.constraint)}, {self.token})"
7375

7476

77+
class RVExpressionTuple(ExpressionTuple):
78+
r"""Etuple form for `RandomVariables`s.
79+
80+
Some `RandomVariable.__call__` signatures do not match their
81+
`RandomVariable.make_node` signatures, causing `ExpressionTuple.eval_if_etuple` to
82+
fail. To circumvent this limitation we subclass `ExpressionTuple`, and
83+
overload the `_eval_apply` method to use `Op.make_node` instead of
84+
`Op.__call__`.
85+
86+
"""
87+
88+
def _eval_apply_fn(self, op: RandomVariable) -> Callable:
89+
def eval_fn(*inputs, **kwargs):
90+
node = op.make_node(*inputs, **kwargs)
91+
return node.outputs[1]
92+
93+
return eval_fn
94+
95+
def __repr__(self):
96+
return "RV" + super().__repr__()
97+
98+
def __str__(self):
99+
return "rv" + super().__repr__()
100+
101+
102+
@etuple.register(RandomVariable, [object])
103+
def etuple_RandomVariable(*args, **kwargs) -> RVExpressionTuple:
104+
return RVExpressionTuple(args, **kwargs)
105+
106+
107+
@etuplize_fn.register(RandomVariable)
108+
def etuplize_fn_RandomVariable(_: RandomVariable):
109+
return etuple_RandomVariable
110+
111+
75112
def car_Variable(x):
76113
if x.owner:
77114
return x.owner.op

environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ dependencies:
1212
- numpy>=1.17.0
1313
- scipy>=0.14
1414
- filelock
15-
- etuples
15+
- etuples>=0.3.8
1616
- logical-unification
1717
- miniKanren
1818
- cons

tests/graph/rewriting/test_unify.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
import aesara.tensor as at
1212
from aesara.graph.basic import Apply, Constant, equal_computations
1313
from aesara.graph.op import Op
14-
from aesara.graph.rewriting.unify import ConstrainedVar, convert_strs_to_vars
14+
from aesara.graph.rewriting.unify import (
15+
ConstrainedVar,
16+
RVExpressionTuple,
17+
convert_strs_to_vars,
18+
)
1519
from aesara.tensor.type import TensorType
1620
from tests.graph.utils import MyType
1721

@@ -110,6 +114,7 @@ def test_etuples():
110114
assert res.owner.inputs == [x_at, y_at]
111115

112116
w_at = etuple(at.add, x_at, y_at)
117+
assert isinstance(w_at, ExpressionTuple)
113118

114119
res = w_at.evaled_obj
115120
assert res.owner.op == at.add
@@ -123,6 +128,8 @@ def test_etuples():
123128

124129
q_at = op1_np(x_at, y_at)
125130
res = etuplize(q_at)
131+
assert isinstance(res, ExpressionTuple)
132+
assert isinstance(res[0], CustomOpNoProps)
126133
assert res[0] == op1_np
127134

128135
with pytest.raises(TypeError):
@@ -144,6 +151,24 @@ def perform(self, node, inputs, outputs):
144151
assert res[0].owner.op == op1_np
145152
assert res[1].owner.op == op1_np
146153

154+
mu_at = at.scalar("mu")
155+
sigma_at = at.scalar("sigma")
156+
157+
w_rv = at.random.normal(mu_at, sigma_at)
158+
w_at = etuplize(w_rv)
159+
assert isinstance(w_at, RVExpressionTuple)
160+
assert isinstance(w_at[0], ExpressionTuple)
161+
162+
z_at = etuple(at.random.normal, mu_at, sigma_at)
163+
assert isinstance(z_at, RVExpressionTuple)
164+
165+
z_at = etuple(at.random.normal, *w_at[1:])
166+
assert isinstance(z_at, RVExpressionTuple)
167+
168+
res = z_at.evaled_obj
169+
assert res.owner.op == at.random.normal
170+
assert res.owner.inputs[-2:] == [mu_at, sigma_at]
171+
147172

148173
def test_unify_Variable():
149174
x_at = at.vector("x")

0 commit comments

Comments
 (0)