1package rafttests
2
3import (
4	"bytes"
5	"crypto/md5"
6	"fmt"
7	"io/ioutil"
8	"net/http"
9	"strings"
10	"sync/atomic"
11	"testing"
12	"time"
13
14	"github.com/hashicorp/go-cleanhttp"
15	uuid "github.com/hashicorp/go-uuid"
16	"github.com/hashicorp/vault/api"
17	"github.com/hashicorp/vault/helper/testhelpers"
18	"github.com/hashicorp/vault/helper/testhelpers/teststorage"
19	vaulthttp "github.com/hashicorp/vault/http"
20	"github.com/hashicorp/vault/physical/raft"
21	"github.com/hashicorp/vault/vault"
22	"golang.org/x/net/http2"
23)
24
25func raftCluster(t testing.TB) *vault.TestCluster {
26	var conf vault.CoreConfig
27	var opts = vault.TestClusterOptions{HandlerFunc: vaulthttp.Handler}
28	teststorage.RaftBackendSetup(&conf, &opts)
29	cluster := vault.NewTestCluster(t, &conf, &opts)
30	cluster.Start()
31	vault.TestWaitActive(t, cluster.Cores[0].Core)
32	return cluster
33}
34
35func TestRaft_Join(t *testing.T) {
36	var conf vault.CoreConfig
37	var opts = vault.TestClusterOptions{HandlerFunc: vaulthttp.Handler}
38	teststorage.RaftBackendSetup(&conf, &opts)
39	opts.SetupFunc = nil
40	cluster := vault.NewTestCluster(t, &conf, &opts)
41	cluster.Start()
42	defer cluster.Cleanup()
43
44	addressProvider := &testhelpers.TestRaftServerAddressProvider{Cluster: cluster}
45
46	leaderCore := cluster.Cores[0]
47	leaderAPI := leaderCore.Client.Address()
48	atomic.StoreUint32(&vault.UpdateClusterAddrForTests, 1)
49
50	// Seal the leader so we can install an address provider
51	{
52		testhelpers.EnsureCoreSealed(t, leaderCore)
53		leaderCore.UnderlyingRawStorage.(*raft.RaftBackend).SetServerAddressProvider(addressProvider)
54		cluster.UnsealCore(t, leaderCore)
55		vault.TestWaitActive(t, leaderCore.Core)
56	}
57
58	joinFunc := func(client *api.Client, addClientCerts bool) {
59		req := &api.RaftJoinRequest{
60			LeaderAPIAddr: leaderAPI,
61			LeaderCACert:  string(cluster.CACertPEM),
62		}
63		if addClientCerts {
64			req.LeaderClientCert = string(cluster.CACertPEM)
65			req.LeaderClientKey = string(cluster.CAKeyPEM)
66		}
67		resp, err := client.Sys().RaftJoin(req)
68		if err != nil {
69			t.Fatal(err)
70		}
71		if !resp.Joined {
72			t.Fatalf("failed to join raft cluster")
73		}
74	}
75
76	joinFunc(cluster.Cores[1].Client, false)
77	joinFunc(cluster.Cores[2].Client, false)
78
79	_, err := cluster.Cores[0].Client.Logical().Write("sys/storage/raft/remove-peer", map[string]interface{}{
80		"server_id": "core-1",
81	})
82	if err != nil {
83		t.Fatal(err)
84	}
85
86	_, err = cluster.Cores[0].Client.Logical().Write("sys/storage/raft/remove-peer", map[string]interface{}{
87		"server_id": "core-2",
88	})
89	if err != nil {
90		t.Fatal(err)
91	}
92
93	joinFunc(cluster.Cores[1].Client, true)
94	joinFunc(cluster.Cores[2].Client, true)
95}
96
97func TestRaft_RemovePeer(t *testing.T) {
98	cluster := raftCluster(t)
99	defer cluster.Cleanup()
100
101	for i, c := range cluster.Cores {
102		if c.Core.Sealed() {
103			t.Fatalf("failed to unseal core %d", i)
104		}
105	}
106
107	client := cluster.Cores[0].Client
108
109	checkConfigFunc := func(expected map[string]bool) {
110		secret, err := client.Logical().Read("sys/storage/raft/configuration")
111		if err != nil {
112			t.Fatal(err)
113		}
114		servers := secret.Data["config"].(map[string]interface{})["servers"].([]interface{})
115
116		for _, s := range servers {
117			server := s.(map[string]interface{})
118			delete(expected, server["node_id"].(string))
119		}
120		if len(expected) != 0 {
121			t.Fatalf("failed to read configuration successfully")
122		}
123	}
124
125	checkConfigFunc(map[string]bool{
126		"core-0": true,
127		"core-1": true,
128		"core-2": true,
129	})
130
131	_, err := client.Logical().Write("sys/storage/raft/remove-peer", map[string]interface{}{
132		"server_id": "core-2",
133	})
134	if err != nil {
135		t.Fatal(err)
136	}
137
138	checkConfigFunc(map[string]bool{
139		"core-0": true,
140		"core-1": true,
141	})
142
143	_, err = client.Logical().Write("sys/storage/raft/remove-peer", map[string]interface{}{
144		"server_id": "core-1",
145	})
146	if err != nil {
147		t.Fatal(err)
148	}
149
150	checkConfigFunc(map[string]bool{
151		"core-0": true,
152	})
153}
154
155func TestRaft_Configuration(t *testing.T) {
156	cluster := raftCluster(t)
157	defer cluster.Cleanup()
158
159	for i, c := range cluster.Cores {
160		if c.Core.Sealed() {
161			t.Fatalf("failed to unseal core %d", i)
162		}
163	}
164
165	client := cluster.Cores[0].Client
166	secret, err := client.Logical().Read("sys/storage/raft/configuration")
167	if err != nil {
168		t.Fatal(err)
169	}
170	servers := secret.Data["config"].(map[string]interface{})["servers"].([]interface{})
171	expected := map[string]bool{
172		"core-0": true,
173		"core-1": true,
174		"core-2": true,
175	}
176	if len(servers) != 3 {
177		t.Fatalf("incorrect number of servers in the configuration")
178	}
179	for _, s := range servers {
180		server := s.(map[string]interface{})
181		nodeID := server["node_id"].(string)
182		leader := server["leader"].(bool)
183		switch nodeID {
184		case "core-0":
185			if !leader {
186				t.Fatalf("expected server to be leader: %#v", server)
187			}
188		default:
189			if leader {
190				t.Fatalf("expected server to not be leader: %#v", server)
191			}
192		}
193
194		delete(expected, nodeID)
195	}
196	if len(expected) != 0 {
197		t.Fatalf("failed to read configuration successfully")
198	}
199}
200
201func TestRaft_ShamirUnseal(t *testing.T) {
202	cluster := raftCluster(t)
203	defer cluster.Cleanup()
204
205	for i, c := range cluster.Cores {
206		if c.Core.Sealed() {
207			t.Fatalf("failed to unseal core %d", i)
208		}
209	}
210}
211
212func TestRaft_SnapshotAPI(t *testing.T) {
213	cluster := raftCluster(t)
214	defer cluster.Cleanup()
215
216	leaderClient := cluster.Cores[0].Client
217
218	// Write a few keys
219	for i := 0; i < 10; i++ {
220		_, err := leaderClient.Logical().Write(fmt.Sprintf("secret/%d", i), map[string]interface{}{
221			"test": "data",
222		})
223		if err != nil {
224			t.Fatal(err)
225		}
226	}
227
228	transport := cleanhttp.DefaultPooledTransport()
229	transport.TLSClientConfig = cluster.Cores[0].TLSConfig.Clone()
230	if err := http2.ConfigureTransport(transport); err != nil {
231		t.Fatal(err)
232	}
233	client := &http.Client{
234		Transport: transport,
235	}
236
237	// Take a snapshot
238	req := leaderClient.NewRequest("GET", "/v1/sys/storage/raft/snapshot")
239	httpReq, err := req.ToHTTP()
240	if err != nil {
241		t.Fatal(err)
242	}
243	resp, err := client.Do(httpReq)
244	if err != nil {
245		t.Fatal(err)
246	}
247	defer resp.Body.Close()
248
249	snap, err := ioutil.ReadAll(resp.Body)
250	if err != nil {
251		t.Fatal(err)
252	}
253	if len(snap) == 0 {
254		t.Fatal("no snapshot returned")
255	}
256
257	// Write a few more keys
258	for i := 10; i < 20; i++ {
259		_, err := leaderClient.Logical().Write(fmt.Sprintf("secret/%d", i), map[string]interface{}{
260			"test": "data",
261		})
262		if err != nil {
263			t.Fatal(err)
264		}
265	}
266
267	// Restore snapshot
268	req = leaderClient.NewRequest("POST", "/v1/sys/storage/raft/snapshot")
269	req.Body = bytes.NewBuffer(snap)
270	httpReq, err = req.ToHTTP()
271	if err != nil {
272		t.Fatal(err)
273	}
274	resp, err = client.Do(httpReq)
275	if err != nil {
276		t.Fatal(err)
277	}
278
279	// List kv to make sure we removed the extra keys
280	secret, err := leaderClient.Logical().List("secret/")
281	if err != nil {
282		t.Fatal(err)
283	}
284
285	if len(secret.Data["keys"].([]interface{})) != 10 {
286		t.Fatal("snapshot didn't apply correctly")
287	}
288}
289
290func TestRaft_SnapshotAPI_RekeyRotate_Backward(t *testing.T) {
291	tCases := []struct {
292		Name   string
293		Rekey  bool
294		Rotate bool
295	}{
296		{
297			Name:   "rekey",
298			Rekey:  true,
299			Rotate: false,
300		},
301		{
302			Name:   "rotate",
303			Rekey:  false,
304			Rotate: true,
305		},
306		{
307			Name:   "both",
308			Rekey:  true,
309			Rotate: true,
310		},
311	}
312
313	for _, tCase := range tCases {
314		t.Run(tCase.Name, func(t *testing.T) {
315			// bind locally
316			tCaseLocal := tCase
317			t.Parallel()
318
319			cluster := raftCluster(t)
320			defer cluster.Cleanup()
321
322			leaderClient := cluster.Cores[0].Client
323
324			// Write a few keys
325			for i := 0; i < 10; i++ {
326				_, err := leaderClient.Logical().Write(fmt.Sprintf("secret/%d", i), map[string]interface{}{
327					"test": "data",
328				})
329				if err != nil {
330					t.Fatal(err)
331				}
332			}
333
334			transport := cleanhttp.DefaultPooledTransport()
335			transport.TLSClientConfig = cluster.Cores[0].TLSConfig.Clone()
336			if err := http2.ConfigureTransport(transport); err != nil {
337				t.Fatal(err)
338			}
339			client := &http.Client{
340				Transport: transport,
341			}
342
343			// Take a snapshot
344			req := leaderClient.NewRequest("GET", "/v1/sys/storage/raft/snapshot")
345			httpReq, err := req.ToHTTP()
346			if err != nil {
347				t.Fatal(err)
348			}
349			resp, err := client.Do(httpReq)
350			if err != nil {
351				t.Fatal(err)
352			}
353			defer resp.Body.Close()
354
355			snap, err := ioutil.ReadAll(resp.Body)
356			if err != nil {
357				t.Fatal(err)
358			}
359			if len(snap) == 0 {
360				t.Fatal("no snapshot returned")
361			}
362
363			// cache the original barrier keys
364			barrierKeys := cluster.BarrierKeys
365
366			if tCaseLocal.Rotate {
367				// Rotate
368				err = leaderClient.Sys().Rotate()
369				if err != nil {
370					t.Fatal(err)
371				}
372			}
373
374			if tCaseLocal.Rekey {
375				// Rekey
376				cluster.BarrierKeys = testhelpers.RekeyCluster(t, cluster, false)
377			}
378
379			if tCaseLocal.Rekey {
380				// Restore snapshot, should fail.
381				req = leaderClient.NewRequest("POST", "/v1/sys/storage/raft/snapshot")
382				req.Body = bytes.NewBuffer(snap)
383				httpReq, err = req.ToHTTP()
384				if err != nil {
385					t.Fatal(err)
386				}
387				resp, err = client.Do(httpReq)
388				if err != nil {
389					t.Fatal(err)
390				}
391				// Parse Response
392				apiResp := api.Response{Response: resp}
393				if !strings.Contains(apiResp.Error().Error(), "could not verify hash file, possibly the snapshot is using a different set of unseal keys") {
394					t.Fatal(apiResp.Error())
395				}
396			}
397
398			// Restore snapshot force
399			req = leaderClient.NewRequest("POST", "/v1/sys/storage/raft/snapshot-force")
400			req.Body = bytes.NewBuffer(snap)
401			httpReq, err = req.ToHTTP()
402			if err != nil {
403				t.Fatal(err)
404			}
405			resp, err = client.Do(httpReq)
406			if err != nil {
407				t.Fatal(err)
408			}
409
410			testhelpers.EnsureStableActiveNode(t, cluster)
411
412			// Write some data so we can make sure we can read it later. This is testing
413			// that we correctly reload the keyring
414			_, err = leaderClient.Logical().Write("secret/foo", map[string]interface{}{
415				"test": "data",
416			})
417			if err != nil {
418				t.Fatal(err)
419			}
420
421			testhelpers.EnsureCoresSealed(t, cluster)
422
423			cluster.BarrierKeys = barrierKeys
424			testhelpers.EnsureCoresUnsealed(t, cluster)
425			testhelpers.WaitForActiveNode(t, cluster)
426			activeCore := testhelpers.DeriveActiveCore(t, cluster)
427
428			// Read the value.
429			data, err := activeCore.Client.Logical().Read("secret/foo")
430			if err != nil {
431				t.Fatal(err)
432			}
433			if data.Data["test"].(string) != "data" {
434				t.Fatal(data)
435			}
436		})
437	}
438}
439
440func TestRaft_SnapshotAPI_RekeyRotate_Forward(t *testing.T) {
441	tCases := []struct {
442		Name       string
443		Rekey      bool
444		Rotate     bool
445		ShouldSeal bool
446	}{
447		{
448			Name:       "rekey",
449			Rekey:      true,
450			Rotate:     false,
451			ShouldSeal: false,
452		},
453		{
454			Name:   "rotate",
455			Rekey:  false,
456			Rotate: true,
457			// Rotate writes a new master key upgrade using the new term, which
458			// we can no longer decrypt. We must seal here.
459			ShouldSeal: true,
460		},
461		{
462			Name:   "both",
463			Rekey:  true,
464			Rotate: true,
465			// If we are moving forward and we have rekeyed and rotated there
466			// isn't any way to restore the latest keys so expect to seal.
467			ShouldSeal: true,
468		},
469	}
470
471	for _, tCase := range tCases {
472		t.Run(tCase.Name, func(t *testing.T) {
473			// bind locally
474			tCaseLocal := tCase
475			t.Parallel()
476
477			cluster := raftCluster(t)
478			defer cluster.Cleanup()
479
480			leaderClient := cluster.Cores[0].Client
481
482			// Write a few keys
483			for i := 0; i < 10; i++ {
484				_, err := leaderClient.Logical().Write(fmt.Sprintf("secret/%d", i), map[string]interface{}{
485					"test": "data",
486				})
487				if err != nil {
488					t.Fatal(err)
489				}
490			}
491
492			transport := cleanhttp.DefaultPooledTransport()
493			transport.TLSClientConfig = cluster.Cores[0].TLSConfig.Clone()
494			if err := http2.ConfigureTransport(transport); err != nil {
495				t.Fatal(err)
496			}
497			client := &http.Client{
498				Transport: transport,
499			}
500
501			// Take a snapshot
502			req := leaderClient.NewRequest("GET", "/v1/sys/storage/raft/snapshot")
503			httpReq, err := req.ToHTTP()
504			if err != nil {
505				t.Fatal(err)
506			}
507			resp, err := client.Do(httpReq)
508			if err != nil {
509				t.Fatal(err)
510			}
511
512			snap, err := ioutil.ReadAll(resp.Body)
513			resp.Body.Close()
514			if err != nil {
515				t.Fatal(err)
516			}
517			if len(snap) == 0 {
518				t.Fatal("no snapshot returned")
519			}
520
521			if tCaseLocal.Rekey {
522				// Rekey
523				cluster.BarrierKeys = testhelpers.RekeyCluster(t, cluster, false)
524			}
525			if tCaseLocal.Rotate {
526				// Set the key clean up to 0 so it's cleaned immediately. This
527				// will simulate that there are no ways to upgrade to the latest
528				// term.
529				vault.KeyRotateGracePeriod = 0
530
531				// Rotate
532				err = leaderClient.Sys().Rotate()
533				if err != nil {
534					t.Fatal(err)
535				}
536				// Let the key upgrade get deleted
537				time.Sleep(1 * time.Second)
538			}
539
540			// cache the new barrier keys
541			newBarrierKeys := cluster.BarrierKeys
542
543			// Take another snapshot for later use in "jumping" forward
544			req = leaderClient.NewRequest("GET", "/v1/sys/storage/raft/snapshot")
545			httpReq, err = req.ToHTTP()
546			if err != nil {
547				t.Fatal(err)
548			}
549			resp, err = client.Do(httpReq)
550			if err != nil {
551				t.Fatal(err)
552			}
553
554			snap2, err := ioutil.ReadAll(resp.Body)
555			resp.Body.Close()
556			if err != nil {
557				t.Fatal(err)
558			}
559			if len(snap2) == 0 {
560				t.Fatal("no snapshot returned")
561			}
562
563			// Restore snapshot to move us back in time so we can test going
564			// forward
565			req = leaderClient.NewRequest("POST", "/v1/sys/storage/raft/snapshot-force")
566			req.Body = bytes.NewBuffer(snap)
567			httpReq, err = req.ToHTTP()
568			if err != nil {
569				t.Fatal(err)
570			}
571			resp, err = client.Do(httpReq)
572			if err != nil {
573				t.Fatal(err)
574			}
575
576			testhelpers.EnsureStableActiveNode(t, cluster)
577			if tCaseLocal.Rekey {
578				// Restore snapshot, should fail.
579				req = leaderClient.NewRequest("POST", "/v1/sys/storage/raft/snapshot")
580				req.Body = bytes.NewBuffer(snap2)
581				httpReq, err = req.ToHTTP()
582				if err != nil {
583					t.Fatal(err)
584				}
585				resp, err = client.Do(httpReq)
586				if err != nil {
587					t.Fatal(err)
588				}
589				// Parse Response
590				apiResp := api.Response{Response: resp}
591				if apiResp.Error() == nil || !strings.Contains(apiResp.Error().Error(), "could not verify hash file, possibly the snapshot is using a different set of unseal keys") {
592					t.Fatalf("expected error verifying hash file, got %v", apiResp.Error())
593				}
594			}
595
596			// Restore snapshot force
597			req = leaderClient.NewRequest("POST", "/v1/sys/storage/raft/snapshot-force")
598			req.Body = bytes.NewBuffer(snap2)
599			httpReq, err = req.ToHTTP()
600			if err != nil {
601				t.Fatal(err)
602			}
603			resp, err = client.Do(httpReq)
604			if err != nil {
605				t.Fatal(err)
606			}
607
608			switch tCaseLocal.ShouldSeal {
609			case true:
610				testhelpers.WaitForNCoresSealed(t, cluster, 3)
611
612			case false:
613				testhelpers.EnsureStableActiveNode(t, cluster)
614
615				// Write some data so we can make sure we can read it later. This is testing
616				// that we correctly reload the keyring
617				_, err = leaderClient.Logical().Write("secret/foo", map[string]interface{}{
618					"test": "data",
619				})
620				if err != nil {
621					t.Fatal(err)
622				}
623
624				testhelpers.EnsureCoresSealed(t, cluster)
625
626				cluster.BarrierKeys = newBarrierKeys
627				testhelpers.EnsureCoresUnsealed(t, cluster)
628				testhelpers.WaitForActiveNode(t, cluster)
629				activeCore := testhelpers.DeriveActiveCore(t, cluster)
630
631				// Read the value.
632				data, err := activeCore.Client.Logical().Read("secret/foo")
633				if err != nil {
634					t.Fatal(err)
635				}
636				if data.Data["test"].(string) != "data" {
637					t.Fatal(data)
638				}
639			}
640		})
641	}
642}
643
644func TestRaft_SnapshotAPI_DifferentCluster(t *testing.T) {
645	cluster := raftCluster(t)
646	defer cluster.Cleanup()
647
648	leaderClient := cluster.Cores[0].Client
649
650	// Write a few keys
651	for i := 0; i < 10; i++ {
652		_, err := leaderClient.Logical().Write(fmt.Sprintf("secret/%d", i), map[string]interface{}{
653			"test": "data",
654		})
655		if err != nil {
656			t.Fatal(err)
657		}
658	}
659
660	transport := cleanhttp.DefaultPooledTransport()
661	transport.TLSClientConfig = cluster.Cores[0].TLSConfig.Clone()
662	if err := http2.ConfigureTransport(transport); err != nil {
663		t.Fatal(err)
664	}
665	client := &http.Client{
666		Transport: transport,
667	}
668
669	// Take a snapshot
670	req := leaderClient.NewRequest("GET", "/v1/sys/storage/raft/snapshot")
671	httpReq, err := req.ToHTTP()
672	if err != nil {
673		t.Fatal(err)
674	}
675	resp, err := client.Do(httpReq)
676	if err != nil {
677		t.Fatal(err)
678	}
679
680	snap, err := ioutil.ReadAll(resp.Body)
681	resp.Body.Close()
682	if err != nil {
683		t.Fatal(err)
684	}
685	if len(snap) == 0 {
686		t.Fatal("no snapshot returned")
687	}
688
689	// Cluster 2
690	{
691		cluster2 := raftCluster(t)
692		defer cluster2.Cleanup()
693
694		leaderClient := cluster2.Cores[0].Client
695
696		transport := cleanhttp.DefaultPooledTransport()
697		transport.TLSClientConfig = cluster2.Cores[0].TLSConfig.Clone()
698		if err := http2.ConfigureTransport(transport); err != nil {
699			t.Fatal(err)
700		}
701		client := &http.Client{
702			Transport: transport,
703		}
704		// Restore snapshot, should fail.
705		req = leaderClient.NewRequest("POST", "/v1/sys/storage/raft/snapshot")
706		req.Body = bytes.NewBuffer(snap)
707		httpReq, err = req.ToHTTP()
708		if err != nil {
709			t.Fatal(err)
710		}
711		resp, err = client.Do(httpReq)
712		if err != nil {
713			t.Fatal(err)
714		}
715		// Parse Response
716		apiResp := api.Response{Response: resp}
717		if !strings.Contains(apiResp.Error().Error(), "could not verify hash file, possibly the snapshot is using a different set of unseal keys") {
718			t.Fatal(apiResp.Error())
719		}
720
721		// Restore snapshot force
722		req = leaderClient.NewRequest("POST", "/v1/sys/storage/raft/snapshot-force")
723		req.Body = bytes.NewBuffer(snap)
724		httpReq, err = req.ToHTTP()
725		if err != nil {
726			t.Fatal(err)
727		}
728		resp, err = client.Do(httpReq)
729		if err != nil {
730			t.Fatal(err)
731		}
732
733		testhelpers.WaitForNCoresSealed(t, cluster2, 3)
734	}
735}
736
737func BenchmarkRaft_SingleNode(b *testing.B) {
738	cluster := raftCluster(b)
739	defer cluster.Cleanup()
740
741	leaderClient := cluster.Cores[0].Client
742
743	bench := func(b *testing.B, dataSize int) {
744		data, err := uuid.GenerateRandomBytes(dataSize)
745		if err != nil {
746			b.Fatal(err)
747		}
748
749		testName := b.Name()
750
751		b.ResetTimer()
752		for i := 0; i < b.N; i++ {
753			key := fmt.Sprintf("secret/%x", md5.Sum([]byte(fmt.Sprintf("%s-%d", testName, i))))
754			_, err := leaderClient.Logical().Write(key, map[string]interface{}{
755				"test": data,
756			})
757			if err != nil {
758				b.Fatal(err)
759			}
760		}
761	}
762
763	b.Run("256b", func(b *testing.B) { bench(b, 25) })
764}
765