diff --git a/Lib/functools.py b/Lib/functools.py index 836eb680ccd4d4..7ed3d67f3cf79c 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -19,7 +19,7 @@ # import weakref # Deferred to single_dispatch() from operator import itemgetter from reprlib import recursive_repr -from types import GenericAlias, MethodType, MappingProxyType, UnionType +from types import FunctionType, GenericAlias, MethodType, MappingProxyType, UnionType from _thread import RLock ################################################################################ @@ -888,6 +888,48 @@ def _find_impl(cls, registry): match = t return registry.get(match) +def _get_singledispatch_annotated_param(func, *, _inside_dispatchmethod=False): + """Finds the first positional and user-specified parameter in a callable + or descriptor. + + Used by singledispatch for registration by type annotation of the parameter. + """ + # Pick the first parameter if function had @staticmethod. + if isinstance(func, staticmethod): + idx = 0 + func = func.__func__ + # Pick the second parameter if function had @classmethod or is a bound method. + elif isinstance(func, (classmethod, MethodType)): + idx = 1 + func = func.__func__ + # If it is a regular function: + # Pick the first parameter if registering via singledispatch. + # Pick the second parameter if registering via singledispatchmethod. + else: + idx = int(_inside_dispatchmethod) + + # If it is a simple function, try to read from the code object fast. + if isinstance(func, FunctionType) and not hasattr(func, "__wrapped__"): + # Emulate inspect._signature_from_function to get the desired parameter. + func_code = func.__code__ + try: + return func_code.co_varnames[:func_code.co_argcount][idx] + except IndexError: + pass + + # Fall back to inspect.signature (slower, but complete). + import inspect + params = list(inspect.signature(func).parameters.values()) + try: + param = params[idx] + except IndexError: + pass + else: + # Allow variadic positional "(*args)" parameters for backward compatibility. + if param.kind not in (inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.VAR_KEYWORD): + return param.name + return None + def singledispatch(func): """Single-dispatch generic function decorator. @@ -935,7 +977,7 @@ def _is_valid_dispatch_type(cls): return (isinstance(cls, UnionType) and all(isinstance(arg, type) for arg in cls.__args__)) - def register(cls, func=None): + def register(cls, func=None, _inside_dispatchmethod=False): """generic_func.register(cls, func) -> func Registers a new implementation for the given *cls* on a *generic_func*. @@ -960,10 +1002,28 @@ def register(cls, func=None): ) func = cls + argname = _get_singledispatch_annotated_param( + func, _inside_dispatchmethod=_inside_dispatchmethod) + if argname is None: + raise TypeError( + f"Invalid first argument to `register()`: {func!r} " + f"does not accept positional arguments." + ) + # only import typing if annotation parsing is necessary from typing import get_type_hints from annotationlib import Format, ForwardRef - argname, cls = next(iter(get_type_hints(func, format=Format.FORWARDREF).items())) + annotations = get_type_hints(func, format=Format.FORWARDREF) + + try: + cls = annotations[argname] + except KeyError: + raise TypeError( + f"Invalid first argument to `register()`: {func!r}. " + "Use either `@register(some_class)` or add a type " + f"annotation to parameter {argname!r} of your callable." + ) from None + if not _is_valid_dispatch_type(cls): if isinstance(cls, UnionType): raise TypeError( @@ -1027,7 +1087,7 @@ def register(self, cls, method=None): Registers a new implementation for the given *cls* on a *generic_method*. """ - return self.dispatcher.register(cls, func=method) + return self.dispatcher.register(cls, func=method, _inside_dispatchmethod=True) def __get__(self, obj, cls=None): return _singledispatchmethod_get(self, obj, cls) diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index 94b469397139c7..4b1a2b3e12ad35 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -2905,13 +2905,34 @@ def t(self, arg): def _(self, arg: int): return "int" @t.register - def _(self, arg: str): + def _(self, arg: complex, /): + return "complex" + @t.register + def _(self, /, arg: str): return "str" + # See GH-130827. + def wrapped1(self: typing.Self, arg: bytes): + return "bytes" + @t.register + @functools.wraps(wrapped1) + def wrapper1(self, *args, **kwargs): + return self.wrapped1(*args, **kwargs) + + def wrapped2(self, arg: bytearray) -> str: + return "bytearray" + @t.register + @functools.wraps(wrapped2) + def wrapper2(self, *args: typing.Any, **kwargs: typing.Any): + return self.wrapped2(*args, **kwargs) + a = A() self.assertEqual(a.t(0), "int") + self.assertEqual(a.t(0j), "complex") self.assertEqual(a.t(''), "str") self.assertEqual(a.t(0.0), "base") + self.assertEqual(a.t(b''), "bytes") + self.assertEqual(a.t(bytearray()), "bytearray") def test_staticmethod_type_ann_register(self): class A: @@ -3172,12 +3193,27 @@ def test_invalid_registrations(self): @functools.singledispatch def i(arg): return "base" + with self.assertRaises(TypeError) as exc: + @i.register + def _() -> None: + return "My function doesn't take arguments" + self.assertStartsWith(str(exc.exception), msg_prefix) + self.assertEndsWith(str(exc.exception), "does not accept positional arguments.") + + with self.assertRaises(TypeError) as exc: + @i.register + def _(*, foo: str) -> None: + return "My function takes keyword-only arguments" + self.assertStartsWith(str(exc.exception), msg_prefix) + self.assertEndsWith(str(exc.exception), "does not accept positional arguments.") + with self.assertRaises(TypeError) as exc: @i.register(42) def _(arg): return "I annotated with a non-type" self.assertStartsWith(str(exc.exception), msg_prefix + "42") self.assertEndsWith(str(exc.exception), msg_suffix) + with self.assertRaises(TypeError) as exc: @i.register def _(arg): @@ -3187,6 +3223,33 @@ def _(arg): ) self.assertEndsWith(str(exc.exception), msg_suffix) + with self.assertRaises(TypeError) as exc: + @i.register + def _(arg, extra: int): + return "I did not annotate the right param" + self.assertStartsWith(str(exc.exception), msg_prefix + + "._" + ) + self.assertEndsWith(str(exc.exception), + "Use either `@register(some_class)` or add a type annotation " + f"to parameter 'arg' of your callable.") + + with self.assertRaises(TypeError) as exc: + # See GH-84644. + + @functools.singledispatch + def func(arg):... + + @func.register + def _int(arg) -> int:... + + self.assertStartsWith(str(exc.exception), msg_prefix + + "._int" + ) + self.assertEndsWith(str(exc.exception), + "Use either `@register(some_class)` or add a type annotation " + f"to parameter 'arg' of your callable.") + with self.assertRaises(TypeError) as exc: @i.register def _(arg: typing.Iterable[str]): diff --git a/Misc/NEWS.d/next/Library/2026-01-06-09-13-53.gh-issue-84644.V_cYP3.rst b/Misc/NEWS.d/next/Library/2026-01-06-09-13-53.gh-issue-84644.V_cYP3.rst new file mode 100644 index 00000000000000..95190f88b16e60 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2026-01-06-09-13-53.gh-issue-84644.V_cYP3.rst @@ -0,0 +1,5 @@ +:func:`functools.singledispatch` and :func:`functools.singledispatchmethod` +now require callables to be correctly annotated if registering without a type explicitly +specified in the decorator. The first user-specified positional parameter of a callable +must always be annotated. Before, a callable could be registered based on its return type +annotation or based on an irrelevant parameter type annotation. Contributed by Bartosz Sławecki.