1package raft
2
3import (
4	"bytes"
5	"context"
6	"crypto/ecdsa"
7	"crypto/elliptic"
8	"crypto/rand"
9	"crypto/tls"
10	"crypto/x509"
11	"crypto/x509/pkix"
12	"errors"
13	fmt "fmt"
14	"math/big"
15	mathrand "math/rand"
16	"net"
17	"sync"
18	"time"
19
20	"github.com/hashicorp/errwrap"
21	log "github.com/hashicorp/go-hclog"
22	uuid "github.com/hashicorp/go-uuid"
23	"github.com/hashicorp/raft"
24	"github.com/hashicorp/vault/sdk/helper/certutil"
25	"github.com/hashicorp/vault/sdk/helper/consts"
26	"github.com/hashicorp/vault/vault/cluster"
27)
28
29// RaftTLSKey is a single TLS keypair in the Keyring
30type RaftTLSKey struct {
31	// ID is a unique identifier for this Key
32	ID string `json:"id"`
33
34	// KeyType defines the algorighm used to generate the private keys
35	KeyType string `json:"key_type"`
36
37	// AppliedIndex is the earliest known raft index that safely contains this
38	// key.
39	AppliedIndex uint64 `json:"applied_index"`
40
41	// CertBytes is the marshaled certificate.
42	CertBytes []byte `json:"cluster_cert"`
43
44	// KeyParams is the marshaled private key.
45	KeyParams *certutil.ClusterKeyParams `json:"cluster_key_params"`
46
47	// CreatedTime is the time this key was generated. This value is useful in
48	// determining when the next rotation should be.
49	CreatedTime time.Time `json:"created_time"`
50
51	parsedCert *x509.Certificate
52	parsedKey  *ecdsa.PrivateKey
53}
54
55// RaftTLSKeyring is the set of keys that raft uses for network communication.
56// Only one key is used to dial at a time but both keys will be used to accept
57// connections.
58type RaftTLSKeyring struct {
59	// Keys is the set of available key pairs
60	Keys []*RaftTLSKey `json:"keys"`
61
62	// AppliedIndex is the earliest known raft index that safely contains the
63	// latest key in the keyring.
64	AppliedIndex uint64 `json:"applied_index"`
65
66	// Term is an incrementing identifier value used to quickly determine if two
67	// states of the keyring are different.
68	Term uint64 `json:"term"`
69
70	// ActiveKeyID is the key ID to track the active key in the keyring. Only
71	// the active key is used for dialing.
72	ActiveKeyID string `json:"active_key_id"`
73}
74
75// GetActive returns the active key.
76func (k *RaftTLSKeyring) GetActive() *RaftTLSKey {
77	if k.ActiveKeyID == "" {
78		return nil
79	}
80
81	for _, key := range k.Keys {
82		if key.ID == k.ActiveKeyID {
83			return key
84		}
85	}
86	return nil
87}
88
89func GenerateTLSKey() (*RaftTLSKey, error) {
90	key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
91	if err != nil {
92		return nil, err
93	}
94
95	host, err := uuid.GenerateUUID()
96	if err != nil {
97		return nil, err
98	}
99	host = fmt.Sprintf("raft-%s", host)
100	template := &x509.Certificate{
101		Subject: pkix.Name{
102			CommonName: host,
103		},
104		DNSNames: []string{host},
105		ExtKeyUsage: []x509.ExtKeyUsage{
106			x509.ExtKeyUsageServerAuth,
107			x509.ExtKeyUsageClientAuth,
108		},
109		KeyUsage:     x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign,
110		SerialNumber: big.NewInt(mathrand.Int63()),
111		NotBefore:    time.Now().Add(-30 * time.Second),
112		// 30 years of single-active uptime ought to be enough for anybody
113		NotAfter:              time.Now().Add(262980 * time.Hour),
114		BasicConstraintsValid: true,
115		IsCA:                  true,
116	}
117
118	certBytes, err := x509.CreateCertificate(rand.Reader, template, template, key.Public(), key)
119	if err != nil {
120		return nil, errwrap.Wrapf("unable to generate local cluster certificate: {{err}}", err)
121	}
122
123	return &RaftTLSKey{
124		ID:        host,
125		KeyType:   certutil.PrivateKeyTypeP521,
126		CertBytes: certBytes,
127		KeyParams: &certutil.ClusterKeyParams{
128			Type: certutil.PrivateKeyTypeP521,
129			X:    key.PublicKey.X,
130			Y:    key.PublicKey.Y,
131			D:    key.D,
132		},
133		CreatedTime: time.Now(),
134	}, nil
135}
136
137// Make sure raftLayer satisfies the raft.StreamLayer interface
138var _ raft.StreamLayer = (*raftLayer)(nil)
139
140// Make sure raftLayer satisfies the cluster.Handler and cluster.Client
141// interfaces
142var _ cluster.Handler = (*raftLayer)(nil)
143var _ cluster.Client = (*raftLayer)(nil)
144
145// RaftLayer implements the raft.StreamLayer interface,
146// so that we can use a single RPC layer for Raft and Vault
147type raftLayer struct {
148	// Addr is the listener address to return
149	addr net.Addr
150
151	// connCh is used to accept connections
152	connCh chan net.Conn
153
154	// Tracks if we are closed
155	closed    bool
156	closeCh   chan struct{}
157	closeLock sync.Mutex
158
159	logger log.Logger
160
161	dialerFunc func(string, time.Duration) (net.Conn, error)
162
163	// TLS config
164	keyring       *RaftTLSKeyring
165	baseTLSConfig *tls.Config
166}
167
168// NewRaftLayer creates a new raftLayer object. It parses the TLS information
169// from the network config.
170func NewRaftLayer(logger log.Logger, raftTLSKeyring *RaftTLSKeyring, clusterAddr net.Addr, baseTLSConfig *tls.Config) (*raftLayer, error) {
171	switch {
172	case clusterAddr == nil:
173		// Clustering disabled on the server, don't try to look for params
174		return nil, errors.New("no raft addr found")
175	}
176
177	layer := &raftLayer{
178		addr:          clusterAddr,
179		connCh:        make(chan net.Conn),
180		closeCh:       make(chan struct{}),
181		logger:        logger,
182		baseTLSConfig: baseTLSConfig,
183	}
184
185	if err := layer.setTLSKeyring(raftTLSKeyring); err != nil {
186		return nil, err
187	}
188
189	return layer, nil
190}
191
192func (l *raftLayer) setTLSKeyring(keyring *RaftTLSKeyring) error {
193	// Fast path a noop update
194	if l.keyring != nil && l.keyring.Term == keyring.Term {
195		return nil
196	}
197
198	for _, key := range keyring.Keys {
199		switch {
200		case key.KeyParams == nil:
201			return errors.New("no raft cluster key params found")
202
203		case key.KeyParams.X == nil, key.KeyParams.Y == nil, key.KeyParams.D == nil:
204			return errors.New("failed to parse raft cluster key")
205
206		case key.KeyParams.Type != certutil.PrivateKeyTypeP521:
207			return errors.New("failed to find valid raft cluster key type")
208
209		case len(key.CertBytes) == 0:
210			return errors.New("no cluster cert found")
211		}
212
213		parsedCert, err := x509.ParseCertificate(key.CertBytes)
214		if err != nil {
215			return errwrap.Wrapf("error parsing raft cluster certificate: {{err}}", err)
216		}
217
218		key.parsedCert = parsedCert
219		key.parsedKey = &ecdsa.PrivateKey{
220			PublicKey: ecdsa.PublicKey{
221				Curve: elliptic.P521(),
222				X:     key.KeyParams.X,
223				Y:     key.KeyParams.Y,
224			},
225			D: key.KeyParams.D,
226		}
227	}
228
229	if keyring.GetActive() == nil {
230		return errors.New("expected one active key to be present in the keyring")
231	}
232
233	l.keyring = keyring
234
235	return nil
236}
237
238func (l *raftLayer) ClientLookup(ctx context.Context, requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) {
239	for _, subj := range requestInfo.AcceptableCAs {
240		for _, key := range l.keyring.Keys {
241			if bytes.Equal(subj, key.parsedCert.RawIssuer) {
242				localCert := make([]byte, len(key.CertBytes))
243				copy(localCert, key.CertBytes)
244
245				return &tls.Certificate{
246					Certificate: [][]byte{localCert},
247					PrivateKey:  key.parsedKey,
248					Leaf:        key.parsedCert,
249				}, nil
250			}
251		}
252	}
253
254	return nil, nil
255}
256
257func (l *raftLayer) ServerLookup(ctx context.Context, clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
258	if l.keyring == nil {
259		return nil, errors.New("got raft connection but no local cert")
260	}
261
262	for _, key := range l.keyring.Keys {
263		if clientHello.ServerName == key.ID {
264			localCert := make([]byte, len(key.CertBytes))
265			copy(localCert, key.CertBytes)
266
267			return &tls.Certificate{
268				Certificate: [][]byte{localCert},
269				PrivateKey:  key.parsedKey,
270				Leaf:        key.parsedCert,
271			}, nil
272		}
273	}
274
275	return nil, nil
276}
277
278// CALookup returns the CA to use when validating this connection.
279func (l *raftLayer) CALookup(context.Context) ([]*x509.Certificate, error) {
280	ret := make([]*x509.Certificate, len(l.keyring.Keys))
281	for i, key := range l.keyring.Keys {
282		ret[i] = key.parsedCert
283	}
284	return ret, nil
285}
286
287// Stop shutsdown the raft layer.
288func (l *raftLayer) Stop() error {
289	l.Close()
290	return nil
291}
292
293// Handoff is used to hand off a connection to the
294// RaftLayer. This allows it to be Accept()'ed
295func (l *raftLayer) Handoff(ctx context.Context, wg *sync.WaitGroup, quit chan struct{}, conn *tls.Conn) error {
296	l.closeLock.Lock()
297	closed := l.closed
298	l.closeLock.Unlock()
299
300	if closed {
301		return errors.New("raft is shutdown")
302	}
303
304	wg.Add(1)
305	go func() {
306		defer wg.Done()
307		select {
308		case l.connCh <- conn:
309		case <-l.closeCh:
310		case <-ctx.Done():
311		case <-quit:
312		}
313	}()
314
315	return nil
316}
317
318// Accept is used to return connection which are
319// dialed to be used with the Raft layer
320func (l *raftLayer) Accept() (net.Conn, error) {
321	select {
322	case conn := <-l.connCh:
323		return conn, nil
324	case <-l.closeCh:
325		return nil, fmt.Errorf("Raft RPC layer closed")
326	}
327}
328
329// Close is used to stop listening for Raft connections
330func (l *raftLayer) Close() error {
331	l.closeLock.Lock()
332	defer l.closeLock.Unlock()
333
334	if !l.closed {
335		l.closed = true
336		close(l.closeCh)
337	}
338	return nil
339}
340
341// Addr is used to return the address of the listener
342func (l *raftLayer) Addr() net.Addr {
343	return l.addr
344}
345
346// Dial is used to create a new outgoing connection
347func (l *raftLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) {
348
349	tlsConfig := l.baseTLSConfig.Clone()
350
351	key := l.keyring.GetActive()
352	if key == nil {
353		return nil, errors.New("no active key")
354	}
355
356	tlsConfig.NextProtos = []string{consts.RaftStorageALPN}
357	tlsConfig.ServerName = key.parsedCert.Subject.CommonName
358
359	l.logger.Debug("creating rpc dialer", "host", tlsConfig.ServerName)
360
361	pool := x509.NewCertPool()
362	pool.AddCert(key.parsedCert)
363	tlsConfig.RootCAs = pool
364	tlsConfig.ClientCAs = pool
365
366	dialer := &net.Dialer{
367		Timeout: timeout,
368	}
369	return tls.DialWithDialer(dialer, "tcp", string(address), tlsConfig)
370}
371