1 use crate::ast::{Enum, Field, Input, Struct, Variant};
2 use crate::attr::Attrs;
3 use quote::ToTokens;
4 use std::collections::BTreeSet as Set;
5 use syn::{Error, GenericArgument, Member, PathArguments, Result, Type};
6
7 impl Input<'_> {
validate(&self) -> Result<()>8 pub(crate) fn validate(&self) -> Result<()> {
9 match self {
10 Input::Struct(input) => input.validate(),
11 Input::Enum(input) => input.validate(),
12 }
13 }
14 }
15
16 impl Struct<'_> {
validate(&self) -> Result<()>17 fn validate(&self) -> Result<()> {
18 check_non_field_attrs(&self.attrs)?;
19 if let Some(transparent) = self.attrs.transparent {
20 if self.fields.len() != 1 {
21 return Err(Error::new_spanned(
22 transparent.original,
23 "#[error(transparent)] requires exactly one field",
24 ));
25 }
26 if let Some(source) = self.fields.iter().find_map(|f| f.attrs.source) {
27 return Err(Error::new_spanned(
28 source,
29 "transparent error struct can't contain #[source]",
30 ));
31 }
32 }
33 check_field_attrs(&self.fields)?;
34 for field in &self.fields {
35 field.validate()?;
36 }
37 Ok(())
38 }
39 }
40
41 impl Enum<'_> {
validate(&self) -> Result<()>42 fn validate(&self) -> Result<()> {
43 check_non_field_attrs(&self.attrs)?;
44 let has_display = self.has_display();
45 for variant in &self.variants {
46 variant.validate()?;
47 if has_display && variant.attrs.display.is_none() && variant.attrs.transparent.is_none()
48 {
49 return Err(Error::new_spanned(
50 variant.original,
51 "missing #[error(\"...\")] display attribute",
52 ));
53 }
54 }
55 let mut from_types = Set::new();
56 for variant in &self.variants {
57 if let Some(from_field) = variant.from_field() {
58 let repr = from_field.ty.to_token_stream().to_string();
59 if !from_types.insert(repr) {
60 return Err(Error::new_spanned(
61 from_field.original,
62 "cannot derive From because another variant has the same source type",
63 ));
64 }
65 }
66 }
67 Ok(())
68 }
69 }
70
71 impl Variant<'_> {
validate(&self) -> Result<()>72 fn validate(&self) -> Result<()> {
73 check_non_field_attrs(&self.attrs)?;
74 if self.attrs.transparent.is_some() {
75 if self.fields.len() != 1 {
76 return Err(Error::new_spanned(
77 self.original,
78 "#[error(transparent)] requires exactly one field",
79 ));
80 }
81 if let Some(source) = self.fields.iter().find_map(|f| f.attrs.source) {
82 return Err(Error::new_spanned(
83 source,
84 "transparent variant can't contain #[source]",
85 ));
86 }
87 }
88 check_field_attrs(&self.fields)?;
89 for field in &self.fields {
90 field.validate()?;
91 }
92 Ok(())
93 }
94 }
95
96 impl Field<'_> {
validate(&self) -> Result<()>97 fn validate(&self) -> Result<()> {
98 if let Some(display) = &self.attrs.display {
99 return Err(Error::new_spanned(
100 display.original,
101 "not expected here; the #[error(...)] attribute belongs on top of a struct or an enum variant",
102 ));
103 }
104 Ok(())
105 }
106 }
107
check_non_field_attrs(attrs: &Attrs) -> Result<()>108 fn check_non_field_attrs(attrs: &Attrs) -> Result<()> {
109 if let Some(from) = &attrs.from {
110 return Err(Error::new_spanned(
111 from,
112 "not expected here; the #[from] attribute belongs on a specific field",
113 ));
114 }
115 if let Some(source) = &attrs.source {
116 return Err(Error::new_spanned(
117 source,
118 "not expected here; the #[source] attribute belongs on a specific field",
119 ));
120 }
121 if let Some(backtrace) = &attrs.backtrace {
122 return Err(Error::new_spanned(
123 backtrace,
124 "not expected here; the #[backtrace] attribute belongs on a specific field",
125 ));
126 }
127 if let Some(display) = &attrs.display {
128 if attrs.transparent.is_some() {
129 return Err(Error::new_spanned(
130 display.original,
131 "cannot have both #[error(transparent)] and a display attribute",
132 ));
133 }
134 }
135 Ok(())
136 }
137
check_field_attrs(fields: &[Field]) -> Result<()>138 fn check_field_attrs(fields: &[Field]) -> Result<()> {
139 let mut from_field = None;
140 let mut source_field = None;
141 let mut backtrace_field = None;
142 let mut has_backtrace = false;
143 for field in fields {
144 if let Some(from) = field.attrs.from {
145 if from_field.is_some() {
146 return Err(Error::new_spanned(from, "duplicate #[from] attribute"));
147 }
148 from_field = Some(field);
149 }
150 if let Some(source) = field.attrs.source {
151 if source_field.is_some() {
152 return Err(Error::new_spanned(source, "duplicate #[source] attribute"));
153 }
154 source_field = Some(field);
155 }
156 if let Some(backtrace) = field.attrs.backtrace {
157 if backtrace_field.is_some() {
158 return Err(Error::new_spanned(
159 backtrace,
160 "duplicate #[backtrace] attribute",
161 ));
162 }
163 backtrace_field = Some(field);
164 has_backtrace = true;
165 }
166 if let Some(transparent) = field.attrs.transparent {
167 return Err(Error::new_spanned(
168 transparent.original,
169 "#[error(transparent)] needs to go outside the enum or struct, not on an individual field",
170 ));
171 }
172 has_backtrace |= field.is_backtrace();
173 }
174 if let (Some(from_field), Some(source_field)) = (from_field, source_field) {
175 if !same_member(from_field, source_field) {
176 return Err(Error::new_spanned(
177 from_field.attrs.from,
178 "#[from] is only supported on the source field, not any other field",
179 ));
180 }
181 }
182 if let Some(from_field) = from_field {
183 if fields.len() > 1 + has_backtrace as usize {
184 return Err(Error::new_spanned(
185 from_field.attrs.from,
186 "deriving From requires no fields other than source and backtrace",
187 ));
188 }
189 }
190 if let Some(source_field) = source_field.or(from_field) {
191 if contains_non_static_lifetime(&source_field.ty) {
192 return Err(Error::new_spanned(
193 &source_field.original.ty,
194 "non-static lifetimes are not allowed in the source of an error, because std::error::Error requires the source is dyn Error + 'static",
195 ));
196 }
197 }
198 Ok(())
199 }
200
same_member(one: &Field, two: &Field) -> bool201 fn same_member(one: &Field, two: &Field) -> bool {
202 match (&one.member, &two.member) {
203 (Member::Named(one), Member::Named(two)) => one == two,
204 (Member::Unnamed(one), Member::Unnamed(two)) => one.index == two.index,
205 _ => unreachable!(),
206 }
207 }
208
contains_non_static_lifetime(ty: &Type) -> bool209 fn contains_non_static_lifetime(ty: &Type) -> bool {
210 match ty {
211 Type::Path(ty) => {
212 let bracketed = match &ty.path.segments.last().unwrap().arguments {
213 PathArguments::AngleBracketed(bracketed) => bracketed,
214 _ => return false,
215 };
216 for arg in &bracketed.args {
217 match arg {
218 GenericArgument::Type(ty) if contains_non_static_lifetime(ty) => return true,
219 GenericArgument::Lifetime(lifetime) if lifetime.ident != "static" => {
220 return true
221 }
222 _ => {}
223 }
224 }
225 false
226 }
227 Type::Reference(ty) => ty
228 .lifetime
229 .as_ref()
230 .map_or(false, |lifetime| lifetime.ident != "static"),
231 _ => false, // maybe implement later if there are common other cases
232 }
233 }
234