1 #![recursion_limit = "256"]
2 // Copyright (c) 2020 Google LLC All rights reserved.
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 /// Implementation of the `FromArgs` and `argh(...)` derive attributes.
7 ///
8 /// For more thorough documentation, see the `argh` crate itself.
9 extern crate proc_macro;
10 
11 use {
12     crate::{
13         errors::Errors,
14         parse_attrs::{FieldAttrs, FieldKind, TypeAttrs},
15     },
16     proc_macro2::{Span, TokenStream},
17     quote::{quote, quote_spanned, ToTokens},
18     std::str::FromStr,
19     syn::spanned::Spanned,
20 };
21 
22 mod errors;
23 mod help;
24 mod parse_attrs;
25 
26 /// Entrypoint for `#[derive(FromArgs)]`.
27 #[proc_macro_derive(FromArgs, attributes(argh))]
argh_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream28 pub fn argh_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
29     let ast = syn::parse_macro_input!(input as syn::DeriveInput);
30     let gen = impl_from_args(&ast);
31     gen.into()
32 }
33 
34 /// Transform the input into a token stream containing any generated implementations,
35 /// as well as all errors that occurred.
impl_from_args(input: &syn::DeriveInput) -> TokenStream36 fn impl_from_args(input: &syn::DeriveInput) -> TokenStream {
37     let errors = &Errors::default();
38     if input.generics.params.len() != 0 {
39         errors.err(
40             &input.generics,
41             "`#![derive(FromArgs)]` cannot be applied to types with generic parameters",
42         );
43     }
44     let type_attrs = &TypeAttrs::parse(errors, input);
45     let mut output_tokens = match &input.data {
46         syn::Data::Struct(ds) => impl_from_args_struct(errors, &input.ident, type_attrs, ds),
47         syn::Data::Enum(de) => impl_from_args_enum(errors, &input.ident, type_attrs, de),
48         syn::Data::Union(_) => {
49             errors.err(input, "`#[derive(FromArgs)]` cannot be applied to unions");
50             TokenStream::new()
51         }
52     };
53     errors.to_tokens(&mut output_tokens);
54     output_tokens
55 }
56 
57 /// The kind of optionality a parameter has.
58 enum Optionality {
59     None,
60     Defaulted(TokenStream),
61     Optional,
62     Repeating,
63 }
64 
65 impl PartialEq<Optionality> for Optionality {
eq(&self, other: &Optionality) -> bool66     fn eq(&self, other: &Optionality) -> bool {
67         use Optionality::*;
68         match (self, other) {
69             (None, None) | (Optional, Optional) | (Repeating, Repeating) => true,
70             // NB: (Defaulted, Defaulted) can't contain the same token streams
71             _ => false,
72         }
73     }
74 }
75 
76 impl Optionality {
77     /// Whether or not this is `Optionality::None`
is_required(&self) -> bool78     fn is_required(&self) -> bool {
79         if let Optionality::None = self {
80             true
81         } else {
82             false
83         }
84     }
85 }
86 
87 /// A field of a `#![derive(FromArgs)]` struct with attributes and some other
88 /// notable metadata appended.
89 struct StructField<'a> {
90     /// The original parsed field
91     field: &'a syn::Field,
92     /// The parsed attributes of the field
93     attrs: FieldAttrs,
94     /// The field name. This is contained optionally inside `field`,
95     /// but is duplicated non-optionally here to indicate that all field that
96     /// have reached this point must have a field name, and it no longer
97     /// needs to be unwrapped.
98     name: &'a syn::Ident,
99     /// Similar to `name` above, this is contained optionally inside `FieldAttrs`,
100     /// but here is fully present to indicate that we only have to consider fields
101     /// with a valid `kind` at this point.
102     kind: FieldKind,
103     // If `field.ty` is `Vec<T>` or `Option<T>`, this is `T`, otherwise it's `&field.ty`.
104     // This is used to enable consistent parsing code between optional and non-optional
105     // keyed and subcommand fields.
106     ty_without_wrapper: &'a syn::Type,
107     // Whether the field represents an optional value, such as an `Option` subcommand field
108     // or an `Option` or `Vec` keyed argument, or if it has a `default`.
109     optionality: Optionality,
110     // The `--`-prefixed name of the option, if one exists.
111     long_name: Option<String>,
112 }
113 
114 impl<'a> StructField<'a> {
115     /// Attempts to parse a field of a `#[derive(FromArgs)]` struct, pulling out the
116     /// fields required for code generation.
new(errors: &Errors, field: &'a syn::Field, attrs: FieldAttrs) -> Option<Self>117     fn new(errors: &Errors, field: &'a syn::Field, attrs: FieldAttrs) -> Option<Self> {
118         let name = field.ident.as_ref().expect("missing ident for named field");
119 
120         // Ensure that one "kind" is present (switch, option, subcommand, positional)
121         let kind = if let Some(field_type) = &attrs.field_type {
122             field_type.kind
123         } else {
124             errors.err(
125                 field,
126                 concat!(
127                     "Missing `argh` field kind attribute.\n",
128                     "Expected one of: `switch`, `option`, `subcommand`, `positional`",
129                 ),
130             );
131             return None;
132         };
133 
134         // Parse out whether a field is optional (`Option` or `Vec`).
135         let optionality;
136         let ty_without_wrapper;
137         match kind {
138             FieldKind::Switch => {
139                 if !ty_expect_switch(errors, &field.ty) {
140                     return None;
141                 }
142                 optionality = Optionality::Optional;
143                 ty_without_wrapper = &field.ty;
144             }
145             FieldKind::Option | FieldKind::Positional => {
146                 if let Some(default) = &attrs.default {
147                     let tokens = match TokenStream::from_str(&default.value()) {
148                         Ok(tokens) => tokens,
149                         Err(_) => {
150                             errors.err(&default, "Invalid tokens: unable to lex `default` value");
151                             return None;
152                         }
153                     };
154                     // Set the span of the generated tokens to the string literal
155                     let tokens: TokenStream = tokens
156                         .into_iter()
157                         .map(|mut tree| {
158                             tree.set_span(default.span().clone());
159                             tree
160                         })
161                         .collect();
162                     optionality = Optionality::Defaulted(tokens);
163                     ty_without_wrapper = &field.ty;
164                 } else {
165                     let mut inner = None;
166                     optionality = if let Some(x) = ty_inner(&["Option"], &field.ty) {
167                         inner = Some(x);
168                         Optionality::Optional
169                     } else if let Some(x) = ty_inner(&["Vec"], &field.ty) {
170                         inner = Some(x);
171                         Optionality::Repeating
172                     } else {
173                         Optionality::None
174                     };
175                     ty_without_wrapper = inner.unwrap_or(&field.ty);
176                 }
177             }
178             FieldKind::SubCommand => {
179                 let inner = ty_inner(&["Option"], &field.ty);
180                 optionality =
181                     if inner.is_some() { Optionality::Optional } else { Optionality::None };
182                 ty_without_wrapper = inner.unwrap_or(&field.ty);
183             }
184         }
185 
186         // Determine the "long" name of options and switches.
187         // Defaults to the kebab-case'd field name if `#[argh(long = "...")]` is omitted.
188         let long_name = match kind {
189             FieldKind::Switch | FieldKind::Option => {
190                 let long_name = attrs
191                     .long
192                     .as_ref()
193                     .map(syn::LitStr::value)
194                     .unwrap_or_else(|| heck::KebabCase::to_kebab_case(&*name.to_string()));
195                 if long_name == "help" {
196                     errors.err(field, "Custom `--help` flags are not supported.");
197                 }
198                 let long_name = format!("--{}", long_name);
199                 Some(long_name)
200             }
201             FieldKind::SubCommand | FieldKind::Positional => None,
202         };
203 
204         Some(StructField { field, attrs, kind, optionality, ty_without_wrapper, name, long_name })
205     }
206 }
207 
208 /// Implements `FromArgs` and `TopLevelCommand` or `SubCommand` for a `#[derive(FromArgs)]` struct.
impl_from_args_struct( errors: &Errors, name: &syn::Ident, type_attrs: &TypeAttrs, ds: &syn::DataStruct, ) -> TokenStream209 fn impl_from_args_struct(
210     errors: &Errors,
211     name: &syn::Ident,
212     type_attrs: &TypeAttrs,
213     ds: &syn::DataStruct,
214 ) -> TokenStream {
215     let fields = match &ds.fields {
216         syn::Fields::Named(fields) => fields,
217         syn::Fields::Unnamed(_) => {
218             errors.err(
219                 &ds.struct_token,
220                 "`#![derive(FromArgs)]` is not currently supported on tuple structs",
221             );
222             return TokenStream::new();
223         }
224         syn::Fields::Unit => {
225             errors.err(&ds.struct_token, "#![derive(FromArgs)]` cannot be applied to unit structs");
226             return TokenStream::new();
227         }
228     };
229 
230     let fields: Vec<_> = fields
231         .named
232         .iter()
233         .filter_map(|field| {
234             let attrs = FieldAttrs::parse(errors, field);
235             StructField::new(errors, field, attrs)
236         })
237         .collect();
238 
239     ensure_only_last_positional_is_optional(errors, &fields);
240 
241     let impl_span = Span::call_site();
242 
243     let from_args_method = impl_from_args_struct_from_args(errors, type_attrs, &fields);
244 
245     let redact_arg_values_method =
246         impl_from_args_struct_redact_arg_values(errors, type_attrs, &fields);
247 
248     let top_or_sub_cmd_impl = top_or_sub_cmd_impl(errors, name, type_attrs);
249 
250     let trait_impl = quote_spanned! { impl_span =>
251         impl argh::FromArgs for #name {
252             #from_args_method
253 
254             #redact_arg_values_method
255         }
256 
257         #top_or_sub_cmd_impl
258     };
259 
260     trait_impl.into()
261 }
262 
impl_from_args_struct_from_args<'a>( errors: &Errors, type_attrs: &TypeAttrs, fields: &'a [StructField<'a>], ) -> TokenStream263 fn impl_from_args_struct_from_args<'a>(
264     errors: &Errors,
265     type_attrs: &TypeAttrs,
266     fields: &'a [StructField<'a>],
267 ) -> TokenStream {
268     let init_fields = declare_local_storage_for_from_args_fields(&fields);
269     let unwrap_fields = unwrap_from_args_fields(&fields);
270     let positional_fields: Vec<&StructField<'_>> =
271         fields.iter().filter(|field| field.kind == FieldKind::Positional).collect();
272     let positional_field_idents = positional_fields.iter().map(|field| &field.field.ident);
273     let positional_field_names = positional_fields.iter().map(|field| field.name.to_string());
274     let last_positional_is_repeating = positional_fields
275         .last()
276         .map(|field| field.optionality == Optionality::Repeating)
277         .unwrap_or(false);
278 
279     let flag_output_table = fields.iter().filter_map(|field| {
280         let field_name = &field.field.ident;
281         match field.kind {
282             FieldKind::Option => Some(quote! { argh::ParseStructOption::Value(&mut #field_name) }),
283             FieldKind::Switch => Some(quote! { argh::ParseStructOption::Flag(&mut #field_name) }),
284             FieldKind::SubCommand | FieldKind::Positional => None,
285         }
286     });
287 
288     let flag_str_to_output_table_map = flag_str_to_output_table_map_entries(&fields);
289 
290     let mut subcommands_iter =
291         fields.iter().filter(|field| field.kind == FieldKind::SubCommand).fuse();
292 
293     let subcommand: Option<&StructField<'_>> = subcommands_iter.next();
294     while let Some(dup_subcommand) = subcommands_iter.next() {
295         errors.duplicate_attrs("subcommand", subcommand.unwrap().field, dup_subcommand.field);
296     }
297 
298     let impl_span = Span::call_site();
299 
300     let missing_requirements_ident = syn::Ident::new("__missing_requirements", impl_span.clone());
301 
302     let append_missing_requirements =
303         append_missing_requirements(&missing_requirements_ident, &fields);
304 
305     let parse_subcommands = if let Some(subcommand) = subcommand {
306         let name = subcommand.name;
307         let ty = subcommand.ty_without_wrapper;
308         quote_spanned! { impl_span =>
309             Some(argh::ParseStructSubCommand {
310                 subcommands: <#ty as argh::SubCommands>::COMMANDS,
311                 parse_func: &mut |__command, __remaining_args| {
312                     #name = Some(<#ty as argh::FromArgs>::from_args(__command, __remaining_args)?);
313                     Ok(())
314                 },
315             })
316         }
317     } else {
318         quote_spanned! { impl_span => None }
319     };
320 
321     // Identifier referring to a value containing the name of the current command as an `&[&str]`.
322     let cmd_name_str_array_ident = syn::Ident::new("__cmd_name", impl_span.clone());
323     let help = help::help(errors, cmd_name_str_array_ident, type_attrs, &fields, subcommand);
324 
325     let method_impl = quote_spanned! { impl_span =>
326         fn from_args(__cmd_name: &[&str], __args: &[&str])
327             -> std::result::Result<Self, argh::EarlyExit>
328         {
329             #( #init_fields )*
330 
331             argh::parse_struct_args(
332                 __cmd_name,
333                 __args,
334                 argh::ParseStructOptions {
335                     arg_to_slot: &[ #( #flag_str_to_output_table_map ,)* ],
336                     slots: &mut [ #( #flag_output_table, )* ],
337                 },
338                 argh::ParseStructPositionals {
339                     positionals: &mut [
340                         #(
341                             argh::ParseStructPositional {
342                                 name: #positional_field_names,
343                                 slot: &mut #positional_field_idents as &mut argh::ParseValueSlot,
344                             },
345                         )*
346                     ],
347                     last_is_repeating: #last_positional_is_repeating,
348                 },
349                 #parse_subcommands,
350                 &|| #help,
351             )?;
352 
353             let mut #missing_requirements_ident = argh::MissingRequirements::default();
354             #(
355                 #append_missing_requirements
356             )*
357             #missing_requirements_ident.err_on_any()?;
358 
359             Ok(Self {
360                 #( #unwrap_fields, )*
361             })
362         }
363     };
364 
365     method_impl.into()
366 }
367 
impl_from_args_struct_redact_arg_values<'a>( errors: &Errors, type_attrs: &TypeAttrs, fields: &'a [StructField<'a>], ) -> TokenStream368 fn impl_from_args_struct_redact_arg_values<'a>(
369     errors: &Errors,
370     type_attrs: &TypeAttrs,
371     fields: &'a [StructField<'a>],
372 ) -> TokenStream {
373     let init_fields = declare_local_storage_for_redacted_fields(&fields);
374     let unwrap_fields = unwrap_redacted_fields(&fields);
375 
376     let positional_fields: Vec<&StructField<'_>> =
377         fields.iter().filter(|field| field.kind == FieldKind::Positional).collect();
378     let positional_field_idents = positional_fields.iter().map(|field| &field.field.ident);
379     let positional_field_names = positional_fields.iter().map(|field| field.name.to_string());
380     let last_positional_is_repeating = positional_fields
381         .last()
382         .map(|field| field.optionality == Optionality::Repeating)
383         .unwrap_or(false);
384 
385     let flag_output_table = fields.iter().filter_map(|field| {
386         let field_name = &field.field.ident;
387         match field.kind {
388             FieldKind::Option => Some(quote! { argh::ParseStructOption::Value(&mut #field_name) }),
389             FieldKind::Switch => Some(quote! { argh::ParseStructOption::Flag(&mut #field_name) }),
390             FieldKind::SubCommand | FieldKind::Positional => None,
391         }
392     });
393 
394     let flag_str_to_output_table_map = flag_str_to_output_table_map_entries(&fields);
395 
396     let mut subcommands_iter =
397         fields.iter().filter(|field| field.kind == FieldKind::SubCommand).fuse();
398 
399     let subcommand: Option<&StructField<'_>> = subcommands_iter.next();
400     while let Some(dup_subcommand) = subcommands_iter.next() {
401         errors.duplicate_attrs("subcommand", subcommand.unwrap().field, dup_subcommand.field);
402     }
403 
404     let impl_span = Span::call_site();
405 
406     let missing_requirements_ident = syn::Ident::new("__missing_requirements", impl_span.clone());
407 
408     let append_missing_requirements =
409         append_missing_requirements(&missing_requirements_ident, &fields);
410 
411     let redact_subcommands = if let Some(subcommand) = subcommand {
412         let name = subcommand.name;
413         let ty = subcommand.ty_without_wrapper;
414         quote_spanned! { impl_span =>
415             Some(argh::ParseStructSubCommand {
416                 subcommands: <#ty as argh::SubCommands>::COMMANDS,
417                 parse_func: &mut |__command, __remaining_args| {
418                     #name = Some(<#ty as argh::FromArgs>::redact_arg_values(__command, __remaining_args)?);
419                     Ok(())
420                 },
421             })
422         }
423     } else {
424         quote_spanned! { impl_span => None }
425     };
426 
427     let cmd_name = if type_attrs.is_subcommand.is_none() {
428         quote! { __cmd_name.last().expect("no command name").to_string() }
429     } else {
430         quote! { __cmd_name.last().expect("no subcommand name").to_string() }
431     };
432 
433     // Identifier referring to a value containing the name of the current command as an `&[&str]`.
434     let cmd_name_str_array_ident = syn::Ident::new("__cmd_name", impl_span.clone());
435     let help = help::help(errors, cmd_name_str_array_ident, type_attrs, &fields, subcommand);
436 
437     let method_impl = quote_spanned! { impl_span =>
438         fn redact_arg_values(__cmd_name: &[&str], __args: &[&str]) -> Result<Vec<String>, argh::EarlyExit> {
439             #( #init_fields )*
440 
441             argh::parse_struct_args(
442                 __cmd_name,
443                 __args,
444                 argh::ParseStructOptions {
445                     arg_to_slot: &[ #( #flag_str_to_output_table_map ,)* ],
446                     slots: &mut [ #( #flag_output_table, )* ],
447                 },
448                 argh::ParseStructPositionals {
449                     positionals: &mut [
450                         #(
451                             argh::ParseStructPositional {
452                                 name: #positional_field_names,
453                                 slot: &mut #positional_field_idents as &mut argh::ParseValueSlot,
454                             },
455                         )*
456                     ],
457                     last_is_repeating: #last_positional_is_repeating,
458                 },
459                 #redact_subcommands,
460                 &|| #help,
461             )?;
462 
463             let mut #missing_requirements_ident = argh::MissingRequirements::default();
464             #(
465                 #append_missing_requirements
466             )*
467             #missing_requirements_ident.err_on_any()?;
468 
469             let mut __redacted = vec![
470                 #cmd_name,
471             ];
472 
473             #( #unwrap_fields )*
474 
475             Ok(__redacted)
476         }
477     };
478 
479     method_impl.into()
480 }
481 
482 /// Ensures that only the last positional arg is non-required.
ensure_only_last_positional_is_optional(errors: &Errors, fields: &[StructField<'_>])483 fn ensure_only_last_positional_is_optional(errors: &Errors, fields: &[StructField<'_>]) {
484     let mut first_non_required_span = None;
485     for field in fields {
486         if field.kind == FieldKind::Positional {
487             if let Some(first) = first_non_required_span {
488                 errors.err_span(
489                     first,
490                     "Only the last positional argument may be `Option`, `Vec`, or defaulted.",
491                 );
492                 errors.err(&field.field, "Later positional argument declared here.");
493                 return;
494             }
495             if !field.optionality.is_required() {
496                 first_non_required_span = Some(field.field.span());
497             }
498         }
499     }
500 }
501 
502 /// Implement `argh::TopLevelCommand` or `argh::SubCommand` as appropriate.
top_or_sub_cmd_impl(errors: &Errors, name: &syn::Ident, type_attrs: &TypeAttrs) -> TokenStream503 fn top_or_sub_cmd_impl(errors: &Errors, name: &syn::Ident, type_attrs: &TypeAttrs) -> TokenStream {
504     let description =
505         help::require_description(errors, name.span(), &type_attrs.description, "type");
506     if type_attrs.is_subcommand.is_none() {
507         // Not a subcommand
508         quote! {
509             impl argh::TopLevelCommand for #name {}
510         }
511     } else {
512         let empty_str = syn::LitStr::new("", Span::call_site());
513         let subcommand_name = type_attrs.name.as_ref().unwrap_or_else(|| {
514             errors.err(name, "`#[argh(name = \"...\")]` attribute is required for subcommands");
515             &empty_str
516         });
517         quote! {
518             impl argh::SubCommand for #name {
519                 const COMMAND: &'static argh::CommandInfo = &argh::CommandInfo {
520                     name: #subcommand_name,
521                     description: #description,
522                 };
523             }
524         }
525     }
526 }
527 
528 /// Declare a local slots to store each field in during parsing.
529 ///
530 /// Most fields are stored in `Option<FieldType>` locals.
531 /// `argh(option)` fields are stored in a `ParseValueSlotTy` along with a
532 /// function that knows how to decode the appropriate value.
declare_local_storage_for_from_args_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a533 fn declare_local_storage_for_from_args_fields<'a>(
534     fields: &'a [StructField<'a>],
535 ) -> impl Iterator<Item = TokenStream> + 'a {
536     fields.iter().map(|field| {
537         let field_name = &field.field.ident;
538         let field_type = &field.ty_without_wrapper;
539 
540         // Wrap field types in `Option` if they aren't already `Option` or `Vec`-wrapped.
541         let field_slot_type = match field.optionality {
542             Optionality::Optional | Optionality::Repeating => (&field.field.ty).into_token_stream(),
543             Optionality::None | Optionality::Defaulted(_) => {
544                 quote! { std::option::Option<#field_type> }
545             }
546         };
547 
548         match field.kind {
549             FieldKind::Option | FieldKind::Positional => {
550                 let from_str_fn = match &field.attrs.from_str_fn {
551                     Some(from_str_fn) => from_str_fn.into_token_stream(),
552                     None => {
553                         quote! {
554                             <#field_type as argh::FromArgValue>::from_arg_value
555                         }
556                     }
557                 };
558 
559                 quote! {
560                     let mut #field_name: argh::ParseValueSlotTy<#field_slot_type, #field_type>
561                         = argh::ParseValueSlotTy {
562                             slot: std::default::Default::default(),
563                             parse_func: |_, value| { #from_str_fn(value) },
564                         };
565                 }
566             }
567             FieldKind::SubCommand => {
568                 quote! { let mut #field_name: #field_slot_type = None; }
569             }
570             FieldKind::Switch => {
571                 quote! { let mut #field_name: #field_slot_type = argh::Flag::default(); }
572             }
573         }
574     })
575 }
576 
577 /// Unwrap non-optional fields and take options out of their tuple slots.
unwrap_from_args_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a578 fn unwrap_from_args_fields<'a>(
579     fields: &'a [StructField<'a>],
580 ) -> impl Iterator<Item = TokenStream> + 'a {
581     fields.iter().map(|field| {
582         let field_name = field.name;
583         match field.kind {
584             FieldKind::Option | FieldKind::Positional => match &field.optionality {
585                 Optionality::None => quote! { #field_name: #field_name.slot.unwrap() },
586                 Optionality::Optional | Optionality::Repeating => {
587                     quote! { #field_name: #field_name.slot }
588                 }
589                 Optionality::Defaulted(tokens) => {
590                     quote! {
591                         #field_name: #field_name.slot.unwrap_or_else(|| #tokens)
592                     }
593                 }
594             },
595             FieldKind::Switch => field_name.into_token_stream(),
596             FieldKind::SubCommand => match field.optionality {
597                 Optionality::None => quote! { #field_name: #field_name.unwrap() },
598                 Optionality::Optional | Optionality::Repeating => field_name.into_token_stream(),
599                 Optionality::Defaulted(_) => unreachable!(),
600             },
601         }
602     })
603 }
604 
605 /// Declare a local slots to store each field in during parsing.
606 ///
607 /// Most fields are stored in `Option<FieldType>` locals.
608 /// `argh(option)` fields are stored in a `ParseValueSlotTy` along with a
609 /// function that knows how to decode the appropriate value.
declare_local_storage_for_redacted_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a610 fn declare_local_storage_for_redacted_fields<'a>(
611     fields: &'a [StructField<'a>],
612 ) -> impl Iterator<Item = TokenStream> + 'a {
613     fields.iter().map(|field| {
614         let field_name = &field.field.ident;
615 
616         match field.kind {
617             FieldKind::Switch => {
618                 quote! {
619                     let mut #field_name = argh::RedactFlag {
620                         slot: None,
621                     };
622                 }
623             }
624             FieldKind::Option => {
625                 quote! {
626                     let mut #field_name: argh::ParseValueSlotTy::<Option<String>, String> =
627                         argh::ParseValueSlotTy {
628                         slot: std::default::Default::default(),
629                         parse_func: |arg, _| { Ok(arg.to_string()) },
630                     };
631                 }
632             }
633             FieldKind::Positional => {
634                 let field_slot_type = match field.optionality {
635                     Optionality::Repeating => {
636                         quote! { std::vec::Vec<String> }
637                     }
638                     Optionality::None | Optionality::Optional | Optionality::Defaulted(_) => {
639                         quote! { std::option::Option<String> }
640                     }
641                 };
642 
643                 let long_name = field.name.to_string();
644                 quote! {
645                     let mut #field_name: argh::ParseValueSlotTy::<#field_slot_type, String> =
646                         argh::ParseValueSlotTy {
647                         slot: std::default::Default::default(),
648                         parse_func: |_, _| { Ok(#long_name.to_string()) },
649                     };
650                 }
651             }
652             FieldKind::SubCommand => {
653                 quote! { let mut #field_name: std::option::Option<std::vec::Vec<String>> = None; }
654             }
655         }
656     })
657 }
658 
659 /// Unwrap non-optional fields and take options out of their tuple slots.
unwrap_redacted_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a660 fn unwrap_redacted_fields<'a>(
661     fields: &'a [StructField<'a>],
662 ) -> impl Iterator<Item = TokenStream> + 'a {
663     fields.iter().map(|field| {
664         let field_name = field.name;
665 
666         match field.kind {
667             FieldKind::Switch | FieldKind::Option => {
668                 quote! {
669                     if let Some(__field_name) = #field_name.slot {
670                         __redacted.push(__field_name);
671                     }
672                 }
673             }
674             FieldKind::Positional => {
675                 quote! {
676                     __redacted.extend(#field_name.slot.into_iter());
677                 }
678             }
679             FieldKind::SubCommand => {
680                 quote! {
681                     if let Some(__subcommand_args) = #field_name {
682                         __redacted.extend(__subcommand_args.into_iter());
683                     }
684                 }
685             }
686         }
687     })
688 }
689 
690 /// Entries of tokens like `("--some-flag-key", 5)` that map from a flag key string
691 /// to an index in the output table.
flag_str_to_output_table_map_entries<'a>(fields: &'a [StructField<'a>]) -> Vec<TokenStream>692 fn flag_str_to_output_table_map_entries<'a>(fields: &'a [StructField<'a>]) -> Vec<TokenStream> {
693     let mut flag_str_to_output_table_map = vec![];
694     for (i, (field, long_name)) in fields
695         .iter()
696         .filter_map(|field| field.long_name.as_ref().map(|long_name| (field, long_name)))
697         .enumerate()
698     {
699         if let Some(short) = &field.attrs.short {
700             let short = format!("-{}", short.value());
701             flag_str_to_output_table_map.push(quote! { (#short, #i) });
702         }
703 
704         flag_str_to_output_table_map.push(quote! { (#long_name, #i) });
705     }
706     flag_str_to_output_table_map
707 }
708 
709 /// For each non-optional field, add an entry to the `argh::MissingRequirements`.
append_missing_requirements<'a>( mri: &syn::Ident, fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a710 fn append_missing_requirements<'a>(
711     // missing_requirements_ident
712     mri: &syn::Ident,
713     fields: &'a [StructField<'a>],
714 ) -> impl Iterator<Item = TokenStream> + 'a {
715     let mri = mri.clone();
716     fields.iter().filter(|f| f.optionality.is_required()).map(move |field| {
717         let field_name = field.name;
718         match field.kind {
719             FieldKind::Switch => unreachable!("switches are always optional"),
720             FieldKind::Positional => {
721                 let name = field.name.to_string();
722                 quote! {
723                     if #field_name.slot.is_none() {
724                         #mri.missing_positional_arg(#name)
725                     }
726                 }
727             }
728             FieldKind::Option => {
729                 let name = field.long_name.as_ref().expect("options always have a long name");
730                 quote! {
731                     if #field_name.slot.is_none() {
732                         #mri.missing_option(#name)
733                     }
734                 }
735             }
736             FieldKind::SubCommand => {
737                 let ty = field.ty_without_wrapper;
738                 quote! {
739                     if #field_name.is_none() {
740                         #mri.missing_subcommands(
741                             <#ty as argh::SubCommands>::COMMANDS,
742                         )
743                     }
744                 }
745             }
746         }
747     })
748 }
749 
750 /// Require that a type can be a `switch`.
751 /// Throws an error for all types except booleans and integers
ty_expect_switch(errors: &Errors, ty: &syn::Type) -> bool752 fn ty_expect_switch(errors: &Errors, ty: &syn::Type) -> bool {
753     fn ty_can_be_switch(ty: &syn::Type) -> bool {
754         if let syn::Type::Path(path) = ty {
755             if path.qself.is_some() {
756                 return false;
757             }
758             if path.path.segments.len() != 1 {
759                 return false;
760             }
761             let ident = &path.path.segments[0].ident;
762             ["bool", "u8", "u16", "u32", "u64", "u128", "i8", "i16", "i32", "i64", "i128"]
763                 .iter()
764                 .any(|path| ident == path)
765         } else {
766             false
767         }
768     }
769 
770     let res = ty_can_be_switch(ty);
771     if !res {
772         errors.err(ty, "switches must be of type `bool` or integer type");
773     }
774     res
775 }
776 
777 /// Returns `Some(T)` if a type is `wrapper_name<T>` for any `wrapper_name` in `wrapper_names`.
ty_inner<'a>(wrapper_names: &[&str], ty: &'a syn::Type) -> Option<&'a syn::Type>778 fn ty_inner<'a>(wrapper_names: &[&str], ty: &'a syn::Type) -> Option<&'a syn::Type> {
779     if let syn::Type::Path(path) = ty {
780         if path.qself.is_some() {
781             return None;
782         }
783         // Since we only check the last path segment, it isn't necessarily the case that
784         // we're referring to `std::vec::Vec` or `std::option::Option`, but there isn't
785         // a fool proof way to check these since name resolution happens after macro expansion,
786         // so this is likely "good enough" (so long as people don't have their own types called
787         // `Option` or `Vec` that take one generic parameter they're looking to parse).
788         let last_segment = path.path.segments.last()?;
789         if !wrapper_names.iter().any(|name| last_segment.ident == *name) {
790             return None;
791         }
792         if let syn::PathArguments::AngleBracketed(gen_args) = &last_segment.arguments {
793             let generic_arg = gen_args.args.first()?;
794             if let syn::GenericArgument::Type(ty) = &generic_arg {
795                 return Some(ty);
796             }
797         }
798     }
799     None
800 }
801 
802 /// Implements `FromArgs` and `SubCommands` for a `#![derive(FromArgs)]` enum.
impl_from_args_enum( errors: &Errors, name: &syn::Ident, type_attrs: &TypeAttrs, de: &syn::DataEnum, ) -> TokenStream803 fn impl_from_args_enum(
804     errors: &Errors,
805     name: &syn::Ident,
806     type_attrs: &TypeAttrs,
807     de: &syn::DataEnum,
808 ) -> TokenStream {
809     parse_attrs::check_enum_type_attrs(errors, type_attrs, &de.enum_token.span);
810 
811     // An enum variant like `<name>(<ty>)`
812     struct SubCommandVariant<'a> {
813         name: &'a syn::Ident,
814         ty: &'a syn::Type,
815     }
816 
817     let variants: Vec<SubCommandVariant<'_>> = de
818         .variants
819         .iter()
820         .filter_map(|variant| {
821             parse_attrs::check_enum_variant_attrs(errors, variant);
822             let name = &variant.ident;
823             let ty = enum_only_single_field_unnamed_variants(errors, &variant.fields)?;
824             Some(SubCommandVariant { name, ty })
825         })
826         .collect();
827 
828     let name_repeating = std::iter::repeat(name.clone());
829     let variant_ty = variants.iter().map(|x| x.ty).collect::<Vec<_>>();
830     let variant_names = variants.iter().map(|x| x.name).collect::<Vec<_>>();
831 
832     quote! {
833         impl argh::FromArgs for #name {
834             fn from_args(command_name: &[&str], args: &[&str])
835                 -> std::result::Result<Self, argh::EarlyExit>
836             {
837                 let subcommand_name = *command_name.last().expect("no subcommand name");
838                 #(
839                     if subcommand_name == <#variant_ty as argh::SubCommand>::COMMAND.name {
840                         return Ok(#name_repeating::#variant_names(
841                             <#variant_ty as argh::FromArgs>::from_args(command_name, args)?
842                         ));
843                     }
844                 )*
845                 unreachable!("no subcommand matched")
846             }
847 
848             fn redact_arg_values(command_name: &[&str], args: &[&str]) -> std::result::Result<Vec<String>, argh::EarlyExit> {
849                 let subcommand_name = *command_name.last().expect("no subcommand name");
850                 #(
851                     if subcommand_name == <#variant_ty as argh::SubCommand>::COMMAND.name {
852                         return <#variant_ty as argh::FromArgs>::redact_arg_values(command_name, args);
853                     }
854                 )*
855                 unreachable!("no subcommand matched")
856             }
857         }
858 
859         impl argh::SubCommands for #name {
860             const COMMANDS: &'static [&'static argh::CommandInfo] = &[#(
861                 <#variant_ty as argh::SubCommand>::COMMAND,
862             )*];
863         }
864     }
865 }
866 
867 /// Returns `Some(Bar)` if the field is a single-field unnamed variant like `Foo(Bar)`.
868 /// Otherwise, generates an error.
enum_only_single_field_unnamed_variants<'a>( errors: &Errors, variant_fields: &'a syn::Fields, ) -> Option<&'a syn::Type>869 fn enum_only_single_field_unnamed_variants<'a>(
870     errors: &Errors,
871     variant_fields: &'a syn::Fields,
872 ) -> Option<&'a syn::Type> {
873     macro_rules! with_enum_suggestion {
874         ($help_text:literal) => {
875             concat!(
876                 $help_text,
877                 "\nInstead, use a variant with a single unnamed field for each subcommand:\n",
878                 "    enum MyCommandEnum {\n",
879                 "        SubCommandOne(SubCommandOne),\n",
880                 "        SubCommandTwo(SubCommandTwo),\n",
881                 "    }",
882             )
883         };
884     }
885 
886     match variant_fields {
887         syn::Fields::Named(fields) => {
888             errors.err(
889                 fields,
890                 with_enum_suggestion!(
891                     "`#![derive(FromArgs)]` `enum`s do not support variants with named fields."
892                 ),
893             );
894             None
895         }
896         syn::Fields::Unit => {
897             errors.err(
898                 variant_fields,
899                 with_enum_suggestion!(
900                     "`#![derive(FromArgs)]` does not support `enum`s with no variants."
901                 ),
902             );
903             None
904         }
905         syn::Fields::Unnamed(fields) => {
906             if fields.unnamed.len() != 1 {
907                 errors.err(
908                     fields,
909                     with_enum_suggestion!(
910                         "`#![derive(FromArgs)]` `enum` variants must only contain one field."
911                     ),
912                 );
913                 None
914             } else {
915                 // `unwrap` is okay because of the length check above.
916                 let first_field = fields.unnamed.first().unwrap();
917                 Some(&first_field.ty)
918             }
919         }
920     }
921 }
922