1/*
2Copyright 2016 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package remotecommand
18
19import (
20	"errors"
21	"io"
22	"net/http"
23	"strings"
24	"testing"
25	"time"
26
27	"k8s.io/api/core/v1"
28	"k8s.io/apimachinery/pkg/util/httpstream"
29	"k8s.io/apimachinery/pkg/util/wait"
30)
31
32type fakeReader struct {
33	err error
34}
35
36func (r *fakeReader) Read([]byte) (int, error) { return 0, r.err }
37
38type fakeWriter struct{}
39
40func (*fakeWriter) Write([]byte) (int, error) { return 0, nil }
41
42type fakeStreamCreator struct {
43	created map[string]bool
44	errors  map[string]error
45}
46
47var _ streamCreator = &fakeStreamCreator{}
48
49func (f *fakeStreamCreator) CreateStream(headers http.Header) (httpstream.Stream, error) {
50	streamType := headers.Get(v1.StreamType)
51	f.created[streamType] = true
52	return nil, f.errors[streamType]
53}
54
55func TestV2CreateStreams(t *testing.T) {
56	tests := []struct {
57		name        string
58		stdin       bool
59		stdinError  error
60		stdout      bool
61		stdoutError error
62		stderr      bool
63		stderrError error
64		errorError  error
65		tty         bool
66		expectError bool
67	}{
68		{
69			name:        "stdin error",
70			stdin:       true,
71			stdinError:  errors.New("stdin error"),
72			expectError: true,
73		},
74		{
75			name:        "stdout error",
76			stdout:      true,
77			stdoutError: errors.New("stdout error"),
78			expectError: true,
79		},
80		{
81			name:        "stderr error",
82			stderr:      true,
83			stderrError: errors.New("stderr error"),
84			expectError: true,
85		},
86		{
87			name:        "error stream error",
88			stdin:       true,
89			stdout:      true,
90			stderr:      true,
91			errorError:  errors.New("error stream error"),
92			expectError: true,
93		},
94		{
95			name:        "no errors",
96			stdin:       true,
97			stdout:      true,
98			stderr:      true,
99			expectError: false,
100		},
101		{
102			name:        "no errors, stderr & tty set, don't expect stderr",
103			stdin:       true,
104			stdout:      true,
105			stderr:      true,
106			tty:         true,
107			expectError: false,
108		},
109	}
110	for _, test := range tests {
111		conn := &fakeStreamCreator{
112			created: make(map[string]bool),
113			errors: map[string]error{
114				v1.StreamTypeStdin:  test.stdinError,
115				v1.StreamTypeStdout: test.stdoutError,
116				v1.StreamTypeStderr: test.stderrError,
117				v1.StreamTypeError:  test.errorError,
118			},
119		}
120
121		opts := StreamOptions{Tty: test.tty}
122		if test.stdin {
123			opts.Stdin = &fakeReader{}
124		}
125		if test.stdout {
126			opts.Stdout = &fakeWriter{}
127		}
128		if test.stderr {
129			opts.Stderr = &fakeWriter{}
130		}
131
132		h := newStreamProtocolV2(opts).(*streamProtocolV2)
133		err := h.createStreams(conn)
134
135		if test.expectError {
136			if err == nil {
137				t.Errorf("%s: expected error", test.name)
138				continue
139			}
140			if e, a := test.stdinError, err; test.stdinError != nil && e != a {
141				t.Errorf("%s: expected %v, got %v", test.name, e, a)
142			}
143			if e, a := test.stdoutError, err; test.stdoutError != nil && e != a {
144				t.Errorf("%s: expected %v, got %v", test.name, e, a)
145			}
146			if e, a := test.stderrError, err; test.stderrError != nil && e != a {
147				t.Errorf("%s: expected %v, got %v", test.name, e, a)
148			}
149			if e, a := test.errorError, err; test.errorError != nil && e != a {
150				t.Errorf("%s: expected %v, got %v", test.name, e, a)
151			}
152			continue
153		}
154
155		if !test.expectError && err != nil {
156			t.Errorf("%s: unexpected error: %v", test.name, err)
157			continue
158		}
159
160		if test.stdin && !conn.created[v1.StreamTypeStdin] {
161			t.Errorf("%s: expected stdin stream", test.name)
162		}
163		if test.stdout && !conn.created[v1.StreamTypeStdout] {
164			t.Errorf("%s: expected stdout stream", test.name)
165		}
166		if test.stderr {
167			if test.tty && conn.created[v1.StreamTypeStderr] {
168				t.Errorf("%s: unexpected stderr stream because tty is set", test.name)
169			} else if !test.tty && !conn.created[v1.StreamTypeStderr] {
170				t.Errorf("%s: expected stderr stream", test.name)
171			}
172		}
173		if !conn.created[v1.StreamTypeError] {
174			t.Errorf("%s: expected error stream", test.name)
175		}
176
177	}
178}
179
180func TestV2ErrorStreamReading(t *testing.T) {
181	tests := []struct {
182		name          string
183		stream        io.Reader
184		expectedError error
185	}{
186		{
187			name:          "error reading from stream",
188			stream:        &fakeReader{errors.New("foo")},
189			expectedError: errors.New("error reading from error stream: foo"),
190		},
191		{
192			name:          "stream returns an error",
193			stream:        strings.NewReader("some error"),
194			expectedError: errors.New("error executing remote command: some error"),
195		},
196	}
197
198	for _, test := range tests {
199		h := newStreamProtocolV2(StreamOptions{}).(*streamProtocolV2)
200		h.errorStream = test.stream
201
202		ch := watchErrorStream(h.errorStream, &errorDecoderV2{})
203		if ch == nil {
204			t.Fatalf("%s: unexpected nil channel", test.name)
205		}
206
207		var err error
208		select {
209		case err = <-ch:
210		case <-time.After(wait.ForeverTestTimeout):
211			t.Fatalf("%s: timed out", test.name)
212		}
213
214		if test.expectedError != nil {
215			if err == nil {
216				t.Errorf("%s: expected an error", test.name)
217			} else if e, a := test.expectedError, err; e.Error() != a.Error() {
218				t.Errorf("%s: expected %q, got %q", test.name, e, a)
219			}
220			continue
221		}
222
223		if test.expectedError == nil && err != nil {
224			t.Errorf("%s: unexpected error: %v", test.name, err)
225			continue
226		}
227	}
228}
229