Skip to content

Commit 9218efd

Browse files
theirixadriangb
authored andcommitted
feat: Add ScalarValue::{new_one,new_zero,new_ten,distance} support for Decimal128 and Decimal256 (apache#16831)
* Add missing ScalarValue impls for large decimals Add methods distance, new_zero, new_one, new_ten for Decimal128, Decimal256 * Support expr simplication for Decimal256 * Replace lookup table with i128::pow * Support different scales for Decimal constructors - Allow to construct one and ten with different scales - Add tests for new_one, new_ten - Add test for distance * Revert "Replace lookup table with i128::pow" This reverts commit ba23e8c. * Use Arrow error directly
1 parent 5554794 commit 9218efd

File tree

2 files changed

+381
-8
lines changed
  • datafusion

2 files changed

+381
-8
lines changed

datafusion/common/src/scalar/mod.rs

Lines changed: 293 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,13 @@ use arrow::compute::kernels::numeric::{
7474
add, add_wrapping, div, mul, mul_wrapping, rem, sub, sub_wrapping,
7575
};
7676
use arrow::datatypes::{
77-
i256, ArrowDictionaryKeyType, ArrowNativeType, ArrowTimestampType, DataType,
78-
Date32Type, Field, Float32Type, Int16Type, Int32Type, Int64Type, Int8Type,
79-
IntervalDayTime, IntervalDayTimeType, IntervalMonthDayNano, IntervalMonthDayNanoType,
80-
IntervalUnit, IntervalYearMonthType, TimeUnit, TimestampMicrosecondType,
81-
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type,
82-
UInt32Type, UInt64Type, UInt8Type, UnionFields, UnionMode, DECIMAL128_MAX_PRECISION,
77+
i256, validate_decimal_precision_and_scale, ArrowDictionaryKeyType, ArrowNativeType,
78+
ArrowTimestampType, DataType, Date32Type, Decimal128Type, Decimal256Type, Field,
79+
Float32Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTime,
80+
IntervalDayTimeType, IntervalMonthDayNano, IntervalMonthDayNanoType, IntervalUnit,
81+
IntervalYearMonthType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
82+
TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type,
83+
UInt8Type, UnionFields, UnionMode, DECIMAL128_MAX_PRECISION,
8384
};
8485
use arrow::util::display::{array_value_to_string, ArrayFormatter, FormatOptions};
8586
use cache::{get_or_create_cached_key_array, get_or_create_cached_null_array};
@@ -1516,6 +1517,34 @@ impl ScalarValue {
15161517
DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(1.0))),
15171518
DataType::Float32 => ScalarValue::Float32(Some(1.0)),
15181519
DataType::Float64 => ScalarValue::Float64(Some(1.0)),
1520+
DataType::Decimal128(precision, scale) => {
1521+
validate_decimal_precision_and_scale::<Decimal128Type>(
1522+
*precision, *scale,
1523+
)?;
1524+
if *scale < 0 {
1525+
return _internal_err!("Negative scale is not supported");
1526+
}
1527+
match i128::from(10).checked_pow(*scale as u32) {
1528+
Some(value) => {
1529+
ScalarValue::Decimal128(Some(value), *precision, *scale)
1530+
}
1531+
None => return _internal_err!("Unsupported scale {scale}"),
1532+
}
1533+
}
1534+
DataType::Decimal256(precision, scale) => {
1535+
validate_decimal_precision_and_scale::<Decimal256Type>(
1536+
*precision, *scale,
1537+
)?;
1538+
if *scale < 0 {
1539+
return _internal_err!("Negative scale is not supported");
1540+
}
1541+
match i256::from(10).checked_pow(*scale as u32) {
1542+
Some(value) => {
1543+
ScalarValue::Decimal256(Some(value), *precision, *scale)
1544+
}
1545+
None => return _internal_err!("Unsupported scale {scale}"),
1546+
}
1547+
}
15191548
_ => {
15201549
return _not_impl_err!(
15211550
"Can't create an one scalar from data_type \"{datatype:?}\""
@@ -1534,6 +1563,34 @@ impl ScalarValue {
15341563
DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(-1.0))),
15351564
DataType::Float32 => ScalarValue::Float32(Some(-1.0)),
15361565
DataType::Float64 => ScalarValue::Float64(Some(-1.0)),
1566+
DataType::Decimal128(precision, scale) => {
1567+
validate_decimal_precision_and_scale::<Decimal128Type>(
1568+
*precision, *scale,
1569+
)?;
1570+
if *scale < 0 {
1571+
return _internal_err!("Negative scale is not supported");
1572+
}
1573+
match i128::from(10).checked_pow(*scale as u32) {
1574+
Some(value) => {
1575+
ScalarValue::Decimal128(Some(-value), *precision, *scale)
1576+
}
1577+
None => return _internal_err!("Unsupported scale {scale}"),
1578+
}
1579+
}
1580+
DataType::Decimal256(precision, scale) => {
1581+
validate_decimal_precision_and_scale::<Decimal256Type>(
1582+
*precision, *scale,
1583+
)?;
1584+
if *scale < 0 {
1585+
return _internal_err!("Negative scale is not supported");
1586+
}
1587+
match i256::from(10).checked_pow(*scale as u32) {
1588+
Some(value) => {
1589+
ScalarValue::Decimal256(Some(-value), *precision, *scale)
1590+
}
1591+
None => return _internal_err!("Unsupported scale {scale}"),
1592+
}
1593+
}
15371594
_ => {
15381595
return _not_impl_err!(
15391596
"Can't create a negative one scalar from data_type \"{datatype:?}\""
@@ -1555,6 +1612,38 @@ impl ScalarValue {
15551612
DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(10.0))),
15561613
DataType::Float32 => ScalarValue::Float32(Some(10.0)),
15571614
DataType::Float64 => ScalarValue::Float64(Some(10.0)),
1615+
DataType::Decimal128(precision, scale) => {
1616+
if let Err(err) = validate_decimal_precision_and_scale::<Decimal128Type>(
1617+
*precision, *scale,
1618+
) {
1619+
return _internal_err!("Invalid precision and scale {err}");
1620+
}
1621+
if *scale <= 0 {
1622+
return _internal_err!("Negative scale is not supported");
1623+
}
1624+
match i128::from(10).checked_pow((*scale + 1) as u32) {
1625+
Some(value) => {
1626+
ScalarValue::Decimal128(Some(value), *precision, *scale)
1627+
}
1628+
None => return _internal_err!("Unsupported scale {scale}"),
1629+
}
1630+
}
1631+
DataType::Decimal256(precision, scale) => {
1632+
if let Err(err) = validate_decimal_precision_and_scale::<Decimal256Type>(
1633+
*precision, *scale,
1634+
) {
1635+
return _internal_err!("Invalid precision and scale {err}");
1636+
}
1637+
if *scale <= 0 {
1638+
return _internal_err!("Negative scale is not supported");
1639+
}
1640+
match i256::from(10).checked_pow((*scale + 1) as u32) {
1641+
Some(value) => {
1642+
ScalarValue::Decimal256(Some(value), *precision, *scale)
1643+
}
1644+
None => return _internal_err!("Unsupported scale {scale}"),
1645+
}
1646+
}
15581647
_ => {
15591648
return _not_impl_err!(
15601649
"Can't create a ten scalar from data_type \"{datatype:?}\""
@@ -1924,6 +2013,26 @@ impl ScalarValue {
19242013
(Self::Float64(Some(l)), Self::Float64(Some(r))) => {
19252014
Some((l - r).abs().round() as _)
19262015
}
2016+
(
2017+
Self::Decimal128(Some(l), lprecision, lscale),
2018+
Self::Decimal128(Some(r), rprecision, rscale),
2019+
) => {
2020+
if lprecision == rprecision && lscale == rscale {
2021+
l.checked_sub(*r)?.checked_abs()?.to_usize()
2022+
} else {
2023+
None
2024+
}
2025+
}
2026+
(
2027+
Self::Decimal256(Some(l), lprecision, lscale),
2028+
Self::Decimal256(Some(r), rprecision, rscale),
2029+
) => {
2030+
if lprecision == rprecision && lscale == rscale {
2031+
l.checked_sub(*r)?.checked_abs()?.to_usize()
2032+
} else {
2033+
None
2034+
}
2035+
}
19272036
_ => None,
19282037
}
19292038
}
@@ -4489,7 +4598,9 @@ mod tests {
44894598
};
44904599
use arrow::buffer::{Buffer, OffsetBuffer};
44914600
use arrow::compute::{is_null, kernels};
4492-
use arrow::datatypes::{ArrowNumericType, Fields, Float64Type};
4601+
use arrow::datatypes::{
4602+
ArrowNumericType, Fields, Float64Type, DECIMAL256_MAX_PRECISION,
4603+
};
44934604
use arrow::error::ArrowError;
44944605
use arrow::util::pretty::pretty_format_columns;
44954606
use chrono::NaiveDate;
@@ -5225,6 +5336,116 @@ mod tests {
52255336
Ok(())
52265337
}
52275338

5339+
#[test]
5340+
fn test_new_one_decimal128() {
5341+
assert_eq!(
5342+
ScalarValue::new_one(&DataType::Decimal128(5, 0)).unwrap(),
5343+
ScalarValue::Decimal128(Some(1), 5, 0)
5344+
);
5345+
assert_eq!(
5346+
ScalarValue::new_one(&DataType::Decimal128(5, 1)).unwrap(),
5347+
ScalarValue::Decimal128(Some(10), 5, 1)
5348+
);
5349+
assert_eq!(
5350+
ScalarValue::new_one(&DataType::Decimal128(5, 2)).unwrap(),
5351+
ScalarValue::Decimal128(Some(100), 5, 2)
5352+
);
5353+
// More precision
5354+
assert_eq!(
5355+
ScalarValue::new_one(&DataType::Decimal128(7, 2)).unwrap(),
5356+
ScalarValue::Decimal128(Some(100), 7, 2)
5357+
);
5358+
// No negative scale
5359+
assert!(ScalarValue::new_one(&DataType::Decimal128(5, -1)).is_err());
5360+
// Invalid combination
5361+
assert!(ScalarValue::new_one(&DataType::Decimal128(0, 2)).is_err());
5362+
assert!(ScalarValue::new_one(&DataType::Decimal128(5, 7)).is_err());
5363+
}
5364+
5365+
#[test]
5366+
fn test_new_one_decimal256() {
5367+
assert_eq!(
5368+
ScalarValue::new_one(&DataType::Decimal256(5, 0)).unwrap(),
5369+
ScalarValue::Decimal256(Some(1.into()), 5, 0)
5370+
);
5371+
assert_eq!(
5372+
ScalarValue::new_one(&DataType::Decimal256(5, 1)).unwrap(),
5373+
ScalarValue::Decimal256(Some(10.into()), 5, 1)
5374+
);
5375+
assert_eq!(
5376+
ScalarValue::new_one(&DataType::Decimal256(5, 2)).unwrap(),
5377+
ScalarValue::Decimal256(Some(100.into()), 5, 2)
5378+
);
5379+
// More precision
5380+
assert_eq!(
5381+
ScalarValue::new_one(&DataType::Decimal256(7, 2)).unwrap(),
5382+
ScalarValue::Decimal256(Some(100.into()), 7, 2)
5383+
);
5384+
// No negative scale
5385+
assert!(ScalarValue::new_one(&DataType::Decimal256(5, -1)).is_err());
5386+
// Invalid combination
5387+
assert!(ScalarValue::new_one(&DataType::Decimal256(0, 2)).is_err());
5388+
assert!(ScalarValue::new_one(&DataType::Decimal256(5, 7)).is_err());
5389+
}
5390+
5391+
#[test]
5392+
fn test_new_ten_decimal128() {
5393+
assert_eq!(
5394+
ScalarValue::new_ten(&DataType::Decimal128(5, 1)).unwrap(),
5395+
ScalarValue::Decimal128(Some(100), 5, 1)
5396+
);
5397+
assert_eq!(
5398+
ScalarValue::new_ten(&DataType::Decimal128(5, 2)).unwrap(),
5399+
ScalarValue::Decimal128(Some(1000), 5, 2)
5400+
);
5401+
// More precision
5402+
assert_eq!(
5403+
ScalarValue::new_ten(&DataType::Decimal128(7, 2)).unwrap(),
5404+
ScalarValue::Decimal128(Some(1000), 7, 2)
5405+
);
5406+
// No negative or zero scale
5407+
assert!(ScalarValue::new_ten(&DataType::Decimal128(5, 0)).is_err());
5408+
assert!(ScalarValue::new_ten(&DataType::Decimal128(5, -1)).is_err());
5409+
// Invalid combination
5410+
assert!(ScalarValue::new_ten(&DataType::Decimal128(0, 2)).is_err());
5411+
assert!(ScalarValue::new_ten(&DataType::Decimal128(5, 7)).is_err());
5412+
}
5413+
5414+
#[test]
5415+
fn test_new_ten_decimal256() {
5416+
assert_eq!(
5417+
ScalarValue::new_ten(&DataType::Decimal256(5, 1)).unwrap(),
5418+
ScalarValue::Decimal256(Some(100.into()), 5, 1)
5419+
);
5420+
assert_eq!(
5421+
ScalarValue::new_ten(&DataType::Decimal256(5, 2)).unwrap(),
5422+
ScalarValue::Decimal256(Some(1000.into()), 5, 2)
5423+
);
5424+
// More precision
5425+
assert_eq!(
5426+
ScalarValue::new_ten(&DataType::Decimal256(7, 2)).unwrap(),
5427+
ScalarValue::Decimal256(Some(1000.into()), 7, 2)
5428+
);
5429+
// No negative or zero scale
5430+
assert!(ScalarValue::new_ten(&DataType::Decimal256(5, 0)).is_err());
5431+
assert!(ScalarValue::new_ten(&DataType::Decimal256(5, -1)).is_err());
5432+
// Invalid combination
5433+
assert!(ScalarValue::new_ten(&DataType::Decimal256(0, 2)).is_err());
5434+
assert!(ScalarValue::new_ten(&DataType::Decimal256(5, 7)).is_err());
5435+
}
5436+
5437+
#[test]
5438+
fn test_new_negative_one_decimal128() {
5439+
assert_eq!(
5440+
ScalarValue::new_negative_one(&DataType::Decimal128(5, 0)).unwrap(),
5441+
ScalarValue::Decimal128(Some(-1), 5, 0)
5442+
);
5443+
assert_eq!(
5444+
ScalarValue::new_negative_one(&DataType::Decimal128(5, 2)).unwrap(),
5445+
ScalarValue::Decimal128(Some(-100), 5, 2)
5446+
);
5447+
}
5448+
52285449
#[test]
52295450
fn test_list_partial_cmp() {
52305451
let a =
@@ -7275,13 +7496,51 @@ mod tests {
72757496
ScalarValue::Float64(Some(-9.9)),
72767497
5,
72777498
),
7499+
(
7500+
ScalarValue::Decimal128(Some(10), 1, 0),
7501+
ScalarValue::Decimal128(Some(5), 1, 0),
7502+
5,
7503+
),
7504+
(
7505+
ScalarValue::Decimal128(Some(5), 1, 0),
7506+
ScalarValue::Decimal128(Some(10), 1, 0),
7507+
5,
7508+
),
7509+
(
7510+
ScalarValue::Decimal256(Some(10.into()), 1, 0),
7511+
ScalarValue::Decimal256(Some(5.into()), 1, 0),
7512+
5,
7513+
),
7514+
(
7515+
ScalarValue::Decimal256(Some(5.into()), 1, 0),
7516+
ScalarValue::Decimal256(Some(10.into()), 1, 0),
7517+
5,
7518+
),
72787519
];
72797520
for (lhs, rhs, expected) in cases.iter() {
72807521
let distance = lhs.distance(rhs).unwrap();
72817522
assert_eq!(distance, *expected);
72827523
}
72837524
}
72847525

7526+
#[test]
7527+
fn test_distance_none() {
7528+
let cases = [
7529+
(
7530+
ScalarValue::Decimal128(Some(i128::MAX), DECIMAL128_MAX_PRECISION, 0),
7531+
ScalarValue::Decimal128(Some(-i128::MAX), DECIMAL128_MAX_PRECISION, 0),
7532+
),
7533+
(
7534+
ScalarValue::Decimal256(Some(i256::MAX), DECIMAL256_MAX_PRECISION, 0),
7535+
ScalarValue::Decimal256(Some(-i256::MAX), DECIMAL256_MAX_PRECISION, 0),
7536+
),
7537+
];
7538+
for (lhs, rhs) in cases.iter() {
7539+
let distance = lhs.distance(rhs);
7540+
assert!(distance.is_none(), "{lhs} vs {rhs}");
7541+
}
7542+
}
7543+
72857544
#[test]
72867545
fn test_scalar_distance_invalid() {
72877546
let cases = [
@@ -7323,7 +7582,33 @@ mod tests {
73237582
(ScalarValue::Date64(Some(0)), ScalarValue::Date64(Some(1))),
73247583
(
73257584
ScalarValue::Decimal128(Some(123), 5, 5),
7326-
ScalarValue::Decimal128(Some(120), 5, 5),
7585+
ScalarValue::Decimal128(Some(120), 5, 3),
7586+
),
7587+
(
7588+
ScalarValue::Decimal128(Some(123), 5, 5),
7589+
ScalarValue::Decimal128(Some(120), 3, 5),
7590+
),
7591+
(
7592+
ScalarValue::Decimal256(Some(123.into()), 5, 5),
7593+
ScalarValue::Decimal256(Some(120.into()), 3, 5),
7594+
),
7595+
// Distance 2 * 2^50 is larger than usize
7596+
(
7597+
ScalarValue::Decimal256(
7598+
Some(i256::from_parts(0, 2_i64.pow(50).into())),
7599+
1,
7600+
0,
7601+
),
7602+
ScalarValue::Decimal256(
7603+
Some(i256::from_parts(0, (-(2_i64).pow(50)).into())),
7604+
1,
7605+
0,
7606+
),
7607+
),
7608+
// Distance overflow
7609+
(
7610+
ScalarValue::Decimal256(Some(i256::from_parts(0, i128::MAX)), 1, 0),
7611+
ScalarValue::Decimal256(Some(i256::from_parts(0, -i128::MAX)), 1, 0),
73277612
),
73287613
];
73297614
for (lhs, rhs) in cases {

0 commit comments

Comments
 (0)