1/*
2Copyright 2015 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package transport
18
19import (
20	"context"
21	"crypto/tls"
22	"errors"
23	"fmt"
24	"net/http"
25	"testing"
26)
27
28const (
29	rootCACert = `-----BEGIN CERTIFICATE-----
30MIIC4DCCAcqgAwIBAgIBATALBgkqhkiG9w0BAQswIzEhMB8GA1UEAwwYMTAuMTMu
31MTI5LjEwNkAxNDIxMzU5MDU4MB4XDTE1MDExNTIxNTczN1oXDTE2MDExNTIxNTcz
32OFowIzEhMB8GA1UEAwwYMTAuMTMuMTI5LjEwNkAxNDIxMzU5MDU4MIIBIjANBgkq
33hkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAunDRXGwsiYWGFDlWH6kjGun+PshDGeZX
34xtx9lUnL8pIRWH3wX6f13PO9sktaOWW0T0mlo6k2bMlSLlSZgG9H6og0W6gLS3vq
35s4VavZ6DbXIwemZG2vbRwsvR+t4G6Nbwelm6F8RFnA1Fwt428pavmNQ/wgYzo+T1
361eS+HiN4ACnSoDSx3QRWcgBkB1g6VReofVjx63i0J+w8Q/41L9GUuLqquFxu6ZnH
3760vTB55lHgFiDLjA1FkEz2dGvGh/wtnFlRvjaPC54JH2K1mPYAUXTreoeJtLJKX0
38ycoiyB24+zGCniUmgIsmQWRPaOPircexCp1BOeze82BT1LCZNTVaxQIDAQABoyMw
39ITAOBgNVHQ8BAf8EBAMCAKQwDwYDVR0TAQH/BAUwAwEB/zALBgkqhkiG9w0BAQsD
40ggEBADMxsUuAFlsYDpF4fRCzXXwrhbtj4oQwcHpbu+rnOPHCZupiafzZpDu+rw4x
41YGPnCb594bRTQn4pAu3Ac18NbLD5pV3uioAkv8oPkgr8aUhXqiv7KdDiaWm6sbAL
42EHiXVBBAFvQws10HMqMoKtO8f1XDNAUkWduakR/U6yMgvOPwS7xl0eUTqyRB6zGb
43K55q2dejiFWaFqB/y78txzvz6UlOZKE44g2JAVoJVM6kGaxh33q8/FmrL4kuN3ut
44W+MmJCVDvd4eEqPwbp7146ZWTqpIJ8lvA6wuChtqV8lhAPka2hD/LMqY8iXNmfXD
45uml0obOEy+ON91k+SWTJ3ggmF/U=
46-----END CERTIFICATE-----`
47
48	certData = `-----BEGIN CERTIFICATE-----
49MIIC6jCCAdSgAwIBAgIBCzALBgkqhkiG9w0BAQswIzEhMB8GA1UEAwwYMTAuMTMu
50MTI5LjEwNkAxNDIxMzU5MDU4MB4XDTE1MDExNTIyMDEzMVoXDTE2MDExNTIyMDEz
51MlowGzEZMBcGA1UEAxMQb3BlbnNoaWZ0LWNsaWVudDCCASIwDQYJKoZIhvcNAQEB
52BQADggEPADCCAQoCggEBAKtdhz0+uCLXw5cSYns9rU/XifFSpb/x24WDdrm72S/v
53b9BPYsAStiP148buylr1SOuNi8sTAZmlVDDIpIVwMLff+o2rKYDicn9fjbrTxTOj
54lI4pHJBH+JU3AJ0tbajupioh70jwFS0oYpwtneg2zcnE2Z4l6mhrj2okrc5Q1/X2
55I2HChtIU4JYTisObtin10QKJX01CLfYXJLa8upWzKZ4/GOcHG+eAV3jXWoXidtjb
561Usw70amoTZ6mIVCkiu1QwCoa8+ycojGfZhvqMsAp1536ZcCul+Na+AbCv4zKS7F
57kQQaImVrXdUiFansIoofGlw/JNuoKK6ssVpS5Ic3pgcCAwEAAaM1MDMwDgYDVR0P
58AQH/BAQDAgCgMBMGA1UdJQQMMAoGCCsGAQUFBwMCMAwGA1UdEwEB/wQCMAAwCwYJ
59KoZIhvcNAQELA4IBAQCKLREH7bXtXtZ+8vI6cjD7W3QikiArGqbl36bAhhWsJLp/
60p/ndKz39iFNaiZ3GlwIURWOOKx3y3GA0x9m8FR+Llthf0EQ8sUjnwaknWs0Y6DQ3
61jjPFZOpV3KPCFrdMJ3++E3MgwFC/Ih/N2ebFX9EcV9Vcc6oVWMdwT0fsrhu683rq
626GSR/3iVX1G/pmOiuaR0fNUaCyCfYrnI4zHBDgSfnlm3vIvN2lrsR/DQBakNL8DJ
63HBgKxMGeUPoneBv+c8DMXIL0EhaFXRlBv9QW45/GiAIOuyFJ0i6hCtGZpJjq4OpQ
64BRjCI+izPzFTjsxD4aORE+WOkyWFCGPWKfNejfw0
65-----END CERTIFICATE-----`
66
67	keyData = `-----BEGIN RSA PRIVATE KEY-----
68MIIEowIBAAKCAQEAq12HPT64ItfDlxJiez2tT9eJ8VKlv/HbhYN2ubvZL+9v0E9i
69wBK2I/Xjxu7KWvVI642LyxMBmaVUMMikhXAwt9/6jaspgOJyf1+NutPFM6OUjikc
70kEf4lTcAnS1tqO6mKiHvSPAVLShinC2d6DbNycTZniXqaGuPaiStzlDX9fYjYcKG
710hTglhOKw5u2KfXRAolfTUIt9hcktry6lbMpnj8Y5wcb54BXeNdaheJ22NvVSzDv
72RqahNnqYhUKSK7VDAKhrz7JyiMZ9mG+oywCnXnfplwK6X41r4BsK/jMpLsWRBBoi
73ZWtd1SIVqewiih8aXD8k26gorqyxWlLkhzemBwIDAQABAoIBAD2XYRs3JrGHQUpU
74FkdbVKZkvrSY0vAZOqBTLuH0zUv4UATb8487anGkWBjRDLQCgxH+jucPTrztekQK
75aW94clo0S3aNtV4YhbSYIHWs1a0It0UdK6ID7CmdWkAj6s0T8W8lQT7C46mWYVLm
765mFnCTHi6aB42jZrqmEpC7sivWwuU0xqj3Ml8kkxQCGmyc9JjmCB4OrFFC8NNt6M
77ObvQkUI6Z3nO4phTbpxkE1/9dT0MmPIF7GhHVzJMS+EyyRYUDllZ0wvVSOM3qZT0
78JMUaBerkNwm9foKJ1+dv2nMKZZbJajv7suUDCfU44mVeaEO+4kmTKSGCGjjTBGkr
797L1ySDECgYEA5ElIMhpdBzIivCuBIH8LlUeuzd93pqssO1G2Xg0jHtfM4tz7fyeI
80cr90dc8gpli24dkSxzLeg3Tn3wIj/Bu64m2TpZPZEIlukYvgdgArmRIPQVxerYey
81OkrfTNkxU1HXsYjLCdGcGXs5lmb+K/kuTcFxaMOs7jZi7La+jEONwf8CgYEAwCs/
82rUOOA0klDsWWisbivOiNPII79c9McZCNBqncCBfMUoiGe8uWDEO4TFHN60vFuVk9
838PkwpCfvaBUX+ajvbafIfHxsnfk1M04WLGCeqQ/ym5Q4sQoQOcC1b1y9qc/xEWfg
84nIUuia0ukYRpl7qQa3tNg+BNFyjypW8zukUAC/kCgYB1/Kojuxx5q5/oQVPrx73k
852bevD+B3c+DYh9MJqSCNwFtUpYIWpggPxoQan4LwdsmO0PKzocb/ilyNFj4i/vII
86NToqSc/WjDFpaDIKyuu9oWfhECye45NqLWhb/6VOuu4QA/Nsj7luMhIBehnEAHW+
87GkzTKM8oD1PxpEG3nPKXYQKBgQC6AuMPRt3XBl1NkCrpSBy/uObFlFaP2Enpf39S
883OZ0Gv0XQrnSaL1kP8TMcz68rMrGX8DaWYsgytstR4W+jyy7WvZwsUu+GjTJ5aMG
8977uEcEBpIi9CBzivfn7hPccE8ZgqPf+n4i6q66yxBJflW5xhvafJqDtW2LcPNbW/
90bvzdmQKBgExALRUXpq+5dbmkdXBHtvXdRDZ6rVmrnjy4nI5bPw+1GqQqk6uAR6B/
91F6NmLCQOO4PDG/cuatNHIr2FrwTmGdEL6ObLUGWn9Oer9gJhHVqqsY5I4sEPo4XX
92stR0Yiw0buV6DL/moUO0HIM9Bjh96HJp+LxiIS6UCdIhMPp5HoQa
93-----END RSA PRIVATE KEY-----`
94)
95
96func TestNew(t *testing.T) {
97	testCases := map[string]struct {
98		Config       *Config
99		Err          bool
100		TLS          bool
101		TLSCert      bool
102		TLSErr       bool
103		Default      bool
104		Insecure     bool
105		DefaultRoots bool
106	}{
107		"default transport": {
108			Default: true,
109			Config:  &Config{},
110		},
111
112		"insecure": {
113			TLS:          true,
114			Insecure:     true,
115			DefaultRoots: true,
116			Config: &Config{TLS: TLSConfig{
117				Insecure: true,
118			}},
119		},
120
121		"server name": {
122			TLS:          true,
123			DefaultRoots: true,
124			Config: &Config{TLS: TLSConfig{
125				ServerName: "foo",
126			}},
127		},
128
129		"ca transport": {
130			TLS: true,
131			Config: &Config{
132				TLS: TLSConfig{
133					CAData: []byte(rootCACert),
134				},
135			},
136		},
137		"bad ca file transport": {
138			Err: true,
139			Config: &Config{
140				TLS: TLSConfig{
141					CAFile: "invalid file",
142				},
143			},
144		},
145		"bad ca data transport": {
146			Err: true,
147			Config: &Config{
148				TLS: TLSConfig{
149					CAData: []byte(rootCACert + "this is not valid"),
150				},
151			},
152		},
153		"ca data overriding bad ca file transport": {
154			TLS: true,
155			Config: &Config{
156				TLS: TLSConfig{
157					CAData: []byte(rootCACert),
158					CAFile: "invalid file",
159				},
160			},
161		},
162
163		"cert transport": {
164			TLS:     true,
165			TLSCert: true,
166			Config: &Config{
167				TLS: TLSConfig{
168					CAData:   []byte(rootCACert),
169					CertData: []byte(certData),
170					KeyData:  []byte(keyData),
171				},
172			},
173		},
174		"bad cert data transport": {
175			Err: true,
176			Config: &Config{
177				TLS: TLSConfig{
178					CAData:   []byte(rootCACert),
179					CertData: []byte(certData),
180					KeyData:  []byte("bad key data"),
181				},
182			},
183		},
184		"bad file cert transport": {
185			Err: true,
186			Config: &Config{
187				TLS: TLSConfig{
188					CAData:   []byte(rootCACert),
189					CertData: []byte(certData),
190					KeyFile:  "invalid file",
191				},
192			},
193		},
194		"key data overriding bad file cert transport": {
195			TLS:     true,
196			TLSCert: true,
197			Config: &Config{
198				TLS: TLSConfig{
199					CAData:   []byte(rootCACert),
200					CertData: []byte(certData),
201					KeyData:  []byte(keyData),
202					KeyFile:  "invalid file",
203				},
204			},
205		},
206		"callback cert and key": {
207			TLS:     true,
208			TLSCert: true,
209			Config: &Config{
210				TLS: TLSConfig{
211					CAData: []byte(rootCACert),
212					GetCert: func() (*tls.Certificate, error) {
213						crt, err := tls.X509KeyPair([]byte(certData), []byte(keyData))
214						return &crt, err
215					},
216				},
217			},
218		},
219		"cert callback error": {
220			TLS:     true,
221			TLSCert: true,
222			TLSErr:  true,
223			Config: &Config{
224				TLS: TLSConfig{
225					CAData: []byte(rootCACert),
226					GetCert: func() (*tls.Certificate, error) {
227						return nil, errors.New("GetCert failure")
228					},
229				},
230			},
231		},
232		"cert data overrides empty callback result": {
233			TLS:     true,
234			TLSCert: true,
235			Config: &Config{
236				TLS: TLSConfig{
237					CAData: []byte(rootCACert),
238					GetCert: func() (*tls.Certificate, error) {
239						return nil, nil
240					},
241					CertData: []byte(certData),
242					KeyData:  []byte(keyData),
243				},
244			},
245		},
246		"callback returns nothing": {
247			TLS:     true,
248			TLSCert: true,
249			Config: &Config{
250				TLS: TLSConfig{
251					CAData: []byte(rootCACert),
252					GetCert: func() (*tls.Certificate, error) {
253						return nil, nil
254					},
255				},
256			},
257		},
258	}
259	for k, testCase := range testCases {
260		t.Run(k, func(t *testing.T) {
261			rt, err := New(testCase.Config)
262			switch {
263			case testCase.Err && err == nil:
264				t.Fatal("unexpected non-error")
265			case !testCase.Err && err != nil:
266				t.Fatalf("unexpected error: %v", err)
267			}
268			if testCase.Err {
269				return
270			}
271
272			switch {
273			case testCase.Default && rt != http.DefaultTransport:
274				t.Fatalf("got %#v, expected the default transport", rt)
275			case !testCase.Default && rt == http.DefaultTransport:
276				t.Fatalf("got %#v, expected non-default transport", rt)
277			}
278
279			// We only know how to check TLSConfig on http.Transports
280			transport := rt.(*http.Transport)
281			switch {
282			case testCase.TLS && transport.TLSClientConfig == nil:
283				t.Fatalf("got %#v, expected TLSClientConfig", transport)
284			case !testCase.TLS && transport.TLSClientConfig != nil:
285				t.Fatalf("got %#v, expected no TLSClientConfig", transport)
286			}
287			if !testCase.TLS {
288				return
289			}
290
291			switch {
292			case testCase.DefaultRoots && transport.TLSClientConfig.RootCAs != nil:
293				t.Fatalf("got %#v, expected nil root CAs", transport.TLSClientConfig.RootCAs)
294			case !testCase.DefaultRoots && transport.TLSClientConfig.RootCAs == nil:
295				t.Fatalf("got %#v, expected non-nil root CAs", transport.TLSClientConfig.RootCAs)
296			}
297
298			switch {
299			case testCase.Insecure != transport.TLSClientConfig.InsecureSkipVerify:
300				t.Fatalf("got %#v, expected %#v", transport.TLSClientConfig.InsecureSkipVerify, testCase.Insecure)
301			}
302
303			switch {
304			case testCase.TLSCert && transport.TLSClientConfig.GetClientCertificate == nil:
305				t.Fatalf("got %#v, expected TLSClientConfig.GetClientCertificate", transport.TLSClientConfig)
306			case !testCase.TLSCert && transport.TLSClientConfig.GetClientCertificate != nil:
307				t.Fatalf("got %#v, expected no TLSClientConfig.GetClientCertificate", transport.TLSClientConfig)
308			}
309			if !testCase.TLSCert {
310				return
311			}
312
313			_, err = transport.TLSClientConfig.GetClientCertificate(nil)
314			switch {
315			case testCase.TLSErr && err == nil:
316				t.Error("got nil error from GetClientCertificate, expected non-nil")
317			case !testCase.TLSErr && err != nil:
318				t.Errorf("got error from GetClientCertificate: %q, expected nil", err)
319			}
320		})
321	}
322}
323
324type fakeRoundTripper struct {
325	Req  *http.Request
326	Resp *http.Response
327	Err  error
328}
329
330func (rt *fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
331	rt.Req = req
332	return rt.Resp, rt.Err
333}
334
335type chainRoundTripper struct {
336	rt    http.RoundTripper
337	value string
338}
339
340func testChain(value string) WrapperFunc {
341	return func(rt http.RoundTripper) http.RoundTripper {
342		return &chainRoundTripper{rt: rt, value: value}
343	}
344}
345
346func (rt *chainRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
347	resp, err := rt.rt.RoundTrip(req)
348	if resp != nil {
349		if resp.Header == nil {
350			resp.Header = make(http.Header)
351		}
352		resp.Header.Set("Value", resp.Header.Get("Value")+rt.value)
353	}
354	return resp, err
355}
356
357func TestWrappers(t *testing.T) {
358	resp1 := &http.Response{}
359	wrapperResp1 := func(rt http.RoundTripper) http.RoundTripper {
360		return &fakeRoundTripper{Resp: resp1}
361	}
362	resp2 := &http.Response{}
363	wrapperResp2 := func(rt http.RoundTripper) http.RoundTripper {
364		return &fakeRoundTripper{Resp: resp2}
365	}
366
367	tests := []struct {
368		name    string
369		fns     []WrapperFunc
370		wantNil bool
371		want    func(*http.Response) bool
372	}{
373		{fns: []WrapperFunc{}, wantNil: true},
374		{fns: []WrapperFunc{nil, nil}, wantNil: true},
375		{fns: []WrapperFunc{nil}, wantNil: false},
376
377		{fns: []WrapperFunc{nil, wrapperResp1}, want: func(resp *http.Response) bool { return resp == resp1 }},
378		{fns: []WrapperFunc{wrapperResp1, nil}, want: func(resp *http.Response) bool { return resp == resp1 }},
379		{fns: []WrapperFunc{nil, wrapperResp1, nil}, want: func(resp *http.Response) bool { return resp == resp1 }},
380		{fns: []WrapperFunc{nil, wrapperResp1, wrapperResp2}, want: func(resp *http.Response) bool { return resp == resp2 }},
381		{fns: []WrapperFunc{wrapperResp1, wrapperResp2}, want: func(resp *http.Response) bool { return resp == resp2 }},
382		{fns: []WrapperFunc{wrapperResp2, wrapperResp1}, want: func(resp *http.Response) bool { return resp == resp1 }},
383
384		{fns: []WrapperFunc{testChain("1")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "1" }},
385		{fns: []WrapperFunc{testChain("1"), testChain("2")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "12" }},
386		{fns: []WrapperFunc{testChain("2"), testChain("1")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "21" }},
387		{fns: []WrapperFunc{testChain("1"), testChain("2"), testChain("3")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "123" }},
388	}
389	for _, tt := range tests {
390		t.Run(tt.name, func(t *testing.T) {
391			got := Wrappers(tt.fns...)
392			if got == nil != tt.wantNil {
393				t.Errorf("Wrappers() = %v", got)
394				return
395			}
396			if got == nil {
397				return
398			}
399
400			rt := &fakeRoundTripper{Resp: &http.Response{}}
401			nested := got(rt)
402			req := &http.Request{}
403			resp, _ := nested.RoundTrip(req)
404			if tt.want != nil && !tt.want(resp) {
405				t.Errorf("unexpected response: %#v", resp)
406			}
407		})
408	}
409}
410
411func Test_contextCanceller_RoundTrip(t *testing.T) {
412	tests := []struct {
413		name string
414		open bool
415		want bool
416	}{
417		{name: "open context should call nested round tripper", open: true, want: true},
418		{name: "closed context should return a known error", open: false, want: false},
419	}
420	for _, tt := range tests {
421		t.Run(tt.name, func(t *testing.T) {
422			req := &http.Request{}
423			rt := &fakeRoundTripper{Resp: &http.Response{}}
424			ctx := context.Background()
425			if !tt.open {
426				c, fn := context.WithCancel(ctx)
427				fn()
428				ctx = c
429			}
430			errTesting := fmt.Errorf("testing")
431			b := &contextCanceller{
432				rt:  rt,
433				ctx: ctx,
434				err: errTesting,
435			}
436			got, err := b.RoundTrip(req)
437			if tt.want {
438				if err != nil {
439					t.Errorf("unexpected error: %v", err)
440				}
441				if got != rt.Resp {
442					t.Errorf("wanted response")
443				}
444				if req != rt.Req {
445					t.Errorf("expect nested call")
446				}
447			} else {
448				if err != errTesting {
449					t.Errorf("unexpected error: %v", err)
450				}
451				if got != nil {
452					t.Errorf("wanted no response")
453				}
454				if rt.Req != nil {
455					t.Errorf("want no nested call")
456				}
457			}
458		})
459	}
460}
461