1// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. 2// 3// This Source Code Form is subject to the terms of the Mozilla Public 4// License, v. 2.0. If a copy of the MPL was not distributed with this file, 5// You can obtain one at http://mozilla.org/MPL/2.0/. 6 7// Package mysql provides a MySQL driver for Go's database/sql package. 8// 9// The driver should be used via the database/sql package: 10// 11// import "database/sql" 12// import _ "github.com/go-sql-driver/mysql" 13// 14// db, err := sql.Open("mysql", "user:password@/dbname") 15// 16// See https://github.com/go-sql-driver/mysql#usage for details 17package mysql 18 19import ( 20 "database/sql" 21 "database/sql/driver" 22 "net" 23 "sync" 24) 25 26// watcher interface is used for context support (From Go 1.8) 27type watcher interface { 28 startWatcher() 29} 30 31// MySQLDriver is exported to make the driver directly accessible. 32// In general the driver is used via the database/sql package. 33type MySQLDriver struct{} 34 35// DialFunc is a function which can be used to establish the network connection. 36// Custom dial functions must be registered with RegisterDial 37type DialFunc func(addr string) (net.Conn, error) 38 39var ( 40 dialsLock sync.RWMutex 41 dials map[string]DialFunc 42) 43 44// RegisterDial registers a custom dial function. It can then be used by the 45// network address mynet(addr), where mynet is the registered new network. 46// addr is passed as a parameter to the dial function. 47func RegisterDial(net string, dial DialFunc) { 48 dialsLock.Lock() 49 defer dialsLock.Unlock() 50 if dials == nil { 51 dials = make(map[string]DialFunc) 52 } 53 dials[net] = dial 54} 55 56// Open new Connection. 57// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how 58// the DSN string is formated 59func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { 60 var err error 61 62 // New mysqlConn 63 mc := &mysqlConn{ 64 maxAllowedPacket: maxPacketSize, 65 maxWriteSize: maxPacketSize - 1, 66 closech: make(chan struct{}), 67 } 68 mc.cfg, err = ParseDSN(dsn) 69 if err != nil { 70 return nil, err 71 } 72 mc.parseTime = mc.cfg.ParseTime 73 74 // Connect to Server 75 dialsLock.RLock() 76 dial, ok := dials[mc.cfg.Net] 77 dialsLock.RUnlock() 78 if ok { 79 mc.netConn, err = dial(mc.cfg.Addr) 80 } else { 81 nd := net.Dialer{Timeout: mc.cfg.Timeout} 82 mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr) 83 } 84 if err != nil { 85 return nil, err 86 } 87 88 // Enable TCP Keepalives on TCP connections 89 if tc, ok := mc.netConn.(*net.TCPConn); ok { 90 if err := tc.SetKeepAlive(true); err != nil { 91 // Don't send COM_QUIT before handshake. 92 mc.netConn.Close() 93 mc.netConn = nil 94 return nil, err 95 } 96 } 97 98 // Call startWatcher for context support (From Go 1.8) 99 if s, ok := interface{}(mc).(watcher); ok { 100 s.startWatcher() 101 } 102 103 mc.buf = newBuffer(mc.netConn) 104 105 // Set I/O timeouts 106 mc.buf.timeout = mc.cfg.ReadTimeout 107 mc.writeTimeout = mc.cfg.WriteTimeout 108 109 // Reading Handshake Initialization Packet 110 authData, plugin, err := mc.readHandshakePacket() 111 if err != nil { 112 mc.cleanup() 113 return nil, err 114 } 115 if plugin == "" { 116 plugin = defaultAuthPlugin 117 } 118 119 // Send Client Authentication Packet 120 authResp, err := mc.auth(authData, plugin) 121 if err != nil { 122 // try the default auth plugin, if using the requested plugin failed 123 errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) 124 plugin = defaultAuthPlugin 125 authResp, err = mc.auth(authData, plugin) 126 if err != nil { 127 mc.cleanup() 128 return nil, err 129 } 130 } 131 if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil { 132 mc.cleanup() 133 return nil, err 134 } 135 136 // Handle response to auth packet, switch methods if possible 137 if err = mc.handleAuthResult(authData, plugin); err != nil { 138 // Authentication failed and MySQL has already closed the connection 139 // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). 140 // Do not send COM_QUIT, just cleanup and return the error. 141 mc.cleanup() 142 return nil, err 143 } 144 145 if mc.cfg.MaxAllowedPacket > 0 { 146 mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket 147 } else { 148 // Get max allowed packet size 149 maxap, err := mc.getSystemVar("max_allowed_packet") 150 if err != nil { 151 mc.Close() 152 return nil, err 153 } 154 mc.maxAllowedPacket = stringToInt(maxap) - 1 155 } 156 if mc.maxAllowedPacket < maxPacketSize { 157 mc.maxWriteSize = mc.maxAllowedPacket 158 } 159 160 // Handle DSN Params 161 err = mc.handleParams() 162 if err != nil { 163 mc.Close() 164 return nil, err 165 } 166 167 return mc, nil 168} 169 170func init() { 171 sql.Register("mysql", &MySQLDriver{}) 172} 173