Skip to content

Commit a17a14d

Browse files
[tensorflow] Introduce default tensor names
When the default tensor string name type, `DefaultTensorName`, is used, the `TFlowMetaTensor.reify` will not assume that the string name is "strict" and rely upon the base graph to assign a unique name derived from the default one. Closes #93.
1 parent b504958 commit a17a14d

File tree

3 files changed

+100
-58
lines changed

3 files changed

+100
-58
lines changed

symbolic_pymc/tensorflow/meta.py

Lines changed: 53 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@
4848
tf_metatize_cache = Cache(50)
4949

5050

51+
class DefaultTensorName(str):
52+
"""A type used to indicate a default tensor name."""
53+
54+
pass
55+
56+
5157
class MetaOpDefLibrary(object):
5258
"""A singleton-like object that holds correspondences between TF Python API functions and the `OpDef`s they construct.
5359
@@ -366,10 +372,16 @@ def _protobuf_convert(cls, k, v):
366372
raise TypeError(f"Could not convert {k}")
367373

368374
def __init__(self, op, name, attr, obj=None):
375+
"""Create a TF meta NodeDef.
376+
377+
XXX: Meta NodeDefs with `name == None` have a special meaning;
378+
their names are uniquely generated. We still consider them equal
379+
(when every other property is equal, of course).
380+
"""
369381
super().__init__(obj=obj)
370382
self.op = metatize(op)
371383
assert name is not None
372-
self.name = name if isvar(name) else str(name)
384+
self.name = name if isvar(name) else name
373385

374386
if not isvar(attr):
375387
opdef_sig, _ = op_def_lib.get_op_info(self.op)
@@ -600,6 +612,11 @@ def reify(self):
600612
# An operation with this name might already exist in the graph
601613
#
602614
try:
615+
# FIXME: Lame hack
616+
if isinstance(self.name, DefaultTensorName):
617+
# Use a unique version of the default name.
618+
raise KeyError()
619+
603620
existing_op = ops.get_default_graph().get_operation_by_name(self.name)
604621
except KeyError:
605622
#
@@ -613,7 +630,15 @@ def reify(self):
613630
# An `Operation` with this name exists, let's make sure it's
614631
# equivalent to this meta `Operation`
615632
#
616-
if self != mt(existing_op):
633+
existing_op_mt = mt(existing_op)
634+
635+
# # Since we can't exactly reproduce all NodeDef.attr information
636+
# # (e.g. dtypes), we need to remove any unnecessary NodeDef.attr
637+
# # fields from comparisons with same-named nodes in the graph.
638+
# if op_attrs.keys() != node_attr.keys():
639+
# existing_op_mt.node_def.attr = node_attr
640+
641+
if self != existing_op_mt:
617642
raise MetaReificationError(
618643
f"An Operation with the name {self.name}"
619644
" already exists in the graph and is not"
@@ -725,40 +750,40 @@ def reify(self):
725750

726751
def __truediv__(self, y):
727752
# TODO: TF performs some dtype logic (using `dtype.base_dtype`) and casting here.
728-
return mt.realdiv(self, y, name="truediv")
753+
return mt.realdiv(self, y, name=DefaultTensorName("truediv"))
729754

730755
def __rtruediv__(self, x):
731756
# TODO: TF performs some dtype logic (using `dtype.base_dtype`) and casting here.
732-
return mt.realdiv(x, self, name="truediv")
757+
return mt.realdiv(x, self, name=DefaultTensorName("truediv"))
733758

734759
def __add__(self, y):
735760
# TODO: If `self.dtype == tf.dtypes.string`, use `mt.add`
736-
return mt.addv2(self, y, name="add")
761+
return mt.addv2(self, y, name=DefaultTensorName("add"))
737762

738763
def __radd__(self, x):
739764
# TODO: If `x.dtype == tf.dtypes.string`, use `mt.add`
740-
return mt.addv2(x, self, name="add")
765+
return mt.addv2(x, self, name=DefaultTensorName("add"))
741766

742767
def __sub__(self, y):
743-
return mt.sub(self, y, name="sub")
768+
return mt.sub(self, y, name=DefaultTensorName("sub"))
744769

745770
def __rsub__(self, x):
746-
return mt.sub(x, self, name="sub")
771+
return mt.sub(x, self, name=DefaultTensorName("sub"))
747772

748773
def __mul__(self, y):
749-
return mt.mul(self, y, name="mul")
774+
return mt.mul(self, y, name=DefaultTensorName("mul"))
750775

751776
def __rmul__(self, x):
752-
return mt.mul(x, self, name="mul")
777+
return mt.mul(x, self, name=DefaultTensorName("mul"))
753778

754779
def __abs__(self):
755-
return mt.abs(self, name="Abs")
780+
return mt.abs(self, name=DefaultTensorName("Abs"))
756781

757782
def __pow__(self, y):
758-
return mt.pow(self, y, name="pow")
783+
return mt.pow(self, y, name=DefaultTensorName("pow"))
759784

760785
def __neg__(self):
761-
return mt.neg(self, name="Neg")
786+
return mt.neg(self, name=DefaultTensorName("Neg"))
762787

763788

764789
class TFlowMetaTensorShape(TFlowMetaSymbol):
@@ -987,48 +1012,22 @@ def __api_call__(self, *args, **kwargs):
9871012

9881013
if not op_args_unreified:
9891014

990-
res_var = None
991-
# name = op_args.get("name", None)
992-
#
993-
# if name is not None:
994-
# #
995-
# # An operation with this name might already exist in the graph
996-
# #
9971015
#
998-
# from tensorflow.python.framework import ops
1016+
# We create the `Operation` in the graph
9991017
#
1000-
# try:
1001-
# this_op = ops.get_default_graph().get_operation_by_name(name)
1002-
# except KeyError:
1003-
# pass
1004-
# else:
1005-
# # TODO: Make sure the existing `Operation` matches our arguments
1006-
# assert this_op.type == self.op_def.obj.name
1007-
#
1008-
# this_op = mt(this_op)
1009-
# op_inputs, op_node_def = self.op_args_to_operation_inputs(op_args)
1010-
# assert op_inputs == this_op.inputs
1011-
# assert op_node_def == this_op.node_def
1012-
# res_var = this_op.default_output
1013-
1014-
if res_var is None:
1015-
#
1016-
# We create the `Operation` in the graph
1017-
#
1018-
1019-
tf_out = self._apply_func(**op_args)
1020-
1021-
# Ensure that the original meta objects will be available
1022-
# for use in the `metatize` that follows
1023-
tf_metatize_cache.update(
1024-
{
1025-
k: v
1026-
for k, v in zip(op_args.values(), apply_arguments.values())
1027-
if isinstance(k, tf.Tensor)
1028-
}
1029-
)
1018+
tf_out = self._apply_func(**op_args)
1019+
1020+
# Ensure that the original meta objects will be available
1021+
# for use in the `metatize` that follows
1022+
tf_metatize_cache.update(
1023+
{
1024+
k: v
1025+
for k, v in zip(op_args.values(), apply_arguments.values())
1026+
if isinstance(k, tf.Tensor)
1027+
}
1028+
)
10301029

1031-
res_var = metatize(tf_out)
1030+
res_var = metatize(tf_out)
10321031

10331032
if "names" in meta._lvar_defaults_enabled:
10341033
# This should also reset the NodeDef's `obj`
@@ -1073,7 +1072,8 @@ def op_args_to_operation_inputs(self, apply_arguments):
10731072
node_attr = var()
10741073

10751074
if "names" not in meta._lvar_defaults_enabled:
1076-
op_name = apply_arguments.get("name", op_def_tf.name) or op_def_tf.name
1075+
default_name = DefaultTensorName(op_def_tf.name)
1076+
op_name = apply_arguments.get("name", default_name) or default_name
10771077
else:
10781078
op_name = var()
10791079

tests/tensorflow/test_meta.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
TFlowMetaOperator,
2626
MetaOpDefLibrary,
2727
MetaReificationError,
28+
DefaultTensorName,
2829
mt)
2930

3031
from tests.tensorflow import run_in_graph_mode
@@ -636,7 +637,7 @@ def test_global_options():
636637
with tf.Graph().as_default(), disable_auto_reification():
637638
y_mt = mt.Placeholder('float')
638639
assert y_mt.obj is None
639-
assert y_mt.name == 'Placeholder:0'
640+
assert isinstance(y_mt.op.name, DefaultTensorName)
640641
assert isinstance(y_mt.op.node_def.attr, dict)
641642

642643
with tf.Graph().as_default(), enable_lvar_defaults('names', 'node_attrs'):
@@ -706,7 +707,7 @@ def test_meta_const():
706707
@run_in_graph_mode
707708
def test_meta_existing_names():
708709

709-
with tf.Graph().as_default():
710+
with tf.Graph().as_default() as test_graph:
710711
one_mt = mt(1)
711712
assert one_mt.op.name == 'Const'
712713

@@ -723,6 +724,7 @@ def test_meta_existing_names():
723724
# Make sure it's the first base variable we created
724725
assert orig_one_tf is one_tf
725726

727+
# FYI: This implicitly creates 'Const_1'
726728
two_mt = mt(2)
727729
two_mt.op.node_def.name = 'Const'
728730

@@ -736,3 +738,16 @@ def test_meta_existing_names():
736738

737739
with pytest.raises(MetaReificationError):
738740
two_mt.reify()
741+
742+
another_one_mt = TFlowMetaOperator('Const', None)(3, var())
743+
# The following is something that would happen as a result of
744+
# reification (of the lvar in the meta object, not the meta object
745+
# itself).
746+
another_one_mt.op.node_def.attr['dtype'] = tf.int32
747+
748+
assert another_one_mt.op.name == 'Const'
749+
assert isinstance(another_one_mt.op.name, DefaultTensorName)
750+
# We need to make sure that the reified meta object actually uses a
751+
# unique name.
752+
assert isinstance(another_one_mt.reify(), tf.Tensor)
753+
assert another_one_mt.reify().op.name == 'Const_2'

tests/tensorflow/test_unify.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def test_etuple_term():
114114
# TODO FIXME: Because of the above two, this errs
115115
# add_lvar_et = etuplize(add_lvar_mt)
116116

117+
117118
@run_in_graph_mode
118119
def test_basic_unify_reify():
119120
# Test reification with manually constructed replacements
@@ -127,8 +128,11 @@ def test_basic_unify_reify():
127128

128129
test_expr = mt.add(tf.constant(1, dtype=tf.float64),
129130
mt.mul(tf.constant(2, dtype=tf.float64),
130-
x_l))
131-
test_reify_res = reify(test_expr, {x_l: a})
131+
x_l, name=var('mul_name')),
132+
name=var('add_name'))
133+
test_reify_res = reify(test_expr, {x_l: a,
134+
var('add_name'): 'Add_10',
135+
var('mul_name'): 'Mul_10'})
132136
test_base_res = test_reify_res.reify()
133137
assert isinstance(test_base_res, tf.Tensor)
134138

@@ -141,7 +145,7 @@ def test_basic_unify_reify():
141145
# Simply make sure that unification succeeds
142146
meta_expected_res = mt(expected_res)
143147
s_test = unify(test_expr, meta_expected_res, {})
144-
assert len(s_test) == 3
148+
assert len(s_test) == 5
145149

146150
assert reify(test_expr, s_test) == meta_expected_res
147151

@@ -199,3 +203,26 @@ def test_sexp_unify_reify():
199203
# Now, the second, `A . y`
200204
assert z_dist_tf.op.inputs[1].op.inputs[0] == A
201205
assert z_dist_tf.op.inputs[1].op.inputs[1] == y
206+
207+
208+
@run_in_graph_mode
209+
@pytest.mark.xfail(strict=True)
210+
def test_unique_names():
211+
212+
first_div_mt = mt(1) / mt(2)
213+
214+
assert first_div_mt.op.name == 'truediv'
215+
assert first_div_mt.reify().op.name
216+
217+
div_lv = mt.realdiv(var('b'), var('c'), name=var('name'))
218+
# Unify with the TF graph, then reify
219+
s = unify(first_div_mt.reify(), div_lv)
220+
221+
s[var('b')] = 1
222+
s[var('b')] = 3
223+
224+
div_mt = reify(div_lv, s)
225+
226+
assert div_mt.op.name == 'truediv'
227+
assert isinstance(div_mt.reify(), tf.Tensor)
228+
assert first_div_mt.reify() != div_mt.reify()

0 commit comments

Comments
 (0)