1 #![recursion_limit="256"]
2 #![cfg_attr(feature = "nightly", feature(proc_macro_diagnostic))]
3
4 extern crate proc_macro;
5
6 use darling::*;
7 use proc_macro::TokenStream;
8 use proc_macro2::{TokenStream as SynTokenStream, Literal};
9 use syn::*;
10 use syn::export::Span;
11 use syn::spanned::Spanned;
12 use quote::*;
13
14 #[cfg(feature = "nightly")]
error(span: Span, data: &str) -> TokenStream15 fn error(span: Span, data: &str) -> TokenStream {
16 span.unstable().error(data).emit();
17 TokenStream::new()
18 }
19
20 #[cfg(not(feature = "nightly"))]
error(_: Span, data: &str) -> TokenStream21 fn error(_: Span, data: &str) -> TokenStream {
22 panic!("{}", data)
23 }
24
enum_set_type_impl( name: &Ident, all_variants: u128, repr: Ident, attrs: EnumsetAttrs, variants: Vec<Ident>, ) -> SynTokenStream25 fn enum_set_type_impl(
26 name: &Ident, all_variants: u128, repr: Ident, attrs: EnumsetAttrs, variants: Vec<Ident>,
27 ) -> SynTokenStream {
28 let is_uninhabited = variants.is_empty();
29 let is_zst = variants.len() == 1;
30
31 let typed_enumset = quote!(::enumset::EnumSet<#name>);
32 let core = quote!(::enumset::internal::core_export);
33 #[cfg(feature = "serde")]
34 let serde = quote!(::enumset::internal::serde);
35
36 // proc_macro2 does not support creating u128 literals.
37 let all_variants = Literal::u128_unsuffixed(all_variants);
38
39 let ops = if attrs.no_ops {
40 quote! {}
41 } else {
42 quote! {
43 impl <O : Into<#typed_enumset>> #core::ops::Sub<O> for #name {
44 type Output = #typed_enumset;
45 fn sub(self, other: O) -> Self::Output {
46 ::enumset::EnumSet::only(self) - other.into()
47 }
48 }
49 impl <O : Into<#typed_enumset>> #core::ops::BitAnd<O> for #name {
50 type Output = #typed_enumset;
51 fn bitand(self, other: O) -> Self::Output {
52 ::enumset::EnumSet::only(self) & other.into()
53 }
54 }
55 impl <O : Into<#typed_enumset>> #core::ops::BitOr<O> for #name {
56 type Output = #typed_enumset;
57 fn bitor(self, other: O) -> Self::Output {
58 ::enumset::EnumSet::only(self) | other.into()
59 }
60 }
61 impl <O : Into<#typed_enumset>> #core::ops::BitXor<O> for #name {
62 type Output = #typed_enumset;
63 fn bitxor(self, other: O) -> Self::Output {
64 ::enumset::EnumSet::only(self) ^ other.into()
65 }
66 }
67 impl #core::ops::Not for #name {
68 type Output = #typed_enumset;
69 fn not(self) -> Self::Output {
70 !::enumset::EnumSet::only(self)
71 }
72 }
73 impl #core::cmp::PartialEq<#typed_enumset> for #name {
74 fn eq(&self, other: &#typed_enumset) -> bool {
75 ::enumset::EnumSet::only(*self) == *other
76 }
77 }
78 }
79 };
80
81 #[cfg(feature = "serde")]
82 let serde_ops = if attrs.serialize_as_list {
83 let expecting_str = format!("a list of {}", name);
84 quote! {
85 fn serialize<S: #serde::Serializer>(
86 set: ::enumset::EnumSet<#name>, ser: S,
87 ) -> #core::result::Result<S::Ok, S::Error> {
88 use #serde::ser::SerializeSeq;
89 let mut seq = ser.serialize_seq(#core::prelude::v1::Some(set.len()))?;
90 for bit in set {
91 seq.serialize_element(&bit)?;
92 }
93 seq.end()
94 }
95 fn deserialize<'de, D: #serde::Deserializer<'de>>(
96 de: D,
97 ) -> #core::result::Result<::enumset::EnumSet<#name>, D::Error> {
98 struct Visitor;
99 impl <'de> #serde::de::Visitor<'de> for Visitor {
100 type Value = ::enumset::EnumSet<#name>;
101 fn expecting(
102 &self, formatter: &mut #core::fmt::Formatter,
103 ) -> #core::fmt::Result {
104 write!(formatter, #expecting_str)
105 }
106 fn visit_seq<A>(
107 mut self, mut seq: A,
108 ) -> #core::result::Result<Self::Value, A::Error> where
109 A: #serde::de::SeqAccess<'de>
110 {
111 let mut accum = ::enumset::EnumSet::<#name>::new();
112 while let #core::prelude::v1::Some(val) = seq.next_element::<#name>()? {
113 accum |= val;
114 }
115 #core::prelude::v1::Ok(accum)
116 }
117 }
118 de.deserialize_seq(Visitor)
119 }
120 }
121 } else {
122 let serialize_repr = attrs.serialize_repr.as_ref()
123 .map(|x| Ident::new(&x, Span::call_site()))
124 .unwrap_or(repr.clone());
125 let check_unknown = if attrs.serialize_deny_unknown {
126 quote! {
127 if value & !#all_variants != 0 {
128 use #serde::de::Error;
129 return #core::prelude::v1::Err(
130 D::Error::custom("enumset contains unknown bits")
131 )
132 }
133 }
134 } else {
135 quote! { }
136 };
137 quote! {
138 fn serialize<S: #serde::Serializer>(
139 set: ::enumset::EnumSet<#name>, ser: S,
140 ) -> #core::result::Result<S::Ok, S::Error> {
141 #serde::Serialize::serialize(&(set.__enumset_underlying as #serialize_repr), ser)
142 }
143 fn deserialize<'de, D: #serde::Deserializer<'de>>(
144 de: D,
145 ) -> #core::result::Result<::enumset::EnumSet<#name>, D::Error> {
146 let value = <#serialize_repr as #serde::Deserialize>::deserialize(de)?;
147 #check_unknown
148 #core::prelude::v1::Ok(::enumset::EnumSet {
149 __enumset_underlying: (value & #all_variants) as #repr,
150 })
151 }
152 }
153 };
154
155 #[cfg(not(feature = "serde"))]
156 let serde_ops = quote! { };
157
158 let into_impl = if is_uninhabited {
159 quote! {
160 fn enum_into_u8(self) -> u8 {
161 panic!(concat!(stringify!(#name), " is uninhabited."))
162 }
163 unsafe fn enum_from_u8(val: u8) -> Self {
164 panic!(concat!(stringify!(#name), " is uninhabited."))
165 }
166 }
167 } else if is_zst {
168 let variant = &variants[0];
169 quote! {
170 fn enum_into_u8(self) -> u8 {
171 self as u8
172 }
173 unsafe fn enum_from_u8(val: u8) -> Self {
174 #name::#variant
175 }
176 }
177 } else {
178 quote! {
179 fn enum_into_u8(self) -> u8 {
180 self as u8
181 }
182 unsafe fn enum_from_u8(val: u8) -> Self {
183 #core::mem::transmute(val)
184 }
185 }
186 };
187
188 let eq_impl = if is_uninhabited {
189 quote!(panic!(concat!(stringify!(#name), " is uninhabited.")))
190 } else {
191 quote!((*self as u8) == (*other as u8))
192 };
193
194 quote! {
195 unsafe impl ::enumset::internal::EnumSetTypePrivate for #name {
196 type Repr = #repr;
197 const ALL_BITS: Self::Repr = #all_variants;
198 #into_impl
199 #serde_ops
200 }
201
202 unsafe impl ::enumset::EnumSetType for #name { }
203
204 impl #core::cmp::PartialEq for #name {
205 fn eq(&self, other: &Self) -> bool {
206 #eq_impl
207 }
208 }
209 impl #core::cmp::Eq for #name { }
210 impl #core::clone::Clone for #name {
211 fn clone(&self) -> Self {
212 *self
213 }
214 }
215 impl #core::marker::Copy for #name { }
216
217 #ops
218 }
219 }
220
221 #[derive(FromDeriveInput, Default)]
222 #[darling(attributes(enumset), default)]
223 struct EnumsetAttrs {
224 no_ops: bool,
225 serialize_as_list: bool,
226 serialize_deny_unknown: bool,
227 #[darling(default)]
228 serialize_repr: Option<String>,
229 }
230
231 #[proc_macro_derive(EnumSetType, attributes(enumset))]
derive_enum_set_type(input: TokenStream) -> TokenStream232 pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
233 let input: DeriveInput = parse_macro_input!(input);
234 if let Data::Enum(data) = &input.data {
235 if !input.generics.params.is_empty() {
236 error(input.generics.span(),
237 "`#[derive(EnumSetType)]` cannot be used on enums with type parameters.")
238 } else {
239 let mut all_variants = 0u128;
240 let mut max_variant = 0;
241 let mut current_variant = 0;
242 let mut has_manual_discriminant = false;
243 let mut variants = Vec::new();
244
245 for variant in &data.variants {
246 if let Fields::Unit = variant.fields {
247 if let Some((_, expr)) = &variant.discriminant {
248 if let Expr::Lit(ExprLit { lit: Lit::Int(i), .. }) = expr {
249 current_variant = match i.base10_parse() {
250 Ok(val) => val,
251 Err(_) => return error(expr.span(), "Error parsing discriminant."),
252 };
253 has_manual_discriminant = true;
254 } else {
255 return error(variant.span(), "Unrecognized discriminant for variant.")
256 }
257 }
258
259 if current_variant >= 128 {
260 let message = if has_manual_discriminant {
261 "`#[derive(EnumSetType)]` only supports enum discriminants up to 127."
262 } else {
263 "`#[derive(EnumSetType)]` only supports enums up to 128 variants."
264 };
265 return error(variant.span(), message)
266 }
267
268 if all_variants & (1 << current_variant) != 0 {
269 return error(variant.span(),
270 &format!("Duplicate enum discriminant: {}", current_variant))
271 }
272 all_variants |= 1 << current_variant;
273 if current_variant > max_variant {
274 max_variant = current_variant
275 }
276
277 variants.push(variant.ident.clone());
278 current_variant += 1;
279 } else {
280 return error(variant.span(),
281 "`#[derive(EnumSetType)]` can only be used on C-like enums.")
282 }
283 }
284
285 let repr = Ident::new(if max_variant <= 7 {
286 "u8"
287 } else if max_variant <= 15 {
288 "u16"
289 } else if max_variant <= 31 {
290 "u32"
291 } else if max_variant <= 63 {
292 "u64"
293 } else if max_variant <= 127 {
294 "u128"
295 } else {
296 panic!("max_variant > 127?")
297 }, Span::call_site());
298
299 let attrs: EnumsetAttrs = match EnumsetAttrs::from_derive_input(&input) {
300 Ok(attrs) => attrs,
301 Err(e) => return e.write_errors().into(),
302 };
303
304 match attrs.serialize_repr.as_ref().map(|x| x.as_str()) {
305 Some("u8") => if max_variant > 7 {
306 return error(input.span(), "Too many variants for u8 serialization repr.")
307 }
308 Some("u16") => if max_variant > 15 {
309 return error(input.span(), "Too many variants for u16 serialization repr.")
310 }
311 Some("u32") => if max_variant > 31 {
312 return error(input.span(), "Too many variants for u32 serialization repr.")
313 }
314 Some("u64") => if max_variant > 63 {
315 return error(input.span(), "Too many variants for u64 serialization repr.")
316 }
317 Some("u128") => if max_variant > 127 {
318 return error(input.span(), "Too many variants for u128 serialization repr.")
319 }
320 None => { }
321 Some(x) => return error(input.span(),
322 &format!("{} is not a valid serialization repr.", x)),
323 };
324
325 enum_set_type_impl(&input.ident, all_variants, repr, attrs, variants).into()
326 }
327 } else {
328 error(input.span(), "`#[derive(EnumSetType)]` may only be used on enums")
329 }
330 }
331