Skip to content

Commit cdaf7e6

Browse files
Check TensorFlow Operation names during meta object reification
1 parent c84b6c1 commit cdaf7e6

File tree

3 files changed

+112
-67
lines changed

3 files changed

+112
-67
lines changed

symbolic_pymc/meta.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,12 @@ def meta_reify_iter(rands):
139139
return type(rands)(reified_rands), any_unreified
140140

141141

142+
class MetaReificationError(Exception):
143+
"""An exception type for errors encountered during the creation of base objects from meta objects."""
144+
145+
pass
146+
147+
142148
class MetaSymbolType(abc.ABCMeta):
143149
def __new__(cls, name, bases, clsdict):
144150

symbolic_pymc/tensorflow/meta.py

Lines changed: 67 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,16 @@
1616

1717
from google.protobuf.message import Message
1818

19-
from tensorflow.python.framework import tensor_util, op_def_registry, op_def_library, tensor_shape
19+
from tensorflow.python.framework import (
20+
tensor_util,
21+
op_def_registry,
22+
op_def_library,
23+
tensor_shape,
24+
ops,
25+
)
2026
from tensorflow.core.framework.op_def_pb2 import OpDef
2127
from tensorflow.core.framework.node_def_pb2 import NodeDef
2228

23-
# from tensorflow.core.framework.tensor_shape_pb2 import TensorShapeProto
24-
2529
from tensorflow_probability import distributions as tfd
2630

2731

@@ -30,6 +34,7 @@
3034
MetaSymbolType,
3135
MetaOp,
3236
MetaVariable,
37+
MetaReificationError,
3338
meta_reify_iter,
3439
_metatize,
3540
metatize,
@@ -61,51 +66,53 @@ class MetaOpDefLibrary(object):
6166
opdef_signatures = {}
6267

6368
@classmethod
64-
def apply_op(cls, *args, **kwargs):
65-
return op_def_library.apply_op(*args, **kwargs)
69+
def get_op_info(cls, opdef):
70+
"""Return the TF Python API function signature for a given `OpDef`.
71+
72+
Parameter
73+
---------
74+
opdef: str or `OpDef` object (meta or base)
75+
"""
76+
if isinstance(opdef, str):
77+
opdef_name = opdef
78+
opdef = op_def_registry.get(opdef_name)
79+
else:
80+
opdef_name = opdef.name
81+
82+
opdef_sig = cls.opdef_signatures.get(opdef_name, None)
83+
84+
if opdef_sig is None and opdef is not None:
85+
opdef_func = getattr(tf.raw_ops, opdef.name, None)
86+
opdef_sig = cls.make_opdef_sig(opdef, opdef_func)
87+
cls.opdef_signatures[opdef.name] = opdef_sig
88+
89+
return opdef_sig
6690

6791
@classmethod
6892
def make_opdef_sig(cls, opdef, opdef_py_func=None):
6993
"""Create a `Signature` object for an `OpDef`.
7094
7195
Annotations are include so that one can partially verify arguments.
7296
"""
73-
input_args = OrderedDict([(a.name, a.type or a.type_attr) for a in opdef.input_arg])
74-
attrs = OrderedDict([(a.name, a) for a in opdef.attr])
75-
76-
params = OrderedDict()
7797
if opdef_py_func:
98+
#
7899
# We assume we're dealing with a function from `tf.raw_ops`.
79-
# Those functions have only the necessary `input_arg`s and
80-
# `attr` inputs as arguments.
100+
# Those functions have only the necessary `input_arg`s and `attr`
101+
# inputs as arguments.
102+
#
81103
opdef_func_sig = Signature.from_callable(opdef_py_func)
82104
params = opdef_func_sig.parameters
83105

