Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 63 additions & 19 deletions blacksheep/baseapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
import logging
from collections import UserDict
import typing

from blacksheep.server.errors import ServerErrorDetailsHandler
from blacksheep.server.routing import Router
Expand All @@ -17,8 +18,11 @@
ValidationError = None


class ExceptionHandlersDict(UserDict):
if typing.TYPE_CHECKING:
from .messages import Request


class ExceptionHandlersDict(UserDict):
def __setitem__(self, key, item) -> None:
if not inspect.iscoroutinefunction(item):
raise InvalidExceptionHandler()
Expand All @@ -32,15 +36,15 @@ def __setitem__(self, key, item) -> None:
return super().__setitem__(key, item)


async def handle_not_found(app, request, http_exception):
async def handle_not_found(app, request, http_exception) -> Response:
return Response(404, content=TextContent("Resource not found"))


async def handle_internal_server_error(app, request, exception):
async def handle_internal_server_error(app, request, exception) -> Response:
return Response(500, content=TextContent("Internal Server Error"))


async def handle_bad_request(app, request, http_exception):
async def handle_bad_request(app, request, http_exception) -> Response:
if getattr(http_exception, "__context__", None) is not None and callable(
getattr(http_exception.__context__, "json", None)
):
Expand All @@ -53,20 +57,20 @@ async def handle_bad_request(app, request, http_exception):
return Response(400, content=TextContent(f"Bad Request: {str(http_exception)}"))


async def _default_pydantic_validation_error_handler(app, request, error):
async def _default_pydantic_validation_error_handler(app, request, error) -> Response:
return Response(
400, content=Content(b"application/json", error.json(indent=4).encode("utf-8"))
)


async def common_http_exception_handler(app, request, http_exception):
async def common_http_exception_handler(app, request, http_exception) -> Response:
return Response(
http_exception.status,
content=TextContent(http.HTTPStatus(http_exception.status).phrase),
)


def get_logger():
def get_logger() -> logging.Logger:
logger = logging.getLogger("blacksheep.server")
logger.setLevel(logging.INFO)
return logger
Expand All @@ -82,7 +86,7 @@ def __init__(self, show_error_details, router):
self.logger = get_logger()
self.server_error_details_handler: ServerErrorDetailsHandler

def init_exceptions_handlers(self):
def init_exceptions_handlers(self) -> ExceptionHandlersDict:
default_handlers = ExceptionHandlersDict(
{404: handle_not_found, 400: handle_bad_request}
)
Expand All @@ -92,15 +96,23 @@ def init_exceptions_handlers(self):
)
return default_handlers

async def log_unhandled_exc(self, request, exc):
async def log_unhandled_exc(
self,
request: "Request",
exc: Exception,
):
self.logger.error(
'Unhandled exception - "%s %s"',
request.method,
request.url.value.decode(),
exc_info=exc,
)

async def log_handled_exc(self, request, exc):
async def log_handled_exc(
self,
request: "Request",
exc: Exception,
):
if isinstance(exc, HTTPException):
self.logger.info(
'HTTP %s - "%s %s". %s',
Expand All @@ -117,7 +129,7 @@ async def log_handled_exc(self, request, exc):
str(exc),
)

async def handle(self, request):
async def handle(self, request: "Request") -> Response:
route = self.router.get_match(request)

if not route:
Expand All @@ -133,7 +145,11 @@ async def handle(self, request):
response = await self.handle_request_handler_exception(request, exc)
return response or Response(204)

async def handle_request_handler_exception(self, request, exc):
async def handle_request_handler_exception(
self,
request: "Request",
exc: Exception,
) -> Response:
if isinstance(exc, HTTPException):
await self.log_handled_exc(request, exc)
return await self.handle_http_exception(request, exc)
Expand All @@ -143,29 +159,46 @@ async def handle_request_handler_exception(self, request, exc):
await self.log_unhandled_exc(request, exc)
return await self.handle_exception(request, exc)

def get_http_exception_handler(self, http_exception):
def get_http_exception_handler(
self, http_exception: HTTPException
) -> typing.Callable[
["BaseApplication", "Request", Exception], typing.Awaitable[Response]
]:
handler = self.get_exception_handler(http_exception, stop_at=HTTPException)
if handler:
return handler
return self.exceptions_handlers.get(
getattr(http_exception, "status", None), common_http_exception_handler
)

def is_handled_exception(self, exception):
def is_handled_exception(self, exception) -> bool:
for class_type in get_class_instance_hierarchy(exception):
if class_type in self.exceptions_handlers:
return True
return False

def get_exception_handler(self, exception, stop_at):
def get_exception_handler(
self,
exception: Exception,
stop_at: type | None,
) -> (
typing.Callable[
["BaseApplication", "Request", Exception], typing.Awaitable[Response]
]
| None
):
for class_type in get_class_instance_hierarchy(exception):
if stop_at is not None and stop_at is class_type:
return None
if class_type in self.exceptions_handlers:
return self.exceptions_handlers[class_type]
return None

async def handle_internal_server_error(self, request, exc):
async def handle_internal_server_error(
self,
request: "Request",
exc,
) -> Response:
if self.show_error_details:
return self.server_error_details_handler.produce_response(request, exc)
error = InternalServerError(exc)
Expand All @@ -179,7 +212,14 @@ async def handle_internal_server_error(self, request, exc):
)
return Response(500, content=TextContent("Internal Server Error"))

async def _apply_exception_handler(self, request, exc, exception_handler):
async def _apply_exception_handler(
self,
request: "Request",
exc: Exception,
exception_handler: typing.Callable[
["BaseApplication", "Request", Exception], typing.Awaitable[Response]
],
):
try:
return await exception_handler(self, request, exc)
except Exception as server_ex:
Expand All @@ -194,15 +234,19 @@ async def _apply_exception_handler(self, request, exc, exception_handler):

return await handle_internal_server_error(self, request, server_ex)

async def handle_http_exception(self, request, http_exception):
async def handle_http_exception(
self,
request: "Request",
http_exception: HTTPException,
) -> Response:
exception_handler = self.get_http_exception_handler(http_exception)
if exception_handler:
return await self._apply_exception_handler(
request, http_exception, exception_handler
)
return await self.handle_exception(request, http_exception)

async def handle_exception(self, request, exc):
async def handle_exception(self, request: "Request", exc: Exception) -> Response:
exception_handler = self.get_exception_handler(exc, None)
if exception_handler:
return await self._apply_exception_handler(request, exc, exception_handler)
Expand Down
4 changes: 2 additions & 2 deletions blacksheep/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ def join_fragments(*args: AnyStr) -> str:
)


def get_class_hierarchy(cls: Type[T]):
def get_class_hierarchy(cls: Type[T]) -> tuple[Type[T], ...]:
return cls.__mro__


def get_class_instance_hierarchy(instance: T):
def get_class_instance_hierarchy(instance: T) -> tuple[Type[T], ...]:
return get_class_hierarchy(type(instance))


Expand Down
Loading