1package raft
2
3import (
4	"context"
5	fmt "fmt"
6	"io/ioutil"
7	"math/rand"
8	"os"
9	"sort"
10	"testing"
11
12	"github.com/go-test/deep"
13	proto "github.com/golang/protobuf/proto"
14	hclog "github.com/hashicorp/go-hclog"
15	"github.com/hashicorp/raft"
16	"github.com/hashicorp/vault/sdk/physical"
17)
18
19func getFSM(t testing.TB) (*FSM, string) {
20	raftDir, err := ioutil.TempDir("", "vault-raft-")
21	if err != nil {
22		t.Fatal(err)
23	}
24	t.Logf("raft dir: %s", raftDir)
25
26	logger := hclog.New(&hclog.LoggerOptions{
27		Name:  "raft",
28		Level: hclog.Trace,
29	})
30
31	fsm, err := NewFSM(raftDir, "", logger)
32	if err != nil {
33		t.Fatal(err)
34	}
35
36	return fsm, raftDir
37}
38
39func TestFSM_Batching(t *testing.T) {
40	fsm, dir := getFSM(t)
41	defer os.RemoveAll(dir)
42
43	var index uint64
44	var term uint64 = 1
45
46	getLog := func(i uint64) (int, *raft.Log) {
47		if rand.Intn(10) >= 8 {
48			term += 1
49			return 0, &raft.Log{
50				Index: i,
51				Term:  term,
52				Type:  raft.LogConfiguration,
53				Data: raft.EncodeConfiguration(raft.Configuration{
54					Servers: []raft.Server{
55						{
56							Address: raft.ServerAddress("test"),
57							ID:      raft.ServerID("test"),
58						},
59					},
60				}),
61			}
62		}
63
64		command := &LogData{
65			Operations: make([]*LogOperation, rand.Intn(10)),
66		}
67
68		for j := range command.Operations {
69			command.Operations[j] = &LogOperation{
70				OpType: putOp,
71				Key:    fmt.Sprintf("key-%d-%d", i, j),
72				Value:  []byte(fmt.Sprintf("value-%d-%d", i, j)),
73			}
74		}
75		commandBytes, err := proto.Marshal(command)
76		if err != nil {
77			t.Fatal(err)
78		}
79		return len(command.Operations), &raft.Log{
80			Index: i,
81			Term:  term,
82			Type:  raft.LogCommand,
83			Data:  commandBytes,
84		}
85	}
86
87	totalKeys := 0
88	for i := 0; i < 100; i++ {
89		batchSize := rand.Intn(64)
90		batch := make([]*raft.Log, batchSize)
91		for j := 0; j < batchSize; j++ {
92			var keys int
93			index++
94			keys, batch[j] = getLog(index)
95			totalKeys += keys
96		}
97
98		resp := fsm.ApplyBatch(batch)
99		if len(resp) != batchSize {
100			t.Fatalf("incorrect response length: got %d expected %d", len(resp), batchSize)
101		}
102
103		for _, r := range resp {
104			if _, ok := r.(*FSMApplyResponse); !ok {
105				t.Fatal("bad response type")
106			}
107		}
108	}
109
110	keys, err := fsm.List(context.Background(), "")
111	if err != nil {
112		t.Fatal(err)
113	}
114
115	if len(keys) != totalKeys {
116		t.Fatalf("incorrect number of keys: got %d expected %d", len(keys), totalKeys)
117	}
118
119	latestIndex, latestConfig := fsm.LatestState()
120	if latestIndex.Index != index {
121		t.Fatalf("bad latest index: got %d expected %d", latestIndex.Index, index)
122	}
123	if latestIndex.Term != term {
124		t.Fatalf("bad latest term: got %d expected %d", latestIndex.Term, term)
125	}
126
127	if latestConfig == nil && term > 1 {
128		t.Fatal("config wasn't updated")
129	}
130}
131
132func TestFSM_List(t *testing.T) {
133	fsm, dir := getFSM(t)
134	defer os.RemoveAll(dir)
135
136	ctx := context.Background()
137	count := 100
138	keys := rand.Perm(count)
139	var sorted []string
140	for _, k := range keys {
141		err := fsm.Put(ctx, &physical.Entry{Key: fmt.Sprintf("foo/%d/bar", k)})
142		if err != nil {
143			t.Fatal(err)
144		}
145		err = fsm.Put(ctx, &physical.Entry{Key: fmt.Sprintf("foo/%d/baz", k)})
146		if err != nil {
147			t.Fatal(err)
148		}
149		sorted = append(sorted, fmt.Sprintf("%d/", k))
150	}
151	sort.Strings(sorted)
152
153	got, err := fsm.List(ctx, "foo/")
154	if err != nil {
155		t.Fatal(err)
156	}
157	sort.Strings(got)
158	if diff := deep.Equal(sorted, got); len(diff) > 0 {
159		t.Fatal(diff)
160	}
161}
162