Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 31 additions & 20 deletions mim/commands/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,26 +41,33 @@
is_flag=True,
default=True,
help='Ignore ssl certificate check')
@click.option(
'--no-checkpoint',
'ignore_checkpoint',
is_flag=True,
help='Ignore ssl certificate check')
@click.option(
'--dest', 'dest_root', type=str, help='Destination of saving checkpoints.')
def cli(package: str,
configs: List[str],
dest_root: Optional[str] = None,
check_certificate: bool = True) -> None:
check_certificate: bool = True,
ignore_checkpoint: bool = False) -> None:
"""Download checkpoints from url and parse configs from package.

\b
Example:
> mim download mmcls --config resnet18_8xb16_cifar10
> mim download mmcls --config resnet18_8xb16_cifar10 --dest .
> mim download mmpretrain --config resnet18_8xb16_cifar10
> mim download mmpretrain --config resnet18_8xb16_cifar10 --dest .
"""
download(package, configs, dest_root, check_certificate)
download(package, configs, dest_root, check_certificate, ignore_checkpoint)


def download(package: str,
configs: List[str],
dest_root: Optional[str] = None,
check_certificate: bool = True) -> List[str]:
check_certificate: bool = True,
ignore_checkpoint: bool = False) -> List[str]:
"""Download checkpoints from url and parse configs from package.

Args:
Expand All @@ -70,6 +77,8 @@ def download(package: str,
config. Default: None.
check_certificate (bool): Whether to check the ssl certificate.
Default: True.
ignore_checkpoint (bool): Whether to download checkpoints. If True,
only config will be downloaded. Default: False.
"""
if dest_root is None:
dest_root = DEFAULT_CACHE_DIR
Expand Down Expand Up @@ -114,21 +123,24 @@ def download(package: str,
for config in configs:
click.echo(f'processing {config}...')

checkpoint_urls = model_info[config]['weight']
for checkpoint_url in checkpoint_urls.split(','):
filename = checkpoint_url.split('/')[-1]
checkpoint_path = osp.join(dest_root, filename)
if osp.exists(checkpoint_path):
echo_success(f'{filename} exists in {dest_root}')
else:
# TODO: check checkpoint hash when all the models are ready.
download_from_file(
checkpoint_url,
checkpoint_path,
check_certificate=check_certificate)
if not ignore_checkpoint:
checkpoint_urls = model_info[config]['weight']
for checkpoint_url in checkpoint_urls.split(','):
filename = checkpoint_url.split('/')[-1]
checkpoint_path = osp.join(dest_root, filename)
if osp.exists(checkpoint_path):
echo_success(f'{filename} exists in {dest_root}')
else:
# TODO: check checkpoint hash when all the models are ready
download_from_file(
checkpoint_url,
checkpoint_path,
check_certificate=check_certificate)

echo_success(
f'Successfully downloaded {filename} to {dest_root}')

echo_success(
f'Successfully downloaded {filename} to {dest_root}')
checkpoints.append(filename)

config_paths = model_info[config]['config']
for config_path in config_paths.split(','):
Expand All @@ -145,7 +157,6 @@ def download(package: str,
config_obj.dump(saved_config_path)
echo_success(
f'Successfully dumped {config}.py to {dest_root}')
checkpoints.append(filename)
break
else:
raise ValueError(
Expand Down
35 changes: 16 additions & 19 deletions tests/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,56 +9,53 @@

def setup_module():
runner = CliRunner()
result = runner.invoke(uninstall, ['mmcv-full', '--yes'])
result = runner.invoke(uninstall, ['mmcv', '--yes'])
assert result.exit_code == 0
result = runner.invoke(uninstall, ['mmengine', '--yes'])
assert result.exit_code == 0
result = runner.invoke(uninstall, ['mmcls', '--yes'])
result = runner.invoke(uninstall, ['mmpretrain', '--yes'])
assert result.exit_code == 0


def test_download(tmp_path):
runner = CliRunner()
result = runner.invoke(install, ['mmcv-full', '--yes'])
result = runner.invoke(install, ['mmcv', '--yes'])
assert result.exit_code == 0
result = runner.invoke(install, ['mmengine', '--yes'])
assert result.exit_code == 0

with pytest.raises(ValueError):
# version is not allowed
download('mmcls==0.11.0', ['resnet18_8xb16_cifar10'])
download('mmpretrain==0.11.0', ['resnet18_8xb16_cifar10'])

with pytest.raises(RuntimeError):
# mmcls is not installed
download('mmcls', ['resnet18_8xb16_cifar10'])
# mmpretrain is not installed
download('mmpretrain', ['resnet18_8xb16_cifar10'])

with pytest.raises(ValueError):
# invalid config
download('mmcls==0.11.0', ['resnet18_b16x8_cifar1'])
download('mmpretrain', ['resnet18_b16x8_cifar1'])

runner = CliRunner()
# mim install mmcls --yes
result = runner.invoke(install, [
'mmcls', '--yes', '-f',
'https://github.com/open-mmlab/mmclassification.git'
])
# mim install mmpretrain
result = runner.invoke(install, ['mmpretrain'])
assert result.exit_code == 0

# mim download mmcls --config resnet18_8xb16_cifar10
checkpoints = download('mmcls', ['resnet18_8xb16_cifar10'])
# mim download mmpretrain --config resnet18_8xb16_cifar10
checkpoints = download('mmpretrain', ['resnet18_8xb16_cifar10'])
assert checkpoints == ['resnet18_b16x8_cifar10_20210528-bd6371c8.pth']
checkpoints = download('mmcls', ['resnet18_8xb16_cifar10'])
checkpoints = download('mmpretrain', ['resnet18_8xb16_cifar10'])

# mim download mmcls --config resnet18_8xb16_cifar10 --dest tmp_path
checkpoints = download('mmcls', ['resnet18_8xb16_cifar10'], tmp_path)
# mim download mmpretrain --config resnet18_8xb16_cifar10 --dest tmp_path
checkpoints = download('mmpretrain', ['resnet18_8xb16_cifar10'], tmp_path)
assert checkpoints == ['resnet18_b16x8_cifar10_20210528-bd6371c8.pth']


def teardown_module():
runner = CliRunner()
result = runner.invoke(uninstall, ['mmcv-full', '--yes'])
result = runner.invoke(uninstall, ['mmcv', '--yes'])
assert result.exit_code == 0
result = runner.invoke(uninstall, ['mmengine', '--yes'])
assert result.exit_code == 0
result = runner.invoke(uninstall, ['mmcls', '--yes'])
result = runner.invoke(uninstall, ['mmpretrain', '--yes'])
assert result.exit_code == 0