diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 8bfcc8f76e69..821840d08f99 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -73,12 +73,13 @@ use arrow::compute::kernels::numeric::{ add, add_wrapping, div, mul, mul_wrapping, rem, sub, sub_wrapping, }; use arrow::datatypes::{ - i256, ArrowDictionaryKeyType, ArrowNativeType, ArrowTimestampType, DataType, - Date32Type, Field, Float32Type, Int16Type, Int32Type, Int64Type, Int8Type, - IntervalDayTime, IntervalDayTimeType, IntervalMonthDayNano, IntervalMonthDayNanoType, - IntervalUnit, IntervalYearMonthType, TimeUnit, TimestampMicrosecondType, - TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, - UInt32Type, UInt64Type, UInt8Type, UnionFields, UnionMode, DECIMAL128_MAX_PRECISION, + i256, validate_decimal_precision_and_scale, ArrowDictionaryKeyType, ArrowNativeType, + ArrowTimestampType, DataType, Date32Type, Decimal128Type, Decimal256Type, Field, + Float32Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTime, + IntervalDayTimeType, IntervalMonthDayNano, IntervalMonthDayNanoType, IntervalUnit, + IntervalYearMonthType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, + TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, + UInt8Type, UnionFields, UnionMode, DECIMAL128_MAX_PRECISION, }; use arrow::util::display::{array_value_to_string, ArrayFormatter, FormatOptions}; use chrono::{Duration, NaiveDate}; @@ -1382,6 +1383,34 @@ impl ScalarValue { DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(1.0))), DataType::Float32 => ScalarValue::Float32(Some(1.0)), DataType::Float64 => ScalarValue::Float64(Some(1.0)), + DataType::Decimal128(precision, scale) => { + validate_decimal_precision_and_scale::( + *precision, *scale, + )?; + if *scale < 0 { + return _internal_err!("Negative scale is not supported"); + } + match i128::from(10).checked_pow(*scale as u32) { + Some(value) => { + ScalarValue::Decimal128(Some(value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } + DataType::Decimal256(precision, scale) => { + validate_decimal_precision_and_scale::( + *precision, *scale, + )?; + if *scale < 0 { + return _internal_err!("Negative scale is not supported"); + } + match i256::from(10).checked_pow(*scale as u32) { + Some(value) => { + ScalarValue::Decimal256(Some(value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } _ => { return _not_impl_err!( "Can't create an one scalar from data_type \"{datatype:?}\"" @@ -1400,6 +1429,34 @@ impl ScalarValue { DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(-1.0))), DataType::Float32 => ScalarValue::Float32(Some(-1.0)), DataType::Float64 => ScalarValue::Float64(Some(-1.0)), + DataType::Decimal128(precision, scale) => { + validate_decimal_precision_and_scale::( + *precision, *scale, + )?; + if *scale < 0 { + return _internal_err!("Negative scale is not supported"); + } + match i128::from(10).checked_pow(*scale as u32) { + Some(value) => { + ScalarValue::Decimal128(Some(-value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } + DataType::Decimal256(precision, scale) => { + validate_decimal_precision_and_scale::( + *precision, *scale, + )?; + if *scale < 0 { + return _internal_err!("Negative scale is not supported"); + } + match i256::from(10).checked_pow(*scale as u32) { + Some(value) => { + ScalarValue::Decimal256(Some(-value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } _ => { return _not_impl_err!( "Can't create a negative one scalar from data_type \"{datatype:?}\"" @@ -1421,6 +1478,38 @@ impl ScalarValue { DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(10.0))), DataType::Float32 => ScalarValue::Float32(Some(10.0)), DataType::Float64 => ScalarValue::Float64(Some(10.0)), + DataType::Decimal128(precision, scale) => { + if let Err(err) = validate_decimal_precision_and_scale::( + *precision, *scale, + ) { + return _internal_err!("Invalid precision and scale {err}"); + } + if *scale <= 0 { + return _internal_err!("Negative scale is not supported"); + } + match i128::from(10).checked_pow((*scale + 1) as u32) { + Some(value) => { + ScalarValue::Decimal128(Some(value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } + DataType::Decimal256(precision, scale) => { + if let Err(err) = validate_decimal_precision_and_scale::( + *precision, *scale, + ) { + return _internal_err!("Invalid precision and scale {err}"); + } + if *scale <= 0 { + return _internal_err!("Negative scale is not supported"); + } + match i256::from(10).checked_pow((*scale + 1) as u32) { + Some(value) => { + ScalarValue::Decimal256(Some(value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } _ => { return _not_impl_err!( "Can't create a ten scalar from data_type \"{datatype:?}\"" @@ -1790,6 +1879,26 @@ impl ScalarValue { (Self::Float64(Some(l)), Self::Float64(Some(r))) => { Some((l - r).abs().round() as _) } + ( + Self::Decimal128(Some(l), lprecision, lscale), + Self::Decimal128(Some(r), rprecision, rscale), + ) => { + if lprecision == rprecision && lscale == rscale { + l.checked_sub(*r)?.checked_abs()?.to_usize() + } else { + None + } + } + ( + Self::Decimal256(Some(l), lprecision, lscale), + Self::Decimal256(Some(r), rprecision, rscale), + ) => { + if lprecision == rprecision && lscale == rscale { + l.checked_sub(*r)?.checked_abs()?.to_usize() + } else { + None + } + } _ => None, } } @@ -4160,7 +4269,9 @@ mod tests { }; use arrow::buffer::{Buffer, OffsetBuffer}; use arrow::compute::{is_null, kernels}; - use arrow::datatypes::{ArrowNumericType, Fields, Float64Type}; + use arrow::datatypes::{ + ArrowNumericType, Fields, Float64Type, DECIMAL256_MAX_PRECISION, + }; use arrow::error::ArrowError; use arrow::util::pretty::pretty_format_columns; use chrono::NaiveDate; @@ -4896,6 +5007,116 @@ mod tests { Ok(()) } + #[test] + fn test_new_one_decimal128() { + assert_eq!( + ScalarValue::new_one(&DataType::Decimal128(5, 0)).unwrap(), + ScalarValue::Decimal128(Some(1), 5, 0) + ); + assert_eq!( + ScalarValue::new_one(&DataType::Decimal128(5, 1)).unwrap(), + ScalarValue::Decimal128(Some(10), 5, 1) + ); + assert_eq!( + ScalarValue::new_one(&DataType::Decimal128(5, 2)).unwrap(), + ScalarValue::Decimal128(Some(100), 5, 2) + ); + // More precision + assert_eq!( + ScalarValue::new_one(&DataType::Decimal128(7, 2)).unwrap(), + ScalarValue::Decimal128(Some(100), 7, 2) + ); + // No negative scale + assert!(ScalarValue::new_one(&DataType::Decimal128(5, -1)).is_err()); + // Invalid combination + assert!(ScalarValue::new_one(&DataType::Decimal128(0, 2)).is_err()); + assert!(ScalarValue::new_one(&DataType::Decimal128(5, 7)).is_err()); + } + + #[test] + fn test_new_one_decimal256() { + assert_eq!( + ScalarValue::new_one(&DataType::Decimal256(5, 0)).unwrap(), + ScalarValue::Decimal256(Some(1.into()), 5, 0) + ); + assert_eq!( + ScalarValue::new_one(&DataType::Decimal256(5, 1)).unwrap(), + ScalarValue::Decimal256(Some(10.into()), 5, 1) + ); + assert_eq!( + ScalarValue::new_one(&DataType::Decimal256(5, 2)).unwrap(), + ScalarValue::Decimal256(Some(100.into()), 5, 2) + ); + // More precision + assert_eq!( + ScalarValue::new_one(&DataType::Decimal256(7, 2)).unwrap(), + ScalarValue::Decimal256(Some(100.into()), 7, 2) + ); + // No negative scale + assert!(ScalarValue::new_one(&DataType::Decimal256(5, -1)).is_err()); + // Invalid combination + assert!(ScalarValue::new_one(&DataType::Decimal256(0, 2)).is_err()); + assert!(ScalarValue::new_one(&DataType::Decimal256(5, 7)).is_err()); + } + + #[test] + fn test_new_ten_decimal128() { + assert_eq!( + ScalarValue::new_ten(&DataType::Decimal128(5, 1)).unwrap(), + ScalarValue::Decimal128(Some(100), 5, 1) + ); + assert_eq!( + ScalarValue::new_ten(&DataType::Decimal128(5, 2)).unwrap(), + ScalarValue::Decimal128(Some(1000), 5, 2) + ); + // More precision + assert_eq!( + ScalarValue::new_ten(&DataType::Decimal128(7, 2)).unwrap(), + ScalarValue::Decimal128(Some(1000), 7, 2) + ); + // No negative or zero scale + assert!(ScalarValue::new_ten(&DataType::Decimal128(5, 0)).is_err()); + assert!(ScalarValue::new_ten(&DataType::Decimal128(5, -1)).is_err()); + // Invalid combination + assert!(ScalarValue::new_ten(&DataType::Decimal128(0, 2)).is_err()); + assert!(ScalarValue::new_ten(&DataType::Decimal128(5, 7)).is_err()); + } + + #[test] + fn test_new_ten_decimal256() { + assert_eq!( + ScalarValue::new_ten(&DataType::Decimal256(5, 1)).unwrap(), + ScalarValue::Decimal256(Some(100.into()), 5, 1) + ); + assert_eq!( + ScalarValue::new_ten(&DataType::Decimal256(5, 2)).unwrap(), + ScalarValue::Decimal256(Some(1000.into()), 5, 2) + ); + // More precision + assert_eq!( + ScalarValue::new_ten(&DataType::Decimal256(7, 2)).unwrap(), + ScalarValue::Decimal256(Some(1000.into()), 7, 2) + ); + // No negative or zero scale + assert!(ScalarValue::new_ten(&DataType::Decimal256(5, 0)).is_err()); + assert!(ScalarValue::new_ten(&DataType::Decimal256(5, -1)).is_err()); + // Invalid combination + assert!(ScalarValue::new_ten(&DataType::Decimal256(0, 2)).is_err()); + assert!(ScalarValue::new_ten(&DataType::Decimal256(5, 7)).is_err()); + } + + #[test] + fn test_new_negative_one_decimal128() { + assert_eq!( + ScalarValue::new_negative_one(&DataType::Decimal128(5, 0)).unwrap(), + ScalarValue::Decimal128(Some(-1), 5, 0) + ); + assert_eq!( + ScalarValue::new_negative_one(&DataType::Decimal128(5, 2)).unwrap(), + ScalarValue::Decimal128(Some(-100), 5, 2) + ); + } + #[test] fn test_list_partial_cmp() { let a = @@ -6946,6 +7167,26 @@ mod tests { ScalarValue::Float64(Some(-9.9)), 5, ), + ( + ScalarValue::Decimal128(Some(10), 1, 0), + ScalarValue::Decimal128(Some(5), 1, 0), + 5, + ), + ( + ScalarValue::Decimal128(Some(5), 1, 0), + ScalarValue::Decimal128(Some(10), 1, 0), + 5, + ), + ( + ScalarValue::Decimal256(Some(10.into()), 1, 0), + ScalarValue::Decimal256(Some(5.into()), 1, 0), + 5, + ), + ( + ScalarValue::Decimal256(Some(5.into()), 1, 0), + ScalarValue::Decimal256(Some(10.into()), 1, 0), + 5, + ), ]; for (lhs, rhs, expected) in cases.iter() { let distance = lhs.distance(rhs).unwrap(); @@ -6953,6 +7194,24 @@ mod tests { } } + #[test] + fn test_distance_none() { + let cases = [ + ( + ScalarValue::Decimal128(Some(i128::MAX), DECIMAL128_MAX_PRECISION, 0), + ScalarValue::Decimal128(Some(-i128::MAX), DECIMAL128_MAX_PRECISION, 0), + ), + ( + ScalarValue::Decimal256(Some(i256::MAX), DECIMAL256_MAX_PRECISION, 0), + ScalarValue::Decimal256(Some(-i256::MAX), DECIMAL256_MAX_PRECISION, 0), + ), + ]; + for (lhs, rhs) in cases.iter() { + let distance = lhs.distance(rhs); + assert!(distance.is_none(), "{lhs} vs {rhs}"); + } + } + #[test] fn test_scalar_distance_invalid() { let cases = [ @@ -6994,7 +7253,33 @@ mod tests { (ScalarValue::Date64(Some(0)), ScalarValue::Date64(Some(1))), ( ScalarValue::Decimal128(Some(123), 5, 5), - ScalarValue::Decimal128(Some(120), 5, 5), + ScalarValue::Decimal128(Some(120), 5, 3), + ), + ( + ScalarValue::Decimal128(Some(123), 5, 5), + ScalarValue::Decimal128(Some(120), 3, 5), + ), + ( + ScalarValue::Decimal256(Some(123.into()), 5, 5), + ScalarValue::Decimal256(Some(120.into()), 3, 5), + ), + // Distance 2 * 2^50 is larger than usize + ( + ScalarValue::Decimal256( + Some(i256::from_parts(0, 2_i64.pow(50).into())), + 1, + 0, + ), + ScalarValue::Decimal256( + Some(i256::from_parts(0, (-(2_i64).pow(50)).into())), + 1, + 0, + ), + ), + // Distance overflow + ( + ScalarValue::Decimal256(Some(i256::from_parts(0, i128::MAX)), 1, 0), + ScalarValue::Decimal256(Some(i256::from_parts(0, -i128::MAX)), 1, 0), ), ]; for (lhs, rhs) in cases { diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index 4df0e125eb18..2f7dadcebaa4 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -17,6 +17,7 @@ //! Utility functions for expression simplification +use arrow::datatypes::i256; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::{ expr::{Between, BinaryExpr, InList}, @@ -150,6 +151,11 @@ pub fn is_zero(s: &Expr) -> bool { Expr::Literal(ScalarValue::Float32(Some(v)), _) if *v == 0. => true, Expr::Literal(ScalarValue::Float64(Some(v)), _) if *v == 0. => true, Expr::Literal(ScalarValue::Decimal128(Some(v), _p, _s), _) if *v == 0 => true, + Expr::Literal(ScalarValue::Decimal256(Some(v), _p, _s), _) + if *v == i256::ZERO => + { + true + } _ => false, } } @@ -173,6 +179,13 @@ pub fn is_one(s: &Expr) -> bool { .map(|x| x == v) .unwrap_or_default() } + Expr::Literal(ScalarValue::Decimal256(Some(v), _p, s), _) => { + *s >= 0 + && match i256::from(10).checked_pow(*s as u32) { + Some(res) => res == *v, + None => false, + } + } _ => false, } } @@ -365,3 +378,78 @@ pub fn distribute_negation(expr: Expr) -> Expr { _ => Expr::Negative(Box::new(expr)), } } + +#[cfg(test)] +mod tests { + use super::{is_one, is_zero}; + use arrow::datatypes::i256; + use datafusion_common::ScalarValue; + use datafusion_expr::lit; + + #[test] + fn test_is_zero() { + assert!(is_zero(&lit(ScalarValue::Int8(Some(0))))); + assert!(is_zero(&lit(ScalarValue::Float32(Some(0.0))))); + assert!(is_zero(&lit(ScalarValue::Decimal128( + Some(i128::from(0)), + 9, + 0 + )))); + assert!(is_zero(&lit(ScalarValue::Decimal128( + Some(i128::from(0)), + 9, + 5 + )))); + assert!(is_zero(&lit(ScalarValue::Decimal256( + Some(i256::ZERO), + 9, + 0 + )))); + assert!(is_zero(&lit(ScalarValue::Decimal256( + Some(i256::ZERO), + 9, + 5 + )))); + } + + #[test] + fn test_is_one() { + assert!(is_one(&lit(ScalarValue::Int8(Some(1))))); + assert!(is_one(&lit(ScalarValue::Float32(Some(1.0))))); + assert!(is_one(&lit(ScalarValue::Decimal128( + Some(i128::from(1)), + 9, + 0 + )))); + assert!(is_one(&lit(ScalarValue::Decimal128( + Some(i128::from(10)), + 9, + 1 + )))); + assert!(is_one(&lit(ScalarValue::Decimal128( + Some(i128::from(100)), + 9, + 2 + )))); + assert!(is_one(&lit(ScalarValue::Decimal256( + Some(i256::from(1)), + 9, + 0 + )))); + assert!(is_one(&lit(ScalarValue::Decimal256( + Some(i256::from(10)), + 9, + 1 + )))); + assert!(is_one(&lit(ScalarValue::Decimal256( + Some(i256::from(100)), + 9, + 2 + )))); + assert!(!is_one(&lit(ScalarValue::Decimal256( + Some(i256::from(100)), + 9, + -1 + )))); + } +}