1package copier
2
3import (
4	"database/sql"
5	"database/sql/driver"
6	"fmt"
7	"reflect"
8	"strings"
9)
10
11// These flags define options for tag handling
12const (
13	// Denotes that a destination field must be copied to. If copying fails then a panic will ensue.
14	tagMust uint8 = 1 << iota
15
16	// Denotes that the program should not panic when the must flag is on and
17	// value is not copied. The program will return an error instead.
18	tagNoPanic
19
20	// Ignore a destination field from being copied to.
21	tagIgnore
22
23	// Denotes that the value as been copied
24	hasCopied
25)
26
27// Option sets copy options
28type Option struct {
29	// setting this value to true will ignore copying zero values of all the fields, including bools, as well as a
30	// struct having all it's fields set to their zero values respectively (see IsZero() in reflect/value.go)
31	IgnoreEmpty bool
32	DeepCopy    bool
33}
34
35// Copy copy things
36func Copy(toValue interface{}, fromValue interface{}) (err error) {
37	return copier(toValue, fromValue, Option{})
38}
39
40// CopyWithOption copy with option
41func CopyWithOption(toValue interface{}, fromValue interface{}, opt Option) (err error) {
42	return copier(toValue, fromValue, opt)
43}
44
45func copier(toValue interface{}, fromValue interface{}, opt Option) (err error) {
46	var (
47		isSlice bool
48		amount  = 1
49		from    = indirect(reflect.ValueOf(fromValue))
50		to      = indirect(reflect.ValueOf(toValue))
51	)
52
53	if !to.CanAddr() {
54		return ErrInvalidCopyDestination
55	}
56
57	// Return is from value is invalid
58	if !from.IsValid() {
59		return ErrInvalidCopyFrom
60	}
61
62	fromType, isPtrFrom := indirectType(from.Type())
63	toType, _ := indirectType(to.Type())
64
65	if fromType.Kind() == reflect.Interface {
66		fromType = reflect.TypeOf(from.Interface())
67	}
68
69	if toType.Kind() == reflect.Interface {
70		toType, _ = indirectType(reflect.TypeOf(to.Interface()))
71		oldTo := to
72		to = reflect.New(reflect.TypeOf(to.Interface())).Elem()
73		defer func() {
74			oldTo.Set(to)
75		}()
76	}
77
78	// Just set it if possible to assign for normal types
79	if from.Kind() != reflect.Slice && from.Kind() != reflect.Struct && from.Kind() != reflect.Map && (from.Type().AssignableTo(to.Type()) || from.Type().ConvertibleTo(to.Type())) {
80		if !isPtrFrom || !opt.DeepCopy {
81			to.Set(from.Convert(to.Type()))
82		} else {
83			fromCopy := reflect.New(from.Type())
84			fromCopy.Set(from.Elem())
85			to.Set(fromCopy.Convert(to.Type()))
86		}
87		return
88	}
89
90	if fromType.Kind() == reflect.Map && toType.Kind() == reflect.Map {
91		if !fromType.Key().ConvertibleTo(toType.Key()) {
92			return ErrMapKeyNotMatch
93		}
94
95		if to.IsNil() {
96			to.Set(reflect.MakeMapWithSize(toType, from.Len()))
97		}
98
99		for _, k := range from.MapKeys() {
100			toKey := indirect(reflect.New(toType.Key()))
101			if !set(toKey, k, opt.DeepCopy) {
102				return fmt.Errorf("%w map, old key: %v, new key: %v", ErrNotSupported, k.Type(), toType.Key())
103			}
104
105			elemType, _ := indirectType(toType.Elem())
106			toValue := indirect(reflect.New(elemType))
107			if !set(toValue, from.MapIndex(k), opt.DeepCopy) {
108				if err = copier(toValue.Addr().Interface(), from.MapIndex(k).Interface(), opt); err != nil {
109					return err
110				}
111			}
112
113			for {
114				if elemType == toType.Elem() {
115					to.SetMapIndex(toKey, toValue)
116					break
117				}
118				elemType = reflect.PtrTo(elemType)
119				toValue = toValue.Addr()
120			}
121		}
122		return
123	}
124
125	if from.Kind() == reflect.Slice && to.Kind() == reflect.Slice && fromType.ConvertibleTo(toType) {
126		if to.IsNil() {
127			slice := reflect.MakeSlice(reflect.SliceOf(to.Type().Elem()), from.Len(), from.Cap())
128			to.Set(slice)
129		}
130
131		for i := 0; i < from.Len(); i++ {
132			if to.Len() < i+1 {
133				to.Set(reflect.Append(to, reflect.New(to.Type().Elem()).Elem()))
134			}
135
136			if !set(to.Index(i), from.Index(i), opt.DeepCopy) {
137				err = CopyWithOption(to.Index(i).Addr().Interface(), from.Index(i).Interface(), opt)
138				if err != nil {
139					continue
140				}
141			}
142		}
143		return
144	}
145
146	if fromType.Kind() != reflect.Struct || toType.Kind() != reflect.Struct {
147		// skip not supported type
148		return
149	}
150
151	if to.Kind() == reflect.Slice {
152		isSlice = true
153		if from.Kind() == reflect.Slice {
154			amount = from.Len()
155		}
156	}
157
158	for i := 0; i < amount; i++ {
159		var dest, source reflect.Value
160
161		if isSlice {
162			// source
163			if from.Kind() == reflect.Slice {
164				source = indirect(from.Index(i))
165			} else {
166				source = indirect(from)
167			}
168			// dest
169			dest = indirect(reflect.New(toType).Elem())
170		} else {
171			source = indirect(from)
172			dest = indirect(to)
173		}
174
175		destKind := dest.Kind()
176		initDest := false
177		if destKind == reflect.Interface {
178			initDest = true
179			dest = indirect(reflect.New(toType))
180		}
181
182		// Get tag options
183		tagBitFlags := map[string]uint8{}
184		if dest.IsValid() {
185			tagBitFlags = getBitFlags(toType)
186		}
187
188		// check source
189		if source.IsValid() {
190			// Copy from source field to dest field or method
191			fromTypeFields := deepFields(fromType)
192			for _, field := range fromTypeFields {
193				name := field.Name
194
195				// Get bit flags for field
196				fieldFlags, _ := tagBitFlags[name]
197
198				// Check if we should ignore copying
199				if (fieldFlags & tagIgnore) != 0 {
200					continue
201				}
202
203				if fromField := source.FieldByName(name); fromField.IsValid() && !shouldIgnore(fromField, opt.IgnoreEmpty) {
204					// process for nested anonymous field
205					destFieldNotSet := false
206					if f, ok := dest.Type().FieldByName(name); ok {
207						for idx := range f.Index {
208							destField := dest.FieldByIndex(f.Index[:idx+1])
209
210							if destField.Kind() != reflect.Ptr {
211								continue
212							}
213
214							if !destField.IsNil() {
215								continue
216							}
217							if !destField.CanSet() {
218								destFieldNotSet = true
219								break
220							}
221
222							// destField is a nil pointer that can be set
223							newValue := reflect.New(destField.Type().Elem())
224							destField.Set(newValue)
225						}
226					}
227
228					if destFieldNotSet {
229						break
230					}
231
232					toField := dest.FieldByName(name)
233					if toField.IsValid() {
234						if toField.CanSet() {
235							if !set(toField, fromField, opt.DeepCopy) {
236								if err := copier(toField.Addr().Interface(), fromField.Interface(), opt); err != nil {
237									return err
238								}
239							}
240							if fieldFlags != 0 {
241								// Note that a copy was made
242								tagBitFlags[name] = fieldFlags | hasCopied
243							}
244						}
245					} else {
246						// try to set to method
247						var toMethod reflect.Value
248						if dest.CanAddr() {
249							toMethod = dest.Addr().MethodByName(name)
250						} else {
251							toMethod = dest.MethodByName(name)
252						}
253
254						if toMethod.IsValid() && toMethod.Type().NumIn() == 1 && fromField.Type().AssignableTo(toMethod.Type().In(0)) {
255							toMethod.Call([]reflect.Value{fromField})
256						}
257					}
258				}
259			}
260
261			// Copy from from method to dest field
262			for _, field := range deepFields(toType) {
263				name := field.Name
264
265				var fromMethod reflect.Value
266				if source.CanAddr() {
267					fromMethod = source.Addr().MethodByName(name)
268				} else {
269					fromMethod = source.MethodByName(name)
270				}
271
272				if fromMethod.IsValid() && fromMethod.Type().NumIn() == 0 && fromMethod.Type().NumOut() == 1 && !shouldIgnore(fromMethod, opt.IgnoreEmpty) {
273					if toField := dest.FieldByName(name); toField.IsValid() && toField.CanSet() {
274						values := fromMethod.Call([]reflect.Value{})
275						if len(values) >= 1 {
276							set(toField, values[0], opt.DeepCopy)
277						}
278					}
279				}
280			}
281		}
282
283		if isSlice {
284			if dest.Addr().Type().AssignableTo(to.Type().Elem()) {
285				to.Set(reflect.Append(to, dest.Addr()))
286			} else if dest.Type().AssignableTo(to.Type().Elem()) {
287				to.Set(reflect.Append(to, dest))
288			}
289		} else if initDest {
290			to.Set(dest)
291		}
292
293		err = checkBitFlags(tagBitFlags)
294	}
295
296	return
297}
298
299func shouldIgnore(v reflect.Value, ignoreEmpty bool) bool {
300	if !ignoreEmpty {
301		return false
302	}
303
304	return v.IsZero()
305}
306
307func deepFields(reflectType reflect.Type) []reflect.StructField {
308	if reflectType, _ = indirectType(reflectType); reflectType.Kind() == reflect.Struct {
309		fields := make([]reflect.StructField, 0, reflectType.NumField())
310
311		for i := 0; i < reflectType.NumField(); i++ {
312			v := reflectType.Field(i)
313			if v.Anonymous {
314				fields = append(fields, deepFields(v.Type)...)
315			} else {
316				fields = append(fields, v)
317			}
318		}
319
320		return fields
321	}
322
323	return nil
324}
325
326func indirect(reflectValue reflect.Value) reflect.Value {
327	for reflectValue.Kind() == reflect.Ptr {
328		reflectValue = reflectValue.Elem()
329	}
330	return reflectValue
331}
332
333func indirectType(reflectType reflect.Type) (_ reflect.Type, isPtr bool) {
334	for reflectType.Kind() == reflect.Ptr || reflectType.Kind() == reflect.Slice {
335		reflectType = reflectType.Elem()
336		isPtr = true
337	}
338	return reflectType, isPtr
339}
340
341func set(to, from reflect.Value, deepCopy bool) bool {
342	if from.IsValid() {
343		if to.Kind() == reflect.Ptr {
344			// set `to` to nil if from is nil
345			if from.Kind() == reflect.Ptr && from.IsNil() {
346				to.Set(reflect.Zero(to.Type()))
347				return true
348			} else if to.IsNil() {
349				// `from`         -> `to`
350				// sql.NullString -> *string
351				if fromValuer, ok := driverValuer(from); ok {
352					v, err := fromValuer.Value()
353					if err != nil {
354						return false
355					}
356					// if `from` is not valid do nothing with `to`
357					if v == nil {
358						return true
359					}
360				}
361				// allocate new `to` variable with default value (eg. *string -> new(string))
362				to.Set(reflect.New(to.Type().Elem()))
363			}
364			// depointer `to`
365			to = to.Elem()
366		}
367
368		if deepCopy {
369			toKind := to.Kind()
370			if toKind == reflect.Interface && to.IsNil() {
371				to.Set(reflect.New(reflect.TypeOf(from.Interface())).Elem())
372				toKind = reflect.TypeOf(to.Interface()).Kind()
373			}
374			if toKind == reflect.Struct || toKind == reflect.Map || toKind == reflect.Slice {
375				return false
376			}
377		}
378
379		if from.Type().ConvertibleTo(to.Type()) {
380			to.Set(from.Convert(to.Type()))
381		} else if toScanner, ok := to.Addr().Interface().(sql.Scanner); ok {
382			// `from`  -> `to`
383			// *string -> sql.NullString
384			if from.Kind() == reflect.Ptr {
385				// if `from` is nil do nothing with `to`
386				if from.IsNil() {
387					return true
388				}
389				// depointer `from`
390				from = indirect(from)
391			}
392			// `from` -> `to`
393			// string -> sql.NullString
394			// set `to` by invoking method Scan(`from`)
395			err := toScanner.Scan(from.Interface())
396			if err != nil {
397				return false
398			}
399		} else if fromValuer, ok := driverValuer(from); ok {
400			// `from`         -> `to`
401			// sql.NullString -> string
402			v, err := fromValuer.Value()
403			if err != nil {
404				return false
405			}
406			// if `from` is not valid do nothing with `to`
407			if v == nil {
408				return true
409			}
410			rv := reflect.ValueOf(v)
411			if rv.Type().AssignableTo(to.Type()) {
412				to.Set(rv)
413			}
414		} else if from.Kind() == reflect.Ptr {
415			return set(to, from.Elem(), deepCopy)
416		} else {
417			return false
418		}
419	}
420
421	return true
422}
423
424// parseTags Parses struct tags and returns uint8 bit flags.
425func parseTags(tag string) (flags uint8) {
426	for _, t := range strings.Split(tag, ",") {
427		switch t {
428		case "-":
429			flags = tagIgnore
430			return
431		case "must":
432			flags = flags | tagMust
433		case "nopanic":
434			flags = flags | tagNoPanic
435		}
436	}
437	return
438}
439
440// getBitFlags Parses struct tags for bit flags.
441func getBitFlags(toType reflect.Type) map[string]uint8 {
442	flags := map[string]uint8{}
443	toTypeFields := deepFields(toType)
444
445	// Get a list dest of tags
446	for _, field := range toTypeFields {
447		tags := field.Tag.Get("copier")
448		if tags != "" {
449			flags[field.Name] = parseTags(tags)
450		}
451	}
452	return flags
453}
454
455// checkBitFlags Checks flags for error or panic conditions.
456func checkBitFlags(flagsList map[string]uint8) (err error) {
457	// Check flag conditions were met
458	for name, flags := range flagsList {
459		if flags&hasCopied == 0 {
460			switch {
461			case flags&tagMust != 0 && flags&tagNoPanic != 0:
462				err = fmt.Errorf("field %s has must tag but was not copied", name)
463				return
464			case flags&(tagMust) != 0:
465				panic(fmt.Sprintf("Field %s has must tag but was not copied", name))
466			}
467		}
468	}
469	return
470}
471
472func driverValuer(v reflect.Value) (i driver.Valuer, ok bool) {
473
474	if !v.CanAddr() {
475		i, ok = v.Interface().(driver.Valuer)
476		return
477	}
478
479	i, ok = v.Addr().Interface().(driver.Valuer)
480	return
481}
482