Skip to content

Commit d584167

Browse files
authored
refactor: impl ReferenceSerialization (#230)
* refactor: impl `ReferenceSerialization` * refactor: ColumnId -> Ulid * fix: column miss match on `Insert`
1 parent bce7cd0 commit d584167

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

95 files changed

+2166
-2164
lines changed

Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,15 @@ regex = { version = "1" }
5959
rocksdb = { version = "0.22.0" }
6060
rust_decimal = { version = "1" }
6161
serde = { version = "1", features = ["derive", "rc"] }
62+
serde_macros = { path = "serde_macros" }
6263
siphasher = { version = "1", features = ["serde"] }
6364
sqlparser = { version = "0.34", features = ["serde"] }
6465
strum_macros = { version = "0.26.2" }
6566
thiserror = { version = "1" }
6667
tokio = { version = "1.36", features = ["full"], optional = true }
6768
tracing = { version = "0.1" }
6869
typetag = { version = "0.2" }
70+
ulid = { version = "1", features = ["serde"] }
6971

7072
[dev-dependencies]
7173
cargo-tarpaulin = { version = "0.27" }
@@ -83,7 +85,7 @@ pprof = { version = "0.13", features = ["flamegraph", "criterion"] }
8385
members = [
8486
"tests/sqllogictest",
8587
"tests/macros-test"
86-
]
88+
, "serde_macros"]
8789

8890
[profile.release]
8991
lto = true

serde_macros/Cargo.toml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
[package]
2+
name = "serde_macros"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
[dependencies]
7+
darling = "0.20"
8+
proc-macro2 = "1"
9+
quote = "1"
10+
syn = "2"
11+
12+
[lib]
13+
path = "src/lib.rs"
14+
proc-macro = true

