1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package sql
6
7import (
8	"context"
9	"database/sql/driver"
10	"errors"
11	"fmt"
12	"io"
13	"log"
14	"reflect"
15	"sort"
16	"strconv"
17	"strings"
18	"sync"
19	"testing"
20	"time"
21)
22
23var _ = log.Printf
24
25// fakeDriver is a fake database that implements Go's driver.Driver
26// interface, just for testing.
27//
28// It speaks a query language that's semantically similar to but
29// syntactically different and simpler than SQL.  The syntax is as
30// follows:
31//
32//   WIPE
33//   CREATE|<tablename>|<col>=<type>,<col>=<type>,...
34//     where types are: "string", [u]int{8,16,32,64}, "bool"
35//   INSERT|<tablename>|col=val,col2=val2,col3=?
36//   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
37//   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2
38//
39// Any of these can be preceded by PANIC|<method>|, to cause the
40// named method on fakeStmt to panic.
41//
42// Any of these can be proceeded by WAIT|<duration>|, to cause the
43// named method on fakeStmt to sleep for the specified duration.
44//
45// Multiple of these can be combined when separated with a semicolon.
46//
47// When opening a fakeDriver's database, it starts empty with no
48// tables. All tables and data are stored in memory only.
49type fakeDriver struct {
50	mu         sync.Mutex // guards 3 following fields
51	openCount  int        // conn opens
52	closeCount int        // conn closes
53	waitCh     chan struct{}
54	waitingCh  chan struct{}
55	dbs        map[string]*fakeDB
56}
57
58type fakeConnector struct {
59	name string
60
61	waiter func(context.Context)
62}
63
64func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) {
65	conn, err := fdriver.Open(c.name)
66	conn.(*fakeConn).waiter = c.waiter
67	return conn, err
68}
69
70func (c *fakeConnector) Driver() driver.Driver {
71	return fdriver
72}
73
74type fakeDriverCtx struct {
75	fakeDriver
76}
77
78var _ driver.DriverContext = &fakeDriverCtx{}
79
80func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) {
81	return &fakeConnector{name: name}, nil
82}
83
84type fakeDB struct {
85	name string
86
87	mu       sync.Mutex
88	tables   map[string]*table
89	badConn  bool
90	allowAny bool
91}
92
93type table struct {
94	mu      sync.Mutex
95	colname []string
96	coltype []string
97	rows    []*row
98}
99
100func (t *table) columnIndex(name string) int {
101	for n, nname := range t.colname {
102		if name == nname {
103			return n
104		}
105	}
106	return -1
107}
108
109type row struct {
110	cols []interface{} // must be same size as its table colname + coltype
111}
112
113type memToucher interface {
114	// touchMem reads & writes some memory, to help find data races.
115	touchMem()
116}
117
118type fakeConn struct {
119	db *fakeDB // where to return ourselves to
120
121	currTx *fakeTx
122
123	// Every operation writes to line to enable the race detector
124	// check for data races.
125	line int64
126
127	// Stats for tests:
128	mu          sync.Mutex
129	stmtsMade   int
130	stmtsClosed int
131	numPrepare  int
132
133	// bad connection tests; see isBad()
134	bad       bool
135	stickyBad bool
136
137	skipDirtySession bool // tests that use Conn should set this to true.
138
139	// dirtySession tests ResetSession, true if a query has executed
140	// until ResetSession is called.
141	dirtySession bool
142
143	// The waiter is called before each query. May be used in place of the "WAIT"
144	// directive.
145	waiter func(context.Context)
146}
147
148func (c *fakeConn) touchMem() {
149	c.line++
150}
151
152func (c *fakeConn) incrStat(v *int) {
153	c.mu.Lock()
154	*v++
155	c.mu.Unlock()
156}
157
158type fakeTx struct {
159	c *fakeConn
160}
161
162type boundCol struct {
163	Column      string
164	Placeholder string
165	Ordinal     int
166}
167
168type fakeStmt struct {
169	memToucher
170	c *fakeConn
171	q string // just for debugging
172
173	cmd   string
174	table string
175	panic string
176	wait  time.Duration
177
178	next *fakeStmt // used for returning multiple results.
179
180	closed bool
181
182	colName      []string      // used by CREATE, INSERT, SELECT (selected columns)
183	colType      []string      // used by CREATE
184	colValue     []interface{} // used by INSERT (mix of strings and "?" for bound params)
185	placeholders int           // used by INSERT/SELECT: number of ? params
186
187	whereCol []boundCol // used by SELECT (all placeholders)
188
189	placeholderConverter []driver.ValueConverter // used by INSERT
190}
191
192var fdriver driver.Driver = &fakeDriver{}
193
194func init() {
195	Register("test", fdriver)
196}
197
198func contains(list []string, y string) bool {
199	for _, x := range list {
200		if x == y {
201			return true
202		}
203	}
204	return false
205}
206
207type Dummy struct {
208	driver.Driver
209}
210
211func TestDrivers(t *testing.T) {
212	unregisterAllDrivers()
213	Register("test", fdriver)
214	Register("invalid", Dummy{})
215	all := Drivers()
216	if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") {
217		t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
218	}
219}
220
221// hook to simulate connection failures
222var hookOpenErr struct {
223	sync.Mutex
224	fn func() error
225}
226
227func setHookOpenErr(fn func() error) {
228	hookOpenErr.Lock()
229	defer hookOpenErr.Unlock()
230	hookOpenErr.fn = fn
231}
232
233// Supports dsn forms:
234//    <dbname>
235//    <dbname>;<opts>  (only currently supported option is `badConn`,
236//                      which causes driver.ErrBadConn to be returned on
237//                      every other conn.Begin())
238func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
239	hookOpenErr.Lock()
240	fn := hookOpenErr.fn
241	hookOpenErr.Unlock()
242	if fn != nil {
243		if err := fn(); err != nil {
244			return nil, err
245		}
246	}
247	parts := strings.Split(dsn, ";")
248	if len(parts) < 1 {
249		return nil, errors.New("fakedb: no database name")
250	}
251	name := parts[0]
252
253	db := d.getDB(name)
254
255	d.mu.Lock()
256	d.openCount++
257	d.mu.Unlock()
258	conn := &fakeConn{db: db}
259
260	if len(parts) >= 2 && parts[1] == "badConn" {
261		conn.bad = true
262	}
263	if d.waitCh != nil {
264		d.waitingCh <- struct{}{}
265		<-d.waitCh
266		d.waitCh = nil
267		d.waitingCh = nil
268	}
269	return conn, nil
270}
271
272func (d *fakeDriver) getDB(name string) *fakeDB {
273	d.mu.Lock()
274	defer d.mu.Unlock()
275	if d.dbs == nil {
276		d.dbs = make(map[string]*fakeDB)
277	}
278	db, ok := d.dbs[name]
279	if !ok {
280		db = &fakeDB{name: name}
281		d.dbs[name] = db
282	}
283	return db
284}
285
286func (db *fakeDB) wipe() {
287	db.mu.Lock()
288	defer db.mu.Unlock()
289	db.tables = nil
290}
291
292func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
293	db.mu.Lock()
294	defer db.mu.Unlock()
295	if db.tables == nil {
296		db.tables = make(map[string]*table)
297	}
298	if _, exist := db.tables[name]; exist {
299		return fmt.Errorf("table %q already exists", name)
300	}
301	if len(columnNames) != len(columnTypes) {
302		return fmt.Errorf("create table of %q len(names) != len(types): %d vs %d",
303			name, len(columnNames), len(columnTypes))
304	}
305	db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
306	return nil
307}
308
309// must be called with db.mu lock held
310func (db *fakeDB) table(table string) (*table, bool) {
311	if db.tables == nil {
312		return nil, false
313	}
314	t, ok := db.tables[table]
315	return t, ok
316}
317
318func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
319	db.mu.Lock()
320	defer db.mu.Unlock()
321	t, ok := db.table(table)
322	if !ok {
323		return
324	}
325	for n, cname := range t.colname {
326		if cname == column {
327			return t.coltype[n], true
328		}
329	}
330	return "", false
331}
332
333func (c *fakeConn) isBad() bool {
334	if c.stickyBad {
335		return true
336	} else if c.bad {
337		if c.db == nil {
338			return false
339		}
340		// alternate between bad conn and not bad conn
341		c.db.badConn = !c.db.badConn
342		return c.db.badConn
343	} else {
344		return false
345	}
346}
347
348func (c *fakeConn) isDirtyAndMark() bool {
349	if c.skipDirtySession {
350		return false
351	}
352	if c.currTx != nil {
353		c.dirtySession = true
354		return false
355	}
356	if c.dirtySession {
357		return true
358	}
359	c.dirtySession = true
360	return false
361}
362
363func (c *fakeConn) Begin() (driver.Tx, error) {
364	if c.isBad() {
365		return nil, driver.ErrBadConn
366	}
367	if c.currTx != nil {
368		return nil, errors.New("already in a transaction")
369	}
370	c.touchMem()
371	c.currTx = &fakeTx{c: c}
372	return c.currTx, nil
373}
374
375var hookPostCloseConn struct {
376	sync.Mutex
377	fn func(*fakeConn, error)
378}
379
380func setHookpostCloseConn(fn func(*fakeConn, error)) {
381	hookPostCloseConn.Lock()
382	defer hookPostCloseConn.Unlock()
383	hookPostCloseConn.fn = fn
384}
385
386var testStrictClose *testing.T
387
388// setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close
389// fails to close. If nil, the check is disabled.
390func setStrictFakeConnClose(t *testing.T) {
391	testStrictClose = t
392}
393
394func (c *fakeConn) ResetSession(ctx context.Context) error {
395	c.dirtySession = false
396	if c.isBad() {
397		return driver.ErrBadConn
398	}
399	return nil
400}
401
402func (c *fakeConn) Close() (err error) {
403	drv := fdriver.(*fakeDriver)
404	defer func() {
405		if err != nil && testStrictClose != nil {
406			testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
407		}
408		hookPostCloseConn.Lock()
409		fn := hookPostCloseConn.fn
410		hookPostCloseConn.Unlock()
411		if fn != nil {
412			fn(c, err)
413		}
414		if err == nil {
415			drv.mu.Lock()
416			drv.closeCount++
417			drv.mu.Unlock()
418		}
419	}()
420	c.touchMem()
421	if c.currTx != nil {
422		return errors.New("can't close fakeConn; in a Transaction")
423	}
424	if c.db == nil {
425		return errors.New("can't close fakeConn; already closed")
426	}
427	if c.stmtsMade > c.stmtsClosed {
428		return errors.New("can't close; dangling statement(s)")
429	}
430	c.db = nil
431	return nil
432}
433
434func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error {
435	for _, arg := range args {
436		switch arg.Value.(type) {
437		case int64, float64, bool, nil, []byte, string, time.Time:
438		default:
439			if !allowAny {
440				return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
441			}
442		}
443	}
444	return nil
445}
446
447func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
448	// Ensure that ExecContext is called if available.
449	panic("ExecContext was not called.")
450}
451
452func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
453	// This is an optional interface, but it's implemented here
454	// just to check that all the args are of the proper types.
455	// ErrSkip is returned so the caller acts as if we didn't
456	// implement this at all.
457	err := checkSubsetTypes(c.db.allowAny, args)
458	if err != nil {
459		return nil, err
460	}
461	return nil, driver.ErrSkip
462}
463
464func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
465	// Ensure that ExecContext is called if available.
466	panic("QueryContext was not called.")
467}
468
469func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
470	// This is an optional interface, but it's implemented here
471	// just to check that all the args are of the proper types.
472	// ErrSkip is returned so the caller acts as if we didn't
473	// implement this at all.
474	err := checkSubsetTypes(c.db.allowAny, args)
475	if err != nil {
476		return nil, err
477	}
478	return nil, driver.ErrSkip
479}
480
481func errf(msg string, args ...interface{}) error {
482	return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
483}
484
485// parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
486// (note that where columns must always contain ? marks,
487//  just a limitation for fakedb)
488func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
489	if len(parts) != 3 {
490		stmt.Close()
491		return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
492	}
493	stmt.table = parts[0]
494
495	stmt.colName = strings.Split(parts[1], ",")
496	for n, colspec := range strings.Split(parts[2], ",") {
497		if colspec == "" {
498			continue
499		}
500		nameVal := strings.Split(colspec, "=")
501		if len(nameVal) != 2 {
502			stmt.Close()
503			return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
504		}
505		column, value := nameVal[0], nameVal[1]
506		_, ok := c.db.columnType(stmt.table, column)
507		if !ok {
508			stmt.Close()
509			return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
510		}
511		if !strings.HasPrefix(value, "?") {
512			stmt.Close()
513			return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
514				stmt.table, column)
515		}
516		stmt.placeholders++
517		stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders})
518	}
519	return stmt, nil
520}
521
522// parts are table|col=type,col2=type2
523func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
524	if len(parts) != 2 {
525		stmt.Close()
526		return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
527	}
528	stmt.table = parts[0]
529	for n, colspec := range strings.Split(parts[1], ",") {
530		nameType := strings.Split(colspec, "=")
531		if len(nameType) != 2 {
532			stmt.Close()
533			return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
534		}
535		stmt.colName = append(stmt.colName, nameType[0])
536		stmt.colType = append(stmt.colType, nameType[1])
537	}
538	return stmt, nil
539}
540
541// parts are table|col=?,col2=val
542func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
543	if len(parts) != 2 {
544		stmt.Close()
545		return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
546	}
547	stmt.table = parts[0]
548	for n, colspec := range strings.Split(parts[1], ",") {
549		nameVal := strings.Split(colspec, "=")
550		if len(nameVal) != 2 {
551			stmt.Close()
552			return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
553		}
554		column, value := nameVal[0], nameVal[1]
555		ctype, ok := c.db.columnType(stmt.table, column)
556		if !ok {
557			stmt.Close()
558			return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
559		}
560		stmt.colName = append(stmt.colName, column)
561
562		if !strings.HasPrefix(value, "?") {
563			var subsetVal interface{}
564			// Convert to driver subset type
565			switch ctype {
566			case "string":
567				subsetVal = []byte(value)
568			case "blob":
569				subsetVal = []byte(value)
570			case "int32":
571				i, err := strconv.Atoi(value)
572				if err != nil {
573					stmt.Close()
574					return nil, errf("invalid conversion to int32 from %q", value)
575				}
576				subsetVal = int64(i) // int64 is a subset type, but not int32
577			default:
578				stmt.Close()
579				return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
580			}
581			stmt.colValue = append(stmt.colValue, subsetVal)
582		} else {
583			stmt.placeholders++
584			stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
585			stmt.colValue = append(stmt.colValue, value)
586		}
587	}
588	return stmt, nil
589}
590
591// hook to simulate broken connections
592var hookPrepareBadConn func() bool
593
594func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
595	panic("use PrepareContext")
596}
597
598func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
599	c.numPrepare++
600	if c.db == nil {
601		panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
602	}
603
604	if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) {
605		return nil, driver.ErrBadConn
606	}
607
608	c.touchMem()
609	var firstStmt, prev *fakeStmt
610	for _, query := range strings.Split(query, ";") {
611		parts := strings.Split(query, "|")
612		if len(parts) < 1 {
613			return nil, errf("empty query")
614		}
615		stmt := &fakeStmt{q: query, c: c, memToucher: c}
616		if firstStmt == nil {
617			firstStmt = stmt
618		}
619		if len(parts) >= 3 {
620			switch parts[0] {
621			case "PANIC":
622				stmt.panic = parts[1]
623				parts = parts[2:]
624			case "WAIT":
625				wait, err := time.ParseDuration(parts[1])
626				if err != nil {
627					return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err)
628				}
629				parts = parts[2:]
630				stmt.wait = wait
631			}
632		}
633		cmd := parts[0]
634		stmt.cmd = cmd
635		parts = parts[1:]
636
637		if c.waiter != nil {
638			c.waiter(ctx)
639		}
640
641		if stmt.wait > 0 {
642			wait := time.NewTimer(stmt.wait)
643			select {
644			case <-wait.C:
645			case <-ctx.Done():
646				wait.Stop()
647				return nil, ctx.Err()
648			}
649		}
650
651		c.incrStat(&c.stmtsMade)
652		var err error
653		switch cmd {
654		case "WIPE":
655			// Nothing
656		case "SELECT":
657			stmt, err = c.prepareSelect(stmt, parts)
658		case "CREATE":
659			stmt, err = c.prepareCreate(stmt, parts)
660		case "INSERT":
661			stmt, err = c.prepareInsert(stmt, parts)
662		case "NOSERT":
663			// Do all the prep-work like for an INSERT but don't actually insert the row.
664			// Used for some of the concurrent tests.
665			stmt, err = c.prepareInsert(stmt, parts)
666		default:
667			stmt.Close()
668			return nil, errf("unsupported command type %q", cmd)
669		}
670		if err != nil {
671			return nil, err
672		}
673		if prev != nil {
674			prev.next = stmt
675		}
676		prev = stmt
677	}
678	return firstStmt, nil
679}
680
681func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
682	if s.panic == "ColumnConverter" {
683		panic(s.panic)
684	}
685	if len(s.placeholderConverter) == 0 {
686		return driver.DefaultParameterConverter
687	}
688	return s.placeholderConverter[idx]
689}
690
691func (s *fakeStmt) Close() error {
692	if s.panic == "Close" {
693		panic(s.panic)
694	}
695	if s.c == nil {
696		panic("nil conn in fakeStmt.Close")
697	}
698	if s.c.db == nil {
699		panic("in fakeStmt.Close, conn's db is nil (already closed)")
700	}
701	s.touchMem()
702	if !s.closed {
703		s.c.incrStat(&s.c.stmtsClosed)
704		s.closed = true
705	}
706	if s.next != nil {
707		s.next.Close()
708	}
709	return nil
710}
711
712var errClosed = errors.New("fakedb: statement has been closed")
713
714// hook to simulate broken connections
715var hookExecBadConn func() bool
716
717func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
718	panic("Using ExecContext")
719}
720func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
721	if s.panic == "Exec" {
722		panic(s.panic)
723	}
724	if s.closed {
725		return nil, errClosed
726	}
727
728	if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
729		return nil, driver.ErrBadConn
730	}
731	if s.c.isDirtyAndMark() {
732		return nil, errors.New("session is dirty")
733	}
734
735	err := checkSubsetTypes(s.c.db.allowAny, args)
736	if err != nil {
737		return nil, err
738	}
739	s.touchMem()
740
741	if s.wait > 0 {
742		time.Sleep(s.wait)
743	}
744
745	select {
746	default:
747	case <-ctx.Done():
748		return nil, ctx.Err()
749	}
750
751	db := s.c.db
752	switch s.cmd {
753	case "WIPE":
754		db.wipe()
755		return driver.ResultNoRows, nil
756	case "CREATE":
757		if err := db.createTable(s.table, s.colName, s.colType); err != nil {
758			return nil, err
759		}
760		return driver.ResultNoRows, nil
761	case "INSERT":
762		return s.execInsert(args, true)
763	case "NOSERT":
764		// Do all the prep-work like for an INSERT but don't actually insert the row.
765		// Used for some of the concurrent tests.
766		return s.execInsert(args, false)
767	}
768	fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s)
769	return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd)
770}
771
772// When doInsert is true, add the row to the table.
773// When doInsert is false do prep-work and error checking, but don't
774// actually add the row to the table.
775func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) {
776	db := s.c.db
777	if len(args) != s.placeholders {
778		panic("error in pkg db; should only get here if size is correct")
779	}
780	db.mu.Lock()
781	t, ok := db.table(s.table)
782	db.mu.Unlock()
783	if !ok {
784		return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
785	}
786
787	t.mu.Lock()
788	defer t.mu.Unlock()
789
790	var cols []interface{}
791	if doInsert {
792		cols = make([]interface{}, len(t.colname))
793	}
794	argPos := 0
795	for n, colname := range s.colName {
796		colidx := t.columnIndex(colname)
797		if colidx == -1 {
798			return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
799		}
800		var val interface{}
801		if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") {
802			if strvalue == "?" {
803				val = args[argPos].Value
804			} else {
805				// Assign value from argument placeholder name.
806				for _, a := range args {
807					if a.Name == strvalue[1:] {
808						val = a.Value
809						break
810					}
811				}
812			}
813			argPos++
814		} else {
815			val = s.colValue[n]
816		}
817		if doInsert {
818			cols[colidx] = val
819		}
820	}
821
822	if doInsert {
823		t.rows = append(t.rows, &row{cols: cols})
824	}
825	return driver.RowsAffected(1), nil
826}
827
828// hook to simulate broken connections
829var hookQueryBadConn func() bool
830
831func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
832	panic("Use QueryContext")
833}
834
835func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
836	if s.panic == "Query" {
837		panic(s.panic)
838	}
839	if s.closed {
840		return nil, errClosed
841	}
842
843	if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
844		return nil, driver.ErrBadConn
845	}
846	if s.c.isDirtyAndMark() {
847		return nil, errors.New("session is dirty")
848	}
849
850	err := checkSubsetTypes(s.c.db.allowAny, args)
851	if err != nil {
852		return nil, err
853	}
854
855	s.touchMem()
856	db := s.c.db
857	if len(args) != s.placeholders {
858		panic("error in pkg db; should only get here if size is correct")
859	}
860
861	setMRows := make([][]*row, 0, 1)
862	setColumns := make([][]string, 0, 1)
863	setColType := make([][]string, 0, 1)
864
865	for {
866		db.mu.Lock()
867		t, ok := db.table(s.table)
868		db.mu.Unlock()
869		if !ok {
870			return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
871		}
872
873		if s.table == "magicquery" {
874			if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" {
875				if args[0].Value == "sleep" {
876					time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond)
877				}
878			}
879		}
880
881		t.mu.Lock()
882
883		colIdx := make(map[string]int) // select column name -> column index in table
884		for _, name := range s.colName {
885			idx := t.columnIndex(name)
886			if idx == -1 {
887				t.mu.Unlock()
888				return nil, fmt.Errorf("fakedb: unknown column name %q", name)
889			}
890			colIdx[name] = idx
891		}
892
893		mrows := []*row{}
894	rows:
895		for _, trow := range t.rows {
896			// Process the where clause, skipping non-match rows. This is lazy
897			// and just uses fmt.Sprintf("%v") to test equality. Good enough
898			// for test code.
899			for _, wcol := range s.whereCol {
900				idx := t.columnIndex(wcol.Column)
901				if idx == -1 {
902					t.mu.Unlock()
903					return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
904				}
905				tcol := trow.cols[idx]
906				if bs, ok := tcol.([]byte); ok {
907					// lazy hack to avoid sprintf %v on a []byte
908					tcol = string(bs)
909				}
910				var argValue interface{}
911				if wcol.Placeholder == "?" {
912					argValue = args[wcol.Ordinal-1].Value
913				} else {
914					// Assign arg value from placeholder name.
915					for _, a := range args {
916						if a.Name == wcol.Placeholder[1:] {
917							argValue = a.Value
918							break
919						}
920					}
921				}
922				if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
923					continue rows
924				}
925			}
926			mrow := &row{cols: make([]interface{}, len(s.colName))}
927			for seli, name := range s.colName {
928				mrow.cols[seli] = trow.cols[colIdx[name]]
929			}
930			mrows = append(mrows, mrow)
931		}
932
933		var colType []string
934		for _, column := range s.colName {
935			colType = append(colType, t.coltype[t.columnIndex(column)])
936		}
937
938		t.mu.Unlock()
939
940		setMRows = append(setMRows, mrows)
941		setColumns = append(setColumns, s.colName)
942		setColType = append(setColType, colType)
943
944		if s.next == nil {
945			break
946		}
947		s = s.next
948	}
949
950	cursor := &rowsCursor{
951		parentMem: s.c,
952		posRow:    -1,
953		rows:      setMRows,
954		cols:      setColumns,
955		colType:   setColType,
956		errPos:    -1,
957	}
958	return cursor, nil
959}
960
961func (s *fakeStmt) NumInput() int {
962	if s.panic == "NumInput" {
963		panic(s.panic)
964	}
965	return s.placeholders
966}
967
968// hook to simulate broken connections
969var hookCommitBadConn func() bool
970
971func (tx *fakeTx) Commit() error {
972	tx.c.currTx = nil
973	if hookCommitBadConn != nil && hookCommitBadConn() {
974		return driver.ErrBadConn
975	}
976	tx.c.touchMem()
977	return nil
978}
979
980// hook to simulate broken connections
981var hookRollbackBadConn func() bool
982
983func (tx *fakeTx) Rollback() error {
984	tx.c.currTx = nil
985	if hookRollbackBadConn != nil && hookRollbackBadConn() {
986		return driver.ErrBadConn
987	}
988	tx.c.touchMem()
989	return nil
990}
991
992type rowsCursor struct {
993	parentMem memToucher
994	cols      [][]string
995	colType   [][]string
996	posSet    int
997	posRow    int
998	rows      [][]*row
999	closed    bool
1000
1001	// errPos and err are for making Next return early with error.
1002	errPos int
1003	err    error
1004
1005	// a clone of slices to give out to clients, indexed by the
1006	// the original slice's first byte address.  we clone them
1007	// just so we're able to corrupt them on close.
1008	bytesClone map[*byte][]byte
1009
1010	// Every operation writes to line to enable the race detector
1011	// check for data races.
1012	// This is separate from the fakeConn.line to allow for drivers that
1013	// can start multiple queries on the same transaction at the same time.
1014	line int64
1015}
1016
1017func (rc *rowsCursor) touchMem() {
1018	rc.parentMem.touchMem()
1019	rc.line++
1020}
1021
1022func (rc *rowsCursor) Close() error {
1023	rc.touchMem()
1024	rc.parentMem.touchMem()
1025	rc.closed = true
1026	return nil
1027}
1028
1029func (rc *rowsCursor) Columns() []string {
1030	return rc.cols[rc.posSet]
1031}
1032
1033func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
1034	return colTypeToReflectType(rc.colType[rc.posSet][index])
1035}
1036
1037var rowsCursorNextHook func(dest []driver.Value) error
1038
1039func (rc *rowsCursor) Next(dest []driver.Value) error {
1040	if rowsCursorNextHook != nil {
1041		return rowsCursorNextHook(dest)
1042	}
1043
1044	if rc.closed {
1045		return errors.New("fakedb: cursor is closed")
1046	}
1047	rc.touchMem()
1048	rc.posRow++
1049	if rc.posRow == rc.errPos {
1050		return rc.err
1051	}
1052	if rc.posRow >= len(rc.rows[rc.posSet]) {
1053		return io.EOF // per interface spec
1054	}
1055	for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
1056		// TODO(bradfitz): convert to subset types? naah, I
1057		// think the subset types should only be input to
1058		// driver, but the sql package should be able to handle
1059		// a wider range of types coming out of drivers. all
1060		// for ease of drivers, and to prevent drivers from
1061		// messing up conversions or doing them differently.
1062		dest[i] = v
1063
1064		if bs, ok := v.([]byte); ok {
1065			if rc.bytesClone == nil {
1066				rc.bytesClone = make(map[*byte][]byte)
1067			}
1068			clone, ok := rc.bytesClone[&bs[0]]
1069			if !ok {
1070				clone = make([]byte, len(bs))
1071				copy(clone, bs)
1072				rc.bytesClone[&bs[0]] = clone
1073			}
1074			dest[i] = clone
1075		}
1076	}
1077	return nil
1078}
1079
1080func (rc *rowsCursor) HasNextResultSet() bool {
1081	rc.touchMem()
1082	return rc.posSet < len(rc.rows)-1
1083}
1084
1085func (rc *rowsCursor) NextResultSet() error {
1086	rc.touchMem()
1087	if rc.HasNextResultSet() {
1088		rc.posSet++
1089		rc.posRow = -1
1090		return nil
1091	}
1092	return io.EOF // Per interface spec.
1093}
1094
1095// fakeDriverString is like driver.String, but indirects pointers like
1096// DefaultValueConverter.
1097//
1098// This could be surprising behavior to retroactively apply to
1099// driver.String now that Go1 is out, but this is convenient for
1100// our TestPointerParamsAndScans.
1101//
1102type fakeDriverString struct{}
1103
1104func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) {
1105	switch c := v.(type) {
1106	case string, []byte:
1107		return v, nil
1108	case *string:
1109		if c == nil {
1110			return nil, nil
1111		}
1112		return *c, nil
1113	}
1114	return fmt.Sprintf("%v", v), nil
1115}
1116
1117type anyTypeConverter struct{}
1118
1119func (anyTypeConverter) ConvertValue(v interface{}) (driver.Value, error) {
1120	return v, nil
1121}
1122
1123func converterForType(typ string) driver.ValueConverter {
1124	switch typ {
1125	case "bool":
1126		return driver.Bool
1127	case "nullbool":
1128		return driver.Null{Converter: driver.Bool}
1129	case "int32":
1130		return driver.Int32
1131	case "string":
1132		return driver.NotNull{Converter: fakeDriverString{}}
1133	case "nullstring":
1134		return driver.Null{Converter: fakeDriverString{}}
1135	case "int64":
1136		// TODO(coopernurse): add type-specific converter
1137		return driver.NotNull{Converter: driver.DefaultParameterConverter}
1138	case "nullint64":
1139		// TODO(coopernurse): add type-specific converter
1140		return driver.Null{Converter: driver.DefaultParameterConverter}
1141	case "float64":
1142		// TODO(coopernurse): add type-specific converter
1143		return driver.NotNull{Converter: driver.DefaultParameterConverter}
1144	case "nullfloat64":
1145		// TODO(coopernurse): add type-specific converter
1146		return driver.Null{Converter: driver.DefaultParameterConverter}
1147	case "datetime":
1148		return driver.DefaultParameterConverter
1149	case "any":
1150		return anyTypeConverter{}
1151	}
1152	panic("invalid fakedb column type of " + typ)
1153}
1154
1155func colTypeToReflectType(typ string) reflect.Type {
1156	switch typ {
1157	case "bool":
1158		return reflect.TypeOf(false)
1159	case "nullbool":
1160		return reflect.TypeOf(NullBool{})
1161	case "int32":
1162		return reflect.TypeOf(int32(0))
1163	case "string":
1164		return reflect.TypeOf("")
1165	case "nullstring":
1166		return reflect.TypeOf(NullString{})
1167	case "int64":
1168		return reflect.TypeOf(int64(0))
1169	case "nullint64":
1170		return reflect.TypeOf(NullInt64{})
1171	case "float64":
1172		return reflect.TypeOf(float64(0))
1173	case "nullfloat64":
1174		return reflect.TypeOf(NullFloat64{})
1175	case "datetime":
1176		return reflect.TypeOf(time.Time{})
1177	case "any":
1178		return reflect.TypeOf(new(interface{})).Elem()
1179	}
1180	panic("invalid fakedb column type of " + typ)
1181}
1182