84-
# for name, param in opdef_func_sig.parameters.items():
85-
# # We make positional parameters permissible (since the
86-
# # functions in `tf.raw_ops` are keyword-only), and we use the
87-
# # `tf.raw_ops` arguments to determine the *actual* required
88-
# # arguments (because `OpDef`'s `input_arg`s and `attrs` aren't
89-
# # exactly clear about that).
90-
# if name in input_args:
91-
# new_default = Parameter.empty
92-
# new_annotation = input_args[name]
93-
# else:
94-
# new_default = None
95-
# new_annotation = attrs.get(name, None)
96-
# if new_annotation is not None:
97-
# new_annotation = new_annotation.type
106+
else:
107+
#
108+
# We're crafting an `Operation` at a low-level via `apply_op`
109+
# (like the functions in `tf.raw_ops` do)
98110
#
99-
# new_param = param.replace(
100-
# kind=Parameter.POSITIONAL_OR_KEYWORD,
101-
# default=new_default,
102-
# annotation=new_annotation,
103-
# )
104-
# params[name] = new_param
111+
input_args = OrderedDict([(a.name, a.type or a.type_attr) for a in opdef.input_arg])
112+
attrs = OrderedDict([(a.name, a) for a in opdef.attr])
113+
params = OrderedDict()
105114

106-
else:
107-
# We're crafting the Operation at a low-level via `apply_op`.
108-
opdef_py_func = partial(op_def_lib.apply_op, opdef.name)
115+
opdef_py_func = partial(op_def_library.apply_op, opdef.name)
109116

110117
for i_name, i_type in input_args.items():
111118
p = Parameter(i_name, Parameter.POSITIONAL_OR_KEYWORD, annotation=i_type)
@@ -144,29 +151,6 @@ def make_opdef_sig(cls, opdef, opdef_py_func=None):
144151
)
145152
return opdef_sig, opdef_py_func
146153

147-
@classmethod
148-
def get_op_info(cls, opdef):
149-
"""Return the TF Python API function signature for a given `OpDef`.
150-
151-
Parameter
152-
---------
153-
opdef: str or `OpDef` object (meta or base)
154-
"""
155-
if isinstance(opdef, str):
156-
opdef_name = opdef
157-
opdef = op_def_registry.get(opdef_name)
158-
else:
159-
opdef_name = opdef.name
160-
161-
opdef_sig = cls.opdef_signatures.get(opdef_name, None)
162-
163-
if opdef_sig is None and opdef is not None:
164-
opdef_func = getattr(tf.raw_ops, opdef.name, None)
165-
opdef_sig = cls.make_opdef_sig(opdef, opdef_func)
166-
cls.opdef_signatures[opdef.name] = cls.make_opdef_sig(opdef, opdef_func)
167-
168-
return opdef_sig
169-
170154

171155
op_def_lib = MetaOpDefLibrary()
172156

@@ -183,7 +167,6 @@ def _metatize_tf_object(obj):
183167
def load_dispatcher():
184168
"""Set/override dispatcher to default to TF objects."""
185169

186-
from tensorflow.python.framework.ops import EagerTensor
187170
from tensorflow.python.ops.gen_linalg_ops import _SvdOutput
188171

189172
def _metatize_tf_svd(obj):
@@ -200,7 +183,7 @@ def _metatize_tf_eager(obj):
200183
" (e.g. within `tensorflow.python.eager.context.graph_mode`)"
201184
)
202185

203-
meta._metatize.add((EagerTensor,), _metatize_tf_eager)
186+
meta._metatize.add((ops.EagerTensor,), _metatize_tf_eager)
204187

205188
meta._metatize.add((object,), _metatize_tf_object)
206189
meta._metatize.add((HashableNDArray,), _metatize_tf_object)
@@ -599,12 +582,30 @@ def reify(self):
599582
)
600583

