Skip to content

Commit 6872e2a

Browse files
authored
feat!: Allow for chaining .name expressions, aligning with Polars' new aliasing behaviour (#2898)
1 parent 60c0481 commit 6872e2a

File tree

11 files changed

+106
-51
lines changed

11 files changed

+106
-51
lines changed

narwhals/_compliant/expr.py

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,44 @@ def _from_series(cls, series: EagerSeriesT) -> Self:
363363
version=series._version,
364364
)
365365

366+
def _with_alias_output_names(self, alias_name: AliasName | None, /) -> Self:
367+
current_alias_output_names = self._alias_output_names
368+
alias_output_names: AliasNames | None = (
369+
None
370+
if alias_name is None
371+
else (
372+
lambda output_names: [
373+
alias_name(x) for x in current_alias_output_names(output_names)
374+
]
375+
)
376+
if current_alias_output_names is not None
377+
else (lambda output_names: [alias_name(x) for x in output_names])
378+
)
379+
380+
def func(df: EagerDataFrameT) -> list[EagerSeriesT]:
381+
if alias_output_names:
382+
return [
383+
series.alias(name)
384+
for series, name in zip(
385+
self(df), alias_output_names(self._evaluate_output_names(df))
386+
)
387+
]
388+
return [
389+
series.alias(name)
390+
for series, name in zip(self(df), self._evaluate_output_names(df))
391+
]
392+
393+
return self.__class__(
394+
func,
395+
depth=self._depth,
396+
function_name=self._function_name,
397+
evaluate_output_names=self._evaluate_output_names,
398+
alias_output_names=alias_output_names,
399+
implementation=self._implementation,
400+
version=self._version,
401+
scalar_kwargs=self._scalar_kwargs,
402+
)
403+
366404
def _reuse_series(
367405
self,
368406
method_name: str,
@@ -1035,7 +1073,7 @@ class CompliantExprNameNamespace( # type: ignore[misc]
10351073
Protocol[CompliantExprT_co],
10361074
):
10371075
def keep(self) -> CompliantExprT_co:
1038-
return self._from_callable(lambda name: name, alias=False)
1076+
return self._from_callable(None)
10391077

10401078
def map(self, function: AliasName) -> CompliantExprT_co:
10411079
return self._from_callable(function)
@@ -1059,41 +1097,27 @@ def fn(output_names: Sequence[str], /) -> Sequence[str]:
10591097

10601098
return fn
10611099

1062-
def _from_callable(
1063-
self, func: AliasName, /, *, alias: bool = True
1064-
) -> CompliantExprT_co: ...
1100+
def _from_callable(self, func: AliasName | None, /) -> CompliantExprT_co: ...
10651101

10661102

10671103
class EagerExprNameNamespace(
10681104
EagerExprNamespace[EagerExprT],
10691105
CompliantExprNameNamespace[EagerExprT],
10701106
Generic[EagerExprT],
10711107
):
1072-
def _from_callable(self, func: AliasName, /, *, alias: bool = True) -> EagerExprT:
1108+
def _from_callable(self, func: AliasName | None) -> EagerExprT:
10731109
expr = self.compliant
1074-
return type(expr)(
1075-
lambda df: [
1076-
series.alias(func(name))
1077-
for series, name in zip(expr(df), expr._evaluate_output_names(df))
1078-
],
1079-
depth=expr._depth,
1080-
function_name=expr._function_name,
1081-
evaluate_output_names=expr._evaluate_output_names,
1082-
alias_output_names=self._alias_output_names(func) if alias else None,
1083-
implementation=expr._implementation,
1084-
version=expr._version,
1085-
scalar_kwargs=expr._scalar_kwargs,
1086-
)
1110+
return expr._with_alias_output_names(func)
10871111

10881112

10891113
class LazyExprNameNamespace(
10901114
LazyExprNamespace[LazyExprT],
10911115
CompliantExprNameNamespace[LazyExprT],
10921116
Generic[LazyExprT],
10931117
):
1094-
def _from_callable(self, func: AliasName, /, *, alias: bool = True) -> LazyExprT:
1118+
def _from_callable(self, func: AliasName | None) -> LazyExprT:
10951119
expr = self.compliant
1096-
output_names = self._alias_output_names(func) if alias else None
1120+
output_names = self._alias_output_names(func) if func else None
10971121
return expr._with_alias_output_names(output_names)
10981122

10991123

narwhals/_dask/expr.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,12 +168,20 @@ def func(df: DaskLazyFrame) -> list[dx.Series]:
168168
)
169169

170170
def _with_alias_output_names(self, func: AliasNames | None, /) -> Self:
171+
current_alias_output_names = self._alias_output_names
172+
alias_output_names = (
173+
None
174+
if func is None
175+
else func
176+
if current_alias_output_names is None
177+
else lambda output_names: func(current_alias_output_names(output_names))
178+
)
171179
return type(self)(
172180
call=self._call,
173181
depth=self._depth,
174182
function_name=self._function_name,
175183
evaluate_output_names=self._evaluate_output_names,
176-
alias_output_names=func,
184+
alias_output_names=alias_output_names,
177185
version=self._version,
178186
scalar_kwargs=self._scalar_kwargs,
179187
)

