diff --git a/aesara/graph/rewriting/unify.py b/aesara/graph/rewriting/unify.py index 4a56212fc8..db432dcf25 100644 --- a/aesara/graph/rewriting/unify.py +++ b/aesara/graph/rewriting/unify.py @@ -12,17 +12,18 @@ from collections.abc import Mapping from numbers import Number -from typing import Dict, Optional, Tuple, Union +from typing import Callable, Dict, Optional, Tuple, Union import numpy as np from cons.core import ConsError, _car, _cdr from etuples import apply, etuple, etuplize -from etuples.core import ExpressionTuple +from etuples.core import ExpressionTuple, etuple +from etuples.dispatch import etuplize_fn from unification.core import _unify, assoc from unification.utils import transitive_get as walk from unification.variable import Var, isvar, var -from aesara.graph.basic import Constant, Variable +from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.op import Op from aesara.graph.type import Type @@ -72,9 +73,70 @@ def __repr__(self): return f"{type(self).__name__}({repr(self.constraint)}, {self.token})" -def car_Variable(x): +class OpExpressionTuple(ExpressionTuple): + r"""Etuple form for `Op`s. + + Some `Op.__call__` signatures (e.g `RandomVariable`s) do not match their + `Op.make_node` signatures, causing `ExpressionTuple.eval_if_etuple` to + fail. To circumvent this constraint we subclass `ExpressionTuple`, and + overload the `_eval_apply` method to use `Op.make_node` instead of + `Op.__call__`. + + """ + + def _eval_apply_fn(self, op: Op) -> Callable: + """We evaluate the etuple to the resulting `Apply` node's outputs.""" + + def eval_op(*inputs, **kwargs): + node = op.make_node(*inputs, **kwargs) + if node.nout == 1: + return node.outputs[0] + else: + return node.outputs + + return eval_op + + def __repr__(self): + return "Op" + super().__repr__() + + def __str__(self): + return "o" + super().__repr__() + + +@etuple.register(Op, [object]) +def etuple_Op(*args, **kwargs) -> OpExpressionTuple: + return OpExpressionTuple(args, **kwargs) + + +@etuplize_fn.register(Op) +def etuplize_fn_Op(_: Op): + return etuple_Op + + +def nth(n: int, node: Apply) -> Variable: + """Function that selects the nth output of an `Apply` node.""" + return node.outputs[n] + + +def car_Variable(x: Variable): + """Return the `car` of a `Variable`. + + The outputs of `Apply` nodes are stored in lists, but the `__call__` + function of the `Op` that creates this `Op` will return the single + output variable instead of the list. We can thus simply return the + `Op` as `car`. + + When there are several outputs, however, `__call__` will return a list, so + returning the `Op` here would return an expression tuple that evaluates to + the list of outputs of the `Apply` node. We thus need to preprend a `nth` + operator that will return the output stored at the specified index. + + """ if x.owner: - return x.owner.op + if x.owner.nout == 1: + return x.owner.op + else: + return nth else: raise ConsError("Not a cons pair.") @@ -82,9 +144,19 @@ def car_Variable(x): _car.add((Variable,), car_Variable) -def cdr_Variable(x): +def cdr_Variable(x: Variable): + """Return the `cdr` of a `Variable` + + For a variable created by a single output `Apply` node the `cdr` is defined as the input list. For + a multiple output `Apply` node the `cdr` is the index of the variable in the + node's outputs list, and the `Apply` node. + + """ if x.owner: - x_e = etuple(_car(x), *x.owner.inputs, evaled_obj=x) + if x.owner.nout == 1: + x_e = etuple(_car(x), *x.owner.inputs, evaled_obj=x) + else: + x_e = etuple(_car(x), x.index, x.owner, evaled_obj=x) else: raise ConsError("Not a cons pair.") @@ -94,7 +166,32 @@ def cdr_Variable(x): _cdr.add((Variable,), cdr_Variable) -def car_Op(x): +def car_Apply(x: Apply): + """Return the `car` of an `Apply` node. + + This will only be called for multiple-output nodes. + + """ + return x.op + + +_car.add((Apply,), car_Apply) + + +def cdr_Apply(x: Apply): + """Return the `car` of an `Apply` node. + + This will only be called for multiple-output nodes. + + """ + x_e = etuple(_car(x), *x.inputs, evaled_obj=x.outputs) + return x_e[1:] + + +_cdr.add((Apply,), cdr_Apply) + + +def car_Op(x: Op): if hasattr(x, "__props__"): return type(x) @@ -104,7 +201,7 @@ def car_Op(x): _car.add((Op,), car_Op) -def cdr_Op(x): +def cdr_Op(x: Op): if not hasattr(x, "__props__"): raise ConsError("Not a cons pair.") @@ -201,12 +298,27 @@ def _unify_Constant_Constant(u, v, s): def _unify_Variable_ExpressionTuple(u, v, s): - # `Constant`s are "atomic" + """Unify a `Variable` with an `ExpressionTuple`. + + If the `Variable`'s owner only has one output we can etuplize the `Variable` + and unify both expression tuple. + + If the owner has multiple outputs, but the `Op`'s `default_output` is not + `None` we unify the etuplized version of the `Variable` with an expanded + expression tuple that account for the variable selection. We only do this + for nodes with a default output (we otherwise expect the caller to use + the `nth` operator in the expression tuple). + + """ if not u.owner: yield False return - - yield _unify(etuplize(u, shallow=True), v, s) + if u.owner.nout == 1: + yield _unify(etuplize(u, shallow=True), v, s) + elif u.owner.nout == 2 and u.owner.op.default_output is not None: + u_et = etuplize(u) + v_et = etuple(nth, u_et[1], v) + yield _unify(u_et, v_et, s) _unify.add( diff --git a/environment.yml b/environment.yml index 3bffffc618..712f9f9501 100644 --- a/environment.yml +++ b/environment.yml @@ -12,7 +12,7 @@ dependencies: - numpy>=1.17.0 - scipy>=0.14 - filelock - - etuples + - etuples>=0.3.8 - logical-unification - miniKanren - cons diff --git a/tests/graph/rewriting/test_unify.py b/tests/graph/rewriting/test_unify.py index 6ce1284794..32ff40269c 100644 --- a/tests/graph/rewriting/test_unify.py +++ b/tests/graph/rewriting/test_unify.py @@ -11,7 +11,11 @@ import aesara.tensor as at from aesara.graph.basic import Apply, Constant, equal_computations from aesara.graph.op import Op -from aesara.graph.rewriting.unify import ConstrainedVar, convert_strs_to_vars +from aesara.graph.rewriting.unify import ( + ConstrainedVar, + OpExpressionTuple, + convert_strs_to_vars, +) from aesara.tensor.type import TensorType from tests.graph.utils import MyType @@ -98,7 +102,7 @@ def test_cons(): assert cdr_res == [atype_at.dtype, atype_at.shape] -def test_etuples(): +def test_etuples_Op(): x_at = at.vector("x") y_at = at.vector("y") @@ -110,11 +114,20 @@ def test_etuples(): assert res.owner.inputs == [x_at, y_at] w_at = etuple(at.add, x_at, y_at) + assert isinstance(w_at, OpExpressionTuple) + w_at._evaled_obj = w_at.null res = w_at.evaled_obj assert res.owner.op == at.add assert res.owner.inputs == [x_at, y_at] + +def test_etuples_atomic_Op(): + x_at = at.vector("x") + y_at = at.vector("y") + + z_at = etuple(x_at, y_at) + # This `Op` doesn't expand into an `etuple` (i.e. it's "atomic") op1_np = CustomOpNoProps(1) @@ -122,12 +135,20 @@ def test_etuples(): assert res.owner.op == op1_np q_at = op1_np(x_at, y_at) - res = etuplize(q_at) - assert res[0] == op1_np + q_et = etuplize(q_at) + assert isinstance(q_et, ExpressionTuple) + assert isinstance(q_et[0], CustomOpNoProps) + assert q_et[0] == op1_np + + q_et._evaled_obj = q_et.null + res = q_et.evaled_obj + assert isinstance(res.owner.op, CustomOpNoProps) with pytest.raises(TypeError): etuplize(op1_np) + +def test_etuples_multioutput_Op(): class MyMultiOutOp(Op): def make_node(self, *inputs): outputs = [MyType()(), MyType()()] @@ -144,6 +165,52 @@ def perform(self, node, inputs, outputs): assert res[0].owner.op == op1_np assert res[1].owner.op == op1_np + # If we etuplize one of the outputs we should recover this output when + # evaluating + q_at, _ = op1_np(x_at) + q_et = etuplize(q_at) + q_et._evaled_obj = q_et.null + res = q_et.evaled_obj + assert res == q_at + + # TODO: If the caller etuplizes the output list, it should recover the list + # when evaluating. + # q_at = op1_np(x_at) + # q_et = etuplize(q_at) + # q_et._evaled_obj = q_et.null + # res = q_et.evaled_obj + + +def test_etuples_default_output_op(): + class MyDefaultOutputOp(Op): + default_output = 1 + + def make_node(self, *inputs): + outputs = [MyType()(), MyType()()] + return Apply(self, list(inputs), outputs) + + def perform(self, node, inputs, outputs): + outputs[0] = np.array(np.array(inputs[0])) + outputs[1] = np.array(np.array(inputs[1])) + + x_at = at.vector("x") + y_at = at.vector("y") + op1_np = MyDefaultOutputOp() + res = apply(op1_np, etuple(x_at, y_at)) + assert res.owner.op == op1_np + assert res.owner.inputs[0] == x_at + assert res.owner.inputs[1] == y_at + + # We should recover the default output when evaluting its etuplized + # counterpart. + q_at = op1_np(x_at, y_at) + q_et = etuplize(q_at) + q_et._evaled_obj = q_et.null + res = q_et.evaled_obj + assert isinstance(res.owner.op, MyDefaultOutputOp) + assert res.owner.inputs[0] == x_at + assert res.owner.inputs[1] == y_at + def test_unify_Variable(): x_at = at.vector("x") @@ -189,7 +256,7 @@ def test_unify_Variable(): res = reify(z_pat_et, s) assert isinstance(res, ExpressionTuple) - assert equal_computations([res.evaled_obj], [z_et.evaled_obj]) + assert equal_computations([res.evaled_obj[0]], [z_et.evaled_obj[0]]) # `ExpressionTuple`, `Variable` s = unify(z_et, x_at, {}) @@ -209,6 +276,33 @@ def test_unify_Variable(): assert s[b_lv] == y_at +def test_unify_default_output_Variable(): + """Make sure that we can unify with the default output of an Apply node.""" + + class MyDefaultOutputOp(Op): + default_output = 1 + + def make_node(self, *inputs): + outputs = [MyType()(), MyType()()] + return Apply(self, list(inputs), outputs) + + def perform(self, node, inputs, outputs): + outputs[0] = np.array(np.array(inputs[0])) + outputs[1] = np.array(np.array(inputs[1])) + + x_at = at.vector("x") + y_at = at.vector("y") + op1_np = MyDefaultOutputOp() + q_at = op1_np(x_at, y_at) + + x_lv, y_lv = var("x"), var("y") + q_et = etuple(op1_np, x_lv, y_lv) + + s = unify(q_et, q_at) + assert s[x_lv] == x_at + assert s[y_lv] == y_at + + def test_unify_Op(): # These `Op`s expand into `ExpressionTuple`s op1 = CustomOp(1)