From 6f629d79ba587ca97843871f3497be363bcdc6ff Mon Sep 17 00:00:00 2001 From: Natarajan Krishnaswami Date: Wed, 5 Feb 2025 18:27:00 -0500 Subject: [PATCH 1/4] Minimal change to allow `attribute_keyed_dict` + test --- sqlmodel/_compat.py | 3 ++ tests/test_attribute_keyed_dict.py | 47 ++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 tests/test_attribute_keyed_dict.py diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 4e80cdc374..a23d544f4b 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -156,6 +156,9 @@ def get_relationship_to( # If a list, then also get the real field elif origin is list: use_annotation = get_args(annotation)[0] + # If a dict, then use the value type + elif origin is dict: + use_annotation = get_args(annotation)[1] return get_relationship_to( name=name, rel_info=rel_info, annotation=use_annotation diff --git a/tests/test_attribute_keyed_dict.py b/tests/test_attribute_keyed_dict.py new file mode 100644 index 0000000000..6dfe5ffeab --- /dev/null +++ b/tests/test_attribute_keyed_dict.py @@ -0,0 +1,47 @@ +from enum import StrEnum + +from sqlalchemy.orm.collections import attribute_keyed_dict +from sqlmodel import Field, Index, Relationship, Session, SQLModel, create_engine + + +def test_attribute_keyed_dict_works(clear_sqlmodel): + class Color(StrEnum): + Orange = "Orange" + Blue = "Blue" + + class Child(SQLModel, table=True): + __tablename__ = "children" + __table_args__ = ( + Index("ix_children_parent_id_color", "parent_id", "color", unique=True), + ) + + id: int | None = Field(primary_key=True, default=None) + parent_id: int = Field(foreign_key="parents.id") + color: Color + value: int + + class Parent(SQLModel, table=True): + __tablename__ = "parents" + + id: int | None = Field(primary_key=True, default=None) + children_by_color: dict[Color, Child] = Relationship( + sa_relationship_kwargs={"collection_class": attribute_keyed_dict("color")} + ) + + engine = create_engine("sqlite://") + SQLModel.metadata.create_all(engine) + with Session(engine) as session: + parent = Parent() + session.add(parent) + session.commit() + session.refresh(parent) + session.add(Child(parent_id=parent.id, color=Color.Orange, value=1)) + session.add(Child(parent_id=parent.id, color=Color.Blue, value=2)) + session.commit() + session.refresh(parent) + assert parent.children_by_color[Color.Orange].parent_id == parent.id + assert parent.children_by_color[Color.Orange].color == Color.Orange + assert parent.children_by_color[Color.Orange].value == 1 + assert parent.children_by_color[Color.Blue].parent_id == parent.id + assert parent.children_by_color[Color.Blue].color == Color.Blue + assert parent.children_by_color[Color.Blue].value == 2 From 7f1d08587b6c90070f5e0f23d59373e7fae71d36 Mon Sep 17 00:00:00 2001 From: Natarajan Krishnaswami Date: Thu, 6 Feb 2025 17:11:22 -0500 Subject: [PATCH 2/4] Remove `StrEnum` --- tests/test_attribute_keyed_dict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_attribute_keyed_dict.py b/tests/test_attribute_keyed_dict.py index 6dfe5ffeab..5e8c61ba28 100644 --- a/tests/test_attribute_keyed_dict.py +++ b/tests/test_attribute_keyed_dict.py @@ -1,11 +1,11 @@ -from enum import StrEnum +from enum import Enum from sqlalchemy.orm.collections import attribute_keyed_dict from sqlmodel import Field, Index, Relationship, Session, SQLModel, create_engine def test_attribute_keyed_dict_works(clear_sqlmodel): - class Color(StrEnum): + class Color(str, Enum): Orange = "Orange" Blue = "Blue" From ad7f6bbc1a96ca48fa50307ba4c49b99b0392d32 Mon Sep 17 00:00:00 2001 From: Natarajan Krishnaswami Date: Thu, 6 Feb 2025 17:30:45 -0500 Subject: [PATCH 3/4] Remove type union pipe syntax --- tests/test_attribute_keyed_dict.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_attribute_keyed_dict.py b/tests/test_attribute_keyed_dict.py index 5e8c61ba28..a55a927f4a 100644 --- a/tests/test_attribute_keyed_dict.py +++ b/tests/test_attribute_keyed_dict.py @@ -1,4 +1,5 @@ from enum import Enum +from typing import Optional from sqlalchemy.orm.collections import attribute_keyed_dict from sqlmodel import Field, Index, Relationship, Session, SQLModel, create_engine @@ -15,7 +16,7 @@ class Child(SQLModel, table=True): Index("ix_children_parent_id_color", "parent_id", "color", unique=True), ) - id: int | None = Field(primary_key=True, default=None) + id: Optional[int] = Field(primary_key=True, default=None) parent_id: int = Field(foreign_key="parents.id") color: Color value: int @@ -23,7 +24,7 @@ class Child(SQLModel, table=True): class Parent(SQLModel, table=True): __tablename__ = "parents" - id: int | None = Field(primary_key=True, default=None) + id: Optional[int] = Field(primary_key=True, default=None) children_by_color: dict[Color, Child] = Relationship( sa_relationship_kwargs={"collection_class": attribute_keyed_dict("color")} ) From 563886a03e7e5e87af95a2f7ee4c079f3127ed3d Mon Sep 17 00:00:00 2001 From: Natarajan Krishnaswami Date: Thu, 6 Feb 2025 19:02:28 -0500 Subject: [PATCH 4/4] Use `Dict[]` instead of `dict[]` --- tests/test_attribute_keyed_dict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_attribute_keyed_dict.py b/tests/test_attribute_keyed_dict.py index a55a927f4a..9d06196396 100644 --- a/tests/test_attribute_keyed_dict.py +++ b/tests/test_attribute_keyed_dict.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Optional +from typing import Dict, Optional from sqlalchemy.orm.collections import attribute_keyed_dict from sqlmodel import Field, Index, Relationship, Session, SQLModel, create_engine @@ -25,7 +25,7 @@ class Parent(SQLModel, table=True): __tablename__ = "parents" id: Optional[int] = Field(primary_key=True, default=None) - children_by_color: dict[Color, Child] = Relationship( + children_by_color: Dict[Color, Child] = Relationship( sa_relationship_kwargs={"collection_class": attribute_keyed_dict("color")} )