1// Go MySQL Driver - A MySQL-Driver for Go's database/sql package 2// 3// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved. 4// 5// This Source Code Form is subject to the terms of the Mozilla Public 6// License, v. 2.0. If a copy of the MPL was not distributed with this file, 7// You can obtain one at http://mozilla.org/MPL/2.0/. 8 9package mysql 10 11import ( 12 "database/sql/driver" 13 "errors" 14 "net" 15 "testing" 16 "time" 17) 18 19var ( 20 errConnClosed = errors.New("connection is closed") 21 errConnTooManyReads = errors.New("too many reads") 22 errConnTooManyWrites = errors.New("too many writes") 23) 24 25// struct to mock a net.Conn for testing purposes 26type mockConn struct { 27 laddr net.Addr 28 raddr net.Addr 29 data []byte 30 closed bool 31 read int 32 written int 33 reads int 34 writes int 35 maxReads int 36 maxWrites int 37} 38 39func (m *mockConn) Read(b []byte) (n int, err error) { 40 if m.closed { 41 return 0, errConnClosed 42 } 43 44 m.reads++ 45 if m.maxReads > 0 && m.reads > m.maxReads { 46 return 0, errConnTooManyReads 47 } 48 49 n = copy(b, m.data) 50 m.read += n 51 m.data = m.data[n:] 52 return 53} 54func (m *mockConn) Write(b []byte) (n int, err error) { 55 if m.closed { 56 return 0, errConnClosed 57 } 58 59 m.writes++ 60 if m.maxWrites > 0 && m.writes > m.maxWrites { 61 return 0, errConnTooManyWrites 62 } 63 64 n = len(b) 65 m.written += n 66 return 67} 68func (m *mockConn) Close() error { 69 m.closed = true 70 return nil 71} 72func (m *mockConn) LocalAddr() net.Addr { 73 return m.laddr 74} 75func (m *mockConn) RemoteAddr() net.Addr { 76 return m.raddr 77} 78func (m *mockConn) SetDeadline(t time.Time) error { 79 return nil 80} 81func (m *mockConn) SetReadDeadline(t time.Time) error { 82 return nil 83} 84func (m *mockConn) SetWriteDeadline(t time.Time) error { 85 return nil 86} 87 88// make sure mockConn implements the net.Conn interface 89var _ net.Conn = new(mockConn) 90 91func TestReadPacketSingleByte(t *testing.T) { 92 conn := new(mockConn) 93 mc := &mysqlConn{ 94 buf: newBuffer(conn), 95 } 96 97 conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} 98 conn.maxReads = 1 99 packet, err := mc.readPacket() 100 if err != nil { 101 t.Fatal(err) 102 } 103 if len(packet) != 1 { 104 t.Fatalf("unexpected packet lenght: expected %d, got %d", 1, len(packet)) 105 } 106 if packet[0] != 0xff { 107 t.Fatalf("unexpected packet content: expected %x, got %x", 0xff, packet[0]) 108 } 109} 110 111func TestReadPacketWrongSequenceID(t *testing.T) { 112 conn := new(mockConn) 113 mc := &mysqlConn{ 114 buf: newBuffer(conn), 115 } 116 117 // too low sequence id 118 conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} 119 conn.maxReads = 1 120 mc.sequence = 1 121 _, err := mc.readPacket() 122 if err != ErrPktSync { 123 t.Errorf("expected ErrPktSync, got %v", err) 124 } 125 126 // reset 127 conn.reads = 0 128 mc.sequence = 0 129 mc.buf = newBuffer(conn) 130 131 // too high sequence id 132 conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff} 133 _, err = mc.readPacket() 134 if err != ErrPktSyncMul { 135 t.Errorf("expected ErrPktSyncMul, got %v", err) 136 } 137} 138 139func TestReadPacketSplit(t *testing.T) { 140 conn := new(mockConn) 141 mc := &mysqlConn{ 142 buf: newBuffer(conn), 143 } 144 145 data := make([]byte, maxPacketSize*2+4*3) 146 const pkt2ofs = maxPacketSize + 4 147 const pkt3ofs = 2 * (maxPacketSize + 4) 148 149 // case 1: payload has length maxPacketSize 150 data = data[:pkt2ofs+4] 151 152 // 1st packet has maxPacketSize length and sequence id 0 153 // ff ff ff 00 ... 154 data[0] = 0xff 155 data[1] = 0xff 156 data[2] = 0xff 157 158 // mark the payload start and end of 1st packet so that we can check if the 159 // content was correctly appended 160 data[4] = 0x11 161 data[maxPacketSize+3] = 0x22 162 163 // 2nd packet has payload length 0 and squence id 1 164 // 00 00 00 01 165 data[pkt2ofs+3] = 0x01 166 167 conn.data = data 168 conn.maxReads = 3 169 packet, err := mc.readPacket() 170 if err != nil { 171 t.Fatal(err) 172 } 173 if len(packet) != maxPacketSize { 174 t.Fatalf("unexpected packet lenght: expected %d, got %d", maxPacketSize, len(packet)) 175 } 176 if packet[0] != 0x11 { 177 t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0]) 178 } 179 if packet[maxPacketSize-1] != 0x22 { 180 t.Fatalf("unexpected payload end: expected %x, got %x", 0x22, packet[maxPacketSize-1]) 181 } 182 183 // case 2: payload has length which is a multiple of maxPacketSize 184 data = data[:cap(data)] 185 186 // 2nd packet now has maxPacketSize length 187 data[pkt2ofs] = 0xff 188 data[pkt2ofs+1] = 0xff 189 data[pkt2ofs+2] = 0xff 190 191 // mark the payload start and end of the 2nd packet 192 data[pkt2ofs+4] = 0x33 193 data[pkt2ofs+maxPacketSize+3] = 0x44 194 195 // 3rd packet has payload length 0 and squence id 2 196 // 00 00 00 02 197 data[pkt3ofs+3] = 0x02 198 199 conn.data = data 200 conn.reads = 0 201 conn.maxReads = 5 202 mc.sequence = 0 203 packet, err = mc.readPacket() 204 if err != nil { 205 t.Fatal(err) 206 } 207 if len(packet) != 2*maxPacketSize { 208 t.Fatalf("unexpected packet lenght: expected %d, got %d", 2*maxPacketSize, len(packet)) 209 } 210 if packet[0] != 0x11 { 211 t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0]) 212 } 213 if packet[2*maxPacketSize-1] != 0x44 { 214 t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[2*maxPacketSize-1]) 215 } 216 217 // case 3: payload has a length larger maxPacketSize, which is not an exact 218 // multiple of it 219 data = data[:pkt2ofs+4+42] 220 data[pkt2ofs] = 0x2a 221 data[pkt2ofs+1] = 0x00 222 data[pkt2ofs+2] = 0x00 223 data[pkt2ofs+4+41] = 0x44 224 225 conn.data = data 226 conn.reads = 0 227 conn.maxReads = 4 228 mc.sequence = 0 229 packet, err = mc.readPacket() 230 if err != nil { 231 t.Fatal(err) 232 } 233 if len(packet) != maxPacketSize+42 { 234 t.Fatalf("unexpected packet lenght: expected %d, got %d", maxPacketSize+42, len(packet)) 235 } 236 if packet[0] != 0x11 { 237 t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0]) 238 } 239 if packet[maxPacketSize+41] != 0x44 { 240 t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[maxPacketSize+41]) 241 } 242} 243 244func TestReadPacketFail(t *testing.T) { 245 conn := new(mockConn) 246 mc := &mysqlConn{ 247 buf: newBuffer(conn), 248 } 249 250 // illegal empty (stand-alone) packet 251 conn.data = []byte{0x00, 0x00, 0x00, 0x00} 252 conn.maxReads = 1 253 _, err := mc.readPacket() 254 if err != driver.ErrBadConn { 255 t.Errorf("expected ErrBadConn, got %v", err) 256 } 257 258 // reset 259 conn.reads = 0 260 mc.sequence = 0 261 mc.buf = newBuffer(conn) 262 263 // fail to read header 264 conn.closed = true 265 _, err = mc.readPacket() 266 if err != driver.ErrBadConn { 267 t.Errorf("expected ErrBadConn, got %v", err) 268 } 269 270 // reset 271 conn.closed = false 272 conn.reads = 0 273 mc.sequence = 0 274 mc.buf = newBuffer(conn) 275 276 // fail to read body 277 conn.maxReads = 1 278 _, err = mc.readPacket() 279 if err != driver.ErrBadConn { 280 t.Errorf("expected ErrBadConn, got %v", err) 281 } 282} 283