1package cobra
2
3import (
4	"bytes"
5	"fmt"
6	"io"
7	"os"
8	"sort"
9	"strings"
10
11	"github.com/spf13/pflag"
12)
13
14// Annotations for Bash completion.
15const (
16	BashCompFilenameExt     = "cobra_annotation_bash_completion_filename_extensions"
17	BashCompCustom          = "cobra_annotation_bash_completion_custom"
18	BashCompOneRequiredFlag = "cobra_annotation_bash_completion_one_required_flag"
19	BashCompSubdirsInDir    = "cobra_annotation_bash_completion_subdirs_in_dir"
20)
21
22func writePreamble(buf *bytes.Buffer, name string) {
23	buf.WriteString(fmt.Sprintf("# bash completion for %-36s -*- shell-script -*-\n", name))
24	buf.WriteString(fmt.Sprintf(`
25__%[1]s_debug()
26{
27    if [[ -n ${BASH_COMP_DEBUG_FILE} ]]; then
28        echo "$*" >> "${BASH_COMP_DEBUG_FILE}"
29    fi
30}
31
32# Homebrew on Macs have version 1.3 of bash-completion which doesn't include
33# _init_completion. This is a very minimal version of that function.
34__%[1]s_init_completion()
35{
36    COMPREPLY=()
37    _get_comp_words_by_ref "$@" cur prev words cword
38}
39
40__%[1]s_index_of_word()
41{
42    local w word=$1
43    shift
44    index=0
45    for w in "$@"; do
46        [[ $w = "$word" ]] && return
47        index=$((index+1))
48    done
49    index=-1
50}
51
52__%[1]s_contains_word()
53{
54    local w word=$1; shift
55    for w in "$@"; do
56        [[ $w = "$word" ]] && return
57    done
58    return 1
59}
60
61__%[1]s_handle_reply()
62{
63    __%[1]s_debug "${FUNCNAME[0]}"
64    local comp
65    case $cur in
66        -*)
67            if [[ $(type -t compopt) = "builtin" ]]; then
68                compopt -o nospace
69            fi
70            local allflags
71            if [ ${#must_have_one_flag[@]} -ne 0 ]; then
72                allflags=("${must_have_one_flag[@]}")
73            else
74                allflags=("${flags[*]} ${two_word_flags[*]}")
75            fi
76            while IFS='' read -r comp; do
77                COMPREPLY+=("$comp")
78            done < <(compgen -W "${allflags[*]}" -- "$cur")
79            if [[ $(type -t compopt) = "builtin" ]]; then
80                [[ "${COMPREPLY[0]}" == *= ]] || compopt +o nospace
81            fi
82
83            # complete after --flag=abc
84            if [[ $cur == *=* ]]; then
85                if [[ $(type -t compopt) = "builtin" ]]; then
86                    compopt +o nospace
87                fi
88
89                local index flag
90                flag="${cur%%=*}"
91                __%[1]s_index_of_word "${flag}" "${flags_with_completion[@]}"
92                COMPREPLY=()
93                if [[ ${index} -ge 0 ]]; then
94                    PREFIX=""
95                    cur="${cur#*=}"
96                    ${flags_completion[${index}]}
97                    if [ -n "${ZSH_VERSION}" ]; then
98                        # zsh completion needs --flag= prefix
99                        eval "COMPREPLY=( \"\${COMPREPLY[@]/#/${flag}=}\" )"
100                    fi
101                fi
102            fi
103            return 0;
104            ;;
105    esac
106
107    # check if we are handling a flag with special work handling
108    local index
109    __%[1]s_index_of_word "${prev}" "${flags_with_completion[@]}"
110    if [[ ${index} -ge 0 ]]; then
111        ${flags_completion[${index}]}
112        return
113    fi
114
115    # we are parsing a flag and don't have a special handler, no completion
116    if [[ ${cur} != "${words[cword]}" ]]; then
117        return
118    fi
119
120    local completions
121    completions=("${commands[@]}")
122    if [[ ${#must_have_one_noun[@]} -ne 0 ]]; then
123        completions=("${must_have_one_noun[@]}")
124    fi
125    if [[ ${#must_have_one_flag[@]} -ne 0 ]]; then
126        completions+=("${must_have_one_flag[@]}")
127    fi
128    while IFS='' read -r comp; do
129        COMPREPLY+=("$comp")
130    done < <(compgen -W "${completions[*]}" -- "$cur")
131
132    if [[ ${#COMPREPLY[@]} -eq 0 && ${#noun_aliases[@]} -gt 0 && ${#must_have_one_noun[@]} -ne 0 ]]; then
133        while IFS='' read -r comp; do
134            COMPREPLY+=("$comp")
135        done < <(compgen -W "${noun_aliases[*]}" -- "$cur")
136    fi
137
138    if [[ ${#COMPREPLY[@]} -eq 0 ]]; then
139		if declare -F __%[1]s_custom_func >/dev/null; then
140			# try command name qualified custom func
141			__%[1]s_custom_func
142		else
143			# otherwise fall back to unqualified for compatibility
144			declare -F __custom_func >/dev/null && __custom_func
145		fi
146    fi
147
148    # available in bash-completion >= 2, not always present on macOS
149    if declare -F __ltrim_colon_completions >/dev/null; then
150        __ltrim_colon_completions "$cur"
151    fi
152
153    # If there is only 1 completion and it is a flag with an = it will be completed
154    # but we don't want a space after the =
155    if [[ "${#COMPREPLY[@]}" -eq "1" ]] && [[ $(type -t compopt) = "builtin" ]] && [[ "${COMPREPLY[0]}" == --*= ]]; then
156       compopt -o nospace
157    fi
158}
159
160# The arguments should be in the form "ext1|ext2|extn"
161__%[1]s_handle_filename_extension_flag()
162{
163    local ext="$1"
164    _filedir "@(${ext})"
165}
166
167__%[1]s_handle_subdirs_in_dir_flag()
168{
169    local dir="$1"
170    pushd "${dir}" >/dev/null 2>&1 && _filedir -d && popd >/dev/null 2>&1 || return
171}
172
173__%[1]s_handle_flag()
174{
175    __%[1]s_debug "${FUNCNAME[0]}: c is $c words[c] is ${words[c]}"
176
177    # if a command required a flag, and we found it, unset must_have_one_flag()
178    local flagname=${words[c]}
179    local flagvalue
180    # if the word contained an =
181    if [[ ${words[c]} == *"="* ]]; then
182        flagvalue=${flagname#*=} # take in as flagvalue after the =
183        flagname=${flagname%%=*} # strip everything after the =
184        flagname="${flagname}=" # but put the = back
185    fi
186    __%[1]s_debug "${FUNCNAME[0]}: looking for ${flagname}"
187    if __%[1]s_contains_word "${flagname}" "${must_have_one_flag[@]}"; then
188        must_have_one_flag=()
189    fi
190
191    # if you set a flag which only applies to this command, don't show subcommands
192    if __%[1]s_contains_word "${flagname}" "${local_nonpersistent_flags[@]}"; then
193      commands=()
194    fi
195
196    # keep flag value with flagname as flaghash
197    # flaghash variable is an associative array which is only supported in bash > 3.
198    if [[ -z "${BASH_VERSION}" || "${BASH_VERSINFO[0]}" -gt 3 ]]; then
199        if [ -n "${flagvalue}" ] ; then
200            flaghash[${flagname}]=${flagvalue}
201        elif [ -n "${words[ $((c+1)) ]}" ] ; then
202            flaghash[${flagname}]=${words[ $((c+1)) ]}
203        else
204            flaghash[${flagname}]="true" # pad "true" for bool flag
205        fi
206    fi
207
208    # skip the argument to a two word flag
209    if [[ ${words[c]} != *"="* ]] && __%[1]s_contains_word "${words[c]}" "${two_word_flags[@]}"; then
210			  __%[1]s_debug "${FUNCNAME[0]}: found a flag ${words[c]}, skip the next argument"
211        c=$((c+1))
212        # if we are looking for a flags value, don't show commands
213        if [[ $c -eq $cword ]]; then
214            commands=()
215        fi
216    fi
217
218    c=$((c+1))
219
220}
221
222__%[1]s_handle_noun()
223{
224    __%[1]s_debug "${FUNCNAME[0]}: c is $c words[c] is ${words[c]}"
225
226    if __%[1]s_contains_word "${words[c]}" "${must_have_one_noun[@]}"; then
227        must_have_one_noun=()
228    elif __%[1]s_contains_word "${words[c]}" "${noun_aliases[@]}"; then
229        must_have_one_noun=()
230    fi
231
232    nouns+=("${words[c]}")
233    c=$((c+1))
234}
235
236__%[1]s_handle_command()
237{
238    __%[1]s_debug "${FUNCNAME[0]}: c is $c words[c] is ${words[c]}"
239
240    local next_command
241    if [[ -n ${last_command} ]]; then
242        next_command="_${last_command}_${words[c]//:/__}"
243    else
244        if [[ $c -eq 0 ]]; then
245            next_command="_%[1]s_root_command"
246        else
247            next_command="_${words[c]//:/__}"
248        fi
249    fi
250    c=$((c+1))
251    __%[1]s_debug "${FUNCNAME[0]}: looking for ${next_command}"
252    declare -F "$next_command" >/dev/null && $next_command
253}
254
255__%[1]s_handle_word()
256{
257    if [[ $c -ge $cword ]]; then
258        __%[1]s_handle_reply
259        return
260    fi
261    __%[1]s_debug "${FUNCNAME[0]}: c is $c words[c] is ${words[c]}"
262    if [[ "${words[c]}" == -* ]]; then
263        __%[1]s_handle_flag
264    elif __%[1]s_contains_word "${words[c]}" "${commands[@]}"; then
265        __%[1]s_handle_command
266    elif [[ $c -eq 0 ]]; then
267        __%[1]s_handle_command
268    elif __%[1]s_contains_word "${words[c]}" "${command_aliases[@]}"; then
269        # aliashash variable is an associative array which is only supported in bash > 3.
270        if [[ -z "${BASH_VERSION}" || "${BASH_VERSINFO[0]}" -gt 3 ]]; then
271            words[c]=${aliashash[${words[c]}]}
272            __%[1]s_handle_command
273        else
274            __%[1]s_handle_noun
275        fi
276    else
277        __%[1]s_handle_noun
278    fi
279    __%[1]s_handle_word
280}
281
282`, name))
283}
284
285func writePostscript(buf *bytes.Buffer, name string) {
286	name = strings.Replace(name, ":", "__", -1)
287	buf.WriteString(fmt.Sprintf("__start_%s()\n", name))
288	buf.WriteString(fmt.Sprintf(`{
289    local cur prev words cword
290    declare -A flaghash 2>/dev/null || :
291    declare -A aliashash 2>/dev/null || :
292    if declare -F _init_completion >/dev/null 2>&1; then
293        _init_completion -s || return
294    else
295        __%[1]s_init_completion -n "=" || return
296    fi
297
298    local c=0
299    local flags=()
300    local two_word_flags=()
301    local local_nonpersistent_flags=()
302    local flags_with_completion=()
303    local flags_completion=()
304    local commands=("%[1]s")
305    local must_have_one_flag=()
306    local must_have_one_noun=()
307    local last_command
308    local nouns=()
309
310    __%[1]s_handle_word
311}
312
313`, name))
314	buf.WriteString(fmt.Sprintf(`if [[ $(type -t compopt) = "builtin" ]]; then
315    complete -o default -F __start_%s %s
316else
317    complete -o default -o nospace -F __start_%s %s
318fi
319
320`, name, name, name, name))
321	buf.WriteString("# ex: ts=4 sw=4 et filetype=sh\n")
322}
323
324func writeCommands(buf *bytes.Buffer, cmd *Command) {
325	buf.WriteString("    commands=()\n")
326	for _, c := range cmd.Commands() {
327		if !c.IsAvailableCommand() || c == cmd.helpCommand {
328			continue
329		}
330		buf.WriteString(fmt.Sprintf("    commands+=(%q)\n", c.Name()))
331		writeCmdAliases(buf, c)
332	}
333	buf.WriteString("\n")
334}
335
336func writeFlagHandler(buf *bytes.Buffer, name string, annotations map[string][]string, cmd *Command) {
337	for key, value := range annotations {
338		switch key {
339		case BashCompFilenameExt:
340			buf.WriteString(fmt.Sprintf("    flags_with_completion+=(%q)\n", name))
341
342			var ext string
343			if len(value) > 0 {
344				ext = fmt.Sprintf("__%s_handle_filename_extension_flag ", cmd.Root().Name()) + strings.Join(value, "|")
345			} else {
346				ext = "_filedir"
347			}
348			buf.WriteString(fmt.Sprintf("    flags_completion+=(%q)\n", ext))
349		case BashCompCustom:
350			buf.WriteString(fmt.Sprintf("    flags_with_completion+=(%q)\n", name))
351			if len(value) > 0 {
352				handlers := strings.Join(value, "; ")
353				buf.WriteString(fmt.Sprintf("    flags_completion+=(%q)\n", handlers))
354			} else {
355				buf.WriteString("    flags_completion+=(:)\n")
356			}
357		case BashCompSubdirsInDir:
358			buf.WriteString(fmt.Sprintf("    flags_with_completion+=(%q)\n", name))
359
360			var ext string
361			if len(value) == 1 {
362				ext = fmt.Sprintf("__%s_handle_subdirs_in_dir_flag ", cmd.Root().Name()) + value[0]
363			} else {
364				ext = "_filedir -d"
365			}
366			buf.WriteString(fmt.Sprintf("    flags_completion+=(%q)\n", ext))
367		}
368	}
369}
370
371func writeShortFlag(buf *bytes.Buffer, flag *pflag.Flag, cmd *Command) {
372	name := flag.Shorthand
373	format := "    "
374	if len(flag.NoOptDefVal) == 0 {
375		format += "two_word_"
376	}
377	format += "flags+=(\"-%s\")\n"
378	buf.WriteString(fmt.Sprintf(format, name))
379	writeFlagHandler(buf, "-"+name, flag.Annotations, cmd)
380}
381
382func writeFlag(buf *bytes.Buffer, flag *pflag.Flag, cmd *Command) {
383	name := flag.Name
384	format := "    flags+=(\"--%s"
385	if len(flag.NoOptDefVal) == 0 {
386		format += "="
387	}
388	format += "\")\n"
389	buf.WriteString(fmt.Sprintf(format, name))
390	if len(flag.NoOptDefVal) == 0 {
391		format = "    two_word_flags+=(\"--%s\")\n"
392		buf.WriteString(fmt.Sprintf(format, name))
393	}
394	writeFlagHandler(buf, "--"+name, flag.Annotations, cmd)
395}
396
397func writeLocalNonPersistentFlag(buf *bytes.Buffer, flag *pflag.Flag) {
398	name := flag.Name
399	format := "    local_nonpersistent_flags+=(\"--%s"
400	if len(flag.NoOptDefVal) == 0 {
401		format += "="
402	}
403	format += "\")\n"
404	buf.WriteString(fmt.Sprintf(format, name))
405}
406
407func writeFlags(buf *bytes.Buffer, cmd *Command) {
408	buf.WriteString(`    flags=()
409    two_word_flags=()
410    local_nonpersistent_flags=()
411    flags_with_completion=()
412    flags_completion=()
413
414`)
415	localNonPersistentFlags := cmd.LocalNonPersistentFlags()
416	cmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) {
417		if nonCompletableFlag(flag) {
418			return
419		}
420		writeFlag(buf, flag, cmd)
421		if len(flag.Shorthand) > 0 {
422			writeShortFlag(buf, flag, cmd)
423		}
424		if localNonPersistentFlags.Lookup(flag.Name) != nil {
425			writeLocalNonPersistentFlag(buf, flag)
426		}
427	})
428	cmd.InheritedFlags().VisitAll(func(flag *pflag.Flag) {
429		if nonCompletableFlag(flag) {
430			return
431		}
432		writeFlag(buf, flag, cmd)
433		if len(flag.Shorthand) > 0 {
434			writeShortFlag(buf, flag, cmd)
435		}
436	})
437
438	buf.WriteString("\n")
439}
440
441func writeRequiredFlag(buf *bytes.Buffer, cmd *Command) {
442	buf.WriteString("    must_have_one_flag=()\n")
443	flags := cmd.NonInheritedFlags()
444	flags.VisitAll(func(flag *pflag.Flag) {
445		if nonCompletableFlag(flag) {
446			return
447		}
448		for key := range flag.Annotations {
449			switch key {
450			case BashCompOneRequiredFlag:
451				format := "    must_have_one_flag+=(\"--%s"
452				if flag.Value.Type() != "bool" {
453					format += "="
454				}
455				format += "\")\n"
456				buf.WriteString(fmt.Sprintf(format, flag.Name))
457
458				if len(flag.Shorthand) > 0 {
459					buf.WriteString(fmt.Sprintf("    must_have_one_flag+=(\"-%s\")\n", flag.Shorthand))
460				}
461			}
462		}
463	})
464}
465
466func writeRequiredNouns(buf *bytes.Buffer, cmd *Command) {
467	buf.WriteString("    must_have_one_noun=()\n")
468	sort.Sort(sort.StringSlice(cmd.ValidArgs))
469	for _, value := range cmd.ValidArgs {
470		buf.WriteString(fmt.Sprintf("    must_have_one_noun+=(%q)\n", value))
471	}
472}
473
474func writeCmdAliases(buf *bytes.Buffer, cmd *Command) {
475	if len(cmd.Aliases) == 0 {
476		return
477	}
478
479	sort.Sort(sort.StringSlice(cmd.Aliases))
480
481	buf.WriteString(fmt.Sprint(`    if [[ -z "${BASH_VERSION}" || "${BASH_VERSINFO[0]}" -gt 3 ]]; then`, "\n"))
482	for _, value := range cmd.Aliases {
483		buf.WriteString(fmt.Sprintf("        command_aliases+=(%q)\n", value))
484		buf.WriteString(fmt.Sprintf("        aliashash[%q]=%q\n", value, cmd.Name()))
485	}
486	buf.WriteString(`    fi`)
487	buf.WriteString("\n")
488}
489func writeArgAliases(buf *bytes.Buffer, cmd *Command) {
490	buf.WriteString("    noun_aliases=()\n")
491	sort.Sort(sort.StringSlice(cmd.ArgAliases))
492	for _, value := range cmd.ArgAliases {
493		buf.WriteString(fmt.Sprintf("    noun_aliases+=(%q)\n", value))
494	}
495}
496
497func gen(buf *bytes.Buffer, cmd *Command) {
498	for _, c := range cmd.Commands() {
499		if !c.IsAvailableCommand() || c == cmd.helpCommand {
500			continue
501		}
502		gen(buf, c)
503	}
504	commandName := cmd.CommandPath()
505	commandName = strings.Replace(commandName, " ", "_", -1)
506	commandName = strings.Replace(commandName, ":", "__", -1)
507
508	if cmd.Root() == cmd {
509		buf.WriteString(fmt.Sprintf("_%s_root_command()\n{\n", commandName))
510	} else {
511		buf.WriteString(fmt.Sprintf("_%s()\n{\n", commandName))
512	}
513
514	buf.WriteString(fmt.Sprintf("    last_command=%q\n", commandName))
515	buf.WriteString("\n")
516	buf.WriteString("    command_aliases=()\n")
517	buf.WriteString("\n")
518
519	writeCommands(buf, cmd)
520	writeFlags(buf, cmd)
521	writeRequiredFlag(buf, cmd)
522	writeRequiredNouns(buf, cmd)
523	writeArgAliases(buf, cmd)
524	buf.WriteString("}\n\n")
525}
526
527// GenBashCompletion generates bash completion file and writes to the passed writer.
528func (c *Command) GenBashCompletion(w io.Writer) error {
529	buf := new(bytes.Buffer)
530	writePreamble(buf, c.Name())
531	if len(c.BashCompletionFunction) > 0 {
532		buf.WriteString(c.BashCompletionFunction + "\n")
533	}
534	gen(buf, c)
535	writePostscript(buf, c.Name())
536
537	_, err := buf.WriteTo(w)
538	return err
539}
540
541func nonCompletableFlag(flag *pflag.Flag) bool {
542	return flag.Hidden || len(flag.Deprecated) > 0
543}
544
545// GenBashCompletionFile generates bash completion file.
546func (c *Command) GenBashCompletionFile(filename string) error {
547	outFile, err := os.Create(filename)
548	if err != nil {
549		return err
550	}
551	defer outFile.Close()
552
553	return c.GenBashCompletion(outFile)
554}
555