Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit d4d10d3

Browse files
committed
Wrap Ops during etuplization
Evaluating etuplized objects fails for some `RandomVariable` ops whose `__call__` function does not defer to `make_node`. In this commit we wrap `RandomVariable` ops during etuplization with a class that always defers `__call__` to `make_node`. We also add a dispatch rule for `etuplize` so it also wraps `RandomVariable` with the same class.
1 parent 77bb152 commit d4d10d3

File tree

1 file changed

+44
-1
lines changed

1 file changed

+44
-1
lines changed

aesara/graph/unify.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"""
1212

1313
from collections.abc import Mapping
14+
from dataclasses import dataclass
1415
from numbers import Number
1516
from typing import Dict, Optional, Tuple, Union
1617

@@ -25,6 +26,7 @@
2526
from aesara.graph.basic import Constant, Variable
2627
from aesara.graph.op import Op
2728
from aesara.graph.type import Type
29+
from aesara.tensor.random.op import RandomVariable
2830

2931

3032
def eval_if_etuple(x):
@@ -72,9 +74,50 @@ def __repr__(self):
7274
return f"ConstrainedVar({repr(self.constraint)}, {self.token})"
7375

7476

77+
@dataclass
78+
class MakeNodeOp:
79+
"""Wrapper around Ops.
80+
81+
Some RandomVariable Ops's `__call__` method does not defer to
82+
`make_node`, and `eval_if_etuple` fails on their etuplized version. To
83+
circumvent this limitation we wrap Ops with this object that always
84+
defers `__call__` to `make_node`.
85+
86+
"""
87+
88+
op: Op
89+
90+
def __call__(self, *args):
91+
return self.op.make_node(*args)
92+
93+
94+
def car_MakeNodeOp(x):
95+
return type(x)
96+
97+
98+
_car.add((MakeNodeOp,), car_MakeNodeOp)
99+
100+
101+
def cdr_MakeNodeOp(x):
102+
x_e = etuple(_car(x), x.op, evaled_obj=x)
103+
return x_e[1:]
104+
105+
106+
_cdr.add((MakeNodeOp,), cdr_MakeNodeOp)
107+
108+
109+
@etuplize.register(RandomVariable)
110+
def etuplize_random(*args, **kwargs):
111+
"""Wrap RandomVariable Ops with a MakeNodeOp object."""
112+
return etuple(MakeNodeOp, etuplize.funcs[(object,)](*args, **kwargs))
113+
114+
75115
def car_Variable(x):
76116
if x.owner:
77-
return x.owner.op
117+
if issubclass(type(x.owner.op), RandomVariable):
118+
return MakeNodeOp(x.owner.op)
119+
else:
120+
return x.owner.op
78121
else:
79122
raise ConsError("Not a cons pair.")
80123

0 commit comments

Comments
 (0)