2020
2121 from typing_extensions import Any , CapsuleType , Self
2222
23- from cudf_polars . typing import ColumnOptions , DataFrameHeader , PolarsDataType , Slice
23+ from rmm . pylibrmm . stream import Stream
2424
25+ from cudf_polars .typing import ColumnOptions , DataFrameHeader , PolarsDataType , Slice
2526
2627__all__ : list [str ] = ["DataFrame" ]
2728
@@ -78,19 +79,21 @@ class DataFrame:
7879 column_map : dict [str , Column ]
7980 table : plc .Table
8081 columns : list [NamedColumn ]
82+ stream : Stream
8183
82- def __init__ (self , columns : Iterable [Column ]) -> None :
84+ def __init__ (self , columns : Iterable [Column ], stream : Stream ) -> None :
8385 columns = list (columns )
8486 if any (c .name is None for c in columns ):
8587 raise ValueError ("All columns must have a name" )
8688 self .columns = [cast (NamedColumn , c ) for c in columns ]
8789 self .dtypes = [c .dtype for c in self .columns ]
8890 self .column_map = {c .name : c for c in self .columns }
8991 self .table = plc .Table ([c .obj for c in self .columns ])
92+ self .stream = stream
9093
9194 def copy (self ) -> Self :
9295 """Return a shallow copy of self."""
93- return type (self )(c .copy () for c in self .columns )
96+ return type (self )(( c .copy () for c in self .columns ), stream = self . stream )
9497
9598 def to_polars (self ) -> pl .DataFrame :
9699 """Convert to a polars DataFrame."""
@@ -135,30 +138,42 @@ def num_rows(self) -> int:
135138 return self .table .num_rows () if self .column_map else 0
136139
137140 @classmethod
138- def from_polars (cls , df : pl .DataFrame ) -> Self :
141+ def from_polars (cls , df : pl .DataFrame , stream : Stream ) -> Self :
139142 """
140143 Create from a polars dataframe.
141144
142145 Parameters
143146 ----------
144147 df
145148 Polars dataframe to convert
149+ stream
150+ CUDA stream used for device memory operations and kernel launches
151+ on this dataframe.
146152
147153 Returns
148154 -------
149155 New dataframe representing the input.
150156 """
151- plc_table = plc .Table .from_arrow (df )
157+ plc_table = plc .Table .from_arrow (df , stream = stream )
152158 return cls (
153- Column (d_col , name = name , dtype = DataType (h_col .dtype )).copy_metadata (h_col )
154- for d_col , h_col , name in zip (
155- plc_table .columns (), df .iter_columns (), df .columns , strict = True
156- )
159+ (
160+ Column (d_col , name = name , dtype = DataType (h_col .dtype )).copy_metadata (
161+ h_col
162+ )
163+ for d_col , h_col , name in zip (
164+ plc_table .columns (), df .iter_columns (), df .columns , strict = True
165+ )
166+ ),
167+ stream = stream ,
157168 )
158169
159170 @classmethod
160171 def from_table (
161- cls , table : plc .Table , names : Sequence [str ], dtypes : Sequence [DataType ]
172+ cls ,
173+ table : plc .Table ,
174+ names : Sequence [str ],
175+ dtypes : Sequence [DataType ],
176+ stream : Stream ,
162177 ) -> Self :
163178 """
164179 Create from a pylibcudf table.
@@ -171,6 +186,10 @@ def from_table(
171186 Names for the columns
172187 dtypes
173188 Dtypes for the columns
189+ stream
190+ CUDA stream used for device memory operations and kernel launches
191+ on this dataframe. The caller is responsible for ensuring that
192+ the data in ``table`` is valid on ``stream``.
174193
175194 Returns
176195 -------
@@ -185,15 +204,19 @@ def from_table(
185204 if table .num_columns () != len (names ):
186205 raise ValueError ("Mismatching name and table length." )
187206 return cls (
188- Column (c , name = name , dtype = dtype )
189- for c , name , dtype in zip (table .columns (), names , dtypes , strict = True )
207+ (
208+ Column (c , name = name , dtype = dtype )
209+ for c , name , dtype in zip (table .columns (), names , dtypes , strict = True )
210+ ),
211+ stream = stream ,
190212 )
191213
192214 @classmethod
193215 def deserialize (
194216 cls ,
195217 header : DataFrameHeader ,
196218 frames : tuple [memoryview [bytes ], plc .gpumemoryview ],
219+ stream : Stream ,
197220 ) -> Self :
198221 """
199222 Create a DataFrame from a serialized representation returned by `.serialize()`.
@@ -204,6 +227,10 @@ def deserialize(
204227 The (unpickled) metadata required to reconstruct the object.
205228 frames
206229 Two-tuple of frames (a memoryview and a gpumemoryview).
230+ stream
231+ CUDA stream used for device memory operations and kernel launches
232+ on this dataframe. The caller is responsible for ensuring that
233+ the data in ``frames`` is valid on ``stream``.
207234
208235 Returns
209236 -------
@@ -212,11 +239,15 @@ def deserialize(
212239 """
213240 packed_metadata , packed_gpu_data = frames
214241 table = plc .contiguous_split .unpack_from_memoryviews (
215- packed_metadata , packed_gpu_data
242+ packed_metadata ,
243+ packed_gpu_data ,
216244 )
217245 return cls (
218- Column (c , ** Column .deserialize_ctor_kwargs (kw ))
219- for c , kw in zip (table .columns (), header ["columns_kwargs" ], strict = True )
246+ (
247+ Column (c , ** Column .deserialize_ctor_kwargs (kw ))
248+ for c , kw in zip (table .columns (), header ["columns_kwargs" ], strict = True )
249+ ),
250+ stream = stream ,
220251 )
221252
222253 def serialize (
@@ -240,7 +271,7 @@ def serialize(
240271 frames
241272 Two-tuple of frames suitable for passing to `plc.contiguous_split.unpack_from_memoryviews`
242273 """
243- packed = plc .contiguous_split .pack (self .table )
274+ packed = plc .contiguous_split .pack (self .table , stream = self . stream )
244275
245276 # Keyword arguments for `Column.__init__`.
246277 columns_kwargs : list [ColumnOptions ] = [
@@ -278,12 +309,19 @@ def sorted_like(
278309 raise ValueError ("Can only copy from identically named frame" )
279310 subset = self .column_names_set if subset is None else subset
280311 return type (self )(
281- c .sorted_like (other ) if c .name in subset else c
282- for c , other in zip (self .columns , like .columns , strict = True )
312+ (
313+ c .sorted_like (other ) if c .name in subset else c
314+ for c , other in zip (self .columns , like .columns , strict = True )
315+ ),
316+ stream = self .stream ,
283317 )
284318
285319 def with_columns (
286- self , columns : Iterable [Column ], * , replace_only : bool = False
320+ self ,
321+ columns : Iterable [Column ],
322+ * ,
323+ replace_only : bool = False ,
324+ stream : Stream ,
287325 ) -> Self :
288326 """
289327 Return a new dataframe with extra columns.
@@ -294,6 +332,13 @@ def with_columns(
294332 Columns to add
295333 replace_only
296334 If true, then only replacements are allowed (matching by name).
335+ stream
336+ CUDA stream used for device memory operations and kernel launches.
337+ The caller is responsible for ensuring that
338+
339+ 1. The data in ``columns`` is valid on ``stream``.
340+ 2. No additional operations occur on ``self.stream`` with the
341+ original data in ``self``.
297342
298343 Returns
299344 -------
@@ -307,33 +352,57 @@ def with_columns(
307352 new = {c .name : c for c in columns }
308353 if replace_only and not self .column_names_set .issuperset (new .keys ()):
309354 raise ValueError ("Cannot replace with non-existing names" )
310- return type (self )((self .column_map | new ).values ())
355+ return type (self )((self .column_map | new ).values (), stream = stream )
311356
312357 def discard_columns (self , names : Set [str ]) -> Self :
313358 """Drop columns by name."""
314- return type (self )(column for column in self .columns if column .name not in names )
359+ return type (self )(
360+ (column for column in self .columns if column .name not in names ),
361+ stream = self .stream ,
362+ )
315363
316364 def select (self , names : Sequence [str ] | Mapping [str , Any ]) -> Self :
317365 """Select columns by name returning DataFrame."""
318366 try :
319- return type (self )(self .column_map [name ] for name in names )
367+ return type (self )(
368+ (self .column_map [name ] for name in names ), stream = self .stream
369+ )
320370 except KeyError as e :
321371 raise ValueError ("Can't select missing names" ) from e
322372
323373 def rename_columns (self , mapping : Mapping [str , str ]) -> Self :
324374 """Rename some columns."""
325- return type (self )(c .rename (mapping .get (c .name , c .name )) for c in self .columns )
375+ return type (self )(
376+ (c .rename (mapping .get (c .name , c .name )) for c in self .columns ),
377+ stream = self .stream ,
378+ )
326379
327380 def select_columns (self , names : Set [str ]) -> list [Column ]:
328381 """Select columns by name."""
329382 return [c for c in self .columns if c .name in names ]
330383
331384 def filter (self , mask : Column ) -> Self :
332- """Return a filtered table given a mask."""
333- table = plc .stream_compaction .apply_boolean_mask (self .table , mask .obj )
385+ """
386+ Return a filtered table given a mask.
387+
388+ Parameters
389+ ----------
390+ mask
391+ Boolean mask to apply to the dataframe. It is the caller's
392+ responsibility to ensure that ``mask`` is valid on ``self.stream``.
393+ A mask that is derived from ``self`` via a computation on ``self.stream``
394+ automatically satisfies this requirement.
395+
396+ Returns
397+ -------
398+ Filtered dataframe
399+ """
400+ table = plc .stream_compaction .apply_boolean_mask (
401+ self .table , mask .obj , stream = self .stream
402+ )
334403 return (
335404 type (self )
336- .from_table (table , self .column_names , self .dtypes )
405+ .from_table (table , self .column_names , self .dtypes , self . stream )
337406 .sorted_like (self )
338407 )
339408
@@ -354,10 +423,12 @@ def slice(self, zlice: Slice | None) -> Self:
354423 if zlice is None :
355424 return self
356425 (table ,) = plc .copying .slice (
357- self .table , conversion .from_polars_slice (zlice , num_rows = self .num_rows )
426+ self .table ,
427+ conversion .from_polars_slice (zlice , num_rows = self .num_rows ),
428+ stream = self .stream ,
358429 )
359430 return (
360431 type (self )
361- .from_table (table , self .column_names , self .dtypes )
432+ .from_table (table , self .column_names , self .dtypes , self . stream )
362433 .sorted_like (self )
363434 )
0 commit comments