1 use crate::utils::{AttrParams, DeriveType, State};
2 use convert_case::{Case, Casing};
3 use proc_macro2::TokenStream;
4 use quote::{format_ident, quote};
5 use syn::{DeriveInput, Fields, Ident, Result};
6
expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream>7 pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> {
8 let state = State::with_attr_params(
9 input,
10 trait_name,
11 quote!(),
12 String::from("unwrap"),
13 AttrParams {
14 enum_: vec!["ignore"],
15 variant: vec!["ignore"],
16 struct_: vec!["ignore"],
17 field: vec!["ignore"],
18 },
19 )?;
20 if state.derive_type != DeriveType::Enum {
21 panic!("Unwrap can only be derived for enums");
22 }
23
24 let enum_name = &input.ident;
25 let (imp_generics, type_generics, where_clause) = input.generics.split_for_impl();
26
27 let mut funcs = vec![];
28 for variant_state in state.enabled_variant_data().variant_states {
29 let variant = variant_state.variant.unwrap();
30 let fn_name = Ident::new(
31 &format_ident!("unwrap_{}", variant.ident)
32 .to_string()
33 .to_case(Case::Snake),
34 variant.ident.span(),
35 );
36 let variant_ident = &variant.ident;
37
38 let (data_pattern, ret_value, ret_type) = match variant.fields {
39 Fields::Named(_) => panic!("cannot unwrap anonymous records"),
40 Fields::Unnamed(ref fields) => {
41 let data_pattern =
42 (0..fields.unnamed.len()).fold(vec![], |mut a, n| {
43 a.push(format_ident!("field_{}", n));
44 a
45 });
46 let ret_type = &fields.unnamed;
47 (
48 quote! { (#(#data_pattern),*) },
49 quote! { (#(#data_pattern),*) },
50 quote! { (#ret_type) },
51 )
52 }
53 Fields::Unit => (quote! {}, quote! { () }, quote! { () }),
54 };
55
56 let other_arms = state.variant_states.iter().map(|variant| {
57 variant.variant.unwrap()
58 }).filter(|variant| {
59 &variant.ident != variant_ident
60 }).map(|variant| {
61 let data_pattern = match variant.fields {
62 Fields::Named(_) => quote! { {..} },
63 Fields::Unnamed(_) => quote! { (..) },
64 Fields::Unit => quote! {},
65 };
66 let variant_ident = &variant.ident;
67 quote! { #enum_name :: #variant_ident #data_pattern =>
68 panic!(concat!("called `", stringify!(#enum_name), "::", stringify!(#fn_name),
69 "()` on a `", stringify!(#variant_ident), "` value"))
70 }
71 });
72
73 // The `track-caller` feature is set by our build script based
74 // on rustc version detection, as `#[track_caller]` was
75 // stabilized in a later version (1.46) of Rust than our MSRV (1.36).
76 let track_caller = if cfg!(feature = "track-caller") {
77 quote! { #[track_caller] }
78 } else {
79 quote! {}
80 };
81 let func = quote! {
82 #track_caller
83 pub fn #fn_name(self) -> #ret_type {
84 match self {
85 #enum_name ::#variant_ident #data_pattern => #ret_value,
86 #(#other_arms),*
87 }
88 }
89 };
90 funcs.push(func);
91 }
92
93 let imp = quote! {
94 impl #imp_generics #enum_name #type_generics #where_clause{
95 #(#funcs)*
96 }
97 };
98
99 Ok(imp)
100 }
101