1package raft 2 3import ( 4 "context" 5 fmt "fmt" 6 "io/ioutil" 7 "math/rand" 8 "os" 9 "sort" 10 "testing" 11 12 "github.com/go-test/deep" 13 proto "github.com/golang/protobuf/proto" 14 hclog "github.com/hashicorp/go-hclog" 15 "github.com/hashicorp/raft" 16 "github.com/hashicorp/vault/sdk/physical" 17) 18 19func getFSM(t testing.TB) (*FSM, string) { 20 raftDir, err := ioutil.TempDir("", "vault-raft-") 21 if err != nil { 22 t.Fatal(err) 23 } 24 t.Logf("raft dir: %s", raftDir) 25 26 logger := hclog.New(&hclog.LoggerOptions{ 27 Name: "raft", 28 Level: hclog.Trace, 29 }) 30 31 fsm, err := NewFSM(raftDir, "", logger) 32 if err != nil { 33 t.Fatal(err) 34 } 35 36 return fsm, raftDir 37} 38 39func TestFSM_Batching(t *testing.T) { 40 fsm, dir := getFSM(t) 41 defer os.RemoveAll(dir) 42 43 var index uint64 44 var term uint64 = 1 45 46 getLog := func(i uint64) (int, *raft.Log) { 47 if rand.Intn(10) >= 8 { 48 term += 1 49 return 0, &raft.Log{ 50 Index: i, 51 Term: term, 52 Type: raft.LogConfiguration, 53 Data: raft.EncodeConfiguration(raft.Configuration{ 54 Servers: []raft.Server{ 55 { 56 Address: raft.ServerAddress("test"), 57 ID: raft.ServerID("test"), 58 }, 59 }, 60 }), 61 } 62 } 63 64 command := &LogData{ 65 Operations: make([]*LogOperation, rand.Intn(10)), 66 } 67 68 for j := range command.Operations { 69 command.Operations[j] = &LogOperation{ 70 OpType: putOp, 71 Key: fmt.Sprintf("key-%d-%d", i, j), 72 Value: []byte(fmt.Sprintf("value-%d-%d", i, j)), 73 } 74 } 75 commandBytes, err := proto.Marshal(command) 76 if err != nil { 77 t.Fatal(err) 78 } 79 return len(command.Operations), &raft.Log{ 80 Index: i, 81 Term: term, 82 Type: raft.LogCommand, 83 Data: commandBytes, 84 } 85 } 86 87 totalKeys := 0 88 for i := 0; i < 100; i++ { 89 batchSize := rand.Intn(64) 90 batch := make([]*raft.Log, batchSize) 91 for j := 0; j < batchSize; j++ { 92 var keys int 93 index++ 94 keys, batch[j] = getLog(index) 95 totalKeys += keys 96 } 97 98 resp := fsm.ApplyBatch(batch) 99 if len(resp) != batchSize { 100 t.Fatalf("incorrect response length: got %d expected %d", len(resp), batchSize) 101 } 102 103 for _, r := range resp { 104 if _, ok := r.(*FSMApplyResponse); !ok { 105 t.Fatal("bad response type") 106 } 107 } 108 } 109 110 keys, err := fsm.List(context.Background(), "") 111 if err != nil { 112 t.Fatal(err) 113 } 114 115 if len(keys) != totalKeys { 116 t.Fatalf("incorrect number of keys: got %d expected %d", len(keys), totalKeys) 117 } 118 119 latestIndex, latestConfig := fsm.LatestState() 120 if latestIndex.Index != index { 121 t.Fatalf("bad latest index: got %d expected %d", latestIndex.Index, index) 122 } 123 if latestIndex.Term != term { 124 t.Fatalf("bad latest term: got %d expected %d", latestIndex.Term, term) 125 } 126 127 if latestConfig == nil && term > 1 { 128 t.Fatal("config wasn't updated") 129 } 130} 131 132func TestFSM_List(t *testing.T) { 133 fsm, dir := getFSM(t) 134 defer os.RemoveAll(dir) 135 136 ctx := context.Background() 137 count := 100 138 keys := rand.Perm(count) 139 var sorted []string 140 for _, k := range keys { 141 err := fsm.Put(ctx, &physical.Entry{Key: fmt.Sprintf("foo/%d/bar", k)}) 142 if err != nil { 143 t.Fatal(err) 144 } 145 err = fsm.Put(ctx, &physical.Entry{Key: fmt.Sprintf("foo/%d/baz", k)}) 146 if err != nil { 147 t.Fatal(err) 148 } 149 sorted = append(sorted, fmt.Sprintf("%d/", k)) 150 } 151 sort.Strings(sorted) 152 153 got, err := fsm.List(ctx, "foo/") 154 if err != nil { 155 t.Fatal(err) 156 } 157 sort.Strings(got) 158 if diff := deep.Equal(sorted, got); len(diff) > 0 { 159 t.Fatal(diff) 160 } 161} 162