Skip to content

Commit 4565c0d

Browse files
authored
Merge pull request #4 from seandstewart/seandstewart/ref-detection
fix: remove use of `graphlib.TypeNode` in type context
2 parents 4887d2c + e4742c0 commit 4565c0d

File tree

7 files changed

+58
-74
lines changed

7 files changed

+58
-74
lines changed

src/typelib/ctx.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,35 +2,29 @@
22

33
from __future__ import annotations
44

5-
import dataclasses
5+
import typing as tp
6+
import typing_extensions as te
67

7-
from typelib import graph
8-
from typelib.py import inspection, refs
8+
from typelib.py import refs
99

10+
ValueT = tp.TypeVar("ValueT")
11+
KeyT = te.TypeAliasType("KeyT", "type | refs.ForwardRef")
1012

11-
class TypeContext(dict):
13+
14+
class TypeContext(dict[KeyT, ValueT], tp.Generic[ValueT]):
1215
"""A key-value mapping which can map between forward references and real types."""
1316

14-
def __missing__(self, key: graph.TypeNode | type | refs.ForwardRef):
17+
def __missing__(self, key: type | refs.ForwardRef):
1518
"""Hook to handle missing type references.
1619
1720
Allows for sharing lookup results between forward references and real types.
1821
1922
Args:
2023
key: The type or reference.
2124
"""
22-
# Eager wrap in a TypeNode
23-
if not isinstance(key, graph.TypeNode):
24-
key = graph.TypeNode(type=key)
25-
return self[key]
26-
2725
# If we missed a ForwardRef, we've already tried this, bail out.
28-
type = key.type
29-
if isinstance(type, refs.ForwardRef):
26+
if isinstance(key, refs.ForwardRef):
3027
raise KeyError(key)
3128

32-
ref = refs.forwardref(
33-
inspection.qualname(type), module=getattr(type, "__module__", None)
34-
)
35-
node = dataclasses.replace(key, type=ref)
36-
return self[node]
29+
ref = refs.forwardref(key)
30+
return self[ref]

