@@ -29,7 +29,7 @@ def _get_dlclibrary_path():
2929 return os .path .split (importlib .util .find_spec ("dlclibrary" ).origin )[0 ]
3030
3131
32- def _loadmodelnames ():
32+ def _load_model_names ():
3333 """Loads URLs and commit hashes for available models."""
3434 from ruamel .yaml import YAML
3535
@@ -38,7 +38,7 @@ def _loadmodelnames():
3838 return YAML ().load (file )
3939
4040
41- def download_huggingface_model (modelname , target_dir = "." , removeHFfolder = True ):
41+ def download_huggingface_model (modelname , target_dir = "." , remove_hf_folder = True ):
4242 """
4343 Downloads a DeepLabCut Model Zoo Project from Hugging Face
4444
@@ -48,52 +48,40 @@ def download_huggingface_model(modelname, target_dir=".", removeHFfolder=True):
4848 Name of the ModelZoo model. For visualizations see: http://www.mackenziemathislab.org/dlc-modelzoo
4949 target_dir : directory (as string)
5050 Directory where to store the model weigths and pose_cfg.yaml file
51- removeHFfolder : bool, default True
51+ remove_hf_folder : bool, default True
5252 Whether to remove the directory structure provided by HuggingFace after downloading and decompressing data into DeepLabCut format.
5353 """
5454 from huggingface_hub import hf_hub_download
55- import tarfile , os
55+ import tarfile
5656 from pathlib import Path
5757
58- neturls = _loadmodelnames ()
59-
60- if modelname in neturls .keys ():
61- print ("Loading...." , modelname )
62- url = neturls [modelname ].split ("/" )
63- repo_id , targzfn = url [0 ] + "/" + url [1 ], str (url [- 1 ])
64-
65- hf_hub_download (repo_id , targzfn , cache_dir = str (target_dir ))
66- # creates a new subfolder as indicated below, unzipping from there and deleting this folder
67-
68- # Building the HuggingFaceHub download path:
69- hf_path = (
70- "models--"
71- + url [0 ]
72- + "--"
73- + url [1 ]
74- + "/snapshots/"
75- + str (neturls [modelname + "_commit" ])
76- + "/"
77- + targzfn
78- )
79-
80- filename = os .path .join (target_dir , hf_path )
81- with tarfile .open (filename , mode = "r:gz" ) as tar :
82- for member in tar :
83- if not member .isdir ():
84- fname = Path (member .name ).name # getting the filename
85- tar .makefile (member , target_dir + "/" + fname )
86- # tar.extractall(target_dir, members=tarfilenamecutting(tar))
87-
88- if removeHFfolder :
89- # Removing folder
90- import shutil
91-
92- shutil .rmtree (
93- Path (os .path .join (target_dir , "models--" + url [0 ] + "--" + url [1 ]))
94- )
95-
96- else :
97- models = [fn for fn in neturls .keys ()]
98- print ("Model does not exist: " , modelname )
99- print ("Pick one of the following: " , MODELOPTIONS )
58+ neturls = _load_model_names ()
59+ if modelname not in neturls :
60+ raise ValueError (f"`modelname` should be one of: { ', ' .join (modelname )} ." )
61+
62+ print ("Loading...." , modelname )
63+ url = neturls [modelname ].split ("/" )
64+ repo_id , targzfn = url [0 ] + "/" + url [1 ], str (url [- 1 ])
65+
66+ hf_hub_download (repo_id , targzfn , cache_dir = str (target_dir ))
67+
68+ # Create a new subfolder as indicated below, unzipping from there and deleting this folder
69+ hf_folder = f"models--{ url [0 ]} --{ url [1 ]} "
70+ hf_path = os .path .join (
71+ hf_folder ,
72+ "snapshots" ,
73+ str (neturls [modelname + "_commit" ]),
74+ targzfn ,
75+ )
76+
77+ filename = os .path .join (target_dir , hf_path )
78+ with tarfile .open (filename , mode = "r:gz" ) as tar :
79+ for member in tar :
80+ if not member .isdir ():
81+ fname = Path (member .name ).name
82+ tar .makefile (member , os .path .join (target_dir , fname ))
83+
84+ if remove_hf_folder :
85+ import shutil
86+
87+ shutil .rmtree (os .path .join (target_dir , hf_folder ))
0 commit comments