1// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
2//
3// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
4//
5// This Source Code Form is subject to the terms of the Mozilla Public
6// License, v. 2.0. If a copy of the MPL was not distributed with this file,
7// You can obtain one at http://mozilla.org/MPL/2.0/.
8
9package mysql
10
11import (
12	"bytes"
13	"context"
14	"crypto/tls"
15	"database/sql"
16	"database/sql/driver"
17	"fmt"
18	"io"
19	"io/ioutil"
20	"log"
21	"math"
22	"net"
23	"net/url"
24	"os"
25	"reflect"
26	"strings"
27	"sync"
28	"sync/atomic"
29	"testing"
30	"time"
31)
32
33// Ensure that all the driver interfaces are implemented
34var (
35	_ driver.Rows = &binaryRows{}
36	_ driver.Rows = &textRows{}
37)
38
39var (
40	user      string
41	pass      string
42	prot      string
43	addr      string
44	dbname    string
45	dsn       string
46	netAddr   string
47	available bool
48)
49
50var (
51	tDate      = time.Date(2012, 6, 14, 0, 0, 0, 0, time.UTC)
52	sDate      = "2012-06-14"
53	tDateTime  = time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)
54	sDateTime  = "2011-11-20 21:27:37"
55	tDate0     = time.Time{}
56	sDate0     = "0000-00-00"
57	sDateTime0 = "0000-00-00 00:00:00"
58)
59
60// See https://github.com/go-sql-driver/mysql/wiki/Testing
61func init() {
62	// get environment variables
63	env := func(key, defaultValue string) string {
64		if value := os.Getenv(key); value != "" {
65			return value
66		}
67		return defaultValue
68	}
69	user = env("MYSQL_TEST_USER", "root")
70	pass = env("MYSQL_TEST_PASS", "")
71	prot = env("MYSQL_TEST_PROT", "tcp")
72	addr = env("MYSQL_TEST_ADDR", "localhost:3306")
73	dbname = env("MYSQL_TEST_DBNAME", "gotest")
74	netAddr = fmt.Sprintf("%s(%s)", prot, addr)
75	dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, pass, netAddr, dbname)
76	c, err := net.Dial(prot, addr)
77	if err == nil {
78		available = true
79		c.Close()
80	}
81}
82
83type DBTest struct {
84	*testing.T
85	db *sql.DB
86}
87
88type netErrorMock struct {
89	temporary bool
90	timeout   bool
91}
92
93func (e netErrorMock) Temporary() bool {
94	return e.temporary
95}
96
97func (e netErrorMock) Timeout() bool {
98	return e.timeout
99}
100
101func (e netErrorMock) Error() string {
102	return fmt.Sprintf("mock net error. Temporary: %v, Timeout %v", e.temporary, e.timeout)
103}
104
105func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
106	if !available {
107		t.Skipf("MySQL server not running on %s", netAddr)
108	}
109
110	dsn += "&multiStatements=true"
111	var db *sql.DB
112	if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation {
113		db, err = sql.Open("mysql", dsn)
114		if err != nil {
115			t.Fatalf("error connecting: %s", err.Error())
116		}
117		defer db.Close()
118	}
119
120	dbt := &DBTest{t, db}
121	for _, test := range tests {
122		test(dbt)
123		dbt.db.Exec("DROP TABLE IF EXISTS test")
124	}
125}
126
127func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
128	if !available {
129		t.Skipf("MySQL server not running on %s", netAddr)
130	}
131
132	db, err := sql.Open("mysql", dsn)
133	if err != nil {
134		t.Fatalf("error connecting: %s", err.Error())
135	}
136	defer db.Close()
137
138	db.Exec("DROP TABLE IF EXISTS test")
139
140	dsn2 := dsn + "&interpolateParams=true"
141	var db2 *sql.DB
142	if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation {
143		db2, err = sql.Open("mysql", dsn2)
144		if err != nil {
145			t.Fatalf("error connecting: %s", err.Error())
146		}
147		defer db2.Close()
148	}
149
150	dsn3 := dsn + "&multiStatements=true"
151	var db3 *sql.DB
152	if _, err := ParseDSN(dsn3); err != errInvalidDSNUnsafeCollation {
153		db3, err = sql.Open("mysql", dsn3)
154		if err != nil {
155			t.Fatalf("error connecting: %s", err.Error())
156		}
157		defer db3.Close()
158	}
159
160	dbt := &DBTest{t, db}
161	dbt2 := &DBTest{t, db2}
162	dbt3 := &DBTest{t, db3}
163	for _, test := range tests {
164		test(dbt)
165		dbt.db.Exec("DROP TABLE IF EXISTS test")
166		if db2 != nil {
167			test(dbt2)
168			dbt2.db.Exec("DROP TABLE IF EXISTS test")
169		}
170		if db3 != nil {
171			test(dbt3)
172			dbt3.db.Exec("DROP TABLE IF EXISTS test")
173		}
174	}
175}
176
177func (dbt *DBTest) fail(method, query string, err error) {
178	if len(query) > 300 {
179		query = "[query too large to print]"
180	}
181	dbt.Fatalf("error on %s %s: %s", method, query, err.Error())
182}
183
184func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) {
185	res, err := dbt.db.Exec(query, args...)
186	if err != nil {
187		dbt.fail("exec", query, err)
188	}
189	return res
190}
191
192func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) {
193	rows, err := dbt.db.Query(query, args...)
194	if err != nil {
195		dbt.fail("query", query, err)
196	}
197	return rows
198}
199
200func maybeSkip(t *testing.T, err error, skipErrno uint16) {
201	mySQLErr, ok := err.(*MySQLError)
202	if !ok {
203		return
204	}
205
206	if mySQLErr.Number == skipErrno {
207		t.Skipf("skipping test for error: %v", err)
208	}
209}
210
211func TestEmptyQuery(t *testing.T) {
212	runTests(t, dsn, func(dbt *DBTest) {
213		// just a comment, no query
214		rows := dbt.mustQuery("--")
215		defer rows.Close()
216		// will hang before #255
217		if rows.Next() {
218			dbt.Errorf("next on rows must be false")
219		}
220	})
221}
222
223func TestCRUD(t *testing.T) {
224	runTests(t, dsn, func(dbt *DBTest) {
225		// Create Table
226		dbt.mustExec("CREATE TABLE test (value BOOL)")
227
228		// Test for unexpected data
229		var out bool
230		rows := dbt.mustQuery("SELECT * FROM test")
231		if rows.Next() {
232			dbt.Error("unexpected data in empty table")
233		}
234		rows.Close()
235
236		// Create Data
237		res := dbt.mustExec("INSERT INTO test VALUES (1)")
238		count, err := res.RowsAffected()
239		if err != nil {
240			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
241		}
242		if count != 1 {
243			dbt.Fatalf("expected 1 affected row, got %d", count)
244		}
245
246		id, err := res.LastInsertId()
247		if err != nil {
248			dbt.Fatalf("res.LastInsertId() returned error: %s", err.Error())
249		}
250		if id != 0 {
251			dbt.Fatalf("expected InsertId 0, got %d", id)
252		}
253
254		// Read
255		rows = dbt.mustQuery("SELECT value FROM test")
256		if rows.Next() {
257			rows.Scan(&out)
258			if true != out {
259				dbt.Errorf("true != %t", out)
260			}
261
262			if rows.Next() {
263				dbt.Error("unexpected data")
264			}
265		} else {
266			dbt.Error("no data")
267		}
268		rows.Close()
269
270		// Update
271		res = dbt.mustExec("UPDATE test SET value = ? WHERE value = ?", false, true)
272		count, err = res.RowsAffected()
273		if err != nil {
274			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
275		}
276		if count != 1 {
277			dbt.Fatalf("expected 1 affected row, got %d", count)
278		}
279
280		// Check Update
281		rows = dbt.mustQuery("SELECT value FROM test")
282		if rows.Next() {
283			rows.Scan(&out)
284			if false != out {
285				dbt.Errorf("false != %t", out)
286			}
287
288			if rows.Next() {
289				dbt.Error("unexpected data")
290			}
291		} else {
292			dbt.Error("no data")
293		}
294		rows.Close()
295
296		// Delete
297		res = dbt.mustExec("DELETE FROM test WHERE value = ?", false)
298		count, err = res.RowsAffected()
299		if err != nil {
300			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
301		}
302		if count != 1 {
303			dbt.Fatalf("expected 1 affected row, got %d", count)
304		}
305
306		// Check for unexpected rows
307		res = dbt.mustExec("DELETE FROM test")
308		count, err = res.RowsAffected()
309		if err != nil {
310			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
311		}
312		if count != 0 {
313			dbt.Fatalf("expected 0 affected row, got %d", count)
314		}
315	})
316}
317
318func TestMultiQuery(t *testing.T) {
319	runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
320		// Create Table
321		dbt.mustExec("CREATE TABLE `test` (`id` int(11) NOT NULL, `value` int(11) NOT NULL) ")
322
323		// Create Data
324		res := dbt.mustExec("INSERT INTO test VALUES (1, 1)")
325		count, err := res.RowsAffected()
326		if err != nil {
327			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
328		}
329		if count != 1 {
330			dbt.Fatalf("expected 1 affected row, got %d", count)
331		}
332
333		// Update
334		res = dbt.mustExec("UPDATE test SET value = 3 WHERE id = 1; UPDATE test SET value = 4 WHERE id = 1; UPDATE test SET value = 5 WHERE id = 1;")
335		count, err = res.RowsAffected()
336		if err != nil {
337			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
338		}
339		if count != 1 {
340			dbt.Fatalf("expected 1 affected row, got %d", count)
341		}
342
343		// Read
344		var out int
345		rows := dbt.mustQuery("SELECT value FROM test WHERE id=1;")
346		if rows.Next() {
347			rows.Scan(&out)
348			if 5 != out {
349				dbt.Errorf("5 != %d", out)
350			}
351
352			if rows.Next() {
353				dbt.Error("unexpected data")
354			}
355		} else {
356			dbt.Error("no data")
357		}
358		rows.Close()
359
360	})
361}
362
363func TestInt(t *testing.T) {
364	runTests(t, dsn, func(dbt *DBTest) {
365		types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"}
366		in := int64(42)
367		var out int64
368		var rows *sql.Rows
369
370		// SIGNED
371		for _, v := range types {
372			dbt.mustExec("CREATE TABLE test (value " + v + ")")
373
374			dbt.mustExec("INSERT INTO test VALUES (?)", in)
375
376			rows = dbt.mustQuery("SELECT value FROM test")
377			if rows.Next() {
378				rows.Scan(&out)
379				if in != out {
380					dbt.Errorf("%s: %d != %d", v, in, out)
381				}
382			} else {
383				dbt.Errorf("%s: no data", v)
384			}
385			rows.Close()
386
387			dbt.mustExec("DROP TABLE IF EXISTS test")
388		}
389
390		// UNSIGNED ZEROFILL
391		for _, v := range types {
392			dbt.mustExec("CREATE TABLE test (value " + v + " ZEROFILL)")
393
394			dbt.mustExec("INSERT INTO test VALUES (?)", in)
395
396			rows = dbt.mustQuery("SELECT value FROM test")
397			if rows.Next() {
398				rows.Scan(&out)
399				if in != out {
400					dbt.Errorf("%s ZEROFILL: %d != %d", v, in, out)
401				}
402			} else {
403				dbt.Errorf("%s ZEROFILL: no data", v)
404			}
405			rows.Close()
406
407			dbt.mustExec("DROP TABLE IF EXISTS test")
408		}
409	})
410}
411
412func TestFloat32(t *testing.T) {
413	runTests(t, dsn, func(dbt *DBTest) {
414		types := [2]string{"FLOAT", "DOUBLE"}
415		in := float32(42.23)
416		var out float32
417		var rows *sql.Rows
418		for _, v := range types {
419			dbt.mustExec("CREATE TABLE test (value " + v + ")")
420			dbt.mustExec("INSERT INTO test VALUES (?)", in)
421			rows = dbt.mustQuery("SELECT value FROM test")
422			if rows.Next() {
423				rows.Scan(&out)
424				if in != out {
425					dbt.Errorf("%s: %g != %g", v, in, out)
426				}
427			} else {
428				dbt.Errorf("%s: no data", v)
429			}
430			rows.Close()
431			dbt.mustExec("DROP TABLE IF EXISTS test")
432		}
433	})
434}
435
436func TestFloat64(t *testing.T) {
437	runTests(t, dsn, func(dbt *DBTest) {
438		types := [2]string{"FLOAT", "DOUBLE"}
439		var expected float64 = 42.23
440		var out float64
441		var rows *sql.Rows
442		for _, v := range types {
443			dbt.mustExec("CREATE TABLE test (value " + v + ")")
444			dbt.mustExec("INSERT INTO test VALUES (42.23)")
445			rows = dbt.mustQuery("SELECT value FROM test")
446			if rows.Next() {
447				rows.Scan(&out)
448				if expected != out {
449					dbt.Errorf("%s: %g != %g", v, expected, out)
450				}
451			} else {
452				dbt.Errorf("%s: no data", v)
453			}
454			rows.Close()
455			dbt.mustExec("DROP TABLE IF EXISTS test")
456		}
457	})
458}
459
460func TestFloat64Placeholder(t *testing.T) {
461	runTests(t, dsn, func(dbt *DBTest) {
462		types := [2]string{"FLOAT", "DOUBLE"}
463		var expected float64 = 42.23
464		var out float64
465		var rows *sql.Rows
466		for _, v := range types {
467			dbt.mustExec("CREATE TABLE test (id int, value " + v + ")")
468			dbt.mustExec("INSERT INTO test VALUES (1, 42.23)")
469			rows = dbt.mustQuery("SELECT value FROM test WHERE id = ?", 1)
470			if rows.Next() {
471				rows.Scan(&out)
472				if expected != out {
473					dbt.Errorf("%s: %g != %g", v, expected, out)
474				}
475			} else {
476				dbt.Errorf("%s: no data", v)
477			}
478			rows.Close()
479			dbt.mustExec("DROP TABLE IF EXISTS test")
480		}
481	})
482}
483
484func TestString(t *testing.T) {
485	runTests(t, dsn, func(dbt *DBTest) {
486		types := [6]string{"CHAR(255)", "VARCHAR(255)", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT"}
487		in := "κόσμε üöäßñóùéàâÿœ'îë Árvíztűrő いろはにほへとちりぬるを イロハニホヘト דג סקרן чащах  น่าฟังเอย"
488		var out string
489		var rows *sql.Rows
490
491		for _, v := range types {
492			dbt.mustExec("CREATE TABLE test (value " + v + ") CHARACTER SET utf8")
493
494			dbt.mustExec("INSERT INTO test VALUES (?)", in)
495
496			rows = dbt.mustQuery("SELECT value FROM test")
497			if rows.Next() {
498				rows.Scan(&out)
499				if in != out {
500					dbt.Errorf("%s: %s != %s", v, in, out)
501				}
502			} else {
503				dbt.Errorf("%s: no data", v)
504			}
505			rows.Close()
506
507			dbt.mustExec("DROP TABLE IF EXISTS test")
508		}
509
510		// BLOB
511		dbt.mustExec("CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8")
512
513		id := 2
514		in = "Lorem ipsum dolor sit amet, consetetur sadipscing elitr, " +
515			"sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, " +
516			"sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. " +
517			"Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. " +
518			"Lorem ipsum dolor sit amet, consetetur sadipscing elitr, " +
519			"sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, " +
520			"sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. " +
521			"Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet."
522		dbt.mustExec("INSERT INTO test VALUES (?, ?)", id, in)
523
524		err := dbt.db.QueryRow("SELECT value FROM test WHERE id = ?", id).Scan(&out)
525		if err != nil {
526			dbt.Fatalf("Error on BLOB-Query: %s", err.Error())
527		} else if out != in {
528			dbt.Errorf("BLOB: %s != %s", in, out)
529		}
530	})
531}
532
533func TestRawBytes(t *testing.T) {
534	runTests(t, dsn, func(dbt *DBTest) {
535		v1 := []byte("aaa")
536		v2 := []byte("bbb")
537		rows := dbt.mustQuery("SELECT ?, ?", v1, v2)
538		defer rows.Close()
539		if rows.Next() {
540			var o1, o2 sql.RawBytes
541			if err := rows.Scan(&o1, &o2); err != nil {
542				dbt.Errorf("Got error: %v", err)
543			}
544			if !bytes.Equal(v1, o1) {
545				dbt.Errorf("expected %v, got %v", v1, o1)
546			}
547			if !bytes.Equal(v2, o2) {
548				dbt.Errorf("expected %v, got %v", v2, o2)
549			}
550			// https://github.com/go-sql-driver/mysql/issues/765
551			// Appending to RawBytes shouldn't overwrite next RawBytes.
552			o1 = append(o1, "xyzzy"...)
553			if !bytes.Equal(v2, o2) {
554				dbt.Errorf("expected %v, got %v", v2, o2)
555			}
556		} else {
557			dbt.Errorf("no data")
558		}
559	})
560}
561
562type testValuer struct {
563	value string
564}
565
566func (tv testValuer) Value() (driver.Value, error) {
567	return tv.value, nil
568}
569
570func TestValuer(t *testing.T) {
571	runTests(t, dsn, func(dbt *DBTest) {
572		in := testValuer{"a_value"}
573		var out string
574		var rows *sql.Rows
575
576		dbt.mustExec("CREATE TABLE test (value VARCHAR(255)) CHARACTER SET utf8")
577		dbt.mustExec("INSERT INTO test VALUES (?)", in)
578		rows = dbt.mustQuery("SELECT value FROM test")
579		if rows.Next() {
580			rows.Scan(&out)
581			if in.value != out {
582				dbt.Errorf("Valuer: %v != %s", in, out)
583			}
584		} else {
585			dbt.Errorf("Valuer: no data")
586		}
587		rows.Close()
588
589		dbt.mustExec("DROP TABLE IF EXISTS test")
590	})
591}
592
593type testValuerWithValidation struct {
594	value string
595}
596
597func (tv testValuerWithValidation) Value() (driver.Value, error) {
598	if len(tv.value) == 0 {
599		return nil, fmt.Errorf("Invalid string valuer. Value must not be empty")
600	}
601
602	return tv.value, nil
603}
604
605func TestValuerWithValidation(t *testing.T) {
606	runTests(t, dsn, func(dbt *DBTest) {
607		in := testValuerWithValidation{"a_value"}
608		var out string
609		var rows *sql.Rows
610
611		dbt.mustExec("CREATE TABLE testValuer (value VARCHAR(255)) CHARACTER SET utf8")
612		dbt.mustExec("INSERT INTO testValuer VALUES (?)", in)
613
614		rows = dbt.mustQuery("SELECT value FROM testValuer")
615		defer rows.Close()
616
617		if rows.Next() {
618			rows.Scan(&out)
619			if in.value != out {
620				dbt.Errorf("Valuer: %v != %s", in, out)
621			}
622		} else {
623			dbt.Errorf("Valuer: no data")
624		}
625
626		if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", testValuerWithValidation{""}); err == nil {
627			dbt.Errorf("Failed to check valuer error")
628		}
629
630		if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", nil); err != nil {
631			dbt.Errorf("Failed to check nil")
632		}
633
634		if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", map[string]bool{}); err == nil {
635			dbt.Errorf("Failed to check not valuer")
636		}
637
638		dbt.mustExec("DROP TABLE IF EXISTS testValuer")
639	})
640}
641
642type timeTests struct {
643	dbtype  string
644	tlayout string
645	tests   []timeTest
646}
647
648type timeTest struct {
649	s string // leading "!": do not use t as value in queries
650	t time.Time
651}
652
653type timeMode byte
654
655func (t timeMode) String() string {
656	switch t {
657	case binaryString:
658		return "binary:string"
659	case binaryTime:
660		return "binary:time.Time"
661	case textString:
662		return "text:string"
663	}
664	panic("unsupported timeMode")
665}
666
667func (t timeMode) Binary() bool {
668	switch t {
669	case binaryString, binaryTime:
670		return true
671	}
672	return false
673}
674
675const (
676	binaryString timeMode = iota
677	binaryTime
678	textString
679)
680
681func (t timeTest) genQuery(dbtype string, mode timeMode) string {
682	var inner string
683	if mode.Binary() {
684		inner = "?"
685	} else {
686		inner = `"%s"`
687	}
688	return `SELECT cast(` + inner + ` as ` + dbtype + `)`
689}
690
691func (t timeTest) run(dbt *DBTest, dbtype, tlayout string, mode timeMode) {
692	var rows *sql.Rows
693	query := t.genQuery(dbtype, mode)
694	switch mode {
695	case binaryString:
696		rows = dbt.mustQuery(query, t.s)
697	case binaryTime:
698		rows = dbt.mustQuery(query, t.t)
699	case textString:
700		query = fmt.Sprintf(query, t.s)
701		rows = dbt.mustQuery(query)
702	default:
703		panic("unsupported mode")
704	}
705	defer rows.Close()
706	var err error
707	if !rows.Next() {
708		err = rows.Err()
709		if err == nil {
710			err = fmt.Errorf("no data")
711		}
712		dbt.Errorf("%s [%s]: %s", dbtype, mode, err)
713		return
714	}
715	var dst interface{}
716	err = rows.Scan(&dst)
717	if err != nil {
718		dbt.Errorf("%s [%s]: %s", dbtype, mode, err)
719		return
720	}
721	switch val := dst.(type) {
722	case []uint8:
723		str := string(val)
724		if str == t.s {
725			return
726		}
727		if mode.Binary() && dbtype == "DATETIME" && len(str) == 26 && str[:19] == t.s {
728			// a fix mainly for TravisCI:
729			// accept full microsecond resolution in result for DATETIME columns
730			// where the binary protocol was used
731			return
732		}
733		dbt.Errorf("%s [%s] to string: expected %q, got %q",
734			dbtype, mode,
735			t.s, str,
736		)
737	case time.Time:
738		if val == t.t {
739			return
740		}
741		dbt.Errorf("%s [%s] to string: expected %q, got %q",
742			dbtype, mode,
743			t.s, val.Format(tlayout),
744		)
745	default:
746		fmt.Printf("%#v\n", []interface{}{dbtype, tlayout, mode, t.s, t.t})
747		dbt.Errorf("%s [%s]: unhandled type %T (is '%v')",
748			dbtype, mode,
749			val, val,
750		)
751	}
752}
753
754func TestDateTime(t *testing.T) {
755	afterTime := func(t time.Time, d string) time.Time {
756		dur, err := time.ParseDuration(d)
757		if err != nil {
758			panic(err)
759		}
760		return t.Add(dur)
761	}
762	// NOTE: MySQL rounds DATETIME(x) up - but that's not included in the tests
763	format := "2006-01-02 15:04:05.999999"
764	t0 := time.Time{}
765	tstr0 := "0000-00-00 00:00:00.000000"
766	testcases := []timeTests{
767		{"DATE", format[:10], []timeTest{
768			{t: time.Date(2011, 11, 20, 0, 0, 0, 0, time.UTC)},
769			{t: t0, s: tstr0[:10]},
770		}},
771		{"DATETIME", format[:19], []timeTest{
772			{t: time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)},
773			{t: t0, s: tstr0[:19]},
774		}},
775		{"DATETIME(0)", format[:21], []timeTest{
776			{t: time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)},
777			{t: t0, s: tstr0[:19]},
778		}},
779		{"DATETIME(1)", format[:21], []timeTest{
780			{t: time.Date(2011, 11, 20, 21, 27, 37, 100000000, time.UTC)},
781			{t: t0, s: tstr0[:21]},
782		}},
783		{"DATETIME(6)", format, []timeTest{
784			{t: time.Date(2011, 11, 20, 21, 27, 37, 123456000, time.UTC)},
785			{t: t0, s: tstr0},
786		}},
787		{"TIME", format[11:19], []timeTest{
788			{t: afterTime(t0, "12345s")},
789			{s: "!-12:34:56"},
790			{s: "!-838:59:59"},
791			{s: "!838:59:59"},
792			{t: t0, s: tstr0[11:19]},
793		}},
794		{"TIME(0)", format[11:19], []timeTest{
795			{t: afterTime(t0, "12345s")},
796			{s: "!-12:34:56"},
797			{s: "!-838:59:59"},
798			{s: "!838:59:59"},
799			{t: t0, s: tstr0[11:19]},
800		}},
801		{"TIME(1)", format[11:21], []timeTest{
802			{t: afterTime(t0, "12345600ms")},
803			{s: "!-12:34:56.7"},
804			{s: "!-838:59:58.9"},
805			{s: "!838:59:58.9"},
806			{t: t0, s: tstr0[11:21]},
807		}},
808		{"TIME(6)", format[11:], []timeTest{
809			{t: afterTime(t0, "1234567890123000ns")},
810			{s: "!-12:34:56.789012"},
811			{s: "!-838:59:58.999999"},
812			{s: "!838:59:58.999999"},
813			{t: t0, s: tstr0[11:]},
814		}},
815	}
816	dsns := []string{
817		dsn + "&parseTime=true",
818		dsn + "&parseTime=false",
819	}
820	for _, testdsn := range dsns {
821		runTests(t, testdsn, func(dbt *DBTest) {
822			microsecsSupported := false
823			zeroDateSupported := false
824			var rows *sql.Rows
825			var err error
826			rows, err = dbt.db.Query(`SELECT cast("00:00:00.1" as TIME(1)) = "00:00:00.1"`)
827			if err == nil {
828				rows.Scan(&microsecsSupported)
829				rows.Close()
830			}
831			rows, err = dbt.db.Query(`SELECT cast("0000-00-00" as DATE) = "0000-00-00"`)
832			if err == nil {
833				rows.Scan(&zeroDateSupported)
834				rows.Close()
835			}
836			for _, setups := range testcases {
837				if t := setups.dbtype; !microsecsSupported && t[len(t)-1:] == ")" {
838					// skip fractional second tests if unsupported by server
839					continue
840				}
841				for _, setup := range setups.tests {
842					allowBinTime := true
843					if setup.s == "" {
844						// fill time string wherever Go can reliable produce it
845						setup.s = setup.t.Format(setups.tlayout)
846					} else if setup.s[0] == '!' {
847						// skip tests using setup.t as source in queries
848						allowBinTime = false
849						// fix setup.s - remove the "!"
850						setup.s = setup.s[1:]
851					}
852					if !zeroDateSupported && setup.s == tstr0[:len(setup.s)] {
853						// skip disallowed 0000-00-00 date
854						continue
855					}
856					setup.run(dbt, setups.dbtype, setups.tlayout, textString)
857					setup.run(dbt, setups.dbtype, setups.tlayout, binaryString)
858					if allowBinTime {
859						setup.run(dbt, setups.dbtype, setups.tlayout, binaryTime)
860					}
861				}
862			}
863		})
864	}
865}
866
867func TestTimestampMicros(t *testing.T) {
868	format := "2006-01-02 15:04:05.999999"
869	f0 := format[:19]
870	f1 := format[:21]
871	f6 := format[:26]
872	runTests(t, dsn, func(dbt *DBTest) {
873		// check if microseconds are supported.
874		// Do not use timestamp(x) for that check - before 5.5.6, x would mean display width
875		// and not precision.
876		// Se last paragraph at http://dev.mysql.com/doc/refman/5.6/en/fractional-seconds.html
877		microsecsSupported := false
878		if rows, err := dbt.db.Query(`SELECT cast("00:00:00.1" as TIME(1)) = "00:00:00.1"`); err == nil {
879			rows.Scan(&microsecsSupported)
880			rows.Close()
881		}
882		if !microsecsSupported {
883			// skip test
884			return
885		}
886		_, err := dbt.db.Exec(`
887			CREATE TABLE test (
888				value0 TIMESTAMP NOT NULL DEFAULT '` + f0 + `',
889				value1 TIMESTAMP(1) NOT NULL DEFAULT '` + f1 + `',
890				value6 TIMESTAMP(6) NOT NULL DEFAULT '` + f6 + `'
891			)`,
892		)
893		if err != nil {
894			dbt.Error(err)
895		}
896		defer dbt.mustExec("DROP TABLE IF EXISTS test")
897		dbt.mustExec("INSERT INTO test SET value0=?, value1=?, value6=?", f0, f1, f6)
898		var res0, res1, res6 string
899		rows := dbt.mustQuery("SELECT * FROM test")
900		defer rows.Close()
901		if !rows.Next() {
902			dbt.Errorf("test contained no selectable values")
903		}
904		err = rows.Scan(&res0, &res1, &res6)
905		if err != nil {
906			dbt.Error(err)
907		}
908		if res0 != f0 {
909			dbt.Errorf("expected %q, got %q", f0, res0)
910		}
911		if res1 != f1 {
912			dbt.Errorf("expected %q, got %q", f1, res1)
913		}
914		if res6 != f6 {
915			dbt.Errorf("expected %q, got %q", f6, res6)
916		}
917	})
918}
919
920func TestNULL(t *testing.T) {
921	runTests(t, dsn, func(dbt *DBTest) {
922		nullStmt, err := dbt.db.Prepare("SELECT NULL")
923		if err != nil {
924			dbt.Fatal(err)
925		}
926		defer nullStmt.Close()
927
928		nonNullStmt, err := dbt.db.Prepare("SELECT 1")
929		if err != nil {
930			dbt.Fatal(err)
931		}
932		defer nonNullStmt.Close()
933
934		// NullBool
935		var nb sql.NullBool
936		// Invalid
937		if err = nullStmt.QueryRow().Scan(&nb); err != nil {
938			dbt.Fatal(err)
939		}
940		if nb.Valid {
941			dbt.Error("valid NullBool which should be invalid")
942		}
943		// Valid
944		if err = nonNullStmt.QueryRow().Scan(&nb); err != nil {
945			dbt.Fatal(err)
946		}
947		if !nb.Valid {
948			dbt.Error("invalid NullBool which should be valid")
949		} else if nb.Bool != true {
950			dbt.Errorf("Unexpected NullBool value: %t (should be true)", nb.Bool)
951		}
952
953		// NullFloat64
954		var nf sql.NullFloat64
955		// Invalid
956		if err = nullStmt.QueryRow().Scan(&nf); err != nil {
957			dbt.Fatal(err)
958		}
959		if nf.Valid {
960			dbt.Error("valid NullFloat64 which should be invalid")
961		}
962		// Valid
963		if err = nonNullStmt.QueryRow().Scan(&nf); err != nil {
964			dbt.Fatal(err)
965		}
966		if !nf.Valid {
967			dbt.Error("invalid NullFloat64 which should be valid")
968		} else if nf.Float64 != float64(1) {
969			dbt.Errorf("unexpected NullFloat64 value: %f (should be 1.0)", nf.Float64)
970		}
971
972		// NullInt64
973		var ni sql.NullInt64
974		// Invalid
975		if err = nullStmt.QueryRow().Scan(&ni); err != nil {
976			dbt.Fatal(err)
977		}
978		if ni.Valid {
979			dbt.Error("valid NullInt64 which should be invalid")
980		}
981		// Valid
982		if err = nonNullStmt.QueryRow().Scan(&ni); err != nil {
983			dbt.Fatal(err)
984		}
985		if !ni.Valid {
986			dbt.Error("invalid NullInt64 which should be valid")
987		} else if ni.Int64 != int64(1) {
988			dbt.Errorf("unexpected NullInt64 value: %d (should be 1)", ni.Int64)
989		}
990
991		// NullString
992		var ns sql.NullString
993		// Invalid
994		if err = nullStmt.QueryRow().Scan(&ns); err != nil {
995			dbt.Fatal(err)
996		}
997		if ns.Valid {
998			dbt.Error("valid NullString which should be invalid")
999		}
1000		// Valid
1001		if err = nonNullStmt.QueryRow().Scan(&ns); err != nil {
1002			dbt.Fatal(err)
1003		}
1004		if !ns.Valid {
1005			dbt.Error("invalid NullString which should be valid")
1006		} else if ns.String != `1` {
1007			dbt.Error("unexpected NullString value:" + ns.String + " (should be `1`)")
1008		}
1009
1010		// nil-bytes
1011		var b []byte
1012		// Read nil
1013		if err = nullStmt.QueryRow().Scan(&b); err != nil {
1014			dbt.Fatal(err)
1015		}
1016		if b != nil {
1017			dbt.Error("non-nil []byte which should be nil")
1018		}
1019		// Read non-nil
1020		if err = nonNullStmt.QueryRow().Scan(&b); err != nil {
1021			dbt.Fatal(err)
1022		}
1023		if b == nil {
1024			dbt.Error("nil []byte which should be non-nil")
1025		}
1026		// Insert nil
1027		b = nil
1028		success := false
1029		if err = dbt.db.QueryRow("SELECT ? IS NULL", b).Scan(&success); err != nil {
1030			dbt.Fatal(err)
1031		}
1032		if !success {
1033			dbt.Error("inserting []byte(nil) as NULL failed")
1034		}
1035		// Check input==output with input==nil
1036		b = nil
1037		if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil {
1038			dbt.Fatal(err)
1039		}
1040		if b != nil {
1041			dbt.Error("non-nil echo from nil input")
1042		}
1043		// Check input==output with input!=nil
1044		b = []byte("")
1045		if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil {
1046			dbt.Fatal(err)
1047		}
1048		if b == nil {
1049			dbt.Error("nil echo from non-nil input")
1050		}
1051
1052		// Insert NULL
1053		dbt.mustExec("CREATE TABLE test (dummmy1 int, value int, dummy2 int)")
1054
1055		dbt.mustExec("INSERT INTO test VALUES (?, ?, ?)", 1, nil, 2)
1056
1057		var out interface{}
1058		rows := dbt.mustQuery("SELECT * FROM test")
1059		defer rows.Close()
1060		if rows.Next() {
1061			rows.Scan(&out)
1062			if out != nil {
1063				dbt.Errorf("%v != nil", out)
1064			}
1065		} else {
1066			dbt.Error("no data")
1067		}
1068	})
1069}
1070
1071func TestUint64(t *testing.T) {
1072	const (
1073		u0    = uint64(0)
1074		uall  = ^u0
1075		uhigh = uall >> 1
1076		utop  = ^uhigh
1077		s0    = int64(0)
1078		sall  = ^s0
1079		shigh = int64(uhigh)
1080		stop  = ^shigh
1081	)
1082	runTests(t, dsn, func(dbt *DBTest) {
1083		stmt, err := dbt.db.Prepare(`SELECT ?, ?, ? ,?, ?, ?, ?, ?`)
1084		if err != nil {
1085			dbt.Fatal(err)
1086		}
1087		defer stmt.Close()
1088		row := stmt.QueryRow(
1089			u0, uhigh, utop, uall,
1090			s0, shigh, stop, sall,
1091		)
1092
1093		var ua, ub, uc, ud uint64
1094		var sa, sb, sc, sd int64
1095
1096		err = row.Scan(&ua, &ub, &uc, &ud, &sa, &sb, &sc, &sd)
1097		if err != nil {
1098			dbt.Fatal(err)
1099		}
1100		switch {
1101		case ua != u0,
1102			ub != uhigh,
1103			uc != utop,
1104			ud != uall,
1105			sa != s0,
1106			sb != shigh,
1107			sc != stop,
1108			sd != sall:
1109			dbt.Fatal("unexpected result value")
1110		}
1111	})
1112}
1113
1114func TestLongData(t *testing.T) {
1115	runTests(t, dsn+"&maxAllowedPacket=0", func(dbt *DBTest) {
1116		var maxAllowedPacketSize int
1117		err := dbt.db.QueryRow("select @@max_allowed_packet").Scan(&maxAllowedPacketSize)
1118		if err != nil {
1119			dbt.Fatal(err)
1120		}
1121		maxAllowedPacketSize--
1122
1123		// don't get too ambitious
1124		if maxAllowedPacketSize > 1<<25 {
1125			maxAllowedPacketSize = 1 << 25
1126		}
1127
1128		dbt.mustExec("CREATE TABLE test (value LONGBLOB)")
1129
1130		in := strings.Repeat(`a`, maxAllowedPacketSize+1)
1131		var out string
1132		var rows *sql.Rows
1133
1134		// Long text data
1135		const nonDataQueryLen = 28 // length query w/o value
1136		inS := in[:maxAllowedPacketSize-nonDataQueryLen]
1137		dbt.mustExec("INSERT INTO test VALUES('" + inS + "')")
1138		rows = dbt.mustQuery("SELECT value FROM test")
1139		defer rows.Close()
1140		if rows.Next() {
1141			rows.Scan(&out)
1142			if inS != out {
1143				dbt.Fatalf("LONGBLOB: length in: %d, length out: %d", len(inS), len(out))
1144			}
1145			if rows.Next() {
1146				dbt.Error("LONGBLOB: unexpexted row")
1147			}
1148		} else {
1149			dbt.Fatalf("LONGBLOB: no data")
1150		}
1151
1152		// Empty table
1153		dbt.mustExec("TRUNCATE TABLE test")
1154
1155		// Long binary data
1156		dbt.mustExec("INSERT INTO test VALUES(?)", in)
1157		rows = dbt.mustQuery("SELECT value FROM test WHERE 1=?", 1)
1158		defer rows.Close()
1159		if rows.Next() {
1160			rows.Scan(&out)
1161			if in != out {
1162				dbt.Fatalf("LONGBLOB: length in: %d, length out: %d", len(in), len(out))
1163			}
1164			if rows.Next() {
1165				dbt.Error("LONGBLOB: unexpexted row")
1166			}
1167		} else {
1168			if err = rows.Err(); err != nil {
1169				dbt.Fatalf("LONGBLOB: no data (err: %s)", err.Error())
1170			} else {
1171				dbt.Fatal("LONGBLOB: no data (err: <nil>)")
1172			}
1173		}
1174	})
1175}
1176
1177func TestLoadData(t *testing.T) {
1178	runTests(t, dsn, func(dbt *DBTest) {
1179		verifyLoadDataResult := func() {
1180			rows, err := dbt.db.Query("SELECT * FROM test")
1181			if err != nil {
1182				dbt.Fatal(err.Error())
1183			}
1184
1185			i := 0
1186			values := [4]string{
1187				"a string",
1188				"a string containing a \t",
1189				"a string containing a \n",
1190				"a string containing both \t\n",
1191			}
1192
1193			var id int
1194			var value string
1195
1196			for rows.Next() {
1197				i++
1198				err = rows.Scan(&id, &value)
1199				if err != nil {
1200					dbt.Fatal(err.Error())
1201				}
1202				if i != id {
1203					dbt.Fatalf("%d != %d", i, id)
1204				}
1205				if values[i-1] != value {
1206					dbt.Fatalf("%q != %q", values[i-1], value)
1207				}
1208			}
1209			err = rows.Err()
1210			if err != nil {
1211				dbt.Fatal(err.Error())
1212			}
1213
1214			if i != 4 {
1215				dbt.Fatalf("rows count mismatch. Got %d, want 4", i)
1216			}
1217		}
1218
1219		dbt.db.Exec("DROP TABLE IF EXISTS test")
1220		dbt.mustExec("CREATE TABLE test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8")
1221
1222		// Local File
1223		file, err := ioutil.TempFile("", "gotest")
1224		defer os.Remove(file.Name())
1225		if err != nil {
1226			dbt.Fatal(err)
1227		}
1228		RegisterLocalFile(file.Name())
1229
1230		// Try first with empty file
1231		dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE %q INTO TABLE test", file.Name()))
1232		var count int
1233		err = dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&count)
1234		if err != nil {
1235			dbt.Fatal(err.Error())
1236		}
1237		if count != 0 {
1238			dbt.Fatalf("unexpected row count: got %d, want 0", count)
1239		}
1240
1241		// Then fille File with data and try to load it
1242		file.WriteString("1\ta string\n2\ta string containing a \\t\n3\ta string containing a \\n\n4\ta string containing both \\t\\n\n")
1243		file.Close()
1244		dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE %q INTO TABLE test", file.Name()))
1245		verifyLoadDataResult()
1246
1247		// Try with non-existing file
1248		_, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'doesnotexist' INTO TABLE test")
1249		if err == nil {
1250			dbt.Fatal("load non-existent file didn't fail")
1251		} else if err.Error() != "local file 'doesnotexist' is not registered" {
1252			dbt.Fatal(err.Error())
1253		}
1254
1255		// Empty table
1256		dbt.mustExec("TRUNCATE TABLE test")
1257
1258		// Reader
1259		RegisterReaderHandler("test", func() io.Reader {
1260			file, err = os.Open(file.Name())
1261			if err != nil {
1262				dbt.Fatal(err)
1263			}
1264			return file
1265		})
1266		dbt.mustExec("LOAD DATA LOCAL INFILE 'Reader::test' INTO TABLE test")
1267		verifyLoadDataResult()
1268		// negative test
1269		_, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'Reader::doesnotexist' INTO TABLE test")
1270		if err == nil {
1271			dbt.Fatal("load non-existent Reader didn't fail")
1272		} else if err.Error() != "Reader 'doesnotexist' is not registered" {
1273			dbt.Fatal(err.Error())
1274		}
1275	})
1276}
1277
1278func TestFoundRows(t *testing.T) {
1279	runTests(t, dsn, func(dbt *DBTest) {
1280		dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
1281		dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
1282
1283		res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
1284		count, err := res.RowsAffected()
1285		if err != nil {
1286			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
1287		}
1288		if count != 2 {
1289			dbt.Fatalf("Expected 2 affected rows, got %d", count)
1290		}
1291		res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
1292		count, err = res.RowsAffected()
1293		if err != nil {
1294			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
1295		}
1296		if count != 2 {
1297			dbt.Fatalf("Expected 2 affected rows, got %d", count)
1298		}
1299	})
1300	runTests(t, dsn+"&clientFoundRows=true", func(dbt *DBTest) {
1301		dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
1302		dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
1303
1304		res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
1305		count, err := res.RowsAffected()
1306		if err != nil {
1307			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
1308		}
1309		if count != 2 {
1310			dbt.Fatalf("Expected 2 matched rows, got %d", count)
1311		}
1312		res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
1313		count, err = res.RowsAffected()
1314		if err != nil {
1315			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
1316		}
1317		if count != 3 {
1318			dbt.Fatalf("Expected 3 matched rows, got %d", count)
1319		}
1320	})
1321}
1322
1323func TestTLS(t *testing.T) {
1324	tlsTestReq := func(dbt *DBTest) {
1325		if err := dbt.db.Ping(); err != nil {
1326			if err == ErrNoTLS {
1327				dbt.Skip("server does not support TLS")
1328			} else {
1329				dbt.Fatalf("error on Ping: %s", err.Error())
1330			}
1331		}
1332
1333		rows := dbt.mustQuery("SHOW STATUS LIKE 'Ssl_cipher'")
1334		defer rows.Close()
1335
1336		var variable, value *sql.RawBytes
1337		for rows.Next() {
1338			if err := rows.Scan(&variable, &value); err != nil {
1339				dbt.Fatal(err.Error())
1340			}
1341
1342			if (*value == nil) || (len(*value) == 0) {
1343				dbt.Fatalf("no Cipher")
1344			} else {
1345				dbt.Logf("Cipher: %s", *value)
1346			}
1347		}
1348	}
1349	tlsTestOpt := func(dbt *DBTest) {
1350		if err := dbt.db.Ping(); err != nil {
1351			dbt.Fatalf("error on Ping: %s", err.Error())
1352		}
1353	}
1354
1355	runTests(t, dsn+"&tls=preferred", tlsTestOpt)
1356	runTests(t, dsn+"&tls=skip-verify", tlsTestReq)
1357
1358	// Verify that registering / using a custom cfg works
1359	RegisterTLSConfig("custom-skip-verify", &tls.Config{
1360		InsecureSkipVerify: true,
1361	})
1362	runTests(t, dsn+"&tls=custom-skip-verify", tlsTestReq)
1363}
1364
1365func TestReuseClosedConnection(t *testing.T) {
1366	// this test does not use sql.database, it uses the driver directly
1367	if !available {
1368		t.Skipf("MySQL server not running on %s", netAddr)
1369	}
1370
1371	md := &MySQLDriver{}
1372	conn, err := md.Open(dsn)
1373	if err != nil {
1374		t.Fatalf("error connecting: %s", err.Error())
1375	}
1376	stmt, err := conn.Prepare("DO 1")
1377	if err != nil {
1378		t.Fatalf("error preparing statement: %s", err.Error())
1379	}
1380	_, err = stmt.Exec(nil)
1381	if err != nil {
1382		t.Fatalf("error executing statement: %s", err.Error())
1383	}
1384	err = conn.Close()
1385	if err != nil {
1386		t.Fatalf("error closing connection: %s", err.Error())
1387	}
1388
1389	defer func() {
1390		if err := recover(); err != nil {
1391			t.Errorf("panic after reusing a closed connection: %v", err)
1392		}
1393	}()
1394	_, err = stmt.Exec(nil)
1395	if err != nil && err != driver.ErrBadConn {
1396		t.Errorf("unexpected error '%s', expected '%s'",
1397			err.Error(), driver.ErrBadConn.Error())
1398	}
1399}
1400
1401func TestCharset(t *testing.T) {
1402	if !available {
1403		t.Skipf("MySQL server not running on %s", netAddr)
1404	}
1405
1406	mustSetCharset := func(charsetParam, expected string) {
1407		runTests(t, dsn+"&"+charsetParam, func(dbt *DBTest) {
1408			rows := dbt.mustQuery("SELECT @@character_set_connection")
1409			defer rows.Close()
1410
1411			if !rows.Next() {
1412				dbt.Fatalf("error getting connection charset: %s", rows.Err())
1413			}
1414
1415			var got string
1416			rows.Scan(&got)
1417
1418			if got != expected {
1419				dbt.Fatalf("expected connection charset %s but got %s", expected, got)
1420			}
1421		})
1422	}
1423
1424	// non utf8 test
1425	mustSetCharset("charset=ascii", "ascii")
1426
1427	// when the first charset is invalid, use the second
1428	mustSetCharset("charset=none,utf8", "utf8")
1429
1430	// when the first charset is valid, use it
1431	mustSetCharset("charset=ascii,utf8", "ascii")
1432	mustSetCharset("charset=utf8,ascii", "utf8")
1433}
1434
1435func TestFailingCharset(t *testing.T) {
1436	runTests(t, dsn+"&charset=none", func(dbt *DBTest) {
1437		// run query to really establish connection...
1438		_, err := dbt.db.Exec("SELECT 1")
1439		if err == nil {
1440			dbt.db.Close()
1441			t.Fatalf("connection must not succeed without a valid charset")
1442		}
1443	})
1444}
1445
1446func TestCollation(t *testing.T) {
1447	if !available {
1448		t.Skipf("MySQL server not running on %s", netAddr)
1449	}
1450
1451	defaultCollation := "utf8mb4_general_ci"
1452	testCollations := []string{
1453		"",               // do not set
1454		defaultCollation, // driver default
1455		"latin1_general_ci",
1456		"binary",
1457		"utf8_unicode_ci",
1458		"cp1257_bin",
1459	}
1460
1461	for _, collation := range testCollations {
1462		var expected, tdsn string
1463		if collation != "" {
1464			tdsn = dsn + "&collation=" + collation
1465			expected = collation
1466		} else {
1467			tdsn = dsn
1468			expected = defaultCollation
1469		}
1470
1471		runTests(t, tdsn, func(dbt *DBTest) {
1472			var got string
1473			if err := dbt.db.QueryRow("SELECT @@collation_connection").Scan(&got); err != nil {
1474				dbt.Fatal(err)
1475			}
1476
1477			if got != expected {
1478				dbt.Fatalf("expected connection collation %s but got %s", expected, got)
1479			}
1480		})
1481	}
1482}
1483
1484func TestColumnsWithAlias(t *testing.T) {
1485	runTests(t, dsn+"&columnsWithAlias=true", func(dbt *DBTest) {
1486		rows := dbt.mustQuery("SELECT 1 AS A")
1487		defer rows.Close()
1488		cols, _ := rows.Columns()
1489		if len(cols) != 1 {
1490			t.Fatalf("expected 1 column, got %d", len(cols))
1491		}
1492		if cols[0] != "A" {
1493			t.Fatalf("expected column name \"A\", got \"%s\"", cols[0])
1494		}
1495
1496		rows = dbt.mustQuery("SELECT * FROM (SELECT 1 AS one) AS A")
1497		defer rows.Close()
1498		cols, _ = rows.Columns()
1499		if len(cols) != 1 {
1500			t.Fatalf("expected 1 column, got %d", len(cols))
1501		}
1502		if cols[0] != "A.one" {
1503			t.Fatalf("expected column name \"A.one\", got \"%s\"", cols[0])
1504		}
1505	})
1506}
1507
1508func TestRawBytesResultExceedsBuffer(t *testing.T) {
1509	runTests(t, dsn, func(dbt *DBTest) {
1510		// defaultBufSize from buffer.go
1511		expected := strings.Repeat("abc", defaultBufSize)
1512
1513		rows := dbt.mustQuery("SELECT '" + expected + "'")
1514		defer rows.Close()
1515		if !rows.Next() {
1516			dbt.Error("expected result, got none")
1517		}
1518		var result sql.RawBytes
1519		rows.Scan(&result)
1520		if expected != string(result) {
1521			dbt.Error("result did not match expected value")
1522		}
1523	})
1524}
1525
1526func TestTimezoneConversion(t *testing.T) {
1527	zones := []string{"UTC", "US/Central", "US/Pacific", "Local"}
1528
1529	// Regression test for timezone handling
1530	tzTest := func(dbt *DBTest) {
1531		// Create table
1532		dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)")
1533
1534		// Insert local time into database (should be converted)
1535		usCentral, _ := time.LoadLocation("US/Central")
1536		reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, time.UTC).In(usCentral)
1537		dbt.mustExec("INSERT INTO test VALUE (?)", reftime)
1538
1539		// Retrieve time from DB
1540		rows := dbt.mustQuery("SELECT ts FROM test")
1541		defer rows.Close()
1542		if !rows.Next() {
1543			dbt.Fatal("did not get any rows out")
1544		}
1545
1546		var dbTime time.Time
1547		err := rows.Scan(&dbTime)
1548		if err != nil {
1549			dbt.Fatal("Err", err)
1550		}
1551
1552		// Check that dates match
1553		if reftime.Unix() != dbTime.Unix() {
1554			dbt.Errorf("times do not match.\n")
1555			dbt.Errorf(" Now(%v)=%v\n", usCentral, reftime)
1556			dbt.Errorf(" Now(UTC)=%v\n", dbTime)
1557		}
1558	}
1559
1560	for _, tz := range zones {
1561		runTests(t, dsn+"&parseTime=true&loc="+url.QueryEscape(tz), tzTest)
1562	}
1563}
1564
1565// Special cases
1566
1567func TestRowsClose(t *testing.T) {
1568	runTests(t, dsn, func(dbt *DBTest) {
1569		rows, err := dbt.db.Query("SELECT 1")
1570		if err != nil {
1571			dbt.Fatal(err)
1572		}
1573
1574		err = rows.Close()
1575		if err != nil {
1576			dbt.Fatal(err)
1577		}
1578
1579		if rows.Next() {
1580			dbt.Fatal("unexpected row after rows.Close()")
1581		}
1582
1583		err = rows.Err()
1584		if err != nil {
1585			dbt.Fatal(err)
1586		}
1587	})
1588}
1589
1590// dangling statements
1591// http://code.google.com/p/go/issues/detail?id=3865
1592func TestCloseStmtBeforeRows(t *testing.T) {
1593	runTests(t, dsn, func(dbt *DBTest) {
1594		stmt, err := dbt.db.Prepare("SELECT 1")
1595		if err != nil {
1596			dbt.Fatal(err)
1597		}
1598
1599		rows, err := stmt.Query()
1600		if err != nil {
1601			stmt.Close()
1602			dbt.Fatal(err)
1603		}
1604		defer rows.Close()
1605
1606		err = stmt.Close()
1607		if err != nil {
1608			dbt.Fatal(err)
1609		}
1610
1611		if !rows.Next() {
1612			dbt.Fatal("getting row failed")
1613		} else {
1614			err = rows.Err()
1615			if err != nil {
1616				dbt.Fatal(err)
1617			}
1618
1619			var out bool
1620			err = rows.Scan(&out)
1621			if err != nil {
1622				dbt.Fatalf("error on rows.Scan(): %s", err.Error())
1623			}
1624			if out != true {
1625				dbt.Errorf("true != %t", out)
1626			}
1627		}
1628	})
1629}
1630
1631// It is valid to have multiple Rows for the same Stmt
1632// http://code.google.com/p/go/issues/detail?id=3734
1633func TestStmtMultiRows(t *testing.T) {
1634	runTests(t, dsn, func(dbt *DBTest) {
1635		stmt, err := dbt.db.Prepare("SELECT 1 UNION SELECT 0")
1636		if err != nil {
1637			dbt.Fatal(err)
1638		}
1639
1640		rows1, err := stmt.Query()
1641		if err != nil {
1642			stmt.Close()
1643			dbt.Fatal(err)
1644		}
1645		defer rows1.Close()
1646
1647		rows2, err := stmt.Query()
1648		if err != nil {
1649			stmt.Close()
1650			dbt.Fatal(err)
1651		}
1652		defer rows2.Close()
1653
1654		var out bool
1655
1656		// 1
1657		if !rows1.Next() {
1658			dbt.Fatal("first rows1.Next failed")
1659		} else {
1660			err = rows1.Err()
1661			if err != nil {
1662				dbt.Fatal(err)
1663			}
1664
1665			err = rows1.Scan(&out)
1666			if err != nil {
1667				dbt.Fatalf("error on rows.Scan(): %s", err.Error())
1668			}
1669			if out != true {
1670				dbt.Errorf("true != %t", out)
1671			}
1672		}
1673
1674		if !rows2.Next() {
1675			dbt.Fatal("first rows2.Next failed")
1676		} else {
1677			err = rows2.Err()
1678			if err != nil {
1679				dbt.Fatal(err)
1680			}
1681
1682			err = rows2.Scan(&out)
1683			if err != nil {
1684				dbt.Fatalf("error on rows.Scan(): %s", err.Error())
1685			}
1686			if out != true {
1687				dbt.Errorf("true != %t", out)
1688			}
1689		}
1690
1691		// 2
1692		if !rows1.Next() {
1693			dbt.Fatal("second rows1.Next failed")
1694		} else {
1695			err = rows1.Err()
1696			if err != nil {
1697				dbt.Fatal(err)
1698			}
1699
1700			err = rows1.Scan(&out)
1701			if err != nil {
1702				dbt.Fatalf("error on rows.Scan(): %s", err.Error())
1703			}
1704			if out != false {
1705				dbt.Errorf("false != %t", out)
1706			}
1707
1708			if rows1.Next() {
1709				dbt.Fatal("unexpected row on rows1")
1710			}
1711			err = rows1.Close()
1712			if err != nil {
1713				dbt.Fatal(err)
1714			}
1715		}
1716
1717		if !rows2.Next() {
1718			dbt.Fatal("second rows2.Next failed")
1719		} else {
1720			err = rows2.Err()
1721			if err != nil {
1722				dbt.Fatal(err)
1723			}
1724
1725			err = rows2.Scan(&out)
1726			if err != nil {
1727				dbt.Fatalf("error on rows.Scan(): %s", err.Error())
1728			}
1729			if out != false {
1730				dbt.Errorf("false != %t", out)
1731			}
1732
1733			if rows2.Next() {
1734				dbt.Fatal("unexpected row on rows2")
1735			}
1736			err = rows2.Close()
1737			if err != nil {
1738				dbt.Fatal(err)
1739			}
1740		}
1741	})
1742}
1743
1744// Regression test for
1745// * more than 32 NULL parameters (issue 209)
1746// * more parameters than fit into the buffer (issue 201)
1747// * parameters * 64 > max_allowed_packet (issue 734)
1748func TestPreparedManyCols(t *testing.T) {
1749	numParams := 65535
1750	runTests(t, dsn, func(dbt *DBTest) {
1751		query := "SELECT ?" + strings.Repeat(",?", numParams-1)
1752		stmt, err := dbt.db.Prepare(query)
1753		if err != nil {
1754			dbt.Fatal(err)
1755		}
1756		defer stmt.Close()
1757
1758		// create more parameters than fit into the buffer
1759		// which will take nil-values
1760		params := make([]interface{}, numParams)
1761		rows, err := stmt.Query(params...)
1762		if err != nil {
1763			dbt.Fatal(err)
1764		}
1765		rows.Close()
1766
1767		// Create 0byte string which we can't send via STMT_LONG_DATA.
1768		for i := 0; i < numParams; i++ {
1769			params[i] = ""
1770		}
1771		rows, err = stmt.Query(params...)
1772		if err != nil {
1773			dbt.Fatal(err)
1774		}
1775		rows.Close()
1776	})
1777}
1778
1779func TestConcurrent(t *testing.T) {
1780	if enabled, _ := readBool(os.Getenv("MYSQL_TEST_CONCURRENT")); !enabled {
1781		t.Skip("MYSQL_TEST_CONCURRENT env var not set")
1782	}
1783
1784	runTests(t, dsn, func(dbt *DBTest) {
1785		var max int
1786		err := dbt.db.QueryRow("SELECT @@max_connections").Scan(&max)
1787		if err != nil {
1788			dbt.Fatalf("%s", err.Error())
1789		}
1790		dbt.Logf("testing up to %d concurrent connections \r\n", max)
1791
1792		var remaining, succeeded int32 = int32(max), 0
1793
1794		var wg sync.WaitGroup
1795		wg.Add(max)
1796
1797		var fatalError string
1798		var once sync.Once
1799		fatalf := func(s string, vals ...interface{}) {
1800			once.Do(func() {
1801				fatalError = fmt.Sprintf(s, vals...)
1802			})
1803		}
1804
1805		for i := 0; i < max; i++ {
1806			go func(id int) {
1807				defer wg.Done()
1808
1809				tx, err := dbt.db.Begin()
1810				atomic.AddInt32(&remaining, -1)
1811
1812				if err != nil {
1813					if err.Error() != "Error 1040: Too many connections" {
1814						fatalf("error on conn %d: %s", id, err.Error())
1815					}
1816					return
1817				}
1818
1819				// keep the connection busy until all connections are open
1820				for remaining > 0 {
1821					if _, err = tx.Exec("DO 1"); err != nil {
1822						fatalf("error on conn %d: %s", id, err.Error())
1823						return
1824					}
1825				}
1826
1827				if err = tx.Commit(); err != nil {
1828					fatalf("error on conn %d: %s", id, err.Error())
1829					return
1830				}
1831
1832				// everything went fine with this connection
1833				atomic.AddInt32(&succeeded, 1)
1834			}(i)
1835		}
1836
1837		// wait until all conections are open
1838		wg.Wait()
1839
1840		if fatalError != "" {
1841			dbt.Fatal(fatalError)
1842		}
1843
1844		dbt.Logf("reached %d concurrent connections\r\n", succeeded)
1845	})
1846}
1847
1848func testDialError(t *testing.T, dialErr error, expectErr error) {
1849	RegisterDialContext("mydial", func(ctx context.Context, addr string) (net.Conn, error) {
1850		return nil, dialErr
1851	})
1852
1853	db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))
1854	if err != nil {
1855		t.Fatalf("error connecting: %s", err.Error())
1856	}
1857	defer db.Close()
1858
1859	_, err = db.Exec("DO 1")
1860	if err != expectErr {
1861		t.Fatalf("was expecting %s. Got: %s", dialErr, err)
1862	}
1863}
1864
1865func TestDialUnknownError(t *testing.T) {
1866	testErr := fmt.Errorf("test")
1867	testDialError(t, testErr, testErr)
1868}
1869
1870func TestDialNonRetryableNetErr(t *testing.T) {
1871	testErr := netErrorMock{}
1872	testDialError(t, testErr, testErr)
1873}
1874
1875func TestDialTemporaryNetErr(t *testing.T) {
1876	testErr := netErrorMock{temporary: true}
1877	testDialError(t, testErr, testErr)
1878}
1879
1880// Tests custom dial functions
1881func TestCustomDial(t *testing.T) {
1882	if !available {
1883		t.Skipf("MySQL server not running on %s", netAddr)
1884	}
1885
1886	// our custom dial function which justs wraps net.Dial here
1887	RegisterDialContext("mydial", func(ctx context.Context, addr string) (net.Conn, error) {
1888		var d net.Dialer
1889		return d.DialContext(ctx, prot, addr)
1890	})
1891
1892	db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))
1893	if err != nil {
1894		t.Fatalf("error connecting: %s", err.Error())
1895	}
1896	defer db.Close()
1897
1898	if _, err = db.Exec("DO 1"); err != nil {
1899		t.Fatalf("connection failed: %s", err.Error())
1900	}
1901}
1902
1903func TestSQLInjection(t *testing.T) {
1904	createTest := func(arg string) func(dbt *DBTest) {
1905		return func(dbt *DBTest) {
1906			dbt.mustExec("CREATE TABLE test (v INTEGER)")
1907			dbt.mustExec("INSERT INTO test VALUES (?)", 1)
1908
1909			var v int
1910			// NULL can't be equal to anything, the idea here is to inject query so it returns row
1911			// This test verifies that escapeQuotes and escapeBackslash are working properly
1912			err := dbt.db.QueryRow("SELECT v FROM test WHERE NULL = ?", arg).Scan(&v)
1913			if err == sql.ErrNoRows {
1914				return // success, sql injection failed
1915			} else if err == nil {
1916				dbt.Errorf("sql injection successful with arg: %s", arg)
1917			} else {
1918				dbt.Errorf("error running query with arg: %s; err: %s", arg, err.Error())
1919			}
1920		}
1921	}
1922
1923	dsns := []string{
1924		dsn,
1925		dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'",
1926	}
1927	for _, testdsn := range dsns {
1928		runTests(t, testdsn, createTest("1 OR 1=1"))
1929		runTests(t, testdsn, createTest("' OR '1'='1"))
1930	}
1931}
1932
1933// Test if inserted data is correctly retrieved after being escaped
1934func TestInsertRetrieveEscapedData(t *testing.T) {
1935	testData := func(dbt *DBTest) {
1936		dbt.mustExec("CREATE TABLE test (v VARCHAR(255))")
1937
1938		// All sequences that are escaped by escapeQuotes and escapeBackslash
1939		v := "foo \x00\n\r\x1a\"'\\"
1940		dbt.mustExec("INSERT INTO test VALUES (?)", v)
1941
1942		var out string
1943		err := dbt.db.QueryRow("SELECT v FROM test").Scan(&out)
1944		if err != nil {
1945			dbt.Fatalf("%s", err.Error())
1946		}
1947
1948		if out != v {
1949			dbt.Errorf("%q != %q", out, v)
1950		}
1951	}
1952
1953	dsns := []string{
1954		dsn,
1955		dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'",
1956	}
1957	for _, testdsn := range dsns {
1958		runTests(t, testdsn, testData)
1959	}
1960}
1961
1962func TestUnixSocketAuthFail(t *testing.T) {
1963	runTests(t, dsn, func(dbt *DBTest) {
1964		// Save the current logger so we can restore it.
1965		oldLogger := errLog
1966
1967		// Set a new logger so we can capture its output.
1968		buffer := bytes.NewBuffer(make([]byte, 0, 64))
1969		newLogger := log.New(buffer, "prefix: ", 0)
1970		SetLogger(newLogger)
1971
1972		// Restore the logger.
1973		defer SetLogger(oldLogger)
1974
1975		// Make a new DSN that uses the MySQL socket file and a bad password, which
1976		// we can make by simply appending any character to the real password.
1977		badPass := pass + "x"
1978		socket := ""
1979		if prot == "unix" {
1980			socket = addr
1981		} else {
1982			// Get socket file from MySQL.
1983			err := dbt.db.QueryRow("SELECT @@socket").Scan(&socket)
1984			if err != nil {
1985				t.Fatalf("error on SELECT @@socket: %s", err.Error())
1986			}
1987		}
1988		t.Logf("socket: %s", socket)
1989		badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s", user, badPass, socket, dbname)
1990		db, err := sql.Open("mysql", badDSN)
1991		if err != nil {
1992			t.Fatalf("error connecting: %s", err.Error())
1993		}
1994		defer db.Close()
1995
1996		// Connect to MySQL for real. This will cause an auth failure.
1997		err = db.Ping()
1998		if err == nil {
1999			t.Error("expected Ping() to return an error")
2000		}
2001
2002		// The driver should not log anything.
2003		if actual := buffer.String(); actual != "" {
2004			t.Errorf("expected no output, got %q", actual)
2005		}
2006	})
2007}
2008
2009// See Issue #422
2010func TestInterruptBySignal(t *testing.T) {
2011	runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
2012		dbt.mustExec(`
2013			DROP PROCEDURE IF EXISTS test_signal;
2014			CREATE PROCEDURE test_signal(ret INT)
2015			BEGIN
2016				SELECT ret;
2017				SIGNAL SQLSTATE
2018					'45001'
2019				SET
2020					MESSAGE_TEXT = "an error",
2021					MYSQL_ERRNO = 45001;
2022			END
2023		`)
2024		defer dbt.mustExec("DROP PROCEDURE test_signal")
2025
2026		var val int
2027
2028		// text protocol
2029		rows, err := dbt.db.Query("CALL test_signal(42)")
2030		if err != nil {
2031			dbt.Fatalf("error on text query: %s", err.Error())
2032		}
2033		for rows.Next() {
2034			if err := rows.Scan(&val); err != nil {
2035				dbt.Error(err)
2036			} else if val != 42 {
2037				dbt.Errorf("expected val to be 42")
2038			}
2039		}
2040		rows.Close()
2041
2042		// binary protocol
2043		rows, err = dbt.db.Query("CALL test_signal(?)", 42)
2044		if err != nil {
2045			dbt.Fatalf("error on binary query: %s", err.Error())
2046		}
2047		for rows.Next() {
2048			if err := rows.Scan(&val); err != nil {
2049				dbt.Error(err)
2050			} else if val != 42 {
2051				dbt.Errorf("expected val to be 42")
2052			}
2053		}
2054		rows.Close()
2055	})
2056}
2057
2058func TestColumnsReusesSlice(t *testing.T) {
2059	rows := mysqlRows{
2060		rs: resultSet{
2061			columns: []mysqlField{
2062				{
2063					tableName: "test",
2064					name:      "A",
2065				},
2066				{
2067					tableName: "test",
2068					name:      "B",
2069				},
2070			},
2071		},
2072	}
2073
2074	allocs := testing.AllocsPerRun(1, func() {
2075		cols := rows.Columns()
2076
2077		if len(cols) != 2 {
2078			t.Fatalf("expected 2 columns, got %d", len(cols))
2079		}
2080	})
2081
2082	if allocs != 0 {
2083		t.Fatalf("expected 0 allocations, got %d", int(allocs))
2084	}
2085
2086	if rows.rs.columnNames == nil {
2087		t.Fatalf("expected columnNames to be set, got nil")
2088	}
2089}
2090
2091func TestRejectReadOnly(t *testing.T) {
2092	runTests(t, dsn, func(dbt *DBTest) {
2093		// Create Table
2094		dbt.mustExec("CREATE TABLE test (value BOOL)")
2095		// Set the session to read-only. We didn't set the `rejectReadOnly`
2096		// option, so any writes after this should fail.
2097		_, err := dbt.db.Exec("SET SESSION TRANSACTION READ ONLY")
2098		// Error 1193: Unknown system variable 'TRANSACTION' => skip test,
2099		// MySQL server version is too old
2100		maybeSkip(t, err, 1193)
2101		if _, err := dbt.db.Exec("DROP TABLE test"); err == nil {
2102			t.Fatalf("writing to DB in read-only session without " +
2103				"rejectReadOnly did not error")
2104		}
2105		// Set the session back to read-write so runTests() can properly clean
2106		// up the table `test`.
2107		dbt.mustExec("SET SESSION TRANSACTION READ WRITE")
2108	})
2109
2110	// Enable the `rejectReadOnly` option.
2111	runTests(t, dsn+"&rejectReadOnly=true", func(dbt *DBTest) {
2112		// Create Table
2113		dbt.mustExec("CREATE TABLE test (value BOOL)")
2114		// Set the session to read only. Any writes after this should error on
2115		// a driver.ErrBadConn, and cause `database/sql` to initiate a new
2116		// connection.
2117		dbt.mustExec("SET SESSION TRANSACTION READ ONLY")
2118		// This would error, but `database/sql` should automatically retry on a
2119		// new connection which is not read-only, and eventually succeed.
2120		dbt.mustExec("DROP TABLE test")
2121	})
2122}
2123
2124func TestPing(t *testing.T) {
2125	runTests(t, dsn, func(dbt *DBTest) {
2126		if err := dbt.db.Ping(); err != nil {
2127			dbt.fail("Ping", "Ping", err)
2128		}
2129	})
2130}
2131
2132// See Issue #799
2133func TestEmptyPassword(t *testing.T) {
2134	if !available {
2135		t.Skipf("MySQL server not running on %s", netAddr)
2136	}
2137
2138	dsn := fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, "", netAddr, dbname)
2139	db, err := sql.Open("mysql", dsn)
2140	if err == nil {
2141		defer db.Close()
2142		err = db.Ping()
2143	}
2144
2145	if pass == "" {
2146		if err != nil {
2147			t.Fatal(err.Error())
2148		}
2149	} else {
2150		if err == nil {
2151			t.Fatal("expected authentication error")
2152		}
2153		if !strings.HasPrefix(err.Error(), "Error 1045") {
2154			t.Fatal(err.Error())
2155		}
2156	}
2157}
2158
2159// static interface implementation checks of mysqlConn
2160var (
2161	_ driver.ConnBeginTx        = &mysqlConn{}
2162	_ driver.ConnPrepareContext = &mysqlConn{}
2163	_ driver.ExecerContext      = &mysqlConn{}
2164	_ driver.Pinger             = &mysqlConn{}
2165	_ driver.QueryerContext     = &mysqlConn{}
2166)
2167
2168// static interface implementation checks of mysqlStmt
2169var (
2170	_ driver.StmtExecContext  = &mysqlStmt{}
2171	_ driver.StmtQueryContext = &mysqlStmt{}
2172)
2173
2174// Ensure that all the driver interfaces are implemented
2175var (
2176	// _ driver.RowsColumnTypeLength        = &binaryRows{}
2177	// _ driver.RowsColumnTypeLength        = &textRows{}
2178	_ driver.RowsColumnTypeDatabaseTypeName = &binaryRows{}
2179	_ driver.RowsColumnTypeDatabaseTypeName = &textRows{}
2180	_ driver.RowsColumnTypeNullable         = &binaryRows{}
2181	_ driver.RowsColumnTypeNullable         = &textRows{}
2182	_ driver.RowsColumnTypePrecisionScale   = &binaryRows{}
2183	_ driver.RowsColumnTypePrecisionScale   = &textRows{}
2184	_ driver.RowsColumnTypeScanType         = &binaryRows{}
2185	_ driver.RowsColumnTypeScanType         = &textRows{}
2186	_ driver.RowsNextResultSet              = &binaryRows{}
2187	_ driver.RowsNextResultSet              = &textRows{}
2188)
2189
2190func TestMultiResultSet(t *testing.T) {
2191	type result struct {
2192		values  [][]int
2193		columns []string
2194	}
2195
2196	// checkRows is a helper test function to validate rows containing 3 result
2197	// sets with specific values and columns. The basic query would look like this:
2198	//
2199	// SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
2200	// SELECT 0 UNION SELECT 1;
2201	// SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;
2202	//
2203	// to distinguish test cases the first string argument is put in front of
2204	// every error or fatal message.
2205	checkRows := func(desc string, rows *sql.Rows, dbt *DBTest) {
2206		expected := []result{
2207			{
2208				values:  [][]int{{1, 2}, {3, 4}},
2209				columns: []string{"col1", "col2"},
2210			},
2211			{
2212				values:  [][]int{{1, 2, 3}, {4, 5, 6}},
2213				columns: []string{"col1", "col2", "col3"},
2214			},
2215		}
2216
2217		var res1 result
2218		for rows.Next() {
2219			var res [2]int
2220			if err := rows.Scan(&res[0], &res[1]); err != nil {
2221				dbt.Fatal(err)
2222			}
2223			res1.values = append(res1.values, res[:])
2224		}
2225
2226		cols, err := rows.Columns()
2227		if err != nil {
2228			dbt.Fatal(desc, err)
2229		}
2230		res1.columns = cols
2231
2232		if !reflect.DeepEqual(expected[0], res1) {
2233			dbt.Error(desc, "want =", expected[0], "got =", res1)
2234		}
2235
2236		if !rows.NextResultSet() {
2237			dbt.Fatal(desc, "expected next result set")
2238		}
2239
2240		// ignoring one result set
2241
2242		if !rows.NextResultSet() {
2243			dbt.Fatal(desc, "expected next result set")
2244		}
2245
2246		var res2 result
2247		cols, err = rows.Columns()
2248		if err != nil {
2249			dbt.Fatal(desc, err)
2250		}
2251		res2.columns = cols
2252
2253		for rows.Next() {
2254			var res [3]int
2255			if err := rows.Scan(&res[0], &res[1], &res[2]); err != nil {
2256				dbt.Fatal(desc, err)
2257			}
2258			res2.values = append(res2.values, res[:])
2259		}
2260
2261		if !reflect.DeepEqual(expected[1], res2) {
2262			dbt.Error(desc, "want =", expected[1], "got =", res2)
2263		}
2264
2265		if rows.NextResultSet() {
2266			dbt.Error(desc, "unexpected next result set")
2267		}
2268
2269		if err := rows.Err(); err != nil {
2270			dbt.Error(desc, err)
2271		}
2272	}
2273
2274	runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
2275		rows := dbt.mustQuery(`DO 1;
2276		SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
2277		DO 1;
2278		SELECT 0 UNION SELECT 1;
2279		SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;`)
2280		defer rows.Close()
2281		checkRows("query: ", rows, dbt)
2282	})
2283
2284	runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
2285		queries := []string{
2286			`
2287			DROP PROCEDURE IF EXISTS test_mrss;
2288			CREATE PROCEDURE test_mrss()
2289			BEGIN
2290				DO 1;
2291				SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
2292				DO 1;
2293				SELECT 0 UNION SELECT 1;
2294				SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;
2295			END
2296		`,
2297			`
2298			DROP PROCEDURE IF EXISTS test_mrss;
2299			CREATE PROCEDURE test_mrss()
2300			BEGIN
2301				SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
2302				SELECT 0 UNION SELECT 1;
2303				SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;
2304			END
2305		`,
2306		}
2307
2308		defer dbt.mustExec("DROP PROCEDURE IF EXISTS test_mrss")
2309
2310		for i, query := range queries {
2311			dbt.mustExec(query)
2312
2313			stmt, err := dbt.db.Prepare("CALL test_mrss()")
2314			if err != nil {
2315				dbt.Fatalf("%v (i=%d)", err, i)
2316			}
2317			defer stmt.Close()
2318
2319			for j := 0; j < 2; j++ {
2320				rows, err := stmt.Query()
2321				if err != nil {
2322					dbt.Fatalf("%v (i=%d) (j=%d)", err, i, j)
2323				}
2324				checkRows(fmt.Sprintf("prepared stmt query (i=%d) (j=%d): ", i, j), rows, dbt)
2325			}
2326		}
2327	})
2328}
2329
2330func TestMultiResultSetNoSelect(t *testing.T) {
2331	runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
2332		rows := dbt.mustQuery("DO 1; DO 2;")
2333		defer rows.Close()
2334
2335		if rows.Next() {
2336			dbt.Error("unexpected row")
2337		}
2338
2339		if rows.NextResultSet() {
2340			dbt.Error("unexpected next result set")
2341		}
2342
2343		if err := rows.Err(); err != nil {
2344			dbt.Error("expected nil; got ", err)
2345		}
2346	})
2347}
2348
2349// tests if rows are set in a proper state if some results were ignored before
2350// calling rows.NextResultSet.
2351func TestSkipResults(t *testing.T) {
2352	runTests(t, dsn, func(dbt *DBTest) {
2353		rows := dbt.mustQuery("SELECT 1, 2")
2354		defer rows.Close()
2355
2356		if !rows.Next() {
2357			dbt.Error("expected row")
2358		}
2359
2360		if rows.NextResultSet() {
2361			dbt.Error("unexpected next result set")
2362		}
2363
2364		if err := rows.Err(); err != nil {
2365			dbt.Error("expected nil; got ", err)
2366		}
2367	})
2368}
2369
2370func TestPingContext(t *testing.T) {
2371	runTests(t, dsn, func(dbt *DBTest) {
2372		ctx, cancel := context.WithCancel(context.Background())
2373		cancel()
2374		if err := dbt.db.PingContext(ctx); err != context.Canceled {
2375			dbt.Errorf("expected context.Canceled, got %v", err)
2376		}
2377	})
2378}
2379
2380func TestContextCancelExec(t *testing.T) {
2381	runTests(t, dsn, func(dbt *DBTest) {
2382		dbt.mustExec("CREATE TABLE test (v INTEGER)")
2383		ctx, cancel := context.WithCancel(context.Background())
2384
2385		// Delay execution for just a bit until db.ExecContext has begun.
2386		defer time.AfterFunc(250*time.Millisecond, cancel).Stop()
2387
2388		// This query will be canceled.
2389		startTime := time.Now()
2390		if _, err := dbt.db.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled {
2391			dbt.Errorf("expected context.Canceled, got %v", err)
2392		}
2393		if d := time.Since(startTime); d > 500*time.Millisecond {
2394			dbt.Errorf("too long execution time: %s", d)
2395		}
2396
2397		// Wait for the INSERT query to be done.
2398		time.Sleep(time.Second)
2399
2400		// Check how many times the query is executed.
2401		var v int
2402		if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil {
2403			dbt.Fatalf("%s", err.Error())
2404		}
2405		if v != 1 { // TODO: need to kill the query, and v should be 0.
2406			dbt.Skipf("[WARN] expected val to be 1, got %d", v)
2407		}
2408
2409		// Context is already canceled, so error should come before execution.
2410		if _, err := dbt.db.ExecContext(ctx, "INSERT INTO test VALUES (1)"); err == nil {
2411			dbt.Error("expected error")
2412		} else if err.Error() != "context canceled" {
2413			dbt.Fatalf("unexpected error: %s", err)
2414		}
2415
2416		// The second insert query will fail, so the table has no changes.
2417		if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil {
2418			dbt.Fatalf("%s", err.Error())
2419		}
2420		if v != 1 {
2421			dbt.Skipf("[WARN] expected val to be 1, got %d", v)
2422		}
2423	})
2424}
2425
2426func TestContextCancelQuery(t *testing.T) {
2427	runTests(t, dsn, func(dbt *DBTest) {
2428		dbt.mustExec("CREATE TABLE test (v INTEGER)")
2429		ctx, cancel := context.WithCancel(context.Background())
2430
2431		// Delay execution for just a bit until db.ExecContext has begun.
2432		defer time.AfterFunc(250*time.Millisecond, cancel).Stop()
2433
2434		// This query will be canceled.
2435		startTime := time.Now()
2436		if _, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled {
2437			dbt.Errorf("expected context.Canceled, got %v", err)
2438		}
2439		if d := time.Since(startTime); d > 500*time.Millisecond {
2440			dbt.Errorf("too long execution time: %s", d)
2441		}
2442
2443		// Wait for the INSERT query to be done.
2444		time.Sleep(time.Second)
2445
2446		// Check how many times the query is executed.
2447		var v int
2448		if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil {
2449			dbt.Fatalf("%s", err.Error())
2450		}
2451		if v != 1 { // TODO: need to kill the query, and v should be 0.
2452			dbt.Skipf("[WARN] expected val to be 1, got %d", v)
2453		}
2454
2455		// Context is already canceled, so error should come before execution.
2456		if _, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (1)"); err != context.Canceled {
2457			dbt.Errorf("expected context.Canceled, got %v", err)
2458		}
2459
2460		// The second insert query will fail, so the table has no changes.
2461		if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil {
2462			dbt.Fatalf("%s", err.Error())
2463		}
2464		if v != 1 {
2465			dbt.Skipf("[WARN] expected val to be 1, got %d", v)
2466		}
2467	})
2468}
2469
2470func TestContextCancelQueryRow(t *testing.T) {
2471	runTests(t, dsn, func(dbt *DBTest) {
2472		dbt.mustExec("CREATE TABLE test (v INTEGER)")
2473		dbt.mustExec("INSERT INTO test VALUES (1), (2), (3)")
2474		ctx, cancel := context.WithCancel(context.Background())
2475
2476		rows, err := dbt.db.QueryContext(ctx, "SELECT v FROM test")
2477		if err != nil {
2478			dbt.Fatalf("%s", err.Error())
2479		}
2480
2481		// the first row will be succeed.
2482		var v int
2483		if !rows.Next() {
2484			dbt.Fatalf("unexpected end")
2485		}
2486		if err := rows.Scan(&v); err != nil {
2487			dbt.Fatalf("%s", err.Error())
2488		}
2489
2490		cancel()
2491		// make sure the driver receives the cancel request.
2492		time.Sleep(100 * time.Millisecond)
2493
2494		if rows.Next() {
2495			dbt.Errorf("expected end, but not")
2496		}
2497		if err := rows.Err(); err != context.Canceled {
2498			dbt.Errorf("expected context.Canceled, got %v", err)
2499		}
2500	})
2501}
2502
2503func TestContextCancelPrepare(t *testing.T) {
2504	runTests(t, dsn, func(dbt *DBTest) {
2505		ctx, cancel := context.WithCancel(context.Background())
2506		cancel()
2507		if _, err := dbt.db.PrepareContext(ctx, "SELECT 1"); err != context.Canceled {
2508			dbt.Errorf("expected context.Canceled, got %v", err)
2509		}
2510	})
2511}
2512
2513func TestContextCancelStmtExec(t *testing.T) {
2514	runTests(t, dsn, func(dbt *DBTest) {
2515		dbt.mustExec("CREATE TABLE test (v INTEGER)")
2516		ctx, cancel := context.WithCancel(context.Background())
2517		stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))")
2518		if err != nil {
2519			dbt.Fatalf("unexpected error: %v", err)
2520		}
2521
2522		// Delay execution for just a bit until db.ExecContext has begun.
2523		defer time.AfterFunc(250*time.Millisecond, cancel).Stop()
2524
2525		// This query will be canceled.
2526		startTime := time.Now()
2527		if _, err := stmt.ExecContext(ctx); err != context.Canceled {
2528			dbt.Errorf("expected context.Canceled, got %v", err)
2529		}
2530		if d := time.Since(startTime); d > 500*time.Millisecond {
2531			dbt.Errorf("too long execution time: %s", d)
2532		}
2533
2534		// Wait for the INSERT query to be done.
2535		time.Sleep(time.Second)
2536
2537		// Check how many times the query is executed.
2538		var v int
2539		if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil {
2540			dbt.Fatalf("%s", err.Error())
2541		}
2542		if v != 1 { // TODO: need to kill the query, and v should be 0.
2543			dbt.Skipf("[WARN] expected val to be 1, got %d", v)
2544		}
2545	})
2546}
2547
2548func TestContextCancelStmtQuery(t *testing.T) {
2549	runTests(t, dsn, func(dbt *DBTest) {
2550		dbt.mustExec("CREATE TABLE test (v INTEGER)")
2551		ctx, cancel := context.WithCancel(context.Background())
2552		stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))")
2553		if err != nil {
2554			dbt.Fatalf("unexpected error: %v", err)
2555		}
2556
2557		// Delay execution for just a bit until db.ExecContext has begun.
2558		defer time.AfterFunc(250*time.Millisecond, cancel).Stop()
2559
2560		// This query will be canceled.
2561		startTime := time.Now()
2562		if _, err := stmt.QueryContext(ctx); err != context.Canceled {
2563			dbt.Errorf("expected context.Canceled, got %v", err)
2564		}
2565		if d := time.Since(startTime); d > 500*time.Millisecond {
2566			dbt.Errorf("too long execution time: %s", d)
2567		}
2568
2569		// Wait for the INSERT query has done.
2570		time.Sleep(time.Second)
2571
2572		// Check how many times the query is executed.
2573		var v int
2574		if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil {
2575			dbt.Fatalf("%s", err.Error())
2576		}
2577		if v != 1 { // TODO: need to kill the query, and v should be 0.
2578			dbt.Skipf("[WARN] expected val to be 1, got %d", v)
2579		}
2580	})
2581}
2582
2583func TestContextCancelBegin(t *testing.T) {
2584	runTests(t, dsn, func(dbt *DBTest) {
2585		dbt.mustExec("CREATE TABLE test (v INTEGER)")
2586		ctx, cancel := context.WithCancel(context.Background())
2587		tx, err := dbt.db.BeginTx(ctx, nil)
2588		if err != nil {
2589			dbt.Fatal(err)
2590		}
2591
2592		// Delay execution for just a bit until db.ExecContext has begun.
2593		defer time.AfterFunc(100*time.Millisecond, cancel).Stop()
2594
2595		// This query will be canceled.
2596		startTime := time.Now()
2597		if _, err := tx.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled {
2598			dbt.Errorf("expected context.Canceled, got %v", err)
2599		}
2600		if d := time.Since(startTime); d > 500*time.Millisecond {
2601			dbt.Errorf("too long execution time: %s", d)
2602		}
2603
2604		// Transaction is canceled, so expect an error.
2605		switch err := tx.Commit(); err {
2606		case sql.ErrTxDone:
2607			// because the transaction has already been rollbacked.
2608			// the database/sql package watches ctx
2609			// and rollbacks when ctx is canceled.
2610		case context.Canceled:
2611			// the database/sql package rollbacks on another goroutine,
2612			// so the transaction may not be rollbacked depending on goroutine scheduling.
2613		default:
2614			dbt.Errorf("expected sql.ErrTxDone or context.Canceled, got %v", err)
2615		}
2616
2617		// Context is canceled, so cannot begin a transaction.
2618		if _, err := dbt.db.BeginTx(ctx, nil); err != context.Canceled {
2619			dbt.Errorf("expected context.Canceled, got %v", err)
2620		}
2621	})
2622}
2623
2624func TestContextBeginIsolationLevel(t *testing.T) {
2625	runTests(t, dsn, func(dbt *DBTest) {
2626		dbt.mustExec("CREATE TABLE test (v INTEGER)")
2627		ctx, cancel := context.WithCancel(context.Background())
2628		defer cancel()
2629
2630		tx1, err := dbt.db.BeginTx(ctx, &sql.TxOptions{
2631			Isolation: sql.LevelRepeatableRead,
2632		})
2633		if err != nil {
2634			dbt.Fatal(err)
2635		}
2636
2637		tx2, err := dbt.db.BeginTx(ctx, &sql.TxOptions{
2638			Isolation: sql.LevelReadCommitted,
2639		})
2640		if err != nil {
2641			dbt.Fatal(err)
2642		}
2643
2644		_, err = tx1.ExecContext(ctx, "INSERT INTO test VALUES (1)")
2645		if err != nil {
2646			dbt.Fatal(err)
2647		}
2648
2649		var v int
2650		row := tx2.QueryRowContext(ctx, "SELECT COUNT(*) FROM test")
2651		if err := row.Scan(&v); err != nil {
2652			dbt.Fatal(err)
2653		}
2654		// Because writer transaction wasn't commited yet, it should be available
2655		if v != 0 {
2656			dbt.Errorf("expected val to be 0, got %d", v)
2657		}
2658
2659		err = tx1.Commit()
2660		if err != nil {
2661			dbt.Fatal(err)
2662		}
2663
2664		row = tx2.QueryRowContext(ctx, "SELECT COUNT(*) FROM test")
2665		if err := row.Scan(&v); err != nil {
2666			dbt.Fatal(err)
2667		}
2668		// Data written by writer transaction is already commited, it should be selectable
2669		if v != 1 {
2670			dbt.Errorf("expected val to be 1, got %d", v)
2671		}
2672		tx2.Commit()
2673	})
2674}
2675
2676func TestContextBeginReadOnly(t *testing.T) {
2677	runTests(t, dsn, func(dbt *DBTest) {
2678		dbt.mustExec("CREATE TABLE test (v INTEGER)")
2679		ctx, cancel := context.WithCancel(context.Background())
2680		defer cancel()
2681
2682		tx, err := dbt.db.BeginTx(ctx, &sql.TxOptions{
2683			ReadOnly: true,
2684		})
2685		if _, ok := err.(*MySQLError); ok {
2686			dbt.Skip("It seems that your MySQL does not support READ ONLY transactions")
2687			return
2688		} else if err != nil {
2689			dbt.Fatal(err)
2690		}
2691
2692		// INSERT queries fail in a READ ONLY transaction.
2693		_, err = tx.ExecContext(ctx, "INSERT INTO test VALUES (1)")
2694		if _, ok := err.(*MySQLError); !ok {
2695			dbt.Errorf("expected MySQLError, got %v", err)
2696		}
2697
2698		// SELECT queries can be executed.
2699		var v int
2700		row := tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM test")
2701		if err := row.Scan(&v); err != nil {
2702			dbt.Fatal(err)
2703		}
2704		if v != 0 {
2705			dbt.Errorf("expected val to be 0, got %d", v)
2706		}
2707
2708		if err := tx.Commit(); err != nil {
2709			dbt.Fatal(err)
2710		}
2711	})
2712}
2713
2714func TestRowsColumnTypes(t *testing.T) {
2715	niNULL := sql.NullInt64{Int64: 0, Valid: false}
2716	ni0 := sql.NullInt64{Int64: 0, Valid: true}
2717	ni1 := sql.NullInt64{Int64: 1, Valid: true}
2718	ni42 := sql.NullInt64{Int64: 42, Valid: true}
2719	nfNULL := sql.NullFloat64{Float64: 0.0, Valid: false}
2720	nf0 := sql.NullFloat64{Float64: 0.0, Valid: true}
2721	nf1337 := sql.NullFloat64{Float64: 13.37, Valid: true}
2722	nt0 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 0, time.UTC), Valid: true}
2723	nt1 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 100000000, time.UTC), Valid: true}
2724	nt2 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 110000000, time.UTC), Valid: true}
2725	nt6 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 111111000, time.UTC), Valid: true}
2726	nd1 := NullTime{Time: time.Date(2006, 01, 02, 0, 0, 0, 0, time.UTC), Valid: true}
2727	nd2 := NullTime{Time: time.Date(2006, 03, 04, 0, 0, 0, 0, time.UTC), Valid: true}
2728	ndNULL := NullTime{Time: time.Time{}, Valid: false}
2729	rbNULL := sql.RawBytes(nil)
2730	rb0 := sql.RawBytes("0")
2731	rb42 := sql.RawBytes("42")
2732	rbTest := sql.RawBytes("Test")
2733	rb0pad4 := sql.RawBytes("0\x00\x00\x00") // BINARY right-pads values with 0x00
2734	rbx0 := sql.RawBytes("\x00")
2735	rbx42 := sql.RawBytes("\x42")
2736
2737	var columns = []struct {
2738		name             string
2739		fieldType        string // type used when creating table schema
2740		databaseTypeName string // actual type used by MySQL
2741		scanType         reflect.Type
2742		nullable         bool
2743		precision        int64 // 0 if not ok
2744		scale            int64
2745		valuesIn         [3]string
2746		valuesOut        [3]interface{}
2747	}{
2748		{"bit8null", "BIT(8)", "BIT", scanTypeRawBytes, true, 0, 0, [3]string{"0x0", "NULL", "0x42"}, [3]interface{}{rbx0, rbNULL, rbx42}},
2749		{"boolnull", "BOOL", "TINYINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "true", "0"}, [3]interface{}{niNULL, ni1, ni0}},
2750		{"bool", "BOOL NOT NULL", "TINYINT", scanTypeInt8, false, 0, 0, [3]string{"1", "0", "FALSE"}, [3]interface{}{int8(1), int8(0), int8(0)}},
2751		{"intnull", "INTEGER", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}},
2752		{"smallint", "SMALLINT NOT NULL", "SMALLINT", scanTypeInt16, false, 0, 0, [3]string{"0", "-32768", "32767"}, [3]interface{}{int16(0), int16(-32768), int16(32767)}},
2753		{"smallintnull", "SMALLINT", "SMALLINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}},
2754		{"int3null", "INT(3)", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}},
2755		{"int7", "INT(7) NOT NULL", "INT", scanTypeInt32, false, 0, 0, [3]string{"0", "-1337", "42"}, [3]interface{}{int32(0), int32(-1337), int32(42)}},
2756		{"mediumintnull", "MEDIUMINT", "MEDIUMINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "42", "NULL"}, [3]interface{}{ni0, ni42, niNULL}},
2757		{"bigint", "BIGINT NOT NULL", "BIGINT", scanTypeInt64, false, 0, 0, [3]string{"0", "65535", "-42"}, [3]interface{}{int64(0), int64(65535), int64(-42)}},
2758		{"bigintnull", "BIGINT", "BIGINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "1", "42"}, [3]interface{}{niNULL, ni1, ni42}},
2759		{"tinyuint", "TINYINT UNSIGNED NOT NULL", "TINYINT", scanTypeUint8, false, 0, 0, [3]string{"0", "255", "42"}, [3]interface{}{uint8(0), uint8(255), uint8(42)}},
2760		{"smalluint", "SMALLINT UNSIGNED NOT NULL", "SMALLINT", scanTypeUint16, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint16(0), uint16(65535), uint16(42)}},
2761		{"biguint", "BIGINT UNSIGNED NOT NULL", "BIGINT", scanTypeUint64, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint64(0), uint64(65535), uint64(42)}},
2762		{"uint13", "INT(13) UNSIGNED NOT NULL", "INT", scanTypeUint32, false, 0, 0, [3]string{"0", "1337", "42"}, [3]interface{}{uint32(0), uint32(1337), uint32(42)}},
2763		{"float", "FLOAT NOT NULL", "FLOAT", scanTypeFloat32, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]interface{}{float32(0), float32(42), float32(13.37)}},
2764		{"floatnull", "FLOAT", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}},
2765		{"float74null", "FLOAT(7,4)", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, 4, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}},
2766		{"double", "DOUBLE NOT NULL", "DOUBLE", scanTypeFloat64, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]interface{}{float64(0), float64(42), float64(13.37)}},
2767		{"doublenull", "DOUBLE", "DOUBLE", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}},
2768		{"decimal1", "DECIMAL(10,6) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 10, 6, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{sql.RawBytes("0.000000"), sql.RawBytes("13.370000"), sql.RawBytes("1234.123456")}},
2769		{"decimal1null", "DECIMAL(10,6)", "DECIMAL", scanTypeRawBytes, true, 10, 6, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{sql.RawBytes("0.000000"), rbNULL, sql.RawBytes("1234.123456")}},
2770		{"decimal2", "DECIMAL(8,4) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 8, 4, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{sql.RawBytes("0.0000"), sql.RawBytes("13.3700"), sql.RawBytes("1234.1235")}},
2771		{"decimal2null", "DECIMAL(8,4)", "DECIMAL", scanTypeRawBytes, true, 8, 4, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{sql.RawBytes("0.0000"), rbNULL, sql.RawBytes("1234.1235")}},
2772		{"decimal3", "DECIMAL(5,0) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 5, 0, [3]string{"0", "13.37", "-12345.123456"}, [3]interface{}{rb0, sql.RawBytes("13"), sql.RawBytes("-12345")}},
2773		{"decimal3null", "DECIMAL(5,0)", "DECIMAL", scanTypeRawBytes, true, 5, 0, [3]string{"0", "NULL", "-12345.123456"}, [3]interface{}{rb0, rbNULL, sql.RawBytes("-12345")}},
2774		{"char25null", "CHAR(25)", "CHAR", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}},
2775		{"varchar42", "VARCHAR(42) NOT NULL", "VARCHAR", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
2776		{"binary4null", "BINARY(4)", "BINARY", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0pad4, rbNULL, rbTest}},
2777		{"varbinary42", "VARBINARY(42) NOT NULL", "VARBINARY", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
2778		{"tinyblobnull", "TINYBLOB", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}},
2779		{"tinytextnull", "TINYTEXT", "TEXT", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}},
2780		{"blobnull", "BLOB", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}},
2781		{"textnull", "TEXT", "TEXT", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}},
2782		{"mediumblob", "MEDIUMBLOB NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
2783		{"mediumtext", "MEDIUMTEXT NOT NULL", "TEXT", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
2784		{"longblob", "LONGBLOB NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
2785		{"longtext", "LONGTEXT NOT NULL", "TEXT", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
2786		{"datetime", "DATETIME", "DATETIME", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt0, nt0}},
2787		{"datetime2", "DATETIME(2)", "DATETIME", scanTypeNullTime, true, 2, 2, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt2}},
2788		{"datetime6", "DATETIME(6)", "DATETIME", scanTypeNullTime, true, 6, 6, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt6}},
2789		{"date", "DATE", "DATE", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02'", "NULL", "'2006-03-04'"}, [3]interface{}{nd1, ndNULL, nd2}},
2790		{"year", "YEAR NOT NULL", "YEAR", scanTypeUint16, false, 0, 0, [3]string{"2006", "2000", "1994"}, [3]interface{}{uint16(2006), uint16(2000), uint16(1994)}},
2791	}
2792
2793	schema := ""
2794	values1 := ""
2795	values2 := ""
2796	values3 := ""
2797	for _, column := range columns {
2798		schema += fmt.Sprintf("`%s` %s, ", column.name, column.fieldType)
2799		values1 += column.valuesIn[0] + ", "
2800		values2 += column.valuesIn[1] + ", "
2801		values3 += column.valuesIn[2] + ", "
2802	}
2803	schema = schema[:len(schema)-2]
2804	values1 = values1[:len(values1)-2]
2805	values2 = values2[:len(values2)-2]
2806	values3 = values3[:len(values3)-2]
2807
2808	dsns := []string{
2809		dsn + "&parseTime=true",
2810		dsn + "&parseTime=false",
2811	}
2812	for _, testdsn := range dsns {
2813		runTests(t, testdsn, func(dbt *DBTest) {
2814			dbt.mustExec("CREATE TABLE test (" + schema + ")")
2815			dbt.mustExec("INSERT INTO test VALUES (" + values1 + "), (" + values2 + "), (" + values3 + ")")
2816
2817			rows, err := dbt.db.Query("SELECT * FROM test")
2818			if err != nil {
2819				t.Fatalf("Query: %v", err)
2820			}
2821
2822			tt, err := rows.ColumnTypes()
2823			if err != nil {
2824				t.Fatalf("ColumnTypes: %v", err)
2825			}
2826
2827			if len(tt) != len(columns) {
2828				t.Fatalf("unexpected number of columns: expected %d, got %d", len(columns), len(tt))
2829			}
2830
2831			types := make([]reflect.Type, len(tt))
2832			for i, tp := range tt {
2833				column := columns[i]
2834
2835				// Name
2836				name := tp.Name()
2837				if name != column.name {
2838					t.Errorf("column name mismatch %s != %s", name, column.name)
2839					continue
2840				}
2841
2842				// DatabaseTypeName
2843				databaseTypeName := tp.DatabaseTypeName()
2844				if databaseTypeName != column.databaseTypeName {
2845					t.Errorf("databasetypename name mismatch for column %q: %s != %s", name, databaseTypeName, column.databaseTypeName)
2846					continue
2847				}
2848
2849				// ScanType
2850				scanType := tp.ScanType()
2851				if scanType != column.scanType {
2852					if scanType == nil {
2853						t.Errorf("scantype is null for column %q", name)
2854					} else {
2855						t.Errorf("scantype mismatch for column %q: %s != %s", name, scanType.Name(), column.scanType.Name())
2856					}
2857					continue
2858				}
2859				types[i] = scanType
2860
2861				// Nullable
2862				nullable, ok := tp.Nullable()
2863				if !ok {
2864					t.Errorf("nullable not ok %q", name)
2865					continue
2866				}
2867				if nullable != column.nullable {
2868					t.Errorf("nullable mismatch for column %q: %t != %t", name, nullable, column.nullable)
2869				}
2870
2871				// Length
2872				// length, ok := tp.Length()
2873				// if length != column.length {
2874				// 	if !ok {
2875				// 		t.Errorf("length not ok for column %q", name)
2876				// 	} else {
2877				// 		t.Errorf("length mismatch for column %q: %d != %d", name, length, column.length)
2878				// 	}
2879				// 	continue
2880				// }
2881
2882				// Precision and Scale
2883				precision, scale, ok := tp.DecimalSize()
2884				if precision != column.precision {
2885					if !ok {
2886						t.Errorf("precision not ok for column %q", name)
2887					} else {
2888						t.Errorf("precision mismatch for column %q: %d != %d", name, precision, column.precision)
2889					}
2890					continue
2891				}
2892				if scale != column.scale {
2893					if !ok {
2894						t.Errorf("scale not ok for column %q", name)
2895					} else {
2896						t.Errorf("scale mismatch for column %q: %d != %d", name, scale, column.scale)
2897					}
2898					continue
2899				}
2900			}
2901
2902			values := make([]interface{}, len(tt))
2903			for i := range values {
2904				values[i] = reflect.New(types[i]).Interface()
2905			}
2906			i := 0
2907			for rows.Next() {
2908				err = rows.Scan(values...)
2909				if err != nil {
2910					t.Fatalf("failed to scan values in %v", err)
2911				}
2912				for j := range values {
2913					value := reflect.ValueOf(values[j]).Elem().Interface()
2914					if !reflect.DeepEqual(value, columns[j].valuesOut[i]) {
2915						if columns[j].scanType == scanTypeRawBytes {
2916							t.Errorf("row %d, column %d: %v != %v", i, j, string(value.(sql.RawBytes)), string(columns[j].valuesOut[i].(sql.RawBytes)))
2917						} else {
2918							t.Errorf("row %d, column %d: %v != %v", i, j, value, columns[j].valuesOut[i])
2919						}
2920					}
2921				}
2922				i++
2923			}
2924			if i != 3 {
2925				t.Errorf("expected 3 rows, got %d", i)
2926			}
2927
2928			if err := rows.Close(); err != nil {
2929				t.Errorf("error closing rows: %s", err)
2930			}
2931		})
2932	}
2933}
2934
2935func TestValuerWithValueReceiverGivenNilValue(t *testing.T) {
2936	runTests(t, dsn, func(dbt *DBTest) {
2937		dbt.mustExec("CREATE TABLE test (value VARCHAR(255))")
2938		dbt.db.Exec("INSERT INTO test VALUES (?)", (*testValuer)(nil))
2939		// This test will panic on the INSERT if ConvertValue() does not check for typed nil before calling Value()
2940	})
2941}
2942
2943// TestRawBytesAreNotModified checks for a race condition that arises when a query context
2944// is canceled while a user is calling rows.Scan. This is a more stringent test than the one
2945// proposed in https://github.com/golang/go/issues/23519. Here we're explicitly using
2946// `sql.RawBytes` to check the contents of our internal buffers are not modified after an implicit
2947// call to `Rows.Close`, so Context cancellation should **not** invalidate the backing buffers.
2948func TestRawBytesAreNotModified(t *testing.T) {
2949	const blob = "abcdefghijklmnop"
2950	const contextRaceIterations = 20
2951	const blobSize = defaultBufSize * 3 / 4 // Second row overwrites first row.
2952	const insertRows = 4
2953
2954	var sqlBlobs = [2]string{
2955		strings.Repeat(blob, blobSize/len(blob)),
2956		strings.Repeat(strings.ToUpper(blob), blobSize/len(blob)),
2957	}
2958
2959	runTests(t, dsn, func(dbt *DBTest) {
2960		dbt.mustExec("CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8")
2961		for i := 0; i < insertRows; i++ {
2962			dbt.mustExec("INSERT INTO test VALUES (?, ?)", i+1, sqlBlobs[i&1])
2963		}
2964
2965		for i := 0; i < contextRaceIterations; i++ {
2966			func() {
2967				ctx, cancel := context.WithCancel(context.Background())
2968				defer cancel()
2969
2970				rows, err := dbt.db.QueryContext(ctx, `SELECT id, value FROM test`)
2971				if err != nil {
2972					t.Fatal(err)
2973				}
2974
2975				var b int
2976				var raw sql.RawBytes
2977				for rows.Next() {
2978					if err := rows.Scan(&b, &raw); err != nil {
2979						t.Fatal(err)
2980					}
2981
2982					before := string(raw)
2983					// Ensure cancelling the query does not corrupt the contents of `raw`
2984					cancel()
2985					time.Sleep(time.Microsecond * 100)
2986					after := string(raw)
2987
2988					if before != after {
2989						t.Fatalf("the backing storage for sql.RawBytes has been modified (i=%v)", i)
2990					}
2991				}
2992				rows.Close()
2993			}()
2994		}
2995	})
2996}
2997
2998var _ driver.DriverContext = &MySQLDriver{}
2999
3000type dialCtxKey struct{}
3001
3002func TestConnectorObeysDialTimeouts(t *testing.T) {
3003	if !available {
3004		t.Skipf("MySQL server not running on %s", netAddr)
3005	}
3006
3007	RegisterDialContext("dialctxtest", func(ctx context.Context, addr string) (net.Conn, error) {
3008		var d net.Dialer
3009		if !ctx.Value(dialCtxKey{}).(bool) {
3010			return nil, fmt.Errorf("test error: query context is not propagated to our dialer")
3011		}
3012		return d.DialContext(ctx, prot, addr)
3013	})
3014
3015	db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@dialctxtest(%s)/%s?timeout=30s", user, pass, addr, dbname))
3016	if err != nil {
3017		t.Fatalf("error connecting: %s", err.Error())
3018	}
3019	defer db.Close()
3020
3021	ctx := context.WithValue(context.Background(), dialCtxKey{}, true)
3022
3023	_, err = db.ExecContext(ctx, "DO 1")
3024	if err != nil {
3025		t.Fatal(err)
3026	}
3027}
3028
3029func configForTests(t *testing.T) *Config {
3030	if !available {
3031		t.Skipf("MySQL server not running on %s", netAddr)
3032	}
3033
3034	mycnf := NewConfig()
3035	mycnf.User = user
3036	mycnf.Passwd = pass
3037	mycnf.Addr = addr
3038	mycnf.Net = prot
3039	mycnf.DBName = dbname
3040	return mycnf
3041}
3042
3043func TestNewConnector(t *testing.T) {
3044	mycnf := configForTests(t)
3045	conn, err := NewConnector(mycnf)
3046	if err != nil {
3047		t.Fatal(err)
3048	}
3049
3050	db := sql.OpenDB(conn)
3051	defer db.Close()
3052
3053	if err := db.Ping(); err != nil {
3054		t.Fatal(err)
3055	}
3056}
3057
3058type slowConnection struct {
3059	net.Conn
3060	slowdown time.Duration
3061}
3062
3063func (sc *slowConnection) Read(b []byte) (int, error) {
3064	time.Sleep(sc.slowdown)
3065	return sc.Conn.Read(b)
3066}
3067
3068type connectorHijack struct {
3069	driver.Connector
3070	connErr error
3071}
3072
3073func (cw *connectorHijack) Connect(ctx context.Context) (driver.Conn, error) {
3074	var conn driver.Conn
3075	conn, cw.connErr = cw.Connector.Connect(ctx)
3076	return conn, cw.connErr
3077}
3078
3079func TestConnectorTimeoutsDuringOpen(t *testing.T) {
3080	RegisterDialContext("slowconn", func(ctx context.Context, addr string) (net.Conn, error) {
3081		var d net.Dialer
3082		conn, err := d.DialContext(ctx, prot, addr)
3083		if err != nil {
3084			return nil, err
3085		}
3086		return &slowConnection{Conn: conn, slowdown: 100 * time.Millisecond}, nil
3087	})
3088
3089	mycnf := configForTests(t)
3090	mycnf.Net = "slowconn"
3091
3092	conn, err := NewConnector(mycnf)
3093	if err != nil {
3094		t.Fatal(err)
3095	}
3096
3097	hijack := &connectorHijack{Connector: conn}
3098
3099	db := sql.OpenDB(hijack)
3100	defer db.Close()
3101
3102	ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
3103	defer cancel()
3104
3105	_, err = db.ExecContext(ctx, "DO 1")
3106	if err != context.DeadlineExceeded {
3107		t.Fatalf("ExecContext should have timed out")
3108	}
3109	if hijack.connErr != context.DeadlineExceeded {
3110		t.Fatalf("(*Connector).Connect should have timed out")
3111	}
3112}
3113
3114// A connection which can only be closed.
3115type dummyConnection struct {
3116	net.Conn
3117	closed bool
3118}
3119
3120func (d *dummyConnection) Close() error {
3121	d.closed = true
3122	return nil
3123}
3124
3125func TestConnectorTimeoutsWatchCancel(t *testing.T) {
3126	var (
3127		cancel  func()           // Used to cancel the context just after connecting.
3128		created *dummyConnection // The created connection.
3129	)
3130
3131	RegisterDialContext("TestConnectorTimeoutsWatchCancel", func(ctx context.Context, addr string) (net.Conn, error) {
3132		// Canceling at this time triggers the watchCancel error branch in Connect().
3133		cancel()
3134		created = &dummyConnection{}
3135		return created, nil
3136	})
3137
3138	mycnf := NewConfig()
3139	mycnf.User = "root"
3140	mycnf.Addr = "foo"
3141	mycnf.Net = "TestConnectorTimeoutsWatchCancel"
3142
3143	conn, err := NewConnector(mycnf)
3144	if err != nil {
3145		t.Fatal(err)
3146	}
3147
3148	db := sql.OpenDB(conn)
3149	defer db.Close()
3150
3151	var ctx context.Context
3152	ctx, cancel = context.WithCancel(context.Background())
3153	defer cancel()
3154
3155	if _, err := db.Conn(ctx); err != context.Canceled {
3156		t.Errorf("got %v, want context.Canceled", err)
3157	}
3158
3159	if created == nil {
3160		t.Fatal("no connection created")
3161	}
3162	if !created.closed {
3163		t.Errorf("connection not closed")
3164	}
3165}
3166