1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package sql
6
7import (
8	"context"
9	"database/sql/driver"
10	"errors"
11	"fmt"
12	"io"
13	"reflect"
14	"sort"
15	"strconv"
16	"strings"
17	"sync"
18	"testing"
19	"time"
20)
21
22// fakeDriver is a fake database that implements Go's driver.Driver
23// interface, just for testing.
24//
25// It speaks a query language that's semantically similar to but
26// syntactically different and simpler than SQL.  The syntax is as
27// follows:
28//
29//   WIPE
30//   CREATE|<tablename>|<col>=<type>,<col>=<type>,...
31//     where types are: "string", [u]int{8,16,32,64}, "bool"
32//   INSERT|<tablename>|col=val,col2=val2,col3=?
33//   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
34//   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2
35//
36// Any of these can be preceded by PANIC|<method>|, to cause the
37// named method on fakeStmt to panic.
38//
39// Any of these can be proceeded by WAIT|<duration>|, to cause the
40// named method on fakeStmt to sleep for the specified duration.
41//
42// Multiple of these can be combined when separated with a semicolon.
43//
44// When opening a fakeDriver's database, it starts empty with no
45// tables. All tables and data are stored in memory only.
46type fakeDriver struct {
47	mu         sync.Mutex // guards 3 following fields
48	openCount  int        // conn opens
49	closeCount int        // conn closes
50	waitCh     chan struct{}
51	waitingCh  chan struct{}
52	dbs        map[string]*fakeDB
53}
54
55type fakeConnector struct {
56	name string
57
58	waiter func(context.Context)
59}
60
61func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) {
62	conn, err := fdriver.Open(c.name)
63	conn.(*fakeConn).waiter = c.waiter
64	return conn, err
65}
66
67func (c *fakeConnector) Driver() driver.Driver {
68	return fdriver
69}
70
71type fakeDriverCtx struct {
72	fakeDriver
73}
74
75var _ driver.DriverContext = &fakeDriverCtx{}
76
77func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) {
78	return &fakeConnector{name: name}, nil
79}
80
81type fakeDB struct {
82	name string
83
84	mu       sync.Mutex
85	tables   map[string]*table
86	badConn  bool
87	allowAny bool
88}
89
90type table struct {
91	mu      sync.Mutex
92	colname []string
93	coltype []string
94	rows    []*row
95}
96
97func (t *table) columnIndex(name string) int {
98	for n, nname := range t.colname {
99		if name == nname {
100			return n
101		}
102	}
103	return -1
104}
105
106type row struct {
107	cols []interface{} // must be same size as its table colname + coltype
108}
109
110type memToucher interface {
111	// touchMem reads & writes some memory, to help find data races.
112	touchMem()
113}
114
115type fakeConn struct {
116	db *fakeDB // where to return ourselves to
117
118	currTx *fakeTx
119
120	// Every operation writes to line to enable the race detector
121	// check for data races.
122	line int64
123
124	// Stats for tests:
125	mu          sync.Mutex
126	stmtsMade   int
127	stmtsClosed int
128	numPrepare  int
129
130	// bad connection tests; see isBad()
131	bad       bool
132	stickyBad bool
133
134	skipDirtySession bool // tests that use Conn should set this to true.
135
136	// dirtySession tests ResetSession, true if a query has executed
137	// until ResetSession is called.
138	dirtySession bool
139
140	// The waiter is called before each query. May be used in place of the "WAIT"
141	// directive.
142	waiter func(context.Context)
143}
144
145func (c *fakeConn) touchMem() {
146	c.line++
147}
148
149func (c *fakeConn) incrStat(v *int) {
150	c.mu.Lock()
151	*v++
152	c.mu.Unlock()
153}
154
155type fakeTx struct {
156	c *fakeConn
157}
158
159type boundCol struct {
160	Column      string
161	Placeholder string
162	Ordinal     int
163}
164
165type fakeStmt struct {
166	memToucher
167	c *fakeConn
168	q string // just for debugging
169
170	cmd   string
171	table string
172	panic string
173	wait  time.Duration
174
175	next *fakeStmt // used for returning multiple results.
176
177	closed bool
178
179	colName      []string      // used by CREATE, INSERT, SELECT (selected columns)
180	colType      []string      // used by CREATE
181	colValue     []interface{} // used by INSERT (mix of strings and "?" for bound params)
182	placeholders int           // used by INSERT/SELECT: number of ? params
183
184	whereCol []boundCol // used by SELECT (all placeholders)
185
186	placeholderConverter []driver.ValueConverter // used by INSERT
187}
188
189var fdriver driver.Driver = &fakeDriver{}
190
191func init() {
192	Register("test", fdriver)
193}
194
195func contains(list []string, y string) bool {
196	for _, x := range list {
197		if x == y {
198			return true
199		}
200	}
201	return false
202}
203
204type Dummy struct {
205	driver.Driver
206}
207
208func TestDrivers(t *testing.T) {
209	unregisterAllDrivers()
210	Register("test", fdriver)
211	Register("invalid", Dummy{})
212	all := Drivers()
213	if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") {
214		t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
215	}
216}
217
218// hook to simulate connection failures
219var hookOpenErr struct {
220	sync.Mutex
221	fn func() error
222}
223
224func setHookOpenErr(fn func() error) {
225	hookOpenErr.Lock()
226	defer hookOpenErr.Unlock()
227	hookOpenErr.fn = fn
228}
229
230// Supports dsn forms:
231//    <dbname>
232//    <dbname>;<opts>  (only currently supported option is `badConn`,
233//                      which causes driver.ErrBadConn to be returned on
234//                      every other conn.Begin())
235func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
236	hookOpenErr.Lock()
237	fn := hookOpenErr.fn
238	hookOpenErr.Unlock()
239	if fn != nil {
240		if err := fn(); err != nil {
241			return nil, err
242		}
243	}
244	parts := strings.Split(dsn, ";")
245	if len(parts) < 1 {
246		return nil, errors.New("fakedb: no database name")
247	}
248	name := parts[0]
249
250	db := d.getDB(name)
251
252	d.mu.Lock()
253	d.openCount++
254	d.mu.Unlock()
255	conn := &fakeConn{db: db}
256
257	if len(parts) >= 2 && parts[1] == "badConn" {
258		conn.bad = true
259	}
260	if d.waitCh != nil {
261		d.waitingCh <- struct{}{}
262		<-d.waitCh
263		d.waitCh = nil
264		d.waitingCh = nil
265	}
266	return conn, nil
267}
268
269func (d *fakeDriver) getDB(name string) *fakeDB {
270	d.mu.Lock()
271	defer d.mu.Unlock()
272	if d.dbs == nil {
273		d.dbs = make(map[string]*fakeDB)
274	}
275	db, ok := d.dbs[name]
276	if !ok {
277		db = &fakeDB{name: name}
278		d.dbs[name] = db
279	}
280	return db
281}
282
283func (db *fakeDB) wipe() {
284	db.mu.Lock()
285	defer db.mu.Unlock()
286	db.tables = nil
287}
288
289func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
290	db.mu.Lock()
291	defer db.mu.Unlock()
292	if db.tables == nil {
293		db.tables = make(map[string]*table)
294	}
295	if _, exist := db.tables[name]; exist {
296		return fmt.Errorf("fakedb: table %q already exists", name)
297	}
298	if len(columnNames) != len(columnTypes) {
299		return fmt.Errorf("fakedb: create table of %q len(names) != len(types): %d vs %d",
300			name, len(columnNames), len(columnTypes))
301	}
302	db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
303	return nil
304}
305
306// must be called with db.mu lock held
307func (db *fakeDB) table(table string) (*table, bool) {
308	if db.tables == nil {
309		return nil, false
310	}
311	t, ok := db.tables[table]
312	return t, ok
313}
314
315func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
316	db.mu.Lock()
317	defer db.mu.Unlock()
318	t, ok := db.table(table)
319	if !ok {
320		return
321	}
322	for n, cname := range t.colname {
323		if cname == column {
324			return t.coltype[n], true
325		}
326	}
327	return "", false
328}
329
330func (c *fakeConn) isBad() bool {
331	if c.stickyBad {
332		return true
333	} else if c.bad {
334		if c.db == nil {
335			return false
336		}
337		// alternate between bad conn and not bad conn
338		c.db.badConn = !c.db.badConn
339		return c.db.badConn
340	} else {
341		return false
342	}
343}
344
345func (c *fakeConn) isDirtyAndMark() bool {
346	if c.skipDirtySession {
347		return false
348	}
349	if c.currTx != nil {
350		c.dirtySession = true
351		return false
352	}
353	if c.dirtySession {
354		return true
355	}
356	c.dirtySession = true
357	return false
358}
359
360func (c *fakeConn) Begin() (driver.Tx, error) {
361	if c.isBad() {
362		return nil, driver.ErrBadConn
363	}
364	if c.currTx != nil {
365		return nil, errors.New("fakedb: already in a transaction")
366	}
367	c.touchMem()
368	c.currTx = &fakeTx{c: c}
369	return c.currTx, nil
370}
371
372var hookPostCloseConn struct {
373	sync.Mutex
374	fn func(*fakeConn, error)
375}
376
377func setHookpostCloseConn(fn func(*fakeConn, error)) {
378	hookPostCloseConn.Lock()
379	defer hookPostCloseConn.Unlock()
380	hookPostCloseConn.fn = fn
381}
382
383var testStrictClose *testing.T
384
385// setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close
386// fails to close. If nil, the check is disabled.
387func setStrictFakeConnClose(t *testing.T) {
388	testStrictClose = t
389}
390
391func (c *fakeConn) ResetSession(ctx context.Context) error {
392	c.dirtySession = false
393	c.currTx = nil
394	if c.isBad() {
395		return driver.ErrBadConn
396	}
397	return nil
398}
399
400var _ driver.Validator = (*fakeConn)(nil)
401
402func (c *fakeConn) IsValid() bool {
403	return !c.isBad()
404}
405
406func (c *fakeConn) Close() (err error) {
407	drv := fdriver.(*fakeDriver)
408	defer func() {
409		if err != nil && testStrictClose != nil {
410			testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
411		}
412		hookPostCloseConn.Lock()
413		fn := hookPostCloseConn.fn
414		hookPostCloseConn.Unlock()
415		if fn != nil {
416			fn(c, err)
417		}
418		if err == nil {
419			drv.mu.Lock()
420			drv.closeCount++
421			drv.mu.Unlock()
422		}
423	}()
424	c.touchMem()
425	if c.currTx != nil {
426		return errors.New("fakedb: can't close fakeConn; in a Transaction")
427	}
428	if c.db == nil {
429		return errors.New("fakedb: can't close fakeConn; already closed")
430	}
431	if c.stmtsMade > c.stmtsClosed {
432		return errors.New("fakedb: can't close; dangling statement(s)")
433	}
434	c.db = nil
435	return nil
436}
437
438func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error {
439	for _, arg := range args {
440		switch arg.Value.(type) {
441		case int64, float64, bool, nil, []byte, string, time.Time:
442		default:
443			if !allowAny {
444				return fmt.Errorf("fakedb: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
445			}
446		}
447	}
448	return nil
449}
450
451func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
452	// Ensure that ExecContext is called if available.
453	panic("ExecContext was not called.")
454}
455
456func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
457	// This is an optional interface, but it's implemented here
458	// just to check that all the args are of the proper types.
459	// ErrSkip is returned so the caller acts as if we didn't
460	// implement this at all.
461	err := checkSubsetTypes(c.db.allowAny, args)
462	if err != nil {
463		return nil, err
464	}
465	return nil, driver.ErrSkip
466}
467
468func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
469	// Ensure that ExecContext is called if available.
470	panic("QueryContext was not called.")
471}
472
473func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
474	// This is an optional interface, but it's implemented here
475	// just to check that all the args are of the proper types.
476	// ErrSkip is returned so the caller acts as if we didn't
477	// implement this at all.
478	err := checkSubsetTypes(c.db.allowAny, args)
479	if err != nil {
480		return nil, err
481	}
482	return nil, driver.ErrSkip
483}
484
485func errf(msg string, args ...interface{}) error {
486	return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
487}
488
489// parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
490// (note that where columns must always contain ? marks,
491//  just a limitation for fakedb)
492func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
493	if len(parts) != 3 {
494		stmt.Close()
495		return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
496	}
497	stmt.table = parts[0]
498
499	stmt.colName = strings.Split(parts[1], ",")
500	for n, colspec := range strings.Split(parts[2], ",") {
501		if colspec == "" {
502			continue
503		}
504		nameVal := strings.Split(colspec, "=")
505		if len(nameVal) != 2 {
506			stmt.Close()
507			return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
508		}
509		column, value := nameVal[0], nameVal[1]
510		_, ok := c.db.columnType(stmt.table, column)
511		if !ok {
512			stmt.Close()
513			return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
514		}
515		if !strings.HasPrefix(value, "?") {
516			stmt.Close()
517			return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
518				stmt.table, column)
519		}
520		stmt.placeholders++
521		stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders})
522	}
523	return stmt, nil
524}
525
526// parts are table|col=type,col2=type2
527func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
528	if len(parts) != 2 {
529		stmt.Close()
530		return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
531	}
532	stmt.table = parts[0]
533	for n, colspec := range strings.Split(parts[1], ",") {
534		nameType := strings.Split(colspec, "=")
535		if len(nameType) != 2 {
536			stmt.Close()
537			return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
538		}
539		stmt.colName = append(stmt.colName, nameType[0])
540		stmt.colType = append(stmt.colType, nameType[1])
541	}
542	return stmt, nil
543}
544
545// parts are table|col=?,col2=val
546func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) {
547	if len(parts) != 2 {
548		stmt.Close()
549		return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
550	}
551	stmt.table = parts[0]
552	for n, colspec := range strings.Split(parts[1], ",") {
553		nameVal := strings.Split(colspec, "=")
554		if len(nameVal) != 2 {
555			stmt.Close()
556			return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
557		}
558		column, value := nameVal[0], nameVal[1]
559		ctype, ok := c.db.columnType(stmt.table, column)
560		if !ok {
561			stmt.Close()
562			return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
563		}
564		stmt.colName = append(stmt.colName, column)
565
566		if !strings.HasPrefix(value, "?") {
567			var subsetVal interface{}
568			// Convert to driver subset type
569			switch ctype {
570			case "string":
571				subsetVal = []byte(value)
572			case "blob":
573				subsetVal = []byte(value)
574			case "int32":
575				i, err := strconv.Atoi(value)
576				if err != nil {
577					stmt.Close()
578					return nil, errf("invalid conversion to int32 from %q", value)
579				}
580				subsetVal = int64(i) // int64 is a subset type, but not int32
581			case "table": // For testing cursor reads.
582				c.skipDirtySession = true
583				vparts := strings.Split(value, "!")
584
585				substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ",")))
586				if err != nil {
587					return nil, err
588				}
589				cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{})
590				substmt.Close()
591				if err != nil {
592					return nil, err
593				}
594				subsetVal = cursor
595			default:
596				stmt.Close()
597				return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
598			}
599			stmt.colValue = append(stmt.colValue, subsetVal)
600		} else {
601			stmt.placeholders++
602			stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
603			stmt.colValue = append(stmt.colValue, value)
604		}
605	}
606	return stmt, nil
607}
608
609// hook to simulate broken connections
610var hookPrepareBadConn func() bool
611
612func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
613	panic("use PrepareContext")
614}
615
616func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
617	c.numPrepare++
618	if c.db == nil {
619		panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
620	}
621
622	if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) {
623		return nil, driver.ErrBadConn
624	}
625
626	c.touchMem()
627	var firstStmt, prev *fakeStmt
628	for _, query := range strings.Split(query, ";") {
629		parts := strings.Split(query, "|")
630		if len(parts) < 1 {
631			return nil, errf("empty query")
632		}
633		stmt := &fakeStmt{q: query, c: c, memToucher: c}
634		if firstStmt == nil {
635			firstStmt = stmt
636		}
637		if len(parts) >= 3 {
638			switch parts[0] {
639			case "PANIC":
640				stmt.panic = parts[1]
641				parts = parts[2:]
642			case "WAIT":
643				wait, err := time.ParseDuration(parts[1])
644				if err != nil {
645					return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err)
646				}
647				parts = parts[2:]
648				stmt.wait = wait
649			}
650		}
651		cmd := parts[0]
652		stmt.cmd = cmd
653		parts = parts[1:]
654
655		if c.waiter != nil {
656			c.waiter(ctx)
657		}
658
659		if stmt.wait > 0 {
660			wait := time.NewTimer(stmt.wait)
661			select {
662			case <-wait.C:
663			case <-ctx.Done():
664				wait.Stop()
665				return nil, ctx.Err()
666			}
667		}
668
669		c.incrStat(&c.stmtsMade)
670		var err error
671		switch cmd {
672		case "WIPE":
673			// Nothing
674		case "SELECT":
675			stmt, err = c.prepareSelect(stmt, parts)
676		case "CREATE":
677			stmt, err = c.prepareCreate(stmt, parts)
678		case "INSERT":
679			stmt, err = c.prepareInsert(ctx, stmt, parts)
680		case "NOSERT":
681			// Do all the prep-work like for an INSERT but don't actually insert the row.
682			// Used for some of the concurrent tests.
683			stmt, err = c.prepareInsert(ctx, stmt, parts)
684		default:
685			stmt.Close()
686			return nil, errf("unsupported command type %q", cmd)
687		}
688		if err != nil {
689			return nil, err
690		}
691		if prev != nil {
692			prev.next = stmt
693		}
694		prev = stmt
695	}
696	return firstStmt, nil
697}
698
699func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
700	if s.panic == "ColumnConverter" {
701		panic(s.panic)
702	}
703	if len(s.placeholderConverter) == 0 {
704		return driver.DefaultParameterConverter
705	}
706	return s.placeholderConverter[idx]
707}
708
709func (s *fakeStmt) Close() error {
710	if s.panic == "Close" {
711		panic(s.panic)
712	}
713	if s.c == nil {
714		panic("nil conn in fakeStmt.Close")
715	}
716	if s.c.db == nil {
717		panic("in fakeStmt.Close, conn's db is nil (already closed)")
718	}
719	s.touchMem()
720	if !s.closed {
721		s.c.incrStat(&s.c.stmtsClosed)
722		s.closed = true
723	}
724	if s.next != nil {
725		s.next.Close()
726	}
727	return nil
728}
729
730var errClosed = errors.New("fakedb: statement has been closed")
731
732// hook to simulate broken connections
733var hookExecBadConn func() bool
734
735func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
736	panic("Using ExecContext")
737}
738
739var errFakeConnSessionDirty = errors.New("fakedb: session is dirty")
740
741func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
742	if s.panic == "Exec" {
743		panic(s.panic)
744	}
745	if s.closed {
746		return nil, errClosed
747	}
748
749	if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
750		return nil, driver.ErrBadConn
751	}
752	if s.c.isDirtyAndMark() {
753		return nil, errFakeConnSessionDirty
754	}
755
756	err := checkSubsetTypes(s.c.db.allowAny, args)
757	if err != nil {
758		return nil, err
759	}
760	s.touchMem()
761
762	if s.wait > 0 {
763		time.Sleep(s.wait)
764	}
765
766	select {
767	default:
768	case <-ctx.Done():
769		return nil, ctx.Err()
770	}
771
772	db := s.c.db
773	switch s.cmd {
774	case "WIPE":
775		db.wipe()
776		return driver.ResultNoRows, nil
777	case "CREATE":
778		if err := db.createTable(s.table, s.colName, s.colType); err != nil {
779			return nil, err
780		}
781		return driver.ResultNoRows, nil
782	case "INSERT":
783		return s.execInsert(args, true)
784	case "NOSERT":
785		// Do all the prep-work like for an INSERT but don't actually insert the row.
786		// Used for some of the concurrent tests.
787		return s.execInsert(args, false)
788	}
789	return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd)
790}
791
792// When doInsert is true, add the row to the table.
793// When doInsert is false do prep-work and error checking, but don't
794// actually add the row to the table.
795func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) {
796	db := s.c.db
797	if len(args) != s.placeholders {
798		panic("error in pkg db; should only get here if size is correct")
799	}
800	db.mu.Lock()
801	t, ok := db.table(s.table)
802	db.mu.Unlock()
803	if !ok {
804		return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
805	}
806
807	t.mu.Lock()
808	defer t.mu.Unlock()
809
810	var cols []interface{}
811	if doInsert {
812		cols = make([]interface{}, len(t.colname))
813	}
814	argPos := 0
815	for n, colname := range s.colName {
816		colidx := t.columnIndex(colname)
817		if colidx == -1 {
818			return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
819		}
820		var val interface{}
821		if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") {
822			if strvalue == "?" {
823				val = args[argPos].Value
824			} else {
825				// Assign value from argument placeholder name.
826				for _, a := range args {
827					if a.Name == strvalue[1:] {
828						val = a.Value
829						break
830					}
831				}
832			}
833			argPos++
834		} else {
835			val = s.colValue[n]
836		}
837		if doInsert {
838			cols[colidx] = val
839		}
840	}
841
842	if doInsert {
843		t.rows = append(t.rows, &row{cols: cols})
844	}
845	return driver.RowsAffected(1), nil
846}
847
848// hook to simulate broken connections
849var hookQueryBadConn func() bool
850
851func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
852	panic("Use QueryContext")
853}
854
855func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
856	if s.panic == "Query" {
857		panic(s.panic)
858	}
859	if s.closed {
860		return nil, errClosed
861	}
862
863	if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
864		return nil, driver.ErrBadConn
865	}
866	if s.c.isDirtyAndMark() {
867		return nil, errFakeConnSessionDirty
868	}
869
870	err := checkSubsetTypes(s.c.db.allowAny, args)
871	if err != nil {
872		return nil, err
873	}
874
875	s.touchMem()
876	db := s.c.db
877	if len(args) != s.placeholders {
878		panic("error in pkg db; should only get here if size is correct")
879	}
880
881	setMRows := make([][]*row, 0, 1)
882	setColumns := make([][]string, 0, 1)
883	setColType := make([][]string, 0, 1)
884
885	for {
886		db.mu.Lock()
887		t, ok := db.table(s.table)
888		db.mu.Unlock()
889		if !ok {
890			return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
891		}
892
893		if s.table == "magicquery" {
894			if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" {
895				if args[0].Value == "sleep" {
896					time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond)
897				}
898			}
899		}
900		if s.table == "tx_status" && s.colName[0] == "tx_status" {
901			txStatus := "autocommit"
902			if s.c.currTx != nil {
903				txStatus = "transaction"
904			}
905			cursor := &rowsCursor{
906				parentMem: s.c,
907				posRow:    -1,
908				rows: [][]*row{
909					[]*row{
910						{
911							cols: []interface{}{
912								txStatus,
913							},
914						},
915					},
916				},
917				cols: [][]string{
918					[]string{
919						"tx_status",
920					},
921				},
922				colType: [][]string{
923					[]string{
924						"string",
925					},
926				},
927				errPos: -1,
928			}
929			return cursor, nil
930		}
931
932		t.mu.Lock()
933
934		colIdx := make(map[string]int) // select column name -> column index in table
935		for _, name := range s.colName {
936			idx := t.columnIndex(name)
937			if idx == -1 {
938				t.mu.Unlock()
939				return nil, fmt.Errorf("fakedb: unknown column name %q", name)
940			}
941			colIdx[name] = idx
942		}
943
944		mrows := []*row{}
945	rows:
946		for _, trow := range t.rows {
947			// Process the where clause, skipping non-match rows. This is lazy
948			// and just uses fmt.Sprintf("%v") to test equality. Good enough
949			// for test code.
950			for _, wcol := range s.whereCol {
951				idx := t.columnIndex(wcol.Column)
952				if idx == -1 {
953					t.mu.Unlock()
954					return nil, fmt.Errorf("fakedb: invalid where clause column %q", wcol)
955				}
956				tcol := trow.cols[idx]
957				if bs, ok := tcol.([]byte); ok {
958					// lazy hack to avoid sprintf %v on a []byte
959					tcol = string(bs)
960				}
961				var argValue interface{}
962				if wcol.Placeholder == "?" {
963					argValue = args[wcol.Ordinal-1].Value
964				} else {
965					// Assign arg value from placeholder name.
966					for _, a := range args {
967						if a.Name == wcol.Placeholder[1:] {
968							argValue = a.Value
969							break
970						}
971					}
972				}
973				if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
974					continue rows
975				}
976			}
977			mrow := &row{cols: make([]interface{}, len(s.colName))}
978			for seli, name := range s.colName {
979				mrow.cols[seli] = trow.cols[colIdx[name]]
980			}
981			mrows = append(mrows, mrow)
982		}
983
984		var colType []string
985		for _, column := range s.colName {
986			colType = append(colType, t.coltype[t.columnIndex(column)])
987		}
988
989		t.mu.Unlock()
990
991		setMRows = append(setMRows, mrows)
992		setColumns = append(setColumns, s.colName)
993		setColType = append(setColType, colType)
994
995		if s.next == nil {
996			break
997		}
998		s = s.next
999	}
1000
1001	cursor := &rowsCursor{
1002		parentMem: s.c,
1003		posRow:    -1,
1004		rows:      setMRows,
1005		cols:      setColumns,
1006		colType:   setColType,
1007		errPos:    -1,
1008	}
1009	return cursor, nil
1010}
1011
1012func (s *fakeStmt) NumInput() int {
1013	if s.panic == "NumInput" {
1014		panic(s.panic)
1015	}
1016	return s.placeholders
1017}
1018
1019// hook to simulate broken connections
1020var hookCommitBadConn func() bool
1021
1022func (tx *fakeTx) Commit() error {
1023	tx.c.currTx = nil
1024	if hookCommitBadConn != nil && hookCommitBadConn() {
1025		return driver.ErrBadConn
1026	}
1027	tx.c.touchMem()
1028	return nil
1029}
1030
1031// hook to simulate broken connections
1032var hookRollbackBadConn func() bool
1033
1034func (tx *fakeTx) Rollback() error {
1035	tx.c.currTx = nil
1036	if hookRollbackBadConn != nil && hookRollbackBadConn() {
1037		return driver.ErrBadConn
1038	}
1039	tx.c.touchMem()
1040	return nil
1041}
1042
1043type rowsCursor struct {
1044	parentMem memToucher
1045	cols      [][]string
1046	colType   [][]string
1047	posSet    int
1048	posRow    int
1049	rows      [][]*row
1050	closed    bool
1051
1052	// errPos and err are for making Next return early with error.
1053	errPos int
1054	err    error
1055
1056	// a clone of slices to give out to clients, indexed by the
1057	// original slice's first byte address.  we clone them
1058	// just so we're able to corrupt them on close.
1059	bytesClone map[*byte][]byte
1060
1061	// Every operation writes to line to enable the race detector
1062	// check for data races.
1063	// This is separate from the fakeConn.line to allow for drivers that
1064	// can start multiple queries on the same transaction at the same time.
1065	line int64
1066}
1067
1068func (rc *rowsCursor) touchMem() {
1069	rc.parentMem.touchMem()
1070	rc.line++
1071}
1072
1073func (rc *rowsCursor) Close() error {
1074	rc.touchMem()
1075	rc.parentMem.touchMem()
1076	rc.closed = true
1077	return nil
1078}
1079
1080func (rc *rowsCursor) Columns() []string {
1081	return rc.cols[rc.posSet]
1082}
1083
1084func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
1085	return colTypeToReflectType(rc.colType[rc.posSet][index])
1086}
1087
1088var rowsCursorNextHook func(dest []driver.Value) error
1089
1090func (rc *rowsCursor) Next(dest []driver.Value) error {
1091	if rowsCursorNextHook != nil {
1092		return rowsCursorNextHook(dest)
1093	}
1094
1095	if rc.closed {
1096		return errors.New("fakedb: cursor is closed")
1097	}
1098	rc.touchMem()
1099	rc.posRow++
1100	if rc.posRow == rc.errPos {
1101		return rc.err
1102	}
1103	if rc.posRow >= len(rc.rows[rc.posSet]) {
1104		return io.EOF // per interface spec
1105	}
1106	for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
1107		// TODO(bradfitz): convert to subset types? naah, I
1108		// think the subset types should only be input to
1109		// driver, but the sql package should be able to handle
1110		// a wider range of types coming out of drivers. all
1111		// for ease of drivers, and to prevent drivers from
1112		// messing up conversions or doing them differently.
1113		dest[i] = v
1114
1115		if bs, ok := v.([]byte); ok {
1116			if rc.bytesClone == nil {
1117				rc.bytesClone = make(map[*byte][]byte)
1118			}
1119			clone, ok := rc.bytesClone[&bs[0]]
1120			if !ok {
1121				clone = make([]byte, len(bs))
1122				copy(clone, bs)
1123				rc.bytesClone[&bs[0]] = clone
1124			}
1125			dest[i] = clone
1126		}
1127	}
1128	return nil
1129}
1130
1131func (rc *rowsCursor) HasNextResultSet() bool {
1132	rc.touchMem()
1133	return rc.posSet < len(rc.rows)-1
1134}
1135
1136func (rc *rowsCursor) NextResultSet() error {
1137	rc.touchMem()
1138	if rc.HasNextResultSet() {
1139		rc.posSet++
1140		rc.posRow = -1
1141		return nil
1142	}
1143	return io.EOF // Per interface spec.
1144}
1145
1146// fakeDriverString is like driver.String, but indirects pointers like
1147// DefaultValueConverter.
1148//
1149// This could be surprising behavior to retroactively apply to
1150// driver.String now that Go1 is out, but this is convenient for
1151// our TestPointerParamsAndScans.
1152//
1153type fakeDriverString struct{}
1154
1155func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) {
1156	switch c := v.(type) {
1157	case string, []byte:
1158		return v, nil
1159	case *string:
1160		if c == nil {
1161			return nil, nil
1162		}
1163		return *c, nil
1164	}
1165	return fmt.Sprintf("%v", v), nil
1166}
1167
1168type anyTypeConverter struct{}
1169
1170func (anyTypeConverter) ConvertValue(v interface{}) (driver.Value, error) {
1171	return v, nil
1172}
1173
1174func converterForType(typ string) driver.ValueConverter {
1175	switch typ {
1176	case "bool":
1177		return driver.Bool
1178	case "nullbool":
1179		return driver.Null{Converter: driver.Bool}
1180	case "int32":
1181		return driver.Int32
1182	case "nullint32":
1183		return driver.Null{Converter: driver.DefaultParameterConverter}
1184	case "string":
1185		return driver.NotNull{Converter: fakeDriverString{}}
1186	case "nullstring":
1187		return driver.Null{Converter: fakeDriverString{}}
1188	case "int64":
1189		// TODO(coopernurse): add type-specific converter
1190		return driver.NotNull{Converter: driver.DefaultParameterConverter}
1191	case "nullint64":
1192		// TODO(coopernurse): add type-specific converter
1193		return driver.Null{Converter: driver.DefaultParameterConverter}
1194	case "float64":
1195		// TODO(coopernurse): add type-specific converter
1196		return driver.NotNull{Converter: driver.DefaultParameterConverter}
1197	case "nullfloat64":
1198		// TODO(coopernurse): add type-specific converter
1199		return driver.Null{Converter: driver.DefaultParameterConverter}
1200	case "datetime":
1201		return driver.NotNull{Converter: driver.DefaultParameterConverter}
1202	case "nulldatetime":
1203		return driver.Null{Converter: driver.DefaultParameterConverter}
1204	case "any":
1205		return anyTypeConverter{}
1206	}
1207	panic("invalid fakedb column type of " + typ)
1208}
1209
1210func colTypeToReflectType(typ string) reflect.Type {
1211	switch typ {
1212	case "bool":
1213		return reflect.TypeOf(false)
1214	case "nullbool":
1215		return reflect.TypeOf(NullBool{})
1216	case "int32":
1217		return reflect.TypeOf(int32(0))
1218	case "nullint32":
1219		return reflect.TypeOf(NullInt32{})
1220	case "string":
1221		return reflect.TypeOf("")
1222	case "nullstring":
1223		return reflect.TypeOf(NullString{})
1224	case "int64":
1225		return reflect.TypeOf(int64(0))
1226	case "nullint64":
1227		return reflect.TypeOf(NullInt64{})
1228	case "float64":
1229		return reflect.TypeOf(float64(0))
1230	case "nullfloat64":
1231		return reflect.TypeOf(NullFloat64{})
1232	case "datetime":
1233		return reflect.TypeOf(time.Time{})
1234	case "any":
1235		return reflect.TypeOf(new(interface{})).Elem()
1236	}
1237	panic("invalid fakedb column type of " + typ)
1238}
1239