1 use proc_macro2::TokenStream;
2 use quote::quote;
3 use syn::{spanned::Spanned as _, Error, Result};
4 
5 use crate::utils::{
6     self, AttrParams, DeriveType, FullMetaInfo, HashSet, MetaInfo, MultiFieldData,
7     State,
8 };
9 
expand( input: &syn::DeriveInput, trait_name: &'static str, ) -> Result<TokenStream>10 pub fn expand(
11     input: &syn::DeriveInput,
12     trait_name: &'static str,
13 ) -> Result<TokenStream> {
14     let syn::DeriveInput {
15         ident, generics, ..
16     } = input;
17 
18     let state = State::with_attr_params(
19         input,
20         trait_name,
21         quote!(::std::error),
22         trait_name.to_lowercase(),
23         allowed_attr_params(),
24     )?;
25 
26     let type_params: HashSet<_> = generics
27         .params
28         .iter()
29         .filter_map(|generic| match generic {
30             syn::GenericParam::Type(ty) => Some(ty.ident.clone()),
31             _ => None,
32         })
33         .collect();
34 
35     let (bounds, source, backtrace) = match state.derive_type {
36         DeriveType::Named | DeriveType::Unnamed => render_struct(&type_params, &state)?,
37         DeriveType::Enum => render_enum(&type_params, &state)?,
38     };
39 
40     let source = source.map(|source| {
41         quote! {
42             fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> {
43                 #source
44             }
45         }
46     });
47 
48     let backtrace = backtrace.map(|backtrace| {
49         quote! {
50             fn backtrace(&self) -> Option<&::std::backtrace::Backtrace> {
51                 #backtrace
52             }
53         }
54     });
55 
56     let mut generics = generics.clone();
57 
58     if !type_params.is_empty() {
59         let generic_parameters = generics.params.iter();
60         generics = utils::add_extra_where_clauses(
61             &generics,
62             quote! {
63                 where
64                     #ident<#(#generic_parameters),*>: ::std::fmt::Debug + ::std::fmt::Display
65             },
66         );
67     }
68 
69     if !bounds.is_empty() {
70         let bounds = bounds.iter();
71         generics = utils::add_extra_where_clauses(
72             &generics,
73             quote! {
74                 where
75                     #(#bounds: ::std::fmt::Debug + ::std::fmt::Display + ::std::error::Error + 'static),*
76             },
77         );
78     }
79 
80     let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
81 
82     let render = quote! {
83         impl#impl_generics ::std::error::Error for #ident#ty_generics #where_clause {
84             #source
85             #backtrace
86         }
87     };
88 
89     Ok(render)
90 }
91 
render_struct( type_params: &HashSet<syn::Ident>, state: &State, ) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)>92 fn render_struct(
93     type_params: &HashSet<syn::Ident>,
94     state: &State,
95 ) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)> {
96     let parsed_fields = parse_fields(&type_params, &state)?;
97 
98     let source = parsed_fields.render_source_as_struct();
99     let backtrace = parsed_fields.render_backtrace_as_struct();
100 
101     Ok((parsed_fields.bounds, source, backtrace))
102 }
103 
render_enum( type_params: &HashSet<syn::Ident>, state: &State, ) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)>104 fn render_enum(
105     type_params: &HashSet<syn::Ident>,
106     state: &State,
107 ) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)> {
108     let mut bounds = HashSet::default();
109     let mut source_match_arms = Vec::new();
110     let mut backtrace_match_arms = Vec::new();
111 
112     for variant in state.enabled_variant_data().variants {
113         let mut default_info = FullMetaInfo::default();
114         default_info.enabled = true;
115 
116         let state = State::from_variant(
117             state.input,
118             state.trait_name,
119             state.trait_module.clone(),
120             state.trait_attr.clone(),
121             allowed_attr_params(),
122             variant,
123             default_info,
124         )?;
125 
126         let parsed_fields = parse_fields(&type_params, &state)?;
127 
128         if let Some(expr) = parsed_fields.render_source_as_enum_variant_match_arm() {
129             source_match_arms.push(expr);
130         }
131 
132         if let Some(expr) = parsed_fields.render_backtrace_as_enum_variant_match_arm() {
133             backtrace_match_arms.push(expr);
134         }
135 
136         bounds.extend(parsed_fields.bounds.into_iter());
137     }
138 
139     let render = |match_arms: &mut Vec<TokenStream>| {
140         if !match_arms.is_empty() && match_arms.len() < state.variants.len() {
141             match_arms.push(quote!(_ => None));
142         }
143 
144         if !match_arms.is_empty() {
145             let expr = quote! {
146                 match self {
147                     #(#match_arms),*
148                 }
149             };
150 
151             Some(expr)
152         } else {
153             None
154         }
155     };
156 
157     let source = render(&mut source_match_arms);
158     let backtrace = render(&mut backtrace_match_arms);
159 
160     Ok((bounds, source, backtrace))
161 }
162 
allowed_attr_params() -> AttrParams163 fn allowed_attr_params() -> AttrParams {
164     AttrParams {
165         enum_: vec!["ignore"],
166         struct_: vec!["ignore"],
167         variant: vec!["ignore"],
168         field: vec!["ignore", "source", "backtrace"],
169     }
170 }
171 
172 struct ParsedFields<'input, 'state> {
173     data: MultiFieldData<'input, 'state>,
174     source: Option<usize>,
175     backtrace: Option<usize>,
176     bounds: HashSet<syn::Type>,
177 }
178 
179 impl<'input, 'state> ParsedFields<'input, 'state> {
new(data: MultiFieldData<'input, 'state>) -> Self180     fn new(data: MultiFieldData<'input, 'state>) -> Self {
181         Self {
182             data,
183             source: None,
184             backtrace: None,
185             bounds: HashSet::default(),
186         }
187     }
188 }
189 
190 impl<'input, 'state> ParsedFields<'input, 'state> {
render_source_as_struct(&self) -> Option<TokenStream>191     fn render_source_as_struct(&self) -> Option<TokenStream> {
192         let source = self.source?;
193         let ident = &self.data.members[source];
194         Some(render_some(quote!(&#ident)))
195     }
196 
render_source_as_enum_variant_match_arm(&self) -> Option<TokenStream>197     fn render_source_as_enum_variant_match_arm(&self) -> Option<TokenStream> {
198         let source = self.source?;
199         let pattern = self.data.matcher(&[source], &[quote!(source)]);
200         let expr = render_some(quote!(source));
201         Some(quote!(#pattern => #expr))
202     }
203 
render_backtrace_as_struct(&self) -> Option<TokenStream>204     fn render_backtrace_as_struct(&self) -> Option<TokenStream> {
205         let backtrace = self.backtrace?;
206         let backtrace_expr = &self.data.members[backtrace];
207         Some(quote!(Some(&#backtrace_expr)))
208     }
209 
render_backtrace_as_enum_variant_match_arm(&self) -> Option<TokenStream>210     fn render_backtrace_as_enum_variant_match_arm(&self) -> Option<TokenStream> {
211         let backtrace = self.backtrace?;
212         let pattern = self.data.matcher(&[backtrace], &[quote!(backtrace)]);
213         Some(quote!(#pattern => Some(backtrace)))
214     }
215 }
216 
render_some<T>(expr: T) -> TokenStream where T: quote::ToTokens,217 fn render_some<T>(expr: T) -> TokenStream
218 where
219     T: quote::ToTokens,
220 {
221     quote!(Some(#expr as &(dyn ::std::error::Error + 'static)))
222 }
223 
parse_fields<'input, 'state>( type_params: &HashSet<syn::Ident>, state: &'state State<'input>, ) -> Result<ParsedFields<'input, 'state>>224 fn parse_fields<'input, 'state>(
225     type_params: &HashSet<syn::Ident>,
226     state: &'state State<'input>,
227 ) -> Result<ParsedFields<'input, 'state>> {
228     let mut parsed_fields = match state.derive_type {
229         DeriveType::Named => {
230             parse_fields_impl(state, |attr, field, _| {
231                 // Unwrapping is safe, cause fields in named struct
232                 // always have an ident
233                 let ident = field.ident.as_ref().unwrap();
234 
235                 match attr {
236                     "source" => ident == "source",
237                     "backtrace" => {
238                         ident == "backtrace"
239                             || is_type_path_ends_with_segment(&field.ty, "Backtrace")
240                     }
241                     _ => unreachable!(),
242                 }
243             })
244         }
245 
246         DeriveType::Unnamed => {
247             let mut parsed_fields =
248                 parse_fields_impl(state, |attr, field, len| match attr {
249                     "source" => {
250                         len == 1
251                             && !is_type_path_ends_with_segment(&field.ty, "Backtrace")
252                     }
253                     "backtrace" => {
254                         is_type_path_ends_with_segment(&field.ty, "Backtrace")
255                     }
256                     _ => unreachable!(),
257                 })?;
258 
259             parsed_fields.source = parsed_fields
260                 .source
261                 .or_else(|| infer_source_field(&state.fields, &parsed_fields));
262 
263             Ok(parsed_fields)
264         }
265 
266         _ => unreachable!(),
267     }?;
268 
269     if let Some(source) = parsed_fields.source {
270         add_bound_if_type_parameter_used_in_type(
271             &mut parsed_fields.bounds,
272             type_params,
273             &state.fields[source].ty,
274         );
275     }
276 
277     Ok(parsed_fields)
278 }
279 
280 /// Checks if `ty` is [`syn::Type::Path`] and ends with segment matching `tail`
281 /// and doesn't contain any generic parameters.
is_type_path_ends_with_segment(ty: &syn::Type, tail: &str) -> bool282 fn is_type_path_ends_with_segment(ty: &syn::Type, tail: &str) -> bool {
283     let ty = match ty {
284         syn::Type::Path(ty) => ty,
285         _ => return false,
286     };
287 
288     // Unwrapping is safe, cause 'syn::TypePath.path.segments'
289     // have to have at least one segment
290     let segment = ty.path.segments.last().unwrap();
291 
292     match segment.arguments {
293         syn::PathArguments::None => (),
294         _ => return false,
295     };
296 
297     segment.ident == tail
298 }
299 
infer_source_field( fields: &[&syn::Field], parsed_fields: &ParsedFields, ) -> Option<usize>300 fn infer_source_field(
301     fields: &[&syn::Field],
302     parsed_fields: &ParsedFields,
303 ) -> Option<usize> {
304     // if we have exactly two fields
305     if fields.len() != 2 {
306         return None;
307     }
308 
309     // no source field was specified/inferred
310     if parsed_fields.source.is_some() {
311         return None;
312     }
313 
314     // but one of the fields was specified/inferred as backtrace field
315     if let Some(backtrace) = parsed_fields.backtrace {
316         // then infer *other field* as source field
317         let source = (backtrace + 1) % 2;
318         // unless it was explicitly marked as non-source
319         if parsed_fields.data.infos[source].info.source != Some(false) {
320             return Some(source);
321         }
322     }
323 
324     None
325 }
326 
parse_fields_impl<'input, 'state, P>( state: &'state State<'input>, is_valid_default_field_for_attr: P, ) -> Result<ParsedFields<'input, 'state>> where P: Fn(&str, &syn::Field, usize) -> bool,327 fn parse_fields_impl<'input, 'state, P>(
328     state: &'state State<'input>,
329     is_valid_default_field_for_attr: P,
330 ) -> Result<ParsedFields<'input, 'state>>
331 where
332     P: Fn(&str, &syn::Field, usize) -> bool,
333 {
334     let MultiFieldData { fields, infos, .. } = state.enabled_fields_data();
335 
336     let iter = fields
337         .iter()
338         .zip(infos.iter().map(|info| &info.info))
339         .enumerate()
340         .map(|(index, (field, info))| (index, *field, info));
341 
342     let source = parse_field_impl(
343         &is_valid_default_field_for_attr,
344         state.fields.len(),
345         iter.clone(),
346         "source",
347         |info| info.source,
348     )?;
349 
350     let backtrace = parse_field_impl(
351         &is_valid_default_field_for_attr,
352         state.fields.len(),
353         iter.clone(),
354         "backtrace",
355         |info| info.backtrace,
356     )?;
357 
358     let mut parsed_fields = ParsedFields::new(state.enabled_fields_data());
359 
360     if let Some((index, _, _)) = source {
361         parsed_fields.source = Some(index);
362     }
363 
364     if let Some((index, _, _)) = backtrace {
365         parsed_fields.backtrace = Some(index);
366     }
367 
368     Ok(parsed_fields)
369 }
370 
parse_field_impl<'a, P, V>( is_valid_default_field_for_attr: &P, len: usize, iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)> + Clone, attr: &str, value: V, ) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>> where P: Fn(&str, &syn::Field, usize) -> bool, V: Fn(&MetaInfo) -> Option<bool>,371 fn parse_field_impl<'a, P, V>(
372     is_valid_default_field_for_attr: &P,
373     len: usize,
374     iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)> + Clone,
375     attr: &str,
376     value: V,
377 ) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>>
378 where
379     P: Fn(&str, &syn::Field, usize) -> bool,
380     V: Fn(&MetaInfo) -> Option<bool>,
381 {
382     let explicit_fields = iter.clone().filter(|(_, _, info)| match value(info) {
383         Some(true) => true,
384         _ => false,
385     });
386 
387     let inferred_fields = iter.filter(|(_, field, info)| match value(info) {
388         None => is_valid_default_field_for_attr(attr, field, len),
389         _ => false,
390     });
391 
392     let field = assert_iter_contains_zero_or_one_item(
393         explicit_fields,
394         &format!(
395             "Multiple `{}` attributes specified. \
396              Single attribute per struct/enum variant allowed.",
397             attr
398         ),
399     )?;
400 
401     let field = match field {
402         field @ Some(_) => field,
403         None => assert_iter_contains_zero_or_one_item(
404             inferred_fields,
405             "Conflicting fields found. Consider specifying some \
406              `#[error(...)]` attributes to resolve conflict.",
407         )?,
408     };
409 
410     Ok(field)
411 }
412 
assert_iter_contains_zero_or_one_item<'a>( mut iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)>, error_msg: &str, ) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>>413 fn assert_iter_contains_zero_or_one_item<'a>(
414     mut iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)>,
415     error_msg: &str,
416 ) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>> {
417     let item = match iter.next() {
418         Some(item) => item,
419         None => return Ok(None),
420     };
421 
422     if let Some((_, field, _)) = iter.next() {
423         return Err(Error::new(field.span(), error_msg));
424     }
425 
426     Ok(Some(item))
427 }
428 
add_bound_if_type_parameter_used_in_type( bounds: &mut HashSet<syn::Type>, type_params: &HashSet<syn::Ident>, ty: &syn::Type, )429 fn add_bound_if_type_parameter_used_in_type(
430     bounds: &mut HashSet<syn::Type>,
431     type_params: &HashSet<syn::Ident>,
432     ty: &syn::Type,
433 ) {
434     if let Some(ty) = utils::get_if_type_parameter_used_in_type(type_params, ty) {
435         bounds.insert(ty);
436     }
437 }
438