diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index d09393fc90..f3815d2eae 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -64,7 +64,7 @@ use datafusion::{ use datafusion_comet_spark_expr::{ create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, create_modulo_expr, create_negate_expr, BinaryOutputStyle, BloomFilterAgg, BloomFilterMightContain, EvalMode, - SparkHour, SparkMinute, SparkSecond, + SparkHour, SparkMinute, SparkSecond, SumInteger, }; use iceberg::expr::Bind; @@ -2027,6 +2027,12 @@ impl PhysicalPlanner { AggregateUDF::new_from_impl(SumDecimal::try_new(datatype, eval_mode)?); AggregateExprBuilder::new(Arc::new(func), vec![child]) } + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + let func = + AggregateUDF::new_from_impl(SumInteger::try_new(datatype, eval_mode)?); + AggregateExprBuilder::new(Arc::new(func), vec![child]) + } _ => { // cast to the result data type of SUM if necessary, we should not expect // a cast failure since it should have already been checked at Spark side diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index 252da78890..b1027153e8 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -21,6 +21,7 @@ mod correlation; mod covariance; mod stddev; mod sum_decimal; +mod sum_int; mod variance; pub use avg::Avg; @@ -29,4 +30,5 @@ pub use correlation::Correlation; pub use covariance::Covariance; pub use stddev::Stddev; pub use sum_decimal::SumDecimal; +pub use sum_int::SumInteger; pub use variance::Variance; diff --git a/native/spark-expr/src/agg_funcs/sum_int.rs b/native/spark-expr/src/agg_funcs/sum_int.rs new file mode 100644 index 0000000000..af56c55fdd --- /dev/null +++ b/native/spark-expr/src/agg_funcs/sum_int.rs @@ -0,0 +1,564 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{arithmetic_overflow_error, EvalMode}; +use arrow::array::{ + as_primitive_array, cast::AsArray, Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, + BooleanArray, Int64Array, PrimitiveArray, +}; +use arrow::datatypes::{ + ArrowNativeType, DataType, Field, FieldRef, Int16Type, Int32Type, Int64Type, Int8Type, +}; +use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; +use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion::logical_expr::Volatility::Immutable; +use datafusion::logical_expr::{ + Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature, +}; +use std::{any::Any, sync::Arc}; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SumInteger { + signature: Signature, + eval_mode: EvalMode, +} + +impl SumInteger { + pub fn try_new(data_type: DataType, eval_mode: EvalMode) -> DFResult { + match data_type { + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => Ok(Self { + signature: Signature::user_defined(Immutable), + eval_mode, + }), + _ => Err(DataFusionError::Internal( + "Invalid data type for SumInteger".into(), + )), + } + } +} + +impl AggregateUDFImpl for SumInteger { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "sum" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Int64) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> DFResult> { + Ok(Box::new(SumIntegerAccumulator::new(self.eval_mode))) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> DFResult> { + if self.eval_mode == EvalMode::Try { + Ok(vec![ + Arc::new(Field::new("sum", DataType::Int64, true)), + Arc::new(Field::new("has_all_nulls", DataType::Boolean, false)), + ]) + } else { + Ok(vec![Arc::new(Field::new("sum", DataType::Int64, true))]) + } + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> DFResult> { + Ok(Box::new(SumIntGroupsAccumulator::new(self.eval_mode))) + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } +} + +#[derive(Debug)] +struct SumIntegerAccumulator { + sum: Option, + eval_mode: EvalMode, + has_all_nulls: bool, +} + +impl SumIntegerAccumulator { + fn new(eval_mode: EvalMode) -> Self { + if eval_mode == EvalMode::Try { + Self { + // Try mode starts with 0 (because if this is init to None we cant say if it is none due to all nulls or due to an overflow) + sum: Some(0), + has_all_nulls: true, + eval_mode, + } + } else { + Self { + sum: None, + has_all_nulls: false, + eval_mode, + } + } + } +} + +impl Accumulator for SumIntegerAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> { + // accumulator internal to add sum and return null sum (and has_nulls false) if there is an overflow in Try Eval mode + fn update_sum_internal( + int_array: &PrimitiveArray, + eval_mode: EvalMode, + mut sum: i64, + ) -> Result, DataFusionError> + where + T: ArrowPrimitiveType, + { + for i in 0..int_array.len() { + if !int_array.is_null(i) { + let v = int_array.value(i).to_i64().ok_or_else(|| { + DataFusionError::Internal(format!( + "Failed to convert value {:?} to i64", + int_array.value(i) + )) + })?; + match eval_mode { + EvalMode::Legacy => { + sum = v.add_wrapping(sum); + } + EvalMode::Ansi | EvalMode::Try => { + match v.add_checked(sum) { + Ok(v) => sum = v, + Err(_e) => { + return if eval_mode == EvalMode::Ansi { + Err(DataFusionError::from(arithmetic_overflow_error( + "integer", + ))) + } else { + Ok(None) + }; + } + }; + } + } + } + } + Ok(Some(sum)) + } + + if self.eval_mode == EvalMode::Try && !self.has_all_nulls && self.sum.is_none() { + // we saw an overflow earlier (Try eval mode). Skip processing + return Ok(()); + } + let values = &values[0]; + if values.len() == values.null_count() { + Ok(()) + } else { + // No nulls so there should be a non-null sum / null incase overflow in Try eval + let running_sum = self.sum.unwrap_or(0); + let sum = match values.data_type() { + DataType::Int64 => update_sum_internal( + as_primitive_array::(values), + self.eval_mode, + running_sum, + )?, + DataType::Int32 => update_sum_internal( + as_primitive_array::(values), + self.eval_mode, + running_sum, + )?, + DataType::Int16 => update_sum_internal( + as_primitive_array::(values), + self.eval_mode, + running_sum, + )?, + DataType::Int8 => update_sum_internal( + as_primitive_array::(values), + self.eval_mode, + running_sum, + )?, + _ => { + return Err(DataFusionError::Internal(format!( + "unsupported data type: {:?}", + values.data_type() + ))); + } + }; + self.sum = sum; + self.has_all_nulls = false; + Ok(()) + } + } + + fn evaluate(&mut self) -> DFResult { + if self.has_all_nulls { + Ok(ScalarValue::Int64(None)) + } else { + Ok(ScalarValue::Int64(self.sum)) + } + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn state(&mut self) -> DFResult> { + if self.eval_mode == EvalMode::Try { + Ok(vec![ + ScalarValue::Int64(self.sum), + ScalarValue::Boolean(Some(self.has_all_nulls)), + ]) + } else { + Ok(vec![ScalarValue::Int64(self.sum)]) + } + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> { + let that_sum_array = states[0].as_primitive::(); + let that_sum = if that_sum_array.is_null(0) { + None + } else { + Some(that_sum_array.value(0)) + }; + + // Check for overflow for early termination + if self.eval_mode == EvalMode::Try { + let that_has_all_nulls = states[1].as_boolean().value(0); + let that_overflowed = !that_has_all_nulls && that_sum.is_none(); + let this_overflowed = !self.has_all_nulls && self.sum.is_none(); + if that_overflowed || this_overflowed { + self.sum = None; + self.has_all_nulls = false; + return Ok(()); + } + if that_has_all_nulls { + return Ok(()); + } + if self.has_all_nulls { + self.sum = that_sum; + self.has_all_nulls = false; + return Ok(()); + } + } else { + if that_sum.is_none() { + return Ok(()); + } + if self.sum.is_none() { + self.sum = that_sum; + return Ok(()); + } + } + + // safe to unwrap (since we checked nulls above) but handling error just in case state is corrupt + let left = self.sum.ok_or_else(|| { + DataFusionError::Internal( + "Invalid state in merging batch. Current batch's sum is None".to_string(), + ) + })?; + let right = that_sum.ok_or_else(|| { + DataFusionError::Internal( + "Invalid state in merging batch. Incoming sum is None".to_string(), + ) + })?; + + match self.eval_mode { + EvalMode::Legacy => { + self.sum = Some(left.add_wrapping(right)); + } + EvalMode::Ansi | EvalMode::Try => match left.add_checked(right) { + Ok(v) => self.sum = Some(v), + Err(_) => { + if self.eval_mode == EvalMode::Ansi { + return Err(DataFusionError::from(arithmetic_overflow_error("integer"))); + } else { + self.sum = None; + self.has_all_nulls = false; + } + } + }, + } + Ok(()) + } +} + +struct SumIntGroupsAccumulator { + sums: Vec>, + has_all_nulls: Vec, + eval_mode: EvalMode, +} + +impl SumIntGroupsAccumulator { + fn new(eval_mode: EvalMode) -> Self { + Self { + sums: Vec::new(), + eval_mode, + has_all_nulls: Vec::new(), + } + } + + fn resize_helper(&mut self, total_num_groups: usize) { + if self.eval_mode == EvalMode::Try { + self.sums.resize(total_num_groups, Some(0)); + self.has_all_nulls.resize(total_num_groups, true); + } else { + self.sums.resize(total_num_groups, None); + self.has_all_nulls.resize(total_num_groups, false); + } + } +} + +impl GroupsAccumulator for SumIntGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> DFResult<()> { + fn update_groups_sum_internal( + int_array: &PrimitiveArray, + group_indices: &[usize], + sums: &mut [Option], + has_all_nulls: &mut [bool], + eval_mode: EvalMode, + ) -> DFResult<()> + where + T: ArrowPrimitiveType, + T::Native: ArrowNativeType, + { + for (i, &group_index) in group_indices.iter().enumerate() { + if !int_array.is_null(i) { + // there is an overflow in prev group in try eval. Skip processing + if eval_mode == EvalMode::Try + && !has_all_nulls[group_index] + && sums[group_index].is_none() + { + continue; + } + let v = int_array.value(i).to_i64().ok_or_else(|| { + DataFusionError::Internal("Failed to convert value to i64".to_string()) + })?; + match eval_mode { + EvalMode::Legacy => { + sums[group_index] = + Some(sums[group_index].unwrap_or(0).add_wrapping(v)); + } + EvalMode::Ansi | EvalMode::Try => { + match sums[group_index].unwrap_or(0).add_checked(v) { + Ok(new_sum) => { + sums[group_index] = Some(new_sum); + } + Err(_) => { + if eval_mode == EvalMode::Ansi { + return Err(DataFusionError::from( + arithmetic_overflow_error("integer"), + )); + } else { + sums[group_index] = None; + } + } + }; + } + } + has_all_nulls[group_index] = false + } + } + Ok(()) + } + + debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet"); + let values = &values[0]; + self.resize_helper(total_num_groups); + + match values.data_type() { + DataType::Int64 => update_groups_sum_internal( + as_primitive_array::(values), + group_indices, + &mut self.sums, + &mut self.has_all_nulls, + self.eval_mode, + )?, + DataType::Int32 => update_groups_sum_internal( + as_primitive_array::(values), + group_indices, + &mut self.sums, + &mut self.has_all_nulls, + self.eval_mode, + )?, + DataType::Int16 => update_groups_sum_internal( + as_primitive_array::(values), + group_indices, + &mut self.sums, + &mut self.has_all_nulls, + self.eval_mode, + )?, + DataType::Int8 => update_groups_sum_internal( + as_primitive_array::(values), + group_indices, + &mut self.sums, + &mut self.has_all_nulls, + self.eval_mode, + )?, + _ => { + return Err(DataFusionError::Internal(format!( + "Unsupported data type for SumIntGroupsAccumulator: {:?}", + values.data_type() + ))) + } + }; + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> DFResult { + match emit_to { + EmitTo::All => { + let result = Arc::new(Int64Array::from_iter( + self.sums + .iter() + .zip(self.has_all_nulls.iter()) + .map(|(&sum, &is_null)| if is_null { None } else { sum }), + )) as ArrayRef; + + self.sums.clear(); + self.has_all_nulls.clear(); + Ok(result) + } + EmitTo::First(n) => { + let result = Arc::new(Int64Array::from_iter( + self.sums + .drain(..n) + .zip(self.has_all_nulls.drain(..n)) + .map(|(sum, is_null)| if is_null { None } else { sum }), + )) as ArrayRef; + Ok(result) + } + } + } + + fn state(&mut self, emit_to: EmitTo) -> DFResult> { + let sums = emit_to.take_needed(&mut self.sums); + + if self.eval_mode == EvalMode::Try { + let has_all_nulls = emit_to.take_needed(&mut self.has_all_nulls); + Ok(vec![ + Arc::new(Int64Array::from(sums)), + Arc::new(BooleanArray::from(has_all_nulls)), + ]) + } else { + Ok(vec![Arc::new(Int64Array::from(sums))]) + } + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> DFResult<()> { + debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet"); + + let that_sums = values[0].as_primitive::(); + + self.resize_helper(total_num_groups); + + let that_sums_is_all_nulls = if self.eval_mode == EvalMode::Try { + Some(values[1].as_boolean()) + } else { + None + }; + + for (idx, &group_index) in group_indices.iter().enumerate() { + let that_sum = if that_sums.is_null(idx) { + None + } else { + Some(that_sums.value(idx)) + }; + + if self.eval_mode == EvalMode::Try { + let that_has_all_nulls = that_sums_is_all_nulls.unwrap().value(idx); + + let that_overflowed = !that_has_all_nulls && that_sum.is_none(); + let this_overflowed = + !self.has_all_nulls[group_index] && self.sums[group_index].is_none(); + + if that_overflowed || this_overflowed { + self.sums[group_index] = None; + self.has_all_nulls[group_index] = false; + continue; + } + + if that_has_all_nulls { + continue; + } + + if self.has_all_nulls[group_index] { + self.sums[group_index] = that_sum; + self.has_all_nulls[group_index] = false; + continue; + } + } else { + if that_sum.is_none() { + continue; + } + if self.sums[group_index].is_none() { + self.sums[group_index] = that_sum; + continue; + } + } + + // Both sides have non-null. Update sums now + let left = self.sums[group_index].unwrap(); + let right = that_sum.unwrap(); + + match self.eval_mode { + EvalMode::Legacy => { + self.sums[group_index] = Some(left.add_wrapping(right)); + } + EvalMode::Ansi | EvalMode::Try => { + match left.add_checked(right) { + Ok(v) => self.sums[group_index] = Some(v), + Err(_) => { + if self.eval_mode == EvalMode::Ansi { + return Err(DataFusionError::from(arithmetic_overflow_error( + "integer", + ))); + } else { + // overflow. update flag accordingly + self.sums[group_index] = None; + self.has_all_nulls[group_index] = false; + } + } + } + } + } + } + Ok(()) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 8ab568dc83..a05efaebbc 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -213,17 +213,6 @@ object CometAverage extends CometAggregateExpressionSerde[Average] { object CometSum extends CometAggregateExpressionSerde[Sum] { - override def getSupportLevel(sum: Sum): SupportLevel = { - sum.evalMode match { - case EvalMode.ANSI if !sum.dataType.isInstanceOf[DecimalType] => - Incompatible(Some("ANSI mode for non decimal inputs is not supported")) - case EvalMode.TRY if !sum.dataType.isInstanceOf[DecimalType] => - Incompatible(Some("TRY mode for non decimal inputs is not supported")) - case _ => - Compatible() - } - } - override def convert( aggExpr: AggregateExpression, sum: Sum, @@ -236,6 +225,8 @@ object CometSum extends CometAggregateExpressionSerde[Sum] { return None } + val evalMode = sum.evalMode + val childExpr = exprToProto(sum.child, inputs, binding) val dataType = serializeDataType(sum.dataType) @@ -243,7 +234,7 @@ object CometSum extends CometAggregateExpressionSerde[Sum] { val builder = ExprOuterClass.Sum.newBuilder() builder.setChild(childExpr.get) builder.setDatatype(dataType.get) - builder.setEvalMode(evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(sum.evalMode))) + builder.setEvalMode(evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(evalMode))) Some( ExprOuterClass.AggExpr diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 060579b2ba..fe2dc76595 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -24,7 +24,6 @@ import scala.util.Random import org.apache.hadoop.fs.Path import org.apache.spark.sql.{CometTestBase, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Cast -import org.apache.spark.sql.catalyst.expressions.aggregate.Sum import org.apache.spark.sql.catalyst.optimizer.EliminateSorts import org.apache.spark.sql.comet.CometHashAggregateExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -1472,11 +1471,23 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("ANSI support for sum - null test") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withParquetTable( + Seq((null.asInstanceOf[java.lang.Long], "a"), (null.asInstanceOf[java.lang.Long], "b")), + "null_tbl") { + val res = sql("SELECT sum(_1) FROM null_tbl") + checkSparkAnswerAndOperator(res) + assert(res.collect() === Array(Row(null))) + } + } + } + } + test("ANSI support for decimal sum - null test") { Seq(true, false).foreach { ansiEnabled => - withSQLConf( - SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString, - CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { withParquetTable( Seq( (null.asInstanceOf[java.math.BigDecimal], "a"), @@ -1490,11 +1501,23 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("ANSI support for try_sum - null test") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withParquetTable( + Seq((null.asInstanceOf[java.lang.Long], "a"), (null.asInstanceOf[java.lang.Long], "b")), + "null_tbl") { + val res = sql("SELECT try_sum(_1) FROM null_tbl") + checkSparkAnswerAndOperator(res) + assert(res.collect() === Array(Row(null))) + } + } + } + } + test("ANSI support for try_sum decimal - null test") { Seq(true, false).foreach { ansiEnabled => - withSQLConf( - SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString, - CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { withParquetTable( Seq( (null.asInstanceOf[java.math.BigDecimal], "a"), @@ -1508,11 +1531,28 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("ANSI support for sum - null test (group by)") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withParquetTable( + Seq( + (null.asInstanceOf[java.lang.Long], "a"), + (null.asInstanceOf[java.lang.Long], "a"), + (null.asInstanceOf[java.lang.Long], "b"), + (null.asInstanceOf[java.lang.Long], "b"), + (null.asInstanceOf[java.lang.Long], "b")), + "tbl") { + val res = sql("SELECT _2, sum(_1) FROM tbl group by 1") + checkSparkAnswerAndOperator(res) + assert(res.orderBy(col("_2")).collect() === Array(Row("a", null), Row("b", null))) + } + } + } + } + test("ANSI support for decimal sum - null test (group by)") { Seq(true, false).foreach { ansiEnabled => - withSQLConf( - SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString, - CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { withParquetTable( Seq( (null.asInstanceOf[java.math.BigDecimal], "a"), @@ -1529,11 +1569,28 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("ANSI support for try_sum - null test (group by)") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withParquetTable( + Seq( + (null.asInstanceOf[java.lang.Long], "a"), + (null.asInstanceOf[java.lang.Long], "a"), + (null.asInstanceOf[java.lang.Long], "b"), + (null.asInstanceOf[java.lang.Long], "b"), + (null.asInstanceOf[java.lang.Long], "b")), + "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl group by 1") + checkSparkAnswerAndOperator(res) + assert(res.orderBy(col("_2")).collect() === Array(Row("a", null), Row("b", null))) + } + } + } + } + test("ANSI support for try_sum decimal - null test (group by)") { Seq(true, false).foreach { ansiEnabled => - withSQLConf( - SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString, - CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { withParquetTable( Seq( (null.asInstanceOf[java.math.BigDecimal], "a"), @@ -1555,11 +1612,63 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { (1 to 50).flatMap(_ => Seq((maxDec38_0, 1))) } + test("ANSI support - SUM function") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + // Test long overflow + withParquetTable(Seq((Long.MaxValue, 1L), (100L, 1L)), "tbl") { + val res = sql("SELECT SUM(_1) FROM tbl") + if (ansiEnabled) { + checkSparkAnswerMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + case _ => fail("Exception should be thrown for Long overflow in ANSI mode") + } + } else { + checkSparkAnswerAndOperator(res) + } + } + // Test long underflow + withParquetTable(Seq((Long.MinValue, 1L), (-100L, 1L)), "tbl") { + val res = sql("SELECT SUM(_1) FROM tbl") + if (ansiEnabled) { + checkSparkAnswerMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + case _ => fail("Exception should be thrown for Long underflow in ANSI mode") + } + } else { + checkSparkAnswerAndOperator(res) + } + } + // Test Int SUM (should not overflow) + withParquetTable(Seq((Int.MaxValue, 1), (Int.MaxValue, 1), (100, 1)), "tbl") { + val res = sql("SELECT SUM(_1) FROM tbl") + checkSparkAnswerAndOperator(res) + } + // Test Short SUM (should not overflow) + withParquetTable( + Seq((Short.MaxValue, 1.toShort), (Short.MaxValue, 1.toShort), (100.toShort, 1.toShort)), + "tbl") { + val res = sql("SELECT SUM(_1) FROM tbl") + checkSparkAnswerAndOperator(res) + } + // Test Byte SUM (should not overflow) + withParquetTable( + Seq((Byte.MaxValue, 1.toByte), (Byte.MaxValue, 1.toByte), (10.toByte, 1.toByte)), + "tbl") { + val res = sql("SELECT SUM(_1) FROM tbl") + checkSparkAnswerAndOperator(res) + } + } + } + } + test("ANSI support for decimal SUM function") { Seq(true, false).foreach { ansiEnabled => - withSQLConf( - SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString, - CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { withParquetTable(generateOverflowDecimalInputs, "tbl") { val res = sql("SELECT SUM(_1) FROM tbl") if (ansiEnabled) { @@ -1578,11 +1687,68 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("ANSI support for SUM - GROUP BY") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withParquetTable( + Seq((Long.MaxValue, 1), (100L, 1), (Long.MaxValue, 2), (200L, 2)), + "tbl") { + val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2").repartition(2) + if (ansiEnabled) { + checkSparkAnswerMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + case _ => + fail("Exception should be thrown for Long overflow with GROUP BY in ANSI mode") + } + } else { + checkSparkAnswerAndOperator(res) + } + } + + withParquetTable( + Seq((Long.MinValue, 1), (-100L, 1), (Long.MinValue, 2), (-200L, 2)), + "tbl") { + val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") + if (ansiEnabled) { + checkSparkAnswerMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + case _ => + fail("Exception should be thrown for Long underflow with GROUP BY in ANSI mode") + } + } else { + checkSparkAnswerAndOperator(res) + } + } + // Test Int with GROUP BY + withParquetTable(Seq((Int.MaxValue, 1), (Int.MaxValue, 1), (100, 2), (200, 2)), "tbl") { + val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(res) + } + // Test Short with GROUP BY + withParquetTable( + Seq((Short.MaxValue, 1), (Short.MaxValue, 1), (100.toShort, 2), (200.toShort, 2)), + "tbl") { + val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(res) + } + // Test Byte with GROUP BY + withParquetTable( + Seq((Byte.MaxValue, 1), (Byte.MaxValue, 1), (10.toByte, 2), (20.toByte, 2)), + "tbl") { + val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(res) + } + } + } + } + test("ANSI support for decimal SUM - GROUP BY") { Seq(true, false).foreach { ansiEnabled => - withSQLConf( - SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString, - CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { withParquetTable(generateOverflowDecimalInputs, "tbl") { val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2").repartition(2) @@ -1602,35 +1768,68 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("try_sum overflow - with GROUP BY") { + // Test Long overflow with GROUP BY - some groups overflow while some don't + withParquetTable(Seq((Long.MaxValue, 1), (100L, 1), (200L, 2), (300L, 2)), "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) + // first group should return NULL (overflow) and group 2 should return 500 + checkSparkAnswerAndOperator(res) + } + + // Test Long underflow with GROUP BY + withParquetTable(Seq((Long.MinValue, 1), (-100L, 1), (-200L, 2), (-300L, 2)), "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) + // first group should return NULL (underflow), second group should return neg 500 + checkSparkAnswerAndOperator(res) + } + + // Test all groups overflow + withParquetTable(Seq((Long.MaxValue, 1), (100L, 1), (Long.MaxValue, 2), (100L, 2)), "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) + // Both groups should return NULL + checkSparkAnswerAndOperator(res) + } + + // Test Short with GROUP BY (should NOT overflow) + withParquetTable( + Seq((Short.MaxValue, 1), (Short.MaxValue, 1), (100.toShort, 2), (200.toShort, 2)), + "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) + checkSparkAnswerAndOperator(res) + } + + // Test Byte with GROUP BY (no overflow) + withParquetTable( + Seq((Byte.MaxValue, 1), (Byte.MaxValue, 1), (10.toByte, 2), (20.toByte, 2)), + "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) + checkSparkAnswerAndOperator(res) + } + } + test("try_sum decimal overflow") { - withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") { - withParquetTable(generateOverflowDecimalInputs, "tbl") { - val res = sql("SELECT try_sum(_1) FROM tbl") - checkSparkAnswerAndOperator(res) - } + withParquetTable(generateOverflowDecimalInputs, "tbl") { + val res = sql("SELECT try_sum(_1) FROM tbl") + checkSparkAnswerAndOperator(res) } } test("try_sum decimal overflow - with GROUP BY") { - withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") { - withParquetTable(generateOverflowDecimalInputs, "tbl") { - val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) - checkSparkAnswerAndOperator(res) - } + withParquetTable(generateOverflowDecimalInputs, "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) + checkSparkAnswerAndOperator(res) } } test("try_sum decimal partial overflow - with GROUP BY") { - withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") { - // Group 1 overflows, Group 2 succeeds - val data: Seq[(java.math.BigDecimal, Int)] = generateOverflowDecimalInputs ++ Seq( - (new java.math.BigDecimal(300), 2), - (new java.math.BigDecimal(200), 2)) - withParquetTable(data, "tbl") { - val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2") - // Group 1 should be NULL, Group 2 should be 500 - checkSparkAnswerAndOperator(res) - } + // Group 1 overflows, Group 2 succeeds + val data: Seq[(java.math.BigDecimal, Int)] = generateOverflowDecimalInputs ++ Seq( + (new java.math.BigDecimal(300), 2), + (new java.math.BigDecimal(200), 2)) + withParquetTable(data, "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2") + // Group 1 should be NULL, Group 2 should be 500 + checkSparkAnswerAndOperator(res) } } diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala index 8f260e2ca8..1728ce5b27 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.SparkContext import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE} import org.apache.spark.sql.TPCDSBase import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Cast} -import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum} +import org.apache.spark.sql.catalyst.expressions.aggregate.Average import org.apache.spark.sql.catalyst.util.resourceToString import org.apache.spark.sql.execution.{FormattedMode, ReusedSubqueryExec, SparkPlan, SubqueryBroadcastExec, SubqueryExec} import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite @@ -226,7 +226,6 @@ trait CometPlanStabilitySuite extends DisableAdaptiveExecutionSuite with TPCDSBa CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key -> "true", // Allow Incompatible is needed for Sum + Average for Spark 4.0.0 / ANSI support CometConf.getExprAllowIncompatConfigKey(classOf[Average]) -> "true", - CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true", // as well as for v1.4/q9, v1.4/q44, v2.7.0/q6, v2.7.0/q64 CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB") {