From d8778e71d8fb48dca0c6f0e2d968de67f25be0e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Tue, 6 Sep 2022 09:43:19 +0200 Subject: [PATCH 1/4] Allow etuplization of `Op`s --- aesara/graph/rewriting/unify.py | 71 +++++++++++++++++++++++++++-- environment.yml | 2 +- tests/graph/rewriting/test_unify.py | 39 ++++++++++++++-- 3 files changed, 102 insertions(+), 10 deletions(-) diff --git a/aesara/graph/rewriting/unify.py b/aesara/graph/rewriting/unify.py index 4a56212fc8..9183676d9a 100644 --- a/aesara/graph/rewriting/unify.py +++ b/aesara/graph/rewriting/unify.py @@ -12,12 +12,13 @@ from collections.abc import Mapping from numbers import Number -from typing import Dict, Optional, Tuple, Union +from typing import Callable, Dict, Optional, Tuple, Union import numpy as np from cons.core import ConsError, _car, _cdr from etuples import apply, etuple, etuplize -from etuples.core import ExpressionTuple +from etuples.core import ExpressionTuple, etuple +from etuples.dispatch import etuplize_fn from unification.core import _unify, assoc from unification.utils import transitive_get as walk from unification.variable import Var, isvar, var @@ -72,6 +73,43 @@ def __repr__(self): return f"{type(self).__name__}({repr(self.constraint)}, {self.token})" +class OpExpressionTuple(ExpressionTuple): + r"""Etuple form for `Op`s. + + Some `Op.__call__` signatures (e.g `RandomVariable`s) do not match their + `Op.make_node` signatures, causing `ExpressionTuple.eval_if_etuple` to + fail. To circumvent this constraint we subclass `ExpressionTuple`, and + overload the `_eval_apply` method to use `Op.make_node` instead of + `Op.__call__`. + + """ + + def _eval_apply_fn(self, op: Op) -> Callable: + """We evaluate the etuple to the resulting `Apply` node's outputs.""" + + def eval_op(*inputs, **kwargs): + node = op.make_node(*inputs, **kwargs) + return node.outputs + + return eval_op + + def __repr__(self): + return "Op" + super().__repr__() + + def __str__(self): + return "o" + super().__repr__() + + +@etuple.register(Op, [object]) +def etuple_Op(*args, **kwargs) -> OpExpressionTuple: + return OpExpressionTuple(args, **kwargs) + + +@etuplize_fn.register(Op) +def etuplize_fn_Op(_: Op): + return etuple_Op + + def car_Variable(x): if x.owner: return x.owner.op @@ -94,7 +132,32 @@ def cdr_Variable(x): _cdr.add((Variable,), cdr_Variable) -def car_Op(x): +def car_Apply(x: Apply): + """Return the `car` of an `Apply` node. + + This will only be called for multiple-output nodes. + + """ + return x.op + + +_car.add((Apply,), car_Apply) + + +def cdr_Apply(x: Apply): + """Return the `car` of an `Apply` node. + + This will only be called for multiple-output nodes. + + """ + x_e = etuple(_car(x), *x.inputs, evaled_obj=x.outputs) + return x_e[1:] + + +_cdr.add((Apply,), cdr_Apply) + + +def car_Op(x: Op): if hasattr(x, "__props__"): return type(x) @@ -104,7 +167,7 @@ def car_Op(x): _car.add((Op,), car_Op) -def cdr_Op(x): +def cdr_Op(x: Op): if not hasattr(x, "__props__"): raise ConsError("Not a cons pair.") diff --git a/environment.yml b/environment.yml index 3bffffc618..712f9f9501 100644 --- a/environment.yml +++ b/environment.yml @@ -12,7 +12,7 @@ dependencies: - numpy>=1.17.0 - scipy>=0.14 - filelock - - etuples + - etuples>=0.3.8 - logical-unification - miniKanren - cons diff --git a/tests/graph/rewriting/test_unify.py b/tests/graph/rewriting/test_unify.py index 6ce1284794..f57e55c6ad 100644 --- a/tests/graph/rewriting/test_unify.py +++ b/tests/graph/rewriting/test_unify.py @@ -11,7 +11,11 @@ import aesara.tensor as at from aesara.graph.basic import Apply, Constant, equal_computations from aesara.graph.op import Op -from aesara.graph.rewriting.unify import ConstrainedVar, convert_strs_to_vars +from aesara.graph.rewriting.unify import ( + ConstrainedVar, + OpExpressionTuple, + convert_strs_to_vars, +) from aesara.tensor.type import TensorType from tests.graph.utils import MyType @@ -110,10 +114,14 @@ def test_etuples(): assert res.owner.inputs == [x_at, y_at] w_at = etuple(at.add, x_at, y_at) + assert isinstance(w_at, ExpressionTuple) res = w_at.evaled_obj - assert res.owner.op == at.add - assert res.owner.inputs == [x_at, y_at] + assert len(res) == 1 + + output = res[0] + assert output.owner.op == at.add + assert output.owner.inputs == [x_at, y_at] # This `Op` doesn't expand into an `etuple` (i.e. it's "atomic") op1_np = CustomOpNoProps(1) @@ -123,6 +131,8 @@ def test_etuples(): q_at = op1_np(x_at, y_at) res = etuplize(q_at) + assert isinstance(res, ExpressionTuple) + assert isinstance(res[0], CustomOpNoProps) assert res[0] == op1_np with pytest.raises(TypeError): @@ -144,6 +154,25 @@ def perform(self, node, inputs, outputs): assert res[0].owner.op == op1_np assert res[1].owner.op == op1_np + mu_at = at.scalar("mu") + sigma_at = at.scalar("sigma") + + w_rv = at.random.normal(mu_at, sigma_at) + w_at = etuplize(w_rv) + assert isinstance(w_at, OpExpressionTuple) + assert isinstance(w_at[0], ExpressionTuple) + + z_at = etuple(at.random.normal, mu_at, sigma_at) + assert isinstance(z_at, OpExpressionTuple) + + z_at = etuple(at.random.normal, *w_at[1:]) + assert isinstance(z_at, OpExpressionTuple) + + res = z_at.evaled_obj + assert len(res) == 2 + assert res[1].owner.op == at.random.normal + assert res[1].owner.inputs[-2:] == [mu_at, sigma_at] + def test_unify_Variable(): x_at = at.vector("x") @@ -176,7 +205,7 @@ def test_unify_Variable(): res = reify(z_pat_et, s) assert isinstance(res, ExpressionTuple) - assert equal_computations([res.evaled_obj], [z_at]) + assert equal_computations([res.evaled_obj[0]], [z_at]) z_et = etuple(at.add, x_at, y_at) @@ -189,7 +218,7 @@ def test_unify_Variable(): res = reify(z_pat_et, s) assert isinstance(res, ExpressionTuple) - assert equal_computations([res.evaled_obj], [z_et.evaled_obj]) + assert equal_computations([res.evaled_obj[0]], [z_et.evaled_obj[0]]) # `ExpressionTuple`, `Variable` s = unify(z_et, x_at, {}) From 63af9960745d8837a90cabf6e785549b4189ebce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 12 Sep 2022 23:24:30 +0200 Subject: [PATCH 2/4] Add nth operator to select the output of an Apply node --- aesara/graph/rewriting/unify.py | 46 +++++++++++++++++++++++++---- tests/graph/rewriting/test_unify.py | 18 +++++------ 2 files changed, 48 insertions(+), 16 deletions(-) diff --git a/aesara/graph/rewriting/unify.py b/aesara/graph/rewriting/unify.py index 9183676d9a..af38eec21d 100644 --- a/aesara/graph/rewriting/unify.py +++ b/aesara/graph/rewriting/unify.py @@ -23,7 +23,7 @@ from unification.utils import transitive_get as walk from unification.variable import Var, isvar, var -from aesara.graph.basic import Constant, Variable +from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.op import Op from aesara.graph.type import Type @@ -89,7 +89,10 @@ def _eval_apply_fn(self, op: Op) -> Callable: def eval_op(*inputs, **kwargs): node = op.make_node(*inputs, **kwargs) - return node.outputs + if node.nout == 1: + return node.outputs[0] + else: + return node.outputs return eval_op @@ -110,9 +113,30 @@ def etuplize_fn_Op(_: Op): return etuple_Op -def car_Variable(x): +def nth(n: int, node: Apply) -> Variable: + """Function that selects the nth output of an `Apply` node.""" + return node.outputs[n] + + +def car_Variable(x: Variable): + """Return the `car` of a `Variable`. + + The outputs of `Apply` nodes are stored in lists, but the `__call__` + function of the `Op` that creates this `Op` will return the single + output variable instead of the list. We can thus simply return the + `Op` as `car`. + + When there are several outputs, however, `__call__` will return a list, so + returning the `Op` here would return an expression tuple that evaluates to + the list of outputs of the `Apply` node. We thus need to preprend a `nth` + operator that will return the output stored at the specified index. + + """ if x.owner: - return x.owner.op + if x.owner.nout == 1: + return x.owner.op + else: + return nth else: raise ConsError("Not a cons pair.") @@ -120,9 +144,19 @@ def car_Variable(x): _car.add((Variable,), car_Variable) -def cdr_Variable(x): +def cdr_Variable(x: Variable): + """Return the `cdr` of a `Variable` + + For a variable created by a single output `Apply` node the `cdr` is defined as the input list. For + a multiple output `Apply` node the `cdr` is the index of the variable in the + node's outputs list, and the `Apply` node. + + """ if x.owner: - x_e = etuple(_car(x), *x.owner.inputs, evaled_obj=x) + if x.owner.nout == 1: + x_e = etuple(_car(x), *x.owner.inputs, evaled_obj=x) + else: + x_e = etuple(_car(x), x.index, x.owner, evaled_obj=x) else: raise ConsError("Not a cons pair.") diff --git a/tests/graph/rewriting/test_unify.py b/tests/graph/rewriting/test_unify.py index f57e55c6ad..5debd7ec30 100644 --- a/tests/graph/rewriting/test_unify.py +++ b/tests/graph/rewriting/test_unify.py @@ -114,14 +114,11 @@ def test_etuples(): assert res.owner.inputs == [x_at, y_at] w_at = etuple(at.add, x_at, y_at) - assert isinstance(w_at, ExpressionTuple) + assert isinstance(w_at, OpExpressionTuple) res = w_at.evaled_obj - assert len(res) == 1 - - output = res[0] - assert output.owner.op == at.add - assert output.owner.inputs == [x_at, y_at] + assert res.owner.op == at.add + assert res.owner.inputs == [x_at, y_at] # This `Op` doesn't expand into an `etuple` (i.e. it's "atomic") op1_np = CustomOpNoProps(1) @@ -159,13 +156,14 @@ def perform(self, node, inputs, outputs): w_rv = at.random.normal(mu_at, sigma_at) w_at = etuplize(w_rv) - assert isinstance(w_at, OpExpressionTuple) - assert isinstance(w_at[0], ExpressionTuple) + assert isinstance(w_at, ExpressionTuple) + assert w_at[1] == 1 + assert isinstance(w_at[2], OpExpressionTuple) z_at = etuple(at.random.normal, mu_at, sigma_at) assert isinstance(z_at, OpExpressionTuple) - z_at = etuple(at.random.normal, *w_at[1:]) + z_at = etuple(at.random.normal, *w_at[2][1:]) assert isinstance(z_at, OpExpressionTuple) res = z_at.evaled_obj @@ -205,7 +203,7 @@ def test_unify_Variable(): res = reify(z_pat_et, s) assert isinstance(res, ExpressionTuple) - assert equal_computations([res.evaled_obj[0]], [z_at]) + assert equal_computations([res.evaled_obj], [z_at]) z_et = etuple(at.add, x_at, y_at) From 96d60ceab87279e4cb4140bd24a4f2ce2068defe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 29 Sep 2022 15:11:56 +0200 Subject: [PATCH 3/4] Test eval of variables created by multiple output and default output `Op`s --- tests/graph/rewriting/test_unify.py | 80 +++++++++++++++++++++-------- 1 file changed, 60 insertions(+), 20 deletions(-) diff --git a/tests/graph/rewriting/test_unify.py b/tests/graph/rewriting/test_unify.py index 5debd7ec30..17b747bae9 100644 --- a/tests/graph/rewriting/test_unify.py +++ b/tests/graph/rewriting/test_unify.py @@ -102,7 +102,7 @@ def test_cons(): assert cdr_res == [atype_at.dtype, atype_at.shape] -def test_etuples(): +def test_etuples_Op(): x_at = at.vector("x") y_at = at.vector("y") @@ -116,10 +116,18 @@ def test_etuples(): w_at = etuple(at.add, x_at, y_at) assert isinstance(w_at, OpExpressionTuple) + w_at._evaled_obj = w_at.null res = w_at.evaled_obj assert res.owner.op == at.add assert res.owner.inputs == [x_at, y_at] + +def test_etuples_atomic_Op(): + x_at = at.vector("x") + y_at = at.vector("y") + + z_at = etuple(x_at, y_at) + # This `Op` doesn't expand into an `etuple` (i.e. it's "atomic") op1_np = CustomOpNoProps(1) @@ -127,14 +135,20 @@ def test_etuples(): assert res.owner.op == op1_np q_at = op1_np(x_at, y_at) - res = etuplize(q_at) - assert isinstance(res, ExpressionTuple) - assert isinstance(res[0], CustomOpNoProps) - assert res[0] == op1_np + q_et = etuplize(q_at) + assert isinstance(q_et, ExpressionTuple) + assert isinstance(q_et[0], CustomOpNoProps) + assert q_et[0] == op1_np + + q_et._evaled_obj = q_et.null + res = q_et.evaled_obj + assert isinstance(res.owner.op, CustomOpNoProps) with pytest.raises(TypeError): etuplize(op1_np) + +def test_etuples_multioutput_Op(): class MyMultiOutOp(Op): def make_node(self, *inputs): outputs = [MyType()(), MyType()()] @@ -151,25 +165,51 @@ def perform(self, node, inputs, outputs): assert res[0].owner.op == op1_np assert res[1].owner.op == op1_np - mu_at = at.scalar("mu") - sigma_at = at.scalar("sigma") + # If we etuplize one of the outputs we should recover this output when + # evaluating + q_at, _ = op1_np(x_at) + q_et = etuplize(q_at) + q_et._evaled_obj = q_et.null + res = q_et.evaled_obj + assert res == q_at - w_rv = at.random.normal(mu_at, sigma_at) - w_at = etuplize(w_rv) - assert isinstance(w_at, ExpressionTuple) - assert w_at[1] == 1 - assert isinstance(w_at[2], OpExpressionTuple) + # TODO: If the caller etuplizes the output list, it should recover the list + # when evaluating. + # q_at = op1_np(x_at) + # q_et = etuplize(q_at) + # q_et._evaled_obj = q_et.null + # res = q_et.evaled_obj - z_at = etuple(at.random.normal, mu_at, sigma_at) - assert isinstance(z_at, OpExpressionTuple) - z_at = etuple(at.random.normal, *w_at[2][1:]) - assert isinstance(z_at, OpExpressionTuple) +def test_etuples_default_output_op(): + class MyDefaultOutputOp(Op): + default_output = 1 - res = z_at.evaled_obj - assert len(res) == 2 - assert res[1].owner.op == at.random.normal - assert res[1].owner.inputs[-2:] == [mu_at, sigma_at] + def make_node(self, *inputs): + outputs = [MyType()(), MyType()()] + return Apply(self, list(inputs), outputs) + + def perform(self, node, inputs, outputs): + outputs[0] = np.array(np.array(inputs[0])) + outputs[1] = np.array(np.array(inputs[1])) + + x_at = at.vector("x") + y_at = at.vector("y") + op1_np = MyDefaultOutputOp() + res = apply(op1_np, etuple(x_at, y_at)) + assert res.owner.op == op1_np + assert res.owner.inputs[0] == x_at + assert res.owner.inputs[1] == y_at + + # We should recover the default output when evaluting its etuplized + # counterpart. + q_at = op1_np(x_at, y_at) + q_et = etuplize(q_at) + q_et._evaled_obj = q_et.null + res = q_et.evaled_obj + assert isinstance(res.owner.op, MyDefaultOutputOp) + assert res.owner.inputs[0] == x_at + assert res.owner.inputs[1] == y_at def test_unify_Variable(): From c84149c781fdb8df3a01a28fdb72b0bd1bda295b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Wed, 5 Oct 2022 17:23:03 +0200 Subject: [PATCH 4/4] Unify `Variable`s that are default outputs with `ExpressionTuple`s The difficulty comes from the fact that the `Variable` is etuplized as `ExpressionTuple(nth, default_output, oExpressionTuple(...))` but we cannot expect the caller to use `nth` here since this `default_output` mechanism is hidden. We expand the latter before unifying. --- aesara/graph/rewriting/unify.py | 21 ++++++++++++++++++--- tests/graph/rewriting/test_unify.py | 27 +++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 3 deletions(-) diff --git a/aesara/graph/rewriting/unify.py b/aesara/graph/rewriting/unify.py index af38eec21d..db432dcf25 100644 --- a/aesara/graph/rewriting/unify.py +++ b/aesara/graph/rewriting/unify.py @@ -298,12 +298,27 @@ def _unify_Constant_Constant(u, v, s): def _unify_Variable_ExpressionTuple(u, v, s): - # `Constant`s are "atomic" + """Unify a `Variable` with an `ExpressionTuple`. + + If the `Variable`'s owner only has one output we can etuplize the `Variable` + and unify both expression tuple. + + If the owner has multiple outputs, but the `Op`'s `default_output` is not + `None` we unify the etuplized version of the `Variable` with an expanded + expression tuple that account for the variable selection. We only do this + for nodes with a default output (we otherwise expect the caller to use + the `nth` operator in the expression tuple). + + """ if not u.owner: yield False return - - yield _unify(etuplize(u, shallow=True), v, s) + if u.owner.nout == 1: + yield _unify(etuplize(u, shallow=True), v, s) + elif u.owner.nout == 2 and u.owner.op.default_output is not None: + u_et = etuplize(u) + v_et = etuple(nth, u_et[1], v) + yield _unify(u_et, v_et, s) _unify.add( diff --git a/tests/graph/rewriting/test_unify.py b/tests/graph/rewriting/test_unify.py index 17b747bae9..32ff40269c 100644 --- a/tests/graph/rewriting/test_unify.py +++ b/tests/graph/rewriting/test_unify.py @@ -276,6 +276,33 @@ def test_unify_Variable(): assert s[b_lv] == y_at +def test_unify_default_output_Variable(): + """Make sure that we can unify with the default output of an Apply node.""" + + class MyDefaultOutputOp(Op): + default_output = 1 + + def make_node(self, *inputs): + outputs = [MyType()(), MyType()()] + return Apply(self, list(inputs), outputs) + + def perform(self, node, inputs, outputs): + outputs[0] = np.array(np.array(inputs[0])) + outputs[1] = np.array(np.array(inputs[1])) + + x_at = at.vector("x") + y_at = at.vector("y") + op1_np = MyDefaultOutputOp() + q_at = op1_np(x_at, y_at) + + x_lv, y_lv = var("x"), var("y") + q_et = etuple(op1_np, x_lv, y_lv) + + s = unify(q_et, q_at) + assert s[x_lv] == x_at + assert s[y_lv] == y_at + + def test_unify_Op(): # These `Op`s expand into `ExpressionTuple`s op1 = CustomOp(1)