1 use proc_macro2::TokenStream;
2 use quote::quote;
3 use syn::{spanned::Spanned as _, Error, Result};
4
5 use crate::utils::{
6 self, AttrParams, DeriveType, FullMetaInfo, HashSet, MetaInfo, MultiFieldData,
7 State,
8 };
9
expand( input: &syn::DeriveInput, trait_name: &'static str, ) -> Result<TokenStream>10 pub fn expand(
11 input: &syn::DeriveInput,
12 trait_name: &'static str,
13 ) -> Result<TokenStream> {
14 let syn::DeriveInput {
15 ident, generics, ..
16 } = input;
17
18 let state = State::with_attr_params(
19 input,
20 trait_name,
21 quote!(::std::error),
22 trait_name.to_lowercase(),
23 allowed_attr_params(),
24 )?;
25
26 let type_params: HashSet<_> = generics
27 .params
28 .iter()
29 .filter_map(|generic| match generic {
30 syn::GenericParam::Type(ty) => Some(ty.ident.clone()),
31 _ => None,
32 })
33 .collect();
34
35 let (bounds, source, backtrace) = match state.derive_type {
36 DeriveType::Named | DeriveType::Unnamed => render_struct(&type_params, &state)?,
37 DeriveType::Enum => render_enum(&type_params, &state)?,
38 };
39
40 let source = source.map(|source| {
41 quote! {
42 fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> {
43 #source
44 }
45 }
46 });
47
48 let backtrace = backtrace.map(|backtrace| {
49 quote! {
50 fn backtrace(&self) -> Option<&::std::backtrace::Backtrace> {
51 #backtrace
52 }
53 }
54 });
55
56 let mut generics = generics.clone();
57
58 if !type_params.is_empty() {
59 let generic_parameters = generics.params.iter();
60 generics = utils::add_extra_where_clauses(
61 &generics,
62 quote! {
63 where
64 #ident<#(#generic_parameters),*>: ::std::fmt::Debug + ::std::fmt::Display
65 },
66 );
67 }
68
69 if !bounds.is_empty() {
70 let bounds = bounds.iter();
71 generics = utils::add_extra_where_clauses(
72 &generics,
73 quote! {
74 where
75 #(#bounds: ::std::fmt::Debug + ::std::fmt::Display + ::std::error::Error + 'static),*
76 },
77 );
78 }
79
80 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
81
82 let render = quote! {
83 impl#impl_generics ::std::error::Error for #ident#ty_generics #where_clause {
84 #source
85 #backtrace
86 }
87 };
88
89 Ok(render)
90 }
91
render_struct( type_params: &HashSet<syn::Ident>, state: &State, ) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)>92 fn render_struct(
93 type_params: &HashSet<syn::Ident>,
94 state: &State,
95 ) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)> {
96 let parsed_fields = parse_fields(&type_params, &state)?;
97
98 let source = parsed_fields.render_source_as_struct();
99 let backtrace = parsed_fields.render_backtrace_as_struct();
100
101 Ok((parsed_fields.bounds, source, backtrace))
102 }
103
render_enum( type_params: &HashSet<syn::Ident>, state: &State, ) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)>104 fn render_enum(
105 type_params: &HashSet<syn::Ident>,
106 state: &State,
107 ) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)> {
108 let mut bounds = HashSet::default();
109 let mut source_match_arms = Vec::new();
110 let mut backtrace_match_arms = Vec::new();
111
112 for variant in state.enabled_variant_data().variants {
113 let mut default_info = FullMetaInfo::default();
114 default_info.enabled = true;
115
116 let state = State::from_variant(
117 state.input,
118 state.trait_name,
119 state.trait_module.clone(),
120 state.trait_attr.clone(),
121 allowed_attr_params(),
122 variant,
123 default_info,
124 )?;
125
126 let parsed_fields = parse_fields(&type_params, &state)?;
127
128 if let Some(expr) = parsed_fields.render_source_as_enum_variant_match_arm() {
129 source_match_arms.push(expr);
130 }
131
132 if let Some(expr) = parsed_fields.render_backtrace_as_enum_variant_match_arm() {
133 backtrace_match_arms.push(expr);
134 }
135
136 bounds.extend(parsed_fields.bounds.into_iter());
137 }
138
139 let render = |match_arms: &mut Vec<TokenStream>| {
140 if !match_arms.is_empty() && match_arms.len() < state.variants.len() {
141 match_arms.push(quote!(_ => None));
142 }
143
144 if !match_arms.is_empty() {
145 let expr = quote! {
146 match self {
147 #(#match_arms),*
148 }
149 };
150
151 Some(expr)
152 } else {
153 None
154 }
155 };
156
157 let source = render(&mut source_match_arms);
158 let backtrace = render(&mut backtrace_match_arms);
159
160 Ok((bounds, source, backtrace))
161 }
162
allowed_attr_params() -> AttrParams163 fn allowed_attr_params() -> AttrParams {
164 AttrParams {
165 enum_: vec!["ignore"],
166 struct_: vec!["ignore"],
167 variant: vec!["ignore"],
168 field: vec!["ignore", "source", "backtrace"],
169 }
170 }
171
172 struct ParsedFields<'input, 'state> {
173 data: MultiFieldData<'input, 'state>,
174 source: Option<usize>,
175 backtrace: Option<usize>,
176 bounds: HashSet<syn::Type>,
177 }
178
179 impl<'input, 'state> ParsedFields<'input, 'state> {
new(data: MultiFieldData<'input, 'state>) -> Self180 fn new(data: MultiFieldData<'input, 'state>) -> Self {
181 Self {
182 data,
183 source: None,
184 backtrace: None,
185 bounds: HashSet::default(),
186 }
187 }
188 }
189
190 impl<'input, 'state> ParsedFields<'input, 'state> {
render_source_as_struct(&self) -> Option<TokenStream>191 fn render_source_as_struct(&self) -> Option<TokenStream> {
192 let source = self.source?;
193 let ident = &self.data.members[source];
194 Some(render_some(quote!(&#ident)))
195 }
196
render_source_as_enum_variant_match_arm(&self) -> Option<TokenStream>197 fn render_source_as_enum_variant_match_arm(&self) -> Option<TokenStream> {
198 let source = self.source?;
199 let pattern = self.data.matcher(&[source], &[quote!(source)]);
200 let expr = render_some(quote!(source));
201 Some(quote!(#pattern => #expr))
202 }
203
render_backtrace_as_struct(&self) -> Option<TokenStream>204 fn render_backtrace_as_struct(&self) -> Option<TokenStream> {
205 let backtrace = self.backtrace?;
206 let backtrace_expr = &self.data.members[backtrace];
207 Some(quote!(Some(&#backtrace_expr)))
208 }
209
render_backtrace_as_enum_variant_match_arm(&self) -> Option<TokenStream>210 fn render_backtrace_as_enum_variant_match_arm(&self) -> Option<TokenStream> {
211 let backtrace = self.backtrace?;
212 let pattern = self.data.matcher(&[backtrace], &[quote!(backtrace)]);
213 Some(quote!(#pattern => Some(backtrace)))
214 }
215 }
216
render_some<T>(expr: T) -> TokenStream where T: quote::ToTokens,217 fn render_some<T>(expr: T) -> TokenStream
218 where
219 T: quote::ToTokens,
220 {
221 quote!(Some(#expr as &(dyn ::std::error::Error + 'static)))
222 }
223
parse_fields<'input, 'state>( type_params: &HashSet<syn::Ident>, state: &'state State<'input>, ) -> Result<ParsedFields<'input, 'state>>224 fn parse_fields<'input, 'state>(
225 type_params: &HashSet<syn::Ident>,
226 state: &'state State<'input>,
227 ) -> Result<ParsedFields<'input, 'state>> {
228 let mut parsed_fields = match state.derive_type {
229 DeriveType::Named => {
230 parse_fields_impl(state, |attr, field, _| {
231 // Unwrapping is safe, cause fields in named struct
232 // always have an ident
233 let ident = field.ident.as_ref().unwrap();
234
235 match attr {
236 "source" => ident == "source",
237 "backtrace" => {
238 ident == "backtrace"
239 || is_type_path_ends_with_segment(&field.ty, "Backtrace")
240 }
241 _ => unreachable!(),
242 }
243 })
244 }
245
246 DeriveType::Unnamed => {
247 let mut parsed_fields =
248 parse_fields_impl(state, |attr, field, len| match attr {
249 "source" => {
250 len == 1
251 && !is_type_path_ends_with_segment(&field.ty, "Backtrace")
252 }
253 "backtrace" => {
254 is_type_path_ends_with_segment(&field.ty, "Backtrace")
255 }
256 _ => unreachable!(),
257 })?;
258
259 parsed_fields.source = parsed_fields
260 .source
261 .or_else(|| infer_source_field(&state.fields, &parsed_fields));
262
263 Ok(parsed_fields)
264 }
265
266 _ => unreachable!(),
267 }?;
268
269 if let Some(source) = parsed_fields.source {
270 add_bound_if_type_parameter_used_in_type(
271 &mut parsed_fields.bounds,
272 type_params,
273 &state.fields[source].ty,
274 );
275 }
276
277 Ok(parsed_fields)
278 }
279
280 /// Checks if `ty` is [`syn::Type::Path`] and ends with segment matching `tail`
281 /// and doesn't contain any generic parameters.
is_type_path_ends_with_segment(ty: &syn::Type, tail: &str) -> bool282 fn is_type_path_ends_with_segment(ty: &syn::Type, tail: &str) -> bool {
283 let ty = match ty {
284 syn::Type::Path(ty) => ty,
285 _ => return false,
286 };
287
288 // Unwrapping is safe, cause 'syn::TypePath.path.segments'
289 // have to have at least one segment
290 let segment = ty.path.segments.last().unwrap();
291
292 match segment.arguments {
293 syn::PathArguments::None => (),
294 _ => return false,
295 };
296
297 segment.ident == tail
298 }
299
infer_source_field( fields: &[&syn::Field], parsed_fields: &ParsedFields, ) -> Option<usize>300 fn infer_source_field(
301 fields: &[&syn::Field],
302 parsed_fields: &ParsedFields,
303 ) -> Option<usize> {
304 // if we have exactly two fields
305 if fields.len() != 2 {
306 return None;
307 }
308
309 // no source field was specified/inferred
310 if parsed_fields.source.is_some() {
311 return None;
312 }
313
314 // but one of the fields was specified/inferred as backtrace field
315 if let Some(backtrace) = parsed_fields.backtrace {
316 // then infer *other field* as source field
317 let source = (backtrace + 1) % 2;
318 // unless it was explicitly marked as non-source
319 if parsed_fields.data.infos[source].info.source != Some(false) {
320 return Some(source);
321 }
322 }
323
324 None
325 }
326
parse_fields_impl<'input, 'state, P>( state: &'state State<'input>, is_valid_default_field_for_attr: P, ) -> Result<ParsedFields<'input, 'state>> where P: Fn(&str, &syn::Field, usize) -> bool,327 fn parse_fields_impl<'input, 'state, P>(
328 state: &'state State<'input>,
329 is_valid_default_field_for_attr: P,
330 ) -> Result<ParsedFields<'input, 'state>>
331 where
332 P: Fn(&str, &syn::Field, usize) -> bool,
333 {
334 let MultiFieldData { fields, infos, .. } = state.enabled_fields_data();
335
336 let iter = fields
337 .iter()
338 .zip(infos.iter().map(|info| &info.info))
339 .enumerate()
340 .map(|(index, (field, info))| (index, *field, info));
341
342 let source = parse_field_impl(
343 &is_valid_default_field_for_attr,
344 state.fields.len(),
345 iter.clone(),
346 "source",
347 |info| info.source,
348 )?;
349
350 let backtrace = parse_field_impl(
351 &is_valid_default_field_for_attr,
352 state.fields.len(),
353 iter.clone(),
354 "backtrace",
355 |info| info.backtrace,
356 )?;
357
358 let mut parsed_fields = ParsedFields::new(state.enabled_fields_data());
359
360 if let Some((index, _, _)) = source {
361 parsed_fields.source = Some(index);
362 }
363
364 if let Some((index, _, _)) = backtrace {
365 parsed_fields.backtrace = Some(index);
366 }
367
368 Ok(parsed_fields)
369 }
370
parse_field_impl<'a, P, V>( is_valid_default_field_for_attr: &P, len: usize, iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)> + Clone, attr: &str, value: V, ) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>> where P: Fn(&str, &syn::Field, usize) -> bool, V: Fn(&MetaInfo) -> Option<bool>,371 fn parse_field_impl<'a, P, V>(
372 is_valid_default_field_for_attr: &P,
373 len: usize,
374 iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)> + Clone,
375 attr: &str,
376 value: V,
377 ) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>>
378 where
379 P: Fn(&str, &syn::Field, usize) -> bool,
380 V: Fn(&MetaInfo) -> Option<bool>,
381 {
382 let explicit_fields = iter.clone().filter(|(_, _, info)| match value(info) {
383 Some(true) => true,
384 _ => false,
385 });
386
387 let inferred_fields = iter.filter(|(_, field, info)| match value(info) {
388 None => is_valid_default_field_for_attr(attr, field, len),
389 _ => false,
390 });
391
392 let field = assert_iter_contains_zero_or_one_item(
393 explicit_fields,
394 &format!(
395 "Multiple `{}` attributes specified. \
396 Single attribute per struct/enum variant allowed.",
397 attr
398 ),
399 )?;
400
401 let field = match field {
402 field @ Some(_) => field,
403 None => assert_iter_contains_zero_or_one_item(
404 inferred_fields,
405 "Conflicting fields found. Consider specifying some \
406 `#[error(...)]` attributes to resolve conflict.",
407 )?,
408 };
409
410 Ok(field)
411 }
412
assert_iter_contains_zero_or_one_item<'a>( mut iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)>, error_msg: &str, ) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>>413 fn assert_iter_contains_zero_or_one_item<'a>(
414 mut iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)>,
415 error_msg: &str,
416 ) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>> {
417 let item = match iter.next() {
418 Some(item) => item,
419 None => return Ok(None),
420 };
421
422 if let Some((_, field, _)) = iter.next() {
423 return Err(Error::new(field.span(), error_msg));
424 }
425
426 Ok(Some(item))
427 }
428
add_bound_if_type_parameter_used_in_type( bounds: &mut HashSet<syn::Type>, type_params: &HashSet<syn::Ident>, ty: &syn::Type, )429 fn add_bound_if_type_parameter_used_in_type(
430 bounds: &mut HashSet<syn::Type>,
431 type_params: &HashSet<syn::Ident>,
432 ty: &syn::Type,
433 ) {
434 if let Some(ty) = utils::get_if_type_parameter_used_in_type(type_params, ty) {
435 bounds.insert(ty);
436 }
437 }
438