@@ -433,6 +433,35 @@ def class_schema(
433
433
return _internal_class_schema (clazz , base_schema )
434
434
435
435
436
+ class _TypeMapping :
437
+ """Helper for looking up field types in a chained list of TYPE_MAPPINGs"""
438
+
439
+ def __init__ (self , * mappings : Mapping [Any , Type [marshmallow .fields .Field ]]) -> None :
440
+ self .mappings = mappings
441
+
442
+ _Field = TypeVar ("_Field" , bound = marshmallow .fields .Field )
443
+
444
+ @overload
445
+ def get (self , typ : object , default : Type [_Field ]) -> Type [_Field ]:
446
+ ...
447
+
448
+ @overload
449
+ def get (
450
+ self , typ : object , default : None = None
451
+ ) -> Optional [Type [marshmallow .fields .Field ]]:
452
+ ...
453
+
454
+ def get (
455
+ self , typ : object , default : Optional [Type [_Field ]] = None
456
+ ) -> Optional [Type [marshmallow .fields .Field ]]:
457
+ for mapping in self .mappings :
458
+ try :
459
+ return mapping [typ ]
460
+ except KeyError :
461
+ pass
462
+ return default
463
+
464
+
436
465
@dataclasses .dataclass
437
466
class _SchemaContext :
438
467
"""Global context for an invocation of class_schema."""
@@ -442,6 +471,18 @@ class _SchemaContext:
442
471
base_schema : Optional [Type [marshmallow .Schema ]] = None
443
472
seen_classes : Dict [type , str ] = dataclasses .field (default_factory = dict )
444
473
474
+ def get_type_mapping (
475
+ self , include_marshmallow_default : bool = False
476
+ ) -> _TypeMapping :
477
+ default_mapping = marshmallow .Schema .TYPE_MAPPING
478
+ if self .base_schema is not None :
479
+ mappings = [self .base_schema .TYPE_MAPPING ]
480
+ if include_marshmallow_default :
481
+ mappings .append (default_mapping )
482
+ else :
483
+ mappings = [default_mapping ]
484
+ return _TypeMapping (* mappings )
485
+
445
486
def __enter__ (self ) -> "_SchemaContext" :
446
487
_schema_ctx_stack .push (self )
447
488
return self
@@ -534,10 +575,9 @@ def _internal_class_schema(
534
575
535
576
def _field_by_type (typ : Union [type , Any ]) -> Optional [Type [marshmallow .fields .Field ]]:
536
577
# FIXME: remove this function
537
- base_schema = _schema_ctx_stack .top .base_schema
538
- return (
539
- base_schema and base_schema .TYPE_MAPPING .get (typ )
540
- ) or marshmallow .Schema .TYPE_MAPPING .get (typ )
578
+ schema_ctx = _schema_ctx_stack .top
579
+ type_mapping = schema_ctx .get_type_mapping (include_marshmallow_default = True )
580
+ return type_mapping .get (typ )
541
581
542
582
543
583
def _field_by_supertype (
@@ -605,15 +645,12 @@ def _field_for_generic_type(
605
645
arguments = typing_inspect .get_args (typ , True )
606
646
if origin :
607
647
# Override base_schema.TYPE_MAPPING to change the class used for generic types below
608
- base_schema = _schema_ctx_stack .top . base_schema
609
- type_mapping = base_schema . TYPE_MAPPING if base_schema else {}
648
+ schema_ctx = _schema_ctx_stack .top
649
+ type_mapping = schema_ctx . get_type_mapping ()
610
650
611
651
if origin in (list , List ):
612
652
child_type = _field_for_schema (arguments [0 ])
613
- list_type = cast (
614
- Type [marshmallow .fields .List ],
615
- type_mapping .get (List , marshmallow .fields .List ),
616
- )
653
+ list_type = type_mapping .get (List , default = marshmallow .fields .List )
617
654
return list_type (child_type , ** metadata )
618
655
if origin in (collections .abc .Sequence , Sequence ) or (
619
656
origin in (tuple , Tuple )
@@ -640,15 +677,10 @@ def _field_for_generic_type(
640
677
)
641
678
if origin in (tuple , Tuple ):
642
679
children = tuple (_field_for_schema (arg ) for arg in arguments )
643
- tuple_type = cast (
644
- Type [marshmallow .fields .Tuple ],
645
- type_mapping .get ( # type:ignore[call-overload]
646
- Tuple , marshmallow .fields .Tuple
647
- ),
648
- )
680
+ tuple_type = type_mapping .get (Tuple , default = marshmallow .fields .Tuple )
649
681
return tuple_type (children , ** metadata )
650
682
elif origin in (dict , Dict , collections .abc .Mapping , Mapping ):
651
- dict_type = type_mapping .get (Dict , marshmallow .fields .Dict )
683
+ dict_type = type_mapping .get (Dict , default = marshmallow .fields .Dict )
652
684
return dict_type (
653
685
keys = _field_for_schema (arguments [0 ]),
654
686
values = _field_for_schema (arguments [1 ]),
0 commit comments