1package copystructure
2
3import (
4	"errors"
5	"reflect"
6	"sync"
7
8	"github.com/mitchellh/reflectwalk"
9)
10
11// Copy returns a deep copy of v.
12func Copy(v interface{}) (interface{}, error) {
13	return Config{}.Copy(v)
14}
15
16// CopierFunc is a function that knows how to deep copy a specific type.
17// Register these globally with the Copiers variable.
18type CopierFunc func(interface{}) (interface{}, error)
19
20// Copiers is a map of types that behave specially when they are copied.
21// If a type is found in this map while deep copying, this function
22// will be called to copy it instead of attempting to copy all fields.
23//
24// The key should be the type, obtained using: reflect.TypeOf(value with type).
25//
26// It is unsafe to write to this map after Copies have started. If you
27// are writing to this map while also copying, wrap all modifications to
28// this map as well as to Copy in a mutex.
29var Copiers map[reflect.Type]CopierFunc = make(map[reflect.Type]CopierFunc)
30
31// Must is a helper that wraps a call to a function returning
32// (interface{}, error) and panics if the error is non-nil. It is intended
33// for use in variable initializations and should only be used when a copy
34// error should be a crashing case.
35func Must(v interface{}, err error) interface{} {
36	if err != nil {
37		panic("copy error: " + err.Error())
38	}
39
40	return v
41}
42
43var errPointerRequired = errors.New("Copy argument must be a pointer when Lock is true")
44
45type Config struct {
46	// Lock any types that are a sync.Locker and are not a mutex while copying.
47	// If there is an RLocker method, use that to get the sync.Locker.
48	Lock bool
49
50	// Copiers is a map of types associated with a CopierFunc. Use the global
51	// Copiers map if this is nil.
52	Copiers map[reflect.Type]CopierFunc
53}
54
55func (c Config) Copy(v interface{}) (interface{}, error) {
56	if c.Lock && reflect.ValueOf(v).Kind() != reflect.Ptr {
57		return nil, errPointerRequired
58	}
59
60	w := new(walker)
61	if c.Lock {
62		w.useLocks = true
63	}
64
65	if c.Copiers == nil {
66		c.Copiers = Copiers
67	}
68
69	err := reflectwalk.Walk(v, w)
70	if err != nil {
71		return nil, err
72	}
73
74	// Get the result. If the result is nil, then we want to turn it
75	// into a typed nil if we can.
76	result := w.Result
77	if result == nil {
78		val := reflect.ValueOf(v)
79		result = reflect.Indirect(reflect.New(val.Type())).Interface()
80	}
81
82	return result, nil
83}
84
85// Return the key used to index interfaces types we've seen. Store the number
86// of pointers in the upper 32bits, and the depth in the lower 32bits. This is
87// easy to calculate, easy to match a key with our current depth, and we don't
88// need to deal with initializing and cleaning up nested maps or slices.
89func ifaceKey(pointers, depth int) uint64 {
90	return uint64(pointers)<<32 | uint64(depth)
91}
92
93type walker struct {
94	Result interface{}
95
96	depth       int
97	ignoreDepth int
98	vals        []reflect.Value
99	cs          []reflect.Value
100
101	// This stores the number of pointers we've walked over, indexed by depth.
102	ps []int
103
104	// If an interface is indirected by a pointer, we need to know the type of
105	// interface to create when creating the new value.  Store the interface
106	// types here, indexed by both the walk depth and the number of pointers
107	// already seen at that depth. Use ifaceKey to calculate the proper uint64
108	// value.
109	ifaceTypes map[uint64]reflect.Type
110
111	// any locks we've taken, indexed by depth
112	locks []sync.Locker
113	// take locks while walking the structure
114	useLocks bool
115}
116
117func (w *walker) Enter(l reflectwalk.Location) error {
118	w.depth++
119
120	// ensure we have enough elements to index via w.depth
121	for w.depth >= len(w.locks) {
122		w.locks = append(w.locks, nil)
123	}
124
125	for len(w.ps) < w.depth+1 {
126		w.ps = append(w.ps, 0)
127	}
128
129	return nil
130}
131
132func (w *walker) Exit(l reflectwalk.Location) error {
133	locker := w.locks[w.depth]
134	w.locks[w.depth] = nil
135	if locker != nil {
136		defer locker.Unlock()
137	}
138
139	// clear out pointers and interfaces as we exit the stack
140	w.ps[w.depth] = 0
141
142	for k := range w.ifaceTypes {
143		mask := uint64(^uint32(0))
144		if k&mask == uint64(w.depth) {
145			delete(w.ifaceTypes, k)
146		}
147	}
148
149	w.depth--
150	if w.ignoreDepth > w.depth {
151		w.ignoreDepth = 0
152	}
153
154	if w.ignoring() {
155		return nil
156	}
157
158	switch l {
159	case reflectwalk.Array:
160		fallthrough
161	case reflectwalk.Map:
162		fallthrough
163	case reflectwalk.Slice:
164		w.replacePointerMaybe()
165
166		// Pop map off our container
167		w.cs = w.cs[:len(w.cs)-1]
168	case reflectwalk.MapValue:
169		// Pop off the key and value
170		mv := w.valPop()
171		mk := w.valPop()
172		m := w.cs[len(w.cs)-1]
173
174		// If mv is the zero value, SetMapIndex deletes the key form the map,
175		// or in this case never adds it. We need to create a properly typed
176		// zero value so that this key can be set.
177		if !mv.IsValid() {
178			mv = reflect.Zero(m.Elem().Type().Elem())
179		}
180		m.Elem().SetMapIndex(mk, mv)
181	case reflectwalk.ArrayElem:
182		// Pop off the value and the index and set it on the array
183		v := w.valPop()
184		i := w.valPop().Interface().(int)
185		if v.IsValid() {
186			a := w.cs[len(w.cs)-1]
187			ae := a.Elem().Index(i) // storing array as pointer on stack - so need Elem() call
188			if ae.CanSet() {
189				ae.Set(v)
190			}
191		}
192	case reflectwalk.SliceElem:
193		// Pop off the value and the index and set it on the slice
194		v := w.valPop()
195		i := w.valPop().Interface().(int)
196		if v.IsValid() {
197			s := w.cs[len(w.cs)-1]
198			se := s.Elem().Index(i)
199			if se.CanSet() {
200				se.Set(v)
201			}
202		}
203	case reflectwalk.Struct:
204		w.replacePointerMaybe()
205
206		// Remove the struct from the container stack
207		w.cs = w.cs[:len(w.cs)-1]
208	case reflectwalk.StructField:
209		// Pop off the value and the field
210		v := w.valPop()
211		f := w.valPop().Interface().(reflect.StructField)
212		if v.IsValid() {
213			s := w.cs[len(w.cs)-1]
214			sf := reflect.Indirect(s).FieldByName(f.Name)
215
216			if sf.CanSet() {
217				sf.Set(v)
218			}
219		}
220	case reflectwalk.WalkLoc:
221		// Clear out the slices for GC
222		w.cs = nil
223		w.vals = nil
224	}
225
226	return nil
227}
228
229func (w *walker) Map(m reflect.Value) error {
230	if w.ignoring() {
231		return nil
232	}
233	w.lock(m)
234
235	// Create the map. If the map itself is nil, then just make a nil map
236	var newMap reflect.Value
237	if m.IsNil() {
238		newMap = reflect.New(m.Type())
239	} else {
240		newMap = wrapPtr(reflect.MakeMap(m.Type()))
241	}
242
243	w.cs = append(w.cs, newMap)
244	w.valPush(newMap)
245	return nil
246}
247
248func (w *walker) MapElem(m, k, v reflect.Value) error {
249	return nil
250}
251
252func (w *walker) PointerEnter(v bool) error {
253	if v {
254		w.ps[w.depth]++
255	}
256	return nil
257}
258
259func (w *walker) PointerExit(v bool) error {
260	if v {
261		w.ps[w.depth]--
262	}
263	return nil
264}
265
266func (w *walker) Interface(v reflect.Value) error {
267	if !v.IsValid() {
268		return nil
269	}
270	if w.ifaceTypes == nil {
271		w.ifaceTypes = make(map[uint64]reflect.Type)
272	}
273
274	w.ifaceTypes[ifaceKey(w.ps[w.depth], w.depth)] = v.Type()
275	return nil
276}
277
278func (w *walker) Primitive(v reflect.Value) error {
279	if w.ignoring() {
280		return nil
281	}
282	w.lock(v)
283
284	// IsValid verifies the v is non-zero and CanInterface verifies
285	// that we're allowed to read this value (unexported fields).
286	var newV reflect.Value
287	if v.IsValid() && v.CanInterface() {
288		newV = reflect.New(v.Type())
289		newV.Elem().Set(v)
290	}
291
292	w.valPush(newV)
293	w.replacePointerMaybe()
294	return nil
295}
296
297func (w *walker) Slice(s reflect.Value) error {
298	if w.ignoring() {
299		return nil
300	}
301	w.lock(s)
302
303	var newS reflect.Value
304	if s.IsNil() {
305		newS = reflect.New(s.Type())
306	} else {
307		newS = wrapPtr(reflect.MakeSlice(s.Type(), s.Len(), s.Cap()))
308	}
309
310	w.cs = append(w.cs, newS)
311	w.valPush(newS)
312	return nil
313}
314
315func (w *walker) SliceElem(i int, elem reflect.Value) error {
316	if w.ignoring() {
317		return nil
318	}
319
320	// We don't write the slice here because elem might still be
321	// arbitrarily complex. Just record the index and continue on.
322	w.valPush(reflect.ValueOf(i))
323
324	return nil
325}
326
327func (w *walker) Array(a reflect.Value) error {
328	if w.ignoring() {
329		return nil
330	}
331	w.lock(a)
332
333	newA := reflect.New(a.Type())
334
335	w.cs = append(w.cs, newA)
336	w.valPush(newA)
337	return nil
338}
339
340func (w *walker) ArrayElem(i int, elem reflect.Value) error {
341	if w.ignoring() {
342		return nil
343	}
344
345	// We don't write the array here because elem might still be
346	// arbitrarily complex. Just record the index and continue on.
347	w.valPush(reflect.ValueOf(i))
348
349	return nil
350}
351
352func (w *walker) Struct(s reflect.Value) error {
353	if w.ignoring() {
354		return nil
355	}
356	w.lock(s)
357
358	var v reflect.Value
359	if c, ok := Copiers[s.Type()]; ok {
360		// We have a Copier for this struct, so we use that copier to
361		// get the copy, and we ignore anything deeper than this.
362		w.ignoreDepth = w.depth
363
364		dup, err := c(s.Interface())
365		if err != nil {
366			return err
367		}
368
369		// We need to put a pointer to the value on the value stack,
370		// so allocate a new pointer and set it.
371		v = reflect.New(s.Type())
372		reflect.Indirect(v).Set(reflect.ValueOf(dup))
373	} else {
374		// No copier, we copy ourselves and allow reflectwalk to guide
375		// us deeper into the structure for copying.
376		v = reflect.New(s.Type())
377	}
378
379	// Push the value onto the value stack for setting the struct field,
380	// and add the struct itself to the containers stack in case we walk
381	// deeper so that its own fields can be modified.
382	w.valPush(v)
383	w.cs = append(w.cs, v)
384
385	return nil
386}
387
388func (w *walker) StructField(f reflect.StructField, v reflect.Value) error {
389	if w.ignoring() {
390		return nil
391	}
392
393	// If PkgPath is non-empty, this is a private (unexported) field.
394	// We do not set this unexported since the Go runtime doesn't allow us.
395	if f.PkgPath != "" {
396		return reflectwalk.SkipEntry
397	}
398
399	// Push the field onto the stack, we'll handle it when we exit
400	// the struct field in Exit...
401	w.valPush(reflect.ValueOf(f))
402	return nil
403}
404
405// ignore causes the walker to ignore any more values until we exit this on
406func (w *walker) ignore() {
407	w.ignoreDepth = w.depth
408}
409
410func (w *walker) ignoring() bool {
411	return w.ignoreDepth > 0 && w.depth >= w.ignoreDepth
412}
413
414func (w *walker) pointerPeek() bool {
415	return w.ps[w.depth] > 0
416}
417
418func (w *walker) valPop() reflect.Value {
419	result := w.vals[len(w.vals)-1]
420	w.vals = w.vals[:len(w.vals)-1]
421
422	// If we're out of values, that means we popped everything off. In
423	// this case, we reset the result so the next pushed value becomes
424	// the result.
425	if len(w.vals) == 0 {
426		w.Result = nil
427	}
428
429	return result
430}
431
432func (w *walker) valPush(v reflect.Value) {
433	w.vals = append(w.vals, v)
434
435	// If we haven't set the result yet, then this is the result since
436	// it is the first (outermost) value we're seeing.
437	if w.Result == nil && v.IsValid() {
438		w.Result = v.Interface()
439	}
440}
441
442func (w *walker) replacePointerMaybe() {
443	// Determine the last pointer value. If it is NOT a pointer, then
444	// we need to push that onto the stack.
445	if !w.pointerPeek() {
446		w.valPush(reflect.Indirect(w.valPop()))
447		return
448	}
449
450	v := w.valPop()
451
452	// If the expected type is a pointer to an interface of any depth,
453	// such as *interface{}, **interface{}, etc., then we need to convert
454	// the value "v" from *CONCRETE to *interface{} so types match for
455	// Set.
456	//
457	// Example if v is type *Foo where Foo is a struct, v would become
458	// *interface{} instead. This only happens if we have an interface expectation
459	// at this depth.
460	//
461	// For more info, see GH-16
462	if iType, ok := w.ifaceTypes[ifaceKey(w.ps[w.depth], w.depth)]; ok && iType.Kind() == reflect.Interface {
463		y := reflect.New(iType)           // Create *interface{}
464		y.Elem().Set(reflect.Indirect(v)) // Assign "Foo" to interface{} (dereferenced)
465		v = y                             // v is now typed *interface{} (where *v = Foo)
466	}
467
468	for i := 1; i < w.ps[w.depth]; i++ {
469		if iType, ok := w.ifaceTypes[ifaceKey(w.ps[w.depth]-i, w.depth)]; ok {
470			iface := reflect.New(iType).Elem()
471			iface.Set(v)
472			v = iface
473		}
474
475		p := reflect.New(v.Type())
476		p.Elem().Set(v)
477		v = p
478	}
479
480	w.valPush(v)
481}
482
483// if this value is a Locker, lock it and add it to the locks slice
484func (w *walker) lock(v reflect.Value) {
485	if !w.useLocks {
486		return
487	}
488
489	if !v.IsValid() || !v.CanInterface() {
490		return
491	}
492
493	type rlocker interface {
494		RLocker() sync.Locker
495	}
496
497	var locker sync.Locker
498
499	// We can't call Interface() on a value directly, since that requires
500	// a copy. This is OK, since the pointer to a value which is a sync.Locker
501	// is also a sync.Locker.
502	if v.Kind() == reflect.Ptr {
503		switch l := v.Interface().(type) {
504		case rlocker:
505			// don't lock a mutex directly
506			if _, ok := l.(*sync.RWMutex); !ok {
507				locker = l.RLocker()
508			}
509		case sync.Locker:
510			locker = l
511		}
512	} else if v.CanAddr() {
513		switch l := v.Addr().Interface().(type) {
514		case rlocker:
515			// don't lock a mutex directly
516			if _, ok := l.(*sync.RWMutex); !ok {
517				locker = l.RLocker()
518			}
519		case sync.Locker:
520			locker = l
521		}
522	}
523
524	// still no callable locker
525	if locker == nil {
526		return
527	}
528
529	// don't lock a mutex directly
530	switch locker.(type) {
531	case *sync.Mutex, *sync.RWMutex:
532		return
533	}
534
535	locker.Lock()
536	w.locks[w.depth] = locker
537}
538
539// wrapPtr is a helper that takes v and always make it *v. copystructure
540// stores things internally as pointers until the last moment before unwrapping
541func wrapPtr(v reflect.Value) reflect.Value {
542	if !v.IsValid() {
543		return v
544	}
545	vPtr := reflect.New(v.Type())
546	vPtr.Elem().Set(v)
547	return vPtr
548}
549