diff --git a/src/huggingface_hub/_commit_api.py b/src/huggingface_hub/_commit_api.py index 229c19026f..52c8f06080 100644 --- a/src/huggingface_hub/_commit_api.py +++ b/src/huggingface_hub/_commit_api.py @@ -15,7 +15,7 @@ from tqdm.contrib.concurrent import thread_map -from .constants import ENDPOINT, HF_HUB_ENABLE_HF_TRANSFER +from . import constants from .errors import EntryNotFoundError from .file_download import hf_hub_url from .lfs import UploadInfo, lfs_upload, post_lfs_batch_info @@ -432,7 +432,7 @@ def _wrapped_lfs_upload(batch_action) -> None: except Exception as exc: raise RuntimeError(f"Error while uploading '{operation.path_in_repo}' to the Hub.") from exc - if HF_HUB_ENABLE_HF_TRANSFER: + if constants.HF_HUB_ENABLE_HF_TRANSFER: logger.debug(f"Uploading {len(filtered_actions)} LFS files to the Hub using `hf_transfer`.") for action in hf_tqdm(filtered_actions, name="huggingface_hub.lfs_upload"): _wrapped_lfs_upload(action) @@ -506,7 +506,7 @@ def _fetch_upload_modes( [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If the Hub API response is improperly formatted. """ - endpoint = endpoint if endpoint is not None else ENDPOINT + endpoint = endpoint if endpoint is not None else constants.ENDPOINT # Fetch upload mode (LFS or regular) chunk by chunk. upload_modes: Dict[str, UploadMode] = {} diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index 6fbdce7b0c..39f5f09c11 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -6,13 +6,7 @@ from tqdm.auto import tqdm as base_tqdm from tqdm.contrib.concurrent import thread_map -from .constants import ( - DEFAULT_ETAG_TIMEOUT, - DEFAULT_REVISION, - HF_HUB_CACHE, - HF_HUB_ENABLE_HF_TRANSFER, - REPO_TYPES, -) +from . import constants from .errors import GatedRepoError, LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name from .hf_api import DatasetInfo, HfApi, ModelInfo, SpaceInfo @@ -40,7 +34,7 @@ def snapshot_download( library_version: Optional[str] = None, user_agent: Optional[Union[Dict, str]] = None, proxies: Optional[Dict] = None, - etag_timeout: float = DEFAULT_ETAG_TIMEOUT, + etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, force_download: bool = False, token: Optional[Union[bool, str]] = None, local_files_only: bool = False, @@ -137,16 +131,16 @@ def snapshot_download( if some parameter value is invalid. """ if cache_dir is None: - cache_dir = HF_HUB_CACHE + cache_dir = constants.HF_HUB_CACHE if revision is None: - revision = DEFAULT_REVISION + revision = constants.DEFAULT_REVISION if isinstance(cache_dir, Path): cache_dir = str(cache_dir) if repo_type is None: repo_type = "model" - if repo_type not in REPO_TYPES: - raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(REPO_TYPES)}") + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(constants.REPO_TYPES)}") storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type)) @@ -287,7 +281,7 @@ def _inner_hf_hub_download(repo_file: str): headers=headers, ) - if HF_HUB_ENABLE_HF_TRANSFER: + if constants.HF_HUB_ENABLE_HF_TRANSFER: # when using hf_transfer we don't want extra parallelism # from the one hf_transfer provides for file in filtered_repo_files: diff --git a/src/huggingface_hub/community.py b/src/huggingface_hub/community.py index 387b0cc121..16f2f02428 100644 --- a/src/huggingface_hub/community.py +++ b/src/huggingface_hub/community.py @@ -9,7 +9,7 @@ from datetime import datetime from typing import List, Literal, Optional, Union -from .constants import REPO_TYPE_MODEL +from . import constants from .utils import parse_datetime @@ -79,7 +79,7 @@ def git_reference(self) -> Optional[str]: @property def url(self) -> str: """Returns the URL of the discussion on the Hub.""" - if self.repo_type is None or self.repo_type == REPO_TYPE_MODEL: + if self.repo_type is None or self.repo_type == constants.REPO_TYPE_MODEL: return f"{self.endpoint}/{self.repo_id}/discussions/{self.num}" return f"{self.endpoint}/{self.repo_type}s/{self.repo_id}/discussions/{self.num}" diff --git a/src/huggingface_hub/fastai_utils.py b/src/huggingface_hub/fastai_utils.py index e586e8663c..3a9bf25f44 100644 --- a/src/huggingface_hub/fastai_utils.py +++ b/src/huggingface_hub/fastai_utils.py @@ -6,8 +6,7 @@ from packaging import version -from huggingface_hub import snapshot_download -from huggingface_hub.constants import CONFIG_NAME +from huggingface_hub import constants, snapshot_download from huggingface_hub.hf_api import HfApi from huggingface_hub.utils import ( SoftTemporaryDirectory, @@ -272,7 +271,7 @@ def _save_pretrained_fastai( if config is not None: if not isinstance(config, dict): raise RuntimeError(f"Provided config should be a dict. Got: '{type(config)}'") - path = os.path.join(save_directory, CONFIG_NAME) + path = os.path.join(save_directory, constants.CONFIG_NAME) with open(path, "w") as f: json.dump(config, f) diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index 52bfdd2c73..6ed745f4a4 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -18,32 +18,18 @@ import requests -from . import __version__ # noqa: F401 # for backward compatibility +from . import ( + __version__, # noqa: F401 # for backward compatibility + constants, +) from ._local_folder import ( get_local_download_paths, read_download_metadata, write_download_metadata, ) from .constants import ( - DEFAULT_ETAG_TIMEOUT, - DEFAULT_REQUEST_TIMEOUT, - DEFAULT_REVISION, - DOWNLOAD_CHUNK_SIZE, - ENDPOINT, - HF_HUB_CACHE, - HF_HUB_DISABLE_SYMLINKS_WARNING, - HF_HUB_DOWNLOAD_TIMEOUT, - HF_HUB_ENABLE_HF_TRANSFER, - HF_HUB_ETAG_TIMEOUT, - HF_TRANSFER_CONCURRENCY, - HUGGINGFACE_CO_URL_TEMPLATE, - HUGGINGFACE_HEADER_X_LINKED_ETAG, - HUGGINGFACE_HEADER_X_LINKED_SIZE, - HUGGINGFACE_HEADER_X_REPO_COMMIT, + HUGGINGFACE_CO_URL_TEMPLATE, # noqa: F401 # for backward compatibility HUGGINGFACE_HUB_CACHE, # noqa: F401 # for backward compatibility - REPO_ID_SEPARATOR, - REPO_TYPES, - REPO_TYPES_URL_PREFIXES, ) from .errors import ( EntryNotFoundError, @@ -118,7 +104,7 @@ def are_symlinks_supported(cache_dir: Union[str, Path, None] = None) -> bool: """ # Defaults to HF cache if cache_dir is None: - cache_dir = HF_HUB_CACHE + cache_dir = constants.HF_HUB_CACHE cache_dir = str(Path(cache_dir).expanduser().resolve()) # make it unique # Check symlink compatibility only once (per cache directory) at first time use @@ -139,7 +125,7 @@ def are_symlinks_supported(cache_dir: Union[str, Path, None] = None) -> bool: # Likely running on Windows _are_symlinks_supported_in_dir[cache_dir] = False - if not HF_HUB_DISABLE_SYMLINKS_WARNING: + if not constants.HF_HUB_DISABLE_SYMLINKS_WARNING: message = ( "`huggingface_hub` cache-system uses symlinks by default to" " efficiently store duplicated files but your machine does not" @@ -259,20 +245,20 @@ def hf_hub_url( if subfolder is not None: filename = f"{subfolder}/{filename}" - if repo_type not in REPO_TYPES: + if repo_type not in constants.REPO_TYPES: raise ValueError("Invalid repo type") - if repo_type in REPO_TYPES_URL_PREFIXES: - repo_id = REPO_TYPES_URL_PREFIXES[repo_type] + repo_id + if repo_type in constants.REPO_TYPES_URL_PREFIXES: + repo_id = constants.REPO_TYPES_URL_PREFIXES[repo_type] + repo_id if revision is None: - revision = DEFAULT_REVISION + revision = constants.DEFAULT_REVISION url = HUGGINGFACE_CO_URL_TEMPLATE.format( repo_id=repo_id, revision=quote(revision, safe=""), filename=quote(filename) ) # Update endpoint if provided - if endpoint is not None and url.startswith(ENDPOINT): - url = endpoint + url[len(ENDPOINT) :] + if endpoint is not None and url.startswith(constants.ENDPOINT): + url = endpoint + url[len(constants.ENDPOINT) :] return url @@ -335,7 +321,7 @@ def filename_to_url( ) if cache_dir is None: - cache_dir = HF_HUB_CACHE + cache_dir = constants.HF_HUB_CACHE if isinstance(cache_dir, Path): cache_dir = str(cache_dir) @@ -442,7 +428,7 @@ def http_get( not set, the filename is guessed from the URL or the `Content-Disposition` header. """ hf_transfer = None - if HF_HUB_ENABLE_HF_TRANSFER: + if constants.HF_HUB_ENABLE_HF_TRANSFER: if resume_size != 0: warnings.warn("'hf_transfer' does not support `resume_size`: falling back to regular download method") elif proxies is not None: @@ -463,7 +449,7 @@ def http_get( headers["Range"] = "bytes=%d-" % (resume_size,) r = _request_wrapper( - method="GET", url=url, stream=True, proxies=proxies, headers=headers, timeout=HF_HUB_DOWNLOAD_TIMEOUT + method="GET", url=url, stream=True, proxies=proxies, headers=headers, timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT ) hf_raise_for_status(r) content_length = r.headers.get("Content-Length") @@ -513,7 +499,7 @@ def http_get( ) with progress_cm as progress: - if hf_transfer and total is not None and total > 5 * DOWNLOAD_CHUNK_SIZE: + if hf_transfer and total is not None and total > 5 * constants.DOWNLOAD_CHUNK_SIZE: supports_callback = "callback" in inspect.signature(hf_transfer.download).parameters if not supports_callback: warnings.warn( @@ -525,8 +511,8 @@ def http_get( hf_transfer.download( url=url, filename=temp_file.name, - max_files=HF_TRANSFER_CONCURRENCY, - chunk_size=DOWNLOAD_CHUNK_SIZE, + max_files=constants.HF_TRANSFER_CONCURRENCY, + chunk_size=constants.DOWNLOAD_CHUNK_SIZE, headers=headers, parallel_failures=3, max_retries=5, @@ -548,7 +534,7 @@ def http_get( return new_resume_size = resume_size try: - for chunk in r.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE): + for chunk in r.iter_content(chunk_size=constants.DOWNLOAD_CHUNK_SIZE): if chunk: # filter out keep-alive new chunks progress.update(len(chunk)) temp_file.write(chunk) @@ -596,7 +582,7 @@ def cached_download( force_download: bool = False, force_filename: Optional[str] = None, proxies: Optional[Dict] = None, - etag_timeout: float = DEFAULT_ETAG_TIMEOUT, + etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, resume_download: Optional[bool] = None, token: Union[bool, str, None] = None, local_files_only: bool = False, @@ -674,9 +660,9 @@ def cached_download( """ - if HF_HUB_ETAG_TIMEOUT != DEFAULT_ETAG_TIMEOUT: + if constants.HF_HUB_ETAG_TIMEOUT != constants.DEFAULT_ETAG_TIMEOUT: # Respect environment variable above user value - etag_timeout = HF_HUB_ETAG_TIMEOUT + etag_timeout = constants.HF_HUB_ETAG_TIMEOUT if not legacy_cache_layout: warnings.warn( @@ -693,7 +679,7 @@ def cached_download( ) if cache_dir is None: - cache_dir = HF_HUB_CACHE + cache_dir = constants.HF_HUB_CACHE if isinstance(cache_dir, Path): cache_dir = str(cache_dir) @@ -725,7 +711,7 @@ def cached_download( ) headers.pop("Accept-Encoding", None) hf_raise_for_status(r) - etag = r.headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get("ETag") + etag = r.headers.get(constants.HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get("ETag") # We favor a custom header indicating the etag of the linked resource, and # we fallback to the regular etag header. # If we don't have any of those, raise an error. @@ -972,7 +958,7 @@ def repo_folder_name(*, repo_id: str, repo_type: str) -> str: """ # remove all `/` occurrences to correctly convert repo to directory name parts = [f"{repo_type}s", *repo_id.split("/")] - return REPO_ID_SEPARATOR.join(parts) + return constants.REPO_ID_SEPARATOR.join(parts) def _check_disk_space(expected_size: int, target_dir: Union[str, Path]) -> None: @@ -1023,7 +1009,7 @@ def hf_hub_download( user_agent: Union[Dict, str, None] = None, force_download: bool = False, proxies: Optional[Dict] = None, - etag_timeout: float = DEFAULT_ETAG_TIMEOUT, + etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, token: Union[bool, str, None] = None, local_files_only: bool = False, headers: Optional[Dict[str, str]] = None, @@ -1137,9 +1123,9 @@ def hf_hub_download( If some parameter value is invalid. """ - if HF_HUB_ETAG_TIMEOUT != DEFAULT_ETAG_TIMEOUT: + if constants.HF_HUB_ETAG_TIMEOUT != constants.DEFAULT_ETAG_TIMEOUT: # Respect environment variable above user value - etag_timeout = HF_HUB_ETAG_TIMEOUT + etag_timeout = constants.HF_HUB_ETAG_TIMEOUT if force_filename is not None: warnings.warn( @@ -1182,9 +1168,9 @@ def hf_hub_download( ) if cache_dir is None: - cache_dir = HF_HUB_CACHE + cache_dir = constants.HF_HUB_CACHE if revision is None: - revision = DEFAULT_REVISION + revision = constants.DEFAULT_REVISION if isinstance(cache_dir, Path): cache_dir = str(cache_dir) if isinstance(local_dir, Path): @@ -1198,8 +1184,8 @@ def hf_hub_download( if repo_type is None: repo_type = "model" - if repo_type not in REPO_TYPES: - raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(REPO_TYPES)}") + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(constants.REPO_TYPES)}") headers = build_hf_headers( token=token, @@ -1583,10 +1569,10 @@ def try_to_load_from_cache( revision = "main" if repo_type is None: repo_type = "model" - if repo_type not in REPO_TYPES: - raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(REPO_TYPES)}") + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(constants.REPO_TYPES)}") if cache_dir is None: - cache_dir = HF_HUB_CACHE + cache_dir = constants.HF_HUB_CACHE object_id = repo_id.replace("/", "--") repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}") @@ -1627,7 +1613,7 @@ def get_hf_file_metadata( url: str, token: Union[bool, str, None] = None, proxies: Optional[Dict] = None, - timeout: Optional[float] = DEFAULT_REQUEST_TIMEOUT, + timeout: Optional[float] = constants.DEFAULT_REQUEST_TIMEOUT, library_name: Optional[str] = None, library_version: Optional[str] = None, user_agent: Union[Dict, str, None] = None, @@ -1685,15 +1671,17 @@ def get_hf_file_metadata( # Return return HfFileMetadata( - commit_hash=r.headers.get(HUGGINGFACE_HEADER_X_REPO_COMMIT), + commit_hash=r.headers.get(constants.HUGGINGFACE_HEADER_X_REPO_COMMIT), # We favor a custom header indicating the etag of the linked resource, and # we fallback to the regular etag header. - etag=_normalize_etag(r.headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get("ETag")), + etag=_normalize_etag(r.headers.get(constants.HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get("ETag")), # Either from response headers (if redirected) or defaults to request url # Do not use directly `url`, as `_request_wrapper` might have followed relative # redirects. location=r.headers.get("Location") or r.request.url, # type: ignore - size=_int_or_none(r.headers.get(HUGGINGFACE_HEADER_X_LINKED_SIZE) or r.headers.get("Content-Length")), + size=_int_or_none( + r.headers.get(constants.HUGGINGFACE_HEADER_X_LINKED_SIZE) or r.headers.get("Content-Length") + ), ) @@ -1756,7 +1744,7 @@ def _get_metadata_or_catch_error( except EntryNotFoundError as http_error: if storage_folder is not None and relative_filename is not None: # Cache the non-existence of the file - commit_hash = http_error.response.headers.get(HUGGINGFACE_HEADER_X_REPO_COMMIT) + commit_hash = http_error.response.headers.get(constants.HUGGINGFACE_HEADER_X_REPO_COMMIT) if commit_hash is not None: no_exist_file_path = Path(storage_folder) / ".no_exist" / commit_hash / relative_filename no_exist_file_path.parent.mkdir(parents=True, exist_ok=True) @@ -1890,14 +1878,14 @@ def _download_to_tmp_and_move( # Do nothing if already exists (except if force_download=True) return - if incomplete_path.exists() and (force_download or (HF_HUB_ENABLE_HF_TRANSFER and not proxies)): + if incomplete_path.exists() and (force_download or (constants.HF_HUB_ENABLE_HF_TRANSFER and not proxies)): # By default, we will try to resume the download if possible. # However, if the user has set `force_download=True` or if `hf_transfer` is enabled, then we should # not resume the download => delete the incomplete file. message = f"Removing incomplete file '{incomplete_path}'" if force_download: message += " (force_download=True)" - elif HF_HUB_ENABLE_HF_TRANSFER and not proxies: + elif constants.HF_HUB_ENABLE_HF_TRANSFER and not proxies: message += " (hf_transfer=True)" logger.info(message) incomplete_path.unlink(missing_ok=True) diff --git a/src/huggingface_hub/hf_file_system.py b/src/huggingface_hub/hf_file_system.py index 1b0e971787..a831b6c929 100644 --- a/src/huggingface_hub/hf_file_system.py +++ b/src/huggingface_hub/hf_file_system.py @@ -15,16 +15,8 @@ from fsspec.utils import isfilelike from requests import Response +from . import constants from ._commit_api import CommitOperationCopy, CommitOperationDelete -from .constants import ( - DEFAULT_REVISION, - ENDPOINT, - HF_HUB_DOWNLOAD_TIMEOUT, - HF_HUB_ETAG_TIMEOUT, - REPO_TYPE_MODEL, - REPO_TYPES_MAPPING, - REPO_TYPES_URL_PREFIXES, -) from .errors import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from .file_download import hf_hub_url, http_get from .hf_api import HfApi, LastCommitInfo, RepoFile @@ -59,10 +51,10 @@ class HfFileSystemResolvedPath: _raw_revision: Optional[str] = field(default=None, repr=False) def unresolve(self) -> str: - repo_path = REPO_TYPES_URL_PREFIXES.get(self.repo_type, "") + self.repo_id + repo_path = constants.REPO_TYPES_URL_PREFIXES.get(self.repo_type, "") + self.repo_id if self._raw_revision: return f"{repo_path}@{self._raw_revision}/{self.path_in_repo}".rstrip("/") - elif self.revision != DEFAULT_REVISION: + elif self.revision != constants.DEFAULT_REVISION: return f"{repo_path}@{safe_revision(self.revision)}/{self.path_in_repo}".rstrip("/") else: return f"{repo_path}/{self.path_in_repo}".rstrip("/") @@ -111,7 +103,7 @@ def __init__( **storage_options, ): super().__init__(*args, **storage_options) - self.endpoint = endpoint or ENDPOINT + self.endpoint = endpoint or constants.ENDPOINT self.token = token self._api = HfApi(endpoint=endpoint, token=token) # Maps (repo_type, repo_id, revision) to a 2-tuple with: @@ -126,7 +118,9 @@ def _repo_and_revision_exist( ) -> Tuple[bool, Optional[Exception]]: if (repo_type, repo_id, revision) not in self._repo_and_revision_exists_cache: try: - self._api.repo_info(repo_id, revision=revision, repo_type=repo_type, timeout=HF_HUB_ETAG_TIMEOUT) + self._api.repo_info( + repo_id, revision=revision, repo_type=repo_type, timeout=constants.HF_HUB_ETAG_TIMEOUT + ) except (RepositoryNotFoundError, HFValidationError) as e: self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = False, e self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = False, e @@ -156,14 +150,14 @@ def _align_revision_in_path_with_revision( if not path: # can't list repositories at root raise NotImplementedError("Access to repositories lists is not implemented.") - elif path.split("/")[0] + "/" in REPO_TYPES_URL_PREFIXES.values(): + elif path.split("/")[0] + "/" in constants.REPO_TYPES_URL_PREFIXES.values(): if "/" not in path: # can't list repositories at the repository type level raise NotImplementedError("Access to repositories lists is not implemented.") repo_type, path = path.split("/", 1) - repo_type = REPO_TYPES_MAPPING[repo_type] + repo_type = constants.REPO_TYPES_MAPPING[repo_type] else: - repo_type = REPO_TYPE_MODEL + repo_type = constants.REPO_TYPE_MODEL if path.count("/") > 0: if "@" in path: repo_id, revision_in_path = path.split("@", 1) @@ -211,7 +205,7 @@ def _align_revision_in_path_with_revision( if not repo_and_revision_exist: raise NotImplementedError("Access to repositories lists is not implemented.") - revision = revision if revision is not None else DEFAULT_REVISION + revision = revision if revision is not None else constants.DEFAULT_REVISION return HfFileSystemResolvedPath(repo_type, repo_id, revision, path_in_repo, _raw_revision=revision_in_path) def invalidate_cache(self, path: Optional[str] = None) -> None: @@ -721,7 +715,7 @@ def _fetch_range(self, start: int, end: int) -> bytes: url, headers=headers, retry_on_status_codes=(502, 503, 504), - timeout=HF_HUB_DOWNLOAD_TIMEOUT, + timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT, ) hf_raise_for_status(r) return r.content @@ -821,7 +815,7 @@ def read(self, length: int = -1): headers=self.fs._api._build_hf_headers(), retry_on_status_codes=(502, 503, 504), stream=True, - timeout=HF_HUB_DOWNLOAD_TIMEOUT, + timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT, ) hf_raise_for_status(self.response) try: @@ -843,7 +837,7 @@ def read(self, length: int = -1): headers={"Range": "bytes=%d-" % self.loc, **self.fs._api._build_hf_headers()}, retry_on_status_codes=(502, 503, 504), stream=True, - timeout=HF_HUB_DOWNLOAD_TIMEOUT, + timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT, ) hf_raise_for_status(self.response) try: diff --git a/src/huggingface_hub/hub_mixin.py b/src/huggingface_hub/hub_mixin.py index 88cec2b446..ceb3ee235a 100644 --- a/src/huggingface_hub/hub_mixin.py +++ b/src/huggingface_hub/hub_mixin.py @@ -17,7 +17,7 @@ Union, ) -from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE +from . import constants from .errors import EntryNotFoundError, HfHubHTTPError from .file_download import hf_hub_download from .hf_api import HfApi @@ -416,7 +416,7 @@ def save_pretrained( # Remove config.json if already exists. After `_save_pretrained` we don't want to overwrite config.json # as it might have been saved by the custom `_save_pretrained` already. However we do want to overwrite # an existing config.json if it was not saved by `_save_pretrained`. - config_path = save_directory / CONFIG_NAME + config_path = save_directory / constants.CONFIG_NAME config_path.unlink(missing_ok=True) # save model weights/files (framework-specific) @@ -504,15 +504,15 @@ def from_pretrained( model_id = str(pretrained_model_name_or_path) config_file: Optional[str] = None if os.path.isdir(model_id): - if CONFIG_NAME in os.listdir(model_id): - config_file = os.path.join(model_id, CONFIG_NAME) + if constants.CONFIG_NAME in os.listdir(model_id): + config_file = os.path.join(model_id, constants.CONFIG_NAME) else: - logger.warning(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}") + logger.warning(f"{constants.CONFIG_NAME} not found in {Path(model_id).resolve()}") else: try: config_file = hf_hub_download( repo_id=model_id, - filename=CONFIG_NAME, + filename=constants.CONFIG_NAME, revision=revision, cache_dir=cache_dir, force_download=force_download, @@ -522,7 +522,7 @@ def from_pretrained( local_files_only=local_files_only, ) except HfHubHTTPError as e: - logger.info(f"{CONFIG_NAME} not found on the HuggingFace Hub: {str(e)}") + logger.info(f"{constants.CONFIG_NAME} not found on the HuggingFace Hub: {str(e)}") # Read config config = None @@ -766,7 +766,7 @@ def __init_subclass__(cls, *args, tags: Optional[List[str]] = None, **kwargs) -> def _save_pretrained(self, save_directory: Path) -> None: """Save weights from a Pytorch model to a local directory.""" model_to_save = self.module if hasattr(self, "module") else self # type: ignore - save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE)) + save_model_as_safetensor(model_to_save, str(save_directory / constants.SAFETENSORS_SINGLE_FILE)) @classmethod def _from_pretrained( @@ -788,13 +788,13 @@ def _from_pretrained( model = cls(**model_kwargs) if os.path.isdir(model_id): print("Loading weights from local directory") - model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE) + model_file = os.path.join(model_id, constants.SAFETENSORS_SINGLE_FILE) return cls._load_as_safetensor(model, model_file, map_location, strict) else: try: model_file = hf_hub_download( repo_id=model_id, - filename=SAFETENSORS_SINGLE_FILE, + filename=constants.SAFETENSORS_SINGLE_FILE, revision=revision, cache_dir=cache_dir, force_download=force_download, @@ -807,7 +807,7 @@ def _from_pretrained( except EntryNotFoundError: model_file = hf_hub_download( repo_id=model_id, - filename=PYTORCH_WEIGHTS_NAME, + filename=constants.PYTORCH_WEIGHTS_NAME, revision=revision, cache_dir=cache_dir, force_download=force_download, diff --git a/src/huggingface_hub/inference_api.py b/src/huggingface_hub/inference_api.py index c889a6d872..f895fcc61c 100644 --- a/src/huggingface_hub/inference_api.py +++ b/src/huggingface_hub/inference_api.py @@ -1,7 +1,7 @@ import io from typing import Any, Dict, List, Optional, Union -from .constants import INFERENCE_ENDPOINT +from . import constants from .hf_api import HfApi from .utils import build_hf_headers, get_session, is_pillow_available, logging, validate_hf_hub_args from .utils._deprecation import _deprecate_method @@ -149,7 +149,7 @@ def __init__( assert model_info.pipeline_tag is not None, "Pipeline tag cannot be None" self.task = model_info.pipeline_tag - self.api_url = f"{INFERENCE_ENDPOINT}/pipeline/{self.task}/{repo_id}" + self.api_url = f"{constants.INFERENCE_ENDPOINT}/pipeline/{self.task}/{repo_id}" def __repr__(self): # Do not add headers to repr to avoid leaking token. diff --git a/src/huggingface_hub/keras_mixin.py b/src/huggingface_hub/keras_mixin.py index e1c9e09fac..f5d9edf37a 100644 --- a/src/huggingface_hub/keras_mixin.py +++ b/src/huggingface_hub/keras_mixin.py @@ -16,7 +16,7 @@ yaml_dump, ) -from .constants import CONFIG_NAME +from . import constants from .hf_api import HfApi from .utils import SoftTemporaryDirectory, logging, validate_hf_hub_args from .utils._typing import CallableT @@ -202,7 +202,7 @@ def save_pretrained_keras( if not isinstance(config, dict): raise RuntimeError(f"Provided config to save_pretrained_keras should be a dict. Got: '{type(config)}'") - with (save_directory / CONFIG_NAME).open("w") as f: + with (save_directory / constants.CONFIG_NAME).open("w") as f: json.dump(config, f) metadata = {} diff --git a/src/huggingface_hub/lfs.py b/src/huggingface_hub/lfs.py index bff32b3307..c96ed3d7be 100644 --- a/src/huggingface_hub/lfs.py +++ b/src/huggingface_hub/lfs.py @@ -27,7 +27,7 @@ from typing import TYPE_CHECKING, BinaryIO, Dict, Iterable, List, Optional, Tuple, TypedDict from urllib.parse import unquote -from huggingface_hub.constants import ENDPOINT, HF_HUB_ENABLE_HF_TRANSFER, REPO_TYPES_URL_PREFIXES +from huggingface_hub import constants from .utils import ( build_hf_headers, @@ -139,10 +139,10 @@ def post_lfs_batch_info( [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) If the server returned an error. """ - endpoint = endpoint if endpoint is not None else ENDPOINT + endpoint = endpoint if endpoint is not None else constants.ENDPOINT url_prefix = "" - if repo_type in REPO_TYPES_URL_PREFIXES: - url_prefix = REPO_TYPES_URL_PREFIXES[repo_type] + if repo_type in constants.REPO_TYPES_URL_PREFIXES: + url_prefix = constants.REPO_TYPES_URL_PREFIXES[repo_type] batch_url = f"{endpoint}/{url_prefix}{repo_id}.git/info/lfs/objects/batch" payload: Dict = { "operation": "upload", @@ -328,9 +328,9 @@ def _upload_multi_part(operation: "CommitOperationAdd", header: Dict, chunk_size sorted_parts_urls = _get_sorted_parts_urls(header=header, upload_info=operation.upload_info, chunk_size=chunk_size) # 2. Upload parts (either with hf_transfer or in pure Python) - use_hf_transfer = HF_HUB_ENABLE_HF_TRANSFER + use_hf_transfer = constants.HF_HUB_ENABLE_HF_TRANSFER if ( - HF_HUB_ENABLE_HF_TRANSFER + constants.HF_HUB_ENABLE_HF_TRANSFER and not isinstance(operation.path_or_fileobj, str) and not isinstance(operation.path_or_fileobj, Path) ): diff --git a/src/huggingface_hub/repocard.py b/src/huggingface_hub/repocard.py index 0a767e6303..f6ae591f40 100644 --- a/src/huggingface_hub/repocard.py +++ b/src/huggingface_hub/repocard.py @@ -19,7 +19,7 @@ ) from huggingface_hub.utils import get_session, is_jinja_available, yaml_dump -from .constants import REPOCARD_NAME +from . import constants from .errors import EntryNotFoundError from .utils import SoftTemporaryDirectory, logging, validate_hf_hub_args @@ -176,7 +176,7 @@ def load( card_path = Path( hf_hub_download( repo_id_or_path, - REPOCARD_NAME, + constants.REPOCARD_NAME, repo_type=repo_type or cls.repo_type, token=token, ) @@ -274,11 +274,11 @@ def push_to_hub( self.validate(repo_type=repo_type) with SoftTemporaryDirectory() as tmpdir: - tmp_path = Path(tmpdir) / REPOCARD_NAME + tmp_path = Path(tmpdir) / constants.REPOCARD_NAME tmp_path.write_text(str(self)) url = upload_file( path_or_fileobj=str(tmp_path), - path_in_repo=REPOCARD_NAME, + path_in_repo=constants.REPOCARD_NAME, repo_id=repo_id, token=token, repo_type=repo_type, diff --git a/src/huggingface_hub/repository.py b/src/huggingface_hub/repository.py index 09b4bbc777..af1ab72fb4 100644 --- a/src/huggingface_hub/repository.py +++ b/src/huggingface_hub/repository.py @@ -9,7 +9,7 @@ from typing import Callable, Dict, Iterator, List, Optional, Tuple, TypedDict, Union from urllib.parse import urlparse -from huggingface_hub.constants import REPO_TYPES_URL_PREFIXES, REPOCARD_NAME +from huggingface_hub import constants from huggingface_hub.repocard import metadata_load, metadata_save from .hf_api import HfApi, repo_type_and_id_from_hf_id @@ -659,8 +659,8 @@ def clone_from(self, repo_url: str, token: Union[bool, str, None] = None): repo_url = hub_url + "/" - if self._repo_type in REPO_TYPES_URL_PREFIXES: - repo_url += REPO_TYPES_URL_PREFIXES[self._repo_type] + if self._repo_type in constants.REPO_TYPES_URL_PREFIXES: + repo_url += constants.REPO_TYPES_URL_PREFIXES[self._repo_type] if token is not None: # Add token in git url when provided @@ -1434,13 +1434,13 @@ def commit( os.chdir(current_working_directory) def repocard_metadata_load(self) -> Optional[Dict]: - filepath = os.path.join(self.local_dir, REPOCARD_NAME) + filepath = os.path.join(self.local_dir, constants.REPOCARD_NAME) if os.path.isfile(filepath): return metadata_load(filepath) return None def repocard_metadata_save(self, data: Dict) -> None: - return metadata_save(os.path.join(self.local_dir, REPOCARD_NAME), data) + return metadata_save(os.path.join(self.local_dir, constants.REPOCARD_NAME), data) @property def commands_failed(self): diff --git a/tests/test_file_download.py b/tests/test_file_download.py index e2aafde3c9..73b5d3bc72 100644 --- a/tests/test_file_download.py +++ b/tests/test_file_download.py @@ -28,14 +28,8 @@ from requests import Response import huggingface_hub.file_download -from huggingface_hub import HfApi, RepoUrl +from huggingface_hub import HfApi, RepoUrl, constants from huggingface_hub._local_folder import write_download_metadata -from huggingface_hub.constants import ( - CONFIG_NAME, - HUGGINGFACE_HEADER_X_LINKED_ETAG, - PYTORCH_WEIGHTS_NAME, - REPO_TYPE_DATASET, -) from huggingface_hub.errors import ( EntryNotFoundError, GatedRepoError, @@ -204,10 +198,10 @@ def test_bogus_url(self): def test_no_connection(self): invalid_url = hf_hub_url( DUMMY_MODEL_ID, - filename=CONFIG_NAME, + filename=constants.CONFIG_NAME, revision=DUMMY_MODEL_ID_REVISION_INVALID, ) - valid_url = hf_hub_url(DUMMY_MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT) + valid_url = hf_hub_url(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, revision=REVISION_ID_DEFAULT) self.assertIsNotNone(cached_download(valid_url, force_download=True, legacy_cache_layout=True)) for offline_mode in OfflineSimulationMode: with offline(mode=offline_mode): @@ -233,7 +227,7 @@ def test_file_not_found_locally_and_network_disabled(self): # Download a first time to get the refs ok filepath = hf_hub_download( DUMMY_MODEL_ID, - filename=CONFIG_NAME, + filename=constants.CONFIG_NAME, cache_dir=tmpdir, local_files_only=False, ) @@ -245,7 +239,7 @@ def test_file_not_found_locally_and_network_disabled(self): with pytest.raises(LocalEntryNotFoundError): hf_hub_download( DUMMY_MODEL_ID, - filename=CONFIG_NAME, + filename=constants.CONFIG_NAME, cache_dir=tmpdir, local_files_only=True, ) @@ -254,7 +248,7 @@ def test_file_not_found_locally_and_network_disabled(self): @expect_deprecation("url_to_filename") def test_file_not_found_locally_and_network_disabled_legacy(self): # Valid file but missing locally and network is disabled. - url = hf_hub_url(DUMMY_MODEL_ID, filename=CONFIG_NAME) + url = hf_hub_url(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME) with SoftTemporaryDirectory() as tmpdir: # Get without network must fail with pytest.raises(LocalEntryNotFoundError): @@ -268,14 +262,14 @@ def test_file_not_found_locally_and_network_disabled_legacy(self): def test_private_repo_and_file_cached_locally(self): api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN) repo_id = api.create_repo(repo_id=repo_name(), private=True).repo_id - api.upload_file(path_or_fileobj=b"content", path_in_repo=CONFIG_NAME, repo_id=repo_id) + api.upload_file(path_or_fileobj=b"content", path_in_repo=constants.CONFIG_NAME, repo_id=repo_id) with SoftTemporaryDirectory() as tmpdir: # Download a first time with token => file is cached - filepath_1 = hf_hub_download(DUMMY_MODEL_ID, filename=CONFIG_NAME, cache_dir=tmpdir, token=TOKEN) + filepath_1 = hf_hub_download(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, cache_dir=tmpdir, token=TOKEN) # Download without token => return cached file - filepath_2 = hf_hub_download(DUMMY_MODEL_ID, filename=CONFIG_NAME, cache_dir=tmpdir) + filepath_2 = hf_hub_download(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, cache_dir=tmpdir) self.assertEqual(filepath_1, filepath_2) @@ -287,13 +281,13 @@ def test_file_cached_and_read_only_access(self): # Valid file but missing locally and network is disabled. with SoftTemporaryDirectory() as tmpdir: # Download a first time to get the refs ok - hf_hub_download(DUMMY_MODEL_ID, filename=CONFIG_NAME, cache_dir=tmpdir) + hf_hub_download(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, cache_dir=tmpdir) # Set read-only permission recursively _recursive_chmod(tmpdir, 0o555) # Get without write-access must succeed - hf_hub_download(DUMMY_MODEL_ID, filename=CONFIG_NAME, cache_dir=tmpdir) + hf_hub_download(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, cache_dir=tmpdir) # Set permission back for cleanup _recursive_chmod(tmpdir, 0o777) @@ -303,7 +297,7 @@ def test_revision_not_found(self): # Valid file but missing revision url = hf_hub_url( DUMMY_MODEL_ID, - filename=CONFIG_NAME, + filename=constants.CONFIG_NAME, revision=DUMMY_MODEL_ID_REVISION_INVALID, ) with self.assertRaisesRegex( @@ -326,7 +320,7 @@ def test_repo_not_found(self): @expect_deprecation("url_to_filename") @expect_deprecation("filename_to_url") def test_standard_object(self): - url = hf_hub_url(DUMMY_MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT) + url = hf_hub_url(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, revision=REVISION_ID_DEFAULT) filepath = cached_download(url, force_download=True, legacy_cache_layout=True) metadata = filename_to_url(filepath, legacy_cache_layout=True) self.assertEqual(metadata, (url, f'"{DUMMY_MODEL_ID_PINNED_SHA1}"')) @@ -338,7 +332,7 @@ def test_standard_object_rev(self): # Same object, but different revision url = hf_hub_url( DUMMY_MODEL_ID, - filename=CONFIG_NAME, + filename=constants.CONFIG_NAME, revision=DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT, ) filepath = cached_download(url, force_download=True, legacy_cache_layout=True) @@ -350,7 +344,7 @@ def test_standard_object_rev(self): @expect_deprecation("url_to_filename") @expect_deprecation("filename_to_url") def test_lfs_object(self): - url = hf_hub_url(DUMMY_MODEL_ID, filename=PYTORCH_WEIGHTS_NAME, revision=REVISION_ID_DEFAULT) + url = hf_hub_url(DUMMY_MODEL_ID, filename=constants.PYTORCH_WEIGHTS_NAME, revision=REVISION_ID_DEFAULT) filepath = cached_download(url, force_download=True, legacy_cache_layout=True) metadata = filename_to_url(filepath, legacy_cache_layout=True) self.assertEqual(metadata, (url, f'"{DUMMY_MODEL_ID_PINNED_SHA256}"')) @@ -362,7 +356,7 @@ def test_dataset_standard_object_rev(self): url = hf_hub_url( DATASET_ID, filename=DATASET_SAMPLE_PY_FILE, - repo_type=REPO_TYPE_DATASET, + repo_type=constants.REPO_TYPE_DATASET, revision=DATASET_REVISION_ID_ONE_SPECIFIC_COMMIT, ) # now let's download @@ -377,7 +371,7 @@ def test_dataset_lfs_object(self): url = hf_hub_url( DATASET_ID, filename="dev-v1.1.json", - repo_type=REPO_TYPE_DATASET, + repo_type=constants.REPO_TYPE_DATASET, revision=DATASET_REVISION_ID_ONE_SPECIFIC_COMMIT, ) filepath = cached_download(url, force_download=True, legacy_cache_layout=True) @@ -442,7 +436,7 @@ def test_hf_hub_download_with_empty_subfolder(self): filepath = Path( hf_hub_download( DUMMY_MODEL_ID, - filename=CONFIG_NAME, + filename=constants.CONFIG_NAME, subfolder="", # Subfolder should be processed as `None` ) ) @@ -450,7 +444,7 @@ def test_hf_hub_download_with_empty_subfolder(self): # Check file exists and is not in a subfolder in cache # e.g: "(...)/snapshots//config.json" self.assertTrue(filepath.is_file()) - self.assertEqual(filepath.name, CONFIG_NAME) + self.assertEqual(filepath.name, constants.CONFIG_NAME) self.assertEqual(Path(filepath).parent.parent.name, "snapshots") def test_hf_hub_download_offline_no_refs(self): @@ -466,7 +460,7 @@ def test_hf_hub_download_offline_no_refs(self): with self.assertRaises(LocalEntryNotFoundError): hf_hub_download( DUMMY_MODEL_ID, - filename=CONFIG_NAME, + filename=constants.CONFIG_NAME, local_files_only=True, cache_dir=cache_dir, ) @@ -489,7 +483,7 @@ def _check_user_agent(headers: dict): # First download hf_hub_download( DUMMY_MODEL_ID, - filename=CONFIG_NAME, + filename=constants.CONFIG_NAME, cache_dir=cache_dir, library_name="test", library_version="1.0.0", @@ -504,7 +498,7 @@ def _check_user_agent(headers: dict): # Second download: no GET call hf_hub_download( DUMMY_MODEL_ID, - filename=CONFIG_NAME, + filename=constants.CONFIG_NAME, cache_dir=cache_dir, library_name="test", library_version="1.0.0", @@ -524,7 +518,7 @@ def test_hf_hub_url_with_empty_subfolder(self): """ url = hf_hub_url( DUMMY_MODEL_ID, - filename=CONFIG_NAME, + filename=constants.CONFIG_NAME, subfolder="", # Subfolder should be processed as `None` ) self.assertTrue( @@ -534,7 +528,7 @@ def test_hf_hub_url_with_empty_subfolder(self): ) ) - @patch("huggingface_hub.file_download.ENDPOINT", "https://huggingface.co") + @patch("huggingface_hub.file_download.constants.ENDPOINT", "https://huggingface.co") @patch( "huggingface_hub.file_download.HUGGINGFACE_CO_URL_TEMPLATE", "https://huggingface.co/{repo_id}/resolve/{revision}/{filename}", @@ -543,7 +537,7 @@ def test_hf_hub_url_with_endpoint(self): self.assertEqual( hf_hub_url( DUMMY_MODEL_ID, - filename=CONFIG_NAME, + filename=constants.CONFIG_NAME, endpoint="https://hf-ci.co", ), "https://hf-ci.co/julien-c/dummy-unknown/resolve/main/config.json", @@ -556,7 +550,7 @@ def test_hf_hub_url_with_endpoint(self): def test_hf_hub_download_legacy(self): filepath = hf_hub_download( DUMMY_MODEL_ID, - filename=CONFIG_NAME, + filename=constants.CONFIG_NAME, revision=REVISION_ID_DEFAULT, force_download=True, legacy_cache_layout=True, @@ -566,12 +560,12 @@ def test_hf_hub_download_legacy(self): def test_try_to_load_from_cache_exist(self): # Make sure the file is cached - filepath = hf_hub_download(DUMMY_MODEL_ID, filename=CONFIG_NAME) + filepath = hf_hub_download(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME) - new_file_path = try_to_load_from_cache(DUMMY_MODEL_ID, filename=CONFIG_NAME) + new_file_path = try_to_load_from_cache(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME) self.assertEqual(filepath, new_file_path) - new_file_path = try_to_load_from_cache(DUMMY_MODEL_ID, filename=CONFIG_NAME, revision="main") + new_file_path = try_to_load_from_cache(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, revision="main") self.assertEqual(filepath, new_file_path) # If file is not cached, returns None @@ -580,25 +574,27 @@ def test_try_to_load_from_cache_exist(self): self.assertIsNone( try_to_load_from_cache( DUMMY_MODEL_ID, - filename=CONFIG_NAME, + filename=constants.CONFIG_NAME, revision="aaa", ) ) # Same for uncached models - self.assertIsNone(try_to_load_from_cache("bert-base", filename=CONFIG_NAME)) + self.assertIsNone(try_to_load_from_cache("bert-base", filename=constants.CONFIG_NAME)) def test_try_to_load_from_cache_specific_pr_revision_exists(self): # Make sure the file is cached - file_path = hf_hub_download(DUMMY_MODEL_ID, filename=CONFIG_NAME, revision="refs/pr/1") + file_path = hf_hub_download(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, revision="refs/pr/1") - new_file_path = try_to_load_from_cache(DUMMY_MODEL_ID, filename=CONFIG_NAME, revision="refs/pr/1") + new_file_path = try_to_load_from_cache(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, revision="refs/pr/1") self.assertEqual(file_path, new_file_path) # If file is not cached, returns None self.assertIsNone(try_to_load_from_cache(DUMMY_MODEL_ID, filename="conf.json", revision="refs/pr/1")) # If revision does not exist, returns None - self.assertIsNone(try_to_load_from_cache(DUMMY_MODEL_ID, filename=CONFIG_NAME, revision="does-not-exist")) + self.assertIsNone( + try_to_load_from_cache(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, revision="does-not-exist") + ) def test_try_to_load_from_cache_no_exist(self): # Make sure the file is cached @@ -623,7 +619,7 @@ def test_try_to_load_from_cache_specific_commit_id_exist(self): commit_id = HfApi().model_info(DUMMY_MODEL_ID).sha filepath = hf_hub_download( DUMMY_MODEL_ID, - filename=CONFIG_NAME, + filename=constants.CONFIG_NAME, revision=commit_id, cache_dir=cache_dir, ) @@ -631,7 +627,7 @@ def test_try_to_load_from_cache_specific_commit_id_exist(self): # Must be able to retrieve it "offline" attempt = try_to_load_from_cache( DUMMY_MODEL_ID, - filename=CONFIG_NAME, + filename=constants.CONFIG_NAME, revision=commit_id, cache_dir=cache_dir, ) @@ -665,7 +661,7 @@ def test_get_hf_file_metadata_basic(self) -> None: """Test getting metadata from a file on the Hub.""" url = hf_hub_url( DUMMY_MODEL_ID, - filename=CONFIG_NAME, + filename=constants.CONFIG_NAME, revision=DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT, ) metadata = get_hf_file_metadata(url) @@ -680,7 +676,7 @@ def test_get_hf_file_metadata_from_a_renamed_repo(self) -> None: """Test getting metadata from a file in a renamed repo on the Hub.""" url = hf_hub_url( DUMMY_RENAMED_OLD_MODEL_ID, - filename=CONFIG_NAME, + filename=constants.CONFIG_NAME, subfolder="", # Subfolder should be processed as `None` ) metadata = get_hf_file_metadata(url) @@ -721,7 +717,7 @@ def _mocked_hf_file_metadata(*args, **kwargs): with patch("huggingface_hub.file_download.get_hf_file_metadata", _mocked_hf_file_metadata): with self.assertRaises(EnvironmentError): - hf_hub_download(DUMMY_MODEL_ID, filename=CONFIG_NAME, cache_dir=cache_dir) + hf_hub_download(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, cache_dir=cache_dir) def test_file_consistency_check_fails_LFS_file(self): """Regression test for #1396 (LFS file). @@ -767,7 +763,7 @@ def test_cached_download_from_github(self): def test_keep_lock_file(self): """Lock files should not be deleted on Linux.""" with SoftTemporaryDirectory() as tmpdir: - hf_hub_download(DUMMY_MODEL_ID, filename=CONFIG_NAME, cache_dir=tmpdir) + hf_hub_download(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, cache_dir=tmpdir) lock_file_exist = False locks_dir = os.path.join(tmpdir, ".locks") for subdir, dirs, files in os.walk(locks_dir): @@ -1261,13 +1257,15 @@ def test_resolve_endpoint_on_lfs_file(self): @staticmethod def _get_etag_and_normalize(response: Response) -> str: response.raise_for_status() - return _normalize_etag(response.headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG) or response.headers.get("ETag")) + return _normalize_etag( + response.headers.get(constants.HUGGINGFACE_HEADER_X_LINKED_ETAG) or response.headers.get("ETag") + ) @with_production_testing class TestEtagTimeoutConfig(unittest.TestCase): - @patch("huggingface_hub.file_download.DEFAULT_ETAG_TIMEOUT", 10) - @patch("huggingface_hub.file_download.HF_HUB_ETAG_TIMEOUT", 10) + @patch("huggingface_hub.file_download.constants.DEFAULT_ETAG_TIMEOUT", 10) + @patch("huggingface_hub.file_download.constants.HF_HUB_ETAG_TIMEOUT", 10) def test_etag_timeout_default_value(self): with SoftTemporaryDirectory() as cache_dir: with patch.object( @@ -1275,13 +1273,13 @@ def test_etag_timeout_default_value(self): "get_hf_file_metadata", wraps=huggingface_hub.file_download.get_hf_file_metadata, ) as mock_etag_call: - hf_hub_download(DUMMY_MODEL_ID, filename=CONFIG_NAME, cache_dir=cache_dir) + hf_hub_download(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, cache_dir=cache_dir) kwargs = mock_etag_call.call_args.kwargs self.assertIn("timeout", kwargs) self.assertEqual(kwargs["timeout"], 10) - @patch("huggingface_hub.file_download.DEFAULT_ETAG_TIMEOUT", 10) - @patch("huggingface_hub.file_download.HF_HUB_ETAG_TIMEOUT", 10) + @patch("huggingface_hub.file_download.constants.DEFAULT_ETAG_TIMEOUT", 10) + @patch("huggingface_hub.file_download.constants.HF_HUB_ETAG_TIMEOUT", 10) def test_etag_timeout_parameter_value(self): with SoftTemporaryDirectory() as cache_dir: with patch.object( @@ -1289,13 +1287,13 @@ def test_etag_timeout_parameter_value(self): "get_hf_file_metadata", wraps=huggingface_hub.file_download.get_hf_file_metadata, ) as mock_etag_call: - hf_hub_download(DUMMY_MODEL_ID, filename=CONFIG_NAME, cache_dir=cache_dir, etag_timeout=12) + hf_hub_download(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, cache_dir=cache_dir, etag_timeout=12) kwargs = mock_etag_call.call_args.kwargs self.assertIn("timeout", kwargs) self.assertEqual(kwargs["timeout"], 12) # passed as parameter, takes priority - @patch("huggingface_hub.file_download.DEFAULT_ETAG_TIMEOUT", 10) - @patch("huggingface_hub.file_download.HF_HUB_ETAG_TIMEOUT", 15) # takes priority + @patch("huggingface_hub.file_download.constants.DEFAULT_ETAG_TIMEOUT", 10) + @patch("huggingface_hub.file_download.constants.HF_HUB_ETAG_TIMEOUT", 15) # takes priority def test_etag_timeout_set_as_env_variable(self): with SoftTemporaryDirectory() as cache_dir: with patch.object( @@ -1303,13 +1301,13 @@ def test_etag_timeout_set_as_env_variable(self): "get_hf_file_metadata", wraps=huggingface_hub.file_download.get_hf_file_metadata, ) as mock_etag_call: - hf_hub_download(DUMMY_MODEL_ID, filename=CONFIG_NAME, cache_dir=cache_dir) + hf_hub_download(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, cache_dir=cache_dir) kwargs = mock_etag_call.call_args.kwargs self.assertIn("timeout", kwargs) self.assertEqual(kwargs["timeout"], 15) - @patch("huggingface_hub.file_download.DEFAULT_ETAG_TIMEOUT", 10) - @patch("huggingface_hub.file_download.HF_HUB_ETAG_TIMEOUT", 15) # takes priority + @patch("huggingface_hub.file_download.constants.DEFAULT_ETAG_TIMEOUT", 10) + @patch("huggingface_hub.file_download.constants.HF_HUB_ETAG_TIMEOUT", 12) # takes priority def test_etag_timeout_set_as_env_variable_parameter_ignored(self): with SoftTemporaryDirectory() as cache_dir: with patch.object( @@ -1317,10 +1315,10 @@ def test_etag_timeout_set_as_env_variable_parameter_ignored(self): "get_hf_file_metadata", wraps=huggingface_hub.file_download.get_hf_file_metadata, ) as mock_etag_call: - hf_hub_download(DUMMY_MODEL_ID, filename=CONFIG_NAME, cache_dir=cache_dir, etag_timeout=12) + hf_hub_download(DUMMY_MODEL_ID, filename=constants.CONFIG_NAME, cache_dir=cache_dir, etag_timeout=12) kwargs = mock_etag_call.call_args.kwargs self.assertIn("timeout", kwargs) - self.assertEqual(kwargs["timeout"], 15) # passed value ignored, HF_HUB_ETAG_TIMEOUT takes priority + self.assertEqual(kwargs["timeout"], 12) # passed value ignored, HF_HUB_ETAG_TIMEOUT takes priority def _recursive_chmod(path: str, mode: int) -> None: diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index aab4a6e48b..5beec1ffea 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -164,9 +164,9 @@ def test_revision_exists(self): assert not self._api.revision_exists(self.repo_id, "main", token=False) # private repo assert not self._api.revision_exists("repo-that-does-not-exist", "main") # missing repo - @patch("huggingface_hub.file_download.ENDPOINT", "https://hub-ci.huggingface.co") + @patch("huggingface_hub.constants.ENDPOINT", "https://hub-ci.huggingface.co") @patch( - "huggingface_hub.file_download.HUGGINGFACE_CO_URL_TEMPLATE", + "huggingface_hub.constants.HUGGINGFACE_CO_URL_TEMPLATE", "https://hub-ci.huggingface.co/{repo_id}/resolve/{revision}/{filename}", ) def test_file_exists(self): diff --git a/tests/test_hub_mixin_pytorch.py b/tests/test_hub_mixin_pytorch.py index f6f746362c..aaabc6b610 100644 --- a/tests/test_hub_mixin_pytorch.py +++ b/tests/test_hub_mixin_pytorch.py @@ -9,8 +9,7 @@ import pytest -from huggingface_hub import HfApi, ModelCard, hf_hub_download -from huggingface_hub.constants import PYTORCH_WEIGHTS_NAME +from huggingface_hub import HfApi, ModelCard, constants, hf_hub_download from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError from huggingface_hub.hub_mixin import ModelHubMixin, PyTorchModelHubMixin from huggingface_hub.serialization._torch import storage_ptr @@ -218,10 +217,10 @@ def pretend_file_download_fallback(self, **kwargs): class TestMixin(ModelHubMixin): def _save_pretrained(self, save_directory: Path) -> None: - torch.save(DummyModel().state_dict(), save_directory / PYTORCH_WEIGHTS_NAME) + torch.save(DummyModel().state_dict(), save_directory / constants.PYTORCH_WEIGHTS_NAME) TestMixin().save_pretrained(self.cache_dir) - return self.cache_dir / PYTORCH_WEIGHTS_NAME + return self.cache_dir / constants.PYTORCH_WEIGHTS_NAME @patch("huggingface_hub.hub_mixin.hf_hub_download") def test_from_pretrained_model_from_hub_fallback_pickle(self, hf_hub_download_mock: Mock) -> None: diff --git a/tests/test_repocard.py b/tests/test_repocard.py index 204c6ed9a7..307b93a772 100644 --- a/tests/test_repocard.py +++ b/tests/test_repocard.py @@ -29,6 +29,7 @@ RepoCard, SpaceCard, SpaceCardData, + constants, get_hf_file_metadata, hf_hub_url, metadata_eval_result, @@ -36,7 +37,6 @@ metadata_save, metadata_update, ) -from huggingface_hub.constants import REPOCARD_NAME from huggingface_hub.errors import EntryNotFoundError from huggingface_hub.file_download import hf_hub_download from huggingface_hub.hf_api import HfApi @@ -206,7 +206,7 @@ class RepocardMetadataTest(unittest.TestCase): cache_dir: Path def setUp(self) -> None: - self.filepath = self.cache_dir / REPOCARD_NAME + self.filepath = self.cache_dir / constants.REPOCARD_NAME def test_metadata_load(self): self.filepath.write_text(DUMMY_MODELCARD) @@ -278,7 +278,9 @@ def setUp(self) -> None: self.repo_id = self.api.create_repo(repo_name()).repo_id self.api.upload_file( - path_or_fileobj=DUMMY_MODELCARD_EVAL_RESULT.encode(), repo_id=self.repo_id, path_in_repo=REPOCARD_NAME + path_or_fileobj=DUMMY_MODELCARD_EVAL_RESULT.encode(), + repo_id=self.repo_id, + path_in_repo=constants.REPOCARD_NAME, ) self.existing_metadata = yaml.safe_load(DUMMY_MODELCARD_EVAL_RESULT.strip().strip("-")) @@ -286,13 +288,13 @@ def tearDown(self) -> None: self.api.delete_repo(repo_id=self.repo_id) def _get_remote_card(self) -> str: - return hf_hub_download(repo_id=self.repo_id, filename=REPOCARD_NAME) + return hf_hub_download(repo_id=self.repo_id, filename=constants.REPOCARD_NAME) def test_update_dataset_name(self): new_datasets_data = {"datasets": ["test/test_dataset"]} metadata_update(self.repo_id, new_datasets_data, token=self.token) - hf_hub_download(repo_id=self.repo_id, filename=REPOCARD_NAME) + hf_hub_download(repo_id=self.repo_id, filename=constants.REPOCARD_NAME) updated_metadata = metadata_load(self._get_remote_card()) expected_metadata = copy.deepcopy(self.existing_metadata) expected_metadata.update(new_datasets_data) @@ -413,7 +415,9 @@ def test_update_metadata_on_empty_text_content(self) -> None: """ # Create modelcard with metadata but empty text content self.api.upload_file( - path_or_fileobj=DUMMY_MODELCARD_NO_TEXT_CONTENT.encode(), path_in_repo=REPOCARD_NAME, repo_id=self.repo_id + path_or_fileobj=DUMMY_MODELCARD_NO_TEXT_CONTENT.encode(), + path_in_repo=constants.REPOCARD_NAME, + repo_id=self.repo_id, ) metadata_update(self.repo_id, {"tag": "test"}, token=self.token)