Skip to content

Commit 8011195

Browse files
Merge pull request #91 from brandonwillard/create-tf-operator-by-name
Create TFlowMetaOperators using string names of OpDefs
2 parents 7656442 + 2dff08c commit 8011195

File tree

2 files changed

+55
-33
lines changed

2 files changed

+55
-33
lines changed

symbolic_pymc/tensorflow/meta.py

Lines changed: 43 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -820,15 +820,34 @@ class TFlowMetaOperator(TFlowMetaSymbol, MetaOp):
820820
base = None
821821
__slots__ = ("op_def", "node_def", "_apply_func_sig", "_apply_func")
822822

823+
@classmethod
824+
def get_metaopdef(cls, name):
825+
"""Obtain a MetaOpDef for a given string name.
826+
827+
This is more flexible because it ignores things like string case
828+
(when the non-`raw_ops` name differs from the TF user-level API).
829+
"""
830+
raw_op_name = op_def_lib.lower_op_name_to_raw.get(name.lower(), name)
831+
op_def = op_def_registry.get(raw_op_name)
832+
if op_def is not None:
833+
return TFlowMetaOpDef(obj=op_def)
834+
823835
def __init__(self, op_def, node_def=None, obj=None):
824836
assert obj is None
825837
super().__init__(None)
826838

827839
self.op_def = op_def
828-
if not isvar(self.op_def):
829-
self._apply_func_sig, self._apply_func = op_def_lib.get_op_info(self.op_def.obj)
830-
else:
840+
841+
if isinstance(self.op_def, str):
842+
self.op_def = self.get_metaopdef(self.op_def)
843+
844+
if self.op_def is None:
845+
raise ValueError(f"Could not find an OpDef for {op_def}")
846+
847+
if isvar(self.op_def):
831848
self._apply_func_sig, self._apply_func = None, None
849+
else:
850+
self._apply_func_sig, self._apply_func = op_def_lib.get_op_info(self.op_def.obj)
832851

833852
self.node_def = node_def
834853

@@ -1097,18 +1116,6 @@ def __init__(self, namespace=None):
10971116
def __call__(self, x):
10981117
return metatize(x)
10991118

1100-
@classmethod
1101-
def find_operator(cls, name):
1102-
"""Attempt to create a meta operator for a given TF function/`Operation` name."""
1103-
raw_op_name = op_def_lib.lower_op_name_to_raw.get(name.lower(), name)
1104-
op_def = op_def_registry.get(raw_op_name)
1105-
1106-
if op_def is not None:
1107-
meta_obj = TFlowMetaOperator(TFlowMetaOpDef(obj=op_def), None)
1108-
return meta_obj
1109-
1110-
return None
1111-
11121119
def __getattr__(self, obj):
11131120

11141121
ns_obj = next((getattr(ns, obj) for ns in self.namespaces if hasattr(ns, obj)), None)
@@ -1122,13 +1129,23 @@ def __getattr__(self, obj):
11221129
if ns_obj is None:
11231130
ns_obj = f_back.f_globals.get(obj)
11241131

1125-
if isinstance(ns_obj, (types.FunctionType, partial)):
1126-
# We assume that the user requested an `Operation`
1127-
# constructor/helper. Return the meta `OpDef`, because
1128-
# it implements a constructor/helper-like `__call__`.
1129-
meta_obj = self.find_operator(obj)
1132+
if isinstance(ns_obj, types.ModuleType):
1133+
# It's a sub-module, so let's create another
1134+
# `TheanoMetaAccessor` and check within there.
1135+
meta_obj = TFlowMetaAccessor(namespace=ns_obj)
1136+
else:
1137+
1138+
# Check for a an OpDef first
1139+
meta_obj = TFlowMetaOperator.get_metaopdef(obj)
1140+
1141+
if meta_obj is not None:
1142+
# We assume that the user requested an `Operation`
1143+
# constructor/helper. Return the meta `OpDef`, because
1144+
# it implements a constructor/helper-like `__call__`.
1145+
if meta_obj is not None:
1146+
meta_obj = TFlowMetaOperator(meta_obj, None)
11301147

1131-
# if meta_obj is None:
1148+
# elif isinstance(ns_obj, (types.FunctionType, partial)):
11321149
# # It's a function, so let's provide a wrapper that converts
11331150
# # to-and-from theano and meta objects.
11341151
# @wraps(ns_obj)
@@ -1137,19 +1154,12 @@ def __getattr__(self, obj):
11371154
# res = ns_obj(*args, **kwargs)
11381155
# return metatize(res)
11391156

1140-
elif isinstance(ns_obj, types.ModuleType):
1141-
# It's a sub-module, so let's create another
1142-
# `TheanoMetaAccessor` and check within there.
1143-
meta_obj = TFlowMetaAccessor(namespace=ns_obj)
1144-
else:
1145-
1146-
# Hopefully, it's convertible to a meta object...
1147-
meta_obj = metatize(ns_obj)
1148-
1149-
if meta_obj is None:
1150-
# Last resort
1151-
meta_obj = self.find_operator(obj)
1157+
else:
1158+
# Hopefully, it's convertible to a meta object...
1159+
meta_obj = metatize(ns_obj)
11521160

1161+
# Finally, we store the result as a meta namespace attribute, or raise
1162+
# an exception.
11531163
if isinstance(
11541164
meta_obj, (MetaSymbol, MetaSymbolType, TFlowMetaOperator, types.FunctionType)
11551165
):

tests/tensorflow/test_meta.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,18 @@ def test_meta_helpers():
5656
assert lvar_op_mt.output_meta_types() is None
5757
assert lvar_op_mt.op_args_to_operation_inputs({}) is None
5858

59+
add_op_mt = TFlowMetaOperator('add')
60+
assert add_op_mt.node_def is None
61+
assert add_op_mt.op_def.obj.name == 'Add'
62+
assert add_op_mt == mt.add
63+
64+
# Both cases should work
65+
add_op_mt_2 = TFlowMetaOperator('Add')
66+
assert add_op_mt == add_op_mt_2
67+
68+
with pytest.raises(ValueError):
69+
TFlowMetaOperator('anoperatorthatdoesnotexist')
70+
5971

6072
def test_meta_eager():
6173

0 commit comments

Comments
 (0)