55
66from __future__ import annotations
77
8- import functools
98from typing import TYPE_CHECKING
109
1110import polars as pl
2827if TYPE_CHECKING :
2928 from typing_extensions import Self
3029
30+ from rmm .pylibrmm .stream import Stream
31+
3132 from cudf_polars .typing import (
3233 ColumnHeader ,
3334 ColumnOptions ,
@@ -82,6 +83,8 @@ def __init__(
8283 self .name = name
8384 self .dtype = dtype
8485 self .set_sorted (is_sorted = is_sorted , order = order , null_order = null_order )
86+ self ._nan_count : int | None = None
87+ self ._obj_scalar : plc .Scalar | None = None
8588
8689 @classmethod
8790 def deserialize (
@@ -126,6 +129,7 @@ def deserialize_ctor_kwargs(
126129
127130 def serialize (
128131 self ,
132+ stream : Stream ,
129133 ) -> tuple [ColumnHeader , tuple [memoryview [bytes ], plc .gpumemoryview ]]:
130134 """
131135 Serialize the Column into header and frames.
@@ -145,7 +149,7 @@ def serialize(
145149 frames
146150 Two-tuple of frames suitable for passing to `plc.contiguous_split.unpack_from_memoryviews`
147151 """
148- packed = plc .contiguous_split .pack (plc .Table ([self .obj ]))
152+ packed = plc .contiguous_split .pack (plc .Table ([self .obj ]), stream = stream )
149153 header : ColumnHeader = {
150154 "column_kwargs" : self .serialize_ctor_kwargs (),
151155 "frame_count" : 2 ,
@@ -162,8 +166,7 @@ def serialize_ctor_kwargs(self) -> ColumnOptions:
162166 "dtype" : pl .polars .dtype_str_repr (self .dtype .polars_type ),
163167 }
164168
165- @functools .cached_property
166- def obj_scalar (self ) -> plc .Scalar :
169+ def obj_scalar (self , stream : Stream ) -> plc .Scalar :
167170 """
168171 A copy of the column object as a pylibcudf Scalar.
169172
@@ -178,7 +181,9 @@ def obj_scalar(self) -> plc.Scalar:
178181 """
179182 if not self .is_scalar :
180183 raise ValueError (f"Cannot convert a column of length { self .size } to scalar" )
181- return plc .copying .get_element (self .obj , 0 )
184+ if self ._obj_scalar is None :
185+ self ._obj_scalar = plc .copying .get_element (self .obj , 0 , stream = stream )
186+ return self ._obj_scalar
182187
183188 def rename (self , name : str | None , / ) -> Self :
184189 """
@@ -228,6 +233,7 @@ def check_sorted(
228233 * ,
229234 order : plc .types .Order ,
230235 null_order : plc .types .NullOrder ,
236+ stream : Stream ,
231237 ) -> bool :
232238 """
233239 Check if the column is sorted.
@@ -238,6 +244,9 @@ def check_sorted(
238244 The requested sort order.
239245 null_order
240246 Where nulls sort to.
247+ stream
248+ CUDA stream used for device memory operations and kernel launches
249+ on this dataframe. The data in ``self.obj`` must be valid on this stream.
241250
242251 Returns
243252 -------
@@ -254,21 +263,26 @@ def check_sorted(
254263 return self .order == order and (
255264 self .null_count == 0 or self .null_order == null_order
256265 )
257- if plc .sorting .is_sorted (plc .Table ([self .obj ]), [order ], [null_order ]):
266+ if plc .sorting .is_sorted (
267+ plc .Table ([self .obj ]), [order ], [null_order ], stream = stream
268+ ):
258269 self .sorted = plc .types .Sorted .YES
259270 self .order = order
260271 self .null_order = null_order
261272 return True
262273 return False
263274
264- def astype (self , dtype : DataType ) -> Column :
275+ def astype (self , dtype : DataType , stream : Stream ) -> Column :
265276 """
266277 Cast the column to as the requested dtype.
267278
268279 Parameters
269280 ----------
270281 dtype
271282 Datatype to cast to.
283+ stream
284+ CUDA stream used for device memory operations and kernel launches
285+ on this dataframe. The data in ``self.obj`` must be valid on this stream.
272286
273287 Returns
274288 -------
@@ -292,11 +306,15 @@ def astype(self, dtype: DataType) -> Column:
292306 plc_dtype .id () == plc .TypeId .STRING
293307 or self .obj .type ().id () == plc .TypeId .STRING
294308 ):
295- return Column (self ._handle_string_cast (plc_dtype ), dtype = dtype )
309+ return Column (
310+ self ._handle_string_cast (plc_dtype , stream = stream ), dtype = dtype
311+ )
296312 elif plc .traits .is_integral_not_bool (
297313 self .obj .type ()
298314 ) and plc .traits .is_timestamp (plc_dtype ):
299- upcasted = plc .unary .cast (self .obj , plc .DataType (plc .TypeId .INT64 ))
315+ upcasted = plc .unary .cast (
316+ self .obj , plc .DataType (plc .TypeId .INT64 ), stream = stream
317+ )
300318 plc_col = plc .column .Column (
301319 plc_dtype ,
302320 upcasted .size (),
@@ -319,40 +337,44 @@ def astype(self, dtype: DataType) -> Column:
319337 self .obj .offset (),
320338 self .obj .children (),
321339 )
322- return Column (plc . unary . cast ( plc_col , plc_dtype ), dtype = dtype ). sorted_like (
323- self
324- )
340+ return Column (
341+ plc . unary . cast ( plc_col , plc_dtype , stream = stream ), dtype = dtype
342+ ). sorted_like ( self )
325343 else :
326- result = Column (plc .unary .cast (self .obj , plc_dtype ), dtype = dtype )
344+ result = Column (
345+ plc .unary .cast (self .obj , plc_dtype , stream = stream ), dtype = dtype
346+ )
327347 if is_order_preserving_cast (self .obj .type (), plc_dtype ):
328348 return result .sorted_like (self )
329349 return result
330350
331- def _handle_string_cast (self , dtype : plc .DataType ) -> plc .Column :
351+ def _handle_string_cast (self , dtype : plc .DataType , stream : Stream ) -> plc .Column :
332352 if dtype .id () == plc .TypeId .STRING :
333353 if is_floating_point (self .obj .type ()):
334- return from_floats (self .obj )
354+ return from_floats (self .obj , stream = stream )
335355 else :
336- return from_integers (self .obj )
356+ return from_integers (self .obj , stream = stream )
337357 else :
338358 if is_floating_point (dtype ):
339- floats = is_float (self .obj )
359+ floats = is_float (self .obj , stream = stream )
340360 if not plc .reduce .reduce (
341361 floats ,
342362 plc .aggregation .all (),
343363 plc .DataType (plc .TypeId .BOOL8 ),
364+ stream = stream ,
344365 ).to_py ():
345366 raise InvalidOperationError ("Conversion from `str` failed." )
346367 return to_floats (self .obj , dtype )
347368 else :
348- integers = is_integer (self .obj )
369+ integers = is_integer (self .obj , stream = stream )
349370 if not plc .reduce .reduce (
350371 integers ,
351372 plc .aggregation .all (),
352373 plc .DataType (plc .TypeId .BOOL8 ),
374+ stream = stream ,
353375 ).to_py ():
354376 raise InvalidOperationError ("Conversion from `str` failed." )
355- return to_integers (self .obj , dtype )
377+ return to_integers (self .obj , dtype , stream = stream )
356378
357379 def copy_metadata (self , from_ : pl .Series , / ) -> Self :
358380 """
@@ -439,28 +461,31 @@ def copy(self) -> Self:
439461 dtype = self .dtype ,
440462 )
441463
442- def mask_nans (self ) -> Self :
464+ def mask_nans (self , stream : Stream ) -> Self :
443465 """Return a shallow copy of self with nans masked out."""
444466 if plc .traits .is_floating_point (self .obj .type ()):
445467 old_count = self .null_count
446- mask , new_count = plc .transform .nans_to_nulls (self .obj )
468+ mask , new_count = plc .transform .nans_to_nulls (self .obj , stream = stream )
447469 result = type (self )(self .obj .with_mask (mask , new_count ), self .dtype )
448470 if old_count == new_count :
449471 return result .sorted_like (self )
450472 return result
451473 return self .copy ()
452474
453- @functools .cached_property
454- def nan_count (self ) -> int :
475+ def nan_count (self , stream : Stream ) -> int :
455476 """Return the number of NaN values in the column."""
456- if self .size > 0 and plc .traits .is_floating_point (self .obj .type ()):
457- # See https://github.com/rapidsai/cudf/issues/20202 for we type ignore
458- return plc .reduce .reduce (
459- plc .unary .is_nan (self .obj ),
460- plc .aggregation .sum (),
461- plc .types .SIZE_TYPE ,
462- ).to_py () # type: ignore[return-value]
463- return 0
477+ if self ._nan_count is None :
478+ if self .size > 0 and plc .traits .is_floating_point (self .obj .type ()):
479+ # See https://github.com/rapidsai/cudf/issues/20202 for we type ignore
480+ self ._nan_count = plc .reduce .reduce ( # type: ignore[assignment]
481+ plc .unary .is_nan (self .obj , stream ),
482+ plc .aggregation .sum (),
483+ plc .types .SIZE_TYPE ,
484+ stream = stream ,
485+ ).to_py ()
486+ else :
487+ self ._nan_count = 0
488+ return self ._nan_count # type: ignore[return-value]
464489
465490 @property
466491 def size (self ) -> int :
@@ -472,7 +497,7 @@ def null_count(self) -> int:
472497 """Return the number of Null values in the column."""
473498 return self .obj .null_count ()
474499
475- def slice (self , zlice : Slice | None ) -> Self :
500+ def slice (self , zlice : Slice | None , stream : Stream ) -> Self :
476501 """
477502 Slice a column.
478503
@@ -481,6 +506,9 @@ def slice(self, zlice: Slice | None) -> Self:
481506 zlice
482507 optional, tuple of start and length, negative values of start
483508 treated as for python indexing. If not provided, returns self.
509+ stream
510+ CUDA stream used for device memory operations and kernel launches
511+ on this dataframe. The data in ``self.obj`` must be valid on this stream.
484512
485513 Returns
486514 -------
@@ -491,6 +519,7 @@ def slice(self, zlice: Slice | None) -> Self:
491519 (table ,) = plc .copying .slice (
492520 plc .Table ([self .obj ]),
493521 conversion .from_polars_slice (zlice , num_rows = self .size ),
522+ stream = stream ,
494523 )
495524 (column ,) = table .columns ()
496525 return type (self )(column , name = self .name , dtype = self .dtype ).sorted_like (self )
0 commit comments