From b4141afa3676cb8fa7ce844db4b10756aa3a4643 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 25 Sep 2025 11:55:40 +0300 Subject: [PATCH 01/95] feat: Make expressions printable, rewrite internals --- docs/how_it_works.md | 111 +++- narwhals/_arrow/dataframe.py | 4 +- narwhals/_arrow/expr.py | 15 +- narwhals/_arrow/group_by.py | 10 +- narwhals/_arrow/namespace.py | 45 +- narwhals/_arrow/selectors.py | 7 - narwhals/_arrow/series.py | 8 +- narwhals/_arrow/series_str.py | 6 +- narwhals/_compliant/__init__.py | 4 - narwhals/_compliant/any_namespace.py | 4 +- narwhals/_compliant/column.py | 60 -- narwhals/_compliant/dataframe.py | 6 +- narwhals/_compliant/expr.py | 153 ++--- narwhals/_compliant/group_by.py | 6 +- narwhals/_compliant/namespace.py | 135 +++-- narwhals/_compliant/series.py | 6 + narwhals/_compliant/when_then.py | 130 ---- narwhals/_dask/dataframe.py | 12 +- narwhals/_dask/expr.py | 164 ++--- narwhals/_dask/expr_dt.py | 71 +-- narwhals/_dask/expr_str.py | 70 +-- narwhals/_dask/group_by.py | 7 +- narwhals/_dask/namespace.py | 120 ++-- narwhals/_dask/selectors.py | 7 - narwhals/_duckdb/dataframe.py | 24 +- narwhals/_duckdb/expr.py | 8 +- narwhals/_duckdb/namespace.py | 20 +- narwhals/_duckdb/utils.py | 8 +- narwhals/_expression_parsing.py | 563 +++++++++++++----- narwhals/_ibis/dataframe.py | 7 - narwhals/_ibis/expr.py | 14 +- narwhals/_ibis/expr_str.py | 2 +- narwhals/_ibis/namespace.py | 31 +- narwhals/_pandas_like/dataframe.py | 22 +- narwhals/_pandas_like/expr.py | 81 ++- narwhals/_pandas_like/group_by.py | 12 +- narwhals/_pandas_like/namespace.py | 63 +- narwhals/_pandas_like/selectors.py | 7 - narwhals/_pandas_like/series_str.py | 6 +- narwhals/_polars/expr.py | 71 +-- narwhals/_polars/namespace.py | 45 +- narwhals/_polars/series.py | 47 +- narwhals/_spark_like/dataframe.py | 6 - narwhals/_spark_like/expr.py | 8 +- narwhals/_spark_like/namespace.py | 26 +- narwhals/_sql/dataframe.py | 11 +- narwhals/_sql/expr.py | 24 +- narwhals/_sql/expr_str.py | 10 +- narwhals/_sql/namespace.py | 29 +- narwhals/_sql/when_then.py | 106 ---- narwhals/dataframe.py | 118 ++-- narwhals/expr.py | 549 ++++++++--------- narwhals/expr_cat.py | 6 +- narwhals/expr_dt.py | 94 ++- narwhals/expr_list.py | 18 +- narwhals/expr_name.py | 26 +- narwhals/expr_str.py | 107 ++-- narwhals/expr_struct.py | 6 +- narwhals/functions.py | 185 ++---- narwhals/group_by.py | 27 +- narwhals/selectors.py | 101 ++-- narwhals/series.py | 46 +- narwhals/series_str.py | 4 +- narwhals/stable/v1/__init__.py | 66 +- narwhals/stable/v2/__init__.py | 4 +- narwhals/typing.py | 21 +- tests/conftest.py | 9 +- .../dt/convert_time_zone_test.py | 1 + .../dt/replace_time_zone_test.py | 1 + tests/expr_and_series/over_test.py | 2 +- tests/expr_and_series/unique_test.py | 10 +- tests/expr_and_series/when_test.py | 15 +- tests/expression_parsing_test.py | 173 +++--- tests/frame/filter_test.py | 2 +- tests/frame/group_by_test.py | 20 +- tests/selectors_test.py | 7 +- tests/utils.py | 2 +- tests/v1_test.py | 3 +- 78 files changed, 1800 insertions(+), 2235 deletions(-) delete mode 100644 narwhals/_compliant/when_then.py delete mode 100644 narwhals/_sql/when_then.py diff --git a/docs/how_it_works.md b/docs/how_it_works.md index add65aacc6..6b79a6f050 100644 --- a/docs/how_it_works.md +++ b/docs/how_it_works.md @@ -76,8 +76,9 @@ pn = PandasLikeNamespace( implementation=Implementation.PANDAS, version=Version.MAIN, ) -print(nw.col("a")._to_compliant_expr(pn)) +print(nw.col("a")(pn)) ``` + The result from the last line above is the same as we'd get from `pn.col('a')`, and it's a `narwhals._pandas_like.expr.PandasLikeExpr` object, which we'll call `PandasLikeExpr` for short. @@ -177,7 +178,7 @@ The way you access the Narwhals-compliant wrapper depends on the object: - `narwhals.DataFrame` and `narwhals.LazyFrame`: use the `._compliant_frame` attribute. - `narwhals.Series`: use the `._compliant_series` attribute. -- `narwhals.Expr`: call the `._to_compliant_expr` method, and pass to it the Narwhals-compliant namespace associated with +- `narwhals.Expr`: call the `.__call__` method, and pass to it the Narwhals-compliant namespace associated with the given backend. 🛑 BUT WAIT! What's a Narwhals-compliant namespace? @@ -212,9 +213,10 @@ pn = PandasLikeNamespace( implementation=Implementation.PANDAS, version=Version.MAIN, ) -expr = (nw.col("a") + 1)._to_compliant_expr(pn) +expr = (nw.col("a") + 1)(pn) print(expr) ``` + If we then extract a Narwhals-compliant dataframe from `df` by calling `._compliant_frame`, we get a `PandasLikeDataFrame` - and that's an object which we can pass `expr` to! @@ -228,6 +230,7 @@ We can then view the underlying pandas Dataframe which was produced by calling ` ```python exec="1" result="python" session="pandas_api_mapping" source="above" print(result._native_frame) ``` + which is the same as we'd have obtained by just using the Narwhals API directly: ```python exec="1" result="python" session="pandas_api_mapping" source="above" @@ -238,10 +241,12 @@ print(nw.to_native(df.select(nw.col("a") + 1))) Group-by is probably one of Polars' most significant innovations (on the syntax side) with respect to pandas. We can write something like + ```python df: pl.DataFrame df.group_by("a").agg((pl.col("c") > pl.col("b").mean()).max()) ``` + To do this in pandas, we need to either use `GroupBy.apply` (sloooow), or do some crazy manual optimisations to get it to work. @@ -249,38 +254,29 @@ In Narwhals, here's what we do: - if somebody uses a simple group-by aggregation (e.g. `df.group_by('a').agg(nw.col('b').mean())`), then on the pandas side we translate it to - ```python - df: pd.DataFrame - df.groupby("a").agg({"b": ["mean"]}) - ``` + + ```python + df: pd.DataFrame + df.groupby("a").agg({"b": ["mean"]}) + ``` + - if somebody passes a complex group-by aggregation, then we use `apply` and raise a `UserWarning`, warning users of the performance penalty and advising them to refactor their code so that the aggregation they perform ends up being a simple one. -In order to tell whether an aggregation is simple, Narwhals uses the private `_depth` attribute of `PandasLikeExpr`: - -```python exec="1" result="python" session="pandas_impl" source="above" -print(pn.col("a").mean()) -print((pn.col("a") + 1).mean()) -``` - -For simple aggregations, Narwhals can just look at `_depth` and `function_name` and figure out -which (efficient) elementary operation this corresponds to in pandas. - ## Expression Metadata -Let's try printing out a few expressions to the console to see what they show us: +Let's try printing out some compliant expressions' metadata to see what it shows us: -```python exec="1" result="python" session="metadata" source="above" +```python exec="1" result="python" session="pandas_impl" source="above" import narwhals as nw -print(nw.col("a")) -print(nw.col("a").mean()) -print(nw.col("a").mean().over("b")) +print(nw.col("a")(pn)._metadata) +print(nw.col("a").mean()(pn)._metadata) +print(nw.col("a").mean().over("b")(pn)._metadata) ``` -Note how they tell us something about their metadata. This section is all about -making sense of what that all means, what the rules are, and what it enables. +This section is all about making sense of what that all means, what the rules are, and what it enables. Here's a brief description of each piece of metadata: @@ -293,8 +289,6 @@ Here's a brief description of each piece of metadata: - `ExpansionKind.MULTI_UNNAMED`: Produces multiple outputs whose names depend on the input dataframe. For example, `nw.nth(0, 1)` or `nw.selectors.numeric()`. -- `last_node`: Kind of the last operation in the expression. See - `narwhals._expression_parsing.ExprKind` for the various options. - `has_windows`: Whether the expression already contains an `over(...)` statement. - `n_orderable_ops`: How many order-dependent operations the expression contains. @@ -311,6 +305,7 @@ Here's a brief description of each piece of metadata: - `is_scalar_like`: Whether the output of the expression is always length-1. - `is_literal`: Whether the expression doesn't depend on any column but instead only on literal values, like `nw.lit(1)`. +- `nodes`: List of operations which this expression applies when evaluated. #### Chaining @@ -377,3 +372,67 @@ Narwhals triggers a broadcast in these situations: Each backend is then responsible for doing its own broadcasting, as defined in each `CompliantExpr.broadcast` method. + +### Elementwise push-down + +SQL is picky about `over` operations. For example: + +- `sum(a) over (partition by b)` is valid. +- `sum(abs(a)) over (partition by b)` is valid. +- `abs(sum(a)) over (partition by b)` is not valid. + +In Polars, however, all three of + +- `pl.col('a').sum().over('b')` is valid. +- `pl.col('a').abs().sum().over('b')` is valid. +- `pl.col('a').sum().abs().over('b')` is valid. + +How can we retain Polars' level of flexibility when translating to SQL engines? + +The answer is: by rewriting expressions. Specifically, we push down `over` nodes past elementwise ones. +To see this, let's try printing the Narwhals equivalent of the last expression above (the one that SQL rejects): + +```python exec="1" result="python" session="pushdown" source="above" +import narwhals as nw + +print(nw.col("a").sum().abs().over("b")) +``` + +Note how Narwhals automatically inserted the `over` operation _before_ the `abs` one. In other words, instead +of doing + +- `sum` -> `abs` -> `over` + +it did + +- `sum` -> `over` -> `abs` + +thus allowing the expression to be valid for SQL engines! + +This is what we refer to as "pushing down `over` nodes". The idea is: + +- Elementwise operations operate row-by-row and don't depend on the rows around them. +- An `over` node partitions or orders a computation. +- Therefore, an elementwise operation followed by an `over` operation is the same + as doing the `over` operation followed by that same elementwise operation! + +Note that the pushdown also applies to any arguments to the elementwise operation. +For example, if we have + +```python +(nw.col("a").sum() + nw.col("b").sum()).over("c") +``` + +then `+` is an elementwise operation and so can be swapped with `over`. We just need +to take care to apply the `over` operation to all the arguments of `+`, so that we +end up with + +```python +nw.col("a").sum().over("c") + nw.col("b").sum().over("c") +``` + +In general, query optimisation is out-of-scope for Narwhals. We consider this +expression rewrite acceptable because: + +- It's simple. +- It allows us to evaluate operations which otherwise wouldn't be allowed for certain backends. diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 9eed006edc..cf7bd383d4 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -424,7 +424,7 @@ def drop_nulls(self, subset: Sequence[str] | None) -> Self: if subset is None: return self._with_native(self.native.drop_null(), validate_column_names=False) plx = self.__narwhals_namespace__() - mask = ~plx.any_horizontal(plx.col(*subset).is_null(), ignore_nulls=True) + mask = ~plx.any_horizontal(plx.col(subset).is_null(), ignore_nulls=True) return self.filter(mask) def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self: @@ -496,7 +496,7 @@ def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self: plx._series.from_iterable(data, context=self, name=name) ) else: - rank = plx.col(order_by[0]).rank("ordinal", descending=False) + rank = plx.col([order_by[0]]).rank("ordinal", descending=False) row_index = (rank.over(partition_by=[], order_by=order_by) - 1).alias(name) return self.select(row_index, plx.all()) diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index a0c56783a2..80e7072527 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -32,8 +32,6 @@ def __init__( self, call: EvalSeries[ArrowDataFrame, ArrowSeries], *, - depth: int, - function_name: str, evaluate_output_names: EvalNames[ArrowDataFrame], alias_output_names: AliasNames | None, version: Version, @@ -41,14 +39,10 @@ def __init__( implementation: Implementation | None = None, ) -> None: self._call = call - self._depth = depth - self._function_name = function_name - self._depth = depth self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names self._version = version - self._scalar_kwargs = scalar_kwargs or {} - self._metadata: ExprMetadata | None = None + self._opt_metadata: ExprMetadata | None = None @classmethod def from_column_names( @@ -57,7 +51,6 @@ def from_column_names( /, *, context: _LimitedContext, - function_name: str = "", ) -> Self: def func(df: ArrowDataFrame) -> list[ArrowSeries]: try: @@ -74,8 +67,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: return cls( func, - depth=0, - function_name=function_name, evaluate_output_names=evaluate_column_names, alias_output_names=None, version=context._version, @@ -93,8 +84,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: return cls( func, - depth=0, - function_name="nth", evaluate_output_names=cls._eval_names_indices(column_indices), alias_output_names=None, version=context._version, @@ -160,8 +149,6 @@ def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]: return self.__class__( func, - depth=self._depth + 1, - function_name=self._function_name + "->over", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, version=self._version, diff --git a/narwhals/_arrow/group_by.py b/narwhals/_arrow/group_by.py index 325acad692..d205dd283e 100644 --- a/narwhals/_arrow/group_by.py +++ b/narwhals/_arrow/group_by.py @@ -71,10 +71,10 @@ def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame: output_names, aliases = evaluate_output_names_and_aliases( expr, self.compliant, exclude ) - - if expr._depth == 0: + md = expr._metadata + if len(list(md.op_nodes_reversed())) == 1: # e.g. `agg(nw.len())` - if expr._function_name != "len": # pragma: no cover + if next(md.op_nodes_reversed()).name != "len": # pragma: no cover msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" raise AssertionError(msg) @@ -85,8 +85,8 @@ def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame: function_name = self._leaf_name(expr) if function_name in {"std", "var"}: - assert "ddof" in expr._scalar_kwargs # noqa: S101 - option: Any = pc.VarianceOptions(ddof=expr._scalar_kwargs["ddof"]) + last_node = next(md.op_nodes_reversed()) + option: Any = pc.VarianceOptions(**last_node.kwargs) elif function_name in {"len", "n_unique"}: option = pc.CountOptions(mode="all") elif function_name == "count": diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 3799aa87b2..abfd678917 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -13,7 +13,7 @@ from narwhals._arrow.selectors import ArrowSelectorNamespace from narwhals._arrow.series import ArrowSeries from narwhals._arrow.utils import cast_to_comparable_string_types -from narwhals._compliant import CompliantThen, EagerNamespace, EagerWhen +from narwhals._compliant import EagerNamespace from narwhals._expression_parsing import ( combine_alias_output_names, combine_evaluate_output_names, @@ -23,8 +23,7 @@ if TYPE_CHECKING: from collections.abc import Iterator, Sequence - from narwhals._arrow.typing import ArrayOrScalar, ChunkedArrayAny, Incomplete - from narwhals._compliant.typing import ScalarKwargs + from narwhals._arrow.typing import ChunkedArrayAny, Incomplete from narwhals._utils import Version from narwhals.typing import IntoDType, NonNestedLiteral @@ -55,8 +54,6 @@ def len(self) -> ArrowExpr: lambda df: [ ArrowSeries.from_iterable([len(df.native)], name="len", context=self) ], - depth=0, - function_name="len", evaluate_output_names=lambda _df: ["len"], alias_output_names=None, version=self._version, @@ -73,8 +70,6 @@ def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries: return self._expr( lambda df: [_lit_arrow_series(df)], - depth=0, - function_name="lit", evaluate_output_names=lambda _df: ["literal"], alias_output_names=None, version=self._version, @@ -90,8 +85,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="all_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -107,8 +100,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="any_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -123,8 +114,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="sum_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -144,8 +133,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="mean_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -165,8 +152,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="min_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -186,8 +171,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="max_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -220,9 +203,6 @@ def _concat_vertical(self, dfs: Sequence[pa.Table], /) -> pa.Table: def selectors(self) -> ArrowSelectorNamespace: return ArrowSelectorNamespace.from_namespace(self) - def when(self, predicate: ArrowExpr) -> ArrowWhen: - return ArrowWhen.from_expr(predicate, context=self) - def concat_str( self, *exprs: ArrowExpr, separator: str, ignore_nulls: bool ) -> ArrowExpr: @@ -250,8 +230,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="concat_str", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -271,33 +249,16 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="coalesce", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, ) - -class ArrowWhen(EagerWhen[ArrowDataFrame, ArrowSeries, ArrowExpr, "ChunkedArrayAny"]): - @property - def _then(self) -> type[ArrowThen]: - return ArrowThen - def _if_then_else( self, when: ChunkedArrayAny, then: ChunkedArrayAny, - otherwise: ArrayOrScalar | NonNestedLiteral, - /, + otherwise: ChunkedArrayAny | None = None, ) -> ChunkedArrayAny: otherwise = pa.nulls(len(when), then.type) if otherwise is None else otherwise return pc.if_else(when, then, otherwise) - - -class ArrowThen( - CompliantThen[ArrowDataFrame, ArrowSeries, ArrowExpr, ArrowWhen], ArrowExpr -): - _depth: int = 0 - _scalar_kwargs: ScalarKwargs = {} # noqa: RUF012 - _function_name: str = "whenthen" diff --git a/narwhals/_arrow/selectors.py b/narwhals/_arrow/selectors.py index 459e0022bb..62bb014083 100644 --- a/narwhals/_arrow/selectors.py +++ b/narwhals/_arrow/selectors.py @@ -8,7 +8,6 @@ if TYPE_CHECKING: from narwhals._arrow.dataframe import ArrowDataFrame # noqa: F401 from narwhals._arrow.series import ArrowSeries # noqa: F401 - from narwhals._compliant.typing import ScalarKwargs class ArrowSelectorNamespace(EagerSelectorNamespace["ArrowDataFrame", "ArrowSeries"]): @@ -18,15 +17,9 @@ def _selector(self) -> type[ArrowSelector]: class ArrowSelector(CompliantSelector["ArrowDataFrame", "ArrowSeries"], ArrowExpr): # type: ignore[misc] - _depth: int = 0 - _scalar_kwargs: ScalarKwargs = {} # noqa: RUF012 - _function_name: str = "selector" - def _to_expr(self) -> ArrowExpr: return ArrowExpr( self._call, - depth=self._depth, - function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, version=self._version, diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index ad35218148..2ae902d442 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -489,9 +489,7 @@ def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray: return self.native.to_numpy() def alias(self, name: str) -> Self: - result = self.__class__(self.native, name=name, version=self._version) - result._broadcast = self._broadcast - return result + return self.__class__(self.native, name=name, version=self._version) @property def dtype(self) -> DType: @@ -868,8 +866,8 @@ def mode(self, *, keep: ModeKeepStrategy) -> ArrowSeries: name=col_token, normalize=False, sort=False, parallel=False ) result = counts.filter( - plx.col(col_token) - == plx.col(col_token).max().broadcast(kind=ExprKind.AGGREGATION) + plx.col([col_token]) + == plx.col([col_token]).max().broadcast(kind=ExprKind.AGGREGATION) ).get_column(self.name) return result.head(1) if keep == "any" else result diff --git a/narwhals/_arrow/series_str.py b/narwhals/_arrow/series_str.py index 4b3fe0ee1d..1e1b6e752f 100644 --- a/narwhals/_arrow/series_str.py +++ b/narwhals/_arrow/series_str.py @@ -18,7 +18,7 @@ class ArrowSeriesStringNamespace(ArrowSeriesNamespace, StringNamespace["ArrowSer def len_chars(self) -> ArrowSeries: return self.with_native(pc.utf8_length(self.native)) - def replace(self, pattern: str, value: str, *, literal: bool, n: int) -> ArrowSeries: + def replace(self, value: str, pattern: str, *, literal: bool, n: int) -> ArrowSeries: fn = pc.replace_substring if literal else pc.replace_substring_regex try: arr = fn(self.native, pattern, replacement=value, max_replacements=n) @@ -29,9 +29,9 @@ def replace(self, pattern: str, value: str, *, literal: bool, n: int) -> ArrowSe raise return self.with_native(arr) - def replace_all(self, pattern: str, value: str, *, literal: bool) -> ArrowSeries: + def replace_all(self, value: str, pattern: str, *, literal: bool) -> ArrowSeries: try: - return self.replace(pattern, value, literal=literal, n=-1) + return self.replace(value, pattern, literal=literal, n=-1) except TypeError as e: if not isinstance(value, str): msg = "PyArrow backed `.str.replace_all` only supports str replacement values." diff --git a/narwhals/_compliant/__init__.py b/narwhals/_compliant/__init__.py index 70bd22588b..600ca4e1fd 100644 --- a/narwhals/_compliant/__init__.py +++ b/narwhals/_compliant/__init__.py @@ -53,7 +53,6 @@ NativeFrameT_co, NativeSeriesT_co, ) -from narwhals._compliant.when_then import CompliantThen, CompliantWhen, EagerWhen from narwhals._compliant.window import WindowInputs __all__ = [ @@ -70,8 +69,6 @@ "CompliantSeries", "CompliantSeriesOrNativeExprT_co", "CompliantSeriesT", - "CompliantThen", - "CompliantWhen", "DepthTrackingExpr", "DepthTrackingGroupBy", "DepthTrackingNamespace", @@ -90,7 +87,6 @@ "EagerSeriesStringNamespace", "EagerSeriesStructNamespace", "EagerSeriesT", - "EagerWhen", "EvalNames", "EvalSeries", "LazyExpr", diff --git a/narwhals/_compliant/any_namespace.py b/narwhals/_compliant/any_namespace.py index 54df3160fd..fc78bfdefb 100644 --- a/narwhals/_compliant/any_namespace.py +++ b/narwhals/_compliant/any_namespace.py @@ -86,10 +86,10 @@ class StringNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]): def len_chars(self) -> CompliantT_co: ... def replace( - self, pattern: str, value: str, *, literal: bool, n: int + self, value: str, pattern: str, *, literal: bool, n: int ) -> CompliantT_co: ... def replace_all( - self, pattern: str, value: str, *, literal: bool + self, value: str, pattern: str, *, literal: bool ) -> CompliantT_co: ... def strip_chars(self, characters: str | None) -> CompliantT_co: ... def starts_with(self, prefix: str) -> CompliantT_co: ... diff --git a/narwhals/_compliant/column.py b/narwhals/_compliant/column.py index 78748a1c8c..1dc680c023 100644 --- a/narwhals/_compliant/column.py +++ b/narwhals/_compliant/column.py @@ -105,66 +105,6 @@ def is_between( return (self > lower_bound) & (self < upper_bound) return (self >= lower_bound) & (self <= upper_bound) - def is_close( - self, - other: Self | NumericLiteral, - *, - abs_tol: float, - rel_tol: float, - nans_equal: bool, - ) -> Self: - from decimal import Decimal - - other_abs: Self | NumericLiteral - other_is_nan: Self | bool - other_is_inf: Self | bool - other_is_not_inf: Self | bool - - if isinstance(other, (float, int, Decimal)): - from math import isinf, isnan - - # NOTE: See https://discuss.python.org/t/inferred-type-of-function-that-calls-dunder-abs-abs/101447 - other_abs = other.__abs__() - other_is_nan = isnan(other) - other_is_inf = isinf(other) - - # Define the other_is_not_inf variable to prevent triggering the following warning: - # > DeprecationWarning: Bitwise inversion '~' on bool is deprecated and will be - # > removed in Python 3.16. - other_is_not_inf = not other_is_inf - - else: - other_abs, other_is_nan = other.abs(), other.is_nan() - other_is_not_inf = other.is_finite() | other_is_nan - other_is_inf = ~other_is_not_inf - - rel_threshold = self.abs().clip(lower_bound=other_abs, upper_bound=None) * rel_tol - tolerance = rel_threshold.clip(lower_bound=abs_tol, upper_bound=None) - - self_is_nan = self.is_nan() - self_is_not_inf = self.is_finite() | self_is_nan - - # Values are close if abs_diff <= tolerance, and both finite - is_close = ( - ((self - other).abs() <= tolerance) & self_is_not_inf & other_is_not_inf - ) - - # Handle infinity cases: infinities are close/equal if they have the same sign - self_sign, other_sign = self > 0, other > 0 - is_same_inf = (~self_is_not_inf) & other_is_inf & (self_sign == other_sign) - - # Handle nan cases: - # * If any value is NaN, then False (via `& ~either_nan`) - # * However, if `nans_equals = True` and if _both_ values are NaN, then True - either_nan = self_is_nan | other_is_nan - result = (is_close | is_same_inf) & ~either_nan - - if nans_equal: - both_nan = self_is_nan & other_is_nan - result = result | both_nan - - return result - def is_duplicated(self) -> Self: return ~self.is_unique() diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 8e7b8b5234..c369663fe5 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -37,6 +37,7 @@ is_slice_index, is_slice_none, ) +from narwhals.exceptions import MultiOutputExpressionError if TYPE_CHECKING: from io import BytesIO @@ -153,7 +154,6 @@ def simple_select(self, *column_names: str) -> Self: def sort( self, *by: str, descending: bool | Sequence[bool], nulls_last: bool ) -> Self: ... - def tail(self, n: int) -> Self: ... def unique( self, subset: Sequence[str] | None, @@ -337,7 +337,9 @@ def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | N def _evaluate_expr(self, expr: EagerExprT, /) -> EagerSeriesT: """Evaluate `expr` and ensure it has a **single** output.""" result: Sequence[EagerSeriesT] = expr(self) - assert len(result) == 1 # debug assertion # noqa: S101 + if len(result) != 1: + msg = "multi-output expressions not allowed in this context" + raise MultiOutputExpressionError(msg) return result[0] def _evaluate_into_exprs(self, *exprs: EagerExprT) -> Sequence[EagerSeriesT]: diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index 55ed2d3e47..54ae5d7f41 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -35,6 +35,7 @@ zip_strict, ) from narwhals.dependencies import is_numpy_array, is_numpy_scalar +from narwhals.exceptions import MultiOutputExpressionError if TYPE_CHECKING: from collections.abc import Mapping, Sequence @@ -43,7 +44,7 @@ from narwhals._compliant.namespace import CompliantNamespace, EagerNamespace from narwhals._compliant.series import CompliantSeries - from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs + from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries from narwhals._expression_parsing import ExprKind, ExprMetadata from narwhals._utils import Implementation, Version, _LimitedContext from narwhals.typing import ( @@ -90,7 +91,14 @@ class CompliantExpr( _implementation: Implementation _evaluate_output_names: EvalNames[CompliantFrameT] _alias_output_names: AliasNames | None - _metadata: ExprMetadata | None + _opt_metadata: ExprMetadata | None + + @property + def _metadata(self) -> ExprMetadata: + # This should be set with extreme care, and only at the Narwhals level or in + # `_expression_parsing.py`, and never from within any compliant class. + assert self._opt_metadata is not None # noqa: S101 + return self._opt_metadata def __call__( self, df: CompliantFrameT @@ -116,6 +124,7 @@ def broadcast( ) -> Self: ... # NOTE: `polars` + def alias(self, name: str) -> Self: ... def all(self) -> Self: ... def any(self) -> Self: ... def count(self) -> Self: ... @@ -170,9 +179,6 @@ class DepthTrackingExpr( ImplExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co], Protocol[CompliantFrameT, CompliantSeriesOrNativeExprT_co], ): - _depth: int - _function_name: str - # NOTE: pyright bug? # Method "from_column_names" overrides class "CompliantExpr" in an incompatible manner # Parameter 2 type mismatch: base parameter is type "EvalNames[CompliantFrameT@DepthTrackingExpr]", override parameter is type "EvalNames[CompliantFrameT@DepthTrackingExpr]" @@ -186,7 +192,6 @@ def from_column_names( # pyright: ignore[reportIncompatibleMethodOverride] /, *, context: _LimitedContext, - function_name: str = "", ) -> Self: ... def _is_elementary(self) -> bool: @@ -204,10 +209,7 @@ def _is_elementary(self) -> bool: Elementary expressions are the only ones supported properly in pandas, PyArrow, and Dask. """ - return self._depth < 2 - - def __repr__(self) -> str: # pragma: no cover - return f"{type(self).__name__}(depth={self._depth}, function_name={self._function_name})" + return len(list(self._metadata.op_nodes_reversed())) <= 2 class EagerExpr( @@ -215,19 +217,15 @@ class EagerExpr( Protocol[EagerDataFrameT, EagerSeriesT], ): _call: EvalSeries[EagerDataFrameT, EagerSeriesT] - _scalar_kwargs: ScalarKwargs def __init__( self, call: EvalSeries[EagerDataFrameT, EagerSeriesT], *, - depth: int, - function_name: str, evaluate_output_names: EvalNames[EagerDataFrameT], alias_output_names: AliasNames | None, implementation: Implementation, version: Version, - scalar_kwargs: ScalarKwargs | None = None, ) -> None: ... def __call__(self, df: EagerDataFrameT) -> Sequence[EagerSeriesT]: @@ -241,30 +239,22 @@ def _from_callable( cls, func: EvalSeries[EagerDataFrameT, EagerSeriesT], *, - depth: int, - function_name: str, evaluate_output_names: EvalNames[EagerDataFrameT], alias_output_names: AliasNames | None, context: _LimitedContext, - scalar_kwargs: ScalarKwargs | None = None, ) -> Self: return cls( func, - depth=depth, - function_name=function_name, evaluate_output_names=evaluate_output_names, alias_output_names=alias_output_names, implementation=context._implementation, version=context._version, - scalar_kwargs=scalar_kwargs, ) @classmethod def _from_series(cls, series: EagerSeriesT) -> Self: return cls( lambda _df: [series], - depth=0, - function_name="series", evaluate_output_names=lambda _df: [series.name], alias_output_names=None, implementation=series._implementation, @@ -300,13 +290,10 @@ def func(df: EagerDataFrameT) -> list[EagerSeriesT]: return self.__class__( func, - depth=self._depth, - function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=alias_output_names, implementation=self._implementation, version=self._version, - scalar_kwargs=self._scalar_kwargs, ) def _reuse_series( @@ -314,7 +301,6 @@ def _reuse_series( method_name: str, *, returns_scalar: bool = False, - scalar_kwargs: ScalarKwargs | None = None, **expressifiable_args: Any, ) -> Self: """Reuse Series implementation for expression. @@ -326,8 +312,6 @@ def _reuse_series( method_name: name of method. returns_scalar: whether the Series version returns a scalar. In this case, the expression version should return a 1-row Series. - scalar_kwargs: non-expressifiable args which we may need to reuse in `agg` or `over`, - such as `ddof` for `std` and `var`. expressifiable_args: keyword arguments to pass to function, which may be expressifiable (e.g. `nw.col('a').is_between(3, nw.col('b')))`). """ @@ -335,16 +319,12 @@ def _reuse_series( self._reuse_series_inner, method_name=method_name, returns_scalar=returns_scalar, - scalar_kwargs=scalar_kwargs or {}, expressifiable_args=expressifiable_args, ) return self._from_callable( func, - depth=self._depth + 1, - function_name=f"{self._function_name}->{method_name}", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, - scalar_kwargs=scalar_kwargs, context=self, ) @@ -365,15 +345,13 @@ def _reuse_series_inner( *, method_name: str, returns_scalar: bool, - scalar_kwargs: ScalarKwargs, expressifiable_args: dict[str, Any], ) -> Sequence[EagerSeriesT]: kwargs = { - **scalar_kwargs, **{ name: df._evaluate_expr(value) if self._is_expr(value) else value for name, value in expressifiable_args.items() - }, + } } method = methodcaller( method_name, @@ -424,11 +402,8 @@ def inner(df: EagerDataFrameT) -> list[EagerSeriesT]: return self._from_callable( inner, - depth=self._depth + 1, - function_name=f"{self._function_name}->{series_namespace}.{method_name}", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, - scalar_kwargs=self._scalar_kwargs, context=self, ) @@ -445,13 +420,10 @@ def func(df: EagerDataFrameT) -> list[EagerSeriesT]: return type(self)( func, - depth=self._depth, - function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, implementation=self._implementation, version=self._version, - scalar_kwargs=self._scalar_kwargs, ) def cast(self, dtype: IntoDType) -> Self: @@ -547,14 +519,10 @@ def median(self) -> Self: return self._reuse_series("median", returns_scalar=True) def std(self, *, ddof: int) -> Self: - return self._reuse_series( - "std", returns_scalar=True, scalar_kwargs={"ddof": ddof} - ) + return self._reuse_series("std", returns_scalar=True, ddof=ddof) def var(self, *, ddof: int) -> Self: - return self._reuse_series( - "var", returns_scalar=True, scalar_kwargs={"ddof": ddof} - ) + return self._reuse_series("var", returns_scalar=True, ddof=ddof) def skew(self) -> Self: return self._reuse_series("skew", returns_scalar=True) @@ -607,7 +575,7 @@ def fill_null( limit: int | None, ) -> Self: return self._reuse_series( - "fill_null", value=value, scalar_kwargs={"strategy": strategy, "limit": limit} + "fill_null", value=value, strategy=strategy, limit=limit ) def is_in(self, other: Any) -> Self: @@ -663,20 +631,15 @@ def alias(self, name: str) -> Self: def alias_output_names(names: Sequence[str]) -> Sequence[str]: if len(names) != 1: msg = f"Expected function with single output, found output names: {names}" - raise ValueError(msg) + raise MultiOutputExpressionError(msg) return [name] - # Define this one manually, so that we can - # override `output_names` and not increase depth return type(self)( lambda df: [series.alias(name) for series in self(df)], - depth=self._depth, - function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=alias_output_names, implementation=self._implementation, version=self._version, - scalar_kwargs=self._scalar_kwargs, ) def is_unique(self) -> Self: @@ -694,14 +657,15 @@ def quantile( return self._reuse_series( "quantile", returns_scalar=True, - scalar_kwargs={"quantile": quantile, "interpolation": interpolation}, + quantile=quantile, + interpolation=interpolation, ) def head(self, n: int) -> Self: - return self._reuse_series("head", scalar_kwargs={"n": n}) + return self._reuse_series("head", n=n) def tail(self, n: int) -> Self: - return self._reuse_series("tail", scalar_kwargs={"n": n}) + return self._reuse_series("tail", n=n) def round(self, decimals: int) -> Self: return self._reuse_series("round", decimals=decimals) @@ -713,7 +677,7 @@ def gather_every(self, n: int, offset: int) -> Self: return self._reuse_series("gather_every", n=n, offset=offset) def mode(self, *, keep: ModeKeepStrategy) -> Self: - return self._reuse_series("mode", scalar_kwargs={"keep": keep}) + return self._reuse_series("mode", keep=keep) def is_finite(self) -> Self: return self._reuse_series("is_finite") @@ -721,11 +685,9 @@ def is_finite(self) -> Self: def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self: return self._reuse_series( "rolling_mean", - scalar_kwargs={ - "window_size": window_size, - "min_samples": min_samples, - "center": center, - }, + window_size=window_size, + min_samples=min_samples, + center=center, ) def rolling_std( @@ -733,22 +695,15 @@ def rolling_std( ) -> Self: return self._reuse_series( "rolling_std", - scalar_kwargs={ - "window_size": window_size, - "min_samples": min_samples, - "center": center, - "ddof": ddof, - }, + window_size=window_size, + min_samples=min_samples, + center=center, + ddof=ddof, ) def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self: return self._reuse_series( - "rolling_sum", - scalar_kwargs={ - "window_size": window_size, - "min_samples": min_samples, - "center": center, - }, + "rolling_sum", window_size=window_size, min_samples=min_samples, center=center ) def rolling_var( @@ -756,12 +711,10 @@ def rolling_var( ) -> Self: return self._reuse_series( "rolling_var", - scalar_kwargs={ - "window_size": window_size, - "min_samples": min_samples, - "center": center, - "ddof": ddof, - }, + window_size=window_size, + min_samples=min_samples, + center=center, + ddof=ddof, ) def map_batches( @@ -805,35 +758,31 @@ def func(df: EagerDataFrameT) -> Sequence[EagerSeriesT]: return self._from_callable( func, - depth=self._depth + 1, - function_name=self._function_name + "->map_batches", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, context=self, ) def shift(self, n: int) -> Self: - return self._reuse_series("shift", scalar_kwargs={"n": n}) + return self._reuse_series("shift", n=n) def cum_sum(self, *, reverse: bool) -> Self: - return self._reuse_series("cum_sum", scalar_kwargs={"reverse": reverse}) + return self._reuse_series("cum_sum", reverse=reverse) def cum_count(self, *, reverse: bool) -> Self: - return self._reuse_series("cum_count", scalar_kwargs={"reverse": reverse}) + return self._reuse_series("cum_count", reverse=reverse) def cum_min(self, *, reverse: bool) -> Self: - return self._reuse_series("cum_min", scalar_kwargs={"reverse": reverse}) + return self._reuse_series("cum_min", reverse=reverse) def cum_max(self, *, reverse: bool) -> Self: - return self._reuse_series("cum_max", scalar_kwargs={"reverse": reverse}) + return self._reuse_series("cum_max", reverse=reverse) def cum_prod(self, *, reverse: bool) -> Self: - return self._reuse_series("cum_prod", scalar_kwargs={"reverse": reverse}) + return self._reuse_series("cum_prod", reverse=reverse) def rank(self, method: RankMethod, *, descending: bool) -> Self: - return self._reuse_series( - "rank", scalar_kwargs={"method": method, "descending": descending} - ) + return self._reuse_series("rank", method=method, descending=descending) def log(self, base: float) -> Self: return self._reuse_series("log", base=base) @@ -851,22 +800,6 @@ def is_between( "is_between", lower_bound=lower_bound, upper_bound=upper_bound, closed=closed ) - def is_close( - self, - other: Self | NumericLiteral, - *, - abs_tol: float, - rel_tol: float, - nans_equal: bool, - ) -> Self: - return self._reuse_series( - "is_close", - other=other, - abs_tol=abs_tol, - rel_tol=rel_tol, - nans_equal=nans_equal, - ) - @property def cat(self) -> EagerExprCatNamespace[Self]: return EagerExprCatNamespace(self) @@ -1099,12 +1032,12 @@ class EagerExprStringNamespace( def len_chars(self) -> EagerExprT: return self.compliant._reuse_series_namespace("str", "len_chars") - def replace(self, pattern: str, value: str, *, literal: bool, n: int) -> EagerExprT: + def replace(self, value: str, pattern: str, *, literal: bool, n: int) -> EagerExprT: return self.compliant._reuse_series_namespace( "str", "replace", pattern=pattern, value=value, literal=literal, n=n ) - def replace_all(self, pattern: str, value: str, *, literal: bool) -> EagerExprT: + def replace_all(self, value: str, pattern: str, *, literal: bool) -> EagerExprT: return self.compliant._reuse_series_namespace( "str", "replace_all", pattern=pattern, value=value, literal=literal ) diff --git a/narwhals/_compliant/group_by.py b/narwhals/_compliant/group_by.py index f9529cd442..217209980e 100644 --- a/narwhals/_compliant/group_by.py +++ b/narwhals/_compliant/group_by.py @@ -1,6 +1,5 @@ from __future__ import annotations -import re from itertools import chain from typing import TYPE_CHECKING, Any, Callable, ClassVar, Protocol, TypeVar @@ -31,9 +30,6 @@ ) -_RE_LEAF_NAME: re.Pattern[str] = re.compile(r"(\w+->)") - - def _evaluate_aliases( frame: CompliantFrameT, exprs: Iterable[ImplExpr[CompliantFrameT, Any]], / ) -> list[str]: @@ -170,7 +166,7 @@ def _remap_expr_name( @classmethod def _leaf_name(cls, expr: DepthTrackingExprAny, /) -> NarwhalsAggregation | Any: """Return the last function name in the chain defined by `expr`.""" - return _RE_LEAF_NAME.sub("", expr._function_name) + return next(expr._metadata.op_nodes_reversed()).name class EagerGroupBy( diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index e3c5bb5a34..9d767b5e80 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -10,37 +10,35 @@ DepthTrackingExprT, EagerDataFrameT, EagerExprT, - EagerSeriesT, + EagerSeriesT_co, LazyExprT, NativeFrameT, NativeFrameT_co, NativeSeriesT, ) -from narwhals._expression_parsing import is_expr, is_series +from narwhals._expression_parsing import is_expr from narwhals._utils import ( exclude_column_names, get_column_names, + is_compliant_expr, passthrough_column_names, ) -from narwhals.dependencies import is_numpy_array, is_numpy_array_2d +from narwhals.dependencies import is_numpy_array_2d if TYPE_CHECKING: - from collections.abc import Container, Iterable, Sequence + from collections.abc import Iterable, Sequence from typing_extensions import TypeAlias from narwhals._compliant.selectors import CompliantSelectorNamespace - from narwhals._compliant.when_then import CompliantWhen, EagerWhen from narwhals._utils import Implementation, Version from narwhals.expr import Expr - from narwhals.series import Series from narwhals.typing import ( ConcatMethod, Into1DArray, IntoDType, IntoSchema, NonNestedLiteral, - _1DArray, _2DArray, ) @@ -61,33 +59,32 @@ class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]): @property def _expr(self) -> type[CompliantExprT]: ... - def parse_into_expr( - self, data: Expr | NonNestedLiteral | Any, /, *, str_as_lit: bool + def evaluate_expr( + self, data: Expr | NonNestedLiteral | Any, / ) -> CompliantExprT | NonNestedLiteral: if is_expr(data): - expr = data._to_compliant_expr(self) + expr = data(self) assert isinstance(expr, self._expr) # noqa: S101 return expr - if isinstance(data, str) and not str_as_lit: - return self.col(data) + # TODO(marco): it would be nice to return `lit(data)` here, + # but for pandas and Dask this causes some issues. return data # NOTE: `polars` def all(self) -> CompliantExprT: return self._expr.from_column_names(get_column_names, context=self) - def col(self, *column_names: str) -> CompliantExprT: - return self._expr.from_column_names( - passthrough_column_names(column_names), context=self - ) + def col(self, names: Sequence[str]) -> CompliantExprT: + assert not isinstance(names, str) # noqa: S101 # debug assertion + return self._expr.from_column_names(passthrough_column_names(names), context=self) - def exclude(self, excluded_names: Container[str]) -> CompliantExprT: + def exclude(self, names: Sequence[str]) -> CompliantExprT: return self._expr.from_column_names( - partial(exclude_column_names, names=excluded_names), context=self + partial(exclude_column_names, names=names), context=self ) - def nth(self, *column_indices: int) -> CompliantExprT: - return self._expr.from_column_indices(*column_indices, context=self) + def nth(self, indices: Sequence[int]) -> CompliantExprT: + return self._expr.from_column_indices(*indices, context=self) def len(self) -> CompliantExprT: ... def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> CompliantExprT: ... @@ -104,9 +101,6 @@ def max_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... def concat( self, items: Iterable[CompliantFrameT], *, how: ConcatMethod ) -> CompliantFrameT: ... - def when( - self, predicate: CompliantExprT - ) -> CompliantWhen[CompliantFrameT, Incomplete, CompliantExprT]: ... def concat_str( self, *exprs: CompliantExprT, separator: str, ignore_nulls: bool ) -> CompliantExprT: ... @@ -122,20 +116,14 @@ class DepthTrackingNamespace( Protocol[CompliantFrameT, DepthTrackingExprT], ): def all(self) -> DepthTrackingExprT: - return self._expr.from_column_names( - get_column_names, function_name="all", context=self - ) + return self._expr.from_column_names(get_column_names, context=self) - def col(self, *column_names: str) -> DepthTrackingExprT: - return self._expr.from_column_names( - passthrough_column_names(column_names), function_name="col", context=self - ) + def col(self, names: Sequence[str]) -> DepthTrackingExprT: + return self._expr.from_column_names(passthrough_column_names(names), context=self) - def exclude(self, excluded_names: Container[str]) -> DepthTrackingExprT: + def exclude(self, names: Sequence[str]) -> DepthTrackingExprT: return self._expr.from_column_names( - partial(exclude_column_names, names=excluded_names), - function_name="exclude", - context=self, + partial(exclude_column_names, names=names), context=self ) @@ -159,7 +147,7 @@ def from_native(self, data: NativeFrameT_co | Any, /) -> CompliantLazyFrameT: class EagerNamespace( DepthTrackingNamespace[EagerDataFrameT, EagerExprT], - Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT], + Protocol[EagerDataFrameT, EagerSeriesT_co, EagerExprT, NativeFrameT, NativeSeriesT], ): @property def _backend_version(self) -> tuple[int, ...]: @@ -168,18 +156,64 @@ def _backend_version(self) -> tuple[int, ...]: @property def _dataframe(self) -> type[EagerDataFrameT]: ... @property - def _series(self) -> type[EagerSeriesT]: ... - def when( - self, predicate: EagerExprT - ) -> EagerWhen[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT]: ... + def _series(self) -> type[EagerSeriesT_co]: ... + def _if_then_else( + self, + when: NativeSeriesT, + then: NativeSeriesT, + otherwise: NativeSeriesT | None = None, + ) -> NativeSeriesT: ... + def when_then( + self, + predicate: EagerExprT, + then: EagerExprT | NonNestedLiteral, + otherwise: EagerExprT | NonNestedLiteral | None = None, + ) -> EagerExprT: + def func(df: EagerDataFrameT) -> Sequence[EagerSeriesT_co]: + predicate_s = df._evaluate_expr(predicate) + align = predicate_s._align_full_broadcast + + if is_compliant_expr(then): + then_s = df._evaluate_expr(then) + else: + then_s = predicate_s._from_scalar(then).alias("literal") + then_s._broadcast = True + if otherwise is None: + predicate_s, then_s = align(predicate_s, then_s) + result = self._if_then_else(predicate_s.native, then_s.native) + + if is_compliant_expr(otherwise): + otherwise_s = df._evaluate_expr(otherwise) + elif otherwise is not None: + otherwise_s = predicate_s._from_scalar(otherwise).alias("literal") + otherwise_s._broadcast = True + + if otherwise is None: + predicate_s, then_s = align(predicate_s, then_s) + result = self._if_then_else(predicate_s.native, then_s.native) + else: + predicate_s, then_s, otherwise_s = align(predicate_s, then_s, otherwise_s) + result = self._if_then_else( + predicate_s.native, then_s.native, otherwise_s.native + ) + return [then_s._with_native(result)] + + return self._expr._from_callable( + func=func, + evaluate_output_names=getattr( + then, "_evaluate_output_names", lambda _df: ["literal"] + ), + alias_output_names=getattr(then, "_alias_output_names", None), + context=predicate, + ) @overload def from_native(self, data: NativeFrameT, /) -> EagerDataFrameT: ... @overload - def from_native(self, data: NativeSeriesT, /) -> EagerSeriesT: ... + def from_native(self, data: NativeSeriesT, /) -> EagerSeriesT_co: ... def from_native( self, data: NativeFrameT | NativeSeriesT | Any, / - ) -> EagerDataFrameT | EagerSeriesT: + ) -> EagerDataFrameT | EagerSeriesT_co: if self._dataframe._is_native(data): return self._dataframe.from_native(data, context=self) if self._series._is_native(data): @@ -187,23 +221,8 @@ def from_native( msg = f"Unsupported type: {type(data).__name__!r}" raise TypeError(msg) - def parse_into_expr( - self, - data: Expr | Series[NativeSeriesT] | _1DArray | NonNestedLiteral, - /, - *, - str_as_lit: bool, - ) -> EagerExprT | NonNestedLiteral: - if not (is_series(data) or is_numpy_array(data)): - return super().parse_into_expr(data, str_as_lit=str_as_lit) - return self._expr._from_series( - data._compliant_series - if is_series(data) - else self._series.from_numpy(data, context=self) - ) - @overload - def from_numpy(self, data: Into1DArray, /, schema: None = ...) -> EagerSeriesT: ... + def from_numpy(self, data: Into1DArray, /, schema: None = ...) -> EagerSeriesT_co: ... @overload def from_numpy( @@ -215,7 +234,7 @@ def from_numpy( data: Into1DArray | _2DArray, /, schema: IntoSchema | Sequence[str] | None = None, - ) -> EagerDataFrameT | EagerSeriesT: + ) -> EagerDataFrameT | EagerSeriesT_co: if is_numpy_array_2d(data): return self._dataframe.from_numpy(data, schema=schema, context=self) return self._series.from_numpy(data, context=self) diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index 561ff3de6d..2b0b02a3f1 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -37,6 +37,7 @@ from typing_extensions import NotRequired, Self, TypedDict from narwhals._compliant.dataframe import CompliantDataFrame + from narwhals._compliant.expr import CompliantExpr, EagerExpr from narwhals._compliant.namespace import EagerNamespace from narwhals._utils import Implementation, Version, _LimitedContext from narwhals.dtypes import DType @@ -96,6 +97,8 @@ def to_narwhals(self) -> Series[NativeSeriesT]: def _with_native(self, series: Any) -> Self: ... def _with_version(self, version: Version) -> Self: ... + def _to_expr(self) -> CompliantExpr[Any, Self]: ... + # NOTE: `polars` @property def dtype(self) -> DType: ... @@ -243,6 +246,9 @@ def __narwhals_namespace__( self, ) -> EagerNamespace[Any, Self, Any, Any, NativeSeriesT]: ... + def _to_expr(self) -> EagerExpr[Any, Any]: + return self.__narwhals_namespace__()._expr._from_series(self) # type: ignore[no-any-return] + def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ... def _gather_slice(self, rows: _SliceIndex | range) -> Self: ... def __getitem__(self, item: MultiIndexSelector[Self]) -> Self: diff --git a/narwhals/_compliant/when_then.py b/narwhals/_compliant/when_then.py deleted file mode 100644 index bc4db69382..0000000000 --- a/narwhals/_compliant/when_then.py +++ /dev/null @@ -1,130 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast - -from narwhals._compliant.expr import CompliantExpr -from narwhals._compliant.typing import ( - CompliantExprAny, - CompliantFrameAny, - CompliantSeriesOrNativeExprAny, - EagerDataFrameT, - EagerExprT, - EagerSeriesT, - LazyExprAny, - NativeSeriesT, -) - -if TYPE_CHECKING: - from collections.abc import Sequence - - from typing_extensions import Self, TypeAlias - - from narwhals._compliant.typing import EvalSeries - from narwhals._utils import Implementation, Version, _LimitedContext - from narwhals.typing import NonNestedLiteral - - -__all__ = ["CompliantThen", "CompliantWhen", "EagerWhen"] - -ExprT = TypeVar("ExprT", bound=CompliantExprAny) -LazyExprT = TypeVar("LazyExprT", bound=LazyExprAny) -SeriesT = TypeVar("SeriesT", bound=CompliantSeriesOrNativeExprAny) -FrameT = TypeVar("FrameT", bound=CompliantFrameAny) - -Scalar: TypeAlias = Any -"""A native literal value.""" - -IntoExpr: TypeAlias = "SeriesT | ExprT | NonNestedLiteral | Scalar" -"""Anything that is convertible into a `CompliantExpr`.""" - - -class CompliantWhen(Protocol[FrameT, SeriesT, ExprT]): - _condition: ExprT - _then_value: IntoExpr[SeriesT, ExprT] - _otherwise_value: IntoExpr[SeriesT, ExprT] | None - _implementation: Implementation - _version: Version - - @property - def _then(self) -> type[CompliantThen[FrameT, SeriesT, ExprT, Self]]: ... - def __call__(self, compliant_frame: FrameT, /) -> Sequence[SeriesT]: ... - def then( - self, value: IntoExpr[SeriesT, ExprT], / - ) -> CompliantThen[FrameT, SeriesT, ExprT, Self]: - return self._then.from_when(self, value) - - @classmethod - def from_expr(cls, condition: ExprT, /, *, context: _LimitedContext) -> Self: - obj = cls.__new__(cls) - obj._condition = condition - obj._then_value = None - obj._otherwise_value = None - obj._implementation = context._implementation - obj._version = context._version - return obj - - -WhenT_contra = TypeVar( - "WhenT_contra", bound=CompliantWhen[Any, Any, Any], contravariant=True -) - - -class CompliantThen( - CompliantExpr[FrameT, SeriesT], Protocol[FrameT, SeriesT, ExprT, WhenT_contra] -): - _call: EvalSeries[FrameT, SeriesT] - _when_value: CompliantWhen[FrameT, SeriesT, ExprT] - _implementation: Implementation - _version: Version - - @classmethod - def from_when(cls, when: WhenT_contra, then: IntoExpr[SeriesT, ExprT], /) -> Self: - when._then_value = then - obj = cls.__new__(cls) - obj._call = when - obj._when_value = when - obj._evaluate_output_names = getattr( - then, "_evaluate_output_names", lambda _df: ["literal"] - ) - obj._alias_output_names = getattr(then, "_alias_output_names", None) - obj._implementation = when._implementation - obj._version = when._version - return obj - - def otherwise(self, otherwise: IntoExpr[SeriesT, ExprT], /) -> ExprT: - self._when_value._otherwise_value = otherwise - return cast("ExprT", self) - - -class EagerWhen( - CompliantWhen[EagerDataFrameT, EagerSeriesT, EagerExprT], - Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT], -): - def _if_then_else( - self, - when: NativeSeriesT, - then: NativeSeriesT, - otherwise: NativeSeriesT | NonNestedLiteral | Scalar, - /, - ) -> NativeSeriesT: ... - - def __call__(self, df: EagerDataFrameT, /) -> Sequence[EagerSeriesT]: - is_expr = self._condition._is_expr - when: EagerSeriesT = self._condition(df)[0] - then: EagerSeriesT - align = when._align_full_broadcast - - if is_expr(self._then_value): - then = self._then_value(df)[0] - else: - then = when.alias("literal")._from_scalar(self._then_value) - then._broadcast = True - - if is_expr(self._otherwise_value): - otherwise = self._otherwise_value(df)[0] - when, then, otherwise = align(when, then, otherwise) - result = self._if_then_else(when.native, then.native, otherwise.native) - else: - when, then = align(when, then) - result = self._if_then_else(when.native, then.native, self._otherwise_value) - return [then._with_native(result)] diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 376bf6a887..4da9097f14 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -193,7 +193,7 @@ def drop_nulls(self, subset: Sequence[str] | None) -> Self: if subset is None: return self._with_native(self.native.dropna()) plx = self.__narwhals_namespace__() - mask = ~plx.any_horizontal(plx.col(*subset).is_null(), ignore_nulls=True) + mask = ~plx.any_horizontal(plx.col(subset).is_null(), ignore_nulls=True) return self.filter(mask) @property @@ -225,10 +225,12 @@ def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self: columns = self.columns const_expr = plx.lit(value=1, dtype=None).alias(name).broadcast(ExprKind.LITERAL) row_index_expr = ( - plx.col(name).cum_sum(reverse=False).over(partition_by=[], order_by=order_by) + plx.col([name]) + .cum_sum(reverse=False) + .over(partition_by=[], order_by=order_by) - 1 ) - return self.with_columns(const_expr).select(row_index_expr, plx.col(*columns)) + return self.with_columns(const_expr).select(row_index_expr, plx.col(columns)) def rename(self, mapping: Mapping[str, str]) -> Self: return self._with_native(self.native.rename(columns=mapping)) @@ -482,8 +484,8 @@ def gather_every(self, n: int, offset: int) -> Self: return ( self.with_row_index(row_index_token, order_by=None) .filter( - (plx.col(row_index_token) >= offset) - & ((plx.col(row_index_token) - offset) % n == 0) + (plx.col([row_index_token]) >= offset) + & ((plx.col([row_index_token]) - offset) % n == 0) ) .drop([row_index_token], strict=False) ) diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index db636bb504..a98b539902 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -14,6 +14,7 @@ narwhals_to_native_dtype, ) from narwhals._expression_parsing import ExprKind, evaluate_output_names_and_aliases +from narwhals._pandas_like.expr import window_kwargs_to_pandas_equivalent from narwhals._pandas_like.utils import get_dtype_backend, native_to_narwhals_dtype from narwhals._utils import ( Implementation, @@ -28,7 +29,12 @@ import dask.dataframe.dask_expr as dx from typing_extensions import Self - from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs + from narwhals._compliant.typing import ( + AliasNames, + EvalNames, + EvalSeries, + NarwhalsAggregation, + ) from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.namespace import DaskNamespace from narwhals._expression_parsing import ExprKind, ExprMetadata @@ -54,21 +60,15 @@ def __init__( self, call: EvalSeries[DaskLazyFrame, dx.Series], # pyright: ignore[reportInvalidTypeForm] *, - depth: int, - function_name: str, evaluate_output_names: EvalNames[DaskLazyFrame], alias_output_names: AliasNames | None, version: Version, - scalar_kwargs: ScalarKwargs | None = None, ) -> None: self._call = call - self._depth = depth - self._function_name = function_name self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names self._version = version - self._scalar_kwargs = scalar_kwargs or {} - self._metadata: ExprMetadata | None = None + self._opt_metadata: ExprMetadata | None = None def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]: return self._call(df) @@ -86,12 +86,9 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return self.__class__( func, - depth=self._depth, - function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, version=self._version, - scalar_kwargs=self._scalar_kwargs, ) @classmethod @@ -101,7 +98,6 @@ def from_column_names( /, *, context: _LimitedContext, - function_name: str = "", ) -> Self: def func(df: DaskLazyFrame) -> list[dx.Series]: try: @@ -116,8 +112,6 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return cls( func, - depth=0, - function_name=function_name, evaluate_output_names=evaluate_column_names, alias_output_names=None, version=context._version, @@ -130,8 +124,6 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return cls( func, - depth=0, - function_name="nth", evaluate_output_names=cls._eval_names_indices(column_indices), alias_output_names=None, version=context._version, @@ -142,8 +134,6 @@ def _with_callable( # First argument to `call` should be `dx.Series` call: Callable[..., dx.Series], /, - expr_name: str = "", - scalar_kwargs: ScalarKwargs | None = None, **expressifiable_args: Self | Any, ) -> Self: def func(df: DaskLazyFrame) -> list[dx.Series]: @@ -160,12 +150,9 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return self.__class__( func, - depth=self._depth + 1, - function_name=f"{self._function_name}->{expr_name}", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, version=self._version, - scalar_kwargs=scalar_kwargs, ) def _with_alias_output_names(self, func: AliasNames | None, /) -> Self: @@ -179,12 +166,9 @@ def _with_alias_output_names(self, func: AliasNames | None, /) -> Self: ) return type(self)( call=self._call, - depth=self._depth, - function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=alias_output_names, version=self._version, - scalar_kwargs=self._scalar_kwargs, ) def _with_binary( @@ -195,9 +179,7 @@ def _with_binary( *, reverse: bool = False, ) -> Self: - result = self._with_callable( - lambda expr, other: call(expr, other), name, other=other - ) + result = self._with_callable(lambda expr, other: call(expr, other), other=other) if reverse: result = result.alias("literal") return result @@ -275,10 +257,10 @@ def __rmod__(self, other: Any) -> Self: return self._reverse_binary_op("__rmod__", lambda a, b: a % b, other) def __invert__(self) -> Self: - return self._with_callable(lambda expr: expr.__invert__(), "__invert__") + return self._with_callable(lambda expr: expr.__invert__()) def mean(self) -> Self: - return self._with_callable(lambda expr: expr.mean().to_series(), "mean") + return self._with_callable(lambda expr: expr.mean().to_series()) def median(self) -> Self: from narwhals.exceptions import InvalidOperationError @@ -290,36 +272,28 @@ def func(s: dx.Series) -> dx.Series: raise InvalidOperationError(msg) return s.median_approximate().to_series() - return self._with_callable(func, "median") + return self._with_callable(func) def min(self) -> Self: - return self._with_callable(lambda expr: expr.min().to_series(), "min") + return self._with_callable(lambda expr: expr.min().to_series()) def max(self) -> Self: - return self._with_callable(lambda expr: expr.max().to_series(), "max") + return self._with_callable(lambda expr: expr.max().to_series()) def std(self, *, ddof: int) -> Self: - return self._with_callable( - lambda expr: expr.std(ddof=ddof).to_series(), - "std", - scalar_kwargs={"ddof": ddof}, - ) + return self._with_callable(lambda expr: expr.std(ddof=ddof).to_series()) def var(self, *, ddof: int) -> Self: - return self._with_callable( - lambda expr: expr.var(ddof=ddof).to_series(), - "var", - scalar_kwargs={"ddof": ddof}, - ) + return self._with_callable(lambda expr: expr.var(ddof=ddof).to_series()) def skew(self) -> Self: - return self._with_callable(lambda expr: expr.skew().to_series(), "skew") + return self._with_callable(lambda expr: expr.skew().to_series()) def kurtosis(self) -> Self: - return self._with_callable(lambda expr: expr.kurtosis().to_series(), "kurtosis") + return self._with_callable(lambda expr: expr.kurtosis().to_series()) def shift(self, n: int) -> Self: - return self._with_callable(lambda expr: expr.shift(n), "shift") + return self._with_callable(lambda expr: expr.shift(n)) def cum_sum(self, *, reverse: bool) -> Self: if reverse: # pragma: no cover @@ -327,52 +301,48 @@ def cum_sum(self, *, reverse: bool) -> Self: msg = "`cum_sum(reverse=True)` is not supported with Dask backend" raise NotImplementedError(msg) - return self._with_callable(lambda expr: expr.cumsum(), "cum_sum") + return self._with_callable(lambda expr: expr.cumsum()) def cum_count(self, *, reverse: bool) -> Self: if reverse: # pragma: no cover msg = "`cum_count(reverse=True)` is not supported with Dask backend" raise NotImplementedError(msg) - return self._with_callable( - lambda expr: (~expr.isna()).astype(int).cumsum(), "cum_count" - ) + return self._with_callable(lambda expr: (~expr.isna()).astype(int).cumsum()) def cum_min(self, *, reverse: bool) -> Self: if reverse: # pragma: no cover msg = "`cum_min(reverse=True)` is not supported with Dask backend" raise NotImplementedError(msg) - return self._with_callable(lambda expr: expr.cummin(), "cum_min") + return self._with_callable(lambda expr: expr.cummin()) def cum_max(self, *, reverse: bool) -> Self: if reverse: # pragma: no cover msg = "`cum_max(reverse=True)` is not supported with Dask backend" raise NotImplementedError(msg) - return self._with_callable(lambda expr: expr.cummax(), "cum_max") + return self._with_callable(lambda expr: expr.cummax()) def cum_prod(self, *, reverse: bool) -> Self: if reverse: # pragma: no cover msg = "`cum_prod(reverse=True)` is not supported with Dask backend" raise NotImplementedError(msg) - return self._with_callable(lambda expr: expr.cumprod(), "cum_prod") + return self._with_callable(lambda expr: expr.cumprod()) def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self: return self._with_callable( lambda expr: expr.rolling( window=window_size, min_periods=min_samples, center=center - ).sum(), - "rolling_sum", + ).sum() ) def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self: return self._with_callable( lambda expr: expr.rolling( window=window_size, min_periods=min_samples, center=center - ).mean(), - "rolling_mean", + ).mean() ) def rolling_var( @@ -382,8 +352,7 @@ def rolling_var( return self._with_callable( lambda expr: expr.rolling( window=window_size, min_periods=min_samples, center=center - ).var(), - "rolling_var", + ).var() ) msg = "Dask backend only supports `ddof=1` for `rolling_var`" raise NotImplementedError(msg) @@ -395,42 +364,39 @@ def rolling_std( return self._with_callable( lambda expr: expr.rolling( window=window_size, min_periods=min_samples, center=center - ).std(), - "rolling_std", + ).std() ) msg = "Dask backend only supports `ddof=1` for `rolling_std`" raise NotImplementedError(msg) def sum(self) -> Self: - return self._with_callable(lambda expr: expr.sum().to_series(), "sum") + return self._with_callable(lambda expr: expr.sum().to_series()) def count(self) -> Self: - return self._with_callable(lambda expr: expr.count().to_series(), "count") + return self._with_callable(lambda expr: expr.count().to_series()) def round(self, decimals: int) -> Self: - return self._with_callable(lambda expr: expr.round(decimals), "round") + return self._with_callable(lambda expr: expr.round(decimals)) def unique(self) -> Self: - return self._with_callable(lambda expr: expr.unique(), "unique") + return self._with_callable(lambda expr: expr.unique()) def drop_nulls(self) -> Self: - return self._with_callable(lambda expr: expr.dropna(), "drop_nulls") + return self._with_callable(lambda expr: expr.dropna()) def abs(self) -> Self: - return self._with_callable(lambda expr: expr.abs(), "abs") + return self._with_callable(lambda expr: expr.abs()) def all(self) -> Self: return self._with_callable( lambda expr: expr.all( axis=None, skipna=True, split_every=False, out=None - ).to_series(), - "all", + ).to_series() ) def any(self) -> Self: return self._with_callable( - lambda expr: expr.any(axis=0, skipna=True, split_every=False).to_series(), - "any", + lambda expr: expr.any(axis=0, skipna=True, split_every=False).to_series() ) def fill_nan(self, value: float | None) -> Self: @@ -448,7 +414,7 @@ def func(expr: dx.Series) -> dx.Series: ) return expr.mask(mask, fill) # pyright: ignore[reportArgumentType] - return self._with_callable(func, "fill_nan") + return self._with_callable(func) def fill_null( self, @@ -467,7 +433,7 @@ def func(expr: dx.Series) -> dx.Series: ) return res_ser - return self._with_callable(func, "fill_null") + return self._with_callable(func) def clip( self, @@ -478,21 +444,18 @@ def clip( lambda expr, lower_bound, upper_bound: expr.clip( lower=lower_bound, upper=upper_bound ), - "clip", lower_bound=lower_bound, upper_bound=upper_bound, ) def diff(self) -> Self: - return self._with_callable(lambda expr: expr.diff(), "diff") + return self._with_callable(lambda expr: expr.diff()) def n_unique(self) -> Self: - return self._with_callable( - lambda expr: expr.nunique(dropna=False).to_series(), "n_unique" - ) + return self._with_callable(lambda expr: expr.nunique(dropna=False).to_series()) def is_null(self) -> Self: - return self._with_callable(lambda expr: expr.isna(), "is_null") + return self._with_callable(lambda expr: expr.isna()) def is_nan(self) -> Self: def func(expr: dx.Series) -> dx.Series: @@ -504,17 +467,17 @@ def func(expr: dx.Series) -> dx.Series: msg = f"`.is_nan` only supported for numeric dtypes and not {dtype}, did you mean `.is_null`?" raise InvalidOperationError(msg) - return self._with_callable(func, "is_null") + return self._with_callable(func) def len(self) -> Self: - return self._with_callable(lambda expr: expr.size.to_series(), "len") + return self._with_callable(lambda expr: expr.size.to_series()) def quantile( self, quantile: float, interpolation: RollingInterpolationMethod ) -> Self: if interpolation == "linear": - def func(expr: dx.Series, quantile: float) -> dx.Series: + def func(expr: dx.Series) -> dx.Series: if expr.npartitions > 1: msg = "`Expr.quantile` is not supported for Dask backend with multiple partitions." raise NotImplementedError(msg) @@ -522,7 +485,7 @@ def func(expr: dx.Series, quantile: float) -> dx.Series: q=quantile, method="dask" ).to_series() # pragma: no cover - return self._with_callable(func, "quantile", quantile=quantile) + return self._with_callable(func) msg = "`higher`, `lower`, `midpoint`, `nearest` - interpolation methods are not supported by Dask. Please use `linear` instead." raise NotImplementedError(msg) @@ -534,7 +497,7 @@ def func(expr: dx.Series) -> dx.Series: first_distinct_index = frame.groupby(_name).agg({col_token: "min"})[col_token] return frame[col_token].isin(first_distinct_index) - return self._with_callable(func, "is_first_distinct") + return self._with_callable(func) def is_last_distinct(self) -> Self: def func(expr: dx.Series) -> dx.Series: @@ -544,7 +507,7 @@ def func(expr: dx.Series) -> dx.Series: last_distinct_index = frame.groupby(_name).agg({col_token: "max"})[col_token] return frame[col_token].isin(last_distinct_index) - return self._with_callable(func, "is_last_distinct") + return self._with_callable(func) def is_unique(self) -> Self: def func(expr: dx.Series) -> dx.Series: @@ -556,15 +519,13 @@ def func(expr: dx.Series) -> dx.Series: == 1 ) - return self._with_callable(func, "is_unique") + return self._with_callable(func) def is_in(self, other: Any) -> Self: - return self._with_callable(lambda expr: expr.isin(other), "is_in") + return self._with_callable(lambda expr: expr.isin(other)) def null_count(self) -> Self: - return self._with_callable( - lambda expr: expr.isna().sum().to_series(), "null_count" - ) + return self._with_callable(lambda expr: expr.isna().sum().to_series()) def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self: # pandas is a required dependency of dask so it's safe to import this @@ -589,7 +550,8 @@ def func(df: DaskLazyFrame) -> Sequence[dx.Series]: msg = "`over` with `order_by` is not yet supported in Dask." raise NotImplementedError(msg) else: - function_name = PandasLikeGroupBy._leaf_name(self) + leaf_node = next(self._metadata.op_nodes_reversed()) + function_name = cast("NarwhalsAggregation", leaf_node.name) try: dask_function_name = PandasLikeGroupBy._REMAP_AGGS[function_name] except KeyError: @@ -611,16 +573,20 @@ def func(df: DaskLazyFrame) -> Sequence[dx.Series]: category=UserWarning, ) grouped = df.native.groupby(partition_by) + kwargs = leaf_node.kwargs + pandas_kwargs = window_kwargs_to_pandas_equivalent( + function_name, kwargs + ) if dask_function_name == "size": if len(output_names) != 1: # pragma: no cover msg = "Safety check failed, please report a bug." raise AssertionError(msg) res_native = grouped.transform( - dask_function_name, **self._scalar_kwargs + dask_function_name, **pandas_kwargs ).to_frame(output_names[0]) else: res_native = grouped[list(output_names)].transform( - dask_function_name, **self._scalar_kwargs + dask_function_name, **pandas_kwargs ) result_frame = df._with_native( res_native.rename(columns=dict(zip(output_names, aliases))) @@ -629,8 +595,6 @@ def func(df: DaskLazyFrame) -> Sequence[dx.Series]: return self.__class__( func, - depth=self._depth + 1, - function_name=self._function_name + "->over", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, version=self._version, @@ -641,12 +605,12 @@ def func(expr: dx.Series) -> dx.Series: native_dtype = narwhals_to_native_dtype(dtype, self._version) return expr.astype(native_dtype) - return self._with_callable(func, "cast") + return self._with_callable(func) def is_finite(self) -> Self: import dask.array as da - return self._with_callable(da.isfinite, "is_finite") + return self._with_callable(da.isfinite) def log(self, base: float) -> Self: import dask.array as da @@ -654,17 +618,17 @@ def log(self, base: float) -> Self: def _log(expr: dx.Series) -> dx.Series: return da.log(expr) / da.log(base) - return self._with_callable(_log, "log") + return self._with_callable(_log) def exp(self) -> Self: import dask.array as da - return self._with_callable(da.exp, "exp") + return self._with_callable(da.exp) def sqrt(self) -> Self: import dask.array as da - return self._with_callable(da.sqrt, "sqrt") + return self._with_callable(da.sqrt) def mode(self, *, keep: ModeKeepStrategy) -> Self: def func(expr: dx.Series) -> dx.Series: @@ -672,7 +636,7 @@ def func(expr: dx.Series) -> dx.Series: result = expr.to_frame().mode()[_name] return result.head(1) if keep == "any" else result - return self._with_callable(func, "mode", scalar_kwargs={"keep": keep}) + return self._with_callable(func) @property def str(self) -> DaskExprStringNamespace: diff --git a/narwhals/_dask/expr_dt.py b/narwhals/_dask/expr_dt.py index a3e3f8eab6..bc75b99747 100644 --- a/narwhals/_dask/expr_dt.py +++ b/narwhals/_dask/expr_dt.py @@ -25,70 +25,59 @@ class DaskExprDateTimeNamespace( LazyExprNamespace["DaskExpr"], DateTimeNamespace["DaskExpr"] ): def date(self) -> DaskExpr: - return self.compliant._with_callable(lambda expr: expr.dt.date, "date") + return self.compliant._with_callable(lambda expr: expr.dt.date) def year(self) -> DaskExpr: - return self.compliant._with_callable(lambda expr: expr.dt.year, "year") + return self.compliant._with_callable(lambda expr: expr.dt.year) def month(self) -> DaskExpr: - return self.compliant._with_callable(lambda expr: expr.dt.month, "month") + return self.compliant._with_callable(lambda expr: expr.dt.month) def day(self) -> DaskExpr: - return self.compliant._with_callable(lambda expr: expr.dt.day, "day") + return self.compliant._with_callable(lambda expr: expr.dt.day) def hour(self) -> DaskExpr: - return self.compliant._with_callable(lambda expr: expr.dt.hour, "hour") + return self.compliant._with_callable(lambda expr: expr.dt.hour) def minute(self) -> DaskExpr: - return self.compliant._with_callable(lambda expr: expr.dt.minute, "minute") + return self.compliant._with_callable(lambda expr: expr.dt.minute) def second(self) -> DaskExpr: - return self.compliant._with_callable(lambda expr: expr.dt.second, "second") + return self.compliant._with_callable(lambda expr: expr.dt.second) def millisecond(self) -> DaskExpr: - return self.compliant._with_callable( - lambda expr: expr.dt.microsecond // 1000, "millisecond" - ) + return self.compliant._with_callable(lambda expr: expr.dt.microsecond // 1000) def microsecond(self) -> DaskExpr: - return self.compliant._with_callable( - lambda expr: expr.dt.microsecond, "microsecond" - ) + return self.compliant._with_callable(lambda expr: expr.dt.microsecond) def nanosecond(self) -> DaskExpr: return self.compliant._with_callable( - lambda expr: expr.dt.microsecond * 1000 + expr.dt.nanosecond, "nanosecond" + lambda expr: expr.dt.microsecond * 1000 + expr.dt.nanosecond ) def ordinal_day(self) -> DaskExpr: - return self.compliant._with_callable( - lambda expr: expr.dt.dayofyear, "ordinal_day" - ) + return self.compliant._with_callable(lambda expr: expr.dt.dayofyear) def weekday(self) -> DaskExpr: return self.compliant._with_callable( - lambda expr: expr.dt.weekday + 1, # Dask is 0-6 - "weekday", + lambda expr: expr.dt.weekday + 1 # Dask is 0-6 ) def to_string(self, format: str) -> DaskExpr: return self.compliant._with_callable( - lambda expr, format: expr.dt.strftime(format.replace("%.f", ".%f")), - "strftime", - format=format, + lambda expr: expr.dt.strftime(format.replace("%.f", ".%f")) ) def replace_time_zone(self, time_zone: str | None) -> DaskExpr: return self.compliant._with_callable( - lambda expr, time_zone: expr.dt.tz_localize(None).dt.tz_localize(time_zone) + lambda expr: expr.dt.tz_localize(None).dt.tz_localize(time_zone) if time_zone is not None - else expr.dt.tz_localize(None), - "tz_localize", - time_zone=time_zone, + else expr.dt.tz_localize(None) ) def convert_time_zone(self, time_zone: str) -> DaskExpr: - def func(s: dx.Series, time_zone: str) -> dx.Series: + def func(s: dx.Series) -> dx.Series: dtype = native_to_narwhals_dtype( s.dtype, self.compliant._version, Implementation.DASK ) @@ -96,11 +85,11 @@ def func(s: dx.Series, time_zone: str) -> dx.Series: return s.dt.tz_localize("UTC").dt.tz_convert(time_zone) # pyright: ignore[reportAttributeAccessIssue] return s.dt.tz_convert(time_zone) # pyright: ignore[reportAttributeAccessIssue] - return self.compliant._with_callable(func, "tz_convert", time_zone=time_zone) + return self.compliant._with_callable(func) # ignoring coverage due to https://github.com/narwhals-dev/narwhals/issues/2808. def timestamp(self, time_unit: TimeUnit) -> DaskExpr: # pragma: no cover - def func(s: dx.Series, time_unit: TimeUnit) -> dx.Series: + def func(s: dx.Series) -> dx.Series: dtype = native_to_narwhals_dtype( s.dtype, self.compliant._version, Implementation.DASK ) @@ -124,33 +113,27 @@ def func(s: dx.Series, time_unit: TimeUnit) -> dx.Series: raise TypeError(msg) return result.where(~mask_na) # pyright: ignore[reportReturnType] - return self.compliant._with_callable(func, "datetime", time_unit=time_unit) + return self.compliant._with_callable(func) def total_minutes(self) -> DaskExpr: - return self.compliant._with_callable( - lambda expr: expr.dt.total_seconds() // 60, "total_minutes" - ) + return self.compliant._with_callable(lambda expr: expr.dt.total_seconds() // 60) def total_seconds(self) -> DaskExpr: - return self.compliant._with_callable( - lambda expr: expr.dt.total_seconds() // 1, "total_seconds" - ) + return self.compliant._with_callable(lambda expr: expr.dt.total_seconds() // 1) def total_milliseconds(self) -> DaskExpr: return self.compliant._with_callable( - lambda expr: expr.dt.total_seconds() * MS_PER_SECOND // 1, - "total_milliseconds", + lambda expr: expr.dt.total_seconds() * MS_PER_SECOND // 1 ) def total_microseconds(self) -> DaskExpr: return self.compliant._with_callable( - lambda expr: expr.dt.total_seconds() * US_PER_SECOND // 1, - "total_microseconds", + lambda expr: expr.dt.total_seconds() * US_PER_SECOND // 1 ) def total_nanoseconds(self) -> DaskExpr: return self.compliant._with_callable( - lambda expr: expr.dt.total_seconds() * NS_PER_SECOND // 1, "total_nanoseconds" + lambda expr: expr.dt.total_seconds() * NS_PER_SECOND // 1 ) def truncate(self, every: str) -> DaskExpr: @@ -160,10 +143,10 @@ def truncate(self, every: str) -> DaskExpr: msg = f"Truncating to {unit} is not yet supported for dask." raise NotImplementedError(msg) freq = f"{interval.multiple}{ALIAS_DICT.get(unit, unit)}" - return self.compliant._with_callable(lambda expr: expr.dt.floor(freq), "truncate") + return self.compliant._with_callable(lambda expr: expr.dt.floor(freq)) def offset_by(self, by: str) -> DaskExpr: - def func(s: dx.Series, by: str) -> dx.Series: + def func(s: dx.Series) -> dx.Series: interval = Interval.parse_no_constraints(by) unit = interval.unit if unit in {"y", "q", "mo", "d", "ns"}: @@ -172,4 +155,4 @@ def func(s: dx.Series, by: str) -> dx.Series: offset = interval.to_timedelta() return s.add(offset) - return self.compliant._with_callable(func, "offset_by", by=by) + return self.compliant._with_callable(func) diff --git a/narwhals/_dask/expr_str.py b/narwhals/_dask/expr_str.py index 953f271cee..846ff9c808 100644 --- a/narwhals/_dask/expr_str.py +++ b/narwhals/_dask/expr_str.py @@ -16,12 +16,10 @@ class DaskExprStringNamespace(LazyExprNamespace["DaskExpr"], StringNamespace["DaskExpr"]): def len_chars(self) -> DaskExpr: - return self.compliant._with_callable(lambda expr: expr.str.len(), "len") + return self.compliant._with_callable(lambda expr: expr.str.len()) - def replace(self, pattern: str, value: str, *, literal: bool, n: int) -> DaskExpr: - def _replace( - expr: dx.Series, pattern: str, value: str, *, literal: bool, n: int - ) -> dx.Series: + def replace(self, value: str, pattern: str, *, literal: bool, n: int) -> DaskExpr: + def _replace(expr: dx.Series, value: str) -> dx.Series: try: return expr.str.replace( # pyright: ignore[reportAttributeAccessIssue] pattern, value, regex=not literal, n=n @@ -32,14 +30,10 @@ def _replace( raise TypeError(msg) from e raise - return self.compliant._with_callable( - _replace, "replace", pattern=pattern, value=value, literal=literal, n=n - ) + return self.compliant._with_callable(_replace, value=value) - def replace_all(self, pattern: str, value: str, *, literal: bool) -> DaskExpr: - def _replace_all( - expr: dx.Series, pattern: str, value: str, *, literal: bool - ) -> dx.Series: + def replace_all(self, value: str, pattern: str, *, literal: bool) -> DaskExpr: + def _replace_all(expr: dx.Series, value: str) -> dx.Series: try: return expr.str.replace( # pyright: ignore[reportAttributeAccessIssue] pattern, value, regex=not literal, n=-1 @@ -50,72 +44,44 @@ def _replace_all( raise TypeError(msg) from e raise - return self.compliant._with_callable( - _replace_all, "replace", pattern=pattern, value=value, literal=literal - ) + return self.compliant._with_callable(_replace_all, value=value) def strip_chars(self, characters: str | None) -> DaskExpr: - return self.compliant._with_callable( - lambda expr, characters: expr.str.strip(characters), - "strip", - characters=characters, - ) + return self.compliant._with_callable(lambda expr: expr.str.strip(characters)) def starts_with(self, prefix: str) -> DaskExpr: - return self.compliant._with_callable( - lambda expr, prefix: expr.str.startswith(prefix), "starts_with", prefix=prefix - ) + return self.compliant._with_callable(lambda expr: expr.str.startswith(prefix)) def ends_with(self, suffix: str) -> DaskExpr: - return self.compliant._with_callable( - lambda expr, suffix: expr.str.endswith(suffix), "ends_with", suffix=suffix - ) + return self.compliant._with_callable(lambda expr: expr.str.endswith(suffix)) def contains(self, pattern: str, *, literal: bool) -> DaskExpr: return self.compliant._with_callable( - lambda expr, pattern, literal: expr.str.contains( - pat=pattern, regex=not literal - ), - "contains", - pattern=pattern, - literal=literal, + lambda expr: expr.str.contains(pat=pattern, regex=not literal) ) def slice(self, offset: int, length: int | None) -> DaskExpr: return self.compliant._with_callable( - lambda expr, offset, length: expr.str.slice( + lambda expr: expr.str.slice( start=offset, stop=offset + length if length else None - ), - "slice", - offset=offset, - length=length, + ) ) def split(self, by: str) -> DaskExpr: - return self.compliant._with_callable( - lambda expr, by: expr.str.split(pat=by), "split", by=by - ) + return self.compliant._with_callable(lambda expr: expr.str.split(pat=by)) def to_datetime(self, format: str | None) -> DaskExpr: return self.compliant._with_callable( - lambda expr, format: dd.to_datetime(expr, format=format), - "to_datetime", - format=format, + lambda expr: dd.to_datetime(expr, format=format) ) def to_uppercase(self) -> DaskExpr: - return self.compliant._with_callable( - lambda expr: expr.str.upper(), "to_uppercase" - ) + return self.compliant._with_callable(lambda expr: expr.str.upper()) def to_lowercase(self) -> DaskExpr: - return self.compliant._with_callable( - lambda expr: expr.str.lower(), "to_lowercase" - ) + return self.compliant._with_callable(lambda expr: expr.str.lower()) def zfill(self, width: int) -> DaskExpr: - return self.compliant._with_callable( - lambda expr, width: expr.str.zfill(width), "zfill", width=width - ) + return self.compliant._with_callable(lambda expr: expr.str.zfill(width)) to_date = not_implemented() diff --git a/narwhals/_dask/group_by.py b/narwhals/_dask/group_by.py index 4427f2d324..65309a2eaa 100644 --- a/narwhals/_dask/group_by.py +++ b/narwhals/_dask/group_by.py @@ -126,17 +126,18 @@ def agg(self, *exprs: DaskExpr) -> DaskLazyFrame: output_names, aliases = evaluate_output_names_and_aliases( expr, self.compliant, exclude ) - if expr._depth == 0: + last_node = next(expr._metadata.op_nodes_reversed()) + if len(list(expr._metadata.op_nodes_reversed())) == 1: # e.g. `agg(nw.len())` column = self._keys[0] - agg_fn = self._remap_expr_name(expr._function_name) + agg_fn = self._remap_expr_name(last_node.name) simple_aggregations.update(dict.fromkeys(aliases, (column, agg_fn))) continue # e.g. `agg(nw.mean('a'))` agg_fn = self._remap_expr_name(self._leaf_name(expr)) # deal with n_unique case in a "lazy" mode to not depend on dask globally - agg_fn = agg_fn(**expr._scalar_kwargs) if callable(agg_fn) else agg_fn + agg_fn = agg_fn(**last_node.kwargs) if callable(agg_fn) else agg_fn simple_aggregations.update( (alias, (output_name, agg_fn)) for alias, output_name in zip_strict(aliases, output_names) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index df66a583c4..2133dc235c 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -8,12 +8,7 @@ import dask.dataframe as dd import pandas as pd -from narwhals._compliant import ( - CompliantThen, - CompliantWhen, - DepthTrackingNamespace, - LazyNamespace, -) +from narwhals._compliant import DepthTrackingNamespace, LazyNamespace from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.expr import DaskExpr from narwhals._dask.selectors import DaskSelectorNamespace @@ -30,11 +25,10 @@ from narwhals._utils import Implementation, zip_strict if TYPE_CHECKING: - from collections.abc import Iterable, Iterator, Sequence + from collections.abc import Iterable, Iterator import dask.dataframe.dask_expr as dx - from narwhals._compliant.typing import ScalarKwargs from narwhals._utils import Version from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral @@ -73,8 +67,6 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return self._expr( func, - depth=0, - function_name="lit", evaluate_output_names=lambda _df: ["literal"], alias_output_names=None, version=self._version, @@ -87,8 +79,6 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return self._expr( func, - depth=0, - function_name="len", evaluate_output_names=lambda _df: ["len"], alias_output_names=None, version=self._version, @@ -106,8 +96,6 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return self._expr( call=func, - depth=max(x._depth for x in exprs) + 1, - function_name="all_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), version=self._version, @@ -122,8 +110,6 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return self._expr( call=func, - depth=max(x._depth for x in exprs) + 1, - function_name="any_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), version=self._version, @@ -138,8 +124,6 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return self._expr( call=func, - depth=max(x._depth for x in exprs) + 1, - function_name="sum_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), version=self._version, @@ -188,8 +172,6 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return self._expr( call=func, - depth=max(x._depth for x in exprs) + 1, - function_name="mean_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), version=self._version, @@ -205,8 +187,6 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return self._expr( call=func, - depth=max(x._depth for x in exprs) + 1, - function_name="min_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), version=self._version, @@ -222,16 +202,11 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return self._expr( call=func, - depth=max(x._depth for x in exprs) + 1, - function_name="max_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), version=self._version, ) - def when(self, predicate: DaskExpr) -> DaskWhen: - return DaskWhen.from_expr(predicate, context=self) - def concat_str( self, *exprs: DaskExpr, separator: str, ignore_nulls: bool ) -> DaskExpr: @@ -266,8 +241,6 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return self._expr( call=func, - depth=max(x._depth for x in exprs) + 1, - function_name="concat_str", evaluate_output_names=getattr( exprs[0], "_evaluate_output_names", lambda _df: ["literal"] ), @@ -284,55 +257,58 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return self._expr( call=func, - depth=max(x._depth for x in exprs) + 1, - function_name="coalesce", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), version=self._version, ) + def when_then( + self, predicate: DaskExpr, then: DaskExpr, otherwise: DaskExpr | None = None + ) -> DaskExpr: + def func(df: DaskLazyFrame) -> list[dx.Series]: + then_value = then(df)[0] if isinstance(then, DaskExpr) else then + otherwise_value = ( + otherwise(df)[0] if isinstance(otherwise, DaskExpr) else otherwise + ) -class DaskWhen(CompliantWhen[DaskLazyFrame, "dx.Series", DaskExpr]): # pyright: ignore[reportInvalidTypeArguments] - @property - def _then(self) -> type[DaskThen]: - return DaskThen - - def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]: - then_value = ( - self._then_value(df)[0] - if isinstance(self._then_value, DaskExpr) - else self._then_value - ) - otherwise_value = ( - self._otherwise_value(df)[0] - if isinstance(self._otherwise_value, DaskExpr) - else self._otherwise_value - ) - - condition = self._condition(df)[0] - # re-evaluate DataFrame if the condition aggregates to force - # then/otherwise to be evaluated against the aggregated frame - assert self._condition._metadata is not None # noqa: S101 - if self._condition._metadata.is_scalar_like: - new_df = df._with_native(condition.to_frame()) - condition = self._condition.broadcast(ExprKind.AGGREGATION)(df)[0] - df = new_df - - if self._otherwise_value is None: - (condition, then_series) = align_series_full_broadcast( - df, condition, then_value + condition = predicate(df)[0] + # re-evaluate DataFrame if the condition aggregates to force + # then/otherwise to be evaluated against the aggregated frame + assert predicate._metadata is not None # noqa: S101 + if all( + x._metadata.is_scalar_like + for x in ( + (predicate, then) + if (isinstance(then, DaskExpr) and otherwise is None) + else (predicate, then, otherwise) + if isinstance(then, DaskExpr) and isinstance(otherwise, DaskExpr) + else (predicate, otherwise) + if isinstance(otherwise, DaskExpr) + else (predicate,) + ) + ): + new_df = df._with_native(condition.to_frame()) + condition = predicate.broadcast(ExprKind.AGGREGATION)(df)[0] + df = new_df + + if otherwise is None: + (condition, then_series) = align_series_full_broadcast( + df, condition, then_value + ) + validate_comparand(condition, then_series) + return [then_series.where(condition)] # pyright: ignore[reportArgumentType] + (condition, then_series, otherwise_series) = align_series_full_broadcast( + df, condition, then_value, otherwise_value ) validate_comparand(condition, then_series) - return [then_series.where(condition)] # pyright: ignore[reportArgumentType] - (condition, then_series, otherwise_series) = align_series_full_broadcast( - df, condition, then_value, otherwise_value - ) - validate_comparand(condition, then_series) - validate_comparand(condition, otherwise_series) - return [then_series.where(condition, otherwise_series)] # pyright: ignore[reportArgumentType] - + validate_comparand(condition, otherwise_series) + return [then_series.where(condition, otherwise_series)] # pyright: ignore[reportArgumentType] -class DaskThen(CompliantThen[DaskLazyFrame, "dx.Series", DaskExpr, DaskWhen], DaskExpr): # pyright: ignore[reportInvalidTypeArguments] - _depth: int = 0 - _scalar_kwargs: ScalarKwargs = {} # noqa: RUF012 - _function_name: str = "whenthen" + return self._expr( + call=func, + evaluate_output_names=getattr( + then, "_evaluate_output_names", lambda _df: ["literal"] + ), + alias_output_names=getattr(then, "_alias_output_names", None), + version=self._version, + ) diff --git a/narwhals/_dask/selectors.py b/narwhals/_dask/selectors.py index 9fb6eeecb8..501662422d 100644 --- a/narwhals/_dask/selectors.py +++ b/narwhals/_dask/selectors.py @@ -8,7 +8,6 @@ if TYPE_CHECKING: import dask.dataframe.dask_expr as dx # noqa: F401 - from narwhals._compliant.typing import ScalarKwargs from narwhals._dask.dataframe import DaskLazyFrame # noqa: F401 @@ -19,15 +18,9 @@ def _selector(self) -> type[DaskSelector]: class DaskSelector(CompliantSelector["DaskLazyFrame", "dx.Series"], DaskExpr): # type: ignore[misc] - _depth: int = 0 - _scalar_kwargs: ScalarKwargs = {} # noqa: RUF012 - _function_name: str = "selector" - def _to_expr(self) -> DaskExpr: return DaskExpr( self._call, - depth=self._depth, - function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, version=self._version, diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 618bee7e0b..088e8783cf 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -12,7 +12,7 @@ F, catch_duckdb_exception, col, - evaluate_exprs, + evaluate_exprs_and_aliases, join_column_names, lit, native_to_narwhals_dtype, @@ -24,7 +24,6 @@ ValidateBackendVersion, Version, generate_temporary_column_name, - not_implemented, parse_columns_to_drop, requires, to_pyarrow_table, @@ -173,14 +172,18 @@ def simple_select(self, *column_names: str) -> Self: return self._with_native(self.native.select(*column_names)) def aggregate(self, *exprs: DuckDBExpr) -> Self: - selection = [val.alias(name) for name, val in evaluate_exprs(self, *exprs)] + selection = [ + val.alias(name) for name, val in evaluate_exprs_and_aliases(self, *exprs) + ] try: return self._with_native(self.native.aggregate(selection)) # type: ignore[arg-type] except Exception as e: # noqa: BLE001 raise catch_duckdb_exception(e, self) from None def select(self, *exprs: DuckDBExpr) -> Self: - selection = (val.alias(name) for name, val in evaluate_exprs(self, *exprs)) + selection = ( + val.alias(name) for name, val in evaluate_exprs_and_aliases(self, *exprs) + ) try: return self._with_native(self.native.select(*selection)) except Exception as e: # noqa: BLE001 @@ -202,7 +205,7 @@ def lazy(self, backend: None = None, **_: None) -> Self: return self def with_columns(self, *exprs: DuckDBExpr) -> Self: - new_columns_map = dict(evaluate_exprs(self, *exprs)) + new_columns_map = dict(evaluate_exprs_and_aliases(self, *exprs)) result = [ new_columns_map.pop(name).alias(name) if name in new_columns_map @@ -220,8 +223,8 @@ def filter(self, predicate: DuckDBExpr) -> Self: mask = predicate(self)[0] try: return self._with_native(self.native.filter(mask)) - except Exception as e: # noqa: BLE001 - raise catch_duckdb_exception(e, self) from None + except Exception as e: + raise catch_duckdb_exception(e, self) from e @property def schema(self) -> dict[str, DType]: @@ -552,10 +555,3 @@ def sink_parquet(self, file: str | Path | BytesIO) -> None: (FORMAT parquet) """ # noqa: S608 duckdb.sql(query) - - gather_every = not_implemented.deprecated( - "`LazyFrame.gather_every` is deprecated and will be removed in a future version." - ) - tail = not_implemented.deprecated( - "`LazyFrame.tail` is deprecated and will be removed in a future version." - ) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index cab9d51f22..f86b5b21a3 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -66,7 +66,7 @@ def __init__( self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names self._version = version - self._metadata: ExprMetadata | None = None + self._opt_metadata: ExprMetadata | None = None self._window_function: DuckDBWindowFunction | None = window_function def _count_star(self) -> Expression: @@ -117,8 +117,14 @@ def from_column_names( def func(df: DuckDBLazyFrame) -> list[Expression]: return [col(name) for name in evaluate_column_names(df)] + def window_func( + df: DuckDBLazyFrame, _window_inputs: WindowInputs[Expression] + ) -> list[Expression]: + return [col(name) for name in evaluate_column_names(df)] + return cls( func, + window_func, evaluate_output_names=evaluate_column_names, alias_output_names=None, version=context._version, diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index 009404c428..3d5ee2a44f 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -26,7 +26,6 @@ combine_evaluate_output_names, ) from narwhals._sql.namespace import SQLNamespace -from narwhals._sql.when_then import SQLThen, SQLWhen from narwhals._utils import Implementation if TYPE_CHECKING: @@ -34,6 +33,7 @@ from duckdb import DuckDBPyRelation # noqa: F401 + from narwhals._compliant.window import WindowInputs from narwhals._utils import Version from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral @@ -134,9 +134,6 @@ def func(cols: Iterable[Expression]) -> Expression: return self._expr._from_elementwise_horizontal_op(func, *exprs) - def when(self, predicate: DuckDBExpr) -> DuckDBWhen: - return DuckDBWhen.from_expr(predicate, context=self) - def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> DuckDBExpr: def func(df: DuckDBLazyFrame) -> list[Expression]: tz = DeferredTimeZone(df.native) @@ -145,8 +142,14 @@ def func(df: DuckDBLazyFrame) -> list[Expression]: return [lit(value).cast(target)] return [lit(value)] + def window_func( + df: DuckDBLazyFrame, _window_inputs: WindowInputs[Expression] + ) -> list[Expression]: + return func(df) + return self._expr( func, + window_func, evaluate_output_names=lambda _df: ["literal"], alias_output_names=None, version=self._version, @@ -162,12 +165,3 @@ def func(_df: DuckDBLazyFrame) -> list[Expression]: alias_output_names=None, version=self._version, ) - - -class DuckDBWhen(SQLWhen["DuckDBLazyFrame", Expression, DuckDBExpr]): - @property - def _then(self) -> type[DuckDBThen]: - return DuckDBThen - - -class DuckDBThen(SQLThen["DuckDBLazyFrame", Expression, DuckDBExpr], DuckDBExpr): ... diff --git a/narwhals/_duckdb/utils.py b/narwhals/_duckdb/utils.py index 413b0cf7fe..026a55fdb5 100644 --- a/narwhals/_duckdb/utils.py +++ b/narwhals/_duckdb/utils.py @@ -68,12 +68,12 @@ def concat_str(*exprs: Expression, separator: str = "") -> Expression: return F("concat_ws", lit(separator), *exprs) if separator else F("concat", *exprs) -def evaluate_exprs( +def evaluate_exprs_and_aliases( df: DuckDBLazyFrame, /, *exprs: DuckDBExpr ) -> list[tuple[str, Expression]]: native_results: list[tuple[str, Expression]] = [] for expr in exprs: - native_series_list = expr._call(df) + native_series_list = expr(df) output_names = expr._evaluate_output_names(df) if expr._alias_output_names is not None: output_names = expr._alias_output_names(output_names) @@ -84,6 +84,10 @@ def evaluate_exprs( return native_results +def evaluate_exprs(df: DuckDBLazyFrame, /, *exprs: DuckDBExpr) -> list[Expression]: + return [item for expr in exprs for item in expr(df)] + + class DeferredTimeZone: """Object which gets passed between `native_to_narwhals_dtype` calls. diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index b5cbd85f6b..c6c890acae 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -5,15 +5,18 @@ from __future__ import annotations from enum import Enum, auto -from itertools import chain -from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar +from typing import TYPE_CHECKING, Any, Literal, ParamSpec, TypeVar, cast, overload from narwhals._utils import is_compliant_expr, zip_strict -from narwhals.dependencies import is_narwhals_series, is_numpy_array, is_numpy_array_1d -from narwhals.exceptions import InvalidOperationError, MultiOutputExpressionError +from narwhals.dependencies import is_numpy_array_1d +from narwhals.exceptions import ( + InvalidIntoExprError, + InvalidOperationError, + MultiOutputExpressionError, +) if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Iterator, Sequence from typing_extensions import Never, TypeIs @@ -30,6 +33,8 @@ from narwhals.typing import IntoExpr, NonNestedLiteral, _1DArray T = TypeVar("T") + PS = ParamSpec("PS") + R = TypeVar("R") def is_expr(obj: Any) -> TypeIs[Expr]: @@ -46,13 +51,6 @@ def is_series(obj: Any) -> TypeIs[Series[Any]]: return isinstance(obj, Series) -def is_into_expr_eager(obj: Any) -> TypeIs[Expr | Series[Any] | str | _1DArray]: - from narwhals.expr import Expr - from narwhals.series import Series - - return isinstance(obj, (Series, Expr, str)) or is_numpy_array_1d(obj) - - def combine_evaluate_output_names( *exprs: CompliantExpr[CompliantFrameT, Any], ) -> EvalNames[CompliantFrameT]: @@ -89,16 +87,14 @@ def evaluate_output_names_and_aliases( if expr._alias_output_names is None else expr._alias_output_names(output_names) ) - if exclude: - assert expr._metadata is not None # noqa: S101 - if expr._metadata.expansion_kind.is_multi_unnamed(): - output_names, aliases = zip_strict( - *[ - (x, alias) - for x, alias in zip_strict(output_names, aliases) - if x not in exclude - ] - ) + if exclude and expr._metadata.expansion_kind.is_multi_unnamed(): + output_names, aliases = zip_strict( + *[ + (x, alias) + for x, alias in zip_strict(output_names, aliases) + if x not in exclude + ] + ) return output_names, aliases @@ -132,6 +128,27 @@ class ExprKind(Enum): OVER = auto() """Results from calling `.over` on expression.""" + COL = auto() + """Results from calling `nw.col`.""" + + NTH = auto() + """Results from calling `nw.nth`.""" + + EXCLUDE = auto() + """Results from calling `nw.exclude`.""" + + ALL = auto() + """Results from calling `nw.all`.""" + + SELECTOR = auto() + """Results from creating an expression with a selector.""" + + WHEN_THEN = auto() + """Results from `when/then expression`, possibly followed by `otherwise`.""" + + SERIES = auto() + """Results from converting a Series to Expr.""" + UNKNOWN = auto() """Based on the information we have, we can't determine the ExprKind.""" @@ -140,12 +157,20 @@ def is_scalar_like(self) -> bool: return self in {ExprKind.LITERAL, ExprKind.AGGREGATION} @property - def is_orderable_window(self) -> bool: - return self in {ExprKind.ORDERABLE_WINDOW, ExprKind.ORDERABLE_AGGREGATION} + def is_orderable(self) -> bool: + # Any operation which may be affected by `order_by`, such as `cum_sum`, + # `diff`, `rank`, `arg_max`, ... + return self in { + ExprKind.ORDERABLE_WINDOW, + ExprKind.ORDERABLE_AGGREGATION, + ExprKind.FILTRATION, + ExprKind.WINDOW, + } @classmethod - def from_expr(cls, obj: Expr) -> ExprKind: + def from_expr(cls, obj: CompliantExprAny) -> ExprKind: meta = obj._metadata + assert meta is not None # noqa: S101 if meta.is_literal: return ExprKind.LITERAL if meta.is_scalar_like: @@ -155,17 +180,9 @@ def from_expr(cls, obj: Expr) -> ExprKind: return ExprKind.UNKNOWN @classmethod - def from_into_expr( - cls, obj: IntoExpr | NonNestedLiteral | _1DArray, *, str_as_lit: bool - ) -> ExprKind: - if is_expr(obj): + def from_into_expr(cls, obj: CompliantExprAny | NonNestedLiteral) -> ExprKind: + if is_compliant_expr(obj): return cls.from_expr(obj) - if ( - is_narwhals_series(obj) - or is_numpy_array(obj) - or (isinstance(obj, str) and not str_as_lit) - ): - return ExprKind.ELEMENTWISE return ExprKind.LITERAL @@ -202,6 +219,82 @@ def __and__(self, other: ExpansionKind) -> Literal[ExpansionKind.MULTI_UNNAMED]: raise AssertionError(msg) # pragma: no cover +class ExprNode: + def __init__( + self, + kind: ExprKind, + name: str, + /, + *exprs: IntoExpr | NonNestedLiteral, + str_as_lit: bool = False, + allow_multi_output: bool = False, + **kwargs: Any, + ) -> None: + self.kind: ExprKind = kind + self.name: str = name + self.exprs: Sequence[IntoExpr | NonNestedLiteral] = exprs + self.kwargs: dict[str, Any] = kwargs + self.str_as_lit: bool = str_as_lit + self.allow_multi_output: bool = allow_multi_output + + # Cached methods. + self._is_orderable_cached: bool | None = None + + def __repr__(self) -> str: + if self.name == "col": + names = ", ".join(str(x) for x in self.kwargs["names"]) + return f"col({names})" + arg_str = [] + expr_repr = ", ".join(str(x) for x in self.exprs) + kwargs_repr = ", ".join(f"{key}={value}" for key, value in self.kwargs.items()) + if self.exprs: + arg_str.append(expr_repr) + if self.kwargs: + arg_str.append(kwargs_repr) + return f"{self.name}({', '.join(arg_str)})" + + def _with_kwargs(self, **kwargs: Any) -> ExprNode: + return self.__class__( + self.kind, self.name, *self.exprs, str_as_lit=self.str_as_lit, **kwargs + ) + + def _push_down_over_node_in_place( + self, over_node: ExprNode, over_node_without_order_by: ExprNode + ) -> None: + exprs = [] + # Note: please keep this as a for-loop (rather than a list-comprehension) + # so that pytest-cov highlights any uncovered branches. + for expr in self.exprs: + if not is_expr(expr): + exprs.append(expr) + elif over_node.kwargs["order_by"] and any( + expr_node.is_orderable() for expr_node in expr._nodes + ): + exprs.append(expr._with_node(over_node)) + elif over_node_without_order_by.kwargs["partition_by"]: + exprs.append(expr._with_node(over_node_without_order_by)) + else: + # If there's no `partition_by`, then `over_node_without_order_by` is a no-op. + exprs.append(expr) + self.exprs = exprs + + def is_orderable(self) -> bool: + if self._is_orderable_cached is None: + # Note: don't combine these if/then statements so that pytest-cov shows if + # anything is uncovered. + if self.kind.is_orderable: # noqa: SIM114 + self._is_orderable_cached = True + elif any( + any(node.is_orderable() for node in expr._nodes) + for expr in self.exprs + if is_expr(expr) + ): + self._is_orderable_cached = True + else: + self._is_orderable_cached = False + return self._is_orderable_cached + + class ExprMetadata: """Expression metadata. @@ -212,7 +305,6 @@ class ExprMetadata: of the other rows around it. is_literal: Whether it is just a literal wrapped in an expression. is_scalar_like: Whether it is a literal or an aggregation. - last_node: The ExprKind of the last node. n_orderable_ops: The number of order-dependent operations. In the lazy case, this number must be `0` by the time the expression is evaluated. @@ -225,15 +317,14 @@ class ExprMetadata: "is_elementwise", "is_literal", "is_scalar_like", - "last_node", "n_orderable_ops", + "nodes", "preserves_length", ) def __init__( self, expansion_kind: ExpansionKind, - last_node: ExprKind, *, has_windows: bool = False, n_orderable_ops: int = 0, @@ -241,19 +332,20 @@ def __init__( is_elementwise: bool = True, is_scalar_like: bool = False, is_literal: bool = False, + nodes: tuple[ExprNode, ...], ) -> None: if is_literal: assert is_scalar_like # noqa: S101 # debug assertion if is_elementwise: assert preserves_length # noqa: S101 # debug assertion self.expansion_kind: ExpansionKind = expansion_kind - self.last_node: ExprKind = last_node self.has_windows: bool = has_windows self.n_orderable_ops: int = n_orderable_ops self.is_elementwise: bool = is_elementwise self.preserves_length: bool = preserves_length self.is_scalar_like: bool = is_scalar_like self.is_literal: bool = is_literal + self.nodes: tuple[ExprNode, ...] = nodes def __init_subclass__(cls, /, *args: Any, **kwds: Any) -> Never: # pragma: no cover msg = f"Cannot subclass {cls.__name__!r}" @@ -263,71 +355,188 @@ def __repr__(self) -> str: # pragma: no cover return ( f"ExprMetadata(\n" f" expansion_kind: {self.expansion_kind},\n" - f" last_node: {self.last_node},\n" f" has_windows: {self.has_windows},\n" f" n_orderable_ops: {self.n_orderable_ops},\n" f" is_elementwise: {self.is_elementwise},\n" f" preserves_length: {self.preserves_length},\n" f" is_scalar_like: {self.is_scalar_like},\n" f" is_literal: {self.is_literal},\n" + f" nodes: {self.nodes},\n" ")" ) + @classmethod + def from_node( # noqa: PLR0911 + cls, node: ExprNode, *ces: CompliantExprAny | NonNestedLiteral + ) -> ExprMetadata: + if node.kind is ExprKind.SERIES: + return cls.from_selector_single(node) + if node.kind is ExprKind.COL: + return ( + ExprMetadata.from_selector_single(node) + if len(node.kwargs["names"]) == 1 + else ExprMetadata.from_selector_multi_named(node) + ) + if node.kind is ExprKind.NTH: + return ( + ExprMetadata.from_selector_single(node) + if len(node.kwargs["indices"]) == 1 + else ExprMetadata.from_selector_multi_unnamed(node) + ) + if node.kind in {ExprKind.ALL, ExprKind.EXCLUDE}: + return ExprMetadata.from_selector_multi_unnamed(node) + if node.kind is ExprKind.AGGREGATION: + return ExprMetadata.from_aggregation(node) + if node.kind is ExprKind.LITERAL: + return ExprMetadata.from_literal(node) + if node.kind is ExprKind.SELECTOR: + return ExprMetadata.from_selector_multi_unnamed(node) + if node.kind is ExprKind.ELEMENTWISE: + return ExprMetadata.from_elementwise(node, *ces) + msg = f"Unexpected node kind: {node.kind}" + raise AssertionError(msg) + + def with_node( # noqa: PLR0911,C901 + self, + node: ExprNode, + ce: CompliantExprAny, + *ces: CompliantExprAny | NonNestedLiteral, + ) -> ExprMetadata: + if node.kind is ExprKind.AGGREGATION: + return self.with_aggregation(node) + if node.kind is ExprKind.ELEMENTWISE: + return combine_metadata( + ce, + *ces, + str_as_lit=node.str_as_lit, + allow_multi_output=node.allow_multi_output, + to_single_output=False, + nodes=(*ce._metadata.nodes, node), + ) + if node.kind is ExprKind.FILTRATION: + return self.with_filtration(node) + if node.kind is ExprKind.ORDERABLE_WINDOW: + return self.with_orderable_window(node) + if node.kind is ExprKind.ORDERABLE_FILTRATION: + return self.with_orderable_filtration(node) + if node.kind is ExprKind.ORDERABLE_AGGREGATION: + return self.with_orderable_aggregation(node) + if node.kind is ExprKind.WINDOW: + return self.with_window(node) + if node.kind is ExprKind.SELECTOR: + return self + if node.kind is ExprKind.OVER: + if node.kwargs["order_by"]: + return self.with_ordered_over(node) + if not node.kwargs["partition_by"]: # pragma: no cover + msg = "At least one of `partition_by` or `order_by` must be specified." + raise InvalidOperationError(msg) + return self.with_partitioned_over(node) + msg = f"Unexpected node kind: {node.kind}" + raise AssertionError(msg) + + @classmethod + def from_aggregation(cls, node: ExprNode) -> ExprMetadata: + return cls( + ExpansionKind.SINGLE, + is_elementwise=False, + preserves_length=False, + is_scalar_like=True, + nodes=(node,), + ) + + @classmethod + def from_literal(cls, node: ExprNode) -> ExprMetadata: + return cls( + ExpansionKind.SINGLE, + is_elementwise=False, + preserves_length=False, + is_literal=True, + is_scalar_like=True, + nodes=(node,), + ) + + @classmethod + def from_selector_single(cls, node: ExprNode) -> ExprMetadata: + # e.g. `nw.col('a')`, `nw.nth(0)` + return cls(ExpansionKind.SINGLE, nodes=(node,)) + + @classmethod + def from_selector_multi_named(cls, node: ExprNode) -> ExprMetadata: + # e.g. `nw.col('a', 'b')` + return cls(ExpansionKind.MULTI_NAMED, nodes=(node,)) + + @classmethod + def from_selector_multi_unnamed(cls, node: ExprNode) -> ExprMetadata: + # e.g. `nw.all()` + return cls(ExpansionKind.MULTI_UNNAMED, nodes=(node,)) + + @classmethod + def from_elementwise( + cls, node: ExprNode, *ces: CompliantExprAny | NonNestedLiteral + ) -> ExprMetadata: + return combine_metadata( + *ces, + str_as_lit=False, + allow_multi_output=True, + to_single_output=True, + nodes=(node,), + ) + @property def is_filtration(self) -> bool: return not self.preserves_length and not self.is_scalar_like - def with_aggregation(self) -> ExprMetadata: + def with_aggregation(self, node: ExprNode) -> ExprMetadata: if self.is_scalar_like: msg = "Can't apply aggregations to scalar-like expressions." raise InvalidOperationError(msg) return ExprMetadata( self.expansion_kind, - ExprKind.AGGREGATION, has_windows=self.has_windows, n_orderable_ops=self.n_orderable_ops, preserves_length=False, is_elementwise=False, is_scalar_like=True, is_literal=False, + nodes=(*self.nodes, node), ) - def with_orderable_aggregation(self) -> ExprMetadata: + def with_orderable_aggregation(self, node: ExprNode) -> ExprMetadata: # Deprecated, used only in stable.v1. if self.is_scalar_like: # pragma: no cover msg = "Can't apply aggregations to scalar-like expressions." raise InvalidOperationError(msg) return ExprMetadata( self.expansion_kind, - ExprKind.ORDERABLE_AGGREGATION, has_windows=self.has_windows, n_orderable_ops=self.n_orderable_ops + 1, preserves_length=False, is_elementwise=False, is_scalar_like=True, is_literal=False, + nodes=(*self.nodes, node), ) - def with_elementwise_op(self) -> ExprMetadata: + def with_elementwise_op(self, node: ExprNode) -> ExprMetadata: return ExprMetadata( self.expansion_kind, - ExprKind.ELEMENTWISE, has_windows=self.has_windows, n_orderable_ops=self.n_orderable_ops, preserves_length=self.preserves_length, is_elementwise=self.is_elementwise, is_scalar_like=self.is_scalar_like, is_literal=self.is_literal, + nodes=(*self.nodes, node), ) - def with_window(self) -> ExprMetadata: + def with_window(self, node: ExprNode) -> ExprMetadata: # Window function which may (but doesn't have to) be used with `over(order_by=...)`. if self.is_scalar_like: msg = "Can't apply window (e.g. `rank`) to scalar-like expression." raise InvalidOperationError(msg) return ExprMetadata( self.expansion_kind, - ExprKind.WINDOW, has_windows=self.has_windows, # The function isn't order-dependent (but, users can still use `order_by` if they wish!), # so we don't increment `n_orderable_ops`. @@ -336,25 +545,26 @@ def with_window(self) -> ExprMetadata: is_elementwise=False, is_scalar_like=False, is_literal=False, + nodes=(*self.nodes, node), ) - def with_orderable_window(self) -> ExprMetadata: + def with_orderable_window(self, node: ExprNode) -> ExprMetadata: # Window function which must be used with `over(order_by=...)`. if self.is_scalar_like: msg = "Can't apply orderable window (e.g. `diff`, `shift`) to scalar-like expression." raise InvalidOperationError(msg) return ExprMetadata( self.expansion_kind, - ExprKind.ORDERABLE_WINDOW, has_windows=self.has_windows, n_orderable_ops=self.n_orderable_ops + 1, preserves_length=self.preserves_length, is_elementwise=False, is_scalar_like=False, is_literal=False, + nodes=(*self.nodes, node), ) - def with_ordered_over(self) -> ExprMetadata: + def with_ordered_over(self, node: ExprNode) -> ExprMetadata: if self.has_windows: msg = "Cannot nest `over` statements." raise InvalidOperationError(msg) @@ -365,7 +575,10 @@ def with_ordered_over(self) -> ExprMetadata: ) raise InvalidOperationError(msg) n_orderable_ops = self.n_orderable_ops - if not n_orderable_ops and self.last_node is not ExprKind.WINDOW: + if ( + not n_orderable_ops + and next(self.op_nodes_reversed()).kind is not ExprKind.WINDOW + ): msg = ( "Cannot use `order_by` in `over` on expression which isn't orderable.\n" "If your expression is orderable, then make sure that `over(order_by=...)`\n" @@ -376,20 +589,20 @@ def with_ordered_over(self) -> ExprMetadata: " + `nw.col('price').diff().over(order_by='date') + 1`\n" ) raise InvalidOperationError(msg) - if self.last_node.is_orderable_window: + if next(self.op_nodes_reversed()).kind.is_orderable and n_orderable_ops > 0: n_orderable_ops -= 1 return ExprMetadata( self.expansion_kind, - ExprKind.OVER, has_windows=True, n_orderable_ops=n_orderable_ops, preserves_length=True, is_elementwise=False, is_scalar_like=False, is_literal=False, + nodes=(*self.nodes, node), ) - def with_partitioned_over(self) -> ExprMetadata: + def with_partitioned_over(self, node: ExprNode) -> ExprMetadata: if self.has_windows: msg = "Cannot nest `over` statements." raise InvalidOperationError(msg) @@ -401,94 +614,51 @@ def with_partitioned_over(self) -> ExprMetadata: raise InvalidOperationError(msg) return ExprMetadata( self.expansion_kind, - ExprKind.OVER, has_windows=True, n_orderable_ops=self.n_orderable_ops, preserves_length=True, is_elementwise=False, is_scalar_like=False, is_literal=False, + nodes=(*self.nodes, node), ) - def with_filtration(self) -> ExprMetadata: + def with_filtration(self, node: ExprNode) -> ExprMetadata: if self.is_scalar_like: msg = "Can't apply filtration (e.g. `drop_nulls`) to scalar-like expression." raise InvalidOperationError(msg) return ExprMetadata( self.expansion_kind, - ExprKind.FILTRATION, has_windows=self.has_windows, n_orderable_ops=self.n_orderable_ops, preserves_length=False, is_elementwise=False, is_scalar_like=False, is_literal=False, + nodes=(*self.nodes, node), ) - def with_orderable_filtration(self) -> ExprMetadata: + def with_orderable_filtration(self, node: ExprNode) -> ExprMetadata: if self.is_scalar_like: msg = "Can't apply filtration (e.g. `drop_nulls`) to scalar-like expression." raise InvalidOperationError(msg) return ExprMetadata( self.expansion_kind, - ExprKind.ORDERABLE_FILTRATION, has_windows=self.has_windows, n_orderable_ops=self.n_orderable_ops + 1, preserves_length=False, is_elementwise=False, is_scalar_like=False, is_literal=False, + nodes=(*self.nodes, node), ) - @staticmethod - def aggregation() -> ExprMetadata: - return ExprMetadata( - ExpansionKind.SINGLE, - ExprKind.AGGREGATION, - is_elementwise=False, - preserves_length=False, - is_scalar_like=True, - ) - - @staticmethod - def literal() -> ExprMetadata: - return ExprMetadata( - ExpansionKind.SINGLE, - ExprKind.LITERAL, - is_elementwise=False, - preserves_length=False, - is_literal=True, - is_scalar_like=True, - ) - - @staticmethod - def selector_single() -> ExprMetadata: - # e.g. `nw.col('a')`, `nw.nth(0)` - return ExprMetadata(ExpansionKind.SINGLE, ExprKind.ELEMENTWISE) - - @staticmethod - def selector_multi_named() -> ExprMetadata: - # e.g. `nw.col('a', 'b')` - return ExprMetadata(ExpansionKind.MULTI_NAMED, ExprKind.ELEMENTWISE) - - @staticmethod - def selector_multi_unnamed() -> ExprMetadata: - # e.g. `nw.all()` - return ExprMetadata(ExpansionKind.MULTI_UNNAMED, ExprKind.ELEMENTWISE) - - @classmethod - def from_binary_op(cls, lhs: Expr, rhs: IntoExpr, /) -> ExprMetadata: - # We may be able to allow multi-output rhs in the future: - # https://github.com/narwhals-dev/narwhals/issues/2244. - return combine_metadata( - lhs, rhs, str_as_lit=True, allow_multi_output=False, to_single_output=False - ) - - @classmethod - def from_horizontal_op(cls, *exprs: IntoExpr) -> ExprMetadata: - return combine_metadata( - *exprs, str_as_lit=False, allow_multi_output=True, to_single_output=True - ) + def op_nodes_reversed(self) -> Iterator[ExprNode]: + for node in reversed(self.nodes): + if node.name.startswith("name.") or node.name == "alias": + # Skip nodes which only do aliasing. + continue + yield node def combine_metadata( @@ -496,6 +666,7 @@ def combine_metadata( str_as_lit: bool, allow_multi_output: bool, to_single_output: bool, + nodes: tuple[ExprNode, ...], ) -> ExprMetadata: """Combine metadata from `args`. @@ -505,6 +676,7 @@ def combine_metadata( allow_multi_output: Whether to allow multi-output inputs. to_single_output: Whether the result is always single-output, regardless of the inputs (e.g. `nw.sum_horizontal`). + nodes: Nodes of result node. """ n_filtrations = 0 result_expansion_kind = ExpansionKind.SINGLE @@ -524,8 +696,9 @@ def combine_metadata( result_preserves_length = True result_is_scalar_like = False result_is_literal = False - elif is_expr(arg): + elif is_compliant_expr(arg): metadata = arg._metadata + assert metadata is not None # noqa: S101 if metadata.expansion_kind.is_multi_output(): expansion_kind = metadata.expansion_kind if i > 0 and not allow_multi_output: @@ -549,67 +722,179 @@ def combine_metadata( result_is_scalar_like &= metadata.is_scalar_like result_is_literal &= metadata.is_literal n_filtrations += int(metadata.is_filtration) - if n_filtrations > 1: msg = "Length-changing expressions can only be used in isolation, or followed by an aggregation" raise InvalidOperationError(msg) if result_preserves_length and n_filtrations: msg = "Cannot combine length-changing expressions with length-preserving ones or aggregations" raise InvalidOperationError(msg) - return ExprMetadata( result_expansion_kind, - # n-ary operations align positionally, and so the last node is elementwise. - ExprKind.ELEMENTWISE, has_windows=result_has_windows, n_orderable_ops=result_n_orderable_ops, preserves_length=result_preserves_length, is_elementwise=result_is_elementwise, is_scalar_like=result_is_scalar_like, is_literal=result_is_literal, + nodes=nodes, ) -def check_expressions_preserve_length(*args: IntoExpr, function_name: str) -> None: +def check_expressions_preserve_length( + *args: CompliantExprAny | NonNestedLiteral, function_name: str +) -> None: # Raise if any argument in `args` isn't length-preserving. # For Series input, we don't raise (yet), we let such checks happen later, # as this function works lazily and so can't evaluate lengths. - from narwhals.series import Series - if not all( - (is_expr(x) and x._metadata.preserves_length) or isinstance(x, (str, Series)) - for x in args - ): + if not all((is_compliant_expr(x) and x._metadata.preserves_length) for x in args): msg = f"Expressions which aggregate or change length cannot be passed to '{function_name}'." raise InvalidOperationError(msg) -def all_exprs_are_scalar_like(*args: IntoExpr, **kwargs: IntoExpr) -> bool: +def all_exprs_are_scalar_like(mds: Sequence[ExprMetadata]) -> bool: # Raise if any argument in `args` isn't an aggregation or literal. # For Series input, we don't raise (yet), we let such checks happen later, # as this function works lazily and so can't evaluate lengths. - exprs = chain(args, kwargs.values()) - return all(is_expr(x) and x._metadata.is_scalar_like for x in exprs) + return all(md.is_scalar_like for md in mds) -def apply_n_ary_operation( +def apply_binary( plx: CompliantNamespaceAny, - n_ary_function: Callable[..., CompliantExprAny], - *comparands: IntoExpr | NonNestedLiteral | _1DArray, - str_as_lit: bool, + name: str, + ce: CompliantExprAny, + other: IntoExpr | NonNestedLiteral | _1DArray, ) -> CompliantExprAny: - parse = plx.parse_into_expr - compliant_exprs = (parse(into, str_as_lit=str_as_lit) for into in comparands) - kinds = [ - ExprKind.from_into_expr(comparand, str_as_lit=str_as_lit) - for comparand in comparands - ] + parse = plx.evaluate_expr + other_compliant = parse(other) + compliant_exprs = [ce, other_compliant] + return getattr(compliant_exprs[0], name)(compliant_exprs[1]) + + +@overload +def _parse_into_expr( + arg: IntoExpr | NonNestedLiteral | _1DArray, + *, + str_as_lit: bool = False, + backend: Any = None, + allow_literal: Literal[False], +) -> Expr: ... + + +@overload +def _parse_into_expr( + arg: IntoExpr | NonNestedLiteral | _1DArray, + *, + str_as_lit: bool = False, + backend: Any = None, + allow_literal: Literal[True] = ..., +) -> Expr | NonNestedLiteral: ... + + +def _parse_into_expr( + arg: IntoExpr | NonNestedLiteral | _1DArray, + *, + str_as_lit: bool = False, + backend: Any = None, + allow_literal: bool = True, +) -> Expr | NonNestedLiteral: + from narwhals.functions import col, new_series + + if isinstance(arg, str) and not str_as_lit: + return col(arg) + if is_numpy_array_1d(arg): + return new_series("", arg, backend=backend)._to_expr() + if is_series(arg): + return arg._to_expr() + if is_expr(arg): + return arg + if not allow_literal: + raise InvalidIntoExprError.from_invalid_type(type(arg)) + return arg + + +def evaluate_into_exprs( + *exprs: IntoExpr | NonNestedLiteral | _1DArray, + ns: CompliantNamespaceAny, + str_as_lit: bool, + allow_multi_output: bool, +) -> Iterator[CompliantExprAny | NonNestedLiteral]: + for expr in exprs: + ret = ns.evaluate_expr( + _parse_into_expr(expr, str_as_lit=str_as_lit, backend=ns._implementation) + ) + if ( + not allow_multi_output + and is_compliant_expr(ret) + and ret._metadata.expansion_kind.is_multi_output() + ): + msg = "Multi-output expressions are not allowed in this context." + raise MultiOutputExpressionError(msg) + yield ret + +def maybe_broadcast_ces( + *ces: CompliantExprAny | NonNestedLiteral, +) -> list[CompliantExprAny | NonNestedLiteral]: + kinds = [ExprKind.from_into_expr(comparand) for comparand in ces] broadcast = any(not kind.is_scalar_like for kind in kinds) - compliant_exprs = ( - compliant_expr.broadcast(kind) - if broadcast and is_compliant_expr(compliant_expr) and is_scalar_like(kind) - else compliant_expr - for compliant_expr, kind in zip_strict(compliant_exprs, kinds) + results: list[CompliantExprAny | NonNestedLiteral] = [] + for compliant_expr, kind in zip_strict(ces, kinds): + if broadcast and is_compliant_expr(compliant_expr) and is_scalar_like(kind): + _compliant_expr: CompliantExprAny = compliant_expr.broadcast(kind) + # Make sure to preserve metadata. + _compliant_expr._opt_metadata = compliant_expr._metadata + results.append(_compliant_expr) + else: + results.append(compliant_expr) + return results + + +def evaluate_root_node(node: ExprNode, ns: CompliantNamespaceAny) -> CompliantExprAny: + if "." in node.name: + module, method = node.name.split(".") + func = getattr(getattr(ns, module), method) + else: + func = getattr(ns, node.name) + ces = maybe_broadcast_ces( + *evaluate_into_exprs( + *node.exprs, + ns=ns, + str_as_lit=node.str_as_lit, + allow_multi_output=node.allow_multi_output, + ) ) - return n_ary_function(*compliant_exprs) + ce = cast("CompliantExpr[Any, Any]", func(*ces, **node.kwargs)) + md = ExprMetadata.from_node(node, *ces) + ce._opt_metadata = md + return ce + + +def evaluate_node( + compliant_expr: CompliantExprAny, node: ExprNode, ns: CompliantNamespaceAny +) -> CompliantExprAny: + md = compliant_expr._metadata + ce, *ces = maybe_broadcast_ces( + compliant_expr, + *evaluate_into_exprs( + *node.exprs, + ns=ns, + str_as_lit=node.str_as_lit, + allow_multi_output=node.allow_multi_output, + ), + ) + assert is_compliant_expr(ce) # noqa: S101 + md = md.with_node(node, ce, *ces) + if "." in node.name: + accessor, method = node.name.split(".") + func = getattr(getattr(ce, accessor), method) + else: + func = getattr(ce, node.name) + if not node.allow_multi_output and any( + x._metadata.expansion_kind.is_multi_output() for x in ces if is_compliant_expr(x) + ): + msg = "multi-output expressions are not allowed as arguments to Expr methods." + raise MultiOutputExpressionError(msg) + ret = cast("CompliantExprAny", func(*ces, **node.kwargs)) + ret._opt_metadata = md + return ret diff --git a/narwhals/_ibis/dataframe.py b/narwhals/_ibis/dataframe.py index 7e198ce11a..056d624ff7 100644 --- a/narwhals/_ibis/dataframe.py +++ b/narwhals/_ibis/dataframe.py @@ -423,12 +423,5 @@ def sink_parquet(self, file: str | Path | BytesIO) -> None: raise NotImplementedError(msg) self.native.to_parquet(file) - gather_every = not_implemented.deprecated( - "`LazyFrame.gather_every` is deprecated and will be removed in a future version." - ) - tail = not_implemented.deprecated( - "`LazyFrame.tail` is deprecated and will be removed in a future version." - ) - # Intentionally not implemented, as Ibis does its own expression rewriting. _evaluate_window_expr = not_implemented() diff --git a/narwhals/_ibis/expr.py b/narwhals/_ibis/expr.py index 3471d3c831..22f1174a5d 100644 --- a/narwhals/_ibis/expr.py +++ b/narwhals/_ibis/expr.py @@ -64,7 +64,7 @@ def __init__( self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names self._version = version - self._metadata: ExprMetadata | None = None + self._opt_metadata: ExprMetadata | None = None self._window_function: IbisWindowFunction | None = window_function @property @@ -199,18 +199,6 @@ def quantile( raise NotImplementedError(msg) return self._with_callable(lambda expr: expr.quantile(quantile)) - def clip(self, lower_bound: Any, upper_bound: Any) -> Self: - def _clip( - expr: ir.NumericValue, lower: Any | None = None, upper: Any | None = None - ) -> ir.NumericValue: - return expr.clip(lower=lower, upper=upper) - - if lower_bound is None: - return self._with_callable(_clip, upper=upper_bound) - if upper_bound is None: - return self._with_callable(_clip, lower=lower_bound) - return self._with_callable(_clip, lower=lower_bound, upper=upper_bound) - def n_unique(self) -> Self: return self._with_callable( lambda expr: expr.nunique() + expr.isnull().any().cast("int8") diff --git a/narwhals/_ibis/expr_str.py b/narwhals/_ibis/expr_str.py index adf84fe4e1..78907de7ef 100644 --- a/narwhals/_ibis/expr_str.py +++ b/narwhals/_ibis/expr_str.py @@ -41,7 +41,7 @@ def fn(expr: ir.StringColumn) -> ir.StringValue: return fn def replace_all( - self, pattern: str, value: str | IbisExpr, *, literal: bool + self, value: str | IbisExpr, pattern: str, *, literal: bool ) -> IbisExpr: fn = self._replace_all_literal if literal else self._replace_all if isinstance(value, str): diff --git a/narwhals/_ibis/namespace.py b/narwhals/_ibis/namespace.py index acf05aa6d0..048745ffba 100644 --- a/narwhals/_ibis/namespace.py +++ b/narwhals/_ibis/namespace.py @@ -17,8 +17,7 @@ from narwhals._ibis.selectors import IbisSelectorNamespace from narwhals._ibis.utils import function, lit, narwhals_to_native_dtype from narwhals._sql.namespace import SQLNamespace -from narwhals._sql.when_then import SQLThen, SQLWhen -from narwhals._utils import Implementation, requires +from narwhals._utils import Implementation if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -108,10 +107,6 @@ def func(cols: Iterable[ir.Value]) -> ir.Value: return self._expr._from_elementwise_horizontal_op(func, *exprs) - @requires.backend_version((10, 0)) - def when(self, predicate: IbisExpr) -> IbisWhen: - return IbisWhen.from_expr(predicate, context=self) - def lit(self, value: Any, dtype: IntoDType | None) -> IbisExpr: def func(_df: IbisLazyFrame) -> Sequence[ir.Value]: ibis_dtype = narwhals_to_native_dtype(dtype, self._version) if dtype else None @@ -134,27 +129,3 @@ def func(_df: IbisLazyFrame) -> list[ir.Value]: alias_output_names=None, version=self._version, ) - - -class IbisWhen(SQLWhen["IbisLazyFrame", "ir.Value", IbisExpr]): - lit = lit - - @property - def _then(self) -> type[IbisThen]: - return IbisThen - - def __call__(self, df: IbisLazyFrame) -> Sequence[ir.Value]: - is_expr = self._condition._is_expr - condition = df._evaluate_expr(self._condition) - then_ = self._then_value - then = df._evaluate_expr(then_) if is_expr(then_) else lit(then_) - other_ = self._otherwise_value - if other_ is None: - result = ibis.cases((condition, then)) - else: - otherwise = df._evaluate_expr(other_) if is_expr(other_) else lit(other_) - result = ibis.cases((condition, then), else_=otherwise) - return [result] - - -class IbisThen(SQLThen["IbisLazyFrame", "ir.Value", IbisExpr], IbisExpr): ... diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 8b516307d4..a8250c9736 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -410,7 +410,7 @@ def drop_nulls(self, subset: Sequence[str] | None) -> Self: self.native.dropna(axis=0), validate_column_names=False ) plx = self.__narwhals_namespace__() - mask = ~plx.any_horizontal(plx.col(*subset).is_null(), ignore_nulls=True) + mask = ~plx.any_horizontal(plx.col(subset).is_null(), ignore_nulls=True) return self.filter(mask) def estimated_size(self, unit: SizeUnit) -> int | float: @@ -419,18 +419,18 @@ def estimated_size(self, unit: SizeUnit) -> int | float: def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self: plx = self.__narwhals_namespace__() - if order_by is None: - size = len(self) - data = self._array_funcs.arange(size) - + size = len(self) + data = self._array_funcs.arange(size) + row_index_s = plx._series.from_iterable( + data, context=self, index=self.native.index, name=name + ) + row_index = plx._expr._from_series(row_index_s) + if order_by: row_index = plx._expr._from_series( - plx._series.from_iterable( - data, context=self, index=self.native.index, name=name - ) + self.with_columns(row_index) + .sort(*order_by, descending=False, nulls_last=False) + .get_column(name) ) - else: - rank = plx.col(order_by[0]).rank(method="ordinal", descending=False) - row_index = (rank.over(partition_by=[], order_by=order_by) - 1).alias(name) return self.select(row_index, plx.all()) def row(self, index: int) -> tuple[Any, ...]: diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 6cab84e3a3..1b06111c14 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, cast from narwhals._compliant import EagerExpr from narwhals._expression_parsing import evaluate_output_names_and_aliases @@ -13,7 +13,12 @@ from typing_extensions import Self - from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs + from narwhals._compliant.typing import ( + AliasNames, + EvalNames, + EvalSeries, + NarwhalsAggregation, + ) from narwhals._expression_parsing import ExprMetadata from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals._pandas_like.namespace import PandasLikeNamespace @@ -43,7 +48,7 @@ def window_kwargs_to_pandas_equivalent( - function_name: str, kwargs: ScalarKwargs + function_name: str, kwargs: dict[str, Any] ) -> dict[str, PythonLiteral]: if function_name == "shift": assert "n" in kwargs # noqa: S101 @@ -111,23 +116,17 @@ def __init__( self, call: EvalSeries[PandasLikeDataFrame, PandasLikeSeries], *, - depth: int, - function_name: str, evaluate_output_names: EvalNames[PandasLikeDataFrame], alias_output_names: AliasNames | None, implementation: Implementation, version: Version, - scalar_kwargs: ScalarKwargs | None = None, ) -> None: self._call = call - self._depth = depth - self._function_name = function_name self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names self._implementation = implementation self._version = version - self._scalar_kwargs = scalar_kwargs or {} - self._metadata: ExprMetadata | None = None + self._opt_metadata: ExprMetadata | None = None def __narwhals_namespace__(self) -> PandasLikeNamespace: from narwhals._pandas_like.namespace import PandasLikeNamespace @@ -141,7 +140,6 @@ def from_column_names( /, *, context: _LimitedContext, - function_name: str = "", ) -> Self: def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: try: @@ -160,8 +158,6 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return cls( func, - depth=0, - function_name=function_name, evaluate_output_names=evaluate_column_names, alias_output_names=None, implementation=context._implementation, @@ -179,8 +175,6 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return cls( func, - depth=0, - function_name="nth", evaluate_output_names=cls._eval_names_indices(column_indices), alias_output_names=None, implementation=context._implementation, @@ -200,20 +194,19 @@ def ewm_mean( ) -> Self: return self._reuse_series( "ewm_mean", - scalar_kwargs={ - "com": com, - "span": span, - "half_life": half_life, - "alpha": alpha, - "adjust": adjust, - "min_samples": min_samples, - "ignore_nulls": ignore_nulls, - }, + com=com, + span=span, + half_life=half_life, + alpha=alpha, + adjust=adjust, + min_samples=min_samples, + ignore_nulls=ignore_nulls, ) def over( # noqa: C901, PLR0915 self, partition_by: Sequence[str], order_by: Sequence[str] ) -> Self: + nodes = self._metadata.nodes if not partition_by: # e.g. `nw.col('a').cum_sum().order_by(key)` # We can always easily support this as it doesn't require grouping. @@ -229,7 +222,7 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: for s in results: s._scatter_in_place(sorting_indices, s) return results - elif not self._is_elementary(): + elif len(nodes) > 2: msg = ( "Only elementary expressions are supported for `.over` in pandas-like backends.\n\n" "Please see: " @@ -237,9 +230,14 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: ) raise NotImplementedError(msg) else: - function_name = PandasLikeGroupBy._leaf_name(self) + assert nodes # noqa: S101 + leaf_node = nodes[-1] + function_name = leaf_node.name + pandas_agg = PandasLikeGroupBy._REMAP_AGGS.get( + cast("NarwhalsAggregation", function_name) + ) pandas_function_name = WINDOW_FUNCTIONS_TO_PANDAS_EQUIVALENT.get( - function_name, PandasLikeGroupBy._REMAP_AGGS.get(function_name) + function_name, pandas_agg ) if pandas_function_name is None: msg = ( @@ -248,21 +246,22 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: f"and {', '.join(PandasLikeGroupBy._REMAP_AGGS)}." ) raise NotImplementedError(msg) + scalar_kwargs = leaf_node.kwargs pandas_kwargs = window_kwargs_to_pandas_equivalent( - function_name, self._scalar_kwargs + function_name, scalar_kwargs ) def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: # noqa: C901, PLR0912, PLR0914, PLR0915 output_names, aliases = evaluate_output_names_and_aliases(self, df, []) if function_name == "cum_count": plx = self.__narwhals_namespace__() - df = df.with_columns(~plx.col(*output_names).is_null()) + df = df.with_columns(~plx.col(output_names).is_null()) if function_name.startswith("cum_"): - assert "reverse" in self._scalar_kwargs # noqa: S101 - reverse = self._scalar_kwargs["reverse"] + assert "reverse" in scalar_kwargs # noqa: S101 + reverse = scalar_kwargs["reverse"] else: - assert "reverse" not in self._scalar_kwargs # noqa: S101 + assert "reverse" not in scalar_kwargs # noqa: S101 reverse = False if order_by: @@ -282,9 +281,9 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: # noqa: C901, rolling = grouped[list(output_names)].rolling(**pandas_kwargs) assert pandas_function_name is not None # help mypy # noqa: S101 if pandas_function_name in {"std", "var"}: - assert "ddof" in self._scalar_kwargs # noqa: S101 + assert "ddof" in scalar_kwargs # noqa: S101 res_native = getattr(rolling, pandas_function_name)( - ddof=self._scalar_kwargs["ddof"] + ddof=scalar_kwargs["ddof"] ) else: res_native = getattr(rolling, pandas_function_name)() @@ -301,13 +300,13 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: # noqa: C901, assert pandas_function_name is not None # help mypy # noqa: S101 res_native = getattr(ewm, pandas_function_name)() elif function_name == "fill_null": - assert "strategy" in self._scalar_kwargs # noqa: S101 - assert "limit" in self._scalar_kwargs # noqa: S101 + assert "strategy" in scalar_kwargs # noqa: S101 + assert "limit" in scalar_kwargs # noqa: S101 df_grouped = grouped[list(output_names)] - if self._scalar_kwargs["strategy"] == "forward": - res_native = df_grouped.ffill(limit=self._scalar_kwargs["limit"]) - elif self._scalar_kwargs["strategy"] == "backward": - res_native = df_grouped.bfill(limit=self._scalar_kwargs["limit"]) + if scalar_kwargs["strategy"] == "forward": + res_native = df_grouped.ffill(limit=scalar_kwargs["limit"]) + elif scalar_kwargs["strategy"] == "backward": + res_native = df_grouped.bfill(limit=scalar_kwargs["limit"]) else: # pragma: no cover # This is deprecated in pandas. Indeed, `nw.col('a').fill_null(3).over('b')` # does not seem very useful, and DuckDB doesn't support it either. @@ -336,8 +335,6 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: # noqa: C901, return self.__class__( func, - depth=self._depth + 1, - function_name=self._function_name + "->over", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, implementation=self._implementation, diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 24dff76654..9f20f090f3 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -118,7 +118,8 @@ def _getitem_aggs( ) elif self.is_mode(): compliant = group_by.compliant - if (keep := self.kwargs.get("keep")) != "any": # pragma: no cover + node_kwargs = next(self.expr._metadata.op_nodes_reversed()).kwargs + if (keep := node_kwargs.get("keep")) != "any": # pragma: no cover msg = ( f"`Expr.mode(keep='{keep}')` is not implemented in group by context for " f"backend {compliant._implementation}\n\n" @@ -162,11 +163,7 @@ def is_mode(self) -> bool: def is_top_level_function(self) -> bool: # e.g. `nw.len()`. - return self.expr._depth == 0 - - @property - def kwargs(self) -> ScalarKwargs: - return self.expr._scalar_kwargs + return len(list(self.expr._metadata.op_nodes_reversed())) == 1 @property def leaf_name(self) -> NarwhalsAggregation | Any: @@ -177,8 +174,9 @@ def leaf_name(self) -> NarwhalsAggregation | Any: def native_agg(self) -> _NativeAgg: """Return a partial `DataFrameGroupBy` method, missing only `self`.""" + last_node = next(self.expr._metadata.op_nodes_reversed()) return _native_agg( - PandasLikeGroupBy._remap_expr_name(self.leaf_name), **self.kwargs + PandasLikeGroupBy._remap_expr_name(self.leaf_name), **last_node.kwargs ) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index d682add939..de65820ba0 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -6,7 +6,7 @@ from itertools import chain from typing import TYPE_CHECKING, Any, Literal, Protocol, overload -from narwhals._compliant import CompliantThen, EagerNamespace, EagerWhen +from narwhals._compliant import EagerNamespace from narwhals._expression_parsing import ( combine_alias_output_names, combine_evaluate_output_names, @@ -24,7 +24,6 @@ from typing_extensions import TypeAlias - from narwhals._compliant.typing import ScalarKwargs from narwhals._utils import Implementation, Version from narwhals.typing import IntoDType, NonNestedLiteral @@ -80,8 +79,6 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="coalesce", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -101,8 +98,6 @@ def _lit_pandas_series(df: PandasLikeDataFrame) -> PandasLikeSeries: return PandasLikeExpr( lambda df: [_lit_pandas_series(df)], - depth=0, - function_name="lit", evaluate_output_names=lambda _df: ["literal"], alias_output_names=None, implementation=self._implementation, @@ -116,8 +111,6 @@ def len(self) -> PandasLikeExpr: [len(df._native_frame)], name="len", index=[0], context=self ) ], - depth=0, - function_name="len", evaluate_output_names=lambda _df: ["len"], alias_output_names=None, implementation=self._implementation, @@ -135,8 +128,6 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="sum_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -168,8 +159,6 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="all_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -201,8 +190,6 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="any_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -220,8 +207,6 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="mean_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -245,8 +230,6 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="min_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -270,8 +253,6 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="max_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, @@ -322,9 +303,6 @@ def _concat_vertical(self, dfs: Sequence[NativeDataFrameT], /) -> NativeDataFram return self._concat(dfs, axis=VERTICAL, copy=False) return self._concat(dfs, axis=VERTICAL) - def when(self, predicate: PandasLikeExpr) -> PandasWhen[NativeSeriesT]: - return PandasWhen[NativeSeriesT].from_expr(predicate, context=self) - def concat_str( self, *exprs: PandasLikeExpr, separator: str, ignore_nulls: bool ) -> PandasLikeExpr: @@ -366,13 +344,20 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return self._expr._from_callable( func=func, - depth=max(x._depth for x in exprs) + 1, - function_name="concat_str", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), context=self, ) + def _if_then_else( + self, + when: NativeSeriesT, + then: NativeSeriesT, + otherwise: NativeSeriesT | None = None, + ) -> NativeSeriesT: + where: Incomplete = then.where + return where(when) if otherwise is None else where(when, otherwise) + class _NativeConcat(Protocol[NativeDataFrameT, NativeSeriesT]): @overload @@ -411,31 +396,3 @@ def __call__( axis: Axis, copy: bool | None = None, ) -> NativeDataFrameT | NativeSeriesT: ... - - -class PandasWhen( - EagerWhen[PandasLikeDataFrame, PandasLikeSeries, PandasLikeExpr, NativeSeriesT] -): - @property - # Signature of "_then" incompatible with supertype "CompliantWhen" - # ArrowWhen seems to follow the same pattern, but no mypy complaint there? - def _then(self) -> type[PandasThen]: # type: ignore[override] - return PandasThen - - def _if_then_else( - self, - when: NativeSeriesT, - then: NativeSeriesT, - otherwise: NativeSeriesT | NonNestedLiteral, - ) -> NativeSeriesT: - where: Incomplete = then.where - return where(when) if otherwise is None else where(when, otherwise) - - -class PandasThen( - CompliantThen[PandasLikeDataFrame, PandasLikeSeries, PandasLikeExpr, PandasWhen], - PandasLikeExpr, -): - _depth: int = 0 - _scalar_kwargs: ScalarKwargs = {} # noqa: RUF012 - _function_name: str = "whenthen" diff --git a/narwhals/_pandas_like/selectors.py b/narwhals/_pandas_like/selectors.py index 7e68108ed3..b2462561cf 100644 --- a/narwhals/_pandas_like/selectors.py +++ b/narwhals/_pandas_like/selectors.py @@ -6,7 +6,6 @@ from narwhals._pandas_like.expr import PandasLikeExpr if TYPE_CHECKING: - from narwhals._compliant.typing import ScalarKwargs from narwhals._pandas_like.dataframe import PandasLikeDataFrame # noqa: F401 from narwhals._pandas_like.series import PandasLikeSeries # noqa: F401 @@ -22,15 +21,9 @@ def _selector(self) -> type[PandasSelector]: class PandasSelector( # type: ignore[misc] CompliantSelector["PandasLikeDataFrame", "PandasLikeSeries"], PandasLikeExpr ): - _depth: int = 0 - _scalar_kwargs: ScalarKwargs = {} # noqa: RUF012 - _function_name: str = "selector" - def _to_expr(self) -> PandasLikeExpr: return PandasLikeExpr( self._call, - depth=self._depth, - function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, implementation=self._implementation, diff --git a/narwhals/_pandas_like/series_str.py b/narwhals/_pandas_like/series_str.py index 19e16f3c35..6f5e5af560 100644 --- a/narwhals/_pandas_like/series_str.py +++ b/narwhals/_pandas_like/series_str.py @@ -16,7 +16,7 @@ def len_chars(self) -> PandasLikeSeries: return self.with_native(self.native.str.len()) def replace( - self, pattern: str, value: str, *, literal: bool, n: int + self, value: str, pattern: str, *, literal: bool, n: int ) -> PandasLikeSeries: try: series = self.native.str.replace( @@ -29,8 +29,8 @@ def replace( raise return self.with_native(series) - def replace_all(self, pattern: str, value: str, *, literal: bool) -> PandasLikeSeries: - return self.replace(pattern, value, literal=literal, n=-1) + def replace_all(self, value: str, pattern: str, *, literal: bool) -> PandasLikeSeries: + return self.replace(value, pattern, literal=literal, n=-1) def strip_chars(self, characters: str | None) -> PandasLikeSeries: return self.with_native(self.native.str.strip(characters)) diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index 8de0ad96c4..a57a02ce2c 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, cast import polars as pl @@ -26,8 +26,9 @@ from narwhals._expression_parsing import ExprKind, ExprMetadata from narwhals._polars.dataframe import Method from narwhals._polars.namespace import PolarsNamespace + from narwhals._polars.series import PolarsSeries from narwhals._utils import Version - from narwhals.typing import IntoDType, ModeKeepStrategy, NumericLiteral + from narwhals.typing import IntoDType, ModeKeepStrategy class PolarsExpr: @@ -35,11 +36,14 @@ class PolarsExpr: _implementation: Implementation = Implementation.POLARS _version: Version _native_expr: pl.Expr - _metadata: ExprMetadata | None = None _evaluate_output_names: Any _alias_output_names: Any __call__: Any + @classmethod + def _from_series(cls, series: PolarsSeries) -> Self: + return cls(series.native, version=series._version) # type: ignore[arg-type] + # CompliantExpr + builtin descriptor # TODO @dangotbanned: Remove in #2713 @classmethod @@ -76,9 +80,10 @@ def __repr__(self) -> str: # pragma: no cover def _with_native(self, expr: pl.Expr) -> Self: return self.__class__(expr, self._version) - def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: - # Let Polars do its thing. - return self + @property + def _metadata(self) -> ExprMetadata: + assert self._opt_metadata is not None # noqa: S101 + return cast("ExprMetadata", self._opt_metadata) def __getattr__(self, attr: str) -> Any: def func(*args: Any, **kwargs: Any) -> Any: @@ -87,6 +92,10 @@ def func(*args: Any, **kwargs: Any) -> Any: return func + def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: + # Let Polars do its thing. + return self + def _renamed_min_periods(self, min_samples: int, /) -> dict[str, Any]: name = "min_periods" if self._backend_version < (1, 21, 0) else "min_samples" return {name: min_samples} @@ -254,46 +263,6 @@ def __invert__(self) -> Self: def cum_count(self, *, reverse: bool) -> Self: return self._with_native(self.native.cum_count(reverse=reverse)) - def is_close( - self, - other: Self | NumericLiteral, - *, - abs_tol: float, - rel_tol: float, - nans_equal: bool, - ) -> Self: - left = self.native - right = other.native if isinstance(other, PolarsExpr) else pl.lit(other) - - if self._backend_version < (1, 32, 0): - lower_bound = right.abs() - tolerance = (left.abs().clip(lower_bound) * rel_tol).clip(abs_tol) - - # Values are close if abs_diff <= tolerance, and both finite - abs_diff = (left - right).abs() - all_ = pl.all_horizontal - is_close = all_((abs_diff <= tolerance), left.is_finite(), right.is_finite()) - - # Handle infinity cases: infinities are "close" only if they have the same sign - is_same_inf = all_( - left.is_infinite(), right.is_infinite(), (left.sign() == right.sign()) - ) - - # Handle nan cases: - # * nans_equals = True => if both values are NaN, then True - # * nans_equals = False => if any value is NaN, then False - left_is_nan, right_is_nan = left.is_nan(), right.is_nan() - either_nan = left_is_nan | right_is_nan - result = (is_close | is_same_inf) & either_nan.not_() - - if nans_equal: - result = result | (left_is_nan & right_is_nan) - else: - result = left.is_close( - right, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal - ) - return self._with_native(result) - def mode(self, *, keep: ModeKeepStrategy) -> Self: result = self.native.mode() return self._with_native(result.first() if keep == "any" else result) @@ -423,6 +392,16 @@ def zfill(self, width: int) -> PolarsExpr: return self.compliant._with_native(native_result) + def replace(self, value: str, pattern: str, *, literal: bool, n: int) -> PolarsExpr: + return self.compliant._with_native( + self.native.str.replace(pattern, extract_native(value), literal=literal, n=n) + ) + + def replace_all(self, value: str, pattern: str, *, literal: bool) -> PolarsExpr: + return self.compliant._with_native( + self.native.str.replace_all(pattern, extract_native(value), literal=literal) + ) + class PolarsExprCatNamespace( PolarsExprNamespace, PolarsCatNamespace[PolarsExpr, pl.Expr] diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 08eb56dbf2..0599196c7c 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -17,7 +17,7 @@ from collections.abc import Iterable, Sequence from datetime import timezone - from narwhals._compliant import CompliantSelectorNamespace, CompliantWhen + from narwhals._compliant import CompliantSelectorNamespace from narwhals._polars.dataframe import Method, PolarsDataFrame, PolarsLazyFrame from narwhals._polars.typing import FrameT from narwhals._utils import Version, _LimitedContext @@ -37,14 +37,10 @@ class PolarsNamespace: all: Method[PolarsExpr] coalesce: Method[PolarsExpr] - col: Method[PolarsExpr] - exclude: Method[PolarsExpr] sum_horizontal: Method[PolarsExpr] min_horizontal: Method[PolarsExpr] max_horizontal: Method[PolarsExpr] - when: Method[CompliantWhen[PolarsDataFrame, PolarsSeries, PolarsExpr]] - _implementation: Implementation = Implementation.POLARS _version: Version @@ -55,6 +51,15 @@ def _backend_version(self) -> tuple[int, ...]: def __init__(self, *, version: Version) -> None: self._version = version + def evaluate_expr( + self, data: Expr | NonNestedLiteral | Any, / + ) -> PolarsExpr | NonNestedLiteral: + if is_expr(data): + expr = data(self) + assert isinstance(expr, self._expr) # noqa: S101 + return expr + return data + def __getattr__(self, attr: str) -> Any: def func(*args: Any, **kwargs: Any) -> Any: pos, kwds = extract_args_kwargs(args, kwargs) @@ -93,12 +98,14 @@ def parse_into_expr( # NOTE: To avoid `pl.lit(None)` failing this `None` check # https://github.com/pola-rs/polars/blob/58dd8e5770f16a9bef9009a1c05f00e15a5263c7/py-polars/polars/expr/expr.py#L2870-L2872 return data + if isinstance(data, PolarsExpr): + return data if is_expr(data): - expr = data._to_compliant_expr(self) + expr = data(self) assert isinstance(expr, self._expr) # noqa: S101 return expr if isinstance(data, str) and not str_as_lit: - return self.col(data) + return self.col([data]) return self.lit(data.to_native() if is_series(data) else data, None) @overload @@ -137,10 +144,16 @@ def from_numpy( return self._dataframe.from_numpy(data, schema=schema, context=self) return self._series.from_numpy(data, context=self) # pragma: no cover + def col(self, names: Sequence[str]) -> PolarsExpr: + return self._expr(pl.col(*names), version=self._version) + + def exclude(self, names: Sequence[str]) -> PolarsExpr: + return self._expr(pl.exclude(*names), version=self._version) + @requires.backend_version( (1, 0, 0), "Please use `col` for columns selection instead." ) - def nth(self, *indices: int) -> PolarsExpr: + def nth(self, indices: Sequence[int]) -> PolarsExpr: return self._expr(pl.nth(*indices), version=self._version) def len(self) -> PolarsExpr: @@ -225,6 +238,22 @@ def concat_str( version=self._version, ) + def when_then( + self, when: PolarsExpr, then: PolarsExpr, otherwise: PolarsExpr | None = None + ) -> PolarsExpr: + if otherwise is None: + (when_native, then_native), _ = extract_args_kwargs((when, then), {}) + return self._expr( + pl.when(when_native).then(then_native), version=self._version + ) + (when_native, then_native, otherwise_native), _ = extract_args_kwargs( + (when, then, otherwise), {} + ) + return self._expr( + pl.when(when_native).then(then_native).otherwise(otherwise_native), + version=self._version, + ) + # NOTE: Implementation is too different to annotate correctly (vs other `*SelectorNamespace`) # 1. Others have lots of private stuff for code reuse # i. None of that is useful here diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index 8da184e335..222f15723e 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -4,6 +4,7 @@ import polars as pl +from narwhals._polars.expr import PolarsExpr from narwhals._polars.utils import ( BACKEND_VERSION, SERIES_ACCEPTS_PD_INDEX, @@ -32,6 +33,7 @@ import pyarrow as pa from typing_extensions import Self, TypeAlias, TypeIs + from narwhals._compliant.typing import CompliantExprAny from narwhals._polars.dataframe import Method, PolarsDataFrame from narwhals._polars.namespace import PolarsNamespace from narwhals._utils import Version, _LimitedContext @@ -43,7 +45,6 @@ ModeKeepStrategy, MultiIndexSelector, NonNestedLiteral, - NumericLiteral, _1DArray, ) @@ -150,6 +151,10 @@ def __init__(self, series: pl.Series, *, version: Version) -> None: self._native_series = series self._version = version + def _to_expr(self) -> CompliantExprAny: + # Polars can treat Series as Expr, so just pass down `self.native`. + return PolarsExpr(self.native, version=self._version) # type: ignore[arg-type] + @property def _backend_version(self) -> tuple[int, ...]: return self._implementation._backend_version() @@ -496,30 +501,6 @@ def __contains__(self, other: Any) -> bool: except Exception as e: # noqa: BLE001 raise catch_polars_exception(e) from None - def is_close( - self, - other: Self | NumericLiteral, - *, - abs_tol: float, - rel_tol: float, - nans_equal: bool, - ) -> PolarsSeries: - if self._backend_version < (1, 32, 0): - name = self.name - ns = self.__narwhals_namespace__() - other_expr = ( - ns.lit(other.native, None) if isinstance(other, PolarsSeries) else other - ) - expr = ns.col(name).is_close( - other_expr, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal - ) - return self.to_frame().select(expr).get_column(name) - other_series = other.native if isinstance(other, PolarsSeries) else other - result = self.native.is_close( - other_series, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal - ) - return self._with_native(result) - def mode(self, *, keep: ModeKeepStrategy) -> Self: result = self.native.mode() return self._with_native(result.head(1) if keep == "any" else result) @@ -768,7 +749,17 @@ class PolarsSeriesStringNamespace( def zfill(self, width: int) -> PolarsSeries: name = self.name ns = self.__narwhals_namespace__() - return self.to_frame().select(ns.col(name).str.zfill(width)).get_column(name) + return self.to_frame().select(ns.col([name]).str.zfill(width)).get_column(name) + + def replace(self, value: str, pattern: str, *, literal: bool, n: int) -> PolarsSeries: + return self.compliant._with_native( + self.native.str.replace(pattern, extract_native(value), literal=literal, n=n) + ) + + def replace_all(self, value: str, pattern: str, *, literal: bool) -> PolarsSeries: + return self.compliant._with_native( + self.native.str.replace_all(pattern, extract_native(value), literal=literal) + ) class PolarsSeriesCatNamespace( @@ -782,12 +773,12 @@ class PolarsSeriesListNamespace( def len(self) -> PolarsSeries: name = self.name ns = self.__narwhals_namespace__() - return self.to_frame().select(ns.col(name).list.len()).get_column(name) + return self.to_frame().select(ns.col([name]).list.len()).get_column(name) def contains(self, item: NonNestedLiteral) -> PolarsSeries: name = self.name ns = self.__narwhals_namespace__() - return self.to_frame().select(ns.col(name).list.contains(item)).get_column(name) + return self.to_frame().select(ns.col([name]).list.contains(item)).get_column(name) class PolarsSeriesStructNamespace( diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index a349429c59..09457a678e 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -606,10 +606,4 @@ def _from_compliant_dataframe( validate_backend_version=True, ) - gather_every = not_implemented.deprecated( - "`LazyFrame.gather_every` is deprecated and will be removed in a future version." - ) join_asof = not_implemented() - tail = not_implemented.deprecated( - "`LazyFrame.tail` is deprecated and will be removed in a future version." - ) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 13dd167fe5..8a78b38b76 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -58,7 +58,7 @@ def __init__( self._alias_output_names = alias_output_names self._version = version self._implementation = implementation - self._metadata: ExprMetadata | None = None + self._opt_metadata: ExprMetadata | None = None self._window_function: SparkWindowFunction | None = window_function _REMAP_RANK_METHOD: ClassVar[Mapping[RankMethod, NativeRankMethod]] = { @@ -171,8 +171,14 @@ def from_column_names( def func(df: SparkLikeLazyFrame) -> list[Column]: return [df._F.col(col_name) for col_name in evaluate_column_names(df)] + def window_func( + df: SparkLikeLazyFrame, _window_inputs: WindowInputs[Column] + ) -> list[Column]: + return [df._F.col(col_name) for col_name in evaluate_column_names(df)] + return cls( func, + window_func, evaluate_output_names=evaluate_column_names, alias_output_names=None, version=context._version, diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index d70d678338..06e6fdd5cb 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -18,7 +18,6 @@ true_divide, ) from narwhals._sql.namespace import SQLNamespace -from narwhals._sql.when_then import SQLThen, SQLWhen from narwhals._utils import zip_strict if TYPE_CHECKING: @@ -26,6 +25,7 @@ from sqlframe.base.column import Column + from narwhals._compliant.window import WindowInputs from narwhals._spark_like.dataframe import SQLFrameDataFrame # noqa: F401 from narwhals._utils import Implementation, Version from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral, PythonLiteral @@ -92,7 +92,7 @@ def _coalesce(self, *exprs: Column) -> Column: return self._F.coalesce(*exprs) def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> SparkLikeExpr: - def _lit(df: SparkLikeLazyFrame) -> list[Column]: + def func(df: SparkLikeLazyFrame) -> list[Column]: column = df._F.lit(value) if dtype: native_dtype = narwhals_to_native_dtype( @@ -102,8 +102,14 @@ def _lit(df: SparkLikeLazyFrame) -> list[Column]: return [column] + def window_func( + df: SparkLikeLazyFrame, _window_inputs: WindowInputs[Column] + ) -> list[Column]: + return func(df) + return self._expr( - call=_lit, + func, + window_func, evaluate_output_names=lambda _df: ["literal"], alias_output_names=None, version=self._version, @@ -214,17 +220,3 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: version=self._version, implementation=self._implementation, ) - - def when(self, predicate: SparkLikeExpr) -> SparkLikeWhen: - return SparkLikeWhen.from_expr(predicate, context=self) - - -class SparkLikeWhen(SQLWhen[SparkLikeLazyFrame, "Column", SparkLikeExpr]): - @property - def _then(self) -> type[SparkLikeThen]: - return SparkLikeThen - - -class SparkLikeThen( - SQLThen[SparkLikeLazyFrame, "Column", SparkLikeExpr], SparkLikeExpr -): ... diff --git a/narwhals/_sql/dataframe.py b/narwhals/_sql/dataframe.py index 356a77373f..41eef9ef9a 100644 --- a/narwhals/_sql/dataframe.py +++ b/narwhals/_sql/dataframe.py @@ -10,6 +10,7 @@ ) from narwhals._translate import ToNarwhalsT_co from narwhals._utils import check_columns_exist +from narwhals.exceptions import MultiOutputExpressionError if TYPE_CHECKING: from collections.abc import Sequence @@ -34,12 +35,16 @@ def _evaluate_window_expr( window_inputs: WindowInputs[NativeExprT], ) -> NativeExprT: result = expr.window_function(self, window_inputs) - assert len(result) == 1 # debug assertion # noqa: S101 + if len(result) != 1: # pragma: no cover + msg = "multi-output expressions not allowed in this context" + raise MultiOutputExpressionError(msg) return result[0] - def _evaluate_expr(self, expr: CompliantExprT_contra, /) -> Any: + def _evaluate_expr(self, expr: SQLExpr[Self, NativeExprT], /) -> NativeExprT: result = expr(self) - assert len(result) == 1 # debug assertion # noqa: S101 + if len(result) != 1: + msg = "multi-output expressions not allowed in this context" + raise MultiOutputExpressionError(msg) return result[0] def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None: diff --git a/narwhals/_sql/expr.py b/narwhals/_sql/expr.py index f427b73f40..292d822cbb 100644 --- a/narwhals/_sql/expr.py +++ b/narwhals/_sql/expr.py @@ -20,7 +20,7 @@ from narwhals._utils import Implementation, Version, not_implemented if TYPE_CHECKING: - from collections.abc import Iterable, Sequence + from collections.abc import Sequence from typing_extensions import Self, TypeIs @@ -44,7 +44,7 @@ class SQLExpr(LazyExpr[SQLLazyFrameT, NativeExprT], Protocol[SQLLazyFrameT, Nati _alias_output_names: AliasNames | None _version: Version _implementation: Implementation - _metadata: ExprMetadata | None + _opt_metadata: ExprMetadata | None _window_function: WindowFunction[SQLLazyFrameT, NativeExprT] | None def __init__( @@ -72,8 +72,6 @@ def func(df: SQLLazyFrameT) -> list[NativeExprT]: native_series_list = self(df) other_native_series = { key: df._evaluate_expr(value) - if self._is_expr(value) - else self._lit(value) for key, value in expressifiable_args.items() } return [ @@ -97,8 +95,6 @@ def window_f( native_series_list = self.window_function(df, window_inputs) other_native_series = { key: df._evaluate_window_expr(value, window_inputs) - if self._is_expr(value) - else self._lit(value) for key, value in expressifiable_args.items() } return [ @@ -178,8 +174,7 @@ def default_window_func( ) -> Sequence[NativeExprT]: assert not inputs.order_by # noqa: S101 return [ - self._window_expression(expr, inputs.partition_by, inputs.order_by) - for expr in self(df) + self._window_expression(expr, inputs.partition_by) for expr in self(df) ] return self._window_function or default_window_func @@ -308,19 +303,16 @@ def _alias_native(cls, expr: NativeExprT, name: str, /) -> NativeExprT: ... @classmethod def _from_elementwise_horizontal_op( - cls, func: Callable[[Iterable[NativeExprT]], NativeExprT], *exprs: Self + cls, func: Callable[[list[NativeExprT]], NativeExprT], *exprs: Self ) -> Self: def call(df: SQLLazyFrameT) -> Sequence[NativeExprT]: - cols = (col for _expr in exprs for col in _expr(df)) - return [func(cols)] + return [func([e for expr in exprs for e in expr(df)])] def window_function( df: SQLLazyFrameT, window_inputs: WindowInputs[NativeExprT] ) -> Sequence[NativeExprT]: - cols = ( - col for _expr in exprs for col in _expr.window_function(df, window_inputs) - ) - return [func(cols)] + lst = [e for expr in exprs for e in expr.window_function(df, window_inputs)] + return [func(lst)] context = exprs[0] return cls( @@ -343,7 +335,6 @@ def _is_multi_output_unnamed(self) -> bool: nw.all().sum(). """ - assert self._metadata is not None # noqa: S101 return self._metadata.expansion_kind.is_multi_unnamed() # Binary @@ -476,6 +467,7 @@ def f(expr: NativeExprT) -> NativeExprT: def window_f( df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT] ) -> Sequence[NativeExprT]: + assert not inputs.order_by # noqa: S101 return [ self._coalesce( self._window_expression( diff --git a/narwhals/_sql/expr_str.py b/narwhals/_sql/expr_str.py index 7e82ed30ec..0924137fca 100644 --- a/narwhals/_sql/expr_str.py +++ b/narwhals/_sql/expr_str.py @@ -38,21 +38,13 @@ def len_chars(self) -> SQLExprT: ) def replace_all( - self, pattern: str, value: str | SQLExprT, *, literal: bool + self, value: str | SQLExprT, pattern: str, *, literal: bool ) -> SQLExprT: fname: str = "replace" if literal else "regexp_replace" options: list[Any] = [] if not literal and self.compliant._implementation.is_duckdb(): options = [self._lit("g")] - - if isinstance(value, str): - return self.compliant._with_elementwise( - lambda expr: self._function( - fname, expr, self._lit(pattern), self._lit(value), *options - ) - ) - return self.compliant._with_elementwise( lambda expr, value: self._function( fname, expr, self._lit(pattern), value, *options diff --git a/narwhals/_sql/namespace.py b/narwhals/_sql/namespace.py index cc27757a00..e1e453c0df 100644 --- a/narwhals/_sql/namespace.py +++ b/narwhals/_sql/namespace.py @@ -2,16 +2,19 @@ import operator from functools import reduce -from typing import TYPE_CHECKING, Any, Protocol +from typing import TYPE_CHECKING, Any, Protocol, cast from narwhals._compliant import LazyNamespace from narwhals._compliant.typing import NativeExprT, NativeFrameT_co +from narwhals._expression_parsing import is_expr from narwhals._sql.typing import SQLExprT, SQLLazyFrameT +from narwhals.functions import lit if TYPE_CHECKING: from collections.abc import Iterable - from narwhals.typing import PythonLiteral + from narwhals.expr import Expr + from narwhals.typing import NonNestedLiteral, PythonLiteral class SQLNamespace( @@ -28,6 +31,13 @@ def _when( ) -> NativeExprT: ... def _coalesce(self, *exprs: NativeExprT) -> NativeExprT: ... + def evaluate_expr(self, data: Expr | NonNestedLiteral | Any, /) -> SQLExprT: + if is_expr(data): + expr = data(self) + assert isinstance(expr, self._expr) # noqa: S101 + return expr + return cast("SQLExprT", lit(data)(self)) + # Horizontal functions def any_horizontal(self, *exprs: SQLExprT, ignore_nulls: bool) -> SQLExprT: def func(cols: Iterable[NativeExprT]) -> NativeExprT: @@ -71,3 +81,18 @@ def func(cols: Iterable[NativeExprT]) -> NativeExprT: return self._coalesce(*cols) return self._expr._from_elementwise_horizontal_op(func, *exprs) + + def when_then( + self, predicate: SQLExprT, then: SQLExprT, otherwise: SQLExprT | None = None + ) -> SQLExprT: + def func(cols: list[NativeExprT]) -> NativeExprT: + return self._when(cols[1], cols[0]) + + def func_with_otherwise(cols: list[NativeExprT]) -> NativeExprT: + return self._when(cols[1], cols[0], cols[2]) + + if otherwise is None: + return self._expr._from_elementwise_horizontal_op(func, then, predicate) + return self._expr._from_elementwise_horizontal_op( + func_with_otherwise, then, predicate, otherwise + ) diff --git a/narwhals/_sql/when_then.py b/narwhals/_sql/when_then.py deleted file mode 100644 index 11c5bf5e20..0000000000 --- a/narwhals/_sql/when_then.py +++ /dev/null @@ -1,106 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Protocol - -from narwhals._compliant.typing import NativeExprT -from narwhals._compliant.when_then import CompliantThen, CompliantWhen -from narwhals._sql.typing import SQLExprT, SQLLazyFrameT - -if TYPE_CHECKING: - from collections.abc import Sequence - - from typing_extensions import Self - - from narwhals._compliant.typing import WindowFunction - from narwhals._compliant.when_then import IntoExpr - from narwhals._compliant.window import WindowInputs - from narwhals._utils import _LimitedContext - - -class SQLWhen( - CompliantWhen[SQLLazyFrameT, NativeExprT, SQLExprT], - Protocol[SQLLazyFrameT, NativeExprT, SQLExprT], -): - @property - def _then(self) -> type[SQLThen[SQLLazyFrameT, NativeExprT, SQLExprT]]: ... - - def __call__(self, df: SQLLazyFrameT) -> Sequence[NativeExprT]: - is_expr = self._condition._is_expr - when = df.__narwhals_namespace__()._when - lit = df.__narwhals_namespace__()._lit - condition = df._evaluate_expr(self._condition) - then_ = self._then_value - then = df._evaluate_expr(then_) if is_expr(then_) else lit(then_) - other_ = self._otherwise_value - if other_ is None: - result = when(condition, then) - else: - otherwise = df._evaluate_expr(other_) if is_expr(other_) else lit(other_) - result = when(condition, then).otherwise(otherwise) - return [result] - - @classmethod - def from_expr(cls, condition: SQLExprT, /, *, context: _LimitedContext) -> Self: - obj = cls.__new__(cls) - obj._condition = condition - obj._then_value = None - obj._otherwise_value = None - obj._implementation = context._implementation - obj._version = context._version - return obj - - def _window_function( - self, df: SQLLazyFrameT, window_inputs: WindowInputs[NativeExprT] - ) -> Sequence[NativeExprT]: - when = df.__narwhals_namespace__()._when - lit = df.__narwhals_namespace__()._lit - is_expr = self._condition._is_expr - condition = self._condition.window_function(df, window_inputs)[0] - then_ = self._then_value - then = ( - then_.window_function(df, window_inputs)[0] if is_expr(then_) else lit(then_) - ) - - other_ = self._otherwise_value - if other_ is None: - result = when(condition, then) - else: - other = ( - other_.window_function(df, window_inputs)[0] - if is_expr(other_) - else lit(other_) - ) - result = when(condition, then).otherwise(other) - return [result] - - -class SQLThen( - CompliantThen[ - SQLLazyFrameT, - NativeExprT, - SQLExprT, - SQLWhen[SQLLazyFrameT, NativeExprT, SQLExprT], - ], - Protocol[SQLLazyFrameT, NativeExprT, SQLExprT], -): - _window_function: WindowFunction[SQLLazyFrameT, NativeExprT] | None - - @classmethod - def from_when( - cls, - when: SQLWhen[SQLLazyFrameT, NativeExprT, SQLExprT], - then: IntoExpr[NativeExprT, SQLExprT], - /, - ) -> Self: - when._then_value = then - obj = cls.__new__(cls) - obj._call = when - obj._window_function = when._window_function - obj._when_value = when - obj._evaluate_output_names = getattr( - then, "_evaluate_output_names", lambda _df: ["literal"] - ) - obj._alias_output_names = getattr(then, "_alias_output_names", None) - obj._implementation = when._implementation - obj._version = when._version - return obj diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 7a3e4ce908..d46a072cbc 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import abstractmethod +from functools import partial from itertools import chain from typing import ( TYPE_CHECKING, @@ -18,8 +19,8 @@ from narwhals._exceptions import issue_warning from narwhals._expression_parsing import ( ExprKind, + _parse_into_expr, check_expressions_preserve_length, - is_into_expr_eager, is_scalar_like, ) from narwhals._typing import Arrow, Pandas, _LazyAllowedImpl, _LazyFrameCollectImpl @@ -46,7 +47,6 @@ from narwhals.dependencies import is_numpy_array_2d, is_pyarrow_table from narwhals.exceptions import ( ColumnNotFoundError, - InvalidIntoExprError, InvalidOperationError, PerformanceWarning, ) @@ -67,7 +67,8 @@ from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias from narwhals._compliant import CompliantDataFrame, CompliantLazyFrame - from narwhals._compliant.typing import CompliantExprAny, EagerNamespaceAny + from narwhals._compliant.typing import CompliantExprAny + from narwhals._expression_parsing import ExprMetadata from narwhals._translate import IntoArrowTable from narwhals._typing import EagerAllowed, IntoBackend, LazyAllowed, Polars from narwhals.group_by import GroupBy, LazyGroupBy @@ -147,20 +148,21 @@ def _flatten_and_extract( # NOTE: Strings are interpreted as column names. out_exprs = [] out_kinds = [] - for expr in flatten(exprs): - compliant_expr = self._extract_compliant(expr) - out_exprs.append(compliant_expr) - out_kinds.append(ExprKind.from_into_expr(expr, str_as_lit=False)) - for alias, expr in named_exprs.items(): - compliant_expr = self._extract_compliant(expr).alias(alias) - out_exprs.append(compliant_expr) - out_kinds.append(ExprKind.from_into_expr(expr, str_as_lit=False)) + ns = self.__narwhals_namespace__() + parse = partial( + _parse_into_expr, backend=self._compliant._implementation, allow_literal=False + ) + all_exprs = chain( + (parse(x) for x in flatten(exprs)), + (parse(expr).alias(alias) for alias, expr in named_exprs.items()), + ) + for expr in all_exprs: + ce = expr(ns) + out_exprs.append(ce) + out_kinds.append(ExprKind.from_expr(ce)) + self._validate_metadata(ce._metadata) return out_exprs, out_kinds - @abstractmethod - def _extract_compliant(self, arg: Any) -> Any: - raise NotImplementedError - def _extract_compliant_frame(self, other: Self | Any, /) -> Any: if isinstance(other, type(self)): return other._compliant_frame @@ -170,6 +172,10 @@ def _extract_compliant_frame(self, other: Self | Any, /) -> Any: def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None: return check_columns_exist(subset, available=self.columns) + @abstractmethod + def _validate_metadata(self, metadata: ExprMetadata) -> None: + pass + @property def schema(self) -> Schema: return Schema(self._compliant_frame.schema.items()) @@ -221,7 +227,7 @@ def select( raise error from e raise compliant_exprs, kinds = self._flatten_and_extract(*flat_exprs, **named_exprs) - if compliant_exprs and all(is_scalar_like(kind) for kind in kinds): + if compliant_exprs and all(x.is_scalar_like for x in kinds): return self._with_compliant(self._compliant_frame.aggregate(*compliant_exprs)) compliant_exprs = [ compliant_expr.broadcast(kind) if is_scalar_like(kind) else compliant_expr @@ -250,12 +256,13 @@ def filter( from narwhals.functions import col flat_predicates = flatten(predicates) - check_expressions_preserve_length(*flat_predicates, function_name="filter") plx = self.__narwhals_namespace__() compliant_predicates, _kinds = self._flatten_and_extract(*flat_predicates) - compliant_constraints = ( - (col(name) == v)._to_compliant_expr(plx) - for name, v in constraints.items() + check_expressions_preserve_length( + *compliant_predicates, function_name="filter" + ) + compliant_constraints, _ = self._flatten_and_extract( + *[col(name) == v for name, v in constraints.items()] ) predicate = plx.all_horizontal( *chain(compliant_predicates, compliant_constraints), ignore_nulls=False @@ -475,12 +482,6 @@ class DataFrame(BaseFrame[DataFrameT]): def _compliant(self) -> CompliantDataFrame[Any, Any, DataFrameT, Self]: return self._compliant_frame - def _extract_compliant(self, arg: Any) -> Any: - if is_into_expr_eager(arg): - plx: EagerNamespaceAny = self.__narwhals_namespace__() - return plx.parse_into_expr(arg, str_as_lit=False) - raise InvalidIntoExprError.from_invalid_type(type(arg)) - @property def _series(self) -> type[Series[Any]]: return Series @@ -489,6 +490,10 @@ def _series(self) -> type[Series[Any]]: def _lazyframe(self) -> type[LazyFrame[Any]]: return LazyFrame + def _validate_metadata(self, metadata: ExprMetadata) -> None: + # all is valid in eager case. + pass + def __init__(self, df: Any, *, level: Literal["full", "lazy", "interchange"]) -> None: self._level: Literal["full", "lazy", "interchange"] = level self._compliant_frame: CompliantDataFrame[Any, Any, DataFrameT, Self] @@ -2290,45 +2295,35 @@ class LazyFrame(BaseFrame[LazyFrameT]): def _compliant(self) -> CompliantLazyFrame[Any, LazyFrameT, Self]: return self._compliant_frame - def _extract_compliant(self, arg: Any) -> Any: - from narwhals.expr import Expr - from narwhals.series import Series - - if isinstance(arg, Series): # pragma: no cover - msg = "Binary operations between Series and LazyFrame are not supported." - raise TypeError(msg) - if isinstance(arg, (Expr, str)): - if isinstance(arg, Expr): - if arg._metadata.n_orderable_ops: - msg = ( - "Order-dependent expressions are not supported for use in LazyFrame.\n\n" - "Hint: To make the expression valid, use `.over` with `order_by` specified.\n\n" - "For example, if you wrote `nw.col('price').cum_sum()` and you have a column\n" - "`'date'` which orders your data, then replace:\n\n" - " nw.col('price').cum_sum()\n\n" - " with:\n\n" - " nw.col('price').cum_sum().over(order_by='date')\n" - " ^^^^^^^^^^^^^^^^^^^^^^\n\n" - "See https://narwhals-dev.github.io/narwhals/concepts/order_dependence/." - ) - raise InvalidOperationError(msg) - if arg._metadata.is_filtration: - msg = ( - "Length-changing expressions are not supported for use in LazyFrame, unless\n" - "followed by an aggregation.\n\n" - "Hints:\n" - "- Instead of `lf.select(nw.col('a').head())`, use `lf.select('a').head()\n" - "- Instead of `lf.select(nw.col('a').drop_nulls()).select(nw.sum('a'))`,\n" - " use `lf.select(nw.col('a').drop_nulls().sum())\n" - ) - raise InvalidOperationError(msg) - return self.__narwhals_namespace__().parse_into_expr(arg, str_as_lit=False) - raise InvalidIntoExprError.from_invalid_type(type(arg)) - @property def _dataframe(self) -> type[DataFrame[Any]]: return DataFrame + def _validate_metadata(self, metadata: ExprMetadata) -> None: + if metadata.n_orderable_ops > 0: + msg = ( + "Order-dependent expressions are not supported for use in LazyFrame.\n\n" + "Hint: To make the expression valid, use `.over` with `order_by` specified.\n\n" + "For example, if you wrote `nw.col('price').cum_sum()` and you have a column\n" + "`'date'` which orders your data, then replace:\n\n" + " nw.col('price').cum_sum()\n\n" + " with:\n\n" + " nw.col('price').cum_sum().over(order_by='date')\n" + " ^^^^^^^^^^^^^^^^^^^^^^\n\n" + "See https://narwhals-dev.github.io/narwhals/concepts/order_dependence/." + ) + raise InvalidOperationError(msg) + if metadata.is_filtration: + msg = ( + "Length-changing expressions are not supported for use in LazyFrame, unless\n" + "followed by an aggregation.\n\n" + "Hints:\n" + "- Instead of `lf.select(nw.col('a').head())`, use `lf.select('a').head()\n" + "- Instead of `lf.select(nw.col('a').drop_nulls()).select(nw.sum('a'))`,\n" + " use `lf.select(nw.col('a').drop_nulls().sum())\n" + ) + raise InvalidOperationError(msg) + def __init__(self, df: Any, *, level: Literal["full", "lazy", "interchange"]) -> None: self._level = level self._compliant_frame: CompliantLazyFrame[Any, LazyFrameT, Self] @@ -2962,7 +2957,6 @@ def group_by( k if is_expr else col(k) for k, is_expr in zip_strict(flat_keys, key_is_expr) ] expr_flat_keys, kinds = self._flatten_and_extract(*_keys) - if not all(kind is ExprKind.ELEMENTWISE for kind in kinds): from narwhals.exceptions import ComputeError diff --git a/narwhals/expr.py b/narwhals/expr.py index cca4d523d9..1c8b8f936c 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -1,18 +1,18 @@ from __future__ import annotations import math -import operator as op from collections.abc import Iterable, Mapping, Sequence from typing import TYPE_CHECKING, Any, Callable from narwhals._expression_parsing import ( - ExprMetadata, - apply_n_ary_operation, - combine_metadata, + ExprKind, + ExprNode, + evaluate_node, + evaluate_root_node, ) from narwhals._utils import _validate_rolling_arguments, ensure_type, flatten from narwhals.dtypes import _validate_dtype -from narwhals.exceptions import ComputeError, InvalidOperationError +from narwhals.exceptions import ComputeError from narwhals.expr_cat import ExprCatNamespace from narwhals.expr_dt import ExprDateTimeNamespace from narwhals.expr_list import ExprListNamespace @@ -24,7 +24,7 @@ if TYPE_CHECKING: from typing import NoReturn, TypeVar - from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias + from typing_extensions import Concatenate, ParamSpec, Self from narwhals._compliant import CompliantExpr, CompliantNamespace from narwhals.dtypes import DType @@ -39,74 +39,58 @@ RankMethod, RollingInterpolationMethod, TemporalLiteral, - _1DArray, ) PS = ParamSpec("PS") R = TypeVar("R") - _ToCompliant: TypeAlias = Callable[ - [CompliantNamespace[Any, Any]], CompliantExpr[Any, Any] - ] class Expr: - def __init__(self, to_compliant_expr: _ToCompliant, metadata: ExprMetadata) -> None: - # callable from CompliantNamespace to CompliantExpr - def func(plx: CompliantNamespace[Any, Any]) -> CompliantExpr[Any, Any]: - result = to_compliant_expr(plx) - result._metadata = self._metadata - return result - - self._to_compliant_expr: _ToCompliant = func - self._metadata = metadata - - def _with_elementwise(self, to_compliant_expr: Callable[[Any], Any]) -> Self: - return self.__class__(to_compliant_expr, self._metadata.with_elementwise_op()) - - def _with_aggregation(self, to_compliant_expr: Callable[[Any], Any]) -> Self: - return self.__class__(to_compliant_expr, self._metadata.with_aggregation()) - - def _with_orderable_aggregation( - self, to_compliant_expr: Callable[[Any], Any] - ) -> Self: - return self.__class__( - to_compliant_expr, self._metadata.with_orderable_aggregation() - ) - - def _with_orderable_window(self, to_compliant_expr: Callable[[Any], Any]) -> Self: - return self.__class__(to_compliant_expr, self._metadata.with_orderable_window()) - - def _with_window(self, to_compliant_expr: Callable[[Any], Any]) -> Self: - return self.__class__(to_compliant_expr, self._metadata.with_window()) - - def _with_filtration(self, to_compliant_expr: Callable[[Any], Any]) -> Self: - return self.__class__(to_compliant_expr, self._metadata.with_filtration()) - - def _with_orderable_filtration(self, to_compliant_expr: Callable[[Any], Any]) -> Self: - return self.__class__( - to_compliant_expr, self._metadata.with_orderable_filtration() - ) - - def _with_nary( - self, - n_ary_function: Callable[..., Any], - *args: IntoExpr | NonNestedLiteral | _1DArray, - ) -> Self: - return self.__class__( - lambda plx: apply_n_ary_operation( - plx, n_ary_function, self, *args, str_as_lit=False - ), - combine_metadata( - self, - *args, - str_as_lit=False, - allow_multi_output=False, - to_single_output=False, - ), - ) + def __init__(self, *nodes: ExprNode) -> None: + self._nodes = nodes + + def __call__(self, ns: CompliantNamespace[Any, Any]) -> CompliantExpr[Any, Any]: + nodes = self._nodes + ce = evaluate_root_node(nodes[0], ns) + for node in nodes[1:]: + ce = evaluate_node(ce, node, ns) + return ce + + def _with_node(self, node: ExprNode) -> Self: + if node.kind is ExprKind.OVER: + # insert `over` before any elementwise operations. + # check "how it works" page in docs for why we do this. + new_nodes = list(self._nodes) + kwargs_no_order_by = { + key: value if key != "order_by" else [] + for (key, value) in node.kwargs.items() + } + node_without_order_by = node._with_kwargs(**kwargs_no_order_by) + n = len(new_nodes) + i = n + while i > 0 and (_node := new_nodes[i - 1]).kind is ExprKind.ELEMENTWISE: + i -= 1 + _node._push_down_over_node_in_place(node, node_without_order_by) + if i == n: + # node could not be pushed down, just append as-is + new_nodes.append(node) + return self.__class__(*new_nodes) + if i > 0: + if node.kwargs["order_by"] and any( + node.is_orderable() for node in new_nodes[:i] + ): + new_nodes.insert(i, node) + elif node.kwargs["partition_by"]: + new_nodes.insert(i, node_without_order_by) + return self.__class__(*new_nodes) + return self.__class__(*self._nodes, node) def __repr__(self) -> str: - return f"Narwhals Expr\nmetadata: {self._metadata}\n" + """Pretty-print the expression by combining all nodes in the metadata.""" + result: str = repr(self._nodes[0]) + for node in self._nodes[1:]: + result = f"{result}.{node!r}" + return result def __bool__(self) -> NoReturn: msg = ( @@ -124,9 +108,7 @@ def __bool__(self) -> NoReturn: def _taxicab_norm(self) -> Self: # This is just used to test out the stable api feature in a realistic-ish way. # It's not intended to be used. - return self._with_aggregation( - lambda plx: self._to_compliant_expr(plx).abs().sum() - ) + return self.abs().sum() # --- convert --- def alias(self, name: str) -> Self: @@ -138,7 +120,7 @@ def alias(self, name: str) -> Self: Examples: >>> import pandas as pd >>> import narwhals as nw - >>> df_native = pd.DataFrame({"a": [1, 2], "b": [4, 5]}) + >>> df_native = pd.DataFrame({"a": [], "b": [4, 5]}) >>> df = nw.from_native(df_native) >>> df.select((nw.col("b") + 10).alias("c")) ┌──────────────────┐ @@ -149,10 +131,7 @@ def alias(self, name: str) -> Self: | 1 15 | └──────────────────┘ """ - # Don't use `_with_elementwise` so that `_metadata.last_node` is preserved. - return self.__class__( - lambda plx: self._to_compliant_expr(plx).alias(name), self._metadata - ) + return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "alias", name=name)) def pipe( self, @@ -207,102 +186,88 @@ def cast(self, dtype: IntoDType) -> Self: └──────────────────┘ """ _validate_dtype(dtype) - return self._with_elementwise( - lambda plx: self._to_compliant_expr(plx).cast(dtype) - ) + return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "cast", dtype=dtype)) # --- binary --- - def _with_binary( - self, - function: Callable[[Any, Any], Any], - other: Self | Any, - *, - str_as_lit: bool = True, - ) -> Self: - return self.__class__( - lambda plx: apply_n_ary_operation( - plx, function, self, other, str_as_lit=str_as_lit - ), - ExprMetadata.from_binary_op(self, other), - ) + def _with_binary(self, attr: str, other: Self | Any) -> Self: + node = ExprNode(ExprKind.ELEMENTWISE, attr, other, str_as_lit=True) + return self._with_node(node) def __eq__(self, other: Self | Any) -> Self: # type: ignore[override] - return self._with_binary(op.eq, other) + return self._with_binary("__eq__", other) def __ne__(self, other: Self | Any) -> Self: # type: ignore[override] - return self._with_binary(op.ne, other) + return self._with_binary("__ne__", other) def __and__(self, other: Any) -> Self: - return self._with_binary(op.and_, other) + return self._with_binary("__and__", other) def __rand__(self, other: Any) -> Self: return (self & other).alias("literal") # type: ignore[no-any-return] def __or__(self, other: Any) -> Self: - return self._with_binary(op.or_, other) + return self._with_binary("__or__", other) def __ror__(self, other: Any) -> Self: return (self | other).alias("literal") # type: ignore[no-any-return] def __add__(self, other: Any) -> Self: - return self._with_binary(op.add, other) + return self._with_binary("__add__", other) def __radd__(self, other: Any) -> Self: return (self + other).alias("literal") # type: ignore[no-any-return] def __sub__(self, other: Any) -> Self: - return self._with_binary(op.sub, other) + return self._with_binary("__sub__", other) def __rsub__(self, other: Any) -> Self: - return self._with_binary(lambda x, y: x.__rsub__(y), other) + return self._with_binary("__rsub__", other) def __truediv__(self, other: Any) -> Self: - return self._with_binary(op.truediv, other) + return self._with_binary("__truediv__", other) def __rtruediv__(self, other: Any) -> Self: - return self._with_binary(lambda x, y: x.__rtruediv__(y), other) + return self._with_binary("__rtruediv__", other) def __mul__(self, other: Any) -> Self: - return self._with_binary(op.mul, other) + return self._with_binary("__mul__", other) def __rmul__(self, other: Any) -> Self: return (self * other).alias("literal") # type: ignore[no-any-return] def __le__(self, other: Any) -> Self: - return self._with_binary(op.le, other) + return self._with_binary("__le__", other) def __lt__(self, other: Any) -> Self: - return self._with_binary(op.lt, other) + return self._with_binary("__lt__", other) def __gt__(self, other: Any) -> Self: - return self._with_binary(op.gt, other) + return self._with_binary("__gt__", other) def __ge__(self, other: Any) -> Self: - return self._with_binary(op.ge, other) + return self._with_binary("__ge__", other) def __pow__(self, other: Any) -> Self: - return self._with_binary(op.pow, other) + return self._with_binary("__pow__", other) def __rpow__(self, other: Any) -> Self: - return self._with_binary(lambda x, y: x.__rpow__(y), other) + return self._with_binary("__rpow__", other) def __floordiv__(self, other: Any) -> Self: - return self._with_binary(op.floordiv, other) + return self._with_binary("__floordiv__", other) def __rfloordiv__(self, other: Any) -> Self: - return self._with_binary(lambda x, y: x.__rfloordiv__(y), other) + return self._with_binary("__rfloordiv__", other) def __mod__(self, other: Any) -> Self: - return self._with_binary(op.mod, other) + return self._with_binary("__mod__", other) def __rmod__(self, other: Any) -> Self: - return self._with_binary(lambda x, y: x.__rmod__(y), other) + return self._with_binary("__rmod__", other) # --- unary --- def __invert__(self) -> Self: - return self._with_elementwise( - lambda plx: self._to_compliant_expr(plx).__invert__() - ) + return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "__invert__")) def any(self) -> Self: """Return whether any of the values in the column are `True`. @@ -322,7 +287,7 @@ def any(self) -> Self: | 0 True True | └──────────────────┘ """ - return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).any()) + return self._with_node(ExprNode(ExprKind.AGGREGATION, "any")) def all(self) -> Self: """Return whether all values in the column are `True`. @@ -342,7 +307,7 @@ def all(self) -> Self: | 0 False True | └──────────────────┘ """ - return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).all()) + return self._with_node(ExprNode(ExprKind.AGGREGATION, "all")) def ewm_mean( self, @@ -427,8 +392,10 @@ def ewm_mean( │ 2.428571 │ └──────────┘ """ - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).ewm_mean( + return self._with_node( + ExprNode( + ExprKind.ORDERABLE_WINDOW, + "ewm_mean", com=com, span=span, half_life=half_life, @@ -455,7 +422,7 @@ def mean(self) -> Self: | 0 0.0 4.0 | └──────────────────┘ """ - return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).mean()) + return self._with_node(ExprNode(ExprKind.AGGREGATION, "mean")) def median(self) -> Self: """Get median value. @@ -476,7 +443,7 @@ def median(self) -> Self: | 0 3.0 4.0 | └──────────────────┘ """ - return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).median()) + return self._with_node(ExprNode(ExprKind.AGGREGATION, "median")) def std(self, *, ddof: int = 1) -> Self: """Get standard deviation. @@ -498,9 +465,7 @@ def std(self, *, ddof: int = 1) -> Self: |0 17.79513 1.265789| └─────────────────────┘ """ - return self._with_aggregation( - lambda plx: self._to_compliant_expr(plx).std(ddof=ddof) - ) + return self._with_node(ExprNode(ExprKind.AGGREGATION, "std", ddof=ddof)) def var(self, *, ddof: int = 1) -> Self: """Get variance. @@ -522,9 +487,7 @@ def var(self, *, ddof: int = 1) -> Self: |0 316.666667 1.602222| └───────────────────────┘ """ - return self._with_aggregation( - lambda plx: self._to_compliant_expr(plx).var(ddof=ddof) - ) + return self._with_node(ExprNode(ExprKind.AGGREGATION, "var", ddof=ddof)) def map_batches( self, @@ -568,18 +531,20 @@ def map_batches( |2 3 6 4.0 7.0| └───────────────────────────┘ """ - - def compliant_expr(plx: Any) -> Any: - return self._to_compliant_expr(plx).map_batches( + kind = ( + ExprKind.ORDERABLE_AGGREGATION + if returns_scalar + else ExprKind.ORDERABLE_FILTRATION + ) + return self._with_node( + ExprNode( + kind, + "map_batches", function=function, return_dtype=return_dtype, returns_scalar=returns_scalar, ) - - if returns_scalar: - return self._with_orderable_aggregation(compliant_expr) - # safest assumptions - return self._with_orderable_filtration(compliant_expr) + ) def skew(self) -> Self: """Calculate the sample skewness of a column. @@ -597,7 +562,7 @@ def skew(self) -> Self: | 0 0.0 1.472427 | └──────────────────┘ """ - return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).skew()) + return self._with_node(ExprNode(ExprKind.AGGREGATION, "skew")) def kurtosis(self) -> Self: """Compute the kurtosis (Fisher's definition) without bias correction. @@ -618,9 +583,9 @@ def kurtosis(self) -> Self: | 0 -1.3 0.210657 | └──────────────────┘ """ - return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).kurtosis()) + return self._with_node(ExprNode(ExprKind.AGGREGATION, "kurtosis")) - def sum(self) -> Expr: + def sum(self) -> Self: """Return the sum value. If there are no non-null elements, the result is zero. @@ -642,7 +607,7 @@ def sum(self) -> Expr: |└────────┴────────┘| └───────────────────┘ """ - return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).sum()) + return self._with_node(ExprNode(ExprKind.AGGREGATION, "sum")) def min(self) -> Self: """Returns the minimum value(s) from a column(s). @@ -660,7 +625,7 @@ def min(self) -> Self: | 0 1 3 | └──────────────────┘ """ - return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).min()) + return self._with_node(ExprNode(ExprKind.AGGREGATION, "min")) def max(self) -> Self: """Returns the maximum value(s) from a column(s). @@ -678,7 +643,7 @@ def max(self) -> Self: | 0 20 100 | └──────────────────┘ """ - return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).max()) + return self._with_node(ExprNode(ExprKind.AGGREGATION, "max")) def count(self) -> Self: """Returns the number of non-null elements in the column. @@ -696,7 +661,7 @@ def count(self) -> Self: | 0 3 2 | └──────────────────┘ """ - return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).count()) + return self._with_node(ExprNode(ExprKind.AGGREGATION, "count")) def n_unique(self) -> Self: """Returns count of unique values. @@ -714,7 +679,7 @@ def n_unique(self) -> Self: | 0 5 3 | └──────────────────┘ """ - return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).n_unique()) + return self._with_node(ExprNode(ExprKind.AGGREGATION, "n_unique")) def unique(self) -> Self: """Return unique values of this expression. @@ -732,7 +697,7 @@ def unique(self) -> Self: | 0 9 12 | └──────────────────┘ """ - return self._with_filtration(lambda plx: self._to_compliant_expr(plx).unique()) + return self._with_node(ExprNode(ExprKind.FILTRATION, "unique")) def abs(self) -> Self: """Return absolute value of each element. @@ -751,7 +716,7 @@ def abs(self) -> Self: |1 -2 4 2 4| └─────────────────────┘ """ - return self._with_elementwise(lambda plx: self._to_compliant_expr(plx).abs()) + return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "abs")) def cum_sum(self, *, reverse: bool = False) -> Self: """Return cumulative sum. @@ -780,8 +745,8 @@ def cum_sum(self, *, reverse: bool = False) -> Self: |4 5 6 15| └──────────────────┘ """ - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).cum_sum(reverse=reverse) + return self._with_node( + ExprNode(ExprKind.ORDERABLE_WINDOW, "cum_sum", reverse=reverse) ) def diff(self) -> Self: @@ -823,9 +788,7 @@ def diff(self) -> Self: | └─────┴────────┘ | └──────────────────┘ """ - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).diff() - ) + return self._with_node(ExprNode(ExprKind.ORDERABLE_WINDOW, "diff")) def shift(self, n: int) -> Self: """Shift values by `n` positions. @@ -870,10 +833,7 @@ def shift(self, n: int) -> Self: └──────────────────┘ """ ensure_type(n, int, param_name="n") - - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).shift(n) - ) + return self._with_node(ExprNode(ExprKind.ORDERABLE_WINDOW, "shift", n=n)) def replace_strict( self, @@ -925,9 +885,13 @@ def replace_strict( new = list(old.values()) old = list(old.keys()) - return self._with_elementwise( - lambda plx: self._to_compliant_expr(plx).replace_strict( - old, new, return_dtype=return_dtype + return self._with_node( + ExprNode( + ExprKind.ELEMENTWISE, + "replace_strict", + old=old, + new=new, + return_dtype=return_dtype, ) ) @@ -962,11 +926,10 @@ def is_between( | 4 5 False | └──────────────────┘ """ - return self._with_nary( - lambda expr, lb, ub: expr.is_between(lb, ub, closed=closed), - lower_bound, - upper_bound, + node = ExprNode( + ExprKind.ELEMENTWISE, "is_between", lower_bound, upper_bound, closed=closed ) + return self._with_node(node) def is_in(self, other: Any) -> Self: """Check if elements of this expression are present in the other iterable. @@ -991,9 +954,11 @@ def is_in(self, other: Any) -> Self: └──────────────────┘ """ if isinstance(other, Iterable) and not isinstance(other, (str, bytes)): - return self._with_elementwise( - lambda plx: self._to_compliant_expr(plx).is_in( - to_native(other, pass_through=True) + return self._with_node( + ExprNode( + ExprKind.ELEMENTWISE, + "is_in", + other=to_native(other, pass_through=True), ) ) msg = "Narwhals `is_in` doesn't accept expressions as an argument, as opposed to Polars. You should provide an iterable instead." @@ -1025,24 +990,7 @@ def filter(self, *predicates: Any) -> Self: | 5 7 12 | └──────────────────┘ """ - flat_predicates = flatten(predicates) - metadata = combine_metadata( - self, - *flat_predicates, - str_as_lit=False, - allow_multi_output=True, - to_single_output=False, - ).with_filtration() - return self.__class__( - lambda plx: apply_n_ary_operation( - plx, - lambda *exprs: exprs[0].filter(*exprs[1:]), - self, - *flat_predicates, - str_as_lit=False, - ), - metadata, - ) + return self._with_node(ExprNode(ExprKind.FILTRATION, "filter", *predicates)) def is_null(self) -> Self: """Returns a boolean Series indicating which values are null. @@ -1073,7 +1021,7 @@ def is_null(self) -> Self: |└───────┴────────┴───────────┴───────────┘| └──────────────────────────────────────────┘ """ - return self._with_elementwise(lambda plx: self._to_compliant_expr(plx).is_null()) + return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "is_null")) def is_nan(self) -> Self: """Indicate which values are NaN. @@ -1104,7 +1052,7 @@ def is_nan(self) -> Self: |└───────┴────────┴──────────┴──────────┘| └────────────────────────────────────────┘ """ - return self._with_elementwise(lambda plx: self._to_compliant_expr(plx).is_nan()) + return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "is_nan")) def fill_null( self, @@ -1191,20 +1139,24 @@ def fill_null( msg = f"strategy not supported: {strategy}" raise ValueError(msg) - return self.__class__( - lambda plx: apply_n_ary_operation( - plx, - lambda *exprs: exprs[0].fill_null( - exprs[1], strategy=strategy, limit=limit - ), - self, + if strategy is not None: + node = ExprNode( + ExprKind.ORDERABLE_WINDOW, + "fill_null", + value=value, + strategy=strategy, + limit=limit, + ) + else: + node = ExprNode( + ExprKind.ELEMENTWISE, + "fill_null", value, + strategy=strategy, + limit=limit, str_as_lit=True, - ), - self._metadata.with_orderable_window() - if strategy is not None - else self._metadata, - ) + ) + return self._with_node(node) def fill_nan(self, value: float | None) -> Self: """Fill floating point NaN values with given value. @@ -1237,9 +1189,7 @@ def fill_nan(self, value: float | None) -> Self: |└────────┴────────┴───────────────┴───────────────┘| └───────────────────────────────────────────────────┘ """ - return self._with_elementwise( - lambda plx: self._to_compliant_expr(plx).fill_nan(value) - ) + return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "fill_nan", value=value)) # --- partial reduction --- def drop_nulls(self) -> Self: @@ -1272,9 +1222,7 @@ def drop_nulls(self) -> Self: | └─────┘ | └──────────────────┘ """ - return self._with_filtration( - lambda plx: self._to_compliant_expr(plx).drop_nulls() - ) + return self._with_node(ExprNode(ExprKind.FILTRATION, "drop_nulls")) def over( self, @@ -1324,22 +1272,10 @@ def over( if not flat_partition_by and not flat_order_by: # pragma: no cover msg = "At least one of `partition_by` or `order_by` must be specified." raise ValueError(msg) - - current_meta = self._metadata - if flat_order_by: - next_meta = current_meta.with_ordered_over() - elif not flat_partition_by: # pragma: no cover - msg = "At least one of `partition_by` or `order_by` must be specified." - raise InvalidOperationError(msg) - else: - next_meta = current_meta.with_partitioned_over() - - return self.__class__( - lambda plx: self._to_compliant_expr(plx).over( - flat_partition_by, flat_order_by - ), - next_meta, + node = ExprNode( + ExprKind.OVER, "over", partition_by=flat_partition_by, order_by=flat_order_by ) + return self._with_node(node) def is_duplicated(self) -> Self: r"""Return a boolean mask indicating duplicated values. @@ -1360,7 +1296,7 @@ def is_duplicated(self) -> Self: |3 1 c True False| └─────────────────────────────────────────┘ """ - return self._with_window(lambda plx: self._to_compliant_expr(plx).is_duplicated()) + return self._with_node(ExprNode(ExprKind.WINDOW, "is_duplicated")) def is_unique(self) -> Self: r"""Return a boolean mask indicating unique values. @@ -1381,7 +1317,7 @@ def is_unique(self) -> Self: |3 1 c False True| └─────────────────────────────────┘ """ - return self._with_window(lambda plx: self._to_compliant_expr(plx).is_unique()) + return self._with_node(ExprNode(ExprKind.WINDOW, "is_unique")) def null_count(self) -> Self: r"""Count null values. @@ -1405,9 +1341,7 @@ def null_count(self) -> Self: | 0 1 2 | └──────────────────┘ """ - return self._with_aggregation( - lambda plx: self._to_compliant_expr(plx).null_count() - ) + return self._with_node(ExprNode(ExprKind.AGGREGATION, "null_count")) def is_first_distinct(self) -> Self: r"""Return a boolean mask indicating the first occurrence of each distinct value. @@ -1434,9 +1368,7 @@ def is_first_distinct(self) -> Self: |3 1 c False True| └─────────────────────────────────────────────────┘ """ - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).is_first_distinct() - ) + return self._with_node(ExprNode(ExprKind.ORDERABLE_WINDOW, "is_first_distinct")) def is_last_distinct(self) -> Self: r"""Return a boolean mask indicating the last occurrence of each distinct value. @@ -1463,9 +1395,7 @@ def is_last_distinct(self) -> Self: |3 1 c True True| └───────────────────────────────────────────────┘ """ - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).is_last_distinct() - ) + return self._with_node(ExprNode(ExprKind.ORDERABLE_WINDOW, "is_last_distinct")) def quantile( self, quantile: float, interpolation: RollingInterpolationMethod @@ -1498,8 +1428,13 @@ def quantile( | 0 24.5 74.5 | └──────────────────┘ """ - return self._with_aggregation( - lambda plx: self._to_compliant_expr(plx).quantile(quantile, interpolation) + return self._with_node( + ExprNode( + ExprKind.AGGREGATION, + "quantile", + quantile=quantile, + interpolation=interpolation, + ) ) def round(self, decimals: int = 0) -> Self: @@ -1532,9 +1467,7 @@ def round(self, decimals: int = 0) -> Self: |2 3.901234 3.9| └──────────────────────┘ """ - return self._with_elementwise( - lambda plx: self._to_compliant_expr(plx).round(decimals) - ) + return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "round", decimals=decimals)) def len(self) -> Self: r"""Return the number of elements in the column. @@ -1557,7 +1490,7 @@ def len(self) -> Self: | 0 2 1 | └──────────────────┘ """ - return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).len()) + return self._with_node(ExprNode(ExprKind.AGGREGATION, "len")) def clip( self, @@ -1585,13 +1518,8 @@ def clip( | 2 3 3 | └──────────────────┘ """ - return self._with_nary( - lambda *exprs: exprs[0].clip( - exprs[1] if lower_bound is not None else None, - exprs[2] if upper_bound is not None else None, - ), - lower_bound, - upper_bound, + return self._with_node( + ExprNode(ExprKind.ELEMENTWISE, "clip", lower_bound, upper_bound) ) def mode(self, *, keep: ModeKeepStrategy = "all") -> Self: @@ -1620,13 +1548,8 @@ def mode(self, *, keep: ModeKeepStrategy = "all") -> Self: if keep not in _supported_keep_values: # pragma: no cover msg = f"`keep` must be one of {_supported_keep_values}, found '{keep}'" raise ValueError(msg) - - def compliant_expr(plx: Any) -> Any: - return self._to_compliant_expr(plx).mode(keep=keep) - - if keep == "any": - return self._with_aggregation(compliant_expr) - return self._with_filtration(compliant_expr) + kind = ExprKind.AGGREGATION if keep == "any" else ExprKind.FILTRATION + return self._with_node(ExprNode(kind, "mode", keep=keep)) def is_finite(self) -> Self: """Returns boolean values indicating which original values are finite. @@ -1660,9 +1583,7 @@ def is_finite(self) -> Self: |└──────┴─────────────┘| └──────────────────────┘ """ - return self._with_elementwise( - lambda plx: self._to_compliant_expr(plx).is_finite() - ) + return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "is_finite")) def cum_count(self, *, reverse: bool = False) -> Self: r"""Return the cumulative count of the non-null values in the column. @@ -1693,8 +1614,8 @@ def cum_count(self, *, reverse: bool = False) -> Self: |3 d 3 1| └─────────────────────────────────────────┘ """ - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).cum_count(reverse=reverse) + return self._with_node( + ExprNode(ExprKind.ORDERABLE_WINDOW, "cum_count", reverse=reverse) ) def cum_min(self, *, reverse: bool = False) -> Self: @@ -1726,8 +1647,8 @@ def cum_min(self, *, reverse: bool = False) -> Self: |3 2.0 1.0 2.0| └────────────────────────────────────┘ """ - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).cum_min(reverse=reverse) + return self._with_node( + ExprNode(ExprKind.ORDERABLE_WINDOW, "cum_min", reverse=reverse) ) def cum_max(self, *, reverse: bool = False) -> Self: @@ -1759,8 +1680,8 @@ def cum_max(self, *, reverse: bool = False) -> Self: |3 2.0 3.0 2.0| └────────────────────────────────────┘ """ - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).cum_max(reverse=reverse) + return self._with_node( + ExprNode(ExprKind.ORDERABLE_WINDOW, "cum_max", reverse=reverse) ) def cum_prod(self, *, reverse: bool = False) -> Self: @@ -1792,8 +1713,8 @@ def cum_prod(self, *, reverse: bool = False) -> Self: |3 2.0 6.0 2.0| └──────────────────────────────────────┘ """ - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).cum_prod(reverse=reverse) + return self._with_node( + ExprNode(ExprKind.ORDERABLE_WINDOW, "cum_prod", reverse=reverse) ) def rolling_sum( @@ -1838,13 +1759,16 @@ def rolling_sum( |3 4.0 6.0| └─────────────────────┘ """ - window_size, min_samples_int = _validate_rolling_arguments( + window_size, min_samples = _validate_rolling_arguments( window_size=window_size, min_samples=min_samples ) - - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).rolling_sum( - window_size=window_size, min_samples=min_samples_int, center=center + return self._with_node( + ExprNode( + ExprKind.ORDERABLE_WINDOW, + "rolling_sum", + window_size=window_size, + min_samples=min_samples, + center=center, ) ) @@ -1894,9 +1818,13 @@ def rolling_mean( window_size=window_size, min_samples=min_samples ) - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).rolling_mean( - window_size=window_size, min_samples=min_samples, center=center + return self._with_node( + ExprNode( + ExprKind.ORDERABLE_WINDOW, + "rolling_mean", + window_size=window_size, + min_samples=min_samples, + center=center, ) ) @@ -1951,10 +1879,14 @@ def rolling_var( window_size, min_samples = _validate_rolling_arguments( window_size=window_size, min_samples=min_samples ) - - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).rolling_var( - window_size=window_size, min_samples=min_samples, center=center, ddof=ddof + return self._with_node( + ExprNode( + ExprKind.ORDERABLE_WINDOW, + "rolling_var", + ddof=ddof, + window_size=window_size, + min_samples=min_samples, + center=center, ) ) @@ -2009,10 +1941,14 @@ def rolling_std( window_size, min_samples = _validate_rolling_arguments( window_size=window_size, min_samples=min_samples ) - - return self._with_orderable_window( - lambda plx: self._to_compliant_expr(plx).rolling_std( - window_size=window_size, min_samples=min_samples, center=center, ddof=ddof + return self._with_node( + ExprNode( + ExprKind.ORDERABLE_WINDOW, + "rolling_std", + ddof=ddof, + window_size=window_size, + min_samples=min_samples, + center=center, ) ) @@ -2070,10 +2006,8 @@ def rank(self, method: RankMethod = "average", *, descending: bool = False) -> S ) raise ValueError(msg) - return self._with_window( - lambda plx: self._to_compliant_expr(plx).rank( - method=method, descending=descending - ) + return self._with_node( + ExprNode(ExprKind.WINDOW, "rank", method=method, descending=descending) ) def log(self, base: float = math.e) -> Self: @@ -2104,9 +2038,7 @@ def log(self, base: float = math.e) -> Self: |log_2: [[0,1,2]] | └────────────────────────────────────────────────┘ """ - return self._with_elementwise( - lambda plx: self._to_compliant_expr(plx).log(base=base) - ) + return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "log", base=base)) def exp(self) -> Self: r"""Compute the exponent. @@ -2129,7 +2061,7 @@ def exp(self) -> Self: |exp: [[0.36787944117144233,1,2.718281828459045]]| └────────────────────────────────────────────────┘ """ - return self._with_elementwise(lambda plx: self._to_compliant_expr(plx).exp()) + return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "exp")) def sqrt(self) -> Self: r"""Compute the square root. @@ -2152,9 +2084,9 @@ def sqrt(self) -> Self: |sqrt: [[1,2,3]] | └──────────────────┘ """ - return self._with_elementwise(lambda plx: self._to_compliant_expr(plx).sqrt()) + return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "sqrt")) - def is_close( + def is_close( # noqa: PLR0914 self, other: Self | NumericLiteral, *, @@ -2224,11 +2156,58 @@ def is_close( msg = f"`rel_tol` must be in the range [0, 1) but got {rel_tol}" raise ComputeError(msg) - kwargs = {"abs_tol": abs_tol, "rel_tol": rel_tol, "nans_equal": nans_equal} - return self._with_nary( - lambda *exprs: exprs[0].is_close(exprs[1], **kwargs), other + from decimal import Decimal + + other_abs: Self | NumericLiteral + other_is_nan: Self | bool + other_is_inf: Self | bool + other_is_not_inf: Self | bool + + if isinstance(other, (float, int, Decimal)): + from math import isinf, isnan + + # NOTE: See https://discuss.python.org/t/inferred-type-of-function-that-calls-dunder-abs-abs/101447 + other_abs = other.__abs__() + other_is_nan = isnan(other) + other_is_inf = isinf(other) + + # Define the other_is_not_inf variable to prevent triggering the following warning: + # > DeprecationWarning: Bitwise inversion '~' on bool is deprecated and will be + # > removed in Python 3.16. + other_is_not_inf = not other_is_inf + + else: + other_abs, other_is_nan = other.abs(), other.is_nan() + other_is_not_inf = other.is_finite() | other_is_nan + other_is_inf = ~other_is_not_inf + + rel_threshold = self.abs().clip(lower_bound=other_abs, upper_bound=None) * rel_tol + tolerance = rel_threshold.clip(lower_bound=abs_tol, upper_bound=None) + + self_is_nan = self.is_nan() + self_is_not_inf = self.is_finite() | self_is_nan + + # Values are close if abs_diff <= tolerance, and both finite + is_close = ( + ((self - other).abs() <= tolerance) & self_is_not_inf & other_is_not_inf ) + # Handle infinity cases: infinities are close/equal if they have the same sign + self_sign, other_sign = self > 0, other > 0 + is_same_inf = (~self_is_not_inf) & other_is_inf & (self_sign == other_sign) + + # Handle nan cases: + # * If any value is NaN, then False (via `& ~either_nan`) + # * However, if `nans_equals = True` and if _both_ values are NaN, then True + either_nan = self_is_nan | other_is_nan + result = (is_close | is_same_inf) & ~either_nan + + if nans_equal: + both_nan = self_is_nan & other_is_nan + result = result | both_nan + + return result + @property def str(self) -> ExprStringNamespace[Self]: return ExprStringNamespace(self) diff --git a/narwhals/expr_cat.py b/narwhals/expr_cat.py index 32c1f399c4..5c0541a981 100644 --- a/narwhals/expr_cat.py +++ b/narwhals/expr_cat.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING, Generic, TypeVar +from narwhals._expression_parsing import ExprKind, ExprNode + if TYPE_CHECKING: from narwhals.expr import Expr @@ -34,6 +36,4 @@ def get_categories(self) -> ExprT: │ mango │ └────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).cat.get_categories() - ) + return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "cat.get_categories")) diff --git a/narwhals/expr_dt.py b/narwhals/expr_dt.py index 9ae6eac38f..38e3563d9f 100644 --- a/narwhals/expr_dt.py +++ b/narwhals/expr_dt.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING, Generic, TypeVar +from narwhals._expression_parsing import ExprKind, ExprNode + if TYPE_CHECKING: from narwhals.expr import Expr from narwhals.typing import TimeUnit @@ -38,9 +40,7 @@ def date(self) -> ExprT: │ 2027-12-13 │ └────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.date() - ) + return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.date")) def year(self) -> ExprT: """Extract year from underlying DateTime representation. @@ -64,9 +64,7 @@ def year(self) -> ExprT: |1 2065-01-01 2065| └──────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.year() - ) + return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.year")) def month(self) -> ExprT: """Extract month from underlying DateTime representation. @@ -87,9 +85,7 @@ def month(self) -> ExprT: a: [[1978-06-01 00:00:00.000000,2065-01-01 00:00:00.000000]] month: [[6,1]] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.month() - ) + return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.month")) def day(self) -> ExprT: """Extract day from underlying DateTime representation. @@ -110,9 +106,7 @@ def day(self) -> ExprT: a: [[1978-06-01 00:00:00.000000,2065-01-01 00:00:00.000000]] day: [[1,1]] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.day() - ) + return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.day")) def hour(self) -> ExprT: """Extract hour from underlying DateTime representation. @@ -142,9 +136,7 @@ def hour(self) -> ExprT: |└─────────────────────┴──────┘| └──────────────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.hour() - ) + return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.hour")) def minute(self) -> ExprT: """Extract minutes from underlying DateTime representation. @@ -164,9 +156,7 @@ def minute(self) -> ExprT: 0 1978-01-01 01:01:00 1 1 2065-01-01 10:20:00 20 """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.minute() - ) + return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.minute")) def second(self) -> ExprT: """Extract seconds from underlying DateTime representation. @@ -192,9 +182,7 @@ def second(self) -> ExprT: a: [[1978-01-01 01:01:01.000000,2065-01-01 10:20:30.000000]] second: [[1,30]] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.second() - ) + return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.second")) def millisecond(self) -> ExprT: """Extract milliseconds from underlying DateTime representation. @@ -222,9 +210,7 @@ def millisecond(self) -> ExprT: a: [[1978-01-01 01:01:01.000000,2065-01-01 10:20:30.067000]] millisecond: [[0,67]] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.millisecond() - ) + return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.millisecond")) def microsecond(self) -> ExprT: """Extract microseconds from underlying DateTime representation. @@ -252,9 +238,7 @@ def microsecond(self) -> ExprT: a: [[1978-01-01 01:01:01.000000,2065-01-01 10:20:30.067000]] microsecond: [[0,67000]] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.microsecond() - ) + return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.microsecond")) def nanosecond(self) -> ExprT: """Extract Nanoseconds from underlying DateTime representation. @@ -282,9 +266,7 @@ def nanosecond(self) -> ExprT: a: [[1978-01-01 01:01:01.000000,2065-01-01 10:20:30.067000]] nanosecond: [[0,67000000]] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.nanosecond() - ) + return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.nanosecond")) def ordinal_day(self) -> ExprT: """Get ordinal day. @@ -306,9 +288,7 @@ def ordinal_day(self) -> ExprT: |1 2020-08-03 216| └───────────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.ordinal_day() - ) + return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.ordinal_day")) def weekday(self) -> ExprT: """Extract the week day from the underlying Date representation. @@ -332,9 +312,7 @@ def weekday(self) -> ExprT: |1 2020-08-03 1| └────────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.weekday() - ) + return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.weekday")) def total_minutes(self) -> ExprT: """Get total minutes. @@ -365,9 +343,7 @@ def total_minutes(self) -> ExprT: │ 20m 40s ┆ 20 │ └──────────────┴─────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.total_minutes() - ) + return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.total_minutes")) def total_seconds(self) -> ExprT: """Get total seconds. @@ -398,9 +374,7 @@ def total_seconds(self) -> ExprT: │ 20s 40ms ┆ 20 │ └──────────────┴─────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.total_seconds() - ) + return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.total_seconds")) def total_milliseconds(self) -> ExprT: """Get total milliseconds. @@ -436,8 +410,8 @@ def total_milliseconds(self) -> ExprT: │ 20040µs ┆ 20 │ └──────────────┴──────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.total_milliseconds() + return self._expr._with_node( + ExprNode(ExprKind.ELEMENTWISE, "dt.total_milliseconds") ) def total_microseconds(self) -> ExprT: @@ -471,8 +445,8 @@ def total_microseconds(self) -> ExprT: a: [[10,1200]] a_total_microseconds: [[10,1200]] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.total_microseconds() + return self._expr._with_node( + ExprNode(ExprKind.ELEMENTWISE, "dt.total_microseconds") ) def total_nanoseconds(self) -> ExprT: @@ -505,8 +479,8 @@ def total_nanoseconds(self) -> ExprT: 0 2024-01-01 00:00:00.000000001 NaN 1 2024-01-01 00:00:00.000000002 1.0 """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.total_nanoseconds() + return self._expr._with_node( + ExprNode(ExprKind.ELEMENTWISE, "dt.total_nanoseconds") ) def to_string(self, format: str) -> ExprT: @@ -569,8 +543,8 @@ def to_string(self, format: str) -> ExprT: |└─────────────────────┘| └───────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.to_string(format) + return self._expr._with_node( + ExprNode(ExprKind.ELEMENTWISE, "dt.to_string", format=format) ) def replace_time_zone(self, time_zone: str | None) -> ExprT: @@ -597,8 +571,8 @@ def replace_time_zone(self, time_zone: str | None) -> ExprT: 0 2024-01-01 00:00:00+05:45 1 2024-01-02 00:00:00+05:45 """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.replace_time_zone(time_zone) + return self._expr._with_node( + ExprNode(ExprKind.ELEMENTWISE, "dt.replace_time_zone", time_zone=time_zone) ) def convert_time_zone(self, time_zone: str) -> ExprT: @@ -631,8 +605,8 @@ def convert_time_zone(self, time_zone: str) -> ExprT: if time_zone is None: msg = "Target `time_zone` cannot be `None` in `convert_time_zone`. Please use `replace_time_zone(None)` if you want to remove the time zone." raise TypeError(msg) - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.convert_time_zone(time_zone) + return self._expr._with_node( + ExprNode(ExprKind.ELEMENTWISE, "dt.convert_time_zone", time_zone=time_zone) ) def timestamp(self, time_unit: TimeUnit = "us") -> ExprT: @@ -671,8 +645,8 @@ def timestamp(self, time_unit: TimeUnit = "us") -> ExprT: f"\n\nExpected one of {{'ns', 'us', 'ms'}}, got {time_unit!r}." ) raise ValueError(msg) - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.timestamp(time_unit) + return self._expr._with_node( + ExprNode(ExprKind.ELEMENTWISE, "dt.timestamp", time_unit=time_unit) ) def truncate(self, every: str) -> ExprT: @@ -715,8 +689,8 @@ def truncate(self, every: str) -> ExprT: |└─────────────────────┴─────────────────────┘| └─────────────────────────────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.truncate(every) + return self._expr._with_node( + ExprNode(ExprKind.ELEMENTWISE, "dt.truncate", every=every) ) def offset_by(self, by: str) -> ExprT: @@ -759,6 +733,6 @@ def offset_by(self, by: str) -> ExprT: |└─────────────────────┴───────────────────────┘| └───────────────────────────────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).dt.offset_by(by) + return self._expr._with_node( + ExprNode(ExprKind.ELEMENTWISE, "dt.offset_by", by=by) ) diff --git a/narwhals/expr_list.py b/narwhals/expr_list.py index 63d34a52ae..70c28e5f66 100644 --- a/narwhals/expr_list.py +++ b/narwhals/expr_list.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING, Generic, TypeVar +from narwhals._expression_parsing import ExprKind, ExprNode + if TYPE_CHECKING: from narwhals.expr import Expr from narwhals.typing import NonNestedLiteral @@ -40,9 +42,7 @@ def len(self) -> ExprT: |└──────────────┴───────┘| └────────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).list.len() - ) + return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "list.len")) def unique(self) -> ExprT: """Get the unique/distinct values in the list. @@ -71,9 +71,7 @@ def unique(self) -> ExprT: |└──────────────┴───────────┘| └────────────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).list.unique() - ) + return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "list.unique")) def contains(self, item: NonNestedLiteral) -> ExprT: """Check if sublists contain the given item. @@ -102,8 +100,8 @@ def contains(self, item: NonNestedLiteral) -> ExprT: |└───────────┴──────────────┘| └────────────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).list.contains(item) + return self._expr._with_node( + ExprNode(ExprKind.ELEMENTWISE, "list.contains", item=item) ) def get(self, index: int) -> ExprT: @@ -145,6 +143,6 @@ def get(self, index: int) -> ExprT: msg = f"Index {index} is out of bounds: should be greater than or equal to 0." raise ValueError(msg) - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).list.get(index) + return self._expr._with_node( + ExprNode(ExprKind.ELEMENTWISE, "list.get", index=index) ) diff --git a/narwhals/expr_name.py b/narwhals/expr_name.py index facda33042..51f67f3a42 100644 --- a/narwhals/expr_name.py +++ b/narwhals/expr_name.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING, Callable, Generic, TypeVar +from narwhals._expression_parsing import ExprKind, ExprNode + if TYPE_CHECKING: from narwhals.expr import Expr @@ -26,9 +28,7 @@ def keep(self) -> ExprT: >>> df.select(nw.col("foo").alias("alias_for_foo").name.keep()).columns ['foo'] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).name.keep() - ) + return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "name.keep")) def map(self, function: Callable[[str], str]) -> ExprT: r"""Rename the output of an expression by mapping a function over the root name. @@ -48,8 +48,8 @@ def map(self, function: Callable[[str], str]) -> ExprT: >>> df.select(nw.col("foo", "BAR").name.map(renaming_func)).columns ['oof', 'RAB'] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).name.map(function) + return self._expr._with_node( + ExprNode(ExprKind.ELEMENTWISE, "name.map", function=function) ) def prefix(self, prefix: str) -> ExprT: @@ -69,8 +69,8 @@ def prefix(self, prefix: str) -> ExprT: >>> df.select(nw.col("foo", "BAR").name.prefix("with_prefix")).columns ['with_prefixfoo', 'with_prefixBAR'] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).name.prefix(prefix) + return self._expr._with_node( + ExprNode(ExprKind.ELEMENTWISE, "name.prefix", prefix=prefix) ) def suffix(self, suffix: str) -> ExprT: @@ -90,8 +90,8 @@ def suffix(self, suffix: str) -> ExprT: >>> df.select(nw.col("foo", "BAR").name.suffix("_with_suffix")).columns ['foo_with_suffix', 'BAR_with_suffix'] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).name.suffix(suffix) + return self._expr._with_node( + ExprNode(ExprKind.ELEMENTWISE, "name.suffix", suffix=suffix) ) def to_lowercase(self) -> ExprT: @@ -108,9 +108,7 @@ def to_lowercase(self) -> ExprT: >>> df.select(nw.col("foo", "BAR").name.to_lowercase()).columns ['foo', 'bar'] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).name.to_lowercase() - ) + return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "name.to_lowercase")) def to_uppercase(self) -> ExprT: r"""Make the root column name uppercase. @@ -126,6 +124,4 @@ def to_uppercase(self) -> ExprT: >>> df.select(nw.col("foo", "BAR").name.to_uppercase()).columns ['FOO', 'BAR'] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).name.to_uppercase() - ) + return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "name.to_uppercase")) diff --git a/narwhals/expr_str.py b/narwhals/expr_str.py index 19edb12911..b8cb7ce78d 100644 --- a/narwhals/expr_str.py +++ b/narwhals/expr_str.py @@ -1,12 +1,14 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Generic, TypeVar +from typing import TYPE_CHECKING, Generic, ParamSpec, TypeVar -from narwhals._expression_parsing import apply_n_ary_operation +from narwhals._expression_parsing import ExprKind, ExprNode if TYPE_CHECKING: from narwhals.expr import Expr + PS = ParamSpec("PS") + ExprT = TypeVar("ExprT", bound="Expr") @@ -38,9 +40,7 @@ def len_chars(self) -> ExprT: |└───────┴───────────┘| └─────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.len_chars() - ) + return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "str.len_chars")) def replace( self, pattern: str, value: str | ExprT, *, literal: bool = False, n: int = 1 @@ -67,17 +67,15 @@ def replace( |1 abc abc123 abc123| └──────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: ( - apply_n_ary_operation( - plx, - lambda self, value: self.str.replace( - pattern, value, literal=literal, n=n - ), - self._expr, - value, - str_as_lit=True, - ) + return self._expr._with_node( + ExprNode( + ExprKind.ELEMENTWISE, + "str.replace", + value, + pattern=pattern, + literal=literal, + n=n, + str_as_lit=True, ) ) @@ -105,17 +103,14 @@ def replace_all( |1 abc abc123 123| └──────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: ( - apply_n_ary_operation( - plx, - lambda self, value: self.str.replace_all( - pattern, value, literal=literal - ), - self._expr, - value, - str_as_lit=True, - ) + return self._expr._with_node( + ExprNode( + ExprKind.ELEMENTWISE, + "str.replace_all", + value, + pattern=pattern, + literal=literal, + str_as_lit=True, ) ) @@ -138,8 +133,8 @@ def strip_chars(self, characters: str | None = None) -> ExprT: ... ) {'fruits': ['apple', '\nmango'], 'stripped': ['apple', 'mango']} """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.strip_chars(characters) + return self._expr._with_node( + ExprNode(ExprKind.ELEMENTWISE, "str.strip_chars", characters=characters) ) def starts_with(self, prefix: str) -> ExprT: @@ -163,8 +158,8 @@ def starts_with(self, prefix: str) -> ExprT: |2 None None| └───────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.starts_with(prefix) + return self._expr._with_node( + ExprNode(ExprKind.ELEMENTWISE, "str.starts_with", prefix=prefix) ) def ends_with(self, suffix: str) -> ExprT: @@ -188,8 +183,8 @@ def ends_with(self, suffix: str) -> ExprT: |2 None None| └───────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.ends_with(suffix) + return self._expr._with_node( + ExprNode(ExprKind.ELEMENTWISE, "str.ends_with", suffix=suffix) ) def contains(self, pattern: str, *, literal: bool = False) -> ExprT: @@ -218,9 +213,9 @@ def contains(self, pattern: str, *, literal: bool = False) -> ExprT: default_match: [[true,false,true]] case_insensitive_match: [[true,false,true]] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.contains( - pattern, literal=literal + return self._expr._with_node( + ExprNode( + ExprKind.ELEMENTWISE, "str.contains", pattern=pattern, literal=literal ) ) @@ -247,10 +242,8 @@ def slice(self, offset: int, length: int | None = None) -> ExprT: |2 papaya ya| └──────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.slice( - offset=offset, length=length - ) + return self._expr._with_node( + ExprNode(ExprKind.ELEMENTWISE, "str.slice", offset=offset, length=length) ) def split(self, by: str) -> ExprT: @@ -279,9 +272,7 @@ def split(self, by: str) -> ExprT: |└─────────┴────────────────┘| └────────────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.split(by=by) - ) + return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "str.split", by=by)) def head(self, n: int = 5) -> ExprT: r"""Take the first n elements of each string. @@ -305,8 +296,8 @@ def head(self, n: int = 5) -> ExprT: lyrics: [["taata","taatatata","zukkyun"]] lyrics_head: [["taata","taata","zukky"]] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.slice(0, n) + return self._expr._with_node( + ExprNode(ExprKind.ELEMENTWISE, "str.slice", offset=0, length=n) ) def tail(self, n: int = 5) -> ExprT: @@ -331,10 +322,8 @@ def tail(self, n: int = 5) -> ExprT: lyrics: [["taata","taatatata","zukkyun"]] lyrics_tail: [["taata","atata","kkyun"]] """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.slice( - offset=-n, length=None - ) + return self._expr._with_node( + ExprNode(ExprKind.ELEMENTWISE, "str.slice", offset=-n, length=None) ) def to_datetime(self, format: str | None = None) -> ExprT: @@ -375,8 +364,8 @@ def to_datetime(self, format: str | None = None) -> ExprT: |└─────────────────────┘| └───────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.to_datetime(format=format) + return self._expr._with_node( + ExprNode(ExprKind.ELEMENTWISE, "str.to_datetime", format=format) ) def to_date(self, format: str | None = None) -> ExprT: @@ -404,8 +393,8 @@ def to_date(self, format: str | None = None) -> ExprT: |a: [[2020-01-01,2020-01-02]]| └────────────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.to_date(format=format) + return self._expr._with_node( + ExprNode(ExprKind.ELEMENTWISE, "str.to_date", format=format) ) def to_uppercase(self) -> ExprT: @@ -430,9 +419,7 @@ def to_uppercase(self) -> ExprT: |1 None None| └──────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.to_uppercase() - ) + return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "str.to_uppercase")) def to_lowercase(self) -> ExprT: r"""Transform string to lowercase variant. @@ -451,9 +438,7 @@ def to_lowercase(self) -> ExprT: |1 None None| └──────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.to_lowercase() - ) + return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "str.to_lowercase")) def zfill(self, width: int) -> ExprT: """Transform string to zero-padded variant. @@ -479,6 +464,6 @@ def zfill(self, width: int) -> ExprT: |3 None None| └──────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.zfill(width) + return self._expr._with_node( + ExprNode(ExprKind.ELEMENTWISE, "str.zfill", width=width) ) diff --git a/narwhals/expr_struct.py b/narwhals/expr_struct.py index fe74cf9f75..83fc64a648 100644 --- a/narwhals/expr_struct.py +++ b/narwhals/expr_struct.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING, Generic, TypeVar +from narwhals._expression_parsing import ExprKind, ExprNode + if TYPE_CHECKING: from narwhals.expr import Expr @@ -40,6 +42,6 @@ def field(self, name: str) -> ExprT: |└──────────────┴──────┘| └───────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).struct.field(name) + return self._expr._with_node( + ExprNode(ExprKind.ELEMENTWISE, "struct.field", name=name) ) diff --git a/narwhals/functions.py b/narwhals/functions.py index 4eb0bdb43a..75a26eb67a 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -3,22 +3,14 @@ import platform import sys from collections.abc import Iterable, Mapping, Sequence -from functools import partial -from typing import TYPE_CHECKING, Any, Callable - -from narwhals._expression_parsing import ( - ExprKind, - ExprMetadata, - apply_n_ary_operation, - combine_metadata, - is_scalar_like, -) +from typing import TYPE_CHECKING, Any + +from narwhals._expression_parsing import ExprKind, ExprNode from narwhals._utils import ( Implementation, Version, deprecate_native_namespace, flatten, - is_compliant_expr, is_eager_allowed, is_sequence_but_not_str, normalize_path, @@ -41,7 +33,6 @@ from typing_extensions import TypeAlias, TypeIs - from narwhals._compliant import CompliantExpr, CompliantNamespace from narwhals._translate import IntoArrowTable from narwhals._typing import Backend, EagerAllowed, IntoBackend from narwhals.dataframe import DataFrame, LazyFrame @@ -56,7 +47,6 @@ NativeLazyFrame, NativeSeries, NonNestedLiteral, - _1DArray, _2DArray, ) @@ -931,16 +921,7 @@ def col(*names: str | Iterable[str]) -> Expr: └──────────────────┘ """ flat_names = flatten(names) - - def func(plx: Any) -> Any: - return plx.col(*flat_names) - - return Expr( - func, - ExprMetadata.selector_single() - if len(flat_names) == 1 - else ExprMetadata.selector_multi_named(), - ) + return Expr(ExprNode(ExprKind.COL, "col", names=flat_names)) def exclude(*names: str | Iterable[str]) -> Expr: @@ -972,12 +953,9 @@ def exclude(*names: str | Iterable[str]) -> Expr: | └─────┘ | └──────────────────┘ """ - exclude_names = frozenset(flatten(names)) - - def func(plx: Any) -> Any: - return plx.exclude(exclude_names) - - return Expr(func, ExprMetadata.selector_multi_unnamed()) + flat_names = flatten(names) + exclude_names = frozenset(flat_names) + return Expr(ExprNode(ExprKind.EXCLUDE, "exclude", names=exclude_names)) def nth(*indices: int | Sequence[int]) -> Expr: @@ -1011,16 +989,8 @@ def nth(*indices: int | Sequence[int]) -> Expr: └──────────────────┘ """ flat_indices = flatten(indices) - - def func(plx: Any) -> Any: - return plx.nth(*flat_indices) - - return Expr( - func, - ExprMetadata.selector_single() - if len(flat_indices) == 1 - else ExprMetadata.selector_multi_unnamed(), - ) + node = ExprNode(ExprKind.NTH, "nth", indices=flat_indices) + return Expr(node) # Add underscore so it doesn't conflict with builtin `all` @@ -1044,7 +1014,8 @@ def all_() -> Expr: | 1 4 0.246 | └──────────────────┘ """ - return Expr(lambda plx: plx.all(), ExprMetadata.selector_multi_unnamed()) + node = ExprNode(ExprKind.ALL, "all") + return Expr(node) # Add underscore so it doesn't conflict with builtin `len` @@ -1073,11 +1044,8 @@ def len_() -> Expr: | └─────┘ | └──────────────────┘ """ - - def func(plx: Any) -> Any: - return plx.len() - - return Expr(func, ExprMetadata.aggregation()) + node = ExprNode(ExprKind.AGGREGATION, "len") + return Expr(node) def sum(*columns: str) -> Expr: @@ -1236,22 +1204,12 @@ def max(*columns: str) -> Expr: return col(*columns).max() -def _expr_with_n_ary_op( - func_name: str, - operation_factory: Callable[ - [CompliantNamespace[Any, Any]], Callable[..., CompliantExpr[Any, Any]] - ], - *exprs: IntoExpr, -) -> Expr: +def _expr_with_horizontal_op(name: str, *exprs: IntoExpr, **kwargs: Any) -> Expr: if not exprs: - msg = f"At least one expression must be passed to `{func_name}`" + msg = f"At least one expression must be passed to `{name}`" raise ValueError(msg) - return Expr( - lambda plx: apply_n_ary_operation( - plx, operation_factory(plx), *exprs, str_as_lit=False - ), - ExprMetadata.from_horizontal_op(*exprs), - ) + node = ExprNode(ExprKind.ELEMENTWISE, name, *exprs, **kwargs, allow_multi_output=True) + return Expr(node) def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: @@ -1288,9 +1246,8 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: |└─────┴──────┴─────┘| └────────────────────┘ """ - return _expr_with_n_ary_op( - "sum_horizontal", lambda plx: plx.sum_horizontal, *flatten(exprs) - ) + flat_exprs = flatten(exprs) + return _expr_with_horizontal_op("sum_horizontal", *flat_exprs) def min_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: @@ -1325,9 +1282,7 @@ def min_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: | h_min: [[1,5,3]] | └──────────────────┘ """ - return _expr_with_n_ary_op( - "min_horizontal", lambda plx: plx.min_horizontal, *flatten(exprs) - ) + return _expr_with_horizontal_op("min_horizontal", *flatten(exprs)) def max_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: @@ -1364,73 +1319,30 @@ def max_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: |└─────┴──────┴───────┘| └──────────────────────┘ """ - return _expr_with_n_ary_op( - "max_horizontal", lambda plx: plx.max_horizontal, *flatten(exprs) - ) + return _expr_with_horizontal_op("max_horizontal", *flatten(exprs)) class When: def __init__(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> None: self._predicate = all_horizontal(*flatten(predicates), ignore_nulls=False) - def then(self, value: IntoExpr | NonNestedLiteral | _1DArray) -> Then: - kind = ExprKind.from_into_expr(value, str_as_lit=False) - if self._predicate._metadata.is_scalar_like and not kind.is_scalar_like: - msg = ( - "If you pass a scalar-like predicate to `nw.when`, then " - "the `then` value must also be scalar-like." - ) - raise InvalidOperationError(msg) - + def then(self, value: IntoExpr | NonNestedLiteral) -> Then: return Then( - lambda plx: apply_n_ary_operation( - plx, - lambda *args: plx.when(args[0]).then(args[1]), + ExprNode( + ExprKind.ELEMENTWISE, + "when_then", self._predicate, value, - str_as_lit=False, - ), - combine_metadata( - self._predicate, - value, - str_as_lit=False, allow_multi_output=False, - to_single_output=False, - ), + ) ) class Then(Expr): - def otherwise(self, value: IntoExpr | NonNestedLiteral | _1DArray) -> Expr: - kind = ExprKind.from_into_expr(value, str_as_lit=False) - if self._metadata.is_scalar_like and not is_scalar_like(kind): - msg = ( - "If you pass a scalar-like predicate to `nw.when`, then " - "the `otherwise` value must also be scalar-like." - ) - raise InvalidOperationError(msg) - - def func(plx: CompliantNamespace[Any, Any]) -> CompliantExpr[Any, Any]: - compliant_expr = self._to_compliant_expr(plx) - compliant_value = plx.parse_into_expr(value, str_as_lit=False) - if ( - not self._metadata.is_scalar_like - and is_scalar_like(kind) - and is_compliant_expr(compliant_value) - ): - compliant_value = compliant_value.broadcast(kind) - return compliant_expr.otherwise(compliant_value) # type: ignore[attr-defined, no-any-return] - - return Expr( - func, - combine_metadata( - self, - value, - str_as_lit=False, - allow_multi_output=False, - to_single_output=False, - ), - ) + def otherwise(self, value: IntoExpr | NonNestedLiteral) -> Expr: + # eject latest node, replace with `when_then_otherwise` + node = self._nodes[0] + return Expr(ExprNode(ExprKind.ELEMENTWISE, "when_then", *node.exprs, value)) def when(*predicates: IntoExpr | Iterable[IntoExpr]) -> When: @@ -1517,10 +1429,9 @@ def all_horizontal(*exprs: IntoExpr | Iterable[IntoExpr], ignore_nulls: bool) -> └─────────────────────────────────────────┘ """ - return _expr_with_n_ary_op( - "all_horizontal", - lambda plx: partial(plx.all_horizontal, ignore_nulls=ignore_nulls), - *flatten(exprs), + flat_exprs = flatten(exprs) + return _expr_with_horizontal_op( + "all_horizontal", *flat_exprs, ignore_nulls=ignore_nulls ) @@ -1560,7 +1471,8 @@ def lit(value: NonNestedLiteral, dtype: IntoDType | None = None) -> Expr: msg = f"Nested datatypes are not supported yet. Got {value}" raise NotImplementedError(msg) - return Expr(lambda plx: plx.lit(value, dtype), ExprMetadata.literal()) + node = ExprNode(ExprKind.LITERAL, "lit", value=value, dtype=dtype) + return Expr(node) def any_horizontal(*exprs: IntoExpr | Iterable[IntoExpr], ignore_nulls: bool) -> Expr: @@ -1609,10 +1521,9 @@ def any_horizontal(*exprs: IntoExpr | Iterable[IntoExpr], ignore_nulls: bool) -> |└───────┴───────┴───────┘| └─────────────────────────┘ """ - return _expr_with_n_ary_op( - "any_horizontal", - lambda plx: partial(plx.any_horizontal, ignore_nulls=ignore_nulls), - *flatten(exprs), + flat_exprs = flatten(exprs) + return _expr_with_horizontal_op( + "any_horizontal", *flat_exprs, ignore_nulls=ignore_nulls ) @@ -1646,9 +1557,7 @@ def mean_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: | a: [[2.5,6.5,3]] | └──────────────────┘ """ - return _expr_with_n_ary_op( - "mean_horizontal", lambda plx: plx.mean_horizontal, *flatten(exprs) - ) + return _expr_with_horizontal_op("mean_horizontal", *flatten(exprs)) def concat_str( @@ -1700,12 +1609,8 @@ def concat_str( └──────────────────┘ """ flat_exprs = flatten([*flatten([exprs]), *more_exprs]) - return _expr_with_n_ary_op( - "concat_str", - lambda plx: lambda *args: plx.concat_str( - *args, separator=separator, ignore_nulls=ignore_nulls - ), - *flat_exprs, + return _expr_with_horizontal_op( + "concat_str", *flat_exprs, separator=separator, ignore_nulls=ignore_nulls ) @@ -1767,9 +1672,5 @@ def coalesce( ) raise TypeError(msg) - return Expr( - lambda plx: apply_n_ary_operation( - plx, lambda *args: plx.coalesce(*args), *flat_exprs, str_as_lit=False - ), - ExprMetadata.from_horizontal_op(*flat_exprs), - ) + node = ExprNode(ExprKind.ELEMENTWISE, "coalesce", *flat_exprs) + return Expr(node) diff --git a/narwhals/group_by.py b/narwhals/group_by.py index c469ac921e..1bdd8e9ec2 100644 --- a/narwhals/group_by.py +++ b/narwhals/group_by.py @@ -2,8 +2,7 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar -from narwhals._expression_parsing import all_exprs_are_scalar_like -from narwhals._utils import flatten, tupleify +from narwhals._utils import tupleify from narwhals.exceptions import InvalidOperationError from narwhals.typing import DataFrameT @@ -72,8 +71,8 @@ def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> DataFrameT: 2 b 3 2 3 c 3 1 """ - flat_aggs = tuple(flatten(aggs)) - if not all_exprs_are_scalar_like(*flat_aggs, **named_aggs): + compliant_aggs, kinds = self._df._flatten_and_extract(*aggs, **named_aggs) + if not all(x.is_scalar_like for x in kinds): msg = ( "Found expression which does not aggregate.\n\n" "All expressions passed to GroupBy.agg must aggregate.\n" @@ -81,14 +80,6 @@ def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> DataFrameT: "but `df.group_by('a').agg(nw.col('b'))` is not." ) raise InvalidOperationError(msg) - plx = self._df.__narwhals_namespace__() - compliant_aggs = ( - *(x._to_compliant_expr(plx) for x in flat_aggs), - *( - value.alias(key)._to_compliant_expr(plx) - for key, value in named_aggs.items() - ), - ) return self._df._with_compliant(self._grouped.agg(*compliant_aggs)) def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]: @@ -166,8 +157,8 @@ def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> LazyFrameT: |└─────┴─────┴─────┘| └───────────────────┘ """ - flat_aggs = tuple(flatten(aggs)) - if not all_exprs_are_scalar_like(*flat_aggs, **named_aggs): + compliant_aggs, kinds = self._df._flatten_and_extract(*aggs, **named_aggs) + if not all(x.is_scalar_like for x in kinds): msg = ( "Found expression which does not aggregate.\n\n" "All expressions passed to GroupBy.agg must aggregate.\n" @@ -175,12 +166,4 @@ def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> LazyFrameT: "but `df.group_by('a').agg(nw.col('b'))` is not." ) raise InvalidOperationError(msg) - plx = self._df.__narwhals_namespace__() - compliant_aggs = ( - *(x._to_compliant_expr(plx) for x in flat_aggs), - *( - value.alias(key)._to_compliant_expr(plx) - for key, value in named_aggs.items() - ), - ) return self._df._with_compliant(self._grouped.agg(*compliant_aggs)) diff --git a/narwhals/selectors.py b/narwhals/selectors.py index b7f774b8df..31db79f86a 100644 --- a/narwhals/selectors.py +++ b/narwhals/selectors.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, NoReturn -from narwhals._expression_parsing import ExprMetadata, combine_metadata +from narwhals._expression_parsing import ExprKind, ExprNode from narwhals._utils import flatten from narwhals.expr import Expr @@ -16,50 +16,46 @@ class Selector(Expr): def _to_expr(self) -> Expr: - return Expr(self._to_compliant_expr, self._metadata) + return Expr(*self._nodes) - def __add__(self, other: Any) -> Expr: # type: ignore[override] - if isinstance(other, Selector): - msg = "unsupported operand type(s) for op: ('Selector' + 'Selector')" - raise TypeError(msg) - return self._to_expr() + other # type: ignore[no-any-return] + def __rsub__(self, other: Any) -> NoReturn: + raise NotImplementedError - def __or__(self, other: Any) -> Expr: # type: ignore[override] + def __rand__(self, other: Any) -> NoReturn: + raise NotImplementedError + + def __ror__(self, other: Any) -> NoReturn: + raise NotImplementedError + + def __and__(self, other: Any) -> Expr: # type: ignore[override] if isinstance(other, Selector): - return self.__class__( - lambda plx: self._to_compliant_expr(plx) | other._to_compliant_expr(plx), - combine_metadata( - self, + return self._with_node( + ExprNode( + ExprKind.ELEMENTWISE, + "__and__", other, - str_as_lit=False, + str_as_lit=True, allow_multi_output=True, - to_single_output=False, - ), + ) ) - return self._to_expr() | other # type: ignore[no-any-return] + return self._to_expr()._with_node( + ExprNode(ExprKind.ELEMENTWISE, "__and__", other, str_as_lit=True) + ) - def __and__(self, other: Any) -> Expr: # type: ignore[override] + def __or__(self, other: Any) -> Expr: # type: ignore[override] if isinstance(other, Selector): - return self.__class__( - lambda plx: self._to_compliant_expr(plx) & other._to_compliant_expr(plx), - combine_metadata( - self, + return self._with_node( + ExprNode( + ExprKind.ELEMENTWISE, + "__or__", other, - str_as_lit=False, + str_as_lit=True, allow_multi_output=True, - to_single_output=False, - ), + ) ) - return self._to_expr() & other # type: ignore[no-any-return] - - def __rsub__(self, other: Any) -> NoReturn: - raise NotImplementedError - - def __rand__(self, other: Any) -> NoReturn: - raise NotImplementedError - - def __ror__(self, other: Any) -> NoReturn: - raise NotImplementedError + return self._to_expr()._with_node( + ExprNode(ExprKind.ELEMENTWISE, "__or__", other, str_as_lit=True) + ) def by_dtype(*dtypes: DType | type[DType] | Iterable[DType | type[DType]]) -> Selector: @@ -89,10 +85,7 @@ def by_dtype(*dtypes: DType | type[DType] | Iterable[DType | type[DType]]) -> Se c: [[8.2,4.6]] """ flattened = flatten(dtypes) - return Selector( - lambda plx: plx.selectors.by_dtype(flattened), - ExprMetadata.selector_multi_unnamed(), - ) + return Selector(ExprNode(ExprKind.SELECTOR, "selectors.by_dtype", dtypes=flattened)) def matches(pattern: str) -> Selector: @@ -120,9 +113,7 @@ def matches(pattern: str) -> Selector: 0 123 2.0 1 456 5.5 """ - return Selector( - lambda plx: plx.selectors.matches(pattern), ExprMetadata.selector_multi_unnamed() - ) + return Selector(ExprNode(ExprKind.SELECTOR, "selectors.matches", pattern=pattern)) def numeric() -> Selector: @@ -151,9 +142,7 @@ def numeric() -> Selector: │ 4 ┆ 4.6 │ └─────┴─────┘ """ - return Selector( - lambda plx: plx.selectors.numeric(), ExprMetadata.selector_multi_unnamed() - ) + return Selector(ExprNode(ExprKind.SELECTOR, "selectors.numeric")) def boolean() -> Selector: @@ -186,9 +175,7 @@ def boolean() -> Selector: | └───────┘ | └──────────────────┘ """ - return Selector( - lambda plx: plx.selectors.boolean(), ExprMetadata.selector_multi_unnamed() - ) + return Selector(ExprNode(ExprKind.SELECTOR, "selectors.boolean")) def string() -> Selector: @@ -217,9 +204,7 @@ def string() -> Selector: │ y │ └─────┘ """ - return Selector( - lambda plx: plx.selectors.string(), ExprMetadata.selector_multi_unnamed() - ) + return Selector(ExprNode(ExprKind.SELECTOR, "selectors.string")) def categorical() -> Selector: @@ -250,9 +235,7 @@ def categorical() -> Selector: │ y │ └─────┘ """ - return Selector( - lambda plx: plx.selectors.categorical(), ExprMetadata.selector_multi_unnamed() - ) + return Selector(ExprNode(ExprKind.SELECTOR, "selectors.categorical")) def all() -> Selector: @@ -275,9 +258,7 @@ def all() -> Selector: 0 1 x False 1 2 y True """ - return Selector( - lambda plx: plx.selectors.all(), ExprMetadata.selector_multi_unnamed() - ) + return Selector(ExprNode(ExprKind.SELECTOR, "selectors.all")) def datetime( @@ -336,8 +317,12 @@ def datetime( tstamp_utc: [[2023-04-10 12:14:16.999000Z,2025-08-25 14:18:22.666000Z]] """ return Selector( - lambda plx: plx.selectors.datetime(time_unit=time_unit, time_zone=time_zone), - ExprMetadata.selector_multi_unnamed(), + ExprNode( + ExprKind.SELECTOR, + "selectors.datetime", + time_unit=time_unit, + time_zone=time_zone, + ) ) diff --git a/narwhals/series.py b/narwhals/series.py index c28ff05464..c4709b6dd1 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -2,8 +2,18 @@ import math from collections.abc import Iterable, Iterator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Literal, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Generic, + Literal, + cast, + overload, +) +from narwhals._expression_parsing import ExprKind, ExprNode, is_series from narwhals._utils import ( Implementation, Version, @@ -20,6 +30,7 @@ from narwhals.dependencies import is_numpy_array, is_numpy_array_1d, is_numpy_scalar from narwhals.dtypes import _validate_dtype, _validate_into_dtype from narwhals.exceptions import ComputeError, InvalidOperationError +from narwhals.expr import Expr from narwhals.series_cat import SeriesCatNamespace from narwhals.series_dt import SeriesDateTimeNamespace from narwhals.series_list import SeriesListNamespace @@ -89,6 +100,11 @@ def _dataframe(self) -> type[DataFrame[Any]]: return DataFrame + def _to_expr(self) -> Expr: + return Expr( + ExprNode(ExprKind.SERIES, "_expr._from_series", series=self._compliant) + ) + def __init__( self, series: Any, *, level: Literal["full", "lazy", "interchange"] ) -> None: @@ -2687,28 +2703,24 @@ def is_close( ] ] """ + from narwhals.functions import col + if not self.dtype.is_numeric(): msg = ( f"is_close operation not supported for dtype `{self.dtype}`\n\n" "Hint: `is_close` is only supported for numeric types" ) raise InvalidOperationError(msg) - - if abs_tol < 0: - msg = f"`abs_tol` must be non-negative but got {abs_tol}" - raise ComputeError(msg) - - if not (0 <= rel_tol < 1): - msg = f"`rel_tol` must be in the range [0, 1) but got {rel_tol}" - raise ComputeError(msg) - - return self._with_compliant( - self._compliant_series.is_close( - self._extract_native(other), - abs_tol=abs_tol, - rel_tol=rel_tol, - nans_equal=nans_equal, - ) + return cast( + "Self", + self.to_frame().select( + col(self.name).is_close( + other._to_expr() if is_series(other) else other, + abs_tol=abs_tol, + rel_tol=rel_tol, + nans_equal=nans_equal, + ) + )[self.name], ) @property diff --git a/narwhals/series_str.py b/narwhals/series_str.py index 7469ad5c53..c74d0d12ba 100644 --- a/narwhals/series_str.py +++ b/narwhals/series_str.py @@ -57,7 +57,7 @@ def replace( """ return self._narwhals_series._with_compliant( self._narwhals_series._compliant_series.str.replace( - pattern, self._extract_compliant(value), literal=literal, n=n + self._extract_compliant(value), pattern=pattern, literal=literal, n=n ) ) @@ -83,7 +83,7 @@ def replace_all( """ return self._narwhals_series._with_compliant( self._narwhals_series._compliant_series.str.replace_all( - pattern, self._extract_compliant(value), literal=literal + self._extract_compliant(value), pattern, literal=literal ) ) diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index ff845bda9c..c68c84d74a 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -6,6 +6,7 @@ import narwhals as nw from narwhals import exceptions, functions as nw_f from narwhals._exceptions import issue_warning +from narwhals._expression_parsing import ExprKind, ExprNode, is_expr from narwhals._typing_compat import TypeVar, assert_never from narwhals._utils import ( Implementation, @@ -67,6 +68,7 @@ from typing_extensions import ParamSpec, Self + from narwhals._expression_parsing import ExprMetadata from narwhals._translate import IntoArrowTable from narwhals._typing import ( Arrow, @@ -234,18 +236,9 @@ def __init__(self, df: Any, *, level: Literal["full", "lazy", "interchange"]) -> def _dataframe(self) -> type[DataFrame[Any]]: return DataFrame - def _extract_compliant(self, arg: Any) -> Any: - # After v1, we raise when passing order-dependent, length-changing, - # or filtration expressions to LazyFrame - from narwhals.expr import Expr - from narwhals.series import Series - - if isinstance(arg, Series): # pragma: no cover - msg = "Mixing Series with LazyFrame is not supported." - raise TypeError(msg) - if isinstance(arg, (Expr, str)): - return self.__narwhals_namespace__().parse_into_expr(arg, str_as_lit=False) - raise InvalidIntoExprError.from_invalid_type(type(arg)) + def _validate_metadata(self, metadata: ExprMetadata) -> None: + # After v1, we raise for order-dependent operations. + pass def collect( self, backend: IntoBackend[Polars | Pandas | Arrow] | None = None, **kwargs: Any @@ -368,15 +361,11 @@ def _l1_norm(self) -> Self: def head(self, n: int = 10) -> Self: r"""Get the first `n` rows.""" - return self._with_orderable_filtration( - lambda plx: self._to_compliant_expr(plx).head(n) # type: ignore[attr-defined] - ) + return self._with_node(ExprNode(ExprKind.FILTRATION, "head", n=n)) def tail(self, n: int = 10) -> Self: r"""Get the last `n` rows.""" - return self._with_orderable_filtration( - lambda plx: self._to_compliant_expr(plx).tail(n) # type: ignore[attr-defined] - ) + return self._with_node(ExprNode(ExprKind.FILTRATION, "tail", n=n)) def gather_every(self, n: int, offset: int = 0) -> Self: r"""Take every nth value in the Series and return as new Series. @@ -385,8 +374,8 @@ def gather_every(self, n: int, offset: int = 0) -> Self: n: Gather every *n*-th row. offset: Starting index. """ - return self._with_orderable_filtration( - lambda plx: self._to_compliant_expr(plx).gather_every(n=n, offset=offset) # type: ignore[attr-defined] + return self._with_node( + ExprNode(ExprKind.ORDERABLE_FILTRATION, "gather_every", n=n, offset=offset) ) def unique(self, *, maintain_order: bool | None = None) -> Self: @@ -397,33 +386,27 @@ def unique(self, *, maintain_order: bool | None = None) -> Self: "You can safely remove this argument." ) issue_warning(msg, UserWarning) - return self._with_filtration(lambda plx: self._to_compliant_expr(plx).unique()) + return self._with_node(ExprNode(ExprKind.FILTRATION, "unique")) def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: """Sort this column. Place null values first.""" - return self._with_window( - lambda plx: self._to_compliant_expr(plx).sort( # type: ignore[attr-defined] - descending=descending, nulls_last=nulls_last + return self._with_node( + ExprNode( + ExprKind.WINDOW, "sort", descending=descending, nulls_last=nulls_last ) ) def arg_max(self) -> Self: """Returns the index of the maximum value.""" - return self._with_orderable_aggregation( - lambda plx: self._to_compliant_expr(plx).arg_max() # type: ignore[attr-defined] - ) + return self._with_node(ExprNode(ExprKind.ORDERABLE_AGGREGATION, "arg_max")) def arg_min(self) -> Self: """Returns the index of the minimum value.""" - return self._with_orderable_aggregation( - lambda plx: self._to_compliant_expr(plx).arg_min() # type: ignore[attr-defined] - ) + return self._with_node(ExprNode(ExprKind.ORDERABLE_AGGREGATION, "arg_min")) def arg_true(self) -> Self: """Find elements where boolean expression is True.""" - return self._with_orderable_filtration( - lambda plx: self._to_compliant_expr(plx).arg_true() # type: ignore[attr-defined] - ) + return self._with_node(ExprNode(ExprKind.ORDERABLE_FILTRATION, "arg_true")) def sample( self, @@ -442,9 +425,14 @@ def sample( seed: Seed for the random number generator. If set to None (default), a random seed is generated for each sample operation. """ - return self._with_filtration( - lambda plx: self._to_compliant_expr(plx).sample( # type: ignore[attr-defined] - n, fraction=fraction, with_replacement=with_replacement, seed=seed + return self._with_node( + ExprNode( + ExprKind.FILTRATION, + "sample", + n=n, + fraction=fraction, + with_replacement=with_replacement, + seed=seed, ) ) @@ -482,7 +470,7 @@ def _stableify( if isinstance(obj, NwSeries): return Series(obj._compliant_series._with_version(Version.V1), level=obj._level) if isinstance(obj, NwExpr): - return Expr(obj._to_compliant_expr, obj._metadata) + return Expr(*obj._nodes) assert_never(obj) @@ -1198,7 +1186,7 @@ def then(self, value: IntoExpr | NonNestedLiteral | _1DArray) -> Then: class Then(nw_f.Then, Expr): @classmethod def from_then(cls, then: nw_f.Then) -> Then: - return cls(then._to_compliant_expr, then._metadata) + return cls(*then._nodes) def otherwise(self, value: IntoExpr | NonNestedLiteral | _1DArray) -> Expr: return _stableify(super().otherwise(value)) @@ -1372,6 +1360,7 @@ def scan_parquet( "Int32", "Int64", "Int128", + "InvalidIntoExprError", "LazyFrame", "List", "Object", @@ -1404,6 +1393,7 @@ def scan_parquet( "generate_temporary_column_name", "get_level", "get_native_namespace", + "is_expr", "is_ordered_categorical", "len", "lit", diff --git a/narwhals/stable/v2/__init__.py b/narwhals/stable/v2/__init__.py index ada1ff20e8..cb97164e46 100644 --- a/narwhals/stable/v2/__init__.py +++ b/narwhals/stable/v2/__init__.py @@ -328,7 +328,7 @@ def _stableify( if isinstance(obj, NwSeries): return Series(obj._compliant_series._with_version(Version.V2), level=obj._level) if isinstance(obj, NwExpr): - return Expr(obj._to_compliant_expr, obj._metadata) + return Expr(*obj._nodes) assert_never(obj) @@ -990,7 +990,7 @@ def then(self, value: IntoExpr | NonNestedLiteral | _1DArray) -> Then: class Then(nw_f.Then, Expr): @classmethod def from_then(cls, then: nw_f.Then) -> Then: - return cls(then._to_compliant_expr, then._metadata) + return cls(*then._nodes) def otherwise(self, value: IntoExpr | NonNestedLiteral | _1DArray) -> Expr: return _stableify(super().otherwise(value)) diff --git a/narwhals/typing.py b/narwhals/typing.py index 356af6e66f..22d4d1d91f 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -106,7 +106,17 @@ def Time(self) -> type[dtypes.Time]: ... def Binary(self) -> type[dtypes.Binary]: ... -IntoExpr: TypeAlias = Union["Expr", str, "Series[Any]"] +_ShapeT = TypeVar("_ShapeT", bound="tuple[int, ...]") +_NDArray: TypeAlias = "np.ndarray[_ShapeT, Any]" +_1DArray: TypeAlias = "_NDArray[tuple[int]]" +_1DArrayInt: TypeAlias = "np.ndarray[tuple[int], np.dtype[np.integer[Any]]]" +_2DArray: TypeAlias = "_NDArray[tuple[int, int]]" # noqa: PYI047 +_AnyDArray: TypeAlias = "_NDArray[tuple[int, ...]]" # noqa: PYI047 +_NumpyScalar: TypeAlias = "np.generic[Any]" +Into1DArray: TypeAlias = "_1DArray | _NumpyScalar" +"""A 1-dimensional `numpy.ndarray` or scalar that can be converted into one.""" + +IntoExpr: TypeAlias = Union["Expr", str, "Series[Any]", _1DArray] """Anything which can be converted to an expression. Use this to mean "either a Narwhals expression, or something which can be converted @@ -347,15 +357,6 @@ def Binary(self) -> type[dtypes.Binary]: ... - *"all"*: Keeps all the mode's. """ -_ShapeT = TypeVar("_ShapeT", bound="tuple[int, ...]") -_NDArray: TypeAlias = "np.ndarray[_ShapeT, Any]" -_1DArray: TypeAlias = "_NDArray[tuple[int]]" -_1DArrayInt: TypeAlias = "np.ndarray[tuple[int], np.dtype[np.integer[Any]]]" -_2DArray: TypeAlias = "_NDArray[tuple[int, int]]" # noqa: PYI047 -_AnyDArray: TypeAlias = "_NDArray[tuple[int, ...]]" # noqa: PYI047 -_NumpyScalar: TypeAlias = "np.generic[Any]" -Into1DArray: TypeAlias = "_1DArray | _NumpyScalar" -"""A 1-dimensional `numpy.ndarray` or scalar that can be converted into one.""" PandasLikeDType: TypeAlias = "pd.api.extensions.ExtensionDtype | np.dtype[Any]" diff --git a/tests/conftest.py b/tests/conftest.py index c823432c5c..85bbb8432f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -124,6 +124,7 @@ def cudf_constructor(obj: Data) -> NativeDataFrame: # pragma: no cover def polars_eager_constructor(obj: Data) -> pl.DataFrame: + pytest.importorskip("polars") import polars as pl return pl.DataFrame(obj) @@ -137,11 +138,11 @@ def polars_lazy_constructor(obj: Data) -> pl.LazyFrame: def duckdb_lazy_constructor(obj: Data) -> duckdb.DuckDBPyRelation: import duckdb - import polars as pl + import pyarrow as pa duckdb.sql("""set timezone = 'UTC'""") - _df = pl.LazyFrame(obj) + _df = pa.table(obj) return duckdb.table("_df") @@ -207,9 +208,11 @@ def _ibis_backend() -> IbisDuckDBBackend: # pragma: no cover def ibis_lazy_constructor(obj: Data) -> ibis.Table: # pragma: no cover + pytest.importorskip("polars") + pytest.importorskip("ibis") import polars as pl - ldf = pl.from_dict(obj).lazy() + ldf = pl.LazyFrame(obj) table_name = str(uuid.uuid4()) return _ibis_backend().create_table(table_name, ldf) diff --git a/tests/expr_and_series/dt/convert_time_zone_test.py b/tests/expr_and_series/dt/convert_time_zone_test.py index 40e5f08d77..65d1a6e3b6 100644 --- a/tests/expr_and_series/dt/convert_time_zone_test.py +++ b/tests/expr_and_series/dt/convert_time_zone_test.py @@ -154,6 +154,7 @@ def test_convert_time_zone_to_connection_tz_duckdb() -> None: ) +@pytest.mark.slow def test_convert_time_zone_to_connection_tz_pyspark() -> None: # pragma: no cover pytest.importorskip("pyspark") diff --git a/tests/expr_and_series/dt/replace_time_zone_test.py b/tests/expr_and_series/dt/replace_time_zone_test.py index d0e90cdadd..1c9dff7d59 100644 --- a/tests/expr_and_series/dt/replace_time_zone_test.py +++ b/tests/expr_and_series/dt/replace_time_zone_test.py @@ -142,6 +142,7 @@ def test_replace_time_zone_to_connection_tz_duckdb() -> None: ) +@pytest.mark.slow def test_replace_time_zone_to_connection_tz_pyspark() -> None: # pragma: no cover pytest.importorskip("pyspark") diff --git a/tests/expr_and_series/over_test.py b/tests/expr_and_series/over_test.py index 1fd71d8347..52e353fd1b 100644 --- a/tests/expr_and_series/over_test.py +++ b/tests/expr_and_series/over_test.py @@ -386,7 +386,7 @@ def test_over_cum_reverse( def test_over_raise_len_change(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) - with pytest.raises(InvalidOperationError): + with pytest.raises((InvalidOperationError, NotImplementedError)): nw.from_native(df).select(nw.col("b").drop_nulls().over("a")) diff --git a/tests/expr_and_series/unique_test.py b/tests/expr_and_series/unique_test.py index faeadf8bcb..c54a0356c7 100644 --- a/tests/expr_and_series/unique_test.py +++ b/tests/expr_and_series/unique_test.py @@ -15,7 +15,7 @@ def test_unique_expr(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) context = ( - pytest.raises(InvalidOperationError) + pytest.raises((InvalidOperationError, NotImplementedError)) if isinstance(df, nw.LazyFrame) else does_not_raise() ) @@ -41,10 +41,10 @@ def test_unique_expr_agg( def test_unique_illegal_combination(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) - with pytest.raises(InvalidOperationError): - df.select((nw.col("a").unique() + nw.col("b").unique()).sum()) - with pytest.raises(InvalidOperationError): - df.select(nw.col("a").unique() + nw.col("b")) + with pytest.raises((InvalidOperationError, NotImplementedError)): + df.select((nw.col("a").unique() + nw.col("a").unique()).sum()) + with pytest.raises((InvalidOperationError, NotImplementedError)): + df.select(nw.col("a").unique() + nw.col("a")) def test_unique_series(constructor_eager: ConstructorEager) -> None: diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index cfcaad680c..422651eaee 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -6,7 +6,7 @@ import pytest import narwhals as nw -from narwhals.exceptions import InvalidOperationError, MultiOutputExpressionError +from narwhals.exceptions import MultiOutputExpressionError from tests.utils import Constructor, ConstructorEager, assert_equal_data if TYPE_CHECKING: @@ -115,13 +115,14 @@ def test_when_then_otherwise_into_expr(constructor: Constructor) -> None: assert_equal_data(result, expected) -def test_when_then_invalid(constructor: Constructor) -> None: +def test_when_then_broadcasting(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) - with pytest.raises(InvalidOperationError): - df.select(nw.when(nw.col("a").sum() > 1).then("c")) - - with pytest.raises(InvalidOperationError): - df.select(nw.when(nw.col("a").sum() > 1).then(1).otherwise("c")) + result = df.select(nw.when(nw.col("a").sum() > 1).then("c")) + expected = {"c": [4.1, 5, 6]} + assert_equal_data(result, expected) + result = df.select(nw.when(nw.col("a").sum() > 1).then(1).otherwise("c")) + expected = {"literal": [1, 1, 1]} + assert_equal_data(result, expected) def test_when_then_otherwise_lit_str(constructor: Constructor) -> None: diff --git a/tests/expression_parsing_test.py b/tests/expression_parsing_test.py index 79b7f89b70..ee92fc35c6 100644 --- a/tests/expression_parsing_test.py +++ b/tests/expression_parsing_test.py @@ -4,96 +4,125 @@ import narwhals as nw from narwhals.exceptions import InvalidOperationError +from tests.utils import Constructor, assert_equal_data + +pytest.importorskip("polars") + +import polars as pl @pytest.mark.parametrize( - ("expr", "expected"), + ("expr", "pl_expr", "expected"), [ - (nw.col("a"), 0), - (nw.col("a").mean(), 0), - (nw.col("a").cum_sum(), 1), - (nw.col("a").cum_sum().over(order_by="id"), 0), - (nw.col("a").cum_sum().abs().over(order_by="id"), 1), - ((nw.col("a").cum_sum() + 1).over(order_by="id"), 1), - (nw.col("a").cum_sum().cum_sum().over(order_by="id"), 1), - (nw.col("a").cum_sum().cum_sum(), 2), - (nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum()), 1), - (nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum()).over(order_by="a"), 1), - (nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum().over(order_by="i")), 0), + (nw.col("a"), pl.col("a"), [-1, 2, 3]), + (nw.col("a").mean(), pl.col("a").mean(), [4 / 3, 4 / 3, 4 / 3]), ( - nw.sum_horizontal( - nw.col("a").diff(), nw.col("a").cum_sum().over(order_by="i") - ), - 1, + nw.col("a").cum_sum().over(order_by="i"), + pl.col("a").cum_sum().over(order_by="i"), + [-1, 1, 4], + ), + ( + nw.col("a").cum_sum().abs().over(order_by="i"), + pl.col("a").cum_sum().abs().over(order_by="i"), + [1, 1, 4], + ), + ( + (nw.col("a").cum_sum() + 1).over(order_by="i"), + (pl.col("a").cum_sum() + 1).over(order_by="i"), + [0, 2, 5], + ), + ( + nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum()).over(order_by="a"), + pl.sum_horizontal(pl.col("a"), pl.col("a").cum_sum()).over(order_by="a"), + [-2, 3, 7], + ), + ( + nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum().over(order_by="i")), + pl.sum_horizontal(pl.col("a"), pl.col("a").cum_sum().over(order_by="i")), + [-2, 3, 7], ), ( nw.sum_horizontal(nw.col("a").diff(), nw.col("a").cum_sum()).over( order_by="i" ), - 2, + pl.sum_horizontal(pl.col("a").diff(), pl.col("a").cum_sum()).over( + order_by="i" + ), + [-1.0, 4.0, 5.0], ), ( nw.sum_horizontal(nw.col("a").diff().abs(), nw.col("a").cum_sum()).over( order_by="i" ), - 2, + pl.sum_horizontal(pl.col("a").diff().abs(), pl.col("a").cum_sum()).over( + order_by="i" + ), + [-1.0, 4.0, 5.0], + ), + ( + (nw.col("a").sum() + nw.col("a").rolling_sum(2, min_samples=1)).over( + order_by="i" + ), + (pl.col("a").sum() + pl.col("a").rolling_sum(2, min_samples=1)).over( + order_by="i" + ), + [3.0, 5.0, 9.0], + ), + ( + (nw.col("a").sum() + nw.col("a").mean()).over("b"), + (pl.col("a").sum() + pl.col("a").mean()).over("b"), + [1.5, 1.5, 6.0], + ), + ( + (nw.col("a").mean().abs() + nw.sum_horizontal(nw.col("a").diff())).over( + order_by="i" + ), + (pl.col("a").mean().abs() + pl.sum_horizontal(pl.col("a").diff())).over( + order_by="i" + ), + [4 / 3, 13 / 3, 7 / 3], ), ], ) -def test_window_kind(expr: nw.Expr, expected: int) -> None: - assert expr._metadata.n_orderable_ops == expected - - -def test_misleading_order_by() -> None: - with pytest.raises(InvalidOperationError): - nw.col("a").mean().over(order_by="b") - - -def test_double_over() -> None: - with pytest.raises(InvalidOperationError): - nw.col("a").mean().over("b").over("c") - - -def test_double_agg() -> None: - with pytest.raises(InvalidOperationError): - nw.col("a").mean().mean() - with pytest.raises(InvalidOperationError): - nw.col("a").mean().sum() - +def test_over_pushdown( + constructor: Constructor, expr: nw.Expr, pl_expr: pl.Expr, expected: list[float] +) -> None: + data = {"a": [-1, 2, 3], "b": [1, 1, 2], "i": [0, 1, 2]} + df = nw.from_native(constructor(data)).lazy() + result = df.select("i", a=expr).sort("i").select("a") + assert_equal_data(result, {"a": expected}) -def test_filter_aggregation() -> None: - with pytest.raises(InvalidOperationError): - nw.col("a").mean().drop_nulls() + # Confirm that doing the calculation in pure-Polars produces the same result. + pl_result = {"a": pl.DataFrame(data).select("i", a=pl_expr).sort("i")["a"].to_list()} + assert_equal_data(result, pl_result) -def test_rank_aggregation() -> None: - with pytest.raises(InvalidOperationError): - nw.col("a").mean().rank() - with pytest.raises(InvalidOperationError): - nw.col("a").mean().is_unique() - - -def test_diff_aggregation() -> None: - with pytest.raises(InvalidOperationError): - nw.col("a").mean().diff() - - -def test_invalid_over() -> None: - with pytest.raises(InvalidOperationError): - nw.col("a").fill_null(3).over("b") - - -def test_nested_over() -> None: - with pytest.raises(InvalidOperationError): - nw.col("a").mean().over("b").over("c") - with pytest.raises(InvalidOperationError): - nw.col("a").mean().over("b").over("c", order_by="i") - - -def test_filtration_over() -> None: - with pytest.raises(InvalidOperationError): - nw.col("a").drop_nulls().over("b") - with pytest.raises(InvalidOperationError): - nw.col("a").drop_nulls().over("b", order_by="i") - with pytest.raises(InvalidOperationError): - nw.col("a").diff().drop_nulls().over("b", order_by="i") +@pytest.mark.parametrize( + "expr", + [ + nw.col("a").cum_sum(), + nw.col("a").cum_sum().cum_sum().over(order_by="i"), + nw.col("a").cum_sum().cum_sum(), + nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum()), + nw.sum_horizontal(nw.col("a").diff(), nw.col("a").cum_sum().over(order_by="i")), + nw.col("a").mean().over(order_by="i"), + nw.col("a").mean().over("b").over("c"), + nw.col("a").mean().over("b").over("c", order_by="i"), + nw.col("a").mean().mean(), + nw.col("a").mean().sum(), + nw.col("a").mean().drop_nulls(), + nw.col("a").mean().rank(), + nw.col("a").mean().is_unique(), + nw.col("a").mean().diff(), + nw.col("a").fill_null(3).over("b"), + nw.col("a").drop_nulls().over("b"), + nw.col("a").drop_nulls().over("b", order_by="i"), + nw.col("a").diff().drop_nulls().over("b", order_by="i"), + ], +) +def test_invalid_operations(constructor: Constructor, expr: nw.Expr) -> None: + df = nw.from_native( + constructor({"a": [-1, 2, 3], "b": [1, 1, 1], "c": [2, 2, 2], "i": [0, 1, 2]}) + ).lazy() + with pytest.raises((InvalidOperationError, NotImplementedError)): + df.select(a=expr) diff --git a/tests/frame/filter_test.py b/tests/frame/filter_test.py index d2b3fd34f3..48941941a0 100644 --- a/tests/frame/filter_test.py +++ b/tests/frame/filter_test.py @@ -49,7 +49,7 @@ def test_filter_raise_on_agg_predicate(constructor: Constructor) -> None: def test_filter_raise_on_shape_mismatch(constructor: Constructor) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]} df = nw.from_native(constructor(data)) - with pytest.raises(InvalidOperationError): + with pytest.raises((InvalidOperationError, NotImplementedError)): df.filter(nw.col("b").unique() > 2).lazy().collect() diff --git a/tests/frame/group_by_test.py b/tests/frame/group_by_test.py index 183811dddf..1d2916da91 100644 --- a/tests/frame/group_by_test.py +++ b/tests/frame/group_by_test.py @@ -364,7 +364,7 @@ def test_group_by_shift_raises(constructor: Constructor) -> None: df_native = {"a": [1, 2, 3], "b": [1, 1, 2]} df = nw.from_native(constructor(df_native)) with pytest.raises(InvalidOperationError, match="does not aggregate"): - df.group_by("b").agg(nw.col("a").shift(1)) + df.group_by("b").agg(nw.col("a").abs()) def test_double_same_aggregation( @@ -513,21 +513,27 @@ def test_group_by_expr( @pytest.mark.parametrize( ("keys", "lazy_context"), [ - ([nw.col("a").drop_nulls()], pytest.raises(InvalidOperationError)), # Filtration + ( + [nw.col("a").drop_nulls()], + pytest.raises((InvalidOperationError, NotImplementedError)), + ), # Filtration ( [nw.col("a").alias("foo"), nw.col("a").drop_nulls()], - pytest.raises(InvalidOperationError), + pytest.raises((InvalidOperationError, NotImplementedError)), ), # Transform and Filtration ( [nw.col("a").alias("foo"), nw.col("a").max()], - pytest.raises(ComputeError), + pytest.raises((ComputeError, NotImplementedError)), ), # Transform and Aggregation ( [nw.col("a").alias("foo"), nw.col("a").cum_max()], - pytest.raises(InvalidOperationError), + pytest.raises((InvalidOperationError, NotImplementedError)), ), # Transform and Window - ([nw.lit(42)], pytest.raises(ComputeError)), # Literal - ([nw.lit(42).abs()], pytest.raises(ComputeError)), # Literal + ([nw.lit(42)], pytest.raises((ComputeError, NotImplementedError))), # Literal + ( + [nw.lit(42).abs()], + pytest.raises((ComputeError, NotImplementedError)), + ), # Literal ], ) def test_group_by_raise_if_not_elementwise( diff --git a/tests/selectors_test.py b/tests/selectors_test.py index 63c6b387e5..ba780299b8 100644 --- a/tests/selectors_test.py +++ b/tests/selectors_test.py @@ -1,6 +1,5 @@ from __future__ import annotations -import re from datetime import datetime, timezone from typing import Literal @@ -8,6 +7,7 @@ import narwhals as nw import narwhals.selectors as ncs +from narwhals.exceptions import MultiOutputExpressionError from tests.utils import ( PANDAS_VERSION, POLARS_VERSION, @@ -246,10 +246,7 @@ def test_set_ops_invalid(constructor: Constructor) -> None: with pytest.raises((NotImplementedError, ValueError)): df.select(1 & ncs.numeric()) - with pytest.raises( - TypeError, - match=re.escape("unsupported operand type(s) for op: ('Selector' + 'Selector')"), - ): + with pytest.raises(MultiOutputExpressionError): df.select(ncs.boolean() + ncs.numeric()) diff --git a/tests/utils.py b/tests/utils.py index 95c29537d4..399ee68622 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -147,7 +147,7 @@ def assert_equal_data(result: Any, expected: Mapping[str, Any]) -> None: are_equivalent_values = lhs == rhs assert are_equivalent_values, ( - f"Mismatch at index {i}: {lhs} != {rhs}\nExpected: {expected}\nGot: {result}" + f"Mismatch {key} at index {i}: {lhs} != {rhs}\nExpected: {expected}\nGot: {result}" ) diff --git a/tests/v1_test.py b/tests/v1_test.py index 87c3888bb5..5c7d68775c 100644 --- a/tests/v1_test.py +++ b/tests/v1_test.py @@ -894,8 +894,9 @@ def test_unique_series_v1() -> None: def test_head_aggregation() -> None: + df = nw.from_native(pd.DataFrame({"a": [1, 2]})) with pytest.raises(InvalidOperationError): - nw_v1.col("a").mean().head() + df.select(nw_v1.col("a").mean().head()) def test_deprecated_expr_methods() -> None: From 0cf73caf53cf91a6bb10241abb1e9aa718fa2abf Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 25 Sep 2025 20:03:15 +0300 Subject: [PATCH 02/95] coverage --- narwhals/_compliant/dataframe.py | 2 +- narwhals/_compliant/series.py | 5 +-- narwhals/_duckdb/expr.py | 6 ---- narwhals/_duckdb/utils.py | 4 --- narwhals/_expression_parsing.py | 62 +++----------------------------- narwhals/_ibis/expr_str.py | 2 -- narwhals/_polars/namespace.py | 25 +------------ narwhals/_polars/series.py | 6 ---- narwhals/_spark_like/expr.py | 6 ---- narwhals/_sql/dataframe.py | 2 +- narwhals/_sql/expr.py | 22 +++--------- narwhals/_utils.py | 2 +- narwhals/dataframe.py | 2 +- 13 files changed, 15 insertions(+), 131 deletions(-) diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index c369663fe5..2e5d7e1fa9 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -337,7 +337,7 @@ def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | N def _evaluate_expr(self, expr: EagerExprT, /) -> EagerSeriesT: """Evaluate `expr` and ensure it has a **single** output.""" result: Sequence[EagerSeriesT] = expr(self) - if len(result) != 1: + if len(result) != 1: # pragma: no cover msg = "multi-output expressions not allowed in this context" raise MultiOutputExpressionError(msg) return result[0] diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index 2b0b02a3f1..b083ee0df1 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -37,7 +37,7 @@ from typing_extensions import NotRequired, Self, TypedDict from narwhals._compliant.dataframe import CompliantDataFrame - from narwhals._compliant.expr import CompliantExpr, EagerExpr + from narwhals._compliant.expr import CompliantExpr from narwhals._compliant.namespace import EagerNamespace from narwhals._utils import Implementation, Version, _LimitedContext from narwhals.dtypes import DType @@ -246,9 +246,6 @@ def __narwhals_namespace__( self, ) -> EagerNamespace[Any, Self, Any, Any, NativeSeriesT]: ... - def _to_expr(self) -> EagerExpr[Any, Any]: - return self.__narwhals_namespace__()._expr._from_series(self) # type: ignore[no-any-return] - def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ... def _gather_slice(self, rows: _SliceIndex | range) -> Self: ... def __getitem__(self, item: MultiIndexSelector[Self]) -> Self: diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index f86b5b21a3..a93058b6df 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -117,14 +117,8 @@ def from_column_names( def func(df: DuckDBLazyFrame) -> list[Expression]: return [col(name) for name in evaluate_column_names(df)] - def window_func( - df: DuckDBLazyFrame, _window_inputs: WindowInputs[Expression] - ) -> list[Expression]: - return [col(name) for name in evaluate_column_names(df)] - return cls( func, - window_func, evaluate_output_names=evaluate_column_names, alias_output_names=None, version=context._version, diff --git a/narwhals/_duckdb/utils.py b/narwhals/_duckdb/utils.py index 026a55fdb5..d3693a82f5 100644 --- a/narwhals/_duckdb/utils.py +++ b/narwhals/_duckdb/utils.py @@ -84,10 +84,6 @@ def evaluate_exprs_and_aliases( return native_results -def evaluate_exprs(df: DuckDBLazyFrame, /, *exprs: DuckDBExpr) -> list[Expression]: - return [item for expr in exprs for item in expr(df)] - - class DeferredTimeZone: """Object which gets passed between `native_to_narwhals_dtype` calls. diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index c6c890acae..982152ad36 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -393,8 +393,8 @@ def from_node( # noqa: PLR0911 return ExprMetadata.from_selector_multi_unnamed(node) if node.kind is ExprKind.ELEMENTWISE: return ExprMetadata.from_elementwise(node, *ces) - msg = f"Unexpected node kind: {node.kind}" - raise AssertionError(msg) + msg = f"Unexpected node kind: {node.kind}" # pragma: no cover + raise AssertionError(msg) # pragma: no cover def with_node( # noqa: PLR0911,C901 self, @@ -409,7 +409,6 @@ def with_node( # noqa: PLR0911,C901 ce, *ces, str_as_lit=node.str_as_lit, - allow_multi_output=node.allow_multi_output, to_single_output=False, nodes=(*ce._metadata.nodes, node), ) @@ -423,8 +422,6 @@ def with_node( # noqa: PLR0911,C901 return self.with_orderable_aggregation(node) if node.kind is ExprKind.WINDOW: return self.with_window(node) - if node.kind is ExprKind.SELECTOR: - return self if node.kind is ExprKind.OVER: if node.kwargs["order_by"]: return self.with_ordered_over(node) @@ -432,8 +429,8 @@ def with_node( # noqa: PLR0911,C901 msg = "At least one of `partition_by` or `order_by` must be specified." raise InvalidOperationError(msg) return self.with_partitioned_over(node) - msg = f"Unexpected node kind: {node.kind}" - raise AssertionError(msg) + msg = f"Unexpected node kind: {node.kind}" # pragma: no cover + raise AssertionError(msg) # pragma: no cover @classmethod def from_aggregation(cls, node: ExprNode) -> ExprMetadata: @@ -476,11 +473,7 @@ def from_elementwise( cls, node: ExprNode, *ces: CompliantExprAny | NonNestedLiteral ) -> ExprMetadata: return combine_metadata( - *ces, - str_as_lit=False, - allow_multi_output=True, - to_single_output=True, - nodes=(node,), + *ces, str_as_lit=False, to_single_output=True, nodes=(node,) ) @property @@ -518,18 +511,6 @@ def with_orderable_aggregation(self, node: ExprNode) -> ExprMetadata: nodes=(*self.nodes, node), ) - def with_elementwise_op(self, node: ExprNode) -> ExprMetadata: - return ExprMetadata( - self.expansion_kind, - has_windows=self.has_windows, - n_orderable_ops=self.n_orderable_ops, - preserves_length=self.preserves_length, - is_elementwise=self.is_elementwise, - is_scalar_like=self.is_scalar_like, - is_literal=self.is_literal, - nodes=(*self.nodes, node), - ) - def with_window(self, node: ExprNode) -> ExprMetadata: # Window function which may (but doesn't have to) be used with `over(order_by=...)`. if self.is_scalar_like: @@ -664,7 +645,6 @@ def op_nodes_reversed(self) -> Iterator[ExprNode]: def combine_metadata( *args: IntoExpr | object | None, str_as_lit: bool, - allow_multi_output: bool, to_single_output: bool, nodes: tuple[ExprNode, ...], ) -> ExprMetadata: @@ -673,7 +653,6 @@ def combine_metadata( Arguments: args: Arguments, maybe expressions, literals, or Series. str_as_lit: Whether to interpret strings as literals or as column names. - allow_multi_output: Whether to allow multi-output inputs. to_single_output: Whether the result is always single-output, regardless of the inputs (e.g. `nw.sum_horizontal`). nodes: Nodes of result node. @@ -701,13 +680,6 @@ def combine_metadata( assert metadata is not None # noqa: S101 if metadata.expansion_kind.is_multi_output(): expansion_kind = metadata.expansion_kind - if i > 0 and not allow_multi_output: - # Left-most argument is always allowed to be multi-output. - msg = ( - "Multi-output expressions (e.g. nw.col('a', 'b'), nw.all()) " - "are not supported in this context." - ) - raise MultiOutputExpressionError(msg) if not to_single_output: result_expansion_kind = ( result_expansion_kind & expansion_kind @@ -752,25 +724,6 @@ def check_expressions_preserve_length( raise InvalidOperationError(msg) -def all_exprs_are_scalar_like(mds: Sequence[ExprMetadata]) -> bool: - # Raise if any argument in `args` isn't an aggregation or literal. - # For Series input, we don't raise (yet), we let such checks happen later, - # as this function works lazily and so can't evaluate lengths. - return all(md.is_scalar_like for md in mds) - - -def apply_binary( - plx: CompliantNamespaceAny, - name: str, - ce: CompliantExprAny, - other: IntoExpr | NonNestedLiteral | _1DArray, -) -> CompliantExprAny: - parse = plx.evaluate_expr - other_compliant = parse(other) - compliant_exprs = [ce, other_compliant] - return getattr(compliant_exprs[0], name)(compliant_exprs[1]) - - @overload def _parse_into_expr( arg: IntoExpr | NonNestedLiteral | _1DArray, @@ -890,11 +843,6 @@ def evaluate_node( func = getattr(getattr(ce, accessor), method) else: func = getattr(ce, node.name) - if not node.allow_multi_output and any( - x._metadata.expansion_kind.is_multi_output() for x in ces if is_compliant_expr(x) - ): - msg = "multi-output expressions are not allowed as arguments to Expr methods." - raise MultiOutputExpressionError(msg) ret = cast("CompliantExprAny", func(*ces, **node.kwargs)) ret._opt_metadata = md return ret diff --git a/narwhals/_ibis/expr_str.py b/narwhals/_ibis/expr_str.py index 78907de7ef..312da1cccb 100644 --- a/narwhals/_ibis/expr_str.py +++ b/narwhals/_ibis/expr_str.py @@ -44,8 +44,6 @@ def replace_all( self, value: str | IbisExpr, pattern: str, *, literal: bool ) -> IbisExpr: fn = self._replace_all_literal if literal else self._replace_all - if isinstance(value, str): - return self.compliant._with_callable(fn(pattern, value)) return self.compliant._with_elementwise( lambda expr, value: fn(pattern, value)(expr), value=value ) diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 0599196c7c..2413895f30 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -5,7 +5,7 @@ import polars as pl -from narwhals._expression_parsing import is_expr, is_series +from narwhals._expression_parsing import is_expr from narwhals._polars.expr import PolarsExpr from narwhals._polars.series import PolarsSeries from narwhals._polars.utils import extract_args_kwargs, narwhals_to_native_dtype @@ -22,14 +22,12 @@ from narwhals._polars.typing import FrameT from narwhals._utils import Version, _LimitedContext from narwhals.expr import Expr - from narwhals.series import Series from narwhals.typing import ( Into1DArray, IntoDType, IntoSchema, NonNestedLiteral, TimeUnit, - _1DArray, _2DArray, ) @@ -87,27 +85,6 @@ def _expr(self) -> type[PolarsExpr]: def _series(self) -> type[PolarsSeries]: return PolarsSeries - def parse_into_expr( - self, - data: Expr | NonNestedLiteral | Series[pl.Series] | _1DArray, - /, - *, - str_as_lit: bool, - ) -> PolarsExpr | None: - if data is None: - # NOTE: To avoid `pl.lit(None)` failing this `None` check - # https://github.com/pola-rs/polars/blob/58dd8e5770f16a9bef9009a1c05f00e15a5263c7/py-polars/polars/expr/expr.py#L2870-L2872 - return data - if isinstance(data, PolarsExpr): - return data - if is_expr(data): - expr = data(self) - assert isinstance(expr, self._expr) # noqa: S101 - return expr - if isinstance(data, str) and not str_as_lit: - return self.col([data]) - return self.lit(data.to_native() if is_series(data) else data, None) - @overload def from_native(self, data: pl.DataFrame, /) -> PolarsDataFrame: ... @overload diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index 222f15723e..2b2e7ce2da 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -4,7 +4,6 @@ import polars as pl -from narwhals._polars.expr import PolarsExpr from narwhals._polars.utils import ( BACKEND_VERSION, SERIES_ACCEPTS_PD_INDEX, @@ -33,7 +32,6 @@ import pyarrow as pa from typing_extensions import Self, TypeAlias, TypeIs - from narwhals._compliant.typing import CompliantExprAny from narwhals._polars.dataframe import Method, PolarsDataFrame from narwhals._polars.namespace import PolarsNamespace from narwhals._utils import Version, _LimitedContext @@ -151,10 +149,6 @@ def __init__(self, series: pl.Series, *, version: Version) -> None: self._native_series = series self._version = version - def _to_expr(self) -> CompliantExprAny: - # Polars can treat Series as Expr, so just pass down `self.native`. - return PolarsExpr(self.native, version=self._version) # type: ignore[arg-type] - @property def _backend_version(self) -> tuple[int, ...]: return self._implementation._backend_version() diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 8a78b38b76..3a2b4bbb5b 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -171,14 +171,8 @@ def from_column_names( def func(df: SparkLikeLazyFrame) -> list[Column]: return [df._F.col(col_name) for col_name in evaluate_column_names(df)] - def window_func( - df: SparkLikeLazyFrame, _window_inputs: WindowInputs[Column] - ) -> list[Column]: - return [df._F.col(col_name) for col_name in evaluate_column_names(df)] - return cls( func, - window_func, evaluate_output_names=evaluate_column_names, alias_output_names=None, version=context._version, diff --git a/narwhals/_sql/dataframe.py b/narwhals/_sql/dataframe.py index 41eef9ef9a..cd3e2546aa 100644 --- a/narwhals/_sql/dataframe.py +++ b/narwhals/_sql/dataframe.py @@ -42,7 +42,7 @@ def _evaluate_window_expr( def _evaluate_expr(self, expr: SQLExpr[Self, NativeExprT], /) -> NativeExprT: result = expr(self) - if len(result) != 1: + if len(result) != 1: # pragma: no cover msg = "multi-output expressions not allowed in this context" raise MultiOutputExpressionError(msg) return result[0] diff --git a/narwhals/_sql/expr.py b/narwhals/_sql/expr.py index 292d822cbb..ba0c99cc10 100644 --- a/narwhals/_sql/expr.py +++ b/narwhals/_sql/expr.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: from collections.abc import Sequence - from typing_extensions import Self, TypeIs + from typing_extensions import Self from narwhals._compliant.typing import AliasNames, WindowFunction from narwhals._expression_parsing import ExprMetadata @@ -290,10 +290,6 @@ def func( return func - @classmethod - def _is_expr(cls, obj: Self | Any) -> TypeIs[Self]: - return hasattr(obj, "__narwhals_expr__") - @property def _backend_version(self) -> tuple[int, ...]: return self._implementation._backend_version() @@ -489,25 +485,15 @@ def clip( lower_bound: Self | NumericLiteral | TemporalLiteral | None, upper_bound: Self | NumericLiteral | TemporalLiteral | None, ) -> Self: - def _clip_lower(expr: NativeExprT, lower_bound: Any) -> NativeExprT: - return self._function("greatest", expr, lower_bound) - - def _clip_upper(expr: NativeExprT, upper_bound: Any) -> NativeExprT: - return self._function("least", expr, upper_bound) - - def _clip_both( - expr: NativeExprT, lower_bound: Any, upper_bound: Any + def _clip( + expr: NativeExprT, lower_bound: NativeExprT, upper_bound: NativeExprT ) -> NativeExprT: return self._function( "greatest", self._function("least", expr, upper_bound), lower_bound ) - if lower_bound is None: - return self._with_elementwise(_clip_upper, upper_bound=upper_bound) - if upper_bound is None: - return self._with_elementwise(_clip_lower, lower_bound=lower_bound) return self._with_elementwise( - _clip_both, lower_bound=lower_bound, upper_bound=upper_bound + _clip, lower_bound=lower_bound, upper_bound=upper_bound ) def is_null(self) -> Self: diff --git a/narwhals/_utils.py b/narwhals/_utils.py index c9d17e367a..346008cc81 100644 --- a/narwhals/_utils.py +++ b/narwhals/_utils.py @@ -1796,7 +1796,7 @@ def __call__(self, *args: Any, **kwds: Any) -> Any: return self.__get__("raise") @classmethod - def deprecated(cls, message: LiteralString, /) -> Self: + def deprecated(cls, message: LiteralString, /) -> Self: # pragma: no cover """Alt constructor, wraps with `@deprecated`. Arguments: diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index d46a072cbc..5d22380b1f 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -173,7 +173,7 @@ def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | N return check_columns_exist(subset, available=self.columns) @abstractmethod - def _validate_metadata(self, metadata: ExprMetadata) -> None: + def _validate_metadata(self, metadata: ExprMetadata) -> None: # pragma: no cover pass @property From 52f978ef10615223b0bd7334219c9acf6fc448f1 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 25 Sep 2025 20:10:50 +0300 Subject: [PATCH 03/95] typing --- narwhals/_compliant/series.py | 3 --- narwhals/_expression_parsing.py | 4 ++-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index b083ee0df1..561ff3de6d 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -37,7 +37,6 @@ from typing_extensions import NotRequired, Self, TypedDict from narwhals._compliant.dataframe import CompliantDataFrame - from narwhals._compliant.expr import CompliantExpr from narwhals._compliant.namespace import EagerNamespace from narwhals._utils import Implementation, Version, _LimitedContext from narwhals.dtypes import DType @@ -97,8 +96,6 @@ def to_narwhals(self) -> Series[NativeSeriesT]: def _with_native(self, series: Any) -> Self: ... def _with_version(self, version: Version) -> Self: ... - def _to_expr(self) -> CompliantExpr[Any, Self]: ... - # NOTE: `polars` @property def dtype(self) -> DType: ... diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 982152ad36..a2bd3824ed 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -5,7 +5,7 @@ from __future__ import annotations from enum import Enum, auto -from typing import TYPE_CHECKING, Any, Literal, ParamSpec, TypeVar, cast, overload +from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast, overload from narwhals._utils import is_compliant_expr, zip_strict from narwhals.dependencies import is_numpy_array_1d @@ -18,7 +18,7 @@ if TYPE_CHECKING: from collections.abc import Iterator, Sequence - from typing_extensions import Never, TypeIs + from typing_extensions import Never, ParamSpec, TypeIs from narwhals._compliant import CompliantExpr, CompliantFrameT from narwhals._compliant.typing import ( From f6ce1968ff709e6f1298bf40f3c16df5f4c20173 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 25 Sep 2025 20:16:04 +0300 Subject: [PATCH 04/95] coverage --- narwhals/_expression_parsing.py | 59 ++++++++++++--------------------- 1 file changed, 22 insertions(+), 37 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index a2bd3824ed..fa5eaed4ec 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -406,11 +406,7 @@ def with_node( # noqa: PLR0911,C901 return self.with_aggregation(node) if node.kind is ExprKind.ELEMENTWISE: return combine_metadata( - ce, - *ces, - str_as_lit=node.str_as_lit, - to_single_output=False, - nodes=(*ce._metadata.nodes, node), + ce, *ces, to_single_output=False, nodes=(*ce._metadata.nodes, node) ) if node.kind is ExprKind.FILTRATION: return self.with_filtration(node) @@ -472,9 +468,7 @@ def from_selector_multi_unnamed(cls, node: ExprNode) -> ExprMetadata: def from_elementwise( cls, node: ExprNode, *ces: CompliantExprAny | NonNestedLiteral ) -> ExprMetadata: - return combine_metadata( - *ces, str_as_lit=False, to_single_output=True, nodes=(node,) - ) + return combine_metadata(*ces, to_single_output=True, nodes=(node,)) @property def is_filtration(self) -> bool: @@ -620,7 +614,7 @@ def with_filtration(self, node: ExprNode) -> ExprMetadata: ) def with_orderable_filtration(self, node: ExprNode) -> ExprMetadata: - if self.is_scalar_like: + if self.is_scalar_like: # pragma: no cover msg = "Can't apply filtration (e.g. `drop_nulls`) to scalar-like expression." raise InvalidOperationError(msg) return ExprMetadata( @@ -643,16 +637,12 @@ def op_nodes_reversed(self) -> Iterator[ExprNode]: def combine_metadata( - *args: IntoExpr | object | None, - str_as_lit: bool, - to_single_output: bool, - nodes: tuple[ExprNode, ...], + *args: IntoExpr | object | None, to_single_output: bool, nodes: tuple[ExprNode, ...] ) -> ExprMetadata: """Combine metadata from `args`. Arguments: args: Arguments, maybe expressions, literals, or Series. - str_as_lit: Whether to interpret strings as literals or as column names. to_single_output: Whether the result is always single-output, regardless of the inputs (e.g. `nw.sum_horizontal`). nodes: Nodes of result node. @@ -671,29 +661,24 @@ def combine_metadata( result_is_literal = True for i, arg in enumerate(args): - if (isinstance(arg, str) and not str_as_lit) or is_series(arg): - result_preserves_length = True - result_is_scalar_like = False - result_is_literal = False - elif is_compliant_expr(arg): - metadata = arg._metadata - assert metadata is not None # noqa: S101 - if metadata.expansion_kind.is_multi_output(): - expansion_kind = metadata.expansion_kind - if not to_single_output: - result_expansion_kind = ( - result_expansion_kind & expansion_kind - if i > 0 - else expansion_kind - ) - - result_has_windows |= metadata.has_windows - result_n_orderable_ops += metadata.n_orderable_ops - result_preserves_length |= metadata.preserves_length - result_is_elementwise &= metadata.is_elementwise - result_is_scalar_like &= metadata.is_scalar_like - result_is_literal &= metadata.is_literal - n_filtrations += int(metadata.is_filtration) + if not is_compliant_expr(arg): + continue + metadata = arg._metadata + assert metadata is not None # noqa: S101 + if metadata.expansion_kind.is_multi_output(): + expansion_kind = metadata.expansion_kind + if not to_single_output: + result_expansion_kind = ( + result_expansion_kind & expansion_kind if i > 0 else expansion_kind + ) + + result_has_windows |= metadata.has_windows + result_n_orderable_ops += metadata.n_orderable_ops + result_preserves_length |= metadata.preserves_length + result_is_elementwise &= metadata.is_elementwise + result_is_scalar_like &= metadata.is_scalar_like + result_is_literal &= metadata.is_literal + n_filtrations += int(metadata.is_filtration) if n_filtrations > 1: msg = "Length-changing expressions can only be used in isolation, or followed by an aggregation" raise InvalidOperationError(msg) From 6169b23fc546b390d3b27973a6876fef299e38bd Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 25 Sep 2025 20:21:28 +0300 Subject: [PATCH 05/95] typing again --- narwhals/_expression_parsing.py | 8 ++------ narwhals/expr_str.py | 4 +--- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index fa5eaed4ec..2d927508c0 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -5,7 +5,7 @@ from __future__ import annotations from enum import Enum, auto -from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast, overload +from typing import TYPE_CHECKING, Any, Literal, cast, overload from narwhals._utils import is_compliant_expr, zip_strict from narwhals.dependencies import is_numpy_array_1d @@ -18,7 +18,7 @@ if TYPE_CHECKING: from collections.abc import Iterator, Sequence - from typing_extensions import Never, ParamSpec, TypeIs + from typing_extensions import Never, TypeIs from narwhals._compliant import CompliantExpr, CompliantFrameT from narwhals._compliant.typing import ( @@ -32,10 +32,6 @@ from narwhals.series import Series from narwhals.typing import IntoExpr, NonNestedLiteral, _1DArray - T = TypeVar("T") - PS = ParamSpec("PS") - R = TypeVar("R") - def is_expr(obj: Any) -> TypeIs[Expr]: """Check whether `obj` is a Narwhals Expr.""" diff --git a/narwhals/expr_str.py b/narwhals/expr_str.py index b8cb7ce78d..50d54c1635 100644 --- a/narwhals/expr_str.py +++ b/narwhals/expr_str.py @@ -1,14 +1,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Generic, ParamSpec, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar from narwhals._expression_parsing import ExprKind, ExprNode if TYPE_CHECKING: from narwhals.expr import Expr - PS = ParamSpec("PS") - ExprT = TypeVar("ExprT", bound="Expr") From c31c5e9a82859ebd936062f7c018f7714bfcdd80 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 25 Sep 2025 21:38:34 +0300 Subject: [PATCH 06/95] revert accidental change --- narwhals/expr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/expr.py b/narwhals/expr.py index 1c8b8f936c..cd7e899475 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -120,7 +120,7 @@ def alias(self, name: str) -> Self: Examples: >>> import pandas as pd >>> import narwhals as nw - >>> df_native = pd.DataFrame({"a": [], "b": [4, 5]}) + >>> df_native = pd.DataFrame({"a": [1, 2], "b": [4, 5]}) >>> df = nw.from_native(df_native) >>> df.select((nw.col("b") + 10).alias("c")) ┌──────────────────┐ From ed29d1c76ca27deb7f6e25ea263ebeabf9e90bbc Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 25 Sep 2025 22:34:08 +0300 Subject: [PATCH 07/95] skip old polars --- narwhals/expr.py | 2 +- narwhals/series.py | 7 ++----- tests/expression_parsing_test.py | 6 +++++- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/narwhals/expr.py b/narwhals/expr.py index cd7e899475..f3595a1baa 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -2088,7 +2088,7 @@ def sqrt(self) -> Self: def is_close( # noqa: PLR0914 self, - other: Self | NumericLiteral, + other: IntoExpr | NumericLiteral, *, abs_tol: float = 0.0, rel_tol: float = 1e-09, diff --git a/narwhals/series.py b/narwhals/series.py index c4709b6dd1..6eaf95fddd 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -13,7 +13,7 @@ overload, ) -from narwhals._expression_parsing import ExprKind, ExprNode, is_series +from narwhals._expression_parsing import ExprKind, ExprNode from narwhals._utils import ( Implementation, Version, @@ -2715,10 +2715,7 @@ def is_close( "Self", self.to_frame().select( col(self.name).is_close( - other._to_expr() if is_series(other) else other, - abs_tol=abs_tol, - rel_tol=rel_tol, - nans_equal=nans_equal, + other, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal ) )[self.name], ) diff --git a/tests/expression_parsing_test.py b/tests/expression_parsing_test.py index ee92fc35c6..06be976725 100644 --- a/tests/expression_parsing_test.py +++ b/tests/expression_parsing_test.py @@ -4,7 +4,7 @@ import narwhals as nw from narwhals.exceptions import InvalidOperationError -from tests.utils import Constructor, assert_equal_data +from tests.utils import POLARS_VERSION, Constructor, assert_equal_data pytest.importorskip("polars") @@ -87,6 +87,8 @@ def test_over_pushdown( constructor: Constructor, expr: nw.Expr, pl_expr: pl.Expr, expected: list[float] ) -> None: + if "polars" in str(constructor) and POLARS_VERSION < (1, 10): + pytest.skip() data = {"a": [-1, 2, 3], "b": [1, 1, 2], "i": [0, 1, 2]} df = nw.from_native(constructor(data)).lazy() result = df.select("i", a=expr).sort("i").select("a") @@ -121,6 +123,8 @@ def test_over_pushdown( ], ) def test_invalid_operations(constructor: Constructor, expr: nw.Expr) -> None: + if "polars" in str(constructor) and POLARS_VERSION < (1, 10): + pytest.skip() df = nw.from_native( constructor({"a": [-1, 2, 3], "b": [1, 1, 1], "c": [2, 2, 2], "i": [0, 1, 2]}) ).lazy() From f90c13b90170a1a42eadcc8a15183d1f16aa5b57 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 26 Sep 2025 09:41:44 +0300 Subject: [PATCH 08/95] old vs --- tests/expression_parsing_test.py | 54 +++++--------------------------- 1 file changed, 8 insertions(+), 46 deletions(-) diff --git a/tests/expression_parsing_test.py b/tests/expression_parsing_test.py index 06be976725..9c0c63f6cc 100644 --- a/tests/expression_parsing_test.py +++ b/tests/expression_parsing_test.py @@ -6,86 +6,52 @@ from narwhals.exceptions import InvalidOperationError from tests.utils import POLARS_VERSION, Constructor, assert_equal_data -pytest.importorskip("polars") - -import polars as pl - @pytest.mark.parametrize( - ("expr", "pl_expr", "expected"), + ("expr", "expected"), [ - (nw.col("a"), pl.col("a"), [-1, 2, 3]), - (nw.col("a").mean(), pl.col("a").mean(), [4 / 3, 4 / 3, 4 / 3]), - ( - nw.col("a").cum_sum().over(order_by="i"), - pl.col("a").cum_sum().over(order_by="i"), - [-1, 1, 4], - ), - ( - nw.col("a").cum_sum().abs().over(order_by="i"), - pl.col("a").cum_sum().abs().over(order_by="i"), - [1, 1, 4], - ), - ( - (nw.col("a").cum_sum() + 1).over(order_by="i"), - (pl.col("a").cum_sum() + 1).over(order_by="i"), - [0, 2, 5], - ), + (nw.col("a"), [-1, 2, 3]), + (nw.col("a").mean(), [4 / 3, 4 / 3, 4 / 3]), + (nw.col("a").cum_sum().over(order_by="i"), [-1, 1, 4]), + (nw.col("a").cum_sum().abs().over(order_by="i"), [1, 1, 4]), + ((nw.col("a").cum_sum() + 1).over(order_by="i"), [0, 2, 5]), ( nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum()).over(order_by="a"), - pl.sum_horizontal(pl.col("a"), pl.col("a").cum_sum()).over(order_by="a"), [-2, 3, 7], ), ( nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum().over(order_by="i")), - pl.sum_horizontal(pl.col("a"), pl.col("a").cum_sum().over(order_by="i")), [-2, 3, 7], ), ( nw.sum_horizontal(nw.col("a").diff(), nw.col("a").cum_sum()).over( order_by="i" ), - pl.sum_horizontal(pl.col("a").diff(), pl.col("a").cum_sum()).over( - order_by="i" - ), [-1.0, 4.0, 5.0], ), ( nw.sum_horizontal(nw.col("a").diff().abs(), nw.col("a").cum_sum()).over( order_by="i" ), - pl.sum_horizontal(pl.col("a").diff().abs(), pl.col("a").cum_sum()).over( - order_by="i" - ), [-1.0, 4.0, 5.0], ), ( (nw.col("a").sum() + nw.col("a").rolling_sum(2, min_samples=1)).over( order_by="i" ), - (pl.col("a").sum() + pl.col("a").rolling_sum(2, min_samples=1)).over( - order_by="i" - ), [3.0, 5.0, 9.0], ), - ( - (nw.col("a").sum() + nw.col("a").mean()).over("b"), - (pl.col("a").sum() + pl.col("a").mean()).over("b"), - [1.5, 1.5, 6.0], - ), + ((nw.col("a").sum() + nw.col("a").mean()).over("b"), [1.5, 1.5, 6.0]), ( (nw.col("a").mean().abs() + nw.sum_horizontal(nw.col("a").diff())).over( order_by="i" ), - (pl.col("a").mean().abs() + pl.sum_horizontal(pl.col("a").diff())).over( - order_by="i" - ), [4 / 3, 13 / 3, 7 / 3], ), ], ) def test_over_pushdown( - constructor: Constructor, expr: nw.Expr, pl_expr: pl.Expr, expected: list[float] + constructor: Constructor, expr: nw.Expr, expected: list[float] ) -> None: if "polars" in str(constructor) and POLARS_VERSION < (1, 10): pytest.skip() @@ -94,10 +60,6 @@ def test_over_pushdown( result = df.select("i", a=expr).sort("i").select("a") assert_equal_data(result, {"a": expected}) - # Confirm that doing the calculation in pure-Polars produces the same result. - pl_result = {"a": pl.DataFrame(data).select("i", a=pl_expr).sort("i")["a"].to_list()} - assert_equal_data(result, pl_result) - @pytest.mark.parametrize( "expr", From 29437845d7857da85c2f5090fac9e5b4a5d327c5 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 26 Sep 2025 11:05:57 +0300 Subject: [PATCH 09/95] fix dataframe to numpy --- narwhals/_pandas_like/dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index a8250c9736..abcbb12c01 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -886,7 +886,7 @@ def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray: if to_convert: df = self.with_columns( self.__narwhals_namespace__() - .col(*to_convert) + .col(to_convert) .dt.convert_time_zone("UTC") .dt.replace_time_zone(None) ).native From dea9e3e901940e68c7e3eeb5cedab96141aa739a Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 26 Sep 2025 11:35:37 +0300 Subject: [PATCH 10/95] document ExprNode --- narwhals/_arrow/dataframe.py | 4 ++-- narwhals/_arrow/series.py | 4 ++-- narwhals/_compliant/namespace.py | 5 ++-- narwhals/_dask/dataframe.py | 12 ++++------ narwhals/_expression_parsing.py | 41 ++++++++++++++++++++++---------- narwhals/_pandas_like/expr.py | 2 +- narwhals/_polars/namespace.py | 4 +--- narwhals/_polars/series.py | 6 ++--- 8 files changed, 45 insertions(+), 33 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index cf7bd383d4..9eed006edc 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -424,7 +424,7 @@ def drop_nulls(self, subset: Sequence[str] | None) -> Self: if subset is None: return self._with_native(self.native.drop_null(), validate_column_names=False) plx = self.__narwhals_namespace__() - mask = ~plx.any_horizontal(plx.col(subset).is_null(), ignore_nulls=True) + mask = ~plx.any_horizontal(plx.col(*subset).is_null(), ignore_nulls=True) return self.filter(mask) def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self: @@ -496,7 +496,7 @@ def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self: plx._series.from_iterable(data, context=self, name=name) ) else: - rank = plx.col([order_by[0]]).rank("ordinal", descending=False) + rank = plx.col(order_by[0]).rank("ordinal", descending=False) row_index = (rank.over(partition_by=[], order_by=order_by) - 1).alias(name) return self.select(row_index, plx.all()) diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 2ae902d442..9f63776c40 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -866,8 +866,8 @@ def mode(self, *, keep: ModeKeepStrategy) -> ArrowSeries: name=col_token, normalize=False, sort=False, parallel=False ) result = counts.filter( - plx.col([col_token]) - == plx.col([col_token]).max().broadcast(kind=ExprKind.AGGREGATION) + plx.col(col_token) + == plx.col(col_token).max().broadcast(kind=ExprKind.AGGREGATION) ).get_column(self.name) return result.head(1) if keep == "any" else result diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index 9d767b5e80..25fd753a28 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -74,8 +74,7 @@ def evaluate_expr( def all(self) -> CompliantExprT: return self._expr.from_column_names(get_column_names, context=self) - def col(self, names: Sequence[str]) -> CompliantExprT: - assert not isinstance(names, str) # noqa: S101 # debug assertion + def col(self, *names: str) -> CompliantExprT: return self._expr.from_column_names(passthrough_column_names(names), context=self) def exclude(self, names: Sequence[str]) -> CompliantExprT: @@ -118,7 +117,7 @@ class DepthTrackingNamespace( def all(self) -> DepthTrackingExprT: return self._expr.from_column_names(get_column_names, context=self) - def col(self, names: Sequence[str]) -> DepthTrackingExprT: + def col(self, *names: str) -> DepthTrackingExprT: return self._expr.from_column_names(passthrough_column_names(names), context=self) def exclude(self, names: Sequence[str]) -> DepthTrackingExprT: diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 4da9097f14..376bf6a887 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -193,7 +193,7 @@ def drop_nulls(self, subset: Sequence[str] | None) -> Self: if subset is None: return self._with_native(self.native.dropna()) plx = self.__narwhals_namespace__() - mask = ~plx.any_horizontal(plx.col(subset).is_null(), ignore_nulls=True) + mask = ~plx.any_horizontal(plx.col(*subset).is_null(), ignore_nulls=True) return self.filter(mask) @property @@ -225,12 +225,10 @@ def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self: columns = self.columns const_expr = plx.lit(value=1, dtype=None).alias(name).broadcast(ExprKind.LITERAL) row_index_expr = ( - plx.col([name]) - .cum_sum(reverse=False) - .over(partition_by=[], order_by=order_by) + plx.col(name).cum_sum(reverse=False).over(partition_by=[], order_by=order_by) - 1 ) - return self.with_columns(const_expr).select(row_index_expr, plx.col(columns)) + return self.with_columns(const_expr).select(row_index_expr, plx.col(*columns)) def rename(self, mapping: Mapping[str, str]) -> Self: return self._with_native(self.native.rename(columns=mapping)) @@ -484,8 +482,8 @@ def gather_every(self, n: int, offset: int) -> Self: return ( self.with_row_index(row_index_token, order_by=None) .filter( - (plx.col([row_index_token]) >= offset) - & ((plx.col([row_index_token]) - offset) % n == 0) + (plx.col(row_index_token) >= offset) + & ((plx.col(row_index_token) - offset) % n == 0) ) .drop([row_index_token], strict=False) ) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 2d927508c0..6e1f4dc936 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -216,6 +216,18 @@ def __and__(self, other: ExpansionKind) -> Literal[ExpansionKind.MULTI_UNNAMED]: class ExprNode: + """An operation to create or modify an expression. + + Parameters: + kind: ExprKind of operation. + name: Name of function, as defined in the compliant protocols. + exprs: Expressifiable arguments to function. + str_as_lit: Whether to interpret strings as literals when they + are present in `exprs`. + allow_multi_output: Whether to allow any of `exprs` to be multi-output. + kwargs: Other (non-expressifiable) arguments to function. + """ + def __init__( self, kind: ExprKind, @@ -785,20 +797,25 @@ def maybe_broadcast_ces( def evaluate_root_node(node: ExprNode, ns: CompliantNamespaceAny) -> CompliantExprAny: - if "." in node.name: - module, method = node.name.split(".") - func = getattr(getattr(ns, module), method) + if node.name == "col": + # There's too much potential for Sequence[str] vs str bugs. + ce = getattr(ns, node.name)(*node.kwargs["names"]) + ces = [] else: - func = getattr(ns, node.name) - ces = maybe_broadcast_ces( - *evaluate_into_exprs( - *node.exprs, - ns=ns, - str_as_lit=node.str_as_lit, - allow_multi_output=node.allow_multi_output, + if "." in node.name: + module, method = node.name.split(".") + func = getattr(getattr(ns, module), method) + else: + func = getattr(ns, node.name) + ces = maybe_broadcast_ces( + *evaluate_into_exprs( + *node.exprs, + ns=ns, + str_as_lit=node.str_as_lit, + allow_multi_output=node.allow_multi_output, + ) ) - ) - ce = cast("CompliantExpr[Any, Any]", func(*ces, **node.kwargs)) + ce = cast("CompliantExprAny", func(*ces, **node.kwargs)) md = ExprMetadata.from_node(node, *ces) ce._opt_metadata = md return ce diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 1b06111c14..ddbe20c7a0 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -255,7 +255,7 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: # noqa: C901, output_names, aliases = evaluate_output_names_and_aliases(self, df, []) if function_name == "cum_count": plx = self.__narwhals_namespace__() - df = df.with_columns(~plx.col(output_names).is_null()) + df = df.with_columns(~plx.col(*output_names).is_null()) if function_name.startswith("cum_"): assert "reverse" in scalar_kwargs # noqa: S101 diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 2413895f30..3ecd7f6029 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -35,6 +35,7 @@ class PolarsNamespace: all: Method[PolarsExpr] coalesce: Method[PolarsExpr] + col: Method[PolarsExpr] sum_horizontal: Method[PolarsExpr] min_horizontal: Method[PolarsExpr] max_horizontal: Method[PolarsExpr] @@ -121,9 +122,6 @@ def from_numpy( return self._dataframe.from_numpy(data, schema=schema, context=self) return self._series.from_numpy(data, context=self) # pragma: no cover - def col(self, names: Sequence[str]) -> PolarsExpr: - return self._expr(pl.col(*names), version=self._version) - def exclude(self, names: Sequence[str]) -> PolarsExpr: return self._expr(pl.exclude(*names), version=self._version) diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index 2b2e7ce2da..fadc4bac76 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -743,7 +743,7 @@ class PolarsSeriesStringNamespace( def zfill(self, width: int) -> PolarsSeries: name = self.name ns = self.__narwhals_namespace__() - return self.to_frame().select(ns.col([name]).str.zfill(width)).get_column(name) + return self.to_frame().select(ns.col(name).str.zfill(width)).get_column(name) def replace(self, value: str, pattern: str, *, literal: bool, n: int) -> PolarsSeries: return self.compliant._with_native( @@ -767,12 +767,12 @@ class PolarsSeriesListNamespace( def len(self) -> PolarsSeries: name = self.name ns = self.__narwhals_namespace__() - return self.to_frame().select(ns.col([name]).list.len()).get_column(name) + return self.to_frame().select(ns.col(name).list.len()).get_column(name) def contains(self, item: NonNestedLiteral) -> PolarsSeries: name = self.name ns = self.__narwhals_namespace__() - return self.to_frame().select(ns.col([name]).list.contains(item)).get_column(name) + return self.to_frame().select(ns.col(name).list.contains(item)).get_column(name) class PolarsSeriesStructNamespace( From 4048ae6a39ce9f6cb32481640f1bd5142ce6e5b9 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 26 Sep 2025 11:44:51 +0300 Subject: [PATCH 11/95] safer `col`, fix typing --- narwhals/_pandas_like/dataframe.py | 4 ++-- narwhals/expr.py | 11 ++++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index abcbb12c01..d4db72b894 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -410,7 +410,7 @@ def drop_nulls(self, subset: Sequence[str] | None) -> Self: self.native.dropna(axis=0), validate_column_names=False ) plx = self.__narwhals_namespace__() - mask = ~plx.any_horizontal(plx.col(subset).is_null(), ignore_nulls=True) + mask = ~plx.any_horizontal(plx.col(*subset).is_null(), ignore_nulls=True) return self.filter(mask) def estimated_size(self, unit: SizeUnit) -> int | float: @@ -886,7 +886,7 @@ def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray: if to_convert: df = self.with_columns( self.__narwhals_namespace__() - .col(to_convert) + .col(*to_convert) .dt.convert_time_zone("UTC") .dt.replace_time_zone(None) ).native diff --git a/narwhals/expr.py b/narwhals/expr.py index f3595a1baa..99e3585551 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -28,6 +28,7 @@ from narwhals._compliant import CompliantExpr, CompliantNamespace from narwhals.dtypes import DType + from narwhals.series import Series from narwhals.typing import ( ClosedInterval, FillNullStrategy, @@ -2088,7 +2089,7 @@ def sqrt(self) -> Self: def is_close( # noqa: PLR0914 self, - other: IntoExpr | NumericLiteral, + other: Expr | Series[Any] | NumericLiteral, *, abs_tol: float = 0.0, rel_tol: float = 1e-09, @@ -2158,10 +2159,10 @@ def is_close( # noqa: PLR0914 from decimal import Decimal - other_abs: Self | NumericLiteral - other_is_nan: Self | bool - other_is_inf: Self | bool - other_is_not_inf: Self | bool + other_abs: Expr | Series[Any] | NumericLiteral + other_is_nan: Expr | Series[Any] | bool + other_is_inf: Expr | Series[Any] | bool + other_is_not_inf: Expr | Series[Any] | bool if isinstance(other, (float, int, Decimal)): from math import isinf, isnan From 906f7fbb264506e6c67e7f03d988b98cdb0cfe20 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 26 Sep 2025 11:46:43 +0300 Subject: [PATCH 12/95] :art: --- narwhals/series.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/narwhals/series.py b/narwhals/series.py index 6eaf95fddd..6e749bffd8 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -2711,14 +2711,12 @@ def is_close( "Hint: `is_close` is only supported for numeric types" ) raise InvalidOperationError(msg) - return cast( - "Self", - self.to_frame().select( - col(self.name).is_close( - other, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal - ) - )[self.name], + ret_df = self.to_frame().select( + col(self.name).is_close( + other, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal + ) ) + return cast("Self", ret_df[self.name]) @property def str(self) -> SeriesStringNamespace[Self]: From a457bf066554044cc5af01b7d3a64149450ae532 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 26 Sep 2025 11:50:37 +0300 Subject: [PATCH 13/95] exclude too --- narwhals/_compliant/namespace.py | 4 ++-- narwhals/_expression_parsing.py | 5 +++-- narwhals/_polars/namespace.py | 4 +--- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index 25fd753a28..d6554db2fc 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -77,7 +77,7 @@ def all(self) -> CompliantExprT: def col(self, *names: str) -> CompliantExprT: return self._expr.from_column_names(passthrough_column_names(names), context=self) - def exclude(self, names: Sequence[str]) -> CompliantExprT: + def exclude(self, *names: str) -> CompliantExprT: return self._expr.from_column_names( partial(exclude_column_names, names=names), context=self ) @@ -120,7 +120,7 @@ def all(self) -> DepthTrackingExprT: def col(self, *names: str) -> DepthTrackingExprT: return self._expr.from_column_names(passthrough_column_names(names), context=self) - def exclude(self, names: Sequence[str]) -> DepthTrackingExprT: + def exclude(self, *names: str) -> DepthTrackingExprT: return self._expr.from_column_names( partial(exclude_column_names, names=names), context=self ) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 6e1f4dc936..c2aded0afd 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -797,8 +797,9 @@ def maybe_broadcast_ces( def evaluate_root_node(node: ExprNode, ns: CompliantNamespaceAny) -> CompliantExprAny: - if node.name == "col": - # There's too much potential for Sequence[str] vs str bugs. + if node.name in {"col", "exclude"}: + # There's too much potential for Sequence[str] vs str bugs, so we pass down + # `names` positionally rather than as a sequence of strings. ce = getattr(ns, node.name)(*node.kwargs["names"]) ces = [] else: diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 3ecd7f6029..699f90f110 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -36,6 +36,7 @@ class PolarsNamespace: all: Method[PolarsExpr] coalesce: Method[PolarsExpr] col: Method[PolarsExpr] + exclude: Method[PolarsExpr] sum_horizontal: Method[PolarsExpr] min_horizontal: Method[PolarsExpr] max_horizontal: Method[PolarsExpr] @@ -122,9 +123,6 @@ def from_numpy( return self._dataframe.from_numpy(data, schema=schema, context=self) return self._series.from_numpy(data, context=self) # pragma: no cover - def exclude(self, names: Sequence[str]) -> PolarsExpr: - return self._expr(pl.exclude(*names), version=self._version) - @requires.backend_version( (1, 0, 0), "Please use `col` for columns selection instead." ) From f29d8ad6fcde48992b1f8c8f4e273c48610c242f Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 26 Sep 2025 14:13:15 +0300 Subject: [PATCH 14/95] typing --- tests/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 85bbb8432f..e8c8044276 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -138,11 +138,11 @@ def polars_lazy_constructor(obj: Data) -> pl.LazyFrame: def duckdb_lazy_constructor(obj: Data) -> duckdb.DuckDBPyRelation: import duckdb - import pyarrow as pa + import polars as pl duckdb.sql("""set timezone = 'UTC'""") - _df = pa.table(obj) + _df = pl.LazyFrame(obj) return duckdb.table("_df") From 11890a9c4922ad3d3fc36b3f9d532699f05b15f1 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 26 Sep 2025 14:18:45 +0300 Subject: [PATCH 15/95] mypy --- narwhals/_expression_parsing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index c2aded0afd..16f9972eaa 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -269,7 +269,7 @@ def _with_kwargs(self, **kwargs: Any) -> ExprNode: def _push_down_over_node_in_place( self, over_node: ExprNode, over_node_without_order_by: ExprNode ) -> None: - exprs = [] + exprs: list[IntoExpr | NonNestedLiteral] = [] # Note: please keep this as a for-loop (rather than a list-comprehension) # so that pytest-cov highlights any uncovered branches. for expr in self.exprs: From 07ed5eeb509c447082d81d92dcb1d47674650d3a Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 26 Sep 2025 16:15:49 +0300 Subject: [PATCH 16/95] remove unnecessary check --- narwhals/expr.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/narwhals/expr.py b/narwhals/expr.py index 99e3585551..5ab722b0bc 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -76,13 +76,12 @@ def _with_node(self, node: ExprNode) -> Self: # node could not be pushed down, just append as-is new_nodes.append(node) return self.__class__(*new_nodes) - if i > 0: - if node.kwargs["order_by"] and any( - node.is_orderable() for node in new_nodes[:i] - ): - new_nodes.insert(i, node) - elif node.kwargs["partition_by"]: - new_nodes.insert(i, node_without_order_by) + if node.kwargs["order_by"] and any( + node.is_orderable() for node in new_nodes[:i] + ): + new_nodes.insert(i, node) + elif node.kwargs["partition_by"]: + new_nodes.insert(i, node_without_order_by) return self.__class__(*new_nodes) return self.__class__(*self._nodes, node) From 06eafaae54044759793984f3531fd88fe22b2a6c Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 26 Sep 2025 21:28:58 +0300 Subject: [PATCH 17/95] wait how tf doe thi work --- narwhals/_arrow/dataframe.py | 7 ++----- narwhals/_arrow/series.py | 4 +--- narwhals/_compliant/expr.py | 8 +++----- narwhals/_dask/dataframe.py | 3 +-- narwhals/_dask/expr.py | 8 ++++---- narwhals/_duckdb/expr.py | 8 ++++---- narwhals/_expression_parsing.py | 21 ++++++++++++--------- narwhals/_ibis/expr.py | 6 +++--- narwhals/_polars/expr.py | 6 +++--- narwhals/_spark_like/expr.py | 6 +++--- narwhals/dataframe.py | 32 +++++++++++++++++--------------- narwhals/group_by.py | 9 +++++---- 12 files changed, 58 insertions(+), 60 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 9eed006edc..e8d1a82c84 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -10,7 +10,6 @@ from narwhals._arrow.series import ArrowSeries from narwhals._arrow.utils import native_to_narwhals_dtype from narwhals._compliant import EagerDataFrame -from narwhals._expression_parsing import ExprKind from narwhals._utils import ( Implementation, Version, @@ -387,12 +386,10 @@ def join( ) return self._with_native( - self.with_columns( - plx.lit(0, None).alias(key_token).broadcast(ExprKind.LITERAL) - ) + self.with_columns(plx.lit(0, None).alias(key_token).broadcast()) .native.join( other.with_columns( - plx.lit(0, None).alias(key_token).broadcast(ExprKind.LITERAL) + plx.lit(0, None).alias(key_token).broadcast() ).native, keys=key_token, right_keys=key_token, diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 9f63776c40..7b402599d4 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -24,7 +24,6 @@ zeros, ) from narwhals._compliant import EagerSeries, EagerSeriesHist -from narwhals._expression_parsing import ExprKind from narwhals._typing_compat import assert_never from narwhals._utils import ( Implementation, @@ -866,8 +865,7 @@ def mode(self, *, keep: ModeKeepStrategy) -> ArrowSeries: name=col_token, normalize=False, sort=False, parallel=False ) result = counts.filter( - plx.col(col_token) - == plx.col(col_token).max().broadcast(kind=ExprKind.AGGREGATION) + plx.col(col_token) == plx.col(col_token).max().broadcast() ).get_column(self.name) return result.head(1) if keep == "any" else result diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index 54ae5d7f41..ca5063b1bd 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -45,7 +45,7 @@ from narwhals._compliant.namespace import CompliantNamespace, EagerNamespace from narwhals._compliant.series import CompliantSeries from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries - from narwhals._expression_parsing import ExprKind, ExprMetadata + from narwhals._expression_parsing import ExprMetadata from narwhals._utils import Implementation, Version, _LimitedContext from narwhals.typing import ( ClosedInterval, @@ -119,9 +119,7 @@ def from_column_names( *, context: _LimitedContext, ) -> Self: ... - def broadcast( - self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL] - ) -> Self: ... + def broadcast(self) -> Self: ... # NOTE: `polars` def alias(self, name: str) -> Self: ... @@ -407,7 +405,7 @@ def inner(df: EagerDataFrameT) -> list[EagerSeriesT]: context=self, ) - def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: + def broadcast(self) -> Self: # Mark the resulting Series with `_broadcast = True`. # Then, when extracting native objects, `extract_native` will # know what to do. diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 376bf6a887..6bddb729a1 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -5,7 +5,6 @@ import dask.dataframe as dd from narwhals._dask.utils import add_row_index, evaluate_exprs -from narwhals._expression_parsing import ExprKind from narwhals._pandas_like.utils import native_to_narwhals_dtype, select_columns_by_name from narwhals._typing_compat import assert_never from narwhals._utils import ( @@ -223,7 +222,7 @@ def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self: return self._with_native(add_row_index(self.native, name)) plx = self.__narwhals_namespace__() columns = self.columns - const_expr = plx.lit(value=1, dtype=None).alias(name).broadcast(ExprKind.LITERAL) + const_expr = plx.lit(value=1, dtype=None).alias(name).broadcast() row_index_expr = ( plx.col(name).cum_sum(reverse=False).over(partition_by=[], order_by=order_by) - 1 diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index a98b539902..b15b08e3a6 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any, Callable, Literal, cast +from typing import TYPE_CHECKING, Any, Callable, cast import pandas as pd @@ -13,7 +13,7 @@ maybe_evaluate_expr, narwhals_to_native_dtype, ) -from narwhals._expression_parsing import ExprKind, evaluate_output_names_and_aliases +from narwhals._expression_parsing import evaluate_output_names_and_aliases from narwhals._pandas_like.expr import window_kwargs_to_pandas_equivalent from narwhals._pandas_like.utils import get_dtype_backend, native_to_narwhals_dtype from narwhals._utils import ( @@ -37,7 +37,7 @@ ) from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.namespace import DaskNamespace - from narwhals._expression_parsing import ExprKind, ExprMetadata + from narwhals._expression_parsing import ExprMetadata from narwhals._utils import Version, _LimitedContext from narwhals.typing import ( FillNullStrategy, @@ -78,7 +78,7 @@ def __narwhals_namespace__(self) -> DaskNamespace: # pragma: no cover return DaskNamespace(version=self._version) - def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: + def broadcast(self) -> Self: def func(df: DaskLazyFrame) -> list[dx.Series]: # result.loc[0][0] is a workaround for dask~<=2024.10.0/dask_expr~<=1.1.16 # that raised a KeyError for result[0] during collection. diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index a93058b6df..b4b811a9ee 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -1,7 +1,7 @@ from __future__ import annotations import operator -from typing import TYPE_CHECKING, Any, Callable, Literal, cast +from typing import TYPE_CHECKING, Any, Callable, cast from duckdb import CoalesceOperator, StarExpression @@ -18,7 +18,6 @@ when, window_expression, ) -from narwhals._expression_parsing import ExprKind, ExprMetadata from narwhals._sql.expr import SQLExpr from narwhals._utils import Implementation, Version @@ -37,6 +36,7 @@ ) from narwhals._duckdb.dataframe import DuckDBLazyFrame from narwhals._duckdb.namespace import DuckDBNamespace + from narwhals._expression_parsing import ExprMetadata from narwhals._utils import _LimitedContext from narwhals.typing import ( FillNullStrategy, @@ -98,8 +98,8 @@ def __narwhals_namespace__(self) -> DuckDBNamespace: # pragma: no cover return DuckDBNamespace(version=self._version) - def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: - if kind is ExprKind.LITERAL: + def broadcast(self) -> Self: + if self._metadata.is_literal: return self if self._backend_version < (1, 3): msg = "At least version 1.3 of DuckDB is required for binary operations between aggregates and columns." diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 16f9972eaa..d2de0460c0 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -182,10 +182,10 @@ def from_into_expr(cls, obj: CompliantExprAny | NonNestedLiteral) -> ExprKind: return ExprKind.LITERAL -def is_scalar_like( - obj: ExprKind, -) -> TypeIs[Literal[ExprKind.LITERAL, ExprKind.AGGREGATION]]: - return obj.is_scalar_like +def is_scalar_like(obj: CompliantExprAny | NonNestedLiteral) -> bool: + if is_compliant_expr(obj): + return obj._metadata.is_scalar_like + return True class ExpansionKind(Enum): @@ -782,12 +782,15 @@ def evaluate_into_exprs( def maybe_broadcast_ces( *ces: CompliantExprAny | NonNestedLiteral, ) -> list[CompliantExprAny | NonNestedLiteral]: - kinds = [ExprKind.from_into_expr(comparand) for comparand in ces] - broadcast = any(not kind.is_scalar_like for kind in kinds) + broadcast = any(not is_scalar_like(ce) for ce in ces) results: list[CompliantExprAny | NonNestedLiteral] = [] - for compliant_expr, kind in zip_strict(ces, kinds): - if broadcast and is_compliant_expr(compliant_expr) and is_scalar_like(kind): - _compliant_expr: CompliantExprAny = compliant_expr.broadcast(kind) + for compliant_expr in ces: + if ( + broadcast + and is_compliant_expr(compliant_expr) + and is_scalar_like(compliant_expr) + ): + _compliant_expr: CompliantExprAny = compliant_expr.broadcast() # Make sure to preserve metadata. _compliant_expr._opt_metadata = compliant_expr._metadata results.append(_compliant_expr) diff --git a/narwhals/_ibis/expr.py b/narwhals/_ibis/expr.py index 22f1174a5d..a282ce78ef 100644 --- a/narwhals/_ibis/expr.py +++ b/narwhals/_ibis/expr.py @@ -1,7 +1,7 @@ from __future__ import annotations import operator -from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, cast +from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast import ibis @@ -36,7 +36,7 @@ EvalSeries, WindowFunction, ) - from narwhals._expression_parsing import ExprKind, ExprMetadata + from narwhals._expression_parsing import ExprMetadata from narwhals._ibis.dataframe import IbisLazyFrame from narwhals._ibis.namespace import IbisNamespace from narwhals._utils import _LimitedContext @@ -117,7 +117,7 @@ def __narwhals_namespace__(self) -> IbisNamespace: # pragma: no cover return IbisNamespace(version=self._version) - def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: + def broadcast(self) -> Self: # Ibis does its own broadcasting. return self diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index a57a02ce2c..bcfa9b4008 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, cast +from typing import TYPE_CHECKING, Any, Callable, ClassVar, cast import polars as pl @@ -23,7 +23,7 @@ from typing_extensions import Self from narwhals._compliant.typing import Accessor - from narwhals._expression_parsing import ExprKind, ExprMetadata + from narwhals._expression_parsing import ExprMetadata from narwhals._polars.dataframe import Method from narwhals._polars.namespace import PolarsNamespace from narwhals._polars.series import PolarsSeries @@ -92,7 +92,7 @@ def func(*args: Any, **kwargs: Any) -> Any: return func - def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: + def broadcast(self) -> Self: # Let Polars do its thing. return self diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 3a2b4bbb5b..afae4ef01b 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -3,7 +3,6 @@ import operator from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, cast -from narwhals._expression_parsing import ExprKind, ExprMetadata from narwhals._spark_like.expr_dt import SparkLikeExprDateTimeNamespace from narwhals._spark_like.expr_list import SparkLikeExprListNamespace from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace @@ -32,6 +31,7 @@ EvalSeries, WindowFunction, ) + from narwhals._expression_parsing import ExprMetadata from narwhals._spark_like.dataframe import SparkLikeLazyFrame from narwhals._spark_like.namespace import SparkLikeNamespace from narwhals._utils import _LimitedContext @@ -96,8 +96,8 @@ def _window_expression( window = window.rowsBetween(rows_start, self._Window.unboundedFollowing) return expr.over(window) - def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: - if kind is ExprKind.LITERAL: + def broadcast(self) -> Self: + if self._metadata.is_literal: return self return self.over([self._F.lit(1)], []) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 5d22380b1f..00a2339542 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -143,11 +143,10 @@ def _with_compliant(self, df: Any) -> Self: def _flatten_and_extract( self, *exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: IntoExpr - ) -> tuple[list[CompliantExprAny], list[ExprKind]]: + ) -> list[CompliantExprAny]: # Process `args` and `kwargs`, extracting underlying objects as we go. # NOTE: Strings are interpreted as column names. out_exprs = [] - out_kinds = [] ns = self.__narwhals_namespace__() parse = partial( _parse_into_expr, backend=self._compliant._implementation, allow_literal=False @@ -159,9 +158,8 @@ def _flatten_and_extract( for expr in all_exprs: ce = expr(ns) out_exprs.append(ce) - out_kinds.append(ExprKind.from_expr(ce)) self._validate_metadata(ce._metadata) - return out_exprs, out_kinds + return out_exprs def _extract_compliant_frame(self, other: Self | Any, /) -> Any: if isinstance(other, type(self)): @@ -204,10 +202,12 @@ def columns(self) -> list[str]: def with_columns( self, *exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: IntoExpr ) -> Self: - compliant_exprs, kinds = self._flatten_and_extract(*exprs, **named_exprs) + compliant_exprs = self._flatten_and_extract(*exprs, **named_exprs) compliant_exprs = [ - compliant_expr.broadcast(kind) if is_scalar_like(kind) else compliant_expr - for compliant_expr, kind in zip_strict(compliant_exprs, kinds) + compliant_expr.broadcast() + if is_scalar_like(compliant_expr) + else compliant_expr + for compliant_expr in compliant_exprs ] return self._with_compliant(self._compliant_frame.with_columns(*compliant_exprs)) @@ -226,12 +226,14 @@ def select( if error := self._check_columns_exist(flat_exprs): raise error from e raise - compliant_exprs, kinds = self._flatten_and_extract(*flat_exprs, **named_exprs) - if compliant_exprs and all(x.is_scalar_like for x in kinds): + compliant_exprs = self._flatten_and_extract(*flat_exprs, **named_exprs) + if compliant_exprs and all(is_scalar_like(x) for x in compliant_exprs): return self._with_compliant(self._compliant_frame.aggregate(*compliant_exprs)) compliant_exprs = [ - compliant_expr.broadcast(kind) if is_scalar_like(kind) else compliant_expr - for compliant_expr, kind in zip_strict(compliant_exprs, kinds) + compliant_expr.broadcast() + if is_scalar_like(compliant_expr) + else compliant_expr + for compliant_expr in compliant_exprs ] return self._with_compliant(self._compliant_frame.select(*compliant_exprs)) @@ -257,11 +259,11 @@ def filter( flat_predicates = flatten(predicates) plx = self.__narwhals_namespace__() - compliant_predicates, _kinds = self._flatten_and_extract(*flat_predicates) + compliant_predicates = self._flatten_and_extract(*flat_predicates) check_expressions_preserve_length( *compliant_predicates, function_name="filter" ) - compliant_constraints, _ = self._flatten_and_extract( + compliant_constraints = self._flatten_and_extract( *[col(name) == v for name, v in constraints.items()] ) predicate = plx.all_horizontal( @@ -1720,9 +1722,9 @@ def group_by( k if is_expr else col(k) for k, is_expr in zip_strict(flat_keys, key_is_expr_or_series) ] - expr_flat_keys, kinds = self._flatten_and_extract(*_keys) + expr_flat_keys = self._flatten_and_extract(*_keys) - if not all(kind is ExprKind.ELEMENTWISE for kind in kinds): + if not all(x._metadata.is_elementwise for x in expr_flat_keys): from narwhals.exceptions import ComputeError msg = ( diff --git a/narwhals/group_by.py b/narwhals/group_by.py index 1bdd8e9ec2..65b9a9b48b 100644 --- a/narwhals/group_by.py +++ b/narwhals/group_by.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar +from narwhals._expression_parsing import is_scalar_like from narwhals._utils import tupleify from narwhals.exceptions import InvalidOperationError from narwhals.typing import DataFrameT @@ -71,8 +72,8 @@ def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> DataFrameT: 2 b 3 2 3 c 3 1 """ - compliant_aggs, kinds = self._df._flatten_and_extract(*aggs, **named_aggs) - if not all(x.is_scalar_like for x in kinds): + compliant_aggs = self._df._flatten_and_extract(*aggs, **named_aggs) + if not all(is_scalar_like(x) for x in compliant_aggs): msg = ( "Found expression which does not aggregate.\n\n" "All expressions passed to GroupBy.agg must aggregate.\n" @@ -157,8 +158,8 @@ def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> LazyFrameT: |└─────┴─────┴─────┘| └───────────────────┘ """ - compliant_aggs, kinds = self._df._flatten_and_extract(*aggs, **named_aggs) - if not all(x.is_scalar_like for x in kinds): + compliant_aggs = self._df._flatten_and_extract(*aggs, **named_aggs) + if not all(is_scalar_like(x) for x in compliant_aggs): msg = ( "Found expression which does not aggregate.\n\n" "All expressions passed to GroupBy.agg must aggregate.\n" From 53c048ff71b7483736b3cf798fb57edf3528de5e Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 26 Sep 2025 19:31:21 +0100 Subject: [PATCH 18/95] grossly simplify broadcast --- narwhals/_dask/namespace.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 2133dc235c..7611169fd0 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -18,7 +18,6 @@ validate_comparand, ) from narwhals._expression_parsing import ( - ExprKind, combine_alias_output_names, combine_evaluate_output_names, ) @@ -288,7 +287,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: ) ): new_df = df._with_native(condition.to_frame()) - condition = predicate.broadcast(ExprKind.AGGREGATION)(df)[0] + condition = predicate.broadcast()(df)[0] df = new_df if otherwise is None: From 1bc6c95c2cd985f841768ea95bd920d9b55dcd61 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 26 Sep 2025 19:39:30 +0100 Subject: [PATCH 19/95] simplify --- narwhals/_expression_parsing.py | 28 ++++++---------------------- narwhals/dataframe.py | 6 +++--- 2 files changed, 9 insertions(+), 25 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index d2de0460c0..6179c4e978 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -148,10 +148,6 @@ class ExprKind(Enum): UNKNOWN = auto() """Based on the information we have, we can't determine the ExprKind.""" - @property - def is_scalar_like(self) -> bool: - return self in {ExprKind.LITERAL, ExprKind.AGGREGATION} - @property def is_orderable(self) -> bool: # Any operation which may be affected by `order_by`, such as `cum_sum`, @@ -163,24 +159,6 @@ def is_orderable(self) -> bool: ExprKind.WINDOW, } - @classmethod - def from_expr(cls, obj: CompliantExprAny) -> ExprKind: - meta = obj._metadata - assert meta is not None # noqa: S101 - if meta.is_literal: - return ExprKind.LITERAL - if meta.is_scalar_like: - return ExprKind.AGGREGATION - if meta.is_elementwise: - return ExprKind.ELEMENTWISE - return ExprKind.UNKNOWN - - @classmethod - def from_into_expr(cls, obj: CompliantExprAny | NonNestedLiteral) -> ExprKind: - if is_compliant_expr(obj): - return cls.from_expr(obj) - return ExprKind.LITERAL - def is_scalar_like(obj: CompliantExprAny | NonNestedLiteral) -> bool: if is_compliant_expr(obj): @@ -188,6 +166,12 @@ def is_scalar_like(obj: CompliantExprAny | NonNestedLiteral) -> bool: return True +def is_elementwise(obj: CompliantExprAny | NonNestedLiteral) -> bool: + if is_compliant_expr(obj): + return obj._metadata.is_elementwise + return False + + class ExpansionKind(Enum): """Describe what kind of expansion the expression performs.""" diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 00a2339542..9c71ba29ae 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -18,9 +18,9 @@ from narwhals._exceptions import issue_warning from narwhals._expression_parsing import ( - ExprKind, _parse_into_expr, check_expressions_preserve_length, + is_elementwise, is_scalar_like, ) from narwhals._typing import Arrow, Pandas, _LazyAllowedImpl, _LazyFrameCollectImpl @@ -2958,8 +2958,8 @@ def group_by( _keys = [ k if is_expr else col(k) for k, is_expr in zip_strict(flat_keys, key_is_expr) ] - expr_flat_keys, kinds = self._flatten_and_extract(*_keys) - if not all(kind is ExprKind.ELEMENTWISE for kind in kinds): + expr_flat_keys = self._flatten_and_extract(*_keys) + if not all(is_elementwise(x) for x in expr_flat_keys): from narwhals.exceptions import ComputeError msg = ( From c2209d2151b430b20ca992a15d7d43a7781cdecf Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 27 Sep 2025 13:38:30 +0100 Subject: [PATCH 20/95] cov --- narwhals/_expression_parsing.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 6179c4e978..cfcb8bf812 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -166,12 +166,6 @@ def is_scalar_like(obj: CompliantExprAny | NonNestedLiteral) -> bool: return True -def is_elementwise(obj: CompliantExprAny | NonNestedLiteral) -> bool: - if is_compliant_expr(obj): - return obj._metadata.is_elementwise - return False - - class ExpansionKind(Enum): """Describe what kind of expansion the expression performs.""" From 68108135ecdb141bb44892816b36603ae75befed Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 27 Sep 2025 14:30:54 +0100 Subject: [PATCH 21/95] post merge fixup --- narwhals/dataframe.py | 8 ++++++-- tests/frame/group_by_test.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 3ccc880ce5..177cdd18c5 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -1722,7 +1722,9 @@ def group_by( for k, is_expr in zip_strict(flat_keys, key_is_expr_or_series) ] expr_flat_keys = self._flatten_and_extract(*_keys) - check_expressions_preserve_length(*_keys, function_name="DataFrame.group_by") + check_expressions_preserve_length( + *expr_flat_keys, function_name="DataFrame.group_by" + ) return GroupBy(self, expr_flat_keys, drop_null_keys=drop_null_keys) def sort( @@ -2950,7 +2952,9 @@ def group_by( k if is_expr else col(k) for k, is_expr in zip_strict(flat_keys, key_is_expr) ] expr_flat_keys = self._flatten_and_extract(*_keys) - check_expressions_preserve_length(*_keys, function_name="LazyFrame.group_by") + check_expressions_preserve_length( + *expr_flat_keys, function_name="LazyFrame.group_by" + ) return LazyGroupBy(self, expr_flat_keys, drop_null_keys=drop_null_keys) def sort( diff --git a/tests/frame/group_by_test.py b/tests/frame/group_by_test.py index 143898b4fb..f37b1418fe 100644 --- a/tests/frame/group_by_test.py +++ b/tests/frame/group_by_test.py @@ -525,7 +525,7 @@ def test_group_by_raise_if_not_preserves_length( ) -> None: data = {"a": [1, 2, 2, None], "b": [0, 1, 2, 3], "x": [1, 2, 3, 4]} df = nw.from_native(constructor(data)) - with pytest.raises(InvalidOperationError): + with pytest.raises((InvalidOperationError, NotImplementedError)): df.group_by(keys).agg(nw.col("x").max()) From 48a9dfdefb0fe5e3164aac476b050af5b821e3b4 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 27 Sep 2025 15:04:29 +0100 Subject: [PATCH 22/95] even simpler! --- narwhals/_dask/namespace.py | 1 - narwhals/_duckdb/expr.py | 5 ----- narwhals/_spark_like/expr.py | 2 -- 3 files changed, 8 deletions(-) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 7611169fd0..f0b937e543 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -273,7 +273,6 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: condition = predicate(df)[0] # re-evaluate DataFrame if the condition aggregates to force # then/otherwise to be evaluated against the aggregated frame - assert predicate._metadata is not None # noqa: S101 if all( x._metadata.is_scalar_like for x in ( diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index b4b811a9ee..6aa6310514 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -99,11 +99,6 @@ def __narwhals_namespace__(self) -> DuckDBNamespace: # pragma: no cover return DuckDBNamespace(version=self._version) def broadcast(self) -> Self: - if self._metadata.is_literal: - return self - if self._backend_version < (1, 3): - msg = "At least version 1.3 of DuckDB is required for binary operations between aggregates and columns." - raise NotImplementedError(msg) return self.over([lit(1)], []) @classmethod diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index afae4ef01b..dacb01c382 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -97,8 +97,6 @@ def _window_expression( return expr.over(window) def broadcast(self) -> Self: - if self._metadata.is_literal: - return self return self.over([self._F.lit(1)], []) @property From 5ba10edf2a8f59103084c4d3e34e2c2381fa67da Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 27 Sep 2025 15:06:17 +0100 Subject: [PATCH 23/95] assign variable --- narwhals/_arrow/group_by.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/narwhals/_arrow/group_by.py b/narwhals/_arrow/group_by.py index d205dd283e..2544b0e378 100644 --- a/narwhals/_arrow/group_by.py +++ b/narwhals/_arrow/group_by.py @@ -72,9 +72,10 @@ def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame: expr, self.compliant, exclude ) md = expr._metadata - if len(list(md.op_nodes_reversed())) == 1: + op_nodes_reversed = list(md.op_nodes_reversed()) + if len(op_nodes_reversed) == 1: # e.g. `agg(nw.len())` - if next(md.op_nodes_reversed()).name != "len": # pragma: no cover + if op_nodes_reversed[0].name != "len": # pragma: no cover msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" raise AssertionError(msg) @@ -85,7 +86,7 @@ def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame: function_name = self._leaf_name(expr) if function_name in {"std", "var"}: - last_node = next(md.op_nodes_reversed()) + last_node = op_nodes_reversed[0] option: Any = pc.VarianceOptions(**last_node.kwargs) elif function_name in {"len", "n_unique"}: option = pc.CountOptions(mode="all") From 8cde5d217888c05a36c5dc1546358a04bfb29b10 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 27 Sep 2025 15:38:58 +0100 Subject: [PATCH 24/95] replace/replace_all typing --- narwhals/_arrow/series_str.py | 41 ++++++++++++++++------------ narwhals/_compliant/any_namespace.py | 38 ++++++++++++-------------- narwhals/_compliant/expr.py | 4 ++- narwhals/_compliant/series.py | 13 +++++---- narwhals/_dask/expr_str.py | 8 ++++-- narwhals/_ibis/expr_str.py | 4 +-- narwhals/_pandas_like/series_str.py | 6 ++-- narwhals/_polars/expr.py | 14 +++++++--- narwhals/_polars/series.py | 14 +++++++--- narwhals/_sql/expr_str.py | 4 +-- 10 files changed, 84 insertions(+), 62 deletions(-) diff --git a/narwhals/_arrow/series_str.py b/narwhals/_arrow/series_str.py index 1e1b6e752f..dcd86d8835 100644 --- a/narwhals/_arrow/series_str.py +++ b/narwhals/_arrow/series_str.py @@ -6,7 +6,12 @@ import pyarrow as pa import pyarrow.compute as pc -from narwhals._arrow.utils import ArrowSeriesNamespace, lit, parse_datetime_format +from narwhals._arrow.utils import ( + ArrowSeriesNamespace, + extract_native, + lit, + parse_datetime_format, +) from narwhals._compliant.any_namespace import StringNamespace if TYPE_CHECKING: @@ -18,25 +23,27 @@ class ArrowSeriesStringNamespace(ArrowSeriesNamespace, StringNamespace["ArrowSer def len_chars(self) -> ArrowSeries: return self.with_native(pc.utf8_length(self.native)) - def replace(self, value: str, pattern: str, *, literal: bool, n: int) -> ArrowSeries: + def replace( + self, value: ArrowSeries | str, pattern: str, *, literal: bool, n: int + ) -> ArrowSeries: fn = pc.replace_substring if literal else pc.replace_substring_regex - try: - arr = fn(self.native, pattern, replacement=value, max_replacements=n) - except TypeError as e: - if not isinstance(value, str): - msg = "PyArrow backed `.str.replace` only supports str replacement values" - raise TypeError(msg) from e - raise + value_native = extract_native(self.compliant, value) + if not isinstance(value_native, str): + msg = "PyArrow backed `.str.replace` only supports str replacement values" + raise TypeError(msg) + arr = fn(self.native, pattern, replacement=value_native, max_replacements=n) return self.with_native(arr) - def replace_all(self, value: str, pattern: str, *, literal: bool) -> ArrowSeries: - try: - return self.replace(value, pattern, literal=literal, n=-1) - except TypeError as e: - if not isinstance(value, str): - msg = "PyArrow backed `.str.replace_all` only supports str replacement values." - raise TypeError(msg) from e - raise + def replace_all( + self, value: ArrowSeries | str, pattern: str, *, literal: bool + ) -> ArrowSeries: + value_native = extract_native(self.compliant, value) + if not isinstance(value_native, str): + msg = ( + "PyArrow backed `.str.replace_all` only supports str replacement values." + ) + raise TypeError(msg) + return self.replace(value_native, pattern, literal=literal, n=-1) def strip_chars(self, characters: str | None) -> ArrowSeries: return self.with_native( diff --git a/narwhals/_compliant/any_namespace.py b/narwhals/_compliant/any_namespace.py index fc78bfdefb..503b873604 100644 --- a/narwhals/_compliant/any_namespace.py +++ b/narwhals/_compliant/any_namespace.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, ClassVar, Protocol +from typing import TYPE_CHECKING, ClassVar, Protocol, TypeVar from narwhals._utils import CompliantT_co, _StoresCompliant @@ -12,6 +12,8 @@ from narwhals._compliant.typing import Accessor from narwhals.typing import NonNestedLiteral, TimeUnit +T = TypeVar("T") + __all__ = [ "CatNamespace", "DateTimeNamespace", @@ -81,27 +83,23 @@ def to_lowercase(self) -> CompliantT_co: ... def to_uppercase(self) -> CompliantT_co: ... -class StringNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]): +class StringNamespace(_StoresCompliant[T], Protocol[T]): _accessor: ClassVar[Accessor] = "str" - def len_chars(self) -> CompliantT_co: ... - def replace( - self, value: str, pattern: str, *, literal: bool, n: int - ) -> CompliantT_co: ... - def replace_all( - self, value: str, pattern: str, *, literal: bool - ) -> CompliantT_co: ... - def strip_chars(self, characters: str | None) -> CompliantT_co: ... - def starts_with(self, prefix: str) -> CompliantT_co: ... - def ends_with(self, suffix: str) -> CompliantT_co: ... - def contains(self, pattern: str, *, literal: bool) -> CompliantT_co: ... - def slice(self, offset: int, length: int | None) -> CompliantT_co: ... - def split(self, by: str) -> CompliantT_co: ... - def to_datetime(self, format: str | None) -> CompliantT_co: ... - def to_date(self, format: str | None) -> CompliantT_co: ... - def to_lowercase(self) -> CompliantT_co: ... - def to_uppercase(self) -> CompliantT_co: ... - def zfill(self, width: int) -> CompliantT_co: ... + def len_chars(self) -> T: ... + def replace(self, value: str, pattern: str, *, literal: bool, n: int) -> T: ... + def replace_all(self, value: T, pattern: str, *, literal: bool) -> T: ... + def strip_chars(self, characters: str | None) -> T: ... + def starts_with(self, prefix: str) -> T: ... + def ends_with(self, suffix: str) -> T: ... + def contains(self, pattern: str, *, literal: bool) -> T: ... + def slice(self, offset: int, length: int | None) -> T: ... + def split(self, by: str) -> T: ... + def to_datetime(self, format: str | None) -> T: ... + def to_date(self, format: str | None) -> T: ... + def to_lowercase(self) -> T: ... + def to_uppercase(self) -> T: ... + def zfill(self, width: int) -> T: ... class StructNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]): diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index ca5063b1bd..cf74ed17b2 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -1035,7 +1035,9 @@ def replace(self, value: str, pattern: str, *, literal: bool, n: int) -> EagerEx "str", "replace", pattern=pattern, value=value, literal=literal, n=n ) - def replace_all(self, value: str, pattern: str, *, literal: bool) -> EagerExprT: + def replace_all( + self, value: EagerExprT | str, pattern: str, *, literal: bool + ) -> EagerExprT: return self.compliant._reuse_series_namespace( "str", "replace_all", pattern=pattern, value=value, literal=literal ) diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index 561ff3de6d..e5ae7d4533 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -13,6 +13,7 @@ from narwhals._compliant.typing import ( CompliantSeriesT_co, EagerDataFrameAny, + EagerSeriesT, EagerSeriesT_co, NativeSeriesT, NativeSeriesT_co, @@ -322,16 +323,16 @@ class EagerSeriesDateTimeNamespace( # type: ignore[misc] class EagerSeriesListNamespace( # type: ignore[misc] - _SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co], - ListNamespace[EagerSeriesT_co], - Protocol[EagerSeriesT_co, NativeSeriesT_co], + _SeriesNamespace[EagerSeriesT, NativeSeriesT_co], + ListNamespace[EagerSeriesT], + Protocol[EagerSeriesT, NativeSeriesT_co], ): ... class EagerSeriesStringNamespace( # type: ignore[misc] - _SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co], - StringNamespace[EagerSeriesT_co], - Protocol[EagerSeriesT_co, NativeSeriesT_co], + _SeriesNamespace[EagerSeriesT, NativeSeriesT_co], + StringNamespace[EagerSeriesT], + Protocol[EagerSeriesT, NativeSeriesT_co], ): ... diff --git a/narwhals/_dask/expr_str.py b/narwhals/_dask/expr_str.py index 846ff9c808..735deb9ac7 100644 --- a/narwhals/_dask/expr_str.py +++ b/narwhals/_dask/expr_str.py @@ -18,7 +18,9 @@ class DaskExprStringNamespace(LazyExprNamespace["DaskExpr"], StringNamespace["Da def len_chars(self) -> DaskExpr: return self.compliant._with_callable(lambda expr: expr.str.len()) - def replace(self, value: str, pattern: str, *, literal: bool, n: int) -> DaskExpr: + def replace( + self, value: DaskExpr | str, pattern: str, *, literal: bool, n: int + ) -> DaskExpr: def _replace(expr: dx.Series, value: str) -> dx.Series: try: return expr.str.replace( # pyright: ignore[reportAttributeAccessIssue] @@ -32,7 +34,9 @@ def _replace(expr: dx.Series, value: str) -> dx.Series: return self.compliant._with_callable(_replace, value=value) - def replace_all(self, value: str, pattern: str, *, literal: bool) -> DaskExpr: + def replace_all( + self, value: DaskExpr | str, pattern: str, *, literal: bool + ) -> DaskExpr: def _replace_all(expr: dx.Series, value: str) -> dx.Series: try: return expr.str.replace( # pyright: ignore[reportAttributeAccessIssue] diff --git a/narwhals/_ibis/expr_str.py b/narwhals/_ibis/expr_str.py index 312da1cccb..fdde174f2b 100644 --- a/narwhals/_ibis/expr_str.py +++ b/narwhals/_ibis/expr_str.py @@ -40,9 +40,7 @@ def fn(expr: ir.StringColumn) -> ir.StringValue: return fn - def replace_all( - self, value: str | IbisExpr, pattern: str, *, literal: bool - ) -> IbisExpr: + def replace_all(self, value: IbisExpr, pattern: str, *, literal: bool) -> IbisExpr: fn = self._replace_all_literal if literal else self._replace_all return self.compliant._with_elementwise( lambda expr, value: fn(pattern, value)(expr), value=value diff --git a/narwhals/_pandas_like/series_str.py b/narwhals/_pandas_like/series_str.py index 6f5e5af560..0597ea1a01 100644 --- a/narwhals/_pandas_like/series_str.py +++ b/narwhals/_pandas_like/series_str.py @@ -16,7 +16,7 @@ def len_chars(self) -> PandasLikeSeries: return self.with_native(self.native.str.len()) def replace( - self, value: str, pattern: str, *, literal: bool, n: int + self, value: PandasLikeSeries | str, pattern: str, *, literal: bool, n: int ) -> PandasLikeSeries: try: series = self.native.str.replace( @@ -29,7 +29,9 @@ def replace( raise return self.with_native(series) - def replace_all(self, value: str, pattern: str, *, literal: bool) -> PandasLikeSeries: + def replace_all( + self, value: PandasLikeSeries | str, pattern: str, *, literal: bool + ) -> PandasLikeSeries: return self.replace(value, pattern, literal=literal, n=-1) def strip_chars(self, characters: str | None) -> PandasLikeSeries: diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index bcfa9b4008..c6f51a7737 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -392,14 +392,20 @@ def zfill(self, width: int) -> PolarsExpr: return self.compliant._with_native(native_result) - def replace(self, value: str, pattern: str, *, literal: bool, n: int) -> PolarsExpr: + def replace( + self, value: PolarsExpr | str, pattern: str, *, literal: bool, n: int + ) -> PolarsExpr: + value_native = value if isinstance(value, str) else extract_native(value) return self.compliant._with_native( - self.native.str.replace(pattern, extract_native(value), literal=literal, n=n) + self.native.str.replace(pattern, value_native, literal=literal, n=n) ) - def replace_all(self, value: str, pattern: str, *, literal: bool) -> PolarsExpr: + def replace_all( + self, value: PolarsExpr | str, pattern: str, *, literal: bool + ) -> PolarsExpr: + value_native = value if isinstance(value, str) else extract_native(value) return self.compliant._with_native( - self.native.str.replace_all(pattern, extract_native(value), literal=literal) + self.native.str.replace_all(pattern, value_native, literal=literal) ) diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index fadc4bac76..e51516fc91 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -745,14 +745,20 @@ def zfill(self, width: int) -> PolarsSeries: ns = self.__narwhals_namespace__() return self.to_frame().select(ns.col(name).str.zfill(width)).get_column(name) - def replace(self, value: str, pattern: str, *, literal: bool, n: int) -> PolarsSeries: + def replace( + self, value: PolarsSeries | str, pattern: str, *, literal: bool, n: int + ) -> PolarsSeries: + value_native = value if isinstance(value, str) else extract_native(value) return self.compliant._with_native( - self.native.str.replace(pattern, extract_native(value), literal=literal, n=n) + self.native.str.replace(pattern, value_native, literal=literal, n=n) # type: ignore[arg-type] ) - def replace_all(self, value: str, pattern: str, *, literal: bool) -> PolarsSeries: + def replace_all( + self, value: PolarsSeries | str, pattern: str, *, literal: bool + ) -> PolarsSeries: + value_native = value if isinstance(value, str) else extract_native(value) return self.compliant._with_native( - self.native.str.replace_all(pattern, extract_native(value), literal=literal) + self.native.str.replace_all(pattern, value_native, literal=literal) # type: ignore[arg-type] ) diff --git a/narwhals/_sql/expr_str.py b/narwhals/_sql/expr_str.py index 0924137fca..c1b5db9b53 100644 --- a/narwhals/_sql/expr_str.py +++ b/narwhals/_sql/expr_str.py @@ -37,9 +37,7 @@ def len_chars(self) -> SQLExprT: lambda expr: self._function("length", expr) ) - def replace_all( - self, value: str | SQLExprT, pattern: str, *, literal: bool - ) -> SQLExprT: + def replace_all(self, value: SQLExprT, pattern: str, *, literal: bool) -> SQLExprT: fname: str = "replace" if literal else "regexp_replace" options: list[Any] = [] From 19a5c99aaf4f368b3923d52c28d70c1c90313dbd Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 27 Sep 2025 15:57:13 +0100 Subject: [PATCH 25/95] yay remove type ignore --- narwhals/_arrow/series_str.py | 17 +++++---------- narwhals/_compliant/series.py | 4 ++-- narwhals/_dask/expr_str.py | 34 +++++++++++++---------------- narwhals/_pandas_like/series_str.py | 13 ++++------- 4 files changed, 26 insertions(+), 42 deletions(-) diff --git a/narwhals/_arrow/series_str.py b/narwhals/_arrow/series_str.py index dcd86d8835..b9b7976058 100644 --- a/narwhals/_arrow/series_str.py +++ b/narwhals/_arrow/series_str.py @@ -6,12 +6,7 @@ import pyarrow as pa import pyarrow.compute as pc -from narwhals._arrow.utils import ( - ArrowSeriesNamespace, - extract_native, - lit, - parse_datetime_format, -) +from narwhals._arrow.utils import ArrowSeriesNamespace, lit, parse_datetime_format from narwhals._compliant.any_namespace import StringNamespace if TYPE_CHECKING: @@ -27,23 +22,21 @@ def replace( self, value: ArrowSeries | str, pattern: str, *, literal: bool, n: int ) -> ArrowSeries: fn = pc.replace_substring if literal else pc.replace_substring_regex - value_native = extract_native(self.compliant, value) - if not isinstance(value_native, str): + if not isinstance(value, str): msg = "PyArrow backed `.str.replace` only supports str replacement values" raise TypeError(msg) - arr = fn(self.native, pattern, replacement=value_native, max_replacements=n) + arr = fn(self.native, pattern, replacement=value, max_replacements=n) return self.with_native(arr) def replace_all( self, value: ArrowSeries | str, pattern: str, *, literal: bool ) -> ArrowSeries: - value_native = extract_native(self.compliant, value) - if not isinstance(value_native, str): + if not isinstance(value, str): msg = ( "PyArrow backed `.str.replace_all` only supports str replacement values." ) raise TypeError(msg) - return self.replace(value_native, pattern, literal=literal, n=-1) + return self.replace(value, pattern, literal=literal, n=-1) def strip_chars(self, characters: str | None) -> ArrowSeries: return self.with_native( diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index e5ae7d4533..6d456e944a 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -322,14 +322,14 @@ class EagerSeriesDateTimeNamespace( # type: ignore[misc] ): ... -class EagerSeriesListNamespace( # type: ignore[misc] +class EagerSeriesListNamespace( _SeriesNamespace[EagerSeriesT, NativeSeriesT_co], ListNamespace[EagerSeriesT], Protocol[EagerSeriesT, NativeSeriesT_co], ): ... -class EagerSeriesStringNamespace( # type: ignore[misc] +class EagerSeriesStringNamespace( _SeriesNamespace[EagerSeriesT, NativeSeriesT_co], StringNamespace[EagerSeriesT], Protocol[EagerSeriesT, NativeSeriesT_co], diff --git a/narwhals/_dask/expr_str.py b/narwhals/_dask/expr_str.py index 735deb9ac7..af884b4306 100644 --- a/narwhals/_dask/expr_str.py +++ b/narwhals/_dask/expr_str.py @@ -21,32 +21,28 @@ def len_chars(self) -> DaskExpr: def replace( self, value: DaskExpr | str, pattern: str, *, literal: bool, n: int ) -> DaskExpr: - def _replace(expr: dx.Series, value: str) -> dx.Series: - try: - return expr.str.replace( # pyright: ignore[reportAttributeAccessIssue] - pattern, value, regex=not literal, n=n + def _replace(expr: dx.Series, value: dx.Series | str) -> dx.Series: + if not isinstance(value, str): + msg = ( + "dask backed `Expr.str.replace` only supports str replacement values" ) - except TypeError as e: - if not isinstance(value, str): - msg = "dask backed `Expr.str.replace` only supports str replacement values" - raise TypeError(msg) from e - raise + raise TypeError(msg) + return expr.str.replace( # pyright: ignore[reportAttributeAccessIssue] + pattern, value, regex=not literal, n=n + ) return self.compliant._with_callable(_replace, value=value) def replace_all( self, value: DaskExpr | str, pattern: str, *, literal: bool ) -> DaskExpr: - def _replace_all(expr: dx.Series, value: str) -> dx.Series: - try: - return expr.str.replace( # pyright: ignore[reportAttributeAccessIssue] - pattern, value, regex=not literal, n=-1 - ) - except TypeError as e: - if not isinstance(value, str): - msg = "dask backed `Expr.str.replace_all` only supports str replacement values." - raise TypeError(msg) from e - raise + def _replace_all(expr: dx.Series, value: dx.Series | str) -> dx.Series: + if not isinstance(value, str): + msg = "dask backed `Expr.str.replace_all` only supports str replacement values." + raise TypeError(msg) + return expr.str.replace( # pyright: ignore[reportAttributeAccessIssue] + pattern, value, regex=not literal, n=-1 + ) return self.compliant._with_callable(_replace_all, value=value) diff --git a/narwhals/_pandas_like/series_str.py b/narwhals/_pandas_like/series_str.py index 0597ea1a01..c162bede4c 100644 --- a/narwhals/_pandas_like/series_str.py +++ b/narwhals/_pandas_like/series_str.py @@ -18,15 +18,10 @@ def len_chars(self) -> PandasLikeSeries: def replace( self, value: PandasLikeSeries | str, pattern: str, *, literal: bool, n: int ) -> PandasLikeSeries: - try: - series = self.native.str.replace( - pat=pattern, repl=value, n=n, regex=not literal - ) - except TypeError as e: - if not isinstance(value, str): - msg = f"{self.compliant._implementation} backed `.str.replace` only supports str replacement values" - raise TypeError(msg) from e - raise + if not isinstance(value, str): + msg = f"{self.compliant._implementation} backed `.str.replace` only supports str replacement values" + raise TypeError(msg) + series = self.native.str.replace(pat=pattern, repl=value, n=n, regex=not literal) return self.with_native(series) def replace_all( From b75871054b075238fb8dae82e7e25e81d6a5332a Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 27 Sep 2025 16:46:12 +0100 Subject: [PATCH 26/95] wooah we can support per-group broadcasting --- narwhals/_expression_parsing.py | 38 ++++++++++++++++++++++++++++---- narwhals/expr.py | 9 ++++++-- tests/expression_parsing_test.py | 8 ++++++- 3 files changed, 48 insertions(+), 7 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index cfcb8bf812..68d71ead65 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -145,9 +145,6 @@ class ExprKind(Enum): SERIES = auto() """Results from converting a Series to Expr.""" - UNKNOWN = auto() - """Based on the information we have, we can't determine the ExprKind.""" - @property def is_orderable(self) -> bool: # Any operation which may be affected by `order_by`, such as `cum_sum`, @@ -159,6 +156,22 @@ def is_orderable(self) -> bool: ExprKind.WINDOW, } + @property + def is_elementwise(self) -> bool: + # Any operation which can operate on each row independently + # of the rows around it, e.g. `abs(), __add__, sum_horizontal, ...` + return self in { + ExprKind.ALL, + ExprKind.COL, + ExprKind.ELEMENTWISE, + ExprKind.EXCLUDE, + ExprKind.LITERAL, + ExprKind.NTH, + ExprKind.SELECTOR, + ExprKind.SERIES, + ExprKind.WHEN_THEN, + } + def is_scalar_like(obj: CompliantExprAny | NonNestedLiteral) -> bool: if is_compliant_expr(obj): @@ -225,6 +238,7 @@ def __init__( # Cached methods. self._is_orderable_cached: bool | None = None + self._is_elementwise_cached: bool | None = None def __repr__(self) -> str: if self.name == "col": @@ -257,7 +271,9 @@ def _push_down_over_node_in_place( expr_node.is_orderable() for expr_node in expr._nodes ): exprs.append(expr._with_node(over_node)) - elif over_node_without_order_by.kwargs["partition_by"]: + elif over_node_without_order_by.kwargs["partition_by"] and not all( + expr_node.is_elementwise() for expr_node in expr._nodes + ): exprs.append(expr._with_node(over_node_without_order_by)) else: # If there's no `partition_by`, then `over_node_without_order_by` is a no-op. @@ -280,6 +296,20 @@ def is_orderable(self) -> bool: self._is_orderable_cached = False return self._is_orderable_cached + def is_elementwise(self) -> bool: + if self._is_elementwise_cached is None: + # Note: don't combine these if/then statements so that pytest-cov shows if + # anything is uncovered. + if not self.kind.is_elementwise or not all( + all(node.is_elementwise() for node in expr._nodes) + for expr in self.exprs + if is_expr(expr) + ): + self._is_elementwise_cached = False + else: + self._is_elementwise_cached = True + return self._is_elementwise_cached + class ExprMetadata: """Expression metadata. diff --git a/narwhals/expr.py b/narwhals/expr.py index 5ab722b0bc..65ebf2dad4 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -12,7 +12,7 @@ ) from narwhals._utils import _validate_rolling_arguments, ensure_type, flatten from narwhals.dtypes import _validate_dtype -from narwhals.exceptions import ComputeError +from narwhals.exceptions import ComputeError, InvalidOperationError from narwhals.expr_cat import ExprCatNamespace from narwhals.expr_dt import ExprDateTimeNamespace from narwhals.expr_list import ExprListNamespace @@ -80,8 +80,13 @@ def _with_node(self, node: ExprNode) -> Self: node.is_orderable() for node in new_nodes[:i] ): new_nodes.insert(i, node) - elif node.kwargs["partition_by"]: + elif node.kwargs["partition_by"] and not all( + node.is_elementwise() for node in new_nodes[:i] + ): new_nodes.insert(i, node_without_order_by) + elif all(node.is_elementwise() for node in new_nodes): + msg = "Cannot apply `over` to elementwise expression." + raise InvalidOperationError(msg) return self.__class__(*new_nodes) return self.__class__(*self._nodes, node) diff --git a/tests/expression_parsing_test.py b/tests/expression_parsing_test.py index 9c0c63f6cc..9c8b8cd510 100644 --- a/tests/expression_parsing_test.py +++ b/tests/expression_parsing_test.py @@ -48,6 +48,7 @@ ), [4 / 3, 13 / 3, 7 / 3], ), + ((nw.col("a") - nw.col("a").mean()).over("b"), [-1.5, 1.5, 0]), ], ) def test_over_pushdown( @@ -78,7 +79,6 @@ def test_over_pushdown( nw.col("a").mean().rank(), nw.col("a").mean().is_unique(), nw.col("a").mean().diff(), - nw.col("a").fill_null(3).over("b"), nw.col("a").drop_nulls().over("b"), nw.col("a").drop_nulls().over("b", order_by="i"), nw.col("a").diff().drop_nulls().over("b", order_by="i"), @@ -92,3 +92,9 @@ def test_invalid_operations(constructor: Constructor, expr: nw.Expr) -> None: ).lazy() with pytest.raises((InvalidOperationError, NotImplementedError)): df.select(a=expr) + + +def test_invalid_elementwise_over() -> None: + # This one raises before it's even evaluated. + with pytest.raises(InvalidOperationError): + nw.col("a").fill_null(3).over("b") From 64ccba40530ce149147bdd57112c67aa87644520 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 27 Sep 2025 16:48:55 +0100 Subject: [PATCH 27/95] test repr --- tests/repr_test.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/repr_test.py b/tests/repr_test.py index 2ef9f77271..62100bee98 100644 --- a/tests/repr_test.py +++ b/tests/repr_test.py @@ -100,3 +100,19 @@ def test_polars_series_repr() -> None: "└────────────────────┘" ) assert result == expected + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + (nw.col("a"), "col(a)"), + (nw.col("a").abs(), "col(a).abs()"), + (nw.col("a").std(ddof=2), "col(a).std(ddof=2)"), + ( + nw.sum_horizontal(nw.col("a").rolling_mean(2), "b"), + "sum_horizontal(col(a).rolling_mean(window_size=2, min_samples=2, center=False), b)", + ), + ], +) +def test_expr_repr(expr: nw.Expr, expected: str) -> None: + assert repr(expr) == expected From 474a2df46c63f20d707926a919922a469942b419 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 27 Sep 2025 16:59:52 +0100 Subject: [PATCH 28/95] dask cmon man --- tests/expression_parsing_test.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/expression_parsing_test.py b/tests/expression_parsing_test.py index 9c8b8cd510..78b71f6000 100644 --- a/tests/expression_parsing_test.py +++ b/tests/expression_parsing_test.py @@ -48,7 +48,6 @@ ), [4 / 3, 13 / 3, 7 / 3], ), - ((nw.col("a") - nw.col("a").mean()).over("b"), [-1.5, 1.5, 0]), ], ) def test_over_pushdown( @@ -62,6 +61,24 @@ def test_over_pushdown( assert_equal_data(result, {"a": expected}) +@pytest.mark.parametrize( + ("expr", "expected"), [((nw.col("a") - nw.col("a").mean()).over("b"), [-1.5, 1.5, 0])] +) +def test_per_group_broadcasting( + constructor: Constructor, + expr: nw.Expr, + expected: list[float], + request: pytest.FixtureRequest, +) -> None: + if "dask" in str(constructor): + # sigh... + request.applymarker(pytest.mark.xfail) + data = {"a": [-1, 2, 3], "b": [1, 1, 2], "i": [0, 1, 2]} + df = nw.from_native(constructor(data)).lazy() + result = df.select("i", a=expr).sort("i").select("a") + assert_equal_data(result, {"a": expected}) + + @pytest.mark.parametrize( "expr", [ From 12a6637a77a03b6ba258cfbbe3e744015cc84a5f Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 27 Sep 2025 17:59:28 +0100 Subject: [PATCH 29/95] minor things --- narwhals/_pandas_like/expr.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index ddbe20c7a0..8ef72904f7 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -206,7 +206,7 @@ def ewm_mean( def over( # noqa: C901, PLR0915 self, partition_by: Sequence[str], order_by: Sequence[str] ) -> Self: - nodes = self._metadata.nodes + op_nodes_reversed = list(self._metadata.op_nodes_reversed()) if not partition_by: # e.g. `nw.col('a').cum_sum().order_by(key)` # We can always easily support this as it doesn't require grouping. @@ -222,7 +222,7 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: for s in results: s._scatter_in_place(sorting_indices, s) return results - elif len(nodes) > 2: + elif len(op_nodes_reversed) > 2: msg = ( "Only elementary expressions are supported for `.over` in pandas-like backends.\n\n" "Please see: " @@ -230,8 +230,8 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: ) raise NotImplementedError(msg) else: - assert nodes # noqa: S101 - leaf_node = nodes[-1] + assert op_nodes_reversed # noqa: S101 + leaf_node = op_nodes_reversed[0] function_name = leaf_node.name pandas_agg = PandasLikeGroupBy._REMAP_AGGS.get( cast("NarwhalsAggregation", function_name) From c602133e5578c78a4d21522da42d5259087e6c0a Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 27 Sep 2025 18:11:43 +0100 Subject: [PATCH 30/95] minor things --- narwhals/_polars/expr.py | 8 ++++---- narwhals/functions.py | 21 ++++++++------------- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index c6f51a7737..2e599b960d 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -80,6 +80,10 @@ def __repr__(self) -> str: # pragma: no cover def _with_native(self, expr: pl.Expr) -> Self: return self.__class__(expr, self._version) + def broadcast(self) -> Self: + # Let Polars do its thing. + return self + @property def _metadata(self) -> ExprMetadata: assert self._opt_metadata is not None # noqa: S101 @@ -92,10 +96,6 @@ def func(*args: Any, **kwargs: Any) -> Any: return func - def broadcast(self) -> Self: - # Let Polars do its thing. - return self - def _renamed_min_periods(self, min_samples: int, /) -> dict[str, Any]: name = "min_periods" if self._backend_version < (1, 21, 0) else "min_samples" return {name: min_samples} diff --git a/narwhals/functions.py b/narwhals/functions.py index 75a26eb67a..0338f87340 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -989,8 +989,7 @@ def nth(*indices: int | Sequence[int]) -> Expr: └──────────────────┘ """ flat_indices = flatten(indices) - node = ExprNode(ExprKind.NTH, "nth", indices=flat_indices) - return Expr(node) + return Expr(ExprNode(ExprKind.NTH, "nth", indices=flat_indices)) # Add underscore so it doesn't conflict with builtin `all` @@ -1014,8 +1013,7 @@ def all_() -> Expr: | 1 4 0.246 | └──────────────────┘ """ - node = ExprNode(ExprKind.ALL, "all") - return Expr(node) + return Expr(ExprNode(ExprKind.ALL, "all")) # Add underscore so it doesn't conflict with builtin `len` @@ -1044,8 +1042,7 @@ def len_() -> Expr: | └─────┘ | └──────────────────┘ """ - node = ExprNode(ExprKind.AGGREGATION, "len") - return Expr(node) + return Expr(ExprNode(ExprKind.AGGREGATION, "len")) def sum(*columns: str) -> Expr: @@ -1208,8 +1205,9 @@ def _expr_with_horizontal_op(name: str, *exprs: IntoExpr, **kwargs: Any) -> Expr if not exprs: msg = f"At least one expression must be passed to `{name}`" raise ValueError(msg) - node = ExprNode(ExprKind.ELEMENTWISE, name, *exprs, **kwargs, allow_multi_output=True) - return Expr(node) + return Expr( + ExprNode(ExprKind.ELEMENTWISE, name, *exprs, **kwargs, allow_multi_output=True) + ) def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: @@ -1340,7 +1338,6 @@ def then(self, value: IntoExpr | NonNestedLiteral) -> Then: class Then(Expr): def otherwise(self, value: IntoExpr | NonNestedLiteral) -> Expr: - # eject latest node, replace with `when_then_otherwise` node = self._nodes[0] return Expr(ExprNode(ExprKind.ELEMENTWISE, "when_then", *node.exprs, value)) @@ -1471,8 +1468,7 @@ def lit(value: NonNestedLiteral, dtype: IntoDType | None = None) -> Expr: msg = f"Nested datatypes are not supported yet. Got {value}" raise NotImplementedError(msg) - node = ExprNode(ExprKind.LITERAL, "lit", value=value, dtype=dtype) - return Expr(node) + return Expr(ExprNode(ExprKind.LITERAL, "lit", value=value, dtype=dtype)) def any_horizontal(*exprs: IntoExpr | Iterable[IntoExpr], ignore_nulls: bool) -> Expr: @@ -1672,5 +1668,4 @@ def coalesce( ) raise TypeError(msg) - node = ExprNode(ExprKind.ELEMENTWISE, "coalesce", *flat_exprs) - return Expr(node) + return Expr(ExprNode(ExprKind.ELEMENTWISE, "coalesce", *flat_exprs)) From 0996ffe9eb9921ae4b1fcd6ddb7076ef8af644b5 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 27 Sep 2025 18:26:43 +0100 Subject: [PATCH 31/95] simplify --- narwhals/_arrow/expr.py | 7 ++----- narwhals/_compliant/expr.py | 6 +++--- narwhals/_compliant/series.py | 2 +- narwhals/_dask/expr.py | 2 -- narwhals/_duckdb/expr.py | 2 -- narwhals/_ibis/expr.py | 2 -- narwhals/_pandas_like/expr.py | 2 -- narwhals/_spark_like/expr.py | 2 -- narwhals/_sql/expr.py | 2 -- 9 files changed, 6 insertions(+), 21 deletions(-) diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index 80e7072527..c293d1b8f0 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -20,8 +20,7 @@ from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._arrow.namespace import ArrowNamespace - from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs - from narwhals._expression_parsing import ExprMetadata + from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries from narwhals._utils import Version, _LimitedContext @@ -35,14 +34,12 @@ def __init__( evaluate_output_names: EvalNames[ArrowDataFrame], alias_output_names: AliasNames | None, version: Version, - scalar_kwargs: ScalarKwargs | None = None, - implementation: Implementation | None = None, + implementation: Implementation = Implementation.PYARROW, ) -> None: self._call = call self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names self._version = version - self._opt_metadata: ExprMetadata | None = None @classmethod def from_column_names( diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index cf74ed17b2..a7c534469e 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -91,12 +91,12 @@ class CompliantExpr( _implementation: Implementation _evaluate_output_names: EvalNames[CompliantFrameT] _alias_output_names: AliasNames | None - _opt_metadata: ExprMetadata | None + # This should be set with extreme care, only in `_expression_parsing.py`, + # and never from within any compliant class. + _opt_metadata: ExprMetadata | None = None @property def _metadata(self) -> ExprMetadata: - # This should be set with extreme care, and only at the Narwhals level or in - # `_expression_parsing.py`, and never from within any compliant class. assert self._opt_metadata is not None # noqa: S101 return self._opt_metadata diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index 6d456e944a..2ed7c6989f 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -322,7 +322,7 @@ class EagerSeriesDateTimeNamespace( # type: ignore[misc] ): ... -class EagerSeriesListNamespace( +class EagerSeriesListNamespace( # pyright: ignore[reportInvalidTypeVarUse] _SeriesNamespace[EagerSeriesT, NativeSeriesT_co], ListNamespace[EagerSeriesT], Protocol[EagerSeriesT, NativeSeriesT_co], diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index b15b08e3a6..7d8e9c7449 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -37,7 +37,6 @@ ) from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.namespace import DaskNamespace - from narwhals._expression_parsing import ExprMetadata from narwhals._utils import Version, _LimitedContext from narwhals.typing import ( FillNullStrategy, @@ -68,7 +67,6 @@ def __init__( self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names self._version = version - self._opt_metadata: ExprMetadata | None = None def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]: return self._call(df) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 6aa6310514..76ca00b198 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -36,7 +36,6 @@ ) from narwhals._duckdb.dataframe import DuckDBLazyFrame from narwhals._duckdb.namespace import DuckDBNamespace - from narwhals._expression_parsing import ExprMetadata from narwhals._utils import _LimitedContext from narwhals.typing import ( FillNullStrategy, @@ -66,7 +65,6 @@ def __init__( self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names self._version = version - self._opt_metadata: ExprMetadata | None = None self._window_function: DuckDBWindowFunction | None = window_function def _count_star(self) -> Expression: diff --git a/narwhals/_ibis/expr.py b/narwhals/_ibis/expr.py index a282ce78ef..8f11a30bae 100644 --- a/narwhals/_ibis/expr.py +++ b/narwhals/_ibis/expr.py @@ -36,7 +36,6 @@ EvalSeries, WindowFunction, ) - from narwhals._expression_parsing import ExprMetadata from narwhals._ibis.dataframe import IbisLazyFrame from narwhals._ibis.namespace import IbisNamespace from narwhals._utils import _LimitedContext @@ -64,7 +63,6 @@ def __init__( self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names self._version = version - self._opt_metadata: ExprMetadata | None = None self._window_function: IbisWindowFunction | None = window_function @property diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 8ef72904f7..14f8592645 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -19,7 +19,6 @@ EvalSeries, NarwhalsAggregation, ) - from narwhals._expression_parsing import ExprMetadata from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals._pandas_like.namespace import PandasLikeNamespace from narwhals._utils import Implementation, Version, _LimitedContext @@ -126,7 +125,6 @@ def __init__( self._alias_output_names = alias_output_names self._implementation = implementation self._version = version - self._opt_metadata: ExprMetadata | None = None def __narwhals_namespace__(self) -> PandasLikeNamespace: from narwhals._pandas_like.namespace import PandasLikeNamespace diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index dacb01c382..dfbe8b2247 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -31,7 +31,6 @@ EvalSeries, WindowFunction, ) - from narwhals._expression_parsing import ExprMetadata from narwhals._spark_like.dataframe import SparkLikeLazyFrame from narwhals._spark_like.namespace import SparkLikeNamespace from narwhals._utils import _LimitedContext @@ -58,7 +57,6 @@ def __init__( self._alias_output_names = alias_output_names self._version = version self._implementation = implementation - self._opt_metadata: ExprMetadata | None = None self._window_function: SparkWindowFunction | None = window_function _REMAP_RANK_METHOD: ClassVar[Mapping[RankMethod, NativeRankMethod]] = { diff --git a/narwhals/_sql/expr.py b/narwhals/_sql/expr.py index ba0c99cc10..c68853f1e5 100644 --- a/narwhals/_sql/expr.py +++ b/narwhals/_sql/expr.py @@ -25,7 +25,6 @@ from typing_extensions import Self from narwhals._compliant.typing import AliasNames, WindowFunction - from narwhals._expression_parsing import ExprMetadata from narwhals._sql.expr_dt import SQLExprDateTimeNamesSpace from narwhals._sql.expr_str import SQLExprStringNamespace from narwhals._sql.namespace import SQLNamespace @@ -44,7 +43,6 @@ class SQLExpr(LazyExpr[SQLLazyFrameT, NativeExprT], Protocol[SQLLazyFrameT, Nati _alias_output_names: AliasNames | None _version: Version _implementation: Implementation - _opt_metadata: ExprMetadata | None _window_function: WindowFunction[SQLLazyFrameT, NativeExprT] | None def __init__( From 0b6d2e56643ed376fbf60653a021072108b91e5f Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 28 Sep 2025 22:14:38 +0100 Subject: [PATCH 32/95] post merge fixup --- narwhals/_dask/expr.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 5c45cdbe7a..1501fde04b 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -560,7 +560,7 @@ def func(df: DaskLazyFrame) -> Sequence[dx.Series]: ) raise NotImplementedError(msg) from None dask_kwargs = window_kwargs_to_pandas_equivalent( - function_name, self._scalar_kwargs + function_name, leaf_node.kwargs ) def func(df: DaskLazyFrame) -> Sequence[dx.Series]: @@ -574,10 +574,6 @@ def func(df: DaskLazyFrame) -> Sequence[dx.Series]: category=UserWarning, ) grouped = df.native.groupby(partition_by) - kwargs = leaf_node.kwargs - pandas_kwargs = window_kwargs_to_pandas_equivalent( - function_name, kwargs - ) if dask_function_name == "size": if len(output_names) != 1: # pragma: no cover msg = "Safety check failed, please report a bug." From dfbdeeed849bd96d351dce3678ac5b76d65f51b5 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 30 Sep 2025 07:58:44 +0100 Subject: [PATCH 33/95] reduce selectors diff --- narwhals/selectors.py | 36 ++++++++++++++++++++++-------------- tests/selectors_test.py | 7 +++++-- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/narwhals/selectors.py b/narwhals/selectors.py index 31db79f86a..32dbdd5d2c 100644 --- a/narwhals/selectors.py +++ b/narwhals/selectors.py @@ -18,45 +18,53 @@ class Selector(Expr): def _to_expr(self) -> Expr: return Expr(*self._nodes) - def __rsub__(self, other: Any) -> NoReturn: - raise NotImplementedError - - def __rand__(self, other: Any) -> NoReturn: - raise NotImplementedError - - def __ror__(self, other: Any) -> NoReturn: - raise NotImplementedError + def __add__(self, other: Any) -> Expr: # type: ignore[override] + if isinstance(other, Selector): + msg = "unsupported operand type(s) for op: ('Selector' + 'Selector')" + raise TypeError(msg) + return self._to_expr()._with_node( + ExprNode(ExprKind.ELEMENTWISE, "__add__", other, str_as_lit=True) + ) - def __and__(self, other: Any) -> Expr: # type: ignore[override] + def __or__(self, other: Any) -> Expr: # type: ignore[override] if isinstance(other, Selector): return self._with_node( ExprNode( ExprKind.ELEMENTWISE, - "__and__", + "__or__", other, str_as_lit=True, allow_multi_output=True, ) ) return self._to_expr()._with_node( - ExprNode(ExprKind.ELEMENTWISE, "__and__", other, str_as_lit=True) + ExprNode(ExprKind.ELEMENTWISE, "__or__", other, str_as_lit=True) ) - def __or__(self, other: Any) -> Expr: # type: ignore[override] + def __and__(self, other: Any) -> Expr: # type: ignore[override] if isinstance(other, Selector): return self._with_node( ExprNode( ExprKind.ELEMENTWISE, - "__or__", + "__and__", other, str_as_lit=True, allow_multi_output=True, ) ) return self._to_expr()._with_node( - ExprNode(ExprKind.ELEMENTWISE, "__or__", other, str_as_lit=True) + ExprNode(ExprKind.ELEMENTWISE, "__and__", other, str_as_lit=True) ) + def __rsub__(self, other: Any) -> NoReturn: + raise NotImplementedError + + def __rand__(self, other: Any) -> NoReturn: + raise NotImplementedError + + def __ror__(self, other: Any) -> NoReturn: + raise NotImplementedError + def by_dtype(*dtypes: DType | type[DType] | Iterable[DType | type[DType]]) -> Selector: """Select columns based on their dtype. diff --git a/tests/selectors_test.py b/tests/selectors_test.py index ba780299b8..63c6b387e5 100644 --- a/tests/selectors_test.py +++ b/tests/selectors_test.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re from datetime import datetime, timezone from typing import Literal @@ -7,7 +8,6 @@ import narwhals as nw import narwhals.selectors as ncs -from narwhals.exceptions import MultiOutputExpressionError from tests.utils import ( PANDAS_VERSION, POLARS_VERSION, @@ -246,7 +246,10 @@ def test_set_ops_invalid(constructor: Constructor) -> None: with pytest.raises((NotImplementedError, ValueError)): df.select(1 & ncs.numeric()) - with pytest.raises(MultiOutputExpressionError): + with pytest.raises( + TypeError, + match=re.escape("unsupported operand type(s) for op: ('Selector' + 'Selector')"), + ): df.select(ncs.boolean() + ncs.numeric()) From 1ebf11e49b5aa9ff1bb5b610819c12471a8ec4b8 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 30 Sep 2025 08:25:45 +0100 Subject: [PATCH 34/95] document ExprNodes --- docs/how_it_works.md | 62 +++++++++++++++++++++++++++++++-- narwhals/_expression_parsing.py | 11 ++++++ 2 files changed, 70 insertions(+), 3 deletions(-) diff --git a/docs/how_it_works.md b/docs/how_it_works.md index 6b79a6f050..97dd528d87 100644 --- a/docs/how_it_works.md +++ b/docs/how_it_works.md @@ -264,6 +264,62 @@ In Narwhals, here's what we do: users of the performance penalty and advising them to refactor their code so that the aggregation they perform ends up being a simple one. +## Nodes + +If we have a Narwhals expression, we can look at the operations which make it up by accessing `_nodes`: + +```python exec="1" result="python" session="pandas_impl" source="above" +import narwhals as nw + +expr = nw.col("a").abs().std(ddof=1) + nw.col("b") +print(expr._nodes) +``` + +Each node represents an operation. Here, we have 4 operations: + +1. Given some dataframe, select column `'a'`. +2. Take its absolute value. +3. Take its standard deviation, with `ddof=1`. +4. Sum column `'b'`. + +Let's take a look at a couple of these nodes. Let's start with the third one: + +```python exec="1" result="python" session="pandas_impl" source="above" +print(expr._nodes[2].as_dict()) +``` + +This tells us a few things: + +- We're performing an aggregation. +- The name of the function is `'std'`. This will be looked up in the compliant object. +- It takes keyword arguments `ddof=1`. +- We'll look at the others later. + +In order for the evaluation to succeed, then `PandasLikeExpr` must have a `std` method defined +on it, which takes a `ddof` argument. And this is what the `CompliantExpr` Protocol is for: so +long as a backend's implementation complies with the protocol, then Narwhals will be able to +unpack a `ExprNode` and turn it into a valid call. + +Let's take a look at the fourth node: + +```python exec="1" result="python" session="pandas_impl" source="above" +print(expr._nodes[3].as_dict()) +``` + +Note how now, the `exprs` attribute is populated. Indeed, we are summing another expression: `col('b')`. +The `exprs` parameter holds arguments which are either expressions, or should be interpreted as expressions. +The `str_as_lit` parameter tells us whether string literals should be interpreted as literals (e.g. `lit('foo')`) +or columns (e.g. `col('foo')`). Finally `allow_multi_output` tells us whether multi-outuput expressions +(more on this in the next section) are allowed to appear in `exprs`. + +Node that the expression in `exprs` also has its own nodes: + +```python exec="1" result="python" session="pandas_impl" source="above" +print(expr._nodes[3].exprs[0]._nodes) +``` + +It's nodes all the way down! + ## Expression Metadata Let's try printing out some compliant expressions' metadata to see what it shows us: @@ -307,7 +363,7 @@ Here's a brief description of each piece of metadata: only on literal values, like `nw.lit(1)`. - `nodes`: List of operations which this expression applies when evaluated. -#### Chaining +### Chaining Say we have `expr.expr_method()`. How does `expr`'s `ExprMetadata` change? This depends on `expr_method`. Details can be found in `narwhals/_expression_parsing`, @@ -351,7 +407,7 @@ is: then `n_orderable_ops` is decreased by 1. This is the only way that `n_orderable_ops` can decrease. -### Broadcasting +## Broadcasting When performing comparisons between columns and aggregations or scalars, we operate as if the aggregation or scalar was broadcasted to the length of the whole column. For example, if we @@ -373,7 +429,7 @@ Narwhals triggers a broadcast in these situations: Each backend is then responsible for doing its own broadcasting, as defined in each `CompliantExpr.broadcast` method. -### Elementwise push-down +## Elementwise push-down SQL is picky about `over` operations. For example: diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 68d71ead65..9b7251e22d 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -253,6 +253,17 @@ def __repr__(self) -> str: arg_str.append(kwargs_repr) return f"{self.name}({', '.join(arg_str)})" + def as_dict(self) -> dict[str, Any]: # pragma: no cover + # Just for debugging. + return { + "kind": self.kind, + "name": self.name, + "exprs": self.exprs, + "kwargs": self.kwargs, + "str_as_lit": self.str_as_lit, + "allow_multi_output": self.allow_multi_output, + } + def _with_kwargs(self, **kwargs: Any) -> ExprNode: return self.__class__( self.kind, self.name, *self.exprs, str_as_lit=self.str_as_lit, **kwargs From bd1ccea0607b18b9d30e97a1af6cd3fd48e889af Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 30 Sep 2025 08:51:40 +0100 Subject: [PATCH 35/95] dask fix --- narwhals/_dask/expr.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 4bea031b8b..788406a983 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -220,8 +220,6 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return self.__class__( func, - depth=self._depth + 1, - function_name=self._function_name + "->__floordiv__", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, version=self._version, @@ -275,8 +273,6 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return self.__class__( func, - depth=self._depth + 1, - function_name=self._function_name + "->__rfloordiv__", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, version=self._version, From 18137e1437a8c92ddff6baebebde994481888b9e Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 30 Sep 2025 15:29:00 +0100 Subject: [PATCH 36/95] simplify --- narwhals/_arrow/series.py | 12 +++------ narwhals/_arrow/series_str.py | 19 ++++++++++---- narwhals/_compliant/any_namespace.py | 2 +- narwhals/_compliant/column.py | 8 +----- narwhals/_compliant/expr.py | 14 ++++------ narwhals/_compliant/namespace.py | 18 ++++++------- narwhals/_dask/dataframe.py | 11 +++++--- narwhals/_dask/expr.py | 39 ++++++++++++++++------------ narwhals/_dask/expr_str.py | 33 +++++++++++------------ narwhals/_dask/utils.py | 12 +++------ narwhals/_pandas_like/series.py | 8 +----- narwhals/_pandas_like/series_str.py | 13 +++++++--- narwhals/_sql/expr.py | 14 ++-------- 13 files changed, 95 insertions(+), 108 deletions(-) diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 7b402599d4..0634c94cfc 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -67,12 +67,10 @@ IntoDType, ModeKeepStrategy, NonNestedLiteral, - NumericLiteral, PythonLiteral, RankMethod, RollingInterpolationMethod, SizedMultiIndexSelector, - TemporalLiteral, _1DArray, _2DArray, _SliceIndex, @@ -835,11 +833,7 @@ def quantile( def gather_every(self, n: int, offset: int = 0) -> Self: return self._with_native(self.native[offset::n]) - def clip( - self, - lower_bound: Self | NumericLiteral | TemporalLiteral | None, - upper_bound: Self | NumericLiteral | TemporalLiteral | None, - ) -> Self: + def clip(self, lower_bound: Self, upper_bound: Self) -> Self: _, lower = ( extract_native(self, lower_bound) if lower_bound is not None else (None, None) ) @@ -847,9 +841,9 @@ def clip( extract_native(self, upper_bound) if upper_bound is not None else (None, None) ) - if lower is None: + if lower is None or isinstance(lower, pa.NullScalar): return self._with_native(pc.min_element_wise(self.native, upper)) - if upper is None: + if upper is None or isinstance(upper, pa.NullScalar): return self._with_native(pc.max_element_wise(self.native, lower)) return self._with_native( pc.max_element_wise(pc.min_element_wise(self.native, upper), lower) diff --git a/narwhals/_arrow/series_str.py b/narwhals/_arrow/series_str.py index b9b7976058..a872fed192 100644 --- a/narwhals/_arrow/series_str.py +++ b/narwhals/_arrow/series_str.py @@ -6,7 +6,12 @@ import pyarrow as pa import pyarrow.compute as pc -from narwhals._arrow.utils import ArrowSeriesNamespace, lit, parse_datetime_format +from narwhals._arrow.utils import ( + ArrowSeriesNamespace, + extract_native, + lit, + parse_datetime_format, +) from narwhals._compliant.any_namespace import StringNamespace if TYPE_CHECKING: @@ -22,21 +27,25 @@ def replace( self, value: ArrowSeries | str, pattern: str, *, literal: bool, n: int ) -> ArrowSeries: fn = pc.replace_substring if literal else pc.replace_substring_regex - if not isinstance(value, str): + _, value_native = extract_native(self.compliant, value) + if not isinstance(value_native, pa.StringScalar): msg = "PyArrow backed `.str.replace` only supports str replacement values" raise TypeError(msg) - arr = fn(self.native, pattern, replacement=value, max_replacements=n) + arr = fn( + self.native, pattern, replacement=value_native.as_py(), max_replacements=n + ) return self.with_native(arr) def replace_all( self, value: ArrowSeries | str, pattern: str, *, literal: bool ) -> ArrowSeries: - if not isinstance(value, str): + _, value_native = extract_native(self.compliant, value) + if not isinstance(value_native, pa.StringScalar): msg = ( "PyArrow backed `.str.replace_all` only supports str replacement values." ) raise TypeError(msg) - return self.replace(value, pattern, literal=literal, n=-1) + return self.replace(value_native.as_py(), pattern, literal=literal, n=-1) def strip_chars(self, characters: str | None) -> ArrowSeries: return self.with_native( diff --git a/narwhals/_compliant/any_namespace.py b/narwhals/_compliant/any_namespace.py index 503b873604..c3c2f2e16f 100644 --- a/narwhals/_compliant/any_namespace.py +++ b/narwhals/_compliant/any_namespace.py @@ -87,7 +87,7 @@ class StringNamespace(_StoresCompliant[T], Protocol[T]): _accessor: ClassVar[Accessor] = "str" def len_chars(self) -> T: ... - def replace(self, value: str, pattern: str, *, literal: bool, n: int) -> T: ... + def replace(self, value: T, pattern: str, *, literal: bool, n: int) -> T: ... def replace_all(self, value: T, pattern: str, *, literal: bool) -> T: ... def strip_chars(self, characters: str | None) -> T: ... def starts_with(self, prefix: str) -> T: ... diff --git a/narwhals/_compliant/column.py b/narwhals/_compliant/column.py index 1dc680c023..fc21045ddc 100644 --- a/narwhals/_compliant/column.py +++ b/narwhals/_compliant/column.py @@ -22,9 +22,7 @@ IntoDType, ModeKeepStrategy, NonNestedLiteral, - NumericLiteral, RankMethod, - TemporalLiteral, ) __all__ = ["CompliantColumn"] @@ -62,11 +60,7 @@ def __narwhals_namespace__(self) -> CompliantNamespace[Any, Any]: ... def abs(self) -> Self: ... def alias(self, name: str) -> Self: ... def cast(self, dtype: IntoDType) -> Self: ... - def clip( - self, - lower_bound: Self | NumericLiteral | TemporalLiteral | None, - upper_bound: Self | NumericLiteral | TemporalLiteral | None, - ) -> Self: ... + def clip(self, lower_bound: Self, upper_bound: Self) -> Self: ... def cum_count(self, *, reverse: bool) -> Self: ... def cum_max(self, *, reverse: bool) -> Self: ... def cum_min(self, *, reverse: bool) -> Self: ... diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index a7c534469e..ff9439d7b5 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -53,10 +53,8 @@ IntoDType, ModeKeepStrategy, NonNestedLiteral, - NumericLiteral, RankMethod, RollingInterpolationMethod, - TemporalLiteral, TimeUnit, ) @@ -548,11 +546,7 @@ def arg_max(self) -> Self: # Other - def clip( - self, - lower_bound: Self | NumericLiteral | TemporalLiteral | None, - upper_bound: Self | NumericLiteral | TemporalLiteral | None, - ) -> Self: + def clip(self, lower_bound: Self, upper_bound: Self) -> Self: return self._reuse_series( "clip", lower_bound=lower_bound, upper_bound=upper_bound ) @@ -1030,13 +1024,15 @@ class EagerExprStringNamespace( def len_chars(self) -> EagerExprT: return self.compliant._reuse_series_namespace("str", "len_chars") - def replace(self, value: str, pattern: str, *, literal: bool, n: int) -> EagerExprT: + def replace( + self, value: EagerExprT, pattern: str, *, literal: bool, n: int + ) -> EagerExprT: return self.compliant._reuse_series_namespace( "str", "replace", pattern=pattern, value=value, literal=literal, n=n ) def replace_all( - self, value: EagerExprT | str, pattern: str, *, literal: bool + self, value: EagerExprT, pattern: str, *, literal: bool ) -> EagerExprT: return self.compliant._reuse_series_namespace( "str", "replace_all", pattern=pattern, value=value, literal=literal diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index d6554db2fc..c3e3de31a1 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import partial -from typing import TYPE_CHECKING, Any, Protocol, overload +from typing import TYPE_CHECKING, Any, Protocol, cast, overload from narwhals._compliant.typing import ( CompliantExprT, @@ -59,16 +59,14 @@ class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]): @property def _expr(self) -> type[CompliantExprT]: ... - def evaluate_expr( - self, data: Expr | NonNestedLiteral | Any, / - ) -> CompliantExprT | NonNestedLiteral: + def evaluate_expr(self, data: Expr | NonNestedLiteral, /) -> CompliantExprT: if is_expr(data): - expr = data(self) - assert isinstance(expr, self._expr) # noqa: S101 - return expr - # TODO(marco): it would be nice to return `lit(data)` here, - # but for pandas and Dask this causes some issues. - return data + ret = data(self) + else: + from narwhals.functions import lit + + ret = lit(data)(self) + return cast("CompliantExprT", ret) # NOTE: `polars` def all(self) -> CompliantExprT: diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 6bddb729a1..283cf6a963 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -222,10 +222,10 @@ def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self: return self._with_native(add_row_index(self.native, name)) plx = self.__narwhals_namespace__() columns = self.columns - const_expr = plx.lit(value=1, dtype=None).alias(name).broadcast() + const_expr = plx.lit(1, dtype=None).alias(name).broadcast() row_index_expr = ( plx.col(name).cum_sum(reverse=False).over(partition_by=[], order_by=order_by) - - 1 + - plx.lit(1, dtype=None).broadcast() ) return self.with_columns(const_expr).select(row_index_expr, plx.col(*columns)) @@ -478,11 +478,14 @@ def tail(self, n: int) -> Self: # pragma: no cover def gather_every(self, n: int, offset: int) -> Self: row_index_token = generate_temporary_column_name(n_bytes=8, columns=self.columns) plx = self.__narwhals_namespace__() + offset_expr = plx.lit(offset, dtype=None).broadcast() + n_expr = plx.lit(n, dtype=None).broadcast() + zero_expr = plx.lit(0, dtype=None).broadcast() return ( self.with_row_index(row_index_token, order_by=None) .filter( - (plx.col(row_index_token) >= offset) - & ((plx.col(row_index_token) - offset) % n == 0) + (plx.col(row_index_token) >= offset_expr) + & ((plx.col(row_index_token) - offset_expr) % n_expr == zero_expr) ) .drop([row_index_token], strict=False) ) diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 788406a983..9876f985bf 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -11,7 +11,7 @@ from narwhals._dask.utils import ( add_row_index, align_series_full_broadcast, - maybe_evaluate_expr, + evaluate_expr, narwhals_to_native_dtype, ) from narwhals._expression_parsing import evaluate_output_names_and_aliases @@ -44,9 +44,7 @@ IntoDType, ModeKeepStrategy, NonNestedLiteral, - NumericLiteral, RollingInterpolationMethod, - TemporalLiteral, ) @@ -139,7 +137,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: native_results: list[dx.Series] = [] native_series_list = self._call(df) other_native_series = { - key: maybe_evaluate_expr(df, value) + key: evaluate_expr(df, value) for key, value in expressifiable_args.items() } for native_series in native_series_list: @@ -215,7 +213,7 @@ def _floordiv( return (series.__floordiv__(other)).where(other != 0, None) def func(df: DaskLazyFrame) -> list[dx.Series]: - other_series = maybe_evaluate_expr(df, other) + other_series = evaluate_expr(df, other) return [_floordiv(df, series, other_series) for series in self(df)] return self.__class__( @@ -269,7 +267,8 @@ def _rfloordiv( return (other.__floordiv__(series)).where(series != 0, None) def func(df: DaskLazyFrame) -> list[dx.Series]: - return [_rfloordiv(df, series, other) for series in self(df)] + other_native = evaluate_expr(df, other) + return [_rfloordiv(df, series, other_native) for series in self(df)] return self.__class__( func, @@ -463,17 +462,23 @@ def func(expr: dx.Series) -> dx.Series: return self._with_callable(func) - def clip( - self, - lower_bound: Self | NumericLiteral | TemporalLiteral | None, - upper_bound: Self | NumericLiteral | TemporalLiteral | None, - ) -> Self: - return self._with_callable( - lambda expr, lower_bound, upper_bound: expr.clip( - lower=lower_bound, upper=upper_bound - ), - lower_bound=lower_bound, - upper_bound=upper_bound, + def clip(self, lower_bound: Self, upper_bound: Self) -> Self: + def func(df: DaskLazyFrame) -> list[dx.Series]: + lower_native: dx.Series = evaluate_expr(df, lower_bound) + upper_native: dx.Series = evaluate_expr(df, upper_bound) + if lower_native.dtype == "O": + # `lower_bound` was `lit(None)` + return [series.clip(upper=upper_native) for series in self(df)] + if upper_native.dtype == "O": + # `lower_bound` was `lit(None)` + return [series.clip(lower_native) for series in self(df)] + return [series.clip(lower_native, upper_native) for series in self(df)] + + return self.__class__( + func, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + version=self._version, ) def diff(self) -> Self: diff --git a/narwhals/_dask/expr_str.py b/narwhals/_dask/expr_str.py index af884b4306..894054722e 100644 --- a/narwhals/_dask/expr_str.py +++ b/narwhals/_dask/expr_str.py @@ -19,29 +19,30 @@ def len_chars(self) -> DaskExpr: return self.compliant._with_callable(lambda expr: expr.str.len()) def replace( - self, value: DaskExpr | str, pattern: str, *, literal: bool, n: int + self, value: DaskExpr, pattern: str, *, literal: bool, n: int ) -> DaskExpr: - def _replace(expr: dx.Series, value: dx.Series | str) -> dx.Series: - if not isinstance(value, str): - msg = ( - "dask backed `Expr.str.replace` only supports str replacement values" - ) - raise TypeError(msg) + if not value._metadata.is_literal: + msg = "dask backed `Expr.str.replace` only supports str replacement values" + raise TypeError(msg) + + def _replace(expr: dx.Series, value: dx.Series) -> dx.Series: + # OK to call `compute` here as `value` is just a literal expression. return expr.str.replace( # pyright: ignore[reportAttributeAccessIssue] - pattern, value, regex=not literal, n=n + pattern, value.compute(), regex=not literal, n=n ) return self.compliant._with_callable(_replace, value=value) - def replace_all( - self, value: DaskExpr | str, pattern: str, *, literal: bool - ) -> DaskExpr: - def _replace_all(expr: dx.Series, value: dx.Series | str) -> dx.Series: - if not isinstance(value, str): - msg = "dask backed `Expr.str.replace_all` only supports str replacement values." - raise TypeError(msg) + def replace_all(self, value: DaskExpr, pattern: str, *, literal: bool) -> DaskExpr: + if not value._metadata.is_literal: + msg = ( + "dask backed `Expr.str.replace_all` only supports str replacement values" + ) + raise TypeError(msg) + + def _replace_all(expr: dx.Series, value: dx.Series) -> dx.Series: return expr.str.replace( # pyright: ignore[reportAttributeAccessIssue] - pattern, value, regex=not literal, n=-1 + pattern, value.compute(), regex=not literal, n=-1 ) return self.compliant._with_callable(_replace_all, value=value) diff --git a/narwhals/_dask/utils.py b/narwhals/_dask/utils.py index dd27940a2f..1cc1d337c4 100644 --- a/narwhals/_dask/utils.py +++ b/narwhals/_dask/utils.py @@ -23,14 +23,10 @@ import dask_expr as dx -def maybe_evaluate_expr(df: DaskLazyFrame, obj: DaskExpr | object) -> dx.Series | object: - from narwhals._dask.expr import DaskExpr - - if isinstance(obj, DaskExpr): - results = obj._call(df) - assert len(results) == 1 # debug assertion # noqa: S101 - return results[0] - return obj +def evaluate_expr(df: DaskLazyFrame, obj: DaskExpr) -> dx.Series: + results = obj._call(df) + assert len(results) == 1 # debug assertion # noqa: S101 + return results[0] def evaluate_exprs(df: DaskLazyFrame, /, *exprs: DaskExpr) -> list[tuple[str, dx.Series]]: diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index cdba4174f0..8b21532481 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -49,11 +49,9 @@ IntoDType, ModeKeepStrategy, NonNestedLiteral, - NumericLiteral, RankMethod, RollingInterpolationMethod, SizedMultiIndexSelector, - TemporalLiteral, _1DArray, _SliceIndex, ) @@ -823,11 +821,7 @@ def to_dummies(self, *, separator: str, drop_first: bool) -> PandasLikeDataFrame def gather_every(self, n: int, offset: int) -> Self: return self._with_native(self.native.iloc[offset::n]) - def clip( - self, - lower_bound: Self | NumericLiteral | TemporalLiteral | None, - upper_bound: Self | NumericLiteral | TemporalLiteral | None, - ) -> Self: + def clip(self, lower_bound: Self, upper_bound: Self) -> Self: _, lower = ( align_and_extract_native(self, lower_bound) if lower_bound is not None diff --git a/narwhals/_pandas_like/series_str.py b/narwhals/_pandas_like/series_str.py index c162bede4c..420058949d 100644 --- a/narwhals/_pandas_like/series_str.py +++ b/narwhals/_pandas_like/series_str.py @@ -3,7 +3,11 @@ from typing import TYPE_CHECKING, Any from narwhals._compliant.any_namespace import StringNamespace -from narwhals._pandas_like.utils import PandasLikeSeriesNamespace, is_dtype_pyarrow +from narwhals._pandas_like.utils import ( + PandasLikeSeriesNamespace, + align_and_extract_native, + is_dtype_pyarrow, +) if TYPE_CHECKING: from narwhals._pandas_like.series import PandasLikeSeries @@ -18,10 +22,13 @@ def len_chars(self) -> PandasLikeSeries: def replace( self, value: PandasLikeSeries | str, pattern: str, *, literal: bool, n: int ) -> PandasLikeSeries: - if not isinstance(value, str): + _, value_native = align_and_extract_native(self.compliant, value) + if not isinstance(value_native, str): msg = f"{self.compliant._implementation} backed `.str.replace` only supports str replacement values" raise TypeError(msg) - series = self.native.str.replace(pat=pattern, repl=value, n=n, regex=not literal) + series = self.native.str.replace( + pat=pattern, repl=value_native, n=n, regex=not literal + ) return self.with_native(series) def replace_all( diff --git a/narwhals/_sql/expr.py b/narwhals/_sql/expr.py index db6e4d2716..032c7da2aa 100644 --- a/narwhals/_sql/expr.py +++ b/narwhals/_sql/expr.py @@ -28,13 +28,7 @@ from narwhals._sql.expr_dt import SQLExprDateTimeNamesSpace from narwhals._sql.expr_str import SQLExprStringNamespace from narwhals._sql.namespace import SQLNamespace - from narwhals.typing import ( - ModeKeepStrategy, - NumericLiteral, - PythonLiteral, - RankMethod, - TemporalLiteral, - ) + from narwhals.typing import ModeKeepStrategy, PythonLiteral, RankMethod class SQLExpr(LazyExpr[SQLLazyFrameT, NativeExprT], Protocol[SQLLazyFrameT, NativeExprT]): @@ -518,11 +512,7 @@ def window_f( def abs(self) -> Self: return self._with_elementwise(lambda expr: self._function("abs", expr)) - def clip( - self, - lower_bound: Self | NumericLiteral | TemporalLiteral | None, - upper_bound: Self | NumericLiteral | TemporalLiteral | None, - ) -> Self: + def clip(self, lower_bound: Self, upper_bound: Self) -> Self: def _clip( expr: NativeExprT, lower_bound: NativeExprT, upper_bound: NativeExprT ) -> NativeExprT: From 255e828218902087e1a50825766d850d00d2ce59 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 30 Sep 2025 15:39:27 +0100 Subject: [PATCH 37/95] wip simpler --- narwhals/_compliant/namespace.py | 10 ++-------- narwhals/_expression_parsing.py | 4 ++-- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index c3e3de31a1..68dec2786b 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -16,7 +16,6 @@ NativeFrameT_co, NativeSeriesT, ) -from narwhals._expression_parsing import is_expr from narwhals._utils import ( exclude_column_names, get_column_names, @@ -59,13 +58,8 @@ class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]): @property def _expr(self) -> type[CompliantExprT]: ... - def evaluate_expr(self, data: Expr | NonNestedLiteral, /) -> CompliantExprT: - if is_expr(data): - ret = data(self) - else: - from narwhals.functions import lit - - ret = lit(data)(self) + def evaluate_expr(self, data: Expr) -> CompliantExprT: + ret = data(self) return cast("CompliantExprT", ret) # NOTE: `polars` diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 9b7251e22d..0c676f5434 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -763,7 +763,7 @@ def _parse_into_expr( backend: Any = None, allow_literal: bool = True, ) -> Expr | NonNestedLiteral: - from narwhals.functions import col, new_series + from narwhals.functions import col, lit, new_series if isinstance(arg, str) and not str_as_lit: return col(arg) @@ -775,7 +775,7 @@ def _parse_into_expr( return arg if not allow_literal: raise InvalidIntoExprError.from_invalid_type(type(arg)) - return arg + return lit(arg) def evaluate_into_exprs( From ae48754e4fb23bfe715e0a443b3ab75ebf5d28ff Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 30 Sep 2025 15:47:29 +0100 Subject: [PATCH 38/95] simpler typing --- narwhals/_compliant/namespace.py | 2 +- narwhals/_expression_parsing.py | 38 ++++++-------------------------- narwhals/_polars/namespace.py | 22 +++++------------- narwhals/_sql/namespace.py | 14 ++---------- 4 files changed, 15 insertions(+), 61 deletions(-) diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index 68dec2786b..fe6ada7d75 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -58,7 +58,7 @@ class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]): @property def _expr(self) -> type[CompliantExprT]: ... - def evaluate_expr(self, data: Expr) -> CompliantExprT: + def evaluate_expr(self, data: Expr, /) -> CompliantExprT: ret = data(self) return cast("CompliantExprT", ret) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 0c676f5434..46e917f695 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -5,7 +5,7 @@ from __future__ import annotations from enum import Enum, auto -from typing import TYPE_CHECKING, Any, Literal, cast, overload +from typing import TYPE_CHECKING, Any, Literal, cast from narwhals._utils import is_compliant_expr, zip_strict from narwhals.dependencies import is_numpy_array_1d @@ -173,10 +173,8 @@ def is_elementwise(self) -> bool: } -def is_scalar_like(obj: CompliantExprAny | NonNestedLiteral) -> bool: - if is_compliant_expr(obj): - return obj._metadata.is_scalar_like - return True +def is_scalar_like(obj: CompliantExprAny) -> bool: + return obj._metadata.is_scalar_like class ExpansionKind(Enum): @@ -736,33 +734,13 @@ def check_expressions_preserve_length( raise InvalidOperationError(msg) -@overload -def _parse_into_expr( - arg: IntoExpr | NonNestedLiteral | _1DArray, - *, - str_as_lit: bool = False, - backend: Any = None, - allow_literal: Literal[False], -) -> Expr: ... - - -@overload -def _parse_into_expr( - arg: IntoExpr | NonNestedLiteral | _1DArray, - *, - str_as_lit: bool = False, - backend: Any = None, - allow_literal: Literal[True] = ..., -) -> Expr | NonNestedLiteral: ... - - def _parse_into_expr( arg: IntoExpr | NonNestedLiteral | _1DArray, *, str_as_lit: bool = False, backend: Any = None, allow_literal: bool = True, -) -> Expr | NonNestedLiteral: +) -> Expr: from narwhals.functions import col, lit, new_series if isinstance(arg, str) and not str_as_lit: @@ -783,7 +761,7 @@ def evaluate_into_exprs( ns: CompliantNamespaceAny, str_as_lit: bool, allow_multi_output: bool, -) -> Iterator[CompliantExprAny | NonNestedLiteral]: +) -> Iterator[CompliantExprAny]: for expr in exprs: ret = ns.evaluate_expr( _parse_into_expr(expr, str_as_lit=str_as_lit, backend=ns._implementation) @@ -798,11 +776,9 @@ def evaluate_into_exprs( yield ret -def maybe_broadcast_ces( - *ces: CompliantExprAny | NonNestedLiteral, -) -> list[CompliantExprAny | NonNestedLiteral]: +def maybe_broadcast_ces(*ces: CompliantExprAny) -> list[CompliantExprAny]: broadcast = any(not is_scalar_like(ce) for ce in ces) - results: list[CompliantExprAny | NonNestedLiteral] = [] + results: list[CompliantExprAny] = [] for compliant_expr in ces: if ( broadcast diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 699f90f110..6ee35911d0 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -5,7 +5,6 @@ import polars as pl -from narwhals._expression_parsing import is_expr from narwhals._polars.expr import PolarsExpr from narwhals._polars.series import PolarsSeries from narwhals._polars.utils import extract_args_kwargs, narwhals_to_native_dtype @@ -22,14 +21,7 @@ from narwhals._polars.typing import FrameT from narwhals._utils import Version, _LimitedContext from narwhals.expr import Expr - from narwhals.typing import ( - Into1DArray, - IntoDType, - IntoSchema, - NonNestedLiteral, - TimeUnit, - _2DArray, - ) + from narwhals.typing import Into1DArray, IntoDType, IntoSchema, TimeUnit, _2DArray class PolarsNamespace: @@ -51,14 +43,10 @@ def _backend_version(self) -> tuple[int, ...]: def __init__(self, *, version: Version) -> None: self._version = version - def evaluate_expr( - self, data: Expr | NonNestedLiteral | Any, / - ) -> PolarsExpr | NonNestedLiteral: - if is_expr(data): - expr = data(self) - assert isinstance(expr, self._expr) # noqa: S101 - return expr - return data + def evaluate_expr(self, data: Expr, /) -> PolarsExpr: + expr = data(self) + assert isinstance(expr, self._expr) # noqa: S101 + return expr def __getattr__(self, attr: str) -> Any: def func(*args: Any, **kwargs: Any) -> Any: diff --git a/narwhals/_sql/namespace.py b/narwhals/_sql/namespace.py index e1e453c0df..c1e7ccb817 100644 --- a/narwhals/_sql/namespace.py +++ b/narwhals/_sql/namespace.py @@ -2,19 +2,16 @@ import operator from functools import reduce -from typing import TYPE_CHECKING, Any, Protocol, cast +from typing import TYPE_CHECKING, Any, Protocol from narwhals._compliant import LazyNamespace from narwhals._compliant.typing import NativeExprT, NativeFrameT_co -from narwhals._expression_parsing import is_expr from narwhals._sql.typing import SQLExprT, SQLLazyFrameT -from narwhals.functions import lit if TYPE_CHECKING: from collections.abc import Iterable - from narwhals.expr import Expr - from narwhals.typing import NonNestedLiteral, PythonLiteral + from narwhals.typing import PythonLiteral class SQLNamespace( @@ -31,13 +28,6 @@ def _when( ) -> NativeExprT: ... def _coalesce(self, *exprs: NativeExprT) -> NativeExprT: ... - def evaluate_expr(self, data: Expr | NonNestedLiteral | Any, /) -> SQLExprT: - if is_expr(data): - expr = data(self) - assert isinstance(expr, self._expr) # noqa: S101 - return expr - return cast("SQLExprT", lit(data)(self)) - # Horizontal functions def any_horizontal(self, *exprs: SQLExprT, ignore_nulls: bool) -> SQLExprT: def func(cols: Iterable[NativeExprT]) -> NativeExprT: From 6c054ccffe9af4559d3ccaa7fd693dcaf83648a0 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 30 Sep 2025 15:57:27 +0100 Subject: [PATCH 39/95] who even needs `is_compliant_expr` anymore? --- narwhals/_compliant/namespace.py | 19 +++------------- narwhals/_expression_parsing.py | 38 ++++++++------------------------ narwhals/_utils.py | 15 +------------ 3 files changed, 13 insertions(+), 59 deletions(-) diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index fe6ada7d75..e6a4480785 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -19,7 +19,6 @@ from narwhals._utils import ( exclude_column_names, get_column_names, - is_compliant_expr, passthrough_column_names, ) from narwhals.dependencies import is_numpy_array_2d @@ -155,34 +154,22 @@ def _if_then_else( otherwise: NativeSeriesT | None = None, ) -> NativeSeriesT: ... def when_then( - self, - predicate: EagerExprT, - then: EagerExprT | NonNestedLiteral, - otherwise: EagerExprT | NonNestedLiteral | None = None, + self, predicate: EagerExprT, then: EagerExprT, otherwise: EagerExprT | None = None ) -> EagerExprT: def func(df: EagerDataFrameT) -> Sequence[EagerSeriesT_co]: predicate_s = df._evaluate_expr(predicate) align = predicate_s._align_full_broadcast - if is_compliant_expr(then): - then_s = df._evaluate_expr(then) - else: - then_s = predicate_s._from_scalar(then).alias("literal") - then_s._broadcast = True + then_s = df._evaluate_expr(then) if otherwise is None: predicate_s, then_s = align(predicate_s, then_s) result = self._if_then_else(predicate_s.native, then_s.native) - if is_compliant_expr(otherwise): - otherwise_s = df._evaluate_expr(otherwise) - elif otherwise is not None: - otherwise_s = predicate_s._from_scalar(otherwise).alias("literal") - otherwise_s._broadcast = True - if otherwise is None: predicate_s, then_s = align(predicate_s, then_s) result = self._if_then_else(predicate_s.native, then_s.native) else: + otherwise_s = df._evaluate_expr(otherwise) predicate_s, then_s, otherwise_s = align(predicate_s, then_s, otherwise_s) result = self._if_then_else( predicate_s.native, then_s.native, otherwise_s.native diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 46e917f695..478f7a9999 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -7,7 +7,7 @@ from enum import Enum, auto from typing import TYPE_CHECKING, Any, Literal, cast -from narwhals._utils import is_compliant_expr, zip_strict +from narwhals._utils import zip_strict from narwhals.dependencies import is_numpy_array_1d from narwhals.exceptions import ( InvalidIntoExprError, @@ -52,10 +52,6 @@ def combine_evaluate_output_names( ) -> EvalNames[CompliantFrameT]: # Follow left-hand-rule for naming. E.g. `nw.sum_horizontal(expr1, expr2)` takes the # first name of `expr1`. - if not is_compliant_expr(exprs[0]): # pragma: no cover - msg = f"Safety assertion failed, expected expression, got: {type(exprs[0])}. Please report a bug." - raise AssertionError(msg) - def evaluate_output_names(df: CompliantFrameT) -> Sequence[str]: return exprs[0]._evaluate_output_names(df)[:1] @@ -392,7 +388,7 @@ def __repr__(self) -> str: # pragma: no cover @classmethod def from_node( # noqa: PLR0911 - cls, node: ExprNode, *ces: CompliantExprAny | NonNestedLiteral + cls, node: ExprNode, *ces: CompliantExprAny ) -> ExprMetadata: if node.kind is ExprKind.SERIES: return cls.from_selector_single(node) @@ -422,10 +418,7 @@ def from_node( # noqa: PLR0911 raise AssertionError(msg) # pragma: no cover def with_node( # noqa: PLR0911,C901 - self, - node: ExprNode, - ce: CompliantExprAny, - *ces: CompliantExprAny | NonNestedLiteral, + self, node: ExprNode, ce: CompliantExprAny, *ces: CompliantExprAny ) -> ExprMetadata: if node.kind is ExprKind.AGGREGATION: return self.with_aggregation(node) @@ -490,9 +483,7 @@ def from_selector_multi_unnamed(cls, node: ExprNode) -> ExprMetadata: return cls(ExpansionKind.MULTI_UNNAMED, nodes=(node,)) @classmethod - def from_elementwise( - cls, node: ExprNode, *ces: CompliantExprAny | NonNestedLiteral - ) -> ExprMetadata: + def from_elementwise(cls, node: ExprNode, *ces: CompliantExprAny) -> ExprMetadata: return combine_metadata(*ces, to_single_output=True, nodes=(node,)) @property @@ -662,7 +653,7 @@ def op_nodes_reversed(self) -> Iterator[ExprNode]: def combine_metadata( - *args: IntoExpr | object | None, to_single_output: bool, nodes: tuple[ExprNode, ...] + *args: CompliantExprAny, to_single_output: bool, nodes: tuple[ExprNode, ...] ) -> ExprMetadata: """Combine metadata from `args`. @@ -686,8 +677,6 @@ def combine_metadata( result_is_literal = True for i, arg in enumerate(args): - if not is_compliant_expr(arg): - continue metadata = arg._metadata assert metadata is not None # noqa: S101 if metadata.expansion_kind.is_multi_output(): @@ -723,13 +712,13 @@ def combine_metadata( def check_expressions_preserve_length( - *args: CompliantExprAny | NonNestedLiteral, function_name: str + *args: CompliantExprAny, function_name: str ) -> None: # Raise if any argument in `args` isn't length-preserving. # For Series input, we don't raise (yet), we let such checks happen later, # as this function works lazily and so can't evaluate lengths. - if not all((is_compliant_expr(x) and x._metadata.preserves_length) for x in args): + if not all(x._metadata.preserves_length for x in args): msg = f"Expressions which aggregate or change length cannot be passed to '{function_name}'." raise InvalidOperationError(msg) @@ -766,11 +755,7 @@ def evaluate_into_exprs( ret = ns.evaluate_expr( _parse_into_expr(expr, str_as_lit=str_as_lit, backend=ns._implementation) ) - if ( - not allow_multi_output - and is_compliant_expr(ret) - and ret._metadata.expansion_kind.is_multi_output() - ): + if not allow_multi_output and ret._metadata.expansion_kind.is_multi_output(): msg = "Multi-output expressions are not allowed in this context." raise MultiOutputExpressionError(msg) yield ret @@ -780,11 +765,7 @@ def maybe_broadcast_ces(*ces: CompliantExprAny) -> list[CompliantExprAny]: broadcast = any(not is_scalar_like(ce) for ce in ces) results: list[CompliantExprAny] = [] for compliant_expr in ces: - if ( - broadcast - and is_compliant_expr(compliant_expr) - and is_scalar_like(compliant_expr) - ): + if broadcast and is_scalar_like(compliant_expr): _compliant_expr: CompliantExprAny = compliant_expr.broadcast() # Make sure to preserve metadata. _compliant_expr._opt_metadata = compliant_expr._metadata @@ -833,7 +814,6 @@ def evaluate_node( allow_multi_output=node.allow_multi_output, ), ) - assert is_compliant_expr(ce) # noqa: S101 md = md.with_node(node, ce, *ces) if "." in node.name: accessor, method = node.name.split(".") diff --git a/narwhals/_utils.py b/narwhals/_utils.py index 346008cc81..0405fe7d6f 100644 --- a/narwhals/_utils.py +++ b/narwhals/_utils.py @@ -65,14 +65,7 @@ TypeIs, ) - from narwhals._compliant import ( - CompliantExpr, - CompliantExprT, - CompliantFrameT, - CompliantSeriesOrNativeExprT_co, - CompliantSeriesT, - NativeSeriesT_co, - ) + from narwhals._compliant import CompliantExprT, CompliantSeriesT, NativeSeriesT_co from narwhals._compliant.any_namespace import NamespaceAccessor from narwhals._compliant.typing import ( Accessor, @@ -1591,12 +1584,6 @@ def is_compliant_series_int( return is_compliant_series(obj) and obj.dtype.is_integer() -def is_compliant_expr( - obj: CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co] | Any, -) -> TypeIs[CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co]]: - return hasattr(obj, "__narwhals_expr__") - - def _is_namespace_accessor(obj: _IntoContext) -> TypeIs[NamespaceAccessor[_FullContext]]: # NOTE: Only `compliant` has false positives **internally** # - https://github.com/narwhals-dev/narwhals/blob/cc69bac35eb8c81a1106969c49bfba9fd569b856/narwhals/_compliant/group_by.py#L44-L49 From 7d37394f6b178d6b9903b6b56b23e8b5672b9d31 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 30 Sep 2025 16:21:07 +0100 Subject: [PATCH 40/95] dask fixup --- narwhals/_dask/namespace.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index f0b937e543..9084575836 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -1,6 +1,7 @@ from __future__ import annotations import operator +from datetime import date, datetime from functools import reduce from itertools import chain from typing import TYPE_CHECKING, cast @@ -58,6 +59,11 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: if dtype is not None: native_dtype = narwhals_to_native_dtype(dtype, self._version) native_pd_series = pd.Series([value], dtype=native_dtype, name="literal") + elif isinstance(value, date) and not isinstance(value, datetime): + # Dask auto-infers this as object type, which causes issues down the line. + native_pd_series = pd.Series( + [value], dtype="date32[pyarrow]", name="literal" + ) else: native_pd_series = pd.Series([value], name="literal") npartitions = df._native_frame.npartitions From 1771c8e8ae9bc35e6881a4367ee9afc3f00851ed Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 30 Sep 2025 16:32:17 +0100 Subject: [PATCH 41/95] polars compat --- narwhals/_polars/expr.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index 27705b3046..d85e0a6853 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -104,6 +104,17 @@ def cast(self, dtype: IntoDType) -> Self: dtype_pl = narwhals_to_native_dtype(dtype, self._version) return self._with_native(self.native.cast(dtype_pl)) + def clip(self, lower_bound: PolarsExpr, upper_bound: PolarsExpr) -> Self: + lower_native = extract_native(lower_bound) + upper_native = extract_native(upper_bound) + if self._backend_version < (1,): + # Work around a bug in old Polars versions. + if lower_bound._metadata.is_literal: + lower_native = pl.select(lower_native).item() + if upper_bound._metadata.is_literal: + upper_native = pl.select(upper_native).item() + return self._with_native(self.native.clip(lower_native, upper_native)) + def ewm_mean( self, *, @@ -307,7 +318,6 @@ def struct(self) -> PolarsExprStructNamespace: arg_max: Method[Self] arg_min: Method[Self] arg_true: Method[Self] - clip: Method[Self] count: Method[Self] cum_max: Method[Self] cum_min: Method[Self] From 8815b191fd0794c7ef7e81859be7c41f0c8a7b29 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 30 Sep 2025 16:42:39 +0100 Subject: [PATCH 42/95] split clip into clip_lower,clip_upper,clip --- narwhals/_compliant/column.py | 2 ++ narwhals/_compliant/expr.py | 6 ++++++ narwhals/_dask/expr.py | 32 ++++++++++++++------------- narwhals/_pandas_like/series.py | 38 +++++++++++++++++++++++++++++++++ narwhals/_polars/expr.py | 13 +++++------ narwhals/_sql/expr.py | 12 +++++++++++ narwhals/expr.py | 8 +++++++ 7 files changed, 88 insertions(+), 23 deletions(-) diff --git a/narwhals/_compliant/column.py b/narwhals/_compliant/column.py index fc21045ddc..e5b87e8869 100644 --- a/narwhals/_compliant/column.py +++ b/narwhals/_compliant/column.py @@ -61,6 +61,8 @@ def abs(self) -> Self: ... def alias(self, name: str) -> Self: ... def cast(self, dtype: IntoDType) -> Self: ... def clip(self, lower_bound: Self, upper_bound: Self) -> Self: ... + def clip_lower(self, lower_bound: Self) -> Self: ... + def clip_upper(self, upper_bound: Self) -> Self: ... def cum_count(self, *, reverse: bool) -> Self: ... def cum_max(self, *, reverse: bool) -> Self: ... def cum_min(self, *, reverse: bool) -> Self: ... diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index ff9439d7b5..1af4f1111c 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -551,6 +551,12 @@ def clip(self, lower_bound: Self, upper_bound: Self) -> Self: "clip", lower_bound=lower_bound, upper_bound=upper_bound ) + def clip_lower(self, lower_bound: Self) -> Self: + return self._reuse_series("clip_lower", lower_bound=lower_bound) + + def clip_upper(self, upper_bound: Self) -> Self: + return self._reuse_series("clip_upper", upper_bound=upper_bound) + def is_null(self) -> Self: return self._reuse_series("is_null") diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 9876f985bf..f193a3d9c6 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -463,22 +463,24 @@ def func(expr: dx.Series) -> dx.Series: return self._with_callable(func) def clip(self, lower_bound: Self, upper_bound: Self) -> Self: - def func(df: DaskLazyFrame) -> list[dx.Series]: - lower_native: dx.Series = evaluate_expr(df, lower_bound) - upper_native: dx.Series = evaluate_expr(df, upper_bound) - if lower_native.dtype == "O": - # `lower_bound` was `lit(None)` - return [series.clip(upper=upper_native) for series in self(df)] - if upper_native.dtype == "O": - # `lower_bound` was `lit(None)` - return [series.clip(lower_native) for series in self(df)] - return [series.clip(lower_native, upper_native) for series in self(df)] + return self._with_callable( + lambda expr, lower_bound, upper_bound: expr.clip( + lower=lower_bound, upper=upper_bound + ), + lower_bound=lower_bound, + upper_bound=upper_bound, + ) - return self.__class__( - func, - evaluate_output_names=self._evaluate_output_names, - alias_output_names=self._alias_output_names, - version=self._version, + def clip_lower(self, lower_bound: Self) -> Self: + return self._with_callable( + lambda expr, lower_bound: expr.clip(lower=lower_bound), + lower_bound=lower_bound, + ) + + def clip_upper(self, upper_bound: Self) -> Self: + return self._with_callable( + lambda expr, upper_bound: expr.clip(upper=upper_bound), + upper_bound=upper_bound, ) def diff(self) -> Self: diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 8b21532481..deb8ffb453 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -849,6 +849,44 @@ def clip(self, lower_bound: Self, upper_bound: Self) -> Self: return self._with_native(result.clip(lower, upper, **kwargs)) + def clip_lower(self, lower_bound: Self) -> Self: + _, lower = ( + align_and_extract_native(self, lower_bound) + if lower_bound is not None + else (None, None) + ) + impl = self._implementation + kwargs: dict[str, Any] = {"axis": 0} if impl.is_modin() else {} + result = self.native + + if not impl.is_pandas() and self._is_native(lower): + # Workaround for both cudf and modin when clipping with a series + # * cudf: https://github.com/rapidsai/cudf/issues/17682 + # * modin: https://github.com/modin-project/modin/issues/7415 + result = result.where(result >= lower, lower) + lower = None + + return self._with_native(result.clip(lower, **kwargs)) + + def clip_upper(self, upper_bound: Self) -> Self: + _, upper = ( + align_and_extract_native(self, upper_bound) + if upper_bound is not None + else (None, None) + ) + impl = self._implementation + kwargs: dict[str, Any] = {"axis": 0} if impl.is_modin() else {} + result = self.native + + if not impl.is_pandas() and self._is_native(upper): + # Workaround for both cudf and modin when clipping with a series + # * cudf: https://github.com/rapidsai/cudf/issues/17682 + # * modin: https://github.com/modin-project/modin/issues/7415 + result = result.where(result <= upper, upper) + upper = None + + return self._with_native(result.clip(upper=upper, **kwargs)) + def to_arrow(self) -> pa.Array[Any]: if self._implementation is Implementation.CUDF: return self.native.to_arrow() diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index d85e0a6853..de4637840e 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -104,16 +104,13 @@ def cast(self, dtype: IntoDType) -> Self: dtype_pl = narwhals_to_native_dtype(dtype, self._version) return self._with_native(self.native.cast(dtype_pl)) - def clip(self, lower_bound: PolarsExpr, upper_bound: PolarsExpr) -> Self: + def clip_lower(self, lower_bound: PolarsExpr) -> Self: lower_native = extract_native(lower_bound) + return self._with_native(self.native.clip(lower_native)) + + def clip_upper(self, upper_bound: PolarsExpr) -> Self: upper_native = extract_native(upper_bound) - if self._backend_version < (1,): - # Work around a bug in old Polars versions. - if lower_bound._metadata.is_literal: - lower_native = pl.select(lower_native).item() - if upper_bound._metadata.is_literal: - upper_native = pl.select(upper_native).item() - return self._with_native(self.native.clip(lower_native, upper_native)) + return self._with_native(self.native.clip(None, upper_native)) def ewm_mean( self, diff --git a/narwhals/_sql/expr.py b/narwhals/_sql/expr.py index 032c7da2aa..946a39fa9e 100644 --- a/narwhals/_sql/expr.py +++ b/narwhals/_sql/expr.py @@ -524,6 +524,18 @@ def _clip( _clip, lower_bound=lower_bound, upper_bound=upper_bound ) + def clip_lower(self, lower_bound: Self) -> Self: + def _clip(expr: NativeExprT, lower_bound: NativeExprT) -> NativeExprT: + return self._function("greatest", expr, lower_bound) + + return self._with_elementwise(_clip, lower_bound=lower_bound) + + def clip_upper(self, upper_bound: Self) -> Self: + def _clip(expr: NativeExprT, upper_bound: NativeExprT) -> NativeExprT: + return self._function("least", expr, upper_bound) + + return self._with_elementwise(_clip, upper_bound=upper_bound) + def is_null(self) -> Self: return self._with_elementwise(lambda expr: self._function("isnull", expr)) diff --git a/narwhals/expr.py b/narwhals/expr.py index 65ebf2dad4..29091c8c7e 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -1523,6 +1523,14 @@ def clip( | 2 3 3 | └──────────────────┘ """ + if upper_bound is None: + return self._with_node( + ExprNode(ExprKind.ELEMENTWISE, "clip_lower", lower_bound) + ) + if lower_bound is None: + return self._with_node( + ExprNode(ExprKind.ELEMENTWISE, "clip_upper", upper_bound) + ) return self._with_node( ExprNode(ExprKind.ELEMENTWISE, "clip", lower_bound, upper_bound) ) From 3ed89a91dc911da86fabfeee9b8c6aaf69a91fa2 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 30 Sep 2025 16:52:02 +0100 Subject: [PATCH 43/95] complete the split --- narwhals/_arrow/series.py | 21 ++++++++++----------- narwhals/_polars/series.py | 13 +++++++++++++ narwhals/series.py | 12 ++++++++++++ 3 files changed, 35 insertions(+), 11 deletions(-) diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 0634c94cfc..f9a45848ea 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -834,21 +834,20 @@ def gather_every(self, n: int, offset: int = 0) -> Self: return self._with_native(self.native[offset::n]) def clip(self, lower_bound: Self, upper_bound: Self) -> Self: - _, lower = ( - extract_native(self, lower_bound) if lower_bound is not None else (None, None) - ) - _, upper = ( - extract_native(self, upper_bound) if upper_bound is not None else (None, None) - ) - - if lower is None or isinstance(lower, pa.NullScalar): - return self._with_native(pc.min_element_wise(self.native, upper)) - if upper is None or isinstance(upper, pa.NullScalar): - return self._with_native(pc.max_element_wise(self.native, lower)) + _, lower = extract_native(self, lower_bound) + _, upper = extract_native(self, upper_bound) return self._with_native( pc.max_element_wise(pc.min_element_wise(self.native, upper), lower) ) + def clip_lower(self, lower_bound: Self) -> Self: + _, lower = extract_native(self, lower_bound) + return self._with_native(pc.max_element_wise(self.native, lower)) + + def clip_upper(self, upper_bound: Self) -> Self: + _, upper = extract_native(self, upper_bound) + return self._with_native(pc.min_element_wise(self.native, upper)) + def to_arrow(self) -> ArrayAny: return self.native.combine_chunks() diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index b45172d108..966570ce54 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -285,6 +285,19 @@ def cast(self, dtype: IntoDType) -> Self: dtype_pl = narwhals_to_native_dtype(dtype, self._version) return self._with_native(self.native.cast(dtype_pl)) + def clip(self, lower_bound: PolarsSeries, upper_bound: PolarsSeries) -> Self: + return self._with_native( + self.native.clip(extract_native(lower_bound), extract_native(upper_bound)) + ) + + def clip_lower(self, lower_bound: PolarsSeries) -> Self: + return self._with_native(self.native.clip(extract_native(lower_bound))) + + def clip_upper(self, upper_bound: PolarsSeries) -> Self: + return self._with_native( + self.native.clip(upper_bound=extract_native(upper_bound)) + ) + @requires.backend_version((1,)) def replace_strict( self, diff --git a/narwhals/series.py b/narwhals/series.py index 6e749bffd8..463a395c2a 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -899,6 +899,18 @@ def clip( 5 3 dtype: int64 """ + if lower_bound is None: + return self._with_compliant( + self._compliant_series.clip_upper( + upper_bound=self._extract_native(upper_bound) + ) + ) + if upper_bound is None: + return self._with_compliant( + self._compliant_series.clip_lower( + lower_bound=self._extract_native(lower_bound) + ) + ) return self._with_compliant( self._compliant_series.clip( lower_bound=self._extract_native(lower_bound), From a7a3bb2025afdc8ff3bb4bf9868d65df4bced82d Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 30 Sep 2025 16:57:03 +0100 Subject: [PATCH 44/95] :art: --- narwhals/_dask/namespace.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 9084575836..df1ea0ac40 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -61,9 +61,8 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: native_pd_series = pd.Series([value], dtype=native_dtype, name="literal") elif isinstance(value, date) and not isinstance(value, datetime): # Dask auto-infers this as object type, which causes issues down the line. - native_pd_series = pd.Series( - [value], dtype="date32[pyarrow]", name="literal" - ) + native_dtype = "date32[pyarrow]" + native_pd_series = pd.Series([value], dtype=native_dtype, name="literal") else: native_pd_series = pd.Series([value], name="literal") npartitions = df._native_frame.npartitions From e982fd0a1c537f506a0c5c3aecfd3124f68506b3 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 30 Sep 2025 16:59:48 +0100 Subject: [PATCH 45/95] typing --- narwhals/_polars/series.py | 1 - 1 file changed, 1 deletion(-) diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index 966570ce54..a9bdbe1209 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -676,7 +676,6 @@ def struct(self) -> PolarsSeriesStructNamespace: arg_max: Method[int] arg_min: Method[int] arg_true: Method[Self] - clip: Method[Self] count: Method[int] cum_max: Method[Self] cum_min: Method[Self] From 90a33020f6bff9012a0dc123de602d6e2057d580 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 30 Sep 2025 17:01:45 +0100 Subject: [PATCH 46/95] skip old dask for fill_null --- tests/expr_and_series/fill_null_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/expr_and_series/fill_null_test.py b/tests/expr_and_series/fill_null_test.py index 014e92ccfb..e6bd003344 100644 --- a/tests/expr_and_series/fill_null_test.py +++ b/tests/expr_and_series/fill_null_test.py @@ -8,6 +8,7 @@ import narwhals as nw from tests.utils import ( + DASK_VERSION, DUCKDB_VERSION, POLARS_VERSION, Constructor, @@ -17,6 +18,9 @@ def test_fill_null(constructor: Constructor) -> None: + if "dask" in str(constructor) and DASK_VERSION <= (2024, 10): + # Bug in old version of Dask. + pytest.skip() data = { "a": [0.0, None, 2.0, 3.0, 4.0], "b": [1.0, None, None, 5.0, 3.0], From e8fa7f1845de7ab8fe858122ac2ee378ba581d1b Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 30 Sep 2025 17:11:38 +0100 Subject: [PATCH 47/95] coverage, simplify --- narwhals/_dask/namespace.py | 5 ++++- narwhals/_pandas_like/series.py | 28 ++++++------------------- tests/expr_and_series/fill_null_test.py | 6 +++--- 3 files changed, 13 insertions(+), 26 deletions(-) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index df1ea0ac40..5453214a89 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -59,8 +59,11 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: if dtype is not None: native_dtype = narwhals_to_native_dtype(dtype, self._version) native_pd_series = pd.Series([value], dtype=native_dtype, name="literal") - elif isinstance(value, date) and not isinstance(value, datetime): + elif isinstance(value, date) and not isinstance( + value, datetime + ): # pragma: no cover # Dask auto-infers this as object type, which causes issues down the line. + # This shows up in TPC-H q8. native_dtype = "date32[pyarrow]" native_pd_series = pd.Series([value], dtype=native_dtype, name="literal") else: diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index deb8ffb453..b308b4be79 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -822,16 +822,8 @@ def gather_every(self, n: int, offset: int) -> Self: return self._with_native(self.native.iloc[offset::n]) def clip(self, lower_bound: Self, upper_bound: Self) -> Self: - _, lower = ( - align_and_extract_native(self, lower_bound) - if lower_bound is not None - else (None, None) - ) - _, upper = ( - align_and_extract_native(self, upper_bound) - if upper_bound is not None - else (None, None) - ) + _, lower = align_and_extract_native(self, lower_bound) + _, upper = align_and_extract_native(self, upper_bound) impl = self._implementation kwargs: dict[str, Any] = {"axis": 0} if impl.is_modin() else {} result = self.native @@ -850,16 +842,12 @@ def clip(self, lower_bound: Self, upper_bound: Self) -> Self: return self._with_native(result.clip(lower, upper, **kwargs)) def clip_lower(self, lower_bound: Self) -> Self: - _, lower = ( - align_and_extract_native(self, lower_bound) - if lower_bound is not None - else (None, None) - ) + _, lower = align_and_extract_native(self, lower_bound) impl = self._implementation kwargs: dict[str, Any] = {"axis": 0} if impl.is_modin() else {} result = self.native - if not impl.is_pandas() and self._is_native(lower): + if not impl.is_pandas() and self._is_native(lower): # pragma: no cover # Workaround for both cudf and modin when clipping with a series # * cudf: https://github.com/rapidsai/cudf/issues/17682 # * modin: https://github.com/modin-project/modin/issues/7415 @@ -869,16 +857,12 @@ def clip_lower(self, lower_bound: Self) -> Self: return self._with_native(result.clip(lower, **kwargs)) def clip_upper(self, upper_bound: Self) -> Self: - _, upper = ( - align_and_extract_native(self, upper_bound) - if upper_bound is not None - else (None, None) - ) + _, upper = align_and_extract_native(self, upper_bound) impl = self._implementation kwargs: dict[str, Any] = {"axis": 0} if impl.is_modin() else {} result = self.native - if not impl.is_pandas() and self._is_native(upper): + if not impl.is_pandas() and self._is_native(upper): # pragma: no cover # Workaround for both cudf and modin when clipping with a series # * cudf: https://github.com/rapidsai/cudf/issues/17682 # * modin: https://github.com/modin-project/modin/issues/7415 diff --git a/tests/expr_and_series/fill_null_test.py b/tests/expr_and_series/fill_null_test.py index e6bd003344..d5f279e30e 100644 --- a/tests/expr_and_series/fill_null_test.py +++ b/tests/expr_and_series/fill_null_test.py @@ -18,9 +18,6 @@ def test_fill_null(constructor: Constructor) -> None: - if "dask" in str(constructor) and DASK_VERSION <= (2024, 10): - # Bug in old version of Dask. - pytest.skip() data = { "a": [0.0, None, 2.0, 3.0, 4.0], "b": [1.0, None, None, 5.0, 3.0], @@ -38,6 +35,9 @@ def test_fill_null(constructor: Constructor) -> None: def test_fill_null_w_aggregate(constructor: Constructor) -> None: + if "dask" in str(constructor) and DASK_VERSION <= (2024, 10): + # Bug in old version of Dask. + pytest.skip() data = {"a": [0.5, None, 2.0, 3.0, 4.5], "b": ["xx", "yy", "zz", None, "yy"]} df = nw.from_native(constructor(data)) From 7a238762dda0e5f1656db73f261b6d47a9fbeda5 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 30 Sep 2025 17:32:01 +0100 Subject: [PATCH 48/95] old dask --- tests/expr_and_series/fill_null_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/expr_and_series/fill_null_test.py b/tests/expr_and_series/fill_null_test.py index d5f279e30e..a9cda707cf 100644 --- a/tests/expr_and_series/fill_null_test.py +++ b/tests/expr_and_series/fill_null_test.py @@ -35,7 +35,7 @@ def test_fill_null(constructor: Constructor) -> None: def test_fill_null_w_aggregate(constructor: Constructor) -> None: - if "dask" in str(constructor) and DASK_VERSION <= (2024, 10): + if "dask" in str(constructor) and DASK_VERSION < (2024, 12): # Bug in old version of Dask. pytest.skip() data = {"a": [0.5, None, 2.0, 3.0, 4.5], "b": ["xx", "yy", "zz", None, "yy"]} From b37dc66428f963e7087a162f00e13b666c163143 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 30 Sep 2025 17:37:47 +0100 Subject: [PATCH 49/95] it gets simpler --- narwhals/_polars/expr.py | 8 ++++---- narwhals/_polars/series.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index de4637840e..7861ee0681 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -407,17 +407,17 @@ def zfill(self, width: int) -> PolarsExpr: return self.compliant._with_native(native_result) def replace( - self, value: PolarsExpr | str, pattern: str, *, literal: bool, n: int + self, value: PolarsExpr, pattern: str, *, literal: bool, n: int ) -> PolarsExpr: - value_native = value if isinstance(value, str) else extract_native(value) + value_native = extract_native(value) return self.compliant._with_native( self.native.str.replace(pattern, value_native, literal=literal, n=n) ) def replace_all( - self, value: PolarsExpr | str, pattern: str, *, literal: bool + self, value: PolarsExpr, pattern: str, *, literal: bool ) -> PolarsExpr: - value_native = value if isinstance(value, str) else extract_native(value) + value_native = extract_native(value) return self.compliant._with_native( self.native.str.replace_all(pattern, value_native, literal=literal) ) diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index a9bdbe1209..d2840e55e7 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -768,17 +768,17 @@ def zfill(self, width: int) -> PolarsSeries: return self.to_frame().select(ns.col(name).str.zfill(width)).get_column(name) def replace( - self, value: PolarsSeries | str, pattern: str, *, literal: bool, n: int + self, value: PolarsSeries, pattern: str, *, literal: bool, n: int ) -> PolarsSeries: - value_native = value if isinstance(value, str) else extract_native(value) + value_native = extract_native(value) return self.compliant._with_native( self.native.str.replace(pattern, value_native, literal=literal, n=n) # type: ignore[arg-type] ) def replace_all( - self, value: PolarsSeries | str, pattern: str, *, literal: bool + self, value: PolarsSeries, pattern: str, *, literal: bool ) -> PolarsSeries: - value_native = value if isinstance(value, str) else extract_native(value) + value_native = extract_native(value) return self.compliant._with_native( self.native.str.replace_all(pattern, value_native, literal=literal) # type: ignore[arg-type] ) From c2089ded6ce4397f358a2b77daf1a8a4526a8f7a Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 30 Sep 2025 17:43:47 +0100 Subject: [PATCH 50/95] remove even more chaff --- narwhals/_dask/namespace.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 5453214a89..1f39f74fd4 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -273,10 +273,8 @@ def when_then( self, predicate: DaskExpr, then: DaskExpr, otherwise: DaskExpr | None = None ) -> DaskExpr: def func(df: DaskLazyFrame) -> list[dx.Series]: - then_value = then(df)[0] if isinstance(then, DaskExpr) else then - otherwise_value = ( - otherwise(df)[0] if isinstance(otherwise, DaskExpr) else otherwise - ) + then_value = then(df)[0] + otherwise_value = otherwise(df)[0] if otherwise is not None else otherwise condition = predicate(df)[0] # re-evaluate DataFrame if the condition aggregates to force @@ -285,12 +283,8 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: x._metadata.is_scalar_like for x in ( (predicate, then) - if (isinstance(then, DaskExpr) and otherwise is None) + if otherwise is None else (predicate, then, otherwise) - if isinstance(then, DaskExpr) and isinstance(otherwise, DaskExpr) - else (predicate, otherwise) - if isinstance(otherwise, DaskExpr) - else (predicate,) ) ): new_df = df._with_native(condition.to_frame()) From ed800885db89a7a695eccc73eb9aa20636c1f8ec Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 30 Sep 2025 18:10:35 +0100 Subject: [PATCH 51/95] remove several Any --- narwhals/_arrow/dataframe.py | 3 ++- narwhals/_compliant/expr.py | 49 +++++++++++++++++------------------- narwhals/_dask/expr.py | 12 +++------ narwhals/_duckdb/expr.py | 12 ++------- narwhals/_ibis/expr.py | 6 ++--- narwhals/_polars/expr.py | 4 +-- narwhals/_spark_like/expr.py | 15 +++++------ narwhals/_sql/expr.py | 10 ++++---- 8 files changed, 47 insertions(+), 64 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 7e31e1088d..37bde19aec 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -512,7 +512,8 @@ def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self: ) else: rank = plx.col(order_by[0]).rank("ordinal", descending=False) - row_index = (rank.over(partition_by=[], order_by=order_by) - 1).alias(name) + one = plx.lit(1, None).broadcast() + row_index = (rank.over(partition_by=[], order_by=order_by) - one).alias(name) return self.select(row_index, plx.all()) def filter(self, predicate: ArrowExpr | list[bool | None]) -> Self: diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index 1af4f1111c..e029e3d582 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -425,70 +425,70 @@ def func(df: EagerDataFrameT) -> list[EagerSeriesT]: def cast(self, dtype: IntoDType) -> Self: return self._reuse_series("cast", dtype=dtype) - def _with_binary(self, operator: str, other: Self | Any, /) -> Self: + def _with_binary(self, operator: str, other: Self, /) -> Self: return self._reuse_series(operator, other=other) - def _with_binary_right(self, operator: str, other: Self | Any, /) -> Self: + def _with_binary_right(self, operator: str, other: Self, /) -> Self: return self.alias("literal")._reuse_series(operator, other=other) - def __eq__(self, other: Self | Any) -> Self: # type: ignore[override] + def __eq__(self, other: Self) -> Self: # type: ignore[override] return self._with_binary("__eq__", other) - def __ne__(self, other: Self | Any) -> Self: # type: ignore[override] + def __ne__(self, other: Self) -> Self: # type: ignore[override] return self._with_binary("__ne__", other) - def __ge__(self, other: Self | Any) -> Self: + def __ge__(self, other: Self) -> Self: return self._with_binary("__ge__", other) - def __gt__(self, other: Self | Any) -> Self: + def __gt__(self, other: Self) -> Self: return self._with_binary("__gt__", other) - def __le__(self, other: Self | Any) -> Self: + def __le__(self, other: Self) -> Self: return self._with_binary("__le__", other) - def __lt__(self, other: Self | Any) -> Self: + def __lt__(self, other: Self) -> Self: return self._with_binary("__lt__", other) - def __and__(self, other: Self | bool | Any) -> Self: + def __and__(self, other: Self) -> Self: return self._with_binary("__and__", other) - def __or__(self, other: Self | bool | Any) -> Self: + def __or__(self, other: Self) -> Self: return self._with_binary("__or__", other) - def __add__(self, other: Self | Any) -> Self: + def __add__(self, other: Self) -> Self: return self._with_binary("__add__", other) - def __sub__(self, other: Self | Any) -> Self: + def __sub__(self, other: Self) -> Self: return self._with_binary("__sub__", other) - def __rsub__(self, other: Self | Any) -> Self: + def __rsub__(self, other: Self) -> Self: return self._with_binary_right("__rsub__", other) - def __mul__(self, other: Self | Any) -> Self: + def __mul__(self, other: Self) -> Self: return self._with_binary("__mul__", other) - def __truediv__(self, other: Self | Any) -> Self: + def __truediv__(self, other: Self) -> Self: return self._with_binary("__truediv__", other) - def __rtruediv__(self, other: Self | Any) -> Self: + def __rtruediv__(self, other: Self) -> Self: return self._with_binary_right("__rtruediv__", other) - def __floordiv__(self, other: Self | Any) -> Self: + def __floordiv__(self, other: Self) -> Self: return self._with_binary("__floordiv__", other) - def __rfloordiv__(self, other: Self | Any) -> Self: + def __rfloordiv__(self, other: Self) -> Self: return self._with_binary_right("__rfloordiv__", other) - def __pow__(self, other: Self | Any) -> Self: + def __pow__(self, other: Self) -> Self: return self._with_binary("__pow__", other) - def __rpow__(self, other: Self | Any) -> Self: + def __rpow__(self, other: Self) -> Self: return self._with_binary_right("__rpow__", other) - def __mod__(self, other: Self | Any) -> Self: + def __mod__(self, other: Self) -> Self: return self._with_binary("__mod__", other) - def __rmod__(self, other: Self | Any) -> Self: + def __rmod__(self, other: Self) -> Self: return self._with_binary_right("__rmod__", other) # Unary @@ -567,10 +567,7 @@ def fill_nan(self, value: float | None) -> Self: return self._reuse_series("fill_nan", value=value) def fill_null( - self, - value: Self | NonNestedLiteral, - strategy: FillNullStrategy | None, - limit: int | None, + self, value: Self | Any, strategy: FillNullStrategy | None, limit: int | None ) -> Self: return self._reuse_series( "fill_null", value=value, strategy=strategy, limit=limit diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index f193a3d9c6..fc9b5bef50 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -43,7 +43,6 @@ FillNullStrategy, IntoDType, ModeKeepStrategy, - NonNestedLiteral, RollingInterpolationMethod, ) @@ -131,7 +130,7 @@ def _with_callable( # First argument to `call` should be `dx.Series` call: Callable[..., dx.Series], /, - **expressifiable_args: Self | Any, + **expressifiable_args: Self, ) -> Self: def func(df: DaskLazyFrame) -> list[dx.Series]: native_results: list[dx.Series] = [] @@ -207,7 +206,7 @@ def __truediv__(self, other: Any) -> Self: def __floordiv__(self, other: Any) -> Self: def _floordiv( - df: DaskLazyFrame, series: dx.Series, other: dx.Series | Any + df: DaskLazyFrame, series: dx.Series, other: dx.Series ) -> dx.Series: series, other = align_series_full_broadcast(df, series, other) return (series.__floordiv__(other)).where(other != 0, None) @@ -261,7 +260,7 @@ def __rtruediv__(self, other: Any) -> Self: def __rfloordiv__(self, other: Any) -> Self: def _rfloordiv( - df: DaskLazyFrame, series: dx.Series, other: dx.Series | Any + df: DaskLazyFrame, series: dx.Series, other: dx.Series ) -> dx.Series: series, other = align_series_full_broadcast(df, series, other) return (other.__floordiv__(series)).where(series != 0, None) @@ -444,10 +443,7 @@ def func(expr: dx.Series) -> dx.Series: return self._with_callable(func) def fill_null( - self, - value: Self | NonNestedLiteral, - strategy: FillNullStrategy | None, - limit: int | None, + self, value: Self | Any, strategy: FillNullStrategy | None, limit: int | None ) -> Self: def func(expr: dx.Series) -> dx.Series: if value is not None: diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index e5647f63dd..955ae2c992 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -37,12 +37,7 @@ from narwhals._duckdb.dataframe import DuckDBLazyFrame from narwhals._duckdb.namespace import DuckDBNamespace from narwhals._utils import _LimitedContext - from narwhals.typing import ( - FillNullStrategy, - IntoDType, - NonNestedLiteral, - RollingInterpolationMethod, - ) + from narwhals.typing import FillNullStrategy, IntoDType, RollingInterpolationMethod DuckDBWindowFunction = WindowFunction[DuckDBLazyFrame, Expression] DuckDBWindowInputs = WindowInputs[Expression] @@ -234,10 +229,7 @@ def is_in(self, other: Sequence[Any]) -> Self: return self._with_elementwise(lambda expr: F("contains", lit(other), expr)) def fill_null( - self, - value: Self | NonNestedLiteral, - strategy: FillNullStrategy | None, - limit: int | None, + self, value: Self | Any, strategy: FillNullStrategy | None, limit: int | None ) -> Self: if strategy is not None: if self._backend_version < (1, 3): # pragma: no cover diff --git a/narwhals/_ibis/expr.py b/narwhals/_ibis/expr.py index 8f11a30bae..9143cabc79 100644 --- a/narwhals/_ibis/expr.py +++ b/narwhals/_ibis/expr.py @@ -167,11 +167,11 @@ def func(df: IbisLazyFrame) -> Sequence[ir.Column]: version=context._version, ) - def _with_binary(self, op: Callable[..., ir.Value], other: Self | Any) -> Self: + def _with_binary(self, op: Callable[..., ir.Value], other: Self) -> Self: return self._with_callable(op, other=other) def _with_elementwise( - self, op: Callable[..., ir.Value], /, **expressifiable_args: Self | Any + self, op: Callable[..., ir.Value], /, **expressifiable_args: Self ) -> Self: return self._with_callable(op, **expressifiable_args) @@ -243,7 +243,7 @@ def null_count(self) -> Self: return self._with_callable(lambda expr: expr.isnull().sum()) def is_nan(self) -> Self: - def func(expr: ir.FloatingValue | Any) -> ir.Value: + def func(expr: ir.FloatingValue) -> ir.Value: otherwise = expr.isnan() if is_floating(expr.type()) else False return ibis.ifelse(expr.isnull(), None, otherwise) diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index 7861ee0681..05de269740 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -238,10 +238,10 @@ def __le__(self, other: Any) -> Self: def __lt__(self, other: Any) -> Self: return self._with_native(self.native.__lt__(extract_native(other))) - def __and__(self, other: PolarsExpr | bool | Any) -> Self: + def __and__(self, other: PolarsExpr) -> Self: return self._with_native(self.native.__and__(extract_native(other))) # type: ignore[operator] - def __or__(self, other: PolarsExpr | bool | Any) -> Self: + def __or__(self, other: PolarsExpr) -> Self: return self._with_native(self.native.__or__(extract_native(other))) # type: ignore[operator] def __add__(self, other: Any) -> Self: diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 6ae333b03d..ff3cdb39d2 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -34,7 +34,7 @@ from narwhals._spark_like.dataframe import SparkLikeLazyFrame from narwhals._spark_like.namespace import SparkLikeNamespace from narwhals._utils import _LimitedContext - from narwhals.typing import FillNullStrategy, IntoDType, NonNestedLiteral, RankMethod + from narwhals.typing import FillNullStrategy, IntoDType, RankMethod NativeRankMethod: TypeAlias = Literal["rank", "dense_rank", "row_number"] SparkWindowFunction = WindowFunction[SparkLikeLazyFrame, Column] @@ -189,19 +189,19 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: implementation=context._implementation, ) - def __truediv__(self, other: SparkLikeExpr) -> Self: + def __truediv__(self, other: Self) -> Self: def _truediv(expr: Column, other: Column) -> Column: return true_divide(self._F, expr, other) return self._with_binary(_truediv, other) - def __rtruediv__(self, other: SparkLikeExpr) -> Self: + def __rtruediv__(self, other: Self) -> Self: def _rtruediv(expr: Column, other: Column) -> Column: return true_divide(self._F, other, expr) return self._with_binary(_rtruediv, other).alias("literal") - def __floordiv__(self, other: SparkLikeExpr) -> Self: + def __floordiv__(self, other: Self) -> Self: def _floordiv(expr: Column, other: Column) -> Column: F = self._F return F.when( @@ -210,7 +210,7 @@ def _floordiv(expr: Column, other: Column) -> Column: return self._with_binary(_floordiv, other) - def __rfloordiv__(self, other: SparkLikeExpr) -> Self: + def __rfloordiv__(self, other: Self) -> Self: def _rfloordiv(expr: Column, other: Column) -> Column: F = self._F return F.when( @@ -333,10 +333,7 @@ def _is_nan(expr: Column) -> Column: return self._with_elementwise(_is_nan) def fill_null( - self, - value: Self | NonNestedLiteral, - strategy: FillNullStrategy | None, - limit: int | None, + self, value: Self | Any, strategy: FillNullStrategy | None, limit: int | None ) -> Self: if strategy is not None: diff --git a/narwhals/_sql/expr.py b/narwhals/_sql/expr.py index 946a39fa9e..89e1bcd3ec 100644 --- a/narwhals/_sql/expr.py +++ b/narwhals/_sql/expr.py @@ -58,7 +58,7 @@ def __narwhals_namespace__( ) -> SQLNamespace[SQLLazyFrameT, Self, Any, NativeExprT]: ... def _callable_to_eval_series( - self, call: Callable[..., NativeExprT], /, **expressifiable_args: Self | Any + self, call: Callable[..., NativeExprT], /, **expressifiable_args: Self ) -> EvalSeries[SQLLazyFrameT, NativeExprT]: def func(df: SQLLazyFrameT) -> list[NativeExprT]: native_series_list = self(df) @@ -74,7 +74,7 @@ def func(df: SQLLazyFrameT) -> list[NativeExprT]: return func def _push_down_window_function( - self, call: Callable[..., NativeExprT], /, **expressifiable_args: Self | Any + self, call: Callable[..., NativeExprT], /, **expressifiable_args: Self ) -> WindowFunction[SQLLazyFrameT, NativeExprT]: def window_f( df: SQLLazyFrameT, window_inputs: WindowInputs[NativeExprT] @@ -113,7 +113,7 @@ def _with_callable( call: Callable[..., NativeExprT], window_func: WindowFunction[SQLLazyFrameT, NativeExprT] | None = None, /, - **expressifiable_args: Self | Any, + **expressifiable_args: Self, ) -> Self: return self.__class__( self._callable_to_eval_series(call, **expressifiable_args), @@ -125,7 +125,7 @@ def _with_callable( ) def _with_elementwise( - self, call: Callable[..., NativeExprT], /, **expressifiable_args: Self | Any + self, call: Callable[..., NativeExprT], /, **expressifiable_args: Self ) -> Self: return self.__class__( self._callable_to_eval_series(call, **expressifiable_args), @@ -136,7 +136,7 @@ def _with_elementwise( implementation=self._implementation, ) - def _with_binary(self, op: Callable[..., NativeExprT], other: Self | Any) -> Self: + def _with_binary(self, op: Callable[..., NativeExprT], other: Self) -> Self: return self.__class__( self._callable_to_eval_series(op, other=other), self._push_down_window_function(op, other=other), From 74985f4a158eba298ef793e99303f8bf94d01707 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 30 Sep 2025 18:15:31 +0100 Subject: [PATCH 52/95] typing --- narwhals/_polars/expr.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index 05de269740..ea95206529 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -220,11 +220,11 @@ def replace_strict( native = self.native.replace_strict(old, new, return_dtype=return_dtype_pl) return self._with_native(native) - def __eq__(self, other: object) -> Self: # type: ignore[override] - return self._with_native(self.native.__eq__(extract_native(other))) # type: ignore[operator] + def __eq__(self, other: PolarsExpr) -> Self: # type: ignore[override] + return self._with_native(self.native.__eq__(extract_native(other))) - def __ne__(self, other: object) -> Self: # type: ignore[override] - return self._with_native(self.native.__ne__(extract_native(other))) # type: ignore[operator] + def __ne__(self, other: PolarsExpr) -> Self: # type: ignore[override] + return self._with_native(self.native.__ne__(extract_native(other))) def __ge__(self, other: Any) -> Self: return self._with_native(self.native.__ge__(extract_native(other))) @@ -239,10 +239,10 @@ def __lt__(self, other: Any) -> Self: return self._with_native(self.native.__lt__(extract_native(other))) def __and__(self, other: PolarsExpr) -> Self: - return self._with_native(self.native.__and__(extract_native(other))) # type: ignore[operator] + return self._with_native(self.native.__and__(extract_native(other))) def __or__(self, other: PolarsExpr) -> Self: - return self._with_native(self.native.__or__(extract_native(other))) # type: ignore[operator] + return self._with_native(self.native.__or__(extract_native(other))) def __add__(self, other: Any) -> Self: return self._with_native(self.native.__add__(extract_native(other))) From 5aa3d71ff0255a550b9740c459c920b5ce70ca2f Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 1 Oct 2025 09:45:44 +0100 Subject: [PATCH 53/95] expressifiable_args -> kwargs --- narwhals/_compliant/expr.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index e029e3d582..4ba96f0fa5 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -293,11 +293,7 @@ def func(df: EagerDataFrameT) -> list[EagerSeriesT]: ) def _reuse_series( - self, - method_name: str, - *, - returns_scalar: bool = False, - **expressifiable_args: Any, + self, method_name: str, *, returns_scalar: bool = False, **kwargs: Any ) -> Self: """Reuse Series implementation for expression. @@ -308,14 +304,13 @@ def _reuse_series( method_name: name of method. returns_scalar: whether the Series version returns a scalar. In this case, the expression version should return a 1-row Series. - expressifiable_args: keyword arguments to pass to function, which may - be expressifiable (e.g. `nw.col('a').is_between(3, nw.col('b')))`). + kwargs: keyword arguments to pass to function. """ func = partial( self._reuse_series_inner, method_name=method_name, returns_scalar=returns_scalar, - expressifiable_args=expressifiable_args, + **kwargs, ) return self._from_callable( func, @@ -341,12 +336,12 @@ def _reuse_series_inner( *, method_name: str, returns_scalar: bool, - expressifiable_args: dict[str, Any], + **kwargs: Any, ) -> Sequence[EagerSeriesT]: kwargs = { **{ name: df._evaluate_expr(value) if self._is_expr(value) else value - for name, value in expressifiable_args.items() + for name, value in kwargs.items() } } method = methodcaller( From 3c075e62b1489677f8084f3351febf440551d4d7 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 1 Oct 2025 10:00:25 +0100 Subject: [PATCH 54/95] more precise typing, remove another `Any` --- narwhals/_compliant/column.py | 6 +----- narwhals/_compliant/expr.py | 2 +- narwhals/_dask/expr.py | 2 +- narwhals/_duckdb/expr.py | 3 ++- narwhals/_ibis/expr.py | 3 ++- narwhals/_spark_like/expr.py | 3 ++- 6 files changed, 9 insertions(+), 10 deletions(-) diff --git a/narwhals/_compliant/column.py b/narwhals/_compliant/column.py index e5b87e8869..5329c2f8f1 100644 --- a/narwhals/_compliant/column.py +++ b/narwhals/_compliant/column.py @@ -21,7 +21,6 @@ FillNullStrategy, IntoDType, ModeKeepStrategy, - NonNestedLiteral, RankMethod, ) @@ -85,10 +84,7 @@ def exp(self) -> Self: ... def sqrt(self) -> Self: ... def fill_nan(self, value: float | None) -> Self: ... def fill_null( - self, - value: Self | NonNestedLiteral, - strategy: FillNullStrategy | None, - limit: int | None, + self, value: Self | None, strategy: FillNullStrategy | None, limit: int | None ) -> Self: ... def is_between( self, lower_bound: Self, upper_bound: Self, closed: ClosedInterval diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index 4ba96f0fa5..548b1cec04 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -562,7 +562,7 @@ def fill_nan(self, value: float | None) -> Self: return self._reuse_series("fill_nan", value=value) def fill_null( - self, value: Self | Any, strategy: FillNullStrategy | None, limit: int | None + self, value: Self | None, strategy: FillNullStrategy | None, limit: int | None ) -> Self: return self._reuse_series( "fill_null", value=value, strategy=strategy, limit=limit diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index fc9b5bef50..ed4651f721 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -443,7 +443,7 @@ def func(expr: dx.Series) -> dx.Series: return self._with_callable(func) def fill_null( - self, value: Self | Any, strategy: FillNullStrategy | None, limit: int | None + self, value: Self | None, strategy: FillNullStrategy | None, limit: int | None ) -> Self: def func(expr: dx.Series) -> dx.Series: if value is not None: diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 955ae2c992..c601fe58de 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -229,7 +229,7 @@ def is_in(self, other: Sequence[Any]) -> Self: return self._with_elementwise(lambda expr: F("contains", lit(other), expr)) def fill_null( - self, value: Self | Any, strategy: FillNullStrategy | None, limit: int | None + self, value: Self | None, strategy: FillNullStrategy | None, limit: int | None ) -> Self: if strategy is not None: if self._backend_version < (1, 3): # pragma: no cover @@ -262,6 +262,7 @@ def _fill_with_strategy( def _fill_constant(expr: Expression, value: Any) -> Expression: return CoalesceOperator(expr, value) + assert value is not None # noqa: S101 return self._with_elementwise(_fill_constant, value=value) def cast(self, dtype: IntoDType) -> Self: diff --git a/narwhals/_ibis/expr.py b/narwhals/_ibis/expr.py index 9143cabc79..bdd514b89c 100644 --- a/narwhals/_ibis/expr.py +++ b/narwhals/_ibis/expr.py @@ -255,7 +255,7 @@ def is_finite(self) -> Self: def is_in(self, other: Sequence[Any]) -> Self: return self._with_callable(lambda expr: expr.isin(other)) - def fill_null(self, value: Self | Any, strategy: Any, limit: int | None) -> Self: + def fill_null(self, value: Self | None, strategy: Any, limit: int | None) -> Self: # Ibis doesn't yet allow ignoring nulls in first/last with window functions, which makes forward/backward # strategies inconsistent when there are nulls present: https://github.com/ibis-project/ibis/issues/9539 if strategy is not None: @@ -268,6 +268,7 @@ def fill_null(self, value: Self | Any, strategy: Any, limit: int | None) -> Self def _fill_null(expr: ir.Value, value: ir.Scalar) -> ir.Value: return expr.fill_null(value) + assert value is not None # noqa: S101 return self._with_callable(_fill_null, value=value) def cast(self, dtype: IntoDType) -> Self: diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index ff3cdb39d2..ad6fd00f0e 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -333,7 +333,7 @@ def _is_nan(expr: Column) -> Column: return self._with_elementwise(_is_nan) def fill_null( - self, value: Self | Any, strategy: FillNullStrategy | None, limit: int | None + self, value: Self | None, strategy: FillNullStrategy | None, limit: int | None ) -> Self: if strategy is not None: @@ -361,6 +361,7 @@ def _fill_with_strategy( def _fill_constant(expr: Column, value: Column) -> Column: return self._F.ifnull(expr, value) + assert value is not None # noqa: S101 return self._with_elementwise(_fill_constant, value=value) @property From f022d1c599701a4311d99a2a42e81cc61405831c Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 1 Oct 2025 10:31:55 +0100 Subject: [PATCH 55/95] more typing fixes --- narwhals/_compliant/column.py | 40 ++++++++++++++++---------------- narwhals/_compliant/expr.py | 2 +- narwhals/_compliant/selectors.py | 6 ++--- narwhals/_sql/expr_dt.py | 10 +++++--- narwhals/_sql/expr_str.py | 29 +++++++++++++++-------- 5 files changed, 50 insertions(+), 37 deletions(-) diff --git a/narwhals/_compliant/column.py b/narwhals/_compliant/column.py index 5329c2f8f1..9fad5c26d1 100644 --- a/narwhals/_compliant/column.py +++ b/narwhals/_compliant/column.py @@ -32,27 +32,27 @@ class CompliantColumn(Protocol): _version: Version - def __add__(self, other: Any) -> Self: ... - def __and__(self, other: Any) -> Self: ... - def __eq__(self, other: object) -> Self: ... # type: ignore[override] - def __floordiv__(self, other: Any) -> Self: ... - def __ge__(self, other: Any) -> Self: ... - def __gt__(self, other: Any) -> Self: ... + def __add__(self, other: Self) -> Self: ... + def __and__(self, other: Self) -> Self: ... + def __eq__(self, other: Self) -> Self: ... # type: ignore[override] + def __floordiv__(self, other: Self) -> Self: ... + def __ge__(self, other: Self) -> Self: ... + def __gt__(self, other: Self) -> Self: ... def __invert__(self) -> Self: ... - def __le__(self, other: Any) -> Self: ... - def __lt__(self, other: Any) -> Self: ... - def __mod__(self, other: Any) -> Self: ... - def __mul__(self, other: Any) -> Self: ... - def __ne__(self, other: object) -> Self: ... # type: ignore[override] - def __or__(self, other: Any) -> Self: ... - def __pow__(self, other: Any) -> Self: ... - def __rfloordiv__(self, other: Any) -> Self: ... - def __rmod__(self, other: Any) -> Self: ... - def __rpow__(self, other: Any) -> Self: ... - def __rsub__(self, other: Any) -> Self: ... - def __rtruediv__(self, other: Any) -> Self: ... - def __sub__(self, other: Any) -> Self: ... - def __truediv__(self, other: Any) -> Self: ... + def __le__(self, other: Self) -> Self: ... + def __lt__(self, other: Self) -> Self: ... + def __mod__(self, other: Self) -> Self: ... + def __mul__(self, other: Self) -> Self: ... + def __ne__(self, other: Self) -> Self: ... # type: ignore[override] + def __or__(self, other: Self) -> Self: ... + def __pow__(self, other: Self) -> Self: ... + def __rfloordiv__(self, other: Self) -> Self: ... + def __rmod__(self, other: Self) -> Self: ... + def __rpow__(self, other: Self) -> Self: ... + def __rsub__(self, other: Self) -> Self: ... + def __rtruediv__(self, other: Self) -> Self: ... + def __sub__(self, other: Self) -> Self: ... + def __truediv__(self, other: Self) -> Self: ... def __narwhals_namespace__(self) -> CompliantNamespace[Any, Any]: ... diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index 548b1cec04..60dff8090b 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -784,7 +784,7 @@ def sqrt(self) -> Self: return self._reuse_series("sqrt") def is_between( - self, lower_bound: Any, upper_bound: Any, closed: ClosedInterval + self, lower_bound: Self, upper_bound: Self, closed: ClosedInterval ) -> Self: return self._reuse_series( "is_between", lower_bound=lower_bound, upper_bound=upper_bound, closed=closed diff --git a/narwhals/_compliant/selectors.py b/narwhals/_compliant/selectors.py index 515d0ccd15..bb4f612d10 100644 --- a/narwhals/_compliant/selectors.py +++ b/narwhals/_compliant/selectors.py @@ -229,7 +229,7 @@ def _is_selector( ) -> TypeIs[CompliantSelector[FrameT, SeriesOrExprT]]: return isinstance(other, type(self)) - @overload + @overload # type: ignore[override] def __sub__(self, other: Self) -> Self: ... @overload def __sub__( @@ -255,7 +255,7 @@ def names(df: FrameT) -> Sequence[str]: return self.selectors._selector.from_callables(series, names, context=self) return self._to_expr() - other - @overload + @overload # type: ignore[override] def __or__(self, other: Self) -> Self: ... @overload def __or__( @@ -284,7 +284,7 @@ def names(df: FrameT) -> Sequence[str]: return self.selectors._selector.from_callables(series, names, context=self) return self._to_expr() | other - @overload + @overload # type: ignore[override] def __and__(self, other: Self) -> Self: ... @overload def __and__( diff --git a/narwhals/_sql/expr_dt.py b/narwhals/_sql/expr_dt.py index 85b65aaf05..5625e643d4 100644 --- a/narwhals/_sql/expr_dt.py +++ b/narwhals/_sql/expr_dt.py @@ -1,17 +1,21 @@ from __future__ import annotations -from typing import Any, Generic +from typing import TYPE_CHECKING, Any, Generic, TypeAlias from narwhals._compliant import LazyExprNamespace from narwhals._compliant.any_namespace import DateTimeNamespace from narwhals._sql.typing import SQLExprT +if TYPE_CHECKING: + # TODO(unassigned): Make string namespace generic in NativeExprT too. + NativeExpr: TypeAlias = Any + class SQLExprDateTimeNamesSpace( LazyExprNamespace[SQLExprT], DateTimeNamespace[SQLExprT], Generic[SQLExprT] ): - def _function(self, name: str, *args: Any) -> SQLExprT: - return self.compliant._function(name, *args) # type: ignore[no-any-return] + def _function(self, name: str, *args: Any) -> NativeExpr: + return self.compliant._function(name, *args) def year(self) -> SQLExprT: return self.compliant._with_elementwise(lambda expr: self._function("year", expr)) diff --git a/narwhals/_sql/expr_str.py b/narwhals/_sql/expr_str.py index c1b5db9b53..0d9c483711 100644 --- a/narwhals/_sql/expr_str.py +++ b/narwhals/_sql/expr_str.py @@ -1,26 +1,35 @@ from __future__ import annotations -from typing import Any, Generic +from typing import TYPE_CHECKING, Any, Generic, TypeAlias from narwhals._compliant import LazyExprNamespace from narwhals._compliant.any_namespace import StringNamespace from narwhals._sql.typing import SQLExprT +if TYPE_CHECKING: + # TODO(unassigned): Make string namespace generic in NativeExprT too. + NativeExpr: TypeAlias = Any + class SQLExprStringNamespace( LazyExprNamespace[SQLExprT], StringNamespace[SQLExprT], Generic[SQLExprT] ): - def _lit(self, value: Any) -> SQLExprT: - return self.compliant._lit(value) # type: ignore[no-any-return] + def _lit(self, value: Any) -> NativeExpr: + return self.compliant._lit(value) - def _function(self, name: str, *args: Any) -> SQLExprT: - return self.compliant._function(name, *args) # type: ignore[no-any-return] + def _function(self, name: str, *args: Any) -> NativeExpr: + return self.compliant._function(name, *args) - def _when(self, condition: Any, value: Any, otherwise: Any | None = None) -> SQLExprT: - return self.compliant._when(condition, value, otherwise) # type: ignore[no-any-return] + def _when( + self, + condition: NativeExpr, + value: NativeExpr, + otherwise: NativeExpr | None = None, + ) -> NativeExpr: + return self.compliant._when(condition, value, otherwise) def contains(self, pattern: str, *, literal: bool) -> SQLExprT: - def func(expr: Any) -> Any: + def func(expr: NativeExpr) -> NativeExpr: if literal: return self._function("contains", expr, self._lit(pattern)) return self._function("regexp_matches", expr, self._lit(pattern)) @@ -51,7 +60,7 @@ def replace_all(self, value: SQLExprT, pattern: str, *, literal: bool) -> SQLExp ) def slice(self, offset: int, length: int | None) -> SQLExprT: - def func(expr: SQLExprT) -> SQLExprT: + def func(expr: NativeExpr) -> NativeExpr: col_length = self._function("length", expr) _offset = ( @@ -99,7 +108,7 @@ def zfill(self, width: int) -> SQLExprT: # There is no built-in zfill function, so we need to implement it manually # using string manipulation functions. - def func(expr: Any) -> Any: + def func(expr: NativeExpr) -> NativeExpr: less_than_width = self._function("length", expr) < self._lit(width) zero, hyphen, plus = self._lit("0"), self._lit("-"), self._lit("+") From 9bcb61d002e8ef8e6b77bf62a661e4de594edbb9 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 1 Oct 2025 10:45:06 +0100 Subject: [PATCH 56/95] more typing fixes --- narwhals/_sql/expr_dt.py | 7 +++---- narwhals/_sql/expr_str.py | 23 ++++++++++------------- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/narwhals/_sql/expr_dt.py b/narwhals/_sql/expr_dt.py index 5625e643d4..8c660bc500 100644 --- a/narwhals/_sql/expr_dt.py +++ b/narwhals/_sql/expr_dt.py @@ -1,21 +1,20 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Generic, TypeAlias +from typing import TYPE_CHECKING, Any, Generic from narwhals._compliant import LazyExprNamespace from narwhals._compliant.any_namespace import DateTimeNamespace from narwhals._sql.typing import SQLExprT if TYPE_CHECKING: - # TODO(unassigned): Make string namespace generic in NativeExprT too. - NativeExpr: TypeAlias = Any + from narwhals._compliant.expr import NativeExpr class SQLExprDateTimeNamesSpace( LazyExprNamespace[SQLExprT], DateTimeNamespace[SQLExprT], Generic[SQLExprT] ): def _function(self, name: str, *args: Any) -> NativeExpr: - return self.compliant._function(name, *args) + return self.compliant._function(name, *args) # type: ignore[no-any-return] def year(self) -> SQLExprT: return self.compliant._with_elementwise(lambda expr: self._function("year", expr)) diff --git a/narwhals/_sql/expr_str.py b/narwhals/_sql/expr_str.py index 0d9c483711..db43531823 100644 --- a/narwhals/_sql/expr_str.py +++ b/narwhals/_sql/expr_str.py @@ -1,32 +1,29 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Generic, TypeAlias +import operator +from typing import TYPE_CHECKING, Any, Generic from narwhals._compliant import LazyExprNamespace from narwhals._compliant.any_namespace import StringNamespace from narwhals._sql.typing import SQLExprT if TYPE_CHECKING: - # TODO(unassigned): Make string namespace generic in NativeExprT too. - NativeExpr: TypeAlias = Any + from narwhals._compliant.expr import NativeExpr class SQLExprStringNamespace( LazyExprNamespace[SQLExprT], StringNamespace[SQLExprT], Generic[SQLExprT] ): def _lit(self, value: Any) -> NativeExpr: - return self.compliant._lit(value) + return self.compliant._lit(value) # type: ignore[no-any-return] def _function(self, name: str, *args: Any) -> NativeExpr: - return self.compliant._function(name, *args) + return self.compliant._function(name, *args) # type: ignore[no-any-return] def _when( - self, - condition: NativeExpr, - value: NativeExpr, - otherwise: NativeExpr | None = None, + self, condition: Any, value: Any, otherwise: Any | None = None ) -> NativeExpr: - return self.compliant._when(condition, value, otherwise) + return self.compliant._when(condition, value, otherwise) # type: ignore[no-any-return] def contains(self, pattern: str, *, literal: bool) -> SQLExprT: def func(expr: NativeExpr) -> NativeExpr: @@ -64,7 +61,7 @@ def func(expr: NativeExpr) -> NativeExpr: col_length = self._function("length", expr) _offset = ( - col_length + self._lit(offset + 1) + operator.add(col_length, self._lit(offset + 1)) if offset < 0 else self._lit(offset + 1) ) @@ -119,10 +116,10 @@ def func(expr: NativeExpr) -> NativeExpr: "lpad", substring, self._lit(width - 1), zero ) return self._when( - starts_with_minus & less_than_width, + operator.and_(starts_with_minus, less_than_width), self._function("concat", hyphen, padded_substring), self._when( - starts_with_plus & less_than_width, + operator.and_(starts_with_plus, less_than_width), self._function("concat", plus, padded_substring), self._when( less_than_width, From 32346ab2d7f9f3b0d4c876294eb7ce13c43dee7f Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 6 Oct 2025 09:17:42 +0100 Subject: [PATCH 57/95] groupby fix --- narwhals/_arrow/group_by.py | 3 ++- narwhals/_compliant/group_by.py | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/narwhals/_arrow/group_by.py b/narwhals/_arrow/group_by.py index 69a5983207..22785f0ebe 100644 --- a/narwhals/_arrow/group_by.py +++ b/narwhals/_arrow/group_by.py @@ -76,8 +76,9 @@ def _configure_agg( ) -> tuple[pa.TableGroupBy, Aggregation, AggregateOptions | None]: option: AggregateOptions | None = None function_name = self._leaf_name(expr) + kwargs = self._kwargs(expr) if function_name in self._OPTION_VARIANCE: - ddof = expr._scalar_kwargs.get("ddof", 1) + ddof = kwargs["ddof"] option = pc.VarianceOptions(ddof=ddof) elif function_name in self._OPTION_COUNT_ALL: option = pc.CountOptions(mode="all") diff --git a/narwhals/_compliant/group_by.py b/narwhals/_compliant/group_by.py index 217209980e..de9cf867e8 100644 --- a/narwhals/_compliant/group_by.py +++ b/narwhals/_compliant/group_by.py @@ -168,6 +168,11 @@ def _leaf_name(cls, expr: DepthTrackingExprAny, /) -> NarwhalsAggregation | Any: """Return the last function name in the chain defined by `expr`.""" return next(expr._metadata.op_nodes_reversed()).name + @classmethod + def _kwargs(cls, expr: DepthTrackingExprAny, /) -> dict[str, Any]: + """Return the last function name in the chain defined by `expr`.""" + return next(expr._metadata.op_nodes_reversed()).kwargs + class EagerGroupBy( DepthTrackingGroupBy[CompliantDataFrameT, EagerExprT_contra, NativeAggregationT_co], From 51d258e413ddefce1ec1816f1078f632be49c51b Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 6 Oct 2025 09:21:52 +0100 Subject: [PATCH 58/95] first last --- narwhals/_expression_parsing.py | 2 +- narwhals/expr.py | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 478f7a9999..e817ae21ce 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -804,7 +804,7 @@ def evaluate_root_node(node: ExprNode, ns: CompliantNamespaceAny) -> CompliantEx def evaluate_node( compliant_expr: CompliantExprAny, node: ExprNode, ns: CompliantNamespaceAny ) -> CompliantExprAny: - md = compliant_expr._metadata + md: ExprMetadata = compliant_expr._metadata ce, *ces = maybe_broadcast_ces( compliant_expr, *evaluate_into_exprs( diff --git a/narwhals/expr.py b/narwhals/expr.py index f6184e12e1..18cdb5ea6f 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -1565,9 +1565,7 @@ def first(self) -> Self: | 1 2 None | └──────────────────┘ """ - return self._with_orderable_aggregation( - lambda plx: self._to_compliant_expr(plx).first() - ) + return self._with_node(ExprNode(ExprKind.ORDERABLE_AGGREGATION, "first")) def last(self) -> Self: """Get the last value. @@ -1606,9 +1604,7 @@ def last(self) -> Self: |b: [[null,"baz"]] | └──────────────────┘ """ - return self._with_orderable_aggregation( - lambda plx: self._to_compliant_expr(plx).last() - ) + return self._with_node(ExprNode(ExprKind.ORDERABLE_AGGREGATION, "last")) def mode(self, *, keep: ModeKeepStrategy = "all") -> Self: r"""Compute the most occurring value(s). From f6ab8d633feef77867778ea4ce985acf80ed9680 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 6 Oct 2025 10:04:14 +0100 Subject: [PATCH 59/95] fix is_native, with_row_index --- narwhals/_arrow/dataframe.py | 19 ++++++++++--------- narwhals/_arrow/expr.py | 2 +- narwhals/_polars/namespace.py | 9 +++++++++ 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 37bde19aec..49adfa5e57 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -502,18 +502,19 @@ def to_dict( return {ser.name: ser.to_list() for ser in it} def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self: - plx = self.__narwhals_namespace__() - if order_by is None: - import numpy as np # ignore-banned-import + import numpy as np # ignore-banned-import - data = pa.array(np.arange(len(self), dtype=np.int64)) + plx = self.__narwhals_namespace__() + size = len(self) + data = pa.array(np.arange(size)) + row_index_s = plx._series.from_iterable(data, context=self, name=name) + row_index = plx._expr._from_series(row_index_s) + if order_by: row_index = plx._expr._from_series( - plx._series.from_iterable(data, context=self, name=name) + self.with_columns(row_index) + .sort(*order_by, descending=False, nulls_last=False) + .get_column(name) ) - else: - rank = plx.col(order_by[0]).rank("ordinal", descending=False) - one = plx.lit(1, None).broadcast() - row_index = (rank.over(partition_by=[], order_by=order_by) - one).alias(name) return self.select(row_index, plx.all()) def filter(self, predicate: ArrowExpr | list[bool | None]) -> Self: diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index 73a739bc8a..0e59cf039b 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -99,7 +99,7 @@ def _reuse_series_extra_kwargs( def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self: meta = self._metadata - if partition_by and meta is not None and not meta.is_scalar_like: + if partition_by and not meta.is_scalar_like: msg = "Only aggregation or literal operations are supported in grouped `over` context for PyArrow." raise NotImplementedError(msg) diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 6ee35911d0..1c32c3f3cf 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -16,6 +16,8 @@ from collections.abc import Iterable, Sequence from datetime import timezone + from typing_extensions import TypeIs + from narwhals._compliant import CompliantSelectorNamespace from narwhals._polars.dataframe import Method, PolarsDataFrame, PolarsLazyFrame from narwhals._polars.typing import FrameT @@ -215,6 +217,13 @@ def when_then( version=self._version, ) + def is_native(self, obj: Any, /) -> TypeIs[pl.DataFrame | pl.LazyFrame | pl.Series]: + return ( + self._dataframe._is_native(obj) + or self._series._is_native(obj) + or self._lazyframe._is_native(obj) + ) + # NOTE: Implementation is too different to annotate correctly (vs other `*SelectorNamespace`) # 1. Others have lots of private stuff for code reuse # i. None of that is useful here From 068e75eb76944a6c770aae3b36282b42f3a9b581 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 6 Oct 2025 10:33:14 +0100 Subject: [PATCH 60/95] typing, docs (thanks Francesco!) --- docs/how_it_works.md | 14 +++++++------- narwhals/_pandas_like/group_by.py | 10 ++++++---- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/docs/how_it_works.md b/docs/how_it_works.md index 97dd528d87..8bb7da8e6b 100644 --- a/docs/how_it_works.md +++ b/docs/how_it_works.md @@ -293,7 +293,7 @@ This tells us a few things: - We're performing an aggregation. - The name of the function is `'std'`. This will be looked up in the compliant object. - It takes keyword arguments `ddof=1`. -- We'll look at the others later. +- We'll look at `exprs`, `str_as_lit`, and `allow_multi_output` later. In order for the evaluation to succeed, then `PandasLikeExpr` must have a `std` method defined on it, which takes a `ddof` argument. And this is what the `CompliantExpr` Protocol is for: so @@ -312,7 +312,7 @@ The `str_as_lit` parameter tells us whether string literals should be interprete or columns (e.g. `col('foo')`). Finally `allow_multi_output` tells us whether multi-outuput expressions (more on this in the next section) are allowed to appear in `exprs`. -Node that the expression in `exprs` also has its own nodes: +Note that the expression in `exprs` also has its own nodes: ```python exec="1" result="python" session="pandas_impl" source="above" print(expr._nodes[3].exprs[0]._nodes) @@ -487,8 +487,8 @@ end up with nw.col("a").sum().over("c") + nw.col("b").sum().over("c") ``` -In general, query optimisation is out-of-scope for Narwhals. We consider this -expression rewrite acceptable because: - -- It's simple. -- It allows us to evaluate operations which otherwise wouldn't be allowed for certain backends. +!!! info + In general, query optimisation is out-of-scope for Narwhals. We consider this + expression rewrite acceptable because: + - It's simple. + - It allows us to evaluate operations which otherwise wouldn't be allowed for certain backends. diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 3816d8876a..8eadb4f824 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -133,7 +133,7 @@ def _getitem_aggs( cols = list(names) native = compliant.native - keys, kwargs = group_by._keys, group_by._kwargs + keys, kwargs = group_by._keys, group_by._group_by_kwargs # Implementation based on the following suggestion: # https://github.com/pandas-dev/pandas/issues/19254#issuecomment-778661578 @@ -224,7 +224,7 @@ class PandasLikeGroupBy( _output_key_names: list[str] """Stores the **original** version of group keys.""" - _kwargs: Mapping[str, bool] + _group_by_kwargs: Mapping[str, bool] """Stores keyword arguments for `DataFrame.groupby` other than `by`.""" @property @@ -252,13 +252,15 @@ def __init__( if set(native.index.names).intersection(self.compliant.columns): native = native.reset_index(drop=True) - self._kwargs = { + self._group_by_kwargs = { "sort": False, "as_index": True, "dropna": drop_null_keys, "observed": True, } - self._grouped: NativeGroupBy = native.groupby(self._keys.copy(), **self._kwargs) + self._grouped: NativeGroupBy = native.groupby( + self._keys.copy(), **self._group_by_kwargs + ) def agg(self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: all_aggs_are_simple = True From db0303a7465c11a97b3d9fa32dbf59faeae571a5 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 6 Oct 2025 13:17:52 +0100 Subject: [PATCH 61/95] more slow testv --- narwhals/_arrow/dataframe.py | 6 +++--- narwhals/_compliant/dataframe.py | 13 ++++++++----- narwhals/_compliant/expr.py | 8 ++++++-- narwhals/_compliant/namespace.py | 6 +++--- narwhals/_pandas_like/dataframe.py | 6 +++--- narwhals/_sql/dataframe.py | 4 +++- narwhals/_sql/expr.py | 2 +- tests/dtypes_test.py | 1 + tests/frame/lazy_test.py | 1 + 9 files changed, 29 insertions(+), 18 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 49adfa5e57..08bbb3547e 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -332,7 +332,7 @@ def simple_select(self, *column_names: str) -> Self: ) def select(self, *exprs: ArrowExpr) -> Self: - new_series = self._evaluate_into_exprs(*exprs) + new_series = self._evaluate_exprs(*exprs) if not new_series: # return empty dataframe, like Polars does return self._with_native( @@ -359,7 +359,7 @@ def with_columns(self, *exprs: ArrowExpr) -> Self: # NOTE: We use a faux-mutable variable and repeatedly "overwrite" (native_frame) # All `pyarrow` data is immutable, so this is fine native_frame = self.native - new_columns = self._evaluate_into_exprs(*exprs) + new_columns = self._evaluate_exprs(*exprs) columns = self.columns for col_value in new_columns: @@ -522,7 +522,7 @@ def filter(self, predicate: ArrowExpr | list[bool | None]) -> Self: mask_native: Mask | ChunkedArrayAny = predicate else: # `[0]` is safe as the predicate's expression only returns a single column - mask_native = self._evaluate_into_exprs(predicate)[0].native + mask_native = self._evaluate_exprs(predicate)[0].native return self._with_native( self.native.filter(mask_native), validate_column_names=False ) diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index d88c8c3fc8..dcc8bc6a58 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -347,21 +347,24 @@ def _with_native( def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None: return check_columns_exist(subset, available=self.columns) - def _evaluate_expr(self, expr: EagerExprT, /) -> EagerSeriesT: + def _evaluate_single_output_expr(self, expr: EagerExprT, /) -> EagerSeriesT: """Evaluate `expr` and ensure it has a **single** output.""" - result: Sequence[EagerSeriesT] = expr(self) + # NOTE: Ignore intermittent [False Negative] + # Argument of type "EagerExprT@EagerDataFrame" cannot be assigned to parameter "expr" of type "EagerExprT@EagerDataFrame" in function "_evaluate_into_expr" + # Type "EagerExprT@EagerDataFrame" is not assignable to type "EagerExprT@EagerDataFrame" + result = self._evaluate_expr(expr) # pyright: ignore[reportArgumentType] if len(result) != 1: # pragma: no cover msg = "multi-output expressions not allowed in this context" raise MultiOutputExpressionError(msg) return result[0] - def _evaluate_into_exprs(self, *exprs: EagerExprT) -> Sequence[EagerSeriesT]: + def _evaluate_exprs(self, *exprs: EagerExprT) -> Sequence[EagerSeriesT]: # NOTE: Ignore intermittent [False Negative] # Argument of type "EagerExprT@EagerDataFrame" cannot be assigned to parameter "expr" of type "EagerExprT@EagerDataFrame" in function "_evaluate_into_expr" # Type "EagerExprT@EagerDataFrame" is not assignable to type "EagerExprT@EagerDataFrame" - return list(chain.from_iterable(self._evaluate_into_expr(expr) for expr in exprs)) # pyright: ignore[reportArgumentType] + return list(chain.from_iterable(self._evaluate_expr(expr) for expr in exprs)) # pyright: ignore[reportArgumentType] - def _evaluate_into_expr(self, expr: EagerExprT, /) -> Sequence[EagerSeriesT]: + def _evaluate_expr(self, expr: EagerExprT, /) -> Sequence[EagerSeriesT]: """Return list of raw columns. For eager backends we alias operations at each step. diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index a9ee2b2703..de73061228 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -343,7 +343,9 @@ def _reuse_series_inner( ) -> Sequence[EagerSeriesT]: kwargs = { **{ - name: df._evaluate_expr(value) if self._is_expr(value) else value + name: df._evaluate_single_output_expr(value) + if self._is_expr(value) + else value for name, value in kwargs.items() } } @@ -386,7 +388,9 @@ def _reuse_series_namespace( def inner(df: EagerDataFrameT) -> list[EagerSeriesT]: kwargs = { - name: df._evaluate_expr(value) if self._is_expr(value) else value + name: df._evaluate_single_output_expr(value) + if self._is_expr(value) + else value for name, value in expressifiable_args.items() } return [ diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index 1fa5c702f9..cec203ac0a 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -161,10 +161,10 @@ def when_then( self, predicate: EagerExprT, then: EagerExprT, otherwise: EagerExprT | None = None ) -> EagerExprT: def func(df: EagerDataFrameT) -> Sequence[EagerSeriesT_co]: - predicate_s = df._evaluate_expr(predicate) + predicate_s = df._evaluate_single_output_expr(predicate) align = predicate_s._align_full_broadcast - then_s = df._evaluate_expr(then) + then_s = df._evaluate_single_output_expr(then) if otherwise is None: predicate_s, then_s = align(predicate_s, then_s) result = self._if_then_else(predicate_s.native, then_s.native) @@ -173,7 +173,7 @@ def func(df: EagerDataFrameT) -> Sequence[EagerSeriesT_co]: predicate_s, then_s = align(predicate_s, then_s) result = self._if_then_else(predicate_s.native, then_s.native) else: - otherwise_s = df._evaluate_expr(otherwise) + otherwise_s = df._evaluate_single_output_expr(otherwise) predicate_s, then_s, otherwise_s = align(predicate_s, then_s, otherwise_s) result = self._if_then_else( predicate_s.native, then_s.native, otherwise_s.native diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 574a90efe6..3ad7bf0a9c 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -418,7 +418,7 @@ def simple_select(self, *column_names: str) -> Self: ) def select(self, *exprs: PandasLikeExpr) -> Self: - new_series = self._evaluate_into_exprs(*exprs) + new_series = self._evaluate_exprs(*exprs) if not new_series: # return empty dataframe, like Polars does return self._with_native(type(self.native)(), validate_column_names=False) @@ -466,14 +466,14 @@ def filter(self, predicate: PandasLikeExpr | list[bool]) -> Self: mask_native: pd.Series[Any] | list[bool] = predicate else: # `[0]` is safe as the predicate's expression only returns a single column - mask = self._evaluate_into_exprs(predicate)[0] + mask = self._evaluate_exprs(predicate)[0] mask_native = self._extract_comparand(mask) return self._with_native( self.native.loc[mask_native], validate_column_names=False ) def with_columns(self, *exprs: PandasLikeExpr) -> Self: - columns = self._evaluate_into_exprs(*exprs) + columns = self._evaluate_exprs(*exprs) if not columns and len(self) == 0: return self name_columns: dict[str, PandasLikeSeries] = {s.name: s for s in columns} diff --git a/narwhals/_sql/dataframe.py b/narwhals/_sql/dataframe.py index cd3e2546aa..f4a03dc541 100644 --- a/narwhals/_sql/dataframe.py +++ b/narwhals/_sql/dataframe.py @@ -40,7 +40,9 @@ def _evaluate_window_expr( raise MultiOutputExpressionError(msg) return result[0] - def _evaluate_expr(self, expr: SQLExpr[Self, NativeExprT], /) -> NativeExprT: + def _evaluate_single_output_expr( + self, expr: SQLExpr[Self, NativeExprT], / + ) -> NativeExprT: result = expr(self) if len(result) != 1: # pragma: no cover msg = "multi-output expressions not allowed in this context" diff --git a/narwhals/_sql/expr.py b/narwhals/_sql/expr.py index 608fb8f198..f286cc33f3 100644 --- a/narwhals/_sql/expr.py +++ b/narwhals/_sql/expr.py @@ -63,7 +63,7 @@ def _callable_to_eval_series( def func(df: SQLLazyFrameT) -> list[NativeExprT]: native_series_list = self(df) other_native_series = { - key: df._evaluate_expr(value) + key: df._evaluate_single_output_expr(value) for key, value in expressifiable_args.items() } return [ diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index 4ff9134c21..d63384647f 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -517,6 +517,7 @@ def test_datetime_w_tz_duckdb() -> None: assert result["b"] == nw.List(nw.List(nw.Datetime("us", "Asia/Kathmandu"))) +@pytest.mark.slow def test_datetime_w_tz_pyspark() -> None: # pragma: no cover pytest.importorskip("pyspark") session = pyspark_session() diff --git a/tests/frame/lazy_test.py b/tests/frame/lazy_test.py index 3b59f42d82..a8c3d4733c 100644 --- a/tests/frame/lazy_test.py +++ b/tests/frame/lazy_test.py @@ -55,6 +55,7 @@ def test_lazy_to_default(constructor_eager: ConstructorEager) -> None: assert isinstance(result.to_native(), expected_cls) +@pytest.mark.slow @pytest.mark.parametrize( "backend", [ From d0b3acde95438682c6d07086fd596737700d988e Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 9 Oct 2025 14:25:23 +0100 Subject: [PATCH 62/95] remove unnecessary extract_native --- narwhals/_pandas_like/utils.py | 5 +++-- narwhals/expr_str.py | 9 ++++----- narwhals/series_str.py | 24 +++++++++++++----------- 3 files changed, 20 insertions(+), 18 deletions(-) diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index ced44e3318..df9f44e127 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -167,9 +167,10 @@ def align_and_extract_native( ) return (lhs.native, rhs.native) - if isinstance(rhs, list): - msg = "Expected Series or scalar, got list." + if True: # isinstance(rhs, list): + msg = f"Expected Series or scalar, got {type(rhs)}." raise TypeError(msg) + # `rhs` must be scalar, so just leave it as-is return lhs.native, rhs diff --git a/narwhals/expr_str.py b/narwhals/expr_str.py index 6feb92e950..8054d04686 100644 --- a/narwhals/expr_str.py +++ b/narwhals/expr_str.py @@ -6,6 +6,7 @@ if TYPE_CHECKING: from narwhals.expr import Expr + from narwhals.typing import IntoExpr ExprT = TypeVar("ExprT", bound="Expr") @@ -41,7 +42,7 @@ def len_chars(self) -> ExprT: return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "str.len_chars")) def replace( - self, pattern: str, value: str | ExprT, *, literal: bool = False, n: int = 1 + self, pattern: str, value: str | IntoExpr, *, literal: bool = False, n: int = 1 ) -> ExprT: r"""Replace first matching regex/literal substring with a new string value. @@ -78,7 +79,7 @@ def replace( ) def replace_all( - self, pattern: str, value: str | ExprT, *, literal: bool = False + self, pattern: str, value: IntoExpr, *, literal: bool = False ) -> ExprT: r"""Replace all matching regex/literal substring with a new string value. @@ -490,9 +491,7 @@ def to_titlecase(self) -> ExprT: |└─────────────────────────┴─────────────────────────┘| └─────────────────────────────────────────────────────┘ """ - return self._expr._with_elementwise( - lambda plx: self._expr._to_compliant_expr(plx).str.to_titlecase() - ) + return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "str.to_titlecase")) def zfill(self, width: int) -> ExprT: """Transform string to zero-padded variant. diff --git a/narwhals/series_str.py b/narwhals/series_str.py index d4c32dc97e..ac1270e84c 100644 --- a/narwhals/series_str.py +++ b/narwhals/series_str.py @@ -1,8 +1,7 @@ from __future__ import annotations -from typing import Any, Generic +from typing import Generic -from narwhals.dependencies import is_narwhals_series from narwhals.typing import SeriesT @@ -31,9 +30,6 @@ def len_chars(self) -> SeriesT: self._narwhals_series._compliant_series.str.len_chars() ) - def _extract_compliant(self, arg: Any) -> Any: - return arg._compliant_series if is_narwhals_series(arg) else arg - def replace( self, pattern: str, value: str | SeriesT, *, literal: bool = False, n: int = 1 ) -> SeriesT: @@ -55,11 +51,14 @@ def replace( 1 abc123 dtype: object """ - return self._narwhals_series._with_compliant( - self._narwhals_series._compliant_series.str.replace( - self._extract_compliant(value), pattern=pattern, literal=literal, n=n + from narwhals.functions import col + + df = self._narwhals_series.to_frame().select( + col(self._narwhals_series.name).str.replace( + pattern=pattern, value=value, literal=literal, n=n ) ) + return df[self._narwhals_series.name] # type: ignore[return-value] def replace_all( self, pattern: str, value: str | SeriesT, *, literal: bool = False @@ -81,11 +80,14 @@ def replace_all( 1 123 dtype: object """ - return self._narwhals_series._with_compliant( - self._narwhals_series._compliant_series.str.replace_all( - self._extract_compliant(value), pattern, literal=literal + from narwhals.functions import col + + df = self._narwhals_series.to_frame().select( + col(self._narwhals_series.name).str.replace_all( + pattern=pattern, value=value, literal=literal ) ) + return df[self._narwhals_series.name] # type: ignore[return-value] def strip_chars(self, characters: str | None = None) -> SeriesT: r"""Remove leading and trailing characters. From a50395ac08568ec2e46f5996ee1372777c9f0d2f Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 9 Oct 2025 14:54:16 +0100 Subject: [PATCH 63/95] typing --- narwhals/_arrow/series_str.py | 12 +++--------- narwhals/_compliant/any_namespace.py | 4 ++-- narwhals/_dask/expr_str.py | 4 +--- narwhals/_duckdb/dataframe.py | 5 ++--- narwhals/_duckdb/namespace.py | 5 +++-- narwhals/_duckdb/utils.py | 6 ++++-- narwhals/_pandas_like/series_str.py | 4 ++-- narwhals/_pandas_like/utils.py | 2 +- 8 files changed, 18 insertions(+), 24 deletions(-) diff --git a/narwhals/_arrow/series_str.py b/narwhals/_arrow/series_str.py index 50f668d44f..a571d4eabd 100644 --- a/narwhals/_arrow/series_str.py +++ b/narwhals/_arrow/series_str.py @@ -24,7 +24,7 @@ def len_chars(self) -> ArrowSeries: return self.with_native(pc.utf8_length(self.native)) def replace( - self, value: ArrowSeries | str, pattern: str, *, literal: bool, n: int + self, value: ArrowSeries, pattern: str, *, literal: bool, n: int ) -> ArrowSeries: fn = pc.replace_substring if literal else pc.replace_substring_regex _, value_native = extract_native(self.compliant, value) @@ -37,15 +37,9 @@ def replace( return self.with_native(arr) def replace_all( - self, value: ArrowSeries | str, pattern: str, *, literal: bool + self, value: ArrowSeries, pattern: str, *, literal: bool ) -> ArrowSeries: - _, value_native = extract_native(self.compliant, value) - if not isinstance(value_native, pa.StringScalar): - msg = ( - "PyArrow backed `.str.replace_all` only supports str replacement values." - ) - raise TypeError(msg) - return self.replace(value_native.as_py(), pattern, literal=literal, n=-1) + return self.replace(value, pattern, literal=literal, n=-1) def strip_chars(self, characters: str | None) -> ArrowSeries: return self.with_native( diff --git a/narwhals/_compliant/any_namespace.py b/narwhals/_compliant/any_namespace.py index c2afb35e32..b7e48a273f 100644 --- a/narwhals/_compliant/any_namespace.py +++ b/narwhals/_compliant/any_namespace.py @@ -87,8 +87,8 @@ class StringNamespace(_StoresCompliant[T], Protocol[T]): _accessor: ClassVar[Accessor] = "str" def len_chars(self) -> T: ... - def replace(self, pattern: str, value: str, *, literal: bool, n: int) -> T: ... - def replace_all(self, pattern: str, value: str, *, literal: bool) -> T: ... + def replace(self, value: T, pattern: str, *, literal: bool, n: int) -> T: ... + def replace_all(self, value: T, pattern: str, *, literal: bool) -> T: ... def strip_chars(self, characters: str | None) -> T: ... def starts_with(self, prefix: str) -> T: ... def ends_with(self, suffix: str) -> T: ... diff --git a/narwhals/_dask/expr_str.py b/narwhals/_dask/expr_str.py index 04178a4f84..677761c58a 100644 --- a/narwhals/_dask/expr_str.py +++ b/narwhals/_dask/expr_str.py @@ -83,9 +83,7 @@ def to_lowercase(self) -> DaskExpr: return self.compliant._with_callable(lambda expr: expr.str.lower()) def to_titlecase(self) -> DaskExpr: - return self.compliant._with_callable( - lambda expr: expr.str.title(), "to_titlecase" - ) + return self.compliant._with_callable(lambda expr: expr.str.title()) def zfill(self, width: int) -> DaskExpr: return self.compliant._with_callable(lambda expr: expr.str.zfill(width)) diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index b1c87e9d97..e56c1bb3c8 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -488,9 +488,8 @@ def explode(self, columns: Sequence[str]) -> Self: rel = self.native original_columns = self.columns - not_null_condition = col_to_explode.isnotnull() & F("len", col_to_explode) > lit( - 0 - ) + zero = lit(0) + not_null_condition = col_to_explode.isnotnull() & F("len", col_to_explode) > zero non_null_rel = rel.filter(not_null_condition).select( *( F("unnest", col_to_explode).alias(name) if name in columns else name diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index 607aa53f94..8c406fb00b 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -139,10 +139,11 @@ def func(cols: Iterable[Expression]) -> Expression: def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> DuckDBExpr: def func(df: DuckDBLazyFrame) -> list[Expression]: tz = DeferredTimeZone(df.native) + ret = lit(value) if dtype is not None: target = narwhals_to_native_dtype(dtype, self._version, tz) - return [lit(value).cast(target)] - return [lit(value)] + return [ret.cast(target)] + return [ret] def window_func( df: DuckDBLazyFrame, _window_inputs: WindowInputs[Expression] diff --git a/narwhals/_duckdb/utils.py b/narwhals/_duckdb/utils.py index 579ba14829..140b019995 100644 --- a/narwhals/_duckdb/utils.py +++ b/narwhals/_duckdb/utils.py @@ -45,8 +45,10 @@ col = duckdb.ColumnExpression """Alias for `duckdb.ColumnExpression`.""" -lit = duckdb.ConstantExpression -"""Alias for `duckdb.ConstantExpression`.""" + +def lit(value: object) -> duckdb.Expression: + return duckdb.ConstantExpression(value) + when = duckdb.CaseExpression """Alias for `duckdb.CaseExpression`.""" diff --git a/narwhals/_pandas_like/series_str.py b/narwhals/_pandas_like/series_str.py index 1c26c50fe4..d73c46dc20 100644 --- a/narwhals/_pandas_like/series_str.py +++ b/narwhals/_pandas_like/series_str.py @@ -20,7 +20,7 @@ def len_chars(self) -> PandasLikeSeries: return self.with_native(self.native.str.len()) def replace( - self, value: PandasLikeSeries | str, pattern: str, *, literal: bool, n: int + self, value: PandasLikeSeries, pattern: str, *, literal: bool, n: int ) -> PandasLikeSeries: _, value_native = align_and_extract_native(self.compliant, value) if not isinstance(value_native, str): @@ -32,7 +32,7 @@ def replace( return self.with_native(series) def replace_all( - self, value: PandasLikeSeries | str, pattern: str, *, literal: bool + self, value: PandasLikeSeries, pattern: str, *, literal: bool ) -> PandasLikeSeries: return self.replace(value, pattern, literal=literal, n=-1) diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index df9f44e127..61ca3926c4 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -167,7 +167,7 @@ def align_and_extract_native( ) return (lhs.native, rhs.native) - if True: # isinstance(rhs, list): + if isinstance(rhs, list): msg = f"Expected Series or scalar, got {type(rhs)}." raise TypeError(msg) From 9a356c7c197bdfb097682ba6c23269ffb0d10ed7 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 10 Oct 2025 08:24:01 +0100 Subject: [PATCH 64/95] revert `to_frame()/select/get_column` change (can do it later) --- narwhals/_duckdb/namespace.py | 5 ++--- narwhals/series_str.py | 24 +++++++++++------------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index 8c406fb00b..607aa53f94 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -139,11 +139,10 @@ def func(cols: Iterable[Expression]) -> Expression: def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> DuckDBExpr: def func(df: DuckDBLazyFrame) -> list[Expression]: tz = DeferredTimeZone(df.native) - ret = lit(value) if dtype is not None: target = narwhals_to_native_dtype(dtype, self._version, tz) - return [ret.cast(target)] - return [ret] + return [lit(value).cast(target)] + return [lit(value)] def window_func( df: DuckDBLazyFrame, _window_inputs: WindowInputs[Expression] diff --git a/narwhals/series_str.py b/narwhals/series_str.py index ac1270e84c..d4c32dc97e 100644 --- a/narwhals/series_str.py +++ b/narwhals/series_str.py @@ -1,7 +1,8 @@ from __future__ import annotations -from typing import Generic +from typing import Any, Generic +from narwhals.dependencies import is_narwhals_series from narwhals.typing import SeriesT @@ -30,6 +31,9 @@ def len_chars(self) -> SeriesT: self._narwhals_series._compliant_series.str.len_chars() ) + def _extract_compliant(self, arg: Any) -> Any: + return arg._compliant_series if is_narwhals_series(arg) else arg + def replace( self, pattern: str, value: str | SeriesT, *, literal: bool = False, n: int = 1 ) -> SeriesT: @@ -51,14 +55,11 @@ def replace( 1 abc123 dtype: object """ - from narwhals.functions import col - - df = self._narwhals_series.to_frame().select( - col(self._narwhals_series.name).str.replace( - pattern=pattern, value=value, literal=literal, n=n + return self._narwhals_series._with_compliant( + self._narwhals_series._compliant_series.str.replace( + self._extract_compliant(value), pattern=pattern, literal=literal, n=n ) ) - return df[self._narwhals_series.name] # type: ignore[return-value] def replace_all( self, pattern: str, value: str | SeriesT, *, literal: bool = False @@ -80,14 +81,11 @@ def replace_all( 1 123 dtype: object """ - from narwhals.functions import col - - df = self._narwhals_series.to_frame().select( - col(self._narwhals_series.name).str.replace_all( - pattern=pattern, value=value, literal=literal + return self._narwhals_series._with_compliant( + self._narwhals_series._compliant_series.str.replace_all( + self._extract_compliant(value), pattern, literal=literal ) ) - return df[self._narwhals_series.name] # type: ignore[return-value] def strip_chars(self, characters: str | None = None) -> SeriesT: r"""Remove leading and trailing characters. From 6e9eab463d51b765bc030e465501f0e1e7ecabc9 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 10 Oct 2025 08:50:40 +0100 Subject: [PATCH 65/95] __call__ -> _to_compliant_expr --- narwhals/_compliant/namespace.py | 7 +------ narwhals/_dask/dataframe.py | 8 ++++++++ narwhals/_dask/expr.py | 7 +++---- narwhals/_dask/utils.py | 6 ------ narwhals/_expression_parsing.py | 6 +++--- narwhals/_pandas_like/utils.py | 2 +- narwhals/_polars/namespace.py | 6 ------ narwhals/dataframe.py | 2 +- narwhals/expr.py | 4 +++- 9 files changed, 20 insertions(+), 28 deletions(-) diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index cec203ac0a..6f7062be35 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import partial -from typing import TYPE_CHECKING, Any, Protocol, cast, overload +from typing import TYPE_CHECKING, Any, Protocol, overload from narwhals._compliant.typing import ( CompliantExprT, @@ -29,7 +29,6 @@ from narwhals._compliant.selectors import CompliantSelectorNamespace from narwhals._utils import Implementation, Version - from narwhals.expr import Expr from narwhals.typing import ( ConcatMethod, Into1DArray, @@ -56,10 +55,6 @@ class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]): @property def _expr(self) -> type[CompliantExprT]: ... - def evaluate_expr(self, data: Expr, /) -> CompliantExprT: - ret = data(self) - return cast("CompliantExprT", ret) - # NOTE: `polars` def all(self) -> CompliantExprT: return self._expr.from_column_names(get_column_names, context=self) diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 108cbf3b96..2f2a9e374b 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -18,6 +18,7 @@ parse_columns_to_drop, zip_strict, ) +from narwhals.exceptions import MultiOutputExpressionError from narwhals.typing import CompliantLazyFrame if TYPE_CHECKING: @@ -107,6 +108,13 @@ def _iter_columns(self) -> Iterator[dx.Series]: for _col, ser in self.native.items(): # noqa: PERF102 yield ser + def _evaluate_single_output_expr(self, obj: DaskExpr) -> dx.Series: + results = obj._call(self) + if len(results) != 1: # pragma: no cover + msg = "multi-output expressions not allowed in this context" + raise MultiOutputExpressionError(msg) + return results[0] + def with_columns(self, *exprs: DaskExpr) -> Self: new_series = evaluate_exprs(self, *exprs) return self._with_native(self.native.assign(**dict(new_series))) diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index e295ad35d1..aa4fc525b0 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -11,7 +11,6 @@ from narwhals._dask.utils import ( add_row_index, align_series_full_broadcast, - evaluate_expr, narwhals_to_native_dtype, ) from narwhals._expression_parsing import evaluate_output_names_and_aliases @@ -136,7 +135,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: native_results: list[dx.Series] = [] native_series_list = self._call(df) other_native_series = { - key: evaluate_expr(df, value) + key: df._evaluate_single_output_expr(value) for key, value in expressifiable_args.items() } for native_series in native_series_list: @@ -212,7 +211,7 @@ def _floordiv( return (series.__floordiv__(other)).where(other != 0, None) def func(df: DaskLazyFrame) -> list[dx.Series]: - other_series = evaluate_expr(df, other) + other_series = df._evaluate_single_output_expr(other) return [_floordiv(df, series, other_series) for series in self(df)] return self.__class__( @@ -266,7 +265,7 @@ def _rfloordiv( return (other.__floordiv__(series)).where(series != 0, None) def func(df: DaskLazyFrame) -> list[dx.Series]: - other_native = evaluate_expr(df, other) + other_native = df._evaluate_single_output_expr(other) return [_rfloordiv(df, series, other_native) for series in self(df)] return self.__class__( diff --git a/narwhals/_dask/utils.py b/narwhals/_dask/utils.py index 1cc1d337c4..e43a45c9d4 100644 --- a/narwhals/_dask/utils.py +++ b/narwhals/_dask/utils.py @@ -23,12 +23,6 @@ import dask_expr as dx -def evaluate_expr(df: DaskLazyFrame, obj: DaskExpr) -> dx.Series: - results = obj._call(df) - assert len(results) == 1 # debug assertion # noqa: S101 - return results[0] - - def evaluate_exprs(df: DaskLazyFrame, /, *exprs: DaskExpr) -> list[tuple[str, dx.Series]]: native_results: list[tuple[str, dx.Series]] = [] for expr in exprs: diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index e817ae21ce..99018ed56b 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -752,9 +752,9 @@ def evaluate_into_exprs( allow_multi_output: bool, ) -> Iterator[CompliantExprAny]: for expr in exprs: - ret = ns.evaluate_expr( - _parse_into_expr(expr, str_as_lit=str_as_lit, backend=ns._implementation) - ) + ret = _parse_into_expr( + expr, str_as_lit=str_as_lit, backend=ns._implementation + )._to_compliant_expr(ns) if not allow_multi_output and ret._metadata.expansion_kind.is_multi_output(): msg = "Multi-output expressions are not allowed in this context." raise MultiOutputExpressionError(msg) diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 61ca3926c4..d5a0d592fe 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -168,7 +168,7 @@ def align_and_extract_native( return (lhs.native, rhs.native) if isinstance(rhs, list): - msg = f"Expected Series or scalar, got {type(rhs)}." + msg = "Expected Series or scalar, got list." raise TypeError(msg) # `rhs` must be scalar, so just leave it as-is diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 1c32c3f3cf..145fa89ed1 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -22,7 +22,6 @@ from narwhals._polars.dataframe import Method, PolarsDataFrame, PolarsLazyFrame from narwhals._polars.typing import FrameT from narwhals._utils import Version, _LimitedContext - from narwhals.expr import Expr from narwhals.typing import Into1DArray, IntoDType, IntoSchema, TimeUnit, _2DArray @@ -45,11 +44,6 @@ def _backend_version(self) -> tuple[int, ...]: def __init__(self, *, version: Version) -> None: self._version = version - def evaluate_expr(self, data: Expr, /) -> PolarsExpr: - expr = data(self) - assert isinstance(expr, self._expr) # noqa: S101 - return expr - def __getattr__(self, attr: str) -> Any: def func(*args: Any, **kwargs: Any) -> Any: pos, kwds = extract_args_kwargs(args, kwargs) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 02d1dec0bc..2deb6a2200 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -155,7 +155,7 @@ def _flatten_and_extract( (parse(expr).alias(alias) for alias, expr in named_exprs.items()), ) for expr in all_exprs: - ce = expr(ns) + ce = expr._to_compliant_expr(ns) out_exprs.append(ce) self._validate_metadata(ce._metadata) return out_exprs diff --git a/narwhals/expr.py b/narwhals/expr.py index 18cdb5ea6f..d01890763f 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -50,7 +50,9 @@ class Expr: def __init__(self, *nodes: ExprNode) -> None: self._nodes = nodes - def __call__(self, ns: CompliantNamespace[Any, Any]) -> CompliantExpr[Any, Any]: + def _to_compliant_expr( + self, ns: CompliantNamespace[Any, Any] + ) -> CompliantExpr[Any, Any]: nodes = self._nodes ce = evaluate_root_node(nodes[0], ns) for node in nodes[1:]: From a8bfefe75d75e435056e454cb5c963765fe1a521 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 10 Oct 2025 09:40:27 +0100 Subject: [PATCH 66/95] skip as appropriate --- tests/expr_and_series/when_test.py | 2 ++ tests/expression_parsing_test.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index 46d036bdbf..cbb908bf64 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -116,6 +116,8 @@ def test_when_then_otherwise_into_expr(constructor: Constructor) -> None: def test_when_then_broadcasting(constructor: Constructor) -> None: + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3): + pytest.skip() df = nw.from_native(constructor(data)) result = df.select(nw.when(nw.col("a").sum() > 1).then("c")) expected = {"c": [4.1, 5, 6]} diff --git a/tests/expression_parsing_test.py b/tests/expression_parsing_test.py index 5dca4f8c16..2c21935df7 100644 --- a/tests/expression_parsing_test.py +++ b/tests/expression_parsing_test.py @@ -4,7 +4,7 @@ import narwhals as nw from narwhals.exceptions import InvalidOperationError -from tests.utils import POLARS_VERSION, Constructor, assert_equal_data +from tests.utils import DUCKDB_VERSION, POLARS_VERSION, Constructor, assert_equal_data @pytest.mark.parametrize( @@ -56,6 +56,8 @@ def test_over_pushdown( ) -> None: if "polars" in str(constructor) and POLARS_VERSION < (1, 10): pytest.skip() + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3): + pytest.skip() data = {"a": [-1, 2, 3], "b": [1, 1, 2], "i": [0, 1, 2]} df = nw.from_native(constructor(data)).lazy() result = df.select("i", a=expr).sort("i").select("a") From 8ee0100013ff7c130c876eabded1d054ff3be4fb Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 10 Oct 2025 09:44:35 +0100 Subject: [PATCH 67/95] coverage, docs --- docs/how_it_works.md | 10 +++++----- narwhals/_compliant/expr.py | 17 ----------------- 2 files changed, 5 insertions(+), 22 deletions(-) diff --git a/docs/how_it_works.md b/docs/how_it_works.md index 8bb7da8e6b..97a2291e53 100644 --- a/docs/how_it_works.md +++ b/docs/how_it_works.md @@ -76,7 +76,7 @@ pn = PandasLikeNamespace( implementation=Implementation.PANDAS, version=Version.MAIN, ) -print(nw.col("a")(pn)) +print(nw.col("a")._to_compliant_expr(pn)) ``` The result from the last line above is the same as we'd get from `pn.col('a')`, and it's @@ -213,7 +213,7 @@ pn = PandasLikeNamespace( implementation=Implementation.PANDAS, version=Version.MAIN, ) -expr = (nw.col("a") + 1)(pn) +expr = (nw.col("a") + 1)._to_compliant_expr(pn) print(expr) ``` @@ -327,9 +327,9 @@ Let's try printing out some compliant expressions' metadata to see what it shows ```python exec="1" result="python" session="pandas_impl" source="above" import narwhals as nw -print(nw.col("a")(pn)._metadata) -print(nw.col("a").mean()(pn)._metadata) -print(nw.col("a").mean().over("b")(pn)._metadata) +print(nw.col("a")._to_compliant_expr(pn)._metadata) +print(nw.col("a").mean()._to_compliant_expr(pn)._metadata) +print(nw.col("a").mean().over("b")._to_compliant_expr(pn)._metadata) ``` This section is all about making sense of what that all means, what the rules are, and what it enables. diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index 73d9f9481e..77deca4133 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -53,7 +53,6 @@ IntoDType, ModeKeepStrategy, NonNestedLiteral, - NumericLiteral, RankMethod, RollingInterpolationMethod, TimeUnit, @@ -797,22 +796,6 @@ def is_between( "is_between", lower_bound=lower_bound, upper_bound=upper_bound, closed=closed ) - def is_close( - self, - other: Self | NumericLiteral, - *, - abs_tol: float, - rel_tol: float, - nans_equal: bool, - ) -> Self: - return self._reuse_series( - "is_close", - other=other, - abs_tol=abs_tol, - rel_tol=rel_tol, - nans_equal=nans_equal, - ) - def first(self) -> Self: return self._reuse_series("first", returns_scalar=True) From 52d58c85e666713af4ef55b23f08dde3345545ce Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 10 Oct 2025 09:50:11 +0100 Subject: [PATCH 68/95] one more --- tests/expression_parsing_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/expression_parsing_test.py b/tests/expression_parsing_test.py index 2c21935df7..b5c6b722f1 100644 --- a/tests/expression_parsing_test.py +++ b/tests/expression_parsing_test.py @@ -76,6 +76,8 @@ def test_per_group_broadcasting( if "dask" in str(constructor): # sigh... request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3): + pytest.skip() data = {"a": [-1, 2, 3], "b": [1, 1, 2], "i": [0, 1, 2]} df = nw.from_native(constructor(data)).lazy() result = df.select("i", a=expr).sort("i").select("a") From e987618cec4baca38eff40488992b6bbce0307ac Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 11 Oct 2025 13:18:06 +0100 Subject: [PATCH 69/95] update docs, use `cls`, coverage, take over_node_order_by and partition by out of loop, list comprehension in repr --- docs/how_it_works.md | 2 +- narwhals/_expression_parsing.py | 26 ++++++++++++++------------ narwhals/expr.py | 5 +---- tests/v1_test.py | 4 +++- 4 files changed, 19 insertions(+), 18 deletions(-) diff --git a/docs/how_it_works.md b/docs/how_it_works.md index 97a2291e53..7adff7a0a5 100644 --- a/docs/how_it_works.md +++ b/docs/how_it_works.md @@ -178,7 +178,7 @@ The way you access the Narwhals-compliant wrapper depends on the object: - `narwhals.DataFrame` and `narwhals.LazyFrame`: use the `._compliant_frame` attribute. - `narwhals.Series`: use the `._compliant_series` attribute. -- `narwhals.Expr`: call the `.__call__` method, and pass to it the Narwhals-compliant namespace associated with +- `narwhals.Expr`: call the `._to_compliant_expr` method, and pass to it the Narwhals-compliant namespace associated with the given backend. 🛑 BUT WAIT! What's a Narwhals-compliant namespace? diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 99018ed56b..47aa09b54c 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -269,14 +269,16 @@ def _push_down_over_node_in_place( exprs: list[IntoExpr | NonNestedLiteral] = [] # Note: please keep this as a for-loop (rather than a list-comprehension) # so that pytest-cov highlights any uncovered branches. + over_node_order_by = over_node.kwargs["order_by"] + over_node_partition_by = over_node.kwargs["partition_by"] for expr in self.exprs: if not is_expr(expr): exprs.append(expr) - elif over_node.kwargs["order_by"] and any( + elif over_node_order_by and any( expr_node.is_orderable() for expr_node in expr._nodes ): exprs.append(expr._with_node(over_node)) - elif over_node_without_order_by.kwargs["partition_by"] and not all( + elif over_node_partition_by and not all( expr_node.is_elementwise() for expr_node in expr._nodes ): exprs.append(expr._with_node(over_node_without_order_by)) @@ -394,26 +396,26 @@ def from_node( # noqa: PLR0911 return cls.from_selector_single(node) if node.kind is ExprKind.COL: return ( - ExprMetadata.from_selector_single(node) + cls.from_selector_single(node) if len(node.kwargs["names"]) == 1 - else ExprMetadata.from_selector_multi_named(node) + else cls.from_selector_multi_named(node) ) if node.kind is ExprKind.NTH: return ( - ExprMetadata.from_selector_single(node) + cls.from_selector_single(node) if len(node.kwargs["indices"]) == 1 - else ExprMetadata.from_selector_multi_unnamed(node) + else cls.from_selector_multi_unnamed(node) ) if node.kind in {ExprKind.ALL, ExprKind.EXCLUDE}: - return ExprMetadata.from_selector_multi_unnamed(node) + return cls.from_selector_multi_unnamed(node) if node.kind is ExprKind.AGGREGATION: - return ExprMetadata.from_aggregation(node) + return cls.from_aggregation(node) if node.kind is ExprKind.LITERAL: - return ExprMetadata.from_literal(node) + return cls.from_literal(node) if node.kind is ExprKind.SELECTOR: - return ExprMetadata.from_selector_multi_unnamed(node) + return cls.from_selector_multi_unnamed(node) if node.kind is ExprKind.ELEMENTWISE: - return ExprMetadata.from_elementwise(node, *ces) + return cls.from_elementwise(node, *ces) msg = f"Unexpected node kind: {node.kind}" # pragma: no cover raise AssertionError(msg) # pragma: no cover @@ -630,7 +632,7 @@ def with_filtration(self, node: ExprNode) -> ExprMetadata: ) def with_orderable_filtration(self, node: ExprNode) -> ExprMetadata: - if self.is_scalar_like: # pragma: no cover + if self.is_scalar_like: msg = "Can't apply filtration (e.g. `drop_nulls`) to scalar-like expression." raise InvalidOperationError(msg) return ExprMetadata( diff --git a/narwhals/expr.py b/narwhals/expr.py index d01890763f..8a608e4370 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -94,10 +94,7 @@ def _with_node(self, node: ExprNode) -> Self: def __repr__(self) -> str: """Pretty-print the expression by combining all nodes in the metadata.""" - result: str = repr(self._nodes[0]) - for node in self._nodes[1:]: - result = f"{result}.{node!r}" - return result + return ".".join(repr(node) for node in self._nodes) def __bool__(self) -> NoReturn: msg = ( diff --git a/tests/v1_test.py b/tests/v1_test.py index 2a9e9660d2..e656fbbe3c 100644 --- a/tests/v1_test.py +++ b/tests/v1_test.py @@ -901,10 +901,12 @@ def test_unique_series_v1() -> None: series.to_frame().select(nw_v1.col("a").unique(maintain_order=False).sum()) -def test_head_aggregation() -> None: +def test_invalid() -> None: df = nw.from_native(pd.DataFrame({"a": [1, 2]})) with pytest.raises(InvalidOperationError): df.select(nw_v1.col("a").mean().head()) + with pytest.raises(InvalidOperationError): + df.select(nw_v1.col("a").mean().arg_true()) def test_deprecated_expr_methods() -> None: From edc117deff9598e940f712d76b3547314ce8bf83 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sat, 11 Oct 2025 16:40:19 +0200 Subject: [PATCH 70/95] WIP: refactor ExprMetadata --- narwhals/_expression_parsing.py | 87 +++++++++++++++++++++++---------- 1 file changed, 61 insertions(+), 26 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 99018ed56b..324b438667 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -330,17 +330,20 @@ class ExprMetadata: lazy case, this number must be `0` by the time the expression is evaluated. preserves_length: Whether the expression preserves the input length. + current_node: The current ExprNode in the linked list. + prev: Reference to the previous ExprMetadata in the linked list (None for root). """ __slots__ = ( + "current_node", "expansion_kind", "has_windows", "is_elementwise", "is_literal", "is_scalar_like", "n_orderable_ops", - "nodes", "preserves_length", + "prev", ) def __init__( @@ -353,7 +356,8 @@ def __init__( is_elementwise: bool = True, is_scalar_like: bool = False, is_literal: bool = False, - nodes: tuple[ExprNode, ...], + current_node: ExprNode, + prev: ExprMetadata | None = None, ) -> None: if is_literal: assert is_scalar_like # noqa: S101 # debug assertion @@ -366,7 +370,8 @@ def __init__( self.preserves_length: bool = preserves_length self.is_scalar_like: bool = is_scalar_like self.is_literal: bool = is_literal - self.nodes: tuple[ExprNode, ...] = nodes + self.current_node: ExprNode = current_node + self.prev: ExprMetadata | None = prev def __init_subclass__(cls, /, *args: Any, **kwds: Any) -> Never: # pragma: no cover msg = f"Cannot subclass {cls.__name__!r}" @@ -382,10 +387,27 @@ def __repr__(self) -> str: # pragma: no cover f" preserves_length: {self.preserves_length},\n" f" is_scalar_like: {self.is_scalar_like},\n" f" is_literal: {self.is_literal},\n" - f" nodes: {self.nodes},\n" + f" nodes: {tuple(self.iter_nodes())},\n" ")" ) + def iter_nodes(self) -> Iterator[ExprNode]: + """Iterate through all nodes from root to current.""" + nodes: list[ExprNode] = [] + current: ExprMetadata | None = self + while current is not None: + nodes.append(current.current_node) + current = current.prev + # Reverse to get from root to current + return iter(reversed(nodes)) + + def iter_nodes_reversed(self) -> Iterator[ExprNode]: + """Iterate through all nodes from current to root.""" + current: ExprMetadata | None = self + while current is not None: + yield current.current_node + current = current.prev + @classmethod def from_node( # noqa: PLR0911 cls, node: ExprNode, *ces: CompliantExprAny @@ -423,9 +445,7 @@ def with_node( # noqa: PLR0911,C901 if node.kind is ExprKind.AGGREGATION: return self.with_aggregation(node) if node.kind is ExprKind.ELEMENTWISE: - return combine_metadata( - ce, *ces, to_single_output=False, nodes=(*ce._metadata.nodes, node) - ) + return combine_metadata(ce, *ces, to_single_output=False, current_node=node) if node.kind is ExprKind.FILTRATION: return self.with_filtration(node) if node.kind is ExprKind.ORDERABLE_WINDOW: @@ -453,7 +473,8 @@ def from_aggregation(cls, node: ExprNode) -> ExprMetadata: is_elementwise=False, preserves_length=False, is_scalar_like=True, - nodes=(node,), + current_node=node, + prev=None, ) @classmethod @@ -464,27 +485,28 @@ def from_literal(cls, node: ExprNode) -> ExprMetadata: preserves_length=False, is_literal=True, is_scalar_like=True, - nodes=(node,), + current_node=node, + prev=None, ) @classmethod def from_selector_single(cls, node: ExprNode) -> ExprMetadata: # e.g. `nw.col('a')`, `nw.nth(0)` - return cls(ExpansionKind.SINGLE, nodes=(node,)) + return cls(ExpansionKind.SINGLE, current_node=node, prev=None) @classmethod def from_selector_multi_named(cls, node: ExprNode) -> ExprMetadata: # e.g. `nw.col('a', 'b')` - return cls(ExpansionKind.MULTI_NAMED, nodes=(node,)) + return cls(ExpansionKind.MULTI_NAMED, current_node=node, prev=None) @classmethod def from_selector_multi_unnamed(cls, node: ExprNode) -> ExprMetadata: # e.g. `nw.all()` - return cls(ExpansionKind.MULTI_UNNAMED, nodes=(node,)) + return cls(ExpansionKind.MULTI_UNNAMED, current_node=node, prev=None) @classmethod def from_elementwise(cls, node: ExprNode, *ces: CompliantExprAny) -> ExprMetadata: - return combine_metadata(*ces, to_single_output=True, nodes=(node,)) + return combine_metadata(*ces, to_single_output=True, current_node=node) @property def is_filtration(self) -> bool: @@ -502,7 +524,8 @@ def with_aggregation(self, node: ExprNode) -> ExprMetadata: is_elementwise=False, is_scalar_like=True, is_literal=False, - nodes=(*self.nodes, node), + current_node=node, + prev=self, ) def with_orderable_aggregation(self, node: ExprNode) -> ExprMetadata: @@ -518,7 +541,8 @@ def with_orderable_aggregation(self, node: ExprNode) -> ExprMetadata: is_elementwise=False, is_scalar_like=True, is_literal=False, - nodes=(*self.nodes, node), + current_node=node, + prev=self, ) def with_window(self, node: ExprNode) -> ExprMetadata: @@ -536,7 +560,8 @@ def with_window(self, node: ExprNode) -> ExprMetadata: is_elementwise=False, is_scalar_like=False, is_literal=False, - nodes=(*self.nodes, node), + current_node=node, + prev=self, ) def with_orderable_window(self, node: ExprNode) -> ExprMetadata: @@ -552,7 +577,8 @@ def with_orderable_window(self, node: ExprNode) -> ExprMetadata: is_elementwise=False, is_scalar_like=False, is_literal=False, - nodes=(*self.nodes, node), + current_node=node, + prev=self, ) def with_ordered_over(self, node: ExprNode) -> ExprMetadata: @@ -590,7 +616,8 @@ def with_ordered_over(self, node: ExprNode) -> ExprMetadata: is_elementwise=False, is_scalar_like=False, is_literal=False, - nodes=(*self.nodes, node), + current_node=node, + prev=self, ) def with_partitioned_over(self, node: ExprNode) -> ExprMetadata: @@ -611,7 +638,8 @@ def with_partitioned_over(self, node: ExprNode) -> ExprMetadata: is_elementwise=False, is_scalar_like=False, is_literal=False, - nodes=(*self.nodes, node), + current_node=node, + prev=self, ) def with_filtration(self, node: ExprNode) -> ExprMetadata: @@ -626,7 +654,8 @@ def with_filtration(self, node: ExprNode) -> ExprMetadata: is_elementwise=False, is_scalar_like=False, is_literal=False, - nodes=(*self.nodes, node), + current_node=node, + prev=self, ) def with_orderable_filtration(self, node: ExprNode) -> ExprMetadata: @@ -641,19 +670,20 @@ def with_orderable_filtration(self, node: ExprNode) -> ExprMetadata: is_elementwise=False, is_scalar_like=False, is_literal=False, - nodes=(*self.nodes, node), + current_node=node, + prev=self, ) def op_nodes_reversed(self) -> Iterator[ExprNode]: - for node in reversed(self.nodes): - if node.name.startswith("name.") or node.name == "alias": + for node in self.iter_nodes_reversed(): + if node.name.startswith(("name.", "alias")): # Skip nodes which only do aliasing. continue yield node def combine_metadata( - *args: CompliantExprAny, to_single_output: bool, nodes: tuple[ExprNode, ...] + *args: CompliantExprAny, to_single_output: bool, current_node: ExprNode ) -> ExprMetadata: """Combine metadata from `args`. @@ -661,7 +691,7 @@ def combine_metadata( args: Arguments, maybe expressions, literals, or Series. to_single_output: Whether the result is always single-output, regardless of the inputs (e.g. `nw.sum_horizontal`). - nodes: Nodes of result node. + current_node: The current node being added. """ n_filtrations = 0 result_expansion_kind = ExpansionKind.SINGLE @@ -675,10 +705,14 @@ def combine_metadata( result_is_scalar_like = True # result is literal if all inputs are literal result_is_literal = True + # Keep reference to first argument's metadata to use as prev + first_metadata: ExprMetadata | None = None for i, arg in enumerate(args): metadata = arg._metadata assert metadata is not None # noqa: S101 + if i == 0: + first_metadata = metadata if metadata.expansion_kind.is_multi_output(): expansion_kind = metadata.expansion_kind if not to_single_output: @@ -707,7 +741,8 @@ def combine_metadata( is_elementwise=result_is_elementwise, is_scalar_like=result_is_scalar_like, is_literal=result_is_literal, - nodes=nodes, + current_node=current_node, + prev=first_metadata, ) From 639b9965a20498befb896b6e319c795c2e49fb63 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sat, 11 Oct 2025 17:03:56 +0200 Subject: [PATCH 71/95] kind = node.kind in from_node and with_node --- narwhals/_expression_parsing.py | 56 +++++++++++++++++---------------- 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 324b438667..e06c9d3983 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -412,58 +412,60 @@ def iter_nodes_reversed(self) -> Iterator[ExprNode]: def from_node( # noqa: PLR0911 cls, node: ExprNode, *ces: CompliantExprAny ) -> ExprMetadata: - if node.kind is ExprKind.SERIES: + kind = node.kind + if kind is ExprKind.SERIES: return cls.from_selector_single(node) - if node.kind is ExprKind.COL: + if kind is ExprKind.COL: return ( - ExprMetadata.from_selector_single(node) + cls.from_selector_single(node) if len(node.kwargs["names"]) == 1 - else ExprMetadata.from_selector_multi_named(node) + else cls.from_selector_multi_named(node) ) - if node.kind is ExprKind.NTH: + if kind is ExprKind.NTH: return ( - ExprMetadata.from_selector_single(node) + cls.from_selector_single(node) if len(node.kwargs["indices"]) == 1 - else ExprMetadata.from_selector_multi_unnamed(node) + else cls.from_selector_multi_unnamed(node) ) - if node.kind in {ExprKind.ALL, ExprKind.EXCLUDE}: - return ExprMetadata.from_selector_multi_unnamed(node) - if node.kind is ExprKind.AGGREGATION: - return ExprMetadata.from_aggregation(node) - if node.kind is ExprKind.LITERAL: - return ExprMetadata.from_literal(node) - if node.kind is ExprKind.SELECTOR: - return ExprMetadata.from_selector_multi_unnamed(node) - if node.kind is ExprKind.ELEMENTWISE: - return ExprMetadata.from_elementwise(node, *ces) - msg = f"Unexpected node kind: {node.kind}" # pragma: no cover + if kind in {ExprKind.ALL, ExprKind.EXCLUDE}: + return cls.from_selector_multi_unnamed(node) + if kind is ExprKind.AGGREGATION: + return cls.from_aggregation(node) + if kind is ExprKind.LITERAL: + return cls.from_literal(node) + if kind is ExprKind.SELECTOR: + return cls.from_selector_multi_unnamed(node) + if kind is ExprKind.ELEMENTWISE: + return cls.from_elementwise(node, *ces) + msg = f"Unexpected node kind: {kind}" # pragma: no cover raise AssertionError(msg) # pragma: no cover def with_node( # noqa: PLR0911,C901 self, node: ExprNode, ce: CompliantExprAny, *ces: CompliantExprAny ) -> ExprMetadata: - if node.kind is ExprKind.AGGREGATION: + kind = node.kind + if kind is ExprKind.AGGREGATION: return self.with_aggregation(node) - if node.kind is ExprKind.ELEMENTWISE: + if kind is ExprKind.ELEMENTWISE: return combine_metadata(ce, *ces, to_single_output=False, current_node=node) - if node.kind is ExprKind.FILTRATION: + if kind is ExprKind.FILTRATION: return self.with_filtration(node) - if node.kind is ExprKind.ORDERABLE_WINDOW: + if kind is ExprKind.ORDERABLE_WINDOW: return self.with_orderable_window(node) - if node.kind is ExprKind.ORDERABLE_FILTRATION: + if kind is ExprKind.ORDERABLE_FILTRATION: return self.with_orderable_filtration(node) - if node.kind is ExprKind.ORDERABLE_AGGREGATION: + if kind is ExprKind.ORDERABLE_AGGREGATION: return self.with_orderable_aggregation(node) - if node.kind is ExprKind.WINDOW: + if kind is ExprKind.WINDOW: return self.with_window(node) - if node.kind is ExprKind.OVER: + if kind is ExprKind.OVER: if node.kwargs["order_by"]: return self.with_ordered_over(node) if not node.kwargs["partition_by"]: # pragma: no cover msg = "At least one of `partition_by` or `order_by` must be specified." raise InvalidOperationError(msg) return self.with_partitioned_over(node) - msg = f"Unexpected node kind: {node.kind}" # pragma: no cover + msg = f"Unexpected node kind: {kind}" # pragma: no cover raise AssertionError(msg) # pragma: no cover @classmethod From 52dd7f6501e4f5ca0ace8ad21692df44ddcb67b2 Mon Sep 17 00:00:00 2001 From: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Date: Sat, 11 Oct 2025 17:55:11 +0200 Subject: [PATCH 72/95] no cover `iter_nodes` --- narwhals/_expression_parsing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 41e05d9874..bf57376c1f 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -393,7 +393,7 @@ def __repr__(self) -> str: # pragma: no cover ")" ) - def iter_nodes(self) -> Iterator[ExprNode]: + def iter_nodes(self) -> Iterator[ExprNode]: # pragma: no cover """Iterate through all nodes from root to current.""" nodes: list[ExprNode] = [] current: ExprMetadata | None = self From e8ab3f6ef44e7b26fae5aad1441357e36a612b35 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sat, 11 Oct 2025 18:45:03 +0200 Subject: [PATCH 73/95] rm iter_nodes method --- narwhals/_expression_parsing.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index bf57376c1f..5314461e76 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -380,6 +380,7 @@ def __init_subclass__(cls, /, *args: Any, **kwds: Any) -> Never: # pragma: no c raise TypeError(msg) def __repr__(self) -> str: # pragma: no cover + nodes = tuple(reversed(tuple(self.iter_nodes_reversed()))) return ( f"ExprMetadata(\n" f" expansion_kind: {self.expansion_kind},\n" @@ -389,20 +390,10 @@ def __repr__(self) -> str: # pragma: no cover f" preserves_length: {self.preserves_length},\n" f" is_scalar_like: {self.is_scalar_like},\n" f" is_literal: {self.is_literal},\n" - f" nodes: {tuple(self.iter_nodes())},\n" + f" nodes: {nodes},\n" ")" ) - def iter_nodes(self) -> Iterator[ExprNode]: # pragma: no cover - """Iterate through all nodes from root to current.""" - nodes: list[ExprNode] = [] - current: ExprMetadata | None = self - while current is not None: - nodes.append(current.current_node) - current = current.prev - # Reverse to get from root to current - return iter(reversed(nodes)) - def iter_nodes_reversed(self) -> Iterator[ExprNode]: """Iterate through all nodes from current to root.""" current: ExprMetadata | None = self From 3ff84acc17242796da742b01b91deb7194f95473 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 11 Oct 2025 19:20:35 +0100 Subject: [PATCH 74/95] pass `prev` to `combine_metadata` --- narwhals/_expression_parsing.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 5314461e76..5729601b06 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -440,7 +440,9 @@ def with_node( # noqa: PLR0911,C901 if kind is ExprKind.AGGREGATION: return self.with_aggregation(node) if kind is ExprKind.ELEMENTWISE: - return combine_metadata(ce, *ces, to_single_output=False, current_node=node) + return combine_metadata( + ce, *ces, to_single_output=False, current_node=node, prev=ce._metadata + ) if kind is ExprKind.FILTRATION: return self.with_filtration(node) if kind is ExprKind.ORDERABLE_WINDOW: @@ -501,7 +503,7 @@ def from_selector_multi_unnamed(cls, node: ExprNode) -> ExprMetadata: @classmethod def from_elementwise(cls, node: ExprNode, *ces: CompliantExprAny) -> ExprMetadata: - return combine_metadata(*ces, to_single_output=True, current_node=node) + return combine_metadata(*ces, to_single_output=True, current_node=node, prev=None) @property def is_filtration(self) -> bool: @@ -678,7 +680,10 @@ def op_nodes_reversed(self) -> Iterator[ExprNode]: def combine_metadata( - *args: CompliantExprAny, to_single_output: bool, current_node: ExprNode + *args: CompliantExprAny, + to_single_output: bool, + current_node: ExprNode, + prev: ExprMetadata | None, ) -> ExprMetadata: """Combine metadata from `args`. @@ -687,6 +692,7 @@ def combine_metadata( to_single_output: Whether the result is always single-output, regardless of the inputs (e.g. `nw.sum_horizontal`). current_node: The current node being added. + prev: ExprMetadata of previous node. """ n_filtrations = 0 result_expansion_kind = ExpansionKind.SINGLE @@ -700,14 +706,10 @@ def combine_metadata( result_is_scalar_like = True # result is literal if all inputs are literal result_is_literal = True - # Keep reference to first argument's metadata to use as prev - first_metadata: ExprMetadata | None = None for i, arg in enumerate(args): metadata = arg._metadata assert metadata is not None # noqa: S101 - if i == 0: - first_metadata = metadata if metadata.expansion_kind.is_multi_output(): expansion_kind = metadata.expansion_kind if not to_single_output: @@ -737,7 +739,7 @@ def combine_metadata( is_scalar_like=result_is_scalar_like, is_literal=result_is_literal, current_node=current_node, - prev=first_metadata, + prev=prev, ) From b0a78e31c3b796fd41f7536d342e729a17d047b4 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 12 Oct 2025 19:19:47 +0100 Subject: [PATCH 75/95] fixup --- narwhals/_arrow/series.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 058107f0c9..52c10bd863 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -488,7 +488,9 @@ def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray: return self.native.to_numpy() def alias(self, name: str) -> Self: - return self.__class__(self.native, name=name, version=self._version) + ret = self.__class__(self.native, name=name, version=self._version) + ret._broadcast = self._broadcast + return ret @property def dtype(self) -> DType: From 1962d5329de5f956b080b91587ffc82de745050d Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 16 Oct 2025 10:45:27 +0100 Subject: [PATCH 76/95] ceil, floor --- narwhals/expr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/expr.py b/narwhals/expr.py index bb700b9c2c..d5f6e1a5f7 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -1494,7 +1494,7 @@ def floor(self) -> Self: |floor: [[1,4,-2]] | └────────────────────────┘ """ - return self._with_elementwise(lambda plx: self._to_compliant_expr(plx).floor()) + return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "floor")) def ceil(self) -> Self: r"""Compute the numerical ceiling. @@ -1517,7 +1517,7 @@ def ceil(self) -> Self: |ceil: [[2,5,-1]] | └────────────────────────┘ """ - return self._with_elementwise(lambda plx: self._to_compliant_expr(plx).ceil()) + return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "ceil")) def len(self) -> Self: r"""Return the number of elements in the column. From eb1f74ec0e9d7b01850476138528f8d8eff34c93 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 16 Oct 2025 10:51:17 +0100 Subject: [PATCH 77/95] split out `_with_node` into `_with_over_node` and `_append_node` --- narwhals/_expression_parsing.py | 4 +- narwhals/expr.py | 189 ++++++++++++++++---------------- narwhals/expr_cat.py | 4 +- narwhals/expr_dt.py | 46 ++++---- narwhals/expr_list.py | 8 +- narwhals/expr_name.py | 16 ++- narwhals/expr_str.py | 34 +++--- narwhals/expr_struct.py | 2 +- narwhals/selectors.py | 10 +- narwhals/stable/v1/__init__.py | 18 +-- 10 files changed, 169 insertions(+), 162 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 5729601b06..34ef3dd5d1 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -277,11 +277,11 @@ def _push_down_over_node_in_place( elif over_node_order_by and any( expr_node.is_orderable() for expr_node in expr._nodes ): - exprs.append(expr._with_node(over_node)) + exprs.append(expr._append_node(over_node)) elif over_node_partition_by and not all( expr_node.is_elementwise() for expr_node in expr._nodes ): - exprs.append(expr._with_node(over_node_without_order_by)) + exprs.append(expr._append_node(over_node_without_order_by)) else: # If there's no `partition_by`, then `over_node_without_order_by` is a no-op. exprs.append(expr) diff --git a/narwhals/expr.py b/narwhals/expr.py index d5f6e1a5f7..d0f38b7419 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -59,39 +59,38 @@ def _to_compliant_expr( ce = evaluate_node(ce, node, ns) return ce - def _with_node(self, node: ExprNode) -> Self: - if node.kind is ExprKind.OVER: - # insert `over` before any elementwise operations. - # check "how it works" page in docs for why we do this. - new_nodes = list(self._nodes) - kwargs_no_order_by = { - key: value if key != "order_by" else [] - for (key, value) in node.kwargs.items() - } - node_without_order_by = node._with_kwargs(**kwargs_no_order_by) - n = len(new_nodes) - i = n - while i > 0 and (_node := new_nodes[i - 1]).kind is ExprKind.ELEMENTWISE: - i -= 1 - _node._push_down_over_node_in_place(node, node_without_order_by) - if i == n: - # node could not be pushed down, just append as-is - new_nodes.append(node) - return self.__class__(*new_nodes) - if node.kwargs["order_by"] and any( - node.is_orderable() for node in new_nodes[:i] - ): - new_nodes.insert(i, node) - elif node.kwargs["partition_by"] and not all( - node.is_elementwise() for node in new_nodes[:i] - ): - new_nodes.insert(i, node_without_order_by) - elif all(node.is_elementwise() for node in new_nodes): - msg = "Cannot apply `over` to elementwise expression." - raise InvalidOperationError(msg) - return self.__class__(*new_nodes) + def _append_node(self, node: ExprNode) -> Self: return self.__class__(*self._nodes, node) + def _with_over_node(self, node: ExprNode) -> Self: + # insert `over` before any elementwise operations. + # check "how it works" page in docs for why we do this. + new_nodes = list(self._nodes) + kwargs_no_order_by = { + key: value if key != "order_by" else [] + for (key, value) in node.kwargs.items() + } + node_without_order_by = node._with_kwargs(**kwargs_no_order_by) + n = len(new_nodes) + i = n + while i > 0 and (_node := new_nodes[i - 1]).kind is ExprKind.ELEMENTWISE: + i -= 1 + _node._push_down_over_node_in_place(node, node_without_order_by) + if i == n: + # node could not be pushed down, just append as-is + new_nodes.append(node) + return self.__class__(*new_nodes) + if node.kwargs["order_by"] and any(node.is_orderable() for node in new_nodes[:i]): + new_nodes.insert(i, node) + elif node.kwargs["partition_by"] and not all( + node.is_elementwise() for node in new_nodes[:i] + ): + new_nodes.insert(i, node_without_order_by) + elif all(node.is_elementwise() for node in new_nodes): + msg = "Cannot apply `over` to elementwise expression." + raise InvalidOperationError(msg) + return self.__class__(*new_nodes) + def __repr__(self) -> str: """Pretty-print the expression by combining all nodes in the metadata.""" return ".".join(repr(node) for node in self._nodes) @@ -135,7 +134,7 @@ def alias(self, name: str) -> Self: | 1 15 | └──────────────────┘ """ - return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "alias", name=name)) + return self._append_node(ExprNode(ExprKind.ELEMENTWISE, "alias", name=name)) def pipe( self, @@ -190,12 +189,12 @@ def cast(self, dtype: IntoDType) -> Self: └──────────────────┘ """ _validate_dtype(dtype) - return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "cast", dtype=dtype)) + return self._append_node(ExprNode(ExprKind.ELEMENTWISE, "cast", dtype=dtype)) # --- binary --- def _with_binary(self, attr: str, other: Self | Any) -> Self: node = ExprNode(ExprKind.ELEMENTWISE, attr, other, str_as_lit=True) - return self._with_node(node) + return self._append_node(node) def __eq__(self, other: Self | Any) -> Self: # type: ignore[override] return self._with_binary("__eq__", other) @@ -271,7 +270,7 @@ def __rmod__(self, other: Any) -> Self: # --- unary --- def __invert__(self) -> Self: - return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "__invert__")) + return self._append_node(ExprNode(ExprKind.ELEMENTWISE, "__invert__")) def any(self) -> Self: """Return whether any of the values in the column are `True`. @@ -291,7 +290,7 @@ def any(self) -> Self: | 0 True True | └──────────────────┘ """ - return self._with_node(ExprNode(ExprKind.AGGREGATION, "any")) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "any")) def all(self) -> Self: """Return whether all values in the column are `True`. @@ -311,7 +310,7 @@ def all(self) -> Self: | 0 False True | └──────────────────┘ """ - return self._with_node(ExprNode(ExprKind.AGGREGATION, "all")) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "all")) def ewm_mean( self, @@ -396,7 +395,7 @@ def ewm_mean( │ 2.428571 │ └──────────┘ """ - return self._with_node( + return self._append_node( ExprNode( ExprKind.ORDERABLE_WINDOW, "ewm_mean", @@ -426,7 +425,7 @@ def mean(self) -> Self: | 0 0.0 4.0 | └──────────────────┘ """ - return self._with_node(ExprNode(ExprKind.AGGREGATION, "mean")) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "mean")) def median(self) -> Self: """Get median value. @@ -447,7 +446,7 @@ def median(self) -> Self: | 0 3.0 4.0 | └──────────────────┘ """ - return self._with_node(ExprNode(ExprKind.AGGREGATION, "median")) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "median")) def std(self, *, ddof: int = 1) -> Self: """Get standard deviation. @@ -469,7 +468,7 @@ def std(self, *, ddof: int = 1) -> Self: |0 17.79513 1.265789| └─────────────────────┘ """ - return self._with_node(ExprNode(ExprKind.AGGREGATION, "std", ddof=ddof)) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "std", ddof=ddof)) def var(self, *, ddof: int = 1) -> Self: """Get variance. @@ -491,7 +490,7 @@ def var(self, *, ddof: int = 1) -> Self: |0 316.666667 1.602222| └───────────────────────┘ """ - return self._with_node(ExprNode(ExprKind.AGGREGATION, "var", ddof=ddof)) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "var", ddof=ddof)) def map_batches( self, @@ -540,7 +539,7 @@ def map_batches( if returns_scalar else ExprKind.ORDERABLE_FILTRATION ) - return self._with_node( + return self._append_node( ExprNode( kind, "map_batches", @@ -566,7 +565,7 @@ def skew(self) -> Self: | 0 0.0 1.472427 | └──────────────────┘ """ - return self._with_node(ExprNode(ExprKind.AGGREGATION, "skew")) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "skew")) def kurtosis(self) -> Self: """Compute the kurtosis (Fisher's definition) without bias correction. @@ -587,7 +586,7 @@ def kurtosis(self) -> Self: | 0 -1.3 0.210657 | └──────────────────┘ """ - return self._with_node(ExprNode(ExprKind.AGGREGATION, "kurtosis")) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "kurtosis")) def sum(self) -> Self: """Return the sum value. @@ -611,7 +610,7 @@ def sum(self) -> Self: |└────────┴────────┘| └───────────────────┘ """ - return self._with_node(ExprNode(ExprKind.AGGREGATION, "sum")) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "sum")) def min(self) -> Self: """Returns the minimum value(s) from a column(s). @@ -629,7 +628,7 @@ def min(self) -> Self: | 0 1 3 | └──────────────────┘ """ - return self._with_node(ExprNode(ExprKind.AGGREGATION, "min")) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "min")) def max(self) -> Self: """Returns the maximum value(s) from a column(s). @@ -647,7 +646,7 @@ def max(self) -> Self: | 0 20 100 | └──────────────────┘ """ - return self._with_node(ExprNode(ExprKind.AGGREGATION, "max")) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "max")) def count(self) -> Self: """Returns the number of non-null elements in the column. @@ -665,7 +664,7 @@ def count(self) -> Self: | 0 3 2 | └──────────────────┘ """ - return self._with_node(ExprNode(ExprKind.AGGREGATION, "count")) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "count")) def n_unique(self) -> Self: """Returns count of unique values. @@ -683,7 +682,7 @@ def n_unique(self) -> Self: | 0 5 3 | └──────────────────┘ """ - return self._with_node(ExprNode(ExprKind.AGGREGATION, "n_unique")) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "n_unique")) def unique(self) -> Self: """Return unique values of this expression. @@ -701,7 +700,7 @@ def unique(self) -> Self: | 0 9 12 | └──────────────────┘ """ - return self._with_node(ExprNode(ExprKind.FILTRATION, "unique")) + return self._append_node(ExprNode(ExprKind.FILTRATION, "unique")) def abs(self) -> Self: """Return absolute value of each element. @@ -720,7 +719,7 @@ def abs(self) -> Self: |1 -2 4 2 4| └─────────────────────┘ """ - return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "abs")) + return self._append_node(ExprNode(ExprKind.ELEMENTWISE, "abs")) def cum_sum(self, *, reverse: bool = False) -> Self: """Return cumulative sum. @@ -749,7 +748,7 @@ def cum_sum(self, *, reverse: bool = False) -> Self: |4 5 6 15| └──────────────────┘ """ - return self._with_node( + return self._append_node( ExprNode(ExprKind.ORDERABLE_WINDOW, "cum_sum", reverse=reverse) ) @@ -792,7 +791,7 @@ def diff(self) -> Self: | └─────┴────────┘ | └──────────────────┘ """ - return self._with_node(ExprNode(ExprKind.ORDERABLE_WINDOW, "diff")) + return self._append_node(ExprNode(ExprKind.ORDERABLE_WINDOW, "diff")) def shift(self, n: int) -> Self: """Shift values by `n` positions. @@ -837,7 +836,7 @@ def shift(self, n: int) -> Self: └──────────────────┘ """ ensure_type(n, int, param_name="n") - return self._with_node(ExprNode(ExprKind.ORDERABLE_WINDOW, "shift", n=n)) + return self._append_node(ExprNode(ExprKind.ORDERABLE_WINDOW, "shift", n=n)) def replace_strict( self, @@ -889,7 +888,7 @@ def replace_strict( new = list(old.values()) old = list(old.keys()) - return self._with_node( + return self._append_node( ExprNode( ExprKind.ELEMENTWISE, "replace_strict", @@ -933,7 +932,7 @@ def is_between( node = ExprNode( ExprKind.ELEMENTWISE, "is_between", lower_bound, upper_bound, closed=closed ) - return self._with_node(node) + return self._append_node(node) def is_in(self, other: Any) -> Self: """Check if elements of this expression are present in the other iterable. @@ -958,7 +957,7 @@ def is_in(self, other: Any) -> Self: └──────────────────┘ """ if isinstance(other, Iterable) and not isinstance(other, (str, bytes)): - return self._with_node( + return self._append_node( ExprNode( ExprKind.ELEMENTWISE, "is_in", @@ -994,7 +993,7 @@ def filter(self, *predicates: Any) -> Self: | 5 7 12 | └──────────────────┘ """ - return self._with_node(ExprNode(ExprKind.FILTRATION, "filter", *predicates)) + return self._append_node(ExprNode(ExprKind.FILTRATION, "filter", *predicates)) def is_null(self) -> Self: """Returns a boolean Series indicating which values are null. @@ -1025,7 +1024,7 @@ def is_null(self) -> Self: |└───────┴────────┴───────────┴───────────┘| └──────────────────────────────────────────┘ """ - return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "is_null")) + return self._append_node(ExprNode(ExprKind.ELEMENTWISE, "is_null")) def is_nan(self) -> Self: """Indicate which values are NaN. @@ -1056,7 +1055,7 @@ def is_nan(self) -> Self: |└───────┴────────┴──────────┴──────────┘| └────────────────────────────────────────┘ """ - return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "is_nan")) + return self._append_node(ExprNode(ExprKind.ELEMENTWISE, "is_nan")) def fill_null( self, @@ -1160,7 +1159,7 @@ def fill_null( limit=limit, str_as_lit=True, ) - return self._with_node(node) + return self._append_node(node) def fill_nan(self, value: float | None) -> Self: """Fill floating point NaN values with given value. @@ -1193,7 +1192,7 @@ def fill_nan(self, value: float | None) -> Self: |└────────┴────────┴───────────────┴───────────────┘| └───────────────────────────────────────────────────┘ """ - return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "fill_nan", value=value)) + return self._append_node(ExprNode(ExprKind.ELEMENTWISE, "fill_nan", value=value)) # --- partial reduction --- def drop_nulls(self) -> Self: @@ -1226,7 +1225,7 @@ def drop_nulls(self) -> Self: | └─────┘ | └──────────────────┘ """ - return self._with_node(ExprNode(ExprKind.FILTRATION, "drop_nulls")) + return self._append_node(ExprNode(ExprKind.FILTRATION, "drop_nulls")) def over( self, @@ -1279,7 +1278,7 @@ def over( node = ExprNode( ExprKind.OVER, "over", partition_by=flat_partition_by, order_by=flat_order_by ) - return self._with_node(node) + return self._with_over_node(node) def is_duplicated(self) -> Self: r"""Return a boolean mask indicating duplicated values. @@ -1300,7 +1299,7 @@ def is_duplicated(self) -> Self: |3 1 c True False| └─────────────────────────────────────────┘ """ - return self._with_node(ExprNode(ExprKind.WINDOW, "is_duplicated")) + return self._append_node(ExprNode(ExprKind.WINDOW, "is_duplicated")) def is_unique(self) -> Self: r"""Return a boolean mask indicating unique values. @@ -1321,7 +1320,7 @@ def is_unique(self) -> Self: |3 1 c False True| └─────────────────────────────────┘ """ - return self._with_node(ExprNode(ExprKind.WINDOW, "is_unique")) + return self._append_node(ExprNode(ExprKind.WINDOW, "is_unique")) def null_count(self) -> Self: r"""Count null values. @@ -1345,7 +1344,7 @@ def null_count(self) -> Self: | 0 1 2 | └──────────────────┘ """ - return self._with_node(ExprNode(ExprKind.AGGREGATION, "null_count")) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "null_count")) def is_first_distinct(self) -> Self: r"""Return a boolean mask indicating the first occurrence of each distinct value. @@ -1372,7 +1371,7 @@ def is_first_distinct(self) -> Self: |3 1 c False True| └─────────────────────────────────────────────────┘ """ - return self._with_node(ExprNode(ExprKind.ORDERABLE_WINDOW, "is_first_distinct")) + return self._append_node(ExprNode(ExprKind.ORDERABLE_WINDOW, "is_first_distinct")) def is_last_distinct(self) -> Self: r"""Return a boolean mask indicating the last occurrence of each distinct value. @@ -1399,7 +1398,7 @@ def is_last_distinct(self) -> Self: |3 1 c True True| └───────────────────────────────────────────────┘ """ - return self._with_node(ExprNode(ExprKind.ORDERABLE_WINDOW, "is_last_distinct")) + return self._append_node(ExprNode(ExprKind.ORDERABLE_WINDOW, "is_last_distinct")) def quantile( self, quantile: float, interpolation: RollingInterpolationMethod @@ -1432,7 +1431,7 @@ def quantile( | 0 24.5 74.5 | └──────────────────┘ """ - return self._with_node( + return self._append_node( ExprNode( ExprKind.AGGREGATION, "quantile", @@ -1471,7 +1470,9 @@ def round(self, decimals: int = 0) -> Self: |2 3.901234 3.9| └──────────────────────┘ """ - return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "round", decimals=decimals)) + return self._append_node( + ExprNode(ExprKind.ELEMENTWISE, "round", decimals=decimals) + ) def floor(self) -> Self: r"""Compute the numerical floor. @@ -1494,7 +1495,7 @@ def floor(self) -> Self: |floor: [[1,4,-2]] | └────────────────────────┘ """ - return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "floor")) + return self._append_node(ExprNode(ExprKind.ELEMENTWISE, "floor")) def ceil(self) -> Self: r"""Compute the numerical ceiling. @@ -1517,7 +1518,7 @@ def ceil(self) -> Self: |ceil: [[2,5,-1]] | └────────────────────────┘ """ - return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "ceil")) + return self._append_node(ExprNode(ExprKind.ELEMENTWISE, "ceil")) def len(self) -> Self: r"""Return the number of elements in the column. @@ -1540,7 +1541,7 @@ def len(self) -> Self: | 0 2 1 | └──────────────────┘ """ - return self._with_node(ExprNode(ExprKind.AGGREGATION, "len")) + return self._append_node(ExprNode(ExprKind.AGGREGATION, "len")) def clip( self, @@ -1569,14 +1570,14 @@ def clip( └──────────────────┘ """ if upper_bound is None: - return self._with_node( + return self._append_node( ExprNode(ExprKind.ELEMENTWISE, "clip_lower", lower_bound) ) if lower_bound is None: - return self._with_node( + return self._append_node( ExprNode(ExprKind.ELEMENTWISE, "clip_upper", upper_bound) ) - return self._with_node( + return self._append_node( ExprNode(ExprKind.ELEMENTWISE, "clip", lower_bound, upper_bound) ) @@ -1610,7 +1611,7 @@ def first(self) -> Self: | 1 2 None | └──────────────────┘ """ - return self._with_node(ExprNode(ExprKind.ORDERABLE_AGGREGATION, "first")) + return self._append_node(ExprNode(ExprKind.ORDERABLE_AGGREGATION, "first")) def last(self) -> Self: """Get the last value. @@ -1649,7 +1650,7 @@ def last(self) -> Self: |b: [[null,"baz"]] | └──────────────────┘ """ - return self._with_node(ExprNode(ExprKind.ORDERABLE_AGGREGATION, "last")) + return self._append_node(ExprNode(ExprKind.ORDERABLE_AGGREGATION, "last")) def mode(self, *, keep: ModeKeepStrategy = "all") -> Self: r"""Compute the most occurring value(s). @@ -1678,7 +1679,7 @@ def mode(self, *, keep: ModeKeepStrategy = "all") -> Self: msg = f"`keep` must be one of {_supported_keep_values}, found '{keep}'" raise ValueError(msg) kind = ExprKind.AGGREGATION if keep == "any" else ExprKind.FILTRATION - return self._with_node(ExprNode(kind, "mode", keep=keep)) + return self._append_node(ExprNode(kind, "mode", keep=keep)) def is_finite(self) -> Self: """Returns boolean values indicating which original values are finite. @@ -1712,7 +1713,7 @@ def is_finite(self) -> Self: |└──────┴─────────────┘| └──────────────────────┘ """ - return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "is_finite")) + return self._append_node(ExprNode(ExprKind.ELEMENTWISE, "is_finite")) def cum_count(self, *, reverse: bool = False) -> Self: r"""Return the cumulative count of the non-null values in the column. @@ -1743,7 +1744,7 @@ def cum_count(self, *, reverse: bool = False) -> Self: |3 d 3 1| └─────────────────────────────────────────┘ """ - return self._with_node( + return self._append_node( ExprNode(ExprKind.ORDERABLE_WINDOW, "cum_count", reverse=reverse) ) @@ -1776,7 +1777,7 @@ def cum_min(self, *, reverse: bool = False) -> Self: |3 2.0 1.0 2.0| └────────────────────────────────────┘ """ - return self._with_node( + return self._append_node( ExprNode(ExprKind.ORDERABLE_WINDOW, "cum_min", reverse=reverse) ) @@ -1809,7 +1810,7 @@ def cum_max(self, *, reverse: bool = False) -> Self: |3 2.0 3.0 2.0| └────────────────────────────────────┘ """ - return self._with_node( + return self._append_node( ExprNode(ExprKind.ORDERABLE_WINDOW, "cum_max", reverse=reverse) ) @@ -1842,7 +1843,7 @@ def cum_prod(self, *, reverse: bool = False) -> Self: |3 2.0 6.0 2.0| └──────────────────────────────────────┘ """ - return self._with_node( + return self._append_node( ExprNode(ExprKind.ORDERABLE_WINDOW, "cum_prod", reverse=reverse) ) @@ -1891,7 +1892,7 @@ def rolling_sum( window_size, min_samples = _validate_rolling_arguments( window_size=window_size, min_samples=min_samples ) - return self._with_node( + return self._append_node( ExprNode( ExprKind.ORDERABLE_WINDOW, "rolling_sum", @@ -1947,7 +1948,7 @@ def rolling_mean( window_size=window_size, min_samples=min_samples ) - return self._with_node( + return self._append_node( ExprNode( ExprKind.ORDERABLE_WINDOW, "rolling_mean", @@ -2008,7 +2009,7 @@ def rolling_var( window_size, min_samples = _validate_rolling_arguments( window_size=window_size, min_samples=min_samples ) - return self._with_node( + return self._append_node( ExprNode( ExprKind.ORDERABLE_WINDOW, "rolling_var", @@ -2070,7 +2071,7 @@ def rolling_std( window_size, min_samples = _validate_rolling_arguments( window_size=window_size, min_samples=min_samples ) - return self._with_node( + return self._append_node( ExprNode( ExprKind.ORDERABLE_WINDOW, "rolling_std", @@ -2135,7 +2136,7 @@ def rank(self, method: RankMethod = "average", *, descending: bool = False) -> S ) raise ValueError(msg) - return self._with_node( + return self._append_node( ExprNode(ExprKind.WINDOW, "rank", method=method, descending=descending) ) @@ -2167,7 +2168,7 @@ def log(self, base: float = math.e) -> Self: |log_2: [[0,1,2]] | └────────────────────────────────────────────────┘ """ - return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "log", base=base)) + return self._append_node(ExprNode(ExprKind.ELEMENTWISE, "log", base=base)) def exp(self) -> Self: r"""Compute the exponent. @@ -2190,7 +2191,7 @@ def exp(self) -> Self: |exp: [[0.36787944117144233,1,2.718281828459045]]| └────────────────────────────────────────────────┘ """ - return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "exp")) + return self._append_node(ExprNode(ExprKind.ELEMENTWISE, "exp")) def sqrt(self) -> Self: r"""Compute the square root. @@ -2213,7 +2214,7 @@ def sqrt(self) -> Self: |sqrt: [[1,2,3]] | └──────────────────┘ """ - return self._with_node(ExprNode(ExprKind.ELEMENTWISE, "sqrt")) + return self._append_node(ExprNode(ExprKind.ELEMENTWISE, "sqrt")) def is_close( # noqa: PLR0914 self, diff --git a/narwhals/expr_cat.py b/narwhals/expr_cat.py index 5c0541a981..c5ba549001 100644 --- a/narwhals/expr_cat.py +++ b/narwhals/expr_cat.py @@ -36,4 +36,6 @@ def get_categories(self) -> ExprT: │ mango │ └────────┘ """ - return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "cat.get_categories")) + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "cat.get_categories") + ) diff --git a/narwhals/expr_dt.py b/narwhals/expr_dt.py index 38e3563d9f..3b8c31055e 100644 --- a/narwhals/expr_dt.py +++ b/narwhals/expr_dt.py @@ -40,7 +40,7 @@ def date(self) -> ExprT: │ 2027-12-13 │ └────────────┘ """ - return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.date")) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.date")) def year(self) -> ExprT: """Extract year from underlying DateTime representation. @@ -64,7 +64,7 @@ def year(self) -> ExprT: |1 2065-01-01 2065| └──────────────────┘ """ - return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.year")) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.year")) def month(self) -> ExprT: """Extract month from underlying DateTime representation. @@ -85,7 +85,7 @@ def month(self) -> ExprT: a: [[1978-06-01 00:00:00.000000,2065-01-01 00:00:00.000000]] month: [[6,1]] """ - return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.month")) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.month")) def day(self) -> ExprT: """Extract day from underlying DateTime representation. @@ -106,7 +106,7 @@ def day(self) -> ExprT: a: [[1978-06-01 00:00:00.000000,2065-01-01 00:00:00.000000]] day: [[1,1]] """ - return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.day")) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.day")) def hour(self) -> ExprT: """Extract hour from underlying DateTime representation. @@ -136,7 +136,7 @@ def hour(self) -> ExprT: |└─────────────────────┴──────┘| └──────────────────────────────┘ """ - return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.hour")) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.hour")) def minute(self) -> ExprT: """Extract minutes from underlying DateTime representation. @@ -156,7 +156,7 @@ def minute(self) -> ExprT: 0 1978-01-01 01:01:00 1 1 2065-01-01 10:20:00 20 """ - return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.minute")) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.minute")) def second(self) -> ExprT: """Extract seconds from underlying DateTime representation. @@ -182,7 +182,7 @@ def second(self) -> ExprT: a: [[1978-01-01 01:01:01.000000,2065-01-01 10:20:30.000000]] second: [[1,30]] """ - return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.second")) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.second")) def millisecond(self) -> ExprT: """Extract milliseconds from underlying DateTime representation. @@ -210,7 +210,7 @@ def millisecond(self) -> ExprT: a: [[1978-01-01 01:01:01.000000,2065-01-01 10:20:30.067000]] millisecond: [[0,67]] """ - return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.millisecond")) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.millisecond")) def microsecond(self) -> ExprT: """Extract microseconds from underlying DateTime representation. @@ -238,7 +238,7 @@ def microsecond(self) -> ExprT: a: [[1978-01-01 01:01:01.000000,2065-01-01 10:20:30.067000]] microsecond: [[0,67000]] """ - return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.microsecond")) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.microsecond")) def nanosecond(self) -> ExprT: """Extract Nanoseconds from underlying DateTime representation. @@ -266,7 +266,7 @@ def nanosecond(self) -> ExprT: a: [[1978-01-01 01:01:01.000000,2065-01-01 10:20:30.067000]] nanosecond: [[0,67000000]] """ - return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.nanosecond")) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.nanosecond")) def ordinal_day(self) -> ExprT: """Get ordinal day. @@ -288,7 +288,7 @@ def ordinal_day(self) -> ExprT: |1 2020-08-03 216| └───────────────────────────┘ """ - return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.ordinal_day")) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.ordinal_day")) def weekday(self) -> ExprT: """Extract the week day from the underlying Date representation. @@ -312,7 +312,7 @@ def weekday(self) -> ExprT: |1 2020-08-03 1| └────────────────────────┘ """ - return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.weekday")) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.weekday")) def total_minutes(self) -> ExprT: """Get total minutes. @@ -343,7 +343,7 @@ def total_minutes(self) -> ExprT: │ 20m 40s ┆ 20 │ └──────────────┴─────────────────┘ """ - return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.total_minutes")) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.total_minutes")) def total_seconds(self) -> ExprT: """Get total seconds. @@ -374,7 +374,7 @@ def total_seconds(self) -> ExprT: │ 20s 40ms ┆ 20 │ └──────────────┴─────────────────┘ """ - return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "dt.total_seconds")) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "dt.total_seconds")) def total_milliseconds(self) -> ExprT: """Get total milliseconds. @@ -410,7 +410,7 @@ def total_milliseconds(self) -> ExprT: │ 20040µs ┆ 20 │ └──────────────┴──────────────────────┘ """ - return self._expr._with_node( + return self._expr._append_node( ExprNode(ExprKind.ELEMENTWISE, "dt.total_milliseconds") ) @@ -445,7 +445,7 @@ def total_microseconds(self) -> ExprT: a: [[10,1200]] a_total_microseconds: [[10,1200]] """ - return self._expr._with_node( + return self._expr._append_node( ExprNode(ExprKind.ELEMENTWISE, "dt.total_microseconds") ) @@ -479,7 +479,7 @@ def total_nanoseconds(self) -> ExprT: 0 2024-01-01 00:00:00.000000001 NaN 1 2024-01-01 00:00:00.000000002 1.0 """ - return self._expr._with_node( + return self._expr._append_node( ExprNode(ExprKind.ELEMENTWISE, "dt.total_nanoseconds") ) @@ -543,7 +543,7 @@ def to_string(self, format: str) -> ExprT: |└─────────────────────┘| └───────────────────────┘ """ - return self._expr._with_node( + return self._expr._append_node( ExprNode(ExprKind.ELEMENTWISE, "dt.to_string", format=format) ) @@ -571,7 +571,7 @@ def replace_time_zone(self, time_zone: str | None) -> ExprT: 0 2024-01-01 00:00:00+05:45 1 2024-01-02 00:00:00+05:45 """ - return self._expr._with_node( + return self._expr._append_node( ExprNode(ExprKind.ELEMENTWISE, "dt.replace_time_zone", time_zone=time_zone) ) @@ -605,7 +605,7 @@ def convert_time_zone(self, time_zone: str) -> ExprT: if time_zone is None: msg = "Target `time_zone` cannot be `None` in `convert_time_zone`. Please use `replace_time_zone(None)` if you want to remove the time zone." raise TypeError(msg) - return self._expr._with_node( + return self._expr._append_node( ExprNode(ExprKind.ELEMENTWISE, "dt.convert_time_zone", time_zone=time_zone) ) @@ -645,7 +645,7 @@ def timestamp(self, time_unit: TimeUnit = "us") -> ExprT: f"\n\nExpected one of {{'ns', 'us', 'ms'}}, got {time_unit!r}." ) raise ValueError(msg) - return self._expr._with_node( + return self._expr._append_node( ExprNode(ExprKind.ELEMENTWISE, "dt.timestamp", time_unit=time_unit) ) @@ -689,7 +689,7 @@ def truncate(self, every: str) -> ExprT: |└─────────────────────┴─────────────────────┘| └─────────────────────────────────────────────┘ """ - return self._expr._with_node( + return self._expr._append_node( ExprNode(ExprKind.ELEMENTWISE, "dt.truncate", every=every) ) @@ -733,6 +733,6 @@ def offset_by(self, by: str) -> ExprT: |└─────────────────────┴───────────────────────┘| └───────────────────────────────────────────────┘ """ - return self._expr._with_node( + return self._expr._append_node( ExprNode(ExprKind.ELEMENTWISE, "dt.offset_by", by=by) ) diff --git a/narwhals/expr_list.py b/narwhals/expr_list.py index 0839c92201..8f9c94c6ab 100644 --- a/narwhals/expr_list.py +++ b/narwhals/expr_list.py @@ -42,7 +42,7 @@ def len(self) -> ExprT: |└──────────────┴───────┘| └────────────────────────┘ """ - return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "list.len")) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "list.len")) def unique(self) -> ExprT: """Get the unique/distinct values in the list. @@ -71,7 +71,7 @@ def unique(self) -> ExprT: |└──────────────┴───────────┘| └────────────────────────────┘ """ - return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "list.unique")) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "list.unique")) def contains(self, item: NonNestedLiteral) -> ExprT: """Check if sublists contain the given item. @@ -100,7 +100,7 @@ def contains(self, item: NonNestedLiteral) -> ExprT: |└───────────┴──────────────┘| └────────────────────────────┘ """ - return self._expr._with_node( + return self._expr._append_node( ExprNode(ExprKind.ELEMENTWISE, "list.contains", item=item) ) @@ -140,6 +140,6 @@ def get(self, index: int) -> ExprT: msg = f"Index {index} is out of bounds: should be greater than or equal to 0." raise ValueError(msg) - return self._expr._with_node( + return self._expr._append_node( ExprNode(ExprKind.ELEMENTWISE, "list.get", index=index) ) diff --git a/narwhals/expr_name.py b/narwhals/expr_name.py index 51f67f3a42..64525aae56 100644 --- a/narwhals/expr_name.py +++ b/narwhals/expr_name.py @@ -28,7 +28,7 @@ def keep(self) -> ExprT: >>> df.select(nw.col("foo").alias("alias_for_foo").name.keep()).columns ['foo'] """ - return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "name.keep")) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "name.keep")) def map(self, function: Callable[[str], str]) -> ExprT: r"""Rename the output of an expression by mapping a function over the root name. @@ -48,7 +48,7 @@ def map(self, function: Callable[[str], str]) -> ExprT: >>> df.select(nw.col("foo", "BAR").name.map(renaming_func)).columns ['oof', 'RAB'] """ - return self._expr._with_node( + return self._expr._append_node( ExprNode(ExprKind.ELEMENTWISE, "name.map", function=function) ) @@ -69,7 +69,7 @@ def prefix(self, prefix: str) -> ExprT: >>> df.select(nw.col("foo", "BAR").name.prefix("with_prefix")).columns ['with_prefixfoo', 'with_prefixBAR'] """ - return self._expr._with_node( + return self._expr._append_node( ExprNode(ExprKind.ELEMENTWISE, "name.prefix", prefix=prefix) ) @@ -90,7 +90,7 @@ def suffix(self, suffix: str) -> ExprT: >>> df.select(nw.col("foo", "BAR").name.suffix("_with_suffix")).columns ['foo_with_suffix', 'BAR_with_suffix'] """ - return self._expr._with_node( + return self._expr._append_node( ExprNode(ExprKind.ELEMENTWISE, "name.suffix", suffix=suffix) ) @@ -108,7 +108,9 @@ def to_lowercase(self) -> ExprT: >>> df.select(nw.col("foo", "BAR").name.to_lowercase()).columns ['foo', 'bar'] """ - return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "name.to_lowercase")) + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "name.to_lowercase") + ) def to_uppercase(self) -> ExprT: r"""Make the root column name uppercase. @@ -124,4 +126,6 @@ def to_uppercase(self) -> ExprT: >>> df.select(nw.col("foo", "BAR").name.to_uppercase()).columns ['FOO', 'BAR'] """ - return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "name.to_uppercase")) + return self._expr._append_node( + ExprNode(ExprKind.ELEMENTWISE, "name.to_uppercase") + ) diff --git a/narwhals/expr_str.py b/narwhals/expr_str.py index 8054d04686..fe914e8375 100644 --- a/narwhals/expr_str.py +++ b/narwhals/expr_str.py @@ -39,7 +39,7 @@ def len_chars(self) -> ExprT: |└───────┴───────────┘| └─────────────────────┘ """ - return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "str.len_chars")) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "str.len_chars")) def replace( self, pattern: str, value: str | IntoExpr, *, literal: bool = False, n: int = 1 @@ -66,7 +66,7 @@ def replace( |1 abc abc123 abc123| └──────────────────────┘ """ - return self._expr._with_node( + return self._expr._append_node( ExprNode( ExprKind.ELEMENTWISE, "str.replace", @@ -102,7 +102,7 @@ def replace_all( |1 abc abc123 123| └──────────────────────┘ """ - return self._expr._with_node( + return self._expr._append_node( ExprNode( ExprKind.ELEMENTWISE, "str.replace_all", @@ -132,7 +132,7 @@ def strip_chars(self, characters: str | None = None) -> ExprT: ... ) {'fruits': ['apple', '\nmango'], 'stripped': ['apple', 'mango']} """ - return self._expr._with_node( + return self._expr._append_node( ExprNode(ExprKind.ELEMENTWISE, "str.strip_chars", characters=characters) ) @@ -157,7 +157,7 @@ def starts_with(self, prefix: str) -> ExprT: |2 None None| └───────────────────┘ """ - return self._expr._with_node( + return self._expr._append_node( ExprNode(ExprKind.ELEMENTWISE, "str.starts_with", prefix=prefix) ) @@ -182,7 +182,7 @@ def ends_with(self, suffix: str) -> ExprT: |2 None None| └───────────────────┘ """ - return self._expr._with_node( + return self._expr._append_node( ExprNode(ExprKind.ELEMENTWISE, "str.ends_with", suffix=suffix) ) @@ -212,7 +212,7 @@ def contains(self, pattern: str, *, literal: bool = False) -> ExprT: default_match: [[true,false,true]] case_insensitive_match: [[true,false,true]] """ - return self._expr._with_node( + return self._expr._append_node( ExprNode( ExprKind.ELEMENTWISE, "str.contains", pattern=pattern, literal=literal ) @@ -241,7 +241,7 @@ def slice(self, offset: int, length: int | None = None) -> ExprT: |2 papaya ya| └──────────────────┘ """ - return self._expr._with_node( + return self._expr._append_node( ExprNode(ExprKind.ELEMENTWISE, "str.slice", offset=offset, length=length) ) @@ -271,7 +271,7 @@ def split(self, by: str) -> ExprT: |└─────────┴────────────────┘| └────────────────────────────┘ """ - return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "str.split", by=by)) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "str.split", by=by)) def head(self, n: int = 5) -> ExprT: r"""Take the first n elements of each string. @@ -295,7 +295,7 @@ def head(self, n: int = 5) -> ExprT: lyrics: [["taata","taatatata","zukkyun"]] lyrics_head: [["taata","taata","zukky"]] """ - return self._expr._with_node( + return self._expr._append_node( ExprNode(ExprKind.ELEMENTWISE, "str.slice", offset=0, length=n) ) @@ -321,7 +321,7 @@ def tail(self, n: int = 5) -> ExprT: lyrics: [["taata","taatatata","zukkyun"]] lyrics_tail: [["taata","atata","kkyun"]] """ - return self._expr._with_node( + return self._expr._append_node( ExprNode(ExprKind.ELEMENTWISE, "str.slice", offset=-n, length=None) ) @@ -363,7 +363,7 @@ def to_datetime(self, format: str | None = None) -> ExprT: |└─────────────────────┘| └───────────────────────┘ """ - return self._expr._with_node( + return self._expr._append_node( ExprNode(ExprKind.ELEMENTWISE, "str.to_datetime", format=format) ) @@ -392,7 +392,7 @@ def to_date(self, format: str | None = None) -> ExprT: |a: [[2020-01-01,2020-01-02]]| └────────────────────────────┘ """ - return self._expr._with_node( + return self._expr._append_node( ExprNode(ExprKind.ELEMENTWISE, "str.to_date", format=format) ) @@ -418,7 +418,7 @@ def to_uppercase(self) -> ExprT: |1 None None| └──────────────────┘ """ - return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "str.to_uppercase")) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "str.to_uppercase")) def to_lowercase(self) -> ExprT: r"""Transform string to lowercase variant. @@ -437,7 +437,7 @@ def to_lowercase(self) -> ExprT: |1 None None| └──────────────────┘ """ - return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "str.to_lowercase")) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "str.to_lowercase")) def to_titlecase(self) -> ExprT: """Modify strings to their titlecase equivalent. @@ -491,7 +491,7 @@ def to_titlecase(self) -> ExprT: |└─────────────────────────┴─────────────────────────┘| └─────────────────────────────────────────────────────┘ """ - return self._expr._with_node(ExprNode(ExprKind.ELEMENTWISE, "str.to_titlecase")) + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "str.to_titlecase")) def zfill(self, width: int) -> ExprT: """Transform string to zero-padded variant. @@ -517,6 +517,6 @@ def zfill(self, width: int) -> ExprT: |3 None None| └──────────────────┘ """ - return self._expr._with_node( + return self._expr._append_node( ExprNode(ExprKind.ELEMENTWISE, "str.zfill", width=width) ) diff --git a/narwhals/expr_struct.py b/narwhals/expr_struct.py index 83fc64a648..7d734732f9 100644 --- a/narwhals/expr_struct.py +++ b/narwhals/expr_struct.py @@ -42,6 +42,6 @@ def field(self, name: str) -> ExprT: |└──────────────┴──────┘| └───────────────────────┘ """ - return self._expr._with_node( + return self._expr._append_node( ExprNode(ExprKind.ELEMENTWISE, "struct.field", name=name) ) diff --git a/narwhals/selectors.py b/narwhals/selectors.py index 52861c036b..bd34a00199 100644 --- a/narwhals/selectors.py +++ b/narwhals/selectors.py @@ -22,13 +22,13 @@ def __add__(self, other: Any) -> Expr: # type: ignore[override] if isinstance(other, Selector): msg = "unsupported operand type(s) for op: ('Selector' + 'Selector')" raise TypeError(msg) - return self._to_expr()._with_node( + return self._to_expr()._append_node( ExprNode(ExprKind.ELEMENTWISE, "__add__", other, str_as_lit=True) ) def __or__(self, other: Any) -> Expr: # type: ignore[override] if isinstance(other, Selector): - return self._with_node( + return self._append_node( ExprNode( ExprKind.ELEMENTWISE, "__or__", @@ -37,13 +37,13 @@ def __or__(self, other: Any) -> Expr: # type: ignore[override] allow_multi_output=True, ) ) - return self._to_expr()._with_node( + return self._to_expr()._append_node( ExprNode(ExprKind.ELEMENTWISE, "__or__", other, str_as_lit=True) ) def __and__(self, other: Any) -> Expr: # type: ignore[override] if isinstance(other, Selector): - return self._with_node( + return self._append_node( ExprNode( ExprKind.ELEMENTWISE, "__and__", @@ -52,7 +52,7 @@ def __and__(self, other: Any) -> Expr: # type: ignore[override] allow_multi_output=True, ) ) - return self._to_expr()._with_node( + return self._to_expr()._append_node( ExprNode(ExprKind.ELEMENTWISE, "__and__", other, str_as_lit=True) ) diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index ac87dbbf3b..3e48fa6bc5 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -373,11 +373,11 @@ def _l1_norm(self) -> Self: def head(self, n: int = 10) -> Self: r"""Get the first `n` rows.""" - return self._with_node(ExprNode(ExprKind.FILTRATION, "head", n=n)) + return self._append_node(ExprNode(ExprKind.FILTRATION, "head", n=n)) def tail(self, n: int = 10) -> Self: r"""Get the last `n` rows.""" - return self._with_node(ExprNode(ExprKind.FILTRATION, "tail", n=n)) + return self._append_node(ExprNode(ExprKind.FILTRATION, "tail", n=n)) def gather_every(self, n: int, offset: int = 0) -> Self: r"""Take every nth value in the Series and return as new Series. @@ -386,7 +386,7 @@ def gather_every(self, n: int, offset: int = 0) -> Self: n: Gather every *n*-th row. offset: Starting index. """ - return self._with_node( + return self._append_node( ExprNode(ExprKind.ORDERABLE_FILTRATION, "gather_every", n=n, offset=offset) ) @@ -398,11 +398,11 @@ def unique(self, *, maintain_order: bool | None = None) -> Self: "You can safely remove this argument." ) issue_warning(msg, UserWarning) - return self._with_node(ExprNode(ExprKind.FILTRATION, "unique")) + return self._append_node(ExprNode(ExprKind.FILTRATION, "unique")) def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: """Sort this column. Place null values first.""" - return self._with_node( + return self._append_node( ExprNode( ExprKind.WINDOW, "sort", descending=descending, nulls_last=nulls_last ) @@ -410,15 +410,15 @@ def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: def arg_max(self) -> Self: """Returns the index of the maximum value.""" - return self._with_node(ExprNode(ExprKind.ORDERABLE_AGGREGATION, "arg_max")) + return self._append_node(ExprNode(ExprKind.ORDERABLE_AGGREGATION, "arg_max")) def arg_min(self) -> Self: """Returns the index of the minimum value.""" - return self._with_node(ExprNode(ExprKind.ORDERABLE_AGGREGATION, "arg_min")) + return self._append_node(ExprNode(ExprKind.ORDERABLE_AGGREGATION, "arg_min")) def arg_true(self) -> Self: """Find elements where boolean expression is True.""" - return self._with_node(ExprNode(ExprKind.ORDERABLE_FILTRATION, "arg_true")) + return self._append_node(ExprNode(ExprKind.ORDERABLE_FILTRATION, "arg_true")) def sample( self, @@ -437,7 +437,7 @@ def sample( seed: Seed for the random number generator. If set to None (default), a random seed is generated for each sample operation. """ - return self._with_node( + return self._append_node( ExprNode( ExprKind.FILTRATION, "sample", From fd8789847e8d68bb8aea58f8505e2e4ffb53050e Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 16 Oct 2025 11:00:26 +0100 Subject: [PATCH 78/95] simplify ExprMetadata.from_node --- narwhals/_expression_parsing.py | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 34ef3dd5d1..dda1089cd9 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -277,11 +277,11 @@ def _push_down_over_node_in_place( elif over_node_order_by and any( expr_node.is_orderable() for expr_node in expr._nodes ): - exprs.append(expr._append_node(over_node)) + exprs.append(expr._with_over_node(over_node)) elif over_node_partition_by and not all( expr_node.is_elementwise() for expr_node in expr._nodes ): - exprs.append(expr._append_node(over_node_without_order_by)) + exprs.append(expr._with_over_node(over_node_without_order_by)) else: # If there's no `partition_by`, then `over_node_without_order_by` is a no-op. exprs.append(expr) @@ -402,12 +402,10 @@ def iter_nodes_reversed(self) -> Iterator[ExprNode]: current = current.prev @classmethod - def from_node( # noqa: PLR0911 - cls, node: ExprNode, *ces: CompliantExprAny - ) -> ExprMetadata: + def from_node(cls, node: ExprNode, *ces: CompliantExprAny) -> ExprMetadata: kind = node.kind - if kind is ExprKind.SERIES: - return cls.from_selector_single(node) + if kind in KIND_TO_METADATA_CONSTRUCTOR: + return KIND_TO_METADATA_CONSTRUCTOR[kind](node, *ces) if kind is ExprKind.COL: return ( cls.from_selector_single(node) @@ -420,16 +418,6 @@ def from_node( # noqa: PLR0911 if len(node.kwargs["indices"]) == 1 else cls.from_selector_multi_unnamed(node) ) - if kind in {ExprKind.ALL, ExprKind.EXCLUDE}: - return cls.from_selector_multi_unnamed(node) - if kind is ExprKind.AGGREGATION: - return cls.from_aggregation(node) - if kind is ExprKind.LITERAL: - return cls.from_literal(node) - if kind is ExprKind.SELECTOR: - return cls.from_selector_multi_unnamed(node) - if kind is ExprKind.ELEMENTWISE: - return cls.from_elementwise(node, *ces) msg = f"Unexpected node kind: {kind}" # pragma: no cover raise AssertionError(msg) # pragma: no cover @@ -679,6 +667,17 @@ def op_nodes_reversed(self) -> Iterator[ExprNode]: yield node +KIND_TO_METADATA_CONSTRUCTOR = { + ExprKind.SERIES: ExprMetadata.from_selector_single, + ExprKind.ALL: ExprMetadata.from_selector_multi_unnamed, + ExprKind.EXCLUDE: ExprMetadata.from_selector_multi_unnamed, + ExprKind.LITERAL: ExprMetadata.from_literal, + ExprKind.SELECTOR: ExprMetadata.from_selector_multi_unnamed, + ExprKind.ELEMENTWISE: ExprMetadata.from_elementwise, + ExprKind.AGGREGATION: ExprMetadata.from_aggregation, +} + + def combine_metadata( *args: CompliantExprAny, to_single_output: bool, From 0840868621e488efad687a01d4c6c9908787038d Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 16 Oct 2025 11:10:46 +0100 Subject: [PATCH 79/95] simplify further --- narwhals/_expression_parsing.py | 123 ++++++++++++++++++-------------- 1 file changed, 68 insertions(+), 55 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index dda1089cd9..8354d92b24 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -403,53 +403,12 @@ def iter_nodes_reversed(self) -> Iterator[ExprNode]: @classmethod def from_node(cls, node: ExprNode, *ces: CompliantExprAny) -> ExprMetadata: - kind = node.kind - if kind in KIND_TO_METADATA_CONSTRUCTOR: - return KIND_TO_METADATA_CONSTRUCTOR[kind](node, *ces) - if kind is ExprKind.COL: - return ( - cls.from_selector_single(node) - if len(node.kwargs["names"]) == 1 - else cls.from_selector_multi_named(node) - ) - if kind is ExprKind.NTH: - return ( - cls.from_selector_single(node) - if len(node.kwargs["indices"]) == 1 - else cls.from_selector_multi_unnamed(node) - ) - msg = f"Unexpected node kind: {kind}" # pragma: no cover - raise AssertionError(msg) # pragma: no cover + return KIND_TO_METADATA_CONSTRUCTOR[node.kind](node, *ces) - def with_node( # noqa: PLR0911,C901 + def with_node( self, node: ExprNode, ce: CompliantExprAny, *ces: CompliantExprAny ) -> ExprMetadata: - kind = node.kind - if kind is ExprKind.AGGREGATION: - return self.with_aggregation(node) - if kind is ExprKind.ELEMENTWISE: - return combine_metadata( - ce, *ces, to_single_output=False, current_node=node, prev=ce._metadata - ) - if kind is ExprKind.FILTRATION: - return self.with_filtration(node) - if kind is ExprKind.ORDERABLE_WINDOW: - return self.with_orderable_window(node) - if kind is ExprKind.ORDERABLE_FILTRATION: - return self.with_orderable_filtration(node) - if kind is ExprKind.ORDERABLE_AGGREGATION: - return self.with_orderable_aggregation(node) - if kind is ExprKind.WINDOW: - return self.with_window(node) - if kind is ExprKind.OVER: - if node.kwargs["order_by"]: - return self.with_ordered_over(node) - if not node.kwargs["partition_by"]: # pragma: no cover - msg = "At least one of `partition_by` or `order_by` must be specified." - raise InvalidOperationError(msg) - return self.with_partitioned_over(node) - msg = f"Unexpected node kind: {kind}" # pragma: no cover - raise AssertionError(msg) # pragma: no cover + return KIND_TO_METADATA_UPDATER[node.kind](self, node, ce, *ces) @classmethod def from_aggregation(cls, node: ExprNode) -> ExprMetadata: @@ -475,10 +434,26 @@ def from_literal(cls, node: ExprNode) -> ExprMetadata: ) @classmethod - def from_selector_single(cls, node: ExprNode) -> ExprMetadata: - # e.g. `nw.col('a')`, `nw.nth(0)` + def from_series(cls, node: ExprNode) -> ExprMetadata: return cls(ExpansionKind.SINGLE, current_node=node, prev=None) + @classmethod + def from_col(cls, node: ExprNode) -> ExprMetadata: + # e.g. `nw.col('a')`, `nw.nth(0)` + return ( + cls(ExpansionKind.SINGLE, current_node=node, prev=None) + if len(node.kwargs["names"]) == 1 + else cls.from_selector_multi_named(node) + ) + + @classmethod + def from_nth(cls, node: ExprNode) -> ExprMetadata: + return ( + cls(ExpansionKind.SINGLE, current_node=node, prev=None) + if len(node.kwargs["indices"]) == 1 + else cls.from_selector_multi_unnamed(node) + ) + @classmethod def from_selector_multi_named(cls, node: ExprNode) -> ExprMetadata: # e.g. `nw.col('a', 'b')` @@ -497,7 +472,7 @@ def from_elementwise(cls, node: ExprNode, *ces: CompliantExprAny) -> ExprMetadat def is_filtration(self) -> bool: return not self.preserves_length and not self.is_scalar_like - def with_aggregation(self, node: ExprNode) -> ExprMetadata: + def with_aggregation(self, node: ExprNode, _ce: CompliantExprAny) -> ExprMetadata: if self.is_scalar_like: msg = "Can't apply aggregations to scalar-like expressions." raise InvalidOperationError(msg) @@ -513,7 +488,16 @@ def with_aggregation(self, node: ExprNode) -> ExprMetadata: prev=self, ) - def with_orderable_aggregation(self, node: ExprNode) -> ExprMetadata: + def with_elementwise( + self, node: ExprNode, ce: CompliantExprAny, *ces: CompliantExprAny + ) -> ExprMetadata: + return combine_metadata( + ce, *ces, to_single_output=False, current_node=node, prev=ce._metadata + ) + + def with_orderable_aggregation( + self, node: ExprNode, _ce: CompliantExprAny + ) -> ExprMetadata: # Deprecated, used only in stable.v1. if self.is_scalar_like: # pragma: no cover msg = "Can't apply aggregations to scalar-like expressions." @@ -530,7 +514,7 @@ def with_orderable_aggregation(self, node: ExprNode) -> ExprMetadata: prev=self, ) - def with_window(self, node: ExprNode) -> ExprMetadata: + def with_window(self, node: ExprNode, _ce: CompliantExprAny) -> ExprMetadata: # Window function which may (but doesn't have to) be used with `over(order_by=...)`. if self.is_scalar_like: msg = "Can't apply window (e.g. `rank`) to scalar-like expression." @@ -549,7 +533,17 @@ def with_window(self, node: ExprNode) -> ExprMetadata: prev=self, ) - def with_orderable_window(self, node: ExprNode) -> ExprMetadata: + def with_over(self, node: ExprNode, _ce: CompliantExprAny) -> ExprMetadata: + if node.kwargs["order_by"]: + return self.with_ordered_over(node, _ce) + if not node.kwargs["partition_by"]: # pragma: no cover + msg = "At least one of `partition_by` or `order_by` must be specified." + raise InvalidOperationError(msg) + return self.with_partitioned_over(node, _ce) + + def with_orderable_window( + self, node: ExprNode, _ce: CompliantExprAny + ) -> ExprMetadata: # Window function which must be used with `over(order_by=...)`. if self.is_scalar_like: msg = "Can't apply orderable window (e.g. `diff`, `shift`) to scalar-like expression." @@ -566,7 +560,7 @@ def with_orderable_window(self, node: ExprNode) -> ExprMetadata: prev=self, ) - def with_ordered_over(self, node: ExprNode) -> ExprMetadata: + def with_ordered_over(self, node: ExprNode, _ce: CompliantExprAny) -> ExprMetadata: if self.has_windows: msg = "Cannot nest `over` statements." raise InvalidOperationError(msg) @@ -605,7 +599,9 @@ def with_ordered_over(self, node: ExprNode) -> ExprMetadata: prev=self, ) - def with_partitioned_over(self, node: ExprNode) -> ExprMetadata: + def with_partitioned_over( + self, node: ExprNode, _ce: CompliantExprAny + ) -> ExprMetadata: if self.has_windows: msg = "Cannot nest `over` statements." raise InvalidOperationError(msg) @@ -627,7 +623,9 @@ def with_partitioned_over(self, node: ExprNode) -> ExprMetadata: prev=self, ) - def with_filtration(self, node: ExprNode) -> ExprMetadata: + def with_filtration( + self, node: ExprNode, _ce: CompliantExprAny, *_ces: CompliantExprAny + ) -> ExprMetadata: if self.is_scalar_like: msg = "Can't apply filtration (e.g. `drop_nulls`) to scalar-like expression." raise InvalidOperationError(msg) @@ -643,7 +641,9 @@ def with_filtration(self, node: ExprNode) -> ExprMetadata: prev=self, ) - def with_orderable_filtration(self, node: ExprNode) -> ExprMetadata: + def with_orderable_filtration( + self, node: ExprNode, _ce: CompliantExprAny + ) -> ExprMetadata: if self.is_scalar_like: msg = "Can't apply filtration (e.g. `drop_nulls`) to scalar-like expression." raise InvalidOperationError(msg) @@ -668,7 +668,9 @@ def op_nodes_reversed(self) -> Iterator[ExprNode]: KIND_TO_METADATA_CONSTRUCTOR = { - ExprKind.SERIES: ExprMetadata.from_selector_single, + ExprKind.SERIES: ExprMetadata.from_series, + ExprKind.COL: ExprMetadata.from_col, + ExprKind.NTH: ExprMetadata.from_nth, ExprKind.ALL: ExprMetadata.from_selector_multi_unnamed, ExprKind.EXCLUDE: ExprMetadata.from_selector_multi_unnamed, ExprKind.LITERAL: ExprMetadata.from_literal, @@ -677,6 +679,17 @@ def op_nodes_reversed(self) -> Iterator[ExprNode]: ExprKind.AGGREGATION: ExprMetadata.from_aggregation, } +KIND_TO_METADATA_UPDATER = { + ExprKind.AGGREGATION: ExprMetadata.with_aggregation, + ExprKind.ELEMENTWISE: ExprMetadata.with_elementwise, + ExprKind.FILTRATION: ExprMetadata.with_filtration, + ExprKind.ORDERABLE_WINDOW: ExprMetadata.with_orderable_window, + ExprKind.ORDERABLE_FILTRATION: ExprMetadata.with_orderable_filtration, + ExprKind.ORDERABLE_AGGREGATION: ExprMetadata.with_orderable_aggregation, + ExprKind.WINDOW: ExprMetadata.with_window, + ExprKind.OVER: ExprMetadata.with_over, +} + def combine_metadata( *args: CompliantExprAny, From a2dbd2e7deab2ba75d391447edb1412f20ccb354 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 16 Oct 2025 11:39:25 +0100 Subject: [PATCH 80/95] clearer names --- narwhals/_expression_parsing.py | 58 +++++++++++++++++++++------------ 1 file changed, 38 insertions(+), 20 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 8354d92b24..79b122a15b 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -402,13 +402,20 @@ def iter_nodes_reversed(self) -> Iterator[ExprNode]: current = current.prev @classmethod - def from_node(cls, node: ExprNode, *ces: CompliantExprAny) -> ExprMetadata: - return KIND_TO_METADATA_CONSTRUCTOR[node.kind](node, *ces) + def from_node( + cls, node: ExprNode, *compliant_exprs: CompliantExprAny + ) -> ExprMetadata: + return KIND_TO_METADATA_CONSTRUCTOR[node.kind](node, *compliant_exprs) def with_node( - self, node: ExprNode, ce: CompliantExprAny, *ces: CompliantExprAny + self, + node: ExprNode, + compliant_expr: CompliantExprAny, + *compliant_expr_args: CompliantExprAny, ) -> ExprMetadata: - return KIND_TO_METADATA_UPDATER[node.kind](self, node, ce, *ces) + return KIND_TO_METADATA_UPDATER[node.kind]( + self, node, compliant_expr, *compliant_expr_args + ) @classmethod def from_aggregation(cls, node: ExprNode) -> ExprMetadata: @@ -465,8 +472,12 @@ def from_selector_multi_unnamed(cls, node: ExprNode) -> ExprMetadata: return cls(ExpansionKind.MULTI_UNNAMED, current_node=node, prev=None) @classmethod - def from_elementwise(cls, node: ExprNode, *ces: CompliantExprAny) -> ExprMetadata: - return combine_metadata(*ces, to_single_output=True, current_node=node, prev=None) + def from_elementwise( + cls, node: ExprNode, *compliant_exprs: CompliantExprAny + ) -> ExprMetadata: + return combine_metadata( + *compliant_exprs, to_single_output=True, current_node=node, prev=None + ) @property def is_filtration(self) -> bool: @@ -489,10 +500,17 @@ def with_aggregation(self, node: ExprNode, _ce: CompliantExprAny) -> ExprMetadat ) def with_elementwise( - self, node: ExprNode, ce: CompliantExprAny, *ces: CompliantExprAny + self, + node: ExprNode, + compliant_expr: CompliantExprAny, + *compliant_expr_args: CompliantExprAny, ) -> ExprMetadata: return combine_metadata( - ce, *ces, to_single_output=False, current_node=node, prev=ce._metadata + compliant_expr, + *compliant_expr_args, + to_single_output=False, + current_node=node, + prev=compliant_expr._metadata, ) def with_orderable_aggregation( @@ -692,7 +710,7 @@ def op_nodes_reversed(self) -> Iterator[ExprNode]: def combine_metadata( - *args: CompliantExprAny, + *compliant_exprs: CompliantExprAny, to_single_output: bool, current_node: ExprNode, prev: ExprMetadata | None, @@ -700,7 +718,7 @@ def combine_metadata( """Combine metadata from `args`. Arguments: - args: Arguments, maybe expressions, literals, or Series. + compliant_exprs: Expression arguments. to_single_output: Whether the result is always single-output, regardless of the inputs (e.g. `nw.sum_horizontal`). current_node: The current node being added. @@ -719,8 +737,8 @@ def combine_metadata( # result is literal if all inputs are literal result_is_literal = True - for i, arg in enumerate(args): - metadata = arg._metadata + for i, ce in enumerate(compliant_exprs): + metadata = ce._metadata assert metadata is not None # noqa: S101 if metadata.expansion_kind.is_multi_output(): expansion_kind = metadata.expansion_kind @@ -805,10 +823,10 @@ def evaluate_into_exprs( yield ret -def maybe_broadcast_ces(*ces: CompliantExprAny) -> list[CompliantExprAny]: - broadcast = any(not is_scalar_like(ce) for ce in ces) +def maybe_broadcast_ces(*compliant_exprs: CompliantExprAny) -> list[CompliantExprAny]: + broadcast = any(not is_scalar_like(ce) for ce in compliant_exprs) results: list[CompliantExprAny] = [] - for compliant_expr in ces: + for compliant_expr in compliant_exprs: if broadcast and is_scalar_like(compliant_expr): _compliant_expr: CompliantExprAny = compliant_expr.broadcast() # Make sure to preserve metadata. @@ -849,7 +867,7 @@ def evaluate_node( compliant_expr: CompliantExprAny, node: ExprNode, ns: CompliantNamespaceAny ) -> CompliantExprAny: md: ExprMetadata = compliant_expr._metadata - ce, *ces = maybe_broadcast_ces( + compliant_expr, *compliant_expr_args = maybe_broadcast_ces( compliant_expr, *evaluate_into_exprs( *node.exprs, @@ -858,12 +876,12 @@ def evaluate_node( allow_multi_output=node.allow_multi_output, ), ) - md = md.with_node(node, ce, *ces) + md = md.with_node(node, compliant_expr, *compliant_expr_args) if "." in node.name: accessor, method = node.name.split(".") - func = getattr(getattr(ce, accessor), method) + func = getattr(getattr(compliant_expr, accessor), method) else: - func = getattr(ce, node.name) - ret = cast("CompliantExprAny", func(*ces, **node.kwargs)) + func = getattr(compliant_expr, node.name) + ret = cast("CompliantExprAny", func(*compliant_expr_args, **node.kwargs)) ret._opt_metadata = md return ret From c9c46c04fdbc751bc7dcd795ae67e446046d2873 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 16 Oct 2025 11:45:45 +0100 Subject: [PATCH 81/95] dask fixup --- narwhals/_dask/expr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 9998ba887f..0a1abe929d 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -406,12 +406,12 @@ def round(self, decimals: int) -> Self: def floor(self) -> Self: import dask.array as da - return self._with_callable(da.floor, "floor") + return self._with_callable(da.floor) def ceil(self) -> Self: import dask.array as da - return self._with_callable(da.ceil, "ceil") + return self._with_callable(da.ceil) def unique(self) -> Self: return self._with_callable(lambda expr: expr.unique()) From 33e3078d2cd0ed33224a9a6dc243ccf879a0d254 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 16 Oct 2025 12:14:10 +0100 Subject: [PATCH 82/95] typing --- narwhals/_expression_parsing.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 79b122a15b..f229e11e22 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -5,7 +5,7 @@ from __future__ import annotations from enum import Enum, auto -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, Callable, Literal, cast from narwhals._utils import zip_strict from narwhals.dependencies import is_numpy_array_1d @@ -685,27 +685,27 @@ def op_nodes_reversed(self) -> Iterator[ExprNode]: yield node -KIND_TO_METADATA_CONSTRUCTOR = { - ExprKind.SERIES: ExprMetadata.from_series, - ExprKind.COL: ExprMetadata.from_col, - ExprKind.NTH: ExprMetadata.from_nth, +KIND_TO_METADATA_CONSTRUCTOR: dict[ExprKind, Callable[[ExprNode], ExprMetadata]] = { + ExprKind.AGGREGATION: ExprMetadata.from_aggregation, ExprKind.ALL: ExprMetadata.from_selector_multi_unnamed, + ExprKind.ELEMENTWISE: ExprMetadata.from_elementwise, ExprKind.EXCLUDE: ExprMetadata.from_selector_multi_unnamed, + ExprKind.SERIES: ExprMetadata.from_series, + ExprKind.COL: ExprMetadata.from_col, ExprKind.LITERAL: ExprMetadata.from_literal, + ExprKind.NTH: ExprMetadata.from_nth, ExprKind.SELECTOR: ExprMetadata.from_selector_multi_unnamed, - ExprKind.ELEMENTWISE: ExprMetadata.from_elementwise, - ExprKind.AGGREGATION: ExprMetadata.from_aggregation, } -KIND_TO_METADATA_UPDATER = { +KIND_TO_METADATA_UPDATER: dict[ExprKind, Callable[..., ExprMetadata]] = { ExprKind.AGGREGATION: ExprMetadata.with_aggregation, ExprKind.ELEMENTWISE: ExprMetadata.with_elementwise, ExprKind.FILTRATION: ExprMetadata.with_filtration, - ExprKind.ORDERABLE_WINDOW: ExprMetadata.with_orderable_window, - ExprKind.ORDERABLE_FILTRATION: ExprMetadata.with_orderable_filtration, ExprKind.ORDERABLE_AGGREGATION: ExprMetadata.with_orderable_aggregation, - ExprKind.WINDOW: ExprMetadata.with_window, + ExprKind.ORDERABLE_FILTRATION: ExprMetadata.with_orderable_filtration, ExprKind.OVER: ExprMetadata.with_over, + ExprKind.ORDERABLE_WINDOW: ExprMetadata.with_orderable_window, + ExprKind.WINDOW: ExprMetadata.with_window, } From 90468deef299118a9a80568e8552e858eb65f866 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 16 Oct 2025 13:57:38 +0100 Subject: [PATCH 83/95] raise developer-facing assertionerror in _metadata --- narwhals/_compliant/expr.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index 249ad4d3b6..9c8df85864 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -95,7 +95,15 @@ class CompliantExpr( @property def _metadata(self) -> ExprMetadata: - assert self._opt_metadata is not None # noqa: S101 + if self._opt_metadata is None: + msg = ( + "`_opt_metadata` is None. This is usually the result of trying to do " + "some operation (such as `over`) which requires access to the metadata " + "at the compliant level. You may want to consider rewriting your logic " + "so that this operation is not necessary. Ideally you should avoid " + "setting `_opt_metadata` manually." + ) + raise AssertionError(msg) return self._opt_metadata def __call__( From a657cbbf19b0023ec48ba36429a99c0dcdc719e6 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 16 Oct 2025 14:05:35 +0100 Subject: [PATCH 84/95] cvg --- narwhals/_compliant/expr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index 9c8df85864..997746d429 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -95,7 +95,7 @@ class CompliantExpr( @property def _metadata(self) -> ExprMetadata: - if self._opt_metadata is None: + if self._opt_metadata is None: # pragma: no cover msg = ( "`_opt_metadata` is None. This is usually the result of trying to do " "some operation (such as `over`) which requires access to the metadata " From 05928049abffb9de651afbffd5c443d681f2e66c Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 17 Oct 2025 08:47:19 +0100 Subject: [PATCH 85/95] correctly respect arguments metadata in `with_filtration`, add test --- narwhals/_expression_parsing.py | 8 +++++--- narwhals/_sql/expr.py | 1 + tests/expression_parsing_test.py | 1 + 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index f229e11e22..17348e9951 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -642,15 +642,17 @@ def with_partitioned_over( ) def with_filtration( - self, node: ExprNode, _ce: CompliantExprAny, *_ces: CompliantExprAny + self, node: ExprNode, *compliant_exprs: CompliantExprAny ) -> ExprMetadata: if self.is_scalar_like: msg = "Can't apply filtration (e.g. `drop_nulls`) to scalar-like expression." raise InvalidOperationError(msg) + result_has_windows = any(x._metadata.has_windows for x in compliant_exprs) + result_n_orderable_ops = sum(x._metadata.n_orderable_ops for x in compliant_exprs) return ExprMetadata( self.expansion_kind, - has_windows=self.has_windows, - n_orderable_ops=self.n_orderable_ops, + has_windows=result_has_windows, + n_orderable_ops=result_n_orderable_ops, preserves_length=False, is_elementwise=False, is_scalar_like=False, diff --git a/narwhals/_sql/expr.py b/narwhals/_sql/expr.py index a428bb0b78..7881b56324 100644 --- a/narwhals/_sql/expr.py +++ b/narwhals/_sql/expr.py @@ -845,4 +845,5 @@ def str(self) -> SQLExprStringNamespace[Self]: ... def dt(self) -> SQLExprDateTimeNamesSpace[Self]: ... drop_nulls = not_implemented() # type: ignore[misc] + filter = not_implemented() # type: ignore[misc] unique = not_implemented() # type: ignore[misc] diff --git a/tests/expression_parsing_test.py b/tests/expression_parsing_test.py index b5c6b722f1..046bbd1c16 100644 --- a/tests/expression_parsing_test.py +++ b/tests/expression_parsing_test.py @@ -104,6 +104,7 @@ def test_per_group_broadcasting( nw.col("a").drop_nulls().over("b"), nw.col("a").drop_nulls().over("b", order_by="i"), nw.col("a").diff().drop_nulls().over("b", order_by="i"), + nw.col("a").filter(nw.col("b").sum().over("c") > 1).sum().over("d"), ], ) def test_invalid_operations(constructor: Constructor, expr: nw.Expr) -> None: From 5d150f38085653ada8aab455bd1bae0a709933b4 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 17 Oct 2025 09:12:40 +0100 Subject: [PATCH 86/95] mark `filter` not implemented for dask --- narwhals/_compliant/column.py | 1 + narwhals/_compliant/series.py | 1 - narwhals/_dask/expr.py | 3 ++- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/narwhals/_compliant/column.py b/narwhals/_compliant/column.py index 5561a48a4d..96fd9040ac 100644 --- a/narwhals/_compliant/column.py +++ b/narwhals/_compliant/column.py @@ -86,6 +86,7 @@ def fill_nan(self, value: float | None) -> Self: ... def fill_null( self, value: Self | None, strategy: FillNullStrategy | None, limit: int | None ) -> Self: ... + def filter(self, predicate: Self) -> Self: ... def is_between( self, lower_bound: Self, upper_bound: Self, closed: ClosedInterval ) -> Self: diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index 434b2dc760..f9c750b50a 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -132,7 +132,6 @@ def arg_max(self) -> int: ... def arg_min(self) -> int: ... def arg_true(self) -> Self: ... def count(self) -> int: ... - def filter(self, predicate: Any) -> Self: ... def first(self) -> PythonLiteral: ... def last(self) -> PythonLiteral: ... def gather_every(self, n: int, offset: int) -> Self: ... diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 0a1abe929d..0bbb89b6d8 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -689,8 +689,9 @@ def str(self) -> DaskExprStringNamespace: def dt(self) -> DaskExprDateTimeNamespace: return DaskExprDateTimeNamespace(self) - rank = not_implemented() + filter = not_implemented() first = not_implemented() + rank = not_implemented() last = not_implemented() # namespaces From 739e0d5a39d3c0579ef7d9a5b4b3fb351067c337 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 17 Oct 2025 09:15:27 +0100 Subject: [PATCH 87/95] fixup --- narwhals/_compliant/column.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_compliant/column.py b/narwhals/_compliant/column.py index 96fd9040ac..c945706fdd 100644 --- a/narwhals/_compliant/column.py +++ b/narwhals/_compliant/column.py @@ -86,7 +86,7 @@ def fill_nan(self, value: float | None) -> Self: ... def fill_null( self, value: Self | None, strategy: FillNullStrategy | None, limit: int | None ) -> Self: ... - def filter(self, predicate: Self) -> Self: ... + def filter(self, *predicates: Self) -> Self: ... def is_between( self, lower_bound: Self, upper_bound: Self, closed: ClosedInterval ) -> Self: From 34229c1e6fd5d0b3cf87badcf7254809d1390ba6 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 17 Oct 2025 09:55:47 +0100 Subject: [PATCH 88/95] keep `filter` in `CompliantSeries` for now --- narwhals/_compliant/column.py | 1 - narwhals/_compliant/series.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_compliant/column.py b/narwhals/_compliant/column.py index c945706fdd..5561a48a4d 100644 --- a/narwhals/_compliant/column.py +++ b/narwhals/_compliant/column.py @@ -86,7 +86,6 @@ def fill_nan(self, value: float | None) -> Self: ... def fill_null( self, value: Self | None, strategy: FillNullStrategy | None, limit: int | None ) -> Self: ... - def filter(self, *predicates: Self) -> Self: ... def is_between( self, lower_bound: Self, upper_bound: Self, closed: ClosedInterval ) -> Self: diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index f9c750b50a..434b2dc760 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -132,6 +132,7 @@ def arg_max(self) -> int: ... def arg_min(self) -> int: ... def arg_true(self) -> Self: ... def count(self) -> int: ... + def filter(self, predicate: Any) -> Self: ... def first(self) -> PythonLiteral: ... def last(self) -> PythonLiteral: ... def gather_every(self, n: int, offset: int) -> Self: ... From 4e8f9ac566443b28415932f161324103704b5100 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 26 Oct 2025 12:59:39 +0000 Subject: [PATCH 89/95] assorted simplifications --- narwhals/_compliant/dataframe.py | 2 +- narwhals/_compliant/group_by.py | 2 +- narwhals/_dask/expr_str.py | 13 +------------ narwhals/_dask/namespace.py | 12 ++++++++---- narwhals/_duckdb/utils.py | 6 ++---- narwhals/_pandas_like/group_by.py | 2 +- narwhals/_polars/namespace.py | 10 +++------- narwhals/stable/v1/__init__.py | 4 ++-- 8 files changed, 19 insertions(+), 32 deletions(-) diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index dcc8bc6a58..03361545bf 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -362,7 +362,7 @@ def _evaluate_exprs(self, *exprs: EagerExprT) -> Sequence[EagerSeriesT]: # NOTE: Ignore intermittent [False Negative] # Argument of type "EagerExprT@EagerDataFrame" cannot be assigned to parameter "expr" of type "EagerExprT@EagerDataFrame" in function "_evaluate_into_expr" # Type "EagerExprT@EagerDataFrame" is not assignable to type "EagerExprT@EagerDataFrame" - return list(chain.from_iterable(self._evaluate_expr(expr) for expr in exprs)) # pyright: ignore[reportArgumentType] + return tuple(chain.from_iterable(self._evaluate_expr(expr) for expr in exprs)) # pyright: ignore[reportArgumentType] def _evaluate_expr(self, expr: EagerExprT, /) -> Sequence[EagerSeriesT]: """Return list of raw columns. diff --git a/narwhals/_compliant/group_by.py b/narwhals/_compliant/group_by.py index de9cf867e8..cb62edf7ac 100644 --- a/narwhals/_compliant/group_by.py +++ b/narwhals/_compliant/group_by.py @@ -170,7 +170,7 @@ def _leaf_name(cls, expr: DepthTrackingExprAny, /) -> NarwhalsAggregation | Any: @classmethod def _kwargs(cls, expr: DepthTrackingExprAny, /) -> dict[str, Any]: - """Return the last function name in the chain defined by `expr`.""" + """Return the last function kwargs in the chain defined by `expr`.""" return next(expr._metadata.op_nodes_reversed()).kwargs diff --git a/narwhals/_dask/expr_str.py b/narwhals/_dask/expr_str.py index 677761c58a..0a5e036e3c 100644 --- a/narwhals/_dask/expr_str.py +++ b/narwhals/_dask/expr_str.py @@ -34,18 +34,7 @@ def _replace(expr: dx.Series, value: dx.Series) -> dx.Series: return self.compliant._with_callable(_replace, value=value) def replace_all(self, value: DaskExpr, pattern: str, *, literal: bool) -> DaskExpr: - if not value._metadata.is_literal: - msg = ( - "dask backed `Expr.str.replace_all` only supports str replacement values" - ) - raise TypeError(msg) - - def _replace_all(expr: dx.Series, value: dx.Series) -> dx.Series: - return expr.str.replace( # pyright: ignore[reportAttributeAccessIssue] - pattern, value.compute(), regex=not literal, n=-1 - ) - - return self.compliant._with_callable(_replace_all, value=value) + return self.replace(value, pattern, literal=literal, n=-1) def strip_chars(self, characters: str | None) -> DaskExpr: return self.compliant._with_callable(lambda expr: expr.str.strip(characters)) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 1f39f74fd4..c4791a7e0d 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -273,10 +273,14 @@ def when_then( self, predicate: DaskExpr, then: DaskExpr, otherwise: DaskExpr | None = None ) -> DaskExpr: def func(df: DaskLazyFrame) -> list[dx.Series]: - then_value = then(df)[0] - otherwise_value = otherwise(df)[0] if otherwise is not None else otherwise + then_value = df._evaluate_single_output_expr(then) + otherwise_value = ( + df._evaluate_single_output_expr(otherwise) + if otherwise is not None + else otherwise + ) - condition = predicate(df)[0] + condition = df._evaluate_single_output_expr(predicate) # re-evaluate DataFrame if the condition aggregates to force # then/otherwise to be evaluated against the aggregated frame if all( @@ -288,7 +292,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: ) ): new_df = df._with_native(condition.to_frame()) - condition = predicate.broadcast()(df)[0] + condition = df._evaluate_single_output_expr(predicate.broadcast()) df = new_df if otherwise is None: diff --git a/narwhals/_duckdb/utils.py b/narwhals/_duckdb/utils.py index 8312acb1ba..37d86a88ba 100644 --- a/narwhals/_duckdb/utils.py +++ b/narwhals/_duckdb/utils.py @@ -45,10 +45,8 @@ col = duckdb.ColumnExpression """Alias for `duckdb.ColumnExpression`.""" - -def lit(value: object) -> duckdb.Expression: - return duckdb.ConstantExpression(value) - +lit = duckdb.ConstantExpression +"""Alias for `duckdb.ConstantExpression`.""" when = duckdb.CaseExpression """Alias for `duckdb.CaseExpression`.""" diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 8eadb4f824..def77bd15b 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -122,7 +122,7 @@ def _getitem_aggs( ) elif self.is_mode(): compliant = group_by.compliant - node_kwargs = next(self.expr._metadata.op_nodes_reversed()).kwargs + node_kwargs = group_by._kwargs(self.expr) if (keep := node_kwargs.get("keep")) != "any": # pragma: no cover msg = ( f"`Expr.mode(keep='{keep}')` is not implemented in group by context for " diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 145fa89ed1..ac8da364be 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -71,6 +71,9 @@ def _expr(self) -> type[PolarsExpr]: def _series(self) -> type[PolarsSeries]: return PolarsSeries + def is_native(self, obj: Any) -> TypeIs[pl.DataFrame | pl.LazyFrame | pl.Series]: + return isinstance(obj, (pl.DataFrame, pl.LazyFrame, pl.Series)) + @overload def from_native(self, data: pl.DataFrame, /) -> PolarsDataFrame: ... @overload @@ -211,13 +214,6 @@ def when_then( version=self._version, ) - def is_native(self, obj: Any, /) -> TypeIs[pl.DataFrame | pl.LazyFrame | pl.Series]: - return ( - self._dataframe._is_native(obj) - or self._series._is_native(obj) - or self._lazyframe._is_native(obj) - ) - # NOTE: Implementation is too different to annotate correctly (vs other `*SelectorNamespace`) # 1. Others have lots of private stuff for code reuse # i. None of that is useful here diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index 3e48fa6bc5..42d06ce928 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -373,11 +373,11 @@ def _l1_norm(self) -> Self: def head(self, n: int = 10) -> Self: r"""Get the first `n` rows.""" - return self._append_node(ExprNode(ExprKind.FILTRATION, "head", n=n)) + return self._append_node(ExprNode(ExprKind.ORDERABLE_FILTRATION, "head", n=n)) def tail(self, n: int = 10) -> Self: r"""Get the last `n` rows.""" - return self._append_node(ExprNode(ExprKind.FILTRATION, "tail", n=n)) + return self._append_node(ExprNode(ExprKind.ORDERABLE_FILTRATION, "tail", n=n)) def gather_every(self, n: int, offset: int = 0) -> Self: r"""Take every nth value in the Series and return as new Series. From 55a6b53d9edad5de1ba99b2ce5f3432316ba4b2e Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 26 Oct 2025 14:55:30 +0000 Subject: [PATCH 90/95] import `col` top-level --- narwhals/functions.py | 25 ++++++++++++------------- narwhals/series.py | 3 +-- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/narwhals/functions.py b/narwhals/functions.py index 1bbf178dea..28ab30a641 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -5,7 +5,7 @@ from collections.abc import Iterable, Mapping, Sequence from typing import TYPE_CHECKING, Any -from narwhals._expression_parsing import ExprKind, ExprNode +from narwhals._expression_parsing import ExprKind, ExprNode, is_expr, is_series from narwhals._utils import ( Implementation, Version, @@ -25,7 +25,6 @@ ) from narwhals.exceptions import InvalidOperationError from narwhals.expr import Expr -from narwhals.series import Series from narwhals.translate import from_native, to_native if TYPE_CHECKING: @@ -37,6 +36,7 @@ from narwhals._translate import IntoArrowTable from narwhals._typing import Backend, EagerAllowed, IntoBackend from narwhals.dataframe import DataFrame, LazyFrame + from narwhals.series import Series from narwhals.typing import ( ConcatMethod, FileSource, @@ -976,9 +976,7 @@ def exclude(*names: str | Iterable[str]) -> Expr: | └─────┘ | └──────────────────┘ """ - flat_names = flatten(names) - exclude_names = frozenset(flat_names) - return Expr(ExprNode(ExprKind.EXCLUDE, "exclude", names=exclude_names)) + return Expr(ExprNode(ExprKind.EXCLUDE, "exclude", names=frozenset(flatten(names)))) def nth(*indices: int | Sequence[int]) -> Expr: @@ -1240,8 +1238,7 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: |└─────┴──────┴─────┘| └────────────────────┘ """ - flat_exprs = flatten(exprs) - return _expr_with_horizontal_op("sum_horizontal", *flat_exprs) + return _expr_with_horizontal_op("sum_horizontal", *flatten(exprs)) def min_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: @@ -1412,9 +1409,8 @@ def all_horizontal(*exprs: IntoExpr | Iterable[IntoExpr], ignore_nulls: bool) -> |all: [[false,false,true,null,false,null]]| └─────────────────────────────────────────┘ """ - flat_exprs = flatten(exprs) return _expr_with_horizontal_op( - "all_horizontal", *flat_exprs, ignore_nulls=ignore_nulls + "all_horizontal", *flatten(exprs), ignore_nulls=ignore_nulls ) @@ -1497,9 +1493,8 @@ def any_horizontal(*exprs: IntoExpr | Iterable[IntoExpr], ignore_nulls: bool) -> |└───────┴───────┴───────┘| └─────────────────────────┘ """ - flat_exprs = flatten(exprs) return _expr_with_horizontal_op( - "any_horizontal", *flat_exprs, ignore_nulls=ignore_nulls + "any_horizontal", *flatten(exprs), ignore_nulls=ignore_nulls ) @@ -1630,10 +1625,14 @@ def coalesce( """ flat_exprs = flatten([*flatten([exprs]), *more_exprs]) - non_exprs = [expr for expr in flat_exprs if not isinstance(expr, (str, Expr, Series))] + non_exprs = [ + expr + for expr in flat_exprs + if not (isinstance(expr, str) or is_expr(expr) or is_series(expr)) + ] if non_exprs: msg = ( - f"All arguments to `coalesce` must be of type {str!r}, {Expr!r}, or {Series!r}." + f"All arguments to `coalesce` must be of type {str!r}, Expr, or Series." "\nGot the following invalid arguments (type, value):" f"\n {', '.join(repr((type(e), e)) for e in non_exprs)}" ) diff --git a/narwhals/series.py b/narwhals/series.py index 270ec7a432..2e34b49913 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -31,6 +31,7 @@ from narwhals.dtypes import _validate_dtype, _validate_into_dtype from narwhals.exceptions import ComputeError, InvalidOperationError from narwhals.expr import Expr +from narwhals.functions import col from narwhals.series_cat import SeriesCatNamespace from narwhals.series_dt import SeriesDateTimeNamespace from narwhals.series_list import SeriesListNamespace @@ -2791,8 +2792,6 @@ def is_close( ] ] """ - from narwhals.functions import col - if not self.dtype.is_numeric(): msg = ( f"is_close operation not supported for dtype `{self.dtype}`\n\n" From 704280901d4194604d58117fd399757320284717 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 26 Oct 2025 15:36:14 +0000 Subject: [PATCH 91/95] add extra mean rank test --- narwhals/_expression_parsing.py | 4 ++-- tests/expression_parsing_test.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 17348e9951..4d89e5a56e 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -147,9 +147,9 @@ def is_orderable(self) -> bool: # `diff`, `rank`, `arg_max`, ... return self in { ExprKind.ORDERABLE_WINDOW, - ExprKind.ORDERABLE_AGGREGATION, - ExprKind.FILTRATION, ExprKind.WINDOW, + ExprKind.ORDERABLE_AGGREGATION, + ExprKind.ORDERABLE_FILTRATION, } @property diff --git a/tests/expression_parsing_test.py b/tests/expression_parsing_test.py index 046bbd1c16..138063eb06 100644 --- a/tests/expression_parsing_test.py +++ b/tests/expression_parsing_test.py @@ -121,3 +121,16 @@ def test_invalid_elementwise_over() -> None: # This one raises before it's even evaluated. with pytest.raises(InvalidOperationError): nw.col("a").fill_null(3).over("b") + + +def test_rank_with_order_by_pushdown() -> None: + pytest.importorskip("pandas") + import pandas as pd + + df = nw.from_native(pd.DataFrame({"a": [1, 1, 2], "i": [2, 1, 0]})) + result = df.select( + "a", + res=nw.sum_horizontal(nw.col("a").rank("ordinal"), nw.lit(1)).over(order_by="i"), + ) + expected = {"a": [1, 1, 2], "res": [3.0, 2.0, 4.0]} + assert_equal_data(result, expected) From 18da6a59250abcd8c3aaf282b2f7e3472e956875 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 26 Oct 2025 15:43:29 +0000 Subject: [PATCH 92/95] more coverage --- narwhals/_expression_parsing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 4d89e5a56e..e1b63d8d2b 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -307,7 +307,9 @@ def is_elementwise(self) -> bool: if self._is_elementwise_cached is None: # Note: don't combine these if/then statements so that pytest-cov shows if # anything is uncovered. - if not self.kind.is_elementwise or not all( + if not self.kind.is_elementwise: # noqa: SIM114 + self._is_elementwise_cached = False + elif not all( all(node.is_elementwise() for node in expr._nodes) for expr in self.exprs if is_expr(expr) From 4d5347d3acd07f60fd83e2e7eb4510574dc1222d Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 26 Oct 2025 15:44:36 +0000 Subject: [PATCH 93/95] de morgans laws simplification (i think) --- narwhals/_expression_parsing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index e1b63d8d2b..2bd68d7062 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -309,8 +309,8 @@ def is_elementwise(self) -> bool: # anything is uncovered. if not self.kind.is_elementwise: # noqa: SIM114 self._is_elementwise_cached = False - elif not all( - all(node.is_elementwise() for node in expr._nodes) + elif any( + any(not node.is_elementwise() for node in expr._nodes) for expr in self.exprs if is_expr(expr) ): From 97cc126e20b551ee0aea081f3816b645c99c5658 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 26 Oct 2025 16:05:49 +0000 Subject: [PATCH 94/95] reduce diff --- narwhals/_expression_parsing.py | 44 ++++++++++++++++----------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 2bd68d7062..6f1c3205b9 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -501,20 +501,6 @@ def with_aggregation(self, node: ExprNode, _ce: CompliantExprAny) -> ExprMetadat prev=self, ) - def with_elementwise( - self, - node: ExprNode, - compliant_expr: CompliantExprAny, - *compliant_expr_args: CompliantExprAny, - ) -> ExprMetadata: - return combine_metadata( - compliant_expr, - *compliant_expr_args, - to_single_output=False, - current_node=node, - prev=compliant_expr._metadata, - ) - def with_orderable_aggregation( self, node: ExprNode, _ce: CompliantExprAny ) -> ExprMetadata: @@ -534,6 +520,20 @@ def with_orderable_aggregation( prev=self, ) + def with_elementwise( + self, + node: ExprNode, + compliant_expr: CompliantExprAny, + *compliant_expr_args: CompliantExprAny, + ) -> ExprMetadata: + return combine_metadata( + compliant_expr, + *compliant_expr_args, + to_single_output=False, + current_node=node, + prev=compliant_expr._metadata, + ) + def with_window(self, node: ExprNode, _ce: CompliantExprAny) -> ExprMetadata: # Window function which may (but doesn't have to) be used with `over(order_by=...)`. if self.is_scalar_like: @@ -553,14 +553,6 @@ def with_window(self, node: ExprNode, _ce: CompliantExprAny) -> ExprMetadata: prev=self, ) - def with_over(self, node: ExprNode, _ce: CompliantExprAny) -> ExprMetadata: - if node.kwargs["order_by"]: - return self.with_ordered_over(node, _ce) - if not node.kwargs["partition_by"]: # pragma: no cover - msg = "At least one of `partition_by` or `order_by` must be specified." - raise InvalidOperationError(msg) - return self.with_partitioned_over(node, _ce) - def with_orderable_window( self, node: ExprNode, _ce: CompliantExprAny ) -> ExprMetadata: @@ -643,6 +635,14 @@ def with_partitioned_over( prev=self, ) + def with_over(self, node: ExprNode, _ce: CompliantExprAny) -> ExprMetadata: + if node.kwargs["order_by"]: + return self.with_ordered_over(node, _ce) + if not node.kwargs["partition_by"]: # pragma: no cover + msg = "At least one of `partition_by` or `order_by` must be specified." + raise InvalidOperationError(msg) + return self.with_partitioned_over(node, _ce) + def with_filtration( self, node: ExprNode, *compliant_exprs: CompliantExprAny ) -> ExprMetadata: From 2a595c09340b3407f9db415881c68ae174eca0d6 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 26 Oct 2025 16:29:08 +0000 Subject: [PATCH 95/95] de morgan again --- narwhals/expr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/expr.py b/narwhals/expr.py index d0f38b7419..887f277cc6 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -82,8 +82,8 @@ def _with_over_node(self, node: ExprNode) -> Self: return self.__class__(*new_nodes) if node.kwargs["order_by"] and any(node.is_orderable() for node in new_nodes[:i]): new_nodes.insert(i, node) - elif node.kwargs["partition_by"] and not all( - node.is_elementwise() for node in new_nodes[:i] + elif node.kwargs["partition_by"] and any( + not node.is_elementwise() for node in new_nodes[:i] ): new_nodes.insert(i, node_without_order_by) elif all(node.is_elementwise() for node in new_nodes):