1 use std::collections::HashSet;
2 
3 use syn;
4 use syn::punctuated::{Pair, Punctuated};
5 use syn::visit::{self, Visit};
6 
7 use internals::ast::{Container, Data};
8 use internals::attr;
9 
10 use proc_macro2::Span;
11 
12 // Remove the default from every type parameter because in the generated impls
13 // they look like associated types: "error: associated type bindings are not
14 // allowed here".
without_defaults(generics: &syn::Generics) -> syn::Generics15 pub fn without_defaults(generics: &syn::Generics) -> syn::Generics {
16     syn::Generics {
17         params: generics
18             .params
19             .iter()
20             .map(|param| match param {
21                 syn::GenericParam::Type(param) => syn::GenericParam::Type(syn::TypeParam {
22                     eq_token: None,
23                     default: None,
24                     ..param.clone()
25                 }),
26                 _ => param.clone(),
27             })
28             .collect(),
29         ..generics.clone()
30     }
31 }
32 
with_where_predicates( generics: &syn::Generics, predicates: &[syn::WherePredicate], ) -> syn::Generics33 pub fn with_where_predicates(
34     generics: &syn::Generics,
35     predicates: &[syn::WherePredicate],
36 ) -> syn::Generics {
37     let mut generics = generics.clone();
38     generics
39         .make_where_clause()
40         .predicates
41         .extend(predicates.iter().cloned());
42     generics
43 }
44 
with_where_predicates_from_fields( cont: &Container, generics: &syn::Generics, from_field: fn(&attr::Field) -> Option<&[syn::WherePredicate]>, ) -> syn::Generics45 pub fn with_where_predicates_from_fields(
46     cont: &Container,
47     generics: &syn::Generics,
48     from_field: fn(&attr::Field) -> Option<&[syn::WherePredicate]>,
49 ) -> syn::Generics {
50     let predicates = cont
51         .data
52         .all_fields()
53         .flat_map(|field| from_field(&field.attrs))
54         .flat_map(|predicates| predicates.to_vec());
55 
56     let mut generics = generics.clone();
57     generics.make_where_clause().predicates.extend(predicates);
58     generics
59 }
60 
with_where_predicates_from_variants( cont: &Container, generics: &syn::Generics, from_variant: fn(&attr::Variant) -> Option<&[syn::WherePredicate]>, ) -> syn::Generics61 pub fn with_where_predicates_from_variants(
62     cont: &Container,
63     generics: &syn::Generics,
64     from_variant: fn(&attr::Variant) -> Option<&[syn::WherePredicate]>,
65 ) -> syn::Generics {
66     let variants = match &cont.data {
67         Data::Enum(variants) => variants,
68         Data::Struct(_, _) => {
69             return generics.clone();
70         }
71     };
72 
73     let predicates = variants
74         .iter()
75         .flat_map(|variant| from_variant(&variant.attrs))
76         .flat_map(|predicates| predicates.to_vec());
77 
78     let mut generics = generics.clone();
79     generics.make_where_clause().predicates.extend(predicates);
80     generics
81 }
82 
83 // Puts the given bound on any generic type parameters that are used in fields
84 // for which filter returns true.
85 //
86 // For example, the following struct needs the bound `A: Serialize, B:
87 // Serialize`.
88 //
89 //     struct S<'b, A, B: 'b, C> {
90 //         a: A,
91 //         b: Option<&'b B>
92 //         #[serde(skip_serializing)]
93 //         c: C,
94 //     }
with_bound( cont: &Container, generics: &syn::Generics, filter: fn(&attr::Field, Option<&attr::Variant>) -> bool, bound: &syn::Path, ) -> syn::Generics95 pub fn with_bound(
96     cont: &Container,
97     generics: &syn::Generics,
98     filter: fn(&attr::Field, Option<&attr::Variant>) -> bool,
99     bound: &syn::Path,
100 ) -> syn::Generics {
101     struct FindTyParams<'ast> {
102         // Set of all generic type parameters on the current struct (A, B, C in
103         // the example). Initialized up front.
104         all_type_params: HashSet<syn::Ident>,
105 
106         // Set of generic type parameters used in fields for which filter
107         // returns true (A and B in the example). Filled in as the visitor sees
108         // them.
109         relevant_type_params: HashSet<syn::Ident>,
110 
111         // Fields whose type is an associated type of one of the generic type
112         // parameters.
113         associated_type_usage: Vec<&'ast syn::TypePath>,
114     }
115     impl<'ast> Visit<'ast> for FindTyParams<'ast> {
116         fn visit_field(&mut self, field: &'ast syn::Field) {
117             if let syn::Type::Path(ty) = &field.ty {
118                 if let Some(Pair::Punctuated(t, _)) = ty.path.segments.pairs().next() {
119                     if self.all_type_params.contains(&t.ident) {
120                         self.associated_type_usage.push(ty);
121                     }
122                 }
123             }
124             self.visit_type(&field.ty);
125         }
126 
127         fn visit_path(&mut self, path: &'ast syn::Path) {
128             if let Some(seg) = path.segments.last() {
129                 if seg.ident == "PhantomData" {
130                     // Hardcoded exception, because PhantomData<T> implements
131                     // Serialize and Deserialize whether or not T implements it.
132                     return;
133                 }
134             }
135             if path.leading_colon.is_none() && path.segments.len() == 1 {
136                 let id = &path.segments[0].ident;
137                 if self.all_type_params.contains(id) {
138                     self.relevant_type_params.insert(id.clone());
139                 }
140             }
141             visit::visit_path(self, path);
142         }
143 
144         // Type parameter should not be considered used by a macro path.
145         //
146         //     struct TypeMacro<T> {
147         //         mac: T!(),
148         //         marker: PhantomData<T>,
149         //     }
150         fn visit_macro(&mut self, _mac: &'ast syn::Macro) {}
151     }
152 
153     let all_type_params = generics
154         .type_params()
155         .map(|param| param.ident.clone())
156         .collect();
157 
158     let mut visitor = FindTyParams {
159         all_type_params,
160         relevant_type_params: HashSet::new(),
161         associated_type_usage: Vec::new(),
162     };
163     match &cont.data {
164         Data::Enum(variants) => {
165             for variant in variants.iter() {
166                 let relevant_fields = variant
167                     .fields
168                     .iter()
169                     .filter(|field| filter(&field.attrs, Some(&variant.attrs)));
170                 for field in relevant_fields {
171                     visitor.visit_field(field.original);
172                 }
173             }
174         }
175         Data::Struct(_, fields) => {
176             for field in fields.iter().filter(|field| filter(&field.attrs, None)) {
177                 visitor.visit_field(field.original);
178             }
179         }
180     }
181 
182     let relevant_type_params = visitor.relevant_type_params;
183     let associated_type_usage = visitor.associated_type_usage;
184     let new_predicates = generics
185         .type_params()
186         .map(|param| param.ident.clone())
187         .filter(|id| relevant_type_params.contains(id))
188         .map(|id| syn::TypePath {
189             qself: None,
190             path: id.into(),
191         })
192         .chain(associated_type_usage.into_iter().cloned())
193         .map(|bounded_ty| {
194             syn::WherePredicate::Type(syn::PredicateType {
195                 lifetimes: None,
196                 // the type parameter that is being bounded e.g. T
197                 bounded_ty: syn::Type::Path(bounded_ty),
198                 colon_token: <Token![:]>::default(),
199                 // the bound e.g. Serialize
200                 bounds: vec![syn::TypeParamBound::Trait(syn::TraitBound {
201                     paren_token: None,
202                     modifier: syn::TraitBoundModifier::None,
203                     lifetimes: None,
204                     path: bound.clone(),
205                 })]
206                 .into_iter()
207                 .collect(),
208             })
209         });
210 
211     let mut generics = generics.clone();
212     generics
213         .make_where_clause()
214         .predicates
215         .extend(new_predicates);
216     generics
217 }
218 
with_self_bound( cont: &Container, generics: &syn::Generics, bound: &syn::Path, ) -> syn::Generics219 pub fn with_self_bound(
220     cont: &Container,
221     generics: &syn::Generics,
222     bound: &syn::Path,
223 ) -> syn::Generics {
224     let mut generics = generics.clone();
225     generics
226         .make_where_clause()
227         .predicates
228         .push(syn::WherePredicate::Type(syn::PredicateType {
229             lifetimes: None,
230             // the type that is being bounded e.g. MyStruct<'a, T>
231             bounded_ty: type_of_item(cont),
232             colon_token: <Token![:]>::default(),
233             // the bound e.g. Default
234             bounds: vec![syn::TypeParamBound::Trait(syn::TraitBound {
235                 paren_token: None,
236                 modifier: syn::TraitBoundModifier::None,
237                 lifetimes: None,
238                 path: bound.clone(),
239             })]
240             .into_iter()
241             .collect(),
242         }));
243     generics
244 }
245 
with_lifetime_bound(generics: &syn::Generics, lifetime: &str) -> syn::Generics246 pub fn with_lifetime_bound(generics: &syn::Generics, lifetime: &str) -> syn::Generics {
247     let bound = syn::Lifetime::new(lifetime, Span::call_site());
248     let def = syn::LifetimeDef {
249         attrs: Vec::new(),
250         lifetime: bound.clone(),
251         colon_token: None,
252         bounds: Punctuated::new(),
253     };
254 
255     let params = Some(syn::GenericParam::Lifetime(def))
256         .into_iter()
257         .chain(generics.params.iter().cloned().map(|mut param| {
258             match &mut param {
259                 syn::GenericParam::Lifetime(param) => {
260                     param.bounds.push(bound.clone());
261                 }
262                 syn::GenericParam::Type(param) => {
263                     param
264                         .bounds
265                         .push(syn::TypeParamBound::Lifetime(bound.clone()));
266                 }
267                 syn::GenericParam::Const(_) => {}
268             }
269             param
270         }))
271         .collect();
272 
273     syn::Generics {
274         params,
275         ..generics.clone()
276     }
277 }
278 
type_of_item(cont: &Container) -> syn::Type279 fn type_of_item(cont: &Container) -> syn::Type {
280     syn::Type::Path(syn::TypePath {
281         qself: None,
282         path: syn::Path {
283             leading_colon: None,
284             segments: vec![syn::PathSegment {
285                 ident: cont.ident.clone(),
286                 arguments: syn::PathArguments::AngleBracketed(
287                     syn::AngleBracketedGenericArguments {
288                         colon2_token: None,
289                         lt_token: <Token![<]>::default(),
290                         args: cont
291                             .generics
292                             .params
293                             .iter()
294                             .map(|param| match param {
295                                 syn::GenericParam::Type(param) => {
296                                     syn::GenericArgument::Type(syn::Type::Path(syn::TypePath {
297                                         qself: None,
298                                         path: param.ident.clone().into(),
299                                     }))
300                                 }
301                                 syn::GenericParam::Lifetime(param) => {
302                                     syn::GenericArgument::Lifetime(param.lifetime.clone())
303                                 }
304                                 syn::GenericParam::Const(_) => {
305                                     panic!("Serde does not support const generics yet");
306                                 }
307                             })
308                             .collect(),
309                         gt_token: <Token![>]>::default(),
310                     },
311                 ),
312             }]
313             .into_iter()
314             .collect(),
315         },
316     })
317 }
318