1// build +windows
2
3package cmd
4
5import (
6	"bytes"
7	"context"
8	"errors"
9	"io"
10	"os"
11	"os/exec"
12	"strings"
13	"syscall"
14	"testing"
15	"time"
16
17	"github.com/Microsoft/hcsshim/internal/cow"
18	hcsschema "github.com/Microsoft/hcsshim/internal/schema2"
19)
20
21type localProcessHost struct {
22}
23
24type localProcess struct {
25	p                     *os.Process
26	state                 *os.ProcessState
27	ch                    chan struct{}
28	stdin, stdout, stderr *os.File
29}
30
31func (h *localProcessHost) OS() string {
32	return "windows"
33}
34
35func (h *localProcessHost) IsOCI() bool {
36	return false
37}
38
39func (h *localProcessHost) CreateProcess(ctx context.Context, cfg interface{}) (_ cow.Process, err error) {
40	params := cfg.(*hcsschema.ProcessParameters)
41	lp := &localProcess{ch: make(chan struct{})}
42	defer func() {
43		if err != nil {
44			lp.Close()
45		}
46	}()
47	var stdin, stdout, stderr *os.File
48	if params.CreateStdInPipe {
49		stdin, lp.stdin, err = os.Pipe()
50		if err != nil {
51			return nil, err
52		}
53		defer stdin.Close()
54	}
55	if params.CreateStdOutPipe {
56		lp.stdout, stdout, err = os.Pipe()
57		if err != nil {
58			return nil, err
59		}
60		defer stdout.Close()
61	}
62	if params.CreateStdErrPipe {
63		lp.stderr, stderr, err = os.Pipe()
64		if err != nil {
65			return nil, err
66		}
67		defer stderr.Close()
68	}
69	path := strings.Split(params.CommandLine, " ")[0] // should be fixed for non-test use...
70	if ppath, err := exec.LookPath(path); err == nil {
71		path = ppath
72	}
73	lp.p, err = os.StartProcess(path, nil, &os.ProcAttr{
74		Files: []*os.File{stdin, stdout, stderr},
75		Sys: &syscall.SysProcAttr{
76			CmdLine: params.CommandLine,
77		},
78	})
79	if err != nil {
80		return nil, err
81	}
82	go func() {
83		lp.state, _ = lp.p.Wait()
84		close(lp.ch)
85	}()
86	return lp, nil
87}
88
89func (p *localProcess) Close() error {
90	if p.p != nil {
91		p.p.Release()
92	}
93	if p.stdin != nil {
94		p.stdin.Close()
95	}
96	if p.stdout != nil {
97		p.stdout.Close()
98	}
99	if p.stderr != nil {
100		p.stderr.Close()
101	}
102	return nil
103}
104
105func (p *localProcess) CloseStdin(ctx context.Context) error {
106	return p.stdin.Close()
107}
108
109func (p *localProcess) ExitCode() (int, error) {
110	select {
111	case <-p.ch:
112		return p.state.ExitCode(), nil
113	default:
114		return -1, errors.New("not exited")
115	}
116}
117
118func (p *localProcess) Kill(ctx context.Context) (bool, error) {
119	return true, p.p.Kill()
120}
121
122func (p *localProcess) Signal(ctx context.Context, _ interface{}) (bool, error) {
123	return p.Kill(ctx)
124}
125
126func (p *localProcess) Pid() int {
127	return p.p.Pid
128}
129
130func (p *localProcess) ResizeConsole(ctx context.Context, x, y uint16) error {
131	return errors.New("not supported")
132}
133
134func (p *localProcess) Stdio() (io.Writer, io.Reader, io.Reader) {
135	return p.stdin, p.stdout, p.stderr
136}
137
138func (p *localProcess) Wait() error {
139	<-p.ch
140	return nil
141}
142
143func TestCmdExitCode(t *testing.T) {
144	cmd := Command(&localProcessHost{}, "cmd", "/c", "exit", "/b", "64")
145	err := cmd.Run()
146	if e, ok := err.(*ExitError); !ok || e.ExitCode() != 64 {
147		t.Fatal("expected exit code 64, got ", err)
148	}
149}
150
151func TestCmdOutput(t *testing.T) {
152	cmd := Command(&localProcessHost{}, "cmd", "/c", "echo", "hello")
153	output, err := cmd.Output()
154	if err != nil {
155		t.Fatal(err)
156	}
157	if string(output) != "hello\r\n" {
158		t.Fatalf("got %q", string(output))
159	}
160}
161
162func TestCmdContext(t *testing.T) {
163	ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond)
164	defer cancel()
165	cmd := CommandContext(ctx, &localProcessHost{}, "cmd", "/c", "pause")
166	r, w := io.Pipe()
167	cmd.Stdin = r
168	err := cmd.Start()
169	if err != nil {
170		t.Fatal(err)
171	}
172	cmd.Process.Wait()
173	w.Close()
174	err = cmd.Wait()
175	if e, ok := err.(*ExitError); !ok || e.ExitCode() != 1 || ctx.Err() == nil {
176		t.Fatal(err)
177	}
178}
179
180func TestCmdStdin(t *testing.T) {
181	cmd := Command(&localProcessHost{}, "findstr", "x*")
182	cmd.Stdin = bytes.NewBufferString("testing 1 2 3")
183	out, err := cmd.Output()
184	if err != nil {
185		t.Fatal(err)
186	}
187	if string(out) != "testing 1 2 3\r\n" {
188		t.Fatalf("got %q", string(out))
189	}
190}
191
192func TestCmdStdinBlocked(t *testing.T) {
193	cmd := Command(&localProcessHost{}, "cmd", "/c", "pause")
194	r, w := io.Pipe()
195	defer r.Close()
196	go func() {
197		b := []byte{'\n'}
198		w.Write(b)
199	}()
200	cmd.Stdin = r
201	_, err := cmd.Output()
202	if err != nil {
203		t.Fatal(err)
204	}
205}
206
207type stuckIoProcessHost struct {
208	cow.ProcessHost
209}
210
211type stuckIoProcess struct {
212	cow.Process
213	stdin, pstdout, pstderr *io.PipeWriter
214	pstdin, stdout, stderr  *io.PipeReader
215}
216
217func (h *stuckIoProcessHost) CreateProcess(ctx context.Context, cfg interface{}) (cow.Process, error) {
218	p, err := h.ProcessHost.CreateProcess(ctx, cfg)
219	if err != nil {
220		return nil, err
221	}
222	sp := &stuckIoProcess{
223		Process: p,
224	}
225	sp.pstdin, sp.stdin = io.Pipe()
226	sp.stdout, sp.pstdout = io.Pipe()
227	sp.stderr, sp.pstderr = io.Pipe()
228	return sp, nil
229}
230
231func (p *stuckIoProcess) Stdio() (io.Writer, io.Reader, io.Reader) {
232	return p.stdin, p.stdout, p.stderr
233}
234
235func (p *stuckIoProcess) Close() error {
236	p.stdin.Close()
237	p.stdout.Close()
238	p.stderr.Close()
239	return p.Process.Close()
240}
241
242func TestCmdStuckIo(t *testing.T) {
243	cmd := Command(&stuckIoProcessHost{&localProcessHost{}}, "cmd", "/c", "echo", "hello")
244	cmd.CopyAfterExitTimeout = time.Millisecond * 200
245	_, err := cmd.Output()
246	if err != io.ErrClosedPipe {
247		t.Fatal(err)
248	}
249}
250