Skip to content

feat: Add ScalarValue::{new_one,new_zero,new_ten,distance} support for Decimal128 and Decimal256 #16831

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 293 additions & 8 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,13 @@ use arrow::compute::kernels::numeric::{
add, add_wrapping, div, mul, mul_wrapping, rem, sub, sub_wrapping,
};
use arrow::datatypes::{
i256, ArrowDictionaryKeyType, ArrowNativeType, ArrowTimestampType, DataType,
Date32Type, Field, Float32Type, Int16Type, Int32Type, Int64Type, Int8Type,
IntervalDayTime, IntervalDayTimeType, IntervalMonthDayNano, IntervalMonthDayNanoType,
IntervalUnit, IntervalYearMonthType, TimeUnit, TimestampMicrosecondType,
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type,
UInt32Type, UInt64Type, UInt8Type, UnionFields, UnionMode, DECIMAL128_MAX_PRECISION,
i256, validate_decimal_precision_and_scale, ArrowDictionaryKeyType, ArrowNativeType,
ArrowTimestampType, DataType, Date32Type, Decimal128Type, Decimal256Type, Field,
Float32Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTime,
IntervalDayTimeType, IntervalMonthDayNano, IntervalMonthDayNanoType, IntervalUnit,
IntervalYearMonthType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type,
UInt8Type, UnionFields, UnionMode, DECIMAL128_MAX_PRECISION,
};
use arrow::util::display::{array_value_to_string, ArrayFormatter, FormatOptions};
use chrono::{Duration, NaiveDate};
Expand Down Expand Up @@ -1382,6 +1383,34 @@ impl ScalarValue {
DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(1.0))),
DataType::Float32 => ScalarValue::Float32(Some(1.0)),
DataType::Float64 => ScalarValue::Float64(Some(1.0)),
DataType::Decimal128(precision, scale) => {
validate_decimal_precision_and_scale::<Decimal128Type>(
*precision, *scale,
)?;
if *scale < 0 {
return _internal_err!("Negative scale is not supported");
}
match i128::from(10).checked_pow(*scale as u32) {
Some(value) => {
ScalarValue::Decimal128(Some(value), *precision, *scale)
}
None => return _internal_err!("Unsupported scale {scale}"),
}
}
DataType::Decimal256(precision, scale) => {
validate_decimal_precision_and_scale::<Decimal256Type>(
*precision, *scale,
)?;
if *scale < 0 {
return _internal_err!("Negative scale is not supported");
}
match i256::from(10).checked_pow(*scale as u32) {
Some(value) => {
ScalarValue::Decimal256(Some(value), *precision, *scale)
}
None => return _internal_err!("Unsupported scale {scale}"),
}
}
_ => {
return _not_impl_err!(
"Can't create an one scalar from data_type \"{datatype:?}\""
Expand All @@ -1400,6 +1429,34 @@ impl ScalarValue {
DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(-1.0))),
DataType::Float32 => ScalarValue::Float32(Some(-1.0)),
DataType::Float64 => ScalarValue::Float64(Some(-1.0)),
DataType::Decimal128(precision, scale) => {
validate_decimal_precision_and_scale::<Decimal128Type>(
*precision, *scale,
)?;
if *scale < 0 {
return _internal_err!("Negative scale is not supported");
}
match i128::from(10).checked_pow(*scale as u32) {
Some(value) => {
ScalarValue::Decimal128(Some(-value), *precision, *scale)
}
None => return _internal_err!("Unsupported scale {scale}"),
}
}
DataType::Decimal256(precision, scale) => {
validate_decimal_precision_and_scale::<Decimal256Type>(
*precision, *scale,
)?;
if *scale < 0 {
return _internal_err!("Negative scale is not supported");
}
match i256::from(10).checked_pow(*scale as u32) {
Some(value) => {
ScalarValue::Decimal256(Some(-value), *precision, *scale)
}
None => return _internal_err!("Unsupported scale {scale}"),
}
}
_ => {
return _not_impl_err!(
"Can't create a negative one scalar from data_type \"{datatype:?}\""
Expand All @@ -1421,6 +1478,38 @@ impl ScalarValue {
DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(10.0))),
DataType::Float32 => ScalarValue::Float32(Some(10.0)),
DataType::Float64 => ScalarValue::Float64(Some(10.0)),
DataType::Decimal128(precision, scale) => {
if let Err(err) = validate_decimal_precision_and_scale::<Decimal128Type>(
*precision, *scale,
) {
return _internal_err!("Invalid precision and scale {err}");
}
if *scale <= 0 {
return _internal_err!("Negative scale is not supported");
}
match i128::from(10).checked_pow((*scale + 1) as u32) {
Some(value) => {
ScalarValue::Decimal128(Some(value), *precision, *scale)
}
None => return _internal_err!("Unsupported scale {scale}"),
}
}
DataType::Decimal256(precision, scale) => {
if let Err(err) = validate_decimal_precision_and_scale::<Decimal256Type>(
*precision, *scale,
) {
return _internal_err!("Invalid precision and scale {err}");
}
if *scale <= 0 {
return _internal_err!("Negative scale is not supported");
}
match i256::from(10).checked_pow((*scale + 1) as u32) {
Some(value) => {
ScalarValue::Decimal256(Some(value), *precision, *scale)
}
None => return _internal_err!("Unsupported scale {scale}"),
}
}
_ => {
return _not_impl_err!(
"Can't create a ten scalar from data_type \"{datatype:?}\""
Expand Down Expand Up @@ -1790,6 +1879,26 @@ impl ScalarValue {
(Self::Float64(Some(l)), Self::Float64(Some(r))) => {
Some((l - r).abs().round() as _)
}
(
Self::Decimal128(Some(l), lprecision, lscale),
Self::Decimal128(Some(r), rprecision, rscale),
) => {
if lprecision == rprecision && lscale == rscale {
l.checked_sub(*r)?.checked_abs()?.to_usize()
} else {
None
}
}
(
Self::Decimal256(Some(l), lprecision, lscale),
Self::Decimal256(Some(r), rprecision, rscale),
) => {
if lprecision == rprecision && lscale == rscale {
l.checked_sub(*r)?.checked_abs()?.to_usize()
} else {
None
}
}
_ => None,
}
}
Expand Down Expand Up @@ -4160,7 +4269,9 @@ mod tests {
};
use arrow::buffer::{Buffer, OffsetBuffer};
use arrow::compute::{is_null, kernels};
use arrow::datatypes::{ArrowNumericType, Fields, Float64Type};
use arrow::datatypes::{
ArrowNumericType, Fields, Float64Type, DECIMAL256_MAX_PRECISION,
};
use arrow::error::ArrowError;
use arrow::util::pretty::pretty_format_columns;
use chrono::NaiveDate;
Expand Down Expand Up @@ -4896,6 +5007,116 @@ mod tests {
Ok(())
}

#[test]
fn test_new_one_decimal128() {
assert_eq!(
ScalarValue::new_one(&DataType::Decimal128(5, 0)).unwrap(),
ScalarValue::Decimal128(Some(1), 5, 0)
);
assert_eq!(
ScalarValue::new_one(&DataType::Decimal128(5, 1)).unwrap(),
ScalarValue::Decimal128(Some(10), 5, 1)
);
assert_eq!(
ScalarValue::new_one(&DataType::Decimal128(5, 2)).unwrap(),
ScalarValue::Decimal128(Some(100), 5, 2)
);
// More precision
assert_eq!(
ScalarValue::new_one(&DataType::Decimal128(7, 2)).unwrap(),
ScalarValue::Decimal128(Some(100), 7, 2)
);
// No negative scale
assert!(ScalarValue::new_one(&DataType::Decimal128(5, -1)).is_err());
// Invalid combination
assert!(ScalarValue::new_one(&DataType::Decimal128(0, 2)).is_err());
assert!(ScalarValue::new_one(&DataType::Decimal128(5, 7)).is_err());
}

#[test]
fn test_new_one_decimal256() {
assert_eq!(
ScalarValue::new_one(&DataType::Decimal256(5, 0)).unwrap(),
ScalarValue::Decimal256(Some(1.into()), 5, 0)
);
assert_eq!(
ScalarValue::new_one(&DataType::Decimal256(5, 1)).unwrap(),
ScalarValue::Decimal256(Some(10.into()), 5, 1)
);
assert_eq!(
ScalarValue::new_one(&DataType::Decimal256(5, 2)).unwrap(),
ScalarValue::Decimal256(Some(100.into()), 5, 2)
);
// More precision
assert_eq!(
ScalarValue::new_one(&DataType::Decimal256(7, 2)).unwrap(),
ScalarValue::Decimal256(Some(100.into()), 7, 2)
);
// No negative scale
assert!(ScalarValue::new_one(&DataType::Decimal256(5, -1)).is_err());
// Invalid combination
assert!(ScalarValue::new_one(&DataType::Decimal256(0, 2)).is_err());
assert!(ScalarValue::new_one(&DataType::Decimal256(5, 7)).is_err());
}

#[test]
fn test_new_ten_decimal128() {
assert_eq!(
ScalarValue::new_ten(&DataType::Decimal128(5, 1)).unwrap(),
ScalarValue::Decimal128(Some(100), 5, 1)
);
assert_eq!(
ScalarValue::new_ten(&DataType::Decimal128(5, 2)).unwrap(),
ScalarValue::Decimal128(Some(1000), 5, 2)
);
// More precision
assert_eq!(
ScalarValue::new_ten(&DataType::Decimal128(7, 2)).unwrap(),
ScalarValue::Decimal128(Some(1000), 7, 2)
);
// No negative or zero scale
assert!(ScalarValue::new_ten(&DataType::Decimal128(5, 0)).is_err());
assert!(ScalarValue::new_ten(&DataType::Decimal128(5, -1)).is_err());
// Invalid combination
assert!(ScalarValue::new_ten(&DataType::Decimal128(0, 2)).is_err());
assert!(ScalarValue::new_ten(&DataType::Decimal128(5, 7)).is_err());
}

#[test]
fn test_new_ten_decimal256() {
assert_eq!(
ScalarValue::new_ten(&DataType::Decimal256(5, 1)).unwrap(),
ScalarValue::Decimal256(Some(100.into()), 5, 1)
);
assert_eq!(
ScalarValue::new_ten(&DataType::Decimal256(5, 2)).unwrap(),
ScalarValue::Decimal256(Some(1000.into()), 5, 2)
);
// More precision
assert_eq!(
ScalarValue::new_ten(&DataType::Decimal256(7, 2)).unwrap(),
ScalarValue::Decimal256(Some(1000.into()), 7, 2)
);
// No negative or zero scale
assert!(ScalarValue::new_ten(&DataType::Decimal256(5, 0)).is_err());
assert!(ScalarValue::new_ten(&DataType::Decimal256(5, -1)).is_err());
// Invalid combination
assert!(ScalarValue::new_ten(&DataType::Decimal256(0, 2)).is_err());
assert!(ScalarValue::new_ten(&DataType::Decimal256(5, 7)).is_err());
}

#[test]
fn test_new_negative_one_decimal128() {
assert_eq!(
ScalarValue::new_negative_one(&DataType::Decimal128(5, 0)).unwrap(),
ScalarValue::Decimal128(Some(-1), 5, 0)
);
assert_eq!(
ScalarValue::new_negative_one(&DataType::Decimal128(5, 2)).unwrap(),
ScalarValue::Decimal128(Some(-100), 5, 2)
);
}

#[test]
fn test_list_partial_cmp() {
let a =
Expand Down Expand Up @@ -6946,13 +7167,51 @@ mod tests {
ScalarValue::Float64(Some(-9.9)),
5,
),
(
ScalarValue::Decimal128(Some(10), 1, 0),
ScalarValue::Decimal128(Some(5), 1, 0),
5,
),
(
ScalarValue::Decimal128(Some(5), 1, 0),
ScalarValue::Decimal128(Some(10), 1, 0),
5,
),
(
ScalarValue::Decimal256(Some(10.into()), 1, 0),
ScalarValue::Decimal256(Some(5.into()), 1, 0),
5,
),
(
ScalarValue::Decimal256(Some(5.into()), 1, 0),
ScalarValue::Decimal256(Some(10.into()), 1, 0),
5,
),
];
for (lhs, rhs, expected) in cases.iter() {
let distance = lhs.distance(rhs).unwrap();
assert_eq!(distance, *expected);
}
}

#[test]
fn test_distance_none() {
let cases = [
(
ScalarValue::Decimal128(Some(i128::MAX), DECIMAL128_MAX_PRECISION, 0),
ScalarValue::Decimal128(Some(-i128::MAX), DECIMAL128_MAX_PRECISION, 0),
),
(
ScalarValue::Decimal256(Some(i256::MAX), DECIMAL256_MAX_PRECISION, 0),
ScalarValue::Decimal256(Some(-i256::MAX), DECIMAL256_MAX_PRECISION, 0),
),
];
for (lhs, rhs) in cases.iter() {
let distance = lhs.distance(rhs);
assert!(distance.is_none(), "{lhs} vs {rhs}");
}
}

#[test]
fn test_scalar_distance_invalid() {
let cases = [
Expand Down Expand Up @@ -6994,7 +7253,33 @@ mod tests {
(ScalarValue::Date64(Some(0)), ScalarValue::Date64(Some(1))),
(
ScalarValue::Decimal128(Some(123), 5, 5),
ScalarValue::Decimal128(Some(120), 5, 5),
ScalarValue::Decimal128(Some(120), 5, 3),
),
(
ScalarValue::Decimal128(Some(123), 5, 5),
ScalarValue::Decimal128(Some(120), 3, 5),
),
(
ScalarValue::Decimal256(Some(123.into()), 5, 5),
ScalarValue::Decimal256(Some(120.into()), 3, 5),
),
// Distance 2 * 2^50 is larger than usize
(
ScalarValue::Decimal256(
Some(i256::from_parts(0, 2_i64.pow(50).into())),
1,
0,
),
ScalarValue::Decimal256(
Some(i256::from_parts(0, (-(2_i64).pow(50)).into())),
1,
0,
),
),
// Distance overflow
(
ScalarValue::Decimal256(Some(i256::from_parts(0, i128::MAX)), 1, 0),
ScalarValue::Decimal256(Some(i256::from_parts(0, -i128::MAX)), 1, 0),
),
];
for (lhs, rhs) in cases {
Expand Down
Loading