1package pq 2 3import ( 4 "bytes" 5 "database/sql" 6 "database/sql/driver" 7 "fmt" 8 "net" 9 "strings" 10 "testing" 11 "time" 12) 13 14func TestCopyInStmt(t *testing.T) { 15 stmt := CopyIn("table name") 16 if stmt != `COPY "table name" () FROM STDIN` { 17 t.Fatal(stmt) 18 } 19 20 stmt = CopyIn("table name", "column 1", "column 2") 21 if stmt != `COPY "table name" ("column 1", "column 2") FROM STDIN` { 22 t.Fatal(stmt) 23 } 24 25 stmt = CopyIn(`table " name """`, `co"lumn""`) 26 if stmt != `COPY "table "" name """"""" ("co""lumn""""") FROM STDIN` { 27 t.Fatal(stmt) 28 } 29} 30 31func TestCopyInSchemaStmt(t *testing.T) { 32 stmt := CopyInSchema("schema name", "table name") 33 if stmt != `COPY "schema name"."table name" () FROM STDIN` { 34 t.Fatal(stmt) 35 } 36 37 stmt = CopyInSchema("schema name", "table name", "column 1", "column 2") 38 if stmt != `COPY "schema name"."table name" ("column 1", "column 2") FROM STDIN` { 39 t.Fatal(stmt) 40 } 41 42 stmt = CopyInSchema(`schema " name """`, `table " name """`, `co"lumn""`) 43 if stmt != `COPY "schema "" name """"""".`+ 44 `"table "" name """"""" ("co""lumn""""") FROM STDIN` { 45 t.Fatal(stmt) 46 } 47} 48 49func TestCopyInMultipleValues(t *testing.T) { 50 db := openTestConn(t) 51 defer db.Close() 52 53 txn, err := db.Begin() 54 if err != nil { 55 t.Fatal(err) 56 } 57 defer txn.Rollback() 58 59 _, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)") 60 if err != nil { 61 t.Fatal(err) 62 } 63 64 stmt, err := txn.Prepare(CopyIn("temp", "a", "b")) 65 if err != nil { 66 t.Fatal(err) 67 } 68 69 longString := strings.Repeat("#", 500) 70 71 for i := 0; i < 500; i++ { 72 _, err = stmt.Exec(int64(i), longString) 73 if err != nil { 74 t.Fatal(err) 75 } 76 } 77 78 result, err := stmt.Exec() 79 if err != nil { 80 t.Fatal(err) 81 } 82 83 rowsAffected, err := result.RowsAffected() 84 if err != nil { 85 t.Fatal(err) 86 } 87 88 if rowsAffected != 500 { 89 t.Fatalf("expected 500 rows affected, not %d", rowsAffected) 90 } 91 92 err = stmt.Close() 93 if err != nil { 94 t.Fatal(err) 95 } 96 97 var num int 98 err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num) 99 if err != nil { 100 t.Fatal(err) 101 } 102 103 if num != 500 { 104 t.Fatalf("expected 500 items, not %d", num) 105 } 106} 107 108func TestCopyInRaiseStmtTrigger(t *testing.T) { 109 db := openTestConn(t) 110 defer db.Close() 111 112 if getServerVersion(t, db) < 90000 { 113 var exists int 114 err := db.QueryRow("SELECT 1 FROM pg_language WHERE lanname = 'plpgsql'").Scan(&exists) 115 if err == sql.ErrNoRows { 116 t.Skip("language PL/PgSQL does not exist; skipping TestCopyInRaiseStmtTrigger") 117 } else if err != nil { 118 t.Fatal(err) 119 } 120 } 121 122 txn, err := db.Begin() 123 if err != nil { 124 t.Fatal(err) 125 } 126 defer txn.Rollback() 127 128 _, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)") 129 if err != nil { 130 t.Fatal(err) 131 } 132 133 _, err = txn.Exec(` 134 CREATE OR REPLACE FUNCTION pg_temp.temptest() 135 RETURNS trigger AS 136 $BODY$ begin 137 raise notice 'Hello world'; 138 return new; 139 end $BODY$ 140 LANGUAGE plpgsql`) 141 if err != nil { 142 t.Fatal(err) 143 } 144 145 _, err = txn.Exec(` 146 CREATE TRIGGER temptest_trigger 147 BEFORE INSERT 148 ON temp 149 FOR EACH ROW 150 EXECUTE PROCEDURE pg_temp.temptest()`) 151 if err != nil { 152 t.Fatal(err) 153 } 154 155 stmt, err := txn.Prepare(CopyIn("temp", "a", "b")) 156 if err != nil { 157 t.Fatal(err) 158 } 159 160 longString := strings.Repeat("#", 500) 161 162 _, err = stmt.Exec(int64(1), longString) 163 if err != nil { 164 t.Fatal(err) 165 } 166 167 _, err = stmt.Exec() 168 if err != nil { 169 t.Fatal(err) 170 } 171 172 err = stmt.Close() 173 if err != nil { 174 t.Fatal(err) 175 } 176 177 var num int 178 err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num) 179 if err != nil { 180 t.Fatal(err) 181 } 182 183 if num != 1 { 184 t.Fatalf("expected 1 items, not %d", num) 185 } 186} 187 188func TestCopyInTypes(t *testing.T) { 189 db := openTestConn(t) 190 defer db.Close() 191 192 txn, err := db.Begin() 193 if err != nil { 194 t.Fatal(err) 195 } 196 defer txn.Rollback() 197 198 _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER, text VARCHAR, blob BYTEA, nothing VARCHAR)") 199 if err != nil { 200 t.Fatal(err) 201 } 202 203 stmt, err := txn.Prepare(CopyIn("temp", "num", "text", "blob", "nothing")) 204 if err != nil { 205 t.Fatal(err) 206 } 207 208 _, err = stmt.Exec(int64(1234567890), "Héllö\n ☃!\r\t\\", []byte{0, 255, 9, 10, 13}, nil) 209 if err != nil { 210 t.Fatal(err) 211 } 212 213 _, err = stmt.Exec() 214 if err != nil { 215 t.Fatal(err) 216 } 217 218 err = stmt.Close() 219 if err != nil { 220 t.Fatal(err) 221 } 222 223 var num int 224 var text string 225 var blob []byte 226 var nothing sql.NullString 227 228 err = txn.QueryRow("SELECT * FROM temp").Scan(&num, &text, &blob, ¬hing) 229 if err != nil { 230 t.Fatal(err) 231 } 232 233 if num != 1234567890 { 234 t.Fatal("unexpected result", num) 235 } 236 if text != "Héllö\n ☃!\r\t\\" { 237 t.Fatal("unexpected result", text) 238 } 239 if !bytes.Equal(blob, []byte{0, 255, 9, 10, 13}) { 240 t.Fatal("unexpected result", blob) 241 } 242 if nothing.Valid { 243 t.Fatal("unexpected result", nothing.String) 244 } 245} 246 247func TestCopyInWrongType(t *testing.T) { 248 db := openTestConn(t) 249 defer db.Close() 250 251 txn, err := db.Begin() 252 if err != nil { 253 t.Fatal(err) 254 } 255 defer txn.Rollback() 256 257 _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)") 258 if err != nil { 259 t.Fatal(err) 260 } 261 262 stmt, err := txn.Prepare(CopyIn("temp", "num")) 263 if err != nil { 264 t.Fatal(err) 265 } 266 defer stmt.Close() 267 268 _, err = stmt.Exec("Héllö\n ☃!\r\t\\") 269 if err != nil { 270 t.Fatal(err) 271 } 272 273 _, err = stmt.Exec() 274 if err == nil { 275 t.Fatal("expected error") 276 } 277 if pge := err.(*Error); pge.Code.Name() != "invalid_text_representation" { 278 t.Fatalf("expected 'invalid input syntax for integer' error, got %s (%+v)", pge.Code.Name(), pge) 279 } 280} 281 282func TestCopyOutsideOfTxnError(t *testing.T) { 283 db := openTestConn(t) 284 defer db.Close() 285 286 _, err := db.Prepare(CopyIn("temp", "num")) 287 if err == nil { 288 t.Fatal("COPY outside of transaction did not return an error") 289 } 290 if err != errCopyNotSupportedOutsideTxn { 291 t.Fatalf("expected %s, got %s", err, err.Error()) 292 } 293} 294 295func TestCopyInBinaryError(t *testing.T) { 296 db := openTestConn(t) 297 defer db.Close() 298 299 txn, err := db.Begin() 300 if err != nil { 301 t.Fatal(err) 302 } 303 defer txn.Rollback() 304 305 _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)") 306 if err != nil { 307 t.Fatal(err) 308 } 309 _, err = txn.Prepare("COPY temp (num) FROM STDIN WITH binary") 310 if err != errBinaryCopyNotSupported { 311 t.Fatalf("expected %s, got %+v", errBinaryCopyNotSupported, err) 312 } 313 // check that the protocol is in a valid state 314 err = txn.Rollback() 315 if err != nil { 316 t.Fatal(err) 317 } 318} 319 320func TestCopyFromError(t *testing.T) { 321 db := openTestConn(t) 322 defer db.Close() 323 324 txn, err := db.Begin() 325 if err != nil { 326 t.Fatal(err) 327 } 328 defer txn.Rollback() 329 330 _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)") 331 if err != nil { 332 t.Fatal(err) 333 } 334 _, err = txn.Prepare("COPY temp (num) TO STDOUT") 335 if err != errCopyToNotSupported { 336 t.Fatalf("expected %s, got %+v", errCopyToNotSupported, err) 337 } 338 // check that the protocol is in a valid state 339 err = txn.Rollback() 340 if err != nil { 341 t.Fatal(err) 342 } 343} 344 345func TestCopySyntaxError(t *testing.T) { 346 db := openTestConn(t) 347 defer db.Close() 348 349 txn, err := db.Begin() 350 if err != nil { 351 t.Fatal(err) 352 } 353 defer txn.Rollback() 354 355 _, err = txn.Prepare("COPY ") 356 if err == nil { 357 t.Fatal("expected error") 358 } 359 if pge := err.(*Error); pge.Code.Name() != "syntax_error" { 360 t.Fatalf("expected syntax error, got %s (%+v)", pge.Code.Name(), pge) 361 } 362 // check that the protocol is in a valid state 363 err = txn.Rollback() 364 if err != nil { 365 t.Fatal(err) 366 } 367} 368 369// Tests for connection errors in copyin.resploop() 370func TestCopyRespLoopConnectionError(t *testing.T) { 371 db := openTestConn(t) 372 defer db.Close() 373 374 txn, err := db.Begin() 375 if err != nil { 376 t.Fatal(err) 377 } 378 defer txn.Rollback() 379 380 var pid int 381 err = txn.QueryRow("SELECT pg_backend_pid()").Scan(&pid) 382 if err != nil { 383 t.Fatal(err) 384 } 385 386 _, err = txn.Exec("CREATE TEMP TABLE temp (a int)") 387 if err != nil { 388 t.Fatal(err) 389 } 390 391 stmt, err := txn.Prepare(CopyIn("temp", "a")) 392 if err != nil { 393 t.Fatal(err) 394 } 395 defer stmt.Close() 396 397 _, err = db.Exec("SELECT pg_terminate_backend($1)", pid) 398 if err != nil { 399 t.Fatal(err) 400 } 401 402 if getServerVersion(t, db) < 90500 { 403 // We have to try and send something over, since postgres before 404 // version 9.5 won't process SIGTERMs while it's waiting for 405 // CopyData/CopyEnd messages; see tcop/postgres.c. 406 _, err = stmt.Exec(1) 407 if err != nil { 408 t.Fatal(err) 409 } 410 } 411 retry(t, time.Second*5, func() error { 412 _, err = stmt.Exec() 413 if err == nil { 414 return fmt.Errorf("expected error") 415 } 416 return nil 417 }) 418 switch pge := err.(type) { 419 case *Error: 420 if pge.Code.Name() != "admin_shutdown" { 421 t.Fatalf("expected admin_shutdown, got %s", pge.Code.Name()) 422 } 423 case *net.OpError: 424 // ignore 425 default: 426 if err == driver.ErrBadConn { 427 // likely an EPIPE 428 } else if err == errCopyInClosed { 429 // ignore 430 } else { 431 t.Fatalf("unexpected error, got %+#v", err) 432 } 433 } 434 435 _ = stmt.Close() 436} 437 438// retry executes f in a backoff loop until it doesn't return an error. If this 439// doesn't happen within duration, t.Fatal is called with the latest error. 440func retry(t *testing.T, duration time.Duration, f func() error) { 441 start := time.Now() 442 next := time.Millisecond * 100 443 for { 444 err := f() 445 if err == nil { 446 return 447 } 448 if time.Since(start) > duration { 449 t.Fatal(err) 450 } 451 time.Sleep(next) 452 next *= 2 453 } 454} 455 456func BenchmarkCopyIn(b *testing.B) { 457 db := openTestConn(b) 458 defer db.Close() 459 460 txn, err := db.Begin() 461 if err != nil { 462 b.Fatal(err) 463 } 464 defer txn.Rollback() 465 466 _, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)") 467 if err != nil { 468 b.Fatal(err) 469 } 470 471 stmt, err := txn.Prepare(CopyIn("temp", "a", "b")) 472 if err != nil { 473 b.Fatal(err) 474 } 475 476 for i := 0; i < b.N; i++ { 477 _, err = stmt.Exec(int64(i), "hello world!") 478 if err != nil { 479 b.Fatal(err) 480 } 481 } 482 483 _, err = stmt.Exec() 484 if err != nil { 485 b.Fatal(err) 486 } 487 488 err = stmt.Close() 489 if err != nil { 490 b.Fatal(err) 491 } 492 493 var num int 494 err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num) 495 if err != nil { 496 b.Fatal(err) 497 } 498 499 if num != b.N { 500 b.Fatalf("expected %d items, not %d", b.N, num) 501 } 502} 503