narwhals/_sql/expr.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,19 @@ def _with_binary(self, op: Callable[..., NativeExprT], other: Self | Any) -> Sel
145145
)
146146

147147
def _with_alias_output_names(self, func: AliasNames | None, /) -> Self:
148+
current_alias_output_names = self._alias_output_names
149+
alias_output_names = (
150+
None
151+
if func is None
152+
else func
153+
if current_alias_output_names is None
154+
else lambda output_names: func(current_alias_output_names(output_names))
155+
)
148156
return type(self)(
149157
self._call,
150158
self._window_function,
151159
evaluate_output_names=self._evaluate_output_names,
152-
alias_output_names=func,
160+
alias_output_names=alias_output_names,
153161
version=self._version,
154162
implementation=self._implementation,
155163
)

narwhals/expr_name.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@ def keep(self) -> ExprT:
1919
A new expression.
2020
2121
Notes:
22-
This will undo any previous renaming operations on the expression.
23-
Due to implementation constraints, this method can only be called as the last
24-
expression in a chain. Only one name operation per expression will work.
22+
For Polars versions prior to 1.32, this will undo any previous renaming operations on the expression.
2523
2624
Examples:
2725
>>> import pandas as pd
@@ -45,9 +43,7 @@ def map(self, function: Callable[[str], str]) -> ExprT:
4543
A new expression.
4644
4745
Notes:
48-
This will undo any previous renaming operations on the expression.
49-
Due to implementation constraints, this method can only be called as the last
50-
expression in a chain. Only one name operation per expression will work.
46+
For Polars versions prior to 1.32, this will undo any previous renaming operations on the expression.
5147
5248
Examples:
5349
>>> import pandas as pd
@@ -72,9 +68,7 @@ def prefix(self, prefix: str) -> ExprT:
7268
A new expression.
7369
7470
Notes:
75-
This will undo any previous renaming operations on the expression.
76-
Due to implementation constraints, this method can only be called as the last
77-
expression in a chain. Only one name operation per expression will work.
71+
For Polars versions prior to 1.32, this will undo any previous renaming operations on the expression.
7872
7973
Examples:
8074
>>> import polars as pl
@@ -98,9 +92,7 @@ def suffix(self, suffix: str) -> ExprT:
9892
A new expression.
9993
10094
Notes:
101-
This will undo any previous renaming operations on the expression.
102-
Due to implementation constraints, this method can only be called as the last
103-
expression in a chain. Only one name operation per expression will work.
95+
For Polars versions prior to 1.32, this will undo any previous renaming operations on the expression.
10496
10597
Examples:
10698
>>> import polars as pl
@@ -121,9 +113,7 @@ def to_lowercase(self) -> ExprT:
121113
A new expression.
122114
123115
Notes:
124-
This will undo any previous renaming operations on the expression.
125-
Due to implementation constraints, this method can only be called as the last
126-
expression in a chain. Only one name operation per expression will work.
116+
For Polars versions prior to 1.32, this will undo any previous renaming operations on the expression.
127117
128118
Examples:
129119
>>> import pyarrow as pa
@@ -144,9 +134,7 @@ def to_uppercase(self) -> ExprT:
144134
A new expression.
145135
146136
Notes:
147-
This will undo any previous renaming operations on the expression.
148-
Due to implementation constraints, this method can only be called as the last
149-
expression in a chain. Only one name operation per expression will work.
137+
For Polars versions prior to 1.32, this will undo any previous renaming operations on the expression.
150138
151139
Examples:
152140
>>> import pyarrow as pa

tests/expr_and_series/name/map_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22

3+
import pytest
4+
35
import narwhals as nw
4-
from tests.utils import Constructor, assert_equal_data
6+
from tests.utils import POLARS_VERSION, Constructor, assert_equal_data
57

68
data = {"foo": [1, 2, 3], "BAR": [4, 5, 6]}
79

@@ -18,9 +20,11 @@ def test_map(constructor: Constructor) -> None:
1820

1921

2022
def test_map_after_alias(constructor: Constructor) -> None:
23+
if "polars" in str(constructor) and POLARS_VERSION < (1, 32):
24+
pytest.skip(reason="https://github.com/pola-rs/polars/issues/23765")
2125
df = nw.from_native(constructor(data))
2226
result = df.select((nw.col("foo")).alias("alias_for_foo").name.map(function=map_func))
23-
expected = {"oof": data["foo"]}
27+
expected = {"oof_rof_saila": data["foo"]}
2428
assert_equal_data(result, expected)
2529

2630

