1// Copyright 2015 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package transport
16
17import (
18	"crypto/ecdsa"
19	"crypto/elliptic"
20	"crypto/rand"
21	"crypto/tls"
22	"crypto/x509"
23	"crypto/x509/pkix"
24	"encoding/pem"
25	"errors"
26	"fmt"
27	"math/big"
28	"net"
29	"os"
30	"path/filepath"
31	"strings"
32	"time"
33
34	"go.etcd.io/etcd/pkg/fileutil"
35	"go.etcd.io/etcd/pkg/tlsutil"
36
37	"go.uber.org/zap"
38)
39
40// NewListener creates a new listner.
41func NewListener(addr, scheme string, tlsinfo *TLSInfo) (l net.Listener, err error) {
42	if l, err = newListener(addr, scheme); err != nil {
43		return nil, err
44	}
45	return wrapTLS(scheme, tlsinfo, l)
46}
47
48func newListener(addr string, scheme string) (net.Listener, error) {
49	if scheme == "unix" || scheme == "unixs" {
50		// unix sockets via unix://laddr
51		return NewUnixListener(addr)
52	}
53	return net.Listen("tcp", addr)
54}
55
56func wrapTLS(scheme string, tlsinfo *TLSInfo, l net.Listener) (net.Listener, error) {
57	if scheme != "https" && scheme != "unixs" {
58		return l, nil
59	}
60	if tlsinfo != nil && tlsinfo.SkipClientSANVerify {
61		return NewTLSListener(l, tlsinfo)
62	}
63	return newTLSListener(l, tlsinfo, checkSAN)
64}
65
66type TLSInfo struct {
67	CertFile            string
68	KeyFile             string
69	TrustedCAFile       string
70	ClientCertAuth      bool
71	CRLFile             string
72	InsecureSkipVerify  bool
73	SkipClientSANVerify bool
74
75	// ServerName ensures the cert matches the given host in case of discovery / virtual hosting
76	ServerName string
77
78	// HandshakeFailure is optionally called when a connection fails to handshake. The
79	// connection will be closed immediately afterwards.
80	HandshakeFailure func(*tls.Conn, error)
81
82	// CipherSuites is a list of supported cipher suites.
83	// If empty, Go auto-populates it by default.
84	// Note that cipher suites are prioritized in the given order.
85	CipherSuites []uint16
86
87	selfCert bool
88
89	// parseFunc exists to simplify testing. Typically, parseFunc
90	// should be left nil. In that case, tls.X509KeyPair will be used.
91	parseFunc func([]byte, []byte) (tls.Certificate, error)
92
93	// AllowedCN is a CN which must be provided by a client.
94	AllowedCN string
95
96	// AllowedHostname is an IP address or hostname that must match the TLS
97	// certificate provided by a client.
98	AllowedHostname string
99
100	// Logger logs TLS errors.
101	// If nil, all logs are discarded.
102	Logger *zap.Logger
103
104	// EmptyCN indicates that the cert must have empty CN.
105	// If true, ClientConfig() will return an error for a cert with non empty CN.
106	EmptyCN bool
107}
108
109func (info TLSInfo) String() string {
110	return fmt.Sprintf("cert = %s, key = %s, trusted-ca = %s, client-cert-auth = %v, crl-file = %s", info.CertFile, info.KeyFile, info.TrustedCAFile, info.ClientCertAuth, info.CRLFile)
111}
112
113func (info TLSInfo) Empty() bool {
114	return info.CertFile == "" && info.KeyFile == ""
115}
116
117func SelfCert(lg *zap.Logger, dirpath string, hosts []string, additionalUsages ...x509.ExtKeyUsage) (info TLSInfo, err error) {
118	info.Logger = lg
119	err = fileutil.TouchDirAll(dirpath)
120	if err != nil {
121		if info.Logger != nil {
122			info.Logger.Warn(
123				"cannot create cert directory",
124				zap.Error(err),
125			)
126		}
127		return
128	}
129
130	certPath := filepath.Join(dirpath, "cert.pem")
131	keyPath := filepath.Join(dirpath, "key.pem")
132	_, errcert := os.Stat(certPath)
133	_, errkey := os.Stat(keyPath)
134	if errcert == nil && errkey == nil {
135		info.CertFile = certPath
136		info.KeyFile = keyPath
137		info.selfCert = true
138		return
139	}
140
141	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
142	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
143	if err != nil {
144		if info.Logger != nil {
145			info.Logger.Warn(
146				"cannot generate random number",
147				zap.Error(err),
148			)
149		}
150		return
151	}
152
153	tmpl := x509.Certificate{
154		SerialNumber: serialNumber,
155		Subject:      pkix.Name{Organization: []string{"etcd"}},
156		NotBefore:    time.Now(),
157		NotAfter:     time.Now().Add(365 * (24 * time.Hour)),
158
159		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
160		ExtKeyUsage:           append([]x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, additionalUsages...),
161		BasicConstraintsValid: true,
162	}
163
164	for _, host := range hosts {
165		h, _, _ := net.SplitHostPort(host)
166		if ip := net.ParseIP(h); ip != nil {
167			tmpl.IPAddresses = append(tmpl.IPAddresses, ip)
168		} else {
169			tmpl.DNSNames = append(tmpl.DNSNames, h)
170		}
171	}
172
173	priv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
174	if err != nil {
175		if info.Logger != nil {
176			info.Logger.Warn(
177				"cannot generate ECDSA key",
178				zap.Error(err),
179			)
180		}
181		return
182	}
183
184	derBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &priv.PublicKey, priv)
185	if err != nil {
186		if info.Logger != nil {
187			info.Logger.Warn(
188				"cannot generate x509 certificate",
189				zap.Error(err),
190			)
191		}
192		return
193	}
194
195	certOut, err := os.Create(certPath)
196	if err != nil {
197		info.Logger.Warn(
198			"cannot cert file",
199			zap.String("path", certPath),
200			zap.Error(err),
201		)
202		return
203	}
204	pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
205	certOut.Close()
206	if info.Logger != nil {
207		info.Logger.Info("created cert file", zap.String("path", certPath))
208	}
209
210	b, err := x509.MarshalECPrivateKey(priv)
211	if err != nil {
212		return
213	}
214	keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
215	if err != nil {
216		if info.Logger != nil {
217			info.Logger.Warn(
218				"cannot key file",
219				zap.String("path", keyPath),
220				zap.Error(err),
221			)
222		}
223		return
224	}
225	pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: b})
226	keyOut.Close()
227	if info.Logger != nil {
228		info.Logger.Info("created key file", zap.String("path", keyPath))
229	}
230	return SelfCert(lg, dirpath, hosts)
231}
232
233// baseConfig is called on initial TLS handshake start.
234//
235// Previously,
236// 1. Server has non-empty (*tls.Config).Certificates on client hello
237// 2. Server calls (*tls.Config).GetCertificate iff:
238//    - Server's (*tls.Config).Certificates is not empty, or
239//    - Client supplies SNI; non-empty (*tls.ClientHelloInfo).ServerName
240//
241// When (*tls.Config).Certificates is always populated on initial handshake,
242// client is expected to provide a valid matching SNI to pass the TLS
243// verification, thus trigger server (*tls.Config).GetCertificate to reload
244// TLS assets. However, a cert whose SAN field does not include domain names
245// but only IP addresses, has empty (*tls.ClientHelloInfo).ServerName, thus
246// it was never able to trigger TLS reload on initial handshake; first
247// ceritifcate object was being used, never being updated.
248//
249// Now, (*tls.Config).Certificates is created empty on initial TLS client
250// handshake, in order to trigger (*tls.Config).GetCertificate and populate
251// rest of the certificates on every new TLS connection, even when client
252// SNI is empty (e.g. cert only includes IPs).
253func (info TLSInfo) baseConfig() (*tls.Config, error) {
254	if info.KeyFile == "" || info.CertFile == "" {
255		return nil, fmt.Errorf("KeyFile and CertFile must both be present[key: %v, cert: %v]", info.KeyFile, info.CertFile)
256	}
257	if info.Logger == nil {
258		info.Logger = zap.NewNop()
259	}
260
261	_, err := tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc)
262	if err != nil {
263		return nil, err
264	}
265
266	cfg := &tls.Config{
267		MinVersion: tls.VersionTLS12,
268		ServerName: info.ServerName,
269	}
270
271	if len(info.CipherSuites) > 0 {
272		cfg.CipherSuites = info.CipherSuites
273	}
274
275	// Client certificates may be verified by either an exact match on the CN,
276	// or a more general check of the CN and SANs.
277	var verifyCertificate func(*x509.Certificate) bool
278	if info.AllowedCN != "" {
279		if info.AllowedHostname != "" {
280			return nil, fmt.Errorf("AllowedCN and AllowedHostname are mutually exclusive (cn=%q, hostname=%q)", info.AllowedCN, info.AllowedHostname)
281		}
282		verifyCertificate = func(cert *x509.Certificate) bool {
283			return info.AllowedCN == cert.Subject.CommonName
284		}
285	}
286	if info.AllowedHostname != "" {
287		verifyCertificate = func(cert *x509.Certificate) bool {
288			return cert.VerifyHostname(info.AllowedHostname) == nil
289		}
290	}
291	if verifyCertificate != nil {
292		cfg.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
293			for _, chains := range verifiedChains {
294				if len(chains) != 0 {
295					if verifyCertificate(chains[0]) {
296						return nil
297					}
298				}
299			}
300			return errors.New("client certificate authentication failed")
301		}
302	}
303
304	// this only reloads certs when there's a client request
305	// TODO: support server-side refresh (e.g. inotify, SIGHUP), caching
306	cfg.GetCertificate = func(clientHello *tls.ClientHelloInfo) (cert *tls.Certificate, err error) {
307		cert, err = tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc)
308		if os.IsNotExist(err) {
309			if info.Logger != nil {
310				info.Logger.Warn(
311					"failed to find peer cert files",
312					zap.String("cert-file", info.CertFile),
313					zap.String("key-file", info.KeyFile),
314					zap.Error(err),
315				)
316			}
317		} else if err != nil {
318			if info.Logger != nil {
319				info.Logger.Warn(
320					"failed to create peer certificate",
321					zap.String("cert-file", info.CertFile),
322					zap.String("key-file", info.KeyFile),
323					zap.Error(err),
324				)
325			}
326		}
327		return cert, err
328	}
329	cfg.GetClientCertificate = func(unused *tls.CertificateRequestInfo) (cert *tls.Certificate, err error) {
330		cert, err = tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc)
331		if os.IsNotExist(err) {
332			if info.Logger != nil {
333				info.Logger.Warn(
334					"failed to find client cert files",
335					zap.String("cert-file", info.CertFile),
336					zap.String("key-file", info.KeyFile),
337					zap.Error(err),
338				)
339			}
340		} else if err != nil {
341			if info.Logger != nil {
342				info.Logger.Warn(
343					"failed to create client certificate",
344					zap.String("cert-file", info.CertFile),
345					zap.String("key-file", info.KeyFile),
346					zap.Error(err),
347				)
348			}
349		}
350		return cert, err
351	}
352	return cfg, nil
353}
354
355// cafiles returns a list of CA file paths.
356func (info TLSInfo) cafiles() []string {
357	cs := make([]string, 0)
358	if info.TrustedCAFile != "" {
359		cs = append(cs, info.TrustedCAFile)
360	}
361	return cs
362}
363
364// ServerConfig generates a tls.Config object for use by an HTTP server.
365func (info TLSInfo) ServerConfig() (*tls.Config, error) {
366	cfg, err := info.baseConfig()
367	if err != nil {
368		return nil, err
369	}
370
371	cfg.ClientAuth = tls.NoClientCert
372	if info.TrustedCAFile != "" || info.ClientCertAuth {
373		cfg.ClientAuth = tls.RequireAndVerifyClientCert
374	}
375
376	cs := info.cafiles()
377	if len(cs) > 0 {
378		cp, err := tlsutil.NewCertPool(cs)
379		if err != nil {
380			return nil, err
381		}
382		cfg.ClientCAs = cp
383	}
384
385	// "h2" NextProtos is necessary for enabling HTTP2 for go's HTTP server
386	cfg.NextProtos = []string{"h2"}
387
388	return cfg, nil
389}
390
391// ClientConfig generates a tls.Config object for use by an HTTP client.
392func (info TLSInfo) ClientConfig() (*tls.Config, error) {
393	var cfg *tls.Config
394	var err error
395
396	if !info.Empty() {
397		cfg, err = info.baseConfig()
398		if err != nil {
399			return nil, err
400		}
401	} else {
402		cfg = &tls.Config{ServerName: info.ServerName}
403	}
404	cfg.InsecureSkipVerify = info.InsecureSkipVerify
405
406	cs := info.cafiles()
407	if len(cs) > 0 {
408		cfg.RootCAs, err = tlsutil.NewCertPool(cs)
409		if err != nil {
410			return nil, err
411		}
412	}
413
414	if info.selfCert {
415		cfg.InsecureSkipVerify = true
416	}
417
418	if info.EmptyCN {
419		hasNonEmptyCN := false
420		cn := ""
421		tlsutil.NewCert(info.CertFile, info.KeyFile, func(certPEMBlock []byte, keyPEMBlock []byte) (tls.Certificate, error) {
422			var block *pem.Block
423			block, _ = pem.Decode(certPEMBlock)
424			cert, err := x509.ParseCertificate(block.Bytes)
425			if err != nil {
426				return tls.Certificate{}, err
427			}
428			if len(cert.Subject.CommonName) != 0 {
429				hasNonEmptyCN = true
430				cn = cert.Subject.CommonName
431			}
432			return tls.X509KeyPair(certPEMBlock, keyPEMBlock)
433		})
434		if hasNonEmptyCN {
435			return nil, fmt.Errorf("cert has non empty Common Name (%s)", cn)
436		}
437	}
438
439	return cfg, nil
440}
441
442// IsClosedConnError returns true if the error is from closing listener, cmux.
443// copied from golang.org/x/net/http2/http2.go
444func IsClosedConnError(err error) bool {
445	// 'use of closed network connection' (Go <=1.8)
446	// 'use of closed file or network connection' (Go >1.8, internal/poll.ErrClosing)
447	// 'mux: listener closed' (cmux.ErrListenerClosed)
448	return err != nil && strings.Contains(err.Error(), "closed")
449}
450