Skip to content

Commit 6bd7023

Browse files
authored
Merge pull request #3 from seandstewart/seandstewart/optional-types
fix: correct handling optional types
2 parents cba3aa8 + 79e431a commit 6bd7023

File tree

6 files changed

+72
-15
lines changed

6 files changed

+72
-15
lines changed

src/typelib/marshals/routines.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ class UnionMarshaller(AbstractMarshaller[UnionT], tp.Generic[UnionT]):
262262
- [`UnionUnmarshaller`][typelib.unmarshals.routines.UnionUnmarshaller]
263263
"""
264264

265-
__slots__ = ("stack", "ordered_routines")
265+
__slots__ = ("stack", "ordered_routines", "nullable")
266266

267267
def __init__(self, t: type[UnionT], context: ContextT, *, var: str | None = None):
268268
"""Constructor.
@@ -274,19 +274,25 @@ def __init__(self, t: type[UnionT], context: ContextT, *, var: str | None = None
274274
"""
275275
super().__init__(t, context, var=var)
276276
self.stack = inspection.args(t)
277+
self.nullable = inspection.isoptionaltype(t)
277278
self.ordered_routines = [self.context[typ] for typ in self.stack]
278279

279280
def __call__(self, val: UnionT) -> serdes.MarshalledValueT:
280-
"""Unmarshal a value into the bound `UnionT`.
281+
"""Marshal a value into the bound `UnionT`.
281282
282283
Args:
283284
val: The input value to unmarshal.
284285
285286
Raises:
286287
ValueError: If `val` cannot be marshalled via any member type.
287288
"""
289+
if self.nullable and val is None:
290+
return val
291+
288292
for routine in self.ordered_routines:
289-
with contextlib.suppress(ValueError, TypeError, SyntaxError):
293+
with contextlib.suppress(
294+
ValueError, TypeError, SyntaxError, AttributeError
295+
):
290296
unmarshalled = routine(val)
291297
return unmarshalled
292298

src/typelib/unmarshals/routines.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,9 @@ def __init__(self, t: type[UnionT], context: ContextT, *, var: str | None = None
678678
"""
679679
super().__init__(t, context, var=var)
680680
self.stack = inspection.args(t)
681+
if inspection.isoptionaltype(t):
682+
self.stack = (self.stack[-1], *self.stack[:-1])
683+
681684
self.ordered_routines = [self.context[typ] for typ in self.stack]
682685

683686
def __call__(self, val: tp.Any) -> UnionT:
@@ -690,7 +693,9 @@ def __call__(self, val: tp.Any) -> UnionT:
690693
ValueError: If `val` cannot be unmarshalled into any member type.
691694
"""
692695
for routine in self.ordered_routines:
693-
with contextlib.suppress(ValueError, TypeError, SyntaxError):
696+
with contextlib.suppress(
697+
ValueError, TypeError, SyntaxError, AttributeError
698+
):
694699
unmarshalled = routine(val)
695700
return unmarshalled
696701

tests/unit/marshals/test_api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@
4242
given_input=2,
4343
expected_output=2,
4444
),
45+
optional_none=dict(
46+
given_type=typing.Optional[typing.Union[int, str]],
47+
given_input=None,
48+
expected_output=None,
49+
),
4550
datetime=dict(
4651
given_type=datetime.datetime,
4752
given_input=datetime.datetime.fromtimestamp(0, datetime.timezone.utc),

tests/unit/marshals/test_routines.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def test_date_marshaller(given_input, expected_output):
126126
expected_output=datetime.datetime(1969, 12, 31).isoformat(),
127127
),
128128
)
129-
def test_datetime_unmarshaller(given_input, expected_output):
129+
def test_datetime_marshaller(given_input, expected_output):
130130
# Given
131131
given_marshaller = routines.DateTimeMarshaller(datetime.datetime, {})
132132
# When
@@ -141,7 +141,7 @@ def test_datetime_unmarshaller(given_input, expected_output):
141141
expected_output="00:00:00+00:00",
142142
),
143143
)
144-
def test_time_unmarshaller(given_input, expected_output):
144+
def test_time_marshaller(given_input, expected_output):
145145
# Given
146146
given_marshaller = routines.TimeMarshaller(datetime.time, {})
147147
# When
@@ -153,7 +153,7 @@ def test_time_unmarshaller(given_input, expected_output):
153153
@pytest.mark.suite(
154154
timedelta=dict(given_input=datetime.timedelta(seconds=1), expected_output="PT1S"),
155155
)
156-
def test_timedelta_unmarshaller(given_input, expected_output):
156+
def test_timedelta_marshaller(given_input, expected_output):
157157
# Given
158158
given_marshaller = routines.TimeDeltaMarshaller(datetime.timedelta, {})
159159
# When
@@ -187,7 +187,7 @@ def test_mapping_marshaller(given_input, expected_output):
187187
expected_output=["field", "value"],
188188
),
189189
)
190-
def test_iterable_unmarshaller(given_input, expected_output):
190+
def test_iterable_marshaller(given_input, expected_output):
191191
# Given
192192
given_marshaller = routines.IterableMarshaller(typing.Iterable, {})
193193
# When
@@ -259,8 +259,26 @@ def test_literal_marshaller(given_input, given_literal, given_context, expected_
259259
},
260260
expected_output=1,
261261
),
262+
optional_date_none=dict(
263+
given_input=None,
264+
given_union=typing.Optional[datetime.date],
265+
given_context={
266+
datetime.date: routines.DateMarshaller(datetime.date, {}),
267+
type(None): routines.NoOpMarshaller(type(None), {}),
268+
},
269+
expected_output=None,
270+
),
271+
optional_date_date=dict(
272+
given_input=datetime.date.today(),
273+
given_union=typing.Optional[datetime.date],
274+
given_context={
275+
datetime.date: routines.DateMarshaller(datetime.date, {}),
276+
type(None): routines.NoOpMarshaller(type(None), {}),
277+
},
278+
expected_output=datetime.date.today().isoformat(),
279+
),
262280
)
263-
def test_union_unmarshaller(given_input, given_union, given_context, expected_output):
281+
def test_union_marshaller(given_input, given_union, given_context, expected_output):
264282
# Given
265283
given_marshaller = routines.UnionMarshaller(given_union, given_context)
266284
# When
@@ -280,7 +298,7 @@ def test_union_unmarshaller(given_input, given_union, given_context, expected_ou
280298
expected_output={"field": 1},
281299
),
282300
)
283-
def test_subscripted_mapping_unmarshaller(
301+
def test_subscripted_mapping_marshaller(
284302
given_input, given_mapping, given_context, expected_output
285303
):
286304
# Given
@@ -373,7 +391,7 @@ def test_subscripted_iterable_marshaller(
373391
expected_output=["field", 1],
374392
),
375393
)
376-
def test_fixed_tuple_unmarshaller(
394+
def test_fixed_tuple_marshaller(
377395
given_input, given_tuple, given_context, expected_output
378396
):
379397
# Given
@@ -419,7 +437,7 @@ def test_fixed_tuple_unmarshaller(
419437
given_input=models.TDict(field="data", value=1),
420438
),
421439
)
422-
def test_structured_type_unmarshaller(
440+
def test_structured_type_marshaller(
423441
given_input, given_cls, given_context, expected_output
424442
):
425443
# Given
@@ -456,12 +474,12 @@ def test_invalid_union():
456474
given_marshaller(given_value)
457475

458476

459-
def test_enum_unmarshaller():
477+
def test_enum_marshaller():
460478
# Given
461-
given_unmarshaller = routines.EnumMarshaller(models.GivenEnum, {})
479+
given_marshaller = routines.EnumMarshaller(models.GivenEnum, {})
462480
given_value = models.GivenEnum.one
463481
expected_value = models.GivenEnum.one.value
464482
# When
465-
unmarshalled = given_unmarshaller(given_value)
483+
unmarshalled = given_marshaller(given_value)
466484
# Then
467485
assert unmarshalled == expected_value

tests/unit/unmarshals/test_api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,11 @@
149149
timestamp=datetime.datetime.fromtimestamp(0, datetime.timezone.utc)
150150
),
151151
),
152+
optional_none=dict(
153+
given_type=typing.Optional[typing.Union[int, str]],
154+
given_input=None,
155+
expected_output=None,
156+
),
152157
attrib_conflict=dict(
153158
given_type=models.Parent,
154159
given_input={"intersection": {"a": 0}, "child": {"intersection": {"b": 0}}},

tests/unit/unmarshals/test_routines.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,24 @@ def test_literal_unmarshaller(
500500
},
501501
expected_output=1,
502502
),
503+
optional_date_none=dict(
504+
given_input=None,
505+
given_union=typing.Optional[datetime.date],
506+
given_context={
507+
datetime.date: routines.DateUnmarshaller(datetime.date, {}),
508+
type(None): routines.NoOpUnmarshaller(type(None), {}),
509+
},
510+
expected_output=None,
511+
),
512+
optional_date_date=dict(
513+
given_input=datetime.date.today().isoformat(),
514+
given_union=typing.Optional[datetime.date],
515+
given_context={
516+
datetime.date: routines.DateUnmarshaller(datetime.date, {}),
517+
type(None): routines.NoneTypeUnmarshaller(type(None), {}),
518+
},
519+
expected_output=datetime.date.today(),
520+
),
503521
)
504522
def test_union_unmarshaller(given_input, given_union, given_context, expected_output):
505523
# Given

0 commit comments

Comments
 (0)