1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package sql
6
7import (
8	"database/sql/driver"
9	"errors"
10	"fmt"
11	"io"
12	"log"
13	"strconv"
14	"strings"
15	"sync"
16	"testing"
17	"time"
18)
19
20var _ = log.Printf
21
22// fakeDriver is a fake database that implements Go's driver.Driver
23// interface, just for testing.
24//
25// It speaks a query language that's semantically similar to but
26// syntantically different and simpler than SQL.  The syntax is as
27// follows:
28//
29//   WIPE
30//   CREATE|<tablename>|<col>=<type>,<col>=<type>,...
31//     where types are: "string", [u]int{8,16,32,64}, "bool"
32//   INSERT|<tablename>|col=val,col2=val2,col3=?
33//   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
34//
35// When opening a fakeDriver's database, it starts empty with no
36// tables.  All tables and data are stored in memory only.
37type fakeDriver struct {
38	mu         sync.Mutex // guards 3 following fields
39	openCount  int        // conn opens
40	closeCount int        // conn closes
41	dbs        map[string]*fakeDB
42}
43
44type fakeDB struct {
45	name string
46
47	mu      sync.Mutex
48	free    []*fakeConn
49	tables  map[string]*table
50	badConn bool
51}
52
53type table struct {
54	mu      sync.Mutex
55	colname []string
56	coltype []string
57	rows    []*row
58}
59
60func (t *table) columnIndex(name string) int {
61	for n, nname := range t.colname {
62		if name == nname {
63			return n
64		}
65	}
66	return -1
67}
68
69type row struct {
70	cols []interface{} // must be same size as its table colname + coltype
71}
72
73func (r *row) clone() *row {
74	nrow := &row{cols: make([]interface{}, len(r.cols))}
75	copy(nrow.cols, r.cols)
76	return nrow
77}
78
79type fakeConn struct {
80	db *fakeDB // where to return ourselves to
81
82	currTx *fakeTx
83
84	// Stats for tests:
85	mu          sync.Mutex
86	stmtsMade   int
87	stmtsClosed int
88	numPrepare  int
89	bad         bool
90}
91
92func (c *fakeConn) incrStat(v *int) {
93	c.mu.Lock()
94	*v++
95	c.mu.Unlock()
96}
97
98type fakeTx struct {
99	c *fakeConn
100}
101
102type fakeStmt struct {
103	c *fakeConn
104	q string // just for debugging
105
106	cmd   string
107	table string
108
109	closed bool
110
111	colName      []string      // used by CREATE, INSERT, SELECT (selected columns)
112	colType      []string      // used by CREATE
113	colValue     []interface{} // used by INSERT (mix of strings and "?" for bound params)
114	placeholders int           // used by INSERT/SELECT: number of ? params
115
116	whereCol []string // used by SELECT (all placeholders)
117
118	placeholderConverter []driver.ValueConverter // used by INSERT
119}
120
121var fdriver driver.Driver = &fakeDriver{}
122
123func init() {
124	Register("test", fdriver)
125}
126
127// Supports dsn forms:
128//    <dbname>
129//    <dbname>;<opts>  (only currently supported option is `badConn`,
130//                      which causes driver.ErrBadConn to be returned on
131//                      every other conn.Begin())
132func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
133	parts := strings.Split(dsn, ";")
134	if len(parts) < 1 {
135		return nil, errors.New("fakedb: no database name")
136	}
137	name := parts[0]
138
139	db := d.getDB(name)
140
141	d.mu.Lock()
142	d.openCount++
143	d.mu.Unlock()
144	conn := &fakeConn{db: db}
145
146	if len(parts) >= 2 && parts[1] == "badConn" {
147		conn.bad = true
148	}
149	return conn, nil
150}
151
152func (d *fakeDriver) getDB(name string) *fakeDB {
153	d.mu.Lock()
154	defer d.mu.Unlock()
155	if d.dbs == nil {
156		d.dbs = make(map[string]*fakeDB)
157	}
158	db, ok := d.dbs[name]
159	if !ok {
160		db = &fakeDB{name: name}
161		d.dbs[name] = db
162	}
163	return db
164}
165
166func (db *fakeDB) wipe() {
167	db.mu.Lock()
168	defer db.mu.Unlock()
169	db.tables = nil
170}
171
172func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
173	db.mu.Lock()
174	defer db.mu.Unlock()
175	if db.tables == nil {
176		db.tables = make(map[string]*table)
177	}
178	if _, exist := db.tables[name]; exist {
179		return fmt.Errorf("table %q already exists", name)
180	}
181	if len(columnNames) != len(columnTypes) {
182		return fmt.Errorf("create table of %q len(names) != len(types): %d vs %d",
183			name, len(columnNames), len(columnTypes))
184	}
185	db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
186	return nil
187}
188
189// must be called with db.mu lock held
190func (db *fakeDB) table(table string) (*table, bool) {
191	if db.tables == nil {
192		return nil, false
193	}
194	t, ok := db.tables[table]
195	return t, ok
196}
197
198func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
199	db.mu.Lock()
200	defer db.mu.Unlock()
201	t, ok := db.table(table)
202	if !ok {
203		return
204	}
205	for n, cname := range t.colname {
206		if cname == column {
207			return t.coltype[n], true
208		}
209	}
210	return "", false
211}
212
213func (c *fakeConn) isBad() bool {
214	// if not simulating bad conn, do nothing
215	if !c.bad {
216		return false
217	}
218	// alternate between bad conn and not bad conn
219	c.db.badConn = !c.db.badConn
220	return c.db.badConn
221}
222
223func (c *fakeConn) Begin() (driver.Tx, error) {
224	if c.isBad() {
225		return nil, driver.ErrBadConn
226	}
227	if c.currTx != nil {
228		return nil, errors.New("already in a transaction")
229	}
230	c.currTx = &fakeTx{c: c}
231	return c.currTx, nil
232}
233
234var hookPostCloseConn struct {
235	sync.Mutex
236	fn func(*fakeConn, error)
237}
238
239func setHookpostCloseConn(fn func(*fakeConn, error)) {
240	hookPostCloseConn.Lock()
241	defer hookPostCloseConn.Unlock()
242	hookPostCloseConn.fn = fn
243}
244
245var testStrictClose *testing.T
246
247// setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close
248// fails to close. If nil, the check is disabled.
249func setStrictFakeConnClose(t *testing.T) {
250	testStrictClose = t
251}
252
253func (c *fakeConn) Close() (err error) {
254	drv := fdriver.(*fakeDriver)
255	defer func() {
256		if err != nil && testStrictClose != nil {
257			testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
258		}
259		hookPostCloseConn.Lock()
260		fn := hookPostCloseConn.fn
261		hookPostCloseConn.Unlock()
262		if fn != nil {
263			fn(c, err)
264		}
265		if err == nil {
266			drv.mu.Lock()
267			drv.closeCount++
268			drv.mu.Unlock()
269		}
270	}()
271	if c.currTx != nil {
272		return errors.New("can't close fakeConn; in a Transaction")
273	}
274	if c.db == nil {
275		return errors.New("can't close fakeConn; already closed")
276	}
277	if c.stmtsMade > c.stmtsClosed {
278		return errors.New("can't close; dangling statement(s)")
279	}
280	c.db = nil
281	return nil
282}
283
284func checkSubsetTypes(args []driver.Value) error {
285	for n, arg := range args {
286		switch arg.(type) {
287		case int64, float64, bool, nil, []byte, string, time.Time:
288		default:
289			return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg)
290		}
291	}
292	return nil
293}
294
295func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
296	// This is an optional interface, but it's implemented here
297	// just to check that all the args are of the proper types.
298	// ErrSkip is returned so the caller acts as if we didn't
299	// implement this at all.
300	err := checkSubsetTypes(args)
301	if err != nil {
302		return nil, err
303	}
304	return nil, driver.ErrSkip
305}
306
307func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
308	// This is an optional interface, but it's implemented here
309	// just to check that all the args are of the proper types.
310	// ErrSkip is returned so the caller acts as if we didn't
311	// implement this at all.
312	err := checkSubsetTypes(args)
313	if err != nil {
314		return nil, err
315	}
316	return nil, driver.ErrSkip
317}
318
319func errf(msg string, args ...interface{}) error {
320	return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
321}
322
323// parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
324// (note that where columns must always contain ? marks,
325//  just a limitation for fakedb)
326func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
327	if len(parts) != 3 {
328		stmt.Close()
329		return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
330	}
331	stmt.table = parts[0]
332	stmt.colName = strings.Split(parts[1], ",")
333	for n, colspec := range strings.Split(parts[2], ",") {
334		if colspec == "" {
335			continue
336		}
337		nameVal := strings.Split(colspec, "=")
338		if len(nameVal) != 2 {
339			stmt.Close()
340			return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
341		}
342		column, value := nameVal[0], nameVal[1]
343		_, ok := c.db.columnType(stmt.table, column)
344		if !ok {
345			stmt.Close()
346			return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
347		}
348		if value != "?" {
349			stmt.Close()
350			return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
351				stmt.table, column)
352		}
353		stmt.whereCol = append(stmt.whereCol, column)
354		stmt.placeholders++
355	}
356	return stmt, nil
357}
358
359// parts are table|col=type,col2=type2
360func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
361	if len(parts) != 2 {
362		stmt.Close()
363		return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
364	}
365	stmt.table = parts[0]
366	for n, colspec := range strings.Split(parts[1], ",") {
367		nameType := strings.Split(colspec, "=")
368		if len(nameType) != 2 {
369			stmt.Close()
370			return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
371		}
372		stmt.colName = append(stmt.colName, nameType[0])
373		stmt.colType = append(stmt.colType, nameType[1])
374	}
375	return stmt, nil
376}
377
378// parts are table|col=?,col2=val
379func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
380	if len(parts) != 2 {
381		stmt.Close()
382		return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
383	}
384	stmt.table = parts[0]
385	for n, colspec := range strings.Split(parts[1], ",") {
386		nameVal := strings.Split(colspec, "=")
387		if len(nameVal) != 2 {
388			stmt.Close()
389			return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
390		}
391		column, value := nameVal[0], nameVal[1]
392		ctype, ok := c.db.columnType(stmt.table, column)
393		if !ok {
394			stmt.Close()
395			return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
396		}
397		stmt.colName = append(stmt.colName, column)
398
399		if value != "?" {
400			var subsetVal interface{}
401			// Convert to driver subset type
402			switch ctype {
403			case "string":
404				subsetVal = []byte(value)
405			case "blob":
406				subsetVal = []byte(value)
407			case "int32":
408				i, err := strconv.Atoi(value)
409				if err != nil {
410					stmt.Close()
411					return nil, errf("invalid conversion to int32 from %q", value)
412				}
413				subsetVal = int64(i) // int64 is a subset type, but not int32
414			default:
415				stmt.Close()
416				return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
417			}
418			stmt.colValue = append(stmt.colValue, subsetVal)
419		} else {
420			stmt.placeholders++
421			stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
422			stmt.colValue = append(stmt.colValue, "?")
423		}
424	}
425	return stmt, nil
426}
427
428func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
429	c.numPrepare++
430	if c.db == nil {
431		panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
432	}
433	parts := strings.Split(query, "|")
434	if len(parts) < 1 {
435		return nil, errf("empty query")
436	}
437	cmd := parts[0]
438	parts = parts[1:]
439	stmt := &fakeStmt{q: query, c: c, cmd: cmd}
440	c.incrStat(&c.stmtsMade)
441	switch cmd {
442	case "WIPE":
443		// Nothing
444	case "SELECT":
445		return c.prepareSelect(stmt, parts)
446	case "CREATE":
447		return c.prepareCreate(stmt, parts)
448	case "INSERT":
449		return c.prepareInsert(stmt, parts)
450	default:
451		stmt.Close()
452		return nil, errf("unsupported command type %q", cmd)
453	}
454	return stmt, nil
455}
456
457func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
458	if len(s.placeholderConverter) == 0 {
459		return driver.DefaultParameterConverter
460	}
461	return s.placeholderConverter[idx]
462}
463
464func (s *fakeStmt) Close() error {
465	if s.c == nil {
466		panic("nil conn in fakeStmt.Close")
467	}
468	if s.c.db == nil {
469		panic("in fakeStmt.Close, conn's db is nil (already closed)")
470	}
471	if !s.closed {
472		s.c.incrStat(&s.c.stmtsClosed)
473		s.closed = true
474	}
475	return nil
476}
477
478var errClosed = errors.New("fakedb: statement has been closed")
479
480func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
481	if s.closed {
482		return nil, errClosed
483	}
484	err := checkSubsetTypes(args)
485	if err != nil {
486		return nil, err
487	}
488
489	db := s.c.db
490	switch s.cmd {
491	case "WIPE":
492		db.wipe()
493		return driver.ResultNoRows, nil
494	case "CREATE":
495		if err := db.createTable(s.table, s.colName, s.colType); err != nil {
496			return nil, err
497		}
498		return driver.ResultNoRows, nil
499	case "INSERT":
500		return s.execInsert(args)
501	}
502	fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s)
503	return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd)
504}
505
506func (s *fakeStmt) execInsert(args []driver.Value) (driver.Result, error) {
507	db := s.c.db
508	if len(args) != s.placeholders {
509		panic("error in pkg db; should only get here if size is correct")
510	}
511	db.mu.Lock()
512	t, ok := db.table(s.table)
513	db.mu.Unlock()
514	if !ok {
515		return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
516	}
517
518	t.mu.Lock()
519	defer t.mu.Unlock()
520
521	cols := make([]interface{}, len(t.colname))
522	argPos := 0
523	for n, colname := range s.colName {
524		colidx := t.columnIndex(colname)
525		if colidx == -1 {
526			return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
527		}
528		var val interface{}
529		if strvalue, ok := s.colValue[n].(string); ok && strvalue == "?" {
530			val = args[argPos]
531			argPos++
532		} else {
533			val = s.colValue[n]
534		}
535		cols[colidx] = val
536	}
537
538	t.rows = append(t.rows, &row{cols: cols})
539	return driver.RowsAffected(1), nil
540}
541
542func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
543	if s.closed {
544		return nil, errClosed
545	}
546	err := checkSubsetTypes(args)
547	if err != nil {
548		return nil, err
549	}
550
551	db := s.c.db
552	if len(args) != s.placeholders {
553		panic("error in pkg db; should only get here if size is correct")
554	}
555
556	db.mu.Lock()
557	t, ok := db.table(s.table)
558	db.mu.Unlock()
559	if !ok {
560		return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
561	}
562
563	if s.table == "magicquery" {
564		if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" {
565			if args[0] == "sleep" {
566				time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond)
567			}
568		}
569	}
570
571	t.mu.Lock()
572	defer t.mu.Unlock()
573
574	colIdx := make(map[string]int) // select column name -> column index in table
575	for _, name := range s.colName {
576		idx := t.columnIndex(name)
577		if idx == -1 {
578			return nil, fmt.Errorf("fakedb: unknown column name %q", name)
579		}
580		colIdx[name] = idx
581	}
582
583	mrows := []*row{}
584rows:
585	for _, trow := range t.rows {
586		// Process the where clause, skipping non-match rows. This is lazy
587		// and just uses fmt.Sprintf("%v") to test equality.  Good enough
588		// for test code.
589		for widx, wcol := range s.whereCol {
590			idx := t.columnIndex(wcol)
591			if idx == -1 {
592				return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
593			}
594			tcol := trow.cols[idx]
595			if bs, ok := tcol.([]byte); ok {
596				// lazy hack to avoid sprintf %v on a []byte
597				tcol = string(bs)
598			}
599			if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) {
600				continue rows
601			}
602		}
603		mrow := &row{cols: make([]interface{}, len(s.colName))}
604		for seli, name := range s.colName {
605			mrow.cols[seli] = trow.cols[colIdx[name]]
606		}
607		mrows = append(mrows, mrow)
608	}
609
610	cursor := &rowsCursor{
611		pos:  -1,
612		rows: mrows,
613		cols: s.colName,
614	}
615	return cursor, nil
616}
617
618func (s *fakeStmt) NumInput() int {
619	return s.placeholders
620}
621
622func (tx *fakeTx) Commit() error {
623	tx.c.currTx = nil
624	return nil
625}
626
627func (tx *fakeTx) Rollback() error {
628	tx.c.currTx = nil
629	return nil
630}
631
632type rowsCursor struct {
633	cols   []string
634	pos    int
635	rows   []*row
636	closed bool
637
638	// a clone of slices to give out to clients, indexed by the
639	// the original slice's first byte address.  we clone them
640	// just so we're able to corrupt them on close.
641	bytesClone map[*byte][]byte
642}
643
644func (rc *rowsCursor) Close() error {
645	if !rc.closed {
646		for _, bs := range rc.bytesClone {
647			bs[0] = 255 // first byte corrupted
648		}
649	}
650	rc.closed = true
651	return nil
652}
653
654func (rc *rowsCursor) Columns() []string {
655	return rc.cols
656}
657
658func (rc *rowsCursor) Next(dest []driver.Value) error {
659	if rc.closed {
660		return errors.New("fakedb: cursor is closed")
661	}
662	rc.pos++
663	if rc.pos >= len(rc.rows) {
664		return io.EOF // per interface spec
665	}
666	for i, v := range rc.rows[rc.pos].cols {
667		// TODO(bradfitz): convert to subset types? naah, I
668		// think the subset types should only be input to
669		// driver, but the sql package should be able to handle
670		// a wider range of types coming out of drivers. all
671		// for ease of drivers, and to prevent drivers from
672		// messing up conversions or doing them differently.
673		dest[i] = v
674
675		if bs, ok := v.([]byte); ok {
676			if rc.bytesClone == nil {
677				rc.bytesClone = make(map[*byte][]byte)
678			}
679			clone, ok := rc.bytesClone[&bs[0]]
680			if !ok {
681				clone = make([]byte, len(bs))
682				copy(clone, bs)
683				rc.bytesClone[&bs[0]] = clone
684			}
685			dest[i] = clone
686		}
687	}
688	return nil
689}
690
691// fakeDriverString is like driver.String, but indirects pointers like
692// DefaultValueConverter.
693//
694// This could be surprising behavior to retroactively apply to
695// driver.String now that Go1 is out, but this is convenient for
696// our TestPointerParamsAndScans.
697//
698type fakeDriverString struct{}
699
700func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) {
701	switch c := v.(type) {
702	case string, []byte:
703		return v, nil
704	case *string:
705		if c == nil {
706			return nil, nil
707		}
708		return *c, nil
709	}
710	return fmt.Sprintf("%v", v), nil
711}
712
713func converterForType(typ string) driver.ValueConverter {
714	switch typ {
715	case "bool":
716		return driver.Bool
717	case "nullbool":
718		return driver.Null{Converter: driver.Bool}
719	case "int32":
720		return driver.Int32
721	case "string":
722		return driver.NotNull{Converter: fakeDriverString{}}
723	case "nullstring":
724		return driver.Null{Converter: fakeDriverString{}}
725	case "int64":
726		// TODO(coopernurse): add type-specific converter
727		return driver.NotNull{Converter: driver.DefaultParameterConverter}
728	case "nullint64":
729		// TODO(coopernurse): add type-specific converter
730		return driver.Null{Converter: driver.DefaultParameterConverter}
731	case "float64":
732		// TODO(coopernurse): add type-specific converter
733		return driver.NotNull{Converter: driver.DefaultParameterConverter}
734	case "nullfloat64":
735		// TODO(coopernurse): add type-specific converter
736		return driver.Null{Converter: driver.DefaultParameterConverter}
737	case "datetime":
738		return driver.DefaultParameterConverter
739	}
740	panic("invalid fakedb column type of " + typ)
741}
742