Skip to content

Commit e568adf

Browse files
Add CUDA stream to cudf_polars.Column.deserialize (#20396)
Column.deserialize uses `plc.contiguous_split.unpack_from_memoryviews`, which accepts a stream argument. This also fixes the type annotation for `unpack_from_memoryviews` and `PackedColumns.unpack` which were missing a `Stream` argument. Authors: - Tom Augspurger (https://github.com/TomAugspurger) Approvers: - Matthew Murray (https://github.com/Matt711) URL: #20396
1 parent d6e5194 commit e568adf

File tree

5 files changed

+21
-7
lines changed

5 files changed

+21
-7
lines changed

python/cudf_polars/cudf_polars/containers/column.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,10 @@ def __init__(
7171

7272
@classmethod
7373
def deserialize(
74-
cls, header: ColumnHeader, frames: tuple[memoryview[bytes], plc.gpumemoryview]
74+
cls,
75+
header: ColumnHeader,
76+
frames: tuple[memoryview[bytes], plc.gpumemoryview],
77+
stream: Stream,
7578
) -> Self:
7679
"""
7780
Create a Column from a serialized representation returned by `.serialize()`.
@@ -82,6 +85,10 @@ def deserialize(
8285
The (unpickled) metadata required to reconstruct the object.
8386
frames
8487
Two-tuple of frames (a memoryview and a gpumemoryview).
88+
stream
89+
CUDA stream used for device memory operations and kernel launches
90+
on this column. The caller is responsible for ensuring that
91+
the data in ``frames`` is valid on ``stream``.
8592
8693
Returns
8794
-------
@@ -90,7 +97,7 @@ def deserialize(
9097
"""
9198
packed_metadata, packed_gpu_data = frames
9299
(plc_column,) = plc.contiguous_split.unpack_from_memoryviews(
93-
packed_metadata, packed_gpu_data
100+
packed_metadata, packed_gpu_data, stream
94101
).columns()
95102
return cls(plc_column, **cls.deserialize_ctor_kwargs(header["column_kwargs"]))
96103

python/cudf_polars/cudf_polars/containers/dataframe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def deserialize(
249249
table = plc.contiguous_split.unpack_from_memoryviews(
250250
packed_metadata,
251251
packed_gpu_data,
252+
stream,
252253
)
253254
return cls(
254255
(

python/cudf_polars/cudf_polars/experimental/dask_registers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,11 @@ def _(
100100
) -> Column:
101101
with log_errors():
102102
metadata, gpudata = frames
103-
return Column.deserialize(header, (metadata, plc.gpumemoryview(gpudata)))
103+
return Column.deserialize(
104+
header,
105+
(metadata, plc.gpumemoryview(gpudata)),
106+
stream=get_dask_cuda_stream(),
107+
)
104108

105109
@overload
106110
def dask_serialize_column_or_frame(
@@ -144,7 +148,7 @@ def _(header: ColumnHeader, frames: tuple[memoryview[bytes], memoryview]) -> Col
144148
frames[0],
145149
plc.gpumemoryview(rmm.DeviceBuffer.to_device(frames[1])),
146150
)
147-
return Column.deserialize(header, new_frames)
151+
return Column.deserialize(header, new_frames, stream=get_dask_cuda_stream())
148152

149153
@dask_serialize.register(DataFrame)
150154
def _(

python/cudf_polars/tests/containers/test_column.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def test_serialize_cache_miss():
230230
# the same hash, so they cache the same, but have some difference
231231
# in behavior (e.g. isinstance).
232232
cudf_polars.containers.datatype._from_polars.cache_clear()
233-
result = Column.deserialize(header, frames)
233+
result = Column.deserialize(header, frames, stream=stream)
234234
assert result.dtype == dtype
235235

236236

python/pylibcudf/pylibcudf/contiguous_split.pyi

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@ class PackedColumns:
1515
) -> tuple[memoryview[bytes], gpumemoryview]: ...
1616

1717
def pack(input: Table, stream: Stream | None = None) -> PackedColumns: ...
18-
def unpack(input: PackedColumns) -> Table: ...
18+
def unpack(input: PackedColumns, stream: Stream | None = None) -> Table: ...
1919
def unpack_from_memoryviews(
20-
metadata: memoryview[bytes], gpu_data: gpumemoryview
20+
metadata: memoryview[bytes],
21+
gpu_data: gpumemoryview,
22+
stream: Stream | None = None,
2123
) -> Table: ...
2224

2325
class ChunkedPack:

0 commit comments

Comments
 (0)