|
1 | 1 | import os |
2 | 2 | from pathlib import Path |
3 | | -from typing import Dict, List, Literal, Optional, Union |
| 3 | +from typing import Dict, Iterable, List, Literal, Optional, Union |
4 | 4 |
|
5 | 5 | import requests |
6 | 6 | from tqdm.auto import tqdm as base_tqdm |
|
15 | 15 | RevisionNotFoundError, |
16 | 16 | ) |
17 | 17 | from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name |
18 | | -from .hf_api import DatasetInfo, HfApi, ModelInfo, SpaceInfo |
| 18 | +from .hf_api import DatasetInfo, HfApi, ModelInfo, RepoFile, SpaceInfo |
19 | 19 | from .utils import OfflineModeIsEnabled, filter_repo_objects, logging, validate_hf_hub_args |
20 | 20 | from .utils import tqdm as hf_tqdm |
21 | 21 |
|
22 | 22 |
|
23 | 23 | logger = logging.get_logger(__name__) |
24 | 24 |
|
| 25 | +VERY_LARGE_REPO_THRESHOLD = 50000 # After this limit, we don't consider `repo_info.siblings` to be reliable enough |
| 26 | + |
25 | 27 |
|
26 | 28 | @validate_hf_hub_args |
27 | 29 | def snapshot_download( |
@@ -145,20 +147,22 @@ def snapshot_download( |
145 | 147 |
|
146 | 148 | storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type)) |
147 | 149 |
|
| 150 | + api = HfApi( |
| 151 | + library_name=library_name, |
| 152 | + library_version=library_version, |
| 153 | + user_agent=user_agent, |
| 154 | + endpoint=endpoint, |
| 155 | + headers=headers, |
| 156 | + token=token, |
| 157 | + ) |
| 158 | + |
148 | 159 | repo_info: Union[ModelInfo, DatasetInfo, SpaceInfo, None] = None |
149 | 160 | api_call_error: Optional[Exception] = None |
150 | 161 | if not local_files_only: |
151 | 162 | # try/except logic to handle different errors => taken from `hf_hub_download` |
152 | 163 | try: |
153 | 164 | # if we have internet connection we want to list files to download |
154 | | - api = HfApi( |
155 | | - library_name=library_name, |
156 | | - library_version=library_version, |
157 | | - user_agent=user_agent, |
158 | | - endpoint=endpoint, |
159 | | - headers=headers, |
160 | | - ) |
161 | | - repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision, token=token) |
| 165 | + repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision) |
162 | 166 | except (requests.exceptions.SSLError, requests.exceptions.ProxyError): |
163 | 167 | # Actually raise for those subclasses of ConnectionError |
164 | 168 | raise |
@@ -251,13 +255,31 @@ def snapshot_download( |
251 | 255 | # => let's download the files! |
252 | 256 | assert repo_info.sha is not None, "Repo info returned from server must have a revision sha." |
253 | 257 | assert repo_info.siblings is not None, "Repo info returned from server must have a siblings list." |
254 | | - filtered_repo_files = list( |
255 | | - filter_repo_objects( |
256 | | - items=[f.rfilename for f in repo_info.siblings], |
257 | | - allow_patterns=allow_patterns, |
258 | | - ignore_patterns=ignore_patterns, |
| 258 | + |
| 259 | + # Corner case: on very large repos, the siblings list in `repo_info` might not contain all files. |
| 260 | + # In that case, we need to use the `list_repo_tree` method to prevent caching issues. |
| 261 | + repo_files: Iterable[str] = [f.rfilename for f in repo_info.siblings] |
| 262 | + has_many_files = len(repo_info.siblings) > VERY_LARGE_REPO_THRESHOLD |
| 263 | + if has_many_files: |
| 264 | + logger.info("The repo has more than 50,000 files. Using `list_repo_tree` to ensure all files are listed.") |
| 265 | + repo_files = ( |
| 266 | + f.rfilename |
| 267 | + for f in api.list_repo_tree(repo_id=repo_id, recursive=True, revision=revision, repo_type=repo_type) |
| 268 | + if isinstance(f, RepoFile) |
259 | 269 | ) |
| 270 | + |
| 271 | + filtered_repo_files: Iterable[str] = filter_repo_objects( |
| 272 | + items=repo_files, |
| 273 | + allow_patterns=allow_patterns, |
| 274 | + ignore_patterns=ignore_patterns, |
260 | 275 | ) |
| 276 | + |
| 277 | + if not has_many_files: |
| 278 | + filtered_repo_files = list(filtered_repo_files) |
| 279 | + tqdm_desc = f"Fetching {len(filtered_repo_files)} files" |
| 280 | + else: |
| 281 | + tqdm_desc = "Fetching ... files" |
| 282 | + |
261 | 283 | commit_hash = repo_info.sha |
262 | 284 | snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) |
263 | 285 | # if passed revision is not identical to commit_hash |
@@ -305,7 +327,7 @@ def _inner_hf_hub_download(repo_file: str): |
305 | 327 | thread_map( |
306 | 328 | _inner_hf_hub_download, |
307 | 329 | filtered_repo_files, |
308 | | - desc=f"Fetching {len(filtered_repo_files)} files", |
| 330 | + desc=tqdm_desc, |
309 | 331 | max_workers=max_workers, |
310 | 332 | # User can use its own tqdm class or the default one from `huggingface_hub.utils` |
311 | 333 | tqdm_class=tqdm_class or hf_tqdm, |
|
0 commit comments