1package pgtype_test
2
3import (
4	"context"
5	"fmt"
6	"os"
7	"testing"
8
9	"github.com/jackc/pgtype"
10	"github.com/jackc/pgtype/testutil"
11	pgx "github.com/jackc/pgx/v4"
12	"github.com/stretchr/testify/assert"
13	"github.com/stretchr/testify/require"
14)
15
16func TestCompositeTypeSetAndGet(t *testing.T) {
17	ci := pgtype.NewConnInfo()
18	ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{
19		{"a", pgtype.TextOID},
20		{"b", pgtype.Int4OID},
21	}, ci)
22	require.NoError(t, err)
23	assert.Equal(t, pgtype.Undefined, ct.Get())
24
25	nilTests := []struct {
26		src interface{}
27	}{
28		{nil},                   // nil interface
29		{(*[]interface{})(nil)}, // typed nil
30	}
31
32	for i, tt := range nilTests {
33		err := ct.Set(tt.src)
34		assert.NoErrorf(t, err, "%d", i)
35		assert.Equal(t, nil, ct.Get())
36	}
37
38	compatibleValuesTests := []struct {
39		src      []interface{}
40		expected map[string]interface{}
41	}{
42		{
43			src:      []interface{}{"foo", int32(42)},
44			expected: map[string]interface{}{"a": "foo", "b": int32(42)},
45		},
46		{
47			src:      []interface{}{nil, nil},
48			expected: map[string]interface{}{"a": nil, "b": nil},
49		},
50		{
51			src:      []interface{}{&pgtype.Text{String: "hi", Status: pgtype.Present}, &pgtype.Int4{Int: 7, Status: pgtype.Present}},
52			expected: map[string]interface{}{"a": "hi", "b": int32(7)},
53		},
54	}
55
56	for i, tt := range compatibleValuesTests {
57		err := ct.Set(tt.src)
58		assert.NoErrorf(t, err, "%d", i)
59		assert.EqualValues(t, tt.expected, ct.Get())
60	}
61}
62
63func TestCompositeTypeAssignTo(t *testing.T) {
64	ci := pgtype.NewConnInfo()
65	ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{
66		{"a", pgtype.TextOID},
67		{"b", pgtype.Int4OID},
68	}, ci)
69	require.NoError(t, err)
70
71	{
72		err := ct.Set([]interface{}{"foo", int32(42)})
73		assert.NoError(t, err)
74
75		var a string
76		var b int32
77
78		err = ct.AssignTo([]interface{}{&a, &b})
79		assert.NoError(t, err)
80
81		assert.Equal(t, "foo", a)
82		assert.Equal(t, int32(42), b)
83	}
84
85	{
86		err := ct.Set([]interface{}{"foo", int32(42)})
87		assert.NoError(t, err)
88
89		var a pgtype.Text
90		var b pgtype.Int4
91
92		err = ct.AssignTo([]interface{}{&a, &b})
93		assert.NoError(t, err)
94
95		assert.Equal(t, pgtype.Text{String: "foo", Status: pgtype.Present}, a)
96		assert.Equal(t, pgtype.Int4{Int: 42, Status: pgtype.Present}, b)
97	}
98
99	// Allow nil destination component as no-op
100	{
101		err := ct.Set([]interface{}{"foo", int32(42)})
102		assert.NoError(t, err)
103
104		var b int32
105
106		err = ct.AssignTo([]interface{}{nil, &b})
107		assert.NoError(t, err)
108
109		assert.Equal(t, int32(42), b)
110	}
111
112	// *[]interface{} dest when null
113	{
114		err := ct.Set(nil)
115		assert.NoError(t, err)
116
117		var a pgtype.Text
118		var b pgtype.Int4
119		dst := []interface{}{&a, &b}
120
121		err = ct.AssignTo(&dst)
122		assert.NoError(t, err)
123
124		assert.Nil(t, dst)
125	}
126
127	// *[]interface{} dest when not null
128	{
129		err := ct.Set([]interface{}{"foo", int32(42)})
130		assert.NoError(t, err)
131
132		var a pgtype.Text
133		var b pgtype.Int4
134		dst := []interface{}{&a, &b}
135
136		err = ct.AssignTo(&dst)
137		assert.NoError(t, err)
138
139		assert.NotNil(t, dst)
140		assert.Equal(t, pgtype.Text{String: "foo", Status: pgtype.Present}, a)
141		assert.Equal(t, pgtype.Int4{Int: 42, Status: pgtype.Present}, b)
142	}
143
144	// Struct fields positionally via reflection
145	{
146		err := ct.Set([]interface{}{"foo", int32(42)})
147		assert.NoError(t, err)
148
149		s := struct {
150			A string
151			B int32
152		}{}
153
154		err = ct.AssignTo(&s)
155		if assert.NoError(t, err) {
156			assert.Equal(t, "foo", s.A)
157			assert.Equal(t, int32(42), s.B)
158		}
159	}
160}
161
162func TestCompositeTypeTranscode(t *testing.T) {
163	conn := testutil.MustConnectPgx(t)
164	defer testutil.MustCloseContext(t, conn)
165
166	_, err := conn.Exec(context.Background(), `drop type if exists ct_test;
167
168create type ct_test as (
169	a text,
170  b int4
171);`)
172	require.NoError(t, err)
173	defer conn.Exec(context.Background(), "drop type ct_test")
174
175	var oid uint32
176	err = conn.QueryRow(context.Background(), `select 'ct_test'::regtype::oid`).Scan(&oid)
177	require.NoError(t, err)
178
179	defer conn.Exec(context.Background(), "drop type ct_test")
180
181	ct, err := pgtype.NewCompositeType("ct_test", []pgtype.CompositeTypeField{
182		{"a", pgtype.TextOID},
183		{"b", pgtype.Int4OID},
184	}, conn.ConnInfo())
185	require.NoError(t, err)
186	conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid})
187
188	// Use simple protocol to force text or binary encoding
189	simpleProtocols := []bool{true, false}
190
191	var a string
192	var b int32
193
194	for _, simpleProtocol := range simpleProtocols {
195		err := conn.QueryRow(context.Background(), "select $1::ct_test", pgx.QuerySimpleProtocol(simpleProtocol),
196			pgtype.CompositeFields{"hi", int32(42)},
197		).Scan(
198			[]interface{}{&a, &b},
199		)
200		if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) {
201			assert.EqualValuesf(t, "hi", a, "Simple Protocol: %v", simpleProtocol)
202			assert.EqualValuesf(t, 42, b, "Simple Protocol: %v", simpleProtocol)
203		}
204	}
205}
206
207// https://github.com/jackc/pgx/issues/874
208func TestCompositeTypeTextDecodeNested(t *testing.T) {
209	newCompositeType := func(name string, fieldNames []string, vals ...pgtype.ValueTranscoder) *pgtype.CompositeType {
210		fields := make([]pgtype.CompositeTypeField, len(fieldNames))
211		for i, name := range fieldNames {
212			fields[i] = pgtype.CompositeTypeField{Name: name}
213		}
214
215		rowType, err := pgtype.NewCompositeTypeValues(name, fields, vals)
216		require.NoError(t, err)
217		return rowType
218	}
219
220	dimensionsType := func() pgtype.ValueTranscoder {
221		return newCompositeType(
222			"dimensions",
223			[]string{"width", "height"},
224			&pgtype.Int4{},
225			&pgtype.Int4{},
226		)
227	}
228	productImageType := func() pgtype.ValueTranscoder {
229		return newCompositeType(
230			"product_image_type",
231			[]string{"source", "dimensions"},
232			&pgtype.Text{},
233			dimensionsType(),
234		)
235	}
236	productImageSetType := newCompositeType(
237		"product_image_set_type",
238		[]string{"name", "orig_image", "images"},
239		&pgtype.Text{},
240		productImageType(),
241		pgtype.NewArrayType("product_image", 0, func() pgtype.ValueTranscoder {
242			return productImageType()
243		}),
244	)
245
246	err := productImageSetType.DecodeText(nil, []byte(`(name,"(img1,""(11,11)"")","{""(img2,\\""(22,22)\\"")"",""(img3,\\""(33,33)\\"")""}")`))
247	require.NoError(t, err)
248}
249
250func Example_composite() {
251	conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
252	if err != nil {
253		fmt.Println(err)
254		return
255	}
256
257	defer conn.Close(context.Background())
258	_, err = conn.Exec(context.Background(), `drop type if exists mytype;`)
259	if err != nil {
260		fmt.Println(err)
261		return
262	}
263
264	_, err = conn.Exec(context.Background(), `create type mytype as (
265  a int4,
266  b text
267);`)
268	if err != nil {
269		fmt.Println(err)
270		return
271	}
272	defer conn.Exec(context.Background(), "drop type mytype")
273
274	var oid uint32
275	err = conn.QueryRow(context.Background(), `select 'mytype'::regtype::oid`).Scan(&oid)
276	if err != nil {
277		fmt.Println(err)
278		return
279	}
280
281	ct, err := pgtype.NewCompositeType("mytype", []pgtype.CompositeTypeField{
282		{"a", pgtype.Int4OID},
283		{"b", pgtype.TextOID},
284	}, conn.ConnInfo())
285	if err != nil {
286		fmt.Println(err)
287		return
288	}
289	conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid})
290
291	var a int
292	var b *string
293
294	err = conn.QueryRow(context.Background(), "select $1::mytype", []interface{}{2, "bar"}).Scan([]interface{}{&a, &b})
295	if err != nil {
296		fmt.Println(err)
297		return
298	}
299
300	fmt.Printf("First: a=%d b=%s\n", a, *b)
301
302	err = conn.QueryRow(context.Background(), "select (1, NULL)::mytype").Scan([]interface{}{&a, &b})
303	if err != nil {
304		fmt.Println(err)
305		return
306	}
307
308	fmt.Printf("Second: a=%d b=%v\n", a, b)
309
310	scanTarget := []interface{}{&a, &b}
311	err = conn.QueryRow(context.Background(), "select NULL::mytype").Scan(&scanTarget)
312	E(err)
313
314	fmt.Printf("Third: isNull=%v\n", scanTarget == nil)
315
316	// Output:
317	// First: a=2 b=bar
318	// Second: a=1 b=<nil>
319	// Third: isNull=true
320}
321