Skip to content

Commit 6eaf0ea

Browse files
authored
Merge branch 'main' into feat/with-row-index-by
2 parents 44a756f + a96e03f commit 6eaf0ea

37 files changed

+613
-233
lines changed

.github/workflows/downstream_tests.yml

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -309,22 +309,26 @@ jobs:
309309
git clone https://github.com/plotly/plotly.py --depth=1
310310
cd plotly.py
311311
git log
312-
- name: install-basics
313-
run: uv pip install --upgrade tox virtualenv setuptools --system
314312
- name: install-deps
315313
run: |
316314
cd plotly.py
317-
uv pip install -r test_requirements/requirements_core.txt -r test_requirements/requirements_optional.txt --system
318-
uv pip install -e . --system
315+
uv venv -p ${{ matrix.python-version }}
316+
uv sync --extra dev_optional
319317
- name: install-narwhals-dev
320318
run: |
321-
uv pip uninstall narwhals --system
322-
uv pip install -e . --system
319+
cd plotly.py
320+
. .venv/bin/activate
321+
uv pip uninstall narwhals
322+
uv pip install -e ./..
323323
- name: show-deps
324-
run: uv pip freeze
324+
run: |
325+
cd plotly.py
326+
. .venv/bin/activate
327+
uv pip freeze
325328
- name: Run pytest on plotly express
326329
run: |
327330
cd plotly.py
331+
. .venv/bin/activate
328332
pytest tests/test_optional/test_px
329333
330334
hierarchicalforecast:

docs/api-reference/expr_str.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@
1818
- to_datetime
1919
- to_lowercase
2020
- to_uppercase
21+
- zfill
2122
show_source: false
2223
show_bases: false

docs/api-reference/series_str.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@
1818
- to_datetime
1919
- to_lowercase
2020
- to_uppercase
21+
- zfill
2122
show_source: false
2223
show_bases: false

narwhals/_arrow/dataframe.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pyarrow.compute as pc
99

