1 #![recursion_limit = "256"]
2 // Copyright (c) 2020 Google LLC All rights reserved.
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5
6 /// Implementation of the `FromArgs` and `argh(...)` derive attributes.
7 ///
8 /// For more thorough documentation, see the `argh` crate itself.
9 extern crate proc_macro;
10
11 use {
12 crate::{
13 errors::Errors,
14 parse_attrs::{FieldAttrs, FieldKind, TypeAttrs},
15 },
16 proc_macro2::{Span, TokenStream},
17 quote::{quote, quote_spanned, ToTokens},
18 std::str::FromStr,
19 syn::spanned::Spanned,
20 };
21
22 mod errors;
23 mod help;
24 mod parse_attrs;
25
26 /// Entrypoint for `#[derive(FromArgs)]`.
27 #[proc_macro_derive(FromArgs, attributes(argh))]
argh_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream28 pub fn argh_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
29 let ast = syn::parse_macro_input!(input as syn::DeriveInput);
30 let gen = impl_from_args(&ast);
31 gen.into()
32 }
33
34 /// Transform the input into a token stream containing any generated implementations,
35 /// as well as all errors that occurred.
impl_from_args(input: &syn::DeriveInput) -> TokenStream36 fn impl_from_args(input: &syn::DeriveInput) -> TokenStream {
37 let errors = &Errors::default();
38 if input.generics.params.len() != 0 {
39 errors.err(
40 &input.generics,
41 "`#![derive(FromArgs)]` cannot be applied to types with generic parameters",
42 );
43 }
44 let type_attrs = &TypeAttrs::parse(errors, input);
45 let mut output_tokens = match &input.data {
46 syn::Data::Struct(ds) => impl_from_args_struct(errors, &input.ident, type_attrs, ds),
47 syn::Data::Enum(de) => impl_from_args_enum(errors, &input.ident, type_attrs, de),
48 syn::Data::Union(_) => {
49 errors.err(input, "`#[derive(FromArgs)]` cannot be applied to unions");
50 TokenStream::new()
51 }
52 };
53 errors.to_tokens(&mut output_tokens);
54 output_tokens
55 }
56
57 /// The kind of optionality a parameter has.
58 enum Optionality {
59 None,
60 Defaulted(TokenStream),
61 Optional,
62 Repeating,
63 }
64
65 impl PartialEq<Optionality> for Optionality {
eq(&self, other: &Optionality) -> bool66 fn eq(&self, other: &Optionality) -> bool {
67 use Optionality::*;
68 match (self, other) {
69 (None, None) | (Optional, Optional) | (Repeating, Repeating) => true,
70 // NB: (Defaulted, Defaulted) can't contain the same token streams
71 _ => false,
72 }
73 }
74 }
75
76 impl Optionality {
77 /// Whether or not this is `Optionality::None`
is_required(&self) -> bool78 fn is_required(&self) -> bool {
79 if let Optionality::None = self {
80 true
81 } else {
82 false
83 }
84 }
85 }
86
87 /// A field of a `#![derive(FromArgs)]` struct with attributes and some other
88 /// notable metadata appended.
89 struct StructField<'a> {
90 /// The original parsed field
91 field: &'a syn::Field,
92 /// The parsed attributes of the field
93 attrs: FieldAttrs,
94 /// The field name. This is contained optionally inside `field`,
95 /// but is duplicated non-optionally here to indicate that all field that
96 /// have reached this point must have a field name, and it no longer
97 /// needs to be unwrapped.
98 name: &'a syn::Ident,
99 /// Similar to `name` above, this is contained optionally inside `FieldAttrs`,
100 /// but here is fully present to indicate that we only have to consider fields
101 /// with a valid `kind` at this point.
102 kind: FieldKind,
103 // If `field.ty` is `Vec<T>` or `Option<T>`, this is `T`, otherwise it's `&field.ty`.
104 // This is used to enable consistent parsing code between optional and non-optional
105 // keyed and subcommand fields.
106 ty_without_wrapper: &'a syn::Type,
107 // Whether the field represents an optional value, such as an `Option` subcommand field
108 // or an `Option` or `Vec` keyed argument, or if it has a `default`.
109 optionality: Optionality,
110 // The `--`-prefixed name of the option, if one exists.
111 long_name: Option<String>,
112 }
113
114 impl<'a> StructField<'a> {
115 /// Attempts to parse a field of a `#[derive(FromArgs)]` struct, pulling out the
116 /// fields required for code generation.
new(errors: &Errors, field: &'a syn::Field, attrs: FieldAttrs) -> Option<Self>117 fn new(errors: &Errors, field: &'a syn::Field, attrs: FieldAttrs) -> Option<Self> {
118 let name = field.ident.as_ref().expect("missing ident for named field");
119
120 // Ensure that one "kind" is present (switch, option, subcommand, positional)
121 let kind = if let Some(field_type) = &attrs.field_type {
122 field_type.kind
123 } else {
124 errors.err(
125 field,
126 concat!(
127 "Missing `argh` field kind attribute.\n",
128 "Expected one of: `switch`, `option`, `subcommand`, `positional`",
129 ),
130 );
131 return None;
132 };
133
134 // Parse out whether a field is optional (`Option` or `Vec`).
135 let optionality;
136 let ty_without_wrapper;
137 match kind {
138 FieldKind::Switch => {
139 if !ty_expect_switch(errors, &field.ty) {
140 return None;
141 }
142 optionality = Optionality::Optional;
143 ty_without_wrapper = &field.ty;
144 }
145 FieldKind::Option | FieldKind::Positional => {
146 if let Some(default) = &attrs.default {
147 let tokens = match TokenStream::from_str(&default.value()) {
148 Ok(tokens) => tokens,
149 Err(_) => {
150 errors.err(&default, "Invalid tokens: unable to lex `default` value");
151 return None;
152 }
153 };
154 // Set the span of the generated tokens to the string literal
155 let tokens: TokenStream = tokens
156 .into_iter()
157 .map(|mut tree| {
158 tree.set_span(default.span().clone());
159 tree
160 })
161 .collect();
162 optionality = Optionality::Defaulted(tokens);
163 ty_without_wrapper = &field.ty;
164 } else {
165 let mut inner = None;
166 optionality = if let Some(x) = ty_inner(&["Option"], &field.ty) {
167 inner = Some(x);
168 Optionality::Optional
169 } else if let Some(x) = ty_inner(&["Vec"], &field.ty) {
170 inner = Some(x);
171 Optionality::Repeating
172 } else {
173 Optionality::None
174 };
175 ty_without_wrapper = inner.unwrap_or(&field.ty);
176 }
177 }
178 FieldKind::SubCommand => {
179 let inner = ty_inner(&["Option"], &field.ty);
180 optionality =
181 if inner.is_some() { Optionality::Optional } else { Optionality::None };
182 ty_without_wrapper = inner.unwrap_or(&field.ty);
183 }
184 }
185
186 // Determine the "long" name of options and switches.
187 // Defaults to the kebab-case'd field name if `#[argh(long = "...")]` is omitted.
188 let long_name = match kind {
189 FieldKind::Switch | FieldKind::Option => {
190 let long_name = attrs
191 .long
192 .as_ref()
193 .map(syn::LitStr::value)
194 .unwrap_or_else(|| heck::KebabCase::to_kebab_case(&*name.to_string()));
195 if long_name == "help" {
196 errors.err(field, "Custom `--help` flags are not supported.");
197 }
198 let long_name = format!("--{}", long_name);
199 Some(long_name)
200 }
201 FieldKind::SubCommand | FieldKind::Positional => None,
202 };
203
204 Some(StructField { field, attrs, kind, optionality, ty_without_wrapper, name, long_name })
205 }
206 }
207
208 /// Implements `FromArgs` and `TopLevelCommand` or `SubCommand` for a `#[derive(FromArgs)]` struct.
impl_from_args_struct( errors: &Errors, name: &syn::Ident, type_attrs: &TypeAttrs, ds: &syn::DataStruct, ) -> TokenStream209 fn impl_from_args_struct(
210 errors: &Errors,
211 name: &syn::Ident,
212 type_attrs: &TypeAttrs,
213 ds: &syn::DataStruct,
214 ) -> TokenStream {
215 let fields = match &ds.fields {
216 syn::Fields::Named(fields) => fields,
217 syn::Fields::Unnamed(_) => {
218 errors.err(
219 &ds.struct_token,
220 "`#![derive(FromArgs)]` is not currently supported on tuple structs",
221 );
222 return TokenStream::new();
223 }
224 syn::Fields::Unit => {
225 errors.err(&ds.struct_token, "#![derive(FromArgs)]` cannot be applied to unit structs");
226 return TokenStream::new();
227 }
228 };
229
230 let fields: Vec<_> = fields
231 .named
232 .iter()
233 .filter_map(|field| {
234 let attrs = FieldAttrs::parse(errors, field);
235 StructField::new(errors, field, attrs)
236 })
237 .collect();
238
239 ensure_only_last_positional_is_optional(errors, &fields);
240
241 let impl_span = Span::call_site();
242
243 let from_args_method = impl_from_args_struct_from_args(errors, type_attrs, &fields);
244
245 let redact_arg_values_method =
246 impl_from_args_struct_redact_arg_values(errors, type_attrs, &fields);
247
248 let top_or_sub_cmd_impl = top_or_sub_cmd_impl(errors, name, type_attrs);
249
250 let trait_impl = quote_spanned! { impl_span =>
251 impl argh::FromArgs for #name {
252 #from_args_method
253
254 #redact_arg_values_method
255 }
256
257 #top_or_sub_cmd_impl
258 };
259
260 trait_impl.into()
261 }
262
impl_from_args_struct_from_args<'a>( errors: &Errors, type_attrs: &TypeAttrs, fields: &'a [StructField<'a>], ) -> TokenStream263 fn impl_from_args_struct_from_args<'a>(
264 errors: &Errors,
265 type_attrs: &TypeAttrs,
266 fields: &'a [StructField<'a>],
267 ) -> TokenStream {
268 let init_fields = declare_local_storage_for_from_args_fields(&fields);
269 let unwrap_fields = unwrap_from_args_fields(&fields);
270 let positional_fields: Vec<&StructField<'_>> =
271 fields.iter().filter(|field| field.kind == FieldKind::Positional).collect();
272 let positional_field_idents = positional_fields.iter().map(|field| &field.field.ident);
273 let positional_field_names = positional_fields.iter().map(|field| field.name.to_string());
274 let last_positional_is_repeating = positional_fields
275 .last()
276 .map(|field| field.optionality == Optionality::Repeating)
277 .unwrap_or(false);
278
279 let flag_output_table = fields.iter().filter_map(|field| {
280 let field_name = &field.field.ident;
281 match field.kind {
282 FieldKind::Option => Some(quote! { argh::ParseStructOption::Value(&mut #field_name) }),
283 FieldKind::Switch => Some(quote! { argh::ParseStructOption::Flag(&mut #field_name) }),
284 FieldKind::SubCommand | FieldKind::Positional => None,
285 }
286 });
287
288 let flag_str_to_output_table_map = flag_str_to_output_table_map_entries(&fields);
289
290 let mut subcommands_iter =
291 fields.iter().filter(|field| field.kind == FieldKind::SubCommand).fuse();
292
293 let subcommand: Option<&StructField<'_>> = subcommands_iter.next();
294 while let Some(dup_subcommand) = subcommands_iter.next() {
295 errors.duplicate_attrs("subcommand", subcommand.unwrap().field, dup_subcommand.field);
296 }
297
298 let impl_span = Span::call_site();
299
300 let missing_requirements_ident = syn::Ident::new("__missing_requirements", impl_span.clone());
301
302 let append_missing_requirements =
303 append_missing_requirements(&missing_requirements_ident, &fields);
304
305 let parse_subcommands = if let Some(subcommand) = subcommand {
306 let name = subcommand.name;
307 let ty = subcommand.ty_without_wrapper;
308 quote_spanned! { impl_span =>
309 Some(argh::ParseStructSubCommand {
310 subcommands: <#ty as argh::SubCommands>::COMMANDS,
311 parse_func: &mut |__command, __remaining_args| {
312 #name = Some(<#ty as argh::FromArgs>::from_args(__command, __remaining_args)?);
313 Ok(())
314 },
315 })
316 }
317 } else {
318 quote_spanned! { impl_span => None }
319 };
320
321 // Identifier referring to a value containing the name of the current command as an `&[&str]`.
322 let cmd_name_str_array_ident = syn::Ident::new("__cmd_name", impl_span.clone());
323 let help = help::help(errors, cmd_name_str_array_ident, type_attrs, &fields, subcommand);
324
325 let method_impl = quote_spanned! { impl_span =>
326 fn from_args(__cmd_name: &[&str], __args: &[&str])
327 -> std::result::Result<Self, argh::EarlyExit>
328 {
329 #( #init_fields )*
330
331 argh::parse_struct_args(
332 __cmd_name,
333 __args,
334 argh::ParseStructOptions {
335 arg_to_slot: &[ #( #flag_str_to_output_table_map ,)* ],
336 slots: &mut [ #( #flag_output_table, )* ],
337 },
338 argh::ParseStructPositionals {
339 positionals: &mut [
340 #(
341 argh::ParseStructPositional {
342 name: #positional_field_names,
343 slot: &mut #positional_field_idents as &mut argh::ParseValueSlot,
344 },
345 )*
346 ],
347 last_is_repeating: #last_positional_is_repeating,
348 },
349 #parse_subcommands,
350 &|| #help,
351 )?;
352
353 let mut #missing_requirements_ident = argh::MissingRequirements::default();
354 #(
355 #append_missing_requirements
356 )*
357 #missing_requirements_ident.err_on_any()?;
358
359 Ok(Self {
360 #( #unwrap_fields, )*
361 })
362 }
363 };
364
365 method_impl.into()
366 }
367
impl_from_args_struct_redact_arg_values<'a>( errors: &Errors, type_attrs: &TypeAttrs, fields: &'a [StructField<'a>], ) -> TokenStream368 fn impl_from_args_struct_redact_arg_values<'a>(
369 errors: &Errors,
370 type_attrs: &TypeAttrs,
371 fields: &'a [StructField<'a>],
372 ) -> TokenStream {
373 let init_fields = declare_local_storage_for_redacted_fields(&fields);
374 let unwrap_fields = unwrap_redacted_fields(&fields);
375
376 let positional_fields: Vec<&StructField<'_>> =
377 fields.iter().filter(|field| field.kind == FieldKind::Positional).collect();
378 let positional_field_idents = positional_fields.iter().map(|field| &field.field.ident);
379 let positional_field_names = positional_fields.iter().map(|field| field.name.to_string());
380 let last_positional_is_repeating = positional_fields
381 .last()
382 .map(|field| field.optionality == Optionality::Repeating)
383 .unwrap_or(false);
384
385 let flag_output_table = fields.iter().filter_map(|field| {
386 let field_name = &field.field.ident;
387 match field.kind {
388 FieldKind::Option => Some(quote! { argh::ParseStructOption::Value(&mut #field_name) }),
389 FieldKind::Switch => Some(quote! { argh::ParseStructOption::Flag(&mut #field_name) }),
390 FieldKind::SubCommand | FieldKind::Positional => None,
391 }
392 });
393
394 let flag_str_to_output_table_map = flag_str_to_output_table_map_entries(&fields);
395
396 let mut subcommands_iter =
397 fields.iter().filter(|field| field.kind == FieldKind::SubCommand).fuse();
398
399 let subcommand: Option<&StructField<'_>> = subcommands_iter.next();
400 while let Some(dup_subcommand) = subcommands_iter.next() {
401 errors.duplicate_attrs("subcommand", subcommand.unwrap().field, dup_subcommand.field);
402 }
403
404 let impl_span = Span::call_site();
405
406 let missing_requirements_ident = syn::Ident::new("__missing_requirements", impl_span.clone());
407
408 let append_missing_requirements =
409 append_missing_requirements(&missing_requirements_ident, &fields);
410
411 let redact_subcommands = if let Some(subcommand) = subcommand {
412 let name = subcommand.name;
413 let ty = subcommand.ty_without_wrapper;
414 quote_spanned! { impl_span =>
415 Some(argh::ParseStructSubCommand {
416 subcommands: <#ty as argh::SubCommands>::COMMANDS,
417 parse_func: &mut |__command, __remaining_args| {
418 #name = Some(<#ty as argh::FromArgs>::redact_arg_values(__command, __remaining_args)?);
419 Ok(())
420 },
421 })
422 }
423 } else {
424 quote_spanned! { impl_span => None }
425 };
426
427 let cmd_name = if type_attrs.is_subcommand.is_none() {
428 quote! { __cmd_name.last().expect("no command name").to_string() }
429 } else {
430 quote! { __cmd_name.last().expect("no subcommand name").to_string() }
431 };
432
433 // Identifier referring to a value containing the name of the current command as an `&[&str]`.
434 let cmd_name_str_array_ident = syn::Ident::new("__cmd_name", impl_span.clone());
435 let help = help::help(errors, cmd_name_str_array_ident, type_attrs, &fields, subcommand);
436
437 let method_impl = quote_spanned! { impl_span =>
438 fn redact_arg_values(__cmd_name: &[&str], __args: &[&str]) -> Result<Vec<String>, argh::EarlyExit> {
439 #( #init_fields )*
440
441 argh::parse_struct_args(
442 __cmd_name,
443 __args,
444 argh::ParseStructOptions {
445 arg_to_slot: &[ #( #flag_str_to_output_table_map ,)* ],
446 slots: &mut [ #( #flag_output_table, )* ],
447 },
448 argh::ParseStructPositionals {
449 positionals: &mut [
450 #(
451 argh::ParseStructPositional {
452 name: #positional_field_names,
453 slot: &mut #positional_field_idents as &mut argh::ParseValueSlot,
454 },
455 )*
456 ],
457 last_is_repeating: #last_positional_is_repeating,
458 },
459 #redact_subcommands,
460 &|| #help,
461 )?;
462
463 let mut #missing_requirements_ident = argh::MissingRequirements::default();
464 #(
465 #append_missing_requirements
466 )*
467 #missing_requirements_ident.err_on_any()?;
468
469 let mut __redacted = vec![
470 #cmd_name,
471 ];
472
473 #( #unwrap_fields )*
474
475 Ok(__redacted)
476 }
477 };
478
479 method_impl.into()
480 }
481
482 /// Ensures that only the last positional arg is non-required.
ensure_only_last_positional_is_optional(errors: &Errors, fields: &[StructField<'_>])483 fn ensure_only_last_positional_is_optional(errors: &Errors, fields: &[StructField<'_>]) {
484 let mut first_non_required_span = None;
485 for field in fields {
486 if field.kind == FieldKind::Positional {
487 if let Some(first) = first_non_required_span {
488 errors.err_span(
489 first,
490 "Only the last positional argument may be `Option`, `Vec`, or defaulted.",
491 );
492 errors.err(&field.field, "Later positional argument declared here.");
493 return;
494 }
495 if !field.optionality.is_required() {
496 first_non_required_span = Some(field.field.span());
497 }
498 }
499 }
500 }
501
502 /// Implement `argh::TopLevelCommand` or `argh::SubCommand` as appropriate.
top_or_sub_cmd_impl(errors: &Errors, name: &syn::Ident, type_attrs: &TypeAttrs) -> TokenStream503 fn top_or_sub_cmd_impl(errors: &Errors, name: &syn::Ident, type_attrs: &TypeAttrs) -> TokenStream {
504 let description =
505 help::require_description(errors, name.span(), &type_attrs.description, "type");
506 if type_attrs.is_subcommand.is_none() {
507 // Not a subcommand
508 quote! {
509 impl argh::TopLevelCommand for #name {}
510 }
511 } else {
512 let empty_str = syn::LitStr::new("", Span::call_site());
513 let subcommand_name = type_attrs.name.as_ref().unwrap_or_else(|| {
514 errors.err(name, "`#[argh(name = \"...\")]` attribute is required for subcommands");
515 &empty_str
516 });
517 quote! {
518 impl argh::SubCommand for #name {
519 const COMMAND: &'static argh::CommandInfo = &argh::CommandInfo {
520 name: #subcommand_name,
521 description: #description,
522 };
523 }
524 }
525 }
526 }
527
528 /// Declare a local slots to store each field in during parsing.
529 ///
530 /// Most fields are stored in `Option<FieldType>` locals.
531 /// `argh(option)` fields are stored in a `ParseValueSlotTy` along with a
532 /// function that knows how to decode the appropriate value.
declare_local_storage_for_from_args_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a533 fn declare_local_storage_for_from_args_fields<'a>(
534 fields: &'a [StructField<'a>],
535 ) -> impl Iterator<Item = TokenStream> + 'a {
536 fields.iter().map(|field| {
537 let field_name = &field.field.ident;
538 let field_type = &field.ty_without_wrapper;
539
540 // Wrap field types in `Option` if they aren't already `Option` or `Vec`-wrapped.
541 let field_slot_type = match field.optionality {
542 Optionality::Optional | Optionality::Repeating => (&field.field.ty).into_token_stream(),
543 Optionality::None | Optionality::Defaulted(_) => {
544 quote! { std::option::Option<#field_type> }
545 }
546 };
547
548 match field.kind {
549 FieldKind::Option | FieldKind::Positional => {
550 let from_str_fn = match &field.attrs.from_str_fn {
551 Some(from_str_fn) => from_str_fn.into_token_stream(),
552 None => {
553 quote! {
554 <#field_type as argh::FromArgValue>::from_arg_value
555 }
556 }
557 };
558
559 quote! {
560 let mut #field_name: argh::ParseValueSlotTy<#field_slot_type, #field_type>
561 = argh::ParseValueSlotTy {
562 slot: std::default::Default::default(),
563 parse_func: |_, value| { #from_str_fn(value) },
564 };
565 }
566 }
567 FieldKind::SubCommand => {
568 quote! { let mut #field_name: #field_slot_type = None; }
569 }
570 FieldKind::Switch => {
571 quote! { let mut #field_name: #field_slot_type = argh::Flag::default(); }
572 }
573 }
574 })
575 }
576
577 /// Unwrap non-optional fields and take options out of their tuple slots.
unwrap_from_args_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a578 fn unwrap_from_args_fields<'a>(
579 fields: &'a [StructField<'a>],
580 ) -> impl Iterator<Item = TokenStream> + 'a {
581 fields.iter().map(|field| {
582 let field_name = field.name;
583 match field.kind {
584 FieldKind::Option | FieldKind::Positional => match &field.optionality {
585 Optionality::None => quote! { #field_name: #field_name.slot.unwrap() },
586 Optionality::Optional | Optionality::Repeating => {
587 quote! { #field_name: #field_name.slot }
588 }
589 Optionality::Defaulted(tokens) => {
590 quote! {
591 #field_name: #field_name.slot.unwrap_or_else(|| #tokens)
592 }
593 }
594 },
595 FieldKind::Switch => field_name.into_token_stream(),
596 FieldKind::SubCommand => match field.optionality {
597 Optionality::None => quote! { #field_name: #field_name.unwrap() },
598 Optionality::Optional | Optionality::Repeating => field_name.into_token_stream(),
599 Optionality::Defaulted(_) => unreachable!(),
600 },
601 }
602 })
603 }
604
605 /// Declare a local slots to store each field in during parsing.
606 ///
607 /// Most fields are stored in `Option<FieldType>` locals.
608 /// `argh(option)` fields are stored in a `ParseValueSlotTy` along with a
609 /// function that knows how to decode the appropriate value.
declare_local_storage_for_redacted_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a610 fn declare_local_storage_for_redacted_fields<'a>(
611 fields: &'a [StructField<'a>],
612 ) -> impl Iterator<Item = TokenStream> + 'a {
613 fields.iter().map(|field| {
614 let field_name = &field.field.ident;
615
616 match field.kind {
617 FieldKind::Switch => {
618 quote! {
619 let mut #field_name = argh::RedactFlag {
620 slot: None,
621 };
622 }
623 }
624 FieldKind::Option => {
625 quote! {
626 let mut #field_name: argh::ParseValueSlotTy::<Option<String>, String> =
627 argh::ParseValueSlotTy {
628 slot: std::default::Default::default(),
629 parse_func: |arg, _| { Ok(arg.to_string()) },
630 };
631 }
632 }
633 FieldKind::Positional => {
634 let field_slot_type = match field.optionality {
635 Optionality::Repeating => {
636 quote! { std::vec::Vec<String> }
637 }
638 Optionality::None | Optionality::Optional | Optionality::Defaulted(_) => {
639 quote! { std::option::Option<String> }
640 }
641 };
642
643 let long_name = field.name.to_string();
644 quote! {
645 let mut #field_name: argh::ParseValueSlotTy::<#field_slot_type, String> =
646 argh::ParseValueSlotTy {
647 slot: std::default::Default::default(),
648 parse_func: |_, _| { Ok(#long_name.to_string()) },
649 };
650 }
651 }
652 FieldKind::SubCommand => {
653 quote! { let mut #field_name: std::option::Option<std::vec::Vec<String>> = None; }
654 }
655 }
656 })
657 }
658
659 /// Unwrap non-optional fields and take options out of their tuple slots.
unwrap_redacted_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a660 fn unwrap_redacted_fields<'a>(
661 fields: &'a [StructField<'a>],
662 ) -> impl Iterator<Item = TokenStream> + 'a {
663 fields.iter().map(|field| {
664 let field_name = field.name;
665
666 match field.kind {
667 FieldKind::Switch | FieldKind::Option => {
668 quote! {
669 if let Some(__field_name) = #field_name.slot {
670 __redacted.push(__field_name);
671 }
672 }
673 }
674 FieldKind::Positional => {
675 quote! {
676 __redacted.extend(#field_name.slot.into_iter());
677 }
678 }
679 FieldKind::SubCommand => {
680 quote! {
681 if let Some(__subcommand_args) = #field_name {
682 __redacted.extend(__subcommand_args.into_iter());
683 }
684 }
685 }
686 }
687 })
688 }
689
690 /// Entries of tokens like `("--some-flag-key", 5)` that map from a flag key string
691 /// to an index in the output table.
flag_str_to_output_table_map_entries<'a>(fields: &'a [StructField<'a>]) -> Vec<TokenStream>692 fn flag_str_to_output_table_map_entries<'a>(fields: &'a [StructField<'a>]) -> Vec<TokenStream> {
693 let mut flag_str_to_output_table_map = vec![];
694 for (i, (field, long_name)) in fields
695 .iter()
696 .filter_map(|field| field.long_name.as_ref().map(|long_name| (field, long_name)))
697 .enumerate()
698 {
699 if let Some(short) = &field.attrs.short {
700 let short = format!("-{}", short.value());
701 flag_str_to_output_table_map.push(quote! { (#short, #i) });
702 }
703
704 flag_str_to_output_table_map.push(quote! { (#long_name, #i) });
705 }
706 flag_str_to_output_table_map
707 }
708
709 /// For each non-optional field, add an entry to the `argh::MissingRequirements`.
append_missing_requirements<'a>( mri: &syn::Ident, fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a710 fn append_missing_requirements<'a>(
711 // missing_requirements_ident
712 mri: &syn::Ident,
713 fields: &'a [StructField<'a>],
714 ) -> impl Iterator<Item = TokenStream> + 'a {
715 let mri = mri.clone();
716 fields.iter().filter(|f| f.optionality.is_required()).map(move |field| {
717 let field_name = field.name;
718 match field.kind {
719 FieldKind::Switch => unreachable!("switches are always optional"),
720 FieldKind::Positional => {
721 let name = field.name.to_string();
722 quote! {
723 if #field_name.slot.is_none() {
724 #mri.missing_positional_arg(#name)
725 }
726 }
727 }
728 FieldKind::Option => {
729 let name = field.long_name.as_ref().expect("options always have a long name");
730 quote! {
731 if #field_name.slot.is_none() {
732 #mri.missing_option(#name)
733 }
734 }
735 }
736 FieldKind::SubCommand => {
737 let ty = field.ty_without_wrapper;
738 quote! {
739 if #field_name.is_none() {
740 #mri.missing_subcommands(
741 <#ty as argh::SubCommands>::COMMANDS,
742 )
743 }
744 }
745 }
746 }
747 })
748 }
749
750 /// Require that a type can be a `switch`.
751 /// Throws an error for all types except booleans and integers
ty_expect_switch(errors: &Errors, ty: &syn::Type) -> bool752 fn ty_expect_switch(errors: &Errors, ty: &syn::Type) -> bool {
753 fn ty_can_be_switch(ty: &syn::Type) -> bool {
754 if let syn::Type::Path(path) = ty {
755 if path.qself.is_some() {
756 return false;
757 }
758 if path.path.segments.len() != 1 {
759 return false;
760 }
761 let ident = &path.path.segments[0].ident;
762 ["bool", "u8", "u16", "u32", "u64", "u128", "i8", "i16", "i32", "i64", "i128"]
763 .iter()
764 .any(|path| ident == path)
765 } else {
766 false
767 }
768 }
769
770 let res = ty_can_be_switch(ty);
771 if !res {
772 errors.err(ty, "switches must be of type `bool` or integer type");
773 }
774 res
775 }
776
777 /// Returns `Some(T)` if a type is `wrapper_name<T>` for any `wrapper_name` in `wrapper_names`.
ty_inner<'a>(wrapper_names: &[&str], ty: &'a syn::Type) -> Option<&'a syn::Type>778 fn ty_inner<'a>(wrapper_names: &[&str], ty: &'a syn::Type) -> Option<&'a syn::Type> {
779 if let syn::Type::Path(path) = ty {
780 if path.qself.is_some() {
781 return None;
782 }
783 // Since we only check the last path segment, it isn't necessarily the case that
784 // we're referring to `std::vec::Vec` or `std::option::Option`, but there isn't
785 // a fool proof way to check these since name resolution happens after macro expansion,
786 // so this is likely "good enough" (so long as people don't have their own types called
787 // `Option` or `Vec` that take one generic parameter they're looking to parse).
788 let last_segment = path.path.segments.last()?;
789 if !wrapper_names.iter().any(|name| last_segment.ident == *name) {
790 return None;
791 }
792 if let syn::PathArguments::AngleBracketed(gen_args) = &last_segment.arguments {
793 let generic_arg = gen_args.args.first()?;
794 if let syn::GenericArgument::Type(ty) = &generic_arg {
795 return Some(ty);
796 }
797 }
798 }
799 None
800 }
801
802 /// Implements `FromArgs` and `SubCommands` for a `#![derive(FromArgs)]` enum.
impl_from_args_enum( errors: &Errors, name: &syn::Ident, type_attrs: &TypeAttrs, de: &syn::DataEnum, ) -> TokenStream803 fn impl_from_args_enum(
804 errors: &Errors,
805 name: &syn::Ident,
806 type_attrs: &TypeAttrs,
807 de: &syn::DataEnum,
808 ) -> TokenStream {
809 parse_attrs::check_enum_type_attrs(errors, type_attrs, &de.enum_token.span);
810
811 // An enum variant like `<name>(<ty>)`
812 struct SubCommandVariant<'a> {
813 name: &'a syn::Ident,
814 ty: &'a syn::Type,
815 }
816
817 let variants: Vec<SubCommandVariant<'_>> = de
818 .variants
819 .iter()
820 .filter_map(|variant| {
821 parse_attrs::check_enum_variant_attrs(errors, variant);
822 let name = &variant.ident;
823 let ty = enum_only_single_field_unnamed_variants(errors, &variant.fields)?;
824 Some(SubCommandVariant { name, ty })
825 })
826 .collect();
827
828 let name_repeating = std::iter::repeat(name.clone());
829 let variant_ty = variants.iter().map(|x| x.ty).collect::<Vec<_>>();
830 let variant_names = variants.iter().map(|x| x.name).collect::<Vec<_>>();
831
832 quote! {
833 impl argh::FromArgs for #name {
834 fn from_args(command_name: &[&str], args: &[&str])
835 -> std::result::Result<Self, argh::EarlyExit>
836 {
837 let subcommand_name = *command_name.last().expect("no subcommand name");
838 #(
839 if subcommand_name == <#variant_ty as argh::SubCommand>::COMMAND.name {
840 return Ok(#name_repeating::#variant_names(
841 <#variant_ty as argh::FromArgs>::from_args(command_name, args)?
842 ));
843 }
844 )*
845 unreachable!("no subcommand matched")
846 }
847
848 fn redact_arg_values(command_name: &[&str], args: &[&str]) -> std::result::Result<Vec<String>, argh::EarlyExit> {
849 let subcommand_name = *command_name.last().expect("no subcommand name");
850 #(
851 if subcommand_name == <#variant_ty as argh::SubCommand>::COMMAND.name {
852 return <#variant_ty as argh::FromArgs>::redact_arg_values(command_name, args);
853 }
854 )*
855 unreachable!("no subcommand matched")
856 }
857 }
858
859 impl argh::SubCommands for #name {
860 const COMMANDS: &'static [&'static argh::CommandInfo] = &[#(
861 <#variant_ty as argh::SubCommand>::COMMAND,
862 )*];
863 }
864 }
865 }
866
867 /// Returns `Some(Bar)` if the field is a single-field unnamed variant like `Foo(Bar)`.
868 /// Otherwise, generates an error.
enum_only_single_field_unnamed_variants<'a>( errors: &Errors, variant_fields: &'a syn::Fields, ) -> Option<&'a syn::Type>869 fn enum_only_single_field_unnamed_variants<'a>(
870 errors: &Errors,
871 variant_fields: &'a syn::Fields,
872 ) -> Option<&'a syn::Type> {
873 macro_rules! with_enum_suggestion {
874 ($help_text:literal) => {
875 concat!(
876 $help_text,
877 "\nInstead, use a variant with a single unnamed field for each subcommand:\n",
878 " enum MyCommandEnum {\n",
879 " SubCommandOne(SubCommandOne),\n",
880 " SubCommandTwo(SubCommandTwo),\n",
881 " }",
882 )
883 };
884 }
885
886 match variant_fields {
887 syn::Fields::Named(fields) => {
888 errors.err(
889 fields,
890 with_enum_suggestion!(
891 "`#![derive(FromArgs)]` `enum`s do not support variants with named fields."
892 ),
893 );
894 None
895 }
896 syn::Fields::Unit => {
897 errors.err(
898 variant_fields,
899 with_enum_suggestion!(
900 "`#![derive(FromArgs)]` does not support `enum`s with no variants."
901 ),
902 );
903 None
904 }
905 syn::Fields::Unnamed(fields) => {
906 if fields.unnamed.len() != 1 {
907 errors.err(
908 fields,
909 with_enum_suggestion!(
910 "`#![derive(FromArgs)]` `enum` variants must only contain one field."
911 ),
912 );
913 None
914 } else {
915 // `unwrap` is okay because of the length check above.
916 let first_field = fields.unnamed.first().unwrap();
917 Some(&first_field.ty)
918 }
919 }
920 }
921 }
922