1/*
2Copyright 2015 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package transport
18
19import (
20	"context"
21	"crypto/tls"
22	"errors"
23	"fmt"
24	"net/http"
25	"testing"
26)
27
28const (
29	rootCACert = `-----BEGIN CERTIFICATE-----
30MIIC4DCCAcqgAwIBAgIBATALBgkqhkiG9w0BAQswIzEhMB8GA1UEAwwYMTAuMTMu
31MTI5LjEwNkAxNDIxMzU5MDU4MB4XDTE1MDExNTIxNTczN1oXDTE2MDExNTIxNTcz
32OFowIzEhMB8GA1UEAwwYMTAuMTMuMTI5LjEwNkAxNDIxMzU5MDU4MIIBIjANBgkq
33hkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAunDRXGwsiYWGFDlWH6kjGun+PshDGeZX
34xtx9lUnL8pIRWH3wX6f13PO9sktaOWW0T0mlo6k2bMlSLlSZgG9H6og0W6gLS3vq
35s4VavZ6DbXIwemZG2vbRwsvR+t4G6Nbwelm6F8RFnA1Fwt428pavmNQ/wgYzo+T1
361eS+HiN4ACnSoDSx3QRWcgBkB1g6VReofVjx63i0J+w8Q/41L9GUuLqquFxu6ZnH
3760vTB55lHgFiDLjA1FkEz2dGvGh/wtnFlRvjaPC54JH2K1mPYAUXTreoeJtLJKX0
38ycoiyB24+zGCniUmgIsmQWRPaOPircexCp1BOeze82BT1LCZNTVaxQIDAQABoyMw
39ITAOBgNVHQ8BAf8EBAMCAKQwDwYDVR0TAQH/BAUwAwEB/zALBgkqhkiG9w0BAQsD
40ggEBADMxsUuAFlsYDpF4fRCzXXwrhbtj4oQwcHpbu+rnOPHCZupiafzZpDu+rw4x
41YGPnCb594bRTQn4pAu3Ac18NbLD5pV3uioAkv8oPkgr8aUhXqiv7KdDiaWm6sbAL
42EHiXVBBAFvQws10HMqMoKtO8f1XDNAUkWduakR/U6yMgvOPwS7xl0eUTqyRB6zGb
43K55q2dejiFWaFqB/y78txzvz6UlOZKE44g2JAVoJVM6kGaxh33q8/FmrL4kuN3ut
44W+MmJCVDvd4eEqPwbp7146ZWTqpIJ8lvA6wuChtqV8lhAPka2hD/LMqY8iXNmfXD
45uml0obOEy+ON91k+SWTJ3ggmF/U=
46-----END CERTIFICATE-----`
47
48	certData = `-----BEGIN CERTIFICATE-----
49MIIC6jCCAdSgAwIBAgIBCzALBgkqhkiG9w0BAQswIzEhMB8GA1UEAwwYMTAuMTMu
50MTI5LjEwNkAxNDIxMzU5MDU4MB4XDTE1MDExNTIyMDEzMVoXDTE2MDExNTIyMDEz
51MlowGzEZMBcGA1UEAxMQb3BlbnNoaWZ0LWNsaWVudDCCASIwDQYJKoZIhvcNAQEB
52BQADggEPADCCAQoCggEBAKtdhz0+uCLXw5cSYns9rU/XifFSpb/x24WDdrm72S/v
53b9BPYsAStiP148buylr1SOuNi8sTAZmlVDDIpIVwMLff+o2rKYDicn9fjbrTxTOj
54lI4pHJBH+JU3AJ0tbajupioh70jwFS0oYpwtneg2zcnE2Z4l6mhrj2okrc5Q1/X2
55I2HChtIU4JYTisObtin10QKJX01CLfYXJLa8upWzKZ4/GOcHG+eAV3jXWoXidtjb
561Usw70amoTZ6mIVCkiu1QwCoa8+ycojGfZhvqMsAp1536ZcCul+Na+AbCv4zKS7F
57kQQaImVrXdUiFansIoofGlw/JNuoKK6ssVpS5Ic3pgcCAwEAAaM1MDMwDgYDVR0P
58AQH/BAQDAgCgMBMGA1UdJQQMMAoGCCsGAQUFBwMCMAwGA1UdEwEB/wQCMAAwCwYJ
59KoZIhvcNAQELA4IBAQCKLREH7bXtXtZ+8vI6cjD7W3QikiArGqbl36bAhhWsJLp/
60p/ndKz39iFNaiZ3GlwIURWOOKx3y3GA0x9m8FR+Llthf0EQ8sUjnwaknWs0Y6DQ3
61jjPFZOpV3KPCFrdMJ3++E3MgwFC/Ih/N2ebFX9EcV9Vcc6oVWMdwT0fsrhu683rq
626GSR/3iVX1G/pmOiuaR0fNUaCyCfYrnI4zHBDgSfnlm3vIvN2lrsR/DQBakNL8DJ
63HBgKxMGeUPoneBv+c8DMXIL0EhaFXRlBv9QW45/GiAIOuyFJ0i6hCtGZpJjq4OpQ
64BRjCI+izPzFTjsxD4aORE+WOkyWFCGPWKfNejfw0
65-----END CERTIFICATE-----`
66
67	keyData = `-----BEGIN RSA PRIVATE KEY-----
68MIIEowIBAAKCAQEAq12HPT64ItfDlxJiez2tT9eJ8VKlv/HbhYN2ubvZL+9v0E9i
69wBK2I/Xjxu7KWvVI642LyxMBmaVUMMikhXAwt9/6jaspgOJyf1+NutPFM6OUjikc
70kEf4lTcAnS1tqO6mKiHvSPAVLShinC2d6DbNycTZniXqaGuPaiStzlDX9fYjYcKG
710hTglhOKw5u2KfXRAolfTUIt9hcktry6lbMpnj8Y5wcb54BXeNdaheJ22NvVSzDv
72RqahNnqYhUKSK7VDAKhrz7JyiMZ9mG+oywCnXnfplwK6X41r4BsK/jMpLsWRBBoi
73ZWtd1SIVqewiih8aXD8k26gorqyxWlLkhzemBwIDAQABAoIBAD2XYRs3JrGHQUpU
74FkdbVKZkvrSY0vAZOqBTLuH0zUv4UATb8487anGkWBjRDLQCgxH+jucPTrztekQK
75aW94clo0S3aNtV4YhbSYIHWs1a0It0UdK6ID7CmdWkAj6s0T8W8lQT7C46mWYVLm
765mFnCTHi6aB42jZrqmEpC7sivWwuU0xqj3Ml8kkxQCGmyc9JjmCB4OrFFC8NNt6M
77ObvQkUI6Z3nO4phTbpxkE1/9dT0MmPIF7GhHVzJMS+EyyRYUDllZ0wvVSOM3qZT0
78JMUaBerkNwm9foKJ1+dv2nMKZZbJajv7suUDCfU44mVeaEO+4kmTKSGCGjjTBGkr
797L1ySDECgYEA5ElIMhpdBzIivCuBIH8LlUeuzd93pqssO1G2Xg0jHtfM4tz7fyeI
80cr90dc8gpli24dkSxzLeg3Tn3wIj/Bu64m2TpZPZEIlukYvgdgArmRIPQVxerYey
81OkrfTNkxU1HXsYjLCdGcGXs5lmb+K/kuTcFxaMOs7jZi7La+jEONwf8CgYEAwCs/
82rUOOA0klDsWWisbivOiNPII79c9McZCNBqncCBfMUoiGe8uWDEO4TFHN60vFuVk9
838PkwpCfvaBUX+ajvbafIfHxsnfk1M04WLGCeqQ/ym5Q4sQoQOcC1b1y9qc/xEWfg
84nIUuia0ukYRpl7qQa3tNg+BNFyjypW8zukUAC/kCgYB1/Kojuxx5q5/oQVPrx73k
852bevD+B3c+DYh9MJqSCNwFtUpYIWpggPxoQan4LwdsmO0PKzocb/ilyNFj4i/vII
86NToqSc/WjDFpaDIKyuu9oWfhECye45NqLWhb/6VOuu4QA/Nsj7luMhIBehnEAHW+
87GkzTKM8oD1PxpEG3nPKXYQKBgQC6AuMPRt3XBl1NkCrpSBy/uObFlFaP2Enpf39S
883OZ0Gv0XQrnSaL1kP8TMcz68rMrGX8DaWYsgytstR4W+jyy7WvZwsUu+GjTJ5aMG
8977uEcEBpIi9CBzivfn7hPccE8ZgqPf+n4i6q66yxBJflW5xhvafJqDtW2LcPNbW/
90bvzdmQKBgExALRUXpq+5dbmkdXBHtvXdRDZ6rVmrnjy4nI5bPw+1GqQqk6uAR6B/
91F6NmLCQOO4PDG/cuatNHIr2FrwTmGdEL6ObLUGWn9Oer9gJhHVqqsY5I4sEPo4XX
92stR0Yiw0buV6DL/moUO0HIM9Bjh96HJp+LxiIS6UCdIhMPp5HoQa
93-----END RSA PRIVATE KEY-----`
94)
95
96func TestNew(t *testing.T) {
97	testCases := map[string]struct {
98		Config       *Config
99		Err          bool
100		TLS          bool
101		TLSCert      bool
102		TLSErr       bool
103		Default      bool
104		Insecure     bool
105		DefaultRoots bool
106	}{
107		"default transport": {
108			Default: true,
109			Config:  &Config{},
110		},
111
112		"insecure": {
113			TLS:          true,
114			Insecure:     true,
115			DefaultRoots: true,
116			Config: &Config{TLS: TLSConfig{
117				Insecure: true,
118			}},
119		},
120
121		"server name": {
122			TLS:          true,
123			DefaultRoots: true,
124			Config: &Config{TLS: TLSConfig{
125				ServerName: "foo",
126			}},
127		},
128
129		"ca transport": {
130			TLS: true,
131			Config: &Config{
132				TLS: TLSConfig{
133					CAData: []byte(rootCACert),
134				},
135			},
136		},
137		"bad ca file transport": {
138			Err: true,
139			Config: &Config{
140				TLS: TLSConfig{
141					CAFile: "invalid file",
142				},
143			},
144		},
145		"ca data overriding bad ca file transport": {
146			TLS: true,
147			Config: &Config{
148				TLS: TLSConfig{
149					CAData: []byte(rootCACert),
150					CAFile: "invalid file",
151				},
152			},
153		},
154
155		"cert transport": {
156			TLS:     true,
157			TLSCert: true,
158			Config: &Config{
159				TLS: TLSConfig{
160					CAData:   []byte(rootCACert),
161					CertData: []byte(certData),
162					KeyData:  []byte(keyData),
163				},
164			},
165		},
166		"bad cert data transport": {
167			Err: true,
168			Config: &Config{
169				TLS: TLSConfig{
170					CAData:   []byte(rootCACert),
171					CertData: []byte(certData),
172					KeyData:  []byte("bad key data"),
173				},
174			},
175		},
176		"bad file cert transport": {
177			Err: true,
178			Config: &Config{
179				TLS: TLSConfig{
180					CAData:   []byte(rootCACert),
181					CertData: []byte(certData),
182					KeyFile:  "invalid file",
183				},
184			},
185		},
186		"key data overriding bad file cert transport": {
187			TLS:     true,
188			TLSCert: true,
189			Config: &Config{
190				TLS: TLSConfig{
191					CAData:   []byte(rootCACert),
192					CertData: []byte(certData),
193					KeyData:  []byte(keyData),
194					KeyFile:  "invalid file",
195				},
196			},
197		},
198		"callback cert and key": {
199			TLS:     true,
200			TLSCert: true,
201			Config: &Config{
202				TLS: TLSConfig{
203					CAData: []byte(rootCACert),
204					GetCert: func() (*tls.Certificate, error) {
205						crt, err := tls.X509KeyPair([]byte(certData), []byte(keyData))
206						return &crt, err
207					},
208				},
209			},
210		},
211		"cert callback error": {
212			TLS:     true,
213			TLSCert: true,
214			TLSErr:  true,
215			Config: &Config{
216				TLS: TLSConfig{
217					CAData: []byte(rootCACert),
218					GetCert: func() (*tls.Certificate, error) {
219						return nil, errors.New("GetCert failure")
220					},
221				},
222			},
223		},
224		"cert data overrides empty callback result": {
225			TLS:     true,
226			TLSCert: true,
227			Config: &Config{
228				TLS: TLSConfig{
229					CAData: []byte(rootCACert),
230					GetCert: func() (*tls.Certificate, error) {
231						return nil, nil
232					},
233					CertData: []byte(certData),
234					KeyData:  []byte(keyData),
235				},
236			},
237		},
238		"callback returns nothing": {
239			TLS:     true,
240			TLSCert: true,
241			Config: &Config{
242				TLS: TLSConfig{
243					CAData: []byte(rootCACert),
244					GetCert: func() (*tls.Certificate, error) {
245						return nil, nil
246					},
247				},
248			},
249		},
250	}
251	for k, testCase := range testCases {
252		t.Run(k, func(t *testing.T) {
253			rt, err := New(testCase.Config)
254			switch {
255			case testCase.Err && err == nil:
256				t.Fatal("unexpected non-error")
257			case !testCase.Err && err != nil:
258				t.Fatalf("unexpected error: %v", err)
259			}
260			if testCase.Err {
261				return
262			}
263
264			switch {
265			case testCase.Default && rt != http.DefaultTransport:
266				t.Fatalf("got %#v, expected the default transport", rt)
267			case !testCase.Default && rt == http.DefaultTransport:
268				t.Fatalf("got %#v, expected non-default transport", rt)
269			}
270
271			// We only know how to check TLSConfig on http.Transports
272			transport := rt.(*http.Transport)
273			switch {
274			case testCase.TLS && transport.TLSClientConfig == nil:
275				t.Fatalf("got %#v, expected TLSClientConfig", transport)
276			case !testCase.TLS && transport.TLSClientConfig != nil:
277				t.Fatalf("got %#v, expected no TLSClientConfig", transport)
278			}
279			if !testCase.TLS {
280				return
281			}
282
283			switch {
284			case testCase.DefaultRoots && transport.TLSClientConfig.RootCAs != nil:
285				t.Fatalf("got %#v, expected nil root CAs", transport.TLSClientConfig.RootCAs)
286			case !testCase.DefaultRoots && transport.TLSClientConfig.RootCAs == nil:
287				t.Fatalf("got %#v, expected non-nil root CAs", transport.TLSClientConfig.RootCAs)
288			}
289
290			switch {
291			case testCase.Insecure != transport.TLSClientConfig.InsecureSkipVerify:
292				t.Fatalf("got %#v, expected %#v", transport.TLSClientConfig.InsecureSkipVerify, testCase.Insecure)
293			}
294
295			switch {
296			case testCase.TLSCert && transport.TLSClientConfig.GetClientCertificate == nil:
297				t.Fatalf("got %#v, expected TLSClientConfig.GetClientCertificate", transport.TLSClientConfig)
298			case !testCase.TLSCert && transport.TLSClientConfig.GetClientCertificate != nil:
299				t.Fatalf("got %#v, expected no TLSClientConfig.GetClientCertificate", transport.TLSClientConfig)
300			}
301			if !testCase.TLSCert {
302				return
303			}
304
305			_, err = transport.TLSClientConfig.GetClientCertificate(nil)
306			switch {
307			case testCase.TLSErr && err == nil:
308				t.Error("got nil error from GetClientCertificate, expected non-nil")
309			case !testCase.TLSErr && err != nil:
310				t.Errorf("got error from GetClientCertificate: %q, expected nil", err)
311			}
312		})
313	}
314}
315
316type fakeRoundTripper struct {
317	Req  *http.Request
318	Resp *http.Response
319	Err  error
320}
321
322func (rt *fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
323	rt.Req = req
324	return rt.Resp, rt.Err
325}
326
327type chainRoundTripper struct {
328	rt    http.RoundTripper
329	value string
330}
331
332func testChain(value string) WrapperFunc {
333	return func(rt http.RoundTripper) http.RoundTripper {
334		return &chainRoundTripper{rt: rt, value: value}
335	}
336}
337
338func (rt *chainRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
339	resp, err := rt.rt.RoundTrip(req)
340	if resp != nil {
341		if resp.Header == nil {
342			resp.Header = make(http.Header)
343		}
344		resp.Header.Set("Value", resp.Header.Get("Value")+rt.value)
345	}
346	return resp, err
347}
348
349func TestWrappers(t *testing.T) {
350	resp1 := &http.Response{}
351	wrapperResp1 := func(rt http.RoundTripper) http.RoundTripper {
352		return &fakeRoundTripper{Resp: resp1}
353	}
354	resp2 := &http.Response{}
355	wrapperResp2 := func(rt http.RoundTripper) http.RoundTripper {
356		return &fakeRoundTripper{Resp: resp2}
357	}
358
359	tests := []struct {
360		name    string
361		fns     []WrapperFunc
362		wantNil bool
363		want    func(*http.Response) bool
364	}{
365		{fns: []WrapperFunc{}, wantNil: true},
366		{fns: []WrapperFunc{nil, nil}, wantNil: true},
367		{fns: []WrapperFunc{nil}, wantNil: false},
368
369		{fns: []WrapperFunc{nil, wrapperResp1}, want: func(resp *http.Response) bool { return resp == resp1 }},
370		{fns: []WrapperFunc{wrapperResp1, nil}, want: func(resp *http.Response) bool { return resp == resp1 }},
371		{fns: []WrapperFunc{nil, wrapperResp1, nil}, want: func(resp *http.Response) bool { return resp == resp1 }},
372		{fns: []WrapperFunc{nil, wrapperResp1, wrapperResp2}, want: func(resp *http.Response) bool { return resp == resp2 }},
373		{fns: []WrapperFunc{wrapperResp1, wrapperResp2}, want: func(resp *http.Response) bool { return resp == resp2 }},
374		{fns: []WrapperFunc{wrapperResp2, wrapperResp1}, want: func(resp *http.Response) bool { return resp == resp1 }},
375
376		{fns: []WrapperFunc{testChain("1")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "1" }},
377		{fns: []WrapperFunc{testChain("1"), testChain("2")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "12" }},
378		{fns: []WrapperFunc{testChain("2"), testChain("1")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "21" }},
379		{fns: []WrapperFunc{testChain("1"), testChain("2"), testChain("3")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "123" }},
380	}
381	for _, tt := range tests {
382		t.Run(tt.name, func(t *testing.T) {
383			got := Wrappers(tt.fns...)
384			if got == nil != tt.wantNil {
385				t.Errorf("Wrappers() = %v", got)
386				return
387			}
388			if got == nil {
389				return
390			}
391
392			rt := &fakeRoundTripper{Resp: &http.Response{}}
393			nested := got(rt)
394			req := &http.Request{}
395			resp, _ := nested.RoundTrip(req)
396			if tt.want != nil && !tt.want(resp) {
397				t.Errorf("unexpected response: %#v", resp)
398			}
399		})
400	}
401}
402
403func Test_contextCanceller_RoundTrip(t *testing.T) {
404	tests := []struct {
405		name string
406		open bool
407		want bool
408	}{
409		{name: "open context should call nested round tripper", open: true, want: true},
410		{name: "closed context should return a known error", open: false, want: false},
411	}
412	for _, tt := range tests {
413		t.Run(tt.name, func(t *testing.T) {
414			req := &http.Request{}
415			rt := &fakeRoundTripper{Resp: &http.Response{}}
416			ctx := context.Background()
417			if !tt.open {
418				c, fn := context.WithCancel(ctx)
419				fn()
420				ctx = c
421			}
422			errTesting := fmt.Errorf("testing")
423			b := &contextCanceller{
424				rt:  rt,
425				ctx: ctx,
426				err: errTesting,
427			}
428			got, err := b.RoundTrip(req)
429			if tt.want {
430				if err != nil {
431					t.Errorf("unexpected error: %v", err)
432				}
433				if got != rt.Resp {
434					t.Errorf("wanted response")
435				}
436				if req != rt.Req {
437					t.Errorf("expect nested call")
438				}
439			} else {
440				if err != errTesting {
441					t.Errorf("unexpected error: %v", err)
442				}
443				if got != nil {
444					t.Errorf("wanted no response")
445				}
446				if rt.Req != nil {
447					t.Errorf("want no nested call")
448				}
449			}
450		})
451	}
452}
453