1/*
2Copyright 2015 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package transport
18
19import (
20	"context"
21	"crypto/tls"
22	"crypto/x509"
23	"fmt"
24	"io/ioutil"
25	"net/http"
26)
27
28// New returns an http.RoundTripper that will provide the authentication
29// or transport level security defined by the provided Config.
30func New(config *Config) (http.RoundTripper, error) {
31	// Set transport level security
32	if config.Transport != nil && (config.HasCA() || config.HasCertAuth() || config.HasCertCallback() || config.TLS.Insecure) {
33		return nil, fmt.Errorf("using a custom transport with TLS certificate options or the insecure flag is not allowed")
34	}
35
36	var (
37		rt  http.RoundTripper
38		err error
39	)
40
41	if config.Transport != nil {
42		rt = config.Transport
43	} else {
44		rt, err = tlsCache.get(config)
45		if err != nil {
46			return nil, err
47		}
48	}
49
50	return HTTPWrappersForConfig(config, rt)
51}
52
53// TLSConfigFor returns a tls.Config that will provide the transport level security defined
54// by the provided Config. Will return nil if no transport level security is requested.
55func TLSConfigFor(c *Config) (*tls.Config, error) {
56	if !(c.HasCA() || c.HasCertAuth() || c.HasCertCallback() || c.TLS.Insecure || len(c.TLS.ServerName) > 0) {
57		return nil, nil
58	}
59	if c.HasCA() && c.TLS.Insecure {
60		return nil, fmt.Errorf("specifying a root certificates file with the insecure flag is not allowed")
61	}
62	if err := loadTLSFiles(c); err != nil {
63		return nil, err
64	}
65
66	tlsConfig := &tls.Config{
67		// Can't use SSLv3 because of POODLE and BEAST
68		// Can't use TLSv1.0 because of POODLE and BEAST using CBC cipher
69		// Can't use TLSv1.1 because of RC4 cipher usage
70		MinVersion:         tls.VersionTLS12,
71		InsecureSkipVerify: c.TLS.Insecure,
72		ServerName:         c.TLS.ServerName,
73	}
74
75	if c.HasCA() {
76		tlsConfig.RootCAs = rootCertPool(c.TLS.CAData)
77	}
78
79	var staticCert *tls.Certificate
80	if c.HasCertAuth() {
81		// If key/cert were provided, verify them before setting up
82		// tlsConfig.GetClientCertificate.
83		cert, err := tls.X509KeyPair(c.TLS.CertData, c.TLS.KeyData)
84		if err != nil {
85			return nil, err
86		}
87		staticCert = &cert
88	}
89
90	if c.HasCertAuth() || c.HasCertCallback() {
91		tlsConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
92			// Note: static key/cert data always take precedence over cert
93			// callback.
94			if staticCert != nil {
95				return staticCert, nil
96			}
97			if c.HasCertCallback() {
98				cert, err := c.TLS.GetCert()
99				if err != nil {
100					return nil, err
101				}
102				// GetCert may return empty value, meaning no cert.
103				if cert != nil {
104					return cert, nil
105				}
106			}
107
108			// Both c.TLS.CertData/KeyData were unset and GetCert didn't return
109			// anything. Return an empty tls.Certificate, no client cert will
110			// be sent to the server.
111			return &tls.Certificate{}, nil
112		}
113	}
114
115	return tlsConfig, nil
116}
117
118// loadTLSFiles copies the data from the CertFile, KeyFile, and CAFile fields into the CertData,
119// KeyData, and CAFile fields, or returns an error. If no error is returned, all three fields are
120// either populated or were empty to start.
121func loadTLSFiles(c *Config) error {
122	var err error
123	c.TLS.CAData, err = dataFromSliceOrFile(c.TLS.CAData, c.TLS.CAFile)
124	if err != nil {
125		return err
126	}
127
128	c.TLS.CertData, err = dataFromSliceOrFile(c.TLS.CertData, c.TLS.CertFile)
129	if err != nil {
130		return err
131	}
132
133	c.TLS.KeyData, err = dataFromSliceOrFile(c.TLS.KeyData, c.TLS.KeyFile)
134	if err != nil {
135		return err
136	}
137	return nil
138}
139
140// dataFromSliceOrFile returns data from the slice (if non-empty), or from the file,
141// or an error if an error occurred reading the file
142func dataFromSliceOrFile(data []byte, file string) ([]byte, error) {
143	if len(data) > 0 {
144		return data, nil
145	}
146	if len(file) > 0 {
147		fileData, err := ioutil.ReadFile(file)
148		if err != nil {
149			return []byte{}, err
150		}
151		return fileData, nil
152	}
153	return nil, nil
154}
155
156// rootCertPool returns nil if caData is empty.  When passed along, this will mean "use system CAs".
157// When caData is not empty, it will be the ONLY information used in the CertPool.
158func rootCertPool(caData []byte) *x509.CertPool {
159	// What we really want is a copy of x509.systemRootsPool, but that isn't exposed.  It's difficult to build (see the go
160	// code for a look at the platform specific insanity), so we'll use the fact that RootCAs == nil gives us the system values
161	// It doesn't allow trusting either/or, but hopefully that won't be an issue
162	if len(caData) == 0 {
163		return nil
164	}
165
166	// if we have caData, use it
167	certPool := x509.NewCertPool()
168	certPool.AppendCertsFromPEM(caData)
169	return certPool
170}
171
172// WrapperFunc wraps an http.RoundTripper when a new transport
173// is created for a client, allowing per connection behavior
174// to be injected.
175type WrapperFunc func(rt http.RoundTripper) http.RoundTripper
176
177// Wrappers accepts any number of wrappers and returns a wrapper
178// function that is the equivalent of calling each of them in order. Nil
179// values are ignored, which makes this function convenient for incrementally
180// wrapping a function.
181func Wrappers(fns ...WrapperFunc) WrapperFunc {
182	if len(fns) == 0 {
183		return nil
184	}
185	// optimize the common case of wrapping a possibly nil transport wrapper
186	// with an additional wrapper
187	if len(fns) == 2 && fns[0] == nil {
188		return fns[1]
189	}
190	return func(rt http.RoundTripper) http.RoundTripper {
191		base := rt
192		for _, fn := range fns {
193			if fn != nil {
194				base = fn(base)
195			}
196		}
197		return base
198	}
199}
200
201// ContextCanceller prevents new requests after the provided context is finished.
202// err is returned when the context is closed, allowing the caller to provide a context
203// appropriate error.
204func ContextCanceller(ctx context.Context, err error) WrapperFunc {
205	return func(rt http.RoundTripper) http.RoundTripper {
206		return &contextCanceller{
207			ctx: ctx,
208			rt:  rt,
209			err: err,
210		}
211	}
212}
213
214type contextCanceller struct {
215	ctx context.Context
216	rt  http.RoundTripper
217	err error
218}
219
220func (b *contextCanceller) RoundTrip(req *http.Request) (*http.Response, error) {
221	select {
222	case <-b.ctx.Done():
223		return nil, b.err
224	default:
225		return b.rt.RoundTrip(req)
226	}
227}
228