1 extern crate proc_macro;
2
3 use ::proc_macro::TokenStream;
4 use ::proc_macro2::Span;
5 use ::quote::{format_ident, quote};
6 use ::syn::{
7 parse::{Parse, ParseStream},
8 parse_macro_input, parse_quote,
9 spanned::Spanned,
10 Data, DeriveInput, Error, Expr, Fields, Ident, LitInt, LitStr, Meta, Result,
11 };
12
13 macro_rules! die {
14 ($spanned:expr=>
15 $msg:expr
16 ) => {
17 return Err(Error::new_spanned($spanned, $msg))
18 };
19
20 (
21 $msg:expr
22 ) => {
23 return Err(Error::new(Span::call_site(), $msg))
24 };
25 }
26
literal(i: u64) -> Expr27 fn literal(i: u64) -> Expr {
28 let literal = LitInt::new(&i.to_string(), Span::call_site());
29 parse_quote! {
30 #literal
31 }
32 }
33
34 mod kw {
35 syn::custom_keyword!(default);
36 syn::custom_keyword!(alternatives);
37 }
38
39 struct NumEnumVariantAttributes {
40 items: syn::punctuated::Punctuated<NumEnumVariantAttributeItem, syn::Token![,]>,
41 }
42
43 impl Parse for NumEnumVariantAttributes {
parse(input: ParseStream<'_>) -> Result<Self>44 fn parse(input: ParseStream<'_>) -> Result<Self> {
45 Ok(Self {
46 items: input.parse_terminated(NumEnumVariantAttributeItem::parse)?,
47 })
48 }
49 }
50
51 enum NumEnumVariantAttributeItem {
52 Default(VariantDefaultAttribute),
53 Alternatives(VariantAlternativesAttribute),
54 }
55
56 impl Parse for NumEnumVariantAttributeItem {
parse(input: ParseStream<'_>) -> Result<Self>57 fn parse(input: ParseStream<'_>) -> Result<Self> {
58 let lookahead = input.lookahead1();
59 if lookahead.peek(kw::default) {
60 input.parse().map(Self::Default)
61 } else if lookahead.peek(kw::alternatives) {
62 input.parse().map(Self::Alternatives)
63 } else {
64 Err(lookahead.error())
65 }
66 }
67 }
68
69 struct VariantDefaultAttribute {
70 keyword: kw::default,
71 }
72
73 impl Parse for VariantDefaultAttribute {
parse(input: ParseStream) -> Result<Self>74 fn parse(input: ParseStream) -> Result<Self> {
75 Ok(Self {
76 keyword: input.parse()?,
77 })
78 }
79 }
80
81 impl Spanned for VariantDefaultAttribute {
span(&self) -> Span82 fn span(&self) -> Span {
83 self.keyword.span()
84 }
85 }
86
87 struct VariantAlternativesAttribute {
88 keyword: kw::alternatives,
89 _eq_token: syn::Token![=],
90 _bracket_token: syn::token::Bracket,
91 expressions: syn::punctuated::Punctuated<Expr, syn::Token![,]>,
92 }
93
94 impl Parse for VariantAlternativesAttribute {
parse(input: ParseStream) -> Result<Self>95 fn parse(input: ParseStream) -> Result<Self> {
96 let content;
97 let keyword = input.parse()?;
98 let _eq_token = input.parse()?;
99 let _bracket_token = syn::bracketed!(content in input);
100 let expressions = content.parse_terminated(Expr::parse)?;
101 Ok(Self {
102 keyword,
103 _eq_token,
104 _bracket_token,
105 expressions,
106 })
107 }
108 }
109
110 impl Spanned for VariantAlternativesAttribute {
span(&self) -> Span111 fn span(&self) -> Span {
112 self.keyword.span()
113 }
114 }
115
116 #[derive(::core::default::Default)]
117 struct AttributeSpans {
118 default: Vec<Span>,
119 alternatives: Vec<Span>,
120 }
121
122 struct VariantInfo {
123 ident: Ident,
124 attr_spans: AttributeSpans,
125 is_default: bool,
126 canonical_value: Expr,
127 alternative_values: Vec<Expr>,
128 }
129
130 impl VariantInfo {
all_values(&self) -> impl Iterator<Item = &Expr>131 fn all_values(&self) -> impl Iterator<Item = &Expr> {
132 ::core::iter::once(&self.canonical_value).chain(self.alternative_values.iter())
133 }
134
is_complex(&self) -> bool135 fn is_complex(&self) -> bool {
136 !self.alternative_values.is_empty()
137 }
138 }
139
140 struct EnumInfo {
141 name: Ident,
142 repr: Ident,
143 variants: Vec<VariantInfo>,
144 }
145
146 impl EnumInfo {
has_default_variant(&self) -> bool147 fn has_default_variant(&self) -> bool {
148 self.default().is_some()
149 }
150
has_complex_variant(&self) -> bool151 fn has_complex_variant(&self) -> bool {
152 self.variants.iter().any(|info| info.is_complex())
153 }
154
default(&self) -> Option<&Ident>155 fn default(&self) -> Option<&Ident> {
156 self.variants
157 .iter()
158 .find(|info| info.is_default)
159 .map(|info| &info.ident)
160 }
161
first_default_attr_span(&self) -> Option<&Span>162 fn first_default_attr_span(&self) -> Option<&Span> {
163 self.variants
164 .iter()
165 .find_map(|info| info.attr_spans.default.first())
166 }
167
first_alternatives_attr_span(&self) -> Option<&Span>168 fn first_alternatives_attr_span(&self) -> Option<&Span> {
169 self.variants
170 .iter()
171 .find_map(|info| info.attr_spans.alternatives.first())
172 }
173
variant_idents(&self) -> Vec<Ident>174 fn variant_idents(&self) -> Vec<Ident> {
175 self.variants
176 .iter()
177 .map(|variant| variant.ident.clone())
178 .collect()
179 }
180
expression_idents(&self) -> Vec<Vec<Ident>>181 fn expression_idents(&self) -> Vec<Vec<Ident>> {
182 self.variants
183 .iter()
184 .map(|info| {
185 let indices = 0..(info.alternative_values.len() + 1);
186 indices
187 .map(|index| format_ident!("{}__num_enum_{}__", info.ident, index))
188 .collect()
189 })
190 .collect()
191 }
192
variant_expressions(&self) -> Vec<Vec<Expr>>193 fn variant_expressions(&self) -> Vec<Vec<Expr>> {
194 self.variants
195 .iter()
196 .map(|variant| variant.all_values().cloned().collect())
197 .collect()
198 }
199 }
200
201 impl Parse for EnumInfo {
parse(input: ParseStream) -> Result<Self>202 fn parse(input: ParseStream) -> Result<Self> {
203 Ok({
204 let input: DeriveInput = input.parse()?;
205 let name = input.ident;
206 let data = match input.data {
207 Data::Enum(data) => data,
208 Data::Union(data) => die!(data.union_token => "Expected enum but found union"),
209 Data::Struct(data) => die!(data.struct_token => "Expected enum but found struct"),
210 };
211
212 let repr: Ident = {
213 let mut attrs = input.attrs.into_iter();
214 loop {
215 if let Some(attr) = attrs.next() {
216 if let Ok(Meta::List(meta_list)) = attr.parse_meta() {
217 if let Some(ident) = meta_list.path.get_ident() {
218 if ident == "repr" {
219 let mut nested = meta_list.nested.iter();
220 if nested.len() != 1 {
221 die!(attr =>
222 "Expected exactly one `repr` argument"
223 );
224 }
225 let repr = nested.next().unwrap();
226 let repr: Ident = parse_quote! {
227 #repr
228 };
229 if repr == "C" {
230 die!(repr =>
231 "repr(C) doesn't have a well defined size"
232 );
233 } else {
234 break repr;
235 }
236 }
237 }
238 }
239 } else {
240 die!("Missing `#[repr({Integer})]` attribute");
241 }
242 }
243 };
244
245 let mut variants: Vec<VariantInfo> = vec![];
246 let mut has_default_variant: bool = false;
247
248 let mut next_discriminant = literal(0);
249 for variant in data.variants.into_iter() {
250 let ident = variant.ident.clone();
251
252 match &variant.fields {
253 Fields::Named(_) | Fields::Unnamed(_) => {
254 die!(variant => format!("`{}` only supports unit variants (with no associated data), but `{}::{}` was not a unit variant.", get_crate_name(), name, ident));
255 }
256 Fields::Unit => {}
257 }
258
259 let discriminant = match variant.discriminant {
260 Some(d) => d.1,
261 None => next_discriminant.clone(),
262 };
263
264 let mut attr_spans: AttributeSpans = Default::default();
265 let mut alternative_values: Vec<Expr> = vec![];
266
267 // `#[num_enum(default)]` is required by `#[derive(FromPrimitive)]`
268 // and forbidden by `#[derive(UnsafeFromPrimitive)]`, so we need to
269 // keep track of whether we encountered such an attribute:
270 let mut is_default: bool = false;
271
272 for attribute in variant.attrs {
273 if attribute.path.is_ident("default") {
274 if has_default_variant {
275 die!(attribute =>
276 "Multiple variants marked `#[default]` or `#[num_enum(default)]` found"
277 );
278 }
279 attr_spans.default.push(attribute.span());
280 is_default = true;
281 } else if attribute.path.is_ident("num_enum") {
282 match attribute.parse_args_with(NumEnumVariantAttributes::parse) {
283 Ok(variant_attributes) => {
284 for variant_attribute in variant_attributes.items {
285 match variant_attribute {
286 NumEnumVariantAttributeItem::Default(default) => {
287 if has_default_variant {
288 die!(default.keyword =>
289 "Multiple variants marked `#[default]` or `#[num_enum(default)]` found"
290 );
291 }
292 attr_spans.default.push(default.span());
293 is_default = true;
294 }
295 NumEnumVariantAttributeItem::Alternatives(alternatives) => {
296 attr_spans.alternatives.push(alternatives.span());
297 alternative_values.extend(alternatives.expressions);
298 }
299 }
300 }
301 }
302 Err(err) => {
303 die!(attribute =>
304 format!("Invalid attribute: {}", err)
305 );
306 }
307 }
308 } else {
309 continue;
310 }
311
312 has_default_variant |= is_default;
313 }
314
315 let canonical_value = discriminant.clone();
316
317 variants.push(VariantInfo {
318 ident,
319 attr_spans,
320 is_default,
321 canonical_value,
322 alternative_values,
323 });
324
325 next_discriminant = parse_quote! {
326 #repr::wrapping_add(#discriminant, 1)
327 };
328 }
329
330 EnumInfo {
331 name,
332 repr,
333 variants,
334 }
335 })
336 }
337 }
338
339 /// Implements `Into<Primitive>` for a `#[repr(Primitive)] enum`.
340 ///
341 /// (It actually implements `From<Enum> for Primitive`)
342 ///
343 /// ## Allows turning an enum into a primitive.
344 ///
345 /// ```rust
346 /// use num_enum::IntoPrimitive;
347 ///
348 /// #[derive(IntoPrimitive)]
349 /// #[repr(u8)]
350 /// enum Number {
351 /// Zero,
352 /// One,
353 /// }
354 ///
355 /// let zero: u8 = Number::Zero.into();
356 /// assert_eq!(zero, 0u8);
357 /// ```
358 #[proc_macro_derive(IntoPrimitive)]
derive_into_primitive(input: TokenStream) -> TokenStream359 pub fn derive_into_primitive(input: TokenStream) -> TokenStream {
360 let EnumInfo { name, repr, .. } = parse_macro_input!(input as EnumInfo);
361
362 TokenStream::from(quote! {
363 impl From<#name> for #repr {
364 #[inline]
365 fn from (enum_value: #name) -> Self
366 {
367 enum_value as Self
368 }
369 }
370 })
371 }
372
373 /// Implements `From<Primitive>` for a `#[repr(Primitive)] enum`.
374 ///
375 /// Turning a primitive into an enum with `from`.
376 /// ----------------------------------------------
377 ///
378 /// ```rust
379 /// use num_enum::FromPrimitive;
380 ///
381 /// #[derive(Debug, Eq, PartialEq, FromPrimitive)]
382 /// #[repr(u8)]
383 /// enum Number {
384 /// Zero,
385 /// #[num_enum(default)]
386 /// NonZero,
387 /// }
388 ///
389 /// let zero = Number::from(0u8);
390 /// assert_eq!(zero, Number::Zero);
391 ///
392 /// let one = Number::from(1u8);
393 /// assert_eq!(one, Number::NonZero);
394 ///
395 /// let two = Number::from(2u8);
396 /// assert_eq!(two, Number::NonZero);
397 /// ```
398 #[proc_macro_derive(FromPrimitive, attributes(num_enum, default))]
derive_from_primitive(input: TokenStream) -> TokenStream399 pub fn derive_from_primitive(input: TokenStream) -> TokenStream {
400 let enum_info: EnumInfo = parse_macro_input!(input);
401 let krate = Ident::new(&get_crate_name(), Span::call_site());
402
403 let default_ident: Ident = match enum_info.default() {
404 Some(ident) => ident.clone(),
405 None => {
406 let span = Span::call_site();
407 let message =
408 "#[derive(FromPrimitive)] requires a variant marked with `#[default]` or `#[num_enum(default)]`";
409 return syn::Error::new(span, message).to_compile_error().into();
410 }
411 };
412
413 let EnumInfo {
414 ref name, ref repr, ..
415 } = enum_info;
416
417 let variant_idents: Vec<Ident> = enum_info.variant_idents();
418 let expression_idents: Vec<Vec<Ident>> = enum_info.expression_idents();
419 let variant_expressions: Vec<Vec<Expr>> = enum_info.variant_expressions();
420
421 debug_assert_eq!(variant_idents.len(), variant_expressions.len());
422
423 TokenStream::from(quote! {
424 impl ::#krate::FromPrimitive for #name {
425 type Primitive = #repr;
426
427 fn from_primitive(number: Self::Primitive) -> Self {
428 // Use intermediate const(s) so that enums defined like
429 // `Two = ONE + 1u8` work properly.
430 #![allow(non_upper_case_globals)]
431 #(
432 #(
433 const #expression_idents: #repr = #variant_expressions;
434 )*
435 )*
436 #[deny(unreachable_patterns)]
437 match number {
438 #(
439 #( #expression_idents )|*
440 => Self::#variant_idents,
441 )*
442 #[allow(unreachable_patterns)]
443 _ => Self::#default_ident,
444 }
445 }
446 }
447
448 impl ::core::convert::From<#repr> for #name {
449 #[inline]
450 fn from (
451 number: #repr,
452 ) -> Self {
453 ::#krate::FromPrimitive::from_primitive(number)
454 }
455 }
456
457 // The Rust stdlib will implement `#name: From<#repr>` for us for free!
458
459 impl ::#krate::TryFromPrimitive for #name {
460 type Primitive = #repr;
461
462 const NAME: &'static str = stringify!(#name);
463
464 #[inline]
465 fn try_from_primitive (
466 number: Self::Primitive,
467 ) -> ::core::result::Result<
468 Self,
469 ::#krate::TryFromPrimitiveError<Self>,
470 >
471 {
472 Ok(::#krate::FromPrimitive::from_primitive(number))
473 }
474 }
475 })
476 }
477
478 /// Implements `TryFrom<Primitive>` for a `#[repr(Primitive)] enum`.
479 ///
480 /// Attempting to turn a primitive into an enum with `try_from`.
481 /// ----------------------------------------------
482 ///
483 /// ```rust
484 /// use num_enum::TryFromPrimitive;
485 /// use std::convert::TryFrom;
486 ///
487 /// #[derive(Debug, Eq, PartialEq, TryFromPrimitive)]
488 /// #[repr(u8)]
489 /// enum Number {
490 /// Zero,
491 /// One,
492 /// }
493 ///
494 /// let zero = Number::try_from(0u8);
495 /// assert_eq!(zero, Ok(Number::Zero));
496 ///
497 /// let three = Number::try_from(3u8);
498 /// assert_eq!(
499 /// three.unwrap_err().to_string(),
500 /// "No discriminant in enum `Number` matches the value `3`",
501 /// );
502 /// ```
503 #[proc_macro_derive(TryFromPrimitive, attributes(num_enum))]
derive_try_from_primitive(input: TokenStream) -> TokenStream504 pub fn derive_try_from_primitive(input: TokenStream) -> TokenStream {
505 let enum_info: EnumInfo = parse_macro_input!(input);
506 let krate = Ident::new(&get_crate_name(), Span::call_site());
507
508 let EnumInfo {
509 ref name, ref repr, ..
510 } = enum_info;
511
512 let variant_idents: Vec<Ident> = enum_info.variant_idents();
513 let expression_idents: Vec<Vec<Ident>> = enum_info.expression_idents();
514 let variant_expressions: Vec<Vec<Expr>> = enum_info.variant_expressions();
515
516 debug_assert_eq!(variant_idents.len(), variant_expressions.len());
517
518 let default_arm = match enum_info.default() {
519 Some(ident) => {
520 quote! {
521 _ => ::core::result::Result::Ok(
522 #name::#ident
523 )
524 }
525 }
526 None => {
527 quote! {
528 _ => ::core::result::Result::Err(
529 ::#krate::TryFromPrimitiveError { number }
530 )
531 }
532 }
533 };
534
535 TokenStream::from(quote! {
536 impl ::#krate::TryFromPrimitive for #name {
537 type Primitive = #repr;
538
539 const NAME: &'static str = stringify!(#name);
540
541 fn try_from_primitive (
542 number: Self::Primitive,
543 ) -> ::core::result::Result<
544 Self,
545 ::#krate::TryFromPrimitiveError<Self>
546 > {
547 // Use intermediate const(s) so that enums defined like
548 // `Two = ONE + 1u8` work properly.
549 #![allow(non_upper_case_globals)]
550 #(
551 #(
552 const #expression_idents: #repr = #variant_expressions;
553 )*
554 )*
555 #[deny(unreachable_patterns)]
556 match number {
557 #(
558 #( #expression_idents )|*
559 => ::core::result::Result::Ok(Self::#variant_idents),
560 )*
561 #[allow(unreachable_patterns)]
562 #default_arm,
563 }
564 }
565 }
566
567 impl ::core::convert::TryFrom<#repr> for #name {
568 type Error = ::#krate::TryFromPrimitiveError<Self>;
569
570 #[inline]
571 fn try_from (
572 number: #repr,
573 ) -> ::core::result::Result<Self, ::#krate::TryFromPrimitiveError<Self>>
574 {
575 ::#krate::TryFromPrimitive::try_from_primitive(number)
576 }
577 }
578 })
579 }
580
581 #[cfg(feature = "proc-macro-crate")]
get_crate_name() -> String582 fn get_crate_name() -> String {
583 let found_crate = proc_macro_crate::crate_name("num_enum").unwrap_or_else(|err| {
584 eprintln!("Warning: {}\n => defaulting to `num_enum`", err,);
585 proc_macro_crate::FoundCrate::Itself
586 });
587
588 match found_crate {
589 proc_macro_crate::FoundCrate::Itself => String::from("num_enum"),
590 proc_macro_crate::FoundCrate::Name(name) => name,
591 }
592 }
593
594 // Don't depend on proc-macro-crate in no_std environments because it causes an awkward dependency
595 // on serde with std.
596 //
597 // no_std dependees on num_enum cannot rename the num_enum crate when they depend on it. Sorry.
598 //
599 // See https://github.com/illicitonion/num_enum/issues/18
600 #[cfg(not(feature = "proc-macro-crate"))]
get_crate_name() -> String601 fn get_crate_name() -> String {
602 String::from("num_enum")
603 }
604
605 /// Generates a `unsafe fn from_unchecked (number: Primitive) -> Self`
606 /// associated function.
607 ///
608 /// Allows unsafely turning a primitive into an enum with from_unchecked.
609 /// -------------------------------------------------------------
610 ///
611 /// If you're really certain a conversion will succeed, and want to avoid a small amount of overhead, you can use unsafe
612 /// code to do this conversion. Unless you have data showing that the match statement generated in the `try_from` above is a
613 /// bottleneck for you, you should avoid doing this, as the unsafe code has potential to cause serious memory issues in
614 /// your program.
615 ///
616 /// ```rust
617 /// use num_enum::UnsafeFromPrimitive;
618 ///
619 /// #[derive(Debug, Eq, PartialEq, UnsafeFromPrimitive)]
620 /// #[repr(u8)]
621 /// enum Number {
622 /// Zero,
623 /// One,
624 /// }
625 ///
626 /// fn main() {
627 /// assert_eq!(
628 /// Number::Zero,
629 /// unsafe { Number::from_unchecked(0_u8) },
630 /// );
631 /// assert_eq!(
632 /// Number::One,
633 /// unsafe { Number::from_unchecked(1_u8) },
634 /// );
635 /// }
636 ///
637 /// unsafe fn undefined_behavior() {
638 /// let _ = Number::from_unchecked(2); // 2 is not a valid discriminant!
639 /// }
640 /// ```
641 #[proc_macro_derive(UnsafeFromPrimitive, attributes(num_enum))]
derive_unsafe_from_primitive(stream: TokenStream) -> TokenStream642 pub fn derive_unsafe_from_primitive(stream: TokenStream) -> TokenStream {
643 let enum_info = parse_macro_input!(stream as EnumInfo);
644
645 if enum_info.has_default_variant() {
646 let span = enum_info
647 .first_default_attr_span()
648 .cloned()
649 .expect("Expected span");
650 let message = "#[derive(UnsafeFromPrimitive)] does not support `#[num_enum(default)]`";
651 return syn::Error::new(span, message).to_compile_error().into();
652 }
653
654 if enum_info.has_complex_variant() {
655 let span = enum_info
656 .first_alternatives_attr_span()
657 .cloned()
658 .expect("Expected span");
659 let message =
660 "#[derive(UnsafeFromPrimitive)] does not support `#[num_enum(alternatives = [..])]`";
661 return syn::Error::new(span, message).to_compile_error().into();
662 }
663
664 let EnumInfo {
665 ref name, ref repr, ..
666 } = enum_info;
667
668 let doc_string = LitStr::new(
669 &format!(
670 r#"
671 Transmutes `number: {repr}` into a [`{name}`].
672
673 # Safety
674
675 - `number` must represent a valid discriminant of [`{name}`]
676 "#,
677 repr = repr,
678 name = name,
679 ),
680 Span::call_site(),
681 );
682
683 TokenStream::from(quote! {
684 impl #name {
685 #[doc = #doc_string]
686 #[inline]
687 pub unsafe fn from_unchecked(number: #repr) -> Self {
688 ::core::mem::transmute(number)
689 }
690 }
691 })
692 }
693
694 /// Implements `core::default::Default` for a `#[repr(Primitive)] enum`.
695 ///
696 /// Whichever variant has the `#[default]` or `#[num_enum(default)]` attribute will be returned.
697 /// ----------------------------------------------
698 ///
699 /// ```rust
700 /// #[derive(Debug, Eq, PartialEq, num_enum::Default)]
701 /// #[repr(u8)]
702 /// enum Number {
703 /// Zero,
704 /// #[default]
705 /// One,
706 /// }
707 ///
708 /// assert_eq!(Number::One, Number::default());
709 /// assert_eq!(Number::One, <Number as ::core::default::Default>::default());
710 /// ```
711 #[proc_macro_derive(Default, attributes(num_enum, default))]
derive_default(stream: TokenStream) -> TokenStream712 pub fn derive_default(stream: TokenStream) -> TokenStream {
713 let enum_info = parse_macro_input!(stream as EnumInfo);
714
715 let default_ident = match enum_info.default() {
716 Some(ident) => ident,
717 None => {
718 let span = Span::call_site();
719 let message =
720 "#[derive(num_enum::Default)] requires a variant marked with `#[default]` or `#[num_enum(default)]`";
721 return syn::Error::new(span, message).to_compile_error().into();
722 }
723 };
724
725 let EnumInfo { ref name, .. } = enum_info;
726
727 TokenStream::from(quote! {
728 impl ::core::default::Default for #name {
729 #[inline]
730 fn default() -> Self {
731 Self::#default_ident
732 }
733 }
734 })
735 }
736