1package copystructure
2
3import (
4	"reflect"
5
6	"github.com/mitchellh/reflectwalk"
7)
8
9// Copy returns a deep copy of v.
10func Copy(v interface{}) (interface{}, error) {
11	w := new(walker)
12	err := reflectwalk.Walk(v, w)
13	if err != nil {
14		return nil, err
15	}
16
17	// Get the result. If the result is nil, then we want to turn it
18	// into a typed nil if we can.
19	result := w.Result
20	if result == nil {
21		val := reflect.ValueOf(v)
22		result = reflect.Indirect(reflect.New(val.Type())).Interface()
23	}
24
25	return result, nil
26}
27
28// CopierFunc is a function that knows how to deep copy a specific type.
29// Register these globally with the Copiers variable.
30type CopierFunc func(interface{}) (interface{}, error)
31
32// Copiers is a map of types that behave specially when they are copied.
33// If a type is found in this map while deep copying, this function
34// will be called to copy it instead of attempting to copy all fields.
35//
36// The key should be the type, obtained using: reflect.TypeOf(value with type).
37//
38// It is unsafe to write to this map after Copies have started. If you
39// are writing to this map while also copying, wrap all modifications to
40// this map as well as to Copy in a mutex.
41var Copiers map[reflect.Type]CopierFunc = make(map[reflect.Type]CopierFunc)
42
43type walker struct {
44	Result interface{}
45
46	depth       int
47	ignoreDepth int
48	vals        []reflect.Value
49	cs          []reflect.Value
50	ps          []bool
51}
52
53func (w *walker) Enter(l reflectwalk.Location) error {
54	w.depth++
55	return nil
56}
57
58func (w *walker) Exit(l reflectwalk.Location) error {
59	w.depth--
60	if w.ignoreDepth > w.depth {
61		w.ignoreDepth = 0
62	}
63
64	if w.ignoring() {
65		return nil
66	}
67
68	switch l {
69	case reflectwalk.Map:
70		fallthrough
71	case reflectwalk.Slice:
72		// Pop map off our container
73		w.cs = w.cs[:len(w.cs)-1]
74	case reflectwalk.MapValue:
75		// Pop off the key and value
76		mv := w.valPop()
77		mk := w.valPop()
78		m := w.cs[len(w.cs)-1]
79		m.SetMapIndex(mk, mv)
80	case reflectwalk.SliceElem:
81		// Pop off the value and the index and set it on the slice
82		v := w.valPop()
83		i := w.valPop().Interface().(int)
84		s := w.cs[len(w.cs)-1]
85		s.Index(i).Set(v)
86	case reflectwalk.Struct:
87		w.replacePointerMaybe()
88
89		// Remove the struct from the container stack
90		w.cs = w.cs[:len(w.cs)-1]
91	case reflectwalk.StructField:
92		// Pop off the value and the field
93		v := w.valPop()
94		f := w.valPop().Interface().(reflect.StructField)
95		if v.IsValid() {
96			s := w.cs[len(w.cs)-1]
97			sf := reflect.Indirect(s).FieldByName(f.Name)
98			sf.Set(v)
99		}
100	case reflectwalk.WalkLoc:
101		// Clear out the slices for GC
102		w.cs = nil
103		w.vals = nil
104	}
105
106	return nil
107}
108
109func (w *walker) Map(m reflect.Value) error {
110	if w.ignoring() {
111		return nil
112	}
113
114	// Get the type for the map
115	t := m.Type()
116	mapType := reflect.MapOf(t.Key(), t.Elem())
117
118	// Create the map. If the map itself is nil, then just make a nil map
119	var newMap reflect.Value
120	if m.IsNil() {
121		newMap = reflect.Indirect(reflect.New(mapType))
122	} else {
123		newMap = reflect.MakeMap(reflect.MapOf(t.Key(), t.Elem()))
124	}
125
126	w.cs = append(w.cs, newMap)
127	w.valPush(newMap)
128	return nil
129}
130
131func (w *walker) MapElem(m, k, v reflect.Value) error {
132	return nil
133}
134
135func (w *walker) PointerEnter(v bool) error {
136	if w.ignoring() {
137		return nil
138	}
139
140	w.ps = append(w.ps, v)
141	return nil
142}
143
144func (w *walker) PointerExit(bool) error {
145	if w.ignoring() {
146		return nil
147	}
148
149	w.ps = w.ps[:len(w.ps)-1]
150	return nil
151}
152
153func (w *walker) Primitive(v reflect.Value) error {
154	if w.ignoring() {
155		return nil
156	}
157
158	// IsValid verifies the v is non-zero and CanInterface verifies
159	// that we're allowed to read this value (unexported fields).
160	var newV reflect.Value
161	if v.IsValid() && v.CanInterface() {
162		newV = reflect.New(v.Type())
163		reflect.Indirect(newV).Set(v)
164	}
165
166	w.valPush(newV)
167	w.replacePointerMaybe()
168	return nil
169}
170
171func (w *walker) Slice(s reflect.Value) error {
172	if w.ignoring() {
173		return nil
174	}
175
176	var newS reflect.Value
177	if s.IsNil() {
178		newS = reflect.Indirect(reflect.New(s.Type()))
179	} else {
180		newS = reflect.MakeSlice(s.Type(), s.Len(), s.Cap())
181	}
182
183	w.cs = append(w.cs, newS)
184	w.valPush(newS)
185	return nil
186}
187
188func (w *walker) SliceElem(i int, elem reflect.Value) error {
189	if w.ignoring() {
190		return nil
191	}
192
193	// We don't write the slice here because elem might still be
194	// arbitrarily complex. Just record the index and continue on.
195	w.valPush(reflect.ValueOf(i))
196
197	return nil
198}
199
200func (w *walker) Struct(s reflect.Value) error {
201	if w.ignoring() {
202		return nil
203	}
204
205	var v reflect.Value
206	if c, ok := Copiers[s.Type()]; ok {
207		// We have a Copier for this struct, so we use that copier to
208		// get the copy, and we ignore anything deeper than this.
209		w.ignoreDepth = w.depth
210
211		dup, err := c(s.Interface())
212		if err != nil {
213			return err
214		}
215
216		v = reflect.ValueOf(dup)
217	} else {
218		// No copier, we copy ourselves and allow reflectwalk to guide
219		// us deeper into the structure for copying.
220		v = reflect.New(s.Type())
221	}
222
223	// Push the value onto the value stack for setting the struct field,
224	// and add the struct itself to the containers stack in case we walk
225	// deeper so that its own fields can be modified.
226	w.valPush(v)
227	w.cs = append(w.cs, v)
228
229	return nil
230}
231
232func (w *walker) StructField(f reflect.StructField, v reflect.Value) error {
233	if w.ignoring() {
234		return nil
235	}
236
237	// Push the field onto the stack, we'll handle it when we exit
238	// the struct field in Exit...
239	w.valPush(reflect.ValueOf(f))
240	return nil
241}
242
243func (w *walker) ignoring() bool {
244	return w.ignoreDepth > 0 && w.depth >= w.ignoreDepth
245}
246
247func (w *walker) pointerPeek() bool {
248	return w.ps[len(w.ps)-1]
249}
250
251func (w *walker) valPop() reflect.Value {
252	result := w.vals[len(w.vals)-1]
253	w.vals = w.vals[:len(w.vals)-1]
254
255	// If we're out of values, that means we popped everything off. In
256	// this case, we reset the result so the next pushed value becomes
257	// the result.
258	if len(w.vals) == 0 {
259		w.Result = nil
260	}
261
262	return result
263}
264
265func (w *walker) valPush(v reflect.Value) {
266	w.vals = append(w.vals, v)
267
268	// If we haven't set the result yet, then this is the result since
269	// it is the first (outermost) value we're seeing.
270	if w.Result == nil && v.IsValid() {
271		w.Result = v.Interface()
272	}
273}
274
275func (w *walker) replacePointerMaybe() {
276	// Determine the last pointer value. If it is NOT a pointer, then
277	// we need to push that onto the stack.
278	if !w.pointerPeek() {
279		w.valPush(reflect.Indirect(w.valPop()))
280	}
281}
282