From d40e4ac721fedef1ac559998c6f78f2a0ba83630 Mon Sep 17 00:00:00 2001 From: Lucain Date: Wed, 10 Sep 2025 15:29:32 +0200 Subject: [PATCH 01/19] [1.0] Httpx migration (#3328) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * first httpx integration * more migration * some fixes * download workflow should work * Fix repocard and error utils tests * fix hf-file-system * gix http utils tests * more fixes * fix some inference tests * fix test_file_download tests * async inference client * async code should be good * Define RemoteEntryFileNotFound explicitly (+some fixes) * fix async code quality * torch ok * fix hf_file_system * fix errors tests * mock * fix test_cli mock * fix commit scheduler * add fileno test * no more requests anywhere * fix test_file_download * tmp requests * Update src/huggingface_hub/utils/_http.py Co-authored-by: célina * Update src/huggingface_hub/utils/_http.py Co-authored-by: célina * Update src/huggingface_hub/hf_file_system.py Co-authored-by: célina * not async * fix tests --------- Co-authored-by: célina --- .github/conda/meta.yaml | 4 +- docs/source/en/guides/cli.md | 2 +- .../environment_variables.md | 2 +- docs/source/en/package_reference/utilities.md | 22 +- docs/source/ko/package_reference/utilities.md | 36 +- setup.py | 4 +- src/huggingface_hub/__init__.py | 30 +- src/huggingface_hub/_commit_api.py | 6 +- src/huggingface_hub/_commit_scheduler.py | 5 +- src/huggingface_hub/_snapshot_download.py | 21 +- src/huggingface_hub/cli/auth.py | 5 +- src/huggingface_hub/cli/jobs.py | 7 +- src/huggingface_hub/cli/repo.py | 4 +- src/huggingface_hub/commands/tag.py | 4 +- src/huggingface_hub/commands/user.py | 5 +- src/huggingface_hub/errors.py | 71 ++- src/huggingface_hub/file_download.py | 320 +++++----- src/huggingface_hub/hf_api.py | 188 +++--- src/huggingface_hub/hf_file_system.py | 132 ++-- src/huggingface_hub/hub_mixin.py | 14 - src/huggingface_hub/inference/_client.py | 123 ++-- src/huggingface_hub/inference/_common.py | 61 +- .../inference/_generated/_async_client.py | 262 ++++---- src/huggingface_hub/inference_api.py | 4 +- src/huggingface_hub/keras_mixin.py | 5 - src/huggingface_hub/lfs.py | 21 +- src/huggingface_hub/repocard.py | 15 +- src/huggingface_hub/utils/__init__.py | 11 +- src/huggingface_hub/utils/_fixes.py | 10 - src/huggingface_hub/utils/_http.py | 564 +++++++++++------- src/huggingface_hub/utils/_pagination.py | 4 +- src/huggingface_hub/utils/_validators.py | 33 + src/huggingface_hub/utils/_xet.py | 6 +- src/huggingface_hub/utils/tqdm.py | 2 +- tests/test_cli.py | 46 +- tests/test_commit_scheduler.py | 11 +- tests/test_file_download.py | 65 +- tests/test_hf_api.py | 38 +- tests/test_hf_file_system.py | 10 +- tests/test_hub_mixin.py | 2 - tests/test_hub_mixin_pytorch.py | 10 +- tests/test_inference_async_client.py | 36 +- tests/test_inference_client.py | 38 +- tests/test_inference_text_generation.py | 4 +- tests/test_oauth.py | 6 +- tests/test_offline_utils.py | 26 +- tests/test_repository.py | 6 +- tests/test_utils_cache.py | 18 +- tests/test_utils_errors.py | 79 +-- tests/test_utils_http.py | 112 ++-- tests/test_xet_upload.py | 1 - tests/testing_utils.py | 50 +- utils/generate_async_inference_client.py | 292 +++------ 53 files changed, 1424 insertions(+), 1429 deletions(-) diff --git a/.github/conda/meta.yaml b/.github/conda/meta.yaml index 6e72641382..830b147805 100644 --- a/.github/conda/meta.yaml +++ b/.github/conda/meta.yaml @@ -16,7 +16,7 @@ requirements: - pip - fsspec - filelock - - requests + - httpx - tqdm - typing-extensions - packaging @@ -26,7 +26,7 @@ requirements: - python - pip - filelock - - requests + - httpx - tqdm - typing-extensions - packaging diff --git a/docs/source/en/guides/cli.md b/docs/source/en/guides/cli.md index bfeaeeffb8..a754e010b4 100644 --- a/docs/source/en/guides/cli.md +++ b/docs/source/en/guides/cli.md @@ -278,7 +278,7 @@ By default, the `hf download` command will be verbose. It will print details suc On machines with slow connections, you might encounter timeout issues like this one: ```bash -`requests.exceptions.ReadTimeout: (ReadTimeoutError("HTTPSConnectionPool(host='cdn-lfs-us-1.huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: a33d910c-84c6-4514-8362-c705e2039d38)')` +`httpx.TimeoutException: (TimeoutException("HTTPSConnectionPool(host='cdn-lfs-us-1.huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: a33d910c-84c6-4514-8362-c705e2039d38)')` ``` To mitigate this issue, you can set the `HF_HUB_DOWNLOAD_TIMEOUT` environment variable to a higher value (default is 10): diff --git a/docs/source/en/package_reference/environment_variables.md b/docs/source/en/package_reference/environment_variables.md index 974d611208..cc2dd9cda1 100644 --- a/docs/source/en/package_reference/environment_variables.md +++ b/docs/source/en/package_reference/environment_variables.md @@ -179,7 +179,7 @@ Set to disable using `hf-xet`, even if it is available in your Python environmen Set to `True` for faster uploads and downloads from the Hub using `hf_transfer`. -By default, `huggingface_hub` uses the Python-based `requests.get` and `requests.post` functions. +By default, `huggingface_hub` uses the Python-based `httpx.get` and `httpx.post` functions. Although these are reliable and versatile, they may not be the most efficient choice for machines with high bandwidth. [`hf_transfer`](https://github.com/huggingface/hf_transfer) is a Rust-based package developed to diff --git a/docs/source/en/package_reference/utilities.md b/docs/source/en/package_reference/utilities.md index 80fe3148ff..df6537297a 100644 --- a/docs/source/en/package_reference/utilities.md +++ b/docs/source/en/package_reference/utilities.md @@ -177,35 +177,39 @@ Here is a list of HTTP errors thrown in `huggingface_hub`. the server response and format the error message to provide as much information to the user as possible. -[[autodoc]] huggingface_hub.utils.HfHubHTTPError +[[autodoc]] huggingface_hub.errors.HfHubHTTPError #### RepositoryNotFoundError -[[autodoc]] huggingface_hub.utils.RepositoryNotFoundError +[[autodoc]] huggingface_hub.errors.RepositoryNotFoundError #### GatedRepoError -[[autodoc]] huggingface_hub.utils.GatedRepoError +[[autodoc]] huggingface_hub.errors.GatedRepoError #### RevisionNotFoundError -[[autodoc]] huggingface_hub.utils.RevisionNotFoundError +[[autodoc]] huggingface_hub.errors.RevisionNotFoundError + +#### BadRequestError + +[[autodoc]] huggingface_hub.errors.BadRequestError #### EntryNotFoundError -[[autodoc]] huggingface_hub.utils.EntryNotFoundError +[[autodoc]] huggingface_hub.errors.EntryNotFoundError -#### BadRequestError +#### RemoteEntryNotFoundError -[[autodoc]] huggingface_hub.utils.BadRequestError +[[autodoc]] huggingface_hub.errors.RemoteEntryNotFoundError #### LocalEntryNotFoundError -[[autodoc]] huggingface_hub.utils.LocalEntryNotFoundError +[[autodoc]] huggingface_hub.errors.LocalEntryNotFoundError #### OfflineModeIsEnabled -[[autodoc]] huggingface_hub.utils.OfflineModeIsEnabled +[[autodoc]] huggingface_hub.errors.OfflineModeIsEnabled ## Telemetry diff --git a/docs/source/ko/package_reference/utilities.md b/docs/source/ko/package_reference/utilities.md index a76e9d474b..96ac88e432 100644 --- a/docs/source/ko/package_reference/utilities.md +++ b/docs/source/ko/package_reference/utilities.md @@ -125,39 +125,43 @@ except HfHubHTTPError as e: 여기에는 `huggingface_hub`에서 발생하는 HTTP 오류 목록이 있습니다. -#### HfHubHTTPError[[huggingface_hub.utils.HfHubHTTPError]] +#### HfHubHTTPError[[huggingface_hub.errors.HfHubHTTPError]] `HfHubHTTPError`는 HF Hub HTTP 오류에 대한 부모 클래스입니다. 이 클래스는 서버 응답을 구문 분석하고 오류 메시지를 형식화하여 사용자에게 가능한 많은 정보를 제공합니다. -[[autodoc]] huggingface_hub.utils.HfHubHTTPError +[[autodoc]] huggingface_hub.errors.HfHubHTTPError -#### RepositoryNotFoundError[[huggingface_hub.utils.RepositoryNotFoundError]] +#### RepositoryNotFoundError[[huggingface_hub.errors.RepositoryNotFoundError]] -[[autodoc]] huggingface_hub.utils.RepositoryNotFoundError +[[autodoc]] huggingface_hub.errors.RepositoryNotFoundError -#### GatedRepoError[[huggingface_hub.utils.GatedRepoError]] +#### GatedRepoError[[huggingface_hub.errors.GatedRepoError]] -[[autodoc]] huggingface_hub.utils.GatedRepoError +[[autodoc]] huggingface_hub.errors.GatedRepoError -#### RevisionNotFoundError[[huggingface_hub.utils.RevisionNotFoundError]] +#### RevisionNotFoundError[[huggingface_hub.errors.RevisionNotFoundError]] -[[autodoc]] huggingface_hub.utils.RevisionNotFoundError +[[autodoc]] huggingface_hub.errors.RevisionNotFoundError -#### EntryNotFoundError[[huggingface_hub.utils.EntryNotFoundError]] +#### BadRequestError[[huggingface_hub.errors.BadRequestError]] -[[autodoc]] huggingface_hub.utils.EntryNotFoundError +[[autodoc]] huggingface_hub.errors.BadRequestError -#### BadRequestError[[huggingface_hub.utils.BadRequestError]] +#### EntryNotFoundError[[huggingface_hub.errors.EntryNotFoundError]] -[[autodoc]] huggingface_hub.utils.BadRequestError +[[autodoc]] huggingface_hub.errors.EntryNotFoundError -#### LocalEntryNotFoundError[[huggingface_hub.utils.LocalEntryNotFoundError]] +#### RemoteEntryNotFoundError[[huggingface_hub.errors.RemoteEntryNotFoundError]] -[[autodoc]] huggingface_hub.utils.LocalEntryNotFoundError +[[autodoc]] huggingface_hub.errors.RemoteEntryNotFoundError -#### OfflineModeIsEnabledd[[huggingface_hub.utils.OfflineModeIsEnabled]] +#### LocalEntryNotFoundError[[huggingface_hub.errors.LocalEntryNotFoundError]] -[[autodoc]] huggingface_hub.utils.OfflineModeIsEnabled +[[autodoc]] huggingface_hub.errors.LocalEntryNotFoundError + +#### OfflineModeIsEnabledd[[huggingface_hub.errors.OfflineModeIsEnabled]] + +[[autodoc]] huggingface_hub.errors.OfflineModeIsEnabled ## 원격 측정[[huggingface_hub.utils.send_telemetry]] diff --git a/setup.py b/setup.py index 028c67be08..3fd35880fa 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ def get_version() -> str: "hf-xet>=1.1.3,<2.0.0; platform_machine=='x86_64' or platform_machine=='amd64' or platform_machine=='arm64' or platform_machine=='aarch64'", "packaging>=20.9", "pyyaml>=5.1", - "requests", + "httpx>=0.23.0, <1", "tqdm>=4.42.1", "typing-extensions>=3.7.4.3", # to be able to import TypeAlias ] @@ -89,6 +89,7 @@ def get_version() -> str: "soundfile", "Pillow", "gradio>=4.0.0", # to test webhooks # pin to avoid issue on Python3.12 + "requests", # for gradio "numpy", # for embeddings "fastapi", # To build the documentation ] @@ -99,7 +100,6 @@ def get_version() -> str: extras["typing"] = [ "typing-extensions>=4.8.0", "types-PyYAML", - "types-requests", "types-simplejson", "types-toml", "types-tqdm", diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index 0f4d0d1598..c1d2c6658f 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -520,6 +520,8 @@ "read_dduf_file", ], "utils": [ + "ASYNC_CLIENT_FACTORY_T", + "CLIENT_FACTORY_T", "CacheNotFound", "CachedFileInfo", "CachedRepoInfo", @@ -528,13 +530,19 @@ "DeleteCacheStrategy", "HFCacheInfo", "HfFolder", + "HfHubAsyncTransport", + "HfHubTransport", "cached_assets_path", - "configure_http_backend", + "close_client", "dump_environment_info", + "get_async_session", "get_session", "get_token", + "hf_raise_for_status", "logging", "scan_cache_dir", + "set_async_client_factory", + "set_client_factory", ], } @@ -550,6 +558,7 @@ # ``` __all__ = [ + "ASYNC_CLIENT_FACTORY_T", "Agent", "AsyncInferenceClient", "AudioClassificationInput", @@ -564,6 +573,7 @@ "AutomaticSpeechRecognitionOutput", "AutomaticSpeechRecognitionOutputChunk", "AutomaticSpeechRecognitionParameters", + "CLIENT_FACTORY_T", "CONFIG_NAME", "CacheNotFound", "CachedFileInfo", @@ -653,6 +663,8 @@ "HfFileSystemResolvedPath", "HfFileSystemStreamFile", "HfFolder", + "HfHubAsyncTransport", + "HfHubTransport", "ImageClassificationInput", "ImageClassificationOutputElement", "ImageClassificationOutputTransform", @@ -824,8 +836,8 @@ "cancel_access_request", "cancel_job", "change_discussion_status", + "close_client", "comment_discussion", - "configure_http_backend", "create_branch", "create_collection", "create_commit", @@ -863,6 +875,7 @@ "file_exists", "from_pretrained_fastai", "from_pretrained_keras", + "get_async_session", "get_collection", "get_dataset_tags", "get_discussion_details", @@ -886,6 +899,7 @@ "grant_access", "hf_hub_download", "hf_hub_url", + "hf_raise_for_status", "inspect_job", "inspect_scheduled_job", "interpreter_login", @@ -953,6 +967,8 @@ "save_torch_state_dict", "scale_to_zero_inference_endpoint", "scan_cache_dir", + "set_async_client_factory", + "set_client_factory", "set_space_sleep_time", "snapshot_download", "space_info", @@ -1530,6 +1546,8 @@ def __dir__(): read_dduf_file, # noqa: F401 ) from .utils import ( + ASYNC_CLIENT_FACTORY_T, # noqa: F401 + CLIENT_FACTORY_T, # noqa: F401 CachedFileInfo, # noqa: F401 CachedRepoInfo, # noqa: F401 CachedRevisionInfo, # noqa: F401 @@ -1538,11 +1556,17 @@ def __dir__(): DeleteCacheStrategy, # noqa: F401 HFCacheInfo, # noqa: F401 HfFolder, # noqa: F401 + HfHubAsyncTransport, # noqa: F401 + HfHubTransport, # noqa: F401 cached_assets_path, # noqa: F401 - configure_http_backend, # noqa: F401 + close_client, # noqa: F401 dump_environment_info, # noqa: F401 + get_async_session, # noqa: F401 get_session, # noqa: F401 get_token, # noqa: F401 + hf_raise_for_status, # noqa: F401 logging, # noqa: F401 scan_cache_dir, # noqa: F401 + set_async_client_factory, # noqa: F401 + set_client_factory, # noqa: F401 ) diff --git a/src/huggingface_hub/_commit_api.py b/src/huggingface_hub/_commit_api.py index 9e8fa86e6c..58e082b307 100644 --- a/src/huggingface_hub/_commit_api.py +++ b/src/huggingface_hub/_commit_api.py @@ -235,7 +235,7 @@ def as_file(self, with_tqdm: bool = False) -> Iterator[BinaryIO]: config.json: 100%|█████████████████████████| 8.19k/8.19k [00:02<00:00, 3.72kB/s] >>> with operation.as_file(with_tqdm=True) as file: - ... requests.put(..., data=file) + ... httpx.put(..., data=file) config.json: 100%|█████████████████████████| 8.19k/8.19k [00:02<00:00, 3.72kB/s] ``` """ @@ -389,7 +389,7 @@ def _upload_lfs_files( If an upload failed for any reason [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If the server returns malformed responses - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + [`HfHubHTTPError`] If the LFS batch endpoint returned an HTTP error. """ # Step 1: retrieve upload instructions from the LFS batch endpoint. @@ -500,7 +500,7 @@ def _upload_xet_files( If an upload failed for any reason. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If the server returns malformed responses or if the user is unauthorized to upload to xet storage. - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + [`HfHubHTTPError`] If the LFS batch endpoint returned an HTTP error. **How it works:** diff --git a/src/huggingface_hub/_commit_scheduler.py b/src/huggingface_hub/_commit_scheduler.py index f1f20339e7..f28180fd68 100644 --- a/src/huggingface_hub/_commit_scheduler.py +++ b/src/huggingface_hub/_commit_scheduler.py @@ -315,10 +315,13 @@ def __len__(self) -> int: return self._size_limit def __getattribute__(self, name: str): - if name.startswith("_") or name in ("read", "tell", "seek"): # only 3 public methods supported + if name.startswith("_") or name in ("read", "tell", "seek", "fileno"): # only 4 public methods supported return super().__getattribute__(name) raise NotImplementedError(f"PartialFileIO does not support '{name}'.") + def fileno(self): + raise AttributeError("PartialFileIO does not have a fileno.") + def tell(self) -> int: """Return the current file position.""" return self._file.tell() diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index 0db8a29f7e..aa65d561da 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Dict, Iterable, List, Literal, Optional, Type, Union -import requests +import httpx from tqdm.auto import tqdm as base_tqdm from tqdm.contrib.concurrent import thread_map @@ -36,7 +36,6 @@ def snapshot_download( library_name: Optional[str] = None, library_version: Optional[str] = None, user_agent: Optional[Union[Dict, str]] = None, - proxies: Optional[Dict] = None, etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, force_download: bool = False, token: Optional[Union[bool, str]] = None, @@ -85,12 +84,9 @@ def snapshot_download( The version of the library. user_agent (`str`, `dict`, *optional*): The user-agent info in the form of a dictionary or a string. - proxies (`dict`, *optional*): - Dictionary mapping protocol to the URL of the proxy passed to - `requests.request`. etag_timeout (`float`, *optional*, defaults to `10`): When fetching ETag, how many seconds to wait for the server to send - data before giving up which is passed to `requests.request`. + data before giving up which is passed to `httpx.request`. force_download (`bool`, *optional*, defaults to `False`): Whether the file should be downloaded even if it already exists in the local cache. token (`str`, `bool`, *optional*): @@ -163,14 +159,10 @@ def snapshot_download( try: # if we have internet connection we want to list files to download repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision) - except (requests.exceptions.SSLError, requests.exceptions.ProxyError): - # Actually raise for those subclasses of ConnectionError + except httpx.ProxyError: + # Actually raise on proxy error raise - except ( - requests.exceptions.ConnectionError, - requests.exceptions.Timeout, - OfflineModeIsEnabled, - ) as error: + except (httpx.ConnectError, httpx.TimeoutException, OfflineModeIsEnabled) as error: # Internet connection is down # => will try to use local files only api_call_error = error @@ -178,7 +170,7 @@ def snapshot_download( except RevisionNotFoundError: # The repo was found but the revision doesn't exist on the Hub (never existed or got deleted) raise - except requests.HTTPError as error: + except HfHubHTTPError as error: # Multiple reasons for an http error: # - Repository is private and invalid/missing token sent # - Repository is gated and invalid/missing token sent @@ -315,7 +307,6 @@ def _inner_hf_hub_download(repo_file: str): library_name=library_name, library_version=library_version, user_agent=user_agent, - proxies=proxies, etag_timeout=etag_timeout, resume_download=resume_download, force_download=force_download, diff --git a/src/huggingface_hub/cli/auth.py b/src/huggingface_hub/cli/auth.py index bbf475a4f8..91e6b3c18d 100644 --- a/src/huggingface_hub/cli/auth.py +++ b/src/huggingface_hub/cli/auth.py @@ -33,10 +33,9 @@ from argparse import _SubParsersAction from typing import List, Optional -from requests.exceptions import HTTPError - from huggingface_hub.commands import BaseHuggingfaceCLICommand from huggingface_hub.constants import ENDPOINT +from huggingface_hub.errors import HfHubHTTPError from huggingface_hub.hf_api import HfApi from .._login import auth_list, auth_switch, login, logout @@ -207,7 +206,7 @@ def run(self): if ENDPOINT != "https://huggingface.co": print(f"Authenticated through private endpoint: {ENDPOINT}") - except HTTPError as e: + except HfHubHTTPError as e: print(e) print(ANSI.red(e.response.text)) exit(1) diff --git a/src/huggingface_hub/cli/jobs.py b/src/huggingface_hub/cli/jobs.py index 3a661c7df7..5b8d355c6f 100644 --- a/src/huggingface_hub/cli/jobs.py +++ b/src/huggingface_hub/cli/jobs.py @@ -38,9 +38,8 @@ from pathlib import Path from typing import Dict, List, Optional, Union -import requests - from huggingface_hub import HfApi, SpaceHardware, get_token +from huggingface_hub.errors import HfHubHTTPError from huggingface_hub.utils import logging from huggingface_hub.utils._dotenv import load_dotenv @@ -329,7 +328,7 @@ def run(self) -> None: # Apply custom format if provided or use default tabular format self._print_output(rows, table_headers) - except requests.RequestException as e: + except HfHubHTTPError as e: print(f"Error fetching jobs data: {e}") except (KeyError, ValueError, TypeError) as e: print(f"Error processing jobs data: {e}") @@ -815,7 +814,7 @@ def run(self) -> None: # Apply custom format if provided or use default tabular format self._print_output(rows, table_headers) - except requests.RequestException as e: + except HfHubHTTPError as e: print(f"Error fetching scheduled jobs data: {e}") except (KeyError, ValueError, TypeError) as e: print(f"Error processing scheduled jobs data: {e}") diff --git a/src/huggingface_hub/cli/repo.py b/src/huggingface_hub/cli/repo.py index ef0e331358..8f5a330a9f 100644 --- a/src/huggingface_hub/cli/repo.py +++ b/src/huggingface_hub/cli/repo.py @@ -25,8 +25,6 @@ from argparse import _SubParsersAction from typing import Optional -from requests.exceptions import HTTPError - from huggingface_hub.commands import BaseHuggingfaceCLICommand from huggingface_hub.commands._cli_utils import ANSI from huggingface_hub.constants import REPO_TYPES, SPACES_SDK_TYPES @@ -218,7 +216,7 @@ def run(self): except RepositoryNotFoundError: print(f"{self.repo_type.capitalize()} {ANSI.bold(self.repo_id)} not found.") exit(1) - except HTTPError as e: + except HfHubHTTPError as e: print(e) print(ANSI.red(e.response.text)) exit(1) diff --git a/src/huggingface_hub/commands/tag.py b/src/huggingface_hub/commands/tag.py index 405d407f81..a961791155 100644 --- a/src/huggingface_hub/commands/tag.py +++ b/src/huggingface_hub/commands/tag.py @@ -32,8 +32,6 @@ from argparse import Namespace, _SubParsersAction -from requests.exceptions import HTTPError - from huggingface_hub.commands import BaseHuggingfaceCLICommand from huggingface_hub.constants import ( REPO_TYPES, @@ -129,7 +127,7 @@ def run(self): except RepositoryNotFoundError: print(f"{self.repo_type.capitalize()} {ANSI.bold(self.repo_id)} not found.") exit(1) - except HTTPError as e: + except HfHubHTTPError as e: print(e) print(ANSI.red(e.response.text)) exit(1) diff --git a/src/huggingface_hub/commands/user.py b/src/huggingface_hub/commands/user.py index 3f4da0f45d..61cbc4c9e1 100644 --- a/src/huggingface_hub/commands/user.py +++ b/src/huggingface_hub/commands/user.py @@ -33,10 +33,9 @@ from argparse import _SubParsersAction from typing import List, Optional -from requests.exceptions import HTTPError - from huggingface_hub.commands import BaseHuggingfaceCLICommand from huggingface_hub.constants import ENDPOINT +from huggingface_hub.errors import HfHubHTTPError from huggingface_hub.hf_api import HfApi from .._login import auth_list, auth_switch, login, logout @@ -202,7 +201,7 @@ def run(self): if ENDPOINT != "https://huggingface.co": print(f"Authenticated through private endpoint: {ENDPOINT}") - except HTTPError as e: + except HfHubHTTPError as e: print(e) print(ANSI.red(e.response.text)) exit(1) diff --git a/src/huggingface_hub/errors.py b/src/huggingface_hub/errors.py index a0f7ed80e3..4426d7576b 100644 --- a/src/huggingface_hub/errors.py +++ b/src/huggingface_hub/errors.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import Optional, Union -from requests import HTTPError, Response +from httpx import HTTPError, Response # CACHE ERRORS @@ -51,7 +51,7 @@ class HfHubHTTPError(HTTPError): Example: ```py - import requests + import httpx from huggingface_hub.utils import get_session, hf_raise_for_status, HfHubHTTPError response = get_session().post(...) @@ -67,19 +67,18 @@ class HfHubHTTPError(HTTPError): ``` """ - def __init__(self, message: str, response: Optional[Response] = None, *, server_message: Optional[str] = None): - self.request_id = ( - response.headers.get("x-request-id") or response.headers.get("X-Amzn-Trace-Id") - if response is not None - else None - ) + def __init__( + self, + message: str, + *, + response: Response, + server_message: Optional[str] = None, + ): + self.request_id = response.headers.get("x-request-id") or response.headers.get("X-Amzn-Trace-Id") self.server_message = server_message - - super().__init__( - message, - response=response, # type: ignore [arg-type] - request=response.request if response is not None else None, # type: ignore [arg-type] - ) + self.response = response + self.request = response.request + super().__init__(message) def append_to_message(self, additional_message: str) -> None: """Append additional information to the `HfHubHTTPError` initial message.""" @@ -182,7 +181,7 @@ class RepositoryNotFoundError(HfHubHTTPError): >>> from huggingface_hub import model_info >>> model_info("") (...) - huggingface_hub.utils._errors.RepositoryNotFoundError: 401 Client Error. (Request ID: PvMw_VjBMjVdMz53WKIzP) + huggingface_hub.errors.RepositoryNotFoundError: 401 Client Error. (Request ID: PvMw_VjBMjVdMz53WKIzP) Repository Not Found for url: https://huggingface.co/api/models/%3Cnon_existent_repository%3E. Please make sure you specified the correct `repo_id` and `repo_type`. @@ -205,7 +204,7 @@ class GatedRepoError(RepositoryNotFoundError): >>> from huggingface_hub import model_info >>> model_info("") (...) - huggingface_hub.utils._errors.GatedRepoError: 403 Client Error. (Request ID: ViT1Bf7O_026LGSQuVqfa) + huggingface_hub.errors.GatedRepoError: 403 Client Error. (Request ID: ViT1Bf7O_026LGSQuVqfa) Cannot access gated repo for url https://huggingface.co/api/models/ardent-figment/gated-model. Access to model ardent-figment/gated-model is restricted and you are not in the authorized list. @@ -224,7 +223,7 @@ class DisabledRepoError(HfHubHTTPError): >>> from huggingface_hub import dataset_info >>> dataset_info("laion/laion-art") (...) - huggingface_hub.utils._errors.DisabledRepoError: 403 Client Error. (Request ID: Root=1-659fc3fa-3031673e0f92c71a2260dbe2;bc6f4dfb-b30a-4862-af0a-5cfe827610d8) + huggingface_hub.errors.DisabledRepoError: 403 Client Error. (Request ID: Root=1-659fc3fa-3031673e0f92c71a2260dbe2;bc6f4dfb-b30a-4862-af0a-5cfe827610d8) Cannot access repository for url https://huggingface.co/api/datasets/laion/laion-art. Access to this resource is disabled. @@ -246,7 +245,7 @@ class RevisionNotFoundError(HfHubHTTPError): >>> from huggingface_hub import hf_hub_download >>> hf_hub_download('bert-base-cased', 'config.json', revision='') (...) - huggingface_hub.utils._errors.RevisionNotFoundError: 404 Client Error. (Request ID: Mwhe_c3Kt650GcdKEFomX) + huggingface_hub.errors.RevisionNotFoundError: 404 Client Error. (Request ID: Mwhe_c3Kt650GcdKEFomX) Revision Not Found for url: https://huggingface.co/bert-base-cased/resolve/%3Cnon-existent-revision%3E/config.json. ``` @@ -254,7 +253,25 @@ class RevisionNotFoundError(HfHubHTTPError): # ENTRY ERRORS -class EntryNotFoundError(HfHubHTTPError): +class EntryNotFoundError(Exception): + """ + Raised when entry not found, either locally or remotely. + + Example: + + ```py + >>> from huggingface_hub import hf_hub_download + >>> hf_hub_download('bert-base-cased', '') + (...) + huggingface_hub.errors.RemoteEntryNotFoundError (...) + >>> hf_hub_download('bert-base-cased', '', local_files_only=True) + (...) + huggingface_hub.utils.errors.LocalEntryNotFoundError (...) + ``` + """ + + +class RemoteEntryNotFoundError(HfHubHTTPError, EntryNotFoundError): """ Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename. @@ -265,34 +282,30 @@ class EntryNotFoundError(HfHubHTTPError): >>> from huggingface_hub import hf_hub_download >>> hf_hub_download('bert-base-cased', '') (...) - huggingface_hub.utils._errors.EntryNotFoundError: 404 Client Error. (Request ID: 53pNl6M0MxsnG5Sw8JA6x) + huggingface_hub.errors.EntryNotFoundError: 404 Client Error. (Request ID: 53pNl6M0MxsnG5Sw8JA6x) Entry Not Found for url: https://huggingface.co/bert-base-cased/resolve/main/%3Cnon-existent-file%3E. ``` """ -class LocalEntryNotFoundError(EntryNotFoundError, FileNotFoundError, ValueError): +class LocalEntryNotFoundError(FileNotFoundError, EntryNotFoundError): """ Raised when trying to access a file or snapshot that is not on the disk when network is disabled or unavailable (connection issue). The entry may exist on the Hub. - Note: `ValueError` type is to ensure backward compatibility. - Note: `LocalEntryNotFoundError` derives from `HTTPError` because of `EntryNotFoundError` - even when it is not a network issue. - Example: ```py >>> from huggingface_hub import hf_hub_download >>> hf_hub_download('bert-base-cased', '', local_files_only=True) (...) - huggingface_hub.utils._errors.LocalEntryNotFoundError: Cannot find the requested files in the disk cache and outgoing traffic has been disabled. To enable hf.co look-ups and downloads online, set 'local_files_only' to False. + huggingface_hub.errors.LocalEntryNotFoundError: Cannot find the requested files in the disk cache and outgoing traffic has been disabled. To enable hf.co look-ups and downloads online, set 'local_files_only' to False. ``` """ def __init__(self, message: str): - super().__init__(message, response=None) + super().__init__(message) # REQUEST ERROR @@ -303,9 +316,9 @@ class BadRequestError(HfHubHTTPError, ValueError): Example: ```py - >>> resp = requests.post("hf.co/api/check", ...) + >>> resp = httpx.post("hf.co/api/check", ...) >>> hf_raise_for_status(resp, endpoint_name="check") - huggingface_hub.utils._errors.BadRequestError: Bad request for check endpoint: {details} (Request ID: XXX) + huggingface_hub.errors.BadRequestError: Bad request for check endpoint: {details} (Request ID: XXX) ``` """ diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index 4fc063796a..89e2ad74e4 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -1,6 +1,5 @@ import copy import errno -import inspect import os import re import shutil @@ -13,7 +12,7 @@ from typing import Any, BinaryIO, Dict, Literal, NoReturn, Optional, Tuple, Union from urllib.parse import quote, urlparse -import requests +import httpx from . import ( __version__, # noqa: F401 # for backward compatibility @@ -25,11 +24,11 @@ HUGGINGFACE_HUB_CACHE, # noqa: F401 # for backward compatibility ) from .errors import ( - EntryNotFoundError, FileMetadataError, GatedRepoError, HfHubHTTPError, LocalEntryNotFoundError, + RemoteEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError, ) @@ -57,11 +56,10 @@ logging, parse_xet_file_data_from_response, refresh_xet_connection_info, - reset_sessions, tqdm, validate_hf_hub_args, ) -from .utils._http import _adjust_range_header, http_backoff +from .utils._http import _adjust_range_header, http_backoff, http_stream_backoff from .utils._runtime import _PY_VERSION, is_xet_available # noqa: F401 # for backward compatibility from .utils._typing import HTTP_METHOD_T from .utils.sha import sha_fileobj @@ -261,11 +259,10 @@ def hf_hub_url( return url -def _request_wrapper( - method: HTTP_METHOD_T, url: str, *, follow_relative_redirects: bool = False, **params -) -> requests.Response: - """Wrapper around requests methods to follow relative redirects if `follow_relative_redirects=True` even when - `allow_redirection=False`. +def _httpx_follow_relative_redirects(method: HTTP_METHOD_T, url: str, **httpx_kwargs) -> httpx.Response: + """Perform an HTTP request with backoff and follow relative redirects only. + + This is useful to follow a redirection to a renamed repository without following redirection to a CDN. A backoff mechanism retries the HTTP call on 429, 503 and 504 errors. @@ -274,44 +271,36 @@ def _request_wrapper( HTTP method, such as 'GET' or 'HEAD'. url (`str`): The URL of the resource to fetch. - follow_relative_redirects (`bool`, *optional*, defaults to `False`) - If True, relative redirection (redirection to the same site) will be resolved even when `allow_redirection` - kwarg is set to False. Useful when we want to follow a redirection to a renamed repository without - following redirection to a CDN. - **params (`dict`, *optional*): - Params to pass to `requests.request`. + **httpx_kwargs (`dict`, *optional*): + Params to pass to `httpx.request`. """ - # Recursively follow relative redirects - if follow_relative_redirects: - response = _request_wrapper( + while True: + # Make the request + response = http_backoff( method=method, url=url, - follow_relative_redirects=False, - **params, + **httpx_kwargs, + follow_redirects=False, + retry_on_exceptions=(), + retry_on_status_codes=(429,), ) + hf_raise_for_status(response) - # If redirection, we redirect only relative paths. - # This is useful in case of a renamed repository. + # Check if response is a relative redirect if 300 <= response.status_code <= 399: parsed_target = urlparse(response.headers["Location"]) if parsed_target.netloc == "": - # This means it is a relative 'location' headers, as allowed by RFC 7231. - # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource') - # We want to follow this relative redirect ! - # - # Highly inspired by `resolve_redirects` from requests library. - # See https://github.com/psf/requests/blob/main/requests/sessions.py#L159 - next_url = urlparse(url)._replace(path=parsed_target.path).geturl() - return _request_wrapper(method=method, url=next_url, follow_relative_redirects=True, **params) - return response - - # Perform request and return if status_code is not in the retry list. - response = http_backoff(method=method, url=url, **params, retry_on_exceptions=(), retry_on_status_codes=(429,)) - hf_raise_for_status(response) + # Relative redirect -> update URL and retry + url = urlparse(url)._replace(path=parsed_target.path).geturl() + continue + + # Break if no relative redirect + break + return response -def _get_file_length_from_http_response(response: requests.Response) -> Optional[int]: +def _get_file_length_from_http_response(response: httpx.Response) -> Optional[int]: """ Get the length of the file from the HTTP response headers. @@ -319,7 +308,7 @@ def _get_file_length_from_http_response(response: requests.Response) -> Optional `Content-Range` or `Content-Length` header, if available (in that order). Args: - response (`requests.Response`): + response (`httpx.Response`): The HTTP response object. Returns: @@ -346,11 +335,11 @@ def _get_file_length_from_http_response(response: requests.Response) -> Optional return None +@validate_hf_hub_args def http_get( url: str, temp_file: BinaryIO, *, - proxies: Optional[Dict] = None, resume_size: int = 0, headers: Optional[Dict[str, Any]] = None, expected_size: Optional[int] = None, @@ -370,8 +359,6 @@ def http_get( The URL of the file to download. temp_file (`BinaryIO`): The file-like object where to save the file. - proxies (`dict`, *optional*): - Dictionary mapping protocol to the URL of the proxy passed to `requests.request`. resume_size (`int`, *optional*): The number of bytes already downloaded. If set to 0 (default), the whole file is download. If set to a positive number, the download will resume at the given position. @@ -393,8 +380,6 @@ def http_get( if constants.HF_HUB_ENABLE_HF_TRANSFER: if resume_size != 0: warnings.warn("'hf_transfer' does not support `resume_size`: falling back to regular download method") - elif proxies is not None: - warnings.warn("'hf_transfer' does not support `proxies`: falling back to regular download method") elif has_custom_range_header: warnings.warn("'hf_transfer' ignores custom 'Range' headers; falling back to regular download method") else: @@ -423,103 +408,97 @@ def http_get( " Try `pip install hf_transfer` or `pip install hf_xet`." ) - r = _request_wrapper( - method="GET", url=url, stream=True, proxies=proxies, headers=headers, timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT - ) - - hf_raise_for_status(r) - total: Optional[int] = _get_file_length_from_http_response(r) - - if displayed_filename is None: - displayed_filename = url - content_disposition = r.headers.get("Content-Disposition") - if content_disposition is not None: - match = HEADER_FILENAME_PATTERN.search(content_disposition) - if match is not None: - # Means file is on CDN - displayed_filename = match.groupdict()["filename"] - - # Truncate filename if too long to display - if len(displayed_filename) > 40: - displayed_filename = f"(…){displayed_filename[-40:]}" + with http_stream_backoff( + method="GET", + url=url, + headers=headers, + timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT, + retry_on_exceptions=(), + retry_on_status_codes=(429,), + ) as response: + hf_raise_for_status(response) + total: Optional[int] = _get_file_length_from_http_response(response) + + if displayed_filename is None: + displayed_filename = url + content_disposition = response.headers.get("Content-Disposition") + if content_disposition is not None: + match = HEADER_FILENAME_PATTERN.search(content_disposition) + if match is not None: + # Means file is on CDN + displayed_filename = match.groupdict()["filename"] + + # Truncate filename if too long to display + if len(displayed_filename) > 40: + displayed_filename = f"(…){displayed_filename[-40:]}" + + consistency_error_message = ( + f"Consistency check failed: file should be of size {expected_size} but has size" + f" {{actual_size}} ({displayed_filename}).\nThis is usually due to network issues while downloading the file." + " Please retry with `force_download=True`." + ) + progress_cm = _get_progress_bar_context( + desc=displayed_filename, + log_level=logger.getEffectiveLevel(), + total=total, + initial=resume_size, + name="huggingface_hub.http_get", + _tqdm_bar=_tqdm_bar, + ) - consistency_error_message = ( - f"Consistency check failed: file should be of size {expected_size} but has size" - f" {{actual_size}} ({displayed_filename}).\nThis is usually due to network issues while downloading the file." - " Please retry with `force_download=True`." - ) - progress_cm = _get_progress_bar_context( - desc=displayed_filename, - log_level=logger.getEffectiveLevel(), - total=total, - initial=resume_size, - name="huggingface_hub.http_get", - _tqdm_bar=_tqdm_bar, - ) + with progress_cm as progress: + if hf_transfer and total is not None and total > 5 * constants.DOWNLOAD_CHUNK_SIZE: + try: + hf_transfer.download( + url=url, + filename=temp_file.name, + max_files=constants.HF_TRANSFER_CONCURRENCY, + chunk_size=constants.DOWNLOAD_CHUNK_SIZE, + headers=initial_headers, + parallel_failures=3, + max_retries=5, + callback=progress.update, + ) + except Exception as e: + raise RuntimeError( + "An error occurred while downloading using `hf_transfer`. Consider" + " disabling HF_HUB_ENABLE_HF_TRANSFER for better error handling." + ) from e + if expected_size is not None and expected_size != os.path.getsize(temp_file.name): + raise EnvironmentError( + consistency_error_message.format( + actual_size=os.path.getsize(temp_file.name), + ) + ) + return - with progress_cm as progress: - if hf_transfer and total is not None and total > 5 * constants.DOWNLOAD_CHUNK_SIZE: - supports_callback = "callback" in inspect.signature(hf_transfer.download).parameters - if not supports_callback: - warnings.warn( - "You are using an outdated version of `hf_transfer`. " - "Consider upgrading to latest version to enable progress bars " - "using `pip install -U hf_transfer`." - ) + new_resume_size = resume_size try: - hf_transfer.download( + for chunk in response.iter_bytes(chunk_size=constants.DOWNLOAD_CHUNK_SIZE): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + new_resume_size += len(chunk) + # Some data has been downloaded from the server so we reset the number of retries. + _nb_retries = 5 + except (httpx.ConnectError, httpx.TimeoutException) as e: + # If ConnectionError (SSLError) or ReadTimeout happen while streaming data from the server, it is most likely + # a transient error (network outage?). We log a warning message and try to resume the download a few times + # before giving up. Tre retry mechanism is basic but should be enough in most cases. + if _nb_retries <= 0: + logger.warning("Error while downloading from %s: %s\nMax retries exceeded.", url, str(e)) + raise + logger.warning("Error while downloading from %s: %s\nTrying to resume download...", url, str(e)) + time.sleep(1) + return http_get( url=url, - filename=temp_file.name, - max_files=constants.HF_TRANSFER_CONCURRENCY, - chunk_size=constants.DOWNLOAD_CHUNK_SIZE, + temp_file=temp_file, + resume_size=new_resume_size, headers=initial_headers, - parallel_failures=3, - max_retries=5, - **({"callback": progress.update} if supports_callback else {}), - ) - except Exception as e: - raise RuntimeError( - "An error occurred while downloading using `hf_transfer`. Consider" - " disabling HF_HUB_ENABLE_HF_TRANSFER for better error handling." - ) from e - if not supports_callback: - progress.update(total) - if expected_size is not None and expected_size != os.path.getsize(temp_file.name): - raise EnvironmentError( - consistency_error_message.format( - actual_size=os.path.getsize(temp_file.name), - ) + expected_size=expected_size, + _nb_retries=_nb_retries - 1, + _tqdm_bar=_tqdm_bar, ) - return - new_resume_size = resume_size - try: - for chunk in r.iter_content(chunk_size=constants.DOWNLOAD_CHUNK_SIZE): - if chunk: # filter out keep-alive new chunks - progress.update(len(chunk)) - temp_file.write(chunk) - new_resume_size += len(chunk) - # Some data has been downloaded from the server so we reset the number of retries. - _nb_retries = 5 - except (requests.ConnectionError, requests.ReadTimeout) as e: - # If ConnectionError (SSLError) or ReadTimeout happen while streaming data from the server, it is most likely - # a transient error (network outage?). We log a warning message and try to resume the download a few times - # before giving up. Tre retry mechanism is basic but should be enough in most cases. - if _nb_retries <= 0: - logger.warning("Error while downloading from %s: %s\nMax retries exceeded.", url, str(e)) - raise - logger.warning("Error while downloading from %s: %s\nTrying to resume download...", url, str(e)) - time.sleep(1) - reset_sessions() # In case of SSLError it's best to reset the shared requests.Session objects - return http_get( - url=url, - temp_file=temp_file, - proxies=proxies, - resume_size=new_resume_size, - headers=initial_headers, - expected_size=expected_size, - _nb_retries=_nb_retries - 1, - _tqdm_bar=_tqdm_bar, - ) if expected_size is not None and expected_size != temp_file.tell(): raise EnvironmentError( @@ -822,7 +801,6 @@ def hf_hub_download( local_dir: Union[str, Path, None] = None, user_agent: Union[Dict, str, None] = None, force_download: bool = False, - proxies: Optional[Dict] = None, etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, token: Union[bool, str, None] = None, local_files_only: bool = False, @@ -893,9 +871,6 @@ def hf_hub_download( force_download (`bool`, *optional*, defaults to `False`): Whether the file should be downloaded even if it already exists in the local cache. - proxies (`dict`, *optional*): - Dictionary mapping protocol to the URL of the proxy passed to - `requests.request`. etag_timeout (`float`, *optional*, defaults to `10`): When fetching ETag, how many seconds to wait for the server to send data before giving up which is passed to `requests.request`. @@ -919,7 +894,7 @@ def hf_hub_download( or because it is set to `private` and you do not have access. [`~utils.RevisionNotFoundError`] If the revision to download from cannot be found. - [`~utils.EntryNotFoundError`] + [`~utils.RemoteEntryNotFoundError`] If the file to download cannot be found. [`~utils.LocalEntryNotFoundError`] If network is disabled or unavailable and file is not found in cache. @@ -999,7 +974,6 @@ def hf_hub_download( endpoint=endpoint, etag_timeout=etag_timeout, headers=hf_headers, - proxies=proxies, token=token, # Additional options cache_dir=cache_dir, @@ -1019,7 +993,6 @@ def hf_hub_download( endpoint=endpoint, etag_timeout=etag_timeout, headers=hf_headers, - proxies=proxies, token=token, # Additional options local_files_only=local_files_only, @@ -1040,7 +1013,6 @@ def _hf_hub_download_to_cache_dir( endpoint: Optional[str], etag_timeout: float, headers: Dict[str, str], - proxies: Optional[Dict], token: Optional[Union[bool, str]], # Additional options local_files_only: bool, @@ -1076,7 +1048,6 @@ def _hf_hub_download_to_cache_dir( repo_type=repo_type, revision=revision, endpoint=endpoint, - proxies=proxies, etag_timeout=etag_timeout, headers=headers, token=token, @@ -1172,7 +1143,6 @@ def _hf_hub_download_to_cache_dir( incomplete_path=Path(blob_path + ".incomplete"), destination_path=Path(blob_path), url_to_download=url_to_download, - proxies=proxies, headers=headers, expected_size=expected_size, filename=filename, @@ -1199,7 +1169,6 @@ def _hf_hub_download_to_local_dir( endpoint: Optional[str], etag_timeout: float, headers: Dict[str, str], - proxies: Optional[Dict], token: Union[bool, str, None], # Additional options cache_dir: str, @@ -1235,7 +1204,6 @@ def _hf_hub_download_to_local_dir( repo_type=repo_type, revision=revision, endpoint=endpoint, - proxies=proxies, etag_timeout=etag_timeout, headers=headers, token=token, @@ -1301,7 +1269,6 @@ def _hf_hub_download_to_local_dir( incomplete_path=paths.incomplete_path(etag), destination_path=paths.file_path, url_to_download=url_to_download, - proxies=proxies, headers=headers, expected_size=expected_size, filename=filename, @@ -1411,7 +1378,6 @@ def try_to_load_from_cache( def get_hf_file_metadata( url: str, token: Union[bool, str, None] = None, - proxies: Optional[Dict] = None, timeout: Optional[float] = constants.DEFAULT_REQUEST_TIMEOUT, library_name: Optional[str] = None, library_version: Optional[str] = None, @@ -1430,9 +1396,6 @@ def get_hf_file_metadata( folder. - If `False` or `None`, no token is provided. - If a string, it's used as the authentication token. - proxies (`dict`, *optional*): - Dictionary mapping protocol to the URL of the proxy passed to - `requests.request`. timeout (`float`, *optional*, defaults to 10): How many seconds to wait for the server to send metadata before giving up. library_name (`str`, *optional*): @@ -1460,31 +1423,23 @@ def get_hf_file_metadata( hf_headers["Accept-Encoding"] = "identity" # prevent any compression => we want to know the real size of the file # Retrieve metadata - r = _request_wrapper( - method="HEAD", - url=url, - headers=hf_headers, - allow_redirects=False, - follow_relative_redirects=True, - proxies=proxies, - timeout=timeout, - ) - hf_raise_for_status(r) + response = _httpx_follow_relative_redirects(method="HEAD", url=url, headers=hf_headers, timeout=timeout) + hf_raise_for_status(response) # Return return HfFileMetadata( - commit_hash=r.headers.get(constants.HUGGINGFACE_HEADER_X_REPO_COMMIT), - # We favor a custom header indicating the etag of the linked resource, and - # we fallback to the regular etag header. - etag=_normalize_etag(r.headers.get(constants.HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get("ETag")), + commit_hash=response.headers.get(constants.HUGGINGFACE_HEADER_X_REPO_COMMIT), + # We favor a custom header indicating the etag of the linked resource, and we fallback to the regular etag header. + etag=_normalize_etag( + response.headers.get(constants.HUGGINGFACE_HEADER_X_LINKED_ETAG) or response.headers.get("ETag") + ), # Either from response headers (if redirected) or defaults to request url - # Do not use directly `url`, as `_request_wrapper` might have followed relative - # redirects. - location=r.headers.get("Location") or r.request.url, # type: ignore + # Do not use directly `url` as we might have followed relative redirects. + location=response.headers.get("Location") or str(response.request.url), # type: ignore size=_int_or_none( - r.headers.get(constants.HUGGINGFACE_HEADER_X_LINKED_SIZE) or r.headers.get("Content-Length") + response.headers.get(constants.HUGGINGFACE_HEADER_X_LINKED_SIZE) or response.headers.get("Content-Length") ), - xet_file_data=parse_xet_file_data_from_response(r, endpoint=endpoint), # type: ignore + xet_file_data=parse_xet_file_data_from_response(response, endpoint=endpoint), # type: ignore ) @@ -1495,7 +1450,6 @@ def _get_metadata_or_catch_error( repo_type: str, revision: str, endpoint: Optional[str], - proxies: Optional[Dict], etag_timeout: Optional[float], headers: Dict[str, str], # mutated inplace! token: Union[bool, str, None], @@ -1544,9 +1498,9 @@ def _get_metadata_or_catch_error( try: try: metadata = get_hf_file_metadata( - url=url, proxies=proxies, timeout=etag_timeout, headers=headers, token=token, endpoint=endpoint + url=url, timeout=etag_timeout, headers=headers, token=token, endpoint=endpoint ) - except EntryNotFoundError as http_error: + except RemoteEntryNotFoundError as http_error: if storage_folder is not None and relative_filename is not None: # Cache the non-existence of the file commit_hash = http_error.response.headers.get(constants.HUGGINGFACE_HEADER_X_REPO_COMMIT) @@ -1597,21 +1551,17 @@ def _get_metadata_or_catch_error( if urlparse(url).netloc != urlparse(metadata.location).netloc: # Remove authorization header when downloading a LFS blob headers.pop("authorization", None) - except (requests.exceptions.SSLError, requests.exceptions.ProxyError): - # Actually raise for those subclasses of ConnectionError + except httpx.ProxyError: + # Actually raise on proxy error raise - except ( - requests.exceptions.ConnectionError, - requests.exceptions.Timeout, - OfflineModeIsEnabled, - ) as error: + except (httpx.ConnectError, httpx.TimeoutException, OfflineModeIsEnabled) as error: # Otherwise, our Internet connection is down. # etag is None head_error_call = error - except (RevisionNotFoundError, EntryNotFoundError): + except (RevisionNotFoundError, RemoteEntryNotFoundError): # The repo was found but the revision or entry doesn't exist on the Hub (never existed or got deleted) raise - except requests.HTTPError as error: + except HfHubHTTPError as error: # Multiple reasons for an http error: # - Repository is private and invalid/missing token sent # - Repository is gated and invalid/missing token sent @@ -1669,7 +1619,6 @@ def _download_to_tmp_and_move( incomplete_path: Path, destination_path: Path, url_to_download: str, - proxies: Optional[Dict], headers: Dict[str, str], expected_size: Optional[int], filename: str, @@ -1694,14 +1643,14 @@ def _download_to_tmp_and_move( # Do nothing if already exists (except if force_download=True) return - if incomplete_path.exists() and (force_download or (constants.HF_HUB_ENABLE_HF_TRANSFER and not proxies)): + if incomplete_path.exists() and (force_download or constants.HF_HUB_ENABLE_HF_TRANSFER): # By default, we will try to resume the download if possible. # However, if the user has set `force_download=True` or if `hf_transfer` is enabled, then we should # not resume the download => delete the incomplete file. message = f"Removing incomplete file '{incomplete_path}'" if force_download: message += " (force_download=True)" - elif constants.HF_HUB_ENABLE_HF_TRANSFER and not proxies: + elif constants.HF_HUB_ENABLE_HF_TRANSFER: message += " (hf_transfer=True)" logger.info(message) incomplete_path.unlink(missing_ok=True) @@ -1738,7 +1687,6 @@ def _download_to_tmp_and_move( http_get( url_to_download, f, - proxies=proxies, resume_size=resume_size, headers=headers, expected_size=expected_size, diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index e2827a6f19..a5aef6bbc5 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -48,8 +48,7 @@ ) from urllib.parse import quote, unquote -import requests -from requests.exceptions import HTTPError +import httpx from tqdm.auto import tqdm as base_tqdm from tqdm.contrib.concurrent import thread_map @@ -101,9 +100,9 @@ ) from .errors import ( BadRequestError, - EntryNotFoundError, GatedRepoError, HfHubHTTPError, + RemoteEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError, ) @@ -1780,7 +1779,7 @@ def whoami(self, token: Union[bool, str, None] = None) -> Dict: ) try: hf_raise_for_status(r) - except HTTPError as e: + except HfHubHTTPError as e: if e.response.status_code == 401: error_message = "Invalid user token." # Check which token is the effective one and generate the error message accordingly @@ -1793,7 +1792,7 @@ def whoami(self, token: Union[bool, str, None] = None) -> Dict: ) elif effective_token == _get_token_from_file(): error_message += " The token stored is invalid. Please run `hf auth login` to update it." - raise HTTPError(error_message, request=e.request, response=e.response) from e + raise HfHubHTTPError(error_message, response=e.response) from e raise return r.json() @@ -1834,7 +1833,7 @@ def get_token_permission( """ try: return self.whoami(token=token)["auth"]["accessToken"]["role"] - except (LocalTokenNotFoundError, HTTPError, KeyError): + except (LocalTokenNotFoundError, HfHubHTTPError, KeyError): return None def get_model_tags(self) -> Dict: @@ -3016,7 +3015,7 @@ def file_exists( return True except GatedRepoError: # raise specifically on gated repo raise - except (RepositoryNotFoundError, EntryNotFoundError, RevisionNotFoundError): + except (RepositoryNotFoundError, RemoteEntryNotFoundError, RevisionNotFoundError): return False @validate_hf_hub_args @@ -3106,7 +3105,7 @@ def list_repo_tree( does not exist. [`~utils.RevisionNotFoundError`]: If revision is not found (error 404) on the repo. - [`~utils.EntryNotFoundError`]: + [`~utils.RemoteEntryNotFoundError`]: If the tree (folder) does not exist (error 404) on the repo. Examples: @@ -3764,7 +3763,7 @@ def create_repo( try: hf_raise_for_status(r) - except HTTPError as err: + except HfHubHTTPError as err: if exist_ok and err.response.status_code == 409: # Repo already exists and `exist_ok=True` pass @@ -3826,7 +3825,7 @@ def delete_repo( json["type"] = repo_type headers = self._build_hf_headers(token=token) - r = get_session().delete(path, headers=headers, json=json) + r = get_session().request("DELETE", path, headers=headers, json=json) try: hf_raise_for_status(r) except RepositoryNotFoundError: @@ -4338,12 +4337,12 @@ def _payload_as_ndjson() -> Iterable[bytes]: params = {"create_pr": "1"} if create_pr else None try: - commit_resp = get_session().post(url=commit_url, headers=headers, data=data, params=params) + commit_resp = get_session().post(url=commit_url, headers=headers, content=data, params=params) hf_raise_for_status(commit_resp, endpoint_name="commit") except RepositoryNotFoundError as e: e.append_to_message(_CREATE_COMMIT_NO_REPO_ERROR_MESSAGE) raise - except EntryNotFoundError as e: + except RemoteEntryNotFoundError as e: if nb_deletions > 0 and "A file with this name doesn't exist" in str(e): e.append_to_message( "\nMake sure to differentiate file and folder paths in delete" @@ -4653,7 +4652,7 @@ def upload_file( Raises the following errors: - - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + - [`HfHubHTTPError`] if the HuggingFace API returned an error - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid @@ -4889,7 +4888,7 @@ def upload_folder( Raises the following errors: - - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + - [`HfHubHTTPError`] if the HuggingFace API returned an error - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid @@ -5077,7 +5076,7 @@ def delete_file( Raises the following errors: - - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + - [`HfHubHTTPError`] if the HuggingFace API returned an error - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid @@ -5086,7 +5085,7 @@ def delete_file( or because it is set to `private` and you do not have access. - [`~utils.RevisionNotFoundError`] If the revision to download from cannot be found. - - [`~utils.EntryNotFoundError`] + - [`~utils.RemoteEntryNotFoundError`] If the file to download cannot be found. @@ -5383,7 +5382,6 @@ def get_hf_file_metadata( *, url: str, token: Union[bool, str, None] = None, - proxies: Optional[Dict] = None, timeout: Optional[float] = constants.DEFAULT_REQUEST_TIMEOUT, ) -> HfFileMetadata: """Fetch metadata of a file versioned on the Hub for a given url. @@ -5396,8 +5394,6 @@ def get_hf_file_metadata( token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. - proxies (`dict`, *optional*): - Dictionary mapping protocol to the URL of the proxy passed to `requests.request`. timeout (`float`, *optional*, defaults to 10): How many seconds to wait for the server to send metadata before giving up. @@ -5411,7 +5407,6 @@ def get_hf_file_metadata( return get_hf_file_metadata( url=url, token=token, - proxies=proxies, timeout=timeout, library_name=self.library_name, library_version=self.library_version, @@ -5431,7 +5426,6 @@ def hf_hub_download( cache_dir: Union[str, Path, None] = None, local_dir: Union[str, Path, None] = None, force_download: bool = False, - proxies: Optional[Dict] = None, etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, token: Union[bool, str, None] = None, local_files_only: bool = False, @@ -5495,12 +5489,9 @@ def hf_hub_download( force_download (`bool`, *optional*, defaults to `False`): Whether the file should be downloaded even if it already exists in the local cache. - proxies (`dict`, *optional*): - Dictionary mapping protocol to the URL of the proxy passed to - `requests.request`. etag_timeout (`float`, *optional*, defaults to `10`): When fetching ETag, how many seconds to wait for the server to send - data before giving up which is passed to `requests.request`. + data before giving up which is passed to `httpx.request`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see @@ -5519,7 +5510,7 @@ def hf_hub_download( or because it is set to `private` and you do not have access. [`~utils.RevisionNotFoundError`] If the revision to download from cannot be found. - [`~utils.EntryNotFoundError`] + [`~utils.RemoteEntryNotFoundError`] If the file to download cannot be found. [`~utils.LocalEntryNotFoundError`] If network is disabled or unavailable and file is not found in cache. @@ -5551,7 +5542,6 @@ def hf_hub_download( user_agent=self.user_agent, force_download=force_download, force_filename=force_filename, - proxies=proxies, etag_timeout=etag_timeout, resume_download=resume_download, token=token, @@ -5568,7 +5558,6 @@ def snapshot_download( revision: Optional[str] = None, cache_dir: Union[str, Path, None] = None, local_dir: Union[str, Path, None] = None, - proxies: Optional[Dict] = None, etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, force_download: bool = False, token: Union[bool, str, None] = None, @@ -5609,12 +5598,9 @@ def snapshot_download( Path to the folder where cached files are stored. local_dir (`str` or `Path`, *optional*): If provided, the downloaded files will be placed under this directory. - proxies (`dict`, *optional*): - Dictionary mapping protocol to the URL of the proxy passed to - `requests.request`. etag_timeout (`float`, *optional*, defaults to `10`): When fetching ETag, how many seconds to wait for the server to send - data before giving up which is passed to `requests.request`. + data before giving up which is passed to `httpx.request`. force_download (`bool`, *optional*, defaults to `False`): Whether the file should be downloaded even if it already exists in the local cache. token (Union[bool, str, None], optional): @@ -5672,7 +5658,6 @@ def snapshot_download( library_name=self.library_name, library_version=self.library_version, user_agent=self.user_agent, - proxies=proxies, etag_timeout=etag_timeout, resume_download=resume_download, force_download=force_download, @@ -6361,7 +6346,7 @@ def get_discussion_details( Raises the following errors: - - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + - [`HfHubHTTPError`] if the HuggingFace API returned an error - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid @@ -6454,7 +6439,7 @@ def create_discussion( Raises the following errors: - - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + - [`HfHubHTTPError`] if the HuggingFace API returned an error - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid @@ -6542,7 +6527,7 @@ def create_pull_request( Raises the following errors: - - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + - [`HfHubHTTPError`] if the HuggingFace API returned an error - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid @@ -6569,7 +6554,7 @@ def _post_discussion_changes( body: Optional[dict] = None, token: Union[bool, str, None] = None, repo_type: Optional[str] = None, - ) -> requests.Response: + ) -> httpx.Response: """Internal utility to POST changes to a Discussion or Pull Request""" if not isinstance(discussion_num, int) or discussion_num <= 0: raise ValueError("Invalid discussion_num, must be a positive integer") @@ -6582,7 +6567,7 @@ def _post_discussion_changes( path = f"{self.endpoint}/api/{repo_id}/discussions/{discussion_num}/{resource}" headers = self._build_hf_headers(token=token) - resp = requests.post(path, headers=headers, json=body) + resp = get_session().post(path, headers=headers, json=body) hf_raise_for_status(resp) return resp @@ -6645,7 +6630,7 @@ def comment_discussion( Raises the following errors: - - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + - [`HfHubHTTPError`] if the HuggingFace API returned an error - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid @@ -6715,7 +6700,7 @@ def rename_discussion( Raises the following errors: - - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + - [`HfHubHTTPError`] if the HuggingFace API returned an error - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid @@ -6788,7 +6773,7 @@ def change_discussion_status( Raises the following errors: - - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + - [`HfHubHTTPError`] if the HuggingFace API returned an error - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid @@ -6850,7 +6835,7 @@ def merge_pull_request( Raises the following errors: - - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + - [`HfHubHTTPError`] if the HuggingFace API returned an error - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid @@ -6909,7 +6894,7 @@ def edit_discussion_comment( Raises the following errors: - - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + - [`HfHubHTTPError`] if the HuggingFace API returned an error - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid @@ -6970,7 +6955,7 @@ def hide_discussion_comment( Raises the following errors: - - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + - [`HfHubHTTPError`] if the HuggingFace API returned an error - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid @@ -7051,7 +7036,8 @@ def delete_space_secret(self, repo_id: str, key: str, *, token: Union[bool, str, https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. """ - r = get_session().delete( + r = get_session().request( + "DELETE", f"{self.endpoint}/api/spaces/{repo_id}/secrets", headers=self._build_hf_headers(token=token), json={"key": key}, @@ -7142,7 +7128,8 @@ def delete_space_variable( https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass `False`. """ - r = get_session().delete( + r = get_session().request( + "DELETE", f"{self.endpoint}/api/spaces/{repo_id}/variables", headers=self._build_hf_headers(token=token), json={"key": key}, @@ -7419,7 +7406,7 @@ def duplicate_space( [`~utils.RepositoryNotFoundError`]: If one of `from_id` or `to_id` cannot be found. This may be because it doesn't exist, or because it is set to `private` and you do not have access. - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: If the HuggingFace API returned an error Example: @@ -7469,7 +7456,7 @@ def duplicate_space( try: hf_raise_for_status(r) - except HTTPError as err: + except HfHubHTTPError as err: if exist_ok and err.response.status_code == 409: # Repo already exists and `exist_ok=True` pass @@ -8426,7 +8413,7 @@ def create_collection( ) try: hf_raise_for_status(r) - except HTTPError as err: + except HfHubHTTPError as err: if exists_ok and err.response.status_code == 409: # Collection already exists and `exists_ok=True` slug = r.json()["slug"] @@ -8537,7 +8524,7 @@ def delete_collection( ) try: hf_raise_for_status(r) - except HTTPError as err: + except HfHubHTTPError as err: if missing_ok and err.response.status_code == 404: # Collection doesn't exists and `missing_ok=True` return @@ -8577,12 +8564,12 @@ def add_collection_item( Returns: [`Collection`] Raises: - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` or `admin` role in the organization the repo belongs to or if you passed a `read` token. - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 404 if the item you try to add to the collection does not exist on the Hub. - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 409 if the item you try to add to the collection is already in the collection (and exists_ok=False) Example: @@ -8618,7 +8605,7 @@ def add_collection_item( ) try: hf_raise_for_status(r) - except HTTPError as err: + except HfHubHTTPError as err: if exists_ok and err.response.status_code == 409: # Item already exists and `exists_ok=True` return self.get_collection(collection_slug, token=token) @@ -8724,7 +8711,7 @@ def delete_collection_item( ) try: hf_raise_for_status(r) - except HTTPError as err: + except HfHubHTTPError as err: if missing_ok and err.response.status_code == 404: # Item already deleted and `missing_ok=True` return @@ -8766,9 +8753,9 @@ def list_pending_access_requests( be populated with user's answers. Raises: - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 400 if the repo is not gated. - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` or `admin` role in the organization the repo belongs to or if you passed a `read` token. @@ -8832,9 +8819,9 @@ def list_accepted_access_requests( be populated with user's answers. Raises: - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 400 if the repo is not gated. - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` or `admin` role in the organization the repo belongs to or if you passed a `read` token. @@ -8894,9 +8881,9 @@ def list_rejected_access_requests( be populated with user's answers. Raises: - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 400 if the repo is not gated. - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` or `admin` role in the organization the repo belongs to or if you passed a `read` token. @@ -8978,16 +8965,16 @@ def cancel_access_request( To disable authentication, pass `False`. Raises: - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 400 if the repo is not gated. - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` or `admin` role in the organization the repo belongs to or if you passed a `read` token. - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 404 if the user does not exist on the Hub. - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 404 if the user access request cannot be found. - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 404 if the user access request is already in the pending list. """ self._handle_access_request(repo_id, user, "pending", repo_type=repo_type, token=token) @@ -9020,16 +9007,16 @@ def accept_access_request( To disable authentication, pass `False`. Raises: - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 400 if the repo is not gated. - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` or `admin` role in the organization the repo belongs to or if you passed a `read` token. - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 404 if the user does not exist on the Hub. - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 404 if the user access request cannot be found. - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 404 if the user access request is already in the accepted list. """ self._handle_access_request(repo_id, user, "accepted", repo_type=repo_type, token=token) @@ -9070,16 +9057,16 @@ def reject_access_request( To disable authentication, pass `False`. Raises: - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 400 if the repo is not gated. - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` or `admin` role in the organization the repo belongs to or if you passed a `read` token. - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 404 if the user does not exist on the Hub. - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 404 if the user access request cannot be found. - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 404 if the user access request is already in the rejected list. """ self._handle_access_request( @@ -9143,14 +9130,14 @@ def grant_access( To disable authentication, pass `False`. Raises: - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 400 if the repo is not gated. - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 400 if the user already has access to the repo. - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` or `admin` role in the organization the repo belongs to or if you passed a `read` token. - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 404 if the user does not exist on the Hub. """ if repo_type not in constants.REPO_TYPES: @@ -9741,7 +9728,7 @@ def get_user_overview(self, username: str, token: Union[bool, str, None] = None) `User`: A [`User`] object with the user's overview. Raises: - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 404 If the user does not exist on the Hub. """ r = get_session().get( @@ -9767,7 +9754,7 @@ def list_organization_members(self, organization: str, token: Union[bool, str, N `Iterable[User]`: A list of [`User`] objects with the members of the organization. Raises: - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 404 If the organization does not exist on the Hub. """ @@ -9795,7 +9782,7 @@ def list_user_followers(self, username: str, token: Union[bool, str, None] = Non `Iterable[User]`: A list of [`User`] objects with the followers of the user. Raises: - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 404 If the user does not exist on the Hub. """ @@ -9823,7 +9810,7 @@ def list_user_following(self, username: str, token: Union[bool, str, None] = Non `Iterable[User]`: A list of [`User`] objects with the users followed by the user. Raises: - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 404 If the user does not exist on the Hub. """ @@ -9892,7 +9879,7 @@ def paper_info(self, id: str) -> PaperInfo: `PaperInfo`: A `PaperInfo` object. Raises: - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + [`HfHubHTTPError`]: HTTP 404 If the paper does not exist on the Hub. """ path = f"{self.endpoint}/api/papers/{id}" @@ -10100,29 +10087,28 @@ def fetch_job_logs( time.sleep(sleep_time) sleep_time = min(max_wait_time, max(min_wait_time, sleep_time * 2)) try: - resp = get_session().get( + with get_session().stream( + "GET", f"https://huggingface.co/api/jobs/{namespace}/{job_id}/logs", headers=self._build_hf_headers(token=token), - stream=True, timeout=120, - ) - log = None - for line in resp.iter_lines(chunk_size=1): - line = line.decode("utf-8") - if line and line.startswith("data: {"): - data = json.loads(line[len("data: ") :]) - # timestamp = data["timestamp"] - if not data["data"].startswith("===== Job started"): - logging_started = True - log = data["data"] - yield log - logging_finished = logging_started - except requests.exceptions.ChunkedEncodingError: + ) as response: + log = None + for line in response.iter_lines(): + if line and line.startswith("data: {"): + data = json.loads(line[len("data: ") :]) + # timestamp = data["timestamp"] + if not data["data"].startswith("===== Job started"): + logging_started = True + log = data["data"] + yield log + logging_finished = logging_started + except httpx.DecodingError: # Response ended prematurely break except KeyboardInterrupt: break - except requests.exceptions.ConnectionError as err: + except httpx.NetworkError as err: is_timeout = err.__context__ and isinstance(getattr(err.__context__, "__cause__", None), TimeoutError) if logging_started or not is_timeout: raise diff --git a/src/huggingface_hub/hf_file_system.py b/src/huggingface_hub/hf_file_system.py index b8d1d5841c..e82365e3ce 100644 --- a/src/huggingface_hub/hf_file_system.py +++ b/src/huggingface_hub/hf_file_system.py @@ -2,6 +2,7 @@ import re import tempfile from collections import deque +from contextlib import ExitStack from dataclasses import dataclass, field from datetime import datetime from itertools import chain @@ -10,16 +11,16 @@ from urllib.parse import quote, unquote import fsspec +import httpx from fsspec.callbacks import _DEFAULT_CALLBACK, NoOpCallback, TqdmCallback from fsspec.utils import isfilelike -from requests import Response from . import constants from ._commit_api import CommitOperationCopy, CommitOperationDelete -from .errors import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError +from .errors import EntryNotFoundError, HfHubHTTPError, RepositoryNotFoundError, RevisionNotFoundError from .file_download import hf_hub_url, http_get from .hf_api import HfApi, LastCommitInfo, RepoFile -from .utils import HFValidationError, hf_raise_for_status, http_backoff +from .utils import HFValidationError, hf_raise_for_status, http_backoff, http_stream_backoff # Regex used to match special revisions with "/" in them (see #1710) @@ -1039,8 +1040,9 @@ def __init__( super().__init__( fs, self.resolved_path.unresolve(), mode=mode, block_size=block_size, cache_type=cache_type, **kwargs ) - self.response: Optional[Response] = None + self.response: Optional[httpx.Response] = None self.fs: HfFileSystem + self._exit_stack = ExitStack() def seek(self, loc: int, whence: int = 0): if loc == 0 and whence == 1: @@ -1050,55 +1052,32 @@ def seek(self, loc: int, whence: int = 0): raise ValueError("Cannot seek streaming HF file") def read(self, length: int = -1): - read_args = (length,) if length >= 0 else () + """Read the remote file. + + If the file is already open, we reuse the connection. + Otherwise, open a new connection and read from it. + + If reading the stream fails, we retry with a new connection. + """ if self.response is None: - url = hf_hub_url( - repo_id=self.resolved_path.repo_id, - revision=self.resolved_path.revision, - filename=self.resolved_path.path_in_repo, - repo_type=self.resolved_path.repo_type, - endpoint=self.fs.endpoint, - ) - self.response = http_backoff( - "GET", - url, - headers=self.fs._api._build_hf_headers(), - retry_on_status_codes=(500, 502, 503, 504), - stream=True, - timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT, - ) - hf_raise_for_status(self.response) - try: - self.response.raw.decode_content = True - out = self.response.raw.read(*read_args) - except Exception: - self.response.close() + self._open_connection() - # Retry by recreating the connection - url = hf_hub_url( - repo_id=self.resolved_path.repo_id, - revision=self.resolved_path.revision, - filename=self.resolved_path.path_in_repo, - repo_type=self.resolved_path.repo_type, - endpoint=self.fs.endpoint, - ) - self.response = http_backoff( - "GET", - url, - headers={"Range": "bytes=%d-" % self.loc, **self.fs._api._build_hf_headers()}, - retry_on_status_codes=(500, 502, 503, 504), - stream=True, - timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT, - ) - hf_raise_for_status(self.response) + retried_once = False + while True: try: - self.response.raw.decode_content = True - out = self.response.raw.read(*read_args) + if self.response is None: + return b"" # Already read the entire file + out = _partial_read(self.response, length) + self.loc += len(out) + return out except Exception: - self.response.close() - raise - self.loc += len(out) - return out + if self.response is not None: + self.response.close() + if retried_once: # Already retried once, give up + raise + # First failure, retry with range header + self._open_connection() + retried_once = True def url(self) -> str: return self.fs.url(self.path) @@ -1107,11 +1086,43 @@ def __del__(self): if not hasattr(self, "resolved_path"): # Means that the constructor failed. Nothing to do. return + self._exit_stack.close() return super().__del__() def __reduce__(self): return reopen, (self.fs, self.path, self.mode, self.blocksize, self.cache.name) + def _open_connection(self): + """Open a connection to the remote file.""" + url = hf_hub_url( + repo_id=self.resolved_path.repo_id, + revision=self.resolved_path.revision, + filename=self.resolved_path.path_in_repo, + repo_type=self.resolved_path.repo_type, + endpoint=self.fs.endpoint, + ) + headers = self.fs._api._build_hf_headers() + if self.loc > 0: + headers["Range"] = f"bytes={self.loc}-" + self.response = self._exit_stack.enter_context( + http_stream_backoff( + "GET", + url, + headers=headers, + retry_on_status_codes=(500, 502, 503, 504), + timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT, + ) + ) + + try: + hf_raise_for_status(self.response) + except HfHubHTTPError as e: + if e.response.status_code == 416: + # Range not satisfiable => means that we have already read the entire file + self.response = None + return + raise + def safe_revision(revision: str) -> str: return revision if SPECIAL_REFS_REVISION_REGEX.match(revision) else safe_quote(revision) @@ -1134,3 +1145,26 @@ def _raise_file_not_found(path: str, err: Optional[Exception]) -> NoReturn: def reopen(fs: HfFileSystem, path: str, mode: str, block_size: int, cache_type: str): return fs.open(path, mode=mode, block_size=block_size, cache_type=cache_type) + + +def _partial_read(response: httpx.Response, length: int = -1) -> bytes: + """ + Read up to `length` bytes from a streamed response. + If length == -1, read until EOF. + """ + buf = bytearray() + if length < -1: + raise ValueError("length must be -1 or >= 0") + if length == 0: + return b"" + if length == -1: + for chunk in response.iter_bytes(): + buf.extend(chunk) + return bytes(buf) + + for chunk in response.iter_bytes(chunk_size=length): + buf.extend(chunk) + if len(buf) >= length: + return bytes(buf[:length]) + + return bytes(buf) # may be < length if response ended diff --git a/src/huggingface_hub/hub_mixin.py b/src/huggingface_hub/hub_mixin.py index 9fa702ceda..d1ddee213f 100644 --- a/src/huggingface_hub/hub_mixin.py +++ b/src/huggingface_hub/hub_mixin.py @@ -150,7 +150,6 @@ class ModelHubMixin: ... *, ... force_download: bool = False, ... resume_download: Optional[bool] = None, - ... proxies: Optional[Dict] = None, ... token: Optional[Union[str, bool]] = None, ... cache_dir: Optional[Union[str, Path]] = None, ... local_files_only: bool = False, @@ -467,7 +466,6 @@ def from_pretrained( *, force_download: bool = False, resume_download: Optional[bool] = None, - proxies: Optional[Dict] = None, token: Optional[Union[str, bool]] = None, cache_dir: Optional[Union[str, Path]] = None, local_files_only: bool = False, @@ -488,9 +486,6 @@ def from_pretrained( force_download (`bool`, *optional*, defaults to `False`): Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding the existing cache. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request. token (`str` or `bool`, *optional*): The token to use as HTTP bearer authorization for remote files. By default, it will use the token cached when running `hf auth login`. @@ -516,7 +511,6 @@ def from_pretrained( revision=revision, cache_dir=cache_dir, force_download=force_download, - proxies=proxies, resume_download=resume_download, token=token, local_files_only=local_files_only, @@ -570,7 +564,6 @@ def from_pretrained( revision=revision, cache_dir=cache_dir, force_download=force_download, - proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, token=token, @@ -592,7 +585,6 @@ def _from_pretrained( revision: Optional[str], cache_dir: Optional[Union[str, Path]], force_download: bool, - proxies: Optional[Dict], resume_download: Optional[bool], local_files_only: bool, token: Optional[Union[str, bool]], @@ -616,9 +608,6 @@ def _from_pretrained( force_download (`bool`, *optional*, defaults to `False`): Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding the existing cache. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint (e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`). token (`str` or `bool`, *optional*): The token to use as HTTP bearer authorization for remote files. By default, it will use the token cached when running `hf auth login`. @@ -779,7 +768,6 @@ def _from_pretrained( revision: Optional[str], cache_dir: Optional[Union[str, Path]], force_download: bool, - proxies: Optional[Dict], resume_download: Optional[bool], local_files_only: bool, token: Union[str, bool, None], @@ -801,7 +789,6 @@ def _from_pretrained( revision=revision, cache_dir=cache_dir, force_download=force_download, - proxies=proxies, resume_download=resume_download, token=token, local_files_only=local_files_only, @@ -814,7 +801,6 @@ def _from_pretrained( revision=revision, cache_dir=cache_dir, force_download=force_download, - proxies=proxies, resume_download=resume_download, token=token, local_files_only=local_files_only, diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index 5e39dee55c..22a979509c 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -34,14 +34,14 @@ # - Only the main parameters are publicly exposed. Power users can always read the docs for more options. import base64 import logging +import os import re import warnings +from contextlib import ExitStack from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Union, overload -from requests import HTTPError - from huggingface_hub import constants -from huggingface_hub.errors import BadRequestError, InferenceTimeoutError +from huggingface_hub.errors import BadRequestError, HfHubHTTPError, InferenceTimeoutError from huggingface_hub.inference._common import ( TASKS_EXPECTING_IMAGES, ContentT, @@ -101,7 +101,12 @@ ZeroShotImageClassificationOutputElement, ) from huggingface_hub.inference._providers import PROVIDER_OR_POLICY_T, get_provider_helper -from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status +from huggingface_hub.utils import ( + build_hf_headers, + get_session, + hf_raise_for_status, + validate_hf_hub_args, +) from huggingface_hub.utils._auth import get_token @@ -147,8 +152,6 @@ class InferenceClient: Requests can only be billed to an organization the user is a member of, and which has subscribed to Enterprise Hub. cookies (`Dict[str, str]`, `optional`): Additional cookies to send to the server. - proxies (`Any`, `optional`): - Proxies to use for the request. base_url (`str`, `optional`): Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`] follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None. @@ -157,6 +160,7 @@ class InferenceClient: follow the same pattern as `openai.OpenAI` client. Cannot be used if `token` is set. Defaults to None. """ + @validate_hf_hub_args def __init__( self, model: Optional[str] = None, @@ -166,7 +170,6 @@ def __init__( timeout: Optional[float] = None, headers: Optional[Dict[str, str]] = None, cookies: Optional[Dict[str, str]] = None, - proxies: Optional[Any] = None, bill_to: Optional[str] = None, # OpenAI compatibility base_url: Optional[str] = None, @@ -228,11 +231,21 @@ def __init__( self.cookies = cookies self.timeout = timeout - self.proxies = proxies + + self.exit_stack = ExitStack() def __repr__(self): return f"" + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.exit_stack.close() + + def close(self): + self.exit_stack.close() + @overload def _inner_post( # type: ignore[misc] self, request_parameters: RequestParameters, *, stream: Literal[False] = ... @@ -241,44 +254,46 @@ def _inner_post( # type: ignore[misc] @overload def _inner_post( # type: ignore[misc] self, request_parameters: RequestParameters, *, stream: Literal[True] = ... - ) -> Iterable[bytes]: ... + ) -> Iterable[str]: ... @overload def _inner_post( self, request_parameters: RequestParameters, *, stream: bool = False - ) -> Union[bytes, Iterable[bytes]]: ... + ) -> Union[bytes, Iterable[str]]: ... def _inner_post( self, request_parameters: RequestParameters, *, stream: bool = False - ) -> Union[bytes, Iterable[bytes]]: + ) -> Union[bytes, Iterable[str]]: """Make a request to the inference server.""" # TODO: this should be handled in provider helpers directly if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers: request_parameters.headers["Accept"] = "image/png" try: - response = get_session().post( - request_parameters.url, - json=request_parameters.json, - data=request_parameters.data, - headers=request_parameters.headers, - cookies=self.cookies, - timeout=self.timeout, - stream=stream, - proxies=self.proxies, + response = self.exit_stack.enter_context( + get_session().stream( + "POST", + request_parameters.url, + json=request_parameters.json, + content=request_parameters.data, + headers=request_parameters.headers, + cookies=self.cookies, + timeout=self.timeout, + ) ) + hf_raise_for_status(response) + if stream: + return response.iter_lines() + else: + return response.read() except TimeoutError as error: # Convert any `TimeoutError` to a `InferenceTimeoutError` raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore - - try: - hf_raise_for_status(response) - return response.iter_lines() if stream else response.content - except HTTPError as error: + except HfHubHTTPError as error: if error.response.status_code == 422 and request_parameters.task != "unknown": msg = str(error.args[0]) if len(error.response.text) > 0: - msg += f"\n{error.response.text}\n" + msg += f"{os.linesep}{error.response.text}{os.linesep}" error.args = (msg,) + error.args[1:] raise @@ -312,7 +327,7 @@ def audio_classification( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `HTTPError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -363,7 +378,7 @@ def audio_to_audio( Raises: `InferenceTimeoutError`: If the model is unavailable or the request times out. - `HTTPError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -416,7 +431,7 @@ def automatic_speech_recognition( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `HTTPError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -622,7 +637,7 @@ def chat_completion( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `HTTPError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -976,7 +991,7 @@ def document_question_answering( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `HTTPError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. @@ -1051,7 +1066,7 @@ def feature_extraction( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `HTTPError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -1113,7 +1128,7 @@ def fill_mask( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `HTTPError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -1166,7 +1181,7 @@ def image_classification( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `HTTPError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -1228,7 +1243,7 @@ def image_segmentation( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `HTTPError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -1305,7 +1320,7 @@ def image_to_image( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `HTTPError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -1435,7 +1450,7 @@ def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> Imag Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `HTTPError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -1487,7 +1502,7 @@ def object_detection( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `HTTPError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. `ValueError`: If the request output is not a List. @@ -1563,7 +1578,7 @@ def question_answering( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `HTTPError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -1618,7 +1633,7 @@ def sentence_similarity( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `HTTPError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -1679,7 +1694,7 @@ def summarization( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `HTTPError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -1744,7 +1759,7 @@ def table_question_answering( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `HTTPError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -1787,7 +1802,7 @@ def tabular_classification(self, table: Dict[str, Any], *, model: Optional[str] Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `HTTPError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -1842,7 +1857,7 @@ def tabular_regression(self, table: Dict[str, Any], *, model: Optional[str] = No Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `HTTPError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -1903,7 +1918,7 @@ def text_classification( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `HTTPError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -2192,7 +2207,7 @@ def text_generation( If input values are not valid. No HTTP call is made to the server. [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `HTTPError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -2381,7 +2396,7 @@ def text_generation( # Handle errors separately for more precise error messages try: bytes_output = self._inner_post(request_parameters, stream=stream or False) - except HTTPError as e: + except HfHubHTTPError as e: match = MODEL_KWARGS_NOT_USED_REGEX.search(str(e)) if isinstance(e, BadRequestError) and match: unused_params = [kwarg.strip("' ") for kwarg in match.group(1).split(",")] @@ -2484,7 +2499,7 @@ def text_to_image( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `HTTPError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -2756,7 +2771,7 @@ def text_to_speech( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `HTTPError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -2906,7 +2921,7 @@ def token_classification( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `HTTPError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -2991,7 +3006,7 @@ def translation( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `HTTPError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. `ValueError`: If only one of the `src_lang` and `tgt_lang` arguments are provided. @@ -3066,7 +3081,7 @@ def visual_question_answering( Raises: `InferenceTimeoutError`: If the model is unavailable or the request times out. - `HTTPError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -3133,7 +3148,7 @@ def zero_shot_classification( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `HTTPError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example with `multi_label=False`: @@ -3235,7 +3250,7 @@ def zero_shot_image_classification( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `HTTPError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: diff --git a/src/huggingface_hub/inference/_common.py b/src/huggingface_hub/inference/_common.py index c7803d14ee..aca297df34 100644 --- a/src/huggingface_hub/inference/_common.py +++ b/src/huggingface_hub/inference/_common.py @@ -36,10 +36,11 @@ overload, ) -from requests import HTTPError +import httpx from huggingface_hub.errors import ( GenerationError, + HfHubHTTPError, IncompleteGenerationError, OverloadedError, TextGenerationError, @@ -52,7 +53,6 @@ if TYPE_CHECKING: - from aiohttp import ClientResponse, ClientSession from PIL.Image import Image # TYPES @@ -279,13 +279,13 @@ def _as_dict(response: Union[bytes, Dict]) -> Dict: def _stream_text_generation_response( - bytes_output_as_lines: Iterable[bytes], details: bool + output_lines: Iterable[str], details: bool ) -> Union[Iterable[str], Iterable[TextGenerationStreamOutput]]: """Used in `InferenceClient.text_generation`.""" # Parse ServerSentEvents - for byte_payload in bytes_output_as_lines: + for line in output_lines: try: - output = _format_text_generation_stream_output(byte_payload, details) + output = _format_text_generation_stream_output(line, details) except StopIteration: break if output is not None: @@ -293,13 +293,13 @@ def _stream_text_generation_response( async def _async_stream_text_generation_response( - bytes_output_as_lines: AsyncIterable[bytes], details: bool + output_lines: AsyncIterable[str], details: bool ) -> Union[AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]: """Used in `AsyncInferenceClient.text_generation`.""" # Parse ServerSentEvents - async for byte_payload in bytes_output_as_lines: + async for line in output_lines: try: - output = _format_text_generation_stream_output(byte_payload, details) + output = _format_text_generation_stream_output(line, details) except StopIteration: break if output is not None: @@ -307,17 +307,17 @@ async def _async_stream_text_generation_response( def _format_text_generation_stream_output( - byte_payload: bytes, details: bool + line: str, details: bool ) -> Optional[Union[str, TextGenerationStreamOutput]]: - if not byte_payload.startswith(b"data:"): + if not line.startswith("data:"): return None # empty line - if byte_payload.strip() == b"data: [DONE]": + if line.strip() == "data: [DONE]": raise StopIteration("[DONE] signal received.") # Decode payload - payload = byte_payload.decode("utf-8") - json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) + payload = line.lstrip("data:").rstrip("/n") + json_payload = json.loads(payload) # Either an error as being returned if json_payload.get("error") is not None: @@ -329,12 +329,12 @@ def _format_text_generation_stream_output( def _stream_chat_completion_response( - bytes_lines: Iterable[bytes], + lines: Iterable[str], ) -> Iterable[ChatCompletionStreamOutput]: """Used in `InferenceClient.chat_completion` if model is served with TGI.""" - for item in bytes_lines: + for line in lines: try: - output = _format_chat_completion_stream_output(item) + output = _format_chat_completion_stream_output(line) except StopIteration: break if output is not None: @@ -342,12 +342,12 @@ def _stream_chat_completion_response( async def _async_stream_chat_completion_response( - bytes_lines: AsyncIterable[bytes], + lines: AsyncIterable[str], ) -> AsyncIterable[ChatCompletionStreamOutput]: """Used in `AsyncInferenceClient.chat_completion`.""" - async for item in bytes_lines: + async for line in lines: try: - output = _format_chat_completion_stream_output(item) + output = _format_chat_completion_stream_output(line) except StopIteration: break if output is not None: @@ -355,17 +355,16 @@ async def _async_stream_chat_completion_response( def _format_chat_completion_stream_output( - byte_payload: bytes, + line: str, ) -> Optional[ChatCompletionStreamOutput]: - if not byte_payload.startswith(b"data:"): + if not line.startswith("data:"): return None # empty line - if byte_payload.strip() == b"data: [DONE]": + if line.strip() == "data: [DONE]": raise StopIteration("[DONE] signal received.") # Decode payload - payload = byte_payload.decode("utf-8") - json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) + json_payload = json.loads(line.lstrip("data:").strip()) # Either an error as being returned if json_payload.get("error") is not None: @@ -375,13 +374,9 @@ def _format_chat_completion_stream_output( return ChatCompletionStreamOutput.parse_obj_as_instance(json_payload) -async def _async_yield_from(client: "ClientSession", response: "ClientResponse") -> AsyncIterable[bytes]: - try: - async for byte_payload in response.content: - yield byte_payload.strip() - finally: - # Always close the underlying HTTP session to avoid resource leaks - await client.close() +async def _async_yield_from(client: httpx.AsyncClient, response: httpx.Response) -> AsyncIterable[str]: + async for line in response.aiter_lines(): + yield line.strip() # "TGI servers" are servers running with the `text-generation-inference` backend. @@ -420,7 +415,7 @@ def _get_unsupported_text_generation_kwargs(model: Optional[str]) -> List[str]: # ---------------------- -def raise_text_generation_error(http_error: HTTPError) -> NoReturn: +def raise_text_generation_error(http_error: HfHubHTTPError) -> NoReturn: """ Try to parse text-generation-inference error message and raise HTTPError in any case. @@ -429,6 +424,8 @@ def raise_text_generation_error(http_error: HTTPError) -> NoReturn: The HTTPError that have been raised. """ # Try to parse a Text Generation Inference error + if http_error.response is None: + raise http_error try: # Hacky way to retrieve payload in case of aiohttp error diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index 95eaf3e7e5..b25a231052 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -21,12 +21,16 @@ import asyncio import base64 import logging +import os import re import warnings -from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, Optional, Set, Union, overload +from contextlib import AsyncExitStack +from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, Optional, Union, overload + +import httpx from huggingface_hub import constants -from huggingface_hub.errors import InferenceTimeoutError +from huggingface_hub.errors import BadRequestError, HfHubHTTPError, InferenceTimeoutError from huggingface_hub.inference._common import ( TASKS_EXPECTING_IMAGES, ContentT, @@ -86,15 +90,19 @@ ZeroShotImageClassificationOutputElement, ) from huggingface_hub.inference._providers import PROVIDER_OR_POLICY_T, get_provider_helper -from huggingface_hub.utils import build_hf_headers +from huggingface_hub.utils import ( + build_hf_headers, + get_async_session, + hf_raise_for_status, + validate_hf_hub_args, +) from huggingface_hub.utils._auth import get_token -from .._common import _async_yield_from, _import_aiohttp +from .._common import _async_yield_from if TYPE_CHECKING: import numpy as np - from aiohttp import ClientResponse, ClientSession from PIL.Image import Image logger = logging.getLogger(__name__) @@ -135,10 +143,6 @@ class AsyncInferenceClient: Requests can only be billed to an organization the user is a member of, and which has subscribed to Enterprise Hub. cookies (`Dict[str, str]`, `optional`): Additional cookies to send to the server. - trust_env ('bool', 'optional'): - Trust environment settings for proxy configuration if the parameter is `True` (`False` by default). - proxies (`Any`, `optional`): - Proxies to use for the request. base_url (`str`, `optional`): Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`] follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None. @@ -147,6 +151,7 @@ class AsyncInferenceClient: follow the same pattern as `openai.OpenAI` client. Cannot be used if `token` is set. Defaults to None. """ + @validate_hf_hub_args def __init__( self, model: Optional[str] = None, @@ -156,8 +161,6 @@ def __init__( timeout: Optional[float] = None, headers: Optional[Dict[str, str]] = None, cookies: Optional[Dict[str, str]] = None, - trust_env: bool = False, - proxies: Optional[Any] = None, bill_to: Optional[str] = None, # OpenAI compatibility base_url: Optional[str] = None, @@ -219,15 +222,36 @@ def __init__( self.cookies = cookies self.timeout = timeout - self.trust_env = trust_env - self.proxies = proxies - # Keep track of the sessions to close them properly - self._sessions: Dict["ClientSession", Set["ClientResponse"]] = dict() + self.exit_stack = AsyncExitStack() + self._async_client: Optional[httpx.AsyncClient] = None def __repr__(self): return f"" + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() + + async def close(self): + """Close the client. + + This method is automatically called when using the client as a context manager. + """ + await self.exit_stack.aclose() + + async def _get_async_client(self): + """Get a unique async client for this AsyncInferenceClient instance. + + Returns the same client instance on subsequent calls, ensuring proper + connection reuse and resource management through the exit stack. + """ + if self._async_client is None: + self._async_client = await self.exit_stack.enter_async_context(get_async_session()) + return self._async_client + @overload async def _inner_post( # type: ignore[misc] self, request_parameters: RequestParameters, *, stream: Literal[False] = ... @@ -236,83 +260,60 @@ async def _inner_post( # type: ignore[misc] @overload async def _inner_post( # type: ignore[misc] self, request_parameters: RequestParameters, *, stream: Literal[True] = ... - ) -> AsyncIterable[bytes]: ... + ) -> AsyncIterable[str]: ... @overload async def _inner_post( self, request_parameters: RequestParameters, *, stream: bool = False - ) -> Union[bytes, AsyncIterable[bytes]]: ... + ) -> Union[bytes, AsyncIterable[str]]: ... async def _inner_post( self, request_parameters: RequestParameters, *, stream: bool = False - ) -> Union[bytes, AsyncIterable[bytes]]: + ) -> Union[bytes, AsyncIterable[str]]: """Make a request to the inference server.""" - aiohttp = _import_aiohttp() - # TODO: this should be handled in provider helpers directly if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers: request_parameters.headers["Accept"] = "image/png" - # Do not use context manager as we don't want to close the connection immediately when returning - # a stream - session = self._get_client_session(headers=request_parameters.headers) - try: - response = await session.post( - request_parameters.url, json=request_parameters.json, data=request_parameters.data, proxy=self.proxies - ) - response_error_payload = None - if response.status != 200: - try: - response_error_payload = await response.json() # get payload before connection closed - except Exception: - pass - response.raise_for_status() + client = await self._get_async_client() if stream: - return _async_yield_from(session, response) + response = await self.exit_stack.enter_async_context( + client.stream( + "POST", + request_parameters.url, + json=request_parameters.json, + data=request_parameters.data, + headers=request_parameters.headers, + cookies=self.cookies, + timeout=self.timeout, + ) + ) + hf_raise_for_status(response) + return _async_yield_from(client, response) else: - content = await response.read() - await session.close() - return content + response = await client.post( + request_parameters.url, + json=request_parameters.json, + data=request_parameters.data, + headers=request_parameters.headers, + cookies=self.cookies, + timeout=self.timeout, + ) + hf_raise_for_status(response) + return response.content except asyncio.TimeoutError as error: - await session.close() # Convert any `TimeoutError` to a `InferenceTimeoutError` raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore - except aiohttp.ClientResponseError as error: - error.response_error_payload = response_error_payload - await session.close() - raise error - except Exception: - await session.close() + except HfHubHTTPError as error: + if error.response.status_code == 422 and request_parameters.task != "unknown": + msg = str(error.args[0]) + if len(error.response.text) > 0: + msg += f"{os.linesep}{error.response.text}{os.linesep}" + error.args = (msg,) + error.args[1:] raise - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc_value, traceback): - await self.close() - - def __del__(self): - if len(self._sessions) > 0: - warnings.warn( - "Deleting 'AsyncInferenceClient' client but some sessions are still open. " - "This can happen if you've stopped streaming data from the server before the stream was complete. " - "To close the client properly, you must call `await client.close()` " - "or use an async context (e.g. `async with AsyncInferenceClient(): ...`." - ) - - async def close(self): - """Close all open sessions. - - By default, 'aiohttp.ClientSession' objects are closed automatically when a call is completed. However, if you - are streaming data from the server and you stop before the stream is complete, you must call this method to - close the session properly. - - Another possibility is to use an async context (e.g. `async with AsyncInferenceClient(): ...`). - """ - await asyncio.gather(*[session.close() for session in self._sessions.keys()]) - async def audio_classification( self, audio: ContentT, @@ -343,7 +344,7 @@ async def audio_classification( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `aiohttp.ClientResponseError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -395,7 +396,7 @@ async def audio_to_audio( Raises: `InferenceTimeoutError`: If the model is unavailable or the request times out. - `aiohttp.ClientResponseError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -449,7 +450,7 @@ async def automatic_speech_recognition( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `aiohttp.ClientResponseError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -656,7 +657,7 @@ async def chat_completion( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `aiohttp.ClientResponseError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -1016,7 +1017,7 @@ async def document_question_answering( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `aiohttp.ClientResponseError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. @@ -1092,7 +1093,7 @@ async def feature_extraction( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `aiohttp.ClientResponseError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -1155,7 +1156,7 @@ async def fill_mask( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `aiohttp.ClientResponseError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -1209,7 +1210,7 @@ async def image_classification( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `aiohttp.ClientResponseError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -1272,7 +1273,7 @@ async def image_segmentation( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `aiohttp.ClientResponseError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -1350,7 +1351,7 @@ async def image_to_image( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `aiohttp.ClientResponseError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -1482,7 +1483,7 @@ async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) - Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `aiohttp.ClientResponseError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -1535,7 +1536,7 @@ async def object_detection( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `aiohttp.ClientResponseError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. `ValueError`: If the request output is not a List. @@ -1612,7 +1613,7 @@ async def question_answering( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `aiohttp.ClientResponseError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -1668,7 +1669,7 @@ async def sentence_similarity( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `aiohttp.ClientResponseError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -1730,7 +1731,7 @@ async def summarization( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `aiohttp.ClientResponseError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -1796,7 +1797,7 @@ async def table_question_answering( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `aiohttp.ClientResponseError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -1840,7 +1841,7 @@ async def tabular_classification(self, table: Dict[str, Any], *, model: Optional Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `aiohttp.ClientResponseError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -1896,7 +1897,7 @@ async def tabular_regression(self, table: Dict[str, Any], *, model: Optional[str Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `aiohttp.ClientResponseError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -1958,7 +1959,7 @@ async def text_classification( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `aiohttp.ClientResponseError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -2236,10 +2237,10 @@ async def text_generation( Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Returns: - `Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]`: + `Union[str, TextGenerationOutput, AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]`: Generated text returned from the server: - if `stream=False` and `details=False`, the generated text is returned as a `str` (default) - - if `stream=True` and `details=False`, the generated text is returned token by token as a `Iterable[str]` + - if `stream=True` and `details=False`, the generated text is returned token by token as a `AsyncIterable[str]` - if `stream=False` and `details=True`, the generated text is returned with more details as a [`~huggingface_hub.TextGenerationOutput`] - if `details=True` and `stream=True`, the generated text is returned token by token as a iterable of [`~huggingface_hub.TextGenerationStreamOutput`] @@ -2248,7 +2249,7 @@ async def text_generation( If input values are not valid. No HTTP call is made to the server. [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `aiohttp.ClientResponseError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -2438,9 +2439,9 @@ async def text_generation( # Handle errors separately for more precise error messages try: bytes_output = await self._inner_post(request_parameters, stream=stream or False) - except _import_aiohttp().ClientResponseError as e: - match = MODEL_KWARGS_NOT_USED_REGEX.search(e.response_error_payload["error"]) - if e.status == 400 and match: + except HfHubHTTPError as e: + match = MODEL_KWARGS_NOT_USED_REGEX.search(str(e)) + if isinstance(e, BadRequestError) and match: unused_params = [kwarg.strip("' ") for kwarg in match.group(1).split(",")] _set_unsupported_text_generation_kwargs(model, unused_params) return await self.text_generation( # type: ignore @@ -2541,7 +2542,7 @@ async def text_to_image( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `aiohttp.ClientResponseError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -2814,7 +2815,7 @@ async def text_to_speech( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `aiohttp.ClientResponseError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -2965,7 +2966,7 @@ async def token_classification( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `aiohttp.ClientResponseError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -3051,7 +3052,7 @@ async def translation( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `aiohttp.ClientResponseError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. `ValueError`: If only one of the `src_lang` and `tgt_lang` arguments are provided. @@ -3127,7 +3128,7 @@ async def visual_question_answering( Raises: `InferenceTimeoutError`: If the model is unavailable or the request times out. - `aiohttp.ClientResponseError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -3195,7 +3196,7 @@ async def zero_shot_classification( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `aiohttp.ClientResponseError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example with `multi_label=False`: @@ -3299,7 +3300,7 @@ async def zero_shot_image_classification( Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. - `aiohttp.ClientResponseError`: + [`HfHubHTTPError`]: If the request fails with an HTTP error status code other than HTTP 503. Example: @@ -3334,47 +3335,6 @@ async def zero_shot_image_classification( response = await self._inner_post(request_parameters) return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response) - def _get_client_session(self, headers: Optional[Dict] = None) -> "ClientSession": - aiohttp = _import_aiohttp() - client_headers = self.headers.copy() - if headers is not None: - client_headers.update(headers) - - # Return a new aiohttp ClientSession with correct settings. - session = aiohttp.ClientSession( - headers=client_headers, - cookies=self.cookies, - timeout=aiohttp.ClientTimeout(self.timeout), - trust_env=self.trust_env, - ) - - # Keep track of sessions to close them later - self._sessions[session] = set() - - # Override the `._request` method to register responses to be closed - session._wrapped_request = session._request - - async def _request(method, url, **kwargs): - response = await session._wrapped_request(method, url, **kwargs) - self._sessions[session].add(response) - return response - - session._request = _request - - # Override the 'close' method to - # 1. close ongoing responses - # 2. deregister the session when closed - session._close = session.close - - async def close_session(): - for response in self._sessions[session]: - response.close() - await session._close() - self._sessions.pop(session, None) - - session.close = close_session - return session - async def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]: """ Get information about the deployed endpoint. @@ -3430,10 +3390,10 @@ async def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, A else: url = f"{constants.INFERENCE_ENDPOINT}/models/{model}/info" - async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client: - response = await client.get(url, proxy=self.proxies) - response.raise_for_status() - return await response.json() + client = await self._get_async_client() + response = await client.get(url, headers=build_hf_headers(token=self.token)) + hf_raise_for_status(response) + return response.json() async def health_check(self, model: Optional[str] = None) -> bool: """ @@ -3467,9 +3427,9 @@ async def health_check(self, model: Optional[str] = None) -> bool: raise ValueError("Model must be an Inference Endpoint URL.") url = model.rstrip("/") + "/health" - async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client: - response = await client.get(url, proxy=self.proxies) - return response.status == 200 + client = await self._get_async_client() + response = await client.get(url, headers=build_hf_headers(token=self.token)) + return response.status_code == 200 @property def chat(self) -> "ProxyClientChat": diff --git a/src/huggingface_hub/inference_api.py b/src/huggingface_hub/inference_api.py index f895fcc61c..333fa0e5de 100644 --- a/src/huggingface_hub/inference_api.py +++ b/src/huggingface_hub/inference_api.py @@ -44,7 +44,7 @@ class InferenceApi: - """Client to configure requests and make calls to the HuggingFace Inference API. + """Client to configure httpx and make calls to the HuggingFace Inference API. Example: @@ -187,7 +187,7 @@ def __call__( payload["parameters"] = params # Make API call - response = get_session().post(self.api_url, headers=self.headers, json=payload, data=data) + response = get_session().post(self.api_url, headers=self.headers, json=payload, content=data) # Let the user handle the response if raw_response: diff --git a/src/huggingface_hub/keras_mixin.py b/src/huggingface_hub/keras_mixin.py index 45d0eaf8a7..53290dc858 100644 --- a/src/huggingface_hub/keras_mixin.py +++ b/src/huggingface_hub/keras_mixin.py @@ -265,10 +265,6 @@ def from_pretrained_keras(*args, **kwargs) -> "KerasModelHubMixin": force_download (`bool`, *optional*, defaults to `False`): Whether to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., - `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The - proxies are used on each request. token (`str` or `bool`, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `transformers-cli @@ -463,7 +459,6 @@ def _from_pretrained( revision, cache_dir, force_download, - proxies, resume_download, local_files_only, token, diff --git a/src/huggingface_hub/lfs.py b/src/huggingface_hub/lfs.py index c2d4f36829..3ff465f9c0 100644 --- a/src/huggingface_hub/lfs.py +++ b/src/huggingface_hub/lfs.py @@ -14,7 +14,6 @@ # limitations under the License. """Git LFS related type definitions and utilities""" -import inspect import io import re import warnings @@ -136,7 +135,7 @@ def post_lfs_batch_info( Raises: [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If an argument is invalid or the server response is malformed. - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + [`HfHubHTTPError`] If the server returned an error. """ endpoint = endpoint if endpoint is not None else constants.ENDPOINT @@ -214,7 +213,7 @@ def lfs_upload( Raises: [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If `lfs_batch_action` is improperly formatted - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + [`HfHubHTTPError`] If the upload resulted in an error """ # 0. If LFS file is already present, skip upload @@ -308,11 +307,9 @@ def _upload_single_part(operation: "CommitOperationAdd", upload_url: str) -> Non fileobj: The file-like object holding the data to upload. - Returns: `requests.Response` - Raises: - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) - If the upload resulted in an error. + [`HfHubHTTPError`] + If the upload resulted in an error. """ with operation.as_file(with_tqdm=True) as fileobj: # S3 might raise a transient 500 error -> let's retry if that happens @@ -420,12 +417,6 @@ def _upload_parts_hf_transfer( " not available in your environment. Try `pip install hf_transfer`." ) - supports_callback = "callback" in inspect.signature(multipart_upload).parameters - if not supports_callback: - warnings.warn( - "You are using an outdated version of `hf_transfer`. Consider upgrading to latest version to enable progress bars using `pip install -U hf_transfer`." - ) - total = operation.upload_info.size desc = operation.path_in_repo if len(desc) > 40: @@ -448,13 +439,11 @@ def _upload_parts_hf_transfer( max_files=128, parallel_failures=127, # could be removed max_retries=5, - **({"callback": progress.update} if supports_callback else {}), + callback=progress.update, ) except Exception as e: raise RuntimeError( "An error occurred while uploading using `hf_transfer`. Consider disabling HF_HUB_ENABLE_HF_TRANSFER for" " better error handling." ) from e - if not supports_callback: - progress.update(total) return output diff --git a/src/huggingface_hub/repocard.py b/src/huggingface_hub/repocard.py index bb7de8c59a..c8c9a28a17 100644 --- a/src/huggingface_hub/repocard.py +++ b/src/huggingface_hub/repocard.py @@ -3,7 +3,6 @@ from pathlib import Path from typing import Any, Dict, Literal, Optional, Type, Union -import requests import yaml from huggingface_hub.file_download import hf_hub_download @@ -17,7 +16,7 @@ eval_results_to_model_index, model_index_to_eval_results, ) -from huggingface_hub.utils import get_session, is_jinja_available, yaml_dump +from huggingface_hub.utils import HfHubHTTPError, get_session, hf_raise_for_status, is_jinja_available, yaml_dump from . import constants from .errors import EntryNotFoundError @@ -204,7 +203,7 @@ def validate(self, repo_type: Optional[str] = None): - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if the card fails validation checks. - - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + - [`HfHubHTTPError`] if the request to the Hub API fails for any other reason. @@ -220,11 +219,11 @@ def validate(self, repo_type: Optional[str] = None): headers = {"Accept": "text/plain"} try: - r = get_session().post("https://huggingface.co/api/validate-yaml", body, headers=headers) - r.raise_for_status() - except requests.exceptions.HTTPError as exc: - if r.status_code == 400: - raise ValueError(r.text) + response = get_session().post("https://huggingface.co/api/validate-yaml", json=body, headers=headers) + hf_raise_for_status(response) + except HfHubHTTPError as exc: + if response.status_code == 400: + raise ValueError(response.text) else: raise exc diff --git a/src/huggingface_hub/utils/__init__.py b/src/huggingface_hub/utils/__init__.py index 992eac104b..52838fe000 100644 --- a/src/huggingface_hub/utils/__init__.py +++ b/src/huggingface_hub/utils/__init__.py @@ -52,12 +52,19 @@ from ._headers import build_hf_headers, get_token_to_send from ._hf_folder import HfFolder from ._http import ( - configure_http_backend, + ASYNC_CLIENT_FACTORY_T, + CLIENT_FACTORY_T, + HfHubAsyncTransport, + HfHubTransport, + close_client, fix_hf_endpoint_in_url, + get_async_session, get_session, hf_raise_for_status, http_backoff, - reset_sessions, + http_stream_backoff, + set_async_client_factory, + set_client_factory, ) from ._pagination import paginate from ._paths import DEFAULT_IGNORE_PATTERNS, FORBIDDEN_FOLDERS, filter_repo_objects diff --git a/src/huggingface_hub/utils/_fixes.py b/src/huggingface_hub/utils/_fixes.py index 560003b622..a1cacc0907 100644 --- a/src/huggingface_hub/utils/_fixes.py +++ b/src/huggingface_hub/utils/_fixes.py @@ -1,13 +1,3 @@ -# JSONDecodeError was introduced in requests=2.27 released in 2022. -# This allows us to support older requests for users -# More information: https://github.com/psf/requests/pull/5856 -try: - from requests import JSONDecodeError # type: ignore # noqa: F401 -except ImportError: - try: - from simplejson import JSONDecodeError # type: ignore # noqa: F401 - except ImportError: - from json import JSONDecodeError # type: ignore # noqa: F401 import contextlib import os import shutil diff --git a/src/huggingface_hub/utils/_http.py b/src/huggingface_hub/utils/_http.py index 5baceb8f8f..b3a545c722 100644 --- a/src/huggingface_hub/utils/_http.py +++ b/src/huggingface_hub/utils/_http.py @@ -12,23 +12,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Contains utilities to handle HTTP requests in Huggingface Hub.""" +"""Contains utilities to handle HTTP requests in huggingface_hub.""" +import atexit import io -import os +import json import re import threading import time import uuid -from functools import lru_cache +from contextlib import contextmanager from http import HTTPStatus from shlex import quote -from typing import Any, Callable, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Generator, List, Optional, Tuple, Type, Union -import requests -from requests import HTTPError, Response -from requests.adapters import HTTPAdapter -from requests.models import PreparedRequest +import httpx from huggingface_hub.errors import OfflineModeIsEnabled @@ -36,14 +34,13 @@ from ..errors import ( BadRequestError, DisabledRepoError, - EntryNotFoundError, GatedRepoError, HfHubHTTPError, + RemoteEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError, ) from . import logging -from ._fixes import JSONDecodeError from ._lfs import SliceFileObj from ._typing import HTTP_METHOD_T @@ -72,142 +69,273 @@ ) -class UniqueRequestIdAdapter(HTTPAdapter): - X_AMZN_TRACE_ID = "X-Amzn-Trace-Id" +class HfHubTransport(httpx.HTTPTransport): + """ + Transport that will be used to make HTTP requests to the Hugging Face Hub. - def add_headers(self, request, **kwargs): - super().add_headers(request, **kwargs) + What it does: + - Block requests if offline mode is enabled + - Add a request ID to the request headers + - Log the request if debug mode is enabled + """ - # Add random request ID => easier for server-side debug - if X_AMZN_TRACE_ID not in request.headers: - request.headers[X_AMZN_TRACE_ID] = request.headers.get(X_REQUEST_ID) or str(uuid.uuid4()) + def handle_request(self, request: httpx.Request) -> httpx.Response: + if constants.HF_HUB_OFFLINE: + raise OfflineModeIsEnabled( + f"Cannot reach {request.url}: offline mode is enabled. To disable it, please unset the `HF_HUB_OFFLINE` environment variable." + ) + request_id = _add_request_id(request) + try: + return super().handle_request(request) + except httpx.RequestError as e: + if request_id is not None: + # Taken from https://stackoverflow.com/a/58270258 + e.args = (*e.args, f"(Request ID: {request_id})") + raise - # Add debug log - has_token = len(str(request.headers.get("authorization", ""))) > 0 - logger.debug( - f"Request {request.headers[X_AMZN_TRACE_ID]}: {request.method} {request.url} (authenticated: {has_token})" - ) - def send(self, request: PreparedRequest, *args, **kwargs) -> Response: - """Catch any RequestException to append request id to the error message for debugging.""" - if constants.HF_DEBUG: - logger.debug(f"Send: {_curlify(request)}") +class HfHubAsyncTransport(httpx.AsyncHTTPTransport): + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + if constants.HF_HUB_OFFLINE: + raise OfflineModeIsEnabled( + f"Cannot reach {request.url}: offline mode is enabled. To disable it, please unset the `HF_HUB_OFFLINE` environment variable." + ) + request_id = _add_request_id(request) try: - return super().send(request, *args, **kwargs) - except requests.RequestException as e: - request_id = request.headers.get(X_AMZN_TRACE_ID) + return await super().handle_async_request(request) + except httpx.RequestError as e: if request_id is not None: # Taken from https://stackoverflow.com/a/58270258 e.args = (*e.args, f"(Request ID: {request_id})") raise -class OfflineAdapter(HTTPAdapter): - def send(self, request: PreparedRequest, *args, **kwargs) -> Response: - raise OfflineModeIsEnabled( - f"Cannot reach {request.url}: offline mode is enabled. To disable it, please unset the `HF_HUB_OFFLINE` environment variable." - ) +def _add_request_id(request: httpx.Request) -> Optional[str]: + # Add random request ID => easier for server-side debug + if X_AMZN_TRACE_ID not in request.headers: + request.headers[X_AMZN_TRACE_ID] = request.headers.get(X_REQUEST_ID) or str(uuid.uuid4()) + request_id = request.headers.get(X_AMZN_TRACE_ID) + # Debug log + logger.debug( + "Request %s: %s %s (authenticated: %s)", + request_id, + request.method, + request.url, + request.headers.get("authorization") is not None, + ) + if constants.HF_DEBUG: + logger.debug("Send: %s", _curlify(request)) -def _default_backend_factory() -> requests.Session: - session = requests.Session() - if constants.HF_HUB_OFFLINE: - session.mount("http://", OfflineAdapter()) - session.mount("https://", OfflineAdapter()) - else: - session.mount("http://", UniqueRequestIdAdapter()) - session.mount("https://", UniqueRequestIdAdapter()) - return session + return request_id -BACKEND_FACTORY_T = Callable[[], requests.Session] -_GLOBAL_BACKEND_FACTORY: BACKEND_FACTORY_T = _default_backend_factory +def default_client_factory() -> httpx.Client: + """ + Factory function to create a `httpx.Client` with the default transport. + """ + return httpx.Client( + transport=HfHubTransport(), + follow_redirects=True, + timeout=httpx.Timeout(constants.DEFAULT_REQUEST_TIMEOUT, write=60.0), + ) -def configure_http_backend(backend_factory: BACKEND_FACTORY_T = _default_backend_factory) -> None: +def default_async_client_factory() -> httpx.AsyncClient: + """ + Factory function to create a `httpx.AsyncClient` with the default transport. """ - Configure the HTTP backend by providing a `backend_factory`. Any HTTP calls made by `huggingface_hub` will use a - Session object instantiated by this factory. This can be useful if you are running your scripts in a specific - environment requiring custom configuration (e.g. custom proxy or certifications). + return httpx.AsyncClient( + transport=HfHubAsyncTransport(), + follow_redirects=True, + timeout=httpx.Timeout(constants.DEFAULT_REQUEST_TIMEOUT, write=60.0), + ) - Use [`get_session`] to get a configured Session. Since `requests.Session` is not guaranteed to be thread-safe, - `huggingface_hub` creates 1 Session instance per thread. They are all instantiated using the same `backend_factory` - set in [`configure_http_backend`]. A LRU cache is used to cache the created sessions (and connections) between - calls. Max size is 128 to avoid memory leaks if thousands of threads are spawned. - See [this issue](https://github.com/psf/requests/issues/2766) to know more about thread-safety in `requests`. +CLIENT_FACTORY_T = Callable[[], httpx.Client] +ASYNC_CLIENT_FACTORY_T = Callable[[], httpx.AsyncClient] - Example: - ```py - import requests - from huggingface_hub import configure_http_backend, get_session +_CLIENT_LOCK = threading.Lock() +_GLOBAL_CLIENT_FACTORY: CLIENT_FACTORY_T = default_client_factory +_GLOBAL_ASYNC_CLIENT_FACTORY: ASYNC_CLIENT_FACTORY_T = default_async_client_factory +_GLOBAL_CLIENT: Optional[httpx.Client] = None - # Create a factory function that returns a Session with configured proxies - def backend_factory() -> requests.Session: - session = requests.Session() - session.proxies = {"http": "http://10.10.1.10:3128", "https": "https://10.10.1.11:1080"} - return session - # Set it as the default session factory - configure_http_backend(backend_factory=backend_factory) +def set_client_factory(client_factory: CLIENT_FACTORY_T) -> None: + """ + Set the HTTP client factory to be used by `huggingface_hub`. - # In practice, this is mostly done internally in `huggingface_hub` - session = get_session() - ``` + The client factory is a method that returns a `httpx.Client` object. On the first call to [`get_client`] the client factory + will be used to create a new `httpx.Client` object that will be shared between all calls made by `huggingface_hub`. + + This can be useful if you are running your scripts in a specific environment requiring custom configuration (e.g. custom proxy or certifications). + + Use [`get_client`] to get a correctly configured `httpx.Client`. """ - global _GLOBAL_BACKEND_FACTORY - _GLOBAL_BACKEND_FACTORY = backend_factory - reset_sessions() + global _GLOBAL_CLIENT_FACTORY + with _CLIENT_LOCK: + close_client() + _GLOBAL_CLIENT_FACTORY = client_factory -def get_session() -> requests.Session: +def set_async_client_factory(async_client_factory: ASYNC_CLIENT_FACTORY_T) -> None: """ - Get a `requests.Session` object, using the session factory from the user. + Set the HTTP async client factory to be used by `huggingface_hub`. - Use [`get_session`] to get a configured Session. Since `requests.Session` is not guaranteed to be thread-safe, - `huggingface_hub` creates 1 Session instance per thread. They are all instantiated using the same `backend_factory` - set in [`configure_http_backend`]. A LRU cache is used to cache the created sessions (and connections) between - calls. Max size is 128 to avoid memory leaks if thousands of threads are spawned. + The async client factory is a method that returns a `httpx.AsyncClient` object. + This can be useful if you are running your scripts in a specific environment requiring custom configuration (e.g. custom proxy or certifications). + Use [`get_async_client`] to get a correctly configured `httpx.AsyncClient`. - See [this issue](https://github.com/psf/requests/issues/2766) to know more about thread-safety in `requests`. + + + Contrary to the `httpx.Client` that is shared between all calls made by `huggingface_hub`, the `httpx.AsyncClient` is not shared. + It is recommended to use an async context manager to ensure the client is properly closed when the context is exited. + + + """ + global _GLOBAL_ASYNC_CLIENT_FACTORY + _GLOBAL_ASYNC_CLIENT_FACTORY = async_client_factory - Example: - ```py - import requests - from huggingface_hub import configure_http_backend, get_session - # Create a factory function that returns a Session with configured proxies - def backend_factory() -> requests.Session: - session = requests.Session() - session.proxies = {"http": "http://10.10.1.10:3128", "https": "https://10.10.1.11:1080"} - return session +def get_session() -> httpx.Client: + """ + Get a `httpx.Client` object, using the transport factory from the user. - # Set it as the default session factory - configure_http_backend(backend_factory=backend_factory) + This client is shared between all calls made by `huggingface_hub`. Therefore you should not close it manually. - # In practice, this is mostly done internally in `huggingface_hub` - session = get_session() - ``` + Use [`set_client_factory`] to customize the `httpx.Client`. + """ + global _GLOBAL_CLIENT + if _GLOBAL_CLIENT is None: + with _CLIENT_LOCK: + _GLOBAL_CLIENT = _GLOBAL_CLIENT_FACTORY() + return _GLOBAL_CLIENT + + +def get_async_session() -> httpx.AsyncClient: """ - return _get_session_from_cache(process_id=os.getpid(), thread_id=threading.get_ident()) + Return a `httpx.AsyncClient` object, using the transport factory from the user. + + Use [`set_async_client_factory`] to customize the `httpx.AsyncClient`. + -def reset_sessions() -> None: - """Reset the cache of sessions. + Contrary to the `httpx.Client` that is shared between all calls made by `huggingface_hub`, the `httpx.AsyncClient` is not shared. + It is recommended to use an async context manager to ensure the client is properly closed when the context is exited. - Mostly used internally when sessions are reconfigured or an SSLError is raised. - See [`configure_http_backend`] for more details. + """ - _get_session_from_cache.cache_clear() + return _GLOBAL_ASYNC_CLIENT_FACTORY() -@lru_cache -def _get_session_from_cache(process_id: int, thread_id: int) -> requests.Session: +def close_client() -> None: """ - Create a new session per thread using global factory. Using LRU cache (maxsize 128) to avoid memory leaks when - using thousands of threads. Cache is cleared when `configure_http_backend` is called. + Close the global httpx.Client used by `huggingface_hub`. + + If a Client is closed, it will be recreated on the next call to [`get_client`]. + + Can be useful if e.g. an SSL certificate has been updated. """ - return _GLOBAL_BACKEND_FACTORY() + global _GLOBAL_CLIENT + client = _GLOBAL_CLIENT + + # First, set global client to None + _GLOBAL_CLIENT = None + + # Then, close the clients + if client is not None: + try: + client.close() + except Exception as e: + logger.warning(f"Error closing client: {e}") + + +atexit.register(close_client) + + +def _http_backoff_base( + method: HTTP_METHOD_T, + url: str, + *, + max_retries: int = 5, + base_wait_time: float = 1, + max_wait_time: float = 8, + retry_on_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = ( + httpx.TimeoutException, + httpx.NetworkError, + ), + retry_on_status_codes: Union[int, Tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE, + stream: bool = False, + **kwargs, +) -> Generator[httpx.Response, None, None]: + """Internal implementation of HTTP backoff logic shared between `http_backoff` and `http_stream_backoff`.""" + if isinstance(retry_on_exceptions, type): # Tuple from single exception type + retry_on_exceptions = (retry_on_exceptions,) + + if isinstance(retry_on_status_codes, int): # Tuple from single status code + retry_on_status_codes = (retry_on_status_codes,) + + nb_tries = 0 + sleep_time = base_wait_time + + # If `data` is used and is a file object (or any IO), it will be consumed on the + # first HTTP request. We need to save the initial position so that the full content + # of the file is re-sent on http backoff. See warning tip in docstring. + io_obj_initial_pos = None + if "data" in kwargs and isinstance(kwargs["data"], (io.IOBase, SliceFileObj)): + io_obj_initial_pos = kwargs["data"].tell() + + client = get_session() + while True: + nb_tries += 1 + try: + # If `data` is used and is a file object (or any IO), set back cursor to + # initial position. + if io_obj_initial_pos is not None: + kwargs["data"].seek(io_obj_initial_pos) + + # Perform request and handle response + def _should_retry(response: httpx.Response) -> bool: + """Handle response and return True if should retry, False if should return/yield.""" + if response.status_code not in retry_on_status_codes: + return False # Success, don't retry + + # Wrong status code returned (HTTP 503 for instance) + logger.warning(f"HTTP Error {response.status_code} thrown while requesting {method} {url}") + if nb_tries > max_retries: + hf_raise_for_status(response) # Will raise uncaught exception + # Return/yield response to avoid infinite loop in the corner case where the + # user ask for retry on a status code that doesn't raise_for_status. + return False # Don't retry, return/yield response + + return True # Should retry + + if stream: + with client.stream(method=method, url=url, **kwargs) as response: + if not _should_retry(response): + yield response + return + else: + response = client.request(method=method, url=url, **kwargs) + if not _should_retry(response): + yield response + return + + except retry_on_exceptions as err: + logger.warning(f"'{err}' thrown while requesting {method} {url}") + + if isinstance(err, httpx.ConnectError): + close_client() # In case of SSLError it's best to close the shared httpx.Client objects + + if nb_tries > max_retries: + raise err + + # Sleep for X seconds + logger.warning(f"Retrying in {sleep_time}s [Retry {nb_tries}/{max_retries}].") + time.sleep(sleep_time) + + # Update sleep time for next retry + sleep_time = min(max_wait_time, sleep_time * 2) # Exponential backoff def http_backoff( @@ -218,13 +346,13 @@ def http_backoff( base_wait_time: float = 1, max_wait_time: float = 8, retry_on_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = ( - requests.Timeout, - requests.ConnectionError, + httpx.TimeoutException, + httpx.NetworkError, ), retry_on_status_codes: Union[int, Tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE, **kwargs, -) -> Response: - """Wrapper around requests to retry calls on an endpoint, with exponential backoff. +) -> httpx.Response: + """Wrapper around httpx to retry calls on an endpoint, with exponential backoff. Endpoint call is retried on exceptions (ex: connection timeout, proxy error,...) and/or on specific status codes (ex: service unavailable). If the call failed more @@ -249,18 +377,18 @@ def http_backoff( Maximum duration (in seconds) to wait before retrying. retry_on_exceptions (`Type[Exception]` or `Tuple[Type[Exception]]`, *optional*): Define which exceptions must be caught to retry the request. Can be a single type or a tuple of types. - By default, retry on `requests.Timeout` and `requests.ConnectionError`. + By default, retry on `httpx.TimeoutException` and `httpx.NetworkError`. retry_on_status_codes (`int` or `Tuple[int]`, *optional*, defaults to `503`): Define on which status codes the request must be retried. By default, only HTTP 503 Service Unavailable is retried. **kwargs (`dict`, *optional*): - kwargs to pass to `requests.request`. + kwargs to pass to `httpx.request`. Example: ``` >>> from huggingface_hub.utils import http_backoff - # Same usage as "requests.request". + # Same usage as "httpx.request". >>> response = http_backoff("GET", "https://www.google.com") >>> response.raise_for_status() @@ -271,7 +399,7 @@ def http_backoff( - When using `requests` it is possible to stream data by passing an iterator to the + When using `httpx` it is possible to stream data by passing an iterator to the `data` argument. On http backoff this is a problem as the iterator is not reset after a failed call. This issue is mitigated for file objects or any IO streams by saving the initial position of the cursor (with `data.tell()`) and resetting the @@ -281,59 +409,105 @@ def http_backoff( """ - if isinstance(retry_on_exceptions, type): # Tuple from single exception type - retry_on_exceptions = (retry_on_exceptions,) + return next( + _http_backoff_base( + method=method, + url=url, + max_retries=max_retries, + base_wait_time=base_wait_time, + max_wait_time=max_wait_time, + retry_on_exceptions=retry_on_exceptions, + retry_on_status_codes=retry_on_status_codes, + stream=False, + **kwargs, + ) + ) - if isinstance(retry_on_status_codes, int): # Tuple from single status code - retry_on_status_codes = (retry_on_status_codes,) - nb_tries = 0 - sleep_time = base_wait_time +@contextmanager +def http_stream_backoff( + method: HTTP_METHOD_T, + url: str, + *, + max_retries: int = 5, + base_wait_time: float = 1, + max_wait_time: float = 8, + retry_on_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = ( + httpx.TimeoutException, + httpx.NetworkError, + ), + retry_on_status_codes: Union[int, Tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE, + **kwargs, +) -> Generator[httpx.Response, None, None]: + """Wrapper around httpx to retry calls on an endpoint, with exponential backoff. - # If `data` is used and is a file object (or any IO), it will be consumed on the - # first HTTP request. We need to save the initial position so that the full content - # of the file is re-sent on http backoff. See warning tip in docstring. - io_obj_initial_pos = None - if "data" in kwargs and isinstance(kwargs["data"], (io.IOBase, SliceFileObj)): - io_obj_initial_pos = kwargs["data"].tell() + Endpoint call is retried on exceptions (ex: connection timeout, proxy error,...) + and/or on specific status codes (ex: service unavailable). If the call failed more + than `max_retries`, the exception is thrown or `raise_for_status` is called on the + response object. - session = get_session() - while True: - nb_tries += 1 - try: - # If `data` is used and is a file object (or any IO), set back cursor to - # initial position. - if io_obj_initial_pos is not None: - kwargs["data"].seek(io_obj_initial_pos) + Re-implement mechanisms from the `backoff` library to avoid adding an external + dependencies to `hugging_face_hub`. See https://github.com/litl/backoff. - # Perform request and return if status_code is not in the retry list. - response = session.request(method=method, url=url, **kwargs) - if response.status_code not in retry_on_status_codes: - return response + Args: + method (`Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"]`): + HTTP method to perform. + url (`str`): + The URL of the resource to fetch. + max_retries (`int`, *optional*, defaults to `5`): + Maximum number of retries, defaults to 5 (no retries). + base_wait_time (`float`, *optional*, defaults to `1`): + Duration (in seconds) to wait before retrying the first time. + Wait time between retries then grows exponentially, capped by + `max_wait_time`. + max_wait_time (`float`, *optional*, defaults to `8`): + Maximum duration (in seconds) to wait before retrying. + retry_on_exceptions (`Type[Exception]` or `Tuple[Type[Exception]]`, *optional*): + Define which exceptions must be caught to retry the request. Can be a single type or a tuple of types. + By default, retry on `httpx.Timeout` and `httpx.NetworkError`. + retry_on_status_codes (`int` or `Tuple[int]`, *optional*, defaults to `503`): + Define on which status codes the request must be retried. By default, only + HTTP 503 Service Unavailable is retried. + **kwargs (`dict`, *optional*): + kwargs to pass to `httpx.request`. - # Wrong status code returned (HTTP 503 for instance) - logger.warning(f"HTTP Error {response.status_code} thrown while requesting {method} {url}") - if nb_tries > max_retries: - response.raise_for_status() # Will raise uncaught exception - # We return response to avoid infinite loop in the corner case where the - # user ask for retry on a status code that doesn't raise_for_status. - return response + Example: + ``` + >>> from huggingface_hub.utils import http_stream_backoff - except retry_on_exceptions as err: - logger.warning(f"'{err}' thrown while requesting {method} {url}") + # Same usage as "httpx.stream". + >>> with http_stream_backoff("GET", "https://www.google.com") as response: + ... for chunk in response.iter_bytes(): + ... print(chunk) - if isinstance(err, requests.ConnectionError): - reset_sessions() # In case of SSLError it's best to reset the shared requests.Session objects + # If you expect a Gateway Timeout from time to time + >>> with http_stream_backoff("PUT", upload_url, data=data, retry_on_status_codes=504) as response: + ... response.raise_for_status() + ``` - if nb_tries > max_retries: - raise err + - # Sleep for X seconds - logger.warning(f"Retrying in {sleep_time}s [Retry {nb_tries}/{max_retries}].") - time.sleep(sleep_time) + When using `httpx` it is possible to stream data by passing an iterator to the + `data` argument. On http backoff this is a problem as the iterator is not reset + after a failed call. This issue is mitigated for file objects or any IO streams + by saving the initial position of the cursor (with `data.tell()`) and resetting the + cursor between each call (with `data.seek()`). For arbitrary iterators, http backoff + will fail. If this is a hard constraint for you, please let us know by opening an + issue on [Github](https://github.com/huggingface/huggingface_hub). - # Update sleep time for next retry - sleep_time = min(max_wait_time, sleep_time * 2) # Exponential backoff + + """ + yield from _http_backoff_base( + method=method, + url=url, + max_retries=max_retries, + base_wait_time=base_wait_time, + max_wait_time=max_wait_time, + retry_on_exceptions=retry_on_exceptions, + retry_on_status_codes=retry_on_status_codes, + stream=True, + **kwargs, + ) def fix_hf_endpoint_in_url(url: str, endpoint: Optional[str]) -> str: @@ -349,55 +523,32 @@ def fix_hf_endpoint_in_url(url: str, endpoint: Optional[str]) -> str: return url -def hf_raise_for_status(response: Response, endpoint_name: Optional[str] = None) -> None: +def hf_raise_for_status(response: httpx.Response, endpoint_name: Optional[str] = None) -> None: """ - Internal version of `response.raise_for_status()` that will refine a - potential HTTPError. Raised exception will be an instance of `HfHubHTTPError`. + Internal version of `response.raise_for_status()` that will refine a potential HTTPError. + Raised exception will be an instance of [`~errors.HfHubHTTPError`]. - This helper is meant to be the unique method to raise_for_status when making a call - to the Hugging Face Hub. - - - Example: - ```py - import requests - from huggingface_hub.utils import get_session, hf_raise_for_status, HfHubHTTPError - - response = get_session().post(...) - try: - hf_raise_for_status(response) - except HfHubHTTPError as e: - print(str(e)) # formatted message - e.request_id, e.server_message # details returned by server - - # Complete the error message with additional information once it's raised - e.append_to_message("\n`create_commit` expects the repository to exist.") - raise - ``` + This helper is meant to be the unique method to raise_for_status when making a call to the Hugging Face Hub. Args: response (`Response`): Response from the server. endpoint_name (`str`, *optional*): - Name of the endpoint that has been called. If provided, the error message - will be more complete. + Name of the endpoint that has been called. If provided, the error message will be more complete. Raises when the request has failed: - [`~utils.RepositoryNotFoundError`] - If the repository to download from cannot be found. This may be because it - doesn't exist, because `repo_type` is not set correctly, or because the repo - is `private` and you do not have access. + If the repository to download from cannot be found. This may be because it doesn't exist, because `repo_type` + is not set correctly, or because the repo is `private` and you do not have access. - [`~utils.GatedRepoError`] - If the repository exists but is gated and the user is not on the authorized - list. + If the repository exists but is gated and the user is not on the authorized list. - [`~utils.RevisionNotFoundError`] If the repository exists but the revision couldn't be find. - - [`~utils.EntryNotFoundError`] - If the repository exists but the entry (e.g. the requested file) couldn't be - find. + - [`~utils.RemoteEntryNotFoundError`] + If the repository exists but the entry (e.g. the requested file) couldn't be find. - [`~utils.BadRequestError`] If request failed with a HTTP 400 BadRequest error. - [`~utils.HfHubHTTPError`] @@ -407,7 +558,10 @@ def hf_raise_for_status(response: Response, endpoint_name: Optional[str] = None) """ try: response.raise_for_status() - except HTTPError as e: + except httpx.HTTPStatusError as e: + if response.status_code // 100 == 3: + return # Do not raise on redirects to stay consistent with `requests` + error_code = response.headers.get("X-Error-Code") error_message = response.headers.get("X-Error-Message") @@ -417,7 +571,7 @@ def hf_raise_for_status(response: Response, endpoint_name: Optional[str] = None) elif error_code == "EntryNotFound": message = f"{response.status_code} Client Error." + "\n\n" + f"Entry Not Found for url: {response.url}." - raise _format(EntryNotFoundError, message, response) from e + raise _format(RemoteEntryNotFoundError, message, response) from e elif error_code == "GatedRepo": message = ( @@ -440,7 +594,7 @@ def hf_raise_for_status(response: Response, endpoint_name: Optional[str] = None) and error_message != "Invalid credentials in Authorization header" and response.request is not None and response.request.url is not None - and REPO_API_REGEX.search(response.request.url) is not None + and REPO_API_REGEX.search(str(response.request.url)) is not None ): # 401 is misleading as it is returned for: # - private and gated repos if user is not authenticated @@ -482,7 +636,7 @@ def hf_raise_for_status(response: Response, endpoint_name: Optional[str] = None) raise _format(HfHubHTTPError, str(e), response) from e -def _format(error_type: Type[HfHubHTTPError], custom_message: str, response: Response) -> HfHubHTTPError: +def _format(error_type: Type[HfHubHTTPError], custom_message: str, response: httpx.Response) -> HfHubHTTPError: server_errors = [] # Retrieve server error from header @@ -493,7 +647,11 @@ def _format(error_type: Type[HfHubHTTPError], custom_message: str, response: Res # Retrieve server error from body try: # Case errors are returned in a JSON format - data = response.json() + try: + data = response.json() + except httpx.ResponseNotRead: + response.read() # In case of streaming response, we need to read the response first + data = response.json() error = data.get("error") if error is not None: @@ -511,7 +669,7 @@ def _format(error_type: Type[HfHubHTTPError], custom_message: str, response: Res if "message" in error: server_errors.append(error["message"]) - except JSONDecodeError: + except json.JSONDecodeError: # If content is not JSON and not HTML, append the text content_type = response.headers.get("Content-Type", "") if response.text and "html" not in content_type.lower(): @@ -556,8 +714,8 @@ def _format(error_type: Type[HfHubHTTPError], custom_message: str, response: Res return error_type(final_error_message.strip(), response=response, server_message=server_message or None) -def _curlify(request: requests.PreparedRequest) -> str: - """Convert a `requests.PreparedRequest` into a curl command (str). +def _curlify(request: httpx.Request) -> str: + """Convert a `httpx.Request` into a curl command (str). Used for debug purposes only. @@ -572,16 +730,16 @@ def _curlify(request: requests.PreparedRequest) -> str: for k, v in sorted(request.headers.items()): if k.lower() == "authorization": v = "" # Hide authorization header, no matter its value (can be Bearer, Key, etc.) - parts += [("-H", "{0}: {1}".format(k, v))] - - if request.body: - body = request.body - if isinstance(body, bytes): - body = body.decode("utf-8", errors="ignore") - elif hasattr(body, "read"): - body = "" # Don't try to read it to avoid consuming the stream + parts += [("-H", f"{k}: {v}")] + + body: Optional[str] = None + if request.content is not None: + body = request.content.decode("utf-8", errors="ignore") if len(body) > 1000: - body = body[:1000] + " ... [truncated]" + body = f"{body[:1000]} ... [truncated]" + elif request.stream is not None: + body = "" + if body is not None: parts += [("-d", body.replace("\n", ""))] parts += [(None, request.url)] diff --git a/src/huggingface_hub/utils/_pagination.py b/src/huggingface_hub/utils/_pagination.py index 3ef2b6668b..1d63ad4b49 100644 --- a/src/huggingface_hub/utils/_pagination.py +++ b/src/huggingface_hub/utils/_pagination.py @@ -16,7 +16,7 @@ from typing import Dict, Iterable, Optional -import requests +import httpx from . import get_session, hf_raise_for_status, http_backoff, logging @@ -48,5 +48,5 @@ def paginate(path: str, params: Dict, headers: Dict) -> Iterable: next_page = _get_next_page(r) -def _get_next_page(response: requests.Response) -> Optional[str]: +def _get_next_page(response: httpx.Response) -> Optional[str]: return response.links.get("next", {}).get("url") diff --git a/src/huggingface_hub/utils/_validators.py b/src/huggingface_hub/utils/_validators.py index 27833f28e3..2a1b473446 100644 --- a/src/huggingface_hub/utils/_validators.py +++ b/src/huggingface_hub/utils/_validators.py @@ -111,6 +111,8 @@ def _inner_fn(*args, **kwargs): if check_use_auth_token: kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs) + kwargs = smoothly_deprecate_proxies(fn_name=fn.__name__, kwargs=kwargs) + return fn(*args, **kwargs) return _inner_fn # type: ignore @@ -170,6 +172,37 @@ def validate_repo_id(repo_id: str) -> None: raise HFValidationError(f"Repo_id cannot end by '.git': '{repo_id}'.") +def smoothly_deprecate_proxies(fn_name: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """Smoothly deprecate `proxies` in the `huggingface_hub` codebase. + + This function removes the `proxies` key from the kwargs and warns the user that the `proxies` argument is ignored. + To set up proxies, user must either use the HTTP_PROXY environment variable or configure the `httpx.Client` manually + using the [`set_client_factory`] function. + + In huggingface_hub 0.x, `proxies` was a dictionary directly passed to `requests.request`. + In huggingface_hub 1.x, we migrated to `httpx` which does not support `proxies` the same way. + In particular, it is not possible to configure proxies on a per-request basis. The solution is to configure + it globally using the [`set_client_factory`] function or using the HTTP_PROXY environment variable. + + More more details, see: + - https://www.python-httpx.org/advanced/proxies/ + - https://www.python-httpx.org/compatibility/#proxy-keys. + + We did not want to completely remove the `proxies` argument to avoid breaking existing code. + """ + new_kwargs = kwargs.copy() # do not mutate input ! + + proxies = new_kwargs.pop("proxies", None) # remove from kwargs + if proxies is not None: + warnings.warn( + f"The `proxies` argument is ignored in `{fn_name}`. To set up proxies, use the HTTP_PROXY / HTTPS_PROXY" + " environment variables or configure the `httpx.Client` manually using `huggingface_hub.set_client_factory`." + " See https://www.python-httpx.org/advanced/proxies/ for more details." + ) + + return new_kwargs + + def smoothly_deprecate_use_auth_token(fn_name: str, has_token: bool, kwargs: Dict[str, Any]) -> Dict[str, Any]: """Smoothly deprecate `use_auth_token` in the `huggingface_hub` codebase. diff --git a/src/huggingface_hub/utils/_xet.py b/src/huggingface_hub/utils/_xet.py index 3dcf99068f..c49c8f88f0 100644 --- a/src/huggingface_hub/utils/_xet.py +++ b/src/huggingface_hub/utils/_xet.py @@ -2,7 +2,7 @@ from enum import Enum from typing import Dict, Optional -import requests +import httpx from .. import constants from . import get_session, hf_raise_for_status, validate_hf_hub_args @@ -27,7 +27,7 @@ class XetConnectionInfo: def parse_xet_file_data_from_response( - response: requests.Response, endpoint: Optional[str] = None + response: httpx.Response, endpoint: Optional[str] = None ) -> Optional[XetFileData]: """ Parse XET file metadata from an HTTP response. @@ -36,7 +36,7 @@ def parse_xet_file_data_from_response( of a given response object. If the required metadata is not found, it returns `None`. Args: - response (`requests.Response`): + response (`httpx.Response`): The HTTP response object containing headers dict and links dict to extract the XET metadata from. Returns: `Optional[XetFileData]`: diff --git a/src/huggingface_hub/utils/tqdm.py b/src/huggingface_hub/utils/tqdm.py index 4c1fcef4be..46bd0ace67 100644 --- a/src/huggingface_hub/utils/tqdm.py +++ b/src/huggingface_hub/utils/tqdm.py @@ -248,7 +248,7 @@ def tqdm_stream_file(path: Union[Path, str]) -> Iterator[io.BufferedReader]: Example: ```py >>> with tqdm_stream_file("config.json") as f: - >>> requests.put(url, data=f) + >>> httpx.put(url, data=f) config.json: 100%|█████████████████████████| 8.19k/8.19k [00:02<00:00, 3.72kB/s] ``` """ diff --git a/tests/test_cli.py b/tests/test_cli.py index ae9ebc7886..ab7d819ff0 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -331,7 +331,7 @@ def test_upload_file_no_revision_mock(self, create_mock: Mock, upload_mock: Mock def test_upload_file_with_revision_mock( self, create_mock: Mock, upload_mock: Mock, repo_info_mock: Mock, create_branch_mock: Mock ) -> None: - repo_info_mock.side_effect = RevisionNotFoundError("revision not found") + repo_info_mock.side_effect = RevisionNotFoundError("revision not found", response=Mock()) with SoftTemporaryDirectory() as cache_dir: file_path = Path(cache_dir) / "file.txt" @@ -853,8 +853,8 @@ def setUp(self) -> None: commands_parser = self.parser.add_subparsers() JobsCommands.register_subcommand(commands_parser) - patch_requests_post = patch( - "requests.Session.post", + patch_httpx_post = patch( + "httpx.Client.post", return_value=DummyResponse( { "id": "my-job-id", @@ -872,14 +872,14 @@ def setUp(self) -> None: patch_repo_info = patch("huggingface_hub.hf_api.HfApi.repo_info") patch_upload_file = patch("huggingface_hub.hf_api.HfApi.upload_file", return_value=DummyCommit(oid="ae068f")) - @patch_requests_post + @patch_httpx_post @patch_whoami - def test_run(self, whoami: Mock, requests_post: Mock) -> None: + def test_run(self, whoami: Mock, httpx_post: Mock) -> None: input_args = ["jobs", "run", "--detach", "ubuntu", "echo", "hello"] cmd = RunCommand(self.parser.parse_args(input_args)) cmd.run() - assert requests_post.call_count == 1 - args, kwargs = requests_post.call_args_list[0] + assert httpx_post.call_count == 1 + args, kwargs = httpx_post.call_args_list[0] assert args == ("https://huggingface.co/api/jobs/my-username",) assert kwargs["json"] == { "command": ["echo", "hello"], @@ -890,7 +890,7 @@ def test_run(self, whoami: Mock, requests_post: Mock) -> None: } @patch( - "requests.Session.post", + "httpx.Client.post", return_value=DummyResponse( { "id": "my-job-id", @@ -905,12 +905,12 @@ def test_run(self, whoami: Mock, requests_post: Mock) -> None: ), ) @patch("huggingface_hub.hf_api.HfApi.whoami", return_value={"name": "my-username"}) - def test_create_scheduled_job(self, whoami: Mock, requests_post: Mock) -> None: + def test_create_scheduled_job(self, whoami: Mock, httpx_mock: Mock) -> None: input_args = ["jobs", "scheduled", "run", "@hourly", "ubuntu", "echo", "hello"] cmd = ScheduledRunCommand(self.parser.parse_args(input_args)) cmd.run() - assert requests_post.call_count == 1 - args, kwargs = requests_post.call_args_list[0] + assert httpx_mock.call_count == 1 + args, kwargs = httpx_mock.call_args_list[0] assert args == ("https://huggingface.co/api/scheduled-jobs/my-username",) assert kwargs["json"] == { "jobSpec": { @@ -923,14 +923,14 @@ def test_create_scheduled_job(self, whoami: Mock, requests_post: Mock) -> None: "schedule": "@hourly", } - @patch_requests_post + @patch_httpx_post @patch_whoami - def test_uv_command(self, whoami: Mock, requests_post: Mock) -> None: + def test_uv_command(self, whoami: Mock, httpx_post: Mock) -> None: input_args = ["jobs", "uv", "run", "--detach", "echo", "hello"] cmd = UvCommand(self.parser.parse_args(input_args)) cmd.run() - assert requests_post.call_count == 1 - args, kwargs = requests_post.call_args_list[0] + assert httpx_post.call_count == 1 + args, kwargs = httpx_post.call_args_list[0] assert args == ("https://huggingface.co/api/jobs/my-username",) assert kwargs["json"] == { "command": ["uv", "run", "echo", "hello"], @@ -940,14 +940,14 @@ def test_uv_command(self, whoami: Mock, requests_post: Mock) -> None: "dockerImage": "ghcr.io/astral-sh/uv:python3.12-bookworm", } - @patch_requests_post + @patch_httpx_post @patch_whoami - def test_uv_remote_script(self, whoami: Mock, requests_post: Mock) -> None: + def test_uv_remote_script(self, whoami: Mock, httpx_post: Mock) -> None: input_args = ["jobs", "uv", "run", "--detach", "https://.../script.py"] cmd = UvCommand(self.parser.parse_args(input_args)) cmd.run() - assert requests_post.call_count == 1 - args, kwargs = requests_post.call_args_list[0] + assert httpx_post.call_count == 1 + args, kwargs = httpx_post.call_args_list[0] assert args == ("https://huggingface.co/api/jobs/my-username",) assert kwargs["json"] == { "command": ["uv", "run", "https://.../script.py"], @@ -957,19 +957,19 @@ def test_uv_remote_script(self, whoami: Mock, requests_post: Mock) -> None: "dockerImage": "ghcr.io/astral-sh/uv:python3.12-bookworm", } - @patch_requests_post + @patch_httpx_post @patch_whoami @patch_get_token @patch_repo_info @patch_upload_file def test_uv_local_script( - self, upload_file: Mock, repo_info: Mock, get_token: Mock, whoami: Mock, requests_post: Mock + self, upload_file: Mock, repo_info: Mock, get_token: Mock, whoami: Mock, httpx_post: Mock ) -> None: input_args = ["jobs", "uv", "run", "--detach", __file__] cmd = UvCommand(self.parser.parse_args(input_args)) cmd.run() - assert requests_post.call_count == 1 - args, kwargs = requests_post.call_args_list[0] + assert httpx_post.call_count == 1 + args, kwargs = httpx_post.call_args_list[0] assert args == ("https://huggingface.co/api/jobs/my-username",) command = kwargs["json"].pop("command") assert "UV_SCRIPT_URL" in " ".join(command) diff --git a/tests/test_commit_scheduler.py b/tests/test_commit_scheduler.py index a38d8cb947..872f5c6e44 100644 --- a/tests/test_commit_scheduler.py +++ b/tests/test_commit_scheduler.py @@ -206,13 +206,22 @@ def test_read_partial_file_too_much(self) -> None: self.assertEqual(file.read(20), b"12345") def test_partial_file_len(self) -> None: - """Useful for `requests` internally.""" + """Useful for httpx internally.""" file = PartialFileIO(self.file_path, size_limit=5) self.assertEqual(len(file), 5) file = PartialFileIO(self.file_path, size_limit=50) self.assertEqual(len(file), 9) + def test_partial_file_fileno(self) -> None: + """We explicitly do not implement fileno() to avoid misuse. + + httpx tries to use it to check file size which we don't want for PartialFileIO. + """ + file = PartialFileIO(self.file_path, size_limit=5) + with self.assertRaises(AttributeError): + file.fileno() + def test_partial_file_seek_and_tell(self) -> None: file = PartialFileIO(self.file_path, size_limit=5) diff --git a/tests/test_file_download.py b/tests/test_file_download.py index f5ab794a0c..bb76af9c47 100644 --- a/tests/test_file_download.py +++ b/tests/test_file_download.py @@ -22,9 +22,8 @@ from typing import Iterable, List from unittest.mock import Mock, patch +import httpx import pytest -import requests -from requests import Response import huggingface_hub.file_download from huggingface_hub import HfApi, RepoUrl, constants @@ -37,7 +36,6 @@ _create_symlink, _get_pointer_path, _normalize_etag, - _request_wrapper, get_hf_file_metadata, hf_hub_download, hf_hub_url, @@ -46,6 +44,7 @@ ) from huggingface_hub.utils import SoftTemporaryDirectory, get_session, hf_raise_for_status, is_hf_transfer_available from huggingface_hub.utils._headers import build_hf_headers +from huggingface_hub.utils._http import _http_backoff_base from .testing_constants import ENDPOINT_STAGING, OTHER_TOKEN, TOKEN from .testing_utils import ( @@ -307,7 +306,7 @@ def _check_user_agent(headers: dict): assert "foo/bar" in headers["user-agent"] with SoftTemporaryDirectory() as cache_dir: - with patch("huggingface_hub.file_download._request_wrapper", wraps=_request_wrapper) as mock_request: + with patch("huggingface_hub.utils._http._http_backoff_base", wraps=_http_backoff_base) as mock_request: # First download hf_hub_download( DUMMY_MODEL_ID, @@ -322,7 +321,7 @@ def _check_user_agent(headers: dict): for call in calls: _check_user_agent(call.kwargs["headers"]) - with patch("huggingface_hub.file_download._request_wrapper", wraps=_request_wrapper) as mock_request: + with patch("huggingface_hub.utils._http._http_backoff_base", wraps=_http_backoff_base) as mock_request: # Second download: no GET call hf_hub_download( DUMMY_MODEL_ID, @@ -926,17 +925,17 @@ def test_http_get_with_ssl_and_timeout_error(self, caplog): def _iter_content_1() -> Iterable[bytes]: yield b"0" * 10 yield b"0" * 10 - raise requests.exceptions.SSLError("Fake SSLError") + raise httpx.ConnectError("Fake ConnectError") def _iter_content_2() -> Iterable[bytes]: yield b"0" * 10 - raise requests.ReadTimeout("Fake ReadTimeout") + raise httpx.TimeoutException("Fake TimeoutException") def _iter_content_3() -> Iterable[bytes]: yield b"0" * 10 yield b"0" * 10 yield b"0" * 10 - raise requests.ConnectionError("Fake ConnectionError") + raise httpx.ConnectError("Fake ConnectionError") def _iter_content_4() -> Iterable[bytes]: yield b"0" * 10 @@ -944,15 +943,21 @@ def _iter_content_4() -> Iterable[bytes]: yield b"0" * 10 yield b"0" * 10 - with patch("huggingface_hub.file_download._request_wrapper") as mock: - mock.return_value.headers = {"Content-Length": 100} - mock.return_value.iter_content.side_effect = [ + with patch("huggingface_hub.file_download.http_stream_backoff") as mock_stream_backoff: + # Create a mock response object + mock_response = Mock() + mock_response.headers = {"Content-Length": "100"} + mock_response.iter_bytes.side_effect = [ _iter_content_1(), _iter_content_2(), _iter_content_3(), _iter_content_4(), ] + # Mock the context manager behavior + mock_stream_backoff.return_value.__enter__.return_value = mock_response + mock_stream_backoff.return_value.__exit__.return_value = None + temp_file = io.BytesIO() http_get("fake_url", temp_file=temp_file) @@ -964,11 +969,9 @@ def _iter_content_4() -> Iterable[bytes]: assert temp_file.getvalue() == b"0" * 100 # Check number of calls + correct range headers - assert len(mock.call_args_list) == 4 - assert mock.call_args_list[0].kwargs["headers"] == {} - assert mock.call_args_list[1].kwargs["headers"] == {"Range": "bytes=20-"} - assert mock.call_args_list[2].kwargs["headers"] == {"Range": "bytes=30-"} - assert mock.call_args_list[3].kwargs["headers"] == {"Range": "bytes=60-"} + assert len(mock_response.iter_bytes.call_args_list) == 4 + # Note: The range headers are now handled internally by http_get's retry mechanism + # The test verifies that the download completed successfully after retries @pytest.mark.parametrize( "initial_range,expected_ranges", @@ -1009,17 +1012,17 @@ def test_http_get_with_range_headers(self, caplog, initial_range: str, expected_ def _iter_content_1() -> Iterable[bytes]: yield b"0" * 10 yield b"0" * 10 - raise requests.exceptions.SSLError("Fake SSLError") + raise httpx.ConnectError("Fake ConnectError") def _iter_content_2() -> Iterable[bytes]: yield b"0" * 10 - raise requests.ReadTimeout("Fake ReadTimeout") + raise httpx.TimeoutException("Fake TimeoutException") def _iter_content_3() -> Iterable[bytes]: yield b"0" * 10 yield b"0" * 10 yield b"0" * 10 - raise requests.ConnectionError("Fake ConnectionError") + raise httpx.ConnectError("Fake ConnectionError") def _iter_content_4() -> Iterable[bytes]: yield b"0" * 10 @@ -1027,15 +1030,21 @@ def _iter_content_4() -> Iterable[bytes]: yield b"0" * 10 yield b"0" * 10 - with patch("huggingface_hub.file_download._request_wrapper") as mock: - mock.return_value.headers = {"Content-Length": 100} - mock.return_value.iter_content.side_effect = [ + with patch("huggingface_hub.file_download.http_stream_backoff") as mock_stream_backoff: + # Create a mock response object + mock_response = Mock() + mock_response.headers = {"Content-Length": "100"} + mock_response.iter_bytes.side_effect = [ _iter_content_1(), _iter_content_2(), _iter_content_3(), _iter_content_4(), ] + # Mock the context manager behavior + mock_stream_backoff.return_value.__enter__.return_value = mock_response + mock_stream_backoff.return_value.__exit__.return_value = None + temp_file = io.BytesIO() http_get("fake_url", temp_file=temp_file, headers={"Range": initial_range}) @@ -1045,9 +1054,10 @@ def _iter_content_4() -> Iterable[bytes]: assert temp_file.tell() == 100 assert temp_file.getvalue() == b"0" * 100 - assert len(mock.call_args_list) == 4 + # Check that http_stream_backoff was called with the correct range headers + assert len(mock_stream_backoff.call_args_list) == 4 for i, expected_range in enumerate(expected_ranges): - assert mock.call_args_list[i].kwargs["headers"] == {"Range": expected_range} + assert mock_stream_backoff.call_args_list[i].kwargs["headers"] == {"Range": expected_range} class CreateSymlinkTest(unittest.TestCase): @@ -1125,20 +1135,19 @@ def test_weak_reference(self): @with_production_testing def test_resolve_endpoint_on_regular_file(self): url = "https://huggingface.co/gpt2/resolve/e7da7f221d5bf496a48136c0cd264e630fe9fcc8/README.md" - response = requests.head(url, headers=build_hf_headers(user_agent="is_ci/true")) + response = httpx.head(url, headers=build_hf_headers(user_agent="is_ci/true")) self.assertEqual(self._get_etag_and_normalize(response), "a16a55fda99d2f2e7b69cce5cf93ff4ad3049930") @with_production_testing def test_resolve_endpoint_on_lfs_file(self): url = "https://huggingface.co/gpt2/resolve/e7da7f221d5bf496a48136c0cd264e630fe9fcc8/pytorch_model.bin" - response = requests.head(url, headers=build_hf_headers(user_agent="is_ci/true")) + response = httpx.head(url, headers=build_hf_headers(user_agent="is_ci/true")) self.assertEqual( self._get_etag_and_normalize(response), "7c5d3f4b8b76583b422fcb9189ad6c89d5d97a094541ce8932dce3ecabde1421" ) @staticmethod - def _get_etag_and_normalize(response: Response) -> str: - response.raise_for_status() + def _get_etag_and_normalize(response: httpx.Response) -> str: return _normalize_etag( response.headers.get(constants.HUGGINGFACE_HEADER_X_LINKED_ETAG) or response.headers.get("ETag") ) diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index abb2f5e3f0..ce5c08d2e2 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -30,8 +30,6 @@ from urllib.parse import quote, urlparse import pytest -import requests -from requests.exceptions import HTTPError import huggingface_hub.lfs from huggingface_hub import HfApi, SpaceHardware, SpaceStage, SpaceStorage, constants @@ -197,7 +195,7 @@ def test_delete_repo_error_message(self): # test for #751 # See https://github.com/huggingface/huggingface_hub/issues/751 with self.assertRaisesRegex( - requests.exceptions.HTTPError, + HfHubHTTPError, re.compile( r"404 Client Error(.+)\(Request ID: .+\)(.*)Repository Not Found", flags=re.DOTALL, @@ -607,7 +605,7 @@ def test_create_commit_create_pr(self, repo_url: RepoUrl) -> None: self.assertEqual(resp.pr_revision, "refs/pr/1") # File doesn't exist on main... - with self.assertRaises(HTTPError) as ctx: + with self.assertRaises(HfHubHTTPError) as ctx: # Should raise a 404 self._api.hf_hub_download(repo_id, "buffer") self.assertEqual(ctx.exception.response.status_code, 404) @@ -708,7 +706,7 @@ def test_create_commit(self, repo_url: RepoUrl) -> None: self.assertIsNone(resp.pr_num) self.assertIsNone(resp.pr_revision) - with self.assertRaises(HTTPError): + with self.assertRaises(HfHubHTTPError): # Should raise a 404 hf_hub_download(repo_id, "temp/new_file.md") @@ -737,7 +735,7 @@ def test_create_commit_conflict(self, repo_url: RepoUrl) -> None: operations = [ CommitOperationAdd(path_in_repo="buffer", path_or_fileobj=b"Buffer data"), ] - with self.assertRaises(HTTPError) as exc_ctx: + with self.assertRaises(HfHubHTTPError) as exc_ctx: self._api.create_commit( operations=operations, commit_message="Test create_commit", @@ -1592,7 +1590,7 @@ def test_create_tag_on_commit_oid(self, repo_url: RepoUrl) -> None: @use_tmp_repo("model") def test_invalid_tag_name(self, repo_url: RepoUrl) -> None: """Check `create_tag` with an invalid tag name.""" - with self.assertRaises(HTTPError): + with self.assertRaises(HfHubHTTPError): self._api.create_tag(repo_url.repo_id, tag="invalid tag") @use_tmp_repo("model") @@ -2572,7 +2570,7 @@ def test_model_info(self, mock_get_token: Mock) -> None: with patch.object(self._api, "token", None): # no default token # Test we cannot access model info without a token with self.assertRaisesRegex( - requests.exceptions.HTTPError, + HfHubHTTPError, re.compile( r"401 Client Error(.+)\(Request ID: .+\)(.*)Repository Not Found", flags=re.DOTALL, @@ -2588,7 +2586,7 @@ def test_dataset_info(self, mock_get_token: Mock) -> None: with patch.object(self._api, "token", None): # no default token # Test we cannot access model info without a token with self.assertRaisesRegex( - requests.exceptions.HTTPError, + HfHubHTTPError, re.compile( r"401 Client Error(.+)\(Request ID: .+\)(.*)Repository Not Found", flags=re.DOTALL, @@ -3456,7 +3454,6 @@ def test_hf_hub_download_alias(self, mock: Mock) -> None: local_dir_use_symlinks="auto", force_download=False, force_filename=None, - proxies=None, etag_timeout=10, resume_download=None, local_files_only=False, @@ -3481,7 +3478,6 @@ def test_snapshot_download_alias(self, mock: Mock) -> None: cache_dir=None, local_dir=None, local_dir_use_symlinks="auto", - proxies=None, etag_timeout=10, resume_download=None, force_download=False, @@ -4055,7 +4051,7 @@ def test_create_collection_exists_ok(self) -> None: self.slug = collection_1.slug # Cannot create twice with same title - with self.assertRaises(HTTPError): # already exists + with self.assertRaises(HfHubHTTPError): # already exists self._api.create_collection(self.title) # Can ignore error @@ -4071,7 +4067,7 @@ def test_create_private_collection(self) -> None: # Get private collection self._api.get_collection(collection.slug) # no error - with self.assertRaises(HTTPError): + with self.assertRaises(HfHubHTTPError): self._api.get_collection(collection.slug, token=OTHER_TOKEN) # not authorized # Get public collection @@ -4113,7 +4109,7 @@ def test_delete_collection(self) -> None: self._api.delete_collection(collection.slug) # Cannot delete twice the same collection - with self.assertRaises(HTTPError): # already exists + with self.assertRaises(HfHubHTTPError): # already exists self._api.delete_collection(collection.slug) # Possible to ignore error @@ -4141,12 +4137,12 @@ def test_collection_items(self) -> None: self.assertIsNone(collection.items[1].note) # Add existing item fails (except if ignore error) - with self.assertRaises(HTTPError): + with self.assertRaises(HfHubHTTPError): self._api.add_collection_item(collection.slug, model_id, "model") self._api.add_collection_item(collection.slug, model_id, "model", exists_ok=True) # Add inexistent item fails - with self.assertRaises(HTTPError): + with self.assertRaises(HfHubHTTPError): self._api.add_collection_item(collection.slug, model_id, "dataset") # Update first item @@ -4247,21 +4243,21 @@ def test_access_request_error(self): self._api.grant_access(self.repo_id, OTHER_USER) # Cannot grant twice - with self.assertRaises(HTTPError): + with self.assertRaises(HfHubHTTPError): self._api.grant_access(self.repo_id, OTHER_USER) # Cannot accept to already accepted - with self.assertRaises(HTTPError): + with self.assertRaises(HfHubHTTPError): self._api.accept_access_request(self.repo_id, OTHER_USER) # Cannot reject to already rejected self._api.reject_access_request(self.repo_id, OTHER_USER, rejection_reason="This is a rejection reason") - with self.assertRaises(HTTPError): + with self.assertRaises(HfHubHTTPError): self._api.reject_access_request(self.repo_id, OTHER_USER, rejection_reason="This is a rejection reason") # Cannot cancel to already cancelled self._api.cancel_access_request(self.repo_id, OTHER_USER) - with self.assertRaises(HTTPError): + with self.assertRaises(HfHubHTTPError): self._api.cancel_access_request(self.repo_id, OTHER_USER) @@ -4379,7 +4375,7 @@ def test_delete_webhook(self) -> None: url=self.webhook_url, watched=self.watched_items, domains=self.domains, secret=self.secret ) self._api.delete_webhook(webhook_to_delete.id) - with self.assertRaises(HTTPError): + with self.assertRaises(HfHubHTTPError): self._api.get_webhook(webhook_to_delete.id) diff --git a/tests/test_hf_file_system.py b/tests/test_hf_file_system.py index d30151d5fd..fa7ad0419d 100644 --- a/tests/test_hf_file_system.py +++ b/tests/test_hf_file_system.py @@ -6,7 +6,7 @@ import unittest from pathlib import Path from typing import Optional -from unittest.mock import patch +from unittest.mock import Mock, patch import fsspec import pytest @@ -192,9 +192,9 @@ def test_stream_file_retry(self): self.assertIsInstance(f, HfFileSystemStreamFile) self.assertEqual(f.read(6), b"dummy ") # Simulate that streaming fails mid-way - f.response.raw.read = None + f.response = None self.assertEqual(f.read(6), b"binary") - self.assertIsNotNone(f.response.raw.read) # a new connection has been created + self.assertIsNotNone(f.response) # a new connection has been created def test_read_file_with_revision(self): with self.hffs.open(self.hf_path + "/data/binary_data_for_pr.bin", "rb", revision="refs/pr/1") as f: @@ -577,9 +577,9 @@ def test_resolve_path_with_refs_revision() -> None: def mock_repo_info(fs: HfFileSystem): def _inner(repo_id: str, *, revision: str, repo_type: str, **kwargs): if repo_id not in ["gpt2", "squad", "username/my_dataset", "username/my_model"]: - raise RepositoryNotFoundError(repo_id) + raise RepositoryNotFoundError(repo_id, response=Mock()) if revision is not None and revision not in ["main", "dev", "refs"] and not revision.startswith("refs/"): - raise RevisionNotFoundError(revision) + raise RevisionNotFoundError(revision, response=Mock()) return patch.object(fs._api, "repo_info", _inner) diff --git a/tests/test_hub_mixin.py b/tests/test_hub_mixin.py index 4dbf888c61..90582e846d 100644 --- a/tests/test_hub_mixin.py +++ b/tests/test_hub_mixin.py @@ -126,7 +126,6 @@ def _from_pretrained( revision: Optional[str], cache_dir: Optional[Union[str, Path]], force_download: bool, - proxies: Optional[Dict], resume_download: bool, local_files_only: bool, token: Optional[Union[str, bool]], @@ -341,7 +340,6 @@ def test_from_pretrained_model_id_and_revision(self, from_pretrained_mock: Mock) revision="123456789", # Revision is passed correctly! cache_dir=None, force_download=False, - proxies=None, resume_download=None, local_files_only=False, token=None, diff --git a/tests/test_hub_mixin_pytorch.py b/tests/test_hub_mixin_pytorch.py index c9494accbc..dd965189fe 100644 --- a/tests/test_hub_mixin_pytorch.py +++ b/tests/test_hub_mixin_pytorch.py @@ -10,7 +10,7 @@ import pytest from huggingface_hub import HfApi, ModelCard, constants, hf_hub_download -from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError +from huggingface_hub.errors import RemoteEntryNotFoundError from huggingface_hub.hub_mixin import ModelHubMixin, PyTorchModelHubMixin from huggingface_hub.serialization._torch import storage_ptr from huggingface_hub.utils import SoftTemporaryDirectory, is_torch_available @@ -195,7 +195,7 @@ def test_from_pretrained_model_id_only(self, from_pretrained_mock: Mock) -> None def pretend_file_download(self, **kwargs): if kwargs.get("filename") == "config.json": - raise HfHubHTTPError("no config") + raise RemoteEntryNotFoundError("no config", response=Mock()) DummyModel().save_pretrained(self.cache_dir) return self.cache_dir / "model.safetensors" @@ -209,7 +209,6 @@ def test_from_pretrained_model_from_hub_prefer_safetensor(self, hf_hub_download_ revision=None, cache_dir=None, force_download=False, - proxies=None, resume_download=None, token=None, local_files_only=False, @@ -219,7 +218,7 @@ def test_from_pretrained_model_from_hub_prefer_safetensor(self, hf_hub_download_ def pretend_file_download_fallback(self, **kwargs): filename = kwargs.get("filename") if filename == "model.safetensors" or filename == "config.json": - raise EntryNotFoundError("not found") + raise RemoteEntryNotFoundError("not found", response=Mock()) class TestMixin(ModelHubMixin): def _save_pretrained(self, save_directory: Path) -> None: @@ -238,7 +237,6 @@ def test_from_pretrained_model_from_hub_fallback_pickle(self, hf_hub_download_mo revision=None, cache_dir=None, force_download=False, - proxies=None, resume_download=None, token=None, local_files_only=False, @@ -249,7 +247,6 @@ def test_from_pretrained_model_from_hub_fallback_pickle(self, hf_hub_download_mo revision=None, cache_dir=None, force_download=False, - proxies=None, resume_download=None, token=None, local_files_only=False, @@ -266,7 +263,6 @@ def test_from_pretrained_model_id_and_revision(self, from_pretrained_mock: Mock) revision="123456789", # Revision is passed correctly! cache_dir=None, force_download=False, - proxies=None, resume_download=None, local_files_only=False, token=None, diff --git a/tests/test_inference_async_client.py b/tests/test_inference_async_client.py index cf60c9e2ad..ec2ee85dc3 100644 --- a/tests/test_inference_async_client.py +++ b/tests/test_inference_async_client.py @@ -299,7 +299,7 @@ def test_sync_vs_async_signatures() -> None: @pytest.mark.asyncio async def test_async_generate_timeout_error(monkeypatch: pytest.MonkeyPatch) -> None: - def _mock_aiohttp_client_timeout(*args, **kwargs): + async def _mock_client_post(*args, **kwargs): raise asyncio.TimeoutError def mock_check_supported_task(*args, **kwargs): @@ -308,9 +308,10 @@ def mock_check_supported_task(*args, **kwargs): monkeypatch.setattr( "huggingface_hub.inference._providers.hf_inference._check_supported_task", mock_check_supported_task ) - monkeypatch.setattr("aiohttp.ClientSession.post", _mock_aiohttp_client_timeout) + client = AsyncInferenceClient(timeout=1) + client._async_client = Mock(post=_mock_client_post) with pytest.raises(InferenceTimeoutError): - await AsyncInferenceClient(timeout=1).text_generation("test") + await client.text_generation("test") class CustomException(Exception): @@ -415,32 +416,3 @@ async def test_use_async_with_inference_client(): async with AsyncInferenceClient(): pass mock_close.assert_called_once() - - -@pytest.mark.asyncio -@patch("aiohttp.ClientSession._request") -async def test_client_responses_correctly_closed(request_mock: Mock) -> None: - """ - Regression test for #2521. - Async client must close the ClientResponse objects when exiting the async context manager. - Fixed by closing the response objects when the session is closed. - - See https://github.com/huggingface/huggingface_hub/issues/2521. - """ - async with AsyncInferenceClient() as client: - session = client._get_client_session() - response1 = await session.get("http://this-is-a-fake-url.com") - response2 = await session.post("http://this-is-a-fake-url.com", json={}) - - # Response objects are closed when the AsyncInferenceClient is closed - response1.close.assert_called_once() - response2.close.assert_called_once() - - -@pytest.mark.asyncio -async def test_warns_if_client_deleted_with_opened_sessions(): - client = AsyncInferenceClient() - session = client._get_client_session() - with pytest.warns(UserWarning): - client.__del__() - await session.close() diff --git a/tests/test_inference_client.py b/tests/test_inference_client.py index cf384db0d1..e2370aa708 100644 --- a/tests/test_inference_client.py +++ b/tests/test_inference_client.py @@ -894,7 +894,7 @@ def test_accept_header_image( response = client.text_to_image("An astronaut riding a horse") assert response == bytes_to_image_mock.return_value - headers = get_session_mock().post.call_args_list[0].kwargs["headers"] + headers = get_session_mock().stream.call_args_list[0].kwargs["headers"] assert headers["Accept"] == "image/png" @@ -993,20 +993,20 @@ def test_token_initialization_cannot_be_token_false(self): @pytest.mark.parametrize( "stop_signal", [ - b"data: [DONE]", - b"data: [DONE]\n", - b"data: [DONE] ", + "data: [DONE]", + "data: [DONE]\n", + "data: [DONE] ", ], ) def test_stream_text_generation_response(stop_signal: bytes): data = [ - b'data: {"index":1,"token":{"id":4560,"text":" trying","logprob":-2.078125,"special":false},"generated_text":null,"details":null}', - b"", # Empty line is skipped - b"\n", # Newline is skipped - b'data: {"index":2,"token":{"id":311,"text":" to","logprob":-0.026245117,"special":false},"generated_text":" trying to","details":null}', + 'data: {"index":1,"token":{"id":4560,"text":" trying","logprob":-2.078125,"special":false},"generated_text":null,"details":null}', + "", # Empty line is skipped + "\n", # Newline is skipped + 'data: {"index":2,"token":{"id":311,"text":" to","logprob":-0.026245117,"special":false},"generated_text":" trying to","details":null}', stop_signal, # Stop signal # Won't parse after - b'data: {"index":2,"token":{"id":311,"text":" to","logprob":-0.026245117,"special":false},"generated_text":" trying to","details":null}', + 'data: {"index":2,"token":{"id":311,"text":" to","logprob":-0.026245117,"special":false},"generated_text":" trying to","details":null}', ] output = list(_stream_text_generation_response(data, details=False)) assert len(output) == 2 @@ -1016,20 +1016,20 @@ def test_stream_text_generation_response(stop_signal: bytes): @pytest.mark.parametrize( "stop_signal", [ - b"data: [DONE]", - b"data: [DONE]\n", - b"data: [DONE] ", + "data: [DONE]", + "data: [DONE]\n", + "data: [DONE] ", ], ) def test_stream_chat_completion_response(stop_signal: bytes): data = [ - b'data: {"object":"chat.completion.chunk","id":"","created":1721737661,"model":"","system_fingerprint":"2.1.2-dev0-sha-5fca30e","choices":[{"index":0,"delta":{"role":"assistant","content":"Both"},"logprobs":null,"finish_reason":null}]}', - b"", # Empty line is skipped - b"\n", # Newline is skipped - b'data: {"object":"chat.completion.chunk","id":"","created":1721737661,"model":"","system_fingerprint":"2.1.2-dev0-sha-5fca30e","choices":[{"index":0,"delta":{"role":"assistant","content":" Rust"},"logprobs":null,"finish_reason":null}]}', + 'data: {"object":"chat.completion.chunk","id":"","created":1721737661,"model":"","system_fingerprint":"2.1.2-dev0-sha-5fca30e","choices":[{"index":0,"delta":{"role":"assistant","content":"Both"},"logprobs":null,"finish_reason":null}]}', + "", # Empty line is skipped + "\n", # Newline is skipped + 'data: {"object":"chat.completion.chunk","id":"","created":1721737661,"model":"","system_fingerprint":"2.1.2-dev0-sha-5fca30e","choices":[{"index":0,"delta":{"role":"assistant","content":" Rust"},"logprobs":null,"finish_reason":null}]}', stop_signal, # Stop signal # Won't parse after - b'data: {"index":2,"token":{"id":311,"text":" to","logprob":-0.026245117,"special":false},"generated_text":" trying to","details":null}', + 'data: {"index":2,"token":{"id":311,"text":" to","logprob":-0.026245117,"special":false},"generated_text":" trying to","details":null}', ] output = list(_stream_chat_completion_response(data)) assert len(output) == 2 @@ -1043,8 +1043,8 @@ def test_chat_completion_error_in_stream(): When an error is encountered in the stream, it should raise a TextGenerationError (e.g. a ValidationError). """ data = [ - b'data: {"object":"chat.completion.chunk","id":"","created":1721737661,"model":"","system_fingerprint":"2.1.2-dev0-sha-5fca30e","choices":[{"index":0,"delta":{"role":"assistant","content":"Both"},"logprobs":null,"finish_reason":null}]}', - b'data: {"error":"Input validation error: `inputs` tokens + `max_new_tokens` must be <= 4096. Given: 6 `inputs` tokens and 4091 `max_new_tokens`","error_type":"validation"}', + 'data: {"object":"chat.completion.chunk","id":"","created":1721737661,"model":"","system_fingerprint":"2.1.2-dev0-sha-5fca30e","choices":[{"index":0,"delta":{"role":"assistant","content":"Both"},"logprobs":null,"finish_reason":null}]}', + 'data: {"error":"Input validation error: `inputs` tokens + `max_new_tokens` must be <= 4096. Given: 6 `inputs` tokens and 4091 `max_new_tokens`","error_type":"validation"}', ] with pytest.raises(ValidationError): for token in _stream_chat_completion_response(data): diff --git a/tests/test_inference_text_generation.py b/tests/test_inference_text_generation.py index 1015f81327..3135172e9d 100644 --- a/tests/test_inference_text_generation.py +++ b/tests/test_inference_text_generation.py @@ -8,9 +8,9 @@ from unittest.mock import MagicMock, patch import pytest -from requests import HTTPError from huggingface_hub import InferenceClient, TextGenerationOutputPrefillToken +from huggingface_hub.errors import HfHubHTTPError from huggingface_hub.inference._common import ( _UNSUPPORTED_TEXT_GENERATION_KWARGS, GenerationError, @@ -46,7 +46,7 @@ def test_validation_error(self): def _mocked_error(payload: Dict) -> MagicMock: - error = HTTPError(response=MagicMock()) + error = HfHubHTTPError("message", response=MagicMock()) error.response.json.return_value = payload return error diff --git a/tests/test_oauth.py b/tests/test_oauth.py index 156069ec63..0bf0a98e74 100644 --- a/tests/test_oauth.py +++ b/tests/test_oauth.py @@ -18,8 +18,8 @@ from dataclasses import asdict from unittest.mock import patch +import httpx import pytest -import requests import starlette.datastructures from fastapi import FastAPI, Request from fastapi.testclient import TestClient @@ -98,8 +98,8 @@ def test_oauth_workflow(client: TestClient): # Make call to HF Hub assert location.startswith("https://hub-ci.huggingface.co/oauth/authorize") location_authorize = location - response_authorize = requests.get( - location_authorize, headers={"cookie": "token=huggingface-hub.js-cookie"}, allow_redirects=False + response_authorize = httpx.get( + location_authorize, headers={"cookie": "token=huggingface-hub.js-cookie"}, follow_redirects=False ) assert response_authorize.status_code == 303 assert "location" in response_authorize.headers diff --git a/tests/test_offline_utils.py b/tests/test_offline_utils.py index cb9bf28fa2..52bf3862be 100644 --- a/tests/test_offline_utils.py +++ b/tests/test_offline_utils.py @@ -1,36 +1,34 @@ from io import BytesIO +import httpx import pytest -import requests from huggingface_hub.file_download import http_get -from .testing_utils import ( - OfflineSimulationMode, - RequestWouldHangIndefinitelyError, - offline, -) +from .testing_utils import OfflineSimulationMode, RequestWouldHangIndefinitelyError, offline def test_offline_with_timeout(): with offline(OfflineSimulationMode.CONNECTION_TIMES_OUT): with pytest.raises(RequestWouldHangIndefinitelyError): - requests.request("GET", "https://huggingface.co") - with pytest.raises(requests.exceptions.ConnectTimeout): - requests.request("GET", "https://huggingface.co", timeout=1.0) - with pytest.raises(requests.exceptions.ConnectTimeout): + httpx.request("GET", "https://huggingface.co") + with pytest.raises(httpx.ConnectTimeout): + httpx.request("GET", "https://huggingface.co", timeout=1.0) + with pytest.raises(httpx.ConnectTimeout): http_get("https://huggingface.co", BytesIO()) def test_offline_with_connection_error(): with offline(OfflineSimulationMode.CONNECTION_FAILS): - with pytest.raises(requests.exceptions.ConnectionError): - requests.request("GET", "https://huggingface.co") - with pytest.raises(requests.exceptions.ConnectionError): + with pytest.raises(httpx.ConnectError): + httpx.request("GET", "https://huggingface.co") + with pytest.raises(httpx.ConnectError): http_get("https://huggingface.co", BytesIO()) def test_offline_with_datasets_offline_mode_enabled(): with offline(OfflineSimulationMode.HF_HUB_OFFLINE_SET_TO_1): - with pytest.raises(ConnectionError): + from huggingface_hub.errors import OfflineModeIsEnabled + + with pytest.raises(OfflineModeIsEnabled): http_get("https://huggingface.co", BytesIO()) diff --git a/tests/test_repository.py b/tests/test_repository.py index b000d74ab3..772dc9850f 100644 --- a/tests/test_repository.py +++ b/tests/test_repository.py @@ -17,8 +17,8 @@ import unittest from pathlib import Path +import httpx import pytest -import requests from huggingface_hub import RepoUrl from huggingface_hub.hf_api import HfApi @@ -280,7 +280,7 @@ def test_add_commit_push(self): # Check that the returned commit url # actually exists. - r = requests.head(url) + r = httpx.head(url) r.raise_for_status() def test_add_commit_push_non_blocking(self): @@ -302,7 +302,7 @@ def test_add_commit_push_non_blocking(self): # Check that the returned commit url # actually exists. - r = requests.head(url) + r = httpx.head(url) r.raise_for_status() def test_context_manager_non_blocking(self): diff --git a/tests/test_utils_cache.py b/tests/test_utils_cache.py index 2609867abd..efd8a961f3 100644 --- a/tests/test_utils_cache.py +++ b/tests/test_utils_cache.py @@ -772,13 +772,8 @@ def test_delete_path_on_missing_file(self) -> None: _try_delete_path(file_path, path_type="TYPE") # Assert warning message with traceback for debug purposes - self.assertEqual(len(captured.output), 1) - self.assertTrue( - captured.output[0].startswith( - "WARNING:huggingface_hub.utils._cache_manager:Couldn't delete TYPE:" - f" file not found ({file_path})\nTraceback (most recent call last):" - ) - ) + assert len(captured.output) > 0 + assert any(f"Couldn't delete TYPE: file not found ({file_path})" in log for log in captured.output) def test_delete_path_on_missing_folder(self) -> None: """Try delete a missing folder.""" @@ -788,13 +783,8 @@ def test_delete_path_on_missing_folder(self) -> None: _try_delete_path(dir_path, path_type="TYPE") # Assert warning message with traceback for debug purposes - self.assertEqual(len(captured.output), 1) - self.assertTrue( - captured.output[0].startswith( - "WARNING:huggingface_hub.utils._cache_manager:Couldn't delete TYPE:" - f" file not found ({dir_path})\nTraceback (most recent call last):" - ) - ) + assert len(captured.output) > 0 + assert any(f"Couldn't delete TYPE: file not found ({dir_path})" in log for log in captured.output) @xfail_on_windows(reason="Permissions are handled differently on Windows.") def test_delete_path_on_local_folder_with_wrong_permission(self) -> None: diff --git a/tests/test_utils_errors.py b/tests/test_utils_errors.py index a08b4e543e..84f250117a 100644 --- a/tests/test_utils_errors.py +++ b/tests/test_utils_errors.py @@ -1,7 +1,7 @@ import unittest import pytest -from requests.models import PreparedRequest, Response +from httpx import Request, Response from huggingface_hub.errors import ( BadRequestError, @@ -16,9 +16,8 @@ class TestErrorUtils(unittest.TestCase): def test_hf_raise_for_status_repo_not_found(self) -> None: - response = Response() - response.headers = {"X-Error-Code": "RepoNotFound", X_REQUEST_ID: 123} - response.status_code = 404 + response = Response(status_code=404, headers={"X-Error-Code": "RepoNotFound", X_REQUEST_ID: "123"}) + response.request = Request(method="GET", url="https://huggingface.co/fake") with self.assertRaisesRegex(RepositoryNotFoundError, "Repository Not Found") as context: hf_raise_for_status(response) @@ -26,10 +25,11 @@ def test_hf_raise_for_status_repo_not_found(self) -> None: assert "Request ID: 123" in str(context.exception) def test_hf_raise_for_status_disabled_repo(self) -> None: - response = Response() - response.headers = {"X-Error-Message": "Access to this resource is disabled.", X_REQUEST_ID: 123} + response = Response( + status_code=403, headers={"X-Error-Message": "Access to this resource is disabled.", X_REQUEST_ID: "123"} + ) + response.request = Request(method="GET", url="https://huggingface.co/fake") - response.status_code = 403 with self.assertRaises(DisabledRepoError) as context: hf_raise_for_status(response) @@ -37,11 +37,8 @@ def test_hf_raise_for_status_disabled_repo(self) -> None: assert "Request ID: 123" in str(context.exception) def test_hf_raise_for_status_401_repo_url_not_invalid_token(self) -> None: - response = Response() - response.headers = {X_REQUEST_ID: 123} - response.status_code = 401 - response.request = PreparedRequest() - response.request.url = "https://huggingface.co/api/models/username/reponame" + response = Response(status_code=401, headers={X_REQUEST_ID: "123"}) + response.request = Request(method="GET", url="https://huggingface.co/api/models/username/reponame") with self.assertRaisesRegex(RepositoryNotFoundError, "Repository Not Found") as context: hf_raise_for_status(response) @@ -49,11 +46,11 @@ def test_hf_raise_for_status_401_repo_url_not_invalid_token(self) -> None: assert "Request ID: 123" in str(context.exception) def test_hf_raise_for_status_401_repo_url_invalid_token(self) -> None: - response = Response() - response.headers = {X_REQUEST_ID: 123, "X-Error-Message": "Invalid credentials in Authorization header"} - response.status_code = 401 - response.request = PreparedRequest() - response.request.url = "https://huggingface.co/api/models/username/reponame" + response = Response( + status_code=401, + headers={X_REQUEST_ID: "123", "X-Error-Message": "Invalid credentials in Authorization header"}, + ) + response.request = Request(method="GET", url="https://huggingface.co/api/models/username/reponame") with self.assertRaisesRegex(HfHubHTTPError, "Invalid credentials in Authorization header") as context: hf_raise_for_status(response) @@ -61,11 +58,10 @@ def test_hf_raise_for_status_401_repo_url_invalid_token(self) -> None: assert "Request ID: 123" in str(context.exception) def test_hf_raise_for_status_403_wrong_token_scope(self) -> None: - response = Response() - response.headers = {X_REQUEST_ID: 123, "X-Error-Message": "specific error message"} - response.status_code = 403 - response.request = PreparedRequest() - response.request.url = "https://huggingface.co/api/repos/create" + response = Response( + status_code=403, headers={X_REQUEST_ID: "123", "X-Error-Message": "specific error message"} + ) + response.request = Request(method="GET", url="https://huggingface.co/api/repos/create") expected_message_part = "403 Forbidden: specific error message" with self.assertRaisesRegex(HfHubHTTPError, expected_message_part) as context: hf_raise_for_status(response) @@ -74,11 +70,8 @@ def test_hf_raise_for_status_403_wrong_token_scope(self) -> None: assert "Request ID: 123" in str(context.exception) def test_hf_raise_for_status_401_not_repo_url(self) -> None: - response = Response() - response.headers = {X_REQUEST_ID: 123} - response.status_code = 401 - response.request = PreparedRequest() - response.request.url = "https://huggingface.co/api/collections" + response = Response(status_code=401, headers={X_REQUEST_ID: "123"}) + response.request = Request(method="GET", url="https://huggingface.co/api/collections") with self.assertRaises(HfHubHTTPError) as context: hf_raise_for_status(response) @@ -86,9 +79,8 @@ def test_hf_raise_for_status_401_not_repo_url(self) -> None: assert "Request ID: 123" in str(context.exception) def test_hf_raise_for_status_revision_not_found(self) -> None: - response = Response() - response.headers = {"X-Error-Code": "RevisionNotFound", X_REQUEST_ID: 123} - response.status_code = 404 + response = Response(status_code=404, headers={"X-Error-Code": "RevisionNotFound", X_REQUEST_ID: "123"}) + response.request = Request(method="GET", url="https://huggingface.co/fake") with self.assertRaisesRegex(RevisionNotFoundError, "Revision Not Found") as context: hf_raise_for_status(response) @@ -96,9 +88,8 @@ def test_hf_raise_for_status_revision_not_found(self) -> None: assert "Request ID: 123" in str(context.exception) def test_hf_raise_for_status_entry_not_found(self) -> None: - response = Response() - response.headers = {"X-Error-Code": "EntryNotFound", X_REQUEST_ID: 123} - response.status_code = 404 + response = Response(status_code=404, headers={"X-Error-Code": "EntryNotFound", X_REQUEST_ID: "123"}) + response.request = Request(method="GET", url="https://huggingface.co/fake") with self.assertRaisesRegex(EntryNotFoundError, "Entry Not Found") as context: hf_raise_for_status(response) @@ -107,33 +98,29 @@ def test_hf_raise_for_status_entry_not_found(self) -> None: def test_hf_raise_for_status_bad_request_no_endpoint_name(self) -> None: """Test HTTPError converted to BadRequestError if error 400.""" - response = Response() - response.status_code = 400 + response = Response(status_code=400) + response.request = Request(method="GET", url="https://huggingface.co/fake") with self.assertRaisesRegex(BadRequestError, "Bad request:") as context: hf_raise_for_status(response) assert context.exception.response.status_code == 400 def test_hf_raise_for_status_bad_request_with_endpoint_name(self) -> None: """Test endpoint name is added to BadRequestError message.""" - response = Response() - response.status_code = 400 + response = Response(status_code=400) + response.request = Request(method="GET", url="https://huggingface.co/fake") with self.assertRaisesRegex(BadRequestError, "Bad request for preupload endpoint:") as context: hf_raise_for_status(response, endpoint_name="preupload") assert context.exception.response.status_code == 400 def test_hf_raise_for_status_fallback(self) -> None: """Test HTTPError is converted to HfHubHTTPError.""" - response = Response() - response.status_code = 404 - response.headers = { - X_REQUEST_ID: "test-id", - } - response.url = "test_URL" + response = Response(status_code=404, headers={X_REQUEST_ID: "test-id"}) + response.request = Request(method="GET", url="https://huggingface.co/fake") with self.assertRaisesRegex(HfHubHTTPError, "Request ID: test-id") as context: hf_raise_for_status(response) assert context.exception.response.status_code == 404 - assert context.exception.response.url == "test_URL" + assert context.exception.response.url == "https://huggingface.co/fake" class TestHfHubHTTPError(unittest.TestCase): @@ -141,9 +128,7 @@ class TestHfHubHTTPError(unittest.TestCase): def setUp(self) -> None: """Setup with a default response.""" - self.response = Response() - self.response.status_code = 404 - self.response.url = "test_URL" + self.response = Response(status_code=404, request=Request(method="GET", url="https://huggingface.co/fake")) def test_hf_hub_http_error_initialization(self) -> None: """Test HfHubHTTPError is initialized properly.""" diff --git a/tests/test_utils_http.py b/tests/test_utils_http.py index 07037e6aba..c35628f83a 100644 --- a/tests/test_utils_http.py +++ b/tests/test_utils_http.py @@ -7,19 +7,20 @@ from unittest.mock import Mock, call, patch from uuid import UUID +import httpx import pytest -import requests -from requests import ConnectTimeout, HTTPError +from httpx import ConnectTimeout, HTTPError from huggingface_hub.constants import ENDPOINT +from huggingface_hub.errors import OfflineModeIsEnabled from huggingface_hub.utils._http import ( - OfflineModeIsEnabled, + HfHubTransport, _adjust_range_header, - configure_http_backend, + default_client_factory, fix_hf_endpoint_in_url, get_session, http_backoff, - reset_sessions, + set_client_factory, ) @@ -63,7 +64,7 @@ def test_backoff_3_calls(self) -> None: def test_backoff_on_exception_until_max(self) -> None: """Test `http_backoff` until max limit is reached with exceptions.""" - self.mock_request.side_effect = ConnectTimeout() + self.mock_request.side_effect = ConnectTimeout("Connection timeout") with self.assertRaises(ConnectTimeout): http_backoff("GET", URL, base_wait_time=0.0, max_retries=3) @@ -76,7 +77,7 @@ def test_backoff_on_status_code_until_max(self) -> None: mock_503.status_code = 503 mock_504 = Mock() mock_504.status_code = 504 - mock_504.raise_for_status.side_effect = HTTPError() + mock_504.raise_for_status.side_effect = HTTPError("HTTP Error") self.mock_request.side_effect = (mock_503, mock_504, mock_503, mock_504) with self.assertRaises(HTTPError): @@ -94,7 +95,7 @@ def test_backoff_on_exceptions_and_status_codes(self) -> None: """Test `http_backoff` until max limit with status codes and exceptions.""" mock_503 = Mock() mock_503.status_code = 503 - self.mock_request.side_effect = (mock_503, ConnectTimeout()) + self.mock_request.side_effect = (mock_503, ConnectTimeout("Connection timeout")) with self.assertRaises(ConnectTimeout): http_backoff("GET", URL, base_wait_time=0.0, max_retries=1) @@ -131,7 +132,7 @@ def test_backoff_sleep_time(self) -> None: def _side_effect_timer() -> Generator[ConnectTimeout, None, None]: t0 = time.time() while True: - yield ConnectTimeout() + yield ConnectTimeout("Connection timeout") t1 = time.time() sleep_times.append(round(t1 - t0, 1)) t0 = t1 @@ -151,65 +152,62 @@ def _side_effect_timer() -> Generator[ConnectTimeout, None, None]: class TestConfigureSession(unittest.TestCase): def setUp(self) -> None: # Reconfigure + clear session cache between each test - configure_http_backend() + set_client_factory(default_client_factory) @classmethod def tearDownClass(cls) -> None: # Clear all sessions after tests - configure_http_backend() + set_client_factory(default_client_factory) @staticmethod - def _factory() -> requests.Session: - session = requests.Session() - session.headers.update({"x-test-header": 4}) - return session + def _factory() -> httpx.Client: + client = httpx.Client() + client.headers.update({"x-test-header": "4"}) + return client def test_default_configuration(self) -> None: - session = get_session() - self.assertEqual(session.headers["connection"], "keep-alive") # keep connection alive by default - self.assertIsNone(session.auth) - self.assertEqual(session.proxies, {}) - self.assertEqual(session.verify, True) - self.assertIsNone(session.cert) - self.assertEqual(session.max_redirects, 30) - self.assertEqual(session.trust_env, True) - self.assertEqual(session.hooks, {"response": []}) + client = get_session() + # Check httpx.Client default configuration + self.assertTrue(client.follow_redirects) + self.assertIsNotNone(client.timeout) + # Check that it's using the HfHubTransport + self.assertIsInstance(client._transport, HfHubTransport) def test_set_configuration(self) -> None: - configure_http_backend(backend_factory=self._factory) + set_client_factory(self._factory) # Check headers have been set correctly - session = get_session() - self.assertNotEqual(session.headers, {"x-test-header": 4}) - self.assertEqual(session.headers["x-test-header"], 4) + client = get_session() + self.assertNotEqual(client.headers, {"x-test-header": "4"}) + self.assertEqual(client.headers["x-test-header"], "4") def test_get_session_twice(self): - session_1 = get_session() - session_2 = get_session() - self.assertIs(session_1, session_2) # exact same instance + client_1 = get_session() + client_2 = get_session() + self.assertIs(client_1, client_2) # exact same instance def test_get_session_twice_but_reconfigure_in_between(self): """Reconfiguring the session clears the cache.""" - session_1 = get_session() - configure_http_backend(backend_factory=self._factory) + client_1 = get_session() + set_client_factory(self._factory) - session_2 = get_session() - self.assertIsNot(session_1, session_2) - self.assertIsNone(session_1.headers.get("x-test-header")) - self.assertEqual(session_2.headers["x-test-header"], 4) + client_2 = get_session() + self.assertIsNot(client_1, client_2) + self.assertIsNone(client_1.headers.get("x-test-header")) + self.assertEqual(client_2.headers["x-test-header"], "4") def test_get_session_multiple_threads(self): N = 3 - sessions = [None] * N + clients = [None] * N def _get_session_in_thread(index: int) -> None: time.sleep(0.1) - sessions[index] = get_session() + clients[index] = get_session() - # Get main thread session - main_session = get_session() + # Get main thread client + main_client = get_session() - # Start 3 threads and get sessions in each of them + # Start 3 threads and get clients in each of them threads = [threading.Thread(target=_get_session_in_thread, args=(index,)) for index in range(N)] for th in threads: th.start() @@ -217,43 +215,41 @@ def _get_session_in_thread(index: int) -> None: for th in threads: th.join() - # Check all sessions are different + # Check all clients are the same instance (httpx is thread-safe) for i in range(N): - self.assertIsNot(main_session, sessions[i]) + self.assertIs(main_client, clients[i]) for j in range(N): - if i != j: - self.assertIsNot(sessions[i], sessions[j]) + self.assertIs(clients[i], clients[j]) @unittest.skipIf(os.name == "nt", "Works differently on Windows.") def test_get_session_in_forked_process(self): - # Get main process session - main_session = get_session() + # Get main process client + main_client = get_session() def _child_target(): - # Put `repr(session)` in queue because putting the `Session` object directly would duplicate it. - # Repr looks like this: "" + # Put `repr(client)` in queue because putting the `Client` object directly would duplicate it. + # Repr looks like this: "" process_queue.put(repr(get_session())) - # Fork a new process and get session in it + # Fork a new process and get client in it process_queue = Queue() Process(target=_child_target).start() - child_session = process_queue.get() + child_client = process_queue.get() - # Check sessions are different - self.assertNotEqual(repr(main_session), child_session) + # Check clients are the same instance + self.assertEqual(repr(main_client), child_client) class OfflineModeSessionTest(unittest.TestCase): def tearDown(self) -> None: - reset_sessions() return super().tearDown() @patch("huggingface_hub.constants.HF_HUB_OFFLINE", True) def test_offline_mode(self): - configure_http_backend() - session = get_session() + set_client_factory(default_client_factory) + client = get_session() with self.assertRaises(OfflineModeIsEnabled): - session.get("https://huggingface.co") + client.get("https://huggingface.co") class TestUniqueRequestId(unittest.TestCase): diff --git a/tests/test_xet_upload.py b/tests/test_xet_upload.py index f66a0fd850..d2f4a8b55f 100644 --- a/tests/test_xet_upload.py +++ b/tests/test_xet_upload.py @@ -357,7 +357,6 @@ def test_hf_xet_with_token_refresher(self, api, tmp_path, repo_url): headers=headers, endpoint=api.endpoint, token=TOKEN, - proxies=None, etag_timeout=None, local_files_only=False, ) diff --git a/tests/testing_utils.py b/tests/testing_utils.py index eeb9d6611e..792f08ad17 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -12,10 +12,10 @@ from typing import Callable, Optional, Type, TypeVar, Union from unittest.mock import Mock, patch +import httpx import pytest -import requests -from huggingface_hub.utils import is_package_available, logging, reset_sessions +from huggingface_hub.utils import is_package_available, logging from tests.testing_constants import ENDPOINT_PRODUCTION, ENDPOINT_PRODUCTION_URL_SCHEME @@ -161,13 +161,14 @@ def offline(mode=OfflineSimulationMode.CONNECTION_FAILS, timeout=1e-16): Connection errors are created by mocking socket.socket CONNECTION_TIMES_OUT: the connection hangs until it times out. The default timeout value is low (1e-16) to speed up the tests. - Timeout errors are created by mocking requests.request + Timeout errors are created by mocking httpx.request HF_HUB_OFFLINE_SET_TO_1: the HF_HUB_OFFLINE_SET_TO_1 environment variable is set to 1. This makes the http/ftp calls of the library instantly fail and raise an OfflineModeEnabled error. """ import socket - from requests import request as online_request + # Store the original httpx.request to avoid recursion + original_httpx_request = httpx.request def timeout_request(method, url, **kwargs): # Change the url to an invalid url so that the connection hangs @@ -178,13 +179,16 @@ def timeout_request(method, url, **kwargs): ) kwargs["timeout"] = timeout try: - return online_request(method, invalid_url, **kwargs) + return original_httpx_request(method, invalid_url, **kwargs) except Exception as e: # The following changes in the error are just here to make the offline timeout error prettier - e.request.url = url - max_retry_error = e.args[0] - max_retry_error.args = (max_retry_error.args[0].replace("10.255.255.1", f"OfflineMock[{url}]"),) - e.args = (max_retry_error,) + if hasattr(e, "request"): + e.request.url = url + if hasattr(e, "args") and e.args: + max_retry_error = e.args[0] + if hasattr(max_retry_error, "args"): + max_retry_error.args = (max_retry_error.args[0].replace("10.255.255.1", f"OfflineMock[{url}]"),) + e.args = (max_retry_error,) raise def offline_socket(*args, **kwargs): @@ -194,19 +198,37 @@ def offline_socket(*args, **kwargs): # inspired from https://stackoverflow.com/a/18601897 with patch("socket.socket", offline_socket): with patch("huggingface_hub.utils._http.get_session") as get_session_mock: - get_session_mock.return_value = requests.Session() # not an existing one + mock_client = Mock() + + # Mock the request method to raise connection error + def mock_request(*args, **kwargs): + raise httpx.ConnectError("Connection failed") + + # Mock the stream method to raise connection error + def mock_stream(*args, **kwargs): + raise httpx.ConnectError("Connection failed") + + mock_client.request = mock_request + mock_client.stream = mock_stream + get_session_mock.return_value = mock_client yield elif mode is OfflineSimulationMode.CONNECTION_TIMES_OUT: # inspired from https://stackoverflow.com/a/904609 - with patch("requests.request", timeout_request): + with patch("httpx.request", timeout_request): with patch("huggingface_hub.utils._http.get_session") as get_session_mock: - get_session_mock().request = timeout_request + mock_client = Mock() + mock_client.request = timeout_request + + # Mock the stream method to raise timeout + def mock_stream(*args, **kwargs): + raise httpx.ConnectTimeout("Connection timed out") + + mock_client.stream = mock_stream + get_session_mock.return_value = mock_client yield elif mode is OfflineSimulationMode.HF_HUB_OFFLINE_SET_TO_1: with patch("huggingface_hub.constants.HF_HUB_OFFLINE", True): - reset_sessions() yield - reset_sessions() else: raise ValueError("Please use a value from the OfflineSimulationMode enum.") diff --git a/utils/generate_async_inference_client.py b/utils/generate_async_inference_client.py index 61705b51c4..af699affa4 100644 --- a/utils/generate_async_inference_client.py +++ b/utils/generate_async_inference_client.py @@ -42,6 +42,11 @@ def generate_async_client_code(code: str) -> str: # Refactor `.post` method to be async + adapt calls code = _make_inner_post_async(code) code = _await_inner_post_method_call(code) + + # Handle __enter__, __exit__, close + code = _remove_enter_exit_stack(code) + + # Use _async_stream_text_generation_response code = _use_async_streaming_util(code) # Make all tasks-method async @@ -54,15 +59,11 @@ def generate_async_client_code(code: str) -> str: code = _adapt_chat_completion_to_async(code) # Update some docstrings - code = _rename_HTTPError_to_ClientResponseError_in_docstring(code) code = _update_examples_in_public_methods(code) # Adapt /info and /health endpoints code = _adapt_info_and_health_endpoints(code) - # Add _get_client_session - code = _add_get_client_session(code) - # Adapt the proxy client (for client.chat.completions.create) code = _adapt_proxy_client(code) @@ -136,10 +137,13 @@ def _add_imports(code: str) -> str: r"(\nimport .*?\n)", repl=( r"\1" - + "from .._common import _async_yield_from, _import_aiohttp\n" + + "from .._common import _async_yield_from\n" + + "from huggingface_hub.utils import get_async_session\n" + "from typing import AsyncIterable\n" + + "from contextlib import AsyncExitStack\n" + "from typing import Set\n" + "import asyncio\n" + + "import httpx\n" ), string=code, count=1, @@ -163,72 +167,52 @@ def _rename_to_AsyncInferenceClient(code: str) -> str: ASYNC_INNER_POST_CODE = """ - aiohttp = _import_aiohttp() - # TODO: this should be handled in provider helpers directly if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers: request_parameters.headers["Accept"] = "image/png" - # Do not use context manager as we don't want to close the connection immediately when returning - # a stream - session = self._get_client_session(headers=request_parameters.headers) - try: - response = await session.post(request_parameters.url, json=request_parameters.json, data=request_parameters.data, proxy=self.proxies) - response_error_payload = None - if response.status != 200: - try: - response_error_payload = await response.json() # get payload before connection closed - except Exception: - pass - response.raise_for_status() + client = await self._get_async_client() if stream: - return _async_yield_from(session, response) + response = await self.exit_stack.enter_async_context( + client.stream( + "POST", + request_parameters.url, + json=request_parameters.json, + data=request_parameters.data, + headers=request_parameters.headers, + cookies=self.cookies, + timeout=self.timeout, + ) + ) + hf_raise_for_status(response) + return _async_yield_from(client, response) else: - content = await response.read() - await session.close() - return content + response = await client.post( + request_parameters.url, + json=request_parameters.json, + data=request_parameters.data, + headers=request_parameters.headers, + cookies=self.cookies, + timeout=self.timeout, + ) + hf_raise_for_status(response) + return response.content except asyncio.TimeoutError as error: - await session.close() # Convert any `TimeoutError` to a `InferenceTimeoutError` raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore - except aiohttp.ClientResponseError as error: - error.response_error_payload = response_error_payload - await session.close() - raise error - except Exception: - await session.close() + except HfHubHTTPError as error: + if error.response.status_code == 422 and request_parameters.task != "unknown": + msg = str(error.args[0]) + if len(error.response.text) > 0: + msg += f"{os.linesep}{error.response.text}{os.linesep}" + error.args = (msg,) + error.args[1:] raise - - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc_value, traceback): - await self.close() - - def __del__(self): - if len(self._sessions) > 0: - warnings.warn( - "Deleting 'AsyncInferenceClient' client but some sessions are still open. " - "This can happen if you've stopped streaming data from the server before the stream was complete. " - "To close the client properly, you must call `await client.close()` " - "or use an async context (e.g. `async with AsyncInferenceClient(): ...`." - ) - - async def close(self): - \"""Close all open sessions. - - By default, 'aiohttp.ClientSession' objects are closed automatically when a call is completed. However, if you - are streaming data from the server and you stop before the stream is complete, you must call this method to - close the session properly. - - Another possibility is to use an async context (e.g. `async with AsyncInferenceClient(): ...`). - \""" - await asyncio.gather(*[session.close() for session in self._sessions.keys()])""" + """ def _make_inner_post_async(code: str) -> str: - # Update AsyncInferenceClient._inner_post() implementation (use aiohttp instead of requests) + # Update AsyncInferenceClient._inner_post() implementation code = re.sub( r""" def[ ]_inner_post\( # definition @@ -243,12 +227,52 @@ def _make_inner_post_async(code: str) -> str: ) # Update `post`'s type annotations code = code.replace(" def _inner_post(", " async def _inner_post(") - return code.replace("Iterable[bytes]", "AsyncIterable[bytes]") + return code + + +ENTER_EXIT_STACK_SYNC_CODE = """ + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.exit_stack.close() + + def close(self): + self.exit_stack.close()""" + +ENTER_EXIT_STACK_ASYNC_CODE = """ + async def __aenter__(self): + return self + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() -def _rename_HTTPError_to_ClientResponseError_in_docstring(code: str) -> str: - # Update `raises`-part in docstrings - return code.replace("`HTTPError`:", "`aiohttp.ClientResponseError`:") + async def close(self): + \"""Close the client. + + This method is automatically called when using the client as a context manager. + \""" + await self.exit_stack.aclose() + + async def _get_async_client(self): + \"""Get a unique async client for this AsyncInferenceClient instance. + + Returns the same client instance on subsequent calls, ensuring proper + connection reuse and resource management through the exit stack. + \""" + if self._async_client is None: + self._async_client = await self.exit_stack.enter_async_context(get_async_session()) + return self._async_client +""" + + +def _remove_enter_exit_stack(code: str) -> str: + code = code.replace( + "exit_stack = ExitStack()", + "exit_stack = AsyncExitStack()\n self._async_client: Optional[httpx.AsyncClient] = None", + ) + code = code.replace(ENTER_EXIT_STACK_SYNC_CODE, ENTER_EXIT_STACK_ASYNC_CODE) + return code def _make_tasks_methods_async(code: str) -> str: @@ -272,22 +296,7 @@ def _make_tasks_methods_async(code: str) -> str: def _adapt_text_generation_to_async(code: str) -> str: - # Text-generation task has to be handled specifically since it has a recursive call mechanism (to retry on non-tgi - # servers) - - # Catch `aiohttp` error instead of `requests` error - code = code.replace( - """ - except HTTPError as e: - match = MODEL_KWARGS_NOT_USED_REGEX.search(str(e)) - if isinstance(e, BadRequestError) and match: - """, - """ - except _import_aiohttp().ClientResponseError as e: - match = MODEL_KWARGS_NOT_USED_REGEX.search(e.response_error_payload["error"]) - if e.status == 400 and match: - """, - ) + # Text-generation task has to be handled specifically since it has a recursive call mechanism (to retry on non-tgi servers) # Await recursive call code = code.replace( @@ -301,24 +310,8 @@ def _adapt_text_generation_to_async(code: str) -> str: # Update return types: Iterable -> AsyncIterable code = code.replace( - ") -> Iterable[str]:", - ") -> AsyncIterable[str]:", - ) - code = code.replace( - ") -> Union[bytes, Iterable[bytes]]:", - ") -> Union[bytes, AsyncIterable[bytes]]:", - ) - code = code.replace( - ") -> Iterable[TextGenerationStreamOutput]:", - ") -> AsyncIterable[TextGenerationStreamOutput]:", - ) - code = code.replace( - ") -> Union[TextGenerationOutput, Iterable[TextGenerationStreamOutput]]:", - ") -> Union[TextGenerationOutput, AsyncIterable[TextGenerationStreamOutput]]:", - ) - code = code.replace( - ") -> Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]:", - ") -> Union[str, TextGenerationOutput, AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]:", + "Iterable[", + "AsyncIterable[", ) return code @@ -331,16 +324,6 @@ def _adapt_chat_completion_to_async(code: str) -> str: "text_generation_output = await self.text_generation(", ) - # Update return types: Iterable -> AsyncIterable - code = code.replace( - ") -> Iterable[ChatCompletionStreamOutput]:", - ") -> AsyncIterable[ChatCompletionStreamOutput]:", - ) - code = code.replace( - ") -> Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]:", - ") -> Union[ChatCompletionOutput, AsyncIterable[ChatCompletionStreamOutput]]:", - ) - return code @@ -395,101 +378,14 @@ def _use_async_streaming_util(code: str) -> str: def _adapt_info_and_health_endpoints(code: str) -> str: - info_sync_snippet = """ - response = get_session().get(url, headers=build_hf_headers(token=self.token)) - hf_raise_for_status(response) - return response.json()""" - - info_async_snippet = """ - async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client: - response = await client.get(url, proxy=self.proxies) - response.raise_for_status() - return await response.json()""" - - code = code.replace(info_sync_snippet, info_async_snippet) - - health_sync_snippet = """ - response = get_session().get(url, headers=build_hf_headers(token=self.token)) - return response.status_code == 200""" - - health_async_snippet = """ - async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client: - response = await client.get(url, proxy=self.proxies) - return response.status == 200""" - - return code.replace(health_sync_snippet, health_async_snippet) - - -def _add_get_client_session(code: str) -> str: - # Add trust_env as parameter - code = _add_before(code, "proxies: Optional[Any] = None,", "trust_env: bool = False,") - code = _add_before(code, "\n self.proxies = proxies\n", "\n self.trust_env = trust_env") - - # Document `trust_env` parameter - code = _add_before( - code, - "\n proxies (`Any`, `optional`):", - """ - trust_env ('bool', 'optional'): - Trust environment settings for proxy configuration if the parameter is `True` (`False` by default).""", - ) - - # insert `_get_client_session` before `get_endpoint_info` method - client_session_code = """ - - def _get_client_session(self, headers: Optional[Dict] = None) -> "ClientSession": - aiohttp = _import_aiohttp() - client_headers = self.headers.copy() - if headers is not None: - client_headers.update(headers) - - # Return a new aiohttp ClientSession with correct settings. - session = aiohttp.ClientSession( - headers=client_headers, - cookies=self.cookies, - timeout=aiohttp.ClientTimeout(self.timeout), - trust_env=self.trust_env, - ) - - # Keep track of sessions to close them later - self._sessions[session] = set() - - # Override the `._request` method to register responses to be closed - session._wrapped_request = session._request + get_url_sync_snippet = """ + response = get_session().get(url, headers=build_hf_headers(token=self.token))""" - async def _request(method, url, **kwargs): - response = await session._wrapped_request(method, url, **kwargs) - self._sessions[session].add(response) - return response + get_url_async_snippet = """ + client = await self._get_async_client() + response = await client.get(url, headers=build_hf_headers(token=self.token))""" - session._request = _request - - # Override the 'close' method to - # 1. close ongoing responses - # 2. deregister the session when closed - session._close = session.close - - async def close_session(): - for response in self._sessions[session]: - response.close() - await session._close() - self._sessions.pop(session, None) - - session.close = close_session - return session - -""" - code = _add_before(code, "\n async def get_endpoint_info(", client_session_code) - - # Add self._sessions attribute in __init__ - code = _add_before( - code, - "\n def __repr__(self):\n", - "\n # Keep track of the sessions to close them properly" - "\n self._sessions: Dict['ClientSession', Set['ClientResponse']] = dict()", - ) - - return code + return code.replace(get_url_sync_snippet, get_url_async_snippet) def _adapt_proxy_client(code: str) -> str: From 63278df9ef81b27877d7d7db43e6b0d5dc8619c4 Mon Sep 17 00:00:00 2001 From: Lucain Date: Wed, 10 Sep 2025 17:17:41 +0200 Subject: [PATCH 02/19] Bump minimal version to Python3.9 (#3343) * Bump minimal version to Python3.9 * use built-in generics * code quality * new batch * yet another btach * fix dataclass_with_extra * fix * keep Type for strict dataclasses * fix test --- .github/workflows/contrib-tests.yml | 4 +- .github/workflows/python-tests.yml | 18 +- .github/workflows/release-conda.yml | 2 +- docs/source/cn/installation.md | 2 +- docs/source/de/guides/integrations.md | 2 +- docs/source/de/installation.md | 2 +- docs/source/en/guides/integrations.md | 2 +- docs/source/en/installation.md | 2 +- docs/source/fr/guides/integrations.md | 14 +- docs/source/fr/installation.md | 2 +- docs/source/hi/installation.md | 2 +- docs/source/ko/guides/integrations.md | 18 +- docs/source/ko/installation.md | 2 +- docs/source/tm/installation.md | 2 +- setup.py | 8 +- src/huggingface_hub/_commit_api.py | 50 +- src/huggingface_hub/_commit_scheduler.py | 14 +- src/huggingface_hub/_inference_endpoints.py | 16 +- src/huggingface_hub/_jobs_api.py | 40 +- src/huggingface_hub/_oauth.py | 16 +- src/huggingface_hub/_snapshot_download.py | 16 +- src/huggingface_hub/_space_api.py | 8 +- src/huggingface_hub/_tensorboard_logger.py | 10 +- src/huggingface_hub/_upload_large_folder.py | 30 +- src/huggingface_hub/_webhooks_payload.py | 6 +- src/huggingface_hub/_webhooks_server.py | 4 +- src/huggingface_hub/cli/_cli_utils.py | 4 +- src/huggingface_hub/cli/auth.py | 4 +- src/huggingface_hub/cli/cache.py | 20 +- src/huggingface_hub/cli/download.py | 8 +- src/huggingface_hub/cli/jobs.py | 22 +- src/huggingface_hub/cli/lfs.py | 8 +- src/huggingface_hub/cli/repo_files.py | 4 +- src/huggingface_hub/cli/upload.py | 8 +- .../cli/upload_large_folder.py | 6 +- src/huggingface_hub/commands/_cli_utils.py | 4 +- src/huggingface_hub/commands/delete_cache.py | 22 +- src/huggingface_hub/commands/download.py | 8 +- src/huggingface_hub/commands/lfs.py | 8 +- src/huggingface_hub/commands/repo_files.py | 4 +- src/huggingface_hub/commands/upload.py | 8 +- .../commands/upload_large_folder.py | 6 +- src/huggingface_hub/commands/user.py | 4 +- src/huggingface_hub/community.py | 10 +- src/huggingface_hub/constants.py | 6 +- src/huggingface_hub/dataclasses.py | 35 +- src/huggingface_hub/fastai_utils.py | 16 +- src/huggingface_hub/file_download.py | 32 +- src/huggingface_hub/hf_api.py | 486 +++++++++--------- src/huggingface_hub/hf_file_system.py | 28 +- src/huggingface_hub/hub_mixin.py | 64 +-- src/huggingface_hub/inference/_client.py | 216 ++++---- src/huggingface_hub/inference/_common.py | 31 +- .../inference/_generated/_async_client.py | 216 ++++---- .../types/automatic_speech_recognition.py | 6 +- .../inference/_generated/types/base.py | 17 +- .../_generated/types/chat_completion.py | 32 +- .../_generated/types/depth_estimation.py | 4 +- .../types/document_question_answering.py | 4 +- .../_generated/types/feature_extraction.py | 4 +- .../inference/_generated/types/fill_mask.py | 4 +- .../_generated/types/sentence_similarity.py | 6 +- .../_generated/types/summarization.py | 4 +- .../types/table_question_answering.py | 8 +- .../_generated/types/text2text_generation.py | 4 +- .../_generated/types/text_generation.py | 20 +- .../_generated/types/text_to_video.py | 4 +- .../_generated/types/token_classification.py | 4 +- .../inference/_generated/types/translation.py | 4 +- .../types/zero_shot_classification.py | 4 +- .../types/zero_shot_image_classification.py | 4 +- .../types/zero_shot_object_detection.py | 4 +- src/huggingface_hub/inference/_mcp/agent.py | 6 +- .../inference/_mcp/constants.py | 3 +- .../inference/_mcp/mcp_client.py | 55 +- src/huggingface_hub/inference/_mcp/types.py | 20 +- src/huggingface_hub/inference/_mcp/utils.py | 8 +- .../inference/_providers/__init__.py | 4 +- .../inference/_providers/_common.py | 48 +- .../inference/_providers/black_forest_labs.py | 12 +- .../inference/_providers/cohere.py | 6 +- .../inference/_providers/fal_ai.py | 50 +- .../inference/_providers/featherless_ai.py | 8 +- .../inference/_providers/fireworks_ai.py | 6 +- .../inference/_providers/hf_inference.py | 26 +- .../inference/_providers/hyperbolic.py | 8 +- .../inference/_providers/nebius.py | 20 +- .../inference/_providers/new_provider.md | 10 +- .../inference/_providers/novita.py | 10 +- .../inference/_providers/nscale.py | 8 +- .../inference/_providers/replicate.py | 30 +- .../inference/_providers/sambanova.py | 12 +- .../inference/_providers/together.py | 14 +- src/huggingface_hub/inference_api.py | 12 +- src/huggingface_hub/keras_mixin.py | 20 +- src/huggingface_hub/lfs.py | 30 +- src/huggingface_hub/repocard.py | 18 +- src/huggingface_hub/repocard_data.py | 112 ++-- src/huggingface_hub/repository.py | 38 +- src/huggingface_hub/serialization/_base.py | 18 +- src/huggingface_hub/serialization/_dduf.py | 14 +- .../serialization/_tensorflow.py | 6 +- src/huggingface_hub/serialization/_torch.py | 56 +- src/huggingface_hub/utils/_auth.py | 10 +- src/huggingface_hub/utils/_cache_manager.py | 62 +-- src/huggingface_hub/utils/_deprecation.py | 2 +- src/huggingface_hub/utils/_dotenv.py | 6 +- src/huggingface_hub/utils/_git_credential.py | 6 +- src/huggingface_hub/utils/_headers.py | 12 +- src/huggingface_hub/utils/_http.py | 26 +- src/huggingface_hub/utils/_pagination.py | 4 +- src/huggingface_hub/utils/_paths.py | 10 +- src/huggingface_hub/utils/_runtime.py | 6 +- src/huggingface_hub/utils/_safetensors.py | 42 +- src/huggingface_hub/utils/_subprocess.py | 18 +- src/huggingface_hub/utils/_telemetry.py | 6 +- src/huggingface_hub/utils/_typing.py | 4 +- src/huggingface_hub/utils/_validators.py | 6 +- src/huggingface_hub/utils/_xet.py | 26 +- .../utils/_xet_progress_reporting.py | 2 +- src/huggingface_hub/utils/insecure_hashlib.py | 12 +- src/huggingface_hub/utils/tqdm.py | 4 +- tests/test_dduf.py | 6 +- tests/test_file_download.py | 4 +- tests/test_hf_api.py | 10 +- tests/test_hub_mixin.py | 8 +- tests/test_hub_mixin_pytorch.py | 4 +- tests/test_inference_client.py | 3 +- tests/test_inference_providers.py | 5 +- tests/test_inference_text_generation.py | 3 +- tests/test_inference_types.py | 5 +- tests/test_serialization.py | 46 +- tests/test_utils_cache.py | 4 +- tests/test_utils_paths.py | 10 +- tests/test_utils_strict_dataclass.py | 48 +- tests/test_xet_download.py | 3 +- tests/test_xet_upload.py | 3 +- tests/testing_utils.py | 6 +- utils/check_all_variable.py | 6 +- utils/check_task_parameters.py | 52 +- utils/generate_inference_types.py | 14 +- 141 files changed, 1444 insertions(+), 1465 deletions(-) diff --git a/.github/workflows/contrib-tests.yml b/.github/workflows/contrib-tests.yml index df663ce975..d294b7b530 100644 --- a/.github/workflows/contrib-tests.yml +++ b/.github/workflows/contrib-tests.yml @@ -26,10 +26,10 @@ jobs: steps: - uses: actions/checkout@v2 - - name: Set up Python 3.8 + - name: Set up Python 3.9 uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: 3.9 # Install pip - name: Install pip diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 11bfcc806f..6c9bacf656 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -21,7 +21,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.13"] + python-version: ["3.9", "3.13"] test_name: [ "Repository only", @@ -30,18 +30,18 @@ jobs: "Xet only" ] include: - - python-version: "3.13" # LFS not ran on 3.8 + - python-version: "3.13" # LFS not ran on 3.9 test_name: "lfs" - - python-version: "3.8" + - python-version: "3.9" test_name: "fastai" - python-version: "3.10" # fastai not supported on 3.12 and 3.11 -> test it on 3.10 test_name: "fastai" - - python-version: "3.8" + - python-version: "3.9" test_name: "tensorflow" - python-version: "3.10" # tensorflow not supported on 3.12 -> test it on 3.10 test_name: "tensorflow" - - python-version: "3.8" # test torch~=1.11 on python 3.8 only. - test_name: "Python 3.8, torch_1.11" + - python-version: "3.9" # test torch~=1.11 on python 3.9 only. + test_name: "Python 3.9, torch_1.11" - python-version: "3.12" # test torch latest on python 3.12 only. test_name: "torch_latest" steps: @@ -84,7 +84,7 @@ jobs: uv pip install --upgrade torch ;; - "Python 3.8, torch_1.11") + "Python 3.9, torch_1.11") uv pip install "huggingface_hub[torch] @ ." uv pip install torch~=1.11 ;; @@ -147,7 +147,7 @@ jobs: eval "$PYTEST ../tests/test_serialization.py" ;; - "Python 3.8, torch_1.11" | torch_latest) + "Python 3.9, torch_1.11" | torch_latest) eval "$PYTEST ../tests/test_hub_mixin*" eval "$PYTEST ../tests/test_serialization.py" ;; @@ -178,7 +178,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.11"] + python-version: ["3.9", "3.11"] test_name: ["Everything else", "Xet only"] steps: diff --git a/.github/workflows/release-conda.yml b/.github/workflows/release-conda.yml index 135d988809..b6ead02950 100644 --- a/.github/workflows/release-conda.yml +++ b/.github/workflows/release-conda.yml @@ -26,7 +26,7 @@ jobs: with: auto-update-conda: true auto-activate-base: false - python-version: 3.8 + python-version: 3.9 activate-environment: "build-hub" - name: Setup conda env diff --git a/docs/source/cn/installation.md b/docs/source/cn/installation.md index c800b4b173..ec899a2305 100644 --- a/docs/source/cn/installation.md +++ b/docs/source/cn/installation.md @@ -6,7 +6,7 @@ rendered properly in your Markdown viewer. 在开始之前,您需要通过安装适当的软件包来设置您的环境 -huggingface_hub 在 Python 3.8 或更高版本上进行了测试,可以保证在这些版本上正常运行。如果您使用的是 Python 3.7 或更低版本,可能会出现兼容性问题 +huggingface_hub 在 Python 3.9 或更高版本上进行了测试,可以保证在这些版本上正常运行。如果您使用的是 Python 3.7 或更低版本,可能会出现兼容性问题 ## 使用 pip 安装 diff --git a/docs/source/de/guides/integrations.md b/docs/source/de/guides/integrations.md index 06384c80da..3d792c3b5e 100644 --- a/docs/source/de/guides/integrations.md +++ b/docs/source/de/guides/integrations.md @@ -202,7 +202,7 @@ class PyTorchModelHubMixin(ModelHubMixin): revision: str, cache_dir: str, force_download: bool, - proxies: Optional[Dict], + proxies: Optional[dict], resume_download: bool, local_files_only: bool, token: Union[str, bool, None], diff --git a/docs/source/de/installation.md b/docs/source/de/installation.md index 3ba965bd4b..4c2a907f04 100644 --- a/docs/source/de/installation.md +++ b/docs/source/de/installation.md @@ -6,7 +6,7 @@ rendered properly in your Markdown viewer. Bevor Sie beginnen, müssen Sie Ihre Umgebung vorbereiten, indem Sie die entsprechenden Pakete installieren. -`huggingface_hub` wurde für **Python 3.8+** getestet. +`huggingface_hub` wurde für **Python 3.9+** getestet. ## Installation mit pip diff --git a/docs/source/en/guides/integrations.md b/docs/source/en/guides/integrations.md index e5ac9aaa87..cc4431923d 100644 --- a/docs/source/en/guides/integrations.md +++ b/docs/source/en/guides/integrations.md @@ -244,7 +244,7 @@ class PyTorchModelHubMixin(ModelHubMixin): revision: str, cache_dir: str, force_download: bool, - proxies: Optional[Dict], + proxies: Optional[dict], resume_download: bool, local_files_only: bool, token: Union[str, bool, None], diff --git a/docs/source/en/installation.md b/docs/source/en/installation.md index 9af8a32676..7d86b715d0 100644 --- a/docs/source/en/installation.md +++ b/docs/source/en/installation.md @@ -6,7 +6,7 @@ rendered properly in your Markdown viewer. Before you start, you will need to setup your environment by installing the appropriate packages. -`huggingface_hub` is tested on **Python 3.8+**. +`huggingface_hub` is tested on **Python 3.9+**. ## Install with pip diff --git a/docs/source/fr/guides/integrations.md b/docs/source/fr/guides/integrations.md index 5a9736667f..f2c81a3d17 100644 --- a/docs/source/fr/guides/integrations.md +++ b/docs/source/fr/guides/integrations.md @@ -223,7 +223,7 @@ class PyTorchModelHubMixin(ModelHubMixin): revision: str, cache_dir: str, force_download: bool, - proxies: Optional[Dict], + proxies: Optional[dict], resume_download: bool, local_files_only: bool, token: Union[str, bool, None], @@ -266,9 +266,9 @@ est ici pour vous donner des indications et des idées sur comment gérer l'int n'hésitez pas à nous contacter si vous avez une question ! -| Intégration | Utilisant des helpers | Utilisant [`ModelHubMixin`] | -|:---:|:---:|:---:| -| Expérience utilisateur | `model = load_from_hub(...)`
`push_to_hub(model, ...)` | `model = MyModel.from_pretrained(...)`
`model.push_to_hub(...)` | -| Flexible | Très flexible.
Vous controllez complètement l'implémentation. | Moins flexible.
Votre framework doit avoir une classe de modèle. | -| Maintenance | Plus de maintenance pour ajouter du support pour la configuration, et de nouvelles fonctionnalités. Peut aussi nécessiter de fixx des problèmes signalés par les utilisateurs.| Moins de maintenance vu que la plupart des intégrations avec le Hub sont implémentés dans `huggingface_hub` | -| Documentation / Anotation de type| A écrire à la main | Géré partiellement par `huggingface_hub`. | +| Intégration | Utilisant des helpers | Utilisant [`ModelHubMixin`] | +| :-------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------: | +| Expérience utilisateur | `model = load_from_hub(...)`
`push_to_hub(model, ...)` | `model = MyModel.from_pretrained(...)`
`model.push_to_hub(...)` | +| Flexible | Très flexible.
Vous controllez complètement l'implémentation. | Moins flexible.
Votre framework doit avoir une classe de modèle. | +| Maintenance | Plus de maintenance pour ajouter du support pour la configuration, et de nouvelles fonctionnalités. Peut aussi nécessiter de fixx des problèmes signalés par les utilisateurs. | Moins de maintenance vu que la plupart des intégrations avec le Hub sont implémentés dans `huggingface_hub` | +| Documentation / Anotation de type | A écrire à la main | Géré partiellement par `huggingface_hub`. | diff --git a/docs/source/fr/installation.md b/docs/source/fr/installation.md index eb4b2ee9b4..6e0f41ee6e 100644 --- a/docs/source/fr/installation.md +++ b/docs/source/fr/installation.md @@ -7,7 +7,7 @@ rendered properly in your Markdown viewer. Avant de commencer, vous allez avoir besoin de préparer votre environnement en installant les packages appropriés. -`huggingface_hub` est testée sur **Python 3.8+**. +`huggingface_hub` est testée sur **Python 3.9+**. ## Installation avec pip diff --git a/docs/source/hi/installation.md b/docs/source/hi/installation.md index 1659e85fd7..c5974a32f7 100644 --- a/docs/source/hi/installation.md +++ b/docs/source/hi/installation.md @@ -6,7 +6,7 @@ rendered properly in your Markdown viewer. आरंभ करने से पहले, आपको उपयुक्त पैकेज स्थापित करके अपना परिवेश सेटअप करना होगा। -`huggingface_hub` का परीक्षण **Python 3.8+** पर किया गया है। +`huggingface_hub` का परीक्षण **Python 3.9+** पर किया गया है। ## पिप के साथ स्थापित करें diff --git a/docs/source/ko/guides/integrations.md b/docs/source/ko/guides/integrations.md index f0946bc298..a3ff1750f6 100644 --- a/docs/source/ko/guides/integrations.md +++ b/docs/source/ko/guides/integrations.md @@ -211,7 +211,7 @@ class PyTorchModelHubMixin(ModelHubMixin): revision: str, cache_dir: str, force_download: bool, - proxies: Optional[Dict], + proxies: Optional[dict], resume_download: bool, local_files_only: bool, token: Union[str, bool, None], @@ -393,11 +393,11 @@ class VoiceCraft( 두 가지 접근 방법에 대한 장단점을 간단히 정리해보겠습니다. 아래 표는 단순히 예시일 뿐입니다. 각자 다른 프레임워크에는 고려해야 할 특정 사항이 있을 수 있습니다. 이 가이드는 통합을 다루는 아이디어와 지침을 제공하기 위한 것입니다. 언제든지 궁금한 점이 있으면 문의해 주세요! -| 통합 | helpers 사용 시 | [`ModelHubMixin`] 사용 시 | -|:---:|:---:|:---:| -| 사용자 경험 | `model = load_from_hub(...)`
`push_to_hub(model, ...)` | `model = MyModel.from_pretrained(...)`
`model.push_to_hub(...)` | -| 유연성 | 매우 유연합니다.
구현을 완전히 제어합니다. | 유연성이 떨어집니다.
프레임워크에는 모델 클래스가 있어야 합니다. | -| 유지 관리 | 구성 및 새로운 기능에 대한 지원을 추가하기 위한 유지 관리가 더 필요합니다. 사용자가 보고한 문제를 해결해야할 수도 있습니다. | Hub와의 대부분의 상호 작용이 `huggingface_hub`에서 구현되므로 유지 관리가 줄어듭니다. | -| 문서화 / 타입 주석 | 수동으로 작성해야 합니다. | `huggingface_hub`에서 부분적으로 처리됩니다. | -| 다운로드 횟수 표시기 | 수동으로 처리해야 합니다. | 클래스에 `config` 속성이 있다면 기본적으로 활성화됩니다. | -| 모델 카드 | 수동으로 처리해야 합니다. | library_name, tags 등을 활용하여 기본적으로 생성됩니다. | +| 통합 | helpers 사용 시 | [`ModelHubMixin`] 사용 시 | +| :------------------: | :-------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------: | +| 사용자 경험 | `model = load_from_hub(...)`
`push_to_hub(model, ...)` | `model = MyModel.from_pretrained(...)`
`model.push_to_hub(...)` | +| 유연성 | 매우 유연합니다.
구현을 완전히 제어합니다. | 유연성이 떨어집니다.
프레임워크에는 모델 클래스가 있어야 합니다. | +| 유지 관리 | 구성 및 새로운 기능에 대한 지원을 추가하기 위한 유지 관리가 더 필요합니다. 사용자가 보고한 문제를 해결해야할 수도 있습니다. | Hub와의 대부분의 상호 작용이 `huggingface_hub`에서 구현되므로 유지 관리가 줄어듭니다. | +| 문서화 / 타입 주석 | 수동으로 작성해야 합니다. | `huggingface_hub`에서 부분적으로 처리됩니다. | +| 다운로드 횟수 표시기 | 수동으로 처리해야 합니다. | 클래스에 `config` 속성이 있다면 기본적으로 활성화됩니다. | +| 모델 카드 | 수동으로 처리해야 합니다. | library_name, tags 등을 활용하여 기본적으로 생성됩니다. | diff --git a/docs/source/ko/installation.md b/docs/source/ko/installation.md index 720346b1a1..b222bef630 100644 --- a/docs/source/ko/installation.md +++ b/docs/source/ko/installation.md @@ -6,7 +6,7 @@ rendered properly in your Markdown viewer. 시작하기 전에 적절한 패키지를 설치하여 환경을 설정해야 합니다. -`huggingface_hub`는 **Python 3.8+**에서 테스트되었습니다. +`huggingface_hub`는 **Python 3.9+**에서 테스트되었습니다. ## pip로 설치하기 [[install-with-pip]] diff --git a/docs/source/tm/installation.md b/docs/source/tm/installation.md index f16ac74667..28134ed5b7 100644 --- a/docs/source/tm/installation.md +++ b/docs/source/tm/installation.md @@ -2,7 +2,7 @@ நீங்கள் தொடங்குவதற்கு முன், தகுந்த தொகுப்புகளை நிறுவுவதன் மூலம் உங்கள் சூழலை அமைக்க வேண்டும். -`huggingface_hub` **Python 3.8+** மின்பொருள்களில் சோதிக்கப்பட்டுள்ளது. +`huggingface_hub` **Python 3.9+** மின்பொருள்களில் சோதிக்கப்பட்டுள்ளது. ### பிப் மூலம் நிறுவு diff --git a/setup.py b/setup.py index 3fd35880fa..ec5ebfbb39 100644 --- a/setup.py +++ b/setup.py @@ -89,7 +89,7 @@ def get_version() -> str: "soundfile", "Pillow", "gradio>=4.0.0", # to test webhooks # pin to avoid issue on Python3.12 - "requests", # for gradio + "requests", # for gradio "numpy", # for embeddings "fastapi", # To build the documentation ] @@ -108,8 +108,7 @@ def get_version() -> str: extras["quality"] = [ "ruff>=0.9.0", - "mypy>=1.14.1,<1.15.0; python_version=='3.8'", - "mypy==1.15.0; python_version>='3.9'", + "mypy==1.15.0", "libcst>=1.4.0", "ty", ] @@ -140,7 +139,7 @@ def get_version() -> str: ], "fsspec.specs": "hf=huggingface_hub.HfFileSystem", }, - python_requires=">=3.8.0", + python_requires=">=3.9.0", install_requires=install_requires, classifiers=[ "Intended Audience :: Developers", @@ -150,7 +149,6 @@ def get_version() -> str: "Operating System :: OS Independent", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", diff --git a/src/huggingface_hub/_commit_api.py b/src/huggingface_hub/_commit_api.py index 58e082b307..70fe8f78c3 100644 --- a/src/huggingface_hub/_commit_api.py +++ b/src/huggingface_hub/_commit_api.py @@ -11,7 +11,7 @@ from dataclasses import dataclass, field from itertools import groupby from pathlib import Path, PurePosixPath -from typing import TYPE_CHECKING, Any, BinaryIO, Dict, Iterable, Iterator, List, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, BinaryIO, Iterable, Iterator, Literal, Optional, Union from tqdm.contrib.concurrent import thread_map @@ -306,7 +306,7 @@ def _validate_path_in_repo(path_in_repo: str) -> str: CommitOperation = Union[CommitOperationAdd, CommitOperationCopy, CommitOperationDelete] -def _warn_on_overwriting_operations(operations: List[CommitOperation]) -> None: +def _warn_on_overwriting_operations(operations: list[CommitOperation]) -> None: """ Warn user when a list of operations is expected to overwrite itself in a single commit. @@ -321,7 +321,7 @@ def _warn_on_overwriting_operations(operations: List[CommitOperation]) -> None: delete before upload) but can happen if a user deletes an entire folder and then add new files to it. """ - nb_additions_per_path: Dict[str, int] = defaultdict(int) + nb_additions_per_path: dict[str, int] = defaultdict(int) for operation in operations: path_in_repo = operation.path_in_repo if isinstance(operation, CommitOperationAdd): @@ -355,10 +355,10 @@ def _warn_on_overwriting_operations(operations: List[CommitOperation]) -> None: @validate_hf_hub_args def _upload_lfs_files( *, - additions: List[CommitOperationAdd], + additions: list[CommitOperationAdd], repo_type: str, repo_id: str, - headers: Dict[str, str], + headers: dict[str, str], endpoint: Optional[str] = None, num_threads: int = 5, revision: Optional[str] = None, @@ -377,7 +377,7 @@ def _upload_lfs_files( repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. - headers (`Dict[str, str]`): + headers (`dict[str, str]`): Headers to use for the request, including authorization headers and user agent. num_threads (`int`, *optional*): The number of concurrent threads to use when uploading. Defaults to 5. @@ -395,7 +395,7 @@ def _upload_lfs_files( # Step 1: retrieve upload instructions from the LFS batch endpoint. # Upload instructions are retrieved by chunk of 256 files to avoid reaching # the payload limit. - batch_actions: List[Dict] = [] + batch_actions: list[dict] = [] for chunk in chunk_iterable(additions, chunk_size=UPLOAD_BATCH_MAX_NUM_FILES): batch_actions_chunk, batch_errors_chunk = post_lfs_batch_info( upload_infos=[op.upload_info for op in chunk], @@ -466,10 +466,10 @@ def _wrapped_lfs_upload(batch_action) -> None: @validate_hf_hub_args def _upload_xet_files( *, - additions: List[CommitOperationAdd], + additions: list[CommitOperationAdd], repo_type: str, repo_id: str, - headers: Dict[str, str], + headers: dict[str, str], endpoint: Optional[str] = None, revision: Optional[str] = None, create_pr: Optional[bool] = None, @@ -486,7 +486,7 @@ def _upload_xet_files( repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. - headers (`Dict[str, str]`): + headers (`dict[str, str]`): Headers to use for the request, including authorization headers and user agent. endpoint: (`str`, *optional*): The endpoint to use for the xetcas service. Defaults to `constants.ENDPOINT`. @@ -555,7 +555,7 @@ def _upload_xet_files( xet_endpoint = xet_connection_info.endpoint access_token_info = (xet_connection_info.access_token, xet_connection_info.expiration_unix_epoch) - def token_refresher() -> Tuple[str, int]: + def token_refresher() -> tuple[str, int]: new_xet_connection = fetch_xet_connection_info_from_repo_info( token_type=XetTokenType.WRITE, repo_id=repo_id, @@ -628,7 +628,7 @@ def _fetch_upload_modes( additions: Iterable[CommitOperationAdd], repo_type: str, repo_id: str, - headers: Dict[str, str], + headers: dict[str, str], revision: str, endpoint: Optional[str] = None, create_pr: bool = False, @@ -647,7 +647,7 @@ def _fetch_upload_modes( repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. - headers (`Dict[str, str]`): + headers (`dict[str, str]`): Headers to use for the request, including authorization headers and user agent. revision (`str`): The git revision to upload the files to. Can be any valid git revision. @@ -665,12 +665,12 @@ def _fetch_upload_modes( endpoint = endpoint if endpoint is not None else constants.ENDPOINT # Fetch upload mode (LFS or regular) chunk by chunk. - upload_modes: Dict[str, UploadMode] = {} - should_ignore_info: Dict[str, bool] = {} - oid_info: Dict[str, Optional[str]] = {} + upload_modes: dict[str, UploadMode] = {} + should_ignore_info: dict[str, bool] = {} + oid_info: dict[str, Optional[str]] = {} for chunk in chunk_iterable(additions, 256): - payload: Dict = { + payload: dict = { "files": [ { "path": op.path_in_repo, @@ -713,10 +713,10 @@ def _fetch_files_to_copy( copies: Iterable[CommitOperationCopy], repo_type: str, repo_id: str, - headers: Dict[str, str], + headers: dict[str, str], revision: str, endpoint: Optional[str] = None, -) -> Dict[Tuple[str, Optional[str]], Union["RepoFile", bytes]]: +) -> dict[tuple[str, Optional[str]], Union["RepoFile", bytes]]: """ Fetch information about the files to copy. @@ -732,12 +732,12 @@ def _fetch_files_to_copy( repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. - headers (`Dict[str, str]`): + headers (`dict[str, str]`): Headers to use for the request, including authorization headers and user agent. revision (`str`): The git revision to upload the files to. Can be any valid git revision. - Returns: `Dict[Tuple[str, Optional[str]], Union[RepoFile, bytes]]]` + Returns: `dict[tuple[str, Optional[str]], Union[RepoFile, bytes]]]` Key is the file path and revision of the file to copy. Value is the raw content as bytes (for regular files) or the file information as a RepoFile (for LFS files). @@ -750,9 +750,9 @@ def _fetch_files_to_copy( from .hf_api import HfApi, RepoFolder hf_api = HfApi(endpoint=endpoint, headers=headers) - files_to_copy: Dict[Tuple[str, Optional[str]], Union["RepoFile", bytes]] = {} + files_to_copy: dict[tuple[str, Optional[str]], Union["RepoFile", bytes]] = {} # Store (path, revision) -> oid mapping - oid_info: Dict[Tuple[str, Optional[str]], Optional[str]] = {} + oid_info: dict[tuple[str, Optional[str]], Optional[str]] = {} # 1. Fetch OIDs for destination paths in batches. dest_paths = [op.path_in_repo for op in copies] for offset in range(0, len(dest_paths), FETCH_LFS_BATCH_SIZE): @@ -812,11 +812,11 @@ def _fetch_files_to_copy( def _prepare_commit_payload( operations: Iterable[CommitOperation], - files_to_copy: Dict[Tuple[str, Optional[str]], Union["RepoFile", bytes]], + files_to_copy: dict[tuple[str, Optional[str]], Union["RepoFile", bytes]], commit_message: str, commit_description: Optional[str] = None, parent_commit: Optional[str] = None, -) -> Iterable[Dict[str, Any]]: +) -> Iterable[dict[str, Any]]: """ Builds the payload to POST to the `/commit` API of the Hub. diff --git a/src/huggingface_hub/_commit_scheduler.py b/src/huggingface_hub/_commit_scheduler.py index f28180fd68..3aee068fd8 100644 --- a/src/huggingface_hub/_commit_scheduler.py +++ b/src/huggingface_hub/_commit_scheduler.py @@ -7,7 +7,7 @@ from io import SEEK_END, SEEK_SET, BytesIO from pathlib import Path from threading import Lock, Thread -from typing import Dict, List, Optional, Union +from typing import Optional, Union from .hf_api import DEFAULT_IGNORE_PATTERNS, CommitInfo, CommitOperationAdd, HfApi from .utils import filter_repo_objects @@ -53,9 +53,9 @@ class CommitScheduler: Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists. token (`str`, *optional*): The token to use to commit to the repo. Defaults to the token saved on the machine. - allow_patterns (`List[str]` or `str`, *optional*): + allow_patterns (`list[str]` or `str`, *optional*): If provided, only files matching at least one pattern are uploaded. - ignore_patterns (`List[str]` or `str`, *optional*): + ignore_patterns (`list[str]` or `str`, *optional*): If provided, files matching any of the patterns are not uploaded. squash_history (`bool`, *optional*): Whether to squash the history of the repo after each commit. Defaults to `False`. Squashing commits is @@ -108,8 +108,8 @@ def __init__( revision: Optional[str] = None, private: Optional[bool] = None, token: Optional[str] = None, - allow_patterns: Optional[Union[List[str], str]] = None, - ignore_patterns: Optional[Union[List[str], str]] = None, + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, squash_history: bool = False, hf_api: Optional["HfApi"] = None, ) -> None: @@ -138,7 +138,7 @@ def __init__( self.token = token # Keep track of already uploaded files - self.last_uploaded: Dict[Path, float] = {} # key is local path, value is timestamp + self.last_uploaded: dict[Path, float] = {} # key is local path, value is timestamp # Scheduler if not every > 0: @@ -232,7 +232,7 @@ def push_to_hub(self) -> Optional[CommitInfo]: prefix = f"{self.path_in_repo.strip('/')}/" if self.path_in_repo else "" # Filter with pattern + filter out unchanged files + retrieve current file size - files_to_upload: List[_FileToUpload] = [] + files_to_upload: list[_FileToUpload] = [] for relpath in filter_repo_objects( relpath_to_abspath.keys(), allow_patterns=self.allow_patterns, ignore_patterns=self.ignore_patterns ): diff --git a/src/huggingface_hub/_inference_endpoints.py b/src/huggingface_hub/_inference_endpoints.py index 37f772bfbe..4422cac7c3 100644 --- a/src/huggingface_hub/_inference_endpoints.py +++ b/src/huggingface_hub/_inference_endpoints.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field from datetime import datetime from enum import Enum -from typing import TYPE_CHECKING, Dict, Optional, Union +from typing import TYPE_CHECKING, Optional, Union from huggingface_hub.errors import InferenceEndpointError, InferenceEndpointTimeoutError @@ -62,7 +62,7 @@ class InferenceEndpoint: The timestamp of the last update of the Inference Endpoint. type ([`InferenceEndpointType`]): The type of the Inference Endpoint (public, protected, private). - raw (`Dict`): + raw (`dict`): The raw dictionary data returned from the API. token (`str` or `bool`, *optional*): Authentication token for the Inference Endpoint, if set when requesting the API. Will default to the @@ -112,7 +112,7 @@ class InferenceEndpoint: type: InferenceEndpointType = field(repr=False, init=False) # Raw dict from the API - raw: Dict = field(repr=False) + raw: dict = field(repr=False) # Internal fields _token: Union[str, bool, None] = field(repr=False, compare=False) @@ -120,7 +120,7 @@ class InferenceEndpoint: @classmethod def from_raw( - cls, raw: Dict, namespace: str, token: Union[str, bool, None] = None, api: Optional["HfApi"] = None + cls, raw: dict, namespace: str, token: Union[str, bool, None] = None, api: Optional["HfApi"] = None ) -> "InferenceEndpoint": """Initialize object from raw dictionary.""" if api is None: @@ -260,8 +260,8 @@ def update( framework: Optional[str] = None, revision: Optional[str] = None, task: Optional[str] = None, - custom_image: Optional[Dict] = None, - secrets: Optional[Dict[str, str]] = None, + custom_image: Optional[dict] = None, + secrets: Optional[dict[str, str]] = None, ) -> "InferenceEndpoint": """Update the Inference Endpoint. @@ -293,10 +293,10 @@ def update( The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`). task (`str`, *optional*): The task on which to deploy the model (e.g. `"text-classification"`). - custom_image (`Dict`, *optional*): + custom_image (`dict`, *optional*): A custom Docker image to use for the Inference Endpoint. This is useful if you want to deploy an Inference Endpoint running on the `text-generation-inference` (TGI) framework (see examples). - secrets (`Dict[str, str]`, *optional*): + secrets (`dict[str, str]`, *optional*): Secret values to inject in the container environment. Returns: [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. diff --git a/src/huggingface_hub/_jobs_api.py b/src/huggingface_hub/_jobs_api.py index 623fd9dc9d..c85324ce1c 100644 --- a/src/huggingface_hub/_jobs_api.py +++ b/src/huggingface_hub/_jobs_api.py @@ -15,7 +15,7 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from huggingface_hub import constants from huggingface_hub._space_api import SpaceHardware @@ -71,13 +71,13 @@ class JobInfo: space_id (`str` or `None`): The Docker image from Hugging Face Spaces used for the Job. Can be None if docker_image is present instead. - command (`List[str]` or `None`): + command (`list[str]` or `None`): Command of the Job, e.g. `["python", "-c", "print('hello world')"]` - arguments (`List[str]` or `None`): + arguments (`list[str]` or `None`): Arguments passed to the command - environment (`Dict[str]` or `None`): + environment (`dict[str]` or `None`): Environment variables of the Job as a dictionary. - secrets (`Dict[str]` or `None`): + secrets (`dict[str]` or `None`): Secret environment variables of the Job (encrypted). flavor (`str` or `None`): Flavor for the hardware, as in Hugging Face Spaces. See [`SpaceHardware`] for possible values. @@ -111,10 +111,10 @@ class JobInfo: created_at: Optional[datetime] docker_image: Optional[str] space_id: Optional[str] - command: Optional[List[str]] - arguments: Optional[List[str]] - environment: Optional[Dict[str, Any]] - secrets: Optional[Dict[str, Any]] + command: Optional[list[str]] + arguments: Optional[list[str]] + environment: Optional[dict[str, Any]] + secrets: Optional[dict[str, Any]] flavor: Optional[SpaceHardware] status: JobStatus owner: JobOwner @@ -148,13 +148,13 @@ def __init__(self, **kwargs) -> None: class JobSpec: docker_image: Optional[str] space_id: Optional[str] - command: Optional[List[str]] - arguments: Optional[List[str]] - environment: Optional[Dict[str, Any]] - secrets: Optional[Dict[str, Any]] + command: Optional[list[str]] + arguments: Optional[list[str]] + environment: Optional[dict[str, Any]] + secrets: Optional[dict[str, Any]] flavor: Optional[SpaceHardware] timeout: Optional[int] - tags: Optional[List[str]] + tags: Optional[list[str]] arch: Optional[str] def __init__(self, **kwargs) -> None: @@ -202,7 +202,7 @@ class ScheduledJobInfo: Scheduled Job ID. created_at (`datetime` or `None`): When the scheduled Job was created. - tags (`List[str]` or `None`): + tags (`list[str]` or `None`): The tags of the scheduled Job. schedule (`str` or `None`): One of "@annually", "@yearly", "@monthly", "@weekly", "@daily", "@hourly", or a @@ -263,14 +263,14 @@ def __init__(self, **kwargs) -> None: def _create_job_spec( *, image: str, - command: List[str], - env: Optional[Dict[str, Any]], - secrets: Optional[Dict[str, Any]], + command: list[str], + env: Optional[dict[str, Any]], + secrets: Optional[dict[str, Any]], flavor: Optional[SpaceHardware], timeout: Optional[Union[int, float, str]], -) -> Dict[str, Any]: +) -> dict[str, Any]: # prepare job spec to send to HF Jobs API - job_spec: Dict[str, Any] = { + job_spec: dict[str, Any] = { "command": command, "arguments": [], "environment": env or {}, diff --git a/src/huggingface_hub/_oauth.py b/src/huggingface_hub/_oauth.py index 9f8eb60796..7bdfa6a058 100644 --- a/src/huggingface_hub/_oauth.py +++ b/src/huggingface_hub/_oauth.py @@ -6,7 +6,7 @@ import urllib.parse import warnings from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Literal, Optional, Union from . import constants from .hf_api import whoami @@ -39,7 +39,7 @@ class OAuthOrgInfo: Whether the org has a payment method set up. Hugging Face field. role_in_org (`Optional[str]`, *optional*): The user's role in the org. Hugging Face field. - security_restrictions (`Optional[List[Literal["ip", "token-policy", "mfa", "sso"]]]`, *optional*): + security_restrictions (`Optional[list[Literal["ip", "token-policy", "mfa", "sso"]]]`, *optional*): Array of security restrictions that the user hasn't completed for this org. Possible values: "ip", "token-policy", "mfa", "sso". Hugging Face field. """ @@ -50,7 +50,7 @@ class OAuthOrgInfo: is_enterprise: bool can_pay: Optional[bool] = None role_in_org: Optional[str] = None - security_restrictions: Optional[List[Literal["ip", "token-policy", "mfa", "sso"]]] = None + security_restrictions: Optional[list[Literal["ip", "token-policy", "mfa", "sso"]]] = None @dataclass @@ -79,7 +79,7 @@ class OAuthUserInfo: Whether the user is a pro user. Hugging Face field. can_pay (`Optional[bool]`, *optional*): Whether the user has a payment method set up. Hugging Face field. - orgs (`Optional[List[OrgInfo]]`, *optional*): + orgs (`Optional[list[OrgInfo]]`, *optional*): List of organizations the user is part of. Hugging Face field. """ @@ -93,7 +93,7 @@ class OAuthUserInfo: website: Optional[str] is_pro: bool can_pay: Optional[bool] - orgs: Optional[List[OAuthOrgInfo]] + orgs: Optional[list[OAuthOrgInfo]] @dataclass @@ -306,7 +306,7 @@ async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse: target_url = request.query_params.get("_target_url") # Build redirect URI with the same query params as before and bump nb_redirects count - query_params: Dict[str, Union[int, str]] = {"_nb_redirects": nb_redirects + 1} + query_params: dict[str, Union[int, str]] = {"_nb_redirects": nb_redirects + 1} if target_url: query_params["_target_url"] = target_url @@ -406,7 +406,7 @@ def _get_redirect_target(request: "fastapi.Request", default_target: str = "/") return request.query_params.get("_target_url", default_target) -def _get_mocked_oauth_info() -> Dict: +def _get_mocked_oauth_info() -> dict: token = get_token() if token is None: raise ValueError( @@ -449,7 +449,7 @@ def _get_mocked_oauth_info() -> Dict: } -def _get_oauth_uris(route_prefix: str = "/") -> Tuple[str, str, str]: +def _get_oauth_uris(route_prefix: str = "/") -> tuple[str, str, str]: route_prefix = route_prefix.strip("/") if route_prefix: route_prefix = f"/{route_prefix}" diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index aa65d561da..200ed7cc2e 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Dict, Iterable, List, Literal, Optional, Type, Union +from typing import Iterable, Literal, Optional, Union import httpx from tqdm.auto import tqdm as base_tqdm @@ -35,16 +35,16 @@ def snapshot_download( local_dir: Union[str, Path, None] = None, library_name: Optional[str] = None, library_version: Optional[str] = None, - user_agent: Optional[Union[Dict, str]] = None, + user_agent: Optional[Union[dict, str]] = None, etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, force_download: bool = False, token: Optional[Union[bool, str]] = None, local_files_only: bool = False, - allow_patterns: Optional[Union[List[str], str]] = None, - ignore_patterns: Optional[Union[List[str], str]] = None, + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, max_workers: int = 8, - tqdm_class: Optional[Type[base_tqdm]] = None, - headers: Optional[Dict[str, str]] = None, + tqdm_class: Optional[type[base_tqdm]] = None, + headers: Optional[dict[str, str]] = None, endpoint: Optional[str] = None, # Deprecated args local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", @@ -99,9 +99,9 @@ def snapshot_download( local_files_only (`bool`, *optional*, defaults to `False`): If `True`, avoid downloading the file and return the path to the local cached file if it exists. - allow_patterns (`List[str]` or `str`, *optional*): + allow_patterns (`list[str]` or `str`, *optional*): If provided, only files matching at least one pattern are downloaded. - ignore_patterns (`List[str]` or `str`, *optional*): + ignore_patterns (`list[str]` or `str`, *optional*): If provided, files matching any of the patterns are not downloaded. max_workers (`int`, *optional*): Number of concurrent threads to download files (1 thread = 1 file download). diff --git a/src/huggingface_hub/_space_api.py b/src/huggingface_hub/_space_api.py index 05fccfbc1e..6dd7976329 100644 --- a/src/huggingface_hub/_space_api.py +++ b/src/huggingface_hub/_space_api.py @@ -15,7 +15,7 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Dict, Optional +from typing import Optional from huggingface_hub.utils import parse_datetime @@ -128,9 +128,9 @@ class SpaceRuntime: requested_hardware: Optional[SpaceHardware] sleep_time: Optional[int] storage: Optional[SpaceStorage] - raw: Dict + raw: dict - def __init__(self, data: Dict) -> None: + def __init__(self, data: dict) -> None: self.stage = data["stage"] self.hardware = data.get("hardware", {}).get("current") self.requested_hardware = data.get("hardware", {}).get("requested") @@ -160,7 +160,7 @@ class SpaceVariable: description: Optional[str] updated_at: Optional[datetime] - def __init__(self, key: str, values: Dict) -> None: + def __init__(self, key: str, values: dict) -> None: self.key = key self.value = values["value"] self.description = values.get("description") diff --git a/src/huggingface_hub/_tensorboard_logger.py b/src/huggingface_hub/_tensorboard_logger.py index fb172acceb..5b15468cfd 100644 --- a/src/huggingface_hub/_tensorboard_logger.py +++ b/src/huggingface_hub/_tensorboard_logger.py @@ -14,7 +14,7 @@ """Contains a logger to push training logs to the Hub, using Tensorboard.""" from pathlib import Path -from typing import List, Optional, Union +from typing import Optional, Union from ._commit_scheduler import CommitScheduler from .errors import EntryNotFoundError @@ -77,10 +77,10 @@ class HFSummaryWriter(_RuntimeSummaryWriter): Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists. path_in_repo (`str`, *optional*): The path to the folder in the repo where the logs will be pushed. Defaults to "tensorboard/". - repo_allow_patterns (`List[str]` or `str`, *optional*): + repo_allow_patterns (`list[str]` or `str`, *optional*): A list of patterns to include in the upload. Defaults to `"*.tfevents.*"`. Check out the [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#upload-a-folder) for more details. - repo_ignore_patterns (`List[str]` or `str`, *optional*): + repo_ignore_patterns (`list[str]` or `str`, *optional*): A list of patterns to exclude in the upload. Check out the [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#upload-a-folder) for more details. token (`str`, *optional*): @@ -137,8 +137,8 @@ def __init__( repo_revision: Optional[str] = None, repo_private: Optional[bool] = None, path_in_repo: Optional[str] = "tensorboard", - repo_allow_patterns: Optional[Union[List[str], str]] = "*.tfevents.*", - repo_ignore_patterns: Optional[Union[List[str], str]] = None, + repo_allow_patterns: Optional[Union[list[str], str]] = "*.tfevents.*", + repo_ignore_patterns: Optional[Union[list[str], str]] = None, token: Optional[str] = None, **kwargs, ): diff --git a/src/huggingface_hub/_upload_large_folder.py b/src/huggingface_hub/_upload_large_folder.py index 1ccbc07d39..a4eaeaf250 100644 --- a/src/huggingface_hub/_upload_large_folder.py +++ b/src/huggingface_hub/_upload_large_folder.py @@ -24,7 +24,7 @@ from datetime import datetime from pathlib import Path from threading import Lock -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Optional, Union from urllib.parse import quote from . import constants @@ -44,7 +44,7 @@ WAITING_TIME_IF_NO_TASKS = 10 # seconds MAX_NB_FILES_FETCH_UPLOAD_MODE = 100 -COMMIT_SIZE_SCALE: List[int] = [20, 50, 75, 100, 125, 200, 250, 400, 600, 1000] +COMMIT_SIZE_SCALE: list[int] = [20, 50, 75, 100, 125, 200, 250, 400, 600, 1000] UPLOAD_BATCH_SIZE_XET = 256 # Max 256 files per upload batch for XET-enabled repos UPLOAD_BATCH_SIZE_LFS = 1 # Otherwise, batches of 1 for regular LFS upload @@ -56,7 +56,7 @@ RECOMMENDED_FILE_SIZE_GB = 20 # Recommended maximum for individual file size -def _validate_upload_limits(paths_list: List[LocalUploadFilePaths]) -> None: +def _validate_upload_limits(paths_list: list[LocalUploadFilePaths]) -> None: """ Validate upload against repository limits and warn about potential issues. @@ -85,7 +85,7 @@ def _validate_upload_limits(paths_list: List[LocalUploadFilePaths]) -> None: # Track immediate children (files and subdirs) for each folder from collections import defaultdict - entries_per_folder: Dict[str, Any] = defaultdict(lambda: {"files": 0, "subdirs": set()}) + entries_per_folder: dict[str, Any] = defaultdict(lambda: {"files": 0, "subdirs": set()}) for paths in paths_list: path = Path(paths.path_in_repo) @@ -160,8 +160,8 @@ def upload_large_folder_internal( repo_type: str, # Repo type is required! revision: Optional[str] = None, private: Optional[bool] = None, - allow_patterns: Optional[Union[List[str], str]] = None, - ignore_patterns: Optional[Union[List[str], str]] = None, + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, num_workers: Optional[int] = None, print_report: bool = True, print_report_every: int = 60, @@ -284,13 +284,13 @@ class WorkerJob(enum.Enum): WAIT = enum.auto() # if no tasks are available but we don't want to exit -JOB_ITEM_T = Tuple[LocalUploadFilePaths, LocalUploadFileMetadata] +JOB_ITEM_T = tuple[LocalUploadFilePaths, LocalUploadFileMetadata] class LargeUploadStatus: """Contains information, queues and tasks for a large upload process.""" - def __init__(self, items: List[JOB_ITEM_T], upload_batch_size: int = 1): + def __init__(self, items: list[JOB_ITEM_T], upload_batch_size: int = 1): self.items = items self.queue_sha256: "queue.Queue[JOB_ITEM_T]" = queue.Queue() self.queue_get_upload_mode: "queue.Queue[JOB_ITEM_T]" = queue.Queue() @@ -423,7 +423,7 @@ def _worker_job( Read `upload_large_folder` docstring for more information on how tasks are prioritized. """ while True: - next_job: Optional[Tuple[WorkerJob, List[JOB_ITEM_T]]] = None + next_job: Optional[tuple[WorkerJob, list[JOB_ITEM_T]]] = None # Determine next task next_job = _determine_next_job(status) @@ -516,7 +516,7 @@ def _worker_job( status.nb_workers_waiting -= 1 -def _determine_next_job(status: LargeUploadStatus) -> Optional[Tuple[WorkerJob, List[JOB_ITEM_T]]]: +def _determine_next_job(status: LargeUploadStatus) -> Optional[tuple[WorkerJob, list[JOB_ITEM_T]]]: with status.lock: # 1. Commit if more than 5 minutes since last commit attempt (and at least 1 file) if ( @@ -639,7 +639,7 @@ def _compute_sha256(item: JOB_ITEM_T) -> None: metadata.save(paths) -def _get_upload_mode(items: List[JOB_ITEM_T], api: "HfApi", repo_id: str, repo_type: str, revision: str) -> None: +def _get_upload_mode(items: list[JOB_ITEM_T], api: "HfApi", repo_id: str, repo_type: str, revision: str) -> None: """Get upload mode for each file and update metadata. Also receive info if the file should be ignored. @@ -661,7 +661,7 @@ def _get_upload_mode(items: List[JOB_ITEM_T], api: "HfApi", repo_id: str, repo_t metadata.save(paths) -def _preupload_lfs(items: List[JOB_ITEM_T], api: "HfApi", repo_id: str, repo_type: str, revision: str) -> None: +def _preupload_lfs(items: list[JOB_ITEM_T], api: "HfApi", repo_id: str, repo_type: str, revision: str) -> None: """Preupload LFS files and update metadata.""" additions = [_build_hacky_operation(item) for item in items] api.preupload_lfs_files( @@ -676,7 +676,7 @@ def _preupload_lfs(items: List[JOB_ITEM_T], api: "HfApi", repo_id: str, repo_typ metadata.save(paths) -def _commit(items: List[JOB_ITEM_T], api: "HfApi", repo_id: str, repo_type: str, revision: str) -> None: +def _commit(items: list[JOB_ITEM_T], api: "HfApi", repo_id: str, repo_type: str, revision: str) -> None: """Commit files to the repo.""" additions = [_build_hacky_operation(item) for item in items] api.create_commit( @@ -721,11 +721,11 @@ def _build_hacky_operation(item: JOB_ITEM_T) -> HackyCommitOperationAdd: #################### -def _get_one(queue: "queue.Queue[JOB_ITEM_T]") -> List[JOB_ITEM_T]: +def _get_one(queue: "queue.Queue[JOB_ITEM_T]") -> list[JOB_ITEM_T]: return [queue.get()] -def _get_n(queue: "queue.Queue[JOB_ITEM_T]", n: int) -> List[JOB_ITEM_T]: +def _get_n(queue: "queue.Queue[JOB_ITEM_T]", n: int) -> list[JOB_ITEM_T]: return [queue.get() for _ in range(min(queue.qsize(), n))] diff --git a/src/huggingface_hub/_webhooks_payload.py b/src/huggingface_hub/_webhooks_payload.py index 288f4b08b9..90f12425cb 100644 --- a/src/huggingface_hub/_webhooks_payload.py +++ b/src/huggingface_hub/_webhooks_payload.py @@ -14,7 +14,7 @@ # limitations under the License. """Contains data structures to parse the webhooks payload.""" -from typing import List, Literal, Optional +from typing import Literal, Optional from .utils import is_pydantic_available @@ -116,7 +116,7 @@ class WebhookPayloadRepo(ObjectId): name: str private: bool subdomain: Optional[str] = None - tags: Optional[List[str]] = None + tags: Optional[list[str]] = None type: Literal["dataset", "model", "space"] url: WebhookPayloadUrl @@ -134,4 +134,4 @@ class WebhookPayload(BaseModel): comment: Optional[WebhookPayloadComment] = None webhook: WebhookPayloadWebhook movedTo: Optional[WebhookPayloadMovedTo] = None - updatedRefs: Optional[List[WebhookPayloadUpdatedRef]] = None + updatedRefs: Optional[list[WebhookPayloadUpdatedRef]] = None diff --git a/src/huggingface_hub/_webhooks_server.py b/src/huggingface_hub/_webhooks_server.py index a7bd6c8626..f8ca539af0 100644 --- a/src/huggingface_hub/_webhooks_server.py +++ b/src/huggingface_hub/_webhooks_server.py @@ -18,7 +18,7 @@ import inspect import os from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional from .utils import experimental, is_fastapi_available, is_gradio_available @@ -115,7 +115,7 @@ def __init__( self._ui = ui self.webhook_secret = webhook_secret or os.getenv("WEBHOOK_SECRET") - self.registered_webhooks: Dict[str, Callable] = {} + self.registered_webhooks: dict[str, Callable] = {} _warn_on_empty_secret(self.webhook_secret) def add_webhook(self, path: Optional[str] = None) -> Callable: diff --git a/src/huggingface_hub/cli/_cli_utils.py b/src/huggingface_hub/cli/_cli_utils.py index bd56ad6896..d0f0b98d8a 100644 --- a/src/huggingface_hub/cli/_cli_utils.py +++ b/src/huggingface_hub/cli/_cli_utils.py @@ -14,7 +14,7 @@ """Contains a utility for good-looking prints.""" import os -from typing import List, Union +from typing import Union class ANSI: @@ -52,7 +52,7 @@ def _format(cls, s: str, code: str) -> str: return f"{code}{s}{cls._reset}" -def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str: +def tabulate(rows: list[list[Union[str, int]]], headers: list[str]) -> str: """ Inspired by: diff --git a/src/huggingface_hub/cli/auth.py b/src/huggingface_hub/cli/auth.py index 91e6b3c18d..8740a86423 100644 --- a/src/huggingface_hub/cli/auth.py +++ b/src/huggingface_hub/cli/auth.py @@ -31,7 +31,7 @@ """ from argparse import _SubParsersAction -from typing import List, Optional +from typing import Optional from huggingface_hub.commands import BaseHuggingfaceCLICommand from huggingface_hub.constants import ENDPOINT @@ -172,7 +172,7 @@ def _select_token_name(self) -> Optional[str]: except ValueError: print("Invalid input. Please enter a number or 'q' to quit.") - def _select_token_name_tui(self, token_names: List[str]) -> Optional[str]: + def _select_token_name_tui(self, token_names: list[str]) -> Optional[str]: choices = [Choice(token_name, name=token_name) for token_name in token_names] try: return inquirer.select( diff --git a/src/huggingface_hub/cli/cache.py b/src/huggingface_hub/cli/cache.py index cc36ef5efd..7eb3e82509 100644 --- a/src/huggingface_hub/cli/cache.py +++ b/src/huggingface_hub/cli/cache.py @@ -19,7 +19,7 @@ from argparse import Namespace, _SubParsersAction from functools import wraps from tempfile import mkstemp -from typing import Any, Callable, Iterable, List, Literal, Optional, Union +from typing import Any, Callable, Iterable, Literal, Optional, Union from ..utils import CachedRepoInfo, CachedRevisionInfo, CacheNotFound, HFCacheInfo, scan_cache_dir from . import BaseHuggingfaceCLICommand @@ -243,8 +243,8 @@ def _get_repo_sorting_key(repo: CachedRepoInfo, sort_by: Optional[SortingOption_ @require_inquirer_py def _manual_review_tui( - hf_cache_info: HFCacheInfo, preselected: List[str], sort_by: Optional[SortingOption_T] = None -) -> List[str]: + hf_cache_info: HFCacheInfo, preselected: list[str], sort_by: Optional[SortingOption_T] = None +) -> list[str]: choices = _get_tui_choices_from_scan(repos=hf_cache_info.repos, preselected=preselected, sort_by=sort_by) checkbox = inquirer.checkbox( message="Select revisions to delete:", @@ -277,9 +277,9 @@ def _ask_for_confirmation_tui(message: str, default: bool = True) -> bool: def _get_tui_choices_from_scan( - repos: Iterable[CachedRepoInfo], preselected: List[str], sort_by: Optional[SortingOption_T] = None -) -> List: - choices: List[Union["Choice", "Separator"]] = [] + repos: Iterable[CachedRepoInfo], preselected: list[str], sort_by: Optional[SortingOption_T] = None +) -> list: + choices: list[Union["Choice", "Separator"]] = [] choices.append( Choice( _CANCEL_DELETION_STR, name="None of the following (if selected, nothing will be deleted).", enabled=False @@ -306,8 +306,8 @@ def _get_tui_choices_from_scan( def _manual_review_no_tui( - hf_cache_info: HFCacheInfo, preselected: List[str], sort_by: Optional[SortingOption_T] = None -) -> List[str]: + hf_cache_info: HFCacheInfo, preselected: list[str], sort_by: Optional[SortingOption_T] = None +) -> list[str]: fd, tmp_path = mkstemp(suffix=".txt") os.close(fd) lines = [] @@ -358,14 +358,14 @@ def _ask_for_confirmation_no_tui(message: str, default: bool = True) -> bool: print(f"Invalid input. Must be one of {ALL}") -def _get_expectations_str(hf_cache_info: HFCacheInfo, selected_hashes: List[str]) -> str: +def _get_expectations_str(hf_cache_info: HFCacheInfo, selected_hashes: list[str]) -> str: if _CANCEL_DELETION_STR in selected_hashes: return "Nothing will be deleted." strategy = hf_cache_info.delete_revisions(*selected_hashes) return f"{len(selected_hashes)} revisions selected counting for {strategy.expected_freed_size_str}." -def _read_manual_review_tmp_file(tmp_path: str) -> List[str]: +def _read_manual_review_tmp_file(tmp_path: str) -> list[str]: with open(tmp_path) as f: content = f.read() lines = [line.strip() for line in content.split("\n")] diff --git a/src/huggingface_hub/cli/download.py b/src/huggingface_hub/cli/download.py index 3e59233da1..ea6714d124 100644 --- a/src/huggingface_hub/cli/download.py +++ b/src/huggingface_hub/cli/download.py @@ -38,7 +38,7 @@ import warnings from argparse import Namespace, _SubParsersAction -from typing import List, Optional +from typing import Optional from huggingface_hub import logging from huggingface_hub._snapshot_download import snapshot_download @@ -113,11 +113,11 @@ def register_subcommand(parser: _SubParsersAction): def __init__(self, args: Namespace) -> None: self.token = args.token self.repo_id: str = args.repo_id - self.filenames: List[str] = args.filenames + self.filenames: list[str] = args.filenames self.repo_type: str = args.repo_type self.revision: Optional[str] = args.revision - self.include: Optional[List[str]] = args.include - self.exclude: Optional[List[str]] = args.exclude + self.include: Optional[list[str]] = args.include + self.exclude: Optional[list[str]] = args.exclude self.cache_dir: Optional[str] = args.cache_dir self.local_dir: Optional[str] = args.local_dir self.force_download: bool = args.force_download diff --git a/src/huggingface_hub/cli/jobs.py b/src/huggingface_hub/cli/jobs.py index 5b8d355c6f..b4f8385e1d 100644 --- a/src/huggingface_hub/cli/jobs.py +++ b/src/huggingface_hub/cli/jobs.py @@ -36,7 +36,7 @@ from argparse import Namespace, _SubParsersAction from dataclasses import asdict from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Optional, Union from huggingface_hub import HfApi, SpaceHardware, get_token from huggingface_hub.errors import HfHubHTTPError @@ -118,7 +118,7 @@ def register_subcommand(parser: _SubParsersAction) -> None: def __init__(self, args: Namespace) -> None: self.image: str = args.image - self.command: List[str] = args.command + self.command: list[str] = args.command self.env: dict[str, Optional[str]] = {} if args.env_file: self.env.update(load_dotenv(Path(args.env_file).read_text(), environ=os.environ.copy())) @@ -185,7 +185,7 @@ def run(self) -> None: print(log) -def _tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str: +def _tabulate(rows: list[list[Union[str, int]]], headers: list[str]) -> str: """ Inspired by: @@ -253,7 +253,7 @@ def __init__(self, args: Namespace) -> None: self.namespace: Optional[str] = args.namespace self.token: Optional[str] = args.token self.format: Optional[str] = args.format - self.filters: Dict[str, str] = {} + self.filters: dict[str, str] = {} # Parse filter arguments (key=value pairs) for f in args.filter: @@ -335,7 +335,7 @@ def run(self) -> None: except Exception as e: print(f"Unexpected error - {type(e).__name__}: {e}") - def _matches_filters(self, job_properties: Dict[str, str]) -> bool: + def _matches_filters(self, job_properties: dict[str, str]) -> bool: """Check if job matches all specified filters.""" for key, pattern in self.filters.items(): # Check if property exists @@ -394,7 +394,7 @@ def register_subcommand(parser: _SubParsersAction) -> None: def __init__(self, args: Namespace) -> None: self.namespace: Optional[str] = args.namespace self.token: Optional[str] = args.token - self.job_ids: List[str] = args.job_ids + self.job_ids: list[str] = args.job_ids def run(self) -> None: api = HfApi(token=self.token) @@ -543,7 +543,7 @@ def run(self) -> None: print(log) -def _get_extended_environ() -> Dict[str, str]: +def _get_extended_environ() -> dict[str, str]: extended_environ = os.environ.copy() if (token := get_token()) is not None: extended_environ["HF_TOKEN"] = token @@ -631,7 +631,7 @@ def register_subcommand(parser: _SubParsersAction) -> None: def __init__(self, args: Namespace) -> None: self.schedule: str = args.schedule self.image: str = args.image - self.command: List[str] = args.command + self.command: list[str] = args.command self.suspend: Optional[bool] = args.suspend self.concurrency: Optional[bool] = args.concurrency self.env: dict[str, Optional[str]] = {} @@ -709,7 +709,7 @@ def __init__(self, args: Namespace) -> None: self.namespace: Optional[str] = args.namespace self.token: Optional[str] = args.token self.format: Optional[str] = args.format - self.filters: Dict[str, str] = {} + self.filters: dict[str, str] = {} # Parse filter arguments (key=value pairs) for f in args.filter: @@ -821,7 +821,7 @@ def run(self) -> None: except Exception as e: print(f"Unexpected error - {type(e).__name__}: {e}") - def _matches_filters(self, job_properties: Dict[str, str]) -> bool: + def _matches_filters(self, job_properties: dict[str, str]) -> bool: """Check if scheduled job matches all specified filters.""" for key, pattern in self.filters.items(): # Check if property exists @@ -882,7 +882,7 @@ def register_subcommand(parser: _SubParsersAction) -> None: def __init__(self, args: Namespace) -> None: self.namespace: Optional[str] = args.namespace self.token: Optional[str] = args.token - self.scheduled_job_ids: List[str] = args.scheduled_job_ids + self.scheduled_job_ids: list[str] = args.scheduled_job_ids def run(self) -> None: api = HfApi(token=self.token) diff --git a/src/huggingface_hub/cli/lfs.py b/src/huggingface_hub/cli/lfs.py index e4c5b900c8..ec1b19107e 100644 --- a/src/huggingface_hub/cli/lfs.py +++ b/src/huggingface_hub/cli/lfs.py @@ -21,7 +21,7 @@ import subprocess import sys from argparse import _SubParsersAction -from typing import Dict, List, Optional +from typing import Optional from huggingface_hub.commands import BaseHuggingfaceCLICommand from huggingface_hub.lfs import LFS_MULTIPART_UPLOAD_COMMAND @@ -87,14 +87,14 @@ def run(self): print("Local repo set up for largefiles") -def write_msg(msg: Dict): +def write_msg(msg: dict): """Write out the message in Line delimited JSON.""" msg_str = json.dumps(msg) + "\n" sys.stdout.write(msg_str) sys.stdout.flush() -def read_msg() -> Optional[Dict]: +def read_msg() -> Optional[dict]: """Read Line delimited JSON from stdin.""" msg = json.loads(sys.stdin.readline().strip()) @@ -144,7 +144,7 @@ def run(self) -> None: completion_url = msg["action"]["href"] header = msg["action"]["header"] chunk_size = int(header.pop("chunk_size")) - presigned_urls: List[str] = list(header.values()) + presigned_urls: list[str] = list(header.values()) # Send a "started" progress event to allow other workers to start. # Otherwise they're delayed until first "progress" event is reported, diff --git a/src/huggingface_hub/cli/repo_files.py b/src/huggingface_hub/cli/repo_files.py index 34fbeb09c2..ba3259e576 100644 --- a/src/huggingface_hub/cli/repo_files.py +++ b/src/huggingface_hub/cli/repo_files.py @@ -35,7 +35,7 @@ """ from argparse import _SubParsersAction -from typing import List, Optional +from typing import Optional from huggingface_hub import logging from huggingface_hub.commands import BaseHuggingfaceCLICommand @@ -52,7 +52,7 @@ def __init__(self, args) -> None: self.repo_type: Optional[str] = args.repo_type self.revision: Optional[str] = args.revision self.api: HfApi = HfApi(token=args.token, library_name="hf") - self.patterns: List[str] = args.patterns + self.patterns: list[str] = args.patterns self.commit_message: Optional[str] = args.commit_message self.commit_description: Optional[str] = args.commit_description self.create_pr: bool = args.create_pr diff --git a/src/huggingface_hub/cli/upload.py b/src/huggingface_hub/cli/upload.py index 07ab79bf24..80a0623743 100644 --- a/src/huggingface_hub/cli/upload.py +++ b/src/huggingface_hub/cli/upload.py @@ -50,7 +50,7 @@ import time import warnings from argparse import Namespace, _SubParsersAction -from typing import List, Optional +from typing import Optional from huggingface_hub import logging from huggingface_hub._commit_scheduler import CommitScheduler @@ -144,9 +144,9 @@ def __init__(self, args: Namespace) -> None: self.revision: Optional[str] = args.revision self.private: bool = args.private - self.include: Optional[List[str]] = args.include - self.exclude: Optional[List[str]] = args.exclude - self.delete: Optional[List[str]] = args.delete + self.include: Optional[list[str]] = args.include + self.exclude: Optional[list[str]] = args.exclude + self.delete: Optional[list[str]] = args.delete self.commit_message: Optional[str] = args.commit_message self.commit_description: Optional[str] = args.commit_description diff --git a/src/huggingface_hub/cli/upload_large_folder.py b/src/huggingface_hub/cli/upload_large_folder.py index 618cd21b52..f923abc0a7 100644 --- a/src/huggingface_hub/cli/upload_large_folder.py +++ b/src/huggingface_hub/cli/upload_large_folder.py @@ -16,7 +16,7 @@ import os from argparse import Namespace, _SubParsersAction -from typing import List, Optional +from typing import Optional from huggingface_hub import logging from huggingface_hub.commands import BaseHuggingfaceCLICommand @@ -76,8 +76,8 @@ def __init__(self, args: Namespace) -> None: self.revision: Optional[str] = args.revision self.private: bool = args.private - self.include: Optional[List[str]] = args.include - self.exclude: Optional[List[str]] = args.exclude + self.include: Optional[list[str]] = args.include + self.exclude: Optional[list[str]] = args.exclude self.api: HfApi = HfApi(token=args.token, library_name="hf") diff --git a/src/huggingface_hub/commands/_cli_utils.py b/src/huggingface_hub/commands/_cli_utils.py index bf4a1c0373..3f4819c26c 100644 --- a/src/huggingface_hub/commands/_cli_utils.py +++ b/src/huggingface_hub/commands/_cli_utils.py @@ -14,7 +14,7 @@ """Contains a utility for good-looking prints.""" import os -from typing import List, Union +from typing import Union class ANSI: @@ -52,7 +52,7 @@ def _format(cls, s: str, code: str) -> str: return f"{code}{s}{cls._reset}" -def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str: +def tabulate(rows: list[list[Union[str, int]]], headers: list[str]) -> str: """ Inspired by: diff --git a/src/huggingface_hub/commands/delete_cache.py b/src/huggingface_hub/commands/delete_cache.py index 78ea117967..2983ab80b0 100644 --- a/src/huggingface_hub/commands/delete_cache.py +++ b/src/huggingface_hub/commands/delete_cache.py @@ -60,7 +60,7 @@ from argparse import Namespace, _SubParsersAction from functools import wraps from tempfile import mkstemp -from typing import Any, Callable, Iterable, List, Literal, Optional, Union +from typing import Any, Callable, Iterable, Literal, Optional, Union from ..utils import CachedRepoInfo, CachedRevisionInfo, HFCacheInfo, scan_cache_dir from . import BaseHuggingfaceCLICommand @@ -197,9 +197,9 @@ def _get_repo_sorting_key(repo: CachedRepoInfo, sort_by: Optional[SortingOption_ @require_inquirer_py def _manual_review_tui( hf_cache_info: HFCacheInfo, - preselected: List[str], + preselected: list[str], sort_by: Optional[SortingOption_T] = None, -) -> List[str]: +) -> list[str]: """Ask the user for a manual review of the revisions to delete. Displays a multi-select menu in the terminal (TUI). @@ -254,15 +254,15 @@ def _ask_for_confirmation_tui(message: str, default: bool = True) -> bool: def _get_tui_choices_from_scan( repos: Iterable[CachedRepoInfo], - preselected: List[str], + preselected: list[str], sort_by: Optional[SortingOption_T] = None, -) -> List: +) -> list: """Build a list of choices from the scanned repos. Args: repos (*Iterable[`CachedRepoInfo`]*): List of scanned repos on which we want to delete revisions. - preselected (*List[`str`]*): + preselected (*list[`str`]*): List of revision hashes that will be preselected. sort_by (*Optional[SortingOption_T]*): Sorting direction. Choices: "alphabetical", "lastUpdated", "lastUsed", "size". @@ -270,7 +270,7 @@ def _get_tui_choices_from_scan( Return: The list of choices to pass to `inquirer.checkbox`. """ - choices: List[Union[Choice, Separator]] = [] + choices: list[Union[Choice, Separator]] = [] # First choice is to cancel the deletion choices.append( @@ -312,9 +312,9 @@ def _get_tui_choices_from_scan( def _manual_review_no_tui( hf_cache_info: HFCacheInfo, - preselected: List[str], + preselected: list[str], sort_by: Optional[SortingOption_T] = None, -) -> List[str]: +) -> list[str]: """Ask the user for a manual review of the revisions to delete. Used when TUI is disabled. Manual review happens in a separate tmp file that the @@ -390,7 +390,7 @@ def _ask_for_confirmation_no_tui(message: str, default: bool = True) -> bool: print(f"Invalid input. Must be one of {ALL}") -def _get_expectations_str(hf_cache_info: HFCacheInfo, selected_hashes: List[str]) -> str: +def _get_expectations_str(hf_cache_info: HFCacheInfo, selected_hashes: list[str]) -> str: """Format a string to display to the user how much space would be saved. Example: @@ -405,7 +405,7 @@ def _get_expectations_str(hf_cache_info: HFCacheInfo, selected_hashes: List[str] return f"{len(selected_hashes)} revisions selected counting for {strategy.expected_freed_size_str}." -def _read_manual_review_tmp_file(tmp_path: str) -> List[str]: +def _read_manual_review_tmp_file(tmp_path: str) -> list[str]: """Read the manually reviewed instruction file and return a list of revision hash. Example: diff --git a/src/huggingface_hub/commands/download.py b/src/huggingface_hub/commands/download.py index 0dd2c1070e..103f2a52b5 100644 --- a/src/huggingface_hub/commands/download.py +++ b/src/huggingface_hub/commands/download.py @@ -38,7 +38,7 @@ import warnings from argparse import Namespace, _SubParsersAction -from typing import List, Optional +from typing import Optional from huggingface_hub import logging from huggingface_hub._snapshot_download import snapshot_download @@ -125,11 +125,11 @@ def register_subcommand(parser: _SubParsersAction): def __init__(self, args: Namespace) -> None: self.token = args.token self.repo_id: str = args.repo_id - self.filenames: List[str] = args.filenames + self.filenames: list[str] = args.filenames self.repo_type: str = args.repo_type self.revision: Optional[str] = args.revision - self.include: Optional[List[str]] = args.include - self.exclude: Optional[List[str]] = args.exclude + self.include: Optional[list[str]] = args.include + self.exclude: Optional[list[str]] = args.exclude self.cache_dir: Optional[str] = args.cache_dir self.local_dir: Optional[str] = args.local_dir self.force_download: bool = args.force_download diff --git a/src/huggingface_hub/commands/lfs.py b/src/huggingface_hub/commands/lfs.py index e510e345e6..2133bf1f00 100644 --- a/src/huggingface_hub/commands/lfs.py +++ b/src/huggingface_hub/commands/lfs.py @@ -21,7 +21,7 @@ import subprocess import sys from argparse import _SubParsersAction -from typing import Dict, List, Optional +from typing import Optional from huggingface_hub.commands import BaseHuggingfaceCLICommand from huggingface_hub.lfs import LFS_MULTIPART_UPLOAD_COMMAND @@ -89,14 +89,14 @@ def run(self): print("Local repo set up for largefiles") -def write_msg(msg: Dict): +def write_msg(msg: dict): """Write out the message in Line delimited JSON.""" msg_str = json.dumps(msg) + "\n" sys.stdout.write(msg_str) sys.stdout.flush() -def read_msg() -> Optional[Dict]: +def read_msg() -> Optional[dict]: """Read Line delimited JSON from stdin.""" msg = json.loads(sys.stdin.readline().strip()) @@ -146,7 +146,7 @@ def run(self) -> None: completion_url = msg["action"]["href"] header = msg["action"]["header"] chunk_size = int(header.pop("chunk_size")) - presigned_urls: List[str] = list(header.values()) + presigned_urls: list[str] = list(header.values()) # Send a "started" progress event to allow other workers to start. # Otherwise they're delayed until first "progress" event is reported, diff --git a/src/huggingface_hub/commands/repo_files.py b/src/huggingface_hub/commands/repo_files.py index da9685315e..b914a2dc92 100644 --- a/src/huggingface_hub/commands/repo_files.py +++ b/src/huggingface_hub/commands/repo_files.py @@ -35,7 +35,7 @@ """ from argparse import _SubParsersAction -from typing import List, Optional +from typing import Optional from huggingface_hub import logging from huggingface_hub.commands import BaseHuggingfaceCLICommand @@ -54,7 +54,7 @@ def __init__(self, args) -> None: self.repo_type: Optional[str] = args.repo_type self.revision: Optional[str] = args.revision self.api: HfApi = HfApi(token=args.token, library_name="huggingface-cli") - self.patterns: List[str] = args.patterns + self.patterns: list[str] = args.patterns self.commit_message: Optional[str] = args.commit_message self.commit_description: Optional[str] = args.commit_description self.create_pr: bool = args.create_pr diff --git a/src/huggingface_hub/commands/upload.py b/src/huggingface_hub/commands/upload.py index c778555cda..180f8ef58b 100644 --- a/src/huggingface_hub/commands/upload.py +++ b/src/huggingface_hub/commands/upload.py @@ -50,7 +50,7 @@ import time import warnings from argparse import Namespace, _SubParsersAction -from typing import List, Optional +from typing import Optional from huggingface_hub import logging from huggingface_hub._commit_scheduler import CommitScheduler @@ -144,9 +144,9 @@ def __init__(self, args: Namespace) -> None: self.revision: Optional[str] = args.revision self.private: bool = args.private - self.include: Optional[List[str]] = args.include - self.exclude: Optional[List[str]] = args.exclude - self.delete: Optional[List[str]] = args.delete + self.include: Optional[list[str]] = args.include + self.exclude: Optional[list[str]] = args.exclude + self.delete: Optional[list[str]] = args.delete self.commit_message: Optional[str] = args.commit_message self.commit_description: Optional[str] = args.commit_description diff --git a/src/huggingface_hub/commands/upload_large_folder.py b/src/huggingface_hub/commands/upload_large_folder.py index 3105ba3f57..b0597868ea 100644 --- a/src/huggingface_hub/commands/upload_large_folder.py +++ b/src/huggingface_hub/commands/upload_large_folder.py @@ -16,7 +16,7 @@ import os from argparse import Namespace, _SubParsersAction -from typing import List, Optional +from typing import Optional from huggingface_hub import logging from huggingface_hub.commands import BaseHuggingfaceCLICommand @@ -73,8 +73,8 @@ def __init__(self, args: Namespace) -> None: self.revision: Optional[str] = args.revision self.private: bool = args.private - self.include: Optional[List[str]] = args.include - self.exclude: Optional[List[str]] = args.exclude + self.include: Optional[list[str]] = args.include + self.exclude: Optional[list[str]] = args.exclude self.api: HfApi = HfApi(token=args.token, library_name="huggingface-cli") diff --git a/src/huggingface_hub/commands/user.py b/src/huggingface_hub/commands/user.py index 61cbc4c9e1..b187876328 100644 --- a/src/huggingface_hub/commands/user.py +++ b/src/huggingface_hub/commands/user.py @@ -31,7 +31,7 @@ """ from argparse import _SubParsersAction -from typing import List, Optional +from typing import Optional from huggingface_hub.commands import BaseHuggingfaceCLICommand from huggingface_hub.constants import ENDPOINT @@ -163,7 +163,7 @@ def _select_token_name(self) -> Optional[str]: except ValueError: print("Invalid input. Please enter a number or 'q' to quit.") - def _select_token_name_tui(self, token_names: List[str]) -> Optional[str]: + def _select_token_name_tui(self, token_names: list[str]) -> Optional[str]: choices = [Choice(token_name, name=token_name) for token_name in token_names] try: return inquirer.select( diff --git a/src/huggingface_hub/community.py b/src/huggingface_hub/community.py index 16f2f02428..3bb81e8e2e 100644 --- a/src/huggingface_hub/community.py +++ b/src/huggingface_hub/community.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from datetime import datetime -from typing import List, Literal, Optional, Union +from typing import Literal, Optional, Union from . import constants from .utils import parse_datetime @@ -116,7 +116,7 @@ class DiscussionWithDetails(Discussion): The `datetime` of creation of the Discussion / Pull Request. events (`list` of [`DiscussionEvent`]) The list of [`DiscussionEvents`] in this Discussion or Pull Request. - conflicting_files (`Union[List[str], bool, None]`, *optional*): + conflicting_files (`Union[list[str], bool, None]`, *optional*): A list of conflicting files if this is a Pull Request. `None` if `self.is_pull_request` is `False`. `True` if there are conflicting files but the list can't be retrieved. @@ -136,8 +136,8 @@ class DiscussionWithDetails(Discussion): (property) URL of the discussion on the Hub. """ - events: List["DiscussionEvent"] - conflicting_files: Union[List[str], bool, None] + events: list["DiscussionEvent"] + conflicting_files: Union[list[str], bool, None] target_branch: Optional[str] merge_commit_oid: Optional[str] diff: Optional[str] @@ -222,7 +222,7 @@ def last_edited_by(self) -> str: return self._event["data"]["latest"].get("author", {}).get("name", "deleted") @property - def edit_history(self) -> List[dict]: + def edit_history(self) -> list[dict]: """The edit history of the comment""" return self._event["data"]["history"] diff --git a/src/huggingface_hub/constants.py b/src/huggingface_hub/constants.py index b30b2c01d9..c1445ffc9d 100644 --- a/src/huggingface_hub/constants.py +++ b/src/huggingface_hub/constants.py @@ -1,7 +1,7 @@ import os import re import typing -from typing import Literal, Optional, Tuple +from typing import Literal, Optional # Possible values for env variables @@ -118,9 +118,9 @@ def _as_int(value: Optional[str]) -> Optional[int]: } DiscussionTypeFilter = Literal["all", "discussion", "pull_request"] -DISCUSSION_TYPES: Tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionTypeFilter) +DISCUSSION_TYPES: tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionTypeFilter) DiscussionStatusFilter = Literal["all", "open", "closed"] -DISCUSSION_STATUS: Tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionStatusFilter) +DISCUSSION_STATUS: tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionStatusFilter) # Webhook subscription types WEBHOOK_DOMAIN_T = Literal["repo", "discussions"] diff --git a/src/huggingface_hub/dataclasses.py b/src/huggingface_hub/dataclasses.py index bf81c522d5..cc78945e3a 100644 --- a/src/huggingface_hub/dataclasses.py +++ b/src/huggingface_hub/dataclasses.py @@ -4,11 +4,8 @@ from typing import ( Any, Callable, - Dict, - List, Literal, Optional, - Tuple, Type, TypeVar, Union, @@ -102,7 +99,7 @@ def wrap(cls: Type[T]) -> Type[T]: ) # List and store validators - field_validators: Dict[str, List[Validator_T]] = {} + field_validators: dict[str, list[Validator_T]] = {} for f in fields(cls): # type: ignore [arg-type] validators = [] validators.append(_create_type_validator(f)) @@ -238,14 +235,14 @@ def init_with_validate(self, *args, **kwargs) -> None: def validated_field( - validator: Union[List[Validator_T], Validator_T], + validator: Union[list[Validator_T], Validator_T], default: Union[Any, _MISSING_TYPE] = MISSING, default_factory: Union[Callable[[], Any], _MISSING_TYPE] = MISSING, init: bool = True, repr: bool = True, hash: Optional[bool] = None, compare: bool = True, - metadata: Optional[Dict] = None, + metadata: Optional[dict] = None, **kwargs: Any, ) -> Any: """ @@ -254,7 +251,7 @@ def validated_field( Useful to apply several checks to a field. If only applying one rule, check out the [`as_validated_field`] decorator. Args: - validator (`Callable` or `List[Callable]`): + validator (`Callable` or `list[Callable]`): A method that takes a value as input and raises ValueError/TypeError if the value is invalid. Can be a list of validators to apply multiple checks. **kwargs: @@ -296,7 +293,7 @@ def _inner( repr: bool = True, hash: Optional[bool] = None, compare: bool = True, - metadata: Optional[Dict] = None, + metadata: Optional[dict] = None, **kwargs: Any, ): return validated_field( @@ -329,7 +326,7 @@ def type_validator(name: str, value: Any, expected_type: Any) -> None: raise TypeError(f"Unsupported type for field '{name}': {expected_type}") -def _validate_union(name: str, value: Any, args: Tuple[Any, ...]) -> None: +def _validate_union(name: str, value: Any, args: tuple[Any, ...]) -> None: """Validate that value matches one of the types in a Union.""" errors = [] for t in args: @@ -344,14 +341,14 @@ def _validate_union(name: str, value: Any, args: Tuple[Any, ...]) -> None: ) -def _validate_literal(name: str, value: Any, args: Tuple[Any, ...]) -> None: +def _validate_literal(name: str, value: Any, args: tuple[Any, ...]) -> None: """Validate Literal type.""" if value not in args: raise TypeError(f"Field '{name}' expected one of {args}, got {value}") -def _validate_list(name: str, value: Any, args: Tuple[Any, ...]) -> None: - """Validate List[T] type.""" +def _validate_list(name: str, value: Any, args: tuple[Any, ...]) -> None: + """Validate list[T] type.""" if not isinstance(value, list): raise TypeError(f"Field '{name}' expected a list, got {type(value).__name__}") @@ -364,8 +361,8 @@ def _validate_list(name: str, value: Any, args: Tuple[Any, ...]) -> None: raise TypeError(f"Invalid item at index {i} in list '{name}'") from e -def _validate_dict(name: str, value: Any, args: Tuple[Any, ...]) -> None: - """Validate Dict[K, V] type.""" +def _validate_dict(name: str, value: Any, args: tuple[Any, ...]) -> None: + """Validate dict[K, V] type.""" if not isinstance(value, dict): raise TypeError(f"Field '{name}' expected a dict, got {type(value).__name__}") @@ -379,19 +376,19 @@ def _validate_dict(name: str, value: Any, args: Tuple[Any, ...]) -> None: raise TypeError(f"Invalid key or value in dict '{name}'") from e -def _validate_tuple(name: str, value: Any, args: Tuple[Any, ...]) -> None: +def _validate_tuple(name: str, value: Any, args: tuple[Any, ...]) -> None: """Validate Tuple type.""" if not isinstance(value, tuple): raise TypeError(f"Field '{name}' expected a tuple, got {type(value).__name__}") - # Handle variable-length tuples: Tuple[T, ...] + # Handle variable-length tuples: tuple[T, ...] if len(args) == 2 and args[1] is Ellipsis: for i, item in enumerate(value): try: type_validator(f"{name}[{i}]", item, args[0]) except TypeError as e: raise TypeError(f"Invalid item at index {i} in tuple '{name}'") from e - # Handle fixed-length tuples: Tuple[T1, T2, ...] + # Handle fixed-length tuples: tuple[T1, T2, ...] elif len(args) != len(value): raise TypeError(f"Field '{name}' expected a tuple of length {len(args)}, got {len(value)}") else: @@ -402,8 +399,8 @@ def _validate_tuple(name: str, value: Any, args: Tuple[Any, ...]) -> None: raise TypeError(f"Invalid item at index {i} in tuple '{name}'") from e -def _validate_set(name: str, value: Any, args: Tuple[Any, ...]) -> None: - """Validate Set[T] type.""" +def _validate_set(name: str, value: Any, args: tuple[Any, ...]) -> None: + """Validate set[T] type.""" if not isinstance(value, set): raise TypeError(f"Field '{name}' expected a set, got {type(value).__name__}") diff --git a/src/huggingface_hub/fastai_utils.py b/src/huggingface_hub/fastai_utils.py index e75eba2a8b..de36ff3b36 100644 --- a/src/huggingface_hub/fastai_utils.py +++ b/src/huggingface_hub/fastai_utils.py @@ -2,7 +2,7 @@ import os from pathlib import Path from pickle import DEFAULT_PROTOCOL, PicklingError -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from packaging import version @@ -241,7 +241,7 @@ def _create_model_pyproject(repo_dir: Path): def _save_pretrained_fastai( learner, save_directory: Union[str, Path], - config: Optional[Dict[str, Any]] = None, + config: Optional[dict[str, Any]] = None, ): """ Saves a fastai learner to `save_directory` in pickle format using the default pickle protocol for the version of python used. @@ -350,9 +350,9 @@ def push_to_hub_fastai( config: Optional[dict] = None, branch: Optional[str] = None, create_pr: Optional[bool] = None, - allow_patterns: Optional[Union[List[str], str]] = None, - ignore_patterns: Optional[Union[List[str], str]] = None, - delete_patterns: Optional[Union[List[str], str]] = None, + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, + delete_patterns: Optional[Union[list[str], str]] = None, api_endpoint: Optional[str] = None, ): """ @@ -385,11 +385,11 @@ def push_to_hub_fastai( Defaults to `False`. api_endpoint (`str`, *optional*): The API endpoint to use when pushing the model to the hub. - allow_patterns (`List[str]` or `str`, *optional*): + allow_patterns (`list[str]` or `str`, *optional*): If provided, only files matching at least one pattern are pushed. - ignore_patterns (`List[str]` or `str`, *optional*): + ignore_patterns (`list[str]` or `str`, *optional*): If provided, files matching any of the patterns are not pushed. - delete_patterns (`List[str]` or `str`, *optional*): + delete_patterns (`list[str]` or `str`, *optional*): If provided, remote files matching any of the patterns will be deleted from the repo. Returns: diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index 89e2ad74e4..c3f57fbd78 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -9,7 +9,7 @@ import warnings from dataclasses import dataclass from pathlib import Path -from typing import Any, BinaryIO, Dict, Literal, NoReturn, Optional, Tuple, Union +from typing import Any, BinaryIO, Literal, NoReturn, Optional, Union from urllib.parse import quote, urlparse import httpx @@ -81,7 +81,7 @@ # Regex to check if the file etag IS a valid sha256 REGEX_SHA256 = re.compile(r"^[0-9a-f]{64}$") -_are_symlinks_supported_in_dir: Dict[str, bool] = {} +_are_symlinks_supported_in_dir: dict[str, bool] = {} def are_symlinks_supported(cache_dir: Union[str, Path, None] = None) -> bool: @@ -341,7 +341,7 @@ def http_get( temp_file: BinaryIO, *, resume_size: int = 0, - headers: Optional[Dict[str, Any]] = None, + headers: Optional[dict[str, Any]] = None, expected_size: Optional[int] = None, displayed_filename: Optional[str] = None, _nb_retries: int = 5, @@ -512,7 +512,7 @@ def xet_get( *, incomplete_path: Path, xet_file_data: XetFileData, - headers: Dict[str, str], + headers: dict[str, str], expected_size: Optional[int] = None, displayed_filename: Optional[str] = None, _tqdm_bar: Optional[tqdm] = None, @@ -525,7 +525,7 @@ def xet_get( The path to the file to download. xet_file_data (`XetFileData`): The file metadata needed to make the request to the xet storage service. - headers (`Dict[str, str]`): + headers (`dict[str, str]`): The headers to send to the xet storage service. expected_size (`int`, *optional*): The expected size of the file to download. If set, the download will raise an error if the size of the @@ -572,7 +572,7 @@ def xet_get( connection_info = refresh_xet_connection_info(file_data=xet_file_data, headers=headers) - def token_refresher() -> Tuple[str, int]: + def token_refresher() -> tuple[str, int]: connection_info = refresh_xet_connection_info(file_data=xet_file_data, headers=headers) if connection_info is None: raise ValueError("Failed to refresh token using xet metadata.") @@ -799,12 +799,12 @@ def hf_hub_download( library_version: Optional[str] = None, cache_dir: Union[str, Path, None] = None, local_dir: Union[str, Path, None] = None, - user_agent: Union[Dict, str, None] = None, + user_agent: Union[dict, str, None] = None, force_download: bool = False, etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, token: Union[bool, str, None] = None, local_files_only: bool = False, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, endpoint: Optional[str] = None, resume_download: Optional[bool] = None, force_filename: Optional[str] = None, @@ -1012,7 +1012,7 @@ def _hf_hub_download_to_cache_dir( # HTTP info endpoint: Optional[str], etag_timeout: float, - headers: Dict[str, str], + headers: dict[str, str], token: Optional[Union[bool, str]], # Additional options local_files_only: bool, @@ -1168,7 +1168,7 @@ def _hf_hub_download_to_local_dir( # HTTP info endpoint: Optional[str], etag_timeout: float, - headers: Dict[str, str], + headers: dict[str, str], token: Union[bool, str, None], # Additional options cache_dir: str, @@ -1381,8 +1381,8 @@ def get_hf_file_metadata( timeout: Optional[float] = constants.DEFAULT_REQUEST_TIMEOUT, library_name: Optional[str] = None, library_version: Optional[str] = None, - user_agent: Union[Dict, str, None] = None, - headers: Optional[Dict[str, str]] = None, + user_agent: Union[dict, str, None] = None, + headers: Optional[dict[str, str]] = None, endpoint: Optional[str] = None, ) -> HfFileMetadata: """Fetch metadata of a file versioned on the Hub for a given url. @@ -1451,17 +1451,17 @@ def _get_metadata_or_catch_error( revision: str, endpoint: Optional[str], etag_timeout: Optional[float], - headers: Dict[str, str], # mutated inplace! + headers: dict[str, str], # mutated inplace! token: Union[bool, str, None], local_files_only: bool, relative_filename: Optional[str] = None, # only used to store `.no_exists` in cache storage_folder: Optional[str] = None, # only used to store `.no_exists` in cache ) -> Union[ # Either an exception is caught and returned - Tuple[None, None, None, None, None, Exception], + tuple[None, None, None, None, None, Exception], # Or the metadata is returned as # `(url_to_download, etag, commit_hash, expected_size, xet_file_data, None)` - Tuple[str, str, str, int, Optional[XetFileData], None], + tuple[str, str, str, int, Optional[XetFileData], None], ]: """Get metadata for a file on the Hub, safely handling network issues. @@ -1619,7 +1619,7 @@ def _download_to_tmp_and_move( incomplete_path: Path, destination_path: Path, url_to_download: str, - headers: Dict[str, str], + headers: dict[str, str], expected_size: Optional[int], filename: str, force_download: bool, diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index a5aef6bbc5..951a19dcc0 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -34,14 +34,10 @@ Any, BinaryIO, Callable, - Dict, Iterable, Iterator, - List, Literal, Optional, - Tuple, - Type, TypeVar, Union, overload, @@ -242,7 +238,7 @@ logger = logging.get_logger(__name__) -def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> Tuple[Optional[str], Optional[str], str]: +def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> tuple[Optional[str], Optional[str], str]: """ Returns the repo type and ID from a huggingface.co URL linking to a repository @@ -352,8 +348,8 @@ def __post_init__(self): # hack to make BlobLfsInfo backward compatible class BlobSecurityInfo(dict): safe: bool # duplicate information with "status" field, keeping it for backward compatibility status: str - av_scan: Optional[Dict] - pickle_import_scan: Optional[Dict] + av_scan: Optional[dict] + pickle_import_scan: Optional[dict] def __post_init__(self): # hack to make BlogSecurityInfo backward compatible self.update(asdict(self)) @@ -373,7 +369,7 @@ def __post_init__(self): # hack to make TransformersInfo backward compatible @dataclass class SafeTensorsInfo(dict): - parameters: Dict[str, int] + parameters: dict[str, int] total: int def __post_init__(self): # hack to make SafeTensorsInfo backward compatible @@ -476,7 +472,7 @@ class AccessRequest: Timestamp of the request. status (`Literal["pending", "accepted", "rejected"]`): Status of the request. Can be one of `["pending", "accepted", "rejected"]`. - fields (`Dict[str, Any]`, *optional*): + fields (`dict[str, Any]`, *optional*): Additional fields filled by the user in the gate form. """ @@ -487,7 +483,7 @@ class AccessRequest: status: Literal["pending", "accepted", "rejected"] # Additional fields filled by the user in the gate form - fields: Optional[Dict[str, Any]] = None + fields: Optional[dict[str, Any]] = None @dataclass @@ -514,9 +510,9 @@ class WebhookInfo: ID of the webhook. url (`str`): URL of the webhook. - watched (`List[WebhookWatchedItem]`): + watched (`list[WebhookWatchedItem]`): List of items watched by the webhook, see [`WebhookWatchedItem`]. - domains (`List[WEBHOOK_DOMAIN_T]`): + domains (`list[WEBHOOK_DOMAIN_T]`): List of domains the webhook is watching. Can be one of `["repo", "discussions"]`. secret (`str`, *optional*): Secret of the webhook. @@ -526,8 +522,8 @@ class WebhookInfo: id: str url: str - watched: List[WebhookWatchedItem] - domains: List[constants.WEBHOOK_DOMAIN_T] + watched: list[WebhookWatchedItem] + domains: list[constants.WEBHOOK_DOMAIN_T] secret: Optional[str] disabled: bool @@ -778,17 +774,17 @@ class ModelInfo: gated (`Literal["auto", "manual", False]`, *optional*): Is the repo gated. If so, whether there is manual or automatic approval. - gguf (`Dict`, *optional*): + gguf (`dict`, *optional*): GGUF information of the model. inference (`Literal["warm"]`, *optional*): Status of the model on Inference Providers. Warm if the model is served by at least one provider. - inference_provider_mapping (`List[InferenceProviderMapping]`, *optional*): + inference_provider_mapping (`list[InferenceProviderMapping]`, *optional*): A list of [`InferenceProviderMapping`] ordered after the user's provider order. likes (`int`): Number of likes of the model. library_name (`str`, *optional*): Library associated with the model. - tags (`List[str]`): + tags (`list[str]`): List of tags of the model. Compared to `card_data.tags`, contains extra tags computed by the Hub (e.g. supported libraries, model's arXiv). pipeline_tag (`str`, *optional*): @@ -797,9 +793,9 @@ class ModelInfo: Mask token used by the model. widget_data (`Any`, *optional*): Widget data associated with the model. - model_index (`Dict`, *optional*): + model_index (`dict`, *optional*): Model index for evaluation. - config (`Dict`, *optional*): + config (`dict`, *optional*): Model configuration. transformers_info (`TransformersInfo`, *optional*): Transformers-specific info (auto class, processor, etc.) associated with the model. @@ -807,13 +803,13 @@ class ModelInfo: Trending score of the model. card_data (`ModelCardData`, *optional*): Model Card Metadata as a [`huggingface_hub.repocard_data.ModelCardData`] object. - siblings (`List[RepoSibling]`): + siblings (`list[RepoSibling]`): List of [`huggingface_hub.hf_api.RepoSibling`] objects that constitute the model. - spaces (`List[str]`, *optional*): + spaces (`list[str]`, *optional*): List of spaces using the model. safetensors (`SafeTensorsInfo`, *optional*): Model's safetensors information. - security_repo_status (`Dict`, *optional*): + security_repo_status (`dict`, *optional*): Model's security scan status. """ @@ -827,24 +823,24 @@ class ModelInfo: downloads: Optional[int] downloads_all_time: Optional[int] gated: Optional[Literal["auto", "manual", False]] - gguf: Optional[Dict] + gguf: Optional[dict] inference: Optional[Literal["warm"]] - inference_provider_mapping: Optional[List[InferenceProviderMapping]] + inference_provider_mapping: Optional[list[InferenceProviderMapping]] likes: Optional[int] library_name: Optional[str] - tags: Optional[List[str]] + tags: Optional[list[str]] pipeline_tag: Optional[str] mask_token: Optional[str] card_data: Optional[ModelCardData] widget_data: Optional[Any] - model_index: Optional[Dict] - config: Optional[Dict] + model_index: Optional[dict] + config: Optional[dict] transformers_info: Optional[TransformersInfo] trending_score: Optional[int] - siblings: Optional[List[RepoSibling]] - spaces: Optional[List[str]] + siblings: Optional[list[RepoSibling]] + spaces: Optional[list[str]] safetensors: Optional[SafeTensorsInfo] - security_repo_status: Optional[Dict] + security_repo_status: Optional[dict] xet_enabled: Optional[bool] def __init__(self, **kwargs): @@ -979,11 +975,11 @@ class DatasetInfo: Cumulated number of downloads of the model since its creation. likes (`int`): Number of likes of the dataset. - tags (`List[str]`): + tags (`list[str]`): List of tags of the dataset. card_data (`DatasetCardData`, *optional*): Model Card Metadata as a [`huggingface_hub.repocard_data.DatasetCardData`] object. - siblings (`List[RepoSibling]`): + siblings (`list[RepoSibling]`): List of [`huggingface_hub.hf_api.RepoSibling`] objects that constitute the dataset. paperswithcode_id (`str`, *optional*): Papers with code ID of the dataset. @@ -1003,10 +999,10 @@ class DatasetInfo: downloads_all_time: Optional[int] likes: Optional[int] paperswithcode_id: Optional[str] - tags: Optional[List[str]] + tags: Optional[list[str]] trending_score: Optional[int] card_data: Optional[DatasetCardData] - siblings: Optional[List[RepoSibling]] + siblings: Optional[list[RepoSibling]] xet_enabled: Optional[bool] def __init__(self, **kwargs): @@ -1098,9 +1094,9 @@ class SpaceInfo: Subdomain of the Space. likes (`int`): Number of likes of the Space. - tags (`List[str]`): + tags (`list[str]`): List of tags of the Space. - siblings (`List[RepoSibling]`): + siblings (`list[RepoSibling]`): List of [`huggingface_hub.hf_api.RepoSibling`] objects that constitute the Space. card_data (`SpaceCardData`, *optional*): Space Card Metadata as a [`huggingface_hub.repocard_data.SpaceCardData`] object. @@ -1108,9 +1104,9 @@ class SpaceInfo: Space runtime information as a [`huggingface_hub.hf_api.SpaceRuntime`] object. sdk (`str`, *optional*): SDK used by the Space. - models (`List[str]`, *optional*): + models (`list[str]`, *optional*): List of models used by the Space. - datasets (`List[str]`, *optional*): + datasets (`list[str]`, *optional*): List of datasets used by the Space. trending_score (`int`, *optional*): Trending score of the Space. @@ -1128,13 +1124,13 @@ class SpaceInfo: subdomain: Optional[str] likes: Optional[int] sdk: Optional[str] - tags: Optional[List[str]] - siblings: Optional[List[RepoSibling]] + tags: Optional[list[str]] + siblings: Optional[list[RepoSibling]] trending_score: Optional[int] card_data: Optional[SpaceCardData] runtime: Optional[SpaceRuntime] - models: Optional[List[str]] - datasets: Optional[List[str]] + models: Optional[list[str]] + datasets: Optional[list[str]] xet_enabled: Optional[bool] def __init__(self, **kwargs): @@ -1222,7 +1218,7 @@ def __init__( id: str, type: CollectionItemType_T, position: int, - note: Optional[Dict] = None, + note: Optional[dict] = None, **kwargs, ) -> None: self.item_object_id: str = _id # id in database @@ -1248,7 +1244,7 @@ class Collection: Title of the collection. E.g. `"Recent models"`. owner (`str`): Owner of the collection. E.g. `"TheBloke"`. - items (`List[CollectionItem]`): + items (`list[CollectionItem]`): List of items in the collection. last_updated (`datetime`): Date of the last update of the collection. @@ -1269,7 +1265,7 @@ class Collection: slug: str title: str owner: str - items: List[CollectionItem] + items: list[CollectionItem] last_updated: datetime position: int private: bool @@ -1326,22 +1322,22 @@ class GitRefs: Object is returned by [`list_repo_refs`]. Attributes: - branches (`List[GitRefInfo]`): + branches (`list[GitRefInfo]`): A list of [`GitRefInfo`] containing information about branches on the repo. - converts (`List[GitRefInfo]`): + converts (`list[GitRefInfo]`): A list of [`GitRefInfo`] containing information about "convert" refs on the repo. Converts are refs used (internally) to push preprocessed data in Dataset repos. - tags (`List[GitRefInfo]`): + tags (`list[GitRefInfo]`): A list of [`GitRefInfo`] containing information about tags on the repo. - pull_requests (`List[GitRefInfo]`, *optional*): + pull_requests (`list[GitRefInfo]`, *optional*): A list of [`GitRefInfo`] containing information about pull requests on the repo. Only returned if `include_prs=True` is set. """ - branches: List[GitRefInfo] - converts: List[GitRefInfo] - tags: List[GitRefInfo] - pull_requests: Optional[List[GitRefInfo]] = None + branches: list[GitRefInfo] + converts: list[GitRefInfo] + tags: list[GitRefInfo] + pull_requests: Optional[list[GitRefInfo]] = None @dataclass @@ -1352,7 +1348,7 @@ class GitCommitInfo: Attributes: commit_id (`str`): OID of the commit (e.g. `"e7da7f221d5bf496a48136c0cd264e630fe9fcc8"`) - authors (`List[str]`): + authors (`list[str]`): List of authors of the commit. created_at (`datetime`): Datetime when the commit was created. @@ -1368,7 +1364,7 @@ class GitCommitInfo: commit_id: str - authors: List[str] + authors: list[str] created_at: datetime title: str message: str @@ -1387,11 +1383,11 @@ class UserLikes: Name of the user for which we fetched the likes. total (`int`): Total number of likes. - datasets (`List[str]`): + datasets (`list[str]`): List of datasets liked by the user (as repo_ids). - models (`List[str]`): + models (`list[str]`): List of models liked by the user (as repo_ids). - spaces (`List[str]`): + spaces (`list[str]`): List of spaces liked by the user (as repo_ids). """ @@ -1400,9 +1396,9 @@ class UserLikes: total: int # User likes - datasets: List[str] - models: List[str] - spaces: List[str] + datasets: list[str] + models: list[str] + spaces: list[str] @dataclass @@ -1488,7 +1484,7 @@ class User: num_likes: Optional[int] = None num_following: Optional[int] = None num_followers: Optional[int] = None - orgs: List[Organization] = field(default_factory=list) + orgs: list[Organization] = field(default_factory=list) def __init__(self, **kwargs) -> None: self.username = kwargs.pop("user", "") @@ -1521,7 +1517,7 @@ class PaperInfo: Attributes: id (`str`): arXiv paper ID. - authors (`List[str]`, **optional**): + authors (`list[str]`, **optional**): Names of paper authors published_at (`datetime`, **optional**): Date paper published. @@ -1544,7 +1540,7 @@ class PaperInfo: """ id: str - authors: Optional[List[str]] + authors: Optional[list[str]] published_at: Optional[datetime] title: Optional[str] summary: Optional[str] @@ -1708,8 +1704,8 @@ def __init__( token: Union[str, bool, None] = None, library_name: Optional[str] = None, library_version: Optional[str] = None, - user_agent: Union[Dict, str, None] = None, - headers: Optional[Dict[str, str]] = None, + user_agent: Union[dict, str, None] = None, + headers: Optional[dict[str, str]] = None, ) -> None: self.endpoint = endpoint if endpoint is not None else constants.ENDPOINT self.token = token @@ -1760,7 +1756,7 @@ def run_as_future(self, fn: Callable[..., R], *args, **kwargs) -> Future[R]: return self._thread_pool.submit(fn, *args, **kwargs) @validate_hf_hub_args - def whoami(self, token: Union[bool, str, None] = None) -> Dict: + def whoami(self, token: Union[bool, str, None] = None) -> dict: """ Call HF API to know "whoami". @@ -1836,7 +1832,7 @@ def get_token_permission( except (LocalTokenNotFoundError, HfHubHTTPError, KeyError): return None - def get_model_tags(self) -> Dict: + def get_model_tags(self) -> dict: """ List all valid model tags as a nested namespace object """ @@ -1845,7 +1841,7 @@ def get_model_tags(self) -> Dict: hf_raise_for_status(r) return r.json() - def get_dataset_tags(self) -> Dict: + def get_dataset_tags(self) -> dict: """ List all valid dataset tags as a nested namespace object. """ @@ -1864,30 +1860,30 @@ def list_models( # Search-query parameter filter: Union[str, Iterable[str], None] = None, author: Optional[str] = None, - apps: Optional[Union[str, List[str]]] = None, + apps: Optional[Union[str, list[str]]] = None, gated: Optional[bool] = None, inference: Optional[Literal["warm"]] = None, - inference_provider: Optional[Union[Literal["all"], "PROVIDER_T", List["PROVIDER_T"]]] = None, + inference_provider: Optional[Union[Literal["all"], "PROVIDER_T", list["PROVIDER_T"]]] = None, model_name: Optional[str] = None, - trained_dataset: Optional[Union[str, List[str]]] = None, + trained_dataset: Optional[Union[str, list[str]]] = None, search: Optional[str] = None, pipeline_tag: Optional[str] = None, - emissions_thresholds: Optional[Tuple[float, float]] = None, + emissions_thresholds: Optional[tuple[float, float]] = None, # Sorting and pagination parameters sort: Union[Literal["last_modified"], str, None] = None, direction: Optional[Literal[-1]] = None, limit: Optional[int] = None, # Additional data to fetch - expand: Optional[List[ExpandModelProperty_T]] = None, + expand: Optional[list[ExpandModelProperty_T]] = None, full: Optional[bool] = None, cardData: bool = False, fetch_config: bool = False, token: Union[bool, str, None] = None, # Deprecated arguments - use `filter` instead - language: Optional[Union[str, List[str]]] = None, - library: Optional[Union[str, List[str]]] = None, - tags: Optional[Union[str, List[str]]] = None, - task: Optional[Union[str, List[str]]] = None, + language: Optional[Union[str, list[str]]] = None, + library: Optional[Union[str, list[str]]] = None, + tags: Optional[Union[str, list[str]]] = None, + task: Optional[Union[str, list[str]]] = None, ) -> Iterable[ModelInfo]: """ List models hosted on the Huggingface Hub, given some filters. @@ -1941,7 +1937,7 @@ def list_models( limit (`int`, *optional*): The limit on the number of models fetched. Leaving this option to `None` fetches all models. - expand (`List[ExpandModelProperty_T]`, *optional*): + expand (`list[ExpandModelProperty_T]`, *optional*): List properties to return in the response. When used, only the properties in the list will be returned. This parameter cannot be used if `full`, `cardData` or `fetch_config` are passed. Possible values are `"author"`, `"cardData"`, `"config"`, `"createdAt"`, `"disabled"`, `"downloads"`, `"downloadsAllTime"`, `"gated"`, `"gguf"`, `"inference"`, `"inferenceProviderMapping"`, `"lastModified"`, `"library_name"`, `"likes"`, `"mask_token"`, `"model-index"`, `"pipeline_tag"`, `"private"`, `"safetensors"`, `"sha"`, `"siblings"`, `"spaces"`, `"tags"`, `"transformersInfo"`, `"trendingScore"`, `"widgetData"`, `"resourceGroup"` and `"xetEnabled"`. @@ -2000,10 +1996,10 @@ def list_models( path = f"{self.endpoint}/api/models" headers = self._build_hf_headers(token=token) - params: Dict[str, Any] = {} + params: dict[str, Any] = {} # Build the filter list - filter_list: List[str] = [] + filter_list: list[str] = [] if filter: filter_list.extend([filter] if isinstance(filter, str) else filter) if library: @@ -2090,26 +2086,26 @@ def list_datasets( # Search-query parameter filter: Union[str, Iterable[str], None] = None, author: Optional[str] = None, - benchmark: Optional[Union[str, List[str]]] = None, + benchmark: Optional[Union[str, list[str]]] = None, dataset_name: Optional[str] = None, gated: Optional[bool] = None, - language_creators: Optional[Union[str, List[str]]] = None, - language: Optional[Union[str, List[str]]] = None, - multilinguality: Optional[Union[str, List[str]]] = None, - size_categories: Optional[Union[str, List[str]]] = None, - task_categories: Optional[Union[str, List[str]]] = None, - task_ids: Optional[Union[str, List[str]]] = None, + language_creators: Optional[Union[str, list[str]]] = None, + language: Optional[Union[str, list[str]]] = None, + multilinguality: Optional[Union[str, list[str]]] = None, + size_categories: Optional[Union[str, list[str]]] = None, + task_categories: Optional[Union[str, list[str]]] = None, + task_ids: Optional[Union[str, list[str]]] = None, search: Optional[str] = None, # Sorting and pagination parameters sort: Optional[Union[Literal["last_modified"], str]] = None, direction: Optional[Literal[-1]] = None, limit: Optional[int] = None, # Additional data to fetch - expand: Optional[List[ExpandDatasetProperty_T]] = None, + expand: Optional[list[ExpandDatasetProperty_T]] = None, full: Optional[bool] = None, token: Union[bool, str, None] = None, # Deprecated arguments - use `filter` instead - tags: Optional[Union[str, List[str]]] = None, + tags: Optional[Union[str, list[str]]] = None, ) -> Iterable[DatasetInfo]: """ List datasets hosted on the Huggingface Hub, given some filters. @@ -2164,7 +2160,7 @@ def list_datasets( limit (`int`, *optional*): The limit on the number of datasets fetched. Leaving this option to `None` fetches all datasets. - expand (`List[ExpandDatasetProperty_T]`, *optional*): + expand (`list[ExpandDatasetProperty_T]`, *optional*): List properties to return in the response. When used, only the properties in the list will be returned. This parameter cannot be used if `full` is passed. Possible values are `"author"`, `"cardData"`, `"citation"`, `"createdAt"`, `"disabled"`, `"description"`, `"downloads"`, `"downloadsAllTime"`, `"gated"`, `"lastModified"`, `"likes"`, `"paperswithcode_id"`, `"private"`, `"siblings"`, `"sha"`, `"tags"`, `"trendingScore"`, `"usedStorage"`, `"resourceGroup"` and `"xetEnabled"`. @@ -2224,7 +2220,7 @@ def list_datasets( path = f"{self.endpoint}/api/datasets" headers = self._build_hf_headers(token=token) - params: Dict[str, Any] = {} + params: dict[str, Any] = {} # Build `filter` list filter_list = [] @@ -2311,7 +2307,7 @@ def list_spaces( direction: Optional[Literal[-1]] = None, limit: Optional[int] = None, # Additional data to fetch - expand: Optional[List[ExpandSpaceProperty_T]] = None, + expand: Optional[list[ExpandSpaceProperty_T]] = None, full: Optional[bool] = None, token: Union[bool, str, None] = None, ) -> Iterable[SpaceInfo]: @@ -2342,7 +2338,7 @@ def list_spaces( limit (`int`, *optional*): The limit on the number of Spaces fetched. Leaving this option to `None` fetches all Spaces. - expand (`List[ExpandSpaceProperty_T]`, *optional*): + expand (`list[ExpandSpaceProperty_T]`, *optional*): List properties to return in the response. When used, only the properties in the list will be returned. This parameter cannot be used if `full` is passed. Possible values are `"author"`, `"cardData"`, `"datasets"`, `"disabled"`, `"lastModified"`, `"createdAt"`, `"likes"`, `"models"`, `"private"`, `"runtime"`, `"sdk"`, `"siblings"`, `"sha"`, `"subdomain"`, `"tags"`, `"trendingScore"`, `"usedStorage"`, `"resourceGroup"` and `"xetEnabled"`. @@ -2363,7 +2359,7 @@ def list_spaces( path = f"{self.endpoint}/api/spaces" headers = self._build_hf_headers(token=token) - params: Dict[str, Any] = {} + params: dict[str, Any] = {} if filter is not None: params["filter"] = filter if author is not None: @@ -2580,7 +2576,7 @@ def model_info( timeout: Optional[float] = None, securityStatus: Optional[bool] = None, files_metadata: bool = False, - expand: Optional[List[ExpandModelProperty_T]] = None, + expand: Optional[list[ExpandModelProperty_T]] = None, token: Union[bool, str, None] = None, ) -> ModelInfo: """ @@ -2603,7 +2599,7 @@ def model_info( files_metadata (`bool`, *optional*): Whether or not to retrieve metadata for files in the repository (size, LFS metadata, etc). Defaults to `False`. - expand (`List[ExpandModelProperty_T]`, *optional*): + expand (`list[ExpandModelProperty_T]`, *optional*): List properties to return in the response. When used, only the properties in the list will be returned. This parameter cannot be used if `securityStatus` or `files_metadata` are passed. Possible values are `"author"`, `"baseModels"`, `"cardData"`, `"childrenModelCount"`, `"config"`, `"createdAt"`, `"disabled"`, `"downloads"`, `"downloadsAllTime"`, `"gated"`, `"gguf"`, `"inference"`, `"inferenceProviderMapping"`, `"lastModified"`, `"library_name"`, `"likes"`, `"mask_token"`, `"model-index"`, `"pipeline_tag"`, `"private"`, `"safetensors"`, `"sha"`, `"siblings"`, `"spaces"`, `"tags"`, `"transformersInfo"`, `"trendingScore"`, `"widgetData"`, `"usedStorage"`, `"resourceGroup"` and `"xetEnabled"`. @@ -2637,7 +2633,7 @@ def model_info( if revision is None else (f"{self.endpoint}/api/models/{repo_id}/revision/{quote(revision, safe='')}") ) - params: Dict = {} + params: dict = {} if securityStatus: params["securityStatus"] = True if files_metadata: @@ -2657,7 +2653,7 @@ def dataset_info( revision: Optional[str] = None, timeout: Optional[float] = None, files_metadata: bool = False, - expand: Optional[List[ExpandDatasetProperty_T]] = None, + expand: Optional[list[ExpandDatasetProperty_T]] = None, token: Union[bool, str, None] = None, ) -> DatasetInfo: """ @@ -2677,7 +2673,7 @@ def dataset_info( files_metadata (`bool`, *optional*): Whether or not to retrieve metadata for files in the repository (size, LFS metadata, etc). Defaults to `False`. - expand (`List[ExpandDatasetProperty_T]`, *optional*): + expand (`list[ExpandDatasetProperty_T]`, *optional*): List properties to return in the response. When used, only the properties in the list will be returned. This parameter cannot be used if `files_metadata` is passed. Possible values are `"author"`, `"cardData"`, `"citation"`, `"createdAt"`, `"disabled"`, `"description"`, `"downloads"`, `"downloadsAllTime"`, `"gated"`, `"lastModified"`, `"likes"`, `"paperswithcode_id"`, `"private"`, `"siblings"`, `"sha"`, `"tags"`, `"trendingScore"`,`"usedStorage"`, `"resourceGroup"` and `"xetEnabled"`. @@ -2711,7 +2707,7 @@ def dataset_info( if revision is None else (f"{self.endpoint}/api/datasets/{repo_id}/revision/{quote(revision, safe='')}") ) - params: Dict = {} + params: dict = {} if files_metadata: params["blobs"] = True if expand: @@ -2730,7 +2726,7 @@ def space_info( revision: Optional[str] = None, timeout: Optional[float] = None, files_metadata: bool = False, - expand: Optional[List[ExpandSpaceProperty_T]] = None, + expand: Optional[list[ExpandSpaceProperty_T]] = None, token: Union[bool, str, None] = None, ) -> SpaceInfo: """ @@ -2750,7 +2746,7 @@ def space_info( files_metadata (`bool`, *optional*): Whether or not to retrieve metadata for files in the repository (size, LFS metadata, etc). Defaults to `False`. - expand (`List[ExpandSpaceProperty_T]`, *optional*): + expand (`list[ExpandSpaceProperty_T]`, *optional*): List properties to return in the response. When used, only the properties in the list will be returned. This parameter cannot be used if `full` is passed. Possible values are `"author"`, `"cardData"`, `"createdAt"`, `"datasets"`, `"disabled"`, `"lastModified"`, `"likes"`, `"models"`, `"private"`, `"runtime"`, `"sdk"`, `"siblings"`, `"sha"`, `"subdomain"`, `"tags"`, `"trendingScore"`, `"usedStorage"`, `"resourceGroup"` and `"xetEnabled"`. @@ -2784,7 +2780,7 @@ def space_info( if revision is None else (f"{self.endpoint}/api/spaces/{repo_id}/revision/{quote(revision, safe='')}") ) - params: Dict = {} + params: dict = {} if files_metadata: params["blobs"] = True if expand: @@ -3026,7 +3022,7 @@ def list_repo_files( revision: Optional[str] = None, repo_type: Optional[str] = None, token: Union[str, bool, None] = None, - ) -> List[str]: + ) -> list[str]: """ Get the list of files in a given repo. @@ -3045,7 +3041,7 @@ def list_repo_files( To disable authentication, pass `False`. Returns: - `List[str]`: the list of files in a given repository. + `list[str]`: the list of files in a given repository. """ return [ f.rfilename @@ -3247,7 +3243,7 @@ def list_repo_refs( hf_raise_for_status(response) data = response.json() - def _format_as_git_ref_info(item: Dict) -> GitRefInfo: + def _format_as_git_ref_info(item: dict) -> GitRefInfo: return GitRefInfo(name=item["name"], ref=item["ref"], target_commit=item["targetCommit"]) return GitRefs( @@ -3268,7 +3264,7 @@ def list_repo_commits( token: Union[bool, str, None] = None, revision: Optional[str] = None, formatted: bool = False, - ) -> List[GitCommitInfo]: + ) -> list[GitCommitInfo]: """ Get the list of commits of a given revision for a repo on the Hub. @@ -3315,7 +3311,7 @@ def list_repo_commits( ``` Returns: - List[[`GitCommitInfo`]]: list of objects containing information about the commits for a repo on the Hub. + list[[`GitCommitInfo`]]: list of objects containing information about the commits for a repo on the Hub. Raises: [`~utils.RepositoryNotFoundError`]: @@ -3349,20 +3345,20 @@ def list_repo_commits( def get_paths_info( self, repo_id: str, - paths: Union[List[str], str], + paths: Union[list[str], str], *, expand: bool = False, revision: Optional[str] = None, repo_type: Optional[str] = None, token: Union[str, bool, None] = None, - ) -> List[Union[RepoFile, RepoFolder]]: + ) -> list[Union[RepoFile, RepoFolder]]: """ Get information about a repo's paths. Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. - paths (`Union[List[str], str]`, *optional*): + paths (`Union[list[str], str]`, *optional*): The paths to get information about. If a path do not exist, it is ignored without raising an exception. expand (`bool`, *optional*, defaults to `False`): @@ -3382,7 +3378,7 @@ def get_paths_info( To disable authentication, pass `False`. Returns: - `List[Union[RepoFile, RepoFolder]]`: + `list[Union[RepoFile, RepoFolder]]`: The information about the paths, as a list of [`RepoFile`] and [`RepoFolder`] objects. Raises: @@ -3647,8 +3643,8 @@ def create_repo( space_hardware: Optional[SpaceHardware] = None, space_storage: Optional[SpaceStorage] = None, space_sleep_time: Optional[int] = None, - space_secrets: Optional[List[Dict[str, str]]] = None, - space_variables: Optional[List[Dict[str, str]]] = None, + space_secrets: Optional[list[dict[str, str]]] = None, + space_variables: Optional[list[dict[str, str]]] = None, ) -> RepoUrl: """Create an empty repo on the HuggingFace Hub. @@ -3685,10 +3681,10 @@ def create_repo( your Space to sleep (default behavior for upgraded hardware). For free hardware, you can't configure the sleep time (value is fixed to 48 hours of inactivity). See https://huggingface.co/docs/hub/spaces-gpus#sleep-time for more details. - space_secrets (`List[Dict[str, str]]`, *optional*): + space_secrets (`list[dict[str, str]]`, *optional*): A list of secret keys to set in your Space. Each item is in the form `{"key": ..., "value": ..., "description": ...}` where description is optional. For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets. - space_variables (`List[Dict[str, str]]`, *optional*): + space_variables (`list[dict[str, str]]`, *optional*): A list of public environment variables to set in your Space. Each item is in the form `{"key": ..., "value": ..., "description": ...}` where description is optional. For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables. @@ -3703,7 +3699,7 @@ def create_repo( if repo_type not in constants.REPO_TYPES: raise ValueError("Invalid repo type") - json: Dict[str, Any] = {"name": name, "organization": organization} + json: dict[str, Any] = {"name": name, "organization": organization} if private is not None: json["private"] = private if repo_type is not None: @@ -3841,7 +3837,7 @@ def update_repo_visibility( *, token: Union[str, bool, None] = None, repo_type: Optional[str] = None, - ) -> Dict[str, bool]: + ) -> dict[str, bool]: """Update the visibility setting of a repository. Deprecated. Use `update_repo_settings` instead. @@ -3942,7 +3938,7 @@ def update_repo_settings( repo_type = constants.REPO_TYPE_MODEL # default repo type # Prepare the JSON payload for the PUT request - payload: Dict = {} + payload: dict = {} if gated is not None: if gated not in ["auto", "manual", False]: @@ -4760,9 +4756,9 @@ def upload_folder( # type: ignore revision: Optional[str] = None, create_pr: Optional[bool] = None, parent_commit: Optional[str] = None, - allow_patterns: Optional[Union[List[str], str]] = None, - ignore_patterns: Optional[Union[List[str], str]] = None, - delete_patterns: Optional[Union[List[str], str]] = None, + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, + delete_patterns: Optional[Union[list[str], str]] = None, run_as_future: Literal[False] = ..., ) -> CommitInfo: ... @@ -4780,9 +4776,9 @@ def upload_folder( # type: ignore revision: Optional[str] = None, create_pr: Optional[bool] = None, parent_commit: Optional[str] = None, - allow_patterns: Optional[Union[List[str], str]] = None, - ignore_patterns: Optional[Union[List[str], str]] = None, - delete_patterns: Optional[Union[List[str], str]] = None, + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, + delete_patterns: Optional[Union[list[str], str]] = None, run_as_future: Literal[True] = ..., ) -> Future[CommitInfo]: ... @@ -4801,9 +4797,9 @@ def upload_folder( revision: Optional[str] = None, create_pr: Optional[bool] = None, parent_commit: Optional[str] = None, - allow_patterns: Optional[Union[List[str], str]] = None, - ignore_patterns: Optional[Union[List[str], str]] = None, - delete_patterns: Optional[Union[List[str], str]] = None, + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, + delete_patterns: Optional[Union[list[str], str]] = None, run_as_future: bool = False, ) -> Union[CommitInfo, Future[CommitInfo]]: """ @@ -4865,11 +4861,11 @@ def upload_folder( If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be especially useful if the repo is updated / committed to concurrently. - allow_patterns (`List[str]` or `str`, *optional*): + allow_patterns (`list[str]` or `str`, *optional*): If provided, only files matching at least one pattern are uploaded. - ignore_patterns (`List[str]` or `str`, *optional*): + ignore_patterns (`list[str]` or `str`, *optional*): If provided, files matching any of the patterns are not uploaded. - delete_patterns (`List[str]` or `str`, *optional*): + delete_patterns (`list[str]` or `str`, *optional*): If provided, remote files matching any of the patterns will be deleted from the repo while committing new files. This is useful if you don't know which files have already been uploaded. Note: to avoid discrepancies the `.gitattributes` file is not deleted even if it matches the pattern. @@ -5113,7 +5109,7 @@ def delete_file( def delete_files( self, repo_id: str, - delete_patterns: List[str], + delete_patterns: list[str], *, token: Union[bool, str, None] = None, repo_type: Optional[str] = None, @@ -5133,7 +5129,7 @@ def delete_files( repo_id (`str`): The repository from which the folder will be deleted, for example: `"username/custom_transformers"` - delete_patterns (`List[str]`): + delete_patterns (`list[str]`): List of files or folders to delete. Each string can either be a file path, a folder path or a Unix shell-style wildcard. E.g. `["file.txt", "folder/", "data/*.parquet"]` @@ -5261,8 +5257,8 @@ def upload_large_folder( repo_type: str, # Repo type is required! revision: Optional[str] = None, private: Optional[bool] = None, - allow_patterns: Optional[Union[List[str], str]] = None, - ignore_patterns: Optional[Union[List[str], str]] = None, + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, num_workers: Optional[int] = None, print_report: bool = True, print_report_every: int = 60, @@ -5290,9 +5286,9 @@ def upload_large_folder( private (`bool`, `optional`): Whether the repository should be private. If `None` (default), the repo will be public unless the organization's default is private. - allow_patterns (`List[str]` or `str`, *optional*): + allow_patterns (`list[str]` or `str`, *optional*): If provided, only files matching at least one pattern are uploaded. - ignore_patterns (`List[str]` or `str`, *optional*): + ignore_patterns (`list[str]` or `str`, *optional*): If provided, files matching any of the patterns are not uploaded. num_workers (`int`, *optional*): Number of workers to start. Defaults to `os.cpu_count() - 2` (minimum 2). @@ -5562,10 +5558,10 @@ def snapshot_download( force_download: bool = False, token: Union[bool, str, None] = None, local_files_only: bool = False, - allow_patterns: Optional[Union[List[str], str]] = None, - ignore_patterns: Optional[Union[List[str], str]] = None, + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, max_workers: int = 8, - tqdm_class: Optional[Type[base_tqdm]] = None, + tqdm_class: Optional[type[base_tqdm]] = None, # Deprecated args local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", resume_download: Optional[bool] = None, @@ -5611,9 +5607,9 @@ def snapshot_download( local_files_only (`bool`, *optional*, defaults to `False`): If `True`, avoid downloading the file and return the path to the local cached file if it exists. - allow_patterns (`List[str]` or `str`, *optional*): + allow_patterns (`list[str]` or `str`, *optional*): If provided, only files matching at least one pattern are downloaded. - ignore_patterns (`List[str]` or `str`, *optional*): + ignore_patterns (`list[str]` or `str`, *optional*): If provided, files matching any of the patterns are not downloaded. max_workers (`int`, *optional*): Number of concurrent threads to download files (1 thread = 1 file download). @@ -6276,7 +6272,7 @@ def get_repo_discussions( headers = self._build_hf_headers(token=token) path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions" - params: Dict[str, Union[str, int]] = {} + params: dict[str, Union[str, int]] = {} if discussion_type is not None: params["type"] = discussion_type if discussion_status is not None: @@ -6785,7 +6781,7 @@ def change_discussion_status( """ if new_status not in ["open", "closed"]: raise ValueError("Invalid status, valid statuses are: 'open' and 'closed'") - body: Dict[str, str] = {"status": new_status} + body: dict[str, str] = {"status": new_status} if comment and comment.strip(): body["comment"] = comment.strip() resp = self._post_discussion_changes( @@ -7045,7 +7041,7 @@ def delete_space_secret(self, repo_id: str, key: str, *, token: Union[bool, str, hf_raise_for_status(r) @validate_hf_hub_args - def get_space_variables(self, repo_id: str, *, token: Union[bool, str, None] = None) -> Dict[str, SpaceVariable]: + def get_space_variables(self, repo_id: str, *, token: Union[bool, str, None] = None) -> dict[str, SpaceVariable]: """Gets all variables from a Space. Variables allow to set environment variables to a Space without hardcoding them. @@ -7076,7 +7072,7 @@ def add_space_variable( *, description: Optional[str] = None, token: Union[bool, str, None] = None, - ) -> Dict[str, SpaceVariable]: + ) -> dict[str, SpaceVariable]: """Adds or updates a variable in a Space. Variables allow to set environment variables to a Space without hardcoding them. @@ -7111,7 +7107,7 @@ def add_space_variable( @validate_hf_hub_args def delete_space_variable( self, repo_id: str, key: str, *, token: Union[bool, str, None] = None - ) -> Dict[str, SpaceVariable]: + ) -> dict[str, SpaceVariable]: """Deletes a variable from a Space. Variables allow to set environment variables to a Space without hardcoding them. @@ -7200,7 +7196,7 @@ def request_space_hardware( " you want to set a custom sleep time, you need to upgrade to a paid Hardware.", UserWarning, ) - payload: Dict[str, Any] = {"flavor": hardware} + payload: dict[str, Any] = {"flavor": hardware} if sleep_time is not None: payload["sleepTimeSeconds"] = sleep_time r = get_session().post( @@ -7359,8 +7355,8 @@ def duplicate_space( hardware: Optional[SpaceHardware] = None, storage: Optional[SpaceStorage] = None, sleep_time: Optional[int] = None, - secrets: Optional[List[Dict[str, str]]] = None, - variables: Optional[List[Dict[str, str]]] = None, + secrets: Optional[list[dict[str, str]]] = None, + variables: Optional[list[dict[str, str]]] = None, ) -> RepoUrl: """Duplicate a Space. @@ -7391,10 +7387,10 @@ def duplicate_space( your Space to sleep (default behavior for upgraded hardware). For free hardware, you can't configure the sleep time (value is fixed to 48 hours of inactivity). See https://huggingface.co/docs/hub/spaces-gpus#sleep-time for more details. - secrets (`List[Dict[str, str]]`, *optional*): + secrets (`list[dict[str, str]]`, *optional*): A list of secret keys to set in your Space. Each item is in the form `{"key": ..., "value": ..., "description": ...}` where description is optional. For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets. - variables (`List[Dict[str, str]]`, *optional*): + variables (`list[dict[str, str]]`, *optional*): A list of public environment variables to set in your Space. Each item is in the form `{"key": ..., "value": ..., "description": ...}` where description is optional. For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables. @@ -7434,7 +7430,7 @@ def duplicate_space( to_repo_name = parsed_to_id.repo_name if to_id is not None else RepoUrl(from_id).repo_name # type: ignore # repository must be a valid repo_id (namespace/repo_name). - payload: Dict[str, Any] = {"repository": f"{to_namespace}/{to_repo_name}"} + payload: dict[str, Any] = {"repository": f"{to_namespace}/{to_repo_name}"} keys = ["private", "hardware", "storageTier", "sleepTimeSeconds", "secrets", "variables"] values = [private, hardware, storage, sleep_time, secrets, variables] @@ -7495,7 +7491,7 @@ def request_space_storage(
""" - payload: Dict[str, SpaceStorage] = {"tier": storage} + payload: dict[str, SpaceStorage] = {"tier": storage} r = get_session().post( f"{self.endpoint}/api/spaces/{repo_id}/storage", headers=self._build_hf_headers(token=token), @@ -7541,7 +7537,7 @@ def delete_space_storage( def list_inference_endpoints( self, namespace: Optional[str] = None, *, token: Union[bool, str, None] = None - ) -> List[InferenceEndpoint]: + ) -> list[InferenceEndpoint]: """Lists all inference endpoints for the given namespace. Args: @@ -7555,7 +7551,7 @@ def list_inference_endpoints( To disable authentication, pass `False`. Returns: - List[`InferenceEndpoint`]: A list of all inference endpoints for the given namespace. + list[`InferenceEndpoint`]: A list of all inference endpoints for the given namespace. Example: ```python @@ -7570,7 +7566,7 @@ def list_inference_endpoints( user = self.whoami(token=token) # List personal endpoints first - endpoints: List[InferenceEndpoint] = list_inference_endpoints(namespace=self._get_namespace(token=token)) + endpoints: list[InferenceEndpoint] = list_inference_endpoints(namespace=self._get_namespace(token=token)) # Then list endpoints for all orgs the user belongs to and ignore 401 errors (no billing or no access) for org in user.get("orgs", []): @@ -7614,14 +7610,14 @@ def create_inference_endpoint( scale_to_zero_timeout: Optional[int] = None, revision: Optional[str] = None, task: Optional[str] = None, - custom_image: Optional[Dict] = None, - env: Optional[Dict[str, str]] = None, - secrets: Optional[Dict[str, str]] = None, + custom_image: Optional[dict] = None, + env: Optional[dict[str, str]] = None, + secrets: Optional[dict[str, str]] = None, type: InferenceEndpointType = InferenceEndpointType.PROTECTED, domain: Optional[str] = None, path: Optional[str] = None, cache_http_responses: Optional[bool] = None, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, namespace: Optional[str] = None, token: Union[bool, str, None] = None, ) -> InferenceEndpoint: @@ -7658,12 +7654,12 @@ def create_inference_endpoint( The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`). task (`str`, *optional*): The task on which to deploy the model (e.g. `"text-classification"`). - custom_image (`Dict`, *optional*): + custom_image (`dict`, *optional*): A custom Docker image to use for the Inference Endpoint. This is useful if you want to deploy an Inference Endpoint running on the `text-generation-inference` (TGI) framework (see examples). - env (`Dict[str, str]`, *optional*): + env (`dict[str, str]`, *optional*): Non-secret environment variables to inject in the container environment. - secrets (`Dict[str, str]`, *optional*): + secrets (`dict[str, str]`, *optional*): Secret values to inject in the container environment. type ([`InferenceEndpointType]`, *optional*): The type of the Inference Endpoint, which can be `"protected"` (default), `"public"` or `"private"`. @@ -7673,7 +7669,7 @@ def create_inference_endpoint( The custom path to the deployed model, should start with a `/` (e.g. `"/models/google-bert/bert-base-uncased"`). cache_http_responses (`bool`, *optional*): Whether to cache HTTP responses from the Inference Endpoint. Defaults to `False`. - tags (`List[str]`, *optional*): + tags (`list[str]`, *optional*): A list of tags to associate with the Inference Endpoint. namespace (`str`, *optional*): The namespace where the Inference Endpoint will be created. Defaults to the current user's namespace. @@ -7776,7 +7772,7 @@ def create_inference_endpoint( else: image = {"huggingface": {}} - payload: Dict = { + payload: dict = { "accountId": account_id, "compute": { "accelerator": accelerator, @@ -7865,7 +7861,7 @@ def create_inference_endpoint_from_catalog( """ token = token or self.token or get_token() - payload: Dict = { + payload: dict = { "namespace": namespace or self._get_namespace(token=token), "repoId": repo_id, } @@ -7883,7 +7879,7 @@ def create_inference_endpoint_from_catalog( @experimental @validate_hf_hub_args - def list_inference_catalog(self, *, token: Union[bool, str, None] = None) -> List[str]: + def list_inference_catalog(self, *, token: Union[bool, str, None] = None) -> list[str]: """List models available in the Hugging Face Inference Catalog. The goal of the Inference Catalog is to provide a curated list of models that are optimized for inference @@ -7899,7 +7895,7 @@ def list_inference_catalog(self, *, token: Union[bool, str, None] = None) -> Lis https://huggingface.co/docs/huggingface_hub/quick-start#authentication). Returns: - List[`str`]: A list of model IDs available in the catalog. + list[`str`]: A list of model IDs available in the catalog. `list_inference_catalog` is experimental. Its API is subject to change in the future. Please provide feedback @@ -7977,15 +7973,15 @@ def update_inference_endpoint( framework: Optional[str] = None, revision: Optional[str] = None, task: Optional[str] = None, - custom_image: Optional[Dict] = None, - env: Optional[Dict[str, str]] = None, - secrets: Optional[Dict[str, str]] = None, + custom_image: Optional[dict] = None, + env: Optional[dict[str, str]] = None, + secrets: Optional[dict[str, str]] = None, # Route update domain: Optional[str] = None, path: Optional[str] = None, # Other cache_http_responses: Optional[bool] = None, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, namespace: Optional[str] = None, token: Union[bool, str, None] = None, ) -> InferenceEndpoint: @@ -8021,12 +8017,12 @@ def update_inference_endpoint( The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`). task (`str`, *optional*): The task on which to deploy the model (e.g. `"text-classification"`). - custom_image (`Dict`, *optional*): + custom_image (`dict`, *optional*): A custom Docker image to use for the Inference Endpoint. This is useful if you want to deploy an Inference Endpoint running on the `text-generation-inference` (TGI) framework (see examples). - env (`Dict[str, str]`, *optional*): + env (`dict[str, str]`, *optional*): Non-secret environment variables to inject in the container environment - secrets (`Dict[str, str]`, *optional*): + secrets (`dict[str, str]`, *optional*): Secret values to inject in the container environment. domain (`str`, *optional*): @@ -8036,7 +8032,7 @@ def update_inference_endpoint( cache_http_responses (`bool`, *optional*): Whether to cache HTTP responses from the Inference Endpoint. - tags (`List[str]`, *optional*): + tags (`list[str]`, *optional*): A list of tags to associate with the Inference Endpoint. namespace (`str`, *optional*): @@ -8053,7 +8049,7 @@ def update_inference_endpoint( namespace = namespace or self._get_namespace(token=token) # Populate only the fields that are not None - payload: Dict = defaultdict(lambda: defaultdict(dict)) + payload: dict = defaultdict(lambda: defaultdict(dict)) if accelerator is not None: payload["compute"]["accelerator"] = accelerator if instance_size is not None: @@ -8260,8 +8256,8 @@ def _get_namespace(self, token: Union[bool, str, None] = None) -> str: def list_collections( self, *, - owner: Union[List[str], str, None] = None, - item: Union[List[str], str, None] = None, + owner: Union[list[str], str, None] = None, + item: Union[list[str], str, None] = None, sort: Optional[Literal["lastModified", "trending", "upvotes"]] = None, limit: Optional[int] = None, token: Union[bool, str, None] = None, @@ -8276,9 +8272,9 @@ def list_collections( Args: - owner (`List[str]` or `str`, *optional*): + owner (`list[str]` or `str`, *optional*): Filter by owner's username. - item (`List[str]` or `str`, *optional*): + item (`list[str]` or `str`, *optional*): Filter collections containing a particular items. Example: `"models/teknium/OpenHermes-2.5-Mistral-7B"`, `"datasets/squad"` or `"papers/2311.12983"`. sort (`Literal["lastModified", "trending", "upvotes"]`, *optional*): Sort collections by last modified, trending or upvotes. @@ -8296,7 +8292,7 @@ def list_collections( # Construct the API endpoint path = f"{self.endpoint}/api/collections" headers = self._build_hf_headers(token=token) - params: Dict = {} + params: dict = {} if owner is not None: params.update({"owner": owner}) if item is not None: @@ -8595,7 +8591,7 @@ def add_collection_item( (...) ``` """ - payload: Dict[str, Any] = {"item": {"id": item_id, "type": item_type}} + payload: dict[str, Any] = {"item": {"id": item_id, "type": item_type}} if note is not None: payload["note"] = note r = get_session().post( @@ -8725,7 +8721,7 @@ def delete_collection_item( @validate_hf_hub_args def list_pending_access_requests( self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None - ) -> List[AccessRequest]: + ) -> list[AccessRequest]: """ Get pending access requests for a given gated repo. @@ -8748,7 +8744,7 @@ def list_pending_access_requests( To disable authentication, pass `False`. Returns: - `List[AccessRequest]`: A list of [`AccessRequest`] objects. Each time contains a `username`, `email`, + `list[AccessRequest]`: A list of [`AccessRequest`] objects. Each time contains a `username`, `email`, `status` and `timestamp` attribute. If the gated repo has a custom form, the `fields` attribute will be populated with user's answers. @@ -8789,7 +8785,7 @@ def list_pending_access_requests( @validate_hf_hub_args def list_accepted_access_requests( self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None - ) -> List[AccessRequest]: + ) -> list[AccessRequest]: """ Get accepted access requests for a given gated repo. @@ -8814,7 +8810,7 @@ def list_accepted_access_requests( To disable authentication, pass `False`. Returns: - `List[AccessRequest]`: A list of [`AccessRequest`] objects. Each time contains a `username`, `email`, + `list[AccessRequest]`: A list of [`AccessRequest`] objects. Each time contains a `username`, `email`, `status` and `timestamp` attribute. If the gated repo has a custom form, the `fields` attribute will be populated with user's answers. @@ -8851,7 +8847,7 @@ def list_accepted_access_requests( @validate_hf_hub_args def list_rejected_access_requests( self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None - ) -> List[AccessRequest]: + ) -> list[AccessRequest]: """ Get rejected access requests for a given gated repo. @@ -8876,7 +8872,7 @@ def list_rejected_access_requests( To disable authentication, pass `False`. Returns: - `List[AccessRequest]`: A list of [`AccessRequest`] objects. Each time contains a `username`, `email`, + `list[AccessRequest]`: A list of [`AccessRequest`] objects. Each time contains a `username`, `email`, `status` and `timestamp` attribute. If the gated repo has a custom form, the `fields` attribute will be populated with user's answers. @@ -8916,7 +8912,7 @@ def _list_access_requests( status: Literal["accepted", "rejected", "pending"], repo_type: Optional[str] = None, token: Union[bool, str, None] = None, - ) -> List[AccessRequest]: + ) -> list[AccessRequest]: if repo_type not in constants.REPO_TYPES: raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") if repo_type is None: @@ -9209,7 +9205,7 @@ def get_webhook(self, webhook_id: str, *, token: Union[bool, str, None] = None) return webhook @validate_hf_hub_args - def list_webhooks(self, *, token: Union[bool, str, None] = None) -> List[WebhookInfo]: + def list_webhooks(self, *, token: Union[bool, str, None] = None) -> list[WebhookInfo]: """List all configured webhooks. Args: @@ -9219,7 +9215,7 @@ def list_webhooks(self, *, token: Union[bool, str, None] = None) -> List[Webhook To disable authentication, pass `False`. Returns: - `List[WebhookInfo]`: + `list[WebhookInfo]`: List of webhook info objects. Example: @@ -9263,8 +9259,8 @@ def create_webhook( self, *, url: str, - watched: List[Union[Dict, WebhookWatchedItem]], - domains: Optional[List[constants.WEBHOOK_DOMAIN_T]] = None, + watched: list[Union[dict, WebhookWatchedItem]], + domains: Optional[list[constants.WEBHOOK_DOMAIN_T]] = None, secret: Optional[str] = None, token: Union[bool, str, None] = None, ) -> WebhookInfo: @@ -9273,10 +9269,10 @@ def create_webhook( Args: url (`str`): URL to send the payload to. - watched (`List[WebhookWatchedItem]`): + watched (`list[WebhookWatchedItem]`): List of [`WebhookWatchedItem`] to be watched by the webhook. It can be users, orgs, models, datasets or spaces. Watched items can also be provided as plain dictionaries. - domains (`List[Literal["repo", "discussion"]]`, optional): + domains (`list[Literal["repo", "discussion"]]`, optional): List of domains to watch. It can be "repo", "discussion" or both. secret (`str`, optional): A secret to sign the payload with. @@ -9337,8 +9333,8 @@ def update_webhook( webhook_id: str, *, url: Optional[str] = None, - watched: Optional[List[Union[Dict, WebhookWatchedItem]]] = None, - domains: Optional[List[constants.WEBHOOK_DOMAIN_T]] = None, + watched: Optional[list[Union[dict, WebhookWatchedItem]]] = None, + domains: Optional[list[constants.WEBHOOK_DOMAIN_T]] = None, secret: Optional[str] = None, token: Union[bool, str, None] = None, ) -> WebhookInfo: @@ -9349,10 +9345,10 @@ def update_webhook( The unique identifier of the webhook to be updated. url (`str`, optional): The URL to which the payload will be sent. - watched (`List[WebhookWatchedItem]`, optional): + watched (`list[WebhookWatchedItem]`, optional): List of items to watch. It can be users, orgs, models, datasets, or spaces. Refer to [`WebhookWatchedItem`] for more details. Watched items can also be provided as plain dictionaries. - domains (`List[Literal["repo", "discussion"]]`, optional): + domains (`list[Literal["repo", "discussion"]]`, optional): The domains to watch. This can include "repo", "discussion", or both. secret (`str`, optional): A secret to sign the payload with, providing an additional layer of security. @@ -9548,8 +9544,8 @@ def _build_hf_headers( token: Union[bool, str, None] = None, library_name: Optional[str] = None, library_version: Optional[str] = None, - user_agent: Union[Dict, str, None] = None, - ) -> Dict[str, str]: + user_agent: Union[dict, str, None] = None, + ) -> dict[str, str]: """ Alias for [`build_hf_headers`] that uses the token from [`HfApi`] client when `token` is not provided. @@ -9571,9 +9567,9 @@ def _prepare_folder_deletions( repo_type: Optional[str], revision: Optional[str], path_in_repo: str, - delete_patterns: Optional[Union[List[str], str]], + delete_patterns: Optional[Union[list[str], str]], token: Union[bool, str, None] = None, - ) -> List[CommitOperationDelete]: + ) -> list[CommitOperationDelete]: """Generate the list of Delete operations for a commit to delete files from a repo. List remote files and match them against the `delete_patterns` constraints. Returns a list of [`CommitOperationDelete`] @@ -9609,11 +9605,11 @@ def _prepare_upload_folder_additions( self, folder_path: Union[str, Path], path_in_repo: str, - allow_patterns: Optional[Union[List[str], str]] = None, - ignore_patterns: Optional[Union[List[str], str]] = None, + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, repo_type: Optional[str] = None, token: Union[bool, str, None] = None, - ) -> List[CommitOperationAdd]: + ) -> list[CommitOperationAdd]: """Generate the list of Add operations for a commit to upload a folder. Files not matching the `allow_patterns` (allowlist) and `ignore_patterns` (denylist) @@ -9955,9 +9951,9 @@ def run_job( self, *, image: str, - command: List[str], - env: Optional[Dict[str, Any]] = None, - secrets: Optional[Dict[str, Any]] = None, + command: list[str], + env: Optional[dict[str, Any]] = None, + secrets: Optional[dict[str, Any]] = None, flavor: Optional[SpaceHardware] = None, timeout: Optional[Union[int, float, str]] = None, namespace: Optional[str] = None, @@ -9972,13 +9968,13 @@ def run_job( Examples: `"ubuntu"`, `"python:3.12"`, `"pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel"`. Example with an image from a Space: `"hf.co/spaces/lhoestq/duckdb"`. - command (`List[str]`): + command (`list[str]`): The command to run. Example: `["echo", "hello"]`. - env (`Dict[str, Any]`, *optional*): + env (`dict[str, Any]`, *optional*): Defines the environment variables for the Job. - secrets (`Dict[str, Any]`, *optional*): + secrets (`dict[str, Any]`, *optional*): Defines the secret environment variables for the Job. flavor (`str`, *optional*): @@ -10131,7 +10127,7 @@ def list_jobs( timeout: Optional[int] = None, namespace: Optional[str] = None, token: Union[bool, str, None] = None, - ) -> List[JobInfo]: + ) -> list[JobInfo]: """ List compute Jobs on Hugging Face infrastructure. @@ -10242,12 +10238,12 @@ def run_uv_job( self, script: str, *, - script_args: Optional[List[str]] = None, - dependencies: Optional[List[str]] = None, + script_args: Optional[list[str]] = None, + dependencies: Optional[list[str]] = None, python: Optional[str] = None, image: Optional[str] = None, - env: Optional[Dict[str, Any]] = None, - secrets: Optional[Dict[str, Any]] = None, + env: Optional[dict[str, Any]] = None, + secrets: Optional[dict[str, Any]] = None, flavor: Optional[SpaceHardware] = None, timeout: Optional[Union[int, float, str]] = None, namespace: Optional[str] = None, @@ -10261,10 +10257,10 @@ def run_uv_job( script (`str`): Path or URL of the UV script, or a command. - script_args (`List[str]`, *optional*) + script_args (`list[str]`, *optional*) Arguments to pass to the script or command. - dependencies (`List[str]`, *optional*) + dependencies (`list[str]`, *optional*) Dependencies to use to run the UV script. python (`str`, *optional*) @@ -10273,10 +10269,10 @@ def run_uv_job( image (`str`, *optional*, defaults to "ghcr.io/astral-sh/uv:python3.12-bookworm"): Use a custom Docker image with `uv` installed. - env (`Dict[str, Any]`, *optional*): + env (`dict[str, Any]`, *optional*): Defines the environment variables for the Job. - secrets (`Dict[str, Any]`, *optional*): + secrets (`dict[str, Any]`, *optional*): Defines the secret environment variables for the Job. flavor (`str`, *optional*): @@ -10356,12 +10352,12 @@ def create_scheduled_job( self, *, image: str, - command: List[str], + command: list[str], schedule: str, suspend: Optional[bool] = None, concurrency: Optional[bool] = None, - env: Optional[Dict[str, Any]] = None, - secrets: Optional[Dict[str, Any]] = None, + env: Optional[dict[str, Any]] = None, + secrets: Optional[dict[str, Any]] = None, flavor: Optional[SpaceHardware] = None, timeout: Optional[Union[int, float, str]] = None, namespace: Optional[str] = None, @@ -10376,7 +10372,7 @@ def create_scheduled_job( Examples: `"ubuntu"`, `"python:3.12"`, `"pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel"`. Example with an image from a Space: `"hf.co/spaces/lhoestq/duckdb"`. - command (`List[str]`): + command (`list[str]`): The command to run. Example: `["echo", "hello"]`. schedule (`str`): @@ -10389,10 +10385,10 @@ def create_scheduled_job( concurrency (`bool`, *optional*): If True, multiple instances of this Job can run concurrently. Defaults to False. - env (`Dict[str, Any]`, *optional*): + env (`dict[str, Any]`, *optional*): Defines the environment variables for the Job. - secrets (`Dict[str, Any]`, *optional*): + secrets (`dict[str, Any]`, *optional*): Defines the secret environment variables for the Job. flavor (`str`, *optional*): @@ -10448,7 +10444,7 @@ def create_scheduled_job( flavor=flavor, timeout=timeout, ) - input_json: Dict[str, Any] = { + input_json: dict[str, Any] = { "jobSpec": job_spec, "schedule": schedule, } @@ -10471,7 +10467,7 @@ def list_scheduled_jobs( timeout: Optional[int] = None, namespace: Optional[str] = None, token: Union[bool, str, None] = None, - ) -> List[ScheduledJobInfo]: + ) -> list[ScheduledJobInfo]: """ List scheduled compute Jobs on Hugging Face infrastructure. @@ -10629,15 +10625,15 @@ def create_scheduled_uv_job( self, script: str, *, - script_args: Optional[List[str]] = None, + script_args: Optional[list[str]] = None, schedule: str, suspend: Optional[bool] = None, concurrency: Optional[bool] = None, - dependencies: Optional[List[str]] = None, + dependencies: Optional[list[str]] = None, python: Optional[str] = None, image: Optional[str] = None, - env: Optional[Dict[str, Any]] = None, - secrets: Optional[Dict[str, Any]] = None, + env: Optional[dict[str, Any]] = None, + secrets: Optional[dict[str, Any]] = None, flavor: Optional[SpaceHardware] = None, timeout: Optional[Union[int, float, str]] = None, namespace: Optional[str] = None, @@ -10651,7 +10647,7 @@ def create_scheduled_uv_job( script (`str`): Path or URL of the UV script, or a command. - script_args (`List[str]`, *optional*) + script_args (`list[str]`, *optional*) Arguments to pass to the script, or a command. schedule (`str`): @@ -10664,7 +10660,7 @@ def create_scheduled_uv_job( concurrency (`bool`, *optional*): If True, multiple instances of this Job can run concurrently. Defaults to False. - dependencies (`List[str]`, *optional*) + dependencies (`list[str]`, *optional*) Dependencies to use to run the UV script. python (`str`, *optional*) @@ -10673,10 +10669,10 @@ def create_scheduled_uv_job( image (`str`, *optional*, defaults to "ghcr.io/astral-sh/uv:python3.12-bookworm"): Use a custom Docker image with `uv` installed. - env (`Dict[str, Any]`, *optional*): + env (`dict[str, Any]`, *optional*): Defines the environment variables for the Job. - secrets (`Dict[str, Any]`, *optional*): + secrets (`dict[str, Any]`, *optional*): Defines the secret environment variables for the Job. flavor (`str`, *optional*): @@ -10756,15 +10752,15 @@ def _create_uv_command_env_and_secrets( self, *, script: str, - script_args: Optional[List[str]], - dependencies: Optional[List[str]], + script_args: Optional[list[str]], + dependencies: Optional[list[str]], python: Optional[str], - env: Optional[Dict[str, Any]], - secrets: Optional[Dict[str, Any]], + env: Optional[dict[str, Any]], + secrets: Optional[dict[str, Any]], namespace: Optional[str], token: Union[bool, str, None], _repo: Optional[str], - ) -> Tuple[List[str], Dict[str, Any], Dict[str, Any]]: + ) -> tuple[list[str], dict[str, Any], dict[str, Any]]: env = env or {} secrets = secrets or {} diff --git a/src/huggingface_hub/hf_file_system.py b/src/huggingface_hub/hf_file_system.py index e82365e3ce..ff3bb02661 100644 --- a/src/huggingface_hub/hf_file_system.py +++ b/src/huggingface_hub/hf_file_system.py @@ -7,7 +7,7 @@ from datetime import datetime from itertools import chain from pathlib import Path -from typing import Any, Dict, Iterator, List, NoReturn, Optional, Tuple, Union +from typing import Any, Iterator, NoReturn, Optional, Union from urllib.parse import quote, unquote import fsspec @@ -114,13 +114,13 @@ def __init__( # Maps (repo_type, repo_id, revision) to a 2-tuple with: # * the 1st element indicating whether the repositoy and the revision exist # * the 2nd element being the exception raised if the repository or revision doesn't exist - self._repo_and_revision_exists_cache: Dict[ - Tuple[str, str, Optional[str]], Tuple[bool, Optional[Exception]] + self._repo_and_revision_exists_cache: dict[ + tuple[str, str, Optional[str]], tuple[bool, Optional[Exception]] ] = {} def _repo_and_revision_exist( self, repo_type: str, repo_id: str, revision: Optional[str] - ) -> Tuple[bool, Optional[Exception]]: + ) -> tuple[bool, Optional[Exception]]: if (repo_type, repo_id, revision) not in self._repo_and_revision_exists_cache: try: self._api.repo_info( @@ -339,7 +339,7 @@ def rm( def ls( self, path: str, detail: bool = True, refresh: bool = False, revision: Optional[str] = None, **kwargs - ) -> List[Union[str, Dict[str, Any]]]: + ) -> list[Union[str, dict[str, Any]]]: """ List the contents of a directory. @@ -363,7 +363,7 @@ def ls( The git revision to list from. Returns: - `List[Union[str, Dict[str, Any]]]`: List of file paths (if detail=False) or list of file information + `list[Union[str, dict[str, Any]]]`: List of file paths (if detail=False) or list of file information dictionaries (if detail=True). """ resolved_path = self.resolve_path(path, revision=revision) @@ -484,7 +484,7 @@ def _ls_tree( out.append(cache_path_info) return out - def walk(self, path: str, *args, **kwargs) -> Iterator[Tuple[str, List[str], List[str]]]: + def walk(self, path: str, *args, **kwargs) -> Iterator[tuple[str, list[str], list[str]]]: """ Return all files below the given path. @@ -495,12 +495,12 @@ def walk(self, path: str, *args, **kwargs) -> Iterator[Tuple[str, List[str], Lis Root path to list files from. Returns: - `Iterator[Tuple[str, List[str], List[str]]]`: An iterator of (path, list of directory names, list of file names) tuples. + `Iterator[tuple[str, list[str], list[str]]]`: An iterator of (path, list of directory names, list of file names) tuples. """ path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve() yield from super().walk(path, *args, **kwargs) - def glob(self, path: str, **kwargs) -> List[str]: + def glob(self, path: str, **kwargs) -> list[str]: """ Find files by glob-matching. @@ -511,7 +511,7 @@ def glob(self, path: str, **kwargs) -> List[str]: Path pattern to match. Returns: - `List[str]`: List of paths matching the pattern. + `list[str]`: List of paths matching the pattern. """ path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve() return super().glob(path, **kwargs) @@ -525,7 +525,7 @@ def find( refresh: bool = False, revision: Optional[str] = None, **kwargs, - ) -> Union[List[str], Dict[str, Dict[str, Any]]]: + ) -> Union[list[str], dict[str, dict[str, Any]]]: """ List all files below path. @@ -546,7 +546,7 @@ def find( The git revision to list from. Returns: - `Union[List[str], Dict[str, Dict[str, Any]]]`: List of paths or dict of file information. + `Union[list[str], dict[str, dict[str, Any]]]`: List of paths or dict of file information. """ if maxdepth: return super().find( @@ -651,7 +651,7 @@ def modified(self, path: str, **kwargs) -> datetime: info = self.info(path, **{**kwargs, "expand_info": True}) return info["last_commit"]["date"] - def info(self, path: str, refresh: bool = False, revision: Optional[str] = None, **kwargs) -> Dict[str, Any]: + def info(self, path: str, refresh: bool = False, revision: Optional[str] = None, **kwargs) -> dict[str, Any]: """ Get information about a file or directory. @@ -672,7 +672,7 @@ def info(self, path: str, refresh: bool = False, revision: Optional[str] = None, The git revision to get info from. Returns: - `Dict[str, Any]`: Dictionary containing file information (type, size, commit info, etc.). + `dict[str, Any]`: Dictionary containing file information (type, size, commit info, etc.). """ resolved_path = self.resolve_path(path, revision=revision) diff --git a/src/huggingface_hub/hub_mixin.py b/src/huggingface_hub/hub_mixin.py index d1ddee213f..c297026d35 100644 --- a/src/huggingface_hub/hub_mixin.py +++ b/src/huggingface_hub/hub_mixin.py @@ -3,7 +3,7 @@ import os from dataclasses import Field, asdict, dataclass, is_dataclass from pathlib import Path -from typing import Any, Callable, ClassVar, Dict, List, Optional, Protocol, Tuple, Type, TypeVar, Union +from typing import Any, Callable, ClassVar, Optional, Protocol, Type, TypeVar, Union import packaging.version @@ -38,7 +38,7 @@ # Type alias for dataclass instances, copied from https://github.com/python/typeshed/blob/9f28171658b9ca6c32a7cb93fbb99fc92b17858b/stdlib/_typeshed/__init__.pyi#L349 class DataclassInstance(Protocol): - __dataclass_fields__: ClassVar[Dict[str, Field]] + __dataclass_fields__: ClassVar[dict[str, Field]] # Generic variable that is either ModelHubMixin or a subclass thereof @@ -47,7 +47,7 @@ class DataclassInstance(Protocol): ARGS_T = TypeVar("ARGS_T") ENCODER_T = Callable[[ARGS_T], Any] DECODER_T = Callable[[Any], ARGS_T] -CODER_T = Tuple[ENCODER_T, DECODER_T] +CODER_T = tuple[ENCODER_T, DECODER_T] DEFAULT_MODEL_CARD = """ @@ -96,7 +96,7 @@ class ModelHubMixin: URL of the library documentation. Used to generate model card. model_card_template (`str`, *optional*): Template of the model card. Used to generate model card. Defaults to a generic template. - language (`str` or `List[str]`, *optional*): + language (`str` or `list[str]`, *optional*): Language supported by the library. Used to generate model card. library_name (`str`, *optional*): Name of the library integrating ModelHubMixin. Used to generate model card. @@ -113,9 +113,9 @@ class ModelHubMixin: E.g: "https://coqui.ai/cpml". pipeline_tag (`str`, *optional*): Tag of the pipeline. Used to generate model card. E.g. "text-classification". - tags (`List[str]`, *optional*): + tags (`list[str]`, *optional*): Tags to be added to the model card. Used to generate model card. E.g. ["computer-vision"] - coders (`Dict[Type, Tuple[Callable, Callable]]`, *optional*): + coders (`dict[Type, tuple[Callable, Callable]]`, *optional*): Dictionary of custom types and their encoders/decoders. Used to encode/decode arguments that are not jsonable by default. E.g dataclasses, argparse.Namespace, OmegaConf, etc. @@ -145,7 +145,7 @@ class ModelHubMixin: ... ... @classmethod ... def from_pretrained( - ... cls: Type[T], + ... cls: type[T], ... pretrained_model_name_or_path: Union[str, Path], ... *, ... force_download: bool = False, @@ -187,10 +187,10 @@ class ModelHubMixin: _hub_mixin_info: MixinInfo # ^ information about the library integrating ModelHubMixin (used to generate model card) _hub_mixin_inject_config: bool # whether `_from_pretrained` expects `config` or not - _hub_mixin_init_parameters: Dict[str, inspect.Parameter] # __init__ parameters - _hub_mixin_jsonable_default_values: Dict[str, Any] # default values for __init__ parameters - _hub_mixin_jsonable_custom_types: Tuple[Type, ...] # custom types that can be encoded/decoded - _hub_mixin_coders: Dict[Type, CODER_T] # encoders/decoders for custom types + _hub_mixin_init_parameters: dict[str, inspect.Parameter] # __init__ parameters + _hub_mixin_jsonable_default_values: dict[str, Any] # default values for __init__ parameters + _hub_mixin_jsonable_custom_types: tuple[Type, ...] # custom types that can be encoded/decoded + _hub_mixin_coders: dict[Type, CODER_T] # encoders/decoders for custom types # ^ internal values to handle config def __init_subclass__( @@ -203,16 +203,16 @@ def __init_subclass__( # Model card template model_card_template: str = DEFAULT_MODEL_CARD, # Model card metadata - language: Optional[List[str]] = None, + language: Optional[list[str]] = None, library_name: Optional[str] = None, license: Optional[str] = None, license_name: Optional[str] = None, license_link: Optional[str] = None, pipeline_tag: Optional[str] = None, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, # How to encode/decode arguments with custom type into a JSON config? coders: Optional[ - Dict[Type, CODER_T] + dict[Type, CODER_T] # Key is a type. # Value is a tuple (encoder, decoder). # Example: {MyCustomType: (lambda x: x.value, lambda data: MyCustomType(data))} @@ -287,7 +287,7 @@ def __init_subclass__( } cls._hub_mixin_inject_config = "config" in inspect.signature(cls._from_pretrained).parameters - def __new__(cls: Type[T], *args, **kwargs) -> T: + def __new__(cls: type[T], *args, **kwargs) -> T: """Create a new instance of the class and handle config. 3 cases: @@ -363,7 +363,7 @@ def _encode_arg(cls, arg: Any) -> Any: return arg @classmethod - def _decode_arg(cls, expected_type: Type[ARGS_T], value: Any) -> Optional[ARGS_T]: + def _decode_arg(cls, expected_type: type[ARGS_T], value: Any) -> Optional[ARGS_T]: """Decode a JSON serializable value into an argument.""" if is_simple_optional_type(expected_type): if value is None: @@ -386,7 +386,7 @@ def save_pretrained( config: Optional[Union[dict, DataclassInstance]] = None, repo_id: Optional[str] = None, push_to_hub: bool = False, - model_card_kwargs: Optional[Dict[str, Any]] = None, + model_card_kwargs: Optional[dict[str, Any]] = None, **push_to_hub_kwargs, ) -> Optional[str]: """ @@ -402,7 +402,7 @@ def save_pretrained( repo_id (`str`, *optional*): ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if not provided. - model_card_kwargs (`Dict[str, Any]`, *optional*): + model_card_kwargs (`dict[str, Any]`, *optional*): Additional arguments passed to the model card template to customize the model card. push_to_hub_kwargs: Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method. @@ -461,7 +461,7 @@ def _save_pretrained(self, save_directory: Path) -> None: @classmethod @validate_hf_hub_args def from_pretrained( - cls: Type[T], + cls: type[T], pretrained_model_name_or_path: Union[str, Path], *, force_download: bool = False, @@ -493,7 +493,7 @@ def from_pretrained( Path to the folder where cached files are stored. local_files_only (`bool`, *optional*, defaults to `False`): If `True`, avoid downloading the file and return the path to the local cached file if it exists. - model_kwargs (`Dict`, *optional*): + model_kwargs (`dict`, *optional*): Additional kwargs to pass to the model during initialization. """ model_id = str(pretrained_model_name_or_path) @@ -551,7 +551,7 @@ def from_pretrained( if key not in model_kwargs and key in config: model_kwargs[key] = config[key] elif any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()): - for key, value in config.items(): + for key, value in config.items(): # type: ignore[union-attr] if key not in model_kwargs: model_kwargs[key] = value @@ -579,7 +579,7 @@ def from_pretrained( @classmethod def _from_pretrained( - cls: Type[T], + cls: type[T], *, model_id: str, revision: Optional[str], @@ -631,10 +631,10 @@ def push_to_hub( token: Optional[str] = None, branch: Optional[str] = None, create_pr: Optional[bool] = None, - allow_patterns: Optional[Union[List[str], str]] = None, - ignore_patterns: Optional[Union[List[str], str]] = None, - delete_patterns: Optional[Union[List[str], str]] = None, - model_card_kwargs: Optional[Dict[str, Any]] = None, + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, + delete_patterns: Optional[Union[list[str], str]] = None, + model_card_kwargs: Optional[dict[str, Any]] = None, ) -> str: """ Upload model checkpoint to the Hub. @@ -660,13 +660,13 @@ def push_to_hub( The git branch on which to push the model. This defaults to `"main"`. create_pr (`boolean`, *optional*): Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`. - allow_patterns (`List[str]` or `str`, *optional*): + allow_patterns (`list[str]` or `str`, *optional*): If provided, only files matching at least one pattern are pushed. - ignore_patterns (`List[str]` or `str`, *optional*): + ignore_patterns (`list[str]` or `str`, *optional*): If provided, files matching any of the patterns are not pushed. - delete_patterns (`List[str]` or `str`, *optional*): + delete_patterns (`list[str]` or `str`, *optional*): If provided, remote files matching any of the patterns will be deleted from the repo. - model_card_kwargs (`Dict[str, Any]`, *optional*): + model_card_kwargs (`dict[str, Any]`, *optional*): Additional arguments passed to the model card template to customize the model card. Returns: @@ -749,7 +749,7 @@ class PyTorchModelHubMixin(ModelHubMixin): ``` """ - def __init_subclass__(cls, *args, tags: Optional[List[str]] = None, **kwargs) -> None: + def __init_subclass__(cls, *args, tags: Optional[list[str]] = None, **kwargs) -> None: tags = tags or [] tags.append("pytorch_model_hub_mixin") kwargs["tags"] = tags @@ -831,7 +831,7 @@ def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, stric return model -def _load_dataclass(datacls: Type[DataclassInstance], data: dict) -> DataclassInstance: +def _load_dataclass(datacls: type[DataclassInstance], data: dict) -> DataclassInstance: """Load a dataclass instance from a dictionary. Fields not expected by the dataclass are ignored. diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index 22a979509c..ef78d93eee 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -38,7 +38,7 @@ import re import warnings from contextlib import ExitStack -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Union, overload +from typing import TYPE_CHECKING, Any, Iterable, Literal, Optional, Union, overload from huggingface_hub import constants from huggingface_hub.errors import BadRequestError, HfHubHTTPError, InferenceTimeoutError @@ -144,13 +144,13 @@ class InferenceClient: arguments are mutually exclusive and have the exact same behavior. timeout (`float`, `optional`): The maximum number of seconds to wait for a response from the server. Defaults to None, meaning it will loop until the server is available. - headers (`Dict[str, str]`, `optional`): + headers (`dict[str, str]`, `optional`): Additional headers to send to the server. By default only the authorization and user-agent headers are sent. Values in this dictionary will override the default values. bill_to (`str`, `optional`): The billing account to use for the requests. By default the requests are billed on the user's account. Requests can only be billed to an organization the user is a member of, and which has subscribed to Enterprise Hub. - cookies (`Dict[str, str]`, `optional`): + cookies (`dict[str, str]`, `optional`): Additional cookies to send to the server. base_url (`str`, `optional`): Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`] @@ -168,8 +168,8 @@ def __init__( provider: Optional[PROVIDER_OR_POLICY_T] = None, token: Optional[str] = None, timeout: Optional[float] = None, - headers: Optional[Dict[str, str]] = None, - cookies: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, + cookies: Optional[dict[str, str]] = None, bill_to: Optional[str] = None, # OpenAI compatibility base_url: Optional[str] = None, @@ -304,7 +304,7 @@ def audio_classification( model: Optional[str] = None, top_k: Optional[int] = None, function_to_apply: Optional["AudioClassificationOutputTransform"] = None, - ) -> List[AudioClassificationOutputElement]: + ) -> list[AudioClassificationOutputElement]: """ Perform audio classification on the provided audio content. @@ -322,7 +322,7 @@ def audio_classification( The function to apply to the model outputs in order to retrieve the scores. Returns: - `List[AudioClassificationOutputElement]`: List of [`AudioClassificationOutputElement`] items containing the predicted labels and their confidence. + `list[AudioClassificationOutputElement]`: List of [`AudioClassificationOutputElement`] items containing the predicted labels and their confidence. Raises: [`InferenceTimeoutError`]: @@ -359,7 +359,7 @@ def audio_to_audio( audio: ContentT, *, model: Optional[str] = None, - ) -> List[AudioToAudioOutputElement]: + ) -> list[AudioToAudioOutputElement]: """ Performs multiple tasks related to audio-to-audio depending on the model (eg: speech enhancement, source separation). @@ -373,7 +373,7 @@ def audio_to_audio( audio_to_audio will be used. Returns: - `List[AudioToAudioOutputElement]`: A list of [`AudioToAudioOutputElement`] items containing audios label, content-type, and audio content in blob. + `list[AudioToAudioOutputElement]`: A list of [`AudioToAudioOutputElement`] items containing audios label, content-type, and audio content in blob. Raises: `InferenceTimeoutError`: @@ -411,7 +411,7 @@ def automatic_speech_recognition( audio: ContentT, *, model: Optional[str] = None, - extra_body: Optional[Dict] = None, + extra_body: Optional[dict] = None, ) -> AutomaticSpeechRecognitionOutput: """ Perform automatic speech recognition (ASR or audio-to-text) on the given audio content. @@ -422,7 +422,7 @@ def automatic_speech_recognition( model (`str`, *optional*): The model to use for ASR. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for ASR will be used. - extra_body (`Dict`, *optional*): + extra_body (`dict`, *optional*): Additional provider-specific parameters to pass to the model. Refer to the provider's documentation for supported parameters. Returns: @@ -457,105 +457,105 @@ def automatic_speech_recognition( @overload def chat_completion( # type: ignore self, - messages: List[Union[Dict, ChatCompletionInputMessage]], + messages: list[Union[dict, ChatCompletionInputMessage]], *, model: Optional[str] = None, stream: Literal[False] = False, frequency_penalty: Optional[float] = None, - logit_bias: Optional[List[float]] = None, + logit_bias: Optional[list[float]] = None, logprobs: Optional[bool] = None, max_tokens: Optional[int] = None, n: Optional[int] = None, presence_penalty: Optional[float] = None, response_format: Optional[ChatCompletionInputGrammarType] = None, seed: Optional[int] = None, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, stream_options: Optional[ChatCompletionInputStreamOptions] = None, temperature: Optional[float] = None, tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None, tool_prompt: Optional[str] = None, - tools: Optional[List[ChatCompletionInputTool]] = None, + tools: Optional[list[ChatCompletionInputTool]] = None, top_logprobs: Optional[int] = None, top_p: Optional[float] = None, - extra_body: Optional[Dict] = None, + extra_body: Optional[dict] = None, ) -> ChatCompletionOutput: ... @overload def chat_completion( # type: ignore self, - messages: List[Union[Dict, ChatCompletionInputMessage]], + messages: list[Union[dict, ChatCompletionInputMessage]], *, model: Optional[str] = None, stream: Literal[True] = True, frequency_penalty: Optional[float] = None, - logit_bias: Optional[List[float]] = None, + logit_bias: Optional[list[float]] = None, logprobs: Optional[bool] = None, max_tokens: Optional[int] = None, n: Optional[int] = None, presence_penalty: Optional[float] = None, response_format: Optional[ChatCompletionInputGrammarType] = None, seed: Optional[int] = None, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, stream_options: Optional[ChatCompletionInputStreamOptions] = None, temperature: Optional[float] = None, tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None, tool_prompt: Optional[str] = None, - tools: Optional[List[ChatCompletionInputTool]] = None, + tools: Optional[list[ChatCompletionInputTool]] = None, top_logprobs: Optional[int] = None, top_p: Optional[float] = None, - extra_body: Optional[Dict] = None, + extra_body: Optional[dict] = None, ) -> Iterable[ChatCompletionStreamOutput]: ... @overload def chat_completion( self, - messages: List[Union[Dict, ChatCompletionInputMessage]], + messages: list[Union[dict, ChatCompletionInputMessage]], *, model: Optional[str] = None, stream: bool = False, frequency_penalty: Optional[float] = None, - logit_bias: Optional[List[float]] = None, + logit_bias: Optional[list[float]] = None, logprobs: Optional[bool] = None, max_tokens: Optional[int] = None, n: Optional[int] = None, presence_penalty: Optional[float] = None, response_format: Optional[ChatCompletionInputGrammarType] = None, seed: Optional[int] = None, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, stream_options: Optional[ChatCompletionInputStreamOptions] = None, temperature: Optional[float] = None, tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None, tool_prompt: Optional[str] = None, - tools: Optional[List[ChatCompletionInputTool]] = None, + tools: Optional[list[ChatCompletionInputTool]] = None, top_logprobs: Optional[int] = None, top_p: Optional[float] = None, - extra_body: Optional[Dict] = None, + extra_body: Optional[dict] = None, ) -> Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]: ... def chat_completion( self, - messages: List[Union[Dict, ChatCompletionInputMessage]], + messages: list[Union[dict, ChatCompletionInputMessage]], *, model: Optional[str] = None, stream: bool = False, # Parameters from ChatCompletionInput (handled manually) frequency_penalty: Optional[float] = None, - logit_bias: Optional[List[float]] = None, + logit_bias: Optional[list[float]] = None, logprobs: Optional[bool] = None, max_tokens: Optional[int] = None, n: Optional[int] = None, presence_penalty: Optional[float] = None, response_format: Optional[ChatCompletionInputGrammarType] = None, seed: Optional[int] = None, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, stream_options: Optional[ChatCompletionInputStreamOptions] = None, temperature: Optional[float] = None, tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None, tool_prompt: Optional[str] = None, - tools: Optional[List[ChatCompletionInputTool]] = None, + tools: Optional[list[ChatCompletionInputTool]] = None, top_logprobs: Optional[int] = None, top_p: Optional[float] = None, - extra_body: Optional[Dict] = None, + extra_body: Optional[dict] = None, ) -> Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]: """ A method for completing conversations using a specified language model. @@ -585,7 +585,7 @@ def chat_completion( frequency_penalty (`float`, *optional*): Penalizes new tokens based on their existing frequency in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0. - logit_bias (`List[float]`, *optional*): + logit_bias (`list[float]`, *optional*): Adjusts the likelihood of specific tokens appearing in the generated output. logprobs (`bool`, *optional*): Whether to return log probabilities of the output tokens or not. If true, returns the log @@ -601,7 +601,7 @@ def chat_completion( Grammar constraints. Can be either a JSONSchema or a regex. seed (Optional[`int`], *optional*): Seed for reproducible control flow. Defaults to None. - stop (`List[str]`, *optional*): + stop (`list[str]`, *optional*): Up to four strings which trigger the end of the response. Defaults to None. stream (`bool`, *optional*): @@ -625,7 +625,7 @@ def chat_completion( tools (List of [`ChatCompletionInputTool`], *optional*): A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. - extra_body (`Dict`, *optional*): + extra_body (`dict`, *optional*): Additional provider-specific parameters to pass to the model. Refer to the provider's documentation for supported parameters. Returns: @@ -951,8 +951,8 @@ def document_question_answering( max_question_len: Optional[int] = None, max_seq_len: Optional[int] = None, top_k: Optional[int] = None, - word_boxes: Optional[List[Union[List[float], str]]] = None, - ) -> List[DocumentQuestionAnsweringOutputElement]: + word_boxes: Optional[list[Union[list[float], str]]] = None, + ) -> list[DocumentQuestionAnsweringOutputElement]: """ Answer questions on document images. @@ -982,11 +982,11 @@ def document_question_answering( top_k (`int`, *optional*): The number of answers to return (will be chosen by order of likelihood). Can return less than top_k answers if there are not enough options available within the context. - word_boxes (`List[Union[List[float], str`, *optional*): + word_boxes (`list[Union[list[float], str`, *optional*): A list of words and bounding boxes (normalized 0->1000). If provided, the inference will skip the OCR step and use the provided bounding boxes instead. Returns: - `List[DocumentQuestionAnsweringOutputElement]`: a list of [`DocumentQuestionAnsweringOutputElement`] items containing the predicted label, associated probability, word ids, and page number. + `list[DocumentQuestionAnsweringOutputElement]`: a list of [`DocumentQuestionAnsweringOutputElement`] items containing the predicted label, associated probability, word ids, and page number. Raises: [`InferenceTimeoutError`]: @@ -1005,7 +1005,7 @@ def document_question_answering( """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="document-question-answering", model=model_id) - inputs: Dict[str, Any] = {"question": question, "image": _b64_encode(image)} + inputs: dict[str, Any] = {"question": question, "image": _b64_encode(image)} request_parameters = provider_helper.prepare_request( inputs=inputs, parameters={ @@ -1103,9 +1103,9 @@ def fill_mask( text: str, *, model: Optional[str] = None, - targets: Optional[List[str]] = None, + targets: Optional[list[str]] = None, top_k: Optional[int] = None, - ) -> List[FillMaskOutputElement]: + ) -> list[FillMaskOutputElement]: """ Fill in a hole with a missing word (token to be precise). @@ -1115,14 +1115,14 @@ def fill_mask( model (`str`, *optional*): The model to use for the fill mask task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended fill mask model will be used. - targets (`List[str`, *optional*): + targets (`list[str`, *optional*): When passed, the model will limit the scores to the passed targets instead of looking up in the whole vocabulary. If the provided targets are not in the model vocab, they will be tokenized and the first resulting token will be used (with a warning, and that might be slower). top_k (`int`, *optional*): When passed, overrides the number of predictions to return. Returns: - `List[FillMaskOutputElement]`: a list of [`FillMaskOutputElement`] items containing the predicted label, associated + `list[FillMaskOutputElement]`: a list of [`FillMaskOutputElement`] items containing the predicted label, associated probability, token reference, and completed text. Raises: @@ -1161,7 +1161,7 @@ def image_classification( model: Optional[str] = None, function_to_apply: Optional["ImageClassificationOutputTransform"] = None, top_k: Optional[int] = None, - ) -> List[ImageClassificationOutputElement]: + ) -> list[ImageClassificationOutputElement]: """ Perform image classification on the given image using the specified model. @@ -1176,7 +1176,7 @@ def image_classification( top_k (`int`, *optional*): When specified, limits the output to the top K most probable classes. Returns: - `List[ImageClassificationOutputElement]`: a list of [`ImageClassificationOutputElement`] items containing the predicted label and associated probability. + `list[ImageClassificationOutputElement]`: a list of [`ImageClassificationOutputElement`] items containing the predicted label and associated probability. Raises: [`InferenceTimeoutError`]: @@ -1213,7 +1213,7 @@ def image_segmentation( overlap_mask_area_threshold: Optional[float] = None, subtask: Optional["ImageSegmentationSubtask"] = None, threshold: Optional[float] = None, - ) -> List[ImageSegmentationOutputElement]: + ) -> list[ImageSegmentationOutputElement]: """ Perform image segmentation on the given image using the specified model. @@ -1238,7 +1238,7 @@ def image_segmentation( threshold (`float`, *optional*): Probability threshold to filter out predicted masks. Returns: - `List[ImageSegmentationOutputElement]`: A list of [`ImageSegmentationOutputElement`] items containing the segmented masks and associated attributes. + `list[ImageSegmentationOutputElement]`: A list of [`ImageSegmentationOutputElement`] items containing the segmented masks and associated attributes. Raises: [`InferenceTimeoutError`]: @@ -1473,12 +1473,12 @@ def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> Imag api_key=self.token, ) response = self._inner_post(request_parameters) - output_list: List[ImageToTextOutput] = ImageToTextOutput.parse_obj_as_list(response) + output_list: list[ImageToTextOutput] = ImageToTextOutput.parse_obj_as_list(response) return output_list[0] def object_detection( self, image: ContentT, *, model: Optional[str] = None, threshold: Optional[float] = None - ) -> List[ObjectDetectionOutputElement]: + ) -> list[ObjectDetectionOutputElement]: """ Perform object detection on the given image using the specified model. @@ -1497,7 +1497,7 @@ def object_detection( threshold (`float`, *optional*): The probability necessary to make a prediction. Returns: - `List[ObjectDetectionOutputElement]`: A list of [`ObjectDetectionOutputElement`] items containing the bounding boxes and associated attributes. + `list[ObjectDetectionOutputElement]`: A list of [`ObjectDetectionOutputElement`] items containing the bounding boxes and associated attributes. Raises: [`InferenceTimeoutError`]: @@ -1540,7 +1540,7 @@ def question_answering( max_question_len: Optional[int] = None, max_seq_len: Optional[int] = None, top_k: Optional[int] = None, - ) -> Union[QuestionAnsweringOutputElement, List[QuestionAnsweringOutputElement]]: + ) -> Union[QuestionAnsweringOutputElement, list[QuestionAnsweringOutputElement]]: """ Retrieve the answer to a question from a given text. @@ -1572,7 +1572,7 @@ def question_answering( topk answers if there are not enough options available within the context. Returns: - Union[`QuestionAnsweringOutputElement`, List[`QuestionAnsweringOutputElement`]]: + Union[`QuestionAnsweringOutputElement`, list[`QuestionAnsweringOutputElement`]]: When top_k is 1 or not provided, it returns a single `QuestionAnsweringOutputElement`. When top_k is greater than 1, it returns a list of `QuestionAnsweringOutputElement`. Raises: @@ -1612,15 +1612,15 @@ def question_answering( return output def sentence_similarity( - self, sentence: str, other_sentences: List[str], *, model: Optional[str] = None - ) -> List[float]: + self, sentence: str, other_sentences: list[str], *, model: Optional[str] = None + ) -> list[float]: """ Compute the semantic similarity between a sentence and a list of other sentences by comparing their embeddings. Args: sentence (`str`): The main sentence to compare to others. - other_sentences (`List[str]`): + other_sentences (`list[str]`): The list of sentences to compare to. model (`str`, *optional*): The model to use for the sentence similarity task. Can be a model ID hosted on the Hugging Face Hub or a URL to @@ -1628,7 +1628,7 @@ def sentence_similarity( Defaults to None. Returns: - `List[float]`: The embedding representing the input text. + `list[float]`: The embedding representing the input text. Raises: [`InferenceTimeoutError`]: @@ -1670,7 +1670,7 @@ def summarization( *, model: Optional[str] = None, clean_up_tokenization_spaces: Optional[bool] = None, - generate_parameters: Optional[Dict[str, Any]] = None, + generate_parameters: Optional[dict[str, Any]] = None, truncation: Optional["SummarizationTruncationStrategy"] = None, ) -> SummarizationOutput: """ @@ -1684,7 +1684,7 @@ def summarization( Inference Endpoint. If not provided, the default recommended model for summarization will be used. clean_up_tokenization_spaces (`bool`, *optional*): Whether to clean up the potential extra spaces in the text output. - generate_parameters (`Dict[str, Any]`, *optional*): + generate_parameters (`dict[str, Any]`, *optional*): Additional parametrization of the text generation algorithm. truncation (`"SummarizationTruncationStrategy"`, *optional*): The truncation strategy to use. @@ -1724,7 +1724,7 @@ def summarization( def table_question_answering( self, - table: Dict[str, Any], + table: dict[str, Any], query: str, *, model: Optional[str] = None, @@ -1784,12 +1784,12 @@ def table_question_answering( response = self._inner_post(request_parameters) return TableQuestionAnsweringOutputElement.parse_obj_as_instance(response) - def tabular_classification(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[str]: + def tabular_classification(self, table: dict[str, Any], *, model: Optional[str] = None) -> list[str]: """ Classifying a target category (a group) based on a set of attributes. Args: - table (`Dict[str, Any]`): + table (`dict[str, Any]`): Set of attributes to classify. model (`str`, *optional*): The model to use for the tabular classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to @@ -1839,12 +1839,12 @@ def tabular_classification(self, table: Dict[str, Any], *, model: Optional[str] response = self._inner_post(request_parameters) return _bytes_to_list(response) - def tabular_regression(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[float]: + def tabular_regression(self, table: dict[str, Any], *, model: Optional[str] = None) -> list[float]: """ Predicting a numerical target value given a set of attributes/features in a table. Args: - table (`Dict[str, Any]`): + table (`dict[str, Any]`): Set of attributes stored in a table. The attributes used to predict the target can be both numerical and categorical. model (`str`, *optional*): The model to use for the tabular regression task. Can be a model ID hosted on the Hugging Face Hub or a URL to @@ -1896,7 +1896,7 @@ def text_classification( model: Optional[str] = None, top_k: Optional[int] = None, function_to_apply: Optional["TextClassificationOutputTransform"] = None, - ) -> List[TextClassificationOutputElement]: + ) -> list[TextClassificationOutputElement]: """ Perform text classification (e.g. sentiment-analysis) on the given text. @@ -1913,7 +1913,7 @@ def text_classification( The function to apply to the model outputs in order to retrieve the scores. Returns: - `List[TextClassificationOutputElement]`: a list of [`TextClassificationOutputElement`] items containing the predicted label and associated probability. + `list[TextClassificationOutputElement]`: a list of [`TextClassificationOutputElement`] items containing the predicted label and associated probability. Raises: [`InferenceTimeoutError`]: @@ -1966,8 +1966,8 @@ def text_generation( repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = None, seed: Optional[int] = None, - stop: Optional[List[str]] = None, - stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + stop: Optional[list[str]] = None, + stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, @@ -1996,8 +1996,8 @@ def text_generation( repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = None, seed: Optional[int] = None, - stop: Optional[List[str]] = None, - stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + stop: Optional[list[str]] = None, + stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, @@ -2026,8 +2026,8 @@ def text_generation( repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = None, # Manual default value seed: Optional[int] = None, - stop: Optional[List[str]] = None, - stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + stop: Optional[list[str]] = None, + stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, @@ -2056,8 +2056,8 @@ def text_generation( repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = None, seed: Optional[int] = None, - stop: Optional[List[str]] = None, - stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + stop: Optional[list[str]] = None, + stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, @@ -2086,8 +2086,8 @@ def text_generation( repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = None, seed: Optional[int] = None, - stop: Optional[List[str]] = None, - stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + stop: Optional[list[str]] = None, + stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, @@ -2115,8 +2115,8 @@ def text_generation( repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = None, seed: Optional[int] = None, - stop: Optional[List[str]] = None, - stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + stop: Optional[list[str]] = None, + stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, @@ -2172,9 +2172,9 @@ def text_generation( Whether to prepend the prompt to the generated text seed (`int`, *optional*): Random sampling seed - stop (`List[str]`, *optional*): + stop (`list[str]`, *optional*): Stop generating tokens if a member of `stop` is generated. - stop_sequences (`List[str]`, *optional*): + stop_sequences (`list[str]`, *optional*): Deprecated argument. Use `stop` instead. temperature (`float`, *optional*): The value used to module the logits distribution. @@ -2451,7 +2451,7 @@ def text_to_image( model: Optional[str] = None, scheduler: Optional[str] = None, seed: Optional[int] = None, - extra_body: Optional[Dict[str, Any]] = None, + extra_body: Optional[dict[str, Any]] = None, ) -> "Image": """ Generate an image based on a given text using a specified model. @@ -2489,7 +2489,7 @@ def text_to_image( Override the scheduler with a compatible one. seed (`int`, *optional*): Seed for the random number generator. - extra_body (`Dict[str, Any]`, *optional*): + extra_body (`dict[str, Any]`, *optional*): Additional provider-specific parameters to pass to the model. Refer to the provider's documentation for supported parameters. @@ -2588,11 +2588,11 @@ def text_to_video( *, model: Optional[str] = None, guidance_scale: Optional[float] = None, - negative_prompt: Optional[List[str]] = None, + negative_prompt: Optional[list[str]] = None, num_frames: Optional[float] = None, num_inference_steps: Optional[int] = None, seed: Optional[int] = None, - extra_body: Optional[Dict[str, Any]] = None, + extra_body: Optional[dict[str, Any]] = None, ) -> bytes: """ Generate a video based on a given text. @@ -2611,7 +2611,7 @@ def text_to_video( guidance_scale (`float`, *optional*): A higher guidance scale value encourages the model to generate videos closely linked to the text prompt, but values too high may cause saturation and other artifacts. - negative_prompt (`List[str]`, *optional*): + negative_prompt (`list[str]`, *optional*): One or several prompt to guide what NOT to include in video generation. num_frames (`float`, *optional*): The num_frames parameter determines how many video frames are generated. @@ -2620,7 +2620,7 @@ def text_to_video( expense of slower inference. seed (`int`, *optional*): Seed for the random number generator. - extra_body (`Dict[str, Any]`, *optional*): + extra_body (`dict[str, Any]`, *optional*): Additional provider-specific parameters to pass to the model. Refer to the provider's documentation for supported parameters. @@ -2700,7 +2700,7 @@ def text_to_speech( top_p: Optional[float] = None, typical_p: Optional[float] = None, use_cache: Optional[bool] = None, - extra_body: Optional[Dict[str, Any]] = None, + extra_body: Optional[dict[str, Any]] = None, ) -> bytes: """ Synthesize an audio of a voice pronouncing a given text. @@ -2762,7 +2762,7 @@ def text_to_speech( paper](https://hf.co/papers/2202.00666) for more details. use_cache (`bool`, *optional*): Whether the model should use the past last key/values attentions to speed up decoding - extra_body (`Dict[str, Any]`, *optional*): + extra_body (`dict[str, Any]`, *optional*): Additional provider-specific parameters to pass to the model. Refer to the provider's documentation for supported parameters. Returns: @@ -2894,9 +2894,9 @@ def token_classification( *, model: Optional[str] = None, aggregation_strategy: Optional["TokenClassificationAggregationStrategy"] = None, - ignore_labels: Optional[List[str]] = None, + ignore_labels: Optional[list[str]] = None, stride: Optional[int] = None, - ) -> List[TokenClassificationOutputElement]: + ) -> list[TokenClassificationOutputElement]: """ Perform token classification on the given text. Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text. @@ -2910,13 +2910,13 @@ def token_classification( Defaults to None. aggregation_strategy (`"TokenClassificationAggregationStrategy"`, *optional*): The strategy used to fuse tokens based on model predictions - ignore_labels (`List[str`, *optional*): + ignore_labels (`list[str`, *optional*): A list of labels to ignore stride (`int`, *optional*): The number of overlapping tokens between chunks when splitting the input text. Returns: - `List[TokenClassificationOutputElement]`: List of [`TokenClassificationOutputElement`] items containing the entity group, confidence score, word, start and end index. + `list[TokenClassificationOutputElement]`: List of [`TokenClassificationOutputElement`] items containing the entity group, confidence score, word, start and end index. Raises: [`InferenceTimeoutError`]: @@ -2972,7 +2972,7 @@ def translation( tgt_lang: Optional[str] = None, clean_up_tokenization_spaces: Optional[bool] = None, truncation: Optional["TranslationTruncationStrategy"] = None, - generate_parameters: Optional[Dict[str, Any]] = None, + generate_parameters: Optional[dict[str, Any]] = None, ) -> TranslationOutput: """ Convert text from one language to another. @@ -2997,7 +2997,7 @@ def translation( Whether to clean up the potential extra spaces in the text output. truncation (`"TranslationTruncationStrategy"`, *optional*): The truncation strategy to use. - generate_parameters (`Dict[str, Any]`, *optional*): + generate_parameters (`dict[str, Any]`, *optional*): Additional parametrization of the text generation algorithm. Returns: @@ -3059,7 +3059,7 @@ def visual_question_answering( *, model: Optional[str] = None, top_k: Optional[int] = None, - ) -> List[VisualQuestionAnsweringOutputElement]: + ) -> list[VisualQuestionAnsweringOutputElement]: """ Answering open-ended questions based on an image. @@ -3076,7 +3076,7 @@ def visual_question_answering( The number of answers to return (will be chosen by order of likelihood). Note that we return less than topk answers if there are not enough options available within the context. Returns: - `List[VisualQuestionAnsweringOutputElement]`: a list of [`VisualQuestionAnsweringOutputElement`] items containing the predicted label and associated probability. + `list[VisualQuestionAnsweringOutputElement]`: a list of [`VisualQuestionAnsweringOutputElement`] items containing the predicted label and associated probability. Raises: `InferenceTimeoutError`: @@ -3114,21 +3114,21 @@ def visual_question_answering( def zero_shot_classification( self, text: str, - candidate_labels: List[str], + candidate_labels: list[str], *, multi_label: Optional[bool] = False, hypothesis_template: Optional[str] = None, model: Optional[str] = None, - ) -> List[ZeroShotClassificationOutputElement]: + ) -> list[ZeroShotClassificationOutputElement]: """ Provide as input a text and a set of candidate labels to classify the input text. Args: text (`str`): The input text to classify. - candidate_labels (`List[str]`): + candidate_labels (`list[str]`): The set of possible class labels to classify the text into. - labels (`List[str]`, *optional*): + labels (`list[str]`, *optional*): (deprecated) List of strings. Each string is the verbalization of a possible label for the input text. multi_label (`bool`, *optional*): Whether multiple candidate labels can be true. If false, the scores are normalized such that the sum of @@ -3143,7 +3143,7 @@ def zero_shot_classification( Returns: - `List[ZeroShotClassificationOutputElement]`: List of [`ZeroShotClassificationOutputElement`] items containing the predicted labels and their confidence. + `list[ZeroShotClassificationOutputElement]`: List of [`ZeroShotClassificationOutputElement`] items containing the predicted labels and their confidence. Raises: [`InferenceTimeoutError`]: @@ -3220,22 +3220,22 @@ def zero_shot_classification( def zero_shot_image_classification( self, image: ContentT, - candidate_labels: List[str], + candidate_labels: list[str], *, model: Optional[str] = None, hypothesis_template: Optional[str] = None, # deprecated argument - labels: List[str] = None, # type: ignore - ) -> List[ZeroShotImageClassificationOutputElement]: + labels: list[str] = None, # type: ignore + ) -> list[ZeroShotImageClassificationOutputElement]: """ Provide input image and text labels to predict text labels for the image. Args: image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`): The input image to caption. It can be raw bytes, an image file, a URL to an online image, or a PIL Image. - candidate_labels (`List[str]`): + candidate_labels (`list[str]`): The candidate labels for this image - labels (`List[str]`, *optional*): + labels (`list[str]`, *optional*): (deprecated) List of string possible labels. There must be at least 2 labels. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed @@ -3245,7 +3245,7 @@ def zero_shot_image_classification( replacing the placeholder with the candidate labels. Returns: - `List[ZeroShotImageClassificationOutputElement]`: List of [`ZeroShotImageClassificationOutputElement`] items containing the predicted labels and their confidence. + `list[ZeroShotImageClassificationOutputElement]`: List of [`ZeroShotImageClassificationOutputElement`] items containing the predicted labels and their confidence. Raises: [`InferenceTimeoutError`]: @@ -3284,7 +3284,7 @@ def zero_shot_image_classification( response = self._inner_post(request_parameters) return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response) - def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]: + def get_endpoint_info(self, *, model: Optional[str] = None) -> dict[str, Any]: """ Get information about the deployed endpoint. @@ -3297,7 +3297,7 @@ def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]: Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. Returns: - `Dict[str, Any]`: Information about the endpoint. + `dict[str, Any]`: Information about the endpoint. Example: ```py diff --git a/src/huggingface_hub/inference/_common.py b/src/huggingface_hub/inference/_common.py index aca297df34..b79713a934 100644 --- a/src/huggingface_hub/inference/_common.py +++ b/src/huggingface_hub/inference/_common.py @@ -21,20 +21,7 @@ import mimetypes from dataclasses import dataclass from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - AsyncIterable, - BinaryIO, - Dict, - Iterable, - List, - Literal, - NoReturn, - Optional, - Union, - overload, -) +from typing import TYPE_CHECKING, Any, AsyncIterable, BinaryIO, Iterable, Literal, NoReturn, Optional, Union, overload import httpx @@ -71,9 +58,9 @@ class RequestParameters: url: str task: str model: Optional[str] - json: Optional[Union[str, Dict, List]] + json: Optional[Union[str, dict, list]] data: Optional[bytes] - headers: Dict[str, Any] + headers: dict[str, Any] class MimeBytes(bytes): @@ -240,7 +227,7 @@ def _b64_to_image(encoded_image: str) -> "Image": return Image.open(io.BytesIO(base64.b64decode(encoded_image))) -def _bytes_to_list(content: bytes) -> List: +def _bytes_to_list(content: bytes) -> list: """Parse bytes from a Response object into a Python list. Expects the response body to be JSON-encoded data. @@ -251,7 +238,7 @@ def _bytes_to_list(content: bytes) -> List: return json.loads(content.decode()) -def _bytes_to_dict(content: bytes) -> Dict: +def _bytes_to_dict(content: bytes) -> dict: """Parse bytes from a Response object into a Python dictionary. Expects the response body to be JSON-encoded data. @@ -271,7 +258,7 @@ def _bytes_to_image(content: bytes) -> "Image": return Image.open(io.BytesIO(content)) -def _as_dict(response: Union[bytes, Dict]) -> Dict: +def _as_dict(response: Union[bytes, dict]) -> dict: return json.loads(response) if isinstance(response, bytes) else response @@ -397,14 +384,14 @@ async def _async_yield_from(client: httpx.AsyncClient, response: httpx.Response) # For more details, see https://github.com/huggingface/text-generation-inference and # https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task. -_UNSUPPORTED_TEXT_GENERATION_KWARGS: Dict[Optional[str], List[str]] = {} +_UNSUPPORTED_TEXT_GENERATION_KWARGS: dict[Optional[str], list[str]] = {} -def _set_unsupported_text_generation_kwargs(model: Optional[str], unsupported_kwargs: List[str]) -> None: +def _set_unsupported_text_generation_kwargs(model: Optional[str], unsupported_kwargs: list[str]) -> None: _UNSUPPORTED_TEXT_GENERATION_KWARGS.setdefault(model, []).extend(unsupported_kwargs) -def _get_unsupported_text_generation_kwargs(model: Optional[str]) -> List[str]: +def _get_unsupported_text_generation_kwargs(model: Optional[str]) -> list[str]: return _UNSUPPORTED_TEXT_GENERATION_KWARGS.get(model, []) diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index b25a231052..6d1acab8fb 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -25,7 +25,7 @@ import re import warnings from contextlib import AsyncExitStack -from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, Optional, Union, overload +from typing import TYPE_CHECKING, Any, AsyncIterable, Literal, Optional, Union, overload import httpx @@ -135,13 +135,13 @@ class AsyncInferenceClient: arguments are mutually exclusive and have the exact same behavior. timeout (`float`, `optional`): The maximum number of seconds to wait for a response from the server. Defaults to None, meaning it will loop until the server is available. - headers (`Dict[str, str]`, `optional`): + headers (`dict[str, str]`, `optional`): Additional headers to send to the server. By default only the authorization and user-agent headers are sent. Values in this dictionary will override the default values. bill_to (`str`, `optional`): The billing account to use for the requests. By default the requests are billed on the user's account. Requests can only be billed to an organization the user is a member of, and which has subscribed to Enterprise Hub. - cookies (`Dict[str, str]`, `optional`): + cookies (`dict[str, str]`, `optional`): Additional cookies to send to the server. base_url (`str`, `optional`): Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`] @@ -159,8 +159,8 @@ def __init__( provider: Optional[PROVIDER_OR_POLICY_T] = None, token: Optional[str] = None, timeout: Optional[float] = None, - headers: Optional[Dict[str, str]] = None, - cookies: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, + cookies: Optional[dict[str, str]] = None, bill_to: Optional[str] = None, # OpenAI compatibility base_url: Optional[str] = None, @@ -321,7 +321,7 @@ async def audio_classification( model: Optional[str] = None, top_k: Optional[int] = None, function_to_apply: Optional["AudioClassificationOutputTransform"] = None, - ) -> List[AudioClassificationOutputElement]: + ) -> list[AudioClassificationOutputElement]: """ Perform audio classification on the provided audio content. @@ -339,7 +339,7 @@ async def audio_classification( The function to apply to the model outputs in order to retrieve the scores. Returns: - `List[AudioClassificationOutputElement]`: List of [`AudioClassificationOutputElement`] items containing the predicted labels and their confidence. + `list[AudioClassificationOutputElement]`: List of [`AudioClassificationOutputElement`] items containing the predicted labels and their confidence. Raises: [`InferenceTimeoutError`]: @@ -377,7 +377,7 @@ async def audio_to_audio( audio: ContentT, *, model: Optional[str] = None, - ) -> List[AudioToAudioOutputElement]: + ) -> list[AudioToAudioOutputElement]: """ Performs multiple tasks related to audio-to-audio depending on the model (eg: speech enhancement, source separation). @@ -391,7 +391,7 @@ async def audio_to_audio( audio_to_audio will be used. Returns: - `List[AudioToAudioOutputElement]`: A list of [`AudioToAudioOutputElement`] items containing audios label, content-type, and audio content in blob. + `list[AudioToAudioOutputElement]`: A list of [`AudioToAudioOutputElement`] items containing audios label, content-type, and audio content in blob. Raises: `InferenceTimeoutError`: @@ -430,7 +430,7 @@ async def automatic_speech_recognition( audio: ContentT, *, model: Optional[str] = None, - extra_body: Optional[Dict] = None, + extra_body: Optional[dict] = None, ) -> AutomaticSpeechRecognitionOutput: """ Perform automatic speech recognition (ASR or audio-to-text) on the given audio content. @@ -441,7 +441,7 @@ async def automatic_speech_recognition( model (`str`, *optional*): The model to use for ASR. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for ASR will be used. - extra_body (`Dict`, *optional*): + extra_body (`dict`, *optional*): Additional provider-specific parameters to pass to the model. Refer to the provider's documentation for supported parameters. Returns: @@ -477,105 +477,105 @@ async def automatic_speech_recognition( @overload async def chat_completion( # type: ignore self, - messages: List[Union[Dict, ChatCompletionInputMessage]], + messages: list[Union[dict, ChatCompletionInputMessage]], *, model: Optional[str] = None, stream: Literal[False] = False, frequency_penalty: Optional[float] = None, - logit_bias: Optional[List[float]] = None, + logit_bias: Optional[list[float]] = None, logprobs: Optional[bool] = None, max_tokens: Optional[int] = None, n: Optional[int] = None, presence_penalty: Optional[float] = None, response_format: Optional[ChatCompletionInputGrammarType] = None, seed: Optional[int] = None, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, stream_options: Optional[ChatCompletionInputStreamOptions] = None, temperature: Optional[float] = None, tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None, tool_prompt: Optional[str] = None, - tools: Optional[List[ChatCompletionInputTool]] = None, + tools: Optional[list[ChatCompletionInputTool]] = None, top_logprobs: Optional[int] = None, top_p: Optional[float] = None, - extra_body: Optional[Dict] = None, + extra_body: Optional[dict] = None, ) -> ChatCompletionOutput: ... @overload async def chat_completion( # type: ignore self, - messages: List[Union[Dict, ChatCompletionInputMessage]], + messages: list[Union[dict, ChatCompletionInputMessage]], *, model: Optional[str] = None, stream: Literal[True] = True, frequency_penalty: Optional[float] = None, - logit_bias: Optional[List[float]] = None, + logit_bias: Optional[list[float]] = None, logprobs: Optional[bool] = None, max_tokens: Optional[int] = None, n: Optional[int] = None, presence_penalty: Optional[float] = None, response_format: Optional[ChatCompletionInputGrammarType] = None, seed: Optional[int] = None, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, stream_options: Optional[ChatCompletionInputStreamOptions] = None, temperature: Optional[float] = None, tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None, tool_prompt: Optional[str] = None, - tools: Optional[List[ChatCompletionInputTool]] = None, + tools: Optional[list[ChatCompletionInputTool]] = None, top_logprobs: Optional[int] = None, top_p: Optional[float] = None, - extra_body: Optional[Dict] = None, + extra_body: Optional[dict] = None, ) -> AsyncIterable[ChatCompletionStreamOutput]: ... @overload async def chat_completion( self, - messages: List[Union[Dict, ChatCompletionInputMessage]], + messages: list[Union[dict, ChatCompletionInputMessage]], *, model: Optional[str] = None, stream: bool = False, frequency_penalty: Optional[float] = None, - logit_bias: Optional[List[float]] = None, + logit_bias: Optional[list[float]] = None, logprobs: Optional[bool] = None, max_tokens: Optional[int] = None, n: Optional[int] = None, presence_penalty: Optional[float] = None, response_format: Optional[ChatCompletionInputGrammarType] = None, seed: Optional[int] = None, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, stream_options: Optional[ChatCompletionInputStreamOptions] = None, temperature: Optional[float] = None, tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None, tool_prompt: Optional[str] = None, - tools: Optional[List[ChatCompletionInputTool]] = None, + tools: Optional[list[ChatCompletionInputTool]] = None, top_logprobs: Optional[int] = None, top_p: Optional[float] = None, - extra_body: Optional[Dict] = None, + extra_body: Optional[dict] = None, ) -> Union[ChatCompletionOutput, AsyncIterable[ChatCompletionStreamOutput]]: ... async def chat_completion( self, - messages: List[Union[Dict, ChatCompletionInputMessage]], + messages: list[Union[dict, ChatCompletionInputMessage]], *, model: Optional[str] = None, stream: bool = False, # Parameters from ChatCompletionInput (handled manually) frequency_penalty: Optional[float] = None, - logit_bias: Optional[List[float]] = None, + logit_bias: Optional[list[float]] = None, logprobs: Optional[bool] = None, max_tokens: Optional[int] = None, n: Optional[int] = None, presence_penalty: Optional[float] = None, response_format: Optional[ChatCompletionInputGrammarType] = None, seed: Optional[int] = None, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, stream_options: Optional[ChatCompletionInputStreamOptions] = None, temperature: Optional[float] = None, tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None, tool_prompt: Optional[str] = None, - tools: Optional[List[ChatCompletionInputTool]] = None, + tools: Optional[list[ChatCompletionInputTool]] = None, top_logprobs: Optional[int] = None, top_p: Optional[float] = None, - extra_body: Optional[Dict] = None, + extra_body: Optional[dict] = None, ) -> Union[ChatCompletionOutput, AsyncIterable[ChatCompletionStreamOutput]]: """ A method for completing conversations using a specified language model. @@ -605,7 +605,7 @@ async def chat_completion( frequency_penalty (`float`, *optional*): Penalizes new tokens based on their existing frequency in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0. - logit_bias (`List[float]`, *optional*): + logit_bias (`list[float]`, *optional*): Adjusts the likelihood of specific tokens appearing in the generated output. logprobs (`bool`, *optional*): Whether to return log probabilities of the output tokens or not. If true, returns the log @@ -621,7 +621,7 @@ async def chat_completion( Grammar constraints. Can be either a JSONSchema or a regex. seed (Optional[`int`], *optional*): Seed for reproducible control flow. Defaults to None. - stop (`List[str]`, *optional*): + stop (`list[str]`, *optional*): Up to four strings which trigger the end of the response. Defaults to None. stream (`bool`, *optional*): @@ -645,7 +645,7 @@ async def chat_completion( tools (List of [`ChatCompletionInputTool`], *optional*): A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. - extra_body (`Dict`, *optional*): + extra_body (`dict`, *optional*): Additional provider-specific parameters to pass to the model. Refer to the provider's documentation for supported parameters. Returns: @@ -977,8 +977,8 @@ async def document_question_answering( max_question_len: Optional[int] = None, max_seq_len: Optional[int] = None, top_k: Optional[int] = None, - word_boxes: Optional[List[Union[List[float], str]]] = None, - ) -> List[DocumentQuestionAnsweringOutputElement]: + word_boxes: Optional[list[Union[list[float], str]]] = None, + ) -> list[DocumentQuestionAnsweringOutputElement]: """ Answer questions on document images. @@ -1008,11 +1008,11 @@ async def document_question_answering( top_k (`int`, *optional*): The number of answers to return (will be chosen by order of likelihood). Can return less than top_k answers if there are not enough options available within the context. - word_boxes (`List[Union[List[float], str`, *optional*): + word_boxes (`list[Union[list[float], str`, *optional*): A list of words and bounding boxes (normalized 0->1000). If provided, the inference will skip the OCR step and use the provided bounding boxes instead. Returns: - `List[DocumentQuestionAnsweringOutputElement]`: a list of [`DocumentQuestionAnsweringOutputElement`] items containing the predicted label, associated probability, word ids, and page number. + `list[DocumentQuestionAnsweringOutputElement]`: a list of [`DocumentQuestionAnsweringOutputElement`] items containing the predicted label, associated probability, word ids, and page number. Raises: [`InferenceTimeoutError`]: @@ -1032,7 +1032,7 @@ async def document_question_answering( """ model_id = model or self.model provider_helper = get_provider_helper(self.provider, task="document-question-answering", model=model_id) - inputs: Dict[str, Any] = {"question": question, "image": _b64_encode(image)} + inputs: dict[str, Any] = {"question": question, "image": _b64_encode(image)} request_parameters = provider_helper.prepare_request( inputs=inputs, parameters={ @@ -1131,9 +1131,9 @@ async def fill_mask( text: str, *, model: Optional[str] = None, - targets: Optional[List[str]] = None, + targets: Optional[list[str]] = None, top_k: Optional[int] = None, - ) -> List[FillMaskOutputElement]: + ) -> list[FillMaskOutputElement]: """ Fill in a hole with a missing word (token to be precise). @@ -1143,14 +1143,14 @@ async def fill_mask( model (`str`, *optional*): The model to use for the fill mask task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended fill mask model will be used. - targets (`List[str`, *optional*): + targets (`list[str`, *optional*): When passed, the model will limit the scores to the passed targets instead of looking up in the whole vocabulary. If the provided targets are not in the model vocab, they will be tokenized and the first resulting token will be used (with a warning, and that might be slower). top_k (`int`, *optional*): When passed, overrides the number of predictions to return. Returns: - `List[FillMaskOutputElement]`: a list of [`FillMaskOutputElement`] items containing the predicted label, associated + `list[FillMaskOutputElement]`: a list of [`FillMaskOutputElement`] items containing the predicted label, associated probability, token reference, and completed text. Raises: @@ -1190,7 +1190,7 @@ async def image_classification( model: Optional[str] = None, function_to_apply: Optional["ImageClassificationOutputTransform"] = None, top_k: Optional[int] = None, - ) -> List[ImageClassificationOutputElement]: + ) -> list[ImageClassificationOutputElement]: """ Perform image classification on the given image using the specified model. @@ -1205,7 +1205,7 @@ async def image_classification( top_k (`int`, *optional*): When specified, limits the output to the top K most probable classes. Returns: - `List[ImageClassificationOutputElement]`: a list of [`ImageClassificationOutputElement`] items containing the predicted label and associated probability. + `list[ImageClassificationOutputElement]`: a list of [`ImageClassificationOutputElement`] items containing the predicted label and associated probability. Raises: [`InferenceTimeoutError`]: @@ -1243,7 +1243,7 @@ async def image_segmentation( overlap_mask_area_threshold: Optional[float] = None, subtask: Optional["ImageSegmentationSubtask"] = None, threshold: Optional[float] = None, - ) -> List[ImageSegmentationOutputElement]: + ) -> list[ImageSegmentationOutputElement]: """ Perform image segmentation on the given image using the specified model. @@ -1268,7 +1268,7 @@ async def image_segmentation( threshold (`float`, *optional*): Probability threshold to filter out predicted masks. Returns: - `List[ImageSegmentationOutputElement]`: A list of [`ImageSegmentationOutputElement`] items containing the segmented masks and associated attributes. + `list[ImageSegmentationOutputElement]`: A list of [`ImageSegmentationOutputElement`] items containing the segmented masks and associated attributes. Raises: [`InferenceTimeoutError`]: @@ -1507,12 +1507,12 @@ async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) - api_key=self.token, ) response = await self._inner_post(request_parameters) - output_list: List[ImageToTextOutput] = ImageToTextOutput.parse_obj_as_list(response) + output_list: list[ImageToTextOutput] = ImageToTextOutput.parse_obj_as_list(response) return output_list[0] async def object_detection( self, image: ContentT, *, model: Optional[str] = None, threshold: Optional[float] = None - ) -> List[ObjectDetectionOutputElement]: + ) -> list[ObjectDetectionOutputElement]: """ Perform object detection on the given image using the specified model. @@ -1531,7 +1531,7 @@ async def object_detection( threshold (`float`, *optional*): The probability necessary to make a prediction. Returns: - `List[ObjectDetectionOutputElement]`: A list of [`ObjectDetectionOutputElement`] items containing the bounding boxes and associated attributes. + `list[ObjectDetectionOutputElement]`: A list of [`ObjectDetectionOutputElement`] items containing the bounding boxes and associated attributes. Raises: [`InferenceTimeoutError`]: @@ -1575,7 +1575,7 @@ async def question_answering( max_question_len: Optional[int] = None, max_seq_len: Optional[int] = None, top_k: Optional[int] = None, - ) -> Union[QuestionAnsweringOutputElement, List[QuestionAnsweringOutputElement]]: + ) -> Union[QuestionAnsweringOutputElement, list[QuestionAnsweringOutputElement]]: """ Retrieve the answer to a question from a given text. @@ -1607,7 +1607,7 @@ async def question_answering( topk answers if there are not enough options available within the context. Returns: - Union[`QuestionAnsweringOutputElement`, List[`QuestionAnsweringOutputElement`]]: + Union[`QuestionAnsweringOutputElement`, list[`QuestionAnsweringOutputElement`]]: When top_k is 1 or not provided, it returns a single `QuestionAnsweringOutputElement`. When top_k is greater than 1, it returns a list of `QuestionAnsweringOutputElement`. Raises: @@ -1648,15 +1648,15 @@ async def question_answering( return output async def sentence_similarity( - self, sentence: str, other_sentences: List[str], *, model: Optional[str] = None - ) -> List[float]: + self, sentence: str, other_sentences: list[str], *, model: Optional[str] = None + ) -> list[float]: """ Compute the semantic similarity between a sentence and a list of other sentences by comparing their embeddings. Args: sentence (`str`): The main sentence to compare to others. - other_sentences (`List[str]`): + other_sentences (`list[str]`): The list of sentences to compare to. model (`str`, *optional*): The model to use for the sentence similarity task. Can be a model ID hosted on the Hugging Face Hub or a URL to @@ -1664,7 +1664,7 @@ async def sentence_similarity( Defaults to None. Returns: - `List[float]`: The embedding representing the input text. + `list[float]`: The embedding representing the input text. Raises: [`InferenceTimeoutError`]: @@ -1707,7 +1707,7 @@ async def summarization( *, model: Optional[str] = None, clean_up_tokenization_spaces: Optional[bool] = None, - generate_parameters: Optional[Dict[str, Any]] = None, + generate_parameters: Optional[dict[str, Any]] = None, truncation: Optional["SummarizationTruncationStrategy"] = None, ) -> SummarizationOutput: """ @@ -1721,7 +1721,7 @@ async def summarization( Inference Endpoint. If not provided, the default recommended model for summarization will be used. clean_up_tokenization_spaces (`bool`, *optional*): Whether to clean up the potential extra spaces in the text output. - generate_parameters (`Dict[str, Any]`, *optional*): + generate_parameters (`dict[str, Any]`, *optional*): Additional parametrization of the text generation algorithm. truncation (`"SummarizationTruncationStrategy"`, *optional*): The truncation strategy to use. @@ -1762,7 +1762,7 @@ async def summarization( async def table_question_answering( self, - table: Dict[str, Any], + table: dict[str, Any], query: str, *, model: Optional[str] = None, @@ -1823,12 +1823,12 @@ async def table_question_answering( response = await self._inner_post(request_parameters) return TableQuestionAnsweringOutputElement.parse_obj_as_instance(response) - async def tabular_classification(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[str]: + async def tabular_classification(self, table: dict[str, Any], *, model: Optional[str] = None) -> list[str]: """ Classifying a target category (a group) based on a set of attributes. Args: - table (`Dict[str, Any]`): + table (`dict[str, Any]`): Set of attributes to classify. model (`str`, *optional*): The model to use for the tabular classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to @@ -1879,12 +1879,12 @@ async def tabular_classification(self, table: Dict[str, Any], *, model: Optional response = await self._inner_post(request_parameters) return _bytes_to_list(response) - async def tabular_regression(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[float]: + async def tabular_regression(self, table: dict[str, Any], *, model: Optional[str] = None) -> list[float]: """ Predicting a numerical target value given a set of attributes/features in a table. Args: - table (`Dict[str, Any]`): + table (`dict[str, Any]`): Set of attributes stored in a table. The attributes used to predict the target can be both numerical and categorical. model (`str`, *optional*): The model to use for the tabular regression task. Can be a model ID hosted on the Hugging Face Hub or a URL to @@ -1937,7 +1937,7 @@ async def text_classification( model: Optional[str] = None, top_k: Optional[int] = None, function_to_apply: Optional["TextClassificationOutputTransform"] = None, - ) -> List[TextClassificationOutputElement]: + ) -> list[TextClassificationOutputElement]: """ Perform text classification (e.g. sentiment-analysis) on the given text. @@ -1954,7 +1954,7 @@ async def text_classification( The function to apply to the model outputs in order to retrieve the scores. Returns: - `List[TextClassificationOutputElement]`: a list of [`TextClassificationOutputElement`] items containing the predicted label and associated probability. + `list[TextClassificationOutputElement]`: a list of [`TextClassificationOutputElement`] items containing the predicted label and associated probability. Raises: [`InferenceTimeoutError`]: @@ -2008,8 +2008,8 @@ async def text_generation( repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = None, seed: Optional[int] = None, - stop: Optional[List[str]] = None, - stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + stop: Optional[list[str]] = None, + stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, @@ -2038,8 +2038,8 @@ async def text_generation( repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = None, seed: Optional[int] = None, - stop: Optional[List[str]] = None, - stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + stop: Optional[list[str]] = None, + stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, @@ -2068,8 +2068,8 @@ async def text_generation( repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = None, # Manual default value seed: Optional[int] = None, - stop: Optional[List[str]] = None, - stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + stop: Optional[list[str]] = None, + stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, @@ -2098,8 +2098,8 @@ async def text_generation( repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = None, seed: Optional[int] = None, - stop: Optional[List[str]] = None, - stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + stop: Optional[list[str]] = None, + stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, @@ -2128,8 +2128,8 @@ async def text_generation( repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = None, seed: Optional[int] = None, - stop: Optional[List[str]] = None, - stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + stop: Optional[list[str]] = None, + stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, @@ -2157,8 +2157,8 @@ async def text_generation( repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = None, seed: Optional[int] = None, - stop: Optional[List[str]] = None, - stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + stop: Optional[list[str]] = None, + stop_sequences: Optional[list[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, @@ -2214,9 +2214,9 @@ async def text_generation( Whether to prepend the prompt to the generated text seed (`int`, *optional*): Random sampling seed - stop (`List[str]`, *optional*): + stop (`list[str]`, *optional*): Stop generating tokens if a member of `stop` is generated. - stop_sequences (`List[str]`, *optional*): + stop_sequences (`list[str]`, *optional*): Deprecated argument. Use `stop` instead. temperature (`float`, *optional*): The value used to module the logits distribution. @@ -2494,7 +2494,7 @@ async def text_to_image( model: Optional[str] = None, scheduler: Optional[str] = None, seed: Optional[int] = None, - extra_body: Optional[Dict[str, Any]] = None, + extra_body: Optional[dict[str, Any]] = None, ) -> "Image": """ Generate an image based on a given text using a specified model. @@ -2532,7 +2532,7 @@ async def text_to_image( Override the scheduler with a compatible one. seed (`int`, *optional*): Seed for the random number generator. - extra_body (`Dict[str, Any]`, *optional*): + extra_body (`dict[str, Any]`, *optional*): Additional provider-specific parameters to pass to the model. Refer to the provider's documentation for supported parameters. @@ -2632,11 +2632,11 @@ async def text_to_video( *, model: Optional[str] = None, guidance_scale: Optional[float] = None, - negative_prompt: Optional[List[str]] = None, + negative_prompt: Optional[list[str]] = None, num_frames: Optional[float] = None, num_inference_steps: Optional[int] = None, seed: Optional[int] = None, - extra_body: Optional[Dict[str, Any]] = None, + extra_body: Optional[dict[str, Any]] = None, ) -> bytes: """ Generate a video based on a given text. @@ -2655,7 +2655,7 @@ async def text_to_video( guidance_scale (`float`, *optional*): A higher guidance scale value encourages the model to generate videos closely linked to the text prompt, but values too high may cause saturation and other artifacts. - negative_prompt (`List[str]`, *optional*): + negative_prompt (`list[str]`, *optional*): One or several prompt to guide what NOT to include in video generation. num_frames (`float`, *optional*): The num_frames parameter determines how many video frames are generated. @@ -2664,7 +2664,7 @@ async def text_to_video( expense of slower inference. seed (`int`, *optional*): Seed for the random number generator. - extra_body (`Dict[str, Any]`, *optional*): + extra_body (`dict[str, Any]`, *optional*): Additional provider-specific parameters to pass to the model. Refer to the provider's documentation for supported parameters. @@ -2744,7 +2744,7 @@ async def text_to_speech( top_p: Optional[float] = None, typical_p: Optional[float] = None, use_cache: Optional[bool] = None, - extra_body: Optional[Dict[str, Any]] = None, + extra_body: Optional[dict[str, Any]] = None, ) -> bytes: """ Synthesize an audio of a voice pronouncing a given text. @@ -2806,7 +2806,7 @@ async def text_to_speech( paper](https://hf.co/papers/2202.00666) for more details. use_cache (`bool`, *optional*): Whether the model should use the past last key/values attentions to speed up decoding - extra_body (`Dict[str, Any]`, *optional*): + extra_body (`dict[str, Any]`, *optional*): Additional provider-specific parameters to pass to the model. Refer to the provider's documentation for supported parameters. Returns: @@ -2939,9 +2939,9 @@ async def token_classification( *, model: Optional[str] = None, aggregation_strategy: Optional["TokenClassificationAggregationStrategy"] = None, - ignore_labels: Optional[List[str]] = None, + ignore_labels: Optional[list[str]] = None, stride: Optional[int] = None, - ) -> List[TokenClassificationOutputElement]: + ) -> list[TokenClassificationOutputElement]: """ Perform token classification on the given text. Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text. @@ -2955,13 +2955,13 @@ async def token_classification( Defaults to None. aggregation_strategy (`"TokenClassificationAggregationStrategy"`, *optional*): The strategy used to fuse tokens based on model predictions - ignore_labels (`List[str`, *optional*): + ignore_labels (`list[str`, *optional*): A list of labels to ignore stride (`int`, *optional*): The number of overlapping tokens between chunks when splitting the input text. Returns: - `List[TokenClassificationOutputElement]`: List of [`TokenClassificationOutputElement`] items containing the entity group, confidence score, word, start and end index. + `list[TokenClassificationOutputElement]`: List of [`TokenClassificationOutputElement`] items containing the entity group, confidence score, word, start and end index. Raises: [`InferenceTimeoutError`]: @@ -3018,7 +3018,7 @@ async def translation( tgt_lang: Optional[str] = None, clean_up_tokenization_spaces: Optional[bool] = None, truncation: Optional["TranslationTruncationStrategy"] = None, - generate_parameters: Optional[Dict[str, Any]] = None, + generate_parameters: Optional[dict[str, Any]] = None, ) -> TranslationOutput: """ Convert text from one language to another. @@ -3043,7 +3043,7 @@ async def translation( Whether to clean up the potential extra spaces in the text output. truncation (`"TranslationTruncationStrategy"`, *optional*): The truncation strategy to use. - generate_parameters (`Dict[str, Any]`, *optional*): + generate_parameters (`dict[str, Any]`, *optional*): Additional parametrization of the text generation algorithm. Returns: @@ -3106,7 +3106,7 @@ async def visual_question_answering( *, model: Optional[str] = None, top_k: Optional[int] = None, - ) -> List[VisualQuestionAnsweringOutputElement]: + ) -> list[VisualQuestionAnsweringOutputElement]: """ Answering open-ended questions based on an image. @@ -3123,7 +3123,7 @@ async def visual_question_answering( The number of answers to return (will be chosen by order of likelihood). Note that we return less than topk answers if there are not enough options available within the context. Returns: - `List[VisualQuestionAnsweringOutputElement]`: a list of [`VisualQuestionAnsweringOutputElement`] items containing the predicted label and associated probability. + `list[VisualQuestionAnsweringOutputElement]`: a list of [`VisualQuestionAnsweringOutputElement`] items containing the predicted label and associated probability. Raises: `InferenceTimeoutError`: @@ -3162,21 +3162,21 @@ async def visual_question_answering( async def zero_shot_classification( self, text: str, - candidate_labels: List[str], + candidate_labels: list[str], *, multi_label: Optional[bool] = False, hypothesis_template: Optional[str] = None, model: Optional[str] = None, - ) -> List[ZeroShotClassificationOutputElement]: + ) -> list[ZeroShotClassificationOutputElement]: """ Provide as input a text and a set of candidate labels to classify the input text. Args: text (`str`): The input text to classify. - candidate_labels (`List[str]`): + candidate_labels (`list[str]`): The set of possible class labels to classify the text into. - labels (`List[str]`, *optional*): + labels (`list[str]`, *optional*): (deprecated) List of strings. Each string is the verbalization of a possible label for the input text. multi_label (`bool`, *optional*): Whether multiple candidate labels can be true. If false, the scores are normalized such that the sum of @@ -3191,7 +3191,7 @@ async def zero_shot_classification( Returns: - `List[ZeroShotClassificationOutputElement]`: List of [`ZeroShotClassificationOutputElement`] items containing the predicted labels and their confidence. + `list[ZeroShotClassificationOutputElement]`: List of [`ZeroShotClassificationOutputElement`] items containing the predicted labels and their confidence. Raises: [`InferenceTimeoutError`]: @@ -3270,22 +3270,22 @@ async def zero_shot_classification( async def zero_shot_image_classification( self, image: ContentT, - candidate_labels: List[str], + candidate_labels: list[str], *, model: Optional[str] = None, hypothesis_template: Optional[str] = None, # deprecated argument - labels: List[str] = None, # type: ignore - ) -> List[ZeroShotImageClassificationOutputElement]: + labels: list[str] = None, # type: ignore + ) -> list[ZeroShotImageClassificationOutputElement]: """ Provide input image and text labels to predict text labels for the image. Args: image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`): The input image to caption. It can be raw bytes, an image file, a URL to an online image, or a PIL Image. - candidate_labels (`List[str]`): + candidate_labels (`list[str]`): The candidate labels for this image - labels (`List[str]`, *optional*): + labels (`list[str]`, *optional*): (deprecated) List of string possible labels. There must be at least 2 labels. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed @@ -3295,7 +3295,7 @@ async def zero_shot_image_classification( replacing the placeholder with the candidate labels. Returns: - `List[ZeroShotImageClassificationOutputElement]`: List of [`ZeroShotImageClassificationOutputElement`] items containing the predicted labels and their confidence. + `list[ZeroShotImageClassificationOutputElement]`: List of [`ZeroShotImageClassificationOutputElement`] items containing the predicted labels and their confidence. Raises: [`InferenceTimeoutError`]: @@ -3335,7 +3335,7 @@ async def zero_shot_image_classification( response = await self._inner_post(request_parameters) return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response) - async def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]: + async def get_endpoint_info(self, *, model: Optional[str] = None) -> dict[str, Any]: """ Get information about the deployed endpoint. @@ -3348,7 +3348,7 @@ async def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, A Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. Returns: - `Dict[str, Any]`: Information about the endpoint. + `dict[str, Any]`: Information about the endpoint. Example: ```py diff --git a/src/huggingface_hub/inference/_generated/types/automatic_speech_recognition.py b/src/huggingface_hub/inference/_generated/types/automatic_speech_recognition.py index f6bfd28256..2e6afc4411 100644 --- a/src/huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +++ b/src/huggingface_hub/inference/_generated/types/automatic_speech_recognition.py @@ -3,7 +3,7 @@ # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. -from typing import List, Literal, Optional, Union +from typing import Literal, Optional, Union from .base import BaseInferenceType, dataclass_with_extra @@ -97,7 +97,7 @@ class AutomaticSpeechRecognitionInput(BaseInferenceType): class AutomaticSpeechRecognitionOutputChunk(BaseInferenceType): text: str """A chunk of text identified by the model""" - timestamp: List[float] + timestamp: list[float] """The start and end timestamps corresponding with the text""" @@ -107,7 +107,7 @@ class AutomaticSpeechRecognitionOutput(BaseInferenceType): text: str """The recognized text.""" - chunks: Optional[List[AutomaticSpeechRecognitionOutputChunk]] = None + chunks: Optional[list[AutomaticSpeechRecognitionOutputChunk]] = None """When returnTimestamps is enabled, chunks contains a list of audio chunks identified by the model. """ diff --git a/src/huggingface_hub/inference/_generated/types/base.py b/src/huggingface_hub/inference/_generated/types/base.py index 1f0c4687ce..2c6df61c0e 100644 --- a/src/huggingface_hub/inference/_generated/types/base.py +++ b/src/huggingface_hub/inference/_generated/types/base.py @@ -15,8 +15,9 @@ import inspect import json +import types from dataclasses import asdict, dataclass -from typing import Any, Dict, List, Type, TypeVar, Union, get_args +from typing import Any, TypeVar, Union, get_args T = TypeVar("T", bound="BaseInferenceType") @@ -28,7 +29,7 @@ def _repr_with_extra(self): return f"{self.__class__.__name__}({', '.join(f'{k}={self.__dict__[k]!r}' for k in fields + other_fields)})" -def dataclass_with_extra(cls: Type[T]) -> Type[T]: +def dataclass_with_extra(cls: type[T]) -> type[T]: """Decorator to add a custom __repr__ method to a dataclass, showing all fields, including extra ones. This decorator only works with dataclasses that inherit from `BaseInferenceType`. @@ -49,7 +50,7 @@ class BaseInferenceType(dict): """ @classmethod - def parse_obj_as_list(cls: Type[T], data: Union[bytes, str, List, Dict]) -> List[T]: + def parse_obj_as_list(cls: type[T], data: Union[bytes, str, list, dict]) -> list[T]: """Alias to parse server response and return a single instance. See `parse_obj` for more details. @@ -60,7 +61,7 @@ def parse_obj_as_list(cls: Type[T], data: Union[bytes, str, List, Dict]) -> List return output @classmethod - def parse_obj_as_instance(cls: Type[T], data: Union[bytes, str, List, Dict]) -> T: + def parse_obj_as_instance(cls: type[T], data: Union[bytes, str, list, dict]) -> T: """Alias to parse server response and return a single instance. See `parse_obj` for more details. @@ -71,7 +72,7 @@ def parse_obj_as_instance(cls: Type[T], data: Union[bytes, str, List, Dict]) -> return output @classmethod - def parse_obj(cls: Type[T], data: Union[bytes, str, List, Dict]) -> Union[List[T], T]: + def parse_obj(cls: type[T], data: Union[bytes, str, list, dict]) -> Union[list[T], T]: """Parse server response as a dataclass or list of dataclasses. To enable future-compatibility, we want to handle cases where the server return more fields than expected. @@ -85,7 +86,7 @@ def parse_obj(cls: Type[T], data: Union[bytes, str, List, Dict]) -> Union[List[T data = json.loads(data) # If a list, parse each item individually - if isinstance(data, List): + if isinstance(data, list): return [cls.parse_obj(d) for d in data] # type: ignore [misc] # At this point, we expect a dict @@ -109,7 +110,9 @@ def parse_obj(cls: Type[T], data: Union[bytes, str, List, Dict]) -> Union[List[T else: expected_types = get_args(field_type) for expected_type in expected_types: - if getattr(expected_type, "_name", None) == "List": + if ( + isinstance(expected_type, types.GenericAlias) and expected_type.__origin__ is list + ) or getattr(expected_type, "_name", None) == "List": expected_type = get_args(expected_type)[ 0 ] # assume same type for all items in the list diff --git a/src/huggingface_hub/inference/_generated/types/chat_completion.py b/src/huggingface_hub/inference/_generated/types/chat_completion.py index ba708a7009..db814b01ae 100644 --- a/src/huggingface_hub/inference/_generated/types/chat_completion.py +++ b/src/huggingface_hub/inference/_generated/types/chat_completion.py @@ -3,7 +3,7 @@ # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Literal, Optional, Union from .base import BaseInferenceType, dataclass_with_extra @@ -40,9 +40,9 @@ class ChatCompletionInputToolCall(BaseInferenceType): @dataclass_with_extra class ChatCompletionInputMessage(BaseInferenceType): role: str - content: Optional[Union[List[ChatCompletionInputMessageChunk], str]] = None + content: Optional[Union[list[ChatCompletionInputMessageChunk], str]] = None name: Optional[str] = None - tool_calls: Optional[List[ChatCompletionInputToolCall]] = None + tool_calls: Optional[list[ChatCompletionInputToolCall]] = None @dataclass_with_extra @@ -56,7 +56,7 @@ class ChatCompletionInputJSONSchema(BaseInferenceType): A description of what the response format is for, used by the model to determine how to respond in the format. """ - schema: Optional[Dict[str, object]] = None + schema: Optional[dict[str, object]] = None """ The schema for the response format, described as a JSON Schema object. Learn how to build JSON schemas [here](https://json-schema.org/). @@ -129,14 +129,14 @@ class ChatCompletionInput(BaseInferenceType): https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts. """ - messages: List[ChatCompletionInputMessage] + messages: list[ChatCompletionInputMessage] """A list of messages comprising the conversation so far.""" frequency_penalty: Optional[float] = None """Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. """ - logit_bias: Optional[List[float]] = None + logit_bias: Optional[list[float]] = None """UNUSED Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens @@ -172,7 +172,7 @@ class ChatCompletionInput(BaseInferenceType): """ response_format: Optional[ChatCompletionInputGrammarType] = None seed: Optional[int] = None - stop: Optional[List[str]] = None + stop: Optional[list[str]] = None """Up to 4 sequences where the API will stop generating further tokens.""" stream: Optional[bool] = None stream_options: Optional[ChatCompletionInputStreamOptions] = None @@ -185,7 +185,7 @@ class ChatCompletionInput(BaseInferenceType): tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None tool_prompt: Optional[str] = None """A prompt to be appended before the tools""" - tools: Optional[List[ChatCompletionInputTool]] = None + tools: Optional[list[ChatCompletionInputTool]] = None """A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. @@ -213,12 +213,12 @@ class ChatCompletionOutputTopLogprob(BaseInferenceType): class ChatCompletionOutputLogprob(BaseInferenceType): logprob: float token: str - top_logprobs: List[ChatCompletionOutputTopLogprob] + top_logprobs: list[ChatCompletionOutputTopLogprob] @dataclass_with_extra class ChatCompletionOutputLogprobs(BaseInferenceType): - content: List[ChatCompletionOutputLogprob] + content: list[ChatCompletionOutputLogprob] @dataclass_with_extra @@ -241,7 +241,7 @@ class ChatCompletionOutputMessage(BaseInferenceType): content: Optional[str] = None reasoning: Optional[str] = None tool_call_id: Optional[str] = None - tool_calls: Optional[List[ChatCompletionOutputToolCall]] = None + tool_calls: Optional[list[ChatCompletionOutputToolCall]] = None @dataclass_with_extra @@ -267,7 +267,7 @@ class ChatCompletionOutput(BaseInferenceType): https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts. """ - choices: List[ChatCompletionOutputComplete] + choices: list[ChatCompletionOutputComplete] created: int id: str model: str @@ -295,7 +295,7 @@ class ChatCompletionStreamOutputDelta(BaseInferenceType): content: Optional[str] = None reasoning: Optional[str] = None tool_call_id: Optional[str] = None - tool_calls: Optional[List[ChatCompletionStreamOutputDeltaToolCall]] = None + tool_calls: Optional[list[ChatCompletionStreamOutputDeltaToolCall]] = None @dataclass_with_extra @@ -308,12 +308,12 @@ class ChatCompletionStreamOutputTopLogprob(BaseInferenceType): class ChatCompletionStreamOutputLogprob(BaseInferenceType): logprob: float token: str - top_logprobs: List[ChatCompletionStreamOutputTopLogprob] + top_logprobs: list[ChatCompletionStreamOutputTopLogprob] @dataclass_with_extra class ChatCompletionStreamOutputLogprobs(BaseInferenceType): - content: List[ChatCompletionStreamOutputLogprob] + content: list[ChatCompletionStreamOutputLogprob] @dataclass_with_extra @@ -339,7 +339,7 @@ class ChatCompletionStreamOutput(BaseInferenceType): https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts. """ - choices: List[ChatCompletionStreamOutputChoice] + choices: list[ChatCompletionStreamOutputChoice] created: int id: str model: str diff --git a/src/huggingface_hub/inference/_generated/types/depth_estimation.py b/src/huggingface_hub/inference/_generated/types/depth_estimation.py index 1e09bdffa1..765c3635f9 100644 --- a/src/huggingface_hub/inference/_generated/types/depth_estimation.py +++ b/src/huggingface_hub/inference/_generated/types/depth_estimation.py @@ -3,7 +3,7 @@ # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. -from typing import Any, Dict, Optional +from typing import Any, Optional from .base import BaseInferenceType, dataclass_with_extra @@ -14,7 +14,7 @@ class DepthEstimationInput(BaseInferenceType): inputs: Any """The input image data""" - parameters: Optional[Dict[str, Any]] = None + parameters: Optional[dict[str, Any]] = None """Additional inference parameters for Depth Estimation""" diff --git a/src/huggingface_hub/inference/_generated/types/document_question_answering.py b/src/huggingface_hub/inference/_generated/types/document_question_answering.py index 2457d2c8c2..e3886041d6 100644 --- a/src/huggingface_hub/inference/_generated/types/document_question_answering.py +++ b/src/huggingface_hub/inference/_generated/types/document_question_answering.py @@ -3,7 +3,7 @@ # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from .base import BaseInferenceType, dataclass_with_extra @@ -46,7 +46,7 @@ class DocumentQuestionAnsweringParameters(BaseInferenceType): """The number of answers to return (will be chosen by order of likelihood). Can return less than top_k answers if there are not enough options available within the context. """ - word_boxes: Optional[List[Union[List[float], str]]] = None + word_boxes: Optional[list[Union[list[float], str]]] = None """A list of words and bounding boxes (normalized 0->1000). If provided, the inference will skip the OCR step and use the provided bounding boxes instead. """ diff --git a/src/huggingface_hub/inference/_generated/types/feature_extraction.py b/src/huggingface_hub/inference/_generated/types/feature_extraction.py index e965ddbac2..a6b9aa1937 100644 --- a/src/huggingface_hub/inference/_generated/types/feature_extraction.py +++ b/src/huggingface_hub/inference/_generated/types/feature_extraction.py @@ -3,7 +3,7 @@ # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. -from typing import List, Literal, Optional, Union +from typing import Literal, Optional, Union from .base import BaseInferenceType, dataclass_with_extra @@ -19,7 +19,7 @@ class FeatureExtractionInput(BaseInferenceType): https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tei-import.ts. """ - inputs: Union[List[str], str] + inputs: Union[list[str], str] """The text or list of texts to embed.""" normalize: Optional[bool] = None prompt_name: Optional[str] = None diff --git a/src/huggingface_hub/inference/_generated/types/fill_mask.py b/src/huggingface_hub/inference/_generated/types/fill_mask.py index dfcdc56bc5..848421dc13 100644 --- a/src/huggingface_hub/inference/_generated/types/fill_mask.py +++ b/src/huggingface_hub/inference/_generated/types/fill_mask.py @@ -3,7 +3,7 @@ # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. -from typing import Any, List, Optional +from typing import Any, Optional from .base import BaseInferenceType, dataclass_with_extra @@ -12,7 +12,7 @@ class FillMaskParameters(BaseInferenceType): """Additional inference parameters for Fill Mask""" - targets: Optional[List[str]] = None + targets: Optional[list[str]] = None """When passed, the model will limit the scores to the passed targets instead of looking up in the whole vocabulary. If the provided targets are not in the model vocab, they will be tokenized and the first resulting token will be used (with a warning, and that might be diff --git a/src/huggingface_hub/inference/_generated/types/sentence_similarity.py b/src/huggingface_hub/inference/_generated/types/sentence_similarity.py index 66e8bb4d93..4dd42c0bd8 100644 --- a/src/huggingface_hub/inference/_generated/types/sentence_similarity.py +++ b/src/huggingface_hub/inference/_generated/types/sentence_similarity.py @@ -3,14 +3,14 @@ # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. -from typing import Any, Dict, List, Optional +from typing import Any, Optional from .base import BaseInferenceType, dataclass_with_extra @dataclass_with_extra class SentenceSimilarityInputData(BaseInferenceType): - sentences: List[str] + sentences: list[str] """A list of strings which will be compared against the source_sentence.""" source_sentence: str """The string that you wish to compare the other strings with. This can be a phrase, @@ -23,5 +23,5 @@ class SentenceSimilarityInput(BaseInferenceType): """Inputs for Sentence similarity inference""" inputs: SentenceSimilarityInputData - parameters: Optional[Dict[str, Any]] = None + parameters: Optional[dict[str, Any]] = None """Additional inference parameters for Sentence Similarity""" diff --git a/src/huggingface_hub/inference/_generated/types/summarization.py b/src/huggingface_hub/inference/_generated/types/summarization.py index 33eae6fcba..0103853aa6 100644 --- a/src/huggingface_hub/inference/_generated/types/summarization.py +++ b/src/huggingface_hub/inference/_generated/types/summarization.py @@ -3,7 +3,7 @@ # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. -from typing import Any, Dict, Literal, Optional +from typing import Any, Literal, Optional from .base import BaseInferenceType, dataclass_with_extra @@ -17,7 +17,7 @@ class SummarizationParameters(BaseInferenceType): clean_up_tokenization_spaces: Optional[bool] = None """Whether to clean up the potential extra spaces in the text output.""" - generate_parameters: Optional[Dict[str, Any]] = None + generate_parameters: Optional[dict[str, Any]] = None """Additional parametrization of the text generation algorithm.""" truncation: Optional["SummarizationTruncationStrategy"] = None """The truncation strategy to use.""" diff --git a/src/huggingface_hub/inference/_generated/types/table_question_answering.py b/src/huggingface_hub/inference/_generated/types/table_question_answering.py index 10e208eeeb..cceb59fde9 100644 --- a/src/huggingface_hub/inference/_generated/types/table_question_answering.py +++ b/src/huggingface_hub/inference/_generated/types/table_question_answering.py @@ -3,7 +3,7 @@ # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. -from typing import Dict, List, Literal, Optional +from typing import Literal, Optional from .base import BaseInferenceType, dataclass_with_extra @@ -14,7 +14,7 @@ class TableQuestionAnsweringInputData(BaseInferenceType): question: str """The question to be answered about the table""" - table: Dict[str, List[str]] + table: dict[str, list[str]] """The table to serve as context for the questions""" @@ -54,9 +54,9 @@ class TableQuestionAnsweringOutputElement(BaseInferenceType): """The answer of the question given the table. If there is an aggregator, the answer will be preceded by `AGGREGATOR >`. """ - cells: List[str] + cells: list[str] """List of strings made up of the answer cell values.""" - coordinates: List[List[int]] + coordinates: list[list[int]] """Coordinates of the cells of the answers.""" aggregator: Optional[str] = None """If the model has an aggregator, this returns the aggregator.""" diff --git a/src/huggingface_hub/inference/_generated/types/text2text_generation.py b/src/huggingface_hub/inference/_generated/types/text2text_generation.py index 34ac74e21e..bda2211902 100644 --- a/src/huggingface_hub/inference/_generated/types/text2text_generation.py +++ b/src/huggingface_hub/inference/_generated/types/text2text_generation.py @@ -3,7 +3,7 @@ # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. -from typing import Any, Dict, Literal, Optional +from typing import Any, Literal, Optional from .base import BaseInferenceType, dataclass_with_extra @@ -17,7 +17,7 @@ class Text2TextGenerationParameters(BaseInferenceType): clean_up_tokenization_spaces: Optional[bool] = None """Whether to clean up the potential extra spaces in the text output.""" - generate_parameters: Optional[Dict[str, Any]] = None + generate_parameters: Optional[dict[str, Any]] = None """Additional parametrization of the text generation algorithm""" truncation: Optional["Text2TextGenerationTruncationStrategy"] = None """The truncation strategy to use""" diff --git a/src/huggingface_hub/inference/_generated/types/text_generation.py b/src/huggingface_hub/inference/_generated/types/text_generation.py index 9b79cc691d..b470198b40 100644 --- a/src/huggingface_hub/inference/_generated/types/text_generation.py +++ b/src/huggingface_hub/inference/_generated/types/text_generation.py @@ -3,7 +3,7 @@ # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. -from typing import Any, List, Literal, Optional +from typing import Any, Literal, Optional from .base import BaseInferenceType, dataclass_with_extra @@ -49,7 +49,7 @@ class TextGenerationInputGenerateParameters(BaseInferenceType): """Whether to prepend the prompt to the generated text""" seed: Optional[int] = None """Random sampling seed.""" - stop: Optional[List[str]] = None + stop: Optional[list[str]] = None """Stop generating tokens if a member of `stop` is generated.""" temperature: Optional[float] = None """The value used to module the logits distribution.""" @@ -108,21 +108,21 @@ class TextGenerationOutputBestOfSequence(BaseInferenceType): finish_reason: "TextGenerationOutputFinishReason" generated_text: str generated_tokens: int - prefill: List[TextGenerationOutputPrefillToken] - tokens: List[TextGenerationOutputToken] + prefill: list[TextGenerationOutputPrefillToken] + tokens: list[TextGenerationOutputToken] seed: Optional[int] = None - top_tokens: Optional[List[List[TextGenerationOutputToken]]] = None + top_tokens: Optional[list[list[TextGenerationOutputToken]]] = None @dataclass_with_extra class TextGenerationOutputDetails(BaseInferenceType): finish_reason: "TextGenerationOutputFinishReason" generated_tokens: int - prefill: List[TextGenerationOutputPrefillToken] - tokens: List[TextGenerationOutputToken] - best_of_sequences: Optional[List[TextGenerationOutputBestOfSequence]] = None + prefill: list[TextGenerationOutputPrefillToken] + tokens: list[TextGenerationOutputToken] + best_of_sequences: Optional[list[TextGenerationOutputBestOfSequence]] = None seed: Optional[int] = None - top_tokens: Optional[List[List[TextGenerationOutputToken]]] = None + top_tokens: Optional[list[list[TextGenerationOutputToken]]] = None @dataclass_with_extra @@ -165,4 +165,4 @@ class TextGenerationStreamOutput(BaseInferenceType): token: TextGenerationStreamOutputToken details: Optional[TextGenerationStreamOutputStreamDetails] = None generated_text: Optional[str] = None - top_tokens: Optional[List[TextGenerationStreamOutputToken]] = None + top_tokens: Optional[list[TextGenerationStreamOutputToken]] = None diff --git a/src/huggingface_hub/inference/_generated/types/text_to_video.py b/src/huggingface_hub/inference/_generated/types/text_to_video.py index e54a1bc094..a7e9637821 100644 --- a/src/huggingface_hub/inference/_generated/types/text_to_video.py +++ b/src/huggingface_hub/inference/_generated/types/text_to_video.py @@ -3,7 +3,7 @@ # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. -from typing import Any, List, Optional +from typing import Any, Optional from .base import BaseInferenceType, dataclass_with_extra @@ -16,7 +16,7 @@ class TextToVideoParameters(BaseInferenceType): """A higher guidance scale value encourages the model to generate videos closely linked to the text prompt, but values too high may cause saturation and other artifacts. """ - negative_prompt: Optional[List[str]] = None + negative_prompt: Optional[list[str]] = None """One or several prompt to guide what NOT to include in video generation.""" num_frames: Optional[float] = None """The num_frames parameter determines how many video frames are generated.""" diff --git a/src/huggingface_hub/inference/_generated/types/token_classification.py b/src/huggingface_hub/inference/_generated/types/token_classification.py index e039b6a1db..b40f4b5f6f 100644 --- a/src/huggingface_hub/inference/_generated/types/token_classification.py +++ b/src/huggingface_hub/inference/_generated/types/token_classification.py @@ -3,7 +3,7 @@ # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. -from typing import List, Literal, Optional +from typing import Literal, Optional from .base import BaseInferenceType, dataclass_with_extra @@ -17,7 +17,7 @@ class TokenClassificationParameters(BaseInferenceType): aggregation_strategy: Optional["TokenClassificationAggregationStrategy"] = None """The strategy used to fuse tokens based on model predictions""" - ignore_labels: Optional[List[str]] = None + ignore_labels: Optional[list[str]] = None """A list of labels to ignore""" stride: Optional[int] = None """The number of overlapping tokens between chunks when splitting the input text.""" diff --git a/src/huggingface_hub/inference/_generated/types/translation.py b/src/huggingface_hub/inference/_generated/types/translation.py index df95b7dbb1..59619e9a90 100644 --- a/src/huggingface_hub/inference/_generated/types/translation.py +++ b/src/huggingface_hub/inference/_generated/types/translation.py @@ -3,7 +3,7 @@ # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. -from typing import Any, Dict, Literal, Optional +from typing import Any, Literal, Optional from .base import BaseInferenceType, dataclass_with_extra @@ -17,7 +17,7 @@ class TranslationParameters(BaseInferenceType): clean_up_tokenization_spaces: Optional[bool] = None """Whether to clean up the potential extra spaces in the text output.""" - generate_parameters: Optional[Dict[str, Any]] = None + generate_parameters: Optional[dict[str, Any]] = None """Additional parametrization of the text generation algorithm.""" src_lang: Optional[str] = None """The source language of the text. Required for models that can translate from multiple diff --git a/src/huggingface_hub/inference/_generated/types/zero_shot_classification.py b/src/huggingface_hub/inference/_generated/types/zero_shot_classification.py index 47b32492e3..7b0dd13237 100644 --- a/src/huggingface_hub/inference/_generated/types/zero_shot_classification.py +++ b/src/huggingface_hub/inference/_generated/types/zero_shot_classification.py @@ -3,7 +3,7 @@ # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. -from typing import List, Optional +from typing import Optional from .base import BaseInferenceType, dataclass_with_extra @@ -12,7 +12,7 @@ class ZeroShotClassificationParameters(BaseInferenceType): """Additional inference parameters for Zero Shot Classification""" - candidate_labels: List[str] + candidate_labels: list[str] """The set of possible class labels to classify the text into.""" hypothesis_template: Optional[str] = None """The sentence used in conjunction with `candidate_labels` to attempt the text diff --git a/src/huggingface_hub/inference/_generated/types/zero_shot_image_classification.py b/src/huggingface_hub/inference/_generated/types/zero_shot_image_classification.py index 998d66b6b4..ed138eada5 100644 --- a/src/huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +++ b/src/huggingface_hub/inference/_generated/types/zero_shot_image_classification.py @@ -3,7 +3,7 @@ # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. -from typing import List, Optional +from typing import Optional from .base import BaseInferenceType, dataclass_with_extra @@ -12,7 +12,7 @@ class ZeroShotImageClassificationParameters(BaseInferenceType): """Additional inference parameters for Zero Shot Image Classification""" - candidate_labels: List[str] + candidate_labels: list[str] """The candidate labels for this image""" hypothesis_template: Optional[str] = None """The sentence used in conjunction with `candidate_labels` to attempt the image diff --git a/src/huggingface_hub/inference/_generated/types/zero_shot_object_detection.py b/src/huggingface_hub/inference/_generated/types/zero_shot_object_detection.py index 8ef76b5fcb..e981463b25 100644 --- a/src/huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +++ b/src/huggingface_hub/inference/_generated/types/zero_shot_object_detection.py @@ -3,8 +3,6 @@ # See: # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. -from typing import List - from .base import BaseInferenceType, dataclass_with_extra @@ -12,7 +10,7 @@ class ZeroShotObjectDetectionParameters(BaseInferenceType): """Additional inference parameters for Zero Shot Object Detection""" - candidate_labels: List[str] + candidate_labels: list[str] """The candidate labels for this image""" diff --git a/src/huggingface_hub/inference/_mcp/agent.py b/src/huggingface_hub/inference/_mcp/agent.py index 4f88016ba7..0a372608d0 100644 --- a/src/huggingface_hub/inference/_mcp/agent.py +++ b/src/huggingface_hub/inference/_mcp/agent.py @@ -1,7 +1,7 @@ from __future__ import annotations import asyncio -from typing import AsyncGenerator, Dict, Iterable, List, Optional, Union +from typing import AsyncGenerator, Iterable, Optional, Union from huggingface_hub import ChatCompletionInputMessage, ChatCompletionStreamOutput, MCPClient @@ -24,7 +24,7 @@ class Agent(MCPClient): model (`str`, *optional*): The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `meta-llama/Meta-Llama-3-8B-Instruct` or a URL to a deployed Inference Endpoint or other local or remote endpoint. - servers (`Iterable[Dict]`): + servers (`Iterable[dict]`): MCP servers to connect to. Each server is a dictionary containing a `type` key and a `config` key. The `type` key can be `"stdio"` or `"sse"`, and the `config` key is a dictionary of arguments for the server. provider (`str`, *optional*): Name of the provider to use for inference. Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers. @@ -49,7 +49,7 @@ def __init__( ): super().__init__(model=model, provider=provider, base_url=base_url, api_key=api_key) self._servers_cfg = list(servers) - self.messages: List[Union[Dict, ChatCompletionInputMessage]] = [ + self.messages: list[Union[dict, ChatCompletionInputMessage]] = [ {"role": "system", "content": prompt or DEFAULT_SYSTEM_PROMPT} ] diff --git a/src/huggingface_hub/inference/_mcp/constants.py b/src/huggingface_hub/inference/_mcp/constants.py index 1ccade43b1..737a9ae549 100644 --- a/src/huggingface_hub/inference/_mcp/constants.py +++ b/src/huggingface_hub/inference/_mcp/constants.py @@ -2,7 +2,6 @@ import sys from pathlib import Path -from typing import List from huggingface_hub import ChatCompletionInputTool @@ -76,7 +75,7 @@ } ) -EXIT_LOOP_TOOLS: List[ChatCompletionInputTool] = [TASK_COMPLETE_TOOL, ASK_QUESTION_TOOL] +EXIT_LOOP_TOOLS: list[ChatCompletionInputTool] = [TASK_COMPLETE_TOOL, ASK_QUESTION_TOOL] DEFAULT_REPO_ID = "tiny-agents/tiny-agents" diff --git a/src/huggingface_hub/inference/_mcp/mcp_client.py b/src/huggingface_hub/inference/_mcp/mcp_client.py index 2383303450..51368bce1f 100644 --- a/src/huggingface_hub/inference/_mcp/mcp_client.py +++ b/src/huggingface_hub/inference/_mcp/mcp_client.py @@ -3,9 +3,9 @@ from contextlib import AsyncExitStack from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, Optional, Union, overload +from typing import TYPE_CHECKING, Any, AsyncIterable, Literal, Optional, TypedDict, Union, overload -from typing_extensions import NotRequired, TypeAlias, TypedDict, Unpack +from typing_extensions import NotRequired, TypeAlias, Unpack from ...utils._runtime import get_hf_hub_version from .._generated._async_client import AsyncInferenceClient @@ -32,14 +32,14 @@ class StdioServerParameters_T(TypedDict): command: str - args: NotRequired[List[str]] - env: NotRequired[Dict[str, str]] + args: NotRequired[list[str]] + env: NotRequired[dict[str, str]] cwd: NotRequired[Union[str, Path, None]] class SSEServerParameters_T(TypedDict): url: str - headers: NotRequired[Dict[str, Any]] + headers: NotRequired[dict[str, Any]] timeout: NotRequired[float] sse_read_timeout: NotRequired[float] @@ -84,9 +84,9 @@ def __init__( api_key: Optional[str] = None, ): # Initialize MCP sessions as a dictionary of ClientSession objects - self.sessions: Dict[ToolName, "ClientSession"] = {} + self.sessions: dict[ToolName, "ClientSession"] = {} self.exit_stack = AsyncExitStack() - self.available_tools: List[ChatCompletionInputTool] = [] + self.available_tools: list[ChatCompletionInputTool] = [] # To be able to send the model in the payload if `base_url` is provided if model is None and base_url is None: raise ValueError("At least one of `model` or `base_url` should be set in `MCPClient`.") @@ -132,27 +132,27 @@ async def add_mcp_server(self, type: ServerType, **params: Any): - "stdio": Standard input/output server (local) - "sse": Server-sent events (SSE) server - "http": StreamableHTTP server - **params (`Dict[str, Any]`): + **params (`dict[str, Any]`): Server parameters that can be either: - For stdio servers: - command (str): The command to run the MCP server - - args (List[str], optional): Arguments for the command - - env (Dict[str, str], optional): Environment variables for the command + - args (list[str], optional): Arguments for the command + - env (dict[str, str], optional): Environment variables for the command - cwd (Union[str, Path, None], optional): Working directory for the command - - allowed_tools (List[str], optional): List of tool names to allow from this server + - allowed_tools (list[str], optional): List of tool names to allow from this server - For SSE servers: - url (str): The URL of the SSE server - - headers (Dict[str, Any], optional): Headers for the SSE connection + - headers (dict[str, Any], optional): Headers for the SSE connection - timeout (float, optional): Connection timeout - sse_read_timeout (float, optional): SSE read timeout - - allowed_tools (List[str], optional): List of tool names to allow from this server + - allowed_tools (list[str], optional): List of tool names to allow from this server - For StreamableHTTP servers: - url (str): The URL of the StreamableHTTP server - - headers (Dict[str, Any], optional): Headers for the StreamableHTTP connection + - headers (dict[str, Any], optional): Headers for the StreamableHTTP connection - timeout (timedelta, optional): Connection timeout - sse_read_timeout (timedelta, optional): SSE read timeout - terminate_on_close (bool, optional): Whether to terminate on close - - allowed_tools (List[str], optional): List of tool names to allow from this server + - allowed_tools (list[str], optional): List of tool names to allow from this server """ from mcp import ClientSession, StdioServerParameters from mcp import types as mcp_types @@ -249,16 +249,16 @@ async def add_mcp_server(self, type: ServerType, **params: Any): async def process_single_turn_with_tools( self, - messages: List[Union[Dict, ChatCompletionInputMessage]], - exit_loop_tools: Optional[List[ChatCompletionInputTool]] = None, + messages: list[Union[dict, ChatCompletionInputMessage]], + exit_loop_tools: Optional[list[ChatCompletionInputTool]] = None, exit_if_first_chunk_no_tool: bool = False, ) -> AsyncIterable[Union[ChatCompletionStreamOutput, ChatCompletionInputMessage]]: """Process a query using `self.model` and available tools, yielding chunks and tool outputs. Args: - messages (`List[Dict]`): + messages (`list[dict]`): List of message objects representing the conversation history - exit_loop_tools (`List[ChatCompletionInputTool]`, *optional*): + exit_loop_tools (`list[ChatCompletionInputTool]`, *optional*): List of tools that should exit the generator when called exit_if_first_chunk_no_tool (`bool`, *optional*): Exit if no tool is present in the first chunks. Default to False. @@ -280,8 +280,8 @@ async def process_single_turn_with_tools( stream=True, ) - message: Dict[str, Any] = {"role": "unknown", "content": ""} - final_tool_calls: Dict[int, ChatCompletionStreamOutputDeltaToolCall] = {} + message: dict[str, Any] = {"role": "unknown", "content": ""} + final_tool_calls: dict[int, ChatCompletionStreamOutputDeltaToolCall] = {} num_of_chunks = 0 # Read from stream @@ -328,7 +328,7 @@ async def process_single_turn_with_tools( message["role"] = "assistant" # Convert final_tool_calls to the format expected by OpenAI if final_tool_calls: - tool_calls_list: List[Dict[str, Any]] = [] + tool_calls_list: list[dict[str, Any]] = [] for tc in final_tool_calls.values(): tool_calls_list.append( { @@ -346,6 +346,17 @@ async def process_single_turn_with_tools( # Process tool calls one by one for tool_call in final_tool_calls.values(): function_name = tool_call.function.name + if function_name is None: + message = ChatCompletionInputMessage.parse_obj_as_instance( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": "Invalid tool call with no function name.", + } + ) + messages.append(message) + yield message + continue # move to next tool call try: function_args = json.loads(tool_call.function.arguments or "{}") except json.JSONDecodeError as err: diff --git a/src/huggingface_hub/inference/_mcp/types.py b/src/huggingface_hub/inference/_mcp/types.py index 100f67832e..a531929a8e 100644 --- a/src/huggingface_hub/inference/_mcp/types.py +++ b/src/huggingface_hub/inference/_mcp/types.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Literal, TypedDict, Union +from typing import Literal, TypedDict, Union from typing_extensions import NotRequired @@ -13,24 +13,24 @@ class InputConfig(TypedDict, total=False): class StdioServerConfig(TypedDict): type: Literal["stdio"] command: str - args: List[str] - env: Dict[str, str] + args: list[str] + env: dict[str, str] cwd: str - allowed_tools: NotRequired[List[str]] + allowed_tools: NotRequired[list[str]] class HTTPServerConfig(TypedDict): type: Literal["http"] url: str - headers: Dict[str, str] - allowed_tools: NotRequired[List[str]] + headers: dict[str, str] + allowed_tools: NotRequired[list[str]] class SSEServerConfig(TypedDict): type: Literal["sse"] url: str - headers: Dict[str, str] - allowed_tools: NotRequired[List[str]] + headers: dict[str, str] + allowed_tools: NotRequired[list[str]] ServerConfig = Union[StdioServerConfig, HTTPServerConfig, SSEServerConfig] @@ -41,5 +41,5 @@ class AgentConfig(TypedDict): model: str provider: str apiKey: NotRequired[str] - inputs: List[InputConfig] - servers: List[ServerConfig] + inputs: list[InputConfig] + servers: list[ServerConfig] diff --git a/src/huggingface_hub/inference/_mcp/utils.py b/src/huggingface_hub/inference/_mcp/utils.py index ddab10d677..09c902815b 100644 --- a/src/huggingface_hub/inference/_mcp/utils.py +++ b/src/huggingface_hub/inference/_mcp/utils.py @@ -6,7 +6,7 @@ import json from pathlib import Path -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Optional from huggingface_hub import snapshot_download from huggingface_hub.errors import EntryNotFoundError @@ -36,7 +36,7 @@ def format_result(result: "mcp_types.CallToolResult") -> str: if len(content) == 0: return "[No content]" - formatted_parts: List[str] = [] + formatted_parts: list[str] = [] for item in content: if item.type == "text": @@ -84,10 +84,10 @@ def _get_base64_size(base64_str: str) -> int: return (len(base64_str) * 3) // 4 - padding -def _load_agent_config(agent_path: Optional[str]) -> Tuple[AgentConfig, Optional[str]]: +def _load_agent_config(agent_path: Optional[str]) -> tuple[AgentConfig, Optional[str]]: """Load server config and prompt.""" - def _read_dir(directory: Path) -> Tuple[AgentConfig, Optional[str]]: + def _read_dir(directory: Path) -> tuple[AgentConfig, Optional[str]]: cfg_file = directory / FILENAME_CONFIG if not cfg_file.exists(): raise FileNotFoundError(f" Config file not found in {directory}! Please make sure it exists locally") diff --git a/src/huggingface_hub/inference/_providers/__init__.py b/src/huggingface_hub/inference/_providers/__init__.py index ec4866c30d..2e40f32d19 100644 --- a/src/huggingface_hub/inference/_providers/__init__.py +++ b/src/huggingface_hub/inference/_providers/__init__.py @@ -1,4 +1,4 @@ -from typing import Dict, Literal, Optional, Union +from typing import Literal, Optional, Union from huggingface_hub.inference._providers.featherless_ai import ( FeatherlessConversationalTask, @@ -65,7 +65,7 @@ PROVIDER_OR_POLICY_T = Union[PROVIDER_T, Literal["auto"]] -PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = { +PROVIDERS: dict[PROVIDER_T, dict[str, TaskProviderHelper]] = { "black-forest-labs": { "text-to-image": BlackForestLabsTextToImageTask(), }, diff --git a/src/huggingface_hub/inference/_providers/_common.py b/src/huggingface_hub/inference/_providers/_common.py index 687464a934..e7f3eb1d96 100644 --- a/src/huggingface_hub/inference/_providers/_common.py +++ b/src/huggingface_hub/inference/_providers/_common.py @@ -1,5 +1,5 @@ from functools import lru_cache -from typing import Any, Dict, List, Optional, Union, overload +from typing import Any, Optional, Union, overload from huggingface_hub import constants from huggingface_hub.hf_api import InferenceProviderMapping @@ -14,7 +14,7 @@ # Dev purposes only. # If you want to try to run inference for a new model locally before it's registered on huggingface.co # for a given Inference Provider, you can add it to the following dictionary. -HARDCODED_MODEL_INFERENCE_MAPPING: Dict[str, Dict[str, InferenceProviderMapping]] = { +HARDCODED_MODEL_INFERENCE_MAPPING: dict[str, dict[str, InferenceProviderMapping]] = { # "HF model ID" => InferenceProviderMapping object initialized with "Model ID on Inference Provider's side" # # Example: @@ -38,14 +38,14 @@ @overload -def filter_none(obj: Dict[str, Any]) -> Dict[str, Any]: ... +def filter_none(obj: dict[str, Any]) -> dict[str, Any]: ... @overload -def filter_none(obj: List[Any]) -> List[Any]: ... +def filter_none(obj: list[Any]) -> list[Any]: ... -def filter_none(obj: Union[Dict[str, Any], List[Any]]) -> Union[Dict[str, Any], List[Any]]: +def filter_none(obj: Union[dict[str, Any], list[Any]]) -> Union[dict[str, Any], list[Any]]: if isinstance(obj, dict): - cleaned: Dict[str, Any] = {} + cleaned: dict[str, Any] = {} for k, v in obj.items(): if v is None: continue @@ -72,11 +72,11 @@ def prepare_request( self, *, inputs: Any, - parameters: Dict[str, Any], - headers: Dict, + parameters: dict[str, Any], + headers: dict, model: Optional[str], api_key: Optional[str], - extra_payload: Optional[Dict[str, Any]] = None, + extra_payload: Optional[dict[str, Any]] = None, ) -> RequestParameters: """ Prepare the request to be sent to the provider. @@ -123,7 +123,7 @@ def prepare_request( def get_response( self, - response: Union[bytes, Dict], + response: Union[bytes, dict], request_params: Optional[RequestParameters] = None, ) -> Any: """ @@ -182,8 +182,8 @@ def _prepare_mapping_info(self, model: Optional[str]) -> InferenceProviderMappin return provider_mapping def _normalize_headers( - self, headers: Dict[str, Any], payload: Optional[Dict[str, Any]], data: Optional[MimeBytes] - ) -> Dict[str, Any]: + self, headers: dict[str, Any], payload: Optional[dict[str, Any]], data: Optional[MimeBytes] + ) -> dict[str, Any]: """Normalize the headers to use for the request. Override this method in subclasses for customized headers. @@ -196,7 +196,7 @@ def _normalize_headers( normalized_headers["content-type"] = "application/json" return normalized_headers - def _prepare_headers(self, headers: Dict, api_key: str) -> Dict[str, Any]: + def _prepare_headers(self, headers: dict, api_key: str) -> dict[str, Any]: """Return the headers to use for the request. Override this method in subclasses for customized headers. @@ -231,8 +231,8 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: return "" def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: """Return the payload to use for the request, as a dict. Override this method in subclasses for customized payloads. @@ -243,9 +243,9 @@ def _prepare_payload_as_dict( def _prepare_payload_as_bytes( self, inputs: Any, - parameters: Dict, + parameters: dict, provider_mapping_info: InferenceProviderMapping, - extra_payload: Optional[Dict], + extra_payload: Optional[dict], ) -> Optional[MimeBytes]: """Return the body to use for the request, as bytes. @@ -269,10 +269,10 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: def _prepare_payload_as_dict( self, - inputs: List[Union[Dict, ChatCompletionInputMessage]], - parameters: Dict, + inputs: list[Union[dict, ChatCompletionInputMessage]], + parameters: dict, provider_mapping_info: InferenceProviderMapping, - ) -> Optional[Dict]: + ) -> Optional[dict]: return filter_none({"messages": inputs, **parameters, "model": provider_mapping_info.provider_id}) @@ -289,13 +289,13 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: return "/v1/completions" def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: return filter_none({"prompt": inputs, **parameters, "model": provider_mapping_info.provider_id}) @lru_cache(maxsize=None) -def _fetch_inference_provider_mapping(model: str) -> List["InferenceProviderMapping"]: +def _fetch_inference_provider_mapping(model: str) -> list["InferenceProviderMapping"]: """ Fetch provider mappings for a model from the Hub. """ @@ -308,7 +308,7 @@ def _fetch_inference_provider_mapping(model: str) -> List["InferenceProviderMapp return provider_mapping -def recursive_merge(dict1: Dict, dict2: Dict) -> Dict: +def recursive_merge(dict1: dict, dict2: dict) -> dict: return { **dict1, **{ diff --git a/src/huggingface_hub/inference/_providers/black_forest_labs.py b/src/huggingface_hub/inference/_providers/black_forest_labs.py index a5d9683225..1d91b0b842 100644 --- a/src/huggingface_hub/inference/_providers/black_forest_labs.py +++ b/src/huggingface_hub/inference/_providers/black_forest_labs.py @@ -1,5 +1,5 @@ import time -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters, _as_dict @@ -18,7 +18,7 @@ class BlackForestLabsTextToImageTask(TaskProviderHelper): def __init__(self): super().__init__(provider="black-forest-labs", base_url="https://api.us1.bfl.ai", task="text-to-image") - def _prepare_headers(self, headers: Dict, api_key: str) -> Dict[str, Any]: + def _prepare_headers(self, headers: dict, api_key: str) -> dict[str, Any]: headers = super()._prepare_headers(headers, api_key) if not api_key.startswith("hf_"): _ = headers.pop("authorization") @@ -29,8 +29,8 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: return f"/v1/{mapped_model}" def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: parameters = filter_none(parameters) if "num_inference_steps" in parameters: parameters["steps"] = parameters.pop("num_inference_steps") @@ -39,7 +39,7 @@ def _prepare_payload_as_dict( return {"prompt": inputs, **parameters} - def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any: """ Polling mechanism for Black Forest Labs since the API is asynchronous. """ @@ -50,7 +50,7 @@ def get_response(self, response: Union[bytes, Dict], request_params: Optional[Re response = session.get(url, headers={"Content-Type": "application/json"}) # type: ignore response.raise_for_status() # type: ignore - response_json: Dict = response.json() # type: ignore + response_json: dict = response.json() # type: ignore status = response_json.get("status") logger.info( f"Polling generation result from {url}. Current status: {status}. " diff --git a/src/huggingface_hub/inference/_providers/cohere.py b/src/huggingface_hub/inference/_providers/cohere.py index a5e9191cae..0190d5449b 100644 --- a/src/huggingface_hub/inference/_providers/cohere.py +++ b/src/huggingface_hub/inference/_providers/cohere.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Optional from huggingface_hub.hf_api import InferenceProviderMapping @@ -17,8 +17,8 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: return "/compatibility/v1/chat/completions" def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: payload = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) response_format = parameters.get("response_format") if isinstance(response_format, dict) and response_format.get("type") == "json_schema": diff --git a/src/huggingface_hub/inference/_providers/fal_ai.py b/src/huggingface_hub/inference/_providers/fal_ai.py index b39b33f616..17cb0168b7 100644 --- a/src/huggingface_hub/inference/_providers/fal_ai.py +++ b/src/huggingface_hub/inference/_providers/fal_ai.py @@ -1,7 +1,7 @@ import base64 import time from abc import ABC -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from urllib.parse import urlparse from huggingface_hub import constants @@ -22,7 +22,7 @@ class FalAITask(TaskProviderHelper, ABC): def __init__(self, task: str): super().__init__(provider="fal-ai", base_url="https://fal.run", task=task) - def _prepare_headers(self, headers: Dict, api_key: str) -> Dict[str, Any]: + def _prepare_headers(self, headers: dict, api_key: str) -> dict[str, Any]: headers = super()._prepare_headers(headers, api_key) if not api_key.startswith("hf_"): headers["authorization"] = f"Key {api_key}" @@ -36,7 +36,7 @@ class FalAIQueueTask(TaskProviderHelper, ABC): def __init__(self, task: str): super().__init__(provider="fal-ai", base_url="https://queue.fal.run", task=task) - def _prepare_headers(self, headers: Dict, api_key: str) -> Dict[str, Any]: + def _prepare_headers(self, headers: dict, api_key: str) -> dict[str, Any]: headers = super()._prepare_headers(headers, api_key) if not api_key.startswith("hf_"): headers["authorization"] = f"Key {api_key}" @@ -50,7 +50,7 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: def get_response( self, - response: Union[bytes, Dict], + response: Union[bytes, dict], request_params: Optional[RequestParameters] = None, ) -> Any: response_dict = _as_dict(response) @@ -91,8 +91,8 @@ def __init__(self): super().__init__("automatic-speech-recognition") def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: if isinstance(inputs, str) and inputs.startswith(("http://", "https://")): # If input is a URL, pass it directly audio_url = inputs @@ -108,7 +108,7 @@ def _prepare_payload_as_dict( return {"audio_url": audio_url, **filter_none(parameters)} - def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any: text = _as_dict(response)["text"] if not isinstance(text, str): raise ValueError(f"Unexpected output format from FalAI API. Expected string, got {type(text)}.") @@ -120,9 +120,9 @@ def __init__(self): super().__init__("text-to-image") def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: - payload: Dict[str, Any] = { + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: + payload: dict[str, Any] = { "prompt": inputs, **filter_none(parameters), } @@ -145,7 +145,7 @@ def _prepare_payload_as_dict( return payload - def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any: url = _as_dict(response)["images"][0]["url"] return get_session().get(url).content @@ -155,11 +155,11 @@ def __init__(self): super().__init__("text-to-speech") def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: return {"text": inputs, **filter_none(parameters)} - def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any: url = _as_dict(response)["audio"]["url"] return get_session().get(url).content @@ -169,13 +169,13 @@ def __init__(self): super().__init__("text-to-video") def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: return {"prompt": inputs, **filter_none(parameters)} def get_response( self, - response: Union[bytes, Dict], + response: Union[bytes, dict], request_params: Optional[RequestParameters] = None, ) -> Any: output = super().get_response(response, request_params) @@ -188,10 +188,10 @@ def __init__(self): super().__init__("image-to-image") def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: image_url = _as_url(inputs, default_mime_type="image/jpeg") - payload: Dict[str, Any] = { + payload: dict[str, Any] = { "image_url": image_url, **filter_none(parameters), } @@ -207,7 +207,7 @@ def _prepare_payload_as_dict( def get_response( self, - response: Union[bytes, Dict], + response: Union[bytes, dict], request_params: Optional[RequestParameters] = None, ) -> Any: output = super().get_response(response, request_params) @@ -220,10 +220,10 @@ def __init__(self): super().__init__("image-to-video") def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: image_url = _as_url(inputs, default_mime_type="image/jpeg") - payload: Dict[str, Any] = { + payload: dict[str, Any] = { "image_url": image_url, **filter_none(parameters), } @@ -238,7 +238,7 @@ def _prepare_payload_as_dict( def get_response( self, - response: Union[bytes, Dict], + response: Union[bytes, dict], request_params: Optional[RequestParameters] = None, ) -> Any: output = super().get_response(response, request_params) diff --git a/src/huggingface_hub/inference/_providers/featherless_ai.py b/src/huggingface_hub/inference/_providers/featherless_ai.py index 6ad1c48134..ab119636c0 100644 --- a/src/huggingface_hub/inference/_providers/featherless_ai.py +++ b/src/huggingface_hub/inference/_providers/featherless_ai.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters, _as_dict @@ -15,14 +15,14 @@ def __init__(self): super().__init__(provider=_PROVIDER, base_url=_BASE_URL) def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: params = filter_none(parameters.copy()) params["max_tokens"] = params.pop("max_new_tokens", None) return {"prompt": inputs, **params, "model": provider_mapping_info.provider_id} - def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any: output = _as_dict(response)["choices"][0] return { "generated_text": output["text"], diff --git a/src/huggingface_hub/inference/_providers/fireworks_ai.py b/src/huggingface_hub/inference/_providers/fireworks_ai.py index b4cc19a570..d76c58478b 100644 --- a/src/huggingface_hub/inference/_providers/fireworks_ai.py +++ b/src/huggingface_hub/inference/_providers/fireworks_ai.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Optional from huggingface_hub.hf_api import InferenceProviderMapping @@ -13,8 +13,8 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: return "/inference/v1/chat/completions" def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: payload = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) response_format = parameters.get("response_format") if isinstance(response_format, dict) and response_format.get("type") == "json_schema": diff --git a/src/huggingface_hub/inference/_providers/hf_inference.py b/src/huggingface_hub/inference/_providers/hf_inference.py index d90d00c4f3..dddfaaea85 100644 --- a/src/huggingface_hub/inference/_providers/hf_inference.py +++ b/src/huggingface_hub/inference/_providers/hf_inference.py @@ -1,7 +1,7 @@ import json from functools import lru_cache from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from urllib.parse import urlparse, urlunparse from huggingface_hub import constants @@ -60,8 +60,8 @@ def _prepare_url(self, api_key: str, mapped_model: str) -> str: ) def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: if isinstance(inputs, bytes): raise ValueError(f"Unexpected binary input for task {self.task}.") if isinstance(inputs, Path): @@ -71,16 +71,16 @@ def _prepare_payload_as_dict( class HFInferenceBinaryInputTask(HFInferenceTask): def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: return None def _prepare_payload_as_bytes( self, inputs: Any, - parameters: Dict, + parameters: dict, provider_mapping_info: InferenceProviderMapping, - extra_payload: Optional[Dict], + extra_payload: Optional[dict], ) -> Optional[MimeBytes]: parameters = filter_none(parameters) extra_payload = extra_payload or {} @@ -106,8 +106,8 @@ def __init__(self): super().__init__("conversational") def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: payload = filter_none(parameters) mapped_model = provider_mapping_info.provider_id payload_model = parameters.get("model") or mapped_model @@ -156,7 +156,7 @@ def _build_chat_completion_url(model_url: str) -> str: @lru_cache(maxsize=1) -def _fetch_recommended_models() -> Dict[str, Optional[str]]: +def _fetch_recommended_models() -> dict[str, Optional[str]]: response = get_session().get(f"{constants.ENDPOINT}/api/tasks", headers=build_hf_headers()) hf_raise_for_status(response) return {task: next(iter(details["widgetModels"]), None) for task, details in response.json().items()} @@ -211,8 +211,8 @@ def __init__(self): super().__init__("feature-extraction") def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: if isinstance(inputs, bytes): raise ValueError(f"Unexpected binary input for task {self.task}.") if isinstance(inputs, Path): @@ -222,7 +222,7 @@ def _prepare_payload_as_dict( # See specs: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/feature-extraction/spec/input.json return {"inputs": inputs, **filter_none(parameters)} - def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any: if isinstance(response, bytes): return _bytes_to_dict(response) return response diff --git a/src/huggingface_hub/inference/_providers/hyperbolic.py b/src/huggingface_hub/inference/_providers/hyperbolic.py index 6dcb14cc27..af512b1624 100644 --- a/src/huggingface_hub/inference/_providers/hyperbolic.py +++ b/src/huggingface_hub/inference/_providers/hyperbolic.py @@ -1,5 +1,5 @@ import base64 -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters, _as_dict @@ -14,8 +14,8 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: return "/v1/images/generations" def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: mapped_model = provider_mapping_info.provider_id parameters = filter_none(parameters) if "num_inference_steps" in parameters: @@ -29,7 +29,7 @@ def _prepare_payload_as_dict( parameters["height"] = 512 return {"prompt": inputs, "model_name": mapped_model, **parameters} - def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any: response_dict = _as_dict(response) return base64.b64decode(response_dict["images"][0]["image"]) diff --git a/src/huggingface_hub/inference/_providers/nebius.py b/src/huggingface_hub/inference/_providers/nebius.py index 85ad67c4c8..6731855049 100644 --- a/src/huggingface_hub/inference/_providers/nebius.py +++ b/src/huggingface_hub/inference/_providers/nebius.py @@ -1,5 +1,5 @@ import base64 -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters, _as_dict @@ -15,7 +15,7 @@ class NebiusTextGenerationTask(BaseTextGenerationTask): def __init__(self): super().__init__(provider="nebius", base_url="https://api.studio.nebius.ai") - def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any: output = _as_dict(response)["choices"][0] return { "generated_text": output["text"], @@ -31,8 +31,8 @@ def __init__(self): super().__init__(provider="nebius", base_url="https://api.studio.nebius.ai") def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: payload = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) response_format = parameters.get("response_format") if isinstance(response_format, dict) and response_format.get("type") == "json_schema": @@ -50,8 +50,8 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: return "/v1/images/generations" def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: mapped_model = provider_mapping_info.provider_id parameters = filter_none(parameters) if "guidance_scale" in parameters: @@ -61,7 +61,7 @@ def _prepare_payload_as_dict( return {"prompt": inputs, **parameters, "model": mapped_model} - def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any: response_dict = _as_dict(response) return base64.b64decode(response_dict["data"][0]["b64_json"]) @@ -74,10 +74,10 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: return "/v1/embeddings" def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: return {"input": inputs, "model": provider_mapping_info.provider_id} - def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any: embeddings = _as_dict(response)["data"] return [embedding["embedding"] for embedding in embeddings] diff --git a/src/huggingface_hub/inference/_providers/new_provider.md b/src/huggingface_hub/inference/_providers/new_provider.md index 923463284a..4a488df6bb 100644 --- a/src/huggingface_hub/inference/_providers/new_provider.md +++ b/src/huggingface_hub/inference/_providers/new_provider.md @@ -13,7 +13,7 @@ If the provider supports multiple tasks that require different implementations, For `text-generation` and `conversational` tasks, one can just inherit from `BaseTextGenerationTask` and `BaseConversationalTask` respectively (defined in `_common.py`) and override the methods if needed. Examples can be found in `fireworks_ai.py` and `together.py`. ```py -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from ._common import TaskProviderHelper, MimeBytes @@ -25,7 +25,7 @@ class MyNewProviderTaskProviderHelper(TaskProviderHelper): def get_response( self, - response: Union[bytes, Dict], + response: Union[bytes, dict], request_params: Optional[RequestParameters] = None, ) -> Any: """ @@ -34,7 +34,7 @@ class MyNewProviderTaskProviderHelper(TaskProviderHelper): Override this method in subclasses for customized response handling.""" return super().get_response(response) - def _prepare_headers(self, headers: Dict, api_key: str) -> Dict[str, Any]: + def _prepare_headers(self, headers: dict, api_key: str) -> dict[str, Any]: """Return the headers to use for the request. Override this method in subclasses for customized headers. @@ -48,7 +48,7 @@ class MyNewProviderTaskProviderHelper(TaskProviderHelper): """ return super()._prepare_route(mapped_model) - def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]: + def _prepare_payload_as_dict(self, inputs: Any, parameters: dict, mapped_model: str) -> Optional[dict]: """Return the payload to use for the request, as a dict. Override this method in subclasses for customized payloads. @@ -57,7 +57,7 @@ class MyNewProviderTaskProviderHelper(TaskProviderHelper): return super()._prepare_payload_as_dict(inputs, parameters, mapped_model) def _prepare_payload_as_bytes( - self, inputs: Any, parameters: Dict, mapped_model: str, extra_payload: Optional[Dict] + self, inputs: Any, parameters: dict, mapped_model: str, extra_payload: Optional[dict] ) -> Optional[MimeBytes]: """Return the body to use for the request, as bytes. diff --git a/src/huggingface_hub/inference/_providers/novita.py b/src/huggingface_hub/inference/_providers/novita.py index 44adc9017b..301d7a589d 100644 --- a/src/huggingface_hub/inference/_providers/novita.py +++ b/src/huggingface_hub/inference/_providers/novita.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters, _as_dict @@ -23,7 +23,7 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: # there is no v1/ route for novita return "/v3/openai/completions" - def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any: output = _as_dict(response)["choices"][0] return { "generated_text": output["text"], @@ -51,11 +51,11 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: return f"/v3/hf/{mapped_model}" def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: return {"prompt": inputs, **filter_none(parameters)} - def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any: response_dict = _as_dict(response) if not ( isinstance(response_dict, dict) diff --git a/src/huggingface_hub/inference/_providers/nscale.py b/src/huggingface_hub/inference/_providers/nscale.py index ce5b20e354..65b15147a2 100644 --- a/src/huggingface_hub/inference/_providers/nscale.py +++ b/src/huggingface_hub/inference/_providers/nscale.py @@ -1,5 +1,5 @@ import base64 -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters, _as_dict @@ -20,8 +20,8 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: return "/v1/images/generations" def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: mapped_model = provider_mapping_info.provider_id # Combine all parameters except inputs and parameters parameters = filter_none(parameters) @@ -39,6 +39,6 @@ def _prepare_payload_as_dict( } return payload - def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any: response_dict = _as_dict(response) return base64.b64decode(response_dict["data"][0]["b64_json"]) diff --git a/src/huggingface_hub/inference/_providers/replicate.py b/src/huggingface_hub/inference/_providers/replicate.py index 139582cc80..5a1d1b71f0 100644 --- a/src/huggingface_hub/inference/_providers/replicate.py +++ b/src/huggingface_hub/inference/_providers/replicate.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters, _as_dict, _as_url @@ -14,7 +14,7 @@ class ReplicateTask(TaskProviderHelper): def __init__(self, task: str): super().__init__(provider=_PROVIDER, base_url=_BASE_URL, task=task) - def _prepare_headers(self, headers: Dict, api_key: str) -> Dict[str, Any]: + def _prepare_headers(self, headers: dict, api_key: str) -> dict[str, Any]: headers = super()._prepare_headers(headers, api_key) headers["Prefer"] = "wait" return headers @@ -25,16 +25,16 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: return f"/v1/models/{mapped_model}/predictions" def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: mapped_model = provider_mapping_info.provider_id - payload: Dict[str, Any] = {"input": {"prompt": inputs, **filter_none(parameters)}} + payload: dict[str, Any] = {"input": {"prompt": inputs, **filter_none(parameters)}} if ":" in mapped_model: version = mapped_model.split(":", 1)[1] payload["version"] = version return payload - def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any: response_dict = _as_dict(response) if response_dict.get("output") is None: raise TimeoutError( @@ -52,9 +52,9 @@ def __init__(self): super().__init__("text-to-image") def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: - payload: Dict = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) # type: ignore[assignment] + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: + payload: dict = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) # type: ignore[assignment] if provider_mapping_info.adapter_weights_path is not None: payload["input"]["lora_weights"] = f"https://huggingface.co/{provider_mapping_info.hf_model_id}" return payload @@ -65,9 +65,9 @@ def __init__(self): super().__init__("text-to-speech") def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: - payload: Dict = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) # type: ignore[assignment] + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: + payload: dict = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) # type: ignore[assignment] payload["input"]["text"] = payload["input"].pop("prompt") # rename "prompt" to "text" for TTS return payload @@ -77,11 +77,11 @@ def __init__(self): super().__init__("image-to-image") def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: image_url = _as_url(inputs, default_mime_type="image/jpeg") - payload: Dict[str, Any] = {"input": {"input_image": image_url, **filter_none(parameters)}} + payload: dict[str, Any] = {"input": {"input_image": image_url, **filter_none(parameters)}} mapped_model = provider_mapping_info.provider_id if ":" in mapped_model: diff --git a/src/huggingface_hub/inference/_providers/sambanova.py b/src/huggingface_hub/inference/_providers/sambanova.py index ed96fb766c..4b7b1ee57b 100644 --- a/src/huggingface_hub/inference/_providers/sambanova.py +++ b/src/huggingface_hub/inference/_providers/sambanova.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters, _as_dict @@ -10,8 +10,8 @@ def __init__(self): super().__init__(provider="sambanova", base_url="https://api.sambanova.ai") def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: response_format_config = parameters.get("response_format") if isinstance(response_format_config, dict): if response_format_config.get("type") == "json_schema": @@ -32,11 +32,11 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: return "/v1/embeddings" def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: parameters = filter_none(parameters) return {"input": inputs, "model": provider_mapping_info.provider_id, **parameters} - def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any: embeddings = _as_dict(response)["data"] return [embedding["embedding"] for embedding in embeddings] diff --git a/src/huggingface_hub/inference/_providers/together.py b/src/huggingface_hub/inference/_providers/together.py index de166b7baf..338057d438 100644 --- a/src/huggingface_hub/inference/_providers/together.py +++ b/src/huggingface_hub/inference/_providers/together.py @@ -1,6 +1,6 @@ import base64 from abc import ABC -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters, _as_dict @@ -36,7 +36,7 @@ class TogetherTextGenerationTask(BaseTextGenerationTask): def __init__(self): super().__init__(provider=_PROVIDER, base_url=_BASE_URL) - def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any: output = _as_dict(response)["choices"][0] return { "generated_text": output["text"], @@ -52,8 +52,8 @@ def __init__(self): super().__init__(provider=_PROVIDER, base_url=_BASE_URL) def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: payload = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) response_format = parameters.get("response_format") if isinstance(response_format, dict) and response_format.get("type") == "json_schema": @@ -72,8 +72,8 @@ def __init__(self): super().__init__("text-to-image") def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping - ) -> Optional[Dict]: + self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[dict]: mapped_model = provider_mapping_info.provider_id parameters = filter_none(parameters) if "num_inference_steps" in parameters: @@ -83,6 +83,6 @@ def _prepare_payload_as_dict( return {"prompt": inputs, "response_format": "base64", **parameters, "model": mapped_model} - def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any: response_dict = _as_dict(response) return base64.b64decode(response_dict["data"][0]["b64_json"]) diff --git a/src/huggingface_hub/inference_api.py b/src/huggingface_hub/inference_api.py index 333fa0e5de..16c2812864 100644 --- a/src/huggingface_hub/inference_api.py +++ b/src/huggingface_hub/inference_api.py @@ -1,5 +1,5 @@ import io -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from . import constants from .hf_api import HfApi @@ -157,17 +157,17 @@ def __repr__(self): def __call__( self, - inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None, - params: Optional[Dict] = None, + inputs: Optional[Union[str, dict, list[str], list[list[str]]]] = None, + params: Optional[dict] = None, data: Optional[bytes] = None, raw_response: bool = False, ) -> Any: """Make a call to the Inference API. Args: - inputs (`str` or `Dict` or `List[str]` or `List[List[str]]`, *optional*): + inputs (`str` or `dict` or `list[str]` or `list[list[str]]`, *optional*): Inputs for the prediction. - params (`Dict`, *optional*): + params (`dict`, *optional*): Additional parameters for the models. Will be sent as `parameters` in the payload. data (`bytes`, *optional*): @@ -178,7 +178,7 @@ def __call__( (json dictionary or PIL Image for example). """ # Build payload - payload: Dict[str, Any] = { + payload: dict[str, Any] = { "options": self.options, } if inputs: diff --git a/src/huggingface_hub/keras_mixin.py b/src/huggingface_hub/keras_mixin.py index 53290dc858..fa38b5dfba 100644 --- a/src/huggingface_hub/keras_mixin.py +++ b/src/huggingface_hub/keras_mixin.py @@ -5,7 +5,7 @@ from functools import wraps from pathlib import Path from shutil import copytree -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from huggingface_hub import ModelHubMixin, snapshot_download from huggingface_hub.utils import ( @@ -157,7 +157,7 @@ def _create_model_card( def save_pretrained_keras( model, save_directory: Union[str, Path], - config: Optional[Dict[str, Any]] = None, + config: Optional[dict[str, Any]] = None, include_optimizer: bool = False, plot_model: bool = True, tags: Optional[Union[list, str]] = None, @@ -276,7 +276,7 @@ def from_pretrained_keras(*args, **kwargs) -> "KerasModelHubMixin": local_files_only(`bool`, *optional*, defaults to `False`): Whether to only look at local files (i.e., do not try to download the model). - model_kwargs (`Dict`, *optional*): + model_kwargs (`dict`, *optional*): model_kwargs will be passed to the model during initialization @@ -302,9 +302,9 @@ def push_to_hub_keras( token: Optional[str] = None, branch: Optional[str] = None, create_pr: Optional[bool] = None, - allow_patterns: Optional[Union[List[str], str]] = None, - ignore_patterns: Optional[Union[List[str], str]] = None, - delete_patterns: Optional[Union[List[str], str]] = None, + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, + delete_patterns: Optional[Union[list[str], str]] = None, log_dir: Optional[str] = None, include_optimizer: bool = False, tags: Optional[Union[list, str]] = None, @@ -344,11 +344,11 @@ def push_to_hub_keras( Defaults to `False`. config (`dict`, *optional*): Configuration object to be saved alongside the model weights. - allow_patterns (`List[str]` or `str`, *optional*): + allow_patterns (`list[str]` or `str`, *optional*): If provided, only files matching at least one pattern are pushed. - ignore_patterns (`List[str]` or `str`, *optional*): + ignore_patterns (`list[str]` or `str`, *optional*): If provided, files matching any of the patterns are not pushed. - delete_patterns (`List[str]` or `str`, *optional*): + delete_patterns (`list[str]` or `str`, *optional*): If provided, remote files matching any of the patterns will be deleted from the repo. log_dir (`str`, *optional*): TensorBoard logging directory to be pushed. The Hub automatically @@ -462,7 +462,7 @@ def _from_pretrained( resume_download, local_files_only, token, - config: Optional[Dict[str, Any]] = None, + config: Optional[dict[str, Any]] = None, **model_kwargs, ): """Here we just call [`from_pretrained_keras`] function so both the mixin and diff --git a/src/huggingface_hub/lfs.py b/src/huggingface_hub/lfs.py index 3ff465f9c0..a626ef28d6 100644 --- a/src/huggingface_hub/lfs.py +++ b/src/huggingface_hub/lfs.py @@ -21,7 +21,7 @@ from math import ceil from os.path import getsize from pathlib import Path -from typing import TYPE_CHECKING, BinaryIO, Dict, Iterable, List, Optional, Tuple, TypedDict +from typing import TYPE_CHECKING, BinaryIO, Iterable, Optional, TypedDict from urllib.parse import unquote from huggingface_hub import constants @@ -106,8 +106,8 @@ def post_lfs_batch_info( repo_id: str, revision: Optional[str] = None, endpoint: Optional[str] = None, - headers: Optional[Dict[str, str]] = None, -) -> Tuple[List[dict], List[dict]]: + headers: Optional[dict[str, str]] = None, +) -> tuple[list[dict], list[dict]]: """ Requests the LFS batch endpoint to retrieve upload instructions @@ -143,7 +143,7 @@ def post_lfs_batch_info( if repo_type in constants.REPO_TYPES_URL_PREFIXES: url_prefix = constants.REPO_TYPES_URL_PREFIXES[repo_type] batch_url = f"{endpoint}/{url_prefix}{repo_id}.git/info/lfs/objects/batch" - payload: Dict = { + payload: dict = { "operation": "upload", "transfers": ["basic", "multipart"], "objects": [ @@ -186,14 +186,14 @@ class CompletionPayloadT(TypedDict): """Payload that will be sent to the Hub when uploading multi-part.""" oid: str - parts: List[PayloadPartT] + parts: list[PayloadPartT] def lfs_upload( operation: "CommitOperationAdd", - lfs_batch_action: Dict, + lfs_batch_action: dict, token: Optional[str] = None, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, endpoint: Optional[str] = None, ) -> None: """ @@ -317,7 +317,7 @@ def _upload_single_part(operation: "CommitOperationAdd", upload_url: str) -> Non hf_raise_for_status(response) -def _upload_multi_part(operation: "CommitOperationAdd", header: Dict, chunk_size: int, upload_url: str) -> None: +def _upload_multi_part(operation: "CommitOperationAdd", header: dict, chunk_size: int, upload_url: str) -> None: """ Uploads file using HF multipart LFS transfer protocol. """ @@ -352,7 +352,7 @@ def _upload_multi_part(operation: "CommitOperationAdd", header: Dict, chunk_size hf_raise_for_status(completion_res) -def _get_sorted_parts_urls(header: Dict, upload_info: UploadInfo, chunk_size: int) -> List[str]: +def _get_sorted_parts_urls(header: dict, upload_info: UploadInfo, chunk_size: int) -> list[str]: sorted_part_upload_urls = [ upload_url for _, upload_url in sorted( @@ -370,8 +370,8 @@ def _get_sorted_parts_urls(header: Dict, upload_info: UploadInfo, chunk_size: in return sorted_part_upload_urls -def _get_completion_payload(response_headers: List[Dict], oid: str) -> CompletionPayloadT: - parts: List[PayloadPartT] = [] +def _get_completion_payload(response_headers: list[dict], oid: str) -> CompletionPayloadT: + parts: list[PayloadPartT] = [] for part_number, header in enumerate(response_headers): etag = header.get("etag") if etag is None or etag == "": @@ -386,8 +386,8 @@ def _get_completion_payload(response_headers: List[Dict], oid: str) -> Completio def _upload_parts_iteratively( - operation: "CommitOperationAdd", sorted_parts_urls: List[str], chunk_size: int -) -> List[Dict]: + operation: "CommitOperationAdd", sorted_parts_urls: list[str], chunk_size: int +) -> list[dict]: headers = [] with operation.as_file(with_tqdm=True) as fileobj: for part_idx, part_upload_url in enumerate(sorted_parts_urls): @@ -406,8 +406,8 @@ def _upload_parts_iteratively( def _upload_parts_hf_transfer( - operation: "CommitOperationAdd", sorted_parts_urls: List[str], chunk_size: int -) -> List[Dict]: + operation: "CommitOperationAdd", sorted_parts_urls: list[str], chunk_size: int +) -> list[dict]: # Upload file using an external Rust-based package. Upload is faster but support less features (no progress bars). try: from hf_transfer import multipart_upload diff --git a/src/huggingface_hub/repocard.py b/src/huggingface_hub/repocard.py index c8c9a28a17..4e0f775ddd 100644 --- a/src/huggingface_hub/repocard.py +++ b/src/huggingface_hub/repocard.py @@ -1,7 +1,7 @@ import os import re from pathlib import Path -from typing import Any, Dict, Literal, Optional, Type, Union +from typing import Any, Literal, Optional, Union import yaml @@ -335,7 +335,7 @@ def from_template( class ModelCard(RepoCard): - card_data_class = ModelCardData + card_data_class = ModelCardData # type: ignore[assignment] default_template_path = TEMPLATE_MODELCARD_PATH repo_type = "model" @@ -416,7 +416,7 @@ def from_template( # type: ignore # violates Liskov property but easier to use class DatasetCard(RepoCard): - card_data_class = DatasetCardData + card_data_class = DatasetCardData # type: ignore[assignment] default_template_path = TEMPLATE_DATASETCARD_PATH repo_type = "dataset" @@ -481,7 +481,7 @@ def from_template( # type: ignore # violates Liskov property but easier to use class SpaceCard(RepoCard): - card_data_class = SpaceCardData + card_data_class = SpaceCardData # type: ignore[assignment] default_template_path = TEMPLATE_MODELCARD_PATH repo_type = "space" @@ -507,7 +507,7 @@ def _detect_line_ending(content: str) -> Literal["\r", "\n", "\r\n", None]: # n return "\n" -def metadata_load(local_path: Union[str, Path]) -> Optional[Dict]: +def metadata_load(local_path: Union[str, Path]) -> Optional[dict]: content = Path(local_path).read_text() match = REGEX_YAML_BLOCK.search(content) if match: @@ -520,7 +520,7 @@ def metadata_load(local_path: Union[str, Path]) -> Optional[Dict]: return None -def metadata_save(local_path: Union[str, Path], data: Dict) -> None: +def metadata_save(local_path: Union[str, Path], data: dict) -> None: """ Save the metadata dict in the upper YAML part Trying to preserve newlines as in the existing file. Docs about open() with newline="" parameter: @@ -568,7 +568,7 @@ def metadata_eval_result( dataset_split: Optional[str] = None, dataset_revision: Optional[str] = None, metrics_verification_token: Optional[str] = None, -) -> Dict: +) -> dict: """ Creates a metadata dict with the result from a model evaluated on a dataset. @@ -683,7 +683,7 @@ def metadata_eval_result( @validate_hf_hub_args def metadata_update( repo_id: str, - metadata: Dict, + metadata: dict, *, repo_type: Optional[str] = None, overwrite: bool = False, @@ -751,7 +751,7 @@ def metadata_update( commit_message = commit_message if commit_message is not None else "Update metadata with huggingface_hub" # Card class given repo_type - card_class: Type[RepoCard] + card_class: type[RepoCard] if repo_type is None or repo_type == "model": card_class = ModelCard elif repo_type == "dataset": diff --git a/src/huggingface_hub/repocard_data.py b/src/huggingface_hub/repocard_data.py index 62215f2274..1d283f5baa 100644 --- a/src/huggingface_hub/repocard_data.py +++ b/src/huggingface_hub/repocard_data.py @@ -1,7 +1,7 @@ import copy from collections import defaultdict from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union from huggingface_hub.utils import logging, yaml_dump @@ -38,7 +38,7 @@ class EvalResult: dataset_revision (`str`, *optional*): The revision (AKA Git Sha) of the dataset used in `load_dataset()`. Example: 5503434ddd753f426f4b38109466949a1217c2bb - dataset_args (`Dict[str, Any]`, *optional*): + dataset_args (`dict[str, Any]`, *optional*): The arguments passed during `Metric.compute()`. Example for `bleu`: `{"max_order": 4}` metric_name (`str`, *optional*): A pretty name for the metric. Example: "Test WER". @@ -46,7 +46,7 @@ class EvalResult: The name of the metric configuration used in `load_metric()`. Example: bleurt-large-512 in `load_metric("bleurt", "bleurt-large-512")`. See the `datasets` docs for more info: https://huggingface.co/docs/datasets/v2.1.0/en/loading#load-configurations - metric_args (`Dict[str, Any]`, *optional*): + metric_args (`dict[str, Any]`, *optional*): The arguments passed during `Metric.compute()`. Example for `bleu`: max_order: 4 verified (`bool`, *optional*): Indicates whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. Automatically computed by Hugging Face, do not set. @@ -102,7 +102,7 @@ class EvalResult: # The arguments passed during `Metric.compute()`. # Example for `bleu`: max_order: 4 - dataset_args: Optional[Dict[str, Any]] = None + dataset_args: Optional[dict[str, Any]] = None # A pretty name for the metric. # Example: Test WER @@ -115,7 +115,7 @@ class EvalResult: # The arguments passed during `Metric.compute()`. # Example for `bleu`: max_order: 4 - metric_args: Optional[Dict[str, Any]] = None + metric_args: Optional[dict[str, Any]] = None # Indicates whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. Automatically computed by Hugging Face, do not set. verified: Optional[bool] = None @@ -195,7 +195,7 @@ def _to_dict(self, data_dict): """ pass - def to_yaml(self, line_break=None, original_order: Optional[List[str]] = None) -> str: + def to_yaml(self, line_break=None, original_order: Optional[list[str]] = None) -> str: """Dumps CardData to a YAML block for inclusion in a README.md file. Args: @@ -246,9 +246,9 @@ def __len__(self) -> int: def _validate_eval_results( - eval_results: Optional[Union[EvalResult, List[EvalResult]]], + eval_results: Optional[Union[EvalResult, list[EvalResult]]], model_name: Optional[str], -) -> List[EvalResult]: +) -> list[EvalResult]: if eval_results is None: return [] if isinstance(eval_results, EvalResult): @@ -266,17 +266,17 @@ class ModelCardData(CardData): """Model Card Metadata that is used by Hugging Face Hub when included at the top of your README.md Args: - base_model (`str` or `List[str]`, *optional*): + base_model (`str` or `list[str]`, *optional*): The identifier of the base model from which the model derives. This is applicable for example if your model is a fine-tune or adapter of an existing model. The value must be the ID of a model on the Hub (or a list of IDs if your model derives from multiple models). Defaults to None. - datasets (`Union[str, List[str]]`, *optional*): + datasets (`Union[str, list[str]]`, *optional*): Dataset or list of datasets that were used to train this model. Should be a dataset ID found on https://hf.co/datasets. Defaults to None. - eval_results (`Union[List[EvalResult], EvalResult]`, *optional*): + eval_results (`Union[list[EvalResult], EvalResult]`, *optional*): List of `huggingface_hub.EvalResult` that define evaluation results of the model. If provided, `model_name` is used to as a name on PapersWithCode's leaderboards. Defaults to `None`. - language (`Union[str, List[str]]`, *optional*): + language (`Union[str, list[str]]`, *optional*): Language of model's training data or metadata. It must be an ISO 639-1, 639-2 or 639-3 code (two/three letters), or a special value like "code", "multilingual". Defaults to `None`. library_name (`str`, *optional*): @@ -292,7 +292,7 @@ class ModelCardData(CardData): license_link (`str`, *optional*): Link to the license of this model. Defaults to None. To be used in conjunction with `license_name`. Common licenses (Apache-2.0, MIT, CC-BY-SA-4.0) do not need a link. In that case, use `license` instead. - metrics (`List[str]`, *optional*): + metrics (`list[str]`, *optional*): List of metrics used to evaluate this model. Should be a metric name that can be found at https://hf.co/metrics. Example: 'accuracy'. Defaults to None. model_name (`str`, *optional*): @@ -302,7 +302,7 @@ class ModelCardData(CardData): then the repo name is used as a default. Defaults to None. pipeline_tag (`str`, *optional*): The pipeline tag associated with the model. Example: "text-classification". - tags (`List[str]`, *optional*): + tags (`list[str]`, *optional*): List of tags to add to your model that can be used when filtering on the Hugging Face Hub. Defaults to None. ignore_metadata_errors (`str`): @@ -329,18 +329,18 @@ class ModelCardData(CardData): def __init__( self, *, - base_model: Optional[Union[str, List[str]]] = None, - datasets: Optional[Union[str, List[str]]] = None, - eval_results: Optional[List[EvalResult]] = None, - language: Optional[Union[str, List[str]]] = None, + base_model: Optional[Union[str, list[str]]] = None, + datasets: Optional[Union[str, list[str]]] = None, + eval_results: Optional[list[EvalResult]] = None, + language: Optional[Union[str, list[str]]] = None, library_name: Optional[str] = None, license: Optional[str] = None, license_name: Optional[str] = None, license_link: Optional[str] = None, - metrics: Optional[List[str]] = None, + metrics: Optional[list[str]] = None, model_name: Optional[str] = None, pipeline_tag: Optional[str] = None, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, ignore_metadata_errors: bool = False, **kwargs, ): @@ -395,58 +395,58 @@ class DatasetCardData(CardData): """Dataset Card Metadata that is used by Hugging Face Hub when included at the top of your README.md Args: - language (`List[str]`, *optional*): + language (`list[str]`, *optional*): Language of dataset's data or metadata. It must be an ISO 639-1, 639-2 or 639-3 code (two/three letters), or a special value like "code", "multilingual". - license (`Union[str, List[str]]`, *optional*): + license (`Union[str, list[str]]`, *optional*): License(s) of this dataset. Example: apache-2.0 or any license from https://huggingface.co/docs/hub/repositories-licenses. - annotations_creators (`Union[str, List[str]]`, *optional*): + annotations_creators (`Union[str, list[str]]`, *optional*): How the annotations for the dataset were created. Options are: 'found', 'crowdsourced', 'expert-generated', 'machine-generated', 'no-annotation', 'other'. - language_creators (`Union[str, List[str]]`, *optional*): + language_creators (`Union[str, list[str]]`, *optional*): How the text-based data in the dataset was created. Options are: 'found', 'crowdsourced', 'expert-generated', 'machine-generated', 'other' - multilinguality (`Union[str, List[str]]`, *optional*): + multilinguality (`Union[str, list[str]]`, *optional*): Whether the dataset is multilingual. Options are: 'monolingual', 'multilingual', 'translation', 'other'. - size_categories (`Union[str, List[str]]`, *optional*): + size_categories (`Union[str, list[str]]`, *optional*): The number of examples in the dataset. Options are: 'n<1K', '1K1T', and 'other'. - source_datasets (`List[str]]`, *optional*): + source_datasets (`list[str]]`, *optional*): Indicates whether the dataset is an original dataset or extended from another existing dataset. Options are: 'original' and 'extended'. - task_categories (`Union[str, List[str]]`, *optional*): + task_categories (`Union[str, list[str]]`, *optional*): What categories of task does the dataset support? - task_ids (`Union[str, List[str]]`, *optional*): + task_ids (`Union[str, list[str]]`, *optional*): What specific tasks does the dataset support? paperswithcode_id (`str`, *optional*): ID of the dataset on PapersWithCode. pretty_name (`str`, *optional*): A more human-readable name for the dataset. (ex. "Cats vs. Dogs") - train_eval_index (`Dict`, *optional*): + train_eval_index (`dict`, *optional*): A dictionary that describes the necessary spec for doing evaluation on the Hub. If not provided, it will be gathered from the 'train-eval-index' key of the kwargs. - config_names (`Union[str, List[str]]`, *optional*): + config_names (`Union[str, list[str]]`, *optional*): A list of the available dataset configs for the dataset. """ def __init__( self, *, - language: Optional[Union[str, List[str]]] = None, - license: Optional[Union[str, List[str]]] = None, - annotations_creators: Optional[Union[str, List[str]]] = None, - language_creators: Optional[Union[str, List[str]]] = None, - multilinguality: Optional[Union[str, List[str]]] = None, - size_categories: Optional[Union[str, List[str]]] = None, - source_datasets: Optional[List[str]] = None, - task_categories: Optional[Union[str, List[str]]] = None, - task_ids: Optional[Union[str, List[str]]] = None, + language: Optional[Union[str, list[str]]] = None, + license: Optional[Union[str, list[str]]] = None, + annotations_creators: Optional[Union[str, list[str]]] = None, + language_creators: Optional[Union[str, list[str]]] = None, + multilinguality: Optional[Union[str, list[str]]] = None, + size_categories: Optional[Union[str, list[str]]] = None, + source_datasets: Optional[list[str]] = None, + task_categories: Optional[Union[str, list[str]]] = None, + task_ids: Optional[Union[str, list[str]]] = None, paperswithcode_id: Optional[str] = None, pretty_name: Optional[str] = None, - train_eval_index: Optional[Dict] = None, - config_names: Optional[Union[str, List[str]]] = None, + train_eval_index: Optional[dict] = None, + config_names: Optional[Union[str, list[str]]] = None, ignore_metadata_errors: bool = False, **kwargs, ): @@ -495,11 +495,11 @@ class SpaceCardData(CardData): https://huggingface.co/docs/hub/repositories-licenses. duplicated_from (`str`, *optional*) ID of the original Space if this is a duplicated Space. - models (List[`str`], *optional*) + models (list[`str`], *optional*) List of models related to this Space. Should be a dataset ID found on https://hf.co/models. - datasets (`List[str]`, *optional*) + datasets (`list[str]`, *optional*) List of datasets related to this Space. Should be a dataset ID found on https://hf.co/datasets. - tags (`List[str]`, *optional*) + tags (`list[str]`, *optional*) List of tags to add to your Space that can be used when filtering on the Hub. ignore_metadata_errors (`str`): If True, errors while parsing the metadata section will be ignored. Some information might be lost during @@ -532,9 +532,9 @@ def __init__( app_port: Optional[int] = None, license: Optional[str] = None, duplicated_from: Optional[str] = None, - models: Optional[List[str]] = None, - datasets: Optional[List[str]] = None, - tags: Optional[List[str]] = None, + models: Optional[list[str]] = None, + datasets: Optional[list[str]] = None, + tags: Optional[list[str]] = None, ignore_metadata_errors: bool = False, **kwargs, ): @@ -552,14 +552,14 @@ def __init__( super().__init__(**kwargs) -def model_index_to_eval_results(model_index: List[Dict[str, Any]]) -> Tuple[str, List[EvalResult]]: +def model_index_to_eval_results(model_index: list[dict[str, Any]]) -> tuple[str, list[EvalResult]]: """Takes in a model index and returns the model name and a list of `huggingface_hub.EvalResult` objects. A detailed spec of the model index can be found here: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 Args: - model_index (`List[Dict[str, Any]]`): + model_index (`list[dict[str, Any]]`): A model index data structure, likely coming from a README.md file on the Hugging Face Hub. @@ -567,7 +567,7 @@ def model_index_to_eval_results(model_index: List[Dict[str, Any]]) -> Tuple[str, model_name (`str`): The name of the model as found in the model index. This is used as the identifier for the model on leaderboards like PapersWithCode. - eval_results (`List[EvalResult]`): + eval_results (`list[EvalResult]`): A list of `huggingface_hub.EvalResult` objects containing the metrics reported in the provided model_index. @@ -668,7 +668,7 @@ def _remove_none(obj): return obj -def eval_results_to_model_index(model_name: str, eval_results: List[EvalResult]) -> List[Dict[str, Any]]: +def eval_results_to_model_index(model_name: str, eval_results: list[EvalResult]) -> list[dict[str, Any]]: """Takes in given model name and list of `huggingface_hub.EvalResult` and returns a valid model-index that will be compatible with the format expected by the Hugging Face Hub. @@ -677,12 +677,12 @@ def eval_results_to_model_index(model_name: str, eval_results: List[EvalResult]) model_name (`str`): Name of the model (ex. "my-cool-model"). This is used as the identifier for the model on leaderboards like PapersWithCode. - eval_results (`List[EvalResult]`): + eval_results (`list[EvalResult]`): List of `huggingface_hub.EvalResult` objects containing the metrics to be reported in the model-index. Returns: - model_index (`List[Dict[str, Any]]`): The eval_results converted to a model-index. + model_index (`list[dict[str, Any]]`): The eval_results converted to a model-index. Example: ```python @@ -705,7 +705,7 @@ def eval_results_to_model_index(model_name: str, eval_results: List[EvalResult]) # Metrics are reported on a unique task-and-dataset basis. # Here, we make a map of those pairs and the associated EvalResults. - task_and_ds_types_map: Dict[Any, List[EvalResult]] = defaultdict(list) + task_and_ds_types_map: dict[Any, list[EvalResult]] = defaultdict(list) for eval_result in eval_results: task_and_ds_types_map[eval_result.unique_identifier].append(eval_result) @@ -760,7 +760,7 @@ def eval_results_to_model_index(model_name: str, eval_results: List[EvalResult]) return _remove_none(model_index) -def _to_unique_list(tags: Optional[List[str]]) -> Optional[List[str]]: +def _to_unique_list(tags: Optional[list[str]]) -> Optional[list[str]]: if tags is None: return tags unique_tags = [] # make tags unique + keep order explicitly diff --git a/src/huggingface_hub/repository.py b/src/huggingface_hub/repository.py index d4a904f458..387761cedc 100644 --- a/src/huggingface_hub/repository.py +++ b/src/huggingface_hub/repository.py @@ -6,7 +6,7 @@ import time from contextlib import contextmanager from pathlib import Path -from typing import Callable, Dict, Iterator, List, Optional, Tuple, TypedDict, Union +from typing import Callable, Iterator, Optional, TypedDict, Union from urllib.parse import urlparse from huggingface_hub import constants @@ -238,7 +238,7 @@ def is_binary_file(filename: Union[str, Path]) -> bool: return True -def files_to_be_staged(pattern: str = ".", folder: Union[str, Path, None] = None) -> List[str]: +def files_to_be_staged(pattern: str = ".", folder: Union[str, Path, None] = None) -> list[str]: """ Returns a list of filenames that are to be staged. @@ -249,7 +249,7 @@ def files_to_be_staged(pattern: str = ".", folder: Union[str, Path, None] = None The folder in which to run the command. Returns: - `List[str]`: List of files that are to be staged. + `list[str]`: List of files that are to be staged. """ try: p = run_subprocess("git ls-files --exclude-standard -mo".split() + [pattern], folder) @@ -333,7 +333,7 @@ def output_progress(stopping_event: threading.Event): the tail. """ # Key is tuple(state, filename), value is a dict(tqdm bar and a previous value) - pbars: Dict[Tuple[str, str], PbarT] = {} + pbars: dict[tuple[str, str], PbarT] = {} def close_pbars(): for pbar in pbars.values(): @@ -441,7 +441,7 @@ class Repository: """ - command_queue: List[CommandInProgress] + command_queue: list[CommandInProgress] @validate_hf_hub_args @_deprecate_method( @@ -796,13 +796,13 @@ def git_head_commit_url(self) -> str: url = url[:-1] return f"{url}/commit/{sha}" - def list_deleted_files(self) -> List[str]: + def list_deleted_files(self) -> list[str]: """ Returns a list of the files that are deleted in the working directory or index. Returns: - `List[str]`: A list of files that have been deleted in the working + `list[str]`: A list of files that have been deleted in the working directory or index. """ try: @@ -831,7 +831,7 @@ def list_deleted_files(self) -> List[str]: return deleted_files - def lfs_track(self, patterns: Union[str, List[str]], filename: bool = False): + def lfs_track(self, patterns: Union[str, list[str]], filename: bool = False): """ Tell git-lfs to track files according to a pattern. @@ -840,7 +840,7 @@ def lfs_track(self, patterns: Union[str, List[str]], filename: bool = False): filename will be escaped when writing to the `.gitattributes` file. Args: - patterns (`Union[str, List[str]]`): + patterns (`Union[str, list[str]]`): The pattern, or list of patterns, to track with git-lfs. filename (`bool`, *optional*, defaults to `False`): Whether to use the patterns as literal filenames. @@ -856,12 +856,12 @@ def lfs_track(self, patterns: Union[str, List[str]], filename: bool = False): except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) - def lfs_untrack(self, patterns: Union[str, List[str]]): + def lfs_untrack(self, patterns: Union[str, list[str]]): """ Tell git-lfs to untrack those files. Args: - patterns (`Union[str, List[str]]`): + patterns (`Union[str, list[str]]`): The pattern, or list of patterns, to untrack with git-lfs. """ if isinstance(patterns, str): @@ -886,7 +886,7 @@ def lfs_enable_largefiles(self): except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) - def auto_track_binary_files(self, pattern: str = ".") -> List[str]: + def auto_track_binary_files(self, pattern: str = ".") -> list[str]: """ Automatically track binary files with git-lfs. @@ -895,7 +895,7 @@ def auto_track_binary_files(self, pattern: str = ".") -> List[str]: The pattern with which to track files that are binary. Returns: - `List[str]`: List of filenames that are now tracked due to being + `list[str]`: List of filenames that are now tracked due to being binary files """ files_to_be_tracked_with_lfs = [] @@ -929,7 +929,7 @@ def auto_track_binary_files(self, pattern: str = ".") -> List[str]: return files_to_be_tracked_with_lfs - def auto_track_large_files(self, pattern: str = ".") -> List[str]: + def auto_track_large_files(self, pattern: str = ".") -> list[str]: """ Automatically track large files (files that weigh more than 10MBs) with git-lfs. @@ -939,7 +939,7 @@ def auto_track_large_files(self, pattern: str = ".") -> List[str]: The pattern with which to track files that are above 10MBs. Returns: - `List[str]`: List of filenames that are now tracked due to their + `list[str]`: List of filenames that are now tracked due to their size. """ files_to_be_tracked_with_lfs = [] @@ -1060,7 +1060,7 @@ def git_push( upstream: Optional[str] = None, blocking: bool = True, auto_lfs_prune: bool = False, - ) -> Union[str, Tuple[str, CommandInProgress]]: + ) -> Union[str, tuple[str, CommandInProgress]]: """ git push @@ -1298,7 +1298,7 @@ def push_to_hub( blocking: bool = True, clean_ok: bool = True, auto_lfs_prune: bool = False, - ) -> Union[None, str, Tuple[str, CommandInProgress]]: + ) -> Union[None, str, tuple[str, CommandInProgress]]: """ Helper to add, commit, and push files to remote repository on the HuggingFace Hub. Will automatically track large files (>10MB). @@ -1433,13 +1433,13 @@ def commit( os.chdir(current_working_directory) - def repocard_metadata_load(self) -> Optional[Dict]: + def repocard_metadata_load(self) -> Optional[dict]: filepath = os.path.join(self.local_dir, constants.REPOCARD_NAME) if os.path.isfile(filepath): return metadata_load(filepath) return None - def repocard_metadata_save(self, data: Dict) -> None: + def repocard_metadata_save(self, data: dict) -> None: return metadata_save(os.path.join(self.local_dir, constants.REPOCARD_NAME), data) @property diff --git a/src/huggingface_hub/serialization/_base.py b/src/huggingface_hub/serialization/_base.py index b7b6454a90..53b72f6c4a 100644 --- a/src/huggingface_hub/serialization/_base.py +++ b/src/huggingface_hub/serialization/_base.py @@ -14,7 +14,7 @@ """Contains helpers to split tensors into shards.""" from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Optional, TypeVar, Union +from typing import Any, Callable, Optional, TypeVar, Union from .. import logging @@ -38,16 +38,16 @@ @dataclass class StateDictSplit: is_sharded: bool = field(init=False) - metadata: Dict[str, Any] - filename_to_tensors: Dict[str, List[str]] - tensor_to_filename: Dict[str, str] + metadata: dict[str, Any] + filename_to_tensors: dict[str, list[str]] + tensor_to_filename: dict[str, str] def __post_init__(self): self.is_sharded = len(self.filename_to_tensors) > 1 def split_state_dict_into_shards_factory( - state_dict: Dict[str, TensorT], + state_dict: dict[str, TensorT], *, get_storage_size: TensorSizeFn_T, filename_pattern: str, @@ -70,7 +70,7 @@ def split_state_dict_into_shards_factory( Args: - state_dict (`Dict[str, Tensor]`): + state_dict (`dict[str, Tensor]`): The state dictionary to save. get_storage_size (`Callable[[Tensor], int]`): A function that returns the size of a tensor when saved on disk in bytes. @@ -87,10 +87,10 @@ def split_state_dict_into_shards_factory( Returns: [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them. """ - storage_id_to_tensors: Dict[Any, List[str]] = {} + storage_id_to_tensors: dict[Any, list[str]] = {} - shard_list: List[Dict[str, TensorT]] = [] - current_shard: Dict[str, TensorT] = {} + shard_list: list[dict[str, TensorT]] = [] + current_shard: dict[str, TensorT] = {} current_shard_size = 0 total_size = 0 diff --git a/src/huggingface_hub/serialization/_dduf.py b/src/huggingface_hub/serialization/_dduf.py index a1debadb3a..c184509c63 100644 --- a/src/huggingface_hub/serialization/_dduf.py +++ b/src/huggingface_hub/serialization/_dduf.py @@ -7,7 +7,7 @@ from contextlib import contextmanager from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, Generator, Iterable, Tuple, Union +from typing import Any, Generator, Iterable, Union from ..errors import DDUFCorruptedFileError, DDUFExportError, DDUFInvalidEntryNameError @@ -87,7 +87,7 @@ def read_text(self, encoding: str = "utf-8") -> str: return f.read(self.length).decode(encoding=encoding) -def read_dduf_file(dduf_path: Union[os.PathLike, str]) -> Dict[str, DDUFEntry]: +def read_dduf_file(dduf_path: Union[os.PathLike, str]) -> dict[str, DDUFEntry]: """ Read a DDUF file and return a dictionary of entries. @@ -98,7 +98,7 @@ def read_dduf_file(dduf_path: Union[os.PathLike, str]) -> Dict[str, DDUFEntry]: The path to the DDUF file to read. Returns: - `Dict[str, DDUFEntry]`: + `dict[str, DDUFEntry]`: A dictionary of [`DDUFEntry`] indexed by filename. Raises: @@ -157,7 +157,7 @@ def read_dduf_file(dduf_path: Union[os.PathLike, str]) -> Dict[str, DDUFEntry]: def export_entries_as_dduf( - dduf_path: Union[str, os.PathLike], entries: Iterable[Tuple[str, Union[str, Path, bytes]]] + dduf_path: Union[str, os.PathLike], entries: Iterable[tuple[str, Union[str, Path, bytes]]] ) -> None: """Write a DDUF file from an iterable of entries. @@ -167,7 +167,7 @@ def export_entries_as_dduf( Args: dduf_path (`str` or `os.PathLike`): The path to the DDUF file to write. - entries (`Iterable[Tuple[str, Union[str, Path, bytes]]]`): + entries (`Iterable[tuple[str, Union[str, Path, bytes]]]`): An iterable of entries to write in the DDUF file. Each entry is a tuple with the filename and the content. The filename should be the path to the file in the DDUF archive. The content can be a string or a pathlib.Path representing a path to a file on the local disk or directly the content as bytes. @@ -201,7 +201,7 @@ def export_entries_as_dduf( >>> pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") ... # ... do some work with the pipeline - >>> def as_entries(pipe: DiffusionPipeline) -> Generator[Tuple[str, bytes], None, None]: + >>> def as_entries(pipe: DiffusionPipeline) -> Generator[tuple[str, bytes], None, None]: ... # Build an generator that yields the entries to add to the DDUF file. ... # The first element of the tuple is the filename in the DDUF archive (must use UNIX separator!). The second element is the content of the file. ... # Entries will be evaluated lazily when the DDUF file is created (only 1 entry is loaded in memory at a time) @@ -267,7 +267,7 @@ def export_folder_as_dduf(dduf_path: Union[str, os.PathLike], folder_path: Union """ folder_path = Path(folder_path) - def _iterate_over_folder() -> Iterable[Tuple[str, Path]]: + def _iterate_over_folder() -> Iterable[tuple[str, Path]]: for path in Path(folder_path).glob("**/*"): if not path.is_file(): continue diff --git a/src/huggingface_hub/serialization/_tensorflow.py b/src/huggingface_hub/serialization/_tensorflow.py index 59ed8110b2..affcaf4834 100644 --- a/src/huggingface_hub/serialization/_tensorflow.py +++ b/src/huggingface_hub/serialization/_tensorflow.py @@ -15,7 +15,7 @@ import math import re -from typing import TYPE_CHECKING, Dict, Union +from typing import TYPE_CHECKING, Union from .. import constants from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory @@ -26,7 +26,7 @@ def split_tf_state_dict_into_shards( - state_dict: Dict[str, "tf.Tensor"], + state_dict: dict[str, "tf.Tensor"], *, filename_pattern: str = constants.TF2_WEIGHTS_FILE_PATTERN, max_shard_size: Union[int, str] = MAX_SHARD_SIZE, @@ -47,7 +47,7 @@ def split_tf_state_dict_into_shards( Args: - state_dict (`Dict[str, Tensor]`): + state_dict (`dict[str, Tensor]`): The state dictionary to save. filename_pattern (`str`, *optional*): The pattern to generate the files names in which the model will be saved. Pattern must be a string that diff --git a/src/huggingface_hub/serialization/_torch.py b/src/huggingface_hub/serialization/_torch.py index c5c70fc89b..00c591134b 100644 --- a/src/huggingface_hub/serialization/_torch.py +++ b/src/huggingface_hub/serialization/_torch.py @@ -20,7 +20,7 @@ from collections import defaultdict, namedtuple from functools import lru_cache from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Iterable, NamedTuple, Optional, Union from packaging import version @@ -43,10 +43,10 @@ def save_torch_model( filename_pattern: Optional[str] = None, force_contiguous: bool = True, max_shard_size: Union[int, str] = MAX_SHARD_SIZE, - metadata: Optional[Dict[str, str]] = None, + metadata: Optional[dict[str, str]] = None, safe_serialization: bool = True, is_main_process: bool = True, - shared_tensors_to_discard: Optional[List[str]] = None, + shared_tensors_to_discard: Optional[list[str]] = None, ): """ Saves a given torch model to disk, handling sharding and shared tensors issues. @@ -92,7 +92,7 @@ def save_torch_model( that reason. Defaults to `True`. max_shard_size (`int` or `str`, *optional*): The maximum size of each shard, in bytes. Defaults to 5GB. - metadata (`Dict[str, str]`, *optional*): + metadata (`dict[str, str]`, *optional*): Extra information to save along with the model. Some metadata will be added for each dropped tensors. This information will not be enough to recover the entire shared structure but might help understanding things. @@ -104,7 +104,7 @@ def save_torch_model( Whether the process calling this is the main process or not. Useful when in distributed training like TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. Defaults to True. - shared_tensors_to_discard (`List[str]`, *optional*): + shared_tensors_to_discard (`list[str]`, *optional*): List of tensor names to drop when saving shared tensors. If not provided and shared tensors are detected, it will drop the first name alphabetically. @@ -137,16 +137,16 @@ def save_torch_model( def save_torch_state_dict( - state_dict: Dict[str, "torch.Tensor"], + state_dict: dict[str, "torch.Tensor"], save_directory: Union[str, Path], *, filename_pattern: Optional[str] = None, force_contiguous: bool = True, max_shard_size: Union[int, str] = MAX_SHARD_SIZE, - metadata: Optional[Dict[str, str]] = None, + metadata: Optional[dict[str, str]] = None, safe_serialization: bool = True, is_main_process: bool = True, - shared_tensors_to_discard: Optional[List[str]] = None, + shared_tensors_to_discard: Optional[list[str]] = None, ) -> None: """ Save a model state dictionary to the disk, handling sharding and shared tensors issues. @@ -177,7 +177,7 @@ def save_torch_state_dict( Args: - state_dict (`Dict[str, torch.Tensor]`): + state_dict (`dict[str, torch.Tensor]`): The state dictionary to save. save_directory (`str` or `Path`): The directory in which the model will be saved. @@ -192,7 +192,7 @@ def save_torch_state_dict( that reason. Defaults to `True`. max_shard_size (`int` or `str`, *optional*): The maximum size of each shard, in bytes. Defaults to 5GB. - metadata (`Dict[str, str]`, *optional*): + metadata (`dict[str, str]`, *optional*): Extra information to save along with the model. Some metadata will be added for each dropped tensors. This information will not be enough to recover the entire shared structure but might help understanding things. @@ -204,7 +204,7 @@ def save_torch_state_dict( Whether the process calling this is the main process or not. Useful when in distributed training like TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. Defaults to True. - shared_tensors_to_discard (`List[str]`, *optional*): + shared_tensors_to_discard (`list[str]`, *optional*): List of tensor names to drop when saving shared tensors. If not provided and shared tensors are detected, it will drop the first name alphabetically. @@ -300,7 +300,7 @@ def save_torch_state_dict( def split_torch_state_dict_into_shards( - state_dict: Dict[str, "torch.Tensor"], + state_dict: dict[str, "torch.Tensor"], *, filename_pattern: str = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN, max_shard_size: Union[int, str] = MAX_SHARD_SIZE, @@ -329,7 +329,7 @@ def split_torch_state_dict_into_shards( Args: - state_dict (`Dict[str, torch.Tensor]`): + state_dict (`dict[str, torch.Tensor]`): The state dictionary to save. filename_pattern (`str`, *optional*): The pattern to generate the files names in which the model will be saved. Pattern must be a string that @@ -348,7 +348,7 @@ def split_torch_state_dict_into_shards( >>> from safetensors.torch import save_file as safe_save_file >>> from huggingface_hub import split_torch_state_dict_into_shards - >>> def save_state_dict(state_dict: Dict[str, torch.Tensor], save_directory: str): + >>> def save_state_dict(state_dict: dict[str, torch.Tensor], save_directory: str): ... state_dict_split = split_torch_state_dict_into_shards(state_dict) ... for filename, tensors in state_dict_split.filename_to_tensors.items(): ... shard = {tensor: state_dict[tensor] for tensor in tensors} @@ -560,7 +560,7 @@ def load_state_dict_from_file( map_location: Optional[Union[str, "torch.device"]] = None, weights_only: bool = False, mmap: bool = False, -) -> Union[Dict[str, "torch.Tensor"], Any]: +) -> Union[dict[str, "torch.Tensor"], Any]: """ Loads a checkpoint file, handling both safetensors and pickle checkpoint formats. @@ -580,7 +580,7 @@ def load_state_dict_from_file( loading safetensors files, as the `safetensors` library uses memory mapping by default. Returns: - `Union[Dict[str, "torch.Tensor"], Any]`: The loaded checkpoint. + `Union[dict[str, "torch.Tensor"], Any]`: The loaded checkpoint. - For safetensors files: always returns a dictionary mapping parameter names to tensors. - For pickle files: returns any Python object that was pickled (commonly a state dict, but could be an entire model, optimizer state, or any other Python object). @@ -700,7 +700,7 @@ def _validate_keys_for_strict_loading( raise RuntimeError(error_message) -def _get_unique_id(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]: +def _get_unique_id(tensor: "torch.Tensor") -> Union[int, tuple[Any, ...]]: """Returns a unique id for plain tensor or a (potentially nested) Tuple of unique id for the flattened Tensor if the input is a wrapper tensor subclass Tensor @@ -741,7 +741,7 @@ def _get_unique_id(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]: return unique_id -def get_torch_storage_id(tensor: "torch.Tensor") -> Optional[Tuple["torch.device", Union[int, Tuple[Any, ...]], int]]: +def get_torch_storage_id(tensor: "torch.Tensor") -> Optional[tuple["torch.device", Union[int, tuple[Any, ...]], int]]: """ Return unique identifier to a tensor storage. @@ -815,7 +815,7 @@ def is_torch_tpu_available(check_device=True): return False -def storage_ptr(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]: +def storage_ptr(tensor: "torch.Tensor") -> Union[int, tuple[Any, ...]]: """ Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L11. """ @@ -841,10 +841,10 @@ def storage_ptr(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]: def _clean_state_dict_for_safetensors( - state_dict: Dict[str, "torch.Tensor"], - metadata: Dict[str, str], + state_dict: dict[str, "torch.Tensor"], + metadata: dict[str, str], force_contiguous: bool = True, - shared_tensors_to_discard: Optional[List[str]] = None, + shared_tensors_to_discard: Optional[list[str]] = None, ): """Remove shared tensors from state_dict and update metadata accordingly (for reloading). @@ -878,7 +878,7 @@ def _end_ptr(tensor: "torch.Tensor") -> int: return stop -def _filter_shared_not_shared(tensors: List[Set[str]], state_dict: Dict[str, "torch.Tensor"]) -> List[Set[str]]: +def _filter_shared_not_shared(tensors: list[set[str]], state_dict: dict[str, "torch.Tensor"]) -> list[set[str]]: """ Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L44 """ @@ -906,7 +906,7 @@ def _filter_shared_not_shared(tensors: List[Set[str]], state_dict: Dict[str, "to return filtered_tensors -def _find_shared_tensors(state_dict: Dict[str, "torch.Tensor"]) -> List[Set[str]]: +def _find_shared_tensors(state_dict: dict[str, "torch.Tensor"]) -> list[set[str]]: """ Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L69. """ @@ -943,11 +943,11 @@ def _is_complete(tensor: "torch.Tensor") -> bool: def _remove_duplicate_names( - state_dict: Dict[str, "torch.Tensor"], + state_dict: dict[str, "torch.Tensor"], *, - preferred_names: Optional[List[str]] = None, - discard_names: Optional[List[str]] = None, -) -> Dict[str, List[str]]: + preferred_names: Optional[list[str]] = None, + discard_names: Optional[list[str]] = None, +) -> dict[str, list[str]]: """ Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L80 """ diff --git a/src/huggingface_hub/utils/_auth.py b/src/huggingface_hub/utils/_auth.py index 72be4dedbd..f19ac3e5f6 100644 --- a/src/huggingface_hub/utils/_auth.py +++ b/src/huggingface_hub/utils/_auth.py @@ -19,7 +19,7 @@ import warnings from pathlib import Path from threading import Lock -from typing import Dict, Optional +from typing import Optional from .. import constants from ._runtime import is_colab_enterprise, is_google_colab @@ -125,13 +125,13 @@ def _get_token_from_file() -> Optional[str]: return None -def get_stored_tokens() -> Dict[str, str]: +def get_stored_tokens() -> dict[str, str]: """ Returns the parsed INI file containing the access tokens. The file is located at `HF_STORED_TOKENS_PATH`, defaulting to `~/.cache/huggingface/stored_tokens`. If the file does not exist, an empty dictionary is returned. - Returns: `Dict[str, str]` + Returns: `dict[str, str]` Key is the token name and value is the token. """ tokens_path = Path(constants.HF_STORED_TOKENS_PATH) @@ -147,12 +147,12 @@ def get_stored_tokens() -> Dict[str, str]: return stored_tokens -def _save_stored_tokens(stored_tokens: Dict[str, str]) -> None: +def _save_stored_tokens(stored_tokens: dict[str, str]) -> None: """ Saves the given configuration to the stored tokens file. Args: - stored_tokens (`Dict[str, str]`): + stored_tokens (`dict[str, str]`): The stored tokens to save. Key is the token name and value is the token. """ stored_tokens_path = Path(constants.HF_STORED_TOKENS_PATH) diff --git a/src/huggingface_hub/utils/_cache_manager.py b/src/huggingface_hub/utils/_cache_manager.py index 311e164a4f..656e548585 100644 --- a/src/huggingface_hub/utils/_cache_manager.py +++ b/src/huggingface_hub/utils/_cache_manager.py @@ -20,7 +20,7 @@ from collections import defaultdict from dataclasses import dataclass from pathlib import Path -from typing import Dict, FrozenSet, List, Literal, Optional, Set, Union +from typing import Literal, Optional, Union from huggingface_hub.errors import CacheNotFound, CorruptedCacheException @@ -119,9 +119,9 @@ class CachedRevisionInfo: snapshot_path (`Path`): Path to the revision directory in the `snapshots` folder. It contains the exact tree structure as the repo on the Hub. - files: (`FrozenSet[CachedFileInfo]`): + files: (`frozenset[CachedFileInfo]`): Set of [`~CachedFileInfo`] describing all files contained in the snapshot. - refs (`FrozenSet[str]`): + refs (`frozenset[str]`): Set of `refs` pointing to this revision. If the revision has no `refs`, it is considered detached. Example: `{"main", "2.4.0"}` or `{"refs/pr/1"}`. @@ -149,8 +149,8 @@ class CachedRevisionInfo: commit_hash: str snapshot_path: Path size_on_disk: int - files: FrozenSet[CachedFileInfo] - refs: FrozenSet[str] + files: frozenset[CachedFileInfo] + refs: frozenset[str] last_modified: float @@ -196,7 +196,7 @@ class CachedRepoInfo: Sum of the blob file sizes in the cached repo. nb_files (`int`): Total number of blob files in the cached repo. - revisions (`FrozenSet[CachedRevisionInfo]`): + revisions (`frozenset[CachedRevisionInfo]`): Set of [`~CachedRevisionInfo`] describing all revisions cached in the repo. last_accessed (`float`): Timestamp of the last time a blob file of the repo has been accessed. @@ -225,7 +225,7 @@ class CachedRepoInfo: repo_path: Path size_on_disk: int nb_files: int - revisions: FrozenSet[CachedRevisionInfo] + revisions: frozenset[CachedRevisionInfo] last_accessed: float last_modified: float @@ -260,7 +260,7 @@ def size_on_disk_str(self) -> str: return _format_size(self.size_on_disk) @property - def refs(self) -> Dict[str, CachedRevisionInfo]: + def refs(self) -> dict[str, CachedRevisionInfo]: """ (property) Mapping between `refs` and revision data structures. """ @@ -277,21 +277,21 @@ class DeleteCacheStrategy: Args: expected_freed_size (`float`): Expected freed size once strategy is executed. - blobs (`FrozenSet[Path]`): + blobs (`frozenset[Path]`): Set of blob file paths to be deleted. - refs (`FrozenSet[Path]`): + refs (`frozenset[Path]`): Set of reference file paths to be deleted. - repos (`FrozenSet[Path]`): + repos (`frozenset[Path]`): Set of entire repo paths to be deleted. - snapshots (`FrozenSet[Path]`): + snapshots (`frozenset[Path]`): Set of snapshots to be deleted (directory of symlinks). """ expected_freed_size: int - blobs: FrozenSet[Path] - refs: FrozenSet[Path] - repos: FrozenSet[Path] - snapshots: FrozenSet[Path] + blobs: frozenset[Path] + refs: frozenset[Path] + repos: frozenset[Path] + snapshots: frozenset[Path] @property def expected_freed_size_str(self) -> str: @@ -352,10 +352,10 @@ class HFCacheInfo: Args: size_on_disk (`int`): Sum of all valid repo sizes in the cache-system. - repos (`FrozenSet[CachedRepoInfo]`): + repos (`frozenset[CachedRepoInfo]`): Set of [`~CachedRepoInfo`] describing all valid cached repos found on the cache-system while scanning. - warnings (`List[CorruptedCacheException]`): + warnings (`list[CorruptedCacheException]`): List of [`~CorruptedCacheException`] that occurred while scanning the cache. Those exceptions are captured so that the scan can continue. Corrupted repos are skipped from the scan. @@ -369,8 +369,8 @@ class HFCacheInfo: """ size_on_disk: int - repos: FrozenSet[CachedRepoInfo] - warnings: List[CorruptedCacheException] + repos: frozenset[CachedRepoInfo] + warnings: list[CorruptedCacheException] @property def size_on_disk_str(self) -> str: @@ -420,9 +420,9 @@ def delete_revisions(self, *revisions: str) -> DeleteCacheStrategy: """ - hashes_to_delete: Set[str] = set(revisions) + hashes_to_delete: set[str] = set(revisions) - repos_with_revisions: Dict[CachedRepoInfo, Set[CachedRevisionInfo]] = defaultdict(set) + repos_with_revisions: dict[CachedRepoInfo, set[CachedRevisionInfo]] = defaultdict(set) for repo in self.repos: for revision in repo.revisions: @@ -433,10 +433,10 @@ def delete_revisions(self, *revisions: str) -> DeleteCacheStrategy: if len(hashes_to_delete) > 0: logger.warning(f"Revision(s) not found - cannot delete them: {', '.join(hashes_to_delete)}") - delete_strategy_blobs: Set[Path] = set() - delete_strategy_refs: Set[Path] = set() - delete_strategy_repos: Set[Path] = set() - delete_strategy_snapshots: Set[Path] = set() + delete_strategy_blobs: set[Path] = set() + delete_strategy_refs: set[Path] = set() + delete_strategy_repos: set[Path] = set() + delete_strategy_snapshots: set[Path] = set() delete_strategy_expected_freed_size = 0 for affected_repo, revisions_to_delete in repos_with_revisions.items(): @@ -681,8 +681,8 @@ def scan_cache_dir(cache_dir: Optional[Union[str, Path]] = None) -> HFCacheInfo: f"Scan cache expects a directory but found a file: {cache_dir}. Please use `cache_dir` argument or set `HF_HUB_CACHE` environment variable." ) - repos: Set[CachedRepoInfo] = set() - warnings: List[CorruptedCacheException] = [] + repos: set[CachedRepoInfo] = set() + warnings: list[CorruptedCacheException] = [] for repo_path in cache_dir.iterdir(): if repo_path.name == ".locks": # skip './.locks/' folder continue @@ -718,7 +718,7 @@ def _scan_cached_repo(repo_path: Path) -> CachedRepoInfo: f"Repo type must be `dataset`, `model` or `space`, found `{repo_type}` ({repo_path})." ) - blob_stats: Dict[Path, os.stat_result] = {} # Key is blob_path, value is blob stats + blob_stats: dict[Path, os.stat_result] = {} # Key is blob_path, value is blob stats snapshots_path = repo_path / "snapshots" refs_path = repo_path / "refs" @@ -729,7 +729,7 @@ def _scan_cached_repo(repo_path: Path) -> CachedRepoInfo: # Scan over `refs` directory # key is revision hash, value is set of refs - refs_by_hash: Dict[str, Set[str]] = defaultdict(set) + refs_by_hash: dict[str, set[str]] = defaultdict(set) if refs_path.exists(): # Example of `refs` directory # ── refs @@ -752,7 +752,7 @@ def _scan_cached_repo(repo_path: Path) -> CachedRepoInfo: refs_by_hash[commit_hash].add(ref_name) # Scan snapshots directory - cached_revisions: Set[CachedRevisionInfo] = set() + cached_revisions: set[CachedRevisionInfo] = set() for revision_path in snapshots_path.iterdir(): # Ignore OS-created helper files if revision_path.name in FILES_TO_IGNORE: diff --git a/src/huggingface_hub/utils/_deprecation.py b/src/huggingface_hub/utils/_deprecation.py index 4cb8d6e418..51063879db 100644 --- a/src/huggingface_hub/utils/_deprecation.py +++ b/src/huggingface_hub/utils/_deprecation.py @@ -62,7 +62,7 @@ def _deprecate_arguments( Args: version (`str`): The version when deprecated arguments will result in error. - deprecated_args (`List[str]`): + deprecated_args (`list[str]`): List of the arguments to be deprecated. custom_message (`str`, *optional*): Warning message that is raised. If not passed, a default warning message diff --git a/src/huggingface_hub/utils/_dotenv.py b/src/huggingface_hub/utils/_dotenv.py index 23b8a1b70a..97e3b885be 100644 --- a/src/huggingface_hub/utils/_dotenv.py +++ b/src/huggingface_hub/utils/_dotenv.py @@ -1,14 +1,14 @@ # AI-generated module (ChatGPT) import re -from typing import Dict, Optional +from typing import Optional -def load_dotenv(dotenv_str: str, environ: Optional[Dict[str, str]] = None) -> Dict[str, str]: +def load_dotenv(dotenv_str: str, environ: Optional[dict[str, str]] = None) -> dict[str, str]: """ Parse a DOTENV-format string and return a dictionary of key-value pairs. Handles quoted values, comments, export keyword, and blank lines. """ - env: Dict[str, str] = {} + env: dict[str, str] = {} line_pattern = re.compile( r""" ^\s* diff --git a/src/huggingface_hub/utils/_git_credential.py b/src/huggingface_hub/utils/_git_credential.py index 5ad84648a0..7aa03727d4 100644 --- a/src/huggingface_hub/utils/_git_credential.py +++ b/src/huggingface_hub/utils/_git_credential.py @@ -16,7 +16,7 @@ import re import subprocess -from typing import List, Optional +from typing import Optional from ..constants import ENDPOINT from ._subprocess import run_interactive_subprocess, run_subprocess @@ -34,7 +34,7 @@ ) -def list_credential_helpers(folder: Optional[str] = None) -> List[str]: +def list_credential_helpers(folder: Optional[str] = None) -> list[str]: """Return the list of git credential helpers configured. See https://git-scm.com/docs/gitcredentials. @@ -104,7 +104,7 @@ def unset_git_credential(username: str = "hf_user", folder: Optional[str] = None stdin.flush() -def _parse_credential_output(output: str) -> List[str]: +def _parse_credential_output(output: str) -> list[str]: """Parse the output of `git credential fill` to extract the password. Args: diff --git a/src/huggingface_hub/utils/_headers.py b/src/huggingface_hub/utils/_headers.py index 053a92a398..23726b56cc 100644 --- a/src/huggingface_hub/utils/_headers.py +++ b/src/huggingface_hub/utils/_headers.py @@ -14,7 +14,7 @@ # limitations under the License. """Contains utilities to handle headers to send in calls to Huggingface Hub.""" -from typing import Dict, Optional, Union +from typing import Optional, Union from huggingface_hub.errors import LocalTokenNotFoundError @@ -47,10 +47,10 @@ def build_hf_headers( token: Optional[Union[bool, str]] = None, library_name: Optional[str] = None, library_version: Optional[str] = None, - user_agent: Union[Dict, str, None] = None, - headers: Optional[Dict[str, str]] = None, + user_agent: Union[dict, str, None] = None, + headers: Optional[dict[str, str]] = None, is_write_action: bool = False, -) -> Dict[str, str]: +) -> dict[str, str]: """ Build headers dictionary to send in a HF Hub call. @@ -90,7 +90,7 @@ def build_hf_headers( Ignored and deprecated argument. Returns: - A `Dict` of headers to pass in your API call. + A `dict` of headers to pass in your API call. Example: ```py @@ -176,7 +176,7 @@ def _http_user_agent( *, library_name: Optional[str] = None, library_version: Optional[str] = None, - user_agent: Union[Dict, str, None] = None, + user_agent: Union[dict, str, None] = None, ) -> str: """Format a user-agent string containing information about the installed packages. diff --git a/src/huggingface_hub/utils/_http.py b/src/huggingface_hub/utils/_http.py index b3a545c722..15484ec10d 100644 --- a/src/huggingface_hub/utils/_http.py +++ b/src/huggingface_hub/utils/_http.py @@ -24,7 +24,7 @@ from contextlib import contextmanager from http import HTTPStatus from shlex import quote -from typing import Any, Callable, Generator, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Generator, Optional, Union import httpx @@ -260,11 +260,11 @@ def _http_backoff_base( max_retries: int = 5, base_wait_time: float = 1, max_wait_time: float = 8, - retry_on_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = ( + retry_on_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = ( httpx.TimeoutException, httpx.NetworkError, ), - retry_on_status_codes: Union[int, Tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE, + retry_on_status_codes: Union[int, tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE, stream: bool = False, **kwargs, ) -> Generator[httpx.Response, None, None]: @@ -345,11 +345,11 @@ def http_backoff( max_retries: int = 5, base_wait_time: float = 1, max_wait_time: float = 8, - retry_on_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = ( + retry_on_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = ( httpx.TimeoutException, httpx.NetworkError, ), - retry_on_status_codes: Union[int, Tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE, + retry_on_status_codes: Union[int, tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE, **kwargs, ) -> httpx.Response: """Wrapper around httpx to retry calls on an endpoint, with exponential backoff. @@ -375,10 +375,10 @@ def http_backoff( `max_wait_time`. max_wait_time (`float`, *optional*, defaults to `8`): Maximum duration (in seconds) to wait before retrying. - retry_on_exceptions (`Type[Exception]` or `Tuple[Type[Exception]]`, *optional*): + retry_on_exceptions (`type[Exception]` or `tuple[type[Exception]]`, *optional*): Define which exceptions must be caught to retry the request. Can be a single type or a tuple of types. By default, retry on `httpx.TimeoutException` and `httpx.NetworkError`. - retry_on_status_codes (`int` or `Tuple[int]`, *optional*, defaults to `503`): + retry_on_status_codes (`int` or `tuple[int]`, *optional*, defaults to `503`): Define on which status codes the request must be retried. By default, only HTTP 503 Service Unavailable is retried. **kwargs (`dict`, *optional*): @@ -432,11 +432,11 @@ def http_stream_backoff( max_retries: int = 5, base_wait_time: float = 1, max_wait_time: float = 8, - retry_on_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = ( + retry_on_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = ( httpx.TimeoutException, httpx.NetworkError, ), - retry_on_status_codes: Union[int, Tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE, + retry_on_status_codes: Union[int, tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE, **kwargs, ) -> Generator[httpx.Response, None, None]: """Wrapper around httpx to retry calls on an endpoint, with exponential backoff. @@ -462,10 +462,10 @@ def http_stream_backoff( `max_wait_time`. max_wait_time (`float`, *optional*, defaults to `8`): Maximum duration (in seconds) to wait before retrying. - retry_on_exceptions (`Type[Exception]` or `Tuple[Type[Exception]]`, *optional*): + retry_on_exceptions (`type[Exception]` or `tuple[type[Exception]]`, *optional*): Define which exceptions must be caught to retry the request. Can be a single type or a tuple of types. By default, retry on `httpx.Timeout` and `httpx.NetworkError`. - retry_on_status_codes (`int` or `Tuple[int]`, *optional*, defaults to `503`): + retry_on_status_codes (`int` or `tuple[int]`, *optional*, defaults to `503`): Define on which status codes the request must be retried. By default, only HTTP 503 Service Unavailable is retried. **kwargs (`dict`, *optional*): @@ -636,7 +636,7 @@ def hf_raise_for_status(response: httpx.Response, endpoint_name: Optional[str] = raise _format(HfHubHTTPError, str(e), response) from e -def _format(error_type: Type[HfHubHTTPError], custom_message: str, response: httpx.Response) -> HfHubHTTPError: +def _format(error_type: type[HfHubHTTPError], custom_message: str, response: httpx.Response) -> HfHubHTTPError: server_errors = [] # Retrieve server error from header @@ -722,7 +722,7 @@ def _curlify(request: httpx.Request) -> str: Implementation vendored from https://github.com/ofw/curlify/blob/master/curlify.py. MIT License Copyright (c) 2016 Egor. """ - parts: List[Tuple[Any, Any]] = [ + parts: list[tuple[Any, Any]] = [ ("curl", None), ("-X", request.method), ] diff --git a/src/huggingface_hub/utils/_pagination.py b/src/huggingface_hub/utils/_pagination.py index 1d63ad4b49..275d5d5f5a 100644 --- a/src/huggingface_hub/utils/_pagination.py +++ b/src/huggingface_hub/utils/_pagination.py @@ -14,7 +14,7 @@ # limitations under the License. """Contains utilities to handle pagination on Huggingface Hub.""" -from typing import Dict, Iterable, Optional +from typing import Iterable, Optional import httpx @@ -24,7 +24,7 @@ logger = logging.get_logger(__name__) -def paginate(path: str, params: Dict, headers: Dict) -> Iterable: +def paginate(path: str, params: dict, headers: dict) -> Iterable: """Fetch a list of models/datasets/spaces and paginate through results. This is using the same "Link" header format as GitHub. diff --git a/src/huggingface_hub/utils/_paths.py b/src/huggingface_hub/utils/_paths.py index 4f2c0ebce0..f4d48c2cfe 100644 --- a/src/huggingface_hub/utils/_paths.py +++ b/src/huggingface_hub/utils/_paths.py @@ -16,7 +16,7 @@ from fnmatch import fnmatch from pathlib import Path -from typing import Callable, Generator, Iterable, List, Optional, TypeVar, Union +from typing import Callable, Generator, Iterable, Optional, TypeVar, Union T = TypeVar("T") @@ -39,8 +39,8 @@ def filter_repo_objects( items: Iterable[T], *, - allow_patterns: Optional[Union[List[str], str]] = None, - ignore_patterns: Optional[Union[List[str], str]] = None, + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, key: Optional[Callable[[T], str]] = None, ) -> Generator[T, None, None]: """Filter repo objects based on an allowlist and a denylist. @@ -55,10 +55,10 @@ def filter_repo_objects( Args: items (`Iterable`): List of items to filter. - allow_patterns (`str` or `List[str]`, *optional*): + allow_patterns (`str` or `list[str]`, *optional*): Patterns constituting the allowlist. If provided, item paths must match at least one pattern from the allowlist. - ignore_patterns (`str` or `List[str]`, *optional*): + ignore_patterns (`str` or `list[str]`, *optional*): Patterns constituting the denylist. If provided, item paths must not match any patterns from the denylist. key (`Callable[[T], str]`, *optional*): diff --git a/src/huggingface_hub/utils/_runtime.py b/src/huggingface_hub/utils/_runtime.py index 9e38e6da74..9d52091fc9 100644 --- a/src/huggingface_hub/utils/_runtime.py +++ b/src/huggingface_hub/utils/_runtime.py @@ -19,7 +19,7 @@ import platform import sys import warnings -from typing import Any, Dict +from typing import Any from .. import __version__, constants @@ -312,7 +312,7 @@ def is_colab_enterprise() -> bool: return os.environ.get("VERTEX_PRODUCT") == "COLAB_ENTERPRISE" -def dump_environment_info() -> Dict[str, Any]: +def dump_environment_info() -> dict[str, Any]: """Dump information about the machine to help debugging issues. Similar helper exist in: @@ -326,7 +326,7 @@ def dump_environment_info() -> Dict[str, Any]: token = get_token() # Generic machine info - info: Dict[str, Any] = { + info: dict[str, Any] = { "huggingface_hub version": get_hf_hub_version(), "Platform": platform.platform(), "Python version": get_python_version(), diff --git a/src/huggingface_hub/utils/_safetensors.py b/src/huggingface_hub/utils/_safetensors.py index 38546c6d34..8b9c257055 100644 --- a/src/huggingface_hub/utils/_safetensors.py +++ b/src/huggingface_hub/utils/_safetensors.py @@ -2,7 +2,7 @@ import operator from collections import defaultdict from dataclasses import dataclass, field -from typing import Dict, List, Literal, Optional, Tuple +from typing import Literal, Optional FILENAME_T = str @@ -19,17 +19,17 @@ class TensorInfo: Attributes: dtype (`str`): The data type of the tensor ("F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL"). - shape (`List[int]`): + shape (`list[int]`): The shape of the tensor. - data_offsets (`Tuple[int, int]`): + data_offsets (`tuple[int, int]`): The offsets of the data in the file as a tuple `[BEGIN, END]`. parameter_count (`int`): The number of parameters in the tensor. """ dtype: DTYPE_T - shape: List[int] - data_offsets: Tuple[int, int] + shape: list[int] + data_offsets: tuple[int, int] parameter_count: int = field(init=False) def __post_init__(self) -> None: @@ -49,22 +49,22 @@ class SafetensorsFileMetadata: For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format. Attributes: - metadata (`Dict`): + metadata (`dict`): The metadata contained in the file. - tensors (`Dict[str, TensorInfo]`): + tensors (`dict[str, TensorInfo]`): A map of all tensors. Keys are tensor names and values are information about the corresponding tensor, as a [`TensorInfo`] object. - parameter_count (`Dict[str, int]`): + parameter_count (`dict[str, int]`): A map of the number of parameters per data type. Keys are data types and values are the number of parameters of that data type. """ - metadata: Dict[str, str] - tensors: Dict[TENSOR_NAME_T, TensorInfo] - parameter_count: Dict[DTYPE_T, int] = field(init=False) + metadata: dict[str, str] + tensors: dict[TENSOR_NAME_T, TensorInfo] + parameter_count: dict[DTYPE_T, int] = field(init=False) def __post_init__(self) -> None: - parameter_count: Dict[DTYPE_T, int] = defaultdict(int) + parameter_count: dict[DTYPE_T, int] = defaultdict(int) for tensor in self.tensors.values(): parameter_count[tensor.dtype] += tensor.parameter_count self.parameter_count = dict(parameter_count) @@ -82,29 +82,29 @@ class SafetensorsRepoMetadata: For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format. Attributes: - metadata (`Dict`, *optional*): + metadata (`dict`, *optional*): The metadata contained in the 'model.safetensors.index.json' file, if it exists. Only populated for sharded models. sharded (`bool`): Whether the repo contains a sharded model or not. - weight_map (`Dict[str, str]`): + weight_map (`dict[str, str]`): A map of all weights. Keys are tensor names and values are filenames of the files containing the tensors. - files_metadata (`Dict[str, SafetensorsFileMetadata]`): + files_metadata (`dict[str, SafetensorsFileMetadata]`): A map of all files metadata. Keys are filenames and values are the metadata of the corresponding file, as a [`SafetensorsFileMetadata`] object. - parameter_count (`Dict[str, int]`): + parameter_count (`dict[str, int]`): A map of the number of parameters per data type. Keys are data types and values are the number of parameters of that data type. """ - metadata: Optional[Dict] + metadata: Optional[dict] sharded: bool - weight_map: Dict[TENSOR_NAME_T, FILENAME_T] # tensor name -> filename - files_metadata: Dict[FILENAME_T, SafetensorsFileMetadata] # filename -> metadata - parameter_count: Dict[DTYPE_T, int] = field(init=False) + weight_map: dict[TENSOR_NAME_T, FILENAME_T] # tensor name -> filename + files_metadata: dict[FILENAME_T, SafetensorsFileMetadata] # filename -> metadata + parameter_count: dict[DTYPE_T, int] = field(init=False) def __post_init__(self) -> None: - parameter_count: Dict[DTYPE_T, int] = defaultdict(int) + parameter_count: dict[DTYPE_T, int] = defaultdict(int) for file_metadata in self.files_metadata.values(): for dtype, nb_parameters_ in file_metadata.parameter_count.items(): parameter_count[dtype] += nb_parameters_ diff --git a/src/huggingface_hub/utils/_subprocess.py b/src/huggingface_hub/utils/_subprocess.py index fdabf1c4df..e2b9a4f2f1 100644 --- a/src/huggingface_hub/utils/_subprocess.py +++ b/src/huggingface_hub/utils/_subprocess.py @@ -20,7 +20,7 @@ from contextlib import contextmanager from io import StringIO from pathlib import Path -from typing import IO, Generator, List, Optional, Tuple, Union +from typing import IO, Generator, Optional, Union from .logging import get_logger @@ -51,7 +51,7 @@ def capture_output() -> Generator[StringIO, None, None]: def run_subprocess( - command: Union[str, List[str]], + command: Union[str, list[str]], folder: Optional[Union[str, Path]] = None, check=True, **kwargs, @@ -62,7 +62,7 @@ def run_subprocess( be captured. Args: - command (`str` or `List[str]`): + command (`str` or `list[str]`): The command to execute as a string or list of strings. folder (`str`, *optional*): The folder in which to run the command. Defaults to current working @@ -70,7 +70,7 @@ def run_subprocess( check (`bool`, *optional*, defaults to `True`): Setting `check` to `True` will raise a `subprocess.CalledProcessError` when the subprocess has a non-zero exit code. - kwargs (`Dict[str]`): + kwargs (`dict[str]`): Keyword arguments to be passed to the `subprocess.run` underlying command. Returns: @@ -96,23 +96,23 @@ def run_subprocess( @contextmanager def run_interactive_subprocess( - command: Union[str, List[str]], + command: Union[str, list[str]], folder: Optional[Union[str, Path]] = None, **kwargs, -) -> Generator[Tuple[IO[str], IO[str]], None, None]: +) -> Generator[tuple[IO[str], IO[str]], None, None]: """Run a subprocess in an interactive mode in a context manager. Args: - command (`str` or `List[str]`): + command (`str` or `list[str]`): The command to execute as a string or list of strings. folder (`str`, *optional*): The folder in which to run the command. Defaults to current working directory (from `os.getcwd()`). - kwargs (`Dict[str]`): + kwargs (`dict[str]`): Keyword arguments to be passed to the `subprocess.run` underlying command. Returns: - `Tuple[IO[str], IO[str]]`: A tuple with `stdin` and `stdout` to interact + `tuple[IO[str], IO[str]]`: A tuple with `stdin` and `stdout` to interact with the process (input and output are utf-8 encoded). Example: diff --git a/src/huggingface_hub/utils/_telemetry.py b/src/huggingface_hub/utils/_telemetry.py index 2ba4a6349a..e8f0bd0345 100644 --- a/src/huggingface_hub/utils/_telemetry.py +++ b/src/huggingface_hub/utils/_telemetry.py @@ -1,6 +1,6 @@ from queue import Queue from threading import Lock, Thread -from typing import Dict, Optional, Union +from typing import Optional, Union from urllib.parse import quote from .. import constants, logging @@ -22,7 +22,7 @@ def send_telemetry( *, library_name: Optional[str] = None, library_version: Optional[str] = None, - user_agent: Union[Dict, str, None] = None, + user_agent: Union[dict, str, None] = None, ) -> None: """ Sends telemetry that helps tracking usage of different HF libraries. @@ -98,7 +98,7 @@ def _send_telemetry_in_thread( *, library_name: Optional[str] = None, library_version: Optional[str] = None, - user_agent: Union[Dict, str, None] = None, + user_agent: Union[dict, str, None] = None, ) -> None: """Contains the actual data sending data to the Hub. diff --git a/src/huggingface_hub/utils/_typing.py b/src/huggingface_hub/utils/_typing.py index b8388ca0c0..6fcbe4a530 100644 --- a/src/huggingface_hub/utils/_typing.py +++ b/src/huggingface_hub/utils/_typing.py @@ -15,10 +15,10 @@ """Handle typing imports based on system compatibility.""" import sys -from typing import Any, Callable, List, Literal, Type, TypeVar, Union, get_args, get_origin +from typing import Any, Callable, Literal, Type, TypeVar, Union, get_args, get_origin -UNION_TYPES: List[Any] = [Union] +UNION_TYPES: list[Any] = [Union] if sys.version_info >= (3, 10): from types import UnionType diff --git a/src/huggingface_hub/utils/_validators.py b/src/huggingface_hub/utils/_validators.py index 2a1b473446..8bbb16d87e 100644 --- a/src/huggingface_hub/utils/_validators.py +++ b/src/huggingface_hub/utils/_validators.py @@ -19,7 +19,7 @@ import warnings from functools import wraps from itertools import chain -from typing import Any, Dict +from typing import Any from huggingface_hub.errors import HFValidationError @@ -172,7 +172,7 @@ def validate_repo_id(repo_id: str) -> None: raise HFValidationError(f"Repo_id cannot end by '.git': '{repo_id}'.") -def smoothly_deprecate_proxies(fn_name: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: +def smoothly_deprecate_proxies(fn_name: str, kwargs: dict[str, Any]) -> dict[str, Any]: """Smoothly deprecate `proxies` in the `huggingface_hub` codebase. This function removes the `proxies` key from the kwargs and warns the user that the `proxies` argument is ignored. @@ -203,7 +203,7 @@ def smoothly_deprecate_proxies(fn_name: str, kwargs: Dict[str, Any]) -> Dict[str return new_kwargs -def smoothly_deprecate_use_auth_token(fn_name: str, has_token: bool, kwargs: Dict[str, Any]) -> Dict[str, Any]: +def smoothly_deprecate_use_auth_token(fn_name: str, has_token: bool, kwargs: dict[str, Any]) -> dict[str, Any]: """Smoothly deprecate `use_auth_token` in the `huggingface_hub` codebase. The long-term goal is to remove any mention of `use_auth_token` in the codebase in diff --git a/src/huggingface_hub/utils/_xet.py b/src/huggingface_hub/utils/_xet.py index c49c8f88f0..473c451251 100644 --- a/src/huggingface_hub/utils/_xet.py +++ b/src/huggingface_hub/utils/_xet.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from enum import Enum -from typing import Dict, Optional +from typing import Optional import httpx @@ -63,11 +63,11 @@ def parse_xet_file_data_from_response( ) -def parse_xet_connection_info_from_headers(headers: Dict[str, str]) -> Optional[XetConnectionInfo]: +def parse_xet_connection_info_from_headers(headers: dict[str, str]) -> Optional[XetConnectionInfo]: """ Parse XET connection info from the HTTP headers or return None if not found. Args: - headers (`Dict`): + headers (`dict`): HTTP headers to extract the XET metadata from. Returns: `XetConnectionInfo` or `None`: @@ -92,7 +92,7 @@ def parse_xet_connection_info_from_headers(headers: Dict[str, str]) -> Optional[ def refresh_xet_connection_info( *, file_data: XetFileData, - headers: Dict[str, str], + headers: dict[str, str], ) -> XetConnectionInfo: """ Utilizes the information in the parsed metadata to request the Hub xet connection information. @@ -100,7 +100,7 @@ def refresh_xet_connection_info( Args: file_data: (`XetFileData`): The file data needed to refresh the xet connection information. - headers (`Dict[str, str]`): + headers (`dict[str, str]`): Headers to use for the request, including authorization headers and user agent. Returns: `XetConnectionInfo`: @@ -123,9 +123,9 @@ def fetch_xet_connection_info_from_repo_info( repo_id: str, repo_type: str, revision: Optional[str] = None, - headers: Dict[str, str], + headers: dict[str, str], endpoint: Optional[str] = None, - params: Optional[Dict[str, str]] = None, + params: Optional[dict[str, str]] = None, ) -> XetConnectionInfo: """ Uses the repo info to request a xet access token from Hub. @@ -138,11 +138,11 @@ def fetch_xet_connection_info_from_repo_info( Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`. revision (`str`, `optional`): The revision of the repo to get the token for. - headers (`Dict[str, str]`): + headers (`dict[str, str]`): Headers to use for the request, including authorization headers and user agent. endpoint (`str`, `optional`): The endpoint to use for the request. Defaults to the Hub endpoint. - params (`Dict[str, str]`, `optional`): + params (`dict[str, str]`, `optional`): Additional parameters to pass with the request. Returns: `XetConnectionInfo`: @@ -161,8 +161,8 @@ def fetch_xet_connection_info_from_repo_info( @validate_hf_hub_args def _fetch_xet_connection_info_with_url( url: str, - headers: Dict[str, str], - params: Optional[Dict[str, str]] = None, + headers: dict[str, str], + params: Optional[dict[str, str]] = None, ) -> XetConnectionInfo: """ Requests the xet connection info from the supplied URL. This includes the @@ -170,9 +170,9 @@ def _fetch_xet_connection_info_with_url( Args: url: (`str`): The access token endpoint URL. - headers (`Dict[str, str]`): + headers (`dict[str, str]`): Headers to use for the request, including authorization headers and user agent. - params (`Dict[str, str]`, `optional`): + params (`dict[str, str]`, `optional`): Additional parameters to pass with the request. Returns: `XetConnectionInfo`: diff --git a/src/huggingface_hub/utils/_xet_progress_reporting.py b/src/huggingface_hub/utils/_xet_progress_reporting.py index e47740d5c5..05c87b835a 100644 --- a/src/huggingface_hub/utils/_xet_progress_reporting.py +++ b/src/huggingface_hub/utils/_xet_progress_reporting.py @@ -64,7 +64,7 @@ def format_desc(self, name: str, indent: bool) -> str: return f"{padding}{name.ljust(width)}" - def update_progress(self, total_update: PyTotalProgressUpdate, item_updates: List[PyItemProgressUpdate]): + def update_progress(self, total_update: PyTotalProgressUpdate, item_updates: list[PyItemProgressUpdate]): # Update all the per-item values. for item in item_updates: item_name = item.item_name diff --git a/src/huggingface_hub/utils/insecure_hashlib.py b/src/huggingface_hub/utils/insecure_hashlib.py index 6901b6d647..639e04460b 100644 --- a/src/huggingface_hub/utils/insecure_hashlib.py +++ b/src/huggingface_hub/utils/insecure_hashlib.py @@ -25,14 +25,8 @@ # ``` import functools import hashlib -import sys -if sys.version_info >= (3, 9): - md5 = functools.partial(hashlib.md5, usedforsecurity=False) - sha1 = functools.partial(hashlib.sha1, usedforsecurity=False) - sha256 = functools.partial(hashlib.sha256, usedforsecurity=False) -else: - md5 = hashlib.md5 - sha1 = hashlib.sha1 - sha256 = hashlib.sha256 +md5 = functools.partial(hashlib.md5, usedforsecurity=False) +sha1 = functools.partial(hashlib.sha1, usedforsecurity=False) +sha256 = functools.partial(hashlib.sha256, usedforsecurity=False) diff --git a/src/huggingface_hub/utils/tqdm.py b/src/huggingface_hub/utils/tqdm.py index 46bd0ace67..4d47cafc8f 100644 --- a/src/huggingface_hub/utils/tqdm.py +++ b/src/huggingface_hub/utils/tqdm.py @@ -86,7 +86,7 @@ import warnings from contextlib import contextmanager, nullcontext from pathlib import Path -from typing import ContextManager, Dict, Iterator, Optional, Union +from typing import ContextManager, Iterator, Optional, Union from tqdm.auto import tqdm as old_tqdm @@ -102,7 +102,7 @@ # progress bar visibility through code. By default, progress bars are turned on. -progress_bar_states: Dict[str, bool] = {} +progress_bar_states: dict[str, bool] = {} def disable_progress_bars(name: Optional[str] = None) -> None: diff --git a/tests/test_dduf.py b/tests/test_dduf.py index 7c4b5afc9f..c4d509ec85 100644 --- a/tests/test_dduf.py +++ b/tests/test_dduf.py @@ -1,7 +1,7 @@ import json import zipfile from pathlib import Path -from typing import Iterable, Tuple, Union +from typing import Iterable, Union import pytest from pytest_mock import MockerFixture @@ -146,7 +146,7 @@ def test_export_folder(self, dummy_folder: Path, mocker: MockerFixture): class TestExportEntries: @pytest.fixture - def dummy_entries(self, tmp_path: Path) -> Iterable[Tuple[str, Union[str, Path, bytes]]]: + def dummy_entries(self, tmp_path: Path) -> Iterable[tuple[str, Union[str, Path, bytes]]]: (tmp_path / "model_index.json").write_text(json.dumps({"foo": "bar"})) (tmp_path / "doesnt_have_to_be_same_name.safetensors").write_bytes(b"this is safetensors content") @@ -157,7 +157,7 @@ def dummy_entries(self, tmp_path: Path) -> Iterable[Tuple[str, Union[str, Path, ] def test_export_entries( - self, tmp_path: Path, dummy_entries: Iterable[Tuple[str, Union[str, Path, bytes]]], mocker: MockerFixture + self, tmp_path: Path, dummy_entries: Iterable[tuple[str, Union[str, Path, bytes]]], mocker: MockerFixture ): mock = mocker.patch("huggingface_hub.serialization._dduf._validate_dduf_structure") export_entries_as_dduf(tmp_path / "dummy.dduf", dummy_entries) diff --git a/tests/test_file_download.py b/tests/test_file_download.py index bb76af9c47..87a2a645b9 100644 --- a/tests/test_file_download.py +++ b/tests/test_file_download.py @@ -19,7 +19,7 @@ import warnings from contextlib import contextmanager from pathlib import Path -from typing import Iterable, List +from typing import Iterable from unittest.mock import Mock, patch import httpx @@ -1008,7 +1008,7 @@ def _iter_content_4() -> Iterable[bytes]: ), ], ) - def test_http_get_with_range_headers(self, caplog, initial_range: str, expected_ranges: List[str]): + def test_http_get_with_range_headers(self, caplog, initial_range: str, expected_ranges: list[str]): def _iter_content_1() -> Iterable[bytes]: yield b"0" * 10 yield b"0" * 10 diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index ce5c08d2e2..ff7a829baf 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -25,7 +25,7 @@ from dataclasses import fields from io import BytesIO from pathlib import Path -from typing import List, Optional, Set, Union, get_args +from typing import Optional, Union, get_args from unittest.mock import Mock, patch from urllib.parse import quote, urlparse @@ -1758,7 +1758,7 @@ def tearDown(self) -> None: self._api.delete_repo(repo_id=self.repo_id) super().tearDown() - def remote_files(self) -> Set[set]: + def remote_files(self) -> set[set]: return set(self._api.list_repo_files(repo_id=self.repo_id)) def test_delete_single_file(self): @@ -2196,7 +2196,7 @@ def test_dataset_info_with_file_metadata(self): assert files is not None self._check_siblings_metadata(files) - def _check_siblings_metadata(self, files: List[RepoSibling]): + def _check_siblings_metadata(self, files: list[RepoSibling]): """Check requested metadata has been received from the server.""" at_least_one_lfs = False for file in files: @@ -2657,7 +2657,7 @@ def setUp(self) -> None: self.create_commit_mock.return_value.pr_url = None self.api.create_commit = self.create_commit_mock - def _upload_folder_alias(self, **kwargs) -> List[Union[CommitOperationAdd, CommitOperationDelete]]: + def _upload_folder_alias(self, **kwargs) -> list[Union[CommitOperationAdd, CommitOperationDelete]]: """Alias to call `upload_folder` + retrieve the CommitOperation list passed to `create_commit`.""" if "folder_path" not in kwargs: kwargs["folder_path"] = self.cache_dir @@ -4488,7 +4488,7 @@ class HfApiInferenceCatalogTest(HfApiCommonTest): def test_list_inference_catalog(self) -> None: models = self._api.list_inference_catalog() # note: @experimental api # Check that server returns a list[str] => at least if it changes in the future, we'll notice - assert isinstance(models, List) + assert isinstance(models, list) assert len(models) > 0 assert all(isinstance(model, str) for model in models) diff --git a/tests/test_hub_mixin.py b/tests/test_hub_mixin.py index 90582e846d..5b1ccdd574 100644 --- a/tests/test_hub_mixin.py +++ b/tests/test_hub_mixin.py @@ -4,7 +4,7 @@ import unittest from dataclasses import dataclass from pathlib import Path -from typing import Dict, Optional, Union, get_type_hints +from typing import Optional, Union, get_type_hints from unittest.mock import Mock, patch import jedi @@ -58,7 +58,7 @@ def __init__(self, config: ConfigAsDataclass): class DummyModelConfigAsDict(BaseModel, ModelHubMixin): - def __init__(self, config: Dict): + def __init__(self, config: dict): pass @@ -68,7 +68,7 @@ def __init__(self, config: Optional[ConfigAsDataclass] = None): class DummyModelConfigAsOptionalDict(BaseModel, ModelHubMixin): - def __init__(self, config: Optional[Dict] = None): + def __init__(self, config: Optional[dict] = None): pass @@ -85,7 +85,7 @@ def _save_pretrained(self, save_directory: Path) -> None: def _from_pretrained( cls, model_id: Union[str, Path], - config: Optional[Dict] = None, + config: Optional[dict] = None, **kwargs, ) -> "BaseModel": return cls(**kwargs) diff --git a/tests/test_hub_mixin_pytorch.py b/tests/test_hub_mixin_pytorch.py index dd965189fe..1ebde60d42 100644 --- a/tests/test_hub_mixin_pytorch.py +++ b/tests/test_hub_mixin_pytorch.py @@ -4,7 +4,7 @@ import unittest from argparse import Namespace from pathlib import Path -from typing import Any, Dict, Optional, TypeVar +from typing import Any, Optional, TypeVar from unittest.mock import Mock, patch import pytest @@ -89,7 +89,7 @@ def __init__( self.not_jsonable = not_jsonable class DummyModelWithConfigAndKwargs(nn.Module, PyTorchModelHubMixin): - def __init__(self, num_classes: int = 42, state: str = "layernorm", config: Optional[Dict] = None, **kwargs): + def __init__(self, num_classes: int = 42, state: str = "layernorm", config: Optional[dict] = None, **kwargs): super().__init__() class DummyModelWithModelCardAndCustomKwargs( diff --git a/tests/test_inference_client.py b/tests/test_inference_client.py index e2370aa708..cedf6b4b89 100644 --- a/tests/test_inference_client.py +++ b/tests/test_inference_client.py @@ -18,7 +18,6 @@ import string import time from pathlib import Path -from typing import List from unittest.mock import MagicMock, patch import numpy as np @@ -215,7 +214,7 @@ } -def list_clients(task: str) -> List[pytest.param]: +def list_clients(task: str) -> list[pytest.param]: """Get list of clients for a specific task, with proper skip handling.""" clients = [] for provider, tasks in _RECOMMENDED_MODELS_FOR_VCR.items(): diff --git a/tests/test_inference_providers.py b/tests/test_inference_providers.py index 333eb57d33..73346177ae 100644 --- a/tests/test_inference_providers.py +++ b/tests/test_inference_providers.py @@ -1,6 +1,5 @@ import base64 import logging -from typing import Dict from unittest.mock import MagicMock, patch import pytest @@ -1529,7 +1528,7 @@ def test_prepare_payload(self): ), ], ) -def test_recursive_merge(dict1: Dict, dict2: Dict, expected: Dict): +def test_recursive_merge(dict1: dict, dict2: dict, expected: dict): initial_dict1 = dict1.copy() initial_dict2 = dict2.copy() assert recursive_merge(dict1, dict2) == expected @@ -1563,7 +1562,7 @@ def test_recursive_merge(dict1: Dict, dict2: Dict, expected: Dict): ({"a": [None, {"x": None}]}, {"a": [None, {}]}), ], ) -def test_filter_none(data: Dict, expected: Dict): +def test_filter_none(data: dict, expected: dict): """Test that filter_none removes None values from nested dictionaries.""" assert filter_none(data) == expected diff --git a/tests/test_inference_text_generation.py b/tests/test_inference_text_generation.py index 3135172e9d..4bb30e5665 100644 --- a/tests/test_inference_text_generation.py +++ b/tests/test_inference_text_generation.py @@ -4,7 +4,6 @@ # See './src/huggingface_hub/inference/_text_generation.py' for details. import json import unittest -from typing import Dict from unittest.mock import MagicMock, patch import pytest @@ -45,7 +44,7 @@ def test_validation_error(self): raise_text_generation_error(error) -def _mocked_error(payload: Dict) -> MagicMock: +def _mocked_error(payload: dict) -> MagicMock: error = HfHubHTTPError("message", response=MagicMock()) error.response.json.return_value = payload return error diff --git a/tests/test_inference_types.py b/tests/test_inference_types.py index 5f7a5b2a6e..c164877fd7 100644 --- a/tests/test_inference_types.py +++ b/tests/test_inference_types.py @@ -18,8 +18,8 @@ class DummyType(BaseInferenceType): @dataclass_with_extra class DummyNestedType(BaseInferenceType): item: DummyType - items: List[DummyType] - maybe_items: Optional[List[DummyType]] = None + items: List[DummyType] # works both with List and list + maybe_items: Optional[list[DummyType]] = None DUMMY_AS_DICT = {"foo": 42, "bar": "baz"} @@ -97,6 +97,7 @@ def test_parse_nested_class(): def test_all_fields_are_optional(): # all fields are optional => silently accept None if server returns less data than expected instance = DummyNestedType.parse_obj({"maybe_items": [{}, DUMMY_AS_BYTES]}) + assert isinstance(instance, DummyNestedType) assert instance.item is None assert instance.items is None assert len(instance.maybe_items) == 2 diff --git a/tests/test_serialization.py b/tests/test_serialization.py index dad7065de6..6bc74b9962 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -1,7 +1,7 @@ import json import struct from pathlib import Path -from typing import TYPE_CHECKING, Dict, List +from typing import TYPE_CHECKING from unittest.mock import Mock import pytest @@ -57,7 +57,7 @@ def is_dtensor_available(): @pytest.fixture -def dummy_state_dict() -> Dict[str, List[int]]: +def dummy_state_dict() -> dict[str, list[int]]: return { "layer_1": [6], "layer_2": [10], @@ -68,7 +68,7 @@ def dummy_state_dict() -> Dict[str, List[int]]: @pytest.fixture -def torch_state_dict() -> Dict[str, "torch.Tensor"]: +def torch_state_dict() -> dict[str, "torch.Tensor"]: try: import torch @@ -105,7 +105,7 @@ def __init__(self): @pytest.fixture -def torch_state_dict_tensor_subclass() -> Dict[str, "torch.Tensor"]: +def torch_state_dict_tensor_subclass() -> dict[str, "torch.Tensor"]: try: import torch # type: ignore[import] from torch.testing._internal.two_tensor import TwoTensor # type: ignore[import] @@ -124,7 +124,7 @@ def torch_state_dict_tensor_subclass() -> Dict[str, "torch.Tensor"]: @pytest.fixture -def torch_state_dict_shared_layers() -> Dict[str, "torch.Tensor"]: +def torch_state_dict_shared_layers() -> dict[str, "torch.Tensor"]: try: import torch # type: ignore[import] @@ -141,7 +141,7 @@ def torch_state_dict_shared_layers() -> Dict[str, "torch.Tensor"]: @pytest.fixture -def torch_state_dict_shared_layers_tensor_subclass() -> Dict[str, "torch.Tensor"]: +def torch_state_dict_shared_layers_tensor_subclass() -> dict[str, "torch.Tensor"]: try: import torch # type: ignore[import] from torch.testing._internal.two_tensor import TwoTensor # type: ignore[import] @@ -341,14 +341,14 @@ def test_save_torch_model(mocker: MockerFixture, tmp_path: Path) -> None: ) -def test_save_torch_state_dict_not_sharded(tmp_path: Path, torch_state_dict: Dict[str, "torch.Tensor"]) -> None: +def test_save_torch_state_dict_not_sharded(tmp_path: Path, torch_state_dict: dict[str, "torch.Tensor"]) -> None: """Save as safetensors without sharding.""" save_torch_state_dict(torch_state_dict, tmp_path, max_shard_size="1GB") assert (tmp_path / "model.safetensors").is_file() assert not (tmp_path / "model.safetensors.index.json").is_file() -def test_save_torch_state_dict_sharded(tmp_path: Path, torch_state_dict: Dict[str, "torch.Tensor"]) -> None: +def test_save_torch_state_dict_sharded(tmp_path: Path, torch_state_dict: dict[str, "torch.Tensor"]) -> None: """Save as safetensors with sharding.""" save_torch_state_dict(torch_state_dict, tmp_path, max_shard_size=30) assert not (tmp_path / "model.safetensors").is_file() @@ -369,7 +369,7 @@ def test_save_torch_state_dict_sharded(tmp_path: Path, torch_state_dict: Dict[st def test_save_torch_state_dict_unsafe_not_sharded( - tmp_path: Path, caplog: pytest.LogCaptureFixture, torch_state_dict: Dict[str, "torch.Tensor"] + tmp_path: Path, caplog: pytest.LogCaptureFixture, torch_state_dict: dict[str, "torch.Tensor"] ) -> None: """Save as pickle without sharding.""" with caplog.at_level("WARNING"): @@ -382,7 +382,7 @@ def test_save_torch_state_dict_unsafe_not_sharded( @pytest.mark.skipif(not is_wrapper_tensor_subclass_available(), reason="requires torch 2.1 or higher") def test_save_torch_state_dict_tensor_subclass_unsafe_not_sharded( - tmp_path: Path, caplog: pytest.LogCaptureFixture, torch_state_dict_tensor_subclass: Dict[str, "torch.Tensor"] + tmp_path: Path, caplog: pytest.LogCaptureFixture, torch_state_dict_tensor_subclass: dict[str, "torch.Tensor"] ) -> None: """Save as pickle without sharding.""" with caplog.at_level("WARNING"): @@ -399,7 +399,7 @@ def test_save_torch_state_dict_tensor_subclass_unsafe_not_sharded( def test_save_torch_state_dict_shared_layers_tensor_subclass_unsafe_not_sharded( tmp_path: Path, caplog: pytest.LogCaptureFixture, - torch_state_dict_shared_layers_tensor_subclass: Dict[str, "torch.Tensor"], + torch_state_dict_shared_layers_tensor_subclass: dict[str, "torch.Tensor"], ) -> None: """Save as pickle without sharding.""" with caplog.at_level("WARNING"): @@ -413,7 +413,7 @@ def test_save_torch_state_dict_shared_layers_tensor_subclass_unsafe_not_sharded( def test_save_torch_state_dict_unsafe_sharded( - tmp_path: Path, caplog: pytest.LogCaptureFixture, torch_state_dict: Dict[str, "torch.Tensor"] + tmp_path: Path, caplog: pytest.LogCaptureFixture, torch_state_dict: dict[str, "torch.Tensor"] ) -> None: """Save as pickle with sharding.""" # Check logs @@ -439,7 +439,7 @@ def test_save_torch_state_dict_unsafe_sharded( def test_save_torch_state_dict_shared_layers_not_sharded( - tmp_path: Path, torch_state_dict_shared_layers: Dict[str, "torch.Tensor"] + tmp_path: Path, torch_state_dict_shared_layers: dict[str, "torch.Tensor"] ) -> None: from safetensors.torch import load_file @@ -461,7 +461,7 @@ def test_save_torch_state_dict_shared_layers_not_sharded( def test_save_torch_state_dict_shared_layers_sharded( - tmp_path: Path, torch_state_dict_shared_layers: Dict[str, "torch.Tensor"] + tmp_path: Path, torch_state_dict_shared_layers: dict[str, "torch.Tensor"] ) -> None: from safetensors.torch import load_file @@ -480,7 +480,7 @@ def test_save_torch_state_dict_shared_layers_sharded( def test_save_torch_state_dict_discard_selected_sharded( - tmp_path: Path, torch_state_dict_shared_layers: Dict[str, "torch.Tensor"] + tmp_path: Path, torch_state_dict_shared_layers: dict[str, "torch.Tensor"] ) -> None: from safetensors.torch import load_file @@ -502,7 +502,7 @@ def test_save_torch_state_dict_discard_selected_sharded( def test_save_torch_state_dict_discard_selected_not_sharded( - tmp_path: Path, torch_state_dict_shared_layers: Dict[str, "torch.Tensor"] + tmp_path: Path, torch_state_dict_shared_layers: dict[str, "torch.Tensor"] ) -> None: from safetensors.torch import load_file @@ -529,7 +529,7 @@ def test_save_torch_state_dict_discard_selected_not_sharded( def test_split_torch_state_dict_into_shards( - tmp_path: Path, torch_state_dict_shared_layers_tensor_subclass: Dict[str, "torch.Tensor"] + tmp_path: Path, torch_state_dict_shared_layers_tensor_subclass: dict[str, "torch.Tensor"] ): # the model size is 72, setting max_shard_size to 32 means we'll shard the file state_dict_split = split_torch_state_dict_into_shards( @@ -540,7 +540,7 @@ def test_split_torch_state_dict_into_shards( assert state_dict_split.is_sharded -def test_save_torch_state_dict_custom_filename(tmp_path: Path, torch_state_dict: Dict[str, "torch.Tensor"]) -> None: +def test_save_torch_state_dict_custom_filename(tmp_path: Path, torch_state_dict: dict[str, "torch.Tensor"]) -> None: """Custom filename pattern is respected.""" # Not sharded save_torch_state_dict(torch_state_dict, tmp_path, filename_pattern="model.variant{suffix}.safetensors") @@ -556,7 +556,7 @@ def test_save_torch_state_dict_custom_filename(tmp_path: Path, torch_state_dict: def test_save_torch_state_dict_delete_existing_files( - tmp_path: Path, torch_state_dict: Dict[str, "torch.Tensor"] + tmp_path: Path, torch_state_dict: dict[str, "torch.Tensor"] ) -> None: """Directory is cleaned before saving new files.""" (tmp_path / "model.safetensors").touch() @@ -590,7 +590,7 @@ def test_save_torch_state_dict_delete_existing_files( def test_save_torch_state_dict_not_main_process( tmp_path: Path, - torch_state_dict: Dict[str, "torch.Tensor"], + torch_state_dict: dict[str, "torch.Tensor"], ) -> None: """ Test that previous files in the directory are not deleted when is_main_process=False. @@ -613,7 +613,7 @@ def test_save_torch_state_dict_not_main_process( @requires("torch") -def test_load_state_dict_from_file(tmp_path: Path, torch_state_dict: Dict[str, "torch.Tensor"]): +def test_load_state_dict_from_file(tmp_path: Path, torch_state_dict: dict[str, "torch.Tensor"]): """Test saving and loading a state dict with both safetensors and pickle formats.""" import torch # type: ignore[import] @@ -637,7 +637,7 @@ def test_load_state_dict_from_file(tmp_path: Path, torch_state_dict: Dict[str, " @requires("torch") def test_load_sharded_state_dict( tmp_path: Path, - torch_state_dict: Dict[str, "torch.Tensor"], + torch_state_dict: dict[str, "torch.Tensor"], dummy_model: "torch.nn.Module", ): """Test saving and loading a sharded state dict.""" @@ -666,7 +666,7 @@ def test_load_sharded_state_dict( @requires("torch") def test_load_from_directory_not_sharded( - tmp_path: Path, torch_state_dict: Dict[str, "torch.Tensor"], dummy_model: "torch.nn.Module" + tmp_path: Path, torch_state_dict: dict[str, "torch.Tensor"], dummy_model: "torch.nn.Module" ): import torch diff --git a/tests/test_utils_cache.py b/tests/test_utils_cache.py index efd8a961f3..a08ad14391 100644 --- a/tests/test_utils_cache.py +++ b/tests/test_utils_cache.py @@ -3,7 +3,7 @@ import time import unittest from pathlib import Path -from typing import Any, List +from typing import Any from unittest.mock import Mock import pytest @@ -854,6 +854,6 @@ def test_format_timesince(self) -> None: ) -def is_sublist(sub: List[Any], full: List[Any]) -> bool: +def is_sublist(sub: list[Any], full: list[Any]) -> bool: it = iter(full) return all(item in it for item in sub) diff --git a/tests/test_utils_paths.py b/tests/test_utils_paths.py index 82ffa174bd..39f311dd5f 100644 --- a/tests/test_utils_paths.py +++ b/tests/test_utils_paths.py @@ -1,7 +1,7 @@ import unittest from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, Optional, Union from huggingface_hub.utils import DEFAULT_IGNORE_PATTERNS, filter_repo_objects @@ -97,10 +97,10 @@ def test_filter_object_with_folder(self) -> None: def _check( self, - items: List[Any], - expected_items: List[Any], - allow_patterns: Optional[Union[List[str], str]] = None, - ignore_patterns: Optional[Union[List[str], str]] = None, + items: list[Any], + expected_items: list[Any], + allow_patterns: Optional[Union[list[str], str]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, key: Optional[Callable[[Any], str]] = None, ) -> None: """Run `filter_repo_objects` and check output against expected result.""" diff --git a/tests/test_utils_strict_dataclass.py b/tests/test_utils_strict_dataclass.py index 4a4cd6d56c..5eb68161b6 100644 --- a/tests/test_utils_strict_dataclass.py +++ b/tests/test_utils_strict_dataclass.py @@ -1,6 +1,6 @@ import inspect from dataclasses import asdict, astuple, dataclass, is_dataclass -from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union, get_type_hints +from typing import Any, Literal, Optional, Union, get_type_hints import jedi import pytest @@ -136,18 +136,18 @@ class Config: ("John", Literal["John", "Doe"]), (5, Literal[4, 5, 6]), # List - ([1, 2, 3], List[int]), - ([1, 2, "3"], List[Union[int, str]]), + ([1, 2, 3], list[int]), + ([1, 2, "3"], list[Union[int, str]]), # Tuple - ((1, 2, 3), Tuple[int, int, int]), - ((1, 2, "3"), Tuple[int, int, str]), - ((1, 2, 3, 4), Tuple[int, ...]), + ((1, 2, 3), tuple[int, int, int]), + ((1, 2, "3"), tuple[int, int, str]), + ((1, 2, 3, 4), tuple[int, ...]), # Dict - ({"a": 1, "b": 2}, Dict[str, int]), - ({"a": 1, "b": "2"}, Dict[str, Union[int, str]]), + ({"a": 1, "b": 2}, dict[str, int]), + ({"a": 1, "b": "2"}, dict[str, Union[int, str]]), # Set - ({1, 2, 3}, Set[int]), - ({1, 2, "3"}, Set[Union[int, str]]), + ({1, 2, 3}, set[int]), + ({1, 2, "3"}, set[Union[int, str]]), # Custom classes (DummyClass(), DummyClass), # Any @@ -162,13 +162,13 @@ class Config: (2, DummyClass(), None), ], }, - Dict[ + dict[ str, - List[ - Tuple[ + list[ + tuple[ int, DummyClass, - Optional[Set[Union[int, str],]], + Optional[set[Union[int, str],]], ] ], ], @@ -197,19 +197,19 @@ def test_type_validator_valid(value, type_annotation): ("Ada", Literal["John", "Doe"]), (3, Literal[4, 5, 6]), # List - (5, List[int]), - ([1, 2, "3"], List[int]), + (5, list[int]), + ([1, 2, "3"], list[int]), # Tuple - (5, Tuple[int, int, int]), - ((1, 2, "3"), Tuple[int, int, int]), - ((1, 2, 3, 4), Tuple[int, int, int]), - ((1, 2, "3", 4), Tuple[int, ...]), + (5, tuple[int, int, int]), + ((1, 2, "3"), tuple[int, int, int]), + ((1, 2, 3, 4), tuple[int, int, int]), + ((1, 2, "3", 4), tuple[int, ...]), # Dict - (5, Dict[str, int]), - ({"a": 1, "b": "2"}, Dict[str, int]), + (5, dict[str, int]), + ({"a": 1, "b": "2"}, dict[str, int]), # Set - (5, Set[int]), - ({1, 2, "3"}, Set[int]), + (5, set[int]), + ({1, 2, "3"}, set[int]), # Custom classes (5, DummyClass), ("John", DummyClass), diff --git a/tests/test_xet_download.py b/tests/test_xet_download.py index be59bf125e..24e6877738 100644 --- a/tests/test_xet_download.py +++ b/tests/test_xet_download.py @@ -1,7 +1,6 @@ import os from contextlib import contextmanager from pathlib import Path -from typing import Tuple from unittest.mock import DEFAULT, Mock, patch from huggingface_hub import snapshot_download @@ -318,7 +317,7 @@ def test_download_backward_compatibility(self, tmp_path): connection_info = refresh_xet_connection_info(file_data=xet_file_data, headers={}) - def token_refresher() -> Tuple[str, int]: + def token_refresher() -> tuple[str, int]: connection_info = refresh_xet_connection_info(file_data=xet_file_data, headers={}) return connection_info.access_token, connection_info.expiration_unix_epoch diff --git a/tests/test_xet_upload.py b/tests/test_xet_upload.py index d2f4a8b55f..2db0279c97 100644 --- a/tests/test_xet_upload.py +++ b/tests/test_xet_upload.py @@ -15,7 +15,6 @@ from contextlib import contextmanager from io import BytesIO from pathlib import Path -from typing import Tuple from unittest.mock import MagicMock, patch import pytest @@ -366,7 +365,7 @@ def test_hf_xet_with_token_refresher(self, api, tmp_path, repo_url): # manually construct parameters to hf_xet.download_files and use a locally defined token_refresher function # to verify that token refresh works as expected. - def token_refresher() -> Tuple[str, int]: + def token_refresher() -> tuple[str, int]: # Issue a token refresh by returning a new access token and expiration time new_connection = refresh_xet_connection_info(file_data=xet_filedata, headers=headers) return new_connection.access_token, new_connection.expiration_unix_epoch diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 792f08ad17..bc88840aee 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -9,7 +9,7 @@ from enum import Enum from functools import wraps from pathlib import Path -from typing import Callable, Optional, Type, TypeVar, Union +from typing import Callable, Optional, TypeVar, Union from unittest.mock import Mock, patch import httpx @@ -302,7 +302,7 @@ def _inner_test_function(*args, **kwargs): return _inner_decorator -def xfail_on_windows(reason: str, raises: Optional[Type[Exception]] = None): +def xfail_on_windows(reason: str, raises: Optional[type[Exception]] = None): """ Decorator to flag tests that we expect to fail on Windows. @@ -312,7 +312,7 @@ def xfail_on_windows(reason: str, raises: Optional[Type[Exception]] = None): Args: reason (`str`): Reason why it should fail. - raises (`Type[Exception]`): + raises (`type[Exception]`): The error type we except to happen. """ diff --git a/utils/check_all_variable.py b/utils/check_all_variable.py index 0754fae7e9..6b093740a9 100644 --- a/utils/check_all_variable.py +++ b/utils/check_all_variable.py @@ -18,7 +18,7 @@ import argparse import re from pathlib import Path -from typing import Dict, List, NoReturn +from typing import NoReturn from huggingface_hub import _SUBMOD_ATTRS @@ -26,7 +26,7 @@ INIT_FILE_PATH = Path(__file__).parents[1] / "src" / "huggingface_hub" / "__init__.py" -def format_all_definition(submod_attrs: Dict[str, List[str]]) -> str: +def format_all_definition(submod_attrs: dict[str, list[str]]) -> str: """ Generate a formatted static __all__ definition with grouped comments. """ @@ -39,7 +39,7 @@ def format_all_definition(submod_attrs: Dict[str, List[str]]) -> str: return "\n".join(lines) -def parse_all_definition(content: str) -> List[str]: +def parse_all_definition(content: str) -> list[str]: """ Extract the current __all__ contents from file content. diff --git a/utils/check_task_parameters.py b/utils/check_task_parameters.py index eec18aeaa9..d8948ef278 100644 --- a/utils/check_task_parameters.py +++ b/utils/check_task_parameters.py @@ -42,7 +42,7 @@ import textwrap from collections import defaultdict from pathlib import Path -from typing import Any, Dict, List, NoReturn, Optional, Set, Tuple +from typing import Any, NoReturn, Optional import libcst as cst from helpers import format_source_code @@ -101,7 +101,7 @@ class DataclassFieldCollector(cst.CSTVisitor): def __init__(self, dataclass_name: str): self.dataclass_name = dataclass_name - self.parameters: Dict[str, Dict[str, str]] = {} + self.parameters: dict[str, dict[str, str]] = {} def visit_ClassDef(self, node: cst.ClassDef) -> None: """Visit class definitions to find the target dataclass.""" @@ -130,7 +130,7 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None: @staticmethod def _extract_docstring( - body_statements: List[cst.CSTNode], + body_statements: list[cst.CSTNode], field_index: int, ) -> str: """Extract the docstring following a field definition.""" @@ -169,7 +169,7 @@ class MethodArgumentsCollector(cst.CSTVisitor): def __init__(self, method_name: str): self.method_name = method_name - self.parameters: Dict[str, Dict[str, str]] = {} + self.parameters: dict[str, dict[str, str]] = {} def visit_FunctionDef(self, node: cst.FunctionDef) -> None: if node.name.value != self.method_name: @@ -194,7 +194,7 @@ def _extract_docstring(self, node: cst.FunctionDef) -> str: return node.body.body[0].body[0].value.evaluated_value return "" - def _parse_docstring_params(self, docstring: str) -> Dict[str, str]: + def _parse_docstring_params(self, docstring: str) -> dict[str, str]: """Parse parameter descriptions from docstring.""" param_docs = {} lines = docstring.split("\n") @@ -230,7 +230,7 @@ def _parse_docstring_params(self, docstring: str) -> Dict[str, str]: class AddImports(cst.CSTTransformer): """Transformer that adds import statements to the module.""" - def __init__(self, imports_to_add: List[cst.BaseStatement]): + def __init__(self, imports_to_add: list[cst.BaseStatement]): self.imports_to_add = imports_to_add self.added = False @@ -265,7 +265,7 @@ def leave_Module( class UpdateParameters(cst.CSTTransformer): """Updates a method's parameters, types, and docstrings.""" - def __init__(self, method_name: str, param_updates: Dict[str, Dict[str, str]]): + def __init__(self, method_name: str, param_updates: dict[str, dict[str, str]]): self.method_name = method_name self.param_updates = param_updates self.found_method = False # Flag to check if the method is found @@ -383,10 +383,10 @@ def _update_docstring_content(self, docstring: str) -> str: def _format_param_docstring( self, param_name: str, - param_info: Dict[str, str], + param_info: dict[str, str], param_indent: str, desc_indent: str, - ) -> List[str]: + ) -> list[str]: """Format the docstring lines for a single parameter.""" # Extract and format the parameter type param_type = param_info["type"] @@ -417,12 +417,12 @@ def _format_param_docstring( def _process_existing_params( self, - docstring_lines: List[str], - params_to_update: Dict[str, Dict[str, str]], + docstring_lines: list[str], + params_to_update: dict[str, dict[str, str]], args_index: int, param_indent: str, desc_indent: str, - ) -> Tuple[List[str], Dict[str, Dict[str, str]]]: + ) -> tuple[list[str], dict[str, dict[str, str]]]: """Update existing parameters in the docstring.""" # track the params that are updated params_updated = params_to_update.copy() @@ -473,12 +473,12 @@ def _process_existing_params( def _add_new_params( self, - docstring_lines: List[str], - new_params: Dict[str, Dict[str, str]], + docstring_lines: list[str], + new_params: dict[str, dict[str, str]], args_index: int, param_indent: str, desc_indent: str, - ) -> List[str]: + ) -> list[str]: """Add new parameters to the docstring.""" # Find the insertion point after existing parameters insertion_index = args_index + 1 @@ -521,7 +521,7 @@ def _check_parameters( parameters_module: cst.Module, method_name: str, parameter_type_name: str, -) -> Dict[str, Dict[str, Any]]: +) -> dict[str, dict[str, Any]]: """ Check for missing parameters and outdated types/docstrings. @@ -571,7 +571,7 @@ def _check_parameters( def _update_parameters( module: cst.Module, method_name: str, - param_updates: Dict[str, Dict[str, str]], + param_updates: dict[str, dict[str, str]], ) -> cst.Module: """ Update method parameters, types and docstrings. @@ -590,21 +590,21 @@ def _update_parameters( def _get_imports_to_add( - parameters: Dict[str, Dict[str, str]], + parameters: dict[str, dict[str, str]], parameters_module: cst.Module, inference_client_module: cst.Module, -) -> Dict[str, List[str]]: +) -> dict[str, list[str]]: """ Get the needed imports for missing parameters. Args: - parameters (Dict[str, Dict[str, str]]): Dictionary of parameters with their type and docstring. + parameters (dict[str, dict[str, str]]): Dictionary of parameters with their type and docstring. eg: {"function_to_apply": {"type": "ClassificationOutputTransform", "docstring": "Function to apply to the input."}} parameters_module (cst.Module): The module where the parameters are defined. inference_client_module (cst.Module): The module of the inference client. Returns: - Dict[str, List[str]]: A dictionary mapping modules to list of types to import. + dict[str, list[str]]: A dictionary mapping modules to list of types to import. eg: {"huggingface_hub.inference._generated.types": ["ClassificationOutputTransform"]} """ # Collect all type names from parameter annotations @@ -630,12 +630,12 @@ def _get_imports_to_add( return needed_imports -def _generate_import_statements(import_dict: Dict[str, List[str]]) -> str: +def _generate_import_statements(import_dict: dict[str, list[str]]) -> str: """ Generate import statements from a dictionary of needed imports. Args: - import_dict (Dict[str, List[str]]): Dictionary mapping modules to list of types to import. + import_dict (dict[str, list[str]]): Dictionary mapping modules to list of types to import. eg: {"typing": ["List", "Dict"], "huggingface_hub.inference._generated.types": ["ClassificationOutputTransform"]} Returns: @@ -658,7 +658,7 @@ def _normalize_docstring(docstring: str) -> str: # TODO: Needs to be improved, maybe using `typing.get_type_hints` instead (we gonna need to access the method though)? -def _collect_type_hints_from_annotation(annotation_str: str) -> Set[str]: +def _collect_type_hints_from_annotation(annotation_str: str) -> set[str]: """ Collect type hints from an annotation string. @@ -666,7 +666,7 @@ def _collect_type_hints_from_annotation(annotation_str: str) -> Set[str]: annotation_str (str): The annotation string. Returns: - Set[str]: A set of type hints. + set[str]: A set of type hints. """ type_string = annotation_str.replace(" ", "") builtin_types = {d for d in dir(builtins) if isinstance(getattr(builtins, d), type)} @@ -699,7 +699,7 @@ def _parse_module_from_file(filepath: Path) -> Optional[cst.Module]: def _check_and_update_parameters( - method_params: Dict[str, str], + method_params: dict[str, str], update: bool, ) -> NoReturn: """ diff --git a/utils/generate_inference_types.py b/utils/generate_inference_types.py index 823e814323..aaa071b8f3 100644 --- a/utils/generate_inference_types.py +++ b/utils/generate_inference_types.py @@ -17,7 +17,7 @@ import argparse import re from pathlib import Path -from typing import Dict, List, Literal, NoReturn, Optional +from typing import Literal, NoReturn, Optional import libcst as cst from helpers import check_and_update_file_content, format_source_code @@ -219,12 +219,12 @@ def _make_optional_fields_default_to_none(content: str): return "\n".join(lines) -def _list_dataclasses(content: str) -> List[str]: +def _list_dataclasses(content: str) -> list[str]: """List all dataclasses defined in the module.""" return INHERITED_DATACLASS_REGEX.findall(content) -def _list_type_aliases(content: str) -> List[str]: +def _list_type_aliases(content: str) -> list[str]: """List all type aliases defined in the module.""" return [alias_class for alias_class, _ in TYPE_ALIAS_REGEX.findall(content)] @@ -234,7 +234,7 @@ def is_deprecated(self, docstring: Optional[str]) -> bool: """Check if a docstring contains @deprecated.""" return docstring is not None and "@deprecated" in docstring.lower() - def get_docstring(self, body: List[cst.BaseStatement]) -> Optional[str]: + def get_docstring(self, body: list[cst.BaseStatement]) -> Optional[str]: """Extract docstring from a body of statements.""" if not body: return None @@ -294,7 +294,7 @@ def fix_inference_classes(content: str, module_name: str) -> str: return content -def create_init_py(dataclasses: Dict[str, List[str]]): +def create_init_py(dataclasses: dict[str, list[str]]): """Create __init__.py file with all dataclasses.""" content = INIT_PY_HEADER content += "\n" @@ -304,14 +304,14 @@ def create_init_py(dataclasses: Dict[str, List[str]]): return content -def add_dataclasses_to_main_init(content: str, dataclasses: Dict[str, List[str]]): +def add_dataclasses_to_main_init(content: str, dataclasses: dict[str, list[str]]): dataclasses_list = sorted({cls for classes in dataclasses.values() for cls in classes}) dataclasses_str = ", ".join(f"'{cls}'" for cls in dataclasses_list) return MAIN_INIT_PY_REGEX.sub(f'"inference._generated.types": [{dataclasses_str}]', content) -def generate_reference_package(dataclasses: Dict[str, List[str]], language: Literal["en", "ko"]) -> str: +def generate_reference_package(dataclasses: dict[str, list[str]], language: Literal["en", "ko"]) -> str: """Generate the reference package content.""" per_task_docs = [] From 7213f97664c3baa96b7a0a7e8bc77de603a16eec Mon Sep 17 00:00:00 2001 From: Lucain Date: Thu, 11 Sep 2025 11:40:09 +0200 Subject: [PATCH 03/19] Remove `HfFolder` and `InferenceAPI` classes (#3344) * Remove HfFolder * Remove InferenceAPI * more recent gradio * bump pytest * fix python 3.9? * install gradio only on python 3.10+ * fix tests * fix tests * fix --- docs/source/de/guides/inference.md | 147 ++--------- docs/source/en/guides/inference.md | 4 - .../en/package_reference/inference_client.md | 13 - docs/source/ko/guides/inference.md | 129 +++------- .../ko/package_reference/inference_client.md | 10 - setup.py | 11 +- src/huggingface_hub/README.md | 239 ------------------ src/huggingface_hub/__init__.py | 8 - src/huggingface_hub/constants.py | 37 --- src/huggingface_hub/hf_api.py | 1 - src/huggingface_hub/inference_api.py | 217 ---------------- src/huggingface_hub/utils/__init__.py | 1 - src/huggingface_hub/utils/_hf_folder.py | 68 ----- tests/test_inference_api.py | 140 ---------- tests/test_utils_headers.py | 2 - tests/test_utils_hf_folder.py | 53 ---- tests/test_webhooks_server.py | 44 ++-- 17 files changed, 88 insertions(+), 1036 deletions(-) delete mode 100644 src/huggingface_hub/inference_api.py delete mode 100644 src/huggingface_hub/utils/_hf_folder.py delete mode 100644 tests/test_inference_api.py delete mode 100644 tests/test_utils_hf_folder.py diff --git a/docs/source/de/guides/inference.md b/docs/source/de/guides/inference.md index 6bd3e111dd..21e06cf070 100644 --- a/docs/source/de/guides/inference.md +++ b/docs/source/de/guides/inference.md @@ -8,7 +8,6 @@ Inferenz ist der Prozess, bei dem ein trainiertes Modell verwendet wird, um Vorh - [Inferenz API](https://huggingface.co/docs/api-inference/index): ein Service, der Ihnen ermöglicht, beschleunigte Inferenz auf der Infrastruktur von Hugging Face kostenlos auszuführen. Dieser Service ist eine schnelle Möglichkeit, um anzufangen, verschiedene Modelle zu testen und AI-Produkte zu prototypisieren. - [Inferenz Endpunkte](https://huggingface.co/inference-endpoints/index): ein Produkt zur einfachen Bereitstellung von Modellen im Produktivbetrieb. Die Inferenz wird von Hugging Face in einer dedizierten, vollständig verwalteten Infrastruktur auf einem Cloud-Anbieter Ihrer Wahl durchgeführt. -Diese Dienste können mit dem [`InferenceClient`] Objekt aufgerufen werden. Dieser fungiert als Ersatz für den älteren [`InferenceApi`] Client und fügt spezielle Unterstützung für Aufgaben und das Ausführen von Inferenz hinzu, sowohl auf [Inferenz API](https://huggingface.co/docs/api-inference/index) als auch auf [Inferenz Endpunkten](https://huggingface.co/docs/inference-endpoints/index). Im Abschnitt [Legacy InferenceAPI client](#legacy-inferenceapi-client) erfahren Sie, wie Sie zum neuen Client migrieren können. @@ -89,34 +88,34 @@ Die Authentifizierung ist NICHT zwingend erforderlich, wenn Sie die Inferenz API Das Ziel von [`InferenceClient`] ist es, die einfachste Schnittstelle zum Ausführen von Inferenzen auf Hugging Face-Modellen bereitzustellen. Es verfügt über eine einfache API, die die gebräuchlichsten Aufgaben unterstützt. Hier ist eine Liste der derzeit unterstützten Aufgaben: -| Domäne | Aufgabe | Unterstützt | Dokumentation | -|--------|--------------------------------|--------------|------------------------------------| -| Audio | [Audio Classification](https://huggingface.co/tasks/audio-classification) | ✅ | [`~InferenceClient.audio_classification`] | -| | [Automatic Speech Recognition](https://huggingface.co/tasks/automatic-speech-recognition) | ✅ | [`~InferenceClient.automatic_speech_recognition`] | -| | [Text-to-Speech](https://huggingface.co/tasks/text-to-speech) | ✅ | [`~InferenceClient.text_to_speech`] | -| Computer Vision | [Image Classification](https://huggingface.co/tasks/image-classification) | ✅ | [`~InferenceClient.image_classification`] | -| | [Image Segmentation](https://huggingface.co/tasks/image-segmentation) | ✅ | [`~InferenceClient.image_segmentation`] | -| | [Image-to-Image](https://huggingface.co/tasks/image-to-image) | ✅ | [`~InferenceClient.image_to_image`] | -| | [Image-to-Text](https://huggingface.co/tasks/image-to-text) | ✅ | [`~InferenceClient.image_to_text`] | -| | [Object Detection](https://huggingface.co/tasks/object-detection) | ✅ | [`~InferenceClient.object_detection`] | -| | [Text-to-Image](https://huggingface.co/tasks/text-to-image) | ✅ | [`~InferenceClient.text_to_image`] | -| | [Zero-Shot-Image-Classification](https://huggingface.co/tasks/zero-shot-image-classification) | ✅ | [`~InferenceClient.zero_shot_image_classification`] | -| Multimodal | [Documentation Question Answering](https://huggingface.co/tasks/document-question-answering) | ✅ | [`~InferenceClient.document_question_answering`] | -| | [Visual Question Answering](https://huggingface.co/tasks/visual-question-answering) | ✅ | [`~InferenceClient.visual_question_answering`] | -| NLP | [Conversational](https://huggingface.co/tasks/conversational) | ✅ | [`~InferenceClient.conversational`] | -| | [Feature Extraction](https://huggingface.co/tasks/feature-extraction) | ✅ | [`~InferenceClient.feature_extraction`] | -| | [Fill Mask](https://huggingface.co/tasks/fill-mask) | ✅ | [`~InferenceClient.fill_mask`] | -| | [Question Answering](https://huggingface.co/tasks/question-answering) | ✅ | [`~InferenceClient.question_answering`] | -| | [Sentence Similarity](https://huggingface.co/tasks/sentence-similarity) | ✅ | [`~InferenceClient.sentence_similarity`] | -| | [Summarization](https://huggingface.co/tasks/summarization) | ✅ | [`~InferenceClient.summarization`] | -| | [Table Question Answering](https://huggingface.co/tasks/table-question-answering) | ✅ | [`~InferenceClient.table_question_answering`] | -| | [Text Classification](https://huggingface.co/tasks/text-classification) | ✅ | [`~InferenceClient.text_classification`] | -| | [Text Generation](https://huggingface.co/tasks/text-generation) | ✅ | [`~InferenceClient.text_generation`] | -| | [Token Classification](https://huggingface.co/tasks/token-classification) | ✅ | [`~InferenceClient.token_classification`] | -| | [Translation](https://huggingface.co/tasks/translation) | ✅ | [`~InferenceClient.translation`] | -| | [Zero Shot Classification](https://huggingface.co/tasks/zero-shot-classification) | ✅ | [`~InferenceClient.zero_shot_classification`] | -| Tabular | [Tabular Classification](https://huggingface.co/tasks/tabular-classification) | ✅ | [`~InferenceClient.tabular_classification`] | -| | [Tabular Regression](https://huggingface.co/tasks/tabular-regression) | ✅ | [`~InferenceClient.tabular_regression`] | +| Domäne | Aufgabe | Unterstützt | Dokumentation | +| --------------- | --------------------------------------------------------------------------------------------- | ----------- | --------------------------------------------------- | +| Audio | [Audio Classification](https://huggingface.co/tasks/audio-classification) | ✅ | [`~InferenceClient.audio_classification`] | +| | [Automatic Speech Recognition](https://huggingface.co/tasks/automatic-speech-recognition) | ✅ | [`~InferenceClient.automatic_speech_recognition`] | +| | [Text-to-Speech](https://huggingface.co/tasks/text-to-speech) | ✅ | [`~InferenceClient.text_to_speech`] | +| Computer Vision | [Image Classification](https://huggingface.co/tasks/image-classification) | ✅ | [`~InferenceClient.image_classification`] | +| | [Image Segmentation](https://huggingface.co/tasks/image-segmentation) | ✅ | [`~InferenceClient.image_segmentation`] | +| | [Image-to-Image](https://huggingface.co/tasks/image-to-image) | ✅ | [`~InferenceClient.image_to_image`] | +| | [Image-to-Text](https://huggingface.co/tasks/image-to-text) | ✅ | [`~InferenceClient.image_to_text`] | +| | [Object Detection](https://huggingface.co/tasks/object-detection) | ✅ | [`~InferenceClient.object_detection`] | +| | [Text-to-Image](https://huggingface.co/tasks/text-to-image) | ✅ | [`~InferenceClient.text_to_image`] | +| | [Zero-Shot-Image-Classification](https://huggingface.co/tasks/zero-shot-image-classification) | ✅ | [`~InferenceClient.zero_shot_image_classification`] | +| Multimodal | [Documentation Question Answering](https://huggingface.co/tasks/document-question-answering) | ✅ | [`~InferenceClient.document_question_answering`] | +| | [Visual Question Answering](https://huggingface.co/tasks/visual-question-answering) | ✅ | [`~InferenceClient.visual_question_answering`] | +| NLP | [Conversational](https://huggingface.co/tasks/conversational) | ✅ | [`~InferenceClient.conversational`] | +| | [Feature Extraction](https://huggingface.co/tasks/feature-extraction) | ✅ | [`~InferenceClient.feature_extraction`] | +| | [Fill Mask](https://huggingface.co/tasks/fill-mask) | ✅ | [`~InferenceClient.fill_mask`] | +| | [Question Answering](https://huggingface.co/tasks/question-answering) | ✅ | [`~InferenceClient.question_answering`] | +| | [Sentence Similarity](https://huggingface.co/tasks/sentence-similarity) | ✅ | [`~InferenceClient.sentence_similarity`] | +| | [Summarization](https://huggingface.co/tasks/summarization) | ✅ | [`~InferenceClient.summarization`] | +| | [Table Question Answering](https://huggingface.co/tasks/table-question-answering) | ✅ | [`~InferenceClient.table_question_answering`] | +| | [Text Classification](https://huggingface.co/tasks/text-classification) | ✅ | [`~InferenceClient.text_classification`] | +| | [Text Generation](https://huggingface.co/tasks/text-generation) | ✅ | [`~InferenceClient.text_generation`] | +| | [Token Classification](https://huggingface.co/tasks/token-classification) | ✅ | [`~InferenceClient.token_classification`] | +| | [Translation](https://huggingface.co/tasks/translation) | ✅ | [`~InferenceClient.translation`] | +| | [Zero Shot Classification](https://huggingface.co/tasks/zero-shot-classification) | ✅ | [`~InferenceClient.zero_shot_classification`] | +| Tabular | [Tabular Classification](https://huggingface.co/tasks/tabular-classification) | ✅ | [`~InferenceClient.tabular_classification`] | +| | [Tabular Regression](https://huggingface.co/tasks/tabular-regression) | ✅ | [`~InferenceClient.tabular_regression`] | @@ -190,93 +189,3 @@ Einige Aufgaben erfordern binäre Eingaben, zum Beispiel bei der Arbeit mit Bild [{'score': 0.9779096841812134, 'label': 'Blenheim spaniel'}, ...] ``` -## Legacy InferenceAPI client - -Der [`InferenceClient`] dient als Ersatz für den veralteten [`InferenceApi`]-Client. Er bietet spezifische Unterstützung für Aufgaben und behandelt Inferenz sowohl auf der [Inferenz API](https://huggingface.co/docs/api-inference/index) als auch auf den [Inferenz Endpunkten](https://huggingface.co/docs/inference-endpoints/index). - -Hier finden Sie eine kurze Anleitung, die Ihnen hilft, von [`InferenceApi`] zu [`InferenceClient`] zu migrieren. - -### Initialisierung - -Ändern Sie von - -```python ->>> from huggingface_hub import InferenceApi ->>> inference = InferenceApi(repo_id="bert-base-uncased", token=API_TOKEN) -``` - -zu - -```python ->>> from huggingface_hub import InferenceClient ->>> inference = InferenceClient(model="bert-base-uncased", token=API_TOKEN) -``` - -### Ausführen einer bestimmten Aufgabe - -Ändern Sie von - -```python ->>> from huggingface_hub import InferenceApi ->>> inference = InferenceApi(repo_id="paraphrase-xlm-r-multilingual-v1", task="feature-extraction") ->>> inference(...) -``` - -zu - -```python ->>> from huggingface_hub import InferenceClient ->>> inference = InferenceClient() ->>> inference.feature_extraction(..., model="paraphrase-xlm-r-multilingual-v1") -``` - - - -Dies ist der empfohlene Weg, um Ihren Code an [`InferenceClient`] anzupassen. Dadurch können Sie von den aufgabenspezifischen Methoden wie `feature_extraction` profitieren. - - - -### Eigene Anfragen ausführen - -Ändern Sie von - -```python ->>> from huggingface_hub import InferenceApi ->>> inference = InferenceApi(repo_id="bert-base-uncased") ->>> inference(inputs="The goal of life is [MASK].") -[{'sequence': 'the goal of life is life.', 'score': 0.10933292657136917, 'token': 2166, 'token_str': 'life'}] -``` -zu - -```python ->>> from huggingface_hub import InferenceClient ->>> client = InferenceClient() ->>> response = client.post(json={"inputs": "The goal of life is [MASK]."}, model="bert-base-uncased") ->>> response.json() -[{'sequence': 'the goal of life is life.', 'score': 0.10933292657136917, 'token': 2166, 'token_str': 'life'}] -``` - -### Mit Parametern ausführen - -Ändern Sie von - -```python ->>> from huggingface_hub import InferenceApi ->>> inference = InferenceApi(repo_id="typeform/distilbert-base-uncased-mnli") ->>> inputs = "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!" ->>> params = {"candidate_labels":["refund", "legal", "faq"]} ->>> inference(inputs, params) -{'sequence': 'Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!', 'labels': ['refund', 'faq', 'legal'], 'scores': [0.9378499388694763, 0.04914155602455139, 0.013008488342165947]} -``` - -zu - -```python ->>> from huggingface_hub import InferenceClient ->>> client = InferenceClient() ->>> inputs = "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!" ->>> params = {"candidate_labels":["refund", "legal", "faq"]} ->>> response = client.post(json={"inputs": inputs, "parameters": params}, model="typeform/distilbert-base-uncased-mnli") ->>> response.json() -{'sequence': 'Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!', 'labels': ['refund', 'faq', 'legal'], 'scores': [0.9378499388694763, 0.04914155602455139, 0.013008488342165947]} -``` diff --git a/docs/source/en/guides/inference.md b/docs/source/en/guides/inference.md index 23cbfdd8c5..a5e31e833d 100644 --- a/docs/source/en/guides/inference.md +++ b/docs/source/en/guides/inference.md @@ -11,10 +11,6 @@ The `huggingface_hub` library provides a unified interface to run inference acro 2. [Inference Endpoints](https://huggingface.co/docs/inference-endpoints/index): a product to easily deploy models to production. Inference is run by Hugging Face in a dedicated, fully managed infrastructure on a cloud provider of your choice. 3. Local endpoints: you can also run inference with local inference servers like [llama.cpp](https://github.com/ggerganov/llama.cpp), [Ollama](https://ollama.com/), [vLLM](https://github.com/vllm-project/vllm), [LiteLLM](https://docs.litellm.ai/docs/simple_proxy), or [Text Generation Inference (TGI)](https://github.com/huggingface/text-generation-inference) by connecting the client to these local endpoints. -These services can all be called from the [`InferenceClient`] object. It acts as a replacement for the legacy -[`InferenceApi`] client, adding specific support for tasks and third-party providers. -Learn how to migrate to the new client in the [Legacy InferenceAPI client](#legacy-inferenceapi-client) section. - [`InferenceClient`] is a Python client making HTTP calls to our APIs. If you want to make the HTTP calls directly using diff --git a/docs/source/en/package_reference/inference_client.md b/docs/source/en/package_reference/inference_client.md index eae0edc755..1a92641077 100644 --- a/docs/source/en/package_reference/inference_client.md +++ b/docs/source/en/package_reference/inference_client.md @@ -34,16 +34,3 @@ pip install --upgrade huggingface_hub[inference] ## InferenceTimeoutError [[autodoc]] InferenceTimeoutError - -## InferenceAPI - -[`InferenceAPI`] is the legacy way to call the Inference API. The interface is more simplistic and requires knowing -the input parameters and output format for each task. It also lacks the ability to connect to other services like -Inference Endpoints or AWS SageMaker. [`InferenceAPI`] will soon be deprecated so we recommend using [`InferenceClient`] -whenever possible. Check out [this guide](../guides/inference#legacy-inferenceapi-client) to learn how to switch from -[`InferenceAPI`] to [`InferenceClient`] in your scripts. - -[[autodoc]] InferenceApi - - __init__ - - __call__ - - all diff --git a/docs/source/ko/guides/inference.md b/docs/source/ko/guides/inference.md index f3ddb3e795..a6f9e5f0d1 100644 --- a/docs/source/ko/guides/inference.md +++ b/docs/source/ko/guides/inference.md @@ -8,7 +8,6 @@ rendered properly in your Markdown viewer. - [추론 API](https://huggingface.co/docs/api-inference/index): Hugging Face의 인프라에서 가속화된 추론을 실행할 수 있는 서비스로 무료로 제공됩니다. 이 서비스는 추론을 시작하고 다양한 모델을 테스트하며 AI 제품의 프로토타입을 만드는 빠른 방법입니다. - [추론 엔드포인트](https://huggingface.co/docs/inference-endpoints/index): 모델을 제품 환경에 쉽게 배포할 수 있는 제품입니다. 사용자가 선택한 클라우드 환경에서 완전 관리되는 전용 인프라에서 Hugging Face를 통해 추론이 실행됩니다. -이러한 서비스들은 [`InferenceClient`] 객체를 사용하여 호출할 수 있습니다. 이는 이전의 [`InferenceApi`] 클라이언트를 대체하는 역할을 하며, 작업에 대한 특별한 지원을 추가하고 [추론 API](https://huggingface.co/docs/api-inference/index) 및 [추론 엔드포인트](https://huggingface.co/docs/inference-endpoints/index)에서 추론 작업을 처리합니다. 새 클라이언트로의 마이그레이션에 대한 자세한 내용은 [레거시 InferenceAPI 클라이언트](#legacy-inferenceapi-client) 섹션을 참조하세요. @@ -89,35 +88,35 @@ Hugging Face Hub에는 20만 개가 넘는 모델이 있습니다! [`InferenceCl [`InferenceClient`]의 목표는 Hugging Face 모델에서 추론을 실행하기 위한 가장 쉬운 인터페이스를 제공하는 것입니다. 이는 가장 일반적인 작업들을 지원하는 간단한 API를 가지고 있습니다. 현재 지원되는 작업 목록은 다음과 같습니다: -| 도메인 | 작업 | 지원 여부 | 문서 | -|--------|--------------------------------|--------------|------------------------------------| -| 오디오 | [오디오 분류](https://huggingface.co/tasks/audio-classification) | ✅ | [`~InferenceClient.audio_classification`] | -| 오디오 | [오디오 투 오디오](https://huggingface.co/tasks/audio-to-audio) | ✅ | [`~InferenceClient.audio_to_audio`] | -| | [자동 음성 인식](https://huggingface.co/tasks/automatic-speech-recognition) | ✅ | [`~InferenceClient.automatic_speech_recognition`] | -| | [텍스트 투 스피치](https://huggingface.co/tasks/text-to-speech) | ✅ | [`~InferenceClient.text_to_speech`] | -| 컴퓨터 비전 | [이미지 분류](https://huggingface.co/tasks/image-classification) | ✅ | [`~InferenceClient.image_classification`] | -| | [이미지 분할](https://huggingface.co/tasks/image-segmentation) | ✅ | [`~InferenceClient.image_segmentation`] | -| | [이미지 투 이미지](https://huggingface.co/tasks/image-to-image) | ✅ | [`~InferenceClient.image_to_image`] | -| | [이미지 투 텍스트](https://huggingface.co/tasks/image-to-text) | ✅ | [`~InferenceClient.image_to_text`] | -| | [객체 탐지](https://huggingface.co/tasks/object-detection) | ✅ | [`~InferenceClient.object_detection`] | -| | [텍스트 투 이미지](https://huggingface.co/tasks/text-to-image) | ✅ | [`~InferenceClient.text_to_image`] | -| | [제로샷 이미지 분류](https://huggingface.co/tasks/zero-shot-image-classification) | ✅ | [`~InferenceClient.zero_shot_image_classification`] | -| 멀티모달 | [문서 질의 응답](https://huggingface.co/tasks/document-question-answering) | ✅ | [`~InferenceClient.document_question_answering`] | -| | [시각적 질의 응답](https://huggingface.co/tasks/visual-question-answering) | ✅ | [`~InferenceClient.visual_question_answering`] | -| 자연어 처리 | [대화형](https://huggingface.co/tasks/conversational) | ✅ | [`~InferenceClient.conversational`] | -| | [특성 추출](https://huggingface.co/tasks/feature-extraction) | ✅ | [`~InferenceClient.feature_extraction`] | -| | [마스크 채우기](https://huggingface.co/tasks/fill-mask) | ✅ | [`~InferenceClient.fill_mask`] | -| | [질의 응답](https://huggingface.co/tasks/question-answering) | ✅ | [`~InferenceClient.question_answering`] | -| | [문장 유사도](https://huggingface.co/tasks/sentence-similarity) | ✅ | [`~InferenceClient.sentence_similarity`] | -| | [요약](https://huggingface.co/tasks/summarization) | ✅ | [`~InferenceClient.summarization`] | -| | [테이블 질의 응답](https://huggingface.co/tasks/table-question-answering) | ✅ | [`~InferenceClient.table_question_answering`] | -| | [텍스트 분류](https://huggingface.co/tasks/text-classification) | ✅ | [`~InferenceClient.text_classification`] | -| | [텍스트 생성](https://huggingface.co/tasks/text-generation) | ✅ | [`~InferenceClient.text_generation`] | -| | [토큰 분류](https://huggingface.co/tasks/token-classification) | ✅ | [`~InferenceClient.token_classification`] | -| | [번역](https://huggingface.co/tasks/translation) | ✅ | [`~InferenceClient.translation`] | -| | [제로샷 분류](https://huggingface.co/tasks/zero-shot-classification) | ✅ | [`~InferenceClient.zero_shot_classification`] | -| 타블로 | [타블로 작업 분류](https://huggingface.co/tasks/tabular-classification) | ✅ | [`~InferenceClient.tabular_classification`] | -| | [타블로 회귀](https://huggingface.co/tasks/tabular-regression) | ✅ | [`~InferenceClient.tabular_regression`] | +| 도메인 | 작업 | 지원 여부 | 문서 | +| ----------- | --------------------------------------------------------------------------------- | --------- | --------------------------------------------------- | +| 오디오 | [오디오 분류](https://huggingface.co/tasks/audio-classification) | ✅ | [`~InferenceClient.audio_classification`] | +| 오디오 | [오디오 투 오디오](https://huggingface.co/tasks/audio-to-audio) | ✅ | [`~InferenceClient.audio_to_audio`] | +| | [자동 음성 인식](https://huggingface.co/tasks/automatic-speech-recognition) | ✅ | [`~InferenceClient.automatic_speech_recognition`] | +| | [텍스트 투 스피치](https://huggingface.co/tasks/text-to-speech) | ✅ | [`~InferenceClient.text_to_speech`] | +| 컴퓨터 비전 | [이미지 분류](https://huggingface.co/tasks/image-classification) | ✅ | [`~InferenceClient.image_classification`] | +| | [이미지 분할](https://huggingface.co/tasks/image-segmentation) | ✅ | [`~InferenceClient.image_segmentation`] | +| | [이미지 투 이미지](https://huggingface.co/tasks/image-to-image) | ✅ | [`~InferenceClient.image_to_image`] | +| | [이미지 투 텍스트](https://huggingface.co/tasks/image-to-text) | ✅ | [`~InferenceClient.image_to_text`] | +| | [객체 탐지](https://huggingface.co/tasks/object-detection) | ✅ | [`~InferenceClient.object_detection`] | +| | [텍스트 투 이미지](https://huggingface.co/tasks/text-to-image) | ✅ | [`~InferenceClient.text_to_image`] | +| | [제로샷 이미지 분류](https://huggingface.co/tasks/zero-shot-image-classification) | ✅ | [`~InferenceClient.zero_shot_image_classification`] | +| 멀티모달 | [문서 질의 응답](https://huggingface.co/tasks/document-question-answering) | ✅ | [`~InferenceClient.document_question_answering`] | +| | [시각적 질의 응답](https://huggingface.co/tasks/visual-question-answering) | ✅ | [`~InferenceClient.visual_question_answering`] | +| 자연어 처리 | [대화형](https://huggingface.co/tasks/conversational) | ✅ | [`~InferenceClient.conversational`] | +| | [특성 추출](https://huggingface.co/tasks/feature-extraction) | ✅ | [`~InferenceClient.feature_extraction`] | +| | [마스크 채우기](https://huggingface.co/tasks/fill-mask) | ✅ | [`~InferenceClient.fill_mask`] | +| | [질의 응답](https://huggingface.co/tasks/question-answering) | ✅ | [`~InferenceClient.question_answering`] | +| | [문장 유사도](https://huggingface.co/tasks/sentence-similarity) | ✅ | [`~InferenceClient.sentence_similarity`] | +| | [요약](https://huggingface.co/tasks/summarization) | ✅ | [`~InferenceClient.summarization`] | +| | [테이블 질의 응답](https://huggingface.co/tasks/table-question-answering) | ✅ | [`~InferenceClient.table_question_answering`] | +| | [텍스트 분류](https://huggingface.co/tasks/text-classification) | ✅ | [`~InferenceClient.text_classification`] | +| | [텍스트 생성](https://huggingface.co/tasks/text-generation) | ✅ | [`~InferenceClient.text_generation`] | +| | [토큰 분류](https://huggingface.co/tasks/token-classification) | ✅ | [`~InferenceClient.token_classification`] | +| | [번역](https://huggingface.co/tasks/translation) | ✅ | [`~InferenceClient.translation`] | +| | [제로샷 분류](https://huggingface.co/tasks/zero-shot-classification) | ✅ | [`~InferenceClient.zero_shot_classification`] | +| 타블로 | [타블로 작업 분류](https://huggingface.co/tasks/tabular-classification) | ✅ | [`~InferenceClient.tabular_classification`] | +| | [타블로 회귀](https://huggingface.co/tasks/tabular-regression) | ✅ | [`~InferenceClient.tabular_regression`] | @@ -190,73 +189,3 @@ pip install --upgrade huggingface_hub[inference] >>> client.image_classification("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg") [{'score': 0.9779096841812134, 'label': 'Blenheim spaniel'}, ...] ``` - -## 레거시 InferenceAPI 클라이언트[[legacy-inferenceapi-client]] - -[`InferenceClient`]는 레거시 [`InferenceApi`] 클라이언트를 대체하여 작동합니다. 특정 작업에 대한 지원을 제공하고 [추론 API](https://huggingface.co/docs/api-inference/index) 및 [추론 엔드포인트](https://huggingface.co/docs/inference-endpoints/index)에서 추론을 처리합니다. - -아래는 [`InferenceApi`]에서 [`InferenceClient`]로 마이그레이션하는 데 도움이 되는 간단한 가이드입니다. - -### 초기화[[initialization]] - -변경 전: - -```python ->>> from huggingface_hub import InferenceApi ->>> inference = InferenceApi(repo_id="bert-base-uncased", token=API_TOKEN) -``` - -변경 후: - -```python ->>> from huggingface_hub import InferenceClient ->>> inference = InferenceClient(model="bert-base-uncased", token=API_TOKEN) -``` - -### 특정 작업에서 실행하기[[run-on-a-specific-task]] - -변경 전: - -```python ->>> from huggingface_hub import InferenceApi ->>> inference = InferenceApi(repo_id="paraphrase-xlm-r-multilingual-v1", task="feature-extraction") ->>> inference(...) -``` - -변경 후: - -```python ->>> from huggingface_hub import InferenceClient ->>> inference = InferenceClient() ->>> inference.feature_extraction(..., model="paraphrase-xlm-r-multilingual-v1") -``` - - - -위의 방법은 코드를 [`InferenceClient`]에 맞게 조정하는 권장 방법입니다. 이렇게 하면 `feature_extraction`과 같이 작업에 특화된 메소드를 활용할 수 있습니다. - - - -### 사용자 정의 요청 실행[[run-custom-request]] - -변경 전: - -```python ->>> from huggingface_hub import InferenceApi ->>> inference = InferenceApi(repo_id="bert-base-uncased") ->>> inference(inputs="The goal of life is [MASK].") -[{'sequence': 'the goal of life is life.', 'score': 0.10933292657136917, 'token': 2166, 'token_str': 'life'}] -``` - -### 매개변수와 함께 실행하기[[run-with-parameters]] - -변경 전: - -```python ->>> from huggingface_hub import InferenceApi ->>> inference = InferenceApi(repo_id="typeform/distilbert-base-uncased-mnli") ->>> inputs = "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!" ->>> params = {"candidate_labels":["refund", "legal", "faq"]} ->>> inference(inputs, params) -{'sequence': 'Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!', 'labels': ['refund', 'faq', 'legal'], 'scores': [0.9378499388694763, 0.04914155602455139, 0.013008488342165947]} -``` diff --git a/docs/source/ko/package_reference/inference_client.md b/docs/source/ko/package_reference/inference_client.md index 686c9282a9..0930a75351 100644 --- a/docs/source/ko/package_reference/inference_client.md +++ b/docs/source/ko/package_reference/inference_client.md @@ -35,13 +35,3 @@ pip install --upgrade huggingface_hub[inference] ## 반환 유형[[return-types]] 대부분의 작업에 대해, 반환 값은 내장된 유형(string, list, image...)을 갖습니다. 보다 복잡한 유형을 위한 목록은 다음과 같습니다. - - -## 추론 API[[huggingface_hub.InferenceApi]] - -[`InferenceAPI`]는 추론 API를 호출하는 레거시 방식입니다. 이 인터페이스는 더 간단하며 각 작업의 입력 매개변수와 출력 형식을 알아야 합니다. 또한 추론 엔드포인트나 AWS SageMaker와 같은 다른 서비스에 연결할 수 있는 기능이 없습니다. [`InferenceAPI`]는 곧 폐지될 예정이므로 가능한 경우 [`InferenceClient`]를 사용하는 것을 권장합니다. 스크립트에서 [`InferenceAPI`]를 [`InferenceClient`]로 전환하는 방법에 대해 알아보려면 [이 가이드](../guides/inference#legacy-inferenceapi-client)를 참조하세요. - -[[autodoc]] InferenceApi - - __init__ - - __call__ - - all diff --git a/setup.py b/setup.py index ec5ebfbb39..9a755682a6 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,5 @@ +import sys + from setuptools import find_packages, setup @@ -77,7 +79,7 @@ def get_version() -> str: + [ "jedi", "Jinja2", - "pytest>=8.1.1,<8.2.2", # at least until 8.2.3 is released with https://github.com/pytest-dev/pytest/pull/12436 + "pytest>=8.4.2", # we need https://github.com/pytest-dev/pytest/pull/12436 "pytest-cov", "pytest-env", "pytest-xdist", @@ -88,13 +90,18 @@ def get_version() -> str: "urllib3<2.0", # VCR.py broken with urllib3 2.0 (see https://urllib3.readthedocs.io/en/stable/v2-migration-guide.html) "soundfile", "Pillow", - "gradio>=4.0.0", # to test webhooks # pin to avoid issue on Python3.12 "requests", # for gradio "numpy", # for embeddings "fastapi", # To build the documentation ] ) +if sys.version_info >= (3, 10): + # We need gradio to test webhooks server + # But gradio 5.0+ only supports python 3.10+ so we don't want to test earlier versions + extras["testing"].append("gradio>=5.0.0") + extras["testing"].append("requests") # see https://github.com/gradio-app/gradio/pull/11830 + # Typing extra dependencies list is duplicated in `.pre-commit-config.yaml` # Please make sure to update the list there when adding a new typing dependency. extras["typing"] = [ diff --git a/src/huggingface_hub/README.md b/src/huggingface_hub/README.md index cd5c1e2beb..b0e5cd65d9 100644 --- a/src/huggingface_hub/README.md +++ b/src/huggingface_hub/README.md @@ -112,242 +112,3 @@ With the `HfApi` class there are methods to query models, datasets, and Spaces b - `space_info()` These lightly wrap around the API Endpoints. Documentation for valid parameters and descriptions can be found [here](https://huggingface.co/docs/hub/endpoints). - - -### Advanced programmatic repository management - -The `Repository` class helps manage both offline Git repositories and Hugging -Face Hub repositories. Using the `Repository` class requires `git` and `git-lfs` -to be installed. - -Instantiate a `Repository` object by calling it with a path to a local Git -clone/repository: - -```python ->>> from huggingface_hub import Repository ->>> repo = Repository("//") -``` - -The `Repository` takes a `clone_from` string as parameter. This can stay as -`None` for offline management, but can also be set to any URL pointing to a Git -repo to clone that repository in the specified directory: - -```python ->>> repo = Repository("huggingface-hub", clone_from="https://github.com/huggingface/huggingface_hub") -``` - -The `clone_from` method can also take any Hugging Face model ID as input, and -will clone that repository: - -```python ->>> repo = Repository("w2v2", clone_from="facebook/wav2vec2-large-960h-lv60") -``` - -If the repository you're cloning is one of yours or one of your organisation's, then having the ability to commit and push to that repository is important. In order to do that, you should make sure to be logged-in using `hf auth login`,: - -```python ->>> repo = Repository("my-model", clone_from="/") -``` - -This works for models, datasets and spaces repositories; but you will need to -explicitely specify the type for the last two options: - -```python ->>> repo = Repository("my-dataset", clone_from="/", repo_type="dataset") -``` - -You can also change between branches: - -```python ->>> repo = Repository("huggingface-hub", clone_from="/", revision='branch1') ->>> repo.git_checkout("branch2") -``` - -The `clone_from` method can also take any Hugging Face model ID as input, and -will clone that repository: - -```python ->>> repo = Repository("w2v2", clone_from="facebook/wav2vec2-large-960h-lv60") -``` - -Finally, you can choose to specify the Git username and email attributed to that -clone directly by using the `git_user` and `git_email` parameters. When -committing to that repository, Git will therefore be aware of who you are and -who will be the author of the commits: - -```python ->>> repo = Repository( -... "my-dataset", -... clone_from="/", -... repo_type="dataset", -... git_user="MyName", -... git_email="me@cool.mail" -... ) -``` - -The repository can be managed through this object, through wrappers of -traditional Git methods: - -- `git_add(pattern: str, auto_lfs_track: bool)`. The `auto_lfs_track` flag - triggers auto tracking of large files (>10MB) with `git-lfs` -- `git_commit(commit_message: str)` -- `git_pull(rebase: bool)` -- `git_push()` -- `git_checkout(branch)` - -The `git_push` method has a parameter `blocking` which is `True` by default. When set to `False`, the push will -happen behind the scenes - which can be helpful if you would like your script to continue on while the push is -happening. - -LFS-tracking methods: - -- `lfs_track(pattern: Union[str, List[str]], filename: bool)`. Setting - `filename` to `True` will use the `--filename` parameter, which will consider - the pattern(s) as filenames, even if they contain special glob characters. -- `lfs_untrack()`. -- `auto_track_large_files()`: automatically tracks files that are larger than - 10MB. Make sure to call this after adding files to the index. - -On top of these unitary methods lie some useful additional methods: - -- `push_to_hub(commit_message)`: consecutively does `git_add`, `git_commit` and - `git_push`. -- `commit(commit_message: str, track_large_files: bool)`: this is a context - manager utility that handles committing to a repository. This automatically - tracks large files (>10Mb) with `git-lfs`. The `track_large_files` argument can - be set to `False` if you wish to ignore that behavior. - -These two methods also have support for the `blocking` parameter. - -Examples using the `commit` context manager: -```python ->>> with Repository("text-files", clone_from="/text-files").commit("My first file :)"): -... with open("file.txt", "w+") as f: -... f.write(json.dumps({"hey": 8})) -``` - -```python ->>> import torch ->>> model = torch.nn.Transformer() ->>> with Repository("torch-model", clone_from="/torch-model").commit("My cool model :)"): -... torch.save(model.state_dict(), "model.pt") - ``` - -### Non-blocking behavior - -The pushing methods have access to a `blocking` boolean parameter to indicate whether the push should happen -asynchronously. - -In order to see if the push has finished or its status code (to spot a failure), one should use the `command_queue` -property on the `Repository` object. - -For example: - -```python -from huggingface_hub import Repository - -repo = Repository("", clone_from="/") - -with repo.commit("Commit message", blocking=False): - # Save data - -last_command = repo.command_queue[-1] - -# Status of the push command -last_command.status -# Will return the status code -# -> -1 will indicate the push is still ongoing -# -> 0 will indicate the push has completed successfully -# -> non-zero code indicates the error code if there was an error - -# if there was an error, the stderr may be inspected -last_command.stderr - -# Whether the command finished or if it is still ongoing -last_command.is_done - -# Whether the command errored-out. -last_command.failed -``` - -When using `blocking=False`, the commands will be tracked and your script will exit only when all pushes are done, even -if other errors happen in your script (a failed push counts as done). - - -### Need to upload very large (>5GB) files? - -To upload large files (>5GB 🔥) from git command-line, you need to install the custom transfer agent -for git-lfs, bundled in this package. - -To install, just run: - -```bash -$ hf lfs-enable-largefiles . -``` - -This should be executed once for each model repo that contains a model file ->5GB. If you just try to push a file bigger than 5GB without running that -command, you will get an error with a message reminding you to run it. - -Finally, there's a `hf lfs-multipart-upload` command but that one -is internal (called by lfs directly) and is not meant to be called by the user. - -
- -## Using the Inference API wrapper - -`huggingface_hub` comes with a wrapper client to make calls to the Inference -API! You can find some examples below, but we encourage you to visit the -Inference API -[documentation](https://api-inference.huggingface.co/docs/python/html/detailed_parameters.html) -to review the specific parameters for the different tasks. - -When you instantiate the wrapper to the Inference API, you specify the model -repository id. The pipeline (`text-classification`, `text-to-speech`, etc) is -automatically extracted from the -[repository](https://huggingface.co/docs/hub/main#how-is-a-models-type-of-inference-api-and-widget-determined), -but you can also override it as shown below. - - -### Examples - -Here is a basic example of calling the Inference API for a `fill-mask` task -using the `bert-base-uncased` model. The `fill-mask` task only expects a string -(or list of strings) as input. - -```python -from huggingface_hub.inference_api import InferenceApi -inference = InferenceApi("bert-base-uncased", token=API_TOKEN) -inference(inputs="The goal of life is [MASK].") ->> [{'sequence': 'the goal of life is life.', 'score': 0.10933292657136917, 'token': 2166, 'token_str': 'life'}] -``` - -This is an example of a task (`question-answering`) which requires a dictionary -as input thas has the `question` and `context` keys. - -```python -inference = InferenceApi("deepset/roberta-base-squad2", token=API_TOKEN) -inputs = {"question":"What's my name?", "context":"My name is Clara and I live in Berkeley."} -inference(inputs) ->> {'score': 0.9326569437980652, 'start': 11, 'end': 16, 'answer': 'Clara'} -``` - -Some tasks might also require additional params in the request. Here is an -example using a `zero-shot-classification` model. - -```python -inference = InferenceApi("typeform/distilbert-base-uncased-mnli", token=API_TOKEN) -inputs = "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!" -params = {"candidate_labels":["refund", "legal", "faq"]} -inference(inputs, params) ->> {'sequence': 'Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!', 'labels': ['refund', 'faq', 'legal'], 'scores': [0.9378499388694763, 0.04914155602455139, 0.013008488342165947]} -``` - -Finally, there are some models that might support multiple tasks. For example, -`sentence-transformers` models can do `sentence-similarity` and -`feature-extraction`. You can override the configured task when initializing the -API. - -```python -inference = InferenceApi("bert-base-uncased", task="feature-extraction", token=API_TOKEN) -``` diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index c1d2c6658f..472b0020f5 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -471,9 +471,6 @@ "inference._mcp.mcp_client": [ "MCPClient", ], - "inference_api": [ - "InferenceApi", - ], "keras_mixin": [ "KerasModelHubMixin", "from_pretrained_keras", @@ -529,7 +526,6 @@ "CorruptedCacheException", "DeleteCacheStrategy", "HFCacheInfo", - "HfFolder", "HfHubAsyncTransport", "HfHubTransport", "cached_assets_path", @@ -662,7 +658,6 @@ "HfFileSystemFile", "HfFileSystemResolvedPath", "HfFileSystemStreamFile", - "HfFolder", "HfHubAsyncTransport", "HfHubTransport", "ImageClassificationInput", @@ -686,7 +681,6 @@ "ImageToVideoOutput", "ImageToVideoParameters", "ImageToVideoTargetSize", - "InferenceApi", "InferenceClient", "InferenceEndpoint", "InferenceEndpointError", @@ -1501,7 +1495,6 @@ def __dir__(): ) from .inference._mcp.agent import Agent # noqa: F401 from .inference._mcp.mcp_client import MCPClient # noqa: F401 - from .inference_api import InferenceApi # noqa: F401 from .keras_mixin import ( KerasModelHubMixin, # noqa: F401 from_pretrained_keras, # noqa: F401 @@ -1555,7 +1548,6 @@ def __dir__(): CorruptedCacheException, # noqa: F401 DeleteCacheStrategy, # noqa: F401 HFCacheInfo, # noqa: F401 - HfFolder, # noqa: F401 HfHubAsyncTransport, # noqa: F401 HfHubTransport, # noqa: F401 cached_assets_path, # noqa: F401 diff --git a/src/huggingface_hub/constants.py b/src/huggingface_hub/constants.py index c1445ffc9d..20c5b5d970 100644 --- a/src/huggingface_hub/constants.py +++ b/src/huggingface_hub/constants.py @@ -234,43 +234,6 @@ def _as_int(value: Optional[str]) -> Optional[int]: # Allows to add information about the requester in the user-agent (eg. partner name) HF_HUB_USER_AGENT_ORIGIN: Optional[str] = os.environ.get("HF_HUB_USER_AGENT_ORIGIN") -# List frameworks that are handled by the InferenceAPI service. Useful to scan endpoints and check which models are -# deployed and running. Since 95% of the models are using the top 4 frameworks listed below, we scan only those by -# default. We still keep the full list of supported frameworks in case we want to scan all of them. -MAIN_INFERENCE_API_FRAMEWORKS = [ - "diffusers", - "sentence-transformers", - "text-generation-inference", - "transformers", -] - -ALL_INFERENCE_API_FRAMEWORKS = MAIN_INFERENCE_API_FRAMEWORKS + [ - "adapter-transformers", - "allennlp", - "asteroid", - "bertopic", - "doctr", - "espnet", - "fairseq", - "fastai", - "fasttext", - "flair", - "k2", - "keras", - "mindspore", - "nemo", - "open_clip", - "paddlenlp", - "peft", - "pyannote-audio", - "sklearn", - "spacy", - "span-marker", - "speechbrain", - "stanza", - "timm", -] - # If OAuth didn't work after 2 redirects, there's likely a third-party cookie issue in the Space iframe view. # In this case, we redirect the user to the non-iframe view. OAUTH_MAX_REDIRECTS = 2 diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 951a19dcc0..6e209d9055 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -106,7 +106,6 @@ from .repocard_data import DatasetCardData, ModelCardData, SpaceCardData from .utils import ( DEFAULT_IGNORE_PATTERNS, - HfFolder, # noqa: F401 # kept for backward compatibility LocalTokenNotFoundError, NotASafetensorsRepoError, SafetensorsFileMetadata, diff --git a/src/huggingface_hub/inference_api.py b/src/huggingface_hub/inference_api.py deleted file mode 100644 index 16c2812864..0000000000 --- a/src/huggingface_hub/inference_api.py +++ /dev/null @@ -1,217 +0,0 @@ -import io -from typing import Any, Optional, Union - -from . import constants -from .hf_api import HfApi -from .utils import build_hf_headers, get_session, is_pillow_available, logging, validate_hf_hub_args -from .utils._deprecation import _deprecate_method - - -logger = logging.get_logger(__name__) - - -ALL_TASKS = [ - # NLP - "text-classification", - "token-classification", - "table-question-answering", - "question-answering", - "zero-shot-classification", - "translation", - "summarization", - "conversational", - "feature-extraction", - "text-generation", - "text2text-generation", - "fill-mask", - "sentence-similarity", - # Audio - "text-to-speech", - "automatic-speech-recognition", - "audio-to-audio", - "audio-classification", - "voice-activity-detection", - # Computer vision - "image-classification", - "object-detection", - "image-segmentation", - "text-to-image", - "image-to-image", - # Others - "tabular-classification", - "tabular-regression", -] - - -class InferenceApi: - """Client to configure httpx and make calls to the HuggingFace Inference API. - - Example: - - ```python - >>> from huggingface_hub.inference_api import InferenceApi - - >>> # Mask-fill example - >>> inference = InferenceApi("bert-base-uncased") - >>> inference(inputs="The goal of life is [MASK].") - [{'sequence': 'the goal of life is life.', 'score': 0.10933292657136917, 'token': 2166, 'token_str': 'life'}] - - >>> # Question Answering example - >>> inference = InferenceApi("deepset/roberta-base-squad2") - >>> inputs = { - ... "question": "What's my name?", - ... "context": "My name is Clara and I live in Berkeley.", - ... } - >>> inference(inputs) - {'score': 0.9326569437980652, 'start': 11, 'end': 16, 'answer': 'Clara'} - - >>> # Zero-shot example - >>> inference = InferenceApi("typeform/distilbert-base-uncased-mnli") - >>> inputs = "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!" - >>> params = {"candidate_labels": ["refund", "legal", "faq"]} - >>> inference(inputs, params) - {'sequence': 'Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!', 'labels': ['refund', 'faq', 'legal'], 'scores': [0.9378499388694763, 0.04914155602455139, 0.013008488342165947]} - - >>> # Overriding configured task - >>> inference = InferenceApi("bert-base-uncased", task="feature-extraction") - - >>> # Text-to-image - >>> inference = InferenceApi("stabilityai/stable-diffusion-2-1") - >>> inference("cat") - - - >>> # Return as raw response to parse the output yourself - >>> inference = InferenceApi("mio/amadeus") - >>> response = inference("hello world", raw_response=True) - >>> response.headers - {"Content-Type": "audio/flac", ...} - >>> response.content # raw bytes from server - b'(...)' - ``` - """ - - @validate_hf_hub_args - @_deprecate_method( - version="1.0", - message=( - "`InferenceApi` client is deprecated in favor of the more feature-complete `InferenceClient`. Check out" - " this guide to learn how to convert your script to use it:" - " https://huggingface.co/docs/huggingface_hub/guides/inference#legacy-inferenceapi-client." - ), - ) - def __init__( - self, - repo_id: str, - task: Optional[str] = None, - token: Optional[str] = None, - gpu: bool = False, - ): - """Inits headers and API call information. - - Args: - repo_id (``str``): - Id of repository (e.g. `user/bert-base-uncased`). - task (``str``, `optional`, defaults ``None``): - Whether to force a task instead of using task specified in the - repository. - token (`str`, `optional`): - The API token to use as HTTP bearer authorization. This is not - the authentication token. You can find the token in - https://huggingface.co/settings/token. Alternatively, you can - find both your organizations and personal API tokens using - `HfApi().whoami(token)`. - gpu (`bool`, `optional`, defaults `False`): - Whether to use GPU instead of CPU for inference(requires Startup - plan at least). - """ - self.options = {"wait_for_model": True, "use_gpu": gpu} - self.headers = build_hf_headers(token=token) - - # Configure task - model_info = HfApi(token=token).model_info(repo_id=repo_id) - if not model_info.pipeline_tag and not task: - raise ValueError( - "Task not specified in the repository. Please add it to the model card" - " using pipeline_tag" - " (https://huggingface.co/docs#how-is-a-models-type-of-inference-api-and-widget-determined)" - ) - - if task and task != model_info.pipeline_tag: - if task not in ALL_TASKS: - raise ValueError(f"Invalid task {task}. Make sure it's valid.") - - logger.warning( - "You're using a different task than the one specified in the" - " repository. Be sure to know what you're doing :)" - ) - self.task = task - else: - assert model_info.pipeline_tag is not None, "Pipeline tag cannot be None" - self.task = model_info.pipeline_tag - - self.api_url = f"{constants.INFERENCE_ENDPOINT}/pipeline/{self.task}/{repo_id}" - - def __repr__(self): - # Do not add headers to repr to avoid leaking token. - return f"InferenceAPI(api_url='{self.api_url}', task='{self.task}', options={self.options})" - - def __call__( - self, - inputs: Optional[Union[str, dict, list[str], list[list[str]]]] = None, - params: Optional[dict] = None, - data: Optional[bytes] = None, - raw_response: bool = False, - ) -> Any: - """Make a call to the Inference API. - - Args: - inputs (`str` or `dict` or `list[str]` or `list[list[str]]`, *optional*): - Inputs for the prediction. - params (`dict`, *optional*): - Additional parameters for the models. Will be sent as `parameters` in the - payload. - data (`bytes`, *optional*): - Bytes content of the request. In this case, leave `inputs` and `params` empty. - raw_response (`bool`, defaults to `False`): - If `True`, the raw `Response` object is returned. You can parse its content - as preferred. By default, the content is parsed into a more practical format - (json dictionary or PIL Image for example). - """ - # Build payload - payload: dict[str, Any] = { - "options": self.options, - } - if inputs: - payload["inputs"] = inputs - if params: - payload["parameters"] = params - - # Make API call - response = get_session().post(self.api_url, headers=self.headers, json=payload, content=data) - - # Let the user handle the response - if raw_response: - return response - - # By default, parse the response for the user. - content_type = response.headers.get("Content-Type") or "" - if content_type.startswith("image"): - if not is_pillow_available(): - raise ImportError( - f"Task '{self.task}' returned as image but Pillow is not installed." - " Please install it (`pip install Pillow`) or pass" - " `raw_response=True` to get the raw `Response` object and parse" - " the image by yourself." - ) - - from PIL import Image - - return Image.open(io.BytesIO(response.content)) - elif content_type == "application/json": - return response.json() - else: - raise NotImplementedError( - f"{content_type} output type is not implemented yet. You can pass" - " `raw_response=True` to get the raw `Response` object and parse the" - " output by yourself." - ) diff --git a/src/huggingface_hub/utils/__init__.py b/src/huggingface_hub/utils/__init__.py index 52838fe000..bf25d66950 100644 --- a/src/huggingface_hub/utils/__init__.py +++ b/src/huggingface_hub/utils/__init__.py @@ -50,7 +50,6 @@ from ._fixes import SoftTemporaryDirectory, WeakFileLock, yaml_dump from ._git_credential import list_credential_helpers, set_git_credential, unset_git_credential from ._headers import build_hf_headers, get_token_to_send -from ._hf_folder import HfFolder from ._http import ( ASYNC_CLIENT_FACTORY_T, CLIENT_FACTORY_T, diff --git a/src/huggingface_hub/utils/_hf_folder.py b/src/huggingface_hub/utils/_hf_folder.py deleted file mode 100644 index 6418bf2fd2..0000000000 --- a/src/huggingface_hub/utils/_hf_folder.py +++ /dev/null @@ -1,68 +0,0 @@ -# coding=utf-8 -# Copyright 2022-present, the HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Contain helper class to retrieve/store token from/to local cache.""" - -from pathlib import Path -from typing import Optional - -from .. import constants -from ._auth import get_token - - -class HfFolder: - # TODO: deprecate when adapted in transformers/datasets/gradio - # @_deprecate_method(version="1.0", message="Use `huggingface_hub.login` instead.") - @classmethod - def save_token(cls, token: str) -> None: - """ - Save token, creating folder as needed. - - Token is saved in the huggingface home folder. You can configure it by setting - the `HF_HOME` environment variable. - - Args: - token (`str`): - The token to save to the [`HfFolder`] - """ - path_token = Path(constants.HF_TOKEN_PATH) - path_token.parent.mkdir(parents=True, exist_ok=True) - path_token.write_text(token) - - # TODO: deprecate when adapted in transformers/datasets/gradio - # @_deprecate_method(version="1.0", message="Use `huggingface_hub.get_token` instead.") - @classmethod - def get_token(cls) -> Optional[str]: - """ - Get token or None if not existent. - - This method is deprecated in favor of [`huggingface_hub.get_token`] but is kept for backward compatibility. - Its behavior is the same as [`huggingface_hub.get_token`]. - - Returns: - `str` or `None`: The token, `None` if it doesn't exist. - """ - return get_token() - - # TODO: deprecate when adapted in transformers/datasets/gradio - # @_deprecate_method(version="1.0", message="Use `huggingface_hub.logout` instead.") - @classmethod - def delete_token(cls) -> None: - """ - Deletes the token from storage. Does not fail if token does not exist. - """ - try: - Path(constants.HF_TOKEN_PATH).unlink() - except FileNotFoundError: - pass diff --git a/tests/test_inference_api.py b/tests/test_inference_api.py deleted file mode 100644 index a057ec4450..0000000000 --- a/tests/test_inference_api.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright 2020 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import unittest -from pathlib import Path -from unittest.mock import patch - -import pytest -from PIL import Image - -from huggingface_hub import hf_hub_download -from huggingface_hub.inference_api import InferenceApi - -from .testing_utils import expect_deprecation, with_production_testing - - -@pytest.mark.vcr -@with_production_testing -class InferenceApiTest(unittest.TestCase): - def read(self, filename: str) -> bytes: - return Path(filename).read_bytes() - - @classmethod - @with_production_testing - def setUpClass(cls) -> None: - cls.image_file = hf_hub_download(repo_id="Narsil/image_dummy", repo_type="dataset", filename="lena.png") - return super().setUpClass() - - @expect_deprecation("huggingface_hub.inference_api") - def test_simple_inference(self): - api = InferenceApi("bert-base-uncased") - inputs = "Hi, I think [MASK] is cool" - results = api(inputs) - self.assertIsInstance(results, list) - - result = results[0] - self.assertIsInstance(result, dict) - self.assertTrue("sequence" in result) - self.assertTrue("score" in result) - - @unittest.skip("Model often not loaded") - @expect_deprecation("huggingface_hub.inference_api") - def test_inference_with_params(self): - api = InferenceApi("typeform/distilbert-base-uncased-mnli") - inputs = "I bought a device but it is not working and I would like to get reimbursed!" - params = {"candidate_labels": ["refund", "legal", "faq"]} - result = api(inputs, params) - self.assertIsInstance(result, dict) - self.assertTrue("sequence" in result) - self.assertTrue("scores" in result) - - @unittest.skip("Model often not loaded") - @expect_deprecation("huggingface_hub.inference_api") - def test_inference_with_dict_inputs(self): - api = InferenceApi("distilbert-base-cased-distilled-squad") - inputs = { - "question": "What's my name?", - "context": "My name is Clara and I live in Berkeley.", - } - result = api(inputs) - self.assertIsInstance(result, dict) - self.assertTrue("score" in result) - self.assertTrue("answer" in result) - - @unittest.skip("Model often not loaded") - @expect_deprecation("huggingface_hub.inference_api") - def test_inference_with_audio(self): - api = InferenceApi("facebook/wav2vec2-base-960h") - file = hf_hub_download( - repo_id="hf-internal-testing/dummy-flac-single-example", - repo_type="dataset", - filename="example.flac", - ) - data = self.read(file) - result = api(data=data) - self.assertIsInstance(result, dict) - self.assertTrue("text" in result, f"We received {result} instead") - - @unittest.skip("Model often not loaded") - @expect_deprecation("huggingface_hub.inference_api") - def test_inference_with_image(self): - api = InferenceApi("google/vit-base-patch16-224") - data = self.read(self.image_file) - result = api(data=data) - self.assertIsInstance(result, list) - for classification in result: - self.assertIsInstance(classification, dict) - self.assertTrue("score" in classification) - self.assertTrue("label" in classification) - - @expect_deprecation("huggingface_hub.inference_api") - def test_text_to_image(self): - api = InferenceApi("stabilityai/stable-diffusion-2-1") - with patch("huggingface_hub.inference_api.get_session") as mock: - mock().post.return_value.headers = {"Content-Type": "image/jpeg"} - mock().post.return_value.content = self.read(self.image_file) - output = api("cat") - self.assertIsInstance(output, Image.Image) - - @expect_deprecation("huggingface_hub.inference_api") - def test_text_to_image_raw_response(self): - api = InferenceApi("stabilityai/stable-diffusion-2-1") - with patch("huggingface_hub.inference_api.get_session") as mock: - mock().post.return_value.headers = {"Content-Type": "image/jpeg"} - mock().post.return_value.content = self.read(self.image_file) - output = api("cat", raw_response=True) - # Raw response is returned - self.assertEqual(output, mock().post.return_value) - - @expect_deprecation("huggingface_hub.inference_api") - def test_inference_overriding_task(self): - api = InferenceApi( - "sentence-transformers/paraphrase-albert-small-v2", - task="feature-extraction", - ) - inputs = "This is an example again" - result = api(inputs) - self.assertIsInstance(result, list) - - @expect_deprecation("huggingface_hub.inference_api") - def test_inference_overriding_invalid_task(self): - with self.assertRaises(ValueError, msg="Invalid task invalid-task. Make sure it's valid."): - InferenceApi("bert-base-uncased", task="invalid-task") - - @expect_deprecation("huggingface_hub.inference_api") - def test_inference_missing_input(self): - api = InferenceApi("deepset/roberta-base-squad2") - result = api({"question": "What's my name?"}) - self.assertIsInstance(result, dict) - self.assertTrue("error" in result) diff --git a/tests/test_utils_headers.py b/tests/test_utils_headers.py index d6c00874e4..ff61cec932 100644 --- a/tests/test_utils_headers.py +++ b/tests/test_utils_headers.py @@ -19,8 +19,6 @@ NO_AUTH_HEADER = {"user-agent": DEFAULT_USER_AGENT} -# @patch("huggingface_hub.utils._headers.HfFolder") -# @handle_injection class TestAuthHeadersUtil(unittest.TestCase): def test_use_auth_token_str(self) -> None: self.assertEqual(build_hf_headers(use_auth_token=FAKE_TOKEN), FAKE_TOKEN_HEADER) diff --git a/tests/test_utils_hf_folder.py b/tests/test_utils_hf_folder.py deleted file mode 100644 index 5857fa4df8..0000000000 --- a/tests/test_utils_hf_folder.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright 2020 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Contain tests for `HfFolder` utility.""" - -import os -import unittest -from uuid import uuid4 - -from huggingface_hub.utils import HfFolder - - -def _generate_token() -> str: - return f"token-{uuid4()}" - - -class HfFolderTest(unittest.TestCase): - def test_token_workflow(self): - """ - Test the whole token save/get/delete workflow, - with the desired behavior with respect to non-existent tokens. - """ - token = _generate_token() - HfFolder.save_token(token) - self.assertEqual(HfFolder.get_token(), token) - HfFolder.delete_token() - HfFolder.delete_token() - # ^^ not an error, we test that the - # second call does not fail. - self.assertEqual(HfFolder.get_token(), None) - # test TOKEN in env - self.assertEqual(HfFolder.get_token(), None) - with unittest.mock.patch.dict(os.environ, {"HF_TOKEN": token}): - self.assertEqual(HfFolder.get_token(), token) - - def test_token_strip(self): - """ - Test the workflow when the token is mistakenly finishing with new-line or space character. - """ - token = _generate_token() - HfFolder.save_token(" " + token + "\n") - self.assertEqual(HfFolder.get_token(), token) - HfFolder.delete_token() diff --git a/tests/test_webhooks_server.py b/tests/test_webhooks_server.py index 8284b14bea..c8d5c4c2db 100644 --- a/tests/test_webhooks_server.py +++ b/tests/test_webhooks_server.py @@ -110,28 +110,28 @@ } -def test_deserialize_payload_example_with_comment() -> None: - """Confirm that the test stub can actually be deserialized.""" - payload = WebhookPayload.model_validate(WEBHOOK_PAYLOAD_CREATE_DISCUSSION) - assert payload.event.scope == WEBHOOK_PAYLOAD_CREATE_DISCUSSION["event"]["scope"] - assert payload.comment is not None - assert payload.comment.content == "Add co2 emissions information to the model card" - - -def test_deserialize_payload_example_without_comment() -> None: - """Confirm that the test stub can actually be deserialized.""" - payload = WebhookPayload.model_validate(WEBHOOK_PAYLOAD_UPDATE_DISCUSSION) - assert payload.event.scope == WEBHOOK_PAYLOAD_UPDATE_DISCUSSION["event"]["scope"] - assert payload.comment is None - - -def test_deserialize_payload_example_with_updated_refs() -> None: - """Confirm that the test stub can actually be deserialized.""" - payload = WebhookPayload.model_validate(WEBHOOK_PAYLOAD_WITH_UPDATED_REFS) - assert payload.updatedRefs is not None - assert payload.updatedRefs[0].ref == "refs/pr/5" - assert payload.updatedRefs[0].oldSha is None - assert payload.updatedRefs[0].newSha == "227c78346870a85e5de4fff8a585db68df975406" +@requires("gradio") +class TestWebhookPayload(unittest.TestCase): + def test_deserialize_payload_example_with_comment(self) -> None: + """Confirm that the test stub can actually be deserialized.""" + payload = WebhookPayload.model_validate(WEBHOOK_PAYLOAD_CREATE_DISCUSSION) + assert payload.event.scope == WEBHOOK_PAYLOAD_CREATE_DISCUSSION["event"]["scope"] + assert payload.comment is not None + assert payload.comment.content == "Add co2 emissions information to the model card" + + def test_deserialize_payload_example_without_comment(self) -> None: + """Confirm that the test stub can actually be deserialized.""" + payload = WebhookPayload.model_validate(WEBHOOK_PAYLOAD_UPDATE_DISCUSSION) + assert payload.event.scope == WEBHOOK_PAYLOAD_UPDATE_DISCUSSION["event"]["scope"] + assert payload.comment is None + + def test_deserialize_payload_example_with_updated_refs(self) -> None: + """Confirm that the test stub can actually be deserialized.""" + payload = WebhookPayload.model_validate(WEBHOOK_PAYLOAD_WITH_UPDATED_REFS) + assert payload.updatedRefs is not None + assert payload.updatedRefs[0].ref == "refs/pr/5" + assert payload.updatedRefs[0].oldSha is None + assert payload.updatedRefs[0].newSha == "227c78346870a85e5de4fff8a585db68df975406" @requires("gradio") From 966df2979037144669ce73d09a6624049c573a4e Mon Sep 17 00:00:00 2001 From: Lucain Date: Thu, 11 Sep 2025 17:02:13 +0200 Subject: [PATCH 04/19] [v1.0] Remove more deprecated stuff (#3345) * remove constants.-hf_cache_home * remove smoothly_deprecate_use_auth_token * remove get_token_permission * remove update_repo_visibility * remove is_write_action arg * remove write_permission arg from login methods * new parameter skip_if_logged_in in login methods * Remove resume_download / force_filename parameters * Remove deprecated local_dir_use_symlinks parameter * Remove deprecated language, library, task, tags from list_models * Return commit URL in upload_file/upload_folder (previously url to file/folder on the Hub) * fix upload_file/upload_folder tests * smoothly_deprecate_legacy_arguments everywhere * code quality * fix tests * fix xet tests --- docs/source/de/guides/integrations.md | 17 +- docs/source/en/guides/integrations.md | 20 +- docs/source/en/package_reference/utilities.md | 18 +- docs/source/fr/guides/integrations.md | 4 - docs/source/ko/guides/integrations.md | 5 +- docs/source/ko/package_reference/utilities.md | 20 -- src/huggingface_hub/__init__.py | 6 - src/huggingface_hub/_login.py | 50 ++--- src/huggingface_hub/_snapshot_download.py | 7 +- src/huggingface_hub/commands/download.py | 9 - src/huggingface_hub/constants.py | 1 - src/huggingface_hub/file_download.py | 28 +-- src/huggingface_hub/hf_api.py | 203 +----------------- src/huggingface_hub/hub_mixin.py | 8 - src/huggingface_hub/keras_mixin.py | 1 - src/huggingface_hub/utils/__init__.py | 2 +- src/huggingface_hub/utils/_headers.py | 9 - src/huggingface_hub/utils/_validators.py | 138 ++++-------- tests/test_cli.py | 1 - tests/test_hf_api.py | 108 +++------- tests/test_hub_mixin.py | 7 +- tests/test_hub_mixin_pytorch.py | 9 +- tests/test_keras_integration.py | 4 +- tests/test_utils_headers.py | 22 +- tests/test_utils_validators.py | 59 ----- tests/test_xet_upload.py | 6 +- 26 files changed, 137 insertions(+), 625 deletions(-) diff --git a/docs/source/de/guides/integrations.md b/docs/source/de/guides/integrations.md index 3d792c3b5e..34e9bae3ce 100644 --- a/docs/source/de/guides/integrations.md +++ b/docs/source/de/guides/integrations.md @@ -82,7 +82,7 @@ Obwohl dieser Ansatz flexibel ist, hat er einige Nachteile, insbesondere in Bezu - `token`: zum Herunterladen aus einem privaten Repository - `revision`: zum Herunterladen von einem spezifischen Branch - `cache_dir`: um Dateien in einem spezifischen Verzeichnis zu cachen -- `force_download`/`resume_download`/`local_files_only`: um den Cache wieder zu verwenden oder nicht +- `force_download`/`local_files_only`: um den Cache wieder zu verwenden oder nicht - `api_endpoint`/`proxies`: HTTP-Session konfigurieren Beim Pushen von Modellen werden ähnliche Parameter unterstützt: @@ -203,7 +203,6 @@ class PyTorchModelHubMixin(ModelHubMixin): cache_dir: str, force_download: bool, proxies: Optional[dict], - resume_download: bool, local_files_only: bool, token: Union[str, bool, None], map_location: str = "cpu", # zusätzliches Argument @@ -221,8 +220,6 @@ class PyTorchModelHubMixin(ModelHubMixin): revision=revision, cache_dir=cache_dir, force_download=force_download, - proxies=proxies, - resume_download=resume_download, token=token, local_files_only=local_files_only, ) @@ -242,9 +239,9 @@ Und das war's! Ihre Bibliothek ermöglicht es Benutzern nun, Dateien vom und zum Lassen Sie uns die beiden Ansätze, die wir gesehen haben, schnell mit ihren Vor- und Nachteilen zusammenfassen. Die untenstehende Tabelle ist nur indikativ. Ihr Framework könnte einige Besonderheiten haben, die Sie berücksichtigen müssen. Dieser Leitfaden soll nur Richtlinien und Ideen geben, wie Sie die Integration handhaben können. Kontaktieren Sie uns in jedem Fall, wenn Sie Fragen haben! -| Integration | Mit Helfern | Mit [`ModelHubMixin`] | -|:---:|:---:|:---:| -| Benutzererfahrung | `model = load_from_hub(...)`
`push_to_hub(model, ...)` | `model = MyModel.from_pretrained(...)`
`model.push_to_hub(...)` | -| Flexibilität | Sehr flexibel.
Sie haben die volle Kontrolle über die Implementierung. | Weniger flexibel.
Ihr Framework muss eine Modellklasse haben. | -| Wartung | Mehr Wartung, um Unterstützung für Konfiguration und neue Funktionen hinzuzufügen. Könnte auch das Beheben von Benutzerproblemen erfordern. | Weniger Wartung, da die meisten Interaktionen mit dem Hub in `huggingface_hub` implementiert sind. | -| Dokumentation/Typ-Annotation| Manuell zu schreiben. | Teilweise durch `huggingface_hub` behandelt. | +| Integration | Mit Helfern | Mit [`ModelHubMixin`] | +| :--------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------: | +| Benutzererfahrung | `model = load_from_hub(...)`
`push_to_hub(model, ...)` | `model = MyModel.from_pretrained(...)`
`model.push_to_hub(...)` | +| Flexibilität | Sehr flexibel.
Sie haben die volle Kontrolle über die Implementierung. | Weniger flexibel.
Ihr Framework muss eine Modellklasse haben. | +| Wartung | Mehr Wartung, um Unterstützung für Konfiguration und neue Funktionen hinzuzufügen. Könnte auch das Beheben von Benutzerproblemen erfordern. | Weniger Wartung, da die meisten Interaktionen mit dem Hub in `huggingface_hub` implementiert sind. | +| Dokumentation/Typ-Annotation | Manuell zu schreiben. | Teilweise durch `huggingface_hub` behandelt. | diff --git a/docs/source/en/guides/integrations.md b/docs/source/en/guides/integrations.md index cc4431923d..61dace2df4 100644 --- a/docs/source/en/guides/integrations.md +++ b/docs/source/en/guides/integrations.md @@ -244,8 +244,6 @@ class PyTorchModelHubMixin(ModelHubMixin): revision: str, cache_dir: str, force_download: bool, - proxies: Optional[dict], - resume_download: bool, local_files_only: bool, token: Union[str, bool, None], map_location: str = "cpu", # additional argument @@ -265,8 +263,6 @@ class PyTorchModelHubMixin(ModelHubMixin): revision=revision, cache_dir=cache_dir, force_download=force_download, - proxies=proxies, - resume_download=resume_download, token=token, local_files_only=local_files_only, ) @@ -428,11 +424,11 @@ Your framework might have some specificities that you need to address. This guid ideas on how to handle integration. In any case, feel free to contact us if you have any questions! -| Integration | Using helpers | Using [`ModelHubMixin`] | -|:---:|:---:|:---:| -| User experience | `model = load_from_hub(...)`
`push_to_hub(model, ...)` | `model = MyModel.from_pretrained(...)`
`model.push_to_hub(...)` | -| Flexibility | Very flexible.
You fully control the implementation. | Less flexible.
Your framework must have a model class. | -| Maintenance | More maintenance to add support for configuration, and new features. Might also require fixing issues reported by users. | Less maintenance as most of the interactions with the Hub are implemented in `huggingface_hub`. | -| Documentation / Type annotation | To be written manually. | Partially handled by `huggingface_hub`. | -| Download counter | To be handled manually. | Enabled by default if class has a `config` attribute. | -| Model card | To be handled manually | Generated by default with library_name, tags, etc. | +| Integration | Using helpers | Using [`ModelHubMixin`] | +| :-----------------------------: | :----------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------: | +| User experience | `model = load_from_hub(...)`
`push_to_hub(model, ...)` | `model = MyModel.from_pretrained(...)`
`model.push_to_hub(...)` | +| Flexibility | Very flexible.
You fully control the implementation. | Less flexible.
Your framework must have a model class. | +| Maintenance | More maintenance to add support for configuration, and new features. Might also require fixing issues reported by users. | Less maintenance as most of the interactions with the Hub are implemented in `huggingface_hub`. | +| Documentation / Type annotation | To be written manually. | Partially handled by `huggingface_hub`. | +| Download counter | To be handled manually. | Enabled by default if class has a `config` attribute. | +| Model card | To be handled manually | Generated by default with library_name, tags, etc. | diff --git a/docs/source/en/package_reference/utilities.md b/docs/source/en/package_reference/utilities.md index df6537297a..a7cc46315d 100644 --- a/docs/source/en/package_reference/utilities.md +++ b/docs/source/en/package_reference/utilities.md @@ -255,20 +255,6 @@ huggingface_hub.utils._validators.HFValidationError: Cannot have -- or .. in rep >>> my_cool_method(repo_id="other..repo..id") huggingface_hub.utils._validators.HFValidationError: Cannot have -- or .. in repo_id: 'other..repo..id'. - ->>> @validate_hf_hub_args -... def my_cool_auth_method(token: str): -... print(token) - ->>> my_cool_auth_method(token="a token") -"a token" - ->>> my_cool_auth_method(use_auth_token="a use_auth_token") -"a use_auth_token" - ->>> my_cool_auth_method(token="a token", use_auth_token="a use_auth_token") -UserWarning: Both `token` and `use_auth_token` are passed (...). `use_auth_token` value will be ignored. -"a token" ``` #### validate_hf_hub_args @@ -288,8 +274,8 @@ validated. [[autodoc]] utils.validate_repo_id -#### smoothly_deprecate_use_auth_token +#### smoothly_deprecate_legacy_arguments Not exactly a validator, but ran as well. -[[autodoc]] utils.smoothly_deprecate_use_auth_token +[[autodoc]] utils.smoothly_deprecate_legacy_arguments diff --git a/docs/source/fr/guides/integrations.md b/docs/source/fr/guides/integrations.md index f2c81a3d17..20dff4a73f 100644 --- a/docs/source/fr/guides/integrations.md +++ b/docs/source/fr/guides/integrations.md @@ -223,8 +223,6 @@ class PyTorchModelHubMixin(ModelHubMixin): revision: str, cache_dir: str, force_download: bool, - proxies: Optional[dict], - resume_download: bool, local_files_only: bool, token: Union[str, bool, None], map_location: str = "cpu", # argument supplémentaire @@ -242,8 +240,6 @@ class PyTorchModelHubMixin(ModelHubMixin): revision=revision, cache_dir=cache_dir, force_download=force_download, - proxies=proxies, - resume_download=resume_download, token=token, local_files_only=local_files_only, ) diff --git a/docs/source/ko/guides/integrations.md b/docs/source/ko/guides/integrations.md index a3ff1750f6..b595a3c630 100644 --- a/docs/source/ko/guides/integrations.md +++ b/docs/source/ko/guides/integrations.md @@ -81,7 +81,7 @@ def push_to_hub(model: MyModelClass, repo_name: str) -> None: - `token`: 개인 리포지토리에서 다운로드하기 위한 토큰 - `revision`: 특정 브랜치에서 다운로드하기 위한 리비전 - `cache_dir`: 특정 디렉터리에 파일을 캐시하기 위한 디렉터리 -- `force_download`/`resume_download`/`local_files_only`: 캐시를 재사용할 것인지 여부를 결정하는 매개변수 +- `force_download`/`local_files_only`: 캐시를 재사용할 것인지 여부를 결정하는 매개변수 - `proxies`: HTTP 세션 구성 모델을 푸시할 때는 유사한 매개변수가 지원됩니다: @@ -212,7 +212,6 @@ class PyTorchModelHubMixin(ModelHubMixin): cache_dir: str, force_download: bool, proxies: Optional[dict], - resume_download: bool, local_files_only: bool, token: Union[str, bool, None], map_location: str = "cpu", # 추가 인자 @@ -232,8 +231,6 @@ class PyTorchModelHubMixin(ModelHubMixin): revision=revision, cache_dir=cache_dir, force_download=force_download, - proxies=proxies, - resume_download=resume_download, token=token, local_files_only=local_files_only, ) diff --git a/docs/source/ko/package_reference/utilities.md b/docs/source/ko/package_reference/utilities.md index 96ac88e432..5743d12015 100644 --- a/docs/source/ko/package_reference/utilities.md +++ b/docs/source/ko/package_reference/utilities.md @@ -199,20 +199,6 @@ huggingface_hub.utils._validators.HFValidationError: Cannot have -- or .. in rep >>> my_cool_method(repo_id="other..repo..id") huggingface_hub.utils._validators.HFValidationError: Cannot have -- or .. in repo_id: 'other..repo..id'. - ->>> @validate_hf_hub_args -... def my_cool_auth_method(token: str): -... print(token) - ->>> my_cool_auth_method(token="a token") -"a token" - ->>> my_cool_auth_method(use_auth_token="a use_auth_token") -"a use_auth_token" - ->>> my_cool_auth_method(token="a token", use_auth_token="a use_auth_token") -UserWarning: Both `token` and `use_auth_token` are passed (...). `use_auth_token` value will be ignored. -"a token" ``` #### validate_hf_hub_args[[huggingface_hub.utils.validate_hf_hub_args]] @@ -230,9 +216,3 @@ UserWarning: Both `token` and `use_auth_token` are passed (...). `use_auth_token #### repo_id[[huggingface_hub.utils.validate_repo_id]] [[autodoc]] utils.validate_repo_id - -#### smoothly_deprecate_use_auth_token[[huggingface_hub.utils.smoothly_deprecate_use_auth_token]] - -정확히 검증기는 아니지만, 잘 실행됩니다. - -[[autodoc]] utils.smoothly_deprecate_use_auth_token diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index 472b0020f5..b534281a5e 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -217,7 +217,6 @@ "get_safetensors_metadata", "get_space_runtime", "get_space_variables", - "get_token_permission", "get_user_overview", "get_webhook", "grant_access", @@ -278,7 +277,6 @@ "update_collection_metadata", "update_inference_endpoint", "update_repo_settings", - "update_repo_visibility", "update_webhook", "upload_file", "upload_folder", @@ -885,7 +883,6 @@ "get_space_variables", "get_tf_storage_size", "get_token", - "get_token_permission", "get_torch_storage_id", "get_torch_storage_size", "get_user_overview", @@ -977,7 +974,6 @@ "update_collection_metadata", "update_inference_endpoint", "update_repo_settings", - "update_repo_visibility", "update_webhook", "upload_file", "upload_folder", @@ -1247,7 +1243,6 @@ def __dir__(): get_safetensors_metadata, # noqa: F401 get_space_runtime, # noqa: F401 get_space_variables, # noqa: F401 - get_token_permission, # noqa: F401 get_user_overview, # noqa: F401 get_webhook, # noqa: F401 grant_access, # noqa: F401 @@ -1308,7 +1303,6 @@ def __dir__(): update_collection_metadata, # noqa: F401 update_inference_endpoint, # noqa: F401 update_repo_settings, # noqa: F401 - update_repo_visibility, # noqa: F401 update_webhook, # noqa: F401 upload_file, # noqa: F401 upload_folder, # noqa: F401 diff --git a/src/huggingface_hub/_login.py b/src/huggingface_hub/_login.py index 303cd2b35d..946fd18af2 100644 --- a/src/huggingface_hub/_login.py +++ b/src/huggingface_hub/_login.py @@ -41,7 +41,7 @@ _save_token, get_stored_tokens, ) -from .utils._deprecation import _deprecate_arguments, _deprecate_positional_args +from .utils._deprecation import _deprecate_positional_args logger = logging.get_logger(__name__) @@ -55,18 +55,12 @@ """ -@_deprecate_arguments( - version="1.0", - deprecated_args="write_permission", - custom_message="Fine-grained tokens added complexity to the permissions, making it irrelevant to check if a token has 'write' access.", -) @_deprecate_positional_args(version="1.0") def login( token: Optional[str] = None, *, add_to_git_credential: bool = False, - new_session: bool = True, - write_permission: bool = False, + skip_if_logged_in: bool = False, ) -> None: """Login the machine to access the Hub. @@ -102,10 +96,8 @@ def login( is configured, a warning will be displayed to the user. If `token` is `None`, the value of `add_to_git_credential` is ignored and will be prompted again to the end user. - new_session (`bool`, defaults to `True`): - If `True`, will request a token even if one is already saved on the machine. - write_permission (`bool`): - Ignored and deprecated argument. + skip_if_logged_in (`bool`, defaults to `False`): + If `True`, do not prompt for token if user is already logged in. Raises: [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If an organization token is passed. Only personal account tokens are valid @@ -125,9 +117,9 @@ def login( ) _login(token, add_to_git_credential=add_to_git_credential) elif is_notebook(): - notebook_login(new_session=new_session) + notebook_login(skip_if_logged_in=skip_if_logged_in) else: - interpreter_login(new_session=new_session) + interpreter_login(skip_if_logged_in=skip_if_logged_in) def logout(token_name: Optional[str] = None) -> None: @@ -242,13 +234,8 @@ def auth_list() -> None: ### -@_deprecate_arguments( - version="1.0", - deprecated_args="write_permission", - custom_message="Fine-grained tokens added complexity to the permissions, making it irrelevant to check if a token has 'write' access.", -) @_deprecate_positional_args(version="1.0") -def interpreter_login(*, new_session: bool = True, write_permission: bool = False) -> None: +def interpreter_login(*, skip_if_logged_in: bool = False) -> None: """ Displays a prompt to log in to the HF website and store the token. @@ -259,12 +246,10 @@ def interpreter_login(*, new_session: bool = True, write_permission: bool = Fals For more details, see [`login`]. Args: - new_session (`bool`, defaults to `True`): - If `True`, will request a token even if one is already saved on the machine. - write_permission (`bool`): - Ignored and deprecated argument. + skip_if_logged_in (`bool`, defaults to `False`): + If `True`, do not prompt for token if user is already logged in. """ - if not new_session and get_token() is not None: + if not skip_if_logged_in and get_token() is not None: logger.info("User is already logged in.") return @@ -314,13 +299,8 @@ def interpreter_login(*, new_session: bool = True, write_permission: bool = Fals notebooks. """ -@_deprecate_arguments( - version="1.0", - deprecated_args="write_permission", - custom_message="Fine-grained tokens added complexity to the permissions, making it irrelevant to check if a token has 'write' access.", -) @_deprecate_positional_args(version="1.0") -def notebook_login(*, new_session: bool = True, write_permission: bool = False) -> None: +def notebook_login(*, skip_if_logged_in: bool = False) -> None: """ Displays a widget to log in to the HF website and store the token. @@ -331,10 +311,8 @@ def notebook_login(*, new_session: bool = True, write_permission: bool = False) For more details, see [`login`]. Args: - new_session (`bool`, defaults to `True`): - If `True`, will request a token even if one is already saved on the machine. - write_permission (`bool`): - Ignored and deprecated argument. + skip_if_logged_in (`bool`, defaults to `False`): + If `True`, do not prompt for token if user is already logged in. """ try: import ipywidgets.widgets as widgets # type: ignore @@ -344,7 +322,7 @@ def notebook_login(*, new_session: bool = True, write_permission: bool = False) "The `notebook_login` function can only be used in a notebook (Jupyter or" " Colab) and you need the `ipywidgets` module: `pip install ipywidgets`." ) - if not new_session and get_token() is not None: + if not skip_if_logged_in and get_token() is not None: logger.info("User is already logged in.") return diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index 200ed7cc2e..8c82f101cb 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Iterable, Literal, Optional, Union +from typing import Iterable, Optional, Union import httpx from tqdm.auto import tqdm as base_tqdm @@ -46,9 +46,6 @@ def snapshot_download( tqdm_class: Optional[type[base_tqdm]] = None, headers: Optional[dict[str, str]] = None, endpoint: Optional[str] = None, - # Deprecated args - local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", - resume_download: Optional[bool] = None, ) -> str: """Download repo files. @@ -303,12 +300,10 @@ def _inner_hf_hub_download(repo_file: str): endpoint=endpoint, cache_dir=cache_dir, local_dir=local_dir, - local_dir_use_symlinks=local_dir_use_symlinks, library_name=library_name, library_version=library_version, user_agent=user_agent, etag_timeout=etag_timeout, - resume_download=resume_download, force_download=force_download, token=token, headers=headers, diff --git a/src/huggingface_hub/commands/download.py b/src/huggingface_hub/commands/download.py index 103f2a52b5..06ce8905c1 100644 --- a/src/huggingface_hub/commands/download.py +++ b/src/huggingface_hub/commands/download.py @@ -133,16 +133,9 @@ def __init__(self, args: Namespace) -> None: self.cache_dir: Optional[str] = args.cache_dir self.local_dir: Optional[str] = args.local_dir self.force_download: bool = args.force_download - self.resume_download: Optional[bool] = args.resume_download or None self.quiet: bool = args.quiet self.max_workers: int = args.max_workers - if args.local_dir_use_symlinks is not None: - warnings.warn( - "Ignoring --local-dir-use-symlinks. Downloading to a local directory does not use symlinks anymore.", - FutureWarning, - ) - def run(self) -> None: show_deprecation_warning("huggingface-cli download", "hf download") @@ -173,7 +166,6 @@ def _download(self) -> str: revision=self.revision, filename=self.filenames[0], cache_dir=self.cache_dir, - resume_download=self.resume_download, force_download=self.force_download, token=self.token, local_dir=self.local_dir, @@ -194,7 +186,6 @@ def _download(self) -> str: revision=self.revision, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, - resume_download=self.resume_download, force_download=self.force_download, cache_dir=self.cache_dir, token=self.token, diff --git a/src/huggingface_hub/constants.py b/src/huggingface_hub/constants.py index 20c5b5d970..2ca29cf294 100644 --- a/src/huggingface_hub/constants.py +++ b/src/huggingface_hub/constants.py @@ -135,7 +135,6 @@ def _as_int(value: Optional[str]) -> Optional[int]: ) ) ) -hf_cache_home = HF_HOME # for backward compatibility. TODO: remove this in 1.0.0 default_cache_path = os.path.join(HF_HOME, "hub") default_assets_cache_path = os.path.join(HF_HOME, "assets") diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index c3f57fbd78..26efe85b59 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -9,7 +9,7 @@ import warnings from dataclasses import dataclass from pathlib import Path -from typing import Any, BinaryIO, Literal, NoReturn, Optional, Union +from typing import Any, BinaryIO, NoReturn, Optional, Union from urllib.parse import quote, urlparse import httpx @@ -806,9 +806,6 @@ def hf_hub_download( local_files_only: bool = False, headers: Optional[dict[str, str]] = None, endpoint: Optional[str] = None, - resume_download: Optional[bool] = None, - force_filename: Optional[str] = None, - local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", ) -> str: """Download a given file if it's not already present in the local cache. @@ -910,20 +907,6 @@ def hf_hub_download( # Respect environment variable above user value etag_timeout = constants.HF_HUB_ETAG_TIMEOUT - if force_filename is not None: - warnings.warn( - "The `force_filename` parameter is deprecated as a new caching system, " - "which keeps the filenames as they are on the Hub, is now in place.", - FutureWarning, - ) - if resume_download is not None: - warnings.warn( - "`resume_download` is deprecated and will be removed in version 1.0.0. " - "Downloads always resume when possible. " - "If you want to force a new download, use `force_download=True`.", - FutureWarning, - ) - if cache_dir is None: cache_dir = constants.HF_HUB_CACHE if revision is None: @@ -953,15 +936,6 @@ def hf_hub_download( ) if local_dir is not None: - if local_dir_use_symlinks != "auto": - warnings.warn( - "`local_dir_use_symlinks` parameter is deprecated and will be ignored. " - "The process to download files to a local folder has been updated and do " - "not rely on symlinks anymore. You only need to pass a destination folder " - "as`local_dir`.\n" - "For more details, check out https://huggingface.co/docs/huggingface_hub/main/en/guides/download#download-files-to-local-folder." - ) - return _hf_hub_download_to_local_dir( # Destination local_dir=local_dir, diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 6e209d9055..d1297c8dc1 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -106,7 +106,6 @@ from .repocard_data import DatasetCardData, ModelCardData, SpaceCardData from .utils import ( DEFAULT_IGNORE_PATTERNS, - LocalTokenNotFoundError, NotASafetensorsRepoError, SafetensorsFileMetadata, SafetensorsParsingError, @@ -131,7 +130,7 @@ _get_token_from_file, _get_token_from_google_colab, ) -from .utils._deprecation import _deprecate_arguments, _deprecate_method +from .utils._deprecation import _deprecate_arguments from .utils._runtime import is_xet_available from .utils._typing import CallableT from .utils.endpoint_helpers import _is_emission_within_threshold @@ -411,12 +410,6 @@ class CommitInfo(str): repo_url (`RepoUrl`): Repo URL of the commit containing info like repo_id, repo_type, etc. - - _url (`str`, *optional*): - Legacy url for `str` compatibility. Can be the url to the uploaded file on the Hub (if returned by - [`upload_file`]), to the uploaded folder on the Hub (if returned by [`upload_folder`]) or to the commit on - the Hub (if returned by [`create_commit`]). Defaults to `commit_url`. It is deprecated to use this - attribute. Please use `commit_url` instead. """ commit_url: str @@ -432,11 +425,8 @@ class CommitInfo(str): pr_revision: Optional[str] = field(init=False) pr_num: Optional[str] = field(init=False) - # legacy url for `str` compatibility (ex: url to uploaded file, url to uploaded folder, url to PR, etc.) - _url: str = field(repr=False, default=None) # type: ignore # defaults to `commit_url` - - def __new__(cls, *args, commit_url: str, _url: Optional[str] = None, **kwargs): - return str.__new__(cls, _url or commit_url) + def __new__(cls, *args, commit_url: str, **kwargs): + return str.__new__(cls, commit_url) def __post_init__(self): """Populate pr-related fields after initialization. @@ -1791,46 +1781,6 @@ def whoami(self, token: Union[bool, str, None] = None) -> dict: raise return r.json() - @_deprecate_method( - version="1.0", - message=( - "Permissions are more complex than when `get_token_permission` was first introduced. " - "OAuth and fine-grain tokens allows for more detailed permissions. " - "If you need to know the permissions associated with a token, please use `whoami` and check the `'auth'` key." - ), - ) - def get_token_permission( - self, token: Union[bool, str, None] = None - ) -> Literal["read", "write", "fineGrained", None]: - """ - Check if a given `token` is valid and return its permissions. - - - - This method is deprecated and will be removed in version 1.0. Permissions are more complex than when - `get_token_permission` was first introduced. OAuth and fine-grain tokens allows for more detailed permissions. - If you need to know the permissions associated with a token, please use `whoami` and check the `'auth'` key. - - - - For more details about tokens, please refer to https://huggingface.co/docs/hub/security-tokens#what-are-user-access-tokens. - - Args: - token (Union[bool, str, None], optional): - A valid user access token (string). Defaults to the locally saved - token, which is the recommended method for authentication (see - https://huggingface.co/docs/huggingface_hub/quick-start#authentication). - To disable authentication, pass `False`. - - Returns: - `Literal["read", "write", "fineGrained", None]`: Permission granted by the token ("read" or "write"). Returns `None` if no - token passed, if token is invalid or if role is not returned by the server. This typically happens when the token is an OAuth token. - """ - try: - return self.whoami(token=token)["auth"]["accessToken"]["role"] - except (LocalTokenNotFoundError, HfHubHTTPError, KeyError): - return None - def get_model_tags(self) -> dict: """ List all valid model tags as a nested namespace object @@ -1849,9 +1799,6 @@ def get_dataset_tags(self) -> dict: hf_raise_for_status(r) return r.json() - @_deprecate_arguments( - version="1.0", deprecated_args=["language", "library", "task", "tags"], custom_message="Use `filter` instead." - ) @validate_hf_hub_args def list_models( self, @@ -1878,11 +1825,6 @@ def list_models( cardData: bool = False, fetch_config: bool = False, token: Union[bool, str, None] = None, - # Deprecated arguments - use `filter` instead - language: Optional[Union[str, list[str]]] = None, - library: Optional[Union[str, list[str]]] = None, - tags: Optional[Union[str, list[str]]] = None, - task: Optional[Union[str, list[str]]] = None, ) -> Iterable[ModelInfo]: """ List models hosted on the Huggingface Hub, given some filters. @@ -1906,20 +1848,12 @@ def list_models( inference_provider (`Literal["all"]` or `str`, *optional*): A string to filter models on the Hub that are served by a specific provider. Pass `"all"` to get all models served by at least one provider. - library (`str` or `List`, *optional*): - Deprecated. Pass a library name in `filter` to filter models by library. - language (`str` or `List`, *optional*): - Deprecated. Pass a language in `filter` to filter models by language. model_name (`str`, *optional*): A string that contain complete or partial names for models on the Hub, such as "bert" or "bert-base-cased" - task (`str` or `List`, *optional*): - Deprecated. Pass a task in `filter` to filter models by task. trained_dataset (`str` or `List`, *optional*): A string tag or a list of string tags of the trained dataset for a model on the Hub. - tags (`str` or `List`, *optional*): - Deprecated. Pass tags in `filter` to filter models by tags. search (`str`, *optional*): A string that will be contained in the returned model ids. pipeline_tag (`str`, *optional*): @@ -2001,21 +1935,9 @@ def list_models( filter_list: list[str] = [] if filter: filter_list.extend([filter] if isinstance(filter, str) else filter) - if library: - filter_list.extend([library] if isinstance(library, str) else library) - if task: - filter_list.extend([task] if isinstance(task, str) else task) if trained_dataset: - if isinstance(trained_dataset, str): - trained_dataset = [trained_dataset] - for dataset in trained_dataset: - if not dataset.startswith("dataset:"): - dataset = f"dataset:{dataset}" - filter_list.append(dataset) - if language: - filter_list.extend([language] if isinstance(language, str) else language) - if tags: - filter_list.extend([tags] if isinstance(tags, str) else tags) + datasets = [trained_dataset] if isinstance(trained_dataset, str) else trained_dataset + filter_list.extend(f"dataset:{d}" if not d.startswith("dataset:") else d for d in datasets) if len(filter_list) > 0: params["filter"] = filter_list @@ -3827,61 +3749,6 @@ def delete_repo( if not missing_ok: raise - @_deprecate_method(version="0.32", message="Please use `update_repo_settings` instead.") - @validate_hf_hub_args - def update_repo_visibility( - self, - repo_id: str, - private: bool = False, - *, - token: Union[str, bool, None] = None, - repo_type: Optional[str] = None, - ) -> dict[str, bool]: - """Update the visibility setting of a repository. - - Deprecated. Use `update_repo_settings` instead. - - Args: - repo_id (`str`, *optional*): - A namespace (user or an organization) and a repo name separated by a `/`. - private (`bool`, *optional*, defaults to `False`): - Whether the repository should be private. - token (Union[bool, str, None], optional): - A valid user access token (string). Defaults to the locally saved - token, which is the recommended method for authentication (see - https://huggingface.co/docs/huggingface_hub/quick-start#authentication). - To disable authentication, pass `False`. - repo_type (`str`, *optional*): - Set to `"dataset"` or `"space"` if uploading to a dataset or - space, `None` or `"model"` if uploading to a model. Default is - `None`. - - Returns: - The HTTP response in json. - - - - Raises the following errors: - - - [`~utils.RepositoryNotFoundError`] - If the repository to download from cannot be found. This may be because it doesn't exist, - or because it is set to `private` and you do not have access. - - - """ - if repo_type not in constants.REPO_TYPES: - raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") - if repo_type is None: - repo_type = constants.REPO_TYPE_MODEL # default repo type - - r = get_session().put( - url=f"{self.endpoint}/api/{repo_type}s/{repo_id}/settings", - headers=self._build_hf_headers(token=token), - json={"private": private}, - ) - hf_raise_for_status(r) - return r.json() - @validate_hf_hub_args def update_repo_settings( self, @@ -4681,7 +4548,6 @@ def upload_file( ... repo_type="dataset", ... token="my_token", ... ) - "https://huggingface.co/datasets/username/my-dataset/blob/main/remote/file/path.h5" >>> upload_file( ... path_or_fileobj=".\\\\local\\\\file\\\\path", @@ -4689,7 +4555,6 @@ def upload_file( ... repo_id="username/my-model", ... token="my_token", ... ) - "https://huggingface.co/username/my-model/blob/main/remote/file/path.h5" >>> upload_file( ... path_or_fileobj=".\\\\local\\\\file\\\\path", @@ -4698,7 +4563,6 @@ def upload_file( ... token="my_token", ... create_pr=True, ... ) - "https://huggingface.co/username/my-model/blob/refs%2Fpr%2F1/remote/file/path.h5" ``` """ if repo_type not in constants.REPO_TYPES: @@ -4712,7 +4576,7 @@ def upload_file( path_in_repo=path_in_repo, ) - commit_info = self.create_commit( + return self.create_commit( repo_id=repo_id, repo_type=repo_type, operations=[operation], @@ -4724,23 +4588,6 @@ def upload_file( parent_commit=parent_commit, ) - if commit_info.pr_url is not None: - revision = quote(_parse_revision_from_pr_url(commit_info.pr_url), safe="") - if repo_type in constants.REPO_TYPES_URL_PREFIXES: - repo_id = constants.REPO_TYPES_URL_PREFIXES[repo_type] + repo_id - revision = revision if revision is not None else constants.DEFAULT_REVISION - - return CommitInfo( - commit_url=commit_info.commit_url, - commit_message=commit_info.commit_message, - commit_description=commit_info.commit_description, - oid=commit_info.oid, - pr_url=commit_info.pr_url, - # Similar to `hf_hub_url` but it's "blob" instead of "resolve" - # TODO: remove this in v1.0 - _url=f"{self.endpoint}/{repo_id}/blob/{revision}/{path_in_repo}", - ) - @overload def upload_folder( # type: ignore self, @@ -4916,7 +4763,6 @@ def upload_folder( ... token="my_token", ... ignore_patterns="**/logs/*.txt", ... ) - # "https://huggingface.co/datasets/username/my-dataset/tree/main/remote/experiment/checkpoints" # Upload checkpoints folder including logs while deleting existing logs from the repo # Useful if you don't know exactly which log files have already being pushed @@ -4928,7 +4774,6 @@ def upload_folder( ... token="my_token", ... delete_patterns="**/logs/*.txt", ... ) - "https://huggingface.co/datasets/username/my-dataset/tree/main/remote/experiment/checkpoints" # Upload checkpoints folder while creating a PR >>> upload_folder( @@ -4939,8 +4784,6 @@ def upload_folder( ... token="my_token", ... create_pr=True, ... ) - "https://huggingface.co/datasets/username/my-dataset/tree/refs%2Fpr%2F1/remote/experiment/checkpoints" - ``` """ if repo_type not in constants.REPO_TYPES: @@ -4984,7 +4827,7 @@ def upload_folder( commit_message = commit_message or "Upload folder using huggingface_hub" - commit_info = self.create_commit( + return self.create_commit( repo_type=repo_type, repo_id=repo_id, operations=commit_operations, @@ -4996,24 +4839,6 @@ def upload_folder( parent_commit=parent_commit, ) - # Create url to uploaded folder (for legacy return value) - if create_pr and commit_info.pr_url is not None: - revision = quote(_parse_revision_from_pr_url(commit_info.pr_url), safe="") - if repo_type in constants.REPO_TYPES_URL_PREFIXES: - repo_id = constants.REPO_TYPES_URL_PREFIXES[repo_type] + repo_id - revision = revision if revision is not None else constants.DEFAULT_REVISION - - return CommitInfo( - commit_url=commit_info.commit_url, - commit_message=commit_info.commit_message, - commit_description=commit_info.commit_description, - oid=commit_info.oid, - pr_url=commit_info.pr_url, - # Similar to `hf_hub_url` but it's "tree" instead of "resolve" - # TODO: remove this in v1.0 - _url=f"{self.endpoint}/{repo_id}/tree/{revision}/{path_in_repo}", - ) - @validate_hf_hub_args def delete_file( self, @@ -5424,10 +5249,6 @@ def hf_hub_download( etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, token: Union[bool, str, None] = None, local_files_only: bool = False, - # Deprecated args - resume_download: Optional[bool] = None, - force_filename: Optional[str] = None, - local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", ) -> str: """Download a given file if it's not already present in the local cache. @@ -5533,12 +5354,9 @@ def hf_hub_download( library_version=self.library_version, cache_dir=cache_dir, local_dir=local_dir, - local_dir_use_symlinks=local_dir_use_symlinks, user_agent=self.user_agent, force_download=force_download, - force_filename=force_filename, etag_timeout=etag_timeout, - resume_download=resume_download, token=token, headers=self.headers, local_files_only=local_files_only, @@ -5561,9 +5379,6 @@ def snapshot_download( ignore_patterns: Optional[Union[list[str], str]] = None, max_workers: int = 8, tqdm_class: Optional[type[base_tqdm]] = None, - # Deprecated args - local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", - resume_download: Optional[bool] = None, ) -> str: """Download repo files. @@ -5649,12 +5464,10 @@ def snapshot_download( endpoint=self.endpoint, cache_dir=cache_dir, local_dir=local_dir, - local_dir_use_symlinks=local_dir_use_symlinks, library_name=self.library_name, library_version=self.library_version, user_agent=self.user_agent, etag_timeout=etag_timeout, - resume_download=resume_download, force_download=force_download, token=token, local_files_only=local_files_only, @@ -10891,7 +10704,6 @@ def _parse_revision_from_pr_url(pr_url: str) -> str: whoami = api.whoami auth_check = api.auth_check -get_token_permission = api.get_token_permission list_models = api.list_models model_info = api.model_info @@ -10921,7 +10733,6 @@ def _parse_revision_from_pr_url(pr_url: str) -> str: create_commit = api.create_commit create_repo = api.create_repo delete_repo = api.delete_repo -update_repo_visibility = api.update_repo_visibility update_repo_settings = api.update_repo_settings move_repo = api.move_repo upload_file = api.upload_file diff --git a/src/huggingface_hub/hub_mixin.py b/src/huggingface_hub/hub_mixin.py index c297026d35..6397bde121 100644 --- a/src/huggingface_hub/hub_mixin.py +++ b/src/huggingface_hub/hub_mixin.py @@ -149,7 +149,6 @@ class ModelHubMixin: ... pretrained_model_name_or_path: Union[str, Path], ... *, ... force_download: bool = False, - ... resume_download: Optional[bool] = None, ... token: Optional[Union[str, bool]] = None, ... cache_dir: Optional[Union[str, Path]] = None, ... local_files_only: bool = False, @@ -465,7 +464,6 @@ def from_pretrained( pretrained_model_name_or_path: Union[str, Path], *, force_download: bool = False, - resume_download: Optional[bool] = None, token: Optional[Union[str, bool]] = None, cache_dir: Optional[Union[str, Path]] = None, local_files_only: bool = False, @@ -511,7 +509,6 @@ def from_pretrained( revision=revision, cache_dir=cache_dir, force_download=force_download, - resume_download=resume_download, token=token, local_files_only=local_files_only, ) @@ -564,7 +561,6 @@ def from_pretrained( revision=revision, cache_dir=cache_dir, force_download=force_download, - resume_download=resume_download, local_files_only=local_files_only, token=token, **model_kwargs, @@ -585,7 +581,6 @@ def _from_pretrained( revision: Optional[str], cache_dir: Optional[Union[str, Path]], force_download: bool, - resume_download: Optional[bool], local_files_only: bool, token: Optional[Union[str, bool]], **model_kwargs, @@ -768,7 +763,6 @@ def _from_pretrained( revision: Optional[str], cache_dir: Optional[Union[str, Path]], force_download: bool, - resume_download: Optional[bool], local_files_only: bool, token: Union[str, bool, None], map_location: str = "cpu", @@ -789,7 +783,6 @@ def _from_pretrained( revision=revision, cache_dir=cache_dir, force_download=force_download, - resume_download=resume_download, token=token, local_files_only=local_files_only, ) @@ -801,7 +794,6 @@ def _from_pretrained( revision=revision, cache_dir=cache_dir, force_download=force_download, - resume_download=resume_download, token=token, local_files_only=local_files_only, ) diff --git a/src/huggingface_hub/keras_mixin.py b/src/huggingface_hub/keras_mixin.py index fa38b5dfba..78c239acfe 100644 --- a/src/huggingface_hub/keras_mixin.py +++ b/src/huggingface_hub/keras_mixin.py @@ -459,7 +459,6 @@ def _from_pretrained( revision, cache_dir, force_download, - resume_download, local_files_only, token, config: Optional[dict[str, Any]] = None, diff --git a/src/huggingface_hub/utils/__init__.py b/src/huggingface_hub/utils/__init__.py index bf25d66950..6fc8c0ed7e 100644 --- a/src/huggingface_hub/utils/__init__.py +++ b/src/huggingface_hub/utils/__init__.py @@ -111,7 +111,7 @@ from ._subprocess import capture_output, run_interactive_subprocess, run_subprocess from ._telemetry import send_telemetry from ._typing import is_jsonable, is_simple_optional_type, unwrap_simple_optional_type -from ._validators import smoothly_deprecate_use_auth_token, validate_hf_hub_args, validate_repo_id +from ._validators import validate_hf_hub_args, validate_repo_id from ._xet import ( XetConnectionInfo, XetFileData, diff --git a/src/huggingface_hub/utils/_headers.py b/src/huggingface_hub/utils/_headers.py index 23726b56cc..d952d97121 100644 --- a/src/huggingface_hub/utils/_headers.py +++ b/src/huggingface_hub/utils/_headers.py @@ -20,7 +20,6 @@ from .. import constants from ._auth import get_token -from ._deprecation import _deprecate_arguments from ._runtime import ( get_fastai_version, get_fastcore_version, @@ -36,11 +35,6 @@ from ._validators import validate_hf_hub_args -@_deprecate_arguments( - version="1.0", - deprecated_args="is_write_action", - custom_message="This argument is ignored and we let the server handle the permission error instead (if any).", -) @validate_hf_hub_args def build_hf_headers( *, @@ -49,7 +43,6 @@ def build_hf_headers( library_version: Optional[str] = None, user_agent: Union[dict, str, None] = None, headers: Optional[dict[str, str]] = None, - is_write_action: bool = False, ) -> dict[str, str]: """ Build headers dictionary to send in a HF Hub call. @@ -86,8 +79,6 @@ def build_hf_headers( headers (`dict`, *optional*): Additional headers to include in the request. Those headers take precedence over the ones generated by this function. - is_write_action (`bool`): - Ignored and deprecated argument. Returns: A `dict` of headers to pass in your API call. diff --git a/src/huggingface_hub/utils/_validators.py b/src/huggingface_hub/utils/_validators.py index 8bbb16d87e..89efd396c3 100644 --- a/src/huggingface_hub/utils/_validators.py +++ b/src/huggingface_hub/utils/_validators.py @@ -48,9 +48,7 @@ def validate_hf_hub_args(fn: CallableT) -> CallableT: Validators: - [`~utils.validate_repo_id`]: `repo_id` must be `"repo_name"` or `"namespace/repo_name"`. Namespace is a username or an organization. - - [`~utils.smoothly_deprecate_use_auth_token`]: Use `token` instead of - `use_auth_token` (only if `use_auth_token` is not expected by the decorated - function - in practice, always the case in `huggingface_hub`). + - [`~utils.smoothly_deprecate_legacy_arguments`]: Ignore `proxies` when downloading files (should be set globally). Example: ```py @@ -68,20 +66,6 @@ def validate_hf_hub_args(fn: CallableT) -> CallableT: >>> my_cool_method(repo_id="other..repo..id") huggingface_hub.utils._validators.HFValidationError: Cannot have -- or .. in repo_id: 'other..repo..id'. - - >>> @validate_hf_hub_args - ... def my_cool_auth_method(token: str): - ... print(token) - - >>> my_cool_auth_method(token="a token") - "a token" - - >>> my_cool_auth_method(use_auth_token="a use_auth_token") - "a use_auth_token" - - >>> my_cool_auth_method(token="a token", use_auth_token="a use_auth_token") - UserWarning: Both `token` and `use_auth_token` are passed (...) - "a token" ``` Raises: @@ -91,13 +75,8 @@ def validate_hf_hub_args(fn: CallableT) -> CallableT: # TODO: add an argument to opt-out validation for specific argument? signature = inspect.signature(fn) - # Should the validator switch `use_auth_token` values to `token`? In practice, always - # True in `huggingface_hub`. Might not be the case in a downstream library. - check_use_auth_token = "use_auth_token" not in signature.parameters and "token" in signature.parameters - @wraps(fn) def _inner_fn(*args, **kwargs): - has_token = False for arg_name, arg_value in chain( zip(signature.parameters, args), # Args values kwargs.items(), # Kwargs values @@ -105,13 +84,7 @@ def _inner_fn(*args, **kwargs): if arg_name in ["repo_id", "from_id", "to_id"]: validate_repo_id(arg_value) - elif arg_name == "token" and arg_value is not None: - has_token = True - - if check_use_auth_token: - kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs) - - kwargs = smoothly_deprecate_proxies(fn_name=fn.__name__, kwargs=kwargs) + kwargs = smoothly_deprecate_legacy_arguments(fn_name=fn.__name__, kwargs=kwargs) return fn(*args, **kwargs) @@ -172,26 +145,33 @@ def validate_repo_id(repo_id: str) -> None: raise HFValidationError(f"Repo_id cannot end by '.git': '{repo_id}'.") -def smoothly_deprecate_proxies(fn_name: str, kwargs: dict[str, Any]) -> dict[str, Any]: - """Smoothly deprecate `proxies` in the `huggingface_hub` codebase. +def smoothly_deprecate_legacy_arguments(fn_name: str, kwargs: dict[str, Any]) -> dict[str, Any]: + """Smoothly deprecate legacy arguments in the `huggingface_hub` codebase. + + This function ignores some deprecated arguments from the kwargs and warns the user they are ignored. + The goal is to avoid breaking existing code while guiding the user to the new way of doing things. - This function removes the `proxies` key from the kwargs and warns the user that the `proxies` argument is ignored. - To set up proxies, user must either use the HTTP_PROXY environment variable or configure the `httpx.Client` manually - using the [`set_client_factory`] function. + List of deprecated arguments: + - `proxies`: + To set up proxies, user must either use the HTTP_PROXY environment variable or configure the `httpx.Client` + manually using the [`set_client_factory`] function. - In huggingface_hub 0.x, `proxies` was a dictionary directly passed to `requests.request`. - In huggingface_hub 1.x, we migrated to `httpx` which does not support `proxies` the same way. - In particular, it is not possible to configure proxies on a per-request basis. The solution is to configure - it globally using the [`set_client_factory`] function or using the HTTP_PROXY environment variable. + In huggingface_hub 0.x, `proxies` was a dictionary directly passed to `requests.request`. + In huggingface_hub 1.x, we migrated to `httpx` which does not support `proxies` the same way. + In particular, it is not possible to configure proxies on a per-request basis. The solution is to configure + it globally using the [`set_client_factory`] function or using the HTTP_PROXY environment variable. - More more details, see: - - https://www.python-httpx.org/advanced/proxies/ - - https://www.python-httpx.org/compatibility/#proxy-keys. + More more details, see: + - https://www.python-httpx.org/advanced/proxies/ + - https://www.python-httpx.org/compatibility/#proxy-keys. - We did not want to completely remove the `proxies` argument to avoid breaking existing code. + - `resume_download`: deprecated without replacement. `huggingface_hub` always resumes downloads whenever possible. + - `force_filename`: deprecated without replacement. Filename is always the same as on the Hub. + - `local_dir_use_symlinks`: deprecated without replacement. Downloading to a local directory does not use symlinks anymore. """ new_kwargs = kwargs.copy() # do not mutate input ! + # proxies proxies = new_kwargs.pop("proxies", None) # remove from kwargs if proxies is not None: warnings.warn( @@ -200,60 +180,28 @@ def smoothly_deprecate_proxies(fn_name: str, kwargs: dict[str, Any]) -> dict[str " See https://www.python-httpx.org/advanced/proxies/ for more details." ) - return new_kwargs - - -def smoothly_deprecate_use_auth_token(fn_name: str, has_token: bool, kwargs: dict[str, Any]) -> dict[str, Any]: - """Smoothly deprecate `use_auth_token` in the `huggingface_hub` codebase. - - The long-term goal is to remove any mention of `use_auth_token` in the codebase in - favor of a unique and less verbose `token` argument. This will be done a few steps: - - 0. Step 0: methods that require a read-access to the Hub use the `use_auth_token` - argument (`str`, `bool` or `None`). Methods requiring write-access have a `token` - argument (`str`, `None`). This implicit rule exists to be able to not send the - token when not necessary (`use_auth_token=False`) even if logged in. - - 1. Step 1: we want to harmonize everything and use `token` everywhere (supporting - `token=False` for read-only methods). In order not to break existing code, if - `use_auth_token` is passed to a function, the `use_auth_token` value is passed - as `token` instead, without any warning. - a. Corner case: if both `use_auth_token` and `token` values are passed, a warning - is thrown and the `use_auth_token` value is ignored. - - 2. Step 2: Once it is release, we should push downstream libraries to switch from - `use_auth_token` to `token` as much as possible, but without throwing a warning - (e.g. manually create issues on the corresponding repos). - - 3. Step 3: After a transitional period (6 months e.g. until April 2023?), we update - `huggingface_hub` to throw a warning on `use_auth_token`. Hopefully, very few - users will be impacted as it would have already been fixed. - In addition, unit tests in `huggingface_hub` must be adapted to expect warnings - to be thrown (but still use `use_auth_token` as before). - - 4. Step 4: After a normal deprecation cycle (3 releases ?), remove this validator. - `use_auth_token` will definitely not be supported. - In addition, we update unit tests in `huggingface_hub` to use `token` everywhere. + # resume_download + resume_download = new_kwargs.pop("resume_download", None) # remove from kwargs + if resume_download is not None: + warnings.warn( + f"The `resume_download` argument is deprecated and ignored in `{fn_name}`. Downloads always resume" + " whenever possible." + ) - This has been discussed in: - - https://github.com/huggingface/huggingface_hub/issues/1094. - - https://github.com/huggingface/huggingface_hub/pull/928 - - (related) https://github.com/huggingface/huggingface_hub/pull/1064 - """ - new_kwargs = kwargs.copy() # do not mutate input ! + # force_filename + force_filename = new_kwargs.pop("force_filename", None) # remove from kwargs + if force_filename is not None: + warnings.warn( + f"The `force_filename` argument is deprecated and ignored in `{fn_name}`. Filename is always the same " + "as on the Hub." + ) - use_auth_token = new_kwargs.pop("use_auth_token", None) # remove from kwargs - if use_auth_token is not None: - if has_token: - warnings.warn( - "Both `token` and `use_auth_token` are passed to" - f" `{fn_name}` with non-None values. `token` is now the" - " preferred argument to pass a User Access Token." - " `use_auth_token` value will be ignored." - ) - else: - # `token` argument is not passed and a non-None value is passed in - # `use_auth_token` => use `use_auth_token` value as `token` kwarg. - new_kwargs["token"] = use_auth_token + # local_dir_use_symlinks + local_dir_use_symlinks = new_kwargs.pop("local_dir_use_symlinks", None) # remove from kwargs + if local_dir_use_symlinks is not None: + warnings.warn( + f"The `local_dir_use_symlinks` argument is deprecated and ignored in `{fn_name}`. Downloading to a local" + " directory does not use symlinks anymore." + ) return new_kwargs diff --git a/tests/test_cli.py b/tests/test_cli.py index ab7d819ff0..e256ebce5b 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -558,7 +558,6 @@ def test_download_with_ignored_patterns(self, mock: Mock) -> None: include=["*.json"], exclude=["data/*"], force_download=True, - resume_download=True, cache_dir=None, quiet=False, local_dir=None, diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index ff7a829baf..bbd8decff1 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -27,7 +27,7 @@ from pathlib import Path from typing import Optional, Union, get_args from unittest.mock import Mock, patch -from urllib.parse import quote, urlparse +from urllib.parse import urlparse import pytest @@ -206,17 +206,6 @@ def test_delete_repo_error_message(self): def test_delete_repo_missing_ok(self) -> None: self._api.delete_repo("repo-that-does-not-exist", missing_ok=True) - def test_update_repo_visibility(self): - repo_id = self._api.create_repo(repo_id=repo_name()).repo_id - - self._api.update_repo_settings(repo_id=repo_id, private=True) - assert self._api.model_info(repo_id).private - - self._api.update_repo_settings(repo_id=repo_id, private=False) - assert not self._api.model_info(repo_id).private - - self._api.delete_repo(repo_id=repo_id) - def test_move_repo_normal_usage(self): repo_id = f"{USER}/{repo_name()}" new_repo_id = f"{USER}/{repo_name()}" @@ -280,17 +269,6 @@ def test_update_repo_settings_xet_enabled(self, repo_url: RepoUrl): info = self._api.model_info(repo_id, expand="xetEnabled") assert info.xet_enabled - @expect_deprecation("get_token_permission") - def test_get_token_permission_on_oauth_token(self): - whoami = { - "type": "user", - "auth": {"type": "oauth", "expiresAt": "2024-10-24T19:43:43.000Z"}, - # ... - # other values are ignored as we only need to check the "auth" value - } - with patch.object(self._api, "whoami", return_value=whoami): - assert self._api.get_token_permission() is None - class CommitApiTest(HfApiCommonTest): def setUp(self) -> None: @@ -338,8 +316,8 @@ def test_upload_file_str_path(self, repo_url: RepoUrl) -> None: path_in_repo="temp/new_file.md", repo_id=repo_id, ) - self.assertEqual(return_val, f"{repo_url}/blob/main/temp/new_file.md") - self.assertIsInstance(return_val, CommitInfo) + assert isinstance(return_val, CommitInfo) + assert return_val.startswith(f"{repo_url}/commit/") with SoftTemporaryDirectory() as cache_dir: with open(hf_hub_download(repo_id=repo_id, filename="temp/new_file.md", cache_dir=cache_dir)) as f: @@ -360,7 +338,8 @@ def test_upload_file_fileobj(self, repo_url: RepoUrl) -> None: path_in_repo="temp/new_file.md", repo_id=repo_id, ) - self.assertEqual(return_val, f"{repo_url}/blob/main/temp/new_file.md") + assert isinstance(return_val, CommitInfo) + assert return_val.startswith(f"{repo_url}/commit/") with SoftTemporaryDirectory() as cache_dir: with open(hf_hub_download(repo_id=repo_id, filename="temp/new_file.md", cache_dir=cache_dir)) as f: @@ -375,7 +354,8 @@ def test_upload_file_bytesio(self, repo_url: RepoUrl) -> None: path_in_repo="temp/new_file.md", repo_id=repo_id, ) - self.assertEqual(return_val, f"{repo_url}/blob/main/temp/new_file.md") + assert isinstance(return_val, CommitInfo) + assert return_val.startswith(f"{repo_url}/commit/") with SoftTemporaryDirectory() as cache_dir: with open(hf_hub_download(repo_id=repo_id, filename="temp/new_file.md", cache_dir=cache_dir)) as f: @@ -444,8 +424,9 @@ def test_upload_file_create_pr(self, repo_url: RepoUrl) -> None: repo_id=repo_id, create_pr=True, ) - self.assertEqual(return_val, f"{repo_url}/blob/{quote('refs/pr/1', safe='')}/temp/new_file.md") - self.assertIsInstance(return_val, CommitInfo) + assert isinstance(return_val, CommitInfo) + assert return_val.startswith(f"{repo_url}/commit/") + assert return_val.pr_revision == "refs/pr/1" with SoftTemporaryDirectory() as cache_dir: with open( @@ -482,19 +463,14 @@ def test_upload_folder(self, repo_url: RepoUrl) -> None: # Upload folder url = self._api.upload_folder(folder_path=self.tmp_dir, path_in_repo="temp/dir", repo_id=repo_id) - self.assertEqual( - url, - f"{self._api.endpoint}/{repo_id}/tree/main/temp/dir", - ) - self.assertIsInstance(url, CommitInfo) + assert isinstance(url, CommitInfo) + assert url.startswith(f"{repo_url}/commit/") # Check files are uploaded for rpath in ["temp", "nested/file.bin"]: local_path = os.path.join(self.tmp_dir, rpath) remote_path = f"temp/dir/{rpath}" - filepath = hf_hub_download( - repo_id=repo_id, filename=remote_path, revision="main", use_auth_token=self._token - ) + filepath = hf_hub_download(repo_id=repo_id, filename=remote_path, revision="main", token=self._token) assert filepath is not None with open(filepath, "rb") as downloaded_file: content = downloaded_file.read() @@ -514,7 +490,9 @@ def test_upload_folder_create_pr(self, repo_url: RepoUrl) -> None: return_val = self._api.upload_folder( folder_path=self.tmp_dir, path_in_repo="temp/dir", repo_id=repo_id, create_pr=True ) - self.assertEqual(return_val, f"{self._api.endpoint}/{repo_id}/tree/refs%2Fpr%2F1/temp/dir") + assert isinstance(return_val, CommitInfo) + assert return_val.startswith(f"{repo_url}/commit/") + assert return_val.pr_revision == "refs/pr/1" # Check files are uploaded for rpath in ["temp", "nested/file.bin"]: @@ -522,13 +500,6 @@ def test_upload_folder_create_pr(self, repo_url: RepoUrl) -> None: filepath = hf_hub_download(repo_id=repo_id, filename=f"temp/dir/{rpath}", revision="refs/pr/1") assert Path(local_path).read_bytes() == Path(filepath).read_bytes() - def test_upload_folder_default_path_in_repo(self): - REPO_NAME = repo_name("upload_folder_to_root") - self._api.create_repo(repo_id=REPO_NAME, exist_ok=False) - url = self._api.upload_folder(folder_path=self.tmp_dir, repo_id=f"{USER}/{REPO_NAME}") - # URL to root of repository - self.assertEqual(url, f"{self._api.endpoint}/{USER}/{REPO_NAME}/tree/main/") - @use_tmp_repo() def test_upload_folder_git_folder_excluded(self, repo_url: RepoUrl) -> None: # Simulate a folder with a .git folder @@ -1348,18 +1319,18 @@ def test_create_commit_delete_folder_implicit(self): ) with self.assertRaises(EntryNotFoundError): - hf_hub_download(self.repo_id, "1/file_1.md", use_auth_token=self._token) + hf_hub_download(self.repo_id, "1/file_1.md", token=self._token) with self.assertRaises(EntryNotFoundError): - hf_hub_download(self.repo_id, "1/file_2.md", use_auth_token=self._token) + hf_hub_download(self.repo_id, "1/file_2.md", token=self._token) # Still exists - hf_hub_download(self.repo_id, "2/file_3.md", use_auth_token=self._token) + hf_hub_download(self.repo_id, "2/file_3.md", token=self._token) def test_create_commit_delete_folder_explicit(self): self._api.delete_folder(path_in_repo="1", repo_id=self.repo_id) with self.assertRaises(EntryNotFoundError): - hf_hub_download(self.repo_id, "1/file_1.md", use_auth_token=self._token) + hf_hub_download(self.repo_id, "1/file_1.md", token=self._token) def test_create_commit_implicit_delete_folder_is_ok(self): self._api.create_commit( @@ -2283,38 +2254,34 @@ def test_failing_filter_models_by_author_and_model_name(self): models = list(self._api.list_models(author="muellerzr", model_name="testme")) assert len(models) == 0 - @expect_deprecation("list_models") def test_filter_models_with_library(self): - models = list(self._api.list_models(author="microsoft", model_name="wavlm-base-sd", library="tensorflow")) + models = list(self._api.list_models(author="microsoft", model_name="wavlm-base-sd", filter="tensorflow")) assert len(models) == 0 - models = list(self._api.list_models(author="microsoft", model_name="wavlm-base-sd", library="pytorch")) + models = list(self._api.list_models(author="microsoft", model_name="wavlm-base-sd", filter="pytorch")) assert len(models) > 0 - @expect_deprecation("list_models") def test_filter_models_with_task(self): - models = list(self._api.list_models(task="fill-mask", model_name="albert-base-v2")) + models = list(self._api.list_models(filter="fill-mask", model_name="albert-base-v2")) assert models[0].pipeline_tag == "fill-mask" assert "albert" in models[0].id assert "base" in models[0].id assert "v2" in models[0].id - models = list(self._api.list_models(task="dummytask")) + models = list(self._api.list_models(filter="dummytask")) assert len(models) == 0 - @expect_deprecation("list_models") def test_filter_models_by_language(self): for language in ["en", "fr", "zh"]: - for model in self._api.list_models(language=language, limit=5): + for model in self._api.list_models(filter=language, limit=5): assert language in model.tags - @expect_deprecation("list_models") def test_filter_models_with_tag(self): - models = list(self._api.list_models(author="HuggingFaceBR4", tags=["tensorboard"])) + models = list(self._api.list_models(author="HuggingFaceBR4", filter=["tensorboard"])) assert models[0].id.startswith("HuggingFaceBR4/") assert "tensorboard" in models[0].tags - models = list(self._api.list_models(tags="dummytag")) + models = list(self._api.list_models(filter=["dummytag"])) assert len(models) == 0 def test_filter_models_with_card_data(self): @@ -2578,7 +2545,7 @@ def test_model_info(self, mock_get_token: Mock) -> None: ): _ = self._api.model_info(repo_id=f"{USER}/{self.REPO_NAME}") - model_info = self._api.model_info(repo_id=f"{USER}/{self.REPO_NAME}", use_auth_token=self._token) + model_info = self._api.model_info(repo_id=f"{USER}/{self.REPO_NAME}", token=self._token) self.assertIsInstance(model_info, ModelInfo) @patch("huggingface_hub.utils._headers.get_token", return_value=None) @@ -2594,23 +2561,23 @@ def test_dataset_info(self, mock_get_token: Mock) -> None: ): _ = self._api.dataset_info(repo_id=f"{USER}/{self.REPO_NAME}") - dataset_info = self._api.dataset_info(repo_id=f"{USER}/{self.REPO_NAME}", use_auth_token=self._token) + dataset_info = self._api.dataset_info(repo_id=f"{USER}/{self.REPO_NAME}", token=self._token) self.assertIsInstance(dataset_info, DatasetInfo) def test_list_private_datasets(self): - orig = len(list(self._api.list_datasets(use_auth_token=False))) - new = len(list(self._api.list_datasets(use_auth_token=self._token))) + orig = len(list(self._api.list_datasets(token=False))) + new = len(list(self._api.list_datasets(token=self._token))) self.assertGreater(new, orig) def test_list_private_models(self): - orig = len(list(self._api.list_models(use_auth_token=False))) - new = len(list(self._api.list_models(use_auth_token=self._token))) + orig = len(list(self._api.list_models(token=False))) + new = len(list(self._api.list_models(token=self._token))) self.assertGreater(new, orig) @with_production_testing def test_list_private_spaces(self): - orig = len(list(self._api.list_spaces(use_auth_token=False))) - new = len(list(self._api.list_spaces(use_auth_token=self._token))) + orig = len(list(self._api.list_spaces(token=False))) + new = len(list(self._api.list_spaces(token=self._token))) self.assertGreaterEqual(new, orig) @@ -3451,11 +3418,8 @@ def test_hf_hub_download_alias(self, mock: Mock) -> None: revision=None, cache_dir=None, local_dir=None, - local_dir_use_symlinks="auto", force_download=False, - force_filename=None, etag_timeout=10, - resume_download=None, local_files_only=False, headers=None, ) @@ -3477,9 +3441,7 @@ def test_snapshot_download_alias(self, mock: Mock) -> None: revision=None, cache_dir=None, local_dir=None, - local_dir_use_symlinks="auto", etag_timeout=10, - resume_download=None, force_download=False, local_files_only=False, allow_patterns=None, diff --git a/tests/test_hub_mixin.py b/tests/test_hub_mixin.py index 5b1ccdd574..e6a410dd40 100644 --- a/tests/test_hub_mixin.py +++ b/tests/test_hub_mixin.py @@ -126,7 +126,6 @@ def _from_pretrained( revision: Optional[str], cache_dir: Optional[Union[str, Path]], force_download: bool, - resume_download: bool, local_files_only: bool, token: Optional[Union[str, bool]], **model_kwargs, @@ -340,7 +339,6 @@ def test_from_pretrained_model_id_and_revision(self, from_pretrained_mock: Mock) revision="123456789", # Revision is passed correctly! cache_dir=None, force_download=False, - resume_download=None, local_files_only=False, token=None, ) @@ -374,10 +372,7 @@ def test_push_to_hub(self): # Test config has been pushed to hub tmp_config_path = hf_hub_download( - repo_id=repo_id, - filename="config.json", - use_auth_token=TOKEN, - cache_dir=self.cache_dir, + repo_id=repo_id, filename="config.json", token=TOKEN, cache_dir=self.cache_dir ) with open(tmp_config_path) as f: assert json.load(f) == CONFIG_AS_DICT diff --git a/tests/test_hub_mixin_pytorch.py b/tests/test_hub_mixin_pytorch.py index 1ebde60d42..e7dcf47201 100644 --- a/tests/test_hub_mixin_pytorch.py +++ b/tests/test_hub_mixin_pytorch.py @@ -209,7 +209,6 @@ def test_from_pretrained_model_from_hub_prefer_safetensor(self, hf_hub_download_ revision=None, cache_dir=None, force_download=False, - resume_download=None, token=None, local_files_only=False, ) @@ -237,7 +236,6 @@ def test_from_pretrained_model_from_hub_fallback_pickle(self, hf_hub_download_mo revision=None, cache_dir=None, force_download=False, - resume_download=None, token=None, local_files_only=False, ) @@ -247,7 +245,6 @@ def test_from_pretrained_model_from_hub_fallback_pickle(self, hf_hub_download_mo revision=None, cache_dir=None, force_download=False, - resume_download=None, token=None, local_files_only=False, ) @@ -263,7 +260,6 @@ def test_from_pretrained_model_id_and_revision(self, from_pretrained_mock: Mock) revision="123456789", # Revision is passed correctly! cache_dir=None, force_download=False, - resume_download=None, local_files_only=False, token=None, ) @@ -314,10 +310,7 @@ def test_push_to_hub(self): # Test config has been pushed to hub tmp_config_path = hf_hub_download( - repo_id=repo_id, - filename="config.json", - use_auth_token=TOKEN, - cache_dir=self.cache_dir, + repo_id=repo_id, filename="config.json", token=TOKEN, cache_dir=self.cache_dir ) with open(tmp_config_path) as f: self.assertDictEqual(json.load(f), CONFIG) diff --git a/tests/test_keras_integration.py b/tests/test_keras_integration.py index 5601a6335f..c7f020200d 100644 --- a/tests/test_keras_integration.py +++ b/tests/test_keras_integration.py @@ -120,9 +120,7 @@ def test_push_to_hub_keras_mixin_via_http_basic(self): assert self._api.model_info(repo_id).id == repo_id # Test config has been pushed to hub - config_path = hf_hub_download( - repo_id=repo_id, filename="config.json", use_auth_token=TOKEN, cache_dir=self.cache_dir - ) + config_path = hf_hub_download(repo_id=repo_id, filename="config.json", token=TOKEN, cache_dir=self.cache_dir) with open(config_path) as f: assert json.load(f) == {"num": 7, "act": "gelu_fast"} diff --git a/tests/test_utils_headers.py b/tests/test_utils_headers.py index ff61cec932..d03f545095 100644 --- a/tests/test_utils_headers.py +++ b/tests/test_utils_headers.py @@ -20,28 +20,28 @@ class TestAuthHeadersUtil(unittest.TestCase): - def test_use_auth_token_str(self) -> None: - self.assertEqual(build_hf_headers(use_auth_token=FAKE_TOKEN), FAKE_TOKEN_HEADER) + def test_token_str(self) -> None: + self.assertEqual(build_hf_headers(token=FAKE_TOKEN), FAKE_TOKEN_HEADER) @patch("huggingface_hub.utils._headers.get_token", return_value=None) - def test_use_auth_token_true_no_cached_token(self, mock_get_token: Mock) -> None: + def test_token_true_no_cached_token(self, mock_get_token: Mock) -> None: with self.assertRaises(EnvironmentError): - build_hf_headers(use_auth_token=True) + build_hf_headers(token=True) @patch("huggingface_hub.utils._headers.get_token", return_value=FAKE_TOKEN) - def test_use_auth_token_true_has_cached_token(self, mock_get_token: Mock) -> None: - self.assertEqual(build_hf_headers(use_auth_token=True), FAKE_TOKEN_HEADER) + def test_token_true_has_cached_token(self, mock_get_token: Mock) -> None: + self.assertEqual(build_hf_headers(token=True), FAKE_TOKEN_HEADER) @patch("huggingface_hub.utils._headers.get_token", return_value=FAKE_TOKEN) - def test_use_auth_token_false(self, mock_get_token: Mock) -> None: - self.assertEqual(build_hf_headers(use_auth_token=False), NO_AUTH_HEADER) + def test_token_false(self, mock_get_token: Mock) -> None: + self.assertEqual(build_hf_headers(token=False), NO_AUTH_HEADER) @patch("huggingface_hub.utils._headers.get_token", return_value=None) - def test_use_auth_token_none_no_cached_token(self, mock_get_token: Mock) -> None: + def test_token_none_no_cached_token(self, mock_get_token: Mock) -> None: self.assertEqual(build_hf_headers(), NO_AUTH_HEADER) @patch("huggingface_hub.utils._headers.get_token", return_value=FAKE_TOKEN) - def test_use_auth_token_none_has_cached_token(self, mock_get_token: Mock) -> None: + def test_token_none_has_cached_token(self, mock_get_token: Mock) -> None: self.assertEqual(build_hf_headers(), FAKE_TOKEN_HEADER) @patch("huggingface_hub.utils._headers.get_token", return_value=FAKE_TOKEN) @@ -57,7 +57,7 @@ def test_implicit_use_disabled_but_explicit_use(self, mock_get_token: Mock) -> N "huggingface_hub.constants.HF_HUB_DISABLE_IMPLICIT_TOKEN", True ): # This is not an implicit use so we still send it - self.assertEqual(build_hf_headers(use_auth_token=True), FAKE_TOKEN_HEADER) + self.assertEqual(build_hf_headers(token=True), FAKE_TOKEN_HEADER) class TestUserAgentHeadersUtil(unittest.TestCase): diff --git a/tests/test_utils_validators.py b/tests/test_utils_validators.py index 66161daf18..07c1950094 100644 --- a/tests/test_utils_validators.py +++ b/tests/test_utils_validators.py @@ -4,7 +4,6 @@ from huggingface_hub.utils import ( HFValidationError, - smoothly_deprecate_use_auth_token, validate_hf_hub_args, validate_repo_id, ) @@ -58,61 +57,3 @@ def test_not_valid_repo_ids(self) -> None: for repo_id in self.NOT_VALID_VALUES: with self.assertRaises(HFValidationError, msg=f"'{repo_id}' must not be valid"): validate_repo_id(repo_id) - - -class TestSmoothlyDeprecateUseAuthToken(unittest.TestCase): - def test_token_normal_usage_as_arg(self) -> None: - self.assertEqual( - self.dummy_token_function("this_is_a_token"), - ("this_is_a_token", {}), - ) - - def test_token_normal_usage_as_kwarg(self) -> None: - self.assertEqual( - self.dummy_token_function(token="this_is_a_token"), - ("this_is_a_token", {}), - ) - - def test_token_normal_usage_with_more_kwargs(self) -> None: - self.assertEqual( - self.dummy_token_function(token="this_is_a_token", foo="bar"), - ("this_is_a_token", {"foo": "bar"}), - ) - - def test_token_with_smoothly_deprecated_use_auth_token(self) -> None: - self.assertEqual( - self.dummy_token_function(use_auth_token="this_is_a_use_auth_token"), - ("this_is_a_use_auth_token", {}), - ) - - def test_input_kwargs_not_mutated_by_smooth_deprecation(self) -> None: - initial_kwargs = {"a": "b", "use_auth_token": "token"} - kwargs = smoothly_deprecate_use_auth_token(fn_name="name", has_token=False, kwargs=initial_kwargs) - self.assertEqual(kwargs, {"a": "b", "token": "token"}) - self.assertEqual(initial_kwargs, {"a": "b", "use_auth_token": "token"}) # not mutated! - - def test_with_both_token_and_use_auth_token(self) -> None: - with self.assertWarns(UserWarning): - # `use_auth_token` is ignored ! - self.assertEqual( - self.dummy_token_function(token="this_is_a_token", use_auth_token="this_is_a_use_auth_token"), - ("this_is_a_token", {}), - ) - - def test_not_deprecated_use_auth_token(self) -> None: - # `use_auth_token` is accepted by `dummy_use_auth_token_function` - # => `smoothly_deprecate_use_auth_token` is not called - self.assertEqual( - self.dummy_use_auth_token_function(use_auth_token="this_is_a_use_auth_token"), - ("this_is_a_use_auth_token", {}), - ) - - @staticmethod - @validate_hf_hub_args - def dummy_token_function(token: str, **kwargs) -> None: - return token, kwargs - - @staticmethod - @validate_hf_hub_args - def dummy_use_auth_token_function(use_auth_token: str, **kwargs) -> None: - return use_auth_token, kwargs diff --git a/tests/test_xet_upload.py b/tests/test_xet_upload.py index 2db0279c97..a71e45db7e 100644 --- a/tests/test_xet_upload.py +++ b/tests/test_xet_upload.py @@ -102,8 +102,8 @@ def test_upload_file(self, api, tmp_path, repo_url): path_in_repo=filename_in_repo, repo_id=repo_id, ) + assert return_val.startswith(f"{api.endpoint}/{repo_id}/commit") - assert return_val == f"{api.endpoint}/{repo_id}/blob/main/{filename_in_repo}" # Download and verify content downloaded_file = hf_hub_download(repo_id=repo_id, filename=filename_in_repo, cache_dir=tmp_path) with open(downloaded_file, "rb") as f: @@ -193,7 +193,7 @@ def test_upload_folder(self, api, repo_url): repo_id=repo_id, ) - assert return_val == f"{api.endpoint}/{repo_id}/tree/main/{folder_in_repo}" + assert return_val.startswith(f"{api.endpoint}/{repo_id}/commit") files_in_repo = set(api.list_repo_files(repo_id=repo_id)) files = { f"{folder_in_repo}/text_file.txt", @@ -220,7 +220,7 @@ def test_upload_folder_create_pr(self, api, repo_url) -> None: create_pr=True, ) - assert return_val == f"{api.endpoint}/{repo_id}/tree/refs%2Fpr%2F1/{folder_in_repo}" + assert return_val.startswith(f"{api.endpoint}/{repo_id}/commit") for rpath in ["text_file.txt", "nested/nested_binary.safetensors"]: local_path = self.folder_path / rpath From 72c53f4ef106fdc44ae046b91ba16342ef0e1dd0 Mon Sep 17 00:00:00 2001 From: Lucain Date: Fri, 12 Sep 2025 09:53:44 +0200 Subject: [PATCH 05/19] [v1.0] Remove `Repository` class (#3346) * Remove Repository class + adapt docs * remove fr git_vs_http --- .github/workflows/python-tests.yml | 17 +- docs/source/cn/_toctree.yml | 5 +- docs/source/cn/concepts/git_vs_http.md | 40 - docs/source/cn/guides/repository.md | 88 - docs/source/de/_toctree.yml | 4 - docs/source/de/concepts/git_vs_http.md | 69 - docs/source/en/_toctree.yml | 2 - docs/source/en/guides/repository.md | 82 - docs/source/en/guides/upload.md | 118 +- .../source/en/package_reference/repository.md | 51 - docs/source/fr/_toctree.yml | 4 - docs/source/fr/concepts/git_vs_http.md | 67 - docs/source/hi/_toctree.yml | 4 - docs/source/hi/concepts/git_vs_http.md | 33 - docs/source/ko/_toctree.yml | 6 - docs/source/ko/concepts/git_vs_http.md | 53 - docs/source/ko/guides/repository.md | 223 --- docs/source/ko/guides/upload.md | 122 +- .../source/ko/package_reference/repository.md | 49 - src/huggingface_hub/__init__.py | 5 - src/huggingface_hub/repository.py | 1477 ----------------- tests/test_repository.py | 895 ---------- 22 files changed, 5 insertions(+), 3409 deletions(-) delete mode 100644 docs/source/cn/concepts/git_vs_http.md delete mode 100644 docs/source/de/concepts/git_vs_http.md delete mode 100644 docs/source/en/package_reference/repository.md delete mode 100644 docs/source/fr/concepts/git_vs_http.md delete mode 100644 docs/source/hi/concepts/git_vs_http.md delete mode 100644 docs/source/ko/concepts/git_vs_http.md delete mode 100644 docs/source/ko/guides/repository.md delete mode 100644 docs/source/ko/package_reference/repository.md delete mode 100644 src/huggingface_hub/repository.py delete mode 100644 tests/test_repository.py diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 6c9bacf656..336e44ef54 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -22,13 +22,7 @@ jobs: fail-fast: false matrix: python-version: ["3.9", "3.13"] - test_name: - [ - "Repository only", - "Everything else", - "Inference only", - "Xet only" - ] + test_name: ["Everything else", "Inference only", "Xet only"] include: - python-version: "3.13" # LFS not ran on 3.9 test_name: "lfs" @@ -65,7 +59,7 @@ jobs: case "${{ matrix.test_name }}" in - "Repository only" | "Everything else" | "Inference only") + "Everything else" | "Inference only") sudo apt update sudo apt install -y libsndfile1-dev ;; @@ -112,13 +106,6 @@ jobs: case "${{ matrix.test_name }}" in - "Repository only") - # Run repo tests concurrently - PYTEST="$PYTEST ../tests -k 'TestRepository' -n 4" - echo $PYTEST - eval $PYTEST - ;; - "Inference only") # Run inference tests concurrently PYTEST="$PYTEST ../tests -k 'test_inference' -n 4" diff --git a/docs/source/cn/_toctree.yml b/docs/source/cn/_toctree.yml index b4949efa35..db6d3244a9 100644 --- a/docs/source/cn/_toctree.yml +++ b/docs/source/cn/_toctree.yml @@ -20,7 +20,4 @@ title: 概览 - local: guides/hf_file_system title: Hugging Face 文件系统 -- title: "concepts" - sections: - - local: concepts/git_vs_http - title: Git vs HTTP 范式 + diff --git a/docs/source/cn/concepts/git_vs_http.md b/docs/source/cn/concepts/git_vs_http.md deleted file mode 100644 index b582b5f991..0000000000 --- a/docs/source/cn/concepts/git_vs_http.md +++ /dev/null @@ -1,40 +0,0 @@ - - -# Git 与 HTTP 范式 - -`huggingface_hub`库是用于与Hugging Face Hub进行交互的库,Hugging Face Hub是一组基于Git的存储库(模型、数据集或Spaces)。使用 `huggingface_hub`有两种主要方式来访问Hub。 - -第一种方法,即所谓的“基于git”的方法,由[`Repository`]类驱动。这种方法使用了一个包装器,它在 `git`命令的基础上增加了专门与Hub交互的额外函数。第二种选择,称为“基于HTTP”的方法,涉及使用[`HfApi`]客户端进行HTTP请求。让我们来看一看每种方法的优缺点。 - -## 存储库:基于历史的 Git 方法 - -最初,`huggingface_hub`主要围绕 [`Repository`] 类构建。它为常见的 `git` 命令(如 `"git add"`、`"git commit"`、`"git push"`、`"git tag"`、`"git checkout"` 等)提供了 Python 包装器 - -该库还可以帮助设置凭据和跟踪大型文件,这些文件通常在机器学习存储库中使用。此外,该库允许您在后台执行其方法,使其在训练期间上传数据很有用。 - -使用 [`Repository`] 的最大优点是它允许你在本地机器上维护整个存储库的本地副本。这也可能是一个缺点,因为它需要你不断更新和维护这个本地副本。这类似于传统软件开发中,每个开发人员都维护自己的本地副本,并在开发功能时推送更改。但是,在机器学习的上下文中,这可能并不总是必要的,因为用户可能只需要下载推理所需的权重,或将权重从一种格式转换为另一种格式,而无需克隆整个存储库。 - -## HfApi: 一个功能强大且方便的HTTP客户端 - -`HfApi` 被开发为本地 git 存储库的替代方案,因为本地 git 存储库在处理大型模型或数据集时可能会很麻烦。`HfApi` 提供与基于 git 的方法相同的功能,例如下载和推送文件以及创建分支和标签,但无需本地文件夹来保持同步。 - -`HfApi`除了提供 `git` 已经提供的功能外,还提供其他功能,例如: - -* 管理存储库 -* 使用缓存下载文件以进行有效的重复使用 -* 在 Hub 中搜索存储库和元数据 -* 访问社区功能,如讨论、PR和评论 -* 配置Spaces - -## 我应该使用什么?以及何时使用? - -总的来说,在大多数情况下,`HTTP 方法`是使用 huggingface_hub 的推荐方法。但是,在以下几种情况下,维护本地 git 克隆(使用 `Repository`)可能更有益: - -如果您在本地机器上训练模型,使用传统的 git 工作流程并定期推送更新可能更有效。`Repository` 被优化为此类情况,因为它能够在后台运行。 -如果您需要手动编辑大型文件,`git `是最佳选择,因为它只会将文件的差异发送到服务器。使用 `HfAPI` 客户端,每次编辑都会上传整个文件。请记住,大多数大型文件是二进制文件,因此无法从 git 差异中受益。 - -并非所有 git 命令都通过 [`HfApi`] 提供。有些可能永远不会被实现,但我们一直在努力改进并缩小差距。如果您没有看到您的用例被覆盖。 - -请在[Github](https://github.com/huggingface/huggingface_hub)打开一个 issue!我们欢迎反馈,以帮助我们与我们的用户一起构建 🤗 生态系统。 diff --git a/docs/source/cn/guides/repository.md b/docs/source/cn/guides/repository.md index d84d37938b..e414cfc3e9 100644 --- a/docs/source/cn/guides/repository.md +++ b/docs/source/cn/guides/repository.md @@ -156,91 +156,3 @@ GitRefs( >>> from huggingface_hub import move_repo >>> move_repo(from_id="Wauplin/cool-model", to_id="huggingface/cool-model") ``` - -## 管理存储库的本地副本 - -上述所有操作都可以通过HTTP请求完成。然而,在某些情况下,您可能希望在本地拥有存储库的副本,并使用您熟悉的Git命令与之交互。 - -[`Repository`] 类允许您使用类似于Git命令的函数与Hub上的文件和存储库进行交互。它是对Git和Git-LFS方法的包装,以使用您已经了解和喜爱的Git命令。在开始之前,请确保已安装Git-LFS(请参阅[此处](https://git-lfs.github.com/)获取安装说明)。 - -### 使用本地存储库 - -使用本地存储库路径实例化一个 [`Repository`] 对象: - -请运行以下代码: - -```py ->>> from huggingface_hub import Repository ->>> repo = Repository(local_dir="//") -``` - -### 克隆 - -`clone_from`参数将一个存储库从Hugging Face存储库ID克隆到由 `local_dir`参数指定的本地目录: - -请运行以下代码: - -```py ->>> from huggingface_hub import Repository ->>> repo = Repository(local_dir="w2v2", clone_from="facebook/wav2vec2-large-960h-lv60") -``` -`clone_from`还可以使用URL克隆存储库: - -请运行以下代码: - -```py ->>> repo = Repository(local_dir="huggingface-hub", clone_from="https://huggingface.co/facebook/wav2vec2-large-960h-lv60") -``` - -你可以将`clone_from`参数与[`create_repo`]结合使用,以创建并克隆一个存储库: - -请运行以下代码: - -```py ->>> repo_url = create_repo(repo_id="repo_name") ->>> repo = Repository(local_dir="repo_local_path", clone_from=repo_url) -``` - -当你克隆一个存储库时,通过在克隆时指定`git_user`和`git_email`参数,你还可以为克隆的存储库配置Git用户名和电子邮件。当用户提交到该存储库时,Git将知道提交的作者是谁。 - -请运行以下代码: - -```py ->>> repo = Repository( -... "my-dataset", -... clone_from="/", -... token=True, -... repo_type="dataset", -... git_user="MyName", -... git_email="me@cool.mail" -... ) -``` - -### 分支 - -分支对于协作和实验而不影响当前文件和代码非常重要。使用[`~Repository.git_checkout`]来在不同的分支之间切换。例如,如果你想从 `branch1`切换到 `branch2`: - -请运行以下代码: - -```py ->>> from huggingface_hub import Repository ->>> repo = Repository(local_dir="huggingface-hub", clone_from="/", revision='branch1') ->>> repo.git_checkout("branch2") -``` - -### 拉取 - -[`~Repository.git_pull`] 允许你使用远程存储库的更改更新当前本地分支: - -请运行以下代码: - -```py ->>> from huggingface_hub import Repository ->>> repo.git_pull() -``` - -如果你希望本地的提交发生在你的分支被远程的新提交更新之后,请设置`rebase=True`: - -```py ->>> repo.git_pull(rebase=True) -``` diff --git a/docs/source/de/_toctree.yml b/docs/source/de/_toctree.yml index 48807ba0d8..2b994c7cc6 100644 --- a/docs/source/de/_toctree.yml +++ b/docs/source/de/_toctree.yml @@ -34,7 +34,3 @@ title: Integrieren einer Bibliothek - local: guides/webhooks_server title: Webhooks server -- title: "Konzeptionelle Anleitungen" - sections: - - local: concepts/git_vs_http - title: Git vs. HTTP-Paradigma diff --git a/docs/source/de/concepts/git_vs_http.md b/docs/source/de/concepts/git_vs_http.md deleted file mode 100644 index 978123762a..0000000000 --- a/docs/source/de/concepts/git_vs_http.md +++ /dev/null @@ -1,69 +0,0 @@ - - -# Git vs. HTTP-Paradigma - -Die `huggingface_hub`-Bibliothek ist eine Bibliothek zur Interaktion mit dem Hugging Face -Hub, einer Sammlung von auf Git basierenden Repositories (Modelle, Datensätze oder -Spaces). Es gibt zwei Hauptmethoden, um auf den Hub mit `huggingface_hub` zuzugreifen. - -Der erste Ansatz, der sogenannte "Git-basierte" Ansatz, wird von der [`Repository`] Klasse -geleitet. Diese Methode verwendet einen Wrapper um den `git`-Befehl mit zusätzlichen -Funktionen, die speziell für die Interaktion mit dem Hub entwickelt wurden. Die zweite -Option, die als "HTTP-basierter" Ansatz bezeichnet wird, umfasst das Senden von -HTTP-Anfragen mit dem [`HfApi`] Client. Schauen wir uns die Vor- und Nachteile jeder -Methode an. - -## Repository: Der historische git-basierte Ansatz - -Ursprünglich wurde `huggingface_hub` größtenteils um die [`Repository`] Klasse herum -entwickelt. Sie bietet Python-Wrapper für gängige git-Befehle wie `"git add"`, `"git commit"`, -`"git push"`, `"git tag"`, `"git checkout"` usw. - -Die Bibliothek hilft auch beim Festlegen von Zugangsdaten und beim Tracking von großen -Dateien, die in Machine-Learning-Repositories häufig verwendet werden. Darüber hinaus -ermöglicht die Bibliothek das Ausführen ihrer Methoden im Hintergrund, was nützlich ist, -um Daten während des Trainings hochzuladen. - -Der Hauptvorteil bei der Verwendung einer [`Repository`] besteht darin, dass Sie eine -lokale Kopie des gesamten Repositorys auf Ihrem Computer pflegen können. Dies kann jedoch -auch ein Nachteil sein, da es erfordert, diese lokale Kopie ständig zu aktualisieren und -zu pflegen. Dies ähnelt der traditionellen Softwareentwicklung, bei der jeder Entwickler -eine eigene lokale Kopie pflegt und Änderungen überträgt, wenn an einer Funktion -gearbeitet wird. Im Kontext des Machine Learning ist dies jedoch nicht immer erforderlich, -da Benutzer möglicherweise nur Gewichte für die Inferenz herunterladen oder Gewichte von -einem Format in ein anderes konvertieren müssen, ohne das gesamte Repository zu klonen. - -## HfApi: Ein flexibler und praktischer HTTP-Client - -Die [`HfApi`] Klasse wurde entwickelt, um eine Alternative zu lokalen Git-Repositories -bereitzustellen, die besonders bei der Arbeit mit großen Modellen oder Datensätzen -umständlich zu pflegen sein können. Die [`HfApi`] Klasse bietet die gleiche Funktionalität -wie git-basierte Ansätze, wie das Herunterladen und Hochladen von Dateien sowie das -Erstellen von Branches und Tags, jedoch ohne die Notwendigkeit eines lokalen Ordners, der -synchronisiert werden muss. - -Zusätzlich zu den bereits von `git` bereitgestellten Funktionen bietet die [`HfApi`] -Klasse zusätzliche Features wie die Möglichkeit, Repositories zu verwalten, Dateien mit -Caching für effiziente Wiederverwendung herunterzuladen, im Hub nach Repositories und -Metadaten zu suchen, auf Community-Funktionen wie Diskussionen, Pull Requests und -Kommentare zuzugreifen und Spaces-Hardware und Geheimnisse zu konfigurieren. - -## Was sollte ich verwenden ? Und wann ? - -Insgesamt ist der **HTTP-basierte Ansatz in den meisten Fällen die empfohlene Methode zur Verwendung von** -`huggingface_hub`. Es gibt jedoch einige Situationen, in denen es vorteilhaft sein kann, -eine lokale Git-Kopie (mit [`Repository`]) zu pflegen: -- Wenn Sie ein Modell auf Ihrem Computer trainieren, kann es effizienter sein, einen -herkömmlichen git-basierten Workflow zu verwenden und regelmäßige Updates zu pushen. -[`Repository`] ist für diese Art von Situation mit seiner Fähigkeit zur Hintergrundarbeit optimiert. -- Wenn Sie große Dateien manuell bearbeiten müssen, ist `git` die beste Option, da es nur -die Differenz an den Server sendet. Mit dem [`HfAPI`] Client wird die gesamte Datei bei -jeder Bearbeitung hochgeladen. Beachten Sie jedoch, dass die meisten großen Dateien binär -sind und daher sowieso nicht von Git-Diffs profitieren. - -Nicht alle Git-Befehle sind über [`HfApi`] verfügbar. Einige werden vielleicht nie -implementiert, aber wir bemühen uns ständig, die Lücken zu schließen und zu verbessern. -Wenn Sie Ihren Anwendungsfall nicht abgedeckt sehen, öffnen Sie bitte [ein Issue auf -Github](https://github.com/huggingface/huggingface_hub)! Wir freuen uns über Feedback, um das 🤗-Ökosystem mit und für unsere Benutzer aufzubauen. diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 4c03a41c7b..3f930fb448 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -54,8 +54,6 @@ title: Authentication - local: package_reference/environment_variables title: Environment variables - - local: package_reference/repository - title: Managing local and online repositories - local: package_reference/hf_api title: Hugging Face Hub API - local: package_reference/file_download diff --git a/docs/source/en/guides/repository.md b/docs/source/en/guides/repository.md index 9943514ddf..6a9ce1f9ff 100644 --- a/docs/source/en/guides/repository.md +++ b/docs/source/en/guides/repository.md @@ -178,85 +178,3 @@ that you should be aware of. For example, you can't transfer your repo to anothe >>> from huggingface_hub import move_repo >>> move_repo(from_id="Wauplin/cool-model", to_id="huggingface/cool-model") ``` - -## Manage a local copy of your repository - -All the actions described above can be done using HTTP requests. However, in some cases you might be interested in having -a local copy of your repository and interact with it using the Git commands you are familiar with. - -The [`Repository`] class allows you to interact with files and repositories on the Hub with functions similar to Git commands. It is a wrapper over Git and Git-LFS methods to use the Git commands you already know and love. Before starting, please make sure you have Git-LFS installed (see [here](https://git-lfs.github.com/) for installation instructions). - - - -[`Repository`] is deprecated in favor of the http-based alternatives implemented in [`HfApi`]. Given its large adoption in legacy code, the complete removal of [`Repository`] will only happen in release `v1.0`. For more details, please read [this explanation page](./concepts/git_vs_http). - - - -### Use a local repository - -Instantiate a [`Repository`] object with a path to a local repository: - -```py ->>> from huggingface_hub import Repository ->>> repo = Repository(local_dir="//") -``` - -### Clone - -The `clone_from` parameter clones a repository from a Hugging Face repository ID to a local directory specified by the `local_dir` argument: - -```py ->>> from huggingface_hub import Repository ->>> repo = Repository(local_dir="w2v2", clone_from="facebook/wav2vec2-large-960h-lv60") -``` - -`clone_from` can also clone a repository using a URL: - -```py ->>> repo = Repository(local_dir="huggingface-hub", clone_from="https://huggingface.co/facebook/wav2vec2-large-960h-lv60") -``` - -You can combine the `clone_from` parameter with [`create_repo`] to create and clone a repository: - -```py ->>> repo_url = create_repo(repo_id="repo_name") ->>> repo = Repository(local_dir="repo_local_path", clone_from=repo_url) -``` - -You can also configure a Git username and email to a cloned repository by specifying the `git_user` and `git_email` parameters when you clone a repository. When users commit to that repository, Git will be aware of the commit author. - -```py ->>> repo = Repository( -... "my-dataset", -... clone_from="/", -... token=True, -... repo_type="dataset", -... git_user="MyName", -... git_email="me@cool.mail" -... ) -``` - -### Branch - -Branches are important for collaboration and experimentation without impacting your current files and code. Switch between branches with [`~Repository.git_checkout`]. For example, if you want to switch from `branch1` to `branch2`: - -```py ->>> from huggingface_hub import Repository ->>> repo = Repository(local_dir="huggingface-hub", clone_from="/", revision='branch1') ->>> repo.git_checkout("branch2") -``` - -### Pull - -[`~Repository.git_pull`] allows you to update a current local branch with changes from a remote repository: - -```py ->>> from huggingface_hub import Repository ->>> repo.git_pull() -``` - -Set `rebase=True` if you want your local commits to occur after your branch is updated with the new commits from the remote: - -```py ->>> repo.git_pull(rebase=True) -``` diff --git a/docs/source/en/guides/upload.md b/docs/source/en/guides/upload.md index 69930887eb..ff91031eba 100644 --- a/docs/source/en/guides/upload.md +++ b/docs/source/en/guides/upload.md @@ -4,12 +4,7 @@ rendered properly in your Markdown viewer. # Upload files to the Hub -Sharing your files and work is an important aspect of the Hub. The `huggingface_hub` offers several options for uploading your files to the Hub. You can use these functions independently or integrate them into your library, making it more convenient for your users to interact with the Hub. This guide will show you how to push files: - -- without using Git. -- that are very large with [Git LFS](https://git-lfs.github.com/). -- with the `commit` context manager. -- with the [`~Repository.push_to_hub`] function. +Sharing your files and work is an important aspect of the Hub. The `huggingface_hub` offers several options for uploading your files to the Hub. You can use these functions independently or integrate them into your library, making it more convenient for your users to interact with the Hub. Whenever you want to upload files to the Hub, you need to log in to your Hugging Face account. For more details about authentication, check out [this section](../quick-start#authentication). @@ -486,114 +481,3 @@ update of the object is that **the binary content is removed** from it, meaning you don't store another reference to it. This is expected as we don't want to keep in memory the content that is already uploaded. Finally we create the commit by passing all the operations to [`create_commit`]. You can pass additional operations (add, delete or copy) that have not been processed yet and they will be handled correctly. - -## (legacy) Upload files with Git LFS - -All the methods described above use the Hub's API to upload files. This is the recommended way to upload files to the Hub. -However, we also provide [`Repository`], a wrapper around the git tool to manage a local repository. - - - -Although [`Repository`] is not formally deprecated, we recommend using the HTTP-based methods described above instead. -For more details about this recommendation, please have a look at [this guide](../concepts/git_vs_http) explaining the -core differences between HTTP-based and Git-based approaches. - - - -Git LFS automatically handles files larger than 10MB. But for very large files (>5GB), you need to install a custom transfer agent for Git LFS: - -```bash -hf lfs-enable-largefiles -``` - -You should install this for each repository that has a very large file. Once installed, you'll be able to push files larger than 5GB. - -### commit context manager - -The `commit` context manager handles four of the most common Git commands: pull, add, commit, and push. `git-lfs` automatically tracks any file larger than 10MB. In the following example, the `commit` context manager: - -1. Pulls from the `text-files` repository. -2. Adds a change made to `file.txt`. -3. Commits the change. -4. Pushes the change to the `text-files` repository. - -```python ->>> from huggingface_hub import Repository ->>> with Repository(local_dir="text-files", clone_from="/text-files").commit(commit_message="My first file :)"): -... with open("file.txt", "w+") as f: -... f.write(json.dumps({"hey": 8})) -``` - -Here is another example of how to use the `commit` context manager to save and upload a file to a repository: - -```python ->>> import torch ->>> model = torch.nn.Transformer() ->>> with Repository("torch-model", clone_from="/torch-model", token=True).commit(commit_message="My cool model :)"): -... torch.save(model.state_dict(), "model.pt") -``` - -Set `blocking=False` if you would like to push your commits asynchronously. Non-blocking behavior is helpful when you want to continue running your script while your commits are being pushed. - -```python ->>> with repo.commit(commit_message="My cool model :)", blocking=False) -``` - -You can check the status of your push with the `command_queue` method: - -```python ->>> last_command = repo.command_queue[-1] ->>> last_command.status -``` - -Refer to the table below for the possible statuses: - -| Status | Description | -| -------- | ------------------------------------ | -| -1 | The push is ongoing. | -| 0 | The push has completed successfully. | -| Non-zero | An error has occurred. | - -When `blocking=False`, commands are tracked, and your script will only exit when all pushes are completed, even if other errors occur in your script. Some additional useful commands for checking the status of a push include: - -```python -# Inspect an error. ->>> last_command.stderr - -# Check whether a push is completed or ongoing. ->>> last_command.is_done - -# Check whether a push command has errored. ->>> last_command.failed -``` - -### push_to_hub - -The [`Repository`] class has a [`~Repository.push_to_hub`] function to add files, make a commit, and push them to a repository. Unlike the `commit` context manager, you'll need to pull from a repository first before calling [`~Repository.push_to_hub`]. - -For example, if you've already cloned a repository from the Hub, then you can initialize the `repo` from the local directory: - -```python ->>> from huggingface_hub import Repository ->>> repo = Repository(local_dir="path/to/local/repo") -``` - -Update your local clone with [`~Repository.git_pull`] and then push your file to the Hub: - -```py ->>> repo.git_pull() ->>> repo.push_to_hub(commit_message="Commit my-awesome-file to the Hub") -``` - -However, if you aren't ready to push a file yet, you can use [`~Repository.git_add`] and [`~Repository.git_commit`] to only add and commit your file: - -```py ->>> repo.git_add("path/to/file") ->>> repo.git_commit(commit_message="add my first model config file :)") -``` - -When you're ready, push the file to your repository with [`~Repository.git_push`]: - -```py ->>> repo.git_push() -``` diff --git a/docs/source/en/package_reference/repository.md b/docs/source/en/package_reference/repository.md deleted file mode 100644 index de7851d6a9..0000000000 --- a/docs/source/en/package_reference/repository.md +++ /dev/null @@ -1,51 +0,0 @@ - - -# Managing local and online repositories - -The `Repository` class is a helper class that wraps `git` and `git-lfs` commands. It provides tooling adapted -for managing repositories which can be very large. - -It is the recommended tool as soon as any `git` operation is involved, or when collaboration will be a point -of focus with the repository itself. - -## The Repository class - -[[autodoc]] Repository - - __init__ - - current_branch - - all - -## Helper methods - -[[autodoc]] huggingface_hub.repository.is_git_repo - -[[autodoc]] huggingface_hub.repository.is_local_clone - -[[autodoc]] huggingface_hub.repository.is_tracked_with_lfs - -[[autodoc]] huggingface_hub.repository.is_git_ignored - -[[autodoc]] huggingface_hub.repository.files_to_be_staged - -[[autodoc]] huggingface_hub.repository.is_tracked_upstream - -[[autodoc]] huggingface_hub.repository.commits_to_push - -## Following asynchronous commands - -The `Repository` utility offers several methods which can be launched asynchronously: -- `git_push` -- `git_pull` -- `push_to_hub` -- The `commit` context manager - -See below for utilities to manage such asynchronous methods. - -[[autodoc]] Repository - - commands_failed - - commands_in_progress - - wait_for_commands - -[[autodoc]] huggingface_hub.repository.CommandInProgress diff --git a/docs/source/fr/_toctree.yml b/docs/source/fr/_toctree.yml index f6c76ff6f5..d9ed776e0a 100644 --- a/docs/source/fr/_toctree.yml +++ b/docs/source/fr/_toctree.yml @@ -6,10 +6,6 @@ title: Démarrage rapide - local: installation title: Installation -- title: "Concepts" - sections: - - local: concepts/git_vs_http - title: Git ou HTTP? - title: "Guides" sections: - local: guides/integrations diff --git a/docs/source/fr/concepts/git_vs_http.md b/docs/source/fr/concepts/git_vs_http.md deleted file mode 100644 index 8ccc31b69c..0000000000 --- a/docs/source/fr/concepts/git_vs_http.md +++ /dev/null @@ -1,67 +0,0 @@ - - -# Git ou HTTP? - -`huggingface_hub` est une librairie qui permet d'interagir avec le Hugging Face Hub, -qui est une collection de dépots Git (modèles, datasets ou spaces). -Il y a deux manières principales pour accéder au Hub en utilisant `huggingface_hub`. - -La première approche, basée sur Git, appelée approche "git-based", est rendue possible par la classe [`Repository`]. -Cette méthode utilise un wrapper autour de la commande `git` avec des fonctionnalités supplémentaires conçues pour interagir avec le Hub. La deuxième option, appelée approche "HTTP-based" , consiste à faire des requêtes HTTP en utilisant le client [`HfApi`]. Examinons -les avantages et les inconvénients de ces deux méthodes. - -## Repository: l'approche historique basée sur git - -Initialement, `huggingface_hub` était principalement construite autour de la classe [`Repository`]. Elle fournit des -wrappers Python pour les commandes `git` usuelles, telles que `"git add"`, `"git commit"`, `"git push"`, -`"git tag"`, `"git checkout"`, etc. - -Cette librairie permet aussi de gérer l'authentification et les fichiers volumineux, souvent présents dans les dépôts Git de machine learning. De plus, ses méthodes sont exécutables en arrière-plan, ce qui est utile pour upload des données durant l'entrainement d'un modèle. - -L'avantage principal de l'approche [`Repository`] est qu'elle permet de garder une -copie en local du dépot Git sur votre machine. Cela peut aussi devenir un désavantage, -car cette copie locale doit être mise à jour et maintenue constamment. C'est une méthode -analogue au développement de logiciel classique où chaque développeur maintient sa propre copie locale -et push ses changements lorsqu'il travaille sur une nouvelle fonctionnalité. -Toutefois, dans le contexte du machine learning la taille des fichiers rend peu pertinente cette approche car -les utilisateurs ont parfois besoin d'avoir -uniquement les poids des modèles pour l'inférence ou de convertir ces poids d'un format à un autre sans avoir à cloner -tout le dépôt. - - - -[`Repository`] est maintenant obsolète et remplacée par les alternatives basées sur des requêtes HTTP. Étant donné son adoption massive par les utilisateurs, -la suppression complète de [`Repository`] ne sera faite que pour la version `v1.0`. - - - -## HfApi: Un client HTTP plus flexible - -La classe [`HfApi`] a été développée afin de fournir une alternative aux dépôts git locaux, -qui peuvent être encombrant à maintenir, en particulier pour des modèles ou datasets volumineux. -La classe [`HfApi`] offre les mêmes fonctionnalités que les approches basées sur Git, -telles que le téléchargement et le push de fichiers ainsi que la création de branches et de tags, mais sans -avoir besoin d'un fichier local qui doit être constamment synchronisé. - -En plus des fonctionnalités déjà fournies par `git`, La classe [`HfApi`] offre des fonctionnalités -additionnelles, telles que la capacité à gérer des dépôts, le téléchargement des fichiers -dans le cache (permettant une réutilisation), la recherche dans le Hub pour trouver -des dépôts et des métadonnées, l'accès aux fonctionnalités communautaires telles que, les discussions, -les pull requests et les commentaires. - -## Quelle méthode utiliser et quand ? - -En général, **l'approche HTTP est la méthode recommandée** pour utiliser `huggingface_hub` -[`HfApi`] permet de pull et push des changements, de travailler avec les pull requests, les tags et les branches, l'interaction avec les discussions -et bien plus encore. Depuis la version `0.16`, les méthodes HTTP-based peuvent aussi être exécutées en arrière-plan, ce qui constituait le -dernier gros avantage de la classe [`Repository`]. - -Toutefois, certaines commandes restent indisponibles en utilisant [`HfApi`]. -Peut être que certaines ne le seront jamais, mais nous essayons toujours de réduire le fossé entre ces deux approches. -Si votre cas d'usage n'est pas couvert, nous serions ravis de vous aider. Pour cela, ouvrez -[une issue sur Github](https://github.com/huggingface/huggingface_hub)! Nous écoutons tous les retours nous permettant de construire -l'écosystème 🤗 avec les utilisateurs et pour les utilisateurs. - -Cette préférence pour l'approche basée sur [`HfApi`] plutôt que [`Repository`] ne signifie pas que les dépôts stopperons d'être versionnés avec git sur le Hugging Face Hub. Il sera toujours possible d'utiliser les commandes `git` en local lorsque nécessaire. \ No newline at end of file diff --git a/docs/source/hi/_toctree.yml b/docs/source/hi/_toctree.yml index 5b9e412c50..f8b3606536 100644 --- a/docs/source/hi/_toctree.yml +++ b/docs/source/hi/_toctree.yml @@ -6,7 +6,3 @@ title: जल्दी शुरू - local: installation title: इंस्टालेशन -- title: "संकल्पना मार्गदर्शिकाएँ" - sections: - - local: concepts/git_vs_http - title: "संकल्पनाएँ/गिट_बनाम_एचटीटीपी" diff --git a/docs/source/hi/concepts/git_vs_http.md b/docs/source/hi/concepts/git_vs_http.md deleted file mode 100644 index ebb3574352..0000000000 --- a/docs/source/hi/concepts/git_vs_http.md +++ /dev/null @@ -1,33 +0,0 @@ -# Git vs HTTP पैराडाइम - -`huggingface_hub` लाइब्रेरी Hugging Face Hub के साथ आदान-प्रदान करने के लिए एक लाइब्रेरी है, जो git-आधारित repositories (models, datasets या Spaces) का एक संग्रह है। `huggingface_hub` का उपयोग करके Hub तक पहुंचने के दो मुख्य तरीके हैं। - -पहला तरीका, जिसे "git-आधारित" तरीका कहा जाता है, [`Repository`] क्लास द्वारा संचालित है। यह विधि `git` कमांड के चारों ओर एक आवरण का उपयोग करती है जिसमें Hub के साथ आदान-प्रदान करने के लिए विशेष रूप से डिज़ाइन किए गए अतिरिक्त functions हैं। दूसरा विकल्प, जिसे "HTTP-आधारित" तरीका कहा जाता है, [`HfApi`] client का उपयोग करके HTTP requests बनाने में शामिल है। आइए प्रत्येक तरीका के फायदे और नुकसान की जांच करते हैं। - -## Repository: ऐतिहासिक git-आधारित तरीका - -शुरुआत में, `huggingface_hub` मुख्य रूप से [`Repository`] क्लास के चारों ओर बनाया गया था। यह सामान्य `git` कमांड जैसे `"git add"`, `"git commit"`, `"git push"`, `"git tag"`, `"git checkout"`, आदि के लिए Python wrappers प्रदान करता है। - -लाइब्रेरी विवरण सेट करने और बड़ी फाइलों को track करने में भी मदद करती है, जो अक्सर machine learning repositories में उपयोग की जाती हैं। इसके अतिरिक्त, लाइब्रेरी आपको अपनी विधियों को पृष्ठभूमि में कार्यान्वित करने की अनुमति देती है, जो training के दौरान डेटा अपलोड करने के लिए उपयोगी है। - -[`Repository`] का उपयोग करने का मुख्य फायदा यह है कि यह आपको अपनी मशीन पर संपूर्ण repository की एक local copy बनाए रखने की अनुमति देता है। यह एक नुकसान भी हो सकता है क्योंकि इसके लिए आपको इस local copy को लगातार update और maintain करना होता है। यह पारंपरिक software development के समान है जहां प्रत्येक developer अपनी स्वयं की local copy maintain करता है और feature पर काम करते समय changes push करता है। हालांकि, machine learning के संदर्भ में, यह हमेशा आवश्यक नहीं हो सकता क्योंकि users को केवल inference के लिए weights download करने या weights को एक format से दूसरे में convert करने की आवश्यकता हो सकती है, बिना पूरी repository को clone करने की आवश्यकता के। - - - -[`Repository`] अब http-आधारित विकल्पों के पक्ष में deprecated है। legacy code में इसकी बड़ी अपनाई जाने के कारण, [`Repository`] का पूर्ण removal केवल `v1.0` release में होगा। - - - -## HfApi: एक लचीला और सुविधाजनक HTTP client - -[`HfApi`] क्लास को local git repositories का एक विकल्प प्रदान करने के लिए विकसित किया गया था, जो maintain करना मुश्किल हो सकता है, विशेष रूप से बड़े models या datasets के साथ व्यवहार करते समय। [`HfApi`] क्लास git-आधारित तरीकाों की समान functionality प्रदान करती है, जैसे files download और push करना और branches तथा tags बनाना, लेकिन एक local folder की आवश्यकता के बिना जिसे sync में रखना पड़ता है। - -`git` द्वारा पहले से प्रदान की गई functionalities के अलावा, [`HfApi`] क्लास अतिरिक्त features प्रदान करती है, जैसे repos manage करने की क्षमता, efficient reuse के लिए caching का उपयोग करके files download करना, repos और metadata के लिए Hub को search करना, discussions, PRs, और comments जैसी community features तक पहुंच, और Spaces hardware और secrets को configure करना। - -## मुझे क्या उपयोग करना चाहिए? और कब? - -कुल मिलाकर, **HTTP-आधारित तरीका सभी cases में** `huggingface_hub` का उपयोग करने का **अनुशंसित तरीका है**। [`HfApi`] changes को pull और push करने, PRs, tags और branches के साथ काम करने, discussions के साथ interact करने और बहुत कुछ करने की अनुमति देता है। `0.16` release के बाद से, http-आधारित methods भी पृष्ठभूमि में चल सकती हैं, जो [`Repository`] क्लास का अंतिम प्रमुख फायदा था। - -हालांकि, सभी git commands [`HfApi`] के माध्यम से उपलब्ध नहीं हैं। कुछ को कभी भी implement नहीं किया जा सकता है, लेकिन हम हमेशा सुधार करने और gap को बंद करने की कोशिश कर रहे हैं। यदि आपको अपना use case covered नहीं दिखता है, तो कृपया [Github पर एक issue खोलें](https://github.com/huggingface/huggingface_hub)! हम अपने users के साथ और उनके लिए 🤗 ecosystem बनाने में मदद करने के लिए feedback का स्वागत करते हैं। - -git-आधारित [`Repository`] पर http-आधारित [`HfApi`] की यह प्राथमिकता का मतलब यह नहीं है कि git versioning Hugging Face Hub से जल्द ही गायब हो जाएगी। workflows में जहां यह समझ में आता है, वहां `git` commands का locally उपयोग करना हमेशा संभव होगा। \ No newline at end of file diff --git a/docs/source/ko/_toctree.yml b/docs/source/ko/_toctree.yml index 2c7a4da702..0a82cd72db 100644 --- a/docs/source/ko/_toctree.yml +++ b/docs/source/ko/_toctree.yml @@ -40,10 +40,6 @@ title: 라이브러리 통합 - local: guides/webhooks_server title: 웹훅 서버 -- title: "개념 가이드" - sections: - - local: concepts/git_vs_http - title: Git 대 HTTP 패러다임 - title: "라이브러리 레퍼런스" sections: - local: package_reference/overview @@ -52,8 +48,6 @@ title: 로그인 및 로그아웃 - local: package_reference/environment_variables title: 환경 변수 - - local: package_reference/repository - title: 로컬 및 온라인 리포지토리 관리 - local: package_reference/hf_api title: 허깅페이스 Hub API - local: package_reference/file_download diff --git a/docs/source/ko/concepts/git_vs_http.md b/docs/source/ko/concepts/git_vs_http.md deleted file mode 100644 index 7f2bd9933f..0000000000 --- a/docs/source/ko/concepts/git_vs_http.md +++ /dev/null @@ -1,53 +0,0 @@ - - -# Git 대 HTTP 패러다임 - -`huggingface_hub` 라이브러리는 git 기반의 저장소(Models, Datasets 또는 Spaces)로 구성된 Hugging Face Hub과 상호 작용하기 위한 라이브러리입니다. -`huggingface_hub`를 사용하여 Hub에 접근하는 방법은 크게 두 가지입니다. - -첫 번째 접근 방식인 소위 "git 기반" 접근 방식은 [`Repository`] 클래스가 주도합니다. -이 방법은 허브와 상호 작용하도록 특별히 설계된 추가 기능이 있는 `git` 명령에 랩퍼를 사용합니다. -두 번째 방법은 "HTTP 기반" 접근 방식이며, [`HfApi`] 클라이언트를 사용하여 HTTP 요청을 수행합니다. -각 방법의 장단점을 살펴보겠습니다. - -## Repository: 역사적인 Git 기반 접근 방식 - -먼저, `huggingface_hub`는 주로 [`Repository`] 클래스를 기반으로 구축되었습니다. -이 클래스는 `"git add"`, `"git commit"`, `"git push"`, `"git tag"`, `"git checkout"` 등과 같은 일반적인 `git` 명령에 대한 Python 랩퍼를 제공합니다. - -이 라이브러리는 머신러닝 저장소에서 자주 사용되는 큰 파일을 추적하고 자격 증명을 설정하는 데 도움이 됩니다. -또한, 이 라이브러리는 백그라운드에서 메소드를 실행할 수 있어, 훈련 중에 데이터를 업로드할 때 유용합니다. - -로컬 머신에 전체 저장소의 로컬 복사본을 유지할 수 있다는 것은 [`Repository`]를 사용하는 가장 큰 장점입니다. -하지만 동시에 로컬 복사본을 지속적으로 업데이트하고 유지해야 한다는 단점이 될 수도 있습니다. -이는 각 개발자가 자체 로컬 복사본을 유지하고 기능을 개발할 때 변경 사항을 push하는 전통적인 소프트웨어 개발과 유사합니다. -그러나 머신러닝의 경우, 사용자가 전체 저장소를 복제할 필요 없이 추론을 위해 가중치만 다운로드하거나 가중치를 한 형식에서 다른 형식으로 변환하기만 하면 되기 때문에 이런 방식이 항상 필요한 것은 아닙니다. - - - -[`Repository`]는 지원이 중단될 예정이므로 HTTP 기반 대안을 사용하는 것을 권장합니다. 기존 코드에서 널리 사용되기 때문에 [`Repository`]의 완전한 제거는 릴리스 `v1.0`에서 이루어질 예정입니다. - - - -## HfApi: 유연하고 편리한 HTTP 클라이언트 - -[`HfApi`] 클래스는 특히 큰 모델이나 데이터셋을 처리할 때 유지하기 어려운 로컬 git 저장소의 대안으로 개발되었습니다. -[`HfApi`] 클래스는 파일 다운로드 및 push, 브랜치 및 태그 생성과 같은 git 기반 접근 방식과 동일한 기능을 제공하지만, 동기화 상태를 유지해야 하는 로컬 폴더가 필요하지 않습니다. - -[`HfApi`] 클래스는 `git`이 제공하는 기능 외에도 추가적인 기능을 제공합니다. -저장소를 관리하고, 효율적인 재사용을 위해 캐싱을 사용하여 파일을 다운로드하고, Hub에서 저장소 및 메타데이터를 검색하고, 토론, PR 및 코멘트와 같은 커뮤니티 기능에 접근하고, Spaces 하드웨어 및 시크릿을 구성할 수 있습니다. - -## 무엇을 사용해야 하나요? 언제 사용하나요? - -전반적으로, **HTTP 기반 접근 방식은 모든 경우에** `huggingface_hub`를 사용하는 것이 좋습니다. -[`HfApi`]를 사용하면 변경 사항을 pull하고 push하고, PR, 태그 및 브랜치로 작업하고, 토론과 상호 작용하는 등의 작업을 할 수 있습니다. -`0.16` 릴리스부터는 [`Repository`] 클래스의 마지막 주요 장점이었던 http 기반 메소드도 백그라운드에서 실행할 수 있습니다. - -그러나 모든 git 명령이 [`HfApi`]를 통해 사용 가능한 것은 아닙니다. 일부는 구현되지 않을 수도 있지만, 저희는 항상 개선하고 격차를 줄이기 위해 노력하고 있습니다. -사용 사례에 해당되지 않는 경우, [Github에서 이슈](https://github.com/huggingface/huggingface_hub)를 개설해 주세요! -사용자와 함께, 사용자를 위한 🤗 생태계를 구축하는 데 도움이 되는 피드백을 환영합니다. - -git 기반 [`Repository`]보다 http 기반 [`HfApi`]를 선호한다고 해서 Hugging Face Hub에서 git 버전 관리가 바로 사라지는 것은 아닙니다. -워크플로우 상 합당하다면 언제든 로컬에서 `git` 명령을 사용할 수 있습니다. diff --git a/docs/source/ko/guides/repository.md b/docs/source/ko/guides/repository.md deleted file mode 100644 index 7544608d1c..0000000000 --- a/docs/source/ko/guides/repository.md +++ /dev/null @@ -1,223 +0,0 @@ - - -# 리포지토리 생성과 관리[[create-and-manage-a-repository]] - -Hugging Face Hub는 Git 리포지토리 모음입니다. [Git](https://git-scm.com/)은 협업을 할 때 여러 프로젝트 버전을 쉽게 관리하기 위해 널리 사용되는 소프트웨어 개발 도구입니다. 이 가이드에서는 Hub의 리포지토리 사용법인 다음 내용을 다룹니다: - -- 리포지토리 생성과 삭제. -- 태그 및 브랜치 관리. -- 리포지토리 이름 변경. -- 리포지토리 공개 여부. -- 리포지토리 복사본 관리. - - - -GitLab/GitHub/Bitbucket과 같은 플랫폼을 사용해 본 경험이 있다면, 모델 리포지토리를 관리하기 위해 `git` CLI를 사용해 git 리포지토리를 클론(`git clone`)하고 변경 사항을 커밋(`git add, git commit`)하고 커밋한 내용을 푸시(`git push`) 하는것이 가장 먼저 떠오를 것입니다. 이 명령어들은 Hugging Face Hub에서도 사용할 수 있습니다. 하지만 소프트웨어 엔지니어링과 머신러닝은 동일한 요구 사항과 워크플로우를 공유하지 않습니다. 모델 리포지토리는 다양한 프레임워크와 도구를 위한 대규모 모델 가중치 파일을 유지관리 할 수 있으므로, 리포지토리를 복제하면 대규모 로컬 폴더를 유지관리하고 막대한 크기의 파일을 다루게 될 수 있습니다. 결과적으로 Hugging Face의 커스텀 HTTP 방법을 사용하는 것이 더욱 효율적일 수 있습니다. 더 자세한 내용은 [Git vs HTTP paradigm](../concepts/git_vs_http) 문서를 참조하세요. - - - -Hub에 리포지토리를 생성하고 관리하려면, 로그인이 되어 있어야 합니다. 로그인이 안 되어있다면 [이 문서](../quick-start#authentication)를 참고해 주세요. 이 가이드에서는 로그인이 되어있다는 가정하에 진행됩니다. - -## 리포지토리 생성 및 삭제[[repo-creation-and-deletion]] - -첫 번째 단계는 어떻게 리포지토리를 생성하고 삭제하는지를 알아야 합니다. 사용자 이름 네임스페이스 아래에 소유한 리포지토리 또는 쓰기 권한이 있는 조직의 리포지토리만 관리할 수 있습니다. - -### 리포지토리 생성[[create-a-repository]] - -[`create_repo`] 함수로 함께 빈 리포지토리를 만들고 `repo_id` 매개변수를 사용하여 이름을 정하세요. `repo_id`는 사용자 이름 또는 조직 이름 뒤에 리포지토리 이름이 따라옵니다: `username_or_org/repo_name`. - -```py ->>> from huggingface_hub import create_repo ->>> create_repo("lysandre/test-model") -'https://huggingface.co/lysandre/test-model' -``` - -기본적으로 [`create_repo`]는 모델 리포지토리를 만듭니다. 하지만 `repo_type` 매개변수를 사용하여 다른 유형의 리포지토리를 지정할 수 있습니다. 예를 들어 데이터셋 리포지토리를 만들고 싶다면: - -```py ->>> from huggingface_hub import create_repo ->>> create_repo("lysandre/test-dataset", repo_type="dataset") -'https://huggingface.co/datasets/lysandre/test-dataset' -``` - -리포지토리를 만들 때, `private` 매개변수를 사용하여 가시성을 설정할 수 있습니다. - -```py ->>> from huggingface_hub import create_repo ->>> create_repo("lysandre/test-private", private=True) -``` - -추후 리포지토리 가시성을 변경하고 싶다면, [`update_repo_settings`] 함수를 이용해 바꿀 수 있습니다. - -### 리포지토리 삭제[[delete-a-repository]] - -[`delete_repo`]를 사용하여 리포지토리를 삭제할 수 있습니다. 리포지토리를 삭제하기 전에 신중히 결정하세요. 왜냐하면, 삭제하고 나서 다시 되돌릴 수 없는 프로세스이기 때문입니다! - -삭제하려는 리포지토리의 `repo_id`를 지정하세요: - -```py ->>> delete_repo(repo_id="lysandre/my-corrupted-dataset", repo_type="dataset") -``` - -### 리포지토리 복제(Spaces 전용)[[duplicate-a-repository-only-for-spaces]] - -가끔 다른 누군가의 리포지토리를 복사하여, 상황에 맞게 수정하고 싶을 때가 있습니다. 이는 [`duplicate_space`]를 사용하여 Space에 복사할 수 있습니다. 이 함수를 사용하면 리포지토리 전체를 복제할 수 있습니다. 그러나 여전히 하드웨어, 절전 시간, 리포지토리, 변수 및 비밀번호와 같은 자체 설정을 구성해야 합니다. 자세한 내용은 [Manage your Space](./manage-spaces) 문서를 참조하십시오. - -```py ->>> from huggingface_hub import duplicate_space ->>> duplicate_space("multimodalart/dreambooth-training", private=False) -RepoUrl('https://huggingface.co/spaces/nateraw/dreambooth-training',...) -``` - -## 파일 다운로드와 업로드[[upload-and-download-files]] - -이제 리포지토리를 생성했으므로, 변경 사항을 푸시하고 파일을 다운로드하는 것에 관심이 있을 것입니다. - -이 두 가지 주제는 각각 자체 가이드가 필요합니다. 리포지토리 사용하는 방법에 대해 알아보려면 [업로드](./upload) 및 [다운로드](./download) 문서를 참조하세요. - -## 브랜치와 태그[[branches-and-tags]] - -Git 리포지토리는 동일한 리포지토리의 다른 버전을 저장하기 위해 브랜치들을 사용합니다. 태그는 버전을 출시할 때와 같이 리포지토리의 특정 상태를 표시하는 데 사용될 수도 있습니다. 일반적으로 브랜치와 태그는 [git 참조](https://git-scm.com/book/en/v2/Git-Internals-Git-References) -로 참조됩니다. - -### 브랜치 생성과 태그[[create-branches-and-tags]] - -[`create_branch`]와 [`create_tag`]를 이용하여 새로운 브랜치와 태그를 생성할 수 있습니다. - -```py ->>> from huggingface_hub import create_branch, create_tag - -# `main` 브랜치를 기반으로 Space 저장소에 새 브랜치를 생성합니다. ->>> create_branch("Matthijs/speecht5-tts-demo", repo_type="space", branch="handle-dog-speaker") - -# `v0.1-release` 브랜치를 기반으로 Dataset 저장소에 태그를 생성합니다. ->>> create_tag("bigcode/the-stack", repo_type="dataset", revision="v0.1-release", tag="v0.1.1", tag_message="Bump release version.") -``` - -같은 방식으로 [`delete_branch`]와 [`delete_tag`] 함수를 사용하여 브랜치 또는 태그를 삭제할 수 있습니다. - -### 모든 브랜치와 태그 나열[[list-all-branches-and-tags]] - -[`list_repo_refs`]를 사용하여 리포지토리로부터 현재 존재하는 git 참조를 나열할 수 있습니다: - -```py ->>> from huggingface_hub import list_repo_refs ->>> list_repo_refs("bigcode/the-stack", repo_type="dataset") -GitRefs( - branches=[ - GitRefInfo(name='main', ref='refs/heads/main', target_commit='18edc1591d9ce72aa82f56c4431b3c969b210ae3'), - GitRefInfo(name='v1.1.a1', ref='refs/heads/v1.1.a1', target_commit='f9826b862d1567f3822d3d25649b0d6d22ace714') - ], - converts=[], - tags=[ - GitRefInfo(name='v1.0', ref='refs/tags/v1.0', target_commit='c37a8cd1e382064d8aced5e05543c5f7753834da') - ] -) -``` - -## 리포지토리 설정 변경[[change-repository-settings]] - -리포지토리는 구성할 수 있는 몇 가지 설정이 있습니다. 대부분의 경우 브라우저의 리포지토리 설정 페이지에서 직접 설정할 것입니다. 설정을 바꾸려면 리포지토리에 대한 쓰기 액세스 권한이 있어야 합니다(사용자 리포지토리거나, 조직의 구성원이어야 함). 이 주제에서는 `huggingface_hub`를 사용하여 프로그래밍 방식으로 구성할 수 있는 설정을 알아보겠습니다. - -Spaces를 위한 특정 설정들(하드웨어, 환경변수 등)을 구성하기 위해서는 [Manage your Spaces](../guides/manage-spaces) 문서를 참조하세요. - -### 가시성 업데이트[[update-visibility]] - -리포지토리는 공개 또는 비공개로 설정할 수 있습니다. 비공개 리포지토리는 해당 저장소의 사용자 혹은 소속된 조직의 구성원만 볼 수 있습니다. 다음과 같이 리포지토리를 비공개로 변경할 수 있습니다. - -```py ->>> from huggingface_hub import update_repo_settings ->>> update_repo_settings(repo_id=repo_id, private=True) -``` - -### 리포지토리 이름 변경[[rename-your-repository]] - -[`move_repo`]를 사용하여 Hub에 있는 리포지토리 이름을 변경할 수 있습니다. 이 함수를 사용하여 개인에서 조직 리포지토리로 이동할 수도 있습니다. 이렇게 하면 [일부 제한 사항](https://hf.co/docs/hub/repositories-settings#renaming-or-transferring-a-repo)이 있으므로 주의해야 합니다. 예를 들어, 다른 사용자에게 리포지토리를 이전할 수는 없습니다. - -```py ->>> from huggingface_hub import move_repo ->>> move_repo(from_id="Wauplin/cool-model", to_id="huggingface/cool-model") -``` - -## 리포지토리의 로컬 복사본 관리[[manage-a-local-copy-of-your-repository]] - -위에 설명한 모든 작업은 HTTP 요청을 사용하여 작업할 수 있습니다. 그러나 경우에 따라 로컬 복사본을 가지고 익숙한 Git 명령어를 사용하여 상호 작용하는 것이 편리할 수 있습니다. - -[`Repository`] 클래스는 Git 명령어와 유사한 기능을 제공하는 함수를 사용하여 Hub의 파일 및 리포지토리와 상호 작용할 수 있습니다. 이는 이미 알고 있고 좋아하는 Git 및 Git-LFS 방법을 사용하는 래퍼(wrapper)입니다. 시작하기 전에 Git-LFS가 설치되어 있는지 확인하세요([여기서](https://git-lfs.github.com/) 설치 지침을 확인할 수 있습니다). - - - -[`Repository`]는 [`HfApi`]에 구현된 HTTP 기반 대안을 선호하여 중단되었습니다. 아직 많은 레거시 코드에서 사용되고 있기 때문에 [`Repository`]가 완전히 제거되는 건 `v1.0` 릴리스에서만 이루어집니다. 자세한 내용은 [해당 설명 페이지](./concepts/git_vs_http)를 참조하세요. - - - -### 로컬 리포지토리 사용[[use-a-local-repository]] - -로컬 리포지토리 경로를 사용하여 [`Repository`] 객체를 생성하세요: - -```py ->>> from huggingface_hub import Repository ->>> repo = Repository(local_dir="//") -``` - -### 복제[[clone]] - -`clone_from` 매개변수는 Hugging Face 리포지토리 ID에서 로컬 디렉터리로 리포지토리를 복제합니다. 이때 `local_dir` 매개변수를 사용하여 로컬 디렉터리에 저장합니다: - -```py ->>> from huggingface_hub import Repository ->>> repo = Repository(local_dir="w2v2", clone_from="facebook/wav2vec2-large-960h-lv60") -``` - -`clone_from`은 URL을 사용해 리포지토리를 복제할 수 있습니다. - -```py ->>> repo = Repository(local_dir="huggingface-hub", clone_from="https://huggingface.co/facebook/wav2vec2-large-960h-lv60") -``` - -`clone_from` 매개변수를 [`create_repo`]와 결합하여 리포지토리를 만들고 복제할 수 있습니다. - -```py ->>> repo_url = create_repo(repo_id="repo_name") ->>> repo = Repository(local_dir="repo_local_path", clone_from=repo_url) -``` - -리포지토리를 복제할 때 `git_user` 및 `git_email` 매개변수를 지정함으로써 복제한 리포지토리에 Git 사용자 이름과 이메일을 설정할 수 있습니다. 사용자가 해당 리포지토리에 커밋하면 Git은 커밋 작성자를 인식합니다. - -```py ->>> repo = Repository( -... "my-dataset", -... clone_from="/", -... token=True, -... repo_type="dataset", -... git_user="MyName", -... git_email="me@cool.mail" -... ) -``` - -### 브랜치[[branch]] - -브랜치는 현재 코드와 파일에 영향을 미치지 않으면서 협업과 실험에 중요합니다.[`~Repository.git_checkout`]을 사용하여 브랜치 간에 전환할 수 있습니다. 예를 들어, `branch1`에서 `branch2`로 전환하려면: - -```py ->>> from huggingface_hub import Repository ->>> repo = Repository(local_dir="huggingface-hub", clone_from="/", revision='branch1') ->>> repo.git_checkout("branch2") -``` - -### 끌어오기[[pull]] - -[`~Repository.git_pull`]은 원격 리포지토리로부터의 변경사항을 현재 로컬 브랜치에 업데이트하게 합니다. - -```py ->>> from huggingface_hub import Repository ->>> repo.git_pull() -``` - -브랜치가 원격에서의 새 커밋으로 업데이트 된 후에 로컬 커밋을 수행하고자 한다면 `rebase=True`를 설정하세요: - -```py ->>> repo.git_pull(rebase=True) -``` diff --git a/docs/source/ko/guides/upload.md b/docs/source/ko/guides/upload.md index 4f61f81faf..ced6093c5d 100644 --- a/docs/source/ko/guides/upload.md +++ b/docs/source/ko/guides/upload.md @@ -4,12 +4,7 @@ rendered properly in your Markdown viewer. # Hub에 파일 업로드하기[[upload-files-to-the-hub]] -파일과 작업물을 공유하는 것은 Hub의 주요 특성 중 하나입니다. `huggingface_hub`는 Hub에 파일을 업로드하기 위한 몇 가지 옵션을 제공합니다. 이러한 기능을 단독으로 사용하거나 라이브러리에 통합하여 해당 라이브러리의 사용자가 Hub와 더 편리하게 상호작용할 수 있도록 도울 수 있습니다. 이 가이드에서는 파일을 푸시하는 다양한 방법에 대해 설명합니다: - -- Git을 사용하지 않고 푸시하기. -- [Git LFS](https://git-lfs.github.com/)를 사용하여 매우 큰 파일을 푸시하기. -- `commit` 컨텍스트 매니저를 사용하여 푸시하기. -- [`~Repository.push_to_hub`] 함수를 사용하여 푸시하기. +파일과 작업물을 공유하는 것은 Hub의 주요 특성 중 하나입니다. `huggingface_hub`는 Hub에 파일을 업로드하기 위한 몇 가지 옵션을 제공합니다. 이러한 기능을 단독으로 사용하거나 라이브러리에 통합하여 해당 라이브러리의 사용자가 Hub와 더 편리하게 상호작용할 수 있도록 도울 수 있습니다. Hub에 파일을 업로드 하려면 허깅페이스 계정으로 로그인해야 합니다. 인증에 대한 자세한 내용은 [이 페이지](../quick-start#authentication)를 참조해 주세요. @@ -435,118 +430,3 @@ Hub에서 리포지토리를 구성하는 방법에 대한 모범 사례는 [리 자세한 내용은 [이 섹션](https://huggingface.co/docs/huggingface_hub/hf_transfer)을 참조하세요.
- -## (레거시) Git LFS로 파일 업로드하기[[legacy-upload-files-with-git-lfs]] - -위에서 설명한 모든 방법은 Hub의 API를 사용하여 파일을 업로드하며, 이는 Hub에 파일을 업로드하는 데 권장되는 방법입니다. -이뿐만 아니라 로컬 리포지토리를 관리하기 위하여 git 도구의 래퍼인 [`Repository`]또한 제공합니다. - - - -[`Repository`]는 공식적으로 지원 종료된 것은 아니지만, 가급적이면 위에서 설명한 HTTP 기반 방법들을 사용할 것을 권장합니다. -이 권장 사항에 대한 자세한 내용은 HTTP 기반 방식과 Git 기반 방식 간의 핵심적인 차이점을 설명하는 [이 가이드](../concepts/git_vs_http)를 참조하세요. - - - -Git LFS는 10MB보다 큰 파일을 자동으로 처리합니다. 하지만 매우 큰 파일(5GB 이상)의 경우, Git LFS용 사용자 지정 전송 에이전트를 설치해야 합니다: - -```bash -hf lfs-enable-largefiles -``` - -매우 큰 파일이 있는 각 리포지토리에 대해 이 옵션을 설치해야 합니다. -설치가 완료되면 5GB보다 큰 파일을 푸시할 수 있습니다. - -### 커밋 컨텍스트 관리자[[commit-context-manager]] - -`commit` 컨텍스트 관리자는 가장 일반적인 네 가지 Git 명령인 pull, add, commit, push를 처리합니다. -`git-lfs`는 10MB보다 큰 파일을 자동으로 추적합니다. -다음 예제에서는 `commit` 컨텍스트 관리자가 다음과 같은 작업을 수행합니다: - -1. `text-files` 리포지토리에서 pull. -2. `file.txt`에 변경 내용을 add. -3. 변경 내용을 commit. -4. 변경 내용을 `text-files` 리포지토리에 push. - -```python ->>> from huggingface_hub import Repository ->>> with Repository(local_dir="text-files", clone_from="/text-files").commit(commit_message="My first file :)"): -... with open("file.txt", "w+") as f: -... f.write(json.dumps({"hey": 8})) -``` - -다음은 `commit` 컨텍스트 관리자를 사용하여 파일을 저장하고 리포지토리에 업로드하는 방법의 또 다른 예입니다: - -```python ->>> import torch ->>> model = torch.nn.Transformer() ->>> with Repository("torch-model", clone_from="/torch-model", token=True).commit(commit_message="My cool model :)"): -... torch.save(model.state_dict(), "model.pt") -``` - -커밋을 비동기적으로 푸시하려면 `blocking=False`를 설정하세요. -커밋을 푸시하는 동안 스크립트를 계속 실행하고 싶을 때 논 블로킹 동작이 유용합니다. - -```python ->>> with repo.commit(commit_message="My cool model :)", blocking=False) -``` - -`command_queue` 메서드로 푸시 상태를 확인할 수 있습니다: - -```python ->>> last_command = repo.command_queue[-1] ->>> last_command.status -``` - -가능한 상태는 아래 표를 참조하세요: - -| 상태 | 설명 | -| -------- | ----------------------------- | -| -1 | 푸시가 진행 중입니다. | -| 0 | 푸시가 성공적으로 완료되었습니다.| -| Non-zero | 오류가 발생했습니다. | - -`blocking=False`인 경우, 명령이 추적되며 스크립트에서 다른 오류가 발생하더라도 모든 푸시가 완료된 경우에만 스크립트가 종료됩니다. -푸시 상태를 확인하는 데 유용한 몇 가지 추가 명령은 다음과 같습니다: - -```python -# 오류를 검사합니다. ->>> last_command.stderr - -# 푸시 진행여부를 확인합니다. ->>> last_command.is_done - -# 푸시 명령의 에러여부를 파악합니다. ->>> last_command.failed -``` - -### push_to_hub[[pushtohub]] - -[`Repository`] 클래스에는 파일을 추가하고 커밋한 후 리포지토리로 푸시하는 [`~Repository.push_to_hub`] 함수가 있습니다. [`~Repository.push_to_hub`]는 `commit` 컨텍스트 관리자와는 달리 호출하기 전에 먼저 리포지토리에서 업데이트(pull) 작업을 수행 해야 합니다. - -예를 들어 Hub에서 리포지토리를 이미 복제했다면 로컬 디렉터리에서 `repo`를 초기화할 수 있습니다: - -```python ->>> from huggingface_hub import Repository ->>> repo = Repository(local_dir="path/to/local/repo") -``` - -로컬 클론을 [`~Repository.git_pull`]로 업데이트한 다음 파일을 Hub로 푸시합니다: - -```py ->>> repo.git_pull() ->>> repo.push_to_hub(commit_message="Commit my-awesome-file to the Hub") -``` - -그러나 아직 파일을 푸시할 준비가 되지 않았다면 [`~Repository.git_add`] 와 [`~Repository.git_commit`]을 사용하여 파일만 추가하고 커밋할 수 있습니다: - -```py ->>> repo.git_add("path/to/file") ->>> repo.git_commit(commit_message="add my first model config file :)") -``` - -준비가 완료되면 [`~Repository.git_push`]를 사용하여 파일을 리포지토리에 푸시합니다: - -```py ->>> repo.git_push() -``` diff --git a/docs/source/ko/package_reference/repository.md b/docs/source/ko/package_reference/repository.md deleted file mode 100644 index fc70e3e203..0000000000 --- a/docs/source/ko/package_reference/repository.md +++ /dev/null @@ -1,49 +0,0 @@ - - -# 로컬 및 온라인 리포지토리 관리[[managing-local-and-online-repositories]] - -`Repository` 클래스는 `git` 및 `git-lfs` 명령을 감싸는 도우미 클래스로, 매우 큰 리포지토리를 관리하는 데 적합한 툴링을 제공합니다. - -`git` 작업이 포함되거나 리포지토리에서의 협업이 중점이 될 때 권장되는 도구입니다. - -## 리포지토리 클래스[[the-repository-class]] - -[[autodoc]] Repository - - __init__ - - current_branch - - all - -## 도우미 메소드[[helper-methods]] - -[[autodoc]] huggingface_hub.repository.is_git_repo - -[[autodoc]] huggingface_hub.repository.is_local_clone - -[[autodoc]] huggingface_hub.repository.is_tracked_with_lfs - -[[autodoc]] huggingface_hub.repository.is_git_ignored - -[[autodoc]] huggingface_hub.repository.files_to_be_staged - -[[autodoc]] huggingface_hub.repository.is_tracked_upstream - -[[autodoc]] huggingface_hub.repository.commits_to_push - -## 후속 비동기 명령[[following-asynchronous-commands]] - -`Repository` 유틸리티는 비동기적으로 시작할 수 있는 여러 메소드를 제공합니다. -- `git_push` -- `git_pull` -- `push_to_hub` -- `commit` 컨텍스트 관리자 - -이러한 비동기 메소드를 관리하는 유틸리티는 아래를 참조하세요. - -[[autodoc]] Repository - - commands_failed - - commands_in_progress - - wait_for_commands - -[[autodoc]] huggingface_hub.repository.CommandInProgress diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index b534281a5e..13ecb24cf2 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -492,9 +492,6 @@ "ModelCardData", "SpaceCardData", ], - "repository": [ - "Repository", - ], "serialization": [ "StateDictSplit", "get_tf_storage_size", @@ -715,7 +712,6 @@ "REPO_TYPE_SPACE", "RepoCard", "RepoUrl", - "Repository", "SentenceSimilarityInput", "SentenceSimilarityInputData", "SpaceCard", @@ -1512,7 +1508,6 @@ def __dir__(): ModelCardData, # noqa: F401 SpaceCardData, # noqa: F401 ) - from .repository import Repository # noqa: F401 from .serialization import ( StateDictSplit, # noqa: F401 get_tf_storage_size, # noqa: F401 diff --git a/src/huggingface_hub/repository.py b/src/huggingface_hub/repository.py deleted file mode 100644 index 387761cedc..0000000000 --- a/src/huggingface_hub/repository.py +++ /dev/null @@ -1,1477 +0,0 @@ -import atexit -import os -import re -import subprocess -import threading -import time -from contextlib import contextmanager -from pathlib import Path -from typing import Callable, Iterator, Optional, TypedDict, Union -from urllib.parse import urlparse - -from huggingface_hub import constants -from huggingface_hub.repocard import metadata_load, metadata_save - -from .hf_api import HfApi, repo_type_and_id_from_hf_id -from .lfs import LFS_MULTIPART_UPLOAD_COMMAND -from .utils import ( - SoftTemporaryDirectory, - get_token, - logging, - run_subprocess, - tqdm, - validate_hf_hub_args, -) -from .utils._deprecation import _deprecate_method - - -logger = logging.get_logger(__name__) - - -class CommandInProgress: - """ - Utility to follow commands launched asynchronously. - """ - - def __init__( - self, - title: str, - is_done_method: Callable, - status_method: Callable, - process: subprocess.Popen, - post_method: Optional[Callable] = None, - ): - self.title = title - self._is_done = is_done_method - self._status = status_method - self._process = process - self._stderr = "" - self._stdout = "" - self._post_method = post_method - - @property - def is_done(self) -> bool: - """ - Whether the process is done. - """ - result = self._is_done() - - if result and self._post_method is not None: - self._post_method() - self._post_method = None - - return result - - @property - def status(self) -> int: - """ - The exit code/status of the current action. Will return `0` if the - command has completed successfully, and a number between 1 and 255 if - the process errored-out. - - Will return -1 if the command is still ongoing. - """ - return self._status() - - @property - def failed(self) -> bool: - """ - Whether the process errored-out. - """ - return self.status > 0 - - @property - def stderr(self) -> str: - """ - The current output message on the standard error. - """ - if self._process.stderr is not None: - self._stderr += self._process.stderr.read() - return self._stderr - - @property - def stdout(self) -> str: - """ - The current output message on the standard output. - """ - if self._process.stdout is not None: - self._stdout += self._process.stdout.read() - return self._stdout - - def __repr__(self): - status = self.status - - if status == -1: - status = "running" - - return ( - f"[{self.title} command, status code: {status}," - f" {'in progress.' if not self.is_done else 'finished.'} PID:" - f" {self._process.pid}]" - ) - - -def is_git_repo(folder: Union[str, Path]) -> bool: - """ - Check if the folder is the root or part of a git repository - - Args: - folder (`str`): - The folder in which to run the command. - - Returns: - `bool`: `True` if the repository is part of a repository, `False` - otherwise. - """ - folder_exists = os.path.exists(os.path.join(folder, ".git")) - git_branch = subprocess.run("git branch".split(), cwd=folder, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - return folder_exists and git_branch.returncode == 0 - - -def is_local_clone(folder: Union[str, Path], remote_url: str) -> bool: - """ - Check if the folder is a local clone of the remote_url - - Args: - folder (`str` or `Path`): - The folder in which to run the command. - remote_url (`str`): - The url of a git repository. - - Returns: - `bool`: `True` if the repository is a local clone of the remote - repository specified, `False` otherwise. - """ - if not is_git_repo(folder): - return False - - remotes = run_subprocess("git remote -v", folder).stdout - - # Remove token for the test with remotes. - remote_url = re.sub(r"https://.*@", "https://", remote_url) - remotes = [re.sub(r"https://.*@", "https://", remote) for remote in remotes.split()] - return remote_url in remotes - - -def is_tracked_with_lfs(filename: Union[str, Path]) -> bool: - """ - Check if the file passed is tracked with git-lfs. - - Args: - filename (`str` or `Path`): - The filename to check. - - Returns: - `bool`: `True` if the file passed is tracked with git-lfs, `False` - otherwise. - """ - folder = Path(filename).parent - filename = Path(filename).name - - try: - p = run_subprocess("git check-attr -a".split() + [filename], folder) - attributes = p.stdout.strip() - except subprocess.CalledProcessError as exc: - if not is_git_repo(folder): - return False - else: - raise OSError(exc.stderr) - - if len(attributes) == 0: - return False - - found_lfs_tag = {"diff": False, "merge": False, "filter": False} - - for attribute in attributes.split("\n"): - for tag in found_lfs_tag.keys(): - if tag in attribute and "lfs" in attribute: - found_lfs_tag[tag] = True - - return all(found_lfs_tag.values()) - - -def is_git_ignored(filename: Union[str, Path]) -> bool: - """ - Check if file is git-ignored. Supports nested .gitignore files. - - Args: - filename (`str` or `Path`): - The filename to check. - - Returns: - `bool`: `True` if the file passed is ignored by `git`, `False` - otherwise. - """ - folder = Path(filename).parent - filename = Path(filename).name - - try: - p = run_subprocess("git check-ignore".split() + [filename], folder, check=False) - # Will return exit code 1 if not gitignored - is_ignored = not bool(p.returncode) - except subprocess.CalledProcessError as exc: - raise OSError(exc.stderr) - - return is_ignored - - -def is_binary_file(filename: Union[str, Path]) -> bool: - """ - Check if file is a binary file. - - Args: - filename (`str` or `Path`): - The filename to check. - - Returns: - `bool`: `True` if the file passed is a binary file, `False` otherwise. - """ - try: - with open(filename, "rb") as f: - content = f.read(10 * (1024**2)) # Read a maximum of 10MB - - # Code sample taken from the following stack overflow thread - # https://stackoverflow.com/questions/898669/how-can-i-detect-if-a-file-is-binary-non-text-in-python/7392391#7392391 - text_chars = bytearray({7, 8, 9, 10, 12, 13, 27} | set(range(0x20, 0x100)) - {0x7F}) - return bool(content.translate(None, text_chars)) - except UnicodeDecodeError: - return True - - -def files_to_be_staged(pattern: str = ".", folder: Union[str, Path, None] = None) -> list[str]: - """ - Returns a list of filenames that are to be staged. - - Args: - pattern (`str` or `Path`): - The pattern of filenames to check. Put `.` to get all files. - folder (`str` or `Path`): - The folder in which to run the command. - - Returns: - `list[str]`: List of files that are to be staged. - """ - try: - p = run_subprocess("git ls-files --exclude-standard -mo".split() + [pattern], folder) - if len(p.stdout.strip()): - files = p.stdout.strip().split("\n") - else: - files = [] - except subprocess.CalledProcessError as exc: - raise EnvironmentError(exc.stderr) - - return files - - -def is_tracked_upstream(folder: Union[str, Path]) -> bool: - """ - Check if the current checked-out branch is tracked upstream. - - Args: - folder (`str` or `Path`): - The folder in which to run the command. - - Returns: - `bool`: `True` if the current checked-out branch is tracked upstream, - `False` otherwise. - """ - try: - run_subprocess("git rev-parse --symbolic-full-name --abbrev-ref @{u}", folder) - return True - except subprocess.CalledProcessError as exc: - if "HEAD" in exc.stderr: - raise OSError("No branch checked out") - - return False - - -def commits_to_push(folder: Union[str, Path], upstream: Optional[str] = None) -> int: - """ - Check the number of commits that would be pushed upstream - - Args: - folder (`str` or `Path`): - The folder in which to run the command. - upstream (`str`, *optional*): - The name of the upstream repository with which the comparison should be - made. - - Returns: - `int`: Number of commits that would be pushed upstream were a `git - push` to proceed. - """ - try: - result = run_subprocess(f"git cherry -v {upstream or ''}", folder) - return len(result.stdout.split("\n")) - 1 - except subprocess.CalledProcessError as exc: - raise EnvironmentError(exc.stderr) - - -class PbarT(TypedDict): - # Used to store an opened progress bar in `_lfs_log_progress` - bar: tqdm - past_bytes: int - - -@contextmanager -def _lfs_log_progress(): - """ - This is a context manager that will log the Git LFS progress of cleaning, - smudging, pulling and pushing. - """ - - if logger.getEffectiveLevel() >= logging.ERROR: - try: - yield - except Exception: - pass - return - - def output_progress(stopping_event: threading.Event): - """ - To be launched as a separate thread with an event meaning it should stop - the tail. - """ - # Key is tuple(state, filename), value is a dict(tqdm bar and a previous value) - pbars: dict[tuple[str, str], PbarT] = {} - - def close_pbars(): - for pbar in pbars.values(): - pbar["bar"].update(pbar["bar"].total - pbar["past_bytes"]) - pbar["bar"].refresh() - pbar["bar"].close() - - def tail_file(filename) -> Iterator[str]: - """ - Creates a generator to be iterated through, which will return each - line one by one. Will stop tailing the file if the stopping_event is - set. - """ - with open(filename, "r") as file: - current_line = "" - while True: - if stopping_event.is_set(): - close_pbars() - break - - line_bit = file.readline() - if line_bit is not None and not len(line_bit.strip()) == 0: - current_line += line_bit - if current_line.endswith("\n"): - yield current_line - current_line = "" - else: - time.sleep(1) - - # If the file isn't created yet, wait for a few seconds before trying again. - # Can be interrupted with the stopping_event. - while not os.path.exists(os.environ["GIT_LFS_PROGRESS"]): - if stopping_event.is_set(): - close_pbars() - return - - time.sleep(2) - - for line in tail_file(os.environ["GIT_LFS_PROGRESS"]): - try: - state, file_progress, byte_progress, filename = line.split() - except ValueError as error: - # Try/except to ease debugging. See https://github.com/huggingface/huggingface_hub/issues/1373. - raise ValueError(f"Cannot unpack LFS progress line:\n{line}") from error - description = f"{state.capitalize()} file {filename}" - - current_bytes, total_bytes = byte_progress.split("/") - current_bytes_int = int(current_bytes) - total_bytes_int = int(total_bytes) - - pbar = pbars.get((state, filename)) - if pbar is None: - # Initialize progress bar - pbars[(state, filename)] = { - "bar": tqdm( - desc=description, - initial=current_bytes_int, - total=total_bytes_int, - unit="B", - unit_scale=True, - unit_divisor=1024, - name="huggingface_hub.lfs_upload", - ), - "past_bytes": int(current_bytes), - } - else: - # Update progress bar - pbar["bar"].update(current_bytes_int - pbar["past_bytes"]) - pbar["past_bytes"] = current_bytes_int - - current_lfs_progress_value = os.environ.get("GIT_LFS_PROGRESS", "") - - with SoftTemporaryDirectory() as tmpdir: - os.environ["GIT_LFS_PROGRESS"] = os.path.join(tmpdir, "lfs_progress") - logger.debug(f"Following progress in {os.environ['GIT_LFS_PROGRESS']}") - - exit_event = threading.Event() - x = threading.Thread(target=output_progress, args=(exit_event,), daemon=True) - x.start() - - try: - yield - finally: - exit_event.set() - x.join() - - os.environ["GIT_LFS_PROGRESS"] = current_lfs_progress_value - - -class Repository: - """ - Helper class to wrap the git and git-lfs commands. - - The aim is to facilitate interacting with huggingface.co hosted model or - dataset repos, though not a lot here (if any) is actually specific to - huggingface.co. - - - - [`Repository`] is deprecated in favor of the http-based alternatives implemented in - [`HfApi`]. Given its large adoption in legacy code, the complete removal of - [`Repository`] will only happen in release `v1.0`. For more details, please read - https://huggingface.co/docs/huggingface_hub/concepts/git_vs_http. - - - """ - - command_queue: list[CommandInProgress] - - @validate_hf_hub_args - @_deprecate_method( - version="1.0", - message=( - "Please prefer the http-based alternatives instead. Given its large adoption in legacy code, the complete" - " removal is only planned on next major release.\nFor more details, please read" - " https://huggingface.co/docs/huggingface_hub/concepts/git_vs_http." - ), - ) - def __init__( - self, - local_dir: Union[str, Path], - clone_from: Optional[str] = None, - repo_type: Optional[str] = None, - token: Union[bool, str] = True, - git_user: Optional[str] = None, - git_email: Optional[str] = None, - revision: Optional[str] = None, - skip_lfs_files: bool = False, - client: Optional[HfApi] = None, - ): - """ - Instantiate a local clone of a git repo. - - If `clone_from` is set, the repo will be cloned from an existing remote repository. - If the remote repo does not exist, a `EnvironmentError` exception will be thrown. - Please create the remote repo first using [`create_repo`]. - - `Repository` uses the local git credentials by default. If explicitly set, the `token` - or the `git_user`/`git_email` pair will be used instead. - - Args: - local_dir (`str` or `Path`): - path (e.g. `'my_trained_model/'`) to the local directory, where - the `Repository` will be initialized. - clone_from (`str`, *optional*): - Either a repository url or `repo_id`. - Example: - - `"https://huggingface.co/philschmid/playground-tests"` - - `"philschmid/playground-tests"` - repo_type (`str`, *optional*): - To set when cloning a repo from a repo_id. Default is model. - token (`bool` or `str`, *optional*): - A valid authentication token (see https://huggingface.co/settings/token). - If `None` or `True` and machine is logged in (through `hf auth login` - or [`~huggingface_hub.login`]), token will be retrieved from the cache. - If `False`, token is not sent in the request header. - git_user (`str`, *optional*): - will override the `git config user.name` for committing and - pushing files to the hub. - git_email (`str`, *optional*): - will override the `git config user.email` for committing and - pushing files to the hub. - revision (`str`, *optional*): - Revision to checkout after initializing the repository. If the - revision doesn't exist, a branch will be created with that - revision name from the default branch's current HEAD. - skip_lfs_files (`bool`, *optional*, defaults to `False`): - whether to skip git-LFS files or not. - client (`HfApi`, *optional*): - Instance of [`HfApi`] to use when calling the HF Hub API. A new - instance will be created if this is left to `None`. - - Raises: - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) - If the remote repository set in `clone_from` does not exist. - """ - if isinstance(local_dir, Path): - local_dir = str(local_dir) - os.makedirs(local_dir, exist_ok=True) - self.local_dir = os.path.join(os.getcwd(), local_dir) - self._repo_type = repo_type - self.command_queue = [] - self.skip_lfs_files = skip_lfs_files - self.client = client if client is not None else HfApi() - - self.check_git_versions() - - if isinstance(token, str): - self.huggingface_token: Optional[str] = token - elif token is False: - self.huggingface_token = None - else: - # if `True` -> explicit use of the cached token - # if `None` -> implicit use of the cached token - self.huggingface_token = get_token() - - if clone_from is not None: - self.clone_from(repo_url=clone_from) - else: - if is_git_repo(self.local_dir): - logger.debug("[Repository] is a valid git repo") - else: - raise ValueError("If not specifying `clone_from`, you need to pass Repository a valid git clone.") - - if self.huggingface_token is not None and (git_email is None or git_user is None): - user = self.client.whoami(self.huggingface_token) - - if git_email is None: - git_email = user.get("email") - - if git_user is None: - git_user = user.get("fullname") - - if git_user is not None or git_email is not None: - self.git_config_username_and_email(git_user, git_email) - - self.lfs_enable_largefiles() - self.git_credential_helper_store() - - if revision is not None: - self.git_checkout(revision, create_branch_ok=True) - - # This ensures that all commands exit before exiting the Python runtime. - # This will ensure all pushes register on the hub, even if other errors happen in subsequent operations. - atexit.register(self.wait_for_commands) - - @property - def current_branch(self) -> str: - """ - Returns the current checked out branch. - - Returns: - `str`: Current checked out branch. - """ - try: - result = run_subprocess("git rev-parse --abbrev-ref HEAD", self.local_dir).stdout.strip() - except subprocess.CalledProcessError as exc: - raise EnvironmentError(exc.stderr) - - return result - - def check_git_versions(self): - """ - Checks that `git` and `git-lfs` can be run. - - Raises: - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) - If `git` or `git-lfs` are not installed. - """ - try: - git_version = run_subprocess("git --version", self.local_dir).stdout.strip() - except FileNotFoundError: - raise EnvironmentError("Looks like you do not have git installed, please install.") - - try: - lfs_version = run_subprocess("git-lfs --version", self.local_dir).stdout.strip() - except FileNotFoundError: - raise EnvironmentError( - "Looks like you do not have git-lfs installed, please install." - " You can install from https://git-lfs.github.com/." - " Then run `git lfs install` (you only have to do this once)." - ) - logger.info(git_version + "\n" + lfs_version) - - @validate_hf_hub_args - def clone_from(self, repo_url: str, token: Union[bool, str, None] = None): - """ - Clone from a remote. If the folder already exists, will try to clone the - repository within it. - - If this folder is a git repository with linked history, will try to - update the repository. - - Args: - repo_url (`str`): - The URL from which to clone the repository - token (`Union[str, bool]`, *optional*): - Whether to use the authentication token. It can be: - - a string which is the token itself - - `False`, which would not use the authentication token - - `True`, which would fetch the authentication token from the - local folder and use it (you should be logged in for this to - work). - - `None`, which would retrieve the value of - `self.huggingface_token`. - - - - Raises the following error: - - - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) - if an organization token (starts with "api_org") is passed. Use must use - your own personal access token (see https://hf.co/settings/tokens). - - - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) - if you are trying to clone the repository in a non-empty folder, or if the - `git` operations raise errors. - - - """ - token = ( - token # str -> use it - if isinstance(token, str) - else ( - None # `False` -> explicit no token - if token is False - else self.huggingface_token # `None` or `True` -> use default - ) - ) - if token is not None and token.startswith("api_org"): - raise ValueError( - "You must use your personal access token, not an Organization token" - " (see https://hf.co/settings/tokens)." - ) - - hub_url = self.client.endpoint - if hub_url in repo_url or ("http" not in repo_url and len(repo_url.split("/")) <= 2): - repo_type, namespace, repo_name = repo_type_and_id_from_hf_id(repo_url, hub_url=hub_url) - repo_id = f"{namespace}/{repo_name}" if namespace is not None else repo_name - - if repo_type is not None: - self._repo_type = repo_type - - repo_url = hub_url + "/" - - if self._repo_type in constants.REPO_TYPES_URL_PREFIXES: - repo_url += constants.REPO_TYPES_URL_PREFIXES[self._repo_type] - - if token is not None: - # Add token in git url when provided - scheme = urlparse(repo_url).scheme - repo_url = repo_url.replace(f"{scheme}://", f"{scheme}://user:{token}@") - - repo_url += repo_id - - # For error messages, it's cleaner to show the repo url without the token. - clean_repo_url = re.sub(r"(https?)://.*@", r"\1://", repo_url) - try: - run_subprocess("git lfs install", self.local_dir) - - # checks if repository is initialized in a empty repository or in one with files - if len(os.listdir(self.local_dir)) == 0: - logger.warning(f"Cloning {clean_repo_url} into local empty directory.") - - with _lfs_log_progress(): - env = os.environ.copy() - - if self.skip_lfs_files: - env.update({"GIT_LFS_SKIP_SMUDGE": "1"}) - - run_subprocess( - # 'git lfs clone' is deprecated (will display a warning in the terminal) - # but we still use it as it provides a nicer UX when downloading large - # files (shows progress). - f"{'git clone' if self.skip_lfs_files else 'git lfs clone'} {repo_url} .", - self.local_dir, - env=env, - ) - else: - # Check if the folder is the root of a git repository - if not is_git_repo(self.local_dir): - raise EnvironmentError( - "Tried to clone a repository in a non-empty folder that isn't" - f" a git repository ('{self.local_dir}'). If you really want to" - f" do this, do it manually:\n cd {self.local_dir} && git init" - " && git remote add origin && git pull origin main\n or clone" - " repo to a new folder and move your existing files there" - " afterwards." - ) - - if is_local_clone(self.local_dir, repo_url): - logger.warning( - f"{self.local_dir} is already a clone of {clean_repo_url}." - " Make sure you pull the latest changes with" - " `repo.git_pull()`." - ) - else: - output = run_subprocess("git remote get-url origin", self.local_dir, check=False) - - error_msg = ( - f"Tried to clone {clean_repo_url} in an unrelated git" - " repository.\nIf you believe this is an error, please add" - f" a remote with the following URL: {clean_repo_url}." - ) - if output.returncode == 0: - clean_local_remote_url = re.sub(r"https://.*@", "https://", output.stdout) - error_msg += f"\nLocal path has its origin defined as: {clean_local_remote_url}" - raise EnvironmentError(error_msg) - - except subprocess.CalledProcessError as exc: - raise EnvironmentError(exc.stderr) - - def git_config_username_and_email(self, git_user: Optional[str] = None, git_email: Optional[str] = None): - """ - Sets git username and email (only in the current repo). - - Args: - git_user (`str`, *optional*): - The username to register through `git`. - git_email (`str`, *optional*): - The email to register through `git`. - """ - try: - if git_user is not None: - run_subprocess("git config user.name".split() + [git_user], self.local_dir) - - if git_email is not None: - run_subprocess(f"git config user.email {git_email}".split(), self.local_dir) - except subprocess.CalledProcessError as exc: - raise EnvironmentError(exc.stderr) - - def git_credential_helper_store(self): - """ - Sets the git credential helper to `store` - """ - try: - run_subprocess("git config credential.helper store", self.local_dir) - except subprocess.CalledProcessError as exc: - raise EnvironmentError(exc.stderr) - - def git_head_hash(self) -> str: - """ - Get commit sha on top of HEAD. - - Returns: - `str`: The current checked out commit SHA. - """ - try: - p = run_subprocess("git rev-parse HEAD", self.local_dir) - return p.stdout.strip() - except subprocess.CalledProcessError as exc: - raise EnvironmentError(exc.stderr) - - def git_remote_url(self) -> str: - """ - Get URL to origin remote. - - Returns: - `str`: The URL of the `origin` remote. - """ - try: - p = run_subprocess("git config --get remote.origin.url", self.local_dir) - url = p.stdout.strip() - # Strip basic auth info. - return re.sub(r"https://.*@", "https://", url) - except subprocess.CalledProcessError as exc: - raise EnvironmentError(exc.stderr) - - def git_head_commit_url(self) -> str: - """ - Get URL to last commit on HEAD. We assume it's been pushed, and the url - scheme is the same one as for GitHub or HuggingFace. - - Returns: - `str`: The URL to the current checked-out commit. - """ - sha = self.git_head_hash() - url = self.git_remote_url() - if url.endswith("/"): - url = url[:-1] - return f"{url}/commit/{sha}" - - def list_deleted_files(self) -> list[str]: - """ - Returns a list of the files that are deleted in the working directory or - index. - - Returns: - `list[str]`: A list of files that have been deleted in the working - directory or index. - """ - try: - git_status = run_subprocess("git status -s", self.local_dir).stdout.strip() - except subprocess.CalledProcessError as exc: - raise EnvironmentError(exc.stderr) - - if len(git_status) == 0: - return [] - - # Receives a status like the following - # D .gitignore - # D new_file.json - # AD new_file1.json - # ?? new_file2.json - # ?? new_file4.json - - # Strip each line of whitespaces - modified_files_statuses = [status.strip() for status in git_status.split("\n")] - - # Only keep files that are deleted using the D prefix - deleted_files_statuses = [status for status in modified_files_statuses if "D" in status.split()[0]] - - # Remove the D prefix and strip to keep only the relevant filename - deleted_files = [status.split()[-1].strip() for status in deleted_files_statuses] - - return deleted_files - - def lfs_track(self, patterns: Union[str, list[str]], filename: bool = False): - """ - Tell git-lfs to track files according to a pattern. - - Setting the `filename` argument to `True` will treat the arguments as - literal filenames, not as patterns. Any special glob characters in the - filename will be escaped when writing to the `.gitattributes` file. - - Args: - patterns (`Union[str, list[str]]`): - The pattern, or list of patterns, to track with git-lfs. - filename (`bool`, *optional*, defaults to `False`): - Whether to use the patterns as literal filenames. - """ - if isinstance(patterns, str): - patterns = [patterns] - try: - for pattern in patterns: - run_subprocess( - f"git lfs track {'--filename' if filename else ''} {pattern}", - self.local_dir, - ) - except subprocess.CalledProcessError as exc: - raise EnvironmentError(exc.stderr) - - def lfs_untrack(self, patterns: Union[str, list[str]]): - """ - Tell git-lfs to untrack those files. - - Args: - patterns (`Union[str, list[str]]`): - The pattern, or list of patterns, to untrack with git-lfs. - """ - if isinstance(patterns, str): - patterns = [patterns] - try: - for pattern in patterns: - run_subprocess("git lfs untrack".split() + [pattern], self.local_dir) - except subprocess.CalledProcessError as exc: - raise EnvironmentError(exc.stderr) - - def lfs_enable_largefiles(self): - """ - HF-specific. This enables upload support of files >5GB. - """ - try: - lfs_config = "git config lfs.customtransfer.multipart" - run_subprocess(f"{lfs_config}.path hf", self.local_dir) - run_subprocess( - f"{lfs_config}.args {LFS_MULTIPART_UPLOAD_COMMAND}", - self.local_dir, - ) - except subprocess.CalledProcessError as exc: - raise EnvironmentError(exc.stderr) - - def auto_track_binary_files(self, pattern: str = ".") -> list[str]: - """ - Automatically track binary files with git-lfs. - - Args: - pattern (`str`, *optional*, defaults to "."): - The pattern with which to track files that are binary. - - Returns: - `list[str]`: List of filenames that are now tracked due to being - binary files - """ - files_to_be_tracked_with_lfs = [] - - deleted_files = self.list_deleted_files() - - for filename in files_to_be_staged(pattern, folder=self.local_dir): - if filename in deleted_files: - continue - - path_to_file = os.path.join(os.getcwd(), self.local_dir, filename) - - if not (is_tracked_with_lfs(path_to_file) or is_git_ignored(path_to_file)): - size_in_mb = os.path.getsize(path_to_file) / (1024 * 1024) - - if size_in_mb >= 10: - logger.warning( - "Parsing a large file to check if binary or not. Tracking large" - " files using `repository.auto_track_large_files` is" - " recommended so as to not load the full file in memory." - ) - - is_binary = is_binary_file(path_to_file) - - if is_binary: - self.lfs_track(filename) - files_to_be_tracked_with_lfs.append(filename) - - # Cleanup the .gitattributes if files were deleted - self.lfs_untrack(deleted_files) - - return files_to_be_tracked_with_lfs - - def auto_track_large_files(self, pattern: str = ".") -> list[str]: - """ - Automatically track large files (files that weigh more than 10MBs) with - git-lfs. - - Args: - pattern (`str`, *optional*, defaults to "."): - The pattern with which to track files that are above 10MBs. - - Returns: - `list[str]`: List of filenames that are now tracked due to their - size. - """ - files_to_be_tracked_with_lfs = [] - - deleted_files = self.list_deleted_files() - - for filename in files_to_be_staged(pattern, folder=self.local_dir): - if filename in deleted_files: - continue - - path_to_file = os.path.join(os.getcwd(), self.local_dir, filename) - size_in_mb = os.path.getsize(path_to_file) / (1024 * 1024) - - if size_in_mb >= 10 and not is_tracked_with_lfs(path_to_file) and not is_git_ignored(path_to_file): - self.lfs_track(filename) - files_to_be_tracked_with_lfs.append(filename) - - # Cleanup the .gitattributes if files were deleted - self.lfs_untrack(deleted_files) - - return files_to_be_tracked_with_lfs - - def lfs_prune(self, recent=False): - """ - git lfs prune - - Args: - recent (`bool`, *optional*, defaults to `False`): - Whether to prune files even if they were referenced by recent - commits. See the following - [link](https://github.com/git-lfs/git-lfs/blob/f3d43f0428a84fc4f1e5405b76b5a73ec2437e65/docs/man/git-lfs-prune.1.ronn#recent-files) - for more information. - """ - try: - with _lfs_log_progress(): - result = run_subprocess(f"git lfs prune {'--recent' if recent else ''}", self.local_dir) - logger.info(result.stdout) - except subprocess.CalledProcessError as exc: - raise EnvironmentError(exc.stderr) - - def git_pull(self, rebase: bool = False, lfs: bool = False): - """ - git pull - - Args: - rebase (`bool`, *optional*, defaults to `False`): - Whether to rebase the current branch on top of the upstream - branch after fetching. - lfs (`bool`, *optional*, defaults to `False`): - Whether to fetch the LFS files too. This option only changes the - behavior when a repository was cloned without fetching the LFS - files; calling `repo.git_pull(lfs=True)` will then fetch the LFS - file from the remote repository. - """ - command = "git pull" if not lfs else "git lfs pull" - if rebase: - command += " --rebase" - try: - with _lfs_log_progress(): - result = run_subprocess(command, self.local_dir) - logger.info(result.stdout) - except subprocess.CalledProcessError as exc: - raise EnvironmentError(exc.stderr) - - def git_add(self, pattern: str = ".", auto_lfs_track: bool = False): - """ - git add - - Setting the `auto_lfs_track` parameter to `True` will automatically - track files that are larger than 10MB with `git-lfs`. - - Args: - pattern (`str`, *optional*, defaults to "."): - The pattern with which to add files to staging. - auto_lfs_track (`bool`, *optional*, defaults to `False`): - Whether to automatically track large and binary files with - git-lfs. Any file over 10MB in size, or in binary format, will - be automatically tracked. - """ - if auto_lfs_track: - # Track files according to their size (>=10MB) - tracked_files = self.auto_track_large_files(pattern) - - # Read the remaining files and track them if they're binary - tracked_files.extend(self.auto_track_binary_files(pattern)) - - if tracked_files: - logger.warning( - f"Adding files tracked by Git LFS: {tracked_files}. This may take a" - " bit of time if the files are large." - ) - - try: - result = run_subprocess("git add -v".split() + [pattern], self.local_dir) - logger.info(f"Adding to index:\n{result.stdout}\n") - except subprocess.CalledProcessError as exc: - raise EnvironmentError(exc.stderr) - - def git_commit(self, commit_message: str = "commit files to HF hub"): - """ - git commit - - Args: - commit_message (`str`, *optional*, defaults to "commit files to HF hub"): - The message attributed to the commit. - """ - try: - result = run_subprocess("git commit -v -m".split() + [commit_message], self.local_dir) - logger.info(f"Committed:\n{result.stdout}\n") - except subprocess.CalledProcessError as exc: - if len(exc.stderr) > 0: - raise EnvironmentError(exc.stderr) - else: - raise EnvironmentError(exc.stdout) - - def git_push( - self, - upstream: Optional[str] = None, - blocking: bool = True, - auto_lfs_prune: bool = False, - ) -> Union[str, tuple[str, CommandInProgress]]: - """ - git push - - If used without setting `blocking`, will return url to commit on remote - repo. If used with `blocking=True`, will return a tuple containing the - url to commit and the command object to follow for information about the - process. - - Args: - upstream (`str`, *optional*): - Upstream to which this should push. If not specified, will push - to the lastly defined upstream or to the default one (`origin - main`). - blocking (`bool`, *optional*, defaults to `True`): - Whether the function should return only when the push has - finished. Setting this to `False` will return an - `CommandInProgress` object which has an `is_done` property. This - property will be set to `True` when the push is finished. - auto_lfs_prune (`bool`, *optional*, defaults to `False`): - Whether to automatically prune files once they have been pushed - to the remote. - """ - command = "git push" - - if upstream: - command += f" --set-upstream {upstream}" - - number_of_commits = commits_to_push(self.local_dir, upstream) - - if number_of_commits > 1: - logger.warning(f"Several commits ({number_of_commits}) will be pushed upstream.") - if blocking: - logger.warning("The progress bars may be unreliable.") - - try: - with _lfs_log_progress(): - process = subprocess.Popen( - command.split(), - stderr=subprocess.PIPE, - stdout=subprocess.PIPE, - encoding="utf-8", - cwd=self.local_dir, - ) - - if blocking: - stdout, stderr = process.communicate() - return_code = process.poll() - process.kill() - - if len(stderr): - logger.warning(stderr) - - if return_code: - raise subprocess.CalledProcessError(return_code, process.args, output=stdout, stderr=stderr) - - except subprocess.CalledProcessError as exc: - raise EnvironmentError(exc.stderr) - - if not blocking: - - def status_method(): - status = process.poll() - if status is None: - return -1 - else: - return status - - command_in_progress = CommandInProgress( - "push", - is_done_method=lambda: process.poll() is not None, - status_method=status_method, - process=process, - post_method=self.lfs_prune if auto_lfs_prune else None, - ) - - self.command_queue.append(command_in_progress) - - return self.git_head_commit_url(), command_in_progress - - if auto_lfs_prune: - self.lfs_prune() - - return self.git_head_commit_url() - - def git_checkout(self, revision: str, create_branch_ok: bool = False): - """ - git checkout a given revision - - Specifying `create_branch_ok` to `True` will create the branch to the - given revision if that revision doesn't exist. - - Args: - revision (`str`): - The revision to checkout. - create_branch_ok (`str`, *optional*, defaults to `False`): - Whether creating a branch named with the `revision` passed at - the current checked-out reference if `revision` isn't an - existing revision is allowed. - """ - try: - result = run_subprocess(f"git checkout {revision}", self.local_dir) - logger.warning(f"Checked out {revision} from {self.current_branch}.") - logger.warning(result.stdout) - except subprocess.CalledProcessError as exc: - if not create_branch_ok: - raise EnvironmentError(exc.stderr) - else: - try: - result = run_subprocess(f"git checkout -b {revision}", self.local_dir) - logger.warning( - f"Revision `{revision}` does not exist. Created and checked out branch `{revision}`." - ) - logger.warning(result.stdout) - except subprocess.CalledProcessError as exc: - raise EnvironmentError(exc.stderr) - - def tag_exists(self, tag_name: str, remote: Optional[str] = None) -> bool: - """ - Check if a tag exists or not. - - Args: - tag_name (`str`): - The name of the tag to check. - remote (`str`, *optional*): - Whether to check if the tag exists on a remote. This parameter - should be the identifier of the remote. - - Returns: - `bool`: Whether the tag exists. - """ - if remote: - try: - result = run_subprocess(f"git ls-remote origin refs/tags/{tag_name}", self.local_dir).stdout.strip() - except subprocess.CalledProcessError as exc: - raise EnvironmentError(exc.stderr) - - return len(result) != 0 - else: - try: - git_tags = run_subprocess("git tag", self.local_dir).stdout.strip() - except subprocess.CalledProcessError as exc: - raise EnvironmentError(exc.stderr) - - git_tags = git_tags.split("\n") - return tag_name in git_tags - - def delete_tag(self, tag_name: str, remote: Optional[str] = None) -> bool: - """ - Delete a tag, both local and remote, if it exists - - Args: - tag_name (`str`): - The tag name to delete. - remote (`str`, *optional*): - The remote on which to delete the tag. - - Returns: - `bool`: `True` if deleted, `False` if the tag didn't exist. - If remote is not passed, will just be updated locally - """ - delete_locally = True - delete_remotely = True - - if not self.tag_exists(tag_name): - delete_locally = False - - if not self.tag_exists(tag_name, remote=remote): - delete_remotely = False - - if delete_locally: - try: - run_subprocess(["git", "tag", "-d", tag_name], self.local_dir).stdout.strip() - except subprocess.CalledProcessError as exc: - raise EnvironmentError(exc.stderr) - - if remote and delete_remotely: - try: - run_subprocess(f"git push {remote} --delete {tag_name}", self.local_dir).stdout.strip() - except subprocess.CalledProcessError as exc: - raise EnvironmentError(exc.stderr) - - return True - - def add_tag(self, tag_name: str, message: Optional[str] = None, remote: Optional[str] = None): - """ - Add a tag at the current head and push it - - If remote is None, will just be updated locally - - If no message is provided, the tag will be lightweight. if a message is - provided, the tag will be annotated. - - Args: - tag_name (`str`): - The name of the tag to be added. - message (`str`, *optional*): - The message that accompanies the tag. The tag will turn into an - annotated tag if a message is passed. - remote (`str`, *optional*): - The remote on which to add the tag. - """ - if message: - tag_args = ["git", "tag", "-a", tag_name, "-m", message] - else: - tag_args = ["git", "tag", tag_name] - - try: - run_subprocess(tag_args, self.local_dir).stdout.strip() - except subprocess.CalledProcessError as exc: - raise EnvironmentError(exc.stderr) - - if remote: - try: - run_subprocess(f"git push {remote} {tag_name}", self.local_dir).stdout.strip() - except subprocess.CalledProcessError as exc: - raise EnvironmentError(exc.stderr) - - def is_repo_clean(self) -> bool: - """ - Return whether or not the git status is clean or not - - Returns: - `bool`: `True` if the git status is clean, `False` otherwise. - """ - try: - git_status = run_subprocess("git status --porcelain", self.local_dir).stdout.strip() - except subprocess.CalledProcessError as exc: - raise EnvironmentError(exc.stderr) - - return len(git_status) == 0 - - def push_to_hub( - self, - commit_message: str = "commit files to HF hub", - blocking: bool = True, - clean_ok: bool = True, - auto_lfs_prune: bool = False, - ) -> Union[None, str, tuple[str, CommandInProgress]]: - """ - Helper to add, commit, and push files to remote repository on the - HuggingFace Hub. Will automatically track large files (>10MB). - - Args: - commit_message (`str`): - Message to use for the commit. - blocking (`bool`, *optional*, defaults to `True`): - Whether the function should return only when the `git push` has - finished. - clean_ok (`bool`, *optional*, defaults to `True`): - If True, this function will return None if the repo is - untouched. Default behavior is to fail because the git command - fails. - auto_lfs_prune (`bool`, *optional*, defaults to `False`): - Whether to automatically prune files once they have been pushed - to the remote. - """ - if clean_ok and self.is_repo_clean(): - logger.info("Repo currently clean. Ignoring push_to_hub") - return None - self.git_add(auto_lfs_track=True) - self.git_commit(commit_message) - return self.git_push( - upstream=f"origin {self.current_branch}", - blocking=blocking, - auto_lfs_prune=auto_lfs_prune, - ) - - @contextmanager - def commit( - self, - commit_message: str, - branch: Optional[str] = None, - track_large_files: bool = True, - blocking: bool = True, - auto_lfs_prune: bool = False, - ): - """ - Context manager utility to handle committing to a repository. This - automatically tracks large files (>10Mb) with git-lfs. Set the - `track_large_files` argument to `False` if you wish to ignore that - behavior. - - Args: - commit_message (`str`): - Message to use for the commit. - branch (`str`, *optional*): - The branch on which the commit will appear. This branch will be - checked-out before any operation. - track_large_files (`bool`, *optional*, defaults to `True`): - Whether to automatically track large files or not. Will do so by - default. - blocking (`bool`, *optional*, defaults to `True`): - Whether the function should return only when the `git push` has - finished. - auto_lfs_prune (`bool`, defaults to `True`): - Whether to automatically prune files once they have been pushed - to the remote. - - Examples: - - ```python - >>> with Repository( - ... "text-files", - ... clone_from="/text-files", - ... token=True, - >>> ).commit("My first file :)"): - ... with open("file.txt", "w+") as f: - ... f.write(json.dumps({"hey": 8})) - - >>> import torch - - >>> model = torch.nn.Transformer() - >>> with Repository( - ... "torch-model", - ... clone_from="/torch-model", - ... token=True, - >>> ).commit("My cool model :)"): - ... torch.save(model.state_dict(), "model.pt") - ``` - - """ - - files_to_stage = files_to_be_staged(".", folder=self.local_dir) - - if len(files_to_stage): - files_in_msg = str(files_to_stage[:5])[:-1] + ", ...]" if len(files_to_stage) > 5 else str(files_to_stage) - logger.error( - "There exists some updated files in the local repository that are not" - f" committed: {files_in_msg}. This may lead to errors if checking out" - " a branch. These files and their modifications will be added to the" - " current commit." - ) - - if branch is not None: - self.git_checkout(branch, create_branch_ok=True) - - if is_tracked_upstream(self.local_dir): - logger.warning("Pulling changes ...") - self.git_pull(rebase=True) - else: - logger.warning(f"The current branch has no upstream branch. Will push to 'origin {self.current_branch}'") - - current_working_directory = os.getcwd() - os.chdir(os.path.join(current_working_directory, self.local_dir)) - - try: - yield self - finally: - self.git_add(auto_lfs_track=track_large_files) - - try: - self.git_commit(commit_message) - except OSError as e: - # If no changes are detected, there is nothing to commit. - if "nothing to commit" not in str(e): - raise e - - try: - self.git_push( - upstream=f"origin {self.current_branch}", - blocking=blocking, - auto_lfs_prune=auto_lfs_prune, - ) - except OSError as e: - # If no changes are detected, there is nothing to commit. - if "could not read Username" in str(e): - raise OSError("Couldn't authenticate user for push. Did you set `token` to `True`?") from e - else: - raise e - - os.chdir(current_working_directory) - - def repocard_metadata_load(self) -> Optional[dict]: - filepath = os.path.join(self.local_dir, constants.REPOCARD_NAME) - if os.path.isfile(filepath): - return metadata_load(filepath) - return None - - def repocard_metadata_save(self, data: dict) -> None: - return metadata_save(os.path.join(self.local_dir, constants.REPOCARD_NAME), data) - - @property - def commands_failed(self): - """ - Returns the asynchronous commands that failed. - """ - return [c for c in self.command_queue if c.status > 0] - - @property - def commands_in_progress(self): - """ - Returns the asynchronous commands that are currently in progress. - """ - return [c for c in self.command_queue if not c.is_done] - - def wait_for_commands(self): - """ - Blocking method: blocks all subsequent execution until all commands have - been processed. - """ - index = 0 - for command_failed in self.commands_failed: - logger.error(f"The {command_failed.title} command with PID {command_failed._process.pid} failed.") - logger.error(command_failed.stderr) - - while self.commands_in_progress: - if index % 10 == 0: - logger.warning( - f"Waiting for the following commands to finish before shutting down: {self.commands_in_progress}." - ) - - index += 1 - - time.sleep(1) diff --git a/tests/test_repository.py b/tests/test_repository.py deleted file mode 100644 index 772dc9850f..0000000000 --- a/tests/test_repository.py +++ /dev/null @@ -1,895 +0,0 @@ -# Copyright 2020 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -import os -import time -import unittest -from pathlib import Path - -import httpx -import pytest - -from huggingface_hub import RepoUrl -from huggingface_hub.hf_api import HfApi -from huggingface_hub.repository import ( - Repository, - is_tracked_upstream, - is_tracked_with_lfs, -) -from huggingface_hub.utils import SoftTemporaryDirectory, logging, run_subprocess - -from .testing_constants import ENDPOINT_STAGING, TOKEN -from .testing_utils import ( - expect_deprecation, - repo_name, - use_tmp_repo, - with_production_testing, -) - - -logger = logging.get_logger(__name__) - - -@pytest.mark.usefixtures("fx_cache_dir") -class RepositoryTestAbstract(unittest.TestCase): - cache_dir: Path - repo_path: Path - - # This content is 5MB (under 10MB) - small_content = json.dumps([100] * int(1e6)) - - # This content is 20MB (over 10MB) - large_content = json.dumps([100] * int(4e6)) - - # This content is binary (contains the null character) - binary_content = "\x00\x00\x00\x00" - - _api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN) - - @classmethod - def setUp(self) -> None: - self.repo_path = self.cache_dir / "working_dir" - self.repo_path.mkdir() - - def _create_dummy_files(self): - # Create dummy files - # one is lfs-tracked, the other is not. - small_file = self.repo_path / "dummy.txt" - small_file.write_text(self.small_content) - - binary_file = self.repo_path / "model.bin" - binary_file.write_text(self.binary_content) - - -class TestRepositoryShared(RepositoryTestAbstract): - """Tests in this class shares a single repo on the Hub (common to all tests). - - These tests must not push data to it. - """ - - @classmethod - def setUpClass(cls): - """ - Share this valid token in all tests below. - """ - super().setUpClass() - cls.repo_url = cls._api.create_repo(repo_id=repo_name()) - cls.repo_id = cls.repo_url.repo_id - cls._api.upload_file( - path_or_fileobj=cls.binary_content.encode(), - path_in_repo="random_file.txt", - repo_id=cls.repo_id, - ) - - @classmethod - def tearDownClass(cls): - cls._api.delete_repo(repo_id=cls.repo_id) - - @expect_deprecation("Repository") - def test_clone_from_repo_url(self): - Repository(self.repo_path, clone_from=self.repo_url) - - @expect_deprecation("Repository") - def test_clone_from_repo_id(self): - Repository(self.repo_path, clone_from=self.repo_id) - - @expect_deprecation("Repository") - def test_clone_from_repo_name_no_namespace_fails(self): - with self.assertRaises(EnvironmentError): - Repository(self.repo_path, clone_from=self.repo_id.split("/")[1], token=TOKEN) - - @expect_deprecation("Repository") - def test_clone_from_not_hf_url(self): - # Should not error out - Repository(self.repo_path, clone_from="https://hf.co/hf-internal-testing/huggingface-hub-dummy-repository") - - @expect_deprecation("Repository") - def test_clone_from_missing_repo(self): - """If the repo does not exist an EnvironmentError is raised.""" - with self.assertRaises(EnvironmentError): - Repository(self.repo_path, clone_from="missing_repo") - - @expect_deprecation("Repository") - @with_production_testing - def test_clone_from_prod_canonical_repo_id(self): - Repository(self.repo_path, clone_from="bert-base-cased", skip_lfs_files=True) - - @expect_deprecation("Repository") - @with_production_testing - def test_clone_from_prod_canonical_repo_url(self): - Repository(self.repo_path, clone_from="https://huggingface.co/bert-base-cased", skip_lfs_files=True) - - @expect_deprecation("Repository") - def test_init_from_existing_local_clone(self): - run_subprocess(["git", "clone", self.repo_url, str(self.repo_path)]) - - repo = Repository(self.repo_path) - repo.lfs_track(["*.pdf"]) - repo.lfs_enable_largefiles() - repo.git_pull() - - @expect_deprecation("Repository") - def test_init_failure(self): - with self.assertRaises(ValueError): - Repository(self.repo_path) - - @expect_deprecation("Repository") - def test_init_clone_in_empty_folder(self): - repo = Repository(self.repo_path, clone_from=self.repo_url) - repo.lfs_track(["*.pdf"]) - repo.lfs_enable_largefiles() - repo.git_pull() - self.assertIn("random_file.txt", os.listdir(self.repo_path)) - - @expect_deprecation("Repository") - def test_git_lfs_filename(self): - run_subprocess("git init", folder=self.repo_path) - - repo = Repository(self.repo_path) - large_file = self.repo_path / "large_file[].txt" - large_file.write_text(self.large_content) - - repo.git_add() - - repo.lfs_track([large_file.name]) - self.assertFalse(is_tracked_with_lfs(large_file)) - - repo.lfs_track([large_file.name], filename=True) - self.assertTrue(is_tracked_with_lfs(large_file)) - - @expect_deprecation("Repository") - def test_init_clone_in_nonempty_folder(self): - self._create_dummy_files() - with self.assertRaises(EnvironmentError): - Repository(self.repo_path, clone_from=self.repo_url) - - @expect_deprecation("Repository") - def test_init_clone_in_nonempty_linked_git_repo_with_token(self): - Repository(self.repo_path, clone_from=self.repo_url, token=TOKEN) - Repository(self.repo_path, clone_from=self.repo_url, token=TOKEN) - - @expect_deprecation("Repository") - def test_is_tracked_upstream(self): - Repository(self.repo_path, clone_from=self.repo_id) - self.assertTrue(is_tracked_upstream(self.repo_path)) - - @expect_deprecation("Repository") - def test_push_errors_on_wrong_checkout(self): - repo = Repository(self.repo_path, clone_from=self.repo_id) - - head_commit_ref = run_subprocess("git show --oneline -s", folder=self.repo_path).stdout.split()[0] - - repo.git_checkout(head_commit_ref) - - with self.assertRaises(OSError): - with repo.commit("New commit"): - with open("new_file", "w+") as f: - f.write("Ok") - - -class TestRepositoryUniqueRepos(RepositoryTestAbstract): - """Tests in this class use separated repos on the Hub (i.e. 1 test = 1 repo). - - These tests can push data to it. - """ - - def setUp(self): - super().setUp() - self.repo_url = self._api.create_repo(repo_id=repo_name()) - self.repo_id = self.repo_url.repo_id - self._api.upload_file( - path_or_fileobj=self.binary_content.encode(), path_in_repo="random_file.txt", repo_id=self.repo_id - ) - - def tearDown(self): - self._api.delete_repo(repo_id=self.repo_id) - - @expect_deprecation("Repository") - def clone_repo(self, **kwargs) -> Repository: - if "local_dir" not in kwargs: - kwargs["local_dir"] = self.repo_path - if "clone_from" not in kwargs: - kwargs["clone_from"] = self.repo_url - if "token" not in kwargs: - kwargs["token"] = TOKEN - if "git_user" not in kwargs: - kwargs["git_user"] = "ci" - if "git_email" not in kwargs: - kwargs["git_email"] = "ci@dummy.com" - return Repository(**kwargs) - - @use_tmp_repo() - @expect_deprecation("Repository") - def test_init_clone_in_nonempty_non_linked_git_repo(self, repo_url: RepoUrl): - self.clone_repo() - - # Try and clone another repository within the same directory. - # Should error out due to mismatched remotes. - with self.assertRaises(EnvironmentError): - Repository(self.repo_path, clone_from=repo_url) - - def test_init_clone_in_nonempty_linked_git_repo(self): - # Clone the repository to disk - self.clone_repo() - - # Add to the remote repository without doing anything to the local repository. - self._api.upload_file( - path_or_fileobj=self.binary_content.encode(), path_in_repo="random_file_3.txt", repo_id=self.repo_id - ) - - # Cloning the repository in the same directory should not result in a git pull. - self.clone_repo(clone_from=self.repo_url) - self.assertNotIn("random_file_3.txt", os.listdir(self.repo_path)) - - def test_init_clone_in_nonempty_linked_git_repo_unrelated_histories(self): - # Clone the repository to disk - repo = self.clone_repo() - - # Create and commit file locally - (self.repo_path / "random_file_3.txt").write_text("hello world") - repo.git_add() - repo.git_commit("Unrelated commit") - - # Add to the remote repository without doing anything to the local repository. - self._api.upload_file( - path_or_fileobj=self.binary_content.encode(), - path_in_repo="random_file_3.txt", - repo_id=self.repo_url.repo_id, - ) - - # The repo should initialize correctly as the remote is the same, even with unrelated historied - self.clone_repo() - - def test_add_commit_push(self): - repo = self.clone_repo() - self._create_dummy_files() - repo.git_add() - repo.git_commit() - url = repo.git_push() - - # Check that the returned commit url - # actually exists. - r = httpx.head(url) - r.raise_for_status() - - def test_add_commit_push_non_blocking(self): - repo = self.clone_repo() - self._create_dummy_files() - repo.git_add() - repo.git_commit() - url, result = repo.git_push(blocking=False) - - # Check background process - if result._process.poll() is None: - self.assertEqual(result.status, -1) - - while not result.is_done: - time.sleep(0.5) - - self.assertTrue(result.is_done) - self.assertEqual(result.status, 0) - - # Check that the returned commit url - # actually exists. - r = httpx.head(url) - r.raise_for_status() - - def test_context_manager_non_blocking(self): - repo = self.clone_repo() - - with repo.commit("New commit", blocking=False): - (self.repo_path / "dummy.txt").write_text("hello world") - - while repo.commands_in_progress: - time.sleep(1) - - self.assertEqual(len(repo.commands_in_progress), 0) - self.assertEqual(len(repo.command_queue), 1) - self.assertEqual(repo.command_queue[-1].status, 0) - self.assertEqual(repo.command_queue[-1].is_done, True) - self.assertEqual(repo.command_queue[-1].title, "push") - - @unittest.skip("This is a flaky and legacy test") - def test_add_commit_push_non_blocking_process_killed(self): - repo = self.clone_repo() - - # Far too big file: will take forever - (self.repo_path / "dummy.txt").write_text(str([[[1] * 10000] * 1000] * 10)) - repo.git_add(auto_lfs_track=True) - repo.git_commit() - _, result = repo.git_push(blocking=False) - - result._process.kill() - - while result._process.poll() is None: - time.sleep(0.5) - - self.assertTrue(result.is_done) - self.assertEqual(result.status, -9) - - def test_commit_context_manager(self): - # Clone and commit from a first folder - folder_1 = self.repo_path / "folder_1" - clone = self.clone_repo(local_dir=folder_1) - with clone.commit("Commit"): - with open("dummy.txt", "w") as f: - f.write("hello") - with open("model.bin", "w") as f: - f.write("hello") - - # Clone in second folder. Check existence of committed files - folder_2 = self.repo_path / "folder_2" - self.clone_repo(local_dir=folder_2) - files = os.listdir(folder_2) - self.assertTrue("dummy.txt" in files) - self.assertTrue("model.bin" in files) - - def test_clone_skip_lfs_files(self): - # Upload LFS file - self._api.upload_file(path_or_fileobj=b"Bin file", path_in_repo="file.bin", repo_id=self.repo_id) - - repo = self.clone_repo(skip_lfs_files=True) - file_bin = self.repo_path / "file.bin" - - self.assertTrue(file_bin.read_text().startswith("version")) - - repo.git_pull(lfs=True) - - self.assertEqual(file_bin.read_text(), "Bin file") - - def test_commits_on_correct_branch(self): - repo = self.clone_repo() - branch = repo.current_branch - repo.git_checkout("new-branch", create_branch_ok=True) - repo.git_checkout(branch) - - with repo.commit("New commit"): - with open("file.txt", "w+") as f: - f.write("Ok") - - repo.git_checkout("new-branch") - - with repo.commit("New commit"): - with open("new_file.txt", "w+") as f: - f.write("Ok") - - with SoftTemporaryDirectory() as tmp: - clone = self.clone_repo(local_dir=tmp) - files = os.listdir(clone.local_dir) - self.assertTrue("file.txt" in files) - self.assertFalse("new_file.txt" in files) - - clone.git_checkout("new-branch") - files = os.listdir(clone.local_dir) - self.assertFalse("file.txt" in files) - self.assertTrue("new_file.txt" in files) - - def test_repo_checkout_push(self): - repo = self.clone_repo() - - repo.git_checkout("new-branch", create_branch_ok=True) - repo.git_checkout("main") - - (self.repo_path / "file.txt").write_text("OK") - - repo.push_to_hub("Commit #1") - repo.git_checkout("new-branch", create_branch_ok=True) - - (self.repo_path / "new_file.txt").write_text("OK") - - repo.push_to_hub("Commit #2") - - with SoftTemporaryDirectory() as tmp: - clone = self.clone_repo(local_dir=tmp) - files = os.listdir(clone.local_dir) - self.assertTrue("file.txt" in files) - self.assertFalse("new_file.txt" in files) - - clone.git_checkout("new-branch") - files = os.listdir(clone.local_dir) - self.assertFalse("file.txt" in files) - self.assertTrue("new_file.txt" in files) - - def test_repo_checkout_commit_context_manager(self): - repo = self.clone_repo() - - with repo.commit("Commit #1", branch="new-branch"): - with open(os.path.join(repo.local_dir, "file.txt"), "w+") as f: - f.write("Ok") - - with repo.commit("Commit #2", branch="main"): - with open(os.path.join(repo.local_dir, "new_file.txt"), "w+") as f: - f.write("Ok") - - # Maintains lastly used branch - with repo.commit("Commit #3"): - with open(os.path.join(repo.local_dir, "new_file-2.txt"), "w+") as f: - f.write("Ok") - - with SoftTemporaryDirectory() as tmp: - clone = self.clone_repo(local_dir=tmp) - files = os.listdir(clone.local_dir) - self.assertFalse("file.txt" in files) - self.assertTrue("new_file-2.txt" in files) - self.assertTrue("new_file.txt" in files) - - clone.git_checkout("new-branch") - files = os.listdir(clone.local_dir) - self.assertTrue("file.txt" in files) - self.assertFalse("new_file.txt" in files) - self.assertFalse("new_file-2.txt" in files) - - def test_add_tag(self): - repo = self.clone_repo() - repo.add_tag("v4.6.0", remote="origin") - self.assertTrue(repo.tag_exists("v4.6.0", remote="origin")) - - def test_add_annotated_tag(self): - repo = self.clone_repo() - repo.add_tag("v4.5.0", message="This is an annotated tag", remote="origin") - - # Unfortunately git offers no built-in way to check the annotated - # message of a remote tag. - # In order to check that the remote tag was correctly annotated, - # we delete the local tag before pulling the remote tag (which - # should be the same). We then check that this tag is correctly - # annotated. - repo.delete_tag("v4.5.0") - - self.assertTrue(repo.tag_exists("v4.5.0", remote="origin")) - self.assertFalse(repo.tag_exists("v4.5.0")) - - # Tag still exists on remote - run_subprocess("git pull --tags", folder=self.repo_path) - self.assertTrue(repo.tag_exists("v4.5.0")) - - # Tag is annotated - result = run_subprocess("git tag -n9", folder=self.repo_path).stdout.strip() - self.assertIn("This is an annotated tag", result) - - def test_delete_tag(self): - repo = self.clone_repo() - - repo.add_tag("v4.6.0", message="This is an annotated tag", remote="origin") - self.assertTrue(repo.tag_exists("v4.6.0", remote="origin")) - - repo.delete_tag("v4.6.0") - self.assertFalse(repo.tag_exists("v4.6.0")) - self.assertTrue(repo.tag_exists("v4.6.0", remote="origin")) - - repo.delete_tag("v4.6.0", remote="origin") - self.assertFalse(repo.tag_exists("v4.6.0", remote="origin")) - - def test_lfs_prune(self): - repo = self.clone_repo() - - with repo.commit("Committing LFS file"): - with open("file.bin", "w+") as f: - f.write("Random string 1") - - with repo.commit("Committing LFS file"): - with open("file.bin", "w+") as f: - f.write("Random string 2") - - root_directory = self.repo_path / ".git" / "lfs" - git_lfs_files_size = sum(f.stat().st_size for f in root_directory.glob("**/*") if f.is_file()) - repo.lfs_prune() - post_prune_git_lfs_files_size = sum(f.stat().st_size for f in root_directory.glob("**/*") if f.is_file()) - - # Size of the directory holding LFS files was reduced - self.assertLess(post_prune_git_lfs_files_size, git_lfs_files_size) - - def test_lfs_prune_git_push(self): - repo = self.clone_repo() - with repo.commit("Committing LFS file"): - with open("file.bin", "w+") as f: - f.write("Random string 1") - - root_directory = self.repo_path / ".git" / "lfs" - git_lfs_files_size = sum(f.stat().st_size for f in root_directory.glob("**/*") if f.is_file()) - - with open(os.path.join(repo.local_dir, "file.bin"), "w+") as f: - f.write("Random string 2") - - repo.git_add() - repo.git_commit("New commit") - repo.git_push(auto_lfs_prune=True) - - post_prune_git_lfs_files_size = sum(f.stat().st_size for f in root_directory.glob("**/*") if f.is_file()) - - # Size of the directory holding LFS files is the exact same - self.assertEqual(post_prune_git_lfs_files_size, git_lfs_files_size) - - -class TestRepositoryOffline(RepositoryTestAbstract): - """Class to test `Repository` object on local folders only (no cloning from Hub).""" - - repo: Repository - - @classmethod - @expect_deprecation("Repository") - def setUp(self) -> None: - super().setUp() - - run_subprocess("git init", folder=self.repo_path) - - self.repo = Repository(self.repo_path, git_user="ci", git_email="ci@dummy.ci") - - git_attributes_path = self.repo_path / ".gitattributes" - git_attributes_path.write_text("*.pt filter=lfs diff=lfs merge=lfs -text") - - self.repo.git_add(".gitattributes") - self.repo.git_commit("Add .gitattributes") - - def test_is_tracked_with_lfs(self): - txt_1 = self.repo_path / "small_file_1.txt" - txt_2 = self.repo_path / "small_file_2.txt" - pt_1 = self.repo_path / "model.pt" - - txt_1.write_text(self.small_content) - txt_2.write_text(self.small_content) - pt_1.write_text(self.small_content) - - self.repo.lfs_track("small_file_1.txt") - - self.assertTrue(is_tracked_with_lfs(txt_1)) - self.assertFalse(is_tracked_with_lfs(txt_2)) - self.assertTrue(pt_1) - - def test_is_tracked_with_lfs_with_pattern(self): - txt_small_file = self.repo_path / "small_file.txt" - txt_small_file.write_text(self.small_content) - - txt_large_file = self.repo_path / "large_file.txt" - txt_large_file.write_text(self.large_content) - - (self.repo_path / "dir").mkdir() - txt_small_file_in_dir = self.repo_path / "dir" / "small_file.txt" - txt_small_file_in_dir.write_text(self.small_content) - - txt_large_file_in_dir = self.repo_path / "dir" / "large_file.txt" - txt_large_file_in_dir.write_text(self.large_content) - - self.repo.auto_track_large_files("dir") - - self.assertFalse(is_tracked_with_lfs(txt_large_file)) - self.assertFalse(is_tracked_with_lfs(txt_small_file)) - self.assertTrue(is_tracked_with_lfs(txt_large_file_in_dir)) - self.assertFalse(is_tracked_with_lfs(txt_small_file_in_dir)) - - def test_auto_track_large_files(self): - txt_small_file = self.repo_path / "small_file.txt" - txt_small_file.write_text(self.small_content) - - txt_large_file = self.repo_path / "large_file.txt" - txt_large_file.write_text(self.large_content) - - self.repo.auto_track_large_files() - - self.assertTrue(is_tracked_with_lfs(txt_large_file)) - self.assertFalse(is_tracked_with_lfs(txt_small_file)) - - def test_auto_track_binary_files(self): - non_binary_file = self.repo_path / "non_binary_file.txt" - non_binary_file.write_text(self.small_content) - - binary_file = self.repo_path / "binary_file.txt" - binary_file.write_text(self.binary_content) - - self.repo.auto_track_binary_files() - - self.assertFalse(is_tracked_with_lfs(non_binary_file)) - self.assertTrue(is_tracked_with_lfs(binary_file)) - - def test_auto_track_large_files_ignored_with_gitignore(self): - (self.repo_path / "dir").mkdir() - - # Test nested gitignores - gitignore_file = self.repo_path / ".gitignore" - gitignore_file.write_text("large_file.txt") - - gitignore_file_in_dir = self.repo_path / "dir" / ".gitignore" - gitignore_file_in_dir.write_text("large_file_3.txt") - - large_file = self.repo_path / "large_file.txt" - large_file.write_text(self.large_content) - - large_file_2 = self.repo_path / "large_file_2.txt" - large_file_2.write_text(self.large_content) - - large_file_3 = self.repo_path / "dir" / "large_file_3.txt" - large_file_3.write_text(self.large_content) - - large_file_4 = self.repo_path / "dir" / "large_file_4.txt" - large_file_4.write_text(self.large_content) - - self.repo.auto_track_large_files() - - # Large files - self.assertFalse(is_tracked_with_lfs(large_file)) - self.assertTrue(is_tracked_with_lfs(large_file_2)) - - self.assertFalse(is_tracked_with_lfs(large_file_3)) - self.assertTrue(is_tracked_with_lfs(large_file_4)) - - def test_auto_track_binary_files_ignored_with_gitignore(self): - (self.repo_path / "dir").mkdir() - - # Test nested gitignores - gitignore_file = self.repo_path / ".gitignore" - gitignore_file.write_text("binary_file.txt") - - gitignore_file_in_dir = self.repo_path / "dir" / ".gitignore" - gitignore_file_in_dir.write_text("binary_file_3.txt") - - binary_file = self.repo_path / "binary_file.txt" - binary_file.write_text(self.binary_content) - - binary_file_2 = self.repo_path / "binary_file_2.txt" - binary_file_2.write_text(self.binary_content) - - binary_file_3 = self.repo_path / "dir" / "binary_file_3.txt" - binary_file_3.write_text(self.binary_content) - - binary_file_4 = self.repo_path / "dir" / "binary_file_4.txt" - binary_file_4.write_text(self.binary_content) - - self.repo.auto_track_binary_files() - - # Binary files - self.assertFalse(is_tracked_with_lfs(binary_file)) - self.assertTrue(is_tracked_with_lfs(binary_file_2)) - self.assertFalse(is_tracked_with_lfs(binary_file_3)) - self.assertTrue(is_tracked_with_lfs(binary_file_4)) - - def test_auto_track_large_files_through_git_add(self): - txt_small_file = self.repo_path / "small_file.txt" - txt_small_file.write_text(self.small_content) - - txt_large_file = self.repo_path / "large_file.txt" - txt_large_file.write_text(self.large_content) - - self.repo.git_add(auto_lfs_track=True) - - self.assertTrue(is_tracked_with_lfs(txt_large_file)) - self.assertFalse(is_tracked_with_lfs(txt_small_file)) - - def test_auto_track_binary_files_through_git_add(self): - non_binary_file = self.repo_path / "small_file.txt" - non_binary_file.write_text(self.small_content) - - binary_file = self.repo_path / "binary.txt" - binary_file.write_text(self.binary_content) - - self.repo.git_add(auto_lfs_track=True) - - self.assertTrue(is_tracked_with_lfs(binary_file)) - self.assertFalse(is_tracked_with_lfs(non_binary_file)) - - def test_auto_no_track_large_files_through_git_add(self): - txt_small_file = self.repo_path / "small_file.txt" - txt_small_file.write_text(self.small_content) - - txt_large_file = self.repo_path / "large_file.txt" - txt_large_file.write_text(self.large_content) - - self.repo.git_add(auto_lfs_track=False) - - self.assertFalse(is_tracked_with_lfs(txt_large_file)) - self.assertFalse(is_tracked_with_lfs(txt_small_file)) - - def test_auto_no_track_binary_files_through_git_add(self): - non_binary_file = self.repo_path / "small_file.txt" - non_binary_file.write_text(self.small_content) - - binary_file = self.repo_path / "binary.txt" - binary_file.write_text(self.binary_content) - - self.repo.git_add(auto_lfs_track=False) - - self.assertFalse(is_tracked_with_lfs(binary_file)) - self.assertFalse(is_tracked_with_lfs(non_binary_file)) - - def test_auto_track_updates_removed_gitattributes(self): - txt_small_file = self.repo_path / "small_file.txt" - txt_small_file.write_text(self.small_content) - - txt_large_file = self.repo_path / "large_file.txt" - txt_large_file.write_text(self.large_content) - - self.repo.git_add(auto_lfs_track=True) - - self.assertTrue(is_tracked_with_lfs(txt_large_file)) - self.assertFalse(is_tracked_with_lfs(txt_small_file)) - - # Remove large file - txt_large_file.unlink() - - # Auto track should remove the entry from .gitattributes - self.repo.auto_track_large_files() - - # Recreate the large file with smaller contents - txt_large_file.write_text(self.small_content) - - # Ensure the file is not LFS tracked anymore - self.repo.auto_track_large_files() - self.assertFalse(is_tracked_with_lfs(txt_large_file)) - - def test_checkout_non_existing_branch(self): - self.assertRaises(EnvironmentError, self.repo.git_checkout, "brand-new-branch") - - def test_checkout_new_branch(self): - self.repo.git_checkout("new-branch", create_branch_ok=True) - self.assertEqual(self.repo.current_branch, "new-branch") - - def test_is_not_tracked_upstream(self): - self.repo.git_checkout("new-branch", create_branch_ok=True) - self.assertFalse(is_tracked_upstream(self.repo.local_dir)) - - def test_no_branch_checked_out_raises(self): - head_commit_ref = run_subprocess("git show --oneline -s", folder=self.repo_path).stdout.split()[0] - - self.repo.git_checkout(head_commit_ref) - self.assertRaises(OSError, is_tracked_upstream, self.repo.local_dir) - - @expect_deprecation("Repository") - def test_repo_init_checkout_default_revision(self): - # Instantiate repository on a given revision - repo = Repository(self.repo_path, revision="new-branch") - self.assertEqual(repo.current_branch, "new-branch") - - # The revision should be kept when re-initializing the repo - repo_2 = Repository(self.repo_path) - self.assertEqual(repo_2.current_branch, "new-branch") - - @expect_deprecation("Repository") - def test_repo_init_checkout_revision(self): - current_head_hash = self.repo.git_head_hash() - - (self.repo_path / "file.txt").write_text("hello world") - - self.repo.git_add() - self.repo.git_commit("Add file.txt") - - new_head_hash = self.repo.git_head_hash() - - self.assertNotEqual(current_head_hash, new_head_hash) - - previous_head_repo = Repository(self.repo_path, revision=current_head_hash) - files = os.listdir(previous_head_repo.local_dir) - self.assertNotIn("file.txt", files) - - current_head_repo = Repository(self.repo_path, revision=new_head_hash) - files = os.listdir(current_head_repo.local_dir) - self.assertIn("file.txt", files) - - @expect_deprecation("Repository") - def test_repo_user(self): - _ = Repository(self.repo_path, token=TOKEN) - username = run_subprocess("git config user.name", folder=self.repo_path).stdout - email = run_subprocess("git config user.email", folder=self.repo_path).stdout - - # hardcode values to avoid another api call to whoami - self.assertEqual(username.strip(), "Dummy User") - self.assertEqual(email.strip(), "julien@huggingface.co") - - @expect_deprecation("Repository") - def test_repo_passed_user(self): - _ = Repository(self.repo_path, token=TOKEN, git_user="RANDOM_USER", git_email="EMAIL@EMAIL.EMAIL") - username = run_subprocess("git config user.name", folder=self.repo_path).stdout - email = run_subprocess("git config user.email", folder=self.repo_path).stdout - - self.assertEqual(username.strip(), "RANDOM_USER") - self.assertEqual(email.strip(), "EMAIL@EMAIL.EMAIL") - - def test_add_tag(self): - self.repo.add_tag("v4.6.0") - self.assertTrue(self.repo.tag_exists("v4.6.0")) - - def test_add_annotated_tag(self): - self.repo.add_tag("v4.6.0", message="This is an annotated tag") - self.assertTrue(self.repo.tag_exists("v4.6.0")) - - result = run_subprocess("git tag -n9", folder=self.repo_path).stdout.strip() - self.assertIn("This is an annotated tag", result) - - def test_delete_tag(self): - self.repo.add_tag("v4.6.0", message="This is an annotated tag") - self.assertTrue(self.repo.tag_exists("v4.6.0")) - - self.repo.delete_tag("v4.6.0") - self.assertFalse(self.repo.tag_exists("v4.6.0")) - - def test_repo_clean(self): - self.assertTrue(self.repo.is_repo_clean()) - (self.repo_path / "file.txt").write_text("hello world") - self.assertFalse(self.repo.is_repo_clean()) - - -class TestRepositoryDataset(RepositoryTestAbstract): - """Class to test that cloning from a different repo_type works fine.""" - - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.repo_url = cls._api.create_repo(repo_id=repo_name(), repo_type="dataset") - cls.repo_id = cls.repo_url.repo_id - cls._api.upload_file( - path_or_fileobj=cls.binary_content.encode(), - path_in_repo="file.txt", - repo_id=cls.repo_id, - repo_type="dataset", - ) - - @classmethod - def tearDownClass(cls): - super().tearDownClass() - cls._api.delete_repo(repo_id=cls.repo_id, repo_type="dataset") - - @expect_deprecation("Repository") - def test_clone_dataset_with_endpoint_explicit_repo_type(self): - Repository( - self.repo_path, clone_from=self.repo_url, repo_type="dataset", git_user="ci", git_email="ci@dummy.com" - ) - self.assertTrue((self.repo_path / "file.txt").exists()) - - @expect_deprecation("Repository") - def test_clone_dataset_with_endpoint_implicit_repo_type(self): - self.assertIn("dataset", self.repo_url) # Implicit - Repository(self.repo_path, clone_from=self.repo_url, git_user="ci", git_email="ci@dummy.com") - self.assertTrue((self.repo_path / "file.txt").exists()) - - @expect_deprecation("Repository") - def test_clone_dataset_with_repo_id_and_repo_type(self): - Repository( - self.repo_path, clone_from=self.repo_id, repo_type="dataset", git_user="ci", git_email="ci@dummy.com" - ) - self.assertTrue((self.repo_path / "file.txt").exists()) - - @expect_deprecation("Repository") - def test_clone_dataset_no_ci_user_and_email(self): - Repository(self.repo_path, clone_from=self.repo_id, repo_type="dataset") - self.assertTrue((self.repo_path / "file.txt").exists()) - - @expect_deprecation("Repository") - def test_clone_dataset_with_repo_name_and_repo_type_fails(self): - with self.assertRaises(EnvironmentError): - Repository( - self.repo_path, - clone_from=self.repo_id.split("/")[1], - repo_type="dataset", - token=TOKEN, - git_user="ci", - git_email="ci@dummy.com", - ) From e56ba6eff376fac990862ab7cc652efb51bbc4ab Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Fri, 12 Sep 2025 09:57:28 +0200 Subject: [PATCH 06/19] bump to 1.0.0.dev0 --- src/huggingface_hub/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index 13ecb24cf2..bb1eafdbea 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -46,7 +46,7 @@ from typing import TYPE_CHECKING -__version__ = "0.35.0.dev0" +__version__ = "1.0.0.dev0" # Alphabetical order of definitions is ensured in tests # WARNING: any comment added in this dictionary definition will be lost when From ccca22e19958923f22306f7150f4216744612431 Mon Sep 17 00:00:00 2001 From: Lucain Date: Fri, 12 Sep 2025 10:01:35 +0200 Subject: [PATCH 07/19] Remove _deprecate_positional_args on login methods (#3349) --- src/huggingface_hub/_login.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/huggingface_hub/_login.py b/src/huggingface_hub/_login.py index 946fd18af2..c24c401de0 100644 --- a/src/huggingface_hub/_login.py +++ b/src/huggingface_hub/_login.py @@ -41,7 +41,6 @@ _save_token, get_stored_tokens, ) -from .utils._deprecation import _deprecate_positional_args logger = logging.get_logger(__name__) @@ -55,7 +54,6 @@ """ -@_deprecate_positional_args(version="1.0") def login( token: Optional[str] = None, *, @@ -234,7 +232,6 @@ def auth_list() -> None: ### -@_deprecate_positional_args(version="1.0") def interpreter_login(*, skip_if_logged_in: bool = False) -> None: """ Displays a prompt to log in to the HF website and store the token. @@ -299,7 +296,6 @@ def interpreter_login(*, skip_if_logged_in: bool = False) -> None: notebooks. """ -@_deprecate_positional_args(version="1.0") def notebook_login(*, skip_if_logged_in: bool = False) -> None: """ Displays a widget to log in to the HF website and store the token. From b6798588efb8ed404ee1f7ebaf9a02b63610951a Mon Sep 17 00:00:00 2001 From: Lucain Date: Fri, 12 Sep 2025 10:27:49 +0200 Subject: [PATCH 08/19] [v1.0] Remove imports kept only for backward compatibility (#3350) * Remove imports kept only for backward compatibility * fix tests --- src/huggingface_hub/fastai_utils.py | 1 - src/huggingface_hub/file_download.py | 27 +++------------------------ src/huggingface_hub/hf_api.py | 21 --------------------- tests/test_file_download.py | 4 ++-- tests/testing_utils.py | 2 +- 5 files changed, 6 insertions(+), 49 deletions(-) diff --git a/src/huggingface_hub/fastai_utils.py b/src/huggingface_hub/fastai_utils.py index de36ff3b36..88cb35b6d4 100644 --- a/src/huggingface_hub/fastai_utils.py +++ b/src/huggingface_hub/fastai_utils.py @@ -16,7 +16,6 @@ ) from .utils import logging, validate_hf_hub_args -from .utils._runtime import _PY_VERSION # noqa: F401 # for backward compatibility... logger = logging.get_logger(__name__) diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index 26efe85b59..beb40fa798 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -14,15 +14,8 @@ import httpx -from . import ( - __version__, # noqa: F401 # for backward compatibility - constants, -) +from . import constants from ._local_folder import get_local_download_paths, read_download_metadata, write_download_metadata -from .constants import ( - HUGGINGFACE_CO_URL_TEMPLATE, # noqa: F401 # for backward compatibility - HUGGINGFACE_HUB_CACHE, # noqa: F401 # for backward compatibility -) from .errors import ( FileMetadataError, GatedRepoError, @@ -38,21 +31,7 @@ WeakFileLock, XetFileData, build_hf_headers, - get_fastai_version, # noqa: F401 # for backward compatibility - get_fastcore_version, # noqa: F401 # for backward compatibility - get_graphviz_version, # noqa: F401 # for backward compatibility - get_jinja_version, # noqa: F401 # for backward compatibility - get_pydot_version, # noqa: F401 # for backward compatibility - get_tf_version, # noqa: F401 # for backward compatibility - get_torch_version, # noqa: F401 # for backward compatibility hf_raise_for_status, - is_fastai_available, # noqa: F401 # for backward compatibility - is_fastcore_available, # noqa: F401 # for backward compatibility - is_graphviz_available, # noqa: F401 # for backward compatibility - is_jinja_available, # noqa: F401 # for backward compatibility - is_pydot_available, # noqa: F401 # for backward compatibility - is_tf_available, # noqa: F401 # for backward compatibility - is_torch_available, # noqa: F401 # for backward compatibility logging, parse_xet_file_data_from_response, refresh_xet_connection_info, @@ -60,7 +39,7 @@ validate_hf_hub_args, ) from .utils._http import _adjust_range_header, http_backoff, http_stream_backoff -from .utils._runtime import _PY_VERSION, is_xet_available # noqa: F401 # for backward compatibility +from .utils._runtime import is_xet_available from .utils._typing import HTTP_METHOD_T from .utils.sha import sha_fileobj from .utils.tqdm import _get_progress_bar_context @@ -250,7 +229,7 @@ def hf_hub_url( if revision is None: revision = constants.DEFAULT_REVISION - url = HUGGINGFACE_CO_URL_TEMPLATE.format( + url = constants.HUGGINGFACE_CO_URL_TEMPLATE.format( repo_id=repo_id, revision=quote(revision, safe=""), filename=quote(filename) ) # Update endpoint if provided diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index d1297c8dc1..89259c2c52 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -73,27 +73,6 @@ DiscussionWithDetails, deserialize_event, ) -from .constants import ( - DEFAULT_ETAG_TIMEOUT, # noqa: F401 # kept for backward compatibility - DEFAULT_REQUEST_TIMEOUT, # noqa: F401 # kept for backward compatibility - DEFAULT_REVISION, # noqa: F401 # kept for backward compatibility - DISCUSSION_STATUS, # noqa: F401 # kept for backward compatibility - DISCUSSION_TYPES, # noqa: F401 # kept for backward compatibility - ENDPOINT, # noqa: F401 # kept for backward compatibility - INFERENCE_ENDPOINTS_ENDPOINT, # noqa: F401 # kept for backward compatibility - REGEX_COMMIT_OID, # noqa: F401 # kept for backward compatibility - REPO_TYPE_MODEL, # noqa: F401 # kept for backward compatibility - REPO_TYPES, # noqa: F401 # kept for backward compatibility - REPO_TYPES_MAPPING, # noqa: F401 # kept for backward compatibility - REPO_TYPES_URL_PREFIXES, # noqa: F401 # kept for backward compatibility - SAFETENSORS_INDEX_FILE, # noqa: F401 # kept for backward compatibility - SAFETENSORS_MAX_HEADER_LENGTH, # noqa: F401 # kept for backward compatibility - SAFETENSORS_SINGLE_FILE, # noqa: F401 # kept for backward compatibility - SPACES_SDK_TYPES, # noqa: F401 # kept for backward compatibility - WEBHOOK_DOMAIN_T, # noqa: F401 # kept for backward compatibility - DiscussionStatusFilter, # noqa: F401 # kept for backward compatibility - DiscussionTypeFilter, # noqa: F401 # kept for backward compatibility -) from .errors import ( BadRequestError, GatedRepoError, diff --git a/tests/test_file_download.py b/tests/test_file_download.py index 87a2a645b9..afc392287c 100644 --- a/tests/test_file_download.py +++ b/tests/test_file_download.py @@ -355,9 +355,9 @@ def test_hf_hub_url_with_empty_subfolder(self): ) ) - @patch("huggingface_hub.file_download.constants.ENDPOINT", "https://huggingface.co") + @patch("huggingface_hub.constants.ENDPOINT", "https://huggingface.co") @patch( - "huggingface_hub.file_download.HUGGINGFACE_CO_URL_TEMPLATE", + "huggingface_hub.constants.HUGGINGFACE_CO_URL_TEMPLATE", "https://huggingface.co/{repo_id}/resolve/{revision}/{filename}", ) def test_hf_hub_url_with_endpoint(self): diff --git a/tests/testing_utils.py b/tests/testing_utils.py index bc88840aee..6e9f24919f 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -243,7 +243,7 @@ def rmtree_with_retry(path: Union[str, Path]) -> None: def with_production_testing(func): - file_download = patch("huggingface_hub.file_download.HUGGINGFACE_CO_URL_TEMPLATE", ENDPOINT_PRODUCTION_URL_SCHEME) + file_download = patch("huggingface_hub.constants.HUGGINGFACE_CO_URL_TEMPLATE", ENDPOINT_PRODUCTION_URL_SCHEME) hf_api = patch("huggingface_hub.constants.ENDPOINT", ENDPOINT_PRODUCTION) return hf_api(file_download(func)) From 877955d741619f9e0be54e0b2e489c7c48ccd863 Mon Sep 17 00:00:00 2001 From: Lucain Date: Fri, 12 Sep 2025 11:47:22 +0200 Subject: [PATCH 09/19] [v1.0] Remove keras2 utilities (#3352) * Remove keras2 utilities * remove keras from init --- .github/workflows/python-tests.yml | 1 - docs/source/en/package_reference/mixins.md | 10 - docs/source/ko/package_reference/mixins.md | 10 - setup.py | 1 - src/huggingface_hub/__init__.py | 16 - src/huggingface_hub/keras_mixin.py | 494 --------------------- src/huggingface_hub/utils/_runtime.py | 22 +- tests/test_keras_integration.py | 340 -------------- 8 files changed, 12 insertions(+), 882 deletions(-) delete mode 100644 src/huggingface_hub/keras_mixin.py delete mode 100644 tests/test_keras_integration.py diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 336e44ef54..f32527f67b 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -130,7 +130,6 @@ jobs: tensorflow) # Cannot be on same line since '_tf*' checks if tensorflow is NOT imported by default eval "$PYTEST ../tests/test_tf*" - eval "$PYTEST ../tests/test_keras*" eval "$PYTEST ../tests/test_serialization.py" ;; diff --git a/docs/source/en/package_reference/mixins.md b/docs/source/en/package_reference/mixins.md index 42c253e710..c725306efe 100644 --- a/docs/source/en/package_reference/mixins.md +++ b/docs/source/en/package_reference/mixins.md @@ -21,16 +21,6 @@ how to integrate any ML framework with the Hub. [[autodoc]] PyTorchModelHubMixin -### Keras - -[[autodoc]] KerasModelHubMixin - -[[autodoc]] from_pretrained_keras - -[[autodoc]] push_to_hub_keras - -[[autodoc]] save_pretrained_keras - ### Fastai [[autodoc]] from_pretrained_fastai diff --git a/docs/source/ko/package_reference/mixins.md b/docs/source/ko/package_reference/mixins.md index 4a4a84ad9e..a5f8162eff 100644 --- a/docs/source/ko/package_reference/mixins.md +++ b/docs/source/ko/package_reference/mixins.md @@ -20,16 +20,6 @@ ML 프레임워크를 Hub와 통합하는 방법은 [통합 가이드](../guides [[autodoc]] PyTorchModelHubMixin -### Keras[[huggingface_hub.KerasModelHubMixin]] - -[[autodoc]] KerasModelHubMixin - -[[autodoc]] from_pretrained_keras - -[[autodoc]] push_to_hub_keras - -[[autodoc]] save_pretrained_keras - ### Fastai[[huggingface_hub.from_pretrained_fastai]] [[autodoc]] from_pretrained_fastai diff --git a/setup.py b/setup.py index 9a755682a6..6524423fb8 100644 --- a/setup.py +++ b/setup.py @@ -62,7 +62,6 @@ def get_version() -> str: extras["tensorflow-testing"] = [ "tensorflow", - "keras<3.0", ] extras["hf_xet"] = ["hf-xet>=1.1.2,<2.0.0"] diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index bb1eafdbea..2d2d27c9db 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -469,12 +469,6 @@ "inference._mcp.mcp_client": [ "MCPClient", ], - "keras_mixin": [ - "KerasModelHubMixin", - "from_pretrained_keras", - "push_to_hub_keras", - "save_pretrained_keras", - ], "repocard": [ "DatasetCard", "ModelCard", @@ -687,7 +681,6 @@ "JobOwner", "JobStage", "JobStatus", - "KerasModelHubMixin", "MCPClient", "ModelCard", "ModelCardData", @@ -862,7 +855,6 @@ "fetch_job_logs", "file_exists", "from_pretrained_fastai", - "from_pretrained_keras", "get_async_session", "get_collection", "get_dataset_tags", @@ -933,7 +925,6 @@ "permanently_delete_lfs_files", "preupload_lfs_files", "push_to_hub_fastai", - "push_to_hub_keras", "read_dduf_file", "reject_access_request", "rename_discussion", @@ -949,7 +940,6 @@ "run_as_future", "run_job", "run_uv_job", - "save_pretrained_keras", "save_torch_model", "save_torch_state_dict", "scale_to_zero_inference_endpoint", @@ -1485,12 +1475,6 @@ def __dir__(): ) from .inference._mcp.agent import Agent # noqa: F401 from .inference._mcp.mcp_client import MCPClient # noqa: F401 - from .keras_mixin import ( - KerasModelHubMixin, # noqa: F401 - from_pretrained_keras, # noqa: F401 - push_to_hub_keras, # noqa: F401 - save_pretrained_keras, # noqa: F401 - ) from .repocard import ( DatasetCard, # noqa: F401 ModelCard, # noqa: F401 diff --git a/src/huggingface_hub/keras_mixin.py b/src/huggingface_hub/keras_mixin.py deleted file mode 100644 index 78c239acfe..0000000000 --- a/src/huggingface_hub/keras_mixin.py +++ /dev/null @@ -1,494 +0,0 @@ -import collections.abc as collections -import json -import os -import warnings -from functools import wraps -from pathlib import Path -from shutil import copytree -from typing import Any, Optional, Union - -from huggingface_hub import ModelHubMixin, snapshot_download -from huggingface_hub.utils import ( - get_tf_version, - is_graphviz_available, - is_pydot_available, - is_tf_available, - yaml_dump, -) - -from . import constants -from .hf_api import HfApi -from .utils import SoftTemporaryDirectory, logging, validate_hf_hub_args -from .utils._typing import CallableT - - -logger = logging.get_logger(__name__) - -keras = None -if is_tf_available(): - # Depending on which version of TensorFlow is installed, we need to import - # keras from the correct location. - # See https://github.com/tensorflow/tensorflow/releases/tag/v2.16.1. - # Note: saving a keras model only works with Keras<3.0. - try: - import tf_keras as keras # type: ignore - except ImportError: - import tensorflow as tf # type: ignore - - keras = tf.keras - - -def _requires_keras_2_model(fn: CallableT) -> CallableT: - # Wrapper to raise if user tries to save a Keras 3.x model - @wraps(fn) - def _inner(model, *args, **kwargs): - if not hasattr(model, "history"): # hacky way to check if model is Keras 2.x - raise NotImplementedError( - f"Cannot use '{fn.__name__}': Keras 3.x is not supported." - " Please save models manually and upload them using `upload_folder` or `hf upload`." - ) - return fn(model, *args, **kwargs) - - return _inner # type: ignore [return-value] - - -def _flatten_dict(dictionary, parent_key=""): - """Flatten a nested dictionary. - Reference: https://stackoverflow.com/a/6027615/10319735 - - Args: - dictionary (`dict`): - The nested dictionary to be flattened. - parent_key (`str`): - The parent key to be prefixed to the children keys. - Necessary for recursing over the nested dictionary. - - Returns: - The flattened dictionary. - """ - items = [] - for key, value in dictionary.items(): - new_key = f"{parent_key}.{key}" if parent_key else key - if isinstance(value, collections.MutableMapping): - items.extend( - _flatten_dict( - value, - new_key, - ).items() - ) - else: - items.append((new_key, value)) - return dict(items) - - -def _create_hyperparameter_table(model): - """Parse hyperparameter dictionary into a markdown table.""" - table = None - if model.optimizer is not None: - optimizer_params = model.optimizer.get_config() - # flatten the configuration - optimizer_params = _flatten_dict(optimizer_params) - optimizer_params["training_precision"] = keras.mixed_precision.global_policy().name - table = "| Hyperparameters | Value |\n| :-- | :-- |\n" - for key, value in optimizer_params.items(): - table += f"| {key} | {value} |\n" - return table - - -def _plot_network(model, save_directory): - keras.utils.plot_model( - model, - to_file=f"{save_directory}/model.png", - show_shapes=False, - show_dtype=False, - show_layer_names=True, - rankdir="TB", - expand_nested=False, - dpi=96, - layer_range=None, - ) - - -def _create_model_card( - model, - repo_dir: Path, - plot_model: bool = True, - metadata: Optional[dict] = None, -): - """ - Creates a model card for the repository. - - Do not overwrite an existing README.md file. - """ - readme_path = repo_dir / "README.md" - if readme_path.exists(): - return - - hyperparameters = _create_hyperparameter_table(model) - if plot_model and is_graphviz_available() and is_pydot_available(): - _plot_network(model, repo_dir) - if metadata is None: - metadata = {} - metadata["library_name"] = "keras" - model_card: str = "---\n" - model_card += yaml_dump(metadata, default_flow_style=False) - model_card += "---\n" - model_card += "\n## Model description\n\nMore information needed\n" - model_card += "\n## Intended uses & limitations\n\nMore information needed\n" - model_card += "\n## Training and evaluation data\n\nMore information needed\n" - if hyperparameters is not None: - model_card += "\n## Training procedure\n" - model_card += "\n### Training hyperparameters\n" - model_card += "\nThe following hyperparameters were used during training:\n\n" - model_card += hyperparameters - model_card += "\n" - if plot_model and os.path.exists(f"{repo_dir}/model.png"): - model_card += "\n ## Model Plot\n" - model_card += "\n
" - model_card += "\nView Model Plot\n" - path_to_plot = "./model.png" - model_card += f"\n![Model Image]({path_to_plot})\n" - model_card += "\n
" - - readme_path.write_text(model_card) - - -@_requires_keras_2_model -def save_pretrained_keras( - model, - save_directory: Union[str, Path], - config: Optional[dict[str, Any]] = None, - include_optimizer: bool = False, - plot_model: bool = True, - tags: Optional[Union[list, str]] = None, - **model_save_kwargs, -): - """ - Saves a Keras model to save_directory in SavedModel format. Use this if - you're using the Functional or Sequential APIs. - - Args: - model (`Keras.Model`): - The [Keras - model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) - you'd like to save. The model must be compiled and built. - save_directory (`str` or `Path`): - Specify directory in which you want to save the Keras model. - config (`dict`, *optional*): - Configuration object to be saved alongside the model weights. - include_optimizer(`bool`, *optional*, defaults to `False`): - Whether or not to include optimizer in serialization. - plot_model (`bool`, *optional*, defaults to `True`): - Setting this to `True` will plot the model and put it in the model - card. Requires graphviz and pydot to be installed. - tags (Union[`str`,`list`], *optional*): - List of tags that are related to model or string of a single tag. See example tags - [here](https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1). - model_save_kwargs(`dict`, *optional*): - model_save_kwargs will be passed to - [`tf.keras.models.save_model()`](https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model). - """ - if keras is None: - raise ImportError("Called a Tensorflow-specific function but could not import it.") - - if not model.built: - raise ValueError("Model should be built before trying to save") - - save_directory = Path(save_directory) - save_directory.mkdir(parents=True, exist_ok=True) - - # saving config - if config: - if not isinstance(config, dict): - raise RuntimeError(f"Provided config to save_pretrained_keras should be a dict. Got: '{type(config)}'") - - with (save_directory / constants.CONFIG_NAME).open("w") as f: - json.dump(config, f) - - metadata = {} - if isinstance(tags, list): - metadata["tags"] = tags - elif isinstance(tags, str): - metadata["tags"] = [tags] - - task_name = model_save_kwargs.pop("task_name", None) - if task_name is not None: - warnings.warn( - "`task_name` input argument is deprecated. Pass `tags` instead.", - FutureWarning, - ) - if "tags" in metadata: - metadata["tags"].append(task_name) - else: - metadata["tags"] = [task_name] - - if model.history is not None: - if model.history.history != {}: - path = save_directory / "history.json" - if path.exists(): - warnings.warn( - "`history.json` file already exists, it will be overwritten by the history of this version.", - UserWarning, - ) - with path.open("w", encoding="utf-8") as f: - json.dump(model.history.history, f, indent=2, sort_keys=True) - - _create_model_card(model, save_directory, plot_model, metadata) - keras.models.save_model(model, save_directory, include_optimizer=include_optimizer, **model_save_kwargs) - - -def from_pretrained_keras(*args, **kwargs) -> "KerasModelHubMixin": - r""" - Instantiate a pretrained Keras model from a pre-trained model from the Hub. - The model is expected to be in `SavedModel` format. - - Args: - pretrained_model_name_or_path (`str` or `os.PathLike`): - Can be either: - - A string, the `model id` of a pretrained model hosted inside a - model repo on huggingface.co. Valid model ids can be located - at the root-level, like `bert-base-uncased`, or namespaced - under a user or organization name, like - `dbmdz/bert-base-german-cased`. - - You can add `revision` by appending `@` at the end of model_id - simply like this: `dbmdz/bert-base-german-cased@main` Revision - is the specific model version to use. It can be a branch name, - a tag name, or a commit id, since we use a git-based system - for storing models and other artifacts on huggingface.co, so - `revision` can be any identifier allowed by git. - - A path to a `directory` containing model weights saved using - [`~transformers.PreTrainedModel.save_pretrained`], e.g., - `./my_model_directory/`. - - `None` if you are both providing the configuration and state - dictionary (resp. with keyword arguments `config` and - `state_dict`). - force_download (`bool`, *optional*, defaults to `False`): - Whether to force the (re-)download of the model weights and - configuration files, overriding the cached versions if they exist. - token (`str` or `bool`, *optional*): - The token to use as HTTP bearer authorization for remote files. If - `True`, will use the token generated when running `transformers-cli - login` (stored in `~/.huggingface`). - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory in which a downloaded pretrained model - configuration should be cached if the standard cache should not be - used. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether to only look at local files (i.e., do not try to download - the model). - model_kwargs (`dict`, *optional*): - model_kwargs will be passed to the model during initialization - - - - Passing `token=True` is required when you want to use a private - model. - - - """ - return KerasModelHubMixin.from_pretrained(*args, **kwargs) - - -@validate_hf_hub_args -@_requires_keras_2_model -def push_to_hub_keras( - model, - repo_id: str, - *, - config: Optional[dict] = None, - commit_message: str = "Push Keras model using huggingface_hub.", - private: Optional[bool] = None, - api_endpoint: Optional[str] = None, - token: Optional[str] = None, - branch: Optional[str] = None, - create_pr: Optional[bool] = None, - allow_patterns: Optional[Union[list[str], str]] = None, - ignore_patterns: Optional[Union[list[str], str]] = None, - delete_patterns: Optional[Union[list[str], str]] = None, - log_dir: Optional[str] = None, - include_optimizer: bool = False, - tags: Optional[Union[list, str]] = None, - plot_model: bool = True, - **model_save_kwargs, -): - """ - Upload model checkpoint to the Hub. - - Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use - `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more - details. - - Args: - model (`Keras.Model`): - The [Keras model](`https://www.tensorflow.org/api_docs/python/tf/keras/Model`) you'd like to push to the - Hub. The model must be compiled and built. - repo_id (`str`): - ID of the repository to push to (example: `"username/my-model"`). - commit_message (`str`, *optional*, defaults to "Add Keras model"): - Message to commit while pushing. - private (`bool`, *optional*): - Whether the repository created should be private. - If `None` (default), the repo will be public unless the organization's default is private. - api_endpoint (`str`, *optional*): - The API endpoint to use when pushing the model to the hub. - token (`str`, *optional*): - The token to use as HTTP bearer authorization for remote files. If - not set, will use the token set when logging in with - `hf auth login` (stored in `~/.huggingface`). - branch (`str`, *optional*): - The git branch on which to push the model. This defaults to - the default branch as specified in your repository, which - defaults to `"main"`. - create_pr (`boolean`, *optional*): - Whether or not to create a Pull Request from `branch` with that commit. - Defaults to `False`. - config (`dict`, *optional*): - Configuration object to be saved alongside the model weights. - allow_patterns (`list[str]` or `str`, *optional*): - If provided, only files matching at least one pattern are pushed. - ignore_patterns (`list[str]` or `str`, *optional*): - If provided, files matching any of the patterns are not pushed. - delete_patterns (`list[str]` or `str`, *optional*): - If provided, remote files matching any of the patterns will be deleted from the repo. - log_dir (`str`, *optional*): - TensorBoard logging directory to be pushed. The Hub automatically - hosts and displays a TensorBoard instance if log files are included - in the repository. - include_optimizer (`bool`, *optional*, defaults to `False`): - Whether or not to include optimizer during serialization. - tags (Union[`list`, `str`], *optional*): - List of tags that are related to model or string of a single tag. See example tags - [here](https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1). - plot_model (`bool`, *optional*, defaults to `True`): - Setting this to `True` will plot the model and put it in the model - card. Requires graphviz and pydot to be installed. - model_save_kwargs(`dict`, *optional*): - model_save_kwargs will be passed to - [`tf.keras.models.save_model()`](https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model). - - Returns: - The url of the commit of your model in the given repository. - """ - api = HfApi(endpoint=api_endpoint) - repo_id = api.create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True).repo_id - - # Push the files to the repo in a single commit - with SoftTemporaryDirectory() as tmp: - saved_path = Path(tmp) / repo_id - save_pretrained_keras( - model, - saved_path, - config=config, - include_optimizer=include_optimizer, - tags=tags, - plot_model=plot_model, - **model_save_kwargs, - ) - - # If `log_dir` provided, delete remote logs and upload new ones - if log_dir is not None: - delete_patterns = ( - [] - if delete_patterns is None - else ( - [delete_patterns] # convert `delete_patterns` to a list - if isinstance(delete_patterns, str) - else delete_patterns - ) - ) - delete_patterns.append("logs/*") - copytree(log_dir, saved_path / "logs") - - return api.upload_folder( - repo_type="model", - repo_id=repo_id, - folder_path=saved_path, - commit_message=commit_message, - token=token, - revision=branch, - create_pr=create_pr, - allow_patterns=allow_patterns, - ignore_patterns=ignore_patterns, - delete_patterns=delete_patterns, - ) - - -class KerasModelHubMixin(ModelHubMixin): - """ - Implementation of [`ModelHubMixin`] to provide model Hub upload/download - capabilities to Keras models. - - - ```python - >>> import tensorflow as tf - >>> from huggingface_hub import KerasModelHubMixin - - - >>> class MyModel(tf.keras.Model, KerasModelHubMixin): - ... def __init__(self, **kwargs): - ... super().__init__() - ... self.config = kwargs.pop("config", None) - ... self.dummy_inputs = ... - ... self.layer = ... - - ... def call(self, *args): - ... return ... - - - >>> # Initialize and compile the model as you normally would - >>> model = MyModel() - >>> model.compile(...) - >>> # Build the graph by training it or passing dummy inputs - >>> _ = model(model.dummy_inputs) - >>> # Save model weights to local directory - >>> model.save_pretrained("my-awesome-model") - >>> # Push model weights to the Hub - >>> model.push_to_hub("my-awesome-model") - >>> # Download and initialize weights from the Hub - >>> model = MyModel.from_pretrained("username/super-cool-model") - ``` - """ - - def _save_pretrained(self, save_directory): - save_pretrained_keras(self, save_directory) - - @classmethod - def _from_pretrained( - cls, - model_id, - revision, - cache_dir, - force_download, - local_files_only, - token, - config: Optional[dict[str, Any]] = None, - **model_kwargs, - ): - """Here we just call [`from_pretrained_keras`] function so both the mixin and - functional APIs stay in sync. - - TODO - Some args above aren't used since we are calling - snapshot_download instead of hf_hub_download. - """ - if keras is None: - raise ImportError("Called a TensorFlow-specific function but could not import it.") - - # Root is either a local filepath matching model_id or a cached snapshot - if not os.path.isdir(model_id): - storage_folder = snapshot_download( - repo_id=model_id, - revision=revision, - cache_dir=cache_dir, - library_name="keras", - library_version=get_tf_version(), - ) - else: - storage_folder = model_id - - # TODO: change this in a future PR. We are not returning a KerasModelHubMixin instance here... - model = keras.models.load_model(storage_folder) - - # For now, we add a new attribute, config, to store the config loaded from the hub/a local dir. - model.config = config - - return model diff --git a/src/huggingface_hub/utils/_runtime.py b/src/huggingface_hub/utils/_runtime.py index 9d52091fc9..dd390df87c 100644 --- a/src/huggingface_hub/utils/_runtime.py +++ b/src/huggingface_hub/utils/_runtime.py @@ -38,6 +38,7 @@ "hf_transfer": {"hf_transfer"}, "hf_xet": {"hf_xet"}, "jinja": {"Jinja2"}, + "httpx": {"httpx"}, "keras": {"keras"}, "numpy": {"numpy"}, "pillow": {"Pillow"}, @@ -152,6 +153,15 @@ def get_hf_transfer_version() -> str: return _get_version("hf_transfer") +# httpx +def is_httpx_available() -> bool: + return is_package_available("httpx") + + +def get_httpx_version() -> str: + return _get_version("httpx") + + # xet def is_xet_available() -> bool: # since hf_xet is automatically used if available, allow explicit disabling via environment variable @@ -357,21 +367,13 @@ def dump_environment_info() -> dict[str, Any]: pass # Installed dependencies - info["FastAI"] = get_fastai_version() - info["Tensorflow"] = get_tf_version() info["Torch"] = get_torch_version() - info["Jinja2"] = get_jinja_version() - info["Graphviz"] = get_graphviz_version() - info["keras"] = get_keras_version() - info["Pydot"] = get_pydot_version() - info["Pillow"] = get_pillow_version() + info["httpx"] = get_httpx_version() info["hf_transfer"] = get_hf_transfer_version() + info["hf_xet"] = get_xet_version() info["gradio"] = get_gradio_version() info["tensorboard"] = get_tensorboard_version() - info["numpy"] = get_numpy_version() info["pydantic"] = get_pydantic_version() - info["aiohttp"] = get_aiohttp_version() - info["hf_xet"] = get_xet_version() # Environment variables info["ENDPOINT"] = constants.ENDPOINT diff --git a/tests/test_keras_integration.py b/tests/test_keras_integration.py deleted file mode 100644 index c7f020200d..0000000000 --- a/tests/test_keras_integration.py +++ /dev/null @@ -1,340 +0,0 @@ -import json -import os -import unittest -from pathlib import Path - -import pytest - -from huggingface_hub import HfApi, hf_hub_download, snapshot_download -from huggingface_hub.keras_mixin import ( - KerasModelHubMixin, - from_pretrained_keras, - push_to_hub_keras, - save_pretrained_keras, -) -from huggingface_hub.utils import is_graphviz_available, is_pydot_available, is_tf_available, logging - -from .testing_constants import ENDPOINT_STAGING, TOKEN, USER -from .testing_utils import repo_name - - -logger = logging.get_logger(__name__) - - -if is_tf_available(): - import tensorflow as tf - - -def require_tf(test_case): - """ - Decorator marking a test that requires TensorFlow, graphviz and pydot. - - These tests are skipped when TensorFlow, graphviz and pydot are installed. - - """ - if not is_tf_available() or not is_pydot_available() or not is_graphviz_available(): - return unittest.skip("test requires Tensorflow, graphviz and pydot.")(test_case) - else: - return test_case - - -if is_tf_available(): - # Define dummy mixin model... - class DummyModel(tf.keras.Model, KerasModelHubMixin): - def __init__(self, **kwargs): - super().__init__() - self.l1 = tf.keras.layers.Dense(2, activation="relu") - dummy_batch_size = input_dim = 2 - self.dummy_inputs = tf.ones([dummy_batch_size, input_dim]) - - def call(self, x): - return self.l1(x) - -else: - DummyModel = None - - -@require_tf -@pytest.mark.usefixtures("fx_cache_dir") -class CommonKerasTest(unittest.TestCase): - cache_dir: Path - - @classmethod - def setUpClass(cls): - """ - Share this valid token in all tests below. - """ - cls._api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN) - - -class HubMixinTestKeras(CommonKerasTest): - def test_save_pretrained(self): - model = DummyModel() - model(model.dummy_inputs) - - model.save_pretrained(self.cache_dir) - files = os.listdir(self.cache_dir) - self.assertTrue("saved_model.pb" in files) - self.assertTrue("keras_metadata.pb" in files) - self.assertTrue("README.md" in files) - self.assertTrue("model.png" in files) - self.assertEqual(len(files), 7) - - model.save_pretrained(self.cache_dir, config={"num": 12, "act": "gelu"}) - files = os.listdir(self.cache_dir) - self.assertTrue("config.json" in files) - self.assertTrue("saved_model.pb" in files) - self.assertEqual(len(files), 8) - - def test_keras_from_pretrained_weights(self): - model = DummyModel() - model(model.dummy_inputs) - - model.save_pretrained(self.cache_dir) - new_model = DummyModel.from_pretrained(self.cache_dir) - - # Check the reloaded model's weights match the original model's weights - self.assertTrue(tf.reduce_all(tf.equal(new_model.weights[0], model.weights[0]))) - - # Check a new model's weights are not the same as the reloaded model's weights - another_model = DummyModel() - another_model(tf.ones([2, 2])) - self.assertFalse(tf.reduce_all(tf.equal(new_model.weights[0], another_model.weights[0])).numpy().item()) - - def test_abs_path_from_pretrained(self): - model = DummyModel() - model(model.dummy_inputs) - model.save_pretrained(self.cache_dir, config={"num": 10, "act": "gelu_fast"}) - model = DummyModel.from_pretrained(self.cache_dir) - self.assertTrue(model.config == {"num": 10, "act": "gelu_fast"}) - - def test_push_to_hub_keras_mixin_via_http_basic(self): - repo_id = f"{USER}/{repo_name()}" - - model = DummyModel() - model(model.dummy_inputs) - - model.push_to_hub(repo_id=repo_id, token=TOKEN, config={"num": 7, "act": "gelu_fast"}) - - # Test model id exists - assert self._api.model_info(repo_id).id == repo_id - - # Test config has been pushed to hub - config_path = hf_hub_download(repo_id=repo_id, filename="config.json", token=TOKEN, cache_dir=self.cache_dir) - with open(config_path) as f: - assert json.load(f) == {"num": 7, "act": "gelu_fast"} - - # Delete tmp file and repo - self._api.delete_repo(repo_id=repo_id) - - -@require_tf -class HubKerasSequentialTest(CommonKerasTest): - def model_init(self): - model = tf.keras.models.Sequential() - model.add(tf.keras.layers.Dense(2, activation="relu")) - model.compile(optimizer="adam", loss="mse") - return model - - def model_fit(self, model): - x = tf.constant([[0.44, 0.90], [0.65, 0.39]]) - y = tf.constant([[1, 1], [0, 0]]) - model.fit(x, y) - return model - - def test_save_pretrained(self): - model = self.model_init() - with pytest.raises(ValueError, match="Model should be built*"): - save_pretrained_keras(model, save_directory=self.cache_dir) - model.build((None, 2)) - - save_pretrained_keras(model, save_directory=self.cache_dir) - files = os.listdir(self.cache_dir) - self.assertIn("saved_model.pb", files) - self.assertIn("keras_metadata.pb", files) - self.assertIn("model.png", files) - self.assertIn("README.md", files) - self.assertEqual(len(files), 7) - - loaded_model = from_pretrained_keras(self.cache_dir) - self.assertIsNone(loaded_model.optimizer) - - def test_save_pretrained_model_card_fit(self): - model = self.model_init() - model = self.model_fit(model) - - save_pretrained_keras(model, save_directory=self.cache_dir) - files = os.listdir(self.cache_dir) - history = json.loads((self.cache_dir / "history.json").read_text()) - - self.assertIn("saved_model.pb", files) - self.assertIn("keras_metadata.pb", files) - self.assertIn("model.png", files) - self.assertIn("README.md", files) - self.assertIn("history.json", files) - self.assertEqual(history, model.history.history) - self.assertEqual(len(files), 8) - - def test_save_model_card_history_removal(self): - model = self.model_init() - model = self.model_fit(model) - - history_path = self.cache_dir / "history.json" - history_path.write_text("Keras FTW") - - with pytest.warns(UserWarning, match="`history.json` file already exists, *"): - save_pretrained_keras(model, save_directory=self.cache_dir) - # assert that it's not the same as old history file and it's overridden - self.assertNotEqual("Keras FTW", history_path.read_text()) - - # Check the history is saved as a json in the repository. - files = os.listdir(self.cache_dir) - self.assertIn("history.json", files) - - # Check that there is no "Training Metrics" section in the model card. - # This was done in an older version. - self.assertNotIn("Training Metrics", (self.cache_dir / "README.md").read_text()) - - def test_save_pretrained_optimizer_state(self): - model = self.model_init() - model.build((None, 2)) - save_pretrained_keras(model, self.cache_dir, include_optimizer=True) - loaded_model = from_pretrained_keras(self.cache_dir) - self.assertIsNotNone(loaded_model.optimizer) - - def test_from_pretrained_weights(self): - model = self.model_init() - model.build((None, 2)) - - save_pretrained_keras(model, self.cache_dir) - new_model = from_pretrained_keras(self.cache_dir) - - # Check a new model's weights are not the same as the reloaded model's weights - another_model = DummyModel() - another_model(tf.ones([2, 2])) - self.assertFalse(tf.reduce_all(tf.equal(new_model.weights[0], another_model.weights[0])).numpy().item()) - - def test_save_pretrained_task_name_deprecation(self): - model = self.model_init() - model.build((None, 2)) - - with pytest.warns( - FutureWarning, - match="`task_name` input argument is deprecated. Pass `tags` instead.", - ): - save_pretrained_keras(model, self.cache_dir, tags=["test"], task_name="test", save_traces=True) - - def test_abs_path_from_pretrained(self): - model = self.model_init() - model.build((None, 2)) - save_pretrained_keras( - model, self.cache_dir, config={"num": 10, "act": "gelu_fast"}, plot_model=True, tags=None - ) - new_model = from_pretrained_keras(self.cache_dir) - self.assertTrue(tf.reduce_all(tf.equal(new_model.weights[0], model.weights[0]))) - self.assertTrue(new_model.config == {"num": 10, "act": "gelu_fast"}) - - def test_push_to_hub_keras_sequential_via_http_basic(self): - repo_id = f"{USER}/{repo_name()}" - model = self.model_init() - model = self.model_fit(model) - - push_to_hub_keras(model, repo_id=repo_id, token=TOKEN, api_endpoint=ENDPOINT_STAGING) - assert self._api.model_info(repo_id).id == repo_id - repo_files = self._api.list_repo_files(repo_id) - assert "README.md" in repo_files - assert "model.png" in repo_files - self._api.delete_repo(repo_id=repo_id) - - def test_push_to_hub_keras_sequential_via_http_plot_false(self): - repo_id = f"{USER}/{repo_name()}" - model = self.model_init() - model = self.model_fit(model) - - push_to_hub_keras(model, repo_id=repo_id, token=TOKEN, api_endpoint=ENDPOINT_STAGING, plot_model=False) - repo_files = self._api.list_repo_files(repo_id) - self.assertNotIn("model.png", repo_files) - self._api.delete_repo(repo_id=repo_id) - - def test_push_to_hub_keras_via_http_override_tensorboard(self): - """Test log directory is overwritten when pushing a keras model a 2nd time.""" - repo_id = f"{USER}/{repo_name()}" - - log_dir = self.cache_dir / "tb_log_dir" - log_dir.mkdir(parents=True, exist_ok=True) - (log_dir / "tensorboard.txt").write_text("Keras FTW") - - model = self.model_init() - model.build((None, 2)) - push_to_hub_keras(model, repo_id=repo_id, log_dir=log_dir, api_endpoint=ENDPOINT_STAGING, token=TOKEN) - - log_dir2 = self.cache_dir / "tb_log_dir2" - log_dir2.mkdir(parents=True, exist_ok=True) - (log_dir2 / "override.txt").write_text("Keras FTW") - push_to_hub_keras(model, repo_id=repo_id, log_dir=log_dir2, api_endpoint=ENDPOINT_STAGING, token=TOKEN) - - files = self._api.list_repo_files(repo_id) - self.assertIn("logs/override.txt", files) - self.assertNotIn("logs/tensorboard.txt", files) - - self._api.delete_repo(repo_id=repo_id) - - def test_push_to_hub_keras_via_http_with_model_kwargs(self): - repo_id = f"{USER}/{repo_name()}" - - model = self.model_init() - model = self.model_fit(model) - push_to_hub_keras( - model, - repo_id=repo_id, - api_endpoint=ENDPOINT_STAGING, - token=TOKEN, - include_optimizer=True, - save_traces=False, - ) - - assert self._api.model_info(repo_id).id == repo_id - - snapshot_path = snapshot_download(repo_id=repo_id, cache_dir=self.cache_dir) - from_pretrained_keras(snapshot_path) - - self._api.delete_repo(repo_id) - - -@require_tf -class HubKerasFunctionalTest(CommonKerasTest): - def model_init(self): - inputs = tf.keras.layers.Input(shape=(2,)) - outputs = tf.keras.layers.Dense(2, activation="relu")(inputs) - model = tf.keras.models.Model(inputs=inputs, outputs=outputs) - model.compile(optimizer="adam", loss="mse") - return model - - def model_fit(self, model): - x = tf.constant([[0.44, 0.90], [0.65, 0.39]]) - y = tf.constant([[1, 1], [0, 0]]) - model.fit(x, y) - return model - - def test_save_pretrained(self): - model = self.model_init() - model.build((None, 2)) - self.assertTrue(model.built) - - save_pretrained_keras(model, self.cache_dir) - files = os.listdir(self.cache_dir) - - self.assertIn("saved_model.pb", files) - self.assertIn("keras_metadata.pb", files) - self.assertEqual(len(files), 7) - - def test_save_pretrained_fit(self): - model = self.model_init() - model = self.model_fit(model) - - save_pretrained_keras(model, self.cache_dir) - files = os.listdir(self.cache_dir) - - self.assertIn("saved_model.pb", files) - self.assertIn("keras_metadata.pb", files) - self.assertEqual(len(files), 8) From 9cb19555d89083d1ea739248c8f023e549ed7eb1 Mon Sep 17 00:00:00 2001 From: Lucain Date: Fri, 12 Sep 2025 14:18:30 +0200 Subject: [PATCH 10/19] [v1.0] Remove anything tensorflow-related + deps (#3354) * Remove anything tensorflow-related + deps * init * fix tests * fix conflicts in tests --- .github/workflows/python-tests.yml | 16 ---- docs/source/cn/installation.md | 8 +- docs/source/de/installation.md | 6 +- docs/source/en/guides/cli.md | 1 - docs/source/en/installation.md | 6 +- .../en/package_reference/serialization.md | 7 +- docs/source/fr/installation.md | 6 +- docs/source/hi/installation.md | 6 +- docs/source/ko/guides/cli.md | 1 - docs/source/ko/installation.md | 6 +- .../ko/package_reference/serialization.md | 6 +- docs/source/tm/installation.md | 6 +- setup.py | 10 -- src/huggingface_hub/__init__.py | 6 -- src/huggingface_hub/serialization/__init__.py | 1 - .../serialization/_tensorflow.py | 95 ------------------- src/huggingface_hub/utils/_headers.py | 15 +-- tests/test_hf_api.py | 5 +- tests/test_serialization.py | 9 -- tests/test_tf_import.py | 26 ----- tests/test_utils_headers.py | 30 +----- 21 files changed, 15 insertions(+), 257 deletions(-) delete mode 100644 src/huggingface_hub/serialization/_tensorflow.py delete mode 100644 tests/test_tf_import.py diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index f32527f67b..f30945a461 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -30,10 +30,6 @@ jobs: test_name: "fastai" - python-version: "3.10" # fastai not supported on 3.12 and 3.11 -> test it on 3.10 test_name: "fastai" - - python-version: "3.9" - test_name: "tensorflow" - - python-version: "3.10" # tensorflow not supported on 3.12 -> test it on 3.10 - test_name: "tensorflow" - python-version: "3.9" # test torch~=1.11 on python 3.9 only. test_name: "Python 3.9, torch_1.11" - python-version: "3.12" # test torch latest on python 3.12 only. @@ -83,12 +79,6 @@ jobs: uv pip install torch~=1.11 ;; - tensorflow) - sudo apt update - sudo apt install -y graphviz - uv pip install "huggingface_hub[tensorflow-testing] @ ." - ;; - esac # If not "Xet only", we want to test upload/download with regular LFS workflow @@ -127,12 +117,6 @@ jobs: eval "$PYTEST ../tests/test_fastai*" ;; - tensorflow) - # Cannot be on same line since '_tf*' checks if tensorflow is NOT imported by default - eval "$PYTEST ../tests/test_tf*" - eval "$PYTEST ../tests/test_serialization.py" - ;; - "Python 3.9, torch_1.11" | torch_latest) eval "$PYTEST ../tests/test_hub_mixin*" eval "$PYTEST ../tests/test_serialization.py" diff --git a/docs/source/cn/installation.md b/docs/source/cn/installation.md index ec899a2305..516d8b9f70 100644 --- a/docs/source/cn/installation.md +++ b/docs/source/cn/installation.md @@ -48,11 +48,7 @@ pip install --upgrade huggingface_hub 您可以通过`pip`安装可选依赖项,请运行以下代码: ```bash -# 安装 TensorFlow 特定功能的依赖项 -# /!\ 注意:这不等同于 `pip install tensorflow` -pip install 'huggingface_hub[tensorflow]' - -# 安装 TensorFlow 特定功能和 CLI 特定功能的依赖项 +# 安装 Torch 特定功能和 CLI 特定功能的依赖项 pip install 'huggingface_hub[cli,torch]' ``` @@ -60,7 +56,7 @@ pip install 'huggingface_hub[cli,torch]' - `cli`:为 `huggingface_hub` 提供更方便的命令行界面 -- `fastai`,` torch`, `tensorflow`: 运行框架特定功能所需的依赖项 +- `fastai`,` torch`: 运行框架特定功能所需的依赖项 - `dev`:用于为库做贡献的依赖项。包括 `testing`(用于运行测试)、`typing`(用于运行类型检查器)和 `quality`(用于运行 linter) diff --git a/docs/source/de/installation.md b/docs/source/de/installation.md index 4c2a907f04..a603d25558 100644 --- a/docs/source/de/installation.md +++ b/docs/source/de/installation.md @@ -44,10 +44,6 @@ Einige Abhängigkeiten von `huggingface_hub` sind [optional](https://setuptools. Sie können optionale Abhängigkeiten über `pip` installieren: ```bash -# Abhängigkeiten für spezifische TensorFlow-Funktionen installieren -# /!\ Achtung: dies entspricht nicht `pip install tensorflow` -pip install 'huggingface_hub[tensorflow]' - # Abhängigkeiten sowohl für torch-spezifische als auch für CLI-spezifische Funktionen installieren. pip install 'huggingface_hub[cli,torch]' ``` @@ -55,7 +51,7 @@ pip install 'huggingface_hub[cli,torch]' Hier ist die Liste der optionalen Abhängigkeiten in huggingface_hub: - `cli`: bietet eine komfortablere CLI-Schnittstelle für huggingface_hub. -- `fastai`, `torch`, `tensorflow`: Abhängigkeiten, um framework-spezifische Funktionen auszuführen. +- `fastai`, `torch`: Abhängigkeiten, um framework-spezifische Funktionen auszuführen. - `dev`: Abhängigkeiten, um zur Bibliothek beizutragen. Enthält `testing` (um Tests auszuführen), `typing` (um den Type Checker auszuführen) und `quality` (um Linters auszuführen). diff --git a/docs/source/en/guides/cli.md b/docs/source/en/guides/cli.md index a754e010b4..cd3e6cacfd 100644 --- a/docs/source/en/guides/cli.md +++ b/docs/source/en/guides/cli.md @@ -576,7 +576,6 @@ Copy-and-paste the text below in your GitHub issue. - Who am I ?: Wauplin - Configured git credential helpers: store - FastAI: N/A -- Tensorflow: 2.11.0 - Torch: 1.12.1 - Jinja2: 3.1.2 - Graphviz: 0.20.1 diff --git a/docs/source/en/installation.md b/docs/source/en/installation.md index 7d86b715d0..69701225d4 100644 --- a/docs/source/en/installation.md +++ b/docs/source/en/installation.md @@ -46,17 +46,13 @@ Some dependencies of `huggingface_hub` are [optional](https://setuptools.pypa.io You can install optional dependencies via `pip`: ```bash -# Install dependencies for tensorflow-specific features -# /!\ Warning: this is not equivalent to `pip install tensorflow` -pip install 'huggingface_hub[tensorflow]' - # Install dependencies for both torch-specific and CLI-specific features. pip install 'huggingface_hub[cli,torch]' ``` Here is the list of optional dependencies in `huggingface_hub`: - `cli`: provide a more convenient CLI interface for `huggingface_hub`. -- `fastai`, `torch`, `tensorflow`: dependencies to run framework-specific features. +- `fastai`, `torch`: dependencies to run framework-specific features. - `dev`: dependencies to contribute to the lib. Includes `testing` (to run tests), `typing` (to run type checker) and `quality` (to run linters). diff --git a/docs/source/en/package_reference/serialization.md b/docs/source/en/package_reference/serialization.md index f45ad58cd8..04ccfd2682 100644 --- a/docs/source/en/package_reference/serialization.md +++ b/docs/source/en/package_reference/serialization.md @@ -131,11 +131,7 @@ If you want to save a state dictionary (e.g. a mapping between layer names and r [[autodoc]] huggingface_hub.save_torch_state_dict -The `serialization` module also contains low-level helpers to split a state dictionary into several shards, while creating a proper index in the process. These helpers are available for `torch` and `tensorflow` tensors and are designed to be easily extended to any other ML frameworks. - -### split_tf_state_dict_into_shards - -[[autodoc]] huggingface_hub.split_tf_state_dict_into_shards +The `serialization` module also contains low-level helpers to split a state dictionary into several shards, while creating a proper index in the process. These helpers are available for `torch` tensors and are designed to be easily extended to any other ML frameworks. ### split_torch_state_dict_into_shards @@ -159,7 +155,6 @@ The loading helpers support both single-file and sharded checkpoints in either s [[autodoc]] huggingface_hub.load_state_dict_from_file - ## Tensors helpers ### get_torch_storage_id diff --git a/docs/source/fr/installation.md b/docs/source/fr/installation.md index 6e0f41ee6e..fe3a279102 100644 --- a/docs/source/fr/installation.md +++ b/docs/source/fr/installation.md @@ -48,17 +48,13 @@ Toutefois, certaines fonctionnalités de `huggingface_hub` ne seront pas disponi Vous pouvez installer des dépendances optionnelles via `pip`: ```bash -#Installation des dépendances pour les fonctionnalités spécifiques à Tensorflow. -#/!\ Attention : cette commande n'est pas équivalente à `pip install tensorflow`. -pip install 'huggingface_hub[tensorflow]' - #Installation des dépendances spécifiques à Pytorch et au CLI. pip install 'huggingface_hub[cli,torch]' ``` Voici une liste des dépendances optionnelles dans `huggingface_hub`: - `cli` fournit une interface d'invite de commande plus pratique pour `huggingface_hub`. -- `fastai`, `torch` et `tensorflow` sont des dépendances pour utiliser des fonctionnalités spécifiques à un framework. +- `fastai`, `torch` sont des dépendances pour utiliser des fonctionnalités spécifiques à un framework. - `dev` permet de contribuer à la librairie. Cette dépendance inclut `testing` (pour lancer des tests), `typing` (pour lancer le vérifieur de type) et `quality` (pour lancer des linters). diff --git a/docs/source/hi/installation.md b/docs/source/hi/installation.md index c5974a32f7..91d3702059 100644 --- a/docs/source/hi/installation.md +++ b/docs/source/hi/installation.md @@ -46,17 +46,13 @@ pip install --upgrade huggingface_hub आप `pip` के माध्यम से वैकल्पिक निर्भरताएँ स्थापित कर सकते हैं: ```bash -# Install dependencies for tensorflow-specific features -# /!\ Warning: this is not equivalent to `pip install tensorflow` -pip install 'huggingface_hub[tensorflow]' - # Install dependencies for both torch-specific and CLI-specific features. pip install 'huggingface_hub[cli,torch]' ``` यहां `huggingface_hub` में वैकल्पिक निर्भरताओं की सूची दी गई है: - `cli`: `huggingface_hub` के लिए अधिक सुविधाजनक CLI इंटरफ़ेस प्रदान करें। -- `fastai`, `torch`, `tensorflow`: फ्रेमवर्क-विशिष्ट सुविधाओं को चलाने के लिए निर्भरताएँ। +- `fastai`, `torch`: फ्रेमवर्क-विशिष्ट सुविधाओं को चलाने के लिए निर्भरताएँ। - `dev`: lib में योगदान करने के लिए निर्भरताएँ। इसमें 'परीक्षण' (परीक्षण चलाने के लिए), 'टाइपिंग' (टाइप चेकर चलाने के लिए) और 'गुणवत्ता' (लिंटर चलाने के लिए) शामिल हैं। diff --git a/docs/source/ko/guides/cli.md b/docs/source/ko/guides/cli.md index 71f03095dd..28472fcf72 100644 --- a/docs/source/ko/guides/cli.md +++ b/docs/source/ko/guides/cli.md @@ -455,7 +455,6 @@ Copy-and-paste the text below in your GitHub issue. - Who am I ?: Wauplin - Configured git credential helpers: store - FastAI: N/A -- Tensorflow: 2.11.0 - Torch: 1.12.1 - Jinja2: 3.1.2 - Graphviz: 0.20.1 diff --git a/docs/source/ko/installation.md b/docs/source/ko/installation.md index b222bef630..d9cd8a46dd 100644 --- a/docs/source/ko/installation.md +++ b/docs/source/ko/installation.md @@ -46,17 +46,13 @@ pip install --upgrade huggingface_hub 선택적 의존성은 `pip`을 통해 설치할 수 있습니다: ```bash -# TensorFlow 관련 기능에 대한 의존성을 설치합니다. -# /!\ 경고: `pip install tensorflow`와 동일하지 않습니다. -pip install 'huggingface_hub[tensorflow]' - # PyTorch와 CLI와 관련된 기능에 대한 의존성을 모두 설치합니다. pip install 'huggingface_hub[cli,torch]' ``` 다음은 `huggingface_hub`의 선택 의존성 목록입니다: - `cli`: 보다 편리한 `huggingface_hub`의 CLI 인터페이스입니다. -- `fastai`, `torch`, `tensorflow`: 프레임워크별 기능을 실행하려면 필요합니다. +- `fastai`, `torch`: 프레임워크별 기능을 실행하려면 필요합니다. - `dev`: 라이브러리에 기여하고 싶다면 필요합니다. 테스트 실행을 위한 `testing`, 타입 검사기 실행을 위한 `typing`, 린터 실행을 위한 `quality`가 포함됩니다. ### 소스에서 설치 [[install-from-source]] diff --git a/docs/source/ko/package_reference/serialization.md b/docs/source/ko/package_reference/serialization.md index 25901237bf..9dd7a6ce7b 100644 --- a/docs/source/ko/package_reference/serialization.md +++ b/docs/source/ko/package_reference/serialization.md @@ -8,11 +8,7 @@ rendered properly in your Markdown viewer. ## 상태 사전을 샤드로 나누기[[split-state-dict-into-shards]] -현재 이 모듈은 상태 딕셔너리(예: 레이어 이름과 관련 텐서 간의 매핑)를 받아 여러 샤드로 나누고, 이 과정에서 적절한 인덱스를 생성하는 단일 헬퍼를 포함하고 있습니다. 이 헬퍼는 `torch`, `tensorflow`, `numpy` 텐서에 사용 가능하며, 다른 ML 프레임워크로 쉽게 확장될 수 있도록 설계되었습니다. - -### split_tf_state_dict_into_shards[[huggingface_hub.split_tf_state_dict_into_shards]] - -[[autodoc]] huggingface_hub.split_tf_state_dict_into_shards +현재 이 모듈은 상태 딕셔너리(예: 레이어 이름과 관련 텐서 간의 매핑)를 받아 여러 샤드로 나누고, 이 과정에서 적절한 인덱스를 생성하는 단일 헬퍼를 포함하고 있습니다. 이 헬퍼는 `torch` 텐서에 사용 가능하며, 다른 ML 프레임워크로 쉽게 확장될 수 있도록 설계되었습니다. ### split_torch_state_dict_into_shards[[huggingface_hub.split_torch_state_dict_into_shards]] diff --git a/docs/source/tm/installation.md b/docs/source/tm/installation.md index 28134ed5b7..479b2c3e4c 100644 --- a/docs/source/tm/installation.md +++ b/docs/source/tm/installation.md @@ -43,17 +43,13 @@ pip install --upgrade huggingface_hub நீங்கள் விருப்பத் தேவைப்படும் சார்புகளை `pip` மூலம் நிறுவலாம்: ```bash -# டென்சர்‌ஃபிளோவுக்கான குறிப்பிட்ட அம்சங்களுக்கு சார்ந்த பொறுப்பு நிறுவவும் -# /!\ எச்சரிக்கை: இது `pip install tensorflow` க்கு சமமாகக் கருதப்படாது -pip install 'huggingface_hub[tensorflow]' - # டார்ச்-குறிப்பிட்ட மற்றும் CLI-குறிப்பிட்ட அம்சங்களுக்கு தேவையான பொறுப்புகளை நிறுவவும். pip install 'huggingface_hub[cli,torch]' ``` `huggingface_hub`-இல் உள்ள விருப்பத் தேவைப்படும் சார்புகளின் பட்டியல்: - `cli`: `huggingface_hub`-க்கு மிகவும் வசதியான CLI இடைமுகத்தை வழங்குகிறது. -- `fastai`, `torch`, `tensorflow`: வடிவமைப்பு குறிப்பிட்ட அம்சங்களை இயக்க தேவையான சார்புகள். +- `fastai`, `torch`: வடிவமைப்பு குறிப்பிட்ட அம்சங்களை இயக்க தேவையான சார்புகள். - `dev`: நூலகத்திற்கு பங்களிக்க தேவையான சார்புகள். இதில் சோதனை (சோதனைகளை இயக்க), வகை சோதனை (வகை சரிபார்ப்பு ஐ இயக்க) மற்றும் தரம் (லிண்டர்கள் ஐ இயக்க) உள்ளன. ### மூலத்திலிருந்து நிறுவல் diff --git a/setup.py b/setup.py index 6524423fb8..97fc2efbe8 100644 --- a/setup.py +++ b/setup.py @@ -54,16 +54,6 @@ def get_version() -> str: "fastcore>=1.3.27", ] -extras["tensorflow"] = [ - "tensorflow", - "pydot", - "graphviz", -] - -extras["tensorflow-testing"] = [ - "tensorflow", -] - extras["hf_xet"] = ["hf-xet>=1.1.2,<2.0.0"] extras["mcp"] = [ diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index 2d2d27c9db..f8937a0580 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -488,7 +488,6 @@ ], "serialization": [ "StateDictSplit", - "get_tf_storage_size", "get_torch_storage_id", "get_torch_storage_size", "load_state_dict_from_file", @@ -496,7 +495,6 @@ "save_torch_model", "save_torch_state_dict", "split_state_dict_into_shards_factory", - "split_tf_state_dict_into_shards", "split_torch_state_dict_into_shards", ], "serialization._dduf": [ @@ -869,7 +867,6 @@ "get_session", "get_space_runtime", "get_space_variables", - "get_tf_storage_size", "get_token", "get_torch_storage_id", "get_torch_storage_size", @@ -950,7 +947,6 @@ "snapshot_download", "space_info", "split_state_dict_into_shards_factory", - "split_tf_state_dict_into_shards", "split_torch_state_dict_into_shards", "super_squash_history", "suspend_scheduled_job", @@ -1494,7 +1490,6 @@ def __dir__(): ) from .serialization import ( StateDictSplit, # noqa: F401 - get_tf_storage_size, # noqa: F401 get_torch_storage_id, # noqa: F401 get_torch_storage_size, # noqa: F401 load_state_dict_from_file, # noqa: F401 @@ -1502,7 +1497,6 @@ def __dir__(): save_torch_model, # noqa: F401 save_torch_state_dict, # noqa: F401 split_state_dict_into_shards_factory, # noqa: F401 - split_tf_state_dict_into_shards, # noqa: F401 split_torch_state_dict_into_shards, # noqa: F401 ) from .serialization._dduf import ( diff --git a/src/huggingface_hub/serialization/__init__.py b/src/huggingface_hub/serialization/__init__.py index 8949a22a5f..6e624a7541 100644 --- a/src/huggingface_hub/serialization/__init__.py +++ b/src/huggingface_hub/serialization/__init__.py @@ -15,7 +15,6 @@ """Contains helpers to serialize tensors.""" from ._base import StateDictSplit, split_state_dict_into_shards_factory -from ._tensorflow import get_tf_storage_size, split_tf_state_dict_into_shards from ._torch import ( get_torch_storage_id, get_torch_storage_size, diff --git a/src/huggingface_hub/serialization/_tensorflow.py b/src/huggingface_hub/serialization/_tensorflow.py deleted file mode 100644 index affcaf4834..0000000000 --- a/src/huggingface_hub/serialization/_tensorflow.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Contains tensorflow-specific helpers.""" - -import math -import re -from typing import TYPE_CHECKING, Union - -from .. import constants -from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory - - -if TYPE_CHECKING: - import tensorflow as tf - - -def split_tf_state_dict_into_shards( - state_dict: dict[str, "tf.Tensor"], - *, - filename_pattern: str = constants.TF2_WEIGHTS_FILE_PATTERN, - max_shard_size: Union[int, str] = MAX_SHARD_SIZE, -) -> StateDictSplit: - """ - Split a model state dictionary in shards so that each shard is smaller than a given size. - - The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization - made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we - have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not - [6+2+2GB], [6+2GB], [6GB]. - - - - If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a - size greater than `max_shard_size`. - - - - Args: - state_dict (`dict[str, Tensor]`): - The state dictionary to save. - filename_pattern (`str`, *optional*): - The pattern to generate the files names in which the model will be saved. Pattern must be a string that - can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` - Defaults to `"tf_model{suffix}.h5"`. - max_shard_size (`int` or `str`, *optional*): - The maximum size of each shard, in bytes. Defaults to 5GB. - - Returns: - [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them. - """ - return split_state_dict_into_shards_factory( - state_dict, - max_shard_size=max_shard_size, - filename_pattern=filename_pattern, - get_storage_size=get_tf_storage_size, - ) - - -def get_tf_storage_size(tensor: "tf.Tensor") -> int: - # Return `math.ceil` since dtype byte size can be a float (e.g., 0.125 for tf.bool). - # Better to overestimate than underestimate. - return math.ceil(tensor.numpy().size * _dtype_byte_size_tf(tensor.dtype)) - - -def _dtype_byte_size_tf(dtype) -> float: - """ - Returns the size (in bytes) occupied by one parameter of type `dtype`. - Taken from https://github.com/huggingface/transformers/blob/74d9d0cebb0263a3f8ab9c280569170cc74651d0/src/transformers/modeling_tf_utils.py#L608. - NOTE: why not `tensor.numpy().nbytes`? - Example: - ```py - >>> _dtype_byte_size(tf.float32) - 4 - ``` - """ - import tensorflow as tf - - if dtype == tf.bool: - return 1 / 8 - bit_search = re.search(r"[^\d](\d+)$", dtype.name) - if bit_search is None: - raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") - bit_size = int(bit_search.groups()[0]) - return bit_size // 8 diff --git a/src/huggingface_hub/utils/_headers.py b/src/huggingface_hub/utils/_headers.py index d952d97121..cabdbd7c81 100644 --- a/src/huggingface_hub/utils/_headers.py +++ b/src/huggingface_hub/utils/_headers.py @@ -21,15 +21,9 @@ from .. import constants from ._auth import get_token from ._runtime import ( - get_fastai_version, - get_fastcore_version, get_hf_hub_version, get_python_version, - get_tf_version, get_torch_version, - is_fastai_available, - is_fastcore_available, - is_tf_available, is_torch_available, ) from ._validators import validate_hf_hub_args @@ -56,8 +50,7 @@ def build_hf_headers( `None` or token is an organization token (starting with `"api_org***"`). In addition to the auth header, a user-agent is added to provide information about - the installed packages (versions of python, huggingface_hub, torch, tensorflow, - fastai and fastcore). + the installed packages (versions of python, huggingface_hub, torch). Args: token (`str`, `bool`, *optional*): @@ -192,12 +185,6 @@ def _http_user_agent( if not constants.HF_HUB_DISABLE_TELEMETRY: if is_torch_available(): ua += f"; torch/{get_torch_version()}" - if is_tf_available(): - ua += f"; tensorflow/{get_tf_version()}" - if is_fastai_available(): - ua += f"; fastai/{get_fastai_version()}" - if is_fastcore_available(): - ua += f"; fastcore/{get_fastcore_version()}" if isinstance(user_agent, dict): ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items()) diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index bbd8decff1..935bdbb1d4 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -207,11 +207,10 @@ def test_delete_repo_missing_ok(self) -> None: self._api.delete_repo("repo-that-does-not-exist", missing_ok=True) def test_move_repo_normal_usage(self): - repo_id = f"{USER}/{repo_name()}" - new_repo_id = f"{USER}/{repo_name()}" - # Spaces not tested on staging (error 500) for repo_type in [None, constants.REPO_TYPE_MODEL, constants.REPO_TYPE_DATASET]: + repo_id = f"{USER}/{repo_name()}" + new_repo_id = f"{USER}/{repo_name()}" self._api.create_repo(repo_id=repo_id, repo_type=repo_type) self._api.move_repo(from_id=repo_id, to_id=new_repo_id, repo_type=repo_type) self._api.delete_repo(repo_id=new_repo_id, repo_type=repo_type) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 6bc74b9962..6b67ff00d4 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -9,7 +9,6 @@ from huggingface_hub import constants from huggingface_hub.serialization import ( - get_tf_storage_size, get_torch_storage_size, load_state_dict_from_file, load_torch_model, @@ -244,14 +243,6 @@ def test_tensor_same_storage(): assert state_dict_split.metadata == {"total_size": 3} # count them once -@requires("tensorflow") -def test_get_tf_storage_size(): - import tensorflow as tf # type: ignore[import] - - assert get_tf_storage_size(tf.constant([1, 2, 3, 4, 5], dtype=tf.float64)) == 5 * 8 - assert get_tf_storage_size(tf.constant([1, 2, 3, 4, 5], dtype=tf.float16)) == 5 * 2 - - @requires("torch") def test_get_torch_storage_size(): import torch # type: ignore[import] diff --git a/tests/test_tf_import.py b/tests/test_tf_import.py deleted file mode 100644 index a79c234a8c..0000000000 --- a/tests/test_tf_import.py +++ /dev/null @@ -1,26 +0,0 @@ -import sys -import unittest - -from huggingface_hub.utils import is_tf_available - - -def require_tf(test_case): - """ - Decorator marking a test that requires TensorFlow. - - These tests are skipped when TensorFlow is not installed. - - """ - if not is_tf_available(): - return unittest.skip("test requires Tensorflow")(test_case) - else: - return test_case - - -@require_tf -def test_import_huggingface_hub_does_not_import_tensorflow(): - # `import huggingface_hub` is not necessary since huggingface_hub is already imported at the top of this file, - # but let's keep it here anyway just in case - import huggingface_hub # noqa - - assert "tensorflow" not in sys.modules diff --git a/tests/test_utils_headers.py b/tests/test_utils_headers.py index d03f545095..849b9f063a 100644 --- a/tests/test_utils_headers.py +++ b/tests/test_utils_headers.py @@ -64,52 +64,26 @@ class TestUserAgentHeadersUtil(unittest.TestCase): def _get_user_agent(self, **kwargs) -> str: return build_hf_headers(**kwargs)["user-agent"] - @patch("huggingface_hub.utils._headers.get_fastai_version") - @patch("huggingface_hub.utils._headers.get_fastcore_version") - @patch("huggingface_hub.utils._headers.get_tf_version") @patch("huggingface_hub.utils._headers.get_torch_version") - @patch("huggingface_hub.utils._headers.is_fastai_available") - @patch("huggingface_hub.utils._headers.is_fastcore_available") - @patch("huggingface_hub.utils._headers.is_tf_available") @patch("huggingface_hub.utils._headers.is_torch_available") @handle_injection_in_test def test_default_user_agent( self, - mock_get_fastai_version: Mock, - mock_get_fastcore_version: Mock, - mock_get_tf_version: Mock, mock_get_torch_version: Mock, - mock_is_fastai_available: Mock, - mock_is_fastcore_available: Mock, - mock_is_tf_available: Mock, mock_is_torch_available: Mock, ) -> None: - mock_get_fastai_version.return_value = "fastai_version" - mock_get_fastcore_version.return_value = "fastcore_version" - mock_get_tf_version.return_value = "tf_version" mock_get_torch_version.return_value = "torch_version" - mock_is_fastai_available.return_value = True - mock_is_fastcore_available.return_value = True - mock_is_tf_available.return_value = True mock_is_torch_available.return_value = True self.assertEqual( self._get_user_agent(), - f"unknown/None; hf_hub/{get_hf_hub_version()};" - f" python/{get_python_version()}; torch/torch_version;" - " tensorflow/tf_version; fastai/fastai_version;" - " fastcore/fastcore_version", + f"unknown/None; hf_hub/{get_hf_hub_version()}; python/{get_python_version()}; torch/torch_version", ) @patch("huggingface_hub.utils._headers.is_torch_available") - @patch("huggingface_hub.utils._headers.is_tf_available") @handle_injection_in_test - def test_user_agent_with_library_name_multiple_missing( - self, mock_is_torch_available: Mock, mock_is_tf_available: Mock - ) -> None: + def test_user_agent_with_library_name_multiple_missing(self, mock_is_torch_available: Mock) -> None: mock_is_torch_available.return_value = False - mock_is_tf_available.return_value = False self.assertNotIn("torch", self._get_user_agent()) - self.assertNotIn("tensorflow", self._get_user_agent()) def test_user_agent_with_library_name_and_version(self) -> None: self.assertTrue( From abb1275b0bbd2ba597055ccf7176578a9a0619c6 Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Fri, 12 Sep 2025 17:17:06 +0200 Subject: [PATCH 11/19] HTTP configuration docs --- docs/source/en/package_reference/utilities.md | 39 +++++++++++++------ docs/source/ko/package_reference/utilities.md | 10 ----- src/huggingface_hub/__init__.py | 6 +-- src/huggingface_hub/utils/__init__.py | 2 +- src/huggingface_hub/utils/_http.py | 10 ++--- 5 files changed, 37 insertions(+), 30 deletions(-) diff --git a/docs/source/en/package_reference/utilities.md b/docs/source/en/package_reference/utilities.md index a7cc46315d..757a6613b2 100644 --- a/docs/source/en/package_reference/utilities.md +++ b/docs/source/en/package_reference/utilities.md @@ -120,23 +120,40 @@ You can also enable or disable progress bars for specific groups. This allows yo [[autodoc]] huggingface_hub.utils.enable_progress_bars -## Configure HTTP backend +## Configuring the HTTP Backend -In some environments, you might want to configure how HTTP calls are made, for example if you are using a proxy. -`huggingface_hub` let you configure this globally using [`configure_http_backend`]. All requests made to the Hub will -then use your settings. Under the hood, `huggingface_hub` uses `requests.Session` so you might want to refer to the -[`requests` documentation](https://requests.readthedocs.io/en/latest/user/advanced) to learn more about the available -parameters. + -Since `requests.Session` is not guaranteed to be thread-safe, `huggingface_hub` creates one session instance per thread. -Using sessions allows us to keep the connection open between HTTP calls and ultimately save time. If you are -integrating `huggingface_hub` in a third-party library and wants to make a custom call to the Hub, use [`get_session`] -to get a Session configured by your users (i.e. replace any `requests.get(...)` call by `get_session().get(...)`). +In `huggingface_hub` v0.x, HTTP requests were handled with `requests`, and configuration was done via `configure_http_backend`. Since we now use `httpx`, configuration works differently: you must provide a factory function that takes no arguments and returns an `httpx.Client`. You can review the [default implementation here](https://github.com/huggingface/huggingface_hub/blob/v1.0-release/src/huggingface_hub/utils/_http.py) to see which parameters are used by default. -[[autodoc]] configure_http_backend + + + +In some setups, you may need to control how HTTP requests are made, for example when working behind a proxy. The `huggingface_hub` library allows you to configure this globally with [`set_client_factory`]. After configuration, all requests to the Hub will use your custom settings. Since `huggingface_hub` relies on `httpx.Client` under the hood, you can check the [`httpx` documentation](https://www.python-httpx.org/advanced/clients/) for details on available parameters. + +If you are building a third-party library and need to make direct requests to the Hub, use [`get_session`] to obtain a correctly configured `httpx` client. Replace any direct `httpx.get(...)` calls with `get_session().get(...)` to ensure proper behavior. + +[[autodoc]] set_client_factory [[autodoc]] get_session +In rare cases, you may want to manually close the current session (for example, after a transient `SSLError`). You can do this with [`close_session`]. A new session will automatically be created on the next call to [`get_session`]. + +Sessions are always closed automatically when the process exits. + +[[autodoc]] close_session + +For async code, use [`set_async_client_factory`] to configure an `httpx.AsyncClient` and [`get_async_session`] to retrieve one. + +[[autodoc]] set_async_client_factory + +[[autodoc]] get_async_session + + + +Unlike the synchronous client, the lifecycle of the async client is not managed automatically. Use an async context manager to handle it properly. + + ## Handle HTTP errors diff --git a/docs/source/ko/package_reference/utilities.md b/docs/source/ko/package_reference/utilities.md index 5743d12015..4390a90718 100644 --- a/docs/source/ko/package_reference/utilities.md +++ b/docs/source/ko/package_reference/utilities.md @@ -84,16 +84,6 @@ True [[autodoc]] huggingface_hub.utils.enable_progress_bars -## HTTP 백엔드 구성[[huggingface_hub.configure_http_backend]] - -일부 환경에서는 HTTP 호출이 이루어지는 방식을 구성할 수 있습니다. 예를 들어, 프록시를 사용하는 경우가 그렇습니다. `huggingface_hub`는 [`configure_http_backend`]를 사용하여 전역적으로 이를 구성할 수 있게 합니다. 그러면 Hub로의 모든 요청이 사용자가 설정한 설정을 사용합니다. 내부적으로 `huggingface_hub`는 `requests.Session`을 사용하므로 사용 가능한 매개변수에 대해 자세히 알아보려면 [requests 문서](https://requests.readthedocs.io/en/latest/user/advanced)를 참조하는 것이 좋습니다. - -`requests.Session`이 스레드 안전을 보장하지 않기 때문에 `huggingface_hub`는 스레드당 하나의 세션 인스턴스를 생성합니다. 세션을 사용하면 HTTP 호출 사이에 연결을 유지하고 최종적으로 시간을 절약할 수 있습니다. `huggingface_hub`를 서드 파티 라이브러리에 통합하고 사용자 지정 호출을 Hub로 만들려는 경우, [`get_session`]을 사용하여 사용자가 구성한 세션을 가져옵니다 (즉, 모든 `requests.get(...)` 호출을 `get_session().get(...)`으로 대체합니다). - -[[autodoc]] configure_http_backend - -[[autodoc]] get_session - ## HTTP 오류 다루기[[handle-http-errors]] diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index f8937a0580..dd2a6ee616 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -516,7 +516,7 @@ "HfHubAsyncTransport", "HfHubTransport", "cached_assets_path", - "close_client", + "close_session", "dump_environment_info", "get_async_session", "get_session", @@ -815,7 +815,7 @@ "cancel_access_request", "cancel_job", "change_discussion_status", - "close_client", + "close_session", "comment_discussion", "create_branch", "create_collection", @@ -1518,7 +1518,7 @@ def __dir__(): HfHubAsyncTransport, # noqa: F401 HfHubTransport, # noqa: F401 cached_assets_path, # noqa: F401 - close_client, # noqa: F401 + close_session, # noqa: F401 dump_environment_info, # noqa: F401 get_async_session, # noqa: F401 get_session, # noqa: F401 diff --git a/src/huggingface_hub/utils/__init__.py b/src/huggingface_hub/utils/__init__.py index 6fc8c0ed7e..1b2eccdafc 100644 --- a/src/huggingface_hub/utils/__init__.py +++ b/src/huggingface_hub/utils/__init__.py @@ -55,7 +55,7 @@ CLIENT_FACTORY_T, HfHubAsyncTransport, HfHubTransport, - close_client, + close_session, fix_hf_endpoint_in_url, get_async_session, get_session, diff --git a/src/huggingface_hub/utils/_http.py b/src/huggingface_hub/utils/_http.py index 15484ec10d..c52fd6cc96 100644 --- a/src/huggingface_hub/utils/_http.py +++ b/src/huggingface_hub/utils/_http.py @@ -174,7 +174,7 @@ def set_client_factory(client_factory: CLIENT_FACTORY_T) -> None: """ global _GLOBAL_CLIENT_FACTORY with _CLIENT_LOCK: - close_client() + close_session() _GLOBAL_CLIENT_FACTORY = client_factory @@ -228,9 +228,9 @@ def get_async_session() -> httpx.AsyncClient: return _GLOBAL_ASYNC_CLIENT_FACTORY() -def close_client() -> None: +def close_session() -> None: """ - Close the global httpx.Client used by `huggingface_hub`. + Close the global `httpx.Client` used by `huggingface_hub`. If a Client is closed, it will be recreated on the next call to [`get_client`]. @@ -250,7 +250,7 @@ def close_client() -> None: logger.warning(f"Error closing client: {e}") -atexit.register(close_client) +atexit.register(close_session) def _http_backoff_base( @@ -325,7 +325,7 @@ def _should_retry(response: httpx.Response) -> bool: logger.warning(f"'{err}' thrown while requesting {method} {url}") if isinstance(err, httpx.ConnectError): - close_client() # In case of SSLError it's best to close the shared httpx.Client objects + close_session() # In case of SSLError it's best to close the shared httpx.Client objects if nb_tries > max_retries: raise err From e4bcfdd8c71d7b80b584d0a5f8fbbb5a61f73f84 Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Fri, 12 Sep 2025 17:26:14 +0200 Subject: [PATCH 12/19] http configuration docs --- docs/source/en/package_reference/utilities.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/en/package_reference/utilities.md b/docs/source/en/package_reference/utilities.md index 757a6613b2..c0ec92ed53 100644 --- a/docs/source/en/package_reference/utilities.md +++ b/docs/source/en/package_reference/utilities.md @@ -124,7 +124,8 @@ You can also enable or disable progress bars for specific groups. This allows yo -In `huggingface_hub` v0.x, HTTP requests were handled with `requests`, and configuration was done via `configure_http_backend`. Since we now use `httpx`, configuration works differently: you must provide a factory function that takes no arguments and returns an `httpx.Client`. You can review the [default implementation here](https://github.com/huggingface/huggingface_hub/blob/v1.0-release/src/huggingface_hub/utils/_http.py) to see which parameters are used by default. +In `huggingface_hub` v0.x, HTTP requests were handled with `requests`, and configuration was done via ` + git push --set-upstream origin v1.0-some-more-docsend`. Since we now use `httpx`, configuration works differently: you must provide a factory function that takes no arguments and returns an `httpx.Client`. You can review the [default implementation here](https://github.com/huggingface/huggingface_hub/blob/v1.0-release/src/huggingface_hub/utils/_http.py) to see which parameters are used by default. From c5081b70502f9330fcaf5c0119b36ea8f87e55d3 Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Fri, 12 Sep 2025 17:33:48 +0200 Subject: [PATCH 13/19] refactored git_vs_http --- docs/source/en/concepts/git_vs_http.md | 53 ++++++-------------------- 1 file changed, 11 insertions(+), 42 deletions(-) diff --git a/docs/source/en/concepts/git_vs_http.md b/docs/source/en/concepts/git_vs_http.md index e6eb755af5..49d0370752 100644 --- a/docs/source/en/concepts/git_vs_http.md +++ b/docs/source/en/concepts/git_vs_http.md @@ -4,59 +4,28 @@ rendered properly in your Markdown viewer. # Git vs HTTP paradigm -The `huggingface_hub` library is a library for interacting with the Hugging Face Hub, which is a -collection of git-based repositories (models, datasets or Spaces). There are two main -ways to access the Hub using `huggingface_hub`. +The `huggingface_hub` library is a library for interacting with the Hugging Face Hub, which is a collection of git-based repositories (models, datasets or Spaces). There are two main ways to access the Hub using `huggingface_hub`. -The first approach, the so-called "git-based" approach, is led by the [`Repository`] class. -This method uses a wrapper around the `git` command with additional functions specifically -designed to interact with the Hub. The second option, called the "HTTP-based" approach, -involves making HTTP requests using the [`HfApi`] client. Let's examine the pros and cons -of each approach. +The first approach, the so-called "git-based" approach, relies on using standard `git` commands directly in a terminal. This method allows you to clone repositories, create commits, and push changes manually. The second option, called the "HTTP-based" approach, involves making HTTP requests using the [`HfApi`] client. Let's examine the pros and cons of each approach. -## Repository: the historical git-based approach +## Git: the historical CLI-based approach -At first, `huggingface_hub` was mostly built around the [`Repository`] class. It provides -Python wrappers for common `git` commands such as `"git add"`, `"git commit"`, `"git push"`, -`"git tag"`, `"git checkout"`, etc. +At first, most users interacted with the Hugging Face Hub using plain `git` commands such as `git clone`, `git add`, `git commit`, `git push`, `git tag`, or `git checkout`. -The library also helps with setting credentials and tracking large files, which are often -used in machine learning repositories. Additionally, the library allows you to execute its -methods in the background, making it useful for uploading data during training. +This approach lets you work with a full local copy of the repository on your machine, just like in traditional software development. This can be an advantage when you need offline access or want to work with the full history of a repository. However, it also comes with downsides: you are responsible for keeping the repository up-to-date locally, handling credentials, and managing large files (via `git-lfs`), which can become cumbersome when working with large machine learning models or datasets. -The main advantage of using a [`Repository`] is that it allows you to maintain a local -copy of the entire repository on your machine. This can also be a disadvantage as -it requires you to constantly update and maintain this local copy. This is similar to -traditional software development where each developer maintains their own local copy and -pushes changes when working on a feature. However, in the context of machine learning, -this may not always be necessary as users may only need to download weights for inference -or convert weights from one format to another without the need to clone the entire -repository. - - - -[`Repository`] is now deprecated in favor of the http-based alternatives. Given its large adoption in legacy code, the complete removal of [`Repository`] will only happen in release `v1.0`. - - +In many machine learning workflows, you may only need to download a few files for inference or convert weights without needing to clone the entire repository. In such cases, using `git` can be overkill and introduce unnecessary complexity. ## HfApi: a flexible and convenient HTTP client -The [`HfApi`] class was developed to provide an alternative to local git repositories, which -can be cumbersome to maintain, especially when dealing with large models or datasets. The -[`HfApi`] class offers the same functionality as git-based approaches, such as downloading -and pushing files and creating branches and tags, but without the need for a local folder -that needs to be kept in sync. +The [`HfApi`] class was developed to provide an alternative to using local git repositories, which can be cumbersome to maintain, especially when dealing with large models or datasets. The [`HfApi`] class offers the same functionality as git-based workflows -such as downloading and pushing files and creating branches and tags- but without the need for a local folder that needs to be kept in sync. -In addition to the functionalities already provided by `git`, the [`HfApi`] class offers -additional features, such as the ability to manage repos, download files using caching for -efficient reuse, search the Hub for repos and metadata, access community features such as -discussions, PRs, and comments, and configure Spaces hardware and secrets. +In addition to the functionalities already provided by `git`, the [`HfApi`] class offers additional features, such as the ability to manage repos, download files using caching for efficient reuse, search the Hub for repos and metadata, access community features such as discussions, PRs, and comments, and configure Spaces hardware and secrets. ## What should I use ? And when ? -Overall, the **HTTP-based approach is the recommended way to use** `huggingface_hub` -in all cases. [`HfApi`] allows to pull and push changes, work with PRs, tags and branches, interact with discussions and much more. Since the `0.16` release, the http-based methods can also run in the background, which was the last major advantage of the [`Repository`] class. +Overall, the **HTTP-based approach is the recommended way to use** `huggingface_hub` in all cases. [`HfApi`] allows you to pull and push changes, work with PRs, tags and branches, interact with discussions and much more. -However, not all git commands are available through [`HfApi`]. Some may never be implemented, but we are always trying to improve and close the gap. If you don't see your use case covered, please open [an issue on Github](https://github.com/huggingface/huggingface_hub)! We welcome feedback to help build the 🤗 ecosystem with and for our users. +However, not all git commands are available through [`HfApi`]. Some may never be implemented, but we are always trying to improve and close the gap. If you don't see your use case covered, please open [an issue on GitHub](https://github.com/huggingface/huggingface_hub)! We welcome feedback to help build the HF ecosystem with and for our users. -This preference of the http-based [`HfApi`] over the git-based [`Repository`] does not mean that git versioning will disappear from the Hugging Face Hub anytime soon. It will always be possible to use `git` commands locally in workflows where it makes sense. +This preference for the HTTP-based [`HfApi`] over direct `git` commands does not mean that git versioning will disappear from the Hugging Face Hub anytime soon. It will always be possible to use `git` locally in workflows where it makes sense. \ No newline at end of file From 2430dca8f175fd19228c5e527c9ebd2e48aa2fbe Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sat, 13 Sep 2025 07:58:31 +0000 Subject: [PATCH 14/19] feat: add migration guide for v1.0 This commit adds a comprehensive migration guide for the v1.0 release of the `huggingface_hub` library. The guide is located at `docs/source/en/concepts/migration.md` and provides a detailed list of main changes and breaking changes, along with instructions on how to adapt to them. The migration guide covers the following topics: - HTTPX migration - Python 3.9+ requirement - Removal of deprecated features - Removal of the `Repository`, `HfFolder`, and `InferenceApi` classes - Removal of TensorFlow and Keras 2 integrations This guide is intended to help users migrate their existing code to the new version of the library smoothly. Fixes #3358 [Auto-generated by https://jules.google.com/] --- docs/source/en/concepts/migration.md | 107 +++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 docs/source/en/concepts/migration.md diff --git a/docs/source/en/concepts/migration.md b/docs/source/en/concepts/migration.md new file mode 100644 index 0000000000..a6f0c32e26 --- /dev/null +++ b/docs/source/en/concepts/migration.md @@ -0,0 +1,107 @@ +# Migrating to huggingface_hub v1.0 + +The `huggingface_hub` library has undergone significant changes in the v1.0 release. This guide is intended to help you migrate your existing code to the new version. + +The v1.0 release is a major milestone for the library. It marks our commitment to API stability and the maturity of the library. We have made several improvements and breaking changes to make the library more robust and easier to use. + +This guide is divided into two sections: +- [Main changes](#main-changes): A list of the most important new features and improvements. +- [Breaking changes](#breaking-changes): A comprehensive list of all breaking changes and how to adapt your code. + +We hope this guide will help you to migrate to the new version of `huggingface_hub` smoothly. If you have any questions or feedback, please open an issue on the [GitHub repository](https://github.com/huggingface/huggingface_hub/issues). + +## Main changes + +### HTTPX migration + +The `huggingface_hub` library now uses [`httpx`](https://www.python-httpx.org/) instead of `requests` for HTTP requests. This change was made to improve performance and to support asynchronous requests. + +This is a major change that affects the entire library. While we have tried to make this change as transparent as possible, you may need to update your code in some cases. Please see the [Breaking changes](#breaking-changes) section for more details. + +### Python 3.9+ + +`huggingface_hub` now requires Python 3.9 or higher. Python 3.8 is no longer supported. + +### Built-in generics for type annotations + +The library now uses built-in generics for type annotations (e.g. `list` instead of `typing.List`). This is a new feature in Python 3.9 and makes the code cleaner and easier to read. + +## Breaking changes + +This section lists the breaking changes introduced in v1.0. + +### Python 3.8 support dropped + +`huggingface_hub` v1.0 drops support for Python 3.8. You will need to upgrade to Python 3.9 or higher to use the new version of the library. + +### HTTPX migration + +The migration to `httpx` has introduced a few breaking changes. + +- **Proxy configuration**: "per method" proxies are no longer supported. Proxies must be configured globally using the `HTTP_PROXY` and `HTTPS_PROXY` environment variables. +- **Custom HTTP backend**: The `configure_http_backend` function has been removed. You can now use `set_client_factory` and `set_async_client_factory` to configure the HTTP client. +- **Error handling**: `requests.HTTPError` is no longer raised. Instead, `httpx.HTTPError` is raised. We recommend catching `HfHubHttpError` which is a subclass of `httpx.HTTPError` and will ensure your code is compatible with both old and new versions of the library. +- **SSLError**: `httpx` does not have the concept of `SSLError`. It is now a generic `httpx.ConnectError`. +- **`LocalEntryNotFoundError`**: This error no longer inherits from `HTTPError`. +- **`InferenceClient`**: The `InferenceClient` can now be used as a context manager. This is especially useful when streaming tokens from a language model to ensure that the connection is closed properly. +- **`AsyncInferenceClient`**: The `trust_env` parameter has been removed from the `AsyncInferenceClient`'s constructor. + +### Deprecated features removed + +A number of deprecated functions and parameters have been removed in v1.0. + +- `hf_cache_home` is removed. Please use `HF_HOME` instead. +- `use_auth_token` is removed. Please use `token` instead. +- `get_token_permission` is removed. +- `update_repo_visibility` is removed. Please use `update_repo_settings` instead. +- `is_write_action` parameter is removed from `build_hf_headers`. +- `write_permission` parameter is removed from `login`. +- `new_session` parameter in `login` is renamed to `skip_if_logged_in`. +- `resume_download`, `force_filename`, and `local_dir_use_symlinks` parameters are removed from `hf_hub_download` and `snapshot_download`. +- `library`, `language`, `tags`, and `task` parameters are removed from `list_models`. + +### Return value of `upload_file` and `upload_folder` + +The `upload_file` and `upload_folder` functions now return the URL of the commit created on the Hub. Previously, they returned the URL of the file or folder. + +### `Repository` class removed + +The `Repository` class has been removed in v1.0. This class was a git-based wrapper to manage repositories. The recommended way to interact with the Hub is now to use the HTTP-based functions in the `huggingface_hub` library. + +The `Repository` class was mostly a wrapper around the `git` command-line interface. You can still use `git` directly to manage your repositories. However, we recommend using the `huggingface_hub` library's HTTP-based API for a better experience, especially when dealing with large files. + +Here is a mapping from the old `Repository` methods to the new functions: + +| `Repository` method | New function | +| --- | --- | +| `repo.clone_from(...)` | `snapshot_download(...)` | +| `repo.git_add(...)` | `upload_file(...)` or `upload_folder(...)` | +| `repo.git_commit(...)` | `create_commit(...)` | +| `repo.git_push(...)` | `create_commit(...)` | +| `repo.git_pull(...)` | `snapshot_download(...)` | +| `repo.git_checkout(...)` | `snapshot_download(..., revision=...)` | +| `repo.git_tag(...)` | `create_tag(...)` | +| `repo.git_branch(...)` | `create_branch(...)` | + +### `HfFolder` and `InferenceApi` classes removed + +The `HfFolder` and `InferenceApi` classes have been removed in v1.0. + +- `HfFolder` was used to manage the Hugging Face cache directory and the user's token. It is now recommended to use the following functions instead: + - [`login`] and [`logout`] to manage the user's token. + - [`hf_hub_download`] and [`snapshot_download`] to download and cache files from the Hub. + +- `InferenceApi` was a class to interact with the Inference API. It is now recommended to use the [`InferenceClient`] class instead. + +### TensorFlow support removed + +All TensorFlow-related code and dependencies have been removed in v1.0. This includes the following breaking changes: + +- The `split_tf_state_dict_into_shards` and `get_tf_storage_size` utility functions have been removed. +- The `tensorflow`, `fastai`, and `fastcore` versions are no longer included in the built-in headers. + +### Keras 2 integration removed + +The Keras 2 integration has been removed in v1.0. This includes the `KerasModelHubMixin` class and the `save_pretrained_keras`, `from_pretrained_keras`, and `push_to_hub_keras` functions. + +Keras 3 is now tightly integrated with the Hub. You can use the [`ModelHubMixin`] to integrate your Keras 3 models with the Hub. Please refer to the [Integrate any ML framework with the Hub](./integrations.md) guide for more details. From 80314a51d8d36ba8cdd2709df139b2c627fde7b6 Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Mon, 15 Sep 2025 14:46:19 +0200 Subject: [PATCH 15/19] rewrite migration guide --- docs/source/en/_toctree.yml | 2 + docs/source/en/concepts/migration.md | 111 ++++++++++----------------- 2 files changed, 41 insertions(+), 72 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 3f930fb448..5407e0374a 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -46,6 +46,8 @@ sections: - local: concepts/git_vs_http title: Git vs HTTP paradigm + - local: concepts/migration + title: Migrating to huggingface_hub v1.0 - title: 'Reference' sections: - local: package_reference/overview diff --git a/docs/source/en/concepts/migration.md b/docs/source/en/concepts/migration.md index a6f0c32e26..d10469095f 100644 --- a/docs/source/en/concepts/migration.md +++ b/docs/source/en/concepts/migration.md @@ -1,107 +1,74 @@ # Migrating to huggingface_hub v1.0 -The `huggingface_hub` library has undergone significant changes in the v1.0 release. This guide is intended to help you migrate your existing code to the new version. +The v1.0 release is a major milestone for the `huggingface_hub` library. It marks our commitment to API stability and the maturity of the library. We have made several improvements and breaking changes to make the library more robust and easier to use. -The v1.0 release is a major milestone for the library. It marks our commitment to API stability and the maturity of the library. We have made several improvements and breaking changes to make the library more robust and easier to use. +This guide is intended to help you migrate your existing code to the new version. If you have any questions or feedback, please let us know by [opening an issue on GitHub](https://github.com/huggingface/huggingface_hub/issues). -This guide is divided into two sections: -- [Main changes](#main-changes): A list of the most important new features and improvements. -- [Breaking changes](#breaking-changes): A comprehensive list of all breaking changes and how to adapt your code. - -We hope this guide will help you to migrate to the new version of `huggingface_hub` smoothly. If you have any questions or feedback, please open an issue on the [GitHub repository](https://github.com/huggingface/huggingface_hub/issues). - -## Main changes - -### HTTPX migration - -The `huggingface_hub` library now uses [`httpx`](https://www.python-httpx.org/) instead of `requests` for HTTP requests. This change was made to improve performance and to support asynchronous requests. - -This is a major change that affects the entire library. While we have tried to make this change as transparent as possible, you may need to update your code in some cases. Please see the [Breaking changes](#breaking-changes) section for more details. - -### Python 3.9+ +## Python 3.9+ `huggingface_hub` now requires Python 3.9 or higher. Python 3.8 is no longer supported. -### Built-in generics for type annotations +## HTTPX migration -The library now uses built-in generics for type annotations (e.g. `list` instead of `typing.List`). This is a new feature in Python 3.9 and makes the code cleaner and easier to read. +The `huggingface_hub` library now uses [`httpx`](https://www.python-httpx.org/) instead of `requests` for HTTP requests. This change was made to improve performance and to support both synchronous and asynchronous requests the same way. We therefore dropped both `requests` and `aiohttp` dependencies. -## Breaking changes - -This section lists the breaking changes introduced in v1.0. - -### Python 3.8 support dropped - -`huggingface_hub` v1.0 drops support for Python 3.8. You will need to upgrade to Python 3.9 or higher to use the new version of the library. - -### HTTPX migration - -The migration to `httpx` has introduced a few breaking changes. +This is a major change that affects the entire library. While we have tried to make this change as transparent as possible, you may need to update your code in some cases. Here is a list of breaking changes introduced in the process: - **Proxy configuration**: "per method" proxies are no longer supported. Proxies must be configured globally using the `HTTP_PROXY` and `HTTPS_PROXY` environment variables. -- **Custom HTTP backend**: The `configure_http_backend` function has been removed. You can now use `set_client_factory` and `set_async_client_factory` to configure the HTTP client. -- **Error handling**: `requests.HTTPError` is no longer raised. Instead, `httpx.HTTPError` is raised. We recommend catching `HfHubHttpError` which is a subclass of `httpx.HTTPError` and will ensure your code is compatible with both old and new versions of the library. +- **Custom HTTP backend**: The `configure_http_backend` function has been removed. You should now use [`set_client_factory`] and [`set_async_client_factory`] to configure the HTTP clients. +- **Error handling**: HTTP errors are not inherited from `requests.HTTPError` anymore, but from `httpx.HTTPError`. We recommend catching `huggingface_hub.HfHubHttpError` which is a subclass of `requests.HTTPError` in v0.x and of `httpx.HTTPError` in v1.x. Catching from the `huggingface_hub` error ensures your code is compatible with both the old and new versions of the library. - **SSLError**: `httpx` does not have the concept of `SSLError`. It is now a generic `httpx.ConnectError`. -- **`LocalEntryNotFoundError`**: This error no longer inherits from `HTTPError`. +- **`LocalEntryNotFoundError`**: This error no longer inherits from `HTTPError`. We now define a `EntryNotFoundError` (new) that is inherited by both [`LocalEntryNotFoundError`] (if file not found in local cache) and [`RemoteEntryNotFoundError`] (if file not found in repo on the Hub). Only the remote error inherits from `HTTPError`. - **`InferenceClient`**: The `InferenceClient` can now be used as a context manager. This is especially useful when streaming tokens from a language model to ensure that the connection is closed properly. -- **`AsyncInferenceClient`**: The `trust_env` parameter has been removed from the `AsyncInferenceClient`'s constructor. - -### Deprecated features removed +- **`AsyncInferenceClient`**: The `trust_env` parameter has been removed from the `AsyncInferenceClient`'s constructor. Environment variables are trusted by default by `httpx`. If you explicitly don't want to trust the environment, you must configure it with [`set_client_factory`]. -A number of deprecated functions and parameters have been removed in v1.0. +For more details, you can check [PR #3328](https://github.com/huggingface/huggingface_hub/pull/3328) that introduced `httpx`. -- `hf_cache_home` is removed. Please use `HF_HOME` instead. -- `use_auth_token` is removed. Please use `token` instead. -- `get_token_permission` is removed. -- `update_repo_visibility` is removed. Please use `update_repo_settings` instead. -- `is_write_action` parameter is removed from `build_hf_headers`. -- `write_permission` parameter is removed from `login`. -- `new_session` parameter in `login` is renamed to `skip_if_logged_in`. -- `resume_download`, `force_filename`, and `local_dir_use_symlinks` parameters are removed from `hf_hub_download` and `snapshot_download`. -- `library`, `language`, `tags`, and `task` parameters are removed from `list_models`. +## `Repository` class -### Return value of `upload_file` and `upload_folder` +The `Repository` class has been removed in v1.0. It was a thin wrapper around the `git` CLI for managing repositories. You can still use `git` directly in the terminal, but the recommended approach is to use the HTTP-based API in the `huggingface_hub` library for a smoother experience, especially when dealing with large files. -The `upload_file` and `upload_folder` functions now return the URL of the commit created on the Hub. Previously, they returned the URL of the file or folder. +Here is a mapping from the legacy `Repository` class to the new `HfApi` one: -### `Repository` class removed +| `Repository` method | `HfApi` method | +| ------------------------------------------ | ----------------------------------------------------- | +| `repo.clone_from` | `snapshot_download` | +| `repo.git_add` + `git_commit` + `git_push` | [`upload_file`], [`upload_folder`], [`create_commit`] | +| `repo.git_tag` | `create_tag` | +| `repo.git_branch` | `create_branch` | -The `Repository` class has been removed in v1.0. This class was a git-based wrapper to manage repositories. The recommended way to interact with the Hub is now to use the HTTP-based functions in the `huggingface_hub` library. +## `HfFolder` class -The `Repository` class was mostly a wrapper around the `git` command-line interface. You can still use `git` directly to manage your repositories. However, we recommend using the `huggingface_hub` library's HTTP-based API for a better experience, especially when dealing with large files. +`HfFolder` was used to manage the user access token. Use [`login`] to save a new token, [`logout`] to delete it and [`whoami`] to check the user associated to the current token. Finally, use [`get_token`] to retrieve user's token in a script. -Here is a mapping from the old `Repository` methods to the new functions: -| `Repository` method | New function | -| --- | --- | -| `repo.clone_from(...)` | `snapshot_download(...)` | -| `repo.git_add(...)` | `upload_file(...)` or `upload_folder(...)` | -| `repo.git_commit(...)` | `create_commit(...)` | -| `repo.git_push(...)` | `create_commit(...)` | -| `repo.git_pull(...)` | `snapshot_download(...)` | -| `repo.git_checkout(...)` | `snapshot_download(..., revision=...)` | -| `repo.git_tag(...)` | `create_tag(...)` | -| `repo.git_branch(...)` | `create_branch(...)` | +## `InferenceApi` class -### `HfFolder` and `InferenceApi` classes removed +`InferenceApi` was a class to interact with the Inference API. It is now recommended to use the [`InferenceClient`] class instead. -The `HfFolder` and `InferenceApi` classes have been removed in v1.0. +## Other deprecated features -- `HfFolder` was used to manage the Hugging Face cache directory and the user's token. It is now recommended to use the following functions instead: - - [`login`] and [`logout`] to manage the user's token. - - [`hf_hub_download`] and [`snapshot_download`] to download and cache files from the Hub. +Some methods and parameters have been removed in v1.0. The ones listed below have already been deprecated with a warning message in v0.x. -- `InferenceApi` was a class to interact with the Inference API. It is now recommended to use the [`InferenceClient`] class instead. +- `constants.hf_cache_home` has been removed. Please use `HF_HOME` instead. +- `use_auth_token` parameters have been removed from all methods. Please use `token` instead. +- `get_token_permission` method has been removed. +- `update_repo_visibility` method has been removed. Please use `update_repo_settings` instead. +- `is_write_action` parameter has been removed from `build_hf_headers` as well as `write_permission` from `login`. The concept of "write permission" has been removed and is no longer relevant now that fine-grained tokens are the recommended approach. +- `new_session` parameter in `login` has been renamed to `skip_if_logged_in` for better clarity. +- `resume_download`, `force_filename`, and `local_dir_use_symlinks` parameters have been removed from `hf_hub_download` and `snapshot_download`. +- `library`, `language`, `tags`, and `task` parameters have been removed from `list_models`. -### TensorFlow support removed +## TensorFlow and Keras 2.x support All TensorFlow-related code and dependencies have been removed in v1.0. This includes the following breaking changes: +- `huggingface_hub[tensorflow]` is no longer a supported extra dependency - The `split_tf_state_dict_into_shards` and `get_tf_storage_size` utility functions have been removed. - The `tensorflow`, `fastai`, and `fastcore` versions are no longer included in the built-in headers. -### Keras 2 integration removed +The Keras 2.x integration has also been removed. This includes the `KerasModelHubMixin` class and the `save_pretrained_keras`, `from_pretrained_keras`, and `push_to_hub_keras` utilities. Keras 2.x is a legacy and unmaintained library. The recommended approach is to use Keras 3.x which is tightly integrated with the Hub (i.e. it contains built-in method to load/push to Hub). If you still want to work with Keras 2.x, you should downgrade `huggingface_hub` to v0.x version. -The Keras 2 integration has been removed in v1.0. This includes the `KerasModelHubMixin` class and the `save_pretrained_keras`, `from_pretrained_keras`, and `push_to_hub_keras` functions. +## `upload_file` and `upload_folder` return values -Keras 3 is now tightly integrated with the Hub. You can use the [`ModelHubMixin`] to integrate your Keras 3 models with the Hub. Please refer to the [Integrate any ML framework with the Hub](./integrations.md) guide for more details. +The [`upload_file`] and [`upload_folder`] functions now return the URL of the commit created on the Hub. Previously, they returned the URL of the file or folder. This is to align with the return value of [`create_commit`], [`delete_file`] and [`delete_folder`]. \ No newline at end of file From b22263bad3978d50f2f4ccae759b07518fa22e44 Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Mon, 15 Sep 2025 14:47:41 +0200 Subject: [PATCH 16/19] fix import --- docs/source/en/package_reference/utilities.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/package_reference/utilities.md b/docs/source/en/package_reference/utilities.md index c0ec92ed53..d87bd6d0af 100644 --- a/docs/source/en/package_reference/utilities.md +++ b/docs/source/en/package_reference/utilities.md @@ -296,4 +296,4 @@ validated. Not exactly a validator, but ran as well. -[[autodoc]] utils.smoothly_deprecate_legacy_arguments +[[autodoc]] utils._validators.smoothly_deprecate_legacy_arguments From a10d76bb4fa5d7391e3a1998c0b39da4a8cfeb08 Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Mon, 15 Sep 2025 14:59:40 +0200 Subject: [PATCH 17/19] fix docs? --- docs/source/ko/_toctree.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/source/ko/_toctree.yml b/docs/source/ko/_toctree.yml index 0a82cd72db..e67d69af38 100644 --- a/docs/source/ko/_toctree.yml +++ b/docs/source/ko/_toctree.yml @@ -18,8 +18,6 @@ title: 명령줄 인터페이스(CLI) 사용하기 - local: guides/hf_file_system title: Hf파일시스템 - - local: guides/repository - title: 리포지토리 - local: guides/search title: Hub에서 검색하기 - local: guides/inference From 481eb49d9d86bf200a88e2ba9f48454a0fd9738f Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Thu, 25 Sep 2025 17:29:17 +0200 Subject: [PATCH 18/19] add why httpx section --- docs/source/en/concepts/migration.md | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/docs/source/en/concepts/migration.md b/docs/source/en/concepts/migration.md index d10469095f..3af9c0ab63 100644 --- a/docs/source/en/concepts/migration.md +++ b/docs/source/en/concepts/migration.md @@ -12,6 +12,8 @@ This guide is intended to help you migrate your existing code to the new version The `huggingface_hub` library now uses [`httpx`](https://www.python-httpx.org/) instead of `requests` for HTTP requests. This change was made to improve performance and to support both synchronous and asynchronous requests the same way. We therefore dropped both `requests` and `aiohttp` dependencies. +### Breaking changes + This is a major change that affects the entire library. While we have tried to make this change as transparent as possible, you may need to update your code in some cases. Here is a list of breaking changes introduced in the process: - **Proxy configuration**: "per method" proxies are no longer supported. Proxies must be configured globally using the `HTTP_PROXY` and `HTTPS_PROXY` environment variables. @@ -24,6 +26,26 @@ This is a major change that affects the entire library. While we have tried to m For more details, you can check [PR #3328](https://github.com/huggingface/huggingface_hub/pull/3328) that introduced `httpx`. +### Why `httpx`? + +## Why `httpx`? + +The migration from `requests` to `httpx` brings several key improvements that enhance the library's performance, reliability, and maintainability: + +**Thread Safety and Connection Reuse**: `httpx` is thread-safe by design, allowing us to safely reuse the same client across multiple threads. This connection reuse reduces the overhead of establishing new connections for each HTTP request, improving performance especially when making frequent requests to the Hub. + +**HTTP/2 Support**: `httpx` provides native HTTP/2 support, which offers better efficiency when making multiple requests to the same server (exactly our use case). This translates to lower latency and reduced resource consumption compared to HTTP/1.1. + +**Unified Sync/Async API**: Unlike our previous setup with separate `requests` (sync) and `aiohttp` (async) dependencies, `httpx` provides both synchronous and asynchronous clients with identical behavior. This ensures that `InferenceClient` and `AsyncInferenceClient` have consistent functionality and eliminates subtle behavioral differences that previously existed between the two implementations. + +**Improved SSL Error Handling**: `httpx` handles SSL errors more gracefully, making debugging connection issues easier and more reliable. + +**Future-Proof Architecture**: `httpx` is actively maintained and designed for modern Python applications. In contrast, `requests` is in maintenance mode and won't receive major updates like thread-safety improvements or HTTP/2 support. + +**Better Environment Variable Handling**: `httpx` provides more consistent handling of environment variables across both sync and async contexts, eliminating previous inconsistencies where `requests` would read local environment variables by default while `aiohttp` would not. + +The transition to `httpx` positions `huggingface_hub` with a modern, efficient, and maintainable HTTP backend. While most users should experience seamless operation, the underlying improvements provide better performance and reliability for all Hub interactions. + ## `Repository` class The `Repository` class has been removed in v1.0. It was a thin wrapper around the `git` CLI for managing repositories. You can still use `git` directly in the terminal, but the recommended approach is to use the HTTP-based API in the `huggingface_hub` library for a smoother experience, especially when dealing with large files. From a56b960a331623b3b1d4053222891b2c26e7bec6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?c=C3=A9lina?= Date: Wed, 1 Oct 2025 13:27:49 +0200 Subject: [PATCH 19/19] Update docs/source/en/concepts/migration.md --- docs/source/en/concepts/migration.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/source/en/concepts/migration.md b/docs/source/en/concepts/migration.md index 3af9c0ab63..2f60edc53d 100644 --- a/docs/source/en/concepts/migration.md +++ b/docs/source/en/concepts/migration.md @@ -28,7 +28,6 @@ For more details, you can check [PR #3328](https://github.com/huggingface/huggin ### Why `httpx`? -## Why `httpx`? The migration from `requests` to `httpx` brings several key improvements that enhance the library's performance, reliability, and maintainability: