1 use proc_macro2::{Ident, Span, TokenStream};
2 use std::fmt::Display;
3 use syn::{
4     parse::{Error, Result},
5     spanned::Spanned,
6     Attribute, Data, DeriveInput, Fields, Lit, Meta, MetaNameValue, NestedMeta,
7 };
8 
9 /// Provides the hook to expand `#[derive(Display)]` into an implementation of `From`
expand(input: &DeriveInput, trait_name: &str) -> Result<TokenStream>10 pub fn expand(input: &DeriveInput, trait_name: &str) -> Result<TokenStream> {
11     let trait_ident = Ident::new(trait_name, Span::call_site());
12     let trait_path = &quote!(::std::fmt::#trait_ident);
13     let trait_attr = match trait_name {
14         "Display" => "display",
15         "Binary" => "binary",
16         "Octal" => "octal",
17         "LowerHex" => "lower_hex",
18         "UpperHex" => "upper_hex",
19         "LowerExp" => "lower_exp",
20         "UpperExp" => "upper_exp",
21         "Pointer" => "pointer",
22         _ => unimplemented!(),
23     };
24 
25     let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
26     let name = &input.ident;
27 
28     let arms = State {
29         trait_path,
30         trait_attr,
31         input,
32     }
33     .get_match_arms()?;
34 
35     Ok(quote! {
36         impl #impl_generics #trait_path for #name #ty_generics #where_clause
37         {
38             #[inline]
39             fn fmt(&self, _derive_more_Display_formatter: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
40                 match self {
41                     #arms
42                     _ => Ok(()) // This is needed for empty enums
43                 }
44             }
45         }
46     })
47 }
48 
49 struct State<'a, 'b> {
50     trait_path: &'b TokenStream,
51     trait_attr: &'static str,
52     input: &'a DeriveInput,
53 }
54 
55 impl<'a, 'b> State<'a, 'b> {
get_proper_syntax(&self) -> impl Display56     fn get_proper_syntax(&self) -> impl Display {
57         format!(
58             r#"Proper syntax: #[{}(fmt = "My format", "arg1", "arg2")]"#,
59             self.trait_attr
60         )
61     }
get_matcher(&self, fields: &Fields) -> TokenStream62     fn get_matcher(&self, fields: &Fields) -> TokenStream {
63         match fields {
64             Fields::Unit => TokenStream::new(),
65             Fields::Unnamed(fields) => {
66                 let fields: TokenStream = (0..fields.unnamed.len())
67                     .map(|n| {
68                         let i = Ident::new(&format!("_{}", n), Span::call_site());
69                         quote!(#i,)
70                     })
71                     .collect();
72                 quote!((#fields))
73             }
74             Fields::Named(fields) => {
75                 let fields: TokenStream = fields
76                     .named
77                     .iter()
78                     .map(|f| {
79                         let i = f.ident.as_ref().unwrap();
80                         quote!(#i,)
81                     })
82                     .collect();
83                 quote!({#fields})
84             }
85         }
86     }
find_meta(&self, attrs: &[Attribute]) -> Result<Option<Meta>>87     fn find_meta(&self, attrs: &[Attribute]) -> Result<Option<Meta>> {
88         let mut it = attrs
89             .iter()
90             .filter_map(|a| a.interpret_meta())
91             .filter(|m| m.name() == self.trait_attr);
92 
93         let meta = it.next();
94         if it.next().is_some() {
95             Err(Error::new(meta.span(), "Too many formats given"))
96         } else {
97             Ok(meta)
98         }
99     }
get_meta_fmt(&self, meta: Meta) -> Result<TokenStream>100     fn get_meta_fmt(&self, meta: Meta) -> Result<TokenStream> {
101         let list = match &meta {
102             Meta::List(list) => list,
103             _ => return Err(Error::new(meta.span(), self.get_proper_syntax())),
104         };
105 
106         let fmt = match &list.nested[0] {
107             NestedMeta::Meta(Meta::NameValue(MetaNameValue {
108                 ident,
109                 lit: Lit::Str(s),
110                 ..
111             }))
112                 if ident == "fmt" =>
113             {
114                 s
115             }
116             _ => return Err(Error::new(list.nested[0].span(), self.get_proper_syntax())),
117         };
118 
119         let args = list
120             .nested
121             .iter()
122             .skip(1) // skip fmt = "..."
123             .try_fold(TokenStream::new(), |args, arg| {
124                 let arg = match arg {
125                     NestedMeta::Literal(Lit::Str(s)) => s,
126                     NestedMeta::Meta(Meta::Word(i)) => {
127                         return Ok(quote_spanned!(list.span()=> #args #i,))
128                     }
129                     _ => return Err(Error::new(arg.span(), self.get_proper_syntax())),
130                 };
131                 let arg: TokenStream = match arg.parse() {
132                     Ok(arg) => arg,
133                     Err(e) => return Err(Error::new(arg.span(), e)),
134                 };
135                 Ok(quote_spanned!(list.span()=> #args #arg,))
136             })?;
137 
138         Ok(quote_spanned!(meta.span()=> write!(_derive_more_Display_formatter, #fmt, #args)))
139     }
infer_fmt(&self, fields: &Fields, name: &Ident) -> Result<TokenStream>140     fn infer_fmt(&self, fields: &Fields, name: &Ident) -> Result<TokenStream> {
141         let fields = match fields {
142             Fields::Unit => {
143                 return Ok(quote!(write!(
144                     _derive_more_Display_formatter,
145                     stringify!(#name)
146                 )));
147             }
148             Fields::Named(fields) => &fields.named,
149             Fields::Unnamed(fields) => &fields.unnamed,
150         };
151         if fields.len() == 0 {
152             return Ok(quote!(write!(
153                 _derive_more_Display_formatter,
154                 stringify!(#name)
155             )));
156         } else if fields.len() > 1 {
157             return Err(Error::new(
158                 fields.span(),
159                 "Can not automatically infer format for types with more than 1 field",
160             ));
161         }
162 
163         let trait_path = self.trait_path;
164         if let Some(ident) = &fields.iter().next().as_ref().unwrap().ident {
165             Ok(quote!(#trait_path::fmt(#ident, _derive_more_Display_formatter)))
166         } else {
167             Ok(quote!(#trait_path::fmt(_0, _derive_more_Display_formatter)))
168         }
169     }
get_match_arms(&self) -> Result<TokenStream>170     fn get_match_arms(&self) -> Result<TokenStream> {
171         match &self.input.data {
172             Data::Enum(e) => {
173                 if let Some(meta) = self.find_meta(&self.input.attrs)? {
174                     let fmt = self.get_meta_fmt(meta)?;
175                     e.variants.iter().try_for_each(|v| {
176                         if let Some(meta) = self.find_meta(&v.attrs)? {
177                             Err(Error::new(
178                                 meta.span(),
179                                 "Can not have a format on the variant when the whole enum has one",
180                             ))
181                         } else {
182                             Ok(())
183                         }
184                     })?;
185                     Ok(quote_spanned!(self.input.span()=> _ => #fmt,))
186                 } else {
187                     e.variants.iter().try_fold(TokenStream::new(), |arms, v| {
188                         let matcher = self.get_matcher(&v.fields);
189                         let fmt = if let Some(meta) = self.find_meta(&v.attrs)? {
190                             self.get_meta_fmt(meta)?
191                         } else {
192                             self.infer_fmt(&v.fields, &v.ident)?
193                         };
194                         let name = &self.input.ident;
195                         let v_name = &v.ident;
196                         Ok(quote_spanned!(self.input.span()=> #arms #name::#v_name #matcher => #fmt,))
197                     })
198                 }
199             }
200             Data::Struct(s) => {
201                 let matcher = self.get_matcher(&s.fields);
202                 let fmt = if let Some(meta) = self.find_meta(&self.input.attrs)? {
203                     self.get_meta_fmt(meta)?
204                 } else {
205                     self.infer_fmt(&s.fields, &self.input.ident)?
206                 };
207                 let name = &self.input.ident;
208                 Ok(quote_spanned!(self.input.span()=> #name #matcher => #fmt,))
209             }
210             Data::Union(_) => {
211                 let meta = self.find_meta(&self.input.attrs)?.ok_or(Error::new(
212                     self.input.span(),
213                     "Can not automatically infer format for unions",
214                 ))?;
215                 let fmt = self.get_meta_fmt(meta)?;
216                 Ok(quote_spanned!(self.input.span()=> _ => #fmt,))
217             }
218         }
219     }
220 }
221