1// PowerShell completions are based on the amazing work from clap:
2// https://github.com/clap-rs/clap/blob/3294d18efe5f264d12c9035f404c7d189d4824e1/src/completions/powershell.rs
3//
4// The generated scripts require PowerShell v5.0+ (which comes Windows 10, but
5// can be downloaded separately for windows 7 or 8.1).
6
7package cobra
8
9import (
10	"bytes"
11	"fmt"
12	"io"
13	"os"
14	"strings"
15
16	"github.com/spf13/pflag"
17)
18
19var powerShellCompletionTemplate = `using namespace System.Management.Automation
20using namespace System.Management.Automation.Language
21Register-ArgumentCompleter -Native -CommandName '%s' -ScriptBlock {
22    param($wordToComplete, $commandAst, $cursorPosition)
23    $commandElements = $commandAst.CommandElements
24    $command = @(
25        '%s'
26        for ($i = 1; $i -lt $commandElements.Count; $i++) {
27            $element = $commandElements[$i]
28            if ($element -isnot [StringConstantExpressionAst] -or
29                $element.StringConstantType -ne [StringConstantType]::BareWord -or
30                $element.Value.StartsWith('-')) {
31                break
32            }
33            $element.Value
34        }
35    ) -join ';'
36    $completions = @(switch ($command) {%s
37    })
38    $completions.Where{ $_.CompletionText -like "$wordToComplete*" } |
39        Sort-Object -Property ListItemText
40}`
41
42func generatePowerShellSubcommandCases(out io.Writer, cmd *Command, previousCommandName string) {
43	var cmdName string
44	if previousCommandName == "" {
45		cmdName = cmd.Name()
46	} else {
47		cmdName = fmt.Sprintf("%s;%s", previousCommandName, cmd.Name())
48	}
49
50	fmt.Fprintf(out, "\n        '%s' {", cmdName)
51
52	cmd.Flags().VisitAll(func(flag *pflag.Flag) {
53		if nonCompletableFlag(flag) {
54			return
55		}
56		usage := escapeStringForPowerShell(flag.Usage)
57		if len(flag.Shorthand) > 0 {
58			fmt.Fprintf(out, "\n            [CompletionResult]::new('-%s', '%s', [CompletionResultType]::ParameterName, '%s')", flag.Shorthand, flag.Shorthand, usage)
59		}
60		fmt.Fprintf(out, "\n            [CompletionResult]::new('--%s', '%s', [CompletionResultType]::ParameterName, '%s')", flag.Name, flag.Name, usage)
61	})
62
63	for _, subCmd := range cmd.Commands() {
64		usage := escapeStringForPowerShell(subCmd.Short)
65		fmt.Fprintf(out, "\n            [CompletionResult]::new('%s', '%s', [CompletionResultType]::ParameterValue, '%s')", subCmd.Name(), subCmd.Name(), usage)
66	}
67
68	fmt.Fprint(out, "\n            break\n        }")
69
70	for _, subCmd := range cmd.Commands() {
71		generatePowerShellSubcommandCases(out, subCmd, cmdName)
72	}
73}
74
75func escapeStringForPowerShell(s string) string {
76	return strings.Replace(s, "'", "''", -1)
77}
78
79// GenPowerShellCompletion generates PowerShell completion file and writes to the passed writer.
80func (c *Command) GenPowerShellCompletion(w io.Writer) error {
81	buf := new(bytes.Buffer)
82
83	var subCommandCases bytes.Buffer
84	generatePowerShellSubcommandCases(&subCommandCases, c, "")
85	fmt.Fprintf(buf, powerShellCompletionTemplate, c.Name(), c.Name(), subCommandCases.String())
86
87	_, err := buf.WriteTo(w)
88	return err
89}
90
91// GenPowerShellCompletionFile generates PowerShell completion file.
92func (c *Command) GenPowerShellCompletionFile(filename string) error {
93	outFile, err := os.Create(filename)
94	if err != nil {
95		return err
96	}
97	defer outFile.Close()
98
99	return c.GenPowerShellCompletion(outFile)
100}
101