Skip to content

Commit 73d722e

Browse files
authored
Make ColumnBase.deserialize construct via pylibcudf (#20142)
Broken off from #20087 Authors: - Matthew Roeschke (https://github.com/mroeschke) - Mads R. B. Kristensen (https://github.com/madsbk) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) URL: #20142
1 parent adc8b23 commit 73d722e

File tree

1 file changed

+83
-9
lines changed

1 file changed

+83
-9
lines changed

python/cudf/cudf/core/column/column.py

Lines changed: 83 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
as_buffer,
4545
cuda_array_interface_wrapper,
4646
)
47+
from cudf.core.buffer.spillable_buffer import SpillableBuffer
4748
from cudf.core.copy_types import GatherMap
4849
from cudf.core.dtypes import (
4950
CategoricalDtype,
@@ -121,6 +122,41 @@ def _can_values_be_equal(left: DtypeObj, right: DtypeObj) -> bool:
121122
return False
122123

123124

125+
class spillable_gpumemoryview(plc.gpumemoryview):
126+
"""
127+
HACK: Prevent automatic unspilling of `SpillableBuffer` objects
128+
when constructing `plc.Column`.
129+
130+
The `plc.Column()` constructor expects a `gpumemoryview` object,
131+
but wrapping a `SpillableBuffer` directly in a `gpumemoryview`
132+
forces the buffer to unspill (materialize) its device data prematurely.
133+
134+
To avoid this, we wrap spillable buffers in this subclass that implements
135+
only the `.obj` attribute; the only attribute actually accessed by
136+
`.from_pylibcudf()`. All other attributes intentionally raise errors to
137+
prevent accidental usage paths that would cause unspilling.
138+
"""
139+
140+
def __init__(self, buf: SpillableBuffer) -> None:
141+
self._buf = buf
142+
143+
@property
144+
def obj(self) -> SpillableBuffer:
145+
return self._buf
146+
147+
@property
148+
def cai(self) -> None: # type: ignore[override]
149+
assert False
150+
151+
@property
152+
def ptr(self) -> None: # type: ignore[override]
153+
assert False
154+
155+
@property
156+
def nbytes(self) -> None: # type: ignore[override]
157+
assert False
158+
159+
124160
class ColumnBase(Serializable, BinaryOperand, Reducible):
125161
"""
126162
A ColumnBase stores columnar data in device memory.
@@ -411,9 +447,13 @@ def set_mask(self, value) -> Self:
411447
new_plc_column = self.to_pylibcudf(
412448
mode="read", use_base=False
413449
).with_mask(new_mask, new_null_count)
414-
return self.from_pylibcudf( # type: ignore[return-value]
415-
new_plc_column,
416-
)._with_type_metadata(self.dtype)
450+
return (
451+
type(self)
452+
.from_pylibcudf( # type: ignore[return-value]
453+
new_plc_column,
454+
)
455+
._with_type_metadata(self.dtype)
456+
)
417457

418458
@property
419459
def null_count(self) -> int:
@@ -1948,6 +1988,8 @@ def serialize(self) -> tuple[dict, list]:
19481988
)
19491989
header["subheaders"] = list(child_headers)
19501990
frames.extend(chain(*child_frames))
1991+
if isinstance(self.dtype, CategoricalDtype):
1992+
header["codes_dtype"] = self.codes.dtype.str # type: ignore[attr-defined]
19511993
header["size"] = self.size
19521994
header["frame_count"] = len(frames)
19531995
return header, frames
@@ -1984,13 +2026,45 @@ def unpack(header, frames) -> tuple[Any, list]:
19842026
child, frames = unpack(h, frames)
19852027
children.append(child)
19862028
assert len(frames) == 0, "Deserialization did not consume all frames"
1987-
return build_column(
1988-
data=data,
1989-
dtype=dtype,
1990-
mask=mask,
1991-
size=header.get("size", None),
1992-
children=tuple(children),
2029+
if "codes_dtype" in header:
2030+
codes_dtype = np.dtype(header["codes_dtype"])
2031+
else:
2032+
codes_dtype = None
2033+
if mask is None:
2034+
null_count = 0
2035+
else:
2036+
null_count = plc.null_mask.null_count(
2037+
plc.gpumemoryview(mask), 0, header["size"]
2038+
)
2039+
if isinstance(dtype, IntervalDtype):
2040+
# TODO: Handle in dtype_to_pylibcudf_type?
2041+
plc_type = plc.DataType(plc.TypeId.STRUCT)
2042+
else:
2043+
plc_type = dtype_to_pylibcudf_type(
2044+
codes_dtype if codes_dtype is not None else dtype
2045+
)
2046+
if isinstance(dtype, CategoricalDtype):
2047+
data = children.pop(0)
2048+
2049+
if isinstance(data, SpillableBuffer):
2050+
data = spillable_gpumemoryview(data)
2051+
elif data is not None:
2052+
data = plc.gpumemoryview(data)
2053+
if isinstance(mask, SpillableBuffer):
2054+
mask = spillable_gpumemoryview(mask)
2055+
elif mask is not None:
2056+
mask = plc.gpumemoryview(mask)
2057+
2058+
plc_column = plc.Column(
2059+
plc_type,
2060+
header["size"],
2061+
data,
2062+
mask,
2063+
null_count,
2064+
0,
2065+
[child.to_pylibcudf(mode="read") for child in children],
19932066
)
2067+
return cls.from_pylibcudf(plc_column)._with_type_metadata(dtype)
19942068

19952069
def unary_operator(self, unaryop: str):
19962070
raise TypeError(

0 commit comments

Comments
 (0)