diff --git a/fastapi_crudrouter/core/_utils.py b/fastapi_crudrouter/core/_utils.py index ef3562e4..7bcb16eb 100644 --- a/fastapi_crudrouter/core/_utils.py +++ b/fastapi_crudrouter/core/_utils.py @@ -1,4 +1,4 @@ -from typing import Optional, Type, Any +from typing import Optional, Type, Any, Tuple from fastapi import Depends, HTTPException from pydantic import create_model @@ -12,13 +12,28 @@ def __init__(self, *args, **kwargs) -> None: # type: ignore self.__dict__ = self -def get_pk_type(schema: Type[PYDANTIC_SCHEMA], pk_field: str) -> Any: +def get_pk_type(schema: Type[PYDANTIC_SCHEMA], pk_field: str = "id") -> Any: try: return schema.__fields__[pk_field].type_ except KeyError: return int +def create_schema_default_factory( + schema_cls: Type[T], create_schema_instance: T, pk_field_name: str = "id" +) -> Tuple[T, bool]: + """ + Is used to check for default_factory for the pk on a Schema, + passing the CreateSchema values into the Schema if a + default_factory on the pk exists + """ + + if callable(schema_cls.__fields__[pk_field_name].default_factory): + return schema_cls(**create_schema_instance.dict()), True + else: + return create_schema_instance, False + + def schema_factory( schema_cls: Type[T], pk_field_name: str = "id", name: str = "Create" ) -> Type[T]: diff --git a/fastapi_crudrouter/core/databases.py b/fastapi_crudrouter/core/databases.py index 7ea3c711..777e2539 100644 --- a/fastapi_crudrouter/core/databases.py +++ b/fastapi_crudrouter/core/databases.py @@ -13,7 +13,7 @@ from . import CRUDGenerator, NOT_FOUND from ._types import PAGINATION, PYDANTIC_SCHEMA, DEPENDENCIES -from ._utils import AttrDict, get_pk_type +from ._utils import AttrDict, get_pk_type, create_schema_default_factory try: from sqlalchemy.sql.schema import Table @@ -111,6 +111,11 @@ def _create(self, *args: Any, **kwargs: Any) -> CALLABLE: async def route( schema: self.create_schema, # type: ignore ) -> Model: + schema, _ = create_schema_default_factory( + schema_cls=self.schema, + create_schema_instance=schema, + pk_field_name=self._pk, + ) query = self.table.insert() try: diff --git a/fastapi_crudrouter/core/gino_starlette.py b/fastapi_crudrouter/core/gino_starlette.py index d07d893e..55894177 100644 --- a/fastapi_crudrouter/core/gino_starlette.py +++ b/fastapi_crudrouter/core/gino_starlette.py @@ -5,6 +5,7 @@ from . import NOT_FOUND, CRUDGenerator, _utils from ._types import DEPENDENCIES, PAGINATION from ._types import PYDANTIC_SCHEMA as SCHEMA +from ._utils import create_schema_default_factory try: from asyncpg.exceptions import UniqueViolationError @@ -94,6 +95,12 @@ def _create(self, *args: Any, **kwargs: Any) -> CALLABLE: async def route( model: self.create_schema, # type: ignore ) -> Model: + model, _ = create_schema_default_factory( + schema_cls=self.schema, + create_schema_instance=model, + pk_field_name=self._pk, + ) + try: async with self.db.transaction(): db_model: Model = await self.db_model.create(**model.dict()) diff --git a/fastapi_crudrouter/core/mem.py b/fastapi_crudrouter/core/mem.py index d4e13c11..e7854a8a 100644 --- a/fastapi_crudrouter/core/mem.py +++ b/fastapi_crudrouter/core/mem.py @@ -1,7 +1,10 @@ from typing import Any, Callable, List, Type, cast, Optional, Union +from fastapi import HTTPException + from . import CRUDGenerator, NOT_FOUND from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA +from ._utils import get_pk_type, create_schema_default_factory CALLABLE = Callable[..., SCHEMA] CALLABLE_LIST = Callable[..., List[SCHEMA]] @@ -24,6 +27,8 @@ def __init__( delete_all_route: Union[bool, DEPENDENCIES] = True, **kwargs: Any ) -> None: + self._pk_type: type = get_pk_type(schema) + super().__init__( schema=schema, create_schema=create_schema, @@ -57,7 +62,7 @@ def route(pagination: PAGINATION = self.pagination) -> List[SCHEMA]: return route def _get_one(self, *args: Any, **kwargs: Any) -> CALLABLE: - def route(item_id: int) -> SCHEMA: + def route(item_id: self._pk_type) -> SCHEMA: for model in self.models: if model.id == item_id: # type: ignore return model @@ -68,8 +73,18 @@ def route(item_id: int) -> SCHEMA: def _create(self, *args: Any, **kwargs: Any) -> CALLABLE: def route(model: self.create_schema) -> SCHEMA: # type: ignore + model, using_default_factory = create_schema_default_factory( + schema_cls=self.schema, + create_schema_instance=model, + pk_field_name=self._pk, + ) model_dict = model.dict() - model_dict["id"] = self._get_next_id() + if using_default_factory: + for _model in self.models: + if _model.id == model.id: # type: ignore + raise HTTPException(422, "Key already exists") from None + else: + model_dict["id"] = self._get_next_id() ready_model = self.schema(**model_dict) self.models.append(ready_model) return ready_model @@ -77,7 +92,7 @@ def route(model: self.create_schema) -> SCHEMA: # type: ignore return route def _update(self, *args: Any, **kwargs: Any) -> CALLABLE: - def route(item_id: int, model: self.update_schema) -> SCHEMA: # type: ignore + def route(item_id: self._pk_type, model: self.update_schema) -> SCHEMA: # type: ignore for ind, model_ in enumerate(self.models): if model_.id == item_id: # type: ignore self.models[ind] = self.schema( @@ -97,7 +112,7 @@ def route() -> List[SCHEMA]: return route def _delete_one(self, *args: Any, **kwargs: Any) -> CALLABLE: - def route(item_id: int) -> SCHEMA: + def route(item_id: self._pk_type) -> SCHEMA: for ind, model in enumerate(self.models): if model.id == item_id: # type: ignore del self.models[ind] diff --git a/fastapi_crudrouter/core/ormar.py b/fastapi_crudrouter/core/ormar.py index 99952600..a6abf404 100644 --- a/fastapi_crudrouter/core/ormar.py +++ b/fastapi_crudrouter/core/ormar.py @@ -13,6 +13,7 @@ from . import CRUDGenerator, NOT_FOUND, _utils from ._types import DEPENDENCIES, PAGINATION +from ._utils import create_schema_default_factory try: from ormar import Model, NoMatch @@ -33,6 +34,7 @@ def __init__( schema: Type[Model], create_schema: Optional[Type[Model]] = None, update_schema: Optional[Type[Model]] = None, + default_factory_schema: Optional[Type[Model]] = None, prefix: Optional[str] = None, tags: Optional[List[str]] = None, paginate: Optional[int] = None, @@ -48,6 +50,9 @@ def __init__( self._pk: str = schema.Meta.pkname self._pk_type: type = _utils.get_pk_type(schema, self._pk) + self.default_factory_schema = ( + default_factory_schema if default_factory_schema else schema + ) super().__init__( schema=schema, @@ -94,6 +99,11 @@ async def route(item_id: self._pk_type) -> Model: # type: ignore def _create(self, *args: Any, **kwargs: Any) -> CALLABLE: async def route(model: self.create_schema) -> Model: # type: ignore + model, _ = create_schema_default_factory( + schema_cls=self.default_factory_schema, + create_schema_instance=model, + pk_field_name=self._pk, + ) model_dict = model.dict() if self.schema.Meta.model_fields[self._pk].autoincrement: model_dict.pop(self._pk, None) diff --git a/fastapi_crudrouter/core/sqlalchemy.py b/fastapi_crudrouter/core/sqlalchemy.py index 58270f34..b56069b9 100644 --- a/fastapi_crudrouter/core/sqlalchemy.py +++ b/fastapi_crudrouter/core/sqlalchemy.py @@ -4,6 +4,7 @@ from . import CRUDGenerator, NOT_FOUND, _utils from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA +from ._utils import create_schema_default_factory try: from sqlalchemy.orm import Session @@ -102,6 +103,12 @@ def route( model: self.create_schema, # type: ignore db: Session = Depends(self.db_func), ) -> Model: + model, _ = create_schema_default_factory( + schema_cls=self.schema, + create_schema_instance=model, + pk_field_name=self._pk, + ) + try: db_model: Model = self.db_model(**model.dict()) db.add(db_model) diff --git a/fastapi_crudrouter/core/tortoise.py b/fastapi_crudrouter/core/tortoise.py index 52972a48..75a7970c 100644 --- a/fastapi_crudrouter/core/tortoise.py +++ b/fastapi_crudrouter/core/tortoise.py @@ -2,6 +2,7 @@ from . import CRUDGenerator, NOT_FOUND from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA +from ._utils import get_pk_type, create_schema_default_factory try: from tortoise.models import Model @@ -40,6 +41,7 @@ def __init__( self.db_model = db_model self._pk: str = db_model.describe()["pk_field"]["db_column"] + self._pk_type: type = get_pk_type(schema, self._pk) super().__init__( schema=schema, @@ -68,7 +70,7 @@ async def route(pagination: PAGINATION = self.pagination) -> List[Model]: return route def _get_one(self, *args: Any, **kwargs: Any) -> CALLABLE: - async def route(item_id: int) -> Model: + async def route(item_id: self._pk_type) -> Model: model = await self.db_model.filter(id=item_id).first() if model: @@ -80,6 +82,11 @@ async def route(item_id: int) -> Model: def _create(self, *args: Any, **kwargs: Any) -> CALLABLE: async def route(model: self.create_schema) -> Model: # type: ignore + model, _ = create_schema_default_factory( + schema_cls=self.schema, + create_schema_instance=model, + pk_field_name=self._pk, + ) db_model = self.db_model(**model.dict()) await db_model.save() @@ -89,7 +96,7 @@ async def route(model: self.create_schema) -> Model: # type: ignore def _update(self, *args: Any, **kwargs: Any) -> CALLABLE: async def route( - item_id: int, model: self.update_schema # type: ignore + item_id: self._pk_type, model: self.update_schema # type: ignore ) -> Model: await self.db_model.filter(id=item_id).update( **model.dict(exclude_unset=True) @@ -106,7 +113,7 @@ async def route() -> List[Model]: return route def _delete_one(self, *args: Any, **kwargs: Any) -> CALLABLE: - async def route(item_id: int) -> Model: + async def route(item_id: self._pk_type) -> Model: model: Model = await self._get_one()(item_id) await self.db_model.filter(id=item_id).delete() diff --git a/tests/__init__.py b/tests/__init__.py index 4603469d..3bc622cc 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,9 +1,11 @@ -from pydantic import BaseModel +from uuid import uuid4 +from pydantic import BaseModel, Field from .conf import config PAGINATION_SIZE = 10 CUSTOM_TAGS = ["Tag1", "Tag2"] +POTATO_TAGS = ["Potato"] class ORMModel(BaseModel): @@ -24,6 +26,16 @@ class Potato(PotatoCreate, ORMModel): pass +class DefaultFactoryPotatoCreate(BaseModel): + color: str + mass: float + + +class DefaultFactoryPotato(DefaultFactoryPotatoCreate, ORMModel): + id: str = Field(default_factory=lambda: str(uuid4())) + pass + + class CustomPotato(PotatoCreate): potato_id: int diff --git a/tests/implementations/databases_.py b/tests/implementations/databases_.py index 3f1fce09..ed38dc1a 100644 --- a/tests/implementations/databases_.py +++ b/tests/implementations/databases_.py @@ -11,7 +11,9 @@ CustomPotato, PAGINATION_SIZE, Potato, + DefaultFactoryPotato, PotatoType, + POTATO_TAGS, CUSTOM_TAGS, config, ) @@ -47,6 +49,13 @@ def databases_implementation(db_uri: str): Column("color", String), Column("type", String), ) + defaultfactorypotatoes = Table( + "defaultfactorypotatoes", + metadata, + Column("id", String, primary_key=True), + Column("color", String), + Column("mass", Float), + ) carrots = Table( "carrots", metadata, @@ -74,6 +83,14 @@ async def shutdown(): prefix="potato", paginate=PAGINATION_SIZE, ), + dict( + database=database, + table=defaultfactorypotatoes, + schema=DefaultFactoryPotato, + prefix="defaultfactorypotato", + tags=POTATO_TAGS, + paginate=PAGINATION_SIZE, + ), dict( database=database, table=carrots, diff --git a/tests/implementations/gino_.py b/tests/implementations/gino_.py index c2f979aa..93c3e4c3 100644 --- a/tests/implementations/gino_.py +++ b/tests/implementations/gino_.py @@ -10,6 +10,8 @@ CarrotCreate, CarrotUpdate, CustomPotato, + DefaultFactoryPotato, + POTATO_TAGS, Potato, PotatoType, config, @@ -47,6 +49,12 @@ class PotatoModel(db.Model): color = db.Column(db.String) type = db.Column(db.String) + class DefaultFactoryPotatoModel(db.Model): + __tablename__ = "defaultfactorypotatoes" + id = db.Column(db.String, primary_key=True, index=True) + mass = db.Column(db.Float) + color = db.Column(db.String) + class CarrotModel(db.Model): __tablename__ = "carrots" id = db.Column(db.Integer, primary_key=True, index=True) @@ -63,6 +71,14 @@ class CarrotModel(db.Model): prefix="potato", paginate=PAGINATION_SIZE, ), + dict( + schema=DefaultFactoryPotato, + db_model=DefaultFactoryPotatoModel, + db=db, + prefix="defaultfactorypotato", + tags=POTATO_TAGS, + paginate=PAGINATION_SIZE, + ), dict( schema=Carrot, db_model=CarrotModel, diff --git a/tests/implementations/memory.py b/tests/implementations/memory.py index 0a53f014..478c55cb 100644 --- a/tests/implementations/memory.py +++ b/tests/implementations/memory.py @@ -1,13 +1,22 @@ from fastapi import FastAPI from fastapi_crudrouter import MemoryCRUDRouter -from tests import Potato, Carrot, CarrotUpdate, PAGINATION_SIZE, CUSTOM_TAGS +from tests import ( + Potato, + DefaultFactoryPotato, + Carrot, + CarrotUpdate, + PAGINATION_SIZE, + CUSTOM_TAGS, + POTATO_TAGS, +) def memory_implementation(**kwargs): app = FastAPI() router_settings = [ dict(schema=Potato, paginate=PAGINATION_SIZE), + dict(schema=DefaultFactoryPotato, paginate=PAGINATION_SIZE, tags=POTATO_TAGS), dict(schema=Carrot, update_schema=CarrotUpdate, tags=CUSTOM_TAGS), ] @@ -17,4 +26,7 @@ def memory_implementation(**kwargs): if __name__ == "__main__": import uvicorn - uvicorn.run(memory_implementation(), port=5000) + app, route_type, routes = memory_implementation() + for route in routes: + app.include_router(route_type(**route)) + uvicorn.run(app, port=5000) diff --git a/tests/implementations/ormar_.py b/tests/implementations/ormar_.py index 6fef389f..207f8084 100644 --- a/tests/implementations/ormar_.py +++ b/tests/implementations/ormar_.py @@ -2,12 +2,21 @@ import databases import ormar +from pydantic import Field import pytest import sqlalchemy from fastapi import FastAPI from fastapi_crudrouter import OrmarCRUDRouter -from tests import CarrotCreate, CarrotUpdate, PAGINATION_SIZE, CUSTOM_TAGS +from tests import ( + CarrotCreate, + CarrotUpdate, + PAGINATION_SIZE, + CUSTOM_TAGS, + POTATO_TAGS, + DefaultFactoryPotato, + DefaultFactoryPotatoCreate, +) DATABASE_URL = "sqlite:///./test.db" database = databases.Database(DATABASE_URL) @@ -44,6 +53,15 @@ class Meta(BaseMeta): type = ormar.String(max_length=255) +class DefaultFactoryPotatoModel(ormar.Model): + class Meta(BaseMeta): + tablename = "defaultfactorypotatoes" + + id = ormar.String(primary_key=True, max_length=300) + mass = ormar.Float() + color = ormar.String(max_length=255) + + class CarrotModel(ormar.Model): class Meta(BaseMeta): pass @@ -114,6 +132,15 @@ def ormar_implementation(**kwargs): prefix="potato", paginate=PAGINATION_SIZE, ), + dict( + schema=DefaultFactoryPotatoModel, + default_factory_schema=DefaultFactoryPotato, + create_schema=DefaultFactoryPotatoCreate, + update_schema=DefaultFactoryPotatoCreate, + prefix="defaultfactorypotato", + tags=POTATO_TAGS, + paginate=PAGINATION_SIZE, + ), dict( schema=CarrotModel, update_schema=CarrotUpdate, diff --git a/tests/implementations/sqlalchemy_.py b/tests/implementations/sqlalchemy_.py index e2295ab8..219e4bbb 100644 --- a/tests/implementations/sqlalchemy_.py +++ b/tests/implementations/sqlalchemy_.py @@ -13,6 +13,8 @@ CustomPotato, PAGINATION_SIZE, Potato, + DefaultFactoryPotato, + POTATO_TAGS, PotatoType, CUSTOM_TAGS, config, @@ -59,6 +61,12 @@ class PotatoModel(Base): color = Column(String) type = Column(String) + class DefaultFactoryPotatoModel(Base): + __tablename__ = "defaultfactorypotatoes" + id = Column(String, primary_key=True, index=True) + mass = Column(Float) + color = Column(String) + class CarrotModel(Base): __tablename__ = "carrots" id = Column(Integer, primary_key=True, index=True) @@ -74,6 +82,14 @@ class CarrotModel(Base): prefix="potato", paginate=PAGINATION_SIZE, ), + dict( + schema=DefaultFactoryPotato, + db_model=DefaultFactoryPotatoModel, + db=session, + prefix="defaultfactorypotato", + tags=POTATO_TAGS, + paginate=PAGINATION_SIZE, + ), dict( schema=Carrot, db_model=CarrotModel, diff --git a/tests/implementations/tortoise_.py b/tests/implementations/tortoise_.py index 198e8054..61312fb2 100644 --- a/tests/implementations/tortoise_.py +++ b/tests/implementations/tortoise_.py @@ -10,6 +10,8 @@ CarrotUpdate, PAGINATION_SIZE, Potato, + DefaultFactoryPotato, + POTATO_TAGS, CUSTOM_TAGS, ) @@ -21,6 +23,12 @@ class PotatoModel(Model): type = fields.CharField(max_length=255) +class DefaultFactoryPotatoModel(Model): + id = fields.CharField(pk=True, index=True, max_length=255) + mass = fields.FloatField() + color = fields.CharField(max_length=255) + + class CarrotModel(Model): length = fields.FloatField() color = fields.CharField(max_length=255) @@ -60,6 +68,13 @@ def tortoise_implementation(**kwargs): prefix="potato", paginate=PAGINATION_SIZE, ), + dict( + schema=DefaultFactoryPotato, + db_model=DefaultFactoryPotatoModel, + prefix="defaultfactorypotato", + tags=POTATO_TAGS, + paginate=PAGINATION_SIZE, + ), dict( schema=Carrot, db_model=CarrotModel, diff --git a/tests/test_default_factory.py b/tests/test_default_factory.py new file mode 100644 index 00000000..80d7d32b --- /dev/null +++ b/tests/test_default_factory.py @@ -0,0 +1,38 @@ +import pytest + +from . import test_router + +basic_potato = dict(mass=1.2, color="Brown") + +PotatoUrl = "/defaultfactorypotato" + + +def test_get(client): + test_router.test_get(client, PotatoUrl) + + +def test_post(client): + test_router.test_post(client, PotatoUrl, basic_potato) + + +def test_get_one(client): + test_router.test_get_one(client, PotatoUrl, basic_potato) + + +def test_update(client): + test_router.test_update(client, PotatoUrl, basic_potato) + + +def test_delete_one(client): + test_router.test_delete_one(client, PotatoUrl, basic_potato) + + +def test_delete_all(client): + test_router.test_delete_all(client, PotatoUrl, basic_potato) + + +@pytest.mark.parametrize( + "id_", [-1, 0, 4, "14", "4802ee13-6f04-40ae-b6bc-be8e9eb6ba82-dxg"] +) +def test_not_found(client, id_): + test_router.test_not_found(client, id_, PotatoUrl, basic_potato) diff --git a/tests/test_openapi_schema.py b/tests/test_openapi_schema.py index 61faf752..b7a30254 100644 --- a/tests/test_openapi_schema.py +++ b/tests/test_openapi_schema.py @@ -1,12 +1,13 @@ from pytest import mark -from tests import CUSTOM_TAGS +from tests import CUSTOM_TAGS, POTATO_TAGS -POTATO_TAGS = ["Potato"] PATHS = ["/potato", "/carrot"] PATH_TAGS = { "/potato": POTATO_TAGS, "/potato/{item_id}": POTATO_TAGS, + "/defaultfactorypotato": POTATO_TAGS, + "/defaultfactorypotato/{item_id}": POTATO_TAGS, "/carrot": CUSTOM_TAGS, "/carrot/{item_id}": CUSTOM_TAGS, }