1package pgtype_test 2 3import ( 4 "context" 5 "fmt" 6 "os" 7 "testing" 8 9 "github.com/jackc/pgtype" 10 "github.com/jackc/pgtype/testutil" 11 pgx "github.com/jackc/pgx/v4" 12 "github.com/stretchr/testify/assert" 13 "github.com/stretchr/testify/require" 14) 15 16func TestCompositeTypeSetAndGet(t *testing.T) { 17 ci := pgtype.NewConnInfo() 18 ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ 19 {"a", pgtype.TextOID}, 20 {"b", pgtype.Int4OID}, 21 }, ci) 22 require.NoError(t, err) 23 assert.Equal(t, pgtype.Undefined, ct.Get()) 24 25 nilTests := []struct { 26 src interface{} 27 }{ 28 {nil}, // nil interface 29 {(*[]interface{})(nil)}, // typed nil 30 } 31 32 for i, tt := range nilTests { 33 err := ct.Set(tt.src) 34 assert.NoErrorf(t, err, "%d", i) 35 assert.Equal(t, nil, ct.Get()) 36 } 37 38 compatibleValuesTests := []struct { 39 src []interface{} 40 expected map[string]interface{} 41 }{ 42 { 43 src: []interface{}{"foo", int32(42)}, 44 expected: map[string]interface{}{"a": "foo", "b": int32(42)}, 45 }, 46 { 47 src: []interface{}{nil, nil}, 48 expected: map[string]interface{}{"a": nil, "b": nil}, 49 }, 50 { 51 src: []interface{}{&pgtype.Text{String: "hi", Status: pgtype.Present}, &pgtype.Int4{Int: 7, Status: pgtype.Present}}, 52 expected: map[string]interface{}{"a": "hi", "b": int32(7)}, 53 }, 54 } 55 56 for i, tt := range compatibleValuesTests { 57 err := ct.Set(tt.src) 58 assert.NoErrorf(t, err, "%d", i) 59 assert.EqualValues(t, tt.expected, ct.Get()) 60 } 61} 62 63func TestCompositeTypeAssignTo(t *testing.T) { 64 ci := pgtype.NewConnInfo() 65 ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ 66 {"a", pgtype.TextOID}, 67 {"b", pgtype.Int4OID}, 68 }, ci) 69 require.NoError(t, err) 70 71 { 72 err := ct.Set([]interface{}{"foo", int32(42)}) 73 assert.NoError(t, err) 74 75 var a string 76 var b int32 77 78 err = ct.AssignTo([]interface{}{&a, &b}) 79 assert.NoError(t, err) 80 81 assert.Equal(t, "foo", a) 82 assert.Equal(t, int32(42), b) 83 } 84 85 { 86 err := ct.Set([]interface{}{"foo", int32(42)}) 87 assert.NoError(t, err) 88 89 var a pgtype.Text 90 var b pgtype.Int4 91 92 err = ct.AssignTo([]interface{}{&a, &b}) 93 assert.NoError(t, err) 94 95 assert.Equal(t, pgtype.Text{String: "foo", Status: pgtype.Present}, a) 96 assert.Equal(t, pgtype.Int4{Int: 42, Status: pgtype.Present}, b) 97 } 98 99 // Allow nil destination component as no-op 100 { 101 err := ct.Set([]interface{}{"foo", int32(42)}) 102 assert.NoError(t, err) 103 104 var b int32 105 106 err = ct.AssignTo([]interface{}{nil, &b}) 107 assert.NoError(t, err) 108 109 assert.Equal(t, int32(42), b) 110 } 111 112 // *[]interface{} dest when null 113 { 114 err := ct.Set(nil) 115 assert.NoError(t, err) 116 117 var a pgtype.Text 118 var b pgtype.Int4 119 dst := []interface{}{&a, &b} 120 121 err = ct.AssignTo(&dst) 122 assert.NoError(t, err) 123 124 assert.Nil(t, dst) 125 } 126 127 // *[]interface{} dest when not null 128 { 129 err := ct.Set([]interface{}{"foo", int32(42)}) 130 assert.NoError(t, err) 131 132 var a pgtype.Text 133 var b pgtype.Int4 134 dst := []interface{}{&a, &b} 135 136 err = ct.AssignTo(&dst) 137 assert.NoError(t, err) 138 139 assert.NotNil(t, dst) 140 assert.Equal(t, pgtype.Text{String: "foo", Status: pgtype.Present}, a) 141 assert.Equal(t, pgtype.Int4{Int: 42, Status: pgtype.Present}, b) 142 } 143 144 // Struct fields positionally via reflection 145 { 146 err := ct.Set([]interface{}{"foo", int32(42)}) 147 assert.NoError(t, err) 148 149 s := struct { 150 A string 151 B int32 152 }{} 153 154 err = ct.AssignTo(&s) 155 if assert.NoError(t, err) { 156 assert.Equal(t, "foo", s.A) 157 assert.Equal(t, int32(42), s.B) 158 } 159 } 160} 161 162func TestCompositeTypeTranscode(t *testing.T) { 163 conn := testutil.MustConnectPgx(t) 164 defer testutil.MustCloseContext(t, conn) 165 166 _, err := conn.Exec(context.Background(), `drop type if exists ct_test; 167 168create type ct_test as ( 169 a text, 170 b int4 171);`) 172 require.NoError(t, err) 173 defer conn.Exec(context.Background(), "drop type ct_test") 174 175 var oid uint32 176 err = conn.QueryRow(context.Background(), `select 'ct_test'::regtype::oid`).Scan(&oid) 177 require.NoError(t, err) 178 179 defer conn.Exec(context.Background(), "drop type ct_test") 180 181 ct, err := pgtype.NewCompositeType("ct_test", []pgtype.CompositeTypeField{ 182 {"a", pgtype.TextOID}, 183 {"b", pgtype.Int4OID}, 184 }, conn.ConnInfo()) 185 require.NoError(t, err) 186 conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid}) 187 188 // Use simple protocol to force text or binary encoding 189 simpleProtocols := []bool{true, false} 190 191 var a string 192 var b int32 193 194 for _, simpleProtocol := range simpleProtocols { 195 err := conn.QueryRow(context.Background(), "select $1::ct_test", pgx.QuerySimpleProtocol(simpleProtocol), 196 pgtype.CompositeFields{"hi", int32(42)}, 197 ).Scan( 198 []interface{}{&a, &b}, 199 ) 200 if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { 201 assert.EqualValuesf(t, "hi", a, "Simple Protocol: %v", simpleProtocol) 202 assert.EqualValuesf(t, 42, b, "Simple Protocol: %v", simpleProtocol) 203 } 204 } 205} 206 207// https://github.com/jackc/pgx/issues/874 208func TestCompositeTypeTextDecodeNested(t *testing.T) { 209 newCompositeType := func(name string, fieldNames []string, vals ...pgtype.ValueTranscoder) *pgtype.CompositeType { 210 fields := make([]pgtype.CompositeTypeField, len(fieldNames)) 211 for i, name := range fieldNames { 212 fields[i] = pgtype.CompositeTypeField{Name: name} 213 } 214 215 rowType, err := pgtype.NewCompositeTypeValues(name, fields, vals) 216 require.NoError(t, err) 217 return rowType 218 } 219 220 dimensionsType := func() pgtype.ValueTranscoder { 221 return newCompositeType( 222 "dimensions", 223 []string{"width", "height"}, 224 &pgtype.Int4{}, 225 &pgtype.Int4{}, 226 ) 227 } 228 productImageType := func() pgtype.ValueTranscoder { 229 return newCompositeType( 230 "product_image_type", 231 []string{"source", "dimensions"}, 232 &pgtype.Text{}, 233 dimensionsType(), 234 ) 235 } 236 productImageSetType := newCompositeType( 237 "product_image_set_type", 238 []string{"name", "orig_image", "images"}, 239 &pgtype.Text{}, 240 productImageType(), 241 pgtype.NewArrayType("product_image", 0, func() pgtype.ValueTranscoder { 242 return productImageType() 243 }), 244 ) 245 246 err := productImageSetType.DecodeText(nil, []byte(`(name,"(img1,""(11,11)"")","{""(img2,\\""(22,22)\\"")"",""(img3,\\""(33,33)\\"")""}")`)) 247 require.NoError(t, err) 248} 249 250func Example_composite() { 251 conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 252 if err != nil { 253 fmt.Println(err) 254 return 255 } 256 257 defer conn.Close(context.Background()) 258 _, err = conn.Exec(context.Background(), `drop type if exists mytype;`) 259 if err != nil { 260 fmt.Println(err) 261 return 262 } 263 264 _, err = conn.Exec(context.Background(), `create type mytype as ( 265 a int4, 266 b text 267);`) 268 if err != nil { 269 fmt.Println(err) 270 return 271 } 272 defer conn.Exec(context.Background(), "drop type mytype") 273 274 var oid uint32 275 err = conn.QueryRow(context.Background(), `select 'mytype'::regtype::oid`).Scan(&oid) 276 if err != nil { 277 fmt.Println(err) 278 return 279 } 280 281 ct, err := pgtype.NewCompositeType("mytype", []pgtype.CompositeTypeField{ 282 {"a", pgtype.Int4OID}, 283 {"b", pgtype.TextOID}, 284 }, conn.ConnInfo()) 285 if err != nil { 286 fmt.Println(err) 287 return 288 } 289 conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid}) 290 291 var a int 292 var b *string 293 294 err = conn.QueryRow(context.Background(), "select $1::mytype", []interface{}{2, "bar"}).Scan([]interface{}{&a, &b}) 295 if err != nil { 296 fmt.Println(err) 297 return 298 } 299 300 fmt.Printf("First: a=%d b=%s\n", a, *b) 301 302 err = conn.QueryRow(context.Background(), "select (1, NULL)::mytype").Scan([]interface{}{&a, &b}) 303 if err != nil { 304 fmt.Println(err) 305 return 306 } 307 308 fmt.Printf("Second: a=%d b=%v\n", a, b) 309 310 scanTarget := []interface{}{&a, &b} 311 err = conn.QueryRow(context.Background(), "select NULL::mytype").Scan(&scanTarget) 312 E(err) 313 314 fmt.Printf("Third: isNull=%v\n", scanTarget == nil) 315 316 // Output: 317 // First: a=2 b=bar 318 // Second: a=1 b=<nil> 319 // Third: isNull=true 320} 321