From 18a9aaafaa5a0c653814a045045114345876670d Mon Sep 17 00:00:00 2001 From: cycledriver Date: Fri, 14 Oct 2022 07:43:46 -0400 Subject: [PATCH 1/2] [sqlalchemy] allow create/update with object for one/many 2 many For one-2-many and many-2-many relationships, allow the create and update routes to accept a partial object in the foreign key attribute. For example: client.post("/heros", json={ "name": Bob, "team": {"name": "Avengers"} } Assuming there is already a team called Avengers, Bob will be created, the Team with name "Avengers" will be looked up and used to populate Bob's team_id foreign key attribute. The only setup required is for the input model for the foreign object to specify the Table class that can be used to lookup the object. For example: class Team(Base): """Team DTO.""" __tablename__ = "teams" id = Column(Integer, primary_key=True, index=True) name = Column(String, unique=True) class TeamUpdate(Model): name: str class Meta: orm_model = Team --- fastapi_crudrouter/core/sqlalchemy.py | 54 ++++- tests/test_sqlalchemy_nested_obj.py | 314 ++++++++++++++++++++++++++ 2 files changed, 362 insertions(+), 6 deletions(-) create mode 100644 tests/test_sqlalchemy_nested_obj.py diff --git a/fastapi_crudrouter/core/sqlalchemy.py b/fastapi_crudrouter/core/sqlalchemy.py index 58270f34..b7f08b8f 100644 --- a/fastapi_crudrouter/core/sqlalchemy.py +++ b/fastapi_crudrouter/core/sqlalchemy.py @@ -1,4 +1,5 @@ from typing import Any, Callable, List, Type, Generator, Optional, Union +from collections.abc import Sequence from fastapi import Depends, HTTPException @@ -9,10 +10,12 @@ from sqlalchemy.orm import Session from sqlalchemy.ext.declarative import DeclarativeMeta as Model from sqlalchemy.exc import IntegrityError + from sqlalchemy import column except ImportError: Model = None Session = None IntegrityError = None + column = None sqlalchemy_installed = False else: sqlalchemy_installed = True @@ -39,7 +42,7 @@ def __init__( update_route: Union[bool, DEPENDENCIES] = True, delete_one_route: Union[bool, DEPENDENCIES] = True, delete_all_route: Union[bool, DEPENDENCIES] = True, - **kwargs: Any + **kwargs: Any, ) -> None: assert ( sqlalchemy_installed @@ -63,7 +66,7 @@ def __init__( update_route=update_route, delete_one_route=delete_one_route, delete_all_route=delete_all_route, - **kwargs + **kwargs, ) def _get_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST: @@ -97,13 +100,49 @@ def route( return route + def _get_orm_object(self, db: Session, orm_model: Model, model: Model) -> Any: + query = db.query(orm_model) + filter_items = 0 + for key, val in model.dict().items(): + if val: + filter_items += 1 + query = query.filter(column(key) == val) + if filter_items == 0: + raise Exception("No attributes for filter found") + return query.one() + + def _get_orm_object_or_value(self, db: Session, val: Any) -> Any: + """Return an inflated database object or a plain value. + + If a `val` is a SqlModel type and has defined a Meta.orm model + attribute, lookup the object from the `db` and return it. + Otherwise, just return the `val`. If `val` is a sequence of + objects, return the sequence of objects from the db. + """ + # we want to iterate through sequences but not strings + if not val or isinstance(val, str): + return val + + if isinstance(val, Sequence): + return [self._get_orm_object_or_value(db, v) for v in val] + else: + if meta_class := getattr(val, "Meta", None): + if orm_model := getattr(meta_class, "orm_model", None): + return self._get_orm_object(db, orm_model, val) + return val + def _create(self, *args: Any, **kwargs: Any) -> CALLABLE: def route( model: self.create_schema, # type: ignore db: Session = Depends(self.db_func), ) -> Model: try: - db_model: Model = self.db_model(**model.dict()) + db_model: Model = self.db_model() + + for key, val in model: + if val: + setattr(db_model, key, self._get_orm_object_or_value(db, val)) + db.add(db_model) db.commit() db.refresh(db_model) @@ -123,9 +162,12 @@ def route( try: db_model: Model = self._get_one()(item_id, db) - for key, value in model.dict(exclude={self._pk}).items(): - if hasattr(db_model, key): - setattr(db_model, key, value) + for key, val in model: + if key != self._pk: + if hasattr(db_model, key): + setattr( + db_model, key, self._get_orm_object_or_value(db, val) + ) db.commit() db.refresh(db_model) diff --git a/tests/test_sqlalchemy_nested_obj.py b/tests/test_sqlalchemy_nested_obj.py new file mode 100644 index 00000000..198f3d7e --- /dev/null +++ b/tests/test_sqlalchemy_nested_obj.py @@ -0,0 +1,314 @@ +from typing import TYPE_CHECKING, Callable, Iterator, List, Optional + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from fastapi_crudrouter import SQLAlchemyCRUDRouter +from sqlalchemy import Column, ForeignKey, Integer, String, Table, inspect +from sqlalchemy.orm import relationship +from sqlalchemy.ext.declarative import declarative_base + +from tests.implementations.sqlalchemy_ import _setup_base_app +from tests import ORMModel + +if TYPE_CHECKING: + typeguard = True +else: + typeguard = False + +HEROES_URL = "/heroes" +TEAMS_URL = "/teams" +SCHOOLS_URL = "/schools" + +Base = declarative_base() + + +hero_school_link = Table( + "hero_school_link", + Base.metadata, + Column("school_id", Integer, ForeignKey("schools.id"), primary_key=True), + Column("hero_id", Integer, ForeignKey("heroes.id"), primary_key=True), +) + + +class School(Base): + """School DTO.""" + + __tablename__ = "schools" + + id = Column(Integer, primary_key=True, index=True) + name = Column(String, unique=True) + # heroes = relationship("Hero", secondary="hero_school_link", backref="schools") + + +class Team(Base): + """Team DTO.""" + + __tablename__ = "teams" + + id = Column(Integer, primary_key=True, index=True) + name = Column(String, unique=True) + headquarters = Column(String) + heroes = relationship("Hero", back_populates="team") + + +class Hero(Base): + """Hero DTO.""" + + __tablename__ = "heroes" + + id = Column(Integer, primary_key=True, index=True) + name = Column(String) + team_id = Column(Integer, ForeignKey("teams.id")) + team = relationship("Team", back_populates="heroes") + schools = relationship("School", secondary="hero_school_link", backref="heroes") + + +class TeamRead(ORMModel): + """Team Read View.""" + + name: Optional[str] = None + headquarters: Optional[str] = None + + +class TeamCreateUpdate(ORMModel): + """Team Update View.""" + + # TODO: id shouldn't be here + id: Optional[int] = None + name: Optional[str] = None + headquarters: Optional[str] = None + + class Meta: + """Meta info.""" + + orm_model = Team + + +class SchoolRead(ORMModel): + """School Read View.""" + + name: Optional[str] = None + + +class SchoolCreateUpdate(ORMModel): + """School Update/Create View.""" + + id: Optional[int] = None + name: Optional[str] = None + + class Meta: + """Meta info.""" + + orm_model = School + + +class HeroRead(ORMModel): + """Hero Read View.""" + + name: Optional[str] = None + team_id: Optional[int] = None + team: Optional[TeamRead] = None + schools: Optional[List[SchoolRead]] = [] + + +class HeroCreateUpdate(ORMModel): + """Hero Update View.""" + + # TODO: id shouldn't be here + id: Optional[int] = None + name: Optional[str] = None + team_id: Optional[int] = None + team: Optional[TeamCreateUpdate] = None + schools: Optional[List[SchoolCreateUpdate]] = [] + + +def hero_app() -> Callable: + """Fastapi application.""" + app, engine, _, session = _setup_base_app() + hero_router = SQLAlchemyCRUDRouter( + db=session, + schema=HeroRead, + update_schema=HeroCreateUpdate, + create_schema=HeroCreateUpdate, + db_model=Hero, + prefix=HEROES_URL, + ) + app.include_router(hero_router) + team_router = SQLAlchemyCRUDRouter( + db=session, + schema=TeamRead, + update_schema=TeamCreateUpdate, + db_model=Team, + prefix=TEAMS_URL, + ) + app.include_router(team_router) + school_router = SQLAlchemyCRUDRouter( + db=session, + schema=SchoolRead, + update_schema=SchoolCreateUpdate, + create_schema=SchoolCreateUpdate, + db_model=School, + prefix=SCHOOLS_URL, + ) + app.include_router(school_router) + Base.metadata.create_all(bind=engine) + return app, session + + +def object_as_dict(obj): + return {c.key: getattr(obj, c.key) for c in inspect(obj).mapper.column_attrs} + + +def test_get(): + """Get all and get one.""" + app, get_session = hero_app() + client = TestClient(app) + team = Team(name="Avengers", headquarters="Avengers Mansion") + session = next(get_session()) + session.add(team) + hero = Hero(name="Bob", team_id=team.id) + session.add(hero) + session.commit() + session.refresh(hero) + + res = client.get(HEROES_URL) + assert res.status_code == 200 + assert res.json() == [{**HeroRead(**object_as_dict(hero)).dict()}] + + res = client.get(f"/heroes/{hero.id}") + assert res.status_code == 200 + assert res.json() == HeroRead(**object_as_dict(hero)) + + +def test_insert() -> None: + """Test basic sqlmodel insert with relationship attribute as object.hero_client + + This just illustrates what we are trying to do with the crudrouter + from a sqlmodel perspective. + """ + _, get_session = hero_app() + session = next(get_session()) + school_obj1 = School(name="Hero Primary School") + session.add(school_obj1) + + school_obj2 = School(name="Hero High") + session.add(school_obj2) + + team = dict(name="Avengers", headquarters="Avengers Mansion") + team_obj = Team(**team) + session.add(team_obj) + + session.commit() + + session.refresh(team_obj) + session.refresh(school_obj1) + session.refresh(school_obj2) + + hero = dict(name="Bob", team=team_obj, schools=[school_obj1, school_obj2]) + hero_obj = Hero(**hero) + session.add(hero_obj) + session.commit() + session.refresh(hero_obj) + + assert hero["name"] == hero_obj.name + assert hero["team"] == hero_obj.team + assert hero["schools"] == [school_obj1, school_obj2] + + +def test_post_one2many_object(): + """Create an object with a one-to-many relation as object.""" + app, _ = hero_app() + client = TestClient(app) + team = dict(name="Avengers", headquarters="Avengers Mansion") + res = client.post(TEAMS_URL, json=team) + team_return = res.json() + assert res.status_code == 200, res.json() + + hero = dict(name="Bob", team=team) + res = client.post("/heroes", json=hero) + hero_return = res.json() + assert res.status_code == 200, hero_return + assert hero_return["team_id"] == team_return["id"] + + +def test_post_many2many_object() -> None: + """Create an object with a many2many relation value as object.""" + app, _ = hero_app() + client = TestClient(app) + school = dict(name="Hero Primary School") + res = client.post(SCHOOLS_URL, json=school) + school1_return = res.json() + assert res.status_code == 200, school1_return + + school = dict(name="Hero High") + res = client.post(SCHOOLS_URL, json=school) + school2_return = res.json() + assert res.status_code == 200, school2_return + + hero = dict( + name="Bob", + schools=[ + {"id": school1_return["id"]}, + {"id": school2_return["id"]}, + ], + ) + + res = client.post("/heroes", json=hero) + hero_return = res.json() + assert res.status_code == 200, hero_return + assert hero_return["schools"] == [school1_return, school2_return] + + +def test_update_one2many_object(): + """Update an object with a one-to-many relation as object.""" + app, _ = hero_app() + client = TestClient(app) + team = dict(name="Avengers", headquarters="Avengers Mansion") + res = client.post(TEAMS_URL, json=team) + team_return = res.json() + assert res.status_code == 200, res.json() + + hero = dict(name="Bob") + res = client.post("/heroes", json=hero) + hero_return = res.json() + assert res.status_code == 200, hero_return + assert hero_return["team_id"] is None + + hero_update = dict(team={"name": team_return["name"]}) + res = client.put(f"/heroes/{team_return['id']}", json=hero_update) + hero_return = res.json() + assert res.status_code == 200, hero_return + assert hero_return["team_id"] == team_return["id"] + + +def test_update_many2many_object() -> None: + """Create an object and update a many2man relation value as object.""" + app, _ = hero_app() + client = TestClient(app) + school1 = dict(name="Hero Primary School") + res = client.post(SCHOOLS_URL, json=school1) + school1_return = res.json() + assert res.status_code == 200, school1_return + + school2 = dict(name="Hero High") + res = client.post(SCHOOLS_URL, json=school2) + school2_return = res.json() + assert res.status_code == 200, school2_return + + hero = dict(name="Bob") + res = client.post("/heroes", json=hero) + hero_return = res.json() + assert res.status_code == 200, hero_return + assert hero_return["schools"] == [] + + hero_update = dict( + schools=[ + {"id": school1_return["id"]}, + {"id": school2_return["id"]}, + ] + ) + res = client.put(f"/heroes/{hero_return['id']}", json=hero_update) + hero_return = res.json() + assert res.status_code == 200, hero_return + assert hero_return["schools"] == [school1_return, school2_return] From 9578caaf84404a8c974f09b16fa26f47f69e3dc1 Mon Sep 17 00:00:00 2001 From: cycledriver Date: Fri, 14 Oct 2022 08:32:49 -0400 Subject: [PATCH 2/2] [sqlalchemy] cleanup unused code --- tests/test_sqlalchemy_nested_obj.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/test_sqlalchemy_nested_obj.py b/tests/test_sqlalchemy_nested_obj.py index 198f3d7e..62dbcf7a 100644 --- a/tests/test_sqlalchemy_nested_obj.py +++ b/tests/test_sqlalchemy_nested_obj.py @@ -1,7 +1,5 @@ -from typing import TYPE_CHECKING, Callable, Iterator, List, Optional +from typing import TYPE_CHECKING, Callable, List, Optional -import pytest -from fastapi import FastAPI from fastapi.testclient import TestClient from fastapi_crudrouter import SQLAlchemyCRUDRouter from sqlalchemy import Column, ForeignKey, Integer, String, Table, inspect @@ -74,7 +72,6 @@ class TeamRead(ORMModel): class TeamCreateUpdate(ORMModel): """Team Update View.""" - # TODO: id shouldn't be here id: Optional[int] = None name: Optional[str] = None headquarters: Optional[str] = None @@ -115,7 +112,6 @@ class HeroRead(ORMModel): class HeroCreateUpdate(ORMModel): """Hero Update View.""" - # TODO: id shouldn't be here id: Optional[int] = None name: Optional[str] = None team_id: Optional[int] = None