1package pgx_test 2 3import ( 4 "bytes" 5 "context" 6 "net" 7 "os" 8 "reflect" 9 "testing" 10 "time" 11 12 "github.com/jackc/pgx/v4" 13 "github.com/stretchr/testify/assert" 14 "github.com/stretchr/testify/require" 15) 16 17func TestDateTranscode(t *testing.T) { 18 t.Parallel() 19 20 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { 21 dates := []time.Time{ 22 time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC), 23 time.Date(1000, 1, 1, 0, 0, 0, 0, time.UTC), 24 time.Date(1600, 1, 1, 0, 0, 0, 0, time.UTC), 25 time.Date(1700, 1, 1, 0, 0, 0, 0, time.UTC), 26 time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC), 27 time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), 28 time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC), 29 time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), 30 time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), 31 time.Date(2001, 1, 2, 0, 0, 0, 0, time.UTC), 32 time.Date(2004, 2, 29, 0, 0, 0, 0, time.UTC), 33 time.Date(2013, 7, 4, 0, 0, 0, 0, time.UTC), 34 time.Date(2013, 12, 25, 0, 0, 0, 0, time.UTC), 35 time.Date(2029, 1, 1, 0, 0, 0, 0, time.UTC), 36 time.Date(2081, 1, 1, 0, 0, 0, 0, time.UTC), 37 time.Date(2096, 2, 29, 0, 0, 0, 0, time.UTC), 38 time.Date(2550, 1, 1, 0, 0, 0, 0, time.UTC), 39 time.Date(9999, 12, 31, 0, 0, 0, 0, time.UTC), 40 } 41 42 for _, actualDate := range dates { 43 var d time.Time 44 45 err := conn.QueryRow(context.Background(), "select $1::date", actualDate).Scan(&d) 46 if err != nil { 47 t.Fatalf("Unexpected failure on QueryRow Scan: %v", err) 48 } 49 if !actualDate.Equal(d) { 50 t.Errorf("Did not transcode date successfully: %v is not %v", d, actualDate) 51 } 52 } 53 }) 54} 55 56func TestTimestampTzTranscode(t *testing.T) { 57 t.Parallel() 58 59 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { 60 inputTime := time.Date(2013, 1, 2, 3, 4, 5, 6000, time.Local) 61 62 var outputTime time.Time 63 64 err := conn.QueryRow(context.Background(), "select $1::timestamptz", inputTime).Scan(&outputTime) 65 if err != nil { 66 t.Fatalf("QueryRow Scan failed: %v", err) 67 } 68 if !inputTime.Equal(outputTime) { 69 t.Errorf("Did not transcode time successfully: %v is not %v", outputTime, inputTime) 70 } 71 }) 72} 73 74// TODO - move these tests to pgtype 75 76func TestJSONAndJSONBTranscode(t *testing.T) { 77 t.Parallel() 78 79 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { 80 for _, typename := range []string{"json", "jsonb"} { 81 if _, ok := conn.ConnInfo().DataTypeForName(typename); !ok { 82 continue // No JSON/JSONB type -- must be running against old PostgreSQL 83 } 84 85 testJSONString(t, conn, typename) 86 testJSONStringPointer(t, conn, typename) 87 } 88 }) 89} 90 91func TestJSONAndJSONBTranscodeExtendedOnly(t *testing.T) { 92 t.Parallel() 93 94 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 95 defer closeConn(t, conn) 96 97 for _, typename := range []string{"json", "jsonb"} { 98 if _, ok := conn.ConnInfo().DataTypeForName(typename); !ok { 99 continue // No JSON/JSONB type -- must be running against old PostgreSQL 100 } 101 testJSONSingleLevelStringMap(t, conn, typename) 102 testJSONNestedMap(t, conn, typename) 103 testJSONStringArray(t, conn, typename) 104 testJSONInt64Array(t, conn, typename) 105 testJSONInt16ArrayFailureDueToOverflow(t, conn, typename) 106 testJSONStruct(t, conn, typename) 107 } 108 109} 110 111func testJSONString(t *testing.T, conn *pgx.Conn, typename string) { 112 input := `{"key": "value"}` 113 expectedOutput := map[string]string{"key": "value"} 114 var output map[string]string 115 err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output) 116 if err != nil { 117 t.Errorf("%s: QueryRow Scan failed: %v", typename, err) 118 return 119 } 120 121 if !reflect.DeepEqual(expectedOutput, output) { 122 t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, expectedOutput, output) 123 return 124 } 125} 126 127func testJSONStringPointer(t *testing.T, conn *pgx.Conn, typename string) { 128 input := `{"key": "value"}` 129 expectedOutput := map[string]string{"key": "value"} 130 var output map[string]string 131 err := conn.QueryRow(context.Background(), "select $1::"+typename, &input).Scan(&output) 132 if err != nil { 133 t.Errorf("%s: QueryRow Scan failed: %v", typename, err) 134 return 135 } 136 137 if !reflect.DeepEqual(expectedOutput, output) { 138 t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, expectedOutput, output) 139 return 140 } 141} 142 143func testJSONSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string) { 144 input := map[string]string{"key": "value"} 145 var output map[string]string 146 err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output) 147 if err != nil { 148 t.Errorf("%s: QueryRow Scan failed: %v", typename, err) 149 return 150 } 151 152 if !reflect.DeepEqual(input, output) { 153 t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, input, output) 154 return 155 } 156} 157 158func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string) { 159 input := map[string]interface{}{ 160 "name": "Uncanny", 161 "stats": map[string]interface{}{"hp": float64(107), "maxhp": float64(150)}, 162 "inventory": []interface{}{"phone", "key"}, 163 } 164 var output map[string]interface{} 165 err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output) 166 if err != nil { 167 t.Errorf("%s: QueryRow Scan failed: %v", typename, err) 168 return 169 } 170 171 if !reflect.DeepEqual(input, output) { 172 t.Errorf("%s: Did not transcode map[string]interface{} successfully: %v is not %v", typename, input, output) 173 return 174 } 175} 176 177func testJSONStringArray(t *testing.T, conn *pgx.Conn, typename string) { 178 input := []string{"foo", "bar", "baz"} 179 var output []string 180 err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output) 181 if err != nil { 182 t.Errorf("%s: QueryRow Scan failed: %v", typename, err) 183 } 184 185 if !reflect.DeepEqual(input, output) { 186 t.Errorf("%s: Did not transcode []string successfully: %v is not %v", typename, input, output) 187 } 188} 189 190func testJSONInt64Array(t *testing.T, conn *pgx.Conn, typename string) { 191 input := []int64{1, 2, 234432} 192 var output []int64 193 err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output) 194 if err != nil { 195 t.Errorf("%s: QueryRow Scan failed: %v", typename, err) 196 } 197 198 if !reflect.DeepEqual(input, output) { 199 t.Errorf("%s: Did not transcode []int64 successfully: %v is not %v", typename, input, output) 200 } 201} 202 203func testJSONInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string) { 204 input := []int{1, 2, 234432} 205 var output []int16 206 err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output) 207 if err == nil || err.Error() != "can't scan into dest[0]: json: cannot unmarshal number 234432 into Go value of type int16" { 208 t.Errorf("%s: Expected *json.UnmarkalTypeError, but got %v", typename, err) 209 } 210} 211 212func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string) { 213 type person struct { 214 Name string `json:"name"` 215 Age int `json:"age"` 216 } 217 218 input := person{ 219 Name: "John", 220 Age: 42, 221 } 222 223 var output person 224 225 err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output) 226 if err != nil { 227 t.Errorf("%s: QueryRow Scan failed: %v", typename, err) 228 } 229 230 if !reflect.DeepEqual(input, output) { 231 t.Errorf("%s: Did not transcode struct successfully: %v is not %v", typename, input, output) 232 } 233} 234 235func mustParseCIDR(t *testing.T, s string) *net.IPNet { 236 _, ipnet, err := net.ParseCIDR(s) 237 if err != nil { 238 t.Fatal(err) 239 } 240 241 return ipnet 242} 243 244func TestStringToNotTextTypeTranscode(t *testing.T) { 245 t.Parallel() 246 247 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { 248 input := "01086ee0-4963-4e35-9116-30c173a8d0bd" 249 250 var output string 251 err := conn.QueryRow(context.Background(), "select $1::uuid", input).Scan(&output) 252 if err != nil { 253 t.Fatal(err) 254 } 255 if input != output { 256 t.Errorf("uuid: Did not transcode string successfully: %s is not %s", input, output) 257 } 258 259 err = conn.QueryRow(context.Background(), "select $1::uuid", &input).Scan(&output) 260 if err != nil { 261 t.Fatal(err) 262 } 263 if input != output { 264 t.Errorf("uuid: Did not transcode pointer to string successfully: %s is not %s", input, output) 265 } 266 }) 267} 268 269func TestInetCIDRTranscodeIPNet(t *testing.T) { 270 t.Parallel() 271 272 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { 273 tests := []struct { 274 sql string 275 value *net.IPNet 276 }{ 277 {"select $1::inet", mustParseCIDR(t, "0.0.0.0/32")}, 278 {"select $1::inet", mustParseCIDR(t, "127.0.0.1/32")}, 279 {"select $1::inet", mustParseCIDR(t, "12.34.56.0/32")}, 280 {"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")}, 281 {"select $1::inet", mustParseCIDR(t, "255.0.0.0/8")}, 282 {"select $1::inet", mustParseCIDR(t, "255.255.255.255/32")}, 283 {"select $1::inet", mustParseCIDR(t, "::/128")}, 284 {"select $1::inet", mustParseCIDR(t, "::/0")}, 285 {"select $1::inet", mustParseCIDR(t, "::1/128")}, 286 {"select $1::inet", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")}, 287 {"select $1::cidr", mustParseCIDR(t, "0.0.0.0/32")}, 288 {"select $1::cidr", mustParseCIDR(t, "127.0.0.1/32")}, 289 {"select $1::cidr", mustParseCIDR(t, "12.34.56.0/32")}, 290 {"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")}, 291 {"select $1::cidr", mustParseCIDR(t, "255.0.0.0/8")}, 292 {"select $1::cidr", mustParseCIDR(t, "255.255.255.255/32")}, 293 {"select $1::cidr", mustParseCIDR(t, "::/128")}, 294 {"select $1::cidr", mustParseCIDR(t, "::/0")}, 295 {"select $1::cidr", mustParseCIDR(t, "::1/128")}, 296 {"select $1::cidr", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")}, 297 } 298 299 for i, tt := range tests { 300 var actual net.IPNet 301 302 err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) 303 if err != nil { 304 t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) 305 continue 306 } 307 308 if actual.String() != tt.value.String() { 309 t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) 310 } 311 } 312 }) 313} 314 315func TestInetCIDRTranscodeIP(t *testing.T) { 316 t.Parallel() 317 318 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { 319 tests := []struct { 320 sql string 321 value net.IP 322 }{ 323 {"select $1::inet", net.ParseIP("0.0.0.0")}, 324 {"select $1::inet", net.ParseIP("127.0.0.1")}, 325 {"select $1::inet", net.ParseIP("12.34.56.0")}, 326 {"select $1::inet", net.ParseIP("255.255.255.255")}, 327 {"select $1::inet", net.ParseIP("::1")}, 328 {"select $1::inet", net.ParseIP("2607:f8b0:4009:80b::200e")}, 329 {"select $1::cidr", net.ParseIP("0.0.0.0")}, 330 {"select $1::cidr", net.ParseIP("127.0.0.1")}, 331 {"select $1::cidr", net.ParseIP("12.34.56.0")}, 332 {"select $1::cidr", net.ParseIP("255.255.255.255")}, 333 {"select $1::cidr", net.ParseIP("::1")}, 334 {"select $1::cidr", net.ParseIP("2607:f8b0:4009:80b::200e")}, 335 } 336 337 for i, tt := range tests { 338 var actual net.IP 339 340 err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) 341 if err != nil { 342 t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) 343 continue 344 } 345 346 if !actual.Equal(tt.value) { 347 t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) 348 } 349 350 ensureConnValid(t, conn) 351 } 352 353 failTests := []struct { 354 sql string 355 value *net.IPNet 356 }{ 357 {"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")}, 358 {"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")}, 359 } 360 for i, tt := range failTests { 361 var actual net.IP 362 363 err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) 364 if err == nil { 365 t.Errorf("%d. Expected failure but got none", i) 366 continue 367 } 368 369 ensureConnValid(t, conn) 370 } 371 }) 372} 373 374func TestInetCIDRArrayTranscodeIPNet(t *testing.T) { 375 t.Parallel() 376 377 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { 378 tests := []struct { 379 sql string 380 value []*net.IPNet 381 }{ 382 { 383 "select $1::inet[]", 384 []*net.IPNet{ 385 mustParseCIDR(t, "0.0.0.0/32"), 386 mustParseCIDR(t, "127.0.0.1/32"), 387 mustParseCIDR(t, "12.34.56.0/32"), 388 mustParseCIDR(t, "192.168.1.0/24"), 389 mustParseCIDR(t, "255.0.0.0/8"), 390 mustParseCIDR(t, "255.255.255.255/32"), 391 mustParseCIDR(t, "::/128"), 392 mustParseCIDR(t, "::/0"), 393 mustParseCIDR(t, "::1/128"), 394 mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), 395 }, 396 }, 397 { 398 "select $1::cidr[]", 399 []*net.IPNet{ 400 mustParseCIDR(t, "0.0.0.0/32"), 401 mustParseCIDR(t, "127.0.0.1/32"), 402 mustParseCIDR(t, "12.34.56.0/32"), 403 mustParseCIDR(t, "192.168.1.0/24"), 404 mustParseCIDR(t, "255.0.0.0/8"), 405 mustParseCIDR(t, "255.255.255.255/32"), 406 mustParseCIDR(t, "::/128"), 407 mustParseCIDR(t, "::/0"), 408 mustParseCIDR(t, "::1/128"), 409 mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), 410 }, 411 }, 412 } 413 414 for i, tt := range tests { 415 var actual []*net.IPNet 416 417 err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) 418 if err != nil { 419 t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) 420 continue 421 } 422 423 if !reflect.DeepEqual(actual, tt.value) { 424 t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) 425 } 426 427 ensureConnValid(t, conn) 428 } 429 }) 430} 431 432func TestInetCIDRArrayTranscodeIP(t *testing.T) { 433 t.Parallel() 434 435 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { 436 tests := []struct { 437 sql string 438 value []net.IP 439 }{ 440 { 441 "select $1::inet[]", 442 []net.IP{ 443 net.ParseIP("0.0.0.0"), 444 net.ParseIP("127.0.0.1"), 445 net.ParseIP("12.34.56.0"), 446 net.ParseIP("255.255.255.255"), 447 net.ParseIP("2607:f8b0:4009:80b::200e"), 448 }, 449 }, 450 { 451 "select $1::cidr[]", 452 []net.IP{ 453 net.ParseIP("0.0.0.0"), 454 net.ParseIP("127.0.0.1"), 455 net.ParseIP("12.34.56.0"), 456 net.ParseIP("255.255.255.255"), 457 net.ParseIP("2607:f8b0:4009:80b::200e"), 458 }, 459 }, 460 } 461 462 for i, tt := range tests { 463 var actual []net.IP 464 465 err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) 466 if err != nil { 467 t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) 468 continue 469 } 470 471 assert.Equal(t, len(tt.value), len(actual), "%d", i) 472 for j := range actual { 473 assert.True(t, actual[j].Equal(tt.value[j]), "%d", i) 474 } 475 476 ensureConnValid(t, conn) 477 } 478 479 failTests := []struct { 480 sql string 481 value []*net.IPNet 482 }{ 483 { 484 "select $1::inet[]", 485 []*net.IPNet{ 486 mustParseCIDR(t, "12.34.56.0/32"), 487 mustParseCIDR(t, "192.168.1.0/24"), 488 }, 489 }, 490 { 491 "select $1::cidr[]", 492 []*net.IPNet{ 493 mustParseCIDR(t, "12.34.56.0/32"), 494 mustParseCIDR(t, "192.168.1.0/24"), 495 }, 496 }, 497 } 498 499 for i, tt := range failTests { 500 var actual []net.IP 501 502 err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) 503 if err == nil { 504 t.Errorf("%d. Expected failure but got none", i) 505 continue 506 } 507 508 ensureConnValid(t, conn) 509 } 510 }) 511} 512 513func TestInetCIDRTranscodeWithJustIP(t *testing.T) { 514 t.Parallel() 515 516 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { 517 tests := []struct { 518 sql string 519 value string 520 }{ 521 {"select $1::inet", "0.0.0.0/32"}, 522 {"select $1::inet", "127.0.0.1/32"}, 523 {"select $1::inet", "12.34.56.0/32"}, 524 {"select $1::inet", "255.255.255.255/32"}, 525 {"select $1::inet", "::/128"}, 526 {"select $1::inet", "2607:f8b0:4009:80b::200e/128"}, 527 {"select $1::cidr", "0.0.0.0/32"}, 528 {"select $1::cidr", "127.0.0.1/32"}, 529 {"select $1::cidr", "12.34.56.0/32"}, 530 {"select $1::cidr", "255.255.255.255/32"}, 531 {"select $1::cidr", "::/128"}, 532 {"select $1::cidr", "2607:f8b0:4009:80b::200e/128"}, 533 } 534 535 for i, tt := range tests { 536 expected := mustParseCIDR(t, tt.value) 537 var actual net.IPNet 538 539 err := conn.QueryRow(context.Background(), tt.sql, expected.IP).Scan(&actual) 540 if err != nil { 541 t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) 542 continue 543 } 544 545 if actual.String() != expected.String() { 546 t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) 547 } 548 549 ensureConnValid(t, conn) 550 } 551 }) 552} 553 554func TestArrayDecoding(t *testing.T) { 555 t.Parallel() 556 557 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { 558 tests := []struct { 559 sql string 560 query interface{} 561 scan interface{} 562 assert func(*testing.T, interface{}, interface{}) 563 }{ 564 { 565 "select $1::bool[]", []bool{true, false, true}, &[]bool{}, 566 func(t *testing.T, query, scan interface{}) { 567 if !reflect.DeepEqual(query, *(scan.(*[]bool))) { 568 t.Errorf("failed to encode bool[]") 569 } 570 }, 571 }, 572 { 573 "select $1::smallint[]", []int16{2, 4, 484, 32767}, &[]int16{}, 574 func(t *testing.T, query, scan interface{}) { 575 if !reflect.DeepEqual(query, *(scan.(*[]int16))) { 576 t.Errorf("failed to encode smallint[]") 577 } 578 }, 579 }, 580 { 581 "select $1::smallint[]", []uint16{2, 4, 484, 32767}, &[]uint16{}, 582 func(t *testing.T, query, scan interface{}) { 583 if !reflect.DeepEqual(query, *(scan.(*[]uint16))) { 584 t.Errorf("failed to encode smallint[]") 585 } 586 }, 587 }, 588 { 589 "select $1::int[]", []int32{2, 4, 484}, &[]int32{}, 590 func(t *testing.T, query, scan interface{}) { 591 if !reflect.DeepEqual(query, *(scan.(*[]int32))) { 592 t.Errorf("failed to encode int[]") 593 } 594 }, 595 }, 596 { 597 "select $1::int[]", []uint32{2, 4, 484, 2147483647}, &[]uint32{}, 598 func(t *testing.T, query, scan interface{}) { 599 if !reflect.DeepEqual(query, *(scan.(*[]uint32))) { 600 t.Errorf("failed to encode int[]") 601 } 602 }, 603 }, 604 { 605 "select $1::bigint[]", []int64{2, 4, 484, 9223372036854775807}, &[]int64{}, 606 func(t *testing.T, query, scan interface{}) { 607 if !reflect.DeepEqual(query, *(scan.(*[]int64))) { 608 t.Errorf("failed to encode bigint[]") 609 } 610 }, 611 }, 612 { 613 "select $1::bigint[]", []uint64{2, 4, 484, 9223372036854775807}, &[]uint64{}, 614 func(t *testing.T, query, scan interface{}) { 615 if !reflect.DeepEqual(query, *(scan.(*[]uint64))) { 616 t.Errorf("failed to encode bigint[]") 617 } 618 }, 619 }, 620 { 621 "select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{}, 622 func(t *testing.T, query, scan interface{}) { 623 if !reflect.DeepEqual(query, *(scan.(*[]string))) { 624 t.Errorf("failed to encode text[]") 625 } 626 }, 627 }, 628 { 629 "select $1::timestamptz[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{}, 630 func(t *testing.T, query, scan interface{}) { 631 queryTimeSlice := query.([]time.Time) 632 scanTimeSlice := *(scan.(*[]time.Time)) 633 require.Equal(t, len(queryTimeSlice), len(scanTimeSlice)) 634 for i := range queryTimeSlice { 635 assert.Truef(t, queryTimeSlice[i].Equal(scanTimeSlice[i]), "%d", i) 636 } 637 }, 638 }, 639 { 640 "select $1::bytea[]", [][]byte{{0, 1, 2, 3}, {4, 5, 6, 7}}, &[][]byte{}, 641 func(t *testing.T, query, scan interface{}) { 642 queryBytesSliceSlice := query.([][]byte) 643 scanBytesSliceSlice := *(scan.(*[][]byte)) 644 if len(queryBytesSliceSlice) != len(scanBytesSliceSlice) { 645 t.Errorf("failed to encode byte[][] to bytea[]: expected %d to equal %d", len(queryBytesSliceSlice), len(scanBytesSliceSlice)) 646 } 647 for i := range queryBytesSliceSlice { 648 qb := queryBytesSliceSlice[i] 649 sb := scanBytesSliceSlice[i] 650 if !bytes.Equal(qb, sb) { 651 t.Errorf("failed to encode byte[][] to bytea[]: expected %v to equal %v", qb, sb) 652 } 653 } 654 }, 655 }, 656 } 657 658 for i, tt := range tests { 659 err := conn.QueryRow(context.Background(), tt.sql, tt.query).Scan(tt.scan) 660 if err != nil { 661 t.Errorf(`%d. error reading array: %v`, i, err) 662 continue 663 } 664 tt.assert(t, tt.query, tt.scan) 665 ensureConnValid(t, conn) 666 } 667 }) 668} 669 670func TestEmptyArrayDecoding(t *testing.T) { 671 t.Parallel() 672 673 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { 674 var val []string 675 676 err := conn.QueryRow(context.Background(), "select array[]::text[]").Scan(&val) 677 if err != nil { 678 t.Errorf(`error reading array: %v`, err) 679 } 680 if len(val) != 0 { 681 t.Errorf("Expected 0 values, got %d", len(val)) 682 } 683 684 var n, m int32 685 686 err = conn.QueryRow(context.Background(), "select 1::integer, array[]::text[], 42::integer").Scan(&n, &val, &m) 687 if err != nil { 688 t.Errorf(`error reading array: %v`, err) 689 } 690 if len(val) != 0 { 691 t.Errorf("Expected 0 values, got %d", len(val)) 692 } 693 if n != 1 { 694 t.Errorf("Expected n to be 1, but it was %d", n) 695 } 696 if m != 42 { 697 t.Errorf("Expected n to be 42, but it was %d", n) 698 } 699 700 rows, err := conn.Query(context.Background(), "select 1::integer, array['test']::text[] union select 2::integer, array[]::text[] union select 3::integer, array['test']::text[]") 701 if err != nil { 702 t.Errorf(`error retrieving rows with array: %v`, err) 703 } 704 defer rows.Close() 705 706 for rows.Next() { 707 err = rows.Scan(&n, &val) 708 if err != nil { 709 t.Errorf(`error reading array: %v`, err) 710 } 711 } 712 }) 713} 714 715func TestPointerPointer(t *testing.T) { 716 t.Parallel() 717 718 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { 719 type allTypes struct { 720 s *string 721 i16 *int16 722 i32 *int32 723 i64 *int64 724 f32 *float32 725 f64 *float64 726 b *bool 727 t *time.Time 728 } 729 730 var actual, zero, expected allTypes 731 732 { 733 s := "foo" 734 expected.s = &s 735 i16 := int16(1) 736 expected.i16 = &i16 737 i32 := int32(1) 738 expected.i32 = &i32 739 i64 := int64(1) 740 expected.i64 = &i64 741 f32 := float32(1.23) 742 expected.f32 = &f32 743 f64 := float64(1.23) 744 expected.f64 = &f64 745 b := true 746 expected.b = &b 747 t := time.Unix(123, 5000) 748 expected.t = &t 749 } 750 751 tests := []struct { 752 sql string 753 queryArgs []interface{} 754 scanArgs []interface{} 755 expected allTypes 756 }{ 757 {"select $1::text", []interface{}{expected.s}, []interface{}{&actual.s}, allTypes{s: expected.s}}, 758 {"select $1::text", []interface{}{zero.s}, []interface{}{&actual.s}, allTypes{}}, 759 {"select $1::int2", []interface{}{expected.i16}, []interface{}{&actual.i16}, allTypes{i16: expected.i16}}, 760 {"select $1::int2", []interface{}{zero.i16}, []interface{}{&actual.i16}, allTypes{}}, 761 {"select $1::int4", []interface{}{expected.i32}, []interface{}{&actual.i32}, allTypes{i32: expected.i32}}, 762 {"select $1::int4", []interface{}{zero.i32}, []interface{}{&actual.i32}, allTypes{}}, 763 {"select $1::int8", []interface{}{expected.i64}, []interface{}{&actual.i64}, allTypes{i64: expected.i64}}, 764 {"select $1::int8", []interface{}{zero.i64}, []interface{}{&actual.i64}, allTypes{}}, 765 {"select $1::float4", []interface{}{expected.f32}, []interface{}{&actual.f32}, allTypes{f32: expected.f32}}, 766 {"select $1::float4", []interface{}{zero.f32}, []interface{}{&actual.f32}, allTypes{}}, 767 {"select $1::float8", []interface{}{expected.f64}, []interface{}{&actual.f64}, allTypes{f64: expected.f64}}, 768 {"select $1::float8", []interface{}{zero.f64}, []interface{}{&actual.f64}, allTypes{}}, 769 {"select $1::bool", []interface{}{expected.b}, []interface{}{&actual.b}, allTypes{b: expected.b}}, 770 {"select $1::bool", []interface{}{zero.b}, []interface{}{&actual.b}, allTypes{}}, 771 {"select $1::timestamptz", []interface{}{expected.t}, []interface{}{&actual.t}, allTypes{t: expected.t}}, 772 {"select $1::timestamptz", []interface{}{zero.t}, []interface{}{&actual.t}, allTypes{}}, 773 } 774 775 for i, tt := range tests { 776 actual = zero 777 778 err := conn.QueryRow(context.Background(), tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) 779 if err != nil { 780 t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, tt.sql, tt.queryArgs) 781 } 782 783 if !reflect.DeepEqual(actual, tt.expected) { 784 t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, tt.sql, tt.queryArgs) 785 } 786 787 ensureConnValid(t, conn) 788 } 789 }) 790} 791 792func TestPointerPointerNonZero(t *testing.T) { 793 t.Parallel() 794 795 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { 796 f := "foo" 797 dest := &f 798 799 err := conn.QueryRow(context.Background(), "select $1::text", nil).Scan(&dest) 800 if err != nil { 801 t.Errorf("Unexpected failure scanning: %v", err) 802 } 803 if dest != nil { 804 t.Errorf("Expected dest to be nil, got %#v", dest) 805 } 806 }) 807} 808 809func TestEncodeTypeRename(t *testing.T) { 810 t.Parallel() 811 812 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { 813 type _int int 814 inInt := _int(1) 815 var outInt _int 816 817 type _int8 int8 818 inInt8 := _int8(2) 819 var outInt8 _int8 820 821 type _int16 int16 822 inInt16 := _int16(3) 823 var outInt16 _int16 824 825 type _int32 int32 826 inInt32 := _int32(4) 827 var outInt32 _int32 828 829 type _int64 int64 830 inInt64 := _int64(5) 831 var outInt64 _int64 832 833 type _uint uint 834 inUint := _uint(6) 835 var outUint _uint 836 837 type _uint8 uint8 838 inUint8 := _uint8(7) 839 var outUint8 _uint8 840 841 type _uint16 uint16 842 inUint16 := _uint16(8) 843 var outUint16 _uint16 844 845 type _uint32 uint32 846 inUint32 := _uint32(9) 847 var outUint32 _uint32 848 849 type _uint64 uint64 850 inUint64 := _uint64(10) 851 var outUint64 _uint64 852 853 type _string string 854 inString := _string("foo") 855 var outString _string 856 857 err := conn.QueryRow(context.Background(), "select $1::int, $2::int, $3::int2, $4::int4, $5::int8, $6::int, $7::int, $8::int, $9::int, $10::int, $11::text", 858 inInt, inInt8, inInt16, inInt32, inInt64, inUint, inUint8, inUint16, inUint32, inUint64, inString, 859 ).Scan(&outInt, &outInt8, &outInt16, &outInt32, &outInt64, &outUint, &outUint8, &outUint16, &outUint32, &outUint64, &outString) 860 if err != nil { 861 t.Fatalf("Failed with type rename: %v", err) 862 } 863 864 if inInt != outInt { 865 t.Errorf("int rename: expected %v, got %v", inInt, outInt) 866 } 867 868 if inInt8 != outInt8 { 869 t.Errorf("int8 rename: expected %v, got %v", inInt8, outInt8) 870 } 871 872 if inInt16 != outInt16 { 873 t.Errorf("int16 rename: expected %v, got %v", inInt16, outInt16) 874 } 875 876 if inInt32 != outInt32 { 877 t.Errorf("int32 rename: expected %v, got %v", inInt32, outInt32) 878 } 879 880 if inInt64 != outInt64 { 881 t.Errorf("int64 rename: expected %v, got %v", inInt64, outInt64) 882 } 883 884 if inUint != outUint { 885 t.Errorf("uint rename: expected %v, got %v", inUint, outUint) 886 } 887 888 if inUint8 != outUint8 { 889 t.Errorf("uint8 rename: expected %v, got %v", inUint8, outUint8) 890 } 891 892 if inUint16 != outUint16 { 893 t.Errorf("uint16 rename: expected %v, got %v", inUint16, outUint16) 894 } 895 896 if inUint32 != outUint32 { 897 t.Errorf("uint32 rename: expected %v, got %v", inUint32, outUint32) 898 } 899 900 if inUint64 != outUint64 { 901 t.Errorf("uint64 rename: expected %v, got %v", inUint64, outUint64) 902 } 903 904 if inString != outString { 905 t.Errorf("string rename: expected %v, got %v", inString, outString) 906 } 907 }) 908} 909 910func TestRowDecodeBinary(t *testing.T) { 911 t.Parallel() 912 913 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 914 defer closeConn(t, conn) 915 916 tests := []struct { 917 sql string 918 expected []interface{} 919 }{ 920 { 921 "select row(1, 'cat', '2015-01-01 08:12:42-00'::timestamptz)", 922 []interface{}{ 923 int32(1), 924 "cat", 925 time.Date(2015, 1, 1, 8, 12, 42, 0, time.UTC).Local(), 926 }, 927 }, 928 { 929 "select row(100.0::float, 1.09::float)", 930 []interface{}{ 931 float64(100), 932 float64(1.09), 933 }, 934 }, 935 } 936 937 for i, tt := range tests { 938 var actual []interface{} 939 940 err := conn.QueryRow(context.Background(), tt.sql).Scan(&actual) 941 if err != nil { 942 t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql) 943 continue 944 } 945 946 if !reflect.DeepEqual(actual, tt.expected) { 947 t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.expected, actual, tt.sql) 948 } 949 950 ensureConnValid(t, conn) 951 } 952} 953 954// https://github.com/jackc/pgx/issues/810 955func TestRowsScanNilThenScanValue(t *testing.T) { 956 t.Parallel() 957 958 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { 959 sql := `select null as a, null as b 960union 961select 1, 2 962order by a nulls first 963` 964 rows, err := conn.Query(context.Background(), sql) 965 require.NoError(t, err) 966 967 require.True(t, rows.Next()) 968 969 err = rows.Scan(nil, nil) 970 require.NoError(t, err) 971 972 require.True(t, rows.Next()) 973 974 var a int 975 var b int 976 err = rows.Scan(&a, &b) 977 require.NoError(t, err) 978 979 require.EqualValues(t, 1, a) 980 require.EqualValues(t, 2, b) 981 982 rows.Close() 983 require.NoError(t, rows.Err()) 984 }) 985} 986 987func TestScanIntoByteSlice(t *testing.T) { 988 t.Parallel() 989 990 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 991 defer closeConn(t, conn) 992 // Success cases 993 for _, tt := range []struct { 994 name string 995 sql string 996 resultFormatCode int16 997 output []byte 998 }{ 999 {"int - text", "select 42", pgx.TextFormatCode, []byte("42")}, 1000 {"text - text", "select 'hi'", pgx.TextFormatCode, []byte("hi")}, 1001 {"text - binary", "select 'hi'", pgx.BinaryFormatCode, []byte("hi")}, 1002 {"json - text", "select '{}'::json", pgx.TextFormatCode, []byte("{}")}, 1003 {"json - binary", "select '{}'::json", pgx.BinaryFormatCode, []byte("{}")}, 1004 {"jsonb - text", "select '{}'::jsonb", pgx.TextFormatCode, []byte("{}")}, 1005 {"jsonb - binary", "select '{}'::jsonb", pgx.BinaryFormatCode, []byte("{}")}, 1006 } { 1007 t.Run(tt.name, func(t *testing.T) { 1008 var buf []byte 1009 err := conn.QueryRow(context.Background(), tt.sql, pgx.QueryResultFormats{tt.resultFormatCode}).Scan(&buf) 1010 require.NoError(t, err) 1011 require.Equal(t, tt.output, buf) 1012 }) 1013 } 1014 1015 // Failure cases 1016 for _, tt := range []struct { 1017 name string 1018 sql string 1019 err string 1020 }{ 1021 {"int binary", "select 42", "can't scan into dest[0]: cannot assign 42 into *[]uint8"}, 1022 } { 1023 t.Run(tt.name, func(t *testing.T) { 1024 var buf []byte 1025 err := conn.QueryRow(context.Background(), tt.sql, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&buf) 1026 require.EqualError(t, err, tt.err) 1027 }) 1028 } 1029} 1030