Skip to content

Commit 1d32cdc

Browse files
committed
Add CUDA streams to cudf-polars
This adds CUDA streams to all pylibcudf calls in cudf-polars. At the moment, we continue to use the default stream for all operations, so we're *explicitly* using the default stream. A future PR will update things to use non-default streams.
1 parent e534472 commit 1d32cdc

33 files changed

+1097
-358
lines changed

python/cudf_polars/cudf_polars/containers/column.py

Lines changed: 61 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from __future__ import annotations
77

8-
import functools
98
from typing import TYPE_CHECKING
109

1110
import polars as pl
@@ -28,6 +27,8 @@
2827
if TYPE_CHECKING:
2928
from typing_extensions import Self
3029

30+
from rmm.pylibrmm.stream import Stream
31+
3132
from cudf_polars.typing import (
3233
ColumnHeader,
3334
ColumnOptions,
@@ -82,6 +83,8 @@ def __init__(
8283
self.name = name
8384
self.dtype = dtype
8485
self.set_sorted(is_sorted=is_sorted, order=order, null_order=null_order)
86+
self._nan_count: int | None = None
87+
self._obj_scalar: plc.Scalar | None = None
8588

8689
@classmethod
8790
def deserialize(
@@ -126,6 +129,7 @@ def deserialize_ctor_kwargs(
126129

127130
def serialize(
128131
self,
132+
stream: Stream,
129133
) -> tuple[ColumnHeader, tuple[memoryview[bytes], plc.gpumemoryview]]:
130134
"""
131135
Serialize the Column into header and frames.
@@ -145,7 +149,7 @@ def serialize(
145149
frames
146150
Two-tuple of frames suitable for passing to `plc.contiguous_split.unpack_from_memoryviews`
147151
"""
148-
packed = plc.contiguous_split.pack(plc.Table([self.obj]))
152+
packed = plc.contiguous_split.pack(plc.Table([self.obj]), stream=stream)
149153
header: ColumnHeader = {
150154
"column_kwargs": self.serialize_ctor_kwargs(),
151155
"frame_count": 2,
@@ -162,8 +166,7 @@ def serialize_ctor_kwargs(self) -> ColumnOptions:
162166
"dtype": pl.polars.dtype_str_repr(self.dtype.polars_type),
163167
}
164168

165-
@functools.cached_property
166-
def obj_scalar(self) -> plc.Scalar:
169+
def obj_scalar(self, stream: Stream) -> plc.Scalar:
167170
"""
168171
A copy of the column object as a pylibcudf Scalar.
169172
@@ -178,7 +181,9 @@ def obj_scalar(self) -> plc.Scalar:
178181
"""
179182
if not self.is_scalar:
180183
raise ValueError(f"Cannot convert a column of length {self.size} to scalar")
181-
return plc.copying.get_element(self.obj, 0)
184+
if self._obj_scalar is None:
185+
self._obj_scalar = plc.copying.get_element(self.obj, 0, stream=stream)
186+
return self._obj_scalar
182187

183188
def rename(self, name: str | None, /) -> Self:
184189
"""
@@ -228,6 +233,7 @@ def check_sorted(
228233
*,
229234
order: plc.types.Order,
230235
null_order: plc.types.NullOrder,
236+
stream: Stream,
231237
) -> bool:
232238
"""
233239
Check if the column is sorted.
@@ -238,6 +244,9 @@ def check_sorted(
238244
The requested sort order.
239245
null_order
240246
Where nulls sort to.
247+
stream
248+
CUDA stream used for device memory operations and kernel launches
249+
on this dataframe. The data in ``self.obj`` must be valid on this stream.
241250
242251
Returns
243252
-------
@@ -254,21 +263,26 @@ def check_sorted(
254263
return self.order == order and (
255264
self.null_count == 0 or self.null_order == null_order
256265
)
257-
if plc.sorting.is_sorted(plc.Table([self.obj]), [order], [null_order]):
266+
if plc.sorting.is_sorted(
267+
plc.Table([self.obj]), [order], [null_order], stream=stream
268+
):
258269
self.sorted = plc.types.Sorted.YES
259270
self.order = order
260271
self.null_order = null_order
261272
return True
262273
return False
263274

264-
def astype(self, dtype: DataType) -> Column:
275+
def astype(self, dtype: DataType, stream: Stream) -> Column:
265276
"""
266277
Cast the column to as the requested dtype.
267278
268279
Parameters
269280
----------
270281
dtype
271282
Datatype to cast to.
283+
stream
284+
CUDA stream used for device memory operations and kernel launches
285+
on this dataframe. The data in ``self.obj`` must be valid on this stream.
272286
273287
Returns
274288
-------
@@ -292,11 +306,15 @@ def astype(self, dtype: DataType) -> Column:
292306
plc_dtype.id() == plc.TypeId.STRING
293307
or self.obj.type().id() == plc.TypeId.STRING
294308
):
295-
return Column(self._handle_string_cast(plc_dtype), dtype=dtype)
309+
return Column(
310+
self._handle_string_cast(plc_dtype, stream=stream), dtype=dtype
311+
)
296312
elif plc.traits.is_integral_not_bool(
297313
self.obj.type()
298314
) and plc.traits.is_timestamp(plc_dtype):
299-
upcasted = plc.unary.cast(self.obj, plc.DataType(plc.TypeId.INT64))
315+
upcasted = plc.unary.cast(
316+
self.obj, plc.DataType(plc.TypeId.INT64), stream=stream
317+
)
300318
plc_col = plc.column.Column(
301319
plc_dtype,
302320
upcasted.size(),
@@ -319,40 +337,44 @@ def astype(self, dtype: DataType) -> Column:
319337
self.obj.offset(),
320338
self.obj.children(),
321339
)
322-
return Column(plc.unary.cast(plc_col, plc_dtype), dtype=dtype).sorted_like(
323-
self
324-
)
340+
return Column(
341+
plc.unary.cast(plc_col, plc_dtype, stream=stream), dtype=dtype
342+
).sorted_like(self)
325343
else:
326-
result = Column(plc.unary.cast(self.obj, plc_dtype), dtype=dtype)
344+
result = Column(
345+
plc.unary.cast(self.obj, plc_dtype, stream=stream), dtype=dtype
346+
)
327347
if is_order_preserving_cast(self.obj.type(), plc_dtype):
328348
return result.sorted_like(self)
329349
return result
330350

331-
def _handle_string_cast(self, dtype: plc.DataType) -> plc.Column:
351+
def _handle_string_cast(self, dtype: plc.DataType, stream: Stream) -> plc.Column:
332352
if dtype.id() == plc.TypeId.STRING:
333353
if is_floating_point(self.obj.type()):
334-
return from_floats(self.obj)
354+
return from_floats(self.obj, stream=stream)
335355
else:
336-
return from_integers(self.obj)
356+
return from_integers(self.obj, stream=stream)
337357
else:
338358
if is_floating_point(dtype):
339-
floats = is_float(self.obj)
359+
floats = is_float(self.obj, stream=stream)
340360
if not plc.reduce.reduce(
341361
floats,
342362
plc.aggregation.all(),
343363
plc.DataType(plc.TypeId.BOOL8),
364+
stream=stream,
344365
).to_py():
345366
raise InvalidOperationError("Conversion from `str` failed.")
346367
return to_floats(self.obj, dtype)
347368
else:
348-
integers = is_integer(self.obj)
369+
integers = is_integer(self.obj, stream=stream)
349370
if not plc.reduce.reduce(
350371
integers,
351372
plc.aggregation.all(),
352373
plc.DataType(plc.TypeId.BOOL8),
374+
stream=stream,
353375
).to_py():
354376
raise InvalidOperationError("Conversion from `str` failed.")
355-
return to_integers(self.obj, dtype)
377+
return to_integers(self.obj, dtype, stream=stream)
356378

357379
def copy_metadata(self, from_: pl.Series, /) -> Self:
358380
"""
@@ -439,28 +461,31 @@ def copy(self) -> Self:
439461
dtype=self.dtype,
440462
)
441463

442-
def mask_nans(self) -> Self:
464+
def mask_nans(self, stream: Stream) -> Self:
443465
"""Return a shallow copy of self with nans masked out."""
444466
if plc.traits.is_floating_point(self.obj.type()):
445467
old_count = self.null_count
446-
mask, new_count = plc.transform.nans_to_nulls(self.obj)
468+
mask, new_count = plc.transform.nans_to_nulls(self.obj, stream=stream)
447469
result = type(self)(self.obj.with_mask(mask, new_count), self.dtype)
448470
if old_count == new_count:
449471
return result.sorted_like(self)
450472
return result
451473
return self.copy()
452474

453-
@functools.cached_property
454-
def nan_count(self) -> int:
475+
def nan_count(self, stream: Stream) -> int:
455476
"""Return the number of NaN values in the column."""
456-
if self.size > 0 and plc.traits.is_floating_point(self.obj.type()):
457-
# See https://github.com/rapidsai/cudf/issues/20202 for we type ignore
458-
return plc.reduce.reduce(
459-
plc.unary.is_nan(self.obj),
460-
plc.aggregation.sum(),
461-
plc.types.SIZE_TYPE,
462-
).to_py() # type: ignore[return-value]
463-
return 0
477+
if self._nan_count is None:
478+
if self.size > 0 and plc.traits.is_floating_point(self.obj.type()):
479+
# See https://github.com/rapidsai/cudf/issues/20202 for we type ignore
480+
self._nan_count = plc.reduce.reduce( # type: ignore[assignment]
481+
plc.unary.is_nan(self.obj, stream),
482+
plc.aggregation.sum(),
483+
plc.types.SIZE_TYPE,
484+
stream=stream,
485+
).to_py()
486+
else:
487+
self._nan_count = 0
488+
return self._nan_count # type: ignore[return-value]
464489

465490
@property
466491
def size(self) -> int:
@@ -472,7 +497,7 @@ def null_count(self) -> int:
472497
"""Return the number of Null values in the column."""
473498
return self.obj.null_count()
474499

475-
def slice(self, zlice: Slice | None) -> Self:
500+
def slice(self, zlice: Slice | None, stream: Stream) -> Self:
476501
"""
477502
Slice a column.
478503
@@ -481,6 +506,9 @@ def slice(self, zlice: Slice | None) -> Self:
481506
zlice
482507
optional, tuple of start and length, negative values of start
483508
treated as for python indexing. If not provided, returns self.
509+
stream
510+
CUDA stream used for device memory operations and kernel launches
511+
on this dataframe. The data in ``self.obj`` must be valid on this stream.
484512
485513
Returns
486514
-------
@@ -491,6 +519,7 @@ def slice(self, zlice: Slice | None) -> Self:
491519
(table,) = plc.copying.slice(
492520
plc.Table([self.obj]),
493521
conversion.from_polars_slice(zlice, num_rows=self.size),
522+
stream=stream,
494523
)
495524
(column,) = table.columns()
496525
return type(self)(column, name=self.name, dtype=self.dtype).sorted_like(self)

python/cudf_polars/cudf_polars/containers/dataframe.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ def deserialize(
252252

253253
def serialize(
254254
self,
255+
stream: Stream | None = None,
255256
) -> tuple[DataFrameHeader, tuple[memoryview[bytes], plc.gpumemoryview]]:
256257
"""
257258
Serialize the table into header and frames.
@@ -264,14 +265,20 @@ def serialize(
264265
>>> from cudf_polars.experimental.dask_serialize import register
265266
>>> register()
266267
268+
Parameters
269+
----------
270+
stream
271+
CUDA stream used for device memory operations and kernel launches
272+
on this dataframe.
273+
267274
Returns
268275
-------
269276
header
270277
A dict containing any picklable metadata required to reconstruct the object.
271278
frames
272279
Two-tuple of frames suitable for passing to `plc.contiguous_split.unpack_from_memoryviews`
273280
"""
274-
packed = plc.contiguous_split.pack(self.table, stream=self.stream)
281+
packed = plc.contiguous_split.pack(self.table, stream=stream)
275282

276283
# Keyword arguments for `Column.__init__`.
277284
columns_kwargs: list[ColumnOptions] = [

0 commit comments

Comments
 (0)