Skip to content

Commit c2531d6

Browse files
committed
Add derive macros
Fixes #43
1 parent 13fcbc7 commit c2531d6

File tree

7 files changed

+514
-0
lines changed

7 files changed

+514
-0
lines changed

Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,13 @@ name = "approx"
2626

2727
[features]
2828
default = ["std"]
29+
derive = ["dep:approx-derive"]
2930
std = []
3031

3132
[dependencies]
33+
approx-derive = { version = "0.5.1", path = "approx-derive", optional = true }
3234
num-traits = { version = "0.2.0", default_features = false }
3335
num-complex = { version = "0.4.0", optional = true }
36+
37+
[workspace]
38+
members = ["approx-derive"]

approx-derive/Cargo.toml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
[package]
2+
name = "approx-derive"
3+
version = "0.5.1"
4+
edition = "2021"
5+
6+
[lib]
7+
proc-macro = true
8+
9+
[dependencies]
10+
darling = "0.20"
11+
itertools = "0.13"
12+
proc-macro2 = "1"
13+
quote = "1"
14+
syn = "2"

approx-derive/src/abs_diff_eq.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
use darling::{ast::Data, FromDeriveInput};
2+
use proc_macro2::TokenStream;
3+
use quote::quote;
4+
use syn::{DeriveInput, Expr, Path};
5+
6+
use crate::{Field, Variant};
7+
8+
#[derive(FromDeriveInput)]
9+
#[darling(allow_unknown_fields, attributes(approx))]
10+
struct Opts {
11+
epsilon: Path,
12+
absolute: Expr,
13+
data: Data<Variant, Field>,
14+
}
15+
16+
pub(crate) fn derive(item: DeriveInput) -> syn::Result<TokenStream> {
17+
let Opts {
18+
epsilon,
19+
absolute,
20+
data,
21+
} = Opts::from_derive_input(&item)?;
22+
23+
let ident = item.ident;
24+
let comparisons = super::comparisons(data, "abs_diff_eq", None);
25+
26+
Ok(quote! {
27+
#[automatically_derived]
28+
impl AbsDiffEq for #ident {
29+
type Epsilon = #epsilon;
30+
31+
fn default_epsilon() -> Self::Epsilon {
32+
#absolute
33+
}
34+
35+
fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
36+
#comparisons
37+
}
38+
}
39+
})
40+
}

