1package raft
2
3import (
4	"context"
5	"crypto/md5"
6	"encoding/base64"
7	fmt "fmt"
8	"io"
9	"io/ioutil"
10	"os"
11	"path/filepath"
12	"testing"
13	"time"
14
15	"github.com/go-test/deep"
16	"github.com/golang/protobuf/proto"
17	hclog "github.com/hashicorp/go-hclog"
18	uuid "github.com/hashicorp/go-uuid"
19	"github.com/hashicorp/raft"
20	"github.com/hashicorp/vault/sdk/helper/jsonutil"
21	"github.com/hashicorp/vault/sdk/physical"
22	bolt "go.etcd.io/bbolt"
23)
24
25func getRaft(t testing.TB, bootstrap bool, noStoreState bool) (*RaftBackend, string) {
26	raftDir, err := ioutil.TempDir("", "vault-raft-")
27	if err != nil {
28		t.Fatal(err)
29	}
30	t.Logf("raft dir: %s", raftDir)
31
32	return getRaftWithDir(t, bootstrap, noStoreState, raftDir)
33}
34
35func getRaftWithDir(t testing.TB, bootstrap bool, noStoreState bool, raftDir string) (*RaftBackend, string) {
36	logger := hclog.New(&hclog.LoggerOptions{
37		Name:  "raft",
38		Level: hclog.Trace,
39	})
40	logger.Info("raft dir", "dir", raftDir)
41
42	conf := map[string]string{
43		"path":          raftDir,
44		"trailing_logs": "100",
45	}
46
47	if noStoreState {
48		conf["doNotStoreLatestState"] = ""
49	}
50
51	backendRaw, err := NewRaftBackend(conf, logger)
52	if err != nil {
53		t.Fatal(err)
54	}
55	backend := backendRaw.(*RaftBackend)
56
57	if bootstrap {
58		err = backend.Bootstrap(context.Background(), []Peer{Peer{ID: backend.NodeID(), Address: backend.NodeID()}})
59		if err != nil {
60			t.Fatal(err)
61		}
62
63		err = backend.SetupCluster(context.Background(), SetupOpts{})
64		if err != nil {
65			t.Fatal(err)
66		}
67
68		for {
69			if backend.AppliedIndex() >= 2 {
70				break
71			}
72		}
73
74	}
75
76	return backend, raftDir
77}
78
79func compareFSMs(t *testing.T, fsm1, fsm2 *FSM) {
80	t.Helper()
81	index1, config1 := fsm1.LatestState()
82	index2, config2 := fsm2.LatestState()
83
84	if !proto.Equal(index1, index2) {
85		t.Fatalf("indexes did not match: %+v != %+v", index1, index2)
86	}
87	if !proto.Equal(config1, config2) {
88		t.Fatalf("configs did not match: %+v != %+v", config1, config2)
89	}
90
91	compareDBs(t, fsm1.db, fsm2.db)
92}
93
94func compareDBs(t *testing.T, boltDB1, boltDB2 *bolt.DB) {
95	db1 := make(map[string]string)
96	db2 := make(map[string]string)
97
98	err := boltDB1.View(func(tx *bolt.Tx) error {
99
100		c := tx.Cursor()
101		for bucketName, _ := c.First(); bucketName != nil; bucketName, _ = c.Next() {
102			b := tx.Bucket(bucketName)
103
104			cBucket := b.Cursor()
105
106			for k, v := cBucket.First(); k != nil; k, v = cBucket.Next() {
107				db1[string(k)] = base64.StdEncoding.EncodeToString(v)
108			}
109		}
110
111		return nil
112	})
113
114	if err != nil {
115		t.Fatal(err)
116	}
117
118	err = boltDB2.View(func(tx *bolt.Tx) error {
119		c := tx.Cursor()
120		for bucketName, _ := c.First(); bucketName != nil; bucketName, _ = c.Next() {
121			b := tx.Bucket(bucketName)
122
123			c := b.Cursor()
124
125			for k, v := c.First(); k != nil; k, v = c.Next() {
126				db2[string(k)] = base64.StdEncoding.EncodeToString(v)
127			}
128		}
129
130		return nil
131	})
132
133	if err != nil {
134		t.Fatal(err)
135	}
136
137	if diff := deep.Equal(db1, db2); diff != nil {
138		t.Fatal(diff)
139	}
140}
141
142func TestRaft_Backend(t *testing.T) {
143	b, dir := getRaft(t, true, true)
144	defer os.RemoveAll(dir)
145
146	physical.ExerciseBackend(t, b)
147}
148
149func TestRaft_Backend_ListPrefix(t *testing.T) {
150	b, dir := getRaft(t, true, true)
151	defer os.RemoveAll(dir)
152
153	physical.ExerciseBackend_ListPrefix(t, b)
154}
155
156func TestRaft_TransactionalBackend(t *testing.T) {
157	b, dir := getRaft(t, true, true)
158	defer os.RemoveAll(dir)
159
160	physical.ExerciseTransactionalBackend(t, b)
161}
162
163func TestRaft_HABackend(t *testing.T) {
164	t.Skip()
165	raft, dir := getRaft(t, true, true)
166	defer os.RemoveAll(dir)
167	raft2, dir2 := getRaft(t, false, true)
168	defer os.RemoveAll(dir2)
169
170	// Add raft2 to the cluster
171	addPeer(t, raft, raft2)
172
173	physical.ExerciseHABackend(t, raft, raft2)
174}
175
176func TestRaft_Backend_ThreeNode(t *testing.T) {
177	raft1, dir := getRaft(t, true, true)
178	raft2, dir2 := getRaft(t, false, true)
179	raft3, dir3 := getRaft(t, false, true)
180	defer os.RemoveAll(dir)
181	defer os.RemoveAll(dir2)
182	defer os.RemoveAll(dir3)
183
184	// Add raft2 to the cluster
185	addPeer(t, raft1, raft2)
186
187	// Add raft3 to the cluster
188	addPeer(t, raft1, raft3)
189
190	physical.ExerciseBackend(t, raft1)
191
192	time.Sleep(10 * time.Second)
193	// Make sure all stores are the same
194	compareFSMs(t, raft1.fsm, raft2.fsm)
195	compareFSMs(t, raft1.fsm, raft3.fsm)
196}
197
198func TestRaft_Recovery(t *testing.T) {
199	// Create 4 raft nodes
200	raft1, dir1 := getRaft(t, true, true)
201	raft2, dir2 := getRaft(t, false, true)
202	raft3, dir3 := getRaft(t, false, true)
203	raft4, dir4 := getRaft(t, false, true)
204	defer os.RemoveAll(dir1)
205	defer os.RemoveAll(dir2)
206	defer os.RemoveAll(dir3)
207	defer os.RemoveAll(dir4)
208
209	// Add them all to the cluster
210	addPeer(t, raft1, raft2)
211	addPeer(t, raft1, raft3)
212	addPeer(t, raft1, raft4)
213
214	// Add some data into the FSM
215	physical.ExerciseBackend(t, raft1)
216
217	time.Sleep(10 * time.Second)
218
219	// Bring down all nodes
220	raft1.TeardownCluster(nil)
221	raft2.TeardownCluster(nil)
222	raft3.TeardownCluster(nil)
223	raft4.TeardownCluster(nil)
224
225	// Prepare peers.json
226	type RecoveryPeer struct {
227		ID       string `json:"id"`
228		Address  string `json:"address"`
229		NonVoter bool   `json: non_voter`
230	}
231
232	// Leave out node 1 during recovery
233	peersList := make([]*RecoveryPeer, 0, 3)
234	peersList = append(peersList, &RecoveryPeer{
235		ID:       raft1.NodeID(),
236		Address:  raft1.NodeID(),
237		NonVoter: false,
238	})
239	peersList = append(peersList, &RecoveryPeer{
240		ID:       raft2.NodeID(),
241		Address:  raft2.NodeID(),
242		NonVoter: false,
243	})
244	peersList = append(peersList, &RecoveryPeer{
245		ID:       raft4.NodeID(),
246		Address:  raft4.NodeID(),
247		NonVoter: false,
248	})
249
250	peersJSONBytes, err := jsonutil.EncodeJSON(peersList)
251	if err != nil {
252		t.Fatal(err)
253	}
254	err = ioutil.WriteFile(filepath.Join(filepath.Join(dir1, raftState), "peers.json"), peersJSONBytes, 0644)
255	if err != nil {
256		t.Fatal(err)
257	}
258	err = ioutil.WriteFile(filepath.Join(filepath.Join(dir2, raftState), "peers.json"), peersJSONBytes, 0644)
259	if err != nil {
260		t.Fatal(err)
261	}
262	err = ioutil.WriteFile(filepath.Join(filepath.Join(dir4, raftState), "peers.json"), peersJSONBytes, 0644)
263	if err != nil {
264		t.Fatal(err)
265	}
266
267	// Bring up the nodes again
268	raft1.SetupCluster(context.Background(), SetupOpts{})
269	raft2.SetupCluster(context.Background(), SetupOpts{})
270	raft4.SetupCluster(context.Background(), SetupOpts{})
271
272	peers, err := raft1.Peers(context.Background())
273	if err != nil {
274		t.Fatal(err)
275	}
276	if len(peers) != 3 {
277		t.Fatalf("failed to recover the cluster")
278	}
279
280	time.Sleep(10 * time.Second)
281
282	compareFSMs(t, raft1.fsm, raft2.fsm)
283	compareFSMs(t, raft1.fsm, raft4.fsm)
284}
285
286func TestRaft_TransactionalBackend_ThreeNode(t *testing.T) {
287	raft1, dir := getRaft(t, true, true)
288	raft2, dir2 := getRaft(t, false, true)
289	raft3, dir3 := getRaft(t, false, true)
290	defer os.RemoveAll(dir)
291	defer os.RemoveAll(dir2)
292	defer os.RemoveAll(dir3)
293
294	// Add raft2 to the cluster
295	addPeer(t, raft1, raft2)
296
297	// Add raft3 to the cluster
298	addPeer(t, raft1, raft3)
299
300	physical.ExerciseTransactionalBackend(t, raft1)
301
302	time.Sleep(10 * time.Second)
303	// Make sure all stores are the same
304	compareFSMs(t, raft1.fsm, raft2.fsm)
305	compareFSMs(t, raft1.fsm, raft3.fsm)
306}
307
308func TestRaft_Backend_Performance(t *testing.T) {
309	b, dir := getRaft(t, true, false)
310	defer os.RemoveAll(dir)
311
312	defaultConfig := raft.DefaultConfig()
313
314	localConfig := raft.DefaultConfig()
315	b.applyConfigSettings(localConfig)
316
317	if localConfig.ElectionTimeout != defaultConfig.ElectionTimeout*5 {
318		t.Fatalf("bad config: %v", localConfig)
319	}
320	if localConfig.HeartbeatTimeout != defaultConfig.HeartbeatTimeout*5 {
321		t.Fatalf("bad config: %v", localConfig)
322	}
323	if localConfig.LeaderLeaseTimeout != defaultConfig.LeaderLeaseTimeout*5 {
324		t.Fatalf("bad config: %v", localConfig)
325	}
326
327	b.conf = map[string]string{
328		"path":                   dir,
329		"performance_multiplier": "5",
330	}
331
332	localConfig = raft.DefaultConfig()
333	b.applyConfigSettings(localConfig)
334
335	if localConfig.ElectionTimeout != defaultConfig.ElectionTimeout*5 {
336		t.Fatalf("bad config: %v", localConfig)
337	}
338	if localConfig.HeartbeatTimeout != defaultConfig.HeartbeatTimeout*5 {
339		t.Fatalf("bad config: %v", localConfig)
340	}
341	if localConfig.LeaderLeaseTimeout != defaultConfig.LeaderLeaseTimeout*5 {
342		t.Fatalf("bad config: %v", localConfig)
343	}
344
345	b.conf = map[string]string{
346		"path":                   dir,
347		"performance_multiplier": "1",
348	}
349
350	localConfig = raft.DefaultConfig()
351	b.applyConfigSettings(localConfig)
352
353	if localConfig.ElectionTimeout != defaultConfig.ElectionTimeout {
354		t.Fatalf("bad config: %v", localConfig)
355	}
356	if localConfig.HeartbeatTimeout != defaultConfig.HeartbeatTimeout {
357		t.Fatalf("bad config: %v", localConfig)
358	}
359	if localConfig.LeaderLeaseTimeout != defaultConfig.LeaderLeaseTimeout {
360		t.Fatalf("bad config: %v", localConfig)
361	}
362
363}
364
365func BenchmarkDB_Puts(b *testing.B) {
366	raft, dir := getRaft(b, true, false)
367	defer os.RemoveAll(dir)
368	raft2, dir2 := getRaft(b, true, false)
369	defer os.RemoveAll(dir2)
370
371	bench := func(b *testing.B, s physical.Backend, dataSize int) {
372		data, err := uuid.GenerateRandomBytes(dataSize)
373		if err != nil {
374			b.Fatal(err)
375		}
376
377		ctx := context.Background()
378		pe := &physical.Entry{
379			Value: data,
380		}
381		testName := b.Name()
382
383		b.ResetTimer()
384		for i := 0; i < b.N; i++ {
385			pe.Key = fmt.Sprintf("%x", md5.Sum([]byte(fmt.Sprintf("%s-%d", testName, i))))
386			err := s.Put(ctx, pe)
387			if err != nil {
388				b.Fatal(err)
389			}
390		}
391	}
392
393	b.Run("256b", func(b *testing.B) { bench(b, raft, 256) })
394	b.Run("256kb", func(b *testing.B) { bench(b, raft2, 256*1024) })
395}
396
397func BenchmarkDB_Snapshot(b *testing.B) {
398	raft, dir := getRaft(b, true, false)
399	defer os.RemoveAll(dir)
400
401	data, err := uuid.GenerateRandomBytes(256 * 1024)
402	if err != nil {
403		b.Fatal(err)
404	}
405
406	ctx := context.Background()
407	pe := &physical.Entry{
408		Value: data,
409	}
410	testName := b.Name()
411
412	for i := 0; i < 100; i++ {
413		pe.Key = fmt.Sprintf("%x", md5.Sum([]byte(fmt.Sprintf("%s-%d", testName, i))))
414		err = raft.Put(ctx, pe)
415		if err != nil {
416			b.Fatal(err)
417		}
418	}
419
420	bench := func(b *testing.B, s *FSM) {
421		b.ResetTimer()
422		for i := 0; i < b.N; i++ {
423			pe.Key = fmt.Sprintf("%x", md5.Sum([]byte(fmt.Sprintf("%s-%d", testName, i))))
424			s.writeTo(ctx, discardCloser{Writer: ioutil.Discard}, discardCloser{Writer: ioutil.Discard})
425		}
426	}
427
428	b.Run("256kb", func(b *testing.B) { bench(b, raft.fsm) })
429}
430
431type discardCloser struct {
432	io.Writer
433}
434
435func (d discardCloser) Close() error               { return nil }
436func (d discardCloser) CloseWithError(error) error { return nil }
437