1package pq 2 3import ( 4 "bytes" 5 "database/sql" 6 "database/sql/driver" 7 "net" 8 "strings" 9 "testing" 10) 11 12func TestCopyInStmt(t *testing.T) { 13 stmt := CopyIn("table name") 14 if stmt != `COPY "table name" () FROM STDIN` { 15 t.Fatal(stmt) 16 } 17 18 stmt = CopyIn("table name", "column 1", "column 2") 19 if stmt != `COPY "table name" ("column 1", "column 2") FROM STDIN` { 20 t.Fatal(stmt) 21 } 22 23 stmt = CopyIn(`table " name """`, `co"lumn""`) 24 if stmt != `COPY "table "" name """"""" ("co""lumn""""") FROM STDIN` { 25 t.Fatal(stmt) 26 } 27} 28 29func TestCopyInSchemaStmt(t *testing.T) { 30 stmt := CopyInSchema("schema name", "table name") 31 if stmt != `COPY "schema name"."table name" () FROM STDIN` { 32 t.Fatal(stmt) 33 } 34 35 stmt = CopyInSchema("schema name", "table name", "column 1", "column 2") 36 if stmt != `COPY "schema name"."table name" ("column 1", "column 2") FROM STDIN` { 37 t.Fatal(stmt) 38 } 39 40 stmt = CopyInSchema(`schema " name """`, `table " name """`, `co"lumn""`) 41 if stmt != `COPY "schema "" name """"""".`+ 42 `"table "" name """"""" ("co""lumn""""") FROM STDIN` { 43 t.Fatal(stmt) 44 } 45} 46 47func TestCopyInMultipleValues(t *testing.T) { 48 db := openTestConn(t) 49 defer db.Close() 50 51 txn, err := db.Begin() 52 if err != nil { 53 t.Fatal(err) 54 } 55 defer txn.Rollback() 56 57 _, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)") 58 if err != nil { 59 t.Fatal(err) 60 } 61 62 stmt, err := txn.Prepare(CopyIn("temp", "a", "b")) 63 if err != nil { 64 t.Fatal(err) 65 } 66 67 longString := strings.Repeat("#", 500) 68 69 for i := 0; i < 500; i++ { 70 _, err = stmt.Exec(int64(i), longString) 71 if err != nil { 72 t.Fatal(err) 73 } 74 } 75 76 result, err := stmt.Exec() 77 if err != nil { 78 t.Fatal(err) 79 } 80 81 rowsAffected, err := result.RowsAffected() 82 if err != nil { 83 t.Fatal(err) 84 } 85 86 if rowsAffected != 500 { 87 t.Fatalf("expected 500 rows affected, not %d", rowsAffected) 88 } 89 90 err = stmt.Close() 91 if err != nil { 92 t.Fatal(err) 93 } 94 95 var num int 96 err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num) 97 if err != nil { 98 t.Fatal(err) 99 } 100 101 if num != 500 { 102 t.Fatalf("expected 500 items, not %d", num) 103 } 104} 105 106func TestCopyInRaiseStmtTrigger(t *testing.T) { 107 db := openTestConn(t) 108 defer db.Close() 109 110 if getServerVersion(t, db) < 90000 { 111 var exists int 112 err := db.QueryRow("SELECT 1 FROM pg_language WHERE lanname = 'plpgsql'").Scan(&exists) 113 if err == sql.ErrNoRows { 114 t.Skip("language PL/PgSQL does not exist; skipping TestCopyInRaiseStmtTrigger") 115 } else if err != nil { 116 t.Fatal(err) 117 } 118 } 119 120 txn, err := db.Begin() 121 if err != nil { 122 t.Fatal(err) 123 } 124 defer txn.Rollback() 125 126 _, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)") 127 if err != nil { 128 t.Fatal(err) 129 } 130 131 _, err = txn.Exec(` 132 CREATE OR REPLACE FUNCTION pg_temp.temptest() 133 RETURNS trigger AS 134 $BODY$ begin 135 raise notice 'Hello world'; 136 return new; 137 end $BODY$ 138 LANGUAGE plpgsql`) 139 if err != nil { 140 t.Fatal(err) 141 } 142 143 _, err = txn.Exec(` 144 CREATE TRIGGER temptest_trigger 145 BEFORE INSERT 146 ON temp 147 FOR EACH ROW 148 EXECUTE PROCEDURE pg_temp.temptest()`) 149 if err != nil { 150 t.Fatal(err) 151 } 152 153 stmt, err := txn.Prepare(CopyIn("temp", "a", "b")) 154 if err != nil { 155 t.Fatal(err) 156 } 157 158 longString := strings.Repeat("#", 500) 159 160 _, err = stmt.Exec(int64(1), longString) 161 if err != nil { 162 t.Fatal(err) 163 } 164 165 _, err = stmt.Exec() 166 if err != nil { 167 t.Fatal(err) 168 } 169 170 err = stmt.Close() 171 if err != nil { 172 t.Fatal(err) 173 } 174 175 var num int 176 err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num) 177 if err != nil { 178 t.Fatal(err) 179 } 180 181 if num != 1 { 182 t.Fatalf("expected 1 items, not %d", num) 183 } 184} 185 186func TestCopyInTypes(t *testing.T) { 187 db := openTestConn(t) 188 defer db.Close() 189 190 txn, err := db.Begin() 191 if err != nil { 192 t.Fatal(err) 193 } 194 defer txn.Rollback() 195 196 _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER, text VARCHAR, blob BYTEA, nothing VARCHAR)") 197 if err != nil { 198 t.Fatal(err) 199 } 200 201 stmt, err := txn.Prepare(CopyIn("temp", "num", "text", "blob", "nothing")) 202 if err != nil { 203 t.Fatal(err) 204 } 205 206 _, err = stmt.Exec(int64(1234567890), "Héllö\n ☃!\r\t\\", []byte{0, 255, 9, 10, 13}, nil) 207 if err != nil { 208 t.Fatal(err) 209 } 210 211 _, err = stmt.Exec() 212 if err != nil { 213 t.Fatal(err) 214 } 215 216 err = stmt.Close() 217 if err != nil { 218 t.Fatal(err) 219 } 220 221 var num int 222 var text string 223 var blob []byte 224 var nothing sql.NullString 225 226 err = txn.QueryRow("SELECT * FROM temp").Scan(&num, &text, &blob, ¬hing) 227 if err != nil { 228 t.Fatal(err) 229 } 230 231 if num != 1234567890 { 232 t.Fatal("unexpected result", num) 233 } 234 if text != "Héllö\n ☃!\r\t\\" { 235 t.Fatal("unexpected result", text) 236 } 237 if !bytes.Equal(blob, []byte{0, 255, 9, 10, 13}) { 238 t.Fatal("unexpected result", blob) 239 } 240 if nothing.Valid { 241 t.Fatal("unexpected result", nothing.String) 242 } 243} 244 245func TestCopyInWrongType(t *testing.T) { 246 db := openTestConn(t) 247 defer db.Close() 248 249 txn, err := db.Begin() 250 if err != nil { 251 t.Fatal(err) 252 } 253 defer txn.Rollback() 254 255 _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)") 256 if err != nil { 257 t.Fatal(err) 258 } 259 260 stmt, err := txn.Prepare(CopyIn("temp", "num")) 261 if err != nil { 262 t.Fatal(err) 263 } 264 defer stmt.Close() 265 266 _, err = stmt.Exec("Héllö\n ☃!\r\t\\") 267 if err != nil { 268 t.Fatal(err) 269 } 270 271 _, err = stmt.Exec() 272 if err == nil { 273 t.Fatal("expected error") 274 } 275 if pge := err.(*Error); pge.Code.Name() != "invalid_text_representation" { 276 t.Fatalf("expected 'invalid input syntax for integer' error, got %s (%+v)", pge.Code.Name(), pge) 277 } 278} 279 280func TestCopyOutsideOfTxnError(t *testing.T) { 281 db := openTestConn(t) 282 defer db.Close() 283 284 _, err := db.Prepare(CopyIn("temp", "num")) 285 if err == nil { 286 t.Fatal("COPY outside of transaction did not return an error") 287 } 288 if err != errCopyNotSupportedOutsideTxn { 289 t.Fatalf("expected %s, got %s", err, err.Error()) 290 } 291} 292 293func TestCopyInBinaryError(t *testing.T) { 294 db := openTestConn(t) 295 defer db.Close() 296 297 txn, err := db.Begin() 298 if err != nil { 299 t.Fatal(err) 300 } 301 defer txn.Rollback() 302 303 _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)") 304 if err != nil { 305 t.Fatal(err) 306 } 307 _, err = txn.Prepare("COPY temp (num) FROM STDIN WITH binary") 308 if err != errBinaryCopyNotSupported { 309 t.Fatalf("expected %s, got %+v", errBinaryCopyNotSupported, err) 310 } 311 // check that the protocol is in a valid state 312 err = txn.Rollback() 313 if err != nil { 314 t.Fatal(err) 315 } 316} 317 318func TestCopyFromError(t *testing.T) { 319 db := openTestConn(t) 320 defer db.Close() 321 322 txn, err := db.Begin() 323 if err != nil { 324 t.Fatal(err) 325 } 326 defer txn.Rollback() 327 328 _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)") 329 if err != nil { 330 t.Fatal(err) 331 } 332 _, err = txn.Prepare("COPY temp (num) TO STDOUT") 333 if err != errCopyToNotSupported { 334 t.Fatalf("expected %s, got %+v", errCopyToNotSupported, err) 335 } 336 // check that the protocol is in a valid state 337 err = txn.Rollback() 338 if err != nil { 339 t.Fatal(err) 340 } 341} 342 343func TestCopySyntaxError(t *testing.T) { 344 db := openTestConn(t) 345 defer db.Close() 346 347 txn, err := db.Begin() 348 if err != nil { 349 t.Fatal(err) 350 } 351 defer txn.Rollback() 352 353 _, err = txn.Prepare("COPY ") 354 if err == nil { 355 t.Fatal("expected error") 356 } 357 if pge := err.(*Error); pge.Code.Name() != "syntax_error" { 358 t.Fatalf("expected syntax error, got %s (%+v)", pge.Code.Name(), pge) 359 } 360 // check that the protocol is in a valid state 361 err = txn.Rollback() 362 if err != nil { 363 t.Fatal(err) 364 } 365} 366 367// Tests for connection errors in copyin.resploop() 368func TestCopyRespLoopConnectionError(t *testing.T) { 369 db := openTestConn(t) 370 defer db.Close() 371 372 txn, err := db.Begin() 373 if err != nil { 374 t.Fatal(err) 375 } 376 defer txn.Rollback() 377 378 var pid int 379 err = txn.QueryRow("SELECT pg_backend_pid()").Scan(&pid) 380 if err != nil { 381 t.Fatal(err) 382 } 383 384 _, err = txn.Exec("CREATE TEMP TABLE temp (a int)") 385 if err != nil { 386 t.Fatal(err) 387 } 388 389 stmt, err := txn.Prepare(CopyIn("temp", "a")) 390 if err != nil { 391 t.Fatal(err) 392 } 393 defer stmt.Close() 394 395 _, err = db.Exec("SELECT pg_terminate_backend($1)", pid) 396 if err != nil { 397 t.Fatal(err) 398 } 399 400 if getServerVersion(t, db) < 90500 { 401 // We have to try and send something over, since postgres before 402 // version 9.5 won't process SIGTERMs while it's waiting for 403 // CopyData/CopyEnd messages; see tcop/postgres.c. 404 _, err = stmt.Exec(1) 405 if err != nil { 406 t.Fatal(err) 407 } 408 } 409 _, err = stmt.Exec() 410 if err == nil { 411 t.Fatalf("expected error") 412 } 413 switch pge := err.(type) { 414 case *Error: 415 if pge.Code.Name() != "admin_shutdown" { 416 t.Fatalf("expected admin_shutdown, got %s", pge.Code.Name()) 417 } 418 case *net.OpError: 419 // ignore 420 default: 421 if err == driver.ErrBadConn { 422 // likely an EPIPE 423 } else { 424 t.Fatalf("unexpected error, got %+#v", err) 425 } 426 } 427 428 _ = stmt.Close() 429} 430 431func BenchmarkCopyIn(b *testing.B) { 432 db := openTestConn(b) 433 defer db.Close() 434 435 txn, err := db.Begin() 436 if err != nil { 437 b.Fatal(err) 438 } 439 defer txn.Rollback() 440 441 _, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)") 442 if err != nil { 443 b.Fatal(err) 444 } 445 446 stmt, err := txn.Prepare(CopyIn("temp", "a", "b")) 447 if err != nil { 448 b.Fatal(err) 449 } 450 451 for i := 0; i < b.N; i++ { 452 _, err = stmt.Exec(int64(i), "hello world!") 453 if err != nil { 454 b.Fatal(err) 455 } 456 } 457 458 _, err = stmt.Exec() 459 if err != nil { 460 b.Fatal(err) 461 } 462 463 err = stmt.Close() 464 if err != nil { 465 b.Fatal(err) 466 } 467 468 var num int 469 err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num) 470 if err != nil { 471 b.Fatal(err) 472 } 473 474 if num != b.N { 475 b.Fatalf("expected %d items, not %d", b.N, num) 476 } 477} 478