1 #![recursion_limit="256"]
2 #![cfg_attr(feature = "nightly", feature(proc_macro_diagnostic))]
3 
4 extern crate proc_macro;
5 
6 use darling::*;
7 use proc_macro::TokenStream;
8 use proc_macro2::{TokenStream as SynTokenStream, Literal};
9 use syn::*;
10 use syn::export::Span;
11 use syn::spanned::Spanned;
12 use quote::*;
13 
14 #[cfg(feature = "nightly")]
error(span: Span, data: &str) -> TokenStream15 fn error(span: Span, data: &str) -> TokenStream {
16     span.unstable().error(data).emit();
17     TokenStream::new()
18 }
19 
20 #[cfg(not(feature = "nightly"))]
error(_: Span, data: &str) -> TokenStream21 fn error(_: Span, data: &str) -> TokenStream {
22     panic!("{}", data)
23 }
24 
enum_set_type_impl( name: &Ident, all_variants: u128, repr: Ident, attrs: EnumsetAttrs, variants: Vec<Ident>, ) -> SynTokenStream25 fn enum_set_type_impl(
26     name: &Ident, all_variants: u128, repr: Ident, attrs: EnumsetAttrs, variants: Vec<Ident>,
27 ) -> SynTokenStream {
28     let is_uninhabited = variants.is_empty();
29     let is_zst = variants.len() == 1;
30 
31     let typed_enumset = quote!(::enumset::EnumSet<#name>);
32     let core = quote!(::enumset::internal::core_export);
33     #[cfg(feature = "serde")]
34     let serde = quote!(::enumset::internal::serde);
35 
36     // proc_macro2 does not support creating u128 literals.
37     let all_variants = Literal::u128_unsuffixed(all_variants);
38 
39     let ops = if attrs.no_ops {
40         quote! {}
41     } else {
42         quote! {
43             impl <O : Into<#typed_enumset>> #core::ops::Sub<O> for #name {
44                 type Output = #typed_enumset;
45                 fn sub(self, other: O) -> Self::Output {
46                     ::enumset::EnumSet::only(self) - other.into()
47                 }
48             }
49             impl <O : Into<#typed_enumset>> #core::ops::BitAnd<O> for #name {
50                 type Output = #typed_enumset;
51                 fn bitand(self, other: O) -> Self::Output {
52                     ::enumset::EnumSet::only(self) & other.into()
53                 }
54             }
55             impl <O : Into<#typed_enumset>> #core::ops::BitOr<O> for #name {
56                 type Output = #typed_enumset;
57                 fn bitor(self, other: O) -> Self::Output {
58                     ::enumset::EnumSet::only(self) | other.into()
59                 }
60             }
61             impl <O : Into<#typed_enumset>> #core::ops::BitXor<O> for #name {
62                 type Output = #typed_enumset;
63                 fn bitxor(self, other: O) -> Self::Output {
64                     ::enumset::EnumSet::only(self) ^ other.into()
65                 }
66             }
67             impl #core::ops::Not for #name {
68                 type Output = #typed_enumset;
69                 fn not(self) -> Self::Output {
70                     !::enumset::EnumSet::only(self)
71                 }
72             }
73             impl #core::cmp::PartialEq<#typed_enumset> for #name {
74                 fn eq(&self, other: &#typed_enumset) -> bool {
75                     ::enumset::EnumSet::only(*self) == *other
76                 }
77             }
78         }
79     };
80 
81     #[cfg(feature = "serde")]
82     let serde_ops = if attrs.serialize_as_list {
83         let expecting_str = format!("a list of {}", name);
84         quote! {
85             fn serialize<S: #serde::Serializer>(
86                 set: ::enumset::EnumSet<#name>, ser: S,
87             ) -> #core::result::Result<S::Ok, S::Error> {
88                 use #serde::ser::SerializeSeq;
89                 let mut seq = ser.serialize_seq(#core::prelude::v1::Some(set.len()))?;
90                 for bit in set {
91                     seq.serialize_element(&bit)?;
92                 }
93                 seq.end()
94             }
95             fn deserialize<'de, D: #serde::Deserializer<'de>>(
96                 de: D,
97             ) -> #core::result::Result<::enumset::EnumSet<#name>, D::Error> {
98                 struct Visitor;
99                 impl <'de> #serde::de::Visitor<'de> for Visitor {
100                     type Value = ::enumset::EnumSet<#name>;
101                     fn expecting(
102                         &self, formatter: &mut #core::fmt::Formatter,
103                     ) -> #core::fmt::Result {
104                         write!(formatter, #expecting_str)
105                     }
106                     fn visit_seq<A>(
107                         mut self, mut seq: A,
108                     ) -> #core::result::Result<Self::Value, A::Error> where
109                         A: #serde::de::SeqAccess<'de>
110                     {
111                         let mut accum = ::enumset::EnumSet::<#name>::new();
112                         while let #core::prelude::v1::Some(val) = seq.next_element::<#name>()? {
113                             accum |= val;
114                         }
115                         #core::prelude::v1::Ok(accum)
116                     }
117                 }
118                 de.deserialize_seq(Visitor)
119             }
120         }
121     } else {
122         let serialize_repr = attrs.serialize_repr.as_ref()
123             .map(|x| Ident::new(&x, Span::call_site()))
124             .unwrap_or(repr.clone());
125         let check_unknown = if attrs.serialize_deny_unknown {
126             quote! {
127                 if value & !#all_variants != 0 {
128                     use #serde::de::Error;
129                     return #core::prelude::v1::Err(
130                         D::Error::custom("enumset contains unknown bits")
131                     )
132                 }
133             }
134         } else {
135             quote! { }
136         };
137         quote! {
138             fn serialize<S: #serde::Serializer>(
139                 set: ::enumset::EnumSet<#name>, ser: S,
140             ) -> #core::result::Result<S::Ok, S::Error> {
141                 #serde::Serialize::serialize(&(set.__enumset_underlying as #serialize_repr), ser)
142             }
143             fn deserialize<'de, D: #serde::Deserializer<'de>>(
144                 de: D,
145             ) -> #core::result::Result<::enumset::EnumSet<#name>, D::Error> {
146                 let value = <#serialize_repr as #serde::Deserialize>::deserialize(de)?;
147                 #check_unknown
148                 #core::prelude::v1::Ok(::enumset::EnumSet {
149                     __enumset_underlying: (value & #all_variants) as #repr,
150                 })
151             }
152         }
153     };
154 
155     #[cfg(not(feature = "serde"))]
156     let serde_ops = quote! { };
157 
158     let into_impl = if is_uninhabited {
159         quote! {
160             fn enum_into_u8(self) -> u8 {
161                 panic!(concat!(stringify!(#name), " is uninhabited."))
162             }
163             unsafe fn enum_from_u8(val: u8) -> Self {
164                 panic!(concat!(stringify!(#name), " is uninhabited."))
165             }
166         }
167     } else if is_zst {
168         let variant = &variants[0];
169         quote! {
170             fn enum_into_u8(self) -> u8 {
171                 self as u8
172             }
173             unsafe fn enum_from_u8(val: u8) -> Self {
174                 #name::#variant
175             }
176         }
177     } else {
178         quote! {
179             fn enum_into_u8(self) -> u8 {
180                 self as u8
181             }
182             unsafe fn enum_from_u8(val: u8) -> Self {
183                 #core::mem::transmute(val)
184             }
185         }
186     };
187 
188     let eq_impl = if is_uninhabited {
189         quote!(panic!(concat!(stringify!(#name), " is uninhabited.")))
190     } else {
191         quote!((*self as u8) == (*other as u8))
192     };
193 
194     quote! {
195         unsafe impl ::enumset::internal::EnumSetTypePrivate for #name {
196             type Repr = #repr;
197             const ALL_BITS: Self::Repr = #all_variants;
198             #into_impl
199             #serde_ops
200         }
201 
202         unsafe impl ::enumset::EnumSetType for #name { }
203 
204         impl #core::cmp::PartialEq for #name {
205             fn eq(&self, other: &Self) -> bool {
206                 #eq_impl
207             }
208         }
209         impl #core::cmp::Eq for #name { }
210         impl #core::clone::Clone for #name {
211             fn clone(&self) -> Self {
212                 *self
213             }
214         }
215         impl #core::marker::Copy for #name { }
216 
217         #ops
218     }
219 }
220 
221 #[derive(FromDeriveInput, Default)]
222 #[darling(attributes(enumset), default)]
223 struct EnumsetAttrs {
224     no_ops: bool,
225     serialize_as_list: bool,
226     serialize_deny_unknown: bool,
227     #[darling(default)]
228     serialize_repr: Option<String>,
229 }
230 
231 #[proc_macro_derive(EnumSetType, attributes(enumset))]
derive_enum_set_type(input: TokenStream) -> TokenStream232 pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
233     let input: DeriveInput = parse_macro_input!(input);
234     if let Data::Enum(data) = &input.data {
235         if !input.generics.params.is_empty() {
236             error(input.generics.span(),
237                   "`#[derive(EnumSetType)]` cannot be used on enums with type parameters.")
238         } else {
239             let mut all_variants = 0u128;
240             let mut max_variant = 0;
241             let mut current_variant = 0;
242             let mut has_manual_discriminant = false;
243             let mut variants = Vec::new();
244 
245             for variant in &data.variants {
246                 if let Fields::Unit = variant.fields {
247                     if let Some((_, expr)) = &variant.discriminant {
248                         if let Expr::Lit(ExprLit { lit: Lit::Int(i), .. }) = expr {
249                             current_variant = match i.base10_parse() {
250                                 Ok(val) => val,
251                                 Err(_) => return error(expr.span(), "Error parsing discriminant."),
252                             };
253                             has_manual_discriminant = true;
254                         } else {
255                             return error(variant.span(), "Unrecognized discriminant for variant.")
256                         }
257                     }
258 
259                     if current_variant >= 128 {
260                         let message = if has_manual_discriminant {
261                             "`#[derive(EnumSetType)]` only supports enum discriminants up to 127."
262                         } else {
263                             "`#[derive(EnumSetType)]` only supports enums up to 128 variants."
264                         };
265                         return error(variant.span(), message)
266                     }
267 
268                     if all_variants & (1 << current_variant) != 0 {
269                         return error(variant.span(),
270                                      &format!("Duplicate enum discriminant: {}", current_variant))
271                     }
272                     all_variants |= 1 << current_variant;
273                     if current_variant > max_variant {
274                         max_variant = current_variant
275                     }
276 
277                     variants.push(variant.ident.clone());
278                     current_variant += 1;
279                 } else {
280                     return error(variant.span(),
281                                  "`#[derive(EnumSetType)]` can only be used on C-like enums.")
282                 }
283             }
284 
285             let repr = Ident::new(if max_variant <= 7 {
286                 "u8"
287             } else if max_variant <= 15 {
288                 "u16"
289             } else if max_variant <= 31 {
290                 "u32"
291             } else if max_variant <= 63 {
292                 "u64"
293             } else if max_variant <= 127 {
294                 "u128"
295             } else {
296                 panic!("max_variant > 127?")
297             }, Span::call_site());
298 
299             let attrs: EnumsetAttrs = match EnumsetAttrs::from_derive_input(&input) {
300                 Ok(attrs) => attrs,
301                 Err(e) => return e.write_errors().into(),
302             };
303 
304             match attrs.serialize_repr.as_ref().map(|x| x.as_str()) {
305                 Some("u8") => if max_variant > 7 {
306                     return error(input.span(), "Too many variants for u8 serialization repr.")
307                 }
308                 Some("u16") => if max_variant > 15 {
309                     return error(input.span(), "Too many variants for u16 serialization repr.")
310                 }
311                 Some("u32") => if max_variant > 31 {
312                     return error(input.span(), "Too many variants for u32 serialization repr.")
313                 }
314                 Some("u64") => if max_variant > 63 {
315                     return error(input.span(), "Too many variants for u64 serialization repr.")
316                 }
317                 Some("u128") => if max_variant > 127 {
318                     return error(input.span(), "Too many variants for u128 serialization repr.")
319                 }
320                 None => { }
321                 Some(x) => return error(input.span(),
322                                         &format!("{} is not a valid serialization repr.", x)),
323             };
324 
325             enum_set_type_impl(&input.ident, all_variants, repr, attrs, variants).into()
326         }
327     } else {
328         error(input.span(), "`#[derive(EnumSetType)]` may only be used on enums")
329     }
330 }
331