1 #![recursion_limit = "2048"]
2 extern crate proc_macro;
3 #[macro_use]
4 extern crate quote;
5 
6 use proc_macro2::{Span, TokenStream};
7 use std::convert::TryFrom;
8 use syn::{
9     parse::{Parse, ParseStream},
10     parse_macro_input,
11     spanned::Spanned,
12     Expr, Ident, Item, ItemEnum, Token, Variant,
13 };
14 
15 struct Flag<'a> {
16     name: Ident,
17     span: Span,
18     value: FlagValue<'a>,
19 }
20 
21 enum FlagValue<'a> {
22     Literal(u128),
23     Deferred,
24     Inferred(&'a mut Variant),
25 }
26 
27 impl FlagValue<'_> {
28     // matches! is beyond our MSRV
29     #[allow(clippy::match_like_matches_macro)]
is_inferred(&self) -> bool30     fn is_inferred(&self) -> bool {
31         match self {
32             FlagValue::Inferred(_) => true,
33             _ => false,
34         }
35     }
36 }
37 
38 struct Parameters {
39     default: Vec<Ident>,
40 }
41 
42 impl Parse for Parameters {
parse(input: ParseStream) -> syn::parse::Result<Self>43     fn parse(input: ParseStream) -> syn::parse::Result<Self> {
44         if input.is_empty() {
45             return Ok(Parameters { default: vec![] });
46         }
47 
48         input.parse::<Token![default]>()?;
49         input.parse::<Token![=]>()?;
50         let mut default = vec![input.parse()?];
51         while !input.is_empty() {
52             input.parse::<Token![|]>()?;
53             default.push(input.parse()?);
54         }
55 
56         Ok(Parameters { default })
57     }
58 }
59 
60 #[proc_macro_attribute]
bitflags_internal( attr: proc_macro::TokenStream, input: proc_macro::TokenStream, ) -> proc_macro::TokenStream61 pub fn bitflags_internal(
62     attr: proc_macro::TokenStream,
63     input: proc_macro::TokenStream,
64 ) -> proc_macro::TokenStream {
65     let Parameters { default } = parse_macro_input!(attr as Parameters);
66     let mut ast = parse_macro_input!(input as Item);
67     let output = match ast {
68         Item::Enum(ref mut item_enum) => gen_enumflags(item_enum, default),
69         _ => Err(syn::Error::new_spanned(
70             &ast,
71             "#[bitflags] requires an enum",
72         )),
73     };
74 
75     output
76         .unwrap_or_else(|err| {
77             let error = err.to_compile_error();
78             quote! {
79                 #ast
80                 #error
81             }
82         })
83         .into()
84 }
85 
86 /// Try to evaluate the expression given.
fold_expr(expr: &syn::Expr) -> Option<u128>87 fn fold_expr(expr: &syn::Expr) -> Option<u128> {
88     match expr {
89         Expr::Lit(ref expr_lit) => match expr_lit.lit {
90             syn::Lit::Int(ref lit_int) => lit_int.base10_parse().ok(),
91             _ => None,
92         },
93         Expr::Binary(ref expr_binary) => {
94             let l = fold_expr(&expr_binary.left)?;
95             let r = fold_expr(&expr_binary.right)?;
96             match &expr_binary.op {
97                 syn::BinOp::Shl(_) => u32::try_from(r).ok().and_then(|r| l.checked_shl(r)),
98                 _ => None,
99             }
100         }
101         Expr::Paren(syn::ExprParen { expr, .. }) | Expr::Group(syn::ExprGroup { expr, .. }) => {
102             fold_expr(expr)
103         }
104         _ => None,
105     }
106 }
107 
collect_flags<'a>( variants: impl Iterator<Item = &'a mut Variant>, ) -> Result<Vec<Flag<'a>>, syn::Error>108 fn collect_flags<'a>(
109     variants: impl Iterator<Item = &'a mut Variant>,
110 ) -> Result<Vec<Flag<'a>>, syn::Error> {
111     variants
112         .map(|variant| {
113             // MSRV: Would this be cleaner with `matches!`?
114             match variant.fields {
115                 syn::Fields::Unit => (),
116                 _ => {
117                     return Err(syn::Error::new_spanned(
118                         &variant.fields,
119                         "Bitflag variants cannot contain additional data",
120                     ))
121                 }
122             }
123 
124             let name = variant.ident.clone();
125             let span = variant.span();
126             let value = if let Some(ref expr) = variant.discriminant {
127                 if let Some(n) = fold_expr(&expr.1) {
128                     FlagValue::Literal(n)
129                 } else {
130                     FlagValue::Deferred
131                 }
132             } else {
133                 FlagValue::Inferred(variant)
134             };
135 
136             Ok(Flag { name, span, value })
137         })
138         .collect()
139 }
140 
inferred_value(type_name: &Ident, previous_variants: &[Ident], repr: &Ident) -> Expr141 fn inferred_value(type_name: &Ident, previous_variants: &[Ident], repr: &Ident) -> Expr {
142     let tokens = if previous_variants.is_empty() {
143         quote!(1)
144     } else {
145         quote!(::enumflags2::_internal::next_bit(
146                 #(#type_name::#previous_variants as u128)|*
147         ) as #repr)
148     };
149 
150     syn::parse2(tokens).expect("couldn't parse inferred value")
151 }
152 
infer_values(flags: &mut [Flag], type_name: &Ident, repr: &Ident)153 fn infer_values(flags: &mut [Flag], type_name: &Ident, repr: &Ident) {
154     let mut previous_variants: Vec<Ident> = flags
155         .iter()
156         .filter(|flag| !flag.value.is_inferred())
157         .map(|flag| flag.name.clone())
158         .collect();
159 
160     for flag in flags {
161         if let FlagValue::Inferred(ref mut variant) = flag.value {
162             variant.discriminant = Some((
163                 <Token![=]>::default(),
164                 inferred_value(type_name, &previous_variants, repr),
165             ));
166             previous_variants.push(flag.name.clone());
167         }
168     }
169 }
170 
171 /// Given a list of attributes, find the `repr`, if any, and return the integer
172 /// type specified.
extract_repr(attrs: &[syn::Attribute]) -> Result<Option<Ident>, syn::Error>173 fn extract_repr(attrs: &[syn::Attribute]) -> Result<Option<Ident>, syn::Error> {
174     use syn::{Meta, NestedMeta};
175     attrs
176         .iter()
177         .find_map(|attr| match attr.parse_meta() {
178             Err(why) => Some(Err(syn::Error::new_spanned(
179                 attr,
180                 format!("Couldn't parse attribute: {}", why),
181             ))),
182             Ok(Meta::List(ref meta)) if meta.path.is_ident("repr") => {
183                 meta.nested.iter().find_map(|mi| match mi {
184                     NestedMeta::Meta(Meta::Path(path)) => path.get_ident().cloned().map(Ok),
185                     _ => None,
186                 })
187             }
188             Ok(_) => None,
189         })
190         .transpose()
191 }
192 
193 /// Check the repr and return the number of bits available
type_bits(ty: &Ident) -> Result<u8, syn::Error>194 fn type_bits(ty: &Ident) -> Result<u8, syn::Error> {
195     // This would be so much easier if we could just match on an Ident...
196     if ty == "usize" {
197         Err(syn::Error::new_spanned(
198             ty,
199             "#[repr(usize)] is not supported. Use u32 or u64 instead.",
200         ))
201     } else if ty == "i8"
202         || ty == "i16"
203         || ty == "i32"
204         || ty == "i64"
205         || ty == "i128"
206         || ty == "isize"
207     {
208         Err(syn::Error::new_spanned(
209             ty,
210             "Signed types in a repr are not supported.",
211         ))
212     } else if ty == "u8" {
213         Ok(8)
214     } else if ty == "u16" {
215         Ok(16)
216     } else if ty == "u32" {
217         Ok(32)
218     } else if ty == "u64" {
219         Ok(64)
220     } else if ty == "u128" {
221         Ok(128)
222     } else {
223         Err(syn::Error::new_spanned(
224             ty,
225             "repr must be an integer type for #[bitflags].",
226         ))
227     }
228 }
229 
230 /// Returns deferred checks
check_flag(type_name: &Ident, flag: &Flag, bits: u8) -> Result<Option<TokenStream>, syn::Error>231 fn check_flag(type_name: &Ident, flag: &Flag, bits: u8) -> Result<Option<TokenStream>, syn::Error> {
232     use FlagValue::*;
233     match flag.value {
234         Literal(n) => {
235             if !n.is_power_of_two() {
236                 Err(syn::Error::new(
237                     flag.span,
238                     "Flags must have exactly one set bit",
239                 ))
240             } else if bits < 128 && n >= 1 << bits {
241                 Err(syn::Error::new(
242                     flag.span,
243                     format!("Flag value out of range for u{}", bits),
244                 ))
245             } else {
246                 Ok(None)
247             }
248         }
249         Inferred(_) => Ok(None),
250         Deferred => {
251             let variant_name = &flag.name;
252             // MSRV: Use an unnamed constant (`const _: ...`).
253             let assertion_name = syn::Ident::new(
254                 &format!("__enumflags_assertion_{}_{}", type_name, flag.name),
255                 Span::call_site(),
256             ); // call_site because def_site is unstable
257 
258             Ok(Some(quote_spanned!(flag.span =>
259                 #[doc(hidden)]
260                 const #assertion_name:
261                     <<[(); (
262                         (#type_name::#variant_name as u128).is_power_of_two()
263                     ) as usize] as enumflags2::_internal::AssertionHelper>
264                         ::Status as enumflags2::_internal::ExactlyOneBitSet>::X
265                     = ();
266             )))
267         }
268     }
269 }
270 
gen_enumflags(ast: &mut ItemEnum, default: Vec<Ident>) -> Result<TokenStream, syn::Error>271 fn gen_enumflags(ast: &mut ItemEnum, default: Vec<Ident>) -> Result<TokenStream, syn::Error> {
272     let ident = &ast.ident;
273 
274     let span = Span::call_site();
275 
276     let repr = extract_repr(&ast.attrs)?
277         .ok_or_else(|| syn::Error::new_spanned(&ident,
278                         "repr attribute missing. Add #[repr(u64)] or a similar attribute to specify the size of the bitfield."))?;
279     let bits = type_bits(&repr)?;
280 
281     let mut variants = collect_flags(ast.variants.iter_mut())?;
282     let deferred = variants
283         .iter()
284         .flat_map(|variant| check_flag(ident, variant, bits).transpose())
285         .collect::<Result<Vec<_>, _>>()?;
286 
287     infer_values(&mut variants, ident, &repr);
288 
289     if (bits as usize) < variants.len() {
290         return Err(syn::Error::new_spanned(
291             &repr,
292             format!("Not enough bits for {} flags", variants.len()),
293         ));
294     }
295 
296     let std_path = quote_spanned!(span => ::enumflags2::_internal::core);
297     let variant_names = ast.variants.iter().map(|v| &v.ident).collect::<Vec<_>>();
298 
299     Ok(quote_spanned! {
300         span =>
301             #ast
302             #(#deferred)*
303             impl #std_path::ops::Not for #ident {
304                 type Output = ::enumflags2::BitFlags<Self>;
305                 #[inline(always)]
306                 fn not(self) -> Self::Output {
307                     use ::enumflags2::{BitFlags, _internal::RawBitFlags};
308                     unsafe { BitFlags::from_bits_unchecked(self.bits()).not() }
309                 }
310             }
311 
312             impl #std_path::ops::BitOr for #ident {
313                 type Output = ::enumflags2::BitFlags<Self>;
314                 #[inline(always)]
315                 fn bitor(self, other: Self) -> Self::Output {
316                     use ::enumflags2::{BitFlags, _internal::RawBitFlags};
317                     unsafe { BitFlags::from_bits_unchecked(self.bits() | other.bits())}
318                 }
319             }
320 
321             impl #std_path::ops::BitAnd for #ident {
322                 type Output = ::enumflags2::BitFlags<Self>;
323                 #[inline(always)]
324                 fn bitand(self, other: Self) -> Self::Output {
325                     use ::enumflags2::{BitFlags, _internal::RawBitFlags};
326                     unsafe { BitFlags::from_bits_unchecked(self.bits() & other.bits())}
327                 }
328             }
329 
330             impl #std_path::ops::BitXor for #ident {
331                 type Output = ::enumflags2::BitFlags<Self>;
332                 #[inline(always)]
333                 fn bitxor(self, other: Self) -> Self::Output {
334                     #std_path::convert::Into::<Self::Output>::into(self) ^ #std_path::convert::Into::<Self::Output>::into(other)
335                 }
336             }
337 
338             impl ::enumflags2::_internal::RawBitFlags for #ident {
339                 type Numeric = #repr;
340 
341                 const EMPTY: Self::Numeric = 0;
342 
343                 const DEFAULT: Self::Numeric =
344                     0 #(| (Self::#default as #repr))*;
345 
346                 const ALL_BITS: Self::Numeric =
347                     0 #(| (Self::#variant_names as #repr))*;
348 
349                 const FLAG_LIST: &'static [Self] =
350                     &[#(Self::#variant_names),*];
351 
352                 const BITFLAGS_TYPE_NAME : &'static str =
353                     concat!("BitFlags<", stringify!(#ident), ">");
354 
355                 fn bits(self) -> Self::Numeric {
356                     self as #repr
357                 }
358             }
359 
360             impl ::enumflags2::BitFlag for #ident {}
361     })
362 }
363