From d7bc5ae0c4707d5934172f048fe8cdb42d4cf8a4 Mon Sep 17 00:00:00 2001 From: dianyo Date: Fri, 21 Feb 2025 08:04:12 +0000 Subject: [PATCH 1/4] update yapf including version and link in precommit --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c8edd013c6..81ddf48216 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,8 +12,8 @@ repos: rev: 5.11.5 hooks: - id: isort - - repo: https://github.com/pre-commit/mirrors-yapf - rev: v0.32.0 + - repo: https://github.com/google/yapf + rev: v0.43.0 hooks: - id: yapf - repo: https://github.com/pre-commit/pre-commit-hooks From d57862ecf8f49e9e14bb8884e381ac43df441bf8 Mon Sep 17 00:00:00 2001 From: dianyo Date: Fri, 21 Feb 2025 08:04:30 +0000 Subject: [PATCH 2/4] add weights_only argument for torch load to support new version of torch --- mmengine/runner/checkpoint.py | 19 +++++++---- mmengine/utils/dl_utils/hub.py | 6 ++-- tests/test_hooks/test_checkpoint_hook.py | 32 +++++++++++++------ tests/test_hooks/test_ema_hook.py | 18 ++++++++--- .../test_scheduler/test_param_scheduler.py | 3 +- tests/test_runner/test_runner.py | 6 ++-- 6 files changed, 57 insertions(+), 27 deletions(-) diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py index 2bf5f50f7c..f061cf5cac 100644 --- a/mmengine/runner/checkpoint.py +++ b/mmengine/runner/checkpoint.py @@ -344,7 +344,8 @@ def load_from_local(filename, map_location): filename = osp.expanduser(filename) if not osp.isfile(filename): raise FileNotFoundError(f'{filename} can not be found.') - checkpoint = torch.load(filename, map_location=map_location) + checkpoint = torch.load( + filename, map_location=map_location, weights_only=False) return checkpoint @@ -412,7 +413,8 @@ def load_from_pavi(filename, map_location=None): with TemporaryDirectory() as tmp_dir: downloaded_file = osp.join(tmp_dir, model.name) model.download(downloaded_file) - checkpoint = torch.load(downloaded_file, map_location=map_location) + checkpoint = torch.load( + downloaded_file, map_location=map_location, weights_only=False) return checkpoint @@ -435,7 +437,8 @@ def load_from_ceph(filename, map_location=None, backend='petrel'): file_backend = get_file_backend( filename, backend_args={'backend': backend}) with io.BytesIO(file_backend.get(filename)) as buffer: - checkpoint = torch.load(buffer, map_location=map_location) + checkpoint = torch.load( + buffer, map_location=map_location, weights_only=False) return checkpoint @@ -504,7 +507,8 @@ def load_from_openmmlab(filename, map_location=None): filename = osp.join(_get_mmengine_home(), model_url) if not osp.isfile(filename): raise FileNotFoundError(f'{filename} can not be found.') - checkpoint = torch.load(filename, map_location=map_location) + checkpoint = torch.load( + filename, map_location=map_location, weights_only=False) return checkpoint @@ -597,9 +601,10 @@ def _load_checkpoint_to_model(model, # strip prefix of state_dict metadata = getattr(state_dict, '_metadata', OrderedDict()) for p, r in revise_keys: - state_dict = OrderedDict( - {re.sub(p, r, k): v - for k, v in state_dict.items()}) + state_dict = OrderedDict({ + re.sub(p, r, k): v + for k, v in state_dict.items() + }) # Keep metadata in state_dict state_dict._metadata = metadata diff --git a/mmengine/utils/dl_utils/hub.py b/mmengine/utils/dl_utils/hub.py index 7f7f1a087d..55a022782b 100644 --- a/mmengine/utils/dl_utils/hub.py +++ b/mmengine/utils/dl_utils/hub.py @@ -48,7 +48,8 @@ def _legacy_zip_load(filename, model_dir, map_location): f.extractall(model_dir) extraced_name = members[0].filename extracted_file = os.path.join(model_dir, extraced_name) - return torch.load(extracted_file, map_location=map_location) + return torch.load( + extracted_file, map_location=map_location, weights_only=False) def load_url(url, model_dir=None, @@ -114,7 +115,8 @@ def load_url(url, return _legacy_zip_load(cached_file, model_dir, map_location) try: - return torch.load(cached_file, map_location=map_location) + return torch.load( + cached_file, map_location=map_location, weights_only=False) except RuntimeError as error: if digit_version(TORCH_VERSION) < digit_version('1.5.0'): warnings.warn( diff --git a/tests/test_hooks/test_checkpoint_hook.py b/tests/test_hooks/test_checkpoint_hook.py index d731a42b76..13914341f7 100644 --- a/tests/test_hooks/test_checkpoint_hook.py +++ b/tests/test_hooks/test_checkpoint_hook.py @@ -458,13 +458,17 @@ def test_with_runner(self, training_type): cfg = copy.deepcopy(common_cfg) runner = self.build_runner(cfg) runner.train() - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_11.pth'), + weights_only=False) self.assertIn('optimizer', ckpt) cfg.default_hooks.checkpoint.save_optimizer = False runner = self.build_runner(cfg) runner.train() - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_11.pth'), + weights_only=False) self.assertNotIn('optimizer', ckpt) # Test save_param_scheduler=False @@ -479,13 +483,17 @@ def test_with_runner(self, training_type): ] runner = self.build_runner(cfg) runner.train() - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_11.pth'), + weights_only=False) self.assertIn('param_schedulers', ckpt) cfg.default_hooks.checkpoint.save_param_scheduler = False runner = self.build_runner(cfg) runner.train() - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_11.pth'), + weights_only=False) self.assertNotIn('param_schedulers', ckpt) self.clear_work_dir() @@ -533,7 +541,9 @@ def test_with_runner(self, training_type): self.assertFalse( osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_11.pth'), + weights_only=False) self.assertEqual(ckpt['message_hub']['runtime_info']['keep_ckpt_ids'], [9, 10, 11]) @@ -574,9 +584,11 @@ def test_with_runner(self, training_type): runner.train() best_ckpt_path = osp.join(cfg.work_dir, f'best_test_acc_{training_type}_5.pth') - best_ckpt = torch.load(best_ckpt_path) + best_ckpt = torch.load(best_ckpt_path, weights_only=False) - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_5.pth'), + weights_only=False) self.assertEqual(best_ckpt_path, ckpt['message_hub']['runtime_info']['best_ckpt']) @@ -603,11 +615,13 @@ def test_with_runner(self, training_type): runner.train() best_ckpt_path = osp.join(cfg.work_dir, f'best_test_acc_{training_type}_5.pth') - best_ckpt = torch.load(best_ckpt_path) + best_ckpt = torch.load(best_ckpt_path, weights_only=False) # if the current ckpt is the best, the interval will be ignored the # the ckpt will also be saved - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_5.pth'), + weights_only=False) self.assertEqual(best_ckpt_path, ckpt['message_hub']['runtime_info']['best_ckpt']) diff --git a/tests/test_hooks/test_ema_hook.py b/tests/test_hooks/test_ema_hook.py index 6dad7ba4f0..9467da45dc 100644 --- a/tests/test_hooks/test_ema_hook.py +++ b/tests/test_hooks/test_ema_hook.py @@ -230,7 +230,8 @@ def test_with_runner(self): self.assertTrue( isinstance(ema_hook.ema_model, ExponentialMovingAverage)) - checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth')) + checkpoint = torch.load( + osp.join(self.temp_dir.name, 'epoch_2.pth'), weights_only=False) self.assertTrue('ema_state_dict' in checkpoint) self.assertTrue(checkpoint['ema_state_dict']['steps'] == 8) @@ -245,7 +246,8 @@ def test_with_runner(self): runner.test() # Test load checkpoint without ema_state_dict - checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth')) + checkpoint = torch.load( + osp.join(self.temp_dir.name, 'epoch_2.pth'), weights_only=False) checkpoint.pop('ema_state_dict') torch.save(checkpoint, osp.join(self.temp_dir.name, 'without_ema_state_dict.pth')) @@ -274,7 +276,9 @@ def test_with_runner(self): runner = self.build_runner(cfg) runner.train() state_dict = torch.load( - osp.join(self.temp_dir.name, 'epoch_4.pth'), map_location='cpu') + osp.join(self.temp_dir.name, 'epoch_4.pth'), + map_location='cpu', + weights_only=False) self.assertIn('ema_state_dict', state_dict) for k, v in state_dict['state_dict'].items(): assert_allclose(v, state_dict['ema_state_dict']['module.' + k]) @@ -287,12 +291,16 @@ def test_with_runner(self): runner = self.build_runner(cfg) runner.train() state_dict = torch.load( - osp.join(self.temp_dir.name, 'iter_4.pth'), map_location='cpu') + osp.join(self.temp_dir.name, 'iter_4.pth'), + map_location='cpu', + weights_only=False) self.assertIn('ema_state_dict', state_dict) for k, v in state_dict['state_dict'].items(): assert_allclose(v, state_dict['ema_state_dict']['module.' + k]) state_dict = torch.load( - osp.join(self.temp_dir.name, 'iter_5.pth'), map_location='cpu') + osp.join(self.temp_dir.name, 'iter_5.pth'), + map_location='cpu', + weights_only=False) self.assertIn('ema_state_dict', state_dict) def _test_swap_parameters(self, func_name, *args, **kwargs): diff --git a/tests/test_optim/test_scheduler/test_param_scheduler.py b/tests/test_optim/test_scheduler/test_param_scheduler.py index a13072dc6e..6dc9587ed3 100644 --- a/tests/test_optim/test_scheduler/test_param_scheduler.py +++ b/tests/test_optim/test_scheduler/test_param_scheduler.py @@ -685,7 +685,8 @@ def _check_scheduler_state_dict(self, scheduler_copy = construct2() torch.save(scheduler.state_dict(), osp.join(self.temp_dir.name, 'tmp.pth')) - state_dict = torch.load(osp.join(self.temp_dir.name, 'tmp.pth')) + state_dict = torch.load( + osp.join(self.temp_dir.name, 'tmp.pth'), weights_only=False) scheduler_copy.load_state_dict(state_dict) for key in scheduler.__dict__.keys(): if key != 'optimizer': diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index e7668054bb..7e105f0895 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -2272,7 +2272,7 @@ def test_checkpoint(self): self.assertTrue(osp.exists(path)) self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_4.pth'))) - ckpt = torch.load(path) + ckpt = torch.load(path, weights_only=False) self.assertEqual(ckpt['meta']['epoch'], 3) self.assertEqual(ckpt['meta']['iter'], 12) self.assertEqual(ckpt['meta']['experiment_name'], @@ -2444,7 +2444,7 @@ def test_checkpoint(self): self.assertTrue(osp.exists(path)) self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_13.pth'))) - ckpt = torch.load(path) + ckpt = torch.load(path, weights_only=False) self.assertEqual(ckpt['meta']['epoch'], 0) self.assertEqual(ckpt['meta']['iter'], 12) assert isinstance(ckpt['optimizer'], dict) @@ -2455,7 +2455,7 @@ def test_checkpoint(self): self.assertEqual(message_hub.get_info('iter'), 11) # 2.1.2 check class attribute _statistic_methods can be saved HistoryBuffer._statistics_methods.clear() - ckpt = torch.load(path) + ckpt = torch.load(path, weights_only=False) self.assertIn('min', HistoryBuffer._statistics_methods) # 2.2 test `load_checkpoint` From ca52af691679fad10b7fb150e072eb594eb7cabb Mon Sep 17 00:00:00 2001 From: dianyo Date: Fri, 21 Feb 2025 08:17:44 +0000 Subject: [PATCH 3/4] update pre-commit zh-cn as well --- .pre-commit-config-zh-cn.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config-zh-cn.yaml b/.pre-commit-config-zh-cn.yaml index 02e009fd74..e895187b07 100644 --- a/.pre-commit-config-zh-cn.yaml +++ b/.pre-commit-config-zh-cn.yaml @@ -12,8 +12,8 @@ repos: rev: 5.11.5 hooks: - id: isort - - repo: https://gitee.com/openmmlab/mirrors-yapf - rev: v0.32.0 + - repo: https://github.com/google/yapf + rev: v0.43.0 hooks: - id: yapf - repo: https://gitee.com/openmmlab/mirrors-pre-commit-hooks From 3594b5e642f6e664305ae31daa3dfc52f0a71634 Mon Sep 17 00:00:00 2001 From: dianyo Date: Fri, 21 Feb 2025 08:24:18 +0000 Subject: [PATCH 4/4] new yapf lint --- mmengine/_strategy/deepspeed.py | 8 +++---- mmengine/config/config.py | 21 +++++++++++-------- mmengine/dataset/utils.py | 3 ++- mmengine/fileio/backends/local_backend.py | 4 ++-- mmengine/fileio/file_client.py | 4 ++-- mmengine/hooks/checkpoint_hook.py | 8 +++---- mmengine/model/test_time_aug.py | 7 ++++--- mmengine/utils/dl_utils/torch_ops.py | 6 +++--- mmengine/visualization/visualizer.py | 5 +++-- tests/test_analysis/test_jit_analysis.py | 7 ++++--- tests/test_dataset/test_base_dataset.py | 8 +++---- .../test_optimizer/test_optimizer_wrapper.py | 12 +++++------ 12 files changed, 50 insertions(+), 43 deletions(-) diff --git a/mmengine/_strategy/deepspeed.py b/mmengine/_strategy/deepspeed.py index 3f89ff760d..3d945a6a54 100644 --- a/mmengine/_strategy/deepspeed.py +++ b/mmengine/_strategy/deepspeed.py @@ -310,10 +310,10 @@ def __init__( self.config.setdefault('gradient_accumulation_steps', 1) self.config['steps_per_print'] = steps_per_print self._inputs_to_half = inputs_to_half - assert (exclude_frozen_parameters is None or - digit_version(deepspeed.__version__) >= digit_version('0.13.2') - ), ('DeepSpeed >= 0.13.2 is required to enable ' - 'exclude_frozen_parameters') + assert (exclude_frozen_parameters is None or digit_version( + deepspeed.__version__) >= digit_version('0.13.2')), ( + 'DeepSpeed >= 0.13.2 is required to enable ' + 'exclude_frozen_parameters') self.exclude_frozen_parameters = exclude_frozen_parameters register_deepspeed_optimizers() diff --git a/mmengine/config/config.py b/mmengine/config/config.py index 36f92f0b3a..1b6efd5d37 100644 --- a/mmengine/config/config.py +++ b/mmengine/config/config.py @@ -48,9 +48,10 @@ def _lazy2string(cfg_dict, dict_type=None): if isinstance(cfg_dict, dict): dict_type = dict_type or type(cfg_dict) - return dict_type( - {k: _lazy2string(v, dict_type) - for k, v in dict.items(cfg_dict)}) + return dict_type({ + k: _lazy2string(v, dict_type) + for k, v in dict.items(cfg_dict) + }) elif isinstance(cfg_dict, (tuple, list)): return type(cfg_dict)(_lazy2string(v, dict_type) for v in cfg_dict) elif isinstance(cfg_dict, (LazyAttr, LazyObject)): @@ -273,13 +274,15 @@ def __reduce_ex__(self, proto): # called by CPython interpreter during pickling. See more details in # https://github.com/python/cpython/blob/8d61a71f9c81619e34d4a30b625922ebc83c561b/Objects/typeobject.c#L6196 # noqa: E501 if digit_version(platform.python_version()) < digit_version('3.8'): - return (self.__class__, ({k: v - for k, v in super().items()}, ), None, - None, None) + return (self.__class__, ({ + k: v + for k, v in super().items() + }, ), None, None, None) else: - return (self.__class__, ({k: v - for k, v in super().items()}, ), None, - None, None, None) + return (self.__class__, ({ + k: v + for k, v in super().items() + }, ), None, None, None, None) def __eq__(self, other): if isinstance(other, ConfigDict): diff --git a/mmengine/dataset/utils.py b/mmengine/dataset/utils.py index 2c9cf96497..d140cc8dc4 100644 --- a/mmengine/dataset/utils.py +++ b/mmengine/dataset/utils.py @@ -158,7 +158,8 @@ def default_collate(data_batch: Sequence) -> Any: return [default_collate(samples) for samples in transposed] elif isinstance(data_item, Mapping): return data_item_type({ - key: default_collate([d[key] for d in data_batch]) + key: + default_collate([d[key] for d in data_batch]) for key in data_item }) else: diff --git a/mmengine/fileio/backends/local_backend.py b/mmengine/fileio/backends/local_backend.py index c7d5f04621..84ebe95514 100644 --- a/mmengine/fileio/backends/local_backend.py +++ b/mmengine/fileio/backends/local_backend.py @@ -156,8 +156,8 @@ def isfile(self, filepath: Union[str, Path]) -> bool: """ return osp.isfile(filepath) - def join_path(self, filepath: Union[str, Path], - *filepaths: Union[str, Path]) -> str: + def join_path(self, filepath: Union[str, Path], *filepaths: + Union[str, Path]) -> str: r"""Concatenate all file paths. Join one or more filepath components intelligently. The return value diff --git a/mmengine/fileio/file_client.py b/mmengine/fileio/file_client.py index 61551d3d1d..29730e7564 100644 --- a/mmengine/fileio/file_client.py +++ b/mmengine/fileio/file_client.py @@ -385,8 +385,8 @@ def isfile(self, filepath: Union[str, Path]) -> bool: """ return self.client.isfile(filepath) - def join_path(self, filepath: Union[str, Path], - *filepaths: Union[str, Path]) -> str: + def join_path(self, filepath: Union[str, Path], *filepaths: + Union[str, Path]) -> str: r"""Concatenate all file paths. Join one or more filepath components intelligently. The return value diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 92a4867bb9..3adb78c7dc 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -196,10 +196,10 @@ def __init__(self, self.save_best = save_best # rule logic - assert (isinstance(rule, str) or is_list_of(rule, str) - or (rule is None)), ( - '"rule" should be a str or list of str or None, ' - f'but got {type(rule)}') + assert (isinstance(rule, str) or is_list_of(rule, str) or + (rule + is None)), ('"rule" should be a str or list of str or None, ' + f'but got {type(rule)}') if isinstance(rule, list): # check the length of rule list assert len(rule) in [ diff --git a/mmengine/model/test_time_aug.py b/mmengine/model/test_time_aug.py index c623eec8bc..2f19248c2c 100644 --- a/mmengine/model/test_time_aug.py +++ b/mmengine/model/test_time_aug.py @@ -124,9 +124,10 @@ def test_step(self, data): data_list: Union[List[dict], List[list]] if isinstance(data, dict): num_augs = len(data[next(iter(data))]) - data_list = [{key: value[idx] - for key, value in data.items()} - for idx in range(num_augs)] + data_list = [{ + key: value[idx] + for key, value in data.items() + } for idx in range(num_augs)] elif isinstance(data, (tuple, list)): num_augs = len(data[0]) data_list = [[_data[idx] for _data in data] diff --git a/mmengine/utils/dl_utils/torch_ops.py b/mmengine/utils/dl_utils/torch_ops.py index 2550ae6986..85dc3100d2 100644 --- a/mmengine/utils/dl_utils/torch_ops.py +++ b/mmengine/utils/dl_utils/torch_ops.py @@ -4,9 +4,9 @@ from ..version_utils import digit_version from .parrots_wrapper import TORCH_VERSION -_torch_version_meshgrid_indexing = ( - 'parrots' not in TORCH_VERSION - and digit_version(TORCH_VERSION) >= digit_version('1.10.0a0')) +_torch_version_meshgrid_indexing = ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) + >= digit_version('1.10.0a0')) def torch_meshgrid(*tensors): diff --git a/mmengine/visualization/visualizer.py b/mmengine/visualization/visualizer.py index 6979395aca..6653497d6e 100644 --- a/mmengine/visualization/visualizer.py +++ b/mmengine/visualization/visualizer.py @@ -754,8 +754,9 @@ def draw_bboxes( assert bboxes.shape[-1] == 4, ( f'The shape of `bboxes` should be (N, 4), but got {bboxes.shape}') - assert (bboxes[:, 0] <= bboxes[:, 2]).all() and (bboxes[:, 1] <= - bboxes[:, 3]).all() + assert (bboxes[:, 0] <= bboxes[:, 2]).all() and (bboxes[:, 1] + <= bboxes[:, + 3]).all() if not self._is_posion_valid(bboxes.reshape((-1, 2, 2))): warnings.warn( 'Warning: The bbox is out of bounds,' diff --git a/tests/test_analysis/test_jit_analysis.py b/tests/test_analysis/test_jit_analysis.py index be10309d0f..4b1dfaf595 100644 --- a/tests/test_analysis/test_jit_analysis.py +++ b/tests/test_analysis/test_jit_analysis.py @@ -634,9 +634,10 @@ def dummy_ops_handle(inputs: List[Any], dummy_flops = {} for name, counts in model.flops.items(): - dummy_flops[name] = Counter( - {op: flop - for op, flop in counts.items() if op != self.lin_op}) + dummy_flops[name] = Counter({ + op: flop + for op, flop in counts.items() if op != self.lin_op + }) dummy_flops[''][dummy_name] = 2 * dummy_out dummy_flops['fc'][dummy_name] = dummy_out dummy_flops['submod'][dummy_name] = dummy_out diff --git a/tests/test_dataset/test_base_dataset.py b/tests/test_dataset/test_base_dataset.py index f4ec815ec2..48bba665fe 100644 --- a/tests/test_dataset/test_base_dataset.py +++ b/tests/test_dataset/test_base_dataset.py @@ -733,13 +733,13 @@ def test_length(self): def test_getitem(self): assert ( self.cat_datasets[0]['imgs'] == self.dataset_a[0]['imgs']).all() - assert (self.cat_datasets[0]['imgs'] != - self.dataset_b[0]['imgs']).all() + assert (self.cat_datasets[0]['imgs'] + != self.dataset_b[0]['imgs']).all() assert ( self.cat_datasets[-1]['imgs'] == self.dataset_b[-1]['imgs']).all() - assert (self.cat_datasets[-1]['imgs'] != - self.dataset_a[-1]['imgs']).all() + assert (self.cat_datasets[-1]['imgs'] + != self.dataset_a[-1]['imgs']).all() def test_get_data_info(self): assert self.cat_datasets.get_data_info( diff --git a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py index ef1db241dd..8a6e57d456 100644 --- a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py +++ b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py @@ -455,8 +455,8 @@ def test_init(self): not torch.cuda.is_available(), reason='`torch.cuda.amp` is only available when pytorch-gpu installed') def test_step(self, dtype): - if dtype is not None and (digit_version(TORCH_VERSION) < - digit_version('1.10.0')): + if dtype is not None and (digit_version(TORCH_VERSION) + < digit_version('1.10.0')): raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to ' 'support `dtype` argument in autocast') if dtype == 'bfloat16' and not bf16_supported(): @@ -478,8 +478,8 @@ def test_step(self, dtype): not torch.cuda.is_available(), reason='`torch.cuda.amp` is only available when pytorch-gpu installed') def test_backward(self, dtype): - if dtype is not None and (digit_version(TORCH_VERSION) < - digit_version('1.10.0')): + if dtype is not None and (digit_version(TORCH_VERSION) + < digit_version('1.10.0')): raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to ' 'support `dtype` argument in autocast') if dtype == 'bfloat16' and not bf16_supported(): @@ -539,8 +539,8 @@ def test_load_state_dict(self): not torch.cuda.is_available(), reason='`torch.cuda.amp` is only available when pytorch-gpu installed') def test_optim_context(self, dtype, target_dtype): - if dtype is not None and (digit_version(TORCH_VERSION) < - digit_version('1.10.0')): + if dtype is not None and (digit_version(TORCH_VERSION) + < digit_version('1.10.0')): raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to ' 'support `dtype` argument in autocast') if dtype == 'bfloat16' and not bf16_supported():