1 // Copyright 2018 Guillaume Pinot (@TeXitoi) <texitoi@texitoi.eu>,
2 // Kevin Knapp (@kbknapp) <kbknapp@gmail.com>, and
3 // Ana Hobden (@hoverbear) <operator@hoverbear.org>
4 //
5 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8 // option. This file may not be copied, modified, or distributed
9 // except according to those terms.
10 //
11 // This work was derived from Structopt (https://github.com/TeXitoi/structopt)
12 // commit#ea76fa1b1b273e65e3b0b1046643715b49bec51f which is licensed under the
13 // MIT/Apache 2.0 license.
14 
15 use crate::{
16     attrs::{Attrs, Kind, Name, ParserKind, DEFAULT_CASING, DEFAULT_ENV_CASING},
17     dummies,
18     utils::{inner_type, sub_type, Sp, Ty},
19 };
20 
21 use proc_macro2::{Ident, Span, TokenStream};
22 use proc_macro_error::{abort, abort_call_site};
23 use quote::{format_ident, quote, quote_spanned};
24 use syn::{
25     punctuated::Punctuated, spanned::Spanned, token::Comma, Attribute, Data, DataStruct,
26     DeriveInput, Field, Fields, Generics, Type,
27 };
28 
derive_args(input: &DeriveInput) -> TokenStream29 pub fn derive_args(input: &DeriveInput) -> TokenStream {
30     let ident = &input.ident;
31 
32     dummies::args(ident);
33 
34     match input.data {
35         Data::Struct(DataStruct {
36             fields: Fields::Named(ref fields),
37             ..
38         }) => gen_for_struct(ident, &input.generics, &fields.named, &input.attrs),
39         Data::Struct(DataStruct {
40             fields: Fields::Unit,
41             ..
42         }) => gen_for_struct(
43             ident,
44             &input.generics,
45             &Punctuated::<Field, Comma>::new(),
46             &input.attrs,
47         ),
48         _ => abort_call_site!("`#[derive(Args)]` only supports non-tuple structs"),
49     }
50 }
51 
gen_for_struct( struct_name: &Ident, generics: &Generics, fields: &Punctuated<Field, Comma>, attrs: &[Attribute], ) -> TokenStream52 pub fn gen_for_struct(
53     struct_name: &Ident,
54     generics: &Generics,
55     fields: &Punctuated<Field, Comma>,
56     attrs: &[Attribute],
57 ) -> TokenStream {
58     let from_arg_matches = gen_from_arg_matches_for_struct(struct_name, generics, fields, attrs);
59 
60     let attrs = Attrs::from_struct(
61         Span::call_site(),
62         attrs,
63         Name::Derived(struct_name.clone()),
64         Sp::call_site(DEFAULT_CASING),
65         Sp::call_site(DEFAULT_ENV_CASING),
66     );
67     let app_var = Ident::new("__clap_app", Span::call_site());
68     let augmentation = gen_augment(fields, &app_var, &attrs, false);
69     let augmentation_update = gen_augment(fields, &app_var, &attrs, true);
70 
71     let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
72 
73     quote! {
74         #from_arg_matches
75 
76         #[allow(dead_code, unreachable_code, unused_variables, unused_braces)]
77         #[allow(
78             clippy::style,
79             clippy::complexity,
80             clippy::pedantic,
81             clippy::restriction,
82             clippy::perf,
83             clippy::deprecated,
84             clippy::nursery,
85             clippy::cargo,
86             clippy::suspicious_else_formatting,
87         )]
88         #[deny(clippy::correctness)]
89         impl #impl_generics clap::Args for #struct_name #ty_generics #where_clause {
90             fn augment_args<'b>(#app_var: clap::App<'b>) -> clap::App<'b> {
91                 #augmentation
92             }
93             fn augment_args_for_update<'b>(#app_var: clap::App<'b>) -> clap::App<'b> {
94                 #augmentation_update
95             }
96         }
97     }
98 }
99 
gen_from_arg_matches_for_struct( struct_name: &Ident, generics: &Generics, fields: &Punctuated<Field, Comma>, attrs: &[Attribute], ) -> TokenStream100 pub fn gen_from_arg_matches_for_struct(
101     struct_name: &Ident,
102     generics: &Generics,
103     fields: &Punctuated<Field, Comma>,
104     attrs: &[Attribute],
105 ) -> TokenStream {
106     let attrs = Attrs::from_struct(
107         Span::call_site(),
108         attrs,
109         Name::Derived(struct_name.clone()),
110         Sp::call_site(DEFAULT_CASING),
111         Sp::call_site(DEFAULT_ENV_CASING),
112     );
113 
114     let constructor = gen_constructor(fields, &attrs);
115     let updater = gen_updater(fields, &attrs, true);
116 
117     let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
118 
119     quote! {
120         #[allow(dead_code, unreachable_code, unused_variables, unused_braces)]
121         #[allow(
122             clippy::style,
123             clippy::complexity,
124             clippy::pedantic,
125             clippy::restriction,
126             clippy::perf,
127             clippy::deprecated,
128             clippy::nursery,
129             clippy::cargo,
130             clippy::suspicious_else_formatting,
131         )]
132         #[deny(clippy::correctness)]
133         impl #impl_generics clap::FromArgMatches for #struct_name #ty_generics #where_clause {
134             fn from_arg_matches(__clap_arg_matches: &clap::ArgMatches) -> ::std::result::Result<Self, clap::Error> {
135                 let v = #struct_name #constructor;
136                 ::std::result::Result::Ok(v)
137             }
138 
139             fn update_from_arg_matches(&mut self, __clap_arg_matches: &clap::ArgMatches) -> ::std::result::Result<(), clap::Error> {
140                 #updater
141                 ::std::result::Result::Ok(())
142             }
143         }
144     }
145 }
146 
147 /// Generate a block of code to add arguments/subcommands corresponding to
148 /// the `fields` to an app.
gen_augment( fields: &Punctuated<Field, Comma>, app_var: &Ident, parent_attribute: &Attrs, override_required: bool, ) -> TokenStream149 pub fn gen_augment(
150     fields: &Punctuated<Field, Comma>,
151     app_var: &Ident,
152     parent_attribute: &Attrs,
153     override_required: bool,
154 ) -> TokenStream {
155     let mut subcmds = fields.iter().filter_map(|field| {
156         let attrs = Attrs::from_field(
157             field,
158             parent_attribute.casing(),
159             parent_attribute.env_casing(),
160         );
161         let kind = attrs.kind();
162         if let Kind::Subcommand(ty) = &*kind {
163             let subcmd_type = match (**ty, sub_type(&field.ty)) {
164                 (Ty::Option, Some(sub_type)) => sub_type,
165                 _ => &field.ty,
166             };
167             let required = if **ty == Ty::Option {
168                 quote!()
169             } else {
170                 quote_spanned! { kind.span()=>
171                     let #app_var = #app_var.setting(
172                         clap::AppSettings::SubcommandRequiredElseHelp
173                     );
174                 }
175             };
176 
177             let span = field.span();
178             let ts = if override_required {
179                 quote! {
180                     let #app_var = <#subcmd_type as clap::Subcommand>::augment_subcommands_for_update( #app_var );
181                 }
182             } else{
183                 quote! {
184                     let #app_var = <#subcmd_type as clap::Subcommand>::augment_subcommands( #app_var );
185                     #required
186                 }
187             };
188             Some((span, ts))
189         } else {
190             None
191         }
192     });
193     let subcmd = subcmds.next().map(|(_, ts)| ts);
194     if let Some((span, _)) = subcmds.next() {
195         abort!(
196             span,
197             "multiple subcommand sets are not allowed, that's the second"
198         );
199     }
200 
201     let args = fields.iter().filter_map(|field| {
202         let attrs = Attrs::from_field(
203             field,
204             parent_attribute.casing(),
205             parent_attribute.env_casing(),
206         );
207         let kind = attrs.kind();
208         match &*kind {
209             Kind::Subcommand(_)
210             | Kind::Skip(_)
211             | Kind::FromGlobal(_)
212             | Kind::ExternalSubcommand => None,
213             Kind::Flatten => {
214                 let ty = &field.ty;
215                 let old_heading_var = format_ident!("__clap_old_heading");
216                 let help_heading = attrs.help_heading();
217                 if override_required {
218                     Some(quote_spanned! { kind.span()=>
219                         let #old_heading_var = #app_var.get_help_heading();
220                         let #app_var = #app_var #help_heading;
221                         let #app_var = <#ty as clap::Args>::augment_args_for_update(#app_var);
222                         let #app_var = #app_var.help_heading(#old_heading_var);
223                     })
224                 } else {
225                     Some(quote_spanned! { kind.span()=>
226                         let #old_heading_var = #app_var.get_help_heading();
227                         let #app_var = #app_var #help_heading;
228                         let #app_var = <#ty as clap::Args>::augment_args(#app_var);
229                         let #app_var = #app_var.help_heading(#old_heading_var);
230                     })
231                 }
232             }
233             Kind::Arg(ty) => {
234                 let convert_type = inner_type(**ty, &field.ty);
235 
236                 let occurrences = *attrs.parser().kind == ParserKind::FromOccurrences;
237                 let flag = *attrs.parser().kind == ParserKind::FromFlag;
238 
239                 let parser = attrs.parser();
240                 let func = &parser.func;
241 
242                 let validator = match *parser.kind {
243                     _ if attrs.is_enum() => quote!(),
244                     ParserKind::TryFromStr => quote_spanned! { func.span()=>
245                         .validator(|s| {
246                             #func(s)
247                             .map(|_: #convert_type| ())
248                         })
249                     },
250                     ParserKind::TryFromOsStr => quote_spanned! { func.span()=>
251                         .validator_os(|s| #func(s).map(|_: #convert_type| ()))
252                     },
253                     ParserKind::FromStr
254                     | ParserKind::FromOsStr
255                     | ParserKind::FromFlag
256                     | ParserKind::FromOccurrences => quote!(),
257                 };
258                 let allow_invalid_utf8 = match *parser.kind {
259                     _ if attrs.is_enum() => quote!(),
260                     ParserKind::FromOsStr | ParserKind::TryFromOsStr => {
261                         quote_spanned! { func.span()=>
262                             .allow_invalid_utf8(true)
263                         }
264                     }
265                     ParserKind::FromStr
266                     | ParserKind::TryFromStr
267                     | ParserKind::FromFlag
268                     | ParserKind::FromOccurrences => quote!(),
269                 };
270 
271                 let value_name = attrs.value_name();
272                 let possible_values = if attrs.is_enum() {
273                     gen_arg_enum_possible_values(convert_type)
274                 } else {
275                     quote!()
276                 };
277 
278                 let modifier = match **ty {
279                     Ty::Bool => quote!(),
280 
281                     Ty::Option => {
282                         quote_spanned! { ty.span()=>
283                             .takes_value(true)
284                             .value_name(#value_name)
285                             #possible_values
286                             #validator
287                             #allow_invalid_utf8
288                         }
289                     }
290 
291                     Ty::OptionOption => quote_spanned! { ty.span()=>
292                         .takes_value(true)
293                         .value_name(#value_name)
294                         .min_values(0)
295                         .max_values(1)
296                         .multiple_values(false)
297                         #possible_values
298                         #validator
299                         #allow_invalid_utf8
300                     },
301 
302                     Ty::OptionVec => quote_spanned! { ty.span()=>
303                         .takes_value(true)
304                         .value_name(#value_name)
305                         .multiple_occurrences(true)
306                         #possible_values
307                         #validator
308                         #allow_invalid_utf8
309                     },
310 
311                     Ty::Vec => {
312                         quote_spanned! { ty.span()=>
313                             .takes_value(true)
314                             .value_name(#value_name)
315                             .multiple_occurrences(true)
316                             #possible_values
317                             #validator
318                             #allow_invalid_utf8
319                         }
320                     }
321 
322                     Ty::Other if occurrences => quote_spanned! { ty.span()=>
323                         .multiple_occurrences(true)
324                     },
325 
326                     Ty::Other if flag => quote_spanned! { ty.span()=>
327                         .takes_value(false)
328                     },
329 
330                     Ty::Other => {
331                         let required = attrs.find_default_method().is_none() && !override_required;
332                         quote_spanned! { ty.span()=>
333                             .takes_value(true)
334                             .value_name(#value_name)
335                             .required(#required)
336                             #possible_values
337                             #validator
338                             #allow_invalid_utf8
339                         }
340                     }
341                 };
342 
343                 let name = attrs.cased_name();
344                 let methods = attrs.field_methods(true);
345 
346                 Some(quote_spanned! { field.span()=>
347                     let #app_var = #app_var.arg(
348                         clap::Arg::new(#name)
349                             #modifier
350                             #methods
351                     );
352                 })
353             }
354         }
355     });
356 
357     let initial_app_methods = parent_attribute.initial_top_level_methods();
358     let final_app_methods = parent_attribute.final_top_level_methods();
359     quote! {{
360         let #app_var = #app_var #initial_app_methods;
361         #( #args )*
362         #subcmd
363         #app_var #final_app_methods
364     }}
365 }
366 
gen_arg_enum_possible_values(ty: &Type) -> TokenStream367 fn gen_arg_enum_possible_values(ty: &Type) -> TokenStream {
368     quote_spanned! { ty.span()=>
369         .possible_values(<#ty as clap::ArgEnum>::value_variants().iter().filter_map(clap::ArgEnum::to_possible_value))
370     }
371 }
372 
gen_constructor(fields: &Punctuated<Field, Comma>, parent_attribute: &Attrs) -> TokenStream373 pub fn gen_constructor(fields: &Punctuated<Field, Comma>, parent_attribute: &Attrs) -> TokenStream {
374     let fields = fields.iter().map(|field| {
375         let attrs = Attrs::from_field(
376             field,
377             parent_attribute.casing(),
378             parent_attribute.env_casing(),
379         );
380         let field_name = field.ident.as_ref().unwrap();
381         let kind = attrs.kind();
382         let arg_matches = format_ident!("__clap_arg_matches");
383         match &*kind {
384             Kind::ExternalSubcommand => {
385                 abort! { kind.span(),
386                     "`external_subcommand` can be used only on enum variants"
387                 }
388             }
389             Kind::Subcommand(ty) => {
390                 let subcmd_type = match (**ty, sub_type(&field.ty)) {
391                     (Ty::Option, Some(sub_type)) => sub_type,
392                     _ => &field.ty,
393                 };
394                 match **ty {
395                     Ty::Option => {
396                         quote_spanned! { kind.span()=>
397                             #field_name: {
398                                 if #arg_matches.subcommand_name().map(<#subcmd_type as clap::Subcommand>::has_subcommand).unwrap_or(false) {
399                                     Some(<#subcmd_type as clap::FromArgMatches>::from_arg_matches(#arg_matches)?)
400                                 } else {
401                                     None
402                                 }
403                             }
404                         }
405                     },
406                     _ => {
407                         quote_spanned! { kind.span()=>
408                             #field_name: {
409                                 <#subcmd_type as clap::FromArgMatches>::from_arg_matches(#arg_matches)?
410                             }
411                         }
412                     },
413                 }
414             }
415 
416             Kind::Flatten => quote_spanned! { kind.span()=>
417                 #field_name: clap::FromArgMatches::from_arg_matches(#arg_matches)?
418             },
419 
420             Kind::Skip(val) => match val {
421                 None => quote_spanned!(kind.span()=> #field_name: Default::default()),
422                 Some(val) => quote_spanned!(kind.span()=> #field_name: (#val).into()),
423             },
424 
425             Kind::Arg(ty) | Kind::FromGlobal(ty) => {
426                 gen_parsers(&attrs, ty, field_name, field, None)
427             }
428         }
429     });
430 
431     quote! {{
432         #( #fields ),*
433     }}
434 }
435 
gen_updater( fields: &Punctuated<Field, Comma>, parent_attribute: &Attrs, use_self: bool, ) -> TokenStream436 pub fn gen_updater(
437     fields: &Punctuated<Field, Comma>,
438     parent_attribute: &Attrs,
439     use_self: bool,
440 ) -> TokenStream {
441     let fields = fields.iter().map(|field| {
442         let attrs = Attrs::from_field(
443             field,
444             parent_attribute.casing(),
445             parent_attribute.env_casing(),
446         );
447         let field_name = field.ident.as_ref().unwrap();
448         let kind = attrs.kind();
449 
450         let access = if use_self {
451             quote! {
452                 #[allow(non_snake_case)]
453                 let #field_name = &mut self.#field_name;
454             }
455         } else {
456             quote!()
457         };
458         let arg_matches = format_ident!("__clap_arg_matches");
459 
460         match &*kind {
461             Kind::ExternalSubcommand => {
462                 abort! { kind.span(),
463                     "`external_subcommand` can be used only on enum variants"
464                 }
465             }
466             Kind::Subcommand(ty) => {
467                 let subcmd_type = match (**ty, sub_type(&field.ty)) {
468                     (Ty::Option, Some(sub_type)) => sub_type,
469                     _ => &field.ty,
470                 };
471 
472                 let updater = quote_spanned! { ty.span()=>
473                     <#subcmd_type as clap::FromArgMatches>::update_from_arg_matches(#field_name, #arg_matches)?;
474                 };
475 
476                 let updater = match **ty {
477                     Ty::Option => quote_spanned! { kind.span()=>
478                         if let Some(#field_name) = #field_name.as_mut() {
479                             #updater
480                         } else {
481                             *#field_name = Some(<#subcmd_type as clap::FromArgMatches>::from_arg_matches(
482                                 #arg_matches
483                             )?);
484                         }
485                     },
486                     _ => quote_spanned! { kind.span()=>
487                         #updater
488                     },
489                 };
490 
491                 quote_spanned! { kind.span()=>
492                     {
493                         #access
494                         #updater
495                     }
496                 }
497             }
498 
499             Kind::Flatten => quote_spanned! { kind.span()=> {
500                     #access
501                     clap::FromArgMatches::update_from_arg_matches(#field_name, #arg_matches)?;
502                 }
503             },
504 
505             Kind::Skip(_) => quote!(),
506 
507             Kind::Arg(ty) | Kind::FromGlobal(ty) => gen_parsers(&attrs, ty, field_name, field, Some(&access)),
508         }
509     });
510 
511     quote! {
512         #( #fields )*
513     }
514 }
515 
gen_parsers( attrs: &Attrs, ty: &Sp<Ty>, field_name: &Ident, field: &Field, update: Option<&TokenStream>, ) -> TokenStream516 fn gen_parsers(
517     attrs: &Attrs,
518     ty: &Sp<Ty>,
519     field_name: &Ident,
520     field: &Field,
521     update: Option<&TokenStream>,
522 ) -> TokenStream {
523     use self::ParserKind::*;
524 
525     let parser = attrs.parser();
526     let func = &parser.func;
527     let span = parser.kind.span();
528     let convert_type = inner_type(**ty, &field.ty);
529     let name = attrs.cased_name();
530     let (value_of, values_of, mut parse) = match *parser.kind {
531         FromStr => (
532             quote_spanned!(span=> value_of),
533             quote_spanned!(span=> values_of),
534             quote_spanned!(func.span()=> |s| ::std::result::Result::Ok::<_, clap::Error>(#func(s))),
535         ),
536         TryFromStr => (
537             quote_spanned!(span=> value_of),
538             quote_spanned!(span=> values_of),
539             quote_spanned!(func.span()=> |s| #func(s).map_err(|err| clap::Error::raw(clap::ErrorKind::ValueValidation, format!("Invalid value for {}: {}", #name, err)))),
540         ),
541         FromOsStr => (
542             quote_spanned!(span=> value_of_os),
543             quote_spanned!(span=> values_of_os),
544             quote_spanned!(func.span()=> |s| ::std::result::Result::Ok::<_, clap::Error>(#func(s))),
545         ),
546         TryFromOsStr => (
547             quote_spanned!(span=> value_of_os),
548             quote_spanned!(span=> values_of_os),
549             quote_spanned!(func.span()=> |s| #func(s).map_err(|err| clap::Error::raw(clap::ErrorKind::ValueValidation, format!("Invalid value for {}: {}", #name, err)))),
550         ),
551         FromOccurrences => (
552             quote_spanned!(span=> occurrences_of),
553             quote!(),
554             func.clone(),
555         ),
556         FromFlag => (quote!(), quote!(), func.clone()),
557     };
558     if attrs.is_enum() {
559         let ci = attrs.ignore_case();
560 
561         parse = quote_spanned! { convert_type.span()=>
562             |s| <#convert_type as clap::ArgEnum>::from_str(s, #ci).map_err(|err| clap::Error::raw(clap::ErrorKind::ValueValidation, format!("Invalid value for {}: {}", #name, err)))
563         }
564     }
565 
566     let flag = *attrs.parser().kind == ParserKind::FromFlag;
567     let occurrences = *attrs.parser().kind == ParserKind::FromOccurrences;
568     // Give this identifier the same hygiene
569     // as the `arg_matches` parameter definition. This
570     // allows us to refer to `arg_matches` within a `quote_spanned` block
571     let arg_matches = format_ident!("__clap_arg_matches");
572 
573     let field_value = match **ty {
574         Ty::Bool => {
575             if update.is_some() {
576                 quote_spanned! { ty.span()=>
577                     *#field_name || #arg_matches.is_present(#name)
578                 }
579             } else {
580                 quote_spanned! { ty.span()=>
581                     #arg_matches.is_present(#name)
582                 }
583             }
584         }
585 
586         Ty::Option => {
587             quote_spanned! { ty.span()=>
588                 #arg_matches.#value_of(#name)
589                     .map(#parse)
590                     .transpose()?
591             }
592         }
593 
594         Ty::OptionOption => quote_spanned! { ty.span()=>
595             if #arg_matches.is_present(#name) {
596                 Some(#arg_matches.#value_of(#name).map(#parse).transpose()?)
597             } else {
598                 None
599             }
600         },
601 
602         Ty::OptionVec => quote_spanned! { ty.span()=>
603             if #arg_matches.is_present(#name) {
604                 Some(#arg_matches.#values_of(#name)
605                     .map(|v| v.map::<::std::result::Result<#convert_type, clap::Error>, _>(#parse).collect::<::std::result::Result<Vec<_>, clap::Error>>())
606                     .transpose()?
607                     .unwrap_or_else(Vec::new))
608             } else {
609                 None
610             }
611         },
612 
613         Ty::Vec => {
614             quote_spanned! { ty.span()=>
615                 #arg_matches.#values_of(#name)
616                     .map(|v| v.map::<::std::result::Result<#convert_type, clap::Error>, _>(#parse).collect::<::std::result::Result<Vec<_>, clap::Error>>())
617                     .transpose()?
618                     .unwrap_or_else(Vec::new)
619             }
620         }
621 
622         Ty::Other if occurrences => quote_spanned! { ty.span()=>
623             #parse(#arg_matches.#value_of(#name))
624         },
625 
626         Ty::Other if flag => quote_spanned! { ty.span()=>
627             #parse(#arg_matches.is_present(#name))
628         },
629 
630         Ty::Other => {
631             quote_spanned! { ty.span()=>
632                 #arg_matches.#value_of(#name)
633                     .ok_or_else(|| clap::Error::raw(clap::ErrorKind::MissingRequiredArgument, format!("The following required argument was not provided: {}", #name)))
634                     .and_then(#parse)?
635             }
636         }
637     };
638 
639     if let Some(access) = update {
640         quote_spanned! { field.span()=>
641             if #arg_matches.is_present(#name) {
642                 #access
643                 *#field_name = #field_value
644             }
645         }
646     } else {
647         quote_spanned!(field.span()=> #field_name: #field_value )
648     }
649 }
650