|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +from libcpp.vector cimport vector |
| 5 | + |
| 6 | +from pylibcudf.libcudf.detail.utilities cimport stream_pool as cpp_stream_pool |
| 7 | +from pylibcudf.libcudf.utilities.span cimport host_span |
| 8 | + |
| 9 | +from rmm.librmm.cuda_stream_view cimport cuda_stream_view |
| 10 | +from rmm.pylibrmm.stream cimport Stream |
| 11 | + |
| 12 | +ctypedef const cuda_stream_view const_cuda_stream_view |
| 13 | + |
| 14 | + |
| 15 | +__all__ = ["join_streams"] |
| 16 | + |
| 17 | + |
| 18 | +cpdef void join_streams(list streams, Stream stream): |
| 19 | + """Synchronize a stream to an event on a set of streams. |
| 20 | +
|
| 21 | + This function synchronizes the joined stream with the waited-on streams |
| 22 | + by placing events on each of the waited-on streams and having the joined |
| 23 | + stream wait on those events. |
| 24 | +
|
| 25 | + Parameters |
| 26 | + ---------- |
| 27 | + streams : list |
| 28 | + A list of Stream objects to wait on. |
| 29 | + stream : Stream |
| 30 | + The joined stream that synchronizes with the waited-on streams. |
| 31 | +
|
| 32 | + Examples |
| 33 | + -------- |
| 34 | + >>> import pylibcudf as plc |
| 35 | + >>> from rmm.pylibrmm.stream import Stream |
| 36 | + >>> # Create streams |
| 37 | + >>> stream1 = Stream() |
| 38 | + >>> stream2 = Stream() |
| 39 | + >>> join_stream = Stream() |
| 40 | + >>> # ... do work on stream1 and stream2 ... |
| 41 | + >>> # Wait for both streams before continuing work on join_stream |
| 42 | + >>> plc.experimental.join_streams([stream1, stream2], join_stream) |
| 43 | + >>> # ... continue work on join_stream ... |
| 44 | + """ |
| 45 | + cdef Stream c_stream = <Stream?>stream |
| 46 | + cdef vector[cuda_stream_view] c_streams |
| 47 | + |
| 48 | + c_streams.reserve(len(streams)) |
| 49 | + for s in streams: |
| 50 | + c_streams.push_back((<Stream?>s).view()) |
| 51 | + |
| 52 | + with nogil: |
| 53 | + cpp_stream_pool.join_streams( |
| 54 | + host_span[const_cuda_stream_view](c_streams.data(), c_streams.size()), |
| 55 | + c_stream.view() |
| 56 | + ) |
0 commit comments