1package pgtype_test
2
3import (
4	"bytes"
5	"errors"
6	"net"
7	"testing"
8
9	"github.com/jackc/pgtype"
10	"github.com/jackc/pgx/v4"
11	_ "github.com/jackc/pgx/v4/stdlib"
12	_ "github.com/lib/pq"
13	"github.com/stretchr/testify/assert"
14	"github.com/stretchr/testify/require"
15)
16
17// Test for renamed types
18type _string string
19type _bool bool
20type _int8 int8
21type _int16 int16
22type _int16Slice []int16
23type _int32Slice []int32
24type _int64Slice []int64
25type _float32Slice []float32
26type _float64Slice []float64
27type _byteSlice []byte
28
29func mustParseCIDR(t testing.TB, s string) *net.IPNet {
30	_, ipnet, err := net.ParseCIDR(s)
31	if err != nil {
32		t.Fatal(err)
33	}
34
35	return ipnet
36}
37
38func mustParseInet(t testing.TB, s string) *net.IPNet {
39	ip, ipnet, err := net.ParseCIDR(s)
40	if err != nil {
41		t.Fatal(err)
42	}
43	if ipv4 := ip.To4(); ipv4 != nil {
44		ip = ipv4
45	}
46
47	ipnet.IP = ip
48
49	return ipnet
50}
51
52func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr {
53	addr, err := net.ParseMAC(s)
54	if err != nil {
55		t.Fatal(err)
56	}
57
58	return addr
59}
60
61func TestConnInfoResultFormatCodeForOID(t *testing.T) {
62	ci := pgtype.NewConnInfo()
63
64	// pgtype.JSONB implements BinaryDecoder but also implements ResultFormatPreferrer to override it to text.
65	assert.Equal(t, int16(pgtype.TextFormatCode), ci.ResultFormatCodeForOID(pgtype.JSONBOID))
66
67	// pgtype.Int4 implements BinaryDecoder but does not implement ResultFormatPreferrer so it should be binary.
68	assert.Equal(t, int16(pgtype.BinaryFormatCode), ci.ResultFormatCodeForOID(pgtype.Int4OID))
69}
70
71func TestConnInfoParamFormatCodeForOID(t *testing.T) {
72	ci := pgtype.NewConnInfo()
73
74	// pgtype.JSONB implements BinaryEncoder but also implements ParamFormatPreferrer to override it to text.
75	assert.Equal(t, int16(pgtype.TextFormatCode), ci.ParamFormatCodeForOID(pgtype.JSONBOID))
76
77	// pgtype.Int4 implements BinaryEncoder but does not implement ParamFormatPreferrer so it should be binary.
78	assert.Equal(t, int16(pgtype.BinaryFormatCode), ci.ParamFormatCodeForOID(pgtype.Int4OID))
79}
80
81func TestConnInfoScanNilIsNoOp(t *testing.T) {
82	ci := pgtype.NewConnInfo()
83
84	err := ci.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), nil)
85	assert.NoError(t, err)
86}
87
88func TestConnInfoScanTextFormatInterfacePtr(t *testing.T) {
89	ci := pgtype.NewConnInfo()
90	var got interface{}
91	err := ci.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), &got)
92	require.NoError(t, err)
93	assert.Equal(t, "foo", got)
94}
95
96func TestConnInfoScanTextFormatNonByteaIntoByteSlice(t *testing.T) {
97	ci := pgtype.NewConnInfo()
98	var got []byte
99	err := ci.Scan(pgtype.JSONBOID, pgx.TextFormatCode, []byte("{}"), &got)
100	require.NoError(t, err)
101	assert.Equal(t, []byte("{}"), got)
102}
103
104func TestConnInfoScanBinaryFormatInterfacePtr(t *testing.T) {
105	ci := pgtype.NewConnInfo()
106	var got interface{}
107	err := ci.Scan(pgtype.TextOID, pgx.BinaryFormatCode, []byte("foo"), &got)
108	require.NoError(t, err)
109	assert.Equal(t, "foo", got)
110}
111
112func TestConnInfoScanUnknownOIDToStringsAndBytes(t *testing.T) {
113	unknownOID := uint32(999999)
114	srcBuf := []byte("foo")
115	ci := pgtype.NewConnInfo()
116
117	var s string
118	err := ci.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &s)
119	assert.NoError(t, err)
120	assert.Equal(t, "foo", s)
121
122	var rs _string
123	err = ci.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &rs)
124	assert.NoError(t, err)
125	assert.Equal(t, "foo", string(rs))
126
127	var b []byte
128	err = ci.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &b)
129	assert.NoError(t, err)
130	assert.Equal(t, []byte("foo"), b)
131
132	err = ci.Scan(unknownOID, pgx.BinaryFormatCode, srcBuf, &b)
133	assert.NoError(t, err)
134	assert.Equal(t, []byte("foo"), b)
135
136	var rb _byteSlice
137	err = ci.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &rb)
138	assert.NoError(t, err)
139	assert.Equal(t, []byte("foo"), []byte(rb))
140
141	err = ci.Scan(unknownOID, pgx.BinaryFormatCode, srcBuf, &b)
142	assert.NoError(t, err)
143	assert.Equal(t, []byte("foo"), []byte(rb))
144}
145
146type pgCustomType struct {
147	a string
148	b string
149}
150
151func (ct *pgCustomType) DecodeText(ci *pgtype.ConnInfo, buf []byte) error {
152	// This is not a complete parser for the text format of composite types. This is just for test purposes.
153	if buf == nil {
154		return errors.New("cannot parse null")
155	}
156
157	if len(buf) < 2 {
158		return errors.New("invalid text format")
159	}
160
161	parts := bytes.Split(buf[1:len(buf)-1], []byte(","))
162	if len(parts) != 2 {
163		return errors.New("wrong number of parts")
164	}
165
166	ct.a = string(parts[0])
167	ct.b = string(parts[1])
168
169	return nil
170}
171
172func TestConnInfoScanUnregisteredOIDToCustomType(t *testing.T) {
173	unregisteredOID := uint32(999999)
174	ci := pgtype.NewConnInfo()
175
176	var ct pgCustomType
177	err := ci.Scan(unregisteredOID, pgx.TextFormatCode, []byte("(foo,bar)"), &ct)
178	assert.NoError(t, err)
179	assert.Equal(t, "foo", ct.a)
180	assert.Equal(t, "bar", ct.b)
181
182	// Scan value into pointer to custom type
183	var pCt *pgCustomType
184	err = ci.Scan(unregisteredOID, pgx.TextFormatCode, []byte("(foo,bar)"), &pCt)
185	assert.NoError(t, err)
186	require.NotNil(t, pCt)
187	assert.Equal(t, "foo", pCt.a)
188	assert.Equal(t, "bar", pCt.b)
189
190	// Scan null into pointer to custom type
191	err = ci.Scan(unregisteredOID, pgx.TextFormatCode, nil, &pCt)
192	assert.NoError(t, err)
193	assert.Nil(t, pCt)
194}
195
196func TestConnInfoScanUnknownOIDTextFormat(t *testing.T) {
197	ci := pgtype.NewConnInfo()
198
199	var n int32
200	err := ci.Scan(0, pgx.TextFormatCode, []byte("123"), &n)
201	assert.NoError(t, err)
202	assert.EqualValues(t, 123, n)
203}
204
205func BenchmarkConnInfoScanInt4IntoBinaryDecoder(b *testing.B) {
206	ci := pgtype.NewConnInfo()
207	src := []byte{0, 0, 0, 42}
208	var v pgtype.Int4
209
210	for i := 0; i < b.N; i++ {
211		v = pgtype.Int4{}
212		err := ci.Scan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v)
213		if err != nil {
214			b.Fatal(err)
215		}
216		if v != (pgtype.Int4{Int: 42, Status: pgtype.Present}) {
217			b.Fatal("scan failed due to bad value")
218		}
219	}
220}
221
222func TestScanPlanBinaryInt32ScanChangedType(t *testing.T) {
223	ci := pgtype.NewConnInfo()
224	src := []byte{0, 0, 0, 42}
225	var v int32
226
227	plan := ci.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v)
228	err := plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v)
229	require.NoError(t, err)
230	require.EqualValues(t, 42, v)
231
232	var d pgtype.Int4
233	err = plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &d)
234	require.NoError(t, err)
235	require.EqualValues(t, 42, d.Int)
236	require.EqualValues(t, pgtype.Present, d.Status)
237}
238
239func BenchmarkConnInfoScanInt4IntoGoInt32(b *testing.B) {
240	ci := pgtype.NewConnInfo()
241	src := []byte{0, 0, 0, 42}
242	var v int32
243
244	for i := 0; i < b.N; i++ {
245		v = 0
246		err := ci.Scan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v)
247		if err != nil {
248			b.Fatal(err)
249		}
250		if v != 42 {
251			b.Fatal("scan failed due to bad value")
252		}
253	}
254}
255
256func BenchmarkScanPlanScanInt4IntoBinaryDecoder(b *testing.B) {
257	ci := pgtype.NewConnInfo()
258	src := []byte{0, 0, 0, 42}
259	var v pgtype.Int4
260
261	plan := ci.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v)
262
263	for i := 0; i < b.N; i++ {
264		v = pgtype.Int4{}
265		err := plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v)
266		if err != nil {
267			b.Fatal(err)
268		}
269		if v != (pgtype.Int4{Int: 42, Status: pgtype.Present}) {
270			b.Fatal("scan failed due to bad value")
271		}
272	}
273}
274
275func BenchmarkScanPlanScanInt4IntoGoInt32(b *testing.B) {
276	ci := pgtype.NewConnInfo()
277	src := []byte{0, 0, 0, 42}
278	var v int32
279
280	plan := ci.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v)
281
282	for i := 0; i < b.N; i++ {
283		v = 0
284		err := plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v)
285		if err != nil {
286			b.Fatal(err)
287		}
288		if v != 42 {
289			b.Fatal("scan failed due to bad value")
290		}
291	}
292}
293