Skip to content

Commit bd74d29

Browse files
authored
Merge pull request #7 from seandstewart/seandstewart/type-alias-type-origin
fix: Handle `TypeAliasType` within `TypeContext` lookups
2 parents d4e8a25 + 1295bed commit bd74d29

File tree

7 files changed

+46
-7
lines changed

7 files changed

+46
-7
lines changed

src/typelib/ctx.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,30 @@
22

33
from __future__ import annotations
44

5+
import contextlib
56
import typing as tp
67
import typing_extensions as te
78

8-
from typelib.py import refs
9+
from typelib.py import inspection, refs
910

1011
ValueT = tp.TypeVar("ValueT")
12+
DefaultT = tp.TypeVar("DefaultT")
1113
KeyT = te.TypeAliasType("KeyT", "type | refs.ForwardRef")
1214

1315

1416
class TypeContext(dict[KeyT, ValueT], tp.Generic[ValueT]):
1517
"""A key-value mapping which can map between forward references and real types."""
1618

19+
def get(self, key: KeyT, default: ValueT | DefaultT = None) -> ValueT | DefaultT:
20+
with contextlib.suppress(KeyError):
21+
return self[key]
22+
23+
return default
24+
1725
def __missing__(self, key: type | refs.ForwardRef):
1826
"""Hook to handle missing type references.
1927
20-
Allows for sharing lookup results between forward references and real types.
28+
Allows for sharing lookup results between forward references, type aliases, real types.
2129
2230
Args:
2331
key: The type or reference.
@@ -26,5 +34,12 @@ def __missing__(self, key: type | refs.ForwardRef):
2634
if isinstance(key, refs.ForwardRef):
2735
raise KeyError(key)
2836

37+
unwrapped = inspection.unwrap(key)
38+
if unwrapped in self:
39+
val = self[unwrapped]
40+
# Store the value at the original key to short-circuit in future
41+
self[key] = val
42+
return val
43+
2944
ref = refs.forwardref(key)
3045
return self[ref]

src/typelib/marshals/routines.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,9 +469,9 @@ def _fields_by_var(self):
469469
m = self.context.get(hint) or self.context.get(resolved)
470470
if m is None:
471471
warnings.warn(
472-
"Failed to identify an unmarshaller for the associated type-variable pair: "
472+
"Failed to identify a marshaller for the associated type-variable pair: "
473473
f"Original ref: {hint}, Resolved ref: {resolved}. Will default to no-op.",
474-
stacklevel=4,
474+
stacklevel=5,
475475
)
476476
fields_by_var[name] = NoOpMarshaller(hint, self.context, var=name)
477477
continue

src/typelib/py/inspection.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -933,7 +933,6 @@ def isclassvartype(obj: type) -> bool:
933933

934934
_UNWRAPPABLE = (
935935
isclassvartype,
936-
isoptionaltype,
937936
isfinal,
938937
)
939938

@@ -1484,6 +1483,22 @@ def istypealiastype(t: tp.Any) -> compat.TypeIs[compat.TypeAliasType]:
14841483
return isinstance(t, compat.TypeAliasType)
14851484

14861485

1486+
@compat.cache
1487+
def unwrap(t: tp.Any) -> tp.Any:
1488+
while True:
1489+
if should_unwrap(t):
1490+
t = t.__args__[0]
1491+
continue
1492+
if istypealiastype(t):
1493+
t = t.__value__
1494+
continue
1495+
1496+
if hasattr(t, "__supertype__"):
1497+
t = t.__supertype__
1498+
continue
1499+
return t
1500+
1501+
14871502
def _safe_issubclass(__cls: type, __class_or_tuple: type | tuple[type, ...]) -> bool:
14881503
try:
14891504
return issubclass(__cls, __class_or_tuple)

src/typelib/unmarshals/routines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -985,7 +985,7 @@ def _fields_by_var(self):
985985
warnings.warn(
986986
"Failed to identify an unmarshaller for the associated type-variable pair: "
987987
f"Original ref: {hint}, Resolved ref: {resolved}. Will default to no-op.",
988-
stacklevel=4,
988+
stacklevel=6,
989989
)
990990
fields_by_var[name] = NoOpUnmarshaller(hint, self.context, var=name)
991991
continue

tests/unit/marshals/test_api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,11 @@
155155
),
156156
expected_output={"intersection": {"a": 0}, "child": {"intersection": {"b": 0}}},
157157
),
158+
nested_type_alias=dict(
159+
given_type=models.NestedTypeAliasType,
160+
given_input=models.NestedTypeAliasType(alias=[1]),
161+
expected_output={"alias": [1]},
162+
),
158163
)
159164
def test_marshal(given_type, given_input, expected_output):
160165
# When

tests/unit/py/test_inspection.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,6 @@ def test_isclassvartype(given_type, expected_is_classvar_type):
592592

593593
@pytest.mark.suite(
594594
classvar=dict(given_type=t.ClassVar[int], expected_should_unwrap=True),
595-
optional=dict(given_type=t.Optional[str], expected_should_unwrap=True),
596595
final=dict(given_type=t.Final[str], expected_should_unwrap=True),
597596
literal=dict(given_type=t.Literal[1], expected_should_unwrap=False),
598597
)

tests/unit/unmarshals/test_api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,11 @@
162162
child=models.Child(intersection=models.ChildIntersect(b=0)),
163163
),
164164
),
165+
nested_type_alias=dict(
166+
given_type=models.NestedTypeAliasType,
167+
given_input={"alias": ["1"]},
168+
expected_output=models.NestedTypeAliasType(alias=[1]),
169+
),
165170
)
166171
def test_unmarshal(given_type, given_input, expected_output):
167172
# When

0 commit comments

Comments
 (0)