Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions sqlmodel/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,16 @@ def get_relationship_to(
# If a list, then also get the real field
elif origin is list:
use_annotation = get_args(annotation)[0]
# If a dict, then use the value (second) type argument
elif origin is dict:
args = get_args(annotation)
if len(args) != 2:
raise ValueError(
f"Dict/Mapping relationship field '{name}' has {len(args)} "
"type arguments. Exactly two required (e.g., dict[str, "
"Model])"
)
use_annotation = args[1]

return get_relationship_to(
name=name, rel_info=rel_info, annotation=use_annotation
Expand Down
109 changes: 109 additions & 0 deletions tests/test_attribute_keyed_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import re
from enum import Enum
from typing import Dict, Optional

import pytest
from sqlalchemy.orm.collections import attribute_keyed_dict
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine

from tests.conftest import needs_py310, needs_pydanticv2


def test_attribute_keyed_dict_works(clear_sqlmodel):
class Color(str, Enum):
Orange = "Orange"
Blue = "Blue"

class Child(SQLModel, table=True):
__tablename__ = "children"

id: Optional[int] = Field(primary_key=True, default=None)
parent_id: int = Field(foreign_key="parents.id")
color: Color
value: int

class Parent(SQLModel, table=True):
__tablename__ = "parents"

id: Optional[int] = Field(primary_key=True, default=None)
children_by_color: Dict[Color, Child] = Relationship(
sa_relationship_kwargs={"collection_class": attribute_keyed_dict("color")}
)

engine = create_engine("sqlite://")
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
parent = Parent()
session.add(parent)
session.commit()
session.refresh(parent)
session.add(Child(parent_id=parent.id, color=Color.Orange, value=1))
session.add(Child(parent_id=parent.id, color=Color.Blue, value=2))
session.commit()
session.refresh(parent)
assert parent.children_by_color[Color.Orange].parent_id == parent.id
assert parent.children_by_color[Color.Orange].color == Color.Orange
assert parent.children_by_color[Color.Orange].value == 1
assert parent.children_by_color[Color.Blue].parent_id == parent.id
assert parent.children_by_color[Color.Blue].color == Color.Blue
assert parent.children_by_color[Color.Blue].value == 2


# typing.Dict throws if it receives the wrong number of type arguments, but dict
# (3.10+) does not; and Pydantic v1 fails to process models with dicts with no
# type arguments.
@needs_pydanticv2
@needs_py310
def test_dict_relationship_throws_on_missing_annotation_arg(clear_sqlmodel):
class Color(str, Enum):
Orange = "Orange"
Blue = "Blue"

class Child(SQLModel, table=True):
__tablename__ = "children"

id: Optional[int] = Field(primary_key=True, default=None)
parent_id: int = Field(foreign_key="parents.id")
color: Color
value: int

error_msg_fmt = "Dict/Mapping relationship field 'children_by_color' has {count} type arguments. Exactly two required (e.g., dict[str, Model])"

# No type args
with pytest.raises(ValueError, match=re.escape(error_msg_fmt.format(count=0))):

class Parent(SQLModel, table=True):
__tablename__ = "parents"

id: Optional[int] = Field(primary_key=True, default=None)
children_by_color: dict[()] = Relationship(
sa_relationship_kwargs={
"collection_class": attribute_keyed_dict("color")
}
)

# One type arg
with pytest.raises(ValueError, match=re.escape(error_msg_fmt.format(count=1))):

class Parent(SQLModel, table=True):
__tablename__ = "parents"

id: Optional[int] = Field(primary_key=True, default=None)
children_by_color: dict[Color] = Relationship(
sa_relationship_kwargs={
"collection_class": attribute_keyed_dict("color")
}
)

# Three type args
with pytest.raises(ValueError, match=re.escape(error_msg_fmt.format(count=3))):

class Parent(SQLModel, table=True):
__tablename__ = "parents"

id: Optional[int] = Field(primary_key=True, default=None)
children_by_color: dict[Color, Child, str] = Relationship(
sa_relationship_kwargs={
"collection_class": attribute_keyed_dict("color")
}
)
Loading