48
48
tf_metatize_cache = Cache (50 )
49
49
50
50
51
+ class DefaultTensorName (str ):
52
+ """A type used to indicate a default tensor name."""
53
+
54
+ pass
55
+
56
+
51
57
class MetaOpDefLibrary (object ):
52
58
"""A singleton-like object that holds correspondences between TF Python API functions and the `OpDef`s they construct.
53
59
@@ -366,10 +372,16 @@ def _protobuf_convert(cls, k, v):
366
372
raise TypeError (f"Could not convert { k } " )
367
373
368
374
def __init__ (self , op , name , attr , obj = None ):
375
+ """Create a TF meta NodeDef.
376
+
377
+ XXX: Meta NodeDefs with `name == None` have a special meaning;
378
+ their names are uniquely generated. We still consider them equal
379
+ (when every other property is equal, of course).
380
+ """
369
381
super ().__init__ (obj = obj )
370
382
self .op = metatize (op )
371
383
assert name is not None
372
- self .name = name if isvar (name ) else str ( name )
384
+ self .name = name if isvar (name ) else name
373
385
374
386
if not isvar (attr ):
375
387
opdef_sig , _ = op_def_lib .get_op_info (self .op )
@@ -600,6 +612,11 @@ def reify(self):
600
612
# An operation with this name might already exist in the graph
601
613
#
602
614
try :
615
+ # FIXME: Lame hack
616
+ if isinstance (self .name , DefaultTensorName ):
617
+ # Use a unique version of the default name.
618
+ raise KeyError ()
619
+
603
620
existing_op = ops .get_default_graph ().get_operation_by_name (self .name )
604
621
except KeyError :
605
622
#
@@ -613,7 +630,15 @@ def reify(self):
613
630
# An `Operation` with this name exists, let's make sure it's
614
631
# equivalent to this meta `Operation`
615
632
#
616
- if self != mt (existing_op ):
633
+ existing_op_mt = mt (existing_op )
634
+
635
+ # # Since we can't exactly reproduce all NodeDef.attr information
636
+ # # (e.g. dtypes), we need to remove any unnecessary NodeDef.attr
637
+ # # fields from comparisons with same-named nodes in the graph.
638
+ # if op_attrs.keys() != node_attr.keys():
639
+ # existing_op_mt.node_def.attr = node_attr
640
+
641
+ if self != existing_op_mt :
617
642
raise MetaReificationError (
618
643
f"An Operation with the name { self .name } "
619
644
" already exists in the graph and is not"
@@ -725,40 +750,40 @@ def reify(self):
725
750
726
751
def __truediv__ (self , y ):
727
752
# TODO: TF performs some dtype logic (using `dtype.base_dtype`) and casting here.
728
- return mt .realdiv (self , y , name = "truediv" )
753
+ return mt .realdiv (self , y , name = DefaultTensorName ( "truediv" ) )
729
754
730
755
def __rtruediv__ (self , x ):
731
756
# TODO: TF performs some dtype logic (using `dtype.base_dtype`) and casting here.
732
- return mt .realdiv (x , self , name = "truediv" )
757
+ return mt .realdiv (x , self , name = DefaultTensorName ( "truediv" ) )
733
758
734
759
def __add__ (self , y ):
735
760
# TODO: If `self.dtype == tf.dtypes.string`, use `mt.add`
736
- return mt .addv2 (self , y , name = "add" )
761
+ return mt .addv2 (self , y , name = DefaultTensorName ( "add" ) )
737
762
738
763
def __radd__ (self , x ):
739
764
# TODO: If `x.dtype == tf.dtypes.string`, use `mt.add`
740
- return mt .addv2 (x , self , name = "add" )
765
+ return mt .addv2 (x , self , name = DefaultTensorName ( "add" ) )
741
766
742
767
def __sub__ (self , y ):
743
- return mt .sub (self , y , name = "sub" )
768
+ return mt .sub (self , y , name = DefaultTensorName ( "sub" ) )
744
769
745
770
def __rsub__ (self , x ):
746
- return mt .sub (x , self , name = "sub" )
771
+ return mt .sub (x , self , name = DefaultTensorName ( "sub" ) )
747
772
748
773
def __mul__ (self , y ):
749
- return mt .mul (self , y , name = "mul" )
774
+ return mt .mul (self , y , name = DefaultTensorName ( "mul" ) )
750
775
751
776
def __rmul__ (self , x ):
752
- return mt .mul (x , self , name = "mul" )
777
+ return mt .mul (x , self , name = DefaultTensorName ( "mul" ) )
753
778
754
779
def __abs__ (self ):
755
- return mt .abs (self , name = "Abs" )
780
+ return mt .abs (self , name = DefaultTensorName ( "Abs" ) )
756
781
757
782
def __pow__ (self , y ):
758
- return mt .pow (self , y , name = "pow" )
783
+ return mt .pow (self , y , name = DefaultTensorName ( "pow" ) )
759
784
760
785
def __neg__ (self ):
761
- return mt .neg (self , name = "Neg" )
786
+ return mt .neg (self , name = DefaultTensorName ( "Neg" ) )
762
787
763
788
764
789
class TFlowMetaTensorShape (TFlowMetaSymbol ):
@@ -987,48 +1012,22 @@ def __api_call__(self, *args, **kwargs):
987
1012
988
1013
if not op_args_unreified :
989
1014
990
- res_var = None
991
- # name = op_args.get("name", None)
992
- #
993
- # if name is not None:
994
- # #
995
- # # An operation with this name might already exist in the graph
996
- # #
997
1015
#
998
- # from tensorflow.python.framework import ops
1016
+ # We create the `Operation` in the graph
999
1017
#
1000
- # try:
1001
- # this_op = ops.get_default_graph().get_operation_by_name(name)
1002
- # except KeyError:
1003
- # pass
1004
- # else:
1005
- # # TODO: Make sure the existing `Operation` matches our arguments
1006
- # assert this_op.type == self.op_def.obj.name
1007
- #
1008
- # this_op = mt(this_op)
1009
- # op_inputs, op_node_def = self.op_args_to_operation_inputs(op_args)
1010
- # assert op_inputs == this_op.inputs
1011
- # assert op_node_def == this_op.node_def
1012
- # res_var = this_op.default_output
1013
-
1014
- if res_var is None :
1015
- #
1016
- # We create the `Operation` in the graph
1017
- #
1018
-
1019
- tf_out = self ._apply_func (** op_args )
1020
-
1021
- # Ensure that the original meta objects will be available
1022
- # for use in the `metatize` that follows
1023
- tf_metatize_cache .update (
1024
- {
1025
- k : v
1026
- for k , v in zip (op_args .values (), apply_arguments .values ())
1027
- if isinstance (k , tf .Tensor )
1028
- }
1029
- )
1018
+ tf_out = self ._apply_func (** op_args )
1019
+
1020
+ # Ensure that the original meta objects will be available
1021
+ # for use in the `metatize` that follows
1022
+ tf_metatize_cache .update (
1023
+ {
1024
+ k : v
1025
+ for k , v in zip (op_args .values (), apply_arguments .values ())
1026
+ if isinstance (k , tf .Tensor )
1027
+ }
1028
+ )
1030
1029
1031
- res_var = metatize (tf_out )
1030
+ res_var = metatize (tf_out )
1032
1031
1033
1032
if "names" in meta ._lvar_defaults_enabled :
1034
1033
# This should also reset the NodeDef's `obj`
@@ -1073,7 +1072,8 @@ def op_args_to_operation_inputs(self, apply_arguments):
1073
1072
node_attr = var ()
1074
1073
1075
1074
if "names" not in meta ._lvar_defaults_enabled :
1076
- op_name = apply_arguments .get ("name" , op_def_tf .name ) or op_def_tf .name
1075
+ default_name = DefaultTensorName (op_def_tf .name )
1076
+ op_name = apply_arguments .get ("name" , default_name ) or default_name
1077
1077
else :
1078
1078
op_name = var ()
1079
1079
0 commit comments