1010
from narwhals._arrow.series import ArrowSeries
11-
from narwhals._arrow.utils import align_series_full_broadcast, native_to_narwhals_dtype
11+
from narwhals._arrow.utils import native_to_narwhals_dtype
1212
from narwhals._compliant import EagerDataFrame
1313
from narwhals._expression_parsing import ExprKind
1414
from narwhals._utils import (
@@ -74,7 +74,9 @@
7474
PromoteOptions: TypeAlias = Literal["none", "default", "permissive"]
7575

7676

77-
class ArrowDataFrame(EagerDataFrame["ArrowSeries", "ArrowExpr", "pa.Table"]):
77+
class ArrowDataFrame(
78+
EagerDataFrame["ArrowSeries", "ArrowExpr", "pa.Table", "ChunkedArrayAny"]
79+
):
7880
def __init__(
7981
self,
8082
native_dataframe: pa.Table,
@@ -330,7 +332,8 @@ def select(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame:
330332
self.native.__class__.from_arrays([]), validate_column_names=False
331333
)
332334
names = [s.name for s in new_series]
333-
reshaped = align_series_full_broadcast(*new_series)
335+
align = new_series[0]._align_full_broadcast
336+
reshaped = align(*new_series)
334337
df = pa.Table.from_arrays([s.native for s in reshaped], names=names)
335338
return self._with_native(df, validate_column_names=True)
336339

narwhals/_arrow/namespace.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,7 @@
1212
from narwhals._arrow.expr import ArrowExpr
1313
from narwhals._arrow.selectors import ArrowSelectorNamespace
1414
from narwhals._arrow.series import ArrowSeries
15-
from narwhals._arrow.utils import (
16-
align_series_full_broadcast,
17-
cast_to_comparable_string_types,
18-
)
15+
from narwhals._arrow.utils import cast_to_comparable_string_types
1916
from narwhals._compliant import CompliantThen, EagerNamespace, EagerWhen
2017
from narwhals._expression_parsing import (
2118
combine_alias_output_names,
@@ -26,12 +23,14 @@
2623
if TYPE_CHECKING:
2724
from collections.abc import Sequence
2825

29-
from narwhals._arrow.typing import Incomplete
26+
from narwhals._arrow.typing import ArrayOrScalar, ChunkedArrayAny, Incomplete
3027
from narwhals._utils import Version
3128
from narwhals.typing import IntoDType, NonNestedLiteral
3229

3330

34-
class ArrowNamespace(EagerNamespace[ArrowDataFrame, ArrowSeries, ArrowExpr, pa.Table]):
31+
class ArrowNamespace(
32+
EagerNamespace[ArrowDataFrame, ArrowSeries, ArrowExpr, pa.Table, "ChunkedArrayAny"]
33+
):
3534
@property
3635
def _dataframe(self) -> type[ArrowDataFrame]:
3736
return ArrowDataFrame
@@ -86,7 +85,8 @@ def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
8685
def all_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
8786
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
8887
series = chain.from_iterable(expr(df) for expr in exprs)
89-
return [reduce(operator.and_, align_series_full_broadcast(*series))]
88+
align = self._series._align_full_broadcast
89+
return [reduce(operator.and_, align(*series))]
9090

9191
return self._expr._from_callable(
9292
func=func,
@@ -100,7 +100,8 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
100100
def any_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
101101
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
102102
series = chain.from_iterable(expr(df) for expr in exprs)
103-
return [reduce(operator.or_, align_series_full_broadcast(*series))]
103+
align = self._series._align_full_broadcast
104+
return [reduce(operator.or_, align(*series))]
104105

105106
return self._expr._from_callable(
106107
func=func,
@@ -115,7 +116,8 @@ def sum_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
115116
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
116117
it = chain.from_iterable(expr(df) for expr in exprs)
117118
series = (s.fill_null(0, strategy=None, limit=None) for s in it)
118-
return [reduce(operator.add, align_series_full_broadcast(*series))]
119+
align = self._series._align_full_broadcast
120+
return [reduce(operator.add, align(*series))]
119121

120122
return self._expr._from_callable(
121123
func=func,
@@ -131,12 +133,11 @@ def mean_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
131133

132134
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
133135
expr_results = list(chain.from_iterable(expr(df) for expr in exprs))
134-
series = align_series_full_broadcast(
136+
align = self._series._align_full_broadcast
137+
series = align(
135138
*(s.fill_null(0, strategy=None, limit=None) for s in expr_results)
136139
)
137-
non_na = align_series_full_broadcast(
138-
*(1 - s.is_null().cast(int_64) for s in expr_results)
139-
)
140+
non_na = align(*(1 - s.is_null().cast(int_64) for s in expr_results))
140141
return [reduce(operator.add, series) / reduce(operator.add, non_na)]
141142

142143
return self._expr._from_callable(
@@ -150,8 +151,9 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
150151

151152
def min_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
152153
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
154+
align = self._series._align_full_broadcast
153155
init_series, *series = list(chain.from_iterable(expr(df) for expr in exprs))
154-
init_series, *series = align_series_full_broadcast(init_series, *series)
156+
init_series, *series = align(init_series, *series)
155157
native_series = reduce(
156158
pc.min_element_wise, [s.native for s in series], init_series.native
157159
)
@@ -175,8 +177,9 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
175177

176178
def max_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
177179
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
180+
align = self._series._align_full_broadcast
178181
init_series, *series = list(chain.from_iterable(expr(df) for expr in exprs))
179-
init_series, *series = align_series_full_broadcast(init_series, *series)
182+
init_series, *series = align(init_series, *series)
180183
native_series = reduce(
181184
pc.max_element_wise, [s.native for s in series], init_series.native
182185
)
@@ -232,7 +235,8 @@ def concat_str(
232235
self, *exprs: ArrowExpr, separator: str, ignore_nulls: bool
233236
) -> ArrowExpr:
234237
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
235-
compliant_series_list = align_series_full_broadcast(
238+
align = self._series._align_full_broadcast
239+
compliant_series_list = align(
236240
*(chain.from_iterable(expr(df) for expr in exprs))
237241
)
238242
name = compliant_series_list[0].name
@@ -263,23 +267,20 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
263267
)
264268

265269

266-
class ArrowWhen(EagerWhen[ArrowDataFrame, ArrowSeries, ArrowExpr]):
270+
class ArrowWhen(EagerWhen[ArrowDataFrame, ArrowSeries, ArrowExpr, "ChunkedArrayAny"]):
267271
@property
268272
def _then(self) -> type[ArrowThen]:
269273
return ArrowThen
270274

271275
def _if_then_else(
272-
self, when: ArrowSeries, then: ArrowSeries, otherwise: ArrowSeries | None, /
273-
) -> ArrowSeries:
274-
if otherwise is None:
275-
when, then = align_series_full_broadcast(when, then)
276-
res_native = pc.if_else(
277-
when.native, then.native, pa.nulls(len(when.native), then.native.type)
278-
)
279-
else:
280-
when, then, otherwise = align_series_full_broadcast(when, then, otherwise)
281-
res_native = pc.if_else(when.native, then.native, otherwise.native)
282-
return then._with_native(res_native)
276+
self,
277+
when: ChunkedArrayAny,
278+
then: ChunkedArrayAny,
279+
otherwise: ArrayOrScalar | NonNestedLiteral,
280+
/,
281+
) -> ChunkedArrayAny:
282+
otherwise = pa.nulls(len(when), then.type) if otherwise is None else otherwise
283+
return pc.if_else(when, then, otherwise)
283284

284285

285286
class ArrowThen(CompliantThen[ArrowDataFrame, ArrowSeries, ArrowExpr], ArrowExpr): ...

narwhals/_arrow/series.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
validate_backend_version,
3333
)
3434
from narwhals.dependencies import is_numpy_array_1d
35-
from narwhals.exceptions import InvalidOperationError
35+
from narwhals.exceptions import InvalidOperationError, ShapeError
3636

3737
if TYPE_CHECKING:
3838
from collections.abc import Iterable, Iterator, Mapping, Sequence
@@ -192,6 +192,25 @@ def from_numpy(cls, data: Into1DArray, /, *, context: _FullContext) -> Self:
192192
data if is_numpy_array_1d(data) else [data], context=context
193193
)
194194

195+
@classmethod
196+
def _align_full_broadcast(cls, *series: Self) -> Sequence[Self]:
197+
lengths = [len(s) for s in series]
198+
max_length = max(lengths)
199+
fast_path = all(_len == max_length for _len in lengths)
200+
if fast_path:
201+
return series
202+
reshaped = []
203+
for s in series:
204+
if s._broadcast:
205+
compliant = s._with_native(pa.repeat(s.native[0], max_length))
206+
elif (actual_len := len(s)) != max_length:
207+
msg = f"Expected object of length {max_length}, got {actual_len}."
208+
raise ShapeError(msg)
209+
else:
210+
compliant = s
211+
reshaped.append(compliant)
212+
return reshaped
213+
195214
def __narwhals_namespace__(self) -> ArrowNamespace:
196215
from narwhals._arrow.namespace import ArrowNamespace
197216

narwhals/_arrow/series_str.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import string
44
from typing import TYPE_CHECKING
55

6+
import pyarrow as pa
67
import pyarrow.compute as pc
78

89
from narwhals._arrow.utils import ArrowSeriesNamespace, lit, parse_datetime_format
910

1011
if TYPE_CHECKING:
1112
from narwhals._arrow.series import ArrowSeries
13+
from narwhals._arrow.typing import Incomplete
1214

1315

1416
class ArrowSeriesStringNamespace(ArrowSeriesNamespace):
@@ -60,3 +62,36 @@ def to_uppercase(self) -> ArrowSeries:
6062

6163
def to_lowercase(self) -> ArrowSeries:
6264
return self.with_native(pc.utf8_lower(self.native))
65+
66+
def zfill(self, width: int) -> ArrowSeries:
67+
binary_join: Incomplete = pc.binary_join_element_wise
68+
native = self.native
69+
hyphen, plus = lit("-"), lit("+")
70+
first_char, remaining_chars = self.slice(0, 1).native, self.slice(1, None).native
71+
72+
# Conditions
73+
less_than_width = pc.less(pc.utf8_length(native), lit(width))
74+
starts_with_hyphen = pc.equal(first_char, hyphen)
75+
starts_with_plus = pc.equal(first_char, plus)
76+
77+
conditions = pc.make_struct(
78+
pc.and_(starts_with_hyphen, less_than_width),
79+
pc.and_(starts_with_plus, less_than_width),
80+
less_than_width,
81+
)
82+
83+
# Cases
84+
padded_remaining_chars = pc.utf8_lpad(remaining_chars, width - 1, padding="0")
85+
86+
result = pc.case_when(
87+
conditions,
88+
binary_join(
89+
pa.repeat(hyphen, len(native)), padded_remaining_chars, ""
90+
), # starts with hyphen and less than width
91+
binary_join(
92+
pa.repeat(plus, len(native)), padded_remaining_chars, ""
93+
), # starts with plus and less than width
94+
pc.utf8_lpad(native, width=width, padding="0"), # less than width
95+
native,
96+
)
97+
return self.with_native(result)

narwhals/_arrow/utils.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88

99
from narwhals._compliant.series import _SeriesNamespace
1010
from narwhals._utils import isinstance_or_issubclass
11-
from narwhals.exceptions import ShapeError
1211

1312
if TYPE_CHECKING:
14-
from collections.abc import Iterable, Iterator, Mapping, Sequence
13+
from collections.abc import Iterable, Iterator, Mapping
1514

1615
from typing_extensions import TypeAlias, TypeIs
1716

@@ -261,31 +260,6 @@ def extract_native(
261260
return lhs.native, rhs if isinstance(rhs, pa.Scalar) else lit(rhs)
262261

263262

264-
def align_series_full_broadcast(*series: ArrowSeries) -> Sequence[ArrowSeries]:
265-
# Ensure all of `series` are of the same length.
266-
lengths = [len(s) for s in series]
267-
max_length = max(lengths)
268-
fast_path = all(_len == max_length for _len in lengths)
269-
270-
if fast_path:
271-
return series
272-
273-
reshaped = []
274-
for s in series:
275-
if s._broadcast:
276-
value = s.native[0]
277-
if s._backend_version < (13,) and hasattr(value, "as_py"):
278-
value = value.as_py()
279-
reshaped.append(s._with_native(pa.array([value] * max_length, type=s._type)))
280-
else:
281-
if (actual_len := len(s)) != max_length:
282-
msg = f"Expected object of length {max_length}, got {actual_len}."
283-
raise ShapeError(msg)
284-
reshaped.append(s)
285-
286-
return reshaped
287-
288-
289263
def floordiv_compat(left: ArrayOrScalar, right: ArrayOrScalar, /) -> Any:
290264
# The following lines are adapted from pandas' pyarrow implementation.
291265
# Ref: https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L124-L154

narwhals/_compliant/any_namespace.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def split(self, by: str) -> CompliantT_co: ...
7979
def to_datetime(self, format: str | None) -> CompliantT_co: ...
8080
def to_lowercase(self) -> CompliantT_co: ...
8181
def to_uppercase(self) -> CompliantT_co: ...
82+
def zfill(self, width: int) -> CompliantT_co: ...
8283

8384

8485
class StructNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):

narwhals/_compliant/dataframe.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
EagerSeriesT,
1414
NativeExprT,
1515
NativeFrameT,
16+
NativeSeriesT,
1617
)
1718
from narwhals._translate import (
1819
ArrowConvertible,
@@ -392,11 +393,11 @@ def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | N
392393
class EagerDataFrame(
393394
CompliantDataFrame[EagerSeriesT, EagerExprT, NativeFrameT, "DataFrame[NativeFrameT]"],
394395
CompliantLazyFrame[EagerExprT, NativeFrameT, "DataFrame[NativeFrameT]"],
395-
Protocol[EagerSeriesT, EagerExprT, NativeFrameT],
396+
Protocol[EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT],
396397
):
397398
def __narwhals_namespace__(
398399
self,
399-
) -> EagerNamespace[Self, EagerSeriesT, EagerExprT, NativeFrameT]: ...
400+
) -> EagerNamespace[Self, EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT]: ...
400401

401402
def to_narwhals(self) -> DataFrame[NativeFrameT]:
402403
return self._version.dataframe(self, level="full")
@@ -444,10 +445,14 @@ def _numpy_column_names(
444445
) -> list[str]:
445446
return list(columns or (f"column_{x}" for x in range(data.shape[1])))
446447

447-
def _gather(self, rows: SizedMultiIndexSelector[Any]) -> Self: ...
448+
def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ...
448449
def _gather_slice(self, rows: _SliceIndex | range) -> Self: ...
449-
def _select_multi_index(self, columns: SizedMultiIndexSelector[Any]) -> Self: ...
450-
def _select_multi_name(self, columns: SizedMultiNameSelector[Any]) -> Self: ...
450+
def _select_multi_index(
451+
self, columns: SizedMultiIndexSelector[NativeSeriesT]
452+
) -> Self: ...
453+
def _select_multi_name(
454+
self, columns: SizedMultiNameSelector[NativeSeriesT]
455+
) -> Self: ...
451456
def _select_slice_index(self, columns: _SliceIndex | range) -> Self: ...
452457
def _select_slice_name(self, columns: _SliceName) -> Self: ...
453458
def __getitem__( # noqa: C901, PLR0912

0 commit comments

Comments
 (0)