1package mock
2
3import (
4	"errors"
5	"fmt"
6	"reflect"
7	"regexp"
8	"runtime"
9	"strings"
10	"sync"
11	"time"
12
13	"github.com/davecgh/go-spew/spew"
14	"github.com/pmezard/go-difflib/difflib"
15	"github.com/stretchr/objx"
16	"github.com/stretchr/testify/assert"
17)
18
19// TestingT is an interface wrapper around *testing.T
20type TestingT interface {
21	Logf(format string, args ...interface{})
22	Errorf(format string, args ...interface{})
23	FailNow()
24}
25
26/*
27	Call
28*/
29
30// Call represents a method call and is used for setting expectations,
31// as well as recording activity.
32type Call struct {
33	Parent *Mock
34
35	// The name of the method that was or will be called.
36	Method string
37
38	// Holds the arguments of the method.
39	Arguments Arguments
40
41	// Holds the arguments that should be returned when
42	// this method is called.
43	ReturnArguments Arguments
44
45	// Holds the caller info for the On() call
46	callerInfo []string
47
48	// The number of times to return the return arguments when setting
49	// expectations. 0 means to always return the value.
50	Repeatability int
51
52	// Amount of times this call has been called
53	totalCalls int
54
55	// Call to this method can be optional
56	optional bool
57
58	// Holds a channel that will be used to block the Return until it either
59	// receives a message or is closed. nil means it returns immediately.
60	WaitFor <-chan time.Time
61
62	waitTime time.Duration
63
64	// Holds a handler used to manipulate arguments content that are passed by
65	// reference. It's useful when mocking methods such as unmarshalers or
66	// decoders.
67	RunFn func(Arguments)
68
69	// PanicMsg holds msg to be used to mock panic on the function call
70	//  if the PanicMsg is set to a non nil string the function call will panic
71	// irrespective of other settings
72	PanicMsg *string
73}
74
75func newCall(parent *Mock, methodName string, callerInfo []string, methodArguments ...interface{}) *Call {
76	return &Call{
77		Parent:          parent,
78		Method:          methodName,
79		Arguments:       methodArguments,
80		ReturnArguments: make([]interface{}, 0),
81		callerInfo:      callerInfo,
82		Repeatability:   0,
83		WaitFor:         nil,
84		RunFn:           nil,
85		PanicMsg:        nil,
86	}
87}
88
89func (c *Call) lock() {
90	c.Parent.mutex.Lock()
91}
92
93func (c *Call) unlock() {
94	c.Parent.mutex.Unlock()
95}
96
97// Return specifies the return arguments for the expectation.
98//
99//    Mock.On("DoSomething").Return(errors.New("failed"))
100func (c *Call) Return(returnArguments ...interface{}) *Call {
101	c.lock()
102	defer c.unlock()
103
104	c.ReturnArguments = returnArguments
105
106	return c
107}
108
109// Panic specifies if the functon call should fail and the panic message
110//
111//    Mock.On("DoSomething").Panic("test panic")
112func (c *Call) Panic(msg string) *Call {
113	c.lock()
114	defer c.unlock()
115
116	c.PanicMsg = &msg
117
118	return c
119}
120
121// Once indicates that that the mock should only return the value once.
122//
123//    Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Once()
124func (c *Call) Once() *Call {
125	return c.Times(1)
126}
127
128// Twice indicates that that the mock should only return the value twice.
129//
130//    Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Twice()
131func (c *Call) Twice() *Call {
132	return c.Times(2)
133}
134
135// Times indicates that that the mock should only return the indicated number
136// of times.
137//
138//    Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Times(5)
139func (c *Call) Times(i int) *Call {
140	c.lock()
141	defer c.unlock()
142	c.Repeatability = i
143	return c
144}
145
146// WaitUntil sets the channel that will block the mock's return until its closed
147// or a message is received.
148//
149//    Mock.On("MyMethod", arg1, arg2).WaitUntil(time.After(time.Second))
150func (c *Call) WaitUntil(w <-chan time.Time) *Call {
151	c.lock()
152	defer c.unlock()
153	c.WaitFor = w
154	return c
155}
156
157// After sets how long to block until the call returns
158//
159//    Mock.On("MyMethod", arg1, arg2).After(time.Second)
160func (c *Call) After(d time.Duration) *Call {
161	c.lock()
162	defer c.unlock()
163	c.waitTime = d
164	return c
165}
166
167// Run sets a handler to be called before returning. It can be used when
168// mocking a method (such as an unmarshaler) that takes a pointer to a struct and
169// sets properties in such struct
170//
171//    Mock.On("Unmarshal", AnythingOfType("*map[string]interface{}")).Return().Run(func(args Arguments) {
172//    	arg := args.Get(0).(*map[string]interface{})
173//    	arg["foo"] = "bar"
174//    })
175func (c *Call) Run(fn func(args Arguments)) *Call {
176	c.lock()
177	defer c.unlock()
178	c.RunFn = fn
179	return c
180}
181
182// Maybe allows the method call to be optional. Not calling an optional method
183// will not cause an error while asserting expectations
184func (c *Call) Maybe() *Call {
185	c.lock()
186	defer c.unlock()
187	c.optional = true
188	return c
189}
190
191// On chains a new expectation description onto the mocked interface. This
192// allows syntax like.
193//
194//    Mock.
195//       On("MyMethod", 1).Return(nil).
196//       On("MyOtherMethod", 'a', 'b', 'c').Return(errors.New("Some Error"))
197//go:noinline
198func (c *Call) On(methodName string, arguments ...interface{}) *Call {
199	return c.Parent.On(methodName, arguments...)
200}
201
202// Mock is the workhorse used to track activity on another object.
203// For an example of its usage, refer to the "Example Usage" section at the top
204// of this document.
205type Mock struct {
206	// Represents the calls that are expected of
207	// an object.
208	ExpectedCalls []*Call
209
210	// Holds the calls that were made to this mocked object.
211	Calls []Call
212
213	// test is An optional variable that holds the test struct, to be used when an
214	// invalid mock call was made.
215	test TestingT
216
217	// TestData holds any data that might be useful for testing.  Testify ignores
218	// this data completely allowing you to do whatever you like with it.
219	testData objx.Map
220
221	mutex sync.Mutex
222}
223
224// TestData holds any data that might be useful for testing.  Testify ignores
225// this data completely allowing you to do whatever you like with it.
226func (m *Mock) TestData() objx.Map {
227
228	if m.testData == nil {
229		m.testData = make(objx.Map)
230	}
231
232	return m.testData
233}
234
235/*
236	Setting expectations
237*/
238
239// Test sets the test struct variable of the mock object
240func (m *Mock) Test(t TestingT) {
241	m.mutex.Lock()
242	defer m.mutex.Unlock()
243	m.test = t
244}
245
246// fail fails the current test with the given formatted format and args.
247// In case that a test was defined, it uses the test APIs for failing a test,
248// otherwise it uses panic.
249func (m *Mock) fail(format string, args ...interface{}) {
250	m.mutex.Lock()
251	defer m.mutex.Unlock()
252
253	if m.test == nil {
254		panic(fmt.Sprintf(format, args...))
255	}
256	m.test.Errorf(format, args...)
257	m.test.FailNow()
258}
259
260// On starts a description of an expectation of the specified method
261// being called.
262//
263//     Mock.On("MyMethod", arg1, arg2)
264func (m *Mock) On(methodName string, arguments ...interface{}) *Call {
265	for _, arg := range arguments {
266		if v := reflect.ValueOf(arg); v.Kind() == reflect.Func {
267			panic(fmt.Sprintf("cannot use Func in expectations. Use mock.AnythingOfType(\"%T\")", arg))
268		}
269	}
270
271	m.mutex.Lock()
272	defer m.mutex.Unlock()
273	c := newCall(m, methodName, assert.CallerInfo(), arguments...)
274	m.ExpectedCalls = append(m.ExpectedCalls, c)
275	return c
276}
277
278// /*
279// 	Recording and responding to activity
280// */
281
282func (m *Mock) findExpectedCall(method string, arguments ...interface{}) (int, *Call) {
283	var expectedCall *Call
284
285	for i, call := range m.ExpectedCalls {
286		if call.Method == method {
287			_, diffCount := call.Arguments.Diff(arguments)
288			if diffCount == 0 {
289				expectedCall = call
290				if call.Repeatability > -1 {
291					return i, call
292				}
293			}
294		}
295	}
296
297	return -1, expectedCall
298}
299
300type matchCandidate struct {
301	call      *Call
302	mismatch  string
303	diffCount int
304}
305
306func (c matchCandidate) isBetterMatchThan(other matchCandidate) bool {
307	if c.call == nil {
308		return false
309	}
310	if other.call == nil {
311		return true
312	}
313
314	if c.diffCount > other.diffCount {
315		return false
316	}
317	if c.diffCount < other.diffCount {
318		return true
319	}
320
321	if c.call.Repeatability > 0 && other.call.Repeatability <= 0 {
322		return true
323	}
324	return false
325}
326
327func (m *Mock) findClosestCall(method string, arguments ...interface{}) (*Call, string) {
328	var bestMatch matchCandidate
329
330	for _, call := range m.expectedCalls() {
331		if call.Method == method {
332
333			errInfo, tempDiffCount := call.Arguments.Diff(arguments)
334			tempCandidate := matchCandidate{
335				call:      call,
336				mismatch:  errInfo,
337				diffCount: tempDiffCount,
338			}
339			if tempCandidate.isBetterMatchThan(bestMatch) {
340				bestMatch = tempCandidate
341			}
342		}
343	}
344
345	return bestMatch.call, bestMatch.mismatch
346}
347
348func callString(method string, arguments Arguments, includeArgumentValues bool) string {
349
350	var argValsString string
351	if includeArgumentValues {
352		var argVals []string
353		for argIndex, arg := range arguments {
354			argVals = append(argVals, fmt.Sprintf("%d: %#v", argIndex, arg))
355		}
356		argValsString = fmt.Sprintf("\n\t\t%s", strings.Join(argVals, "\n\t\t"))
357	}
358
359	return fmt.Sprintf("%s(%s)%s", method, arguments.String(), argValsString)
360}
361
362// Called tells the mock object that a method has been called, and gets an array
363// of arguments to return.  Panics if the call is unexpected (i.e. not preceded by
364// appropriate .On .Return() calls)
365// If Call.WaitFor is set, blocks until the channel is closed or receives a message.
366func (m *Mock) Called(arguments ...interface{}) Arguments {
367	// get the calling function's name
368	pc, _, _, ok := runtime.Caller(1)
369	if !ok {
370		panic("Couldn't get the caller information")
371	}
372	functionPath := runtime.FuncForPC(pc).Name()
373	//Next four lines are required to use GCCGO function naming conventions.
374	//For Ex:  github_com_docker_libkv_store_mock.WatchTree.pN39_github_com_docker_libkv_store_mock.Mock
375	//uses interface information unlike golang github.com/docker/libkv/store/mock.(*Mock).WatchTree
376	//With GCCGO we need to remove interface information starting from pN<dd>.
377	re := regexp.MustCompile("\\.pN\\d+_")
378	if re.MatchString(functionPath) {
379		functionPath = re.Split(functionPath, -1)[0]
380	}
381	parts := strings.Split(functionPath, ".")
382	functionName := parts[len(parts)-1]
383	return m.MethodCalled(functionName, arguments...)
384}
385
386// MethodCalled tells the mock object that the given method has been called, and gets
387// an array of arguments to return. Panics if the call is unexpected (i.e. not preceded
388// by appropriate .On .Return() calls)
389// If Call.WaitFor is set, blocks until the channel is closed or receives a message.
390func (m *Mock) MethodCalled(methodName string, arguments ...interface{}) Arguments {
391	m.mutex.Lock()
392	//TODO: could combine expected and closes in single loop
393	found, call := m.findExpectedCall(methodName, arguments...)
394
395	if found < 0 {
396		// expected call found but it has already been called with repeatable times
397		if call != nil {
398			m.mutex.Unlock()
399			m.fail("\nassert: mock: The method has been called over %d times.\n\tEither do one more Mock.On(\"%s\").Return(...), or remove extra call.\n\tThis call was unexpected:\n\t\t%s\n\tat: %s", call.totalCalls, methodName, callString(methodName, arguments, true), assert.CallerInfo())
400		}
401		// we have to fail here - because we don't know what to do
402		// as the return arguments.  This is because:
403		//
404		//   a) this is a totally unexpected call to this method,
405		//   b) the arguments are not what was expected, or
406		//   c) the developer has forgotten to add an accompanying On...Return pair.
407		closestCall, mismatch := m.findClosestCall(methodName, arguments...)
408		m.mutex.Unlock()
409
410		if closestCall != nil {
411			m.fail("\n\nmock: Unexpected Method Call\n-----------------------------\n\n%s\n\nThe closest call I have is: \n\n%s\n\n%s\nDiff: %s",
412				callString(methodName, arguments, true),
413				callString(methodName, closestCall.Arguments, true),
414				diffArguments(closestCall.Arguments, arguments),
415				strings.TrimSpace(mismatch),
416			)
417		} else {
418			m.fail("\nassert: mock: I don't know what to return because the method call was unexpected.\n\tEither do Mock.On(\"%s\").Return(...) first, or remove the %s() call.\n\tThis method was unexpected:\n\t\t%s\n\tat: %s", methodName, methodName, callString(methodName, arguments, true), assert.CallerInfo())
419		}
420	}
421
422	if call.Repeatability == 1 {
423		call.Repeatability = -1
424	} else if call.Repeatability > 1 {
425		call.Repeatability--
426	}
427	call.totalCalls++
428
429	// add the call
430	m.Calls = append(m.Calls, *newCall(m, methodName, assert.CallerInfo(), arguments...))
431	m.mutex.Unlock()
432
433	// block if specified
434	if call.WaitFor != nil {
435		<-call.WaitFor
436	} else {
437		time.Sleep(call.waitTime)
438	}
439
440	m.mutex.Lock()
441	panicMsg := call.PanicMsg
442	m.mutex.Unlock()
443	if panicMsg != nil {
444		panic(*panicMsg)
445	}
446
447	m.mutex.Lock()
448	runFn := call.RunFn
449	m.mutex.Unlock()
450
451	if runFn != nil {
452		runFn(arguments)
453	}
454
455	m.mutex.Lock()
456	returnArgs := call.ReturnArguments
457	m.mutex.Unlock()
458
459	return returnArgs
460}
461
462/*
463	Assertions
464*/
465
466type assertExpectationser interface {
467	AssertExpectations(TestingT) bool
468}
469
470// AssertExpectationsForObjects asserts that everything specified with On and Return
471// of the specified objects was in fact called as expected.
472//
473// Calls may have occurred in any order.
474func AssertExpectationsForObjects(t TestingT, testObjects ...interface{}) bool {
475	if h, ok := t.(tHelper); ok {
476		h.Helper()
477	}
478	for _, obj := range testObjects {
479		if m, ok := obj.(Mock); ok {
480			t.Logf("Deprecated mock.AssertExpectationsForObjects(myMock.Mock) use mock.AssertExpectationsForObjects(myMock)")
481			obj = &m
482		}
483		m := obj.(assertExpectationser)
484		if !m.AssertExpectations(t) {
485			t.Logf("Expectations didn't match for Mock: %+v", reflect.TypeOf(m))
486			return false
487		}
488	}
489	return true
490}
491
492// AssertExpectations asserts that everything specified with On and Return was
493// in fact called as expected.  Calls may have occurred in any order.
494func (m *Mock) AssertExpectations(t TestingT) bool {
495	if h, ok := t.(tHelper); ok {
496		h.Helper()
497	}
498	m.mutex.Lock()
499	defer m.mutex.Unlock()
500	var somethingMissing bool
501	var failedExpectations int
502
503	// iterate through each expectation
504	expectedCalls := m.expectedCalls()
505	for _, expectedCall := range expectedCalls {
506		if !expectedCall.optional && !m.methodWasCalled(expectedCall.Method, expectedCall.Arguments) && expectedCall.totalCalls == 0 {
507			somethingMissing = true
508			failedExpectations++
509			t.Logf("FAIL:\t%s(%s)\n\t\tat: %s", expectedCall.Method, expectedCall.Arguments.String(), expectedCall.callerInfo)
510		} else {
511			if expectedCall.Repeatability > 0 {
512				somethingMissing = true
513				failedExpectations++
514				t.Logf("FAIL:\t%s(%s)\n\t\tat: %s", expectedCall.Method, expectedCall.Arguments.String(), expectedCall.callerInfo)
515			} else {
516				t.Logf("PASS:\t%s(%s)", expectedCall.Method, expectedCall.Arguments.String())
517			}
518		}
519	}
520
521	if somethingMissing {
522		t.Errorf("FAIL: %d out of %d expectation(s) were met.\n\tThe code you are testing needs to make %d more call(s).\n\tat: %s", len(expectedCalls)-failedExpectations, len(expectedCalls), failedExpectations, assert.CallerInfo())
523	}
524
525	return !somethingMissing
526}
527
528// AssertNumberOfCalls asserts that the method was called expectedCalls times.
529func (m *Mock) AssertNumberOfCalls(t TestingT, methodName string, expectedCalls int) bool {
530	if h, ok := t.(tHelper); ok {
531		h.Helper()
532	}
533	m.mutex.Lock()
534	defer m.mutex.Unlock()
535	var actualCalls int
536	for _, call := range m.calls() {
537		if call.Method == methodName {
538			actualCalls++
539		}
540	}
541	return assert.Equal(t, expectedCalls, actualCalls, fmt.Sprintf("Expected number of calls (%d) does not match the actual number of calls (%d).", expectedCalls, actualCalls))
542}
543
544// AssertCalled asserts that the method was called.
545// It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method.
546func (m *Mock) AssertCalled(t TestingT, methodName string, arguments ...interface{}) bool {
547	if h, ok := t.(tHelper); ok {
548		h.Helper()
549	}
550	m.mutex.Lock()
551	defer m.mutex.Unlock()
552	if !m.methodWasCalled(methodName, arguments) {
553		var calledWithArgs []string
554		for _, call := range m.calls() {
555			calledWithArgs = append(calledWithArgs, fmt.Sprintf("%v", call.Arguments))
556		}
557		if len(calledWithArgs) == 0 {
558			return assert.Fail(t, "Should have called with given arguments",
559				fmt.Sprintf("Expected %q to have been called with:\n%v\nbut no actual calls happened", methodName, arguments))
560		}
561		return assert.Fail(t, "Should have called with given arguments",
562			fmt.Sprintf("Expected %q to have been called with:\n%v\nbut actual calls were:\n        %v", methodName, arguments, strings.Join(calledWithArgs, "\n")))
563	}
564	return true
565}
566
567// AssertNotCalled asserts that the method was not called.
568// It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method.
569func (m *Mock) AssertNotCalled(t TestingT, methodName string, arguments ...interface{}) bool {
570	if h, ok := t.(tHelper); ok {
571		h.Helper()
572	}
573	m.mutex.Lock()
574	defer m.mutex.Unlock()
575	if m.methodWasCalled(methodName, arguments) {
576		return assert.Fail(t, "Should not have called with given arguments",
577			fmt.Sprintf("Expected %q to not have been called with:\n%v\nbut actually it was.", methodName, arguments))
578	}
579	return true
580}
581
582// IsMethodCallable checking that the method can be called
583// If the method was called more than `Repeatability` return false
584func (m *Mock) IsMethodCallable(t TestingT, methodName string, arguments ...interface{}) bool {
585	if h, ok := t.(tHelper); ok {
586		h.Helper()
587	}
588	m.mutex.Lock()
589	defer m.mutex.Unlock()
590
591	for _, v := range m.ExpectedCalls {
592		if v.Method != methodName {
593			continue
594		}
595		if len(arguments) != len(v.Arguments) {
596			continue
597		}
598		if v.Repeatability < v.totalCalls {
599			continue
600		}
601		if isArgsEqual(v.Arguments, arguments) {
602			return true
603		}
604	}
605	return false
606}
607
608// isArgsEqual compares arguments
609func isArgsEqual(expected Arguments, args []interface{}) bool {
610	if len(expected) != len(args) {
611		return false
612	}
613	for i, v := range args {
614		if !reflect.DeepEqual(expected[i], v) {
615			return false
616		}
617	}
618	return true
619}
620
621func (m *Mock) methodWasCalled(methodName string, expected []interface{}) bool {
622	for _, call := range m.calls() {
623		if call.Method == methodName {
624
625			_, differences := Arguments(expected).Diff(call.Arguments)
626
627			if differences == 0 {
628				// found the expected call
629				return true
630			}
631
632		}
633	}
634	// we didn't find the expected call
635	return false
636}
637
638func (m *Mock) expectedCalls() []*Call {
639	return append([]*Call{}, m.ExpectedCalls...)
640}
641
642func (m *Mock) calls() []Call {
643	return append([]Call{}, m.Calls...)
644}
645
646/*
647	Arguments
648*/
649
650// Arguments holds an array of method arguments or return values.
651type Arguments []interface{}
652
653const (
654	// Anything is used in Diff and Assert when the argument being tested
655	// shouldn't be taken into consideration.
656	Anything = "mock.Anything"
657)
658
659// AnythingOfTypeArgument is a string that contains the type of an argument
660// for use when type checking.  Used in Diff and Assert.
661type AnythingOfTypeArgument string
662
663// AnythingOfType returns an AnythingOfTypeArgument object containing the
664// name of the type to check for.  Used in Diff and Assert.
665//
666// For example:
667//	Assert(t, AnythingOfType("string"), AnythingOfType("int"))
668func AnythingOfType(t string) AnythingOfTypeArgument {
669	return AnythingOfTypeArgument(t)
670}
671
672// IsTypeArgument is a struct that contains the type of an argument
673// for use when type checking.  This is an alternative to AnythingOfType.
674// Used in Diff and Assert.
675type IsTypeArgument struct {
676	t interface{}
677}
678
679// IsType returns an IsTypeArgument object containing the type to check for.
680// You can provide a zero-value of the type to check.  This is an
681// alternative to AnythingOfType.  Used in Diff and Assert.
682//
683// For example:
684// Assert(t, IsType(""), IsType(0))
685func IsType(t interface{}) *IsTypeArgument {
686	return &IsTypeArgument{t: t}
687}
688
689// argumentMatcher performs custom argument matching, returning whether or
690// not the argument is matched by the expectation fixture function.
691type argumentMatcher struct {
692	// fn is a function which accepts one argument, and returns a bool.
693	fn reflect.Value
694}
695
696func (f argumentMatcher) Matches(argument interface{}) bool {
697	expectType := f.fn.Type().In(0)
698	expectTypeNilSupported := false
699	switch expectType.Kind() {
700	case reflect.Interface, reflect.Chan, reflect.Func, reflect.Map, reflect.Slice, reflect.Ptr:
701		expectTypeNilSupported = true
702	}
703
704	argType := reflect.TypeOf(argument)
705	var arg reflect.Value
706	if argType == nil {
707		arg = reflect.New(expectType).Elem()
708	} else {
709		arg = reflect.ValueOf(argument)
710	}
711
712	if argType == nil && !expectTypeNilSupported {
713		panic(errors.New("attempting to call matcher with nil for non-nil expected type"))
714	}
715	if argType == nil || argType.AssignableTo(expectType) {
716		result := f.fn.Call([]reflect.Value{arg})
717		return result[0].Bool()
718	}
719	return false
720}
721
722func (f argumentMatcher) String() string {
723	return fmt.Sprintf("func(%s) bool", f.fn.Type().In(0).Name())
724}
725
726// MatchedBy can be used to match a mock call based on only certain properties
727// from a complex struct or some calculation. It takes a function that will be
728// evaluated with the called argument and will return true when there's a match
729// and false otherwise.
730//
731// Example:
732// m.On("Do", MatchedBy(func(req *http.Request) bool { return req.Host == "example.com" }))
733//
734// |fn|, must be a function accepting a single argument (of the expected type)
735// which returns a bool. If |fn| doesn't match the required signature,
736// MatchedBy() panics.
737func MatchedBy(fn interface{}) argumentMatcher {
738	fnType := reflect.TypeOf(fn)
739
740	if fnType.Kind() != reflect.Func {
741		panic(fmt.Sprintf("assert: arguments: %s is not a func", fn))
742	}
743	if fnType.NumIn() != 1 {
744		panic(fmt.Sprintf("assert: arguments: %s does not take exactly one argument", fn))
745	}
746	if fnType.NumOut() != 1 || fnType.Out(0).Kind() != reflect.Bool {
747		panic(fmt.Sprintf("assert: arguments: %s does not return a bool", fn))
748	}
749
750	return argumentMatcher{fn: reflect.ValueOf(fn)}
751}
752
753// Get Returns the argument at the specified index.
754func (args Arguments) Get(index int) interface{} {
755	if index+1 > len(args) {
756		panic(fmt.Sprintf("assert: arguments: Cannot call Get(%d) because there are %d argument(s).", index, len(args)))
757	}
758	return args[index]
759}
760
761// Is gets whether the objects match the arguments specified.
762func (args Arguments) Is(objects ...interface{}) bool {
763	for i, obj := range args {
764		if obj != objects[i] {
765			return false
766		}
767	}
768	return true
769}
770
771// Diff gets a string describing the differences between the arguments
772// and the specified objects.
773//
774// Returns the diff string and number of differences found.
775func (args Arguments) Diff(objects []interface{}) (string, int) {
776	//TODO: could return string as error and nil for No difference
777
778	var output = "\n"
779	var differences int
780
781	var maxArgCount = len(args)
782	if len(objects) > maxArgCount {
783		maxArgCount = len(objects)
784	}
785
786	for i := 0; i < maxArgCount; i++ {
787		var actual, expected interface{}
788		var actualFmt, expectedFmt string
789
790		if len(objects) <= i {
791			actual = "(Missing)"
792			actualFmt = "(Missing)"
793		} else {
794			actual = objects[i]
795			actualFmt = fmt.Sprintf("(%[1]T=%[1]v)", actual)
796		}
797
798		if len(args) <= i {
799			expected = "(Missing)"
800			expectedFmt = "(Missing)"
801		} else {
802			expected = args[i]
803			expectedFmt = fmt.Sprintf("(%[1]T=%[1]v)", expected)
804		}
805
806		if matcher, ok := expected.(argumentMatcher); ok {
807			if matcher.Matches(actual) {
808				output = fmt.Sprintf("%s\t%d: PASS:  %s matched by %s\n", output, i, actualFmt, matcher)
809			} else {
810				differences++
811				output = fmt.Sprintf("%s\t%d: FAIL:  %s not matched by %s\n", output, i, actualFmt, matcher)
812			}
813		} else if reflect.TypeOf(expected) == reflect.TypeOf((*AnythingOfTypeArgument)(nil)).Elem() {
814
815			// type checking
816			if reflect.TypeOf(actual).Name() != string(expected.(AnythingOfTypeArgument)) && reflect.TypeOf(actual).String() != string(expected.(AnythingOfTypeArgument)) {
817				// not match
818				differences++
819				output = fmt.Sprintf("%s\t%d: FAIL:  type %s != type %s - %s\n", output, i, expected, reflect.TypeOf(actual).Name(), actualFmt)
820			}
821
822		} else if reflect.TypeOf(expected) == reflect.TypeOf((*IsTypeArgument)(nil)) {
823			t := expected.(*IsTypeArgument).t
824			if reflect.TypeOf(t) != reflect.TypeOf(actual) {
825				differences++
826				output = fmt.Sprintf("%s\t%d: FAIL:  type %s != type %s - %s\n", output, i, reflect.TypeOf(t).Name(), reflect.TypeOf(actual).Name(), actualFmt)
827			}
828		} else {
829
830			// normal checking
831
832			if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) {
833				// match
834				output = fmt.Sprintf("%s\t%d: PASS:  %s == %s\n", output, i, actualFmt, expectedFmt)
835			} else {
836				// not match
837				differences++
838				output = fmt.Sprintf("%s\t%d: FAIL:  %s != %s\n", output, i, actualFmt, expectedFmt)
839			}
840		}
841
842	}
843
844	if differences == 0 {
845		return "No differences.", differences
846	}
847
848	return output, differences
849
850}
851
852// Assert compares the arguments with the specified objects and fails if
853// they do not exactly match.
854func (args Arguments) Assert(t TestingT, objects ...interface{}) bool {
855	if h, ok := t.(tHelper); ok {
856		h.Helper()
857	}
858
859	// get the differences
860	diff, diffCount := args.Diff(objects)
861
862	if diffCount == 0 {
863		return true
864	}
865
866	// there are differences... report them...
867	t.Logf(diff)
868	t.Errorf("%sArguments do not match.", assert.CallerInfo())
869
870	return false
871
872}
873
874// String gets the argument at the specified index. Panics if there is no argument, or
875// if the argument is of the wrong type.
876//
877// If no index is provided, String() returns a complete string representation
878// of the arguments.
879func (args Arguments) String(indexOrNil ...int) string {
880
881	if len(indexOrNil) == 0 {
882		// normal String() method - return a string representation of the args
883		var argsStr []string
884		for _, arg := range args {
885			argsStr = append(argsStr, fmt.Sprintf("%T", arg)) // handles nil nicely
886		}
887		return strings.Join(argsStr, ",")
888	} else if len(indexOrNil) == 1 {
889		// Index has been specified - get the argument at that index
890		var index = indexOrNil[0]
891		var s string
892		var ok bool
893		if s, ok = args.Get(index).(string); !ok {
894			panic(fmt.Sprintf("assert: arguments: String(%d) failed because object wasn't correct type: %s", index, args.Get(index)))
895		}
896		return s
897	}
898
899	panic(fmt.Sprintf("assert: arguments: Wrong number of arguments passed to String.  Must be 0 or 1, not %d", len(indexOrNil)))
900
901}
902
903// Int gets the argument at the specified index. Panics if there is no argument, or
904// if the argument is of the wrong type.
905func (args Arguments) Int(index int) int {
906	var s int
907	var ok bool
908	if s, ok = args.Get(index).(int); !ok {
909		panic(fmt.Sprintf("assert: arguments: Int(%d) failed because object wasn't correct type: %v", index, args.Get(index)))
910	}
911	return s
912}
913
914// Error gets the argument at the specified index. Panics if there is no argument, or
915// if the argument is of the wrong type.
916func (args Arguments) Error(index int) error {
917	obj := args.Get(index)
918	var s error
919	var ok bool
920	if obj == nil {
921		return nil
922	}
923	if s, ok = obj.(error); !ok {
924		panic(fmt.Sprintf("assert: arguments: Error(%d) failed because object wasn't correct type: %v", index, args.Get(index)))
925	}
926	return s
927}
928
929// Bool gets the argument at the specified index. Panics if there is no argument, or
930// if the argument is of the wrong type.
931func (args Arguments) Bool(index int) bool {
932	var s bool
933	var ok bool
934	if s, ok = args.Get(index).(bool); !ok {
935		panic(fmt.Sprintf("assert: arguments: Bool(%d) failed because object wasn't correct type: %v", index, args.Get(index)))
936	}
937	return s
938}
939
940func typeAndKind(v interface{}) (reflect.Type, reflect.Kind) {
941	t := reflect.TypeOf(v)
942	k := t.Kind()
943
944	if k == reflect.Ptr {
945		t = t.Elem()
946		k = t.Kind()
947	}
948	return t, k
949}
950
951func diffArguments(expected Arguments, actual Arguments) string {
952	if len(expected) != len(actual) {
953		return fmt.Sprintf("Provided %v arguments, mocked for %v arguments", len(expected), len(actual))
954	}
955
956	for x := range expected {
957		if diffString := diff(expected[x], actual[x]); diffString != "" {
958			return fmt.Sprintf("Difference found in argument %v:\n\n%s", x, diffString)
959		}
960	}
961
962	return ""
963}
964
965// diff returns a diff of both values as long as both are of the same type and
966// are a struct, map, slice or array. Otherwise it returns an empty string.
967func diff(expected interface{}, actual interface{}) string {
968	if expected == nil || actual == nil {
969		return ""
970	}
971
972	et, ek := typeAndKind(expected)
973	at, _ := typeAndKind(actual)
974
975	if et != at {
976		return ""
977	}
978
979	if ek != reflect.Struct && ek != reflect.Map && ek != reflect.Slice && ek != reflect.Array {
980		return ""
981	}
982
983	e := spewConfig.Sdump(expected)
984	a := spewConfig.Sdump(actual)
985
986	diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{
987		A:        difflib.SplitLines(e),
988		B:        difflib.SplitLines(a),
989		FromFile: "Expected",
990		FromDate: "",
991		ToFile:   "Actual",
992		ToDate:   "",
993		Context:  1,
994	})
995
996	return diff
997}
998
999var spewConfig = spew.ConfigState{
1000	Indent:                  " ",
1001	DisablePointerAddresses: true,
1002	DisableCapacities:       true,
1003	SortKeys:                true,
1004}
1005
1006type tHelper interface {
1007	Helper()
1008}
1009