tests/expr_and_series/name/prefix_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22

3+
import pytest
4+
35
import narwhals as nw
4-
from tests.utils import Constructor, assert_equal_data
6+
from tests.utils import POLARS_VERSION, Constructor, assert_equal_data
57

68
data = {"foo": [1, 2, 3], "BAR": [4, 5, 6]}
79
prefix = "with_prefix_"
@@ -15,9 +17,11 @@ def test_prefix(constructor: Constructor) -> None:
1517

1618

1719
def test_suffix_after_alias(constructor: Constructor) -> None:
20+
if "polars" in str(constructor) and POLARS_VERSION < (1, 32):
21+
pytest.skip(reason="https://github.com/pola-rs/polars/issues/23765")
1822
df = nw.from_native(constructor(data))
1923
result = df.select((nw.col("foo")).alias("alias_for_foo").name.prefix(prefix))
20-
expected = {"with_prefix_foo": [1, 2, 3]}
24+
expected = {"with_prefix_alias_for_foo": [1, 2, 3]}
2125
assert_equal_data(result, expected)
2226

2327

tests/expr_and_series/name/suffix_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22

3+
import pytest
4+
35
import narwhals as nw
4-
from tests.utils import Constructor, assert_equal_data
6+
from tests.utils import POLARS_VERSION, Constructor, assert_equal_data
57

68
data = {"foo": [1, 2, 3], "BAR": [4, 5, 6]}
79
suffix = "_with_suffix"
@@ -15,9 +17,11 @@ def test_suffix(constructor: Constructor) -> None:
1517

1618

1719
def test_suffix_after_alias(constructor: Constructor) -> None:
20+
if "polars" in str(constructor) and POLARS_VERSION < (1, 32):
21+
pytest.skip(reason="https://github.com/pola-rs/polars/issues/23765")
1822
df = nw.from_native(constructor(data))
1923
result = df.select((nw.col("foo")).alias("alias_for_foo").name.suffix(suffix))
20-
expected = {"foo_with_suffix": [1, 2, 3]}
24+
expected = {"alias_for_foo_with_suffix": [1, 2, 3]}
2125
assert_equal_data(result, expected)
2226

2327

tests/expr_and_series/name/to_lowercase_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22

3+
import pytest
4+
35
import narwhals as nw
4-
from tests.utils import Constructor, assert_equal_data
6+
from tests.utils import POLARS_VERSION, Constructor, assert_equal_data
57

68
data = {"foo": [1, 2, 3], "BAR": [4, 5, 6]}
79

@@ -14,9 +16,11 @@ def test_to_lowercase(constructor: Constructor) -> None:
1416

1517

1618
def test_to_lowercase_after_alias(constructor: Constructor) -> None:
19+
if "polars" in str(constructor) and POLARS_VERSION < (1, 32):
20+
pytest.skip(reason="https://github.com/pola-rs/polars/issues/23765")
1721
df = nw.from_native(constructor(data))
1822
result = df.select((nw.col("BAR")).alias("ALIAS_FOR_BAR").name.to_lowercase())
19-
expected = {"bar": [4, 5, 6]}
23+
expected = {"alias_for_bar": [4, 5, 6]}
2024
assert_equal_data(result, expected)
2125

2226

tests/expr_and_series/name/to_uppercase_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22

3+
import pytest
4+
35
import narwhals as nw
4-
from tests.utils import Constructor, assert_equal_data
6+
from tests.utils import POLARS_VERSION, Constructor, assert_equal_data
57

68
data = {"foo": [1, 2, 3], "BAR": [4, 5, 6]}
79

@@ -14,9 +16,11 @@ def test_to_uppercase(constructor: Constructor) -> None:
1416

1517

1618
def test_to_uppercase_after_alias(constructor: Constructor) -> None:
19+
if "polars" in str(constructor) and POLARS_VERSION < (1, 32):
20+
pytest.skip(reason="https://github.com/pola-rs/polars/issues/23765")
1721
df = nw.from_native(constructor(data))
1822
result = df.select((nw.col("foo")).alias("alias_for_foo").name.to_uppercase())
19-
expected = {"FOO": [1, 2, 3]}
23+
expected = {"ALIAS_FOR_FOO": [1, 2, 3]}
2024
assert_equal_data(result, expected)
2125

2226

tests/selectors_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ def test_datetime(constructor: Constructor, request: pytest.FixtureRequest) -> N
8888
or ("pyarrow" in str(constructor) and is_windows())
8989
or ("pandas" in str(constructor) and PANDAS_VERSION < (2,))
9090
or "ibis" in str(constructor)
91+
# https://github.com/pola-rs/polars/issues/23767
92+
or ("polars" in str(constructor) and POLARS_VERSION == (1, 32, 0, 1))
9193
):
9294
request.applymarker(pytest.mark.xfail)
9395
if "modin" in str(constructor):

0 commit comments

Comments
 (0)