1package pq
2
3import (
4	"bytes"
5	"database/sql"
6	"database/sql/driver"
7	"fmt"
8	"net"
9	"strings"
10	"testing"
11	"time"
12)
13
14func TestCopyInStmt(t *testing.T) {
15	stmt := CopyIn("table name")
16	if stmt != `COPY "table name" () FROM STDIN` {
17		t.Fatal(stmt)
18	}
19
20	stmt = CopyIn("table name", "column 1", "column 2")
21	if stmt != `COPY "table name" ("column 1", "column 2") FROM STDIN` {
22		t.Fatal(stmt)
23	}
24
25	stmt = CopyIn(`table " name """`, `co"lumn""`)
26	if stmt != `COPY "table "" name """"""" ("co""lumn""""") FROM STDIN` {
27		t.Fatal(stmt)
28	}
29}
30
31func TestCopyInSchemaStmt(t *testing.T) {
32	stmt := CopyInSchema("schema name", "table name")
33	if stmt != `COPY "schema name"."table name" () FROM STDIN` {
34		t.Fatal(stmt)
35	}
36
37	stmt = CopyInSchema("schema name", "table name", "column 1", "column 2")
38	if stmt != `COPY "schema name"."table name" ("column 1", "column 2") FROM STDIN` {
39		t.Fatal(stmt)
40	}
41
42	stmt = CopyInSchema(`schema " name """`, `table " name """`, `co"lumn""`)
43	if stmt != `COPY "schema "" name """"""".`+
44		`"table "" name """"""" ("co""lumn""""") FROM STDIN` {
45		t.Fatal(stmt)
46	}
47}
48
49func TestCopyInMultipleValues(t *testing.T) {
50	db := openTestConn(t)
51	defer db.Close()
52
53	txn, err := db.Begin()
54	if err != nil {
55		t.Fatal(err)
56	}
57	defer txn.Rollback()
58
59	_, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)")
60	if err != nil {
61		t.Fatal(err)
62	}
63
64	stmt, err := txn.Prepare(CopyIn("temp", "a", "b"))
65	if err != nil {
66		t.Fatal(err)
67	}
68
69	longString := strings.Repeat("#", 500)
70
71	for i := 0; i < 500; i++ {
72		_, err = stmt.Exec(int64(i), longString)
73		if err != nil {
74			t.Fatal(err)
75		}
76	}
77
78	result, err := stmt.Exec()
79	if err != nil {
80		t.Fatal(err)
81	}
82
83	rowsAffected, err := result.RowsAffected()
84	if err != nil {
85		t.Fatal(err)
86	}
87
88	if rowsAffected != 500 {
89		t.Fatalf("expected 500 rows affected, not %d", rowsAffected)
90	}
91
92	err = stmt.Close()
93	if err != nil {
94		t.Fatal(err)
95	}
96
97	var num int
98	err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num)
99	if err != nil {
100		t.Fatal(err)
101	}
102
103	if num != 500 {
104		t.Fatalf("expected 500 items, not %d", num)
105	}
106}
107
108func TestCopyInRaiseStmtTrigger(t *testing.T) {
109	db := openTestConn(t)
110	defer db.Close()
111
112	if getServerVersion(t, db) < 90000 {
113		var exists int
114		err := db.QueryRow("SELECT 1 FROM pg_language WHERE lanname = 'plpgsql'").Scan(&exists)
115		if err == sql.ErrNoRows {
116			t.Skip("language PL/PgSQL does not exist; skipping TestCopyInRaiseStmtTrigger")
117		} else if err != nil {
118			t.Fatal(err)
119		}
120	}
121
122	txn, err := db.Begin()
123	if err != nil {
124		t.Fatal(err)
125	}
126	defer txn.Rollback()
127
128	_, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)")
129	if err != nil {
130		t.Fatal(err)
131	}
132
133	_, err = txn.Exec(`
134			CREATE OR REPLACE FUNCTION pg_temp.temptest()
135			RETURNS trigger AS
136			$BODY$ begin
137				raise notice 'Hello world';
138				return new;
139			end $BODY$
140			LANGUAGE plpgsql`)
141	if err != nil {
142		t.Fatal(err)
143	}
144
145	_, err = txn.Exec(`
146			CREATE TRIGGER temptest_trigger
147			BEFORE INSERT
148			ON temp
149			FOR EACH ROW
150			EXECUTE PROCEDURE pg_temp.temptest()`)
151	if err != nil {
152		t.Fatal(err)
153	}
154
155	stmt, err := txn.Prepare(CopyIn("temp", "a", "b"))
156	if err != nil {
157		t.Fatal(err)
158	}
159
160	longString := strings.Repeat("#", 500)
161
162	_, err = stmt.Exec(int64(1), longString)
163	if err != nil {
164		t.Fatal(err)
165	}
166
167	_, err = stmt.Exec()
168	if err != nil {
169		t.Fatal(err)
170	}
171
172	err = stmt.Close()
173	if err != nil {
174		t.Fatal(err)
175	}
176
177	var num int
178	err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num)
179	if err != nil {
180		t.Fatal(err)
181	}
182
183	if num != 1 {
184		t.Fatalf("expected 1 items, not %d", num)
185	}
186}
187
188func TestCopyInTypes(t *testing.T) {
189	db := openTestConn(t)
190	defer db.Close()
191
192	txn, err := db.Begin()
193	if err != nil {
194		t.Fatal(err)
195	}
196	defer txn.Rollback()
197
198	_, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER, text VARCHAR, blob BYTEA, nothing VARCHAR)")
199	if err != nil {
200		t.Fatal(err)
201	}
202
203	stmt, err := txn.Prepare(CopyIn("temp", "num", "text", "blob", "nothing"))
204	if err != nil {
205		t.Fatal(err)
206	}
207
208	_, err = stmt.Exec(int64(1234567890), "Héllö\n ☃!\r\t\\", []byte{0, 255, 9, 10, 13}, nil)
209	if err != nil {
210		t.Fatal(err)
211	}
212
213	_, err = stmt.Exec()
214	if err != nil {
215		t.Fatal(err)
216	}
217
218	err = stmt.Close()
219	if err != nil {
220		t.Fatal(err)
221	}
222
223	var num int
224	var text string
225	var blob []byte
226	var nothing sql.NullString
227
228	err = txn.QueryRow("SELECT * FROM temp").Scan(&num, &text, &blob, &nothing)
229	if err != nil {
230		t.Fatal(err)
231	}
232
233	if num != 1234567890 {
234		t.Fatal("unexpected result", num)
235	}
236	if text != "Héllö\n ☃!\r\t\\" {
237		t.Fatal("unexpected result", text)
238	}
239	if !bytes.Equal(blob, []byte{0, 255, 9, 10, 13}) {
240		t.Fatal("unexpected result", blob)
241	}
242	if nothing.Valid {
243		t.Fatal("unexpected result", nothing.String)
244	}
245}
246
247func TestCopyInWrongType(t *testing.T) {
248	db := openTestConn(t)
249	defer db.Close()
250
251	txn, err := db.Begin()
252	if err != nil {
253		t.Fatal(err)
254	}
255	defer txn.Rollback()
256
257	_, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)")
258	if err != nil {
259		t.Fatal(err)
260	}
261
262	stmt, err := txn.Prepare(CopyIn("temp", "num"))
263	if err != nil {
264		t.Fatal(err)
265	}
266	defer stmt.Close()
267
268	_, err = stmt.Exec("Héllö\n ☃!\r\t\\")
269	if err != nil {
270		t.Fatal(err)
271	}
272
273	_, err = stmt.Exec()
274	if err == nil {
275		t.Fatal("expected error")
276	}
277	if pge := err.(*Error); pge.Code.Name() != "invalid_text_representation" {
278		t.Fatalf("expected 'invalid input syntax for integer' error, got %s (%+v)", pge.Code.Name(), pge)
279	}
280}
281
282func TestCopyOutsideOfTxnError(t *testing.T) {
283	db := openTestConn(t)
284	defer db.Close()
285
286	_, err := db.Prepare(CopyIn("temp", "num"))
287	if err == nil {
288		t.Fatal("COPY outside of transaction did not return an error")
289	}
290	if err != errCopyNotSupportedOutsideTxn {
291		t.Fatalf("expected %s, got %s", err, err.Error())
292	}
293}
294
295func TestCopyInBinaryError(t *testing.T) {
296	db := openTestConn(t)
297	defer db.Close()
298
299	txn, err := db.Begin()
300	if err != nil {
301		t.Fatal(err)
302	}
303	defer txn.Rollback()
304
305	_, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)")
306	if err != nil {
307		t.Fatal(err)
308	}
309	_, err = txn.Prepare("COPY temp (num) FROM STDIN WITH binary")
310	if err != errBinaryCopyNotSupported {
311		t.Fatalf("expected %s, got %+v", errBinaryCopyNotSupported, err)
312	}
313	// check that the protocol is in a valid state
314	err = txn.Rollback()
315	if err != nil {
316		t.Fatal(err)
317	}
318}
319
320func TestCopyFromError(t *testing.T) {
321	db := openTestConn(t)
322	defer db.Close()
323
324	txn, err := db.Begin()
325	if err != nil {
326		t.Fatal(err)
327	}
328	defer txn.Rollback()
329
330	_, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)")
331	if err != nil {
332		t.Fatal(err)
333	}
334	_, err = txn.Prepare("COPY temp (num) TO STDOUT")
335	if err != errCopyToNotSupported {
336		t.Fatalf("expected %s, got %+v", errCopyToNotSupported, err)
337	}
338	// check that the protocol is in a valid state
339	err = txn.Rollback()
340	if err != nil {
341		t.Fatal(err)
342	}
343}
344
345func TestCopySyntaxError(t *testing.T) {
346	db := openTestConn(t)
347	defer db.Close()
348
349	txn, err := db.Begin()
350	if err != nil {
351		t.Fatal(err)
352	}
353	defer txn.Rollback()
354
355	_, err = txn.Prepare("COPY ")
356	if err == nil {
357		t.Fatal("expected error")
358	}
359	if pge := err.(*Error); pge.Code.Name() != "syntax_error" {
360		t.Fatalf("expected syntax error, got %s (%+v)", pge.Code.Name(), pge)
361	}
362	// check that the protocol is in a valid state
363	err = txn.Rollback()
364	if err != nil {
365		t.Fatal(err)
366	}
367}
368
369// Tests for connection errors in copyin.resploop()
370func TestCopyRespLoopConnectionError(t *testing.T) {
371	db := openTestConn(t)
372	defer db.Close()
373
374	txn, err := db.Begin()
375	if err != nil {
376		t.Fatal(err)
377	}
378	defer txn.Rollback()
379
380	var pid int
381	err = txn.QueryRow("SELECT pg_backend_pid()").Scan(&pid)
382	if err != nil {
383		t.Fatal(err)
384	}
385
386	_, err = txn.Exec("CREATE TEMP TABLE temp (a int)")
387	if err != nil {
388		t.Fatal(err)
389	}
390
391	stmt, err := txn.Prepare(CopyIn("temp", "a"))
392	if err != nil {
393		t.Fatal(err)
394	}
395	defer stmt.Close()
396
397	_, err = db.Exec("SELECT pg_terminate_backend($1)", pid)
398	if err != nil {
399		t.Fatal(err)
400	}
401
402	if getServerVersion(t, db) < 90500 {
403		// We have to try and send something over, since postgres before
404		// version 9.5 won't process SIGTERMs while it's waiting for
405		// CopyData/CopyEnd messages; see tcop/postgres.c.
406		_, err = stmt.Exec(1)
407		if err != nil {
408			t.Fatal(err)
409		}
410	}
411	retry(t, time.Second*5, func() error {
412		_, err = stmt.Exec()
413		if err == nil {
414			return fmt.Errorf("expected error")
415		}
416		return nil
417	})
418	switch pge := err.(type) {
419	case *Error:
420		if pge.Code.Name() != "admin_shutdown" {
421			t.Fatalf("expected admin_shutdown, got %s", pge.Code.Name())
422		}
423	case *net.OpError:
424		// ignore
425	default:
426		if err == driver.ErrBadConn {
427			// likely an EPIPE
428		} else if err == errCopyInClosed {
429			// ignore
430		} else {
431			t.Fatalf("unexpected error, got %+#v", err)
432		}
433	}
434
435	_ = stmt.Close()
436}
437
438// retry executes f in a backoff loop until it doesn't return an error. If this
439// doesn't happen within duration, t.Fatal is called with the latest error.
440func retry(t *testing.T, duration time.Duration, f func() error) {
441	start := time.Now()
442	next := time.Millisecond * 100
443	for {
444		err := f()
445		if err == nil {
446			return
447		}
448		if time.Since(start) > duration {
449			t.Fatal(err)
450		}
451		time.Sleep(next)
452		next *= 2
453	}
454}
455
456func BenchmarkCopyIn(b *testing.B) {
457	db := openTestConn(b)
458	defer db.Close()
459
460	txn, err := db.Begin()
461	if err != nil {
462		b.Fatal(err)
463	}
464	defer txn.Rollback()
465
466	_, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)")
467	if err != nil {
468		b.Fatal(err)
469	}
470
471	stmt, err := txn.Prepare(CopyIn("temp", "a", "b"))
472	if err != nil {
473		b.Fatal(err)
474	}
475
476	for i := 0; i < b.N; i++ {
477		_, err = stmt.Exec(int64(i), "hello world!")
478		if err != nil {
479			b.Fatal(err)
480		}
481	}
482
483	_, err = stmt.Exec()
484	if err != nil {
485		b.Fatal(err)
486	}
487
488	err = stmt.Close()
489	if err != nil {
490		b.Fatal(err)
491	}
492
493	var num int
494	err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num)
495	if err != nil {
496		b.Fatal(err)
497	}
498
499	if num != b.N {
500		b.Fatalf("expected %d items, not %d", b.N, num)
501	}
502}
503