Skip to content

Commit a6611c0

Browse files
committed
Add derive macros
Fixes #43
1 parent 13fcbc7 commit a6611c0

File tree

7 files changed

+594
-0
lines changed

7 files changed

+594
-0
lines changed

Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,13 @@ name = "approx"
2626

2727
[features]
2828
default = ["std"]
29+
derive = ["dep:approx-derive"]
2930
std = []
3031

3132
[dependencies]
33+
approx-derive = { version = "0.5.1", path = "approx-derive", optional = true }
3234
num-traits = { version = "0.2.0", default_features = false }
3335
num-complex = { version = "0.4.0", optional = true }
36+
37+
[workspace]
38+
members = ["approx-derive"]

approx-derive/Cargo.toml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
[package]
2+
name = "approx-derive"
3+
version = "0.5.1"
4+
edition = "2021"
5+
6+
[lib]
7+
proc-macro = true
8+
9+
[dependencies]
10+
darling = "0.20"
11+
itertools = "0.13"
12+
proc-macro2 = "1"
13+
quote = "1"
14+
syn = "2"

approx-derive/src/abs_diff_eq.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
use darling::{ast::Generics, FromMeta};
2+
use proc_macro2::TokenStream;
3+
use quote::quote;
4+
use syn::{DeriveInput, Expr, Path};
5+
6+
#[derive(FromMeta)]
7+
#[darling(allow_unknown_fields)]
8+
struct Opts {
9+
epsilon: Path,
10+
absolute: Expr,
11+
}
12+
13+
pub(crate) fn derive(item: DeriveInput) -> syn::Result<TokenStream> {
14+
let (
15+
super::Opts {
16+
value: Opts { epsilon, absolute },
17+
generics: Generics { where_clause, .. },
18+
data,
19+
},
20+
params,
21+
) = super::Opts::parse(&item)?;
22+
23+
let ident = item.ident;
24+
let comparisons = super::comparisons(data, "abs_diff_eq", None);
25+
26+
Ok(dbg!(quote! {
27+
#[automatically_derived]
28+
impl<#(#params)*> AbsDiffEq for #ident <#(#params)*> #where_clause {
29+
type Epsilon = #epsilon;
30+
31+
fn default_epsilon() -> Self::Epsilon {
32+
#absolute
33+
}
34+
35+
fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
36+
#comparisons
37+
}
38+
}
39+
}))
40+
}

