1package gock
2
3import (
4	"net/http"
5	"sync"
6)
7
8// Mock represents the required interface that must
9// be implemented by HTTP mock instances.
10type Mock interface {
11	// Disable disables the current mock manually.
12	Disable()
13
14	// Done returns true if the current mock is disabled.
15	Done() bool
16
17	// Request returns the mock Request instance.
18	Request() *Request
19
20	// Response returns the mock Response instance.
21	Response() *Response
22
23	// Match matches the given http.Request with the current mock.
24	Match(*http.Request) (bool, error)
25
26	// AddMatcher adds a new matcher function.
27	AddMatcher(MatchFunc)
28
29	// SetMatcher uses a new matcher implementation.
30	SetMatcher(Matcher)
31}
32
33// Mocker implements a Mock capable interface providing
34// a default mock configuration used internally to store mocks.
35type Mocker struct {
36	// disabler stores a disabler for thread safety checking current mock is disabled
37	disabler *disabler
38
39	// mutex stores the mock mutex for thread safety.
40	mutex sync.Mutex
41
42	// matcher stores a Matcher capable instance to match the given http.Request.
43	matcher Matcher
44
45	// request stores the mock Request to match.
46	request *Request
47
48	// response stores the mock Response to use in case of match.
49	response *Response
50}
51
52type disabler struct {
53	// disabled stores if the current mock is disabled.
54	disabled bool
55
56	// mutex stores the disabler mutex for thread safety.
57	mutex sync.RWMutex
58}
59
60func (d *disabler) isDisabled() bool {
61	d.mutex.RLock()
62	defer d.mutex.RUnlock()
63	return d.disabled
64}
65
66func (d *disabler) Disable() {
67	d.mutex.Lock()
68	defer d.mutex.Unlock()
69	d.disabled = true
70}
71
72// NewMock creates a new HTTP mock based on the given request and response instances.
73// It's mostly used internally.
74func NewMock(req *Request, res *Response) *Mocker {
75	mock := &Mocker{
76		disabler: new(disabler),
77		request:  req,
78		response: res,
79		matcher:  DefaultMatcher.Clone(),
80	}
81	res.Mock = mock
82	req.Mock = mock
83	req.Response = res
84	return mock
85}
86
87// Disable disables the current mock manually.
88func (m *Mocker) Disable() {
89	m.disabler.Disable()
90}
91
92// Done returns true in case that the current mock
93// instance is disabled and therefore must be removed.
94func (m *Mocker) Done() bool {
95	// prevent deadlock with m.mutex
96	if m.disabler.isDisabled() {
97		return true
98	}
99
100	m.mutex.Lock()
101	defer m.mutex.Unlock()
102	return !m.request.Persisted && m.request.Counter == 0
103}
104
105// Request returns the Request instance
106// configured for the current HTTP mock.
107func (m *Mocker) Request() *Request {
108	return m.request
109}
110
111// Response returns the Response instance
112// configured for the current HTTP mock.
113func (m *Mocker) Response() *Response {
114	return m.response
115}
116
117// Match matches the given http.Request with the current Request
118// mock expectation, returning true if matches.
119func (m *Mocker) Match(req *http.Request) (bool, error) {
120	if m.disabler.isDisabled() {
121		return false, nil
122	}
123
124	// Filter
125	for _, filter := range m.request.Filters {
126		if !filter(req) {
127			return false, nil
128		}
129	}
130
131	// Map
132	for _, mapper := range m.request.Mappers {
133		if treq := mapper(req); treq != nil {
134			req = treq
135		}
136	}
137
138	// Match
139	matches, err := m.matcher.Match(req, m.request)
140	if matches {
141		m.decrement()
142	}
143
144	return matches, err
145}
146
147// SetMatcher sets a new matcher implementation
148// for the current mock expectation.
149func (m *Mocker) SetMatcher(matcher Matcher) {
150	m.matcher = matcher
151}
152
153// AddMatcher adds a new matcher function
154// for the current mock expectation.
155func (m *Mocker) AddMatcher(fn MatchFunc) {
156	m.matcher.Add(fn)
157}
158
159// decrement decrements the current mock Request counter.
160func (m *Mocker) decrement() {
161	if m.request.Persisted {
162		return
163	}
164
165	m.mutex.Lock()
166	defer m.mutex.Unlock()
167
168	m.request.Counter--
169	if m.request.Counter == 0 {
170		m.disabler.Disable()
171	}
172}
173