diff --git a/fastapi_crudrouter/core/_base.py b/fastapi_crudrouter/core/_base.py index e45d33f..2bd0fc3 100644 --- a/fastapi_crudrouter/core/_base.py +++ b/fastapi_crudrouter/core/_base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Any, Callable, Generic, List, Optional, Type, Union -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, status from fastapi.types import DecoratedCallable from ._types import T, DEPENDENCIES @@ -30,6 +30,11 @@ def __init__( update_route: Union[bool, DEPENDENCIES] = True, delete_one_route: Union[bool, DEPENDENCIES] = True, delete_all_route: Union[bool, DEPENDENCIES] = True, + create_status_code: Optional[int] = status.HTTP_201_CREATED, + get_all_status_code: Optional[int] = status.HTTP_200_OK, + get_one_status_code: Optional[int] = status.HTTP_200_OK, + update_status_code: Optional[int] = status.HTTP_200_OK, + delete_status_code: Optional[int] = status.HTTP_200_OK, **kwargs: Any, ) -> None: @@ -61,6 +66,7 @@ def __init__( response_model=Optional[List[self.schema]], # type: ignore summary="Get All", dependencies=get_all_route, + status_code=get_all_status_code, ) if create_route: @@ -68,9 +74,10 @@ def __init__( "", self._create(), methods=["POST"], - response_model=self.schema, + response_model=None if create_status_code == 204 else self.schema, summary="Create One", dependencies=create_route, + status_code=create_status_code, ) if delete_all_route: @@ -92,6 +99,7 @@ def __init__( summary="Get One", dependencies=get_one_route, error_responses=[NOT_FOUND], + status_code=get_one_status_code, ) if update_route: @@ -99,10 +107,11 @@ def __init__( "/{item_id}", self._update(), methods=["PUT"], - response_model=self.schema, + response_model=None if update_status_code == 204 else self.schema, summary="Update One", dependencies=update_route, error_responses=[NOT_FOUND], + status_code=update_status_code, ) if delete_one_route: @@ -110,10 +119,11 @@ def __init__( "/{item_id}", self._delete_one(), methods=["DELETE"], - response_model=self.schema, + response_model=None if delete_status_code == 204 else self.schema, summary="Delete One", dependencies=delete_one_route, error_responses=[NOT_FOUND], + status_code=delete_status_code, ) def _add_api_route( @@ -122,6 +132,7 @@ def _add_api_route( endpoint: Callable[..., Any], dependencies: Union[bool, DEPENDENCIES], error_responses: Optional[List[HTTPException]] = None, + status_code: Optional[int] = 200, **kwargs: Any, ) -> None: dependencies = [] if isinstance(dependencies, bool) else dependencies @@ -132,7 +143,7 @@ def _add_api_route( ) super().add_api_route( - path, endpoint, dependencies=dependencies, responses=responses, **kwargs + path, endpoint, dependencies=dependencies, responses=responses, status_code=status_code, **kwargs ) def api_route( diff --git a/fastapi_crudrouter/core/sqlalchemy.py b/fastapi_crudrouter/core/sqlalchemy.py index 58270f3..0fd5b8f 100644 --- a/fastapi_crudrouter/core/sqlalchemy.py +++ b/fastapi_crudrouter/core/sqlalchemy.py @@ -1,6 +1,6 @@ from typing import Any, Callable, List, Type, Generator, Optional, Union -from fastapi import Depends, HTTPException +from fastapi import Depends, HTTPException, status from . import CRUDGenerator, NOT_FOUND, _utils from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA @@ -39,6 +39,11 @@ def __init__( update_route: Union[bool, DEPENDENCIES] = True, delete_one_route: Union[bool, DEPENDENCIES] = True, delete_all_route: Union[bool, DEPENDENCIES] = True, + create_status_code: Optional[int] = status.HTTP_201_CREATED, + get_all_status_code: Optional[int] = status.HTTP_200_OK, + get_one_status_code: Optional[int] = status.HTTP_200_OK, + update_status_code: Optional[int] = status.HTTP_200_OK, + delete_status_code: Optional[int] = status.HTTP_200_OK, **kwargs: Any ) -> None: assert ( @@ -49,6 +54,9 @@ def __init__( self.db_func = db self._pk: str = db_model.__table__.primary_key.columns.keys()[0] self._pk_type: type = _utils.get_pk_type(schema, self._pk) + self.create_status_code = create_status_code + self.update_status_code = update_status_code + self.delete_status_code = delete_status_code super().__init__( schema=schema, @@ -63,6 +71,11 @@ def __init__( update_route=update_route, delete_one_route=delete_one_route, delete_all_route=delete_all_route, + create_status_code=create_status_code, + get_all_status_code=get_all_status_code, + get_one_status_code=get_one_status_code, + update_status_code=update_status_code, + delete_status_code=delete_status_code, **kwargs ) @@ -107,7 +120,9 @@ def route( db.add(db_model) db.commit() db.refresh(db_model) - return db_model + + if not self.create_status_code == status.HTTP_204_NO_CONTENT: + return db_model except IntegrityError: db.rollback() raise HTTPException(422, "Key already exists") from None @@ -130,7 +145,8 @@ def route( db.commit() db.refresh(db_model) - return db_model + if not self.update_status_code == status.HTTP_204_NO_CONTENT: + return db_model except IntegrityError as e: db.rollback() self._raise(e) @@ -154,6 +170,7 @@ def route( db.delete(db_model) db.commit() - return db_model + if not self.delete_status_code == status.HTTP_204_NO_CONTENT: + return db_model return route