1package evaltest
2
3import (
4	"fmt"
5	"math"
6	"reflect"
7	"regexp"
8
9	"src.elv.sh/pkg/eval"
10)
11
12// ApproximatelyThreshold defines the threshold for matching float64 values when
13// using Approximately.
14const ApproximatelyThreshold = 1e-15
15
16// Approximately can be passed to Case.Puts to match a float64 within the
17// threshold defined by ApproximatelyThreshold.
18type Approximately struct{ F float64 }
19
20func matchFloat64(a, b, threshold float64) bool {
21	if math.IsNaN(a) && math.IsNaN(b) {
22		return true
23	}
24	if math.IsInf(a, 0) && math.IsInf(b, 0) &&
25		math.Signbit(a) == math.Signbit(b) {
26		return true
27	}
28	return math.Abs(a-b) <= threshold
29}
30
31// MatchingRegexp can be passed to Case.Puts to match a any string that matches
32// a regexp pattern. If the pattern is not a valid regexp, the test will panic.
33type MatchingRegexp struct{ Pattern string }
34
35func matchRegexp(p, s string) bool {
36	matched, err := regexp.MatchString(p, s)
37	if err != nil {
38		panic(err)
39	}
40	return matched
41}
42
43type errorMatcher interface{ matchError(error) bool }
44
45// AnyError is an error that can be passed to Case.Throws to match any non-nil
46// error.
47var AnyError = anyError{}
48
49// An errorMatcher for any error.
50type anyError struct{}
51
52func (anyError) Error() string { return "any error" }
53
54func (anyError) matchError(e error) bool { return e != nil }
55
56// An errorMatcher for exceptions.
57type exc struct {
58	reason error
59	stacks []string
60}
61
62func (e exc) Error() string {
63	if len(e.stacks) == 0 {
64		return fmt.Sprintf("exception with reason %v", e.reason)
65	}
66	return fmt.Sprintf("exception with reason %v and stacks %v", e.reason, e.stacks)
67}
68
69func (e exc) matchError(e2 error) bool {
70	if e2, ok := e2.(eval.Exception); ok {
71		return matchErr(e.reason, e2.Reason()) &&
72			(len(e.stacks) == 0 ||
73				reflect.DeepEqual(e.stacks, getStackTexts(e2.StackTrace())))
74	}
75	return false
76}
77
78func getStackTexts(tb *eval.StackTrace) []string {
79	texts := []string{}
80	for tb != nil {
81		ctx := tb.Head
82		texts = append(texts, ctx.Source[ctx.From:ctx.To])
83		tb = tb.Next
84	}
85	return texts
86}
87
88// ErrorWithType returns an error that can be passed to the Case.Throws to match
89// any error with the same type as the argument.
90func ErrorWithType(v error) error { return errWithType{v} }
91
92// An errorMatcher for any error with the given type.
93type errWithType struct{ v error }
94
95func (e errWithType) Error() string { return fmt.Sprintf("error with type %T", e.v) }
96
97func (e errWithType) matchError(e2 error) bool {
98	return reflect.TypeOf(e.v) == reflect.TypeOf(e2)
99}
100
101// ErrorWithMessage returns an error that can be passed to Case.Throws to match
102// any error with the given message.
103func ErrorWithMessage(msg string) error { return errWithMessage{msg} }
104
105// An errorMatcher for any error with the given message.
106type errWithMessage struct{ msg string }
107
108func (e errWithMessage) Error() string { return "error with message " + e.msg }
109
110func (e errWithMessage) matchError(e2 error) bool {
111	return e2 != nil && e.msg == e2.Error()
112}
113
114// CmdExit returns an error that can be passed to Case.Throws to match an
115// eval.ExternalCmdExit ignoring the Pid field.
116func CmdExit(v eval.ExternalCmdExit) error { return errCmdExit{v} }
117
118// An errorMatcher for an ExternalCmdExit error that ignores the `Pid` member.
119// We only match the command name and exit status because at run time we
120// cannot know the correct value for `Pid`.
121type errCmdExit struct{ v eval.ExternalCmdExit }
122
123func (e errCmdExit) Error() string {
124	return e.v.Error()
125}
126
127func (e errCmdExit) matchError(gotErr error) bool {
128	if gotErr == nil {
129		return false
130	}
131	ge := gotErr.(eval.ExternalCmdExit)
132	return e.v.CmdName == ge.CmdName && e.v.WaitStatus == ge.WaitStatus
133}
134
135type errOneOf struct{ errs []error }
136
137func OneOfErrors(errs ...error) error { return errOneOf{errs} }
138
139func (e errOneOf) Error() string { return fmt.Sprint("one of", e.errs) }
140
141func (e errOneOf) matchError(gotError error) bool {
142	for _, want := range e.errs {
143		if matchErr(want, gotError) {
144			return true
145		}
146	}
147	return false
148}
149