1// Copyright 2015 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package transport
16
17import (
18	"crypto/ecdsa"
19	"crypto/elliptic"
20	"crypto/rand"
21	"crypto/tls"
22	"crypto/x509"
23	"crypto/x509/pkix"
24	"encoding/pem"
25	"errors"
26	"fmt"
27	"math/big"
28	"net"
29	"os"
30	"path/filepath"
31	"strings"
32	"time"
33
34	"github.com/coreos/etcd/pkg/tlsutil"
35)
36
37func NewListener(addr, scheme string, tlsinfo *TLSInfo) (l net.Listener, err error) {
38	if l, err = newListener(addr, scheme); err != nil {
39		return nil, err
40	}
41	return wrapTLS(addr, scheme, tlsinfo, l)
42}
43
44func newListener(addr string, scheme string) (net.Listener, error) {
45	if scheme == "unix" || scheme == "unixs" {
46		// unix sockets via unix://laddr
47		return NewUnixListener(addr)
48	}
49	return net.Listen("tcp", addr)
50}
51
52func wrapTLS(addr, scheme string, tlsinfo *TLSInfo, l net.Listener) (net.Listener, error) {
53	if scheme != "https" && scheme != "unixs" {
54		return l, nil
55	}
56	return newTLSListener(l, tlsinfo, checkSAN)
57}
58
59type TLSInfo struct {
60	CertFile           string
61	KeyFile            string
62	CAFile             string // TODO: deprecate this in v4
63	TrustedCAFile      string
64	ClientCertAuth     bool
65	CRLFile            string
66	InsecureSkipVerify bool
67
68	// ServerName ensures the cert matches the given host in case of discovery / virtual hosting
69	ServerName string
70
71	// HandshakeFailure is optionally called when a connection fails to handshake. The
72	// connection will be closed immediately afterwards.
73	HandshakeFailure func(*tls.Conn, error)
74
75	// CipherSuites is a list of supported cipher suites.
76	// If empty, Go auto-populates it by default.
77	// Note that cipher suites are prioritized in the given order.
78	CipherSuites []uint16
79
80	selfCert bool
81
82	// parseFunc exists to simplify testing. Typically, parseFunc
83	// should be left nil. In that case, tls.X509KeyPair will be used.
84	parseFunc func([]byte, []byte) (tls.Certificate, error)
85
86	// AllowedCN is a CN which must be provided by a client.
87	AllowedCN string
88}
89
90func (info TLSInfo) String() string {
91	return fmt.Sprintf("cert = %s, key = %s, ca = %s, trusted-ca = %s, client-cert-auth = %v, crl-file = %s", info.CertFile, info.KeyFile, info.CAFile, info.TrustedCAFile, info.ClientCertAuth, info.CRLFile)
92}
93
94func (info TLSInfo) Empty() bool {
95	return info.CertFile == "" && info.KeyFile == ""
96}
97
98func SelfCert(dirpath string, hosts []string) (info TLSInfo, err error) {
99	if err = os.MkdirAll(dirpath, 0700); err != nil {
100		return
101	}
102
103	certPath := filepath.Join(dirpath, "cert.pem")
104	keyPath := filepath.Join(dirpath, "key.pem")
105	_, errcert := os.Stat(certPath)
106	_, errkey := os.Stat(keyPath)
107	if errcert == nil && errkey == nil {
108		info.CertFile = certPath
109		info.KeyFile = keyPath
110		info.selfCert = true
111		return
112	}
113
114	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
115	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
116	if err != nil {
117		return
118	}
119
120	tmpl := x509.Certificate{
121		SerialNumber: serialNumber,
122		Subject:      pkix.Name{Organization: []string{"etcd"}},
123		NotBefore:    time.Now(),
124		NotAfter:     time.Now().Add(365 * (24 * time.Hour)),
125
126		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
127		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
128		BasicConstraintsValid: true,
129	}
130
131	for _, host := range hosts {
132		h, _, _ := net.SplitHostPort(host)
133		if ip := net.ParseIP(h); ip != nil {
134			tmpl.IPAddresses = append(tmpl.IPAddresses, ip)
135		} else {
136			tmpl.DNSNames = append(tmpl.DNSNames, h)
137		}
138	}
139
140	priv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
141	if err != nil {
142		return
143	}
144
145	derBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &priv.PublicKey, priv)
146	if err != nil {
147		return
148	}
149
150	certOut, err := os.Create(certPath)
151	if err != nil {
152		return
153	}
154	pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
155	certOut.Close()
156
157	b, err := x509.MarshalECPrivateKey(priv)
158	if err != nil {
159		return
160	}
161	keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
162	if err != nil {
163		return
164	}
165	pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: b})
166	keyOut.Close()
167
168	return SelfCert(dirpath, hosts)
169}
170
171func (info TLSInfo) baseConfig() (*tls.Config, error) {
172	if info.KeyFile == "" || info.CertFile == "" {
173		return nil, fmt.Errorf("KeyFile and CertFile must both be present[key: %v, cert: %v]", info.KeyFile, info.CertFile)
174	}
175
176	_, err := tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc)
177	if err != nil {
178		return nil, err
179	}
180
181	cfg := &tls.Config{
182		MinVersion: tls.VersionTLS12,
183		ServerName: info.ServerName,
184	}
185
186	if len(info.CipherSuites) > 0 {
187		cfg.CipherSuites = info.CipherSuites
188	}
189
190	if info.AllowedCN != "" {
191		cfg.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
192			for _, chains := range verifiedChains {
193				if len(chains) != 0 {
194					if info.AllowedCN == chains[0].Subject.CommonName {
195						return nil
196					}
197				}
198			}
199			return errors.New("CommonName authentication failed")
200		}
201	}
202
203	// this only reloads certs when there's a client request
204	// TODO: support server-side refresh (e.g. inotify, SIGHUP), caching
205	cfg.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
206		return tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc)
207	}
208	cfg.GetClientCertificate = func(unused *tls.CertificateRequestInfo) (*tls.Certificate, error) {
209		return tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc)
210	}
211	return cfg, nil
212}
213
214// cafiles returns a list of CA file paths.
215func (info TLSInfo) cafiles() []string {
216	cs := make([]string, 0)
217	if info.CAFile != "" {
218		cs = append(cs, info.CAFile)
219	}
220	if info.TrustedCAFile != "" {
221		cs = append(cs, info.TrustedCAFile)
222	}
223	return cs
224}
225
226// ServerConfig generates a tls.Config object for use by an HTTP server.
227func (info TLSInfo) ServerConfig() (*tls.Config, error) {
228	cfg, err := info.baseConfig()
229	if err != nil {
230		return nil, err
231	}
232
233	cfg.ClientAuth = tls.NoClientCert
234	if info.CAFile != "" || info.ClientCertAuth {
235		cfg.ClientAuth = tls.RequireAndVerifyClientCert
236	}
237
238	CAFiles := info.cafiles()
239	if len(CAFiles) > 0 {
240		cp, err := tlsutil.NewCertPool(CAFiles)
241		if err != nil {
242			return nil, err
243		}
244		cfg.ClientCAs = cp
245	}
246
247	// "h2" NextProtos is necessary for enabling HTTP2 for go's HTTP server
248	cfg.NextProtos = []string{"h2"}
249
250	return cfg, nil
251}
252
253// ClientConfig generates a tls.Config object for use by an HTTP client.
254func (info TLSInfo) ClientConfig() (*tls.Config, error) {
255	var cfg *tls.Config
256	var err error
257
258	if !info.Empty() {
259		cfg, err = info.baseConfig()
260		if err != nil {
261			return nil, err
262		}
263	} else {
264		cfg = &tls.Config{ServerName: info.ServerName}
265	}
266	cfg.InsecureSkipVerify = info.InsecureSkipVerify
267
268	CAFiles := info.cafiles()
269	if len(CAFiles) > 0 {
270		cfg.RootCAs, err = tlsutil.NewCertPool(CAFiles)
271		if err != nil {
272			return nil, err
273		}
274	}
275
276	if info.selfCert {
277		cfg.InsecureSkipVerify = true
278	}
279	return cfg, nil
280}
281
282// IsClosedConnError returns true if the error is from closing listener, cmux.
283// copied from golang.org/x/net/http2/http2.go
284func IsClosedConnError(err error) bool {
285	// 'use of closed network connection' (Go <=1.8)
286	// 'use of closed file or network connection' (Go >1.8, internal/poll.ErrClosing)
287	// 'mux: listener closed' (cmux.ErrListenerClosed)
288	return err != nil && strings.Contains(err.Error(), "closed")
289}
290