|  | 
| 11 | 11 | """ | 
| 12 | 12 | 
 | 
| 13 | 13 | from collections.abc import Mapping | 
|  | 14 | +from dataclasses import dataclass | 
| 14 | 15 | from numbers import Number | 
| 15 | 16 | from typing import Dict, Optional, Tuple, Union | 
| 16 | 17 | 
 | 
|  | 
| 25 | 26 | from aesara.graph.basic import Constant, Variable | 
| 26 | 27 | from aesara.graph.op import Op | 
| 27 | 28 | from aesara.graph.type import Type | 
|  | 29 | +from aesara.tensor.random.op import RandomVariable | 
| 28 | 30 | 
 | 
| 29 | 31 | 
 | 
| 30 | 32 | def eval_if_etuple(x): | 
| @@ -72,9 +74,50 @@ def __repr__(self): | 
| 72 | 74 |         return f"{type(self).__name__}({repr(self.constraint)}, {self.token})" | 
| 73 | 75 | 
 | 
| 74 | 76 | 
 | 
|  | 77 | +@dataclass | 
|  | 78 | +class MakeRandomVariableNodeOp: | 
|  | 79 | +    r"""Wrapper around `RandomVariable` `Op`\s. | 
|  | 80 | +
 | 
|  | 81 | +    Some `Op.__call__` signatures (e.g. `RandomVariable`) do not match their | 
|  | 82 | +    `Op.make_node` signatures, causing `ExpressionTuple.eval_if_etuple` to | 
|  | 83 | +    fail. To circumvent this limitation, we wrap `Op`\s with this class, which | 
|  | 84 | +    maps `Op.__call__` directly to `Op.make_node`. | 
|  | 85 | +
 | 
|  | 86 | +    """ | 
|  | 87 | + | 
|  | 88 | +    op: Op | 
|  | 89 | + | 
|  | 90 | +    def __call__(self, *args): | 
|  | 91 | +        return self.op.make_node(*args) | 
|  | 92 | + | 
|  | 93 | + | 
|  | 94 | +def car_MakeRandomVariableNodeOp(x): | 
|  | 95 | +    return type(x) | 
|  | 96 | + | 
|  | 97 | + | 
|  | 98 | +_car.add((MakeRandomVariableNodeOp,), car_MakeRandomVariableNodeOp) | 
|  | 99 | + | 
|  | 100 | + | 
|  | 101 | +def cdr_MakeRandomVariableNodeOp(x): | 
|  | 102 | +    x_e = etuple(_car(x), x.op, evaled_obj=x) | 
|  | 103 | +    return x_e[1:] | 
|  | 104 | + | 
|  | 105 | + | 
|  | 106 | +_cdr.add((MakeRandomVariableNodeOp,), cdr_MakeRandomVariableNodeOp) | 
|  | 107 | + | 
|  | 108 | + | 
|  | 109 | +@etuplize.register(RandomVariable) | 
|  | 110 | +def etuplize_random(*args, **kwargs): | 
|  | 111 | +    """Wrap RandomVariable Ops with a MakeNodeOp object.""" | 
|  | 112 | +    return etuple(MakeRandomVariableNodeOp, etuplize.funcs[(object,)](*args, **kwargs)) | 
|  | 113 | + | 
|  | 114 | + | 
| 75 | 115 | def car_Variable(x): | 
| 76 | 116 |     if x.owner: | 
| 77 |  | -        return x.owner.op | 
|  | 117 | +        if issubclass(type(x.owner.op), RandomVariable): | 
|  | 118 | +            return MakeRandomVariableNodeOp(x.owner.op) | 
|  | 119 | +        else: | 
|  | 120 | +            return x.owner.op | 
| 78 | 121 |     else: | 
| 79 | 122 |         raise ConsError("Not a cons pair.") | 
| 80 | 123 | 
 | 
|  | 
0 commit comments