1// +build go1.8 2 3/* 4Copyright 2016 The Kubernetes Authors. 5 6Licensed under the Apache License, Version 2.0 (the "License"); 7you may not use this file except in compliance with the License. 8You may obtain a copy of the License at 9 10 http://www.apache.org/licenses/LICENSE-2.0 11 12Unless required by applicable law or agreed to in writing, software 13distributed under the License is distributed on an "AS IS" BASIS, 14WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15See the License for the specific language governing permissions and 16limitations under the License. 17*/ 18 19package net 20 21import ( 22 "bufio" 23 "bytes" 24 "crypto/tls" 25 "fmt" 26 "io/ioutil" 27 "net" 28 "net/http" 29 "net/http/httptest" 30 "net/url" 31 "os" 32 "reflect" 33 "strings" 34 "testing" 35 36 "github.com/stretchr/testify/assert" 37 "github.com/stretchr/testify/require" 38 "k8s.io/apimachinery/pkg/util/wait" 39) 40 41func TestGetClientIP(t *testing.T) { 42 ipString := "10.0.0.1" 43 ip := net.ParseIP(ipString) 44 invalidIPString := "invalidIPString" 45 testCases := []struct { 46 Request http.Request 47 ExpectedIP net.IP 48 }{ 49 { 50 Request: http.Request{}, 51 }, 52 { 53 Request: http.Request{ 54 Header: map[string][]string{ 55 "X-Real-Ip": {ipString}, 56 }, 57 }, 58 ExpectedIP: ip, 59 }, 60 { 61 Request: http.Request{ 62 Header: map[string][]string{ 63 "X-Real-Ip": {invalidIPString}, 64 }, 65 }, 66 }, 67 { 68 Request: http.Request{ 69 Header: map[string][]string{ 70 "X-Forwarded-For": {ipString}, 71 }, 72 }, 73 ExpectedIP: ip, 74 }, 75 { 76 Request: http.Request{ 77 Header: map[string][]string{ 78 "X-Forwarded-For": {invalidIPString}, 79 }, 80 }, 81 }, 82 { 83 Request: http.Request{ 84 Header: map[string][]string{ 85 "X-Forwarded-For": {invalidIPString + "," + ipString}, 86 }, 87 }, 88 ExpectedIP: ip, 89 }, 90 { 91 Request: http.Request{ 92 // RemoteAddr is in the form host:port 93 RemoteAddr: ipString + ":1234", 94 }, 95 ExpectedIP: ip, 96 }, 97 { 98 Request: http.Request{ 99 RemoteAddr: invalidIPString, 100 }, 101 }, 102 { 103 Request: http.Request{ 104 Header: map[string][]string{ 105 "X-Forwarded-For": {invalidIPString}, 106 }, 107 // RemoteAddr is in the form host:port 108 RemoteAddr: ipString, 109 }, 110 ExpectedIP: ip, 111 }, 112 } 113 114 for i, test := range testCases { 115 if a, e := GetClientIP(&test.Request), test.ExpectedIP; reflect.DeepEqual(e, a) != true { 116 t.Fatalf("test case %d failed. expected: %v, actual: %v", i, e, a) 117 } 118 } 119} 120 121func TestAppendForwardedForHeader(t *testing.T) { 122 testCases := []struct { 123 addr, forwarded, expected string 124 }{ 125 {"1.2.3.4:8000", "", "1.2.3.4"}, 126 {"1.2.3.4:8000", "8.8.8.8", "8.8.8.8, 1.2.3.4"}, 127 {"1.2.3.4:8000", "8.8.8.8, 1.2.3.4", "8.8.8.8, 1.2.3.4, 1.2.3.4"}, 128 {"1.2.3.4:8000", "foo,bar", "foo,bar, 1.2.3.4"}, 129 } 130 for i, test := range testCases { 131 req := &http.Request{ 132 RemoteAddr: test.addr, 133 Header: make(http.Header), 134 } 135 if test.forwarded != "" { 136 req.Header.Set("X-Forwarded-For", test.forwarded) 137 } 138 139 AppendForwardedForHeader(req) 140 actual := req.Header.Get("X-Forwarded-For") 141 if actual != test.expected { 142 t.Errorf("[%d] Expected %q, Got %q", i, test.expected, actual) 143 } 144 } 145} 146 147func TestProxierWithNoProxyCIDR(t *testing.T) { 148 testCases := []struct { 149 name string 150 noProxy string 151 url string 152 153 expectedDelegated bool 154 }{ 155 { 156 name: "no env", 157 url: "https://192.168.143.1/api", 158 expectedDelegated: true, 159 }, 160 { 161 name: "no cidr", 162 noProxy: "192.168.63.1", 163 url: "https://192.168.143.1/api", 164 expectedDelegated: true, 165 }, 166 { 167 name: "hostname", 168 noProxy: "192.168.63.0/24,192.168.143.0/24", 169 url: "https://my-hostname/api", 170 expectedDelegated: true, 171 }, 172 { 173 name: "match second cidr", 174 noProxy: "192.168.63.0/24,192.168.143.0/24", 175 url: "https://192.168.143.1/api", 176 expectedDelegated: false, 177 }, 178 { 179 name: "match second cidr with host:port", 180 noProxy: "192.168.63.0/24,192.168.143.0/24", 181 url: "https://192.168.143.1:8443/api", 182 expectedDelegated: false, 183 }, 184 { 185 name: "IPv6 cidr", 186 noProxy: "2001:db8::/48", 187 url: "https://[2001:db8::1]/api", 188 expectedDelegated: false, 189 }, 190 { 191 name: "IPv6+port cidr", 192 noProxy: "2001:db8::/48", 193 url: "https://[2001:db8::1]:8443/api", 194 expectedDelegated: false, 195 }, 196 { 197 name: "IPv6, not matching cidr", 198 noProxy: "2001:db8::/48", 199 url: "https://[2001:db8:1::1]/api", 200 expectedDelegated: true, 201 }, 202 { 203 name: "IPv6+port, not matching cidr", 204 noProxy: "2001:db8::/48", 205 url: "https://[2001:db8:1::1]:8443/api", 206 expectedDelegated: true, 207 }, 208 } 209 210 for _, test := range testCases { 211 os.Setenv("NO_PROXY", test.noProxy) 212 actualDelegated := false 213 proxyFunc := NewProxierWithNoProxyCIDR(func(req *http.Request) (*url.URL, error) { 214 actualDelegated = true 215 return nil, nil 216 }) 217 218 req, err := http.NewRequest("GET", test.url, nil) 219 if err != nil { 220 t.Errorf("%s: unexpected err: %v", test.name, err) 221 continue 222 } 223 if _, err := proxyFunc(req); err != nil { 224 t.Errorf("%s: unexpected err: %v", test.name, err) 225 continue 226 } 227 228 if test.expectedDelegated != actualDelegated { 229 t.Errorf("%s: expected %v, got %v", test.name, test.expectedDelegated, actualDelegated) 230 continue 231 } 232 } 233} 234 235type fakeTLSClientConfigHolder struct { 236 called bool 237} 238 239func (f *fakeTLSClientConfigHolder) TLSClientConfig() *tls.Config { 240 f.called = true 241 return nil 242} 243func (f *fakeTLSClientConfigHolder) RoundTrip(*http.Request) (*http.Response, error) { 244 return nil, nil 245} 246 247func TestTLSClientConfigHolder(t *testing.T) { 248 rt := &fakeTLSClientConfigHolder{} 249 TLSClientConfig(rt) 250 251 if !rt.called { 252 t.Errorf("didn't find tls config") 253 } 254} 255 256func TestJoinPreservingTrailingSlash(t *testing.T) { 257 tests := []struct { 258 a string 259 b string 260 want string 261 }{ 262 // All empty 263 {"", "", ""}, 264 265 // Empty a 266 {"", "/", "/"}, 267 {"", "foo", "foo"}, 268 {"", "/foo", "/foo"}, 269 {"", "/foo/", "/foo/"}, 270 271 // Empty b 272 {"/", "", "/"}, 273 {"foo", "", "foo"}, 274 {"/foo", "", "/foo"}, 275 {"/foo/", "", "/foo/"}, 276 277 // Both populated 278 {"/", "/", "/"}, 279 {"foo", "foo", "foo/foo"}, 280 {"/foo", "/foo", "/foo/foo"}, 281 {"/foo/", "/foo/", "/foo/foo/"}, 282 } 283 for _, tt := range tests { 284 name := fmt.Sprintf("%q+%q=%q", tt.a, tt.b, tt.want) 285 t.Run(name, func(t *testing.T) { 286 if got := JoinPreservingTrailingSlash(tt.a, tt.b); got != tt.want { 287 t.Errorf("JoinPreservingTrailingSlash() = %v, want %v", got, tt.want) 288 } 289 }) 290 } 291} 292 293func TestConnectWithRedirects(t *testing.T) { 294 tests := []struct { 295 desc string 296 redirects []string 297 method string // initial request method, empty == GET 298 expectError bool 299 expectedRedirects int 300 newPort bool // special case different port test 301 }{{ 302 desc: "relative redirects allowed", 303 redirects: []string{"/ok"}, 304 expectedRedirects: 1, 305 }, { 306 desc: "redirects to the same host are allowed", 307 redirects: []string{"http://HOST/ok"}, // HOST replaced with server address in test 308 expectedRedirects: 1, 309 }, { 310 desc: "POST redirects to GET", 311 method: http.MethodPost, 312 redirects: []string{"/ok"}, 313 expectedRedirects: 1, 314 }, { 315 desc: "PUT redirects to GET", 316 method: http.MethodPut, 317 redirects: []string{"/ok"}, 318 expectedRedirects: 1, 319 }, { 320 desc: "DELETE redirects to GET", 321 method: http.MethodDelete, 322 redirects: []string{"/ok"}, 323 expectedRedirects: 1, 324 }, { 325 desc: "9 redirects are allowed", 326 redirects: []string{"/1", "/2", "/3", "/4", "/5", "/6", "/7", "/8", "/9"}, 327 expectedRedirects: 9, 328 }, { 329 desc: "10 redirects are forbidden", 330 redirects: []string{"/1", "/2", "/3", "/4", "/5", "/6", "/7", "/8", "/9", "/10"}, 331 expectError: true, 332 }, { 333 desc: "redirect to different host are prevented", 334 redirects: []string{"http://example.com/foo"}, 335 expectedRedirects: 0, 336 }, { 337 desc: "multiple redirect to different host forbidden", 338 redirects: []string{"/1", "/2", "/3", "http://example.com/foo"}, 339 expectedRedirects: 3, 340 }, { 341 desc: "redirect to different port is allowed", 342 redirects: []string{"http://HOST/foo"}, 343 expectedRedirects: 1, 344 newPort: true, 345 }} 346 347 const resultString = "Test output" 348 for _, test := range tests { 349 t.Run(test.desc, func(t *testing.T) { 350 redirectCount := 0 351 s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 352 // Verify redirect request. 353 if redirectCount > 0 { 354 expectedURL, err := url.Parse(test.redirects[redirectCount-1]) 355 require.NoError(t, err, "test URL error") 356 assert.Equal(t, req.URL.Path, expectedURL.Path, "unknown redirect path") 357 assert.Equal(t, http.MethodGet, req.Method, "redirects must always be GET") 358 } 359 if redirectCount < len(test.redirects) { 360 http.Redirect(w, req, test.redirects[redirectCount], http.StatusFound) 361 redirectCount++ 362 } else if redirectCount == len(test.redirects) { 363 w.Write([]byte(resultString)) 364 } else { 365 t.Errorf("unexpected number of redirects %d to %s", redirectCount, req.URL.String()) 366 } 367 })) 368 defer s.Close() 369 370 u, err := url.Parse(s.URL) 371 require.NoError(t, err, "Error parsing server URL") 372 host := u.Host 373 374 // Special case new-port test with a secondary server. 375 if test.newPort { 376 s2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 377 w.Write([]byte(resultString)) 378 })) 379 defer s2.Close() 380 u2, err := url.Parse(s2.URL) 381 require.NoError(t, err, "Error parsing secondary server URL") 382 383 // Sanity check: secondary server uses same hostname, different port. 384 require.Equal(t, u.Hostname(), u2.Hostname(), "sanity check: same hostname") 385 require.NotEqual(t, u.Port(), u2.Port(), "sanity check: different port") 386 387 // Redirect to the secondary server. 388 host = u2.Host 389 390 } 391 392 // Update redirect URLs with actual host. 393 for i := range test.redirects { 394 test.redirects[i] = strings.Replace(test.redirects[i], "HOST", host, 1) 395 } 396 397 method := test.method 398 if method == "" { 399 method = http.MethodGet 400 } 401 402 netdialer := &net.Dialer{ 403 Timeout: wait.ForeverTestTimeout, 404 KeepAlive: wait.ForeverTestTimeout, 405 } 406 dialer := DialerFunc(func(req *http.Request) (net.Conn, error) { 407 conn, err := netdialer.Dial("tcp", req.URL.Host) 408 if err != nil { 409 return conn, err 410 } 411 if err = req.Write(conn); err != nil { 412 require.NoError(t, conn.Close()) 413 return nil, fmt.Errorf("error sending request: %v", err) 414 } 415 return conn, err 416 }) 417 conn, rawResponse, err := ConnectWithRedirects(method, u, http.Header{} /*body*/, nil, dialer, true) 418 if test.expectError { 419 require.Error(t, err, "expected request error") 420 return 421 } 422 423 require.NoError(t, err, "unexpected request error") 424 assert.NoError(t, conn.Close(), "error closing connection") 425 426 resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(rawResponse)), nil) 427 require.NoError(t, err, "unexpected request error") 428 429 result, err := ioutil.ReadAll(resp.Body) 430 require.NoError(t, resp.Body.Close()) 431 if test.expectedRedirects < len(test.redirects) { 432 // Expect the last redirect to be returned. 433 assert.Equal(t, http.StatusFound, resp.StatusCode, "Final response is not a redirect") 434 assert.Equal(t, test.redirects[len(test.redirects)-1], resp.Header.Get("Location")) 435 assert.NotEqual(t, resultString, string(result), "wrong content") 436 } else { 437 assert.Equal(t, resultString, string(result), "stream content does not match") 438 } 439 }) 440 } 441} 442 443func TestAllowsHTTP2(t *testing.T) { 444 testcases := []struct { 445 Name string 446 Transport *http.Transport 447 ExpectAllows bool 448 }{ 449 { 450 Name: "empty", 451 Transport: &http.Transport{}, 452 ExpectAllows: true, 453 }, 454 { 455 Name: "empty tlsconfig", 456 Transport: &http.Transport{TLSClientConfig: &tls.Config{}}, 457 ExpectAllows: true, 458 }, 459 { 460 Name: "zero-length NextProtos", 461 Transport: &http.Transport{TLSClientConfig: &tls.Config{NextProtos: []string{}}}, 462 ExpectAllows: true, 463 }, 464 { 465 Name: "includes h2 in NextProtos after", 466 Transport: &http.Transport{TLSClientConfig: &tls.Config{NextProtos: []string{"http/1.1", "h2"}}}, 467 ExpectAllows: true, 468 }, 469 { 470 Name: "includes h2 in NextProtos before", 471 Transport: &http.Transport{TLSClientConfig: &tls.Config{NextProtos: []string{"h2", "http/1.1"}}}, 472 ExpectAllows: true, 473 }, 474 { 475 Name: "includes h2 in NextProtos between", 476 Transport: &http.Transport{TLSClientConfig: &tls.Config{NextProtos: []string{"http/1.1", "h2", "h3"}}}, 477 ExpectAllows: true, 478 }, 479 { 480 Name: "excludes h2 in NextProtos", 481 Transport: &http.Transport{TLSClientConfig: &tls.Config{NextProtos: []string{"http/1.1"}}}, 482 ExpectAllows: false, 483 }, 484 } 485 486 for _, tc := range testcases { 487 t.Run(tc.Name, func(t *testing.T) { 488 allows := allowsHTTP2(tc.Transport) 489 if allows != tc.ExpectAllows { 490 t.Errorf("expected %v, got %v", tc.ExpectAllows, allows) 491 } 492 }) 493 } 494} 495