Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion neural_compressor/adaptor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,7 +1354,7 @@ def inspect_tensor(
model = model.graph_def
if not quantization_cfg:
# TODO get config from graph if config is None
quantization_cfg = load_data_from_pkl("./nc_workspace/", "cfg.pkl")
quantization_cfg = load_data_from_pkl("./nc_workspace/cfg.pkl")
node_list = op_list
# create the mapping between node name and node, key: node_name, val: node
graph_node_name_mapping = {}
Expand Down
4 changes: 2 additions & 2 deletions neural_compressor/mix_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .model import Model
from .strategy import STRATEGIES
from .utils import alias_param, logger
from .utils.utility import CpuInfo, secure_check_eval_func, time_limit
from .utils.utility import CpuInfo, load_data_from_pkl, secure_check_eval_func, time_limit


@alias_param("conf", param_alias="config")
Expand Down Expand Up @@ -142,7 +142,7 @@ def fit(model, conf, eval_func=None, eval_dataloader=None, eval_metric=None, **k
if resume_file:
assert os.path.exists(resume_file), "The specified resume file {} doesn't exist!".format(resume_file)
with open(resume_file, "rb") as f:
_resume = pickle.load(f).__dict__
_resume = load_data_from_pkl(f).__dict__

strategy = STRATEGIES["automixedprecision"](
model=wrapped_model,
Expand Down
4 changes: 2 additions & 2 deletions neural_compressor/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .model import Model
from .strategy import STRATEGIES
from .utils import logger
from .utils.utility import dump_class_attrs, secure_check_eval_func, time_limit
from .utils.utility import dump_class_attrs, load_data_from_pkl, secure_check_eval_func, time_limit


def fit(
Expand Down Expand Up @@ -182,7 +182,7 @@ def eval_func(model):
if resume_file:
assert os.path.exists(resume_file), "The specified resume file {} doesn't exist!".format(resume_file)
with open(resume_file, "rb") as f:
_resume = pickle.load(f).__dict__
_resume = load_data_from_pkl(f).__dict__

if eval_func is None and eval_dataloader is None: # pragma: no cover
logger.info("Quantize model without tuning!")
Expand Down
4 changes: 2 additions & 2 deletions neural_compressor/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from .metric import register_customer_metric
from .model.model import Model
from .utils import logger
from .utils.utility import time_limit
from .utils.utility import load_data_from_pkl, time_limit


class CompressionManager:
Expand Down Expand Up @@ -312,7 +312,7 @@ def eval_func(model):
if resume_file:
assert os.path.exists(resume_file), "The specified resume file {} doesn't exist!".format(resume_file)
with open(resume_file, "rb") as f:
_resume = pickle.load(f).__dict__
_resume = load_data_from_pkl(f).__dict__

if eval_func is None and eval_dataloader is None: # pragma: no cover
logger.info("Quantize model without tuning!")
Expand Down
8 changes: 3 additions & 5 deletions neural_compressor/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,20 +638,18 @@ class GLOBAL_STATE:
STATE = MODE.QUANTIZATION


def load_data_from_pkl(path, filename):
def load_data_from_pkl(file_path):
"""Load data from local pkl file.

Args:
path: The directory to load data
filename: The filename to load
file_path: The directory to load data
"""
try:
file_path = os.path.join(path, filename)
with open(file_path, "rb") as fp:
data = _safe_pickle_load(fp)
return data
except FileExistsError:
logging.getLogger("neural_compressor").info("Can not open %s." % path)
logging.getLogger("neural_compressor").info("Can not open %s." % file_path)


def dump_data_to_local(data, path, filename):
Expand Down
Loading