1// Licensed under the MIT license, see LICENCE file for details.
2
3package quicktest
4
5import (
6	"fmt"
7	"reflect"
8	"strings"
9	"sync"
10	"testing"
11)
12
13// Check runs the given check using the provided t and continues execution in
14// case of failure. For instance:
15//
16//     qt.Check(t, answer, qt.Equals, 42)
17//     qt.Check(t, got, qt.IsNil, qt.Commentf("iteration %d", i))
18//
19// Additional args (not consumed by the checker), when provided, are included as
20// comments in the failure output when the check fails.
21func Check(t testing.TB, got interface{}, checker Checker, args ...interface{}) bool {
22	t.Helper()
23	return New(t).Check(got, checker, args...)
24}
25
26// Assert runs the given check using the provided t and stops execution in case
27// of failure. For instance:
28//
29//     qt.Assert(t, got, qt.DeepEquals, []int{42, 47})
30//     qt.Assert(t, got, qt.ErrorMatches, "bad wolf .*", qt.Commentf("a comment"))
31//
32// Additional args (not consumed by the checker), when provided, are included as
33// comments in the failure output when the check fails.
34func Assert(t testing.TB, got interface{}, checker Checker, args ...interface{}) bool {
35	t.Helper()
36	return New(t).Assert(got, checker, args...)
37}
38
39// New returns a new checker instance that uses t to fail the test when checks
40// fail. It only ever calls the Fatal, Error and (when available) Run methods
41// of t. For instance.
42//
43//     func TestFoo(t *testing.T) {
44//         t.Run("A=42", func(t *testing.T) {
45//             c := qt.New(t)
46//             c.Assert(a, qt.Equals, 42)
47//         })
48//     }
49//
50// The library already provides some base checkers, and more can be added by
51// implementing the Checker interface.
52//
53// If there is a likelihood that Defer will be called, then
54// a call to Done should be deferred after calling New.
55// For example:
56//
57//     func TestFoo(t *testing.T) {
58//             c := qt.New(t)
59//             defer c.Done()
60//             c.Setenv("HOME", "/non-existent")
61//             c.Assert(os.Getenv("HOME"), qt.Equals, "/non-existent")
62//     })
63//
64// A value of C that's has a non-nil TB field but is otherwise zero is valid.
65// So:
66//
67//	c := &qt.C{TB: t}
68//
69// is valid a way to create a C value; it's exactly the same as:
70//
71//	c := qt.New(t)
72//
73// Methods on C may be called concurrently, assuming the underlying
74// `testing.TB` implementation also allows that.
75func New(t testing.TB) *C {
76	return &C{
77		TB: t,
78	}
79}
80
81// C is a quicktest checker. It embeds a testing.TB value and provides
82// additional checking functionality. If an Assert or Check operation fails, it
83// uses the wrapped TB value to fail the test appropriately.
84type C struct {
85	testing.TB
86
87	mu         sync.Mutex
88	doneNeeded bool
89	deferred   func()
90	format     formatFunc
91}
92
93// cleaner is implemented by testing.TB on Go 1.14 and later.
94type cleaner interface {
95	Cleanup(func())
96}
97
98// Defer registers a function to be called when c.Done is
99// called. Deferred functions will be called in last added, first called
100// order. If c.Done is not called by the end of the test, the test
101// may panic. Note that if Cleanup is called, there is no
102// need to call Done.
103//
104// Deprecated: in Go >= 1.14 use testing.TB.Cleanup instead.
105func (c *C) Defer(f func()) {
106	c.mu.Lock()
107	defer c.mu.Unlock()
108	if cleaner, ok := c.TB.(cleaner); ok {
109		// Use TB.Cleanup when available, but add a check
110		// that Done has been called so that we don't run
111		// into unexpected Go version incompatibilities.
112		if c.doneNeeded {
113			// We've already installed the wrapper func that checks for Done
114			// so we can avoid doing it again.
115			cleaner.Cleanup(f)
116			return
117		}
118		c.doneNeeded = true
119		cleaner.Cleanup(func() {
120			c.mu.Lock()
121			doneNeeded := c.doneNeeded
122			c.mu.Unlock()
123			if doneNeeded {
124				panic("Done not called after Defer")
125			}
126			f()
127		})
128		return
129	}
130
131	oldDeferred := c.deferred
132	c.deferred = func() {
133		if oldDeferred != nil {
134			defer oldDeferred()
135		}
136		f()
137	}
138}
139
140// Done calls all the functions registered by Defer in reverse
141// registration order. After it's called, the functions are
142// unregistered, so calling Done twice will only call them once.
143//
144// When a test function is called by Run, Done will be called
145// automatically on the C value passed into it.
146//
147// Deprecated: in Go >= 1.14 this is no longer needed if using
148// testing.TB.Cleanup.
149func (c *C) Done() {
150	c.mu.Lock()
151	deferred := c.deferred
152	c.deferred = nil
153	c.doneNeeded = false
154	c.mu.Unlock()
155
156	if deferred != nil {
157		deferred()
158	}
159}
160
161// SetFormat sets the function used to print values in test failures.
162// By default Format is used.
163// Any subsequent subtests invoked with c.Run will also use this function by
164// default.
165func (c *C) SetFormat(format func(interface{}) string) {
166	c.mu.Lock()
167	c.format = format
168	c.mu.Unlock()
169}
170
171// getFormat returns the format function
172// safely acquired under lock.
173func (c *C) getFormat() func(interface{}) string {
174	c.mu.Lock()
175	defer c.mu.Unlock()
176	return c.format
177}
178
179// Check runs the given check and continues execution in case of failure.
180// For instance:
181//
182//     c.Check(answer, qt.Equals, 42)
183//     c.Check(got, qt.IsNil, qt.Commentf("iteration %d", i))
184//
185// Additional args (not consumed by the checker), when provided, are included
186// as comments in the failure output when the check fails.
187func (c *C) Check(got interface{}, checker Checker, args ...interface{}) bool {
188	c.TB.Helper()
189	return check(c, checkParams{
190		fail:    c.TB.Error,
191		checker: checker,
192		got:     got,
193		args:    args,
194	})
195}
196
197// Assert runs the given check and stops execution in case of failure.
198// For instance:
199//
200//     c.Assert(got, qt.DeepEquals, []int{42, 47})
201//     c.Assert(got, qt.ErrorMatches, "bad wolf .*", qt.Commentf("a comment"))
202//
203// Additional args (not consumed by the checker), when provided, are included
204// as comments in the failure output when the check fails.
205func (c *C) Assert(got interface{}, checker Checker, args ...interface{}) bool {
206	c.TB.Helper()
207	return check(c, checkParams{
208		fail:    c.TB.Fatal,
209		checker: checker,
210		got:     got,
211		args:    args,
212	})
213}
214
215var (
216	stringType = reflect.TypeOf("")
217	boolType   = reflect.TypeOf(true)
218	tbType     = reflect.TypeOf(new(testing.TB)).Elem()
219)
220
221// Run runs f as a subtest of t called name. It's a wrapper around
222// the Run method of c.TB that provides the quicktest checker to f. When
223// the function completes, c.Done will be called to run any
224// functions registered with c.Defer.
225//
226// c.TB must implement a Run method of the following form:
227//
228//	Run(string, func(T)) bool
229//
230// where T is any type that is assignable to testing.TB.
231// Implementations include *testing.T, *testing.B and *C itself.
232//
233// The TB field in the subtest will hold the value passed
234// by Run to its argument function.
235//
236//     func TestFoo(t *testing.T) {
237//         c := qt.New(t)
238//         c.Run("A=42", func(c *qt.C) {
239//             // This assertion only stops the current subtest.
240//             c.Assert(a, qt.Equals, 42)
241//         })
242//     }
243//
244// A panic is raised when Run is called and the embedded concrete type does not
245// implement a Run method with a correct signature.
246func (c *C) Run(name string, f func(c *C)) bool {
247	badType := func(m string) {
248		panic(fmt.Sprintf("cannot execute Run with underlying concrete type %T (%s)", c.TB, m))
249	}
250	m := reflect.ValueOf(c.TB).MethodByName("Run")
251	if !m.IsValid() {
252		// c.TB doesn't implement a Run method.
253		badType("no Run method")
254	}
255	mt := m.Type()
256	if mt.NumIn() != 2 ||
257		mt.In(0) != stringType ||
258		mt.NumOut() != 1 ||
259		mt.Out(0) != boolType {
260		// The Run method doesn't have the right argument counts and types.
261		badType("wrong argument count for Run method")
262	}
263	farg := mt.In(1)
264	if farg.Kind() != reflect.Func ||
265		farg.NumIn() != 1 ||
266		farg.NumOut() != 0 ||
267		!farg.In(0).AssignableTo(tbType) {
268		// The first argument to the Run function arg isn't right.
269		badType("bad first argument type for Run method")
270	}
271	fv := reflect.MakeFunc(farg, func(args []reflect.Value) []reflect.Value {
272		c2 := New(args[0].Interface().(testing.TB))
273		defer c2.Done()
274		c2.SetFormat(c.getFormat())
275		f(c2)
276		return nil
277	})
278	return m.Call([]reflect.Value{reflect.ValueOf(name), fv})[0].Interface().(bool)
279}
280
281// Parallel signals that this test is to be run in parallel with (and only with) other parallel tests.
282// It's a wrapper around *testing.T.Parallel.
283//
284// A panic is raised when Parallel is called and the embedded concrete type does not
285// implement Parallel, for instance if TB's concrete type is a benchmark.
286func (c *C) Parallel() {
287	p, ok := c.TB.(interface {
288		Parallel()
289	})
290	if !ok {
291		panic(fmt.Sprintf("cannot execute Parallel with underlying concrete type %T", c.TB))
292	}
293	p.Parallel()
294}
295
296// check performs the actual check with the provided params.
297// In case of failure p.fail is called. In the fail report values are formatted
298// using p.format.
299func check(c *C, p checkParams) bool {
300	c.TB.Helper()
301	rp := reportParams{
302		got:    p.got,
303		args:   p.args,
304		format: c.getFormat(),
305	}
306	if rp.format == nil {
307		// No format set; use the default: Format.
308		rp.format = Format
309	}
310	// Allow checkers to annotate messages.
311	note := func(key string, value interface{}) {
312		rp.notes = append(rp.notes, note{
313			key:   key,
314			value: value,
315		})
316	}
317	// Ensure that we have a checker.
318	if p.checker == nil {
319		p.fail(report(BadCheckf("nil checker provided"), rp))
320		return false
321	}
322	// Extract a comment if it has been provided.
323	rp.argNames = p.checker.ArgNames()
324	wantNumArgs := len(rp.argNames) - 1
325	if len(p.args) > 0 {
326		if comment, ok := p.args[len(p.args)-1].(Comment); ok {
327			rp.comment = comment
328			rp.args = p.args[:len(p.args)-1]
329		}
330	}
331	// Validate that we have the correct number of arguments.
332	if gotNumArgs := len(rp.args); gotNumArgs != wantNumArgs {
333		if gotNumArgs > 0 {
334			note("got args", rp.args)
335		}
336		if wantNumArgs > 0 {
337			note("want args", Unquoted(strings.Join(rp.argNames[1:], ", ")))
338		}
339		var prefix string
340		if gotNumArgs > wantNumArgs {
341			prefix = "too many arguments provided to checker"
342		} else {
343			prefix = "not enough arguments provided to checker"
344		}
345		p.fail(report(BadCheckf("%s: got %d, want %d", prefix, gotNumArgs, wantNumArgs), rp))
346		return false
347	}
348
349	// Execute the check and report the failure if necessary.
350	if err := p.checker.Check(p.got, p.args, note); err != nil {
351		p.fail(report(err, rp))
352		return false
353	}
354	return true
355}
356
357// checkParams holds parameters for executing a check.
358type checkParams struct {
359	fail    func(...interface{})
360	checker Checker
361	got     interface{}
362	args    []interface{}
363}
364