1 #![recursion_limit = "2048"]
2 extern crate proc_macro;
3 #[macro_use]
4 extern crate quote;
5
6 use proc_macro2::{Span, TokenStream};
7 use std::convert::TryFrom;
8 use syn::{
9 parse::{Parse, ParseStream},
10 parse_macro_input,
11 spanned::Spanned,
12 Expr, Ident, Item, ItemEnum, Token, Variant,
13 };
14
15 struct Flag<'a> {
16 name: Ident,
17 span: Span,
18 value: FlagValue<'a>,
19 }
20
21 enum FlagValue<'a> {
22 Literal(u128),
23 Deferred,
24 Inferred(&'a mut Variant),
25 }
26
27 impl FlagValue<'_> {
28 // matches! is beyond our MSRV
29 #[allow(clippy::match_like_matches_macro)]
is_inferred(&self) -> bool30 fn is_inferred(&self) -> bool {
31 match self {
32 FlagValue::Inferred(_) => true,
33 _ => false,
34 }
35 }
36 }
37
38 struct Parameters {
39 default: Vec<Ident>,
40 }
41
42 impl Parse for Parameters {
parse(input: ParseStream) -> syn::parse::Result<Self>43 fn parse(input: ParseStream) -> syn::parse::Result<Self> {
44 if input.is_empty() {
45 return Ok(Parameters { default: vec![] });
46 }
47
48 input.parse::<Token![default]>()?;
49 input.parse::<Token![=]>()?;
50 let mut default = vec![input.parse()?];
51 while !input.is_empty() {
52 input.parse::<Token![|]>()?;
53 default.push(input.parse()?);
54 }
55
56 Ok(Parameters { default })
57 }
58 }
59
60 #[proc_macro_attribute]
bitflags_internal( attr: proc_macro::TokenStream, input: proc_macro::TokenStream, ) -> proc_macro::TokenStream61 pub fn bitflags_internal(
62 attr: proc_macro::TokenStream,
63 input: proc_macro::TokenStream,
64 ) -> proc_macro::TokenStream {
65 let Parameters { default } = parse_macro_input!(attr as Parameters);
66 let mut ast = parse_macro_input!(input as Item);
67 let output = match ast {
68 Item::Enum(ref mut item_enum) => gen_enumflags(item_enum, default),
69 _ => Err(syn::Error::new_spanned(
70 &ast,
71 "#[bitflags] requires an enum",
72 )),
73 };
74
75 output
76 .unwrap_or_else(|err| {
77 let error = err.to_compile_error();
78 quote! {
79 #ast
80 #error
81 }
82 })
83 .into()
84 }
85
86 /// Try to evaluate the expression given.
fold_expr(expr: &syn::Expr) -> Option<u128>87 fn fold_expr(expr: &syn::Expr) -> Option<u128> {
88 match expr {
89 Expr::Lit(ref expr_lit) => match expr_lit.lit {
90 syn::Lit::Int(ref lit_int) => lit_int.base10_parse().ok(),
91 _ => None,
92 },
93 Expr::Binary(ref expr_binary) => {
94 let l = fold_expr(&expr_binary.left)?;
95 let r = fold_expr(&expr_binary.right)?;
96 match &expr_binary.op {
97 syn::BinOp::Shl(_) => u32::try_from(r).ok().and_then(|r| l.checked_shl(r)),
98 _ => None,
99 }
100 }
101 Expr::Paren(syn::ExprParen { expr, .. }) | Expr::Group(syn::ExprGroup { expr, .. }) => {
102 fold_expr(expr)
103 }
104 _ => None,
105 }
106 }
107
collect_flags<'a>( variants: impl Iterator<Item = &'a mut Variant>, ) -> Result<Vec<Flag<'a>>, syn::Error>108 fn collect_flags<'a>(
109 variants: impl Iterator<Item = &'a mut Variant>,
110 ) -> Result<Vec<Flag<'a>>, syn::Error> {
111 variants
112 .map(|variant| {
113 // MSRV: Would this be cleaner with `matches!`?
114 match variant.fields {
115 syn::Fields::Unit => (),
116 _ => {
117 return Err(syn::Error::new_spanned(
118 &variant.fields,
119 "Bitflag variants cannot contain additional data",
120 ))
121 }
122 }
123
124 let name = variant.ident.clone();
125 let span = variant.span();
126 let value = if let Some(ref expr) = variant.discriminant {
127 if let Some(n) = fold_expr(&expr.1) {
128 FlagValue::Literal(n)
129 } else {
130 FlagValue::Deferred
131 }
132 } else {
133 FlagValue::Inferred(variant)
134 };
135
136 Ok(Flag { name, span, value })
137 })
138 .collect()
139 }
140
inferred_value(type_name: &Ident, previous_variants: &[Ident], repr: &Ident) -> Expr141 fn inferred_value(type_name: &Ident, previous_variants: &[Ident], repr: &Ident) -> Expr {
142 let tokens = if previous_variants.is_empty() {
143 quote!(1)
144 } else {
145 quote!(::enumflags2::_internal::next_bit(
146 #(#type_name::#previous_variants as u128)|*
147 ) as #repr)
148 };
149
150 syn::parse2(tokens).expect("couldn't parse inferred value")
151 }
152
infer_values(flags: &mut [Flag], type_name: &Ident, repr: &Ident)153 fn infer_values(flags: &mut [Flag], type_name: &Ident, repr: &Ident) {
154 let mut previous_variants: Vec<Ident> = flags
155 .iter()
156 .filter(|flag| !flag.value.is_inferred())
157 .map(|flag| flag.name.clone())
158 .collect();
159
160 for flag in flags {
161 if let FlagValue::Inferred(ref mut variant) = flag.value {
162 variant.discriminant = Some((
163 <Token![=]>::default(),
164 inferred_value(type_name, &previous_variants, repr),
165 ));
166 previous_variants.push(flag.name.clone());
167 }
168 }
169 }
170
171 /// Given a list of attributes, find the `repr`, if any, and return the integer
172 /// type specified.
extract_repr(attrs: &[syn::Attribute]) -> Result<Option<Ident>, syn::Error>173 fn extract_repr(attrs: &[syn::Attribute]) -> Result<Option<Ident>, syn::Error> {
174 use syn::{Meta, NestedMeta};
175 attrs
176 .iter()
177 .find_map(|attr| match attr.parse_meta() {
178 Err(why) => Some(Err(syn::Error::new_spanned(
179 attr,
180 format!("Couldn't parse attribute: {}", why),
181 ))),
182 Ok(Meta::List(ref meta)) if meta.path.is_ident("repr") => {
183 meta.nested.iter().find_map(|mi| match mi {
184 NestedMeta::Meta(Meta::Path(path)) => path.get_ident().cloned().map(Ok),
185 _ => None,
186 })
187 }
188 Ok(_) => None,
189 })
190 .transpose()
191 }
192
193 /// Check the repr and return the number of bits available
type_bits(ty: &Ident) -> Result<u8, syn::Error>194 fn type_bits(ty: &Ident) -> Result<u8, syn::Error> {
195 // This would be so much easier if we could just match on an Ident...
196 if ty == "usize" {
197 Err(syn::Error::new_spanned(
198 ty,
199 "#[repr(usize)] is not supported. Use u32 or u64 instead.",
200 ))
201 } else if ty == "i8"
202 || ty == "i16"
203 || ty == "i32"
204 || ty == "i64"
205 || ty == "i128"
206 || ty == "isize"
207 {
208 Err(syn::Error::new_spanned(
209 ty,
210 "Signed types in a repr are not supported.",
211 ))
212 } else if ty == "u8" {
213 Ok(8)
214 } else if ty == "u16" {
215 Ok(16)
216 } else if ty == "u32" {
217 Ok(32)
218 } else if ty == "u64" {
219 Ok(64)
220 } else if ty == "u128" {
221 Ok(128)
222 } else {
223 Err(syn::Error::new_spanned(
224 ty,
225 "repr must be an integer type for #[bitflags].",
226 ))
227 }
228 }
229
230 /// Returns deferred checks
check_flag(type_name: &Ident, flag: &Flag, bits: u8) -> Result<Option<TokenStream>, syn::Error>231 fn check_flag(type_name: &Ident, flag: &Flag, bits: u8) -> Result<Option<TokenStream>, syn::Error> {
232 use FlagValue::*;
233 match flag.value {
234 Literal(n) => {
235 if !n.is_power_of_two() {
236 Err(syn::Error::new(
237 flag.span,
238 "Flags must have exactly one set bit",
239 ))
240 } else if bits < 128 && n >= 1 << bits {
241 Err(syn::Error::new(
242 flag.span,
243 format!("Flag value out of range for u{}", bits),
244 ))
245 } else {
246 Ok(None)
247 }
248 }
249 Inferred(_) => Ok(None),
250 Deferred => {
251 let variant_name = &flag.name;
252 // MSRV: Use an unnamed constant (`const _: ...`).
253 let assertion_name = syn::Ident::new(
254 &format!("__enumflags_assertion_{}_{}", type_name, flag.name),
255 Span::call_site(),
256 ); // call_site because def_site is unstable
257
258 Ok(Some(quote_spanned!(flag.span =>
259 #[doc(hidden)]
260 const #assertion_name:
261 <<[(); (
262 (#type_name::#variant_name as u128).is_power_of_two()
263 ) as usize] as enumflags2::_internal::AssertionHelper>
264 ::Status as enumflags2::_internal::ExactlyOneBitSet>::X
265 = ();
266 )))
267 }
268 }
269 }
270
gen_enumflags(ast: &mut ItemEnum, default: Vec<Ident>) -> Result<TokenStream, syn::Error>271 fn gen_enumflags(ast: &mut ItemEnum, default: Vec<Ident>) -> Result<TokenStream, syn::Error> {
272 let ident = &ast.ident;
273
274 let span = Span::call_site();
275
276 let repr = extract_repr(&ast.attrs)?
277 .ok_or_else(|| syn::Error::new_spanned(&ident,
278 "repr attribute missing. Add #[repr(u64)] or a similar attribute to specify the size of the bitfield."))?;
279 let bits = type_bits(&repr)?;
280
281 let mut variants = collect_flags(ast.variants.iter_mut())?;
282 let deferred = variants
283 .iter()
284 .flat_map(|variant| check_flag(ident, variant, bits).transpose())
285 .collect::<Result<Vec<_>, _>>()?;
286
287 infer_values(&mut variants, ident, &repr);
288
289 if (bits as usize) < variants.len() {
290 return Err(syn::Error::new_spanned(
291 &repr,
292 format!("Not enough bits for {} flags", variants.len()),
293 ));
294 }
295
296 let std_path = quote_spanned!(span => ::enumflags2::_internal::core);
297 let variant_names = ast.variants.iter().map(|v| &v.ident).collect::<Vec<_>>();
298
299 Ok(quote_spanned! {
300 span =>
301 #ast
302 #(#deferred)*
303 impl #std_path::ops::Not for #ident {
304 type Output = ::enumflags2::BitFlags<Self>;
305 #[inline(always)]
306 fn not(self) -> Self::Output {
307 use ::enumflags2::{BitFlags, _internal::RawBitFlags};
308 unsafe { BitFlags::from_bits_unchecked(self.bits()).not() }
309 }
310 }
311
312 impl #std_path::ops::BitOr for #ident {
313 type Output = ::enumflags2::BitFlags<Self>;
314 #[inline(always)]
315 fn bitor(self, other: Self) -> Self::Output {
316 use ::enumflags2::{BitFlags, _internal::RawBitFlags};
317 unsafe { BitFlags::from_bits_unchecked(self.bits() | other.bits())}
318 }
319 }
320
321 impl #std_path::ops::BitAnd for #ident {
322 type Output = ::enumflags2::BitFlags<Self>;
323 #[inline(always)]
324 fn bitand(self, other: Self) -> Self::Output {
325 use ::enumflags2::{BitFlags, _internal::RawBitFlags};
326 unsafe { BitFlags::from_bits_unchecked(self.bits() & other.bits())}
327 }
328 }
329
330 impl #std_path::ops::BitXor for #ident {
331 type Output = ::enumflags2::BitFlags<Self>;
332 #[inline(always)]
333 fn bitxor(self, other: Self) -> Self::Output {
334 #std_path::convert::Into::<Self::Output>::into(self) ^ #std_path::convert::Into::<Self::Output>::into(other)
335 }
336 }
337
338 impl ::enumflags2::_internal::RawBitFlags for #ident {
339 type Numeric = #repr;
340
341 const EMPTY: Self::Numeric = 0;
342
343 const DEFAULT: Self::Numeric =
344 0 #(| (Self::#default as #repr))*;
345
346 const ALL_BITS: Self::Numeric =
347 0 #(| (Self::#variant_names as #repr))*;
348
349 const FLAG_LIST: &'static [Self] =
350 &[#(Self::#variant_names),*];
351
352 const BITFLAGS_TYPE_NAME : &'static str =
353 concat!("BitFlags<", stringify!(#ident), ">");
354
355 fn bits(self) -> Self::Numeric {
356 self as #repr
357 }
358 }
359
360 impl ::enumflags2::BitFlag for #ident {}
361 })
362 }
363