1/*
2Copyright 2015 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package proxy
18
19import (
20	"context"
21	"crypto/tls"
22	"fmt"
23	"net"
24	"net/http"
25	"net/url"
26
27	"k8s.io/klog/v2"
28
29	utilnet "k8s.io/apimachinery/pkg/util/net"
30	"k8s.io/apimachinery/third_party/forked/golang/netutil"
31)
32
33// dialURL will dial the specified URL using the underlying dialer held by the passed
34// RoundTripper. The primary use of this method is to support proxying upgradable connections.
35// For this reason this method will prefer to negotiate http/1.1 if the URL scheme is https.
36// If you wish to ensure ALPN negotiates http2 then set NextProto=[]string{"http2"} in the
37// TLSConfig of the http.Transport
38func dialURL(ctx context.Context, url *url.URL, transport http.RoundTripper) (net.Conn, error) {
39	dialAddr := netutil.CanonicalAddr(url)
40
41	dialer, err := utilnet.DialerFor(transport)
42	if err != nil {
43		klog.V(5).Infof("Unable to unwrap transport %T to get dialer: %v", transport, err)
44	}
45
46	switch url.Scheme {
47	case "http":
48		if dialer != nil {
49			return dialer(ctx, "tcp", dialAddr)
50		}
51		var d net.Dialer
52		return d.DialContext(ctx, "tcp", dialAddr)
53	case "https":
54		// Get the tls config from the transport if we recognize it
55		var tlsConfig *tls.Config
56		var tlsConn *tls.Conn
57		var err error
58		tlsConfig, err = utilnet.TLSClientConfig(transport)
59		if err != nil {
60			klog.V(5).Infof("Unable to unwrap transport %T to get at TLS config: %v", transport, err)
61		}
62
63		if dialer != nil {
64			// We have a dialer; use it to open the connection, then
65			// create a tls client using the connection.
66			netConn, err := dialer(ctx, "tcp", dialAddr)
67			if err != nil {
68				return nil, err
69			}
70			if tlsConfig == nil {
71				// tls.Client requires non-nil config
72				klog.Warningf("using custom dialer with no TLSClientConfig. Defaulting to InsecureSkipVerify")
73				// tls.Handshake() requires ServerName or InsecureSkipVerify
74				tlsConfig = &tls.Config{
75					InsecureSkipVerify: true,
76				}
77			} else if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify {
78				// tls.Handshake() requires ServerName or InsecureSkipVerify
79				// infer the ServerName from the hostname we're connecting to.
80				inferredHost := dialAddr
81				if host, _, err := net.SplitHostPort(dialAddr); err == nil {
82					inferredHost = host
83				}
84				// Make a copy to avoid polluting the provided config
85				tlsConfigCopy := tlsConfig.Clone()
86				tlsConfigCopy.ServerName = inferredHost
87				tlsConfig = tlsConfigCopy
88			}
89
90			// Since this method is primary used within a "Connection: Upgrade" call we assume the caller is
91			// going to write HTTP/1.1 request to the wire. http2 should not be allowed in the TLSConfig.NextProtos,
92			// so we explicitly set that here. We only do this check if the TLSConfig support http/1.1.
93			if supportsHTTP11(tlsConfig.NextProtos) {
94				tlsConfig = tlsConfig.Clone()
95				tlsConfig.NextProtos = []string{"http/1.1"}
96			}
97
98			tlsConn = tls.Client(netConn, tlsConfig)
99			if err := tlsConn.Handshake(); err != nil {
100				netConn.Close()
101				return nil, err
102			}
103
104		} else {
105			// Dial. This Dial method does not allow to pass a context unfortunately
106			tlsConn, err = tls.Dial("tcp", dialAddr, tlsConfig)
107			if err != nil {
108				return nil, err
109			}
110		}
111
112		// Return if we were configured to skip validation
113		if tlsConfig != nil && tlsConfig.InsecureSkipVerify {
114			return tlsConn, nil
115		}
116
117		// Verify
118		host, _, _ := net.SplitHostPort(dialAddr)
119		if tlsConfig != nil && len(tlsConfig.ServerName) > 0 {
120			host = tlsConfig.ServerName
121		}
122		if err := tlsConn.VerifyHostname(host); err != nil {
123			tlsConn.Close()
124			return nil, err
125		}
126
127		return tlsConn, nil
128	default:
129		return nil, fmt.Errorf("Unknown scheme: %s", url.Scheme)
130	}
131}
132
133func supportsHTTP11(nextProtos []string) bool {
134	if len(nextProtos) == 0 {
135		return true
136	}
137	for _, proto := range nextProtos {
138		if proto == "http/1.1" {
139			return true
140		}
141	}
142	return false
143}
144