1package cobra
2
3import (
4	"bytes"
5	"fmt"
6	"os"
7	"sort"
8	"strings"
9
10	"github.com/spf13/pflag"
11)
12
13const (
14	BashCompFilenameExt     = "cobra_annotation_bash_completion_filename_extentions"
15	BashCompOneRequiredFlag = "cobra_annotation_bash_completion_one_required_flag"
16	BashCompSubdirsInDir    = "cobra_annotation_bash_completion_subdirs_in_dir"
17)
18
19func preamble(out *bytes.Buffer) {
20	fmt.Fprintf(out, `#!/bin/bash
21
22__debug()
23{
24    if [[ -n ${BASH_COMP_DEBUG_FILE} ]]; then
25        echo "$*" >> "${BASH_COMP_DEBUG_FILE}"
26    fi
27}
28
29# Homebrew on Macs have version 1.3 of bash-completion which doesn't include
30# _init_completion. This is a very minimal version of that function.
31__my_init_completion()
32{
33    COMPREPLY=()
34    _get_comp_words_by_ref cur prev words cword
35}
36
37__index_of_word()
38{
39    local w word=$1
40    shift
41    index=0
42    for w in "$@"; do
43        [[ $w = "$word" ]] && return
44        index=$((index+1))
45    done
46    index=-1
47}
48
49__contains_word()
50{
51    local w word=$1; shift
52    for w in "$@"; do
53        [[ $w = "$word" ]] && return
54    done
55    return 1
56}
57
58__handle_reply()
59{
60    __debug "${FUNCNAME}"
61    case $cur in
62        -*)
63            if [[ $(type -t compopt) = "builtin" ]]; then
64                compopt -o nospace
65            fi
66            local allflags
67            if [ ${#must_have_one_flag[@]} -ne 0 ]; then
68                allflags=("${must_have_one_flag[@]}")
69            else
70                allflags=("${flags[*]} ${two_word_flags[*]}")
71            fi
72            COMPREPLY=( $(compgen -W "${allflags[*]}" -- "$cur") )
73            if [[ $(type -t compopt) = "builtin" ]]; then
74                [[ $COMPREPLY == *= ]] || compopt +o nospace
75            fi
76            return 0;
77            ;;
78    esac
79
80    # check if we are handling a flag with special work handling
81    local index
82    __index_of_word "${prev}" "${flags_with_completion[@]}"
83    if [[ ${index} -ge 0 ]]; then
84        ${flags_completion[${index}]}
85        return
86    fi
87
88    # we are parsing a flag and don't have a special handler, no completion
89    if [[ ${cur} != "${words[cword]}" ]]; then
90        return
91    fi
92
93    local completions
94    if [[ ${#must_have_one_flag[@]} -ne 0 ]]; then
95        completions=("${must_have_one_flag[@]}")
96    elif [[ ${#must_have_one_noun[@]} -ne 0 ]]; then
97        completions=("${must_have_one_noun[@]}")
98    else
99        completions=("${commands[@]}")
100    fi
101    COMPREPLY=( $(compgen -W "${completions[*]}" -- "$cur") )
102
103    if [[ ${#COMPREPLY[@]} -eq 0 ]]; then
104        declare -F __custom_func >/dev/null && __custom_func
105    fi
106}
107
108# The arguments should be in the form "ext1|ext2|extn"
109__handle_filename_extension_flag()
110{
111    local ext="$1"
112    _filedir "@(${ext})"
113}
114
115__handle_subdirs_in_dir_flag()
116{
117    local dir="$1"
118    pushd "${dir}" >/dev/null 2>&1 && _filedir -d && popd >/dev/null 2>&1
119}
120
121__handle_flag()
122{
123    __debug "${FUNCNAME}: c is $c words[c] is ${words[c]}"
124
125    # if a command required a flag, and we found it, unset must_have_one_flag()
126    local flagname=${words[c]}
127    # if the word contained an =
128    if [[ ${words[c]} == *"="* ]]; then
129        flagname=${flagname%%=*} # strip everything after the =
130        flagname="${flagname}=" # but put the = back
131    fi
132    __debug "${FUNCNAME}: looking for ${flagname}"
133    if __contains_word "${flagname}" "${must_have_one_flag[@]}"; then
134        must_have_one_flag=()
135    fi
136
137    # skip the argument to a two word flag
138    if __contains_word "${words[c]}" "${two_word_flags[@]}"; then
139        c=$((c+1))
140        # if we are looking for a flags value, don't show commands
141        if [[ $c -eq $cword ]]; then
142            commands=()
143        fi
144    fi
145
146    # skip the flag itself
147    c=$((c+1))
148
149}
150
151__handle_noun()
152{
153    __debug "${FUNCNAME}: c is $c words[c] is ${words[c]}"
154
155    if __contains_word "${words[c]}" "${must_have_one_noun[@]}"; then
156        must_have_one_noun=()
157    fi
158
159    nouns+=("${words[c]}")
160    c=$((c+1))
161}
162
163__handle_command()
164{
165    __debug "${FUNCNAME}: c is $c words[c] is ${words[c]}"
166
167    local next_command
168    if [[ -n ${last_command} ]]; then
169        next_command="_${last_command}_${words[c]}"
170    else
171        next_command="_${words[c]}"
172    fi
173    c=$((c+1))
174    __debug "${FUNCNAME}: looking for ${next_command}"
175    declare -F $next_command >/dev/null && $next_command
176}
177
178__handle_word()
179{
180    if [[ $c -ge $cword ]]; then
181        __handle_reply
182        return
183    fi
184    __debug "${FUNCNAME}: c is $c words[c] is ${words[c]}"
185    if [[ "${words[c]}" == -* ]]; then
186        __handle_flag
187    elif __contains_word "${words[c]}" "${commands[@]}"; then
188        __handle_command
189    else
190        __handle_noun
191    fi
192    __handle_word
193}
194
195`)
196}
197
198func postscript(out *bytes.Buffer, name string) {
199	fmt.Fprintf(out, "__start_%s()\n", name)
200	fmt.Fprintf(out, `{
201    local cur prev words cword
202    if declare -F _init_completion >/dev/null 2>&1; then
203        _init_completion -s || return
204    else
205        __my_init_completion || return
206    fi
207
208    local c=0
209    local flags=()
210    local two_word_flags=()
211    local flags_with_completion=()
212    local flags_completion=()
213    local commands=("%s")
214    local must_have_one_flag=()
215    local must_have_one_noun=()
216    local last_command
217    local nouns=()
218
219    __handle_word
220}
221
222`, name)
223	fmt.Fprintf(out, `if [[ $(type -t compopt) = "builtin" ]]; then
224    complete -F __start_%s %s
225else
226    complete -o nospace -F __start_%s %s
227fi
228
229`, name, name, name, name)
230	fmt.Fprintf(out, "# ex: ts=4 sw=4 et filetype=sh\n")
231}
232
233func writeCommands(cmd *Command, out *bytes.Buffer) {
234	fmt.Fprintf(out, "    commands=()\n")
235	for _, c := range cmd.Commands() {
236		if !c.IsAvailableCommand() || c == cmd.helpCommand {
237			continue
238		}
239		fmt.Fprintf(out, "    commands+=(%q)\n", c.Name())
240	}
241	fmt.Fprintf(out, "\n")
242}
243
244func writeFlagHandler(name string, annotations map[string][]string, out *bytes.Buffer) {
245	for key, value := range annotations {
246		switch key {
247		case BashCompFilenameExt:
248			fmt.Fprintf(out, "    flags_with_completion+=(%q)\n", name)
249
250			if len(value) > 0 {
251				ext := "__handle_filename_extension_flag " + strings.Join(value, "|")
252				fmt.Fprintf(out, "    flags_completion+=(%q)\n", ext)
253			} else {
254				ext := "_filedir"
255				fmt.Fprintf(out, "    flags_completion+=(%q)\n", ext)
256			}
257		case BashCompSubdirsInDir:
258			fmt.Fprintf(out, "    flags_with_completion+=(%q)\n", name)
259
260			if len(value) == 1 {
261				ext := "__handle_subdirs_in_dir_flag " + value[0]
262				fmt.Fprintf(out, "    flags_completion+=(%q)\n", ext)
263			} else {
264				ext := "_filedir -d"
265				fmt.Fprintf(out, "    flags_completion+=(%q)\n", ext)
266			}
267		}
268	}
269}
270
271func writeShortFlag(flag *pflag.Flag, out *bytes.Buffer) {
272	b := (flag.Value.Type() == "bool")
273	name := flag.Shorthand
274	format := "    "
275	if !b {
276		format += "two_word_"
277	}
278	format += "flags+=(\"-%s\")\n"
279	fmt.Fprintf(out, format, name)
280	writeFlagHandler("-"+name, flag.Annotations, out)
281}
282
283func writeFlag(flag *pflag.Flag, out *bytes.Buffer) {
284	b := (flag.Value.Type() == "bool")
285	name := flag.Name
286	format := "    flags+=(\"--%s"
287	if !b {
288		format += "="
289	}
290	format += "\")\n"
291	fmt.Fprintf(out, format, name)
292	writeFlagHandler("--"+name, flag.Annotations, out)
293}
294
295func writeFlags(cmd *Command, out *bytes.Buffer) {
296	fmt.Fprintf(out, `    flags=()
297    two_word_flags=()
298    flags_with_completion=()
299    flags_completion=()
300
301`)
302	cmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) {
303		writeFlag(flag, out)
304		if len(flag.Shorthand) > 0 {
305			writeShortFlag(flag, out)
306		}
307	})
308	cmd.InheritedFlags().VisitAll(func(flag *pflag.Flag) {
309		writeFlag(flag, out)
310		if len(flag.Shorthand) > 0 {
311			writeShortFlag(flag, out)
312		}
313	})
314
315	fmt.Fprintf(out, "\n")
316}
317
318func writeRequiredFlag(cmd *Command, out *bytes.Buffer) {
319	fmt.Fprintf(out, "    must_have_one_flag=()\n")
320	flags := cmd.NonInheritedFlags()
321	flags.VisitAll(func(flag *pflag.Flag) {
322		for key := range flag.Annotations {
323			switch key {
324			case BashCompOneRequiredFlag:
325				format := "    must_have_one_flag+=(\"--%s"
326				b := (flag.Value.Type() == "bool")
327				if !b {
328					format += "="
329				}
330				format += "\")\n"
331				fmt.Fprintf(out, format, flag.Name)
332
333				if len(flag.Shorthand) > 0 {
334					fmt.Fprintf(out, "    must_have_one_flag+=(\"-%s\")\n", flag.Shorthand)
335				}
336			}
337		}
338	})
339}
340
341func writeRequiredNoun(cmd *Command, out *bytes.Buffer) {
342	fmt.Fprintf(out, "    must_have_one_noun=()\n")
343	sort.Sort(sort.StringSlice(cmd.ValidArgs))
344	for _, value := range cmd.ValidArgs {
345		fmt.Fprintf(out, "    must_have_one_noun+=(%q)\n", value)
346	}
347}
348
349func gen(cmd *Command, out *bytes.Buffer) {
350	for _, c := range cmd.Commands() {
351		if !c.IsAvailableCommand() || c == cmd.helpCommand {
352			continue
353		}
354		gen(c, out)
355	}
356	commandName := cmd.CommandPath()
357	commandName = strings.Replace(commandName, " ", "_", -1)
358	fmt.Fprintf(out, "_%s()\n{\n", commandName)
359	fmt.Fprintf(out, "    last_command=%q\n", commandName)
360	writeCommands(cmd, out)
361	writeFlags(cmd, out)
362	writeRequiredFlag(cmd, out)
363	writeRequiredNoun(cmd, out)
364	fmt.Fprintf(out, "}\n\n")
365}
366
367func (cmd *Command) GenBashCompletion(out *bytes.Buffer) {
368	preamble(out)
369	if len(cmd.BashCompletionFunction) > 0 {
370		fmt.Fprintf(out, "%s\n", cmd.BashCompletionFunction)
371	}
372	gen(cmd, out)
373	postscript(out, cmd.Name())
374}
375
376func (cmd *Command) GenBashCompletionFile(filename string) error {
377	out := new(bytes.Buffer)
378
379	cmd.GenBashCompletion(out)
380
381	outFile, err := os.Create(filename)
382	if err != nil {
383		return err
384	}
385	defer outFile.Close()
386
387	_, err = outFile.Write(out.Bytes())
388	if err != nil {
389		return err
390	}
391	return nil
392}
393
394// MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag, if it exists.
395func (cmd *Command) MarkFlagRequired(name string) error {
396	return MarkFlagRequired(cmd.Flags(), name)
397}
398
399// MarkPersistentFlagRequired adds the BashCompOneRequiredFlag annotation to the named persistent flag, if it exists.
400func (cmd *Command) MarkPersistentFlagRequired(name string) error {
401	return MarkFlagRequired(cmd.PersistentFlags(), name)
402}
403
404// MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag in the flag set, if it exists.
405func MarkFlagRequired(flags *pflag.FlagSet, name string) error {
406	return flags.SetAnnotation(name, BashCompOneRequiredFlag, []string{"true"})
407}
408
409// MarkFlagFilename adds the BashCompFilenameExt annotation to the named flag, if it exists.
410// Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided.
411func (cmd *Command) MarkFlagFilename(name string, extensions ...string) error {
412	return MarkFlagFilename(cmd.Flags(), name, extensions...)
413}
414
415// MarkPersistentFlagFilename adds the BashCompFilenameExt annotation to the named persistent flag, if it exists.
416// Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided.
417func (cmd *Command) MarkPersistentFlagFilename(name string, extensions ...string) error {
418	return MarkFlagFilename(cmd.PersistentFlags(), name, extensions...)
419}
420
421// MarkFlagFilename adds the BashCompFilenameExt annotation to the named flag in the flag set, if it exists.
422// Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided.
423func MarkFlagFilename(flags *pflag.FlagSet, name string, extensions ...string) error {
424	return flags.SetAnnotation(name, BashCompFilenameExt, extensions)
425}
426