1/*Package icmd executes binaries and provides convenient assertions for testing the results.
2 */
3package icmd // import "gotest.tools/icmd"
4
5import (
6	"bytes"
7	"fmt"
8	"io"
9	"os/exec"
10	"strings"
11	"sync"
12	"time"
13
14	"gotest.tools/assert"
15	"gotest.tools/assert/cmp"
16)
17
18type helperT interface {
19	Helper()
20}
21
22// None is a token to inform Result.Assert that the output should be empty
23const None = "[NOTHING]"
24
25type lockedBuffer struct {
26	m   sync.RWMutex
27	buf bytes.Buffer
28}
29
30func (buf *lockedBuffer) Write(b []byte) (int, error) {
31	buf.m.Lock()
32	defer buf.m.Unlock()
33	return buf.buf.Write(b)
34}
35
36func (buf *lockedBuffer) String() string {
37	buf.m.RLock()
38	defer buf.m.RUnlock()
39	return buf.buf.String()
40}
41
42// Result stores the result of running a command
43type Result struct {
44	Cmd      *exec.Cmd
45	ExitCode int
46	Error    error
47	// Timeout is true if the command was killed because it ran for too long
48	Timeout   bool
49	outBuffer *lockedBuffer
50	errBuffer *lockedBuffer
51}
52
53// Assert compares the Result against the Expected struct, and fails the test if
54// any of the expectations are not met.
55//
56// This function is equivalent to assert.Assert(t, result.Equal(exp)).
57func (r *Result) Assert(t assert.TestingT, exp Expected) *Result {
58	if ht, ok := t.(helperT); ok {
59		ht.Helper()
60	}
61	assert.Assert(t, r.Equal(exp))
62	return r
63}
64
65// Equal compares the result to Expected. If the result doesn't match expected
66// returns a formatted failure message with the command, stdout, stderr, exit code,
67// and any failed expectations.
68func (r *Result) Equal(exp Expected) cmp.Comparison {
69	return func() cmp.Result {
70		return cmp.ResultFromError(r.match(exp))
71	}
72}
73
74// Compare the result to Expected and return an error if they do not match.
75func (r *Result) Compare(exp Expected) error {
76	return r.match(exp)
77}
78
79// nolint: gocyclo
80func (r *Result) match(exp Expected) error {
81	errors := []string{}
82	add := func(format string, args ...interface{}) {
83		errors = append(errors, fmt.Sprintf(format, args...))
84	}
85
86	if exp.ExitCode != r.ExitCode {
87		add("ExitCode was %d expected %d", r.ExitCode, exp.ExitCode)
88	}
89	if exp.Timeout != r.Timeout {
90		if exp.Timeout {
91			add("Expected command to timeout")
92		} else {
93			add("Expected command to finish, but it hit the timeout")
94		}
95	}
96	if !matchOutput(exp.Out, r.Stdout()) {
97		add("Expected stdout to contain %q", exp.Out)
98	}
99	if !matchOutput(exp.Err, r.Stderr()) {
100		add("Expected stderr to contain %q", exp.Err)
101	}
102	switch {
103	// If a non-zero exit code is expected there is going to be an error.
104	// Don't require an error message as well as an exit code because the
105	// error message is going to be "exit status <code> which is not useful
106	case exp.Error == "" && exp.ExitCode != 0:
107	case exp.Error == "" && r.Error != nil:
108		add("Expected no error")
109	case exp.Error != "" && r.Error == nil:
110		add("Expected error to contain %q, but there was no error", exp.Error)
111	case exp.Error != "" && !strings.Contains(r.Error.Error(), exp.Error):
112		add("Expected error to contain %q", exp.Error)
113	}
114
115	if len(errors) == 0 {
116		return nil
117	}
118	return fmt.Errorf("%s\nFailures:\n%s", r, strings.Join(errors, "\n"))
119}
120
121func matchOutput(expected string, actual string) bool {
122	switch expected {
123	case None:
124		return actual == ""
125	default:
126		return strings.Contains(actual, expected)
127	}
128}
129
130func (r *Result) String() string {
131	var timeout string
132	if r.Timeout {
133		timeout = " (timeout)"
134	}
135
136	return fmt.Sprintf(`
137Command:  %s
138ExitCode: %d%s
139Error:    %v
140Stdout:   %v
141Stderr:   %v
142`,
143		strings.Join(r.Cmd.Args, " "),
144		r.ExitCode,
145		timeout,
146		r.Error,
147		r.Stdout(),
148		r.Stderr())
149}
150
151// Expected is the expected output from a Command. This struct is compared to a
152// Result struct by Result.Assert().
153type Expected struct {
154	ExitCode int
155	Timeout  bool
156	Error    string
157	Out      string
158	Err      string
159}
160
161// Success is the default expected result. A Success result is one with a 0
162// ExitCode.
163var Success = Expected{}
164
165// Stdout returns the stdout of the process as a string
166func (r *Result) Stdout() string {
167	return r.outBuffer.String()
168}
169
170// Stderr returns the stderr of the process as a string
171func (r *Result) Stderr() string {
172	return r.errBuffer.String()
173}
174
175// Combined returns the stdout and stderr combined into a single string
176func (r *Result) Combined() string {
177	return r.outBuffer.String() + r.errBuffer.String()
178}
179
180func (r *Result) setExitError(err error) {
181	if err == nil {
182		return
183	}
184	r.Error = err
185	r.ExitCode = processExitCode(err)
186}
187
188// Cmd contains the arguments and options for a process to run as part of a test
189// suite.
190type Cmd struct {
191	Command []string
192	Timeout time.Duration
193	Stdin   io.Reader
194	Stdout  io.Writer
195	Dir     string
196	Env     []string
197}
198
199// Command create a simple Cmd with the specified command and arguments
200func Command(command string, args ...string) Cmd {
201	return Cmd{Command: append([]string{command}, args...)}
202}
203
204// RunCmd runs a command and returns a Result
205func RunCmd(cmd Cmd, cmdOperators ...CmdOp) *Result {
206	for _, op := range cmdOperators {
207		op(&cmd)
208	}
209	result := StartCmd(cmd)
210	if result.Error != nil {
211		return result
212	}
213	return WaitOnCmd(cmd.Timeout, result)
214}
215
216// RunCommand runs a command with default options, and returns a result
217func RunCommand(command string, args ...string) *Result {
218	return RunCmd(Command(command, args...))
219}
220
221// StartCmd starts a command, but doesn't wait for it to finish
222func StartCmd(cmd Cmd) *Result {
223	result := buildCmd(cmd)
224	if result.Error != nil {
225		return result
226	}
227	result.setExitError(result.Cmd.Start())
228	return result
229}
230
231// TODO: support exec.CommandContext
232func buildCmd(cmd Cmd) *Result {
233	var execCmd *exec.Cmd
234	switch len(cmd.Command) {
235	case 1:
236		execCmd = exec.Command(cmd.Command[0])
237	default:
238		execCmd = exec.Command(cmd.Command[0], cmd.Command[1:]...)
239	}
240	outBuffer := new(lockedBuffer)
241	errBuffer := new(lockedBuffer)
242
243	execCmd.Stdin = cmd.Stdin
244	execCmd.Dir = cmd.Dir
245	execCmd.Env = cmd.Env
246	if cmd.Stdout != nil {
247		execCmd.Stdout = io.MultiWriter(outBuffer, cmd.Stdout)
248	} else {
249		execCmd.Stdout = outBuffer
250	}
251	execCmd.Stderr = errBuffer
252	return &Result{
253		Cmd:       execCmd,
254		outBuffer: outBuffer,
255		errBuffer: errBuffer,
256	}
257}
258
259// WaitOnCmd waits for a command to complete. If timeout is non-nil then
260// only wait until the timeout.
261func WaitOnCmd(timeout time.Duration, result *Result) *Result {
262	if timeout == time.Duration(0) {
263		result.setExitError(result.Cmd.Wait())
264		return result
265	}
266
267	done := make(chan error, 1)
268	// Wait for command to exit in a goroutine
269	go func() {
270		done <- result.Cmd.Wait()
271	}()
272
273	select {
274	case <-time.After(timeout):
275		killErr := result.Cmd.Process.Kill()
276		if killErr != nil {
277			fmt.Printf("failed to kill (pid=%d): %v\n", result.Cmd.Process.Pid, killErr)
278		}
279		result.Timeout = true
280	case err := <-done:
281		result.setExitError(err)
282	}
283	return result
284}
285