Skip to content

Commit 6f1966d

Browse files
Use stream in cudf_polars.DataFrame.to_polars (#20323)
This ensures that the device-to-host transfer used in `cudf_polars.DataFrame.to_polars` performed on the stream associated with the dataframe. As a rough test, I updated our stream generation function to always return a new stream: ```diff diff --git a/python/cudf_polars/cudf_polars/utils/cuda_stream.py b/python/cudf_polars/cudf_polars/utils/cuda_stream.py index 4f8d540..a3cf6c9 100644 --- a/python/cudf_polars/cudf_polars/utils/cuda_stream.py +++ b/python/cudf_polars/cudf_polars/utils/cuda_stream.py @@ -7,7 +7,7 @@ from __future__ import annotations from typing import TYPE_CHECKING -from rmm.pylibrmm.stream import DEFAULT_STREAM +from rmm.pylibrmm.stream import DEFAULT_STREAM, Stream if TYPE_CHECKING: from collections.abc import Iterable @@ -22,7 +22,7 @@ def get_dask_cuda_stream() -> Stream: def get_cuda_stream() -> Stream: """Get the default CUDA stream for the current thread.""" - return DEFAULT_STREAM + return Stream() def join_cuda_streams( ``` And then ran that under nsys. Here's a screenshot of the device-to-host copy that was previously on the default stream. <img width="1427" height="257" alt="Screenshot 2025-10-21 at 10 31 22 AM" src="https://github.com/user-attachments/assets/5a210e06-9e00-469c-a9e1-1000989b2bb0" /> It's not visible in the screenshot, but there are other `to_arrow_host` calls that still show up on the default stream. I'm tracking those down still. Closes #20309 Authors: - Tom Augspurger (https://github.com/TomAugspurger) Approvers: - Matthew Roeschke (https://github.com/mroeschke) URL: #20323
1 parent e10364d commit 6f1966d

File tree

9 files changed

+58
-27
lines changed

9 files changed

+58
-27
lines changed

python/cudf_polars/cudf_polars/containers/dataframe.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,21 @@ def _create_polars_column_metadata(
5656
# This is also defined in pylibcudf.interop
5757
class _ObjectWithArrowMetadata:
5858
def __init__(
59-
self, obj: plc.Table | plc.Column, metadata: list[plc.interop.ColumnMetadata]
59+
self,
60+
obj: plc.Table | plc.Column,
61+
metadata: list[plc.interop.ColumnMetadata],
62+
stream: Stream,
6063
) -> None:
6164
self.obj = obj
6265
self.metadata = metadata
66+
self.stream = stream
6367

6468
def __arrow_c_array__(
6569
self, requested_schema: None = None
6670
) -> tuple[CapsuleType, CapsuleType]:
67-
return self.obj._to_schema(self.metadata), self.obj._to_host_array()
71+
return self.obj._to_schema(self.metadata), self.obj._to_host_array(
72+
stream=self.stream
73+
)
6874

6975

7076
# Pacify the type checker. DataFrame init asserts that all the columns
@@ -108,7 +114,9 @@ def to_polars(self) -> pl.DataFrame:
108114
_create_polars_column_metadata(name, dtype.polars_type)
109115
for name, dtype in zip(name_map, self.dtypes, strict=True)
110116
]
111-
table_with_metadata = _ObjectWithArrowMetadata(self.table, metadata)
117+
table_with_metadata = _ObjectWithArrowMetadata(
118+
self.table, metadata, self.stream
119+
)
112120
df = pl.DataFrame(table_with_metadata)
113121
return df.rename(name_map).with_columns(
114122
pl.col(c.name).set_sorted(descending=c.order == plc.types.Order.DESCENDING)

python/pylibcudf/pylibcudf/_interop_helpers.pyx

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ from pylibcudf.libcudf.interop cimport (
1212
release_arrow_device_array_raw,
1313
release_arrow_schema_raw,
1414
)
15+
from .utils cimport _get_stream
1516

1617
from dataclasses import dataclass, field
1718

@@ -33,12 +34,16 @@ class ArrowLike(metaclass=_ArrowLikeMeta):
3334

3435

3536
class _ObjectWithArrowMetadata:
36-
def __init__(self, obj, metadata=None):
37+
def __init__(self, obj, metadata=None, stream=None):
3738
self.obj = obj
3839
self.metadata = metadata
40+
self.stream = _get_stream(stream)
3941

4042
def __arrow_c_array__(self, requested_schema=None):
41-
return self.obj._to_schema(self.metadata), self.obj._to_host_array()
43+
return (
44+
self.obj._to_schema(self.metadata),
45+
self.obj._to_host_array(stream=self.stream),
46+
)
4247

4348

4449
@dataclass

python/pylibcudf/pylibcudf/column.pyi

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION.
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
22
# SPDX-License-Identifier: Apache-2.0
33

44
from collections.abc import Iterable, Sequence
@@ -89,11 +89,13 @@ class Column:
8989
def from_rmm_buffer(
9090
buff: DeviceBuffer, dtype: DataType, size: int, children: list[Column]
9191
) -> Column: ...
92-
def to_arrow(self, metadata: list | str | None = None) -> ArrowLike: ...
92+
def to_arrow(
93+
self, metadata: list | str | None = None, stream: Stream | None = None
94+
) -> ArrowLike: ...
9395
# Private methods below are included because polars is currently using them,
9496
# but we want to remove stubs for these private methods eventually
9597
def _to_schema(self, metadata: Any = None) -> Any: ...
96-
def _to_host_array(self) -> Any: ...
98+
def _to_host_array(self, stream: Stream) -> Any: ...
9799
@staticmethod
98100
def from_arrow(
99101
obj: ArrowLike,

python/pylibcudf/pylibcudf/column.pyx

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -349,14 +349,17 @@ cdef class Column:
349349

350350
def to_arrow(
351351
self,
352-
metadata: ColumnMetadata | str | None = None
352+
metadata: ColumnMetadata | str | None = None,
353+
stream: Stream = None,
353354
) -> ArrowLike:
354355
"""Create a pyarrow array from a pylibcudf column.
355356

356357
Parameters
357358
----------
358359
metadata : ColumnMetadata | str | None
359360
The metadata to attach to the column.
361+
stream : Stream | None
362+
CUDA stream on which to perform the operation.
360363

361364
Returns
362365
-------
@@ -371,7 +374,7 @@ cdef class Column:
371374
# TODO: Once the arrow C device interface registers more
372375
# types that it supports, we can call pa.array(self) if
373376
# no metadata is passed.
374-
return pa.array(_ObjectWithArrowMetadata(self, metadata))
377+
return pa.array(_ObjectWithArrowMetadata(self, metadata, stream))
375378

376379
@staticmethod
377380
def from_arrow(
@@ -1324,10 +1327,10 @@ cdef class Column:
13241327

13251328
return PyCapsule_New(<void*>raw_schema_ptr, 'arrow_schema', _release_schema)
13261329

1327-
def _to_host_array(self):
1330+
def _to_host_array(self, Stream stream):
13281331
cdef ArrowArray* raw_host_array_ptr
13291332
with nogil:
1330-
raw_host_array_ptr = to_arrow_host_raw(self.view())
1333+
raw_host_array_ptr = to_arrow_host_raw(self.view(), stream.view())
13311334

13321335
return PyCapsule_New(<void*>raw_host_array_ptr, "arrow_array", _release_array)
13331336

@@ -1346,7 +1349,7 @@ cdef class Column:
13461349
if requested_schema is not None:
13471350
raise ValueError("pylibcudf.Column does not support alternative schema")
13481351

1349-
return self._to_schema(), self._to_host_array()
1352+
return self._to_schema(), self._to_host_array(_get_stream(None))
13501353

13511354
def __arrow_c_device_array__(self, requested_schema=None, **kwargs):
13521355
if requested_schema is not None:

python/pylibcudf/pylibcudf/libcudf/interop.pxd

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ cdef extern from *:
135135
template <typename ViewType>
136136
ArrowArray* to_arrow_host_raw(
137137
ViewType const& obj,
138-
rmm::cuda_stream_view stream = cudf::get_default_stream(),
138+
rmm::cuda_stream_view stream,
139139
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref()) {
140140
ArrowArray *arr = new ArrowArray();
141141
auto device_arr = cudf::to_arrow_host(obj, stream, mr);
@@ -221,10 +221,12 @@ cdef extern from *:
221221
ArrowSchema *
222222
) except +libcudf_exception_handler nogil
223223
cdef ArrowArray* to_arrow_host_raw(
224-
const table_view& tbl
224+
const table_view& tbl,
225+
cuda_stream_view stream,
225226
) except +libcudf_exception_handler nogil
226227
cdef ArrowArray* to_arrow_host_raw(
227-
const column_view& tbl
228+
const column_view& tbl,
229+
cuda_stream_view stream,
228230
) except +libcudf_exception_handler nogil
229231
cdef void release_arrow_array_raw(
230232
ArrowArray *

python/pylibcudf/pylibcudf/table.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION.
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
22
# SPDX-License-Identifier: Apache-2.0
33

44
from typing import Any
@@ -20,7 +20,7 @@ class Table:
2020
# Private methods below are included because polars is currently using them,
2121
# but we want to remove stubs for these private methods eventually
2222
def _to_schema(self, metadata: Any = None) -> Any: ...
23-
def _to_host_array(self) -> Any: ...
23+
def _to_host_array(self, stream: Stream) -> Any: ...
2424
@staticmethod
2525
def from_arrow(
2626
arrow_like: ArrowLike,

python/pylibcudf/pylibcudf/table.pyx

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,17 @@ cdef class Table:
7474

7575
def to_arrow(
7676
self,
77-
metadata: list[ColumnMetadata | str] | None = None
77+
metadata: list[ColumnMetadata | str] | None = None,
78+
stream: Stream = None,
7879
) -> ArrowLike:
7980
"""Create a pyarrow table from a pylibcudf table.
8081

8182
Parameters
8283
----------
8384
metadata : list[ColumnMetadata | str] | None
8485
The metadata to attach to the columns of the table.
86+
stream : Stream | None
87+
CUDA stream on which to perform the operation.
8588

8689
Returns
8790
-------
@@ -96,7 +99,7 @@ cdef class Table:
9699
# TODO: Once the arrow C device interface registers more
97100
# types that it supports, we can call pa.table(self) if
98101
# no metadata is passed.
99-
return pa.table(_ObjectWithArrowMetadata(self, metadata))
102+
return pa.table(_ObjectWithArrowMetadata(self, metadata, stream))
100103

101104
@staticmethod
102105
def from_arrow(
@@ -320,10 +323,11 @@ cdef class Table:
320323

321324
return PyCapsule_New(<void*>raw_schema_ptr, "arrow_schema", _release_schema)
322325

323-
def _to_host_array(self):
326+
def _to_host_array(self, Stream stream):
324327
cdef ArrowArray* raw_host_array_ptr
328+
325329
with nogil:
326-
raw_host_array_ptr = to_arrow_host_raw(self.view())
330+
raw_host_array_ptr = to_arrow_host_raw(self.view(), stream.view())
327331

328332
return PyCapsule_New(<void*>raw_host_array_ptr, "arrow_array", _release_array)
329333

@@ -342,7 +346,7 @@ cdef class Table:
342346
if requested_schema is not None:
343347
raise ValueError("pylibcudf.Table does not support alternative schema")
344348

345-
return self._to_schema(), self._to_host_array()
349+
return self._to_schema(), self._to_host_array(_get_stream(None))
346350

347351
def __arrow_c_device_array__(self, requested_schema=None, **kwargs):
348352
if requested_schema is not None:
Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import pytest
45
from utils import assert_column_eq
56

7+
from rmm.pylibrmm.stream import Stream
68

7-
def test_column_to_arrow(table_data):
9+
10+
@pytest.mark.parametrize("stream", [None, Stream()])
11+
def test_column_to_arrow(table_data, stream):
812
plc_tbl, _ = table_data
913
for col in plc_tbl.tbl.columns():
10-
assert_column_eq(col, col.to_arrow())
14+
assert_column_eq(col, col.to_arrow(stream=stream))

python/pylibcudf/tests/test_table.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import pytest
66
from utils import assert_table_eq
77

8+
from rmm.pylibrmm.stream import Stream
9+
810
import pylibcudf as plc
911

1012

@@ -23,10 +25,11 @@ def test_table_shape(arrow_tbl):
2325
assert plc_tbl.shape() == arrow_tbl.shape
2426

2527

26-
def test_table_to_arrow(table_data):
28+
@pytest.mark.parametrize("stream", [None, Stream()])
29+
def test_table_to_arrow(table_data, stream):
2730
plc_tbl, _ = table_data
2831
expect = plc_tbl.tbl
29-
got = expect.to_arrow()
32+
got = expect.to_arrow(stream=stream)
3033
# The order of `got` and `expect` is reversed here
3134
# because in almost all pylibcudf tests the `expect`
3235
# is a pyarrow object while `got` is a pylibcudf object,

0 commit comments

Comments
 (0)