diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 8b011d6e..0b6c47ef 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -13,24 +13,39 @@ fn column_names(data: &DataStruct, cx: &Ctxt, container: &Container) -> Result { let rename_rule = container.rename_all_rules().deserialize; - let column_names_iter = fields + + let chain_iters = fields .named .iter() .enumerate() - .map(|(index, field)| Field::from_ast(cx, index, field, None, &SerdeDefault::None)) - .filter(|field| !field.skip_serializing() && !field.skip_deserializing()) - .map(|field| { - rename_rule - .apply_to_field(field.name().serialize_name()) - .to_string() + .map(|(index, field)| { + ( + Field::from_ast(cx, index, field, None, &SerdeDefault::None), + &field.ty, + ) + }) + .filter(|(field, _)| !(field.skip_serializing() || field.skip_deserializing())) + .map(|(field_meta, ty)| { + if field_meta.flatten() { + quote! { + <#ty as clickhouse::Row>::column_names().into_iter() + } + } else { + let column_name = rename_rule + .apply_to_field(field_meta.name().serialize_name()) + .to_string(); + quote! { + std::iter::once(#column_name) + } + } }); quote! { - &[#( #column_names_iter,)*] + std::iter::empty() #(.chain(#chain_iters))* } } Fields::Unnamed(_) => { - quote! { &[] } + quote! { [] } } Fields::Unit => unreachable!("checked by the caller"), }) @@ -94,8 +109,8 @@ fn row_impl(input: DeriveInput) -> Result { #[automatically_derived] impl #impl_generics clickhouse::Row for #name #ty_generics #where_clause { const NAME: &'static str = stringify!(#name); - const COLUMN_NAMES: &'static [&'static str] = #column_names; - const COLUMN_COUNT: usize = ::COLUMN_NAMES.len(); + fn column_names() -> impl IntoIterator { #column_names } + fn column_count() -> usize { ::column_names().into_iter().count() } const KIND: clickhouse::_priv::RowKind = clickhouse::_priv::RowKind::Struct; type Value<'__v> = #value; diff --git a/derive/src/tests/snapshots/generic_borrowed_row-2.snap b/derive/src/tests/snapshots/generic_borrowed_row-2.snap index 2d6511e5..fc7b09f8 100644 --- a/derive/src/tests/snapshots/generic_borrowed_row-2.snap +++ b/derive/src/tests/snapshots/generic_borrowed_row-2.snap @@ -11,8 +11,12 @@ struct Sample<'a, A, B> { #[automatically_derived] impl<'a, A, B> clickhouse::Row for Sample<'a, A, B> { const NAME: &'static str = stringify!(Sample); - const COLUMN_NAMES: &'static [&'static str] = &["a", "b"]; - const COLUMN_COUNT: usize = ::COLUMN_NAMES.len(); + fn column_names() -> impl IntoIterator { + std::iter::empty().chain(std::iter::once("a")).chain(std::iter::once("b")) + } + fn column_count() -> usize { + ::column_names().into_iter().count() + } const KIND: clickhouse::_priv::RowKind = clickhouse::_priv::RowKind::Struct; type Value<'__v> = Sample<'__v, A, B>; } diff --git a/derive/src/tests/snapshots/generic_borrowed_row-3.snap b/derive/src/tests/snapshots/generic_borrowed_row-3.snap index 19f7d5fc..708817a5 100644 --- a/derive/src/tests/snapshots/generic_borrowed_row-3.snap +++ b/derive/src/tests/snapshots/generic_borrowed_row-3.snap @@ -17,8 +17,12 @@ where T: Clone, { const NAME: &'static str = stringify!(Sample); - const COLUMN_NAMES: &'static [&'static str] = &["a", "b"]; - const COLUMN_COUNT: usize = ::COLUMN_NAMES.len(); + fn column_names() -> impl IntoIterator { + std::iter::empty().chain(std::iter::once("a")).chain(std::iter::once("b")) + } + fn column_count() -> usize { + ::column_names().into_iter().count() + } const KIND: clickhouse::_priv::RowKind = clickhouse::_priv::RowKind::Struct; type Value<'__v> = Sample<'__v, T>; } diff --git a/derive/src/tests/snapshots/generic_borrowed_row.snap b/derive/src/tests/snapshots/generic_borrowed_row.snap index e0c8e897..026ffd40 100644 --- a/derive/src/tests/snapshots/generic_borrowed_row.snap +++ b/derive/src/tests/snapshots/generic_borrowed_row.snap @@ -11,8 +11,12 @@ struct Sample<'a, T> { #[automatically_derived] impl<'a, T> clickhouse::Row for Sample<'a, T> { const NAME: &'static str = stringify!(Sample); - const COLUMN_NAMES: &'static [&'static str] = &["a", "b"]; - const COLUMN_COUNT: usize = ::COLUMN_NAMES.len(); + fn column_names() -> impl IntoIterator { + std::iter::empty().chain(std::iter::once("a")).chain(std::iter::once("b")) + } + fn column_count() -> usize { + ::column_names().into_iter().count() + } const KIND: clickhouse::_priv::RowKind = clickhouse::_priv::RowKind::Struct; type Value<'__v> = Sample<'__v, T>; } diff --git a/derive/src/tests/snapshots/generic_owned_row-2.snap b/derive/src/tests/snapshots/generic_owned_row-2.snap index 225cad3d..8b8bb322 100644 --- a/derive/src/tests/snapshots/generic_owned_row-2.snap +++ b/derive/src/tests/snapshots/generic_owned_row-2.snap @@ -11,8 +11,12 @@ struct Sample { #[automatically_derived] impl clickhouse::Row for Sample { const NAME: &'static str = stringify!(Sample); - const COLUMN_NAMES: &'static [&'static str] = &["a", "b"]; - const COLUMN_COUNT: usize = ::COLUMN_NAMES.len(); + fn column_names() -> impl IntoIterator { + std::iter::empty().chain(std::iter::once("a")).chain(std::iter::once("b")) + } + fn column_count() -> usize { + ::column_names().into_iter().count() + } const KIND: clickhouse::_priv::RowKind = clickhouse::_priv::RowKind::Struct; type Value<'__v> = Self; } diff --git a/derive/src/tests/snapshots/generic_owned_row-3.snap b/derive/src/tests/snapshots/generic_owned_row-3.snap index 5d8d1694..eeb4d4a0 100644 --- a/derive/src/tests/snapshots/generic_owned_row-3.snap +++ b/derive/src/tests/snapshots/generic_owned_row-3.snap @@ -17,8 +17,12 @@ where T: Clone, { const NAME: &'static str = stringify!(Sample); - const COLUMN_NAMES: &'static [&'static str] = &["a", "b"]; - const COLUMN_COUNT: usize = ::COLUMN_NAMES.len(); + fn column_names() -> impl IntoIterator { + std::iter::empty().chain(std::iter::once("a")).chain(std::iter::once("b")) + } + fn column_count() -> usize { + ::column_names().into_iter().count() + } const KIND: clickhouse::_priv::RowKind = clickhouse::_priv::RowKind::Struct; type Value<'__v> = Self; } diff --git a/derive/src/tests/snapshots/generic_owned_row.snap b/derive/src/tests/snapshots/generic_owned_row.snap index 1e3f6862..f17eaabe 100644 --- a/derive/src/tests/snapshots/generic_owned_row.snap +++ b/derive/src/tests/snapshots/generic_owned_row.snap @@ -11,8 +11,12 @@ struct Sample { #[automatically_derived] impl clickhouse::Row for Sample { const NAME: &'static str = stringify!(Sample); - const COLUMN_NAMES: &'static [&'static str] = &["a", "b"]; - const COLUMN_COUNT: usize = ::COLUMN_NAMES.len(); + fn column_names() -> impl IntoIterator { + std::iter::empty().chain(std::iter::once("a")).chain(std::iter::once("b")) + } + fn column_count() -> usize { + ::column_names().into_iter().count() + } const KIND: clickhouse::_priv::RowKind = clickhouse::_priv::RowKind::Struct; type Value<'__v> = Self; } diff --git a/derive/src/tests/snapshots/serde_rename.snap b/derive/src/tests/snapshots/serde_rename.snap index 28b3ab22..20af346e 100644 --- a/derive/src/tests/snapshots/serde_rename.snap +++ b/derive/src/tests/snapshots/serde_rename.snap @@ -14,8 +14,15 @@ struct Sample { #[automatically_derived] impl clickhouse::Row for Sample { const NAME: &'static str = stringify!(Sample); - const COLUMN_NAMES: &'static [&'static str] = &["a", "items.a", "items.b"]; - const COLUMN_COUNT: usize = ::COLUMN_NAMES.len(); + fn column_names() -> impl IntoIterator { + std::iter::empty() + .chain(std::iter::once("a")) + .chain(std::iter::once("items.a")) + .chain(std::iter::once("items.b")) + } + fn column_count() -> usize { + ::column_names().into_iter().count() + } const KIND: clickhouse::_priv::RowKind = clickhouse::_priv::RowKind::Struct; type Value<'__v> = Self; } diff --git a/derive/src/tests/snapshots/serde_skip_deserializing.snap b/derive/src/tests/snapshots/serde_skip_deserializing.snap index 630f01d3..9a2d1179 100644 --- a/derive/src/tests/snapshots/serde_skip_deserializing.snap +++ b/derive/src/tests/snapshots/serde_skip_deserializing.snap @@ -12,8 +12,12 @@ struct Sample { #[automatically_derived] impl clickhouse::Row for Sample { const NAME: &'static str = stringify!(Sample); - const COLUMN_NAMES: &'static [&'static str] = &["a"]; - const COLUMN_COUNT: usize = ::COLUMN_NAMES.len(); + fn column_names() -> impl IntoIterator { + std::iter::empty().chain(std::iter::once("a")) + } + fn column_count() -> usize { + ::column_names().into_iter().count() + } const KIND: clickhouse::_priv::RowKind = clickhouse::_priv::RowKind::Struct; type Value<'__v> = Self; } diff --git a/derive/src/tests/snapshots/serde_skip_serializing.snap b/derive/src/tests/snapshots/serde_skip_serializing.snap index 4188c568..aeca1bd2 100644 --- a/derive/src/tests/snapshots/serde_skip_serializing.snap +++ b/derive/src/tests/snapshots/serde_skip_serializing.snap @@ -12,8 +12,12 @@ struct Sample { #[automatically_derived] impl clickhouse::Row for Sample { const NAME: &'static str = stringify!(Sample); - const COLUMN_NAMES: &'static [&'static str] = &["a"]; - const COLUMN_COUNT: usize = ::COLUMN_NAMES.len(); + fn column_names() -> impl IntoIterator { + std::iter::empty().chain(std::iter::once("a")) + } + fn column_count() -> usize { + ::column_names().into_iter().count() + } const KIND: clickhouse::_priv::RowKind = clickhouse::_priv::RowKind::Struct; type Value<'__v> = Self; } diff --git a/derive/src/tests/snapshots/simple_borrowed_row.snap b/derive/src/tests/snapshots/simple_borrowed_row.snap index 2cb01c26..eadd0114 100644 --- a/derive/src/tests/snapshots/simple_borrowed_row.snap +++ b/derive/src/tests/snapshots/simple_borrowed_row.snap @@ -11,8 +11,12 @@ struct Sample<'a> { #[automatically_derived] impl<'a> clickhouse::Row for Sample<'a> { const NAME: &'static str = stringify!(Sample); - const COLUMN_NAMES: &'static [&'static str] = &["a", "b"]; - const COLUMN_COUNT: usize = ::COLUMN_NAMES.len(); + fn column_names() -> impl IntoIterator { + std::iter::empty().chain(std::iter::once("a")).chain(std::iter::once("b")) + } + fn column_count() -> usize { + ::column_names().into_iter().count() + } const KIND: clickhouse::_priv::RowKind = clickhouse::_priv::RowKind::Struct; type Value<'__v> = Sample<'__v>; } diff --git a/derive/src/tests/snapshots/simple_owned_row.snap b/derive/src/tests/snapshots/simple_owned_row.snap index 2140a3f9..5e18b5d2 100644 --- a/derive/src/tests/snapshots/simple_owned_row.snap +++ b/derive/src/tests/snapshots/simple_owned_row.snap @@ -11,8 +11,12 @@ struct Sample { #[automatically_derived] impl clickhouse::Row for Sample { const NAME: &'static str = stringify!(Sample); - const COLUMN_NAMES: &'static [&'static str] = &["a", "b"]; - const COLUMN_COUNT: usize = ::COLUMN_NAMES.len(); + fn column_names() -> impl IntoIterator { + std::iter::empty().chain(std::iter::once("a")).chain(std::iter::once("b")) + } + fn column_count() -> usize { + ::column_names().into_iter().count() + } const KIND: clickhouse::_priv::RowKind = clickhouse::_priv::RowKind::Struct; type Value<'__v> = Self; } diff --git a/src/row.rs b/src/row.rs index 2181b0f8..67e30e47 100644 --- a/src/row.rs +++ b/src/row.rs @@ -29,9 +29,9 @@ pub trait Row { const NAME: &'static str; // TODO: different list for SELECT/INSERT (de/ser) #[doc(hidden)] - const COLUMN_NAMES: &'static [&'static str]; + fn column_names() -> impl IntoIterator; #[doc(hidden)] - const COLUMN_COUNT: usize; + fn column_count() -> usize; #[doc(hidden)] const KIND: RowKind; #[doc(hidden)] @@ -220,8 +220,8 @@ macro_rules! impl_row_for_tuple { ($i:ident $($other:ident)+) => { impl<$i: Row, $($other: Primitive),+> Row for ($i, $($other),+) { const NAME: &'static str = $i::NAME; - const COLUMN_NAMES: &'static [&'static str] = $i::COLUMN_NAMES; - const COLUMN_COUNT: usize = $i::COLUMN_COUNT + count_tokens!($($other)*); + fn column_names() -> impl IntoIterator { $i::column_names() } + fn column_count() -> usize { $i::column_count() + count_tokens!($($other)*) } const KIND: RowKind = RowKind::Tuple; type Value<'a> = Self; @@ -237,8 +237,12 @@ impl Primitive for () {} impl Row for P { const NAME: &'static str = stringify!(P); - const COLUMN_NAMES: &'static [&'static str] = &[]; - const COLUMN_COUNT: usize = 1; + fn column_names() -> impl IntoIterator { + [] + } + fn column_count() -> usize { + 1 + } const KIND: RowKind = RowKind::Primitive; type Value<'a> = Self; @@ -248,8 +252,12 @@ impl_row_for_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8); impl Row for Vec { const NAME: &'static str = "Vec"; - const COLUMN_NAMES: &'static [&'static str] = &[]; - const COLUMN_COUNT: usize = 1; + fn column_names() -> impl IntoIterator { + [] + } + fn column_count() -> usize { + 1 + } const KIND: RowKind = RowKind::Vec; type Value<'a> = Self; @@ -257,20 +265,21 @@ impl Row for Vec { /// Collects all field names in depth and joins them with comma. pub(crate) fn join_column_names() -> Option { - if R::COLUMN_NAMES.is_empty() { + if R::column_names().into_iter().next().is_none() { return None; } - let out = R::COLUMN_NAMES - .iter() - .enumerate() - .fold(String::new(), |mut res, (idx, name)| { - if idx > 0 { - res.push(','); - } - sql::escape::identifier(name, &mut res).expect("impossible"); - res - }); + let out = + R::column_names() + .into_iter() + .enumerate() + .fold(String::new(), |mut res, (idx, name)| { + if idx > 0 { + res.push(','); + } + sql::escape::identifier(name, &mut res).expect("impossible"); + res + }); Some(out) } @@ -381,4 +390,26 @@ mod tests { assert_eq!(join_column_names::().unwrap(), "`type`,`if`"); } + + #[test] + fn it_flattens() { + use serde::Serialize; + + #[derive(Row, Serialize)] + #[allow(dead_code)] + struct Inner { + a: u32, + b: u32, + } + + #[derive(Row, Serialize)] + #[allow(dead_code)] + struct Outer { + #[serde(flatten)] + inner: Inner, + c: u32, + } + + assert_eq!(join_column_names::().unwrap(), "`a`,`b`,`c`"); + } } diff --git a/src/row_metadata.rs b/src/row_metadata.rs index 1f2a921f..5c87e893 100644 --- a/src/row_metadata.rs +++ b/src/row_metadata.rs @@ -22,7 +22,7 @@ static ROW_METADATA_CACHE: OnceCell = OnceCell::const_ne #[derive(Debug, PartialEq)] pub(crate) enum AccessType { WithSeqAccess, - WithMapAccess(Vec), + WithMapAccess(Vec, Vec<&'static str>), } /// Contains a vector of [`Column`] objects parsed from the beginning @@ -57,13 +57,13 @@ impl RowMetadata { AccessType::WithSeqAccess // ignored } RowKind::Tuple => { - if T::COLUMN_COUNT != columns.len() { + if T::column_count() != columns.len() { panic!( "While processing a tuple row: database schema has {} columns, \ but the tuple definition has {} fields in total.\ \n#### All schema columns:\n{}", columns.len(), - T::COLUMN_COUNT, + T::column_count(), join_panic_schema_hint(&columns), ); } @@ -82,23 +82,25 @@ impl RowMetadata { AccessType::WithSeqAccess // ignored } RowKind::Struct => { - if columns.len() != T::COLUMN_NAMES.len() { + if columns.len() != T::column_names().into_iter().count() { panic!( "While processing struct {}: database schema has {} columns, \ but the struct definition has {} fields.\ \n#### All struct fields:\n{}\n#### All schema columns:\n{}", T::NAME, columns.len(), - T::COLUMN_NAMES.len(), - join_panic_schema_hint(T::COLUMN_NAMES), + T::column_names().into_iter().count(), + join_panic_schema_hint(&T::column_names().into_iter().collect::>()), join_panic_schema_hint(&columns), ); } - let mut mapping = Vec::with_capacity(T::COLUMN_NAMES.len()); + let mut mapping = Vec::with_capacity(T::column_names().into_iter().count()); let mut expected_index = 0; let mut should_use_map = false; for col in &columns { - if let Some(index) = T::COLUMN_NAMES.iter().position(|field| col.name == *field) + if let Some(index) = T::column_names() + .into_iter() + .position(|field| col.name == *field) { if index != expected_index { should_use_map = true @@ -112,13 +114,15 @@ impl RowMetadata { \n#### All struct fields:\n{}\n#### All schema columns:\n{}", T::NAME, col, - join_panic_schema_hint(T::COLUMN_NAMES), + join_panic_schema_hint( + &T::column_names().into_iter().collect::>() + ), join_panic_schema_hint(&columns), ); } } if should_use_map { - AccessType::WithMapAccess(mapping) + AccessType::WithMapAccess(mapping, T::column_names().into_iter().collect()) } else { AccessType::WithSeqAccess } @@ -133,7 +137,7 @@ impl RowMetadata { #[inline] pub(crate) fn get_schema_index(&self, struct_idx: usize) -> usize { match &self.access_type { - AccessType::WithMapAccess(mapping) => { + AccessType::WithMapAccess(mapping, _) => { if struct_idx < mapping.len() { mapping[struct_idx] } else { @@ -145,9 +149,22 @@ impl RowMetadata { } } + #[inline] + pub(crate) fn get_field_name(&self, struct_idx: usize) -> Option<&'static str> { + match &self.access_type { + AccessType::WithMapAccess(mapping, field_names) => { + let Some(mapped) = mapping.get(struct_idx) else { + return None; + }; + field_names.get(*mapped).copied() + } + AccessType::WithSeqAccess => None, + } + } + #[inline] pub(crate) fn is_field_order_wrong(&self) -> bool { - matches!(self.access_type, AccessType::WithMapAccess(_)) + matches!(self.access_type, AccessType::WithMapAccess(_, _)) } } diff --git a/src/rowbinary/de.rs b/src/rowbinary/de.rs index 88bf9823..65a2cb4d 100644 --- a/src/rowbinary/de.rs +++ b/src/rowbinary/de.rs @@ -61,6 +61,7 @@ where { input: &'cursor mut &'data [u8], validator: V, + is_inner: bool, _marker: PhantomData, } @@ -72,6 +73,7 @@ where Self { input, validator, + is_inner: false, _marker: PhantomData, } } @@ -83,6 +85,7 @@ where RowBinaryDeserializer { input: self.input, validator: self.validator.validate(serde_type), + is_inner: true, _marker: PhantomData, } } @@ -156,8 +159,35 @@ where impl_num!(f64, deserialize_f64, visit_f64, get_f64_le, SerdeType::F64); #[inline(always)] - fn deserialize_any>(self, _: V) -> Result { - Err(Error::DeserializeAnyNotSupported) + fn deserialize_any>(self, visitor: V) -> Result { + match self + .validator + .peek() + .ok_or(Error::DeserializeAnyNotSupported)? + { + SerdeType::Bool => self.deserialize_bool(visitor), + SerdeType::I8 => self.deserialize_i8(visitor), + SerdeType::I16 => self.deserialize_i16(visitor), + SerdeType::I32 => self.deserialize_i32(visitor), + SerdeType::I64 => self.deserialize_i64(visitor), + SerdeType::I128 => self.deserialize_i128(visitor), + SerdeType::U8 => self.deserialize_u8(visitor), + SerdeType::U16 => self.deserialize_u16(visitor), + SerdeType::U32 => self.deserialize_u32(visitor), + SerdeType::U64 => self.deserialize_u64(visitor), + SerdeType::U128 => self.deserialize_u128(visitor), + SerdeType::F32 => self.deserialize_f32(visitor), + SerdeType::F64 => self.deserialize_f64(visitor), + SerdeType::Str => self.deserialize_str(visitor), + SerdeType::String => self.deserialize_string(visitor), + SerdeType::Option => self.deserialize_option(visitor), + SerdeType::Enum => self.deserialize_enum("", &[], visitor), + SerdeType::Bytes(_) => self.deserialize_bytes(visitor), + SerdeType::ByteBuf(_) => self.deserialize_byte_buf(visitor), + SerdeType::Tuple(len) => self.deserialize_tuple(len, visitor), + SerdeType::Seq(_) => self.deserialize_seq(visitor), + SerdeType::Map(_) => self.deserialize_map(visitor), + } } #[inline(always)] @@ -263,12 +293,19 @@ where #[inline(always)] fn deserialize_map>(self, visitor: V) -> Result { - let len = self.read_size()?; - let deserializer = &mut self.inner(SerdeType::Map(len)); - visitor.visit_map(RowBinaryMapAccess { - deserializer, - remaining: len, - }) + if self.is_inner { + let len = self.read_size()?; + let deserializer = &mut self.inner(SerdeType::Map(len)); + visitor.visit_map(RowBinaryMapAccess { + deserializer, + remaining: len, + }) + } else { + visitor.visit_map(RowBinaryStructAsMapAccess { + deserializer: self, + current_field_idx: 0, + }) + } } #[inline(always)] @@ -278,6 +315,7 @@ where fields: &'static [&'static str], visitor: V, ) -> Result { + self.is_inner = true; if !self.validator.is_field_order_wrong() { visitor.visit_seq(RowBinarySeqAccess { deserializer: self, @@ -287,7 +325,6 @@ where visitor.visit_map(RowBinaryStructAsMapAccess { deserializer: self, current_field_idx: 0, - fields, }) } } @@ -417,7 +454,6 @@ where { deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, R, Validator>, current_field_idx: usize, - fields: &'static [&'static str], } struct StructFieldIdentifier(&'static str); @@ -475,14 +511,14 @@ where where K: DeserializeSeed<'data>, { - if self.current_field_idx >= self.fields.len() { - return Ok(None); - } - let schema_index = self + let Some(field_name) = self .deserializer .validator - .get_schema_index(self.current_field_idx); - let field_id = StructFieldIdentifier(self.fields[schema_index]); + .get_field_name(self.current_field_idx) + else { + return Ok(None); + }; + let field_id = StructFieldIdentifier(field_name); self.current_field_idx += 1; seed.deserialize(field_id).map(Some) } @@ -495,7 +531,8 @@ where } fn size_hint(&self) -> Option { - Some(self.fields.len()) + // Some(self.fields.len()) + None } } diff --git a/src/rowbinary/tests.rs b/src/rowbinary/tests.rs index 12ee127a..d1e36e39 100644 --- a/src/rowbinary/tests.rs +++ b/src/rowbinary/tests.rs @@ -48,30 +48,34 @@ struct Sample<'a> { // clickhouse_derive is not working here impl Row for Sample<'_> { const NAME: &'static str = "Sample"; - const COLUMN_NAMES: &'static [&'static str] = &[ - "int8", - "int32", - "int64", - "uint8", - "uint32", - "uint64", - "float32", - "float64", - "datetime", - "datetime64", - "time32", - "time64", - "decimal64", - "decimal128", - "string", - "blob", - "optional_decimal64", - "optional_datetime", - "fixed_string", - "array", - "boolean", - ]; - const COLUMN_COUNT: usize = 21; + fn column_names() -> impl IntoIterator { + [ + "int8", + "int32", + "int64", + "uint8", + "uint32", + "uint64", + "float32", + "float64", + "datetime", + "datetime64", + "time32", + "time64", + "decimal64", + "decimal128", + "string", + "blob", + "optional_decimal64", + "optional_datetime", + "fixed_string", + "array", + "boolean", + ] + } + fn column_count() -> usize { + 21 + } const KIND: crate::row::RowKind = crate::row::RowKind::Struct; type Value<'a> = Sample<'a>; diff --git a/src/rowbinary/validation.rs b/src/rowbinary/validation.rs index abecd5e7..7469f61b 100644 --- a/src/rowbinary/validation.rs +++ b/src/rowbinary/validation.rs @@ -28,6 +28,8 @@ pub(crate) trait SchemaValidator: Sized { /// It is used only if the crate detects that while the field names and the types are correct, /// the field order in the struct does not match the column order in the database schema. fn get_schema_index(&self, struct_idx: usize) -> usize; + fn get_field_name(&self, struct_idx: usize) -> Option<&'static str>; + fn peek(&self) -> Option; } pub(crate) struct DataTypeValidator<'cursor, R: Row> { @@ -184,6 +186,34 @@ impl<'cursor, R: Row> SchemaValidator for DataTypeValidator<'cursor, R> { self.metadata.get_schema_index(struct_idx) } + fn get_field_name(&self, struct_idx: usize) -> Option<&'static str> { + self.metadata.get_field_name(struct_idx) + } + + fn peek(&self) -> Option { + match R::KIND { + RowKind::Primitive => { + let data_type = &self.metadata.columns[0].data_type; + Some(data_type.into()) + } + RowKind::Tuple => Some(SerdeType::Tuple(self.metadata.columns.len()).into()), + RowKind::Vec => { + let data_type = &self.metadata.columns[0].data_type; + match data_type { + DataTypeNode::Array(inner_type) => Some(From::from(&**inner_type)), + _ => panic!( + "Expected Array type when validating root level sequence, but got {}", + self.metadata.columns[0].data_type + ), + } + } + RowKind::Struct => { + let current_column = &self.metadata.columns[self.current_column_idx]; + Some(From::from(¤t_column.data_type)) + } + } + } + #[cold] fn validate_identifier(&mut self, _value: T) { unreachable!() @@ -393,6 +423,50 @@ impl<'cursor, R: Row> SchemaValidator for Option usize { unreachable!() } + + fn get_field_name(&self, _struct_idx: usize) -> Option<&'static str> { + unreachable!() + } + + fn peek(&self) -> Option { + let inner = self.as_ref()?; + match &inner.kind { + InnerDataTypeValidatorKind::Map(kv, state) => match state { + MapValidatorState::Key => Some(From::from(&*kv[0])), + MapValidatorState::Value => Some(From::from(&*kv[1])), + }, + InnerDataTypeValidatorKind::MapAsSequence(kv, state) => match state { + MapAsSequenceValidatorState::Tuple | MapAsSequenceValidatorState::Key => { + Some(From::from(&*kv[0])) + } + MapAsSequenceValidatorState::Value => Some(From::from(&*kv[1])), + }, + InnerDataTypeValidatorKind::Array(inner_type) => Some(From::from(*inner_type)), + InnerDataTypeValidatorKind::Nullable(inner_type) => Some(From::from(*inner_type)), + InnerDataTypeValidatorKind::Tuple(elements_types) => { + Some(From::from(&elements_types[0])) + } + InnerDataTypeValidatorKind::FixedString(_len) => None, + InnerDataTypeValidatorKind::RootTuple(columns, current_index) => { + Some(From::from(&columns[*current_index].data_type)) + } + InnerDataTypeValidatorKind::RootArray(inner_data_type) => { + Some(From::from(*inner_data_type)) + } + InnerDataTypeValidatorKind::Variant(possible_types, state) => match state { + VariantValidationState::Pending => { + unreachable!() + } + VariantValidationState::Identifier(value) => { + let data_type = &possible_types[*value as usize]; + Some(From::from(data_type)) + } + }, + InnerDataTypeValidatorKind::Enum(_values_map) => { + unreachable!() + } + } + } } impl Drop for InnerDataTypeValidator<'_, '_, R> { @@ -638,6 +712,14 @@ impl SchemaValidator for () { fn get_schema_index(&self, _struct_idx: usize) -> usize { unreachable!() } + + fn get_field_name(&self, _struct_idx: usize) -> Option<&'static str> { + unreachable!() + } + + fn peek(&self) -> Option { + unreachable!() + } } /// Which Serde data type (De)serializer used for the given type. @@ -677,6 +759,51 @@ pub(crate) enum SerdeType { // IgnoredAny, } +impl From<&DataTypeNode> for SerdeType { + fn from(data_type: &DataTypeNode) -> Self { + match dbg!(data_type.remove_low_cardinality()) { + DataTypeNode::Bool => SerdeType::Bool, + DataTypeNode::Int8 | DataTypeNode::Enum(EnumType::Enum8, _) => SerdeType::I8, + DataTypeNode::Int16 | DataTypeNode::Enum(EnumType::Enum16, _) => SerdeType::I16, + DataTypeNode::Int32 + | DataTypeNode::Date32 + | DataTypeNode::Time + | DataTypeNode::Decimal(_, _, DecimalType::Decimal32) => SerdeType::I32, + DataTypeNode::Int64 + | DataTypeNode::DateTime64(_, _) + | DataTypeNode::Time64(_) + | DataTypeNode::Decimal(_, _, DecimalType::Decimal64) + | DataTypeNode::Interval(_) => SerdeType::I64, + DataTypeNode::Int128 | DataTypeNode::Decimal(_, _, DecimalType::Decimal128) => { + SerdeType::I128 + } + DataTypeNode::UInt8 => SerdeType::U8, + DataTypeNode::UInt16 | DataTypeNode::Date => SerdeType::U16, + DataTypeNode::UInt32 | DataTypeNode::DateTime(_) | DataTypeNode::IPv4 => SerdeType::U32, + DataTypeNode::UInt64 => SerdeType::U64, + DataTypeNode::UInt128 => SerdeType::U128, + DataTypeNode::Float32 => SerdeType::F32, + DataTypeNode::Float64 => SerdeType::F64, + DataTypeNode::String | DataTypeNode::JSON => SerdeType::String, + DataTypeNode::Nullable(_) => SerdeType::Option, + DataTypeNode::Array(_) + | DataTypeNode::Ring + | DataTypeNode::Polygon + | DataTypeNode::MultiPolygon + | DataTypeNode::LineString + | DataTypeNode::MultiLineString => SerdeType::Seq(0), + DataTypeNode::Tuple(elements) => SerdeType::Tuple(elements.len()), + DataTypeNode::FixedString(len) => SerdeType::Tuple(*len), + DataTypeNode::IPv6 => SerdeType::Tuple(16), + DataTypeNode::UUID => SerdeType::Tuple(UUID_TUPLE_ELEMENTS.len()), + DataTypeNode::Point => SerdeType::Tuple(POINT_TUPLE_ELEMENTS.len()), + DataTypeNode::Map(_) => SerdeType::Map(0), + DataTypeNode::Variant(_) => SerdeType::Enum, + _ => unimplemented!(), + } + } +} + impl Display for SerdeType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self {