1/*Package opt provides common go-cmp.Options for use with assert.DeepEqual.
2 */
3package opt // import "gotest.tools/assert/opt"
4
5import (
6	"fmt"
7	"reflect"
8	"strings"
9	"time"
10
11	gocmp "github.com/google/go-cmp/cmp"
12)
13
14// DurationWithThreshold returns a gocmp.Comparer for comparing time.Duration. The
15// Comparer returns true if the difference between the two Duration values is
16// within the threshold and neither value is zero.
17func DurationWithThreshold(threshold time.Duration) gocmp.Option {
18	return gocmp.Comparer(cmpDuration(threshold))
19}
20
21func cmpDuration(threshold time.Duration) func(x, y time.Duration) bool {
22	return func(x, y time.Duration) bool {
23		if x == 0 || y == 0 {
24			return false
25		}
26		delta := x - y
27		return delta <= threshold && delta >= -threshold
28	}
29}
30
31// TimeWithThreshold returns a gocmp.Comparer for comparing time.Time. The
32// Comparer returns true if the difference between the two Time values is
33// within the threshold and neither value is zero.
34func TimeWithThreshold(threshold time.Duration) gocmp.Option {
35	return gocmp.Comparer(cmpTime(threshold))
36}
37
38func cmpTime(threshold time.Duration) func(x, y time.Time) bool {
39	return func(x, y time.Time) bool {
40		if x.IsZero() || y.IsZero() {
41			return false
42		}
43		delta := x.Sub(y)
44		return delta <= threshold && delta >= -threshold
45	}
46}
47
48// PathString is a gocmp.FilterPath filter that returns true when path.String()
49// matches any of the specs.
50//
51// The path spec is a dot separated string where each segment is a field name.
52// Slices, Arrays, and Maps are always matched against every element in the
53// sequence. gocmp.Indirect, gocmp.Transform, and gocmp.TypeAssertion are always
54// ignored.
55//
56// Note: this path filter is not type safe. Incorrect paths will be silently
57// ignored. Consider using a type safe path filter for more complex paths.
58func PathString(specs ...string) func(path gocmp.Path) bool {
59	return func(path gocmp.Path) bool {
60		for _, spec := range specs {
61			if path.String() == spec {
62				return true
63			}
64		}
65		return false
66	}
67}
68
69// PathDebug is a gocmp.FilerPath filter that always returns false. It prints
70// each path it receives. It can be used to debug path matching problems.
71func PathDebug(path gocmp.Path) bool {
72	fmt.Printf("PATH string=%s gostring=%s\n", path, path.GoString())
73	for _, step := range path {
74		fmt.Printf("  STEP %s\ttype=%s\t%s\n",
75			formatStepType(step), step.Type(), stepTypeFields(step))
76	}
77	return false
78}
79
80func formatStepType(step gocmp.PathStep) string {
81	return strings.Title(strings.TrimPrefix(reflect.TypeOf(step).String(), "*cmp."))
82}
83
84func stepTypeFields(step gocmp.PathStep) string {
85	switch typed := step.(type) {
86	case gocmp.StructField:
87		return fmt.Sprintf("name=%s", typed.Name())
88	case gocmp.MapIndex:
89		return fmt.Sprintf("key=%s", typed.Key().Interface())
90	case gocmp.Transform:
91		return fmt.Sprintf("name=%s", typed.Name())
92	case gocmp.SliceIndex:
93		return fmt.Sprintf("name=%d", typed.Key())
94	}
95	return ""
96}
97
98// PathField is a gocmp.FilerPath filter that matches a struct field by name.
99// PathField will match every instance of the field in a recursive or nested
100// structure.
101func PathField(structType interface{}, field string) func(gocmp.Path) bool {
102	typ := reflect.TypeOf(structType)
103	if typ.Kind() != reflect.Struct {
104		panic(fmt.Sprintf("type %s is not a struct", typ))
105	}
106	if _, ok := typ.FieldByName(field); !ok {
107		panic(fmt.Sprintf("type %s does not have field %s", typ, field))
108	}
109
110	return func(path gocmp.Path) bool {
111		return path.Index(-2).Type() == typ && isStructField(path.Index(-1), field)
112	}
113}
114
115func isStructField(step gocmp.PathStep, name string) bool {
116	field, ok := step.(gocmp.StructField)
117	return ok && field.Name() == name
118}
119