Skip to content

Commit f5b5637

Browse files
authored
Merge pull request #3 from jeylau/main
Minor fixes, typo edits, and other cleanups
2 parents f48666f + 18a0f99 commit f5b5637

File tree

4 files changed

+45
-63
lines changed

4 files changed

+45
-63
lines changed

dlclibrary/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@
99
# Licensed under GNU Lesser General Public License v3.0
1010
#
1111

12-
from dlclibrary.dlcmodelzoo.modelzoo_download import download_hugginface_model
12+
from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model
1313
from dlclibrary.version import __version__, VERSION

dlclibrary/dlcmodelzoo/modelzoo_download.py

Lines changed: 32 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"mouse_pupil_vclose",
1919
"horse_sideview",
2020
"full_macaque",
21-
"superanimal_mouse",
21+
"superanimal_mouse_topview",
2222
]
2323

2424

@@ -29,83 +29,59 @@ def _get_dlclibrary_path():
2929
return os.path.split(importlib.util.find_spec("dlclibrary").origin)[0]
3030

3131

32-
def _loadmodelnames():
33-
"""Loads URLs and commit hashes for available models."""
32+
def _load_model_names():
33+
"""Load URLs and commit hashes for available models."""
3434
from ruamel.yaml import YAML
3535

3636
fn = os.path.join(_get_dlclibrary_path(), "modelzoo_urls.yaml")
3737
with open(fn) as file:
3838
return YAML().load(file)
3939

4040

41-
def download_huggingface_model(modelname, target_dir=".", removeHFfolder=True):
41+
def download_huggingface_model(modelname, target_dir=".", remove_hf_folder=True):
4242
"""
43-
Downloads a DeepLabCut Model Zoo Project from Hugging Face
43+
Download a DeepLabCut Model Zoo Project from Hugging Face
4444
4545
Parameters
4646
----------
4747
modelname : string
4848
Name of the ModelZoo model. For visualizations see: http://www.mackenziemathislab.org/dlc-modelzoo
4949
target_dir : directory (as string)
50-
Directory where to store the model weigths and pose_cfg.yaml file
51-
removeHFfolder : bool, default True
50+
Directory where to store the model weights and pose_cfg.yaml file
51+
remove_hf_folder : bool, default True
5252
Whether to remove the directory structure provided by HuggingFace after downloading and decompressing data into DeepLabCut format.
5353
"""
5454
from huggingface_hub import hf_hub_download
55-
import tarfile, os
55+
import tarfile
5656
from pathlib import Path
5757

58-
neturls = _loadmodelnames()
58+
neturls = _load_model_names()
59+
if modelname not in neturls:
60+
raise ValueError(f"`modelname` should be one of: {', '.join(modelname)}.")
5961

60-
if modelname in neturls.keys():
61-
print("Loading....", modelname)
62-
url = neturls[modelname].split("/")
63-
repo_id, targzfn = url[0] + "/" + url[1], str(url[-1])
62+
print("Loading....", modelname)
63+
url = neturls[modelname].split("/")
64+
repo_id, targzfn = url[0] + "/" + url[1], str(url[-1])
6465

65-
hf_hub_download(repo_id, targzfn, cache_dir=str(target_dir))
66-
# creates a new subfolder as indicated below, unzipping from there and deleting this folder
66+
hf_hub_download(repo_id, targzfn, cache_dir=str(target_dir))
6767

68-
# Building the HuggingFaceHub download path:
69-
hf_path = (
70-
"models--"
71-
+ url[0]
72-
+ "--"
73-
+ url[1]
74-
+ "/snapshots/"
75-
+ str(neturls[modelname + "_commit"])
76-
+ "/"
77-
+ targzfn
78-
)
68+
# Create a new subfolder as indicated below, unzipping from there and deleting this folder
69+
hf_folder = f"models--{url[0]}--{url[1]}"
70+
hf_path = os.path.join(
71+
hf_folder,
72+
"snapshots",
73+
str(neturls[modelname + "_commit"]),
74+
targzfn,
75+
)
7976

