1// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
2//
3// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
4//
5// This Source Code Form is subject to the terms of the Mozilla Public
6// License, v. 2.0. If a copy of the MPL was not distributed with this file,
7// You can obtain one at http://mozilla.org/MPL/2.0/.
8
9package mysql
10
11import (
12	"bytes"
13	"context"
14	"database/sql"
15	"database/sql/driver"
16	"fmt"
17	"math"
18	"runtime"
19	"strings"
20	"sync"
21	"sync/atomic"
22	"testing"
23	"time"
24)
25
26type TB testing.B
27
28func (tb *TB) check(err error) {
29	if err != nil {
30		tb.Fatal(err)
31	}
32}
33
34func (tb *TB) checkDB(db *sql.DB, err error) *sql.DB {
35	tb.check(err)
36	return db
37}
38
39func (tb *TB) checkRows(rows *sql.Rows, err error) *sql.Rows {
40	tb.check(err)
41	return rows
42}
43
44func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt {
45	tb.check(err)
46	return stmt
47}
48
49func initDB(b *testing.B, queries ...string) *sql.DB {
50	tb := (*TB)(b)
51	db := tb.checkDB(sql.Open("mysql", dsn))
52	for _, query := range queries {
53		if _, err := db.Exec(query); err != nil {
54			b.Fatalf("error on %q: %v", query, err)
55		}
56	}
57	return db
58}
59
60const concurrencyLevel = 10
61
62func BenchmarkQuery(b *testing.B) {
63	tb := (*TB)(b)
64	b.StopTimer()
65	b.ReportAllocs()
66	db := initDB(b,
67		"DROP TABLE IF EXISTS foo",
68		"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
69		`INSERT INTO foo VALUES (1, "one")`,
70		`INSERT INTO foo VALUES (2, "two")`,
71	)
72	db.SetMaxIdleConns(concurrencyLevel)
73	defer db.Close()
74
75	stmt := tb.checkStmt(db.Prepare("SELECT val FROM foo WHERE id=?"))
76	defer stmt.Close()
77
78	remain := int64(b.N)
79	var wg sync.WaitGroup
80	wg.Add(concurrencyLevel)
81	defer wg.Wait()
82	b.StartTimer()
83
84	for i := 0; i < concurrencyLevel; i++ {
85		go func() {
86			for {
87				if atomic.AddInt64(&remain, -1) < 0 {
88					wg.Done()
89					return
90				}
91
92				var got string
93				tb.check(stmt.QueryRow(1).Scan(&got))
94				if got != "one" {
95					b.Errorf("query = %q; want one", got)
96					wg.Done()
97					return
98				}
99			}
100		}()
101	}
102}
103
104func BenchmarkExec(b *testing.B) {
105	tb := (*TB)(b)
106	b.StopTimer()
107	b.ReportAllocs()
108	db := tb.checkDB(sql.Open("mysql", dsn))
109	db.SetMaxIdleConns(concurrencyLevel)
110	defer db.Close()
111
112	stmt := tb.checkStmt(db.Prepare("DO 1"))
113	defer stmt.Close()
114
115	remain := int64(b.N)
116	var wg sync.WaitGroup
117	wg.Add(concurrencyLevel)
118	defer wg.Wait()
119	b.StartTimer()
120
121	for i := 0; i < concurrencyLevel; i++ {
122		go func() {
123			for {
124				if atomic.AddInt64(&remain, -1) < 0 {
125					wg.Done()
126					return
127				}
128
129				if _, err := stmt.Exec(); err != nil {
130					b.Logf("stmt.Exec failed: %v", err)
131					b.Fail()
132				}
133			}
134		}()
135	}
136}
137
138// data, but no db writes
139var roundtripSample []byte
140
141func initRoundtripBenchmarks() ([]byte, int, int) {
142	if roundtripSample == nil {
143		roundtripSample = []byte(strings.Repeat("0123456789abcdef", 1024*1024))
144	}
145	return roundtripSample, 16, len(roundtripSample)
146}
147
148func BenchmarkRoundtripTxt(b *testing.B) {
149	b.StopTimer()
150	sample, min, max := initRoundtripBenchmarks()
151	sampleString := string(sample)
152	b.ReportAllocs()
153	tb := (*TB)(b)
154	db := tb.checkDB(sql.Open("mysql", dsn))
155	defer db.Close()
156	b.StartTimer()
157	var result string
158	for i := 0; i < b.N; i++ {
159		length := min + i
160		if length > max {
161			length = max
162		}
163		test := sampleString[0:length]
164		rows := tb.checkRows(db.Query(`SELECT "` + test + `"`))
165		if !rows.Next() {
166			rows.Close()
167			b.Fatalf("crashed")
168		}
169		err := rows.Scan(&result)
170		if err != nil {
171			rows.Close()
172			b.Fatalf("crashed")
173		}
174		if result != test {
175			rows.Close()
176			b.Errorf("mismatch")
177		}
178		rows.Close()
179	}
180}
181
182func BenchmarkRoundtripBin(b *testing.B) {
183	b.StopTimer()
184	sample, min, max := initRoundtripBenchmarks()
185	b.ReportAllocs()
186	tb := (*TB)(b)
187	db := tb.checkDB(sql.Open("mysql", dsn))
188	defer db.Close()
189	stmt := tb.checkStmt(db.Prepare("SELECT ?"))
190	defer stmt.Close()
191	b.StartTimer()
192	var result sql.RawBytes
193	for i := 0; i < b.N; i++ {
194		length := min + i
195		if length > max {
196			length = max
197		}
198		test := sample[0:length]
199		rows := tb.checkRows(stmt.Query(test))
200		if !rows.Next() {
201			rows.Close()
202			b.Fatalf("crashed")
203		}
204		err := rows.Scan(&result)
205		if err != nil {
206			rows.Close()
207			b.Fatalf("crashed")
208		}
209		if !bytes.Equal(result, test) {
210			rows.Close()
211			b.Errorf("mismatch")
212		}
213		rows.Close()
214	}
215}
216
217func BenchmarkInterpolation(b *testing.B) {
218	mc := &mysqlConn{
219		cfg: &Config{
220			InterpolateParams: true,
221			Loc:               time.UTC,
222		},
223		maxAllowedPacket: maxPacketSize,
224		maxWriteSize:     maxPacketSize - 1,
225		buf:              newBuffer(nil),
226	}
227
228	args := []driver.Value{
229		int64(42424242),
230		float64(math.Pi),
231		false,
232		time.Unix(1423411542, 807015000),
233		[]byte("bytes containing special chars ' \" \a \x00"),
234		"string containing special chars ' \" \a \x00",
235	}
236	q := "SELECT ?, ?, ?, ?, ?, ?"
237
238	b.ReportAllocs()
239	b.ResetTimer()
240	for i := 0; i < b.N; i++ {
241		_, err := mc.interpolateParams(q, args)
242		if err != nil {
243			b.Fatal(err)
244		}
245	}
246}
247
248func benchmarkQueryContext(b *testing.B, db *sql.DB, p int) {
249	ctx, cancel := context.WithCancel(context.Background())
250	defer cancel()
251	db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0))
252
253	tb := (*TB)(b)
254	stmt := tb.checkStmt(db.PrepareContext(ctx, "SELECT val FROM foo WHERE id=?"))
255	defer stmt.Close()
256
257	b.SetParallelism(p)
258	b.ReportAllocs()
259	b.ResetTimer()
260	b.RunParallel(func(pb *testing.PB) {
261		var got string
262		for pb.Next() {
263			tb.check(stmt.QueryRow(1).Scan(&got))
264			if got != "one" {
265				b.Fatalf("query = %q; want one", got)
266			}
267		}
268	})
269}
270
271func BenchmarkQueryContext(b *testing.B) {
272	db := initDB(b,
273		"DROP TABLE IF EXISTS foo",
274		"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
275		`INSERT INTO foo VALUES (1, "one")`,
276		`INSERT INTO foo VALUES (2, "two")`,
277	)
278	defer db.Close()
279	for _, p := range []int{1, 2, 3, 4} {
280		b.Run(fmt.Sprintf("%d", p), func(b *testing.B) {
281			benchmarkQueryContext(b, db, p)
282		})
283	}
284}
285
286func benchmarkExecContext(b *testing.B, db *sql.DB, p int) {
287	ctx, cancel := context.WithCancel(context.Background())
288	defer cancel()
289	db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0))
290
291	tb := (*TB)(b)
292	stmt := tb.checkStmt(db.PrepareContext(ctx, "DO 1"))
293	defer stmt.Close()
294
295	b.SetParallelism(p)
296	b.ReportAllocs()
297	b.ResetTimer()
298	b.RunParallel(func(pb *testing.PB) {
299		for pb.Next() {
300			if _, err := stmt.ExecContext(ctx); err != nil {
301				b.Fatal(err)
302			}
303		}
304	})
305}
306
307func BenchmarkExecContext(b *testing.B) {
308	db := initDB(b,
309		"DROP TABLE IF EXISTS foo",
310		"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
311		`INSERT INTO foo VALUES (1, "one")`,
312		`INSERT INTO foo VALUES (2, "two")`,
313	)
314	defer db.Close()
315	for _, p := range []int{1, 2, 3, 4} {
316		b.Run(fmt.Sprintf("%d", p), func(b *testing.B) {
317			benchmarkQueryContext(b, db, p)
318		})
319	}
320}
321
322// BenchmarkQueryRawBytes benchmarks fetching 100 blobs using sql.RawBytes.
323// "size=" means size of each blobs.
324func BenchmarkQueryRawBytes(b *testing.B) {
325	var sizes []int = []int{100, 1000, 2000, 4000, 8000, 12000, 16000, 32000, 64000, 256000}
326	db := initDB(b,
327		"DROP TABLE IF EXISTS bench_rawbytes",
328		"CREATE TABLE bench_rawbytes (id INT PRIMARY KEY, val LONGBLOB)",
329	)
330	defer db.Close()
331
332	blob := make([]byte, sizes[len(sizes)-1])
333	for i := range blob {
334		blob[i] = 42
335	}
336	for i := 0; i < 100; i++ {
337		_, err := db.Exec("INSERT INTO bench_rawbytes VALUES (?, ?)", i, blob)
338		if err != nil {
339			b.Fatal(err)
340		}
341	}
342
343	for _, s := range sizes {
344		b.Run(fmt.Sprintf("size=%v", s), func(b *testing.B) {
345			db.SetMaxIdleConns(0)
346			db.SetMaxIdleConns(1)
347			b.ReportAllocs()
348			b.ResetTimer()
349
350			for j := 0; j < b.N; j++ {
351				rows, err := db.Query("SELECT LEFT(val, ?) as v FROM bench_rawbytes", s)
352				if err != nil {
353					b.Fatal(err)
354				}
355				nrows := 0
356				for rows.Next() {
357					var buf sql.RawBytes
358					err := rows.Scan(&buf)
359					if err != nil {
360						b.Fatal(err)
361					}
362					if len(buf) != s {
363						b.Fatalf("size mismatch: expected %v, got %v", s, len(buf))
364					}
365					nrows++
366				}
367				rows.Close()
368				if nrows != 100 {
369					b.Fatalf("numbers of rows mismatch: expected %v, got %v", 100, nrows)
370				}
371			}
372		})
373	}
374}
375