Skip to content

Commit 0693083

Browse files
Add join_streams to pylibcudf API (#20316)
This adds a `plc.experimental.join_streams` API to pylibcudf. It wraps `cudf::detail::join_streams`. We've used the `.experimental` namespace in part to reflect the fact that the C++ implementation we're wrapping is in a `detail` namespace. This change will support our work on making cudf-polars use non-default streams for operations, which is important for performance with the in-progress rapidsmpf streaming network runtime. We only need the ability to join some stream to a sequence of upstream streams, so only `join_streams` is wrapped, nothing to do with `stream_pools` that are also provided in that module. Closes #20315 Authors: - Tom Augspurger (https://github.com/TomAugspurger) Approvers: - Lawrence Mitchell (https://github.com/wence-) - Matthew Murray (https://github.com/Matt711) - Bradley Dice (https://github.com/bdice) URL: #20316
1 parent 6f1966d commit 0693083

File tree

16 files changed

+174
-0
lines changed

16 files changed

+174
-0
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
============
2+
Experimental
3+
============
4+
5+
APIs in this namespace are experimental and may change without warning in the future.
6+
7+
.. automodule:: pylibcudf.experimental
8+
:members:

docs/cudf/source/pylibcudf/api_docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ This page provides API documentation for pylibcudf.
5151
:maxdepth: 2
5252
:caption: Subpackages
5353

54+
experimental/index.rst
5455
io/index.rst
5556
strings/index.rst
5657
nvtext/index.rst

python/pylibcudf/pylibcudf/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,4 @@ add_subdirectory(libcudf)
6767
add_subdirectory(strings)
6868
add_subdirectory(io)
6969
add_subdirectory(nvtext)
70+
add_subdirectory(experimental)

python/pylibcudf/pylibcudf/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
contiguous_split,
2020
copying,
2121
datetime,
22+
experimental,
2223
expressions,
2324
filling,
2425
groupby,
@@ -72,6 +73,7 @@
7273
"contiguous_split",
7374
"copying",
7475
"datetime",
76+
"experimental",
7577
"expressions",
7678
"filling",
7779
"gpumemoryview",
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# =============================================================================
2+
# cmake-format: off
3+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
4+
# SPDX-License-Identifier: Apache-2.0
5+
# cmake-format: on
6+
# =============================================================================
7+
8+
set(cython_sources _join_streams.pyx)
9+
10+
set(linked_libraries cudf::cudf)
11+
rapids_cython_create_modules(
12+
CXX
13+
SOURCE_FILES "${cython_sources}"
14+
LINKED_LIBRARIES "${linked_libraries}" MODULE_PREFIX pylibcudf_experimental_ ASSOCIATED_TARGETS
15+
cudf
16+
)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from ._join_streams cimport join_streams
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from ._join_streams import join_streams
5+
6+
__all__ = [
7+
"join_streams",
8+
]
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from rmm.pylibrmm.stream cimport Stream
5+
6+
cpdef void join_streams(list streams, Stream stream)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from rmm.pylibrmm.stream import Stream
5+
6+
def join_streams(streams: list[Stream], stream: Stream) -> None: ...
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from libcpp.vector cimport vector
5+
6+
from pylibcudf.libcudf.detail.utilities cimport stream_pool as cpp_stream_pool
7+
from pylibcudf.libcudf.utilities.span cimport host_span
8+
9+
from rmm.librmm.cuda_stream_view cimport cuda_stream_view
10+
from rmm.pylibrmm.stream cimport Stream
11+
12+
ctypedef const cuda_stream_view const_cuda_stream_view
13+
14+
15+
__all__ = ["join_streams"]
16+
17+
18+
cpdef void join_streams(list streams, Stream stream):
19+
"""Synchronize a stream to an event on a set of streams.
20+
21+
This function synchronizes the joined stream with the waited-on streams
22+
by placing events on each of the waited-on streams and having the joined
23+
stream wait on those events.
24+
25+
Parameters
26+
----------
27+
streams : list
28+
A list of Stream objects to wait on.
29+
stream : Stream
30+
The joined stream that synchronizes with the waited-on streams.
31+
32+
Examples
33+
--------
34+
>>> import pylibcudf as plc
35+
>>> from rmm.pylibrmm.stream import Stream
36+
>>> # Create streams
37+
>>> stream1 = Stream()
38+
>>> stream2 = Stream()
39+
>>> join_stream = Stream()
40+
>>> # ... do work on stream1 and stream2 ...
41+
>>> # Wait for both streams before continuing work on join_stream
42+
>>> plc.experimental.join_streams([stream1, stream2], join_stream)
43+
>>> # ... continue work on join_stream ...
44+
"""
45+
cdef Stream c_stream = <Stream?>stream
46+
cdef vector[cuda_stream_view] c_streams
47+
48+
c_streams.reserve(len(streams))
49+
for s in streams:
50+
c_streams.push_back((<Stream?>s).view())
51+
52+
with nogil:
53+
cpp_stream_pool.join_streams(
54+
host_span[const_cuda_stream_view](c_streams.data(), c_streams.size()),
55+
c_stream.view()
56+
)

0 commit comments

Comments
 (0)