80-
filename = os.path.join(target_dir, hf_path)
81-
with tarfile.open(filename, mode="r:gz") as tar:
82-
for member in tar:
83-
if not member.isdir():
84-
fname = Path(member.name).name # getting the filename
85-
tar.makefile(member, target_dir + "/" + fname)
86-
# tar.extractall(target_dir, members=tarfilenamecutting(tar))
77+
filename = os.path.join(target_dir, hf_path)
78+
with tarfile.open(filename, mode="r:gz") as tar:
79+
for member in tar:
80+
if not member.isdir():
81+
fname = Path(member.name).name
82+
tar.makefile(member, os.path.join(target_dir, fname))
8783

88-
if removeHFfolder:
89-
# Removing folder
90-
import shutil
84+
if remove_hf_folder:
85+
import shutil
9186

92-
shutil.rmtree(
93-
Path(os.path.join(target_dir, "models--" + url[0] + "--" + url[1]))
94-
)
95-
96-
else:
97-
models = [fn for fn in neturls.keys()]
98-
print("Model does not exist: ", modelname)
99-
print("Pick one of the following: ", MODELOPTIONS)
100-
101-
102-
if __name__ == "__main__":
103-
print("Randomly downloading a model for testing...")
104-
105-
import random
106-
107-
# modelname = 'full_cat'
108-
modelname = random.choice(MODELOPTIONS)
109-
110-
target_dir = "/Users/alex/Downloads" # folder has to exist!
111-
download_hugginface_model(modelname, target_dir)
87+
shutil.rmtree(os.path.join(target_dir, hf_folder))

dlclibrary/modelzoo_urls.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,5 @@ horse_sideview_commit: fd0329b2ffc8fe7a5e6eb3d4850ebca75987e92c
3131
full_macaque: mwmathis/DeepLabCutModelZoo-macaque_full/DLC_macaque_full_resnet50.tar.gz
3232
full_macaque_commit: 4c7ebf2628d5b7eb0483356595256fb01b7e1a9e
3333

34-
superanimal_mouse: mwmathis/DeepLabCutModelZoo-SuperAnimal-TopViewMouse/DLC_ma_supertopview5k_resnet_50_iteration-0_shuffle-1.tar.gz
35-
superanimal_mouse_commit: a7d7df40c3307a3c7a0ceeb2593d46a783235b28
34+
superanimal_mouse_topview: mwmathis/DeepLabCutModelZoo-SuperAnimal-TopViewMouse/DLC_ma_supertopview5k_resnet_50_iteration-0_shuffle-1.tar.gz
35+
superanimal_mouse_topview_commit: a7d7df40c3307a3c7a0ceeb2593d46a783235b28

tests/test_modeldownload.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,21 @@
88
#
99
# Licensed under GNU Lesser General Public License v3.0
1010
#
11+
import dlclibrary
12+
import os
1113
import pytest
1214

1315

14-
def test_catdownload(tmp_path_factory):
15-
# TODO: just download the lightweight stuff..
16-
import dlclibrary, os
17-
16+
def test_download_huggingface_model(tmp_path_factory):
1817
folder = tmp_path_factory.mktemp("cat")
1918
dlclibrary.download_huggingface_model("full_cat", str(folder))
2019

2120
assert os.path.exists(folder / "pose_cfg.yaml")
2221
assert os.path.exists(folder / "snapshot-75000.meta")
22+
# Verify that the Hugging Face folder was removed
23+
assert not any(f.startswith("models--") for f in os.listdir(folder))
24+
25+
26+
def test_download_huggingface_wrong_model():
27+
with pytest.raises(ValueError):
28+
dlclibrary.download_huggingface_model("wrong_model_name")

0 commit comments

Comments
 (0)