1// +build go1.8
2
3/*
4Copyright 2016 The Kubernetes Authors.
5
6Licensed under the Apache License, Version 2.0 (the "License");
7you may not use this file except in compliance with the License.
8You may obtain a copy of the License at
9
10    http://www.apache.org/licenses/LICENSE-2.0
11
12Unless required by applicable law or agreed to in writing, software
13distributed under the License is distributed on an "AS IS" BASIS,
14WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15See the License for the specific language governing permissions and
16limitations under the License.
17*/
18
19package net
20
21import (
22	"bufio"
23	"bytes"
24	"crypto/tls"
25	"fmt"
26	"io/ioutil"
27	"net"
28	"net/http"
29	"net/http/httptest"
30	"net/url"
31	"os"
32	"reflect"
33	"strings"
34	"testing"
35
36	"github.com/stretchr/testify/assert"
37	"github.com/stretchr/testify/require"
38	"k8s.io/apimachinery/pkg/util/wait"
39)
40
41func TestGetClientIP(t *testing.T) {
42	ipString := "10.0.0.1"
43	ip := net.ParseIP(ipString)
44	invalidIPString := "invalidIPString"
45	testCases := []struct {
46		Request    http.Request
47		ExpectedIP net.IP
48	}{
49		{
50			Request: http.Request{},
51		},
52		{
53			Request: http.Request{
54				Header: map[string][]string{
55					"X-Real-Ip": {ipString},
56				},
57			},
58			ExpectedIP: ip,
59		},
60		{
61			Request: http.Request{
62				Header: map[string][]string{
63					"X-Real-Ip": {invalidIPString},
64				},
65			},
66		},
67		{
68			Request: http.Request{
69				Header: map[string][]string{
70					"X-Forwarded-For": {ipString},
71				},
72			},
73			ExpectedIP: ip,
74		},
75		{
76			Request: http.Request{
77				Header: map[string][]string{
78					"X-Forwarded-For": {invalidIPString},
79				},
80			},
81		},
82		{
83			Request: http.Request{
84				Header: map[string][]string{
85					"X-Forwarded-For": {invalidIPString + "," + ipString},
86				},
87			},
88			ExpectedIP: ip,
89		},
90		{
91			Request: http.Request{
92				// RemoteAddr is in the form host:port
93				RemoteAddr: ipString + ":1234",
94			},
95			ExpectedIP: ip,
96		},
97		{
98			Request: http.Request{
99				RemoteAddr: invalidIPString,
100			},
101		},
102		{
103			Request: http.Request{
104				Header: map[string][]string{
105					"X-Forwarded-For": {invalidIPString},
106				},
107				// RemoteAddr is in the form host:port
108				RemoteAddr: ipString,
109			},
110			ExpectedIP: ip,
111		},
112	}
113
114	for i, test := range testCases {
115		if a, e := GetClientIP(&test.Request), test.ExpectedIP; reflect.DeepEqual(e, a) != true {
116			t.Fatalf("test case %d failed. expected: %v, actual: %v", i, e, a)
117		}
118	}
119}
120
121func TestAppendForwardedForHeader(t *testing.T) {
122	testCases := []struct {
123		addr, forwarded, expected string
124	}{
125		{"1.2.3.4:8000", "", "1.2.3.4"},
126		{"1.2.3.4:8000", "8.8.8.8", "8.8.8.8, 1.2.3.4"},
127		{"1.2.3.4:8000", "8.8.8.8, 1.2.3.4", "8.8.8.8, 1.2.3.4, 1.2.3.4"},
128		{"1.2.3.4:8000", "foo,bar", "foo,bar, 1.2.3.4"},
129	}
130	for i, test := range testCases {
131		req := &http.Request{
132			RemoteAddr: test.addr,
133			Header:     make(http.Header),
134		}
135		if test.forwarded != "" {
136			req.Header.Set("X-Forwarded-For", test.forwarded)
137		}
138
139		AppendForwardedForHeader(req)
140		actual := req.Header.Get("X-Forwarded-For")
141		if actual != test.expected {
142			t.Errorf("[%d] Expected %q, Got %q", i, test.expected, actual)
143		}
144	}
145}
146
147func TestProxierWithNoProxyCIDR(t *testing.T) {
148	testCases := []struct {
149		name    string
150		noProxy string
151		url     string
152
153		expectedDelegated bool
154	}{
155		{
156			name:              "no env",
157			url:               "https://192.168.143.1/api",
158			expectedDelegated: true,
159		},
160		{
161			name:              "no cidr",
162			noProxy:           "192.168.63.1",
163			url:               "https://192.168.143.1/api",
164			expectedDelegated: true,
165		},
166		{
167			name:              "hostname",
168			noProxy:           "192.168.63.0/24,192.168.143.0/24",
169			url:               "https://my-hostname/api",
170			expectedDelegated: true,
171		},
172		{
173			name:              "match second cidr",
174			noProxy:           "192.168.63.0/24,192.168.143.0/24",
175			url:               "https://192.168.143.1/api",
176			expectedDelegated: false,
177		},
178		{
179			name:              "match second cidr with host:port",
180			noProxy:           "192.168.63.0/24,192.168.143.0/24",
181			url:               "https://192.168.143.1:8443/api",
182			expectedDelegated: false,
183		},
184		{
185			name:              "IPv6 cidr",
186			noProxy:           "2001:db8::/48",
187			url:               "https://[2001:db8::1]/api",
188			expectedDelegated: false,
189		},
190		{
191			name:              "IPv6+port cidr",
192			noProxy:           "2001:db8::/48",
193			url:               "https://[2001:db8::1]:8443/api",
194			expectedDelegated: false,
195		},
196		{
197			name:              "IPv6, not matching cidr",
198			noProxy:           "2001:db8::/48",
199			url:               "https://[2001:db8:1::1]/api",
200			expectedDelegated: true,
201		},
202		{
203			name:              "IPv6+port, not matching cidr",
204			noProxy:           "2001:db8::/48",
205			url:               "https://[2001:db8:1::1]:8443/api",
206			expectedDelegated: true,
207		},
208	}
209
210	for _, test := range testCases {
211		os.Setenv("NO_PROXY", test.noProxy)
212		actualDelegated := false
213		proxyFunc := NewProxierWithNoProxyCIDR(func(req *http.Request) (*url.URL, error) {
214			actualDelegated = true
215			return nil, nil
216		})
217
218		req, err := http.NewRequest("GET", test.url, nil)
219		if err != nil {
220			t.Errorf("%s: unexpected err: %v", test.name, err)
221			continue
222		}
223		if _, err := proxyFunc(req); err != nil {
224			t.Errorf("%s: unexpected err: %v", test.name, err)
225			continue
226		}
227
228		if test.expectedDelegated != actualDelegated {
229			t.Errorf("%s: expected %v, got %v", test.name, test.expectedDelegated, actualDelegated)
230			continue
231		}
232	}
233}
234
235type fakeTLSClientConfigHolder struct {
236	called bool
237}
238
239func (f *fakeTLSClientConfigHolder) TLSClientConfig() *tls.Config {
240	f.called = true
241	return nil
242}
243func (f *fakeTLSClientConfigHolder) RoundTrip(*http.Request) (*http.Response, error) {
244	return nil, nil
245}
246
247func TestTLSClientConfigHolder(t *testing.T) {
248	rt := &fakeTLSClientConfigHolder{}
249	TLSClientConfig(rt)
250
251	if !rt.called {
252		t.Errorf("didn't find tls config")
253	}
254}
255
256func TestJoinPreservingTrailingSlash(t *testing.T) {
257	tests := []struct {
258		a    string
259		b    string
260		want string
261	}{
262		// All empty
263		{"", "", ""},
264
265		// Empty a
266		{"", "/", "/"},
267		{"", "foo", "foo"},
268		{"", "/foo", "/foo"},
269		{"", "/foo/", "/foo/"},
270
271		// Empty b
272		{"/", "", "/"},
273		{"foo", "", "foo"},
274		{"/foo", "", "/foo"},
275		{"/foo/", "", "/foo/"},
276
277		// Both populated
278		{"/", "/", "/"},
279		{"foo", "foo", "foo/foo"},
280		{"/foo", "/foo", "/foo/foo"},
281		{"/foo/", "/foo/", "/foo/foo/"},
282	}
283	for _, tt := range tests {
284		name := fmt.Sprintf("%q+%q=%q", tt.a, tt.b, tt.want)
285		t.Run(name, func(t *testing.T) {
286			if got := JoinPreservingTrailingSlash(tt.a, tt.b); got != tt.want {
287				t.Errorf("JoinPreservingTrailingSlash() = %v, want %v", got, tt.want)
288			}
289		})
290	}
291}
292
293func TestConnectWithRedirects(t *testing.T) {
294	tests := []struct {
295		desc              string
296		redirects         []string
297		method            string // initial request method, empty == GET
298		expectError       bool
299		expectedRedirects int
300		newPort           bool // special case different port test
301	}{{
302		desc:              "relative redirects allowed",
303		redirects:         []string{"/ok"},
304		expectedRedirects: 1,
305	}, {
306		desc:              "redirects to the same host are allowed",
307		redirects:         []string{"http://HOST/ok"}, // HOST replaced with server address in test
308		expectedRedirects: 1,
309	}, {
310		desc:              "POST redirects to GET",
311		method:            http.MethodPost,
312		redirects:         []string{"/ok"},
313		expectedRedirects: 1,
314	}, {
315		desc:              "PUT redirects to GET",
316		method:            http.MethodPut,
317		redirects:         []string{"/ok"},
318		expectedRedirects: 1,
319	}, {
320		desc:              "DELETE redirects to GET",
321		method:            http.MethodDelete,
322		redirects:         []string{"/ok"},
323		expectedRedirects: 1,
324	}, {
325		desc:              "9 redirects are allowed",
326		redirects:         []string{"/1", "/2", "/3", "/4", "/5", "/6", "/7", "/8", "/9"},
327		expectedRedirects: 9,
328	}, {
329		desc:        "10 redirects are forbidden",
330		redirects:   []string{"/1", "/2", "/3", "/4", "/5", "/6", "/7", "/8", "/9", "/10"},
331		expectError: true,
332	}, {
333		desc:              "redirect to different host are prevented",
334		redirects:         []string{"http://example.com/foo"},
335		expectedRedirects: 0,
336	}, {
337		desc:              "multiple redirect to different host forbidden",
338		redirects:         []string{"/1", "/2", "/3", "http://example.com/foo"},
339		expectedRedirects: 3,
340	}, {
341		desc:              "redirect to different port is allowed",
342		redirects:         []string{"http://HOST/foo"},
343		expectedRedirects: 1,
344		newPort:           true,
345	}}
346
347	const resultString = "Test output"
348	for _, test := range tests {
349		t.Run(test.desc, func(t *testing.T) {
350			redirectCount := 0
351			s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
352				// Verify redirect request.
353				if redirectCount > 0 {
354					expectedURL, err := url.Parse(test.redirects[redirectCount-1])
355					require.NoError(t, err, "test URL error")
356					assert.Equal(t, req.URL.Path, expectedURL.Path, "unknown redirect path")
357					assert.Equal(t, http.MethodGet, req.Method, "redirects must always be GET")
358				}
359				if redirectCount < len(test.redirects) {
360					http.Redirect(w, req, test.redirects[redirectCount], http.StatusFound)
361					redirectCount++
362				} else if redirectCount == len(test.redirects) {
363					w.Write([]byte(resultString))
364				} else {
365					t.Errorf("unexpected number of redirects %d to %s", redirectCount, req.URL.String())
366				}
367			}))
368			defer s.Close()
369
370			u, err := url.Parse(s.URL)
371			require.NoError(t, err, "Error parsing server URL")
372			host := u.Host
373
374			// Special case new-port test with a secondary server.
375			if test.newPort {
376				s2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
377					w.Write([]byte(resultString))
378				}))
379				defer s2.Close()
380				u2, err := url.Parse(s2.URL)
381				require.NoError(t, err, "Error parsing secondary server URL")
382
383				// Sanity check: secondary server uses same hostname, different port.
384				require.Equal(t, u.Hostname(), u2.Hostname(), "sanity check: same hostname")
385				require.NotEqual(t, u.Port(), u2.Port(), "sanity check: different port")
386
387				// Redirect to the secondary server.
388				host = u2.Host
389
390			}
391
392			// Update redirect URLs with actual host.
393			for i := range test.redirects {
394				test.redirects[i] = strings.Replace(test.redirects[i], "HOST", host, 1)
395			}
396
397			method := test.method
398			if method == "" {
399				method = http.MethodGet
400			}
401
402			netdialer := &net.Dialer{
403				Timeout:   wait.ForeverTestTimeout,
404				KeepAlive: wait.ForeverTestTimeout,
405			}
406			dialer := DialerFunc(func(req *http.Request) (net.Conn, error) {
407				conn, err := netdialer.Dial("tcp", req.URL.Host)
408				if err != nil {
409					return conn, err
410				}
411				if err = req.Write(conn); err != nil {
412					require.NoError(t, conn.Close())
413					return nil, fmt.Errorf("error sending request: %v", err)
414				}
415				return conn, err
416			})
417			conn, rawResponse, err := ConnectWithRedirects(method, u, http.Header{} /*body*/, nil, dialer, true)
418			if test.expectError {
419				require.Error(t, err, "expected request error")
420				return
421			}
422
423			require.NoError(t, err, "unexpected request error")
424			assert.NoError(t, conn.Close(), "error closing connection")
425
426			resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(rawResponse)), nil)
427			require.NoError(t, err, "unexpected request error")
428
429			result, err := ioutil.ReadAll(resp.Body)
430			require.NoError(t, resp.Body.Close())
431			if test.expectedRedirects < len(test.redirects) {
432				// Expect the last redirect to be returned.
433				assert.Equal(t, http.StatusFound, resp.StatusCode, "Final response is not a redirect")
434				assert.Equal(t, test.redirects[len(test.redirects)-1], resp.Header.Get("Location"))
435				assert.NotEqual(t, resultString, string(result), "wrong content")
436			} else {
437				assert.Equal(t, resultString, string(result), "stream content does not match")
438			}
439		})
440	}
441}
442
443func TestAllowsHTTP2(t *testing.T) {
444	testcases := []struct {
445		Name         string
446		Transport    *http.Transport
447		ExpectAllows bool
448	}{
449		{
450			Name:         "empty",
451			Transport:    &http.Transport{},
452			ExpectAllows: true,
453		},
454		{
455			Name:         "empty tlsconfig",
456			Transport:    &http.Transport{TLSClientConfig: &tls.Config{}},
457			ExpectAllows: true,
458		},
459		{
460			Name:         "zero-length NextProtos",
461			Transport:    &http.Transport{TLSClientConfig: &tls.Config{NextProtos: []string{}}},
462			ExpectAllows: true,
463		},
464		{
465			Name:         "includes h2 in NextProtos after",
466			Transport:    &http.Transport{TLSClientConfig: &tls.Config{NextProtos: []string{"http/1.1", "h2"}}},
467			ExpectAllows: true,
468		},
469		{
470			Name:         "includes h2 in NextProtos before",
471			Transport:    &http.Transport{TLSClientConfig: &tls.Config{NextProtos: []string{"h2", "http/1.1"}}},
472			ExpectAllows: true,
473		},
474		{
475			Name:         "includes h2 in NextProtos between",
476			Transport:    &http.Transport{TLSClientConfig: &tls.Config{NextProtos: []string{"http/1.1", "h2", "h3"}}},
477			ExpectAllows: true,
478		},
479		{
480			Name:         "excludes h2 in NextProtos",
481			Transport:    &http.Transport{TLSClientConfig: &tls.Config{NextProtos: []string{"http/1.1"}}},
482			ExpectAllows: false,
483		},
484	}
485
486	for _, tc := range testcases {
487		t.Run(tc.Name, func(t *testing.T) {
488			allows := allowsHTTP2(tc.Transport)
489			if allows != tc.ExpectAllows {
490				t.Errorf("expected %v, got %v", tc.ExpectAllows, allows)
491			}
492		})
493	}
494}
495