1package cobra
2
3import (
4	"encoding/json"
5	"fmt"
6	"io"
7	"os"
8	"sort"
9	"strings"
10	"text/template"
11
12	"github.com/spf13/pflag"
13)
14
15const (
16	zshCompArgumentAnnotation   = "cobra_annotations_zsh_completion_argument_annotation"
17	zshCompArgumentFilenameComp = "cobra_annotations_zsh_completion_argument_file_completion"
18	zshCompArgumentWordComp     = "cobra_annotations_zsh_completion_argument_word_completion"
19	zshCompDirname              = "cobra_annotations_zsh_dirname"
20)
21
22var (
23	zshCompFuncMap = template.FuncMap{
24		"genZshFuncName":              zshCompGenFuncName,
25		"extractFlags":                zshCompExtractFlag,
26		"genFlagEntryForZshArguments": zshCompGenFlagEntryForArguments,
27		"extractArgsCompletions":      zshCompExtractArgumentCompletionHintsForRendering,
28	}
29	zshCompletionText = `
30{{/* should accept Command (that contains subcommands) as parameter */}}
31{{define "argumentsC" -}}
32{{ $cmdPath := genZshFuncName .}}
33function {{$cmdPath}} {
34  local -a commands
35
36  _arguments -C \{{- range extractFlags .}}
37    {{genFlagEntryForZshArguments .}} \{{- end}}
38    "1: :->cmnds" \
39    "*::arg:->args"
40
41  case $state in
42  cmnds)
43    commands=({{range .Commands}}{{if not .Hidden}}
44      "{{.Name}}:{{.Short}}"{{end}}{{end}}
45    )
46    _describe "command" commands
47    ;;
48  esac
49
50  case "$words[1]" in {{- range .Commands}}{{if not .Hidden}}
51  {{.Name}})
52    {{$cmdPath}}_{{.Name}}
53    ;;{{end}}{{end}}
54  esac
55}
56{{range .Commands}}{{if not .Hidden}}
57{{template "selectCmdTemplate" .}}
58{{- end}}{{end}}
59{{- end}}
60
61{{/* should accept Command without subcommands as parameter */}}
62{{define "arguments" -}}
63function {{genZshFuncName .}} {
64{{"  _arguments"}}{{range extractFlags .}} \
65    {{genFlagEntryForZshArguments . -}}
66{{end}}{{range extractArgsCompletions .}} \
67    {{.}}{{end}}
68}
69{{end}}
70
71{{/* dispatcher for commands with or without subcommands */}}
72{{define "selectCmdTemplate" -}}
73{{if .Hidden}}{{/* ignore hidden*/}}{{else -}}
74{{if .Commands}}{{template "argumentsC" .}}{{else}}{{template "arguments" .}}{{end}}
75{{- end}}
76{{- end}}
77
78{{/* template entry point */}}
79{{define "Main" -}}
80#compdef _{{.Name}} {{.Name}}
81
82{{template "selectCmdTemplate" .}}
83{{end}}
84`
85)
86
87// zshCompArgsAnnotation is used to encode/decode zsh completion for
88// arguments to/from Command.Annotations.
89type zshCompArgsAnnotation map[int]zshCompArgHint
90
91type zshCompArgHint struct {
92	// Indicates the type of the completion to use. One of:
93	// zshCompArgumentFilenameComp or zshCompArgumentWordComp
94	Tipe string `json:"type"`
95
96	// A value for the type above (globs for file completion or words)
97	Options []string `json:"options"`
98}
99
100// GenZshCompletionFile generates zsh completion file.
101func (c *Command) GenZshCompletionFile(filename string) error {
102	outFile, err := os.Create(filename)
103	if err != nil {
104		return err
105	}
106	defer outFile.Close()
107
108	return c.GenZshCompletion(outFile)
109}
110
111// GenZshCompletion generates a zsh completion file and writes to the passed
112// writer. The completion always run on the root command regardless of the
113// command it was called from.
114func (c *Command) GenZshCompletion(w io.Writer) error {
115	tmpl, err := template.New("Main").Funcs(zshCompFuncMap).Parse(zshCompletionText)
116	if err != nil {
117		return fmt.Errorf("error creating zsh completion template: %v", err)
118	}
119	return tmpl.Execute(w, c.Root())
120}
121
122// MarkZshCompPositionalArgumentFile marks the specified argument (first
123// argument is 1) as completed by file selection. patterns (e.g. "*.txt") are
124// optional - if not provided the completion will search for all files.
125func (c *Command) MarkZshCompPositionalArgumentFile(argPosition int, patterns ...string) error {
126	if argPosition < 1 {
127		return fmt.Errorf("Invalid argument position (%d)", argPosition)
128	}
129	annotation, err := c.zshCompGetArgsAnnotations()
130	if err != nil {
131		return err
132	}
133	if c.zshcompArgsAnnotationnIsDuplicatePosition(annotation, argPosition) {
134		return fmt.Errorf("Duplicate annotation for positional argument at index %d", argPosition)
135	}
136	annotation[argPosition] = zshCompArgHint{
137		Tipe:    zshCompArgumentFilenameComp,
138		Options: patterns,
139	}
140	return c.zshCompSetArgsAnnotations(annotation)
141}
142
143// MarkZshCompPositionalArgumentWords marks the specified positional argument
144// (first argument is 1) as completed by the provided words. At east one word
145// must be provided, spaces within words will be offered completion with
146// "word\ word".
147func (c *Command) MarkZshCompPositionalArgumentWords(argPosition int, words ...string) error {
148	if argPosition < 1 {
149		return fmt.Errorf("Invalid argument position (%d)", argPosition)
150	}
151	if len(words) == 0 {
152		return fmt.Errorf("Trying to set empty word list for positional argument %d", argPosition)
153	}
154	annotation, err := c.zshCompGetArgsAnnotations()
155	if err != nil {
156		return err
157	}
158	if c.zshcompArgsAnnotationnIsDuplicatePosition(annotation, argPosition) {
159		return fmt.Errorf("Duplicate annotation for positional argument at index %d", argPosition)
160	}
161	annotation[argPosition] = zshCompArgHint{
162		Tipe:    zshCompArgumentWordComp,
163		Options: words,
164	}
165	return c.zshCompSetArgsAnnotations(annotation)
166}
167
168func zshCompExtractArgumentCompletionHintsForRendering(c *Command) ([]string, error) {
169	var result []string
170	annotation, err := c.zshCompGetArgsAnnotations()
171	if err != nil {
172		return nil, err
173	}
174	for k, v := range annotation {
175		s, err := zshCompRenderZshCompArgHint(k, v)
176		if err != nil {
177			return nil, err
178		}
179		result = append(result, s)
180	}
181	if len(c.ValidArgs) > 0 {
182		if _, positionOneExists := annotation[1]; !positionOneExists {
183			s, err := zshCompRenderZshCompArgHint(1, zshCompArgHint{
184				Tipe:    zshCompArgumentWordComp,
185				Options: c.ValidArgs,
186			})
187			if err != nil {
188				return nil, err
189			}
190			result = append(result, s)
191		}
192	}
193	sort.Strings(result)
194	return result, nil
195}
196
197func zshCompRenderZshCompArgHint(i int, z zshCompArgHint) (string, error) {
198	switch t := z.Tipe; t {
199	case zshCompArgumentFilenameComp:
200		var globs []string
201		for _, g := range z.Options {
202			globs = append(globs, fmt.Sprintf(`-g "%s"`, g))
203		}
204		return fmt.Sprintf(`'%d: :_files %s'`, i, strings.Join(globs, " ")), nil
205	case zshCompArgumentWordComp:
206		var words []string
207		for _, w := range z.Options {
208			words = append(words, fmt.Sprintf("%q", w))
209		}
210		return fmt.Sprintf(`'%d: :(%s)'`, i, strings.Join(words, " ")), nil
211	default:
212		return "", fmt.Errorf("Invalid zsh argument completion annotation: %s", t)
213	}
214}
215
216func (c *Command) zshcompArgsAnnotationnIsDuplicatePosition(annotation zshCompArgsAnnotation, position int) bool {
217	_, dup := annotation[position]
218	return dup
219}
220
221func (c *Command) zshCompGetArgsAnnotations() (zshCompArgsAnnotation, error) {
222	annotation := make(zshCompArgsAnnotation)
223	annotationString, ok := c.Annotations[zshCompArgumentAnnotation]
224	if !ok {
225		return annotation, nil
226	}
227	err := json.Unmarshal([]byte(annotationString), &annotation)
228	if err != nil {
229		return annotation, fmt.Errorf("Error unmarshaling zsh argument annotation: %v", err)
230	}
231	return annotation, nil
232}
233
234func (c *Command) zshCompSetArgsAnnotations(annotation zshCompArgsAnnotation) error {
235	jsn, err := json.Marshal(annotation)
236	if err != nil {
237		return fmt.Errorf("Error marshaling zsh argument annotation: %v", err)
238	}
239	if c.Annotations == nil {
240		c.Annotations = make(map[string]string)
241	}
242	c.Annotations[zshCompArgumentAnnotation] = string(jsn)
243	return nil
244}
245
246func zshCompGenFuncName(c *Command) string {
247	if c.HasParent() {
248		return zshCompGenFuncName(c.Parent()) + "_" + c.Name()
249	}
250	return "_" + c.Name()
251}
252
253func zshCompExtractFlag(c *Command) []*pflag.Flag {
254	var flags []*pflag.Flag
255	c.LocalFlags().VisitAll(func(f *pflag.Flag) {
256		if !f.Hidden {
257			flags = append(flags, f)
258		}
259	})
260	c.InheritedFlags().VisitAll(func(f *pflag.Flag) {
261		if !f.Hidden {
262			flags = append(flags, f)
263		}
264	})
265	return flags
266}
267
268// zshCompGenFlagEntryForArguments returns an entry that matches _arguments
269// zsh-completion parameters. It's too complicated to generate in a template.
270func zshCompGenFlagEntryForArguments(f *pflag.Flag) string {
271	if f.Name == "" || f.Shorthand == "" {
272		return zshCompGenFlagEntryForSingleOptionFlag(f)
273	}
274	return zshCompGenFlagEntryForMultiOptionFlag(f)
275}
276
277func zshCompGenFlagEntryForSingleOptionFlag(f *pflag.Flag) string {
278	var option, multiMark, extras string
279
280	if zshCompFlagCouldBeSpecifiedMoreThenOnce(f) {
281		multiMark = "*"
282	}
283
284	option = "--" + f.Name
285	if option == "--" {
286		option = "-" + f.Shorthand
287	}
288	extras = zshCompGenFlagEntryExtras(f)
289
290	return fmt.Sprintf(`'%s%s[%s]%s'`, multiMark, option, zshCompQuoteFlagDescription(f.Usage), extras)
291}
292
293func zshCompGenFlagEntryForMultiOptionFlag(f *pflag.Flag) string {
294	var options, parenMultiMark, curlyMultiMark, extras string
295
296	if zshCompFlagCouldBeSpecifiedMoreThenOnce(f) {
297		parenMultiMark = "*"
298		curlyMultiMark = "\\*"
299	}
300
301	options = fmt.Sprintf(`'(%s-%s %s--%s)'{%s-%s,%s--%s}`,
302		parenMultiMark, f.Shorthand, parenMultiMark, f.Name, curlyMultiMark, f.Shorthand, curlyMultiMark, f.Name)
303	extras = zshCompGenFlagEntryExtras(f)
304
305	return fmt.Sprintf(`%s'[%s]%s'`, options, zshCompQuoteFlagDescription(f.Usage), extras)
306}
307
308func zshCompGenFlagEntryExtras(f *pflag.Flag) string {
309	if f.NoOptDefVal != "" {
310		return ""
311	}
312
313	extras := ":" // allow options for flag (even without assistance)
314	for key, values := range f.Annotations {
315		switch key {
316		case zshCompDirname:
317			extras = fmt.Sprintf(":filename:_files -g %q", values[0])
318		case BashCompFilenameExt:
319			extras = ":filename:_files"
320			for _, pattern := range values {
321				extras = extras + fmt.Sprintf(` -g "%s"`, pattern)
322			}
323		}
324	}
325
326	return extras
327}
328
329func zshCompFlagCouldBeSpecifiedMoreThenOnce(f *pflag.Flag) bool {
330	return strings.Contains(f.Value.Type(), "Slice") ||
331		strings.Contains(f.Value.Type(), "Array")
332}
333
334func zshCompQuoteFlagDescription(s string) string {
335	return strings.Replace(s, "'", `'\''`, -1)
336}
337