1 /* This file incorporates work covered by the following copyright and
2  * permission notice:
3  *   Copyright 2016 The serde Developers. See
4  *   https://github.com/serde-rs/serde/blob/3f28a9324042950afa80354722aeeee1a55cbfa3/README.md#license.
5  *
6  *   Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
7  *   http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
8  *   <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
9  *   option. This file may not be copied, modified, or distributed
10  *   except according to those terms.
11  */
12 
13 use ast;
14 use attr;
15 use std::collections::HashSet;
16 use syn::{self, visit, GenericParam};
17 
18 // use internals::ast::Item;
19 // use internals::attr;
20 
21 /// Remove the default from every type parameter because in the generated `impl`s
22 /// they look like associated types: "error: associated type bindings are not
23 /// allowed here".
without_defaults(generics: &syn::Generics) -> syn::Generics24 pub fn without_defaults(generics: &syn::Generics) -> syn::Generics {
25     syn::Generics {
26         params: generics
27             .params
28             .iter()
29             .map(|generic_param| match *generic_param {
30                 GenericParam::Type(ref ty_param) => syn::GenericParam::Type(syn::TypeParam {
31                     default: None,
32                     ..ty_param.clone()
33                 }),
34                 ref param => param.clone(),
35             })
36             .collect(),
37         ..generics.clone()
38     }
39 }
40 
with_where_predicates( generics: &syn::Generics, predicates: &[syn::WherePredicate], ) -> syn::Generics41 pub fn with_where_predicates(
42     generics: &syn::Generics,
43     predicates: &[syn::WherePredicate],
44 ) -> syn::Generics {
45     let mut cloned = generics.clone();
46     cloned
47         .make_where_clause()
48         .predicates
49         .extend(predicates.iter().cloned());
50     cloned
51 }
52 
with_where_predicates_from_fields<F>( item: &ast::Input, generics: &syn::Generics, from_field: F, ) -> syn::Generics where F: Fn(&attr::Field) -> Option<&[syn::WherePredicate]>,53 pub fn with_where_predicates_from_fields<F>(
54     item: &ast::Input,
55     generics: &syn::Generics,
56     from_field: F,
57 ) -> syn::Generics
58 where
59     F: Fn(&attr::Field) -> Option<&[syn::WherePredicate]>,
60 {
61     let mut cloned = generics.clone();
62     {
63         let fields = item.body.all_fields();
64         let field_where_predicates = fields
65             .iter()
66             .flat_map(|field| from_field(&field.attrs))
67             .flat_map(|predicates| predicates.to_vec());
68 
69         cloned
70             .make_where_clause()
71             .predicates
72             .extend(field_where_predicates);
73     }
74     cloned
75 }
76 
77 /// Puts the given bound on any generic type parameters that are used in fields
78 /// for which filter returns true.
79 ///
80 /// For example, the following structure needs the bound `A: Debug, B: Debug`.
81 ///
82 /// ```ignore
83 /// struct S<'b, A, B: 'b, C> {
84 ///     a: A,
85 ///     b: Option<&'b B>
86 ///     #[derivative(Debug="ignore")]
87 ///     c: C,
88 /// }
89 /// ```
with_bound<F>( item: &ast::Input, generics: &syn::Generics, filter: F, bound: &syn::Path, ) -> syn::Generics where F: Fn(&attr::Field) -> bool,90 pub fn with_bound<F>(
91     item: &ast::Input,
92     generics: &syn::Generics,
93     filter: F,
94     bound: &syn::Path,
95 ) -> syn::Generics
96 where
97     F: Fn(&attr::Field) -> bool,
98 {
99     #[derive(Debug)]
100     struct FindTyParams {
101         /// Set of all generic type parameters on the current struct (A, B, C in
102         /// the example). Initialized up front.
103         all_ty_params: HashSet<syn::Ident>,
104         /// Set of generic type parameters used in fields for which filter
105         /// returns true (A and B in the example). Filled in as the visitor sees
106         /// them.
107         relevant_ty_params: HashSet<syn::Ident>,
108     }
109     impl<'ast> visit::Visit<'ast> for FindTyParams {
110         fn visit_path(&mut self, path: &'ast syn::Path) {
111             if is_phantom_data(path) {
112                 // Hardcoded exception, because `PhantomData<T>` implements
113                 // most traits whether or not `T` implements it.
114                 return;
115             }
116             if path.leading_colon.is_none() && path.segments.len() == 1 {
117                 let id = &path.segments[0].ident;
118                 if self.all_ty_params.contains(id) {
119                     self.relevant_ty_params.insert(id.clone());
120                 }
121             }
122             visit::visit_path(self, path);
123         }
124     }
125 
126     let all_ty_params: HashSet<_> = generics
127         .type_params()
128         .map(|ty_param| ty_param.ident.clone())
129         .collect();
130 
131     let relevant_tys = item
132         .body
133         .all_fields()
134         .into_iter()
135         .filter(|field| {
136             if let syn::Type::Path(syn::TypePath { ref path, .. }) = *field.ty {
137                 !is_phantom_data(path)
138             } else {
139                 true
140             }
141         })
142         .filter(|field| filter(&field.attrs))
143         .map(|field| &field.ty);
144 
145     let mut visitor = FindTyParams {
146         all_ty_params: all_ty_params,
147         relevant_ty_params: HashSet::new(),
148     };
149     for ty in relevant_tys {
150         visit::visit_type(&mut visitor, ty);
151     }
152 
153     let mut cloned = generics.clone();
154     {
155         let relevant_where_predicates = generics
156             .type_params()
157             .map(|ty_param| &ty_param.ident)
158             .filter(|id| visitor.relevant_ty_params.contains(id))
159             .map(|id| -> syn::WherePredicate { parse_quote!( #id : #bound ) });
160 
161         cloned
162             .make_where_clause()
163             .predicates
164             .extend(relevant_where_predicates);
165     }
166     cloned
167 }
168 
is_phantom_data(path: &syn::Path) -> bool169 fn is_phantom_data(path: &syn::Path) -> bool {
170     match path.segments.last() {
171         Some(syn::punctuated::Pair::End(seg)) if seg.ident == "PhantomData" => true,
172         _ => false,
173     }
174 }
175