1/*
2Copyright 2015 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spdy
18
19import (
20	"bufio"
21	"bytes"
22	"context"
23	"crypto/tls"
24	"encoding/base64"
25	"fmt"
26	"io"
27	"io/ioutil"
28	"net"
29	"net/http"
30	"net/http/httputil"
31	"net/url"
32	"strings"
33
34	apierrors "k8s.io/apimachinery/pkg/api/errors"
35	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
36	"k8s.io/apimachinery/pkg/runtime"
37	"k8s.io/apimachinery/pkg/runtime/serializer"
38	"k8s.io/apimachinery/pkg/util/httpstream"
39	utilnet "k8s.io/apimachinery/pkg/util/net"
40	"k8s.io/apimachinery/third_party/forked/golang/netutil"
41)
42
43// SpdyRoundTripper knows how to upgrade an HTTP request to one that supports
44// multiplexed streams. After RoundTrip() is invoked, Conn will be set
45// and usable. SpdyRoundTripper implements the UpgradeRoundTripper interface.
46type SpdyRoundTripper struct {
47	//tlsConfig holds the TLS configuration settings to use when connecting
48	//to the remote server.
49	tlsConfig *tls.Config
50
51	/* TODO according to http://golang.org/pkg/net/http/#RoundTripper, a RoundTripper
52	   must be safe for use by multiple concurrent goroutines. If this is absolutely
53	   necessary, we could keep a map from http.Request to net.Conn. In practice,
54	   a client will create an http.Client, set the transport to a new insteace of
55	   SpdyRoundTripper, and use it a single time, so this hopefully won't be an issue.
56	*/
57	// conn is the underlying network connection to the remote server.
58	conn net.Conn
59
60	// Dialer is the dialer used to connect.  Used if non-nil.
61	Dialer *net.Dialer
62
63	// proxier knows which proxy to use given a request, defaults to http.ProxyFromEnvironment
64	// Used primarily for mocking the proxy discovery in tests.
65	proxier func(req *http.Request) (*url.URL, error)
66
67	// followRedirects indicates if the round tripper should examine responses for redirects and
68	// follow them.
69	followRedirects bool
70	// requireSameHostRedirects restricts redirect following to only follow redirects to the same host
71	// as the original request.
72	requireSameHostRedirects bool
73}
74
75var _ utilnet.TLSClientConfigHolder = &SpdyRoundTripper{}
76var _ httpstream.UpgradeRoundTripper = &SpdyRoundTripper{}
77var _ utilnet.Dialer = &SpdyRoundTripper{}
78
79// NewRoundTripper creates a new SpdyRoundTripper that will use
80// the specified tlsConfig.
81func NewRoundTripper(tlsConfig *tls.Config, followRedirects, requireSameHostRedirects bool) httpstream.UpgradeRoundTripper {
82	return NewSpdyRoundTripper(tlsConfig, followRedirects, requireSameHostRedirects)
83}
84
85// NewSpdyRoundTripper creates a new SpdyRoundTripper that will use
86// the specified tlsConfig. This function is mostly meant for unit tests.
87func NewSpdyRoundTripper(tlsConfig *tls.Config, followRedirects, requireSameHostRedirects bool) *SpdyRoundTripper {
88	return &SpdyRoundTripper{
89		tlsConfig:                tlsConfig,
90		followRedirects:          followRedirects,
91		requireSameHostRedirects: requireSameHostRedirects,
92	}
93}
94
95// TLSClientConfig implements pkg/util/net.TLSClientConfigHolder for proper TLS checking during
96// proxying with a spdy roundtripper.
97func (s *SpdyRoundTripper) TLSClientConfig() *tls.Config {
98	return s.tlsConfig
99}
100
101// Dial implements k8s.io/apimachinery/pkg/util/net.Dialer.
102func (s *SpdyRoundTripper) Dial(req *http.Request) (net.Conn, error) {
103	conn, err := s.dial(req)
104	if err != nil {
105		return nil, err
106	}
107
108	if err := req.Write(conn); err != nil {
109		conn.Close()
110		return nil, err
111	}
112
113	return conn, nil
114}
115
116// dial dials the host specified by req, using TLS if appropriate, optionally
117// using a proxy server if one is configured via environment variables.
118func (s *SpdyRoundTripper) dial(req *http.Request) (net.Conn, error) {
119	proxier := s.proxier
120	if proxier == nil {
121		proxier = utilnet.NewProxierWithNoProxyCIDR(http.ProxyFromEnvironment)
122	}
123	proxyURL, err := proxier(req)
124	if err != nil {
125		return nil, err
126	}
127
128	if proxyURL == nil {
129		return s.dialWithoutProxy(req.Context(), req.URL)
130	}
131
132	// ensure we use a canonical host with proxyReq
133	targetHost := netutil.CanonicalAddr(req.URL)
134
135	// proxying logic adapted from http://blog.h6t.eu/post/74098062923/golang-websocket-with-http-proxy-support
136	proxyReq := http.Request{
137		Method: "CONNECT",
138		URL:    &url.URL{},
139		Host:   targetHost,
140	}
141
142	if pa := s.proxyAuth(proxyURL); pa != "" {
143		proxyReq.Header = http.Header{}
144		proxyReq.Header.Set("Proxy-Authorization", pa)
145	}
146
147	proxyDialConn, err := s.dialWithoutProxy(req.Context(), proxyURL)
148	if err != nil {
149		return nil, err
150	}
151
152	proxyClientConn := httputil.NewProxyClientConn(proxyDialConn, nil)
153	_, err = proxyClientConn.Do(&proxyReq)
154	if err != nil && err != httputil.ErrPersistEOF {
155		return nil, err
156	}
157
158	rwc, _ := proxyClientConn.Hijack()
159
160	if req.URL.Scheme != "https" {
161		return rwc, nil
162	}
163
164	host, _, err := net.SplitHostPort(targetHost)
165	if err != nil {
166		return nil, err
167	}
168
169	tlsConfig := s.tlsConfig
170	switch {
171	case tlsConfig == nil:
172		tlsConfig = &tls.Config{ServerName: host}
173	case len(tlsConfig.ServerName) == 0:
174		tlsConfig = tlsConfig.Clone()
175		tlsConfig.ServerName = host
176	}
177
178	tlsConn := tls.Client(rwc, tlsConfig)
179
180	// need to manually call Handshake() so we can call VerifyHostname() below
181	if err := tlsConn.Handshake(); err != nil {
182		return nil, err
183	}
184
185	// Return if we were configured to skip validation
186	if tlsConfig.InsecureSkipVerify {
187		return tlsConn, nil
188	}
189
190	if err := tlsConn.VerifyHostname(tlsConfig.ServerName); err != nil {
191		return nil, err
192	}
193
194	return tlsConn, nil
195}
196
197// dialWithoutProxy dials the host specified by url, using TLS if appropriate.
198func (s *SpdyRoundTripper) dialWithoutProxy(ctx context.Context, url *url.URL) (net.Conn, error) {
199	dialAddr := netutil.CanonicalAddr(url)
200
201	if url.Scheme == "http" {
202		if s.Dialer == nil {
203			var d net.Dialer
204			return d.DialContext(ctx, "tcp", dialAddr)
205		} else {
206			return s.Dialer.DialContext(ctx, "tcp", dialAddr)
207		}
208	}
209
210	// TODO validate the TLSClientConfig is set up?
211	var conn *tls.Conn
212	var err error
213	if s.Dialer == nil {
214		conn, err = tls.Dial("tcp", dialAddr, s.tlsConfig)
215	} else {
216		conn, err = tls.DialWithDialer(s.Dialer, "tcp", dialAddr, s.tlsConfig)
217	}
218	if err != nil {
219		return nil, err
220	}
221
222	// Return if we were configured to skip validation
223	if s.tlsConfig != nil && s.tlsConfig.InsecureSkipVerify {
224		return conn, nil
225	}
226
227	host, _, err := net.SplitHostPort(dialAddr)
228	if err != nil {
229		return nil, err
230	}
231	if s.tlsConfig != nil && len(s.tlsConfig.ServerName) > 0 {
232		host = s.tlsConfig.ServerName
233	}
234	err = conn.VerifyHostname(host)
235	if err != nil {
236		return nil, err
237	}
238
239	return conn, nil
240}
241
242// proxyAuth returns, for a given proxy URL, the value to be used for the Proxy-Authorization header
243func (s *SpdyRoundTripper) proxyAuth(proxyURL *url.URL) string {
244	if proxyURL == nil || proxyURL.User == nil {
245		return ""
246	}
247	credentials := proxyURL.User.String()
248	encodedAuth := base64.StdEncoding.EncodeToString([]byte(credentials))
249	return fmt.Sprintf("Basic %s", encodedAuth)
250}
251
252// RoundTrip executes the Request and upgrades it. After a successful upgrade,
253// clients may call SpdyRoundTripper.Connection() to retrieve the upgraded
254// connection.
255func (s *SpdyRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
256	header := utilnet.CloneHeader(req.Header)
257	header.Add(httpstream.HeaderConnection, httpstream.HeaderUpgrade)
258	header.Add(httpstream.HeaderUpgrade, HeaderSpdy31)
259
260	var (
261		conn        net.Conn
262		rawResponse []byte
263		err         error
264	)
265
266	if s.followRedirects {
267		conn, rawResponse, err = utilnet.ConnectWithRedirects(req.Method, req.URL, header, req.Body, s, s.requireSameHostRedirects)
268	} else {
269		clone := utilnet.CloneRequest(req)
270		clone.Header = header
271		conn, err = s.Dial(clone)
272	}
273	if err != nil {
274		return nil, err
275	}
276
277	responseReader := bufio.NewReader(
278		io.MultiReader(
279			bytes.NewBuffer(rawResponse),
280			conn,
281		),
282	)
283
284	resp, err := http.ReadResponse(responseReader, nil)
285	if err != nil {
286		if conn != nil {
287			conn.Close()
288		}
289		return nil, err
290	}
291
292	s.conn = conn
293
294	return resp, nil
295}
296
297// NewConnection validates the upgrade response, creating and returning a new
298// httpstream.Connection if there were no errors.
299func (s *SpdyRoundTripper) NewConnection(resp *http.Response) (httpstream.Connection, error) {
300	connectionHeader := strings.ToLower(resp.Header.Get(httpstream.HeaderConnection))
301	upgradeHeader := strings.ToLower(resp.Header.Get(httpstream.HeaderUpgrade))
302	if (resp.StatusCode != http.StatusSwitchingProtocols) || !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(HeaderSpdy31)) {
303		defer resp.Body.Close()
304		responseError := ""
305		responseErrorBytes, err := ioutil.ReadAll(resp.Body)
306		if err != nil {
307			responseError = "unable to read error from server response"
308		} else {
309			// TODO: I don't belong here, I should be abstracted from this class
310			if obj, _, err := statusCodecs.UniversalDecoder().Decode(responseErrorBytes, nil, &metav1.Status{}); err == nil {
311				if status, ok := obj.(*metav1.Status); ok {
312					return nil, &apierrors.StatusError{ErrStatus: *status}
313				}
314			}
315			responseError = string(responseErrorBytes)
316			responseError = strings.TrimSpace(responseError)
317		}
318
319		return nil, fmt.Errorf("unable to upgrade connection: %s", responseError)
320	}
321
322	return NewClientConnection(s.conn)
323}
324
325// statusScheme is private scheme for the decoding here until someone fixes the TODO in NewConnection
326var statusScheme = runtime.NewScheme()
327
328// ParameterCodec knows about query parameters used with the meta v1 API spec.
329var statusCodecs = serializer.NewCodecFactory(statusScheme)
330
331func init() {
332	statusScheme.AddUnversionedTypes(metav1.SchemeGroupVersion,
333		&metav1.Status{},
334	)
335}
336