1package cobra
2
3import (
4	"bytes"
5	"fmt"
6	"io"
7	"os"
8	"sort"
9	"strings"
10
11	"github.com/spf13/pflag"
12)
13
14// Annotations for Bash completion.
15const (
16	BashCompFilenameExt     = "cobra_annotation_bash_completion_filename_extensions"
17	BashCompCustom          = "cobra_annotation_bash_completion_custom"
18	BashCompOneRequiredFlag = "cobra_annotation_bash_completion_one_required_flag"
19	BashCompSubdirsInDir    = "cobra_annotation_bash_completion_subdirs_in_dir"
20)
21
22func writePreamble(buf *bytes.Buffer, name string) {
23	buf.WriteString(fmt.Sprintf("# bash completion for %-36s -*- shell-script -*-\n", name))
24	buf.WriteString(fmt.Sprintf(`
25__%[1]s_debug()
26{
27    if [[ -n ${BASH_COMP_DEBUG_FILE} ]]; then
28        echo "$*" >> "${BASH_COMP_DEBUG_FILE}"
29    fi
30}
31
32# Homebrew on Macs have version 1.3 of bash-completion which doesn't include
33# _init_completion. This is a very minimal version of that function.
34__%[1]s_init_completion()
35{
36    COMPREPLY=()
37    _get_comp_words_by_ref "$@" cur prev words cword
38}
39
40__%[1]s_index_of_word()
41{
42    local w word=$1
43    shift
44    index=0
45    for w in "$@"; do
46        [[ $w = "$word" ]] && return
47        index=$((index+1))
48    done
49    index=-1
50}
51
52__%[1]s_contains_word()
53{
54    local w word=$1; shift
55    for w in "$@"; do
56        [[ $w = "$word" ]] && return
57    done
58    return 1
59}
60
61__%[1]s_handle_reply()
62{
63    __%[1]s_debug "${FUNCNAME[0]}"
64    case $cur in
65        -*)
66            if [[ $(type -t compopt) = "builtin" ]]; then
67                compopt -o nospace
68            fi
69            local allflags
70            if [ ${#must_have_one_flag[@]} -ne 0 ]; then
71                allflags=("${must_have_one_flag[@]}")
72            else
73                allflags=("${flags[*]} ${two_word_flags[*]}")
74            fi
75            COMPREPLY=( $(compgen -W "${allflags[*]}" -- "$cur") )
76            if [[ $(type -t compopt) = "builtin" ]]; then
77                [[ "${COMPREPLY[0]}" == *= ]] || compopt +o nospace
78            fi
79
80            # complete after --flag=abc
81            if [[ $cur == *=* ]]; then
82                if [[ $(type -t compopt) = "builtin" ]]; then
83                    compopt +o nospace
84                fi
85
86                local index flag
87                flag="${cur%%=*}"
88                __%[1]s_index_of_word "${flag}" "${flags_with_completion[@]}"
89                COMPREPLY=()
90                if [[ ${index} -ge 0 ]]; then
91                    PREFIX=""
92                    cur="${cur#*=}"
93                    ${flags_completion[${index}]}
94                    if [ -n "${ZSH_VERSION}" ]; then
95                        # zsh completion needs --flag= prefix
96                        eval "COMPREPLY=( \"\${COMPREPLY[@]/#/${flag}=}\" )"
97                    fi
98                fi
99            fi
100            return 0;
101            ;;
102    esac
103
104    # check if we are handling a flag with special work handling
105    local index
106    __%[1]s_index_of_word "${prev}" "${flags_with_completion[@]}"
107    if [[ ${index} -ge 0 ]]; then
108        ${flags_completion[${index}]}
109        return
110    fi
111
112    # we are parsing a flag and don't have a special handler, no completion
113    if [[ ${cur} != "${words[cword]}" ]]; then
114        return
115    fi
116
117    local completions
118    completions=("${commands[@]}")
119    if [[ ${#must_have_one_noun[@]} -ne 0 ]]; then
120        completions=("${must_have_one_noun[@]}")
121    fi
122    if [[ ${#must_have_one_flag[@]} -ne 0 ]]; then
123        completions+=("${must_have_one_flag[@]}")
124    fi
125    COMPREPLY=( $(compgen -W "${completions[*]}" -- "$cur") )
126
127    if [[ ${#COMPREPLY[@]} -eq 0 && ${#noun_aliases[@]} -gt 0 && ${#must_have_one_noun[@]} -ne 0 ]]; then
128        COMPREPLY=( $(compgen -W "${noun_aliases[*]}" -- "$cur") )
129    fi
130
131    if [[ ${#COMPREPLY[@]} -eq 0 ]]; then
132        declare -F __custom_func >/dev/null && __custom_func
133    fi
134
135    # available in bash-completion >= 2, not always present on macOS
136    if declare -F __ltrim_colon_completions >/dev/null; then
137        __ltrim_colon_completions "$cur"
138    fi
139
140    # If there is only 1 completion and it is a flag with an = it will be completed
141    # but we don't want a space after the =
142    if [[ "${#COMPREPLY[@]}" -eq "1" ]] && [[ $(type -t compopt) = "builtin" ]] && [[ "${COMPREPLY[0]}" == --*= ]]; then
143       compopt -o nospace
144    fi
145}
146
147# The arguments should be in the form "ext1|ext2|extn"
148__%[1]s_handle_filename_extension_flag()
149{
150    local ext="$1"
151    _filedir "@(${ext})"
152}
153
154__%[1]s_handle_subdirs_in_dir_flag()
155{
156    local dir="$1"
157    pushd "${dir}" >/dev/null 2>&1 && _filedir -d && popd >/dev/null 2>&1
158}
159
160__%[1]s_handle_flag()
161{
162    __%[1]s_debug "${FUNCNAME[0]}: c is $c words[c] is ${words[c]}"
163
164    # if a command required a flag, and we found it, unset must_have_one_flag()
165    local flagname=${words[c]}
166    local flagvalue
167    # if the word contained an =
168    if [[ ${words[c]} == *"="* ]]; then
169        flagvalue=${flagname#*=} # take in as flagvalue after the =
170        flagname=${flagname%%=*} # strip everything after the =
171        flagname="${flagname}=" # but put the = back
172    fi
173    __%[1]s_debug "${FUNCNAME[0]}: looking for ${flagname}"
174    if __%[1]s_contains_word "${flagname}" "${must_have_one_flag[@]}"; then
175        must_have_one_flag=()
176    fi
177
178    # if you set a flag which only applies to this command, don't show subcommands
179    if __%[1]s_contains_word "${flagname}" "${local_nonpersistent_flags[@]}"; then
180      commands=()
181    fi
182
183    # keep flag value with flagname as flaghash
184    # flaghash variable is an associative array which is only supported in bash > 3.
185    if [[ -z "${BASH_VERSION}" || "${BASH_VERSINFO[0]}" -gt 3 ]]; then
186        if [ -n "${flagvalue}" ] ; then
187            flaghash[${flagname}]=${flagvalue}
188        elif [ -n "${words[ $((c+1)) ]}" ] ; then
189            flaghash[${flagname}]=${words[ $((c+1)) ]}
190        else
191            flaghash[${flagname}]="true" # pad "true" for bool flag
192        fi
193    fi
194
195    # skip the argument to a two word flag
196    if __%[1]s_contains_word "${words[c]}" "${two_word_flags[@]}"; then
197        c=$((c+1))
198        # if we are looking for a flags value, don't show commands
199        if [[ $c -eq $cword ]]; then
200            commands=()
201        fi
202    fi
203
204    c=$((c+1))
205
206}
207
208__%[1]s_handle_noun()
209{
210    __%[1]s_debug "${FUNCNAME[0]}: c is $c words[c] is ${words[c]}"
211
212    if __%[1]s_contains_word "${words[c]}" "${must_have_one_noun[@]}"; then
213        must_have_one_noun=()
214    elif __%[1]s_contains_word "${words[c]}" "${noun_aliases[@]}"; then
215        must_have_one_noun=()
216    fi
217
218    nouns+=("${words[c]}")
219    c=$((c+1))
220}
221
222__%[1]s_handle_command()
223{
224    __%[1]s_debug "${FUNCNAME[0]}: c is $c words[c] is ${words[c]}"
225
226    local next_command
227    if [[ -n ${last_command} ]]; then
228        next_command="_${last_command}_${words[c]//:/__}"
229    else
230        if [[ $c -eq 0 ]]; then
231            next_command="_%[1]s_root_command"
232        else
233            next_command="_${words[c]//:/__}"
234        fi
235    fi
236    c=$((c+1))
237    __%[1]s_debug "${FUNCNAME[0]}: looking for ${next_command}"
238    declare -F "$next_command" >/dev/null && $next_command
239}
240
241__%[1]s_handle_word()
242{
243    if [[ $c -ge $cword ]]; then
244        __%[1]s_handle_reply
245        return
246    fi
247    __%[1]s_debug "${FUNCNAME[0]}: c is $c words[c] is ${words[c]}"
248    if [[ "${words[c]}" == -* ]]; then
249        __%[1]s_handle_flag
250    elif __%[1]s_contains_word "${words[c]}" "${commands[@]}"; then
251        __%[1]s_handle_command
252    elif [[ $c -eq 0 ]]; then
253        __%[1]s_handle_command
254    elif __%[1]s_contains_word "${words[c]}" "${command_aliases[@]}"; then
255        # aliashash variable is an associative array which is only supported in bash > 3.
256        if [[ -z "${BASH_VERSION}" || "${BASH_VERSINFO[0]}" -gt 3 ]]; then
257            words[c]=${aliashash[${words[c]}]}
258            __%[1]s_handle_command
259        else
260            __%[1]s_handle_noun
261        fi
262    else
263        __%[1]s_handle_noun
264    fi
265    __%[1]s_handle_word
266}
267
268`, name))
269}
270
271func writePostscript(buf *bytes.Buffer, name string) {
272	name = strings.Replace(name, ":", "__", -1)
273	buf.WriteString(fmt.Sprintf("__start_%s()\n", name))
274	buf.WriteString(fmt.Sprintf(`{
275    local cur prev words cword
276    declare -A flaghash 2>/dev/null || :
277    declare -A aliashash 2>/dev/null || :
278    if declare -F _init_completion >/dev/null 2>&1; then
279        _init_completion -s || return
280    else
281        __%[1]s_init_completion -n "=" || return
282    fi
283
284    local c=0
285    local flags=()
286    local two_word_flags=()
287    local local_nonpersistent_flags=()
288    local flags_with_completion=()
289    local flags_completion=()
290    local commands=("%[1]s")
291    local must_have_one_flag=()
292    local must_have_one_noun=()
293    local last_command
294    local nouns=()
295
296    __%[1]s_handle_word
297}
298
299`, name))
300	buf.WriteString(fmt.Sprintf(`if [[ $(type -t compopt) = "builtin" ]]; then
301    complete -o default -F __start_%s %s
302else
303    complete -o default -o nospace -F __start_%s %s
304fi
305
306`, name, name, name, name))
307	buf.WriteString("# ex: ts=4 sw=4 et filetype=sh\n")
308}
309
310func writeCommands(buf *bytes.Buffer, cmd *Command) {
311	buf.WriteString("    commands=()\n")
312	for _, c := range cmd.Commands() {
313		if !c.IsAvailableCommand() || c == cmd.helpCommand {
314			continue
315		}
316		buf.WriteString(fmt.Sprintf("    commands+=(%q)\n", c.Name()))
317		writeCmdAliases(buf, c)
318	}
319	buf.WriteString("\n")
320}
321
322func writeFlagHandler(buf *bytes.Buffer, name string, annotations map[string][]string, cmd *Command) {
323	for key, value := range annotations {
324		switch key {
325		case BashCompFilenameExt:
326			buf.WriteString(fmt.Sprintf("    flags_with_completion+=(%q)\n", name))
327
328			var ext string
329			if len(value) > 0 {
330				ext = fmt.Sprintf("__%s_handle_filename_extension_flag ", cmd.Root().Name()) + strings.Join(value, "|")
331			} else {
332				ext = "_filedir"
333			}
334			buf.WriteString(fmt.Sprintf("    flags_completion+=(%q)\n", ext))
335		case BashCompCustom:
336			buf.WriteString(fmt.Sprintf("    flags_with_completion+=(%q)\n", name))
337			if len(value) > 0 {
338				handlers := strings.Join(value, "; ")
339				buf.WriteString(fmt.Sprintf("    flags_completion+=(%q)\n", handlers))
340			} else {
341				buf.WriteString("    flags_completion+=(:)\n")
342			}
343		case BashCompSubdirsInDir:
344			buf.WriteString(fmt.Sprintf("    flags_with_completion+=(%q)\n", name))
345
346			var ext string
347			if len(value) == 1 {
348				ext = fmt.Sprintf("__%s_handle_subdirs_in_dir_flag ", cmd.Root().Name()) + value[0]
349			} else {
350				ext = "_filedir -d"
351			}
352			buf.WriteString(fmt.Sprintf("    flags_completion+=(%q)\n", ext))
353		}
354	}
355}
356
357func writeShortFlag(buf *bytes.Buffer, flag *pflag.Flag, cmd *Command) {
358	name := flag.Shorthand
359	format := "    "
360	if len(flag.NoOptDefVal) == 0 {
361		format += "two_word_"
362	}
363	format += "flags+=(\"-%s\")\n"
364	buf.WriteString(fmt.Sprintf(format, name))
365	writeFlagHandler(buf, "-"+name, flag.Annotations, cmd)
366}
367
368func writeFlag(buf *bytes.Buffer, flag *pflag.Flag, cmd *Command) {
369	name := flag.Name
370	format := "    flags+=(\"--%s"
371	if len(flag.NoOptDefVal) == 0 {
372		format += "="
373	}
374	format += "\")\n"
375	buf.WriteString(fmt.Sprintf(format, name))
376	writeFlagHandler(buf, "--"+name, flag.Annotations, cmd)
377}
378
379func writeLocalNonPersistentFlag(buf *bytes.Buffer, flag *pflag.Flag) {
380	name := flag.Name
381	format := "    local_nonpersistent_flags+=(\"--%s"
382	if len(flag.NoOptDefVal) == 0 {
383		format += "="
384	}
385	format += "\")\n"
386	buf.WriteString(fmt.Sprintf(format, name))
387}
388
389func writeFlags(buf *bytes.Buffer, cmd *Command) {
390	buf.WriteString(`    flags=()
391    two_word_flags=()
392    local_nonpersistent_flags=()
393    flags_with_completion=()
394    flags_completion=()
395
396`)
397	localNonPersistentFlags := cmd.LocalNonPersistentFlags()
398	cmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) {
399		if nonCompletableFlag(flag) {
400			return
401		}
402		writeFlag(buf, flag, cmd)
403		if len(flag.Shorthand) > 0 {
404			writeShortFlag(buf, flag, cmd)
405		}
406		if localNonPersistentFlags.Lookup(flag.Name) != nil {
407			writeLocalNonPersistentFlag(buf, flag)
408		}
409	})
410	cmd.InheritedFlags().VisitAll(func(flag *pflag.Flag) {
411		if nonCompletableFlag(flag) {
412			return
413		}
414		writeFlag(buf, flag, cmd)
415		if len(flag.Shorthand) > 0 {
416			writeShortFlag(buf, flag, cmd)
417		}
418	})
419
420	buf.WriteString("\n")
421}
422
423func writeRequiredFlag(buf *bytes.Buffer, cmd *Command) {
424	buf.WriteString("    must_have_one_flag=()\n")
425	flags := cmd.NonInheritedFlags()
426	flags.VisitAll(func(flag *pflag.Flag) {
427		if nonCompletableFlag(flag) {
428			return
429		}
430		for key := range flag.Annotations {
431			switch key {
432			case BashCompOneRequiredFlag:
433				format := "    must_have_one_flag+=(\"--%s"
434				if flag.Value.Type() != "bool" {
435					format += "="
436				}
437				format += "\")\n"
438				buf.WriteString(fmt.Sprintf(format, flag.Name))
439
440				if len(flag.Shorthand) > 0 {
441					buf.WriteString(fmt.Sprintf("    must_have_one_flag+=(\"-%s\")\n", flag.Shorthand))
442				}
443			}
444		}
445	})
446}
447
448func writeRequiredNouns(buf *bytes.Buffer, cmd *Command) {
449	buf.WriteString("    must_have_one_noun=()\n")
450	sort.Sort(sort.StringSlice(cmd.ValidArgs))
451	for _, value := range cmd.ValidArgs {
452		buf.WriteString(fmt.Sprintf("    must_have_one_noun+=(%q)\n", value))
453	}
454}
455
456func writeCmdAliases(buf *bytes.Buffer, cmd *Command) {
457	if len(cmd.Aliases) == 0 {
458		return
459	}
460
461	sort.Sort(sort.StringSlice(cmd.Aliases))
462
463	buf.WriteString(fmt.Sprint(`    if [[ -z "${BASH_VERSION}" || "${BASH_VERSINFO[0]}" -gt 3 ]]; then`, "\n"))
464	for _, value := range cmd.Aliases {
465		buf.WriteString(fmt.Sprintf("        command_aliases+=(%q)\n", value))
466		buf.WriteString(fmt.Sprintf("        aliashash[%q]=%q\n", value, cmd.Name()))
467	}
468	buf.WriteString(`    fi`)
469	buf.WriteString("\n")
470}
471func writeArgAliases(buf *bytes.Buffer, cmd *Command) {
472	buf.WriteString("    noun_aliases=()\n")
473	sort.Sort(sort.StringSlice(cmd.ArgAliases))
474	for _, value := range cmd.ArgAliases {
475		buf.WriteString(fmt.Sprintf("    noun_aliases+=(%q)\n", value))
476	}
477}
478
479func gen(buf *bytes.Buffer, cmd *Command) {
480	for _, c := range cmd.Commands() {
481		if !c.IsAvailableCommand() || c == cmd.helpCommand {
482			continue
483		}
484		gen(buf, c)
485	}
486	commandName := cmd.CommandPath()
487	commandName = strings.Replace(commandName, " ", "_", -1)
488	commandName = strings.Replace(commandName, ":", "__", -1)
489
490	if cmd.Root() == cmd {
491		buf.WriteString(fmt.Sprintf("_%s_root_command()\n{\n", commandName))
492	} else {
493		buf.WriteString(fmt.Sprintf("_%s()\n{\n", commandName))
494	}
495
496	buf.WriteString(fmt.Sprintf("    last_command=%q\n", commandName))
497	buf.WriteString("\n")
498	buf.WriteString("    command_aliases=()\n")
499	buf.WriteString("\n")
500
501	writeCommands(buf, cmd)
502	writeFlags(buf, cmd)
503	writeRequiredFlag(buf, cmd)
504	writeRequiredNouns(buf, cmd)
505	writeArgAliases(buf, cmd)
506	buf.WriteString("}\n\n")
507}
508
509// GenBashCompletion generates bash completion file and writes to the passed writer.
510func (c *Command) GenBashCompletion(w io.Writer) error {
511	buf := new(bytes.Buffer)
512	writePreamble(buf, c.Name())
513	if len(c.BashCompletionFunction) > 0 {
514		buf.WriteString(c.BashCompletionFunction + "\n")
515	}
516	gen(buf, c)
517	writePostscript(buf, c.Name())
518
519	_, err := buf.WriteTo(w)
520	return err
521}
522
523func nonCompletableFlag(flag *pflag.Flag) bool {
524	return flag.Hidden || len(flag.Deprecated) > 0
525}
526
527// GenBashCompletionFile generates bash completion file.
528func (c *Command) GenBashCompletionFile(filename string) error {
529	outFile, err := os.Create(filename)
530	if err != nil {
531		return err
532	}
533	defer outFile.Close()
534
535	return c.GenBashCompletion(outFile)
536}
537