Skip to content

Commit 53ff1a8

Browse files
committed
add #[default_error(<type>)] attribute to initializer macros
The `#[default_error(<type>)]` attribute macro can be used to supply a default type as the error used for the `[pin_]init!` macros. This way one can easily define custom `try_[pin_]init!` variants that default to your project specific error type. Just write the following declarative macro: macro_rules! try_init { ($($args:tt)*) => { ::pin_init::init!( #[default_error(YourCustomErrorType)] $($args)* ) } } Signed-off-by: Benno Lossin <lossin@kernel.org>
1 parent 5f88f2c commit 53ff1a8

File tree

2 files changed

+65
-4
lines changed

2 files changed

+65
-4
lines changed

internal/src/init.rs

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@ use syn::{
66
parse_quote,
77
punctuated::Punctuated,
88
spanned::Spanned,
9-
token, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
9+
token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
1010
};
1111

1212
pub struct Initializer {
13+
attrs: Vec<InitializerAttribute>,
1314
this: Option<This>,
1415
path: Path,
1516
brace_token: token::Brace,
@@ -50,23 +51,44 @@ impl InitializerField {
5051
}
5152
}
5253

54+
enum InitializerAttribute {
55+
DefaultError(DefaultErrorAttribute),
56+
}
57+
58+
struct DefaultErrorAttribute {
59+
ty: Type,
60+
}
61+
5362
pub(crate) fn expand(
5463
Initializer {
64+
attrs,
5565
this,
5666
path,
5767
brace_token,
5868
fields,
5969
rest,
60-
mut error,
70+
error,
6171
}: Initializer,
6272
default_error: Option<&'static str>,
6373
pinned: bool,
6474
) -> TokenStream {
6575
let mut errors = TokenStream::new();
76+
let mut error = error.map(|(_, err)| err);
77+
if let Some(default_error) = attrs.iter().fold(None, |acc, attr| {
78+
#[expect(irrefutable_let_patterns)]
79+
if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr {
80+
Some(ty.clone())
81+
} else {
82+
acc
83+
}
84+
}) {
85+
error.get_or_insert(default_error);
86+
}
6687
if let Some(default_error) = default_error {
67-
error.get_or_insert((Default::default(), syn::parse_str(default_error).unwrap()));
88+
error.get_or_insert(syn::parse_str(default_error).unwrap());
6889
}
69-
let error = error.map(|(_, err)| err).unwrap_or_else(|| {
90+
91+
let error = error.unwrap_or_else(|| {
7092
errors.extend(quote_spanned!(brace_token.span.close()=>
7193
::core::compile_error!("expected `? <type>` after `}`");
7294
));
@@ -350,6 +372,7 @@ fn make_field_check(
350372

351373
impl Parse for Initializer {
352374
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
375+
let attrs = input.call(Attribute::parse_outer)?;
353376
let this = input.peek(Token![&]).then(|| input.parse()).transpose()?;
354377
let path = input.parse()?;
355378
let content;
@@ -381,7 +404,19 @@ impl Parse for Initializer {
381404
.peek(Token![?])
382405
.then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?)))
383406
.transpose()?;
407+
let attrs = attrs
408+
.into_iter()
409+
.map(|a| {
410+
if a.path().is_ident("default_error") {
411+
a.parse_args::<DefaultErrorAttribute>()
412+
.map(InitializerAttribute::DefaultError)
413+
} else {
414+
Err(syn::Error::new_spanned(a, "unknown initializer attribute"))
415+
}
416+
})
417+
.collect::<Result<Vec<_>, _>>()?;
384418
Ok(Self {
419+
attrs,
385420
this,
386421
path,
387422
brace_token,
@@ -392,6 +427,16 @@ impl Parse for Initializer {
392427
}
393428
}
394429

430+
impl Parse for DefaultErrorAttribute {
431+
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
432+
let ty = input.parse()?;
433+
if !input.peek(End) {
434+
return Err(input.error("expected end of input"));
435+
}
436+
Ok(Self { ty })
437+
}
438+
}
439+
395440
impl Parse for This {
396441
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
397442
Ok(Self {

tests/default_error.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#![allow(dead_code)]
2+
3+
use pin_init::{init, Init};
4+
5+
struct Foo {}
6+
7+
struct Error;
8+
9+
impl Foo {
10+
fn new() -> impl Init<Foo, Error> {
11+
init!(
12+
#[default_error(Error)]
13+
Foo {}
14+
)
15+
}
16+
}

0 commit comments

Comments
 (0)