approx-derive/src/lib.rs

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
mod abs_diff_eq;
2+
mod relative_eq;
3+
4+
use darling::{
5+
ast::{Data, Fields, GenericParam, GenericParamExt, Generics, Style},
6+
util::Flag,
7+
FromDeriveInput, FromField, FromMeta, FromVariant,
8+
};
9+
use itertools::Itertools;
10+
use proc_macro2::{Literal, Span, TokenStream};
11+
use quote::{format_ident, quote, ToTokens};
12+
use syn::{parse_macro_input, DeriveInput, Ident};
13+
14+
#[derive(FromDeriveInput)]
15+
#[darling(attributes(approx))]
16+
struct Opts<T> {
17+
generics: Generics<GenericParam<Ident>>,
18+
data: Data<Variant, Field>,
19+
#[darling(flatten)]
20+
value: T,
21+
}
22+
23+
impl<T> Opts<T> {
24+
fn parse(input: &DeriveInput) -> darling::Result<(Self, Vec<Ident>)>
25+
where
26+
T: FromMeta,
27+
{
28+
Self::from_derive_input(&input).map(|opts| {
29+
let params = opts
30+
.generics
31+
.params
32+
.iter()
33+
.map(|param| param.as_type_param().cloned().unwrap())
34+
.collect();
35+
(opts, params)
36+
})
37+
}
38+
}
39+
40+
#[derive(FromVariant)]
41+
#[darling(attributes(approx))]
42+
struct Variant {
43+
ident: Ident,
44+
fields: Fields<Field>,
45+
}
46+
47+
#[derive(FromField)]
48+
#[darling(attributes(approx), and_then = Self::validate)]
49+
struct Field {
50+
ident: Option<Ident>,
51+
skip: Flag,
52+
approximate: Flag,
53+
}
54+
55+
impl Field {
56+
fn validate(self) -> darling::Result<Self> {
57+
if self.skip.is_present() && self.approximate.is_present() {
58+
Err(
59+
darling::Error::custom("Cannot both skip and use approximate equality")
60+
.with_span(&self.approximate.span()),
61+
)
62+
} else {
63+
Ok(self)
64+
}
65+
}
66+
}
67+
68+
#[proc_macro_derive(AbsDiffEq, attributes(approx))]
69+
pub fn abs_diff_eq(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
70+
let item = parse_macro_input!(item as DeriveInput);
71+
convert(abs_diff_eq::derive(item))
72+
}
73+
74+
#[proc_macro_derive(RelativeEq, attributes(approx))]
75+
pub fn relative_eq(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
76+
let item = parse_macro_input!(item as DeriveInput);
77+
convert(relative_eq::derive(item))
78+
}
79+
80+
fn convert(tokens: syn::Result<TokenStream>) -> proc_macro::TokenStream {
81+
tokens.unwrap_or_else(syn::Error::into_compile_error).into()
82+
}
83+
84+
fn comparisons(data: Data<Variant, Field>, f: &str, arg: Option<&str>) -> TokenStream {
85+
let comparisons = data
86+
.map_enum_variants(|variant| {
87+
let ident = variant.ident;
88+
let fields = variant.fields;
89+
match fields.style {
90+
Style::Tuple => {
91+
let (comps, self_extractors, other_extractors): (Vec<_>, Vec<_>, Vec<_>) =
92+
fields
93+
.fields
94+
.into_iter()
95+
.enumerate()
96+
.map(|(i, field)| {
97+
if field.skip.is_present() {
98+
(None, format_ident!("_"), format_ident!("_"))
99+
} else {
100+
let one = format_ident!("_{}", i);
101+
let other = format_ident!("other_{}", i);
102+
(
103+
Some(compare(&one, &other, field.approximate, f, arg)),
104+
one,
105+
other,
106+
)
107+
}
108+
})
109+
.multiunzip();
110+
let comps = comps.iter().flatten();
111+
quote! {
112+
Self::#ident(#(#self_extractors),*) => match other {
113+
Self::#ident(#(#other_extractors),*) => #(#comps)&&*,
114+
_ => false
115+
}
116+
}
117+
}
118+
Style::Struct => {
119+
let (comps, self_extractors, other_extractors): (Vec<_>, Vec<_>, Vec<_>) =
120+
fields
121+
.fields
122+
.into_iter()
123+
.filter_map(|field| {
124+
if field.skip.is_present() {
125+
None
126+
} else {
127+
let one = field.ident.clone().unwrap();
128+
let other = format_ident!("other_{}", one);
129+
Some((
130+
compare(&one, &other, field.approximate, f, arg),
131+
one.clone(),
132+
quote! { #one: #other },
133+
))
134+
}
135+
})
136+
.multiunzip();
137+
quote! {
138+
Self::#ident { #(#self_extractors),*, .. } => match other {
139+
Self::#ident {#(#other_extractors),*, ..} => #(#comps)&&*,
140+
_ => false
141+
}
142+
}
143+
}
144+
Style::Unit => quote! { Self::#ident => self == other },
145+
}
146+
})
147+
.map_struct(|fields| {
148+
Fields::<TokenStream>::from((
149+
fields.style,
150+
fields
151+
.into_iter()
152+
.enumerate()
153+
.filter_map(|(i, field)| {
154+
if field.skip.is_present() {
155+
None
156+
} else {
157+
let ident = match field.ident {
158+
None => Literal::usize_unsuffixed(i).to_token_stream(),
159+
Some(ident) => quote! { #ident },
160+
};
161+
Some(compare(
162+
quote! { self.#ident },
163+
quote! { other.#ident },
164+
field.approximate,
165+
f,
166+
arg,
167+
))
168+
}
169+
})
170+
.collect::<Vec<_>>(),
171+
))
172+
});
173+
174+
if comparisons.is_enum() {
175+
let comparisons = comparisons.take_enum().unwrap();
176+
quote! {
177+
match self {
178+
#(#comparisons),*
179+
}
180+
}
181+
} else {
182+
let comparisons = comparisons.take_struct().unwrap().fields.into_iter();
183+
quote!(#(#comparisons)&&*)
184+
}
185+
}
186+
187+
fn compare<One, Other>(
188+
one: One,
189+
other: Other,
190+
approximate: Flag,
191+
f: &str,
192+
arg: Option<&str>,
193+
) -> TokenStream
194+
where
195+
One: ToTokens,
196+
Other: ToTokens,
197+
{
198+
if approximate.is_present() {
199+
let q = |s| Ident::new(s, Span::call_site());
200+
let arg = arg.map(|arg| {
201+
let arg = q(arg);
202+
quote! {, #arg}
203+
});
204+
let f = q(f);
205+
quote! { #one.#f(&#other, epsilon #arg) }
206+
} else {
207+
quote! { #one == #other }
208+
}
209+
}

approx-derive/src/relative_eq.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
use darling::{ast::Generics, FromMeta};
2+
use proc_macro2::TokenStream;
3+
use quote::quote;
4+
use syn::{DeriveInput, Expr};
5+
6+
#[derive(FromMeta)]
7+
#[darling(allow_unknown_fields)]
8+
struct Opts {
9+
relative: Expr,
10+
}
11+
12+
pub(crate) fn derive(item: DeriveInput) -> syn::Result<TokenStream> {
13+
let (
14+
super::Opts {
15+
value: Opts { relative },
16+
generics: Generics { where_clause, .. },
17+
data,
18+
},
19+
params,
20+
) = super::Opts::parse(&item)?;
21+
22+
let ident = item.ident;
23+
let comparisons = super::comparisons(data, "relative_eq", Some("max_relative"));
24+
25+
Ok(quote! {
26+
#[automatically_derived]
27+
impl<#(#params)*> RelativeEq for #ident <#(#params)*> #where_clause {
28+
fn default_max_relative() -> Self::Epsilon {
29+
#relative
30+
}
31+
32+
fn relative_eq(&self, other: &Self, epsilon: Self::Epsilon, max_relative: Self::Epsilon) -> bool {
33+
#comparisons
34+
}
35+
}
36+
})
37+
}

src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@
157157
#![no_std]
158158
#![allow(clippy::transmute_float_to_int)]
159159

160+
#[cfg(feature = "derive")]
161+
extern crate approx_derive;
160162
#[cfg(feature = "num-complex")]
161163
extern crate num_complex;
162164
extern crate num_traits;
@@ -171,6 +173,9 @@ pub use abs_diff_eq::AbsDiffEq;
171173
pub use relative_eq::RelativeEq;
172174
pub use ulps_eq::UlpsEq;
173175

176+
#[cfg(feature = "derive")]
177+
pub use approx_derive::{AbsDiffEq, RelativeEq};
178+
174179
/// The requisite parameters for testing for approximate equality using a
175180
/// absolute difference based comparison.
176181
///

0 commit comments

Comments
 (0)