1package verify
2
3import (
4	"fmt"
5	"reflect"
6	"strings"
7	"testing"
8)
9
10// Values verifies that got has all the content, and only the content, defined by want.
11// Note that NaN always results in a mismatch.
12func Values(tb testing.TB, name string, got, want interface{}) (ok bool) {
13	t := travel{}
14	t.values(reflect.ValueOf(got), reflect.ValueOf(want), nil)
15
16	fail := t.report(name)
17	if fail != "" {
18		tb.Helper()
19		tb.Error(fail)
20		return false
21	}
22
23	return true
24}
25
26func (t *travel) values(got, want reflect.Value, path []*segment) {
27	if !want.IsValid() {
28		if got.IsValid() {
29			t.differ(path, "Unwanted %s", got.Type())
30		}
31		return
32	}
33	if !got.IsValid() {
34		t.differ(path, "Missing %s", want.Type())
35		return
36	}
37
38	if got.Type() != want.Type() {
39		t.differ(path, "Got type %s, want %s", got.Type(), want.Type())
40		return
41	}
42
43	switch got.Kind() {
44
45	case reflect.Struct:
46		seg := &segment{format: "/%s"}
47		path = append(path, seg)
48
49		var unexp []string
50		for i, n := 0, got.NumField(); i < n; i++ {
51			field := got.Type().Field(i)
52			if field.PkgPath != "" {
53				unexp = append(unexp, field.Name)
54			} else {
55				seg.x = field.Name
56				t.values(got.Field(i), want.Field(i), path)
57			}
58		}
59		path = path[:len(path)-1]
60
61		if len(unexp) != 0 && !reflect.DeepEqual(got.Interface(), want.Interface()) {
62			t.differ(path, "Type %s with unexported fields %q not equal", got.Type(), unexp)
63		}
64
65	case reflect.Slice, reflect.Array:
66		n := got.Len()
67		if n != want.Len() {
68			t.differ(path, "Got %d elements, want %d", n, want.Len())
69			return
70		}
71
72		seg := &segment{format: "[%d]"}
73		path = append(path, seg)
74		for i := 0; i < n; i++ {
75			seg.x = i
76			t.values(got.Index(i), want.Index(i), path)
77		}
78		path = path[:len(path)-1]
79
80	case reflect.Ptr:
81		if got.Pointer() != want.Pointer() {
82			t.values(got.Elem(), want.Elem(), path)
83		}
84
85	case reflect.Interface:
86		t.values(got.Elem(), want.Elem(), path)
87
88	case reflect.Map:
89		seg := &segment{}
90		path = append(path, seg)
91		for _, key := range want.MapKeys() {
92			applyKeySeg(seg, key)
93			t.values(got.MapIndex(key), want.MapIndex(key), path)
94		}
95
96		for _, key := range got.MapKeys() {
97			v := want.MapIndex(key)
98			if v.IsValid() {
99				continue
100			}
101			applyKeySeg(seg, key)
102			t.values(got.MapIndex(key), v, path)
103		}
104		path = path[:len(path)-1]
105
106	case reflect.Func:
107		if !(got.IsNil() && want.IsNil()) {
108			t.differ(path, "Can't compare functions")
109		}
110
111	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
112		if a, b := got.Int(), want.Int(); a != b {
113			if a < 0xA && a > -0xA && b < 0xA && b > -0xA {
114				t.differ(path, fmt.Sprintf("Got %d, want %d", a, b))
115			} else {
116				t.differ(path, fmt.Sprintf("Got %d (0x%x), want %d (0x%x)", a, a, b, b))
117			}
118		}
119
120	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
121		if a, b := got.Uint(), want.Uint(); a != b {
122			if a < 0xA && b < 0xA {
123				t.differ(path, fmt.Sprintf("Got %d, want %d", a, b))
124			} else {
125				t.differ(path, fmt.Sprintf("Got %d (0x%x), want %d (0x%x)", a, a, b, b))
126			}
127		}
128
129	case reflect.String:
130		if a, b := got.String(), want.String(); a != b {
131			t.differ(path, differMsg(a, b))
132		}
133
134	default:
135		if a, b := got.Interface(), want.Interface(); a != b {
136			t.differ(path, fmt.Sprintf("Got %v, want %v", a, b))
137		}
138	}
139}
140
141func applyKeySeg(dst *segment, key reflect.Value) {
142	if key.Kind() == reflect.String {
143		dst.format = "[%q]"
144	} else {
145		dst.format = "[%v]"
146	}
147	dst.x = key.Interface()
148}
149
150func differMsg(got, want string) string {
151	if len(got) < 9 || len(want) < 9 {
152		return fmt.Sprintf("Got %q, want %q", got, want)
153	}
154
155	got, want = fmt.Sprintf("%q", got), fmt.Sprintf("%q", want)
156
157	// find first character which differs
158	var i int
159	a, b := []rune(got), []rune(want)
160	for i = 0; i < len(a); i++ {
161		if i >= len(b) || a[i] != b[i] {
162			break
163		}
164	}
165	return fmt.Sprintf("Got %s, want %s\n    %s^", got, want, strings.Repeat(" ", i))
166}
167