1/* 2Copyright 2016 The Kubernetes Authors. 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17package streaming 18 19import ( 20 "crypto/tls" 21 "io" 22 "net/http" 23 "net/http/httptest" 24 "net/url" 25 "strconv" 26 "strings" 27 "sync" 28 "testing" 29 30 "github.com/stretchr/testify/assert" 31 "github.com/stretchr/testify/require" 32 33 api "k8s.io/api/core/v1" 34 restclient "k8s.io/client-go/rest" 35 "k8s.io/client-go/tools/remotecommand" 36 "k8s.io/client-go/transport/spdy" 37 runtimeapi "k8s.io/cri-api/pkg/apis/runtime/v1alpha2" 38 kubeletportforward "k8s.io/kubernetes/pkg/kubelet/cri/streaming/portforward" 39) 40 41const ( 42 testAddr = "localhost:12345" 43 testContainerID = "container789" 44 testPodSandboxID = "pod0987" 45) 46 47func TestGetExec(t *testing.T) { 48 serv, err := NewServer(Config{ 49 Addr: testAddr, 50 }, nil) 51 assert.NoError(t, err) 52 53 tlsServer, err := NewServer(Config{ 54 Addr: testAddr, 55 TLSConfig: &tls.Config{}, 56 }, nil) 57 assert.NoError(t, err) 58 59 const pathPrefix = "cri/shim" 60 prefixServer, err := NewServer(Config{ 61 Addr: testAddr, 62 BaseURL: &url.URL{ 63 Scheme: "http", 64 Host: testAddr, 65 Path: "/" + pathPrefix + "/", 66 }, 67 }, nil) 68 assert.NoError(t, err) 69 70 assertRequestToken := func(expectedReq *runtimeapi.ExecRequest, cache *requestCache, token string) { 71 req, ok := cache.Consume(token) 72 require.True(t, ok, "token %s not found!", token) 73 assert.Equal(t, expectedReq, req) 74 } 75 request := &runtimeapi.ExecRequest{ 76 ContainerId: testContainerID, 77 Cmd: []string{"echo", "foo"}, 78 Tty: true, 79 Stdin: true, 80 } 81 { // Non-TLS 82 resp, err := serv.GetExec(request) 83 assert.NoError(t, err) 84 expectedURL := "http://" + testAddr + "/exec/" 85 assert.Contains(t, resp.Url, expectedURL) 86 token := strings.TrimPrefix(resp.Url, expectedURL) 87 assertRequestToken(request, serv.(*server).cache, token) 88 } 89 90 { // TLS 91 resp, err := tlsServer.GetExec(request) 92 assert.NoError(t, err) 93 expectedURL := "https://" + testAddr + "/exec/" 94 assert.Contains(t, resp.Url, expectedURL) 95 token := strings.TrimPrefix(resp.Url, expectedURL) 96 assertRequestToken(request, tlsServer.(*server).cache, token) 97 } 98 99 { // Path prefix 100 resp, err := prefixServer.GetExec(request) 101 assert.NoError(t, err) 102 expectedURL := "http://" + testAddr + "/" + pathPrefix + "/exec/" 103 assert.Contains(t, resp.Url, expectedURL) 104 token := strings.TrimPrefix(resp.Url, expectedURL) 105 assertRequestToken(request, prefixServer.(*server).cache, token) 106 } 107} 108 109func TestValidateExecAttachRequest(t *testing.T) { 110 type config struct { 111 tty bool 112 stdin bool 113 stdout bool 114 stderr bool 115 } 116 for _, tc := range []struct { 117 desc string 118 configs []config 119 expectErr bool 120 }{ 121 { 122 desc: "at least one stream must be true", 123 expectErr: true, 124 configs: []config{ 125 {false, false, false, false}, 126 {true, false, false, false}}, 127 }, 128 { 129 desc: "tty and stderr cannot both be true", 130 expectErr: true, 131 configs: []config{ 132 {true, false, false, true}, 133 {true, false, true, true}, 134 {true, true, false, true}, 135 {true, true, true, true}, 136 }, 137 }, 138 { 139 desc: "a valid config should pass", 140 expectErr: false, 141 configs: []config{ 142 {false, false, false, true}, 143 {false, false, true, false}, 144 {false, false, true, true}, 145 {false, true, false, false}, 146 {false, true, false, true}, 147 {false, true, true, false}, 148 {false, true, true, true}, 149 {true, false, true, false}, 150 {true, true, false, false}, 151 {true, true, true, false}, 152 }, 153 }, 154 } { 155 t.Run(tc.desc, func(t *testing.T) { 156 for _, c := range tc.configs { 157 // validate the exec request. 158 execReq := &runtimeapi.ExecRequest{ 159 ContainerId: testContainerID, 160 Cmd: []string{"date"}, 161 Tty: c.tty, 162 Stdin: c.stdin, 163 Stdout: c.stdout, 164 Stderr: c.stderr, 165 } 166 err := validateExecRequest(execReq) 167 assert.Equal(t, tc.expectErr, err != nil, "config: %v, err: %v", c, err) 168 169 // validate the attach request. 170 attachReq := &runtimeapi.AttachRequest{ 171 ContainerId: testContainerID, 172 Tty: c.tty, 173 Stdin: c.stdin, 174 Stdout: c.stdout, 175 Stderr: c.stderr, 176 } 177 err = validateAttachRequest(attachReq) 178 assert.Equal(t, tc.expectErr, err != nil, "config: %v, err: %v", c, err) 179 } 180 }) 181 } 182} 183 184func TestGetAttach(t *testing.T) { 185 serv, err := NewServer(Config{ 186 Addr: testAddr, 187 }, nil) 188 require.NoError(t, err) 189 190 tlsServer, err := NewServer(Config{ 191 Addr: testAddr, 192 TLSConfig: &tls.Config{}, 193 }, nil) 194 require.NoError(t, err) 195 196 assertRequestToken := func(expectedReq *runtimeapi.AttachRequest, cache *requestCache, token string) { 197 req, ok := cache.Consume(token) 198 require.True(t, ok, "token %s not found!", token) 199 assert.Equal(t, expectedReq, req) 200 } 201 202 request := &runtimeapi.AttachRequest{ 203 ContainerId: testContainerID, 204 Stdin: true, 205 Tty: true, 206 } 207 { // Non-TLS 208 resp, err := serv.GetAttach(request) 209 assert.NoError(t, err) 210 expectedURL := "http://" + testAddr + "/attach/" 211 assert.Contains(t, resp.Url, expectedURL) 212 token := strings.TrimPrefix(resp.Url, expectedURL) 213 assertRequestToken(request, serv.(*server).cache, token) 214 } 215 216 { // TLS 217 resp, err := tlsServer.GetAttach(request) 218 assert.NoError(t, err) 219 expectedURL := "https://" + testAddr + "/attach/" 220 assert.Contains(t, resp.Url, expectedURL) 221 token := strings.TrimPrefix(resp.Url, expectedURL) 222 assertRequestToken(request, tlsServer.(*server).cache, token) 223 } 224} 225 226func TestGetPortForward(t *testing.T) { 227 podSandboxID := testPodSandboxID 228 request := &runtimeapi.PortForwardRequest{ 229 PodSandboxId: podSandboxID, 230 Port: []int32{1, 2, 3, 4}, 231 } 232 233 { // Non-TLS 234 serv, err := NewServer(Config{ 235 Addr: testAddr, 236 }, nil) 237 assert.NoError(t, err) 238 resp, err := serv.GetPortForward(request) 239 assert.NoError(t, err) 240 expectedURL := "http://" + testAddr + "/portforward/" 241 assert.True(t, strings.HasPrefix(resp.Url, expectedURL)) 242 token := strings.TrimPrefix(resp.Url, expectedURL) 243 req, ok := serv.(*server).cache.Consume(token) 244 require.True(t, ok, "token %s not found!", token) 245 assert.Equal(t, testPodSandboxID, req.(*runtimeapi.PortForwardRequest).PodSandboxId) 246 } 247 248 { // TLS 249 tlsServer, err := NewServer(Config{ 250 Addr: testAddr, 251 TLSConfig: &tls.Config{}, 252 }, nil) 253 assert.NoError(t, err) 254 resp, err := tlsServer.GetPortForward(request) 255 assert.NoError(t, err) 256 expectedURL := "https://" + testAddr + "/portforward/" 257 assert.True(t, strings.HasPrefix(resp.Url, expectedURL)) 258 token := strings.TrimPrefix(resp.Url, expectedURL) 259 req, ok := tlsServer.(*server).cache.Consume(token) 260 require.True(t, ok, "token %s not found!", token) 261 assert.Equal(t, testPodSandboxID, req.(*runtimeapi.PortForwardRequest).PodSandboxId) 262 } 263} 264 265func TestServeExec(t *testing.T) { 266 runRemoteCommandTest(t, "exec") 267} 268 269func TestServeAttach(t *testing.T) { 270 runRemoteCommandTest(t, "attach") 271} 272 273func TestServePortForward(t *testing.T) { 274 s, testServer := startTestServer(t) 275 defer testServer.Close() 276 277 resp, err := s.GetPortForward(&runtimeapi.PortForwardRequest{ 278 PodSandboxId: testPodSandboxID, 279 }) 280 require.NoError(t, err) 281 reqURL, err := url.Parse(resp.Url) 282 require.NoError(t, err) 283 284 transport, upgrader, err := spdy.RoundTripperFor(&restclient.Config{}) 285 require.NoError(t, err) 286 dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", reqURL) 287 streamConn, _, err := dialer.Dial(kubeletportforward.ProtocolV1Name) 288 require.NoError(t, err) 289 defer streamConn.Close() 290 291 // Create the streams. 292 headers := http.Header{} 293 // Error stream is required, but unused in this test. 294 headers.Set(api.StreamType, api.StreamTypeError) 295 headers.Set(api.PortHeader, strconv.Itoa(testPort)) 296 _, err = streamConn.CreateStream(headers) 297 require.NoError(t, err) 298 // Setup the data stream. 299 headers.Set(api.StreamType, api.StreamTypeData) 300 headers.Set(api.PortHeader, strconv.Itoa(testPort)) 301 stream, err := streamConn.CreateStream(headers) 302 require.NoError(t, err) 303 304 doClientStreams(t, "portforward", stream, stream, nil) 305} 306 307// 308// Run the remote command test. 309// commandType is either "exec" or "attach". 310func runRemoteCommandTest(t *testing.T, commandType string) { 311 s, testServer := startTestServer(t) 312 defer testServer.Close() 313 314 var reqURL *url.URL 315 stdin, stdout, stderr := true, true, true 316 containerID := testContainerID 317 switch commandType { 318 case "exec": 319 resp, err := s.GetExec(&runtimeapi.ExecRequest{ 320 ContainerId: containerID, 321 Cmd: []string{"echo"}, 322 Stdin: stdin, 323 Stdout: stdout, 324 Stderr: stderr, 325 }) 326 require.NoError(t, err) 327 reqURL, err = url.Parse(resp.Url) 328 require.NoError(t, err) 329 case "attach": 330 resp, err := s.GetAttach(&runtimeapi.AttachRequest{ 331 ContainerId: containerID, 332 Stdin: stdin, 333 Stdout: stdout, 334 Stderr: stderr, 335 }) 336 require.NoError(t, err) 337 reqURL, err = url.Parse(resp.Url) 338 require.NoError(t, err) 339 } 340 341 wg := sync.WaitGroup{} 342 wg.Add(2) 343 344 stdinR, stdinW := io.Pipe() 345 stdoutR, stdoutW := io.Pipe() 346 stderrR, stderrW := io.Pipe() 347 348 go func() { 349 defer wg.Done() 350 exec, err := remotecommand.NewSPDYExecutor(&restclient.Config{}, "POST", reqURL) 351 require.NoError(t, err) 352 353 opts := remotecommand.StreamOptions{ 354 Stdin: stdinR, 355 Stdout: stdoutW, 356 Stderr: stderrW, 357 Tty: false, 358 } 359 require.NoError(t, exec.Stream(opts)) 360 }() 361 362 go func() { 363 defer wg.Done() 364 doClientStreams(t, commandType, stdinW, stdoutR, stderrR) 365 }() 366 367 wg.Wait() 368 369 // Repeat request with the same URL should be a 404. 370 resp, err := http.Get(reqURL.String()) 371 require.NoError(t, err) 372 assert.Equal(t, http.StatusNotFound, resp.StatusCode) 373} 374 375func startTestServer(t *testing.T) (Server, *httptest.Server) { 376 var s Server 377 testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 378 s.ServeHTTP(w, r) 379 })) 380 cleanup := true 381 defer func() { 382 if cleanup { 383 testServer.Close() 384 } 385 }() 386 387 testURL, err := url.Parse(testServer.URL) 388 require.NoError(t, err) 389 390 rt := newFakeRuntime(t) 391 config := DefaultConfig 392 config.BaseURL = testURL 393 s, err = NewServer(config, rt) 394 require.NoError(t, err) 395 396 cleanup = false // Caller must close the test server. 397 return s, testServer 398} 399 400const ( 401 testInput = "abcdefg" 402 testOutput = "fooBARbaz" 403 testErr = "ERROR!!!" 404 testPort = 12345 405) 406 407func newFakeRuntime(t *testing.T) *fakeRuntime { 408 return &fakeRuntime{ 409 t: t, 410 } 411} 412 413type fakeRuntime struct { 414 t *testing.T 415} 416 417func (f *fakeRuntime) Exec(containerID string, cmd []string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error { 418 assert.Equal(f.t, testContainerID, containerID) 419 doServerStreams(f.t, "exec", stdin, stdout, stderr) 420 return nil 421} 422 423func (f *fakeRuntime) Attach(containerID string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error { 424 assert.Equal(f.t, testContainerID, containerID) 425 doServerStreams(f.t, "attach", stdin, stdout, stderr) 426 return nil 427} 428 429func (f *fakeRuntime) PortForward(podSandboxID string, port int32, stream io.ReadWriteCloser) error { 430 assert.Equal(f.t, testPodSandboxID, podSandboxID) 431 assert.EqualValues(f.t, testPort, port) 432 doServerStreams(f.t, "portforward", stream, stream, nil) 433 return nil 434} 435 436// Send & receive expected input/output. Must be the inverse of doClientStreams. 437// Function will block until the expected i/o is finished. 438func doServerStreams(t *testing.T, prefix string, stdin io.Reader, stdout, stderr io.Writer) { 439 if stderr != nil { 440 writeExpected(t, "server stderr", stderr, prefix+testErr) 441 } 442 readExpected(t, "server stdin", stdin, prefix+testInput) 443 writeExpected(t, "server stdout", stdout, prefix+testOutput) 444} 445 446// Send & receive expected input/output. Must be the inverse of doServerStreams. 447// Function will block until the expected i/o is finished. 448func doClientStreams(t *testing.T, prefix string, stdin io.Writer, stdout, stderr io.Reader) { 449 if stderr != nil { 450 readExpected(t, "client stderr", stderr, prefix+testErr) 451 } 452 writeExpected(t, "client stdin", stdin, prefix+testInput) 453 readExpected(t, "client stdout", stdout, prefix+testOutput) 454} 455 456// Read and verify the expected string from the stream. 457func readExpected(t *testing.T, streamName string, r io.Reader, expected string) { 458 result := make([]byte, len(expected)) 459 _, err := io.ReadAtLeast(r, result, len(expected)) 460 assert.NoError(t, err, "stream %s", streamName) 461 assert.Equal(t, expected, string(result), "stream %s", streamName) 462} 463 464// Write and verify success of the data over the stream. 465func writeExpected(t *testing.T, streamName string, w io.Writer, data string) { 466 n, err := io.WriteString(w, data) 467 assert.NoError(t, err, "stream %s", streamName) 468 assert.Equal(t, len(data), n, "stream %s", streamName) 469} 470