Skip to content

Commit 7656442

Browse files
Merge pull request #90 from brandonwillard/reify-check-names-in-tf-graph
Make reify check for existing names in the current TensorFlow graph
2 parents c84b6c1 + 027bffc commit 7656442

File tree

3 files changed

+141
-67
lines changed

3 files changed

+141
-67
lines changed

symbolic_pymc/meta.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,12 @@ def meta_reify_iter(rands):
139139
return type(rands)(reified_rands), any_unreified
140140

141141

142+
class MetaReificationError(Exception):
143+
"""An exception type for errors encountered during the creation of base objects from meta objects."""
144+
145+
pass
146+
147+
142148
class MetaSymbolType(abc.ABCMeta):
143149
def __new__(cls, name, bases, clsdict):
144150

symbolic_pymc/tensorflow/meta.py

Lines changed: 81 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,16 @@
1616

1717
from google.protobuf.message import Message
1818

19-
from tensorflow.python.framework import tensor_util, op_def_registry, op_def_library, tensor_shape
19+
from tensorflow.python.framework import (
20+
tensor_util,
21+
op_def_registry,
22+
op_def_library,
23+
tensor_shape,
24+
ops,
25+
)
2026
from tensorflow.core.framework.op_def_pb2 import OpDef
2127
from tensorflow.core.framework.node_def_pb2 import NodeDef
2228

23-
# from tensorflow.core.framework.tensor_shape_pb2 import TensorShapeProto
24-
2529
from tensorflow_probability import distributions as tfd
2630

2731

@@ -30,6 +34,7 @@
3034
MetaSymbolType,
3135
MetaOp,
3236
MetaVariable,
37+
MetaReificationError,
3338
meta_reify_iter,
3439
_metatize,
3540
metatize,
@@ -60,52 +65,68 @@ class MetaOpDefLibrary(object):
6065
}
6166
opdef_signatures = {}
6267

68+
def __init__(self):
69+
#
70+
# We need this in order to construct "Const" tensors directly, since
71+
# the "value" attr in a meta `NodeDef` is just a NumPy array and not
72+
# the `TensorProto` expected by `raw_ops.Const`.
73+
#
74+
def mt_const(value, dtype, name=None):
75+
return tf.raw_ops.Const(
76+
value=tensor_util.make_tensor_proto(value), dtype=dtype, name=name
77+
)
78+
79+
opdef = op_def_registry.get("Const")
80+
self.opdef_signatures[opdef.name] = self.make_opdef_sig(opdef, mt_const)
81+
6382
@classmethod
64-
def apply_op(cls, *args, **kwargs):
65-
return op_def_library.apply_op(*args, **kwargs)
83+
def get_op_info(cls, opdef):
84+
"""Return the TF Python API function signature for a given `OpDef`.
85+
86+
Parameter
87+
---------
88+
opdef: str or `OpDef` object (meta or base)
89+
"""
90+
if isinstance(opdef, str):
91+
opdef_name = opdef
92+
opdef = op_def_registry.get(opdef_name)
93+
else:
94+
opdef_name = opdef.name
95+
96+
opdef_sig = cls.opdef_signatures.get(opdef_name, None)
97+
98+
if opdef_sig is None and opdef is not None:
99+
opdef_func = getattr(tf.raw_ops, opdef.name, None)
100+
opdef_sig = cls.make_opdef_sig(opdef, opdef_func)
101+
cls.opdef_signatures[opdef.name] = opdef_sig
102+
103+
return opdef_sig
66104

67105
@classmethod
68106
def make_opdef_sig(cls, opdef, opdef_py_func=None):
69107
"""Create a `Signature` object for an `OpDef`.
70108
71109
Annotations are include so that one can partially verify arguments.
72110
"""
73-
input_args = OrderedDict([(a.name, a.type or a.type_attr) for a in opdef.input_arg])
74-
attrs = OrderedDict([(a.name, a) for a in opdef.attr])
75-
76-
params = OrderedDict()
77111
if opdef_py_func:
112+
#
78113
# We assume we're dealing with a function from `tf.raw_ops`.
79-
# Those functions have only the necessary `input_arg`s and
80-
# `attr` inputs as arguments.
114+
# Those functions have only the necessary `input_arg`s and `attr`
115+
# inputs as arguments.
116+
#
81117
opdef_func_sig = Signature.from_callable(opdef_py_func)
82118
params = opdef_func_sig.parameters
83119

