1 use crate::ast::{Enum, Field, Input, Struct, Variant};
2 use crate::attr::Attrs;
3 use quote::ToTokens;
4 use std::collections::BTreeSet as Set;
5 use syn::{Error, GenericArgument, Member, PathArguments, Result, Type};
6 
7 impl Input<'_> {
validate(&self) -> Result<()>8     pub(crate) fn validate(&self) -> Result<()> {
9         match self {
10             Input::Struct(input) => input.validate(),
11             Input::Enum(input) => input.validate(),
12         }
13     }
14 }
15 
16 impl Struct<'_> {
validate(&self) -> Result<()>17     fn validate(&self) -> Result<()> {
18         check_non_field_attrs(&self.attrs)?;
19         if let Some(transparent) = self.attrs.transparent {
20             if self.fields.len() != 1 {
21                 return Err(Error::new_spanned(
22                     transparent.original,
23                     "#[error(transparent)] requires exactly one field",
24                 ));
25             }
26             if let Some(source) = self.fields.iter().find_map(|f| f.attrs.source) {
27                 return Err(Error::new_spanned(
28                     source,
29                     "transparent error struct can't contain #[source]",
30                 ));
31             }
32         }
33         check_field_attrs(&self.fields)?;
34         for field in &self.fields {
35             field.validate()?;
36         }
37         Ok(())
38     }
39 }
40 
41 impl Enum<'_> {
validate(&self) -> Result<()>42     fn validate(&self) -> Result<()> {
43         check_non_field_attrs(&self.attrs)?;
44         let has_display = self.has_display();
45         for variant in &self.variants {
46             variant.validate()?;
47             if has_display && variant.attrs.display.is_none() && variant.attrs.transparent.is_none()
48             {
49                 return Err(Error::new_spanned(
50                     variant.original,
51                     "missing #[error(\"...\")] display attribute",
52                 ));
53             }
54         }
55         let mut from_types = Set::new();
56         for variant in &self.variants {
57             if let Some(from_field) = variant.from_field() {
58                 let repr = from_field.ty.to_token_stream().to_string();
59                 if !from_types.insert(repr) {
60                     return Err(Error::new_spanned(
61                         from_field.original,
62                         "cannot derive From because another variant has the same source type",
63                     ));
64                 }
65             }
66         }
67         Ok(())
68     }
69 }
70 
71 impl Variant<'_> {
validate(&self) -> Result<()>72     fn validate(&self) -> Result<()> {
73         check_non_field_attrs(&self.attrs)?;
74         if self.attrs.transparent.is_some() {
75             if self.fields.len() != 1 {
76                 return Err(Error::new_spanned(
77                     self.original,
78                     "#[error(transparent)] requires exactly one field",
79                 ));
80             }
81             if let Some(source) = self.fields.iter().find_map(|f| f.attrs.source) {
82                 return Err(Error::new_spanned(
83                     source,
84                     "transparent variant can't contain #[source]",
85                 ));
86             }
87         }
88         check_field_attrs(&self.fields)?;
89         for field in &self.fields {
90             field.validate()?;
91         }
92         Ok(())
93     }
94 }
95 
96 impl Field<'_> {
validate(&self) -> Result<()>97     fn validate(&self) -> Result<()> {
98         if let Some(display) = &self.attrs.display {
99             return Err(Error::new_spanned(
100                 display.original,
101                 "not expected here; the #[error(...)] attribute belongs on top of a struct or an enum variant",
102             ));
103         }
104         Ok(())
105     }
106 }
107 
check_non_field_attrs(attrs: &Attrs) -> Result<()>108 fn check_non_field_attrs(attrs: &Attrs) -> Result<()> {
109     if let Some(from) = &attrs.from {
110         return Err(Error::new_spanned(
111             from,
112             "not expected here; the #[from] attribute belongs on a specific field",
113         ));
114     }
115     if let Some(source) = &attrs.source {
116         return Err(Error::new_spanned(
117             source,
118             "not expected here; the #[source] attribute belongs on a specific field",
119         ));
120     }
121     if let Some(backtrace) = &attrs.backtrace {
122         return Err(Error::new_spanned(
123             backtrace,
124             "not expected here; the #[backtrace] attribute belongs on a specific field",
125         ));
126     }
127     if let Some(display) = &attrs.display {
128         if attrs.transparent.is_some() {
129             return Err(Error::new_spanned(
130                 display.original,
131                 "cannot have both #[error(transparent)] and a display attribute",
132             ));
133         }
134     }
135     Ok(())
136 }
137 
check_field_attrs(fields: &[Field]) -> Result<()>138 fn check_field_attrs(fields: &[Field]) -> Result<()> {
139     let mut from_field = None;
140     let mut source_field = None;
141     let mut backtrace_field = None;
142     let mut has_backtrace = false;
143     for field in fields {
144         if let Some(from) = field.attrs.from {
145             if from_field.is_some() {
146                 return Err(Error::new_spanned(from, "duplicate #[from] attribute"));
147             }
148             from_field = Some(field);
149         }
150         if let Some(source) = field.attrs.source {
151             if source_field.is_some() {
152                 return Err(Error::new_spanned(source, "duplicate #[source] attribute"));
153             }
154             source_field = Some(field);
155         }
156         if let Some(backtrace) = field.attrs.backtrace {
157             if backtrace_field.is_some() {
158                 return Err(Error::new_spanned(
159                     backtrace,
160                     "duplicate #[backtrace] attribute",
161                 ));
162             }
163             backtrace_field = Some(field);
164             has_backtrace = true;
165         }
166         if let Some(transparent) = field.attrs.transparent {
167             return Err(Error::new_spanned(
168                 transparent.original,
169                 "#[error(transparent)] needs to go outside the enum or struct, not on an individual field",
170             ));
171         }
172         has_backtrace |= field.is_backtrace();
173     }
174     if let (Some(from_field), Some(source_field)) = (from_field, source_field) {
175         if !same_member(from_field, source_field) {
176             return Err(Error::new_spanned(
177                 from_field.attrs.from,
178                 "#[from] is only supported on the source field, not any other field",
179             ));
180         }
181     }
182     if let Some(from_field) = from_field {
183         if fields.len() > 1 + has_backtrace as usize {
184             return Err(Error::new_spanned(
185                 from_field.attrs.from,
186                 "deriving From requires no fields other than source and backtrace",
187             ));
188         }
189     }
190     if let Some(source_field) = source_field.or(from_field) {
191         if contains_non_static_lifetime(&source_field.ty) {
192             return Err(Error::new_spanned(
193                 &source_field.original.ty,
194                 "non-static lifetimes are not allowed in the source of an error, because std::error::Error requires the source is dyn Error + 'static",
195             ));
196         }
197     }
198     Ok(())
199 }
200 
same_member(one: &Field, two: &Field) -> bool201 fn same_member(one: &Field, two: &Field) -> bool {
202     match (&one.member, &two.member) {
203         (Member::Named(one), Member::Named(two)) => one == two,
204         (Member::Unnamed(one), Member::Unnamed(two)) => one.index == two.index,
205         _ => unreachable!(),
206     }
207 }
208 
contains_non_static_lifetime(ty: &Type) -> bool209 fn contains_non_static_lifetime(ty: &Type) -> bool {
210     match ty {
211         Type::Path(ty) => {
212             let bracketed = match &ty.path.segments.last().unwrap().arguments {
213                 PathArguments::AngleBracketed(bracketed) => bracketed,
214                 _ => return false,
215             };
216             for arg in &bracketed.args {
217                 match arg {
218                     GenericArgument::Type(ty) if contains_non_static_lifetime(ty) => return true,
219                     GenericArgument::Lifetime(lifetime) if lifetime.ident != "static" => {
220                         return true
221                     }
222                     _ => {}
223                 }
224             }
225             false
226         }
227         Type::Reference(ty) => ty
228             .lifetime
229             .as_ref()
230             .map_or(false, |lifetime| lifetime.ident != "static"),
231         _ => false, // maybe implement later if there are common other cases
232     }
233 }
234