Skip to content

Commit 5e46381

Browse files
authored
feat(query): pivot support any order by expression (#18770)
* fix(query): fix pivot working with cte * fix(query): fix pivot working with cte * fix(query): fix pivot working with cte * fix(query): fix pivot working with cte * fix(query): fix pivot working with cte * fix(query): fix alias name * fix(query): fix alias name
1 parent 0a218c8 commit 5e46381

File tree

9 files changed

+870
-176
lines changed

9 files changed

+870
-176
lines changed

src/query/ast/src/ast/query.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,7 @@ impl Display for TimeTravelPoint {
607607
pub enum PivotValues {
608608
ColumnValues(Vec<Expr>),
609609
Subquery(Box<Query>),
610+
Any { order_by: Option<Vec<OrderByExpr>> },
610611
}
611612

612613
#[derive(Debug, Clone, PartialEq, Drive, DriveMut)]
@@ -626,6 +627,13 @@ impl Display for Pivot {
626627
PivotValues::Subquery(subquery) => {
627628
write!(f, "{}", subquery)?;
628629
}
630+
PivotValues::Any { order_by } => {
631+
write!(f, "ANY")?;
632+
if let Some(order_by_exprs) = order_by {
633+
write!(f, " ORDER BY ")?;
634+
write_comma_separated_list(f, order_by_exprs)?;
635+
}
636+
}
629637
}
630638
write!(f, "))")?;
631639
Ok(())

src/query/ast/src/parser/query.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -945,8 +945,20 @@ fn unpivot(i: Input) -> IResult<Unpivot> {
945945

946946
fn pivot_values(i: Input) -> IResult<PivotValues> {
947947
alt((
948-
map(comma_separated_list1(expr), PivotValues::ColumnValues),
948+
// Parse ANY [ORDER BY ...] - must be first to avoid ANY being parsed as expression
949+
map(
950+
rule! {
951+
ANY ~
952+
(ORDER ~ BY ~ #comma_separated_list1(order_by_expr))?
953+
},
954+
|(_, order_by_opt)| PivotValues::Any {
955+
order_by: order_by_opt.map(|(_, _, order_by_list)| order_by_list),
956+
},
957+
),
958+
// Parse subquery - must be before expr list to avoid parsing subquery as expression
949959
map(query, |q| PivotValues::Subquery(Box::new(q))),
960+
// Parse expression list - must be last
961+
map(comma_separated_list1(expr), PivotValues::ColumnValues),
950962
))(i)
951963
}
952964

src/query/ast/tests/it/parser.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,6 +1225,7 @@ fn test_query() {
12251225
r#"SELECT * FROM ((SELECT *) EXCEPT (SELECT *)) foo"#,
12261226
r#"SELECT * FROM (((SELECT *) EXCEPT (SELECT *))) foo"#,
12271227
r#"SELECT * FROM (SELECT * FROM xyu ORDER BY x, y) AS xyu"#,
1228+
r#"SELECT 'Lowest value sale' AS aggregate, * FROM quarterly_sales PIVOT(MIN(amount) FOR quarter IN (ANY ORDER BY quarter))"#,
12281229
r#"select * from monthly_sales pivot(sum(amount) for month in ('JAN', 'FEB', 'MAR', 'APR')) order by empid"#,
12291230
r#"select * from (select * from monthly_sales) pivot(sum(amount) for month in ('JAN', 'FEB', 'MAR', 'APR')) order by empid"#,
12301231
r#"select * from monthly_sales pivot(sum(amount) for month in (select distinct month from monthly_sales)) order by empid"#,

src/query/ast/tests/it/testdata/query.txt

Lines changed: 233 additions & 73 deletions
Large diffs are not rendered by default.

src/query/expression/src/types.rs

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ use std::ops::Range;
4848

4949
use databend_common_ast::ast::TypeName;
5050
pub use databend_common_base::base::OrderedFloat;
51+
use databend_common_exception::ErrorCode;
5152
pub use databend_common_io::deserialize_bitmap;
5253
use enum_as_inner::EnumAsInner;
5354
use serde::Deserialize;
@@ -406,6 +407,76 @@ impl DataType {
406407
}
407408
}
408409

410+
pub fn to_type_name(&self) -> databend_common_exception::Result<TypeName> {
411+
match self {
412+
DataType::Number(num_ty) => match num_ty {
413+
NumberDataType::UInt8 => Ok(TypeName::UInt8),
414+
NumberDataType::UInt16 => Ok(TypeName::UInt16),
415+
NumberDataType::UInt32 => Ok(TypeName::UInt32),
416+
NumberDataType::UInt64 => Ok(TypeName::UInt64),
417+
NumberDataType::Int8 => Ok(TypeName::Int8),
418+
NumberDataType::Int16 => Ok(TypeName::Int16),
419+
NumberDataType::Int32 => Ok(TypeName::Int32),
420+
NumberDataType::Int64 => Ok(TypeName::Int64),
421+
NumberDataType::Float32 => Ok(TypeName::Float32),
422+
NumberDataType::Float64 => Ok(TypeName::Float64),
423+
},
424+
DataType::String => Ok(TypeName::String),
425+
DataType::Date => Ok(TypeName::Date),
426+
DataType::Timestamp => Ok(TypeName::Timestamp),
427+
DataType::Interval => Ok(TypeName::Interval),
428+
DataType::Decimal(size) => {
429+
let precision = size.precision();
430+
let scale = size.scale();
431+
Ok(TypeName::Decimal { precision, scale })
432+
}
433+
DataType::Array(inner_ty) => {
434+
let inner_ty = inner_ty.to_type_name()?;
435+
Ok(TypeName::Array(Box::new(inner_ty)))
436+
}
437+
DataType::Map(inner_ty) => {
438+
let inner_ty = inner_ty.as_tuple().unwrap();
439+
let key_ty = inner_ty[0].to_type_name()?;
440+
let val_ty = inner_ty[1].to_type_name()?;
441+
Ok(TypeName::Map {
442+
key_type: Box::new(key_ty),
443+
val_type: Box::new(val_ty),
444+
})
445+
}
446+
DataType::Bitmap => Ok(TypeName::Bitmap),
447+
DataType::Variant => Ok(TypeName::Variant),
448+
DataType::Geometry => Ok(TypeName::Geometry),
449+
DataType::Geography => Ok(TypeName::Geography),
450+
DataType::Tuple(inner_tys) => {
451+
let inner_tys = inner_tys
452+
.iter()
453+
.map(|inner_ty| inner_ty.to_type_name())
454+
.collect::<Result<Vec<TypeName>, ErrorCode>>()?;
455+
Ok(TypeName::Tuple {
456+
fields_name: None,
457+
fields_type: inner_tys,
458+
})
459+
}
460+
DataType::Vector(inner_ty) => {
461+
let d = inner_ty.dimension();
462+
Ok(TypeName::Vector(d))
463+
}
464+
DataType::Nullable(inner_ty) => {
465+
Ok(TypeName::Nullable(Box::new(inner_ty.to_type_name()?)))
466+
}
467+
DataType::Boolean => Ok(TypeName::Boolean),
468+
DataType::Binary => Ok(TypeName::Binary),
469+
DataType::Null
470+
| DataType::EmptyArray
471+
| DataType::EmptyMap
472+
| DataType::Opaque(_)
473+
| DataType::Generic(_) => Err(ErrorCode::BadArguments(format!(
474+
"Unsupported data type {} to sql type",
475+
self
476+
))),
477+
}
478+
}
479+
409480
// Returns the number of leaf columns of the DataType
410481
pub fn num_leaf_columns(&self) -> usize {
411482
match self {

src/query/functions/src/aggregates/adaptors/aggregate_combinator_if.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ pub struct AggregateIfCombinator {
4444
argument_len: usize,
4545
nested_name: String,
4646
nested: AggregateFunctionRef,
47+
always_false: bool,
4748
}
4849

4950
impl AggregateIfCombinator {
@@ -63,7 +64,10 @@ impl AggregateIfCombinator {
6364
)));
6465
}
6566

66-
if !matches!(&arguments[argument_len - 1], DataType::Boolean) {
67+
let mut always_false = false;
68+
if arguments[argument_len - 1].is_null() {
69+
always_false = true;
70+
} else if !matches!(&arguments[argument_len - 1], DataType::Boolean) {
6771
return Err(ErrorCode::BadArguments(format!(
6872
"The type of the last argument for {name} must be boolean type, but got {:?}",
6973
&arguments[argument_len - 1]
@@ -78,6 +82,7 @@ impl AggregateIfCombinator {
7882
argument_len,
7983
nested_name: nested_name.to_owned(),
8084
nested,
85+
always_false,
8186
}))
8287
}
8388

@@ -110,6 +115,9 @@ impl AggregateFunction for AggregateIfCombinator {
110115
validity: Option<&Bitmap>,
111116
input_rows: usize,
112117
) -> Result<()> {
118+
if self.always_false {
119+
return Ok(());
120+
}
113121
let predicate =
114122
BooleanType::try_downcast_column(&columns[self.argument_len - 1].to_column()).unwrap();
115123

@@ -132,6 +140,9 @@ impl AggregateFunction for AggregateIfCombinator {
132140
columns: ProjectedBlock,
133141
_input_rows: usize,
134142
) -> Result<()> {
143+
if self.always_false {
144+
return Ok(());
145+
}
135146
let predicate: Bitmap =
136147
BooleanType::try_downcast_column(&columns[self.argument_len - 1].to_column()).unwrap();
137148
let (columns, row_size) =
@@ -145,6 +156,9 @@ impl AggregateFunction for AggregateIfCombinator {
145156
}
146157

147158
fn accumulate_row(&self, place: AggrState, columns: ProjectedBlock, row: usize) -> Result<()> {
159+
if self.always_false {
160+
return Ok(());
161+
}
148162
let predicate: Bitmap =
149163
BooleanType::try_downcast_column(&columns[self.argument_len - 1].to_column()).unwrap();
150164
if predicate.get_bit(row) {
@@ -199,6 +213,9 @@ impl AggregateFunction for AggregateIfCombinator {
199213
}
200214

201215
fn get_if_condition(&self, entries: ProjectedBlock) -> Option<Bitmap> {
216+
if self.always_false {
217+
return Some(Bitmap::new_constant(false, entries.len()));
218+
}
202219
let condition_col = entries[self.argument_len - 1].clone().remove_nullable();
203220
let predicate = BooleanType::try_downcast_column(&condition_col.to_column()).unwrap();
204221
Some(predicate)

src/query/sql/src/planner/binder/aggregate.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -941,11 +941,12 @@ impl Binder {
941941
let mut finder = Finder::new(&f);
942942
finder.visit(&item.scalar)?;
943943
if !finder.scalars().is_empty() {
944-
return Err(ErrorCode::SemanticError(
945-
"GROUP BY items can't contain aggregate functions or window functions"
946-
.to_string(),
947-
)
948-
.set_span(item.scalar.span()));
944+
let scalar = finder.scalars().first().unwrap();
945+
return Err(ErrorCode::SemanticError(format!(
946+
"GROUP BY items can't contain aggregate functions or window functions: {:?}",
947+
scalar
948+
))
949+
.set_span(scalar.span()));
949950
}
950951
}
951952

0 commit comments

Comments
 (0)