84-
# for name, param in opdef_func_sig.parameters.items():
85-
# # We make positional parameters permissible (since the
86-
# # functions in `tf.raw_ops` are keyword-only), and we use the
87-
# # `tf.raw_ops` arguments to determine the *actual* required
88-
# # arguments (because `OpDef`'s `input_arg`s and `attrs` aren't
89-
# # exactly clear about that).
90-
# if name in input_args:
91-
# new_default = Parameter.empty
92-
# new_annotation = input_args[name]
93-
# else:
94-
# new_default = None
95-
# new_annotation = attrs.get(name, None)
96-
# if new_annotation is not None:
97-
# new_annotation = new_annotation.type
120+
else:
98121
#
99-
# new_param = param.replace(
100-
# kind=Parameter.POSITIONAL_OR_KEYWORD,
101-
# default=new_default,
102-
# annotation=new_annotation,
103-
# )
104-
# params[name] = new_param
122+
# We're crafting an `Operation` at a low-level via `apply_op`
123+
# (like the functions in `tf.raw_ops` do)
124+
#
125+
input_args = OrderedDict([(a.name, a.type or a.type_attr) for a in opdef.input_arg])
126+
attrs = OrderedDict([(a.name, a) for a in opdef.attr])
127+
params = OrderedDict()
105128

106-
else:
107-
# We're crafting the Operation at a low-level via `apply_op`.
108-
opdef_py_func = partial(op_def_lib.apply_op, opdef.name)
129+
opdef_py_func = partial(op_def_library.apply_op, opdef.name)
109130

110131
for i_name, i_type in input_args.items():
111132
p = Parameter(i_name, Parameter.POSITIONAL_OR_KEYWORD, annotation=i_type)
@@ -144,29 +165,6 @@ def make_opdef_sig(cls, opdef, opdef_py_func=None):
144165
)
145166
return opdef_sig, opdef_py_func
146167

147-
@classmethod
148-
def get_op_info(cls, opdef):
149-
"""Return the TF Python API function signature for a given `OpDef`.
150-
151-
Parameter
152-
---------
153-
opdef: str or `OpDef` object (meta or base)
154-
"""
155-
if isinstance(opdef, str):
156-
opdef_name = opdef
157-
opdef = op_def_registry.get(opdef_name)
158-
else:
159-
opdef_name = opdef.name
160-
161-
opdef_sig = cls.opdef_signatures.get(opdef_name, None)
162-
163-
if opdef_sig is None and opdef is not None:
164-
opdef_func = getattr(tf.raw_ops, opdef.name, None)
165-
opdef_sig = cls.make_opdef_sig(opdef, opdef_func)
166-
cls.opdef_signatures[opdef.name] = cls.make_opdef_sig(opdef, opdef_func)
167-
168-
return opdef_sig
169-
170168

171169
op_def_lib = MetaOpDefLibrary()
172170

@@ -183,7 +181,6 @@ def _metatize_tf_object(obj):
183181
def load_dispatcher():
184182
"""Set/override dispatcher to default to TF objects."""
185183

186-
from tensorflow.python.framework.ops import EagerTensor
187184
from tensorflow.python.ops.gen_linalg_ops import _SvdOutput
188185

189186
def _metatize_tf_svd(obj):
@@ -200,7 +197,7 @@ def _metatize_tf_eager(obj):
200197
" (e.g. within `tensorflow.python.eager.context.graph_mode`)"
201198
)
202199

203-
meta._metatize.add((EagerTensor,), _metatize_tf_eager)
200+
meta._metatize.add((ops.EagerTensor,), _metatize_tf_eager)
204201

205202
meta._metatize.add((object,), _metatize_tf_object)
206203
meta._metatize.add((HashableNDArray,), _metatize_tf_object)
@@ -599,12 +596,30 @@ def reify(self):
599596
)
600597

