1// Go MySQL Driver - A MySQL-Driver for Go's database/sql package 2// 3// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. 4// 5// This Source Code Form is subject to the terms of the Mozilla Public 6// License, v. 2.0. If a copy of the MPL was not distributed with this file, 7// You can obtain one at http://mozilla.org/MPL/2.0/. 8 9package mysql 10 11import ( 12 "context" 13 "database/sql" 14 "database/sql/driver" 15 "encoding/json" 16 "io" 17 "net" 18 "strconv" 19 "strings" 20 "time" 21) 22 23type mysqlConn struct { 24 buf buffer 25 netConn net.Conn 26 rawConn net.Conn // underlying connection when netConn is TLS connection. 27 affectedRows uint64 28 insertId uint64 29 cfg *Config 30 maxAllowedPacket int 31 maxWriteSize int 32 writeTimeout time.Duration 33 flags clientFlag 34 status statusFlag 35 sequence uint8 36 parseTime bool 37 reset bool // set when the Go SQL package calls ResetSession 38 39 // for context support (Go 1.8+) 40 watching bool 41 watcher chan<- context.Context 42 closech chan struct{} 43 finished chan<- struct{} 44 canceled atomicError // set non-nil if conn is canceled 45 closed atomicBool // set when conn is closed, before closech is closed 46} 47 48// Handles parameters set in DSN after the connection is established 49func (mc *mysqlConn) handleParams() (err error) { 50 var cmdSet strings.Builder 51 for param, val := range mc.cfg.Params { 52 switch param { 53 // Charset: character_set_connection, character_set_client, character_set_results 54 case "charset": 55 charsets := strings.Split(val, ",") 56 for i := range charsets { 57 // ignore errors here - a charset may not exist 58 err = mc.exec("SET NAMES " + charsets[i]) 59 if err == nil { 60 break 61 } 62 } 63 if err != nil { 64 return 65 } 66 67 // Other system vars accumulated in a single SET command 68 default: 69 if cmdSet.Len() == 0 { 70 // Heuristic: 29 chars for each other key=value to reduce reallocations 71 cmdSet.Grow(4 + len(param) + 1 + len(val) + 30*(len(mc.cfg.Params)-1)) 72 cmdSet.WriteString("SET ") 73 } else { 74 cmdSet.WriteByte(',') 75 } 76 cmdSet.WriteString(param) 77 cmdSet.WriteByte('=') 78 cmdSet.WriteString(val) 79 } 80 } 81 82 if cmdSet.Len() > 0 { 83 err = mc.exec(cmdSet.String()) 84 if err != nil { 85 return 86 } 87 } 88 89 return 90} 91 92func (mc *mysqlConn) markBadConn(err error) error { 93 if mc == nil { 94 return err 95 } 96 if err != errBadConnNoWrite { 97 return err 98 } 99 return driver.ErrBadConn 100} 101 102func (mc *mysqlConn) Begin() (driver.Tx, error) { 103 return mc.begin(false) 104} 105 106func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { 107 if mc.closed.IsSet() { 108 errLog.Print(ErrInvalidConn) 109 return nil, driver.ErrBadConn 110 } 111 var q string 112 if readOnly { 113 q = "START TRANSACTION READ ONLY" 114 } else { 115 q = "START TRANSACTION" 116 } 117 err := mc.exec(q) 118 if err == nil { 119 return &mysqlTx{mc}, err 120 } 121 return nil, mc.markBadConn(err) 122} 123 124func (mc *mysqlConn) Close() (err error) { 125 // Makes Close idempotent 126 if !mc.closed.IsSet() { 127 err = mc.writeCommandPacket(comQuit) 128 } 129 130 mc.cleanup() 131 132 return 133} 134 135// Closes the network connection and unsets internal variables. Do not call this 136// function after successfully authentication, call Close instead. This function 137// is called before auth or on auth failure because MySQL will have already 138// closed the network connection. 139func (mc *mysqlConn) cleanup() { 140 if !mc.closed.TrySet(true) { 141 return 142 } 143 144 // Makes cleanup idempotent 145 close(mc.closech) 146 if mc.netConn == nil { 147 return 148 } 149 if err := mc.netConn.Close(); err != nil { 150 errLog.Print(err) 151 } 152} 153 154func (mc *mysqlConn) error() error { 155 if mc.closed.IsSet() { 156 if err := mc.canceled.Value(); err != nil { 157 return err 158 } 159 return ErrInvalidConn 160 } 161 return nil 162} 163 164func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { 165 if mc.closed.IsSet() { 166 errLog.Print(ErrInvalidConn) 167 return nil, driver.ErrBadConn 168 } 169 // Send command 170 err := mc.writeCommandPacketStr(comStmtPrepare, query) 171 if err != nil { 172 // STMT_PREPARE is safe to retry. So we can return ErrBadConn here. 173 errLog.Print(err) 174 return nil, driver.ErrBadConn 175 } 176 177 stmt := &mysqlStmt{ 178 mc: mc, 179 } 180 181 // Read Result 182 columnCount, err := stmt.readPrepareResultPacket() 183 if err == nil { 184 if stmt.paramCount > 0 { 185 if err = mc.readUntilEOF(); err != nil { 186 return nil, err 187 } 188 } 189 190 if columnCount > 0 { 191 err = mc.readUntilEOF() 192 } 193 } 194 195 return stmt, err 196} 197 198func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) { 199 // Number of ? should be same to len(args) 200 if strings.Count(query, "?") != len(args) { 201 return "", driver.ErrSkip 202 } 203 204 buf, err := mc.buf.takeCompleteBuffer() 205 if err != nil { 206 // can not take the buffer. Something must be wrong with the connection 207 errLog.Print(err) 208 return "", ErrInvalidConn 209 } 210 buf = buf[:0] 211 argPos := 0 212 213 for i := 0; i < len(query); i++ { 214 q := strings.IndexByte(query[i:], '?') 215 if q == -1 { 216 buf = append(buf, query[i:]...) 217 break 218 } 219 buf = append(buf, query[i:i+q]...) 220 i += q 221 222 arg := args[argPos] 223 argPos++ 224 225 if arg == nil { 226 buf = append(buf, "NULL"...) 227 continue 228 } 229 230 switch v := arg.(type) { 231 case int64: 232 buf = strconv.AppendInt(buf, v, 10) 233 case uint64: 234 // Handle uint64 explicitly because our custom ConvertValue emits unsigned values 235 buf = strconv.AppendUint(buf, v, 10) 236 case float64: 237 buf = strconv.AppendFloat(buf, v, 'g', -1, 64) 238 case bool: 239 if v { 240 buf = append(buf, '1') 241 } else { 242 buf = append(buf, '0') 243 } 244 case time.Time: 245 if v.IsZero() { 246 buf = append(buf, "'0000-00-00'"...) 247 } else { 248 buf = append(buf, '\'') 249 buf, err = appendDateTime(buf, v.In(mc.cfg.Loc)) 250 if err != nil { 251 return "", err 252 } 253 buf = append(buf, '\'') 254 } 255 case json.RawMessage: 256 buf = append(buf, '\'') 257 if mc.status&statusNoBackslashEscapes == 0 { 258 buf = escapeBytesBackslash(buf, v) 259 } else { 260 buf = escapeBytesQuotes(buf, v) 261 } 262 buf = append(buf, '\'') 263 case []byte: 264 if v == nil { 265 buf = append(buf, "NULL"...) 266 } else { 267 buf = append(buf, "_binary'"...) 268 if mc.status&statusNoBackslashEscapes == 0 { 269 buf = escapeBytesBackslash(buf, v) 270 } else { 271 buf = escapeBytesQuotes(buf, v) 272 } 273 buf = append(buf, '\'') 274 } 275 case string: 276 buf = append(buf, '\'') 277 if mc.status&statusNoBackslashEscapes == 0 { 278 buf = escapeStringBackslash(buf, v) 279 } else { 280 buf = escapeStringQuotes(buf, v) 281 } 282 buf = append(buf, '\'') 283 default: 284 return "", driver.ErrSkip 285 } 286 287 if len(buf)+4 > mc.maxAllowedPacket { 288 return "", driver.ErrSkip 289 } 290 } 291 if argPos != len(args) { 292 return "", driver.ErrSkip 293 } 294 return string(buf), nil 295} 296 297func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { 298 if mc.closed.IsSet() { 299 errLog.Print(ErrInvalidConn) 300 return nil, driver.ErrBadConn 301 } 302 if len(args) != 0 { 303 if !mc.cfg.InterpolateParams { 304 return nil, driver.ErrSkip 305 } 306 // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement 307 prepared, err := mc.interpolateParams(query, args) 308 if err != nil { 309 return nil, err 310 } 311 query = prepared 312 } 313 mc.affectedRows = 0 314 mc.insertId = 0 315 316 err := mc.exec(query) 317 if err == nil { 318 return &mysqlResult{ 319 affectedRows: int64(mc.affectedRows), 320 insertId: int64(mc.insertId), 321 }, err 322 } 323 return nil, mc.markBadConn(err) 324} 325 326// Internal function to execute commands 327func (mc *mysqlConn) exec(query string) error { 328 // Send command 329 if err := mc.writeCommandPacketStr(comQuery, query); err != nil { 330 return mc.markBadConn(err) 331 } 332 333 // Read Result 334 resLen, err := mc.readResultSetHeaderPacket() 335 if err != nil { 336 return err 337 } 338 339 if resLen > 0 { 340 // columns 341 if err := mc.readUntilEOF(); err != nil { 342 return err 343 } 344 345 // rows 346 if err := mc.readUntilEOF(); err != nil { 347 return err 348 } 349 } 350 351 return mc.discardResults() 352} 353 354func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { 355 return mc.query(query, args) 356} 357 358func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) { 359 if mc.closed.IsSet() { 360 errLog.Print(ErrInvalidConn) 361 return nil, driver.ErrBadConn 362 } 363 if len(args) != 0 { 364 if !mc.cfg.InterpolateParams { 365 return nil, driver.ErrSkip 366 } 367 // try client-side prepare to reduce roundtrip 368 prepared, err := mc.interpolateParams(query, args) 369 if err != nil { 370 return nil, err 371 } 372 query = prepared 373 } 374 // Send command 375 err := mc.writeCommandPacketStr(comQuery, query) 376 if err == nil { 377 // Read Result 378 var resLen int 379 resLen, err = mc.readResultSetHeaderPacket() 380 if err == nil { 381 rows := new(textRows) 382 rows.mc = mc 383 384 if resLen == 0 { 385 rows.rs.done = true 386 387 switch err := rows.NextResultSet(); err { 388 case nil, io.EOF: 389 return rows, nil 390 default: 391 return nil, err 392 } 393 } 394 395 // Columns 396 rows.rs.columns, err = mc.readColumns(resLen) 397 return rows, err 398 } 399 } 400 return nil, mc.markBadConn(err) 401} 402 403// Gets the value of the given MySQL System Variable 404// The returned byte slice is only valid until the next read 405func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { 406 // Send command 407 if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil { 408 return nil, err 409 } 410 411 // Read Result 412 resLen, err := mc.readResultSetHeaderPacket() 413 if err == nil { 414 rows := new(textRows) 415 rows.mc = mc 416 rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}} 417 418 if resLen > 0 { 419 // Columns 420 if err := mc.readUntilEOF(); err != nil { 421 return nil, err 422 } 423 } 424 425 dest := make([]driver.Value, resLen) 426 if err = rows.readRow(dest); err == nil { 427 return dest[0].([]byte), mc.readUntilEOF() 428 } 429 } 430 return nil, err 431} 432 433// finish is called when the query has canceled. 434func (mc *mysqlConn) cancel(err error) { 435 mc.canceled.Set(err) 436 mc.cleanup() 437} 438 439// finish is called when the query has succeeded. 440func (mc *mysqlConn) finish() { 441 if !mc.watching || mc.finished == nil { 442 return 443 } 444 select { 445 case mc.finished <- struct{}{}: 446 mc.watching = false 447 case <-mc.closech: 448 } 449} 450 451// Ping implements driver.Pinger interface 452func (mc *mysqlConn) Ping(ctx context.Context) (err error) { 453 if mc.closed.IsSet() { 454 errLog.Print(ErrInvalidConn) 455 return driver.ErrBadConn 456 } 457 458 if err = mc.watchCancel(ctx); err != nil { 459 return 460 } 461 defer mc.finish() 462 463 if err = mc.writeCommandPacket(comPing); err != nil { 464 return mc.markBadConn(err) 465 } 466 467 return mc.readResultOK() 468} 469 470// BeginTx implements driver.ConnBeginTx interface 471func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { 472 if mc.closed.IsSet() { 473 return nil, driver.ErrBadConn 474 } 475 476 if err := mc.watchCancel(ctx); err != nil { 477 return nil, err 478 } 479 defer mc.finish() 480 481 if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault { 482 level, err := mapIsolationLevel(opts.Isolation) 483 if err != nil { 484 return nil, err 485 } 486 err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level) 487 if err != nil { 488 return nil, err 489 } 490 } 491 492 return mc.begin(opts.ReadOnly) 493} 494 495func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 496 dargs, err := namedValueToValue(args) 497 if err != nil { 498 return nil, err 499 } 500 501 if err := mc.watchCancel(ctx); err != nil { 502 return nil, err 503 } 504 505 rows, err := mc.query(query, dargs) 506 if err != nil { 507 mc.finish() 508 return nil, err 509 } 510 rows.finish = mc.finish 511 return rows, err 512} 513 514func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 515 dargs, err := namedValueToValue(args) 516 if err != nil { 517 return nil, err 518 } 519 520 if err := mc.watchCancel(ctx); err != nil { 521 return nil, err 522 } 523 defer mc.finish() 524 525 return mc.Exec(query, dargs) 526} 527 528func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 529 if err := mc.watchCancel(ctx); err != nil { 530 return nil, err 531 } 532 533 stmt, err := mc.Prepare(query) 534 mc.finish() 535 if err != nil { 536 return nil, err 537 } 538 539 select { 540 default: 541 case <-ctx.Done(): 542 stmt.Close() 543 return nil, ctx.Err() 544 } 545 return stmt, nil 546} 547 548func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 549 dargs, err := namedValueToValue(args) 550 if err != nil { 551 return nil, err 552 } 553 554 if err := stmt.mc.watchCancel(ctx); err != nil { 555 return nil, err 556 } 557 558 rows, err := stmt.query(dargs) 559 if err != nil { 560 stmt.mc.finish() 561 return nil, err 562 } 563 rows.finish = stmt.mc.finish 564 return rows, err 565} 566 567func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 568 dargs, err := namedValueToValue(args) 569 if err != nil { 570 return nil, err 571 } 572 573 if err := stmt.mc.watchCancel(ctx); err != nil { 574 return nil, err 575 } 576 defer stmt.mc.finish() 577 578 return stmt.Exec(dargs) 579} 580 581func (mc *mysqlConn) watchCancel(ctx context.Context) error { 582 if mc.watching { 583 // Reach here if canceled, 584 // so the connection is already invalid 585 mc.cleanup() 586 return nil 587 } 588 // When ctx is already cancelled, don't watch it. 589 if err := ctx.Err(); err != nil { 590 return err 591 } 592 // When ctx is not cancellable, don't watch it. 593 if ctx.Done() == nil { 594 return nil 595 } 596 // When watcher is not alive, can't watch it. 597 if mc.watcher == nil { 598 return nil 599 } 600 601 mc.watching = true 602 mc.watcher <- ctx 603 return nil 604} 605 606func (mc *mysqlConn) startWatcher() { 607 watcher := make(chan context.Context, 1) 608 mc.watcher = watcher 609 finished := make(chan struct{}) 610 mc.finished = finished 611 go func() { 612 for { 613 var ctx context.Context 614 select { 615 case ctx = <-watcher: 616 case <-mc.closech: 617 return 618 } 619 620 select { 621 case <-ctx.Done(): 622 mc.cancel(ctx.Err()) 623 case <-finished: 624 case <-mc.closech: 625 return 626 } 627 } 628 }() 629} 630 631func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { 632 nv.Value, err = converter{}.ConvertValue(nv.Value) 633 return 634} 635 636// ResetSession implements driver.SessionResetter. 637// (From Go 1.10) 638func (mc *mysqlConn) ResetSession(ctx context.Context) error { 639 if mc.closed.IsSet() { 640 return driver.ErrBadConn 641 } 642 mc.reset = true 643 return nil 644} 645 646// IsValid implements driver.Validator interface 647// (From Go 1.15) 648func (mc *mysqlConn) IsValid() bool { 649 return !mc.closed.IsSet() 650} 651