1// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
2//
3// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
4//
5// This Source Code Form is subject to the terms of the Mozilla Public
6// License, v. 2.0. If a copy of the MPL was not distributed with this file,
7// You can obtain one at http://mozilla.org/MPL/2.0/.
8
9package mysql
10
11import (
12	"context"
13	"database/sql"
14	"database/sql/driver"
15	"encoding/json"
16	"io"
17	"net"
18	"strconv"
19	"strings"
20	"time"
21)
22
23type mysqlConn struct {
24	buf              buffer
25	netConn          net.Conn
26	rawConn          net.Conn // underlying connection when netConn is TLS connection.
27	affectedRows     uint64
28	insertId         uint64
29	cfg              *Config
30	maxAllowedPacket int
31	maxWriteSize     int
32	writeTimeout     time.Duration
33	flags            clientFlag
34	status           statusFlag
35	sequence         uint8
36	parseTime        bool
37	reset            bool // set when the Go SQL package calls ResetSession
38
39	// for context support (Go 1.8+)
40	watching bool
41	watcher  chan<- context.Context
42	closech  chan struct{}
43	finished chan<- struct{}
44	canceled atomicError // set non-nil if conn is canceled
45	closed   atomicBool  // set when conn is closed, before closech is closed
46}
47
48// Handles parameters set in DSN after the connection is established
49func (mc *mysqlConn) handleParams() (err error) {
50	var cmdSet strings.Builder
51	for param, val := range mc.cfg.Params {
52		switch param {
53		// Charset: character_set_connection, character_set_client, character_set_results
54		case "charset":
55			charsets := strings.Split(val, ",")
56			for i := range charsets {
57				// ignore errors here - a charset may not exist
58				err = mc.exec("SET NAMES " + charsets[i])
59				if err == nil {
60					break
61				}
62			}
63			if err != nil {
64				return
65			}
66
67		// Other system vars accumulated in a single SET command
68		default:
69			if cmdSet.Len() == 0 {
70				// Heuristic: 29 chars for each other key=value to reduce reallocations
71				cmdSet.Grow(4 + len(param) + 1 + len(val) + 30*(len(mc.cfg.Params)-1))
72				cmdSet.WriteString("SET ")
73			} else {
74				cmdSet.WriteByte(',')
75			}
76			cmdSet.WriteString(param)
77			cmdSet.WriteByte('=')
78			cmdSet.WriteString(val)
79		}
80	}
81
82	if cmdSet.Len() > 0 {
83		err = mc.exec(cmdSet.String())
84		if err != nil {
85			return
86		}
87	}
88
89	return
90}
91
92func (mc *mysqlConn) markBadConn(err error) error {
93	if mc == nil {
94		return err
95	}
96	if err != errBadConnNoWrite {
97		return err
98	}
99	return driver.ErrBadConn
100}
101
102func (mc *mysqlConn) Begin() (driver.Tx, error) {
103	return mc.begin(false)
104}
105
106func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
107	if mc.closed.IsSet() {
108		errLog.Print(ErrInvalidConn)
109		return nil, driver.ErrBadConn
110	}
111	var q string
112	if readOnly {
113		q = "START TRANSACTION READ ONLY"
114	} else {
115		q = "START TRANSACTION"
116	}
117	err := mc.exec(q)
118	if err == nil {
119		return &mysqlTx{mc}, err
120	}
121	return nil, mc.markBadConn(err)
122}
123
124func (mc *mysqlConn) Close() (err error) {
125	// Makes Close idempotent
126	if !mc.closed.IsSet() {
127		err = mc.writeCommandPacket(comQuit)
128	}
129
130	mc.cleanup()
131
132	return
133}
134
135// Closes the network connection and unsets internal variables. Do not call this
136// function after successfully authentication, call Close instead. This function
137// is called before auth or on auth failure because MySQL will have already
138// closed the network connection.
139func (mc *mysqlConn) cleanup() {
140	if !mc.closed.TrySet(true) {
141		return
142	}
143
144	// Makes cleanup idempotent
145	close(mc.closech)
146	if mc.netConn == nil {
147		return
148	}
149	if err := mc.netConn.Close(); err != nil {
150		errLog.Print(err)
151	}
152}
153
154func (mc *mysqlConn) error() error {
155	if mc.closed.IsSet() {
156		if err := mc.canceled.Value(); err != nil {
157			return err
158		}
159		return ErrInvalidConn
160	}
161	return nil
162}
163
164func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
165	if mc.closed.IsSet() {
166		errLog.Print(ErrInvalidConn)
167		return nil, driver.ErrBadConn
168	}
169	// Send command
170	err := mc.writeCommandPacketStr(comStmtPrepare, query)
171	if err != nil {
172		// STMT_PREPARE is safe to retry.  So we can return ErrBadConn here.
173		errLog.Print(err)
174		return nil, driver.ErrBadConn
175	}
176
177	stmt := &mysqlStmt{
178		mc: mc,
179	}
180
181	// Read Result
182	columnCount, err := stmt.readPrepareResultPacket()
183	if err == nil {
184		if stmt.paramCount > 0 {
185			if err = mc.readUntilEOF(); err != nil {
186				return nil, err
187			}
188		}
189
190		if columnCount > 0 {
191			err = mc.readUntilEOF()
192		}
193	}
194
195	return stmt, err
196}
197
198func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
199	// Number of ? should be same to len(args)
200	if strings.Count(query, "?") != len(args) {
201		return "", driver.ErrSkip
202	}
203
204	buf, err := mc.buf.takeCompleteBuffer()
205	if err != nil {
206		// can not take the buffer. Something must be wrong with the connection
207		errLog.Print(err)
208		return "", ErrInvalidConn
209	}
210	buf = buf[:0]
211	argPos := 0
212
213	for i := 0; i < len(query); i++ {
214		q := strings.IndexByte(query[i:], '?')
215		if q == -1 {
216			buf = append(buf, query[i:]...)
217			break
218		}
219		buf = append(buf, query[i:i+q]...)
220		i += q
221
222		arg := args[argPos]
223		argPos++
224
225		if arg == nil {
226			buf = append(buf, "NULL"...)
227			continue
228		}
229
230		switch v := arg.(type) {
231		case int64:
232			buf = strconv.AppendInt(buf, v, 10)
233		case uint64:
234			// Handle uint64 explicitly because our custom ConvertValue emits unsigned values
235			buf = strconv.AppendUint(buf, v, 10)
236		case float64:
237			buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
238		case bool:
239			if v {
240				buf = append(buf, '1')
241			} else {
242				buf = append(buf, '0')
243			}
244		case time.Time:
245			if v.IsZero() {
246				buf = append(buf, "'0000-00-00'"...)
247			} else {
248				buf = append(buf, '\'')
249				buf, err = appendDateTime(buf, v.In(mc.cfg.Loc))
250				if err != nil {
251					return "", err
252				}
253				buf = append(buf, '\'')
254			}
255		case json.RawMessage:
256			buf = append(buf, '\'')
257			if mc.status&statusNoBackslashEscapes == 0 {
258				buf = escapeBytesBackslash(buf, v)
259			} else {
260				buf = escapeBytesQuotes(buf, v)
261			}
262			buf = append(buf, '\'')
263		case []byte:
264			if v == nil {
265				buf = append(buf, "NULL"...)
266			} else {
267				buf = append(buf, "_binary'"...)
268				if mc.status&statusNoBackslashEscapes == 0 {
269					buf = escapeBytesBackslash(buf, v)
270				} else {
271					buf = escapeBytesQuotes(buf, v)
272				}
273				buf = append(buf, '\'')
274			}
275		case string:
276			buf = append(buf, '\'')
277			if mc.status&statusNoBackslashEscapes == 0 {
278				buf = escapeStringBackslash(buf, v)
279			} else {
280				buf = escapeStringQuotes(buf, v)
281			}
282			buf = append(buf, '\'')
283		default:
284			return "", driver.ErrSkip
285		}
286
287		if len(buf)+4 > mc.maxAllowedPacket {
288			return "", driver.ErrSkip
289		}
290	}
291	if argPos != len(args) {
292		return "", driver.ErrSkip
293	}
294	return string(buf), nil
295}
296
297func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
298	if mc.closed.IsSet() {
299		errLog.Print(ErrInvalidConn)
300		return nil, driver.ErrBadConn
301	}
302	if len(args) != 0 {
303		if !mc.cfg.InterpolateParams {
304			return nil, driver.ErrSkip
305		}
306		// try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
307		prepared, err := mc.interpolateParams(query, args)
308		if err != nil {
309			return nil, err
310		}
311		query = prepared
312	}
313	mc.affectedRows = 0
314	mc.insertId = 0
315
316	err := mc.exec(query)
317	if err == nil {
318		return &mysqlResult{
319			affectedRows: int64(mc.affectedRows),
320			insertId:     int64(mc.insertId),
321		}, err
322	}
323	return nil, mc.markBadConn(err)
324}
325
326// Internal function to execute commands
327func (mc *mysqlConn) exec(query string) error {
328	// Send command
329	if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
330		return mc.markBadConn(err)
331	}
332
333	// Read Result
334	resLen, err := mc.readResultSetHeaderPacket()
335	if err != nil {
336		return err
337	}
338
339	if resLen > 0 {
340		// columns
341		if err := mc.readUntilEOF(); err != nil {
342			return err
343		}
344
345		// rows
346		if err := mc.readUntilEOF(); err != nil {
347			return err
348		}
349	}
350
351	return mc.discardResults()
352}
353
354func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
355	return mc.query(query, args)
356}
357
358func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
359	if mc.closed.IsSet() {
360		errLog.Print(ErrInvalidConn)
361		return nil, driver.ErrBadConn
362	}
363	if len(args) != 0 {
364		if !mc.cfg.InterpolateParams {
365			return nil, driver.ErrSkip
366		}
367		// try client-side prepare to reduce roundtrip
368		prepared, err := mc.interpolateParams(query, args)
369		if err != nil {
370			return nil, err
371		}
372		query = prepared
373	}
374	// Send command
375	err := mc.writeCommandPacketStr(comQuery, query)
376	if err == nil {
377		// Read Result
378		var resLen int
379		resLen, err = mc.readResultSetHeaderPacket()
380		if err == nil {
381			rows := new(textRows)
382			rows.mc = mc
383
384			if resLen == 0 {
385				rows.rs.done = true
386
387				switch err := rows.NextResultSet(); err {
388				case nil, io.EOF:
389					return rows, nil
390				default:
391					return nil, err
392				}
393			}
394
395			// Columns
396			rows.rs.columns, err = mc.readColumns(resLen)
397			return rows, err
398		}
399	}
400	return nil, mc.markBadConn(err)
401}
402
403// Gets the value of the given MySQL System Variable
404// The returned byte slice is only valid until the next read
405func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
406	// Send command
407	if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
408		return nil, err
409	}
410
411	// Read Result
412	resLen, err := mc.readResultSetHeaderPacket()
413	if err == nil {
414		rows := new(textRows)
415		rows.mc = mc
416		rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
417
418		if resLen > 0 {
419			// Columns
420			if err := mc.readUntilEOF(); err != nil {
421				return nil, err
422			}
423		}
424
425		dest := make([]driver.Value, resLen)
426		if err = rows.readRow(dest); err == nil {
427			return dest[0].([]byte), mc.readUntilEOF()
428		}
429	}
430	return nil, err
431}
432
433// finish is called when the query has canceled.
434func (mc *mysqlConn) cancel(err error) {
435	mc.canceled.Set(err)
436	mc.cleanup()
437}
438
439// finish is called when the query has succeeded.
440func (mc *mysqlConn) finish() {
441	if !mc.watching || mc.finished == nil {
442		return
443	}
444	select {
445	case mc.finished <- struct{}{}:
446		mc.watching = false
447	case <-mc.closech:
448	}
449}
450
451// Ping implements driver.Pinger interface
452func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
453	if mc.closed.IsSet() {
454		errLog.Print(ErrInvalidConn)
455		return driver.ErrBadConn
456	}
457
458	if err = mc.watchCancel(ctx); err != nil {
459		return
460	}
461	defer mc.finish()
462
463	if err = mc.writeCommandPacket(comPing); err != nil {
464		return mc.markBadConn(err)
465	}
466
467	return mc.readResultOK()
468}
469
470// BeginTx implements driver.ConnBeginTx interface
471func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
472	if mc.closed.IsSet() {
473		return nil, driver.ErrBadConn
474	}
475
476	if err := mc.watchCancel(ctx); err != nil {
477		return nil, err
478	}
479	defer mc.finish()
480
481	if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault {
482		level, err := mapIsolationLevel(opts.Isolation)
483		if err != nil {
484			return nil, err
485		}
486		err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level)
487		if err != nil {
488			return nil, err
489		}
490	}
491
492	return mc.begin(opts.ReadOnly)
493}
494
495func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
496	dargs, err := namedValueToValue(args)
497	if err != nil {
498		return nil, err
499	}
500
501	if err := mc.watchCancel(ctx); err != nil {
502		return nil, err
503	}
504
505	rows, err := mc.query(query, dargs)
506	if err != nil {
507		mc.finish()
508		return nil, err
509	}
510	rows.finish = mc.finish
511	return rows, err
512}
513
514func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
515	dargs, err := namedValueToValue(args)
516	if err != nil {
517		return nil, err
518	}
519
520	if err := mc.watchCancel(ctx); err != nil {
521		return nil, err
522	}
523	defer mc.finish()
524
525	return mc.Exec(query, dargs)
526}
527
528func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
529	if err := mc.watchCancel(ctx); err != nil {
530		return nil, err
531	}
532
533	stmt, err := mc.Prepare(query)
534	mc.finish()
535	if err != nil {
536		return nil, err
537	}
538
539	select {
540	default:
541	case <-ctx.Done():
542		stmt.Close()
543		return nil, ctx.Err()
544	}
545	return stmt, nil
546}
547
548func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
549	dargs, err := namedValueToValue(args)
550	if err != nil {
551		return nil, err
552	}
553
554	if err := stmt.mc.watchCancel(ctx); err != nil {
555		return nil, err
556	}
557
558	rows, err := stmt.query(dargs)
559	if err != nil {
560		stmt.mc.finish()
561		return nil, err
562	}
563	rows.finish = stmt.mc.finish
564	return rows, err
565}
566
567func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
568	dargs, err := namedValueToValue(args)
569	if err != nil {
570		return nil, err
571	}
572
573	if err := stmt.mc.watchCancel(ctx); err != nil {
574		return nil, err
575	}
576	defer stmt.mc.finish()
577
578	return stmt.Exec(dargs)
579}
580
581func (mc *mysqlConn) watchCancel(ctx context.Context) error {
582	if mc.watching {
583		// Reach here if canceled,
584		// so the connection is already invalid
585		mc.cleanup()
586		return nil
587	}
588	// When ctx is already cancelled, don't watch it.
589	if err := ctx.Err(); err != nil {
590		return err
591	}
592	// When ctx is not cancellable, don't watch it.
593	if ctx.Done() == nil {
594		return nil
595	}
596	// When watcher is not alive, can't watch it.
597	if mc.watcher == nil {
598		return nil
599	}
600
601	mc.watching = true
602	mc.watcher <- ctx
603	return nil
604}
605
606func (mc *mysqlConn) startWatcher() {
607	watcher := make(chan context.Context, 1)
608	mc.watcher = watcher
609	finished := make(chan struct{})
610	mc.finished = finished
611	go func() {
612		for {
613			var ctx context.Context
614			select {
615			case ctx = <-watcher:
616			case <-mc.closech:
617				return
618			}
619
620			select {
621			case <-ctx.Done():
622				mc.cancel(ctx.Err())
623			case <-finished:
624			case <-mc.closech:
625				return
626			}
627		}
628	}()
629}
630
631func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
632	nv.Value, err = converter{}.ConvertValue(nv.Value)
633	return
634}
635
636// ResetSession implements driver.SessionResetter.
637// (From Go 1.10)
638func (mc *mysqlConn) ResetSession(ctx context.Context) error {
639	if mc.closed.IsSet() {
640		return driver.ErrBadConn
641	}
642	mc.reset = true
643	return nil
644}
645
646// IsValid implements driver.Validator interface
647// (From Go 1.15)
648func (mc *mysqlConn) IsValid() bool {
649	return !mc.closed.IsSet()
650}
651