1// Copyright (c) 2017 Couchbase, Inc. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package vellum 16 17import ( 18 "reflect" 19 "sort" 20 "testing" 21) 22 23func TestMergeIterator(t *testing.T) { 24 25 tests := []struct { 26 desc string 27 in []map[string]uint64 28 merge MergeFunc 29 want map[string]uint64 30 }{ 31 { 32 desc: "two non-empty iterators with no duplicate keys", 33 in: []map[string]uint64{ 34 map[string]uint64{ 35 "a": 1, 36 "c": 3, 37 "e": 5, 38 }, 39 map[string]uint64{ 40 "b": 2, 41 "d": 4, 42 "f": 6, 43 }, 44 }, 45 merge: func(mvs []uint64) uint64 { 46 return mvs[0] 47 }, 48 want: map[string]uint64{ 49 "a": 1, 50 "c": 3, 51 "e": 5, 52 "b": 2, 53 "d": 4, 54 "f": 6, 55 }, 56 }, 57 { 58 desc: "two non-empty iterators with duplicate keys summed", 59 in: []map[string]uint64{ 60 map[string]uint64{ 61 "a": 1, 62 "c": 3, 63 "e": 5, 64 }, 65 map[string]uint64{ 66 "a": 2, 67 "c": 4, 68 "e": 6, 69 }, 70 }, 71 merge: func(mvs []uint64) uint64 { 72 var rv uint64 73 for _, mv := range mvs { 74 rv += mv 75 } 76 return rv 77 }, 78 want: map[string]uint64{ 79 "a": 3, 80 "c": 7, 81 "e": 11, 82 }, 83 }, 84 85 { 86 desc: "non-working example", 87 in: []map[string]uint64{ 88 map[string]uint64{ 89 "mon": 2, 90 "tues": 3, 91 "thurs": 5, 92 "tye": 99, 93 }, 94 map[string]uint64{ 95 "bold": 25, 96 "last": 1, 97 "next": 500, 98 "tank": 0, 99 }, 100 }, 101 merge: func(mvs []uint64) uint64 { 102 return mvs[0] 103 }, 104 want: map[string]uint64{ 105 "mon": 2, 106 "tues": 3, 107 "thurs": 5, 108 "tye": 99, 109 "bold": 25, 110 "last": 1, 111 "next": 500, 112 "tank": 0, 113 }, 114 }, 115 } 116 117 for _, test := range tests { 118 t.Run(test.desc, func(t *testing.T) { 119 var itrs []Iterator 120 for i := range test.in { 121 itr, err := newTestIterator(test.in[i]) 122 if err != nil && err != ErrIteratorDone { 123 t.Fatalf("error creating iterator: %v", err) 124 } 125 if err == nil { 126 itrs = append(itrs, itr) 127 } 128 } 129 mi, err := NewMergeIterator(itrs, test.merge) 130 if err != nil && err != ErrIteratorDone { 131 t.Fatalf("error creating iterator: %v", err) 132 } 133 got := make(map[string]uint64) 134 for err == nil { 135 currk, currv := mi.Current() 136 err = mi.Next() 137 got[string(currk)] = currv 138 } 139 if err != nil && err != ErrIteratorDone { 140 t.Fatalf("error iterating: %v", err) 141 } 142 143 if !reflect.DeepEqual(got, test.want) { 144 t.Errorf("expected %v, got %v", test.want, got) 145 } 146 }) 147 } 148} 149 150type testIterator struct { 151 vals map[int]uint64 152 keys []string 153 curr int 154} 155 156func newTestIterator(in map[string]uint64) (*testIterator, error) { 157 rv := &testIterator{ 158 vals: make(map[int]uint64, len(in)), 159 } 160 for k := range in { 161 rv.keys = append(rv.keys, k) 162 } 163 sort.Strings(rv.keys) 164 for i, k := range rv.keys { 165 rv.vals[i] = in[k] 166 } 167 return rv, nil 168} 169 170func (m *testIterator) Current() ([]byte, uint64) { 171 if m.curr >= len(m.keys) { 172 return nil, 0 173 } 174 return []byte(m.keys[m.curr]), m.vals[m.curr] 175} 176 177func (m *testIterator) Next() error { 178 m.curr++ 179 if m.curr >= len(m.keys) { 180 return ErrIteratorDone 181 } 182 return nil 183} 184 185func (m *testIterator) Seek(key []byte) error { 186 m.curr = sort.SearchStrings(m.keys, string(key)) 187 if m.curr >= len(m.keys) { 188 return ErrIteratorDone 189 } 190 return nil 191} 192 193func (m *testIterator) Reset(f *FST, startKeyInclusive, endKeyExclusive []byte, aut Automaton) error { 194 return nil 195} 196 197func (m *testIterator) Close() error { 198 return nil 199} 200 201func TestMergeFunc(t *testing.T) { 202 tests := []struct { 203 desc string 204 in []uint64 205 merge MergeFunc 206 want uint64 207 }{ 208 { 209 desc: "min", 210 in: []uint64{5, 99, 1}, 211 merge: MergeMin, 212 want: 1, 213 }, 214 { 215 desc: "max", 216 in: []uint64{5, 99, 1}, 217 merge: MergeMax, 218 want: 99, 219 }, 220 { 221 desc: "sum", 222 in: []uint64{5, 99, 1}, 223 merge: MergeSum, 224 want: 105, 225 }, 226 } 227 228 for _, test := range tests { 229 t.Run(test.desc, func(t *testing.T) { 230 got := test.merge(test.in) 231 if test.want != got { 232 t.Errorf("expected %d, got %d", test.want, got) 233 } 234 }) 235 } 236} 237 238func TestEmptyMergeIterator(t *testing.T) { 239 mi, err := NewMergeIterator([]Iterator{}, MergeMin) 240 if err != ErrIteratorDone { 241 t.Fatalf("expected iterator done, got %v", err) 242 } 243 244 // should get valid merge iterator anyway 245 if mi == nil { 246 t.Fatalf("expected non-nil merge iterator") 247 } 248 249 // current returns nil, 0 per interface spec 250 ck, cv := mi.Current() 251 if ck != nil { 252 t.Errorf("expected current to return nil key, got %v", ck) 253 } 254 if cv != 0 { 255 t.Errorf("expected current to return 0 val, got %d", cv) 256 } 257 258 // calling Next/Seek continues to return ErrIteratorDone 259 err = mi.Next() 260 if err != ErrIteratorDone { 261 t.Errorf("expected iterator done, got %v", err) 262 } 263 err = mi.Seek([]byte("anywhere")) 264 if err != ErrIteratorDone { 265 t.Errorf("expected iterator done, got %v", err) 266 } 267 268 err = mi.Close() 269 if err != nil { 270 t.Errorf("error closing %v", err) 271 } 272 273} 274