601598
if not (op_inputs_unreified or op_attrs_unreified or isvar(self.name)):
602-
603-
apply_arguments = operator.input_args(*op_inputs, name=self.name, **op_attrs)
604-
tf_out = operator._apply_func(**apply_arguments)
605-
op_tf = tf_out.op
606-
607-
# TODO: Update NodeDef attrs?
599+
#
600+
# An operation with this name might already exist in the graph
601+
#
602+
try:
603+
existing_op = ops.get_default_graph().get_operation_by_name(self.name)
604+
except KeyError:
605+
#
606+
# There is no such `Operation`, so we attempt to create it
607+
#
608+
apply_arguments = operator.input_args(*op_inputs, name=self.name, **op_attrs)
609+
tf_out = operator._apply_func(**apply_arguments)
610+
op_tf = tf_out.op
611+
else:
612+
#
613+
# An `Operation` with this name exists, let's make sure it's
614+
# equivalent to this meta `Operation`
615+
#
616+
if self != mt(existing_op):
617+
raise MetaReificationError(
618+
f"An Operation with the name {self.name}"
619+
" already exists in the graph and is not"
620+
" equal to this meta object."
621+
)
622+
op_tf = existing_op
608623

609624
assert op_tf is not None
610625
self._obj = op_tf
@@ -1149,4 +1164,5 @@ def __getattr__(self, obj):
11491164

11501165
mt = TFlowMetaAccessor()
11511166

1167+
11521168
load_dispatcher()

tests/tensorflow/test_meta.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,16 @@
1515
from unification import var, isvar
1616

1717
from symbolic_pymc.utils import HashableNDArray
18-
from symbolic_pymc.meta import MetaSymbol, disable_auto_reification, enable_lvar_defaults
18+
from symbolic_pymc.meta import (MetaSymbol, disable_auto_reification,
19+
enable_lvar_defaults)
1920
from symbolic_pymc.tensorflow.meta import (TFlowMetaTensor,
2021
TFlowMetaTensorShape,
2122
TFlowMetaOp,
2223
TFlowMetaOpDef,
2324
TFlowMetaNodeDef,
2425
TFlowMetaOperator,
2526
MetaOpDefLibrary,
27+
MetaReificationError,
2628
mt)
2729

2830
from tests.tensorflow import run_in_graph_mode
@@ -212,7 +214,7 @@ def test_meta_basic():
212214

213215

214216
@run_in_graph_mode
215-
def test_meta_Op():
217+
def test_meta_operation():
216218

217219
t1_tf = tf.convert_to_tensor([[1, 2, 3], [4, 5, 6]])
218220
t2_tf = tf.convert_to_tensor([[7, 8, 9], [10, 11, 12]])
@@ -673,3 +675,53 @@ def test_global_options():
673675
with tf.Graph().as_default(), enable_lvar_defaults('names'):
674676
a_mt = mt(1.0)
675677
assert isvar(a_mt.name)
678+
679+
680+
@run_in_graph_mode
681+
def test_meta_const():
682+
"""Make sure we can create a Const tensor by hand."""
683+
684+
with tf.Graph().as_default():
685+
one_mt = mt.const(1, 'int32', 'Const')
686+
687+
with tf.Graph().as_default():
688+
another_one_mt = mt(1)
689+
690+
assert one_mt == another_one_mt
691+
assert isinstance(one_mt.reify(), tf.Tensor)
692+
assert one_mt.reify().op.type == 'Const'
693+
694+
695+
@run_in_graph_mode
696+
def test_meta_existing_names():
697+
698+
with tf.Graph().as_default():
699+
one_mt = mt(1)
700+
assert one_mt.op.name == 'Const'
701+
702+
# Clear-out the associated base variable
703+
orig_one_tf = one_mt._obj
704+
one_mt.reset()
705+
one_mt.op.reset()
706+
assert one_mt.obj is None
707+
assert one_mt.op.obj is None
708+
709+
# Attempt to reify to a base variable
710+
one_tf = one_mt.reify()
711+
assert one_tf.op.name == 'Const'
712+
# Make sure it's the first base variable we created
713+
assert orig_one_tf is one_tf
714+
715+
two_mt = mt(2)
716+
two_mt.op.node_def.name = 'Const'
717+
718+
# TODO FIXME: We shouldn't have to do this manually after changing a
719+
# dependency.
720+
two_mt.reset()
721+
two_mt.op.reset()
722+
assert two_mt.obj is None
723+
assert two_mt.op.obj is None
724+
assert two_mt.op.name == 'Const'
725+
726+
with pytest.raises(MetaReificationError):
727+
two_mt.reify()

0 commit comments

Comments
 (0)