1 // https://github.com/rust-lang/rust/issues/13101
2 
3 use ast;
4 use attr;
5 use matcher;
6 use paths;
7 use proc_macro2;
8 use syn;
9 use utils;
10 
11 /// Derive `Eq` for `input`.
derive_eq(input: &ast::Input) -> proc_macro2::TokenStream12 pub fn derive_eq(input: &ast::Input) -> proc_macro2::TokenStream {
13     let name = &input.ident;
14 
15     let eq_trait_path = eq_trait_path();
16     let generics = utils::build_impl_generics(
17         input,
18         &eq_trait_path,
19         needs_eq_bound,
20         |field| field.eq_bound(),
21         |input| input.eq_bound(),
22     );
23     let new_where_clause;
24     let (impl_generics, ty_generics, mut where_clause) = generics.split_for_impl();
25 
26     if let Some(new_where_clause2) =
27         maybe_add_copy(input, where_clause, |f| !f.attrs.ignore_partial_eq())
28     {
29         new_where_clause = new_where_clause2;
30         where_clause = Some(&new_where_clause);
31     }
32 
33     quote! {
34         #[allow(unused_qualifications)]
35         impl #impl_generics #eq_trait_path for #name #ty_generics #where_clause {}
36     }
37 }
38 
39 /// Derive `PartialEq` for `input`.
derive_partial_eq(input: &ast::Input) -> proc_macro2::TokenStream40 pub fn derive_partial_eq(input: &ast::Input) -> proc_macro2::TokenStream {
41     let discriminant_cmp = if let ast::Body::Enum(_) = input.body {
42         let discriminant_path = paths::discriminant_path();
43 
44         quote!((#discriminant_path(&*self) == #discriminant_path(&*other)))
45     } else {
46         quote!(true)
47     };
48 
49     let name = &input.ident;
50 
51     let partial_eq_trait_path = partial_eq_trait_path();
52     let generics = utils::build_impl_generics(
53         input,
54         &partial_eq_trait_path,
55         needs_partial_eq_bound,
56         |field| field.partial_eq_bound(),
57         |input| input.partial_eq_bound(),
58     );
59     let new_where_clause;
60     let (impl_generics, ty_generics, mut where_clause) = generics.split_for_impl();
61 
62     let match_fields = if input.is_trivial_enum() {
63         quote!(true)
64     } else {
65         matcher::Matcher::new(matcher::BindingStyle::Ref, input.attrs.is_packed)
66             .with_field_filter(|f: &ast::Field| !f.attrs.ignore_partial_eq())
67             .build_2_arms(
68                 (quote!(*self), quote!(*other)),
69                 (input, "__self"),
70                 (input, "__other"),
71                 |_, _, _, (left_variant, right_variant)| {
72                     let cmp = left_variant.iter().zip(&right_variant).map(|(o, i)| {
73                         let outer_name = &o.expr;
74                         let inner_name = &i.expr;
75 
76                         if o.field.attrs.ignore_partial_eq() {
77                             None
78                         } else if let Some(compare_fn) = o.field.attrs.partial_eq_compare_with() {
79                             Some(quote!(&& #compare_fn(&#outer_name, &#inner_name)))
80                         } else {
81                             Some(quote!(&& &#outer_name == &#inner_name))
82                         }
83                     });
84 
85                     quote!(true #(#cmp)*)
86                 },
87             )
88     };
89 
90     if let Some(new_where_clause2) =
91         maybe_add_copy(input, where_clause, |f| !f.attrs.ignore_partial_eq())
92     {
93         new_where_clause = new_where_clause2;
94         where_clause = Some(&new_where_clause);
95     }
96 
97     quote! {
98         #[allow(unused_qualifications)]
99         #[allow(clippy::unneeded_field_pattern)]
100         impl #impl_generics #partial_eq_trait_path for #name #ty_generics #where_clause {
101             fn eq(&self, other: &Self) -> bool {
102                 #discriminant_cmp && #match_fields
103             }
104         }
105     }
106 }
107 
108 /// Derive `PartialOrd` for `input`.
derive_partial_ord( input: &ast::Input, errors: &mut proc_macro2::TokenStream, ) -> proc_macro2::TokenStream109 pub fn derive_partial_ord(
110     input: &ast::Input,
111     errors: &mut proc_macro2::TokenStream,
112 ) -> proc_macro2::TokenStream {
113     if let ast::Body::Enum(_) = input.body {
114         if !input.attrs.partial_ord_on_enum() {
115             let message = "can't use `#[derivative(PartialOrd)]` on an enumeration without \
116             `feature_allow_slow_enum`; see the documentation for more details";
117             errors.extend(syn::Error::new(input.span, message).to_compile_error());
118         }
119     }
120 
121     let option_path = option_path();
122     let ordering_path = ordering_path();
123 
124     let body = matcher::Matcher::new(matcher::BindingStyle::Ref, input.attrs.is_packed)
125         .with_field_filter(|f: &ast::Field| !f.attrs.ignore_partial_ord())
126         .build_arms(input, "__self", |_, n, _, _, _, outer_bis| {
127             let body = matcher::Matcher::new(matcher::BindingStyle::Ref, input.attrs.is_packed)
128                 .with_field_filter(|f: &ast::Field| !f.attrs.ignore_partial_ord())
129                 .build_arms(input, "__other", |_, m, _, _, _, inner_bis| {
130                     match n.cmp(&m) {
131                         ::std::cmp::Ordering::Less => {
132                             quote!(#option_path::Some(#ordering_path::Less))
133                         }
134                         ::std::cmp::Ordering::Greater => {
135                             quote!(#option_path::Some(#ordering_path::Greater))
136                         }
137                         ::std::cmp::Ordering::Equal => {
138                             let equal_path = quote!(#ordering_path::Equal);
139                             outer_bis
140                                 .iter()
141                                 .rev()
142                                 .zip(inner_bis.into_iter().rev())
143                                 .fold(quote!(#option_path::Some(#equal_path)), |acc, (o, i)| {
144                                     let outer_name = &o.expr;
145                                     let inner_name = &i.expr;
146 
147                                     if o.field.attrs.ignore_partial_ord() {
148                                         acc
149                                     } else {
150                                         let cmp_fn = o
151                                             .field
152                                             .attrs
153                                             .partial_ord_compare_with()
154                                             .map(|f| quote!(#f))
155                                             .unwrap_or_else(|| {
156                                                 let path = partial_ord_trait_path();
157                                                 quote!(#path::partial_cmp)
158                                             });
159 
160                                         quote!(match #cmp_fn(&#outer_name, &#inner_name) {
161                                             #option_path::Some(#equal_path) => #acc,
162                                             __derive_ordering_other => __derive_ordering_other,
163                                         })
164                                     }
165                                 })
166                         }
167                     }
168                 });
169 
170             quote! {
171                 match *other {
172                     #body
173                 }
174 
175             }
176         });
177 
178     let name = &input.ident;
179 
180     let partial_ord_trait_path = partial_ord_trait_path();
181     let generics = utils::build_impl_generics(
182         input,
183         &partial_ord_trait_path,
184         needs_partial_ord_bound,
185         |field| field.partial_ord_bound(),
186         |input| input.partial_ord_bound(),
187     );
188     let new_where_clause;
189     let (impl_generics, ty_generics, mut where_clause) = generics.split_for_impl();
190 
191     if let Some(new_where_clause2) =
192         maybe_add_copy(input, where_clause, |f| !f.attrs.ignore_partial_ord())
193     {
194         new_where_clause = new_where_clause2;
195         where_clause = Some(&new_where_clause);
196     }
197 
198     quote! {
199         #[allow(unused_qualifications)]
200         #[allow(clippy::unneeded_field_pattern)]
201         impl #impl_generics #partial_ord_trait_path for #name #ty_generics #where_clause {
202             fn partial_cmp(&self, other: &Self) -> #option_path<#ordering_path> {
203                 match *self {
204                     #body
205                 }
206             }
207         }
208     }
209 }
210 
211 /// Derive `Ord` for `input`.
derive_ord( input: &ast::Input, errors: &mut proc_macro2::TokenStream, ) -> proc_macro2::TokenStream212 pub fn derive_ord(
213     input: &ast::Input,
214     errors: &mut proc_macro2::TokenStream,
215 ) -> proc_macro2::TokenStream {
216     if let ast::Body::Enum(_) = input.body {
217         if !input.attrs.ord_on_enum() {
218             let message = "can't use `#[derivative(Ord)]` on an enumeration without \
219             `feature_allow_slow_enum`; see the documentation for more details";
220             errors.extend(syn::Error::new(input.span, message).to_compile_error());
221         }
222     }
223 
224     let ordering_path = ordering_path();
225 
226     let body = matcher::Matcher::new(matcher::BindingStyle::Ref, input.attrs.is_packed)
227         .with_field_filter(|f: &ast::Field| !f.attrs.ignore_ord())
228         .build_arms(input, "__self", |_, n, _, _, _, outer_bis| {
229             let body = matcher::Matcher::new(matcher::BindingStyle::Ref, input.attrs.is_packed)
230                 .with_field_filter(|f: &ast::Field| !f.attrs.ignore_ord())
231                 .build_arms(input, "__other", |_, m, _, _, _, inner_bis| {
232                     match n.cmp(&m) {
233                         ::std::cmp::Ordering::Less => quote!(#ordering_path::Less),
234                         ::std::cmp::Ordering::Greater => quote!(#ordering_path::Greater),
235                         ::std::cmp::Ordering::Equal => {
236                             let equal_path = quote!(#ordering_path::Equal);
237                             outer_bis
238                                 .iter()
239                                 .rev()
240                                 .zip(inner_bis.into_iter().rev())
241                                 .fold(quote!(#equal_path), |acc, (o, i)| {
242                                     let outer_name = &o.expr;
243                                     let inner_name = &i.expr;
244 
245                                     if o.field.attrs.ignore_ord() {
246                                         acc
247                                     } else {
248                                         let cmp_fn = o
249                                             .field
250                                             .attrs
251                                             .ord_compare_with()
252                                             .map(|f| quote!(#f))
253                                             .unwrap_or_else(|| {
254                                                 let path = ord_trait_path();
255                                                 quote!(#path::cmp)
256                                             });
257 
258                                         quote!(match #cmp_fn(&#outer_name, &#inner_name) {
259                                            #equal_path => #acc,
260                                             __derive_ordering_other => __derive_ordering_other,
261                                         })
262                                     }
263                                 })
264                         }
265                     }
266                 });
267 
268             quote! {
269                 match *other {
270                     #body
271                 }
272 
273             }
274         });
275 
276     let name = &input.ident;
277 
278     let ord_trait_path = ord_trait_path();
279     let generics = utils::build_impl_generics(
280         input,
281         &ord_trait_path,
282         needs_ord_bound,
283         |field| field.ord_bound(),
284         |input| input.ord_bound(),
285     );
286     let new_where_clause;
287     let (impl_generics, ty_generics, mut where_clause) = generics.split_for_impl();
288 
289     if let Some(new_where_clause2) = maybe_add_copy(input, where_clause, |f| !f.attrs.ignore_ord())
290     {
291         new_where_clause = new_where_clause2;
292         where_clause = Some(&new_where_clause);
293     }
294 
295     quote! {
296         #[allow(unused_qualifications)]
297         #[allow(clippy::unneeded_field_pattern)]
298         impl #impl_generics #ord_trait_path for #name #ty_generics #where_clause {
299             fn cmp(&self, other: &Self) -> #ordering_path {
300                 match *self {
301                     #body
302                 }
303             }
304         }
305     }
306 }
307 
needs_partial_eq_bound(attrs: &attr::Field) -> bool308 fn needs_partial_eq_bound(attrs: &attr::Field) -> bool {
309     !attrs.ignore_partial_eq() && attrs.partial_eq_bound().is_none()
310 }
311 
needs_partial_ord_bound(attrs: &attr::Field) -> bool312 fn needs_partial_ord_bound(attrs: &attr::Field) -> bool {
313     !attrs.ignore_partial_ord() && attrs.partial_ord_bound().is_none()
314 }
315 
needs_ord_bound(attrs: &attr::Field) -> bool316 fn needs_ord_bound(attrs: &attr::Field) -> bool {
317     !attrs.ignore_ord() && attrs.ord_bound().is_none()
318 }
319 
needs_eq_bound(attrs: &attr::Field) -> bool320 fn needs_eq_bound(attrs: &attr::Field) -> bool {
321     !attrs.ignore_partial_eq() && attrs.eq_bound().is_none()
322 }
323 
324 /// Return the path of the `Eq` trait, that is `::std::cmp::Eq`.
eq_trait_path() -> syn::Path325 fn eq_trait_path() -> syn::Path {
326     if cfg!(feature = "use_core") {
327         parse_quote!(::core::cmp::Eq)
328     } else {
329         parse_quote!(::std::cmp::Eq)
330     }
331 }
332 
333 /// Return the path of the `PartialEq` trait, that is `::std::cmp::PartialEq`.
partial_eq_trait_path() -> syn::Path334 fn partial_eq_trait_path() -> syn::Path {
335     if cfg!(feature = "use_core") {
336         parse_quote!(::core::cmp::PartialEq)
337     } else {
338         parse_quote!(::std::cmp::PartialEq)
339     }
340 }
341 
342 /// Return the path of the `PartialOrd` trait, that is `::std::cmp::PartialOrd`.
partial_ord_trait_path() -> syn::Path343 fn partial_ord_trait_path() -> syn::Path {
344     if cfg!(feature = "use_core") {
345         parse_quote!(::core::cmp::PartialOrd)
346     } else {
347         parse_quote!(::std::cmp::PartialOrd)
348     }
349 }
350 
351 /// Return the path of the `Ord` trait, that is `::std::cmp::Ord`.
ord_trait_path() -> syn::Path352 fn ord_trait_path() -> syn::Path {
353     if cfg!(feature = "use_core") {
354         parse_quote!(::core::cmp::Ord)
355     } else {
356         parse_quote!(::std::cmp::Ord)
357     }
358 }
359 
360 /// Return the path of the `Option` trait, that is `::std::option::Option`.
option_path() -> syn::Path361 fn option_path() -> syn::Path {
362     if cfg!(feature = "use_core") {
363         parse_quote!(::core::option::Option)
364     } else {
365         parse_quote!(::std::option::Option)
366     }
367 }
368 
369 /// Return the path of the `Ordering` trait, that is `::std::cmp::Ordering`.
ordering_path() -> syn::Path370 fn ordering_path() -> syn::Path {
371     if cfg!(feature = "use_core") {
372         parse_quote!(::core::cmp::Ordering)
373     } else {
374         parse_quote!(::std::cmp::Ordering)
375     }
376 }
377 
maybe_add_copy( input: &ast::Input, where_clause: Option<&syn::WhereClause>, field_filter: impl Fn(&ast::Field) -> bool, ) -> Option<syn::WhereClause>378 fn maybe_add_copy(
379     input: &ast::Input,
380     where_clause: Option<&syn::WhereClause>,
381     field_filter: impl Fn(&ast::Field) -> bool,
382 ) -> Option<syn::WhereClause> {
383     if input.attrs.is_packed && !input.body.is_empty() {
384         let mut new_where_clause = where_clause.cloned().unwrap_or_else(|| syn::WhereClause {
385             where_token: parse_quote!(where),
386             predicates: Default::default(),
387         });
388 
389         new_where_clause.predicates.extend(
390             input
391                 .body
392                 .all_fields()
393                 .into_iter()
394                 .filter(|f| field_filter(f))
395                 .map(|f| {
396                     let ty = f.ty;
397 
398                     let pred: syn::WherePredicate = parse_quote!(#ty: Copy);
399                     pred
400                 }),
401         );
402 
403         Some(new_where_clause)
404     } else {
405         None
406     }
407 }
408