1package stdlib_test
2
3import (
4	"bytes"
5	"context"
6	"database/sql"
7	"database/sql/driver"
8	"encoding/json"
9	"math"
10	"os"
11	"reflect"
12	"testing"
13	"time"
14
15	"github.com/jackc/pgconn"
16	"github.com/jackc/pgx/v4"
17	"github.com/jackc/pgx/v4/stdlib"
18	"github.com/stretchr/testify/assert"
19	"github.com/stretchr/testify/require"
20)
21
22func openDB(t testing.TB) *sql.DB {
23	config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
24	require.NoError(t, err)
25	return stdlib.OpenDB(*config)
26}
27
28func closeDB(t testing.TB, db *sql.DB) {
29	err := db.Close()
30	require.NoError(t, err)
31}
32
33func testWithAndWithoutPreferSimpleProtocol(t *testing.T, f func(t *testing.T, db *sql.DB)) {
34	t.Run("SimpleProto",
35		func(t *testing.T) {
36			config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
37			require.NoError(t, err)
38
39			config.PreferSimpleProtocol = true
40			db := stdlib.OpenDB(*config)
41			defer func() {
42				err := db.Close()
43				require.NoError(t, err)
44			}()
45
46			f(t, db)
47
48			ensureDBValid(t, db)
49		},
50	)
51
52	t.Run("DefaultProto",
53		func(t *testing.T) {
54			config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
55			require.NoError(t, err)
56
57			db := stdlib.OpenDB(*config)
58			defer func() {
59				err := db.Close()
60				require.NoError(t, err)
61			}()
62
63			f(t, db)
64
65			ensureDBValid(t, db)
66		},
67	)
68}
69
70// Do a simple query to ensure the DB is still usable. This is of less use in stdlib as the connection pool should
71// cover an broken connections.
72func ensureDBValid(t testing.TB, db *sql.DB) {
73	var sum, rowCount int32
74
75	rows, err := db.Query("select generate_series(1,$1)", 10)
76	require.NoError(t, err)
77	defer rows.Close()
78
79	for rows.Next() {
80		var n int32
81		rows.Scan(&n)
82		sum += n
83		rowCount++
84	}
85
86	require.NoError(t, rows.Err())
87
88	if rowCount != 10 {
89		t.Error("Select called onDataRow wrong number of times")
90	}
91	if sum != 55 {
92		t.Error("Wrong values returned")
93	}
94}
95
96type preparer interface {
97	Prepare(query string) (*sql.Stmt, error)
98}
99
100func prepareStmt(t *testing.T, p preparer, sql string) *sql.Stmt {
101	stmt, err := p.Prepare(sql)
102	require.NoError(t, err)
103	return stmt
104}
105
106func closeStmt(t *testing.T, stmt *sql.Stmt) {
107	err := stmt.Close()
108	require.NoError(t, err)
109}
110
111func TestSQLOpen(t *testing.T) {
112	db, err := sql.Open("pgx", os.Getenv("PGX_TEST_DATABASE"))
113	require.NoError(t, err)
114	closeDB(t, db)
115}
116
117func TestNormalLifeCycle(t *testing.T) {
118	db := openDB(t)
119	defer closeDB(t, db)
120
121	stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n")
122	defer closeStmt(t, stmt)
123
124	rows, err := stmt.Query(int32(1), int32(10))
125	require.NoError(t, err)
126
127	rowCount := int64(0)
128
129	for rows.Next() {
130		rowCount++
131
132		var s string
133		var n int64
134		err := rows.Scan(&s, &n)
135		require.NoError(t, err)
136
137		if s != "foo" {
138			t.Errorf(`Expected "foo", received "%v"`, s)
139		}
140		if n != rowCount {
141			t.Errorf("Expected %d, received %d", rowCount, n)
142		}
143	}
144	require.NoError(t, rows.Err())
145
146	require.EqualValues(t, 10, rowCount)
147
148	err = rows.Close()
149	require.NoError(t, err)
150
151	ensureDBValid(t, db)
152}
153
154func TestStmtExec(t *testing.T) {
155	db := openDB(t)
156	defer closeDB(t, db)
157
158	tx, err := db.Begin()
159	require.NoError(t, err)
160
161	createStmt := prepareStmt(t, tx, "create temporary table t(a varchar not null)")
162	_, err = createStmt.Exec()
163	require.NoError(t, err)
164	closeStmt(t, createStmt)
165
166	insertStmt := prepareStmt(t, tx, "insert into t values($1::text)")
167	result, err := insertStmt.Exec("foo")
168	require.NoError(t, err)
169
170	n, err := result.RowsAffected()
171	require.NoError(t, err)
172	require.EqualValues(t, 1, n)
173	closeStmt(t, insertStmt)
174
175	ensureDBValid(t, db)
176}
177
178func TestQueryCloseRowsEarly(t *testing.T) {
179	db := openDB(t)
180	defer closeDB(t, db)
181
182	stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n")
183	defer closeStmt(t, stmt)
184
185	rows, err := stmt.Query(int32(1), int32(10))
186	require.NoError(t, err)
187
188	// Close rows immediately without having read them
189	err = rows.Close()
190	require.NoError(t, err)
191
192	// Run the query again to ensure the connection and statement are still ok
193	rows, err = stmt.Query(int32(1), int32(10))
194	require.NoError(t, err)
195
196	rowCount := int64(0)
197
198	for rows.Next() {
199		rowCount++
200
201		var s string
202		var n int64
203		err := rows.Scan(&s, &n)
204		require.NoError(t, err)
205		if s != "foo" {
206			t.Errorf(`Expected "foo", received "%v"`, s)
207		}
208		if n != rowCount {
209			t.Errorf("Expected %d, received %d", rowCount, n)
210		}
211	}
212	require.NoError(t, rows.Err())
213	require.EqualValues(t, 10, rowCount)
214
215	err = rows.Close()
216	require.NoError(t, err)
217
218	ensureDBValid(t, db)
219}
220
221func TestConnExec(t *testing.T) {
222	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
223		_, err := db.Exec("create temporary table t(a varchar not null)")
224		require.NoError(t, err)
225
226		result, err := db.Exec("insert into t values('hey')")
227		require.NoError(t, err)
228
229		n, err := result.RowsAffected()
230		require.NoError(t, err)
231		require.EqualValues(t, 1, n)
232	})
233}
234
235func TestConnQuery(t *testing.T) {
236	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
237		rows, err := db.Query("select 'foo', n from generate_series($1::int, $2::int) n", int32(1), int32(10))
238		require.NoError(t, err)
239
240		rowCount := int64(0)
241
242		for rows.Next() {
243			rowCount++
244
245			var s string
246			var n int64
247			err := rows.Scan(&s, &n)
248			require.NoError(t, err)
249			if s != "foo" {
250				t.Errorf(`Expected "foo", received "%v"`, s)
251			}
252			if n != rowCount {
253				t.Errorf("Expected %d, received %d", rowCount, n)
254			}
255		}
256		require.NoError(t, rows.Err())
257		require.EqualValues(t, 10, rowCount)
258
259		err = rows.Close()
260		require.NoError(t, err)
261	})
262}
263
264// https://github.com/jackc/pgx/issues/781
265func TestConnQueryDifferentScanPlansIssue781(t *testing.T) {
266	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
267		var s string
268		var b bool
269
270		rows, err := db.Query("select true, 'foo'")
271		require.NoError(t, err)
272
273		require.True(t, rows.Next())
274		require.NoError(t, rows.Scan(&b, &s))
275		assert.Equal(t, true, b)
276		assert.Equal(t, "foo", s)
277	})
278}
279
280func TestConnQueryNull(t *testing.T) {
281	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
282		rows, err := db.Query("select $1::int", nil)
283		require.NoError(t, err)
284
285		rowCount := int64(0)
286
287		for rows.Next() {
288			rowCount++
289
290			var n sql.NullInt64
291			err := rows.Scan(&n)
292			require.NoError(t, err)
293			if n.Valid != false {
294				t.Errorf("Expected n to be null, but it was %v", n)
295			}
296		}
297		require.NoError(t, rows.Err())
298		require.EqualValues(t, 1, rowCount)
299
300		err = rows.Close()
301		require.NoError(t, err)
302	})
303}
304
305func TestConnQueryRowByteSlice(t *testing.T) {
306	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
307		expected := []byte{222, 173, 190, 239}
308		var actual []byte
309
310		err := db.QueryRow(`select E'\\xdeadbeef'::bytea`).Scan(&actual)
311		require.NoError(t, err)
312		require.EqualValues(t, expected, actual)
313	})
314}
315
316func TestConnQueryFailure(t *testing.T) {
317	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
318		_, err := db.Query("select 'foo")
319		require.Error(t, err)
320		require.IsType(t, new(pgconn.PgError), err)
321	})
322}
323
324func TestConnSimpleSlicePassThrough(t *testing.T) {
325	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
326		var n int64
327		err := db.QueryRow("select cardinality($1::text[])", []string{"a", "b", "c"}).Scan(&n)
328		require.NoError(t, err)
329		assert.EqualValues(t, 3, n)
330	})
331}
332
333// Test type that pgx would handle natively in binary, but since it is not a
334// database/sql native type should be passed through as a string
335func TestConnQueryRowPgxBinary(t *testing.T) {
336	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
337		sql := "select $1::int4[]"
338		expected := "{1,2,3}"
339		var actual string
340
341		err := db.QueryRow(sql, expected).Scan(&actual)
342		require.NoError(t, err)
343		require.EqualValues(t, expected, actual)
344	})
345}
346
347func TestConnQueryRowUnknownType(t *testing.T) {
348	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
349		sql := "select $1::point"
350		expected := "(1,2)"
351		var actual string
352
353		err := db.QueryRow(sql, expected).Scan(&actual)
354		require.NoError(t, err)
355		require.EqualValues(t, expected, actual)
356	})
357}
358
359func TestConnQueryJSONIntoByteSlice(t *testing.T) {
360	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
361		_, err := db.Exec(`
362		create temporary table docs(
363			body json not null
364		);
365
366		insert into docs(body) values('{"foo":"bar"}');
367`)
368		require.NoError(t, err)
369
370		sql := `select * from docs`
371		expected := []byte(`{"foo":"bar"}`)
372		var actual []byte
373
374		err = db.QueryRow(sql).Scan(&actual)
375		if err != nil {
376			t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
377		}
378
379		if bytes.Compare(actual, expected) != 0 {
380			t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, string(expected), string(actual), sql)
381		}
382
383		_, err = db.Exec(`drop table docs`)
384		require.NoError(t, err)
385	})
386}
387
388func TestConnExecInsertByteSliceIntoJSON(t *testing.T) {
389	// Not testing with simple protocol because there is no way for that to work. A []byte will be considered binary data
390	// that needs to escape. No way to know whether the destination is really a text compatible or a bytea.
391
392	db := openDB(t)
393	defer closeDB(t, db)
394
395	_, err := db.Exec(`
396		create temporary table docs(
397			body json not null
398		);
399`)
400	require.NoError(t, err)
401
402	expected := []byte(`{"foo":"bar"}`)
403
404	_, err = db.Exec(`insert into docs(body) values($1)`, expected)
405	require.NoError(t, err)
406
407	var actual []byte
408	err = db.QueryRow(`select body from docs`).Scan(&actual)
409	require.NoError(t, err)
410
411	if bytes.Compare(actual, expected) != 0 {
412		t.Errorf(`Expected "%v", got "%v"`, string(expected), string(actual))
413	}
414
415	_, err = db.Exec(`drop table docs`)
416	require.NoError(t, err)
417}
418
419func TestTransactionLifeCycle(t *testing.T) {
420	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
421		_, err := db.Exec("create temporary table t(a varchar not null)")
422		require.NoError(t, err)
423
424		tx, err := db.Begin()
425		require.NoError(t, err)
426
427		_, err = tx.Exec("insert into t values('hi')")
428		require.NoError(t, err)
429
430		err = tx.Rollback()
431		require.NoError(t, err)
432
433		var n int64
434		err = db.QueryRow("select count(*) from t").Scan(&n)
435		require.NoError(t, err)
436		require.EqualValues(t, 0, n)
437
438		tx, err = db.Begin()
439		require.NoError(t, err)
440
441		_, err = tx.Exec("insert into t values('hi')")
442		require.NoError(t, err)
443
444		err = tx.Commit()
445		require.NoError(t, err)
446
447		err = db.QueryRow("select count(*) from t").Scan(&n)
448		require.NoError(t, err)
449		require.EqualValues(t, 1, n)
450	})
451}
452
453func TestConnBeginTxIsolation(t *testing.T) {
454	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
455		var defaultIsoLevel string
456		err := db.QueryRow("show transaction_isolation").Scan(&defaultIsoLevel)
457		require.NoError(t, err)
458
459		supportedTests := []struct {
460			sqlIso sql.IsolationLevel
461			pgIso  string
462		}{
463			{sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel},
464			{sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"},
465			{sqlIso: sql.LevelReadCommitted, pgIso: "read committed"},
466			{sqlIso: sql.LevelRepeatableRead, pgIso: "repeatable read"},
467			{sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"},
468			{sqlIso: sql.LevelSerializable, pgIso: "serializable"},
469		}
470		for i, tt := range supportedTests {
471			func() {
472				tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso})
473				if err != nil {
474					t.Errorf("%d. BeginTx failed: %v", i, err)
475					return
476				}
477				defer tx.Rollback()
478
479				var pgIso string
480				err = tx.QueryRow("show transaction_isolation").Scan(&pgIso)
481				if err != nil {
482					t.Errorf("%d. QueryRow failed: %v", i, err)
483				}
484
485				if pgIso != tt.pgIso {
486					t.Errorf("%d. pgIso => %s, want %s", i, pgIso, tt.pgIso)
487				}
488			}()
489		}
490
491		unsupportedTests := []struct {
492			sqlIso sql.IsolationLevel
493		}{
494			{sqlIso: sql.LevelWriteCommitted},
495			{sqlIso: sql.LevelLinearizable},
496		}
497		for i, tt := range unsupportedTests {
498			tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso})
499			if err == nil {
500				t.Errorf("%d. BeginTx should have failed", i)
501				tx.Rollback()
502			}
503		}
504	})
505}
506
507func TestConnBeginTxReadOnly(t *testing.T) {
508	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
509		tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true})
510		require.NoError(t, err)
511		defer tx.Rollback()
512
513		var pgReadOnly string
514		err = tx.QueryRow("show transaction_read_only").Scan(&pgReadOnly)
515		if err != nil {
516			t.Errorf("QueryRow failed: %v", err)
517		}
518
519		if pgReadOnly != "on" {
520			t.Errorf("pgReadOnly => %s, want %s", pgReadOnly, "on")
521		}
522	})
523}
524
525func TestBeginTxContextCancel(t *testing.T) {
526	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
527		_, err := db.Exec("drop table if exists t")
528		require.NoError(t, err)
529
530		ctx, cancelFn := context.WithCancel(context.Background())
531
532		tx, err := db.BeginTx(ctx, nil)
533		require.NoError(t, err)
534
535		_, err = tx.Exec("create table t(id serial)")
536		require.NoError(t, err)
537
538		cancelFn()
539
540		err = tx.Commit()
541		if err != context.Canceled && err != sql.ErrTxDone {
542			t.Fatalf("err => %v, want %v or %v", err, context.Canceled, sql.ErrTxDone)
543		}
544
545		var n int
546		err = db.QueryRow("select count(*) from t").Scan(&n)
547		if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "42P01" {
548			t.Fatalf(`err => %v, want PgError{Code: "42P01"}`, err)
549		}
550	})
551}
552
553func TestAcquireConn(t *testing.T) {
554	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
555		var conns []*pgx.Conn
556
557		for i := 1; i < 6; i++ {
558			conn, err := stdlib.AcquireConn(db)
559			if err != nil {
560				t.Errorf("%d. AcquireConn failed: %v", i, err)
561				continue
562			}
563
564			var n int32
565			err = conn.QueryRow(context.Background(), "select 1").Scan(&n)
566			if err != nil {
567				t.Errorf("%d. QueryRow failed: %v", i, err)
568			}
569			if n != 1 {
570				t.Errorf("%d. n => %d, want %d", i, n, 1)
571			}
572
573			stats := db.Stats()
574			if stats.OpenConnections != i {
575				t.Errorf("%d. stats.OpenConnections => %d, want %d", i, stats.OpenConnections, i)
576			}
577
578			conns = append(conns, conn)
579		}
580
581		for i, conn := range conns {
582			if err := stdlib.ReleaseConn(db, conn); err != nil {
583				t.Errorf("%d. stdlib.ReleaseConn failed: %v", i, err)
584			}
585		}
586	})
587}
588
589func TestConnRaw(t *testing.T) {
590	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
591		conn, err := db.Conn(context.Background())
592		require.NoError(t, err)
593
594		var n int
595		err = conn.Raw(func(driverConn interface{}) error {
596			conn := driverConn.(*stdlib.Conn).Conn()
597			return conn.QueryRow(context.Background(), "select 42").Scan(&n)
598		})
599		require.NoError(t, err)
600		assert.EqualValues(t, 42, n)
601	})
602}
603
604// https://github.com/jackc/pgx/issues/673
605func TestReleaseConnWithTxInProgress(t *testing.T) {
606	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
607		c1, err := stdlib.AcquireConn(db)
608		require.NoError(t, err)
609
610		_, err = c1.Exec(context.Background(), "begin")
611		require.NoError(t, err)
612
613		c1PID := c1.PgConn().PID()
614
615		err = stdlib.ReleaseConn(db, c1)
616		require.NoError(t, err)
617
618		c2, err := stdlib.AcquireConn(db)
619		require.NoError(t, err)
620
621		c2PID := c2.PgConn().PID()
622
623		err = stdlib.ReleaseConn(db, c2)
624		require.NoError(t, err)
625
626		require.NotEqual(t, c1PID, c2PID)
627
628		// Releasing a conn with a tx in progress should close the connection
629		stats := db.Stats()
630		require.Equal(t, 1, stats.OpenConnections)
631	})
632}
633
634func TestConnPingContextSuccess(t *testing.T) {
635	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
636		err := db.PingContext(context.Background())
637		require.NoError(t, err)
638	})
639}
640
641func TestConnPrepareContextSuccess(t *testing.T) {
642	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
643		stmt, err := db.PrepareContext(context.Background(), "select now()")
644		require.NoError(t, err)
645		err = stmt.Close()
646		require.NoError(t, err)
647	})
648}
649
650func TestConnExecContextSuccess(t *testing.T) {
651	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
652		_, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)")
653		require.NoError(t, err)
654	})
655}
656
657func TestConnExecContextFailureRetry(t *testing.T) {
658	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
659		// we get a connection, immediately close it, and then get it back
660		{
661			conn, err := stdlib.AcquireConn(db)
662			require.NoError(t, err)
663			conn.Close(context.Background())
664			stdlib.ReleaseConn(db, conn)
665		}
666		conn, err := db.Conn(context.Background())
667		require.NoError(t, err)
668		_, err = conn.ExecContext(context.Background(), "select 1")
669		require.EqualValues(t, driver.ErrBadConn, err)
670	})
671}
672
673func TestConnQueryContextSuccess(t *testing.T) {
674	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
675		rows, err := db.QueryContext(context.Background(), "select * from generate_series(1,10) n")
676		require.NoError(t, err)
677
678		for rows.Next() {
679			var n int64
680			err := rows.Scan(&n)
681			require.NoError(t, err)
682		}
683		require.NoError(t, rows.Err())
684	})
685}
686
687func TestConnQueryContextFailureRetry(t *testing.T) {
688	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
689		// we get a connection, immediately close it, and then get it back
690		{
691			conn, err := stdlib.AcquireConn(db)
692			require.NoError(t, err)
693			conn.Close(context.Background())
694			stdlib.ReleaseConn(db, conn)
695		}
696		conn, err := db.Conn(context.Background())
697		require.NoError(t, err)
698
699		_, err = conn.QueryContext(context.Background(), "select 1")
700		require.EqualValues(t, driver.ErrBadConn, err)
701	})
702}
703
704func TestRowsColumnTypeDatabaseTypeName(t *testing.T) {
705	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
706		rows, err := db.Query("select * from generate_series(1,10) n")
707		require.NoError(t, err)
708
709		columnTypes, err := rows.ColumnTypes()
710		require.NoError(t, err)
711		require.Len(t, columnTypes, 1)
712
713		if columnTypes[0].DatabaseTypeName() != "INT4" {
714			t.Errorf("columnTypes[0].DatabaseTypeName() => %v, want %v", columnTypes[0].DatabaseTypeName(), "INT4")
715		}
716
717		err = rows.Close()
718		require.NoError(t, err)
719	})
720}
721
722func TestStmtExecContextSuccess(t *testing.T) {
723	db := openDB(t)
724	defer closeDB(t, db)
725
726	_, err := db.Exec("create temporary table t(id int primary key)")
727	require.NoError(t, err)
728
729	stmt, err := db.Prepare("insert into t(id) values ($1::int4)")
730	require.NoError(t, err)
731	defer stmt.Close()
732
733	_, err = stmt.ExecContext(context.Background(), 42)
734	require.NoError(t, err)
735
736	ensureDBValid(t, db)
737}
738
739func TestStmtExecContextCancel(t *testing.T) {
740	db := openDB(t)
741	defer closeDB(t, db)
742
743	_, err := db.Exec("create temporary table t(id int primary key)")
744	require.NoError(t, err)
745
746	stmt, err := db.Prepare("insert into t(id) select $1::int4 from pg_sleep(5)")
747	require.NoError(t, err)
748	defer stmt.Close()
749
750	ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
751	defer cancel()
752
753	_, err = stmt.ExecContext(ctx, 42)
754	if !pgconn.Timeout(err) {
755		t.Errorf("expected timeout error, got %v", err)
756	}
757
758	ensureDBValid(t, db)
759}
760
761func TestStmtQueryContextSuccess(t *testing.T) {
762	db := openDB(t)
763	defer closeDB(t, db)
764
765	stmt, err := db.Prepare("select * from generate_series(1,$1::int4) n")
766	require.NoError(t, err)
767	defer stmt.Close()
768
769	rows, err := stmt.QueryContext(context.Background(), 5)
770	require.NoError(t, err)
771
772	for rows.Next() {
773		var n int64
774		if err := rows.Scan(&n); err != nil {
775			t.Error(err)
776		}
777	}
778
779	if rows.Err() != nil {
780		t.Error(rows.Err())
781	}
782
783	ensureDBValid(t, db)
784}
785
786func TestRowsColumnTypes(t *testing.T) {
787	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
788		columnTypesTests := []struct {
789			Name     string
790			TypeName string
791			Length   struct {
792				Len int64
793				OK  bool
794			}
795			DecimalSize struct {
796				Precision int64
797				Scale     int64
798				OK        bool
799			}
800			ScanType reflect.Type
801		}{
802			{
803				Name:     "a",
804				TypeName: "INT4",
805				Length: struct {
806					Len int64
807					OK  bool
808				}{
809					Len: 0,
810					OK:  false,
811				},
812				DecimalSize: struct {
813					Precision int64
814					Scale     int64
815					OK        bool
816				}{
817					Precision: 0,
818					Scale:     0,
819					OK:        false,
820				},
821				ScanType: reflect.TypeOf(int32(0)),
822			}, {
823				Name:     "bar",
824				TypeName: "TEXT",
825				Length: struct {
826					Len int64
827					OK  bool
828				}{
829					Len: math.MaxInt64,
830					OK:  true,
831				},
832				DecimalSize: struct {
833					Precision int64
834					Scale     int64
835					OK        bool
836				}{
837					Precision: 0,
838					Scale:     0,
839					OK:        false,
840				},
841				ScanType: reflect.TypeOf(""),
842			}, {
843				Name:     "dec",
844				TypeName: "NUMERIC",
845				Length: struct {
846					Len int64
847					OK  bool
848				}{
849					Len: 0,
850					OK:  false,
851				},
852				DecimalSize: struct {
853					Precision int64
854					Scale     int64
855					OK        bool
856				}{
857					Precision: 9,
858					Scale:     2,
859					OK:        true,
860				},
861				ScanType: reflect.TypeOf(float64(0)),
862			}, {
863				Name:     "d",
864				TypeName: "1266",
865				Length: struct {
866					Len int64
867					OK  bool
868				}{
869					Len: 0,
870					OK:  false,
871				},
872				DecimalSize: struct {
873					Precision int64
874					Scale     int64
875					OK        bool
876				}{
877					Precision: 0,
878					Scale:     0,
879					OK:        false,
880				},
881				ScanType: reflect.TypeOf(""),
882			},
883		}
884
885		rows, err := db.Query("SELECT 1 AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec, '12:00:00'::timetz as d")
886		require.NoError(t, err)
887
888		columns, err := rows.ColumnTypes()
889		require.NoError(t, err)
890		assert.Len(t, columns, 4)
891
892		for i, tt := range columnTypesTests {
893			c := columns[i]
894			if c.Name() != tt.Name {
895				t.Errorf("(%d) got: %s, want: %s", i, c.Name(), tt.Name)
896			}
897			if c.DatabaseTypeName() != tt.TypeName {
898				t.Errorf("(%d) got: %s, want: %s", i, c.DatabaseTypeName(), tt.TypeName)
899			}
900			l, ok := c.Length()
901			if l != tt.Length.Len {
902				t.Errorf("(%d) got: %d, want: %d", i, l, tt.Length.Len)
903			}
904			if ok != tt.Length.OK {
905				t.Errorf("(%d) got: %t, want: %t", i, ok, tt.Length.OK)
906			}
907			p, s, ok := c.DecimalSize()
908			if p != tt.DecimalSize.Precision {
909				t.Errorf("(%d) got: %d, want: %d", i, p, tt.DecimalSize.Precision)
910			}
911			if s != tt.DecimalSize.Scale {
912				t.Errorf("(%d) got: %d, want: %d", i, s, tt.DecimalSize.Scale)
913			}
914			if ok != tt.DecimalSize.OK {
915				t.Errorf("(%d) got: %t, want: %t", i, ok, tt.DecimalSize.OK)
916			}
917			if c.ScanType() != tt.ScanType {
918				t.Errorf("(%d) got: %v, want: %v", i, c.ScanType(), tt.ScanType)
919			}
920		}
921	})
922}
923
924func TestQueryLifeCycle(t *testing.T) {
925	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
926		rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3)
927		require.NoError(t, err)
928
929		rowCount := int64(0)
930
931		for rows.Next() {
932			rowCount++
933			var (
934				s string
935				n int64
936			)
937
938			err := rows.Scan(&s, &n)
939			require.NoError(t, err)
940
941			if s != "foo" {
942				t.Errorf(`Expected "foo", received "%v"`, s)
943			}
944
945			if n != rowCount {
946				t.Errorf("Expected %d, received %d", rowCount, n)
947			}
948		}
949		require.NoError(t, rows.Err())
950
951		err = rows.Close()
952		require.NoError(t, err)
953
954		rows, err = db.Query("select 1 where false")
955		require.NoError(t, err)
956
957		rowCount = int64(0)
958
959		for rows.Next() {
960			rowCount++
961		}
962		require.NoError(t, rows.Err())
963		require.EqualValues(t, 0, rowCount)
964
965		err = rows.Close()
966		require.NoError(t, err)
967	})
968}
969
970// https://github.com/jackc/pgx/issues/409
971func TestScanJSONIntoJSONRawMessage(t *testing.T) {
972	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
973		var msg json.RawMessage
974
975		err := db.QueryRow("select '{}'::json").Scan(&msg)
976		require.NoError(t, err)
977		require.EqualValues(t, []byte("{}"), []byte(msg))
978	})
979}
980
981type testLog struct {
982	lvl  pgx.LogLevel
983	msg  string
984	data map[string]interface{}
985}
986
987type testLogger struct {
988	logs []testLog
989}
990
991func (l *testLogger) Log(ctx context.Context, lvl pgx.LogLevel, msg string, data map[string]interface{}) {
992	l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data})
993}
994
995func TestRegisterConnConfig(t *testing.T) {
996	connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
997	require.NoError(t, err)
998
999	logger := &testLogger{}
1000	connConfig.Logger = logger
1001
1002	connStr := stdlib.RegisterConnConfig(connConfig)
1003	defer stdlib.UnregisterConnConfig(connStr)
1004
1005	db, err := sql.Open("pgx", connStr)
1006	require.NoError(t, err)
1007	defer closeDB(t, db)
1008
1009	var n int64
1010	err = db.QueryRow("select 1").Scan(&n)
1011	require.NoError(t, err)
1012
1013	l := logger.logs[len(logger.logs)-1]
1014	assert.Equal(t, "Query", l.msg)
1015	assert.Equal(t, "select 1", l.data["sql"])
1016}
1017