16
16
17
17
from google .protobuf .message import Message
18
18
19
- from tensorflow .python .framework import tensor_util , op_def_registry , op_def_library , tensor_shape
19
+ from tensorflow .python .framework import (
20
+ tensor_util ,
21
+ op_def_registry ,
22
+ op_def_library ,
23
+ tensor_shape ,
24
+ ops ,
25
+ )
20
26
from tensorflow .core .framework .op_def_pb2 import OpDef
21
27
from tensorflow .core .framework .node_def_pb2 import NodeDef
22
28
23
- # from tensorflow.core.framework.tensor_shape_pb2 import TensorShapeProto
24
-
25
29
from tensorflow_probability import distributions as tfd
26
30
27
31
30
34
MetaSymbolType ,
31
35
MetaOp ,
32
36
MetaVariable ,
37
+ MetaReificationError ,
33
38
meta_reify_iter ,
34
39
_metatize ,
35
40
metatize ,
@@ -60,52 +65,68 @@ class MetaOpDefLibrary(object):
60
65
}
61
66
opdef_signatures = {}
62
67
68
+ def __init__ (self ):
69
+ #
70
+ # We need this in order to construct "Const" tensors directly, since
71
+ # the "value" attr in a meta `NodeDef` is just a NumPy array and not
72
+ # the `TensorProto` expected by `raw_ops.Const`.
73
+ #
74
+ def mt_const (value , dtype , name = None ):
75
+ return tf .raw_ops .Const (
76
+ value = tensor_util .make_tensor_proto (value ), dtype = dtype , name = name
77
+ )
78
+
79
+ opdef = op_def_registry .get ("Const" )
80
+ self .opdef_signatures [opdef .name ] = self .make_opdef_sig (opdef , mt_const )
81
+
63
82
@classmethod
64
- def apply_op (cls , * args , ** kwargs ):
65
- return op_def_library .apply_op (* args , ** kwargs )
83
+ def get_op_info (cls , opdef ):
84
+ """Return the TF Python API function signature for a given `OpDef`.
85
+
86
+ Parameter
87
+ ---------
88
+ opdef: str or `OpDef` object (meta or base)
89
+ """
90
+ if isinstance (opdef , str ):
91
+ opdef_name = opdef
92
+ opdef = op_def_registry .get (opdef_name )
93
+ else :
94
+ opdef_name = opdef .name
95
+
96
+ opdef_sig = cls .opdef_signatures .get (opdef_name , None )
97
+
98
+ if opdef_sig is None and opdef is not None :
99
+ opdef_func = getattr (tf .raw_ops , opdef .name , None )
100
+ opdef_sig = cls .make_opdef_sig (opdef , opdef_func )
101
+ cls .opdef_signatures [opdef .name ] = opdef_sig
102
+
103
+ return opdef_sig
66
104
67
105
@classmethod
68
106
def make_opdef_sig (cls , opdef , opdef_py_func = None ):
69
107
"""Create a `Signature` object for an `OpDef`.
70
108
71
109
Annotations are include so that one can partially verify arguments.
72
110
"""
73
- input_args = OrderedDict ([(a .name , a .type or a .type_attr ) for a in opdef .input_arg ])
74
- attrs = OrderedDict ([(a .name , a ) for a in opdef .attr ])
75
-
76
- params = OrderedDict ()
77
111
if opdef_py_func :
112
+ #
78
113
# We assume we're dealing with a function from `tf.raw_ops`.
79
- # Those functions have only the necessary `input_arg`s and
80
- # `attr` inputs as arguments.
114
+ # Those functions have only the necessary `input_arg`s and `attr`
115
+ # inputs as arguments.
116
+ #
81
117
opdef_func_sig = Signature .from_callable (opdef_py_func )
82
118
params = opdef_func_sig .parameters
83
119
84
- # for name, param in opdef_func_sig.parameters.items():
85
- # # We make positional parameters permissible (since the
86
- # # functions in `tf.raw_ops` are keyword-only), and we use the
87
- # # `tf.raw_ops` arguments to determine the *actual* required
88
- # # arguments (because `OpDef`'s `input_arg`s and `attrs` aren't
89
- # # exactly clear about that).
90
- # if name in input_args:
91
- # new_default = Parameter.empty
92
- # new_annotation = input_args[name]
93
- # else:
94
- # new_default = None
95
- # new_annotation = attrs.get(name, None)
96
- # if new_annotation is not None:
97
- # new_annotation = new_annotation.type
120
+ else :
98
121
#
99
- # new_param = param.replace(
100
- # kind=Parameter.POSITIONAL_OR_KEYWORD,
101
- # default=new_default,
102
- # annotation=new_annotation,
103
- # )
104
- # params[name] = new_param
122
+ # We're crafting an `Operation` at a low-level via `apply_op`
123
+ # (like the functions in `tf.raw_ops` do)
124
+ #
125
+ input_args = OrderedDict ([( a . name , a . type or a . type_attr ) for a in opdef . input_arg ])
126
+ attrs = OrderedDict ([( a . name , a ) for a in opdef . attr ] )
127
+ params = OrderedDict ()
105
128
106
- else :
107
- # We're crafting the Operation at a low-level via `apply_op`.
108
- opdef_py_func = partial (op_def_lib .apply_op , opdef .name )
129
+ opdef_py_func = partial (op_def_library .apply_op , opdef .name )
109
130
110
131
for i_name , i_type in input_args .items ():
111
132
p = Parameter (i_name , Parameter .POSITIONAL_OR_KEYWORD , annotation = i_type )
@@ -144,29 +165,6 @@ def make_opdef_sig(cls, opdef, opdef_py_func=None):
144
165
)
145
166
return opdef_sig , opdef_py_func
146
167
147
- @classmethod
148
- def get_op_info (cls , opdef ):
149
- """Return the TF Python API function signature for a given `OpDef`.
150
-
151
- Parameter
152
- ---------
153
- opdef: str or `OpDef` object (meta or base)
154
- """
155
- if isinstance (opdef , str ):
156
- opdef_name = opdef
157
- opdef = op_def_registry .get (opdef_name )
158
- else :
159
- opdef_name = opdef .name
160
-
161
- opdef_sig = cls .opdef_signatures .get (opdef_name , None )
162
-
163
- if opdef_sig is None and opdef is not None :
164
- opdef_func = getattr (tf .raw_ops , opdef .name , None )
165
- opdef_sig = cls .make_opdef_sig (opdef , opdef_func )
166
- cls .opdef_signatures [opdef .name ] = cls .make_opdef_sig (opdef , opdef_func )
167
-
168
- return opdef_sig
169
-
170
168
171
169
op_def_lib = MetaOpDefLibrary ()
172
170
@@ -183,7 +181,6 @@ def _metatize_tf_object(obj):
183
181
def load_dispatcher ():
184
182
"""Set/override dispatcher to default to TF objects."""
185
183
186
- from tensorflow .python .framework .ops import EagerTensor
187
184
from tensorflow .python .ops .gen_linalg_ops import _SvdOutput
188
185
189
186
def _metatize_tf_svd (obj ):
@@ -200,7 +197,7 @@ def _metatize_tf_eager(obj):
200
197
" (e.g. within `tensorflow.python.eager.context.graph_mode`)"
201
198
)
202
199
203
- meta ._metatize .add ((EagerTensor ,), _metatize_tf_eager )
200
+ meta ._metatize .add ((ops . EagerTensor ,), _metatize_tf_eager )
204
201
205
202
meta ._metatize .add ((object ,), _metatize_tf_object )
206
203
meta ._metatize .add ((HashableNDArray ,), _metatize_tf_object )
@@ -599,12 +596,30 @@ def reify(self):
599
596
)
600
597
601
598
if not (op_inputs_unreified or op_attrs_unreified or isvar (self .name )):
602
-
603
- apply_arguments = operator .input_args (* op_inputs , name = self .name , ** op_attrs )
604
- tf_out = operator ._apply_func (** apply_arguments )
605
- op_tf = tf_out .op
606
-
607
- # TODO: Update NodeDef attrs?
599
+ #
600
+ # An operation with this name might already exist in the graph
601
+ #
602
+ try :
603
+ existing_op = ops .get_default_graph ().get_operation_by_name (self .name )
604
+ except KeyError :
605
+ #
606
+ # There is no such `Operation`, so we attempt to create it
607
+ #
608
+ apply_arguments = operator .input_args (* op_inputs , name = self .name , ** op_attrs )
609
+ tf_out = operator ._apply_func (** apply_arguments )
610
+ op_tf = tf_out .op
611
+ else :
612
+ #
613
+ # An `Operation` with this name exists, let's make sure it's
614
+ # equivalent to this meta `Operation`
615
+ #
616
+ if self != mt (existing_op ):
617
+ raise MetaReificationError (
618
+ f"An Operation with the name { self .name } "
619
+ " already exists in the graph and is not"
620
+ " equal to this meta object."
621
+ )
622
+ op_tf = existing_op
608
623
609
624
assert op_tf is not None
610
625
self ._obj = op_tf
@@ -1149,4 +1164,5 @@ def __getattr__(self, obj):
1149
1164
1150
1165
mt = TFlowMetaAccessor ()
1151
1166
1167
+
1152
1168
load_dispatcher ()
0 commit comments