diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000..ce2e561 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,3 @@ +Release type: minor + +Add support for secondary table relationships in SQLAlchemy mapper, addressing a bug and enhancing the loader to handle these relationships efficiently. diff --git a/src/strawberry_sqlalchemy_mapper/exc.py b/src/strawberry_sqlalchemy_mapper/exc.py index 0213276..6c687af 100644 --- a/src/strawberry_sqlalchemy_mapper/exc.py +++ b/src/strawberry_sqlalchemy_mapper/exc.py @@ -35,3 +35,12 @@ def __init__(self, model): f"Model `{model}` is not polymorphic or is not the base model of its " "inheritance chain, and thus cannot be used as an interface." ) + + +class InvalidLocalRemotePairs(Exception): + def __init__(self, relationship_name): + super().__init__( + f"The `local_remote_pairs` for the relationship `{relationship_name}` is invalid or " + "missing. This is likely an issue with the library. " + "Please report this error to the maintainers." + ) diff --git a/src/strawberry_sqlalchemy_mapper/loader.py b/src/strawberry_sqlalchemy_mapper/loader.py index c78e563..b093c0d 100644 --- a/src/strawberry_sqlalchemy_mapper/loader.py +++ b/src/strawberry_sqlalchemy_mapper/loader.py @@ -18,6 +18,8 @@ from sqlalchemy.orm import RelationshipProperty, Session from strawberry.dataloader import DataLoader +from strawberry_sqlalchemy_mapper.exc import InvalidLocalRemotePairs + class StrawberrySQLAlchemyLoader: """ @@ -45,13 +47,22 @@ def __init__( "One of bind or async_bind_factory must be set for loader to function properly." ) - async def _scalars_all(self, *args, **kwargs): + async def _scalars_all(self, *args, query_secondary_tables=False, **kwargs): + # query_secondary_tables explanation: + # We need to retrieve values from both the self_model and related_model. + # To achieve this, we must disable the default SQLAlchemy optimization + # that returns only related_model values. + # This is necessary because we use the keys variable + # to match both related_model and self_model. if self._async_bind_factory: async with self._async_bind_factory() as bind: + if query_secondary_tables: + return (await bind.execute(*args, **kwargs)).all() return (await bind.scalars(*args, **kwargs)).all() - else: - assert self._bind is not None - return self._bind.scalars(*args, **kwargs).all() + assert self._bind is not None + if query_secondary_tables: + return self._bind.execute(*args, **kwargs).all() + return self._bind.scalars(*args, **kwargs).all() def loader_for(self, relationship: RelationshipProperty) -> DataLoader: """ @@ -63,14 +74,72 @@ def loader_for(self, relationship: RelationshipProperty) -> DataLoader: related_model = relationship.entity.entity async def load_fn(keys: List[Tuple]) -> List[Any]: - query = select(related_model).filter( - tuple_(*[remote for _, remote in relationship.local_remote_pairs or []]).in_( - keys + def _build_normal_relationship_query(related_model, relationship, keys): + return select(related_model).filter( + tuple_( + *[remote for _, remote in relationship.local_remote_pairs or []] + ).in_(keys) + ) + + def _build_relationship_with_secondary_table_query( + related_model, relationship, keys + ): + # Use another query when relationship uses a secondary table + self_model = relationship.parent.entity + + if not relationship.local_remote_pairs: + raise InvalidLocalRemotePairs( + f"{related_model.__name__} -- {self_model.__name__}" + ) + + self_model_key_label = str(relationship.local_remote_pairs[0][1].key) + related_model_key_label = str(relationship.local_remote_pairs[1][1].key) + + self_model_key = str(relationship.local_remote_pairs[0][0].key) + related_model_key = str(relationship.local_remote_pairs[1][0].key) + + remote_to_use = relationship.local_remote_pairs[0][1] + query_keys = tuple([item[0] for item in keys]) + + # This query returns rows in this format -> (self_model.key, related_model) + return ( + select( + getattr(self_model, self_model_key).label(self_model_key_label), + related_model, + ) + .join( + relationship.secondary, + getattr(relationship.secondary.c, related_model_key_label) + == getattr(related_model, related_model_key), + ) + .join( + self_model, + getattr(relationship.secondary.c, self_model_key_label) + == getattr(self_model, self_model_key), + ) + .filter(remote_to_use.in_(query_keys)) + ) + + query = ( + _build_normal_relationship_query(related_model, relationship, keys) + if relationship.secondary is None + else _build_relationship_with_secondary_table_query( + related_model, relationship, keys ) ) + if relationship.order_by: query = query.order_by(*relationship.order_by) - rows = await self._scalars_all(query) + + if relationship.secondary is not None: + # We need to retrieve values from both the self_model and related_model. + # To achieve this, we must disable the default SQLAlchemy optimization + # that returns only related_model values. + # This is necessary because we use the keys variable + # to match both related_model and self_model. + rows = await self._scalars_all(query, query_secondary_tables=True) + else: + rows = await self._scalars_all(query) def group_by_remote_key(row: Any) -> Tuple: return tuple( @@ -82,8 +151,13 @@ def group_by_remote_key(row: Any) -> Tuple: ) grouped_keys: Mapping[Tuple, List[Any]] = defaultdict(list) - for row in rows: - grouped_keys[group_by_remote_key(row)].append(row) + if relationship.secondary is None: + for row in rows: + grouped_keys[group_by_remote_key(row)].append(row) + else: + for row in rows: + grouped_keys[(row[0],)].append(row[1]) + if relationship.uselist: return [grouped_keys[key] for key in keys] else: diff --git a/src/strawberry_sqlalchemy_mapper/mapper.py b/src/strawberry_sqlalchemy_mapper/mapper.py index 0b8f426..72154e3 100644 --- a/src/strawberry_sqlalchemy_mapper/mapper.py +++ b/src/strawberry_sqlalchemy_mapper/mapper.py @@ -81,6 +81,7 @@ from strawberry_sqlalchemy_mapper.exc import ( HybridPropertyNotAnnotated, InterfaceModelNotPolymorphic, + InvalidLocalRemotePairs, UnsupportedAssociationProxyTarget, UnsupportedColumnType, UnsupportedDescriptorType, @@ -493,13 +494,28 @@ async def resolve(self, info: Info): if relationship.key not in instance_state.unloaded: related_objects = getattr(self, relationship.key) else: - relationship_key = tuple( - [ + if relationship.secondary is None: + relationship_key = tuple( getattr(self, local.key) for local, _ in relationship.local_remote_pairs or [] if local.key - ] - ) + ) + else: + # If has a secondary table, gets only the first ID + # as additional IDs require a separate query + if not relationship.local_remote_pairs: + raise InvalidLocalRemotePairs( + f"{relationship.entity.entity.__name__} -" + f"- {relationship.parent.entity.__name__}" + ) + + local_remote_pairs_secondary_table_local = relationship.local_remote_pairs[0][0] + relationship_key = tuple( + [ + getattr(self, str(local_remote_pairs_secondary_table_local.key)), + ] + ) + if any(item is None for item in relationship_key): if relationship.uselist: return [] @@ -527,7 +543,9 @@ def connection_resolver_for( if relationship.uselist and not use_list: return self.make_connection_wrapper_resolver( relationship_resolver, - self.model_to_type_or_interface_name(relationship.entity.entity), # type: ignore[arg-type] + self.model_to_type_or_interface_name( + relationship.entity.entity # type: ignore[arg-type] + ), ) else: return relationship_resolver diff --git a/tests/conftest.py b/tests/conftest.py index 09c9c6e..7d65e50 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -118,3 +118,471 @@ def base(): @pytest.fixture def mapper(): return StrawberrySQLAlchemyMapper() + + +@pytest.fixture +def default_employee_department_join_table(base): + EmployeeDepartmentJoinTable = sqlalchemy.Table( + "employee_department_join_table", + base.metadata, + sqlalchemy.Column("employee_id", sqlalchemy.ForeignKey("employee.id"), primary_key=True), + sqlalchemy.Column( + "department_id", sqlalchemy.ForeignKey("department.id"), primary_key=True + ), + ) + return EmployeeDepartmentJoinTable + + +@pytest.fixture +def secondary_tables(base, default_employee_department_join_table): + class Employee(base): + __tablename__ = "employee" + id = sqlalchemy.Column( + sqlalchemy.Integer, autoincrement=True, primary_key=True, nullable=False + ) + name = sqlalchemy.Column(sqlalchemy.String, nullable=False) + role = sqlalchemy.Column(sqlalchemy.String, nullable=True) + department = orm.relationship( + "Department", + secondary="employee_department_join_table", + back_populates="employees", + ) + + class Department(base): + __tablename__ = "department" + id = sqlalchemy.Column(sqlalchemy.Integer, autoincrement=True, primary_key=True) + name = sqlalchemy.Column(sqlalchemy.String, nullable=True) + employees = orm.relationship( + "Employee", + secondary="employee_department_join_table", + back_populates="department", + ) + + return Employee, Department + + +@pytest.fixture +def secondary_tables_with_another_foreign_key(base): + EmployeeDepartmentJoinTable = sqlalchemy.Table( + "employee_department_join_table", + base.metadata, + sqlalchemy.Column( + "employee_name", sqlalchemy.ForeignKey("employee.name"), primary_key=True + ), + sqlalchemy.Column( + "department_id", sqlalchemy.ForeignKey("department.id"), primary_key=True + ), + ) + + class Employee(base): + __tablename__ = "employee" + id = sqlalchemy.Column(sqlalchemy.Integer, autoincrement=True, nullable=False) + name = sqlalchemy.Column(sqlalchemy.String, nullable=False, primary_key=True) + role = sqlalchemy.Column(sqlalchemy.String, nullable=True) + department = orm.relationship( + "Department", + secondary="employee_department_join_table", + back_populates="employees", + ) + + class Department(base): + __tablename__ = "department" + id = sqlalchemy.Column(sqlalchemy.Integer, autoincrement=True, primary_key=True) + name = sqlalchemy.Column(sqlalchemy.String, nullable=True) + employees = orm.relationship( + "Employee", + secondary="employee_department_join_table", + back_populates="department", + ) + + return Employee, Department + + +@pytest.fixture +def secondary_tables_with_more_secondary_tables(base, default_employee_department_join_table): + EmployeeBuildingJoinTable = sqlalchemy.Table( + "employee_building_join_table", + base.metadata, + sqlalchemy.Column("employee_id", sqlalchemy.ForeignKey("employee.id"), primary_key=True), + sqlalchemy.Column("building_id", sqlalchemy.ForeignKey("building.id"), primary_key=True), + ) + + class Employee(base): + __tablename__ = "employee" + id = sqlalchemy.Column(sqlalchemy.Integer, autoincrement=True, primary_key=True) + name = sqlalchemy.Column(sqlalchemy.String, nullable=False) + role = sqlalchemy.Column(sqlalchemy.String, nullable=True) + department = orm.relationship( + "Department", + secondary="employee_department_join_table", + back_populates="employees", + ) + building = orm.relationship( + "Building", + secondary="employee_building_join_table", + back_populates="employees", + ) + + class Department(base): + __tablename__ = "department" + id = sqlalchemy.Column(sqlalchemy.Integer, autoincrement=True, primary_key=True) + name = sqlalchemy.Column(sqlalchemy.String, nullable=False) + employees = orm.relationship( + "Employee", + secondary="employee_department_join_table", + back_populates="department", + ) + + class Building(base): + __tablename__ = "building" + id = sqlalchemy.Column(sqlalchemy.Integer, autoincrement=True, primary_key=True) + name = sqlalchemy.Column(sqlalchemy.String, nullable=False) + employees = orm.relationship( + "Employee", + secondary="employee_building_join_table", + back_populates="building", + ) + + return Employee, Department, Building + + +@pytest.fixture +def secondary_tables_with_use_list_false(base, default_employee_department_join_table): + class Employee(base): + __tablename__ = "employee" + id = sqlalchemy.Column(sqlalchemy.Integer, autoincrement=True, primary_key=True) + name = sqlalchemy.Column(sqlalchemy.String, nullable=False) + role = sqlalchemy.Column(sqlalchemy.String, nullable=True) + department = orm.relationship( + "Department", + secondary="employee_department_join_table", + back_populates="employees", + ) + + class Department(base): + __tablename__ = "department" + id = sqlalchemy.Column(sqlalchemy.Integer, autoincrement=True, primary_key=True) + name = sqlalchemy.Column(sqlalchemy.String, nullable=False) + employees = orm.relationship( + "Employee", + secondary="employee_department_join_table", + back_populates="department", + uselist=False, + ) + + return Employee, Department + + +@pytest.fixture +def secondary_tables_with_normal_relationship(base, default_employee_department_join_table): + class Employee(base): + __tablename__ = "employee" + id = sqlalchemy.Column(sqlalchemy.Integer, autoincrement=True, primary_key=True) + name = sqlalchemy.Column(sqlalchemy.String, nullable=False) + role = sqlalchemy.Column(sqlalchemy.String, nullable=True) + department = orm.relationship( + "Department", + secondary="employee_department_join_table", + back_populates="employees", + ) + building_id = sqlalchemy.Column(sqlalchemy.Integer, sqlalchemy.ForeignKey("building.id")) + building = orm.relationship( + "Building", + back_populates="employees", + ) + + class Department(base): + __tablename__ = "department" + id = sqlalchemy.Column(sqlalchemy.Integer, autoincrement=True, primary_key=True) + name = sqlalchemy.Column(sqlalchemy.String, nullable=False) + employees = orm.relationship( + "Employee", + secondary="employee_department_join_table", + back_populates="department", + ) + + class Building(base): + __tablename__ = "building" + id = sqlalchemy.Column(sqlalchemy.Integer, autoincrement=True, primary_key=True) + name = sqlalchemy.Column(sqlalchemy.String, nullable=False) + employees = orm.relationship( + "Employee", + back_populates="building", + ) + + return Employee, Department, Building + + +@pytest.fixture +def expected_schema_from_secondary_tables(): + return ''' + type Department { + id: Int! + name: String + employees: EmployeeConnection! + } + + type DepartmentConnection { + """Pagination data for this connection""" + pageInfo: PageInfo! + edges: [DepartmentEdge!]! + } + + type DepartmentEdge { + """A cursor for use in pagination""" + cursor: String! + + """The item at the end of the edge""" + node: Department! + } + + type Employee { + id: Int! + name: String! + role: String + department: DepartmentConnection! + } + + type EmployeeConnection { + """Pagination data for this connection""" + pageInfo: PageInfo! + edges: [EmployeeEdge!]! + } + + type EmployeeEdge { + """A cursor for use in pagination""" + cursor: String! + + """The item at the end of the edge""" + node: Employee! + } + + """Information to aid in pagination.""" + type PageInfo { + """When paginating forwards, are there more items?""" + hasNextPage: Boolean! + + """When paginating backwards, are there more items?""" + hasPreviousPage: Boolean! + + """When paginating backwards, the cursor to continue.""" + startCursor: String + + """When paginating forwards, the cursor to continue.""" + endCursor: String + } + + type Query { + departments: [Department!]! + } + ''' + + +@pytest.fixture +def expected_schema_from_secondary_tables_with_more_secondary_tables(): + return ''' + type Building { + id: Int! + name: String! + employees: EmployeeConnection! + } + + type BuildingConnection { + """Pagination data for this connection""" + pageInfo: PageInfo! + edges: [BuildingEdge!]! + } + + type BuildingEdge { + """A cursor for use in pagination""" + cursor: String! + + """The item at the end of the edge""" + node: Building! + } + + type Department { + id: Int! + name: String! + employees: EmployeeConnection! + } + + type DepartmentConnection { + """Pagination data for this connection""" + pageInfo: PageInfo! + edges: [DepartmentEdge!]! + } + + type DepartmentEdge { + """A cursor for use in pagination""" + cursor: String! + + """The item at the end of the edge""" + node: Department! + } + + type Employee { + id: Int! + name: String! + role: String + department: DepartmentConnection! + building: BuildingConnection! + } + + type EmployeeConnection { + """Pagination data for this connection""" + pageInfo: PageInfo! + edges: [EmployeeEdge!]! + } + + type EmployeeEdge { + """A cursor for use in pagination""" + cursor: String! + + """The item at the end of the edge""" + node: Employee! + } + + """Information to aid in pagination.""" + type PageInfo { + """When paginating forwards, are there more items?""" + hasNextPage: Boolean! + + """When paginating backwards, are there more items?""" + hasPreviousPage: Boolean! + + """When paginating backwards, the cursor to continue.""" + startCursor: String + + """When paginating forwards, the cursor to continue.""" + endCursor: String + } + + type Query { + departments: [Department!]! + } + ''' + + +@pytest.fixture +def expected_schema_from_secondary_tables_with_more_secondary_tables_with_use_list_false(): + return ''' + type Department { + id: Int! + name: String! + employees: Employee + } + + type DepartmentConnection { + """Pagination data for this connection""" + pageInfo: PageInfo! + edges: [DepartmentEdge!]! + } + + type DepartmentEdge { + """A cursor for use in pagination""" + cursor: String! + + """The item at the end of the edge""" + node: Department! + } + + type Employee { + id: Int! + name: String! + role: String + department: DepartmentConnection! + } + + """Information to aid in pagination.""" + type PageInfo { + """When paginating forwards, are there more items?""" + hasNextPage: Boolean! + + """When paginating backwards, are there more items?""" + hasPreviousPage: Boolean! + + """When paginating backwards, the cursor to continue.""" + startCursor: String + + """When paginating forwards, the cursor to continue.""" + endCursor: String + } + + type Query { + departments: [Department!]! + } + ''' + + +@pytest.fixture +def expected_schema_from_secondary_tables_with_more_secondary_tables_with_normal_relationship(): + return ''' + type Building { + id: Int! + name: String! + employees: EmployeeConnection! + } + + type Department { + id: Int! + name: String! + employees: EmployeeConnection! + } + + type DepartmentConnection { + """Pagination data for this connection""" + pageInfo: PageInfo! + edges: [DepartmentEdge!]! + } + + type DepartmentEdge { + """A cursor for use in pagination""" + cursor: String! + + """The item at the end of the edge""" + node: Department! + } + + type Employee { + id: Int! + name: String! + role: String + buildingId: Int + department: DepartmentConnection! + building: Building + } + + type EmployeeConnection { + """Pagination data for this connection""" + pageInfo: PageInfo! + edges: [EmployeeEdge!]! + } + + type EmployeeEdge { + """A cursor for use in pagination""" + cursor: String! + + """The item at the end of the edge""" + node: Employee! + } + + """Information to aid in pagination.""" + type PageInfo { + """When paginating forwards, are there more items?""" + hasNextPage: Boolean! + + """When paginating backwards, are there more items?""" + hasPreviousPage: Boolean! + + """When paginating backwards, the cursor to continue.""" + startCursor: String + + """When paginating forwards, the cursor to continue.""" + endCursor: String + } + + type Query { + departments: [Department!]! + } + ''' diff --git a/tests/relay/test_connection.py b/tests/relay/test_connection.py index a46594d..4c10eb5 100644 --- a/tests/relay/test_connection.py +++ b/tests/relay/test_connection.py @@ -7,7 +7,11 @@ from sqlalchemy.ext.asyncio.engine import AsyncEngine from sqlalchemy.orm import sessionmaker from strawberry import relay -from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper, connection +from strawberry_sqlalchemy_mapper import ( + StrawberrySQLAlchemyLoader, + StrawberrySQLAlchemyMapper, + connection, +) from strawberry_sqlalchemy_mapper.relay import KeysetConnection @@ -751,3 +755,163 @@ class Query: }, } } + + +@pytest.mark.asyncio +async def test_query_with_secondary_table_with_values_list( + secondary_tables, base, async_engine, async_sessionmaker +): + async with async_engine.begin() as conn: + await conn.run_sync(base.metadata.create_all) + + mapper = StrawberrySQLAlchemyMapper() + EmployeeModel, DepartmentModel = secondary_tables + + @mapper.type(DepartmentModel) + class Department: + pass + + @mapper.type(EmployeeModel) + class Employee: + pass + + @strawberry.type + class Query: + departments: relay.ListConnection[Department] = connection(sessionmaker=async_sessionmaker) + + mapper.finalize() + schema = strawberry.Schema(query=Query) + + query = """\ + query { + departments { + edges { + node { + id + name + employees { + edges { + node { + id + name + role + department { + edges { + node { + id + name + } + } + } + } + } + } + } + } + } + } + """ + + # Create test data + async with async_sessionmaker(expire_on_commit=False) as session: + department1 = DepartmentModel(id=10, name="Department Test 1") + department2 = DepartmentModel(id=3, name="Department Test 2") + e1 = EmployeeModel(id=1, name="John", role="Developer") + e2 = EmployeeModel(id=5, name="Bill", role="Doctor") + e3 = EmployeeModel(id=4, name="Maria", role="Teacher") + department1.employees.append(e1) + department1.employees.append(e2) + department2.employees.append(e3) + session.add_all([department1, department2, e1, e2, e3]) + await session.commit() + + result = await schema.execute( + query, + context_value={ + "sqlalchemy_loader": StrawberrySQLAlchemyLoader( + async_bind_factory=async_sessionmaker + ) + }, + ) + assert result.errors is None + assert result.data == { + "departments": { + "edges": [ + { + "node": { + "id": 10, + "name": "Department Test 1", + "employees": { + "edges": [ + { + "node": { + "id": 5, + "name": "Bill", + "role": "Doctor", + "department": { + "edges": [ + { + "node": { + "id": 10, + "name": "Department Test 1", + } + } + ] + }, + } + }, + { + "node": { + "id": 1, + "name": "John", + "role": "Developer", + "department": { + "edges": [ + { + "node": { + "id": 10, + "name": "Department Test 1", + } + } + ] + }, + } + }, + ] + }, + } + }, + { + "node": { + "id": 3, + "name": "Department Test 2", + "employees": { + "edges": [ + { + "node": { + "id": 4, + "name": "Maria", + "role": "Teacher", + "department": { + "edges": [ + { + "node": { + "id": 3, + "name": "Department Test 2", + } + } + ] + }, + } + } + ] + }, + } + }, + ] + } + } + + +# TODO: Add a test with keyset connection with secondary tables +# TODO: Add a test only to check the duplication of connections on mapper diff --git a/tests/test_loader.py b/tests/test_loader.py index 7a3c9a9..2529051 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -1,7 +1,8 @@ import pytest -from sqlalchemy import Column, ForeignKey, Integer, String, Table +from sqlalchemy import Column, ForeignKey, Integer, String from sqlalchemy.orm import relationship from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyLoader +from strawberry_sqlalchemy_mapper.exc import InvalidLocalRemotePairs pytest_plugins = ("pytest_asyncio",) @@ -26,38 +27,6 @@ class Department(base): return Employee, Department -@pytest.fixture -def secondary_tables(base): - EmployeeDepartmentJoinTable = Table( - "employee_department_join_table", - base.metadata, - Column("employee_id", ForeignKey("employee.e_id"), primary_key=True), - Column("department_id", ForeignKey("department.d_id"), primary_key=True), - ) - - class Employee(base): - __tablename__ = "employee" - e_id = Column(Integer, autoincrement=True, primary_key=True) - name = Column(String, nullable=False) - departments = relationship( - "Department", - secondary="employee_department_join_table", - back_populates="employees", - ) - - class Department(base): - __tablename__ = "department" - d_id = Column(Integer, autoincrement=True, primary_key=True) - name = Column(String, nullable=False) - employees = relationship( - "Employee", - secondary="employee_department_join_table", - back_populates="departments", - ) - - return Employee, Department - - def test_loader_init(): loader = StrawberrySQLAlchemyLoader(bind=None) assert loader._bind is None @@ -140,36 +109,171 @@ async def test_loader_with_async_session( assert {e.name for e in employees} == {"e1"} -@pytest.mark.xfail +def create_default_data_on_secondary_table_tests(session, employee, department): + e1 = employee(name="e1", id=1) + e2 = employee(name="e2", id=2) + d1 = department(name="d1") + d2 = department(name="d2") + d3 = department(name="d3") + session.add_all([e1, e2, d1, d2, d3]) + session.flush() + + e1.department.append(d1) + e1.department.append(d2) + e2.department.append(d2) + return e1, e2, d1, d2, d3 + + @pytest.mark.asyncio -async def test_loader_for_secondary(engine, base, sessionmaker, secondary_tables): +async def test_loader_for_secondary_table(engine, base, sessionmaker, secondary_tables): Employee, Department = secondary_tables base.metadata.create_all(engine) with sessionmaker() as session: - e1 = Employee(name="e1") - e2 = Employee(name="e2") - d1 = Department(name="d1") - d2 = Department(name="d2") - session.add(e1) - session.add(e2) - session.add(d1) - session.add(d2) - session.flush() + e1, _, _, _, _ = create_default_data_on_secondary_table_tests( + session=session, employee=Employee, department=Department + ) + session.commit() - e1.departments.append(d1) - e1.departments.append(d2) - e2.departments.append(d2) + base_loader = StrawberrySQLAlchemyLoader(bind=session) + loader = base_loader.loader_for(Employee.department.property) + + key = tuple( + [ + getattr(e1, str(Employee.department.property.local_remote_pairs[0][0].key)), + ] + ) + + departments = await loader.load(key) + assert {d.name for d in departments} == {"d1", "d2"} + + +@pytest.mark.asyncio +async def test_loader_for_secondary_tables_with_another_foreign_key( + engine, base, sessionmaker, secondary_tables_with_another_foreign_key +): + Employee, Department = secondary_tables_with_another_foreign_key + base.metadata.create_all(engine) + + with sessionmaker() as session: + e1, _, _, _, _ = create_default_data_on_secondary_table_tests( + session=session, employee=Employee, department=Department + ) session.commit() base_loader = StrawberrySQLAlchemyLoader(bind=session) - loader = base_loader.loader_for(Employee.departments.property) + loader = base_loader.loader_for(Employee.department.property) key = tuple( [ - getattr(e1, local.key) - for local, _ in Employee.departments.property.local_remote_pairs + getattr(e1, str(Employee.department.property.local_remote_pairs[0][0].key)), ] ) + departments = await loader.load(key) assert {d.name for d in departments} == {"d1", "d2"} + + +@pytest.mark.asyncio +async def test_loader_for_secondary_tables_with_more_secondary_tables( + engine, base, sessionmaker, secondary_tables_with_more_secondary_tables +): + Employee, Department, Building = secondary_tables_with_more_secondary_tables + base.metadata.create_all(engine) + + with sessionmaker() as session: + e1, e2, _, _, _ = create_default_data_on_secondary_table_tests( + session=session, employee=Employee, department=Department + ) + + b1 = Building(id=2, name="Building 1") + b1.employees.append(e1) + b1.employees.append(e2) + session.add(b1) + session.commit() + + base_loader = StrawberrySQLAlchemyLoader(bind=session) + loader = base_loader.loader_for(Employee.department.property) + + key = tuple( + [ + getattr(e1, str(Employee.department.property.local_remote_pairs[0][0].key)), + ] + ) + + departments = await loader.load(key) + assert {d.name for d in departments} == {"d1", "d2"} + + +@pytest.mark.asyncio +async def test_loader_for_secondary_tables_with_use_list_false( + engine, base, sessionmaker, secondary_tables_with_use_list_false +): + Employee, Department = secondary_tables_with_use_list_false + base.metadata.create_all(engine) + + with sessionmaker() as session: + e1, _, _, _, _ = create_default_data_on_secondary_table_tests( + session=session, employee=Employee, department=Department + ) + session.commit() + + base_loader = StrawberrySQLAlchemyLoader(bind=session) + loader = base_loader.loader_for(Employee.department.property) + + key = tuple( + [ + getattr(e1, str(Employee.department.property.local_remote_pairs[0][0].key)), + ] + ) + + departments = await loader.load(key) + assert {d.name for d in departments} == {"d1"} + + +@pytest.mark.asyncio +async def test_loader_for_secondary_tables_with_normal_relationship( + engine, base, sessionmaker, secondary_tables_with_normal_relationship +): + Employee, Department, Building = secondary_tables_with_normal_relationship + base.metadata.create_all(engine) + + with sessionmaker() as session: + e1, e2, _, _, _ = create_default_data_on_secondary_table_tests( + session=session, employee=Employee, department=Department + ) + + b1 = Building(id=2, name="Building 1") + b1.employees.append(e1) + b1.employees.append(e2) + session.add(b1) + session.commit() + + base_loader = StrawberrySQLAlchemyLoader(bind=session) + loader = base_loader.loader_for(Employee.department.property) + + key = tuple( + [ + getattr(e1, str(Employee.department.property.local_remote_pairs[0][0].key)), + ] + ) + + departments = await loader.load(key) + assert {d.name for d in departments} == {"d1", "d2"} + + +@pytest.mark.asyncio +async def test_loader_secondary_tables_should_raise_exc_if_relationship_dont_has_local_remote_pairs( + engine, base, sessionmaker, secondary_tables_with_normal_relationship +): + Employee, Department, Building = secondary_tables_with_normal_relationship + base.metadata.create_all(engine) + + with sessionmaker() as session: + base_loader = StrawberrySQLAlchemyLoader(bind=session) + + Employee.department.property.local_remote_pairs = [] + loader = base_loader.loader_for(Employee.department.property) + + with pytest.raises(expected_exception=InvalidLocalRemotePairs): + await loader.load((1,)) diff --git a/tests/test_mapper.py b/tests/test_mapper.py index 1207404..5d67c46 100644 --- a/tests/test_mapper.py +++ b/tests/test_mapper.py @@ -365,6 +365,155 @@ def departments(self) -> Department: ... assert str(schema) == textwrap.dedent(expected).strip() +def test_relationships_schema_with_secondary_tables( + secondary_tables, mapper, expected_schema_from_secondary_tables +): + EmployeeModel, DepartmentModel = secondary_tables + + @mapper.type(EmployeeModel) + class Employee: + pass + + @mapper.type(DepartmentModel) + class Department: + pass + + @strawberry.type + class Query: + @strawberry.field + def departments(self) -> List[Department]: ... + + mapper.finalize() + schema = strawberry.Schema(query=Query) + + assert str(schema) == textwrap.dedent(expected_schema_from_secondary_tables).strip() + + +def test_relationships_schema_with_secondary_tables_with_another_foreign_key( + secondary_tables_with_another_foreign_key, + mapper, + expected_schema_from_secondary_tables, +): + EmployeeModel, DepartmentModel = secondary_tables_with_another_foreign_key + + @mapper.type(EmployeeModel) + class Employee: + pass + + @mapper.type(DepartmentModel) + class Department: + pass + + @strawberry.type + class Query: + @strawberry.field + def departments(self) -> List[Department]: ... + + mapper.finalize() + schema = strawberry.Schema(query=Query) + + assert str(schema) == textwrap.dedent(expected_schema_from_secondary_tables).strip() + + +def test_relationships_schema_with_secondary_tables_with_more_secondary_tables( + secondary_tables_with_more_secondary_tables, + mapper, + expected_schema_from_secondary_tables_with_more_secondary_tables, +): + EmployeeModel, DepartmentModel, BuildingModel = secondary_tables_with_more_secondary_tables + + @mapper.type(EmployeeModel) + class Employee: + pass + + @mapper.type(DepartmentModel) + class Department: + pass + + @mapper.type(BuildingModel) + class Building: + pass + + @strawberry.type + class Query: + @strawberry.field + def departments(self) -> List[Department]: ... + + mapper.finalize() + schema = strawberry.Schema(query=Query) + + assert ( + str(schema) + == textwrap.dedent(expected_schema_from_secondary_tables_with_more_secondary_tables).strip() + ) + + +def test_relationships_schema_with_secondary_tables_with_use_list_false( + secondary_tables_with_use_list_false, + mapper, + expected_schema_from_secondary_tables_with_more_secondary_tables_with_use_list_false, +): + EmployeeModel, DepartmentModel = secondary_tables_with_use_list_false + + @mapper.type(EmployeeModel) + class Employee: + pass + + @mapper.type(DepartmentModel) + class Department: + pass + + @strawberry.type + class Query: + @strawberry.field + def departments(self) -> List[Department]: ... + + mapper.finalize() + schema = strawberry.Schema(query=Query) + + assert ( + str(schema) + == textwrap.dedent( + expected_schema_from_secondary_tables_with_more_secondary_tables_with_use_list_false + ).strip() + ) + + +def test_relationships_schema_with_secondary_tables_with_normal_relationship( + secondary_tables_with_normal_relationship, + mapper, + expected_schema_from_secondary_tables_with_more_secondary_tables_with_normal_relationship, +): + EmployeeModel, DepartmentModel, BuildingModel = secondary_tables_with_normal_relationship + + @mapper.type(EmployeeModel) + class Employee: + pass + + @mapper.type(DepartmentModel) + class Department: + pass + + @mapper.type(BuildingModel) + class Building: + pass + + @strawberry.type + class Query: + @strawberry.field + def departments(self) -> List[Department]: ... + + mapper.finalize() + schema = strawberry.Schema(query=Query) + + assert ( + str(schema) + == textwrap.dedent( + expected_schema_from_secondary_tables_with_more_secondary_tables_with_normal_relationship + ).strip() + ) + + @pytest.mark.parametrize( "directives", [ diff --git a/tests/test_secondary_tables_query.py b/tests/test_secondary_tables_query.py new file mode 100644 index 0000000..3e4b330 --- /dev/null +++ b/tests/test_secondary_tables_query.py @@ -0,0 +1,872 @@ +from typing import List + +import pytest +import strawberry +from sqlalchemy import select +from strawberry import relay +from strawberry_sqlalchemy_mapper import ( + StrawberrySQLAlchemyLoader, + StrawberrySQLAlchemyMapper, + connection, +) + + +@pytest.fixture +def default_query_secondary_table(): + return """ + query { + departments { + id + name + employees { + edges { + node { + id + name + role + department { + edges { + node { + id + name + } + } + } + } + } + } + } + } + """ + + +def created_default_secondary_table_data(session, employee_model, department_model): + department1 = department_model(id=10, name="Department Test 1") + department2 = department_model(id=3, name="Department Test 2") + e1 = employee_model(id=1, name="John", role="Developer") + e2 = employee_model(id=5, name="Bill", role="Doctor") + e3 = employee_model(id=4, name="Maria", role="Teacher") + department1.employees.append(e1) + department1.employees.append(e2) + department2.employees.append(e3) + session.add_all([department1, department2, e1, e2, e3]) + return e1, e2, e3, department1, department2 + + +@pytest.mark.asyncio +async def test_query_with_secondary_table_with_values_list_without_list_connection( + secondary_tables, + base, + async_engine, + async_sessionmaker, + default_query_secondary_table, +): + async with async_engine.begin() as conn: + await conn.run_sync(base.metadata.create_all) + + mapper = StrawberrySQLAlchemyMapper() + EmployeeModel, DepartmentModel = secondary_tables + + @mapper.type(DepartmentModel) + class Department: + pass + + @mapper.type(EmployeeModel) + class Employee: + pass + + @strawberry.type + class Query: + @strawberry.field + async def departments(self) -> List[Department]: + async with async_sessionmaker() as session: + result = await session.execute(select(DepartmentModel)) + return result.scalars().all() + + mapper.finalize() + schema = strawberry.Schema(query=Query) + + # Create test data + async with async_sessionmaker(expire_on_commit=False) as session: + created_default_secondary_table_data( + session=session, + employee_model=EmployeeModel, + department_model=DepartmentModel, + ) + await session.commit() + + result = await schema.execute( + default_query_secondary_table, + context_value={ + "sqlalchemy_loader": StrawberrySQLAlchemyLoader( + async_bind_factory=async_sessionmaker + ) + }, + ) + assert result.errors is None + assert result.data == { + "departments": [ + { + "id": 10, + "name": "Department Test 1", + "employees": { + "edges": [ + { + "node": { + "id": 5, + "name": "Bill", + "role": "Doctor", + "department": { + "edges": [ + { + "node": { + "id": 10, + "name": "Department Test 1", + } + } + ] + }, + } + }, + { + "node": { + "id": 1, + "name": "John", + "role": "Developer", + "department": { + "edges": [ + { + "node": { + "id": 10, + "name": "Department Test 1", + } + } + ] + }, + } + }, + ] + }, + }, + { + "id": 3, + "name": "Department Test 2", + "employees": { + "edges": [ + { + "node": { + "id": 4, + "name": "Maria", + "role": "Teacher", + "department": { + "edges": [ + { + "node": { + "id": 3, + "name": "Department Test 2", + } + } + ] + }, + } + } + ] + }, + }, + ] + } + + +@pytest.mark.asyncio +async def test_query_with_secondary_table_with_values_list_with_foreign_key_different_than_id( + secondary_tables_with_another_foreign_key, + base, + async_engine, + async_sessionmaker, + default_query_secondary_table, +): + async with async_engine.begin() as conn: + await conn.run_sync(base.metadata.create_all) + + mapper = StrawberrySQLAlchemyMapper() + EmployeeModel, DepartmentModel = secondary_tables_with_another_foreign_key + + @mapper.type(DepartmentModel) + class Department: + pass + + @mapper.type(EmployeeModel) + class Employee: + pass + + @strawberry.type + class Query: + @strawberry.field + async def departments(self) -> List[Department]: + async with async_sessionmaker() as session: + result = await session.execute(select(DepartmentModel)) + return result.scalars().all() + + mapper.finalize() + schema = strawberry.Schema(query=Query) + + # Create test data + async with async_sessionmaker(expire_on_commit=False) as session: + created_default_secondary_table_data( + session=session, + employee_model=EmployeeModel, + department_model=DepartmentModel, + ) + await session.commit() + + result = await schema.execute( + default_query_secondary_table, + context_value={ + "sqlalchemy_loader": StrawberrySQLAlchemyLoader( + async_bind_factory=async_sessionmaker + ) + }, + ) + assert result.errors is None + assert result.data == { + "departments": [ + { + "id": 10, + "name": "Department Test 1", + "employees": { + "edges": [ + { + "node": { + "id": 5, + "name": "Bill", + "role": "Doctor", + "department": { + "edges": [ + { + "node": { + "id": 10, + "name": "Department Test 1", + } + } + ] + }, + } + }, + { + "node": { + "id": 1, + "name": "John", + "role": "Developer", + "department": { + "edges": [ + { + "node": { + "id": 10, + "name": "Department Test 1", + } + } + ] + }, + } + }, + ] + }, + }, + { + "id": 3, + "name": "Department Test 2", + "employees": { + "edges": [ + { + "node": { + "id": 4, + "name": "Maria", + "role": "Teacher", + "department": { + "edges": [ + { + "node": { + "id": 3, + "name": "Department Test 2", + } + } + ] + }, + } + } + ] + }, + }, + ] + } + + +@pytest.mark.asyncio +async def test_query_with_secondary_tables_with_more_than_2_colluns_values_list( + secondary_tables_with_more_secondary_tables, base, async_engine, async_sessionmaker +): + async with async_engine.begin() as conn: + await conn.run_sync(base.metadata.create_all) + + mapper = StrawberrySQLAlchemyMapper() + EmployeeModel, DepartmentModel, BuildingModel = ( + secondary_tables_with_more_secondary_tables + ) + + @mapper.type(DepartmentModel) + class Department: + pass + + @mapper.type(EmployeeModel) + class Employee: + pass + + @mapper.type(BuildingModel) + class Building: + pass + + @strawberry.type + class Query: + @strawberry.field + async def departments(self) -> List[Department]: + async with async_sessionmaker() as session: + result = await session.execute(select(DepartmentModel)) + return result.scalars().all() + + mapper.finalize() + schema = strawberry.Schema(query=Query) + + query = """ + query { + departments { + id + name + employees { + edges { + node { + id + name + role + department { + edges { + node { + id + name + } + } + }, + building { + edges { + node { + id + name + } + } + } + } + } + } + } + } + """ + + # Create test data + async with async_sessionmaker(expire_on_commit=False) as session: + building = BuildingModel(id=2, name="Building 1") + e1, e2, e3, _, _ = created_default_secondary_table_data( + session=session, + employee_model=EmployeeModel, + department_model=DepartmentModel, + ) + building.employees.append(e1) + building.employees.append(e2) + building.employees.append(e3) + session.add(building) + await session.commit() + + result = await schema.execute( + query, + context_value={ + "sqlalchemy_loader": StrawberrySQLAlchemyLoader( + async_bind_factory=async_sessionmaker + ) + }, + ) + assert result.errors is None + assert result.data == { + "departments": [ + { + "id": 10, + "name": "Department Test 1", + "employees": { + "edges": [ + { + "node": { + "id": 5, + "name": "Bill", + "role": "Doctor", + "department": { + "edges": [ + { + "node": { + "id": 10, + "name": "Department Test 1", + } + } + ] + }, + "building": { + "edges": [ + {"node": {"id": 2, "name": "Building 1"}} + ] + }, + } + }, + { + "node": { + "id": 1, + "name": "John", + "role": "Developer", + "department": { + "edges": [ + { + "node": { + "id": 10, + "name": "Department Test 1", + } + } + ] + }, + "building": { + "edges": [ + {"node": {"id": 2, "name": "Building 1"}} + ] + }, + } + }, + ] + }, + }, + { + "id": 3, + "name": "Department Test 2", + "employees": { + "edges": [ + { + "node": { + "id": 4, + "name": "Maria", + "role": "Teacher", + "department": { + "edges": [ + { + "node": { + "id": 3, + "name": "Department Test 2", + } + } + ] + }, + "building": { + "edges": [ + {"node": {"id": 2, "name": "Building 1"}} + ] + }, + } + } + ] + }, + }, + ] + } + + +@pytest.mark.asyncio +async def test_query_with_secondary_table( + secondary_tables_with_use_list_false, base, async_engine, async_sessionmaker +): + async with async_engine.begin() as conn: + await conn.run_sync(base.metadata.create_all) + mapper = StrawberrySQLAlchemyMapper() + EmployeeModel, DepartmentModel = secondary_tables_with_use_list_false + + @mapper.type(DepartmentModel) + class Department: + pass + + @mapper.type(EmployeeModel) + class Employee: + pass + + @strawberry.type + class Query: + employees: relay.ListConnection[Employee] = connection( + sessionmaker=async_sessionmaker + ) + + mapper.finalize() + schema = strawberry.Schema(query=Query) + + query = """ + query { + employees { + edges { + node { + id + name + role + department { + edges { + node { + id + name + employees { + id + name + role + } + } + } + } + } + } + } + } + """ + + # Create test data + async with async_sessionmaker(expire_on_commit=False) as session: + department = DepartmentModel(name="Department Test") + e1 = EmployeeModel(name="John", role="Developer") + e2 = EmployeeModel(name="Bill", role="Doctor") + e3 = EmployeeModel(name="Maria", role="Teacher") + e1.department.append(department) + session.add_all([department, e1, e2, e3]) + await session.commit() + + result = await schema.execute( + query, + context_value={ + "sqlalchemy_loader": StrawberrySQLAlchemyLoader( + async_bind_factory=async_sessionmaker + ) + }, + ) + assert result.errors is None + assert result.data == { + "employees": { + "edges": [ + { + "node": { + "id": 1, + "name": "John", + "role": "Developer", + "department": { + "edges": [ + { + "node": { + "id": 1, + "name": "Department Test", + "employees": { + "id": 1, + "name": "John", + "role": "Developer", + }, + } + } + ] + }, + } + }, + { + "node": { + "id": 2, + "name": "Bill", + "role": "Doctor", + "department": {"edges": []}, + } + }, + { + "node": { + "id": 3, + "name": "Maria", + "role": "Teacher", + "department": {"edges": []}, + } + }, + ] + } + } + + +@pytest.mark.asyncio +async def test_query_with_secondary_table_without_list_connection( + secondary_tables_with_use_list_false, base, async_engine, async_sessionmaker +): + async with async_engine.begin() as conn: + await conn.run_sync(base.metadata.create_all) + + mapper = StrawberrySQLAlchemyMapper() + EmployeeModel, DepartmentModel = secondary_tables_with_use_list_false + + @mapper.type(DepartmentModel) + class Department: + pass + + @mapper.type(EmployeeModel) + class Employee: + pass + + @strawberry.type + class Query: + @strawberry.field + async def employees(self) -> List[Employee]: + async with async_sessionmaker() as session: + result = await session.execute(select(EmployeeModel)) + return result.scalars().all() + + mapper.finalize() + schema = strawberry.Schema(query=Query) + + query = """ + query { + employees { + id + name + role + department { + edges { + node { + id + name + employees { + id + name + role + } + } + } + } + } + } + """ + + # Create test data + async with async_sessionmaker(expire_on_commit=False) as session: + department = DepartmentModel(name="Department Test") + e1 = EmployeeModel(name="John", role="Developer") + e2 = EmployeeModel(name="Bill", role="Doctor") + e3 = EmployeeModel(name="Maria", role="Teacher") + e1.department.append(department) + session.add_all([department, e1, e2, e3]) + await session.commit() + + result = await schema.execute( + query, + context_value={ + "sqlalchemy_loader": StrawberrySQLAlchemyLoader( + async_bind_factory=async_sessionmaker + ) + }, + ) + assert result.errors is None + assert result.data == { + "employees": [ + { + "id": 1, + "name": "John", + "role": "Developer", + "department": { + "edges": [ + { + "node": { + "id": 1, + "name": "Department Test", + "employees": { + "id": 1, + "name": "John", + "role": "Developer", + }, + } + } + ] + }, + }, + { + "id": 2, + "name": "Bill", + "role": "Doctor", + "department": {"edges": []}, + }, + { + "id": 3, + "name": "Maria", + "role": "Teacher", + "department": {"edges": []}, + }, + ] + } + + +@pytest.mark.asyncio +async def test_query_with_secondary_table_with_values_list_and_normal_relationship( + secondary_tables_with_normal_relationship, base, async_engine, async_sessionmaker +): + async with async_engine.begin() as conn: + await conn.run_sync(base.metadata.create_all) + + mapper = StrawberrySQLAlchemyMapper() + EmployeeModel, DepartmentModel, BuildingModel = ( + secondary_tables_with_normal_relationship + ) + + @mapper.type(DepartmentModel) + class Department: + pass + + @mapper.type(EmployeeModel) + class Employee: + pass + + @mapper.type(BuildingModel) + class Building: + pass + + @strawberry.type + class Query: + @strawberry.field + async def departments(self) -> List[Department]: + async with async_sessionmaker() as session: + result = await session.execute(select(DepartmentModel)) + return result.scalars().all() + + mapper.finalize() + schema = strawberry.Schema(query=Query) + + query = """ + query { + departments { + id + name + employees { + edges { + node { + id + name + role + department { + edges { + node { + id + name + } + } + }, + building { + id + name + } + } + } + } + } + } + """ + + # Create test data + async with async_sessionmaker(expire_on_commit=False) as session: + e1, e2, e3, _, _ = created_default_secondary_table_data( + session=session, + employee_model=EmployeeModel, + department_model=DepartmentModel, + ) + building = BuildingModel(id=2, name="Building 1") + building.employees.append(e1) + building.employees.append(e2) + building.employees.append(e3) + session.add(building) + await session.commit() + + result = await schema.execute( + query, + context_value={ + "sqlalchemy_loader": StrawberrySQLAlchemyLoader( + async_bind_factory=async_sessionmaker + ) + }, + ) + assert result.errors is None + assert result.data == { + "departments": [ + { + "id": 10, + "name": "Department Test 1", + "employees": { + "edges": [ + { + "node": { + "id": 5, + "name": "Bill", + "role": "Doctor", + "department": { + "edges": [ + { + "node": { + "id": 10, + "name": "Department Test 1", + } + } + ] + }, + "building": {"id": 2, "name": "Building 1"}, + } + }, + { + "node": { + "id": 1, + "name": "John", + "role": "Developer", + "department": { + "edges": [ + { + "node": { + "id": 10, + "name": "Department Test 1", + } + } + ] + }, + "building": {"id": 2, "name": "Building 1"}, + } + }, + ] + }, + }, + { + "id": 3, + "name": "Department Test 2", + "employees": { + "edges": [ + { + "node": { + "id": 4, + "name": "Maria", + "role": "Teacher", + "department": { + "edges": [ + { + "node": { + "id": 3, + "name": "Department Test 2", + } + } + ] + }, + "building": {"id": 2, "name": "Building 1"}, + } + } + ] + }, + }, + ] + }