1// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
2//
3// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
4//
5// This Source Code Form is subject to the terms of the Mozilla Public
6// License, v. 2.0. If a copy of the MPL was not distributed with this file,
7// You can obtain one at http://mozilla.org/MPL/2.0/.
8
9package mysql
10
11import (
12	"crypto/tls"
13	"database/sql"
14	"database/sql/driver"
15	"encoding/binary"
16	"errors"
17	"fmt"
18	"io"
19	"strconv"
20	"strings"
21	"sync"
22	"sync/atomic"
23	"time"
24)
25
26// Registry for custom tls.Configs
27var (
28	tlsConfigLock     sync.RWMutex
29	tlsConfigRegistry map[string]*tls.Config
30)
31
32// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open.
33// Use the key as a value in the DSN where tls=value.
34//
35// Note: The provided tls.Config is exclusively owned by the driver after
36// registering it.
37//
38//  rootCertPool := x509.NewCertPool()
39//  pem, err := ioutil.ReadFile("/path/ca-cert.pem")
40//  if err != nil {
41//      log.Fatal(err)
42//  }
43//  if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
44//      log.Fatal("Failed to append PEM.")
45//  }
46//  clientCert := make([]tls.Certificate, 0, 1)
47//  certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem")
48//  if err != nil {
49//      log.Fatal(err)
50//  }
51//  clientCert = append(clientCert, certs)
52//  mysql.RegisterTLSConfig("custom", &tls.Config{
53//      RootCAs: rootCertPool,
54//      Certificates: clientCert,
55//  })
56//  db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
57//
58func RegisterTLSConfig(key string, config *tls.Config) error {
59	if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" || strings.ToLower(key) == "preferred" {
60		return fmt.Errorf("key '%s' is reserved", key)
61	}
62
63	tlsConfigLock.Lock()
64	if tlsConfigRegistry == nil {
65		tlsConfigRegistry = make(map[string]*tls.Config)
66	}
67
68	tlsConfigRegistry[key] = config
69	tlsConfigLock.Unlock()
70	return nil
71}
72
73// DeregisterTLSConfig removes the tls.Config associated with key.
74func DeregisterTLSConfig(key string) {
75	tlsConfigLock.Lock()
76	if tlsConfigRegistry != nil {
77		delete(tlsConfigRegistry, key)
78	}
79	tlsConfigLock.Unlock()
80}
81
82func getTLSConfigClone(key string) (config *tls.Config) {
83	tlsConfigLock.RLock()
84	if v, ok := tlsConfigRegistry[key]; ok {
85		config = v.Clone()
86	}
87	tlsConfigLock.RUnlock()
88	return
89}
90
91// Returns the bool value of the input.
92// The 2nd return value indicates if the input was a valid bool value
93func readBool(input string) (value bool, valid bool) {
94	switch input {
95	case "1", "true", "TRUE", "True":
96		return true, true
97	case "0", "false", "FALSE", "False":
98		return false, true
99	}
100
101	// Not a valid bool value
102	return
103}
104
105/******************************************************************************
106*                           Time related utils                                *
107******************************************************************************/
108
109func parseDateTime(b []byte, loc *time.Location) (time.Time, error) {
110	const base = "0000-00-00 00:00:00.000000"
111	switch len(b) {
112	case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM"
113		if string(b) == base[:len(b)] {
114			return time.Time{}, nil
115		}
116
117		year, err := parseByteYear(b)
118		if err != nil {
119			return time.Time{}, err
120		}
121		if year <= 0 {
122			year = 1
123		}
124
125		if b[4] != '-' {
126			return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[4])
127		}
128
129		m, err := parseByte2Digits(b[5], b[6])
130		if err != nil {
131			return time.Time{}, err
132		}
133		if m <= 0 {
134			m = 1
135		}
136		month := time.Month(m)
137
138		if b[7] != '-' {
139			return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[7])
140		}
141
142		day, err := parseByte2Digits(b[8], b[9])
143		if err != nil {
144			return time.Time{}, err
145		}
146		if day <= 0 {
147			day = 1
148		}
149		if len(b) == 10 {
150			return time.Date(year, month, day, 0, 0, 0, 0, loc), nil
151		}
152
153		if b[10] != ' ' {
154			return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[10])
155		}
156
157		hour, err := parseByte2Digits(b[11], b[12])
158		if err != nil {
159			return time.Time{}, err
160		}
161		if b[13] != ':' {
162			return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[13])
163		}
164
165		min, err := parseByte2Digits(b[14], b[15])
166		if err != nil {
167			return time.Time{}, err
168		}
169		if b[16] != ':' {
170			return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[16])
171		}
172
173		sec, err := parseByte2Digits(b[17], b[18])
174		if err != nil {
175			return time.Time{}, err
176		}
177		if len(b) == 19 {
178			return time.Date(year, month, day, hour, min, sec, 0, loc), nil
179		}
180
181		if b[19] != '.' {
182			return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[19])
183		}
184		nsec, err := parseByteNanoSec(b[20:])
185		if err != nil {
186			return time.Time{}, err
187		}
188		return time.Date(year, month, day, hour, min, sec, nsec, loc), nil
189	default:
190		return time.Time{}, fmt.Errorf("invalid time bytes: %s", b)
191	}
192}
193
194func parseByteYear(b []byte) (int, error) {
195	year, n := 0, 1000
196	for i := 0; i < 4; i++ {
197		v, err := bToi(b[i])
198		if err != nil {
199			return 0, err
200		}
201		year += v * n
202		n = n / 10
203	}
204	return year, nil
205}
206
207func parseByte2Digits(b1, b2 byte) (int, error) {
208	d1, err := bToi(b1)
209	if err != nil {
210		return 0, err
211	}
212	d2, err := bToi(b2)
213	if err != nil {
214		return 0, err
215	}
216	return d1*10 + d2, nil
217}
218
219func parseByteNanoSec(b []byte) (int, error) {
220	ns, digit := 0, 100000 // max is 6-digits
221	for i := 0; i < len(b); i++ {
222		v, err := bToi(b[i])
223		if err != nil {
224			return 0, err
225		}
226		ns += v * digit
227		digit /= 10
228	}
229	// nanoseconds has 10-digits. (needs to scale digits)
230	// 10 - 6 = 4, so we have to multiple 1000.
231	return ns * 1000, nil
232}
233
234func bToi(b byte) (int, error) {
235	if b < '0' || b > '9' {
236		return 0, errors.New("not [0-9]")
237	}
238	return int(b - '0'), nil
239}
240
241func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) {
242	switch num {
243	case 0:
244		return time.Time{}, nil
245	case 4:
246		return time.Date(
247			int(binary.LittleEndian.Uint16(data[:2])), // year
248			time.Month(data[2]),                       // month
249			int(data[3]),                              // day
250			0, 0, 0, 0,
251			loc,
252		), nil
253	case 7:
254		return time.Date(
255			int(binary.LittleEndian.Uint16(data[:2])), // year
256			time.Month(data[2]),                       // month
257			int(data[3]),                              // day
258			int(data[4]),                              // hour
259			int(data[5]),                              // minutes
260			int(data[6]),                              // seconds
261			0,
262			loc,
263		), nil
264	case 11:
265		return time.Date(
266			int(binary.LittleEndian.Uint16(data[:2])), // year
267			time.Month(data[2]),                       // month
268			int(data[3]),                              // day
269			int(data[4]),                              // hour
270			int(data[5]),                              // minutes
271			int(data[6]),                              // seconds
272			int(binary.LittleEndian.Uint32(data[7:11]))*1000, // nanoseconds
273			loc,
274		), nil
275	}
276	return nil, fmt.Errorf("invalid DATETIME packet length %d", num)
277}
278
279func appendDateTime(buf []byte, t time.Time) ([]byte, error) {
280	year, month, day := t.Date()
281	hour, min, sec := t.Clock()
282	nsec := t.Nanosecond()
283
284	if year < 1 || year > 9999 {
285		return buf, errors.New("year is not in the range [1, 9999]: " + strconv.Itoa(year)) // use errors.New instead of fmt.Errorf to avoid year escape to heap
286	}
287	year100 := year / 100
288	year1 := year % 100
289
290	var localBuf [len("2006-01-02T15:04:05.999999999")]byte // does not escape
291	localBuf[0], localBuf[1], localBuf[2], localBuf[3] = digits10[year100], digits01[year100], digits10[year1], digits01[year1]
292	localBuf[4] = '-'
293	localBuf[5], localBuf[6] = digits10[month], digits01[month]
294	localBuf[7] = '-'
295	localBuf[8], localBuf[9] = digits10[day], digits01[day]
296
297	if hour == 0 && min == 0 && sec == 0 && nsec == 0 {
298		return append(buf, localBuf[:10]...), nil
299	}
300
301	localBuf[10] = ' '
302	localBuf[11], localBuf[12] = digits10[hour], digits01[hour]
303	localBuf[13] = ':'
304	localBuf[14], localBuf[15] = digits10[min], digits01[min]
305	localBuf[16] = ':'
306	localBuf[17], localBuf[18] = digits10[sec], digits01[sec]
307
308	if nsec == 0 {
309		return append(buf, localBuf[:19]...), nil
310	}
311	nsec100000000 := nsec / 100000000
312	nsec1000000 := (nsec / 1000000) % 100
313	nsec10000 := (nsec / 10000) % 100
314	nsec100 := (nsec / 100) % 100
315	nsec1 := nsec % 100
316	localBuf[19] = '.'
317
318	// milli second
319	localBuf[20], localBuf[21], localBuf[22] =
320		digits01[nsec100000000], digits10[nsec1000000], digits01[nsec1000000]
321	// micro second
322	localBuf[23], localBuf[24], localBuf[25] =
323		digits10[nsec10000], digits01[nsec10000], digits10[nsec100]
324	// nano second
325	localBuf[26], localBuf[27], localBuf[28] =
326		digits01[nsec100], digits10[nsec1], digits01[nsec1]
327
328	// trim trailing zeros
329	n := len(localBuf)
330	for n > 0 && localBuf[n-1] == '0' {
331		n--
332	}
333
334	return append(buf, localBuf[:n]...), nil
335}
336
337// zeroDateTime is used in formatBinaryDateTime to avoid an allocation
338// if the DATE or DATETIME has the zero value.
339// It must never be changed.
340// The current behavior depends on database/sql copying the result.
341var zeroDateTime = []byte("0000-00-00 00:00:00.000000")
342
343const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789"
344const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999"
345
346func appendMicrosecs(dst, src []byte, decimals int) []byte {
347	if decimals <= 0 {
348		return dst
349	}
350	if len(src) == 0 {
351		return append(dst, ".000000"[:decimals+1]...)
352	}
353
354	microsecs := binary.LittleEndian.Uint32(src[:4])
355	p1 := byte(microsecs / 10000)
356	microsecs -= 10000 * uint32(p1)
357	p2 := byte(microsecs / 100)
358	microsecs -= 100 * uint32(p2)
359	p3 := byte(microsecs)
360
361	switch decimals {
362	default:
363		return append(dst, '.',
364			digits10[p1], digits01[p1],
365			digits10[p2], digits01[p2],
366			digits10[p3], digits01[p3],
367		)
368	case 1:
369		return append(dst, '.',
370			digits10[p1],
371		)
372	case 2:
373		return append(dst, '.',
374			digits10[p1], digits01[p1],
375		)
376	case 3:
377		return append(dst, '.',
378			digits10[p1], digits01[p1],
379			digits10[p2],
380		)
381	case 4:
382		return append(dst, '.',
383			digits10[p1], digits01[p1],
384			digits10[p2], digits01[p2],
385		)
386	case 5:
387		return append(dst, '.',
388			digits10[p1], digits01[p1],
389			digits10[p2], digits01[p2],
390			digits10[p3],
391		)
392	}
393}
394
395func formatBinaryDateTime(src []byte, length uint8) (driver.Value, error) {
396	// length expects the deterministic length of the zero value,
397	// negative time and 100+ hours are automatically added if needed
398	if len(src) == 0 {
399		return zeroDateTime[:length], nil
400	}
401	var dst []byte      // return value
402	var p1, p2, p3 byte // current digit pair
403
404	switch length {
405	case 10, 19, 21, 22, 23, 24, 25, 26:
406	default:
407		t := "DATE"
408		if length > 10 {
409			t += "TIME"
410		}
411		return nil, fmt.Errorf("illegal %s length %d", t, length)
412	}
413	switch len(src) {
414	case 4, 7, 11:
415	default:
416		t := "DATE"
417		if length > 10 {
418			t += "TIME"
419		}
420		return nil, fmt.Errorf("illegal %s packet length %d", t, len(src))
421	}
422	dst = make([]byte, 0, length)
423	// start with the date
424	year := binary.LittleEndian.Uint16(src[:2])
425	pt := year / 100
426	p1 = byte(year - 100*uint16(pt))
427	p2, p3 = src[2], src[3]
428	dst = append(dst,
429		digits10[pt], digits01[pt],
430		digits10[p1], digits01[p1], '-',
431		digits10[p2], digits01[p2], '-',
432		digits10[p3], digits01[p3],
433	)
434	if length == 10 {
435		return dst, nil
436	}
437	if len(src) == 4 {
438		return append(dst, zeroDateTime[10:length]...), nil
439	}
440	dst = append(dst, ' ')
441	p1 = src[4] // hour
442	src = src[5:]
443
444	// p1 is 2-digit hour, src is after hour
445	p2, p3 = src[0], src[1]
446	dst = append(dst,
447		digits10[p1], digits01[p1], ':',
448		digits10[p2], digits01[p2], ':',
449		digits10[p3], digits01[p3],
450	)
451	return appendMicrosecs(dst, src[2:], int(length)-20), nil
452}
453
454func formatBinaryTime(src []byte, length uint8) (driver.Value, error) {
455	// length expects the deterministic length of the zero value,
456	// negative time and 100+ hours are automatically added if needed
457	if len(src) == 0 {
458		return zeroDateTime[11 : 11+length], nil
459	}
460	var dst []byte // return value
461
462	switch length {
463	case
464		8,                      // time (can be up to 10 when negative and 100+ hours)
465		10, 11, 12, 13, 14, 15: // time with fractional seconds
466	default:
467		return nil, fmt.Errorf("illegal TIME length %d", length)
468	}
469	switch len(src) {
470	case 8, 12:
471	default:
472		return nil, fmt.Errorf("invalid TIME packet length %d", len(src))
473	}
474	// +2 to enable negative time and 100+ hours
475	dst = make([]byte, 0, length+2)
476	if src[0] == 1 {
477		dst = append(dst, '-')
478	}
479	days := binary.LittleEndian.Uint32(src[1:5])
480	hours := int64(days)*24 + int64(src[5])
481
482	if hours >= 100 {
483		dst = strconv.AppendInt(dst, hours, 10)
484	} else {
485		dst = append(dst, digits10[hours], digits01[hours])
486	}
487
488	min, sec := src[6], src[7]
489	dst = append(dst, ':',
490		digits10[min], digits01[min], ':',
491		digits10[sec], digits01[sec],
492	)
493	return appendMicrosecs(dst, src[8:], int(length)-9), nil
494}
495
496/******************************************************************************
497*                       Convert from and to bytes                             *
498******************************************************************************/
499
500func uint64ToBytes(n uint64) []byte {
501	return []byte{
502		byte(n),
503		byte(n >> 8),
504		byte(n >> 16),
505		byte(n >> 24),
506		byte(n >> 32),
507		byte(n >> 40),
508		byte(n >> 48),
509		byte(n >> 56),
510	}
511}
512
513func uint64ToString(n uint64) []byte {
514	var a [20]byte
515	i := 20
516
517	// U+0030 = 0
518	// ...
519	// U+0039 = 9
520
521	var q uint64
522	for n >= 10 {
523		i--
524		q = n / 10
525		a[i] = uint8(n-q*10) + 0x30
526		n = q
527	}
528
529	i--
530	a[i] = uint8(n) + 0x30
531
532	return a[i:]
533}
534
535// treats string value as unsigned integer representation
536func stringToInt(b []byte) int {
537	val := 0
538	for i := range b {
539		val *= 10
540		val += int(b[i] - 0x30)
541	}
542	return val
543}
544
545// returns the string read as a bytes slice, wheter the value is NULL,
546// the number of bytes read and an error, in case the string is longer than
547// the input slice
548func readLengthEncodedString(b []byte) ([]byte, bool, int, error) {
549	// Get length
550	num, isNull, n := readLengthEncodedInteger(b)
551	if num < 1 {
552		return b[n:n], isNull, n, nil
553	}
554
555	n += int(num)
556
557	// Check data length
558	if len(b) >= n {
559		return b[n-int(num) : n : n], false, n, nil
560	}
561	return nil, false, n, io.EOF
562}
563
564// returns the number of bytes skipped and an error, in case the string is
565// longer than the input slice
566func skipLengthEncodedString(b []byte) (int, error) {
567	// Get length
568	num, _, n := readLengthEncodedInteger(b)
569	if num < 1 {
570		return n, nil
571	}
572
573	n += int(num)
574
575	// Check data length
576	if len(b) >= n {
577		return n, nil
578	}
579	return n, io.EOF
580}
581
582// returns the number read, whether the value is NULL and the number of bytes read
583func readLengthEncodedInteger(b []byte) (uint64, bool, int) {
584	// See issue #349
585	if len(b) == 0 {
586		return 0, true, 1
587	}
588
589	switch b[0] {
590	// 251: NULL
591	case 0xfb:
592		return 0, true, 1
593
594	// 252: value of following 2
595	case 0xfc:
596		return uint64(b[1]) | uint64(b[2])<<8, false, 3
597
598	// 253: value of following 3
599	case 0xfd:
600		return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4
601
602	// 254: value of following 8
603	case 0xfe:
604		return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
605				uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
606				uint64(b[7])<<48 | uint64(b[8])<<56,
607			false, 9
608	}
609
610	// 0-250: value of first byte
611	return uint64(b[0]), false, 1
612}
613
614// encodes a uint64 value and appends it to the given bytes slice
615func appendLengthEncodedInteger(b []byte, n uint64) []byte {
616	switch {
617	case n <= 250:
618		return append(b, byte(n))
619
620	case n <= 0xffff:
621		return append(b, 0xfc, byte(n), byte(n>>8))
622
623	case n <= 0xffffff:
624		return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16))
625	}
626	return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24),
627		byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56))
628}
629
630// reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize.
631// If cap(buf) is not enough, reallocate new buffer.
632func reserveBuffer(buf []byte, appendSize int) []byte {
633	newSize := len(buf) + appendSize
634	if cap(buf) < newSize {
635		// Grow buffer exponentially
636		newBuf := make([]byte, len(buf)*2+appendSize)
637		copy(newBuf, buf)
638		buf = newBuf
639	}
640	return buf[:newSize]
641}
642
643// escapeBytesBackslash escapes []byte with backslashes (\)
644// This escapes the contents of a string (provided as []byte) by adding backslashes before special
645// characters, and turning others into specific escape sequences, such as
646// turning newlines into \n and null bytes into \0.
647// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932
648func escapeBytesBackslash(buf, v []byte) []byte {
649	pos := len(buf)
650	buf = reserveBuffer(buf, len(v)*2)
651
652	for _, c := range v {
653		switch c {
654		case '\x00':
655			buf[pos] = '\\'
656			buf[pos+1] = '0'
657			pos += 2
658		case '\n':
659			buf[pos] = '\\'
660			buf[pos+1] = 'n'
661			pos += 2
662		case '\r':
663			buf[pos] = '\\'
664			buf[pos+1] = 'r'
665			pos += 2
666		case '\x1a':
667			buf[pos] = '\\'
668			buf[pos+1] = 'Z'
669			pos += 2
670		case '\'':
671			buf[pos] = '\\'
672			buf[pos+1] = '\''
673			pos += 2
674		case '"':
675			buf[pos] = '\\'
676			buf[pos+1] = '"'
677			pos += 2
678		case '\\':
679			buf[pos] = '\\'
680			buf[pos+1] = '\\'
681			pos += 2
682		default:
683			buf[pos] = c
684			pos++
685		}
686	}
687
688	return buf[:pos]
689}
690
691// escapeStringBackslash is similar to escapeBytesBackslash but for string.
692func escapeStringBackslash(buf []byte, v string) []byte {
693	pos := len(buf)
694	buf = reserveBuffer(buf, len(v)*2)
695
696	for i := 0; i < len(v); i++ {
697		c := v[i]
698		switch c {
699		case '\x00':
700			buf[pos] = '\\'
701			buf[pos+1] = '0'
702			pos += 2
703		case '\n':
704			buf[pos] = '\\'
705			buf[pos+1] = 'n'
706			pos += 2
707		case '\r':
708			buf[pos] = '\\'
709			buf[pos+1] = 'r'
710			pos += 2
711		case '\x1a':
712			buf[pos] = '\\'
713			buf[pos+1] = 'Z'
714			pos += 2
715		case '\'':
716			buf[pos] = '\\'
717			buf[pos+1] = '\''
718			pos += 2
719		case '"':
720			buf[pos] = '\\'
721			buf[pos+1] = '"'
722			pos += 2
723		case '\\':
724			buf[pos] = '\\'
725			buf[pos+1] = '\\'
726			pos += 2
727		default:
728			buf[pos] = c
729			pos++
730		}
731	}
732
733	return buf[:pos]
734}
735
736// escapeBytesQuotes escapes apostrophes in []byte by doubling them up.
737// This escapes the contents of a string by doubling up any apostrophes that
738// it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in
739// effect on the server.
740// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038
741func escapeBytesQuotes(buf, v []byte) []byte {
742	pos := len(buf)
743	buf = reserveBuffer(buf, len(v)*2)
744
745	for _, c := range v {
746		if c == '\'' {
747			buf[pos] = '\''
748			buf[pos+1] = '\''
749			pos += 2
750		} else {
751			buf[pos] = c
752			pos++
753		}
754	}
755
756	return buf[:pos]
757}
758
759// escapeStringQuotes is similar to escapeBytesQuotes but for string.
760func escapeStringQuotes(buf []byte, v string) []byte {
761	pos := len(buf)
762	buf = reserveBuffer(buf, len(v)*2)
763
764	for i := 0; i < len(v); i++ {
765		c := v[i]
766		if c == '\'' {
767			buf[pos] = '\''
768			buf[pos+1] = '\''
769			pos += 2
770		} else {
771			buf[pos] = c
772			pos++
773		}
774	}
775
776	return buf[:pos]
777}
778
779/******************************************************************************
780*                               Sync utils                                    *
781******************************************************************************/
782
783// noCopy may be embedded into structs which must not be copied
784// after the first use.
785//
786// See https://github.com/golang/go/issues/8005#issuecomment-190753527
787// for details.
788type noCopy struct{}
789
790// Lock is a no-op used by -copylocks checker from `go vet`.
791func (*noCopy) Lock() {}
792
793// atomicBool is a wrapper around uint32 for usage as a boolean value with
794// atomic access.
795type atomicBool struct {
796	_noCopy noCopy
797	value   uint32
798}
799
800// IsSet returns whether the current boolean value is true
801func (ab *atomicBool) IsSet() bool {
802	return atomic.LoadUint32(&ab.value) > 0
803}
804
805// Set sets the value of the bool regardless of the previous value
806func (ab *atomicBool) Set(value bool) {
807	if value {
808		atomic.StoreUint32(&ab.value, 1)
809	} else {
810		atomic.StoreUint32(&ab.value, 0)
811	}
812}
813
814// TrySet sets the value of the bool and returns whether the value changed
815func (ab *atomicBool) TrySet(value bool) bool {
816	if value {
817		return atomic.SwapUint32(&ab.value, 1) == 0
818	}
819	return atomic.SwapUint32(&ab.value, 0) > 0
820}
821
822// atomicError is a wrapper for atomically accessed error values
823type atomicError struct {
824	_noCopy noCopy
825	value   atomic.Value
826}
827
828// Set sets the error value regardless of the previous value.
829// The value must not be nil
830func (ae *atomicError) Set(value error) {
831	ae.value.Store(value)
832}
833
834// Value returns the current error value
835func (ae *atomicError) Value() error {
836	if v := ae.value.Load(); v != nil {
837		// this will panic if the value doesn't implement the error interface
838		return v.(error)
839	}
840	return nil
841}
842
843func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
844	dargs := make([]driver.Value, len(named))
845	for n, param := range named {
846		if len(param.Name) > 0 {
847			// TODO: support the use of Named Parameters #561
848			return nil, errors.New("mysql: driver does not support the use of Named Parameters")
849		}
850		dargs[n] = param.Value
851	}
852	return dargs, nil
853}
854
855func mapIsolationLevel(level driver.IsolationLevel) (string, error) {
856	switch sql.IsolationLevel(level) {
857	case sql.LevelRepeatableRead:
858		return "REPEATABLE READ", nil
859	case sql.LevelReadCommitted:
860		return "READ COMMITTED", nil
861	case sql.LevelReadUncommitted:
862		return "READ UNCOMMITTED", nil
863	case sql.LevelSerializable:
864		return "SERIALIZABLE", nil
865	default:
866		return "", fmt.Errorf("mysql: unsupported isolation level: %v", level)
867	}
868}
869