Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,6 @@
"CorruptedCacheException",
"DeleteCacheStrategy",
"HFCacheInfo",
"HfHubAsyncTransport",
"HfHubTransport",
"cached_assets_path",
"close_session",
"dump_environment_info",
Expand Down Expand Up @@ -645,8 +643,6 @@
"HfFileSystemFile",
"HfFileSystemResolvedPath",
"HfFileSystemStreamFile",
"HfHubAsyncTransport",
"HfHubTransport",
"ImageClassificationInput",
"ImageClassificationOutputElement",
"ImageClassificationOutputTransform",
Expand Down Expand Up @@ -1515,8 +1511,6 @@ def __dir__():
CorruptedCacheException, # noqa: F401
DeleteCacheStrategy, # noqa: F401
HFCacheInfo, # noqa: F401
HfHubAsyncTransport, # noqa: F401
HfHubTransport, # noqa: F401
cached_assets_path, # noqa: F401
close_session, # noqa: F401
dump_environment_info, # noqa: F401
Expand Down
2 changes: 0 additions & 2 deletions src/huggingface_hub/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@
from ._http import (
ASYNC_CLIENT_FACTORY_T,
CLIENT_FACTORY_T,
HfHubAsyncTransport,
HfHubTransport,
close_session,
fix_hf_endpoint_in_url,
get_async_session,
Expand Down
46 changes: 9 additions & 37 deletions src/huggingface_hub/utils/_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,49 +69,21 @@
)


class HfHubTransport(httpx.HTTPTransport):
def hf_request_event_hook(request: httpx.Request) -> None:
"""
Transport that will be used to make HTTP requests to the Hugging Face Hub.
Event hook that will be used to make HTTP requests to the Hugging Face Hub.

What it does:
- Block requests if offline mode is enabled
- Add a request ID to the request headers
- Log the request if debug mode is enabled
"""
if constants.HF_HUB_OFFLINE:
raise OfflineModeIsEnabled(
f"Cannot reach {request.url}: offline mode is enabled. To disable it, please unset the `HF_HUB_OFFLINE` environment variable."
)

def handle_request(self, request: httpx.Request) -> httpx.Response:
if constants.HF_HUB_OFFLINE:
raise OfflineModeIsEnabled(
f"Cannot reach {request.url}: offline mode is enabled. To disable it, please unset the `HF_HUB_OFFLINE` environment variable."
)
request_id = _add_request_id(request)
try:
return super().handle_request(request)
except httpx.RequestError as e:
if request_id is not None:
# Taken from https://stackoverflow.com/a/58270258
e.args = (*e.args, f"(Request ID: {request_id})")
raise


class HfHubAsyncTransport(httpx.AsyncHTTPTransport):
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
if constants.HF_HUB_OFFLINE:
raise OfflineModeIsEnabled(
f"Cannot reach {request.url}: offline mode is enabled. To disable it, please unset the `HF_HUB_OFFLINE` environment variable."
)
request_id = _add_request_id(request)
try:
return await super().handle_async_request(request)
except httpx.RequestError as e:
if request_id is not None:
# Taken from https://stackoverflow.com/a/58270258
e.args = (*e.args, f"(Request ID: {request_id})")
raise


def _add_request_id(request: httpx.Request) -> Optional[str]:
# Add random request ID => easier for server-side debug
# Add random request ID => easier for server-side debugging
if X_AMZN_TRACE_ID not in request.headers:
request.headers[X_AMZN_TRACE_ID] = request.headers.get(X_REQUEST_ID) or str(uuid.uuid4())
request_id = request.headers.get(X_AMZN_TRACE_ID)
Expand All @@ -135,7 +107,7 @@ def default_client_factory() -> httpx.Client:
Factory function to create a `httpx.Client` with the default transport.
"""
return httpx.Client(
transport=HfHubTransport(),
event_hooks={"request": [hf_request_event_hook]},
follow_redirects=True,
timeout=httpx.Timeout(constants.DEFAULT_REQUEST_TIMEOUT, write=60.0),
)
Expand All @@ -146,7 +118,7 @@ def default_async_client_factory() -> httpx.AsyncClient:
Factory function to create a `httpx.AsyncClient` with the default transport.
"""
return httpx.AsyncClient(
transport=HfHubAsyncTransport(),
event_hooks={"request": [hf_request_event_hook]},
follow_redirects=True,
timeout=httpx.Timeout(constants.DEFAULT_REQUEST_TIMEOUT, write=60.0),
)
Expand Down
32 changes: 29 additions & 3 deletions tests/test_utils_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from huggingface_hub.constants import ENDPOINT
from huggingface_hub.errors import OfflineModeIsEnabled
from huggingface_hub.utils._http import (
HfHubTransport,
_adjust_range_header,
default_client_factory,
fix_hf_endpoint_in_url,
Expand Down Expand Up @@ -170,8 +169,6 @@ def test_default_configuration(self) -> None:
# Check httpx.Client default configuration
self.assertTrue(client.follow_redirects)
self.assertIsNotNone(client.timeout)
# Check that it's using the HfHubTransport
self.assertIsInstance(client._transport, HfHubTransport)

def test_set_configuration(self) -> None:
set_client_factory(self._factory)
Expand Down Expand Up @@ -332,3 +329,32 @@ def test_adjust_range_header():
_adjust_range_header("bytes=0-100", 150)
with pytest.raises(RuntimeError):
_adjust_range_header("bytes=-50", 100)


def test_proxy_env_is_used(monkeypatch):
"""Regression test for https://github.com/huggingface/transformers/issues/41301.

Test is hacky and uses httpx internal attributes, but it works.
We just need to test that proxies from env vars are used when creating the client.
"""
monkeypatch.setenv("HTTP_PROXY", "http://proxy.example1.com:8080")
monkeypatch.setenv("HTTPS_PROXY", "http://proxy.example2.com:8181")

client = get_session()
mounts = client._mounts
url_patterns = list(mounts.keys())
assert len(url_patterns) == 2 # http and https

http_url_pattern = next(url for url in url_patterns if url.pattern == "http://")
http_proxy_url = mounts[http_url_pattern]._pool._proxy_url
assert http_proxy_url.scheme == b"http"
assert http_proxy_url.host == b"proxy.example1.com"
assert http_proxy_url.port == 8080
assert http_proxy_url.target == b"/"

https_url_pattern = next(url for url in url_patterns if url.pattern == "https://")
https_proxy_url = mounts[https_url_pattern]._pool._proxy_url
assert https_proxy_url.scheme == b"http"
assert https_proxy_url.host == b"proxy.example2.com"
assert https_proxy_url.port == 8181
assert https_proxy_url.target == b"/"
Loading