|
44 | 44 | as_buffer, |
45 | 45 | cuda_array_interface_wrapper, |
46 | 46 | ) |
| 47 | +from cudf.core.buffer.spillable_buffer import SpillableBuffer |
47 | 48 | from cudf.core.copy_types import GatherMap |
48 | 49 | from cudf.core.dtypes import ( |
49 | 50 | CategoricalDtype, |
@@ -121,6 +122,41 @@ def _can_values_be_equal(left: DtypeObj, right: DtypeObj) -> bool: |
121 | 122 | return False |
122 | 123 |
|
123 | 124 |
|
| 125 | +class spillable_gpumemoryview(plc.gpumemoryview): |
| 126 | + """ |
| 127 | + HACK: Prevent automatic unspilling of `SpillableBuffer` objects |
| 128 | + when constructing `plc.Column`. |
| 129 | +
|
| 130 | + The `plc.Column()` constructor expects a `gpumemoryview` object, |
| 131 | + but wrapping a `SpillableBuffer` directly in a `gpumemoryview` |
| 132 | + forces the buffer to unspill (materialize) its device data prematurely. |
| 133 | +
|
| 134 | + To avoid this, we wrap spillable buffers in this subclass that implements |
| 135 | + only the `.obj` attribute; the only attribute actually accessed by |
| 136 | + `.from_pylibcudf()`. All other attributes intentionally raise errors to |
| 137 | + prevent accidental usage paths that would cause unspilling. |
| 138 | + """ |
| 139 | + |
| 140 | + def __init__(self, buf: SpillableBuffer) -> None: |
| 141 | + self._buf = buf |
| 142 | + |
| 143 | + @property |
| 144 | + def obj(self) -> SpillableBuffer: |
| 145 | + return self._buf |
| 146 | + |
| 147 | + @property |
| 148 | + def cai(self) -> None: # type: ignore[override] |
| 149 | + assert False |
| 150 | + |
| 151 | + @property |
| 152 | + def ptr(self) -> None: # type: ignore[override] |
| 153 | + assert False |
| 154 | + |
| 155 | + @property |
| 156 | + def nbytes(self) -> None: # type: ignore[override] |
| 157 | + assert False |
| 158 | + |
| 159 | + |
124 | 160 | class ColumnBase(Serializable, BinaryOperand, Reducible): |
125 | 161 | """ |
126 | 162 | A ColumnBase stores columnar data in device memory. |
@@ -411,9 +447,13 @@ def set_mask(self, value) -> Self: |
411 | 447 | new_plc_column = self.to_pylibcudf( |
412 | 448 | mode="read", use_base=False |
413 | 449 | ).with_mask(new_mask, new_null_count) |
414 | | - return self.from_pylibcudf( # type: ignore[return-value] |
415 | | - new_plc_column, |
416 | | - )._with_type_metadata(self.dtype) |
| 450 | + return ( |
| 451 | + type(self) |
| 452 | + .from_pylibcudf( # type: ignore[return-value] |
| 453 | + new_plc_column, |
| 454 | + ) |
| 455 | + ._with_type_metadata(self.dtype) |
| 456 | + ) |
417 | 457 |
|
418 | 458 | @property |
419 | 459 | def null_count(self) -> int: |
@@ -1948,6 +1988,8 @@ def serialize(self) -> tuple[dict, list]: |
1948 | 1988 | ) |
1949 | 1989 | header["subheaders"] = list(child_headers) |
1950 | 1990 | frames.extend(chain(*child_frames)) |
| 1991 | + if isinstance(self.dtype, CategoricalDtype): |
| 1992 | + header["codes_dtype"] = self.codes.dtype.str # type: ignore[attr-defined] |
1951 | 1993 | header["size"] = self.size |
1952 | 1994 | header["frame_count"] = len(frames) |
1953 | 1995 | return header, frames |
@@ -1984,13 +2026,45 @@ def unpack(header, frames) -> tuple[Any, list]: |
1984 | 2026 | child, frames = unpack(h, frames) |
1985 | 2027 | children.append(child) |
1986 | 2028 | assert len(frames) == 0, "Deserialization did not consume all frames" |
1987 | | - return build_column( |
1988 | | - data=data, |
1989 | | - dtype=dtype, |
1990 | | - mask=mask, |
1991 | | - size=header.get("size", None), |
1992 | | - children=tuple(children), |
| 2029 | + if "codes_dtype" in header: |
| 2030 | + codes_dtype = np.dtype(header["codes_dtype"]) |
| 2031 | + else: |
| 2032 | + codes_dtype = None |
| 2033 | + if mask is None: |
| 2034 | + null_count = 0 |
| 2035 | + else: |
| 2036 | + null_count = plc.null_mask.null_count( |
| 2037 | + plc.gpumemoryview(mask), 0, header["size"] |
| 2038 | + ) |
| 2039 | + if isinstance(dtype, IntervalDtype): |
| 2040 | + # TODO: Handle in dtype_to_pylibcudf_type? |
| 2041 | + plc_type = plc.DataType(plc.TypeId.STRUCT) |
| 2042 | + else: |
| 2043 | + plc_type = dtype_to_pylibcudf_type( |
| 2044 | + codes_dtype if codes_dtype is not None else dtype |
| 2045 | + ) |
| 2046 | + if isinstance(dtype, CategoricalDtype): |
| 2047 | + data = children.pop(0) |
| 2048 | + |
| 2049 | + if isinstance(data, SpillableBuffer): |
| 2050 | + data = spillable_gpumemoryview(data) |
| 2051 | + elif data is not None: |
| 2052 | + data = plc.gpumemoryview(data) |
| 2053 | + if isinstance(mask, SpillableBuffer): |
| 2054 | + mask = spillable_gpumemoryview(mask) |
| 2055 | + elif mask is not None: |
| 2056 | + mask = plc.gpumemoryview(mask) |
| 2057 | + |
| 2058 | + plc_column = plc.Column( |
| 2059 | + plc_type, |
| 2060 | + header["size"], |
| 2061 | + data, |
| 2062 | + mask, |
| 2063 | + null_count, |
| 2064 | + 0, |
| 2065 | + [child.to_pylibcudf(mode="read") for child in children], |
1993 | 2066 | ) |
| 2067 | + return cls.from_pylibcudf(plc_column)._with_type_metadata(dtype) |
1994 | 2068 |
|
1995 | 2069 | def unary_operator(self, unaryop: str): |
1996 | 2070 | raise TypeError( |
|
0 commit comments