1 /* This file incorporates work covered by the following copyright and
2 * permission notice:
3 * Copyright 2016 The serde Developers. See
4 * https://github.com/serde-rs/serde/blob/3f28a9324042950afa80354722aeeee1a55cbfa3/README.md#license.
5 *
6 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
7 * http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
8 * <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
9 * option. This file may not be copied, modified, or distributed
10 * except according to those terms.
11 */
12
13 use ast;
14 use attr;
15 use std::collections::HashSet;
16 use syn::{self, visit, GenericParam};
17
18 // use internals::ast::Item;
19 // use internals::attr;
20
21 /// Remove the default from every type parameter because in the generated `impl`s
22 /// they look like associated types: "error: associated type bindings are not
23 /// allowed here".
without_defaults(generics: &syn::Generics) -> syn::Generics24 pub fn without_defaults(generics: &syn::Generics) -> syn::Generics {
25 syn::Generics {
26 params: generics
27 .params
28 .iter()
29 .map(|generic_param| match *generic_param {
30 GenericParam::Type(ref ty_param) => syn::GenericParam::Type(syn::TypeParam {
31 default: None,
32 ..ty_param.clone()
33 }),
34 ref param => param.clone(),
35 })
36 .collect(),
37 ..generics.clone()
38 }
39 }
40
with_where_predicates( generics: &syn::Generics, predicates: &[syn::WherePredicate], ) -> syn::Generics41 pub fn with_where_predicates(
42 generics: &syn::Generics,
43 predicates: &[syn::WherePredicate],
44 ) -> syn::Generics {
45 let mut cloned = generics.clone();
46 cloned
47 .make_where_clause()
48 .predicates
49 .extend(predicates.iter().cloned());
50 cloned
51 }
52
with_where_predicates_from_fields<F>( item: &ast::Input, generics: &syn::Generics, from_field: F, ) -> syn::Generics where F: Fn(&attr::Field) -> Option<&[syn::WherePredicate]>,53 pub fn with_where_predicates_from_fields<F>(
54 item: &ast::Input,
55 generics: &syn::Generics,
56 from_field: F,
57 ) -> syn::Generics
58 where
59 F: Fn(&attr::Field) -> Option<&[syn::WherePredicate]>,
60 {
61 let mut cloned = generics.clone();
62 {
63 let fields = item.body.all_fields();
64 let field_where_predicates = fields
65 .iter()
66 .flat_map(|field| from_field(&field.attrs))
67 .flat_map(|predicates| predicates.to_vec());
68
69 cloned
70 .make_where_clause()
71 .predicates
72 .extend(field_where_predicates);
73 }
74 cloned
75 }
76
77 /// Puts the given bound on any generic type parameters that are used in fields
78 /// for which filter returns true.
79 ///
80 /// For example, the following structure needs the bound `A: Debug, B: Debug`.
81 ///
82 /// ```ignore
83 /// struct S<'b, A, B: 'b, C> {
84 /// a: A,
85 /// b: Option<&'b B>
86 /// #[derivative(Debug="ignore")]
87 /// c: C,
88 /// }
89 /// ```
with_bound<F>( item: &ast::Input, generics: &syn::Generics, filter: F, bound: &syn::Path, ) -> syn::Generics where F: Fn(&attr::Field) -> bool,90 pub fn with_bound<F>(
91 item: &ast::Input,
92 generics: &syn::Generics,
93 filter: F,
94 bound: &syn::Path,
95 ) -> syn::Generics
96 where
97 F: Fn(&attr::Field) -> bool,
98 {
99 #[derive(Debug)]
100 struct FindTyParams {
101 /// Set of all generic type parameters on the current struct (A, B, C in
102 /// the example). Initialized up front.
103 all_ty_params: HashSet<syn::Ident>,
104 /// Set of generic type parameters used in fields for which filter
105 /// returns true (A and B in the example). Filled in as the visitor sees
106 /// them.
107 relevant_ty_params: HashSet<syn::Ident>,
108 }
109 impl<'ast> visit::Visit<'ast> for FindTyParams {
110 fn visit_path(&mut self, path: &'ast syn::Path) {
111 if is_phantom_data(path) {
112 // Hardcoded exception, because `PhantomData<T>` implements
113 // most traits whether or not `T` implements it.
114 return;
115 }
116 if path.leading_colon.is_none() && path.segments.len() == 1 {
117 let id = &path.segments[0].ident;
118 if self.all_ty_params.contains(id) {
119 self.relevant_ty_params.insert(id.clone());
120 }
121 }
122 visit::visit_path(self, path);
123 }
124 }
125
126 let all_ty_params: HashSet<_> = generics
127 .type_params()
128 .map(|ty_param| ty_param.ident.clone())
129 .collect();
130
131 let relevant_tys = item
132 .body
133 .all_fields()
134 .into_iter()
135 .filter(|field| {
136 if let syn::Type::Path(syn::TypePath { ref path, .. }) = *field.ty {
137 !is_phantom_data(path)
138 } else {
139 true
140 }
141 })
142 .filter(|field| filter(&field.attrs))
143 .map(|field| &field.ty);
144
145 let mut visitor = FindTyParams {
146 all_ty_params: all_ty_params,
147 relevant_ty_params: HashSet::new(),
148 };
149 for ty in relevant_tys {
150 visit::visit_type(&mut visitor, ty);
151 }
152
153 let mut cloned = generics.clone();
154 {
155 let relevant_where_predicates = generics
156 .type_params()
157 .map(|ty_param| &ty_param.ident)
158 .filter(|id| visitor.relevant_ty_params.contains(id))
159 .map(|id| -> syn::WherePredicate { parse_quote!( #id : #bound ) });
160
161 cloned
162 .make_where_clause()
163 .predicates
164 .extend(relevant_where_predicates);
165 }
166 cloned
167 }
168
is_phantom_data(path: &syn::Path) -> bool169 fn is_phantom_data(path: &syn::Path) -> bool {
170 match path.segments.last() {
171 Some(syn::punctuated::Pair::End(seg)) if seg.ident == "PhantomData" => true,
172 _ => false,
173 }
174 }
175