1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package sql
6
7import (
8	"context"
9	"database/sql/driver"
10	"errors"
11	"fmt"
12	"io"
13	"reflect"
14	"sort"
15	"strconv"
16	"strings"
17	"sync"
18	"testing"
19	"time"
20)
21
22// fakeDriver is a fake database that implements Go's driver.Driver
23// interface, just for testing.
24//
25// It speaks a query language that's semantically similar to but
26// syntactically different and simpler than SQL.  The syntax is as
27// follows:
28//
29//   WIPE
30//   CREATE|<tablename>|<col>=<type>,<col>=<type>,...
31//     where types are: "string", [u]int{8,16,32,64}, "bool"
32//   INSERT|<tablename>|col=val,col2=val2,col3=?
33//   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
34//   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2
35//
36// Any of these can be preceded by PANIC|<method>|, to cause the
37// named method on fakeStmt to panic.
38//
39// Any of these can be proceeded by WAIT|<duration>|, to cause the
40// named method on fakeStmt to sleep for the specified duration.
41//
42// Multiple of these can be combined when separated with a semicolon.
43//
44// When opening a fakeDriver's database, it starts empty with no
45// tables. All tables and data are stored in memory only.
46type fakeDriver struct {
47	mu         sync.Mutex // guards 3 following fields
48	openCount  int        // conn opens
49	closeCount int        // conn closes
50	waitCh     chan struct{}
51	waitingCh  chan struct{}
52	dbs        map[string]*fakeDB
53}
54
55type fakeConnector struct {
56	name string
57
58	waiter func(context.Context)
59	closed bool
60}
61
62func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) {
63	conn, err := fdriver.Open(c.name)
64	conn.(*fakeConn).waiter = c.waiter
65	return conn, err
66}
67
68func (c *fakeConnector) Driver() driver.Driver {
69	return fdriver
70}
71
72func (c *fakeConnector) Close() error {
73	if c.closed {
74		return errors.New("fakedb: connector is closed")
75	}
76	c.closed = true
77	return nil
78}
79
80type fakeDriverCtx struct {
81	fakeDriver
82}
83
84var _ driver.DriverContext = &fakeDriverCtx{}
85
86func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) {
87	return &fakeConnector{name: name}, nil
88}
89
90type fakeDB struct {
91	name string
92
93	mu       sync.Mutex
94	tables   map[string]*table
95	badConn  bool
96	allowAny bool
97}
98
99type fakeError struct {
100	Message string
101	Wrapped error
102}
103
104func (err fakeError) Error() string {
105	return err.Message
106}
107
108func (err fakeError) Unwrap() error {
109	return err.Wrapped
110}
111
112type table struct {
113	mu      sync.Mutex
114	colname []string
115	coltype []string
116	rows    []*row
117}
118
119func (t *table) columnIndex(name string) int {
120	for n, nname := range t.colname {
121		if name == nname {
122			return n
123		}
124	}
125	return -1
126}
127
128type row struct {
129	cols []any // must be same size as its table colname + coltype
130}
131
132type memToucher interface {
133	// touchMem reads & writes some memory, to help find data races.
134	touchMem()
135}
136
137type fakeConn struct {
138	db *fakeDB // where to return ourselves to
139
140	currTx *fakeTx
141
142	// Every operation writes to line to enable the race detector
143	// check for data races.
144	line int64
145
146	// Stats for tests:
147	mu          sync.Mutex
148	stmtsMade   int
149	stmtsClosed int
150	numPrepare  int
151
152	// bad connection tests; see isBad()
153	bad       bool
154	stickyBad bool
155
156	skipDirtySession bool // tests that use Conn should set this to true.
157
158	// dirtySession tests ResetSession, true if a query has executed
159	// until ResetSession is called.
160	dirtySession bool
161
162	// The waiter is called before each query. May be used in place of the "WAIT"
163	// directive.
164	waiter func(context.Context)
165}
166
167func (c *fakeConn) touchMem() {
168	c.line++
169}
170
171func (c *fakeConn) incrStat(v *int) {
172	c.mu.Lock()
173	*v++
174	c.mu.Unlock()
175}
176
177type fakeTx struct {
178	c *fakeConn
179}
180
181type boundCol struct {
182	Column      string
183	Placeholder string
184	Ordinal     int
185}
186
187type fakeStmt struct {
188	memToucher
189	c *fakeConn
190	q string // just for debugging
191
192	cmd   string
193	table string
194	panic string
195	wait  time.Duration
196
197	next *fakeStmt // used for returning multiple results.
198
199	closed bool
200
201	colName      []string // used by CREATE, INSERT, SELECT (selected columns)
202	colType      []string // used by CREATE
203	colValue     []any    // used by INSERT (mix of strings and "?" for bound params)
204	placeholders int      // used by INSERT/SELECT: number of ? params
205
206	whereCol []boundCol // used by SELECT (all placeholders)
207
208	placeholderConverter []driver.ValueConverter // used by INSERT
209}
210
211var fdriver driver.Driver = &fakeDriver{}
212
213func init() {
214	Register("test", fdriver)
215}
216
217func contains(list []string, y string) bool {
218	for _, x := range list {
219		if x == y {
220			return true
221		}
222	}
223	return false
224}
225
226type Dummy struct {
227	driver.Driver
228}
229
230func TestDrivers(t *testing.T) {
231	unregisterAllDrivers()
232	Register("test", fdriver)
233	Register("invalid", Dummy{})
234	all := Drivers()
235	if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") {
236		t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
237	}
238}
239
240// hook to simulate connection failures
241var hookOpenErr struct {
242	sync.Mutex
243	fn func() error
244}
245
246func setHookOpenErr(fn func() error) {
247	hookOpenErr.Lock()
248	defer hookOpenErr.Unlock()
249	hookOpenErr.fn = fn
250}
251
252// Supports dsn forms:
253//    <dbname>
254//    <dbname>;<opts>  (only currently supported option is `badConn`,
255//                      which causes driver.ErrBadConn to be returned on
256//                      every other conn.Begin())
257func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
258	hookOpenErr.Lock()
259	fn := hookOpenErr.fn
260	hookOpenErr.Unlock()
261	if fn != nil {
262		if err := fn(); err != nil {
263			return nil, err
264		}
265	}
266	parts := strings.Split(dsn, ";")
267	if len(parts) < 1 {
268		return nil, errors.New("fakedb: no database name")
269	}
270	name := parts[0]
271
272	db := d.getDB(name)
273
274	d.mu.Lock()
275	d.openCount++
276	d.mu.Unlock()
277	conn := &fakeConn{db: db}
278
279	if len(parts) >= 2 && parts[1] == "badConn" {
280		conn.bad = true
281	}
282	if d.waitCh != nil {
283		d.waitingCh <- struct{}{}
284		<-d.waitCh
285		d.waitCh = nil
286		d.waitingCh = nil
287	}
288	return conn, nil
289}
290
291func (d *fakeDriver) getDB(name string) *fakeDB {
292	d.mu.Lock()
293	defer d.mu.Unlock()
294	if d.dbs == nil {
295		d.dbs = make(map[string]*fakeDB)
296	}
297	db, ok := d.dbs[name]
298	if !ok {
299		db = &fakeDB{name: name}
300		d.dbs[name] = db
301	}
302	return db
303}
304
305func (db *fakeDB) wipe() {
306	db.mu.Lock()
307	defer db.mu.Unlock()
308	db.tables = nil
309}
310
311func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
312	db.mu.Lock()
313	defer db.mu.Unlock()
314	if db.tables == nil {
315		db.tables = make(map[string]*table)
316	}
317	if _, exist := db.tables[name]; exist {
318		return fmt.Errorf("fakedb: table %q already exists", name)
319	}
320	if len(columnNames) != len(columnTypes) {
321		return fmt.Errorf("fakedb: create table of %q len(names) != len(types): %d vs %d",
322			name, len(columnNames), len(columnTypes))
323	}
324	db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
325	return nil
326}
327
328// must be called with db.mu lock held
329func (db *fakeDB) table(table string) (*table, bool) {
330	if db.tables == nil {
331		return nil, false
332	}
333	t, ok := db.tables[table]
334	return t, ok
335}
336
337func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
338	db.mu.Lock()
339	defer db.mu.Unlock()
340	t, ok := db.table(table)
341	if !ok {
342		return
343	}
344	for n, cname := range t.colname {
345		if cname == column {
346			return t.coltype[n], true
347		}
348	}
349	return "", false
350}
351
352func (c *fakeConn) isBad() bool {
353	if c.stickyBad {
354		return true
355	} else if c.bad {
356		if c.db == nil {
357			return false
358		}
359		// alternate between bad conn and not bad conn
360		c.db.badConn = !c.db.badConn
361		return c.db.badConn
362	} else {
363		return false
364	}
365}
366
367func (c *fakeConn) isDirtyAndMark() bool {
368	if c.skipDirtySession {
369		return false
370	}
371	if c.currTx != nil {
372		c.dirtySession = true
373		return false
374	}
375	if c.dirtySession {
376		return true
377	}
378	c.dirtySession = true
379	return false
380}
381
382func (c *fakeConn) Begin() (driver.Tx, error) {
383	if c.isBad() {
384		return nil, fakeError{Wrapped: driver.ErrBadConn}
385	}
386	if c.currTx != nil {
387		return nil, errors.New("fakedb: already in a transaction")
388	}
389	c.touchMem()
390	c.currTx = &fakeTx{c: c}
391	return c.currTx, nil
392}
393
394var hookPostCloseConn struct {
395	sync.Mutex
396	fn func(*fakeConn, error)
397}
398
399func setHookpostCloseConn(fn func(*fakeConn, error)) {
400	hookPostCloseConn.Lock()
401	defer hookPostCloseConn.Unlock()
402	hookPostCloseConn.fn = fn
403}
404
405var testStrictClose *testing.T
406
407// setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close
408// fails to close. If nil, the check is disabled.
409func setStrictFakeConnClose(t *testing.T) {
410	testStrictClose = t
411}
412
413func (c *fakeConn) ResetSession(ctx context.Context) error {
414	c.dirtySession = false
415	c.currTx = nil
416	if c.isBad() {
417		return fakeError{Message: "Reset Session: bad conn", Wrapped: driver.ErrBadConn}
418	}
419	return nil
420}
421
422var _ driver.Validator = (*fakeConn)(nil)
423
424func (c *fakeConn) IsValid() bool {
425	return !c.isBad()
426}
427
428func (c *fakeConn) Close() (err error) {
429	drv := fdriver.(*fakeDriver)
430	defer func() {
431		if err != nil && testStrictClose != nil {
432			testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
433		}
434		hookPostCloseConn.Lock()
435		fn := hookPostCloseConn.fn
436		hookPostCloseConn.Unlock()
437		if fn != nil {
438			fn(c, err)
439		}
440		if err == nil {
441			drv.mu.Lock()
442			drv.closeCount++
443			drv.mu.Unlock()
444		}
445	}()
446	c.touchMem()
447	if c.currTx != nil {
448		return errors.New("fakedb: can't close fakeConn; in a Transaction")
449	}
450	if c.db == nil {
451		return errors.New("fakedb: can't close fakeConn; already closed")
452	}
453	if c.stmtsMade > c.stmtsClosed {
454		return errors.New("fakedb: can't close; dangling statement(s)")
455	}
456	c.db = nil
457	return nil
458}
459
460func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error {
461	for _, arg := range args {
462		switch arg.Value.(type) {
463		case int64, float64, bool, nil, []byte, string, time.Time:
464		default:
465			if !allowAny {
466				return fmt.Errorf("fakedb: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
467			}
468		}
469	}
470	return nil
471}
472
473func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
474	// Ensure that ExecContext is called if available.
475	panic("ExecContext was not called.")
476}
477
478func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
479	// This is an optional interface, but it's implemented here
480	// just to check that all the args are of the proper types.
481	// ErrSkip is returned so the caller acts as if we didn't
482	// implement this at all.
483	err := checkSubsetTypes(c.db.allowAny, args)
484	if err != nil {
485		return nil, err
486	}
487	return nil, driver.ErrSkip
488}
489
490func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
491	// Ensure that ExecContext is called if available.
492	panic("QueryContext was not called.")
493}
494
495func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
496	// This is an optional interface, but it's implemented here
497	// just to check that all the args are of the proper types.
498	// ErrSkip is returned so the caller acts as if we didn't
499	// implement this at all.
500	err := checkSubsetTypes(c.db.allowAny, args)
501	if err != nil {
502		return nil, err
503	}
504	return nil, driver.ErrSkip
505}
506
507func errf(msg string, args ...any) error {
508	return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
509}
510
511// parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
512// (note that where columns must always contain ? marks,
513//  just a limitation for fakedb)
514func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
515	if len(parts) != 3 {
516		stmt.Close()
517		return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
518	}
519	stmt.table = parts[0]
520
521	stmt.colName = strings.Split(parts[1], ",")
522	for n, colspec := range strings.Split(parts[2], ",") {
523		if colspec == "" {
524			continue
525		}
526		nameVal := strings.Split(colspec, "=")
527		if len(nameVal) != 2 {
528			stmt.Close()
529			return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
530		}
531		column, value := nameVal[0], nameVal[1]
532		_, ok := c.db.columnType(stmt.table, column)
533		if !ok {
534			stmt.Close()
535			return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
536		}
537		if !strings.HasPrefix(value, "?") {
538			stmt.Close()
539			return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
540				stmt.table, column)
541		}
542		stmt.placeholders++
543		stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders})
544	}
545	return stmt, nil
546}
547
548// parts are table|col=type,col2=type2
549func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
550	if len(parts) != 2 {
551		stmt.Close()
552		return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
553	}
554	stmt.table = parts[0]
555	for n, colspec := range strings.Split(parts[1], ",") {
556		nameType := strings.Split(colspec, "=")
557		if len(nameType) != 2 {
558			stmt.Close()
559			return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
560		}
561		stmt.colName = append(stmt.colName, nameType[0])
562		stmt.colType = append(stmt.colType, nameType[1])
563	}
564	return stmt, nil
565}
566
567// parts are table|col=?,col2=val
568func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) {
569	if len(parts) != 2 {
570		stmt.Close()
571		return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
572	}
573	stmt.table = parts[0]
574	for n, colspec := range strings.Split(parts[1], ",") {
575		nameVal := strings.Split(colspec, "=")
576		if len(nameVal) != 2 {
577			stmt.Close()
578			return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
579		}
580		column, value := nameVal[0], nameVal[1]
581		ctype, ok := c.db.columnType(stmt.table, column)
582		if !ok {
583			stmt.Close()
584			return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
585		}
586		stmt.colName = append(stmt.colName, column)
587
588		if !strings.HasPrefix(value, "?") {
589			var subsetVal any
590			// Convert to driver subset type
591			switch ctype {
592			case "string":
593				subsetVal = []byte(value)
594			case "blob":
595				subsetVal = []byte(value)
596			case "int32":
597				i, err := strconv.Atoi(value)
598				if err != nil {
599					stmt.Close()
600					return nil, errf("invalid conversion to int32 from %q", value)
601				}
602				subsetVal = int64(i) // int64 is a subset type, but not int32
603			case "table": // For testing cursor reads.
604				c.skipDirtySession = true
605				vparts := strings.Split(value, "!")
606
607				substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ",")))
608				if err != nil {
609					return nil, err
610				}
611				cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{})
612				substmt.Close()
613				if err != nil {
614					return nil, err
615				}
616				subsetVal = cursor
617			default:
618				stmt.Close()
619				return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
620			}
621			stmt.colValue = append(stmt.colValue, subsetVal)
622		} else {
623			stmt.placeholders++
624			stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
625			stmt.colValue = append(stmt.colValue, value)
626		}
627	}
628	return stmt, nil
629}
630
631// hook to simulate broken connections
632var hookPrepareBadConn func() bool
633
634func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
635	panic("use PrepareContext")
636}
637
638func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
639	c.numPrepare++
640	if c.db == nil {
641		panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
642	}
643
644	if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) {
645		return nil, fakeError{Message: "Preapre: Sticky Bad", Wrapped: driver.ErrBadConn}
646	}
647
648	c.touchMem()
649	var firstStmt, prev *fakeStmt
650	for _, query := range strings.Split(query, ";") {
651		parts := strings.Split(query, "|")
652		if len(parts) < 1 {
653			return nil, errf("empty query")
654		}
655		stmt := &fakeStmt{q: query, c: c, memToucher: c}
656		if firstStmt == nil {
657			firstStmt = stmt
658		}
659		if len(parts) >= 3 {
660			switch parts[0] {
661			case "PANIC":
662				stmt.panic = parts[1]
663				parts = parts[2:]
664			case "WAIT":
665				wait, err := time.ParseDuration(parts[1])
666				if err != nil {
667					return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err)
668				}
669				parts = parts[2:]
670				stmt.wait = wait
671			}
672		}
673		cmd := parts[0]
674		stmt.cmd = cmd
675		parts = parts[1:]
676
677		if c.waiter != nil {
678			c.waiter(ctx)
679		}
680
681		if stmt.wait > 0 {
682			wait := time.NewTimer(stmt.wait)
683			select {
684			case <-wait.C:
685			case <-ctx.Done():
686				wait.Stop()
687				return nil, ctx.Err()
688			}
689		}
690
691		c.incrStat(&c.stmtsMade)
692		var err error
693		switch cmd {
694		case "WIPE":
695			// Nothing
696		case "SELECT":
697			stmt, err = c.prepareSelect(stmt, parts)
698		case "CREATE":
699			stmt, err = c.prepareCreate(stmt, parts)
700		case "INSERT":
701			stmt, err = c.prepareInsert(ctx, stmt, parts)
702		case "NOSERT":
703			// Do all the prep-work like for an INSERT but don't actually insert the row.
704			// Used for some of the concurrent tests.
705			stmt, err = c.prepareInsert(ctx, stmt, parts)
706		default:
707			stmt.Close()
708			return nil, errf("unsupported command type %q", cmd)
709		}
710		if err != nil {
711			return nil, err
712		}
713		if prev != nil {
714			prev.next = stmt
715		}
716		prev = stmt
717	}
718	return firstStmt, nil
719}
720
721func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
722	if s.panic == "ColumnConverter" {
723		panic(s.panic)
724	}
725	if len(s.placeholderConverter) == 0 {
726		return driver.DefaultParameterConverter
727	}
728	return s.placeholderConverter[idx]
729}
730
731func (s *fakeStmt) Close() error {
732	if s.panic == "Close" {
733		panic(s.panic)
734	}
735	if s.c == nil {
736		panic("nil conn in fakeStmt.Close")
737	}
738	if s.c.db == nil {
739		panic("in fakeStmt.Close, conn's db is nil (already closed)")
740	}
741	s.touchMem()
742	if !s.closed {
743		s.c.incrStat(&s.c.stmtsClosed)
744		s.closed = true
745	}
746	if s.next != nil {
747		s.next.Close()
748	}
749	return nil
750}
751
752var errClosed = errors.New("fakedb: statement has been closed")
753
754// hook to simulate broken connections
755var hookExecBadConn func() bool
756
757func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
758	panic("Using ExecContext")
759}
760
761var errFakeConnSessionDirty = errors.New("fakedb: session is dirty")
762
763func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
764	if s.panic == "Exec" {
765		panic(s.panic)
766	}
767	if s.closed {
768		return nil, errClosed
769	}
770
771	if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
772		return nil, fakeError{Message: "Exec: Sticky Bad", Wrapped: driver.ErrBadConn}
773	}
774	if s.c.isDirtyAndMark() {
775		return nil, errFakeConnSessionDirty
776	}
777
778	err := checkSubsetTypes(s.c.db.allowAny, args)
779	if err != nil {
780		return nil, err
781	}
782	s.touchMem()
783
784	if s.wait > 0 {
785		time.Sleep(s.wait)
786	}
787
788	select {
789	default:
790	case <-ctx.Done():
791		return nil, ctx.Err()
792	}
793
794	db := s.c.db
795	switch s.cmd {
796	case "WIPE":
797		db.wipe()
798		return driver.ResultNoRows, nil
799	case "CREATE":
800		if err := db.createTable(s.table, s.colName, s.colType); err != nil {
801			return nil, err
802		}
803		return driver.ResultNoRows, nil
804	case "INSERT":
805		return s.execInsert(args, true)
806	case "NOSERT":
807		// Do all the prep-work like for an INSERT but don't actually insert the row.
808		// Used for some of the concurrent tests.
809		return s.execInsert(args, false)
810	}
811	return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd)
812}
813
814// When doInsert is true, add the row to the table.
815// When doInsert is false do prep-work and error checking, but don't
816// actually add the row to the table.
817func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) {
818	db := s.c.db
819	if len(args) != s.placeholders {
820		panic("error in pkg db; should only get here if size is correct")
821	}
822	db.mu.Lock()
823	t, ok := db.table(s.table)
824	db.mu.Unlock()
825	if !ok {
826		return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
827	}
828
829	t.mu.Lock()
830	defer t.mu.Unlock()
831
832	var cols []any
833	if doInsert {
834		cols = make([]any, len(t.colname))
835	}
836	argPos := 0
837	for n, colname := range s.colName {
838		colidx := t.columnIndex(colname)
839		if colidx == -1 {
840			return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
841		}
842		var val any
843		if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") {
844			if strvalue == "?" {
845				val = args[argPos].Value
846			} else {
847				// Assign value from argument placeholder name.
848				for _, a := range args {
849					if a.Name == strvalue[1:] {
850						val = a.Value
851						break
852					}
853				}
854			}
855			argPos++
856		} else {
857			val = s.colValue[n]
858		}
859		if doInsert {
860			cols[colidx] = val
861		}
862	}
863
864	if doInsert {
865		t.rows = append(t.rows, &row{cols: cols})
866	}
867	return driver.RowsAffected(1), nil
868}
869
870// hook to simulate broken connections
871var hookQueryBadConn func() bool
872
873func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
874	panic("Use QueryContext")
875}
876
877func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
878	if s.panic == "Query" {
879		panic(s.panic)
880	}
881	if s.closed {
882		return nil, errClosed
883	}
884
885	if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
886		return nil, fakeError{Message: "Query: Sticky Bad", Wrapped: driver.ErrBadConn}
887	}
888	if s.c.isDirtyAndMark() {
889		return nil, errFakeConnSessionDirty
890	}
891
892	err := checkSubsetTypes(s.c.db.allowAny, args)
893	if err != nil {
894		return nil, err
895	}
896
897	s.touchMem()
898	db := s.c.db
899	if len(args) != s.placeholders {
900		panic("error in pkg db; should only get here if size is correct")
901	}
902
903	setMRows := make([][]*row, 0, 1)
904	setColumns := make([][]string, 0, 1)
905	setColType := make([][]string, 0, 1)
906
907	for {
908		db.mu.Lock()
909		t, ok := db.table(s.table)
910		db.mu.Unlock()
911		if !ok {
912			return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
913		}
914
915		if s.table == "magicquery" {
916			if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" {
917				if args[0].Value == "sleep" {
918					time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond)
919				}
920			}
921		}
922		if s.table == "tx_status" && s.colName[0] == "tx_status" {
923			txStatus := "autocommit"
924			if s.c.currTx != nil {
925				txStatus = "transaction"
926			}
927			cursor := &rowsCursor{
928				parentMem: s.c,
929				posRow:    -1,
930				rows: [][]*row{
931					{
932						{
933							cols: []any{
934								txStatus,
935							},
936						},
937					},
938				},
939				cols: [][]string{
940					{
941						"tx_status",
942					},
943				},
944				colType: [][]string{
945					{
946						"string",
947					},
948				},
949				errPos: -1,
950			}
951			return cursor, nil
952		}
953
954		t.mu.Lock()
955
956		colIdx := make(map[string]int) // select column name -> column index in table
957		for _, name := range s.colName {
958			idx := t.columnIndex(name)
959			if idx == -1 {
960				t.mu.Unlock()
961				return nil, fmt.Errorf("fakedb: unknown column name %q", name)
962			}
963			colIdx[name] = idx
964		}
965
966		mrows := []*row{}
967	rows:
968		for _, trow := range t.rows {
969			// Process the where clause, skipping non-match rows. This is lazy
970			// and just uses fmt.Sprintf("%v") to test equality. Good enough
971			// for test code.
972			for _, wcol := range s.whereCol {
973				idx := t.columnIndex(wcol.Column)
974				if idx == -1 {
975					t.mu.Unlock()
976					return nil, fmt.Errorf("fakedb: invalid where clause column %q", wcol)
977				}
978				tcol := trow.cols[idx]
979				if bs, ok := tcol.([]byte); ok {
980					// lazy hack to avoid sprintf %v on a []byte
981					tcol = string(bs)
982				}
983				var argValue any
984				if wcol.Placeholder == "?" {
985					argValue = args[wcol.Ordinal-1].Value
986				} else {
987					// Assign arg value from placeholder name.
988					for _, a := range args {
989						if a.Name == wcol.Placeholder[1:] {
990							argValue = a.Value
991							break
992						}
993					}
994				}
995				if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
996					continue rows
997				}
998			}
999			mrow := &row{cols: make([]any, len(s.colName))}
1000			for seli, name := range s.colName {
1001				mrow.cols[seli] = trow.cols[colIdx[name]]
1002			}
1003			mrows = append(mrows, mrow)
1004		}
1005
1006		var colType []string
1007		for _, column := range s.colName {
1008			colType = append(colType, t.coltype[t.columnIndex(column)])
1009		}
1010
1011		t.mu.Unlock()
1012
1013		setMRows = append(setMRows, mrows)
1014		setColumns = append(setColumns, s.colName)
1015		setColType = append(setColType, colType)
1016
1017		if s.next == nil {
1018			break
1019		}
1020		s = s.next
1021	}
1022
1023	cursor := &rowsCursor{
1024		parentMem: s.c,
1025		posRow:    -1,
1026		rows:      setMRows,
1027		cols:      setColumns,
1028		colType:   setColType,
1029		errPos:    -1,
1030	}
1031	return cursor, nil
1032}
1033
1034func (s *fakeStmt) NumInput() int {
1035	if s.panic == "NumInput" {
1036		panic(s.panic)
1037	}
1038	return s.placeholders
1039}
1040
1041// hook to simulate broken connections
1042var hookCommitBadConn func() bool
1043
1044func (tx *fakeTx) Commit() error {
1045	tx.c.currTx = nil
1046	if hookCommitBadConn != nil && hookCommitBadConn() {
1047		return fakeError{Message: "Commit: Hook Bad Conn", Wrapped: driver.ErrBadConn}
1048	}
1049	tx.c.touchMem()
1050	return nil
1051}
1052
1053// hook to simulate broken connections
1054var hookRollbackBadConn func() bool
1055
1056func (tx *fakeTx) Rollback() error {
1057	tx.c.currTx = nil
1058	if hookRollbackBadConn != nil && hookRollbackBadConn() {
1059		return fakeError{Message: "Rollback: Hook Bad Conn", Wrapped: driver.ErrBadConn}
1060	}
1061	tx.c.touchMem()
1062	return nil
1063}
1064
1065type rowsCursor struct {
1066	parentMem memToucher
1067	cols      [][]string
1068	colType   [][]string
1069	posSet    int
1070	posRow    int
1071	rows      [][]*row
1072	closed    bool
1073
1074	// errPos and err are for making Next return early with error.
1075	errPos int
1076	err    error
1077
1078	// a clone of slices to give out to clients, indexed by the
1079	// original slice's first byte address.  we clone them
1080	// just so we're able to corrupt them on close.
1081	bytesClone map[*byte][]byte
1082
1083	// Every operation writes to line to enable the race detector
1084	// check for data races.
1085	// This is separate from the fakeConn.line to allow for drivers that
1086	// can start multiple queries on the same transaction at the same time.
1087	line int64
1088}
1089
1090func (rc *rowsCursor) touchMem() {
1091	rc.parentMem.touchMem()
1092	rc.line++
1093}
1094
1095func (rc *rowsCursor) Close() error {
1096	rc.touchMem()
1097	rc.parentMem.touchMem()
1098	rc.closed = true
1099	return nil
1100}
1101
1102func (rc *rowsCursor) Columns() []string {
1103	return rc.cols[rc.posSet]
1104}
1105
1106func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
1107	return colTypeToReflectType(rc.colType[rc.posSet][index])
1108}
1109
1110var rowsCursorNextHook func(dest []driver.Value) error
1111
1112func (rc *rowsCursor) Next(dest []driver.Value) error {
1113	if rowsCursorNextHook != nil {
1114		return rowsCursorNextHook(dest)
1115	}
1116
1117	if rc.closed {
1118		return errors.New("fakedb: cursor is closed")
1119	}
1120	rc.touchMem()
1121	rc.posRow++
1122	if rc.posRow == rc.errPos {
1123		return rc.err
1124	}
1125	if rc.posRow >= len(rc.rows[rc.posSet]) {
1126		return io.EOF // per interface spec
1127	}
1128	for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
1129		// TODO(bradfitz): convert to subset types? naah, I
1130		// think the subset types should only be input to
1131		// driver, but the sql package should be able to handle
1132		// a wider range of types coming out of drivers. all
1133		// for ease of drivers, and to prevent drivers from
1134		// messing up conversions or doing them differently.
1135		dest[i] = v
1136
1137		if bs, ok := v.([]byte); ok {
1138			if rc.bytesClone == nil {
1139				rc.bytesClone = make(map[*byte][]byte)
1140			}
1141			clone, ok := rc.bytesClone[&bs[0]]
1142			if !ok {
1143				clone = make([]byte, len(bs))
1144				copy(clone, bs)
1145				rc.bytesClone[&bs[0]] = clone
1146			}
1147			dest[i] = clone
1148		}
1149	}
1150	return nil
1151}
1152
1153func (rc *rowsCursor) HasNextResultSet() bool {
1154	rc.touchMem()
1155	return rc.posSet < len(rc.rows)-1
1156}
1157
1158func (rc *rowsCursor) NextResultSet() error {
1159	rc.touchMem()
1160	if rc.HasNextResultSet() {
1161		rc.posSet++
1162		rc.posRow = -1
1163		return nil
1164	}
1165	return io.EOF // Per interface spec.
1166}
1167
1168// fakeDriverString is like driver.String, but indirects pointers like
1169// DefaultValueConverter.
1170//
1171// This could be surprising behavior to retroactively apply to
1172// driver.String now that Go1 is out, but this is convenient for
1173// our TestPointerParamsAndScans.
1174//
1175type fakeDriverString struct{}
1176
1177func (fakeDriverString) ConvertValue(v any) (driver.Value, error) {
1178	switch c := v.(type) {
1179	case string, []byte:
1180		return v, nil
1181	case *string:
1182		if c == nil {
1183			return nil, nil
1184		}
1185		return *c, nil
1186	}
1187	return fmt.Sprintf("%v", v), nil
1188}
1189
1190type anyTypeConverter struct{}
1191
1192func (anyTypeConverter) ConvertValue(v any) (driver.Value, error) {
1193	return v, nil
1194}
1195
1196func converterForType(typ string) driver.ValueConverter {
1197	switch typ {
1198	case "bool":
1199		return driver.Bool
1200	case "nullbool":
1201		return driver.Null{Converter: driver.Bool}
1202	case "byte", "int16":
1203		return driver.NotNull{Converter: driver.DefaultParameterConverter}
1204	case "int32":
1205		return driver.Int32
1206	case "nullbyte", "nullint32", "nullint16":
1207		return driver.Null{Converter: driver.DefaultParameterConverter}
1208	case "string":
1209		return driver.NotNull{Converter: fakeDriverString{}}
1210	case "nullstring":
1211		return driver.Null{Converter: fakeDriverString{}}
1212	case "int64":
1213		// TODO(coopernurse): add type-specific converter
1214		return driver.NotNull{Converter: driver.DefaultParameterConverter}
1215	case "nullint64":
1216		// TODO(coopernurse): add type-specific converter
1217		return driver.Null{Converter: driver.DefaultParameterConverter}
1218	case "float64":
1219		// TODO(coopernurse): add type-specific converter
1220		return driver.NotNull{Converter: driver.DefaultParameterConverter}
1221	case "nullfloat64":
1222		// TODO(coopernurse): add type-specific converter
1223		return driver.Null{Converter: driver.DefaultParameterConverter}
1224	case "datetime":
1225		return driver.NotNull{Converter: driver.DefaultParameterConverter}
1226	case "nulldatetime":
1227		return driver.Null{Converter: driver.DefaultParameterConverter}
1228	case "any":
1229		return anyTypeConverter{}
1230	}
1231	panic("invalid fakedb column type of " + typ)
1232}
1233
1234func colTypeToReflectType(typ string) reflect.Type {
1235	switch typ {
1236	case "bool":
1237		return reflect.TypeOf(false)
1238	case "nullbool":
1239		return reflect.TypeOf(NullBool{})
1240	case "int16":
1241		return reflect.TypeOf(int16(0))
1242	case "nullint16":
1243		return reflect.TypeOf(NullInt16{})
1244	case "int32":
1245		return reflect.TypeOf(int32(0))
1246	case "nullint32":
1247		return reflect.TypeOf(NullInt32{})
1248	case "string":
1249		return reflect.TypeOf("")
1250	case "nullstring":
1251		return reflect.TypeOf(NullString{})
1252	case "int64":
1253		return reflect.TypeOf(int64(0))
1254	case "nullint64":
1255		return reflect.TypeOf(NullInt64{})
1256	case "float64":
1257		return reflect.TypeOf(float64(0))
1258	case "nullfloat64":
1259		return reflect.TypeOf(NullFloat64{})
1260	case "datetime":
1261		return reflect.TypeOf(time.Time{})
1262	case "any":
1263		return reflect.TypeOf(new(any)).Elem()
1264	}
1265	panic("invalid fakedb column type of " + typ)
1266}
1267