diff --git a/blacksheep/contents.py b/blacksheep/contents.py index c106b085..fd380114 100644 --- a/blacksheep/contents.py +++ b/blacksheep/contents.py @@ -14,7 +14,7 @@ def __init__(self, content_type: bytes, data: bytes): self.body = data self.length = len(data) - async def read(self): + async def read(self) -> bytes: return self.body diff --git a/blacksheep/headers.py b/blacksheep/headers.py index a0403d50..0d6fc63e 100644 --- a/blacksheep/headers.py +++ b/blacksheep/headers.py @@ -20,7 +20,7 @@ def __eq__(self, other): class Headers: - def __init__(self, values: list[tuple[bytes, bytes]] = None): + def __init__(self, values: list[tuple[bytes, bytes]] | None = None): if values is None: values = [] self.values = values diff --git a/blacksheep/messages.py b/blacksheep/messages.py index 8191ef50..9a193f55 100644 --- a/blacksheep/messages.py +++ b/blacksheep/messages.py @@ -3,7 +3,7 @@ import re from datetime import timedelta from json.decoder import JSONDecodeError -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, TypeAlias from urllib.parse import parse_qs, quote, unquote, urlencode from guardpost import Identity @@ -16,6 +16,8 @@ from .contents import ( ASGIContent, Content, + FormPart, + StreamedContent, multiparts_to_dictionary, parse_www_form_urlencoded, ) @@ -30,14 +32,17 @@ _charset_rx = re.compile(rb"charset=([\w\-]+)", re.I) -def parse_charset(value: bytes): +RawHeader: TypeAlias = tuple[bytes, bytes] + + +def parse_charset(value: bytes) -> str | None: m = _charset_rx.search(value) if m: return m.group(1).decode("ascii") return None -async def _read_stream(request): +async def _read_stream(request: "Request"): async for _ in request.content.stream(): pass @@ -52,8 +57,8 @@ async def _call_soon(coro): class Message: - def __init__(self, headers): - self._raw_headers = headers or [] + def __init__(self, headers: list[RawHeader]): + self._raw_headers: list[RawHeader] = headers or [] @property def headers(self) -> Headers: @@ -63,39 +68,40 @@ def headers(self) -> Headers: self.__dict__[key] = Headers(self._raw_headers) return self.__dict__[key] - def with_content(self, content: Content): - self.content = content + def with_content(self, content: Content | StreamedContent) -> "Message": + self.content: Content | StreamedContent | None = content return self - def get_first_header(self, key: bytes): + def get_first_header(self, key: bytes) -> bytes | None: key = key.lower() for header in self._raw_headers: if header[0].lower() == key: return header[1] + return None - def get_headers(self, key: bytes): - results = [] + def get_headers(self, key: bytes) -> list[bytes]: + results: list[bytes] = [] key = key.lower() for header in self._raw_headers: if header[0].lower() == key: results.append(header[1]) return results - def init_prop(self, name: str, value): + def init_prop(self, name: str, value: Any): try: getattr(self, name) except AttributeError: setattr(self, name, value) - def get_headers_tuples(self, key: bytes): - results = [] + def get_headers_tuples(self, key: bytes) -> list[RawHeader]: + results: list[RawHeader] = [] key = key.lower() for header in self._raw_headers: if header[0].lower() == key: results.append(header) return results - def get_single_header(self, key: bytes): + def get_single_header(self, key: bytes) -> bytes: results = self.get_headers(key) if len(results) > 1: raise ValueError("Headers contains more than one header with the given key") @@ -112,18 +118,18 @@ def remove_header(self, key: bytes): for header in to_remove: self._raw_headers.remove(header) - def remove_headers(self, headers): + def remove_headers(self, headers: list[RawHeader]): for header in headers: self._raw_headers.remove(header) - def _has_header(self, key: bytes): + def _has_header(self, key: bytes) -> bool: key = key.lower() for existing_key, existing_value in self._raw_headers: if existing_key.lower() == key: return True return False - def has_header(self, key: bytes): + def has_header(self, key: bytes) -> bool: return self._has_header(key) def _add_header(self, key: bytes, value: bytes): @@ -140,24 +146,24 @@ def set_header(self, key: bytes, value: bytes): self.remove_header(key) self._raw_headers.append((key, value)) - def content_type(self): + def content_type(self) -> bytes | None: if hasattr(self, "content") and self.content and self.content.type: return self.content.type return self.get_first_header(b"content-type") - async def read(self): + async def read(self) -> bytes | None: if hasattr(self, "content") and self.content: return await self.content.read() return None async def stream(self): - if hasattr(self, "content") and self.content: + if hasattr(self, "content") and self.content and hasattr(self.content, "stream"): async for chunk in self.content.stream(): yield chunk else: yield None - async def text(self): + async def text(self) -> str: body = await self.read() if body is None: return "" @@ -175,19 +181,23 @@ async def form(self): return parse_www_form_urlencoded(text) if b"multipart/form-data;" in content_type_value: body = await self.read() + if body is None: + return None return multiparts_to_dictionary(list(parse_multipart(body))) return None - async def multipart(self): + async def multipart(self) -> list[FormPart] | None: content_type_value = self.content_type() if not content_type_value: return None if b"multipart/form-data;" in content_type_value: body = await self.read() + if body is None: + return None return list(parse_multipart(body)) return None - def declares_content_type(self, type: bytes): + def declares_content_type(self, type: bytes) -> bool: content_type = self.content_type() if not content_type: return False @@ -195,24 +205,26 @@ def declares_content_type(self, type: bytes): return True return False - def declares_json(self): + def declares_json(self) -> bool: return self.declares_content_type(b"json") - def declares_xml(self): + def declares_xml(self) -> bool: return self.declares_content_type(b"xml") - async def files(self, name=None): + async def files(self, name: str | bytes | None = None) -> list[FormPart]: if isinstance(name, str): name = name.encode("ascii") content_type = self.content_type() if not content_type or b"multipart/form-data;" not in content_type: return [] data = await self.multipart() + if data is None: + return [] if name: return [part for part in data if part.file_name and part.name == name] return [part for part in data if part.file_name] - async def json(self, loads=json_settings.loads): + async def json(self, loads=json_settings.loads) -> dict[str, Any] | list[Any] | None: if not self.declares_json(): return None text = await self.text() @@ -230,52 +242,57 @@ async def json(self, loads=json_settings.loads): ) raise BadRequestFormat("Cannot parse content as JSON", decode_error) - def has_body(self): + def has_body(self) -> bool: content = getattr(self, "content", None) if not content or content.length == 0: return False return True @property - def charset(self): + def charset(self) -> str: content_type = self.content_type() if content_type: return parse_charset(content_type) or "utf8" return "utf8" -def method_without_body(method: str): +def method_without_body(method: str) -> bool: return method in ("GET", "HEAD", "TRACE") class Request(Message): - def __init__(self, method: str, url: bytes, headers): + def __init__(self, method: str, url: bytes | None, headers: list[RawHeader]): + self._path: bytes | None + self._raw_query: bytes | None + self._is_disconnected: bool + _url = URL(url) if url else None self._raw_headers = headers or [] self.method = method self._url = _url - self._session = None + self._session: "Session | None" = None if _url: self._path = _url.path self._raw_query = _url.query else: self._path = None self._raw_query = None - self.scope = None + self.scope: dict[str, Any] = {} self.content: Content | None = None + # TODO: deprecate the 'identity' property in the future. This requires a breaking # change in guardpost, too. @property - def identity(self): + def identity(self) -> Identity: return self.user @identity.setter - def identity(self, value): + def identity(self, value: Identity): self.__dict__["_user"] = value @property - def user(self): + def user(self) -> Identity: try: return self.__dict__["_user"] except KeyError: @@ -283,7 +300,7 @@ def user(self): return self.__dict__["_user"] @user.setter - def user(self, value): + def user(self, value: Identity): self.__dict__["_user"] = value @property @@ -313,7 +330,7 @@ def host(self) -> str: return self.__dict__["host"] @host.setter - def host(self, value: str) -> None: + def host(self, value: str): self.__dict__["host"] = value @property @@ -352,7 +369,7 @@ def original_client_ip(self, value: str): self.__dict__["original_client_ip"] = value @property - def session(self): + def session(self) -> "Session": if self._session is None: raise TypeError( "A session is not configured for this request, activate " @@ -365,14 +382,14 @@ def session(self, value: "Session"): self._session = value @classmethod - def incoming(cls, method: str, path: bytes, query: bytes, headers): + def incoming(cls, method: str, path: bytes, query: bytes, headers) -> "Request": request = cls(method, None, headers) request._path = path request._raw_query = query return request @property - def query(self): + def query(self) -> dict[str, list[str]]: if self._raw_query: return parse_qs(self._raw_query.decode("utf8")) return {} @@ -384,10 +401,11 @@ def query(self, value): self.url = self.url.with_query(raw_query) @property - def url(self): + def url(self) -> URL: if self._url: return self._url - if self._raw_query: + # _raw_query and _path must be available together according to the constructor + if self._raw_query and self._path: self._url = URL(self._path + b"?" + self._raw_query) else: self._url = URL(self._path) @@ -416,11 +434,11 @@ def url(self, value): self.__dict__["host"] = None self.remove_header(b"host") - def __repr__(self): + def __repr__(self) -> str: return f"" @property - def cookies(self): + def cookies(self) -> dict[str, str]: cookies = {} cookies_headers = self.get_headers(b"cookie") if cookies_headers: @@ -437,7 +455,7 @@ def cookies(self): ) return cookies - def get_cookie(self, name: str): + def get_cookie(self, name: str) -> str | None: return self.cookies.get(name) def set_cookie(self, name: str, value: str): @@ -449,20 +467,20 @@ def set_cookie(self, name: str, value: str): self._raw_headers.append((b"cookie", new_value)) @property - def etag(self): + def etag(self) -> bytes | None: return self.get_first_header(b"etag") @property - def if_none_match(self): + def if_none_match(self) -> bytes | None: return self.get_first_header(b"if-none-match") - def expect_100_continue(self): + def expect_100_continue(self) -> bool: value = self.get_first_header(b"expect") if value and value.lower() == b"100-continue": return True return False - async def is_disconnected(self): + async def is_disconnected(self) -> bool: if not isinstance(self.content, ASGIContent): raise TypeError( "This method is only supported when a request is bound to " @@ -480,23 +498,23 @@ async def is_disconnected(self): class Response(Message): - def __init__(self, status: int, headers=None, content: Content = None): + def __init__(self, status: int, headers=None, content: Content | None = None): self._raw_headers = headers or [] self.status = status self.content = content - def __repr__(self): + def __repr__(self) -> str: return f"" @property - def cookies(self): + def cookies(self) -> dict[str, Cookie]: return self.get_cookies() @property def reason(self) -> str: return http.HTTPStatus(self.status).phrase - def get_cookies(self): + def get_cookies(self) -> dict[str, Cookie]: cookies = {} set_cookies_headers = self.get_headers(b"set-cookie") if set_cookies_headers: @@ -505,7 +523,7 @@ def get_cookies(self): cookies[cookie.name] = cookie return cookies - def get_cookie(self, name: str): + def get_cookie(self, name: str) -> Cookie | None: set_cookies_headers = self.get_headers(b"set-cookie") if set_cookies_headers: for value in set_cookies_headers: @@ -517,7 +535,7 @@ def get_cookie(self, name: str): def set_cookie(self, cookie: Cookie): self._raw_headers.append((b"set-cookie", write_cookie_for_response(cookie))) - def set_cookies(self, cookies): + def set_cookies(self, cookies: list[Cookie]): for cookie in cookies: self.set_cookie(cookie) @@ -534,7 +552,7 @@ def remove_cookie(self, name: str): to_remove.append(value) self.remove_headers(to_remove) - def is_redirect(self): + def is_redirect(self) -> bool: return self.status in {301, 302, 303, 307, 308} async def raise_for_status(self): @@ -542,18 +560,18 @@ async def raise_for_status(self): raise FailedRequestError(self.status, await self.text()) -def is_cors_request(request: "Request"): +def is_cors_request(request: Request) -> bool: return bool(request.get_first_header(b"Origin")) -def is_cors_preflight_request(request: "Request"): +def is_cors_preflight_request(request: Request) -> bool: if request.method != "OPTIONS" or not is_cors_request(request): return False next_request_method = request.get_first_header(b"Access-Control-Request-Method") return bool(next_request_method) -def ensure_bytes(value): +def ensure_bytes(value) -> bytes: if isinstance(value, str): return value.encode() if isinstance(value, bytes): @@ -561,7 +579,7 @@ def ensure_bytes(value): raise ValueError("Input value must be bytes or str") -def get_request_absolute_url(request: "Request"): +def get_request_absolute_url(request: Request) -> URL: if request.url.is_absolute: return request.url return build_absolute_url( @@ -572,7 +590,7 @@ def get_request_absolute_url(request: "Request"): ) -def get_absolute_url_to_path(request: "Request", path: str): +def get_absolute_url_to_path(request: Request, path: str) -> URL: return build_absolute_url( ensure_bytes(request.scheme), ensure_bytes(request.host), diff --git a/blacksheep/url.py b/blacksheep/url.py index 501ea267..f44e3ea4 100644 --- a/blacksheep/url.py +++ b/blacksheep/url.py @@ -112,7 +112,7 @@ def __eq__(self, other): return NotImplemented -def build_absolute_url(scheme: bytes, host: bytes, base_path: bytes, path: bytes): +def build_absolute_url(scheme: bytes, host: bytes, base_path: bytes, path: bytes) -> URL: scheme_str = scheme.decode() if isinstance(scheme, bytes) else scheme valid_schema(scheme_str) url = (