Skip to content

Commit 6fdbd4e

Browse files
Add context to IR.do_evaluate (#20322)
This adds a keyword-only `context` argument to cudf_polars IR.do_evaluate method. The purpose to provide access to special pieces of data that might be necessary for controlling an IR nodes' execution, but doesn't belong on the IR node itself as a non-child argument. Specifically, we'd like to provide a CUDA `stream` argument as part of #20228, but we generalize that slightly and provide a system for providing arbitrary data. A few notes on the implementation: - For now, the context is just an empty dataclass. I suspect its design might change in the future. - I've opted to push the creation of the context as high as possible. For now it's created in `_callback` and passed into `ir.evaluate` / `evaluate_streaming` and from there to all the methods that require it. - There's some awkwardness between how our IR nodes and Dask's task graph treat arguments. I've opted to make `context` keyword only in `IR.do_evaluate(..., context)`. However, Dask's task graph doesn't really deal with that. It wants a tuple of `(function, arg1, arg2, ...)`. So that requires using `functools.partial(function, context=context)(arg1, arg2, ...)`. - After implementing this, I realized that `Expr.evaluate` *also* takes a context, and its a different type `ExecutionContext` :( I can rename the IR variant if we want. Just a draft for now, and probably not worth reviewing until I have a branch somewhere that combines CUDA streams with this to verify it meets our needs. Authors: - Tom Augspurger (https://github.com/TomAugspurger) Approvers: - Matthew Murray (https://github.com/Matt711) URL: #20322
1 parent 46670ea commit 6fdbd4e

File tree

19 files changed

+243
-67
lines changed

19 files changed

+243
-67
lines changed

python/cudf_polars/cudf_polars/callback.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from rmm._cuda import gpu
2424

2525
import cudf_polars.dsl.tracing
26+
from cudf_polars.dsl.ir import IRExecutionContext
2627
from cudf_polars.dsl.tracing import CUDF_POLARS_NVTX_DOMAIN
2728
from cudf_polars.dsl.translate import Translator
2829
from cudf_polars.utils.config import _env_get_int, get_total_device_memory
@@ -218,14 +219,17 @@ def _callback(
218219
assert n_rows is None
219220
if timer is not None:
220221
assert should_time
222+
223+
context = IRExecutionContext()
224+
221225
with (
222226
nvtx.annotate(message="ExecuteIR", domain=CUDF_POLARS_NVTX_DOMAIN),
223227
# Device must be set before memory resource is obtained.
224228
set_device(config_options.device),
225229
set_memory_resource(memory_resource),
226230
):
227231
if config_options.executor.name == "in-memory":
228-
df = ir.evaluate(cache={}, timer=timer).to_polars()
232+
df = ir.evaluate(cache={}, timer=timer, context=context).to_polars()
229233
if timer is None:
230234
return df
231235
else:
@@ -243,7 +247,7 @@ def _callback(
243247
""")
244248
raise NotImplementedError(msg)
245249

246-
return evaluate_streaming(ir, config_options).to_polars()
250+
return evaluate_streaming(ir, config_options, context=context).to_polars()
247251
assert_never(f"Unknown executor '{config_options.executor}'")
248252

249253

python/cudf_polars/cudf_polars/dsl/ir.py

Lines changed: 89 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import json
1818
import random
1919
import time
20+
from dataclasses import dataclass
2021
from functools import cache
2122
from pathlib import Path
2223
from typing import TYPE_CHECKING, Any, ClassVar, overload
@@ -72,6 +73,7 @@
7273
"GroupBy",
7374
"HConcat",
7475
"HStack",
76+
"IRExecutionContext",
7577
"Join",
7678
"MapFunction",
7779
"MergeSorted",
@@ -88,6 +90,16 @@
8890
]
8991

9092

93+
@dataclass(frozen=True)
94+
class IRExecutionContext:
95+
"""
96+
Runtime context for IR node execution.
97+
98+
This dataclass holds runtime information and configuration needed
99+
during the evaluation of IR nodes.
100+
"""
101+
102+
91103
_BINOPS = {
92104
plc.binaryop.BinaryOperator.EQUAL,
93105
plc.binaryop.BinaryOperator.NOT_EQUAL,
@@ -158,7 +170,9 @@ def get_hashable(self) -> Hashable:
158170
translation phase should fail earlier.
159171
"""
160172

161-
def evaluate(self, *, cache: CSECache, timer: Timer | None) -> DataFrame:
173+
def evaluate(
174+
self, *, cache: CSECache, timer: Timer | None, context: IRExecutionContext
175+
) -> DataFrame:
162176
"""
163177
Evaluate the node (recursively) and return a dataframe.
164178
@@ -170,6 +184,8 @@ def evaluate(self, *, cache: CSECache, timer: Timer | None) -> DataFrame:
170184
timer
171185
If not None, a Timer object to record timings for the
172186
evaluation of the node.
187+
context
188+
The execution context for the node.
173189
174190
Notes
175191
-----
@@ -188,16 +204,19 @@ def evaluate(self, *, cache: CSECache, timer: Timer | None) -> DataFrame:
188204
If evaluation fails. Ideally this should not occur, since the
189205
translation phase should fail earlier.
190206
"""
191-
children = [child.evaluate(cache=cache, timer=timer) for child in self.children]
207+
children = [
208+
child.evaluate(cache=cache, timer=timer, context=context)
209+
for child in self.children
210+
]
192211
if timer is not None:
193212
start = time.monotonic_ns()
194-
result = self.do_evaluate(*self._non_child_args, *children)
213+
result = self.do_evaluate(*self._non_child_args, *children, context=context)
195214
end = time.monotonic_ns()
196215
# TODO: Set better names on each class object.
197216
timer.store(start, end, type(self).__name__)
198217
return result
199218
else:
200-
return self.do_evaluate(*self._non_child_args, *children)
219+
return self.do_evaluate(*self._non_child_args, *children, context=context)
201220

202221

203222
class ErrorNode(IR):
@@ -587,6 +606,8 @@ def do_evaluate(
587606
include_file_paths: str | None,
588607
predicate: expr.NamedExpr | None,
589608
parquet_options: ParquetOptions,
609+
*,
610+
context: IRExecutionContext,
590611
) -> DataFrame:
591612
"""Evaluate and return a dataframe."""
592613
stream = get_cuda_stream()
@@ -1111,6 +1132,8 @@ def do_evaluate(
11111132
parquet_options: ParquetOptions,
11121133
options: dict[str, Any],
11131134
df: DataFrame,
1135+
*,
1136+
context: IRExecutionContext,
11141137
) -> DataFrame:
11151138
"""Write the dataframe to a file."""
11161139
target = plc.io.SinkInfo([path])
@@ -1164,22 +1187,29 @@ def is_equal(self, other: Self) -> bool: # noqa: D102
11641187
@log_do_evaluate
11651188
@nvtx_annotate_cudf_polars(message="Cache")
11661189
def do_evaluate(
1167-
cls, key: int, refcount: int | None, df: DataFrame
1190+
cls,
1191+
key: int,
1192+
refcount: int | None,
1193+
df: DataFrame,
1194+
*,
1195+
context: IRExecutionContext,
11681196
) -> DataFrame: # pragma: no cover; basic evaluation never calls this
11691197
"""Evaluate and return a dataframe."""
11701198
# Our value has already been computed for us, so let's just
11711199
# return it.
11721200
return df
11731201

1174-
def evaluate(self, *, cache: CSECache, timer: Timer | None) -> DataFrame:
1202+
def evaluate(
1203+
self, *, cache: CSECache, timer: Timer | None, context: IRExecutionContext
1204+
) -> DataFrame:
11751205
"""Evaluate and return a dataframe."""
11761206
# We must override the recursion scheme because we don't want
11771207
# to recurse if we're in the cache.
11781208
try:
11791209
(result, hits) = cache[self.key]
11801210
except KeyError:
11811211
(value,) = self.children
1182-
result = value.evaluate(cache=cache, timer=timer)
1212+
result = value.evaluate(cache=cache, timer=timer, context=context)
11831213
cache[self.key] = (result, 0)
11841214
return result
11851215
else:
@@ -1249,6 +1279,8 @@ def do_evaluate(
12491279
schema: Schema,
12501280
df: Any,
12511281
projection: tuple[str, ...] | None,
1282+
*,
1283+
context: IRExecutionContext,
12521284
) -> DataFrame:
12531285
"""Evaluate and return a dataframe."""
12541286
if projection is not None:
@@ -1309,6 +1341,8 @@ def do_evaluate(
13091341
exprs: tuple[expr.NamedExpr, ...],
13101342
should_broadcast: bool, # noqa: FBT001
13111343
df: DataFrame,
1344+
*,
1345+
context: IRExecutionContext,
13121346
) -> DataFrame:
13131347
"""Evaluate and return a dataframe."""
13141348
# Handle any broadcasting
@@ -1317,7 +1351,9 @@ def do_evaluate(
13171351
columns = broadcast(*columns)
13181352
return DataFrame(columns, stream=df.stream)
13191353

1320-
def evaluate(self, *, cache: CSECache, timer: Timer | None) -> DataFrame:
1354+
def evaluate(
1355+
self, *, cache: CSECache, timer: Timer | None, context: IRExecutionContext
1356+
) -> DataFrame:
13211357
"""
13221358
Evaluate the Select node with special handling for fast count queries.
13231359
@@ -1329,6 +1365,8 @@ def evaluate(self, *, cache: CSECache, timer: Timer | None) -> DataFrame:
13291365
timer
13301366
If not None, a Timer object to record timings for the
13311367
evaluation of the node.
1368+
context
1369+
The execution context for the node.
13321370
13331371
Returns
13341372
-------
@@ -1364,7 +1402,7 @@ def evaluate(self, *, cache: CSECache, timer: Timer | None) -> DataFrame:
13641402
)
13651403
return DataFrame([col], stream=stream)
13661404

1367-
return super().evaluate(cache=cache, timer=timer)
1405+
return super().evaluate(cache=cache, timer=timer, context=context)
13681406

13691407

13701408
class Reduce(IR):
@@ -1394,6 +1432,8 @@ def do_evaluate(
13941432
cls,
13951433
exprs: tuple[expr.NamedExpr, ...],
13961434
df: DataFrame,
1435+
*,
1436+
context: IRExecutionContext,
13971437
) -> DataFrame: # pragma: no cover; not exposed by polars yet
13981438
"""Evaluate and return a dataframe."""
13991439
columns = broadcast(*(e.evaluate(df) for e in exprs))
@@ -1497,6 +1537,8 @@ def do_evaluate(
14971537
aggs: Sequence[expr.NamedExpr],
14981538
zlice: Zlice | None,
14991539
df: DataFrame,
1540+
*,
1541+
context: IRExecutionContext,
15001542
) -> DataFrame:
15011543
"""Evaluate and return a dataframe."""
15021544
keys = broadcast(*(k.evaluate(df) for k in keys_in), target_length=df.num_rows)
@@ -1627,6 +1669,8 @@ def do_evaluate(
16271669
maintain_order: bool, # noqa: FBT001
16281670
zlice: Zlice | None,
16291671
df: DataFrame,
1672+
*,
1673+
context: IRExecutionContext,
16301674
) -> DataFrame:
16311675
"""Evaluate and return a dataframe."""
16321676
keys = broadcast(*(k.evaluate(df) for k in keys_in), target_length=df.num_rows)
@@ -1904,6 +1948,8 @@ def do_evaluate(
19041948
options: tuple,
19051949
left: DataFrame,
19061950
right: DataFrame,
1951+
*,
1952+
context: IRExecutionContext,
19071953
) -> DataFrame:
19081954
"""Evaluate and return a dataframe."""
19091955
stream = get_joined_cuda_stream(upstreams=(left.stream, right.stream))
@@ -2187,6 +2233,8 @@ def do_evaluate(
21872233
],
21882234
left: DataFrame,
21892235
right: DataFrame,
2236+
*,
2237+
context: IRExecutionContext,
21902238
) -> DataFrame:
21912239
"""Evaluate and return a dataframe."""
21922240
stream = get_joined_cuda_stream(upstreams=(left.stream, right.stream))
@@ -2358,6 +2406,8 @@ def do_evaluate(
23582406
exprs: Sequence[expr.NamedExpr],
23592407
should_broadcast: bool, # noqa: FBT001
23602408
df: DataFrame,
2409+
*,
2410+
context: IRExecutionContext,
23612411
) -> DataFrame:
23622412
"""Evaluate and return a dataframe."""
23632413
columns = [c.evaluate(df) for c in exprs]
@@ -2426,6 +2476,8 @@ def do_evaluate(
24262476
zlice: Zlice | None,
24272477
stable: bool, # noqa: FBT001
24282478
df: DataFrame,
2479+
*,
2480+
context: IRExecutionContext,
24292481
) -> DataFrame:
24302482
"""Evaluate and return a dataframe."""
24312483
if subset is None:
@@ -2519,6 +2571,8 @@ def do_evaluate(
25192571
stable: bool, # noqa: FBT001
25202572
zlice: Zlice | None,
25212573
df: DataFrame,
2574+
*,
2575+
context: IRExecutionContext,
25222576
) -> DataFrame:
25232577
"""Evaluate and return a dataframe."""
25242578
sort_keys = broadcast(*(k.evaluate(df) for k in by), target_length=df.num_rows)
@@ -2565,7 +2619,9 @@ def __init__(self, schema: Schema, offset: int, length: int | None, df: IR):
25652619
@classmethod
25662620
@log_do_evaluate
25672621
@nvtx_annotate_cudf_polars(message="Slice")
2568-
def do_evaluate(cls, offset: int, length: int, df: DataFrame) -> DataFrame:
2622+
def do_evaluate(
2623+
cls, offset: int, length: int, df: DataFrame, *, context: IRExecutionContext
2624+
) -> DataFrame:
25692625
"""Evaluate and return a dataframe."""
25702626
return df.slice((offset, length))
25712627

@@ -2587,7 +2643,9 @@ def __init__(self, schema: Schema, mask: expr.NamedExpr, df: IR):
25872643
@classmethod
25882644
@log_do_evaluate
25892645
@nvtx_annotate_cudf_polars(message="Filter")
2590-
def do_evaluate(cls, mask_expr: expr.NamedExpr, df: DataFrame) -> DataFrame:
2646+
def do_evaluate(
2647+
cls, mask_expr: expr.NamedExpr, df: DataFrame, *, context: IRExecutionContext
2648+
) -> DataFrame:
25912649
"""Evaluate and return a dataframe."""
25922650
(mask,) = broadcast(mask_expr.evaluate(df), target_length=df.num_rows)
25932651
return df.filter(mask)
@@ -2607,7 +2665,9 @@ def __init__(self, schema: Schema, df: IR):
26072665
@classmethod
26082666
@log_do_evaluate
26092667
@nvtx_annotate_cudf_polars(message="Projection")
2610-
def do_evaluate(cls, schema: Schema, df: DataFrame) -> DataFrame:
2668+
def do_evaluate(
2669+
cls, schema: Schema, df: DataFrame, *, context: IRExecutionContext
2670+
) -> DataFrame:
26112671
"""Evaluate and return a dataframe."""
26122672
# This can reorder things.
26132673
columns = broadcast(
@@ -2641,7 +2701,9 @@ def __init__(self, schema: Schema, key: str, left: IR, right: IR):
26412701
@classmethod
26422702
@log_do_evaluate
26432703
@nvtx_annotate_cudf_polars(message="MergeSorted")
2644-
def do_evaluate(cls, key: str, *dfs: DataFrame) -> DataFrame:
2704+
def do_evaluate(
2705+
cls, key: str, *dfs: DataFrame, context: IRExecutionContext
2706+
) -> DataFrame:
26452707
"""Evaluate and return a dataframe."""
26462708
stream = get_joined_cuda_stream(upstreams=(df.stream for df in dfs))
26472709
left, right = dfs
@@ -2766,7 +2828,13 @@ def get_hashable(self) -> Hashable:
27662828
@log_do_evaluate
27672829
@nvtx_annotate_cudf_polars(message="MapFunction")
27682830
def do_evaluate(
2769-
cls, schema: Schema, name: str, options: Any, df: DataFrame
2831+
cls,
2832+
schema: Schema,
2833+
name: str,
2834+
options: Any,
2835+
df: DataFrame,
2836+
*,
2837+
context: IRExecutionContext,
27702838
) -> DataFrame:
27712839
"""Evaluate and return a dataframe."""
27722840
if name == "rechunk":
@@ -2872,7 +2940,9 @@ def __init__(self, schema: Schema, zlice: Zlice | None, *children: IR):
28722940
@classmethod
28732941
@log_do_evaluate
28742942
@nvtx_annotate_cudf_polars(message="Union")
2875-
def do_evaluate(cls, zlice: Zlice | None, *dfs: DataFrame) -> DataFrame:
2943+
def do_evaluate(
2944+
cls, zlice: Zlice | None, *dfs: DataFrame, context: IRExecutionContext
2945+
) -> DataFrame:
28762946
"""Evaluate and return a dataframe."""
28772947
stream = get_joined_cuda_stream(upstreams=(df.stream for df in dfs))
28782948

@@ -2942,6 +3012,7 @@ def do_evaluate(
29423012
cls,
29433013
should_broadcast: bool, # noqa: FBT001
29443014
*dfs: DataFrame,
3015+
context: IRExecutionContext,
29453016
) -> DataFrame:
29463017
"""Evaluate and return a dataframe."""
29473018
stream = get_joined_cuda_stream(upstreams=(df.stream for df in dfs))
@@ -2991,7 +3062,9 @@ def __init__(self, schema: Schema):
29913062
@classmethod
29923063
@log_do_evaluate
29933064
@nvtx_annotate_cudf_polars(message="Empty")
2994-
def do_evaluate(cls, schema: Schema) -> DataFrame: # pragma: no cover
3065+
def do_evaluate(
3066+
cls, schema: Schema, *, context: IRExecutionContext
3067+
) -> DataFrame: # pragma: no cover
29953068
"""Evaluate and return a dataframe."""
29963069
stream = get_cuda_stream()
29973070
return DataFrame(

python/cudf_polars/cudf_polars/dsl/tracing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,9 @@ def wrapper(
166166
log = structlog.get_logger()
167167

168168
# By convention, all non-dataframe arguments (non_child) come first.
169-
# Anything remaining is a dataframe.
169+
# Anything remaining is a dataframe, except for 'context' kwarg.
170170
frames: list[cudf_polars.containers.DataFrame] = (
171-
list(args) + list(kwargs.values())
171+
list(args) + [v for k, v in kwargs.items() if k != "context"]
172172
)[len(cls._non_child) :] # type: ignore[assignment]
173173

174174
before_start = time.monotonic_ns()

0 commit comments

Comments
 (0)