1 use syn;
2 use proc_macro::{TokenStream, Diagnostic};
3 use proc_macro2::TokenStream as TokenStream2;
4 
5 use spanned::Spanned;
6 use ext::GenericExt;
7 
8 use field::{Field, Fields};
9 use support::{GenericSupport, DataSupport};
10 use derived::{Derived, Variant, Struct, Enum};
11 
12 pub type Result<T> = ::std::result::Result<T, Diagnostic>;
13 pub type MapResult = Result<TokenStream2>;
14 
15 macro_rules! validator {
16     ($fn_name:ident: $validate_fn_type:ty, $field:ident) => {
17         pub fn $fn_name<F: 'static>(&mut self, f: F) -> &mut Self
18             where F: Fn(&DeriveGenerator, $validate_fn_type) -> Result<()>
19         {
20             self.$field = Box::new(f);
21             self
22         }
23     }
24 }
25 
26 macro_rules! mappers {
27     ($(($map_f:ident, $try_f:ident, $get_f:ident): $type:ty, $vec:ident),*) => (
28         crate fn push_default_mappers(&mut self) {
29             $(self.$vec.push(Box::new(concat_idents!(default_, $get_f)));)*
30         }
31 
32         $(
33             pub fn $map_f<F: 'static>(&mut self, f: F) -> &mut Self
34                 where F: Fn(&DeriveGenerator, $type) -> TokenStream2
35             {
36                 if !self.$vec.is_empty() {
37                     let last = self.$vec.len() - 1;
38                     self.$vec[last] = Box::new(move |g, v| Ok(f(g, v)));
39                 }
40 
41                 self
42             }
43 
44             pub fn $try_f<F: 'static>(&mut self, f: F) -> &mut Self
45                 where F: Fn(&DeriveGenerator, $type) -> MapResult
46             {
47                 if !self.$vec.is_empty() {
48                     let last = self.$vec.len() - 1;
49                     self.$vec[last] = Box::new(f);
50                 }
51 
52                 self
53             }
54 
55             pub fn $get_f(&self) -> &Box<Fn(&DeriveGenerator, $type) -> MapResult> {
56                 assert!(!self.$vec.is_empty());
57                 let last = self.$vec.len() - 1;
58                 &self.$vec[last]
59             }
60         )*
61     )
62 }
63 
64 // FIXME: Take a `Box<Fn>` everywhere so we can capture args!
65 pub struct DeriveGenerator {
66     pub input: syn::DeriveInput,
67     pub trait_impl: syn::ItemImpl,
68     pub trait_path: syn::Path,
69     crate generic_support: GenericSupport,
70     crate data_support: DataSupport,
71     crate enum_validator: Box<Fn(&DeriveGenerator, Enum) -> Result<()>>,
72     crate struct_validator: Box<Fn(&DeriveGenerator, Struct) -> Result<()>>,
73     crate generics_validator: Box<Fn(&DeriveGenerator, &::syn::Generics) -> Result<()>>,
74     crate fields_validator: Box<Fn(&DeriveGenerator, Fields) -> Result<()>>,
75     crate type_generic_mapper: Option<Box<Fn(&DeriveGenerator, &syn::Ident, &syn::TypeParam) -> TokenStream2>>,
76     crate generic_replacements: Vec<(usize, usize)>,
77     crate functions: Vec<Box<Fn(&DeriveGenerator, TokenStream2) -> TokenStream2>>,
78     crate enum_mappers: Vec<Box<Fn(&DeriveGenerator, Enum) -> MapResult>>,
79     crate struct_mappers: Vec<Box<Fn(&DeriveGenerator, Struct) -> MapResult>>,
80     crate variant_mappers: Vec<Box<Fn(&DeriveGenerator, Variant) -> MapResult>>,
81     crate fields_mappers: Vec<Box<Fn(&DeriveGenerator, Fields) -> MapResult>>,
82     crate field_mappers: Vec<Box<Fn(&DeriveGenerator, Field) -> MapResult>>,
83 }
84 
default_enum_mapper(gen: &DeriveGenerator, data: Enum) -> MapResult85 pub fn default_enum_mapper(gen: &DeriveGenerator, data: Enum) -> MapResult {
86     let variant = data.variants().map(|v| &v.value.ident);
87     let fields = data.variants().map(|v| v.fields().match_tokens());
88     let enum_name = ::std::iter::repeat(&data.derive_input.ident);
89     let expression = data.variants()
90         .map(|v| gen.variant_mapper()(gen, v))
91         .collect::<Result<Vec<_>>>()?;
92 
93     Ok(quote! {
94         // FIXME: Check if we can also use id_match_tokens due to match
95         // ergonomics. I don't think so, though. If we can't, then ask (in
96         // `function`) whether receiver is `&self`, `&mut self` or `self` and
97         // bind match accordingly.
98         match self {
99             #(#enum_name::#variant #fields => { #expression }),*
100         }
101     })
102 }
103 
null_enum_mapper(gen: &DeriveGenerator, data: Enum) -> MapResult104 pub fn null_enum_mapper(gen: &DeriveGenerator, data: Enum) -> MapResult {
105     let expression = data.variants()
106         .map(|v| gen.variant_mapper()(gen, v))
107         .collect::<Result<Vec<_>>>()?;
108 
109     Ok(quote!(#(#expression)*))
110 }
111 
default_struct_mapper(gen: &DeriveGenerator, data: Struct) -> MapResult112 pub fn default_struct_mapper(gen: &DeriveGenerator, data: Struct) -> MapResult {
113     gen.fields_mapper()(gen, data.fields())
114 }
115 
default_variant_mapper(gen: &DeriveGenerator, data: Variant) -> MapResult116 pub fn default_variant_mapper(gen: &DeriveGenerator, data: Variant) -> MapResult {
117     gen.fields_mapper()(gen, data.fields())
118 }
119 
default_field_mapper(_gen: &DeriveGenerator, _data: Field) -> MapResult120 pub fn default_field_mapper(_gen: &DeriveGenerator, _data: Field) -> MapResult {
121     Ok(TokenStream2::new())
122 }
123 
default_fields_mapper(g: &DeriveGenerator, fields: Fields) -> MapResult124 pub fn default_fields_mapper(g: &DeriveGenerator, fields: Fields) -> MapResult {
125     let field = fields.iter()
126         .map(|field| g.field_mapper()(g, field))
127         .collect::<Result<Vec<_>>>()?;
128 
129     Ok(quote!({ #(#field)* }))
130 }
131 
132 impl DeriveGenerator {
build_for(input: TokenStream, trait_impl: TokenStream2) -> DeriveGenerator133     pub fn build_for(input: TokenStream, trait_impl: TokenStream2) -> DeriveGenerator {
134         let trait_impl: syn::ItemImpl = syn::parse2(quote!(#trait_impl for Foo {}))
135             .expect("invalid impl");
136         let trait_path = trait_impl.trait_.clone().expect("impl does not have trait").1;
137         let input = syn::parse(input).expect("invalid derive input");
138 
139         DeriveGenerator {
140             input, trait_impl, trait_path,
141             generic_support: GenericSupport::None,
142             data_support: DataSupport::None,
143             type_generic_mapper: None,
144             generic_replacements: vec![],
145             enum_validator: Box::new(|_, _| Ok(())),
146             struct_validator: Box::new(|_, _| Ok(())),
147             generics_validator: Box::new(|_, _| Ok(())),
148             fields_validator: Box::new(|_, _| Ok(())),
149             functions: vec![],
150             enum_mappers: vec![],
151             struct_mappers: vec![],
152             variant_mappers: vec![],
153             field_mappers: vec![],
154             fields_mappers: vec![],
155         }
156     }
157 
generic_support(&mut self, support: GenericSupport) -> &mut Self158     pub fn generic_support(&mut self, support: GenericSupport) -> &mut Self {
159         self.generic_support = support;
160         self
161     }
162 
data_support(&mut self, support: DataSupport) -> &mut Self163     pub fn data_support(&mut self, support: DataSupport) -> &mut Self {
164         self.data_support = support;
165         self
166     }
167 
map_type_generic<F: 'static>(&mut self, f: F) -> &mut Self where F: Fn(&DeriveGenerator, &syn::Ident, &syn::TypeParam) -> TokenStream2168     pub fn map_type_generic<F: 'static>(&mut self, f: F) -> &mut Self
169         where F: Fn(&DeriveGenerator, &syn::Ident, &syn::TypeParam) -> TokenStream2
170     {
171         self.type_generic_mapper = Some(Box::new(f));
172         self
173     }
174 
replace_generic(&mut self, trait_gen: usize, impl_gen: usize) -> &mut Self175     pub fn replace_generic(&mut self, trait_gen: usize, impl_gen: usize) -> &mut Self {
176         self.generic_replacements.push((trait_gen, impl_gen));
177         self
178     }
179 
180     validator!(validate_enum: Enum, enum_validator);
181     validator!(validate_struct: Struct, struct_validator);
182     validator!(validate_generics: &syn::Generics, generics_validator);
183     validator!(validate_fields: Fields, fields_validator);
184 
function<F: 'static>(&mut self, f: F) -> &mut Self where F: Fn(&DeriveGenerator, TokenStream2) -> TokenStream2185     pub fn function<F: 'static>(&mut self, f: F) -> &mut Self
186         where F: Fn(&DeriveGenerator, TokenStream2) -> TokenStream2
187     {
188         self.functions.push(Box::new(f));
189         self.push_default_mappers();
190         self
191     }
192 
193     mappers! {
194         (map_struct, try_map_struct, struct_mapper): Struct, struct_mappers,
195         (map_enum, try_map_enum, enum_mapper): Enum, enum_mappers,
196         (map_variant, try_map_variant, variant_mapper): Variant, variant_mappers,
197         (map_fields, try_map_fields, fields_mapper): Fields, fields_mappers,
198         (map_field, try_map_field, field_mapper): Field, field_mappers
199     }
200 
_to_tokens(&mut self) -> Result<TokenStream>201     fn _to_tokens(&mut self) -> Result<TokenStream> {
202         use syn::*;
203 
204         // Step 1: Run all validators.
205         // Step 1a: First, check for data support.
206         let (span, support) = (self.input.span(), self.data_support);
207         match self.input.data {
208             Data::Struct(ref data) => {
209                 let named = Struct::from(&self.input, data).fields().are_named();
210                 if named && !support.contains(DataSupport::NamedStruct) {
211                     return Err(span.error("named structs are not supported"));
212                 }
213 
214                 if !named && !support.contains(DataSupport::TupleStruct) {
215                     return Err(span.error("tuple structs are not supported"));
216                 }
217             }
218             Data::Enum(..) if !support.contains(DataSupport::Enum) => {
219                 return Err(span.error("enums are not supported"));
220             }
221             Data::Union(..) if !support.contains(DataSupport::Union) => {
222                 return Err(span.error("unions are not supported"));
223             }
224             _ => { /* we're okay! */ }
225         }
226 
227         // Step 1b: Second, check for generics support.
228         for generic in &self.input.generics.params {
229             use syn::GenericParam::*;
230 
231             let (span, support) = (generic.span(), self.generic_support);
232             match generic {
233                 Type(..) if !support.contains(GenericSupport::Type) => {
234                     return Err(span.error("type generics are not supported"));
235                 }
236                 Lifetime(..) if !support.contains(GenericSupport::Lifetime) => {
237                     return Err(span.error("lifetime generics are not supported"));
238                 }
239                 Const(..) if !support.contains(GenericSupport::Const) => {
240                     return Err(span.error("const generics are not supported"));
241                 }
242                 _ => { /* we're okay! */ }
243             }
244         }
245 
246         // Step 1c: Third, run the custom validators.
247         (self.generics_validator)(self, &self.input.generics)?;
248         match self.input.data {
249             Data::Struct(ref data) => {
250                 let derived = Derived::from(&self.input, data);
251                 (self.struct_validator)(self, derived)?;
252                 (self.fields_validator)(self, derived.fields())?;
253             }
254             Data::Enum(ref data) => {
255                 let derived = Derived::from(&self.input, data);
256                 (self.enum_validator)(self, derived)?;
257                 for variant in derived.variants() {
258                     (self.fields_validator)(self, variant.fields())?;
259                 }
260             }
261             Data::Union(ref _data) => unimplemented!("union custom validation"),
262         }
263 
264         // Step 2: Generate the code!
265         // Step 2a: Generate the code for each function.
266         let mut function_code = vec![];
267         for i in 0..self.functions.len() {
268             let function = &self.functions[i];
269             let inner = match self.input.data {
270                 Data::Struct(ref data) => {
271                     let derived = Derived::from(&self.input, data);
272                     self.struct_mappers[i](self, derived)?
273                 }
274                 Data::Enum(ref data) => {
275                     let derived = Derived::from(&self.input, data);
276                     self.enum_mappers[i](self, derived)?
277                 }
278                 Data::Union(ref _data) => unimplemented!("can't gen unions yet"),
279             };
280 
281             function_code.push(function(self, inner));
282         }
283 
284         // Step 2b: Create a couple of generics to mutate with user's input.
285         let mut generics = self.input.generics.clone();
286 
287         // Step 2c: Add additional where bounds if the generator asks for it.
288         if let Some(ref type_mapper) = self.type_generic_mapper {
289             for ty in self.input.generics.type_params() {
290                 let new_ty = type_mapper(self, &ty.ident, ty);
291                 let clause = syn::parse2(new_ty).expect("invalid type generic mapping");
292                 generics.make_where_clause().predicates.push(clause);
293             }
294         }
295 
296         // Step 2d: Add any generics in the trait.
297         let mut generics_for_impl_generics = generics.clone();
298         for (i, trait_param) in self.trait_impl.generics.params.iter().enumerate() {
299             // Step 2d.0: Perform a generic replacement if requested. Here,
300             // we determine if a generic (i) in the trait is going to replace a
301             // generic in the user's type (the `jth` of the right kind).
302             let replacement = self.generic_replacements.iter()
303                 .filter(|r| r.0 == i)
304                 .next();
305 
306             if let Some((_, j)) = replacement {
307                 use syn::{punctuated::Punctuated, token::Comma};
308 
309                 // Step 2d.1: Actually perform the replacement.
310                 let replace_in = |ps: &mut Punctuated<GenericParam, Comma>| -> bool {
311                     ps.iter_mut()
312                         .filter(|param| param.kind() == trait_param.kind())
313                         .nth(*j)
314                         .map(|impl_param| *impl_param = trait_param.clone())
315                         .is_some()
316                 };
317 
318                 // Step 2d.2: If it fails, insert a new impl generic.
319                 // NOTE: It's critical that `generics` is attempted first!
320                 // Otherwise, we might replace generics that don't exist in the
321                 // user's type.
322                 if !replace_in(&mut generics.params)
323                     || !replace_in(&mut generics_for_impl_generics.params)
324                 {
325                     generics_for_impl_generics.params.insert(0, trait_param.clone());
326                 }
327             } else {
328                 // Step 2d.2: Otherwise, insert a new impl<..> generic.
329                 generics_for_impl_generics.params.insert(0, trait_param.clone());
330             }
331         }
332 
333         // Step 2e: Split the generics, but use the `impl_generics` from above.
334         let (impl_gen, _, _) = generics_for_impl_generics.split_for_impl();
335         let (_, ty_gen, where_gen) = generics.split_for_impl();
336 
337         // Step 2b: Generate the complete implementation.
338         let target = &self.input.ident;
339         let trait_name = &self.trait_path;
340         Ok(quote! {
341             impl #impl_gen #trait_name for #target #ty_gen #where_gen {
342                 #(#function_code)*
343             }
344         }.into())
345     }
346 
debug(&mut self) -> &mut Self347     pub fn debug(&mut self) -> &mut Self {
348         match self._to_tokens() {
349             Ok(tokens) => println!("Tokens produced: {}", tokens.to_string()),
350             Err(e) => println!("Error produced: {:?}", e)
351         }
352 
353         self
354     }
355 
to_tokens(&mut self) -> TokenStream356     pub fn to_tokens(&mut self) -> TokenStream {
357         // FIXME: Emit something like: Trait: msg.
358         self._to_tokens()
359             .unwrap_or_else(|diag| {
360                 if let Some(last) = self.trait_path.segments.last() {
361                     use proc_macro::Span;
362                     use proc_macro::Level::*;
363 
364                     let id = &last.ident;
365                     let msg = match diag.level() {
366                         Error => format!("error occurred while deriving `{}`", id),
367                         Warning => format!("warning issued by `{}` derive", id),
368                         Note => format!("note issued by `{}` derive", id),
369                         Help => format!("help provided by `{}` derive", id),
370                         _ => format!("while deriving `{}`", id)
371                     };
372 
373                     diag.span_note(Span::call_site(), msg).emit();
374                 }
375 
376                 TokenStream::new().into()
377             })
378     }
379 }
380