1package cobra
2
3import (
4	"errors"
5	"fmt"
6	"os"
7	"strings"
8
9	"github.com/spf13/pflag"
10)
11
12const (
13	// ShellCompRequestCmd is the name of the hidden command that is used to request
14	// completion results from the program.  It is used by the shell completion scripts.
15	ShellCompRequestCmd = "__complete"
16	// ShellCompNoDescRequestCmd is the name of the hidden command that is used to request
17	// completion results without their description.  It is used by the shell completion scripts.
18	ShellCompNoDescRequestCmd = "__completeNoDesc"
19)
20
21// Global map of flag completion functions.
22var flagCompletionFunctions = map[*pflag.Flag]func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective){}
23
24// ShellCompDirective is a bit map representing the different behaviors the shell
25// can be instructed to have once completions have been provided.
26type ShellCompDirective int
27
28const (
29	// ShellCompDirectiveError indicates an error occurred and completions should be ignored.
30	ShellCompDirectiveError ShellCompDirective = 1 << iota
31
32	// ShellCompDirectiveNoSpace indicates that the shell should not add a space
33	// after the completion even if there is a single completion provided.
34	ShellCompDirectiveNoSpace
35
36	// ShellCompDirectiveNoFileComp indicates that the shell should not provide
37	// file completion even when no completion is provided.
38	// This currently does not work for zsh or bash < 4
39	ShellCompDirectiveNoFileComp
40
41	// ShellCompDirectiveDefault indicates to let the shell perform its default
42	// behavior after completions have been provided.
43	ShellCompDirectiveDefault ShellCompDirective = 0
44)
45
46// RegisterFlagCompletionFunc should be called to register a function to provide completion for a flag.
47func (c *Command) RegisterFlagCompletionFunc(flagName string, f func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective)) error {
48	flag := c.Flag(flagName)
49	if flag == nil {
50		return fmt.Errorf("RegisterFlagCompletionFunc: flag '%s' does not exist", flagName)
51	}
52	if _, exists := flagCompletionFunctions[flag]; exists {
53		return fmt.Errorf("RegisterFlagCompletionFunc: flag '%s' already registered", flagName)
54	}
55	flagCompletionFunctions[flag] = f
56	return nil
57}
58
59// Returns a string listing the different directive enabled in the specified parameter
60func (d ShellCompDirective) string() string {
61	var directives []string
62	if d&ShellCompDirectiveError != 0 {
63		directives = append(directives, "ShellCompDirectiveError")
64	}
65	if d&ShellCompDirectiveNoSpace != 0 {
66		directives = append(directives, "ShellCompDirectiveNoSpace")
67	}
68	if d&ShellCompDirectiveNoFileComp != 0 {
69		directives = append(directives, "ShellCompDirectiveNoFileComp")
70	}
71	if len(directives) == 0 {
72		directives = append(directives, "ShellCompDirectiveDefault")
73	}
74
75	if d > ShellCompDirectiveError+ShellCompDirectiveNoSpace+ShellCompDirectiveNoFileComp {
76		return fmt.Sprintf("ERROR: unexpected ShellCompDirective value: %d", d)
77	}
78	return strings.Join(directives, ", ")
79}
80
81// Adds a special hidden command that can be used to request custom completions.
82func (c *Command) initCompleteCmd(args []string) {
83	completeCmd := &Command{
84		Use:                   fmt.Sprintf("%s [command-line]", ShellCompRequestCmd),
85		Aliases:               []string{ShellCompNoDescRequestCmd},
86		DisableFlagsInUseLine: true,
87		Hidden:                true,
88		DisableFlagParsing:    true,
89		Args:                  MinimumNArgs(1),
90		Short:                 "Request shell completion choices for the specified command-line",
91		Long: fmt.Sprintf("%[2]s is a special command that is used by the shell completion logic\n%[1]s",
92			"to request completion choices for the specified command-line.", ShellCompRequestCmd),
93		Run: func(cmd *Command, args []string) {
94			finalCmd, completions, directive, err := cmd.getCompletions(args)
95			if err != nil {
96				CompErrorln(err.Error())
97				// Keep going for multiple reasons:
98				// 1- There could be some valid completions even though there was an error
99				// 2- Even without completions, we need to print the directive
100			}
101
102			noDescriptions := (cmd.CalledAs() == ShellCompNoDescRequestCmd)
103			for _, comp := range completions {
104				if noDescriptions {
105					// Remove any description that may be included following a tab character.
106					comp = strings.Split(comp, "\t")[0]
107				}
108				// Print each possible completion to stdout for the completion script to consume.
109				fmt.Fprintln(finalCmd.OutOrStdout(), comp)
110			}
111
112			if directive > ShellCompDirectiveError+ShellCompDirectiveNoSpace+ShellCompDirectiveNoFileComp {
113				directive = ShellCompDirectiveDefault
114			}
115
116			// As the last printout, print the completion directive for the completion script to parse.
117			// The directive integer must be that last character following a single colon (:).
118			// The completion script expects :<directive>
119			fmt.Fprintf(finalCmd.OutOrStdout(), ":%d\n", directive)
120
121			// Print some helpful info to stderr for the user to understand.
122			// Output from stderr must be ignored by the completion script.
123			fmt.Fprintf(finalCmd.ErrOrStderr(), "Completion ended with directive: %s\n", directive.string())
124		},
125	}
126	c.AddCommand(completeCmd)
127	subCmd, _, err := c.Find(args)
128	if err != nil || subCmd.Name() != ShellCompRequestCmd {
129		// Only create this special command if it is actually being called.
130		// This reduces possible side-effects of creating such a command;
131		// for example, having this command would cause problems to a
132		// cobra program that only consists of the root command, since this
133		// command would cause the root command to suddenly have a subcommand.
134		c.RemoveCommand(completeCmd)
135	}
136}
137
138func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDirective, error) {
139	var completions []string
140
141	// The last argument, which is not completely typed by the user,
142	// should not be part of the list of arguments
143	toComplete := args[len(args)-1]
144	trimmedArgs := args[:len(args)-1]
145
146	// Find the real command for which completion must be performed
147	finalCmd, finalArgs, err := c.Root().Find(trimmedArgs)
148	if err != nil {
149		// Unable to find the real command. E.g., <program> someInvalidCmd <TAB>
150		return c, completions, ShellCompDirectiveDefault, fmt.Errorf("Unable to find a command for arguments: %v", trimmedArgs)
151	}
152
153	// When doing completion of a flag name, as soon as an argument starts with
154	// a '-' we know it is a flag.  We cannot use isFlagArg() here as it requires
155	// the flag to be complete
156	if len(toComplete) > 0 && toComplete[0] == '-' && !strings.Contains(toComplete, "=") {
157		// We are completing a flag name
158		finalCmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) {
159			completions = append(completions, getFlagNameCompletions(flag, toComplete)...)
160		})
161		finalCmd.InheritedFlags().VisitAll(func(flag *pflag.Flag) {
162			completions = append(completions, getFlagNameCompletions(flag, toComplete)...)
163		})
164
165		directive := ShellCompDirectiveDefault
166		if len(completions) > 0 {
167			if strings.HasSuffix(completions[0], "=") {
168				directive = ShellCompDirectiveNoSpace
169			}
170		}
171		return finalCmd, completions, directive, nil
172	}
173
174	var flag *pflag.Flag
175	if !finalCmd.DisableFlagParsing {
176		// We only do flag completion if we are allowed to parse flags
177		// This is important for commands which have requested to do their own flag completion.
178		flag, finalArgs, toComplete, err = checkIfFlagCompletion(finalCmd, finalArgs, toComplete)
179		if err != nil {
180			// Error while attempting to parse flags
181			return finalCmd, completions, ShellCompDirectiveDefault, err
182		}
183	}
184
185	if flag == nil {
186		// Complete subcommand names
187		for _, subCmd := range finalCmd.Commands() {
188			if subCmd.IsAvailableCommand() && strings.HasPrefix(subCmd.Name(), toComplete) {
189				completions = append(completions, fmt.Sprintf("%s\t%s", subCmd.Name(), subCmd.Short))
190			}
191		}
192
193		if len(finalCmd.ValidArgs) > 0 {
194			// Always complete ValidArgs, even if we are completing a subcommand name.
195			// This is for commands that have both subcommands and ValidArgs.
196			for _, validArg := range finalCmd.ValidArgs {
197				if strings.HasPrefix(validArg, toComplete) {
198					completions = append(completions, validArg)
199				}
200			}
201
202			// If there are ValidArgs specified (even if they don't match), we stop completion.
203			// Only one of ValidArgs or ValidArgsFunction can be used for a single command.
204			return finalCmd, completions, ShellCompDirectiveNoFileComp, nil
205		}
206
207		// Always let the logic continue so as to add any ValidArgsFunction completions,
208		// even if we already found sub-commands.
209		// This is for commands that have subcommands but also specify a ValidArgsFunction.
210	}
211
212	// Parse the flags and extract the arguments to prepare for calling the completion function
213	if err = finalCmd.ParseFlags(finalArgs); err != nil {
214		return finalCmd, completions, ShellCompDirectiveDefault, fmt.Errorf("Error while parsing flags from args %v: %s", finalArgs, err.Error())
215	}
216
217	// We only remove the flags from the arguments if DisableFlagParsing is not set.
218	// This is important for commands which have requested to do their own flag completion.
219	if !finalCmd.DisableFlagParsing {
220		finalArgs = finalCmd.Flags().Args()
221	}
222
223	// Find the completion function for the flag or command
224	var completionFn func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective)
225	if flag != nil {
226		completionFn = flagCompletionFunctions[flag]
227	} else {
228		completionFn = finalCmd.ValidArgsFunction
229	}
230	if completionFn == nil {
231		// Go custom completion not supported/needed for this flag or command
232		return finalCmd, completions, ShellCompDirectiveDefault, nil
233	}
234
235	// Call the registered completion function to get the completions
236	comps, directive := completionFn(finalCmd, finalArgs, toComplete)
237	completions = append(completions, comps...)
238	return finalCmd, completions, directive, nil
239}
240
241func getFlagNameCompletions(flag *pflag.Flag, toComplete string) []string {
242	if nonCompletableFlag(flag) {
243		return []string{}
244	}
245
246	var completions []string
247	flagName := "--" + flag.Name
248	if strings.HasPrefix(flagName, toComplete) {
249		// Flag without the =
250		completions = append(completions, fmt.Sprintf("%s\t%s", flagName, flag.Usage))
251
252		if len(flag.NoOptDefVal) == 0 {
253			// Flag requires a value, so it can be suffixed with =
254			flagName += "="
255			completions = append(completions, fmt.Sprintf("%s\t%s", flagName, flag.Usage))
256		}
257	}
258
259	flagName = "-" + flag.Shorthand
260	if len(flag.Shorthand) > 0 && strings.HasPrefix(flagName, toComplete) {
261		completions = append(completions, fmt.Sprintf("%s\t%s", flagName, flag.Usage))
262	}
263
264	return completions
265}
266
267func checkIfFlagCompletion(finalCmd *Command, args []string, lastArg string) (*pflag.Flag, []string, string, error) {
268	var flagName string
269	trimmedArgs := args
270	flagWithEqual := false
271	if isFlagArg(lastArg) {
272		if index := strings.Index(lastArg, "="); index >= 0 {
273			flagName = strings.TrimLeft(lastArg[:index], "-")
274			lastArg = lastArg[index+1:]
275			flagWithEqual = true
276		} else {
277			return nil, nil, "", errors.New("Unexpected completion request for flag")
278		}
279	}
280
281	if len(flagName) == 0 {
282		if len(args) > 0 {
283			prevArg := args[len(args)-1]
284			if isFlagArg(prevArg) {
285				// Only consider the case where the flag does not contain an =.
286				// If the flag contains an = it means it has already been fully processed,
287				// so we don't need to deal with it here.
288				if index := strings.Index(prevArg, "="); index < 0 {
289					flagName = strings.TrimLeft(prevArg, "-")
290
291					// Remove the uncompleted flag or else there could be an error created
292					// for an invalid value for that flag
293					trimmedArgs = args[:len(args)-1]
294				}
295			}
296		}
297	}
298
299	if len(flagName) == 0 {
300		// Not doing flag completion
301		return nil, trimmedArgs, lastArg, nil
302	}
303
304	flag := findFlag(finalCmd, flagName)
305	if flag == nil {
306		// Flag not supported by this command, nothing to complete
307		err := fmt.Errorf("Subcommand '%s' does not support flag '%s'", finalCmd.Name(), flagName)
308		return nil, nil, "", err
309	}
310
311	if !flagWithEqual {
312		if len(flag.NoOptDefVal) != 0 {
313			// We had assumed dealing with a two-word flag but the flag is a boolean flag.
314			// In that case, there is no value following it, so we are not really doing flag completion.
315			// Reset everything to do noun completion.
316			trimmedArgs = args
317			flag = nil
318		}
319	}
320
321	return flag, trimmedArgs, lastArg, nil
322}
323
324func findFlag(cmd *Command, name string) *pflag.Flag {
325	flagSet := cmd.Flags()
326	if len(name) == 1 {
327		// First convert the short flag into a long flag
328		// as the cmd.Flag() search only accepts long flags
329		if short := flagSet.ShorthandLookup(name); short != nil {
330			name = short.Name
331		} else {
332			set := cmd.InheritedFlags()
333			if short = set.ShorthandLookup(name); short != nil {
334				name = short.Name
335			} else {
336				return nil
337			}
338		}
339	}
340	return cmd.Flag(name)
341}
342
343// CompDebug prints the specified string to the same file as where the
344// completion script prints its logs.
345// Note that completion printouts should never be on stdout as they would
346// be wrongly interpreted as actual completion choices by the completion script.
347func CompDebug(msg string, printToStdErr bool) {
348	msg = fmt.Sprintf("[Debug] %s", msg)
349
350	// Such logs are only printed when the user has set the environment
351	// variable BASH_COMP_DEBUG_FILE to the path of some file to be used.
352	if path := os.Getenv("BASH_COMP_DEBUG_FILE"); path != "" {
353		f, err := os.OpenFile(path,
354			os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
355		if err == nil {
356			defer f.Close()
357			f.WriteString(msg)
358		}
359	}
360
361	if printToStdErr {
362		// Must print to stderr for this not to be read by the completion script.
363		fmt.Fprintf(os.Stderr, msg)
364	}
365}
366
367// CompDebugln prints the specified string with a newline at the end
368// to the same file as where the completion script prints its logs.
369// Such logs are only printed when the user has set the environment
370// variable BASH_COMP_DEBUG_FILE to the path of some file to be used.
371func CompDebugln(msg string, printToStdErr bool) {
372	CompDebug(fmt.Sprintf("%s\n", msg), printToStdErr)
373}
374
375// CompError prints the specified completion message to stderr.
376func CompError(msg string) {
377	msg = fmt.Sprintf("[Error] %s", msg)
378	CompDebug(msg, true)
379}
380
381// CompErrorln prints the specified completion message to stderr with a newline at the end.
382func CompErrorln(msg string) {
383	CompError(fmt.Sprintf("%s\n", msg))
384}
385