Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 124 additions & 12 deletions aesara/graph/rewriting/unify.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,18 @@

from collections.abc import Mapping
from numbers import Number
from typing import Dict, Optional, Tuple, Union
from typing import Callable, Dict, Optional, Tuple, Union

import numpy as np
from cons.core import ConsError, _car, _cdr
from etuples import apply, etuple, etuplize
from etuples.core import ExpressionTuple
from etuples.core import ExpressionTuple, etuple
from etuples.dispatch import etuplize_fn
from unification.core import _unify, assoc
from unification.utils import transitive_get as walk
from unification.variable import Var, isvar, var

from aesara.graph.basic import Constant, Variable
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.op import Op
from aesara.graph.type import Type

Expand Down Expand Up @@ -72,19 +73,90 @@ def __repr__(self):
return f"{type(self).__name__}({repr(self.constraint)}, {self.token})"


def car_Variable(x):
class OpExpressionTuple(ExpressionTuple):
r"""Etuple form for `Op`s.

Some `Op.__call__` signatures (e.g `RandomVariable`s) do not match their
`Op.make_node` signatures, causing `ExpressionTuple.eval_if_etuple` to
fail. To circumvent this constraint we subclass `ExpressionTuple`, and
overload the `_eval_apply` method to use `Op.make_node` instead of
`Op.__call__`.

"""

def _eval_apply_fn(self, op: Op) -> Callable:
"""We evaluate the etuple to the resulting `Apply` node's outputs."""

def eval_op(*inputs, **kwargs):
node = op.make_node(*inputs, **kwargs)
if node.nout == 1:
return node.outputs[0]
else:
return node.outputs

return eval_op

def __repr__(self):
return "Op" + super().__repr__()

def __str__(self):
return "o" + super().__repr__()


@etuple.register(Op, [object])
def etuple_Op(*args, **kwargs) -> OpExpressionTuple:
return OpExpressionTuple(args, **kwargs)


@etuplize_fn.register(Op)
def etuplize_fn_Op(_: Op):
return etuple_Op


def nth(n: int, node: Apply) -> Variable:
"""Function that selects the nth output of an `Apply` node."""
return node.outputs[n]


def car_Variable(x: Variable):
"""Return the `car` of a `Variable`.

The outputs of `Apply` nodes are stored in lists, but the `__call__`
function of the `Op` that creates this `Op` will return the single
output variable instead of the list. We can thus simply return the
`Op` as `car`.

When there are several outputs, however, `__call__` will return a list, so
returning the `Op` here would return an expression tuple that evaluates to
the list of outputs of the `Apply` node. We thus need to preprend a `nth`
operator that will return the output stored at the specified index.

"""
if x.owner:
return x.owner.op
if x.owner.nout == 1:
return x.owner.op
else:
return nth
else:
raise ConsError("Not a cons pair.")


_car.add((Variable,), car_Variable)


def cdr_Variable(x):
def cdr_Variable(x: Variable):
"""Return the `cdr` of a `Variable`

For a variable created by a single output `Apply` node the `cdr` is defined as the input list. For
a multiple output `Apply` node the `cdr` is the index of the variable in the
node's outputs list, and the `Apply` node.

"""
if x.owner:
x_e = etuple(_car(x), *x.owner.inputs, evaled_obj=x)
if x.owner.nout == 1:
x_e = etuple(_car(x), *x.owner.inputs, evaled_obj=x)
else:
x_e = etuple(_car(x), x.index, x.owner, evaled_obj=x)
else:
raise ConsError("Not a cons pair.")

Expand All @@ -94,7 +166,32 @@ def cdr_Variable(x):
_cdr.add((Variable,), cdr_Variable)


def car_Op(x):
def car_Apply(x: Apply):
"""Return the `car` of an `Apply` node.

This will only be called for multiple-output nodes.

"""
return x.op


_car.add((Apply,), car_Apply)


def cdr_Apply(x: Apply):
"""Return the `car` of an `Apply` node.

This will only be called for multiple-output nodes.

"""
x_e = etuple(_car(x), *x.inputs, evaled_obj=x.outputs)
return x_e[1:]


_cdr.add((Apply,), cdr_Apply)


def car_Op(x: Op):
if hasattr(x, "__props__"):
return type(x)

Expand All @@ -104,7 +201,7 @@ def car_Op(x):
_car.add((Op,), car_Op)


def cdr_Op(x):
def cdr_Op(x: Op):
if not hasattr(x, "__props__"):
raise ConsError("Not a cons pair.")

Expand Down Expand Up @@ -201,12 +298,27 @@ def _unify_Constant_Constant(u, v, s):


def _unify_Variable_ExpressionTuple(u, v, s):
# `Constant`s are "atomic"
"""Unify a `Variable` with an `ExpressionTuple`.

If the `Variable`'s owner only has one output we can etuplize the `Variable`
and unify both expression tuple.

If the owner has multiple outputs, but the `Op`'s `default_output` is not
`None` we unify the etuplized version of the `Variable` with an expanded
expression tuple that account for the variable selection. We only do this
for nodes with a default output (we otherwise expect the caller to use
the `nth` operator in the expression tuple).

"""
if not u.owner:
yield False
return

yield _unify(etuplize(u, shallow=True), v, s)
if u.owner.nout == 1:
yield _unify(etuplize(u, shallow=True), v, s)
elif u.owner.nout == 2 and u.owner.op.default_output is not None:
u_et = etuplize(u)
v_et = etuple(nth, u_et[1], v)
yield _unify(u_et, v_et, s)


