1// Licensed under the MIT license, see LICENCE file for details.
2
3package quicktest
4
5import (
6	"fmt"
7	"reflect"
8	"strings"
9	"sync"
10	"testing"
11)
12
13// New returns a new checker instance that uses t to fail the test when checks
14// fail. It only ever calls the Fatal, Error and (when available) Run methods
15// of t. For instance.
16//
17//     func TestFoo(t *testing.T) {
18//         t.Run("A=42", func(t *testing.T) {
19//             c := qt.New(t)
20//             c.Assert(a, qt.Equals, 42)
21//         })
22//     }
23//
24// The library already provides some base checkers, and more can be added by
25// implementing the Checker interface.
26//
27// If there is a likelihood that Defer will be called, then
28// a call to Done should be deferred after calling New.
29// For example:
30//
31//     func TestFoo(t *testing.T) {
32//             c := qt.New(t)
33//             defer c.Done()
34//             c.Setenv("HOME", "/non-existent")
35//             c.Assert(os.Getenv("HOME"), qt.Equals, "/non-existent")
36//     })
37//
38// A value of C that's has a non-nil TB field but is otherwise zero is valid.
39// So:
40//
41//	c := &qt.C{TB: t}
42//
43// is valid a way to create a C value; it's exactly the same as:
44//
45//	c := qt.New(t)
46//
47// Methods on C may be called concurrently, assuming the underlying
48// `testing.TB` implementation also allows that.
49func New(t testing.TB) *C {
50	return &C{
51		TB: t,
52	}
53}
54
55// C is a quicktest checker. It embeds a testing.TB value and provides
56// additional checking functionality. If an Assert or Check operation fails, it
57// uses the wrapped TB value to fail the test appropriately.
58type C struct {
59	testing.TB
60
61	mu         sync.Mutex
62	doneNeeded bool
63	deferred   func()
64	format     formatFunc
65}
66
67// cleaner is implemented by testing.TB on Go 1.14 and later.
68type cleaner interface {
69	Cleanup(func())
70}
71
72// Defer registers a function to be called when c.Done is
73// called. Deferred functions will be called in last added, first called
74// order. If c.Done is not called by the end of the test, the test
75// may panic. Note that if Cleanup is called, there is no
76// need to call Done.
77//
78// Deprecated: in Go >= 1.14 use testing.TB.Cleanup instead.
79func (c *C) Defer(f func()) {
80	c.mu.Lock()
81	defer c.mu.Unlock()
82	if cleaner, ok := c.TB.(cleaner); ok {
83		// Use TB.Cleanup when available, but add a check
84		// that Done has been called so that we don't run
85		// into unexpected Go version incompatibilities.
86		if c.doneNeeded {
87			// We've already installed the wrapper func that checks for Done
88			// so we can avoid doing it again.
89			cleaner.Cleanup(f)
90			return
91		}
92		c.doneNeeded = true
93		cleaner.Cleanup(func() {
94			c.mu.Lock()
95			doneNeeded := c.doneNeeded
96			c.mu.Unlock()
97			if doneNeeded {
98				panic("Done not called after Defer")
99			}
100			f()
101		})
102		return
103	}
104
105	oldDeferred := c.deferred
106	c.deferred = func() {
107		if oldDeferred != nil {
108			defer oldDeferred()
109		}
110		f()
111	}
112}
113
114// Done calls all the functions registered by Defer in reverse
115// registration order. After it's called, the functions are
116// unregistered, so calling Done twice will only call them once.
117//
118// When a test function is called by Run, Done will be called
119// automatically on the C value passed into it.
120//
121// Deprecated: in Go >= 1.14 this is no longer needed if using
122// testing.TB.Cleanup.
123func (c *C) Done() {
124	c.mu.Lock()
125	deferred := c.deferred
126	c.deferred = nil
127	c.doneNeeded = false
128	c.mu.Unlock()
129
130	if deferred != nil {
131		deferred()
132	}
133}
134
135// SetFormat sets the function used to print values in test failures.
136// By default Format is used.
137// Any subsequent subtests invoked with c.Run will also use this function by
138// default.
139func (c *C) SetFormat(format func(interface{}) string) {
140	c.mu.Lock()
141	c.format = format
142	c.mu.Unlock()
143}
144
145// getFormat returns the format function
146// safely acquired under lock.
147func (c *C) getFormat() func(interface{}) string {
148	c.mu.Lock()
149	defer c.mu.Unlock()
150	return c.format
151}
152
153// Check runs the given check and continues execution in case of failure.
154// For instance:
155//
156//     c.Check(answer, qt.Equals, 42)
157//     c.Check(got, qt.IsNil, qt.Commentf("iteration %d", i))
158//
159// Additional args (not consumed by the checker), when provided, are included
160// as comments in the failure output when the check fails.
161func (c *C) Check(got interface{}, checker Checker, args ...interface{}) bool {
162	return c.check(c.TB.Error, checker, got, args)
163}
164
165// Assert runs the given check and stops execution in case of failure.
166// For instance:
167//
168//     c.Assert(got, qt.DeepEquals, []int{42, 47})
169//     c.Assert(got, qt.ErrorMatches, "bad wolf .*", qt.Commentf("a comment"))
170//
171// Additional args (not consumed by the checker), when provided, are included
172// as comments in the failure output when the check fails.
173func (c *C) Assert(got interface{}, checker Checker, args ...interface{}) bool {
174	return c.check(c.TB.Fatal, checker, got, args)
175}
176
177var (
178	stringType = reflect.TypeOf("")
179	boolType   = reflect.TypeOf(true)
180	tbType     = reflect.TypeOf(new(testing.TB)).Elem()
181)
182
183// Run runs f as a subtest of t called name. It's a wrapper around
184// the Run method of c.TB that provides the quicktest checker to f. When
185// the function completes, c.Done will be called to run any
186// functions registered with c.Defer.
187//
188// c.TB must implement a Run method of the following form:
189//
190//	Run(string, func(T)) bool
191//
192// where T is any type that is assignable to testing.TB.
193// Implementations include *testing.T, *testing.B and *C itself.
194//
195// The TB field in the subtest will hold the value passed
196// by Run to its argument function.
197//
198//     func TestFoo(t *testing.T) {
199//         c := qt.New(t)
200//         c.Run("A=42", func(c *qt.C) {
201//             // This assertion only stops the current subtest.
202//             c.Assert(a, qt.Equals, 42)
203//         })
204//     }
205//
206// A panic is raised when Run is called and the embedded concrete type does not
207// implement a Run method with a correct signature.
208func (c *C) Run(name string, f func(c *C)) bool {
209	badType := func(m string) {
210		panic(fmt.Sprintf("cannot execute Run with underlying concrete type %T (%s)", c.TB, m))
211	}
212	m := reflect.ValueOf(c.TB).MethodByName("Run")
213	if !m.IsValid() {
214		// c.TB doesn't implement a Run method.
215		badType("no Run method")
216	}
217	mt := m.Type()
218	if mt.NumIn() != 2 ||
219		mt.In(0) != stringType ||
220		mt.NumOut() != 1 ||
221		mt.Out(0) != boolType {
222		// The Run method doesn't have the right argument counts and types.
223		badType("wrong argument count for Run method")
224	}
225	farg := mt.In(1)
226	if farg.Kind() != reflect.Func ||
227		farg.NumIn() != 1 ||
228		farg.NumOut() != 0 ||
229		!farg.In(0).AssignableTo(tbType) {
230		// The first argument to the Run function arg isn't right.
231		badType("bad first argument type for Run method")
232	}
233	fv := reflect.MakeFunc(farg, func(args []reflect.Value) []reflect.Value {
234		c2 := New(args[0].Interface().(testing.TB))
235		defer c2.Done()
236		c2.SetFormat(c.getFormat())
237		f(c2)
238		return nil
239	})
240	return m.Call([]reflect.Value{reflect.ValueOf(name), fv})[0].Interface().(bool)
241}
242
243// Parallel signals that this test is to be run in parallel with (and only with) other parallel tests.
244// It's a wrapper around *testing.T.Parallel.
245//
246// A panic is raised when Parallel is called and the embedded concrete type does not
247// implement Parallel, for instance if TB's concrete type is a benchmark.
248func (c *C) Parallel() {
249	p, ok := c.TB.(interface {
250		Parallel()
251	})
252	if !ok {
253		panic(fmt.Sprintf("cannot execute Parallel with underlying concrete type %T", c.TB))
254	}
255	p.Parallel()
256}
257
258// check performs the actual check and calls the provided fail function in case
259// of failure.
260func (c *C) check(fail func(...interface{}), checker Checker, got interface{}, args []interface{}) bool {
261	// Allow checkers to annotate messages.
262	rp := reportParams{
263		got:    got,
264		args:   args,
265		format: c.getFormat(),
266	}
267	if rp.format == nil {
268		// No format set; use the default: Format.
269		rp.format = Format
270	}
271	note := func(key string, value interface{}) {
272		rp.notes = append(rp.notes, note{
273			key:   key,
274			value: value,
275		})
276	}
277	// Ensure that we have a checker.
278	if checker == nil {
279		fail(report(BadCheckf("nil checker provided"), rp))
280		return false
281	}
282	// Extract a comment if it has been provided.
283	rp.argNames = checker.ArgNames()
284	wantNumArgs := len(rp.argNames) - 1
285	if len(args) > 0 {
286		if comment, ok := args[len(args)-1].(Comment); ok {
287			rp.comment = comment
288			rp.args = args[:len(args)-1]
289		}
290	}
291	// Validate that we have the correct number of arguments.
292	if gotNumArgs := len(rp.args); gotNumArgs != wantNumArgs {
293		if gotNumArgs > 0 {
294			note("got args", rp.args)
295		}
296		if wantNumArgs > 0 {
297			note("want args", Unquoted(strings.Join(rp.argNames[1:], ", ")))
298		}
299		var prefix string
300		if gotNumArgs > wantNumArgs {
301			prefix = "too many arguments provided to checker"
302		} else {
303			prefix = "not enough arguments provided to checker"
304		}
305		fail(report(BadCheckf("%s: got %d, want %d", prefix, gotNumArgs, wantNumArgs), rp))
306		return false
307	}
308
309	// Execute the check and report the failure if necessary.
310	if err := checker.Check(got, args, note); err != nil {
311		fail(report(err, rp))
312		return false
313	}
314	return true
315}
316