1package pgx_test
2
3import (
4	"bytes"
5	"context"
6	"fmt"
7	"io"
8	"net"
9	"os"
10	"strconv"
11	"strings"
12	"testing"
13	"time"
14
15	"github.com/jackc/pgconn"
16	"github.com/jackc/pgconn/stmtcache"
17	"github.com/jackc/pgtype"
18	"github.com/jackc/pgx/v4"
19	"github.com/stretchr/testify/require"
20)
21
22func BenchmarkMinimalUnpreparedSelectWithoutStatementCache(b *testing.B) {
23	config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
24	config.BuildStatementCache = nil
25
26	conn := mustConnect(b, config)
27	defer closeConn(b, conn)
28
29	var n int64
30
31	b.ResetTimer()
32	for i := 0; i < b.N; i++ {
33		err := conn.QueryRow(context.Background(), "select $1::int8", i).Scan(&n)
34		if err != nil {
35			b.Fatal(err)
36		}
37
38		if n != int64(i) {
39			b.Fatalf("expected %d, got %d", i, n)
40		}
41	}
42}
43
44func BenchmarkMinimalUnpreparedSelectWithStatementCacheModeDescribe(b *testing.B) {
45	config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
46	config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache {
47		return stmtcache.New(conn, stmtcache.ModeDescribe, 32)
48	}
49
50	conn := mustConnect(b, config)
51	defer closeConn(b, conn)
52
53	var n int64
54
55	b.ResetTimer()
56	for i := 0; i < b.N; i++ {
57		err := conn.QueryRow(context.Background(), "select $1::int8", i).Scan(&n)
58		if err != nil {
59			b.Fatal(err)
60		}
61
62		if n != int64(i) {
63			b.Fatalf("expected %d, got %d", i, n)
64		}
65	}
66}
67
68func BenchmarkMinimalUnpreparedSelectWithStatementCacheModePrepare(b *testing.B) {
69	config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
70	config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache {
71		return stmtcache.New(conn, stmtcache.ModePrepare, 32)
72	}
73
74	conn := mustConnect(b, config)
75	defer closeConn(b, conn)
76
77	var n int64
78
79	b.ResetTimer()
80	for i := 0; i < b.N; i++ {
81		err := conn.QueryRow(context.Background(), "select $1::int8", i).Scan(&n)
82		if err != nil {
83			b.Fatal(err)
84		}
85
86		if n != int64(i) {
87			b.Fatalf("expected %d, got %d", i, n)
88		}
89	}
90}
91
92func BenchmarkMinimalPreparedSelect(b *testing.B) {
93	conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")))
94	defer closeConn(b, conn)
95
96	_, err := conn.Prepare(context.Background(), "ps1", "select $1::int8")
97	if err != nil {
98		b.Fatal(err)
99	}
100
101	var n int64
102
103	b.ResetTimer()
104	for i := 0; i < b.N; i++ {
105		err = conn.QueryRow(context.Background(), "ps1", i).Scan(&n)
106		if err != nil {
107			b.Fatal(err)
108		}
109
110		if n != int64(i) {
111			b.Fatalf("expected %d, got %d", i, n)
112		}
113	}
114}
115
116func BenchmarkMinimalPgConnPreparedSelect(b *testing.B) {
117	conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")))
118	defer closeConn(b, conn)
119
120	pgConn := conn.PgConn()
121
122	_, err := pgConn.Prepare(context.Background(), "ps1", "select $1::int8", nil)
123	if err != nil {
124		b.Fatal(err)
125	}
126
127	encodedBytes := make([]byte, 8)
128
129	b.ResetTimer()
130	for i := 0; i < b.N; i++ {
131
132		rr := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{encodedBytes}, []int16{1}, []int16{1})
133		if err != nil {
134			b.Fatal(err)
135		}
136
137		for rr.NextRow() {
138			for i := range rr.Values() {
139				if bytes.Compare(rr.Values()[0], encodedBytes) != 0 {
140					b.Fatalf("unexpected values: %s %s", rr.Values()[i], encodedBytes)
141				}
142			}
143		}
144		_, err = rr.Close()
145		if err != nil {
146			b.Fatal(err)
147		}
148	}
149}
150
151func BenchmarkPointerPointerWithNullValues(b *testing.B) {
152	conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")))
153	defer closeConn(b, conn)
154
155	_, err := conn.Prepare(context.Background(), "selectNulls", "select 1::int4, 'johnsmith', null::text, null::text, null::text, null::date, null::timestamptz")
156	if err != nil {
157		b.Fatal(err)
158	}
159
160	b.ResetTimer()
161	for i := 0; i < b.N; i++ {
162		var record struct {
163			id            int32
164			userName      string
165			email         *string
166			name          *string
167			sex           *string
168			birthDate     *time.Time
169			lastLoginTime *time.Time
170		}
171
172		err = conn.QueryRow(context.Background(), "selectNulls").Scan(
173			&record.id,
174			&record.userName,
175			&record.email,
176			&record.name,
177			&record.sex,
178			&record.birthDate,
179			&record.lastLoginTime,
180		)
181		if err != nil {
182			b.Fatal(err)
183		}
184
185		// These checks both ensure that the correct data was returned
186		// and provide a benchmark of accessing the returned values.
187		if record.id != 1 {
188			b.Fatalf("bad value for id: %v", record.id)
189		}
190		if record.userName != "johnsmith" {
191			b.Fatalf("bad value for userName: %v", record.userName)
192		}
193		if record.email != nil {
194			b.Fatalf("bad value for email: %v", record.email)
195		}
196		if record.name != nil {
197			b.Fatalf("bad value for name: %v", record.name)
198		}
199		if record.sex != nil {
200			b.Fatalf("bad value for sex: %v", record.sex)
201		}
202		if record.birthDate != nil {
203			b.Fatalf("bad value for birthDate: %v", record.birthDate)
204		}
205		if record.lastLoginTime != nil {
206			b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime)
207		}
208	}
209}
210
211func BenchmarkPointerPointerWithPresentValues(b *testing.B) {
212	conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")))
213	defer closeConn(b, conn)
214
215	_, err := conn.Prepare(context.Background(), "selectNulls", "select 1::int4, 'johnsmith', 'johnsmith@example.com', 'John Smith', 'male', '1970-01-01'::date, '2015-01-01 00:00:00'::timestamptz")
216	if err != nil {
217		b.Fatal(err)
218	}
219
220	b.ResetTimer()
221	for i := 0; i < b.N; i++ {
222		var record struct {
223			id            int32
224			userName      string
225			email         *string
226			name          *string
227			sex           *string
228			birthDate     *time.Time
229			lastLoginTime *time.Time
230		}
231
232		err = conn.QueryRow(context.Background(), "selectNulls").Scan(
233			&record.id,
234			&record.userName,
235			&record.email,
236			&record.name,
237			&record.sex,
238			&record.birthDate,
239			&record.lastLoginTime,
240		)
241		if err != nil {
242			b.Fatal(err)
243		}
244
245		// These checks both ensure that the correct data was returned
246		// and provide a benchmark of accessing the returned values.
247		if record.id != 1 {
248			b.Fatalf("bad value for id: %v", record.id)
249		}
250		if record.userName != "johnsmith" {
251			b.Fatalf("bad value for userName: %v", record.userName)
252		}
253		if record.email == nil || *record.email != "johnsmith@example.com" {
254			b.Fatalf("bad value for email: %v", record.email)
255		}
256		if record.name == nil || *record.name != "John Smith" {
257			b.Fatalf("bad value for name: %v", record.name)
258		}
259		if record.sex == nil || *record.sex != "male" {
260			b.Fatalf("bad value for sex: %v", record.sex)
261		}
262		if record.birthDate == nil || *record.birthDate != time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC) {
263			b.Fatalf("bad value for birthDate: %v", record.birthDate)
264		}
265		if record.lastLoginTime == nil || *record.lastLoginTime != time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local) {
266			b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime)
267		}
268	}
269}
270
271func BenchmarkSelectWithoutLogging(b *testing.B) {
272	conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")))
273	defer closeConn(b, conn)
274
275	benchmarkSelectWithLog(b, conn)
276}
277
278type discardLogger struct{}
279
280func (dl discardLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
281}
282
283func BenchmarkSelectWithLoggingTraceDiscard(b *testing.B) {
284	var logger discardLogger
285	config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
286	config.Logger = logger
287	config.LogLevel = pgx.LogLevelTrace
288
289	conn := mustConnect(b, config)
290	defer closeConn(b, conn)
291
292	benchmarkSelectWithLog(b, conn)
293}
294
295func BenchmarkSelectWithLoggingDebugWithDiscard(b *testing.B) {
296	var logger discardLogger
297	config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
298	config.Logger = logger
299	config.LogLevel = pgx.LogLevelDebug
300
301	conn := mustConnect(b, config)
302	defer closeConn(b, conn)
303
304	benchmarkSelectWithLog(b, conn)
305}
306
307func BenchmarkSelectWithLoggingInfoWithDiscard(b *testing.B) {
308	var logger discardLogger
309	config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
310	config.Logger = logger
311	config.LogLevel = pgx.LogLevelInfo
312
313	conn := mustConnect(b, config)
314	defer closeConn(b, conn)
315
316	benchmarkSelectWithLog(b, conn)
317}
318
319func BenchmarkSelectWithLoggingErrorWithDiscard(b *testing.B) {
320	var logger discardLogger
321	config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
322	config.Logger = logger
323	config.LogLevel = pgx.LogLevelError
324
325	conn := mustConnect(b, config)
326	defer closeConn(b, conn)
327
328	benchmarkSelectWithLog(b, conn)
329}
330
331func benchmarkSelectWithLog(b *testing.B, conn *pgx.Conn) {
332	_, err := conn.Prepare(context.Background(), "test", "select 1::int4, 'johnsmith', 'johnsmith@example.com', 'John Smith', 'male', '1970-01-01'::date, '2015-01-01 00:00:00'::timestamptz")
333	if err != nil {
334		b.Fatal(err)
335	}
336
337	b.ResetTimer()
338	for i := 0; i < b.N; i++ {
339		var record struct {
340			id            int32
341			userName      string
342			email         string
343			name          string
344			sex           string
345			birthDate     time.Time
346			lastLoginTime time.Time
347		}
348
349		err = conn.QueryRow(context.Background(), "test").Scan(
350			&record.id,
351			&record.userName,
352			&record.email,
353			&record.name,
354			&record.sex,
355			&record.birthDate,
356			&record.lastLoginTime,
357		)
358		if err != nil {
359			b.Fatal(err)
360		}
361
362		// These checks both ensure that the correct data was returned
363		// and provide a benchmark of accessing the returned values.
364		if record.id != 1 {
365			b.Fatalf("bad value for id: %v", record.id)
366		}
367		if record.userName != "johnsmith" {
368			b.Fatalf("bad value for userName: %v", record.userName)
369		}
370		if record.email != "johnsmith@example.com" {
371			b.Fatalf("bad value for email: %v", record.email)
372		}
373		if record.name != "John Smith" {
374			b.Fatalf("bad value for name: %v", record.name)
375		}
376		if record.sex != "male" {
377			b.Fatalf("bad value for sex: %v", record.sex)
378		}
379		if record.birthDate != time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC) {
380			b.Fatalf("bad value for birthDate: %v", record.birthDate)
381		}
382		if record.lastLoginTime != time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local) {
383			b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime)
384		}
385	}
386}
387
388const benchmarkWriteTableCreateSQL = `drop table if exists t;
389
390create table t(
391	varchar_1 varchar not null,
392	varchar_2 varchar not null,
393	varchar_null_1 varchar,
394	date_1 date not null,
395	date_null_1 date,
396	int4_1 int4 not null,
397	int4_2 int4 not null,
398	int4_null_1 int4,
399	tstz_1 timestamptz not null,
400	tstz_2 timestamptz,
401	bool_1 bool not null,
402	bool_2 bool not null,
403	bool_3 bool not null
404);
405`
406
407const benchmarkWriteTableInsertSQL = `insert into t(
408	varchar_1,
409	varchar_2,
410	varchar_null_1,
411	date_1,
412	date_null_1,
413	int4_1,
414	int4_2,
415	int4_null_1,
416	tstz_1,
417	tstz_2,
418	bool_1,
419	bool_2,
420	bool_3
421) values (
422	$1::varchar,
423	$2::varchar,
424	$3::varchar,
425	$4::date,
426	$5::date,
427	$6::int4,
428	$7::int4,
429	$8::int4,
430	$9::timestamptz,
431	$10::timestamptz,
432	$11::bool,
433	$12::bool,
434	$13::bool
435)`
436
437type benchmarkWriteTableCopyFromSrc struct {
438	count int
439	idx   int
440	row   []interface{}
441}
442
443func (s *benchmarkWriteTableCopyFromSrc) Next() bool {
444	s.idx++
445	return s.idx < s.count
446}
447
448func (s *benchmarkWriteTableCopyFromSrc) Values() ([]interface{}, error) {
449	return s.row, nil
450}
451
452func (s *benchmarkWriteTableCopyFromSrc) Err() error {
453	return nil
454}
455
456func newBenchmarkWriteTableCopyFromSrc(count int) pgx.CopyFromSource {
457	return &benchmarkWriteTableCopyFromSrc{
458		count: count,
459		row: []interface{}{
460			"varchar_1",
461			"varchar_2",
462			&pgtype.Text{Status: pgtype.Null},
463			time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local),
464			&pgtype.Date{Status: pgtype.Null},
465			1,
466			2,
467			&pgtype.Int4{Status: pgtype.Null},
468			time.Date(2001, 1, 1, 0, 0, 0, 0, time.Local),
469			time.Date(2002, 1, 1, 0, 0, 0, 0, time.Local),
470			true,
471			false,
472			true,
473		},
474	}
475}
476
477func benchmarkWriteNRowsViaInsert(b *testing.B, n int) {
478	conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")))
479	defer closeConn(b, conn)
480
481	mustExec(b, conn, benchmarkWriteTableCreateSQL)
482	_, err := conn.Prepare(context.Background(), "insert_t", benchmarkWriteTableInsertSQL)
483	if err != nil {
484		b.Fatal(err)
485	}
486
487	b.ResetTimer()
488
489	for i := 0; i < b.N; i++ {
490		src := newBenchmarkWriteTableCopyFromSrc(n)
491
492		tx, err := conn.Begin(context.Background())
493		if err != nil {
494			b.Fatal(err)
495		}
496
497		for src.Next() {
498			values, _ := src.Values()
499			if _, err = tx.Exec(context.Background(), "insert_t", values...); err != nil {
500				b.Fatalf("Exec unexpectedly failed with: %v", err)
501			}
502		}
503
504		err = tx.Commit(context.Background())
505		if err != nil {
506			b.Fatal(err)
507		}
508	}
509}
510
511type queryArgs []interface{}
512
513func (qa *queryArgs) Append(v interface{}) string {
514	*qa = append(*qa, v)
515	return "$" + strconv.Itoa(len(*qa))
516}
517
518// note this function is only used for benchmarks -- it doesn't escape tableName
519// or columnNames
520func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc pgx.CopyFromSource) (int, error) {
521	maxRowsPerInsert := 65535 / len(columnNames)
522	rowsThisInsert := 0
523	rowCount := 0
524
525	sqlBuf := &bytes.Buffer{}
526	args := make(queryArgs, 0)
527
528	resetQuery := func() {
529		sqlBuf.Reset()
530		fmt.Fprintf(sqlBuf, "insert into %s(%s) values", tableName, strings.Join(columnNames, ", "))
531
532		args = args[0:0]
533
534		rowsThisInsert = 0
535	}
536	resetQuery()
537
538	tx, err := conn.Begin(context.Background())
539	if err != nil {
540		return 0, err
541	}
542	defer tx.Rollback(context.Background())
543
544	for rowSrc.Next() {
545		if rowsThisInsert > 0 {
546			sqlBuf.WriteByte(',')
547		}
548
549		sqlBuf.WriteByte('(')
550
551		values, err := rowSrc.Values()
552		if err != nil {
553			return 0, err
554		}
555
556		for i, val := range values {
557			if i > 0 {
558				sqlBuf.WriteByte(',')
559			}
560			sqlBuf.WriteString(args.Append(val))
561		}
562
563		sqlBuf.WriteByte(')')
564
565		rowsThisInsert++
566
567		if rowsThisInsert == maxRowsPerInsert {
568			_, err := tx.Exec(context.Background(), sqlBuf.String(), args...)
569			if err != nil {
570				return 0, err
571			}
572
573			rowCount += rowsThisInsert
574			resetQuery()
575		}
576	}
577
578	if rowsThisInsert > 0 {
579		_, err := tx.Exec(context.Background(), sqlBuf.String(), args...)
580		if err != nil {
581			return 0, err
582		}
583
584		rowCount += rowsThisInsert
585	}
586
587	if err := tx.Commit(context.Background()); err != nil {
588		return 0, nil
589	}
590
591	return rowCount, nil
592
593}
594
595func benchmarkWriteNRowsViaMultiInsert(b *testing.B, n int) {
596	conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")))
597	defer closeConn(b, conn)
598
599	mustExec(b, conn, benchmarkWriteTableCreateSQL)
600	_, err := conn.Prepare(context.Background(), "insert_t", benchmarkWriteTableInsertSQL)
601	if err != nil {
602		b.Fatal(err)
603	}
604
605	b.ResetTimer()
606
607	for i := 0; i < b.N; i++ {
608		src := newBenchmarkWriteTableCopyFromSrc(n)
609
610		_, err := multiInsert(conn, "t",
611			[]string{"varchar_1",
612				"varchar_2",
613				"varchar_null_1",
614				"date_1",
615				"date_null_1",
616				"int4_1",
617				"int4_2",
618				"int4_null_1",
619				"tstz_1",
620				"tstz_2",
621				"bool_1",
622				"bool_2",
623				"bool_3"},
624			src)
625		if err != nil {
626			b.Fatal(err)
627		}
628	}
629}
630
631func benchmarkWriteNRowsViaCopy(b *testing.B, n int) {
632	conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")))
633	defer closeConn(b, conn)
634
635	mustExec(b, conn, benchmarkWriteTableCreateSQL)
636
637	b.ResetTimer()
638
639	for i := 0; i < b.N; i++ {
640		src := newBenchmarkWriteTableCopyFromSrc(n)
641
642		_, err := conn.CopyFrom(context.Background(),
643			pgx.Identifier{"t"},
644			[]string{"varchar_1",
645				"varchar_2",
646				"varchar_null_1",
647				"date_1",
648				"date_null_1",
649				"int4_1",
650				"int4_2",
651				"int4_null_1",
652				"tstz_1",
653				"tstz_2",
654				"bool_1",
655				"bool_2",
656				"bool_3"},
657			src)
658		if err != nil {
659			b.Fatal(err)
660		}
661	}
662}
663
664func BenchmarkWrite5RowsViaInsert(b *testing.B) {
665	benchmarkWriteNRowsViaInsert(b, 5)
666}
667
668func BenchmarkWrite5RowsViaMultiInsert(b *testing.B) {
669	benchmarkWriteNRowsViaMultiInsert(b, 5)
670}
671
672func BenchmarkWrite5RowsViaCopy(b *testing.B) {
673	benchmarkWriteNRowsViaCopy(b, 5)
674}
675
676func BenchmarkWrite10RowsViaInsert(b *testing.B) {
677	benchmarkWriteNRowsViaInsert(b, 10)
678}
679
680func BenchmarkWrite10RowsViaMultiInsert(b *testing.B) {
681	benchmarkWriteNRowsViaMultiInsert(b, 10)
682}
683
684func BenchmarkWrite10RowsViaCopy(b *testing.B) {
685	benchmarkWriteNRowsViaCopy(b, 10)
686}
687
688func BenchmarkWrite100RowsViaInsert(b *testing.B) {
689	benchmarkWriteNRowsViaInsert(b, 100)
690}
691
692func BenchmarkWrite100RowsViaMultiInsert(b *testing.B) {
693	benchmarkWriteNRowsViaMultiInsert(b, 100)
694}
695
696func BenchmarkWrite100RowsViaCopy(b *testing.B) {
697	benchmarkWriteNRowsViaCopy(b, 100)
698}
699
700func BenchmarkWrite1000RowsViaInsert(b *testing.B) {
701	benchmarkWriteNRowsViaInsert(b, 1000)
702}
703
704func BenchmarkWrite1000RowsViaMultiInsert(b *testing.B) {
705	benchmarkWriteNRowsViaMultiInsert(b, 1000)
706}
707
708func BenchmarkWrite1000RowsViaCopy(b *testing.B) {
709	benchmarkWriteNRowsViaCopy(b, 1000)
710}
711
712func BenchmarkWrite10000RowsViaInsert(b *testing.B) {
713	benchmarkWriteNRowsViaInsert(b, 10000)
714}
715
716func BenchmarkWrite10000RowsViaMultiInsert(b *testing.B) {
717	benchmarkWriteNRowsViaMultiInsert(b, 10000)
718}
719
720func BenchmarkWrite10000RowsViaCopy(b *testing.B) {
721	benchmarkWriteNRowsViaCopy(b, 10000)
722}
723
724func BenchmarkMultipleQueriesNonBatchNoStatementCache(b *testing.B) {
725	config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
726	config.BuildStatementCache = nil
727
728	conn := mustConnect(b, config)
729	defer closeConn(b, conn)
730
731	benchmarkMultipleQueriesNonBatch(b, conn, 3)
732}
733
734func BenchmarkMultipleQueriesNonBatchPrepareStatementCache(b *testing.B) {
735	config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
736	config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache {
737		return stmtcache.New(conn, stmtcache.ModePrepare, 32)
738	}
739
740	conn := mustConnect(b, config)
741	defer closeConn(b, conn)
742
743	benchmarkMultipleQueriesNonBatch(b, conn, 3)
744}
745
746func BenchmarkMultipleQueriesNonBatchDescribeStatementCache(b *testing.B) {
747	config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
748	config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache {
749		return stmtcache.New(conn, stmtcache.ModeDescribe, 32)
750	}
751
752	conn := mustConnect(b, config)
753	defer closeConn(b, conn)
754
755	benchmarkMultipleQueriesNonBatch(b, conn, 3)
756}
757
758func benchmarkMultipleQueriesNonBatch(b *testing.B, conn *pgx.Conn, queryCount int) {
759	b.ResetTimer()
760	for i := 0; i < b.N; i++ {
761		for j := 0; j < queryCount; j++ {
762			rows, err := conn.Query(context.Background(), "select n from generate_series(0, 5) n")
763			if err != nil {
764				b.Fatal(err)
765			}
766
767			for k := 0; rows.Next(); k++ {
768				var n int
769				if err := rows.Scan(&n); err != nil {
770					b.Fatal(err)
771				}
772				if n != k {
773					b.Fatalf("n => %v, want %v", n, k)
774				}
775			}
776
777			if rows.Err() != nil {
778				b.Fatal(rows.Err())
779			}
780		}
781	}
782}
783
784func BenchmarkMultipleQueriesBatchNoStatementCache(b *testing.B) {
785	config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
786	config.BuildStatementCache = nil
787
788	conn := mustConnect(b, config)
789	defer closeConn(b, conn)
790
791	benchmarkMultipleQueriesBatch(b, conn, 3)
792}
793
794func BenchmarkMultipleQueriesBatchPrepareStatementCache(b *testing.B) {
795	config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
796	config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache {
797		return stmtcache.New(conn, stmtcache.ModePrepare, 32)
798	}
799
800	conn := mustConnect(b, config)
801	defer closeConn(b, conn)
802
803	benchmarkMultipleQueriesBatch(b, conn, 3)
804}
805
806func BenchmarkMultipleQueriesBatchDescribeStatementCache(b *testing.B) {
807	config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
808	config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache {
809		return stmtcache.New(conn, stmtcache.ModeDescribe, 32)
810	}
811
812	conn := mustConnect(b, config)
813	defer closeConn(b, conn)
814
815	benchmarkMultipleQueriesBatch(b, conn, 3)
816}
817
818func benchmarkMultipleQueriesBatch(b *testing.B, conn *pgx.Conn, queryCount int) {
819	b.ResetTimer()
820	for i := 0; i < b.N; i++ {
821		batch := &pgx.Batch{}
822		for j := 0; j < queryCount; j++ {
823			batch.Queue("select n from generate_series(0,5) n")
824		}
825
826		br := conn.SendBatch(context.Background(), batch)
827
828		for j := 0; j < queryCount; j++ {
829			rows, err := br.Query()
830			if err != nil {
831				b.Fatal(err)
832			}
833
834			for k := 0; rows.Next(); k++ {
835				var n int
836				if err := rows.Scan(&n); err != nil {
837					b.Fatal(err)
838				}
839				if n != k {
840					b.Fatalf("n => %v, want %v", n, k)
841				}
842			}
843
844			if rows.Err() != nil {
845				b.Fatal(rows.Err())
846			}
847		}
848
849		err := br.Close()
850		if err != nil {
851			b.Fatal(err)
852		}
853	}
854}
855
856func BenchmarkSelectManyUnknownEnum(b *testing.B) {
857	conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
858	defer closeConn(b, conn)
859
860	ctx := context.Background()
861	tx, err := conn.Begin(ctx)
862	require.NoError(b, err)
863	defer tx.Rollback(ctx)
864
865	_, err = tx.Exec(context.Background(), "drop type if exists color;")
866	require.NoError(b, err)
867
868	_, err = tx.Exec(ctx, `create type color as enum ('blue', 'green', 'orange')`)
869	require.NoError(b, err)
870
871	b.ResetTimer()
872	var x, y, z string
873	for i := 0; i < b.N; i++ {
874		rows, err := conn.Query(ctx, "select 'blue'::color, 'green'::color, 'orange'::color from generate_series(1,10)")
875		if err != nil {
876			b.Fatal(err)
877		}
878
879		for rows.Next() {
880			err = rows.Scan(&x, &y, &z)
881			if err != nil {
882				b.Fatal(err)
883			}
884
885			if x != "blue" {
886				b.Fatal("unexpected result")
887			}
888			if y != "green" {
889				b.Fatal("unexpected result")
890			}
891			if z != "orange" {
892				b.Fatal("unexpected result")
893			}
894		}
895
896		if rows.Err() != nil {
897			b.Fatal(rows.Err())
898		}
899	}
900}
901
902func BenchmarkSelectManyRegisteredEnum(b *testing.B) {
903	conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
904	defer closeConn(b, conn)
905
906	ctx := context.Background()
907	tx, err := conn.Begin(ctx)
908	require.NoError(b, err)
909	defer tx.Rollback(ctx)
910
911	_, err = tx.Exec(context.Background(), "drop type if exists color;")
912	require.NoError(b, err)
913
914	_, err = tx.Exec(ctx, `create type color as enum ('blue', 'green', 'orange')`)
915	require.NoError(b, err)
916
917	var oid uint32
918	err = conn.QueryRow(context.Background(), "select oid from pg_type where typname=$1;", "color").Scan(&oid)
919	require.NoError(b, err)
920
921	et := pgtype.NewEnumType("color", []string{"blue", "green", "orange"})
922	conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: et, Name: "color", OID: oid})
923
924	b.ResetTimer()
925	var x, y, z string
926	for i := 0; i < b.N; i++ {
927		rows, err := conn.Query(ctx, "select 'blue'::color, 'green'::color, 'orange'::color from generate_series(1,10)")
928		if err != nil {
929			b.Fatal(err)
930		}
931
932		for rows.Next() {
933			err = rows.Scan(&x, &y, &z)
934			if err != nil {
935				b.Fatal(err)
936			}
937
938			if x != "blue" {
939				b.Fatal("unexpected result")
940			}
941			if y != "green" {
942				b.Fatal("unexpected result")
943			}
944			if z != "orange" {
945				b.Fatal("unexpected result")
946			}
947		}
948
949		if rows.Err() != nil {
950			b.Fatal(rows.Err())
951		}
952	}
953}
954
955func getSelectRowsCounts(b *testing.B) []int64 {
956	var rowCounts []int64
957	{
958		s := os.Getenv("PGX_BENCH_SELECT_ROWS_COUNTS")
959		if s != "" {
960			for _, p := range strings.Split(s, " ") {
961				n, err := strconv.ParseInt(p, 10, 64)
962				if err != nil {
963					b.Fatalf("Bad PGX_BENCH_SELECT_ROWS_COUNTS value: %v", err)
964				}
965				rowCounts = append(rowCounts, n)
966			}
967		}
968	}
969
970	if len(rowCounts) == 0 {
971		rowCounts = []int64{1, 10, 100, 1000}
972	}
973
974	return rowCounts
975}
976
977type BenchRowSimple struct {
978	ID         int32
979	FirstName  string
980	LastName   string
981	Sex        string
982	BirthDate  time.Time
983	Weight     int32
984	Height     int32
985	UpdateTime time.Time
986}
987
988func BenchmarkSelectRowsScanSimple(b *testing.B) {
989	conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
990	defer closeConn(b, conn)
991
992	rowCounts := getSelectRowsCounts(b)
993
994	for _, rowCount := range rowCounts {
995		b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
996			br := &BenchRowSimple{}
997			for i := 0; i < b.N; i++ {
998				rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
999				if err != nil {
1000					b.Fatal(err)
1001				}
1002
1003				for rows.Next() {
1004					rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.UpdateTime)
1005				}
1006
1007				if rows.Err() != nil {
1008					b.Fatal(rows.Err())
1009				}
1010			}
1011		})
1012	}
1013}
1014
1015type BenchRowStringBytes struct {
1016	ID         int32
1017	FirstName  []byte
1018	LastName   []byte
1019	Sex        []byte
1020	BirthDate  time.Time
1021	Weight     int32
1022	Height     int32
1023	UpdateTime time.Time
1024}
1025
1026func BenchmarkSelectRowsScanStringBytes(b *testing.B) {
1027	conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
1028	defer closeConn(b, conn)
1029
1030	rowCounts := getSelectRowsCounts(b)
1031
1032	for _, rowCount := range rowCounts {
1033		b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
1034			br := &BenchRowStringBytes{}
1035			for i := 0; i < b.N; i++ {
1036				rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
1037				if err != nil {
1038					b.Fatal(err)
1039				}
1040
1041				for rows.Next() {
1042					rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.UpdateTime)
1043				}
1044
1045				if rows.Err() != nil {
1046					b.Fatal(rows.Err())
1047				}
1048			}
1049		})
1050	}
1051}
1052
1053type BenchRowDecoder struct {
1054	ID         pgtype.Int4
1055	FirstName  pgtype.Text
1056	LastName   pgtype.Text
1057	Sex        pgtype.Text
1058	BirthDate  pgtype.Date
1059	Weight     pgtype.Int4
1060	Height     pgtype.Int4
1061	UpdateTime pgtype.Timestamptz
1062}
1063
1064func BenchmarkSelectRowsScanDecoder(b *testing.B) {
1065	conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
1066	defer closeConn(b, conn)
1067
1068	rowCounts := getSelectRowsCounts(b)
1069
1070	for _, rowCount := range rowCounts {
1071		b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
1072			formats := []struct {
1073				name string
1074				code int16
1075			}{
1076				{"text", pgx.TextFormatCode},
1077				{"binary", pgx.BinaryFormatCode},
1078			}
1079			for _, format := range formats {
1080				b.Run(format.name, func(b *testing.B) {
1081
1082					br := &BenchRowDecoder{}
1083					for i := 0; i < b.N; i++ {
1084						rows, err := conn.Query(
1085							context.Background(),
1086							"select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n",
1087							pgx.QueryResultFormats{format.code},
1088							rowCount,
1089						)
1090						if err != nil {
1091							b.Fatal(err)
1092						}
1093
1094						for rows.Next() {
1095							rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.UpdateTime)
1096						}
1097
1098						if rows.Err() != nil {
1099							b.Fatal(rows.Err())
1100						}
1101					}
1102				})
1103			}
1104		})
1105	}
1106}
1107
1108func BenchmarkSelectRowsExplicitDecoding(b *testing.B) {
1109	conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
1110	defer closeConn(b, conn)
1111
1112	rowCounts := getSelectRowsCounts(b)
1113
1114	for _, rowCount := range rowCounts {
1115		b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
1116			br := &BenchRowDecoder{}
1117			for i := 0; i < b.N; i++ {
1118				rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
1119				if err != nil {
1120					b.Fatal(err)
1121				}
1122
1123				for rows.Next() {
1124					rawValues := rows.RawValues()
1125
1126					err = br.ID.DecodeBinary(conn.ConnInfo(), rawValues[0])
1127					if err != nil {
1128						b.Fatal(err)
1129					}
1130
1131					err = br.FirstName.DecodeText(conn.ConnInfo(), rawValues[1])
1132					if err != nil {
1133						b.Fatal(err)
1134					}
1135
1136					err = br.LastName.DecodeText(conn.ConnInfo(), rawValues[2])
1137					if err != nil {
1138						b.Fatal(err)
1139					}
1140
1141					err = br.Sex.DecodeText(conn.ConnInfo(), rawValues[3])
1142					if err != nil {
1143						b.Fatal(err)
1144					}
1145
1146					err = br.BirthDate.DecodeBinary(conn.ConnInfo(), rawValues[4])
1147					if err != nil {
1148						b.Fatal(err)
1149					}
1150
1151					err = br.Weight.DecodeBinary(conn.ConnInfo(), rawValues[5])
1152					if err != nil {
1153						b.Fatal(err)
1154					}
1155
1156					err = br.Height.DecodeBinary(conn.ConnInfo(), rawValues[6])
1157					if err != nil {
1158						b.Fatal(err)
1159					}
1160
1161					err = br.UpdateTime.DecodeBinary(conn.ConnInfo(), rawValues[7])
1162					if err != nil {
1163						b.Fatal(err)
1164					}
1165				}
1166
1167				if rows.Err() != nil {
1168					b.Fatal(rows.Err())
1169				}
1170			}
1171		})
1172	}
1173}
1174
1175func BenchmarkSelectRowsPgConnExecText(b *testing.B) {
1176	conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
1177	defer closeConn(b, conn)
1178
1179	rowCounts := getSelectRowsCounts(b)
1180
1181	for _, rowCount := range rowCounts {
1182		b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
1183			for i := 0; i < b.N; i++ {
1184				mrr := conn.PgConn().Exec(context.Background(), fmt.Sprintf("select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + %d) n", rowCount))
1185				for mrr.NextResult() {
1186					rr := mrr.ResultReader()
1187					for rr.NextRow() {
1188						rr.Values()
1189					}
1190				}
1191
1192				err := mrr.Close()
1193				if err != nil {
1194					b.Fatal(err)
1195				}
1196			}
1197		})
1198	}
1199}
1200
1201func BenchmarkSelectRowsPgConnExecParams(b *testing.B) {
1202	conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
1203	defer closeConn(b, conn)
1204
1205	rowCounts := getSelectRowsCounts(b)
1206
1207	for _, rowCount := range rowCounts {
1208		b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
1209			formats := []struct {
1210				name string
1211				code int16
1212			}{
1213				{"text", pgx.TextFormatCode},
1214				{"binary - mostly", pgx.BinaryFormatCode},
1215			}
1216			for _, format := range formats {
1217				b.Run(format.name, func(b *testing.B) {
1218					for i := 0; i < b.N; i++ {
1219						rr := conn.PgConn().ExecParams(
1220							context.Background(),
1221							"select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n",
1222							[][]byte{[]byte(strconv.FormatInt(rowCount, 10))},
1223							nil,
1224							nil,
1225							[]int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code},
1226						)
1227						for rr.NextRow() {
1228							rr.Values()
1229						}
1230
1231						_, err := rr.Close()
1232						if err != nil {
1233							b.Fatal(err)
1234						}
1235					}
1236				})
1237			}
1238		})
1239	}
1240}
1241
1242func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) {
1243	conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
1244	defer closeConn(b, conn)
1245
1246	rowCounts := getSelectRowsCounts(b)
1247
1248	_, err := conn.PgConn().Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil)
1249	if err != nil {
1250		b.Fatal(err)
1251	}
1252
1253	for _, rowCount := range rowCounts {
1254		b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
1255			formats := []struct {
1256				name string
1257				code int16
1258			}{
1259				{"text", pgx.TextFormatCode},
1260				{"binary - mostly", pgx.BinaryFormatCode},
1261			}
1262			for _, format := range formats {
1263				b.Run(format.name, func(b *testing.B) {
1264					for i := 0; i < b.N; i++ {
1265						rr := conn.PgConn().ExecPrepared(
1266							context.Background(),
1267							"ps1",
1268							[][]byte{[]byte(strconv.FormatInt(rowCount, 10))},
1269							nil,
1270							[]int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code},
1271						)
1272						for rr.NextRow() {
1273							rr.Values()
1274						}
1275
1276						_, err := rr.Close()
1277						if err != nil {
1278							b.Fatal(err)
1279						}
1280					}
1281				})
1282			}
1283		})
1284	}
1285}
1286
1287type queryRecorder struct {
1288	conn      net.Conn
1289	writeBuf  []byte
1290	readCount int
1291}
1292
1293func (qr *queryRecorder) Read(b []byte) (n int, err error) {
1294	n, err = qr.conn.Read(b)
1295	qr.readCount += n
1296	return n, err
1297}
1298
1299func (qr *queryRecorder) Write(b []byte) (n int, err error) {
1300	qr.writeBuf = append(qr.writeBuf, b...)
1301	return qr.conn.Write(b)
1302}
1303
1304func (qr *queryRecorder) Close() error {
1305	return qr.conn.Close()
1306}
1307
1308func (qr *queryRecorder) LocalAddr() net.Addr {
1309	return qr.conn.LocalAddr()
1310}
1311
1312func (qr *queryRecorder) RemoteAddr() net.Addr {
1313	return qr.conn.RemoteAddr()
1314}
1315
1316func (qr *queryRecorder) SetDeadline(t time.Time) error {
1317	return qr.conn.SetDeadline(t)
1318}
1319
1320func (qr *queryRecorder) SetReadDeadline(t time.Time) error {
1321	return qr.conn.SetReadDeadline(t)
1322}
1323
1324func (qr *queryRecorder) SetWriteDeadline(t time.Time) error {
1325	return qr.conn.SetWriteDeadline(t)
1326}
1327
1328// BenchmarkSelectRowsRawPrepared hijacks a pgconn connection and inserts a queryRecorder. It then executes the query
1329// once. The benchmark is simply sending the exact query bytes over the wire to the server and reading the expected
1330// number of bytes back. It does nothing else. This should be the theoretical maximum performance a Go application
1331// could achieve.
1332func BenchmarkSelectRowsRawPrepared(b *testing.B) {
1333	rowCounts := getSelectRowsCounts(b)
1334
1335	for _, rowCount := range rowCounts {
1336		b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
1337			formats := []struct {
1338				name string
1339				code int16
1340			}{
1341				{"text", pgx.TextFormatCode},
1342				{"binary - mostly", pgx.BinaryFormatCode},
1343			}
1344			for _, format := range formats {
1345				b.Run(format.name, func(b *testing.B) {
1346					conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")).PgConn()
1347					defer conn.Close(context.Background())
1348
1349					_, err := conn.Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil)
1350					if err != nil {
1351						b.Fatal(err)
1352					}
1353
1354					hijackedConn, err := conn.Hijack()
1355					require.NoError(b, err)
1356
1357					qr := &queryRecorder{
1358						conn: hijackedConn.Conn,
1359					}
1360
1361					hijackedConn.Conn = qr
1362					hijackedConn.Frontend = hijackedConn.Config.BuildFrontend(qr, qr)
1363					conn, err = pgconn.Construct(hijackedConn)
1364					require.NoError(b, err)
1365
1366					{
1367						rr := conn.ExecPrepared(
1368							context.Background(),
1369							"ps1",
1370							[][]byte{[]byte(strconv.FormatInt(rowCount, 10))},
1371							nil,
1372							[]int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code},
1373						)
1374						_, err := rr.Close()
1375						require.NoError(b, err)
1376					}
1377
1378					buf := make([]byte, qr.readCount)
1379
1380					b.ResetTimer()
1381					for i := 0; i < b.N; i++ {
1382						_, err := qr.conn.Write(qr.writeBuf)
1383						if err != nil {
1384							b.Fatal(err)
1385						}
1386
1387						_, err = io.ReadFull(qr.conn, buf)
1388						if err != nil {
1389							b.Fatal(err)
1390						}
1391					}
1392				})
1393			}
1394		})
1395	}
1396}
1397