1// Copyright (c) 2015 Hiram Jerónimo Pérez https://worg.xyz
2
3// Permission is hereby granted, free of charge, to any person obtaining a copy
4// of this software and associated documentation files (the "Software"), to deal
5// in the Software without restriction, including without limitation the rights
6// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7// copies of the Software, and to permit persons to whom the Software is
8// furnished to do so, subject to the following conditions:
9
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19// THE SOFTWARE.
20
21// Package merger is an utility to merge structs of the same type
22package merger
23
24import (
25	"errors"
26	"reflect"
27	"strings"
28)
29
30var (
31	// ErrDistinctType occurs when trying to merge structs of distinct type
32	ErrDistinctType = errors.New(`dst and src must be of the same type`)
33	// ErrNoPtr occurs when no struct pointer is sent as destination
34	ErrNoPtr = errors.New(`dst must be a pointer to a struct`)
35	// ErrNilArguments occurs on receiving nil as arguments
36	ErrNilArguments = errors.New(`no nil values allowed`)
37	// ErrUnknown occurs if the type can't be merged
38	ErrUnknown = errors.New(`could not merge`)
39)
40
41// Merge sets zero values from dst to non zero values of src
42// accepts two structs of the same type as arguments
43// dst must be a pointer to a struct
44func Merge(dst, src interface{}) error {
45	if dst == nil || src == nil {
46		return ErrNilArguments
47	}
48
49	if !isStructPtr(dst) {
50		return ErrNoPtr
51	}
52
53	if !typesMatch(src, dst) {
54		return ErrDistinctType
55	}
56
57	vSrc := getValue(src)
58	vDst := getValue(dst)
59
60	for i := 0; i < vSrc.NumField(); i++ {
61		df := vDst.Field(i)
62		sf := vSrc.Field(i)
63		if err := merge(df, sf); err != nil {
64			return err
65		}
66	}
67
68	return nil
69}
70
71// merge merges two reflect values based upon their kinds
72func merge(dst, src reflect.Value) (err error) {
73	if dst.CanSet() && !isZero(src) {
74		switch dst.Kind() {
75		// base types
76		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
77			reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr,
78			reflect.Float32, reflect.Float64, reflect.String, reflect.Bool:
79			if isZero(dst) {
80				switch dst.Kind() {
81				case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
82					dst.SetInt(src.Int())
83				case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
84					dst.SetUint(src.Uint())
85				case reflect.Float32, reflect.Float64:
86					dst.SetFloat(src.Float())
87				case reflect.String:
88					dst.SetString(src.String())
89				case reflect.Bool:
90					dst.SetBool(src.Bool())
91				}
92			}
93		case reflect.Slice:
94			dst.Set(mergeSlice(dst, src))
95		case reflect.Struct:
96			// handle structs with IsZero method [ie time.Time]
97			if fnZero, ok := dst.Type().MethodByName(`IsZero`); ok {
98				res := fnZero.Func.Call([]reflect.Value{dst})
99				if len(res) > 0 {
100					if v, isok := res[0].Interface().(bool); isok && v {
101						dst.Set(src)
102					}
103				}
104			}
105
106			for i := 0; i < src.NumField(); i++ {
107				df := dst.Field(i)
108				sf := src.Field(i)
109				if err := merge(df, sf); err != nil {
110					return err
111				}
112			}
113		case reflect.Map:
114			dst.Set(mergeMap(dst, src))
115		case reflect.Ptr:
116			// defer pointers
117			if !dst.IsNil() {
118				dst = getValue(dst)
119			} else {
120				dst.Set(src)
121				break
122			}
123			if src.CanAddr() && src.IsNil() {
124				src = getValue(src)
125				if err := merge(dst, src); err != nil {
126					return err
127				}
128			}
129		default:
130			return ErrUnknown
131		}
132	}
133	return
134}
135
136// mergeSlice merges two slices only if dst slice fields are zero and
137// src fields are nonzero
138func mergeSlice(dst, src reflect.Value) (res reflect.Value) {
139	for i := 0; i < src.Len(); i++ {
140		if i >= dst.Len() {
141			dst = reflect.Append(dst, src.Index(i))
142		}
143		if err := merge(dst.Index(i), src.Index(i)); err != nil {
144			res = dst
145			return
146		}
147	}
148
149	res = dst
150	return
151}
152
153// mergeMap traverses a map and merges the nonzero values of
154// src into dst
155func mergeMap(dst, src reflect.Value) (res reflect.Value) {
156	if dst.IsNil() {
157		dst = reflect.MakeMap(dst.Type())
158	}
159
160	for _, k := range src.MapKeys() {
161		vs := src.MapIndex(k)
162		vd := dst.MapIndex(k)
163		if !vd.IsValid() && isZero(vd) && !isZero(vs) {
164			dst.SetMapIndex(k, vs)
165		}
166	}
167
168	return dst
169}
170
171// typesMatch typechecks two interfaces
172func typesMatch(a, b interface{}) bool {
173	return strings.TrimPrefix(reflect.TypeOf(a).String(), "*") == strings.TrimPrefix(reflect.TypeOf(b).String(), "*")
174}
175
176// getValue returns a reflect.Value from an interface
177// deferring pointers if needed
178func getValue(t interface{}) (rslt reflect.Value) {
179	rslt = reflect.ValueOf(t)
180
181	for rslt.Kind() == reflect.Ptr && !rslt.IsNil() {
182		rslt = rslt.Elem()
183	}
184
185	return
186}
187
188// isStructPtr determines if a value is a struct pointer
189func isStructPtr(v interface{}) bool {
190	t := reflect.TypeOf(v)
191	return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct
192}
193
194// isZero is mostly stolen from encoding/json package's isEmptyValue function
195// determines if a value has the zero value of its type
196func isZero(v reflect.Value) bool {
197	switch v.Kind() {
198	case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
199		return v.Len() == 0
200	case reflect.Bool:
201		return !v.Bool()
202	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
203		return v.Int() == 0
204	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
205		return v.Uint() == 0
206	case reflect.Float32, reflect.Float64:
207		return v.Float() == 0
208	case reflect.Interface, reflect.Ptr, reflect.Func:
209		return v.IsNil()
210	case reflect.Struct:
211		zero := reflect.Zero(v.Type()).Interface()
212		return reflect.DeepEqual(v.Interface(), zero)
213	default:
214		if !v.IsValid() {
215			return true
216		}
217
218		zero := reflect.Zero(v.Type())
219		return v.Interface() == zero.Interface()
220	}
221
222}
223