Skip to content

Commit de5fca0

Browse files
committed
Wrap Ops during etuplization
Evaluating etuplized objects fails for some `RandomVariable` ops whose `__call__` function does not defer to `make_node`. In this commit we wrap `RandomVariable` ops during etuplization with a class that always defers `__call__` to `make_node`. We also add a dispatch rule for `etuplize` so it also wraps `RandomVariable` with the same class.
1 parent 9f176da commit de5fca0

File tree

2 files changed

+55
-1
lines changed

2 files changed

+55
-1
lines changed

aesara/graph/rewriting/unify.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"""
1212

1313
from collections.abc import Mapping
14+
from dataclasses import dataclass
1415
from numbers import Number
1516
from typing import Dict, Optional, Tuple, Union
1617

@@ -25,6 +26,7 @@
2526
from aesara.graph.basic import Constant, Variable
2627
from aesara.graph.op import Op
2728
from aesara.graph.type import Type
29+
from aesara.tensor.random.op import RandomVariable
2830

2931

3032
def eval_if_etuple(x):
@@ -72,9 +74,50 @@ def __repr__(self):
7274
return f"{type(self).__name__}({repr(self.constraint)}, {self.token})"
7375

7476

77+
@dataclass
78+
class MakeRandomVariableNodeOp:
79+
r"""Wrapper around `RandomVariable` `Op`\s.
80+
81+
Some `Op.__call__` signatures (e.g. `RandomVariable`) do not match their
82+
`Op.make_node` signatures, causing `ExpressionTuple.eval_if_etuple` to
83+
fail. To circumvent this limitation, we wrap `Op`\s with this class, which
84+
maps `Op.__call__` directly to `Op.make_node`.
85+
86+
"""
87+
88+
op: Op
89+
90+
def __call__(self, *args):
91+
return self.op.make_node(*args)
92+
93+
94+
def car_MakeRandomVariableNodeOp(x):
95+
return type(x)
96+
97+
98+
_car.add((MakeRandomVariableNodeOp,), car_MakeRandomVariableNodeOp)
99+
100+
101+
def cdr_MakeRandomVariableNodeOp(x):
102+
x_e = etuple(_car(x), x.op, evaled_obj=x)
103+
return x_e[1:]
104+
105+
106+
_cdr.add((MakeRandomVariableNodeOp,), cdr_MakeRandomVariableNodeOp)
107+
108+
109+
@etuplize.register(RandomVariable)
110+
def etuplize_random(*args, **kwargs):
111+
"""Wrap RandomVariable Ops with a MakeNodeOp object."""
112+
return etuple(MakeRandomVariableNodeOp, etuplize.funcs[(object,)](*args, **kwargs))
113+
114+
75115
def car_Variable(x):
76116
if x.owner:
77-
return x.owner.op
117+
if issubclass(type(x.owner.op), RandomVariable):
118+
return MakeRandomVariableNodeOp(x.owner.op)
119+
else:
120+
return x.owner.op
78121
else:
79122
raise ConsError("Not a cons pair.")
80123

tests/graph/rewriting/test_unify.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,17 @@ def perform(self, node, inputs, outputs):
144144
assert res[0].owner.op == op1_np
145145
assert res[1].owner.op == op1_np
146146

147+
# Etuplize RandomVariable ops and the Variables they create
148+
x_rv = at.random.normal(0, 1, name="x")
149+
x_et = etuplize(x_rv)
150+
res = x_et.evaled_obj
151+
assert res.owner.op == at.random.normal
152+
153+
y_et = etuple(etuplize(at.random.normal), *x_et[1:])
154+
res = y_et.evaled_obj
155+
assert x_et == y_et
156+
assert res.op == at.random.normal
157+
147158

148159
def test_unify_Variable():
149160
x_at = at.vector("x")

0 commit comments

Comments
 (0)