diff --git a/src/huggingface_hub/utils/_http.py b/src/huggingface_hub/utils/_http.py index f98dfda517..f5533316f8 100644 --- a/src/huggingface_hub/utils/_http.py +++ b/src/huggingface_hub/utils/_http.py @@ -102,6 +102,13 @@ def hf_request_event_hook(request: httpx.Request) -> None: return request_id +async def async_hf_request_event_hook(request: httpx.Request) -> None: + """ + Async version of `hf_request_event_hook`. + """ + return hf_request_event_hook(request) + + def default_client_factory() -> httpx.Client: """ Factory function to create a `httpx.Client` with the default transport. @@ -118,7 +125,7 @@ def default_async_client_factory() -> httpx.AsyncClient: Factory function to create a `httpx.AsyncClient` with the default transport. """ return httpx.AsyncClient( - event_hooks={"request": [hf_request_event_hook]}, + event_hooks={"request": [async_hf_request_event_hook]}, follow_redirects=True, timeout=httpx.Timeout(constants.DEFAULT_REQUEST_TIMEOUT, write=60.0), ) diff --git a/tests/test_utils_http.py b/tests/test_utils_http.py index cfce921497..0b20f78c75 100644 --- a/tests/test_utils_http.py +++ b/tests/test_utils_http.py @@ -17,6 +17,7 @@ _adjust_range_header, default_client_factory, fix_hf_endpoint_in_url, + get_async_session, get_session, http_backoff, set_client_factory, @@ -362,3 +363,18 @@ def test_proxy_env_is_used(monkeypatch): # Reset set_client_factory(default_client_factory) + + +def test_client_get_request(): + # Check that sync client works + client = get_session() + response = client.get("https://huggingface.co") + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_async_client_get_request(): + # Check that async client works + client = get_async_session() + response = await client.get("https://huggingface.co") + assert response.status_code == 200