1// Copyright (c) 2015 Hiram Jerónimo Pérez https://worg.xyz 2 3// Permission is hereby granted, free of charge, to any person obtaining a copy 4// of this software and associated documentation files (the "Software"), to deal 5// in the Software without restriction, including without limitation the rights 6// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7// copies of the Software, and to permit persons to whom the Software is 8// furnished to do so, subject to the following conditions: 9 10// The above copyright notice and this permission notice shall be included in 11// all copies or substantial portions of the Software. 12 13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19// THE SOFTWARE. 20 21// Package merger is an utility to merge structs of the same type 22package merger 23 24import ( 25 "errors" 26 "reflect" 27 "strings" 28) 29 30var ( 31 // ErrDistinctType occurs when trying to merge structs of distinct type 32 ErrDistinctType = errors.New(`dst and src must be of the same type`) 33 // ErrNoPtr occurs when no struct pointer is sent as destination 34 ErrNoPtr = errors.New(`dst must be a pointer to a struct`) 35 // ErrNilArguments occurs on receiving nil as arguments 36 ErrNilArguments = errors.New(`no nil values allowed`) 37 // ErrUnknown occurs if the type can't be merged 38 ErrUnknown = errors.New(`could not merge`) 39) 40 41// Merge sets zero values from dst to non zero values of src 42// accepts two structs of the same type as arguments 43// dst must be a pointer to a struct 44func Merge(dst, src interface{}) error { 45 if dst == nil || src == nil { 46 return ErrNilArguments 47 } 48 49 if !isStructPtr(dst) { 50 return ErrNoPtr 51 } 52 53 if !typesMatch(src, dst) { 54 return ErrDistinctType 55 } 56 57 vSrc := getValue(src) 58 vDst := getValue(dst) 59 60 for i := 0; i < vSrc.NumField(); i++ { 61 df := vDst.Field(i) 62 sf := vSrc.Field(i) 63 if err := merge(df, sf); err != nil { 64 return err 65 } 66 } 67 68 return nil 69} 70 71// merge merges two reflect values based upon their kinds 72func merge(dst, src reflect.Value) (err error) { 73 if dst.CanSet() && !isZero(src) { 74 switch dst.Kind() { 75 // base types 76 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, 77 reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, 78 reflect.Float32, reflect.Float64, reflect.String, reflect.Bool: 79 if isZero(dst) { 80 switch dst.Kind() { 81 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 82 dst.SetInt(src.Int()) 83 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: 84 dst.SetUint(src.Uint()) 85 case reflect.Float32, reflect.Float64: 86 dst.SetFloat(src.Float()) 87 case reflect.String: 88 dst.SetString(src.String()) 89 case reflect.Bool: 90 dst.SetBool(src.Bool()) 91 } 92 } 93 case reflect.Slice: 94 dst.Set(mergeSlice(dst, src)) 95 case reflect.Struct: 96 // handle structs with IsZero method [ie time.Time] 97 if fnZero, ok := dst.Type().MethodByName(`IsZero`); ok { 98 res := fnZero.Func.Call([]reflect.Value{dst}) 99 if len(res) > 0 { 100 if v, isok := res[0].Interface().(bool); isok && v { 101 dst.Set(src) 102 } 103 } 104 } 105 106 for i := 0; i < src.NumField(); i++ { 107 df := dst.Field(i) 108 sf := src.Field(i) 109 if err := merge(df, sf); err != nil { 110 return err 111 } 112 } 113 case reflect.Map: 114 dst.Set(mergeMap(dst, src)) 115 case reflect.Ptr: 116 // defer pointers 117 if !dst.IsNil() { 118 dst = getValue(dst) 119 } else { 120 dst.Set(src) 121 break 122 } 123 if src.CanAddr() && src.IsNil() { 124 src = getValue(src) 125 if err := merge(dst, src); err != nil { 126 return err 127 } 128 } 129 default: 130 return ErrUnknown 131 } 132 } 133 return 134} 135 136// mergeSlice merges two slices only if dst slice fields are zero and 137// src fields are nonzero 138func mergeSlice(dst, src reflect.Value) (res reflect.Value) { 139 for i := 0; i < src.Len(); i++ { 140 if i >= dst.Len() { 141 dst = reflect.Append(dst, src.Index(i)) 142 } 143 if err := merge(dst.Index(i), src.Index(i)); err != nil { 144 res = dst 145 return 146 } 147 } 148 149 res = dst 150 return 151} 152 153// mergeMap traverses a map and merges the nonzero values of 154// src into dst 155func mergeMap(dst, src reflect.Value) (res reflect.Value) { 156 if dst.IsNil() { 157 dst = reflect.MakeMap(dst.Type()) 158 } 159 160 for _, k := range src.MapKeys() { 161 vs := src.MapIndex(k) 162 vd := dst.MapIndex(k) 163 if !vd.IsValid() && isZero(vd) && !isZero(vs) { 164 dst.SetMapIndex(k, vs) 165 } 166 } 167 168 return dst 169} 170 171// typesMatch typechecks two interfaces 172func typesMatch(a, b interface{}) bool { 173 return strings.TrimPrefix(reflect.TypeOf(a).String(), "*") == strings.TrimPrefix(reflect.TypeOf(b).String(), "*") 174} 175 176// getValue returns a reflect.Value from an interface 177// deferring pointers if needed 178func getValue(t interface{}) (rslt reflect.Value) { 179 rslt = reflect.ValueOf(t) 180 181 for rslt.Kind() == reflect.Ptr && !rslt.IsNil() { 182 rslt = rslt.Elem() 183 } 184 185 return 186} 187 188// isStructPtr determines if a value is a struct pointer 189func isStructPtr(v interface{}) bool { 190 t := reflect.TypeOf(v) 191 return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct 192} 193 194// isZero is mostly stolen from encoding/json package's isEmptyValue function 195// determines if a value has the zero value of its type 196func isZero(v reflect.Value) bool { 197 switch v.Kind() { 198 case reflect.Array, reflect.Map, reflect.Slice, reflect.String: 199 return v.Len() == 0 200 case reflect.Bool: 201 return !v.Bool() 202 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 203 return v.Int() == 0 204 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: 205 return v.Uint() == 0 206 case reflect.Float32, reflect.Float64: 207 return v.Float() == 0 208 case reflect.Interface, reflect.Ptr, reflect.Func: 209 return v.IsNil() 210 case reflect.Struct: 211 zero := reflect.Zero(v.Type()).Interface() 212 return reflect.DeepEqual(v.Interface(), zero) 213 default: 214 if !v.IsValid() { 215 return true 216 } 217 218 zero := reflect.Zero(v.Type()) 219 return v.Interface() == zero.Interface() 220 } 221 222} 223