serde_macros/src/lib.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
mod reference_serialization;
2+
3+
use proc_macro::TokenStream;
4+
use syn::{parse_macro_input, DeriveInput};
5+
6+
#[proc_macro_derive(ReferenceSerialization, attributes(reference_serialization))]
7+
pub fn reference_serialization(input: TokenStream) -> TokenStream {
8+
let ast = parse_macro_input!(input as DeriveInput);
9+
10+
let result = reference_serialization::handle(ast);
11+
match result {
12+
Ok(codegen) => codegen.into(),
13+
Err(e) => e.to_compile_error().into(),
14+
}
15+
}
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
use darling::ast::Data;
2+
use darling::{FromDeriveInput, FromField, FromVariant};
3+
use proc_macro2::{Ident, Span, TokenStream};
4+
use quote::quote;
5+
use syn::{
6+
AngleBracketedGenericArguments, DeriveInput, Error, GenericArgument, PathArguments, Type,
7+
TypePath,
8+
};
9+
10+
#[derive(Debug, FromDeriveInput)]
11+
#[darling(attributes(record))]
12+
struct SerializationOpts {
13+
ident: Ident,
14+
data: Data<SerializationVariantOpts, SerializationFieldOpt>,
15+
}
16+
17+
#[derive(Debug, FromVariant)]
18+
#[darling(attributes(record))]
19+
struct SerializationVariantOpts {
20+
ident: Ident,
21+
fields: darling::ast::Fields<SerializationFieldOpt>,
22+
}
23+
24+
#[derive(Debug, FromField)]
25+
#[darling(attributes(record))]
26+
struct SerializationFieldOpt {
27+
ident: Option<Ident>,
28+
ty: Type,
29+
}
30+
31+
fn process_type(ty: &Type) -> TokenStream {
32+
if let Type::Path(TypePath { path, .. }) = ty {
33+
let ident = &path.segments.last().unwrap().ident;
34+
35+
match ident.to_string().as_str() {
36+
"Vec" | "Option" | "Arc" | "Box" | "PhantomData" | "Bound" | "CountMinSketch" => {
37+
if let PathArguments::AngleBracketed(AngleBracketedGenericArguments {
38+
args, ..
39+
}) = &path.segments.last().unwrap().arguments
40+
{
41+
if let Some(GenericArgument::Type(inner_ty)) = args.first() {
42+
let inner_processed = process_type(inner_ty);
43+
44+
return quote! {
45+
#ident::<#inner_processed>
46+
};
47+
}
48+
}
49+
}
50+
_ => {}
51+
}
52+
53+
quote! { #ty }
54+
} else {
55+
quote! { #ty }
56+
}
57+
}
58+
59+
pub(crate) fn handle(ast: DeriveInput) -> Result<TokenStream, Error> {
60+
let record_opts: SerializationOpts = SerializationOpts::from_derive_input(&ast)?;
61+
let struct_name = &record_opts.ident;
62+
63+
Ok(match record_opts.data {
64+
Data::Struct(data_struct) => {
65+
let mut encode_fields: Vec<TokenStream> = Vec::new();
66+
let mut decode_fields: Vec<TokenStream> = Vec::new();
67+
let mut init_fields: Vec<TokenStream> = Vec::new();
68+
let mut is_tuple = false;
69+
70+
for (i, field_opts) in data_struct.fields.into_iter().enumerate() {
71+
is_tuple = is_tuple || field_opts.ident.is_none();
72+
73+
let field_name = field_opts
74+
.ident
75+
.unwrap_or_else(|| Ident::new(&format!("filed_{}", i), Span::call_site()));
76+
let ty = process_type(&field_opts.ty);
77+
78+
encode_fields.push(quote! {
79+
#field_name.encode(writer, is_direct, reference_tables)?;
80+
});
81+
decode_fields.push(quote! {
82+
let #field_name = #ty::decode(reader, drive, reference_tables)?;
83+
});
84+
init_fields.push(quote! {
85+
#field_name,
86+
})
87+
}
88+
let init_stream = if is_tuple {
89+
quote! { #struct_name ( #(#init_fields)* ) }
90+
} else {
91+
quote! { #struct_name { #(#init_fields)* } }
92+
};
93+
94+
quote! {
95+
impl crate::serdes::ReferenceSerialization for #struct_name {
96+
fn encode<W: std::io::Write>(
97+
&self,
98+
writer: &mut W,
99+
is_direct: bool,
100+
reference_tables: &mut crate::serdes::ReferenceTables,
101+
) -> Result<(), crate::errors::DatabaseError> {
102+
let #init_stream = self;
103+
104+
#(#encode_fields)*
105+
106+
Ok(())
107+
}
108+
109+
fn decode<T: crate::storage::Transaction, R: std::io::Read>(
110+
reader: &mut R,
111+
drive: Option<(&T, &crate::storage::TableCache)>,
112+
reference_tables: &crate::serdes::ReferenceTables,
113+
) -> Result<Self, crate::errors::DatabaseError> {
114+
#(#decode_fields)*
115+
116+
Ok(#init_stream)
117+
}
118+
}
119+
}
120+
}
121+
Data::Enum(data_enum) => {
122+
let mut variant_encode_fields: Vec<TokenStream> = Vec::new();
123+
let mut variant_decode_fields: Vec<TokenStream> = Vec::new();
124+
125+
for (i, variant_opts) in data_enum.into_iter().enumerate() {
126+
let i = i as u8;
127+
let mut encode_fields: Vec<TokenStream> = Vec::new();
128+
let mut decode_fields: Vec<TokenStream> = Vec::new();
129+
let mut init_fields: Vec<TokenStream> = Vec::new();
130+
let enum_name = variant_opts.ident;
131+
let mut is_tuple = false;
132+
133+
for (i, field_opts) in variant_opts.fields.into_iter().enumerate() {
134+
is_tuple = is_tuple || field_opts.ident.is_none();
135+
136+
let field_name = field_opts
137+
.ident
138+
.unwrap_or_else(|| Ident::new(&format!("filed_{}", i), Span::call_site()));
139+
let ty = process_type(&field_opts.ty);
140+
141+
encode_fields.push(quote! {
142+
#field_name.encode(writer, is_direct, reference_tables)?;
143+
});
144+
decode_fields.push(quote! {
145+
let #field_name = #ty::decode(reader, drive, reference_tables)?;
146+
});
147+
init_fields.push(quote! {
148+
#field_name,
149+
})
150+
}
151+
152+
let init_stream = if is_tuple {
153+
quote! { #struct_name::#enum_name ( #(#init_fields)* ) }
154+
} else {
155+
quote! { #struct_name::#enum_name { #(#init_fields)* } }
156+
};
157+
variant_encode_fields.push(quote! {
158+
#init_stream => {
159+
std::io::Write::write_all(writer, &[#i])?;
160+
161+
#(#encode_fields)*
162+
}
163+
});
164+
variant_decode_fields.push(quote! {
165+
#i => {
166+
#(#decode_fields)*
167+
168+
#init_stream
169+
}
170+
});
171+
}
172+
173+
quote! {
174+
impl crate::serdes::ReferenceSerialization for #struct_name {
175+
fn encode<W: std::io::Write>(
176+
&self,
177+
writer: &mut W,
178+
is_direct: bool,
179+
reference_tables: &mut crate::serdes::ReferenceTables,
180+
) -> Result<(), crate::errors::DatabaseError> {
181+
match self {
182+
#(#variant_encode_fields)*
183+
}
184+
185+
Ok(())
186+
}
187+
188+
fn decode<T: crate::storage::Transaction, R: std::io::Read>(
189+
reader: &mut R,
190+
drive: Option<(&T, &crate::storage::TableCache)>,
191+
reference_tables: &crate::serdes::ReferenceTables,
192+
) -> Result<Self, crate::errors::DatabaseError> {
193+
let mut type_bytes = [0u8; 1];
194+
std::io::Read::read_exact(reader, &mut type_bytes)?;
195+
196+
Ok(match type_bytes[0] {
197+
#(#variant_decode_fields)*
198+
_ => unreachable!(),
199+
})
200+
}
201+
}
202+
}
203+
}
204+
})
205+
}

src/binder/copy.rs

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,45 @@ use std::path::PathBuf;
22
use std::str::FromStr;
33
use std::sync::Arc;
44

5+
use super::*;
56
use crate::errors::DatabaseError;
67
use crate::planner::operator::copy_from_file::CopyFromFileOperator;
78
use crate::planner::operator::copy_to_file::CopyToFileOperator;
89
use crate::planner::operator::Operator;
910
use serde::{Deserialize, Serialize};
11+
use serde_macros::ReferenceSerialization;
1012
use sqlparser::ast::{CopyOption, CopySource, CopyTarget};
1113

12-
use super::*;
13-
14-
#[derive(Debug, PartialEq, PartialOrd, Ord, Hash, Eq, Clone, Serialize, Deserialize)]
14+
#[derive(
15+
Debug,
16+
PartialEq,
17+
PartialOrd,
18+
Ord,
19+
Hash,
20+
Eq,
21+
Clone,
22+
Serialize,
23+
Deserialize,
24+
ReferenceSerialization,
25+
)]
1526
pub struct ExtSource {
1627
pub path: PathBuf,
1728
pub format: FileFormat,
1829
}
1930

2031
/// File format.
21-
#[derive(Debug, PartialEq, PartialOrd, Ord, Hash, Eq, Clone, Serialize, Deserialize)]
32+
#[derive(
33+
Debug,
34+
PartialEq,
35+
PartialOrd,
36+
Ord,
37+
Hash,
38+
Eq,
39+
Clone,
40+
Serialize,
41+
Deserialize,
42+
ReferenceSerialization,
43+
)]
2244
pub enum FileFormat {
2345
Csv {
2446
/// Delimiter to parse.

src/binder/expr.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@ use std::slice;
1111
use std::sync::Arc;
1212

1313
use super::{lower_ident, Binder, BinderContext, QueryBindStep, SubQueryType};
14-
use crate::expression::function::scala::ScalarFunction;
15-
use crate::expression::function::table::TableFunction;
14+
use crate::expression::function::scala::{ArcScalarFunctionImpl, ScalarFunction};
15+
use crate::expression::function::table::{ArcTableFunctionImpl, TableFunction};
1616
use crate::expression::function::FunctionSummary;
1717
use crate::expression::{AliasType, ScalarExpression};
1818
use crate::planner::LogicalPlan;
1919
use crate::storage::Transaction;
2020
use crate::types::value::{DataValue, Utf8Type};
21-
use crate::types::LogicalType;
21+
use crate::types::{ColumnId, LogicalType};
2222

2323
macro_rules! try_alias {
2424
($context:expr, $full_name:expr) => {
@@ -231,11 +231,11 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
231231
sub_query: LogicalPlan,
232232
) -> Result<(ScalarExpression, LogicalPlan), DatabaseError> {
233233
let mut alias_column = ColumnCatalog::clone(&column);
234-
alias_column.set_ref_table(self.context.temp_table(), 0);
234+
alias_column.set_ref_table(self.context.temp_table(), ColumnId::new());
235235

236236
let alias_expr = ScalarExpression::Alias {
237237
expr: Box::new(ScalarExpression::ColumnRef(column)),
238-
alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(Arc::new(
238+
alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(ColumnRef::from(
239239
alias_column,
240240
)))),
241241
};
@@ -246,7 +246,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
246246
fn bind_subquery(
247247
&mut self,
248248
subquery: &Query,
249-
) -> Result<(LogicalPlan, Arc<ColumnCatalog>), DatabaseError> {
249+
) -> Result<(LogicalPlan, ColumnRef), DatabaseError> {
250250
let BinderContext {
251251
table_cache,
252252
transaction,
@@ -601,13 +601,13 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
601601
if let Some(function) = self.context.scala_functions.get(&summary) {
602602
return Ok(ScalarExpression::ScalaFunction(ScalarFunction {
603603
args,
604-
inner: function.clone(),
604+
inner: ArcScalarFunctionImpl(function.clone()),
605605
}));
606606
}
607607
if let Some(function) = self.context.table_functions.get(&summary) {
608608
return Ok(ScalarExpression::TableFunction(TableFunction {
609609
args,
610-
inner: function.clone(),
610+
inner: ArcTableFunctionImpl(function.clone()),
611611
}));
612612
}
613613

src/binder/insert.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
1919
idents: &[Ident],
2020
expr_rows: &Vec<Vec<Expr>>,
2121
is_overwrite: bool,
22+
is_mapping_by_name: bool,
2223
) -> Result<LogicalPlan, DatabaseError> {
2324
// FIXME: Make it better to detect the current BindStep
2425
self.context.allow_default = true;
@@ -97,6 +98,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
9798
Operator::Insert(InsertOperator {
9899
table_name,
99100
is_overwrite,
101+
is_mapping_by_name,
100102
}),
101103
vec![values_plan],
102104
))

0 commit comments

Comments
 (0)