Skip to content

Commit 16ce9af

Browse files
authored
chore(query): improve resolve large array (#18949)
* chore(query): improve resolve large array * update * update * update * update
1 parent e9cc22c commit 16ce9af

File tree

6 files changed

+221
-9
lines changed

6 files changed

+221
-9
lines changed

src/query/expression/src/type_check.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ pub fn check<Index: ColumnIndex>(
112112

113113
if dest_ty.is_integer() && src_ty.is_integer() {
114114
if let Ok(casted_scalar) =
115-
cast_scalar(*span, scalar.clone(), dest_ty, fn_registry)
115+
cast_scalar(*span, scalar.clone(), &dest_ty, fn_registry)
116116
{
117117
*scalar = casted_scalar;
118118
*data_type = scalar.as_ref().infer_data_type();

src/query/expression/src/utils/mod.rs

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ pub mod visitor;
2727

2828
use databend_common_ast::Span;
2929
use databend_common_column::bitmap::Bitmap;
30+
use databend_common_exception::ErrorCode;
3031
use databend_common_exception::Result;
3132

3233
pub use self::column_from::*;
@@ -36,8 +37,12 @@ use crate::types::AnyType;
3637
use crate::types::DataType;
3738
use crate::types::Decimal;
3839
use crate::types::DecimalDataKind;
40+
use crate::types::DecimalDataType;
3941
use crate::types::DecimalSize;
42+
use crate::types::NumberDataType;
4043
use crate::types::NumberScalar;
44+
use crate::types::F32;
45+
use crate::types::F64;
4146
use crate::BlockEntry;
4247
use crate::Column;
4348
use crate::DataBlock;
@@ -87,9 +92,13 @@ pub fn eval_function(
8792
pub fn cast_scalar(
8893
span: Span,
8994
scalar: Scalar,
90-
dest_type: DataType,
95+
dest_type: &DataType,
9196
fn_registry: &FunctionRegistry,
9297
) -> Result<Scalar> {
98+
if let Some(result) = try_fast_cast_scalar(&scalar, dest_type) {
99+
return result;
100+
}
101+
93102
let raw_expr = RawExpr::Cast {
94103
span,
95104
is_try: false,
@@ -98,7 +107,7 @@ pub fn cast_scalar(
98107
scalar,
99108
data_type: None,
100109
}),
101-
dest_type,
110+
dest_type: dest_type.clone(),
102111
};
103112
let expr = crate::type_check::check(&raw_expr, fn_registry)?;
104113
let block = DataBlock::empty();
@@ -107,6 +116,89 @@ pub fn cast_scalar(
107116
Ok(evaluator.run(&expr)?.into_scalar().unwrap())
108117
}
109118

119+
fn try_fast_cast_scalar(scalar: &Scalar, dest_type: &DataType) -> Option<Result<Scalar>> {
120+
match dest_type {
121+
DataType::Null => Some(Ok(Scalar::Null)),
122+
DataType::Nullable(inner) => {
123+
if matches!(scalar, Scalar::Null) {
124+
Some(Ok(Scalar::Null))
125+
} else {
126+
try_fast_cast_scalar(scalar, inner)
127+
}
128+
}
129+
DataType::Number(NumberDataType::Float32) => match scalar {
130+
Scalar::Null => Some(Ok(Scalar::Null)),
131+
Scalar::Number(num) => Some(Ok(Scalar::Number(NumberScalar::Float32(num.to_f32())))),
132+
Scalar::Decimal(dec) => Some(Ok(Scalar::Number(NumberScalar::Float32(F32::from(
133+
dec.to_float32(),
134+
))))),
135+
_ => None,
136+
},
137+
DataType::Number(NumberDataType::Float64) => match scalar {
138+
Scalar::Null => Some(Ok(Scalar::Null)),
139+
Scalar::Number(num) => Some(Ok(Scalar::Number(NumberScalar::Float64(num.to_f64())))),
140+
Scalar::Decimal(dec) => Some(Ok(Scalar::Number(NumberScalar::Float64(F64::from(
141+
dec.to_float64(),
142+
))))),
143+
_ => None,
144+
},
145+
DataType::Decimal(size) => match scalar {
146+
Scalar::Null => Some(Ok(Scalar::Null)),
147+
Scalar::Decimal(dec) => Some(rescale_decimal_scalar(*dec, *size)),
148+
_ => None,
149+
},
150+
_ => None,
151+
}
152+
}
153+
154+
fn rescale_decimal_scalar(decimal: DecimalScalar, target_size: DecimalSize) -> Result<Scalar> {
155+
let from_size = decimal.size();
156+
if from_size == target_size {
157+
return Ok(Scalar::Decimal(decimal));
158+
}
159+
160+
let source_scale = from_size.scale();
161+
let target_scale = target_size.scale();
162+
let data_type: DecimalDataType = target_size.into();
163+
164+
let scaled = match data_type {
165+
DecimalDataType::Decimal64(_) => {
166+
let value = decimal.as_decimal::<i64>();
167+
let adjusted = rescale_decimal_value(value, source_scale, target_scale)?;
168+
Scalar::Decimal(DecimalScalar::Decimal64(adjusted, target_size))
169+
}
170+
DecimalDataType::Decimal128(_) => {
171+
let value = decimal.as_decimal::<i128>();
172+
let adjusted = rescale_decimal_value(value, source_scale, target_scale)?;
173+
Scalar::Decimal(DecimalScalar::Decimal128(adjusted, target_size))
174+
}
175+
DecimalDataType::Decimal256(_) => {
176+
let value = decimal.as_decimal::<i256>();
177+
let adjusted = rescale_decimal_value(value, source_scale, target_scale)?;
178+
Scalar::Decimal(DecimalScalar::Decimal256(adjusted, target_size))
179+
}
180+
};
181+
182+
Ok(scaled)
183+
}
184+
185+
fn rescale_decimal_value<T: Decimal>(value: T, source_scale: u8, target_scale: u8) -> Result<T> {
186+
if source_scale == target_scale {
187+
return Ok(value);
188+
}
189+
190+
let diff = target_scale.abs_diff(source_scale);
191+
if target_scale > source_scale {
192+
value.checked_mul(T::e(diff)).ok_or_else(|| {
193+
ErrorCode::Overflow("Decimal literal overflow after scale expansion".to_string())
194+
})
195+
} else {
196+
value.checked_div(T::e(diff)).ok_or_else(|| {
197+
ErrorCode::Overflow("Decimal literal overflow after scale reduction".to_string())
198+
})
199+
}
200+
}
201+
110202
pub fn column_merge_validity(entry: &BlockEntry, bitmap: Option<Bitmap>) -> Option<Bitmap> {
111203
match entry {
112204
BlockEntry::Const(scalar, data_type, n) => {

src/query/functions/src/scalars/decimal/src/comparison.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,12 @@ fn op_decimal<Op: CmpOp>(
145145
T::e(size_calc.scale() - a_type.scale()),
146146
T::e(size_calc.scale() - b_type.scale()),
147147
);
148-
compare_decimal(a, b, |a, b, _| Op::compare(a, b, f_a, f_b), ctx)
148+
149+
if (f_a == f_b) {
150+
compare_decimal(a, b, |a, b, _| Op::is(a.cmp(&b)), ctx)
151+
} else {
152+
compare_decimal(a, b, |a, b, _| Op::compare(a, b, f_a, f_b), ctx)
153+
}
149154
}
150155
})
151156
}

src/query/service/src/interpreters/interpreter_set.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ impl SetInterpreter {
6262
async fn execute_settings(&self, scalars: Vec<Scalar>, is_global: bool) -> Result<()> {
6363
let scalars: Vec<Scalar> = scalars
6464
.into_iter()
65-
.map(|scalar| cast_scalar(None, scalar.clone(), DataType::String, &BUILTIN_FUNCTIONS))
65+
.map(|scalar| cast_scalar(None, scalar.clone(), &DataType::String, &BUILTIN_FUNCTIONS))
6666
.collect::<Result<Vec<_>>>()?;
6767

6868
let mut keys: Vec<String> = vec![];

src/query/sql/src/planner/binder/statement_settings.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ impl Binder {
8383
let scalar = cast_scalar(
8484
None,
8585
scalar.clone(),
86-
DataType::String,
86+
&DataType::String,
8787
&BUILTIN_FUNCTIONS,
8888
)?;
8989
results.push(scalar);

src/query/sql/src/planner/semantic/type_check.rs

Lines changed: 118 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,14 @@ use databend_common_compress::CompressAlgorithm;
6262
use databend_common_compress::DecompressDecoder;
6363
use databend_common_exception::ErrorCode;
6464
use databend_common_exception::Result;
65+
use databend_common_expression::cast_scalar;
6566
use databend_common_expression::display::display_tuple_field_name;
6667
use databend_common_expression::expr;
6768
use databend_common_expression::infer_schema_type;
6869
use databend_common_expression::shrink_scalar;
6970
use databend_common_expression::type_check;
7071
use databend_common_expression::type_check::check_number;
72+
use databend_common_expression::type_check::common_super_type;
7173
use databend_common_expression::type_check::convert_escape_pattern;
7274
use databend_common_expression::types::decimal::DecimalScalar;
7375
use databend_common_expression::types::decimal::DecimalSize;
@@ -81,6 +83,7 @@ use databend_common_expression::types::F32;
8183
use databend_common_expression::udf_client::UDFFlightClient;
8284
use databend_common_expression::BlockEntry;
8385
use databend_common_expression::Column;
86+
use databend_common_expression::ColumnBuilder;
8487
use databend_common_expression::ColumnIndex;
8588
use databend_common_expression::Constant;
8689
use databend_common_expression::ConstantFolder;
@@ -3459,6 +3462,15 @@ impl<'a> TypeChecker<'a> {
34593462
// Omit unary + operator
34603463
self.resolve(child)
34613464
}
3465+
UnaryOperator::Minus => {
3466+
if let Expr::Literal { value, .. } = child {
3467+
let box (value, data_type) = self.resolve_minus_literal_scalar(span, value)?;
3468+
let scalar_expr = ScalarExpr::ConstantExpr(ConstantExpr { span, value });
3469+
return Ok(Box::new((scalar_expr, data_type)));
3470+
}
3471+
let name = op.to_func_name();
3472+
self.resolve_function(span, name.as_str(), vec![], &[child])
3473+
}
34623474
other => {
34633475
let name = other.to_func_name();
34643476
self.resolve_function(span, name.as_str(), vec![], &[child])
@@ -4871,18 +4883,121 @@ impl<'a> TypeChecker<'a> {
48714883
Ok(Box::new((value, data_type)))
48724884
}
48734885

4874-
// TODO(leiysky): use an array builder function instead, since we should allow declaring
4875-
// an array with variable as element.
4886+
pub fn resolve_minus_literal_scalar(
4887+
&self,
4888+
span: Span,
4889+
literal: &databend_common_ast::ast::Literal,
4890+
) -> Result<Box<(Scalar, DataType)>> {
4891+
let value = match literal {
4892+
Literal::UInt64(v) => {
4893+
if *v <= i64::MAX as u64 {
4894+
Scalar::Number(NumberScalar::Int64(-(*v as i64)))
4895+
} else {
4896+
Scalar::Decimal(DecimalScalar::Decimal128(
4897+
-(*v as i128),
4898+
DecimalSize::new_unchecked(i128::MAX_PRECISION, 0),
4899+
))
4900+
}
4901+
}
4902+
Literal::Decimal256 {
4903+
value,
4904+
precision,
4905+
scale,
4906+
} => Scalar::Decimal(DecimalScalar::Decimal256(
4907+
i256(*value).checked_mul(i256::minus_one()).unwrap(),
4908+
DecimalSize::new_unchecked(*precision, *scale),
4909+
)),
4910+
Literal::Float64(v) => Scalar::Number(NumberScalar::Float64((-*v).into())),
4911+
Literal::Null => Scalar::Null,
4912+
Literal::String(_) | Literal::Boolean(_) => {
4913+
return Err(ErrorCode::InvalidArgument(format!(
4914+
"Invalid minus operator for {}",
4915+
literal
4916+
))
4917+
.set_span(span));
4918+
}
4919+
};
4920+
let value = shrink_scalar(value);
4921+
let data_type = value.as_ref().infer_data_type();
4922+
Ok(Box::new((value, data_type)))
4923+
}
4924+
4925+
// Fast path for constant arrays so we don't need to go through the scalar `array()` function
4926+
// (which performs full type-checking and constant-folding). Non-constant elements still use
4927+
// the generic resolver to preserve the previous behaviour.
48764928
fn resolve_array(&mut self, span: Span, exprs: &[Expr]) -> Result<Box<(ScalarExpr, DataType)>> {
48774929
let mut elems = Vec::with_capacity(exprs.len());
4930+
let mut constant_values: Option<Vec<(Scalar, DataType)>> =
4931+
Some(Vec::with_capacity(exprs.len()));
4932+
let mut element_type: Option<DataType> = None;
4933+
48784934
for expr in exprs {
4879-
let box (arg, _data_type) = self.resolve(expr)?;
4935+
let box (arg, data_type) = self.resolve(expr)?;
4936+
if let Some(values) = constant_values.as_mut() {
4937+
let maybe_constant = match &arg {
4938+
ScalarExpr::ConstantExpr(constant) => Some(constant.value.clone()),
4939+
ScalarExpr::TypedConstantExpr(constant, _) => Some(constant.value.clone()),
4940+
_ => None,
4941+
};
4942+
if let Some(value) = maybe_constant {
4943+
element_type = if let Some(current_ty) = element_type.clone() {
4944+
common_super_type(
4945+
current_ty.clone(),
4946+
data_type.clone(),
4947+
&BUILTIN_FUNCTIONS.default_cast_rules,
4948+
)
4949+
} else {
4950+
Some(data_type.clone())
4951+
};
4952+
4953+
if element_type.is_some() {
4954+
values.push((value, data_type));
4955+
} else {
4956+
constant_values = None;
4957+
element_type = None;
4958+
}
4959+
} else {
4960+
constant_values = None;
4961+
element_type = None;
4962+
}
4963+
}
48804964
elems.push(arg);
48814965
}
48824966

4967+
if let (Some(values), Some(element_ty)) = (constant_values, element_type) {
4968+
let mut casted = Vec::with_capacity(values.len());
4969+
for (value, ty) in values {
4970+
if ty == element_ty {
4971+
casted.push(value);
4972+
} else {
4973+
casted.push(cast_scalar(span, value, &element_ty, &BUILTIN_FUNCTIONS)?);
4974+
}
4975+
}
4976+
return Ok(Self::build_constant_array(span, element_ty, casted));
4977+
}
4978+
48834979
self.resolve_scalar_function_call(span, "array", vec![], elems)
48844980
}
48854981

4982+
fn build_constant_array(
4983+
span: Span,
4984+
element_ty: DataType,
4985+
values: Vec<Scalar>,
4986+
) -> Box<(ScalarExpr, DataType)> {
4987+
let mut builder = ColumnBuilder::with_capacity(&element_ty, values.len());
4988+
for value in &values {
4989+
builder.push(value.as_ref());
4990+
}
4991+
let scalar = Scalar::Array(builder.build());
4992+
Box::new((
4993+
ScalarExpr::ConstantExpr(ConstantExpr {
4994+
span,
4995+
value: scalar,
4996+
}),
4997+
DataType::Array(Box::new(element_ty)),
4998+
))
4999+
}
5000+
48865001
fn resolve_map(
48875002
&mut self,
48885003
span: Span,

0 commit comments

Comments
 (0)