Skip to content

Commit 2cdb61b

Browse files
[tensorflow] Allow unspecified NodeDef names
`NodeDef` name values can now be `None`, which means that reification will use the next available unique name in the default graph. Closes #93.
1 parent b504958 commit 2cdb61b

File tree

3 files changed

+53
-49
lines changed

3 files changed

+53
-49
lines changed

symbolic_pymc/tensorflow/meta.py

Lines changed: 32 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -366,10 +366,15 @@ def _protobuf_convert(cls, k, v):
366366
raise TypeError(f"Could not convert {k}")
367367

368368
def __init__(self, op, name, attr, obj=None):
369+
"""Create a TF meta NodeDef.
370+
371+
XXX: Meta NodeDefs with `name == None` have a special meaning;
372+
their names are uniquely generated. We still consider them equal
373+
(when every other property is equal, of course).
374+
"""
369375
super().__init__(obj=obj)
370376
self.op = metatize(op)
371-
assert name is not None
372-
self.name = name if isvar(name) else str(name)
377+
self.name = name if isvar(name) else name
373378

374379
if not isvar(attr):
375380
opdef_sig, _ = op_def_lib.get_op_info(self.op)
@@ -601,7 +606,7 @@ def reify(self):
601606
#
602607
try:
603608
existing_op = ops.get_default_graph().get_operation_by_name(self.name)
604-
except KeyError:
609+
except (KeyError, TypeError):
605610
#
606611
# There is no such `Operation`, so we attempt to create it
607612
#
@@ -613,7 +618,15 @@ def reify(self):
613618
# An `Operation` with this name exists, let's make sure it's
614619
# equivalent to this meta `Operation`
615620
#
616-
if self != mt(existing_op):
621+
existing_op_mt = mt(existing_op)
622+
623+
# # Since we can't exactly reproduce all NodeDef.attr information
624+
# # (e.g. dtypes), we need to remove any unnecessary NodeDef.attr
625+
# # fields from comparisons with same-named nodes in the graph.
626+
# if op_attrs.keys() != node_attr.keys():
627+
# existing_op_mt.node_def.attr = node_attr
628+
629+
if self != existing_op_mt:
617630
raise MetaReificationError(
618631
f"An Operation with the name {self.name}"
619632
" already exists in the graph and is not"
@@ -987,48 +1000,22 @@ def __api_call__(self, *args, **kwargs):
9871000

9881001
if not op_args_unreified:
9891002

990-
res_var = None
991-
# name = op_args.get("name", None)
9921003
#
993-
# if name is not None:
994-
# #
995-
# # An operation with this name might already exist in the graph
996-
# #
1004+
# We create the `Operation` in the graph
9971005
#
998-
# from tensorflow.python.framework import ops
999-
#
1000-
# try:
1001-
# this_op = ops.get_default_graph().get_operation_by_name(name)
1002-
# except KeyError:
1003-
# pass
1004-
# else:
1005-
# # TODO: Make sure the existing `Operation` matches our arguments
1006-
# assert this_op.type == self.op_def.obj.name
1007-
#
1008-
# this_op = mt(this_op)
1009-
# op_inputs, op_node_def = self.op_args_to_operation_inputs(op_args)
1010-
# assert op_inputs == this_op.inputs
1011-
# assert op_node_def == this_op.node_def
1012-
# res_var = this_op.default_output
1013-
1014-
if res_var is None:
1015-
#
1016-
# We create the `Operation` in the graph
1017-
#
1018-
1019-
tf_out = self._apply_func(**op_args)
1020-
1021-
# Ensure that the original meta objects will be available
1022-
# for use in the `metatize` that follows
1023-
tf_metatize_cache.update(
1024-
{
1025-
k: v
1026-
for k, v in zip(op_args.values(), apply_arguments.values())
1027-
if isinstance(k, tf.Tensor)
1028-
}
1029-
)
1006+
tf_out = self._apply_func(**op_args)
1007+
1008+
# Ensure that the original meta objects will be available
1009+
# for use in the `metatize` that follows
1010+
tf_metatize_cache.update(
1011+
{
1012+
k: v
1013+
for k, v in zip(op_args.values(), apply_arguments.values())
1014+
if isinstance(k, tf.Tensor)
1015+
}
1016+
)
10301017

1031-
res_var = metatize(tf_out)
1018+
res_var = metatize(tf_out)
10321019

10331020
if "names" in meta._lvar_defaults_enabled:
10341021
# This should also reset the NodeDef's `obj`
@@ -1073,7 +1060,8 @@ def op_args_to_operation_inputs(self, apply_arguments):
10731060
node_attr = var()
10741061

10751062
if "names" not in meta._lvar_defaults_enabled:
1076-
op_name = apply_arguments.get("name", op_def_tf.name) or op_def_tf.name
1063+
# default_name = ops.get_default_graph().unique_name(op_def_tf.name, mark_as_used=False)
1064+
op_name = apply_arguments.get("name", None)
10771065
else:
10781066
op_name = var()
10791067

tests/tensorflow/test_meta.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,7 @@ def test_global_options():
636636
with tf.Graph().as_default(), disable_auto_reification():
637637
y_mt = mt.Placeholder('float')
638638
assert y_mt.obj is None
639-
assert y_mt.name == 'Placeholder:0'
639+
assert isvar(y_mt.name)
640640
assert isinstance(y_mt.op.node_def.attr, dict)
641641

642642
with tf.Graph().as_default(), enable_lvar_defaults('names', 'node_attrs'):
@@ -706,7 +706,7 @@ def test_meta_const():
706706
@run_in_graph_mode
707707
def test_meta_existing_names():
708708

709-
with tf.Graph().as_default():
709+
with tf.Graph().as_default() as test_graph:
710710
one_mt = mt(1)
711711
assert one_mt.op.name == 'Const'
712712

@@ -723,6 +723,7 @@ def test_meta_existing_names():
723723
# Make sure it's the first base variable we created
724724
assert orig_one_tf is one_tf
725725

726+
# FYI: This implicitly creates 'Const_1'
726727
two_mt = mt(2)
727728
two_mt.op.node_def.name = 'Const'
728729

@@ -736,3 +737,15 @@ def test_meta_existing_names():
736737

737738
with pytest.raises(MetaReificationError):
738739
two_mt.reify()
740+
741+
another_one_mt = TFlowMetaOperator('Const', None)(3, var())
742+
# The following is something that would happen as a result of
743+
# reification (of the lvar in the meta object, not the meta object
744+
# itself).
745+
another_one_mt.op.node_def.attr['dtype'] = tf.int32
746+
747+
assert another_one_mt.op.name is None
748+
# We need to make sure that the reified meta object actually uses a
749+
# unique name.
750+
assert isinstance(another_one_mt.reify(), tf.Tensor)
751+
assert another_one_mt.reify().op.name == 'Const_2'

tests/tensorflow/test_unify.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,11 @@ def test_basic_unify_reify():
127127

128128
test_expr = mt.add(tf.constant(1, dtype=tf.float64),
129129
mt.mul(tf.constant(2, dtype=tf.float64),
130-
x_l))
131-
test_reify_res = reify(test_expr, {x_l: a})
130+
x_l, name=var('mul_name')),
131+
name=var('add_name'))
132+
test_reify_res = reify(test_expr, {x_l: a,
133+
var('add_name'): 'Add_10',
134+
var('mul_name'): 'Mul_10'})
132135
test_base_res = test_reify_res.reify()
133136
assert isinstance(test_base_res, tf.Tensor)
134137

@@ -141,7 +144,7 @@ def test_basic_unify_reify():
141144
# Simply make sure that unification succeeds
142145
meta_expected_res = mt(expected_res)
143146
s_test = unify(test_expr, meta_expected_res, {})
144-
assert len(s_test) == 3
147+
assert len(s_test) == 5
145148

146149
assert reify(test_expr, s_test) == meta_expected_res
147150

0 commit comments

Comments
 (0)