1 use std::collections::HashSet;
2
3 use syn;
4 use syn::punctuated::{Pair, Punctuated};
5 use syn::visit::{self, Visit};
6
7 use internals::ast::{Container, Data};
8 use internals::attr;
9
10 use proc_macro2::Span;
11
12 // Remove the default from every type parameter because in the generated impls
13 // they look like associated types: "error: associated type bindings are not
14 // allowed here".
without_defaults(generics: &syn::Generics) -> syn::Generics15 pub fn without_defaults(generics: &syn::Generics) -> syn::Generics {
16 syn::Generics {
17 params: generics
18 .params
19 .iter()
20 .map(|param| match param {
21 syn::GenericParam::Type(param) => syn::GenericParam::Type(syn::TypeParam {
22 eq_token: None,
23 default: None,
24 ..param.clone()
25 }),
26 _ => param.clone(),
27 })
28 .collect(),
29 ..generics.clone()
30 }
31 }
32
with_where_predicates( generics: &syn::Generics, predicates: &[syn::WherePredicate], ) -> syn::Generics33 pub fn with_where_predicates(
34 generics: &syn::Generics,
35 predicates: &[syn::WherePredicate],
36 ) -> syn::Generics {
37 let mut generics = generics.clone();
38 generics
39 .make_where_clause()
40 .predicates
41 .extend(predicates.iter().cloned());
42 generics
43 }
44
with_where_predicates_from_fields( cont: &Container, generics: &syn::Generics, from_field: fn(&attr::Field) -> Option<&[syn::WherePredicate]>, ) -> syn::Generics45 pub fn with_where_predicates_from_fields(
46 cont: &Container,
47 generics: &syn::Generics,
48 from_field: fn(&attr::Field) -> Option<&[syn::WherePredicate]>,
49 ) -> syn::Generics {
50 let predicates = cont
51 .data
52 .all_fields()
53 .flat_map(|field| from_field(&field.attrs))
54 .flat_map(|predicates| predicates.to_vec());
55
56 let mut generics = generics.clone();
57 generics.make_where_clause().predicates.extend(predicates);
58 generics
59 }
60
with_where_predicates_from_variants( cont: &Container, generics: &syn::Generics, from_variant: fn(&attr::Variant) -> Option<&[syn::WherePredicate]>, ) -> syn::Generics61 pub fn with_where_predicates_from_variants(
62 cont: &Container,
63 generics: &syn::Generics,
64 from_variant: fn(&attr::Variant) -> Option<&[syn::WherePredicate]>,
65 ) -> syn::Generics {
66 let variants = match &cont.data {
67 Data::Enum(variants) => variants,
68 Data::Struct(_, _) => {
69 return generics.clone();
70 }
71 };
72
73 let predicates = variants
74 .iter()
75 .flat_map(|variant| from_variant(&variant.attrs))
76 .flat_map(|predicates| predicates.to_vec());
77
78 let mut generics = generics.clone();
79 generics.make_where_clause().predicates.extend(predicates);
80 generics
81 }
82
83 // Puts the given bound on any generic type parameters that are used in fields
84 // for which filter returns true.
85 //
86 // For example, the following struct needs the bound `A: Serialize, B:
87 // Serialize`.
88 //
89 // struct S<'b, A, B: 'b, C> {
90 // a: A,
91 // b: Option<&'b B>
92 // #[serde(skip_serializing)]
93 // c: C,
94 // }
with_bound( cont: &Container, generics: &syn::Generics, filter: fn(&attr::Field, Option<&attr::Variant>) -> bool, bound: &syn::Path, ) -> syn::Generics95 pub fn with_bound(
96 cont: &Container,
97 generics: &syn::Generics,
98 filter: fn(&attr::Field, Option<&attr::Variant>) -> bool,
99 bound: &syn::Path,
100 ) -> syn::Generics {
101 struct FindTyParams<'ast> {
102 // Set of all generic type parameters on the current struct (A, B, C in
103 // the example). Initialized up front.
104 all_type_params: HashSet<syn::Ident>,
105
106 // Set of generic type parameters used in fields for which filter
107 // returns true (A and B in the example). Filled in as the visitor sees
108 // them.
109 relevant_type_params: HashSet<syn::Ident>,
110
111 // Fields whose type is an associated type of one of the generic type
112 // parameters.
113 associated_type_usage: Vec<&'ast syn::TypePath>,
114 }
115 impl<'ast> Visit<'ast> for FindTyParams<'ast> {
116 fn visit_field(&mut self, field: &'ast syn::Field) {
117 if let syn::Type::Path(ty) = &field.ty {
118 if let Some(Pair::Punctuated(t, _)) = ty.path.segments.pairs().next() {
119 if self.all_type_params.contains(&t.ident) {
120 self.associated_type_usage.push(ty);
121 }
122 }
123 }
124 self.visit_type(&field.ty);
125 }
126
127 fn visit_path(&mut self, path: &'ast syn::Path) {
128 if let Some(seg) = path.segments.last() {
129 if seg.ident == "PhantomData" {
130 // Hardcoded exception, because PhantomData<T> implements
131 // Serialize and Deserialize whether or not T implements it.
132 return;
133 }
134 }
135 if path.leading_colon.is_none() && path.segments.len() == 1 {
136 let id = &path.segments[0].ident;
137 if self.all_type_params.contains(id) {
138 self.relevant_type_params.insert(id.clone());
139 }
140 }
141 visit::visit_path(self, path);
142 }
143
144 // Type parameter should not be considered used by a macro path.
145 //
146 // struct TypeMacro<T> {
147 // mac: T!(),
148 // marker: PhantomData<T>,
149 // }
150 fn visit_macro(&mut self, _mac: &'ast syn::Macro) {}
151 }
152
153 let all_type_params = generics
154 .type_params()
155 .map(|param| param.ident.clone())
156 .collect();
157
158 let mut visitor = FindTyParams {
159 all_type_params,
160 relevant_type_params: HashSet::new(),
161 associated_type_usage: Vec::new(),
162 };
163 match &cont.data {
164 Data::Enum(variants) => {
165 for variant in variants.iter() {
166 let relevant_fields = variant
167 .fields
168 .iter()
169 .filter(|field| filter(&field.attrs, Some(&variant.attrs)));
170 for field in relevant_fields {
171 visitor.visit_field(field.original);
172 }
173 }
174 }
175 Data::Struct(_, fields) => {
176 for field in fields.iter().filter(|field| filter(&field.attrs, None)) {
177 visitor.visit_field(field.original);
178 }
179 }
180 }
181
182 let relevant_type_params = visitor.relevant_type_params;
183 let associated_type_usage = visitor.associated_type_usage;
184 let new_predicates = generics
185 .type_params()
186 .map(|param| param.ident.clone())
187 .filter(|id| relevant_type_params.contains(id))
188 .map(|id| syn::TypePath {
189 qself: None,
190 path: id.into(),
191 })
192 .chain(associated_type_usage.into_iter().cloned())
193 .map(|bounded_ty| {
194 syn::WherePredicate::Type(syn::PredicateType {
195 lifetimes: None,
196 // the type parameter that is being bounded e.g. T
197 bounded_ty: syn::Type::Path(bounded_ty),
198 colon_token: <Token![:]>::default(),
199 // the bound e.g. Serialize
200 bounds: vec![syn::TypeParamBound::Trait(syn::TraitBound {
201 paren_token: None,
202 modifier: syn::TraitBoundModifier::None,
203 lifetimes: None,
204 path: bound.clone(),
205 })]
206 .into_iter()
207 .collect(),
208 })
209 });
210
211 let mut generics = generics.clone();
212 generics
213 .make_where_clause()
214 .predicates
215 .extend(new_predicates);
216 generics
217 }
218
with_self_bound( cont: &Container, generics: &syn::Generics, bound: &syn::Path, ) -> syn::Generics219 pub fn with_self_bound(
220 cont: &Container,
221 generics: &syn::Generics,
222 bound: &syn::Path,
223 ) -> syn::Generics {
224 let mut generics = generics.clone();
225 generics
226 .make_where_clause()
227 .predicates
228 .push(syn::WherePredicate::Type(syn::PredicateType {
229 lifetimes: None,
230 // the type that is being bounded e.g. MyStruct<'a, T>
231 bounded_ty: type_of_item(cont),
232 colon_token: <Token![:]>::default(),
233 // the bound e.g. Default
234 bounds: vec![syn::TypeParamBound::Trait(syn::TraitBound {
235 paren_token: None,
236 modifier: syn::TraitBoundModifier::None,
237 lifetimes: None,
238 path: bound.clone(),
239 })]
240 .into_iter()
241 .collect(),
242 }));
243 generics
244 }
245
with_lifetime_bound(generics: &syn::Generics, lifetime: &str) -> syn::Generics246 pub fn with_lifetime_bound(generics: &syn::Generics, lifetime: &str) -> syn::Generics {
247 let bound = syn::Lifetime::new(lifetime, Span::call_site());
248 let def = syn::LifetimeDef {
249 attrs: Vec::new(),
250 lifetime: bound.clone(),
251 colon_token: None,
252 bounds: Punctuated::new(),
253 };
254
255 let params = Some(syn::GenericParam::Lifetime(def))
256 .into_iter()
257 .chain(generics.params.iter().cloned().map(|mut param| {
258 match &mut param {
259 syn::GenericParam::Lifetime(param) => {
260 param.bounds.push(bound.clone());
261 }
262 syn::GenericParam::Type(param) => {
263 param
264 .bounds
265 .push(syn::TypeParamBound::Lifetime(bound.clone()));
266 }
267 syn::GenericParam::Const(_) => {}
268 }
269 param
270 }))
271 .collect();
272
273 syn::Generics {
274 params,
275 ..generics.clone()
276 }
277 }
278
type_of_item(cont: &Container) -> syn::Type279 fn type_of_item(cont: &Container) -> syn::Type {
280 syn::Type::Path(syn::TypePath {
281 qself: None,
282 path: syn::Path {
283 leading_colon: None,
284 segments: vec![syn::PathSegment {
285 ident: cont.ident.clone(),
286 arguments: syn::PathArguments::AngleBracketed(
287 syn::AngleBracketedGenericArguments {
288 colon2_token: None,
289 lt_token: <Token![<]>::default(),
290 args: cont
291 .generics
292 .params
293 .iter()
294 .map(|param| match param {
295 syn::GenericParam::Type(param) => {
296 syn::GenericArgument::Type(syn::Type::Path(syn::TypePath {
297 qself: None,
298 path: param.ident.clone().into(),
299 }))
300 }
301 syn::GenericParam::Lifetime(param) => {
302 syn::GenericArgument::Lifetime(param.lifetime.clone())
303 }
304 syn::GenericParam::Const(_) => {
305 panic!("Serde does not support const generics yet");
306 }
307 })
308 .collect(),
309 gt_token: <Token![>]>::default(),
310 },
311 ),
312 }]
313 .into_iter()
314 .collect(),
315 },
316 })
317 }
318