1package raft 2 3import ( 4 "bytes" 5 "context" 6 "crypto/ecdsa" 7 "crypto/elliptic" 8 "crypto/rand" 9 "crypto/tls" 10 "crypto/x509" 11 "crypto/x509/pkix" 12 "errors" 13 fmt "fmt" 14 "math/big" 15 mathrand "math/rand" 16 "net" 17 "sync" 18 "time" 19 20 "github.com/hashicorp/errwrap" 21 log "github.com/hashicorp/go-hclog" 22 uuid "github.com/hashicorp/go-uuid" 23 "github.com/hashicorp/raft" 24 "github.com/hashicorp/vault/sdk/helper/certutil" 25 "github.com/hashicorp/vault/sdk/helper/consts" 26 "github.com/hashicorp/vault/vault/cluster" 27) 28 29// RaftTLSKey is a single TLS keypair in the Keyring 30type RaftTLSKey struct { 31 // ID is a unique identifier for this Key 32 ID string `json:"id"` 33 34 // KeyType defines the algorighm used to generate the private keys 35 KeyType string `json:"key_type"` 36 37 // AppliedIndex is the earliest known raft index that safely contains this 38 // key. 39 AppliedIndex uint64 `json:"applied_index"` 40 41 // CertBytes is the marshaled certificate. 42 CertBytes []byte `json:"cluster_cert"` 43 44 // KeyParams is the marshaled private key. 45 KeyParams *certutil.ClusterKeyParams `json:"cluster_key_params"` 46 47 // CreatedTime is the time this key was generated. This value is useful in 48 // determining when the next rotation should be. 49 CreatedTime time.Time `json:"created_time"` 50 51 parsedCert *x509.Certificate 52 parsedKey *ecdsa.PrivateKey 53} 54 55// RaftTLSKeyring is the set of keys that raft uses for network communication. 56// Only one key is used to dial at a time but both keys will be used to accept 57// connections. 58type RaftTLSKeyring struct { 59 // Keys is the set of available key pairs 60 Keys []*RaftTLSKey `json:"keys"` 61 62 // AppliedIndex is the earliest known raft index that safely contains the 63 // latest key in the keyring. 64 AppliedIndex uint64 `json:"applied_index"` 65 66 // Term is an incrementing identifier value used to quickly determine if two 67 // states of the keyring are different. 68 Term uint64 `json:"term"` 69 70 // ActiveKeyID is the key ID to track the active key in the keyring. Only 71 // the active key is used for dialing. 72 ActiveKeyID string `json:"active_key_id"` 73} 74 75// GetActive returns the active key. 76func (k *RaftTLSKeyring) GetActive() *RaftTLSKey { 77 if k.ActiveKeyID == "" { 78 return nil 79 } 80 81 for _, key := range k.Keys { 82 if key.ID == k.ActiveKeyID { 83 return key 84 } 85 } 86 return nil 87} 88 89func GenerateTLSKey() (*RaftTLSKey, error) { 90 key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) 91 if err != nil { 92 return nil, err 93 } 94 95 host, err := uuid.GenerateUUID() 96 if err != nil { 97 return nil, err 98 } 99 host = fmt.Sprintf("raft-%s", host) 100 template := &x509.Certificate{ 101 Subject: pkix.Name{ 102 CommonName: host, 103 }, 104 DNSNames: []string{host}, 105 ExtKeyUsage: []x509.ExtKeyUsage{ 106 x509.ExtKeyUsageServerAuth, 107 x509.ExtKeyUsageClientAuth, 108 }, 109 KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign, 110 SerialNumber: big.NewInt(mathrand.Int63()), 111 NotBefore: time.Now().Add(-30 * time.Second), 112 // 30 years of single-active uptime ought to be enough for anybody 113 NotAfter: time.Now().Add(262980 * time.Hour), 114 BasicConstraintsValid: true, 115 IsCA: true, 116 } 117 118 certBytes, err := x509.CreateCertificate(rand.Reader, template, template, key.Public(), key) 119 if err != nil { 120 return nil, errwrap.Wrapf("unable to generate local cluster certificate: {{err}}", err) 121 } 122 123 return &RaftTLSKey{ 124 ID: host, 125 KeyType: certutil.PrivateKeyTypeP521, 126 CertBytes: certBytes, 127 KeyParams: &certutil.ClusterKeyParams{ 128 Type: certutil.PrivateKeyTypeP521, 129 X: key.PublicKey.X, 130 Y: key.PublicKey.Y, 131 D: key.D, 132 }, 133 CreatedTime: time.Now(), 134 }, nil 135} 136 137// Make sure raftLayer satisfies the raft.StreamLayer interface 138var _ raft.StreamLayer = (*raftLayer)(nil) 139 140// Make sure raftLayer satisfies the cluster.Handler and cluster.Client 141// interfaces 142var _ cluster.Handler = (*raftLayer)(nil) 143var _ cluster.Client = (*raftLayer)(nil) 144 145// RaftLayer implements the raft.StreamLayer interface, 146// so that we can use a single RPC layer for Raft and Vault 147type raftLayer struct { 148 // Addr is the listener address to return 149 addr net.Addr 150 151 // connCh is used to accept connections 152 connCh chan net.Conn 153 154 // Tracks if we are closed 155 closed bool 156 closeCh chan struct{} 157 closeLock sync.Mutex 158 159 logger log.Logger 160 161 dialerFunc func(string, time.Duration) (net.Conn, error) 162 163 // TLS config 164 keyring *RaftTLSKeyring 165 baseTLSConfig *tls.Config 166} 167 168// NewRaftLayer creates a new raftLayer object. It parses the TLS information 169// from the network config. 170func NewRaftLayer(logger log.Logger, raftTLSKeyring *RaftTLSKeyring, clusterAddr net.Addr, baseTLSConfig *tls.Config) (*raftLayer, error) { 171 switch { 172 case clusterAddr == nil: 173 // Clustering disabled on the server, don't try to look for params 174 return nil, errors.New("no raft addr found") 175 } 176 177 layer := &raftLayer{ 178 addr: clusterAddr, 179 connCh: make(chan net.Conn), 180 closeCh: make(chan struct{}), 181 logger: logger, 182 baseTLSConfig: baseTLSConfig, 183 } 184 185 if err := layer.setTLSKeyring(raftTLSKeyring); err != nil { 186 return nil, err 187 } 188 189 return layer, nil 190} 191 192func (l *raftLayer) setTLSKeyring(keyring *RaftTLSKeyring) error { 193 // Fast path a noop update 194 if l.keyring != nil && l.keyring.Term == keyring.Term { 195 return nil 196 } 197 198 for _, key := range keyring.Keys { 199 switch { 200 case key.KeyParams == nil: 201 return errors.New("no raft cluster key params found") 202 203 case key.KeyParams.X == nil, key.KeyParams.Y == nil, key.KeyParams.D == nil: 204 return errors.New("failed to parse raft cluster key") 205 206 case key.KeyParams.Type != certutil.PrivateKeyTypeP521: 207 return errors.New("failed to find valid raft cluster key type") 208 209 case len(key.CertBytes) == 0: 210 return errors.New("no cluster cert found") 211 } 212 213 parsedCert, err := x509.ParseCertificate(key.CertBytes) 214 if err != nil { 215 return errwrap.Wrapf("error parsing raft cluster certificate: {{err}}", err) 216 } 217 218 key.parsedCert = parsedCert 219 key.parsedKey = &ecdsa.PrivateKey{ 220 PublicKey: ecdsa.PublicKey{ 221 Curve: elliptic.P521(), 222 X: key.KeyParams.X, 223 Y: key.KeyParams.Y, 224 }, 225 D: key.KeyParams.D, 226 } 227 } 228 229 if keyring.GetActive() == nil { 230 return errors.New("expected one active key to be present in the keyring") 231 } 232 233 l.keyring = keyring 234 235 return nil 236} 237 238func (l *raftLayer) ClientLookup(ctx context.Context, requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) { 239 for _, subj := range requestInfo.AcceptableCAs { 240 for _, key := range l.keyring.Keys { 241 if bytes.Equal(subj, key.parsedCert.RawIssuer) { 242 localCert := make([]byte, len(key.CertBytes)) 243 copy(localCert, key.CertBytes) 244 245 return &tls.Certificate{ 246 Certificate: [][]byte{localCert}, 247 PrivateKey: key.parsedKey, 248 Leaf: key.parsedCert, 249 }, nil 250 } 251 } 252 } 253 254 return nil, nil 255} 256 257func (l *raftLayer) ServerLookup(ctx context.Context, clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { 258 if l.keyring == nil { 259 return nil, errors.New("got raft connection but no local cert") 260 } 261 262 for _, key := range l.keyring.Keys { 263 if clientHello.ServerName == key.ID { 264 localCert := make([]byte, len(key.CertBytes)) 265 copy(localCert, key.CertBytes) 266 267 return &tls.Certificate{ 268 Certificate: [][]byte{localCert}, 269 PrivateKey: key.parsedKey, 270 Leaf: key.parsedCert, 271 }, nil 272 } 273 } 274 275 return nil, nil 276} 277 278// CALookup returns the CA to use when validating this connection. 279func (l *raftLayer) CALookup(context.Context) ([]*x509.Certificate, error) { 280 ret := make([]*x509.Certificate, len(l.keyring.Keys)) 281 for i, key := range l.keyring.Keys { 282 ret[i] = key.parsedCert 283 } 284 return ret, nil 285} 286 287// Stop shutsdown the raft layer. 288func (l *raftLayer) Stop() error { 289 l.Close() 290 return nil 291} 292 293// Handoff is used to hand off a connection to the 294// RaftLayer. This allows it to be Accept()'ed 295func (l *raftLayer) Handoff(ctx context.Context, wg *sync.WaitGroup, quit chan struct{}, conn *tls.Conn) error { 296 l.closeLock.Lock() 297 closed := l.closed 298 l.closeLock.Unlock() 299 300 if closed { 301 return errors.New("raft is shutdown") 302 } 303 304 wg.Add(1) 305 go func() { 306 defer wg.Done() 307 select { 308 case l.connCh <- conn: 309 case <-l.closeCh: 310 case <-ctx.Done(): 311 case <-quit: 312 } 313 }() 314 315 return nil 316} 317 318// Accept is used to return connection which are 319// dialed to be used with the Raft layer 320func (l *raftLayer) Accept() (net.Conn, error) { 321 select { 322 case conn := <-l.connCh: 323 return conn, nil 324 case <-l.closeCh: 325 return nil, fmt.Errorf("Raft RPC layer closed") 326 } 327} 328 329// Close is used to stop listening for Raft connections 330func (l *raftLayer) Close() error { 331 l.closeLock.Lock() 332 defer l.closeLock.Unlock() 333 334 if !l.closed { 335 l.closed = true 336 close(l.closeCh) 337 } 338 return nil 339} 340 341// Addr is used to return the address of the listener 342func (l *raftLayer) Addr() net.Addr { 343 return l.addr 344} 345 346// Dial is used to create a new outgoing connection 347func (l *raftLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) { 348 349 tlsConfig := l.baseTLSConfig.Clone() 350 351 key := l.keyring.GetActive() 352 if key == nil { 353 return nil, errors.New("no active key") 354 } 355 356 tlsConfig.NextProtos = []string{consts.RaftStorageALPN} 357 tlsConfig.ServerName = key.parsedCert.Subject.CommonName 358 359 l.logger.Debug("creating rpc dialer", "host", tlsConfig.ServerName) 360 361 pool := x509.NewCertPool() 362 pool.AddCert(key.parsedCert) 363 tlsConfig.RootCAs = pool 364 tlsConfig.ClientCAs = pool 365 366 dialer := &net.Dialer{ 367 Timeout: timeout, 368 } 369 return tls.DialWithDialer(dialer, "tcp", string(address), tlsConfig) 370} 371