1// Go MySQL Driver - A MySQL-Driver for Go's database/sql package 2// 3// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. 4// 5// This Source Code Form is subject to the terms of the Mozilla Public 6// License, v. 2.0. If a copy of the MPL was not distributed with this file, 7// You can obtain one at http://mozilla.org/MPL/2.0/. 8 9package mysql 10 11import ( 12 "crypto/tls" 13 "database/sql" 14 "database/sql/driver" 15 "encoding/binary" 16 "errors" 17 "fmt" 18 "io" 19 "strconv" 20 "strings" 21 "sync" 22 "sync/atomic" 23 "time" 24) 25 26// Registry for custom tls.Configs 27var ( 28 tlsConfigLock sync.RWMutex 29 tlsConfigRegistry map[string]*tls.Config 30) 31 32// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open. 33// Use the key as a value in the DSN where tls=value. 34// 35// Note: The provided tls.Config is exclusively owned by the driver after 36// registering it. 37// 38// rootCertPool := x509.NewCertPool() 39// pem, err := ioutil.ReadFile("/path/ca-cert.pem") 40// if err != nil { 41// log.Fatal(err) 42// } 43// if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { 44// log.Fatal("Failed to append PEM.") 45// } 46// clientCert := make([]tls.Certificate, 0, 1) 47// certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem") 48// if err != nil { 49// log.Fatal(err) 50// } 51// clientCert = append(clientCert, certs) 52// mysql.RegisterTLSConfig("custom", &tls.Config{ 53// RootCAs: rootCertPool, 54// Certificates: clientCert, 55// }) 56// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom") 57// 58func RegisterTLSConfig(key string, config *tls.Config) error { 59 if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" || strings.ToLower(key) == "preferred" { 60 return fmt.Errorf("key '%s' is reserved", key) 61 } 62 63 tlsConfigLock.Lock() 64 if tlsConfigRegistry == nil { 65 tlsConfigRegistry = make(map[string]*tls.Config) 66 } 67 68 tlsConfigRegistry[key] = config 69 tlsConfigLock.Unlock() 70 return nil 71} 72 73// DeregisterTLSConfig removes the tls.Config associated with key. 74func DeregisterTLSConfig(key string) { 75 tlsConfigLock.Lock() 76 if tlsConfigRegistry != nil { 77 delete(tlsConfigRegistry, key) 78 } 79 tlsConfigLock.Unlock() 80} 81 82func getTLSConfigClone(key string) (config *tls.Config) { 83 tlsConfigLock.RLock() 84 if v, ok := tlsConfigRegistry[key]; ok { 85 config = v.Clone() 86 } 87 tlsConfigLock.RUnlock() 88 return 89} 90 91// Returns the bool value of the input. 92// The 2nd return value indicates if the input was a valid bool value 93func readBool(input string) (value bool, valid bool) { 94 switch input { 95 case "1", "true", "TRUE", "True": 96 return true, true 97 case "0", "false", "FALSE", "False": 98 return false, true 99 } 100 101 // Not a valid bool value 102 return 103} 104 105/****************************************************************************** 106* Time related utils * 107******************************************************************************/ 108 109func parseDateTime(b []byte, loc *time.Location) (time.Time, error) { 110 const base = "0000-00-00 00:00:00.000000" 111 switch len(b) { 112 case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM" 113 if string(b) == base[:len(b)] { 114 return time.Time{}, nil 115 } 116 117 year, err := parseByteYear(b) 118 if err != nil { 119 return time.Time{}, err 120 } 121 if year <= 0 { 122 year = 1 123 } 124 125 if b[4] != '-' { 126 return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[4]) 127 } 128 129 m, err := parseByte2Digits(b[5], b[6]) 130 if err != nil { 131 return time.Time{}, err 132 } 133 if m <= 0 { 134 m = 1 135 } 136 month := time.Month(m) 137 138 if b[7] != '-' { 139 return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[7]) 140 } 141 142 day, err := parseByte2Digits(b[8], b[9]) 143 if err != nil { 144 return time.Time{}, err 145 } 146 if day <= 0 { 147 day = 1 148 } 149 if len(b) == 10 { 150 return time.Date(year, month, day, 0, 0, 0, 0, loc), nil 151 } 152 153 if b[10] != ' ' { 154 return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[10]) 155 } 156 157 hour, err := parseByte2Digits(b[11], b[12]) 158 if err != nil { 159 return time.Time{}, err 160 } 161 if b[13] != ':' { 162 return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[13]) 163 } 164 165 min, err := parseByte2Digits(b[14], b[15]) 166 if err != nil { 167 return time.Time{}, err 168 } 169 if b[16] != ':' { 170 return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[16]) 171 } 172 173 sec, err := parseByte2Digits(b[17], b[18]) 174 if err != nil { 175 return time.Time{}, err 176 } 177 if len(b) == 19 { 178 return time.Date(year, month, day, hour, min, sec, 0, loc), nil 179 } 180 181 if b[19] != '.' { 182 return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[19]) 183 } 184 nsec, err := parseByteNanoSec(b[20:]) 185 if err != nil { 186 return time.Time{}, err 187 } 188 return time.Date(year, month, day, hour, min, sec, nsec, loc), nil 189 default: 190 return time.Time{}, fmt.Errorf("invalid time bytes: %s", b) 191 } 192} 193 194func parseByteYear(b []byte) (int, error) { 195 year, n := 0, 1000 196 for i := 0; i < 4; i++ { 197 v, err := bToi(b[i]) 198 if err != nil { 199 return 0, err 200 } 201 year += v * n 202 n = n / 10 203 } 204 return year, nil 205} 206 207func parseByte2Digits(b1, b2 byte) (int, error) { 208 d1, err := bToi(b1) 209 if err != nil { 210 return 0, err 211 } 212 d2, err := bToi(b2) 213 if err != nil { 214 return 0, err 215 } 216 return d1*10 + d2, nil 217} 218 219func parseByteNanoSec(b []byte) (int, error) { 220 ns, digit := 0, 100000 // max is 6-digits 221 for i := 0; i < len(b); i++ { 222 v, err := bToi(b[i]) 223 if err != nil { 224 return 0, err 225 } 226 ns += v * digit 227 digit /= 10 228 } 229 // nanoseconds has 10-digits. (needs to scale digits) 230 // 10 - 6 = 4, so we have to multiple 1000. 231 return ns * 1000, nil 232} 233 234func bToi(b byte) (int, error) { 235 if b < '0' || b > '9' { 236 return 0, errors.New("not [0-9]") 237 } 238 return int(b - '0'), nil 239} 240 241func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) { 242 switch num { 243 case 0: 244 return time.Time{}, nil 245 case 4: 246 return time.Date( 247 int(binary.LittleEndian.Uint16(data[:2])), // year 248 time.Month(data[2]), // month 249 int(data[3]), // day 250 0, 0, 0, 0, 251 loc, 252 ), nil 253 case 7: 254 return time.Date( 255 int(binary.LittleEndian.Uint16(data[:2])), // year 256 time.Month(data[2]), // month 257 int(data[3]), // day 258 int(data[4]), // hour 259 int(data[5]), // minutes 260 int(data[6]), // seconds 261 0, 262 loc, 263 ), nil 264 case 11: 265 return time.Date( 266 int(binary.LittleEndian.Uint16(data[:2])), // year 267 time.Month(data[2]), // month 268 int(data[3]), // day 269 int(data[4]), // hour 270 int(data[5]), // minutes 271 int(data[6]), // seconds 272 int(binary.LittleEndian.Uint32(data[7:11]))*1000, // nanoseconds 273 loc, 274 ), nil 275 } 276 return nil, fmt.Errorf("invalid DATETIME packet length %d", num) 277} 278 279func appendDateTime(buf []byte, t time.Time) ([]byte, error) { 280 year, month, day := t.Date() 281 hour, min, sec := t.Clock() 282 nsec := t.Nanosecond() 283 284 if year < 1 || year > 9999 { 285 return buf, errors.New("year is not in the range [1, 9999]: " + strconv.Itoa(year)) // use errors.New instead of fmt.Errorf to avoid year escape to heap 286 } 287 year100 := year / 100 288 year1 := year % 100 289 290 var localBuf [len("2006-01-02T15:04:05.999999999")]byte // does not escape 291 localBuf[0], localBuf[1], localBuf[2], localBuf[3] = digits10[year100], digits01[year100], digits10[year1], digits01[year1] 292 localBuf[4] = '-' 293 localBuf[5], localBuf[6] = digits10[month], digits01[month] 294 localBuf[7] = '-' 295 localBuf[8], localBuf[9] = digits10[day], digits01[day] 296 297 if hour == 0 && min == 0 && sec == 0 && nsec == 0 { 298 return append(buf, localBuf[:10]...), nil 299 } 300 301 localBuf[10] = ' ' 302 localBuf[11], localBuf[12] = digits10[hour], digits01[hour] 303 localBuf[13] = ':' 304 localBuf[14], localBuf[15] = digits10[min], digits01[min] 305 localBuf[16] = ':' 306 localBuf[17], localBuf[18] = digits10[sec], digits01[sec] 307 308 if nsec == 0 { 309 return append(buf, localBuf[:19]...), nil 310 } 311 nsec100000000 := nsec / 100000000 312 nsec1000000 := (nsec / 1000000) % 100 313 nsec10000 := (nsec / 10000) % 100 314 nsec100 := (nsec / 100) % 100 315 nsec1 := nsec % 100 316 localBuf[19] = '.' 317 318 // milli second 319 localBuf[20], localBuf[21], localBuf[22] = 320 digits01[nsec100000000], digits10[nsec1000000], digits01[nsec1000000] 321 // micro second 322 localBuf[23], localBuf[24], localBuf[25] = 323 digits10[nsec10000], digits01[nsec10000], digits10[nsec100] 324 // nano second 325 localBuf[26], localBuf[27], localBuf[28] = 326 digits01[nsec100], digits10[nsec1], digits01[nsec1] 327 328 // trim trailing zeros 329 n := len(localBuf) 330 for n > 0 && localBuf[n-1] == '0' { 331 n-- 332 } 333 334 return append(buf, localBuf[:n]...), nil 335} 336 337// zeroDateTime is used in formatBinaryDateTime to avoid an allocation 338// if the DATE or DATETIME has the zero value. 339// It must never be changed. 340// The current behavior depends on database/sql copying the result. 341var zeroDateTime = []byte("0000-00-00 00:00:00.000000") 342 343const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" 344const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999" 345 346func appendMicrosecs(dst, src []byte, decimals int) []byte { 347 if decimals <= 0 { 348 return dst 349 } 350 if len(src) == 0 { 351 return append(dst, ".000000"[:decimals+1]...) 352 } 353 354 microsecs := binary.LittleEndian.Uint32(src[:4]) 355 p1 := byte(microsecs / 10000) 356 microsecs -= 10000 * uint32(p1) 357 p2 := byte(microsecs / 100) 358 microsecs -= 100 * uint32(p2) 359 p3 := byte(microsecs) 360 361 switch decimals { 362 default: 363 return append(dst, '.', 364 digits10[p1], digits01[p1], 365 digits10[p2], digits01[p2], 366 digits10[p3], digits01[p3], 367 ) 368 case 1: 369 return append(dst, '.', 370 digits10[p1], 371 ) 372 case 2: 373 return append(dst, '.', 374 digits10[p1], digits01[p1], 375 ) 376 case 3: 377 return append(dst, '.', 378 digits10[p1], digits01[p1], 379 digits10[p2], 380 ) 381 case 4: 382 return append(dst, '.', 383 digits10[p1], digits01[p1], 384 digits10[p2], digits01[p2], 385 ) 386 case 5: 387 return append(dst, '.', 388 digits10[p1], digits01[p1], 389 digits10[p2], digits01[p2], 390 digits10[p3], 391 ) 392 } 393} 394 395func formatBinaryDateTime(src []byte, length uint8) (driver.Value, error) { 396 // length expects the deterministic length of the zero value, 397 // negative time and 100+ hours are automatically added if needed 398 if len(src) == 0 { 399 return zeroDateTime[:length], nil 400 } 401 var dst []byte // return value 402 var p1, p2, p3 byte // current digit pair 403 404 switch length { 405 case 10, 19, 21, 22, 23, 24, 25, 26: 406 default: 407 t := "DATE" 408 if length > 10 { 409 t += "TIME" 410 } 411 return nil, fmt.Errorf("illegal %s length %d", t, length) 412 } 413 switch len(src) { 414 case 4, 7, 11: 415 default: 416 t := "DATE" 417 if length > 10 { 418 t += "TIME" 419 } 420 return nil, fmt.Errorf("illegal %s packet length %d", t, len(src)) 421 } 422 dst = make([]byte, 0, length) 423 // start with the date 424 year := binary.LittleEndian.Uint16(src[:2]) 425 pt := year / 100 426 p1 = byte(year - 100*uint16(pt)) 427 p2, p3 = src[2], src[3] 428 dst = append(dst, 429 digits10[pt], digits01[pt], 430 digits10[p1], digits01[p1], '-', 431 digits10[p2], digits01[p2], '-', 432 digits10[p3], digits01[p3], 433 ) 434 if length == 10 { 435 return dst, nil 436 } 437 if len(src) == 4 { 438 return append(dst, zeroDateTime[10:length]...), nil 439 } 440 dst = append(dst, ' ') 441 p1 = src[4] // hour 442 src = src[5:] 443 444 // p1 is 2-digit hour, src is after hour 445 p2, p3 = src[0], src[1] 446 dst = append(dst, 447 digits10[p1], digits01[p1], ':', 448 digits10[p2], digits01[p2], ':', 449 digits10[p3], digits01[p3], 450 ) 451 return appendMicrosecs(dst, src[2:], int(length)-20), nil 452} 453 454func formatBinaryTime(src []byte, length uint8) (driver.Value, error) { 455 // length expects the deterministic length of the zero value, 456 // negative time and 100+ hours are automatically added if needed 457 if len(src) == 0 { 458 return zeroDateTime[11 : 11+length], nil 459 } 460 var dst []byte // return value 461 462 switch length { 463 case 464 8, // time (can be up to 10 when negative and 100+ hours) 465 10, 11, 12, 13, 14, 15: // time with fractional seconds 466 default: 467 return nil, fmt.Errorf("illegal TIME length %d", length) 468 } 469 switch len(src) { 470 case 8, 12: 471 default: 472 return nil, fmt.Errorf("invalid TIME packet length %d", len(src)) 473 } 474 // +2 to enable negative time and 100+ hours 475 dst = make([]byte, 0, length+2) 476 if src[0] == 1 { 477 dst = append(dst, '-') 478 } 479 days := binary.LittleEndian.Uint32(src[1:5]) 480 hours := int64(days)*24 + int64(src[5]) 481 482 if hours >= 100 { 483 dst = strconv.AppendInt(dst, hours, 10) 484 } else { 485 dst = append(dst, digits10[hours], digits01[hours]) 486 } 487 488 min, sec := src[6], src[7] 489 dst = append(dst, ':', 490 digits10[min], digits01[min], ':', 491 digits10[sec], digits01[sec], 492 ) 493 return appendMicrosecs(dst, src[8:], int(length)-9), nil 494} 495 496/****************************************************************************** 497* Convert from and to bytes * 498******************************************************************************/ 499 500func uint64ToBytes(n uint64) []byte { 501 return []byte{ 502 byte(n), 503 byte(n >> 8), 504 byte(n >> 16), 505 byte(n >> 24), 506 byte(n >> 32), 507 byte(n >> 40), 508 byte(n >> 48), 509 byte(n >> 56), 510 } 511} 512 513func uint64ToString(n uint64) []byte { 514 var a [20]byte 515 i := 20 516 517 // U+0030 = 0 518 // ... 519 // U+0039 = 9 520 521 var q uint64 522 for n >= 10 { 523 i-- 524 q = n / 10 525 a[i] = uint8(n-q*10) + 0x30 526 n = q 527 } 528 529 i-- 530 a[i] = uint8(n) + 0x30 531 532 return a[i:] 533} 534 535// treats string value as unsigned integer representation 536func stringToInt(b []byte) int { 537 val := 0 538 for i := range b { 539 val *= 10 540 val += int(b[i] - 0x30) 541 } 542 return val 543} 544 545// returns the string read as a bytes slice, wheter the value is NULL, 546// the number of bytes read and an error, in case the string is longer than 547// the input slice 548func readLengthEncodedString(b []byte) ([]byte, bool, int, error) { 549 // Get length 550 num, isNull, n := readLengthEncodedInteger(b) 551 if num < 1 { 552 return b[n:n], isNull, n, nil 553 } 554 555 n += int(num) 556 557 // Check data length 558 if len(b) >= n { 559 return b[n-int(num) : n : n], false, n, nil 560 } 561 return nil, false, n, io.EOF 562} 563 564// returns the number of bytes skipped and an error, in case the string is 565// longer than the input slice 566func skipLengthEncodedString(b []byte) (int, error) { 567 // Get length 568 num, _, n := readLengthEncodedInteger(b) 569 if num < 1 { 570 return n, nil 571 } 572 573 n += int(num) 574 575 // Check data length 576 if len(b) >= n { 577 return n, nil 578 } 579 return n, io.EOF 580} 581 582// returns the number read, whether the value is NULL and the number of bytes read 583func readLengthEncodedInteger(b []byte) (uint64, bool, int) { 584 // See issue #349 585 if len(b) == 0 { 586 return 0, true, 1 587 } 588 589 switch b[0] { 590 // 251: NULL 591 case 0xfb: 592 return 0, true, 1 593 594 // 252: value of following 2 595 case 0xfc: 596 return uint64(b[1]) | uint64(b[2])<<8, false, 3 597 598 // 253: value of following 3 599 case 0xfd: 600 return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4 601 602 // 254: value of following 8 603 case 0xfe: 604 return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | 605 uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 | 606 uint64(b[7])<<48 | uint64(b[8])<<56, 607 false, 9 608 } 609 610 // 0-250: value of first byte 611 return uint64(b[0]), false, 1 612} 613 614// encodes a uint64 value and appends it to the given bytes slice 615func appendLengthEncodedInteger(b []byte, n uint64) []byte { 616 switch { 617 case n <= 250: 618 return append(b, byte(n)) 619 620 case n <= 0xffff: 621 return append(b, 0xfc, byte(n), byte(n>>8)) 622 623 case n <= 0xffffff: 624 return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16)) 625 } 626 return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24), 627 byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56)) 628} 629 630// reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize. 631// If cap(buf) is not enough, reallocate new buffer. 632func reserveBuffer(buf []byte, appendSize int) []byte { 633 newSize := len(buf) + appendSize 634 if cap(buf) < newSize { 635 // Grow buffer exponentially 636 newBuf := make([]byte, len(buf)*2+appendSize) 637 copy(newBuf, buf) 638 buf = newBuf 639 } 640 return buf[:newSize] 641} 642 643// escapeBytesBackslash escapes []byte with backslashes (\) 644// This escapes the contents of a string (provided as []byte) by adding backslashes before special 645// characters, and turning others into specific escape sequences, such as 646// turning newlines into \n and null bytes into \0. 647// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932 648func escapeBytesBackslash(buf, v []byte) []byte { 649 pos := len(buf) 650 buf = reserveBuffer(buf, len(v)*2) 651 652 for _, c := range v { 653 switch c { 654 case '\x00': 655 buf[pos] = '\\' 656 buf[pos+1] = '0' 657 pos += 2 658 case '\n': 659 buf[pos] = '\\' 660 buf[pos+1] = 'n' 661 pos += 2 662 case '\r': 663 buf[pos] = '\\' 664 buf[pos+1] = 'r' 665 pos += 2 666 case '\x1a': 667 buf[pos] = '\\' 668 buf[pos+1] = 'Z' 669 pos += 2 670 case '\'': 671 buf[pos] = '\\' 672 buf[pos+1] = '\'' 673 pos += 2 674 case '"': 675 buf[pos] = '\\' 676 buf[pos+1] = '"' 677 pos += 2 678 case '\\': 679 buf[pos] = '\\' 680 buf[pos+1] = '\\' 681 pos += 2 682 default: 683 buf[pos] = c 684 pos++ 685 } 686 } 687 688 return buf[:pos] 689} 690 691// escapeStringBackslash is similar to escapeBytesBackslash but for string. 692func escapeStringBackslash(buf []byte, v string) []byte { 693 pos := len(buf) 694 buf = reserveBuffer(buf, len(v)*2) 695 696 for i := 0; i < len(v); i++ { 697 c := v[i] 698 switch c { 699 case '\x00': 700 buf[pos] = '\\' 701 buf[pos+1] = '0' 702 pos += 2 703 case '\n': 704 buf[pos] = '\\' 705 buf[pos+1] = 'n' 706 pos += 2 707 case '\r': 708 buf[pos] = '\\' 709 buf[pos+1] = 'r' 710 pos += 2 711 case '\x1a': 712 buf[pos] = '\\' 713 buf[pos+1] = 'Z' 714 pos += 2 715 case '\'': 716 buf[pos] = '\\' 717 buf[pos+1] = '\'' 718 pos += 2 719 case '"': 720 buf[pos] = '\\' 721 buf[pos+1] = '"' 722 pos += 2 723 case '\\': 724 buf[pos] = '\\' 725 buf[pos+1] = '\\' 726 pos += 2 727 default: 728 buf[pos] = c 729 pos++ 730 } 731 } 732 733 return buf[:pos] 734} 735 736// escapeBytesQuotes escapes apostrophes in []byte by doubling them up. 737// This escapes the contents of a string by doubling up any apostrophes that 738// it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in 739// effect on the server. 740// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038 741func escapeBytesQuotes(buf, v []byte) []byte { 742 pos := len(buf) 743 buf = reserveBuffer(buf, len(v)*2) 744 745 for _, c := range v { 746 if c == '\'' { 747 buf[pos] = '\'' 748 buf[pos+1] = '\'' 749 pos += 2 750 } else { 751 buf[pos] = c 752 pos++ 753 } 754 } 755 756 return buf[:pos] 757} 758 759// escapeStringQuotes is similar to escapeBytesQuotes but for string. 760func escapeStringQuotes(buf []byte, v string) []byte { 761 pos := len(buf) 762 buf = reserveBuffer(buf, len(v)*2) 763 764 for i := 0; i < len(v); i++ { 765 c := v[i] 766 if c == '\'' { 767 buf[pos] = '\'' 768 buf[pos+1] = '\'' 769 pos += 2 770 } else { 771 buf[pos] = c 772 pos++ 773 } 774 } 775 776 return buf[:pos] 777} 778 779/****************************************************************************** 780* Sync utils * 781******************************************************************************/ 782 783// noCopy may be embedded into structs which must not be copied 784// after the first use. 785// 786// See https://github.com/golang/go/issues/8005#issuecomment-190753527 787// for details. 788type noCopy struct{} 789 790// Lock is a no-op used by -copylocks checker from `go vet`. 791func (*noCopy) Lock() {} 792 793// atomicBool is a wrapper around uint32 for usage as a boolean value with 794// atomic access. 795type atomicBool struct { 796 _noCopy noCopy 797 value uint32 798} 799 800// IsSet returns whether the current boolean value is true 801func (ab *atomicBool) IsSet() bool { 802 return atomic.LoadUint32(&ab.value) > 0 803} 804 805// Set sets the value of the bool regardless of the previous value 806func (ab *atomicBool) Set(value bool) { 807 if value { 808 atomic.StoreUint32(&ab.value, 1) 809 } else { 810 atomic.StoreUint32(&ab.value, 0) 811 } 812} 813 814// TrySet sets the value of the bool and returns whether the value changed 815func (ab *atomicBool) TrySet(value bool) bool { 816 if value { 817 return atomic.SwapUint32(&ab.value, 1) == 0 818 } 819 return atomic.SwapUint32(&ab.value, 0) > 0 820} 821 822// atomicError is a wrapper for atomically accessed error values 823type atomicError struct { 824 _noCopy noCopy 825 value atomic.Value 826} 827 828// Set sets the error value regardless of the previous value. 829// The value must not be nil 830func (ae *atomicError) Set(value error) { 831 ae.value.Store(value) 832} 833 834// Value returns the current error value 835func (ae *atomicError) Value() error { 836 if v := ae.value.Load(); v != nil { 837 // this will panic if the value doesn't implement the error interface 838 return v.(error) 839 } 840 return nil 841} 842 843func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { 844 dargs := make([]driver.Value, len(named)) 845 for n, param := range named { 846 if len(param.Name) > 0 { 847 // TODO: support the use of Named Parameters #561 848 return nil, errors.New("mysql: driver does not support the use of Named Parameters") 849 } 850 dargs[n] = param.Value 851 } 852 return dargs, nil 853} 854 855func mapIsolationLevel(level driver.IsolationLevel) (string, error) { 856 switch sql.IsolationLevel(level) { 857 case sql.LevelRepeatableRead: 858 return "REPEATABLE READ", nil 859 case sql.LevelReadCommitted: 860 return "READ COMMITTED", nil 861 case sql.LevelReadUncommitted: 862 return "READ UNCOMMITTED", nil 863 case sql.LevelSerializable: 864 return "SERIALIZABLE", nil 865 default: 866 return "", fmt.Errorf("mysql: unsupported isolation level: %v", level) 867 } 868} 869