1 use crate::utils::{AttrParams, DeriveType, State};
2 use convert_case::{Case, Casing};
3 use proc_macro2::TokenStream;
4 use quote::{format_ident, quote};
5 use syn::{DeriveInput, Fields, Ident, Result};
6 
expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream>7 pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> {
8     let state = State::with_attr_params(
9         input,
10         trait_name,
11         quote!(),
12         String::from("unwrap"),
13         AttrParams {
14             enum_: vec!["ignore"],
15             variant: vec!["ignore"],
16             struct_: vec!["ignore"],
17             field: vec!["ignore"],
18         },
19     )?;
20     if state.derive_type != DeriveType::Enum {
21         panic!("Unwrap can only be derived for enums");
22     }
23 
24     let enum_name = &input.ident;
25     let (imp_generics, type_generics, where_clause) = input.generics.split_for_impl();
26 
27     let mut funcs = vec![];
28     for variant_state in state.enabled_variant_data().variant_states {
29         let variant = variant_state.variant.unwrap();
30         let fn_name = Ident::new(
31             &format_ident!("unwrap_{}", variant.ident)
32                 .to_string()
33                 .to_case(Case::Snake),
34             variant.ident.span(),
35         );
36         let variant_ident = &variant.ident;
37 
38         let (data_pattern, ret_value, ret_type) = match variant.fields {
39             Fields::Named(_) => panic!("cannot unwrap anonymous records"),
40             Fields::Unnamed(ref fields) => {
41                 let data_pattern =
42                     (0..fields.unnamed.len()).fold(vec![], |mut a, n| {
43                         a.push(format_ident!("field_{}", n));
44                         a
45                     });
46                 let ret_type = &fields.unnamed;
47                 (
48                     quote! { (#(#data_pattern),*) },
49                     quote! { (#(#data_pattern),*) },
50                     quote! { (#ret_type) },
51                 )
52             }
53             Fields::Unit => (quote! {}, quote! { () }, quote! { () }),
54         };
55 
56         let other_arms = state.variant_states.iter().map(|variant| {
57             variant.variant.unwrap()
58         }).filter(|variant| {
59             &variant.ident != variant_ident
60         }).map(|variant| {
61             let data_pattern = match variant.fields {
62                 Fields::Named(_) => quote! { {..} },
63                 Fields::Unnamed(_) => quote! { (..) },
64                 Fields::Unit => quote! {},
65             };
66             let variant_ident = &variant.ident;
67             quote! { #enum_name :: #variant_ident #data_pattern =>
68                       panic!(concat!("called `", stringify!(#enum_name), "::", stringify!(#fn_name),
69                                      "()` on a `", stringify!(#variant_ident), "` value"))
70             }
71         });
72 
73         // The `track-caller` feature is set by our build script based
74         // on rustc version detection, as `#[track_caller]` was
75         // stabilized in a later version (1.46) of Rust than our MSRV (1.36).
76         let track_caller = if cfg!(feature = "track-caller") {
77             quote! { #[track_caller] }
78         } else {
79             quote! {}
80         };
81         let func = quote! {
82             #track_caller
83             pub fn #fn_name(self) -> #ret_type {
84                 match self {
85                     #enum_name ::#variant_ident #data_pattern => #ret_value,
86                     #(#other_arms),*
87                 }
88             }
89         };
90         funcs.push(func);
91     }
92 
93     let imp = quote! {
94         impl #imp_generics #enum_name #type_generics #where_clause{
95             #(#funcs)*
96         }
97     };
98 
99     Ok(imp)
100 }
101