1package cobra
2
3import (
4	"bytes"
5	"fmt"
6	"io"
7	"os"
8	"strings"
9)
10
11// GenZshCompletionFile generates zsh completion file.
12func (c *Command) GenZshCompletionFile(filename string) error {
13	outFile, err := os.Create(filename)
14	if err != nil {
15		return err
16	}
17	defer outFile.Close()
18
19	return c.GenZshCompletion(outFile)
20}
21
22// GenZshCompletion generates a zsh completion file and writes to the passed writer.
23func (c *Command) GenZshCompletion(w io.Writer) error {
24	buf := new(bytes.Buffer)
25
26	writeHeader(buf, c)
27	maxDepth := maxDepth(c)
28	writeLevelMapping(buf, maxDepth)
29	writeLevelCases(buf, maxDepth, c)
30
31	_, err := buf.WriteTo(w)
32	return err
33}
34
35func writeHeader(w io.Writer, cmd *Command) {
36	fmt.Fprintf(w, "#compdef %s\n\n", cmd.Name())
37}
38
39func maxDepth(c *Command) int {
40	if len(c.Commands()) == 0 {
41		return 0
42	}
43	maxDepthSub := 0
44	for _, s := range c.Commands() {
45		subDepth := maxDepth(s)
46		if subDepth > maxDepthSub {
47			maxDepthSub = subDepth
48		}
49	}
50	return 1 + maxDepthSub
51}
52
53func writeLevelMapping(w io.Writer, numLevels int) {
54	fmt.Fprintln(w, `_arguments \`)
55	for i := 1; i <= numLevels; i++ {
56		fmt.Fprintf(w, `  '%d: :->level%d' \`, i, i)
57		fmt.Fprintln(w)
58	}
59	fmt.Fprintf(w, `  '%d: :%s'`, numLevels+1, "_files")
60	fmt.Fprintln(w)
61}
62
63func writeLevelCases(w io.Writer, maxDepth int, root *Command) {
64	fmt.Fprintln(w, "case $state in")
65	defer fmt.Fprintln(w, "esac")
66
67	for i := 1; i <= maxDepth; i++ {
68		fmt.Fprintf(w, "  level%d)\n", i)
69		writeLevel(w, root, i)
70		fmt.Fprintln(w, "  ;;")
71	}
72	fmt.Fprintln(w, "  *)")
73	fmt.Fprintln(w, "    _arguments '*: :_files'")
74	fmt.Fprintln(w, "  ;;")
75}
76
77func writeLevel(w io.Writer, root *Command, i int) {
78	fmt.Fprintf(w, "    case $words[%d] in\n", i)
79	defer fmt.Fprintln(w, "    esac")
80
81	commands := filterByLevel(root, i)
82	byParent := groupByParent(commands)
83
84	for p, c := range byParent {
85		names := names(c)
86		fmt.Fprintf(w, "      %s)\n", p)
87		fmt.Fprintf(w, "        _arguments '%d: :(%s)'\n", i, strings.Join(names, " "))
88		fmt.Fprintln(w, "      ;;")
89	}
90	fmt.Fprintln(w, "      *)")
91	fmt.Fprintln(w, "        _arguments '*: :_files'")
92	fmt.Fprintln(w, "      ;;")
93
94}
95
96func filterByLevel(c *Command, l int) []*Command {
97	cs := make([]*Command, 0)
98	if l == 0 {
99		cs = append(cs, c)
100		return cs
101	}
102	for _, s := range c.Commands() {
103		cs = append(cs, filterByLevel(s, l-1)...)
104	}
105	return cs
106}
107
108func groupByParent(commands []*Command) map[string][]*Command {
109	m := make(map[string][]*Command)
110	for _, c := range commands {
111		parent := c.Parent()
112		if parent == nil {
113			continue
114		}
115		m[parent.Name()] = append(m[parent.Name()], c)
116	}
117	return m
118}
119
120func names(commands []*Command) []string {
121	ns := make([]string, len(commands))
122	for i, c := range commands {
123		ns[i] = c.Name()
124	}
125	return ns
126}
127