601584
if not (op_inputs_unreified or op_attrs_unreified or isvar(self.name)):
602-
603-
apply_arguments = operator.input_args(*op_inputs, name=self.name, **op_attrs)
604-
tf_out = operator._apply_func(**apply_arguments)
605-
op_tf = tf_out.op
606-
607-
# TODO: Update NodeDef attrs?
585+
#
586+
# An operation with this name might already exist in the graph
587+
#
588+
try:
589+
existing_op = ops.get_default_graph().get_operation_by_name(self.name)
590+
except KeyError:
591+
#
592+
# There is no such `Operation`, so we attempt to create it
593+
#
594+
apply_arguments = operator.input_args(*op_inputs, name=self.name, **op_attrs)
595+
tf_out = operator._apply_func(**apply_arguments)
596+
op_tf = tf_out.op
597+
else:
598+
#
599+
# An `Operation` with this name exists, let's make sure it's
600+
# equivalent to this meta `Operation`
601+
#
602+
if self != mt(existing_op):
603+
raise MetaReificationError(
604+
f"An Operation with the name {self.name}"
605+
" already exists in the graph and is not"
606+
" equal to this meta object."
607+
)
608+
op_tf = existing_op
608609

609610
assert op_tf is not None
610611
self._obj = op_tf
@@ -1149,4 +1150,5 @@ def __getattr__(self, obj):
11491150

11501151
mt = TFlowMetaAccessor()
11511152

1153+
11521154
load_dispatcher()

tests/tensorflow/test_meta.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,16 @@
1515
from unification import var, isvar
1616

1717
from symbolic_pymc.utils import HashableNDArray
18-
from symbolic_pymc.meta import MetaSymbol, disable_auto_reification, enable_lvar_defaults
18+
from symbolic_pymc.meta import (MetaSymbol, disable_auto_reification,
19+
enable_lvar_defaults)
1920
from symbolic_pymc.tensorflow.meta import (TFlowMetaTensor,
2021
TFlowMetaTensorShape,
2122
TFlowMetaOp,
2223
TFlowMetaOpDef,
2324
TFlowMetaNodeDef,
2425
TFlowMetaOperator,
2526
MetaOpDefLibrary,
27+
MetaReificationError,
2628
mt)
2729

2830
from tests.tensorflow import run_in_graph_mode
@@ -212,7 +214,7 @@ def test_meta_basic():
212214

213215

214216
@run_in_graph_mode
215-
def test_meta_Op():
217+
def test_meta_operation():
216218

217219
t1_tf = tf.convert_to_tensor([[1, 2, 3], [4, 5, 6]])
218220
t2_tf = tf.convert_to_tensor([[7, 8, 9], [10, 11, 12]])
@@ -673,3 +675,38 @@ def test_global_options():
673675
with tf.Graph().as_default(), enable_lvar_defaults('names'):
674676
a_mt = mt(1.0)
675677
assert isvar(a_mt.name)
678+
679+
680+
@run_in_graph_mode
681+
def test_meta_existing_names():
682+
683+
with tf.Graph().as_default():
684+
one_mt = mt(1)
685+
assert one_mt.op.name == 'Const'
686+
687+
# Clear-out the associated base variable
688+
orig_one_tf = one_mt._obj
689+
one_mt.reset()
690+
one_mt.op.reset()
691+
assert one_mt.obj is None
692+
assert one_mt.op.obj is None
693+
694+
# Attempt to reify to a base variable
695+
one_tf = one_mt.reify()
696+
assert one_tf.op.name == 'Const'
697+
# Make sure it's the first base variable we created
698+
assert orig_one_tf is one_tf
699+
700+
two_mt = mt(2)
701+
two_mt.op.node_def.name = 'Const'
702+
703+
# TODO FIXME: We shouldn't have to do this manually after changing a
704+
# dependency.
705+
two_mt.reset()
706+
two_mt.op.reset()
707+
assert two_mt.obj is None
708+
assert two_mt.op.obj is None
709+
assert two_mt.op.name == 'Const'
710+
711+
with pytest.raises(MetaReificationError):
712+
two_mt.reify()

0 commit comments

Comments
 (0)