1/*
2Copyright 2015 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package transport
18
19import (
20	"context"
21	"crypto/tls"
22	"crypto/x509"
23	"fmt"
24	"io/ioutil"
25	"net/http"
26
27	utilnet "k8s.io/apimachinery/pkg/util/net"
28	"k8s.io/klog"
29)
30
31// New returns an http.RoundTripper that will provide the authentication
32// or transport level security defined by the provided Config.
33func New(config *Config) (http.RoundTripper, error) {
34	// Set transport level security
35	if config.Transport != nil && (config.HasCA() || config.HasCertAuth() || config.HasCertCallback() || config.TLS.Insecure) {
36		return nil, fmt.Errorf("using a custom transport with TLS certificate options or the insecure flag is not allowed")
37	}
38
39	var (
40		rt  http.RoundTripper
41		err error
42	)
43
44	if config.Transport != nil {
45		rt = config.Transport
46	} else {
47		rt, err = tlsCache.get(config)
48		if err != nil {
49			return nil, err
50		}
51	}
52
53	return HTTPWrappersForConfig(config, rt)
54}
55
56// TLSConfigFor returns a tls.Config that will provide the transport level security defined
57// by the provided Config. Will return nil if no transport level security is requested.
58func TLSConfigFor(c *Config) (*tls.Config, error) {
59	if !(c.HasCA() || c.HasCertAuth() || c.HasCertCallback() || c.TLS.Insecure || len(c.TLS.ServerName) > 0 || len(c.TLS.NextProtos) > 0) {
60		return nil, nil
61	}
62	if c.HasCA() && c.TLS.Insecure {
63		return nil, fmt.Errorf("specifying a root certificates file with the insecure flag is not allowed")
64	}
65	if err := loadTLSFiles(c); err != nil {
66		return nil, err
67	}
68
69	tlsConfig := &tls.Config{
70		// Can't use SSLv3 because of POODLE and BEAST
71		// Can't use TLSv1.0 because of POODLE and BEAST using CBC cipher
72		// Can't use TLSv1.1 because of RC4 cipher usage
73		MinVersion:         tls.VersionTLS12,
74		InsecureSkipVerify: c.TLS.Insecure,
75		ServerName:         c.TLS.ServerName,
76		NextProtos:         c.TLS.NextProtos,
77	}
78
79	if c.HasCA() {
80		tlsConfig.RootCAs = rootCertPool(c.TLS.CAData)
81	}
82
83	var staticCert *tls.Certificate
84	if c.HasCertAuth() {
85		// If key/cert were provided, verify them before setting up
86		// tlsConfig.GetClientCertificate.
87		cert, err := tls.X509KeyPair(c.TLS.CertData, c.TLS.KeyData)
88		if err != nil {
89			return nil, err
90		}
91		staticCert = &cert
92	}
93
94	if c.HasCertAuth() || c.HasCertCallback() {
95		tlsConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
96			// Note: static key/cert data always take precedence over cert
97			// callback.
98			if staticCert != nil {
99				return staticCert, nil
100			}
101			if c.HasCertCallback() {
102				cert, err := c.TLS.GetCert()
103				if err != nil {
104					return nil, err
105				}
106				// GetCert may return empty value, meaning no cert.
107				if cert != nil {
108					return cert, nil
109				}
110			}
111
112			// Both c.TLS.CertData/KeyData were unset and GetCert didn't return
113			// anything. Return an empty tls.Certificate, no client cert will
114			// be sent to the server.
115			return &tls.Certificate{}, nil
116		}
117	}
118
119	return tlsConfig, nil
120}
121
122// loadTLSFiles copies the data from the CertFile, KeyFile, and CAFile fields into the CertData,
123// KeyData, and CAFile fields, or returns an error. If no error is returned, all three fields are
124// either populated or were empty to start.
125func loadTLSFiles(c *Config) error {
126	var err error
127	c.TLS.CAData, err = dataFromSliceOrFile(c.TLS.CAData, c.TLS.CAFile)
128	if err != nil {
129		return err
130	}
131
132	c.TLS.CertData, err = dataFromSliceOrFile(c.TLS.CertData, c.TLS.CertFile)
133	if err != nil {
134		return err
135	}
136
137	c.TLS.KeyData, err = dataFromSliceOrFile(c.TLS.KeyData, c.TLS.KeyFile)
138	if err != nil {
139		return err
140	}
141	return nil
142}
143
144// dataFromSliceOrFile returns data from the slice (if non-empty), or from the file,
145// or an error if an error occurred reading the file
146func dataFromSliceOrFile(data []byte, file string) ([]byte, error) {
147	if len(data) > 0 {
148		return data, nil
149	}
150	if len(file) > 0 {
151		fileData, err := ioutil.ReadFile(file)
152		if err != nil {
153			return []byte{}, err
154		}
155		return fileData, nil
156	}
157	return nil, nil
158}
159
160// rootCertPool returns nil if caData is empty.  When passed along, this will mean "use system CAs".
161// When caData is not empty, it will be the ONLY information used in the CertPool.
162func rootCertPool(caData []byte) *x509.CertPool {
163	// What we really want is a copy of x509.systemRootsPool, but that isn't exposed.  It's difficult to build (see the go
164	// code for a look at the platform specific insanity), so we'll use the fact that RootCAs == nil gives us the system values
165	// It doesn't allow trusting either/or, but hopefully that won't be an issue
166	if len(caData) == 0 {
167		return nil
168	}
169
170	// if we have caData, use it
171	certPool := x509.NewCertPool()
172	certPool.AppendCertsFromPEM(caData)
173	return certPool
174}
175
176// WrapperFunc wraps an http.RoundTripper when a new transport
177// is created for a client, allowing per connection behavior
178// to be injected.
179type WrapperFunc func(rt http.RoundTripper) http.RoundTripper
180
181// Wrappers accepts any number of wrappers and returns a wrapper
182// function that is the equivalent of calling each of them in order. Nil
183// values are ignored, which makes this function convenient for incrementally
184// wrapping a function.
185func Wrappers(fns ...WrapperFunc) WrapperFunc {
186	if len(fns) == 0 {
187		return nil
188	}
189	// optimize the common case of wrapping a possibly nil transport wrapper
190	// with an additional wrapper
191	if len(fns) == 2 && fns[0] == nil {
192		return fns[1]
193	}
194	return func(rt http.RoundTripper) http.RoundTripper {
195		base := rt
196		for _, fn := range fns {
197			if fn != nil {
198				base = fn(base)
199			}
200		}
201		return base
202	}
203}
204
205// ContextCanceller prevents new requests after the provided context is finished.
206// err is returned when the context is closed, allowing the caller to provide a context
207// appropriate error.
208func ContextCanceller(ctx context.Context, err error) WrapperFunc {
209	return func(rt http.RoundTripper) http.RoundTripper {
210		return &contextCanceller{
211			ctx: ctx,
212			rt:  rt,
213			err: err,
214		}
215	}
216}
217
218type contextCanceller struct {
219	ctx context.Context
220	rt  http.RoundTripper
221	err error
222}
223
224func (b *contextCanceller) RoundTrip(req *http.Request) (*http.Response, error) {
225	select {
226	case <-b.ctx.Done():
227		return nil, b.err
228	default:
229		return b.rt.RoundTrip(req)
230	}
231}
232
233func tryCancelRequest(rt http.RoundTripper, req *http.Request) {
234	type canceler interface {
235		CancelRequest(*http.Request)
236	}
237	switch rt := rt.(type) {
238	case canceler:
239		rt.CancelRequest(req)
240	case utilnet.RoundTripperWrapper:
241		tryCancelRequest(rt.WrappedRoundTripper(), req)
242	default:
243		klog.Warningf("Unable to cancel request for %T", rt)
244	}
245}
246