1// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
2//
3// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
4//
5// This Source Code Form is subject to the terms of the Mozilla Public
6// License, v. 2.0. If a copy of the MPL was not distributed with this file,
7// You can obtain one at http://mozilla.org/MPL/2.0/.
8
9package mysql
10
11import (
12	"bytes"
13	"context"
14	"database/sql"
15	"database/sql/driver"
16	"fmt"
17	"math"
18	"runtime"
19	"strings"
20	"sync"
21	"sync/atomic"
22	"testing"
23	"time"
24)
25
26type TB testing.B
27
28func (tb *TB) check(err error) {
29	if err != nil {
30		tb.Fatal(err)
31	}
32}
33
34func (tb *TB) checkDB(db *sql.DB, err error) *sql.DB {
35	tb.check(err)
36	return db
37}
38
39func (tb *TB) checkRows(rows *sql.Rows, err error) *sql.Rows {
40	tb.check(err)
41	return rows
42}
43
44func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt {
45	tb.check(err)
46	return stmt
47}
48
49func initDB(b *testing.B, queries ...string) *sql.DB {
50	tb := (*TB)(b)
51	db := tb.checkDB(sql.Open("mysql", dsn))
52	for _, query := range queries {
53		if _, err := db.Exec(query); err != nil {
54			b.Fatalf("error on %q: %v", query, err)
55		}
56	}
57	return db
58}
59
60const concurrencyLevel = 10
61
62func BenchmarkQuery(b *testing.B) {
63	tb := (*TB)(b)
64	b.StopTimer()
65	b.ReportAllocs()
66	db := initDB(b,
67		"DROP TABLE IF EXISTS foo",
68		"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
69		`INSERT INTO foo VALUES (1, "one")`,
70		`INSERT INTO foo VALUES (2, "two")`,
71	)
72	db.SetMaxIdleConns(concurrencyLevel)
73	defer db.Close()
74
75	stmt := tb.checkStmt(db.Prepare("SELECT val FROM foo WHERE id=?"))
76	defer stmt.Close()
77
78	remain := int64(b.N)
79	var wg sync.WaitGroup
80	wg.Add(concurrencyLevel)
81	defer wg.Wait()
82	b.StartTimer()
83
84	for i := 0; i < concurrencyLevel; i++ {
85		go func() {
86			for {
87				if atomic.AddInt64(&remain, -1) < 0 {
88					wg.Done()
89					return
90				}
91
92				var got string
93				tb.check(stmt.QueryRow(1).Scan(&got))
94				if got != "one" {
95					b.Errorf("query = %q; want one", got)
96					wg.Done()
97					return
98				}
99			}
100		}()
101	}
102}
103
104func BenchmarkExec(b *testing.B) {
105	tb := (*TB)(b)
106	b.StopTimer()
107	b.ReportAllocs()
108	db := tb.checkDB(sql.Open("mysql", dsn))
109	db.SetMaxIdleConns(concurrencyLevel)
110	defer db.Close()
111
112	stmt := tb.checkStmt(db.Prepare("DO 1"))
113	defer stmt.Close()
114
115	remain := int64(b.N)
116	var wg sync.WaitGroup
117	wg.Add(concurrencyLevel)
118	defer wg.Wait()
119	b.StartTimer()
120
121	for i := 0; i < concurrencyLevel; i++ {
122		go func() {
123			for {
124				if atomic.AddInt64(&remain, -1) < 0 {
125					wg.Done()
126					return
127				}
128
129				if _, err := stmt.Exec(); err != nil {
130					b.Fatal(err.Error())
131				}
132			}
133		}()
134	}
135}
136
137// data, but no db writes
138var roundtripSample []byte
139
140func initRoundtripBenchmarks() ([]byte, int, int) {
141	if roundtripSample == nil {
142		roundtripSample = []byte(strings.Repeat("0123456789abcdef", 1024*1024))
143	}
144	return roundtripSample, 16, len(roundtripSample)
145}
146
147func BenchmarkRoundtripTxt(b *testing.B) {
148	b.StopTimer()
149	sample, min, max := initRoundtripBenchmarks()
150	sampleString := string(sample)
151	b.ReportAllocs()
152	tb := (*TB)(b)
153	db := tb.checkDB(sql.Open("mysql", dsn))
154	defer db.Close()
155	b.StartTimer()
156	var result string
157	for i := 0; i < b.N; i++ {
158		length := min + i
159		if length > max {
160			length = max
161		}
162		test := sampleString[0:length]
163		rows := tb.checkRows(db.Query(`SELECT "` + test + `"`))
164		if !rows.Next() {
165			rows.Close()
166			b.Fatalf("crashed")
167		}
168		err := rows.Scan(&result)
169		if err != nil {
170			rows.Close()
171			b.Fatalf("crashed")
172		}
173		if result != test {
174			rows.Close()
175			b.Errorf("mismatch")
176		}
177		rows.Close()
178	}
179}
180
181func BenchmarkRoundtripBin(b *testing.B) {
182	b.StopTimer()
183	sample, min, max := initRoundtripBenchmarks()
184	b.ReportAllocs()
185	tb := (*TB)(b)
186	db := tb.checkDB(sql.Open("mysql", dsn))
187	defer db.Close()
188	stmt := tb.checkStmt(db.Prepare("SELECT ?"))
189	defer stmt.Close()
190	b.StartTimer()
191	var result sql.RawBytes
192	for i := 0; i < b.N; i++ {
193		length := min + i
194		if length > max {
195			length = max
196		}
197		test := sample[0:length]
198		rows := tb.checkRows(stmt.Query(test))
199		if !rows.Next() {
200			rows.Close()
201			b.Fatalf("crashed")
202		}
203		err := rows.Scan(&result)
204		if err != nil {
205			rows.Close()
206			b.Fatalf("crashed")
207		}
208		if !bytes.Equal(result, test) {
209			rows.Close()
210			b.Errorf("mismatch")
211		}
212		rows.Close()
213	}
214}
215
216func BenchmarkInterpolation(b *testing.B) {
217	mc := &mysqlConn{
218		cfg: &Config{
219			InterpolateParams: true,
220			Loc:               time.UTC,
221		},
222		maxAllowedPacket: maxPacketSize,
223		maxWriteSize:     maxPacketSize - 1,
224		buf:              newBuffer(nil),
225	}
226
227	args := []driver.Value{
228		int64(42424242),
229		float64(math.Pi),
230		false,
231		time.Unix(1423411542, 807015000),
232		[]byte("bytes containing special chars ' \" \a \x00"),
233		"string containing special chars ' \" \a \x00",
234	}
235	q := "SELECT ?, ?, ?, ?, ?, ?"
236
237	b.ReportAllocs()
238	b.ResetTimer()
239	for i := 0; i < b.N; i++ {
240		_, err := mc.interpolateParams(q, args)
241		if err != nil {
242			b.Fatal(err)
243		}
244	}
245}
246
247func benchmarkQueryContext(b *testing.B, db *sql.DB, p int) {
248	ctx, cancel := context.WithCancel(context.Background())
249	defer cancel()
250	db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0))
251
252	tb := (*TB)(b)
253	stmt := tb.checkStmt(db.PrepareContext(ctx, "SELECT val FROM foo WHERE id=?"))
254	defer stmt.Close()
255
256	b.SetParallelism(p)
257	b.ReportAllocs()
258	b.ResetTimer()
259	b.RunParallel(func(pb *testing.PB) {
260		var got string
261		for pb.Next() {
262			tb.check(stmt.QueryRow(1).Scan(&got))
263			if got != "one" {
264				b.Fatalf("query = %q; want one", got)
265			}
266		}
267	})
268}
269
270func BenchmarkQueryContext(b *testing.B) {
271	db := initDB(b,
272		"DROP TABLE IF EXISTS foo",
273		"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
274		`INSERT INTO foo VALUES (1, "one")`,
275		`INSERT INTO foo VALUES (2, "two")`,
276	)
277	defer db.Close()
278	for _, p := range []int{1, 2, 3, 4} {
279		b.Run(fmt.Sprintf("%d", p), func(b *testing.B) {
280			benchmarkQueryContext(b, db, p)
281		})
282	}
283}
284
285func benchmarkExecContext(b *testing.B, db *sql.DB, p int) {
286	ctx, cancel := context.WithCancel(context.Background())
287	defer cancel()
288	db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0))
289
290	tb := (*TB)(b)
291	stmt := tb.checkStmt(db.PrepareContext(ctx, "DO 1"))
292	defer stmt.Close()
293
294	b.SetParallelism(p)
295	b.ReportAllocs()
296	b.ResetTimer()
297	b.RunParallel(func(pb *testing.PB) {
298		for pb.Next() {
299			if _, err := stmt.ExecContext(ctx); err != nil {
300				b.Fatal(err)
301			}
302		}
303	})
304}
305
306func BenchmarkExecContext(b *testing.B) {
307	db := initDB(b,
308		"DROP TABLE IF EXISTS foo",
309		"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
310		`INSERT INTO foo VALUES (1, "one")`,
311		`INSERT INTO foo VALUES (2, "two")`,
312	)
313	defer db.Close()
314	for _, p := range []int{1, 2, 3, 4} {
315		b.Run(fmt.Sprintf("%d", p), func(b *testing.B) {
316			benchmarkQueryContext(b, db, p)
317		})
318	}
319}
320
321// BenchmarkQueryRawBytes benchmarks fetching 100 blobs using sql.RawBytes.
322// "size=" means size of each blobs.
323func BenchmarkQueryRawBytes(b *testing.B) {
324	var sizes []int = []int{100, 1000, 2000, 4000, 8000, 12000, 16000, 32000, 64000, 256000}
325	db := initDB(b,
326		"DROP TABLE IF EXISTS bench_rawbytes",
327		"CREATE TABLE bench_rawbytes (id INT PRIMARY KEY, val LONGBLOB)",
328	)
329	defer db.Close()
330
331	blob := make([]byte, sizes[len(sizes)-1])
332	for i := range blob {
333		blob[i] = 42
334	}
335	for i := 0; i < 100; i++ {
336		_, err := db.Exec("INSERT INTO bench_rawbytes VALUES (?, ?)", i, blob)
337		if err != nil {
338			b.Fatal(err)
339		}
340	}
341
342	for _, s := range sizes {
343		b.Run(fmt.Sprintf("size=%v", s), func(b *testing.B) {
344			db.SetMaxIdleConns(0)
345			db.SetMaxIdleConns(1)
346			b.ReportAllocs()
347			b.ResetTimer()
348
349			for j := 0; j < b.N; j++ {
350				rows, err := db.Query("SELECT LEFT(val, ?) as v FROM bench_rawbytes", s)
351				if err != nil {
352					b.Fatal(err)
353				}
354				nrows := 0
355				for rows.Next() {
356					var buf sql.RawBytes
357					err := rows.Scan(&buf)
358					if err != nil {
359						b.Fatal(err)
360					}
361					if len(buf) != s {
362						b.Fatalf("size mismatch: expected %v, got %v", s, len(buf))
363					}
364					nrows++
365				}
366				rows.Close()
367				if nrows != 100 {
368					b.Fatalf("numbers of rows mismatch: expected %v, got %v", 100, nrows)
369				}
370			}
371		})
372	}
373}
374