Skip to content

Commit 0820fce

Browse files
committed
refactor: add _SchemaContext.get_type_mapping
1 parent 127ed10 commit 0820fce

File tree

1 file changed

+49
-17
lines changed

1 file changed

+49
-17
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,35 @@ def class_schema(
433433
return _internal_class_schema(clazz, base_schema)
434434

435435

436+
class _TypeMapping:
437+
"""Helper for looking up field types in a chained list of TYPE_MAPPINGs"""
438+
439+
def __init__(self, *mappings: Mapping[Any, Type[marshmallow.fields.Field]]) -> None:
440+
self.mappings = mappings
441+
442+
_Field = TypeVar("_Field", bound=marshmallow.fields.Field)
443+
444+
@overload
445+
def get(self, typ: object, default: Type[_Field]) -> Type[_Field]:
446+
...
447+
448+
@overload
449+
def get(
450+
self, typ: object, default: None = None
451+
) -> Optional[Type[marshmallow.fields.Field]]:
452+
...
453+
454+
def get(
455+
self, typ: object, default: Optional[Type[_Field]] = None
456+
) -> Optional[Type[marshmallow.fields.Field]]:
457+
for mapping in self.mappings:
458+
try:
459+
return mapping[typ]
460+
except KeyError:
461+
pass
462+
return default
463+
464+
436465
@dataclasses.dataclass
437466
class _SchemaContext:
438467
"""Global context for an invocation of class_schema."""
@@ -442,6 +471,18 @@ class _SchemaContext:
442471
base_schema: Optional[Type[marshmallow.Schema]] = None
443472
seen_classes: Dict[type, str] = dataclasses.field(default_factory=dict)
444473

474+
def get_type_mapping(
475+
self, include_marshmallow_default: bool = False
476+
) -> _TypeMapping:
477+
default_mapping = marshmallow.Schema.TYPE_MAPPING
478+
if self.base_schema is not None:
479+
mappings = [self.base_schema.TYPE_MAPPING]
480+
if include_marshmallow_default:
481+
mappings.append(default_mapping)
482+
else:
483+
mappings = [default_mapping]
484+
return _TypeMapping(*mappings)
485+
445486
def __enter__(self) -> "_SchemaContext":
446487
_schema_ctx_stack.push(self)
447488
return self
@@ -534,10 +575,9 @@ def _internal_class_schema(
534575

535576
def _field_by_type(typ: Union[type, Any]) -> Optional[Type[marshmallow.fields.Field]]:
536577
# FIXME: remove this function
537-
base_schema = _schema_ctx_stack.top.base_schema
538-
return (
539-
base_schema and base_schema.TYPE_MAPPING.get(typ)
540-
) or marshmallow.Schema.TYPE_MAPPING.get(typ)
578+
schema_ctx = _schema_ctx_stack.top
579+
type_mapping = schema_ctx.get_type_mapping(include_marshmallow_default=True)
580+
return type_mapping.get(typ)
541581

542582

543583
def _field_by_supertype(
@@ -605,15 +645,12 @@ def _field_for_generic_type(
605645
arguments = typing_inspect.get_args(typ, True)
606646
if origin:
607647
# Override base_schema.TYPE_MAPPING to change the class used for generic types below
608-
base_schema = _schema_ctx_stack.top.base_schema
609-
type_mapping = base_schema.TYPE_MAPPING if base_schema else {}
648+
schema_ctx = _schema_ctx_stack.top
649+
type_mapping = schema_ctx.get_type_mapping()
610650

611651
if origin in (list, List):
612652
child_type = _field_for_schema(arguments[0])
613-
list_type = cast(
614-
Type[marshmallow.fields.List],
615-
type_mapping.get(List, marshmallow.fields.List),
616-
)
653+
list_type = type_mapping.get(List, default=marshmallow.fields.List)
617654
return list_type(child_type, **metadata)
618655
if origin in (collections.abc.Sequence, Sequence) or (
619656
origin in (tuple, Tuple)
@@ -640,15 +677,10 @@ def _field_for_generic_type(
640677
)
641678
if origin in (tuple, Tuple):
642679
children = tuple(_field_for_schema(arg) for arg in arguments)
643-
tuple_type = cast(
644-
Type[marshmallow.fields.Tuple],
645-
type_mapping.get( # type:ignore[call-overload]
646-
Tuple, marshmallow.fields.Tuple
647-
),
648-
)
680+
tuple_type = type_mapping.get(Tuple, default=marshmallow.fields.Tuple)
649681
return tuple_type(children, **metadata)
650682
elif origin in (dict, Dict, collections.abc.Mapping, Mapping):
651-
dict_type = type_mapping.get(Dict, marshmallow.fields.Dict)
683+
dict_type = type_mapping.get(Dict, default=marshmallow.fields.Dict)
652684
return dict_type(
653685
keys=_field_for_schema(arguments[0]),
654686
values=_field_for_schema(arguments[1]),

0 commit comments

Comments
 (0)