1// Copyright 2013 Dario Castañé. All rights reserved.
2// Copyright 2009 The Go Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style
4// license that can be found in the LICENSE file.
5
6// Based on src/pkg/reflect/deepequal.go from official
7// golang's stdlib.
8
9package mergo
10
11import (
12	"fmt"
13	"reflect"
14)
15
16func hasMergeableFields(dst reflect.Value) (exported bool) {
17	for i, n := 0, dst.NumField(); i < n; i++ {
18		field := dst.Type().Field(i)
19		if field.Anonymous && dst.Field(i).Kind() == reflect.Struct {
20			exported = exported || hasMergeableFields(dst.Field(i))
21		} else if isExportedComponent(&field) {
22			exported = exported || len(field.PkgPath) == 0
23		}
24	}
25	return
26}
27
28func isExportedComponent(field *reflect.StructField) bool {
29	pkgPath := field.PkgPath
30	if len(pkgPath) > 0 {
31		return false
32	}
33	c := field.Name[0]
34	if 'a' <= c && c <= 'z' || c == '_' {
35		return false
36	}
37	return true
38}
39
40type Config struct {
41	Overwrite                    bool
42	AppendSlice                  bool
43	TypeCheck                    bool
44	Transformers                 Transformers
45	overwriteWithEmptyValue      bool
46	overwriteSliceWithEmptyValue bool
47	sliceDeepCopy                bool
48	debug                        bool
49}
50
51type Transformers interface {
52	Transformer(reflect.Type) func(dst, src reflect.Value) error
53}
54
55// Traverses recursively both values, assigning src's fields values to dst.
56// The map argument tracks comparisons that have already been seen, which allows
57// short circuiting on recursive types.
58func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, config *Config) (err error) {
59	overwrite := config.Overwrite
60	typeCheck := config.TypeCheck
61	overwriteWithEmptySrc := config.overwriteWithEmptyValue
62	overwriteSliceWithEmptySrc := config.overwriteSliceWithEmptyValue
63	sliceDeepCopy := config.sliceDeepCopy
64
65	if !src.IsValid() {
66		return
67	}
68	if dst.CanAddr() {
69		addr := dst.UnsafeAddr()
70		h := 17 * addr
71		seen := visited[h]
72		typ := dst.Type()
73		for p := seen; p != nil; p = p.next {
74			if p.ptr == addr && p.typ == typ {
75				return nil
76			}
77		}
78		// Remember, remember...
79		visited[h] = &visit{addr, typ, seen}
80	}
81
82	if config.Transformers != nil && !isEmptyValue(dst) {
83		if fn := config.Transformers.Transformer(dst.Type()); fn != nil {
84			err = fn(dst, src)
85			return
86		}
87	}
88
89	switch dst.Kind() {
90	case reflect.Struct:
91		if hasMergeableFields(dst) {
92			for i, n := 0, dst.NumField(); i < n; i++ {
93				if err = deepMerge(dst.Field(i), src.Field(i), visited, depth+1, config); err != nil {
94					return
95				}
96			}
97		} else {
98			if dst.CanSet() && (isReflectNil(dst) || overwrite) && (!isEmptyValue(src) || overwriteWithEmptySrc) {
99				dst.Set(src)
100			}
101		}
102	case reflect.Map:
103		if dst.IsNil() && !src.IsNil() {
104			if dst.CanSet() {
105				dst.Set(reflect.MakeMap(dst.Type()))
106			} else {
107				dst = src
108				return
109			}
110		}
111
112		if src.Kind() != reflect.Map {
113			if overwrite {
114				dst.Set(src)
115			}
116			return
117		}
118
119		for _, key := range src.MapKeys() {
120			srcElement := src.MapIndex(key)
121			if !srcElement.IsValid() {
122				continue
123			}
124			dstElement := dst.MapIndex(key)
125			switch srcElement.Kind() {
126			case reflect.Chan, reflect.Func, reflect.Map, reflect.Interface, reflect.Slice:
127				if srcElement.IsNil() {
128					if overwrite {
129						dst.SetMapIndex(key, srcElement)
130					}
131					continue
132				}
133				fallthrough
134			default:
135				if !srcElement.CanInterface() {
136					continue
137				}
138				switch reflect.TypeOf(srcElement.Interface()).Kind() {
139				case reflect.Struct:
140					fallthrough
141				case reflect.Ptr:
142					fallthrough
143				case reflect.Map:
144					srcMapElm := srcElement
145					dstMapElm := dstElement
146					if srcMapElm.CanInterface() {
147						srcMapElm = reflect.ValueOf(srcMapElm.Interface())
148						if dstMapElm.IsValid() {
149							dstMapElm = reflect.ValueOf(dstMapElm.Interface())
150						}
151					}
152					if err = deepMerge(dstMapElm, srcMapElm, visited, depth+1, config); err != nil {
153						return
154					}
155				case reflect.Slice:
156					srcSlice := reflect.ValueOf(srcElement.Interface())
157
158					var dstSlice reflect.Value
159					if !dstElement.IsValid() || dstElement.IsNil() {
160						dstSlice = reflect.MakeSlice(srcSlice.Type(), 0, srcSlice.Len())
161					} else {
162						dstSlice = reflect.ValueOf(dstElement.Interface())
163					}
164
165					if (!isEmptyValue(src) || overwriteWithEmptySrc || overwriteSliceWithEmptySrc) && (overwrite || isEmptyValue(dst)) && !config.AppendSlice && !sliceDeepCopy {
166						if typeCheck && srcSlice.Type() != dstSlice.Type() {
167							return fmt.Errorf("cannot override two slices with different type (%s, %s)", srcSlice.Type(), dstSlice.Type())
168						}
169						dstSlice = srcSlice
170					} else if config.AppendSlice {
171						if srcSlice.Type() != dstSlice.Type() {
172							return fmt.Errorf("cannot append two slices with different type (%s, %s)", srcSlice.Type(), dstSlice.Type())
173						}
174						dstSlice = reflect.AppendSlice(dstSlice, srcSlice)
175					} else if sliceDeepCopy {
176						i := 0
177						for ; i < srcSlice.Len() && i < dstSlice.Len(); i++ {
178							srcElement := srcSlice.Index(i)
179							dstElement := dstSlice.Index(i)
180
181							if srcElement.CanInterface() {
182								srcElement = reflect.ValueOf(srcElement.Interface())
183							}
184							if dstElement.CanInterface() {
185								dstElement = reflect.ValueOf(dstElement.Interface())
186							}
187
188							if err = deepMerge(dstElement, srcElement, visited, depth+1, config); err != nil {
189								return
190							}
191						}
192
193					}
194					dst.SetMapIndex(key, dstSlice)
195				}
196			}
197			if dstElement.IsValid() && !isEmptyValue(dstElement) && (reflect.TypeOf(srcElement.Interface()).Kind() == reflect.Map || reflect.TypeOf(srcElement.Interface()).Kind() == reflect.Slice) {
198				continue
199			}
200
201			if srcElement.IsValid() && ((srcElement.Kind() != reflect.Ptr && overwrite) || !dstElement.IsValid() || isEmptyValue(dstElement)) {
202				if dst.IsNil() {
203					dst.Set(reflect.MakeMap(dst.Type()))
204				}
205				dst.SetMapIndex(key, srcElement)
206			}
207		}
208	case reflect.Slice:
209		if !dst.CanSet() {
210			break
211		}
212		if (!isEmptyValue(src) || overwriteWithEmptySrc || overwriteSliceWithEmptySrc) && (overwrite || isEmptyValue(dst)) && !config.AppendSlice && !sliceDeepCopy {
213			dst.Set(src)
214		} else if config.AppendSlice {
215			if src.Type() != dst.Type() {
216				return fmt.Errorf("cannot append two slice with different type (%s, %s)", src.Type(), dst.Type())
217			}
218			dst.Set(reflect.AppendSlice(dst, src))
219		} else if sliceDeepCopy {
220			for i := 0; i < src.Len() && i < dst.Len(); i++ {
221				srcElement := src.Index(i)
222				dstElement := dst.Index(i)
223				if srcElement.CanInterface() {
224					srcElement = reflect.ValueOf(srcElement.Interface())
225				}
226				if dstElement.CanInterface() {
227					dstElement = reflect.ValueOf(dstElement.Interface())
228				}
229
230				if err = deepMerge(dstElement, srcElement, visited, depth+1, config); err != nil {
231					return
232				}
233			}
234		}
235	case reflect.Ptr:
236		fallthrough
237	case reflect.Interface:
238		if isReflectNil(src) {
239			if overwriteWithEmptySrc && dst.CanSet() && src.Type().AssignableTo(dst.Type()) {
240				dst.Set(src)
241			}
242			break
243		}
244
245		if src.Kind() != reflect.Interface {
246			if dst.IsNil() || (src.Kind() != reflect.Ptr && overwrite) {
247				if dst.CanSet() && (overwrite || isEmptyValue(dst)) {
248					dst.Set(src)
249				}
250			} else if src.Kind() == reflect.Ptr {
251				if err = deepMerge(dst.Elem(), src.Elem(), visited, depth+1, config); err != nil {
252					return
253				}
254			} else if dst.Elem().Type() == src.Type() {
255				if err = deepMerge(dst.Elem(), src, visited, depth+1, config); err != nil {
256					return
257				}
258			} else {
259				return ErrDifferentArgumentsTypes
260			}
261			break
262		}
263
264		if dst.IsNil() || overwrite {
265			if dst.CanSet() && (overwrite || isEmptyValue(dst)) {
266				dst.Set(src)
267			}
268			break
269		}
270
271		if dst.Elem().Kind() == src.Elem().Kind() {
272			if err = deepMerge(dst.Elem(), src.Elem(), visited, depth+1, config); err != nil {
273				return
274			}
275			break
276		}
277	default:
278		mustSet := (isEmptyValue(dst) || overwrite) && (!isEmptyValue(src) || overwriteWithEmptySrc)
279		if mustSet {
280			if dst.CanSet() {
281				dst.Set(src)
282			} else {
283				dst = src
284			}
285		}
286	}
287
288	return
289}
290
291// Merge will fill any empty for value type attributes on the dst struct using corresponding
292// src attributes if they themselves are not empty. dst and src must be valid same-type structs
293// and dst must be a pointer to struct.
294// It won't merge unexported (private) fields and will do recursively any exported field.
295func Merge(dst, src interface{}, opts ...func(*Config)) error {
296	return merge(dst, src, opts...)
297}
298
299// MergeWithOverwrite will do the same as Merge except that non-empty dst attributes will be overridden by
300// non-empty src attribute values.
301// Deprecated: use Merge(…) with WithOverride
302func MergeWithOverwrite(dst, src interface{}, opts ...func(*Config)) error {
303	return merge(dst, src, append(opts, WithOverride)...)
304}
305
306// WithTransformers adds transformers to merge, allowing to customize the merging of some types.
307func WithTransformers(transformers Transformers) func(*Config) {
308	return func(config *Config) {
309		config.Transformers = transformers
310	}
311}
312
313// WithOverride will make merge override non-empty dst attributes with non-empty src attributes values.
314func WithOverride(config *Config) {
315	config.Overwrite = true
316}
317
318// WithOverwriteWithEmptyValue will make merge override non empty dst attributes with empty src attributes values.
319func WithOverwriteWithEmptyValue(config *Config) {
320	config.Overwrite = true
321	config.overwriteWithEmptyValue = true
322}
323
324// WithOverrideEmptySlice will make merge override empty dst slice with empty src slice.
325func WithOverrideEmptySlice(config *Config) {
326	config.overwriteSliceWithEmptyValue = true
327}
328
329// WithAppendSlice will make merge append slices instead of overwriting it.
330func WithAppendSlice(config *Config) {
331	config.AppendSlice = true
332}
333
334// WithTypeCheck will make merge check types while overwriting it (must be used with WithOverride).
335func WithTypeCheck(config *Config) {
336	config.TypeCheck = true
337}
338
339// WithSliceDeepCopy will merge slice element one by one with Overwrite flag.
340func WithSliceDeepCopy(config *Config) {
341	config.sliceDeepCopy = true
342	config.Overwrite = true
343}
344
345func merge(dst, src interface{}, opts ...func(*Config)) error {
346	if dst != nil && reflect.ValueOf(dst).Kind() != reflect.Ptr {
347		return ErrNonPointerAgument
348	}
349	var (
350		vDst, vSrc reflect.Value
351		err        error
352	)
353
354	config := &Config{}
355
356	for _, opt := range opts {
357		opt(config)
358	}
359
360	if vDst, vSrc, err = resolveValues(dst, src); err != nil {
361		return err
362	}
363	if vDst.Type() != vSrc.Type() {
364		return ErrDifferentArgumentsTypes
365	}
366	return deepMerge(vDst, vSrc, make(map[uintptr]*visit), 0, config)
367}
368
369// IsReflectNil is the reflect value provided nil
370func isReflectNil(v reflect.Value) bool {
371	k := v.Kind()
372	switch k {
373	case reflect.Interface, reflect.Slice, reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr:
374		// Both interface and slice are nil if first word is 0.
375		// Both are always bigger than a word; assume flagIndir.
376		return v.IsNil()
377	default:
378		return false
379	}
380}
381