1// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
2//
3// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
4//
5// This Source Code Form is subject to the terms of the Mozilla Public
6// License, v. 2.0. If a copy of the MPL was not distributed with this file,
7// You can obtain one at http://mozilla.org/MPL/2.0/.
8
9package mysql
10
11import (
12	"bytes"
13	"crypto/tls"
14	"database/sql/driver"
15	"encoding/binary"
16	"errors"
17	"fmt"
18	"io"
19	"math"
20	"time"
21)
22
23// Packets documentation:
24// http://dev.mysql.com/doc/internals/en/client-server-protocol.html
25
26// Read packet to buffer 'data'
27func (mc *mysqlConn) readPacket() ([]byte, error) {
28	var prevData []byte
29	for {
30		// read packet header
31		data, err := mc.buf.readNext(4)
32		if err != nil {
33			if cerr := mc.canceled.Value(); cerr != nil {
34				return nil, cerr
35			}
36			errLog.Print(err)
37			mc.Close()
38			return nil, ErrInvalidConn
39		}
40
41		// packet length [24 bit]
42		pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16)
43
44		// check packet sync [8 bit]
45		if data[3] != mc.sequence {
46			if data[3] > mc.sequence {
47				return nil, ErrPktSyncMul
48			}
49			return nil, ErrPktSync
50		}
51		mc.sequence++
52
53		// packets with length 0 terminate a previous packet which is a
54		// multiple of (2^24)−1 bytes long
55		if pktLen == 0 {
56			// there was no previous packet
57			if prevData == nil {
58				errLog.Print(ErrMalformPkt)
59				mc.Close()
60				return nil, ErrInvalidConn
61			}
62
63			return prevData, nil
64		}
65
66		// read packet body [pktLen bytes]
67		data, err = mc.buf.readNext(pktLen)
68		if err != nil {
69			if cerr := mc.canceled.Value(); cerr != nil {
70				return nil, cerr
71			}
72			errLog.Print(err)
73			mc.Close()
74			return nil, ErrInvalidConn
75		}
76
77		// return data if this was the last packet
78		if pktLen < maxPacketSize {
79			// zero allocations for non-split packets
80			if prevData == nil {
81				return data, nil
82			}
83
84			return append(prevData, data...), nil
85		}
86
87		prevData = append(prevData, data...)
88	}
89}
90
91// Write packet buffer 'data'
92func (mc *mysqlConn) writePacket(data []byte) error {
93	pktLen := len(data) - 4
94
95	if pktLen > mc.maxAllowedPacket {
96		return ErrPktTooLarge
97	}
98
99	for {
100		var size int
101		if pktLen >= maxPacketSize {
102			data[0] = 0xff
103			data[1] = 0xff
104			data[2] = 0xff
105			size = maxPacketSize
106		} else {
107			data[0] = byte(pktLen)
108			data[1] = byte(pktLen >> 8)
109			data[2] = byte(pktLen >> 16)
110			size = pktLen
111		}
112		data[3] = mc.sequence
113
114		// Write packet
115		if mc.writeTimeout > 0 {
116			if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil {
117				return err
118			}
119		}
120
121		n, err := mc.netConn.Write(data[:4+size])
122		if err == nil && n == 4+size {
123			mc.sequence++
124			if size != maxPacketSize {
125				return nil
126			}
127			pktLen -= size
128			data = data[size:]
129			continue
130		}
131
132		// Handle error
133		if err == nil { // n != len(data)
134			mc.cleanup()
135			errLog.Print(ErrMalformPkt)
136		} else {
137			if cerr := mc.canceled.Value(); cerr != nil {
138				return cerr
139			}
140			if n == 0 && pktLen == len(data)-4 {
141				// only for the first loop iteration when nothing was written yet
142				return errBadConnNoWrite
143			}
144			mc.cleanup()
145			errLog.Print(err)
146		}
147		return ErrInvalidConn
148	}
149}
150
151/******************************************************************************
152*                           Initialization Process                            *
153******************************************************************************/
154
155// Handshake Initialization Packet
156// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
157func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) {
158	data, err = mc.readPacket()
159	if err != nil {
160		// for init we can rewrite this to ErrBadConn for sql.Driver to retry, since
161		// in connection initialization we don't risk retrying non-idempotent actions.
162		if err == ErrInvalidConn {
163			return nil, "", driver.ErrBadConn
164		}
165		return
166	}
167
168	if data[0] == iERR {
169		return nil, "", mc.handleErrorPacket(data)
170	}
171
172	// protocol version [1 byte]
173	if data[0] < minProtocolVersion {
174		return nil, "", fmt.Errorf(
175			"unsupported protocol version %d. Version %d or higher is required",
176			data[0],
177			minProtocolVersion,
178		)
179	}
180
181	// server version [null terminated string]
182	// connection id [4 bytes]
183	pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4
184
185	// first part of the password cipher [8 bytes]
186	authData := data[pos : pos+8]
187
188	// (filler) always 0x00 [1 byte]
189	pos += 8 + 1
190
191	// capability flags (lower 2 bytes) [2 bytes]
192	mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
193	if mc.flags&clientProtocol41 == 0 {
194		return nil, "", ErrOldProtocol
195	}
196	if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
197		return nil, "", ErrNoTLS
198	}
199	pos += 2
200
201	if len(data) > pos {
202		// character set [1 byte]
203		// status flags [2 bytes]
204		// capability flags (upper 2 bytes) [2 bytes]
205		// length of auth-plugin-data [1 byte]
206		// reserved (all [00]) [10 bytes]
207		pos += 1 + 2 + 2 + 1 + 10
208
209		// second part of the password cipher [mininum 13 bytes],
210		// where len=MAX(13, length of auth-plugin-data - 8)
211		//
212		// The web documentation is ambiguous about the length. However,
213		// according to mysql-5.7/sql/auth/sql_authentication.cc line 538,
214		// the 13th byte is "\0 byte, terminating the second part of
215		// a scramble". So the second part of the password cipher is
216		// a NULL terminated string that's at least 13 bytes with the
217		// last byte being NULL.
218		//
219		// The official Python library uses the fixed length 12
220		// which seems to work but technically could have a hidden bug.
221		authData = append(authData, data[pos:pos+12]...)
222		pos += 13
223
224		// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
225		// \NUL otherwise
226		if end := bytes.IndexByte(data[pos:], 0x00); end != -1 {
227			plugin = string(data[pos : pos+end])
228		} else {
229			plugin = string(data[pos:])
230		}
231
232		// make a memory safe copy of the cipher slice
233		var b [20]byte
234		copy(b[:], authData)
235		return b[:], plugin, nil
236	}
237
238	// make a memory safe copy of the cipher slice
239	var b [8]byte
240	copy(b[:], authData)
241	return b[:], plugin, nil
242}
243
244// Client Authentication Packet
245// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
246func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error {
247	// Adjust client flags based on server support
248	clientFlags := clientProtocol41 |
249		clientSecureConn |
250		clientLongPassword |
251		clientTransactions |
252		clientLocalFiles |
253		clientPluginAuth |
254		clientMultiResults |
255		mc.flags&clientLongFlag
256
257	if mc.cfg.ClientFoundRows {
258		clientFlags |= clientFoundRows
259	}
260
261	// To enable TLS / SSL
262	if mc.cfg.tls != nil {
263		clientFlags |= clientSSL
264	}
265
266	if mc.cfg.MultiStatements {
267		clientFlags |= clientMultiStatements
268	}
269
270	// encode length of the auth plugin data
271	var authRespLEIBuf [9]byte
272	authRespLen := len(authResp)
273	authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen))
274	if len(authRespLEI) > 1 {
275		// if the length can not be written in 1 byte, it must be written as a
276		// length encoded integer
277		clientFlags |= clientPluginAuthLenEncClientData
278	}
279
280	pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1
281
282	// To specify a db name
283	if n := len(mc.cfg.DBName); n > 0 {
284		clientFlags |= clientConnectWithDB
285		pktLen += n + 1
286	}
287
288	// Calculate packet length and get buffer with that size
289	data := mc.buf.takeSmallBuffer(pktLen + 4)
290	if data == nil {
291		// cannot take the buffer. Something must be wrong with the connection
292		errLog.Print(ErrBusyBuffer)
293		return errBadConnNoWrite
294	}
295
296	// ClientFlags [32 bit]
297	data[4] = byte(clientFlags)
298	data[5] = byte(clientFlags >> 8)
299	data[6] = byte(clientFlags >> 16)
300	data[7] = byte(clientFlags >> 24)
301
302	// MaxPacketSize [32 bit] (none)
303	data[8] = 0x00
304	data[9] = 0x00
305	data[10] = 0x00
306	data[11] = 0x00
307
308	// Charset [1 byte]
309	var found bool
310	data[12], found = collations[mc.cfg.Collation]
311	if !found {
312		// Note possibility for false negatives:
313		// could be triggered  although the collation is valid if the
314		// collations map does not contain entries the server supports.
315		return errors.New("unknown collation")
316	}
317
318	// SSL Connection Request Packet
319	// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
320	if mc.cfg.tls != nil {
321		// Send TLS / SSL request packet
322		if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil {
323			return err
324		}
325
326		// Switch to TLS
327		tlsConn := tls.Client(mc.netConn, mc.cfg.tls)
328		if err := tlsConn.Handshake(); err != nil {
329			return err
330		}
331		mc.netConn = tlsConn
332		mc.buf.nc = tlsConn
333	}
334
335	// Filler [23 bytes] (all 0x00)
336	pos := 13
337	for ; pos < 13+23; pos++ {
338		data[pos] = 0
339	}
340
341	// User [null terminated string]
342	if len(mc.cfg.User) > 0 {
343		pos += copy(data[pos:], mc.cfg.User)
344	}
345	data[pos] = 0x00
346	pos++
347
348	// Auth Data [length encoded integer]
349	pos += copy(data[pos:], authRespLEI)
350	pos += copy(data[pos:], authResp)
351
352	// Databasename [null terminated string]
353	if len(mc.cfg.DBName) > 0 {
354		pos += copy(data[pos:], mc.cfg.DBName)
355		data[pos] = 0x00
356		pos++
357	}
358
359	pos += copy(data[pos:], plugin)
360	data[pos] = 0x00
361	pos++
362
363	// Send Auth packet
364	return mc.writePacket(data[:pos])
365}
366
367// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
368func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
369	pktLen := 4 + len(authData)
370	data := mc.buf.takeSmallBuffer(pktLen)
371	if data == nil {
372		// cannot take the buffer. Something must be wrong with the connection
373		errLog.Print(ErrBusyBuffer)
374		return errBadConnNoWrite
375	}
376
377	// Add the auth data [EOF]
378	copy(data[4:], authData)
379	return mc.writePacket(data)
380}
381
382/******************************************************************************
383*                             Command Packets                                 *
384******************************************************************************/
385
386func (mc *mysqlConn) writeCommandPacket(command byte) error {
387	// Reset Packet Sequence
388	mc.sequence = 0
389
390	data := mc.buf.takeSmallBuffer(4 + 1)
391	if data == nil {
392		// cannot take the buffer. Something must be wrong with the connection
393		errLog.Print(ErrBusyBuffer)
394		return errBadConnNoWrite
395	}
396
397	// Add command byte
398	data[4] = command
399
400	// Send CMD packet
401	return mc.writePacket(data)
402}
403
404func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
405	// Reset Packet Sequence
406	mc.sequence = 0
407
408	pktLen := 1 + len(arg)
409	data := mc.buf.takeBuffer(pktLen + 4)
410	if data == nil {
411		// cannot take the buffer. Something must be wrong with the connection
412		errLog.Print(ErrBusyBuffer)
413		return errBadConnNoWrite
414	}
415
416	// Add command byte
417	data[4] = command
418
419	// Add arg
420	copy(data[5:], arg)
421
422	// Send CMD packet
423	return mc.writePacket(data)
424}
425
426func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
427	// Reset Packet Sequence
428	mc.sequence = 0
429
430	data := mc.buf.takeSmallBuffer(4 + 1 + 4)
431	if data == nil {
432		// cannot take the buffer. Something must be wrong with the connection
433		errLog.Print(ErrBusyBuffer)
434		return errBadConnNoWrite
435	}
436
437	// Add command byte
438	data[4] = command
439
440	// Add arg [32 bit]
441	data[5] = byte(arg)
442	data[6] = byte(arg >> 8)
443	data[7] = byte(arg >> 16)
444	data[8] = byte(arg >> 24)
445
446	// Send CMD packet
447	return mc.writePacket(data)
448}
449
450/******************************************************************************
451*                              Result Packets                                 *
452******************************************************************************/
453
454func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
455	data, err := mc.readPacket()
456	if err != nil {
457		return nil, "", err
458	}
459
460	// packet indicator
461	switch data[0] {
462
463	case iOK:
464		return nil, "", mc.handleOkPacket(data)
465
466	case iAuthMoreData:
467		return data[1:], "", err
468
469	case iEOF:
470		if len(data) == 1 {
471			// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
472			return nil, "mysql_old_password", nil
473		}
474		pluginEndIndex := bytes.IndexByte(data, 0x00)
475		if pluginEndIndex < 0 {
476			return nil, "", ErrMalformPkt
477		}
478		plugin := string(data[1:pluginEndIndex])
479		authData := data[pluginEndIndex+1:]
480		return authData, plugin, nil
481
482	default: // Error otherwise
483		return nil, "", mc.handleErrorPacket(data)
484	}
485}
486
487// Returns error if Packet is not an 'Result OK'-Packet
488func (mc *mysqlConn) readResultOK() error {
489	data, err := mc.readPacket()
490	if err != nil {
491		return err
492	}
493
494	if data[0] == iOK {
495		return mc.handleOkPacket(data)
496	}
497	return mc.handleErrorPacket(data)
498}
499
500// Result Set Header Packet
501// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
502func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
503	data, err := mc.readPacket()
504	if err == nil {
505		switch data[0] {
506
507		case iOK:
508			return 0, mc.handleOkPacket(data)
509
510		case iERR:
511			return 0, mc.handleErrorPacket(data)
512
513		case iLocalInFile:
514			return 0, mc.handleInFileRequest(string(data[1:]))
515		}
516
517		// column count
518		num, _, n := readLengthEncodedInteger(data)
519		if n-len(data) == 0 {
520			return int(num), nil
521		}
522
523		return 0, ErrMalformPkt
524	}
525	return 0, err
526}
527
528// Error Packet
529// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet
530func (mc *mysqlConn) handleErrorPacket(data []byte) error {
531	if data[0] != iERR {
532		return ErrMalformPkt
533	}
534
535	// 0xff [1 byte]
536
537	// Error Number [16 bit uint]
538	errno := binary.LittleEndian.Uint16(data[1:3])
539
540	// 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION
541	// 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover)
542	if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly {
543		// Oops; we are connected to a read-only connection, and won't be able
544		// to issue any write statements. Since RejectReadOnly is configured,
545		// we throw away this connection hoping this one would have write
546		// permission. This is specifically for a possible race condition
547		// during failover (e.g. on AWS Aurora). See README.md for more.
548		//
549		// We explicitly close the connection before returning
550		// driver.ErrBadConn to ensure that `database/sql` purges this
551		// connection and initiates a new one for next statement next time.
552		mc.Close()
553		return driver.ErrBadConn
554	}
555
556	pos := 3
557
558	// SQL State [optional: # + 5bytes string]
559	if data[3] == 0x23 {
560		//sqlstate := string(data[4 : 4+5])
561		pos = 9
562	}
563
564	// Error Message [string]
565	return &MySQLError{
566		Number:  errno,
567		Message: string(data[pos:]),
568	}
569}
570
571func readStatus(b []byte) statusFlag {
572	return statusFlag(b[0]) | statusFlag(b[1])<<8
573}
574
575// Ok Packet
576// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
577func (mc *mysqlConn) handleOkPacket(data []byte) error {
578	var n, m int
579
580	// 0x00 [1 byte]
581
582	// Affected rows [Length Coded Binary]
583	mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
584
585	// Insert id [Length Coded Binary]
586	mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
587
588	// server_status [2 bytes]
589	mc.status = readStatus(data[1+n+m : 1+n+m+2])
590	if mc.status&statusMoreResultsExists != 0 {
591		return nil
592	}
593
594	// warning count [2 bytes]
595
596	return nil
597}
598
599// Read Packets as Field Packets until EOF-Packet or an Error appears
600// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
601func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
602	columns := make([]mysqlField, count)
603
604	for i := 0; ; i++ {
605		data, err := mc.readPacket()
606		if err != nil {
607			return nil, err
608		}
609
610		// EOF Packet
611		if data[0] == iEOF && (len(data) == 5 || len(data) == 1) {
612			if i == count {
613				return columns, nil
614			}
615			return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns))
616		}
617
618		// Catalog
619		pos, err := skipLengthEncodedString(data)
620		if err != nil {
621			return nil, err
622		}
623
624		// Database [len coded string]
625		n, err := skipLengthEncodedString(data[pos:])
626		if err != nil {
627			return nil, err
628		}
629		pos += n
630
631		// Table [len coded string]
632		if mc.cfg.ColumnsWithAlias {
633			tableName, _, n, err := readLengthEncodedString(data[pos:])
634			if err != nil {
635				return nil, err
636			}
637			pos += n
638			columns[i].tableName = string(tableName)
639		} else {
640			n, err = skipLengthEncodedString(data[pos:])
641			if err != nil {
642				return nil, err
643			}
644			pos += n
645		}
646
647		// Original table [len coded string]
648		n, err = skipLengthEncodedString(data[pos:])
649		if err != nil {
650			return nil, err
651		}
652		pos += n
653
654		// Name [len coded string]
655		name, _, n, err := readLengthEncodedString(data[pos:])
656		if err != nil {
657			return nil, err
658		}
659		columns[i].name = string(name)
660		pos += n
661
662		// Original name [len coded string]
663		n, err = skipLengthEncodedString(data[pos:])
664		if err != nil {
665			return nil, err
666		}
667		pos += n
668
669		// Filler [uint8]
670		pos++
671
672		// Charset [charset, collation uint8]
673		columns[i].charSet = data[pos]
674		pos += 2
675
676		// Length [uint32]
677		columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4])
678		pos += 4
679
680		// Field type [uint8]
681		columns[i].fieldType = fieldType(data[pos])
682		pos++
683
684		// Flags [uint16]
685		columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
686		pos += 2
687
688		// Decimals [uint8]
689		columns[i].decimals = data[pos]
690		//pos++
691
692		// Default value [len coded binary]
693		//if pos < len(data) {
694		//	defaultVal, _, err = bytesToLengthCodedBinary(data[pos:])
695		//}
696	}
697}
698
699// Read Packets as Field Packets until EOF-Packet or an Error appears
700// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow
701func (rows *textRows) readRow(dest []driver.Value) error {
702	mc := rows.mc
703
704	if rows.rs.done {
705		return io.EOF
706	}
707
708	data, err := mc.readPacket()
709	if err != nil {
710		return err
711	}
712
713	// EOF Packet
714	if data[0] == iEOF && len(data) == 5 {
715		// server_status [2 bytes]
716		rows.mc.status = readStatus(data[3:])
717		rows.rs.done = true
718		if !rows.HasNextResultSet() {
719			rows.mc = nil
720		}
721		return io.EOF
722	}
723	if data[0] == iERR {
724		rows.mc = nil
725		return mc.handleErrorPacket(data)
726	}
727
728	// RowSet Packet
729	var n int
730	var isNull bool
731	pos := 0
732
733	for i := range dest {
734		// Read bytes and convert to string
735		dest[i], isNull, n, err = readLengthEncodedString(data[pos:])
736		pos += n
737		if err == nil {
738			if !isNull {
739				if !mc.parseTime {
740					continue
741				} else {
742					switch rows.rs.columns[i].fieldType {
743					case fieldTypeTimestamp, fieldTypeDateTime,
744						fieldTypeDate, fieldTypeNewDate:
745						dest[i], err = parseDateTime(
746							string(dest[i].([]byte)),
747							mc.cfg.Loc,
748						)
749						if err == nil {
750							continue
751						}
752					default:
753						continue
754					}
755				}
756
757			} else {
758				dest[i] = nil
759				continue
760			}
761		}
762		return err // err != nil
763	}
764
765	return nil
766}
767
768// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read
769func (mc *mysqlConn) readUntilEOF() error {
770	for {
771		data, err := mc.readPacket()
772		if err != nil {
773			return err
774		}
775
776		switch data[0] {
777		case iERR:
778			return mc.handleErrorPacket(data)
779		case iEOF:
780			if len(data) == 5 {
781				mc.status = readStatus(data[3:])
782			}
783			return nil
784		}
785	}
786}
787
788/******************************************************************************
789*                           Prepared Statements                               *
790******************************************************************************/
791
792// Prepare Result Packets
793// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html
794func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
795	data, err := stmt.mc.readPacket()
796	if err == nil {
797		// packet indicator [1 byte]
798		if data[0] != iOK {
799			return 0, stmt.mc.handleErrorPacket(data)
800		}
801
802		// statement id [4 bytes]
803		stmt.id = binary.LittleEndian.Uint32(data[1:5])
804
805		// Column count [16 bit uint]
806		columnCount := binary.LittleEndian.Uint16(data[5:7])
807
808		// Param count [16 bit uint]
809		stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9]))
810
811		// Reserved [8 bit]
812
813		// Warning count [16 bit uint]
814
815		return columnCount, nil
816	}
817	return 0, err
818}
819
820// http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html
821func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
822	maxLen := stmt.mc.maxAllowedPacket - 1
823	pktLen := maxLen
824
825	// After the header (bytes 0-3) follows before the data:
826	// 1 byte command
827	// 4 bytes stmtID
828	// 2 bytes paramID
829	const dataOffset = 1 + 4 + 2
830
831	// Cannot use the write buffer since
832	// a) the buffer is too small
833	// b) it is in use
834	data := make([]byte, 4+1+4+2+len(arg))
835
836	copy(data[4+dataOffset:], arg)
837
838	for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset {
839		if dataOffset+argLen < maxLen {
840			pktLen = dataOffset + argLen
841		}
842
843		stmt.mc.sequence = 0
844		// Add command byte [1 byte]
845		data[4] = comStmtSendLongData
846
847		// Add stmtID [32 bit]
848		data[5] = byte(stmt.id)
849		data[6] = byte(stmt.id >> 8)
850		data[7] = byte(stmt.id >> 16)
851		data[8] = byte(stmt.id >> 24)
852
853		// Add paramID [16 bit]
854		data[9] = byte(paramID)
855		data[10] = byte(paramID >> 8)
856
857		// Send CMD packet
858		err := stmt.mc.writePacket(data[:4+pktLen])
859		if err == nil {
860			data = data[pktLen-dataOffset:]
861			continue
862		}
863		return err
864
865	}
866
867	// Reset Packet Sequence
868	stmt.mc.sequence = 0
869	return nil
870}
871
872// Execute Prepared Statement
873// http://dev.mysql.com/doc/internals/en/com-stmt-execute.html
874func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
875	if len(args) != stmt.paramCount {
876		return fmt.Errorf(
877			"argument count mismatch (got: %d; has: %d)",
878			len(args),
879			stmt.paramCount,
880		)
881	}
882
883	const minPktLen = 4 + 1 + 4 + 1 + 4
884	mc := stmt.mc
885
886	// Determine threshould dynamically to avoid packet size shortage.
887	longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1)
888	if longDataSize < 64 {
889		longDataSize = 64
890	}
891
892	// Reset packet-sequence
893	mc.sequence = 0
894
895	var data []byte
896
897	if len(args) == 0 {
898		data = mc.buf.takeBuffer(minPktLen)
899	} else {
900		data = mc.buf.takeCompleteBuffer()
901	}
902	if data == nil {
903		// cannot take the buffer. Something must be wrong with the connection
904		errLog.Print(ErrBusyBuffer)
905		return errBadConnNoWrite
906	}
907
908	// command [1 byte]
909	data[4] = comStmtExecute
910
911	// statement_id [4 bytes]
912	data[5] = byte(stmt.id)
913	data[6] = byte(stmt.id >> 8)
914	data[7] = byte(stmt.id >> 16)
915	data[8] = byte(stmt.id >> 24)
916
917	// flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte]
918	data[9] = 0x00
919
920	// iteration_count (uint32(1)) [4 bytes]
921	data[10] = 0x01
922	data[11] = 0x00
923	data[12] = 0x00
924	data[13] = 0x00
925
926	if len(args) > 0 {
927		pos := minPktLen
928
929		var nullMask []byte
930		if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) {
931			// buffer has to be extended but we don't know by how much so
932			// we depend on append after all data with known sizes fit.
933			// We stop at that because we deal with a lot of columns here
934			// which makes the required allocation size hard to guess.
935			tmp := make([]byte, pos+maskLen+typesLen)
936			copy(tmp[:pos], data[:pos])
937			data = tmp
938			nullMask = data[pos : pos+maskLen]
939			pos += maskLen
940		} else {
941			nullMask = data[pos : pos+maskLen]
942			for i := 0; i < maskLen; i++ {
943				nullMask[i] = 0
944			}
945			pos += maskLen
946		}
947
948		// newParameterBoundFlag 1 [1 byte]
949		data[pos] = 0x01
950		pos++
951
952		// type of each parameter [len(args)*2 bytes]
953		paramTypes := data[pos:]
954		pos += len(args) * 2
955
956		// value of each parameter [n bytes]
957		paramValues := data[pos:pos]
958		valuesCap := cap(paramValues)
959
960		for i, arg := range args {
961			// build NULL-bitmap
962			if arg == nil {
963				nullMask[i/8] |= 1 << (uint(i) & 7)
964				paramTypes[i+i] = byte(fieldTypeNULL)
965				paramTypes[i+i+1] = 0x00
966				continue
967			}
968
969			// cache types and values
970			switch v := arg.(type) {
971			case int64:
972				paramTypes[i+i] = byte(fieldTypeLongLong)
973				paramTypes[i+i+1] = 0x00
974
975				if cap(paramValues)-len(paramValues)-8 >= 0 {
976					paramValues = paramValues[:len(paramValues)+8]
977					binary.LittleEndian.PutUint64(
978						paramValues[len(paramValues)-8:],
979						uint64(v),
980					)
981				} else {
982					paramValues = append(paramValues,
983						uint64ToBytes(uint64(v))...,
984					)
985				}
986
987			case float64:
988				paramTypes[i+i] = byte(fieldTypeDouble)
989				paramTypes[i+i+1] = 0x00
990
991				if cap(paramValues)-len(paramValues)-8 >= 0 {
992					paramValues = paramValues[:len(paramValues)+8]
993					binary.LittleEndian.PutUint64(
994						paramValues[len(paramValues)-8:],
995						math.Float64bits(v),
996					)
997				} else {
998					paramValues = append(paramValues,
999						uint64ToBytes(math.Float64bits(v))...,
1000					)
1001				}
1002
1003			case bool:
1004				paramTypes[i+i] = byte(fieldTypeTiny)
1005				paramTypes[i+i+1] = 0x00
1006
1007				if v {
1008					paramValues = append(paramValues, 0x01)
1009				} else {
1010					paramValues = append(paramValues, 0x00)
1011				}
1012
1013			case []byte:
1014				// Common case (non-nil value) first
1015				if v != nil {
1016					paramTypes[i+i] = byte(fieldTypeString)
1017					paramTypes[i+i+1] = 0x00
1018
1019					if len(v) < longDataSize {
1020						paramValues = appendLengthEncodedInteger(paramValues,
1021							uint64(len(v)),
1022						)
1023						paramValues = append(paramValues, v...)
1024					} else {
1025						if err := stmt.writeCommandLongData(i, v); err != nil {
1026							return err
1027						}
1028					}
1029					continue
1030				}
1031
1032				// Handle []byte(nil) as a NULL value
1033				nullMask[i/8] |= 1 << (uint(i) & 7)
1034				paramTypes[i+i] = byte(fieldTypeNULL)
1035				paramTypes[i+i+1] = 0x00
1036
1037			case string:
1038				paramTypes[i+i] = byte(fieldTypeString)
1039				paramTypes[i+i+1] = 0x00
1040
1041				if len(v) < longDataSize {
1042					paramValues = appendLengthEncodedInteger(paramValues,
1043						uint64(len(v)),
1044					)
1045					paramValues = append(paramValues, v...)
1046				} else {
1047					if err := stmt.writeCommandLongData(i, []byte(v)); err != nil {
1048						return err
1049					}
1050				}
1051
1052			case time.Time:
1053				paramTypes[i+i] = byte(fieldTypeString)
1054				paramTypes[i+i+1] = 0x00
1055
1056				var a [64]byte
1057				var b = a[:0]
1058
1059				if v.IsZero() {
1060					b = append(b, "0000-00-00"...)
1061				} else {
1062					b = v.In(mc.cfg.Loc).AppendFormat(b, timeFormat)
1063				}
1064
1065				paramValues = appendLengthEncodedInteger(paramValues,
1066					uint64(len(b)),
1067				)
1068				paramValues = append(paramValues, b...)
1069
1070			default:
1071				return fmt.Errorf("cannot convert type: %T", arg)
1072			}
1073		}
1074
1075		// Check if param values exceeded the available buffer
1076		// In that case we must build the data packet with the new values buffer
1077		if valuesCap != cap(paramValues) {
1078			data = append(data[:pos], paramValues...)
1079			mc.buf.buf = data
1080		}
1081
1082		pos += len(paramValues)
1083		data = data[:pos]
1084	}
1085
1086	return mc.writePacket(data)
1087}
1088
1089func (mc *mysqlConn) discardResults() error {
1090	for mc.status&statusMoreResultsExists != 0 {
1091		resLen, err := mc.readResultSetHeaderPacket()
1092		if err != nil {
1093			return err
1094		}
1095		if resLen > 0 {
1096			// columns
1097			if err := mc.readUntilEOF(); err != nil {
1098				return err
1099			}
1100			// rows
1101			if err := mc.readUntilEOF(); err != nil {
1102				return err
1103			}
1104		}
1105	}
1106	return nil
1107}
1108
1109// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
1110func (rows *binaryRows) readRow(dest []driver.Value) error {
1111	data, err := rows.mc.readPacket()
1112	if err != nil {
1113		return err
1114	}
1115
1116	// packet indicator [1 byte]
1117	if data[0] != iOK {
1118		// EOF Packet
1119		if data[0] == iEOF && len(data) == 5 {
1120			rows.mc.status = readStatus(data[3:])
1121			rows.rs.done = true
1122			if !rows.HasNextResultSet() {
1123				rows.mc = nil
1124			}
1125			return io.EOF
1126		}
1127		mc := rows.mc
1128		rows.mc = nil
1129
1130		// Error otherwise
1131		return mc.handleErrorPacket(data)
1132	}
1133
1134	// NULL-bitmap,  [(column-count + 7 + 2) / 8 bytes]
1135	pos := 1 + (len(dest)+7+2)>>3
1136	nullMask := data[1:pos]
1137
1138	for i := range dest {
1139		// Field is NULL
1140		// (byte >> bit-pos) % 2 == 1
1141		if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 {
1142			dest[i] = nil
1143			continue
1144		}
1145
1146		// Convert to byte-coded string
1147		switch rows.rs.columns[i].fieldType {
1148		case fieldTypeNULL:
1149			dest[i] = nil
1150			continue
1151
1152		// Numeric Types
1153		case fieldTypeTiny:
1154			if rows.rs.columns[i].flags&flagUnsigned != 0 {
1155				dest[i] = int64(data[pos])
1156			} else {
1157				dest[i] = int64(int8(data[pos]))
1158			}
1159			pos++
1160			continue
1161
1162		case fieldTypeShort, fieldTypeYear:
1163			if rows.rs.columns[i].flags&flagUnsigned != 0 {
1164				dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
1165			} else {
1166				dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
1167			}
1168			pos += 2
1169			continue
1170
1171		case fieldTypeInt24, fieldTypeLong:
1172			if rows.rs.columns[i].flags&flagUnsigned != 0 {
1173				dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
1174			} else {
1175				dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
1176			}
1177			pos += 4
1178			continue
1179
1180		case fieldTypeLongLong:
1181			if rows.rs.columns[i].flags&flagUnsigned != 0 {
1182				val := binary.LittleEndian.Uint64(data[pos : pos+8])
1183				if val > math.MaxInt64 {
1184					dest[i] = uint64ToString(val)
1185				} else {
1186					dest[i] = int64(val)
1187				}
1188			} else {
1189				dest[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8]))
1190			}
1191			pos += 8
1192			continue
1193
1194		case fieldTypeFloat:
1195			dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))
1196			pos += 4
1197			continue
1198
1199		case fieldTypeDouble:
1200			dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[pos : pos+8]))
1201			pos += 8
1202			continue
1203
1204		// Length coded Binary Strings
1205		case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
1206			fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
1207			fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
1208			fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON:
1209			var isNull bool
1210			var n int
1211			dest[i], isNull, n, err = readLengthEncodedString(data[pos:])
1212			pos += n
1213			if err == nil {
1214				if !isNull {
1215					continue
1216				} else {
1217					dest[i] = nil
1218					continue
1219				}
1220			}
1221			return err
1222
1223		case
1224			fieldTypeDate, fieldTypeNewDate, // Date YYYY-MM-DD
1225			fieldTypeTime,                         // Time [-][H]HH:MM:SS[.fractal]
1226			fieldTypeTimestamp, fieldTypeDateTime: // Timestamp YYYY-MM-DD HH:MM:SS[.fractal]
1227
1228			num, isNull, n := readLengthEncodedInteger(data[pos:])
1229			pos += n
1230
1231			switch {
1232			case isNull:
1233				dest[i] = nil
1234				continue
1235			case rows.rs.columns[i].fieldType == fieldTypeTime:
1236				// database/sql does not support an equivalent to TIME, return a string
1237				var dstlen uint8
1238				switch decimals := rows.rs.columns[i].decimals; decimals {
1239				case 0x00, 0x1f:
1240					dstlen = 8
1241				case 1, 2, 3, 4, 5, 6:
1242					dstlen = 8 + 1 + decimals
1243				default:
1244					return fmt.Errorf(
1245						"protocol error, illegal decimals value %d",
1246						rows.rs.columns[i].decimals,
1247					)
1248				}
1249				dest[i], err = formatBinaryTime(data[pos:pos+int(num)], dstlen)
1250			case rows.mc.parseTime:
1251				dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
1252			default:
1253				var dstlen uint8
1254				if rows.rs.columns[i].fieldType == fieldTypeDate {
1255					dstlen = 10
1256				} else {
1257					switch decimals := rows.rs.columns[i].decimals; decimals {
1258					case 0x00, 0x1f:
1259						dstlen = 19
1260					case 1, 2, 3, 4, 5, 6:
1261						dstlen = 19 + 1 + decimals
1262					default:
1263						return fmt.Errorf(
1264							"protocol error, illegal decimals value %d",
1265							rows.rs.columns[i].decimals,
1266						)
1267					}
1268				}
1269				dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen)
1270			}
1271
1272			if err == nil {
1273				pos += int(num)
1274				continue
1275			} else {
1276				return err
1277			}
1278
1279		// Please report if this happens!
1280		default:
1281			return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType)
1282		}
1283	}
1284
1285	return nil
1286}
1287