1// Copyright 2015 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package transport
16
17import (
18	"crypto/ecdsa"
19	"crypto/elliptic"
20	"crypto/rand"
21	"crypto/tls"
22	"crypto/x509"
23	"crypto/x509/pkix"
24	"encoding/pem"
25	"errors"
26	"fmt"
27	"math/big"
28	"net"
29	"os"
30	"path/filepath"
31	"strings"
32	"time"
33
34	"go.etcd.io/etcd/pkg/tlsutil"
35
36	"go.uber.org/zap"
37)
38
39// NewListener creates a new listner.
40func NewListener(addr, scheme string, tlsinfo *TLSInfo) (l net.Listener, err error) {
41	if l, err = newListener(addr, scheme); err != nil {
42		return nil, err
43	}
44	return wrapTLS(scheme, tlsinfo, l)
45}
46
47func newListener(addr string, scheme string) (net.Listener, error) {
48	if scheme == "unix" || scheme == "unixs" {
49		// unix sockets via unix://laddr
50		return NewUnixListener(addr)
51	}
52	return net.Listen("tcp", addr)
53}
54
55func wrapTLS(scheme string, tlsinfo *TLSInfo, l net.Listener) (net.Listener, error) {
56	if scheme != "https" && scheme != "unixs" {
57		return l, nil
58	}
59	return newTLSListener(l, tlsinfo, checkSAN)
60}
61
62type TLSInfo struct {
63	CertFile           string
64	KeyFile            string
65	TrustedCAFile      string
66	ClientCertAuth     bool
67	CRLFile            string
68	InsecureSkipVerify bool
69
70	// ServerName ensures the cert matches the given host in case of discovery / virtual hosting
71	ServerName string
72
73	// HandshakeFailure is optionally called when a connection fails to handshake. The
74	// connection will be closed immediately afterwards.
75	HandshakeFailure func(*tls.Conn, error)
76
77	// CipherSuites is a list of supported cipher suites.
78	// If empty, Go auto-populates it by default.
79	// Note that cipher suites are prioritized in the given order.
80	CipherSuites []uint16
81
82	selfCert bool
83
84	// parseFunc exists to simplify testing. Typically, parseFunc
85	// should be left nil. In that case, tls.X509KeyPair will be used.
86	parseFunc func([]byte, []byte) (tls.Certificate, error)
87
88	// AllowedCN is a CN which must be provided by a client.
89	AllowedCN string
90
91	// Logger logs TLS errors.
92	// If nil, all logs are discarded.
93	Logger *zap.Logger
94
95	// EmptyCN indicates that the cert must have empty CN.
96	// If true, ClientConfig() will return an error for a cert with non empty CN.
97	EmptyCN bool
98}
99
100func (info TLSInfo) String() string {
101	return fmt.Sprintf("cert = %s, key = %s, trusted-ca = %s, client-cert-auth = %v, crl-file = %s", info.CertFile, info.KeyFile, info.TrustedCAFile, info.ClientCertAuth, info.CRLFile)
102}
103
104func (info TLSInfo) Empty() bool {
105	return info.CertFile == "" && info.KeyFile == ""
106}
107
108func SelfCert(lg *zap.Logger, dirpath string, hosts []string) (info TLSInfo, err error) {
109	if err = os.MkdirAll(dirpath, 0700); err != nil {
110		return
111	}
112	info.Logger = lg
113
114	certPath := filepath.Join(dirpath, "cert.pem")
115	keyPath := filepath.Join(dirpath, "key.pem")
116	_, errcert := os.Stat(certPath)
117	_, errkey := os.Stat(keyPath)
118	if errcert == nil && errkey == nil {
119		info.CertFile = certPath
120		info.KeyFile = keyPath
121		info.selfCert = true
122		return
123	}
124
125	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
126	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
127	if err != nil {
128		if info.Logger != nil {
129			info.Logger.Warn(
130				"cannot generate random number",
131				zap.Error(err),
132			)
133		}
134		return
135	}
136
137	tmpl := x509.Certificate{
138		SerialNumber: serialNumber,
139		Subject:      pkix.Name{Organization: []string{"etcd"}},
140		NotBefore:    time.Now(),
141		NotAfter:     time.Now().Add(365 * (24 * time.Hour)),
142
143		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
144		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
145		BasicConstraintsValid: true,
146	}
147
148	for _, host := range hosts {
149		h, _, _ := net.SplitHostPort(host)
150		if ip := net.ParseIP(h); ip != nil {
151			tmpl.IPAddresses = append(tmpl.IPAddresses, ip)
152		} else {
153			tmpl.DNSNames = append(tmpl.DNSNames, h)
154		}
155	}
156
157	priv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
158	if err != nil {
159		if info.Logger != nil {
160			info.Logger.Warn(
161				"cannot generate ECDSA key",
162				zap.Error(err),
163			)
164		}
165		return
166	}
167
168	derBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &priv.PublicKey, priv)
169	if err != nil {
170		if info.Logger != nil {
171			info.Logger.Warn(
172				"cannot generate x509 certificate",
173				zap.Error(err),
174			)
175		}
176		return
177	}
178
179	certOut, err := os.Create(certPath)
180	if err != nil {
181		info.Logger.Warn(
182			"cannot cert file",
183			zap.String("path", certPath),
184			zap.Error(err),
185		)
186		return
187	}
188	pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
189	certOut.Close()
190	if info.Logger != nil {
191		info.Logger.Info("created cert file", zap.String("path", certPath))
192	}
193
194	b, err := x509.MarshalECPrivateKey(priv)
195	if err != nil {
196		return
197	}
198	keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
199	if err != nil {
200		if info.Logger != nil {
201			info.Logger.Warn(
202				"cannot key file",
203				zap.String("path", keyPath),
204				zap.Error(err),
205			)
206		}
207		return
208	}
209	pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: b})
210	keyOut.Close()
211	if info.Logger != nil {
212		info.Logger.Info("created key file", zap.String("path", keyPath))
213	}
214	return SelfCert(lg, dirpath, hosts)
215}
216
217// baseConfig is called on initial TLS handshake start.
218//
219// Previously,
220// 1. Server has non-empty (*tls.Config).Certificates on client hello
221// 2. Server calls (*tls.Config).GetCertificate iff:
222//    - Server's (*tls.Config).Certificates is not empty, or
223//    - Client supplies SNI; non-empty (*tls.ClientHelloInfo).ServerName
224//
225// When (*tls.Config).Certificates is always populated on initial handshake,
226// client is expected to provide a valid matching SNI to pass the TLS
227// verification, thus trigger server (*tls.Config).GetCertificate to reload
228// TLS assets. However, a cert whose SAN field does not include domain names
229// but only IP addresses, has empty (*tls.ClientHelloInfo).ServerName, thus
230// it was never able to trigger TLS reload on initial handshake; first
231// ceritifcate object was being used, never being updated.
232//
233// Now, (*tls.Config).Certificates is created empty on initial TLS client
234// handshake, in order to trigger (*tls.Config).GetCertificate and populate
235// rest of the certificates on every new TLS connection, even when client
236// SNI is empty (e.g. cert only includes IPs).
237func (info TLSInfo) baseConfig() (*tls.Config, error) {
238	if info.KeyFile == "" || info.CertFile == "" {
239		return nil, fmt.Errorf("KeyFile and CertFile must both be present[key: %v, cert: %v]", info.KeyFile, info.CertFile)
240	}
241	if info.Logger == nil {
242		info.Logger = zap.NewNop()
243	}
244
245	_, err := tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc)
246	if err != nil {
247		return nil, err
248	}
249
250	cfg := &tls.Config{
251		MinVersion: tls.VersionTLS12,
252		ServerName: info.ServerName,
253	}
254
255	if len(info.CipherSuites) > 0 {
256		cfg.CipherSuites = info.CipherSuites
257	}
258
259	if info.AllowedCN != "" {
260		cfg.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
261			for _, chains := range verifiedChains {
262				if len(chains) != 0 {
263					if info.AllowedCN == chains[0].Subject.CommonName {
264						return nil
265					}
266				}
267			}
268			return errors.New("CommonName authentication failed")
269		}
270	}
271
272	// this only reloads certs when there's a client request
273	// TODO: support server-side refresh (e.g. inotify, SIGHUP), caching
274	cfg.GetCertificate = func(clientHello *tls.ClientHelloInfo) (cert *tls.Certificate, err error) {
275		cert, err = tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc)
276		if os.IsNotExist(err) {
277			if info.Logger != nil {
278				info.Logger.Warn(
279					"failed to find peer cert files",
280					zap.String("cert-file", info.CertFile),
281					zap.String("key-file", info.KeyFile),
282					zap.Error(err),
283				)
284			}
285		} else if err != nil {
286			if info.Logger != nil {
287				info.Logger.Warn(
288					"failed to create peer certificate",
289					zap.String("cert-file", info.CertFile),
290					zap.String("key-file", info.KeyFile),
291					zap.Error(err),
292				)
293			}
294		}
295		return cert, err
296	}
297	cfg.GetClientCertificate = func(unused *tls.CertificateRequestInfo) (cert *tls.Certificate, err error) {
298		cert, err = tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc)
299		if os.IsNotExist(err) {
300			if info.Logger != nil {
301				info.Logger.Warn(
302					"failed to find client cert files",
303					zap.String("cert-file", info.CertFile),
304					zap.String("key-file", info.KeyFile),
305					zap.Error(err),
306				)
307			}
308		} else if err != nil {
309			if info.Logger != nil {
310				info.Logger.Warn(
311					"failed to create client certificate",
312					zap.String("cert-file", info.CertFile),
313					zap.String("key-file", info.KeyFile),
314					zap.Error(err),
315				)
316			}
317		}
318		return cert, err
319	}
320	return cfg, nil
321}
322
323// cafiles returns a list of CA file paths.
324func (info TLSInfo) cafiles() []string {
325	cs := make([]string, 0)
326	if info.TrustedCAFile != "" {
327		cs = append(cs, info.TrustedCAFile)
328	}
329	return cs
330}
331
332// ServerConfig generates a tls.Config object for use by an HTTP server.
333func (info TLSInfo) ServerConfig() (*tls.Config, error) {
334	cfg, err := info.baseConfig()
335	if err != nil {
336		return nil, err
337	}
338
339	cfg.ClientAuth = tls.NoClientCert
340	if info.TrustedCAFile != "" || info.ClientCertAuth {
341		cfg.ClientAuth = tls.RequireAndVerifyClientCert
342	}
343
344	cs := info.cafiles()
345	if len(cs) > 0 {
346		cp, err := tlsutil.NewCertPool(cs)
347		if err != nil {
348			return nil, err
349		}
350		cfg.ClientCAs = cp
351	}
352
353	// "h2" NextProtos is necessary for enabling HTTP2 for go's HTTP server
354	cfg.NextProtos = []string{"h2"}
355
356	return cfg, nil
357}
358
359// ClientConfig generates a tls.Config object for use by an HTTP client.
360func (info TLSInfo) ClientConfig() (*tls.Config, error) {
361	var cfg *tls.Config
362	var err error
363
364	if !info.Empty() {
365		cfg, err = info.baseConfig()
366		if err != nil {
367			return nil, err
368		}
369	} else {
370		cfg = &tls.Config{ServerName: info.ServerName}
371	}
372	cfg.InsecureSkipVerify = info.InsecureSkipVerify
373
374	cs := info.cafiles()
375	if len(cs) > 0 {
376		cfg.RootCAs, err = tlsutil.NewCertPool(cs)
377		if err != nil {
378			return nil, err
379		}
380	}
381
382	if info.selfCert {
383		cfg.InsecureSkipVerify = true
384	}
385
386	if info.EmptyCN {
387		hasNonEmptyCN := false
388		cn := ""
389		tlsutil.NewCert(info.CertFile, info.KeyFile, func(certPEMBlock []byte, keyPEMBlock []byte) (tls.Certificate, error) {
390			var block *pem.Block
391			block, _ = pem.Decode(certPEMBlock)
392			cert, err := x509.ParseCertificate(block.Bytes)
393			if err != nil {
394				return tls.Certificate{}, err
395			}
396			if len(cert.Subject.CommonName) != 0 {
397				hasNonEmptyCN = true
398				cn = cert.Subject.CommonName
399			}
400			return tls.X509KeyPair(certPEMBlock, keyPEMBlock)
401		})
402		if hasNonEmptyCN {
403			return nil, fmt.Errorf("cert has non empty Common Name (%s)", cn)
404		}
405	}
406
407	return cfg, nil
408}
409
410// IsClosedConnError returns true if the error is from closing listener, cmux.
411// copied from golang.org/x/net/http2/http2.go
412func IsClosedConnError(err error) bool {
413	// 'use of closed network connection' (Go <=1.8)
414	// 'use of closed file or network connection' (Go >1.8, internal/poll.ErrClosing)
415	// 'mux: listener closed' (cmux.ErrListenerClosed)
416	return err != nil && strings.Contains(err.Error(), "closed")
417}
418