1package set
2
3import (
4	"fmt"
5	"reflect"
6	"sort"
7	"testing"
8)
9
10// TestBasicSetOps tests the fundamental operations, whose implementations operate
11// directly on the underlying data structure. The remaining operations are implemented
12// in terms of these.
13func TestBasicSetOps(t *testing.T) {
14	s := NewSet(testRules{})
15	want := map[int][]interface{}{}
16	if !reflect.DeepEqual(s.vals, want) {
17		t.Fatalf("new set has unexpected contents %#v; want %#v", s.vals, want)
18	}
19	s.Add(1)
20	want[1] = []interface{}{1}
21	if !reflect.DeepEqual(s.vals, want) {
22		t.Fatalf("after s.Add(1) set has unexpected contents %#v; want %#v", s.vals, want)
23	}
24	if !s.Has(1) {
25		t.Fatalf("s.Has(1) returned false; want true")
26	}
27	s.Add(2)
28	want[2] = []interface{}{2}
29	if !reflect.DeepEqual(s.vals, want) {
30		t.Fatalf("after s.Add(2) set has unexpected contents %#v; want %#v", s.vals, want)
31	}
32	if !s.Has(2) {
33		t.Fatalf("s.Has(2) returned false; want true")
34	}
35
36	// Our testRules cause 17 and 33 to return the same hash value as 1, so we can use this
37	// to test the situation where multiple values are in a bucket.
38	if s.Has(17) {
39		t.Fatalf("s.Has(17) returned true; want false")
40	}
41	s.Add(17)
42	s.Add(33)
43	want[1] = append(want[1], 17, 33)
44	if !reflect.DeepEqual(s.vals, want) {
45		t.Fatalf("after s.Add(17) and s.Add(33) set has unexpected contents %#v; want %#v", s.vals, want)
46	}
47	if !s.Has(17) {
48		t.Fatalf("s.Has(17) returned false; want true")
49	}
50	if !s.Has(33) {
51		t.Fatalf("s.Has(33) returned false; want true")
52	}
53
54	vals := make([]int, 0)
55	s.EachValue(func(v interface{}) {
56		vals = append(vals, v.(int))
57	})
58	sort.Ints(vals)
59	if want := []int{1, 2, 17, 33}; !reflect.DeepEqual(vals, want) {
60		t.Fatalf("wrong values from EachValue %#v; want %#v", vals, want)
61	}
62
63	s.Remove(2)
64	delete(want, 2)
65	if !reflect.DeepEqual(s.vals, want) {
66		t.Fatalf("after s.Remove(2) set has unexpected contents %#v; want %#v", s.vals, want)
67	}
68
69	s.Remove(17)
70	want[1] = []interface{}{1, 33}
71	if !reflect.DeepEqual(s.vals, want) {
72		t.Fatalf("after s.Remove(17) set has unexpected contents %#v; want %#v", s.vals, want)
73	}
74
75	s.Remove(1)
76	want[1] = []interface{}{33}
77	if !reflect.DeepEqual(s.vals, want) {
78		t.Fatalf("after s.Remove(1) set has unexpected contents %#v; want %#v", s.vals, want)
79	}
80
81	s.Remove(33)
82	delete(want, 1)
83	if !reflect.DeepEqual(s.vals, want) {
84		t.Fatalf("after s.Remove(33) set has unexpected contents %#v; want %#v", s.vals, want)
85	}
86
87	vals = make([]int, 0)
88	s.EachValue(func(v interface{}) {
89		vals = append(vals, v.(int))
90	})
91	if len(vals) > 0 {
92		t.Fatalf("s.EachValue produced values %#v; want no calls", vals)
93	}
94}
95
96func TestUnion(t *testing.T) {
97	tests := []struct {
98		s1         Set
99		s2         Set
100		wantValues []int
101	}{
102		{
103			NewSet(testRules{}),
104			NewSet(testRules{}),
105			nil,
106		},
107		{
108			NewSetFromSlice(testRules{}, []interface{}{1}),
109			NewSet(testRules{}),
110			[]int{1},
111		},
112		{
113			NewSetFromSlice(testRules{}, []interface{}{1}),
114			NewSetFromSlice(testRules{}, []interface{}{2}),
115			[]int{1, 2},
116		},
117		{
118			NewSetFromSlice(testRules{}, []interface{}{1}),
119			NewSetFromSlice(testRules{}, []interface{}{1}),
120			[]int{1},
121		},
122		{
123			NewSetFromSlice(testRules{}, []interface{}{17, 33}),
124			NewSetFromSlice(testRules{}, []interface{}{1}),
125			[]int{1, 17, 33},
126		},
127		{
128			NewSetFromSlice(testRules{}, []interface{}{17, 33}),
129			NewSetFromSlice(testRules{}, []interface{}{2, 1}),
130			[]int{1, 2, 17, 33},
131		},
132	}
133
134	for i, test := range tests {
135		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
136			got := test.s1.Union(test.s2)
137			var gotValues []int
138			got.EachValue(func(v interface{}) {
139				gotValues = append(gotValues, v.(int))
140			})
141			sort.Ints(gotValues)
142			sort.Ints(test.wantValues)
143			if !reflect.DeepEqual(gotValues, test.wantValues) {
144				s1Values := test.s1.Values()
145				s2Values := test.s2.Values()
146				t.Errorf(
147					"wrong result %#v for %#v union %#v; want %#v",
148					gotValues,
149					s1Values,
150					s2Values,
151					test.wantValues,
152				)
153			}
154		})
155	}
156}
157
158func TestIntersection(t *testing.T) {
159	tests := []struct {
160		s1         Set
161		s2         Set
162		wantValues []int
163	}{
164		{
165			NewSet(testRules{}),
166			NewSet(testRules{}),
167			nil,
168		},
169		{
170			NewSetFromSlice(testRules{}, []interface{}{1}),
171			NewSet(testRules{}),
172			nil,
173		},
174		{
175			NewSetFromSlice(testRules{}, []interface{}{1}),
176			NewSetFromSlice(testRules{}, []interface{}{2}),
177			nil,
178		},
179		{
180			NewSetFromSlice(testRules{}, []interface{}{1}),
181			NewSetFromSlice(testRules{}, []interface{}{1}),
182			[]int{1},
183		},
184		{
185			NewSetFromSlice(testRules{}, []interface{}{1, 17}),
186			NewSetFromSlice(testRules{}, []interface{}{1, 2, 3}),
187			[]int{1},
188		},
189		{
190			NewSetFromSlice(testRules{}, []interface{}{3, 2, 1}),
191			NewSetFromSlice(testRules{}, []interface{}{1, 2, 3}),
192			[]int{1, 2, 3},
193		},
194		{
195			NewSetFromSlice(testRules{}, []interface{}{17, 33}),
196			NewSetFromSlice(testRules{}, []interface{}{1}),
197			nil,
198		},
199		{
200			NewSetFromSlice(testRules{}, []interface{}{17, 33}),
201			NewSetFromSlice(testRules{}, []interface{}{2, 1}),
202			nil,
203		},
204	}
205
206	for i, test := range tests {
207		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
208			got := test.s1.Intersection(test.s2)
209			var gotValues []int
210			got.EachValue(func(v interface{}) {
211				gotValues = append(gotValues, v.(int))
212			})
213			sort.Ints(gotValues)
214			sort.Ints(test.wantValues)
215			if !reflect.DeepEqual(gotValues, test.wantValues) {
216				s1Values := test.s1.Values()
217				s2Values := test.s2.Values()
218				t.Errorf(
219					"wrong result %#v for %#v intersection %#v; want %#v",
220					gotValues,
221					s1Values,
222					s2Values,
223					test.wantValues,
224				)
225			}
226		})
227	}
228}
229
230func TestSubtract(t *testing.T) {
231	tests := []struct {
232		s1         Set
233		s2         Set
234		wantValues []int
235	}{
236		{
237			NewSet(testRules{}),
238			NewSet(testRules{}),
239			nil,
240		},
241		{
242			NewSetFromSlice(testRules{}, []interface{}{1}),
243			NewSet(testRules{}),
244			[]int{1},
245		},
246		{
247			NewSetFromSlice(testRules{}, []interface{}{1}),
248			NewSetFromSlice(testRules{}, []interface{}{2}),
249			[]int{1},
250		},
251		{
252			NewSetFromSlice(testRules{}, []interface{}{1}),
253			NewSetFromSlice(testRules{}, []interface{}{1}),
254			nil,
255		},
256		{
257			NewSetFromSlice(testRules{}, []interface{}{1, 17}),
258			NewSetFromSlice(testRules{}, []interface{}{1, 2, 3}),
259			[]int{17},
260		},
261		{
262			NewSetFromSlice(testRules{}, []interface{}{3, 2, 1}),
263			NewSetFromSlice(testRules{}, []interface{}{1, 2, 3}),
264			nil,
265		},
266		{
267			NewSetFromSlice(testRules{}, []interface{}{17, 33}),
268			NewSetFromSlice(testRules{}, []interface{}{1}),
269			[]int{17, 33},
270		},
271		{
272			NewSetFromSlice(testRules{}, []interface{}{17, 33}),
273			NewSetFromSlice(testRules{}, []interface{}{2, 1}),
274			[]int{17, 33},
275		},
276	}
277
278	for i, test := range tests {
279		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
280			got := test.s1.Subtract(test.s2)
281			var gotValues []int
282			got.EachValue(func(v interface{}) {
283				gotValues = append(gotValues, v.(int))
284			})
285			sort.Ints(gotValues)
286			sort.Ints(test.wantValues)
287			if !reflect.DeepEqual(gotValues, test.wantValues) {
288				s1Values := test.s1.Values()
289				s2Values := test.s2.Values()
290				t.Errorf(
291					"wrong result %#v for %#v subtract %#v; want %#v",
292					gotValues,
293					s1Values,
294					s2Values,
295					test.wantValues,
296				)
297			}
298		})
299	}
300}
301
302func TestSymmetricDifference(t *testing.T) {
303	tests := []struct {
304		s1         Set
305		s2         Set
306		wantValues []int
307	}{
308		{
309			NewSet(testRules{}),
310			NewSet(testRules{}),
311			nil,
312		},
313		{
314			NewSetFromSlice(testRules{}, []interface{}{1}),
315			NewSet(testRules{}),
316			[]int{1},
317		},
318		{
319			NewSetFromSlice(testRules{}, []interface{}{1}),
320			NewSetFromSlice(testRules{}, []interface{}{2}),
321			[]int{1, 2},
322		},
323		{
324			NewSetFromSlice(testRules{}, []interface{}{1}),
325			NewSetFromSlice(testRules{}, []interface{}{1}),
326			nil,
327		},
328		{
329			NewSetFromSlice(testRules{}, []interface{}{1, 17}),
330			NewSetFromSlice(testRules{}, []interface{}{1, 2, 3}),
331			[]int{2, 3, 17},
332		},
333		{
334			NewSetFromSlice(testRules{}, []interface{}{3, 2, 1}),
335			NewSetFromSlice(testRules{}, []interface{}{1, 2, 3}),
336			nil,
337		},
338		{
339			NewSetFromSlice(testRules{}, []interface{}{17, 33}),
340			NewSetFromSlice(testRules{}, []interface{}{1}),
341			[]int{1, 17, 33},
342		},
343		{
344			NewSetFromSlice(testRules{}, []interface{}{17, 33}),
345			NewSetFromSlice(testRules{}, []interface{}{2, 1}),
346			[]int{1, 2, 17, 33},
347		},
348	}
349
350	for i, test := range tests {
351		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
352			got := test.s1.SymmetricDifference(test.s2)
353			var gotValues []int
354			got.EachValue(func(v interface{}) {
355				gotValues = append(gotValues, v.(int))
356			})
357			sort.Ints(gotValues)
358			sort.Ints(test.wantValues)
359			if !reflect.DeepEqual(gotValues, test.wantValues) {
360				s1Values := test.s1.Values()
361				s2Values := test.s2.Values()
362				t.Errorf(
363					"wrong result %#v for %#v symmetric difference %#v; want %#v",
364					gotValues,
365					s1Values,
366					s2Values,
367					test.wantValues,
368				)
369			}
370		})
371	}
372}
373