1package raft
2
3import (
4	"bytes"
5	"context"
6	"crypto/ecdsa"
7	"crypto/elliptic"
8	"crypto/rand"
9	"crypto/tls"
10	"crypto/x509"
11	"crypto/x509/pkix"
12	"errors"
13	fmt "fmt"
14	"io"
15	"math/big"
16	mathrand "math/rand"
17	"net"
18	"net/url"
19	"sync"
20	"time"
21
22	log "github.com/hashicorp/go-hclog"
23	uuid "github.com/hashicorp/go-uuid"
24	"github.com/hashicorp/raft"
25	"github.com/hashicorp/vault/sdk/helper/certutil"
26	"github.com/hashicorp/vault/sdk/helper/consts"
27	"github.com/hashicorp/vault/vault/cluster"
28)
29
30// TLSKey is a single TLS keypair in the Keyring
31type TLSKey struct {
32	// ID is a unique identifier for this Key
33	ID string `json:"id"`
34
35	// KeyType defines the algorighm used to generate the private keys
36	KeyType string `json:"key_type"`
37
38	// AppliedIndex is the earliest known raft index that safely contains this
39	// key.
40	AppliedIndex uint64 `json:"applied_index"`
41
42	// CertBytes is the marshaled certificate.
43	CertBytes []byte `json:"cluster_cert"`
44
45	// KeyParams is the marshaled private key.
46	KeyParams *certutil.ClusterKeyParams `json:"cluster_key_params"`
47
48	// CreatedTime is the time this key was generated. This value is useful in
49	// determining when the next rotation should be.
50	CreatedTime time.Time `json:"created_time"`
51
52	parsedCert *x509.Certificate
53	parsedKey  *ecdsa.PrivateKey
54}
55
56// TLSKeyring is the set of keys that raft uses for network communication.
57// Only one key is used to dial at a time but both keys will be used to accept
58// connections.
59type TLSKeyring struct {
60	// Keys is the set of available key pairs
61	Keys []*TLSKey `json:"keys"`
62
63	// AppliedIndex is the earliest known raft index that safely contains the
64	// latest key in the keyring.
65	AppliedIndex uint64 `json:"applied_index"`
66
67	// Term is an incrementing identifier value used to quickly determine if two
68	// states of the keyring are different.
69	Term uint64 `json:"term"`
70
71	// ActiveKeyID is the key ID to track the active key in the keyring. Only
72	// the active key is used for dialing.
73	ActiveKeyID string `json:"active_key_id"`
74}
75
76// GetActive returns the active key.
77func (k *TLSKeyring) GetActive() *TLSKey {
78	if k.ActiveKeyID == "" {
79		return nil
80	}
81
82	for _, key := range k.Keys {
83		if key.ID == k.ActiveKeyID {
84			return key
85		}
86	}
87	return nil
88}
89
90func GenerateTLSKey(reader io.Reader) (*TLSKey, error) {
91	key, err := ecdsa.GenerateKey(elliptic.P521(), reader)
92	if err != nil {
93		return nil, err
94	}
95
96	host, err := uuid.GenerateUUID()
97	if err != nil {
98		return nil, err
99	}
100	host = fmt.Sprintf("raft-%s", host)
101	template := &x509.Certificate{
102		Subject: pkix.Name{
103			CommonName: host,
104		},
105		DNSNames: []string{host},
106		ExtKeyUsage: []x509.ExtKeyUsage{
107			x509.ExtKeyUsageServerAuth,
108			x509.ExtKeyUsageClientAuth,
109		},
110		KeyUsage:     x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign,
111		SerialNumber: big.NewInt(mathrand.Int63()),
112		NotBefore:    time.Now().Add(-30 * time.Second),
113		// 30 years ought to be enough for anybody
114		NotAfter:              time.Now().Add(262980 * time.Hour),
115		BasicConstraintsValid: true,
116		IsCA:                  true,
117	}
118
119	certBytes, err := x509.CreateCertificate(rand.Reader, template, template, key.Public(), key)
120	if err != nil {
121		return nil, fmt.Errorf("unable to generate local cluster certificate: %w", err)
122	}
123
124	return &TLSKey{
125		ID:        host,
126		KeyType:   certutil.PrivateKeyTypeP521,
127		CertBytes: certBytes,
128		KeyParams: &certutil.ClusterKeyParams{
129			Type: certutil.PrivateKeyTypeP521,
130			X:    key.PublicKey.X,
131			Y:    key.PublicKey.Y,
132			D:    key.D,
133		},
134		CreatedTime: time.Now(),
135	}, nil
136}
137
138var (
139	// Make sure raftLayer satisfies the raft.StreamLayer interface
140	_ raft.StreamLayer = (*raftLayer)(nil)
141
142	// Make sure raftLayer satisfies the cluster.Handler and cluster.Client
143	// interfaces
144	_ cluster.Handler = (*raftLayer)(nil)
145	_ cluster.Client  = (*raftLayer)(nil)
146)
147
148// RaftLayer implements the raft.StreamLayer interface,
149// so that we can use a single RPC layer for Raft and Vault
150type raftLayer struct {
151	// Addr is the listener address to return
152	addr net.Addr
153
154	// connCh is used to accept connections
155	connCh chan net.Conn
156
157	// Tracks if we are closed
158	closed    bool
159	closeCh   chan struct{}
160	closeLock sync.Mutex
161
162	logger log.Logger
163
164	dialerFunc func(string, time.Duration) (net.Conn, error)
165
166	// TLS config
167	keyring         *TLSKeyring
168	clusterListener cluster.ClusterHook
169}
170
171// NewRaftLayer creates a new raftLayer object. It parses the TLS information
172// from the network config.
173func NewRaftLayer(logger log.Logger, raftTLSKeyring *TLSKeyring, clusterListener cluster.ClusterHook) (*raftLayer, error) {
174	clusterAddr := clusterListener.Addr()
175	if clusterAddr == nil {
176		return nil, errors.New("no raft addr found")
177	}
178
179	{
180		// Test the advertised address to make sure it's not an unspecified IP
181		u := url.URL{
182			Host: clusterAddr.String(),
183		}
184		ip := net.ParseIP(u.Hostname())
185		if ip != nil && ip.IsUnspecified() {
186			return nil, fmt.Errorf("cannot use unspecified IP with raft storage: %s", clusterAddr.String())
187		}
188	}
189
190	layer := &raftLayer{
191		addr:            clusterAddr,
192		connCh:          make(chan net.Conn),
193		closeCh:         make(chan struct{}),
194		logger:          logger,
195		clusterListener: clusterListener,
196	}
197
198	if err := layer.setTLSKeyring(raftTLSKeyring); err != nil {
199		return nil, err
200	}
201
202	return layer, nil
203}
204
205func (l *raftLayer) setTLSKeyring(keyring *TLSKeyring) error {
206	// Fast path a noop update
207	if l.keyring != nil && l.keyring.Term == keyring.Term {
208		return nil
209	}
210
211	for _, key := range keyring.Keys {
212		switch {
213		case key.KeyParams == nil:
214			return errors.New("no raft cluster key params found")
215
216		case key.KeyParams.X == nil, key.KeyParams.Y == nil, key.KeyParams.D == nil:
217			return errors.New("failed to parse raft cluster key")
218
219		case key.KeyParams.Type != certutil.PrivateKeyTypeP521:
220			return errors.New("failed to find valid raft cluster key type")
221
222		case len(key.CertBytes) == 0:
223			return errors.New("no cluster cert found")
224		}
225
226		parsedCert, err := x509.ParseCertificate(key.CertBytes)
227		if err != nil {
228			return fmt.Errorf("error parsing raft cluster certificate: %w", err)
229		}
230
231		key.parsedCert = parsedCert
232		key.parsedKey = &ecdsa.PrivateKey{
233			PublicKey: ecdsa.PublicKey{
234				Curve: elliptic.P521(),
235				X:     key.KeyParams.X,
236				Y:     key.KeyParams.Y,
237			},
238			D: key.KeyParams.D,
239		}
240	}
241
242	if keyring.GetActive() == nil {
243		return errors.New("expected one active key to be present in the keyring")
244	}
245
246	l.keyring = keyring
247
248	return nil
249}
250
251func (l *raftLayer) ServerName() string {
252	key := l.keyring.GetActive()
253	if key == nil {
254		return ""
255	}
256
257	return key.parsedCert.Subject.CommonName
258}
259
260func (l *raftLayer) CACert(ctx context.Context) *x509.Certificate {
261	key := l.keyring.GetActive()
262	if key == nil {
263		return nil
264	}
265
266	return key.parsedCert
267}
268
269func (l *raftLayer) ClientLookup(ctx context.Context, requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) {
270	for _, subj := range requestInfo.AcceptableCAs {
271		for _, key := range l.keyring.Keys {
272			if bytes.Equal(subj, key.parsedCert.RawIssuer) {
273				localCert := make([]byte, len(key.CertBytes))
274				copy(localCert, key.CertBytes)
275
276				return &tls.Certificate{
277					Certificate: [][]byte{localCert},
278					PrivateKey:  key.parsedKey,
279					Leaf:        key.parsedCert,
280				}, nil
281			}
282		}
283	}
284
285	return nil, nil
286}
287
288func (l *raftLayer) ServerLookup(ctx context.Context, clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
289	if l.keyring == nil {
290		return nil, errors.New("got raft connection but no local cert")
291	}
292
293	for _, key := range l.keyring.Keys {
294		if clientHello.ServerName == key.ID {
295			localCert := make([]byte, len(key.CertBytes))
296			copy(localCert, key.CertBytes)
297
298			return &tls.Certificate{
299				Certificate: [][]byte{localCert},
300				PrivateKey:  key.parsedKey,
301				Leaf:        key.parsedCert,
302			}, nil
303		}
304	}
305
306	return nil, nil
307}
308
309// CALookup returns the CA to use when validating this connection.
310func (l *raftLayer) CALookup(context.Context) ([]*x509.Certificate, error) {
311	ret := make([]*x509.Certificate, len(l.keyring.Keys))
312	for i, key := range l.keyring.Keys {
313		ret[i] = key.parsedCert
314	}
315	return ret, nil
316}
317
318// Stop shuts down the raft layer.
319func (l *raftLayer) Stop() error {
320	l.Close()
321	return nil
322}
323
324// Handoff is used to hand off a connection to the
325// RaftLayer. This allows it to be Accept()'ed
326func (l *raftLayer) Handoff(ctx context.Context, wg *sync.WaitGroup, quit chan struct{}, conn *tls.Conn) error {
327	l.closeLock.Lock()
328	closed := l.closed
329	l.closeLock.Unlock()
330
331	if closed {
332		return errors.New("raft is shutdown")
333	}
334
335	wg.Add(1)
336	go func() {
337		defer wg.Done()
338		select {
339		case l.connCh <- conn:
340		case <-l.closeCh:
341		case <-ctx.Done():
342		case <-quit:
343		}
344	}()
345
346	return nil
347}
348
349// Accept is used to return connection which are
350// dialed to be used with the Raft layer
351func (l *raftLayer) Accept() (net.Conn, error) {
352	select {
353	case conn := <-l.connCh:
354		return conn, nil
355	case <-l.closeCh:
356		return nil, fmt.Errorf("Raft RPC layer closed")
357	}
358}
359
360// Close is used to stop listening for Raft connections
361func (l *raftLayer) Close() error {
362	l.closeLock.Lock()
363	defer l.closeLock.Unlock()
364
365	if !l.closed {
366		l.closed = true
367		close(l.closeCh)
368	}
369	return nil
370}
371
372// Addr is used to return the address of the listener
373func (l *raftLayer) Addr() net.Addr {
374	return l.addr
375}
376
377// Dial is used to create a new outgoing connection
378func (l *raftLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) {
379	dialFunc := l.clusterListener.GetDialerFunc(context.Background(), consts.RaftStorageALPN)
380	return dialFunc(string(address), timeout)
381}
382