diff --git a/stdlib/@tests/test_cases/check_compression.py b/stdlib/@tests/test_cases/check_compression.py new file mode 100644 index 000000000000..7fc106f125c7 --- /dev/null +++ b/stdlib/@tests/test_cases/check_compression.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import io +import sys +from _typeshed import ReadableBuffer +from bz2 import BZ2Decompressor +from lzma import LZMADecompressor +from typing import cast +from typing_extensions import assert_type +from zlib import decompressobj + +if sys.version_info >= (3, 14): + from compression._common._streams import DecompressReader, _Decompressor, _Reader + from compression.zstd import ZstdDecompressor +else: + from _compression import DecompressReader, _Decompressor, _Reader + +### +# Tests for DecompressReader/_Decompressor +### + + +class CustomDecompressor: + def decompress(self, data: ReadableBuffer, max_length: int = -1) -> bytes: + return b"" + + @property + def unused_data(self) -> bytes: + return b"" + + @property + def eof(self) -> bool: + return False + + @property + def needs_input(self) -> bool: + return False + + +def accept_decompressor(d: _Decompressor) -> None: + d.decompress(b"random bytes", 0) + assert_type(d.eof, bool) + assert_type(d.unused_data, bytes) + + +fp = cast(_Reader, io.BytesIO(b"hello world")) +DecompressReader(fp, decompressobj) +DecompressReader(fp, BZ2Decompressor) +DecompressReader(fp, LZMADecompressor) +DecompressReader(fp, CustomDecompressor) +accept_decompressor(decompressobj()) +accept_decompressor(BZ2Decompressor()) +accept_decompressor(LZMADecompressor()) +accept_decompressor(CustomDecompressor()) + +if sys.version_info >= (3, 14): + DecompressReader(fp, ZstdDecompressor) + accept_decompressor(ZstdDecompressor()) diff --git a/stdlib/_compression.pyi b/stdlib/_compression.pyi index aa67df2ab478..6015bcb13f1c 100644 --- a/stdlib/_compression.pyi +++ b/stdlib/_compression.pyi @@ -1,6 +1,6 @@ # _compression is replaced by compression._common._streams on Python 3.14+ (PEP-784) -from _typeshed import Incomplete, WriteableBuffer +from _typeshed import ReadableBuffer, WriteableBuffer from collections.abc import Callable from io import DEFAULT_BUFFER_SIZE, BufferedIOBase, RawIOBase from typing import Any, Protocol, type_check_only @@ -13,13 +13,24 @@ class _Reader(Protocol): def seekable(self) -> bool: ... def seek(self, n: int, /) -> Any: ... +@type_check_only +class _Decompressor(Protocol): + def decompress(self, data: ReadableBuffer, /, max_length: int = ...) -> bytes: ... + @property + def unused_data(self) -> bytes: ... + @property + def eof(self) -> bool: ... + # `zlib._Decompress` does not have next property, but `DecompressReader` calls it: + # @property + # def needs_input(self) -> bool: ... + class BaseStream(BufferedIOBase): ... class DecompressReader(RawIOBase): def __init__( self, fp: _Reader, - decomp_factory: Callable[..., Incomplete], + decomp_factory: Callable[..., _Decompressor], trailing_error: type[Exception] | tuple[type[Exception], ...] = (), **decomp_args: Any, # These are passed to decomp_factory. ) -> None: ... diff --git a/stdlib/compression/_common/_streams.pyi b/stdlib/compression/_common/_streams.pyi index b8463973ec67..96aec24d1c2d 100644 --- a/stdlib/compression/_common/_streams.pyi +++ b/stdlib/compression/_common/_streams.pyi @@ -1,4 +1,4 @@ -from _typeshed import Incomplete, WriteableBuffer +from _typeshed import ReadableBuffer, WriteableBuffer from collections.abc import Callable from io import DEFAULT_BUFFER_SIZE, BufferedIOBase, RawIOBase from typing import Any, Protocol, type_check_only @@ -11,13 +11,24 @@ class _Reader(Protocol): def seekable(self) -> bool: ... def seek(self, n: int, /) -> Any: ... +@type_check_only +class _Decompressor(Protocol): + def decompress(self, data: ReadableBuffer, /, max_length: int = ...) -> bytes: ... + @property + def unused_data(self) -> bytes: ... + @property + def eof(self) -> bool: ... + # `zlib._Decompress` does not have next property, but `DecompressReader` calls it: + # @property + # def needs_input(self) -> bool: ... + class BaseStream(BufferedIOBase): ... class DecompressReader(RawIOBase): def __init__( self, fp: _Reader, - decomp_factory: Callable[..., Incomplete], # Consider backporting changes to _compression + decomp_factory: Callable[..., _Decompressor], # Consider backporting changes to _compression trailing_error: type[Exception] | tuple[type[Exception], ...] = (), **decomp_args: Any, # These are passed to decomp_factory. ) -> None: ...