Skip to content

Commit d2a699a

Browse files
committed
feat: support enum.Enum subtypes
1 parent 34d521a commit d2a699a

File tree

11 files changed

+64
-2
lines changed

11 files changed

+64
-2
lines changed

src/typelib/marshals/api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ def __call__(self, val: T) -> serdes.MarshalledValueT:
114114
inspection.isliteral: routines.LiteralMarshaller,
115115
# Special handler for Unions...
116116
inspection.isuniontype: routines.UnionMarshaller,
117+
# Special handling for Enums
118+
inspection.isenumtype: routines.EnumMarshaller,
117119
# Non-intersecting types (order doesn't matter here.
118120
inspection.isdatetimetype: routines.DateTimeMarshaller,
119121
inspection.isdatetype: routines.DateMarshaller,

src/typelib/marshals/routines.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import contextlib
77
import datetime
88
import decimal
9+
import enum
910
import fractions
1011
import pathlib
1112
import re
@@ -41,6 +42,7 @@
4142
"SubscriptedMappingMarshaller",
4243
"FixedTupleMarshaller",
4344
"StructuredTypeMarshaller",
45+
"EnumMarshaller",
4446
)
4547

4648

@@ -151,6 +153,20 @@ def __call__(self, val: T) -> str:
151153
PathT = tp.TypeVar("PathT", bound=pathlib.Path)
152154
PathMarshaller = ToStringMarshaller[PathT]
153155

156+
EnumT = tp.TypeVar("EnumT", bound=enum.Enum)
157+
158+
159+
class EnumMarshaller(AbstractMarshaller[EnumT], tp.Generic[EnumT]):
160+
"""A marshaller that converts an [`enum.Enum`][] instance to its assigned value."""
161+
162+
def __call__(self, val: EnumT) -> serdes.MarshalledValueT:
163+
"""Marshal an [`enum.Enum`][] instance into a [`serdes.MarshalledValueT`][].
164+
165+
Args:
166+
val: The enum instance to marshal.
167+
"""
168+
return val.value
169+
154170

155171
PatternT = tp.TypeVar("PatternT", bound=re.Pattern)
156172

src/typelib/py/inspection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -903,7 +903,7 @@ def isenumtype(obj: type) -> compat.TypeIs[type[enum.Enum]]:
903903
>>> isenumtype(FooNum)
904904
True
905905
"""
906-
return issubclass(obj, enum.Enum)
906+
return _safe_issubclass(obj, enum.Enum)
907907

908908

909909
@compat.cache

src/typelib/unmarshals/api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ def __call__(self, val: tp.Any) -> T:
108108
inspection.isliteral: routines.LiteralUnmarshaller,
109109
# Special handler for Unions...
110110
inspection.isuniontype: routines.UnionUnmarshaller,
111+
# Special handling for Enums
112+
inspection.isenumtype: routines.EnumUnmarshaller,
111113
# Non-intersecting types (order doesn't matter here.
112114
inspection.isdatetimetype: routines.DateTimeUnmarshaller,
113115
inspection.isdatetype: routines.DateUnmarshaller,

src/typelib/unmarshals/routines.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import contextlib
77
import datetime
88
import decimal
9+
import enum
910
import fractions
1011
import numbers
1112
import pathlib
@@ -46,6 +47,7 @@
4647
"SubscriptedMappingUnmarshaller",
4748
"FixedTupleUnmarshaller",
4849
"StructuredTypeUnmarshaller",
50+
"EnumUnmarshaller",
4951
)
5052

5153

@@ -537,7 +539,7 @@ def __call__(self, val: tp.Any) -> UUIDT:
537539

538540

539541
class PatternUnmarshaller(AbstractUnmarshaller[PatternT], tp.Generic[PatternT]):
540-
"""Unmarshaller that converts an input to a[`re.Pattern`][].
542+
"""Unmarshaller that converts an input to a [`re.Pattern`][].
541543
542544
Note:
543545
You can't instantiate a [`re.Pattern`][] directly, so we don't have a good
@@ -596,6 +598,9 @@ def __call__(self, val: tp.Any) -> T:
596598
MappingUnmarshaller = CastUnmarshaller[tp.Mapping]
597599
IterableUnmarshaller = CastUnmarshaller[tp.Iterable]
598600

601+
EnumT = tp.TypeVar("EnumT", bound=enum.Enum)
602+
EnumUnmarshaller = CastUnmarshaller[EnumT]
603+
599604

600605
LiteralT = tp.TypeVar("LiteralT")
601606

tests/models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import dataclasses
4+
import enum
45
import typing
56

67

@@ -49,3 +50,7 @@ class NTuple(typing.NamedTuple):
4950
class TDict(typing.TypedDict):
5051
field: str
5152
value: int
53+
54+
55+
class GivenEnum(enum.Enum):
56+
one = "one"

tests/unit/marshals/test_api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,11 @@
137137
),
138138
expected_output={"indirect": {"cycle": {"indirect": {"cycle": None}}}},
139139
),
140+
enum_type=dict(
141+
given_type=models.GivenEnum,
142+
given_input=models.GivenEnum.one,
143+
expected_output=models.GivenEnum.one.value,
144+
),
140145
)
141146
def test_marshal(given_type, given_input, expected_output):
142147
# When

tests/unit/marshals/test_routines.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,3 +449,14 @@ def test_invalid_union():
449449
# When/Then
450450
with pytest.raises(expected_exception):
451451
given_marshaller(given_value)
452+
453+
454+
def test_enum_unmarshaller():
455+
# Given
456+
given_unmarshaller = routines.EnumMarshaller(models.GivenEnum, {})
457+
given_value = models.GivenEnum.one
458+
expected_value = models.GivenEnum.one.value
459+
# When
460+
unmarshalled = given_unmarshaller(given_value)
461+
# Then
462+
assert unmarshalled == expected_value
File renamed without changes.

tests/unit/unmarshal/test_api.py renamed to tests/unit/unmarshals/test_api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,11 @@
137137
given_input='["1", "2"]',
138138
expected_output=[1, 2],
139139
),
140+
enum_type=dict(
141+
given_type=models.GivenEnum,
142+
given_input="one",
143+
expected_output=models.GivenEnum.one,
144+
),
140145
)
141146
def test_unmarshal(given_type, given_input, expected_output):
142147
# When

0 commit comments

Comments
 (0)