1212from narwhals ._arrow .expr import ArrowExpr
1313from narwhals ._arrow .selectors import ArrowSelectorNamespace
1414from narwhals ._arrow .series import ArrowSeries
15- from narwhals ._arrow .utils import (
16- align_series_full_broadcast ,
17- cast_to_comparable_string_types ,
18- )
15+ from narwhals ._arrow .utils import cast_to_comparable_string_types
1916from narwhals ._compliant import CompliantThen , EagerNamespace , EagerWhen
2017from narwhals ._expression_parsing import (
2118 combine_alias_output_names ,
2623if TYPE_CHECKING :
2724 from collections .abc import Sequence
2825
29- from narwhals ._arrow .typing import Incomplete
26+ from narwhals ._arrow .typing import ArrayOrScalar , ChunkedArrayAny , Incomplete
3027 from narwhals ._utils import Version
3128 from narwhals .typing import IntoDType , NonNestedLiteral
3229
3330
34- class ArrowNamespace (EagerNamespace [ArrowDataFrame , ArrowSeries , ArrowExpr , pa .Table ]):
31+ class ArrowNamespace (
32+ EagerNamespace [ArrowDataFrame , ArrowSeries , ArrowExpr , pa .Table , "ChunkedArrayAny" ]
33+ ):
3534 @property
3635 def _dataframe (self ) -> type [ArrowDataFrame ]:
3736 return ArrowDataFrame
@@ -86,7 +85,8 @@ def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
8685 def all_horizontal (self , * exprs : ArrowExpr ) -> ArrowExpr :
8786 def func (df : ArrowDataFrame ) -> list [ArrowSeries ]:
8887 series = chain .from_iterable (expr (df ) for expr in exprs )
89- return [reduce (operator .and_ , align_series_full_broadcast (* series ))]
88+ align = self ._series ._align_full_broadcast
89+ return [reduce (operator .and_ , align (* series ))]
9090
9191 return self ._expr ._from_callable (
9292 func = func ,
@@ -100,7 +100,8 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
100100 def any_horizontal (self , * exprs : ArrowExpr ) -> ArrowExpr :
101101 def func (df : ArrowDataFrame ) -> list [ArrowSeries ]:
102102 series = chain .from_iterable (expr (df ) for expr in exprs )
103- return [reduce (operator .or_ , align_series_full_broadcast (* series ))]
103+ align = self ._series ._align_full_broadcast
104+ return [reduce (operator .or_ , align (* series ))]
104105
105106 return self ._expr ._from_callable (
106107 func = func ,
@@ -115,7 +116,8 @@ def sum_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
115116 def func (df : ArrowDataFrame ) -> list [ArrowSeries ]:
116117 it = chain .from_iterable (expr (df ) for expr in exprs )
117118 series = (s .fill_null (0 , strategy = None , limit = None ) for s in it )
118- return [reduce (operator .add , align_series_full_broadcast (* series ))]
119+ align = self ._series ._align_full_broadcast
120+ return [reduce (operator .add , align (* series ))]
119121
120122 return self ._expr ._from_callable (
121123 func = func ,
@@ -131,12 +133,11 @@ def mean_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
131133
132134 def func (df : ArrowDataFrame ) -> list [ArrowSeries ]:
133135 expr_results = list (chain .from_iterable (expr (df ) for expr in exprs ))
134- series = align_series_full_broadcast (
136+ align = self ._series ._align_full_broadcast
137+ series = align (
135138 * (s .fill_null (0 , strategy = None , limit = None ) for s in expr_results )
136139 )
137- non_na = align_series_full_broadcast (
138- * (1 - s .is_null ().cast (int_64 ) for s in expr_results )
139- )
140+ non_na = align (* (1 - s .is_null ().cast (int_64 ) for s in expr_results ))
140141 return [reduce (operator .add , series ) / reduce (operator .add , non_na )]
141142
142143 return self ._expr ._from_callable (
@@ -150,8 +151,9 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
150151
151152 def min_horizontal (self , * exprs : ArrowExpr ) -> ArrowExpr :
152153 def func (df : ArrowDataFrame ) -> list [ArrowSeries ]:
154+ align = self ._series ._align_full_broadcast
153155 init_series , * series = list (chain .from_iterable (expr (df ) for expr in exprs ))
154- init_series , * series = align_series_full_broadcast (init_series , * series )
156+ init_series , * series = align (init_series , * series )
155157 native_series = reduce (
156158 pc .min_element_wise , [s .native for s in series ], init_series .native
157159 )
@@ -175,8 +177,9 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
175177
176178 def max_horizontal (self , * exprs : ArrowExpr ) -> ArrowExpr :
177179 def func (df : ArrowDataFrame ) -> list [ArrowSeries ]:
180+ align = self ._series ._align_full_broadcast
178181 init_series , * series = list (chain .from_iterable (expr (df ) for expr in exprs ))
179- init_series , * series = align_series_full_broadcast (init_series , * series )
182+ init_series , * series = align (init_series , * series )
180183 native_series = reduce (
181184 pc .max_element_wise , [s .native for s in series ], init_series .native
182185 )
@@ -232,7 +235,8 @@ def concat_str(
232235 self , * exprs : ArrowExpr , separator : str , ignore_nulls : bool
233236 ) -> ArrowExpr :
234237 def func (df : ArrowDataFrame ) -> list [ArrowSeries ]:
235- compliant_series_list = align_series_full_broadcast (
238+ align = self ._series ._align_full_broadcast
239+ compliant_series_list = align (
236240 * (chain .from_iterable (expr (df ) for expr in exprs ))
237241 )
238242 name = compliant_series_list [0 ].name
@@ -263,23 +267,20 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
263267 )
264268
265269
266- class ArrowWhen (EagerWhen [ArrowDataFrame , ArrowSeries , ArrowExpr ]):
270+ class ArrowWhen (EagerWhen [ArrowDataFrame , ArrowSeries , ArrowExpr , "ChunkedArrayAny" ]):
267271 @property
268272 def _then (self ) -> type [ArrowThen ]:
269273 return ArrowThen
270274
271275 def _if_then_else (
272- self , when : ArrowSeries , then : ArrowSeries , otherwise : ArrowSeries | None , /
273- ) -> ArrowSeries :
274- if otherwise is None :
275- when , then = align_series_full_broadcast (when , then )
276- res_native = pc .if_else (
277- when .native , then .native , pa .nulls (len (when .native ), then .native .type )
278- )
279- else :
280- when , then , otherwise = align_series_full_broadcast (when , then , otherwise )
281- res_native = pc .if_else (when .native , then .native , otherwise .native )
282- return then ._with_native (res_native )
276+ self ,
277+ when : ChunkedArrayAny ,
278+ then : ChunkedArrayAny ,
279+ otherwise : ArrayOrScalar | NonNestedLiteral ,
280+ / ,
281+ ) -> ChunkedArrayAny :
282+ otherwise = pa .nulls (len (when ), then .type ) if otherwise is None else otherwise
283+ return pc .if_else (when , then , otherwise )
283284
284285
285286class ArrowThen (CompliantThen [ArrowDataFrame , ArrowSeries , ArrowExpr ], ArrowExpr ): ...
0 commit comments