1// Package deep provides function deep.Equal which is like reflect.DeepEqual but 2// returns a list of differences. This is helpful when comparing complex types 3// like structures and maps. 4package deep 5 6import ( 7 "errors" 8 "fmt" 9 "log" 10 "reflect" 11 "strings" 12) 13 14var ( 15 // FloatPrecision is the number of decimal places to round float values 16 // to when comparing. 17 FloatPrecision = 10 18 19 // MaxDiff specifies the maximum number of differences to return. 20 MaxDiff = 10 21 22 // MaxDepth specifies the maximum levels of a struct to recurse into, 23 // if greater than zero. If zero, there is no limit. 24 MaxDepth = 0 25 26 // LogErrors causes errors to be logged to STDERR when true. 27 LogErrors = false 28 29 // CompareUnexportedFields causes unexported struct fields, like s in 30 // T{s int}, to be compared when true. 31 CompareUnexportedFields = false 32) 33 34var ( 35 // ErrMaxRecursion is logged when MaxDepth is reached. 36 ErrMaxRecursion = errors.New("recursed to MaxDepth") 37 38 // ErrTypeMismatch is logged when Equal passed two different types of values. 39 ErrTypeMismatch = errors.New("variables are different reflect.Type") 40 41 // ErrNotHandled is logged when a primitive Go kind is not handled. 42 ErrNotHandled = errors.New("cannot compare the reflect.Kind") 43) 44 45type cmp struct { 46 diff []string 47 buff []string 48 floatFormat string 49} 50 51var errorType = reflect.TypeOf((*error)(nil)).Elem() 52 53// Equal compares variables a and b, recursing into their structure up to 54// MaxDepth levels deep (if greater than zero), and returns a list of differences, 55// or nil if there are none. Some differences may not be found if an error is 56// also returned. 57// 58// If a type has an Equal method, like time.Equal, it is called to check for 59// equality. 60func Equal(a, b interface{}) []string { 61 aVal := reflect.ValueOf(a) 62 bVal := reflect.ValueOf(b) 63 c := &cmp{ 64 diff: []string{}, 65 buff: []string{}, 66 floatFormat: fmt.Sprintf("%%.%df", FloatPrecision), 67 } 68 if a == nil && b == nil { 69 return nil 70 } else if a == nil && b != nil { 71 c.saveDiff("<nil pointer>", b) 72 } else if a != nil && b == nil { 73 c.saveDiff(a, "<nil pointer>") 74 } 75 if len(c.diff) > 0 { 76 return c.diff 77 } 78 79 c.equals(aVal, bVal, 0) 80 if len(c.diff) > 0 { 81 return c.diff // diffs 82 } 83 return nil // no diffs 84} 85 86func (c *cmp) equals(a, b reflect.Value, level int) { 87 if MaxDepth > 0 && level > MaxDepth { 88 logError(ErrMaxRecursion) 89 return 90 } 91 92 // Check if one value is nil, e.g. T{x: *X} and T.x is nil 93 if !a.IsValid() || !b.IsValid() { 94 if a.IsValid() && !b.IsValid() { 95 c.saveDiff(a.Type(), "<nil pointer>") 96 } else if !a.IsValid() && b.IsValid() { 97 c.saveDiff("<nil pointer>", b.Type()) 98 } 99 return 100 } 101 102 // If differenet types, they can't be equal 103 aType := a.Type() 104 bType := b.Type() 105 if aType != bType { 106 c.saveDiff(aType, bType) 107 logError(ErrTypeMismatch) 108 return 109 } 110 111 // Primitive https://golang.org/pkg/reflect/#Kind 112 aKind := a.Kind() 113 bKind := b.Kind() 114 115 // If both types implement the error interface, compare the error strings. 116 // This must be done before dereferencing because the interface is on a 117 // pointer receiver. 118 if aType.Implements(errorType) && bType.Implements(errorType) { 119 if a.Elem().IsValid() && b.Elem().IsValid() { // both err != nil 120 aString := a.MethodByName("Error").Call(nil)[0].String() 121 bString := b.MethodByName("Error").Call(nil)[0].String() 122 if aString != bString { 123 c.saveDiff(aString, bString) 124 return 125 } 126 } 127 } 128 129 // Dereference pointers and interface{} 130 if aElem, bElem := aKind == reflect.Ptr || aKind == reflect.Interface, 131 bKind == reflect.Ptr || bKind == reflect.Interface; aElem || bElem { 132 133 if aElem { 134 a = a.Elem() 135 } 136 137 if bElem { 138 b = b.Elem() 139 } 140 141 c.equals(a, b, level+1) 142 return 143 } 144 145 switch aKind { 146 147 ///////////////////////////////////////////////////////////////////// 148 // Iterable kinds 149 ///////////////////////////////////////////////////////////////////// 150 151 case reflect.Struct: 152 /* 153 The variables are structs like: 154 type T struct { 155 FirstName string 156 LastName string 157 } 158 Type = <pkg>.T, Kind = reflect.Struct 159 160 Iterate through the fields (FirstName, LastName), recurse into their values. 161 */ 162 163 // Types with an Equal() method, like time.Time, only if struct field 164 // is exported (CanInterface) 165 if eqFunc := a.MethodByName("Equal"); eqFunc.IsValid() && eqFunc.CanInterface() { 166 // Handle https://github.com/go-test/deep/issues/15: 167 // Don't call T.Equal if the method is from an embedded struct, like: 168 // type Foo struct { time.Time } 169 // First, we'll encounter Equal(Ttime, time.Time) but if we pass b 170 // as the 2nd arg we'll panic: "Call using pkg.Foo as type time.Time" 171 // As far as I can tell, there's no way to see that the method is from 172 // time.Time not Foo. So we check the type of the 1st (0) arg and skip 173 // unless it's b type. Later, we'll encounter the time.Time anonymous/ 174 // embedded field and then we'll have Equal(time.Time, time.Time). 175 funcType := eqFunc.Type() 176 if funcType.NumIn() == 1 && funcType.In(0) == bType { 177 retVals := eqFunc.Call([]reflect.Value{b}) 178 if !retVals[0].Bool() { 179 c.saveDiff(a, b) 180 } 181 return 182 } 183 } 184 185 for i := 0; i < a.NumField(); i++ { 186 if aType.Field(i).PkgPath != "" && !CompareUnexportedFields { 187 continue // skip unexported field, e.g. s in type T struct {s string} 188 } 189 190 c.push(aType.Field(i).Name) // push field name to buff 191 192 // Get the Value for each field, e.g. FirstName has Type = string, 193 // Kind = reflect.String. 194 af := a.Field(i) 195 bf := b.Field(i) 196 197 // Recurse to compare the field values 198 c.equals(af, bf, level+1) 199 200 c.pop() // pop field name from buff 201 202 if len(c.diff) >= MaxDiff { 203 break 204 } 205 } 206 case reflect.Map: 207 /* 208 The variables are maps like: 209 map[string]int{ 210 "foo": 1, 211 "bar": 2, 212 } 213 Type = map[string]int, Kind = reflect.Map 214 215 Or: 216 type T map[string]int{} 217 Type = <pkg>.T, Kind = reflect.Map 218 219 Iterate through the map keys (foo, bar), recurse into their values. 220 */ 221 222 if a.IsNil() || b.IsNil() { 223 if a.IsNil() && !b.IsNil() { 224 c.saveDiff("<nil map>", b) 225 } else if !a.IsNil() && b.IsNil() { 226 c.saveDiff(a, "<nil map>") 227 } 228 return 229 } 230 231 if a.Pointer() == b.Pointer() { 232 return 233 } 234 235 for _, key := range a.MapKeys() { 236 c.push(fmt.Sprintf("map[%s]", key)) 237 238 aVal := a.MapIndex(key) 239 bVal := b.MapIndex(key) 240 if bVal.IsValid() { 241 c.equals(aVal, bVal, level+1) 242 } else { 243 c.saveDiff(aVal, "<does not have key>") 244 } 245 246 c.pop() 247 248 if len(c.diff) >= MaxDiff { 249 return 250 } 251 } 252 253 for _, key := range b.MapKeys() { 254 if aVal := a.MapIndex(key); aVal.IsValid() { 255 continue 256 } 257 258 c.push(fmt.Sprintf("map[%s]", key)) 259 c.saveDiff("<does not have key>", b.MapIndex(key)) 260 c.pop() 261 if len(c.diff) >= MaxDiff { 262 return 263 } 264 } 265 case reflect.Array: 266 n := a.Len() 267 for i := 0; i < n; i++ { 268 c.push(fmt.Sprintf("array[%d]", i)) 269 c.equals(a.Index(i), b.Index(i), level+1) 270 c.pop() 271 if len(c.diff) >= MaxDiff { 272 break 273 } 274 } 275 case reflect.Slice: 276 if a.IsNil() || b.IsNil() { 277 if a.IsNil() && !b.IsNil() { 278 c.saveDiff("<nil slice>", b) 279 } else if !a.IsNil() && b.IsNil() { 280 c.saveDiff(a, "<nil slice>") 281 } 282 return 283 } 284 285 aLen := a.Len() 286 bLen := b.Len() 287 288 if a.Pointer() == b.Pointer() && aLen == bLen { 289 return 290 } 291 292 n := aLen 293 if bLen > aLen { 294 n = bLen 295 } 296 for i := 0; i < n; i++ { 297 c.push(fmt.Sprintf("slice[%d]", i)) 298 if i < aLen && i < bLen { 299 c.equals(a.Index(i), b.Index(i), level+1) 300 } else if i < aLen { 301 c.saveDiff(a.Index(i), "<no value>") 302 } else { 303 c.saveDiff("<no value>", b.Index(i)) 304 } 305 c.pop() 306 if len(c.diff) >= MaxDiff { 307 break 308 } 309 } 310 311 ///////////////////////////////////////////////////////////////////// 312 // Primitive kinds 313 ///////////////////////////////////////////////////////////////////// 314 315 case reflect.Float32, reflect.Float64: 316 // Avoid 0.04147685731961082 != 0.041476857319611 317 // 6 decimal places is close enough 318 aval := fmt.Sprintf(c.floatFormat, a.Float()) 319 bval := fmt.Sprintf(c.floatFormat, b.Float()) 320 if aval != bval { 321 c.saveDiff(a.Float(), b.Float()) 322 } 323 case reflect.Bool: 324 if a.Bool() != b.Bool() { 325 c.saveDiff(a.Bool(), b.Bool()) 326 } 327 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 328 if a.Int() != b.Int() { 329 c.saveDiff(a.Int(), b.Int()) 330 } 331 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 332 if a.Uint() != b.Uint() { 333 c.saveDiff(a.Uint(), b.Uint()) 334 } 335 case reflect.String: 336 if a.String() != b.String() { 337 c.saveDiff(a.String(), b.String()) 338 } 339 340 default: 341 logError(ErrNotHandled) 342 } 343} 344 345func (c *cmp) push(name string) { 346 c.buff = append(c.buff, name) 347} 348 349func (c *cmp) pop() { 350 if len(c.buff) > 0 { 351 c.buff = c.buff[0 : len(c.buff)-1] 352 } 353} 354 355func (c *cmp) saveDiff(aval, bval interface{}) { 356 if len(c.buff) > 0 { 357 varName := strings.Join(c.buff, ".") 358 c.diff = append(c.diff, fmt.Sprintf("%s: %v != %v", varName, aval, bval)) 359 } else { 360 c.diff = append(c.diff, fmt.Sprintf("%v != %v", aval, bval)) 361 } 362} 363 364func logError(err error) { 365 if LogErrors { 366 log.Println(err) 367 } 368} 369