1// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
2//
3// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
4//
5// This Source Code Form is subject to the terms of the Mozilla Public
6// License, v. 2.0. If a copy of the MPL was not distributed with this file,
7// You can obtain one at http://mozilla.org/MPL/2.0/.
8
9package mysql
10
11import (
12	"database/sql/driver"
13	"errors"
14	"net"
15	"testing"
16	"time"
17)
18
19var (
20	errConnClosed        = errors.New("connection is closed")
21	errConnTooManyReads  = errors.New("too many reads")
22	errConnTooManyWrites = errors.New("too many writes")
23)
24
25// struct to mock a net.Conn for testing purposes
26type mockConn struct {
27	laddr     net.Addr
28	raddr     net.Addr
29	data      []byte
30	closed    bool
31	read      int
32	written   int
33	reads     int
34	writes    int
35	maxReads  int
36	maxWrites int
37}
38
39func (m *mockConn) Read(b []byte) (n int, err error) {
40	if m.closed {
41		return 0, errConnClosed
42	}
43
44	m.reads++
45	if m.maxReads > 0 && m.reads > m.maxReads {
46		return 0, errConnTooManyReads
47	}
48
49	n = copy(b, m.data)
50	m.read += n
51	m.data = m.data[n:]
52	return
53}
54func (m *mockConn) Write(b []byte) (n int, err error) {
55	if m.closed {
56		return 0, errConnClosed
57	}
58
59	m.writes++
60	if m.maxWrites > 0 && m.writes > m.maxWrites {
61		return 0, errConnTooManyWrites
62	}
63
64	n = len(b)
65	m.written += n
66	return
67}
68func (m *mockConn) Close() error {
69	m.closed = true
70	return nil
71}
72func (m *mockConn) LocalAddr() net.Addr {
73	return m.laddr
74}
75func (m *mockConn) RemoteAddr() net.Addr {
76	return m.raddr
77}
78func (m *mockConn) SetDeadline(t time.Time) error {
79	return nil
80}
81func (m *mockConn) SetReadDeadline(t time.Time) error {
82	return nil
83}
84func (m *mockConn) SetWriteDeadline(t time.Time) error {
85	return nil
86}
87
88// make sure mockConn implements the net.Conn interface
89var _ net.Conn = new(mockConn)
90
91func TestReadPacketSingleByte(t *testing.T) {
92	conn := new(mockConn)
93	mc := &mysqlConn{
94		buf: newBuffer(conn),
95	}
96
97	conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
98	conn.maxReads = 1
99	packet, err := mc.readPacket()
100	if err != nil {
101		t.Fatal(err)
102	}
103	if len(packet) != 1 {
104		t.Fatalf("unexpected packet lenght: expected %d, got %d", 1, len(packet))
105	}
106	if packet[0] != 0xff {
107		t.Fatalf("unexpected packet content: expected %x, got %x", 0xff, packet[0])
108	}
109}
110
111func TestReadPacketWrongSequenceID(t *testing.T) {
112	conn := new(mockConn)
113	mc := &mysqlConn{
114		buf: newBuffer(conn),
115	}
116
117	// too low sequence id
118	conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
119	conn.maxReads = 1
120	mc.sequence = 1
121	_, err := mc.readPacket()
122	if err != ErrPktSync {
123		t.Errorf("expected ErrPktSync, got %v", err)
124	}
125
126	// reset
127	conn.reads = 0
128	mc.sequence = 0
129	mc.buf = newBuffer(conn)
130
131	// too high sequence id
132	conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff}
133	_, err = mc.readPacket()
134	if err != ErrPktSyncMul {
135		t.Errorf("expected ErrPktSyncMul, got %v", err)
136	}
137}
138
139func TestReadPacketSplit(t *testing.T) {
140	conn := new(mockConn)
141	mc := &mysqlConn{
142		buf: newBuffer(conn),
143	}
144
145	data := make([]byte, maxPacketSize*2+4*3)
146	const pkt2ofs = maxPacketSize + 4
147	const pkt3ofs = 2 * (maxPacketSize + 4)
148
149	// case 1: payload has length maxPacketSize
150	data = data[:pkt2ofs+4]
151
152	// 1st packet has maxPacketSize length and sequence id 0
153	// ff ff ff 00 ...
154	data[0] = 0xff
155	data[1] = 0xff
156	data[2] = 0xff
157
158	// mark the payload start and end of 1st packet so that we can check if the
159	// content was correctly appended
160	data[4] = 0x11
161	data[maxPacketSize+3] = 0x22
162
163	// 2nd packet has payload length 0 and squence id 1
164	// 00 00 00 01
165	data[pkt2ofs+3] = 0x01
166
167	conn.data = data
168	conn.maxReads = 3
169	packet, err := mc.readPacket()
170	if err != nil {
171		t.Fatal(err)
172	}
173	if len(packet) != maxPacketSize {
174		t.Fatalf("unexpected packet lenght: expected %d, got %d", maxPacketSize, len(packet))
175	}
176	if packet[0] != 0x11 {
177		t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
178	}
179	if packet[maxPacketSize-1] != 0x22 {
180		t.Fatalf("unexpected payload end: expected %x, got %x", 0x22, packet[maxPacketSize-1])
181	}
182
183	// case 2: payload has length which is a multiple of maxPacketSize
184	data = data[:cap(data)]
185
186	// 2nd packet now has maxPacketSize length
187	data[pkt2ofs] = 0xff
188	data[pkt2ofs+1] = 0xff
189	data[pkt2ofs+2] = 0xff
190
191	// mark the payload start and end of the 2nd packet
192	data[pkt2ofs+4] = 0x33
193	data[pkt2ofs+maxPacketSize+3] = 0x44
194
195	// 3rd packet has payload length 0 and squence id 2
196	// 00 00 00 02
197	data[pkt3ofs+3] = 0x02
198
199	conn.data = data
200	conn.reads = 0
201	conn.maxReads = 5
202	mc.sequence = 0
203	packet, err = mc.readPacket()
204	if err != nil {
205		t.Fatal(err)
206	}
207	if len(packet) != 2*maxPacketSize {
208		t.Fatalf("unexpected packet lenght: expected %d, got %d", 2*maxPacketSize, len(packet))
209	}
210	if packet[0] != 0x11 {
211		t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
212	}
213	if packet[2*maxPacketSize-1] != 0x44 {
214		t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[2*maxPacketSize-1])
215	}
216
217	// case 3: payload has a length larger maxPacketSize, which is not an exact
218	// multiple of it
219	data = data[:pkt2ofs+4+42]
220	data[pkt2ofs] = 0x2a
221	data[pkt2ofs+1] = 0x00
222	data[pkt2ofs+2] = 0x00
223	data[pkt2ofs+4+41] = 0x44
224
225	conn.data = data
226	conn.reads = 0
227	conn.maxReads = 4
228	mc.sequence = 0
229	packet, err = mc.readPacket()
230	if err != nil {
231		t.Fatal(err)
232	}
233	if len(packet) != maxPacketSize+42 {
234		t.Fatalf("unexpected packet lenght: expected %d, got %d", maxPacketSize+42, len(packet))
235	}
236	if packet[0] != 0x11 {
237		t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
238	}
239	if packet[maxPacketSize+41] != 0x44 {
240		t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[maxPacketSize+41])
241	}
242}
243
244func TestReadPacketFail(t *testing.T) {
245	conn := new(mockConn)
246	mc := &mysqlConn{
247		buf: newBuffer(conn),
248	}
249
250	// illegal empty (stand-alone) packet
251	conn.data = []byte{0x00, 0x00, 0x00, 0x00}
252	conn.maxReads = 1
253	_, err := mc.readPacket()
254	if err != driver.ErrBadConn {
255		t.Errorf("expected ErrBadConn, got %v", err)
256	}
257
258	// reset
259	conn.reads = 0
260	mc.sequence = 0
261	mc.buf = newBuffer(conn)
262
263	// fail to read header
264	conn.closed = true
265	_, err = mc.readPacket()
266	if err != driver.ErrBadConn {
267		t.Errorf("expected ErrBadConn, got %v", err)
268	}
269
270	// reset
271	conn.closed = false
272	conn.reads = 0
273	mc.sequence = 0
274	mc.buf = newBuffer(conn)
275
276	// fail to read body
277	conn.maxReads = 1
278	_, err = mc.readPacket()
279	if err != driver.ErrBadConn {
280		t.Errorf("expected ErrBadConn, got %v", err)
281	}
282}
283