1package pq
2
3import (
4	"bytes"
5	"database/sql"
6	"database/sql/driver"
7	"net"
8	"strings"
9	"testing"
10)
11
12func TestCopyInStmt(t *testing.T) {
13	stmt := CopyIn("table name")
14	if stmt != `COPY "table name" () FROM STDIN` {
15		t.Fatal(stmt)
16	}
17
18	stmt = CopyIn("table name", "column 1", "column 2")
19	if stmt != `COPY "table name" ("column 1", "column 2") FROM STDIN` {
20		t.Fatal(stmt)
21	}
22
23	stmt = CopyIn(`table " name """`, `co"lumn""`)
24	if stmt != `COPY "table "" name """"""" ("co""lumn""""") FROM STDIN` {
25		t.Fatal(stmt)
26	}
27}
28
29func TestCopyInSchemaStmt(t *testing.T) {
30	stmt := CopyInSchema("schema name", "table name")
31	if stmt != `COPY "schema name"."table name" () FROM STDIN` {
32		t.Fatal(stmt)
33	}
34
35	stmt = CopyInSchema("schema name", "table name", "column 1", "column 2")
36	if stmt != `COPY "schema name"."table name" ("column 1", "column 2") FROM STDIN` {
37		t.Fatal(stmt)
38	}
39
40	stmt = CopyInSchema(`schema " name """`, `table " name """`, `co"lumn""`)
41	if stmt != `COPY "schema "" name """"""".`+
42		`"table "" name """"""" ("co""lumn""""") FROM STDIN` {
43		t.Fatal(stmt)
44	}
45}
46
47func TestCopyInMultipleValues(t *testing.T) {
48	db := openTestConn(t)
49	defer db.Close()
50
51	txn, err := db.Begin()
52	if err != nil {
53		t.Fatal(err)
54	}
55	defer txn.Rollback()
56
57	_, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)")
58	if err != nil {
59		t.Fatal(err)
60	}
61
62	stmt, err := txn.Prepare(CopyIn("temp", "a", "b"))
63	if err != nil {
64		t.Fatal(err)
65	}
66
67	longString := strings.Repeat("#", 500)
68
69	for i := 0; i < 500; i++ {
70		_, err = stmt.Exec(int64(i), longString)
71		if err != nil {
72			t.Fatal(err)
73		}
74	}
75
76	result, err := stmt.Exec()
77	if err != nil {
78		t.Fatal(err)
79	}
80
81	rowsAffected, err := result.RowsAffected()
82	if err != nil {
83		t.Fatal(err)
84	}
85
86	if rowsAffected != 500 {
87		t.Fatalf("expected 500 rows affected, not %d", rowsAffected)
88	}
89
90	err = stmt.Close()
91	if err != nil {
92		t.Fatal(err)
93	}
94
95	var num int
96	err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num)
97	if err != nil {
98		t.Fatal(err)
99	}
100
101	if num != 500 {
102		t.Fatalf("expected 500 items, not %d", num)
103	}
104}
105
106func TestCopyInRaiseStmtTrigger(t *testing.T) {
107	db := openTestConn(t)
108	defer db.Close()
109
110	if getServerVersion(t, db) < 90000 {
111		var exists int
112		err := db.QueryRow("SELECT 1 FROM pg_language WHERE lanname = 'plpgsql'").Scan(&exists)
113		if err == sql.ErrNoRows {
114			t.Skip("language PL/PgSQL does not exist; skipping TestCopyInRaiseStmtTrigger")
115		} else if err != nil {
116			t.Fatal(err)
117		}
118	}
119
120	txn, err := db.Begin()
121	if err != nil {
122		t.Fatal(err)
123	}
124	defer txn.Rollback()
125
126	_, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)")
127	if err != nil {
128		t.Fatal(err)
129	}
130
131	_, err = txn.Exec(`
132			CREATE OR REPLACE FUNCTION pg_temp.temptest()
133			RETURNS trigger AS
134			$BODY$ begin
135				raise notice 'Hello world';
136				return new;
137			end $BODY$
138			LANGUAGE plpgsql`)
139	if err != nil {
140		t.Fatal(err)
141	}
142
143	_, err = txn.Exec(`
144			CREATE TRIGGER temptest_trigger
145			BEFORE INSERT
146			ON temp
147			FOR EACH ROW
148			EXECUTE PROCEDURE pg_temp.temptest()`)
149	if err != nil {
150		t.Fatal(err)
151	}
152
153	stmt, err := txn.Prepare(CopyIn("temp", "a", "b"))
154	if err != nil {
155		t.Fatal(err)
156	}
157
158	longString := strings.Repeat("#", 500)
159
160	_, err = stmt.Exec(int64(1), longString)
161	if err != nil {
162		t.Fatal(err)
163	}
164
165	_, err = stmt.Exec()
166	if err != nil {
167		t.Fatal(err)
168	}
169
170	err = stmt.Close()
171	if err != nil {
172		t.Fatal(err)
173	}
174
175	var num int
176	err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num)
177	if err != nil {
178		t.Fatal(err)
179	}
180
181	if num != 1 {
182		t.Fatalf("expected 1 items, not %d", num)
183	}
184}
185
186func TestCopyInTypes(t *testing.T) {
187	db := openTestConn(t)
188	defer db.Close()
189
190	txn, err := db.Begin()
191	if err != nil {
192		t.Fatal(err)
193	}
194	defer txn.Rollback()
195
196	_, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER, text VARCHAR, blob BYTEA, nothing VARCHAR)")
197	if err != nil {
198		t.Fatal(err)
199	}
200
201	stmt, err := txn.Prepare(CopyIn("temp", "num", "text", "blob", "nothing"))
202	if err != nil {
203		t.Fatal(err)
204	}
205
206	_, err = stmt.Exec(int64(1234567890), "Héllö\n ☃!\r\t\\", []byte{0, 255, 9, 10, 13}, nil)
207	if err != nil {
208		t.Fatal(err)
209	}
210
211	_, err = stmt.Exec()
212	if err != nil {
213		t.Fatal(err)
214	}
215
216	err = stmt.Close()
217	if err != nil {
218		t.Fatal(err)
219	}
220
221	var num int
222	var text string
223	var blob []byte
224	var nothing sql.NullString
225
226	err = txn.QueryRow("SELECT * FROM temp").Scan(&num, &text, &blob, &nothing)
227	if err != nil {
228		t.Fatal(err)
229	}
230
231	if num != 1234567890 {
232		t.Fatal("unexpected result", num)
233	}
234	if text != "Héllö\n ☃!\r\t\\" {
235		t.Fatal("unexpected result", text)
236	}
237	if !bytes.Equal(blob, []byte{0, 255, 9, 10, 13}) {
238		t.Fatal("unexpected result", blob)
239	}
240	if nothing.Valid {
241		t.Fatal("unexpected result", nothing.String)
242	}
243}
244
245func TestCopyInWrongType(t *testing.T) {
246	db := openTestConn(t)
247	defer db.Close()
248
249	txn, err := db.Begin()
250	if err != nil {
251		t.Fatal(err)
252	}
253	defer txn.Rollback()
254
255	_, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)")
256	if err != nil {
257		t.Fatal(err)
258	}
259
260	stmt, err := txn.Prepare(CopyIn("temp", "num"))
261	if err != nil {
262		t.Fatal(err)
263	}
264	defer stmt.Close()
265
266	_, err = stmt.Exec("Héllö\n ☃!\r\t\\")
267	if err != nil {
268		t.Fatal(err)
269	}
270
271	_, err = stmt.Exec()
272	if err == nil {
273		t.Fatal("expected error")
274	}
275	if pge := err.(*Error); pge.Code.Name() != "invalid_text_representation" {
276		t.Fatalf("expected 'invalid input syntax for integer' error, got %s (%+v)", pge.Code.Name(), pge)
277	}
278}
279
280func TestCopyOutsideOfTxnError(t *testing.T) {
281	db := openTestConn(t)
282	defer db.Close()
283
284	_, err := db.Prepare(CopyIn("temp", "num"))
285	if err == nil {
286		t.Fatal("COPY outside of transaction did not return an error")
287	}
288	if err != errCopyNotSupportedOutsideTxn {
289		t.Fatalf("expected %s, got %s", err, err.Error())
290	}
291}
292
293func TestCopyInBinaryError(t *testing.T) {
294	db := openTestConn(t)
295	defer db.Close()
296
297	txn, err := db.Begin()
298	if err != nil {
299		t.Fatal(err)
300	}
301	defer txn.Rollback()
302
303	_, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)")
304	if err != nil {
305		t.Fatal(err)
306	}
307	_, err = txn.Prepare("COPY temp (num) FROM STDIN WITH binary")
308	if err != errBinaryCopyNotSupported {
309		t.Fatalf("expected %s, got %+v", errBinaryCopyNotSupported, err)
310	}
311	// check that the protocol is in a valid state
312	err = txn.Rollback()
313	if err != nil {
314		t.Fatal(err)
315	}
316}
317
318func TestCopyFromError(t *testing.T) {
319	db := openTestConn(t)
320	defer db.Close()
321
322	txn, err := db.Begin()
323	if err != nil {
324		t.Fatal(err)
325	}
326	defer txn.Rollback()
327
328	_, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)")
329	if err != nil {
330		t.Fatal(err)
331	}
332	_, err = txn.Prepare("COPY temp (num) TO STDOUT")
333	if err != errCopyToNotSupported {
334		t.Fatalf("expected %s, got %+v", errCopyToNotSupported, err)
335	}
336	// check that the protocol is in a valid state
337	err = txn.Rollback()
338	if err != nil {
339		t.Fatal(err)
340	}
341}
342
343func TestCopySyntaxError(t *testing.T) {
344	db := openTestConn(t)
345	defer db.Close()
346
347	txn, err := db.Begin()
348	if err != nil {
349		t.Fatal(err)
350	}
351	defer txn.Rollback()
352
353	_, err = txn.Prepare("COPY ")
354	if err == nil {
355		t.Fatal("expected error")
356	}
357	if pge := err.(*Error); pge.Code.Name() != "syntax_error" {
358		t.Fatalf("expected syntax error, got %s (%+v)", pge.Code.Name(), pge)
359	}
360	// check that the protocol is in a valid state
361	err = txn.Rollback()
362	if err != nil {
363		t.Fatal(err)
364	}
365}
366
367// Tests for connection errors in copyin.resploop()
368func TestCopyRespLoopConnectionError(t *testing.T) {
369	db := openTestConn(t)
370	defer db.Close()
371
372	txn, err := db.Begin()
373	if err != nil {
374		t.Fatal(err)
375	}
376	defer txn.Rollback()
377
378	var pid int
379	err = txn.QueryRow("SELECT pg_backend_pid()").Scan(&pid)
380	if err != nil {
381		t.Fatal(err)
382	}
383
384	_, err = txn.Exec("CREATE TEMP TABLE temp (a int)")
385	if err != nil {
386		t.Fatal(err)
387	}
388
389	stmt, err := txn.Prepare(CopyIn("temp", "a"))
390	if err != nil {
391		t.Fatal(err)
392	}
393	defer stmt.Close()
394
395	_, err = db.Exec("SELECT pg_terminate_backend($1)", pid)
396	if err != nil {
397		t.Fatal(err)
398	}
399
400	if getServerVersion(t, db) < 90500 {
401		// We have to try and send something over, since postgres before
402		// version 9.5 won't process SIGTERMs while it's waiting for
403		// CopyData/CopyEnd messages; see tcop/postgres.c.
404		_, err = stmt.Exec(1)
405		if err != nil {
406			t.Fatal(err)
407		}
408	}
409	_, err = stmt.Exec()
410	if err == nil {
411		t.Fatalf("expected error")
412	}
413	switch pge := err.(type) {
414	case *Error:
415		if pge.Code.Name() != "admin_shutdown" {
416			t.Fatalf("expected admin_shutdown, got %s", pge.Code.Name())
417		}
418	case *net.OpError:
419		// ignore
420	default:
421		if err == driver.ErrBadConn {
422			// likely an EPIPE
423		} else {
424			t.Fatalf("unexpected error, got %+#v", err)
425		}
426	}
427
428	_ = stmt.Close()
429}
430
431func BenchmarkCopyIn(b *testing.B) {
432	db := openTestConn(b)
433	defer db.Close()
434
435	txn, err := db.Begin()
436	if err != nil {
437		b.Fatal(err)
438	}
439	defer txn.Rollback()
440
441	_, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)")
442	if err != nil {
443		b.Fatal(err)
444	}
445
446	stmt, err := txn.Prepare(CopyIn("temp", "a", "b"))
447	if err != nil {
448		b.Fatal(err)
449	}
450
451	for i := 0; i < b.N; i++ {
452		_, err = stmt.Exec(int64(i), "hello world!")
453		if err != nil {
454			b.Fatal(err)
455		}
456	}
457
458	_, err = stmt.Exec()
459	if err != nil {
460		b.Fatal(err)
461	}
462
463	err = stmt.Close()
464	if err != nil {
465		b.Fatal(err)
466	}
467
468	var num int
469	err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num)
470	if err != nil {
471		b.Fatal(err)
472	}
473
474	if num != b.N {
475		b.Fatalf("expected %d items, not %d", b.N, num)
476	}
477}
478