Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
309 changes: 301 additions & 8 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,13 @@ use arrow::compute::kernels::numeric::{
add, add_wrapping, div, mul, mul_wrapping, rem, sub, sub_wrapping,
};
use arrow::datatypes::{
i256, ArrowDictionaryKeyType, ArrowNativeType, ArrowTimestampType, DataType,
Date32Type, Field, Float32Type, Int16Type, Int32Type, Int64Type, Int8Type,
IntervalDayTime, IntervalDayTimeType, IntervalMonthDayNano, IntervalMonthDayNanoType,
IntervalUnit, IntervalYearMonthType, TimeUnit, TimestampMicrosecondType,
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type,
UInt32Type, UInt64Type, UInt8Type, UnionFields, UnionMode, DECIMAL128_MAX_PRECISION,
i256, validate_decimal_precision_and_scale, ArrowDictionaryKeyType, ArrowNativeType,
ArrowTimestampType, DataType, Date32Type, Decimal128Type, Decimal256Type, Field,
Float32Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTime,
IntervalDayTimeType, IntervalMonthDayNano, IntervalMonthDayNanoType, IntervalUnit,
IntervalYearMonthType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type,
UInt8Type, UnionFields, UnionMode, DECIMAL128_MAX_PRECISION,
};
use arrow::util::display::{array_value_to_string, ArrayFormatter, FormatOptions};
use chrono::{Duration, NaiveDate};
Expand Down Expand Up @@ -1382,6 +1383,38 @@ impl ScalarValue {
DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(1.0))),
DataType::Float32 => ScalarValue::Float32(Some(1.0)),
DataType::Float64 => ScalarValue::Float64(Some(1.0)),
DataType::Decimal128(precision, scale) => {
if let Err(err) = validate_decimal_precision_and_scale::<Decimal128Type>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason to add new InternalError wrappers here?

As in why not just

Suggested change
if let Err(err) = validate_decimal_precision_and_scale::<Decimal128Type>(
validate_decimal_precision_and_scale::<Decimal128Type>(*precision, *scale)?;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, no need for it, updated. Forgot about the auto-conversion from ArrowError.

*precision, *scale,
) {
return _internal_err!("Invalid precision and scale {err}");
}
if *scale < 0 {
return _internal_err!("Negative scale is not supported");
}
match i128::from(10).checked_pow(*scale as u32) {
Some(value) => {
ScalarValue::Decimal128(Some(value), *precision, *scale)
}
None => return _internal_err!("Unsupported scale {scale}"),
}
}
DataType::Decimal256(precision, scale) => {
if let Err(err) = validate_decimal_precision_and_scale::<Decimal256Type>(
*precision, *scale,
) {
return _internal_err!("Invalid precision and scale {err}");
}
if *scale < 0 {
return _internal_err!("Negative scale is not supported");
}
match i256::from(10).checked_pow(*scale as u32) {
Some(value) => {
ScalarValue::Decimal256(Some(value), *precision, *scale)
}
None => return _internal_err!("Unsupported scale {scale}"),
}
}
_ => {
return _not_impl_err!(
"Can't create an one scalar from data_type \"{datatype:?}\""
Expand All @@ -1400,6 +1433,38 @@ impl ScalarValue {
DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(-1.0))),
DataType::Float32 => ScalarValue::Float32(Some(-1.0)),
DataType::Float64 => ScalarValue::Float64(Some(-1.0)),
DataType::Decimal128(precision, scale) => {
if let Err(err) = validate_decimal_precision_and_scale::<Decimal128Type>(
*precision, *scale,
) {
return _internal_err!("Invalid precision and scale {err}");
}
if *scale < 0 {
return _internal_err!("Negative scale is not supported");
}
match i128::from(10).checked_pow(*scale as u32) {
Some(value) => {
ScalarValue::Decimal128(Some(-value), *precision, *scale)
}
None => return _internal_err!("Unsupported scale {scale}"),
}
}
DataType::Decimal256(precision, scale) => {
if let Err(err) = validate_decimal_precision_and_scale::<Decimal256Type>(
*precision, *scale,
) {
return _internal_err!("Invalid precision and scale {err}");
}
if *scale < 0 {
return _internal_err!("Negative scale is not supported");
}
match i256::from(10).checked_pow(*scale as u32) {
Some(value) => {
ScalarValue::Decimal256(Some(-value), *precision, *scale)
}
None => return _internal_err!("Unsupported scale {scale}"),
}
}
_ => {
return _not_impl_err!(
"Can't create a negative one scalar from data_type \"{datatype:?}\""
Expand All @@ -1421,6 +1486,38 @@ impl ScalarValue {
DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(10.0))),
DataType::Float32 => ScalarValue::Float32(Some(10.0)),
DataType::Float64 => ScalarValue::Float64(Some(10.0)),
DataType::Decimal128(precision, scale) => {
if let Err(err) = validate_decimal_precision_and_scale::<Decimal128Type>(
*precision, *scale,
) {
return _internal_err!("Invalid precision and scale {err}");
}
if *scale <= 0 {
return _internal_err!("Negative scale is not supported");
}
match i128::from(10).checked_pow((*scale + 1) as u32) {
Some(value) => {
ScalarValue::Decimal128(Some(value), *precision, *scale)
}
None => return _internal_err!("Unsupported scale {scale}"),
}
}
DataType::Decimal256(precision, scale) => {
if let Err(err) = validate_decimal_precision_and_scale::<Decimal256Type>(
*precision, *scale,
) {
return _internal_err!("Invalid precision and scale {err}");
}
if *scale <= 0 {
return _internal_err!("Negative scale is not supported");
}
match i256::from(10).checked_pow((*scale + 1) as u32) {
Some(value) => {
ScalarValue::Decimal256(Some(value), *precision, *scale)
}
None => return _internal_err!("Unsupported scale {scale}"),
}
}
_ => {
return _not_impl_err!(
"Can't create a ten scalar from data_type \"{datatype:?}\""
Expand Down Expand Up @@ -1790,6 +1887,26 @@ impl ScalarValue {
(Self::Float64(Some(l)), Self::Float64(Some(r))) => {
Some((l - r).abs().round() as _)
}
(
Self::Decimal128(Some(l), lprecision, lscale),
Self::Decimal128(Some(r), rprecision, rscale),
) => {
if lprecision == rprecision && lscale == rscale {
l.checked_sub(*r)?.checked_abs()?.to_usize()
} else {
None
}
}
(
Self::Decimal256(Some(l), lprecision, lscale),
Self::Decimal256(Some(r), rprecision, rscale),
) => {
if lprecision == rprecision && lscale == rscale {
l.checked_sub(*r)?.checked_abs()?.to_usize()
} else {
None
}
}
_ => None,
}
}
Expand Down Expand Up @@ -4160,7 +4277,9 @@ mod tests {
};
use arrow::buffer::{Buffer, OffsetBuffer};
use arrow::compute::{is_null, kernels};
use arrow::datatypes::{ArrowNumericType, Fields, Float64Type};
use arrow::datatypes::{
ArrowNumericType, Fields, Float64Type, DECIMAL256_MAX_PRECISION,
};
use arrow::error::ArrowError;
use arrow::util::pretty::pretty_format_columns;
use chrono::NaiveDate;
Expand Down Expand Up @@ -4896,6 +5015,116 @@ mod tests {
Ok(())
}

#[test]
fn test_new_one_decimal128() {
assert_eq!(
ScalarValue::new_one(&DataType::Decimal128(5, 0)).unwrap(),
ScalarValue::Decimal128(Some(1), 5, 0)
);
assert_eq!(
ScalarValue::new_one(&DataType::Decimal128(5, 1)).unwrap(),
ScalarValue::Decimal128(Some(10), 5, 1)
);
assert_eq!(
ScalarValue::new_one(&DataType::Decimal128(5, 2)).unwrap(),
ScalarValue::Decimal128(Some(100), 5, 2)
);
// More precision
assert_eq!(
ScalarValue::new_one(&DataType::Decimal128(7, 2)).unwrap(),
ScalarValue::Decimal128(Some(100), 7, 2)
);
// No negative scale
assert!(ScalarValue::new_one(&DataType::Decimal128(5, -1)).is_err());
// Invalid combination
assert!(ScalarValue::new_one(&DataType::Decimal128(0, 2)).is_err());
assert!(ScalarValue::new_one(&DataType::Decimal128(5, 7)).is_err());
}

#[test]
fn test_new_one_decimal256() {
assert_eq!(
ScalarValue::new_one(&DataType::Decimal256(5, 0)).unwrap(),
ScalarValue::Decimal256(Some(1.into()), 5, 0)
);
assert_eq!(
ScalarValue::new_one(&DataType::Decimal256(5, 1)).unwrap(),
ScalarValue::Decimal256(Some(10.into()), 5, 1)
);
assert_eq!(
ScalarValue::new_one(&DataType::Decimal256(5, 2)).unwrap(),
ScalarValue::Decimal256(Some(100.into()), 5, 2)
);
// More precision
assert_eq!(
ScalarValue::new_one(&DataType::Decimal256(7, 2)).unwrap(),
ScalarValue::Decimal256(Some(100.into()), 7, 2)
);
// No negative scale
assert!(ScalarValue::new_one(&DataType::Decimal256(5, -1)).is_err());
// Invalid combination
assert!(ScalarValue::new_one(&DataType::Decimal256(0, 2)).is_err());
assert!(ScalarValue::new_one(&DataType::Decimal256(5, 7)).is_err());
}

#[test]
fn test_new_ten_decimal128() {
assert_eq!(
ScalarValue::new_ten(&DataType::Decimal128(5, 1)).unwrap(),
ScalarValue::Decimal128(Some(100), 5, 1)
);
assert_eq!(
ScalarValue::new_ten(&DataType::Decimal128(5, 2)).unwrap(),
ScalarValue::Decimal128(Some(1000), 5, 2)
);
// More precision
assert_eq!(
ScalarValue::new_ten(&DataType::Decimal128(7, 2)).unwrap(),
ScalarValue::Decimal128(Some(1000), 7, 2)
);
// No negative or zero scale
assert!(ScalarValue::new_ten(&DataType::Decimal128(5, 0)).is_err());
assert!(ScalarValue::new_ten(&DataType::Decimal128(5, -1)).is_err());
// Invalid combination
assert!(ScalarValue::new_ten(&DataType::Decimal128(0, 2)).is_err());
assert!(ScalarValue::new_ten(&DataType::Decimal128(5, 7)).is_err());
}

#[test]
fn test_new_ten_decimal256() {
assert_eq!(
ScalarValue::new_ten(&DataType::Decimal256(5, 1)).unwrap(),
ScalarValue::Decimal256(Some(100.into()), 5, 1)
);
assert_eq!(
ScalarValue::new_ten(&DataType::Decimal256(5, 2)).unwrap(),
ScalarValue::Decimal256(Some(1000.into()), 5, 2)
);
// More precision
assert_eq!(
ScalarValue::new_ten(&DataType::Decimal256(7, 2)).unwrap(),
ScalarValue::Decimal256(Some(1000.into()), 7, 2)
);
// No negative or zero scale
assert!(ScalarValue::new_ten(&DataType::Decimal256(5, 0)).is_err());
assert!(ScalarValue::new_ten(&DataType::Decimal256(5, -1)).is_err());
// Invalid combination
assert!(ScalarValue::new_ten(&DataType::Decimal256(0, 2)).is_err());
assert!(ScalarValue::new_ten(&DataType::Decimal256(5, 7)).is_err());
}

#[test]
fn test_new_negative_one_decimal128() {
assert_eq!(
ScalarValue::new_negative_one(&DataType::Decimal128(5, 0)).unwrap(),
ScalarValue::Decimal128(Some(-1), 5, 0)
);
assert_eq!(
ScalarValue::new_negative_one(&DataType::Decimal128(5, 2)).unwrap(),
ScalarValue::Decimal128(Some(-100), 5, 2)
);
}

#[test]
fn test_list_partial_cmp() {
let a =
Expand Down Expand Up @@ -6946,13 +7175,51 @@ mod tests {
ScalarValue::Float64(Some(-9.9)),
5,
),
(
ScalarValue::Decimal128(Some(10), 1, 0),
ScalarValue::Decimal128(Some(5), 1, 0),
5,
),
(
ScalarValue::Decimal128(Some(5), 1, 0),
ScalarValue::Decimal128(Some(10), 1, 0),
5,
),
(
ScalarValue::Decimal256(Some(10.into()), 1, 0),
ScalarValue::Decimal256(Some(5.into()), 1, 0),
5,
),
(
ScalarValue::Decimal256(Some(5.into()), 1, 0),
ScalarValue::Decimal256(Some(10.into()), 1, 0),
5,
),
];
for (lhs, rhs, expected) in cases.iter() {
let distance = lhs.distance(rhs).unwrap();
assert_eq!(distance, *expected);
}
}

#[test]
fn test_distance_none() {
let cases = [
(
ScalarValue::Decimal128(Some(i128::MAX), DECIMAL128_MAX_PRECISION, 0),
ScalarValue::Decimal128(Some(-i128::MAX), DECIMAL128_MAX_PRECISION, 0),
),
(
ScalarValue::Decimal256(Some(i256::MAX), DECIMAL256_MAX_PRECISION, 0),
ScalarValue::Decimal256(Some(-i256::MAX), DECIMAL256_MAX_PRECISION, 0),
),
];
for (lhs, rhs) in cases.iter() {
let distance = lhs.distance(rhs);
assert!(distance.is_none(), "{lhs} vs {rhs}");
}
}

#[test]
fn test_scalar_distance_invalid() {
let cases = [
Expand Down Expand Up @@ -6994,7 +7261,33 @@ mod tests {
(ScalarValue::Date64(Some(0)), ScalarValue::Date64(Some(1))),
(
ScalarValue::Decimal128(Some(123), 5, 5),
ScalarValue::Decimal128(Some(120), 5, 5),
ScalarValue::Decimal128(Some(120), 5, 3),
),
(
ScalarValue::Decimal128(Some(123), 5, 5),
ScalarValue::Decimal128(Some(120), 3, 5),
),
(
ScalarValue::Decimal256(Some(123.into()), 5, 5),
ScalarValue::Decimal256(Some(120.into()), 3, 5),
),
// Distance 2 * 2^50 is larger than usize
(
ScalarValue::Decimal256(
Some(i256::from_parts(0, 2_i64.pow(50).into())),
1,
0,
),
ScalarValue::Decimal256(
Some(i256::from_parts(0, (-(2_i64).pow(50)).into())),
1,
0,
),
),
// Distance overflow
(
ScalarValue::Decimal256(Some(i256::from_parts(0, i128::MAX)), 1, 0),
ScalarValue::Decimal256(Some(i256::from_parts(0, -i128::MAX)), 1, 0),
),
];
for (lhs, rhs) in cases {
Expand Down
Loading