diff --git a/blacksheep/baseapp.py b/blacksheep/baseapp.py index 34dcfeeb..67b1a40c 100644 --- a/blacksheep/baseapp.py +++ b/blacksheep/baseapp.py @@ -2,6 +2,7 @@ import inspect import logging from collections import UserDict +import typing from blacksheep.server.errors import ServerErrorDetailsHandler from blacksheep.server.routing import Router @@ -17,8 +18,11 @@ ValidationError = None -class ExceptionHandlersDict(UserDict): +if typing.TYPE_CHECKING: + from .messages import Request + +class ExceptionHandlersDict(UserDict): def __setitem__(self, key, item) -> None: if not inspect.iscoroutinefunction(item): raise InvalidExceptionHandler() @@ -32,15 +36,15 @@ def __setitem__(self, key, item) -> None: return super().__setitem__(key, item) -async def handle_not_found(app, request, http_exception): +async def handle_not_found(app, request, http_exception) -> Response: return Response(404, content=TextContent("Resource not found")) -async def handle_internal_server_error(app, request, exception): +async def handle_internal_server_error(app, request, exception) -> Response: return Response(500, content=TextContent("Internal Server Error")) -async def handle_bad_request(app, request, http_exception): +async def handle_bad_request(app, request, http_exception) -> Response: if getattr(http_exception, "__context__", None) is not None and callable( getattr(http_exception.__context__, "json", None) ): @@ -53,20 +57,20 @@ async def handle_bad_request(app, request, http_exception): return Response(400, content=TextContent(f"Bad Request: {str(http_exception)}")) -async def _default_pydantic_validation_error_handler(app, request, error): +async def _default_pydantic_validation_error_handler(app, request, error) -> Response: return Response( 400, content=Content(b"application/json", error.json(indent=4).encode("utf-8")) ) -async def common_http_exception_handler(app, request, http_exception): +async def common_http_exception_handler(app, request, http_exception) -> Response: return Response( http_exception.status, content=TextContent(http.HTTPStatus(http_exception.status).phrase), ) -def get_logger(): +def get_logger() -> logging.Logger: logger = logging.getLogger("blacksheep.server") logger.setLevel(logging.INFO) return logger @@ -82,7 +86,7 @@ def __init__(self, show_error_details, router): self.logger = get_logger() self.server_error_details_handler: ServerErrorDetailsHandler - def init_exceptions_handlers(self): + def init_exceptions_handlers(self) -> ExceptionHandlersDict: default_handlers = ExceptionHandlersDict( {404: handle_not_found, 400: handle_bad_request} ) @@ -92,7 +96,11 @@ def init_exceptions_handlers(self): ) return default_handlers - async def log_unhandled_exc(self, request, exc): + async def log_unhandled_exc( + self, + request: "Request", + exc: Exception, + ): self.logger.error( 'Unhandled exception - "%s %s"', request.method, @@ -100,7 +108,11 @@ async def log_unhandled_exc(self, request, exc): exc_info=exc, ) - async def log_handled_exc(self, request, exc): + async def log_handled_exc( + self, + request: "Request", + exc: Exception, + ): if isinstance(exc, HTTPException): self.logger.info( 'HTTP %s - "%s %s". %s', @@ -117,7 +129,7 @@ async def log_handled_exc(self, request, exc): str(exc), ) - async def handle(self, request): + async def handle(self, request: "Request") -> Response: route = self.router.get_match(request) if not route: @@ -133,7 +145,11 @@ async def handle(self, request): response = await self.handle_request_handler_exception(request, exc) return response or Response(204) - async def handle_request_handler_exception(self, request, exc): + async def handle_request_handler_exception( + self, + request: "Request", + exc: Exception, + ) -> Response: if isinstance(exc, HTTPException): await self.log_handled_exc(request, exc) return await self.handle_http_exception(request, exc) @@ -143,7 +159,11 @@ async def handle_request_handler_exception(self, request, exc): await self.log_unhandled_exc(request, exc) return await self.handle_exception(request, exc) - def get_http_exception_handler(self, http_exception): + def get_http_exception_handler( + self, http_exception: HTTPException + ) -> typing.Callable[ + ["BaseApplication", "Request", Exception], typing.Awaitable[Response] + ]: handler = self.get_exception_handler(http_exception, stop_at=HTTPException) if handler: return handler @@ -151,13 +171,22 @@ def get_http_exception_handler(self, http_exception): getattr(http_exception, "status", None), common_http_exception_handler ) - def is_handled_exception(self, exception): + def is_handled_exception(self, exception) -> bool: for class_type in get_class_instance_hierarchy(exception): if class_type in self.exceptions_handlers: return True return False - def get_exception_handler(self, exception, stop_at): + def get_exception_handler( + self, + exception: Exception, + stop_at: type | None, + ) -> ( + typing.Callable[ + ["BaseApplication", "Request", Exception], typing.Awaitable[Response] + ] + | None + ): for class_type in get_class_instance_hierarchy(exception): if stop_at is not None and stop_at is class_type: return None @@ -165,7 +194,11 @@ def get_exception_handler(self, exception, stop_at): return self.exceptions_handlers[class_type] return None - async def handle_internal_server_error(self, request, exc): + async def handle_internal_server_error( + self, + request: "Request", + exc, + ) -> Response: if self.show_error_details: return self.server_error_details_handler.produce_response(request, exc) error = InternalServerError(exc) @@ -179,7 +212,14 @@ async def handle_internal_server_error(self, request, exc): ) return Response(500, content=TextContent("Internal Server Error")) - async def _apply_exception_handler(self, request, exc, exception_handler): + async def _apply_exception_handler( + self, + request: "Request", + exc: Exception, + exception_handler: typing.Callable[ + ["BaseApplication", "Request", Exception], typing.Awaitable[Response] + ], + ): try: return await exception_handler(self, request, exc) except Exception as server_ex: @@ -194,7 +234,11 @@ async def _apply_exception_handler(self, request, exc, exception_handler): return await handle_internal_server_error(self, request, server_ex) - async def handle_http_exception(self, request, http_exception): + async def handle_http_exception( + self, + request: "Request", + http_exception: HTTPException, + ) -> Response: exception_handler = self.get_http_exception_handler(http_exception) if exception_handler: return await self._apply_exception_handler( @@ -202,7 +246,7 @@ async def handle_http_exception(self, request, http_exception): ) return await self.handle_exception(request, http_exception) - async def handle_exception(self, request, exc): + async def handle_exception(self, request: "Request", exc: Exception) -> Response: exception_handler = self.get_exception_handler(exc, None) if exception_handler: return await self._apply_exception_handler(request, exc, exception_handler) diff --git a/blacksheep/utils/__init__.py b/blacksheep/utils/__init__.py index 54cd168c..c527a67a 100644 --- a/blacksheep/utils/__init__.py +++ b/blacksheep/utils/__init__.py @@ -31,11 +31,11 @@ def join_fragments(*args: AnyStr) -> str: ) -def get_class_hierarchy(cls: Type[T]): +def get_class_hierarchy(cls: Type[T]) -> tuple[Type[T], ...]: return cls.__mro__ -def get_class_instance_hierarchy(instance: T): +def get_class_instance_hierarchy(instance: T) -> tuple[Type[T], ...]: return get_class_hierarchy(type(instance))