1// Copyright 2011 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package sql 6 7import ( 8 "database/sql/driver" 9 "errors" 10 "fmt" 11 "io" 12 "log" 13 "strconv" 14 "strings" 15 "sync" 16 "testing" 17 "time" 18) 19 20var _ = log.Printf 21 22// fakeDriver is a fake database that implements Go's driver.Driver 23// interface, just for testing. 24// 25// It speaks a query language that's semantically similar to but 26// syntantically different and simpler than SQL. The syntax is as 27// follows: 28// 29// WIPE 30// CREATE|<tablename>|<col>=<type>,<col>=<type>,... 31// where types are: "string", [u]int{8,16,32,64}, "bool" 32// INSERT|<tablename>|col=val,col2=val2,col3=? 33// SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=? 34// 35// When opening a fakeDriver's database, it starts empty with no 36// tables. All tables and data are stored in memory only. 37type fakeDriver struct { 38 mu sync.Mutex // guards 3 following fields 39 openCount int // conn opens 40 closeCount int // conn closes 41 dbs map[string]*fakeDB 42} 43 44type fakeDB struct { 45 name string 46 47 mu sync.Mutex 48 free []*fakeConn 49 tables map[string]*table 50 badConn bool 51} 52 53type table struct { 54 mu sync.Mutex 55 colname []string 56 coltype []string 57 rows []*row 58} 59 60func (t *table) columnIndex(name string) int { 61 for n, nname := range t.colname { 62 if name == nname { 63 return n 64 } 65 } 66 return -1 67} 68 69type row struct { 70 cols []interface{} // must be same size as its table colname + coltype 71} 72 73func (r *row) clone() *row { 74 nrow := &row{cols: make([]interface{}, len(r.cols))} 75 copy(nrow.cols, r.cols) 76 return nrow 77} 78 79type fakeConn struct { 80 db *fakeDB // where to return ourselves to 81 82 currTx *fakeTx 83 84 // Stats for tests: 85 mu sync.Mutex 86 stmtsMade int 87 stmtsClosed int 88 numPrepare int 89 bad bool 90} 91 92func (c *fakeConn) incrStat(v *int) { 93 c.mu.Lock() 94 *v++ 95 c.mu.Unlock() 96} 97 98type fakeTx struct { 99 c *fakeConn 100} 101 102type fakeStmt struct { 103 c *fakeConn 104 q string // just for debugging 105 106 cmd string 107 table string 108 109 closed bool 110 111 colName []string // used by CREATE, INSERT, SELECT (selected columns) 112 colType []string // used by CREATE 113 colValue []interface{} // used by INSERT (mix of strings and "?" for bound params) 114 placeholders int // used by INSERT/SELECT: number of ? params 115 116 whereCol []string // used by SELECT (all placeholders) 117 118 placeholderConverter []driver.ValueConverter // used by INSERT 119} 120 121var fdriver driver.Driver = &fakeDriver{} 122 123func init() { 124 Register("test", fdriver) 125} 126 127// Supports dsn forms: 128// <dbname> 129// <dbname>;<opts> (only currently supported option is `badConn`, 130// which causes driver.ErrBadConn to be returned on 131// every other conn.Begin()) 132func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { 133 parts := strings.Split(dsn, ";") 134 if len(parts) < 1 { 135 return nil, errors.New("fakedb: no database name") 136 } 137 name := parts[0] 138 139 db := d.getDB(name) 140 141 d.mu.Lock() 142 d.openCount++ 143 d.mu.Unlock() 144 conn := &fakeConn{db: db} 145 146 if len(parts) >= 2 && parts[1] == "badConn" { 147 conn.bad = true 148 } 149 return conn, nil 150} 151 152func (d *fakeDriver) getDB(name string) *fakeDB { 153 d.mu.Lock() 154 defer d.mu.Unlock() 155 if d.dbs == nil { 156 d.dbs = make(map[string]*fakeDB) 157 } 158 db, ok := d.dbs[name] 159 if !ok { 160 db = &fakeDB{name: name} 161 d.dbs[name] = db 162 } 163 return db 164} 165 166func (db *fakeDB) wipe() { 167 db.mu.Lock() 168 defer db.mu.Unlock() 169 db.tables = nil 170} 171 172func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error { 173 db.mu.Lock() 174 defer db.mu.Unlock() 175 if db.tables == nil { 176 db.tables = make(map[string]*table) 177 } 178 if _, exist := db.tables[name]; exist { 179 return fmt.Errorf("table %q already exists", name) 180 } 181 if len(columnNames) != len(columnTypes) { 182 return fmt.Errorf("create table of %q len(names) != len(types): %d vs %d", 183 name, len(columnNames), len(columnTypes)) 184 } 185 db.tables[name] = &table{colname: columnNames, coltype: columnTypes} 186 return nil 187} 188 189// must be called with db.mu lock held 190func (db *fakeDB) table(table string) (*table, bool) { 191 if db.tables == nil { 192 return nil, false 193 } 194 t, ok := db.tables[table] 195 return t, ok 196} 197 198func (db *fakeDB) columnType(table, column string) (typ string, ok bool) { 199 db.mu.Lock() 200 defer db.mu.Unlock() 201 t, ok := db.table(table) 202 if !ok { 203 return 204 } 205 for n, cname := range t.colname { 206 if cname == column { 207 return t.coltype[n], true 208 } 209 } 210 return "", false 211} 212 213func (c *fakeConn) isBad() bool { 214 // if not simulating bad conn, do nothing 215 if !c.bad { 216 return false 217 } 218 // alternate between bad conn and not bad conn 219 c.db.badConn = !c.db.badConn 220 return c.db.badConn 221} 222 223func (c *fakeConn) Begin() (driver.Tx, error) { 224 if c.isBad() { 225 return nil, driver.ErrBadConn 226 } 227 if c.currTx != nil { 228 return nil, errors.New("already in a transaction") 229 } 230 c.currTx = &fakeTx{c: c} 231 return c.currTx, nil 232} 233 234var hookPostCloseConn struct { 235 sync.Mutex 236 fn func(*fakeConn, error) 237} 238 239func setHookpostCloseConn(fn func(*fakeConn, error)) { 240 hookPostCloseConn.Lock() 241 defer hookPostCloseConn.Unlock() 242 hookPostCloseConn.fn = fn 243} 244 245var testStrictClose *testing.T 246 247// setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close 248// fails to close. If nil, the check is disabled. 249func setStrictFakeConnClose(t *testing.T) { 250 testStrictClose = t 251} 252 253func (c *fakeConn) Close() (err error) { 254 drv := fdriver.(*fakeDriver) 255 defer func() { 256 if err != nil && testStrictClose != nil { 257 testStrictClose.Errorf("failed to close a test fakeConn: %v", err) 258 } 259 hookPostCloseConn.Lock() 260 fn := hookPostCloseConn.fn 261 hookPostCloseConn.Unlock() 262 if fn != nil { 263 fn(c, err) 264 } 265 if err == nil { 266 drv.mu.Lock() 267 drv.closeCount++ 268 drv.mu.Unlock() 269 } 270 }() 271 if c.currTx != nil { 272 return errors.New("can't close fakeConn; in a Transaction") 273 } 274 if c.db == nil { 275 return errors.New("can't close fakeConn; already closed") 276 } 277 if c.stmtsMade > c.stmtsClosed { 278 return errors.New("can't close; dangling statement(s)") 279 } 280 c.db = nil 281 return nil 282} 283 284func checkSubsetTypes(args []driver.Value) error { 285 for n, arg := range args { 286 switch arg.(type) { 287 case int64, float64, bool, nil, []byte, string, time.Time: 288 default: 289 return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg) 290 } 291 } 292 return nil 293} 294 295func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) { 296 // This is an optional interface, but it's implemented here 297 // just to check that all the args are of the proper types. 298 // ErrSkip is returned so the caller acts as if we didn't 299 // implement this at all. 300 err := checkSubsetTypes(args) 301 if err != nil { 302 return nil, err 303 } 304 return nil, driver.ErrSkip 305} 306 307func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) { 308 // This is an optional interface, but it's implemented here 309 // just to check that all the args are of the proper types. 310 // ErrSkip is returned so the caller acts as if we didn't 311 // implement this at all. 312 err := checkSubsetTypes(args) 313 if err != nil { 314 return nil, err 315 } 316 return nil, driver.ErrSkip 317} 318 319func errf(msg string, args ...interface{}) error { 320 return errors.New("fakedb: " + fmt.Sprintf(msg, args...)) 321} 322 323// parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=? 324// (note that where columns must always contain ? marks, 325// just a limitation for fakedb) 326func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) { 327 if len(parts) != 3 { 328 stmt.Close() 329 return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) 330 } 331 stmt.table = parts[0] 332 stmt.colName = strings.Split(parts[1], ",") 333 for n, colspec := range strings.Split(parts[2], ",") { 334 if colspec == "" { 335 continue 336 } 337 nameVal := strings.Split(colspec, "=") 338 if len(nameVal) != 2 { 339 stmt.Close() 340 return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 341 } 342 column, value := nameVal[0], nameVal[1] 343 _, ok := c.db.columnType(stmt.table, column) 344 if !ok { 345 stmt.Close() 346 return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) 347 } 348 if value != "?" { 349 stmt.Close() 350 return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", 351 stmt.table, column) 352 } 353 stmt.whereCol = append(stmt.whereCol, column) 354 stmt.placeholders++ 355 } 356 return stmt, nil 357} 358 359// parts are table|col=type,col2=type2 360func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) { 361 if len(parts) != 2 { 362 stmt.Close() 363 return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) 364 } 365 stmt.table = parts[0] 366 for n, colspec := range strings.Split(parts[1], ",") { 367 nameType := strings.Split(colspec, "=") 368 if len(nameType) != 2 { 369 stmt.Close() 370 return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 371 } 372 stmt.colName = append(stmt.colName, nameType[0]) 373 stmt.colType = append(stmt.colType, nameType[1]) 374 } 375 return stmt, nil 376} 377 378// parts are table|col=?,col2=val 379func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) { 380 if len(parts) != 2 { 381 stmt.Close() 382 return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) 383 } 384 stmt.table = parts[0] 385 for n, colspec := range strings.Split(parts[1], ",") { 386 nameVal := strings.Split(colspec, "=") 387 if len(nameVal) != 2 { 388 stmt.Close() 389 return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 390 } 391 column, value := nameVal[0], nameVal[1] 392 ctype, ok := c.db.columnType(stmt.table, column) 393 if !ok { 394 stmt.Close() 395 return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column) 396 } 397 stmt.colName = append(stmt.colName, column) 398 399 if value != "?" { 400 var subsetVal interface{} 401 // Convert to driver subset type 402 switch ctype { 403 case "string": 404 subsetVal = []byte(value) 405 case "blob": 406 subsetVal = []byte(value) 407 case "int32": 408 i, err := strconv.Atoi(value) 409 if err != nil { 410 stmt.Close() 411 return nil, errf("invalid conversion to int32 from %q", value) 412 } 413 subsetVal = int64(i) // int64 is a subset type, but not int32 414 default: 415 stmt.Close() 416 return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype) 417 } 418 stmt.colValue = append(stmt.colValue, subsetVal) 419 } else { 420 stmt.placeholders++ 421 stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype)) 422 stmt.colValue = append(stmt.colValue, "?") 423 } 424 } 425 return stmt, nil 426} 427 428func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { 429 c.numPrepare++ 430 if c.db == nil { 431 panic("nil c.db; conn = " + fmt.Sprintf("%#v", c)) 432 } 433 parts := strings.Split(query, "|") 434 if len(parts) < 1 { 435 return nil, errf("empty query") 436 } 437 cmd := parts[0] 438 parts = parts[1:] 439 stmt := &fakeStmt{q: query, c: c, cmd: cmd} 440 c.incrStat(&c.stmtsMade) 441 switch cmd { 442 case "WIPE": 443 // Nothing 444 case "SELECT": 445 return c.prepareSelect(stmt, parts) 446 case "CREATE": 447 return c.prepareCreate(stmt, parts) 448 case "INSERT": 449 return c.prepareInsert(stmt, parts) 450 default: 451 stmt.Close() 452 return nil, errf("unsupported command type %q", cmd) 453 } 454 return stmt, nil 455} 456 457func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { 458 if len(s.placeholderConverter) == 0 { 459 return driver.DefaultParameterConverter 460 } 461 return s.placeholderConverter[idx] 462} 463 464func (s *fakeStmt) Close() error { 465 if s.c == nil { 466 panic("nil conn in fakeStmt.Close") 467 } 468 if s.c.db == nil { 469 panic("in fakeStmt.Close, conn's db is nil (already closed)") 470 } 471 if !s.closed { 472 s.c.incrStat(&s.c.stmtsClosed) 473 s.closed = true 474 } 475 return nil 476} 477 478var errClosed = errors.New("fakedb: statement has been closed") 479 480func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { 481 if s.closed { 482 return nil, errClosed 483 } 484 err := checkSubsetTypes(args) 485 if err != nil { 486 return nil, err 487 } 488 489 db := s.c.db 490 switch s.cmd { 491 case "WIPE": 492 db.wipe() 493 return driver.ResultNoRows, nil 494 case "CREATE": 495 if err := db.createTable(s.table, s.colName, s.colType); err != nil { 496 return nil, err 497 } 498 return driver.ResultNoRows, nil 499 case "INSERT": 500 return s.execInsert(args) 501 } 502 fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s) 503 return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd) 504} 505 506func (s *fakeStmt) execInsert(args []driver.Value) (driver.Result, error) { 507 db := s.c.db 508 if len(args) != s.placeholders { 509 panic("error in pkg db; should only get here if size is correct") 510 } 511 db.mu.Lock() 512 t, ok := db.table(s.table) 513 db.mu.Unlock() 514 if !ok { 515 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 516 } 517 518 t.mu.Lock() 519 defer t.mu.Unlock() 520 521 cols := make([]interface{}, len(t.colname)) 522 argPos := 0 523 for n, colname := range s.colName { 524 colidx := t.columnIndex(colname) 525 if colidx == -1 { 526 return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname) 527 } 528 var val interface{} 529 if strvalue, ok := s.colValue[n].(string); ok && strvalue == "?" { 530 val = args[argPos] 531 argPos++ 532 } else { 533 val = s.colValue[n] 534 } 535 cols[colidx] = val 536 } 537 538 t.rows = append(t.rows, &row{cols: cols}) 539 return driver.RowsAffected(1), nil 540} 541 542func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { 543 if s.closed { 544 return nil, errClosed 545 } 546 err := checkSubsetTypes(args) 547 if err != nil { 548 return nil, err 549 } 550 551 db := s.c.db 552 if len(args) != s.placeholders { 553 panic("error in pkg db; should only get here if size is correct") 554 } 555 556 db.mu.Lock() 557 t, ok := db.table(s.table) 558 db.mu.Unlock() 559 if !ok { 560 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 561 } 562 563 if s.table == "magicquery" { 564 if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" { 565 if args[0] == "sleep" { 566 time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond) 567 } 568 } 569 } 570 571 t.mu.Lock() 572 defer t.mu.Unlock() 573 574 colIdx := make(map[string]int) // select column name -> column index in table 575 for _, name := range s.colName { 576 idx := t.columnIndex(name) 577 if idx == -1 { 578 return nil, fmt.Errorf("fakedb: unknown column name %q", name) 579 } 580 colIdx[name] = idx 581 } 582 583 mrows := []*row{} 584rows: 585 for _, trow := range t.rows { 586 // Process the where clause, skipping non-match rows. This is lazy 587 // and just uses fmt.Sprintf("%v") to test equality. Good enough 588 // for test code. 589 for widx, wcol := range s.whereCol { 590 idx := t.columnIndex(wcol) 591 if idx == -1 { 592 return nil, fmt.Errorf("db: invalid where clause column %q", wcol) 593 } 594 tcol := trow.cols[idx] 595 if bs, ok := tcol.([]byte); ok { 596 // lazy hack to avoid sprintf %v on a []byte 597 tcol = string(bs) 598 } 599 if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) { 600 continue rows 601 } 602 } 603 mrow := &row{cols: make([]interface{}, len(s.colName))} 604 for seli, name := range s.colName { 605 mrow.cols[seli] = trow.cols[colIdx[name]] 606 } 607 mrows = append(mrows, mrow) 608 } 609 610 cursor := &rowsCursor{ 611 pos: -1, 612 rows: mrows, 613 cols: s.colName, 614 } 615 return cursor, nil 616} 617 618func (s *fakeStmt) NumInput() int { 619 return s.placeholders 620} 621 622func (tx *fakeTx) Commit() error { 623 tx.c.currTx = nil 624 return nil 625} 626 627func (tx *fakeTx) Rollback() error { 628 tx.c.currTx = nil 629 return nil 630} 631 632type rowsCursor struct { 633 cols []string 634 pos int 635 rows []*row 636 closed bool 637 638 // a clone of slices to give out to clients, indexed by the 639 // the original slice's first byte address. we clone them 640 // just so we're able to corrupt them on close. 641 bytesClone map[*byte][]byte 642} 643 644func (rc *rowsCursor) Close() error { 645 if !rc.closed { 646 for _, bs := range rc.bytesClone { 647 bs[0] = 255 // first byte corrupted 648 } 649 } 650 rc.closed = true 651 return nil 652} 653 654func (rc *rowsCursor) Columns() []string { 655 return rc.cols 656} 657 658func (rc *rowsCursor) Next(dest []driver.Value) error { 659 if rc.closed { 660 return errors.New("fakedb: cursor is closed") 661 } 662 rc.pos++ 663 if rc.pos >= len(rc.rows) { 664 return io.EOF // per interface spec 665 } 666 for i, v := range rc.rows[rc.pos].cols { 667 // TODO(bradfitz): convert to subset types? naah, I 668 // think the subset types should only be input to 669 // driver, but the sql package should be able to handle 670 // a wider range of types coming out of drivers. all 671 // for ease of drivers, and to prevent drivers from 672 // messing up conversions or doing them differently. 673 dest[i] = v 674 675 if bs, ok := v.([]byte); ok { 676 if rc.bytesClone == nil { 677 rc.bytesClone = make(map[*byte][]byte) 678 } 679 clone, ok := rc.bytesClone[&bs[0]] 680 if !ok { 681 clone = make([]byte, len(bs)) 682 copy(clone, bs) 683 rc.bytesClone[&bs[0]] = clone 684 } 685 dest[i] = clone 686 } 687 } 688 return nil 689} 690 691// fakeDriverString is like driver.String, but indirects pointers like 692// DefaultValueConverter. 693// 694// This could be surprising behavior to retroactively apply to 695// driver.String now that Go1 is out, but this is convenient for 696// our TestPointerParamsAndScans. 697// 698type fakeDriverString struct{} 699 700func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) { 701 switch c := v.(type) { 702 case string, []byte: 703 return v, nil 704 case *string: 705 if c == nil { 706 return nil, nil 707 } 708 return *c, nil 709 } 710 return fmt.Sprintf("%v", v), nil 711} 712 713func converterForType(typ string) driver.ValueConverter { 714 switch typ { 715 case "bool": 716 return driver.Bool 717 case "nullbool": 718 return driver.Null{Converter: driver.Bool} 719 case "int32": 720 return driver.Int32 721 case "string": 722 return driver.NotNull{Converter: fakeDriverString{}} 723 case "nullstring": 724 return driver.Null{Converter: fakeDriverString{}} 725 case "int64": 726 // TODO(coopernurse): add type-specific converter 727 return driver.NotNull{Converter: driver.DefaultParameterConverter} 728 case "nullint64": 729 // TODO(coopernurse): add type-specific converter 730 return driver.Null{Converter: driver.DefaultParameterConverter} 731 case "float64": 732 // TODO(coopernurse): add type-specific converter 733 return driver.NotNull{Converter: driver.DefaultParameterConverter} 734 case "nullfloat64": 735 // TODO(coopernurse): add type-specific converter 736 return driver.Null{Converter: driver.DefaultParameterConverter} 737 case "datetime": 738 return driver.DefaultParameterConverter 739 } 740 panic("invalid fakedb column type of " + typ) 741} 742