1// Package deep provides function deep.Equal which is like reflect.DeepEqual but
2// returns a list of differences. This is helpful when comparing complex types
3// like structures and maps.
4package deep
5
6import (
7	"errors"
8	"fmt"
9	"log"
10	"reflect"
11	"strings"
12)
13
14var (
15	// FloatPrecision is the number of decimal places to round float values
16	// to when comparing.
17	FloatPrecision = 10
18
19	// MaxDiff specifies the maximum number of differences to return.
20	MaxDiff = 10
21
22	// MaxDepth specifies the maximum levels of a struct to recurse into,
23	// if greater than zero. If zero, there is no limit.
24	MaxDepth = 0
25
26	// LogErrors causes errors to be logged to STDERR when true.
27	LogErrors = false
28
29	// CompareUnexportedFields causes unexported struct fields, like s in
30	// T{s int}, to be compared when true.
31	CompareUnexportedFields = false
32)
33
34var (
35	// ErrMaxRecursion is logged when MaxDepth is reached.
36	ErrMaxRecursion = errors.New("recursed to MaxDepth")
37
38	// ErrTypeMismatch is logged when Equal passed two different types of values.
39	ErrTypeMismatch = errors.New("variables are different reflect.Type")
40
41	// ErrNotHandled is logged when a primitive Go kind is not handled.
42	ErrNotHandled = errors.New("cannot compare the reflect.Kind")
43)
44
45type cmp struct {
46	diff        []string
47	buff        []string
48	floatFormat string
49}
50
51var errorType = reflect.TypeOf((*error)(nil)).Elem()
52
53// Equal compares variables a and b, recursing into their structure up to
54// MaxDepth levels deep (if greater than zero), and returns a list of differences,
55// or nil if there are none. Some differences may not be found if an error is
56// also returned.
57//
58// If a type has an Equal method, like time.Equal, it is called to check for
59// equality.
60func Equal(a, b interface{}) []string {
61	aVal := reflect.ValueOf(a)
62	bVal := reflect.ValueOf(b)
63	c := &cmp{
64		diff:        []string{},
65		buff:        []string{},
66		floatFormat: fmt.Sprintf("%%.%df", FloatPrecision),
67	}
68	if a == nil && b == nil {
69		return nil
70	} else if a == nil && b != nil {
71		c.saveDiff("<nil pointer>", b)
72	} else if a != nil && b == nil {
73		c.saveDiff(a, "<nil pointer>")
74	}
75	if len(c.diff) > 0 {
76		return c.diff
77	}
78
79	c.equals(aVal, bVal, 0)
80	if len(c.diff) > 0 {
81		return c.diff // diffs
82	}
83	return nil // no diffs
84}
85
86func (c *cmp) equals(a, b reflect.Value, level int) {
87	if MaxDepth > 0 && level > MaxDepth {
88		logError(ErrMaxRecursion)
89		return
90	}
91
92	// Check if one value is nil, e.g. T{x: *X} and T.x is nil
93	if !a.IsValid() || !b.IsValid() {
94		if a.IsValid() && !b.IsValid() {
95			c.saveDiff(a.Type(), "<nil pointer>")
96		} else if !a.IsValid() && b.IsValid() {
97			c.saveDiff("<nil pointer>", b.Type())
98		}
99		return
100	}
101
102	// If differenet types, they can't be equal
103	aType := a.Type()
104	bType := b.Type()
105	if aType != bType {
106		c.saveDiff(aType, bType)
107		logError(ErrTypeMismatch)
108		return
109	}
110
111	// Primitive https://golang.org/pkg/reflect/#Kind
112	aKind := a.Kind()
113	bKind := b.Kind()
114
115	// If both types implement the error interface, compare the error strings.
116	// This must be done before dereferencing because the interface is on a
117	// pointer receiver.
118	if aType.Implements(errorType) && bType.Implements(errorType) {
119		if a.Elem().IsValid() && b.Elem().IsValid() { // both err != nil
120			aString := a.MethodByName("Error").Call(nil)[0].String()
121			bString := b.MethodByName("Error").Call(nil)[0].String()
122			if aString != bString {
123				c.saveDiff(aString, bString)
124				return
125			}
126		}
127	}
128
129	// Dereference pointers and interface{}
130	if aElem, bElem := aKind == reflect.Ptr || aKind == reflect.Interface,
131		bKind == reflect.Ptr || bKind == reflect.Interface; aElem || bElem {
132
133		if aElem {
134			a = a.Elem()
135		}
136
137		if bElem {
138			b = b.Elem()
139		}
140
141		c.equals(a, b, level+1)
142		return
143	}
144
145	switch aKind {
146
147	/////////////////////////////////////////////////////////////////////
148	// Iterable kinds
149	/////////////////////////////////////////////////////////////////////
150
151	case reflect.Struct:
152		/*
153			The variables are structs like:
154				type T struct {
155					FirstName string
156					LastName  string
157				}
158			Type = <pkg>.T, Kind = reflect.Struct
159
160			Iterate through the fields (FirstName, LastName), recurse into their values.
161		*/
162
163		// Types with an Equal() method, like time.Time, only if struct field
164		// is exported (CanInterface)
165		if eqFunc := a.MethodByName("Equal"); eqFunc.IsValid() && eqFunc.CanInterface() {
166			// Handle https://github.com/go-test/deep/issues/15:
167			// Don't call T.Equal if the method is from an embedded struct, like:
168			//   type Foo struct { time.Time }
169			// First, we'll encounter Equal(Ttime, time.Time) but if we pass b
170			// as the 2nd arg we'll panic: "Call using pkg.Foo as type time.Time"
171			// As far as I can tell, there's no way to see that the method is from
172			// time.Time not Foo. So we check the type of the 1st (0) arg and skip
173			// unless it's b type. Later, we'll encounter the time.Time anonymous/
174			// embedded field and then we'll have Equal(time.Time, time.Time).
175			funcType := eqFunc.Type()
176			if funcType.NumIn() == 1 && funcType.In(0) == bType {
177				retVals := eqFunc.Call([]reflect.Value{b})
178				if !retVals[0].Bool() {
179					c.saveDiff(a, b)
180				}
181				return
182			}
183		}
184
185		for i := 0; i < a.NumField(); i++ {
186			if aType.Field(i).PkgPath != "" && !CompareUnexportedFields {
187				continue // skip unexported field, e.g. s in type T struct {s string}
188			}
189
190			c.push(aType.Field(i).Name) // push field name to buff
191
192			// Get the Value for each field, e.g. FirstName has Type = string,
193			// Kind = reflect.String.
194			af := a.Field(i)
195			bf := b.Field(i)
196
197			// Recurse to compare the field values
198			c.equals(af, bf, level+1)
199
200			c.pop() // pop field name from buff
201
202			if len(c.diff) >= MaxDiff {
203				break
204			}
205		}
206	case reflect.Map:
207		/*
208			The variables are maps like:
209				map[string]int{
210					"foo": 1,
211					"bar": 2,
212				}
213			Type = map[string]int, Kind = reflect.Map
214
215			Or:
216				type T map[string]int{}
217			Type = <pkg>.T, Kind = reflect.Map
218
219			Iterate through the map keys (foo, bar), recurse into their values.
220		*/
221
222		if a.IsNil() || b.IsNil() {
223			if a.IsNil() && !b.IsNil() {
224				c.saveDiff("<nil map>", b)
225			} else if !a.IsNil() && b.IsNil() {
226				c.saveDiff(a, "<nil map>")
227			}
228			return
229		}
230
231		if a.Pointer() == b.Pointer() {
232			return
233		}
234
235		for _, key := range a.MapKeys() {
236			c.push(fmt.Sprintf("map[%s]", key))
237
238			aVal := a.MapIndex(key)
239			bVal := b.MapIndex(key)
240			if bVal.IsValid() {
241				c.equals(aVal, bVal, level+1)
242			} else {
243				c.saveDiff(aVal, "<does not have key>")
244			}
245
246			c.pop()
247
248			if len(c.diff) >= MaxDiff {
249				return
250			}
251		}
252
253		for _, key := range b.MapKeys() {
254			if aVal := a.MapIndex(key); aVal.IsValid() {
255				continue
256			}
257
258			c.push(fmt.Sprintf("map[%s]", key))
259			c.saveDiff("<does not have key>", b.MapIndex(key))
260			c.pop()
261			if len(c.diff) >= MaxDiff {
262				return
263			}
264		}
265	case reflect.Array:
266		n := a.Len()
267		for i := 0; i < n; i++ {
268			c.push(fmt.Sprintf("array[%d]", i))
269			c.equals(a.Index(i), b.Index(i), level+1)
270			c.pop()
271			if len(c.diff) >= MaxDiff {
272				break
273			}
274		}
275	case reflect.Slice:
276		if a.IsNil() || b.IsNil() {
277			if a.IsNil() && !b.IsNil() {
278				c.saveDiff("<nil slice>", b)
279			} else if !a.IsNil() && b.IsNil() {
280				c.saveDiff(a, "<nil slice>")
281			}
282			return
283		}
284
285		aLen := a.Len()
286		bLen := b.Len()
287
288		if a.Pointer() == b.Pointer() && aLen == bLen {
289			return
290		}
291
292		n := aLen
293		if bLen > aLen {
294			n = bLen
295		}
296		for i := 0; i < n; i++ {
297			c.push(fmt.Sprintf("slice[%d]", i))
298			if i < aLen && i < bLen {
299				c.equals(a.Index(i), b.Index(i), level+1)
300			} else if i < aLen {
301				c.saveDiff(a.Index(i), "<no value>")
302			} else {
303				c.saveDiff("<no value>", b.Index(i))
304			}
305			c.pop()
306			if len(c.diff) >= MaxDiff {
307				break
308			}
309		}
310
311	/////////////////////////////////////////////////////////////////////
312	// Primitive kinds
313	/////////////////////////////////////////////////////////////////////
314
315	case reflect.Float32, reflect.Float64:
316		// Avoid 0.04147685731961082 != 0.041476857319611
317		// 6 decimal places is close enough
318		aval := fmt.Sprintf(c.floatFormat, a.Float())
319		bval := fmt.Sprintf(c.floatFormat, b.Float())
320		if aval != bval {
321			c.saveDiff(a.Float(), b.Float())
322		}
323	case reflect.Bool:
324		if a.Bool() != b.Bool() {
325			c.saveDiff(a.Bool(), b.Bool())
326		}
327	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
328		if a.Int() != b.Int() {
329			c.saveDiff(a.Int(), b.Int())
330		}
331	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
332		if a.Uint() != b.Uint() {
333			c.saveDiff(a.Uint(), b.Uint())
334		}
335	case reflect.String:
336		if a.String() != b.String() {
337			c.saveDiff(a.String(), b.String())
338		}
339
340	default:
341		logError(ErrNotHandled)
342	}
343}
344
345func (c *cmp) push(name string) {
346	c.buff = append(c.buff, name)
347}
348
349func (c *cmp) pop() {
350	if len(c.buff) > 0 {
351		c.buff = c.buff[0 : len(c.buff)-1]
352	}
353}
354
355func (c *cmp) saveDiff(aval, bval interface{}) {
356	if len(c.buff) > 0 {
357		varName := strings.Join(c.buff, ".")
358		c.diff = append(c.diff, fmt.Sprintf("%s: %v != %v", varName, aval, bval))
359	} else {
360		c.diff = append(c.diff, fmt.Sprintf("%v != %v", aval, bval))
361	}
362}
363
364func logError(err error) {
365	if LogErrors {
366		log.Println(err)
367	}
368}
369