1717import json
1818import random
1919import time
20+ from dataclasses import dataclass
2021from functools import cache
2122from pathlib import Path
2223from typing import TYPE_CHECKING , Any , ClassVar , overload
7273 "GroupBy" ,
7374 "HConcat" ,
7475 "HStack" ,
76+ "IRExecutionContext" ,
7577 "Join" ,
7678 "MapFunction" ,
7779 "MergeSorted" ,
8890]
8991
9092
93+ @dataclass (frozen = True )
94+ class IRExecutionContext :
95+ """
96+ Runtime context for IR node execution.
97+
98+ This dataclass holds runtime information and configuration needed
99+ during the evaluation of IR nodes.
100+ """
101+
102+
91103_BINOPS = {
92104 plc .binaryop .BinaryOperator .EQUAL ,
93105 plc .binaryop .BinaryOperator .NOT_EQUAL ,
@@ -158,7 +170,9 @@ def get_hashable(self) -> Hashable:
158170 translation phase should fail earlier.
159171 """
160172
161- def evaluate (self , * , cache : CSECache , timer : Timer | None ) -> DataFrame :
173+ def evaluate (
174+ self , * , cache : CSECache , timer : Timer | None , context : IRExecutionContext
175+ ) -> DataFrame :
162176 """
163177 Evaluate the node (recursively) and return a dataframe.
164178
@@ -170,6 +184,8 @@ def evaluate(self, *, cache: CSECache, timer: Timer | None) -> DataFrame:
170184 timer
171185 If not None, a Timer object to record timings for the
172186 evaluation of the node.
187+ context
188+ The execution context for the node.
173189
174190 Notes
175191 -----
@@ -188,16 +204,19 @@ def evaluate(self, *, cache: CSECache, timer: Timer | None) -> DataFrame:
188204 If evaluation fails. Ideally this should not occur, since the
189205 translation phase should fail earlier.
190206 """
191- children = [child .evaluate (cache = cache , timer = timer ) for child in self .children ]
207+ children = [
208+ child .evaluate (cache = cache , timer = timer , context = context )
209+ for child in self .children
210+ ]
192211 if timer is not None :
193212 start = time .monotonic_ns ()
194- result = self .do_evaluate (* self ._non_child_args , * children )
213+ result = self .do_evaluate (* self ._non_child_args , * children , context = context )
195214 end = time .monotonic_ns ()
196215 # TODO: Set better names on each class object.
197216 timer .store (start , end , type (self ).__name__ )
198217 return result
199218 else :
200- return self .do_evaluate (* self ._non_child_args , * children )
219+ return self .do_evaluate (* self ._non_child_args , * children , context = context )
201220
202221
203222class ErrorNode (IR ):
@@ -587,6 +606,8 @@ def do_evaluate(
587606 include_file_paths : str | None ,
588607 predicate : expr .NamedExpr | None ,
589608 parquet_options : ParquetOptions ,
609+ * ,
610+ context : IRExecutionContext ,
590611 ) -> DataFrame :
591612 """Evaluate and return a dataframe."""
592613 stream = get_cuda_stream ()
@@ -1111,6 +1132,8 @@ def do_evaluate(
11111132 parquet_options : ParquetOptions ,
11121133 options : dict [str , Any ],
11131134 df : DataFrame ,
1135+ * ,
1136+ context : IRExecutionContext ,
11141137 ) -> DataFrame :
11151138 """Write the dataframe to a file."""
11161139 target = plc .io .SinkInfo ([path ])
@@ -1164,22 +1187,29 @@ def is_equal(self, other: Self) -> bool: # noqa: D102
11641187 @log_do_evaluate
11651188 @nvtx_annotate_cudf_polars (message = "Cache" )
11661189 def do_evaluate (
1167- cls , key : int , refcount : int | None , df : DataFrame
1190+ cls ,
1191+ key : int ,
1192+ refcount : int | None ,
1193+ df : DataFrame ,
1194+ * ,
1195+ context : IRExecutionContext ,
11681196 ) -> DataFrame : # pragma: no cover; basic evaluation never calls this
11691197 """Evaluate and return a dataframe."""
11701198 # Our value has already been computed for us, so let's just
11711199 # return it.
11721200 return df
11731201
1174- def evaluate (self , * , cache : CSECache , timer : Timer | None ) -> DataFrame :
1202+ def evaluate (
1203+ self , * , cache : CSECache , timer : Timer | None , context : IRExecutionContext
1204+ ) -> DataFrame :
11751205 """Evaluate and return a dataframe."""
11761206 # We must override the recursion scheme because we don't want
11771207 # to recurse if we're in the cache.
11781208 try :
11791209 (result , hits ) = cache [self .key ]
11801210 except KeyError :
11811211 (value ,) = self .children
1182- result = value .evaluate (cache = cache , timer = timer )
1212+ result = value .evaluate (cache = cache , timer = timer , context = context )
11831213 cache [self .key ] = (result , 0 )
11841214 return result
11851215 else :
@@ -1249,6 +1279,8 @@ def do_evaluate(
12491279 schema : Schema ,
12501280 df : Any ,
12511281 projection : tuple [str , ...] | None ,
1282+ * ,
1283+ context : IRExecutionContext ,
12521284 ) -> DataFrame :
12531285 """Evaluate and return a dataframe."""
12541286 if projection is not None :
@@ -1309,6 +1341,8 @@ def do_evaluate(
13091341 exprs : tuple [expr .NamedExpr , ...],
13101342 should_broadcast : bool , # noqa: FBT001
13111343 df : DataFrame ,
1344+ * ,
1345+ context : IRExecutionContext ,
13121346 ) -> DataFrame :
13131347 """Evaluate and return a dataframe."""
13141348 # Handle any broadcasting
@@ -1317,7 +1351,9 @@ def do_evaluate(
13171351 columns = broadcast (* columns )
13181352 return DataFrame (columns , stream = df .stream )
13191353
1320- def evaluate (self , * , cache : CSECache , timer : Timer | None ) -> DataFrame :
1354+ def evaluate (
1355+ self , * , cache : CSECache , timer : Timer | None , context : IRExecutionContext
1356+ ) -> DataFrame :
13211357 """
13221358 Evaluate the Select node with special handling for fast count queries.
13231359
@@ -1329,6 +1365,8 @@ def evaluate(self, *, cache: CSECache, timer: Timer | None) -> DataFrame:
13291365 timer
13301366 If not None, a Timer object to record timings for the
13311367 evaluation of the node.
1368+ context
1369+ The execution context for the node.
13321370
13331371 Returns
13341372 -------
@@ -1364,7 +1402,7 @@ def evaluate(self, *, cache: CSECache, timer: Timer | None) -> DataFrame:
13641402 )
13651403 return DataFrame ([col ], stream = stream )
13661404
1367- return super ().evaluate (cache = cache , timer = timer )
1405+ return super ().evaluate (cache = cache , timer = timer , context = context )
13681406
13691407
13701408class Reduce (IR ):
@@ -1394,6 +1432,8 @@ def do_evaluate(
13941432 cls ,
13951433 exprs : tuple [expr .NamedExpr , ...],
13961434 df : DataFrame ,
1435+ * ,
1436+ context : IRExecutionContext ,
13971437 ) -> DataFrame : # pragma: no cover; not exposed by polars yet
13981438 """Evaluate and return a dataframe."""
13991439 columns = broadcast (* (e .evaluate (df ) for e in exprs ))
@@ -1497,6 +1537,8 @@ def do_evaluate(
14971537 aggs : Sequence [expr .NamedExpr ],
14981538 zlice : Zlice | None ,
14991539 df : DataFrame ,
1540+ * ,
1541+ context : IRExecutionContext ,
15001542 ) -> DataFrame :
15011543 """Evaluate and return a dataframe."""
15021544 keys = broadcast (* (k .evaluate (df ) for k in keys_in ), target_length = df .num_rows )
@@ -1627,6 +1669,8 @@ def do_evaluate(
16271669 maintain_order : bool , # noqa: FBT001
16281670 zlice : Zlice | None ,
16291671 df : DataFrame ,
1672+ * ,
1673+ context : IRExecutionContext ,
16301674 ) -> DataFrame :
16311675 """Evaluate and return a dataframe."""
16321676 keys = broadcast (* (k .evaluate (df ) for k in keys_in ), target_length = df .num_rows )
@@ -1904,6 +1948,8 @@ def do_evaluate(
19041948 options : tuple ,
19051949 left : DataFrame ,
19061950 right : DataFrame ,
1951+ * ,
1952+ context : IRExecutionContext ,
19071953 ) -> DataFrame :
19081954 """Evaluate and return a dataframe."""
19091955 stream = get_joined_cuda_stream (upstreams = (left .stream , right .stream ))
@@ -2187,6 +2233,8 @@ def do_evaluate(
21872233 ],
21882234 left : DataFrame ,
21892235 right : DataFrame ,
2236+ * ,
2237+ context : IRExecutionContext ,
21902238 ) -> DataFrame :
21912239 """Evaluate and return a dataframe."""
21922240 stream = get_joined_cuda_stream (upstreams = (left .stream , right .stream ))
@@ -2358,6 +2406,8 @@ def do_evaluate(
23582406 exprs : Sequence [expr .NamedExpr ],
23592407 should_broadcast : bool , # noqa: FBT001
23602408 df : DataFrame ,
2409+ * ,
2410+ context : IRExecutionContext ,
23612411 ) -> DataFrame :
23622412 """Evaluate and return a dataframe."""
23632413 columns = [c .evaluate (df ) for c in exprs ]
@@ -2426,6 +2476,8 @@ def do_evaluate(
24262476 zlice : Zlice | None ,
24272477 stable : bool , # noqa: FBT001
24282478 df : DataFrame ,
2479+ * ,
2480+ context : IRExecutionContext ,
24292481 ) -> DataFrame :
24302482 """Evaluate and return a dataframe."""
24312483 if subset is None :
@@ -2519,6 +2571,8 @@ def do_evaluate(
25192571 stable : bool , # noqa: FBT001
25202572 zlice : Zlice | None ,
25212573 df : DataFrame ,
2574+ * ,
2575+ context : IRExecutionContext ,
25222576 ) -> DataFrame :
25232577 """Evaluate and return a dataframe."""
25242578 sort_keys = broadcast (* (k .evaluate (df ) for k in by ), target_length = df .num_rows )
@@ -2565,7 +2619,9 @@ def __init__(self, schema: Schema, offset: int, length: int | None, df: IR):
25652619 @classmethod
25662620 @log_do_evaluate
25672621 @nvtx_annotate_cudf_polars (message = "Slice" )
2568- def do_evaluate (cls , offset : int , length : int , df : DataFrame ) -> DataFrame :
2622+ def do_evaluate (
2623+ cls , offset : int , length : int , df : DataFrame , * , context : IRExecutionContext
2624+ ) -> DataFrame :
25692625 """Evaluate and return a dataframe."""
25702626 return df .slice ((offset , length ))
25712627
@@ -2587,7 +2643,9 @@ def __init__(self, schema: Schema, mask: expr.NamedExpr, df: IR):
25872643 @classmethod
25882644 @log_do_evaluate
25892645 @nvtx_annotate_cudf_polars (message = "Filter" )
2590- def do_evaluate (cls , mask_expr : expr .NamedExpr , df : DataFrame ) -> DataFrame :
2646+ def do_evaluate (
2647+ cls , mask_expr : expr .NamedExpr , df : DataFrame , * , context : IRExecutionContext
2648+ ) -> DataFrame :
25912649 """Evaluate and return a dataframe."""
25922650 (mask ,) = broadcast (mask_expr .evaluate (df ), target_length = df .num_rows )
25932651 return df .filter (mask )
@@ -2607,7 +2665,9 @@ def __init__(self, schema: Schema, df: IR):
26072665 @classmethod
26082666 @log_do_evaluate
26092667 @nvtx_annotate_cudf_polars (message = "Projection" )
2610- def do_evaluate (cls , schema : Schema , df : DataFrame ) -> DataFrame :
2668+ def do_evaluate (
2669+ cls , schema : Schema , df : DataFrame , * , context : IRExecutionContext
2670+ ) -> DataFrame :
26112671 """Evaluate and return a dataframe."""
26122672 # This can reorder things.
26132673 columns = broadcast (
@@ -2641,7 +2701,9 @@ def __init__(self, schema: Schema, key: str, left: IR, right: IR):
26412701 @classmethod
26422702 @log_do_evaluate
26432703 @nvtx_annotate_cudf_polars (message = "MergeSorted" )
2644- def do_evaluate (cls , key : str , * dfs : DataFrame ) -> DataFrame :
2704+ def do_evaluate (
2705+ cls , key : str , * dfs : DataFrame , context : IRExecutionContext
2706+ ) -> DataFrame :
26452707 """Evaluate and return a dataframe."""
26462708 stream = get_joined_cuda_stream (upstreams = (df .stream for df in dfs ))
26472709 left , right = dfs
@@ -2766,7 +2828,13 @@ def get_hashable(self) -> Hashable:
27662828 @log_do_evaluate
27672829 @nvtx_annotate_cudf_polars (message = "MapFunction" )
27682830 def do_evaluate (
2769- cls , schema : Schema , name : str , options : Any , df : DataFrame
2831+ cls ,
2832+ schema : Schema ,
2833+ name : str ,
2834+ options : Any ,
2835+ df : DataFrame ,
2836+ * ,
2837+ context : IRExecutionContext ,
27702838 ) -> DataFrame :
27712839 """Evaluate and return a dataframe."""
27722840 if name == "rechunk" :
@@ -2872,7 +2940,9 @@ def __init__(self, schema: Schema, zlice: Zlice | None, *children: IR):
28722940 @classmethod
28732941 @log_do_evaluate
28742942 @nvtx_annotate_cudf_polars (message = "Union" )
2875- def do_evaluate (cls , zlice : Zlice | None , * dfs : DataFrame ) -> DataFrame :
2943+ def do_evaluate (
2944+ cls , zlice : Zlice | None , * dfs : DataFrame , context : IRExecutionContext
2945+ ) -> DataFrame :
28762946 """Evaluate and return a dataframe."""
28772947 stream = get_joined_cuda_stream (upstreams = (df .stream for df in dfs ))
28782948
@@ -2942,6 +3012,7 @@ def do_evaluate(
29423012 cls ,
29433013 should_broadcast : bool , # noqa: FBT001
29443014 * dfs : DataFrame ,
3015+ context : IRExecutionContext ,
29453016 ) -> DataFrame :
29463017 """Evaluate and return a dataframe."""
29473018 stream = get_joined_cuda_stream (upstreams = (df .stream for df in dfs ))
@@ -2991,7 +3062,9 @@ def __init__(self, schema: Schema):
29913062 @classmethod
29923063 @log_do_evaluate
29933064 @nvtx_annotate_cudf_polars (message = "Empty" )
2994- def do_evaluate (cls , schema : Schema ) -> DataFrame : # pragma: no cover
3065+ def do_evaluate (
3066+ cls , schema : Schema , * , context : IRExecutionContext
3067+ ) -> DataFrame : # pragma: no cover
29953068 """Evaluate and return a dataframe."""
29963069 stream = get_cuda_stream ()
29973070 return DataFrame (
0 commit comments