diff --git a/mim/commands/download.py b/mim/commands/download.py index a49921a..f77e71b 100644 --- a/mim/commands/download.py +++ b/mim/commands/download.py @@ -41,26 +41,33 @@ is_flag=True, default=True, help='Ignore ssl certificate check') +@click.option( + '--no-checkpoint', + 'ignore_checkpoint', + is_flag=True, + help='Ignore ssl certificate check') @click.option( '--dest', 'dest_root', type=str, help='Destination of saving checkpoints.') def cli(package: str, configs: List[str], dest_root: Optional[str] = None, - check_certificate: bool = True) -> None: + check_certificate: bool = True, + ignore_checkpoint: bool = False) -> None: """Download checkpoints from url and parse configs from package. \b Example: - > mim download mmcls --config resnet18_8xb16_cifar10 - > mim download mmcls --config resnet18_8xb16_cifar10 --dest . + > mim download mmpretrain --config resnet18_8xb16_cifar10 + > mim download mmpretrain --config resnet18_8xb16_cifar10 --dest . """ - download(package, configs, dest_root, check_certificate) + download(package, configs, dest_root, check_certificate, ignore_checkpoint) def download(package: str, configs: List[str], dest_root: Optional[str] = None, - check_certificate: bool = True) -> List[str]: + check_certificate: bool = True, + ignore_checkpoint: bool = False) -> List[str]: """Download checkpoints from url and parse configs from package. Args: @@ -70,6 +77,8 @@ def download(package: str, config. Default: None. check_certificate (bool): Whether to check the ssl certificate. Default: True. + ignore_checkpoint (bool): Whether to download checkpoints. If True, + only config will be downloaded. Default: False. """ if dest_root is None: dest_root = DEFAULT_CACHE_DIR @@ -114,21 +123,24 @@ def download(package: str, for config in configs: click.echo(f'processing {config}...') - checkpoint_urls = model_info[config]['weight'] - for checkpoint_url in checkpoint_urls.split(','): - filename = checkpoint_url.split('/')[-1] - checkpoint_path = osp.join(dest_root, filename) - if osp.exists(checkpoint_path): - echo_success(f'{filename} exists in {dest_root}') - else: - # TODO: check checkpoint hash when all the models are ready. - download_from_file( - checkpoint_url, - checkpoint_path, - check_certificate=check_certificate) + if not ignore_checkpoint: + checkpoint_urls = model_info[config]['weight'] + for checkpoint_url in checkpoint_urls.split(','): + filename = checkpoint_url.split('/')[-1] + checkpoint_path = osp.join(dest_root, filename) + if osp.exists(checkpoint_path): + echo_success(f'{filename} exists in {dest_root}') + else: + # TODO: check checkpoint hash when all the models are ready + download_from_file( + checkpoint_url, + checkpoint_path, + check_certificate=check_certificate) + + echo_success( + f'Successfully downloaded {filename} to {dest_root}') - echo_success( - f'Successfully downloaded {filename} to {dest_root}') + checkpoints.append(filename) config_paths = model_info[config]['config'] for config_path in config_paths.split(','): @@ -145,7 +157,6 @@ def download(package: str, config_obj.dump(saved_config_path) echo_success( f'Successfully dumped {config}.py to {dest_root}') - checkpoints.append(filename) break else: raise ValueError( diff --git a/tests/test_download.py b/tests/test_download.py index 7413303..8585a27 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -9,56 +9,53 @@ def setup_module(): runner = CliRunner() - result = runner.invoke(uninstall, ['mmcv-full', '--yes']) + result = runner.invoke(uninstall, ['mmcv', '--yes']) assert result.exit_code == 0 result = runner.invoke(uninstall, ['mmengine', '--yes']) assert result.exit_code == 0 - result = runner.invoke(uninstall, ['mmcls', '--yes']) + result = runner.invoke(uninstall, ['mmpretrain', '--yes']) assert result.exit_code == 0 def test_download(tmp_path): runner = CliRunner() - result = runner.invoke(install, ['mmcv-full', '--yes']) + result = runner.invoke(install, ['mmcv', '--yes']) assert result.exit_code == 0 result = runner.invoke(install, ['mmengine', '--yes']) assert result.exit_code == 0 with pytest.raises(ValueError): # version is not allowed - download('mmcls==0.11.0', ['resnet18_8xb16_cifar10']) + download('mmpretrain==0.11.0', ['resnet18_8xb16_cifar10']) with pytest.raises(RuntimeError): - # mmcls is not installed - download('mmcls', ['resnet18_8xb16_cifar10']) + # mmpretrain is not installed + download('mmpretrain', ['resnet18_8xb16_cifar10']) with pytest.raises(ValueError): # invalid config - download('mmcls==0.11.0', ['resnet18_b16x8_cifar1']) + download('mmpretrain', ['resnet18_b16x8_cifar1']) runner = CliRunner() - # mim install mmcls --yes - result = runner.invoke(install, [ - 'mmcls', '--yes', '-f', - 'https://github.com/open-mmlab/mmclassification.git' - ]) + # mim install mmpretrain + result = runner.invoke(install, ['mmpretrain']) assert result.exit_code == 0 - # mim download mmcls --config resnet18_8xb16_cifar10 - checkpoints = download('mmcls', ['resnet18_8xb16_cifar10']) + # mim download mmpretrain --config resnet18_8xb16_cifar10 + checkpoints = download('mmpretrain', ['resnet18_8xb16_cifar10']) assert checkpoints == ['resnet18_b16x8_cifar10_20210528-bd6371c8.pth'] - checkpoints = download('mmcls', ['resnet18_8xb16_cifar10']) + checkpoints = download('mmpretrain', ['resnet18_8xb16_cifar10']) - # mim download mmcls --config resnet18_8xb16_cifar10 --dest tmp_path - checkpoints = download('mmcls', ['resnet18_8xb16_cifar10'], tmp_path) + # mim download mmpretrain --config resnet18_8xb16_cifar10 --dest tmp_path + checkpoints = download('mmpretrain', ['resnet18_8xb16_cifar10'], tmp_path) assert checkpoints == ['resnet18_b16x8_cifar10_20210528-bd6371c8.pth'] def teardown_module(): runner = CliRunner() - result = runner.invoke(uninstall, ['mmcv-full', '--yes']) + result = runner.invoke(uninstall, ['mmcv', '--yes']) assert result.exit_code == 0 result = runner.invoke(uninstall, ['mmengine', '--yes']) assert result.exit_code == 0 - result = runner.invoke(uninstall, ['mmcls', '--yes']) + result = runner.invoke(uninstall, ['mmpretrain', '--yes']) assert result.exit_code == 0