@@ -74,12 +74,13 @@ use arrow::compute::kernels::numeric::{
74
74
add, add_wrapping, div, mul, mul_wrapping, rem, sub, sub_wrapping,
75
75
} ;
76
76
use arrow:: datatypes:: {
77
- i256, ArrowDictionaryKeyType , ArrowNativeType , ArrowTimestampType , DataType ,
78
- Date32Type , Field , Float32Type , Int16Type , Int32Type , Int64Type , Int8Type ,
79
- IntervalDayTime , IntervalDayTimeType , IntervalMonthDayNano , IntervalMonthDayNanoType ,
80
- IntervalUnit , IntervalYearMonthType , TimeUnit , TimestampMicrosecondType ,
81
- TimestampMillisecondType , TimestampNanosecondType , TimestampSecondType , UInt16Type ,
82
- UInt32Type , UInt64Type , UInt8Type , UnionFields , UnionMode , DECIMAL128_MAX_PRECISION ,
77
+ i256, validate_decimal_precision_and_scale, ArrowDictionaryKeyType , ArrowNativeType ,
78
+ ArrowTimestampType , DataType , Date32Type , Decimal128Type , Decimal256Type , Field ,
79
+ Float32Type , Int16Type , Int32Type , Int64Type , Int8Type , IntervalDayTime ,
80
+ IntervalDayTimeType , IntervalMonthDayNano , IntervalMonthDayNanoType , IntervalUnit ,
81
+ IntervalYearMonthType , TimeUnit , TimestampMicrosecondType , TimestampMillisecondType ,
82
+ TimestampNanosecondType , TimestampSecondType , UInt16Type , UInt32Type , UInt64Type ,
83
+ UInt8Type , UnionFields , UnionMode , DECIMAL128_MAX_PRECISION ,
83
84
} ;
84
85
use arrow:: util:: display:: { array_value_to_string, ArrayFormatter , FormatOptions } ;
85
86
use cache:: { get_or_create_cached_key_array, get_or_create_cached_null_array} ;
@@ -1516,6 +1517,34 @@ impl ScalarValue {
1516
1517
DataType :: Float16 => ScalarValue :: Float16 ( Some ( f16:: from_f32 ( 1.0 ) ) ) ,
1517
1518
DataType :: Float32 => ScalarValue :: Float32 ( Some ( 1.0 ) ) ,
1518
1519
DataType :: Float64 => ScalarValue :: Float64 ( Some ( 1.0 ) ) ,
1520
+ DataType :: Decimal128 ( precision, scale) => {
1521
+ validate_decimal_precision_and_scale :: < Decimal128Type > (
1522
+ * precision, * scale,
1523
+ ) ?;
1524
+ if * scale < 0 {
1525
+ return _internal_err ! ( "Negative scale is not supported" ) ;
1526
+ }
1527
+ match i128:: from ( 10 ) . checked_pow ( * scale as u32 ) {
1528
+ Some ( value) => {
1529
+ ScalarValue :: Decimal128 ( Some ( value) , * precision, * scale)
1530
+ }
1531
+ None => return _internal_err ! ( "Unsupported scale {scale}" ) ,
1532
+ }
1533
+ }
1534
+ DataType :: Decimal256 ( precision, scale) => {
1535
+ validate_decimal_precision_and_scale :: < Decimal256Type > (
1536
+ * precision, * scale,
1537
+ ) ?;
1538
+ if * scale < 0 {
1539
+ return _internal_err ! ( "Negative scale is not supported" ) ;
1540
+ }
1541
+ match i256:: from ( 10 ) . checked_pow ( * scale as u32 ) {
1542
+ Some ( value) => {
1543
+ ScalarValue :: Decimal256 ( Some ( value) , * precision, * scale)
1544
+ }
1545
+ None => return _internal_err ! ( "Unsupported scale {scale}" ) ,
1546
+ }
1547
+ }
1519
1548
_ => {
1520
1549
return _not_impl_err ! (
1521
1550
"Can't create an one scalar from data_type \" {datatype:?}\" "
@@ -1534,6 +1563,34 @@ impl ScalarValue {
1534
1563
DataType :: Float16 => ScalarValue :: Float16 ( Some ( f16:: from_f32 ( -1.0 ) ) ) ,
1535
1564
DataType :: Float32 => ScalarValue :: Float32 ( Some ( -1.0 ) ) ,
1536
1565
DataType :: Float64 => ScalarValue :: Float64 ( Some ( -1.0 ) ) ,
1566
+ DataType :: Decimal128 ( precision, scale) => {
1567
+ validate_decimal_precision_and_scale :: < Decimal128Type > (
1568
+ * precision, * scale,
1569
+ ) ?;
1570
+ if * scale < 0 {
1571
+ return _internal_err ! ( "Negative scale is not supported" ) ;
1572
+ }
1573
+ match i128:: from ( 10 ) . checked_pow ( * scale as u32 ) {
1574
+ Some ( value) => {
1575
+ ScalarValue :: Decimal128 ( Some ( -value) , * precision, * scale)
1576
+ }
1577
+ None => return _internal_err ! ( "Unsupported scale {scale}" ) ,
1578
+ }
1579
+ }
1580
+ DataType :: Decimal256 ( precision, scale) => {
1581
+ validate_decimal_precision_and_scale :: < Decimal256Type > (
1582
+ * precision, * scale,
1583
+ ) ?;
1584
+ if * scale < 0 {
1585
+ return _internal_err ! ( "Negative scale is not supported" ) ;
1586
+ }
1587
+ match i256:: from ( 10 ) . checked_pow ( * scale as u32 ) {
1588
+ Some ( value) => {
1589
+ ScalarValue :: Decimal256 ( Some ( -value) , * precision, * scale)
1590
+ }
1591
+ None => return _internal_err ! ( "Unsupported scale {scale}" ) ,
1592
+ }
1593
+ }
1537
1594
_ => {
1538
1595
return _not_impl_err ! (
1539
1596
"Can't create a negative one scalar from data_type \" {datatype:?}\" "
@@ -1555,6 +1612,38 @@ impl ScalarValue {
1555
1612
DataType :: Float16 => ScalarValue :: Float16 ( Some ( f16:: from_f32 ( 10.0 ) ) ) ,
1556
1613
DataType :: Float32 => ScalarValue :: Float32 ( Some ( 10.0 ) ) ,
1557
1614
DataType :: Float64 => ScalarValue :: Float64 ( Some ( 10.0 ) ) ,
1615
+ DataType :: Decimal128 ( precision, scale) => {
1616
+ if let Err ( err) = validate_decimal_precision_and_scale :: < Decimal128Type > (
1617
+ * precision, * scale,
1618
+ ) {
1619
+ return _internal_err ! ( "Invalid precision and scale {err}" ) ;
1620
+ }
1621
+ if * scale <= 0 {
1622
+ return _internal_err ! ( "Negative scale is not supported" ) ;
1623
+ }
1624
+ match i128:: from ( 10 ) . checked_pow ( ( * scale + 1 ) as u32 ) {
1625
+ Some ( value) => {
1626
+ ScalarValue :: Decimal128 ( Some ( value) , * precision, * scale)
1627
+ }
1628
+ None => return _internal_err ! ( "Unsupported scale {scale}" ) ,
1629
+ }
1630
+ }
1631
+ DataType :: Decimal256 ( precision, scale) => {
1632
+ if let Err ( err) = validate_decimal_precision_and_scale :: < Decimal256Type > (
1633
+ * precision, * scale,
1634
+ ) {
1635
+ return _internal_err ! ( "Invalid precision and scale {err}" ) ;
1636
+ }
1637
+ if * scale <= 0 {
1638
+ return _internal_err ! ( "Negative scale is not supported" ) ;
1639
+ }
1640
+ match i256:: from ( 10 ) . checked_pow ( ( * scale + 1 ) as u32 ) {
1641
+ Some ( value) => {
1642
+ ScalarValue :: Decimal256 ( Some ( value) , * precision, * scale)
1643
+ }
1644
+ None => return _internal_err ! ( "Unsupported scale {scale}" ) ,
1645
+ }
1646
+ }
1558
1647
_ => {
1559
1648
return _not_impl_err ! (
1560
1649
"Can't create a ten scalar from data_type \" {datatype:?}\" "
@@ -1924,6 +2013,26 @@ impl ScalarValue {
1924
2013
( Self :: Float64 ( Some ( l) ) , Self :: Float64 ( Some ( r) ) ) => {
1925
2014
Some ( ( l - r) . abs ( ) . round ( ) as _ )
1926
2015
}
2016
+ (
2017
+ Self :: Decimal128 ( Some ( l) , lprecision, lscale) ,
2018
+ Self :: Decimal128 ( Some ( r) , rprecision, rscale) ,
2019
+ ) => {
2020
+ if lprecision == rprecision && lscale == rscale {
2021
+ l. checked_sub ( * r) ?. checked_abs ( ) ?. to_usize ( )
2022
+ } else {
2023
+ None
2024
+ }
2025
+ }
2026
+ (
2027
+ Self :: Decimal256 ( Some ( l) , lprecision, lscale) ,
2028
+ Self :: Decimal256 ( Some ( r) , rprecision, rscale) ,
2029
+ ) => {
2030
+ if lprecision == rprecision && lscale == rscale {
2031
+ l. checked_sub ( * r) ?. checked_abs ( ) ?. to_usize ( )
2032
+ } else {
2033
+ None
2034
+ }
2035
+ }
1927
2036
_ => None ,
1928
2037
}
1929
2038
}
@@ -4489,7 +4598,9 @@ mod tests {
4489
4598
} ;
4490
4599
use arrow:: buffer:: { Buffer , OffsetBuffer } ;
4491
4600
use arrow:: compute:: { is_null, kernels} ;
4492
- use arrow:: datatypes:: { ArrowNumericType , Fields , Float64Type } ;
4601
+ use arrow:: datatypes:: {
4602
+ ArrowNumericType , Fields , Float64Type , DECIMAL256_MAX_PRECISION ,
4603
+ } ;
4493
4604
use arrow:: error:: ArrowError ;
4494
4605
use arrow:: util:: pretty:: pretty_format_columns;
4495
4606
use chrono:: NaiveDate ;
@@ -5225,6 +5336,116 @@ mod tests {
5225
5336
Ok ( ( ) )
5226
5337
}
5227
5338
5339
+ #[ test]
5340
+ fn test_new_one_decimal128 ( ) {
5341
+ assert_eq ! (
5342
+ ScalarValue :: new_one( & DataType :: Decimal128 ( 5 , 0 ) ) . unwrap( ) ,
5343
+ ScalarValue :: Decimal128 ( Some ( 1 ) , 5 , 0 )
5344
+ ) ;
5345
+ assert_eq ! (
5346
+ ScalarValue :: new_one( & DataType :: Decimal128 ( 5 , 1 ) ) . unwrap( ) ,
5347
+ ScalarValue :: Decimal128 ( Some ( 10 ) , 5 , 1 )
5348
+ ) ;
5349
+ assert_eq ! (
5350
+ ScalarValue :: new_one( & DataType :: Decimal128 ( 5 , 2 ) ) . unwrap( ) ,
5351
+ ScalarValue :: Decimal128 ( Some ( 100 ) , 5 , 2 )
5352
+ ) ;
5353
+ // More precision
5354
+ assert_eq ! (
5355
+ ScalarValue :: new_one( & DataType :: Decimal128 ( 7 , 2 ) ) . unwrap( ) ,
5356
+ ScalarValue :: Decimal128 ( Some ( 100 ) , 7 , 2 )
5357
+ ) ;
5358
+ // No negative scale
5359
+ assert ! ( ScalarValue :: new_one( & DataType :: Decimal128 ( 5 , -1 ) ) . is_err( ) ) ;
5360
+ // Invalid combination
5361
+ assert ! ( ScalarValue :: new_one( & DataType :: Decimal128 ( 0 , 2 ) ) . is_err( ) ) ;
5362
+ assert ! ( ScalarValue :: new_one( & DataType :: Decimal128 ( 5 , 7 ) ) . is_err( ) ) ;
5363
+ }
5364
+
5365
+ #[ test]
5366
+ fn test_new_one_decimal256 ( ) {
5367
+ assert_eq ! (
5368
+ ScalarValue :: new_one( & DataType :: Decimal256 ( 5 , 0 ) ) . unwrap( ) ,
5369
+ ScalarValue :: Decimal256 ( Some ( 1 . into( ) ) , 5 , 0 )
5370
+ ) ;
5371
+ assert_eq ! (
5372
+ ScalarValue :: new_one( & DataType :: Decimal256 ( 5 , 1 ) ) . unwrap( ) ,
5373
+ ScalarValue :: Decimal256 ( Some ( 10 . into( ) ) , 5 , 1 )
5374
+ ) ;
5375
+ assert_eq ! (
5376
+ ScalarValue :: new_one( & DataType :: Decimal256 ( 5 , 2 ) ) . unwrap( ) ,
5377
+ ScalarValue :: Decimal256 ( Some ( 100 . into( ) ) , 5 , 2 )
5378
+ ) ;
5379
+ // More precision
5380
+ assert_eq ! (
5381
+ ScalarValue :: new_one( & DataType :: Decimal256 ( 7 , 2 ) ) . unwrap( ) ,
5382
+ ScalarValue :: Decimal256 ( Some ( 100 . into( ) ) , 7 , 2 )
5383
+ ) ;
5384
+ // No negative scale
5385
+ assert ! ( ScalarValue :: new_one( & DataType :: Decimal256 ( 5 , -1 ) ) . is_err( ) ) ;
5386
+ // Invalid combination
5387
+ assert ! ( ScalarValue :: new_one( & DataType :: Decimal256 ( 0 , 2 ) ) . is_err( ) ) ;
5388
+ assert ! ( ScalarValue :: new_one( & DataType :: Decimal256 ( 5 , 7 ) ) . is_err( ) ) ;
5389
+ }
5390
+
5391
+ #[ test]
5392
+ fn test_new_ten_decimal128 ( ) {
5393
+ assert_eq ! (
5394
+ ScalarValue :: new_ten( & DataType :: Decimal128 ( 5 , 1 ) ) . unwrap( ) ,
5395
+ ScalarValue :: Decimal128 ( Some ( 100 ) , 5 , 1 )
5396
+ ) ;
5397
+ assert_eq ! (
5398
+ ScalarValue :: new_ten( & DataType :: Decimal128 ( 5 , 2 ) ) . unwrap( ) ,
5399
+ ScalarValue :: Decimal128 ( Some ( 1000 ) , 5 , 2 )
5400
+ ) ;
5401
+ // More precision
5402
+ assert_eq ! (
5403
+ ScalarValue :: new_ten( & DataType :: Decimal128 ( 7 , 2 ) ) . unwrap( ) ,
5404
+ ScalarValue :: Decimal128 ( Some ( 1000 ) , 7 , 2 )
5405
+ ) ;
5406
+ // No negative or zero scale
5407
+ assert ! ( ScalarValue :: new_ten( & DataType :: Decimal128 ( 5 , 0 ) ) . is_err( ) ) ;
5408
+ assert ! ( ScalarValue :: new_ten( & DataType :: Decimal128 ( 5 , -1 ) ) . is_err( ) ) ;
5409
+ // Invalid combination
5410
+ assert ! ( ScalarValue :: new_ten( & DataType :: Decimal128 ( 0 , 2 ) ) . is_err( ) ) ;
5411
+ assert ! ( ScalarValue :: new_ten( & DataType :: Decimal128 ( 5 , 7 ) ) . is_err( ) ) ;
5412
+ }
5413
+
5414
+ #[ test]
5415
+ fn test_new_ten_decimal256 ( ) {
5416
+ assert_eq ! (
5417
+ ScalarValue :: new_ten( & DataType :: Decimal256 ( 5 , 1 ) ) . unwrap( ) ,
5418
+ ScalarValue :: Decimal256 ( Some ( 100 . into( ) ) , 5 , 1 )
5419
+ ) ;
5420
+ assert_eq ! (
5421
+ ScalarValue :: new_ten( & DataType :: Decimal256 ( 5 , 2 ) ) . unwrap( ) ,
5422
+ ScalarValue :: Decimal256 ( Some ( 1000 . into( ) ) , 5 , 2 )
5423
+ ) ;
5424
+ // More precision
5425
+ assert_eq ! (
5426
+ ScalarValue :: new_ten( & DataType :: Decimal256 ( 7 , 2 ) ) . unwrap( ) ,
5427
+ ScalarValue :: Decimal256 ( Some ( 1000 . into( ) ) , 7 , 2 )
5428
+ ) ;
5429
+ // No negative or zero scale
5430
+ assert ! ( ScalarValue :: new_ten( & DataType :: Decimal256 ( 5 , 0 ) ) . is_err( ) ) ;
5431
+ assert ! ( ScalarValue :: new_ten( & DataType :: Decimal256 ( 5 , -1 ) ) . is_err( ) ) ;
5432
+ // Invalid combination
5433
+ assert ! ( ScalarValue :: new_ten( & DataType :: Decimal256 ( 0 , 2 ) ) . is_err( ) ) ;
5434
+ assert ! ( ScalarValue :: new_ten( & DataType :: Decimal256 ( 5 , 7 ) ) . is_err( ) ) ;
5435
+ }
5436
+
5437
+ #[ test]
5438
+ fn test_new_negative_one_decimal128 ( ) {
5439
+ assert_eq ! (
5440
+ ScalarValue :: new_negative_one( & DataType :: Decimal128 ( 5 , 0 ) ) . unwrap( ) ,
5441
+ ScalarValue :: Decimal128 ( Some ( -1 ) , 5 , 0 )
5442
+ ) ;
5443
+ assert_eq ! (
5444
+ ScalarValue :: new_negative_one( & DataType :: Decimal128 ( 5 , 2 ) ) . unwrap( ) ,
5445
+ ScalarValue :: Decimal128 ( Some ( -100 ) , 5 , 2 )
5446
+ ) ;
5447
+ }
5448
+
5228
5449
#[ test]
5229
5450
fn test_list_partial_cmp ( ) {
5230
5451
let a =
@@ -7275,13 +7496,51 @@ mod tests {
7275
7496
ScalarValue :: Float64 ( Some ( -9.9 ) ) ,
7276
7497
5 ,
7277
7498
) ,
7499
+ (
7500
+ ScalarValue :: Decimal128 ( Some ( 10 ) , 1 , 0 ) ,
7501
+ ScalarValue :: Decimal128 ( Some ( 5 ) , 1 , 0 ) ,
7502
+ 5 ,
7503
+ ) ,
7504
+ (
7505
+ ScalarValue :: Decimal128 ( Some ( 5 ) , 1 , 0 ) ,
7506
+ ScalarValue :: Decimal128 ( Some ( 10 ) , 1 , 0 ) ,
7507
+ 5 ,
7508
+ ) ,
7509
+ (
7510
+ ScalarValue :: Decimal256 ( Some ( 10 . into ( ) ) , 1 , 0 ) ,
7511
+ ScalarValue :: Decimal256 ( Some ( 5 . into ( ) ) , 1 , 0 ) ,
7512
+ 5 ,
7513
+ ) ,
7514
+ (
7515
+ ScalarValue :: Decimal256 ( Some ( 5 . into ( ) ) , 1 , 0 ) ,
7516
+ ScalarValue :: Decimal256 ( Some ( 10 . into ( ) ) , 1 , 0 ) ,
7517
+ 5 ,
7518
+ ) ,
7278
7519
] ;
7279
7520
for ( lhs, rhs, expected) in cases. iter ( ) {
7280
7521
let distance = lhs. distance ( rhs) . unwrap ( ) ;
7281
7522
assert_eq ! ( distance, * expected) ;
7282
7523
}
7283
7524
}
7284
7525
7526
+ #[ test]
7527
+ fn test_distance_none ( ) {
7528
+ let cases = [
7529
+ (
7530
+ ScalarValue :: Decimal128 ( Some ( i128:: MAX ) , DECIMAL128_MAX_PRECISION , 0 ) ,
7531
+ ScalarValue :: Decimal128 ( Some ( -i128:: MAX ) , DECIMAL128_MAX_PRECISION , 0 ) ,
7532
+ ) ,
7533
+ (
7534
+ ScalarValue :: Decimal256 ( Some ( i256:: MAX ) , DECIMAL256_MAX_PRECISION , 0 ) ,
7535
+ ScalarValue :: Decimal256 ( Some ( -i256:: MAX ) , DECIMAL256_MAX_PRECISION , 0 ) ,
7536
+ ) ,
7537
+ ] ;
7538
+ for ( lhs, rhs) in cases. iter ( ) {
7539
+ let distance = lhs. distance ( rhs) ;
7540
+ assert ! ( distance. is_none( ) , "{lhs} vs {rhs}" ) ;
7541
+ }
7542
+ }
7543
+
7285
7544
#[ test]
7286
7545
fn test_scalar_distance_invalid ( ) {
7287
7546
let cases = [
@@ -7323,7 +7582,33 @@ mod tests {
7323
7582
( ScalarValue :: Date64 ( Some ( 0 ) ) , ScalarValue :: Date64 ( Some ( 1 ) ) ) ,
7324
7583
(
7325
7584
ScalarValue :: Decimal128 ( Some ( 123 ) , 5 , 5 ) ,
7326
- ScalarValue :: Decimal128 ( Some ( 120 ) , 5 , 5 ) ,
7585
+ ScalarValue :: Decimal128 ( Some ( 120 ) , 5 , 3 ) ,
7586
+ ) ,
7587
+ (
7588
+ ScalarValue :: Decimal128 ( Some ( 123 ) , 5 , 5 ) ,
7589
+ ScalarValue :: Decimal128 ( Some ( 120 ) , 3 , 5 ) ,
7590
+ ) ,
7591
+ (
7592
+ ScalarValue :: Decimal256 ( Some ( 123 . into ( ) ) , 5 , 5 ) ,
7593
+ ScalarValue :: Decimal256 ( Some ( 120 . into ( ) ) , 3 , 5 ) ,
7594
+ ) ,
7595
+ // Distance 2 * 2^50 is larger than usize
7596
+ (
7597
+ ScalarValue :: Decimal256 (
7598
+ Some ( i256:: from_parts ( 0 , 2_i64 . pow ( 50 ) . into ( ) ) ) ,
7599
+ 1 ,
7600
+ 0 ,
7601
+ ) ,
7602
+ ScalarValue :: Decimal256 (
7603
+ Some ( i256:: from_parts ( 0 , ( -( 2_i64 ) . pow ( 50 ) ) . into ( ) ) ) ,
7604
+ 1 ,
7605
+ 0 ,
7606
+ ) ,
7607
+ ) ,
7608
+ // Distance overflow
7609
+ (
7610
+ ScalarValue :: Decimal256 ( Some ( i256:: from_parts ( 0 , i128:: MAX ) ) , 1 , 0 ) ,
7611
+ ScalarValue :: Decimal256 ( Some ( i256:: from_parts ( 0 , -i128:: MAX ) ) , 1 , 0 ) ,
7327
7612
) ,
7328
7613
] ;
7329
7614
for ( lhs, rhs) in cases {
0 commit comments