_unify.add(
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies:
- numpy>=1.17.0
- scipy>=0.14
- filelock
- etuples
- etuples>=0.3.8
- logical-unification
- miniKanren
- cons
Expand Down
104 changes: 99 additions & 5 deletions tests/graph/rewriting/test_unify.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
import aesara.tensor as at
from aesara.graph.basic import Apply, Constant, equal_computations
from aesara.graph.op import Op
from aesara.graph.rewriting.unify import ConstrainedVar, convert_strs_to_vars
from aesara.graph.rewriting.unify import (
ConstrainedVar,
OpExpressionTuple,
convert_strs_to_vars,
)
from aesara.tensor.type import TensorType
from tests.graph.utils import MyType

Expand Down Expand Up @@ -98,7 +102,7 @@ def test_cons():
assert cdr_res == [atype_at.dtype, atype_at.shape]


def test_etuples():
def test_etuples_Op():
x_at = at.vector("x")
y_at = at.vector("y")

Expand All @@ -110,24 +114,41 @@ def test_etuples():
assert res.owner.inputs == [x_at, y_at]

w_at = etuple(at.add, x_at, y_at)
assert isinstance(w_at, OpExpressionTuple)

w_at._evaled_obj = w_at.null
res = w_at.evaled_obj
assert res.owner.op == at.add
assert res.owner.inputs == [x_at, y_at]


def test_etuples_atomic_Op():
x_at = at.vector("x")
y_at = at.vector("y")

z_at = etuple(x_at, y_at)

# This `Op` doesn't expand into an `etuple` (i.e. it's "atomic")
op1_np = CustomOpNoProps(1)

res = apply(op1_np, z_at)
assert res.owner.op == op1_np

q_at = op1_np(x_at, y_at)
res = etuplize(q_at)
assert res[0] == op1_np
q_et = etuplize(q_at)
assert isinstance(q_et, ExpressionTuple)
assert isinstance(q_et[0], CustomOpNoProps)
assert q_et[0] == op1_np

q_et._evaled_obj = q_et.null
res = q_et.evaled_obj
assert isinstance(res.owner.op, CustomOpNoProps)

with pytest.raises(TypeError):
etuplize(op1_np)


def test_etuples_multioutput_Op():
class MyMultiOutOp(Op):
def make_node(self, *inputs):
outputs = [MyType()(), MyType()()]
Expand All @@ -144,6 +165,52 @@ def perform(self, node, inputs, outputs):
assert res[0].owner.op == op1_np
assert res[1].owner.op == op1_np

# If we etuplize one of the outputs we should recover this output when
# evaluating
q_at, _ = op1_np(x_at)
q_et = etuplize(q_at)
q_et._evaled_obj = q_et.null
res = q_et.evaled_obj
assert res == q_at

# TODO: If the caller etuplizes the output list, it should recover the list
# when evaluating.
# q_at = op1_np(x_at)
# q_et = etuplize(q_at)
# q_et._evaled_obj = q_et.null
# res = q_et.evaled_obj


def test_etuples_default_output_op():
class MyDefaultOutputOp(Op):
default_output = 1

def make_node(self, *inputs):
outputs = [MyType()(), MyType()()]
return Apply(self, list(inputs), outputs)

def perform(self, node, inputs, outputs):
outputs[0] = np.array(np.array(inputs[0]))
outputs[1] = np.array(np.array(inputs[1]))

x_at = at.vector("x")
y_at = at.vector("y")
op1_np = MyDefaultOutputOp()
res = apply(op1_np, etuple(x_at, y_at))
assert res.owner.op == op1_np
assert res.owner.inputs[0] == x_at
assert res.owner.inputs[1] == y_at

# We should recover the default output when evaluting its etuplized
# counterpart.
q_at = op1_np(x_at, y_at)
q_et = etuplize(q_at)
q_et._evaled_obj = q_et.null
res = q_et.evaled_obj
assert isinstance(res.owner.op, MyDefaultOutputOp)
assert res.owner.inputs[0] == x_at
assert res.owner.inputs[1] == y_at


def test_unify_Variable():
x_at = at.vector("x")
Expand Down Expand Up @@ -189,7 +256,7 @@ def test_unify_Variable():
res = reify(z_pat_et, s)

assert isinstance(res, ExpressionTuple)
assert equal_computations([res.evaled_obj], [z_et.evaled_obj])
assert equal_computations([res.evaled_obj[0]], [z_et.evaled_obj[0]])

# `ExpressionTuple`, `Variable`
s = unify(z_et, x_at, {})
Expand All @@ -209,6 +276,33 @@ def test_unify_Variable():
assert s[b_lv] == y_at


def test_unify_default_output_Variable():
"""Make sure that we can unify with the default output of an Apply node."""

class MyDefaultOutputOp(Op):
default_output = 1

def make_node(self, *inputs):
outputs = [MyType()(), MyType()()]
return Apply(self, list(inputs), outputs)

def perform(self, node, inputs, outputs):
outputs[0] = np.array(np.array(inputs[0]))
outputs[1] = np.array(np.array(inputs[1]))

x_at = at.vector("x")
y_at = at.vector("y")
op1_np = MyDefaultOutputOp()
q_at = op1_np(x_at, y_at)

x_lv, y_lv = var("x"), var("y")
q_et = etuple(op1_np, x_lv, y_lv)

s = unify(q_et, q_at)
assert s[x_lv] == x_at
assert s[y_lv] == y_at


def test_unify_Op():
# These `Op`s expand into `ExpressionTuple`s
op1 = CustomOp(1)
Expand Down