1package pgx_test 2 3import ( 4 "context" 5 "testing" 6 7 "github.com/jackc/pgx" 8 "github.com/jackc/pgx/pgtype" 9) 10 11func TestConnBeginBatch(t *testing.T) { 12 t.Parallel() 13 14 conn := mustConnect(t, *defaultConnConfig) 15 defer closeConn(t, conn) 16 17 sql := `create temporary table ledger( 18 id serial primary key, 19 description varchar not null, 20 amount int not null 21);` 22 mustExec(t, conn, sql) 23 24 batch := conn.BeginBatch() 25 batch.Queue("insert into ledger(description, amount) values($1, $2)", 26 []interface{}{"q1", 1}, 27 []pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID}, 28 nil, 29 ) 30 batch.Queue("insert into ledger(description, amount) values($1, $2)", 31 []interface{}{"q2", 2}, 32 []pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID}, 33 nil, 34 ) 35 batch.Queue("insert into ledger(description, amount) values($1, $2)", 36 []interface{}{"q3", 3}, 37 []pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID}, 38 nil, 39 ) 40 batch.Queue("select id, description, amount from ledger order by id", 41 nil, 42 nil, 43 []int16{pgx.BinaryFormatCode, pgx.TextFormatCode, pgx.BinaryFormatCode}, 44 ) 45 batch.Queue("select sum(amount) from ledger", 46 nil, 47 nil, 48 []int16{pgx.BinaryFormatCode}, 49 ) 50 51 err := batch.Send(context.Background(), nil) 52 if err != nil { 53 t.Fatal(err) 54 } 55 56 ct, err := batch.ExecResults() 57 if err != nil { 58 t.Error(err) 59 } 60 if ct.RowsAffected() != 1 { 61 t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) 62 } 63 64 ct, err = batch.ExecResults() 65 if err != nil { 66 t.Error(err) 67 } 68 if ct.RowsAffected() != 1 { 69 t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) 70 } 71 72 rows, err := batch.QueryResults() 73 if err != nil { 74 t.Error(err) 75 } 76 77 var id int32 78 var description string 79 var amount int32 80 if !rows.Next() { 81 t.Fatal("expected a row to be available") 82 } 83 if err := rows.Scan(&id, &description, &amount); err != nil { 84 t.Fatal(err) 85 } 86 if id != 1 { 87 t.Errorf("id => %v, want %v", id, 1) 88 } 89 if description != "q1" { 90 t.Errorf("description => %v, want %v", description, "q1") 91 } 92 if amount != 1 { 93 t.Errorf("amount => %v, want %v", amount, 1) 94 } 95 96 if !rows.Next() { 97 t.Fatal("expected a row to be available") 98 } 99 if err := rows.Scan(&id, &description, &amount); err != nil { 100 t.Fatal(err) 101 } 102 if id != 2 { 103 t.Errorf("id => %v, want %v", id, 2) 104 } 105 if description != "q2" { 106 t.Errorf("description => %v, want %v", description, "q2") 107 } 108 if amount != 2 { 109 t.Errorf("amount => %v, want %v", amount, 2) 110 } 111 112 if !rows.Next() { 113 t.Fatal("expected a row to be available") 114 } 115 if err := rows.Scan(&id, &description, &amount); err != nil { 116 t.Fatal(err) 117 } 118 if id != 3 { 119 t.Errorf("id => %v, want %v", id, 3) 120 } 121 if description != "q3" { 122 t.Errorf("description => %v, want %v", description, "q3") 123 } 124 if amount != 3 { 125 t.Errorf("amount => %v, want %v", amount, 3) 126 } 127 128 if rows.Next() { 129 t.Fatal("did not expect a row to be available") 130 } 131 132 if rows.Err() != nil { 133 t.Fatal(rows.Err()) 134 } 135 136 err = batch.QueryRowResults().Scan(&amount) 137 if err != nil { 138 t.Error(err) 139 } 140 if amount != 6 { 141 t.Errorf("amount => %v, want %v", amount, 6) 142 } 143 144 err = batch.Close() 145 if err != nil { 146 t.Fatal(err) 147 } 148 149 ensureConnValid(t, conn) 150} 151 152func TestConnBeginBatchWithPreparedStatement(t *testing.T) { 153 t.Parallel() 154 155 conn := mustConnect(t, *defaultConnConfig) 156 defer closeConn(t, conn) 157 158 _, err := conn.Prepare("ps1", "select n from generate_series(0,$1::int) n") 159 if err != nil { 160 t.Fatal(err) 161 } 162 163 batch := conn.BeginBatch() 164 165 queryCount := 3 166 for i := 0; i < queryCount; i++ { 167 batch.Queue("ps1", 168 []interface{}{5}, 169 nil, 170 []int16{pgx.BinaryFormatCode}, 171 ) 172 } 173 174 err = batch.Send(context.Background(), nil) 175 if err != nil { 176 t.Fatal(err) 177 } 178 179 for i := 0; i < queryCount; i++ { 180 rows, err := batch.QueryResults() 181 if err != nil { 182 t.Fatal(err) 183 } 184 185 for k := 0; rows.Next(); k++ { 186 var n int 187 if err := rows.Scan(&n); err != nil { 188 t.Fatal(err) 189 } 190 if n != k { 191 t.Fatalf("n => %v, want %v", n, k) 192 } 193 } 194 195 if rows.Err() != nil { 196 t.Fatal(rows.Err()) 197 } 198 } 199 200 err = batch.Close() 201 if err != nil { 202 t.Fatal(err) 203 } 204 205 ensureConnValid(t, conn) 206} 207 208func TestConnBeginBatchContextCancelBeforeExecResults(t *testing.T) { 209 t.Parallel() 210 211 conn := mustConnect(t, *defaultConnConfig) 212 213 sql := `create temporary table ledger( 214 id serial primary key, 215 description varchar not null, 216 amount int not null 217);` 218 mustExec(t, conn, sql) 219 220 batch := conn.BeginBatch() 221 batch.Queue("insert into ledger(description, amount) values($1, $2)", 222 []interface{}{"q1", 1}, 223 []pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID}, 224 nil, 225 ) 226 batch.Queue("select pg_sleep(2)", 227 nil, 228 nil, 229 nil, 230 ) 231 232 ctx, cancelFn := context.WithCancel(context.Background()) 233 234 err := batch.Send(ctx, nil) 235 if err != nil { 236 t.Fatal(err) 237 } 238 239 cancelFn() 240 241 _, err = batch.ExecResults() 242 if err != context.Canceled { 243 t.Errorf("err => %v, want %v", err, context.Canceled) 244 } 245 246 if conn.IsAlive() { 247 t.Error("conn should be dead, but was alive") 248 } 249} 250 251func TestConnBeginBatchContextCancelBeforeQueryResults(t *testing.T) { 252 t.Parallel() 253 254 conn := mustConnect(t, *defaultConnConfig) 255 256 batch := conn.BeginBatch() 257 batch.Queue("select pg_sleep(2)", 258 nil, 259 nil, 260 nil, 261 ) 262 batch.Queue("select pg_sleep(2)", 263 nil, 264 nil, 265 nil, 266 ) 267 268 ctx, cancelFn := context.WithCancel(context.Background()) 269 270 err := batch.Send(ctx, nil) 271 if err != nil { 272 t.Fatal(err) 273 } 274 275 cancelFn() 276 277 _, err = batch.QueryResults() 278 if err != context.Canceled { 279 t.Errorf("err => %v, want %v", err, context.Canceled) 280 } 281 282 if conn.IsAlive() { 283 t.Error("conn should be dead, but was alive") 284 } 285} 286 287func TestConnBeginBatchContextCancelBeforeFinish(t *testing.T) { 288 t.Parallel() 289 290 conn := mustConnect(t, *defaultConnConfig) 291 292 batch := conn.BeginBatch() 293 batch.Queue("select pg_sleep(2)", 294 nil, 295 nil, 296 nil, 297 ) 298 batch.Queue("select pg_sleep(2)", 299 nil, 300 nil, 301 nil, 302 ) 303 304 ctx, cancelFn := context.WithCancel(context.Background()) 305 306 err := batch.Send(ctx, nil) 307 if err != nil { 308 t.Fatal(err) 309 } 310 311 cancelFn() 312 313 err = batch.Close() 314 if err != context.Canceled { 315 t.Errorf("err => %v, want %v", err, context.Canceled) 316 } 317 318 if conn.IsAlive() { 319 t.Error("conn should be dead, but was alive") 320 } 321} 322 323func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) { 324 t.Parallel() 325 326 conn := mustConnect(t, *defaultConnConfig) 327 defer closeConn(t, conn) 328 329 batch := conn.BeginBatch() 330 batch.Queue("select n from generate_series(0,5) n", 331 nil, 332 nil, 333 []int16{pgx.BinaryFormatCode}, 334 ) 335 batch.Queue("select n from generate_series(0,5) n", 336 nil, 337 nil, 338 []int16{pgx.BinaryFormatCode}, 339 ) 340 341 err := batch.Send(context.Background(), nil) 342 if err != nil { 343 t.Fatal(err) 344 } 345 346 rows, err := batch.QueryResults() 347 if err != nil { 348 t.Error(err) 349 } 350 351 for i := 0; i < 3; i++ { 352 if !rows.Next() { 353 t.Error("expected a row to be available") 354 } 355 356 var n int 357 if err := rows.Scan(&n); err != nil { 358 t.Error(err) 359 } 360 if n != i { 361 t.Errorf("n => %v, want %v", n, i) 362 } 363 } 364 365 rows.Close() 366 367 rows, err = batch.QueryResults() 368 if err != nil { 369 t.Error(err) 370 } 371 372 for i := 0; rows.Next(); i++ { 373 var n int 374 if err := rows.Scan(&n); err != nil { 375 t.Error(err) 376 } 377 if n != i { 378 t.Errorf("n => %v, want %v", n, i) 379 } 380 } 381 382 if rows.Err() != nil { 383 t.Error(rows.Err()) 384 } 385 386 err = batch.Close() 387 if err != nil { 388 t.Fatal(err) 389 } 390 391 ensureConnValid(t, conn) 392} 393 394func TestConnBeginBatchQueryError(t *testing.T) { 395 t.Parallel() 396 397 conn := mustConnect(t, *defaultConnConfig) 398 defer closeConn(t, conn) 399 400 batch := conn.BeginBatch() 401 batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0", 402 nil, 403 nil, 404 []int16{pgx.BinaryFormatCode}, 405 ) 406 batch.Queue("select n from generate_series(0,5) n", 407 nil, 408 nil, 409 []int16{pgx.BinaryFormatCode}, 410 ) 411 412 err := batch.Send(context.Background(), nil) 413 if err != nil { 414 t.Fatal(err) 415 } 416 417 rows, err := batch.QueryResults() 418 if err != nil { 419 t.Error(err) 420 } 421 422 for i := 0; rows.Next(); i++ { 423 var n int 424 if err := rows.Scan(&n); err != nil { 425 t.Error(err) 426 } 427 if n != i { 428 t.Errorf("n => %v, want %v", n, i) 429 } 430 } 431 432 if pgErr, ok := rows.Err().(pgx.PgError); !(ok && pgErr.Code == "22012") { 433 t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012) 434 } 435 436 err = batch.Close() 437 if pgErr, ok := err.(pgx.PgError); !(ok && pgErr.Code == "22012") { 438 t.Errorf("rows.Err() => %v, want error code %v", err, 22012) 439 } 440 441 if conn.IsAlive() { 442 t.Error("conn should be dead, but was alive") 443 } 444} 445 446func TestConnBeginBatchQuerySyntaxError(t *testing.T) { 447 t.Parallel() 448 449 conn := mustConnect(t, *defaultConnConfig) 450 defer closeConn(t, conn) 451 452 batch := conn.BeginBatch() 453 batch.Queue("select 1 1", 454 nil, 455 nil, 456 []int16{pgx.BinaryFormatCode}, 457 ) 458 459 err := batch.Send(context.Background(), nil) 460 if err != nil { 461 t.Fatal(err) 462 } 463 464 var n int32 465 err = batch.QueryRowResults().Scan(&n) 466 if pgErr, ok := err.(pgx.PgError); !(ok && pgErr.Code == "42601") { 467 t.Errorf("rows.Err() => %v, want error code %v", err, 42601) 468 } 469 470 err = batch.Close() 471 if err == nil { 472 t.Error("Expected error") 473 } 474 475 if conn.IsAlive() { 476 t.Error("conn should be dead, but was alive") 477 } 478} 479 480func TestConnBeginBatchQueryRowInsert(t *testing.T) { 481 t.Parallel() 482 483 conn := mustConnect(t, *defaultConnConfig) 484 defer closeConn(t, conn) 485 486 sql := `create temporary table ledger( 487 id serial primary key, 488 description varchar not null, 489 amount int not null 490);` 491 mustExec(t, conn, sql) 492 493 batch := conn.BeginBatch() 494 batch.Queue("select 1", 495 nil, 496 nil, 497 []int16{pgx.BinaryFormatCode}, 498 ) 499 batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", 500 []interface{}{"q1", 1}, 501 []pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID}, 502 nil, 503 ) 504 505 err := batch.Send(context.Background(), nil) 506 if err != nil { 507 t.Fatal(err) 508 } 509 510 var value int 511 err = batch.QueryRowResults().Scan(&value) 512 if err != nil { 513 t.Error(err) 514 } 515 516 ct, err := batch.ExecResults() 517 if err != nil { 518 t.Error(err) 519 } 520 if ct.RowsAffected() != 2 { 521 t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2) 522 } 523 524 batch.Close() 525 526 ensureConnValid(t, conn) 527} 528 529func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) { 530 t.Parallel() 531 532 conn := mustConnect(t, *defaultConnConfig) 533 defer closeConn(t, conn) 534 535 sql := `create temporary table ledger( 536 id serial primary key, 537 description varchar not null, 538 amount int not null 539);` 540 mustExec(t, conn, sql) 541 542 batch := conn.BeginBatch() 543 batch.Queue("select 1 union all select 2 union all select 3", 544 nil, 545 nil, 546 []int16{pgx.BinaryFormatCode}, 547 ) 548 batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", 549 []interface{}{"q1", 1}, 550 []pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID}, 551 nil, 552 ) 553 554 err := batch.Send(context.Background(), nil) 555 if err != nil { 556 t.Fatal(err) 557 } 558 559 rows, err := batch.QueryResults() 560 if err != nil { 561 t.Error(err) 562 } 563 rows.Close() 564 565 ct, err := batch.ExecResults() 566 if err != nil { 567 t.Error(err) 568 } 569 if ct.RowsAffected() != 2 { 570 t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2) 571 } 572 573 batch.Close() 574 575 ensureConnValid(t, conn) 576} 577 578func TestTxBeginBatch(t *testing.T) { 579 t.Parallel() 580 581 conn := mustConnect(t, *defaultConnConfig) 582 defer closeConn(t, conn) 583 584 sql := `create temporary table ledger1( 585 id serial primary key, 586 description varchar not null 587);` 588 mustExec(t, conn, sql) 589 590 sql = `create temporary table ledger2( 591 id int primary key, 592 amount int not null 593);` 594 mustExec(t, conn, sql) 595 596 tx, _ := conn.Begin() 597 batch := tx.BeginBatch() 598 batch.Queue("insert into ledger1(description) values($1) returning id", 599 []interface{}{"q1"}, 600 []pgtype.OID{pgtype.VarcharOID}, 601 []int16{pgx.BinaryFormatCode}, 602 ) 603 604 err := batch.Send(context.Background(), nil) 605 if err != nil { 606 t.Fatal(err) 607 } 608 var id int 609 err = batch.QueryRowResults().Scan(&id) 610 if err != nil { 611 t.Error(err) 612 } 613 batch.Close() 614 615 batch = tx.BeginBatch() 616 batch.Queue("insert into ledger2(id,amount) values($1, $2)", 617 []interface{}{id, 2}, 618 []pgtype.OID{pgtype.Int4OID, pgtype.Int4OID}, 619 nil, 620 ) 621 622 batch.Queue("select amount from ledger2 where id = $1", 623 []interface{}{id}, 624 []pgtype.OID{pgtype.Int4OID}, 625 nil, 626 ) 627 628 err = batch.Send(context.Background(), nil) 629 if err != nil { 630 t.Fatal(err) 631 } 632 ct, err := batch.ExecResults() 633 if err != nil { 634 t.Error(err) 635 } 636 if ct.RowsAffected() != 1 { 637 t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) 638 } 639 640 var amout int 641 err = batch.QueryRowResults().Scan(&amout) 642 if err != nil { 643 t.Error(err) 644 } 645 646 batch.Close() 647 tx.Commit() 648 649 var count int 650 conn.QueryRow("select count(1) from ledger1 where id = $1", id).Scan(&count) 651 if count != 1 { 652 t.Errorf("count => %v, want %v", count, 1) 653 } 654 655 err = batch.Close() 656 if err != nil { 657 t.Fatal(err) 658 } 659 660 ensureConnValid(t, conn) 661} 662 663func TestTxBeginBatchRollback(t *testing.T) { 664 t.Parallel() 665 666 conn := mustConnect(t, *defaultConnConfig) 667 defer closeConn(t, conn) 668 669 sql := `create temporary table ledger1( 670 id serial primary key, 671 description varchar not null 672);` 673 mustExec(t, conn, sql) 674 675 tx, _ := conn.Begin() 676 batch := tx.BeginBatch() 677 batch.Queue("insert into ledger1(description) values($1) returning id", 678 []interface{}{"q1"}, 679 []pgtype.OID{pgtype.VarcharOID}, 680 []int16{pgx.BinaryFormatCode}, 681 ) 682 683 err := batch.Send(context.Background(), nil) 684 if err != nil { 685 t.Fatal(err) 686 } 687 var id int 688 err = batch.QueryRowResults().Scan(&id) 689 if err != nil { 690 t.Error(err) 691 } 692 batch.Close() 693 tx.Rollback() 694 695 row := conn.QueryRow("select count(1) from ledger1 where id = $1", id) 696 var count int 697 row.Scan(&count) 698 if count != 0 { 699 t.Errorf("count => %v, want %v", count, 0) 700 } 701 702 ensureConnValid(t, conn) 703} 704