1//  Copyright (c) 2017 Couchbase, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// 		http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package vellum
16
17import (
18	"reflect"
19	"sort"
20	"testing"
21)
22
23func TestMergeIterator(t *testing.T) {
24
25	tests := []struct {
26		desc  string
27		in    []map[string]uint64
28		merge MergeFunc
29		want  map[string]uint64
30	}{
31		{
32			desc: "two non-empty iterators with no duplicate keys",
33			in: []map[string]uint64{
34				map[string]uint64{
35					"a": 1,
36					"c": 3,
37					"e": 5,
38				},
39				map[string]uint64{
40					"b": 2,
41					"d": 4,
42					"f": 6,
43				},
44			},
45			merge: func(mvs []uint64) uint64 {
46				return mvs[0]
47			},
48			want: map[string]uint64{
49				"a": 1,
50				"c": 3,
51				"e": 5,
52				"b": 2,
53				"d": 4,
54				"f": 6,
55			},
56		},
57		{
58			desc: "two non-empty iterators with duplicate keys summed",
59			in: []map[string]uint64{
60				map[string]uint64{
61					"a": 1,
62					"c": 3,
63					"e": 5,
64				},
65				map[string]uint64{
66					"a": 2,
67					"c": 4,
68					"e": 6,
69				},
70			},
71			merge: func(mvs []uint64) uint64 {
72				var rv uint64
73				for _, mv := range mvs {
74					rv += mv
75				}
76				return rv
77			},
78			want: map[string]uint64{
79				"a": 3,
80				"c": 7,
81				"e": 11,
82			},
83		},
84
85		{
86			desc: "non-working example",
87			in: []map[string]uint64{
88				map[string]uint64{
89					"mon":   2,
90					"tues":  3,
91					"thurs": 5,
92					"tye":   99,
93				},
94				map[string]uint64{
95					"bold": 25,
96					"last": 1,
97					"next": 500,
98					"tank": 0,
99				},
100			},
101			merge: func(mvs []uint64) uint64 {
102				return mvs[0]
103			},
104			want: map[string]uint64{
105				"mon":   2,
106				"tues":  3,
107				"thurs": 5,
108				"tye":   99,
109				"bold":  25,
110				"last":  1,
111				"next":  500,
112				"tank":  0,
113			},
114		},
115	}
116
117	for _, test := range tests {
118		t.Run(test.desc, func(t *testing.T) {
119			var itrs []Iterator
120			for i := range test.in {
121				itr, err := newTestIterator(test.in[i])
122				if err != nil && err != ErrIteratorDone {
123					t.Fatalf("error creating iterator: %v", err)
124				}
125				if err == nil {
126					itrs = append(itrs, itr)
127				}
128			}
129			mi, err := NewMergeIterator(itrs, test.merge)
130			if err != nil && err != ErrIteratorDone {
131				t.Fatalf("error creating iterator: %v", err)
132			}
133			got := make(map[string]uint64)
134			for err == nil {
135				currk, currv := mi.Current()
136				err = mi.Next()
137				got[string(currk)] = currv
138			}
139			if err != nil && err != ErrIteratorDone {
140				t.Fatalf("error iterating: %v", err)
141			}
142
143			if !reflect.DeepEqual(got, test.want) {
144				t.Errorf("expected %v, got %v", test.want, got)
145			}
146		})
147	}
148}
149
150type testIterator struct {
151	vals map[int]uint64
152	keys []string
153	curr int
154}
155
156func newTestIterator(in map[string]uint64) (*testIterator, error) {
157	rv := &testIterator{
158		vals: make(map[int]uint64, len(in)),
159	}
160	for k := range in {
161		rv.keys = append(rv.keys, k)
162	}
163	sort.Strings(rv.keys)
164	for i, k := range rv.keys {
165		rv.vals[i] = in[k]
166	}
167	return rv, nil
168}
169
170func (m *testIterator) Current() ([]byte, uint64) {
171	if m.curr >= len(m.keys) {
172		return nil, 0
173	}
174	return []byte(m.keys[m.curr]), m.vals[m.curr]
175}
176
177func (m *testIterator) Next() error {
178	m.curr++
179	if m.curr >= len(m.keys) {
180		return ErrIteratorDone
181	}
182	return nil
183}
184
185func (m *testIterator) Seek(key []byte) error {
186	m.curr = sort.SearchStrings(m.keys, string(key))
187	if m.curr >= len(m.keys) {
188		return ErrIteratorDone
189	}
190	return nil
191}
192
193func (m *testIterator) Reset(f *FST, startKeyInclusive, endKeyExclusive []byte, aut Automaton) error {
194	return nil
195}
196
197func (m *testIterator) Close() error {
198	return nil
199}
200
201func TestMergeFunc(t *testing.T) {
202	tests := []struct {
203		desc  string
204		in    []uint64
205		merge MergeFunc
206		want  uint64
207	}{
208		{
209			desc:  "min",
210			in:    []uint64{5, 99, 1},
211			merge: MergeMin,
212			want:  1,
213		},
214		{
215			desc:  "max",
216			in:    []uint64{5, 99, 1},
217			merge: MergeMax,
218			want:  99,
219		},
220		{
221			desc:  "sum",
222			in:    []uint64{5, 99, 1},
223			merge: MergeSum,
224			want:  105,
225		},
226	}
227
228	for _, test := range tests {
229		t.Run(test.desc, func(t *testing.T) {
230			got := test.merge(test.in)
231			if test.want != got {
232				t.Errorf("expected %d, got %d", test.want, got)
233			}
234		})
235	}
236}
237
238func TestEmptyMergeIterator(t *testing.T) {
239	mi, err := NewMergeIterator([]Iterator{}, MergeMin)
240	if err != ErrIteratorDone {
241		t.Fatalf("expected iterator done, got %v", err)
242	}
243
244	// should get valid merge iterator anyway
245	if mi == nil {
246		t.Fatalf("expected non-nil merge iterator")
247	}
248
249	// current returns nil, 0 per interface spec
250	ck, cv := mi.Current()
251	if ck != nil {
252		t.Errorf("expected current to return nil key, got %v", ck)
253	}
254	if cv != 0 {
255		t.Errorf("expected current to return 0 val, got %d", cv)
256	}
257
258	// calling Next/Seek continues to return ErrIteratorDone
259	err = mi.Next()
260	if err != ErrIteratorDone {
261		t.Errorf("expected iterator done, got %v", err)
262	}
263	err = mi.Seek([]byte("anywhere"))
264	if err != ErrIteratorDone {
265		t.Errorf("expected iterator done, got %v", err)
266	}
267
268	err = mi.Close()
269	if err != nil {
270		t.Errorf("error closing %v", err)
271	}
272
273}
274