Skip to content

Commit 23e7901

Browse files
committed
Changed the binding method of table in parent, which should support more cases
1 parent b79fc7f commit 23e7901

File tree

23 files changed

+302
-132
lines changed

23 files changed

+302
-132
lines changed

src/binder/create_index.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ impl<T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'_, '_, T, A>
4343
for expr in exprs {
4444
// TODO: Expression Index
4545
match self.bind_expr(&expr.expr)? {
46-
ScalarExpression::ColumnRef(column) => columns.push(column),
46+
ScalarExpression::ColumnRef(column, false) => columns.push(column),
4747
expr => {
4848
return Err(DatabaseError::UnsupportedStmt(format!(
4949
"'CREATE INDEX' by {}",

src/binder/create_table.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,6 @@ mod tests {
189189
&scala_functions,
190190
&table_functions,
191191
Arc::new(AtomicUsize::new(0)),
192-
vec![],
193192
),
194193
&[],
195194
None,

src/binder/create_view.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,10 @@ impl<T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'_, '_, T, A>
3939
column.set_ref_table(view_name.clone(), Ulid::new(), true);
4040

4141
ScalarExpression::Alias {
42-
expr: Box::new(ScalarExpression::ColumnRef(mapping_column.clone())),
42+
expr: Box::new(ScalarExpression::ColumnRef(mapping_column.clone(), false)),
4343
alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(
4444
ColumnRef::from(column),
45+
false,
4546
))),
4647
}
4748
})

src/binder/expr.rs

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use std::collections::HashMap;
1111
use std::slice;
1212
use std::sync::Arc;
1313

14-
use super::{lower_ident, Binder, BinderContext, QueryBindStep, SubQueryType};
14+
use super::{lower_ident, Binder, BinderContext, QueryBindStep, Source, SubQueryType};
1515
use crate::expression::function::scala::{ArcScalarFunctionImpl, ScalarFunction};
1616
use crate::expression::function::table::{ArcTableFunctionImpl, TableFunction};
1717
use crate::expression::function::FunctionSummary;
@@ -259,9 +259,10 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T
259259

260260
let alias_expr = ScalarExpression::Alias {
261261
expr: Box::new(expr),
262-
alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(ColumnRef::from(
263-
alias_column,
264-
)))),
262+
alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(
263+
ColumnRef::from(alias_column),
264+
false,
265+
))),
265266
};
266267
let alias_plan = self.bind_project(sub_query, vec![alias_expr.clone()])?;
267268
Ok((alias_expr, alias_plan))
@@ -279,7 +280,6 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T
279280
scala_functions,
280281
table_functions,
281282
temp_table_id,
282-
parent_name,
283283
..
284284
} = &self.context;
285285
let mut binder = Binder::new(
@@ -290,17 +290,11 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T
290290
scala_functions,
291291
table_functions,
292292
temp_table_id.clone(),
293-
parent_name.clone(),
294293
),
295294
self.args,
296295
Some(self),
297296
);
298297

299-
self.context.bind_table.iter().find(|((t, _, _), _)| {
300-
binder.context.parent_name.push(t.as_str().to_string());
301-
true
302-
});
303-
304298
let mut sub_query = binder.bind_query(subquery)?;
305299
let sub_query_schema = sub_query.output_schema();
306300

@@ -319,13 +313,13 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T
319313

