From 1bb36a62c0d7966a1d9600af783bc973d58f7d34 Mon Sep 17 00:00:00 2001 From: "Volm, David" Date: Fri, 26 Aug 2022 12:16:08 -0500 Subject: [PATCH 01/10] Support default_factory on primary keys of schemas for create routes --- fastapi_crudrouter/core/_utils.py | 15 +++++++++++++++ fastapi_crudrouter/core/databases.py | 3 ++- fastapi_crudrouter/core/gino_starlette.py | 3 +++ fastapi_crudrouter/core/mem.py | 2 ++ fastapi_crudrouter/core/ormar.py | 2 ++ fastapi_crudrouter/core/sqlalchemy.py | 3 +++ fastapi_crudrouter/core/tortoise.py | 2 ++ 7 files changed, 29 insertions(+), 1 deletion(-) diff --git a/fastapi_crudrouter/core/_utils.py b/fastapi_crudrouter/core/_utils.py index ef3562e4..f1ced566 100644 --- a/fastapi_crudrouter/core/_utils.py +++ b/fastapi_crudrouter/core/_utils.py @@ -19,6 +19,21 @@ def get_pk_type(schema: Type[PYDANTIC_SCHEMA], pk_field: str) -> Any: return int +def create_schema_default_factory( + schema_cls: Type[T], create_schema_instance: T, pk_field_name: str = "id" +) -> T: + """ + Is used to check for default_factory for the pk on a Schema, + passing the CreateSchema values into the Schema if a + default_factory on the pk exists + """ + + if callable(schema_cls.__fields__[pk_field_name].default_factory): + return schema_cls(**create_schema_instance.dict()) + else: + return create_schema_instance + + def schema_factory( schema_cls: Type[T], pk_field_name: str = "id", name: str = "Create" ) -> Type[T]: diff --git a/fastapi_crudrouter/core/databases.py b/fastapi_crudrouter/core/databases.py index 7ea3c711..df8e46e3 100644 --- a/fastapi_crudrouter/core/databases.py +++ b/fastapi_crudrouter/core/databases.py @@ -13,7 +13,7 @@ from . import CRUDGenerator, NOT_FOUND from ._types import PAGINATION, PYDANTIC_SCHEMA, DEPENDENCIES -from ._utils import AttrDict, get_pk_type +from ._utils import AttrDict, get_pk_type, create_schema_default_factory try: from sqlalchemy.sql.schema import Table @@ -111,6 +111,7 @@ def _create(self, *args: Any, **kwargs: Any) -> CALLABLE: async def route( schema: self.create_schema, # type: ignore ) -> Model: + schema = create_schema_default_factory(schema_cls=self.schema, create_schema_instance=schema, pk_field_name=self._pk) query = self.table.insert() try: diff --git a/fastapi_crudrouter/core/gino_starlette.py b/fastapi_crudrouter/core/gino_starlette.py index d07d893e..273b4746 100644 --- a/fastapi_crudrouter/core/gino_starlette.py +++ b/fastapi_crudrouter/core/gino_starlette.py @@ -5,6 +5,7 @@ from . import NOT_FOUND, CRUDGenerator, _utils from ._types import DEPENDENCIES, PAGINATION from ._types import PYDANTIC_SCHEMA as SCHEMA +from ._utils import create_schema_default_factory try: from asyncpg.exceptions import UniqueViolationError @@ -94,6 +95,8 @@ def _create(self, *args: Any, **kwargs: Any) -> CALLABLE: async def route( model: self.create_schema, # type: ignore ) -> Model: + model = create_schema_default_factory(schema_cls=self.schema, create_schema_instance=model, pk_field_name=self._pk) + try: async with self.db.transaction(): db_model: Model = await self.db_model.create(**model.dict()) diff --git a/fastapi_crudrouter/core/mem.py b/fastapi_crudrouter/core/mem.py index d4e13c11..d75d554b 100644 --- a/fastapi_crudrouter/core/mem.py +++ b/fastapi_crudrouter/core/mem.py @@ -2,6 +2,7 @@ from . import CRUDGenerator, NOT_FOUND from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA +from ._utils import create_schema_default_factory CALLABLE = Callable[..., SCHEMA] CALLABLE_LIST = Callable[..., List[SCHEMA]] @@ -68,6 +69,7 @@ def route(item_id: int) -> SCHEMA: def _create(self, *args: Any, **kwargs: Any) -> CALLABLE: def route(model: self.create_schema) -> SCHEMA: # type: ignore + model = create_schema_default_factory(schema_cls=self.schema, create_schema_instance=model, pk_field_name=self._pk) model_dict = model.dict() model_dict["id"] = self._get_next_id() ready_model = self.schema(**model_dict) diff --git a/fastapi_crudrouter/core/ormar.py b/fastapi_crudrouter/core/ormar.py index 99952600..7e8bdf86 100644 --- a/fastapi_crudrouter/core/ormar.py +++ b/fastapi_crudrouter/core/ormar.py @@ -13,6 +13,7 @@ from . import CRUDGenerator, NOT_FOUND, _utils from ._types import DEPENDENCIES, PAGINATION +from ._utils import create_schema_default_factory try: from ormar import Model, NoMatch @@ -94,6 +95,7 @@ async def route(item_id: self._pk_type) -> Model: # type: ignore def _create(self, *args: Any, **kwargs: Any) -> CALLABLE: async def route(model: self.create_schema) -> Model: # type: ignore + model = create_schema_default_factory(schema_cls=self.schema, create_schema_instance=model, pk_field_name=self._pk) model_dict = model.dict() if self.schema.Meta.model_fields[self._pk].autoincrement: model_dict.pop(self._pk, None) diff --git a/fastapi_crudrouter/core/sqlalchemy.py b/fastapi_crudrouter/core/sqlalchemy.py index 58270f34..2342d36b 100644 --- a/fastapi_crudrouter/core/sqlalchemy.py +++ b/fastapi_crudrouter/core/sqlalchemy.py @@ -4,6 +4,7 @@ from . import CRUDGenerator, NOT_FOUND, _utils from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA +from ._utils import create_schema_default_factory try: from sqlalchemy.orm import Session @@ -102,6 +103,8 @@ def route( model: self.create_schema, # type: ignore db: Session = Depends(self.db_func), ) -> Model: + model = create_schema_default_factory(schema_cls=self.schema, create_schema_instance=model, pk_field_name=self._pk) + try: db_model: Model = self.db_model(**model.dict()) db.add(db_model) diff --git a/fastapi_crudrouter/core/tortoise.py b/fastapi_crudrouter/core/tortoise.py index 52972a48..1c044f00 100644 --- a/fastapi_crudrouter/core/tortoise.py +++ b/fastapi_crudrouter/core/tortoise.py @@ -2,6 +2,7 @@ from . import CRUDGenerator, NOT_FOUND from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA +from ._utils import create_schema_default_factory try: from tortoise.models import Model @@ -80,6 +81,7 @@ async def route(item_id: int) -> Model: def _create(self, *args: Any, **kwargs: Any) -> CALLABLE: async def route(model: self.create_schema) -> Model: # type: ignore + model = create_schema_default_factory(schema_cls=self.schema, create_schema_instance=model, pk_field_name=self._pk) db_model = self.db_model(**model.dict()) await db_model.save() From d6f9254e5bf258ca6f4c627fbe81327b0ce4daab Mon Sep 17 00:00:00 2001 From: "Volm, David" Date: Fri, 26 Aug 2022 12:39:48 -0500 Subject: [PATCH 02/10] better default_factory support for In Memory database --- fastapi_crudrouter/core/_utils.py | 8 ++++---- fastapi_crudrouter/core/databases.py | 2 +- fastapi_crudrouter/core/gino_starlette.py | 2 +- fastapi_crudrouter/core/mem.py | 11 +++++++++-- fastapi_crudrouter/core/ormar.py | 2 +- fastapi_crudrouter/core/sqlalchemy.py | 2 +- fastapi_crudrouter/core/tortoise.py | 2 +- 7 files changed, 18 insertions(+), 11 deletions(-) diff --git a/fastapi_crudrouter/core/_utils.py b/fastapi_crudrouter/core/_utils.py index f1ced566..390036b2 100644 --- a/fastapi_crudrouter/core/_utils.py +++ b/fastapi_crudrouter/core/_utils.py @@ -1,4 +1,4 @@ -from typing import Optional, Type, Any +from typing import Optional, Type, Any, Tuple from fastapi import Depends, HTTPException from pydantic import create_model @@ -21,7 +21,7 @@ def get_pk_type(schema: Type[PYDANTIC_SCHEMA], pk_field: str) -> Any: def create_schema_default_factory( schema_cls: Type[T], create_schema_instance: T, pk_field_name: str = "id" -) -> T: +) -> Tuple[T, bool]: """ Is used to check for default_factory for the pk on a Schema, passing the CreateSchema values into the Schema if a @@ -29,9 +29,9 @@ def create_schema_default_factory( """ if callable(schema_cls.__fields__[pk_field_name].default_factory): - return schema_cls(**create_schema_instance.dict()) + return schema_cls(**create_schema_instance.dict()), True else: - return create_schema_instance + return create_schema_instance, False def schema_factory( diff --git a/fastapi_crudrouter/core/databases.py b/fastapi_crudrouter/core/databases.py index df8e46e3..4042371c 100644 --- a/fastapi_crudrouter/core/databases.py +++ b/fastapi_crudrouter/core/databases.py @@ -111,7 +111,7 @@ def _create(self, *args: Any, **kwargs: Any) -> CALLABLE: async def route( schema: self.create_schema, # type: ignore ) -> Model: - schema = create_schema_default_factory(schema_cls=self.schema, create_schema_instance=schema, pk_field_name=self._pk) + schema, _ = create_schema_default_factory(schema_cls=self.schema, create_schema_instance=schema, pk_field_name=self._pk) query = self.table.insert() try: diff --git a/fastapi_crudrouter/core/gino_starlette.py b/fastapi_crudrouter/core/gino_starlette.py index 273b4746..b1526e5e 100644 --- a/fastapi_crudrouter/core/gino_starlette.py +++ b/fastapi_crudrouter/core/gino_starlette.py @@ -95,7 +95,7 @@ def _create(self, *args: Any, **kwargs: Any) -> CALLABLE: async def route( model: self.create_schema, # type: ignore ) -> Model: - model = create_schema_default_factory(schema_cls=self.schema, create_schema_instance=model, pk_field_name=self._pk) + model, _ = create_schema_default_factory(schema_cls=self.schema, create_schema_instance=model, pk_field_name=self._pk) try: async with self.db.transaction(): diff --git a/fastapi_crudrouter/core/mem.py b/fastapi_crudrouter/core/mem.py index d75d554b..d32e14d0 100644 --- a/fastapi_crudrouter/core/mem.py +++ b/fastapi_crudrouter/core/mem.py @@ -1,5 +1,7 @@ from typing import Any, Callable, List, Type, cast, Optional, Union +from fastapi import HTTPException + from . import CRUDGenerator, NOT_FOUND from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA from ._utils import create_schema_default_factory @@ -69,9 +71,14 @@ def route(item_id: int) -> SCHEMA: def _create(self, *args: Any, **kwargs: Any) -> CALLABLE: def route(model: self.create_schema) -> SCHEMA: # type: ignore - model = create_schema_default_factory(schema_cls=self.schema, create_schema_instance=model, pk_field_name=self._pk) + model, using_default_factory = create_schema_default_factory(schema_cls=self.schema, create_schema_instance=model, pk_field_name=self._pk) model_dict = model.dict() - model_dict["id"] = self._get_next_id() + if using_default_factory: + for _model in self.models: + if _model.id == model.id: # type: ignore + raise HTTPException(422, "Key already exists") from None + else: + model_dict["id"] = self._get_next_id() ready_model = self.schema(**model_dict) self.models.append(ready_model) return ready_model diff --git a/fastapi_crudrouter/core/ormar.py b/fastapi_crudrouter/core/ormar.py index 7e8bdf86..612b871c 100644 --- a/fastapi_crudrouter/core/ormar.py +++ b/fastapi_crudrouter/core/ormar.py @@ -95,7 +95,7 @@ async def route(item_id: self._pk_type) -> Model: # type: ignore def _create(self, *args: Any, **kwargs: Any) -> CALLABLE: async def route(model: self.create_schema) -> Model: # type: ignore - model = create_schema_default_factory(schema_cls=self.schema, create_schema_instance=model, pk_field_name=self._pk) + model, _ = create_schema_default_factory(schema_cls=self.schema, create_schema_instance=model, pk_field_name=self._pk) model_dict = model.dict() if self.schema.Meta.model_fields[self._pk].autoincrement: model_dict.pop(self._pk, None) diff --git a/fastapi_crudrouter/core/sqlalchemy.py b/fastapi_crudrouter/core/sqlalchemy.py index 2342d36b..547a5d6c 100644 --- a/fastapi_crudrouter/core/sqlalchemy.py +++ b/fastapi_crudrouter/core/sqlalchemy.py @@ -103,7 +103,7 @@ def route( model: self.create_schema, # type: ignore db: Session = Depends(self.db_func), ) -> Model: - model = create_schema_default_factory(schema_cls=self.schema, create_schema_instance=model, pk_field_name=self._pk) + model, _ = create_schema_default_factory(schema_cls=self.schema, create_schema_instance=model, pk_field_name=self._pk) try: db_model: Model = self.db_model(**model.dict()) diff --git a/fastapi_crudrouter/core/tortoise.py b/fastapi_crudrouter/core/tortoise.py index 1c044f00..d41a2bd2 100644 --- a/fastapi_crudrouter/core/tortoise.py +++ b/fastapi_crudrouter/core/tortoise.py @@ -81,7 +81,7 @@ async def route(item_id: int) -> Model: def _create(self, *args: Any, **kwargs: Any) -> CALLABLE: async def route(model: self.create_schema) -> Model: # type: ignore - model = create_schema_default_factory(schema_cls=self.schema, create_schema_instance=model, pk_field_name=self._pk) + model, _ = create_schema_default_factory(schema_cls=self.schema, create_schema_instance=model, pk_field_name=self._pk) db_model = self.db_model(**model.dict()) await db_model.save() From 4d3ef8f4b8286ac36025a197d43bc906834c343f Mon Sep 17 00:00:00 2001 From: "Volm, David" Date: Fri, 16 Sep 2022 06:35:00 -0500 Subject: [PATCH 03/10] Ran `black fastapi_crudrouter tests` --- fastapi_crudrouter/core/databases.py | 6 +++++- fastapi_crudrouter/core/gino_starlette.py | 6 +++++- fastapi_crudrouter/core/mem.py | 6 +++++- fastapi_crudrouter/core/ormar.py | 6 +++++- fastapi_crudrouter/core/sqlalchemy.py | 6 +++++- fastapi_crudrouter/core/tortoise.py | 6 +++++- 6 files changed, 30 insertions(+), 6 deletions(-) diff --git a/fastapi_crudrouter/core/databases.py b/fastapi_crudrouter/core/databases.py index 4042371c..777e2539 100644 --- a/fastapi_crudrouter/core/databases.py +++ b/fastapi_crudrouter/core/databases.py @@ -111,7 +111,11 @@ def _create(self, *args: Any, **kwargs: Any) -> CALLABLE: async def route( schema: self.create_schema, # type: ignore ) -> Model: - schema, _ = create_schema_default_factory(schema_cls=self.schema, create_schema_instance=schema, pk_field_name=self._pk) + schema, _ = create_schema_default_factory( + schema_cls=self.schema, + create_schema_instance=schema, + pk_field_name=self._pk, + ) query = self.table.insert() try: diff --git a/fastapi_crudrouter/core/gino_starlette.py b/fastapi_crudrouter/core/gino_starlette.py index b1526e5e..55894177 100644 --- a/fastapi_crudrouter/core/gino_starlette.py +++ b/fastapi_crudrouter/core/gino_starlette.py @@ -95,7 +95,11 @@ def _create(self, *args: Any, **kwargs: Any) -> CALLABLE: async def route( model: self.create_schema, # type: ignore ) -> Model: - model, _ = create_schema_default_factory(schema_cls=self.schema, create_schema_instance=model, pk_field_name=self._pk) + model, _ = create_schema_default_factory( + schema_cls=self.schema, + create_schema_instance=model, + pk_field_name=self._pk, + ) try: async with self.db.transaction(): diff --git a/fastapi_crudrouter/core/mem.py b/fastapi_crudrouter/core/mem.py index d32e14d0..a2d6078e 100644 --- a/fastapi_crudrouter/core/mem.py +++ b/fastapi_crudrouter/core/mem.py @@ -71,7 +71,11 @@ def route(item_id: int) -> SCHEMA: def _create(self, *args: Any, **kwargs: Any) -> CALLABLE: def route(model: self.create_schema) -> SCHEMA: # type: ignore - model, using_default_factory = create_schema_default_factory(schema_cls=self.schema, create_schema_instance=model, pk_field_name=self._pk) + model, using_default_factory = create_schema_default_factory( + schema_cls=self.schema, + create_schema_instance=model, + pk_field_name=self._pk, + ) model_dict = model.dict() if using_default_factory: for _model in self.models: diff --git a/fastapi_crudrouter/core/ormar.py b/fastapi_crudrouter/core/ormar.py index 612b871c..7a54e60b 100644 --- a/fastapi_crudrouter/core/ormar.py +++ b/fastapi_crudrouter/core/ormar.py @@ -95,7 +95,11 @@ async def route(item_id: self._pk_type) -> Model: # type: ignore def _create(self, *args: Any, **kwargs: Any) -> CALLABLE: async def route(model: self.create_schema) -> Model: # type: ignore - model, _ = create_schema_default_factory(schema_cls=self.schema, create_schema_instance=model, pk_field_name=self._pk) + model, _ = create_schema_default_factory( + schema_cls=self.schema, + create_schema_instance=model, + pk_field_name=self._pk, + ) model_dict = model.dict() if self.schema.Meta.model_fields[self._pk].autoincrement: model_dict.pop(self._pk, None) diff --git a/fastapi_crudrouter/core/sqlalchemy.py b/fastapi_crudrouter/core/sqlalchemy.py index 547a5d6c..b56069b9 100644 --- a/fastapi_crudrouter/core/sqlalchemy.py +++ b/fastapi_crudrouter/core/sqlalchemy.py @@ -103,7 +103,11 @@ def route( model: self.create_schema, # type: ignore db: Session = Depends(self.db_func), ) -> Model: - model, _ = create_schema_default_factory(schema_cls=self.schema, create_schema_instance=model, pk_field_name=self._pk) + model, _ = create_schema_default_factory( + schema_cls=self.schema, + create_schema_instance=model, + pk_field_name=self._pk, + ) try: db_model: Model = self.db_model(**model.dict()) diff --git a/fastapi_crudrouter/core/tortoise.py b/fastapi_crudrouter/core/tortoise.py index d41a2bd2..5c95194b 100644 --- a/fastapi_crudrouter/core/tortoise.py +++ b/fastapi_crudrouter/core/tortoise.py @@ -81,7 +81,11 @@ async def route(item_id: int) -> Model: def _create(self, *args: Any, **kwargs: Any) -> CALLABLE: async def route(model: self.create_schema) -> Model: # type: ignore - model, _ = create_schema_default_factory(schema_cls=self.schema, create_schema_instance=model, pk_field_name=self._pk) + model, _ = create_schema_default_factory( + schema_cls=self.schema, + create_schema_instance=model, + pk_field_name=self._pk, + ) db_model = self.db_model(**model.dict()) await db_model.save() From 8000bacadecb126a1149ff23f23073e248e96131 Mon Sep 17 00:00:00 2001 From: "Volm, David" Date: Sat, 17 Sep 2022 11:35:26 -0500 Subject: [PATCH 04/10] Added tests for default factory on all implementations --- tests/__init__.py | 14 +++++++++++++- tests/implementations/databases_.py | 17 +++++++++++++++++ tests/implementations/gino_.py | 16 ++++++++++++++++ tests/implementations/memory.py | 11 ++++++++++- tests/implementations/ormar_.py | 26 +++++++++++++++++++++++++- tests/implementations/sqlalchemy_.py | 16 ++++++++++++++++ tests/implementations/tortoise_.py | 15 +++++++++++++++ tests/test_openapi_schema.py | 5 +++-- 8 files changed, 115 insertions(+), 5 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index 4603469d..f1763d37 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,9 +1,11 @@ -from pydantic import BaseModel +from uuid import uuid4 +from pydantic import BaseModel, Field from .conf import config PAGINATION_SIZE = 10 CUSTOM_TAGS = ["Tag1", "Tag2"] +POTATO_TAGS = ["Potato"] class ORMModel(BaseModel): @@ -24,6 +26,16 @@ class Potato(PotatoCreate, ORMModel): pass +class DefaultFactoryPotatoCreate(BaseModel): + id: str = Field(default_factory=lambda: str(uuid4())) + color: str + mass: float + + +class DefaultFactoryPotato(DefaultFactoryPotatoCreate, ORMModel): + pass + + class CustomPotato(PotatoCreate): potato_id: int diff --git a/tests/implementations/databases_.py b/tests/implementations/databases_.py index 3f1fce09..ed38dc1a 100644 --- a/tests/implementations/databases_.py +++ b/tests/implementations/databases_.py @@ -11,7 +11,9 @@ CustomPotato, PAGINATION_SIZE, Potato, + DefaultFactoryPotato, PotatoType, + POTATO_TAGS, CUSTOM_TAGS, config, ) @@ -47,6 +49,13 @@ def databases_implementation(db_uri: str): Column("color", String), Column("type", String), ) + defaultfactorypotatoes = Table( + "defaultfactorypotatoes", + metadata, + Column("id", String, primary_key=True), + Column("color", String), + Column("mass", Float), + ) carrots = Table( "carrots", metadata, @@ -74,6 +83,14 @@ async def shutdown(): prefix="potato", paginate=PAGINATION_SIZE, ), + dict( + database=database, + table=defaultfactorypotatoes, + schema=DefaultFactoryPotato, + prefix="defaultfactorypotato", + tags=POTATO_TAGS, + paginate=PAGINATION_SIZE, + ), dict( database=database, table=carrots, diff --git a/tests/implementations/gino_.py b/tests/implementations/gino_.py index c2f979aa..93c3e4c3 100644 --- a/tests/implementations/gino_.py +++ b/tests/implementations/gino_.py @@ -10,6 +10,8 @@ CarrotCreate, CarrotUpdate, CustomPotato, + DefaultFactoryPotato, + POTATO_TAGS, Potato, PotatoType, config, @@ -47,6 +49,12 @@ class PotatoModel(db.Model): color = db.Column(db.String) type = db.Column(db.String) + class DefaultFactoryPotatoModel(db.Model): + __tablename__ = "defaultfactorypotatoes" + id = db.Column(db.String, primary_key=True, index=True) + mass = db.Column(db.Float) + color = db.Column(db.String) + class CarrotModel(db.Model): __tablename__ = "carrots" id = db.Column(db.Integer, primary_key=True, index=True) @@ -63,6 +71,14 @@ class CarrotModel(db.Model): prefix="potato", paginate=PAGINATION_SIZE, ), + dict( + schema=DefaultFactoryPotato, + db_model=DefaultFactoryPotatoModel, + db=db, + prefix="defaultfactorypotato", + tags=POTATO_TAGS, + paginate=PAGINATION_SIZE, + ), dict( schema=Carrot, db_model=CarrotModel, diff --git a/tests/implementations/memory.py b/tests/implementations/memory.py index 0a53f014..5c51888e 100644 --- a/tests/implementations/memory.py +++ b/tests/implementations/memory.py @@ -1,13 +1,22 @@ from fastapi import FastAPI from fastapi_crudrouter import MemoryCRUDRouter -from tests import Potato, Carrot, CarrotUpdate, PAGINATION_SIZE, CUSTOM_TAGS +from tests import ( + Potato, + DefaultFactoryPotato, + Carrot, + CarrotUpdate, + PAGINATION_SIZE, + CUSTOM_TAGS, + POTATO_TAGS, +) def memory_implementation(**kwargs): app = FastAPI() router_settings = [ dict(schema=Potato, paginate=PAGINATION_SIZE), + dict(schema=DefaultFactoryPotato, paginate=PAGINATION_SIZE, tags=POTATO_TAGS), dict(schema=Carrot, update_schema=CarrotUpdate, tags=CUSTOM_TAGS), ] diff --git a/tests/implementations/ormar_.py b/tests/implementations/ormar_.py index 6fef389f..1a68474d 100644 --- a/tests/implementations/ormar_.py +++ b/tests/implementations/ormar_.py @@ -2,12 +2,20 @@ import databases import ormar +from pydantic import Field import pytest import sqlalchemy from fastapi import FastAPI from fastapi_crudrouter import OrmarCRUDRouter -from tests import CarrotCreate, CarrotUpdate, PAGINATION_SIZE, CUSTOM_TAGS +from tests import ( + CarrotCreate, + CarrotUpdate, + PAGINATION_SIZE, + CUSTOM_TAGS, + POTATO_TAGS, + DefaultFactoryPotatoCreate, +) DATABASE_URL = "sqlite:///./test.db" database = databases.Database(DATABASE_URL) @@ -44,6 +52,15 @@ class Meta(BaseMeta): type = ormar.String(max_length=255) +class DefaultFactoryPotato(ormar.Model): + class Meta(BaseMeta): + tablename = "defaultfactorypotatoes" + + id = ormar.String(primary_key=True, max_length=300) + mass = ormar.Float() + color = ormar.String(max_length=255) + + class CarrotModel(ormar.Model): class Meta(BaseMeta): pass @@ -114,6 +131,13 @@ def ormar_implementation(**kwargs): prefix="potato", paginate=PAGINATION_SIZE, ), + dict( + schema=DefaultFactoryPotato, + create_schema=DefaultFactoryPotatoCreate, + prefix="defaultfactorypotato", + tags=POTATO_TAGS, + paginate=PAGINATION_SIZE, + ), dict( schema=CarrotModel, update_schema=CarrotUpdate, diff --git a/tests/implementations/sqlalchemy_.py b/tests/implementations/sqlalchemy_.py index e2295ab8..219e4bbb 100644 --- a/tests/implementations/sqlalchemy_.py +++ b/tests/implementations/sqlalchemy_.py @@ -13,6 +13,8 @@ CustomPotato, PAGINATION_SIZE, Potato, + DefaultFactoryPotato, + POTATO_TAGS, PotatoType, CUSTOM_TAGS, config, @@ -59,6 +61,12 @@ class PotatoModel(Base): color = Column(String) type = Column(String) + class DefaultFactoryPotatoModel(Base): + __tablename__ = "defaultfactorypotatoes" + id = Column(String, primary_key=True, index=True) + mass = Column(Float) + color = Column(String) + class CarrotModel(Base): __tablename__ = "carrots" id = Column(Integer, primary_key=True, index=True) @@ -74,6 +82,14 @@ class CarrotModel(Base): prefix="potato", paginate=PAGINATION_SIZE, ), + dict( + schema=DefaultFactoryPotato, + db_model=DefaultFactoryPotatoModel, + db=session, + prefix="defaultfactorypotato", + tags=POTATO_TAGS, + paginate=PAGINATION_SIZE, + ), dict( schema=Carrot, db_model=CarrotModel, diff --git a/tests/implementations/tortoise_.py b/tests/implementations/tortoise_.py index 198e8054..61312fb2 100644 --- a/tests/implementations/tortoise_.py +++ b/tests/implementations/tortoise_.py @@ -10,6 +10,8 @@ CarrotUpdate, PAGINATION_SIZE, Potato, + DefaultFactoryPotato, + POTATO_TAGS, CUSTOM_TAGS, ) @@ -21,6 +23,12 @@ class PotatoModel(Model): type = fields.CharField(max_length=255) +class DefaultFactoryPotatoModel(Model): + id = fields.CharField(pk=True, index=True, max_length=255) + mass = fields.FloatField() + color = fields.CharField(max_length=255) + + class CarrotModel(Model): length = fields.FloatField() color = fields.CharField(max_length=255) @@ -60,6 +68,13 @@ def tortoise_implementation(**kwargs): prefix="potato", paginate=PAGINATION_SIZE, ), + dict( + schema=DefaultFactoryPotato, + db_model=DefaultFactoryPotatoModel, + prefix="defaultfactorypotato", + tags=POTATO_TAGS, + paginate=PAGINATION_SIZE, + ), dict( schema=Carrot, db_model=CarrotModel, diff --git a/tests/test_openapi_schema.py b/tests/test_openapi_schema.py index 61faf752..b7a30254 100644 --- a/tests/test_openapi_schema.py +++ b/tests/test_openapi_schema.py @@ -1,12 +1,13 @@ from pytest import mark -from tests import CUSTOM_TAGS +from tests import CUSTOM_TAGS, POTATO_TAGS -POTATO_TAGS = ["Potato"] PATHS = ["/potato", "/carrot"] PATH_TAGS = { "/potato": POTATO_TAGS, "/potato/{item_id}": POTATO_TAGS, + "/defaultfactorypotato": POTATO_TAGS, + "/defaultfactorypotato/{item_id}": POTATO_TAGS, "/carrot": CUSTOM_TAGS, "/carrot/{item_id}": CUSTOM_TAGS, } From a8e419152ad80a0974d3562b90024004e77e38c3 Mon Sep 17 00:00:00 2001 From: "Volm, David" Date: Sat, 17 Sep 2022 14:33:51 -0500 Subject: [PATCH 05/10] Allow other types of PKs in memory and tortoise implementations --- fastapi_crudrouter/core/_utils.py | 2 +- fastapi_crudrouter/core/mem.py | 10 ++++++---- fastapi_crudrouter/core/tortoise.py | 9 +++++---- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/fastapi_crudrouter/core/_utils.py b/fastapi_crudrouter/core/_utils.py index 390036b2..7bcb16eb 100644 --- a/fastapi_crudrouter/core/_utils.py +++ b/fastapi_crudrouter/core/_utils.py @@ -12,7 +12,7 @@ def __init__(self, *args, **kwargs) -> None: # type: ignore self.__dict__ = self -def get_pk_type(schema: Type[PYDANTIC_SCHEMA], pk_field: str) -> Any: +def get_pk_type(schema: Type[PYDANTIC_SCHEMA], pk_field: str = "id") -> Any: try: return schema.__fields__[pk_field].type_ except KeyError: diff --git a/fastapi_crudrouter/core/mem.py b/fastapi_crudrouter/core/mem.py index a2d6078e..e7854a8a 100644 --- a/fastapi_crudrouter/core/mem.py +++ b/fastapi_crudrouter/core/mem.py @@ -4,7 +4,7 @@ from . import CRUDGenerator, NOT_FOUND from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA -from ._utils import create_schema_default_factory +from ._utils import get_pk_type, create_schema_default_factory CALLABLE = Callable[..., SCHEMA] CALLABLE_LIST = Callable[..., List[SCHEMA]] @@ -27,6 +27,8 @@ def __init__( delete_all_route: Union[bool, DEPENDENCIES] = True, **kwargs: Any ) -> None: + self._pk_type: type = get_pk_type(schema) + super().__init__( schema=schema, create_schema=create_schema, @@ -60,7 +62,7 @@ def route(pagination: PAGINATION = self.pagination) -> List[SCHEMA]: return route def _get_one(self, *args: Any, **kwargs: Any) -> CALLABLE: - def route(item_id: int) -> SCHEMA: + def route(item_id: self._pk_type) -> SCHEMA: for model in self.models: if model.id == item_id: # type: ignore return model @@ -90,7 +92,7 @@ def route(model: self.create_schema) -> SCHEMA: # type: ignore return route def _update(self, *args: Any, **kwargs: Any) -> CALLABLE: - def route(item_id: int, model: self.update_schema) -> SCHEMA: # type: ignore + def route(item_id: self._pk_type, model: self.update_schema) -> SCHEMA: # type: ignore for ind, model_ in enumerate(self.models): if model_.id == item_id: # type: ignore self.models[ind] = self.schema( @@ -110,7 +112,7 @@ def route() -> List[SCHEMA]: return route def _delete_one(self, *args: Any, **kwargs: Any) -> CALLABLE: - def route(item_id: int) -> SCHEMA: + def route(item_id: self._pk_type) -> SCHEMA: for ind, model in enumerate(self.models): if model.id == item_id: # type: ignore del self.models[ind] diff --git a/fastapi_crudrouter/core/tortoise.py b/fastapi_crudrouter/core/tortoise.py index 5c95194b..75a7970c 100644 --- a/fastapi_crudrouter/core/tortoise.py +++ b/fastapi_crudrouter/core/tortoise.py @@ -2,7 +2,7 @@ from . import CRUDGenerator, NOT_FOUND from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA -from ._utils import create_schema_default_factory +from ._utils import get_pk_type, create_schema_default_factory try: from tortoise.models import Model @@ -41,6 +41,7 @@ def __init__( self.db_model = db_model self._pk: str = db_model.describe()["pk_field"]["db_column"] + self._pk_type: type = get_pk_type(schema, self._pk) super().__init__( schema=schema, @@ -69,7 +70,7 @@ async def route(pagination: PAGINATION = self.pagination) -> List[Model]: return route def _get_one(self, *args: Any, **kwargs: Any) -> CALLABLE: - async def route(item_id: int) -> Model: + async def route(item_id: self._pk_type) -> Model: model = await self.db_model.filter(id=item_id).first() if model: @@ -95,7 +96,7 @@ async def route(model: self.create_schema) -> Model: # type: ignore def _update(self, *args: Any, **kwargs: Any) -> CALLABLE: async def route( - item_id: int, model: self.update_schema # type: ignore + item_id: self._pk_type, model: self.update_schema # type: ignore ) -> Model: await self.db_model.filter(id=item_id).update( **model.dict(exclude_unset=True) @@ -112,7 +113,7 @@ async def route() -> List[Model]: return route def _delete_one(self, *args: Any, **kwargs: Any) -> CALLABLE: - async def route(item_id: int) -> Model: + async def route(item_id: self._pk_type) -> Model: model: Model = await self._get_one()(item_id) await self.db_model.filter(id=item_id).delete() From 61d8412f01cbb7cadf96b16663d074391d9ce11d Mon Sep 17 00:00:00 2001 From: "Volm, David" Date: Sat, 17 Sep 2022 14:35:14 -0500 Subject: [PATCH 06/10] allow override of default factory schema in ormar implementation --- fastapi_crudrouter/core/ormar.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/fastapi_crudrouter/core/ormar.py b/fastapi_crudrouter/core/ormar.py index 7a54e60b..a6abf404 100644 --- a/fastapi_crudrouter/core/ormar.py +++ b/fastapi_crudrouter/core/ormar.py @@ -34,6 +34,7 @@ def __init__( schema: Type[Model], create_schema: Optional[Type[Model]] = None, update_schema: Optional[Type[Model]] = None, + default_factory_schema: Optional[Type[Model]] = None, prefix: Optional[str] = None, tags: Optional[List[str]] = None, paginate: Optional[int] = None, @@ -49,6 +50,9 @@ def __init__( self._pk: str = schema.Meta.pkname self._pk_type: type = _utils.get_pk_type(schema, self._pk) + self.default_factory_schema = ( + default_factory_schema if default_factory_schema else schema + ) super().__init__( schema=schema, @@ -96,7 +100,7 @@ async def route(item_id: self._pk_type) -> Model: # type: ignore def _create(self, *args: Any, **kwargs: Any) -> CALLABLE: async def route(model: self.create_schema) -> Model: # type: ignore model, _ = create_schema_default_factory( - schema_cls=self.schema, + schema_cls=self.default_factory_schema, create_schema_instance=model, pk_field_name=self._pk, ) From 5a11d1d45ddbc92d3471a9ab375e649dd852abdf Mon Sep 17 00:00:00 2001 From: "Volm, David" Date: Sat, 17 Sep 2022 14:36:04 -0500 Subject: [PATCH 07/10] don't have ID in default factory create model --- tests/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/__init__.py b/tests/__init__.py index f1763d37..3bc622cc 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -27,12 +27,12 @@ class Potato(PotatoCreate, ORMModel): class DefaultFactoryPotatoCreate(BaseModel): - id: str = Field(default_factory=lambda: str(uuid4())) color: str mass: float class DefaultFactoryPotato(DefaultFactoryPotatoCreate, ORMModel): + id: str = Field(default_factory=lambda: str(uuid4())) pass From b5a99feeafd787918a12582f6314700684e104b2 Mon Sep 17 00:00:00 2001 From: "Volm, David" Date: Sat, 17 Sep 2022 14:36:55 -0500 Subject: [PATCH 08/10] fix memory __main__ --- tests/implementations/memory.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/implementations/memory.py b/tests/implementations/memory.py index 5c51888e..478c55cb 100644 --- a/tests/implementations/memory.py +++ b/tests/implementations/memory.py @@ -26,4 +26,7 @@ def memory_implementation(**kwargs): if __name__ == "__main__": import uvicorn - uvicorn.run(memory_implementation(), port=5000) + app, route_type, routes = memory_implementation() + for route in routes: + app.include_router(route_type(**route)) + uvicorn.run(app, port=5000) From 139825523b7c717e75a752ec305da8aecf0f1d01 Mon Sep 17 00:00:00 2001 From: "Volm, David" Date: Sat, 17 Sep 2022 14:37:51 -0500 Subject: [PATCH 09/10] fix ormar test for default_factory --- tests/implementations/ormar_.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/implementations/ormar_.py b/tests/implementations/ormar_.py index 1a68474d..207f8084 100644 --- a/tests/implementations/ormar_.py +++ b/tests/implementations/ormar_.py @@ -14,6 +14,7 @@ PAGINATION_SIZE, CUSTOM_TAGS, POTATO_TAGS, + DefaultFactoryPotato, DefaultFactoryPotatoCreate, ) @@ -52,7 +53,7 @@ class Meta(BaseMeta): type = ormar.String(max_length=255) -class DefaultFactoryPotato(ormar.Model): +class DefaultFactoryPotatoModel(ormar.Model): class Meta(BaseMeta): tablename = "defaultfactorypotatoes" @@ -132,8 +133,10 @@ def ormar_implementation(**kwargs): paginate=PAGINATION_SIZE, ), dict( - schema=DefaultFactoryPotato, + schema=DefaultFactoryPotatoModel, + default_factory_schema=DefaultFactoryPotato, create_schema=DefaultFactoryPotatoCreate, + update_schema=DefaultFactoryPotatoCreate, prefix="defaultfactorypotato", tags=POTATO_TAGS, paginate=PAGINATION_SIZE, From 62045ca82df927d2c7cde9a74cb76d9ed4148e3c Mon Sep 17 00:00:00 2001 From: "Volm, David" Date: Sat, 17 Sep 2022 14:40:20 -0500 Subject: [PATCH 10/10] Added default_factory tests --- tests/test_default_factory.py | 38 +++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 tests/test_default_factory.py diff --git a/tests/test_default_factory.py b/tests/test_default_factory.py new file mode 100644 index 00000000..80d7d32b --- /dev/null +++ b/tests/test_default_factory.py @@ -0,0 +1,38 @@ +import pytest + +from . import test_router + +basic_potato = dict(mass=1.2, color="Brown") + +PotatoUrl = "/defaultfactorypotato" + + +def test_get(client): + test_router.test_get(client, PotatoUrl) + + +def test_post(client): + test_router.test_post(client, PotatoUrl, basic_potato) + + +def test_get_one(client): + test_router.test_get_one(client, PotatoUrl, basic_potato) + + +def test_update(client): + test_router.test_update(client, PotatoUrl, basic_potato) + + +def test_delete_one(client): + test_router.test_delete_one(client, PotatoUrl, basic_potato) + + +def test_delete_all(client): + test_router.test_delete_all(client, PotatoUrl, basic_potato) + + +@pytest.mark.parametrize( + "id_", [-1, 0, 4, "14", "4802ee13-6f04-40ae-b6bc-be8e9eb6ba82-dxg"] +) +def test_not_found(client, id_): + test_router.test_not_found(client, id_, PotatoUrl, basic_potato)