diff --git a/great_tables/_pipe.py b/great_tables/_pipe.py index f7dff7a30..0bb350375 100644 --- a/great_tables/_pipe.py +++ b/great_tables/_pipe.py @@ -1,7 +1,8 @@ from __future__ import annotations from typing import TYPE_CHECKING, Callable -from typing_extensions import ParamSpec + +from typing_extensions import Concatenate, ParamSpec if TYPE_CHECKING: from .gt import GT @@ -10,7 +11,9 @@ P = ParamSpec("P") -def pipe(self: "GT", func: Callable[P, "GT"], *args: P.args, **kwargs: P.kwargs) -> "GT": +def pipe( + self: "GT", func: Callable[Concatenate["GT", P], "GT"], *args: P.args, **kwargs: P.kwargs +) -> "GT": """ Provide a structured way to chain a function for a GT object. diff --git a/tests/test_pipe.py b/tests/test_pipe.py index 484b60070..6c8ca25ea 100644 --- a/tests/test_pipe.py +++ b/tests/test_pipe.py @@ -1,14 +1,19 @@ +from __future__ import annotations + import polars as pl from great_tables import GT +from great_tables._text import BaseText +from great_tables._tbl_data import SelectExpr + -def test_pipe(): +def test_pipe() -> None: columns = ["x", "y"] label = "a spanner" df = pl.DataFrame({"x": [1, 2, 3], "y": [3, 2, 1]}) - def tab_spanner2(gt, label, columns): + def tab_spanner2(gt: GT, label: str | BaseText, columns: SelectExpr) -> GT: return gt.tab_spanner(label=label, columns=columns) gt1 = GT(df).tab_spanner(label, columns=columns)