diff --git a/scylla-cql/src/_macro_internal.rs b/scylla-cql/src/_macro_internal.rs index 4327ab72c8..09c9d7564b 100644 --- a/scylla-cql/src/_macro_internal.rs +++ b/scylla-cql/src/_macro_internal.rs @@ -5,7 +5,7 @@ pub use crate::deserialize::row::{ BuiltinDeserializationError as BuiltinRowDeserializationError, BuiltinDeserializationErrorKind as BuiltinRowDeserializationErrorKind, BuiltinTypeCheckErrorKind as DeserBuiltinRowTypeCheckErrorKind, ColumnIterator, DeserializeRow, - deser_error_replace_rust_name as row_deser_error_replace_rust_name, + RawColumn, deser_error_replace_rust_name as row_deser_error_replace_rust_name, mk_deser_err as mk_row_deser_err, mk_typck_err as mk_row_typck_err, }; pub use crate::deserialize::value::{ diff --git a/scylla-cql/src/deserialize/row_tests.rs b/scylla-cql/src/deserialize/row_tests.rs index 080771d984..8bdd8e7e60 100644 --- a/scylla-cql/src/deserialize/row_tests.rs +++ b/scylla-cql/src/deserialize/row_tests.rs @@ -114,9 +114,11 @@ fn test_struct_deserialization_loose_ordering() { d: i32, #[scylla(default_when_null)] e: &'a str, + #[scylla(allow_missing)] + f: &'a str, } - // Original order of columns + // Original order of columns without field f let specs = &[ spec("a", ColumnType::Native(NativeType::Text)), spec("b", ColumnType::Native(NativeType::Int)), @@ -133,17 +135,19 @@ fn test_struct_deserialization_loose_ordering() { c: String::new(), d: 0, e: "def", + f: "", } ); - // Different order of columns - should still work + // Different order of columns with field f - should still work with let specs = &[ spec("e", ColumnType::Native(NativeType::Text)), spec("b", ColumnType::Native(NativeType::Int)), spec("d", ColumnType::Native(NativeType::Int)), + spec("f", ColumnType::Native(NativeType::Text)), spec("a", ColumnType::Native(NativeType::Text)), ]; - let byts = serialize_cells([None, val_int(123), None, val_str("abc")]); + let byts = serialize_cells([None, val_int(123), None, val_str("efg"), val_str("abc")]); let row = deserialize::>(specs, &byts).unwrap(); assert_eq!( row, @@ -153,6 +157,7 @@ fn test_struct_deserialization_loose_ordering() { c: String::new(), d: 0, e: "", + f: "efg", } ); @@ -189,11 +194,13 @@ fn test_struct_deserialization_strict_ordering() { c: String, #[scylla(default_when_null)] d: i32, + #[scylla(allow_missing)] + f: i32, #[scylla(default_when_null)] e: &'a str, } - // Correct order of columns + // Correct order of columns without field f let specs = &[ spec("a", ColumnType::Native(NativeType::Text)), spec("b", ColumnType::Native(NativeType::Int)), @@ -210,6 +217,35 @@ fn test_struct_deserialization_strict_ordering() { c: String::new(), d: 0, e: "def", + f: 0, + } + ); + + // Correct order of columns with field f + let specs = &[ + spec("a", ColumnType::Native(NativeType::Text)), + spec("b", ColumnType::Native(NativeType::Int)), + spec("d", ColumnType::Native(NativeType::Int)), + spec("f", ColumnType::Native(NativeType::Int)), + spec("e", ColumnType::Native(NativeType::Text)), + ]; + let byts = serialize_cells([ + val_str("abc"), + val_int(123), + None, + val_int(234), + val_str("def"), + ]); + let row = deserialize::>(specs, &byts).unwrap(); + assert_eq!( + row, + MyRow { + a: "abc", + b: Some(123), + c: String::new(), + d: 0, + e: "def", + f: 234, } ); diff --git a/scylla-macros/src/deserialize/row.rs b/scylla-macros/src/deserialize/row.rs index ab40bf3881..761a8df88b 100644 --- a/scylla-macros/src/deserialize/row.rs +++ b/scylla-macros/src/deserialize/row.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use darling::{FromAttributes, FromField}; use proc_macro2::Span; + use syn::ext::IdentExt; use syn::parse_quote; @@ -51,13 +52,19 @@ struct Field { #[darling(default)] default_when_null: bool, + // If true, then - if this field is missing from the UDT fields metadata + // - it will be initialized to Default::default(). + #[darling(default)] + #[darling(rename = "allow_missing")] + default_when_missing: bool, + ident: Option, ty: syn::Type, } impl DeserializeCommonFieldAttrs for Field { fn needs_default(&self) -> bool { - self.skip || self.default_when_null + self.skip || self.default_when_null || self.default_when_missing } fn deserialize_target(&self) -> &syn::Type { @@ -137,7 +144,7 @@ fn validate_attrs(attrs: &StructAttrs, fields: &[Field]) -> Result<(), darling:: impl Field { // Returns whether this field is mandatory for deserialization. fn is_required(&self) -> bool { - !self.skip + !self.skip && !self.default_when_missing } // The name of the column corresponding to this Rust struct field @@ -175,19 +182,41 @@ impl StructDesc { struct TypeCheckAssumeOrderGenerator<'sd>(&'sd StructDesc); impl TypeCheckAssumeOrderGenerator<'_> { - fn generate_name_verification( + fn generate_field_validation( &self, field_index: usize, // These two indices can be different because of `skip` attribute column_index: usize, // applied to some field. field: &Field, - column_spec: &syn::Ident, - ) -> Option { - (!self.0.attrs.skip_name_checks).then(|| { - let macro_internal = self.0.struct_attrs().macro_internal_path(); - let rust_field_name = field.cql_name_literal(); + ) -> syn::Expr { + let skip_name_checks = self.0.attrs.skip_name_checks; + let default_when_missing = field.default_when_missing; + let (frame_lifetime, metadata_lifetime) = self.0.constraint_lifetimes(); + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let rust_field_name = field.cql_name_literal(); + let field_deserialization: syn::Type = { + let typ = field.deserialize_target(); + if field.default_when_null { + parse_quote! { + ::std::option::Option<#typ> + } + } else { + parse_quote! { + #typ + } + } + }; + + let name_mismatch: syn::Expr = if field.default_when_missing { parse_quote! { - if #column_spec.name() != #rust_field_name { + { + saved_col = ::std::option::Option::Some(next_col); + break 'verification; + } + } + } else { + parse_quote! { + { return ::std::result::Result::Err( #macro_internal::mk_row_typck_err::( column_types_iter(), @@ -195,13 +224,47 @@ impl TypeCheckAssumeOrderGenerator<'_> { field_index: #field_index, column_index: #column_index, rust_column_name: #rust_field_name, - db_column_name: ::std::borrow::ToOwned::to_owned(#column_spec.name()), + db_column_name: ::std::borrow::ToOwned::to_owned(next_col.name()), } ) ); } } - }) + }; + + let name_verification: Option = (!skip_name_checks).then(|| { + parse_quote! { + if next_col.name() != #rust_field_name { + #name_mismatch + } + } + }); + + parse_quote! { + 'field: { + let next_col = match saved_col.take().or_else(|| ::std::iter::Iterator::next(&mut col_iter)) { + ::std::option::Option::Some(col_spec) => col_spec, + // In case the Rust field allows default-initialisation and there are no more CQL fields, + // simply assume it's going to be default-initialised. + ::std::option::Option::None if #default_when_missing => break 'field, + ::std::option::Option::None => return ::std::result::Result::Err(wrong_column_count()), + }; + + 'verification: { + #name_verification + + <#field_deserialization as #macro_internal::DeserializeValue<#frame_lifetime, #metadata_lifetime>>::type_check(next_col.typ()) + .map_err(|err| #macro_internal::mk_row_typck_err::( + column_types_iter(), + #macro_internal::DeserBuiltinRowTypeCheckErrorKind::ColumnTypeCheckFailed { + column_index: #column_index, + column_name: ::std::borrow::ToOwned::to_owned(next_col.name()), + err, + } + ))?; + } + } + } } fn generate(&self) -> syn::ImplItemFn { @@ -209,38 +272,31 @@ impl TypeCheckAssumeOrderGenerator<'_> { // of the columns correspond fields' names/types. let macro_internal = self.0.struct_attrs().macro_internal_path(); - let (frame_lifetime, metadata_lifetime) = self.0.constraint_lifetimes(); - let required_fields_iter = || { + let required_fields_iter = || self.0.fields().iter().filter(|f| f.is_required()); + let required_fields_count = required_fields_iter().count(); + let required_fields_count_lit = + syn::LitInt::new(&required_fields_count.to_string(), Span::call_site()); + + let nonskipped_fields_iter = || { self.0 .fields() .iter() + // It is important that we enumerate **before** filtering, because otherwise we would not + // count the skipped fields, which might be confusing. .enumerate() - .filter(|(_, f)| f.is_required()) + .filter(|(_idx, f)| !f.skip) }; - let required_fields_count = required_fields_iter().count(); - let required_fields_idents: Vec<_> = (0..required_fields_count) + + let nonskippable_fields_idents: Vec<_> = (0..nonskipped_fields_iter().count()) .map(|i| quote::format_ident!("f_{}", i)) .collect(); - let name_verifications = required_fields_iter() - .zip(required_fields_idents.iter().enumerate()) - .map(|((field_idx, field), (col_idx, fidents))| { - self.generate_name_verification(field_idx, col_idx, field, fidents) - }); - let required_fields_deserializers = required_fields_iter().map(|(_, f)| -> syn::Type { - let typ = f.deserialize_target(); - if f.default_when_null { - parse_quote! { - ::std::option::Option<#typ> - } - } else { - parse_quote! { - #typ - } - } - }); - let numbers = 0usize..; + let field_validations = nonskipped_fields_iter() + .zip(nonskippable_fields_idents.iter().enumerate()) + .map(|((field_idx, field), (col_idx, ..))| { + self.generate_field_validation(field_idx, col_idx, field) + }); parse_quote! { fn type_check( @@ -248,35 +304,34 @@ impl TypeCheckAssumeOrderGenerator<'_> { ) -> ::std::result::Result<(), #macro_internal::TypeCheckError> { let column_types_iter = || ::std::iter::Iterator::map(specs.iter(), |spec| ::std::clone::Clone::clone(spec.typ()).into_owned()); - match specs { - [#(#required_fields_idents),*] => { - #( - // Verify the name (unless `skip_name_checks' is specified) - #name_verifications + let wrong_column_count = || { + #macro_internal::mk_row_typck_err::( + column_types_iter(), + #macro_internal::DeserBuiltinRowTypeCheckErrorKind::WrongColumnCount { + rust_cols: #required_fields_count, + cql_cols: specs.len(), + } + ) + }; - // Verify the type - <#required_fields_deserializers as #macro_internal::DeserializeValue<#frame_lifetime, #metadata_lifetime>>::type_check(#required_fields_idents.typ()) - .map_err(|err| #macro_internal::mk_row_typck_err::( - column_types_iter(), - #macro_internal::DeserBuiltinRowTypeCheckErrorKind::ColumnTypeCheckFailed { - column_index: #numbers, - column_name: ::std::borrow::ToOwned::to_owned(#required_fields_idents.name()), - err, - } - ))?; - )* - ::std::result::Result::Ok(()) - }, - _ => ::std::result::Result::Err( - #macro_internal::mk_row_typck_err::( - column_types_iter(), - #macro_internal::DeserBuiltinRowTypeCheckErrorKind::WrongColumnCount { - rust_cols: #required_fields_count, - cql_cols: specs.len(), - } - ), - ), + if specs.len() < #required_fields_count_lit { + return ::std::result::Result::Err(wrong_column_count()); + } + + + let mut col_iter = specs.iter(); + let mut saved_col = ::std::option::Option::None::<&#macro_internal::ColumnSpec>; + #( + #field_validations + )* + + if let ::std::option::Option::Some(next_col) = saved_col + .take() + .or_else(|| ::std::iter::Iterator::next(&mut col_iter)) { + return ::std::result::Result::Err(wrong_column_count()); } + + ::std::result::Result::Ok(()) } } } @@ -296,15 +351,24 @@ impl DeserializeAssumeOrderGenerator<'_> { let deserializer = field.deserialize_target(); let (frame_lifetime, metadata_lifetime) = self.0.constraint_lifetimes(); - let name_check: Option = (!self.0.struct_attrs().skip_name_checks).then(|| parse_quote! { - if col.spec.name() != #cql_name_literal { - ::std::panic!( - "Typecheck should have prevented this scenario - field-column name mismatch! Rust field name {}, CQL column name {}", - #cql_name_literal, - col.spec.name() - ); + let name_mismatch: syn::Expr = if field.default_when_missing { + parse_quote! { + { + saved_col = ::std::option::Option::Some(col); + ::std::default::Default::default() + } } - }); + } else { + parse_quote! { + { + ::std::panic!( + "Typecheck should have prevented this scenario - field-column name mismatch! Rust field name {}, CQL column name {}", + #cql_name_literal, + col.spec.name() + ); + } + } + }; let deserialize_expr: syn::Expr = if field.default_when_null { parse_quote! { @@ -331,15 +395,45 @@ impl DeserializeAssumeOrderGenerator<'_> { } }; + let maybe_name_check_and_deserialize_or_save: syn::Expr = + if self.0.struct_attrs().skip_name_checks { + parse_quote! { + #deserialize_expr + } + } else { + parse_quote! { + if col.spec.name() == #cql_name_literal { + #deserialize_expr + } else { + #name_mismatch + } + } + }; + + let no_more_fields: syn::Expr = if field.default_when_missing { + parse_quote! { + ::std::default::Default::default() + } + } else { + parse_quote! { + // Type check has ensured that there are enough CQL UDT fields. + ::std::panic!("Typecheck should have prevented this scenario! Too few columns in the serialized data.") + } + }; parse_quote!( { - let col = ::std::iter::Iterator::next(&mut row) - .expect("Typecheck should have prevented this scenario! Too few columns in the serialized data.") + let maybe_next_col = saved_col + .take() + .map(::std::result::Result::Ok) + .or_else(|| ::std::iter::Iterator::next(&mut row)) + .transpose() .map_err(#macro_internal::row_deser_error_replace_rust_name::)?; - #name_check - - #deserialize_expr + if let ::std::option::Option::Some(col) = maybe_next_col { + #maybe_name_check_and_deserialize_or_save + } else { + #no_more_fields + } } ) } @@ -359,6 +453,8 @@ impl DeserializeAssumeOrderGenerator<'_> { fn deserialize( mut row: #macro_internal::ColumnIterator<#frame_lifetime, #metadata_lifetime>, ) -> ::std::result::Result { + let mut saved_col = ::std::option::Option::None::<#macro_internal::RawColumn<#frame_lifetime, #metadata_lifetime>>; + ::std::result::Result::Ok(Self { #(#field_idents: #field_finalizers,)* }) @@ -543,11 +639,18 @@ impl DeserializeUnorderedGenerator<'_> { let deserialize_field = Self::deserialize_field_variable(field); let cql_name_literal = field.cql_name_literal(); - parse_quote! { - #deserialize_field.unwrap_or_else(|| ::std::panic!( - "column {} missing in DB row - type check should have prevented this!", - #cql_name_literal - )) + if field.default_when_missing { + // Generate Default::default if the field was missing + parse_quote! { + #deserialize_field.unwrap_or_default() + } + } else { + parse_quote! { + #deserialize_field.unwrap_or_else(|| ::std::panic!( + "column {} missing in DB row - type check should have prevented this!", + #cql_name_literal + )) + } } } diff --git a/scylla-macros/src/deserialize/value.rs b/scylla-macros/src/deserialize/value.rs index bed81ac771..39ea8abf59 100644 --- a/scylla-macros/src/deserialize/value.rs +++ b/scylla-macros/src/deserialize/value.rs @@ -595,8 +595,7 @@ impl TypeCheckUnorderedGenerator<'_> { let visited_flag = Self::visited_flag_variable(field); let typ = field.deserialize_target(); let cql_name_literal = field.cql_name_literal(); - let decrement_if_required: Option = field - .is_required() + let decrement_if_required: Option = field.is_required() .then(|| parse_quote! {remaining_required_cql_fields -= 1;}); parse_quote! { @@ -662,6 +661,7 @@ impl TypeCheckUnorderedGenerator<'_> { .filter(|f| !f.skip) .map(|f| f.cql_name_literal()); let required_cql_field_count = rust_fields.iter().filter(|f| f.is_required()).count(); + let required_cql_field_count_lit = syn::LitInt::new(&required_cql_field_count.to_string(), Span::call_site()); let extract_cql_fields_expr = self.0.generate_extract_fields_from_type(parse_quote!(typ)); diff --git a/scylla-macros/src/lib.rs b/scylla-macros/src/lib.rs index 05a8ddcf78..9881735d81 100644 --- a/scylla-macros/src/lib.rs +++ b/scylla-macros/src/lib.rs @@ -376,7 +376,8 @@ mod deserialize; /// If set, the generated implementation will not verify the column names at /// all. Because it only works with `enforce_order`, it will deserialize first /// column into the first field, second column into the second field and so on. -/// It will still still verify that the column types and field types match. +/// It will still verify that the column types and field types match. +/// /// /// ## Field attributes /// @@ -395,6 +396,11 @@ mod deserialize; /// By default, the generated implementation will try to match the Rust field /// to a column with the same name. This attribute allows to match to a column /// with provided name. +/// +/// #[scylla(allow_missing)] +/// +/// if set, implementation will not fail if some columns are missing. +/// Instead, it will initialize the field with `Default::default()`. #[proc_macro_derive(DeserializeRow, attributes(scylla))] pub fn deserialize_row_derive(tokens_input: TokenStream) -> TokenStream { match deserialize::row::deserialize_row_derive(tokens_input) { @@ -501,6 +507,7 @@ pub fn deserialize_row_derive(tokens_input: TokenStream) -> TokenStream { /// If more strictness is desired, this flag makes sure that no excess fields /// are present and forces error in case there are some. /// +/// /// ## Field attributes /// /// `#[scylla(skip)]` diff --git a/scylla-macros/src/serialize/row.rs b/scylla-macros/src/serialize/row.rs index ec114f43d0..4d3d4a3d38 100644 --- a/scylla-macros/src/serialize/row.rs +++ b/scylla-macros/src/serialize/row.rs @@ -22,6 +22,11 @@ struct Attributes { // This annotation only works if `enforce_order` flavor is specified. #[darling(default)] skip_name_checks: bool, + + // Used for deserialization only. Ignored in serialization. + #[darling(default)] + #[darling(rename = "allow_missing")] + _default_when_missing: bool, } impl Attributes { @@ -70,6 +75,11 @@ struct FieldAttributes { #[darling(default)] #[darling(rename = "default_when_null")] _default_when_null: bool, + + // Used for deserialization only. Ignored in serialization. + #[darling(default)] + #[darling(rename = "allow_missing")] + _default_when_missing: bool, } struct Context { diff --git a/scylla-macros/src/serialize/value.rs b/scylla-macros/src/serialize/value.rs index 47c62c297e..b8454e8219 100644 --- a/scylla-macros/src/serialize/value.rs +++ b/scylla-macros/src/serialize/value.rs @@ -30,6 +30,11 @@ struct Attributes { // the DB will interpret them as NULLs anyway. #[darling(default)] forbid_excess_udt_fields: bool, + + // Used for deserialization only. Ignored in serialization. + #[darling(default)] + #[darling(rename = "allow_missing")] + _default_when_missing: bool, } impl Attributes { diff --git a/scylla/tests/integration/macros/hygiene.rs b/scylla/tests/integration/macros/hygiene.rs index 047503a6d9..d385b0ca21 100644 --- a/scylla/tests/integration/macros/hygiene.rs +++ b/scylla/tests/integration/macros/hygiene.rs @@ -311,6 +311,8 @@ macro_rules! test_crate { c: ::core::primitive::i32, #[scylla(default_when_null)] d: ::core::primitive::i32, + #[scylla(allow_missing)] + e: ::core::primitive::i32, } // Test attributes for row struct with ordered flavor @@ -327,6 +329,8 @@ macro_rules! test_crate { c: ::core::primitive::i32, #[scylla(default_when_null)] d: ::core::primitive::i32, + #[scylla(allow_missing)] + e: ::core::primitive::i32, } // Test attributes for row struct with ordered flavor and skipped name checks @@ -342,6 +346,8 @@ macro_rules! test_crate { c: ::core::primitive::i32, #[scylla(default_when_null)] d: ::core::primitive::i32, + #[scylla(allow_missing)] + e: ::core::primitive::i32, } }; }