1package cobra
2
3import (
4	"fmt"
5	"os"
6	"strings"
7
8	"github.com/spf13/pflag"
9)
10
11const (
12	// ShellCompRequestCmd is the name of the hidden command that is used to request
13	// completion results from the program.  It is used by the shell completion scripts.
14	ShellCompRequestCmd = "__complete"
15	// ShellCompNoDescRequestCmd is the name of the hidden command that is used to request
16	// completion results without their description.  It is used by the shell completion scripts.
17	ShellCompNoDescRequestCmd = "__completeNoDesc"
18)
19
20// Global map of flag completion functions.
21var flagCompletionFunctions = map[*pflag.Flag]func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective){}
22
23// ShellCompDirective is a bit map representing the different behaviors the shell
24// can be instructed to have once completions have been provided.
25type ShellCompDirective int
26
27const (
28	// ShellCompDirectiveError indicates an error occurred and completions should be ignored.
29	ShellCompDirectiveError ShellCompDirective = 1 << iota
30
31	// ShellCompDirectiveNoSpace indicates that the shell should not add a space
32	// after the completion even if there is a single completion provided.
33	ShellCompDirectiveNoSpace
34
35	// ShellCompDirectiveNoFileComp indicates that the shell should not provide
36	// file completion even when no completion is provided.
37	// This currently does not work for zsh or bash < 4
38	ShellCompDirectiveNoFileComp
39
40	// ShellCompDirectiveFilterFileExt indicates that the provided completions
41	// should be used as file extension filters.
42	// For flags, using Command.MarkFlagFilename() and Command.MarkPersistentFlagFilename()
43	// is a shortcut to using this directive explicitly.  The BashCompFilenameExt
44	// annotation can also be used to obtain the same behavior for flags.
45	ShellCompDirectiveFilterFileExt
46
47	// ShellCompDirectiveFilterDirs indicates that only directory names should
48	// be provided in file completion.  To request directory names within another
49	// directory, the returned completions should specify the directory within
50	// which to search.  The BashCompSubdirsInDir annotation can be used to
51	// obtain the same behavior but only for flags.
52	ShellCompDirectiveFilterDirs
53
54	// ===========================================================================
55
56	// All directives using iota should be above this one.
57	// For internal use.
58	shellCompDirectiveMaxValue
59
60	// ShellCompDirectiveDefault indicates to let the shell perform its default
61	// behavior after completions have been provided.
62	// This one must be last to avoid messing up the iota count.
63	ShellCompDirectiveDefault ShellCompDirective = 0
64)
65
66// RegisterFlagCompletionFunc should be called to register a function to provide completion for a flag.
67func (c *Command) RegisterFlagCompletionFunc(flagName string, f func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective)) error {
68	flag := c.Flag(flagName)
69	if flag == nil {
70		return fmt.Errorf("RegisterFlagCompletionFunc: flag '%s' does not exist", flagName)
71	}
72	if _, exists := flagCompletionFunctions[flag]; exists {
73		return fmt.Errorf("RegisterFlagCompletionFunc: flag '%s' already registered", flagName)
74	}
75	flagCompletionFunctions[flag] = f
76	return nil
77}
78
79// Returns a string listing the different directive enabled in the specified parameter
80func (d ShellCompDirective) string() string {
81	var directives []string
82	if d&ShellCompDirectiveError != 0 {
83		directives = append(directives, "ShellCompDirectiveError")
84	}
85	if d&ShellCompDirectiveNoSpace != 0 {
86		directives = append(directives, "ShellCompDirectiveNoSpace")
87	}
88	if d&ShellCompDirectiveNoFileComp != 0 {
89		directives = append(directives, "ShellCompDirectiveNoFileComp")
90	}
91	if d&ShellCompDirectiveFilterFileExt != 0 {
92		directives = append(directives, "ShellCompDirectiveFilterFileExt")
93	}
94	if d&ShellCompDirectiveFilterDirs != 0 {
95		directives = append(directives, "ShellCompDirectiveFilterDirs")
96	}
97	if len(directives) == 0 {
98		directives = append(directives, "ShellCompDirectiveDefault")
99	}
100
101	if d >= shellCompDirectiveMaxValue {
102		return fmt.Sprintf("ERROR: unexpected ShellCompDirective value: %d", d)
103	}
104	return strings.Join(directives, ", ")
105}
106
107// Adds a special hidden command that can be used to request custom completions.
108func (c *Command) initCompleteCmd(args []string) {
109	completeCmd := &Command{
110		Use:                   fmt.Sprintf("%s [command-line]", ShellCompRequestCmd),
111		Aliases:               []string{ShellCompNoDescRequestCmd},
112		DisableFlagsInUseLine: true,
113		Hidden:                true,
114		DisableFlagParsing:    true,
115		Args:                  MinimumNArgs(1),
116		Short:                 "Request shell completion choices for the specified command-line",
117		Long: fmt.Sprintf("%[2]s is a special command that is used by the shell completion logic\n%[1]s",
118			"to request completion choices for the specified command-line.", ShellCompRequestCmd),
119		Run: func(cmd *Command, args []string) {
120			finalCmd, completions, directive, err := cmd.getCompletions(args)
121			if err != nil {
122				CompErrorln(err.Error())
123				// Keep going for multiple reasons:
124				// 1- There could be some valid completions even though there was an error
125				// 2- Even without completions, we need to print the directive
126			}
127
128			noDescriptions := (cmd.CalledAs() == ShellCompNoDescRequestCmd)
129			for _, comp := range completions {
130				if noDescriptions {
131					// Remove any description that may be included following a tab character.
132					comp = strings.Split(comp, "\t")[0]
133				}
134
135				// Make sure we only write the first line to the output.
136				// This is needed if a description contains a linebreak.
137				// Otherwise the shell scripts will interpret the other lines as new flags
138				// and could therefore provide a wrong completion.
139				comp = strings.Split(comp, "\n")[0]
140
141				// Finally trim the completion.  This is especially important to get rid
142				// of a trailing tab when there are no description following it.
143				// For example, a sub-command without a description should not be completed
144				// with a tab at the end (or else zsh will show a -- following it
145				// although there is no description).
146				comp = strings.TrimSpace(comp)
147
148				// Print each possible completion to stdout for the completion script to consume.
149				fmt.Fprintln(finalCmd.OutOrStdout(), comp)
150			}
151
152			if directive >= shellCompDirectiveMaxValue {
153				directive = ShellCompDirectiveDefault
154			}
155
156			// As the last printout, print the completion directive for the completion script to parse.
157			// The directive integer must be that last character following a single colon (:).
158			// The completion script expects :<directive>
159			fmt.Fprintf(finalCmd.OutOrStdout(), ":%d\n", directive)
160
161			// Print some helpful info to stderr for the user to understand.
162			// Output from stderr must be ignored by the completion script.
163			fmt.Fprintf(finalCmd.ErrOrStderr(), "Completion ended with directive: %s\n", directive.string())
164		},
165	}
166	c.AddCommand(completeCmd)
167	subCmd, _, err := c.Find(args)
168	if err != nil || subCmd.Name() != ShellCompRequestCmd {
169		// Only create this special command if it is actually being called.
170		// This reduces possible side-effects of creating such a command;
171		// for example, having this command would cause problems to a
172		// cobra program that only consists of the root command, since this
173		// command would cause the root command to suddenly have a subcommand.
174		c.RemoveCommand(completeCmd)
175	}
176}
177
178func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDirective, error) {
179	// The last argument, which is not completely typed by the user,
180	// should not be part of the list of arguments
181	toComplete := args[len(args)-1]
182	trimmedArgs := args[:len(args)-1]
183
184	var finalCmd *Command
185	var finalArgs []string
186	var err error
187	// Find the real command for which completion must be performed
188	// check if we need to traverse here to parse local flags on parent commands
189	if c.Root().TraverseChildren {
190		finalCmd, finalArgs, err = c.Root().Traverse(trimmedArgs)
191	} else {
192		finalCmd, finalArgs, err = c.Root().Find(trimmedArgs)
193	}
194	if err != nil {
195		// Unable to find the real command. E.g., <program> someInvalidCmd <TAB>
196		return c, []string{}, ShellCompDirectiveDefault, fmt.Errorf("Unable to find a command for arguments: %v", trimmedArgs)
197	}
198
199	// Check if we are doing flag value completion before parsing the flags.
200	// This is important because if we are completing a flag value, we need to also
201	// remove the flag name argument from the list of finalArgs or else the parsing
202	// could fail due to an invalid value (incomplete) for the flag.
203	flag, finalArgs, toComplete, err := checkIfFlagCompletion(finalCmd, finalArgs, toComplete)
204	if err != nil {
205		// Error while attempting to parse flags
206		return finalCmd, []string{}, ShellCompDirectiveDefault, err
207	}
208
209	// Parse the flags early so we can check if required flags are set
210	if err = finalCmd.ParseFlags(finalArgs); err != nil {
211		return finalCmd, []string{}, ShellCompDirectiveDefault, fmt.Errorf("Error while parsing flags from args %v: %s", finalArgs, err.Error())
212	}
213
214	if flag != nil {
215		// Check if we are completing a flag value subject to annotations
216		if validExts, present := flag.Annotations[BashCompFilenameExt]; present {
217			if len(validExts) != 0 {
218				// File completion filtered by extensions
219				return finalCmd, validExts, ShellCompDirectiveFilterFileExt, nil
220			}
221
222			// The annotation requests simple file completion.  There is no reason to do
223			// that since it is the default behavior anyway.  Let's ignore this annotation
224			// in case the program also registered a completion function for this flag.
225			// Even though it is a mistake on the program's side, let's be nice when we can.
226		}
227
228		if subDir, present := flag.Annotations[BashCompSubdirsInDir]; present {
229			if len(subDir) == 1 {
230				// Directory completion from within a directory
231				return finalCmd, subDir, ShellCompDirectiveFilterDirs, nil
232			}
233			// Directory completion
234			return finalCmd, []string{}, ShellCompDirectiveFilterDirs, nil
235		}
236	}
237
238	// When doing completion of a flag name, as soon as an argument starts with
239	// a '-' we know it is a flag.  We cannot use isFlagArg() here as it requires
240	// the flag name to be complete
241	if flag == nil && len(toComplete) > 0 && toComplete[0] == '-' && !strings.Contains(toComplete, "=") {
242		var completions []string
243
244		// First check for required flags
245		completions = completeRequireFlags(finalCmd, toComplete)
246
247		// If we have not found any required flags, only then can we show regular flags
248		if len(completions) == 0 {
249			doCompleteFlags := func(flag *pflag.Flag) {
250				if !flag.Changed ||
251					strings.Contains(flag.Value.Type(), "Slice") ||
252					strings.Contains(flag.Value.Type(), "Array") {
253					// If the flag is not already present, or if it can be specified multiple times (Array or Slice)
254					// we suggest it as a completion
255					completions = append(completions, getFlagNameCompletions(flag, toComplete)...)
256				}
257			}
258
259			// We cannot use finalCmd.Flags() because we may not have called ParsedFlags() for commands
260			// that have set DisableFlagParsing; it is ParseFlags() that merges the inherited and
261			// non-inherited flags.
262			finalCmd.InheritedFlags().VisitAll(func(flag *pflag.Flag) {
263				doCompleteFlags(flag)
264			})
265			finalCmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) {
266				doCompleteFlags(flag)
267			})
268		}
269
270		directive := ShellCompDirectiveNoFileComp
271		if len(completions) == 1 && strings.HasSuffix(completions[0], "=") {
272			// If there is a single completion, the shell usually adds a space
273			// after the completion.  We don't want that if the flag ends with an =
274			directive = ShellCompDirectiveNoSpace
275		}
276		return finalCmd, completions, directive, nil
277	}
278
279	// We only remove the flags from the arguments if DisableFlagParsing is not set.
280	// This is important for commands which have requested to do their own flag completion.
281	if !finalCmd.DisableFlagParsing {
282		finalArgs = finalCmd.Flags().Args()
283	}
284
285	var completions []string
286	directive := ShellCompDirectiveDefault
287	if flag == nil {
288		foundLocalNonPersistentFlag := false
289		// If TraverseChildren is true on the root command we don't check for
290		// local flags because we can use a local flag on a parent command
291		if !finalCmd.Root().TraverseChildren {
292			// Check if there are any local, non-persistent flags on the command-line
293			localNonPersistentFlags := finalCmd.LocalNonPersistentFlags()
294			finalCmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) {
295				if localNonPersistentFlags.Lookup(flag.Name) != nil && flag.Changed {
296					foundLocalNonPersistentFlag = true
297				}
298			})
299		}
300
301		// Complete subcommand names, including the help command
302		if len(finalArgs) == 0 && !foundLocalNonPersistentFlag {
303			// We only complete sub-commands if:
304			// - there are no arguments on the command-line and
305			// - there are no local, non-peristent flag on the command-line or TraverseChildren is true
306			for _, subCmd := range finalCmd.Commands() {
307				if subCmd.IsAvailableCommand() || subCmd == finalCmd.helpCommand {
308					if strings.HasPrefix(subCmd.Name(), toComplete) {
309						completions = append(completions, fmt.Sprintf("%s\t%s", subCmd.Name(), subCmd.Short))
310					}
311					directive = ShellCompDirectiveNoFileComp
312				}
313			}
314		}
315
316		// Complete required flags even without the '-' prefix
317		completions = append(completions, completeRequireFlags(finalCmd, toComplete)...)
318
319		// Always complete ValidArgs, even if we are completing a subcommand name.
320		// This is for commands that have both subcommands and ValidArgs.
321		if len(finalCmd.ValidArgs) > 0 {
322			if len(finalArgs) == 0 {
323				// ValidArgs are only for the first argument
324				for _, validArg := range finalCmd.ValidArgs {
325					if strings.HasPrefix(validArg, toComplete) {
326						completions = append(completions, validArg)
327					}
328				}
329				directive = ShellCompDirectiveNoFileComp
330
331				// If no completions were found within commands or ValidArgs,
332				// see if there are any ArgAliases that should be completed.
333				if len(completions) == 0 {
334					for _, argAlias := range finalCmd.ArgAliases {
335						if strings.HasPrefix(argAlias, toComplete) {
336							completions = append(completions, argAlias)
337						}
338					}
339				}
340			}
341
342			// If there are ValidArgs specified (even if they don't match), we stop completion.
343			// Only one of ValidArgs or ValidArgsFunction can be used for a single command.
344			return finalCmd, completions, directive, nil
345		}
346
347		// Let the logic continue so as to add any ValidArgsFunction completions,
348		// even if we already found sub-commands.
349		// This is for commands that have subcommands but also specify a ValidArgsFunction.
350	}
351
352	// Find the completion function for the flag or command
353	var completionFn func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective)
354	if flag != nil {
355		completionFn = flagCompletionFunctions[flag]
356	} else {
357		completionFn = finalCmd.ValidArgsFunction
358	}
359	if completionFn != nil {
360		// Go custom completion defined for this flag or command.
361		// Call the registered completion function to get the completions.
362		var comps []string
363		comps, directive = completionFn(finalCmd, finalArgs, toComplete)
364		completions = append(completions, comps...)
365	}
366
367	return finalCmd, completions, directive, nil
368}
369
370func getFlagNameCompletions(flag *pflag.Flag, toComplete string) []string {
371	if nonCompletableFlag(flag) {
372		return []string{}
373	}
374
375	var completions []string
376	flagName := "--" + flag.Name
377	if strings.HasPrefix(flagName, toComplete) {
378		// Flag without the =
379		completions = append(completions, fmt.Sprintf("%s\t%s", flagName, flag.Usage))
380
381		// Why suggest both long forms: --flag and --flag= ?
382		// This forces the user to *always* have to type either an = or a space after the flag name.
383		// Let's be nice and avoid making users have to do that.
384		// Since boolean flags and shortname flags don't show the = form, let's go that route and never show it.
385		// The = form will still work, we just won't suggest it.
386		// This also makes the list of suggested flags shorter as we avoid all the = forms.
387		//
388		// if len(flag.NoOptDefVal) == 0 {
389		// 	// Flag requires a value, so it can be suffixed with =
390		// 	flagName += "="
391		// 	completions = append(completions, fmt.Sprintf("%s\t%s", flagName, flag.Usage))
392		// }
393	}
394
395	flagName = "-" + flag.Shorthand
396	if len(flag.Shorthand) > 0 && strings.HasPrefix(flagName, toComplete) {
397		completions = append(completions, fmt.Sprintf("%s\t%s", flagName, flag.Usage))
398	}
399
400	return completions
401}
402
403func completeRequireFlags(finalCmd *Command, toComplete string) []string {
404	var completions []string
405
406	doCompleteRequiredFlags := func(flag *pflag.Flag) {
407		if _, present := flag.Annotations[BashCompOneRequiredFlag]; present {
408			if !flag.Changed {
409				// If the flag is not already present, we suggest it as a completion
410				completions = append(completions, getFlagNameCompletions(flag, toComplete)...)
411			}
412		}
413	}
414
415	// We cannot use finalCmd.Flags() because we may not have called ParsedFlags() for commands
416	// that have set DisableFlagParsing; it is ParseFlags() that merges the inherited and
417	// non-inherited flags.
418	finalCmd.InheritedFlags().VisitAll(func(flag *pflag.Flag) {
419		doCompleteRequiredFlags(flag)
420	})
421	finalCmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) {
422		doCompleteRequiredFlags(flag)
423	})
424
425	return completions
426}
427
428func checkIfFlagCompletion(finalCmd *Command, args []string, lastArg string) (*pflag.Flag, []string, string, error) {
429	if finalCmd.DisableFlagParsing {
430		// We only do flag completion if we are allowed to parse flags
431		// This is important for commands which have requested to do their own flag completion.
432		return nil, args, lastArg, nil
433	}
434
435	var flagName string
436	trimmedArgs := args
437	flagWithEqual := false
438
439	// When doing completion of a flag name, as soon as an argument starts with
440	// a '-' we know it is a flag.  We cannot use isFlagArg() here as that function
441	// requires the flag name to be complete
442	if len(lastArg) > 0 && lastArg[0] == '-' {
443		if index := strings.Index(lastArg, "="); index >= 0 {
444			// Flag with an =
445			flagName = strings.TrimLeft(lastArg[:index], "-")
446			lastArg = lastArg[index+1:]
447			flagWithEqual = true
448		} else {
449			// Normal flag completion
450			return nil, args, lastArg, nil
451		}
452	}
453
454	if len(flagName) == 0 {
455		if len(args) > 0 {
456			prevArg := args[len(args)-1]
457			if isFlagArg(prevArg) {
458				// Only consider the case where the flag does not contain an =.
459				// If the flag contains an = it means it has already been fully processed,
460				// so we don't need to deal with it here.
461				if index := strings.Index(prevArg, "="); index < 0 {
462					flagName = strings.TrimLeft(prevArg, "-")
463
464					// Remove the uncompleted flag or else there could be an error created
465					// for an invalid value for that flag
466					trimmedArgs = args[:len(args)-1]
467				}
468			}
469		}
470	}
471
472	if len(flagName) == 0 {
473		// Not doing flag completion
474		return nil, trimmedArgs, lastArg, nil
475	}
476
477	flag := findFlag(finalCmd, flagName)
478	if flag == nil {
479		// Flag not supported by this command, nothing to complete
480		err := fmt.Errorf("Subcommand '%s' does not support flag '%s'", finalCmd.Name(), flagName)
481		return nil, nil, "", err
482	}
483
484	if !flagWithEqual {
485		if len(flag.NoOptDefVal) != 0 {
486			// We had assumed dealing with a two-word flag but the flag is a boolean flag.
487			// In that case, there is no value following it, so we are not really doing flag completion.
488			// Reset everything to do noun completion.
489			trimmedArgs = args
490			flag = nil
491		}
492	}
493
494	return flag, trimmedArgs, lastArg, nil
495}
496
497func findFlag(cmd *Command, name string) *pflag.Flag {
498	flagSet := cmd.Flags()
499	if len(name) == 1 {
500		// First convert the short flag into a long flag
501		// as the cmd.Flag() search only accepts long flags
502		if short := flagSet.ShorthandLookup(name); short != nil {
503			name = short.Name
504		} else {
505			set := cmd.InheritedFlags()
506			if short = set.ShorthandLookup(name); short != nil {
507				name = short.Name
508			} else {
509				return nil
510			}
511		}
512	}
513	return cmd.Flag(name)
514}
515
516// CompDebug prints the specified string to the same file as where the
517// completion script prints its logs.
518// Note that completion printouts should never be on stdout as they would
519// be wrongly interpreted as actual completion choices by the completion script.
520func CompDebug(msg string, printToStdErr bool) {
521	msg = fmt.Sprintf("[Debug] %s", msg)
522
523	// Such logs are only printed when the user has set the environment
524	// variable BASH_COMP_DEBUG_FILE to the path of some file to be used.
525	if path := os.Getenv("BASH_COMP_DEBUG_FILE"); path != "" {
526		f, err := os.OpenFile(path,
527			os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
528		if err == nil {
529			defer f.Close()
530			WriteStringAndCheck(f, msg)
531		}
532	}
533
534	if printToStdErr {
535		// Must print to stderr for this not to be read by the completion script.
536		fmt.Fprint(os.Stderr, msg)
537	}
538}
539
540// CompDebugln prints the specified string with a newline at the end
541// to the same file as where the completion script prints its logs.
542// Such logs are only printed when the user has set the environment
543// variable BASH_COMP_DEBUG_FILE to the path of some file to be used.
544func CompDebugln(msg string, printToStdErr bool) {
545	CompDebug(fmt.Sprintf("%s\n", msg), printToStdErr)
546}
547
548// CompError prints the specified completion message to stderr.
549func CompError(msg string) {
550	msg = fmt.Sprintf("[Error] %s", msg)
551	CompDebug(msg, true)
552}
553
554// CompErrorln prints the specified completion message to stderr with a newline at the end.
555func CompErrorln(msg string) {
556	CompError(fmt.Sprintf("%s\n", msg))
557}
558