Skip to content

Commit 027bffc

Browse files
Introduce a work-around for constructing TensorFlow constants by hand
1 parent cdaf7e6 commit 027bffc

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

symbolic_pymc/tensorflow/meta.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,20 @@ class MetaOpDefLibrary(object):
6565
}
6666
opdef_signatures = {}
6767

68+
def __init__(self):
69+
#
70+
# We need this in order to construct "Const" tensors directly, since
71+
# the "value" attr in a meta `NodeDef` is just a NumPy array and not
72+
# the `TensorProto` expected by `raw_ops.Const`.
73+
#
74+
def mt_const(value, dtype, name=None):
75+
return tf.raw_ops.Const(
76+
value=tensor_util.make_tensor_proto(value), dtype=dtype, name=name
77+
)
78+
79+
opdef = op_def_registry.get("Const")
80+
self.opdef_signatures[opdef.name] = self.make_opdef_sig(opdef, mt_const)
81+
6882
@classmethod
6983
def get_op_info(cls, opdef):
7084
"""Return the TF Python API function signature for a given `OpDef`.

tests/tensorflow/test_meta.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,21 @@ def test_global_options():
677677
assert isvar(a_mt.name)
678678

679679

680+
@run_in_graph_mode
681+
def test_meta_const():
682+
"""Make sure we can create a Const tensor by hand."""
683+
684+
with tf.Graph().as_default():
685+
one_mt = mt.const(1, 'int32', 'Const')
686+
687+
with tf.Graph().as_default():
688+
another_one_mt = mt(1)
689+
690+
assert one_mt == another_one_mt
691+
assert isinstance(one_mt.reify(), tf.Tensor)
692+
assert one_mt.reify().op.type == 'Const'
693+
694+
680695
@run_in_graph_mode
681696
def test_meta_existing_names():
682697

0 commit comments

Comments
 (0)