1package pgx_test
2
3import (
4	"context"
5	"testing"
6
7	"github.com/jackc/pgx"
8	"github.com/jackc/pgx/pgtype"
9)
10
11func TestConnBeginBatch(t *testing.T) {
12	t.Parallel()
13
14	conn := mustConnect(t, *defaultConnConfig)
15	defer closeConn(t, conn)
16
17	sql := `create temporary table ledger(
18  id serial primary key,
19  description varchar not null,
20  amount int not null
21);`
22	mustExec(t, conn, sql)
23
24	batch := conn.BeginBatch()
25	batch.Queue("insert into ledger(description, amount) values($1, $2)",
26		[]interface{}{"q1", 1},
27		[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
28		nil,
29	)
30	batch.Queue("insert into ledger(description, amount) values($1, $2)",
31		[]interface{}{"q2", 2},
32		[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
33		nil,
34	)
35	batch.Queue("insert into ledger(description, amount) values($1, $2)",
36		[]interface{}{"q3", 3},
37		[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
38		nil,
39	)
40	batch.Queue("select id, description, amount from ledger order by id",
41		nil,
42		nil,
43		[]int16{pgx.BinaryFormatCode, pgx.TextFormatCode, pgx.BinaryFormatCode},
44	)
45	batch.Queue("select sum(amount) from ledger",
46		nil,
47		nil,
48		[]int16{pgx.BinaryFormatCode},
49	)
50
51	err := batch.Send(context.Background(), nil)
52	if err != nil {
53		t.Fatal(err)
54	}
55
56	ct, err := batch.ExecResults()
57	if err != nil {
58		t.Error(err)
59	}
60	if ct.RowsAffected() != 1 {
61		t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
62	}
63
64	ct, err = batch.ExecResults()
65	if err != nil {
66		t.Error(err)
67	}
68	if ct.RowsAffected() != 1 {
69		t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
70	}
71
72	rows, err := batch.QueryResults()
73	if err != nil {
74		t.Error(err)
75	}
76
77	var id int32
78	var description string
79	var amount int32
80	if !rows.Next() {
81		t.Fatal("expected a row to be available")
82	}
83	if err := rows.Scan(&id, &description, &amount); err != nil {
84		t.Fatal(err)
85	}
86	if id != 1 {
87		t.Errorf("id => %v, want %v", id, 1)
88	}
89	if description != "q1" {
90		t.Errorf("description => %v, want %v", description, "q1")
91	}
92	if amount != 1 {
93		t.Errorf("amount => %v, want %v", amount, 1)
94	}
95
96	if !rows.Next() {
97		t.Fatal("expected a row to be available")
98	}
99	if err := rows.Scan(&id, &description, &amount); err != nil {
100		t.Fatal(err)
101	}
102	if id != 2 {
103		t.Errorf("id => %v, want %v", id, 2)
104	}
105	if description != "q2" {
106		t.Errorf("description => %v, want %v", description, "q2")
107	}
108	if amount != 2 {
109		t.Errorf("amount => %v, want %v", amount, 2)
110	}
111
112	if !rows.Next() {
113		t.Fatal("expected a row to be available")
114	}
115	if err := rows.Scan(&id, &description, &amount); err != nil {
116		t.Fatal(err)
117	}
118	if id != 3 {
119		t.Errorf("id => %v, want %v", id, 3)
120	}
121	if description != "q3" {
122		t.Errorf("description => %v, want %v", description, "q3")
123	}
124	if amount != 3 {
125		t.Errorf("amount => %v, want %v", amount, 3)
126	}
127
128	if rows.Next() {
129		t.Fatal("did not expect a row to be available")
130	}
131
132	if rows.Err() != nil {
133		t.Fatal(rows.Err())
134	}
135
136	err = batch.QueryRowResults().Scan(&amount)
137	if err != nil {
138		t.Error(err)
139	}
140	if amount != 6 {
141		t.Errorf("amount => %v, want %v", amount, 6)
142	}
143
144	err = batch.Close()
145	if err != nil {
146		t.Fatal(err)
147	}
148
149	ensureConnValid(t, conn)
150}
151
152func TestConnBeginBatchWithPreparedStatement(t *testing.T) {
153	t.Parallel()
154
155	conn := mustConnect(t, *defaultConnConfig)
156	defer closeConn(t, conn)
157
158	_, err := conn.Prepare("ps1", "select n from generate_series(0,$1::int) n")
159	if err != nil {
160		t.Fatal(err)
161	}
162
163	batch := conn.BeginBatch()
164
165	queryCount := 3
166	for i := 0; i < queryCount; i++ {
167		batch.Queue("ps1",
168			[]interface{}{5},
169			nil,
170			[]int16{pgx.BinaryFormatCode},
171		)
172	}
173
174	err = batch.Send(context.Background(), nil)
175	if err != nil {
176		t.Fatal(err)
177	}
178
179	for i := 0; i < queryCount; i++ {
180		rows, err := batch.QueryResults()
181		if err != nil {
182			t.Fatal(err)
183		}
184
185		for k := 0; rows.Next(); k++ {
186			var n int
187			if err := rows.Scan(&n); err != nil {
188				t.Fatal(err)
189			}
190			if n != k {
191				t.Fatalf("n => %v, want %v", n, k)
192			}
193		}
194
195		if rows.Err() != nil {
196			t.Fatal(rows.Err())
197		}
198	}
199
200	err = batch.Close()
201	if err != nil {
202		t.Fatal(err)
203	}
204
205	ensureConnValid(t, conn)
206}
207
208func TestConnBeginBatchContextCancelBeforeExecResults(t *testing.T) {
209	t.Parallel()
210
211	conn := mustConnect(t, *defaultConnConfig)
212
213	sql := `create temporary table ledger(
214  id serial primary key,
215  description varchar not null,
216  amount int not null
217);`
218	mustExec(t, conn, sql)
219
220	batch := conn.BeginBatch()
221	batch.Queue("insert into ledger(description, amount) values($1, $2)",
222		[]interface{}{"q1", 1},
223		[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
224		nil,
225	)
226	batch.Queue("select pg_sleep(2)",
227		nil,
228		nil,
229		nil,
230	)
231
232	ctx, cancelFn := context.WithCancel(context.Background())
233
234	err := batch.Send(ctx, nil)
235	if err != nil {
236		t.Fatal(err)
237	}
238
239	cancelFn()
240
241	_, err = batch.ExecResults()
242	if err != context.Canceled {
243		t.Errorf("err => %v, want %v", err, context.Canceled)
244	}
245
246	if conn.IsAlive() {
247		t.Error("conn should be dead, but was alive")
248	}
249}
250
251func TestConnBeginBatchContextCancelBeforeQueryResults(t *testing.T) {
252	t.Parallel()
253
254	conn := mustConnect(t, *defaultConnConfig)
255
256	batch := conn.BeginBatch()
257	batch.Queue("select pg_sleep(2)",
258		nil,
259		nil,
260		nil,
261	)
262	batch.Queue("select pg_sleep(2)",
263		nil,
264		nil,
265		nil,
266	)
267
268	ctx, cancelFn := context.WithCancel(context.Background())
269
270	err := batch.Send(ctx, nil)
271	if err != nil {
272		t.Fatal(err)
273	}
274
275	cancelFn()
276
277	_, err = batch.QueryResults()
278	if err != context.Canceled {
279		t.Errorf("err => %v, want %v", err, context.Canceled)
280	}
281
282	if conn.IsAlive() {
283		t.Error("conn should be dead, but was alive")
284	}
285}
286
287func TestConnBeginBatchContextCancelBeforeFinish(t *testing.T) {
288	t.Parallel()
289
290	conn := mustConnect(t, *defaultConnConfig)
291
292	batch := conn.BeginBatch()
293	batch.Queue("select pg_sleep(2)",
294		nil,
295		nil,
296		nil,
297	)
298	batch.Queue("select pg_sleep(2)",
299		nil,
300		nil,
301		nil,
302	)
303
304	ctx, cancelFn := context.WithCancel(context.Background())
305
306	err := batch.Send(ctx, nil)
307	if err != nil {
308		t.Fatal(err)
309	}
310
311	cancelFn()
312
313	err = batch.Close()
314	if err != context.Canceled {
315		t.Errorf("err => %v, want %v", err, context.Canceled)
316	}
317
318	if conn.IsAlive() {
319		t.Error("conn should be dead, but was alive")
320	}
321}
322
323func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) {
324	t.Parallel()
325
326	conn := mustConnect(t, *defaultConnConfig)
327	defer closeConn(t, conn)
328
329	batch := conn.BeginBatch()
330	batch.Queue("select n from generate_series(0,5) n",
331		nil,
332		nil,
333		[]int16{pgx.BinaryFormatCode},
334	)
335	batch.Queue("select n from generate_series(0,5) n",
336		nil,
337		nil,
338		[]int16{pgx.BinaryFormatCode},
339	)
340
341	err := batch.Send(context.Background(), nil)
342	if err != nil {
343		t.Fatal(err)
344	}
345
346	rows, err := batch.QueryResults()
347	if err != nil {
348		t.Error(err)
349	}
350
351	for i := 0; i < 3; i++ {
352		if !rows.Next() {
353			t.Error("expected a row to be available")
354		}
355
356		var n int
357		if err := rows.Scan(&n); err != nil {
358			t.Error(err)
359		}
360		if n != i {
361			t.Errorf("n => %v, want %v", n, i)
362		}
363	}
364
365	rows.Close()
366
367	rows, err = batch.QueryResults()
368	if err != nil {
369		t.Error(err)
370	}
371
372	for i := 0; rows.Next(); i++ {
373		var n int
374		if err := rows.Scan(&n); err != nil {
375			t.Error(err)
376		}
377		if n != i {
378			t.Errorf("n => %v, want %v", n, i)
379		}
380	}
381
382	if rows.Err() != nil {
383		t.Error(rows.Err())
384	}
385
386	err = batch.Close()
387	if err != nil {
388		t.Fatal(err)
389	}
390
391	ensureConnValid(t, conn)
392}
393
394func TestConnBeginBatchQueryError(t *testing.T) {
395	t.Parallel()
396
397	conn := mustConnect(t, *defaultConnConfig)
398	defer closeConn(t, conn)
399
400	batch := conn.BeginBatch()
401	batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0",
402		nil,
403		nil,
404		[]int16{pgx.BinaryFormatCode},
405	)
406	batch.Queue("select n from generate_series(0,5) n",
407		nil,
408		nil,
409		[]int16{pgx.BinaryFormatCode},
410	)
411
412	err := batch.Send(context.Background(), nil)
413	if err != nil {
414		t.Fatal(err)
415	}
416
417	rows, err := batch.QueryResults()
418	if err != nil {
419		t.Error(err)
420	}
421
422	for i := 0; rows.Next(); i++ {
423		var n int
424		if err := rows.Scan(&n); err != nil {
425			t.Error(err)
426		}
427		if n != i {
428			t.Errorf("n => %v, want %v", n, i)
429		}
430	}
431
432	if pgErr, ok := rows.Err().(pgx.PgError); !(ok && pgErr.Code == "22012") {
433		t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012)
434	}
435
436	err = batch.Close()
437	if pgErr, ok := err.(pgx.PgError); !(ok && pgErr.Code == "22012") {
438		t.Errorf("rows.Err() => %v, want error code %v", err, 22012)
439	}
440
441	if conn.IsAlive() {
442		t.Error("conn should be dead, but was alive")
443	}
444}
445
446func TestConnBeginBatchQuerySyntaxError(t *testing.T) {
447	t.Parallel()
448
449	conn := mustConnect(t, *defaultConnConfig)
450	defer closeConn(t, conn)
451
452	batch := conn.BeginBatch()
453	batch.Queue("select 1 1",
454		nil,
455		nil,
456		[]int16{pgx.BinaryFormatCode},
457	)
458
459	err := batch.Send(context.Background(), nil)
460	if err != nil {
461		t.Fatal(err)
462	}
463
464	var n int32
465	err = batch.QueryRowResults().Scan(&n)
466	if pgErr, ok := err.(pgx.PgError); !(ok && pgErr.Code == "42601") {
467		t.Errorf("rows.Err() => %v, want error code %v", err, 42601)
468	}
469
470	err = batch.Close()
471	if err == nil {
472		t.Error("Expected error")
473	}
474
475	if conn.IsAlive() {
476		t.Error("conn should be dead, but was alive")
477	}
478}
479
480func TestConnBeginBatchQueryRowInsert(t *testing.T) {
481	t.Parallel()
482
483	conn := mustConnect(t, *defaultConnConfig)
484	defer closeConn(t, conn)
485
486	sql := `create temporary table ledger(
487  id serial primary key,
488  description varchar not null,
489  amount int not null
490);`
491	mustExec(t, conn, sql)
492
493	batch := conn.BeginBatch()
494	batch.Queue("select 1",
495		nil,
496		nil,
497		[]int16{pgx.BinaryFormatCode},
498	)
499	batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)",
500		[]interface{}{"q1", 1},
501		[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
502		nil,
503	)
504
505	err := batch.Send(context.Background(), nil)
506	if err != nil {
507		t.Fatal(err)
508	}
509
510	var value int
511	err = batch.QueryRowResults().Scan(&value)
512	if err != nil {
513		t.Error(err)
514	}
515
516	ct, err := batch.ExecResults()
517	if err != nil {
518		t.Error(err)
519	}
520	if ct.RowsAffected() != 2 {
521		t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2)
522	}
523
524	batch.Close()
525
526	ensureConnValid(t, conn)
527}
528
529func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) {
530	t.Parallel()
531
532	conn := mustConnect(t, *defaultConnConfig)
533	defer closeConn(t, conn)
534
535	sql := `create temporary table ledger(
536  id serial primary key,
537  description varchar not null,
538  amount int not null
539);`
540	mustExec(t, conn, sql)
541
542	batch := conn.BeginBatch()
543	batch.Queue("select 1 union all select 2 union all select 3",
544		nil,
545		nil,
546		[]int16{pgx.BinaryFormatCode},
547	)
548	batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)",
549		[]interface{}{"q1", 1},
550		[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
551		nil,
552	)
553
554	err := batch.Send(context.Background(), nil)
555	if err != nil {
556		t.Fatal(err)
557	}
558
559	rows, err := batch.QueryResults()
560	if err != nil {
561		t.Error(err)
562	}
563	rows.Close()
564
565	ct, err := batch.ExecResults()
566	if err != nil {
567		t.Error(err)
568	}
569	if ct.RowsAffected() != 2 {
570		t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2)
571	}
572
573	batch.Close()
574
575	ensureConnValid(t, conn)
576}
577
578func TestTxBeginBatch(t *testing.T) {
579	t.Parallel()
580
581	conn := mustConnect(t, *defaultConnConfig)
582	defer closeConn(t, conn)
583
584	sql := `create temporary table ledger1(
585  id serial primary key,
586  description varchar not null
587);`
588	mustExec(t, conn, sql)
589
590	sql = `create temporary table ledger2(
591  id int primary key,
592  amount int not null
593);`
594	mustExec(t, conn, sql)
595
596	tx, _ := conn.Begin()
597	batch := tx.BeginBatch()
598	batch.Queue("insert into ledger1(description) values($1) returning id",
599		[]interface{}{"q1"},
600		[]pgtype.OID{pgtype.VarcharOID},
601		[]int16{pgx.BinaryFormatCode},
602	)
603
604	err := batch.Send(context.Background(), nil)
605	if err != nil {
606		t.Fatal(err)
607	}
608	var id int
609	err = batch.QueryRowResults().Scan(&id)
610	if err != nil {
611		t.Error(err)
612	}
613	batch.Close()
614
615	batch = tx.BeginBatch()
616	batch.Queue("insert into ledger2(id,amount) values($1, $2)",
617		[]interface{}{id, 2},
618		[]pgtype.OID{pgtype.Int4OID, pgtype.Int4OID},
619		nil,
620	)
621
622	batch.Queue("select amount from ledger2 where id = $1",
623		[]interface{}{id},
624		[]pgtype.OID{pgtype.Int4OID},
625		nil,
626	)
627
628	err = batch.Send(context.Background(), nil)
629	if err != nil {
630		t.Fatal(err)
631	}
632	ct, err := batch.ExecResults()
633	if err != nil {
634		t.Error(err)
635	}
636	if ct.RowsAffected() != 1 {
637		t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
638	}
639
640	var amout int
641	err = batch.QueryRowResults().Scan(&amout)
642	if err != nil {
643		t.Error(err)
644	}
645
646	batch.Close()
647	tx.Commit()
648
649	var count int
650	conn.QueryRow("select count(1) from ledger1 where id = $1", id).Scan(&count)
651	if count != 1 {
652		t.Errorf("count => %v, want %v", count, 1)
653	}
654
655	err = batch.Close()
656	if err != nil {
657		t.Fatal(err)
658	}
659
660	ensureConnValid(t, conn)
661}
662
663func TestTxBeginBatchRollback(t *testing.T) {
664	t.Parallel()
665
666	conn := mustConnect(t, *defaultConnConfig)
667	defer closeConn(t, conn)
668
669	sql := `create temporary table ledger1(
670  id serial primary key,
671  description varchar not null
672);`
673	mustExec(t, conn, sql)
674
675	tx, _ := conn.Begin()
676	batch := tx.BeginBatch()
677	batch.Queue("insert into ledger1(description) values($1) returning id",
678		[]interface{}{"q1"},
679		[]pgtype.OID{pgtype.VarcharOID},
680		[]int16{pgx.BinaryFormatCode},
681	)
682
683	err := batch.Send(context.Background(), nil)
684	if err != nil {
685		t.Fatal(err)
686	}
687	var id int
688	err = batch.QueryRowResults().Scan(&id)
689	if err != nil {
690		t.Error(err)
691	}
692	batch.Close()
693	tx.Rollback()
694
695	row := conn.QueryRow("select count(1) from ledger1 where id = $1", id)
696	var count int
697	row.Scan(&count)
698	if count != 0 {
699		t.Errorf("count => %v, want %v", count, 0)
700	}
701
702	ensureConnValid(t, conn)
703}
704