1 use proc_macro2::TokenStream;
2 use quote::quote;
3 use syn::{spanned::Spanned as _, Error, Result};
4 
5 use crate::utils::{
6     self, AttrParams, DeriveType, FullMetaInfo, HashSet, MetaInfo, MultiFieldData,
7     State,
8 };
9 
expand( input: &syn::DeriveInput, trait_name: &'static str, ) -> Result<TokenStream>10 pub fn expand(
11     input: &syn::DeriveInput,
12     trait_name: &'static str,
13 ) -> Result<TokenStream> {
14     let syn::DeriveInput {
15         ident, generics, ..
16     } = input;
17 
18     let state = State::with_attr_params(
19         input,
20         trait_name,
21         quote!(::std::error),
22         trait_name.to_lowercase(),
23         allowed_attr_params(),
24     )?;
25 
26     let type_params: HashSet<_> = generics
27         .params
28         .iter()
29         .filter_map(|generic| match generic {
30             syn::GenericParam::Type(ty) => Some(ty.ident.clone()),
31             _ => None,
32         })
33         .collect();
34 
35     let (bounds, source, backtrace) = match state.derive_type {
36         DeriveType::Named | DeriveType::Unnamed => render_struct(&type_params, &state)?,
37         DeriveType::Enum => render_enum(&type_params, &state)?,
38     };
39 
40     let source = source.map(|source| {
41         quote! {
42             fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> {
43                 #source
44             }
45         }
46     });
47 
48     let backtrace = backtrace.map(|backtrace| {
49         quote! {
50             fn backtrace(&self) -> Option<&::std::backtrace::Backtrace> {
51                 #backtrace
52             }
53         }
54     });
55 
56     let mut generics = generics.clone();
57 
58     if !type_params.is_empty() {
59         let generic_parameters = generics.params.iter();
60         generics = utils::add_extra_where_clauses(
61             &generics,
62             quote! {
63                 where
64                     #ident<#(#generic_parameters),*>: ::std::fmt::Debug + ::std::fmt::Display
65             },
66         );
67     }
68 
69     if !bounds.is_empty() {
70         let bounds = bounds.iter();
71         generics = utils::add_extra_where_clauses(
72             &generics,
73             quote! {
74                 where
75                     #(#bounds: ::std::fmt::Debug + ::std::fmt::Display + ::std::error::Error + 'static),*
76             },
77         );
78     }
79 
80     let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
81 
82     let render = quote! {
83         impl#impl_generics ::std::error::Error for #ident#ty_generics #where_clause {
84             #source
85             #backtrace
86         }
87     };
88 
89     Ok(render)
90 }
91 
render_struct( type_params: &HashSet<syn::Ident>, state: &State, ) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)>92 fn render_struct(
93     type_params: &HashSet<syn::Ident>,
94     state: &State,
95 ) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)> {
96     let parsed_fields = parse_fields(type_params, state)?;
97 
98     let source = parsed_fields.render_source_as_struct();
99     let backtrace = parsed_fields.render_backtrace_as_struct();
100 
101     Ok((parsed_fields.bounds, source, backtrace))
102 }
103 
render_enum( type_params: &HashSet<syn::Ident>, state: &State, ) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)>104 fn render_enum(
105     type_params: &HashSet<syn::Ident>,
106     state: &State,
107 ) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)> {
108     let mut bounds = HashSet::default();
109     let mut source_match_arms = Vec::new();
110     let mut backtrace_match_arms = Vec::new();
111 
112     for variant in state.enabled_variant_data().variants {
113         let default_info = FullMetaInfo {
114             enabled: true,
115             ..FullMetaInfo::default()
116         };
117 
118         let state = State::from_variant(
119             state.input,
120             state.trait_name,
121             state.trait_module.clone(),
122             state.trait_attr.clone(),
123             allowed_attr_params(),
124             variant,
125             default_info,
126         )?;
127 
128         let parsed_fields = parse_fields(type_params, &state)?;
129 
130         if let Some(expr) = parsed_fields.render_source_as_enum_variant_match_arm() {
131             source_match_arms.push(expr);
132         }
133 
134         if let Some(expr) = parsed_fields.render_backtrace_as_enum_variant_match_arm() {
135             backtrace_match_arms.push(expr);
136         }
137 
138         bounds.extend(parsed_fields.bounds.into_iter());
139     }
140 
141     let render = |match_arms: &mut Vec<TokenStream>| {
142         if !match_arms.is_empty() && match_arms.len() < state.variants.len() {
143             match_arms.push(quote!(_ => None));
144         }
145 
146         if !match_arms.is_empty() {
147             let expr = quote! {
148                 match self {
149                     #(#match_arms),*
150                 }
151             };
152 
153             Some(expr)
154         } else {
155             None
156         }
157     };
158 
159     let source = render(&mut source_match_arms);
160     let backtrace = render(&mut backtrace_match_arms);
161 
162     Ok((bounds, source, backtrace))
163 }
164 
allowed_attr_params() -> AttrParams165 fn allowed_attr_params() -> AttrParams {
166     AttrParams {
167         enum_: vec!["ignore"],
168         struct_: vec!["ignore"],
169         variant: vec!["ignore"],
170         field: vec!["ignore", "source", "backtrace"],
171     }
172 }
173 
174 struct ParsedFields<'input, 'state> {
175     data: MultiFieldData<'input, 'state>,
176     source: Option<usize>,
177     backtrace: Option<usize>,
178     bounds: HashSet<syn::Type>,
179 }
180 
181 impl<'input, 'state> ParsedFields<'input, 'state> {
new(data: MultiFieldData<'input, 'state>) -> Self182     fn new(data: MultiFieldData<'input, 'state>) -> Self {
183         Self {
184             data,
185             source: None,
186             backtrace: None,
187             bounds: HashSet::default(),
188         }
189     }
190 }
191 
192 impl<'input, 'state> ParsedFields<'input, 'state> {
render_source_as_struct(&self) -> Option<TokenStream>193     fn render_source_as_struct(&self) -> Option<TokenStream> {
194         let source = self.source?;
195         let ident = &self.data.members[source];
196         Some(render_some(quote!(&#ident)))
197     }
198 
render_source_as_enum_variant_match_arm(&self) -> Option<TokenStream>199     fn render_source_as_enum_variant_match_arm(&self) -> Option<TokenStream> {
200         let source = self.source?;
201         let pattern = self.data.matcher(&[source], &[quote!(source)]);
202         let expr = render_some(quote!(source));
203         Some(quote!(#pattern => #expr))
204     }
205 
render_backtrace_as_struct(&self) -> Option<TokenStream>206     fn render_backtrace_as_struct(&self) -> Option<TokenStream> {
207         let backtrace = self.backtrace?;
208         let backtrace_expr = &self.data.members[backtrace];
209         Some(quote!(Some(&#backtrace_expr)))
210     }
211 
render_backtrace_as_enum_variant_match_arm(&self) -> Option<TokenStream>212     fn render_backtrace_as_enum_variant_match_arm(&self) -> Option<TokenStream> {
213         let backtrace = self.backtrace?;
214         let pattern = self.data.matcher(&[backtrace], &[quote!(backtrace)]);
215         Some(quote!(#pattern => Some(backtrace)))
216     }
217 }
218 
render_some<T>(expr: T) -> TokenStream where T: quote::ToTokens,219 fn render_some<T>(expr: T) -> TokenStream
220 where
221     T: quote::ToTokens,
222 {
223     quote!(Some(#expr as &(dyn ::std::error::Error + 'static)))
224 }
225 
parse_fields<'input, 'state>( type_params: &HashSet<syn::Ident>, state: &'state State<'input>, ) -> Result<ParsedFields<'input, 'state>>226 fn parse_fields<'input, 'state>(
227     type_params: &HashSet<syn::Ident>,
228     state: &'state State<'input>,
229 ) -> Result<ParsedFields<'input, 'state>> {
230     let mut parsed_fields = match state.derive_type {
231         DeriveType::Named => {
232             parse_fields_impl(state, |attr, field, _| {
233                 // Unwrapping is safe, cause fields in named struct
234                 // always have an ident
235                 let ident = field.ident.as_ref().unwrap();
236 
237                 match attr {
238                     "source" => ident == "source",
239                     "backtrace" => {
240                         ident == "backtrace"
241                             || is_type_path_ends_with_segment(&field.ty, "Backtrace")
242                     }
243                     _ => unreachable!(),
244                 }
245             })
246         }
247 
248         DeriveType::Unnamed => {
249             let mut parsed_fields =
250                 parse_fields_impl(state, |attr, field, len| match attr {
251                     "source" => {
252                         len == 1
253                             && !is_type_path_ends_with_segment(&field.ty, "Backtrace")
254                     }
255                     "backtrace" => {
256                         is_type_path_ends_with_segment(&field.ty, "Backtrace")
257                     }
258                     _ => unreachable!(),
259                 })?;
260 
261             parsed_fields.source = parsed_fields
262                 .source
263                 .or_else(|| infer_source_field(&state.fields, &parsed_fields));
264 
265             Ok(parsed_fields)
266         }
267 
268         _ => unreachable!(),
269     }?;
270 
271     if let Some(source) = parsed_fields.source {
272         add_bound_if_type_parameter_used_in_type(
273             &mut parsed_fields.bounds,
274             type_params,
275             &state.fields[source].ty,
276         );
277     }
278 
279     Ok(parsed_fields)
280 }
281 
282 /// Checks if `ty` is [`syn::Type::Path`] and ends with segment matching `tail`
283 /// and doesn't contain any generic parameters.
is_type_path_ends_with_segment(ty: &syn::Type, tail: &str) -> bool284 fn is_type_path_ends_with_segment(ty: &syn::Type, tail: &str) -> bool {
285     let ty = match ty {
286         syn::Type::Path(ty) => ty,
287         _ => return false,
288     };
289 
290     // Unwrapping is safe, cause 'syn::TypePath.path.segments'
291     // have to have at least one segment
292     let segment = ty.path.segments.last().unwrap();
293 
294     match segment.arguments {
295         syn::PathArguments::None => (),
296         _ => return false,
297     };
298 
299     segment.ident == tail
300 }
301 
infer_source_field( fields: &[&syn::Field], parsed_fields: &ParsedFields, ) -> Option<usize>302 fn infer_source_field(
303     fields: &[&syn::Field],
304     parsed_fields: &ParsedFields,
305 ) -> Option<usize> {
306     // if we have exactly two fields
307     if fields.len() != 2 {
308         return None;
309     }
310 
311     // no source field was specified/inferred
312     if parsed_fields.source.is_some() {
313         return None;
314     }
315 
316     // but one of the fields was specified/inferred as backtrace field
317     if let Some(backtrace) = parsed_fields.backtrace {
318         // then infer *other field* as source field
319         let source = (backtrace + 1) % 2;
320         // unless it was explicitly marked as non-source
321         if parsed_fields.data.infos[source].info.source != Some(false) {
322             return Some(source);
323         }
324     }
325 
326     None
327 }
328 
parse_fields_impl<'input, 'state, P>( state: &'state State<'input>, is_valid_default_field_for_attr: P, ) -> Result<ParsedFields<'input, 'state>> where P: Fn(&str, &syn::Field, usize) -> bool,329 fn parse_fields_impl<'input, 'state, P>(
330     state: &'state State<'input>,
331     is_valid_default_field_for_attr: P,
332 ) -> Result<ParsedFields<'input, 'state>>
333 where
334     P: Fn(&str, &syn::Field, usize) -> bool,
335 {
336     let MultiFieldData { fields, infos, .. } = state.enabled_fields_data();
337 
338     let iter = fields
339         .iter()
340         .zip(infos.iter().map(|info| &info.info))
341         .enumerate()
342         .map(|(index, (field, info))| (index, *field, info));
343 
344     let source = parse_field_impl(
345         &is_valid_default_field_for_attr,
346         state.fields.len(),
347         iter.clone(),
348         "source",
349         |info| info.source,
350     )?;
351 
352     let backtrace = parse_field_impl(
353         &is_valid_default_field_for_attr,
354         state.fields.len(),
355         iter.clone(),
356         "backtrace",
357         |info| info.backtrace,
358     )?;
359 
360     let mut parsed_fields = ParsedFields::new(state.enabled_fields_data());
361 
362     if let Some((index, _, _)) = source {
363         parsed_fields.source = Some(index);
364     }
365 
366     if let Some((index, _, _)) = backtrace {
367         parsed_fields.backtrace = Some(index);
368     }
369 
370     Ok(parsed_fields)
371 }
372 
parse_field_impl<'a, P, V>( is_valid_default_field_for_attr: &P, len: usize, iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)> + Clone, attr: &str, value: V, ) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>> where P: Fn(&str, &syn::Field, usize) -> bool, V: Fn(&MetaInfo) -> Option<bool>,373 fn parse_field_impl<'a, P, V>(
374     is_valid_default_field_for_attr: &P,
375     len: usize,
376     iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)> + Clone,
377     attr: &str,
378     value: V,
379 ) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>>
380 where
381     P: Fn(&str, &syn::Field, usize) -> bool,
382     V: Fn(&MetaInfo) -> Option<bool>,
383 {
384     let explicit_fields = iter.clone().filter(|(_, _, info)| match value(info) {
385         Some(true) => true,
386         _ => false,
387     });
388 
389     let inferred_fields = iter.filter(|(_, field, info)| match value(info) {
390         None => is_valid_default_field_for_attr(attr, field, len),
391         _ => false,
392     });
393 
394     let field = assert_iter_contains_zero_or_one_item(
395         explicit_fields,
396         &format!(
397             "Multiple `{}` attributes specified. \
398              Single attribute per struct/enum variant allowed.",
399             attr
400         ),
401     )?;
402 
403     let field = match field {
404         field @ Some(_) => field,
405         None => assert_iter_contains_zero_or_one_item(
406             inferred_fields,
407             "Conflicting fields found. Consider specifying some \
408              `#[error(...)]` attributes to resolve conflict.",
409         )?,
410     };
411 
412     Ok(field)
413 }
414 
assert_iter_contains_zero_or_one_item<'a>( mut iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)>, error_msg: &str, ) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>>415 fn assert_iter_contains_zero_or_one_item<'a>(
416     mut iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)>,
417     error_msg: &str,
418 ) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>> {
419     let item = match iter.next() {
420         Some(item) => item,
421         None => return Ok(None),
422     };
423 
424     if let Some((_, field, _)) = iter.next() {
425         return Err(Error::new(field.span(), error_msg));
426     }
427 
428     Ok(Some(item))
429 }
430 
add_bound_if_type_parameter_used_in_type( bounds: &mut HashSet<syn::Type>, type_params: &HashSet<syn::Ident>, ty: &syn::Type, )431 fn add_bound_if_type_parameter_used_in_type(
432     bounds: &mut HashSet<syn::Type>,
433     type_params: &HashSet<syn::Ident>,
434     ty: &syn::Type,
435 ) {
436     if let Some(ty) = utils::get_if_type_parameter_used_in_type(type_params, ty) {
437         bounds.insert(ty);
438     }
439 }
440