1package pgx_test
2
3import (
4	"context"
5	"fmt"
6	"os"
7	"reflect"
8	"testing"
9	"time"
10
11	"github.com/jackc/pgconn"
12	"github.com/jackc/pgx/v4"
13	"github.com/stretchr/testify/require"
14)
15
16func TestConnCopyFromSmall(t *testing.T) {
17	t.Parallel()
18
19	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
20	defer closeConn(t, conn)
21
22	mustExec(t, conn, `create temporary table foo(
23		a int2,
24		b int4,
25		c int8,
26		d varchar,
27		e text,
28		f date,
29		g timestamptz
30	)`)
31
32	tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
33
34	inputRows := [][]interface{}{
35		{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime},
36		{nil, nil, nil, nil, nil, nil, nil},
37	}
38
39	copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
40	if err != nil {
41		t.Errorf("Unexpected error for CopyFrom: %v", err)
42	}
43	if int(copyCount) != len(inputRows) {
44		t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
45	}
46
47	rows, err := conn.Query(context.Background(), "select * from foo")
48	if err != nil {
49		t.Errorf("Unexpected error for Query: %v", err)
50	}
51
52	var outputRows [][]interface{}
53	for rows.Next() {
54		row, err := rows.Values()
55		if err != nil {
56			t.Errorf("Unexpected error for rows.Values(): %v", err)
57		}
58		outputRows = append(outputRows, row)
59	}
60
61	if rows.Err() != nil {
62		t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
63	}
64
65	if !reflect.DeepEqual(inputRows, outputRows) {
66		t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
67	}
68
69	ensureConnValid(t, conn)
70}
71
72func TestConnCopyFromSliceSmall(t *testing.T) {
73	t.Parallel()
74
75	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
76	defer closeConn(t, conn)
77
78	mustExec(t, conn, `create temporary table foo(
79		a int2,
80		b int4,
81		c int8,
82		d varchar,
83		e text,
84		f date,
85		g timestamptz
86	)`)
87
88	tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
89
90	inputRows := [][]interface{}{
91		{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime},
92		{nil, nil, nil, nil, nil, nil, nil},
93	}
94
95	copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"},
96		pgx.CopyFromSlice(len(inputRows), func(i int) ([]interface{}, error) {
97			return inputRows[i], nil
98		}))
99	if err != nil {
100		t.Errorf("Unexpected error for CopyFrom: %v", err)
101	}
102	if int(copyCount) != len(inputRows) {
103		t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
104	}
105
106	rows, err := conn.Query(context.Background(), "select * from foo")
107	if err != nil {
108		t.Errorf("Unexpected error for Query: %v", err)
109	}
110
111	var outputRows [][]interface{}
112	for rows.Next() {
113		row, err := rows.Values()
114		if err != nil {
115			t.Errorf("Unexpected error for rows.Values(): %v", err)
116		}
117		outputRows = append(outputRows, row)
118	}
119
120	if rows.Err() != nil {
121		t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
122	}
123
124	if !reflect.DeepEqual(inputRows, outputRows) {
125		t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
126	}
127
128	ensureConnValid(t, conn)
129}
130
131func TestConnCopyFromLarge(t *testing.T) {
132	t.Parallel()
133
134	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
135	defer closeConn(t, conn)
136
137	skipCockroachDB(t, conn, "Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/52722)")
138
139	mustExec(t, conn, `create temporary table foo(
140		a int2,
141		b int4,
142		c int8,
143		d varchar,
144		e text,
145		f date,
146		g timestamptz,
147		h bytea
148	)`)
149
150	tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
151
152	inputRows := [][]interface{}{}
153
154	for i := 0; i < 10000; i++ {
155		inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime, []byte{111, 111, 111, 111}})
156	}
157
158	copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows))
159	if err != nil {
160		t.Errorf("Unexpected error for CopyFrom: %v", err)
161	}
162	if int(copyCount) != len(inputRows) {
163		t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
164	}
165
166	rows, err := conn.Query(context.Background(), "select * from foo")
167	if err != nil {
168		t.Errorf("Unexpected error for Query: %v", err)
169	}
170
171	var outputRows [][]interface{}
172	for rows.Next() {
173		row, err := rows.Values()
174		if err != nil {
175			t.Errorf("Unexpected error for rows.Values(): %v", err)
176		}
177		outputRows = append(outputRows, row)
178	}
179
180	if rows.Err() != nil {
181		t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
182	}
183
184	if !reflect.DeepEqual(inputRows, outputRows) {
185		t.Errorf("Input rows and output rows do not equal")
186	}
187
188	ensureConnValid(t, conn)
189}
190
191func TestConnCopyFromEnum(t *testing.T) {
192	t.Parallel()
193
194	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
195	defer closeConn(t, conn)
196
197	ctx := context.Background()
198	tx, err := conn.Begin(ctx)
199	require.NoError(t, err)
200	defer tx.Rollback(ctx)
201
202	_, err = tx.Exec(ctx, `drop type if exists color`)
203	require.NoError(t, err)
204
205	_, err = tx.Exec(ctx, `drop type if exists fruit`)
206	require.NoError(t, err)
207
208	_, err = tx.Exec(ctx, `create type color as enum ('blue', 'green', 'orange')`)
209	require.NoError(t, err)
210
211	_, err = tx.Exec(ctx, `create type fruit as enum ('apple', 'orange', 'grape')`)
212	require.NoError(t, err)
213
214	_, err = tx.Exec(ctx, `create table foo(
215		a text,
216		b color,
217		c fruit,
218		d color,
219		e fruit,
220		f text
221	)`)
222	require.NoError(t, err)
223
224	inputRows := [][]interface{}{
225		{"abc", "blue", "grape", "orange", "orange", "def"},
226		{nil, nil, nil, nil, nil, nil},
227	}
228
229	copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f"}, pgx.CopyFromRows(inputRows))
230	require.NoError(t, err)
231	require.EqualValues(t, len(inputRows), copyCount)
232
233	rows, err := conn.Query(ctx, "select * from foo")
234	require.NoError(t, err)
235
236	var outputRows [][]interface{}
237	for rows.Next() {
238		row, err := rows.Values()
239		require.NoError(t, err)
240		outputRows = append(outputRows, row)
241	}
242
243	require.NoError(t, rows.Err())
244
245	if !reflect.DeepEqual(inputRows, outputRows) {
246		t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
247	}
248
249	ensureConnValid(t, conn)
250}
251
252func TestConnCopyFromJSON(t *testing.T) {
253	t.Parallel()
254
255	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
256	defer closeConn(t, conn)
257
258	for _, typeName := range []string{"json", "jsonb"} {
259		if _, ok := conn.ConnInfo().DataTypeForName(typeName); !ok {
260			return // No JSON/JSONB type -- must be running against old PostgreSQL
261		}
262	}
263
264	mustExec(t, conn, `create temporary table foo(
265		a json,
266		b jsonb
267	)`)
268
269	inputRows := [][]interface{}{
270		{map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}},
271		{nil, nil},
272	}
273
274	copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows))
275	if err != nil {
276		t.Errorf("Unexpected error for CopyFrom: %v", err)
277	}
278	if int(copyCount) != len(inputRows) {
279		t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
280	}
281
282	rows, err := conn.Query(context.Background(), "select * from foo")
283	if err != nil {
284		t.Errorf("Unexpected error for Query: %v", err)
285	}
286
287	var outputRows [][]interface{}
288	for rows.Next() {
289		row, err := rows.Values()
290		if err != nil {
291			t.Errorf("Unexpected error for rows.Values(): %v", err)
292		}
293		outputRows = append(outputRows, row)
294	}
295
296	if rows.Err() != nil {
297		t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
298	}
299
300	if !reflect.DeepEqual(inputRows, outputRows) {
301		t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
302	}
303
304	ensureConnValid(t, conn)
305}
306
307type clientFailSource struct {
308	count int
309	err   error
310}
311
312func (cfs *clientFailSource) Next() bool {
313	cfs.count++
314	return cfs.count < 100
315}
316
317func (cfs *clientFailSource) Values() ([]interface{}, error) {
318	if cfs.count == 3 {
319		cfs.err = fmt.Errorf("client error")
320		return nil, cfs.err
321	}
322	return []interface{}{make([]byte, 100000)}, nil
323}
324
325func (cfs *clientFailSource) Err() error {
326	return cfs.err
327}
328
329func TestConnCopyFromFailServerSideMidway(t *testing.T) {
330	t.Parallel()
331
332	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
333	defer closeConn(t, conn)
334
335	mustExec(t, conn, `create temporary table foo(
336		a int4,
337		b varchar not null
338	)`)
339
340	inputRows := [][]interface{}{
341		{int32(1), "abc"},
342		{int32(2), nil}, // this row should trigger a failure
343		{int32(3), "def"},
344	}
345
346	copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows))
347	if err == nil {
348		t.Errorf("Expected CopyFrom return error, but it did not")
349	}
350	if _, ok := err.(*pgconn.PgError); !ok {
351		t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err)
352	}
353	if copyCount != 0 {
354		t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
355	}
356
357	rows, err := conn.Query(context.Background(), "select * from foo")
358	if err != nil {
359		t.Errorf("Unexpected error for Query: %v", err)
360	}
361
362	var outputRows [][]interface{}
363	for rows.Next() {
364		row, err := rows.Values()
365		if err != nil {
366			t.Errorf("Unexpected error for rows.Values(): %v", err)
367		}
368		outputRows = append(outputRows, row)
369	}
370
371	if rows.Err() != nil {
372		t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
373	}
374
375	if len(outputRows) != 0 {
376		t.Errorf("Expected 0 rows, but got %v", outputRows)
377	}
378
379	mustExec(t, conn, "truncate foo")
380
381	ensureConnValid(t, conn)
382}
383
384type failSource struct {
385	count int
386}
387
388func (fs *failSource) Next() bool {
389	time.Sleep(time.Millisecond * 100)
390	fs.count++
391	return fs.count < 100
392}
393
394func (fs *failSource) Values() ([]interface{}, error) {
395	if fs.count == 3 {
396		return []interface{}{nil}, nil
397	}
398	return []interface{}{make([]byte, 100000)}, nil
399}
400
401func (fs *failSource) Err() error {
402	return nil
403}
404
405func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
406	t.Parallel()
407
408	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
409	defer closeConn(t, conn)
410
411	mustExec(t, conn, `create temporary table foo(
412		a bytea not null
413	)`)
414
415	startTime := time.Now()
416
417	copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, &failSource{})
418	if err == nil {
419		t.Errorf("Expected CopyFrom return error, but it did not")
420	}
421	if _, ok := err.(*pgconn.PgError); !ok {
422		t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err)
423	}
424	if copyCount != 0 {
425		t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
426	}
427
428	endTime := time.Now()
429	copyTime := endTime.Sub(startTime)
430	if copyTime > time.Second {
431		t.Errorf("Failing CopyFrom shouldn't have taken so long: %v", copyTime)
432	}
433
434	rows, err := conn.Query(context.Background(), "select * from foo")
435	if err != nil {
436		t.Errorf("Unexpected error for Query: %v", err)
437	}
438
439	var outputRows [][]interface{}
440	for rows.Next() {
441		row, err := rows.Values()
442		if err != nil {
443			t.Errorf("Unexpected error for rows.Values(): %v", err)
444		}
445		outputRows = append(outputRows, row)
446	}
447
448	if rows.Err() != nil {
449		t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
450	}
451
452	if len(outputRows) != 0 {
453		t.Errorf("Expected 0 rows, but got %v", outputRows)
454	}
455
456	ensureConnValid(t, conn)
457}
458
459type slowFailRaceSource struct {
460	count int
461}
462
463func (fs *slowFailRaceSource) Next() bool {
464	time.Sleep(time.Millisecond)
465	fs.count++
466	return fs.count < 1000
467}
468
469func (fs *slowFailRaceSource) Values() ([]interface{}, error) {
470	if fs.count == 500 {
471		return []interface{}{nil, nil}, nil
472	}
473	return []interface{}{1, make([]byte, 1000)}, nil
474}
475
476func (fs *slowFailRaceSource) Err() error {
477	return nil
478}
479
480func TestConnCopyFromSlowFailRace(t *testing.T) {
481	t.Parallel()
482
483	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
484	defer closeConn(t, conn)
485
486	mustExec(t, conn, `create temporary table foo(
487		a int not null,
488		b bytea not null
489	)`)
490
491	copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b"}, &slowFailRaceSource{})
492	if err == nil {
493		t.Errorf("Expected CopyFrom return error, but it did not")
494	}
495	if _, ok := err.(*pgconn.PgError); !ok {
496		t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err)
497	}
498	if copyCount != 0 {
499		t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
500	}
501
502	ensureConnValid(t, conn)
503}
504
505func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) {
506	t.Parallel()
507
508	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
509	defer closeConn(t, conn)
510
511	mustExec(t, conn, `create temporary table foo(
512		a bytea not null
513	)`)
514
515	copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, &clientFailSource{})
516	if err == nil {
517		t.Errorf("Expected CopyFrom return error, but it did not")
518	}
519	if copyCount != 0 {
520		t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
521	}
522
523	rows, err := conn.Query(context.Background(), "select * from foo")
524	if err != nil {
525		t.Errorf("Unexpected error for Query: %v", err)
526	}
527
528	var outputRows [][]interface{}
529	for rows.Next() {
530		row, err := rows.Values()
531		if err != nil {
532			t.Errorf("Unexpected error for rows.Values(): %v", err)
533		}
534		outputRows = append(outputRows, row)
535	}
536
537	if rows.Err() != nil {
538		t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
539	}
540
541	if len(outputRows) != 0 {
542		t.Errorf("Expected 0 rows, but got %v", len(outputRows))
543	}
544
545	ensureConnValid(t, conn)
546}
547
548type clientFinalErrSource struct {
549	count int
550}
551
552func (cfs *clientFinalErrSource) Next() bool {
553	cfs.count++
554	return cfs.count < 5
555}
556
557func (cfs *clientFinalErrSource) Values() ([]interface{}, error) {
558	return []interface{}{make([]byte, 100000)}, nil
559}
560
561func (cfs *clientFinalErrSource) Err() error {
562	return fmt.Errorf("final error")
563}
564
565func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {
566	t.Parallel()
567
568	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
569	defer closeConn(t, conn)
570
571	mustExec(t, conn, `create temporary table foo(
572		a bytea not null
573	)`)
574
575	copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, &clientFinalErrSource{})
576	if err == nil {
577		t.Errorf("Expected CopyFrom return error, but it did not")
578	}
579	if copyCount != 0 {
580		t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
581	}
582
583	rows, err := conn.Query(context.Background(), "select * from foo")
584	if err != nil {
585		t.Errorf("Unexpected error for Query: %v", err)
586	}
587
588	var outputRows [][]interface{}
589	for rows.Next() {
590		row, err := rows.Values()
591		if err != nil {
592			t.Errorf("Unexpected error for rows.Values(): %v", err)
593		}
594		outputRows = append(outputRows, row)
595	}
596
597	if rows.Err() != nil {
598		t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
599	}
600
601	if len(outputRows) != 0 {
602		t.Errorf("Expected 0 rows, but got %v", outputRows)
603	}
604
605	ensureConnValid(t, conn)
606}
607