1/*
2Copyright 2016 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package streaming
18
19import (
20	"crypto/tls"
21	"io"
22	"net/http"
23	"net/http/httptest"
24	"net/url"
25	"strconv"
26	"strings"
27	"sync"
28	"testing"
29
30	"github.com/stretchr/testify/assert"
31	"github.com/stretchr/testify/require"
32
33	api "k8s.io/api/core/v1"
34	restclient "k8s.io/client-go/rest"
35	"k8s.io/client-go/tools/remotecommand"
36	"k8s.io/client-go/transport/spdy"
37	runtimeapi "k8s.io/cri-api/pkg/apis/runtime/v1alpha2"
38	kubeletportforward "k8s.io/kubernetes/pkg/kubelet/cri/streaming/portforward"
39)
40
41const (
42	testAddr         = "localhost:12345"
43	testContainerID  = "container789"
44	testPodSandboxID = "pod0987"
45)
46
47func TestGetExec(t *testing.T) {
48	serv, err := NewServer(Config{
49		Addr: testAddr,
50	}, nil)
51	assert.NoError(t, err)
52
53	tlsServer, err := NewServer(Config{
54		Addr:      testAddr,
55		TLSConfig: &tls.Config{},
56	}, nil)
57	assert.NoError(t, err)
58
59	const pathPrefix = "cri/shim"
60	prefixServer, err := NewServer(Config{
61		Addr: testAddr,
62		BaseURL: &url.URL{
63			Scheme: "http",
64			Host:   testAddr,
65			Path:   "/" + pathPrefix + "/",
66		},
67	}, nil)
68	assert.NoError(t, err)
69
70	assertRequestToken := func(expectedReq *runtimeapi.ExecRequest, cache *requestCache, token string) {
71		req, ok := cache.Consume(token)
72		require.True(t, ok, "token %s not found!", token)
73		assert.Equal(t, expectedReq, req)
74	}
75	request := &runtimeapi.ExecRequest{
76		ContainerId: testContainerID,
77		Cmd:         []string{"echo", "foo"},
78		Tty:         true,
79		Stdin:       true,
80	}
81	{ // Non-TLS
82		resp, err := serv.GetExec(request)
83		assert.NoError(t, err)
84		expectedURL := "http://" + testAddr + "/exec/"
85		assert.Contains(t, resp.Url, expectedURL)
86		token := strings.TrimPrefix(resp.Url, expectedURL)
87		assertRequestToken(request, serv.(*server).cache, token)
88	}
89
90	{ // TLS
91		resp, err := tlsServer.GetExec(request)
92		assert.NoError(t, err)
93		expectedURL := "https://" + testAddr + "/exec/"
94		assert.Contains(t, resp.Url, expectedURL)
95		token := strings.TrimPrefix(resp.Url, expectedURL)
96		assertRequestToken(request, tlsServer.(*server).cache, token)
97	}
98
99	{ // Path prefix
100		resp, err := prefixServer.GetExec(request)
101		assert.NoError(t, err)
102		expectedURL := "http://" + testAddr + "/" + pathPrefix + "/exec/"
103		assert.Contains(t, resp.Url, expectedURL)
104		token := strings.TrimPrefix(resp.Url, expectedURL)
105		assertRequestToken(request, prefixServer.(*server).cache, token)
106	}
107}
108
109func TestValidateExecAttachRequest(t *testing.T) {
110	type config struct {
111		tty    bool
112		stdin  bool
113		stdout bool
114		stderr bool
115	}
116	for _, tc := range []struct {
117		desc      string
118		configs   []config
119		expectErr bool
120	}{
121		{
122			desc:      "at least one stream must be true",
123			expectErr: true,
124			configs: []config{
125				{false, false, false, false},
126				{true, false, false, false}},
127		},
128		{
129			desc:      "tty and stderr cannot both be true",
130			expectErr: true,
131			configs: []config{
132				{true, false, false, true},
133				{true, false, true, true},
134				{true, true, false, true},
135				{true, true, true, true},
136			},
137		},
138		{
139			desc:      "a valid config should pass",
140			expectErr: false,
141			configs: []config{
142				{false, false, false, true},
143				{false, false, true, false},
144				{false, false, true, true},
145				{false, true, false, false},
146				{false, true, false, true},
147				{false, true, true, false},
148				{false, true, true, true},
149				{true, false, true, false},
150				{true, true, false, false},
151				{true, true, true, false},
152			},
153		},
154	} {
155		t.Run(tc.desc, func(t *testing.T) {
156			for _, c := range tc.configs {
157				// validate the exec request.
158				execReq := &runtimeapi.ExecRequest{
159					ContainerId: testContainerID,
160					Cmd:         []string{"date"},
161					Tty:         c.tty,
162					Stdin:       c.stdin,
163					Stdout:      c.stdout,
164					Stderr:      c.stderr,
165				}
166				err := validateExecRequest(execReq)
167				assert.Equal(t, tc.expectErr, err != nil, "config: %v,  err: %v", c, err)
168
169				// validate the attach request.
170				attachReq := &runtimeapi.AttachRequest{
171					ContainerId: testContainerID,
172					Tty:         c.tty,
173					Stdin:       c.stdin,
174					Stdout:      c.stdout,
175					Stderr:      c.stderr,
176				}
177				err = validateAttachRequest(attachReq)
178				assert.Equal(t, tc.expectErr, err != nil, "config: %v, err: %v", c, err)
179			}
180		})
181	}
182}
183
184func TestGetAttach(t *testing.T) {
185	serv, err := NewServer(Config{
186		Addr: testAddr,
187	}, nil)
188	require.NoError(t, err)
189
190	tlsServer, err := NewServer(Config{
191		Addr:      testAddr,
192		TLSConfig: &tls.Config{},
193	}, nil)
194	require.NoError(t, err)
195
196	assertRequestToken := func(expectedReq *runtimeapi.AttachRequest, cache *requestCache, token string) {
197		req, ok := cache.Consume(token)
198		require.True(t, ok, "token %s not found!", token)
199		assert.Equal(t, expectedReq, req)
200	}
201
202	request := &runtimeapi.AttachRequest{
203		ContainerId: testContainerID,
204		Stdin:       true,
205		Tty:         true,
206	}
207	{ // Non-TLS
208		resp, err := serv.GetAttach(request)
209		assert.NoError(t, err)
210		expectedURL := "http://" + testAddr + "/attach/"
211		assert.Contains(t, resp.Url, expectedURL)
212		token := strings.TrimPrefix(resp.Url, expectedURL)
213		assertRequestToken(request, serv.(*server).cache, token)
214	}
215
216	{ // TLS
217		resp, err := tlsServer.GetAttach(request)
218		assert.NoError(t, err)
219		expectedURL := "https://" + testAddr + "/attach/"
220		assert.Contains(t, resp.Url, expectedURL)
221		token := strings.TrimPrefix(resp.Url, expectedURL)
222		assertRequestToken(request, tlsServer.(*server).cache, token)
223	}
224}
225
226func TestGetPortForward(t *testing.T) {
227	podSandboxID := testPodSandboxID
228	request := &runtimeapi.PortForwardRequest{
229		PodSandboxId: podSandboxID,
230		Port:         []int32{1, 2, 3, 4},
231	}
232
233	{ // Non-TLS
234		serv, err := NewServer(Config{
235			Addr: testAddr,
236		}, nil)
237		assert.NoError(t, err)
238		resp, err := serv.GetPortForward(request)
239		assert.NoError(t, err)
240		expectedURL := "http://" + testAddr + "/portforward/"
241		assert.True(t, strings.HasPrefix(resp.Url, expectedURL))
242		token := strings.TrimPrefix(resp.Url, expectedURL)
243		req, ok := serv.(*server).cache.Consume(token)
244		require.True(t, ok, "token %s not found!", token)
245		assert.Equal(t, testPodSandboxID, req.(*runtimeapi.PortForwardRequest).PodSandboxId)
246	}
247
248	{ // TLS
249		tlsServer, err := NewServer(Config{
250			Addr:      testAddr,
251			TLSConfig: &tls.Config{},
252		}, nil)
253		assert.NoError(t, err)
254		resp, err := tlsServer.GetPortForward(request)
255		assert.NoError(t, err)
256		expectedURL := "https://" + testAddr + "/portforward/"
257		assert.True(t, strings.HasPrefix(resp.Url, expectedURL))
258		token := strings.TrimPrefix(resp.Url, expectedURL)
259		req, ok := tlsServer.(*server).cache.Consume(token)
260		require.True(t, ok, "token %s not found!", token)
261		assert.Equal(t, testPodSandboxID, req.(*runtimeapi.PortForwardRequest).PodSandboxId)
262	}
263}
264
265func TestServeExec(t *testing.T) {
266	runRemoteCommandTest(t, "exec")
267}
268
269func TestServeAttach(t *testing.T) {
270	runRemoteCommandTest(t, "attach")
271}
272
273func TestServePortForward(t *testing.T) {
274	s, testServer := startTestServer(t)
275	defer testServer.Close()
276
277	resp, err := s.GetPortForward(&runtimeapi.PortForwardRequest{
278		PodSandboxId: testPodSandboxID,
279	})
280	require.NoError(t, err)
281	reqURL, err := url.Parse(resp.Url)
282	require.NoError(t, err)
283
284	transport, upgrader, err := spdy.RoundTripperFor(&restclient.Config{})
285	require.NoError(t, err)
286	dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", reqURL)
287	streamConn, _, err := dialer.Dial(kubeletportforward.ProtocolV1Name)
288	require.NoError(t, err)
289	defer streamConn.Close()
290
291	// Create the streams.
292	headers := http.Header{}
293	// Error stream is required, but unused in this test.
294	headers.Set(api.StreamType, api.StreamTypeError)
295	headers.Set(api.PortHeader, strconv.Itoa(testPort))
296	_, err = streamConn.CreateStream(headers)
297	require.NoError(t, err)
298	// Setup the data stream.
299	headers.Set(api.StreamType, api.StreamTypeData)
300	headers.Set(api.PortHeader, strconv.Itoa(testPort))
301	stream, err := streamConn.CreateStream(headers)
302	require.NoError(t, err)
303
304	doClientStreams(t, "portforward", stream, stream, nil)
305}
306
307//
308// Run the remote command test.
309// commandType is either "exec" or "attach".
310func runRemoteCommandTest(t *testing.T, commandType string) {
311	s, testServer := startTestServer(t)
312	defer testServer.Close()
313
314	var reqURL *url.URL
315	stdin, stdout, stderr := true, true, true
316	containerID := testContainerID
317	switch commandType {
318	case "exec":
319		resp, err := s.GetExec(&runtimeapi.ExecRequest{
320			ContainerId: containerID,
321			Cmd:         []string{"echo"},
322			Stdin:       stdin,
323			Stdout:      stdout,
324			Stderr:      stderr,
325		})
326		require.NoError(t, err)
327		reqURL, err = url.Parse(resp.Url)
328		require.NoError(t, err)
329	case "attach":
330		resp, err := s.GetAttach(&runtimeapi.AttachRequest{
331			ContainerId: containerID,
332			Stdin:       stdin,
333			Stdout:      stdout,
334			Stderr:      stderr,
335		})
336		require.NoError(t, err)
337		reqURL, err = url.Parse(resp.Url)
338		require.NoError(t, err)
339	}
340
341	wg := sync.WaitGroup{}
342	wg.Add(2)
343
344	stdinR, stdinW := io.Pipe()
345	stdoutR, stdoutW := io.Pipe()
346	stderrR, stderrW := io.Pipe()
347
348	go func() {
349		defer wg.Done()
350		exec, err := remotecommand.NewSPDYExecutor(&restclient.Config{}, "POST", reqURL)
351		require.NoError(t, err)
352
353		opts := remotecommand.StreamOptions{
354			Stdin:  stdinR,
355			Stdout: stdoutW,
356			Stderr: stderrW,
357			Tty:    false,
358		}
359		require.NoError(t, exec.Stream(opts))
360	}()
361
362	go func() {
363		defer wg.Done()
364		doClientStreams(t, commandType, stdinW, stdoutR, stderrR)
365	}()
366
367	wg.Wait()
368
369	// Repeat request with the same URL should be a 404.
370	resp, err := http.Get(reqURL.String())
371	require.NoError(t, err)
372	assert.Equal(t, http.StatusNotFound, resp.StatusCode)
373}
374
375func startTestServer(t *testing.T) (Server, *httptest.Server) {
376	var s Server
377	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
378		s.ServeHTTP(w, r)
379	}))
380	cleanup := true
381	defer func() {
382		if cleanup {
383			testServer.Close()
384		}
385	}()
386
387	testURL, err := url.Parse(testServer.URL)
388	require.NoError(t, err)
389
390	rt := newFakeRuntime(t)
391	config := DefaultConfig
392	config.BaseURL = testURL
393	s, err = NewServer(config, rt)
394	require.NoError(t, err)
395
396	cleanup = false // Caller must close the test server.
397	return s, testServer
398}
399
400const (
401	testInput  = "abcdefg"
402	testOutput = "fooBARbaz"
403	testErr    = "ERROR!!!"
404	testPort   = 12345
405)
406
407func newFakeRuntime(t *testing.T) *fakeRuntime {
408	return &fakeRuntime{
409		t: t,
410	}
411}
412
413type fakeRuntime struct {
414	t *testing.T
415}
416
417func (f *fakeRuntime) Exec(containerID string, cmd []string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error {
418	assert.Equal(f.t, testContainerID, containerID)
419	doServerStreams(f.t, "exec", stdin, stdout, stderr)
420	return nil
421}
422
423func (f *fakeRuntime) Attach(containerID string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error {
424	assert.Equal(f.t, testContainerID, containerID)
425	doServerStreams(f.t, "attach", stdin, stdout, stderr)
426	return nil
427}
428
429func (f *fakeRuntime) PortForward(podSandboxID string, port int32, stream io.ReadWriteCloser) error {
430	assert.Equal(f.t, testPodSandboxID, podSandboxID)
431	assert.EqualValues(f.t, testPort, port)
432	doServerStreams(f.t, "portforward", stream, stream, nil)
433	return nil
434}
435
436// Send & receive expected input/output. Must be the inverse of doClientStreams.
437// Function will block until the expected i/o is finished.
438func doServerStreams(t *testing.T, prefix string, stdin io.Reader, stdout, stderr io.Writer) {
439	if stderr != nil {
440		writeExpected(t, "server stderr", stderr, prefix+testErr)
441	}
442	readExpected(t, "server stdin", stdin, prefix+testInput)
443	writeExpected(t, "server stdout", stdout, prefix+testOutput)
444}
445
446// Send & receive expected input/output. Must be the inverse of doServerStreams.
447// Function will block until the expected i/o is finished.
448func doClientStreams(t *testing.T, prefix string, stdin io.Writer, stdout, stderr io.Reader) {
449	if stderr != nil {
450		readExpected(t, "client stderr", stderr, prefix+testErr)
451	}
452	writeExpected(t, "client stdin", stdin, prefix+testInput)
453	readExpected(t, "client stdout", stdout, prefix+testOutput)
454}
455
456// Read and verify the expected string from the stream.
457func readExpected(t *testing.T, streamName string, r io.Reader, expected string) {
458	result := make([]byte, len(expected))
459	_, err := io.ReadAtLeast(r, result, len(expected))
460	assert.NoError(t, err, "stream %s", streamName)
461	assert.Equal(t, expected, string(result), "stream %s", streamName)
462}
463
464// Write and verify success of the data over the stream.
465func writeExpected(t *testing.T, streamName string, w io.Writer, data string) {
466	n, err := io.WriteString(w, data)
467	assert.NoError(t, err, "stream %s", streamName)
468	assert.Equal(t, len(data), n, "stream %s", streamName)
469}
470