1package stdlib_test 2 3import ( 4 "bytes" 5 "context" 6 "database/sql" 7 "database/sql/driver" 8 "encoding/json" 9 "math" 10 "os" 11 "reflect" 12 "testing" 13 "time" 14 15 "github.com/jackc/pgconn" 16 "github.com/jackc/pgx/v4" 17 "github.com/jackc/pgx/v4/stdlib" 18 "github.com/stretchr/testify/assert" 19 "github.com/stretchr/testify/require" 20) 21 22func openDB(t testing.TB) *sql.DB { 23 config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) 24 require.NoError(t, err) 25 return stdlib.OpenDB(*config) 26} 27 28func closeDB(t testing.TB, db *sql.DB) { 29 err := db.Close() 30 require.NoError(t, err) 31} 32 33func testWithAndWithoutPreferSimpleProtocol(t *testing.T, f func(t *testing.T, db *sql.DB)) { 34 t.Run("SimpleProto", 35 func(t *testing.T) { 36 config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) 37 require.NoError(t, err) 38 39 config.PreferSimpleProtocol = true 40 db := stdlib.OpenDB(*config) 41 defer func() { 42 err := db.Close() 43 require.NoError(t, err) 44 }() 45 46 f(t, db) 47 48 ensureDBValid(t, db) 49 }, 50 ) 51 52 t.Run("DefaultProto", 53 func(t *testing.T) { 54 config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) 55 require.NoError(t, err) 56 57 db := stdlib.OpenDB(*config) 58 defer func() { 59 err := db.Close() 60 require.NoError(t, err) 61 }() 62 63 f(t, db) 64 65 ensureDBValid(t, db) 66 }, 67 ) 68} 69 70// Do a simple query to ensure the DB is still usable. This is of less use in stdlib as the connection pool should 71// cover an broken connections. 72func ensureDBValid(t testing.TB, db *sql.DB) { 73 var sum, rowCount int32 74 75 rows, err := db.Query("select generate_series(1,$1)", 10) 76 require.NoError(t, err) 77 defer rows.Close() 78 79 for rows.Next() { 80 var n int32 81 rows.Scan(&n) 82 sum += n 83 rowCount++ 84 } 85 86 require.NoError(t, rows.Err()) 87 88 if rowCount != 10 { 89 t.Error("Select called onDataRow wrong number of times") 90 } 91 if sum != 55 { 92 t.Error("Wrong values returned") 93 } 94} 95 96type preparer interface { 97 Prepare(query string) (*sql.Stmt, error) 98} 99 100func prepareStmt(t *testing.T, p preparer, sql string) *sql.Stmt { 101 stmt, err := p.Prepare(sql) 102 require.NoError(t, err) 103 return stmt 104} 105 106func closeStmt(t *testing.T, stmt *sql.Stmt) { 107 err := stmt.Close() 108 require.NoError(t, err) 109} 110 111func TestSQLOpen(t *testing.T) { 112 db, err := sql.Open("pgx", os.Getenv("PGX_TEST_DATABASE")) 113 require.NoError(t, err) 114 closeDB(t, db) 115} 116 117func TestNormalLifeCycle(t *testing.T) { 118 db := openDB(t) 119 defer closeDB(t, db) 120 121 stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n") 122 defer closeStmt(t, stmt) 123 124 rows, err := stmt.Query(int32(1), int32(10)) 125 require.NoError(t, err) 126 127 rowCount := int64(0) 128 129 for rows.Next() { 130 rowCount++ 131 132 var s string 133 var n int64 134 err := rows.Scan(&s, &n) 135 require.NoError(t, err) 136 137 if s != "foo" { 138 t.Errorf(`Expected "foo", received "%v"`, s) 139 } 140 if n != rowCount { 141 t.Errorf("Expected %d, received %d", rowCount, n) 142 } 143 } 144 require.NoError(t, rows.Err()) 145 146 require.EqualValues(t, 10, rowCount) 147 148 err = rows.Close() 149 require.NoError(t, err) 150 151 ensureDBValid(t, db) 152} 153 154func TestStmtExec(t *testing.T) { 155 db := openDB(t) 156 defer closeDB(t, db) 157 158 tx, err := db.Begin() 159 require.NoError(t, err) 160 161 createStmt := prepareStmt(t, tx, "create temporary table t(a varchar not null)") 162 _, err = createStmt.Exec() 163 require.NoError(t, err) 164 closeStmt(t, createStmt) 165 166 insertStmt := prepareStmt(t, tx, "insert into t values($1::text)") 167 result, err := insertStmt.Exec("foo") 168 require.NoError(t, err) 169 170 n, err := result.RowsAffected() 171 require.NoError(t, err) 172 require.EqualValues(t, 1, n) 173 closeStmt(t, insertStmt) 174 175 ensureDBValid(t, db) 176} 177 178func TestQueryCloseRowsEarly(t *testing.T) { 179 db := openDB(t) 180 defer closeDB(t, db) 181 182 stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n") 183 defer closeStmt(t, stmt) 184 185 rows, err := stmt.Query(int32(1), int32(10)) 186 require.NoError(t, err) 187 188 // Close rows immediately without having read them 189 err = rows.Close() 190 require.NoError(t, err) 191 192 // Run the query again to ensure the connection and statement are still ok 193 rows, err = stmt.Query(int32(1), int32(10)) 194 require.NoError(t, err) 195 196 rowCount := int64(0) 197 198 for rows.Next() { 199 rowCount++ 200 201 var s string 202 var n int64 203 err := rows.Scan(&s, &n) 204 require.NoError(t, err) 205 if s != "foo" { 206 t.Errorf(`Expected "foo", received "%v"`, s) 207 } 208 if n != rowCount { 209 t.Errorf("Expected %d, received %d", rowCount, n) 210 } 211 } 212 require.NoError(t, rows.Err()) 213 require.EqualValues(t, 10, rowCount) 214 215 err = rows.Close() 216 require.NoError(t, err) 217 218 ensureDBValid(t, db) 219} 220 221func TestConnExec(t *testing.T) { 222 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { 223 _, err := db.Exec("create temporary table t(a varchar not null)") 224 require.NoError(t, err) 225 226 result, err := db.Exec("insert into t values('hey')") 227 require.NoError(t, err) 228 229 n, err := result.RowsAffected() 230 require.NoError(t, err) 231 require.EqualValues(t, 1, n) 232 }) 233} 234 235func TestConnQuery(t *testing.T) { 236 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { 237 rows, err := db.Query("select 'foo', n from generate_series($1::int, $2::int) n", int32(1), int32(10)) 238 require.NoError(t, err) 239 240 rowCount := int64(0) 241 242 for rows.Next() { 243 rowCount++ 244 245 var s string 246 var n int64 247 err := rows.Scan(&s, &n) 248 require.NoError(t, err) 249 if s != "foo" { 250 t.Errorf(`Expected "foo", received "%v"`, s) 251 } 252 if n != rowCount { 253 t.Errorf("Expected %d, received %d", rowCount, n) 254 } 255 } 256 require.NoError(t, rows.Err()) 257 require.EqualValues(t, 10, rowCount) 258 259 err = rows.Close() 260 require.NoError(t, err) 261 }) 262} 263 264// https://github.com/jackc/pgx/issues/781 265func TestConnQueryDifferentScanPlansIssue781(t *testing.T) { 266 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { 267 var s string 268 var b bool 269 270 rows, err := db.Query("select true, 'foo'") 271 require.NoError(t, err) 272 273 require.True(t, rows.Next()) 274 require.NoError(t, rows.Scan(&b, &s)) 275 assert.Equal(t, true, b) 276 assert.Equal(t, "foo", s) 277 }) 278} 279 280func TestConnQueryNull(t *testing.T) { 281 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { 282 rows, err := db.Query("select $1::int", nil) 283 require.NoError(t, err) 284 285 rowCount := int64(0) 286 287 for rows.Next() { 288 rowCount++ 289 290 var n sql.NullInt64 291 err := rows.Scan(&n) 292 require.NoError(t, err) 293 if n.Valid != false { 294 t.Errorf("Expected n to be null, but it was %v", n) 295 } 296 } 297 require.NoError(t, rows.Err()) 298 require.EqualValues(t, 1, rowCount) 299 300 err = rows.Close() 301 require.NoError(t, err) 302 }) 303} 304 305func TestConnQueryRowByteSlice(t *testing.T) { 306 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { 307 expected := []byte{222, 173, 190, 239} 308 var actual []byte 309 310 err := db.QueryRow(`select E'\\xdeadbeef'::bytea`).Scan(&actual) 311 require.NoError(t, err) 312 require.EqualValues(t, expected, actual) 313 }) 314} 315 316func TestConnQueryFailure(t *testing.T) { 317 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { 318 _, err := db.Query("select 'foo") 319 require.Error(t, err) 320 require.IsType(t, new(pgconn.PgError), err) 321 }) 322} 323 324func TestConnSimpleSlicePassThrough(t *testing.T) { 325 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { 326 var n int64 327 err := db.QueryRow("select cardinality($1::text[])", []string{"a", "b", "c"}).Scan(&n) 328 require.NoError(t, err) 329 assert.EqualValues(t, 3, n) 330 }) 331} 332 333// Test type that pgx would handle natively in binary, but since it is not a 334// database/sql native type should be passed through as a string 335func TestConnQueryRowPgxBinary(t *testing.T) { 336 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { 337 sql := "select $1::int4[]" 338 expected := "{1,2,3}" 339 var actual string 340 341 err := db.QueryRow(sql, expected).Scan(&actual) 342 require.NoError(t, err) 343 require.EqualValues(t, expected, actual) 344 }) 345} 346 347func TestConnQueryRowUnknownType(t *testing.T) { 348 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { 349 sql := "select $1::point" 350 expected := "(1,2)" 351 var actual string 352 353 err := db.QueryRow(sql, expected).Scan(&actual) 354 require.NoError(t, err) 355 require.EqualValues(t, expected, actual) 356 }) 357} 358 359func TestConnQueryJSONIntoByteSlice(t *testing.T) { 360 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { 361 _, err := db.Exec(` 362 create temporary table docs( 363 body json not null 364 ); 365 366 insert into docs(body) values('{"foo":"bar"}'); 367`) 368 require.NoError(t, err) 369 370 sql := `select * from docs` 371 expected := []byte(`{"foo":"bar"}`) 372 var actual []byte 373 374 err = db.QueryRow(sql).Scan(&actual) 375 if err != nil { 376 t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) 377 } 378 379 if bytes.Compare(actual, expected) != 0 { 380 t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, string(expected), string(actual), sql) 381 } 382 383 _, err = db.Exec(`drop table docs`) 384 require.NoError(t, err) 385 }) 386} 387 388func TestConnExecInsertByteSliceIntoJSON(t *testing.T) { 389 // Not testing with simple protocol because there is no way for that to work. A []byte will be considered binary data 390 // that needs to escape. No way to know whether the destination is really a text compatible or a bytea. 391 392 db := openDB(t) 393 defer closeDB(t, db) 394 395 _, err := db.Exec(` 396 create temporary table docs( 397 body json not null 398 ); 399`) 400 require.NoError(t, err) 401 402 expected := []byte(`{"foo":"bar"}`) 403 404 _, err = db.Exec(`insert into docs(body) values($1)`, expected) 405 require.NoError(t, err) 406 407 var actual []byte 408 err = db.QueryRow(`select body from docs`).Scan(&actual) 409 require.NoError(t, err) 410 411 if bytes.Compare(actual, expected) != 0 { 412 t.Errorf(`Expected "%v", got "%v"`, string(expected), string(actual)) 413 } 414 415 _, err = db.Exec(`drop table docs`) 416 require.NoError(t, err) 417} 418 419func TestTransactionLifeCycle(t *testing.T) { 420 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { 421 _, err := db.Exec("create temporary table t(a varchar not null)") 422 require.NoError(t, err) 423 424 tx, err := db.Begin() 425 require.NoError(t, err) 426 427 _, err = tx.Exec("insert into t values('hi')") 428 require.NoError(t, err) 429 430 err = tx.Rollback() 431 require.NoError(t, err) 432 433 var n int64 434 err = db.QueryRow("select count(*) from t").Scan(&n) 435 require.NoError(t, err) 436 require.EqualValues(t, 0, n) 437 438 tx, err = db.Begin() 439 require.NoError(t, err) 440 441 _, err = tx.Exec("insert into t values('hi')") 442 require.NoError(t, err) 443 444 err = tx.Commit() 445 require.NoError(t, err) 446 447 err = db.QueryRow("select count(*) from t").Scan(&n) 448 require.NoError(t, err) 449 require.EqualValues(t, 1, n) 450 }) 451} 452 453func TestConnBeginTxIsolation(t *testing.T) { 454 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { 455 var defaultIsoLevel string 456 err := db.QueryRow("show transaction_isolation").Scan(&defaultIsoLevel) 457 require.NoError(t, err) 458 459 supportedTests := []struct { 460 sqlIso sql.IsolationLevel 461 pgIso string 462 }{ 463 {sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel}, 464 {sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"}, 465 {sqlIso: sql.LevelReadCommitted, pgIso: "read committed"}, 466 {sqlIso: sql.LevelRepeatableRead, pgIso: "repeatable read"}, 467 {sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"}, 468 {sqlIso: sql.LevelSerializable, pgIso: "serializable"}, 469 } 470 for i, tt := range supportedTests { 471 func() { 472 tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso}) 473 if err != nil { 474 t.Errorf("%d. BeginTx failed: %v", i, err) 475 return 476 } 477 defer tx.Rollback() 478 479 var pgIso string 480 err = tx.QueryRow("show transaction_isolation").Scan(&pgIso) 481 if err != nil { 482 t.Errorf("%d. QueryRow failed: %v", i, err) 483 } 484 485 if pgIso != tt.pgIso { 486 t.Errorf("%d. pgIso => %s, want %s", i, pgIso, tt.pgIso) 487 } 488 }() 489 } 490 491 unsupportedTests := []struct { 492 sqlIso sql.IsolationLevel 493 }{ 494 {sqlIso: sql.LevelWriteCommitted}, 495 {sqlIso: sql.LevelLinearizable}, 496 } 497 for i, tt := range unsupportedTests { 498 tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso}) 499 if err == nil { 500 t.Errorf("%d. BeginTx should have failed", i) 501 tx.Rollback() 502 } 503 } 504 }) 505} 506 507func TestConnBeginTxReadOnly(t *testing.T) { 508 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { 509 tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}) 510 require.NoError(t, err) 511 defer tx.Rollback() 512 513 var pgReadOnly string 514 err = tx.QueryRow("show transaction_read_only").Scan(&pgReadOnly) 515 if err != nil { 516 t.Errorf("QueryRow failed: %v", err) 517 } 518 519 if pgReadOnly != "on" { 520 t.Errorf("pgReadOnly => %s, want %s", pgReadOnly, "on") 521 } 522 }) 523} 524 525func TestBeginTxContextCancel(t *testing.T) { 526 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { 527 _, err := db.Exec("drop table if exists t") 528 require.NoError(t, err) 529 530 ctx, cancelFn := context.WithCancel(context.Background()) 531 532 tx, err := db.BeginTx(ctx, nil) 533 require.NoError(t, err) 534 535 _, err = tx.Exec("create table t(id serial)") 536 require.NoError(t, err) 537 538 cancelFn() 539 540 err = tx.Commit() 541 if err != context.Canceled && err != sql.ErrTxDone { 542 t.Fatalf("err => %v, want %v or %v", err, context.Canceled, sql.ErrTxDone) 543 } 544 545 var n int 546 err = db.QueryRow("select count(*) from t").Scan(&n) 547 if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "42P01" { 548 t.Fatalf(`err => %v, want PgError{Code: "42P01"}`, err) 549 } 550 }) 551} 552 553func TestAcquireConn(t *testing.T) { 554 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { 555 var conns []*pgx.Conn 556 557 for i := 1; i < 6; i++ { 558 conn, err := stdlib.AcquireConn(db) 559 if err != nil { 560 t.Errorf("%d. AcquireConn failed: %v", i, err) 561 continue 562 } 563 564 var n int32 565 err = conn.QueryRow(context.Background(), "select 1").Scan(&n) 566 if err != nil { 567 t.Errorf("%d. QueryRow failed: %v", i, err) 568 } 569 if n != 1 { 570 t.Errorf("%d. n => %d, want %d", i, n, 1) 571 } 572 573 stats := db.Stats() 574 if stats.OpenConnections != i { 575 t.Errorf("%d. stats.OpenConnections => %d, want %d", i, stats.OpenConnections, i) 576 } 577 578 conns = append(conns, conn) 579 } 580 581 for i, conn := range conns { 582 if err := stdlib.ReleaseConn(db, conn); err != nil { 583 t.Errorf("%d. stdlib.ReleaseConn failed: %v", i, err) 584 } 585 } 586 }) 587} 588 589func TestConnRaw(t *testing.T) { 590 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { 591 conn, err := db.Conn(context.Background()) 592 require.NoError(t, err) 593 594 var n int 595 err = conn.Raw(func(driverConn interface{}) error { 596 conn := driverConn.(*stdlib.Conn).Conn() 597 return conn.QueryRow(context.Background(), "select 42").Scan(&n) 598 }) 599 require.NoError(t, err) 600 assert.EqualValues(t, 42, n) 601 }) 602} 603 604// https://github.com/jackc/pgx/issues/673 605func TestReleaseConnWithTxInProgress(t *testing.T) { 606 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { 607 c1, err := stdlib.AcquireConn(db) 608 require.NoError(t, err) 609 610 _, err = c1.Exec(context.Background(), "begin") 611 require.NoError(t, err) 612 613 c1PID := c1.PgConn().PID() 614 615 err = stdlib.ReleaseConn(db, c1) 616 require.NoError(t, err) 617 618 c2, err := stdlib.AcquireConn(db) 619 require.NoError(t, err) 620 621 c2PID := c2.PgConn().PID() 622 623 err = stdlib.ReleaseConn(db, c2) 624 require.NoError(t, err) 625 626 require.NotEqual(t, c1PID, c2PID) 627 628 // Releasing a conn with a tx in progress should close the connection 629 stats := db.Stats() 630 require.Equal(t, 1, stats.OpenConnections) 631 }) 632} 633 634func TestConnPingContextSuccess(t *testing.T) { 635 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { 636 err := db.PingContext(context.Background()) 637 require.NoError(t, err) 638 }) 639} 640 641func TestConnPrepareContextSuccess(t *testing.T) { 642 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { 643 stmt, err := db.PrepareContext(context.Background(), "select now()") 644 require.NoError(t, err) 645 err = stmt.Close() 646 require.NoError(t, err) 647 }) 648} 649 650func TestConnExecContextSuccess(t *testing.T) { 651 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { 652 _, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)") 653 require.NoError(t, err) 654 }) 655} 656 657func TestConnExecContextFailureRetry(t *testing.T) { 658 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { 659 // we get a connection, immediately close it, and then get it back 660 { 661 conn, err := stdlib.AcquireConn(db) 662 require.NoError(t, err) 663 conn.Close(context.Background()) 664 stdlib.ReleaseConn(db, conn) 665 } 666 conn, err := db.Conn(context.Background()) 667 require.NoError(t, err) 668 _, err = conn.ExecContext(context.Background(), "select 1") 669 require.EqualValues(t, driver.ErrBadConn, err) 670 }) 671} 672 673func TestConnQueryContextSuccess(t *testing.T) { 674 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { 675 rows, err := db.QueryContext(context.Background(), "select * from generate_series(1,10) n") 676 require.NoError(t, err) 677 678 for rows.Next() { 679 var n int64 680 err := rows.Scan(&n) 681 require.NoError(t, err) 682 } 683 require.NoError(t, rows.Err()) 684 }) 685} 686 687func TestConnQueryContextFailureRetry(t *testing.T) { 688 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { 689 // we get a connection, immediately close it, and then get it back 690 { 691 conn, err := stdlib.AcquireConn(db) 692 require.NoError(t, err) 693 conn.Close(context.Background()) 694 stdlib.ReleaseConn(db, conn) 695 } 696 conn, err := db.Conn(context.Background()) 697 require.NoError(t, err) 698 699 _, err = conn.QueryContext(context.Background(), "select 1") 700 require.EqualValues(t, driver.ErrBadConn, err) 701 }) 702} 703 704func TestRowsColumnTypeDatabaseTypeName(t *testing.T) { 705 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { 706 rows, err := db.Query("select * from generate_series(1,10) n") 707 require.NoError(t, err) 708 709 columnTypes, err := rows.ColumnTypes() 710 require.NoError(t, err) 711 require.Len(t, columnTypes, 1) 712 713 if columnTypes[0].DatabaseTypeName() != "INT4" { 714 t.Errorf("columnTypes[0].DatabaseTypeName() => %v, want %v", columnTypes[0].DatabaseTypeName(), "INT4") 715 } 716 717 err = rows.Close() 718 require.NoError(t, err) 719 }) 720} 721 722func TestStmtExecContextSuccess(t *testing.T) { 723 db := openDB(t) 724 defer closeDB(t, db) 725 726 _, err := db.Exec("create temporary table t(id int primary key)") 727 require.NoError(t, err) 728 729 stmt, err := db.Prepare("insert into t(id) values ($1::int4)") 730 require.NoError(t, err) 731 defer stmt.Close() 732 733 _, err = stmt.ExecContext(context.Background(), 42) 734 require.NoError(t, err) 735 736 ensureDBValid(t, db) 737} 738 739func TestStmtExecContextCancel(t *testing.T) { 740 db := openDB(t) 741 defer closeDB(t, db) 742 743 _, err := db.Exec("create temporary table t(id int primary key)") 744 require.NoError(t, err) 745 746 stmt, err := db.Prepare("insert into t(id) select $1::int4 from pg_sleep(5)") 747 require.NoError(t, err) 748 defer stmt.Close() 749 750 ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 751 defer cancel() 752 753 _, err = stmt.ExecContext(ctx, 42) 754 if !pgconn.Timeout(err) { 755 t.Errorf("expected timeout error, got %v", err) 756 } 757 758 ensureDBValid(t, db) 759} 760 761func TestStmtQueryContextSuccess(t *testing.T) { 762 db := openDB(t) 763 defer closeDB(t, db) 764 765 stmt, err := db.Prepare("select * from generate_series(1,$1::int4) n") 766 require.NoError(t, err) 767 defer stmt.Close() 768 769 rows, err := stmt.QueryContext(context.Background(), 5) 770 require.NoError(t, err) 771 772 for rows.Next() { 773 var n int64 774 if err := rows.Scan(&n); err != nil { 775 t.Error(err) 776 } 777 } 778 779 if rows.Err() != nil { 780 t.Error(rows.Err()) 781 } 782 783 ensureDBValid(t, db) 784} 785 786func TestRowsColumnTypes(t *testing.T) { 787 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { 788 columnTypesTests := []struct { 789 Name string 790 TypeName string 791 Length struct { 792 Len int64 793 OK bool 794 } 795 DecimalSize struct { 796 Precision int64 797 Scale int64 798 OK bool 799 } 800 ScanType reflect.Type 801 }{ 802 { 803 Name: "a", 804 TypeName: "INT4", 805 Length: struct { 806 Len int64 807 OK bool 808 }{ 809 Len: 0, 810 OK: false, 811 }, 812 DecimalSize: struct { 813 Precision int64 814 Scale int64 815 OK bool 816 }{ 817 Precision: 0, 818 Scale: 0, 819 OK: false, 820 }, 821 ScanType: reflect.TypeOf(int32(0)), 822 }, { 823 Name: "bar", 824 TypeName: "TEXT", 825 Length: struct { 826 Len int64 827 OK bool 828 }{ 829 Len: math.MaxInt64, 830 OK: true, 831 }, 832 DecimalSize: struct { 833 Precision int64 834 Scale int64 835 OK bool 836 }{ 837 Precision: 0, 838 Scale: 0, 839 OK: false, 840 }, 841 ScanType: reflect.TypeOf(""), 842 }, { 843 Name: "dec", 844 TypeName: "NUMERIC", 845 Length: struct { 846 Len int64 847 OK bool 848 }{ 849 Len: 0, 850 OK: false, 851 }, 852 DecimalSize: struct { 853 Precision int64 854 Scale int64 855 OK bool 856 }{ 857 Precision: 9, 858 Scale: 2, 859 OK: true, 860 }, 861 ScanType: reflect.TypeOf(float64(0)), 862 }, { 863 Name: "d", 864 TypeName: "1266", 865 Length: struct { 866 Len int64 867 OK bool 868 }{ 869 Len: 0, 870 OK: false, 871 }, 872 DecimalSize: struct { 873 Precision int64 874 Scale int64 875 OK bool 876 }{ 877 Precision: 0, 878 Scale: 0, 879 OK: false, 880 }, 881 ScanType: reflect.TypeOf(""), 882 }, 883 } 884 885 rows, err := db.Query("SELECT 1 AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec, '12:00:00'::timetz as d") 886 require.NoError(t, err) 887 888 columns, err := rows.ColumnTypes() 889 require.NoError(t, err) 890 assert.Len(t, columns, 4) 891 892 for i, tt := range columnTypesTests { 893 c := columns[i] 894 if c.Name() != tt.Name { 895 t.Errorf("(%d) got: %s, want: %s", i, c.Name(), tt.Name) 896 } 897 if c.DatabaseTypeName() != tt.TypeName { 898 t.Errorf("(%d) got: %s, want: %s", i, c.DatabaseTypeName(), tt.TypeName) 899 } 900 l, ok := c.Length() 901 if l != tt.Length.Len { 902 t.Errorf("(%d) got: %d, want: %d", i, l, tt.Length.Len) 903 } 904 if ok != tt.Length.OK { 905 t.Errorf("(%d) got: %t, want: %t", i, ok, tt.Length.OK) 906 } 907 p, s, ok := c.DecimalSize() 908 if p != tt.DecimalSize.Precision { 909 t.Errorf("(%d) got: %d, want: %d", i, p, tt.DecimalSize.Precision) 910 } 911 if s != tt.DecimalSize.Scale { 912 t.Errorf("(%d) got: %d, want: %d", i, s, tt.DecimalSize.Scale) 913 } 914 if ok != tt.DecimalSize.OK { 915 t.Errorf("(%d) got: %t, want: %t", i, ok, tt.DecimalSize.OK) 916 } 917 if c.ScanType() != tt.ScanType { 918 t.Errorf("(%d) got: %v, want: %v", i, c.ScanType(), tt.ScanType) 919 } 920 } 921 }) 922} 923 924func TestQueryLifeCycle(t *testing.T) { 925 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { 926 rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3) 927 require.NoError(t, err) 928 929 rowCount := int64(0) 930 931 for rows.Next() { 932 rowCount++ 933 var ( 934 s string 935 n int64 936 ) 937 938 err := rows.Scan(&s, &n) 939 require.NoError(t, err) 940 941 if s != "foo" { 942 t.Errorf(`Expected "foo", received "%v"`, s) 943 } 944 945 if n != rowCount { 946 t.Errorf("Expected %d, received %d", rowCount, n) 947 } 948 } 949 require.NoError(t, rows.Err()) 950 951 err = rows.Close() 952 require.NoError(t, err) 953 954 rows, err = db.Query("select 1 where false") 955 require.NoError(t, err) 956 957 rowCount = int64(0) 958 959 for rows.Next() { 960 rowCount++ 961 } 962 require.NoError(t, rows.Err()) 963 require.EqualValues(t, 0, rowCount) 964 965 err = rows.Close() 966 require.NoError(t, err) 967 }) 968} 969 970// https://github.com/jackc/pgx/issues/409 971func TestScanJSONIntoJSONRawMessage(t *testing.T) { 972 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { 973 var msg json.RawMessage 974 975 err := db.QueryRow("select '{}'::json").Scan(&msg) 976 require.NoError(t, err) 977 require.EqualValues(t, []byte("{}"), []byte(msg)) 978 }) 979} 980 981type testLog struct { 982 lvl pgx.LogLevel 983 msg string 984 data map[string]interface{} 985} 986 987type testLogger struct { 988 logs []testLog 989} 990 991func (l *testLogger) Log(ctx context.Context, lvl pgx.LogLevel, msg string, data map[string]interface{}) { 992 l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data}) 993} 994 995func TestRegisterConnConfig(t *testing.T) { 996 connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) 997 require.NoError(t, err) 998 999 logger := &testLogger{} 1000 connConfig.Logger = logger 1001 1002 connStr := stdlib.RegisterConnConfig(connConfig) 1003 defer stdlib.UnregisterConnConfig(connStr) 1004 1005 db, err := sql.Open("pgx", connStr) 1006 require.NoError(t, err) 1007 defer closeDB(t, db) 1008 1009 var n int64 1010 err = db.QueryRow("select 1").Scan(&n) 1011 require.NoError(t, err) 1012 1013 l := logger.logs[len(logger.logs)-1] 1014 assert.Equal(t, "Query", l.msg) 1015 assert.Equal(t, "select 1", l.data["sql"]) 1016} 1017