1// Copyright 2015 The etcd Authors 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package transport 16 17import ( 18 "crypto/ecdsa" 19 "crypto/elliptic" 20 "crypto/rand" 21 "crypto/tls" 22 "crypto/x509" 23 "crypto/x509/pkix" 24 "encoding/pem" 25 "errors" 26 "fmt" 27 "math/big" 28 "net" 29 "os" 30 "path/filepath" 31 "strings" 32 "time" 33 34 "github.com/coreos/etcd/pkg/tlsutil" 35) 36 37func NewListener(addr, scheme string, tlsinfo *TLSInfo) (l net.Listener, err error) { 38 if l, err = newListener(addr, scheme); err != nil { 39 return nil, err 40 } 41 return wrapTLS(addr, scheme, tlsinfo, l) 42} 43 44func newListener(addr string, scheme string) (net.Listener, error) { 45 if scheme == "unix" || scheme == "unixs" { 46 // unix sockets via unix://laddr 47 return NewUnixListener(addr) 48 } 49 return net.Listen("tcp", addr) 50} 51 52func wrapTLS(addr, scheme string, tlsinfo *TLSInfo, l net.Listener) (net.Listener, error) { 53 if scheme != "https" && scheme != "unixs" { 54 return l, nil 55 } 56 return newTLSListener(l, tlsinfo, checkSAN) 57} 58 59type TLSInfo struct { 60 CertFile string 61 KeyFile string 62 CAFile string // TODO: deprecate this in v4 63 TrustedCAFile string 64 ClientCertAuth bool 65 CRLFile string 66 InsecureSkipVerify bool 67 68 // ServerName ensures the cert matches the given host in case of discovery / virtual hosting 69 ServerName string 70 71 // HandshakeFailure is optionally called when a connection fails to handshake. The 72 // connection will be closed immediately afterwards. 73 HandshakeFailure func(*tls.Conn, error) 74 75 // CipherSuites is a list of supported cipher suites. 76 // If empty, Go auto-populates it by default. 77 // Note that cipher suites are prioritized in the given order. 78 CipherSuites []uint16 79 80 selfCert bool 81 82 // parseFunc exists to simplify testing. Typically, parseFunc 83 // should be left nil. In that case, tls.X509KeyPair will be used. 84 parseFunc func([]byte, []byte) (tls.Certificate, error) 85 86 // AllowedCN is a CN which must be provided by a client. 87 AllowedCN string 88} 89 90func (info TLSInfo) String() string { 91 return fmt.Sprintf("cert = %s, key = %s, ca = %s, trusted-ca = %s, client-cert-auth = %v, crl-file = %s", info.CertFile, info.KeyFile, info.CAFile, info.TrustedCAFile, info.ClientCertAuth, info.CRLFile) 92} 93 94func (info TLSInfo) Empty() bool { 95 return info.CertFile == "" && info.KeyFile == "" 96} 97 98func SelfCert(dirpath string, hosts []string) (info TLSInfo, err error) { 99 if err = os.MkdirAll(dirpath, 0700); err != nil { 100 return 101 } 102 103 certPath := filepath.Join(dirpath, "cert.pem") 104 keyPath := filepath.Join(dirpath, "key.pem") 105 _, errcert := os.Stat(certPath) 106 _, errkey := os.Stat(keyPath) 107 if errcert == nil && errkey == nil { 108 info.CertFile = certPath 109 info.KeyFile = keyPath 110 info.selfCert = true 111 return 112 } 113 114 serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) 115 serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) 116 if err != nil { 117 return 118 } 119 120 tmpl := x509.Certificate{ 121 SerialNumber: serialNumber, 122 Subject: pkix.Name{Organization: []string{"etcd"}}, 123 NotBefore: time.Now(), 124 NotAfter: time.Now().Add(365 * (24 * time.Hour)), 125 126 KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, 127 ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, 128 BasicConstraintsValid: true, 129 } 130 131 for _, host := range hosts { 132 h, _, _ := net.SplitHostPort(host) 133 if ip := net.ParseIP(h); ip != nil { 134 tmpl.IPAddresses = append(tmpl.IPAddresses, ip) 135 } else { 136 tmpl.DNSNames = append(tmpl.DNSNames, h) 137 } 138 } 139 140 priv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) 141 if err != nil { 142 return 143 } 144 145 derBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &priv.PublicKey, priv) 146 if err != nil { 147 return 148 } 149 150 certOut, err := os.Create(certPath) 151 if err != nil { 152 return 153 } 154 pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) 155 certOut.Close() 156 157 b, err := x509.MarshalECPrivateKey(priv) 158 if err != nil { 159 return 160 } 161 keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) 162 if err != nil { 163 return 164 } 165 pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: b}) 166 keyOut.Close() 167 168 return SelfCert(dirpath, hosts) 169} 170 171func (info TLSInfo) baseConfig() (*tls.Config, error) { 172 if info.KeyFile == "" || info.CertFile == "" { 173 return nil, fmt.Errorf("KeyFile and CertFile must both be present[key: %v, cert: %v]", info.KeyFile, info.CertFile) 174 } 175 176 _, err := tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc) 177 if err != nil { 178 return nil, err 179 } 180 181 cfg := &tls.Config{ 182 MinVersion: tls.VersionTLS12, 183 ServerName: info.ServerName, 184 } 185 186 if len(info.CipherSuites) > 0 { 187 cfg.CipherSuites = info.CipherSuites 188 } 189 190 if info.AllowedCN != "" { 191 cfg.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 192 for _, chains := range verifiedChains { 193 if len(chains) != 0 { 194 if info.AllowedCN == chains[0].Subject.CommonName { 195 return nil 196 } 197 } 198 } 199 return errors.New("CommonName authentication failed") 200 } 201 } 202 203 // this only reloads certs when there's a client request 204 // TODO: support server-side refresh (e.g. inotify, SIGHUP), caching 205 cfg.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { 206 return tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc) 207 } 208 cfg.GetClientCertificate = func(unused *tls.CertificateRequestInfo) (*tls.Certificate, error) { 209 return tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc) 210 } 211 return cfg, nil 212} 213 214// cafiles returns a list of CA file paths. 215func (info TLSInfo) cafiles() []string { 216 cs := make([]string, 0) 217 if info.CAFile != "" { 218 cs = append(cs, info.CAFile) 219 } 220 if info.TrustedCAFile != "" { 221 cs = append(cs, info.TrustedCAFile) 222 } 223 return cs 224} 225 226// ServerConfig generates a tls.Config object for use by an HTTP server. 227func (info TLSInfo) ServerConfig() (*tls.Config, error) { 228 cfg, err := info.baseConfig() 229 if err != nil { 230 return nil, err 231 } 232 233 cfg.ClientAuth = tls.NoClientCert 234 if info.CAFile != "" || info.ClientCertAuth { 235 cfg.ClientAuth = tls.RequireAndVerifyClientCert 236 } 237 238 CAFiles := info.cafiles() 239 if len(CAFiles) > 0 { 240 cp, err := tlsutil.NewCertPool(CAFiles) 241 if err != nil { 242 return nil, err 243 } 244 cfg.ClientCAs = cp 245 } 246 247 // "h2" NextProtos is necessary for enabling HTTP2 for go's HTTP server 248 cfg.NextProtos = []string{"h2"} 249 250 return cfg, nil 251} 252 253// ClientConfig generates a tls.Config object for use by an HTTP client. 254func (info TLSInfo) ClientConfig() (*tls.Config, error) { 255 var cfg *tls.Config 256 var err error 257 258 if !info.Empty() { 259 cfg, err = info.baseConfig() 260 if err != nil { 261 return nil, err 262 } 263 } else { 264 cfg = &tls.Config{ServerName: info.ServerName} 265 } 266 cfg.InsecureSkipVerify = info.InsecureSkipVerify 267 268 CAFiles := info.cafiles() 269 if len(CAFiles) > 0 { 270 cfg.RootCAs, err = tlsutil.NewCertPool(CAFiles) 271 if err != nil { 272 return nil, err 273 } 274 } 275 276 if info.selfCert { 277 cfg.InsecureSkipVerify = true 278 } 279 return cfg, nil 280} 281 282// IsClosedConnError returns true if the error is from closing listener, cmux. 283// copied from golang.org/x/net/http2/http2.go 284func IsClosedConnError(err error) bool { 285 // 'use of closed network connection' (Go <=1.8) 286 // 'use of closed file or network connection' (Go >1.8, internal/poll.ErrClosing) 287 // 'mux: listener closed' (cmux.ErrListenerClosed) 288 return err != nil && strings.Contains(err.Error(), "closed") 289} 290