1package pretty
2
3import (
4	"fmt"
5	"io"
6	"reflect"
7)
8
9type sbuf []string
10
11func (p *sbuf) Printf(format string, a ...interface{}) {
12	s := fmt.Sprintf(format, a...)
13	*p = append(*p, s)
14}
15
16// Diff returns a slice where each element describes
17// a difference between a and b.
18func Diff(a, b interface{}) (desc []string) {
19	Pdiff((*sbuf)(&desc), a, b)
20	return desc
21}
22
23// wprintfer calls Fprintf on w for each Printf call
24// with a trailing newline.
25type wprintfer struct{ w io.Writer }
26
27func (p *wprintfer) Printf(format string, a ...interface{}) {
28	fmt.Fprintf(p.w, format+"\n", a...)
29}
30
31// Fdiff writes to w a description of the differences between a and b.
32func Fdiff(w io.Writer, a, b interface{}) {
33	Pdiff(&wprintfer{w}, a, b)
34}
35
36type Printfer interface {
37	Printf(format string, a ...interface{})
38}
39
40// Pdiff prints to p a description of the differences between a and b.
41// It calls Printf once for each difference, with no trailing newline.
42// The standard library log.Logger is a Printfer.
43func Pdiff(p Printfer, a, b interface{}) {
44	diffPrinter{w: p}.diff(reflect.ValueOf(a), reflect.ValueOf(b))
45}
46
47type Logfer interface {
48	Logf(format string, a ...interface{})
49}
50
51// logprintfer calls Fprintf on w for each Printf call
52// with a trailing newline.
53type logprintfer struct{ l Logfer }
54
55func (p *logprintfer) Printf(format string, a ...interface{}) {
56	p.l.Logf(format, a...)
57}
58
59// Ldiff prints to l a description of the differences between a and b.
60// It calls Logf once for each difference, with no trailing newline.
61// The standard library testing.T and testing.B are Logfers.
62func Ldiff(l Logfer, a, b interface{}) {
63	Pdiff(&logprintfer{l}, a, b)
64}
65
66type diffPrinter struct {
67	w Printfer
68	l string // label
69}
70
71func (w diffPrinter) printf(f string, a ...interface{}) {
72	var l string
73	if w.l != "" {
74		l = w.l + ": "
75	}
76	w.w.Printf(l+f, a...)
77}
78
79func (w diffPrinter) diff(av, bv reflect.Value) {
80	if !av.IsValid() && bv.IsValid() {
81		w.printf("nil != %# v", formatter{v: bv, quote: true})
82		return
83	}
84	if av.IsValid() && !bv.IsValid() {
85		w.printf("%# v != nil", formatter{v: av, quote: true})
86		return
87	}
88	if !av.IsValid() && !bv.IsValid() {
89		return
90	}
91
92	at := av.Type()
93	bt := bv.Type()
94	if at != bt {
95		w.printf("%v != %v", at, bt)
96		return
97	}
98
99	switch kind := at.Kind(); kind {
100	case reflect.Bool:
101		if a, b := av.Bool(), bv.Bool(); a != b {
102			w.printf("%v != %v", a, b)
103		}
104	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
105		if a, b := av.Int(), bv.Int(); a != b {
106			w.printf("%d != %d", a, b)
107		}
108	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
109		if a, b := av.Uint(), bv.Uint(); a != b {
110			w.printf("%d != %d", a, b)
111		}
112	case reflect.Float32, reflect.Float64:
113		if a, b := av.Float(), bv.Float(); a != b {
114			w.printf("%v != %v", a, b)
115		}
116	case reflect.Complex64, reflect.Complex128:
117		if a, b := av.Complex(), bv.Complex(); a != b {
118			w.printf("%v != %v", a, b)
119		}
120	case reflect.Array:
121		n := av.Len()
122		for i := 0; i < n; i++ {
123			w.relabel(fmt.Sprintf("[%d]", i)).diff(av.Index(i), bv.Index(i))
124		}
125	case reflect.Chan, reflect.Func, reflect.UnsafePointer:
126		if a, b := av.Pointer(), bv.Pointer(); a != b {
127			w.printf("%#x != %#x", a, b)
128		}
129	case reflect.Interface:
130		w.diff(av.Elem(), bv.Elem())
131	case reflect.Map:
132		ak, both, bk := keyDiff(av.MapKeys(), bv.MapKeys())
133		for _, k := range ak {
134			w := w.relabel(fmt.Sprintf("[%#v]", k))
135			w.printf("%q != (missing)", av.MapIndex(k))
136		}
137		for _, k := range both {
138			w := w.relabel(fmt.Sprintf("[%#v]", k))
139			w.diff(av.MapIndex(k), bv.MapIndex(k))
140		}
141		for _, k := range bk {
142			w := w.relabel(fmt.Sprintf("[%#v]", k))
143			w.printf("(missing) != %q", bv.MapIndex(k))
144		}
145	case reflect.Ptr:
146		switch {
147		case av.IsNil() && !bv.IsNil():
148			w.printf("nil != %# v", formatter{v: bv, quote: true})
149		case !av.IsNil() && bv.IsNil():
150			w.printf("%# v != nil", formatter{v: av, quote: true})
151		case !av.IsNil() && !bv.IsNil():
152			w.diff(av.Elem(), bv.Elem())
153		}
154	case reflect.Slice:
155		lenA := av.Len()
156		lenB := bv.Len()
157		if lenA != lenB {
158			w.printf("%s[%d] != %s[%d]", av.Type(), lenA, bv.Type(), lenB)
159			break
160		}
161		for i := 0; i < lenA; i++ {
162			w.relabel(fmt.Sprintf("[%d]", i)).diff(av.Index(i), bv.Index(i))
163		}
164	case reflect.String:
165		if a, b := av.String(), bv.String(); a != b {
166			w.printf("%q != %q", a, b)
167		}
168	case reflect.Struct:
169		for i := 0; i < av.NumField(); i++ {
170			w.relabel(at.Field(i).Name).diff(av.Field(i), bv.Field(i))
171		}
172	default:
173		panic("unknown reflect Kind: " + kind.String())
174	}
175}
176
177func (d diffPrinter) relabel(name string) (d1 diffPrinter) {
178	d1 = d
179	if d.l != "" && name[0] != '[' {
180		d1.l += "."
181	}
182	d1.l += name
183	return d1
184}
185
186// keyEqual compares a and b for equality.
187// Both a and b must be valid map keys.
188func keyEqual(av, bv reflect.Value) bool {
189	if !av.IsValid() && !bv.IsValid() {
190		return true
191	}
192	if !av.IsValid() || !bv.IsValid() || av.Type() != bv.Type() {
193		return false
194	}
195	switch kind := av.Kind(); kind {
196	case reflect.Bool:
197		a, b := av.Bool(), bv.Bool()
198		return a == b
199	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
200		a, b := av.Int(), bv.Int()
201		return a == b
202	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
203		a, b := av.Uint(), bv.Uint()
204		return a == b
205	case reflect.Float32, reflect.Float64:
206		a, b := av.Float(), bv.Float()
207		return a == b
208	case reflect.Complex64, reflect.Complex128:
209		a, b := av.Complex(), bv.Complex()
210		return a == b
211	case reflect.Array:
212		for i := 0; i < av.Len(); i++ {
213			if !keyEqual(av.Index(i), bv.Index(i)) {
214				return false
215			}
216		}
217		return true
218	case reflect.Chan, reflect.UnsafePointer, reflect.Ptr:
219		a, b := av.Pointer(), bv.Pointer()
220		return a == b
221	case reflect.Interface:
222		return keyEqual(av.Elem(), bv.Elem())
223	case reflect.String:
224		a, b := av.String(), bv.String()
225		return a == b
226	case reflect.Struct:
227		for i := 0; i < av.NumField(); i++ {
228			if !keyEqual(av.Field(i), bv.Field(i)) {
229				return false
230			}
231		}
232		return true
233	default:
234		panic("invalid map key type " + av.Type().String())
235	}
236}
237
238func keyDiff(a, b []reflect.Value) (ak, both, bk []reflect.Value) {
239	for _, av := range a {
240		inBoth := false
241		for _, bv := range b {
242			if keyEqual(av, bv) {
243				inBoth = true
244				both = append(both, av)
245				break
246			}
247		}
248		if !inBoth {
249			ak = append(ak, av)
250		}
251	}
252	for _, bv := range b {
253		inBoth := false
254		for _, av := range a {
255			if keyEqual(av, bv) {
256				inBoth = true
257				break
258			}
259		}
260		if !inBoth {
261			bk = append(bk, bv)
262		}
263	}
264	return
265}
266