1 use syn;
2 use proc_macro::{TokenStream, Diagnostic};
3 use proc_macro2::TokenStream as TokenStream2;
4
5 use spanned::Spanned;
6 use ext::GenericExt;
7
8 use field::{Field, Fields};
9 use support::{GenericSupport, DataSupport};
10 use derived::{Derived, Variant, Struct, Enum};
11
12 pub type Result<T> = ::std::result::Result<T, Diagnostic>;
13 pub type MapResult = Result<TokenStream2>;
14
15 macro_rules! validator {
16 ($fn_name:ident: $validate_fn_type:ty, $field:ident) => {
17 pub fn $fn_name<F: 'static>(&mut self, f: F) -> &mut Self
18 where F: Fn(&DeriveGenerator, $validate_fn_type) -> Result<()>
19 {
20 self.$field = Box::new(f);
21 self
22 }
23 }
24 }
25
26 macro_rules! mappers {
27 ($(($map_f:ident, $try_f:ident, $get_f:ident): $type:ty, $vec:ident),*) => (
28 crate fn push_default_mappers(&mut self) {
29 $(self.$vec.push(Box::new(concat_idents!(default_, $get_f)));)*
30 }
31
32 $(
33 pub fn $map_f<F: 'static>(&mut self, f: F) -> &mut Self
34 where F: Fn(&DeriveGenerator, $type) -> TokenStream2
35 {
36 if !self.$vec.is_empty() {
37 let last = self.$vec.len() - 1;
38 self.$vec[last] = Box::new(move |g, v| Ok(f(g, v)));
39 }
40
41 self
42 }
43
44 pub fn $try_f<F: 'static>(&mut self, f: F) -> &mut Self
45 where F: Fn(&DeriveGenerator, $type) -> MapResult
46 {
47 if !self.$vec.is_empty() {
48 let last = self.$vec.len() - 1;
49 self.$vec[last] = Box::new(f);
50 }
51
52 self
53 }
54
55 pub fn $get_f(&self) -> &Box<Fn(&DeriveGenerator, $type) -> MapResult> {
56 assert!(!self.$vec.is_empty());
57 let last = self.$vec.len() - 1;
58 &self.$vec[last]
59 }
60 )*
61 )
62 }
63
64 // FIXME: Take a `Box<Fn>` everywhere so we can capture args!
65 pub struct DeriveGenerator {
66 pub input: syn::DeriveInput,
67 pub trait_impl: syn::ItemImpl,
68 pub trait_path: syn::Path,
69 crate generic_support: GenericSupport,
70 crate data_support: DataSupport,
71 crate enum_validator: Box<Fn(&DeriveGenerator, Enum) -> Result<()>>,
72 crate struct_validator: Box<Fn(&DeriveGenerator, Struct) -> Result<()>>,
73 crate generics_validator: Box<Fn(&DeriveGenerator, &::syn::Generics) -> Result<()>>,
74 crate fields_validator: Box<Fn(&DeriveGenerator, Fields) -> Result<()>>,
75 crate type_generic_mapper: Option<Box<Fn(&DeriveGenerator, &syn::Ident, &syn::TypeParam) -> TokenStream2>>,
76 crate generic_replacements: Vec<(usize, usize)>,
77 crate functions: Vec<Box<Fn(&DeriveGenerator, TokenStream2) -> TokenStream2>>,
78 crate enum_mappers: Vec<Box<Fn(&DeriveGenerator, Enum) -> MapResult>>,
79 crate struct_mappers: Vec<Box<Fn(&DeriveGenerator, Struct) -> MapResult>>,
80 crate variant_mappers: Vec<Box<Fn(&DeriveGenerator, Variant) -> MapResult>>,
81 crate fields_mappers: Vec<Box<Fn(&DeriveGenerator, Fields) -> MapResult>>,
82 crate field_mappers: Vec<Box<Fn(&DeriveGenerator, Field) -> MapResult>>,
83 }
84
default_enum_mapper(gen: &DeriveGenerator, data: Enum) -> MapResult85 pub fn default_enum_mapper(gen: &DeriveGenerator, data: Enum) -> MapResult {
86 let variant = data.variants().map(|v| &v.value.ident);
87 let fields = data.variants().map(|v| v.fields().match_tokens());
88 let enum_name = ::std::iter::repeat(&data.derive_input.ident);
89 let expression = data.variants()
90 .map(|v| gen.variant_mapper()(gen, v))
91 .collect::<Result<Vec<_>>>()?;
92
93 Ok(quote! {
94 // FIXME: Check if we can also use id_match_tokens due to match
95 // ergonomics. I don't think so, though. If we can't, then ask (in
96 // `function`) whether receiver is `&self`, `&mut self` or `self` and
97 // bind match accordingly.
98 match self {
99 #(#enum_name::#variant #fields => { #expression }),*
100 }
101 })
102 }
103
null_enum_mapper(gen: &DeriveGenerator, data: Enum) -> MapResult104 pub fn null_enum_mapper(gen: &DeriveGenerator, data: Enum) -> MapResult {
105 let expression = data.variants()
106 .map(|v| gen.variant_mapper()(gen, v))
107 .collect::<Result<Vec<_>>>()?;
108
109 Ok(quote!(#(#expression)*))
110 }
111
default_struct_mapper(gen: &DeriveGenerator, data: Struct) -> MapResult112 pub fn default_struct_mapper(gen: &DeriveGenerator, data: Struct) -> MapResult {
113 gen.fields_mapper()(gen, data.fields())
114 }
115
default_variant_mapper(gen: &DeriveGenerator, data: Variant) -> MapResult116 pub fn default_variant_mapper(gen: &DeriveGenerator, data: Variant) -> MapResult {
117 gen.fields_mapper()(gen, data.fields())
118 }
119
default_field_mapper(_gen: &DeriveGenerator, _data: Field) -> MapResult120 pub fn default_field_mapper(_gen: &DeriveGenerator, _data: Field) -> MapResult {
121 Ok(TokenStream2::new())
122 }
123
default_fields_mapper(g: &DeriveGenerator, fields: Fields) -> MapResult124 pub fn default_fields_mapper(g: &DeriveGenerator, fields: Fields) -> MapResult {
125 let field = fields.iter()
126 .map(|field| g.field_mapper()(g, field))
127 .collect::<Result<Vec<_>>>()?;
128
129 Ok(quote!({ #(#field)* }))
130 }
131
132 impl DeriveGenerator {
build_for(input: TokenStream, trait_impl: TokenStream2) -> DeriveGenerator133 pub fn build_for(input: TokenStream, trait_impl: TokenStream2) -> DeriveGenerator {
134 let trait_impl: syn::ItemImpl = syn::parse2(quote!(#trait_impl for Foo {}))
135 .expect("invalid impl");
136 let trait_path = trait_impl.trait_.clone().expect("impl does not have trait").1;
137 let input = syn::parse(input).expect("invalid derive input");
138
139 DeriveGenerator {
140 input, trait_impl, trait_path,
141 generic_support: GenericSupport::None,
142 data_support: DataSupport::None,
143 type_generic_mapper: None,
144 generic_replacements: vec![],
145 enum_validator: Box::new(|_, _| Ok(())),
146 struct_validator: Box::new(|_, _| Ok(())),
147 generics_validator: Box::new(|_, _| Ok(())),
148 fields_validator: Box::new(|_, _| Ok(())),
149 functions: vec![],
150 enum_mappers: vec![],
151 struct_mappers: vec![],
152 variant_mappers: vec![],
153 field_mappers: vec![],
154 fields_mappers: vec![],
155 }
156 }
157
generic_support(&mut self, support: GenericSupport) -> &mut Self158 pub fn generic_support(&mut self, support: GenericSupport) -> &mut Self {
159 self.generic_support = support;
160 self
161 }
162
data_support(&mut self, support: DataSupport) -> &mut Self163 pub fn data_support(&mut self, support: DataSupport) -> &mut Self {
164 self.data_support = support;
165 self
166 }
167
map_type_generic<F: 'static>(&mut self, f: F) -> &mut Self where F: Fn(&DeriveGenerator, &syn::Ident, &syn::TypeParam) -> TokenStream2168 pub fn map_type_generic<F: 'static>(&mut self, f: F) -> &mut Self
169 where F: Fn(&DeriveGenerator, &syn::Ident, &syn::TypeParam) -> TokenStream2
170 {
171 self.type_generic_mapper = Some(Box::new(f));
172 self
173 }
174
replace_generic(&mut self, trait_gen: usize, impl_gen: usize) -> &mut Self175 pub fn replace_generic(&mut self, trait_gen: usize, impl_gen: usize) -> &mut Self {
176 self.generic_replacements.push((trait_gen, impl_gen));
177 self
178 }
179
180 validator!(validate_enum: Enum, enum_validator);
181 validator!(validate_struct: Struct, struct_validator);
182 validator!(validate_generics: &syn::Generics, generics_validator);
183 validator!(validate_fields: Fields, fields_validator);
184
function<F: 'static>(&mut self, f: F) -> &mut Self where F: Fn(&DeriveGenerator, TokenStream2) -> TokenStream2185 pub fn function<F: 'static>(&mut self, f: F) -> &mut Self
186 where F: Fn(&DeriveGenerator, TokenStream2) -> TokenStream2
187 {
188 self.functions.push(Box::new(f));
189 self.push_default_mappers();
190 self
191 }
192
193 mappers! {
194 (map_struct, try_map_struct, struct_mapper): Struct, struct_mappers,
195 (map_enum, try_map_enum, enum_mapper): Enum, enum_mappers,
196 (map_variant, try_map_variant, variant_mapper): Variant, variant_mappers,
197 (map_fields, try_map_fields, fields_mapper): Fields, fields_mappers,
198 (map_field, try_map_field, field_mapper): Field, field_mappers
199 }
200
_to_tokens(&mut self) -> Result<TokenStream>201 fn _to_tokens(&mut self) -> Result<TokenStream> {
202 use syn::*;
203
204 // Step 1: Run all validators.
205 // Step 1a: First, check for data support.
206 let (span, support) = (self.input.span(), self.data_support);
207 match self.input.data {
208 Data::Struct(ref data) => {
209 let named = Struct::from(&self.input, data).fields().are_named();
210 if named && !support.contains(DataSupport::NamedStruct) {
211 return Err(span.error("named structs are not supported"));
212 }
213
214 if !named && !support.contains(DataSupport::TupleStruct) {
215 return Err(span.error("tuple structs are not supported"));
216 }
217 }
218 Data::Enum(..) if !support.contains(DataSupport::Enum) => {
219 return Err(span.error("enums are not supported"));
220 }
221 Data::Union(..) if !support.contains(DataSupport::Union) => {
222 return Err(span.error("unions are not supported"));
223 }
224 _ => { /* we're okay! */ }
225 }
226
227 // Step 1b: Second, check for generics support.
228 for generic in &self.input.generics.params {
229 use syn::GenericParam::*;
230
231 let (span, support) = (generic.span(), self.generic_support);
232 match generic {
233 Type(..) if !support.contains(GenericSupport::Type) => {
234 return Err(span.error("type generics are not supported"));
235 }
236 Lifetime(..) if !support.contains(GenericSupport::Lifetime) => {
237 return Err(span.error("lifetime generics are not supported"));
238 }
239 Const(..) if !support.contains(GenericSupport::Const) => {
240 return Err(span.error("const generics are not supported"));
241 }
242 _ => { /* we're okay! */ }
243 }
244 }
245
246 // Step 1c: Third, run the custom validators.
247 (self.generics_validator)(self, &self.input.generics)?;
248 match self.input.data {
249 Data::Struct(ref data) => {
250 let derived = Derived::from(&self.input, data);
251 (self.struct_validator)(self, derived)?;
252 (self.fields_validator)(self, derived.fields())?;
253 }
254 Data::Enum(ref data) => {
255 let derived = Derived::from(&self.input, data);
256 (self.enum_validator)(self, derived)?;
257 for variant in derived.variants() {
258 (self.fields_validator)(self, variant.fields())?;
259 }
260 }
261 Data::Union(ref _data) => unimplemented!("union custom validation"),
262 }
263
264 // Step 2: Generate the code!
265 // Step 2a: Generate the code for each function.
266 let mut function_code = vec![];
267 for i in 0..self.functions.len() {
268 let function = &self.functions[i];
269 let inner = match self.input.data {
270 Data::Struct(ref data) => {
271 let derived = Derived::from(&self.input, data);
272 self.struct_mappers[i](self, derived)?
273 }
274 Data::Enum(ref data) => {
275 let derived = Derived::from(&self.input, data);
276 self.enum_mappers[i](self, derived)?
277 }
278 Data::Union(ref _data) => unimplemented!("can't gen unions yet"),
279 };
280
281 function_code.push(function(self, inner));
282 }
283
284 // Step 2b: Create a couple of generics to mutate with user's input.
285 let mut generics = self.input.generics.clone();
286
287 // Step 2c: Add additional where bounds if the generator asks for it.
288 if let Some(ref type_mapper) = self.type_generic_mapper {
289 for ty in self.input.generics.type_params() {
290 let new_ty = type_mapper(self, &ty.ident, ty);
291 let clause = syn::parse2(new_ty).expect("invalid type generic mapping");
292 generics.make_where_clause().predicates.push(clause);
293 }
294 }
295
296 // Step 2d: Add any generics in the trait.
297 let mut generics_for_impl_generics = generics.clone();
298 for (i, trait_param) in self.trait_impl.generics.params.iter().enumerate() {
299 // Step 2d.0: Perform a generic replacement if requested. Here,
300 // we determine if a generic (i) in the trait is going to replace a
301 // generic in the user's type (the `jth` of the right kind).
302 let replacement = self.generic_replacements.iter()
303 .filter(|r| r.0 == i)
304 .next();
305
306 if let Some((_, j)) = replacement {
307 use syn::{punctuated::Punctuated, token::Comma};
308
309 // Step 2d.1: Actually perform the replacement.
310 let replace_in = |ps: &mut Punctuated<GenericParam, Comma>| -> bool {
311 ps.iter_mut()
312 .filter(|param| param.kind() == trait_param.kind())
313 .nth(*j)
314 .map(|impl_param| *impl_param = trait_param.clone())
315 .is_some()
316 };
317
318 // Step 2d.2: If it fails, insert a new impl generic.
319 // NOTE: It's critical that `generics` is attempted first!
320 // Otherwise, we might replace generics that don't exist in the
321 // user's type.
322 if !replace_in(&mut generics.params)
323 || !replace_in(&mut generics_for_impl_generics.params)
324 {
325 generics_for_impl_generics.params.insert(0, trait_param.clone());
326 }
327 } else {
328 // Step 2d.2: Otherwise, insert a new impl<..> generic.
329 generics_for_impl_generics.params.insert(0, trait_param.clone());
330 }
331 }
332
333 // Step 2e: Split the generics, but use the `impl_generics` from above.
334 let (impl_gen, _, _) = generics_for_impl_generics.split_for_impl();
335 let (_, ty_gen, where_gen) = generics.split_for_impl();
336
337 // Step 2b: Generate the complete implementation.
338 let target = &self.input.ident;
339 let trait_name = &self.trait_path;
340 Ok(quote! {
341 impl #impl_gen #trait_name for #target #ty_gen #where_gen {
342 #(#function_code)*
343 }
344 }.into())
345 }
346
debug(&mut self) -> &mut Self347 pub fn debug(&mut self) -> &mut Self {
348 match self._to_tokens() {
349 Ok(tokens) => println!("Tokens produced: {}", tokens.to_string()),
350 Err(e) => println!("Error produced: {:?}", e)
351 }
352
353 self
354 }
355
to_tokens(&mut self) -> TokenStream356 pub fn to_tokens(&mut self) -> TokenStream {
357 // FIXME: Emit something like: Trait: msg.
358 self._to_tokens()
359 .unwrap_or_else(|diag| {
360 if let Some(last) = self.trait_path.segments.last() {
361 use proc_macro::Span;
362 use proc_macro::Level::*;
363
364 let id = &last.ident;
365 let msg = match diag.level() {
366 Error => format!("error occurred while deriving `{}`", id),
367 Warning => format!("warning issued by `{}` derive", id),
368 Note => format!("note issued by `{}` derive", id),
369 Help => format!("help provided by `{}` derive", id),
370 _ => format!("while deriving `{}`", id)
371 };
372
373 diag.span_note(Span::call_site(), msg).emit();
374 }
375
376 TokenStream::new().into()
377 })
378 }
379 }
380