1 use proc_macro2::TokenStream;
2 use quote::quote;
3 use syn::{Data, DeriveInput, Fields};
4 
5 use crate::helpers::{
6     non_enum_error, occurrence_error, HasStrumVariantProperties, HasTypeProperties,
7 };
8 
from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream>9 pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
10     let name = &ast.ident;
11     let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
12     let variants = match &ast.data {
13         Data::Enum(v) => &v.variants,
14         _ => return Err(non_enum_error()),
15     };
16 
17     let type_properties = ast.get_type_properties()?;
18     let strum_module_path = type_properties.crate_module_path();
19 
20     let mut default_kw = None;
21     let mut default = quote! { _ => ::core::result::Result::Err(#strum_module_path::ParseError::VariantNotFound) };
22     let mut arms = Vec::new();
23     for variant in variants {
24         let ident = &variant.ident;
25         let variant_properties = variant.get_variant_properties()?;
26 
27         if variant_properties.disabled.is_some() {
28             continue;
29         }
30 
31         if let Some(kw) = variant_properties.default {
32             if let Some(fst_kw) = default_kw {
33                 return Err(occurrence_error(fst_kw, kw, "default"));
34             }
35 
36             match &variant.fields {
37                 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {}
38                 _ => {
39                     return Err(syn::Error::new_spanned(
40                         variant,
41                         "Default only works on newtype structs with a single String field",
42                     ))
43                 }
44             }
45 
46             default_kw = Some(kw);
47             default = quote! {
48                 default => ::core::result::Result::Ok(#name::#ident(default.into()))
49             };
50             continue;
51         }
52 
53         let is_ascii_case_insensitive = variant_properties
54             .ascii_case_insensitive
55             .unwrap_or(type_properties.ascii_case_insensitive);
56         // If we don't have any custom variants, add the default serialized name.
57         let attrs = variant_properties
58             .get_serializations(type_properties.case_style)
59             .into_iter()
60             .map(|serialization| {
61                 if is_ascii_case_insensitive {
62                     quote! { s if s.eq_ignore_ascii_case(#serialization) }
63                 } else {
64                     quote! { #serialization }
65                 }
66             });
67 
68         let params = match &variant.fields {
69             Fields::Unit => quote! {},
70             Fields::Unnamed(fields) => {
71                 let defaults =
72                     ::std::iter::repeat(quote!(Default::default())).take(fields.unnamed.len());
73                 quote! { (#(#defaults),*) }
74             }
75             Fields::Named(fields) => {
76                 let fields = fields
77                     .named
78                     .iter()
79                     .map(|field| field.ident.as_ref().unwrap());
80                 quote! { {#(#fields: Default::default()),*} }
81             }
82         };
83 
84         arms.push(quote! { #(#attrs => ::core::result::Result::Ok(#name::#ident #params)),* });
85     }
86 
87     arms.push(default);
88 
89     let from_str = quote! {
90         #[allow(clippy::use_self)]
91         impl #impl_generics ::core::str::FromStr for #name #ty_generics #where_clause {
92             type Err = #strum_module_path::ParseError;
93             fn from_str(s: &str) -> ::core::result::Result< #name #ty_generics , <Self as ::core::str::FromStr>::Err> {
94                 match s {
95                     #(#arms),*
96                 }
97             }
98         }
99     };
100 
101     let try_from_str = try_from_str(
102         name,
103         impl_generics,
104         ty_generics,
105         where_clause,
106         strum_module_path,
107     );
108 
109     Ok(quote! {
110         #from_str
111         #try_from_str
112     })
113 }
114 
115 #[rustversion::before(1.34)]
try_from_str( _name: &proc_macro2::Ident, _impl_generics: syn::ImplGenerics, _ty_generics: syn::TypeGenerics, _where_clause: Option<&syn::WhereClause>, _strum_module_path: syn::Path, ) -> TokenStream116 fn try_from_str(
117     _name: &proc_macro2::Ident,
118     _impl_generics: syn::ImplGenerics,
119     _ty_generics: syn::TypeGenerics,
120     _where_clause: Option<&syn::WhereClause>,
121     _strum_module_path: syn::Path,
122 ) -> TokenStream {
123     Default::default()
124 }
125 
126 #[rustversion::since(1.34)]
try_from_str( name: &proc_macro2::Ident, impl_generics: syn::ImplGenerics, ty_generics: syn::TypeGenerics, where_clause: Option<&syn::WhereClause>, strum_module_path: syn::Path, ) -> TokenStream127 fn try_from_str(
128     name: &proc_macro2::Ident,
129     impl_generics: syn::ImplGenerics,
130     ty_generics: syn::TypeGenerics,
131     where_clause: Option<&syn::WhereClause>,
132     strum_module_path: syn::Path,
133 ) -> TokenStream {
134     quote! {
135         #[allow(clippy::use_self)]
136         impl #impl_generics ::core::convert::TryFrom<&str> for #name #ty_generics #where_clause {
137             type Error = #strum_module_path::ParseError;
138             fn try_from(s: &str) -> ::core::result::Result< #name #ty_generics , <Self as ::core::convert::TryFrom<&str>>::Error> {
139                 ::core::str::FromStr::from_str(s)
140             }
141         }
142     }
143 }
144