approx-derive/src/lib.rs

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
mod abs_diff_eq;
2+
mod relative_eq;
3+
4+
use darling::{
5+
ast::{Data, Fields, Style},
6+
util::Flag,
7+
FromField, FromVariant,
8+
};
9+
use itertools::Itertools;
10+
use proc_macro2::{Literal, Span, TokenStream};
11+
use quote::{format_ident, quote, ToTokens};
12+
use syn::{parse_macro_input, DeriveInput, Ident};
13+
14+
#[derive(FromVariant)]
15+
#[darling(attributes(approx))]
16+
struct Variant {
17+
ident: Ident,
18+
fields: Fields<Field>,
19+
}
20+
21+
#[derive(FromField)]
22+
#[darling(attributes(approx), and_then = Self::validate)]
23+
struct Field {
24+
ident: Option<Ident>,
25+
skip: Flag,
26+
approximate: Flag,
27+
}
28+
29+
impl Field {
30+
fn validate(self) -> darling::Result<Self> {
31+
if self.skip.is_present() && self.approximate.is_present() {
32+
Err(
33+
darling::Error::custom("Cannot both skip and use approximate equality")
34+
.with_span(&self.approximate.span()),
35+
)
36+
} else {
37+
Ok(self)
38+
}
39+
}
40+
}
41+
42+
#[proc_macro_derive(AbsDiffEq, attributes(approx))]
43+
pub fn abs_diff_eq(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
44+
let item = parse_macro_input!(item as DeriveInput);
45+
convert(abs_diff_eq::derive(item))
46+
}
47+
48+
#[proc_macro_derive(RelativeEq, attributes(approx))]
49+
pub fn relative_eq(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
50+
let item = parse_macro_input!(item as DeriveInput);
51+
convert(relative_eq::derive(item))
52+
}
53+
54+
fn convert(tokens: syn::Result<TokenStream>) -> proc_macro::TokenStream {
55+
tokens.unwrap_or_else(syn::Error::into_compile_error).into()
56+
}
57+
58+
fn comparisons(data: Data<Variant, Field>, f: &str, arg: Option<&str>) -> TokenStream {
59+
let comparisons = data
60+
.map_enum_variants(|variant| {
61+
let ident = variant.ident;
62+
let fields = variant.fields;
63+
match fields.style {
64+
Style::Tuple => {
65+
let (comps, self_extractors, other_extractors): (Vec<_>, Vec<_>, Vec<_>) =
66+
fields
67+
.fields
68+
.into_iter()
69+
.enumerate()
70+
.map(|(i, field)| {
71+
if field.skip.is_present() {
72+
(None, format_ident!("_"), format_ident!("_"))
73+
} else {
74+
let one = format_ident!("_{}", i);
75+
let other = format_ident!("other_{}", i);
76+
(
77+
Some(compare(&one, &other, field.approximate, f, arg)),
78+
one,
79+
other,
80+
)
81+
}
82+
})
83+
.multiunzip();
84+
let comps = comps.iter().flatten();
85+
quote! {
86+
Self::#ident(#(#self_extractors),*) => match other {
87+
Self::#ident(#(#other_extractors),*) => #(#comps)&&*,
88+
_ => false
89+
}
90+
}
91+
}
92+
Style::Struct => {
93+
let (comps, self_extractors, other_extractors): (Vec<_>, Vec<_>, Vec<_>) =
94+
fields
95+
.fields
96+
.into_iter()
97+
.filter_map(|field| {
98+
if field.skip.is_present() {
99+
None
100+
} else {
101+
let one = field.ident.clone().unwrap();
102+
let other = format_ident!("other_{}", one);
103+
Some((
104+
compare(&one, &other, field.approximate, f, arg),
105+
one.clone(),
106+
quote! { #one: #other },
107+
))
108+
}
109+
})
110+
.multiunzip();
111+
quote! {
112+
Self::#ident { #(#self_extractors),*, .. } => match other {
113+
Self::#ident {#(#other_extractors),*, ..} => #(#comps)&&*,
114+
_ => false
115+
}
116+
}
117+
}
118+
Style::Unit => quote! { Self::#ident => self == other },
119+
}
120+
})
121+
.map_struct(|fields| {
122+
Fields::<TokenStream>::from((
123+
fields.style,
124+
fields
125+
.into_iter()
126+
.enumerate()
127+
.filter_map(|(i, field)| {
128+
if field.skip.is_present() {
129+
None
130+
} else {
131+
let ident = match field.ident {
132+
None => Literal::usize_unsuffixed(i).to_token_stream(),
133+
Some(ident) => quote! { #ident },
134+
};
135+
Some(compare(
136+
quote! { self.#ident },
137+
quote! { other.#ident },
138+
field.approximate,
139+
f,
140+
arg,
141+
))
142+
}
143+
})
144+
.collect::<Vec<_>>(),
145+
))
146+
});
147+
148+
if comparisons.is_enum() {
149+
let comparisons = comparisons.take_enum().unwrap();
150+
quote! {
151+
match self {
152+
#(#comparisons),*
153+
}
154+
}
155+
} else {
156+
let comparisons = comparisons.take_struct().unwrap().fields.into_iter();
157+
quote!(#(#comparisons)&&*)
158+
}
159+
}
160+
161+
fn compare<One, Other>(
162+
one: One,
163+
other: Other,
164+
approximate: Flag,
165+
f: &str,
166+
arg: Option<&str>,
167+
) -> TokenStream
168+
where
169+
One: ToTokens,
170+
Other: ToTokens,
171+
{
172+
if approximate.is_present() {
173+
let q = |s| Ident::new(s, Span::call_site());
174+
let arg = arg.map(|arg| {
175+
let arg = q(arg);
176+
quote! {, #arg}
177+
});
178+
let f = q(f);
179+
quote! { #one.#f(&#other, epsilon #arg) }
180+
} else {
181+
quote! { #one == #other }
182+
}
183+
}

approx-derive/src/relative_eq.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
use darling::{ast::Data, FromDeriveInput};
2+
use proc_macro2::TokenStream;
3+
use quote::quote;
4+
use syn::{DeriveInput, Expr};
5+
6+
use crate::{Field, Variant};
7+
8+
#[derive(FromDeriveInput)]
9+
#[darling(allow_unknown_fields, attributes(approx))]
10+
struct Opts {
11+
relative: Expr,
12+
data: Data<Variant, Field>,
13+
}
14+
15+
pub(crate) fn derive(item: DeriveInput) -> syn::Result<TokenStream> {
16+
let Opts { relative, data } = Opts::from_derive_input(&item)?;
17+
18+
let ident = item.ident;
19+
let comparisons = super::comparisons(data, "relative_eq", Some("max_relative"));
20+
21+
Ok(quote! {
22+
#[automatically_derived]
23+
impl RelativeEq for #ident {
24+
fn default_max_relative() -> Self::Epsilon {
25+
#relative
26+
}
27+
28+
fn relative_eq(&self, other: &Self, epsilon: Self::Epsilon, max_relative: Self::Epsilon) -> bool {
29+
#comparisons
30+
}
31+
}
32+
})
33+
}

src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@
157157
#![no_std]
158158
#![allow(clippy::transmute_float_to_int)]
159159

160+
#[cfg(feature = "derive")]
161+
extern crate approx_derive;
160162
#[cfg(feature = "num-complex")]
161163
extern crate num_complex;
162164
extern crate num_traits;
@@ -171,6 +173,9 @@ pub use abs_diff_eq::AbsDiffEq;
171173
pub use relative_eq::RelativeEq;
172174
pub use ulps_eq::UlpsEq;
173175

176+
#[cfg(feature = "derive")]
177+
pub use approx_derive::{AbsDiffEq, RelativeEq};
178+
174179
/// The requisite parameters for testing for approximate equality using a
175180
/// absolute difference based comparison.
176181
///

0 commit comments

Comments
 (0)