1package pgx_test
2
3import (
4	"bytes"
5	"context"
6	"net"
7	"os"
8	"reflect"
9	"testing"
10	"time"
11
12	"github.com/jackc/pgx/v4"
13	"github.com/stretchr/testify/assert"
14	"github.com/stretchr/testify/require"
15)
16
17func TestDateTranscode(t *testing.T) {
18	t.Parallel()
19
20	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
21		dates := []time.Time{
22			time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC),
23			time.Date(1000, 1, 1, 0, 0, 0, 0, time.UTC),
24			time.Date(1600, 1, 1, 0, 0, 0, 0, time.UTC),
25			time.Date(1700, 1, 1, 0, 0, 0, 0, time.UTC),
26			time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC),
27			time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC),
28			time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC),
29			time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC),
30			time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC),
31			time.Date(2001, 1, 2, 0, 0, 0, 0, time.UTC),
32			time.Date(2004, 2, 29, 0, 0, 0, 0, time.UTC),
33			time.Date(2013, 7, 4, 0, 0, 0, 0, time.UTC),
34			time.Date(2013, 12, 25, 0, 0, 0, 0, time.UTC),
35			time.Date(2029, 1, 1, 0, 0, 0, 0, time.UTC),
36			time.Date(2081, 1, 1, 0, 0, 0, 0, time.UTC),
37			time.Date(2096, 2, 29, 0, 0, 0, 0, time.UTC),
38			time.Date(2550, 1, 1, 0, 0, 0, 0, time.UTC),
39			time.Date(9999, 12, 31, 0, 0, 0, 0, time.UTC),
40		}
41
42		for _, actualDate := range dates {
43			var d time.Time
44
45			err := conn.QueryRow(context.Background(), "select $1::date", actualDate).Scan(&d)
46			if err != nil {
47				t.Fatalf("Unexpected failure on QueryRow Scan: %v", err)
48			}
49			if !actualDate.Equal(d) {
50				t.Errorf("Did not transcode date successfully: %v is not %v", d, actualDate)
51			}
52		}
53	})
54}
55
56func TestTimestampTzTranscode(t *testing.T) {
57	t.Parallel()
58
59	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
60		inputTime := time.Date(2013, 1, 2, 3, 4, 5, 6000, time.Local)
61
62		var outputTime time.Time
63
64		err := conn.QueryRow(context.Background(), "select $1::timestamptz", inputTime).Scan(&outputTime)
65		if err != nil {
66			t.Fatalf("QueryRow Scan failed: %v", err)
67		}
68		if !inputTime.Equal(outputTime) {
69			t.Errorf("Did not transcode time successfully: %v is not %v", outputTime, inputTime)
70		}
71	})
72}
73
74// TODO - move these tests to pgtype
75
76func TestJSONAndJSONBTranscode(t *testing.T) {
77	t.Parallel()
78
79	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
80		for _, typename := range []string{"json", "jsonb"} {
81			if _, ok := conn.ConnInfo().DataTypeForName(typename); !ok {
82				continue // No JSON/JSONB type -- must be running against old PostgreSQL
83			}
84
85			testJSONString(t, conn, typename)
86			testJSONStringPointer(t, conn, typename)
87		}
88	})
89}
90
91func TestJSONAndJSONBTranscodeExtendedOnly(t *testing.T) {
92	t.Parallel()
93
94	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
95	defer closeConn(t, conn)
96
97	for _, typename := range []string{"json", "jsonb"} {
98		if _, ok := conn.ConnInfo().DataTypeForName(typename); !ok {
99			continue // No JSON/JSONB type -- must be running against old PostgreSQL
100		}
101		testJSONSingleLevelStringMap(t, conn, typename)
102		testJSONNestedMap(t, conn, typename)
103		testJSONStringArray(t, conn, typename)
104		testJSONInt64Array(t, conn, typename)
105		testJSONInt16ArrayFailureDueToOverflow(t, conn, typename)
106		testJSONStruct(t, conn, typename)
107	}
108
109}
110
111func testJSONString(t *testing.T, conn *pgx.Conn, typename string) {
112	input := `{"key": "value"}`
113	expectedOutput := map[string]string{"key": "value"}
114	var output map[string]string
115	err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output)
116	if err != nil {
117		t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
118		return
119	}
120
121	if !reflect.DeepEqual(expectedOutput, output) {
122		t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, expectedOutput, output)
123		return
124	}
125}
126
127func testJSONStringPointer(t *testing.T, conn *pgx.Conn, typename string) {
128	input := `{"key": "value"}`
129	expectedOutput := map[string]string{"key": "value"}
130	var output map[string]string
131	err := conn.QueryRow(context.Background(), "select $1::"+typename, &input).Scan(&output)
132	if err != nil {
133		t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
134		return
135	}
136
137	if !reflect.DeepEqual(expectedOutput, output) {
138		t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, expectedOutput, output)
139		return
140	}
141}
142
143func testJSONSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string) {
144	input := map[string]string{"key": "value"}
145	var output map[string]string
146	err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output)
147	if err != nil {
148		t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
149		return
150	}
151
152	if !reflect.DeepEqual(input, output) {
153		t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, input, output)
154		return
155	}
156}
157
158func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string) {
159	input := map[string]interface{}{
160		"name":      "Uncanny",
161		"stats":     map[string]interface{}{"hp": float64(107), "maxhp": float64(150)},
162		"inventory": []interface{}{"phone", "key"},
163	}
164	var output map[string]interface{}
165	err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output)
166	if err != nil {
167		t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
168		return
169	}
170
171	if !reflect.DeepEqual(input, output) {
172		t.Errorf("%s: Did not transcode map[string]interface{} successfully: %v is not %v", typename, input, output)
173		return
174	}
175}
176
177func testJSONStringArray(t *testing.T, conn *pgx.Conn, typename string) {
178	input := []string{"foo", "bar", "baz"}
179	var output []string
180	err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output)
181	if err != nil {
182		t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
183	}
184
185	if !reflect.DeepEqual(input, output) {
186		t.Errorf("%s: Did not transcode []string successfully: %v is not %v", typename, input, output)
187	}
188}
189
190func testJSONInt64Array(t *testing.T, conn *pgx.Conn, typename string) {
191	input := []int64{1, 2, 234432}
192	var output []int64
193	err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output)
194	if err != nil {
195		t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
196	}
197
198	if !reflect.DeepEqual(input, output) {
199		t.Errorf("%s: Did not transcode []int64 successfully: %v is not %v", typename, input, output)
200	}
201}
202
203func testJSONInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string) {
204	input := []int{1, 2, 234432}
205	var output []int16
206	err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output)
207	if err == nil || err.Error() != "can't scan into dest[0]: json: cannot unmarshal number 234432 into Go value of type int16" {
208		t.Errorf("%s: Expected *json.UnmarkalTypeError, but got %v", typename, err)
209	}
210}
211
212func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string) {
213	type person struct {
214		Name string `json:"name"`
215		Age  int    `json:"age"`
216	}
217
218	input := person{
219		Name: "John",
220		Age:  42,
221	}
222
223	var output person
224
225	err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output)
226	if err != nil {
227		t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
228	}
229
230	if !reflect.DeepEqual(input, output) {
231		t.Errorf("%s: Did not transcode struct successfully: %v is not %v", typename, input, output)
232	}
233}
234
235func mustParseCIDR(t *testing.T, s string) *net.IPNet {
236	_, ipnet, err := net.ParseCIDR(s)
237	if err != nil {
238		t.Fatal(err)
239	}
240
241	return ipnet
242}
243
244func TestStringToNotTextTypeTranscode(t *testing.T) {
245	t.Parallel()
246
247	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
248		input := "01086ee0-4963-4e35-9116-30c173a8d0bd"
249
250		var output string
251		err := conn.QueryRow(context.Background(), "select $1::uuid", input).Scan(&output)
252		if err != nil {
253			t.Fatal(err)
254		}
255		if input != output {
256			t.Errorf("uuid: Did not transcode string successfully: %s is not %s", input, output)
257		}
258
259		err = conn.QueryRow(context.Background(), "select $1::uuid", &input).Scan(&output)
260		if err != nil {
261			t.Fatal(err)
262		}
263		if input != output {
264			t.Errorf("uuid: Did not transcode pointer to string successfully: %s is not %s", input, output)
265		}
266	})
267}
268
269func TestInetCIDRTranscodeIPNet(t *testing.T) {
270	t.Parallel()
271
272	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
273		tests := []struct {
274			sql   string
275			value *net.IPNet
276		}{
277			{"select $1::inet", mustParseCIDR(t, "0.0.0.0/32")},
278			{"select $1::inet", mustParseCIDR(t, "127.0.0.1/32")},
279			{"select $1::inet", mustParseCIDR(t, "12.34.56.0/32")},
280			{"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")},
281			{"select $1::inet", mustParseCIDR(t, "255.0.0.0/8")},
282			{"select $1::inet", mustParseCIDR(t, "255.255.255.255/32")},
283			{"select $1::inet", mustParseCIDR(t, "::/128")},
284			{"select $1::inet", mustParseCIDR(t, "::/0")},
285			{"select $1::inet", mustParseCIDR(t, "::1/128")},
286			{"select $1::inet", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")},
287			{"select $1::cidr", mustParseCIDR(t, "0.0.0.0/32")},
288			{"select $1::cidr", mustParseCIDR(t, "127.0.0.1/32")},
289			{"select $1::cidr", mustParseCIDR(t, "12.34.56.0/32")},
290			{"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")},
291			{"select $1::cidr", mustParseCIDR(t, "255.0.0.0/8")},
292			{"select $1::cidr", mustParseCIDR(t, "255.255.255.255/32")},
293			{"select $1::cidr", mustParseCIDR(t, "::/128")},
294			{"select $1::cidr", mustParseCIDR(t, "::/0")},
295			{"select $1::cidr", mustParseCIDR(t, "::1/128")},
296			{"select $1::cidr", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")},
297		}
298
299		for i, tt := range tests {
300			var actual net.IPNet
301
302			err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual)
303			if err != nil {
304				t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
305				continue
306			}
307
308			if actual.String() != tt.value.String() {
309				t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql)
310			}
311		}
312	})
313}
314
315func TestInetCIDRTranscodeIP(t *testing.T) {
316	t.Parallel()
317
318	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
319		tests := []struct {
320			sql   string
321			value net.IP
322		}{
323			{"select $1::inet", net.ParseIP("0.0.0.0")},
324			{"select $1::inet", net.ParseIP("127.0.0.1")},
325			{"select $1::inet", net.ParseIP("12.34.56.0")},
326			{"select $1::inet", net.ParseIP("255.255.255.255")},
327			{"select $1::inet", net.ParseIP("::1")},
328			{"select $1::inet", net.ParseIP("2607:f8b0:4009:80b::200e")},
329			{"select $1::cidr", net.ParseIP("0.0.0.0")},
330			{"select $1::cidr", net.ParseIP("127.0.0.1")},
331			{"select $1::cidr", net.ParseIP("12.34.56.0")},
332			{"select $1::cidr", net.ParseIP("255.255.255.255")},
333			{"select $1::cidr", net.ParseIP("::1")},
334			{"select $1::cidr", net.ParseIP("2607:f8b0:4009:80b::200e")},
335		}
336
337		for i, tt := range tests {
338			var actual net.IP
339
340			err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual)
341			if err != nil {
342				t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
343				continue
344			}
345
346			if !actual.Equal(tt.value) {
347				t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql)
348			}
349
350			ensureConnValid(t, conn)
351		}
352
353		failTests := []struct {
354			sql   string
355			value *net.IPNet
356		}{
357			{"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")},
358			{"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")},
359		}
360		for i, tt := range failTests {
361			var actual net.IP
362
363			err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual)
364			if err == nil {
365				t.Errorf("%d. Expected failure but got none", i)
366				continue
367			}
368
369			ensureConnValid(t, conn)
370		}
371	})
372}
373
374func TestInetCIDRArrayTranscodeIPNet(t *testing.T) {
375	t.Parallel()
376
377	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
378		tests := []struct {
379			sql   string
380			value []*net.IPNet
381		}{
382			{
383				"select $1::inet[]",
384				[]*net.IPNet{
385					mustParseCIDR(t, "0.0.0.0/32"),
386					mustParseCIDR(t, "127.0.0.1/32"),
387					mustParseCIDR(t, "12.34.56.0/32"),
388					mustParseCIDR(t, "192.168.1.0/24"),
389					mustParseCIDR(t, "255.0.0.0/8"),
390					mustParseCIDR(t, "255.255.255.255/32"),
391					mustParseCIDR(t, "::/128"),
392					mustParseCIDR(t, "::/0"),
393					mustParseCIDR(t, "::1/128"),
394					mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"),
395				},
396			},
397			{
398				"select $1::cidr[]",
399				[]*net.IPNet{
400					mustParseCIDR(t, "0.0.0.0/32"),
401					mustParseCIDR(t, "127.0.0.1/32"),
402					mustParseCIDR(t, "12.34.56.0/32"),
403					mustParseCIDR(t, "192.168.1.0/24"),
404					mustParseCIDR(t, "255.0.0.0/8"),
405					mustParseCIDR(t, "255.255.255.255/32"),
406					mustParseCIDR(t, "::/128"),
407					mustParseCIDR(t, "::/0"),
408					mustParseCIDR(t, "::1/128"),
409					mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"),
410				},
411			},
412		}
413
414		for i, tt := range tests {
415			var actual []*net.IPNet
416
417			err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual)
418			if err != nil {
419				t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
420				continue
421			}
422
423			if !reflect.DeepEqual(actual, tt.value) {
424				t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql)
425			}
426
427			ensureConnValid(t, conn)
428		}
429	})
430}
431
432func TestInetCIDRArrayTranscodeIP(t *testing.T) {
433	t.Parallel()
434
435	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
436		tests := []struct {
437			sql   string
438			value []net.IP
439		}{
440			{
441				"select $1::inet[]",
442				[]net.IP{
443					net.ParseIP("0.0.0.0"),
444					net.ParseIP("127.0.0.1"),
445					net.ParseIP("12.34.56.0"),
446					net.ParseIP("255.255.255.255"),
447					net.ParseIP("2607:f8b0:4009:80b::200e"),
448				},
449			},
450			{
451				"select $1::cidr[]",
452				[]net.IP{
453					net.ParseIP("0.0.0.0"),
454					net.ParseIP("127.0.0.1"),
455					net.ParseIP("12.34.56.0"),
456					net.ParseIP("255.255.255.255"),
457					net.ParseIP("2607:f8b0:4009:80b::200e"),
458				},
459			},
460		}
461
462		for i, tt := range tests {
463			var actual []net.IP
464
465			err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual)
466			if err != nil {
467				t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
468				continue
469			}
470
471			assert.Equal(t, len(tt.value), len(actual), "%d", i)
472			for j := range actual {
473				assert.True(t, actual[j].Equal(tt.value[j]), "%d", i)
474			}
475
476			ensureConnValid(t, conn)
477		}
478
479		failTests := []struct {
480			sql   string
481			value []*net.IPNet
482		}{
483			{
484				"select $1::inet[]",
485				[]*net.IPNet{
486					mustParseCIDR(t, "12.34.56.0/32"),
487					mustParseCIDR(t, "192.168.1.0/24"),
488				},
489			},
490			{
491				"select $1::cidr[]",
492				[]*net.IPNet{
493					mustParseCIDR(t, "12.34.56.0/32"),
494					mustParseCIDR(t, "192.168.1.0/24"),
495				},
496			},
497		}
498
499		for i, tt := range failTests {
500			var actual []net.IP
501
502			err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual)
503			if err == nil {
504				t.Errorf("%d. Expected failure but got none", i)
505				continue
506			}
507
508			ensureConnValid(t, conn)
509		}
510	})
511}
512
513func TestInetCIDRTranscodeWithJustIP(t *testing.T) {
514	t.Parallel()
515
516	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
517		tests := []struct {
518			sql   string
519			value string
520		}{
521			{"select $1::inet", "0.0.0.0/32"},
522			{"select $1::inet", "127.0.0.1/32"},
523			{"select $1::inet", "12.34.56.0/32"},
524			{"select $1::inet", "255.255.255.255/32"},
525			{"select $1::inet", "::/128"},
526			{"select $1::inet", "2607:f8b0:4009:80b::200e/128"},
527			{"select $1::cidr", "0.0.0.0/32"},
528			{"select $1::cidr", "127.0.0.1/32"},
529			{"select $1::cidr", "12.34.56.0/32"},
530			{"select $1::cidr", "255.255.255.255/32"},
531			{"select $1::cidr", "::/128"},
532			{"select $1::cidr", "2607:f8b0:4009:80b::200e/128"},
533		}
534
535		for i, tt := range tests {
536			expected := mustParseCIDR(t, tt.value)
537			var actual net.IPNet
538
539			err := conn.QueryRow(context.Background(), tt.sql, expected.IP).Scan(&actual)
540			if err != nil {
541				t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
542				continue
543			}
544
545			if actual.String() != expected.String() {
546				t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql)
547			}
548
549			ensureConnValid(t, conn)
550		}
551	})
552}
553
554func TestArrayDecoding(t *testing.T) {
555	t.Parallel()
556
557	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
558		tests := []struct {
559			sql    string
560			query  interface{}
561			scan   interface{}
562			assert func(*testing.T, interface{}, interface{})
563		}{
564			{
565				"select $1::bool[]", []bool{true, false, true}, &[]bool{},
566				func(t *testing.T, query, scan interface{}) {
567					if !reflect.DeepEqual(query, *(scan.(*[]bool))) {
568						t.Errorf("failed to encode bool[]")
569					}
570				},
571			},
572			{
573				"select $1::smallint[]", []int16{2, 4, 484, 32767}, &[]int16{},
574				func(t *testing.T, query, scan interface{}) {
575					if !reflect.DeepEqual(query, *(scan.(*[]int16))) {
576						t.Errorf("failed to encode smallint[]")
577					}
578				},
579			},
580			{
581				"select $1::smallint[]", []uint16{2, 4, 484, 32767}, &[]uint16{},
582				func(t *testing.T, query, scan interface{}) {
583					if !reflect.DeepEqual(query, *(scan.(*[]uint16))) {
584						t.Errorf("failed to encode smallint[]")
585					}
586				},
587			},
588			{
589				"select $1::int[]", []int32{2, 4, 484}, &[]int32{},
590				func(t *testing.T, query, scan interface{}) {
591					if !reflect.DeepEqual(query, *(scan.(*[]int32))) {
592						t.Errorf("failed to encode int[]")
593					}
594				},
595			},
596			{
597				"select $1::int[]", []uint32{2, 4, 484, 2147483647}, &[]uint32{},
598				func(t *testing.T, query, scan interface{}) {
599					if !reflect.DeepEqual(query, *(scan.(*[]uint32))) {
600						t.Errorf("failed to encode int[]")
601					}
602				},
603			},
604			{
605				"select $1::bigint[]", []int64{2, 4, 484, 9223372036854775807}, &[]int64{},
606				func(t *testing.T, query, scan interface{}) {
607					if !reflect.DeepEqual(query, *(scan.(*[]int64))) {
608						t.Errorf("failed to encode bigint[]")
609					}
610				},
611			},
612			{
613				"select $1::bigint[]", []uint64{2, 4, 484, 9223372036854775807}, &[]uint64{},
614				func(t *testing.T, query, scan interface{}) {
615					if !reflect.DeepEqual(query, *(scan.(*[]uint64))) {
616						t.Errorf("failed to encode bigint[]")
617					}
618				},
619			},
620			{
621				"select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{},
622				func(t *testing.T, query, scan interface{}) {
623					if !reflect.DeepEqual(query, *(scan.(*[]string))) {
624						t.Errorf("failed to encode text[]")
625					}
626				},
627			},
628			{
629				"select $1::timestamptz[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{},
630				func(t *testing.T, query, scan interface{}) {
631					queryTimeSlice := query.([]time.Time)
632					scanTimeSlice := *(scan.(*[]time.Time))
633					require.Equal(t, len(queryTimeSlice), len(scanTimeSlice))
634					for i := range queryTimeSlice {
635						assert.Truef(t, queryTimeSlice[i].Equal(scanTimeSlice[i]), "%d", i)
636					}
637				},
638			},
639			{
640				"select $1::bytea[]", [][]byte{{0, 1, 2, 3}, {4, 5, 6, 7}}, &[][]byte{},
641				func(t *testing.T, query, scan interface{}) {
642					queryBytesSliceSlice := query.([][]byte)
643					scanBytesSliceSlice := *(scan.(*[][]byte))
644					if len(queryBytesSliceSlice) != len(scanBytesSliceSlice) {
645						t.Errorf("failed to encode byte[][] to bytea[]: expected %d to equal %d", len(queryBytesSliceSlice), len(scanBytesSliceSlice))
646					}
647					for i := range queryBytesSliceSlice {
648						qb := queryBytesSliceSlice[i]
649						sb := scanBytesSliceSlice[i]
650						if !bytes.Equal(qb, sb) {
651							t.Errorf("failed to encode byte[][] to bytea[]: expected %v to equal %v", qb, sb)
652						}
653					}
654				},
655			},
656		}
657
658		for i, tt := range tests {
659			err := conn.QueryRow(context.Background(), tt.sql, tt.query).Scan(tt.scan)
660			if err != nil {
661				t.Errorf(`%d. error reading array: %v`, i, err)
662				continue
663			}
664			tt.assert(t, tt.query, tt.scan)
665			ensureConnValid(t, conn)
666		}
667	})
668}
669
670func TestEmptyArrayDecoding(t *testing.T) {
671	t.Parallel()
672
673	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
674		var val []string
675
676		err := conn.QueryRow(context.Background(), "select array[]::text[]").Scan(&val)
677		if err != nil {
678			t.Errorf(`error reading array: %v`, err)
679		}
680		if len(val) != 0 {
681			t.Errorf("Expected 0 values, got %d", len(val))
682		}
683
684		var n, m int32
685
686		err = conn.QueryRow(context.Background(), "select 1::integer, array[]::text[], 42::integer").Scan(&n, &val, &m)
687		if err != nil {
688			t.Errorf(`error reading array: %v`, err)
689		}
690		if len(val) != 0 {
691			t.Errorf("Expected 0 values, got %d", len(val))
692		}
693		if n != 1 {
694			t.Errorf("Expected n to be 1, but it was %d", n)
695		}
696		if m != 42 {
697			t.Errorf("Expected n to be 42, but it was %d", n)
698		}
699
700		rows, err := conn.Query(context.Background(), "select 1::integer, array['test']::text[] union select 2::integer, array[]::text[] union select 3::integer, array['test']::text[]")
701		if err != nil {
702			t.Errorf(`error retrieving rows with array: %v`, err)
703		}
704		defer rows.Close()
705
706		for rows.Next() {
707			err = rows.Scan(&n, &val)
708			if err != nil {
709				t.Errorf(`error reading array: %v`, err)
710			}
711		}
712	})
713}
714
715func TestPointerPointer(t *testing.T) {
716	t.Parallel()
717
718	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
719		type allTypes struct {
720			s   *string
721			i16 *int16
722			i32 *int32
723			i64 *int64
724			f32 *float32
725			f64 *float64
726			b   *bool
727			t   *time.Time
728		}
729
730		var actual, zero, expected allTypes
731
732		{
733			s := "foo"
734			expected.s = &s
735			i16 := int16(1)
736			expected.i16 = &i16
737			i32 := int32(1)
738			expected.i32 = &i32
739			i64 := int64(1)
740			expected.i64 = &i64
741			f32 := float32(1.23)
742			expected.f32 = &f32
743			f64 := float64(1.23)
744			expected.f64 = &f64
745			b := true
746			expected.b = &b
747			t := time.Unix(123, 5000)
748			expected.t = &t
749		}
750
751		tests := []struct {
752			sql       string
753			queryArgs []interface{}
754			scanArgs  []interface{}
755			expected  allTypes
756		}{
757			{"select $1::text", []interface{}{expected.s}, []interface{}{&actual.s}, allTypes{s: expected.s}},
758			{"select $1::text", []interface{}{zero.s}, []interface{}{&actual.s}, allTypes{}},
759			{"select $1::int2", []interface{}{expected.i16}, []interface{}{&actual.i16}, allTypes{i16: expected.i16}},
760			{"select $1::int2", []interface{}{zero.i16}, []interface{}{&actual.i16}, allTypes{}},
761			{"select $1::int4", []interface{}{expected.i32}, []interface{}{&actual.i32}, allTypes{i32: expected.i32}},
762			{"select $1::int4", []interface{}{zero.i32}, []interface{}{&actual.i32}, allTypes{}},
763			{"select $1::int8", []interface{}{expected.i64}, []interface{}{&actual.i64}, allTypes{i64: expected.i64}},
764			{"select $1::int8", []interface{}{zero.i64}, []interface{}{&actual.i64}, allTypes{}},
765			{"select $1::float4", []interface{}{expected.f32}, []interface{}{&actual.f32}, allTypes{f32: expected.f32}},
766			{"select $1::float4", []interface{}{zero.f32}, []interface{}{&actual.f32}, allTypes{}},
767			{"select $1::float8", []interface{}{expected.f64}, []interface{}{&actual.f64}, allTypes{f64: expected.f64}},
768			{"select $1::float8", []interface{}{zero.f64}, []interface{}{&actual.f64}, allTypes{}},
769			{"select $1::bool", []interface{}{expected.b}, []interface{}{&actual.b}, allTypes{b: expected.b}},
770			{"select $1::bool", []interface{}{zero.b}, []interface{}{&actual.b}, allTypes{}},
771			{"select $1::timestamptz", []interface{}{expected.t}, []interface{}{&actual.t}, allTypes{t: expected.t}},
772			{"select $1::timestamptz", []interface{}{zero.t}, []interface{}{&actual.t}, allTypes{}},
773		}
774
775		for i, tt := range tests {
776			actual = zero
777
778			err := conn.QueryRow(context.Background(), tt.sql, tt.queryArgs...).Scan(tt.scanArgs...)
779			if err != nil {
780				t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, tt.sql, tt.queryArgs)
781			}
782
783			if !reflect.DeepEqual(actual, tt.expected) {
784				t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, tt.sql, tt.queryArgs)
785			}
786
787			ensureConnValid(t, conn)
788		}
789	})
790}
791
792func TestPointerPointerNonZero(t *testing.T) {
793	t.Parallel()
794
795	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
796		f := "foo"
797		dest := &f
798
799		err := conn.QueryRow(context.Background(), "select $1::text", nil).Scan(&dest)
800		if err != nil {
801			t.Errorf("Unexpected failure scanning: %v", err)
802		}
803		if dest != nil {
804			t.Errorf("Expected dest to be nil, got %#v", dest)
805		}
806	})
807}
808
809func TestEncodeTypeRename(t *testing.T) {
810	t.Parallel()
811
812	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
813		type _int int
814		inInt := _int(1)
815		var outInt _int
816
817		type _int8 int8
818		inInt8 := _int8(2)
819		var outInt8 _int8
820
821		type _int16 int16
822		inInt16 := _int16(3)
823		var outInt16 _int16
824
825		type _int32 int32
826		inInt32 := _int32(4)
827		var outInt32 _int32
828
829		type _int64 int64
830		inInt64 := _int64(5)
831		var outInt64 _int64
832
833		type _uint uint
834		inUint := _uint(6)
835		var outUint _uint
836
837		type _uint8 uint8
838		inUint8 := _uint8(7)
839		var outUint8 _uint8
840
841		type _uint16 uint16
842		inUint16 := _uint16(8)
843		var outUint16 _uint16
844
845		type _uint32 uint32
846		inUint32 := _uint32(9)
847		var outUint32 _uint32
848
849		type _uint64 uint64
850		inUint64 := _uint64(10)
851		var outUint64 _uint64
852
853		type _string string
854		inString := _string("foo")
855		var outString _string
856
857		err := conn.QueryRow(context.Background(), "select $1::int, $2::int, $3::int2, $4::int4, $5::int8, $6::int, $7::int, $8::int, $9::int, $10::int, $11::text",
858			inInt, inInt8, inInt16, inInt32, inInt64, inUint, inUint8, inUint16, inUint32, inUint64, inString,
859		).Scan(&outInt, &outInt8, &outInt16, &outInt32, &outInt64, &outUint, &outUint8, &outUint16, &outUint32, &outUint64, &outString)
860		if err != nil {
861			t.Fatalf("Failed with type rename: %v", err)
862		}
863
864		if inInt != outInt {
865			t.Errorf("int rename: expected %v, got %v", inInt, outInt)
866		}
867
868		if inInt8 != outInt8 {
869			t.Errorf("int8 rename: expected %v, got %v", inInt8, outInt8)
870		}
871
872		if inInt16 != outInt16 {
873			t.Errorf("int16 rename: expected %v, got %v", inInt16, outInt16)
874		}
875
876		if inInt32 != outInt32 {
877			t.Errorf("int32 rename: expected %v, got %v", inInt32, outInt32)
878		}
879
880		if inInt64 != outInt64 {
881			t.Errorf("int64 rename: expected %v, got %v", inInt64, outInt64)
882		}
883
884		if inUint != outUint {
885			t.Errorf("uint rename: expected %v, got %v", inUint, outUint)
886		}
887
888		if inUint8 != outUint8 {
889			t.Errorf("uint8 rename: expected %v, got %v", inUint8, outUint8)
890		}
891
892		if inUint16 != outUint16 {
893			t.Errorf("uint16 rename: expected %v, got %v", inUint16, outUint16)
894		}
895
896		if inUint32 != outUint32 {
897			t.Errorf("uint32 rename: expected %v, got %v", inUint32, outUint32)
898		}
899
900		if inUint64 != outUint64 {
901			t.Errorf("uint64 rename: expected %v, got %v", inUint64, outUint64)
902		}
903
904		if inString != outString {
905			t.Errorf("string rename: expected %v, got %v", inString, outString)
906		}
907	})
908}
909
910func TestRowDecodeBinary(t *testing.T) {
911	t.Parallel()
912
913	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
914	defer closeConn(t, conn)
915
916	tests := []struct {
917		sql      string
918		expected []interface{}
919	}{
920		{
921			"select row(1, 'cat', '2015-01-01 08:12:42-00'::timestamptz)",
922			[]interface{}{
923				int32(1),
924				"cat",
925				time.Date(2015, 1, 1, 8, 12, 42, 0, time.UTC).Local(),
926			},
927		},
928		{
929			"select row(100.0::float, 1.09::float)",
930			[]interface{}{
931				float64(100),
932				float64(1.09),
933			},
934		},
935	}
936
937	for i, tt := range tests {
938		var actual []interface{}
939
940		err := conn.QueryRow(context.Background(), tt.sql).Scan(&actual)
941		if err != nil {
942			t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql)
943			continue
944		}
945
946		if !reflect.DeepEqual(actual, tt.expected) {
947			t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.expected, actual, tt.sql)
948		}
949
950		ensureConnValid(t, conn)
951	}
952}
953
954// https://github.com/jackc/pgx/issues/810
955func TestRowsScanNilThenScanValue(t *testing.T) {
956	t.Parallel()
957
958	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
959		sql := `select null as a, null as b
960union
961select 1, 2
962order by a nulls first
963`
964		rows, err := conn.Query(context.Background(), sql)
965		require.NoError(t, err)
966
967		require.True(t, rows.Next())
968
969		err = rows.Scan(nil, nil)
970		require.NoError(t, err)
971
972		require.True(t, rows.Next())
973
974		var a int
975		var b int
976		err = rows.Scan(&a, &b)
977		require.NoError(t, err)
978
979		require.EqualValues(t, 1, a)
980		require.EqualValues(t, 2, b)
981
982		rows.Close()
983		require.NoError(t, rows.Err())
984	})
985}
986
987func TestScanIntoByteSlice(t *testing.T) {
988	t.Parallel()
989
990	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
991	defer closeConn(t, conn)
992	// Success cases
993	for _, tt := range []struct {
994		name             string
995		sql              string
996		resultFormatCode int16
997		output           []byte
998	}{
999		{"int - text", "select 42", pgx.TextFormatCode, []byte("42")},
1000		{"text - text", "select 'hi'", pgx.TextFormatCode, []byte("hi")},
1001		{"text - binary", "select 'hi'", pgx.BinaryFormatCode, []byte("hi")},
1002		{"json - text", "select '{}'::json", pgx.TextFormatCode, []byte("{}")},
1003		{"json - binary", "select '{}'::json", pgx.BinaryFormatCode, []byte("{}")},
1004		{"jsonb - text", "select '{}'::jsonb", pgx.TextFormatCode, []byte("{}")},
1005		{"jsonb - binary", "select '{}'::jsonb", pgx.BinaryFormatCode, []byte("{}")},
1006	} {
1007		t.Run(tt.name, func(t *testing.T) {
1008			var buf []byte
1009			err := conn.QueryRow(context.Background(), tt.sql, pgx.QueryResultFormats{tt.resultFormatCode}).Scan(&buf)
1010			require.NoError(t, err)
1011			require.Equal(t, tt.output, buf)
1012		})
1013	}
1014
1015	// Failure cases
1016	for _, tt := range []struct {
1017		name string
1018		sql  string
1019		err  string
1020	}{
1021		{"int binary", "select 42", "can't scan into dest[0]: cannot assign 42 into *[]uint8"},
1022	} {
1023		t.Run(tt.name, func(t *testing.T) {
1024			var buf []byte
1025			err := conn.QueryRow(context.Background(), tt.sql, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&buf)
1026			require.EqualError(t, err, tt.err)
1027		})
1028	}
1029}
1030