1// Go MySQL Driver - A MySQL-Driver for Go's database/sql package 2// 3// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. 4// 5// This Source Code Form is subject to the terms of the Mozilla Public 6// License, v. 2.0. If a copy of the MPL was not distributed with this file, 7// You can obtain one at http://mozilla.org/MPL/2.0/. 8 9package mysql 10 11import ( 12 "bytes" 13 "crypto/tls" 14 "database/sql/driver" 15 "encoding/binary" 16 "errors" 17 "fmt" 18 "io" 19 "math" 20 "time" 21) 22 23// Packets documentation: 24// http://dev.mysql.com/doc/internals/en/client-server-protocol.html 25 26// Read packet to buffer 'data' 27func (mc *mysqlConn) readPacket() ([]byte, error) { 28 var prevData []byte 29 for { 30 // read packet header 31 data, err := mc.buf.readNext(4) 32 if err != nil { 33 if cerr := mc.canceled.Value(); cerr != nil { 34 return nil, cerr 35 } 36 errLog.Print(err) 37 mc.Close() 38 return nil, ErrInvalidConn 39 } 40 41 // packet length [24 bit] 42 pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) 43 44 // check packet sync [8 bit] 45 if data[3] != mc.sequence { 46 if data[3] > mc.sequence { 47 return nil, ErrPktSyncMul 48 } 49 return nil, ErrPktSync 50 } 51 mc.sequence++ 52 53 // packets with length 0 terminate a previous packet which is a 54 // multiple of (2^24)−1 bytes long 55 if pktLen == 0 { 56 // there was no previous packet 57 if prevData == nil { 58 errLog.Print(ErrMalformPkt) 59 mc.Close() 60 return nil, ErrInvalidConn 61 } 62 63 return prevData, nil 64 } 65 66 // read packet body [pktLen bytes] 67 data, err = mc.buf.readNext(pktLen) 68 if err != nil { 69 if cerr := mc.canceled.Value(); cerr != nil { 70 return nil, cerr 71 } 72 errLog.Print(err) 73 mc.Close() 74 return nil, ErrInvalidConn 75 } 76 77 // return data if this was the last packet 78 if pktLen < maxPacketSize { 79 // zero allocations for non-split packets 80 if prevData == nil { 81 return data, nil 82 } 83 84 return append(prevData, data...), nil 85 } 86 87 prevData = append(prevData, data...) 88 } 89} 90 91// Write packet buffer 'data' 92func (mc *mysqlConn) writePacket(data []byte) error { 93 pktLen := len(data) - 4 94 95 if pktLen > mc.maxAllowedPacket { 96 return ErrPktTooLarge 97 } 98 99 for { 100 var size int 101 if pktLen >= maxPacketSize { 102 data[0] = 0xff 103 data[1] = 0xff 104 data[2] = 0xff 105 size = maxPacketSize 106 } else { 107 data[0] = byte(pktLen) 108 data[1] = byte(pktLen >> 8) 109 data[2] = byte(pktLen >> 16) 110 size = pktLen 111 } 112 data[3] = mc.sequence 113 114 // Write packet 115 if mc.writeTimeout > 0 { 116 if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { 117 return err 118 } 119 } 120 121 n, err := mc.netConn.Write(data[:4+size]) 122 if err == nil && n == 4+size { 123 mc.sequence++ 124 if size != maxPacketSize { 125 return nil 126 } 127 pktLen -= size 128 data = data[size:] 129 continue 130 } 131 132 // Handle error 133 if err == nil { // n != len(data) 134 mc.cleanup() 135 errLog.Print(ErrMalformPkt) 136 } else { 137 if cerr := mc.canceled.Value(); cerr != nil { 138 return cerr 139 } 140 if n == 0 && pktLen == len(data)-4 { 141 // only for the first loop iteration when nothing was written yet 142 return errBadConnNoWrite 143 } 144 mc.cleanup() 145 errLog.Print(err) 146 } 147 return ErrInvalidConn 148 } 149} 150 151/****************************************************************************** 152* Initialization Process * 153******************************************************************************/ 154 155// Handshake Initialization Packet 156// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake 157func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) { 158 data, err = mc.readPacket() 159 if err != nil { 160 // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since 161 // in connection initialization we don't risk retrying non-idempotent actions. 162 if err == ErrInvalidConn { 163 return nil, "", driver.ErrBadConn 164 } 165 return 166 } 167 168 if data[0] == iERR { 169 return nil, "", mc.handleErrorPacket(data) 170 } 171 172 // protocol version [1 byte] 173 if data[0] < minProtocolVersion { 174 return nil, "", fmt.Errorf( 175 "unsupported protocol version %d. Version %d or higher is required", 176 data[0], 177 minProtocolVersion, 178 ) 179 } 180 181 // server version [null terminated string] 182 // connection id [4 bytes] 183 pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 184 185 // first part of the password cipher [8 bytes] 186 authData := data[pos : pos+8] 187 188 // (filler) always 0x00 [1 byte] 189 pos += 8 + 1 190 191 // capability flags (lower 2 bytes) [2 bytes] 192 mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) 193 if mc.flags&clientProtocol41 == 0 { 194 return nil, "", ErrOldProtocol 195 } 196 if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { 197 return nil, "", ErrNoTLS 198 } 199 pos += 2 200 201 if len(data) > pos { 202 // character set [1 byte] 203 // status flags [2 bytes] 204 // capability flags (upper 2 bytes) [2 bytes] 205 // length of auth-plugin-data [1 byte] 206 // reserved (all [00]) [10 bytes] 207 pos += 1 + 2 + 2 + 1 + 10 208 209 // second part of the password cipher [mininum 13 bytes], 210 // where len=MAX(13, length of auth-plugin-data - 8) 211 // 212 // The web documentation is ambiguous about the length. However, 213 // according to mysql-5.7/sql/auth/sql_authentication.cc line 538, 214 // the 13th byte is "\0 byte, terminating the second part of 215 // a scramble". So the second part of the password cipher is 216 // a NULL terminated string that's at least 13 bytes with the 217 // last byte being NULL. 218 // 219 // The official Python library uses the fixed length 12 220 // which seems to work but technically could have a hidden bug. 221 authData = append(authData, data[pos:pos+12]...) 222 pos += 13 223 224 // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2) 225 // \NUL otherwise 226 if end := bytes.IndexByte(data[pos:], 0x00); end != -1 { 227 plugin = string(data[pos : pos+end]) 228 } else { 229 plugin = string(data[pos:]) 230 } 231 232 // make a memory safe copy of the cipher slice 233 var b [20]byte 234 copy(b[:], authData) 235 return b[:], plugin, nil 236 } 237 238 // make a memory safe copy of the cipher slice 239 var b [8]byte 240 copy(b[:], authData) 241 return b[:], plugin, nil 242} 243 244// Client Authentication Packet 245// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse 246func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error { 247 // Adjust client flags based on server support 248 clientFlags := clientProtocol41 | 249 clientSecureConn | 250 clientLongPassword | 251 clientTransactions | 252 clientLocalFiles | 253 clientPluginAuth | 254 clientMultiResults | 255 mc.flags&clientLongFlag 256 257 if mc.cfg.ClientFoundRows { 258 clientFlags |= clientFoundRows 259 } 260 261 // To enable TLS / SSL 262 if mc.cfg.tls != nil { 263 clientFlags |= clientSSL 264 } 265 266 if mc.cfg.MultiStatements { 267 clientFlags |= clientMultiStatements 268 } 269 270 // encode length of the auth plugin data 271 var authRespLEIBuf [9]byte 272 authRespLen := len(authResp) 273 authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen)) 274 if len(authRespLEI) > 1 { 275 // if the length can not be written in 1 byte, it must be written as a 276 // length encoded integer 277 clientFlags |= clientPluginAuthLenEncClientData 278 } 279 280 pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 281 282 // To specify a db name 283 if n := len(mc.cfg.DBName); n > 0 { 284 clientFlags |= clientConnectWithDB 285 pktLen += n + 1 286 } 287 288 // Calculate packet length and get buffer with that size 289 data := mc.buf.takeSmallBuffer(pktLen + 4) 290 if data == nil { 291 // cannot take the buffer. Something must be wrong with the connection 292 errLog.Print(ErrBusyBuffer) 293 return errBadConnNoWrite 294 } 295 296 // ClientFlags [32 bit] 297 data[4] = byte(clientFlags) 298 data[5] = byte(clientFlags >> 8) 299 data[6] = byte(clientFlags >> 16) 300 data[7] = byte(clientFlags >> 24) 301 302 // MaxPacketSize [32 bit] (none) 303 data[8] = 0x00 304 data[9] = 0x00 305 data[10] = 0x00 306 data[11] = 0x00 307 308 // Charset [1 byte] 309 var found bool 310 data[12], found = collations[mc.cfg.Collation] 311 if !found { 312 // Note possibility for false negatives: 313 // could be triggered although the collation is valid if the 314 // collations map does not contain entries the server supports. 315 return errors.New("unknown collation") 316 } 317 318 // SSL Connection Request Packet 319 // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest 320 if mc.cfg.tls != nil { 321 // Send TLS / SSL request packet 322 if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil { 323 return err 324 } 325 326 // Switch to TLS 327 tlsConn := tls.Client(mc.netConn, mc.cfg.tls) 328 if err := tlsConn.Handshake(); err != nil { 329 return err 330 } 331 mc.netConn = tlsConn 332 mc.buf.nc = tlsConn 333 } 334 335 // Filler [23 bytes] (all 0x00) 336 pos := 13 337 for ; pos < 13+23; pos++ { 338 data[pos] = 0 339 } 340 341 // User [null terminated string] 342 if len(mc.cfg.User) > 0 { 343 pos += copy(data[pos:], mc.cfg.User) 344 } 345 data[pos] = 0x00 346 pos++ 347 348 // Auth Data [length encoded integer] 349 pos += copy(data[pos:], authRespLEI) 350 pos += copy(data[pos:], authResp) 351 352 // Databasename [null terminated string] 353 if len(mc.cfg.DBName) > 0 { 354 pos += copy(data[pos:], mc.cfg.DBName) 355 data[pos] = 0x00 356 pos++ 357 } 358 359 pos += copy(data[pos:], plugin) 360 data[pos] = 0x00 361 pos++ 362 363 // Send Auth packet 364 return mc.writePacket(data[:pos]) 365} 366 367// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse 368func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { 369 pktLen := 4 + len(authData) 370 data := mc.buf.takeSmallBuffer(pktLen) 371 if data == nil { 372 // cannot take the buffer. Something must be wrong with the connection 373 errLog.Print(ErrBusyBuffer) 374 return errBadConnNoWrite 375 } 376 377 // Add the auth data [EOF] 378 copy(data[4:], authData) 379 return mc.writePacket(data) 380} 381 382/****************************************************************************** 383* Command Packets * 384******************************************************************************/ 385 386func (mc *mysqlConn) writeCommandPacket(command byte) error { 387 // Reset Packet Sequence 388 mc.sequence = 0 389 390 data := mc.buf.takeSmallBuffer(4 + 1) 391 if data == nil { 392 // cannot take the buffer. Something must be wrong with the connection 393 errLog.Print(ErrBusyBuffer) 394 return errBadConnNoWrite 395 } 396 397 // Add command byte 398 data[4] = command 399 400 // Send CMD packet 401 return mc.writePacket(data) 402} 403 404func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { 405 // Reset Packet Sequence 406 mc.sequence = 0 407 408 pktLen := 1 + len(arg) 409 data := mc.buf.takeBuffer(pktLen + 4) 410 if data == nil { 411 // cannot take the buffer. Something must be wrong with the connection 412 errLog.Print(ErrBusyBuffer) 413 return errBadConnNoWrite 414 } 415 416 // Add command byte 417 data[4] = command 418 419 // Add arg 420 copy(data[5:], arg) 421 422 // Send CMD packet 423 return mc.writePacket(data) 424} 425 426func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { 427 // Reset Packet Sequence 428 mc.sequence = 0 429 430 data := mc.buf.takeSmallBuffer(4 + 1 + 4) 431 if data == nil { 432 // cannot take the buffer. Something must be wrong with the connection 433 errLog.Print(ErrBusyBuffer) 434 return errBadConnNoWrite 435 } 436 437 // Add command byte 438 data[4] = command 439 440 // Add arg [32 bit] 441 data[5] = byte(arg) 442 data[6] = byte(arg >> 8) 443 data[7] = byte(arg >> 16) 444 data[8] = byte(arg >> 24) 445 446 // Send CMD packet 447 return mc.writePacket(data) 448} 449 450/****************************************************************************** 451* Result Packets * 452******************************************************************************/ 453 454func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { 455 data, err := mc.readPacket() 456 if err != nil { 457 return nil, "", err 458 } 459 460 // packet indicator 461 switch data[0] { 462 463 case iOK: 464 return nil, "", mc.handleOkPacket(data) 465 466 case iAuthMoreData: 467 return data[1:], "", err 468 469 case iEOF: 470 if len(data) == 1 { 471 // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest 472 return nil, "mysql_old_password", nil 473 } 474 pluginEndIndex := bytes.IndexByte(data, 0x00) 475 if pluginEndIndex < 0 { 476 return nil, "", ErrMalformPkt 477 } 478 plugin := string(data[1:pluginEndIndex]) 479 authData := data[pluginEndIndex+1:] 480 return authData, plugin, nil 481 482 default: // Error otherwise 483 return nil, "", mc.handleErrorPacket(data) 484 } 485} 486 487// Returns error if Packet is not an 'Result OK'-Packet 488func (mc *mysqlConn) readResultOK() error { 489 data, err := mc.readPacket() 490 if err != nil { 491 return err 492 } 493 494 if data[0] == iOK { 495 return mc.handleOkPacket(data) 496 } 497 return mc.handleErrorPacket(data) 498} 499 500// Result Set Header Packet 501// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset 502func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { 503 data, err := mc.readPacket() 504 if err == nil { 505 switch data[0] { 506 507 case iOK: 508 return 0, mc.handleOkPacket(data) 509 510 case iERR: 511 return 0, mc.handleErrorPacket(data) 512 513 case iLocalInFile: 514 return 0, mc.handleInFileRequest(string(data[1:])) 515 } 516 517 // column count 518 num, _, n := readLengthEncodedInteger(data) 519 if n-len(data) == 0 { 520 return int(num), nil 521 } 522 523 return 0, ErrMalformPkt 524 } 525 return 0, err 526} 527 528// Error Packet 529// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet 530func (mc *mysqlConn) handleErrorPacket(data []byte) error { 531 if data[0] != iERR { 532 return ErrMalformPkt 533 } 534 535 // 0xff [1 byte] 536 537 // Error Number [16 bit uint] 538 errno := binary.LittleEndian.Uint16(data[1:3]) 539 540 // 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION 541 // 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover) 542 if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly { 543 // Oops; we are connected to a read-only connection, and won't be able 544 // to issue any write statements. Since RejectReadOnly is configured, 545 // we throw away this connection hoping this one would have write 546 // permission. This is specifically for a possible race condition 547 // during failover (e.g. on AWS Aurora). See README.md for more. 548 // 549 // We explicitly close the connection before returning 550 // driver.ErrBadConn to ensure that `database/sql` purges this 551 // connection and initiates a new one for next statement next time. 552 mc.Close() 553 return driver.ErrBadConn 554 } 555 556 pos := 3 557 558 // SQL State [optional: # + 5bytes string] 559 if data[3] == 0x23 { 560 //sqlstate := string(data[4 : 4+5]) 561 pos = 9 562 } 563 564 // Error Message [string] 565 return &MySQLError{ 566 Number: errno, 567 Message: string(data[pos:]), 568 } 569} 570 571func readStatus(b []byte) statusFlag { 572 return statusFlag(b[0]) | statusFlag(b[1])<<8 573} 574 575// Ok Packet 576// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet 577func (mc *mysqlConn) handleOkPacket(data []byte) error { 578 var n, m int 579 580 // 0x00 [1 byte] 581 582 // Affected rows [Length Coded Binary] 583 mc.affectedRows, _, n = readLengthEncodedInteger(data[1:]) 584 585 // Insert id [Length Coded Binary] 586 mc.insertId, _, m = readLengthEncodedInteger(data[1+n:]) 587 588 // server_status [2 bytes] 589 mc.status = readStatus(data[1+n+m : 1+n+m+2]) 590 if mc.status&statusMoreResultsExists != 0 { 591 return nil 592 } 593 594 // warning count [2 bytes] 595 596 return nil 597} 598 599// Read Packets as Field Packets until EOF-Packet or an Error appears 600// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41 601func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { 602 columns := make([]mysqlField, count) 603 604 for i := 0; ; i++ { 605 data, err := mc.readPacket() 606 if err != nil { 607 return nil, err 608 } 609 610 // EOF Packet 611 if data[0] == iEOF && (len(data) == 5 || len(data) == 1) { 612 if i == count { 613 return columns, nil 614 } 615 return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns)) 616 } 617 618 // Catalog 619 pos, err := skipLengthEncodedString(data) 620 if err != nil { 621 return nil, err 622 } 623 624 // Database [len coded string] 625 n, err := skipLengthEncodedString(data[pos:]) 626 if err != nil { 627 return nil, err 628 } 629 pos += n 630 631 // Table [len coded string] 632 if mc.cfg.ColumnsWithAlias { 633 tableName, _, n, err := readLengthEncodedString(data[pos:]) 634 if err != nil { 635 return nil, err 636 } 637 pos += n 638 columns[i].tableName = string(tableName) 639 } else { 640 n, err = skipLengthEncodedString(data[pos:]) 641 if err != nil { 642 return nil, err 643 } 644 pos += n 645 } 646 647 // Original table [len coded string] 648 n, err = skipLengthEncodedString(data[pos:]) 649 if err != nil { 650 return nil, err 651 } 652 pos += n 653 654 // Name [len coded string] 655 name, _, n, err := readLengthEncodedString(data[pos:]) 656 if err != nil { 657 return nil, err 658 } 659 columns[i].name = string(name) 660 pos += n 661 662 // Original name [len coded string] 663 n, err = skipLengthEncodedString(data[pos:]) 664 if err != nil { 665 return nil, err 666 } 667 pos += n 668 669 // Filler [uint8] 670 pos++ 671 672 // Charset [charset, collation uint8] 673 columns[i].charSet = data[pos] 674 pos += 2 675 676 // Length [uint32] 677 columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4]) 678 pos += 4 679 680 // Field type [uint8] 681 columns[i].fieldType = fieldType(data[pos]) 682 pos++ 683 684 // Flags [uint16] 685 columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) 686 pos += 2 687 688 // Decimals [uint8] 689 columns[i].decimals = data[pos] 690 //pos++ 691 692 // Default value [len coded binary] 693 //if pos < len(data) { 694 // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:]) 695 //} 696 } 697} 698 699// Read Packets as Field Packets until EOF-Packet or an Error appears 700// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow 701func (rows *textRows) readRow(dest []driver.Value) error { 702 mc := rows.mc 703 704 if rows.rs.done { 705 return io.EOF 706 } 707 708 data, err := mc.readPacket() 709 if err != nil { 710 return err 711 } 712 713 // EOF Packet 714 if data[0] == iEOF && len(data) == 5 { 715 // server_status [2 bytes] 716 rows.mc.status = readStatus(data[3:]) 717 rows.rs.done = true 718 if !rows.HasNextResultSet() { 719 rows.mc = nil 720 } 721 return io.EOF 722 } 723 if data[0] == iERR { 724 rows.mc = nil 725 return mc.handleErrorPacket(data) 726 } 727 728 // RowSet Packet 729 var n int 730 var isNull bool 731 pos := 0 732 733 for i := range dest { 734 // Read bytes and convert to string 735 dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) 736 pos += n 737 if err == nil { 738 if !isNull { 739 if !mc.parseTime { 740 continue 741 } else { 742 switch rows.rs.columns[i].fieldType { 743 case fieldTypeTimestamp, fieldTypeDateTime, 744 fieldTypeDate, fieldTypeNewDate: 745 dest[i], err = parseDateTime( 746 string(dest[i].([]byte)), 747 mc.cfg.Loc, 748 ) 749 if err == nil { 750 continue 751 } 752 default: 753 continue 754 } 755 } 756 757 } else { 758 dest[i] = nil 759 continue 760 } 761 } 762 return err // err != nil 763 } 764 765 return nil 766} 767 768// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read 769func (mc *mysqlConn) readUntilEOF() error { 770 for { 771 data, err := mc.readPacket() 772 if err != nil { 773 return err 774 } 775 776 switch data[0] { 777 case iERR: 778 return mc.handleErrorPacket(data) 779 case iEOF: 780 if len(data) == 5 { 781 mc.status = readStatus(data[3:]) 782 } 783 return nil 784 } 785 } 786} 787 788/****************************************************************************** 789* Prepared Statements * 790******************************************************************************/ 791 792// Prepare Result Packets 793// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html 794func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { 795 data, err := stmt.mc.readPacket() 796 if err == nil { 797 // packet indicator [1 byte] 798 if data[0] != iOK { 799 return 0, stmt.mc.handleErrorPacket(data) 800 } 801 802 // statement id [4 bytes] 803 stmt.id = binary.LittleEndian.Uint32(data[1:5]) 804 805 // Column count [16 bit uint] 806 columnCount := binary.LittleEndian.Uint16(data[5:7]) 807 808 // Param count [16 bit uint] 809 stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9])) 810 811 // Reserved [8 bit] 812 813 // Warning count [16 bit uint] 814 815 return columnCount, nil 816 } 817 return 0, err 818} 819 820// http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html 821func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { 822 maxLen := stmt.mc.maxAllowedPacket - 1 823 pktLen := maxLen 824 825 // After the header (bytes 0-3) follows before the data: 826 // 1 byte command 827 // 4 bytes stmtID 828 // 2 bytes paramID 829 const dataOffset = 1 + 4 + 2 830 831 // Cannot use the write buffer since 832 // a) the buffer is too small 833 // b) it is in use 834 data := make([]byte, 4+1+4+2+len(arg)) 835 836 copy(data[4+dataOffset:], arg) 837 838 for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset { 839 if dataOffset+argLen < maxLen { 840 pktLen = dataOffset + argLen 841 } 842 843 stmt.mc.sequence = 0 844 // Add command byte [1 byte] 845 data[4] = comStmtSendLongData 846 847 // Add stmtID [32 bit] 848 data[5] = byte(stmt.id) 849 data[6] = byte(stmt.id >> 8) 850 data[7] = byte(stmt.id >> 16) 851 data[8] = byte(stmt.id >> 24) 852 853 // Add paramID [16 bit] 854 data[9] = byte(paramID) 855 data[10] = byte(paramID >> 8) 856 857 // Send CMD packet 858 err := stmt.mc.writePacket(data[:4+pktLen]) 859 if err == nil { 860 data = data[pktLen-dataOffset:] 861 continue 862 } 863 return err 864 865 } 866 867 // Reset Packet Sequence 868 stmt.mc.sequence = 0 869 return nil 870} 871 872// Execute Prepared Statement 873// http://dev.mysql.com/doc/internals/en/com-stmt-execute.html 874func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { 875 if len(args) != stmt.paramCount { 876 return fmt.Errorf( 877 "argument count mismatch (got: %d; has: %d)", 878 len(args), 879 stmt.paramCount, 880 ) 881 } 882 883 const minPktLen = 4 + 1 + 4 + 1 + 4 884 mc := stmt.mc 885 886 // Determine threshould dynamically to avoid packet size shortage. 887 longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1) 888 if longDataSize < 64 { 889 longDataSize = 64 890 } 891 892 // Reset packet-sequence 893 mc.sequence = 0 894 895 var data []byte 896 897 if len(args) == 0 { 898 data = mc.buf.takeBuffer(minPktLen) 899 } else { 900 data = mc.buf.takeCompleteBuffer() 901 } 902 if data == nil { 903 // cannot take the buffer. Something must be wrong with the connection 904 errLog.Print(ErrBusyBuffer) 905 return errBadConnNoWrite 906 } 907 908 // command [1 byte] 909 data[4] = comStmtExecute 910 911 // statement_id [4 bytes] 912 data[5] = byte(stmt.id) 913 data[6] = byte(stmt.id >> 8) 914 data[7] = byte(stmt.id >> 16) 915 data[8] = byte(stmt.id >> 24) 916 917 // flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte] 918 data[9] = 0x00 919 920 // iteration_count (uint32(1)) [4 bytes] 921 data[10] = 0x01 922 data[11] = 0x00 923 data[12] = 0x00 924 data[13] = 0x00 925 926 if len(args) > 0 { 927 pos := minPktLen 928 929 var nullMask []byte 930 if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) { 931 // buffer has to be extended but we don't know by how much so 932 // we depend on append after all data with known sizes fit. 933 // We stop at that because we deal with a lot of columns here 934 // which makes the required allocation size hard to guess. 935 tmp := make([]byte, pos+maskLen+typesLen) 936 copy(tmp[:pos], data[:pos]) 937 data = tmp 938 nullMask = data[pos : pos+maskLen] 939 pos += maskLen 940 } else { 941 nullMask = data[pos : pos+maskLen] 942 for i := 0; i < maskLen; i++ { 943 nullMask[i] = 0 944 } 945 pos += maskLen 946 } 947 948 // newParameterBoundFlag 1 [1 byte] 949 data[pos] = 0x01 950 pos++ 951 952 // type of each parameter [len(args)*2 bytes] 953 paramTypes := data[pos:] 954 pos += len(args) * 2 955 956 // value of each parameter [n bytes] 957 paramValues := data[pos:pos] 958 valuesCap := cap(paramValues) 959 960 for i, arg := range args { 961 // build NULL-bitmap 962 if arg == nil { 963 nullMask[i/8] |= 1 << (uint(i) & 7) 964 paramTypes[i+i] = byte(fieldTypeNULL) 965 paramTypes[i+i+1] = 0x00 966 continue 967 } 968 969 // cache types and values 970 switch v := arg.(type) { 971 case int64: 972 paramTypes[i+i] = byte(fieldTypeLongLong) 973 paramTypes[i+i+1] = 0x00 974 975 if cap(paramValues)-len(paramValues)-8 >= 0 { 976 paramValues = paramValues[:len(paramValues)+8] 977 binary.LittleEndian.PutUint64( 978 paramValues[len(paramValues)-8:], 979 uint64(v), 980 ) 981 } else { 982 paramValues = append(paramValues, 983 uint64ToBytes(uint64(v))..., 984 ) 985 } 986 987 case float64: 988 paramTypes[i+i] = byte(fieldTypeDouble) 989 paramTypes[i+i+1] = 0x00 990 991 if cap(paramValues)-len(paramValues)-8 >= 0 { 992 paramValues = paramValues[:len(paramValues)+8] 993 binary.LittleEndian.PutUint64( 994 paramValues[len(paramValues)-8:], 995 math.Float64bits(v), 996 ) 997 } else { 998 paramValues = append(paramValues, 999 uint64ToBytes(math.Float64bits(v))..., 1000 ) 1001 } 1002 1003 case bool: 1004 paramTypes[i+i] = byte(fieldTypeTiny) 1005 paramTypes[i+i+1] = 0x00 1006 1007 if v { 1008 paramValues = append(paramValues, 0x01) 1009 } else { 1010 paramValues = append(paramValues, 0x00) 1011 } 1012 1013 case []byte: 1014 // Common case (non-nil value) first 1015 if v != nil { 1016 paramTypes[i+i] = byte(fieldTypeString) 1017 paramTypes[i+i+1] = 0x00 1018 1019 if len(v) < longDataSize { 1020 paramValues = appendLengthEncodedInteger(paramValues, 1021 uint64(len(v)), 1022 ) 1023 paramValues = append(paramValues, v...) 1024 } else { 1025 if err := stmt.writeCommandLongData(i, v); err != nil { 1026 return err 1027 } 1028 } 1029 continue 1030 } 1031 1032 // Handle []byte(nil) as a NULL value 1033 nullMask[i/8] |= 1 << (uint(i) & 7) 1034 paramTypes[i+i] = byte(fieldTypeNULL) 1035 paramTypes[i+i+1] = 0x00 1036 1037 case string: 1038 paramTypes[i+i] = byte(fieldTypeString) 1039 paramTypes[i+i+1] = 0x00 1040 1041 if len(v) < longDataSize { 1042 paramValues = appendLengthEncodedInteger(paramValues, 1043 uint64(len(v)), 1044 ) 1045 paramValues = append(paramValues, v...) 1046 } else { 1047 if err := stmt.writeCommandLongData(i, []byte(v)); err != nil { 1048 return err 1049 } 1050 } 1051 1052 case time.Time: 1053 paramTypes[i+i] = byte(fieldTypeString) 1054 paramTypes[i+i+1] = 0x00 1055 1056 var a [64]byte 1057 var b = a[:0] 1058 1059 if v.IsZero() { 1060 b = append(b, "0000-00-00"...) 1061 } else { 1062 b = v.In(mc.cfg.Loc).AppendFormat(b, timeFormat) 1063 } 1064 1065 paramValues = appendLengthEncodedInteger(paramValues, 1066 uint64(len(b)), 1067 ) 1068 paramValues = append(paramValues, b...) 1069 1070 default: 1071 return fmt.Errorf("cannot convert type: %T", arg) 1072 } 1073 } 1074 1075 // Check if param values exceeded the available buffer 1076 // In that case we must build the data packet with the new values buffer 1077 if valuesCap != cap(paramValues) { 1078 data = append(data[:pos], paramValues...) 1079 mc.buf.buf = data 1080 } 1081 1082 pos += len(paramValues) 1083 data = data[:pos] 1084 } 1085 1086 return mc.writePacket(data) 1087} 1088 1089func (mc *mysqlConn) discardResults() error { 1090 for mc.status&statusMoreResultsExists != 0 { 1091 resLen, err := mc.readResultSetHeaderPacket() 1092 if err != nil { 1093 return err 1094 } 1095 if resLen > 0 { 1096 // columns 1097 if err := mc.readUntilEOF(); err != nil { 1098 return err 1099 } 1100 // rows 1101 if err := mc.readUntilEOF(); err != nil { 1102 return err 1103 } 1104 } 1105 } 1106 return nil 1107} 1108 1109// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html 1110func (rows *binaryRows) readRow(dest []driver.Value) error { 1111 data, err := rows.mc.readPacket() 1112 if err != nil { 1113 return err 1114 } 1115 1116 // packet indicator [1 byte] 1117 if data[0] != iOK { 1118 // EOF Packet 1119 if data[0] == iEOF && len(data) == 5 { 1120 rows.mc.status = readStatus(data[3:]) 1121 rows.rs.done = true 1122 if !rows.HasNextResultSet() { 1123 rows.mc = nil 1124 } 1125 return io.EOF 1126 } 1127 mc := rows.mc 1128 rows.mc = nil 1129 1130 // Error otherwise 1131 return mc.handleErrorPacket(data) 1132 } 1133 1134 // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes] 1135 pos := 1 + (len(dest)+7+2)>>3 1136 nullMask := data[1:pos] 1137 1138 for i := range dest { 1139 // Field is NULL 1140 // (byte >> bit-pos) % 2 == 1 1141 if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 { 1142 dest[i] = nil 1143 continue 1144 } 1145 1146 // Convert to byte-coded string 1147 switch rows.rs.columns[i].fieldType { 1148 case fieldTypeNULL: 1149 dest[i] = nil 1150 continue 1151 1152 // Numeric Types 1153 case fieldTypeTiny: 1154 if rows.rs.columns[i].flags&flagUnsigned != 0 { 1155 dest[i] = int64(data[pos]) 1156 } else { 1157 dest[i] = int64(int8(data[pos])) 1158 } 1159 pos++ 1160 continue 1161 1162 case fieldTypeShort, fieldTypeYear: 1163 if rows.rs.columns[i].flags&flagUnsigned != 0 { 1164 dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2])) 1165 } else { 1166 dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2]))) 1167 } 1168 pos += 2 1169 continue 1170 1171 case fieldTypeInt24, fieldTypeLong: 1172 if rows.rs.columns[i].flags&flagUnsigned != 0 { 1173 dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4])) 1174 } else { 1175 dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4]))) 1176 } 1177 pos += 4 1178 continue 1179 1180 case fieldTypeLongLong: 1181 if rows.rs.columns[i].flags&flagUnsigned != 0 { 1182 val := binary.LittleEndian.Uint64(data[pos : pos+8]) 1183 if val > math.MaxInt64 { 1184 dest[i] = uint64ToString(val) 1185 } else { 1186 dest[i] = int64(val) 1187 } 1188 } else { 1189 dest[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8])) 1190 } 1191 pos += 8 1192 continue 1193 1194 case fieldTypeFloat: 1195 dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])) 1196 pos += 4 1197 continue 1198 1199 case fieldTypeDouble: 1200 dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[pos : pos+8])) 1201 pos += 8 1202 continue 1203 1204 // Length coded Binary Strings 1205 case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, 1206 fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, 1207 fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, 1208 fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON: 1209 var isNull bool 1210 var n int 1211 dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) 1212 pos += n 1213 if err == nil { 1214 if !isNull { 1215 continue 1216 } else { 1217 dest[i] = nil 1218 continue 1219 } 1220 } 1221 return err 1222 1223 case 1224 fieldTypeDate, fieldTypeNewDate, // Date YYYY-MM-DD 1225 fieldTypeTime, // Time [-][H]HH:MM:SS[.fractal] 1226 fieldTypeTimestamp, fieldTypeDateTime: // Timestamp YYYY-MM-DD HH:MM:SS[.fractal] 1227 1228 num, isNull, n := readLengthEncodedInteger(data[pos:]) 1229 pos += n 1230 1231 switch { 1232 case isNull: 1233 dest[i] = nil 1234 continue 1235 case rows.rs.columns[i].fieldType == fieldTypeTime: 1236 // database/sql does not support an equivalent to TIME, return a string 1237 var dstlen uint8 1238 switch decimals := rows.rs.columns[i].decimals; decimals { 1239 case 0x00, 0x1f: 1240 dstlen = 8 1241 case 1, 2, 3, 4, 5, 6: 1242 dstlen = 8 + 1 + decimals 1243 default: 1244 return fmt.Errorf( 1245 "protocol error, illegal decimals value %d", 1246 rows.rs.columns[i].decimals, 1247 ) 1248 } 1249 dest[i], err = formatBinaryTime(data[pos:pos+int(num)], dstlen) 1250 case rows.mc.parseTime: 1251 dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc) 1252 default: 1253 var dstlen uint8 1254 if rows.rs.columns[i].fieldType == fieldTypeDate { 1255 dstlen = 10 1256 } else { 1257 switch decimals := rows.rs.columns[i].decimals; decimals { 1258 case 0x00, 0x1f: 1259 dstlen = 19 1260 case 1, 2, 3, 4, 5, 6: 1261 dstlen = 19 + 1 + decimals 1262 default: 1263 return fmt.Errorf( 1264 "protocol error, illegal decimals value %d", 1265 rows.rs.columns[i].decimals, 1266 ) 1267 } 1268 } 1269 dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen) 1270 } 1271 1272 if err == nil { 1273 pos += int(num) 1274 continue 1275 } else { 1276 return err 1277 } 1278 1279 // Please report if this happens! 1280 default: 1281 return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType) 1282 } 1283 } 1284 1285 return nil 1286} 1287