Skip to content

Commit b6336d4

Browse files
authored
use safe pickle_load for resume_file (#2329)
Signed-off-by: chensuyue <suyue.chen@intel.com>
1 parent 907c538 commit b6336d4

File tree

5 files changed

+10
-12
lines changed

5 files changed

+10
-12
lines changed

neural_compressor/adaptor/tensorflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1354,7 +1354,7 @@ def inspect_tensor(
13541354
model = model.graph_def
13551355
if not quantization_cfg:
13561356
# TODO get config from graph if config is None
1357-
quantization_cfg = load_data_from_pkl("./nc_workspace/", "cfg.pkl")
1357+
quantization_cfg = load_data_from_pkl("./nc_workspace/cfg.pkl")
13581358
node_list = op_list
13591359
# create the mapping between node name and node, key: node_name, val: node
13601360
graph_node_name_mapping = {}

neural_compressor/mix_precision.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from .model import Model
3030
from .strategy import STRATEGIES
3131
from .utils import alias_param, logger
32-
from .utils.utility import CpuInfo, secure_check_eval_func, time_limit
32+
from .utils.utility import CpuInfo, load_data_from_pkl, secure_check_eval_func, time_limit
3333

3434

3535
@alias_param("conf", param_alias="config")
@@ -142,7 +142,7 @@ def fit(model, conf, eval_func=None, eval_dataloader=None, eval_metric=None, **k
142142
if resume_file:
143143
assert os.path.exists(resume_file), "The specified resume file {} doesn't exist!".format(resume_file)
144144
with open(resume_file, "rb") as f:
145-
_resume = pickle.load(f).__dict__
145+
_resume = load_data_from_pkl(f).__dict__
146146

147147
strategy = STRATEGIES["automixedprecision"](
148148
model=wrapped_model,

neural_compressor/quantization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from .model import Model
2828
from .strategy import STRATEGIES
2929
from .utils import logger
30-
from .utils.utility import dump_class_attrs, secure_check_eval_func, time_limit
30+
from .utils.utility import dump_class_attrs, load_data_from_pkl, secure_check_eval_func, time_limit
3131

3232

3333
def fit(
@@ -182,7 +182,7 @@ def eval_func(model):
182182
if resume_file:
183183
assert os.path.exists(resume_file), "The specified resume file {} doesn't exist!".format(resume_file)
184184
with open(resume_file, "rb") as f:
185-
_resume = pickle.load(f).__dict__
185+
_resume = load_data_from_pkl(f).__dict__
186186

187187
if eval_func is None and eval_dataloader is None: # pragma: no cover
188188
logger.info("Quantize model without tuning!")

neural_compressor/training.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from .metric import register_customer_metric
3333
from .model.model import Model
3434
from .utils import logger
35-
from .utils.utility import time_limit
35+
from .utils.utility import load_data_from_pkl, time_limit
3636

3737

3838
class CompressionManager:
@@ -312,7 +312,7 @@ def eval_func(model):
312312
if resume_file:
313313
assert os.path.exists(resume_file), "The specified resume file {} doesn't exist!".format(resume_file)
314314
with open(resume_file, "rb") as f:
315-
_resume = pickle.load(f).__dict__
315+
_resume = load_data_from_pkl(f).__dict__
316316

317317
if eval_func is None and eval_dataloader is None: # pragma: no cover
318318
logger.info("Quantize model without tuning!")

neural_compressor/utils/utility.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -638,20 +638,18 @@ class GLOBAL_STATE:
638638
STATE = MODE.QUANTIZATION
639639

640640

641-
def load_data_from_pkl(path, filename):
641+
def load_data_from_pkl(file_path):
642642
"""Load data from local pkl file.
643643
644644
Args:
645-
path: The directory to load data
646-
filename: The filename to load
645+
file_path: The directory to load data
647646
"""
648647
try:
649-
file_path = os.path.join(path, filename)
650648
with open(file_path, "rb") as fp:
651649
data = _safe_pickle_load(fp)
652650
return data
653651
except FileExistsError:
654-
logging.getLogger("neural_compressor").info("Can not open %s." % path)
652+
logging.getLogger("neural_compressor").info("Can not open %s." % file_path)
655653

656654

657655
def dump_data_to_local(data, path, filename):

0 commit comments

Comments
 (0)