320314
let columns = sub_query_schema
321315
.iter()
322-
.map(|column| ScalarExpression::ColumnRef(column.clone()))
316+
.map(|column| ScalarExpression::ColumnRef(column.clone(), false))
323317
.collect::<Vec<_>>();
324318
ScalarExpression::Tuple(columns)
325319
} else {
326320
fn_check(1)?;
327321

328-
ScalarExpression::ColumnRef(sub_query_schema[0].clone())
322+
ScalarExpression::ColumnRef(sub_query_schema[0].clone(), false)
329323
};
330324
Ok((sub_query, expr))
331325
}
@@ -376,13 +370,39 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T
376370
try_default!(&full_name.0, full_name.1);
377371
}
378372
if let Some(table) = full_name.0.or(bind_table_name) {
379-
let source = self.context.bind_source(&table)?;
373+
let mut source: &Source;
374+
let mut from_parent: bool;
375+
let (mut parent, mut parent_context) = if let Some(parent) = self.parent {
376+
(Some(parent), Some(&parent.context))
377+
} else {
378+
(None, None)
379+
};
380+
381+
loop {
382+
(source, from_parent) = match self.context.bind_source(parent_context, &table) {
383+
(Ok(source), from_parent) => (source, from_parent),
384+
(Err(e), _) => {
385+
if let Some(p) = parent {
386+
(parent, parent_context) = match p.parent {
387+
Some(parent) => (Some(parent), Some(&parent.context)),
388+
None => return Err(e),
389+
}
390+
} else {
391+
return Err(e);
392+
}
393+
continue;
394+
}
395+
};
396+
break;
397+
}
398+
380399
let schema_buf = self.table_schema_buf.entry(Arc::new(table)).or_default();
381400

382401
Ok(ScalarExpression::ColumnRef(
383402
source
384403
.column(&full_name.1, schema_buf)
385404
.ok_or_else(|| DatabaseError::ColumnNotFound(full_name.1.to_string()))?,
405+
from_parent,
386406
))
387407
} else {
388408
let op =
@@ -411,7 +431,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T
411431
table_schema_buf.entry(table_name.clone()).or_default();
412432
source.column(&full_name.1, schema_buf)
413433
} {
414-
*got_column = Some(ScalarExpression::ColumnRef(column));
434+
*got_column = Some(ScalarExpression::ColumnRef(column, false));
415435
}
416436
}
417437
};

src/binder/insert.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ impl<T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'_, '_, T, A>
5151
slice::from_ref(ident),
5252
Some(table_name.to_string()),
5353
)? {
54-
ScalarExpression::ColumnRef(catalog) => columns.push(catalog),
54+
ScalarExpression::ColumnRef(catalog, _) => columns.push(catalog),
5555
_ => return Err(DatabaseError::UnsupportedStmt(ident.to_string())),
5656
}
5757
}

src/binder/mod.rs

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ pub struct BinderContext<'a, T: Transaction> {
114114
using: HashSet<String>,
115115

116116
bind_step: QueryBindStep,
117-
parent_name: Vec<String>,
118117
sub_queries: HashMap<QueryBindStep, Vec<SubQueryType>>,
119118

120119
temp_table_id: Arc<AtomicUsize>,
@@ -172,7 +171,6 @@ impl<'a, T: Transaction> BinderContext<'a, T> {
172171
scala_functions: &'a ScalaFunctions,
173172
table_functions: &'a TableFunctions,
174173
temp_table_id: Arc<AtomicUsize>,
175-
parent_name: Vec<String>,
176174
) -> Self {
177175
BinderContext {
178176
scala_functions,
@@ -187,7 +185,6 @@ impl<'a, T: Transaction> BinderContext<'a, T> {
187185
agg_calls: Default::default(),
188186
using: Default::default(),
189187
bind_step: QueryBindStep::From,
190-
parent_name,
191188
sub_queries: Default::default(),
192189
temp_table_id,
193190
allow_default: false,
@@ -278,14 +275,27 @@ impl<'a, T: Transaction> BinderContext<'a, T> {
278275
Ok(source)
279276
}
280277

281-
pub fn bind_source<'b: 'a>(&self, table_name: &str) -> Result<&Source, DatabaseError> {
278+
pub fn bind_source<'b: 'a>(
279+
&self,
280+
parent: Option<&'a BinderContext<'_, T>>,
281+
table_name: &str,
282+
) -> (Result<&'b Source, DatabaseError>, bool) {
282283
if let Some(source) = self.bind_table.iter().find(|((t, alias, _), _)| {
283284
t.as_str() == table_name
284285
|| matches!(alias.as_ref().map(|a| a.as_str() == table_name), Some(true))
285286
}) {
286-
Ok(source.1)
287+
(Ok(source.1), false)
288+
} else if let Some(context) = parent {
289+
if let Some(source) = context.bind_table.iter().find(|((t, alias, _), _)| {
290+
t.as_str() == table_name
291+
|| matches!(alias.as_ref().map(|a| a.as_str() == table_name), Some(true))
292+
}) {
293+
(Ok(source.1), true)
294+
} else {
295+
(Err(DatabaseError::InvalidTable(table_name.into())), false)
296+
}
287297
} else {
288-
Err(DatabaseError::InvalidTable(table_name.into()))
298+
(Err(DatabaseError::InvalidTable(table_name.into())), false)
289299
}
290300
}
291301

@@ -553,7 +563,6 @@ pub mod test {
553563
&scala_functions,
554564
&table_functions,
555565
Arc::new(AtomicUsize::new(0)),
556-
vec![],
557566
),
558567
&[],
559568
None,

0 commit comments

Comments
 (0)