1/*
2Package sqlmock is a mock library implementing sql driver. Which has one and only
3purpose - to simulate any sql driver behavior in tests, without needing a real
4database connection. It helps to maintain correct **TDD** workflow.
5
6It does not require any modifications to your source code in order to test
7and mock database operations. Supports concurrency and multiple database mocking.
8
9The driver allows to mock any sql driver method behavior.
10*/
11package sqlmock
12
13import (
14	"database/sql"
15	"database/sql/driver"
16	"fmt"
17	"time"
18)
19
20// Sqlmock interface serves to create expectations
21// for any kind of database action in order to mock
22// and test real database behavior.
23type SqlmockCommon interface {
24	// ExpectClose queues an expectation for this database
25	// action to be triggered. the *ExpectedClose allows
26	// to mock database response
27	ExpectClose() *ExpectedClose
28
29	// ExpectationsWereMet checks whether all queued expectations
30	// were met in order. If any of them was not met - an error is returned.
31	ExpectationsWereMet() error
32
33	// ExpectPrepare expects Prepare() to be called with expectedSQL query.
34	// the *ExpectedPrepare allows to mock database response.
35	// Note that you may expect Query() or Exec() on the *ExpectedPrepare
36	// statement to prevent repeating expectedSQL
37	ExpectPrepare(expectedSQL string) *ExpectedPrepare
38
39	// ExpectQuery expects Query() or QueryRow() to be called with expectedSQL query.
40	// the *ExpectedQuery allows to mock database response.
41	ExpectQuery(expectedSQL string) *ExpectedQuery
42
43	// ExpectExec expects Exec() to be called with expectedSQL query.
44	// the *ExpectedExec allows to mock database response
45	ExpectExec(expectedSQL string) *ExpectedExec
46
47	// ExpectBegin expects *sql.DB.Begin to be called.
48	// the *ExpectedBegin allows to mock database response
49	ExpectBegin() *ExpectedBegin
50
51	// ExpectCommit expects *sql.Tx.Commit to be called.
52	// the *ExpectedCommit allows to mock database response
53	ExpectCommit() *ExpectedCommit
54
55	// ExpectRollback expects *sql.Tx.Rollback to be called.
56	// the *ExpectedRollback allows to mock database response
57	ExpectRollback() *ExpectedRollback
58
59	// ExpectPing expected *sql.DB.Ping to be called.
60	// the *ExpectedPing allows to mock database response
61	//
62	// Ping support only exists in the SQL library in Go 1.8 and above.
63	// ExpectPing in Go <=1.7 will return an ExpectedPing but not register
64	// any expectations.
65	//
66	// You must enable pings using MonitorPingsOption for this to register
67	// any expectations.
68	ExpectPing() *ExpectedPing
69
70	// MatchExpectationsInOrder gives an option whether to match all
71	// expectations in the order they were set or not.
72	//
73	// By default it is set to - true. But if you use goroutines
74	// to parallelize your query executation, that option may
75	// be handy.
76	//
77	// This option may be turned on anytime during tests. As soon
78	// as it is switched to false, expectations will be matched
79	// in any order. Or otherwise if switched to true, any unmatched
80	// expectations will be expected in order
81	MatchExpectationsInOrder(bool)
82
83	// NewRows allows Rows to be created from a
84	// sql driver.Value slice or from the CSV string and
85	// to be used as sql driver.Rows.
86	NewRows(columns []string) *Rows
87}
88
89type sqlmock struct {
90	ordered      bool
91	dsn          string
92	opened       int
93	drv          *mockDriver
94	converter    driver.ValueConverter
95	queryMatcher QueryMatcher
96	monitorPings bool
97
98	expected []expectation
99}
100
101func (c *sqlmock) open(options []func(*sqlmock) error) (*sql.DB, Sqlmock, error) {
102	db, err := sql.Open("sqlmock", c.dsn)
103	if err != nil {
104		return db, c, err
105	}
106	for _, option := range options {
107		err := option(c)
108		if err != nil {
109			return db, c, err
110		}
111	}
112	if c.converter == nil {
113		c.converter = driver.DefaultParameterConverter
114	}
115	if c.queryMatcher == nil {
116		c.queryMatcher = QueryMatcherRegexp
117	}
118
119	if c.monitorPings {
120		// We call Ping on the driver shortly to verify startup assertions by
121		// driving internal behaviour of the sql standard library. We don't
122		// want this call to ping to be monitored for expectation purposes so
123		// temporarily disable.
124		c.monitorPings = false
125		defer func() { c.monitorPings = true }()
126	}
127	return db, c, db.Ping()
128}
129
130func (c *sqlmock) ExpectClose() *ExpectedClose {
131	e := &ExpectedClose{}
132	c.expected = append(c.expected, e)
133	return e
134}
135
136func (c *sqlmock) MatchExpectationsInOrder(b bool) {
137	c.ordered = b
138}
139
140// Close a mock database driver connection. It may or may not
141// be called depending on the circumstances, but if it is called
142// there must be an *ExpectedClose expectation satisfied.
143// meets http://golang.org/pkg/database/sql/driver/#Conn interface
144func (c *sqlmock) Close() error {
145	c.drv.Lock()
146	defer c.drv.Unlock()
147
148	c.opened--
149	if c.opened == 0 {
150		delete(c.drv.conns, c.dsn)
151	}
152
153	var expected *ExpectedClose
154	var fulfilled int
155	var ok bool
156	for _, next := range c.expected {
157		next.Lock()
158		if next.fulfilled() {
159			next.Unlock()
160			fulfilled++
161			continue
162		}
163
164		if expected, ok = next.(*ExpectedClose); ok {
165			break
166		}
167
168		next.Unlock()
169		if c.ordered {
170			return fmt.Errorf("call to database Close, was not expected, next expectation is: %s", next)
171		}
172	}
173
174	if expected == nil {
175		msg := "call to database Close was not expected"
176		if fulfilled == len(c.expected) {
177			msg = "all expectations were already fulfilled, " + msg
178		}
179		return fmt.Errorf(msg)
180	}
181
182	expected.triggered = true
183	expected.Unlock()
184	return expected.err
185}
186
187func (c *sqlmock) ExpectationsWereMet() error {
188	for _, e := range c.expected {
189		e.Lock()
190		fulfilled := e.fulfilled()
191		e.Unlock()
192
193		if !fulfilled {
194			return fmt.Errorf("there is a remaining expectation which was not matched: %s", e)
195		}
196
197		// for expected prepared statement check whether it was closed if expected
198		if prep, ok := e.(*ExpectedPrepare); ok {
199			if prep.mustBeClosed && !prep.wasClosed {
200				return fmt.Errorf("expected prepared statement to be closed, but it was not: %s", prep)
201			}
202		}
203
204		// must check whether all expected queried rows are closed
205		if query, ok := e.(*ExpectedQuery); ok {
206			if query.rowsMustBeClosed && !query.rowsWereClosed {
207				return fmt.Errorf("expected query rows to be closed, but it was not: %s", query)
208			}
209		}
210	}
211	return nil
212}
213
214// Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface
215func (c *sqlmock) Begin() (driver.Tx, error) {
216	ex, err := c.begin()
217	if ex != nil {
218		time.Sleep(ex.delay)
219	}
220	if err != nil {
221		return nil, err
222	}
223
224	return c, nil
225}
226
227func (c *sqlmock) begin() (*ExpectedBegin, error) {
228	var expected *ExpectedBegin
229	var ok bool
230	var fulfilled int
231	for _, next := range c.expected {
232		next.Lock()
233		if next.fulfilled() {
234			next.Unlock()
235			fulfilled++
236			continue
237		}
238
239		if expected, ok = next.(*ExpectedBegin); ok {
240			break
241		}
242
243		next.Unlock()
244		if c.ordered {
245			return nil, fmt.Errorf("call to database transaction Begin, was not expected, next expectation is: %s", next)
246		}
247	}
248	if expected == nil {
249		msg := "call to database transaction Begin was not expected"
250		if fulfilled == len(c.expected) {
251			msg = "all expectations were already fulfilled, " + msg
252		}
253		return nil, fmt.Errorf(msg)
254	}
255
256	expected.triggered = true
257	expected.Unlock()
258
259	return expected, expected.err
260}
261
262func (c *sqlmock) ExpectBegin() *ExpectedBegin {
263	e := &ExpectedBegin{}
264	c.expected = append(c.expected, e)
265	return e
266}
267
268func (c *sqlmock) ExpectExec(expectedSQL string) *ExpectedExec {
269	e := &ExpectedExec{}
270	e.expectSQL = expectedSQL
271	e.converter = c.converter
272	c.expected = append(c.expected, e)
273	return e
274}
275
276// Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface
277func (c *sqlmock) Prepare(query string) (driver.Stmt, error) {
278	ex, err := c.prepare(query)
279	if ex != nil {
280		time.Sleep(ex.delay)
281	}
282	if err != nil {
283		return nil, err
284	}
285
286	return &statement{c, ex, query}, nil
287}
288
289func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) {
290	var expected *ExpectedPrepare
291	var fulfilled int
292	var ok bool
293
294	for _, next := range c.expected {
295		next.Lock()
296		if next.fulfilled() {
297			next.Unlock()
298			fulfilled++
299			continue
300		}
301
302		if c.ordered {
303			if expected, ok = next.(*ExpectedPrepare); ok {
304				break
305			}
306
307			next.Unlock()
308			return nil, fmt.Errorf("call to Prepare statement with query '%s', was not expected, next expectation is: %s", query, next)
309		}
310
311		if pr, ok := next.(*ExpectedPrepare); ok {
312			if err := c.queryMatcher.Match(pr.expectSQL, query); err == nil {
313				expected = pr
314				break
315			}
316		}
317		next.Unlock()
318	}
319
320	if expected == nil {
321		msg := "call to Prepare '%s' query was not expected"
322		if fulfilled == len(c.expected) {
323			msg = "all expectations were already fulfilled, " + msg
324		}
325		return nil, fmt.Errorf(msg, query)
326	}
327	defer expected.Unlock()
328	if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil {
329		return nil, fmt.Errorf("Prepare: %v", err)
330	}
331
332	expected.triggered = true
333	return expected, expected.err
334}
335
336func (c *sqlmock) ExpectPrepare(expectedSQL string) *ExpectedPrepare {
337	e := &ExpectedPrepare{expectSQL: expectedSQL, mock: c}
338	c.expected = append(c.expected, e)
339	return e
340}
341
342func (c *sqlmock) ExpectQuery(expectedSQL string) *ExpectedQuery {
343	e := &ExpectedQuery{}
344	e.expectSQL = expectedSQL
345	e.converter = c.converter
346	c.expected = append(c.expected, e)
347	return e
348}
349
350func (c *sqlmock) ExpectCommit() *ExpectedCommit {
351	e := &ExpectedCommit{}
352	c.expected = append(c.expected, e)
353	return e
354}
355
356func (c *sqlmock) ExpectRollback() *ExpectedRollback {
357	e := &ExpectedRollback{}
358	c.expected = append(c.expected, e)
359	return e
360}
361
362// Commit meets http://golang.org/pkg/database/sql/driver/#Tx
363func (c *sqlmock) Commit() error {
364	var expected *ExpectedCommit
365	var fulfilled int
366	var ok bool
367	for _, next := range c.expected {
368		next.Lock()
369		if next.fulfilled() {
370			next.Unlock()
371			fulfilled++
372			continue
373		}
374
375		if expected, ok = next.(*ExpectedCommit); ok {
376			break
377		}
378
379		next.Unlock()
380		if c.ordered {
381			return fmt.Errorf("call to Commit transaction, was not expected, next expectation is: %s", next)
382		}
383	}
384	if expected == nil {
385		msg := "call to Commit transaction was not expected"
386		if fulfilled == len(c.expected) {
387			msg = "all expectations were already fulfilled, " + msg
388		}
389		return fmt.Errorf(msg)
390	}
391
392	expected.triggered = true
393	expected.Unlock()
394	return expected.err
395}
396
397// Rollback meets http://golang.org/pkg/database/sql/driver/#Tx
398func (c *sqlmock) Rollback() error {
399	var expected *ExpectedRollback
400	var fulfilled int
401	var ok bool
402	for _, next := range c.expected {
403		next.Lock()
404		if next.fulfilled() {
405			next.Unlock()
406			fulfilled++
407			continue
408		}
409
410		if expected, ok = next.(*ExpectedRollback); ok {
411			break
412		}
413
414		next.Unlock()
415		if c.ordered {
416			return fmt.Errorf("call to Rollback transaction, was not expected, next expectation is: %s", next)
417		}
418	}
419	if expected == nil {
420		msg := "call to Rollback transaction was not expected"
421		if fulfilled == len(c.expected) {
422			msg = "all expectations were already fulfilled, " + msg
423		}
424		return fmt.Errorf(msg)
425	}
426
427	expected.triggered = true
428	expected.Unlock()
429	return expected.err
430}
431
432// NewRows allows Rows to be created from a
433// sql driver.Value slice or from the CSV string and
434// to be used as sql driver.Rows.
435func (c *sqlmock) NewRows(columns []string) *Rows {
436	r := NewRows(columns)
437	r.converter = c.converter
438	return r
439}
440