16
16
17
17
from google .protobuf .message import Message
18
18
19
- from tensorflow .python .framework import tensor_util , op_def_registry , op_def_library , tensor_shape
19
+ from tensorflow .python .framework import (
20
+ tensor_util ,
21
+ op_def_registry ,
22
+ op_def_library ,
23
+ tensor_shape ,
24
+ ops ,
25
+ )
20
26
from tensorflow .core .framework .op_def_pb2 import OpDef
21
27
from tensorflow .core .framework .node_def_pb2 import NodeDef
22
28
23
- # from tensorflow.core.framework.tensor_shape_pb2 import TensorShapeProto
24
-
25
29
from tensorflow_probability import distributions as tfd
26
30
27
31
30
34
MetaSymbolType ,
31
35
MetaOp ,
32
36
MetaVariable ,
37
+ MetaReificationError ,
33
38
meta_reify_iter ,
34
39
_metatize ,
35
40
metatize ,
@@ -61,51 +66,53 @@ class MetaOpDefLibrary(object):
61
66
opdef_signatures = {}
62
67
63
68
@classmethod
64
- def apply_op (cls , * args , ** kwargs ):
65
- return op_def_library .apply_op (* args , ** kwargs )
69
+ def get_op_info (cls , opdef ):
70
+ """Return the TF Python API function signature for a given `OpDef`.
71
+
72
+ Parameter
73
+ ---------
74
+ opdef: str or `OpDef` object (meta or base)
75
+ """
76
+ if isinstance (opdef , str ):
77
+ opdef_name = opdef
78
+ opdef = op_def_registry .get (opdef_name )
79
+ else :
80
+ opdef_name = opdef .name
81
+
82
+ opdef_sig = cls .opdef_signatures .get (opdef_name , None )
83
+
84
+ if opdef_sig is None and opdef is not None :
85
+ opdef_func = getattr (tf .raw_ops , opdef .name , None )
86
+ opdef_sig = cls .make_opdef_sig (opdef , opdef_func )
87
+ cls .opdef_signatures [opdef .name ] = opdef_sig
88
+
89
+ return opdef_sig
66
90
67
91
@classmethod
68
92
def make_opdef_sig (cls , opdef , opdef_py_func = None ):
69
93
"""Create a `Signature` object for an `OpDef`.
70
94
71
95
Annotations are include so that one can partially verify arguments.
72
96
"""
73
- input_args = OrderedDict ([(a .name , a .type or a .type_attr ) for a in opdef .input_arg ])
74
- attrs = OrderedDict ([(a .name , a ) for a in opdef .attr ])
75
-
76
- params = OrderedDict ()
77
97
if opdef_py_func :
98
+ #
78
99
# We assume we're dealing with a function from `tf.raw_ops`.
79
- # Those functions have only the necessary `input_arg`s and
80
- # `attr` inputs as arguments.
100
+ # Those functions have only the necessary `input_arg`s and `attr`
101
+ # inputs as arguments.
102
+ #
81
103
opdef_func_sig = Signature .from_callable (opdef_py_func )
82
104
params = opdef_func_sig .parameters
83
105
84
- # for name, param in opdef_func_sig.parameters.items():
85
- # # We make positional parameters permissible (since the
86
- # # functions in `tf.raw_ops` are keyword-only), and we use the
87
- # # `tf.raw_ops` arguments to determine the *actual* required
88
- # # arguments (because `OpDef`'s `input_arg`s and `attrs` aren't
89
- # # exactly clear about that).
90
- # if name in input_args:
91
- # new_default = Parameter.empty
92
- # new_annotation = input_args[name]
93
- # else:
94
- # new_default = None
95
- # new_annotation = attrs.get(name, None)
96
- # if new_annotation is not None:
97
- # new_annotation = new_annotation.type
106
+ else :
107
+ #
108
+ # We're crafting an `Operation` at a low-level via `apply_op`
109
+ # (like the functions in `tf.raw_ops` do)
98
110
#
99
- # new_param = param.replace(
100
- # kind=Parameter.POSITIONAL_OR_KEYWORD,
101
- # default=new_default,
102
- # annotation=new_annotation,
103
- # )
104
- # params[name] = new_param
111
+ input_args = OrderedDict ([(a .name , a .type or a .type_attr ) for a in opdef .input_arg ])
112
+ attrs = OrderedDict ([(a .name , a ) for a in opdef .attr ])
113
+ params = OrderedDict ()
105
114
106
- else :
107
- # We're crafting the Operation at a low-level via `apply_op`.
108
- opdef_py_func = partial (op_def_lib .apply_op , opdef .name )
115
+ opdef_py_func = partial (op_def_library .apply_op , opdef .name )
109
116
110
117
for i_name , i_type in input_args .items ():
111
118
p = Parameter (i_name , Parameter .POSITIONAL_OR_KEYWORD , annotation = i_type )
@@ -144,29 +151,6 @@ def make_opdef_sig(cls, opdef, opdef_py_func=None):
144
151
)
145
152
return opdef_sig , opdef_py_func
146
153
147
- @classmethod
148
- def get_op_info (cls , opdef ):
149
- """Return the TF Python API function signature for a given `OpDef`.
150
-
151
- Parameter
152
- ---------
153
- opdef: str or `OpDef` object (meta or base)
154
- """
155
- if isinstance (opdef , str ):
156
- opdef_name = opdef
157
- opdef = op_def_registry .get (opdef_name )
158
- else :
159
- opdef_name = opdef .name
160
-
161
- opdef_sig = cls .opdef_signatures .get (opdef_name , None )
162
-
163
- if opdef_sig is None and opdef is not None :
164
- opdef_func = getattr (tf .raw_ops , opdef .name , None )
165
- opdef_sig = cls .make_opdef_sig (opdef , opdef_func )
166
- cls .opdef_signatures [opdef .name ] = cls .make_opdef_sig (opdef , opdef_func )
167
-
168
- return opdef_sig
169
-
170
154
171
155
op_def_lib = MetaOpDefLibrary ()
172
156
@@ -183,7 +167,6 @@ def _metatize_tf_object(obj):
183
167
def load_dispatcher ():
184
168
"""Set/override dispatcher to default to TF objects."""
185
169
186
- from tensorflow .python .framework .ops import EagerTensor
187
170
from tensorflow .python .ops .gen_linalg_ops import _SvdOutput
188
171
189
172
def _metatize_tf_svd (obj ):
@@ -200,7 +183,7 @@ def _metatize_tf_eager(obj):
200
183
" (e.g. within `tensorflow.python.eager.context.graph_mode`)"
201
184
)
202
185
203
- meta ._metatize .add ((EagerTensor ,), _metatize_tf_eager )
186
+ meta ._metatize .add ((ops . EagerTensor ,), _metatize_tf_eager )
204
187
205
188
meta ._metatize .add ((object ,), _metatize_tf_object )
206
189
meta ._metatize .add ((HashableNDArray ,), _metatize_tf_object )
@@ -599,12 +582,30 @@ def reify(self):
599
582
)
600
583
601
584
if not (op_inputs_unreified or op_attrs_unreified or isvar (self .name )):
602
-
603
- apply_arguments = operator .input_args (* op_inputs , name = self .name , ** op_attrs )
604
- tf_out = operator ._apply_func (** apply_arguments )
605
- op_tf = tf_out .op
606
-
607
- # TODO: Update NodeDef attrs?
585
+ #
586
+ # An operation with this name might already exist in the graph
587
+ #
588
+ try :
589
+ existing_op = ops .get_default_graph ().get_operation_by_name (self .name )
590
+ except KeyError :
591
+ #
592
+ # There is no such `Operation`, so we attempt to create it
593
+ #
594
+ apply_arguments = operator .input_args (* op_inputs , name = self .name , ** op_attrs )
595
+ tf_out = operator ._apply_func (** apply_arguments )
596
+ op_tf = tf_out .op
597
+ else :
598
+ #
599
+ # An `Operation` with this name exists, let's make sure it's
600
+ # equivalent to this meta `Operation`
601
+ #
602
+ if self != mt (existing_op ):
603
+ raise MetaReificationError (
604
+ f"An Operation with the name { self .name } "
605
+ " already exists in the graph and is not"
606
+ " equal to this meta object."
607
+ )
608
+ op_tf = existing_op
608
609
609
610
assert op_tf is not None
610
611
self ._obj = op_tf
@@ -1149,4 +1150,5 @@ def __getattr__(self, obj):
1149
1150
1150
1151
mt = TFlowMetaAccessor ()
1151
1152
1153
+
1152
1154
load_dispatcher ()
0 commit comments