1// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. 2// 3// This Source Code Form is subject to the terms of the Mozilla Public 4// License, v. 2.0. If a copy of the MPL was not distributed with this file, 5// You can obtain one at http://mozilla.org/MPL/2.0/. 6 7// Package mysql provides a MySQL driver for Go's database/sql package. 8// 9// The driver should be used via the database/sql package: 10// 11// import "database/sql" 12// import _ "github.com/go-sql-driver/mysql" 13// 14// db, err := sql.Open("mysql", "user:password@/dbname") 15// 16// See https://github.com/go-sql-driver/mysql#usage for details 17package mysql 18 19import ( 20 "database/sql" 21 "database/sql/driver" 22 "net" 23 "sync" 24) 25 26// MySQLDriver is exported to make the driver directly accessible. 27// In general the driver is used via the database/sql package. 28type MySQLDriver struct{} 29 30// DialFunc is a function which can be used to establish the network connection. 31// Custom dial functions must be registered with RegisterDial 32type DialFunc func(addr string) (net.Conn, error) 33 34var ( 35 dialsLock sync.RWMutex 36 dials map[string]DialFunc 37) 38 39// RegisterDial registers a custom dial function. It can then be used by the 40// network address mynet(addr), where mynet is the registered new network. 41// addr is passed as a parameter to the dial function. 42func RegisterDial(net string, dial DialFunc) { 43 dialsLock.Lock() 44 defer dialsLock.Unlock() 45 if dials == nil { 46 dials = make(map[string]DialFunc) 47 } 48 dials[net] = dial 49} 50 51// Open new Connection. 52// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how 53// the DSN string is formated 54func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { 55 var err error 56 57 // New mysqlConn 58 mc := &mysqlConn{ 59 maxAllowedPacket: maxPacketSize, 60 maxWriteSize: maxPacketSize - 1, 61 closech: make(chan struct{}), 62 } 63 mc.cfg, err = ParseDSN(dsn) 64 if err != nil { 65 return nil, err 66 } 67 mc.parseTime = mc.cfg.ParseTime 68 69 // Connect to Server 70 dialsLock.RLock() 71 dial, ok := dials[mc.cfg.Net] 72 dialsLock.RUnlock() 73 if ok { 74 mc.netConn, err = dial(mc.cfg.Addr) 75 } else { 76 nd := net.Dialer{Timeout: mc.cfg.Timeout} 77 mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr) 78 } 79 if err != nil { 80 return nil, err 81 } 82 83 // Enable TCP Keepalives on TCP connections 84 if tc, ok := mc.netConn.(*net.TCPConn); ok { 85 if err := tc.SetKeepAlive(true); err != nil { 86 // Don't send COM_QUIT before handshake. 87 mc.netConn.Close() 88 mc.netConn = nil 89 return nil, err 90 } 91 } 92 93 // Call startWatcher for context support (From Go 1.8) 94 mc.startWatcher() 95 96 mc.buf = newBuffer(mc.netConn) 97 98 // Set I/O timeouts 99 mc.buf.timeout = mc.cfg.ReadTimeout 100 mc.writeTimeout = mc.cfg.WriteTimeout 101 102 // Reading Handshake Initialization Packet 103 authData, plugin, err := mc.readHandshakePacket() 104 if err != nil { 105 mc.cleanup() 106 return nil, err 107 } 108 if plugin == "" { 109 plugin = defaultAuthPlugin 110 } 111 112 // Send Client Authentication Packet 113 authResp, addNUL, err := mc.auth(authData, plugin) 114 if err != nil { 115 // try the default auth plugin, if using the requested plugin failed 116 errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) 117 plugin = defaultAuthPlugin 118 authResp, addNUL, err = mc.auth(authData, plugin) 119 if err != nil { 120 mc.cleanup() 121 return nil, err 122 } 123 } 124 if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil { 125 mc.cleanup() 126 return nil, err 127 } 128 129 // Handle response to auth packet, switch methods if possible 130 if err = mc.handleAuthResult(authData, plugin); err != nil { 131 // Authentication failed and MySQL has already closed the connection 132 // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). 133 // Do not send COM_QUIT, just cleanup and return the error. 134 mc.cleanup() 135 return nil, err 136 } 137 138 if mc.cfg.MaxAllowedPacket > 0 { 139 mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket 140 } else { 141 // Get max allowed packet size 142 maxap, err := mc.getSystemVar("max_allowed_packet") 143 if err != nil { 144 mc.Close() 145 return nil, err 146 } 147 mc.maxAllowedPacket = stringToInt(maxap) - 1 148 } 149 if mc.maxAllowedPacket < maxPacketSize { 150 mc.maxWriteSize = mc.maxAllowedPacket 151 } 152 153 // Handle DSN Params 154 err = mc.handleParams() 155 if err != nil { 156 mc.Close() 157 return nil, err 158 } 159 160 return mc, nil 161} 162 163func init() { 164 sql.Register("mysql", &MySQLDriver{}) 165} 166