Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions pyzoo/zoo/util/modulepickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import hashlib
import sys
from logging import getLogger
import types

__all__ = ('extend', 'extend_ray', 'extend_cloudpickle')

Expand Down Expand Up @@ -74,11 +75,16 @@ def compress(packagename, path):
return tar.getvalue()


def import_compressed(name, package, class_name):
def import_compressed(name, package, class_name,is_anyfunc):
res = package.load(name)
if getattr(res, class_name, None):
class_type = getattr(res, class_name)
return class_type.__new__(class_type)
obj_type = getattr(res, class_name)

return obj_type.__new__(obj_type) if not is_anyfunc else types.FunctionType(getattr(obj_type, "__code__", ""),
getattr(obj_type, "__globals__", ""),
name=getattr(obj_type, "__name__", ""),
argdefs=getattr(obj_type, "__defaults__", ""),
closure=getattr(obj_type, "__closure__", ""))
else:
return res

Expand All @@ -101,6 +107,9 @@ def is_local(module):
if path is None:
return False

# if your zoo is not installed by whl,
# to debug codes you may exclude your az path from loacl path
# to avoid infinite resursion.
if path.startswith(python_lib_path):
return False

Expand Down Expand Up @@ -157,7 +166,18 @@ def reducer_override(self, obj):
else:
print("get local {} in save_module, path is {}".format(module.__name__, module.__file__))
package = self.compress_package(packagename(module), get_path(module))
args = (module.__name__, package, obj.__class__.__name__)

try:
# todo:Should check class type first
is_anyfunc=isinstance(obj, types.FunctionType)
except TypeError: # t is not a class (old Boost; see SF #502085)
is_anyfunc = False

if is_anyfunc:
args = (module.__name__, package, obj.__name__,is_anyfunc)
else:
args = (module.__name__, package, obj.__class__.__name__,is_anyfunc)

return import_compressed, args, obj.__dict__
return super().reducer_override(obj)

Expand Down