1// Copyright 2013 Dario Castañé. All rights reserved.
2// Copyright 2009 The Go Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style
4// license that can be found in the LICENSE file.
5
6package mergo
7
8import (
9	"fmt"
10	"io/ioutil"
11	"reflect"
12	"testing"
13	"time"
14
15	"gopkg.in/yaml.v1"
16)
17
18type simpleTest struct {
19	Value int
20}
21
22type complexTest struct {
23	St simpleTest
24	sz int
25	Id string
26}
27
28type moreComplextText struct {
29	Ct complexTest
30	St simpleTest
31	Nt simpleTest
32}
33
34type pointerTest struct {
35	C *simpleTest
36}
37
38type sliceTest struct {
39	S []int
40	M []map[string]string
41}
42
43func TestKb(t *testing.T) {
44	type testStruct struct {
45		Name     string
46		KeyValue map[string]interface{}
47	}
48
49	akv := make(map[string]interface{})
50	akv["Key1"] = "not value 1"
51	akv["Key2"] = "value2"
52	a := testStruct{}
53	a.Name = "A"
54	a.KeyValue = akv
55
56	bkv := make(map[string]interface{})
57	bkv["Key1"] = "value1"
58	bkv["Key3"] = "value3"
59	b := testStruct{}
60	b.Name = "B"
61	b.KeyValue = bkv
62
63	ekv := make(map[string]interface{})
64	ekv["Key1"] = "value1"
65	ekv["Key2"] = "value2"
66	ekv["Key3"] = "value3"
67	expected := testStruct{}
68	expected.Name = "B"
69	expected.KeyValue = ekv
70
71	Merge(&b, a)
72
73	if !reflect.DeepEqual(b, expected) {
74		t.Errorf("Actual: %#v did not match \nExpected: %#v", b, expected)
75	}
76}
77
78func TestNil(t *testing.T) {
79	if err := Merge(nil, nil); err != ErrNilArguments {
80		t.Fail()
81	}
82}
83
84func TestDifferentTypes(t *testing.T) {
85	a := simpleTest{42}
86	b := 42
87	if err := Merge(&a, b); err != ErrDifferentArgumentsTypes {
88		t.Fail()
89	}
90}
91
92func TestSimpleStruct(t *testing.T) {
93	a := simpleTest{}
94	b := simpleTest{42}
95	if err := Merge(&a, b); err != nil {
96		t.FailNow()
97	}
98	if a.Value != 42 {
99		t.Fatalf("b not merged in properly: a.Value(%d) != b.Value(%d)", a.Value, b.Value)
100	}
101	if !reflect.DeepEqual(a, b) {
102		t.FailNow()
103	}
104}
105
106func TestComplexStruct(t *testing.T) {
107	a := complexTest{}
108	a.Id = "athing"
109	b := complexTest{simpleTest{42}, 1, "bthing"}
110	if err := Merge(&a, b); err != nil {
111		t.FailNow()
112	}
113	if a.St.Value != 42 {
114		t.Fatalf("b not merged in properly: a.St.Value(%d) != b.St.Value(%d)", a.St.Value, b.St.Value)
115	}
116	if a.sz == 1 {
117		t.Fatalf("a's private field sz not preserved from merge: a.sz(%d) == b.sz(%d)", a.sz, b.sz)
118	}
119	if a.Id == b.Id {
120		t.Fatalf("a's field Id merged unexpectedly: a.Id(%s) == b.Id(%s)", a.Id, b.Id)
121	}
122}
123
124func TestComplexStructWithOverwrite(t *testing.T) {
125	a := complexTest{simpleTest{1}, 1, "do-not-overwrite-with-empty-value"}
126	b := complexTest{simpleTest{42}, 2, ""}
127
128	expect := complexTest{simpleTest{42}, 1, "do-not-overwrite-with-empty-value"}
129	if err := MergeWithOverwrite(&a, b); err != nil {
130		t.FailNow()
131	}
132
133	if !reflect.DeepEqual(a, expect) {
134		t.Fatalf("Test failed:\ngot  :\n%#v\n\nwant :\n%#v\n\n", a, expect)
135	}
136}
137
138func TestPointerStruct(t *testing.T) {
139	s1 := simpleTest{}
140	s2 := simpleTest{19}
141	a := pointerTest{&s1}
142	b := pointerTest{&s2}
143	if err := Merge(&a, b); err != nil {
144		t.FailNow()
145	}
146	if a.C.Value != b.C.Value {
147		t.Fatalf("b not merged in properly: a.C.Value(%d) != b.C.Value(%d)", a.C.Value, b.C.Value)
148	}
149}
150
151type embeddingStruct struct {
152	embeddedStruct
153}
154
155type embeddedStruct struct {
156	A string
157}
158
159func TestEmbeddedStruct(t *testing.T) {
160	tests := []struct {
161		src      embeddingStruct
162		dst      embeddingStruct
163		expected embeddingStruct
164	}{
165		{
166			src: embeddingStruct{
167				embeddedStruct{"foo"},
168			},
169			dst: embeddingStruct{
170				embeddedStruct{""},
171			},
172			expected: embeddingStruct{
173				embeddedStruct{"foo"},
174			},
175		},
176		{
177			src: embeddingStruct{
178				embeddedStruct{""},
179			},
180			dst: embeddingStruct{
181				embeddedStruct{"bar"},
182			},
183			expected: embeddingStruct{
184				embeddedStruct{"bar"},
185			},
186		},
187		{
188			src: embeddingStruct{
189				embeddedStruct{"foo"},
190			},
191			dst: embeddingStruct{
192				embeddedStruct{"bar"},
193			},
194			expected: embeddingStruct{
195				embeddedStruct{"bar"},
196			},
197		},
198	}
199
200	for _, test := range tests {
201		err := Merge(&test.dst, test.src)
202		if err != nil {
203			t.Errorf("unexpected error: %v", err)
204			continue
205		}
206		if !reflect.DeepEqual(test.dst, test.expected) {
207			t.Errorf("unexpected output\nexpected:\n%+v\nsaw:\n%+v\n", test.expected, test.dst)
208		}
209	}
210}
211
212func TestPointerStructNil(t *testing.T) {
213	a := pointerTest{nil}
214	b := pointerTest{&simpleTest{19}}
215	if err := Merge(&a, b); err != nil {
216		t.FailNow()
217	}
218	if a.C.Value != b.C.Value {
219		t.Fatalf("b not merged in a properly: a.C.Value(%d) != b.C.Value(%d)", a.C.Value, b.C.Value)
220	}
221}
222
223func TestSliceStruct(t *testing.T) {
224	a := sliceTest{}
225	b := sliceTest{S: []int{1, 2, 3}}
226	if err := Merge(&a, b); err != nil {
227		t.FailNow()
228	}
229	if len(b.S) != 3 {
230		t.FailNow()
231	}
232	if len(a.S) != len(b.S) {
233		t.Fatalf("b not merged in a proper way %d != %d", len(a.S), len(b.S))
234	}
235
236	a = sliceTest{S: []int{1}}
237	b = sliceTest{S: []int{2, 3, 4}}
238	if err := Merge(&a, b); err != nil {
239		t.FailNow()
240	}
241	if len(a.S) != 4 {
242		t.FailNow()
243	}
244
245	a = sliceTest{S: []int{1}}
246	b = sliceTest{S: []int{1, 2, 3, 4}}
247	if err := Merge(&a, b); err != nil {
248		t.FailNow()
249	}
250	if len(a.S) != 4 {
251		t.FailNow()
252	}
253
254	c := sliceTest{M: []map[string]string{
255		map[string]string{"foo": "bar"},
256		map[string]string{"baz": "qux"},
257	}}
258	d := sliceTest{M: []map[string]string{
259		map[string]string{"foo": "qux"},
260	}}
261	if err := Merge(&c, d); err != nil {
262		fmt.Println(err)
263		t.FailNow()
264	}
265	if len(c.M) != 3 {
266		fmt.Println(a)
267		t.FailNow()
268	}
269
270	c = sliceTest{M: []map[string]string{
271		map[string]string{"foo": "bar"},
272		map[string]string{"baz": "qux"},
273	}}
274	d = sliceTest{M: []map[string]string{
275		map[string]string{"foo": "bar"},
276	}}
277	if err := Merge(&c, d); err != nil {
278		fmt.Println(err)
279		t.FailNow()
280	}
281	if len(c.M) != 2 {
282		fmt.Println(a)
283		t.FailNow()
284	}
285}
286
287func TestMapsWithOverwrite(t *testing.T) {
288	m := map[string]simpleTest{
289		"a": simpleTest{},   // overwritten by 16
290		"b": simpleTest{42}, // not overwritten by empty value
291		"c": simpleTest{13}, // overwritten by 12
292		"d": simpleTest{61},
293	}
294	n := map[string]simpleTest{
295		"a": simpleTest{16},
296		"b": simpleTest{},
297		"c": simpleTest{12},
298		"e": simpleTest{14},
299	}
300	expect := map[string]simpleTest{
301		"a": simpleTest{16},
302		"b": simpleTest{},
303		"c": simpleTest{12},
304		"d": simpleTest{61},
305		"e": simpleTest{14},
306	}
307
308	if err := MergeWithOverwrite(&m, n); err != nil {
309		t.Fatalf(err.Error())
310	}
311
312	if !reflect.DeepEqual(m, expect) {
313		t.Fatalf("Test failed:\ngot  :\n%#v\n\nwant :\n%#v\n\n", m, expect)
314	}
315}
316
317func TestMaps(t *testing.T) {
318	m := map[string]simpleTest{
319		"a": simpleTest{},
320		"b": simpleTest{42},
321		"c": simpleTest{13},
322		"d": simpleTest{61},
323	}
324	n := map[string]simpleTest{
325		"a": simpleTest{16},
326		"b": simpleTest{},
327		"c": simpleTest{12},
328		"e": simpleTest{14},
329	}
330	expect := map[string]simpleTest{
331		"a": simpleTest{0},
332		"b": simpleTest{42},
333		"c": simpleTest{13},
334		"d": simpleTest{61},
335		"e": simpleTest{14},
336	}
337
338	if err := Merge(&m, n); err != nil {
339		t.Fatalf(err.Error())
340	}
341
342	if !reflect.DeepEqual(m, expect) {
343		t.Fatalf("Test failed:\ngot  :\n%#v\n\nwant :\n%#v\n\n", m, expect)
344	}
345	if m["a"].Value != 0 {
346		t.Fatalf(`n merged in m because I solved non-addressable map values TODO: m["a"].Value(%d) != n["a"].Value(%d)`, m["a"].Value, n["a"].Value)
347	}
348	if m["b"].Value != 42 {
349		t.Fatalf(`n wrongly merged in m: m["b"].Value(%d) != n["b"].Value(%d)`, m["b"].Value, n["b"].Value)
350	}
351	if m["c"].Value != 13 {
352		t.Fatalf(`n overwritten in m: m["c"].Value(%d) != n["c"].Value(%d)`, m["c"].Value, n["c"].Value)
353	}
354}
355
356func TestYAMLMaps(t *testing.T) {
357	thing := loadYAML("testdata/thing.yml")
358	license := loadYAML("testdata/license.yml")
359	ft := thing["fields"].(map[interface{}]interface{})
360	fl := license["fields"].(map[interface{}]interface{})
361	expectedLength := len(ft) + len(fl)
362	if err := Merge(&license, thing); err != nil {
363		t.Fatal(err.Error())
364	}
365	currentLength := len(license["fields"].(map[interface{}]interface{}))
366	if currentLength != expectedLength {
367		t.Fatalf(`thing not merged in license properly, license must have %d elements instead of %d`, expectedLength, currentLength)
368	}
369	fields := license["fields"].(map[interface{}]interface{})
370	if _, ok := fields["id"]; !ok {
371		t.Fatalf(`thing not merged in license properly, license must have a new id field from thing`)
372	}
373}
374
375func TestTwoPointerValues(t *testing.T) {
376	a := &simpleTest{}
377	b := &simpleTest{42}
378	if err := Merge(a, b); err != nil {
379		t.Fatalf(`Boom. You crossed the streams: %s`, err)
380	}
381}
382
383func TestMap(t *testing.T) {
384	a := complexTest{}
385	a.Id = "athing"
386	c := moreComplextText{a, simpleTest{}, simpleTest{}}
387	b := map[string]interface{}{
388		"ct": map[string]interface{}{
389			"st": map[string]interface{}{
390				"value": 42,
391			},
392			"sz": 1,
393			"id": "bthing",
394		},
395		"st": &simpleTest{144}, // Mapping a reference
396		"zt": simpleTest{299},  // Mapping a missing field (zt doesn't exist)
397		"nt": simpleTest{3},
398	}
399	if err := Map(&c, b); err != nil {
400		t.FailNow()
401	}
402	m := b["ct"].(map[string]interface{})
403	n := m["st"].(map[string]interface{})
404	o := b["st"].(*simpleTest)
405	p := b["nt"].(simpleTest)
406	if c.Ct.St.Value != 42 {
407		t.Fatalf("b not merged in properly: c.Ct.St.Value(%d) != b.Ct.St.Value(%d)", c.Ct.St.Value, n["value"])
408	}
409	if c.St.Value != 144 {
410		t.Fatalf("b not merged in properly: c.St.Value(%d) != b.St.Value(%d)", c.St.Value, o.Value)
411	}
412	if c.Nt.Value != 3 {
413		t.Fatalf("b not merged in properly: c.Nt.Value(%d) != b.Nt.Value(%d)", c.St.Value, p.Value)
414	}
415	if c.Ct.sz == 1 {
416		t.Fatalf("a's private field sz not preserved from merge: c.Ct.sz(%d) == b.Ct.sz(%d)", c.Ct.sz, m["sz"])
417	}
418	if c.Ct.Id == m["id"] {
419		t.Fatalf("a's field Id merged unexpectedly: c.Ct.Id(%s) == b.Ct.Id(%s)", c.Ct.Id, m["id"])
420	}
421}
422
423func TestSimpleMap(t *testing.T) {
424	a := simpleTest{}
425	b := map[string]interface{}{
426		"value": 42,
427	}
428	if err := Map(&a, b); err != nil {
429		t.FailNow()
430	}
431	if a.Value != 42 {
432		t.Fatalf("b not merged in properly: a.Value(%d) != b.Value(%v)", a.Value, b["value"])
433	}
434}
435
436type pointerMapTest struct {
437	A      int
438	hidden int
439	B      *simpleTest
440}
441
442func TestBackAndForth(t *testing.T) {
443	pt := pointerMapTest{42, 1, &simpleTest{66}}
444	m := make(map[string]interface{})
445	if err := Map(&m, pt); err != nil {
446		t.FailNow()
447	}
448	var (
449		v  interface{}
450		ok bool
451	)
452	if v, ok = m["a"]; v.(int) != pt.A || !ok {
453		t.Fatalf("pt not merged in properly: m[`a`](%d) != pt.A(%d)", v, pt.A)
454	}
455	if v, ok = m["b"]; !ok {
456		t.Fatalf("pt not merged in properly: B is missing in m")
457	}
458	var st *simpleTest
459	if st = v.(*simpleTest); st.Value != 66 {
460		t.Fatalf("something went wrong while mapping pt on m, B wasn't copied")
461	}
462	bpt := pointerMapTest{}
463	if err := Map(&bpt, m); err != nil {
464		t.Fatal(err)
465	}
466	if bpt.A != pt.A {
467		t.Fatalf("pt not merged in properly: bpt.A(%d) != pt.A(%d)", bpt.A, pt.A)
468	}
469	if bpt.hidden == pt.hidden {
470		t.Fatalf("pt unexpectedly merged: bpt.hidden(%d) == pt.hidden(%d)", bpt.hidden, pt.hidden)
471	}
472	if bpt.B.Value != pt.B.Value {
473		t.Fatalf("pt not merged in properly: bpt.B.Value(%d) != pt.B.Value(%d)", bpt.B.Value, pt.B.Value)
474	}
475}
476
477type structWithTimePointer struct {
478	Birth *time.Time
479}
480
481func TestTime(t *testing.T) {
482	now := time.Now()
483	dataStruct := structWithTimePointer{
484		Birth: &now,
485	}
486	dataMap := map[string]interface{}{
487		"Birth": &now,
488	}
489	b := structWithTimePointer{}
490	if err := Merge(&b, dataStruct); err != nil {
491		t.FailNow()
492	}
493	if b.Birth.IsZero() {
494		t.Fatalf("time.Time not merged in properly: b.Birth(%v) != dataStruct['Birth'](%v)", b.Birth, dataStruct.Birth)
495	}
496	if b.Birth != dataStruct.Birth {
497		t.Fatalf("time.Time not merged in properly: b.Birth(%v) != dataStruct['Birth'](%v)", b.Birth, dataStruct.Birth)
498	}
499	b = structWithTimePointer{}
500	if err := Map(&b, dataMap); err != nil {
501		t.FailNow()
502	}
503	if b.Birth.IsZero() {
504		t.Fatalf("time.Time not merged in properly: b.Birth(%v) != dataMap['Birth'](%v)", b.Birth, dataMap["Birth"])
505	}
506}
507
508func loadYAML(path string) (m map[string]interface{}) {
509	m = make(map[string]interface{})
510	raw, _ := ioutil.ReadFile(path)
511	_ = yaml.Unmarshal(raw, &m)
512	return
513}
514