Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config-zh-cn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ repos:
rev: 5.11.5
hooks:
- id: isort
- repo: https://gitee.com/openmmlab/mirrors-yapf
rev: v0.32.0
- repo: https://github.com/google/yapf
rev: v0.43.0
hooks:
- id: yapf
- repo: https://gitee.com/openmmlab/mirrors-pre-commit-hooks
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ repos:
rev: 5.11.5
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.32.0
- repo: https://github.com/google/yapf
rev: v0.43.0
hooks:
- id: yapf
- repo: https://github.com/pre-commit/pre-commit-hooks
Expand Down
8 changes: 4 additions & 4 deletions mmengine/_strategy/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,10 @@ def __init__(
self.config.setdefault('gradient_accumulation_steps', 1)
self.config['steps_per_print'] = steps_per_print
self._inputs_to_half = inputs_to_half
assert (exclude_frozen_parameters is None or
digit_version(deepspeed.__version__) >= digit_version('0.13.2')
), ('DeepSpeed >= 0.13.2 is required to enable '
'exclude_frozen_parameters')
assert (exclude_frozen_parameters is None or digit_version(
deepspeed.__version__) >= digit_version('0.13.2')), (
'DeepSpeed >= 0.13.2 is required to enable '
'exclude_frozen_parameters')
self.exclude_frozen_parameters = exclude_frozen_parameters

register_deepspeed_optimizers()
Expand Down
21 changes: 12 additions & 9 deletions mmengine/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@
def _lazy2string(cfg_dict, dict_type=None):
if isinstance(cfg_dict, dict):
dict_type = dict_type or type(cfg_dict)
return dict_type(
{k: _lazy2string(v, dict_type)
for k, v in dict.items(cfg_dict)})
return dict_type({
k: _lazy2string(v, dict_type)
for k, v in dict.items(cfg_dict)
})
elif isinstance(cfg_dict, (tuple, list)):
return type(cfg_dict)(_lazy2string(v, dict_type) for v in cfg_dict)
elif isinstance(cfg_dict, (LazyAttr, LazyObject)):
Expand Down Expand Up @@ -273,13 +274,15 @@ def __reduce_ex__(self, proto):
# called by CPython interpreter during pickling. See more details in
# https://github.com/python/cpython/blob/8d61a71f9c81619e34d4a30b625922ebc83c561b/Objects/typeobject.c#L6196 # noqa: E501
if digit_version(platform.python_version()) < digit_version('3.8'):
return (self.__class__, ({k: v
for k, v in super().items()}, ), None,
None, None)
return (self.__class__, ({
k: v
for k, v in super().items()
}, ), None, None, None)
else:
return (self.__class__, ({k: v
for k, v in super().items()}, ), None,
None, None, None)
return (self.__class__, ({
k: v
for k, v in super().items()
}, ), None, None, None, None)

def __eq__(self, other):
if isinstance(other, ConfigDict):
Expand Down
3 changes: 2 additions & 1 deletion mmengine/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def default_collate(data_batch: Sequence) -> Any:
return [default_collate(samples) for samples in transposed]
elif isinstance(data_item, Mapping):
return data_item_type({
key: default_collate([d[key] for d in data_batch])
key:
default_collate([d[key] for d in data_batch])
for key in data_item
})
else:
Expand Down
4 changes: 2 additions & 2 deletions mmengine/fileio/backends/local_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ def isfile(self, filepath: Union[str, Path]) -> bool:
"""
return osp.isfile(filepath)

def join_path(self, filepath: Union[str, Path],
*filepaths: Union[str, Path]) -> str:
def join_path(self, filepath: Union[str, Path], *filepaths:
Union[str, Path]) -> str:
r"""Concatenate all file paths.

Join one or more filepath components intelligently. The return value
Expand Down
4 changes: 2 additions & 2 deletions mmengine/fileio/file_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,8 @@ def isfile(self, filepath: Union[str, Path]) -> bool:
"""
return self.client.isfile(filepath)

def join_path(self, filepath: Union[str, Path],
*filepaths: Union[str, Path]) -> str:
def join_path(self, filepath: Union[str, Path], *filepaths:
Union[str, Path]) -> str:
r"""Concatenate all file paths.

Join one or more filepath components intelligently. The return value
Expand Down
8 changes: 4 additions & 4 deletions mmengine/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,10 @@ def __init__(self,
self.save_best = save_best

# rule logic
assert (isinstance(rule, str) or is_list_of(rule, str)
or (rule is None)), (
'"rule" should be a str or list of str or None, '
f'but got {type(rule)}')
assert (isinstance(rule, str) or is_list_of(rule, str) or
(rule
is None)), ('"rule" should be a str or list of str or None, '
f'but got {type(rule)}')
if isinstance(rule, list):
# check the length of rule list
assert len(rule) in [
Expand Down
7 changes: 4 additions & 3 deletions mmengine/model/test_time_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,10 @@ def test_step(self, data):
data_list: Union[List[dict], List[list]]
if isinstance(data, dict):
num_augs = len(data[next(iter(data))])
data_list = [{key: value[idx]
for key, value in data.items()}
for idx in range(num_augs)]
data_list = [{
key: value[idx]
for key, value in data.items()
} for idx in range(num_augs)]
elif isinstance(data, (tuple, list)):
num_augs = len(data[0])
data_list = [[_data[idx] for _data in data]
Expand Down
19 changes: 12 additions & 7 deletions mmengine/runner/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,8 @@ def load_from_local(filename, map_location):
filename = osp.expanduser(filename)
if not osp.isfile(filename):
raise FileNotFoundError(f'{filename} can not be found.')
checkpoint = torch.load(filename, map_location=map_location)
checkpoint = torch.load(
filename, map_location=map_location, weights_only=False)
return checkpoint


Expand Down Expand Up @@ -412,7 +413,8 @@ def load_from_pavi(filename, map_location=None):
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(downloaded_file, map_location=map_location)
checkpoint = torch.load(
downloaded_file, map_location=map_location, weights_only=False)
return checkpoint


Expand All @@ -435,7 +437,8 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
file_backend = get_file_backend(
filename, backend_args={'backend': backend})
with io.BytesIO(file_backend.get(filename)) as buffer:
checkpoint = torch.load(buffer, map_location=map_location)
checkpoint = torch.load(
buffer, map_location=map_location, weights_only=False)
return checkpoint


Expand Down Expand Up @@ -504,7 +507,8 @@ def load_from_openmmlab(filename, map_location=None):
filename = osp.join(_get_mmengine_home(), model_url)
if not osp.isfile(filename):
raise FileNotFoundError(f'{filename} can not be found.')
checkpoint = torch.load(filename, map_location=map_location)
checkpoint = torch.load(
filename, map_location=map_location, weights_only=False)
return checkpoint


Expand Down Expand Up @@ -597,9 +601,10 @@ def _load_checkpoint_to_model(model,
# strip prefix of state_dict
metadata = getattr(state_dict, '_metadata', OrderedDict())
for p, r in revise_keys:
state_dict = OrderedDict(
{re.sub(p, r, k): v
for k, v in state_dict.items()})
state_dict = OrderedDict({
re.sub(p, r, k): v
for k, v in state_dict.items()
})
# Keep metadata in state_dict
state_dict._metadata = metadata

Expand Down
6 changes: 4 additions & 2 deletions mmengine/utils/dl_utils/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def _legacy_zip_load(filename, model_dir, map_location):
f.extractall(model_dir)
extraced_name = members[0].filename
extracted_file = os.path.join(model_dir, extraced_name)
return torch.load(extracted_file, map_location=map_location)
return torch.load(
extracted_file, map_location=map_location, weights_only=False)

def load_url(url,
model_dir=None,
Expand Down Expand Up @@ -114,7 +115,8 @@ def load_url(url,
return _legacy_zip_load(cached_file, model_dir, map_location)

try:
return torch.load(cached_file, map_location=map_location)
return torch.load(
cached_file, map_location=map_location, weights_only=False)
except RuntimeError as error:
if digit_version(TORCH_VERSION) < digit_version('1.5.0'):
warnings.warn(
Expand Down
6 changes: 3 additions & 3 deletions mmengine/utils/dl_utils/torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from ..version_utils import digit_version
from .parrots_wrapper import TORCH_VERSION

_torch_version_meshgrid_indexing = (
'parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.10.0a0'))
_torch_version_meshgrid_indexing = ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION)
>= digit_version('1.10.0a0'))


def torch_meshgrid(*tensors):
Expand Down
5 changes: 3 additions & 2 deletions mmengine/visualization/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,8 +754,9 @@ def draw_bboxes(
assert bboxes.shape[-1] == 4, (
f'The shape of `bboxes` should be (N, 4), but got {bboxes.shape}')

assert (bboxes[:, 0] <= bboxes[:, 2]).all() and (bboxes[:, 1] <=
bboxes[:, 3]).all()
assert (bboxes[:, 0] <= bboxes[:, 2]).all() and (bboxes[:, 1]
<= bboxes[:,
3]).all()
if not self._is_posion_valid(bboxes.reshape((-1, 2, 2))):
warnings.warn(
'Warning: The bbox is out of bounds,'
Expand Down
7 changes: 4 additions & 3 deletions tests/test_analysis/test_jit_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,9 +634,10 @@ def dummy_ops_handle(inputs: List[Any],

dummy_flops = {}
for name, counts in model.flops.items():
dummy_flops[name] = Counter(
{op: flop
for op, flop in counts.items() if op != self.lin_op})
dummy_flops[name] = Counter({
op: flop
for op, flop in counts.items() if op != self.lin_op
})
dummy_flops[''][dummy_name] = 2 * dummy_out
dummy_flops['fc'][dummy_name] = dummy_out
dummy_flops['submod'][dummy_name] = dummy_out
Expand Down
8 changes: 4 additions & 4 deletions tests/test_dataset/test_base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,13 +733,13 @@ def test_length(self):
def test_getitem(self):
assert (
self.cat_datasets[0]['imgs'] == self.dataset_a[0]['imgs']).all()
assert (self.cat_datasets[0]['imgs'] !=
self.dataset_b[0]['imgs']).all()
assert (self.cat_datasets[0]['imgs']
!= self.dataset_b[0]['imgs']).all()

assert (
self.cat_datasets[-1]['imgs'] == self.dataset_b[-1]['imgs']).all()
assert (self.cat_datasets[-1]['imgs'] !=
self.dataset_a[-1]['imgs']).all()
assert (self.cat_datasets[-1]['imgs']
!= self.dataset_a[-1]['imgs']).all()

def test_get_data_info(self):
assert self.cat_datasets.get_data_info(
Expand Down
32 changes: 23 additions & 9 deletions tests/test_hooks/test_checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,13 +458,17 @@ def test_with_runner(self, training_type):
cfg = copy.deepcopy(common_cfg)
runner = self.build_runner(cfg)
runner.train()
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
ckpt = torch.load(
osp.join(cfg.work_dir, f'{training_type}_11.pth'),
weights_only=False)
self.assertIn('optimizer', ckpt)

cfg.default_hooks.checkpoint.save_optimizer = False
runner = self.build_runner(cfg)
runner.train()
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
ckpt = torch.load(
osp.join(cfg.work_dir, f'{training_type}_11.pth'),
weights_only=False)
self.assertNotIn('optimizer', ckpt)

# Test save_param_scheduler=False
Expand All @@ -479,13 +483,17 @@ def test_with_runner(self, training_type):
]
runner = self.build_runner(cfg)
runner.train()
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
ckpt = torch.load(
osp.join(cfg.work_dir, f'{training_type}_11.pth'),
weights_only=False)
self.assertIn('param_schedulers', ckpt)

cfg.default_hooks.checkpoint.save_param_scheduler = False
runner = self.build_runner(cfg)
runner.train()
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
ckpt = torch.load(
osp.join(cfg.work_dir, f'{training_type}_11.pth'),
weights_only=False)
self.assertNotIn('param_schedulers', ckpt)

self.clear_work_dir()
Expand Down Expand Up @@ -533,7 +541,9 @@ def test_with_runner(self, training_type):
self.assertFalse(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))

ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
ckpt = torch.load(
osp.join(cfg.work_dir, f'{training_type}_11.pth'),
weights_only=False)
self.assertEqual(ckpt['message_hub']['runtime_info']['keep_ckpt_ids'],
[9, 10, 11])

Expand Down Expand Up @@ -574,9 +584,11 @@ def test_with_runner(self, training_type):
runner.train()
best_ckpt_path = osp.join(cfg.work_dir,
f'best_test_acc_{training_type}_5.pth')
best_ckpt = torch.load(best_ckpt_path)
best_ckpt = torch.load(best_ckpt_path, weights_only=False)

ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth'))
ckpt = torch.load(
osp.join(cfg.work_dir, f'{training_type}_5.pth'),
weights_only=False)
self.assertEqual(best_ckpt_path,
ckpt['message_hub']['runtime_info']['best_ckpt'])

Expand All @@ -603,11 +615,13 @@ def test_with_runner(self, training_type):
runner.train()
best_ckpt_path = osp.join(cfg.work_dir,
f'best_test_acc_{training_type}_5.pth')
best_ckpt = torch.load(best_ckpt_path)
best_ckpt = torch.load(best_ckpt_path, weights_only=False)

# if the current ckpt is the best, the interval will be ignored the
# the ckpt will also be saved
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth'))
ckpt = torch.load(
osp.join(cfg.work_dir, f'{training_type}_5.pth'),
weights_only=False)
self.assertEqual(best_ckpt_path,
ckpt['message_hub']['runtime_info']['best_ckpt'])

Expand Down
18 changes: 13 additions & 5 deletions tests/test_hooks/test_ema_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ def test_with_runner(self):
self.assertTrue(
isinstance(ema_hook.ema_model, ExponentialMovingAverage))

checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'))
checkpoint = torch.load(
osp.join(self.temp_dir.name, 'epoch_2.pth'), weights_only=False)
self.assertTrue('ema_state_dict' in checkpoint)
self.assertTrue(checkpoint['ema_state_dict']['steps'] == 8)

Expand All @@ -245,7 +246,8 @@ def test_with_runner(self):
runner.test()

# Test load checkpoint without ema_state_dict
checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'))
checkpoint = torch.load(
osp.join(self.temp_dir.name, 'epoch_2.pth'), weights_only=False)
checkpoint.pop('ema_state_dict')
torch.save(checkpoint,
osp.join(self.temp_dir.name, 'without_ema_state_dict.pth'))
Expand Down Expand Up @@ -274,7 +276,9 @@ def test_with_runner(self):
runner = self.build_runner(cfg)
runner.train()
state_dict = torch.load(
osp.join(self.temp_dir.name, 'epoch_4.pth'), map_location='cpu')
osp.join(self.temp_dir.name, 'epoch_4.pth'),
map_location='cpu',
weights_only=False)
self.assertIn('ema_state_dict', state_dict)
for k, v in state_dict['state_dict'].items():
assert_allclose(v, state_dict['ema_state_dict']['module.' + k])
Expand All @@ -287,12 +291,16 @@ def test_with_runner(self):
runner = self.build_runner(cfg)
runner.train()
state_dict = torch.load(
osp.join(self.temp_dir.name, 'iter_4.pth'), map_location='cpu')
osp.join(self.temp_dir.name, 'iter_4.pth'),
map_location='cpu',
weights_only=False)
self.assertIn('ema_state_dict', state_dict)
for k, v in state_dict['state_dict'].items():
assert_allclose(v, state_dict['ema_state_dict']['module.' + k])
state_dict = torch.load(
osp.join(self.temp_dir.name, 'iter_5.pth'), map_location='cpu')
osp.join(self.temp_dir.name, 'iter_5.pth'),
map_location='cpu',
weights_only=False)
self.assertIn('ema_state_dict', state_dict)

def _test_swap_parameters(self, func_name, *args, **kwargs):
Expand Down
Loading