1// Copyright 2011 Google Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package gomock
16
17import (
18	"bytes"
19	"errors"
20	"fmt"
21)
22
23// callSet represents a set of expected calls, indexed by receiver and method
24// name.
25type callSet struct {
26	// Calls that are still expected.
27	expected map[callSetKey][]*Call
28	// Calls that have been exhausted.
29	exhausted map[callSetKey][]*Call
30}
31
32// callSetKey is the key in the maps in callSet
33type callSetKey struct {
34	receiver interface{}
35	fname    string
36}
37
38func newCallSet() *callSet {
39	return &callSet{make(map[callSetKey][]*Call), make(map[callSetKey][]*Call)}
40}
41
42// Add adds a new expected call.
43func (cs callSet) Add(call *Call) {
44	key := callSetKey{call.receiver, call.method}
45	m := cs.expected
46	if call.exhausted() {
47		m = cs.exhausted
48	}
49	m[key] = append(m[key], call)
50}
51
52// Remove removes an expected call.
53func (cs callSet) Remove(call *Call) {
54	key := callSetKey{call.receiver, call.method}
55	calls := cs.expected[key]
56	for i, c := range calls {
57		if c == call {
58			// maintain order for remaining calls
59			cs.expected[key] = append(calls[:i], calls[i+1:]...)
60			cs.exhausted[key] = append(cs.exhausted[key], call)
61			break
62		}
63	}
64}
65
66// FindMatch searches for a matching call. Returns error with explanation message if no call matched.
67func (cs callSet) FindMatch(receiver interface{}, method string, args []interface{}) (*Call, error) {
68	key := callSetKey{receiver, method}
69
70	// Search through the expected calls.
71	expected := cs.expected[key]
72	var callsErrors bytes.Buffer
73	for _, call := range expected {
74		err := call.matches(args)
75		if err != nil {
76			_, _ = fmt.Fprintf(&callsErrors, "\n%v", err)
77		} else {
78			return call, nil
79		}
80	}
81
82	// If we haven't found a match then search through the exhausted calls so we
83	// get useful error messages.
84	exhausted := cs.exhausted[key]
85	for _, call := range exhausted {
86		if err := call.matches(args); err != nil {
87			_, _ = fmt.Fprintf(&callsErrors, "\n%v", err)
88			continue
89		}
90		_, _ = fmt.Fprintf(
91			&callsErrors, "all expected calls for method %q have been exhausted", method,
92		)
93	}
94
95	if len(expected)+len(exhausted) == 0 {
96		_, _ = fmt.Fprintf(&callsErrors, "there are no expected calls of the method %q for that receiver", method)
97	}
98
99	return nil, errors.New(callsErrors.String())
100}
101
102// Failures returns the calls that are not satisfied.
103func (cs callSet) Failures() []*Call {
104	failures := make([]*Call, 0, len(cs.expected))
105	for _, calls := range cs.expected {
106		for _, call := range calls {
107			if !call.satisfied() {
108				failures = append(failures, call)
109			}
110		}
111	}
112	return failures
113}
114