1package cobra
2
3import (
4	"bytes"
5	"fmt"
6	"io"
7	"os"
8	"sort"
9	"strings"
10
11	"github.com/spf13/pflag"
12)
13
14// Annotations for Bash completion.
15const (
16	BashCompFilenameExt     = "cobra_annotation_bash_completion_filename_extensions"
17	BashCompCustom          = "cobra_annotation_bash_completion_custom"
18	BashCompOneRequiredFlag = "cobra_annotation_bash_completion_one_required_flag"
19	BashCompSubdirsInDir    = "cobra_annotation_bash_completion_subdirs_in_dir"
20)
21
22func writePreamble(buf *bytes.Buffer, name string) {
23	buf.WriteString(fmt.Sprintf("# bash completion for %-36s -*- shell-script -*-\n", name))
24	buf.WriteString(fmt.Sprintf(`
25__%[1]s_debug()
26{
27    if [[ -n ${BASH_COMP_DEBUG_FILE} ]]; then
28        echo "$*" >> "${BASH_COMP_DEBUG_FILE}"
29    fi
30}
31
32# Homebrew on Macs have version 1.3 of bash-completion which doesn't include
33# _init_completion. This is a very minimal version of that function.
34__%[1]s_init_completion()
35{
36    COMPREPLY=()
37    _get_comp_words_by_ref "$@" cur prev words cword
38}
39
40__%[1]s_index_of_word()
41{
42    local w word=$1
43    shift
44    index=0
45    for w in "$@"; do
46        [[ $w = "$word" ]] && return
47        index=$((index+1))
48    done
49    index=-1
50}
51
52__%[1]s_contains_word()
53{
54    local w word=$1; shift
55    for w in "$@"; do
56        [[ $w = "$word" ]] && return
57    done
58    return 1
59}
60
61__%[1]s_handle_reply()
62{
63    __%[1]s_debug "${FUNCNAME[0]}"
64    case $cur in
65        -*)
66            if [[ $(type -t compopt) = "builtin" ]]; then
67                compopt -o nospace
68            fi
69            local allflags
70            if [ ${#must_have_one_flag[@]} -ne 0 ]; then
71                allflags=("${must_have_one_flag[@]}")
72            else
73                allflags=("${flags[*]} ${two_word_flags[*]}")
74            fi
75            COMPREPLY=( $(compgen -W "${allflags[*]}" -- "$cur") )
76            if [[ $(type -t compopt) = "builtin" ]]; then
77                [[ "${COMPREPLY[0]}" == *= ]] || compopt +o nospace
78            fi
79
80            # complete after --flag=abc
81            if [[ $cur == *=* ]]; then
82                if [[ $(type -t compopt) = "builtin" ]]; then
83                    compopt +o nospace
84                fi
85
86                local index flag
87                flag="${cur%%=*}"
88                __%[1]s_index_of_word "${flag}" "${flags_with_completion[@]}"
89                COMPREPLY=()
90                if [[ ${index} -ge 0 ]]; then
91                    PREFIX=""
92                    cur="${cur#*=}"
93                    ${flags_completion[${index}]}
94                    if [ -n "${ZSH_VERSION}" ]; then
95                        # zsh completion needs --flag= prefix
96                        eval "COMPREPLY=( \"\${COMPREPLY[@]/#/${flag}=}\" )"
97                    fi
98                fi
99            fi
100            return 0;
101            ;;
102    esac
103
104    # check if we are handling a flag with special work handling
105    local index
106    __%[1]s_index_of_word "${prev}" "${flags_with_completion[@]}"
107    if [[ ${index} -ge 0 ]]; then
108        ${flags_completion[${index}]}
109        return
110    fi
111
112    # we are parsing a flag and don't have a special handler, no completion
113    if [[ ${cur} != "${words[cword]}" ]]; then
114        return
115    fi
116
117    local completions
118    completions=("${commands[@]}")
119    if [[ ${#must_have_one_noun[@]} -ne 0 ]]; then
120        completions=("${must_have_one_noun[@]}")
121    fi
122    if [[ ${#must_have_one_flag[@]} -ne 0 ]]; then
123        completions+=("${must_have_one_flag[@]}")
124    fi
125    COMPREPLY=( $(compgen -W "${completions[*]}" -- "$cur") )
126
127    if [[ ${#COMPREPLY[@]} -eq 0 && ${#noun_aliases[@]} -gt 0 && ${#must_have_one_noun[@]} -ne 0 ]]; then
128        COMPREPLY=( $(compgen -W "${noun_aliases[*]}" -- "$cur") )
129    fi
130
131    if [[ ${#COMPREPLY[@]} -eq 0 ]]; then
132		if declare -F __%[1]s_custom_func >/dev/null; then
133			# try command name qualified custom func
134			__%[1]s_custom_func
135		else
136			# otherwise fall back to unqualified for compatibility
137			declare -F __custom_func >/dev/null && __custom_func
138		fi
139    fi
140
141    # available in bash-completion >= 2, not always present on macOS
142    if declare -F __ltrim_colon_completions >/dev/null; then
143        __ltrim_colon_completions "$cur"
144    fi
145
146    # If there is only 1 completion and it is a flag with an = it will be completed
147    # but we don't want a space after the =
148    if [[ "${#COMPREPLY[@]}" -eq "1" ]] && [[ $(type -t compopt) = "builtin" ]] && [[ "${COMPREPLY[0]}" == --*= ]]; then
149       compopt -o nospace
150    fi
151}
152
153# The arguments should be in the form "ext1|ext2|extn"
154__%[1]s_handle_filename_extension_flag()
155{
156    local ext="$1"
157    _filedir "@(${ext})"
158}
159
160__%[1]s_handle_subdirs_in_dir_flag()
161{
162    local dir="$1"
163    pushd "${dir}" >/dev/null 2>&1 && _filedir -d && popd >/dev/null 2>&1
164}
165
166__%[1]s_handle_flag()
167{
168    __%[1]s_debug "${FUNCNAME[0]}: c is $c words[c] is ${words[c]}"
169
170    # if a command required a flag, and we found it, unset must_have_one_flag()
171    local flagname=${words[c]}
172    local flagvalue
173    # if the word contained an =
174    if [[ ${words[c]} == *"="* ]]; then
175        flagvalue=${flagname#*=} # take in as flagvalue after the =
176        flagname=${flagname%%=*} # strip everything after the =
177        flagname="${flagname}=" # but put the = back
178    fi
179    __%[1]s_debug "${FUNCNAME[0]}: looking for ${flagname}"
180    if __%[1]s_contains_word "${flagname}" "${must_have_one_flag[@]}"; then
181        must_have_one_flag=()
182    fi
183
184    # if you set a flag which only applies to this command, don't show subcommands
185    if __%[1]s_contains_word "${flagname}" "${local_nonpersistent_flags[@]}"; then
186      commands=()
187    fi
188
189    # keep flag value with flagname as flaghash
190    # flaghash variable is an associative array which is only supported in bash > 3.
191    if [[ -z "${BASH_VERSION}" || "${BASH_VERSINFO[0]}" -gt 3 ]]; then
192        if [ -n "${flagvalue}" ] ; then
193            flaghash[${flagname}]=${flagvalue}
194        elif [ -n "${words[ $((c+1)) ]}" ] ; then
195            flaghash[${flagname}]=${words[ $((c+1)) ]}
196        else
197            flaghash[${flagname}]="true" # pad "true" for bool flag
198        fi
199    fi
200
201    # skip the argument to a two word flag
202    if [[ ${words[c]} != *"="* ]] && __%[1]s_contains_word "${words[c]}" "${two_word_flags[@]}"; then
203			  __%[1]s_debug "${FUNCNAME[0]}: found a flag ${words[c]}, skip the next argument"
204        c=$((c+1))
205        # if we are looking for a flags value, don't show commands
206        if [[ $c -eq $cword ]]; then
207            commands=()
208        fi
209    fi
210
211    c=$((c+1))
212
213}
214
215__%[1]s_handle_noun()
216{
217    __%[1]s_debug "${FUNCNAME[0]}: c is $c words[c] is ${words[c]}"
218
219    if __%[1]s_contains_word "${words[c]}" "${must_have_one_noun[@]}"; then
220        must_have_one_noun=()
221    elif __%[1]s_contains_word "${words[c]}" "${noun_aliases[@]}"; then
222        must_have_one_noun=()
223    fi
224
225    nouns+=("${words[c]}")
226    c=$((c+1))
227}
228
229__%[1]s_handle_command()
230{
231    __%[1]s_debug "${FUNCNAME[0]}: c is $c words[c] is ${words[c]}"
232
233    local next_command
234    if [[ -n ${last_command} ]]; then
235        next_command="_${last_command}_${words[c]//:/__}"
236    else
237        if [[ $c -eq 0 ]]; then
238            next_command="_%[1]s_root_command"
239        else
240            next_command="_${words[c]//:/__}"
241        fi
242    fi
243    c=$((c+1))
244    __%[1]s_debug "${FUNCNAME[0]}: looking for ${next_command}"
245    declare -F "$next_command" >/dev/null && $next_command
246}
247
248__%[1]s_handle_word()
249{
250    if [[ $c -ge $cword ]]; then
251        __%[1]s_handle_reply
252        return
253    fi
254    __%[1]s_debug "${FUNCNAME[0]}: c is $c words[c] is ${words[c]}"
255    if [[ "${words[c]}" == -* ]]; then
256        __%[1]s_handle_flag
257    elif __%[1]s_contains_word "${words[c]}" "${commands[@]}"; then
258        __%[1]s_handle_command
259    elif [[ $c -eq 0 ]]; then
260        __%[1]s_handle_command
261    elif __%[1]s_contains_word "${words[c]}" "${command_aliases[@]}"; then
262        # aliashash variable is an associative array which is only supported in bash > 3.
263        if [[ -z "${BASH_VERSION}" || "${BASH_VERSINFO[0]}" -gt 3 ]]; then
264            words[c]=${aliashash[${words[c]}]}
265            __%[1]s_handle_command
266        else
267            __%[1]s_handle_noun
268        fi
269    else
270        __%[1]s_handle_noun
271    fi
272    __%[1]s_handle_word
273}
274
275`, name))
276}
277
278func writePostscript(buf *bytes.Buffer, name string) {
279	name = strings.Replace(name, ":", "__", -1)
280	buf.WriteString(fmt.Sprintf("__start_%s()\n", name))
281	buf.WriteString(fmt.Sprintf(`{
282    local cur prev words cword
283    declare -A flaghash 2>/dev/null || :
284    declare -A aliashash 2>/dev/null || :
285    if declare -F _init_completion >/dev/null 2>&1; then
286        _init_completion -s || return
287    else
288        __%[1]s_init_completion -n "=" || return
289    fi
290
291    local c=0
292    local flags=()
293    local two_word_flags=()
294    local local_nonpersistent_flags=()
295    local flags_with_completion=()
296    local flags_completion=()
297    local commands=("%[1]s")
298    local must_have_one_flag=()
299    local must_have_one_noun=()
300    local last_command
301    local nouns=()
302
303    __%[1]s_handle_word
304}
305
306`, name))
307	buf.WriteString(fmt.Sprintf(`if [[ $(type -t compopt) = "builtin" ]]; then
308    complete -o default -F __start_%s %s
309else
310    complete -o default -o nospace -F __start_%s %s
311fi
312
313`, name, name, name, name))
314	buf.WriteString("# ex: ts=4 sw=4 et filetype=sh\n")
315}
316
317func writeCommands(buf *bytes.Buffer, cmd *Command) {
318	buf.WriteString("    commands=()\n")
319	for _, c := range cmd.Commands() {
320		if !c.IsAvailableCommand() || c == cmd.helpCommand {
321			continue
322		}
323		buf.WriteString(fmt.Sprintf("    commands+=(%q)\n", c.Name()))
324		writeCmdAliases(buf, c)
325	}
326	buf.WriteString("\n")
327}
328
329func writeFlagHandler(buf *bytes.Buffer, name string, annotations map[string][]string, cmd *Command) {
330	for key, value := range annotations {
331		switch key {
332		case BashCompFilenameExt:
333			buf.WriteString(fmt.Sprintf("    flags_with_completion+=(%q)\n", name))
334
335			var ext string
336			if len(value) > 0 {
337				ext = fmt.Sprintf("__%s_handle_filename_extension_flag ", cmd.Root().Name()) + strings.Join(value, "|")
338			} else {
339				ext = "_filedir"
340			}
341			buf.WriteString(fmt.Sprintf("    flags_completion+=(%q)\n", ext))
342		case BashCompCustom:
343			buf.WriteString(fmt.Sprintf("    flags_with_completion+=(%q)\n", name))
344			if len(value) > 0 {
345				handlers := strings.Join(value, "; ")
346				buf.WriteString(fmt.Sprintf("    flags_completion+=(%q)\n", handlers))
347			} else {
348				buf.WriteString("    flags_completion+=(:)\n")
349			}
350		case BashCompSubdirsInDir:
351			buf.WriteString(fmt.Sprintf("    flags_with_completion+=(%q)\n", name))
352
353			var ext string
354			if len(value) == 1 {
355				ext = fmt.Sprintf("__%s_handle_subdirs_in_dir_flag ", cmd.Root().Name()) + value[0]
356			} else {
357				ext = "_filedir -d"
358			}
359			buf.WriteString(fmt.Sprintf("    flags_completion+=(%q)\n", ext))
360		}
361	}
362}
363
364func writeShortFlag(buf *bytes.Buffer, flag *pflag.Flag, cmd *Command) {
365	name := flag.Shorthand
366	format := "    "
367	if len(flag.NoOptDefVal) == 0 {
368		format += "two_word_"
369	}
370	format += "flags+=(\"-%s\")\n"
371	buf.WriteString(fmt.Sprintf(format, name))
372	writeFlagHandler(buf, "-"+name, flag.Annotations, cmd)
373}
374
375func writeFlag(buf *bytes.Buffer, flag *pflag.Flag, cmd *Command) {
376	name := flag.Name
377	format := "    flags+=(\"--%s"
378	if len(flag.NoOptDefVal) == 0 {
379		format += "="
380	}
381	format += "\")\n"
382	buf.WriteString(fmt.Sprintf(format, name))
383	if len(flag.NoOptDefVal) == 0 {
384		format = "    two_word_flags+=(\"--%s\")\n"
385		buf.WriteString(fmt.Sprintf(format, name))
386	}
387	writeFlagHandler(buf, "--"+name, flag.Annotations, cmd)
388}
389
390func writeLocalNonPersistentFlag(buf *bytes.Buffer, flag *pflag.Flag) {
391	name := flag.Name
392	format := "    local_nonpersistent_flags+=(\"--%s"
393	if len(flag.NoOptDefVal) == 0 {
394		format += "="
395	}
396	format += "\")\n"
397	buf.WriteString(fmt.Sprintf(format, name))
398}
399
400func writeFlags(buf *bytes.Buffer, cmd *Command) {
401	buf.WriteString(`    flags=()
402    two_word_flags=()
403    local_nonpersistent_flags=()
404    flags_with_completion=()
405    flags_completion=()
406
407`)
408	localNonPersistentFlags := cmd.LocalNonPersistentFlags()
409	cmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) {
410		if nonCompletableFlag(flag) {
411			return
412		}
413		writeFlag(buf, flag, cmd)
414		if len(flag.Shorthand) > 0 {
415			writeShortFlag(buf, flag, cmd)
416		}
417		if localNonPersistentFlags.Lookup(flag.Name) != nil {
418			writeLocalNonPersistentFlag(buf, flag)
419		}
420	})
421	cmd.InheritedFlags().VisitAll(func(flag *pflag.Flag) {
422		if nonCompletableFlag(flag) {
423			return
424		}
425		writeFlag(buf, flag, cmd)
426		if len(flag.Shorthand) > 0 {
427			writeShortFlag(buf, flag, cmd)
428		}
429	})
430
431	buf.WriteString("\n")
432}
433
434func writeRequiredFlag(buf *bytes.Buffer, cmd *Command) {
435	buf.WriteString("    must_have_one_flag=()\n")
436	flags := cmd.NonInheritedFlags()
437	flags.VisitAll(func(flag *pflag.Flag) {
438		if nonCompletableFlag(flag) {
439			return
440		}
441		for key := range flag.Annotations {
442			switch key {
443			case BashCompOneRequiredFlag:
444				format := "    must_have_one_flag+=(\"--%s"
445				if flag.Value.Type() != "bool" {
446					format += "="
447				}
448				format += "\")\n"
449				buf.WriteString(fmt.Sprintf(format, flag.Name))
450
451				if len(flag.Shorthand) > 0 {
452					buf.WriteString(fmt.Sprintf("    must_have_one_flag+=(\"-%s\")\n", flag.Shorthand))
453				}
454			}
455		}
456	})
457}
458
459func writeRequiredNouns(buf *bytes.Buffer, cmd *Command) {
460	buf.WriteString("    must_have_one_noun=()\n")
461	sort.Sort(sort.StringSlice(cmd.ValidArgs))
462	for _, value := range cmd.ValidArgs {
463		buf.WriteString(fmt.Sprintf("    must_have_one_noun+=(%q)\n", value))
464	}
465}
466
467func writeCmdAliases(buf *bytes.Buffer, cmd *Command) {
468	if len(cmd.Aliases) == 0 {
469		return
470	}
471
472	sort.Sort(sort.StringSlice(cmd.Aliases))
473
474	buf.WriteString(fmt.Sprint(`    if [[ -z "${BASH_VERSION}" || "${BASH_VERSINFO[0]}" -gt 3 ]]; then`, "\n"))
475	for _, value := range cmd.Aliases {
476		buf.WriteString(fmt.Sprintf("        command_aliases+=(%q)\n", value))
477		buf.WriteString(fmt.Sprintf("        aliashash[%q]=%q\n", value, cmd.Name()))
478	}
479	buf.WriteString(`    fi`)
480	buf.WriteString("\n")
481}
482func writeArgAliases(buf *bytes.Buffer, cmd *Command) {
483	buf.WriteString("    noun_aliases=()\n")
484	sort.Sort(sort.StringSlice(cmd.ArgAliases))
485	for _, value := range cmd.ArgAliases {
486		buf.WriteString(fmt.Sprintf("    noun_aliases+=(%q)\n", value))
487	}
488}
489
490func gen(buf *bytes.Buffer, cmd *Command) {
491	for _, c := range cmd.Commands() {
492		if !c.IsAvailableCommand() || c == cmd.helpCommand {
493			continue
494		}
495		gen(buf, c)
496	}
497	commandName := cmd.CommandPath()
498	commandName = strings.Replace(commandName, " ", "_", -1)
499	commandName = strings.Replace(commandName, ":", "__", -1)
500
501	if cmd.Root() == cmd {
502		buf.WriteString(fmt.Sprintf("_%s_root_command()\n{\n", commandName))
503	} else {
504		buf.WriteString(fmt.Sprintf("_%s()\n{\n", commandName))
505	}
506
507	buf.WriteString(fmt.Sprintf("    last_command=%q\n", commandName))
508	buf.WriteString("\n")
509	buf.WriteString("    command_aliases=()\n")
510	buf.WriteString("\n")
511
512	writeCommands(buf, cmd)
513	writeFlags(buf, cmd)
514	writeRequiredFlag(buf, cmd)
515	writeRequiredNouns(buf, cmd)
516	writeArgAliases(buf, cmd)
517	buf.WriteString("}\n\n")
518}
519
520// GenBashCompletion generates bash completion file and writes to the passed writer.
521func (c *Command) GenBashCompletion(w io.Writer) error {
522	buf := new(bytes.Buffer)
523	writePreamble(buf, c.Name())
524	if len(c.BashCompletionFunction) > 0 {
525		buf.WriteString(c.BashCompletionFunction + "\n")
526	}
527	gen(buf, c)
528	writePostscript(buf, c.Name())
529
530	_, err := buf.WriteTo(w)
531	return err
532}
533
534func nonCompletableFlag(flag *pflag.Flag) bool {
535	return flag.Hidden || len(flag.Deprecated) > 0
536}
537
538// GenBashCompletionFile generates bash completion file.
539func (c *Command) GenBashCompletionFile(filename string) error {
540	outFile, err := os.Create(filename)
541	if err != nil {
542		return err
543	}
544	defer outFile.Close()
545
546	return c.GenBashCompletion(outFile)
547}
548