1package raft 2 3import ( 4 "bytes" 5 "context" 6 "crypto/ecdsa" 7 "crypto/elliptic" 8 "crypto/rand" 9 "crypto/tls" 10 "crypto/x509" 11 "crypto/x509/pkix" 12 "errors" 13 fmt "fmt" 14 "io" 15 "math/big" 16 mathrand "math/rand" 17 "net" 18 "net/url" 19 "sync" 20 "time" 21 22 log "github.com/hashicorp/go-hclog" 23 uuid "github.com/hashicorp/go-uuid" 24 "github.com/hashicorp/raft" 25 "github.com/hashicorp/vault/sdk/helper/certutil" 26 "github.com/hashicorp/vault/sdk/helper/consts" 27 "github.com/hashicorp/vault/vault/cluster" 28) 29 30// TLSKey is a single TLS keypair in the Keyring 31type TLSKey struct { 32 // ID is a unique identifier for this Key 33 ID string `json:"id"` 34 35 // KeyType defines the algorighm used to generate the private keys 36 KeyType string `json:"key_type"` 37 38 // AppliedIndex is the earliest known raft index that safely contains this 39 // key. 40 AppliedIndex uint64 `json:"applied_index"` 41 42 // CertBytes is the marshaled certificate. 43 CertBytes []byte `json:"cluster_cert"` 44 45 // KeyParams is the marshaled private key. 46 KeyParams *certutil.ClusterKeyParams `json:"cluster_key_params"` 47 48 // CreatedTime is the time this key was generated. This value is useful in 49 // determining when the next rotation should be. 50 CreatedTime time.Time `json:"created_time"` 51 52 parsedCert *x509.Certificate 53 parsedKey *ecdsa.PrivateKey 54} 55 56// TLSKeyring is the set of keys that raft uses for network communication. 57// Only one key is used to dial at a time but both keys will be used to accept 58// connections. 59type TLSKeyring struct { 60 // Keys is the set of available key pairs 61 Keys []*TLSKey `json:"keys"` 62 63 // AppliedIndex is the earliest known raft index that safely contains the 64 // latest key in the keyring. 65 AppliedIndex uint64 `json:"applied_index"` 66 67 // Term is an incrementing identifier value used to quickly determine if two 68 // states of the keyring are different. 69 Term uint64 `json:"term"` 70 71 // ActiveKeyID is the key ID to track the active key in the keyring. Only 72 // the active key is used for dialing. 73 ActiveKeyID string `json:"active_key_id"` 74} 75 76// GetActive returns the active key. 77func (k *TLSKeyring) GetActive() *TLSKey { 78 if k.ActiveKeyID == "" { 79 return nil 80 } 81 82 for _, key := range k.Keys { 83 if key.ID == k.ActiveKeyID { 84 return key 85 } 86 } 87 return nil 88} 89 90func GenerateTLSKey(reader io.Reader) (*TLSKey, error) { 91 key, err := ecdsa.GenerateKey(elliptic.P521(), reader) 92 if err != nil { 93 return nil, err 94 } 95 96 host, err := uuid.GenerateUUID() 97 if err != nil { 98 return nil, err 99 } 100 host = fmt.Sprintf("raft-%s", host) 101 template := &x509.Certificate{ 102 Subject: pkix.Name{ 103 CommonName: host, 104 }, 105 DNSNames: []string{host}, 106 ExtKeyUsage: []x509.ExtKeyUsage{ 107 x509.ExtKeyUsageServerAuth, 108 x509.ExtKeyUsageClientAuth, 109 }, 110 KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign, 111 SerialNumber: big.NewInt(mathrand.Int63()), 112 NotBefore: time.Now().Add(-30 * time.Second), 113 // 30 years ought to be enough for anybody 114 NotAfter: time.Now().Add(262980 * time.Hour), 115 BasicConstraintsValid: true, 116 IsCA: true, 117 } 118 119 certBytes, err := x509.CreateCertificate(rand.Reader, template, template, key.Public(), key) 120 if err != nil { 121 return nil, fmt.Errorf("unable to generate local cluster certificate: %w", err) 122 } 123 124 return &TLSKey{ 125 ID: host, 126 KeyType: certutil.PrivateKeyTypeP521, 127 CertBytes: certBytes, 128 KeyParams: &certutil.ClusterKeyParams{ 129 Type: certutil.PrivateKeyTypeP521, 130 X: key.PublicKey.X, 131 Y: key.PublicKey.Y, 132 D: key.D, 133 }, 134 CreatedTime: time.Now(), 135 }, nil 136} 137 138var ( 139 // Make sure raftLayer satisfies the raft.StreamLayer interface 140 _ raft.StreamLayer = (*raftLayer)(nil) 141 142 // Make sure raftLayer satisfies the cluster.Handler and cluster.Client 143 // interfaces 144 _ cluster.Handler = (*raftLayer)(nil) 145 _ cluster.Client = (*raftLayer)(nil) 146) 147 148// RaftLayer implements the raft.StreamLayer interface, 149// so that we can use a single RPC layer for Raft and Vault 150type raftLayer struct { 151 // Addr is the listener address to return 152 addr net.Addr 153 154 // connCh is used to accept connections 155 connCh chan net.Conn 156 157 // Tracks if we are closed 158 closed bool 159 closeCh chan struct{} 160 closeLock sync.Mutex 161 162 logger log.Logger 163 164 dialerFunc func(string, time.Duration) (net.Conn, error) 165 166 // TLS config 167 keyring *TLSKeyring 168 clusterListener cluster.ClusterHook 169} 170 171// NewRaftLayer creates a new raftLayer object. It parses the TLS information 172// from the network config. 173func NewRaftLayer(logger log.Logger, raftTLSKeyring *TLSKeyring, clusterListener cluster.ClusterHook) (*raftLayer, error) { 174 clusterAddr := clusterListener.Addr() 175 if clusterAddr == nil { 176 return nil, errors.New("no raft addr found") 177 } 178 179 { 180 // Test the advertised address to make sure it's not an unspecified IP 181 u := url.URL{ 182 Host: clusterAddr.String(), 183 } 184 ip := net.ParseIP(u.Hostname()) 185 if ip != nil && ip.IsUnspecified() { 186 return nil, fmt.Errorf("cannot use unspecified IP with raft storage: %s", clusterAddr.String()) 187 } 188 } 189 190 layer := &raftLayer{ 191 addr: clusterAddr, 192 connCh: make(chan net.Conn), 193 closeCh: make(chan struct{}), 194 logger: logger, 195 clusterListener: clusterListener, 196 } 197 198 if err := layer.setTLSKeyring(raftTLSKeyring); err != nil { 199 return nil, err 200 } 201 202 return layer, nil 203} 204 205func (l *raftLayer) setTLSKeyring(keyring *TLSKeyring) error { 206 // Fast path a noop update 207 if l.keyring != nil && l.keyring.Term == keyring.Term { 208 return nil 209 } 210 211 for _, key := range keyring.Keys { 212 switch { 213 case key.KeyParams == nil: 214 return errors.New("no raft cluster key params found") 215 216 case key.KeyParams.X == nil, key.KeyParams.Y == nil, key.KeyParams.D == nil: 217 return errors.New("failed to parse raft cluster key") 218 219 case key.KeyParams.Type != certutil.PrivateKeyTypeP521: 220 return errors.New("failed to find valid raft cluster key type") 221 222 case len(key.CertBytes) == 0: 223 return errors.New("no cluster cert found") 224 } 225 226 parsedCert, err := x509.ParseCertificate(key.CertBytes) 227 if err != nil { 228 return fmt.Errorf("error parsing raft cluster certificate: %w", err) 229 } 230 231 key.parsedCert = parsedCert 232 key.parsedKey = &ecdsa.PrivateKey{ 233 PublicKey: ecdsa.PublicKey{ 234 Curve: elliptic.P521(), 235 X: key.KeyParams.X, 236 Y: key.KeyParams.Y, 237 }, 238 D: key.KeyParams.D, 239 } 240 } 241 242 if keyring.GetActive() == nil { 243 return errors.New("expected one active key to be present in the keyring") 244 } 245 246 l.keyring = keyring 247 248 return nil 249} 250 251func (l *raftLayer) ServerName() string { 252 key := l.keyring.GetActive() 253 if key == nil { 254 return "" 255 } 256 257 return key.parsedCert.Subject.CommonName 258} 259 260func (l *raftLayer) CACert(ctx context.Context) *x509.Certificate { 261 key := l.keyring.GetActive() 262 if key == nil { 263 return nil 264 } 265 266 return key.parsedCert 267} 268 269func (l *raftLayer) ClientLookup(ctx context.Context, requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) { 270 for _, subj := range requestInfo.AcceptableCAs { 271 for _, key := range l.keyring.Keys { 272 if bytes.Equal(subj, key.parsedCert.RawIssuer) { 273 localCert := make([]byte, len(key.CertBytes)) 274 copy(localCert, key.CertBytes) 275 276 return &tls.Certificate{ 277 Certificate: [][]byte{localCert}, 278 PrivateKey: key.parsedKey, 279 Leaf: key.parsedCert, 280 }, nil 281 } 282 } 283 } 284 285 return nil, nil 286} 287 288func (l *raftLayer) ServerLookup(ctx context.Context, clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { 289 if l.keyring == nil { 290 return nil, errors.New("got raft connection but no local cert") 291 } 292 293 for _, key := range l.keyring.Keys { 294 if clientHello.ServerName == key.ID { 295 localCert := make([]byte, len(key.CertBytes)) 296 copy(localCert, key.CertBytes) 297 298 return &tls.Certificate{ 299 Certificate: [][]byte{localCert}, 300 PrivateKey: key.parsedKey, 301 Leaf: key.parsedCert, 302 }, nil 303 } 304 } 305 306 return nil, nil 307} 308 309// CALookup returns the CA to use when validating this connection. 310func (l *raftLayer) CALookup(context.Context) ([]*x509.Certificate, error) { 311 ret := make([]*x509.Certificate, len(l.keyring.Keys)) 312 for i, key := range l.keyring.Keys { 313 ret[i] = key.parsedCert 314 } 315 return ret, nil 316} 317 318// Stop shuts down the raft layer. 319func (l *raftLayer) Stop() error { 320 l.Close() 321 return nil 322} 323 324// Handoff is used to hand off a connection to the 325// RaftLayer. This allows it to be Accept()'ed 326func (l *raftLayer) Handoff(ctx context.Context, wg *sync.WaitGroup, quit chan struct{}, conn *tls.Conn) error { 327 l.closeLock.Lock() 328 closed := l.closed 329 l.closeLock.Unlock() 330 331 if closed { 332 return errors.New("raft is shutdown") 333 } 334 335 wg.Add(1) 336 go func() { 337 defer wg.Done() 338 select { 339 case l.connCh <- conn: 340 case <-l.closeCh: 341 case <-ctx.Done(): 342 case <-quit: 343 } 344 }() 345 346 return nil 347} 348 349// Accept is used to return connection which are 350// dialed to be used with the Raft layer 351func (l *raftLayer) Accept() (net.Conn, error) { 352 select { 353 case conn := <-l.connCh: 354 return conn, nil 355 case <-l.closeCh: 356 return nil, fmt.Errorf("Raft RPC layer closed") 357 } 358} 359 360// Close is used to stop listening for Raft connections 361func (l *raftLayer) Close() error { 362 l.closeLock.Lock() 363 defer l.closeLock.Unlock() 364 365 if !l.closed { 366 l.closed = true 367 close(l.closeCh) 368 } 369 return nil 370} 371 372// Addr is used to return the address of the listener 373func (l *raftLayer) Addr() net.Addr { 374 return l.addr 375} 376 377// Dial is used to create a new outgoing connection 378func (l *raftLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) { 379 dialFunc := l.clusterListener.GetDialerFunc(context.Background(), consts.RaftStorageALPN) 380 return dialFunc(string(address), timeout) 381} 382