Skip to content

Commit adc8b23

Browse files
authored
[cudf-polars] CUDA stream (#20154)
`DataFrame` now has an associated CUDA stream, which must be used by `do_evaluate` et al. At the moment, everything will use the default stream so this should behave the same. Once the CUDA stream is explicitly used everywhere, a future PR will explicitly give different dataframes their own streams. Part of #20228. Closes #20241 Authors: - Mads R. B. Kristensen (https://github.com/madsbk) - Tom Augspurger (https://github.com/TomAugspurger) Approvers: - Tom Augspurger (https://github.com/TomAugspurger) - Lawrence Mitchell (https://github.com/wence-) - Matthew Murray (https://github.com/Matt711) URL: #20154
1 parent cfdf7c4 commit adc8b23

File tree

13 files changed

+518
-143
lines changed

13 files changed

+518
-143
lines changed

python/cudf_polars/cudf_polars/containers/dataframe.py

Lines changed: 99 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020

2121
from typing_extensions import Any, CapsuleType, Self
2222

23-
from cudf_polars.typing import ColumnOptions, DataFrameHeader, PolarsDataType, Slice
23+
from rmm.pylibrmm.stream import Stream
2424

25+
from cudf_polars.typing import ColumnOptions, DataFrameHeader, PolarsDataType, Slice
2526

2627
__all__: list[str] = ["DataFrame"]
2728

@@ -78,19 +79,21 @@ class DataFrame:
7879
column_map: dict[str, Column]
7980
table: plc.Table
8081
columns: list[NamedColumn]
82+
stream: Stream
8183

82-
def __init__(self, columns: Iterable[Column]) -> None:
84+
def __init__(self, columns: Iterable[Column], stream: Stream) -> None:
8385
columns = list(columns)
8486
if any(c.name is None for c in columns):
8587
raise ValueError("All columns must have a name")
8688
self.columns = [cast(NamedColumn, c) for c in columns]
8789
self.dtypes = [c.dtype for c in self.columns]
8890
self.column_map = {c.name: c for c in self.columns}
8991
self.table = plc.Table([c.obj for c in self.columns])
92+
self.stream = stream
9093

9194
def copy(self) -> Self:
9295
"""Return a shallow copy of self."""
93-
return type(self)(c.copy() for c in self.columns)
96+
return type(self)((c.copy() for c in self.columns), stream=self.stream)
9497

9598
def to_polars(self) -> pl.DataFrame:
9699
"""Convert to a polars DataFrame."""
@@ -135,30 +138,42 @@ def num_rows(self) -> int:
135138
return self.table.num_rows() if self.column_map else 0
136139

137140
@classmethod
138-
def from_polars(cls, df: pl.DataFrame) -> Self:
141+
def from_polars(cls, df: pl.DataFrame, stream: Stream) -> Self:
139142
"""
140143
Create from a polars dataframe.
141144
142145
Parameters
143146
----------
144147
df
145148
Polars dataframe to convert
149+
stream
150+
CUDA stream used for device memory operations and kernel launches
151+
on this dataframe.
146152
147153
Returns
148154
-------
149155
New dataframe representing the input.
150156
"""
151-
plc_table = plc.Table.from_arrow(df)
157+
plc_table = plc.Table.from_arrow(df, stream=stream)
152158
return cls(
153-
Column(d_col, name=name, dtype=DataType(h_col.dtype)).copy_metadata(h_col)
154-
for d_col, h_col, name in zip(
155-
plc_table.columns(), df.iter_columns(), df.columns, strict=True
156-
)
159+
(
160+
Column(d_col, name=name, dtype=DataType(h_col.dtype)).copy_metadata(
161+
h_col
162+
)
163+
for d_col, h_col, name in zip(
164+
plc_table.columns(), df.iter_columns(), df.columns, strict=True
165+
)
166+
),
167+
stream=stream,
157168
)
158169

159170
@classmethod
160171
def from_table(
161-
cls, table: plc.Table, names: Sequence[str], dtypes: Sequence[DataType]
172+
cls,
173+
table: plc.Table,
174+
names: Sequence[str],
175+
dtypes: Sequence[DataType],
176+
stream: Stream,
162177
) -> Self:
163178
"""
164179
Create from a pylibcudf table.
@@ -171,6 +186,10 @@ def from_table(
171186
Names for the columns
172187
dtypes
173188
Dtypes for the columns
189+
stream
190+
CUDA stream used for device memory operations and kernel launches
191+
on this dataframe. The caller is responsible for ensuring that
192+
the data in ``table`` is valid on ``stream``.
174193
175194
Returns
176195
-------
@@ -185,15 +204,19 @@ def from_table(
185204
if table.num_columns() != len(names):
186205
raise ValueError("Mismatching name and table length.")
187206
return cls(
188-
Column(c, name=name, dtype=dtype)
189-
for c, name, dtype in zip(table.columns(), names, dtypes, strict=True)
207+
(
208+
Column(c, name=name, dtype=dtype)
209+
for c, name, dtype in zip(table.columns(), names, dtypes, strict=True)
210+
),
211+
stream=stream,
190212
)
191213

192214
@classmethod
193215
def deserialize(
194216
cls,
195217
header: DataFrameHeader,
196218
frames: tuple[memoryview[bytes], plc.gpumemoryview],
219+
stream: Stream,
197220
) -> Self:
198221
"""
199222
Create a DataFrame from a serialized representation returned by `.serialize()`.
@@ -204,6 +227,10 @@ def deserialize(
204227
The (unpickled) metadata required to reconstruct the object.
205228
frames
206229
Two-tuple of frames (a memoryview and a gpumemoryview).
230+
stream
231+
CUDA stream used for device memory operations and kernel launches
232+
on this dataframe. The caller is responsible for ensuring that
233+
the data in ``frames`` is valid on ``stream``.
207234
208235
Returns
209236
-------
@@ -212,11 +239,15 @@ def deserialize(
212239
"""
213240
packed_metadata, packed_gpu_data = frames
214241
table = plc.contiguous_split.unpack_from_memoryviews(
215-
packed_metadata, packed_gpu_data
242+
packed_metadata,
243+
packed_gpu_data,
216244
)
217245
return cls(
218-
Column(c, **Column.deserialize_ctor_kwargs(kw))
219-
for c, kw in zip(table.columns(), header["columns_kwargs"], strict=True)
246+
(
247+
Column(c, **Column.deserialize_ctor_kwargs(kw))
248+
for c, kw in zip(table.columns(), header["columns_kwargs"], strict=True)
249+
),
250+
stream=stream,
220251
)
221252

222253
def serialize(
@@ -240,7 +271,7 @@ def serialize(
240271
frames
241272
Two-tuple of frames suitable for passing to `plc.contiguous_split.unpack_from_memoryviews`
242273
"""
243-
packed = plc.contiguous_split.pack(self.table)
274+
packed = plc.contiguous_split.pack(self.table, stream=self.stream)
244275

245276
# Keyword arguments for `Column.__init__`.
246277
columns_kwargs: list[ColumnOptions] = [
@@ -278,12 +309,19 @@ def sorted_like(
278309
raise ValueError("Can only copy from identically named frame")
279310
subset = self.column_names_set if subset is None else subset
280311
return type(self)(
281-
c.sorted_like(other) if c.name in subset else c
282-
for c, other in zip(self.columns, like.columns, strict=True)
312+
(
313+
c.sorted_like(other) if c.name in subset else c
314+
for c, other in zip(self.columns, like.columns, strict=True)
315+
),
316+
stream=self.stream,
283317
)
284318

285319
def with_columns(
286-
self, columns: Iterable[Column], *, replace_only: bool = False
320+
self,
321+
columns: Iterable[Column],
322+
*,
323+
replace_only: bool = False,
324+
stream: Stream,
287325
) -> Self:
288326
"""
289327
Return a new dataframe with extra columns.
@@ -294,6 +332,13 @@ def with_columns(
294332
Columns to add
295333
replace_only
296334
If true, then only replacements are allowed (matching by name).
335+
stream
336+
CUDA stream used for device memory operations and kernel launches.
337+
The caller is responsible for ensuring that
338+
339+
1. The data in ``columns`` is valid on ``stream``.
340+
2. No additional operations occur on ``self.stream`` with the
341+
original data in ``self``.
297342
298343
Returns
299344
-------
@@ -307,33 +352,57 @@ def with_columns(
307352
new = {c.name: c for c in columns}
308353
if replace_only and not self.column_names_set.issuperset(new.keys()):
309354
raise ValueError("Cannot replace with non-existing names")
310-
return type(self)((self.column_map | new).values())
355+
return type(self)((self.column_map | new).values(), stream=stream)
311356

312357
def discard_columns(self, names: Set[str]) -> Self:
313358
"""Drop columns by name."""
314-
return type(self)(column for column in self.columns if column.name not in names)
359+
return type(self)(
360+
(column for column in self.columns if column.name not in names),
361+
stream=self.stream,
362+
)
315363

316364
def select(self, names: Sequence[str] | Mapping[str, Any]) -> Self:
317365
"""Select columns by name returning DataFrame."""
318366
try:
319-
return type(self)(self.column_map[name] for name in names)
367+
return type(self)(
368+
(self.column_map[name] for name in names), stream=self.stream
369+
)
320370
except KeyError as e:
321371
raise ValueError("Can't select missing names") from e
322372

323373
def rename_columns(self, mapping: Mapping[str, str]) -> Self:
324374
"""Rename some columns."""
325-
return type(self)(c.rename(mapping.get(c.name, c.name)) for c in self.columns)
375+
return type(self)(
376+
(c.rename(mapping.get(c.name, c.name)) for c in self.columns),
377+
stream=self.stream,
378+
)
326379

327380
def select_columns(self, names: Set[str]) -> list[Column]:
328381
"""Select columns by name."""
329382
return [c for c in self.columns if c.name in names]
330383

331384
def filter(self, mask: Column) -> Self:
332-
"""Return a filtered table given a mask."""
333-
table = plc.stream_compaction.apply_boolean_mask(self.table, mask.obj)
385+
"""
386+
Return a filtered table given a mask.
387+
388+
Parameters
389+
----------
390+
mask
391+
Boolean mask to apply to the dataframe. It is the caller's
392+
responsibility to ensure that ``mask`` is valid on ``self.stream``.
393+
A mask that is derived from ``self`` via a computation on ``self.stream``
394+
automatically satisfies this requirement.
395+
396+
Returns
397+
-------
398+
Filtered dataframe
399+
"""
400+
table = plc.stream_compaction.apply_boolean_mask(
401+
self.table, mask.obj, stream=self.stream
402+
)
334403
return (
335404
type(self)
336-
.from_table(table, self.column_names, self.dtypes)
405+
.from_table(table, self.column_names, self.dtypes, self.stream)
337406
.sorted_like(self)
338407
)
339408

@@ -354,10 +423,12 @@ def slice(self, zlice: Slice | None) -> Self:
354423
if zlice is None:
355424
return self
356425
(table,) = plc.copying.slice(
357-
self.table, conversion.from_polars_slice(zlice, num_rows=self.num_rows)
426+
self.table,
427+
conversion.from_polars_slice(zlice, num_rows=self.num_rows),
428+
stream=self.stream,
358429
)
359430
return (
360431
type(self)
361-
.from_table(table, self.column_names, self.dtypes)
432+
.from_table(table, self.column_names, self.dtypes, self.stream)
362433
.sorted_like(self)
363434
)

python/cudf_polars/cudf_polars/dsl/expressions/rolling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -854,5 +854,5 @@ def do_evaluate( # noqa: D102
854854

855855
# Create a temporary DataFrame with the broadcasted columns named by their
856856
# placeholder names from agg decomposition, then evaluate the post-expression.
857-
df = DataFrame(broadcasted_cols)
857+
df = DataFrame(broadcasted_cols, stream=df.stream)
858858
return self.post.value.evaluate(df, context=ExecutionContext.FRAME)

0 commit comments

Comments
 (0)