Skip to content

Commit e31fbb6

Browse files
committed
Fix and test code
1 parent fb11133 commit e31fbb6

File tree

3 files changed

+46
-51
lines changed

3 files changed

+46
-51
lines changed

dlclibrary/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@
99
# Licensed under GNU Lesser General Public License v3.0
1010
#
1111

12-
from dlclibrary.dlcmodelzoo.modelzoo_download import download_hugginface_model
12+
from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model
1313
from dlclibrary.version import __version__, VERSION

dlclibrary/dlcmodelzoo/modelzoo_download.py

Lines changed: 34 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def _get_dlclibrary_path():
2929
return os.path.split(importlib.util.find_spec("dlclibrary").origin)[0]
3030

3131

32-
def _loadmodelnames():
32+
def _load_model_names():
3333
"""Loads URLs and commit hashes for available models."""
3434
from ruamel.yaml import YAML
3535

@@ -38,7 +38,7 @@ def _loadmodelnames():
3838
return YAML().load(file)
3939

4040

41-
def download_huggingface_model(modelname, target_dir=".", removeHFfolder=True):
41+
def download_huggingface_model(modelname, target_dir=".", remove_hf_folder=True):
4242
"""
4343
Downloads a DeepLabCut Model Zoo Project from Hugging Face
4444
@@ -48,52 +48,40 @@ def download_huggingface_model(modelname, target_dir=".", removeHFfolder=True):
4848
Name of the ModelZoo model. For visualizations see: http://www.mackenziemathislab.org/dlc-modelzoo
4949
target_dir : directory (as string)
5050
Directory where to store the model weigths and pose_cfg.yaml file
51-
removeHFfolder : bool, default True
51+
remove_hf_folder : bool, default True
5252
Whether to remove the directory structure provided by HuggingFace after downloading and decompressing data into DeepLabCut format.
5353
"""
5454
from huggingface_hub import hf_hub_download
55-
import tarfile, os
55+
import tarfile
5656
from pathlib import Path
5757

58-
neturls = _loadmodelnames()
59-
60-
if modelname in neturls.keys():
61-
print("Loading....", modelname)
62-
url = neturls[modelname].split("/")
63-
repo_id, targzfn = url[0] + "/" + url[1], str(url[-1])
64-
65-
hf_hub_download(repo_id, targzfn, cache_dir=str(target_dir))
66-
# creates a new subfolder as indicated below, unzipping from there and deleting this folder
67-
68-
# Building the HuggingFaceHub download path:
69-
hf_path = (
70-
"models--"
71-
+ url[0]
72-
+ "--"
73-
+ url[1]
74-
+ "/snapshots/"
75-
+ str(neturls[modelname + "_commit"])
76-
+ "/"
77-
+ targzfn
78-
)
79-
80-
filename = os.path.join(target_dir, hf_path)
81-
with tarfile.open(filename, mode="r:gz") as tar:
82-
for member in tar:
83-
if not member.isdir():
84-
fname = Path(member.name).name # getting the filename
85-
tar.makefile(member, target_dir + "/" + fname)
86-
# tar.extractall(target_dir, members=tarfilenamecutting(tar))
87-
88-
if removeHFfolder:
89-
# Removing folder
90-
import shutil
91-
92-
shutil.rmtree(
93-
Path(os.path.join(target_dir, "models--" + url[0] + "--" + url[1]))
94-
)
95-
96-
else:
97-
models = [fn for fn in neturls.keys()]
98-
print("Model does not exist: ", modelname)
99-
print("Pick one of the following: ", MODELOPTIONS)
58+
neturls = _load_model_names()
59+
if modelname not in neturls:
60+
raise ValueError(f"`modelname` should be one of: {', '.join(modelname)}.")
61+
62+
print("Loading....", modelname)
63+
url = neturls[modelname].split("/")
64+
repo_id, targzfn = url[0] + "/" + url[1], str(url[-1])
65+
66+
hf_hub_download(repo_id, targzfn, cache_dir=str(target_dir))
67+
68+
# Create a new subfolder as indicated below, unzipping from there and deleting this folder
69+
hf_folder = f"models--{url[0]}--{url[1]}"
70+
hf_path = os.path.join(
71+
hf_folder,
72+
"snapshots",
73+
str(neturls[modelname + "_commit"]),
74+
targzfn,
75+
)
76+
77+
filename = os.path.join(target_dir, hf_path)
78+
with tarfile.open(filename, mode="r:gz") as tar:
79+
for member in tar:
80+
if not member.isdir():
81+
fname = Path(member.name).name
82+
tar.makefile(member, os.path.join(target_dir, fname))
83+
84+
if remove_hf_folder:
85+
import shutil
86+
87+
shutil.rmtree(os.path.join(target_dir, hf_folder))

tests/test_modeldownload.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,21 @@
88
#
99
# Licensed under GNU Lesser General Public License v3.0
1010
#
11+
import dlclibrary
12+
import os
13+
import pytest
1114

1215

13-
def test_catdownload(tmp_path_factory):
14-
# TODO: just download the lightweight stuff..
15-
import dlclibrary, os
16-
16+
def test_download_huggingface_model(tmp_path_factory):
1717
folder = tmp_path_factory.mktemp("cat")
1818
dlclibrary.download_huggingface_model("full_cat", str(folder))
1919

2020
assert os.path.exists(folder / "pose_cfg.yaml")
2121
assert os.path.exists(folder / "snapshot-75000.meta")
22+
# Verify that the Hugging Face folder was removed
23+
assert not any(f.startswith("models--") for f in os.listdir(folder))
24+
25+
26+
def test_download_huggingface_wrong_model():
27+
with pytest.raises(ValueError):
28+
dlclibrary.download_huggingface_model("wrong_model_name")

0 commit comments

Comments
 (0)