Skip to content

Commit 5a19105

Browse files
berkaysynnadaadriangb
authored andcommitted
minor: implement with_new_expressions for AggregateFunctionExpr (apache#16897)
* minor * Update aggregate.rs
1 parent 55c08d5 commit 5a19105

File tree

9 files changed

+68
-19
lines changed

9 files changed

+68
-19
lines changed

datafusion/expr/src/window_state.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ use datafusion_common::{
3434
};
3535

3636
/// Holds the state of evaluating a window function
37-
#[derive(Debug)]
37+
#[derive(Debug, Clone)]
3838
pub struct WindowAggState {
3939
/// The range that we calculate the window function
4040
pub window_frame_range: Range<usize>,
@@ -112,7 +112,7 @@ impl WindowAggState {
112112
}
113113

114114
/// This object stores the window frame state for use in incremental calculations.
115-
#[derive(Debug)]
115+
#[derive(Debug, Clone)]
116116
pub enum WindowFrameContext {
117117
/// ROWS frames are inherently stateless.
118118
Rows(Arc<WindowFrame>),
@@ -240,7 +240,7 @@ impl WindowFrameContext {
240240
}
241241

242242
/// State for each unique partition determined according to PARTITION BY column(s)
243-
#[derive(Debug)]
243+
#[derive(Debug, Clone, PartialEq)]
244244
pub struct PartitionBatchState {
245245
/// The record batch belonging to current partition
246246
pub record_batch: RecordBatch,
@@ -282,7 +282,7 @@ impl PartitionBatchState {
282282
/// ranges of data while processing RANGE frames.
283283
/// Attribute `sort_options` stores the column ordering specified by the ORDER
284284
/// BY clause. This information is used to calculate the range.
285-
#[derive(Debug, Default)]
285+
#[derive(Debug, Default, Clone)]
286286
pub struct WindowFrameStateRange {
287287
sort_options: Vec<SortOptions>,
288288
}
@@ -454,7 +454,7 @@ impl WindowFrameStateRange {
454454

455455
/// This structure encapsulates all the state information we require as we
456456
/// scan groups of data while processing window frames.
457-
#[derive(Debug, Default)]
457+
#[derive(Debug, Default, Clone)]
458458
pub struct WindowFrameStateGroups {
459459
/// A tuple containing group values and the row index where the group ends.
460460
/// Example: [[1, 1], [1, 1], [2, 1], [2, 1], ...] would correspond to

datafusion/physical-expr/src/aggregate.rs

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -616,10 +616,42 @@ impl AggregateFunctionExpr {
616616
/// Returns `Some(Arc<dyn AggregateExpr>)` if re-write is supported, otherwise returns `None`.
617617
pub fn with_new_expressions(
618618
&self,
619-
_args: Vec<Arc<dyn PhysicalExpr>>,
620-
_order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
619+
args: Vec<Arc<dyn PhysicalExpr>>,
620+
order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
621621
) -> Option<AggregateFunctionExpr> {
622-
None
622+
if args.len() != self.args.len()
623+
|| (self.order_sensitivity() != AggregateOrderSensitivity::Insensitive
624+
&& order_by_exprs.len() != self.order_bys.len())
625+
{
626+
return None;
627+
}
628+
629+
let new_order_bys = self
630+
.order_bys
631+
.iter()
632+
.zip(order_by_exprs)
633+
.map(|(req, new_expr)| PhysicalSortExpr {
634+
expr: new_expr,
635+
options: req.options,
636+
})
637+
.collect();
638+
639+
Some(AggregateFunctionExpr {
640+
fun: self.fun.clone(),
641+
args,
642+
return_field: Arc::clone(&self.return_field),
643+
name: self.name.clone(),
644+
// TODO: Human name should be updated after re-write to not mislead
645+
human_display: self.human_display.clone(),
646+
schema: self.schema.clone(),
647+
order_bys: new_order_bys,
648+
ignore_nulls: self.ignore_nulls,
649+
ordering_fields: self.ordering_fields.clone(),
650+
is_distinct: self.is_distinct,
651+
is_reversed: false,
652+
input_fields: self.input_fields.clone(),
653+
is_nullable: self.is_nullable,
654+
})
623655
}
624656

625657
/// If this function is max, return (output_field, true)

datafusion/physical-expr/src/expressions/literal.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ use datafusion_expr_common::interval_arithmetic::Interval;
3636
use datafusion_expr_common::sort_properties::{ExprProperties, SortProperties};
3737

3838
/// Represents a literal value
39-
#[derive(Debug, PartialEq, Eq)]
39+
#[derive(Debug, PartialEq, Eq, Clone)]
4040
pub struct Literal {
4141
value: ScalarValue,
4242
field: FieldRef,

datafusion/physical-expr/src/window/aggregate.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use std::sync::Arc;
2323

2424
use crate::aggregate::AggregateFunctionExpr;
2525
use crate::window::standard::add_new_ordering_expr_with_partition_by;
26-
use crate::window::window_expr::AggregateWindowExpr;
26+
use crate::window::window_expr::{AggregateWindowExpr, WindowFn};
2727
use crate::window::{
2828
PartitionBatches, PartitionWindowAggStates, SlidingAggregateWindowExpr, WindowExpr,
2929
};
@@ -211,6 +211,10 @@ impl WindowExpr for PlainAggregateWindowExpr {
211211
fn uses_bounded_memory(&self) -> bool {
212212
!self.window_frame.end_bound.is_unbounded()
213213
}
214+
215+
fn create_window_fn(&self) -> Result<WindowFn> {
216+
Ok(WindowFn::Aggregate(self.get_accumulator()?))
217+
}
214218
}
215219

216220
impl AggregateWindowExpr for PlainAggregateWindowExpr {

datafusion/physical-expr/src/window/sliding_aggregate.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use std::ops::Range;
2222
use std::sync::Arc;
2323

2424
use crate::aggregate::AggregateFunctionExpr;
25-
use crate::window::window_expr::AggregateWindowExpr;
25+
use crate::window::window_expr::{AggregateWindowExpr, WindowFn};
2626
use crate::window::{
2727
PartitionBatches, PartitionWindowAggStates, PlainAggregateWindowExpr, WindowExpr,
2828
};
@@ -175,6 +175,10 @@ impl WindowExpr for SlidingAggregateWindowExpr {
175175
window_frame: Arc::clone(&self.window_frame),
176176
}))
177177
}
178+
179+
fn create_window_fn(&self) -> Result<WindowFn> {
180+
Ok(WindowFn::Aggregate(self.get_accumulator()?))
181+
}
178182
}
179183

180184
impl AggregateWindowExpr for SlidingAggregateWindowExpr {

datafusion/physical-expr/src/window/standard.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,10 @@ impl WindowExpr for StandardWindowExpr {
275275
false
276276
}
277277
}
278+
279+
fn create_window_fn(&self) -> Result<WindowFn> {
280+
Ok(WindowFn::Builtin(self.expr.create_evaluator()?))
281+
}
278282
}
279283

280284
/// Adds a new ordering expression into existing ordering equivalence class(es) based on

datafusion/physical-expr/src/window/window_expr.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,12 @@ pub trait WindowExpr: Send + Sync + Debug {
130130
/// Get the reverse expression of this [WindowExpr].
131131
fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>>;
132132

133+
/// Creates a new instance of the window function evaluator.
134+
///
135+
/// Returns `WindowFn::Builtin` for built-in window functions (e.g., ROW_NUMBER, RANK)
136+
/// or `WindowFn::Aggregate` for aggregate window functions (e.g., SUM, AVG).
137+
fn create_window_fn(&self) -> Result<WindowFn>;
138+
133139
/// Returns all expressions used in the [`WindowExpr`].
134140
/// These expressions are (1) function arguments, (2) partition by expressions, (3) order by expressions.
135141
fn all_expressions(&self) -> WindowPhysicalExpressions {

datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::mem::size_of;
19+
1820
use crate::aggregates::group_values::GroupValues;
21+
1922
use arrow::array::{Array, ArrayRef, OffsetSizeTrait, RecordBatch};
23+
use datafusion_common::Result;
2024
use datafusion_expr::EmitTo;
2125
use datafusion_physical_expr_common::binary_map::{ArrowBytesMap, OutputType};
22-
use std::mem::size_of;
2326

2427
/// A [`GroupValues`] storing single column of Utf8/LargeUtf8/Binary/LargeBinary values
2528
///
@@ -42,11 +45,7 @@ impl<O: OffsetSizeTrait> GroupValuesByes<O> {
4245
}
4346

4447
impl<O: OffsetSizeTrait> GroupValues for GroupValuesByes<O> {
45-
fn intern(
46-
&mut self,
47-
cols: &[ArrayRef],
48-
groups: &mut Vec<usize>,
49-
) -> datafusion_common::Result<()> {
48+
fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> Result<()> {
5049
assert_eq!(cols.len(), 1);
5150

5251
// look up / add entries in the table
@@ -85,7 +84,7 @@ impl<O: OffsetSizeTrait> GroupValues for GroupValuesByes<O> {
8584
self.num_groups
8685
}
8786

88-
fn emit(&mut self, emit_to: EmitTo) -> datafusion_common::Result<Vec<ArrayRef>> {
87+
fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
8988
// Reset the map to default, and convert it into a single array
9089
let map_contents = self.map.take().into_state();
9190

datafusion/physical-plan/src/test.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ impl ExecutionPlan for TestMemoryExec {
131131
}
132132

133133
fn as_any(&self) -> &dyn Any {
134-
unimplemented!()
134+
self
135135
}
136136

137137
fn properties(&self) -> &PlanProperties {

0 commit comments

Comments
 (0)