From 58d8c9a352cf43e0127fb1e3518c624b82a4215d Mon Sep 17 00:00:00 2001 From: hsiang-c Date: Thu, 16 Oct 2025 14:32:23 -0700 Subject: [PATCH 1/8] Removed old ABS implementation --- native/core/src/execution/planner.rs | 13 ------------- native/proto/src/proto/expr.proto | 6 ------ 2 files changed, 19 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 1550efd799..ffa5f86052 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -675,19 +675,6 @@ impl PhysicalPlanner { let op = DataFusionOperator::BitwiseShiftLeft; Ok(Arc::new(BinaryExpr::new(left, op, right))) } - // https://github.com/apache/datafusion-comet/issues/666 - // ExprStruct::Abs(expr) => { - // let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?; - // let return_type = child.data_type(&input_schema)?; - // let args = vec![child]; - // let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; - // let comet_abs = Arc::new(ScalarUDF::new_from_impl(Abs::new( - // eval_mode, - // return_type.to_string(), - // )?)); - // let expr = ScalarFunctionExpr::new("abs", comet_abs, args, return_type); - // Ok(Arc::new(expr)) - // } ExprStruct::CaseWhen(case_when) => { let when_then_pairs = case_when .when diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 5853bc613c..c9037dcd69 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -70,7 +70,6 @@ message Expr { IfExpr if = 44; NormalizeNaNAndZero normalize_nan_and_zero = 45; TruncTimestamp truncTimestamp = 47; - Abs abs = 49; Subquery subquery = 50; UnboundReference unbound = 51; BloomFilterMightContain bloom_filter_might_contain = 52; @@ -351,11 +350,6 @@ message TruncTimestamp { string timezone = 3; } -message Abs { - Expr child = 1; - EvalMode eval_mode = 2; -} - message Subquery { int64 id = 1; DataType datatype = 2; From d5b70c3275fbe958b3b7a88a35fd78bce49f1bba Mon Sep 17 00:00:00 2001 From: hsiang-c Date: Thu, 16 Oct 2025 14:33:28 -0700 Subject: [PATCH 2/8] Define Comet's ABS in Scala --- .../apache/comet/serde/QueryPlanSerde.scala | 3 ++- .../scala/org/apache/comet/serde/math.scala | 21 ++++++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 233261091b..33739d9491 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -137,7 +137,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[Subtract] -> CometSubtract, classOf[Tan] -> CometScalarFunction("tan"), classOf[UnaryMinus] -> CometUnaryMinus, - classOf[Unhex] -> CometUnhex) + classOf[Unhex] -> CometUnhex, + classOf[Abs] -> CometAbs) private val mapExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[GetMapValue] -> CometMapExtract, diff --git a/spark/src/main/scala/org/apache/comet/serde/math.scala b/spark/src/main/scala/org/apache/comet/serde/math.scala index bfcd242d76..90de894e3e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/math.scala +++ b/spark/src/main/scala/org/apache/comet/serde/math.scala @@ -19,7 +19,7 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.{Atan2, Attribute, Ceil, CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log, Log10, Log2, Unhex} +import org.apache.spark.sql.catalyst.expressions.{Abs, Atan2, Attribute, Ceil, CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log, Log10, Log2, Unhex} import org.apache.spark.sql.types.DecimalType import org.apache.comet.CometSparkSessionExtensions.withInfo @@ -144,6 +144,25 @@ object CometUnhex extends CometExpressionSerde[Unhex] with MathExprBase { } } +object CometAbs extends CometExpressionSerde[Abs] with MathExprBase { + override def convert( + expr: Abs, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val childExpr = exprToProtoInternal(expr.child, inputs, binding) + val failOnErrorExpr = exprToProtoInternal(Literal(expr.failOnError), inputs, binding) + + val optExpr = + scalarFunctionExprToProtoWithReturnType( + "abs", + expr.dataType, + false, + childExpr, + failOnErrorExpr) + optExprWithInfo(optExpr, expr, expr.child) + } +} + sealed trait MathExprBase { protected def nullIfNegative(expression: Expression): Expression = { val zero = Literal.default(expression.dataType) From 7b1965adb8777974d251809e5451065cce888acf Mon Sep 17 00:00:00 2001 From: hsiang-c Date: Thu, 16 Oct 2025 14:34:32 -0700 Subject: [PATCH 3/8] Implement Comet's ABS in rust --- native/spark-expr/src/comet_scalar_funcs.rs | 5 + native/spark-expr/src/math_funcs/abs.rs | 798 ++++++++++++++++++++ native/spark-expr/src/math_funcs/mod.rs | 1 + 3 files changed, 804 insertions(+) create mode 100644 native/spark-expr/src/math_funcs/abs.rs diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index fc0c096b15..021bb1c78f 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -16,6 +16,7 @@ // under the License. use crate::hash_funcs::*; +use crate::math_funcs::abs::abs; use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub}; use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ @@ -180,6 +181,10 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(spark_modulo); make_comet_scalar_udf!("spark_modulo", func, without data_type, fail_on_error) } + "abs" => { + let func = Arc::new(abs); + make_comet_scalar_udf!("abs", func, without data_type) + } _ => registry.udf(fun_name).map_err(|e| { DataFusionError::Execution(format!( "Function {fun_name} not found in the registry: {e}", diff --git a/native/spark-expr/src/math_funcs/abs.rs b/native/spark-expr/src/math_funcs/abs.rs new file mode 100644 index 0000000000..54d21a3bb5 --- /dev/null +++ b/native/spark-expr/src/math_funcs/abs.rs @@ -0,0 +1,798 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::arithmetic_overflow_error; +use arrow::array::*; +use arrow::datatypes::*; +use arrow::error::ArrowError; +use datafusion::common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion::logical_expr::ColumnarValue; +use std::sync::Arc; + +macro_rules! legacy_compute_op { + ($ARRAY:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident) => {{ + let n = $ARRAY.as_any().downcast_ref::<$TYPE>(); + match n { + Some(array) => { + let res: $RESULT = arrow::compute::kernels::arity::unary(array, |x| x.$FUNC()); + Ok(res) + } + _ => Err(DataFusionError::Internal(format!( + "Invalid data type for abs" + ))), + } + }}; +} + +macro_rules! ansi_compute_op { + ($ARRAY:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident, $NATIVE:ident, $FROM_TYPE:expr) => {{ + let n = $ARRAY.as_any().downcast_ref::<$TYPE>(); + match n { + Some(array) => { + match arrow::compute::kernels::arity::try_unary(array, |x| { + if x == $NATIVE::MIN { + Err(ArrowError::ArithmeticOverflow($FROM_TYPE.to_string())) + } else { + Ok(x.$FUNC()) + } + }) { + Ok(res) => Ok(ColumnarValue::Array(Arc::>::new( + res, + ))), + Err(_) => Err(arithmetic_overflow_error($FROM_TYPE).into()), + } + } + _ => Err(DataFusionError::Internal("Invalid data type".to_string())), + } + }}; +} + +/// This function mimics SparkSQL's [Abs]: https://github.com/apache/spark/blob/v4.0.1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala#L148 +/// Spark's [ANSI-compliant]: https://spark.apache.org/docs/latest/sql-ref-ansi-compliance.html#arithmetic-operations dialect mode throws org.apache.spark.SparkArithmeticException +/// when abs causes overflow. +pub fn abs(args: &[ColumnarValue]) -> Result { + if args.len() > 2 { + return exec_err!("abs takes at most 2 arguments, but got: {}", args.len()); + } + + let fail_on_error = if args.len() == 2 { + match &args[1] { + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))) => *fail_on_error, + _ => { + return exec_err!( + "The second argument must be boolean scalar, but got: {:?}", + args[1] + ); + } + } + } else { + false + }; + + match &args[0] { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Null + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => Ok(args[0].clone()), + DataType::Int8 => { + if !fail_on_error { + let result = legacy_compute_op!(array, wrapping_abs, Int8Array, Int8Array); + Ok(ColumnarValue::Array(Arc::new(result?))) + } else { + ansi_compute_op!(array, abs, Int8Array, Int8Type, i8, "Int8") + } + } + DataType::Int16 => { + if !fail_on_error { + let result = legacy_compute_op!(array, wrapping_abs, Int16Array, Int16Array); + Ok(ColumnarValue::Array(Arc::new(result?))) + } else { + ansi_compute_op!(array, abs, Int16Array, Int16Type, i16, "Int16") + } + } + DataType::Int32 => { + if !fail_on_error { + let result = legacy_compute_op!(array, wrapping_abs, Int32Array, Int32Array); + Ok(ColumnarValue::Array(Arc::new(result?))) + } else { + ansi_compute_op!(array, abs, Int32Array, Int32Type, i32, "Int32") + } + } + DataType::Int64 => { + if !fail_on_error { + let result = legacy_compute_op!(array, wrapping_abs, Int64Array, Int64Array); + Ok(ColumnarValue::Array(Arc::new(result?))) + } else { + ansi_compute_op!(array, abs, Int64Array, Int64Type, i64, "Int64") + } + } + DataType::Float32 => { + let result = legacy_compute_op!(array, abs, Float32Array, Float32Array); + Ok(ColumnarValue::Array(Arc::new(result?))) + } + DataType::Float64 => { + let result = legacy_compute_op!(array, abs, Float64Array, Float64Array); + Ok(ColumnarValue::Array(Arc::new(result?))) + } + DataType::Decimal128(precision, scale) => { + if !fail_on_error { + let result = + legacy_compute_op!(array, wrapping_abs, Decimal128Array, Decimal128Array)?; + let result = result.with_data_type(DataType::Decimal128(*precision, *scale)); + Ok(ColumnarValue::Array(Arc::new(result))) + } else { + // Need to pass precision and scale from input, so not using ansi_compute_op + let input = array.as_any().downcast_ref::(); + match input { + Some(i) => { + match arrow::compute::kernels::arity::try_unary(i, |x| { + if x == i128::MIN { + Err(ArrowError::ArithmeticOverflow("Decimal128".to_string())) + } else { + Ok(x.abs()) + } + }) { + Ok(res) => Ok(ColumnarValue::Array(Arc::< + PrimitiveArray, + >::new( + res.with_data_type(DataType::Decimal128(*precision, *scale)), + ))), + Err(_) => Err(arithmetic_overflow_error("Decimal128").into()), + } + } + _ => Err(DataFusionError::Internal("Invalid data type".to_string())), + } + } + } + DataType::Decimal256(precision, scale) => { + if !fail_on_error { + let result = + legacy_compute_op!(array, wrapping_abs, Decimal256Array, Decimal256Array)?; + let result = result.with_data_type(DataType::Decimal256(*precision, *scale)); + Ok(ColumnarValue::Array(Arc::new(result))) + } else { + // Need to pass precision and scale from input, so not using ansi_compute_op + let input = array.as_any().downcast_ref::(); + match input { + Some(i) => { + match arrow::compute::kernels::arity::try_unary(i, |x| { + if x == i256::MIN { + Err(ArrowError::ArithmeticOverflow("Decimal256".to_string())) + } else { + Ok(x.wrapping_abs()) // i256 doesn't define abs() method + } + }) { + Ok(res) => Ok(ColumnarValue::Array(Arc::< + PrimitiveArray, + >::new( + res.with_data_type(DataType::Decimal256(*precision, *scale)), + ))), + Err(_) => Err(arithmetic_overflow_error("Decimal256").into()), + } + } + _ => Err(DataFusionError::Internal("Invalid data type".to_string())), + } + } + } + dt => exec_err!("Not supported datatype for ABS: {dt}"), + }, + ColumnarValue::Scalar(sv) => match sv { + ScalarValue::Null + | ScalarValue::UInt8(_) + | ScalarValue::UInt16(_) + | ScalarValue::UInt32(_) + | ScalarValue::UInt64(_) => Ok(args[0].clone()), + ScalarValue::Int8(a) => a + .map(|v| match v.checked_abs() { + Some(abs_val) => Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(abs_val)))), + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(v)))) + } else { + Err(arithmetic_overflow_error("Int8").into()) + } + } + }) + .unwrap(), + ScalarValue::Int16(a) => a + .map(|v| match v.checked_abs() { + Some(abs_val) => Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(abs_val)))), + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(v)))) + } else { + Err(arithmetic_overflow_error("Int16").into()) + } + } + }) + .unwrap(), + ScalarValue::Int32(a) => a + .map(|v| match v.checked_abs() { + Some(abs_val) => Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(abs_val)))), + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(v)))) + } else { + Err(arithmetic_overflow_error("Int32").into()) + } + } + }) + .unwrap(), + ScalarValue::Int64(a) => a + .map(|v| match v.checked_abs() { + Some(abs_val) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(abs_val)))), + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(v)))) + } else { + Err(arithmetic_overflow_error("Int64").into()) + } + } + }) + .unwrap(), + ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Float32( + a.map(|x| x.abs()), + ))), + ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Float64( + a.map(|x| x.abs()), + ))), + ScalarValue::Decimal128(a, precision, scale) => a + .map(|v| match v.checked_abs() { + Some(abs_val) => Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + Some(abs_val), + *precision, + *scale, + ))), + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + Some(v), + *precision, + *scale, + ))) + } else { + Err(arithmetic_overflow_error("Decimal128").into()) + } + } + }) + .unwrap(), + ScalarValue::Decimal256(a, precision, scale) => a + .map(|v| match v.checked_abs() { + Some(abs_val) => Ok(ColumnarValue::Scalar(ScalarValue::Decimal256( + Some(abs_val), + *precision, + *scale, + ))), + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Decimal256( + Some(v), + *precision, + *scale, + ))) + } else { + Err(arithmetic_overflow_error("Decimal256").into()) + } + } + }) + .unwrap(), + dt => exec_err!("Not supported datatype for ABS: {dt}"), + }, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::common::cast::{ + as_decimal128_array, as_decimal256_array, as_float32_array, as_float64_array, + as_int16_array, as_int32_array, as_int64_array, as_int8_array, as_uint64_array, + }; + + fn with_fail_on_error Result<()>>(test_fn: F) { + for fail_on_error in [true, false] { + let _ = test_fn(fail_on_error); + } + } + + // Unsigned types, return as is + #[test] + fn test_abs_u8_scalar() { + with_fail_on_error(|fail_on_error| { + let args = ColumnarValue::Scalar(ScalarValue::UInt8(Some(u8::MAX))); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::UInt8(Some(result)))) => { + assert_eq!(result, u8::MAX); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i8_scalar() { + with_fail_on_error(|fail_on_error| { + let args = ColumnarValue::Scalar(ScalarValue::Int8(Some(i8::MIN))); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(result)))) => { + assert_eq!(result, i8::MIN); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i16_scalar() { + with_fail_on_error(|fail_on_error| { + let args = ColumnarValue::Scalar(ScalarValue::Int16(Some(i16::MIN))); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(result)))) => { + assert_eq!(result, i16::MIN); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i32_scalar() { + with_fail_on_error(|fail_on_error| { + let args = ColumnarValue::Scalar(ScalarValue::Int32(Some(i32::MIN))); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(result)))) => { + assert_eq!(result, i32::MIN); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i64_scalar() { + with_fail_on_error(|fail_on_error| { + let args = ColumnarValue::Scalar(ScalarValue::Int64(Some(i64::MIN))); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(result)))) => { + assert_eq!(result, i64::MIN); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_decimal128_scalar() { + with_fail_on_error(|fail_on_error| { + let args = ColumnarValue::Scalar(ScalarValue::Decimal128(Some(i128::MIN), 18, 10)); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(Some(result), precision, scale))) => { + assert_eq!(result, i128::MIN); + assert_eq!(precision, 18); + assert_eq!(scale, 10); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_decimal256_scalar() { + with_fail_on_error(|fail_on_error| { + let args = ColumnarValue::Scalar(ScalarValue::Decimal256(Some(i256::MIN), 10, 2)); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::Decimal256(Some(result), precision, scale))) => { + assert_eq!(result, i256::MIN); + assert_eq!(precision, 10); + assert_eq!(scale, 2); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i8_array() { + with_fail_on_error(|fail_on_error| { + let input = Int8Array::from(vec![Some(-1), Some(i8::MIN), Some(i8::MAX), None]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = Int8Array::from(vec![Some(1), Some(i8::MIN), Some(i8::MAX), None]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_int8_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i16_array() { + with_fail_on_error(|fail_on_error| { + let input = Int16Array::from(vec![Some(-1), Some(i16::MIN), Some(i16::MAX), None]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = Int16Array::from(vec![Some(1), Some(i16::MIN), Some(i16::MAX), None]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_int16_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i32_array() { + with_fail_on_error(|fail_on_error| { + let input = Int32Array::from(vec![Some(-1), Some(i32::MIN), Some(i32::MAX), None]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = Int32Array::from(vec![Some(1), Some(i32::MIN), Some(i32::MAX), None]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_int32_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i64_array() { + with_fail_on_error(|fail_on_error| { + let input = Int64Array::from(vec![Some(-1), Some(i64::MIN), Some(i64::MAX), None]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = Int64Array::from(vec![Some(1), Some(i64::MIN), Some(i64::MAX), None]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_int64_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_f32_array() { + with_fail_on_error(|fail_on_error| { + let input = Float32Array::from(vec![Some(-1f32), Some(f32::MIN), Some(f32::MAX), None]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = + Float32Array::from(vec![Some(1f32), Some(f32::MAX), Some(f32::MAX), None]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_float32_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_f64_array() { + with_fail_on_error(|fail_on_error| { + let input = Float64Array::from(vec![Some(-1f64), Some(f64::MIN), Some(f64::MAX), None]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = + Float64Array::from(vec![Some(1f64), Some(f64::MAX), Some(f64::MAX), None]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_float64_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_decimal128_array() { + with_fail_on_error(|fail_on_error| { + let input = Decimal128Array::from(vec![Some(i128::MIN), None]) + .with_precision_and_scale(38, 37)?; + let args = ColumnarValue::Array(Arc::new(input)); + let expected = Decimal128Array::from(vec![Some(i128::MIN), None]) + .with_precision_and_scale(38, 37)?; + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_decimal128_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_decimal256_array() { + with_fail_on_error(|fail_on_error| { + let input = Decimal256Array::from(vec![Some(i256::MIN), None]) + .with_precision_and_scale(5, 2)?; + let args = ColumnarValue::Array(Arc::new(input)); + let expected = Decimal256Array::from(vec![Some(i256::MIN), None]) + .with_precision_and_scale(5, 2)?; + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_decimal256_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_u64_array() { + with_fail_on_error(|fail_on_error| { + let input = UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_uint64_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } +} diff --git a/native/spark-expr/src/math_funcs/mod.rs b/native/spark-expr/src/math_funcs/mod.rs index 873b290ebd..7df87eb9f2 100644 --- a/native/spark-expr/src/math_funcs/mod.rs +++ b/native/spark-expr/src/math_funcs/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +pub(crate) mod abs; mod ceil; pub(crate) mod checked_arithmetic; mod div; From eb94cb6027a994e669ca4b517f0c98db999d13cd Mon Sep 17 00:00:00 2001 From: hsiang-c Date: Thu, 16 Oct 2025 14:34:49 -0700 Subject: [PATCH 4/8] Enable ABS tests in legacy/ANSI mode --- .../apache/comet/CometExpressionSuite.scala | 27 ++++++++++--------- .../org/apache/spark/sql/CometTestBase.scala | 8 +++--- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 9085c0fa29..f64b6d22b5 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1385,23 +1385,25 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { testDoubleScalarExpr("expm1") } - // https://github.com/apache/datafusion-comet/issues/666 - ignore("abs") { - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = dictionaryEnabled, 100) - withParquetTable(path.toString, "tbl") { - Seq(2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 15, 16, 17).foreach { col => - checkSparkAnswerAndOperator(s"SELECT abs(_${col}) FROM tbl") + test("abs") { + Seq(true, false).foreach { ansi_enabled => + Seq(true, false).foreach { dictionaryEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansi_enabled.toString) { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = dictionaryEnabled, 100) + withParquetTable(path.toString, "tbl") { + Seq(2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 15, 16, 17).foreach { col => + checkSparkAnswerAndOperator(s"SELECT abs(_${col}) FROM tbl") + } + } } } } } } - // https://github.com/apache/datafusion-comet/issues/666 - ignore("abs Overflow ansi mode") { + test("abs Overflow ANSI mode") { def testAbsAnsiOverflow[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = { withParquetTable(data, "tbl") { @@ -1434,8 +1436,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - // https://github.com/apache/datafusion-comet/issues/666 - ignore("abs Overflow legacy mode") { + test("abs Overflow legacy mode") { def testAbsLegacyOverflow[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = { withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 844bd07f3b..900b8a44f5 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -504,10 +504,10 @@ abstract class CometTestBase | optional float _6; | optional double _7; | optional binary _8(UTF8); - | optional int32 _9(UINT_8); - | optional int32 _10(UINT_16); - | optional int32 _11(UINT_32); - | optional int64 _12(UINT_64); + | optional int32 _9(INT_8); + | optional int32 _10(INT_16); + | optional int32 _11(INT_32); + | optional int64 _12(INT_64); | optional binary _13(ENUM); | optional FIXED_LEN_BYTE_ARRAY(3) _14; | optional int32 _15(DECIMAL(5, 2)); From ce1c1c20684d2a4a0386e8bc0686eca59392d2f2 Mon Sep 17 00:00:00 2001 From: hsiang-c Date: Wed, 15 Oct 2025 17:02:14 -0700 Subject: [PATCH 5/8] Fix bit position b/c schema change in CometTestBase --- .../scala/org/apache/comet/CometBitwiseExpressionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala index d89e81b0fd..cf7eb02bff 100644 --- a/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala @@ -134,7 +134,7 @@ class CometBitwiseExpressionSuite extends CometTestBase with AdaptiveSparkPlanHe s"bit_get(_3, $shortBitPosition)", s"bit_get(_4, $intBitPosition)", s"bit_get(_5, $longBitPosition)", - s"bit_get(_11, $longBitPosition)")) + s"bit_get(_11, $intBitPosition)")) } } } From 9cb2f6ca208699958459f0cce30a77deecd1cd95 Mon Sep 17 00:00:00 2001 From: hsiang-c Date: Thu, 16 Oct 2025 14:21:58 -0700 Subject: [PATCH 6/8] Update docs --- docs/source/user-guide/latest/configs.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/user-guide/latest/configs.md b/docs/source/user-guide/latest/configs.md index a299a75738..1b593e4821 100644 --- a/docs/source/user-guide/latest/configs.md +++ b/docs/source/user-guide/latest/configs.md @@ -150,6 +150,7 @@ These settings can be used to determine which parts of the plan are accelerated | Config | Description | Default Value | |--------|-------------|---------------| +| `spark.comet.expression.Abs.enabled` | Enable Comet acceleration for `Abs` | true | | `spark.comet.expression.Acos.enabled` | Enable Comet acceleration for `Acos` | true | | `spark.comet.expression.Add.enabled` | Enable Comet acceleration for `Add` | true | | `spark.comet.expression.Alias.enabled` | Enable Comet acceleration for `Alias` | true | From 6016f6993ee151ba8d1fe336b096422f8ce552f8 Mon Sep 17 00:00:00 2001 From: hsiang-c Date: Fri, 17 Oct 2025 13:32:29 -0700 Subject: [PATCH 7/8] Fix style --- native/spark-expr/src/math_funcs/abs.rs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/native/spark-expr/src/math_funcs/abs.rs b/native/spark-expr/src/math_funcs/abs.rs index 54d21a3bb5..78148995dd 100644 --- a/native/spark-expr/src/math_funcs/abs.rs +++ b/native/spark-expr/src/math_funcs/abs.rs @@ -460,7 +460,11 @@ mod tests { let fail_on_error_arg = ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); match abs(&[args, fail_on_error_arg]) { - Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(Some(result), precision, scale))) => { + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + Some(result), + precision, + scale, + ))) => { assert_eq!(result, i128::MIN); assert_eq!(precision, 18); assert_eq!(scale, 10); @@ -489,7 +493,11 @@ mod tests { let fail_on_error_arg = ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); match abs(&[args, fail_on_error_arg]) { - Ok(ColumnarValue::Scalar(ScalarValue::Decimal256(Some(result), precision, scale))) => { + Ok(ColumnarValue::Scalar(ScalarValue::Decimal256( + Some(result), + precision, + scale, + ))) => { assert_eq!(result, i256::MIN); assert_eq!(precision, 10); assert_eq!(scale, 2); From 1ddff9831c6e0b750289723e24e89a08b7c6a39e Mon Sep 17 00:00:00 2001 From: hsiang-c Date: Fri, 17 Oct 2025 13:48:23 -0700 Subject: [PATCH 8/8] Update doc --- docs/source/user-guide/latest/expressions.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index 3ccead03a1..809e69d2f8 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -118,6 +118,7 @@ incompatible expressions. | Expression | SQL | Spark-Compatible? | Compatibility Notes | |----------------|-----------|-------------------|-----------------------------------| +| Abs | `abs` | Yes | | | Acos | `acos` | Yes | | | Add | `+` | Yes | | | Asin | `asin` | Yes | |