src/typelib/marshals/api.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,26 +46,24 @@ def marshaller(
4646
[`typing.ForwardRef`][], or string reference.
4747
"""
4848
nodes = graph.static_order(t)
49-
context: dict[type | graph.TypeNode, routines.AbstractMarshaller] = (
50-
ctx.TypeContext()
51-
)
49+
context: ctx.TypeContext[routines.AbstractMarshaller] = ctx.TypeContext()
5250
if not nodes:
5351
return routines.NoOpMarshaller(t=t, context=context, var=None) # type: ignore[arg-type]
5452

5553
# "root" type will always be the final node in the sequence.
5654
root = nodes[-1]
5755
for node in nodes:
58-
context[node] = _get_unmarshaller(node, context=context)
56+
context[node.type] = _get_unmarshaller(node, context=context)
5957

60-
return context[root]
58+
return context[root.type]
6159

6260

6361
def _get_unmarshaller( # type: ignore[return]
6462
node: graph.TypeNode,
6563
context: routines.ContextT,
6664
) -> routines.AbstractMarshaller[T]:
67-
if node in context:
68-
return context[node]
65+
if node.type in context:
66+
return context[node.type]
6967

7068
for check, unmarshaller_cls in _HANDLERS.items():
7169
if check(node.type):

src/typelib/marshals/routines.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import uuid
1515
import warnings
1616

17-
from typelib import graph, serdes
17+
from typelib import ctx, serdes
1818
from typelib.py import compat, inspection, refs
1919

2020
T = tp.TypeVar("T")
@@ -93,7 +93,7 @@ def __init__(self, t: type[T], context: ContextT, *, var: str | None = None):
9393
def __call__(self, val: T) -> serdes.MarshalledValueT: ...
9494

9595

96-
ContextT: tp.TypeAlias = "tp.Mapping[type | graph.TypeNode, AbstractMarshaller]"
96+
ContextT: tp.TypeAlias = "ctx.TypeContext[AbstractMarshaller]"
9797

9898

9999
class NoOpMarshaller(AbstractMarshaller[T], tp.Generic[T]):
@@ -463,25 +463,21 @@ def __init__(self, t: type[_ST], context: ContextT, *, var: str | None = None):
463463

464464
def _fields_by_var(self):
465465
fields_by_var = {}
466-
tp_var_map = {(t.type, t.var): m for t, m in self.context.items()}
467466
hints = inspection.cached_type_hints(self.t)
468467
for name, hint in hints.items():
469468
resolved = refs.evaluate(hint)
470-
fkey = (hint, name)
471-
rkey = (resolved, name)
472-
if fkey in tp_var_map:
473-
fields_by_var[name] = tp_var_map[fkey]
474-
continue
475-
if rkey in tp_var_map:
476-
fields_by_var[name] = tp_var_map[rkey]
469+
m = self.context.get(hint) or self.context.get(resolved)
470+
if m is None:
471+
warnings.warn(
472+
"Failed to identify an unmarshaller for the associated type-variable pair: "
473+
f"Original ref: {hint}, Resolved ref: {resolved}. Will default to no-op.",
474+
stacklevel=4,
475+
)
476+
fields_by_var[name] = NoOpMarshaller(hint, self.context, var=name)
477477
continue
478478

479-
warnings.warn( # pragma: no cover
480-
"Failed to identify an unmarshaller for the associated type-variable pair: "
481-
f"Original ref: {fkey}, Resolved ref: {resolved}. Will default to no-op.",
482-
stacklevel=3,
483-
)
484-
fields_by_var[name] = NoOpMarshaller(hint, self.context, var=name)
479+
fields_by_var[name] = m
480+
485481
return fields_by_var
486482

487483
def __call__(self, val: _ST) -> MarshalledMappingT:

src/typelib/py/refs.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@
1818
import sys
1919
import typing
2020

21-
from typelib.py import frames, future
21+
from typelib.py import frames, future, inspection
2222

2323
__all__ = ("ForwardRef", "evaluate", "forwardref")
2424

2525
ForwardRef: typing.TypeAlias = typing.ForwardRef
2626

2727

2828
def forwardref(
29-
ref: str,
29+
ref: str | type,
3030
*,
3131
is_argument: bool = False,
3232
module: typing.Any | None = None,
@@ -43,11 +43,18 @@ def forwardref(
4343
module: The python module in which the reference string is defined (optional)
4444
is_class: Whether the reference string is a class (default True).
4545
"""
46+
if not isinstance(ref, str):
47+
name = inspection.qualname(ref)
48+
module = module or getattr(ref, "__module__", None)
49+
else:
50+
name = typing.cast(str, ref)
51+
4652
module = _resolve_module_name(ref, module)
4753
if module is not None:
48-
ref = ref.replace(f"{module}.", "")
54+
name = name.replace(f"{module}.", "")
55+
4956
return ForwardRef(
50-
ref,
57+
name,
5158
is_argument=is_argument,
5259
module=module,
5360
is_class=is_class,

src/typelib/unmarshals/api.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,26 +40,24 @@ def unmarshaller(
4040
May be a type, type alias, [`typing.ForwardRef`][], or string reference.
4141
"""
4242
nodes = graph.static_order(t)
43-
context: dict[type | graph.TypeNode, routines.AbstractUnmarshaller] = (
44-
ctx.TypeContext()
45-
)
43+
context: ctx.TypeContext[routines.AbstractUnmarshaller] = ctx.TypeContext()
4644
if not nodes:
4745
return routines.NoOpUnmarshaller(t=t, context=context, var=None) # type: ignore[arg-type]
4846

4947
# "root" type will always be the final node in the sequence.
5048
root = nodes[-1]
5149
for node in nodes:
52-
context[node] = _get_unmarshaller(node, context=context)
50+
context[node.type] = _get_unmarshaller(node, context=context)
5351

54-
return context[root]
52+
return context[root.type]
5553

5654

5755
def _get_unmarshaller( # type: ignore[return]
5856
node: graph.TypeNode,
5957
context: routines.ContextT,
6058
) -> routines.AbstractUnmarshaller[T]:
61-
if node in context:
62-
return context[node]
59+
if node.type in context:
60+
return context[node.type]
6361

6462
for check, unmarshaller_cls in _HANDLERS.items():
6563
if check(node.type):

src/typelib/unmarshals/routines.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import uuid
1717
import warnings
1818

19-
from typelib import constants, graph, serdes
19+
from typelib import constants, ctx, serdes
2020
from typelib.py import compat, inspection, refs
2121

2222
T = tp.TypeVar("T")
@@ -134,7 +134,7 @@ def __call__(self, val: tp.Any) -> None:
134134
return None
135135

136136

137-
ContextT: tp.TypeAlias = "tp.Mapping[type | graph.TypeNode, AbstractUnmarshaller]"
137+
ContextT: tp.TypeAlias = "ctx.TypeContext[AbstractUnmarshaller]"
138138
BytesT = tp.TypeVar("BytesT", bound=bytes)
139139

140140

@@ -977,25 +977,21 @@ def __init__(self, t: type[_ST], context: ContextT, *, var: str | None = None):
977977

978978
def _fields_by_var(self):
979979
fields_by_var = {}
980-
tp_var_map = {(t.type, t.var): m for t, m in self.context.items()}
981980
hints = inspection.cached_type_hints(self.t)
982981
for name, hint in hints.items():
983982
resolved = refs.evaluate(hint)
984-
fkey = (hint, name)
985-
rkey = (resolved, name)
986-
if fkey in tp_var_map:
987-
fields_by_var[name] = tp_var_map[fkey]
988-
continue
989-
if rkey in tp_var_map:
990-
fields_by_var[name] = tp_var_map[rkey]
983+
m = self.context.get(hint) or self.context.get(resolved)
984+
if m is None:
985+
warnings.warn(
986+
"Failed to identify an unmarshaller for the associated type-variable pair: "
987+
f"Original ref: {hint}, Resolved ref: {resolved}. Will default to no-op.",
988+
stacklevel=4,
989+
)
990+
fields_by_var[name] = NoOpUnmarshaller(hint, self.context, var=name)
991991
continue
992992

993-
warnings.warn(
994-
"Failed to identify an unmarshaller for the associated type-variable pair: "
995-
f"Original ref: {fkey}, Resolved ref: {resolved}. Will default to no-op.",
996-
stacklevel=3,
997-
)
998-
fields_by_var[name] = NoOpUnmarshaller(hint, self.context, var=name)
993+
fields_by_var[name] = m
994+
999995
return fields_by_var
1000996

1001997
def __call__(self, val: tp.Any) -> _ST:

tests/unit/unmarshals/test_routines.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import pytest
1212

13-
from typelib import graph
1413
from typelib.unmarshals import routines
1514

1615
from tests import models
@@ -739,12 +738,8 @@ def test_fixed_tuple_unmarshaller(
739738
@pytest.mark.suite(
740739
context=dict(
741740
given_context={
742-
graph.TypeNode(int, var="value"): routines.NumberUnmarshaller(
743-
int, {}, var="value"
744-
),
745-
graph.TypeNode(str, var="field"): routines.StringUnmarshaller(
746-
str, {}, var="field"
747-
),
741+
int: routines.NumberUnmarshaller(int, {}, var="value"),
742+
str: routines.StringUnmarshaller(str, {}, var="field"),
748743
},
749744
),
750745
)

0 commit comments

Comments
 (0)