1 extern crate proc_macro;
2 
3 use ::proc_macro::TokenStream;
4 use ::proc_macro2::Span;
5 use ::quote::{format_ident, quote};
6 use ::syn::{
7     parse::{Parse, ParseStream},
8     parse_macro_input, parse_quote,
9     spanned::Spanned,
10     Data, DeriveInput, Error, Expr, Fields, Ident, LitInt, LitStr, Meta, Result,
11 };
12 
13 macro_rules! die {
14     ($spanned:expr=>
15         $msg:expr
16     ) => {
17         return Err(Error::new_spanned($spanned, $msg))
18     };
19 
20     (
21         $msg:expr
22     ) => {
23         return Err(Error::new(Span::call_site(), $msg))
24     };
25 }
26 
literal(i: u64) -> Expr27 fn literal(i: u64) -> Expr {
28     let literal = LitInt::new(&i.to_string(), Span::call_site());
29     parse_quote! {
30         #literal
31     }
32 }
33 
34 mod kw {
35     syn::custom_keyword!(default);
36     syn::custom_keyword!(alternatives);
37 }
38 
39 struct NumEnumVariantAttributes {
40     items: syn::punctuated::Punctuated<NumEnumVariantAttributeItem, syn::Token![,]>,
41 }
42 
43 impl Parse for NumEnumVariantAttributes {
parse(input: ParseStream<'_>) -> Result<Self>44     fn parse(input: ParseStream<'_>) -> Result<Self> {
45         Ok(Self {
46             items: input.parse_terminated(NumEnumVariantAttributeItem::parse)?,
47         })
48     }
49 }
50 
51 enum NumEnumVariantAttributeItem {
52     Default(VariantDefaultAttribute),
53     Alternatives(VariantAlternativesAttribute),
54 }
55 
56 impl Parse for NumEnumVariantAttributeItem {
parse(input: ParseStream<'_>) -> Result<Self>57     fn parse(input: ParseStream<'_>) -> Result<Self> {
58         let lookahead = input.lookahead1();
59         if lookahead.peek(kw::default) {
60             input.parse().map(Self::Default)
61         } else if lookahead.peek(kw::alternatives) {
62             input.parse().map(Self::Alternatives)
63         } else {
64             Err(lookahead.error())
65         }
66     }
67 }
68 
69 struct VariantDefaultAttribute {
70     keyword: kw::default,
71 }
72 
73 impl Parse for VariantDefaultAttribute {
parse(input: ParseStream) -> Result<Self>74     fn parse(input: ParseStream) -> Result<Self> {
75         Ok(Self {
76             keyword: input.parse()?,
77         })
78     }
79 }
80 
81 impl Spanned for VariantDefaultAttribute {
span(&self) -> Span82     fn span(&self) -> Span {
83         self.keyword.span()
84     }
85 }
86 
87 struct VariantAlternativesAttribute {
88     keyword: kw::alternatives,
89     _eq_token: syn::Token![=],
90     _bracket_token: syn::token::Bracket,
91     expressions: syn::punctuated::Punctuated<Expr, syn::Token![,]>,
92 }
93 
94 impl Parse for VariantAlternativesAttribute {
parse(input: ParseStream) -> Result<Self>95     fn parse(input: ParseStream) -> Result<Self> {
96         let content;
97         let keyword = input.parse()?;
98         let _eq_token = input.parse()?;
99         let _bracket_token = syn::bracketed!(content in input);
100         let expressions = content.parse_terminated(Expr::parse)?;
101         Ok(Self {
102             keyword,
103             _eq_token,
104             _bracket_token,
105             expressions,
106         })
107     }
108 }
109 
110 impl Spanned for VariantAlternativesAttribute {
span(&self) -> Span111     fn span(&self) -> Span {
112         self.keyword.span()
113     }
114 }
115 
116 #[derive(::core::default::Default)]
117 struct AttributeSpans {
118     default: Vec<Span>,
119     alternatives: Vec<Span>,
120 }
121 
122 struct VariantInfo {
123     ident: Ident,
124     attr_spans: AttributeSpans,
125     is_default: bool,
126     canonical_value: Expr,
127     alternative_values: Vec<Expr>,
128 }
129 
130 impl VariantInfo {
all_values(&self) -> impl Iterator<Item = &Expr>131     fn all_values(&self) -> impl Iterator<Item = &Expr> {
132         ::core::iter::once(&self.canonical_value).chain(self.alternative_values.iter())
133     }
134 
is_complex(&self) -> bool135     fn is_complex(&self) -> bool {
136         !self.alternative_values.is_empty()
137     }
138 }
139 
140 struct EnumInfo {
141     name: Ident,
142     repr: Ident,
143     variants: Vec<VariantInfo>,
144 }
145 
146 impl EnumInfo {
has_default_variant(&self) -> bool147     fn has_default_variant(&self) -> bool {
148         self.default().is_some()
149     }
150 
has_complex_variant(&self) -> bool151     fn has_complex_variant(&self) -> bool {
152         self.variants.iter().any(|info| info.is_complex())
153     }
154 
default(&self) -> Option<&Ident>155     fn default(&self) -> Option<&Ident> {
156         self.variants
157             .iter()
158             .find(|info| info.is_default)
159             .map(|info| &info.ident)
160     }
161 
first_default_attr_span(&self) -> Option<&Span>162     fn first_default_attr_span(&self) -> Option<&Span> {
163         self.variants
164             .iter()
165             .find_map(|info| info.attr_spans.default.first())
166     }
167 
first_alternatives_attr_span(&self) -> Option<&Span>168     fn first_alternatives_attr_span(&self) -> Option<&Span> {
169         self.variants
170             .iter()
171             .find_map(|info| info.attr_spans.alternatives.first())
172     }
173 
variant_idents(&self) -> Vec<Ident>174     fn variant_idents(&self) -> Vec<Ident> {
175         self.variants
176             .iter()
177             .map(|variant| variant.ident.clone())
178             .collect()
179     }
180 
expression_idents(&self) -> Vec<Vec<Ident>>181     fn expression_idents(&self) -> Vec<Vec<Ident>> {
182         self.variants
183             .iter()
184             .map(|info| {
185                 let indices = 0..(info.alternative_values.len() + 1);
186                 indices
187                     .map(|index| format_ident!("{}__num_enum_{}__", info.ident, index))
188                     .collect()
189             })
190             .collect()
191     }
192 
variant_expressions(&self) -> Vec<Vec<Expr>>193     fn variant_expressions(&self) -> Vec<Vec<Expr>> {
194         self.variants
195             .iter()
196             .map(|variant| variant.all_values().cloned().collect())
197             .collect()
198     }
199 }
200 
201 impl Parse for EnumInfo {
parse(input: ParseStream) -> Result<Self>202     fn parse(input: ParseStream) -> Result<Self> {
203         Ok({
204             let input: DeriveInput = input.parse()?;
205             let name = input.ident;
206             let data = match input.data {
207                 Data::Enum(data) => data,
208                 Data::Union(data) => die!(data.union_token => "Expected enum but found union"),
209                 Data::Struct(data) => die!(data.struct_token => "Expected enum but found struct"),
210             };
211 
212             let repr: Ident = {
213                 let mut attrs = input.attrs.into_iter();
214                 loop {
215                     if let Some(attr) = attrs.next() {
216                         if let Ok(Meta::List(meta_list)) = attr.parse_meta() {
217                             if let Some(ident) = meta_list.path.get_ident() {
218                                 if ident == "repr" {
219                                     let mut nested = meta_list.nested.iter();
220                                     if nested.len() != 1 {
221                                         die!(attr =>
222                                             "Expected exactly one `repr` argument"
223                                         );
224                                     }
225                                     let repr = nested.next().unwrap();
226                                     let repr: Ident = parse_quote! {
227                                         #repr
228                                     };
229                                     if repr == "C" {
230                                         die!(repr =>
231                                             "repr(C) doesn't have a well defined size"
232                                         );
233                                     } else {
234                                         break repr;
235                                     }
236                                 }
237                             }
238                         }
239                     } else {
240                         die!("Missing `#[repr({Integer})]` attribute");
241                     }
242                 }
243             };
244 
245             let mut variants: Vec<VariantInfo> = vec![];
246             let mut has_default_variant: bool = false;
247 
248             let mut next_discriminant = literal(0);
249             for variant in data.variants.into_iter() {
250                 let ident = variant.ident.clone();
251 
252                 match &variant.fields {
253                     Fields::Named(_) | Fields::Unnamed(_) => {
254                         die!(variant => format!("`{}` only supports unit variants (with no associated data), but `{}::{}` was not a unit variant.", get_crate_name(), name, ident));
255                     }
256                     Fields::Unit => {}
257                 }
258 
259                 let discriminant = match variant.discriminant {
260                     Some(d) => d.1,
261                     None => next_discriminant.clone(),
262                 };
263 
264                 let mut attr_spans: AttributeSpans = Default::default();
265                 let mut alternative_values: Vec<Expr> = vec![];
266 
267                 // `#[num_enum(default)]` is required by `#[derive(FromPrimitive)]`
268                 // and forbidden by `#[derive(UnsafeFromPrimitive)]`, so we need to
269                 // keep track of whether we encountered such an attribute:
270                 let mut is_default: bool = false;
271 
272                 for attribute in variant.attrs {
273                     if attribute.path.is_ident("default") {
274                         if has_default_variant {
275                             die!(attribute =>
276                                 "Multiple variants marked `#[default]` or `#[num_enum(default)]` found"
277                             );
278                         }
279                         attr_spans.default.push(attribute.span());
280                         is_default = true;
281                     } else if attribute.path.is_ident("num_enum") {
282                         match attribute.parse_args_with(NumEnumVariantAttributes::parse) {
283                             Ok(variant_attributes) => {
284                                 for variant_attribute in variant_attributes.items {
285                                     match variant_attribute {
286                                         NumEnumVariantAttributeItem::Default(default) => {
287                                             if has_default_variant {
288                                                 die!(default.keyword =>
289                                                     "Multiple variants marked `#[default]` or `#[num_enum(default)]` found"
290                                                 );
291                                             }
292                                             attr_spans.default.push(default.span());
293                                             is_default = true;
294                                         }
295                                         NumEnumVariantAttributeItem::Alternatives(alternatives) => {
296                                             attr_spans.alternatives.push(alternatives.span());
297                                             alternative_values.extend(alternatives.expressions);
298                                         }
299                                     }
300                                 }
301                             }
302                             Err(err) => {
303                                 die!(attribute =>
304                                     format!("Invalid attribute: {}", err)
305                                 );
306                             }
307                         }
308                     } else {
309                         continue;
310                     }
311 
312                     has_default_variant |= is_default;
313                 }
314 
315                 let canonical_value = discriminant.clone();
316 
317                 variants.push(VariantInfo {
318                     ident,
319                     attr_spans,
320                     is_default,
321                     canonical_value,
322                     alternative_values,
323                 });
324 
325                 next_discriminant = parse_quote! {
326                     #repr::wrapping_add(#discriminant, 1)
327                 };
328             }
329 
330             EnumInfo {
331                 name,
332                 repr,
333                 variants,
334             }
335         })
336     }
337 }
338 
339 /// Implements `Into<Primitive>` for a `#[repr(Primitive)] enum`.
340 ///
341 /// (It actually implements `From<Enum> for Primitive`)
342 ///
343 /// ## Allows turning an enum into a primitive.
344 ///
345 /// ```rust
346 /// use num_enum::IntoPrimitive;
347 ///
348 /// #[derive(IntoPrimitive)]
349 /// #[repr(u8)]
350 /// enum Number {
351 ///     Zero,
352 ///     One,
353 /// }
354 ///
355 /// let zero: u8 = Number::Zero.into();
356 /// assert_eq!(zero, 0u8);
357 /// ```
358 #[proc_macro_derive(IntoPrimitive)]
derive_into_primitive(input: TokenStream) -> TokenStream359 pub fn derive_into_primitive(input: TokenStream) -> TokenStream {
360     let EnumInfo { name, repr, .. } = parse_macro_input!(input as EnumInfo);
361 
362     TokenStream::from(quote! {
363         impl From<#name> for #repr {
364             #[inline]
365             fn from (enum_value: #name) -> Self
366             {
367                 enum_value as Self
368             }
369         }
370     })
371 }
372 
373 /// Implements `From<Primitive>` for a `#[repr(Primitive)] enum`.
374 ///
375 /// Turning a primitive into an enum with `from`.
376 /// ----------------------------------------------
377 ///
378 /// ```rust
379 /// use num_enum::FromPrimitive;
380 ///
381 /// #[derive(Debug, Eq, PartialEq, FromPrimitive)]
382 /// #[repr(u8)]
383 /// enum Number {
384 ///     Zero,
385 ///     #[num_enum(default)]
386 ///     NonZero,
387 /// }
388 ///
389 /// let zero = Number::from(0u8);
390 /// assert_eq!(zero, Number::Zero);
391 ///
392 /// let one = Number::from(1u8);
393 /// assert_eq!(one, Number::NonZero);
394 ///
395 /// let two = Number::from(2u8);
396 /// assert_eq!(two, Number::NonZero);
397 /// ```
398 #[proc_macro_derive(FromPrimitive, attributes(num_enum, default))]
derive_from_primitive(input: TokenStream) -> TokenStream399 pub fn derive_from_primitive(input: TokenStream) -> TokenStream {
400     let enum_info: EnumInfo = parse_macro_input!(input);
401     let krate = Ident::new(&get_crate_name(), Span::call_site());
402 
403     let default_ident: Ident = match enum_info.default() {
404         Some(ident) => ident.clone(),
405         None => {
406             let span = Span::call_site();
407             let message =
408                 "#[derive(FromPrimitive)] requires a variant marked with `#[default]` or `#[num_enum(default)]`";
409             return syn::Error::new(span, message).to_compile_error().into();
410         }
411     };
412 
413     let EnumInfo {
414         ref name, ref repr, ..
415     } = enum_info;
416 
417     let variant_idents: Vec<Ident> = enum_info.variant_idents();
418     let expression_idents: Vec<Vec<Ident>> = enum_info.expression_idents();
419     let variant_expressions: Vec<Vec<Expr>> = enum_info.variant_expressions();
420 
421     debug_assert_eq!(variant_idents.len(), variant_expressions.len());
422 
423     TokenStream::from(quote! {
424         impl ::#krate::FromPrimitive for #name {
425             type Primitive = #repr;
426 
427             fn from_primitive(number: Self::Primitive) -> Self {
428                 // Use intermediate const(s) so that enums defined like
429                 // `Two = ONE + 1u8` work properly.
430                 #![allow(non_upper_case_globals)]
431                 #(
432                     #(
433                         const #expression_idents: #repr = #variant_expressions;
434                     )*
435                 )*
436                 #[deny(unreachable_patterns)]
437                 match number {
438                     #(
439                         #( #expression_idents )|*
440                         => Self::#variant_idents,
441                     )*
442                     #[allow(unreachable_patterns)]
443                     _ => Self::#default_ident,
444                 }
445             }
446         }
447 
448         impl ::core::convert::From<#repr> for #name {
449             #[inline]
450             fn from (
451                 number: #repr,
452             ) -> Self {
453                 ::#krate::FromPrimitive::from_primitive(number)
454             }
455         }
456 
457         // The Rust stdlib will implement `#name: From<#repr>` for us for free!
458 
459         impl ::#krate::TryFromPrimitive for #name {
460             type Primitive = #repr;
461 
462             const NAME: &'static str = stringify!(#name);
463 
464             #[inline]
465             fn try_from_primitive (
466                 number: Self::Primitive,
467             ) -> ::core::result::Result<
468                 Self,
469                 ::#krate::TryFromPrimitiveError<Self>,
470             >
471             {
472                 Ok(::#krate::FromPrimitive::from_primitive(number))
473             }
474         }
475     })
476 }
477 
478 /// Implements `TryFrom<Primitive>` for a `#[repr(Primitive)] enum`.
479 ///
480 /// Attempting to turn a primitive into an enum with `try_from`.
481 /// ----------------------------------------------
482 ///
483 /// ```rust
484 /// use num_enum::TryFromPrimitive;
485 /// use std::convert::TryFrom;
486 ///
487 /// #[derive(Debug, Eq, PartialEq, TryFromPrimitive)]
488 /// #[repr(u8)]
489 /// enum Number {
490 ///     Zero,
491 ///     One,
492 /// }
493 ///
494 /// let zero = Number::try_from(0u8);
495 /// assert_eq!(zero, Ok(Number::Zero));
496 ///
497 /// let three = Number::try_from(3u8);
498 /// assert_eq!(
499 ///     three.unwrap_err().to_string(),
500 ///     "No discriminant in enum `Number` matches the value `3`",
501 /// );
502 /// ```
503 #[proc_macro_derive(TryFromPrimitive, attributes(num_enum))]
derive_try_from_primitive(input: TokenStream) -> TokenStream504 pub fn derive_try_from_primitive(input: TokenStream) -> TokenStream {
505     let enum_info: EnumInfo = parse_macro_input!(input);
506     let krate = Ident::new(&get_crate_name(), Span::call_site());
507 
508     let EnumInfo {
509         ref name, ref repr, ..
510     } = enum_info;
511 
512     let variant_idents: Vec<Ident> = enum_info.variant_idents();
513     let expression_idents: Vec<Vec<Ident>> = enum_info.expression_idents();
514     let variant_expressions: Vec<Vec<Expr>> = enum_info.variant_expressions();
515 
516     debug_assert_eq!(variant_idents.len(), variant_expressions.len());
517 
518     let default_arm = match enum_info.default() {
519         Some(ident) => {
520             quote! {
521                 _ => ::core::result::Result::Ok(
522                     #name::#ident
523                 )
524             }
525         }
526         None => {
527             quote! {
528                 _ => ::core::result::Result::Err(
529                     ::#krate::TryFromPrimitiveError { number }
530                 )
531             }
532         }
533     };
534 
535     TokenStream::from(quote! {
536         impl ::#krate::TryFromPrimitive for #name {
537             type Primitive = #repr;
538 
539             const NAME: &'static str = stringify!(#name);
540 
541             fn try_from_primitive (
542                 number: Self::Primitive,
543             ) -> ::core::result::Result<
544                 Self,
545                 ::#krate::TryFromPrimitiveError<Self>
546             > {
547                 // Use intermediate const(s) so that enums defined like
548                 // `Two = ONE + 1u8` work properly.
549                 #![allow(non_upper_case_globals)]
550                 #(
551                     #(
552                         const #expression_idents: #repr = #variant_expressions;
553                     )*
554                 )*
555                 #[deny(unreachable_patterns)]
556                 match number {
557                     #(
558                         #( #expression_idents )|*
559                         => ::core::result::Result::Ok(Self::#variant_idents),
560                     )*
561                     #[allow(unreachable_patterns)]
562                     #default_arm,
563                 }
564             }
565         }
566 
567         impl ::core::convert::TryFrom<#repr> for #name {
568             type Error = ::#krate::TryFromPrimitiveError<Self>;
569 
570             #[inline]
571             fn try_from (
572                 number: #repr,
573             ) -> ::core::result::Result<Self, ::#krate::TryFromPrimitiveError<Self>>
574             {
575                 ::#krate::TryFromPrimitive::try_from_primitive(number)
576             }
577         }
578     })
579 }
580 
581 #[cfg(feature = "proc-macro-crate")]
get_crate_name() -> String582 fn get_crate_name() -> String {
583     let found_crate = proc_macro_crate::crate_name("num_enum").unwrap_or_else(|err| {
584         eprintln!("Warning: {}\n    => defaulting to `num_enum`", err,);
585         proc_macro_crate::FoundCrate::Itself
586     });
587 
588     match found_crate {
589         proc_macro_crate::FoundCrate::Itself => String::from("num_enum"),
590         proc_macro_crate::FoundCrate::Name(name) => name,
591     }
592 }
593 
594 // Don't depend on proc-macro-crate in no_std environments because it causes an awkward dependency
595 // on serde with std.
596 //
597 // no_std dependees on num_enum cannot rename the num_enum crate when they depend on it. Sorry.
598 //
599 // See https://github.com/illicitonion/num_enum/issues/18
600 #[cfg(not(feature = "proc-macro-crate"))]
get_crate_name() -> String601 fn get_crate_name() -> String {
602     String::from("num_enum")
603 }
604 
605 /// Generates a `unsafe fn from_unchecked (number: Primitive) -> Self`
606 /// associated function.
607 ///
608 /// Allows unsafely turning a primitive into an enum with from_unchecked.
609 /// -------------------------------------------------------------
610 ///
611 /// If you're really certain a conversion will succeed, and want to avoid a small amount of overhead, you can use unsafe
612 /// code to do this conversion. Unless you have data showing that the match statement generated in the `try_from` above is a
613 /// bottleneck for you, you should avoid doing this, as the unsafe code has potential to cause serious memory issues in
614 /// your program.
615 ///
616 /// ```rust
617 /// use num_enum::UnsafeFromPrimitive;
618 ///
619 /// #[derive(Debug, Eq, PartialEq, UnsafeFromPrimitive)]
620 /// #[repr(u8)]
621 /// enum Number {
622 ///     Zero,
623 ///     One,
624 /// }
625 ///
626 /// fn main() {
627 ///     assert_eq!(
628 ///         Number::Zero,
629 ///         unsafe { Number::from_unchecked(0_u8) },
630 ///     );
631 ///     assert_eq!(
632 ///         Number::One,
633 ///         unsafe { Number::from_unchecked(1_u8) },
634 ///     );
635 /// }
636 ///
637 /// unsafe fn undefined_behavior() {
638 ///     let _ = Number::from_unchecked(2); // 2 is not a valid discriminant!
639 /// }
640 /// ```
641 #[proc_macro_derive(UnsafeFromPrimitive, attributes(num_enum))]
derive_unsafe_from_primitive(stream: TokenStream) -> TokenStream642 pub fn derive_unsafe_from_primitive(stream: TokenStream) -> TokenStream {
643     let enum_info = parse_macro_input!(stream as EnumInfo);
644 
645     if enum_info.has_default_variant() {
646         let span = enum_info
647             .first_default_attr_span()
648             .cloned()
649             .expect("Expected span");
650         let message = "#[derive(UnsafeFromPrimitive)] does not support `#[num_enum(default)]`";
651         return syn::Error::new(span, message).to_compile_error().into();
652     }
653 
654     if enum_info.has_complex_variant() {
655         let span = enum_info
656             .first_alternatives_attr_span()
657             .cloned()
658             .expect("Expected span");
659         let message =
660             "#[derive(UnsafeFromPrimitive)] does not support `#[num_enum(alternatives = [..])]`";
661         return syn::Error::new(span, message).to_compile_error().into();
662     }
663 
664     let EnumInfo {
665         ref name, ref repr, ..
666     } = enum_info;
667 
668     let doc_string = LitStr::new(
669         &format!(
670             r#"
671 Transmutes `number: {repr}` into a [`{name}`].
672 
673 # Safety
674 
675   - `number` must represent a valid discriminant of [`{name}`]
676 "#,
677             repr = repr,
678             name = name,
679         ),
680         Span::call_site(),
681     );
682 
683     TokenStream::from(quote! {
684         impl #name {
685             #[doc = #doc_string]
686             #[inline]
687             pub unsafe fn from_unchecked(number: #repr) -> Self {
688                 ::core::mem::transmute(number)
689             }
690         }
691     })
692 }
693 
694 /// Implements `core::default::Default` for a `#[repr(Primitive)] enum`.
695 ///
696 /// Whichever variant has the `#[default]` or `#[num_enum(default)]` attribute will be returned.
697 /// ----------------------------------------------
698 ///
699 /// ```rust
700 /// #[derive(Debug, Eq, PartialEq, num_enum::Default)]
701 /// #[repr(u8)]
702 /// enum Number {
703 ///     Zero,
704 ///     #[default]
705 ///     One,
706 /// }
707 ///
708 /// assert_eq!(Number::One, Number::default());
709 /// assert_eq!(Number::One, <Number as ::core::default::Default>::default());
710 /// ```
711 #[proc_macro_derive(Default, attributes(num_enum, default))]
derive_default(stream: TokenStream) -> TokenStream712 pub fn derive_default(stream: TokenStream) -> TokenStream {
713     let enum_info = parse_macro_input!(stream as EnumInfo);
714 
715     let default_ident = match enum_info.default() {
716         Some(ident) => ident,
717         None => {
718             let span = Span::call_site();
719             let message =
720                 "#[derive(num_enum::Default)] requires a variant marked with `#[default]` or `#[num_enum(default)]`";
721             return syn::Error::new(span, message).to_compile_error().into();
722         }
723     };
724 
725     let EnumInfo { ref name, .. } = enum_info;
726 
727     TokenStream::from(quote! {
728         impl ::core::default::Default for #name {
729             #[inline]
730             fn default() -> Self {
731                 Self::#default_ident
732             }
733         }
734     })
735 }
736