1 extern crate proc_macro;
2 
3 use proc_macro2::{Span, TokenStream};
4 use quote::quote;
5 use syn::*;
6 
7 static ARBITRARY_LIFETIME_NAME: &str = "'arbitrary";
8 
9 #[proc_macro_derive(Arbitrary)]
derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream10 pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
11     let input = syn::parse_macro_input!(tokens as syn::DeriveInput);
12     let (lifetime_without_bounds, lifetime_with_bounds) =
13         build_arbitrary_lifetime(input.generics.clone());
14 
15     let arbitrary_method = gen_arbitrary_method(&input, lifetime_without_bounds.clone());
16     let size_hint_method = gen_size_hint_method(&input);
17     let name = input.ident;
18     // Add a bound `T: Arbitrary` to every type parameter T.
19     let generics = add_trait_bounds(input.generics, lifetime_without_bounds.clone());
20 
21     // Build ImplGeneric with a lifetime (https://github.com/dtolnay/syn/issues/90)
22     let mut generics_with_lifetime = generics.clone();
23     generics_with_lifetime
24         .params
25         .push(GenericParam::Lifetime(lifetime_with_bounds));
26     let (impl_generics, _, _) = generics_with_lifetime.split_for_impl();
27 
28     // Build TypeGenerics and WhereClause without a lifetime
29     let (_, ty_generics, where_clause) = generics.split_for_impl();
30 
31     (quote! {
32         impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> for #name #ty_generics #where_clause {
33             #arbitrary_method
34             #size_hint_method
35         }
36     })
37     .into()
38 }
39 
40 // Returns: (lifetime without bounds, lifetime with bounds)
41 // Example: ("'arbitrary", "'arbitrary: 'a + 'b")
build_arbitrary_lifetime(generics: Generics) -> (LifetimeDef, LifetimeDef)42 fn build_arbitrary_lifetime(generics: Generics) -> (LifetimeDef, LifetimeDef) {
43     let lifetime_without_bounds =
44         LifetimeDef::new(Lifetime::new(ARBITRARY_LIFETIME_NAME, Span::call_site()));
45     let mut lifetime_with_bounds = lifetime_without_bounds.clone();
46 
47     for param in generics.params.iter() {
48         if let GenericParam::Lifetime(lifetime_def) = param {
49             lifetime_with_bounds
50                 .bounds
51                 .push(lifetime_def.lifetime.clone());
52         }
53     }
54 
55     (lifetime_without_bounds, lifetime_with_bounds)
56 }
57 
58 // Add a bound `T: Arbitrary` to every type parameter T.
add_trait_bounds(mut generics: Generics, lifetime: LifetimeDef) -> Generics59 fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeDef) -> Generics {
60     for param in generics.params.iter_mut() {
61         if let GenericParam::Type(type_param) = param {
62             type_param
63                 .bounds
64                 .push(parse_quote!(arbitrary::Arbitrary<#lifetime>));
65         }
66     }
67     generics
68 }
69 
gen_arbitrary_method(input: &DeriveInput, lifetime: LifetimeDef) -> TokenStream70 fn gen_arbitrary_method(input: &DeriveInput, lifetime: LifetimeDef) -> TokenStream {
71     let ident = &input.ident;
72     let arbitrary_structlike = |fields| {
73         let arbitrary = construct(fields, |_, _| quote!(arbitrary::Arbitrary::arbitrary(u)?));
74         let arbitrary_take_rest = construct_take_rest(fields);
75         quote! {
76             fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
77                 Ok(#ident #arbitrary)
78             }
79 
80             fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
81                 Ok(#ident #arbitrary_take_rest)
82             }
83         }
84     };
85     match &input.data {
86         Data::Struct(data) => arbitrary_structlike(&data.fields),
87         Data::Union(data) => arbitrary_structlike(&Fields::Named(data.fields.clone())),
88         Data::Enum(data) => {
89             let variants = data.variants.iter().enumerate().map(|(i, variant)| {
90                 let idx = i as u64;
91                 let ctor = construct(&variant.fields, |_, _| {
92                     quote!(arbitrary::Arbitrary::arbitrary(u)?)
93                 });
94                 let variant_name = &variant.ident;
95                 quote! { #idx => #ident::#variant_name #ctor }
96             });
97             let variants_take_rest = data.variants.iter().enumerate().map(|(i, variant)| {
98                 let idx = i as u64;
99                 let ctor = construct_take_rest(&variant.fields);
100                 let variant_name = &variant.ident;
101                 quote! { #idx => #ident::#variant_name #ctor }
102             });
103             let count = data.variants.len() as u64;
104             quote! {
105                 fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
106                     // Use a multiply + shift to generate a ranged random number
107                     // with slight bias. For details, see:
108                     // https://lemire.me/blog/2016/06/30/fast-random-shuffling
109                     Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(u)?) * #count) >> 32 {
110                         #(#variants,)*
111                         _ => unreachable!()
112                     })
113                 }
114 
115                 fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
116                     // Use a multiply + shift to generate a ranged random number
117                     // with slight bias. For details, see:
118                     // https://lemire.me/blog/2016/06/30/fast-random-shuffling
119                     Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(&mut u)?) * #count) >> 32 {
120                         #(#variants_take_rest,)*
121                         _ => unreachable!()
122                     })
123                 }
124             }
125         }
126     }
127 }
128 
construct(fields: &Fields, ctor: impl Fn(usize, &Field) -> TokenStream) -> TokenStream129 fn construct(fields: &Fields, ctor: impl Fn(usize, &Field) -> TokenStream) -> TokenStream {
130     match fields {
131         Fields::Named(names) => {
132             let names = names.named.iter().enumerate().map(|(i, f)| {
133                 let name = f.ident.as_ref().unwrap();
134                 let ctor = ctor(i, f);
135                 quote! { #name: #ctor }
136             });
137             quote! { { #(#names,)* } }
138         }
139         Fields::Unnamed(names) => {
140             let names = names.unnamed.iter().enumerate().map(|(i, f)| {
141                 let ctor = ctor(i, f);
142                 quote! { #ctor }
143             });
144             quote! { ( #(#names),* ) }
145         }
146         Fields::Unit => quote!(),
147     }
148 }
149 
construct_take_rest(fields: &Fields) -> TokenStream150 fn construct_take_rest(fields: &Fields) -> TokenStream {
151     construct(fields, |idx, _| {
152         if idx + 1 == fields.len() {
153             quote! { arbitrary::Arbitrary::arbitrary_take_rest(u)? }
154         } else {
155             quote! { arbitrary::Arbitrary::arbitrary(&mut u)? }
156         }
157     })
158 }
159 
gen_size_hint_method(input: &DeriveInput) -> TokenStream160 fn gen_size_hint_method(input: &DeriveInput) -> TokenStream {
161     let size_hint_fields = |fields: &Fields| {
162         let tys = fields.iter().map(|f| &f.ty);
163         quote! {
164             arbitrary::size_hint::and_all(&[
165                 #( <#tys as arbitrary::Arbitrary>::size_hint(depth) ),*
166             ])
167         }
168     };
169     let size_hint_structlike = |fields: &Fields| {
170         let hint = size_hint_fields(fields);
171         quote! {
172             #[inline]
173             fn size_hint(depth: usize) -> (usize, Option<usize>) {
174                 arbitrary::size_hint::recursion_guard(depth, |depth| #hint)
175             }
176         }
177     };
178     match &input.data {
179         Data::Struct(data) => size_hint_structlike(&data.fields),
180         Data::Union(data) => size_hint_structlike(&Fields::Named(data.fields.clone())),
181         Data::Enum(data) => {
182             let variants = data.variants.iter().map(|v| size_hint_fields(&v.fields));
183             quote! {
184                 #[inline]
185                 fn size_hint(depth: usize) -> (usize, Option<usize>) {
186                     arbitrary::size_hint::and(
187                         <u32 as arbitrary::Arbitrary>::size_hint(depth),
188                         arbitrary::size_hint::recursion_guard(depth, |depth| {
189                             arbitrary::size_hint::or_all(&[ #( #variants ),* ])
190                         }),
191                     )
192                 }
193             }
194         }
195     }
196 }
197