diff --git a/httpx/_config.py b/httpx/_config.py index 467a6c90ae..ef5d83e178 100644 --- a/httpx/_config.py +++ b/httpx/_config.py @@ -2,6 +2,9 @@ import os import typing +from datetime import timedelta + +from httpx._utils import opt_timedelta_to_seconds from ._models import Headers from ._types import CertTypes, HeaderTypes, TimeoutTypes @@ -87,10 +90,10 @@ def __init__( self, timeout: TimeoutTypes | UnsetType = UNSET, *, - connect: None | float | UnsetType = UNSET, - read: None | float | UnsetType = UNSET, - write: None | float | UnsetType = UNSET, - pool: None | float | UnsetType = UNSET, + connect: None | float | timedelta | UnsetType = UNSET, + read: None | float | timedelta | UnsetType = UNSET, + write: None | float | timedelta | UnsetType = UNSET, + pool: None | float | timedelta | UnsetType = UNSET, ) -> None: if isinstance(timeout, Timeout): # Passed as a single explicit Timeout. @@ -104,30 +107,47 @@ def __init__( self.pool = timeout.pool # type: typing.Optional[float] elif isinstance(timeout, tuple): # Passed as a tuple. - self.connect = timeout[0] - self.read = timeout[1] - self.write = None if len(timeout) < 3 else timeout[2] - self.pool = None if len(timeout) < 4 else timeout[3] + assert connect is UNSET + assert read is UNSET + assert write is UNSET + assert pool is UNSET + self.connect = opt_timedelta_to_seconds(timeout[0]) + self.read = opt_timedelta_to_seconds(timeout[1]) + self.write = opt_timedelta_to_seconds( + None if len(timeout) < 3 else timeout[2] + ) + self.pool = opt_timedelta_to_seconds( + None if len(timeout) < 4 else timeout[3] + ) elif not ( isinstance(connect, UnsetType) or isinstance(read, UnsetType) or isinstance(write, UnsetType) or isinstance(pool, UnsetType) ): - self.connect = connect - self.read = read - self.write = write - self.pool = pool + self.connect = opt_timedelta_to_seconds(connect) + self.read = opt_timedelta_to_seconds(read) + self.write = opt_timedelta_to_seconds(write) + self.pool = opt_timedelta_to_seconds(pool) else: if isinstance(timeout, UnsetType): raise ValueError( "httpx.Timeout must either include a default, or set all " "four parameters explicitly." ) - self.connect = timeout if isinstance(connect, UnsetType) else connect - self.read = timeout if isinstance(read, UnsetType) else read - self.write = timeout if isinstance(write, UnsetType) else write - self.pool = timeout if isinstance(pool, UnsetType) else pool + + self.connect = opt_timedelta_to_seconds( + timeout if isinstance(connect, UnsetType) else connect + ) + self.read = opt_timedelta_to_seconds( + timeout if isinstance(read, UnsetType) else read + ) + self.write = opt_timedelta_to_seconds( + timeout if isinstance(write, UnsetType) else write + ) + self.pool = opt_timedelta_to_seconds( + timeout if isinstance(pool, UnsetType) else pool + ) def as_dict(self) -> dict[str, float | None]: return { diff --git a/httpx/_types.py b/httpx/_types.py index 704dfdffc8..0e5f5ef5c9 100644 --- a/httpx/_types.py +++ b/httpx/_types.py @@ -2,6 +2,7 @@ Type definitions for type checking purposes. """ +from datetime import timedelta from http.cookiejar import CookieJar from typing import ( IO, @@ -52,8 +53,13 @@ CookieTypes = Union["Cookies", CookieJar, Dict[str, str], List[Tuple[str, str]]] TimeoutTypes = Union[ - Optional[float], - Tuple[Optional[float], Optional[float], Optional[float], Optional[float]], + Optional[Union[float, timedelta]], + Tuple[ + Optional[Union[float, timedelta]], + Optional[Union[float, timedelta]], + Optional[Union[float, timedelta]], + Optional[Union[float, timedelta]], + ], "Timeout", ] ProxyTypes = Union["URL", str, "Proxy"] diff --git a/httpx/_utils.py b/httpx/_utils.py index 7fe827da4d..2a8dc2fc85 100644 --- a/httpx/_utils.py +++ b/httpx/_utils.py @@ -4,6 +4,7 @@ import os import re import typing +from datetime import timedelta from urllib.request import getproxies from ._types import PrimitiveData @@ -12,6 +13,14 @@ from ._urls import URL +def opt_timedelta_to_seconds( + value: typing.Union[float, timedelta, None], +) -> float | None: + if isinstance(value, timedelta): + return value.total_seconds() + return value + + def primitive_value_to_str(value: PrimitiveData) -> str: """ Coerce a primitive data type into a string value. diff --git a/tests/test_config.py b/tests/test_config.py index 22abd4c22c..abb4ee7a0f 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,5 +1,6 @@ import ssl import typing +from datetime import timedelta from pathlib import Path import certifi @@ -105,11 +106,26 @@ def test_timeout_eq(): assert timeout == httpx.Timeout(timeout=5.0) +def test_timeout_timedelta_eq(): + timeout = httpx.Timeout(timeout=timedelta(seconds=5.0)) + assert timeout == httpx.Timeout(timeout=5.0) + + def test_timeout_all_parameters_set(): timeout = httpx.Timeout(connect=5.0, read=5.0, write=5.0, pool=5.0) assert timeout == httpx.Timeout(timeout=5.0) +def test_timeout_all_parameters_timedelta_set(): + timeout = httpx.Timeout( + connect=timedelta(seconds=5.0), + read=timedelta(seconds=5.0), + write=timedelta(seconds=5.0), + pool=timedelta(seconds=5.0), + ) + assert timeout == httpx.Timeout(timeout=5.0) + + def test_timeout_from_nothing(): timeout = httpx.Timeout(None) assert timeout.connect is None @@ -133,11 +149,21 @@ def test_timeout_from_one_value(): assert timeout == httpx.Timeout(timeout=(None, 5.0, None, None)) +def test_timeout_from_one_timedelta_value(): + timeout = httpx.Timeout(None, read=timedelta(seconds=5.0)) + assert timeout == httpx.Timeout(timeout=(None, 5.0, None, None)) + + def test_timeout_from_one_value_and_default(): timeout = httpx.Timeout(5.0, pool=60.0) assert timeout == httpx.Timeout(timeout=(5.0, 5.0, 5.0, 60.0)) +def test_timeout_from_one_value_and_default_timedelta(): + timeout = httpx.Timeout(timedelta(seconds=5.0), pool=timedelta(seconds=60.0)) + assert timeout == httpx.Timeout(timeout=(5.0, 5.0, 5.0, 60.0)) + + def test_timeout_missing_default(): with pytest.raises(ValueError): httpx.Timeout(pool=60.0) @@ -148,6 +174,18 @@ def test_timeout_from_tuple(): assert timeout == httpx.Timeout(timeout=5.0) +def test_timeout_from_timedelta_tuple(): + timeout = httpx.Timeout( + timeout=( + timedelta(seconds=5.0), + timedelta(seconds=5.0), + timedelta(seconds=5.0), + timedelta(seconds=5.0), + ) + ) + assert timeout == httpx.Timeout(timeout=5.0) + + def test_timeout_from_config_instance(): timeout = httpx.Timeout(timeout=5.0) assert httpx.Timeout(timeout) == httpx.Timeout(timeout=5.0)