1package pq 2 3import ( 4 "bufio" 5 "context" 6 "crypto/md5" 7 "crypto/sha256" 8 "database/sql" 9 "database/sql/driver" 10 "encoding/binary" 11 "errors" 12 "fmt" 13 "io" 14 "net" 15 "os" 16 "os/user" 17 "path" 18 "path/filepath" 19 "strconv" 20 "strings" 21 "time" 22 "unicode" 23 24 "github.com/lib/pq/oid" 25 "github.com/lib/pq/scram" 26) 27 28// Common error types 29var ( 30 ErrNotSupported = errors.New("pq: Unsupported command") 31 ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction") 32 ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server") 33 ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less") 34 ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly") 35 36 errUnexpectedReady = errors.New("unexpected ReadyForQuery") 37 errNoRowsAffected = errors.New("no RowsAffected available after the empty statement") 38 errNoLastInsertID = errors.New("no LastInsertId available after the empty statement") 39) 40 41// Driver is the Postgres database driver. 42type Driver struct{} 43 44// Open opens a new connection to the database. name is a connection string. 45// Most users should only use it through database/sql package from the standard 46// library. 47func (d *Driver) Open(name string) (driver.Conn, error) { 48 return Open(name) 49} 50 51func init() { 52 sql.Register("postgres", &Driver{}) 53} 54 55type parameterStatus struct { 56 // server version in the same format as server_version_num, or 0 if 57 // unavailable 58 serverVersion int 59 60 // the current location based on the TimeZone value of the session, if 61 // available 62 currentLocation *time.Location 63} 64 65type transactionStatus byte 66 67const ( 68 txnStatusIdle transactionStatus = 'I' 69 txnStatusIdleInTransaction transactionStatus = 'T' 70 txnStatusInFailedTransaction transactionStatus = 'E' 71) 72 73func (s transactionStatus) String() string { 74 switch s { 75 case txnStatusIdle: 76 return "idle" 77 case txnStatusIdleInTransaction: 78 return "idle in transaction" 79 case txnStatusInFailedTransaction: 80 return "in a failed transaction" 81 default: 82 errorf("unknown transactionStatus %d", s) 83 } 84 85 panic("not reached") 86} 87 88// Dialer is the dialer interface. It can be used to obtain more control over 89// how pq creates network connections. 90type Dialer interface { 91 Dial(network, address string) (net.Conn, error) 92 DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) 93} 94 95type DialerContext interface { 96 DialContext(ctx context.Context, network, address string) (net.Conn, error) 97} 98 99type defaultDialer struct { 100 d net.Dialer 101} 102 103func (d defaultDialer) Dial(network, address string) (net.Conn, error) { 104 return d.d.Dial(network, address) 105} 106func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { 107 ctx, cancel := context.WithTimeout(context.Background(), timeout) 108 defer cancel() 109 return d.DialContext(ctx, network, address) 110} 111func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { 112 return d.d.DialContext(ctx, network, address) 113} 114 115type conn struct { 116 c net.Conn 117 buf *bufio.Reader 118 namei int 119 scratch [512]byte 120 txnStatus transactionStatus 121 txnFinish func() 122 123 // Save connection arguments to use during CancelRequest. 124 dialer Dialer 125 opts values 126 127 // Cancellation key data for use with CancelRequest messages. 128 processID int 129 secretKey int 130 131 parameterStatus parameterStatus 132 133 saveMessageType byte 134 saveMessageBuffer []byte 135 136 // If true, this connection is bad and all public-facing functions should 137 // return ErrBadConn. 138 bad bool 139 140 // If set, this connection should never use the binary format when 141 // receiving query results from prepared statements. Only provided for 142 // debugging. 143 disablePreparedBinaryResult bool 144 145 // Whether to always send []byte parameters over as binary. Enables single 146 // round-trip mode for non-prepared Query calls. 147 binaryParameters bool 148 149 // If true this connection is in the middle of a COPY 150 inCopy bool 151} 152 153// Handle driver-side settings in parsed connection string. 154func (cn *conn) handleDriverSettings(o values) (err error) { 155 boolSetting := func(key string, val *bool) error { 156 if value, ok := o[key]; ok { 157 if value == "yes" { 158 *val = true 159 } else if value == "no" { 160 *val = false 161 } else { 162 return fmt.Errorf("unrecognized value %q for %s", value, key) 163 } 164 } 165 return nil 166 } 167 168 err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult) 169 if err != nil { 170 return err 171 } 172 return boolSetting("binary_parameters", &cn.binaryParameters) 173} 174 175func (cn *conn) handlePgpass(o values) { 176 // if a password was supplied, do not process .pgpass 177 if _, ok := o["password"]; ok { 178 return 179 } 180 filename := os.Getenv("PGPASSFILE") 181 if filename == "" { 182 // XXX this code doesn't work on Windows where the default filename is 183 // XXX %APPDATA%\postgresql\pgpass.conf 184 // Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470 185 userHome := os.Getenv("HOME") 186 if userHome == "" { 187 user, err := user.Current() 188 if err != nil { 189 return 190 } 191 userHome = user.HomeDir 192 } 193 filename = filepath.Join(userHome, ".pgpass") 194 } 195 fileinfo, err := os.Stat(filename) 196 if err != nil { 197 return 198 } 199 mode := fileinfo.Mode() 200 if mode&(0x77) != 0 { 201 // XXX should warn about incorrect .pgpass permissions as psql does 202 return 203 } 204 file, err := os.Open(filename) 205 if err != nil { 206 return 207 } 208 defer file.Close() 209 scanner := bufio.NewScanner(io.Reader(file)) 210 hostname := o["host"] 211 ntw, _ := network(o) 212 port := o["port"] 213 db := o["dbname"] 214 username := o["user"] 215 // From: https://github.com/tg/pgpass/blob/master/reader.go 216 getFields := func(s string) []string { 217 fs := make([]string, 0, 5) 218 f := make([]rune, 0, len(s)) 219 220 var esc bool 221 for _, c := range s { 222 switch { 223 case esc: 224 f = append(f, c) 225 esc = false 226 case c == '\\': 227 esc = true 228 case c == ':': 229 fs = append(fs, string(f)) 230 f = f[:0] 231 default: 232 f = append(f, c) 233 } 234 } 235 return append(fs, string(f)) 236 } 237 for scanner.Scan() { 238 line := scanner.Text() 239 if len(line) == 0 || line[0] == '#' { 240 continue 241 } 242 split := getFields(line) 243 if len(split) != 5 { 244 continue 245 } 246 if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) { 247 o["password"] = split[4] 248 return 249 } 250 } 251} 252 253func (cn *conn) writeBuf(b byte) *writeBuf { 254 cn.scratch[0] = b 255 return &writeBuf{ 256 buf: cn.scratch[:5], 257 pos: 1, 258 } 259} 260 261// Open opens a new connection to the database. dsn is a connection string. 262// Most users should only use it through database/sql package from the standard 263// library. 264func Open(dsn string) (_ driver.Conn, err error) { 265 return DialOpen(defaultDialer{}, dsn) 266} 267 268// DialOpen opens a new connection to the database using a dialer. 269func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) { 270 c, err := NewConnector(dsn) 271 if err != nil { 272 return nil, err 273 } 274 c.dialer = d 275 return c.open(context.Background()) 276} 277 278func (c *Connector) open(ctx context.Context) (cn *conn, err error) { 279 // Handle any panics during connection initialization. Note that we 280 // specifically do *not* want to use errRecover(), as that would turn any 281 // connection errors into ErrBadConns, hiding the real error message from 282 // the user. 283 defer errRecoverNoErrBadConn(&err) 284 285 o := c.opts 286 287 cn = &conn{ 288 opts: o, 289 dialer: c.dialer, 290 } 291 err = cn.handleDriverSettings(o) 292 if err != nil { 293 return nil, err 294 } 295 cn.handlePgpass(o) 296 297 cn.c, err = dial(ctx, c.dialer, o) 298 if err != nil { 299 return nil, err 300 } 301 302 err = cn.ssl(o) 303 if err != nil { 304 return nil, err 305 } 306 307 // cn.startup panics on error. Make sure we don't leak cn.c. 308 panicking := true 309 defer func() { 310 if panicking { 311 cn.c.Close() 312 } 313 }() 314 315 cn.buf = bufio.NewReader(cn.c) 316 cn.startup(o) 317 318 // reset the deadline, in case one was set (see dial) 319 if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { 320 err = cn.c.SetDeadline(time.Time{}) 321 } 322 panicking = false 323 return cn, err 324} 325 326func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) { 327 network, address := network(o) 328 // SSL is not necessary or supported over UNIX domain sockets 329 if network == "unix" { 330 o["sslmode"] = "disable" 331 } 332 333 // Zero or not specified means wait indefinitely. 334 if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { 335 seconds, err := strconv.ParseInt(timeout, 10, 0) 336 if err != nil { 337 return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err) 338 } 339 duration := time.Duration(seconds) * time.Second 340 341 // connect_timeout should apply to the entire connection establishment 342 // procedure, so we both use a timeout for the TCP connection 343 // establishment and set a deadline for doing the initial handshake. 344 // The deadline is then reset after startup() is done. 345 deadline := time.Now().Add(duration) 346 var conn net.Conn 347 if dctx, ok := d.(DialerContext); ok { 348 ctx, cancel := context.WithTimeout(ctx, duration) 349 defer cancel() 350 conn, err = dctx.DialContext(ctx, network, address) 351 } else { 352 conn, err = d.DialTimeout(network, address, duration) 353 } 354 if err != nil { 355 return nil, err 356 } 357 err = conn.SetDeadline(deadline) 358 return conn, err 359 } 360 if dctx, ok := d.(DialerContext); ok { 361 return dctx.DialContext(ctx, network, address) 362 } 363 return d.Dial(network, address) 364} 365 366func network(o values) (string, string) { 367 host := o["host"] 368 369 if strings.HasPrefix(host, "/") { 370 sockPath := path.Join(host, ".s.PGSQL."+o["port"]) 371 return "unix", sockPath 372 } 373 374 return "tcp", net.JoinHostPort(host, o["port"]) 375} 376 377type values map[string]string 378 379// scanner implements a tokenizer for libpq-style option strings. 380type scanner struct { 381 s []rune 382 i int 383} 384 385// newScanner returns a new scanner initialized with the option string s. 386func newScanner(s string) *scanner { 387 return &scanner{[]rune(s), 0} 388} 389 390// Next returns the next rune. 391// It returns 0, false if the end of the text has been reached. 392func (s *scanner) Next() (rune, bool) { 393 if s.i >= len(s.s) { 394 return 0, false 395 } 396 r := s.s[s.i] 397 s.i++ 398 return r, true 399} 400 401// SkipSpaces returns the next non-whitespace rune. 402// It returns 0, false if the end of the text has been reached. 403func (s *scanner) SkipSpaces() (rune, bool) { 404 r, ok := s.Next() 405 for unicode.IsSpace(r) && ok { 406 r, ok = s.Next() 407 } 408 return r, ok 409} 410 411// parseOpts parses the options from name and adds them to the values. 412// 413// The parsing code is based on conninfo_parse from libpq's fe-connect.c 414func parseOpts(name string, o values) error { 415 s := newScanner(name) 416 417 for { 418 var ( 419 keyRunes, valRunes []rune 420 r rune 421 ok bool 422 ) 423 424 if r, ok = s.SkipSpaces(); !ok { 425 break 426 } 427 428 // Scan the key 429 for !unicode.IsSpace(r) && r != '=' { 430 keyRunes = append(keyRunes, r) 431 if r, ok = s.Next(); !ok { 432 break 433 } 434 } 435 436 // Skip any whitespace if we're not at the = yet 437 if r != '=' { 438 r, ok = s.SkipSpaces() 439 } 440 441 // The current character should be = 442 if r != '=' || !ok { 443 return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes)) 444 } 445 446 // Skip any whitespace after the = 447 if r, ok = s.SkipSpaces(); !ok { 448 // If we reach the end here, the last value is just an empty string as per libpq. 449 o[string(keyRunes)] = "" 450 break 451 } 452 453 if r != '\'' { 454 for !unicode.IsSpace(r) { 455 if r == '\\' { 456 if r, ok = s.Next(); !ok { 457 return fmt.Errorf(`missing character after backslash`) 458 } 459 } 460 valRunes = append(valRunes, r) 461 462 if r, ok = s.Next(); !ok { 463 break 464 } 465 } 466 } else { 467 quote: 468 for { 469 if r, ok = s.Next(); !ok { 470 return fmt.Errorf(`unterminated quoted string literal in connection string`) 471 } 472 switch r { 473 case '\'': 474 break quote 475 case '\\': 476 r, _ = s.Next() 477 fallthrough 478 default: 479 valRunes = append(valRunes, r) 480 } 481 } 482 } 483 484 o[string(keyRunes)] = string(valRunes) 485 } 486 487 return nil 488} 489 490func (cn *conn) isInTransaction() bool { 491 return cn.txnStatus == txnStatusIdleInTransaction || 492 cn.txnStatus == txnStatusInFailedTransaction 493} 494 495func (cn *conn) checkIsInTransaction(intxn bool) { 496 if cn.isInTransaction() != intxn { 497 cn.bad = true 498 errorf("unexpected transaction status %v", cn.txnStatus) 499 } 500} 501 502func (cn *conn) Begin() (_ driver.Tx, err error) { 503 return cn.begin("") 504} 505 506func (cn *conn) begin(mode string) (_ driver.Tx, err error) { 507 if cn.bad { 508 return nil, driver.ErrBadConn 509 } 510 defer cn.errRecover(&err) 511 512 cn.checkIsInTransaction(false) 513 _, commandTag, err := cn.simpleExec("BEGIN" + mode) 514 if err != nil { 515 return nil, err 516 } 517 if commandTag != "BEGIN" { 518 cn.bad = true 519 return nil, fmt.Errorf("unexpected command tag %s", commandTag) 520 } 521 if cn.txnStatus != txnStatusIdleInTransaction { 522 cn.bad = true 523 return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus) 524 } 525 return cn, nil 526} 527 528func (cn *conn) closeTxn() { 529 if finish := cn.txnFinish; finish != nil { 530 finish() 531 } 532} 533 534func (cn *conn) Commit() (err error) { 535 defer cn.closeTxn() 536 if cn.bad { 537 return driver.ErrBadConn 538 } 539 defer cn.errRecover(&err) 540 541 cn.checkIsInTransaction(true) 542 // We don't want the client to think that everything is okay if it tries 543 // to commit a failed transaction. However, no matter what we return, 544 // database/sql will release this connection back into the free connection 545 // pool so we have to abort the current transaction here. Note that you 546 // would get the same behaviour if you issued a COMMIT in a failed 547 // transaction, so it's also the least surprising thing to do here. 548 if cn.txnStatus == txnStatusInFailedTransaction { 549 if err := cn.Rollback(); err != nil { 550 return err 551 } 552 return ErrInFailedTransaction 553 } 554 555 _, commandTag, err := cn.simpleExec("COMMIT") 556 if err != nil { 557 if cn.isInTransaction() { 558 cn.bad = true 559 } 560 return err 561 } 562 if commandTag != "COMMIT" { 563 cn.bad = true 564 return fmt.Errorf("unexpected command tag %s", commandTag) 565 } 566 cn.checkIsInTransaction(false) 567 return nil 568} 569 570func (cn *conn) Rollback() (err error) { 571 defer cn.closeTxn() 572 if cn.bad { 573 return driver.ErrBadConn 574 } 575 defer cn.errRecover(&err) 576 577 cn.checkIsInTransaction(true) 578 _, commandTag, err := cn.simpleExec("ROLLBACK") 579 if err != nil { 580 if cn.isInTransaction() { 581 cn.bad = true 582 } 583 return err 584 } 585 if commandTag != "ROLLBACK" { 586 return fmt.Errorf("unexpected command tag %s", commandTag) 587 } 588 cn.checkIsInTransaction(false) 589 return nil 590} 591 592func (cn *conn) gname() string { 593 cn.namei++ 594 return strconv.FormatInt(int64(cn.namei), 10) 595} 596 597func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) { 598 b := cn.writeBuf('Q') 599 b.string(q) 600 cn.send(b) 601 602 for { 603 t, r := cn.recv1() 604 switch t { 605 case 'C': 606 res, commandTag = cn.parseComplete(r.string()) 607 case 'Z': 608 cn.processReadyForQuery(r) 609 if res == nil && err == nil { 610 err = errUnexpectedReady 611 } 612 // done 613 return 614 case 'E': 615 err = parseError(r) 616 case 'I': 617 res = emptyRows 618 case 'T', 'D': 619 // ignore any results 620 default: 621 cn.bad = true 622 errorf("unknown response for simple query: %q", t) 623 } 624 } 625} 626 627func (cn *conn) simpleQuery(q string) (res *rows, err error) { 628 defer cn.errRecover(&err) 629 630 b := cn.writeBuf('Q') 631 b.string(q) 632 cn.send(b) 633 634 for { 635 t, r := cn.recv1() 636 switch t { 637 case 'C', 'I': 638 // We allow queries which don't return any results through Query as 639 // well as Exec. We still have to give database/sql a rows object 640 // the user can close, though, to avoid connections from being 641 // leaked. A "rows" with done=true works fine for that purpose. 642 if err != nil { 643 cn.bad = true 644 errorf("unexpected message %q in simple query execution", t) 645 } 646 if res == nil { 647 res = &rows{ 648 cn: cn, 649 } 650 } 651 // Set the result and tag to the last command complete if there wasn't a 652 // query already run. Although queries usually return from here and cede 653 // control to Next, a query with zero results does not. 654 if t == 'C' && res.colNames == nil { 655 res.result, res.tag = cn.parseComplete(r.string()) 656 } 657 res.done = true 658 case 'Z': 659 cn.processReadyForQuery(r) 660 // done 661 return 662 case 'E': 663 res = nil 664 err = parseError(r) 665 case 'D': 666 if res == nil { 667 cn.bad = true 668 errorf("unexpected DataRow in simple query execution") 669 } 670 // the query didn't fail; kick off to Next 671 cn.saveMessage(t, r) 672 return 673 case 'T': 674 // res might be non-nil here if we received a previous 675 // CommandComplete, but that's fine; just overwrite it 676 res = &rows{cn: cn} 677 res.rowsHeader = parsePortalRowDescribe(r) 678 679 // To work around a bug in QueryRow in Go 1.2 and earlier, wait 680 // until the first DataRow has been received. 681 default: 682 cn.bad = true 683 errorf("unknown response for simple query: %q", t) 684 } 685 } 686} 687 688type noRows struct{} 689 690var emptyRows noRows 691 692var _ driver.Result = noRows{} 693 694func (noRows) LastInsertId() (int64, error) { 695 return 0, errNoLastInsertID 696} 697 698func (noRows) RowsAffected() (int64, error) { 699 return 0, errNoRowsAffected 700} 701 702// Decides which column formats to use for a prepared statement. The input is 703// an array of type oids, one element per result column. 704func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte) { 705 if len(colTyps) == 0 { 706 return nil, colFmtDataAllText 707 } 708 709 colFmts = make([]format, len(colTyps)) 710 if forceText { 711 return colFmts, colFmtDataAllText 712 } 713 714 allBinary := true 715 allText := true 716 for i, t := range colTyps { 717 switch t.OID { 718 // This is the list of types to use binary mode for when receiving them 719 // through a prepared statement. If a type appears in this list, it 720 // must also be implemented in binaryDecode in encode.go. 721 case oid.T_bytea: 722 fallthrough 723 case oid.T_int8: 724 fallthrough 725 case oid.T_int4: 726 fallthrough 727 case oid.T_int2: 728 fallthrough 729 case oid.T_uuid: 730 colFmts[i] = formatBinary 731 allText = false 732 733 default: 734 allBinary = false 735 } 736 } 737 738 if allBinary { 739 return colFmts, colFmtDataAllBinary 740 } else if allText { 741 return colFmts, colFmtDataAllText 742 } else { 743 colFmtData = make([]byte, 2+len(colFmts)*2) 744 binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts))) 745 for i, v := range colFmts { 746 binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v)) 747 } 748 return colFmts, colFmtData 749 } 750} 751 752func (cn *conn) prepareTo(q, stmtName string) *stmt { 753 st := &stmt{cn: cn, name: stmtName} 754 755 b := cn.writeBuf('P') 756 b.string(st.name) 757 b.string(q) 758 b.int16(0) 759 760 b.next('D') 761 b.byte('S') 762 b.string(st.name) 763 764 b.next('S') 765 cn.send(b) 766 767 cn.readParseResponse() 768 st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse() 769 st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult) 770 cn.readReadyForQuery() 771 return st 772} 773 774func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) { 775 if cn.bad { 776 return nil, driver.ErrBadConn 777 } 778 defer cn.errRecover(&err) 779 780 if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") { 781 s, err := cn.prepareCopyIn(q) 782 if err == nil { 783 cn.inCopy = true 784 } 785 return s, err 786 } 787 return cn.prepareTo(q, cn.gname()), nil 788} 789 790func (cn *conn) Close() (err error) { 791 // Skip cn.bad return here because we always want to close a connection. 792 defer cn.errRecover(&err) 793 794 // Ensure that cn.c.Close is always run. Since error handling is done with 795 // panics and cn.errRecover, the Close must be in a defer. 796 defer func() { 797 cerr := cn.c.Close() 798 if err == nil { 799 err = cerr 800 } 801 }() 802 803 // Don't go through send(); ListenerConn relies on us not scribbling on the 804 // scratch buffer of this connection. 805 return cn.sendSimpleMessage('X') 806} 807 808// Implement the "Queryer" interface 809func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) { 810 return cn.query(query, args) 811} 812 813func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { 814 if cn.bad { 815 return nil, driver.ErrBadConn 816 } 817 if cn.inCopy { 818 return nil, errCopyInProgress 819 } 820 defer cn.errRecover(&err) 821 822 // Check to see if we can use the "simpleQuery" interface, which is 823 // *much* faster than going through prepare/exec 824 if len(args) == 0 { 825 return cn.simpleQuery(query) 826 } 827 828 if cn.binaryParameters { 829 cn.sendBinaryModeQuery(query, args) 830 831 cn.readParseResponse() 832 cn.readBindResponse() 833 rows := &rows{cn: cn} 834 rows.rowsHeader = cn.readPortalDescribeResponse() 835 cn.postExecuteWorkaround() 836 return rows, nil 837 } 838 st := cn.prepareTo(query, "") 839 st.exec(args) 840 return &rows{ 841 cn: cn, 842 rowsHeader: st.rowsHeader, 843 }, nil 844} 845 846// Implement the optional "Execer" interface for one-shot queries 847func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) { 848 if cn.bad { 849 return nil, driver.ErrBadConn 850 } 851 defer cn.errRecover(&err) 852 853 // Check to see if we can use the "simpleExec" interface, which is 854 // *much* faster than going through prepare/exec 855 if len(args) == 0 { 856 // ignore commandTag, our caller doesn't care 857 r, _, err := cn.simpleExec(query) 858 return r, err 859 } 860 861 if cn.binaryParameters { 862 cn.sendBinaryModeQuery(query, args) 863 864 cn.readParseResponse() 865 cn.readBindResponse() 866 cn.readPortalDescribeResponse() 867 cn.postExecuteWorkaround() 868 res, _, err = cn.readExecuteResponse("Execute") 869 return res, err 870 } 871 // Use the unnamed statement to defer planning until bind 872 // time, or else value-based selectivity estimates cannot be 873 // used. 874 st := cn.prepareTo(query, "") 875 r, err := st.Exec(args) 876 if err != nil { 877 panic(err) 878 } 879 return r, err 880} 881 882func (cn *conn) send(m *writeBuf) { 883 _, err := cn.c.Write(m.wrap()) 884 if err != nil { 885 panic(err) 886 } 887} 888 889func (cn *conn) sendStartupPacket(m *writeBuf) error { 890 _, err := cn.c.Write((m.wrap())[1:]) 891 return err 892} 893 894// Send a message of type typ to the server on the other end of cn. The 895// message should have no payload. This method does not use the scratch 896// buffer. 897func (cn *conn) sendSimpleMessage(typ byte) (err error) { 898 _, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'}) 899 return err 900} 901 902// saveMessage memorizes a message and its buffer in the conn struct. 903// recvMessage will then return these values on the next call to it. This 904// method is useful in cases where you have to see what the next message is 905// going to be (e.g. to see whether it's an error or not) but you can't handle 906// the message yourself. 907func (cn *conn) saveMessage(typ byte, buf *readBuf) { 908 if cn.saveMessageType != 0 { 909 cn.bad = true 910 errorf("unexpected saveMessageType %d", cn.saveMessageType) 911 } 912 cn.saveMessageType = typ 913 cn.saveMessageBuffer = *buf 914} 915 916// recvMessage receives any message from the backend, or returns an error if 917// a problem occurred while reading the message. 918func (cn *conn) recvMessage(r *readBuf) (byte, error) { 919 // workaround for a QueryRow bug, see exec 920 if cn.saveMessageType != 0 { 921 t := cn.saveMessageType 922 *r = cn.saveMessageBuffer 923 cn.saveMessageType = 0 924 cn.saveMessageBuffer = nil 925 return t, nil 926 } 927 928 x := cn.scratch[:5] 929 _, err := io.ReadFull(cn.buf, x) 930 if err != nil { 931 return 0, err 932 } 933 934 // read the type and length of the message that follows 935 t := x[0] 936 n := int(binary.BigEndian.Uint32(x[1:])) - 4 937 var y []byte 938 if n <= len(cn.scratch) { 939 y = cn.scratch[:n] 940 } else { 941 y = make([]byte, n) 942 } 943 _, err = io.ReadFull(cn.buf, y) 944 if err != nil { 945 return 0, err 946 } 947 *r = y 948 return t, nil 949} 950 951// recv receives a message from the backend, but if an error happened while 952// reading the message or the received message was an ErrorResponse, it panics. 953// NoticeResponses are ignored. This function should generally be used only 954// during the startup sequence. 955func (cn *conn) recv() (t byte, r *readBuf) { 956 for { 957 var err error 958 r = &readBuf{} 959 t, err = cn.recvMessage(r) 960 if err != nil { 961 panic(err) 962 } 963 switch t { 964 case 'E': 965 panic(parseError(r)) 966 case 'N': 967 // ignore 968 default: 969 return 970 } 971 } 972} 973 974// recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by 975// the caller to avoid an allocation. 976func (cn *conn) recv1Buf(r *readBuf) byte { 977 for { 978 t, err := cn.recvMessage(r) 979 if err != nil { 980 panic(err) 981 } 982 983 switch t { 984 case 'A', 'N': 985 // ignore 986 case 'S': 987 cn.processParameterStatus(r) 988 default: 989 return t 990 } 991 } 992} 993 994// recv1 receives a message from the backend, panicking if an error occurs 995// while attempting to read it. All asynchronous messages are ignored, with 996// the exception of ErrorResponse. 997func (cn *conn) recv1() (t byte, r *readBuf) { 998 r = &readBuf{} 999 t = cn.recv1Buf(r) 1000 return t, r 1001} 1002 1003func (cn *conn) ssl(o values) error { 1004 upgrade, err := ssl(o) 1005 if err != nil { 1006 return err 1007 } 1008 1009 if upgrade == nil { 1010 // Nothing to do 1011 return nil 1012 } 1013 1014 w := cn.writeBuf(0) 1015 w.int32(80877103) 1016 if err = cn.sendStartupPacket(w); err != nil { 1017 return err 1018 } 1019 1020 b := cn.scratch[:1] 1021 _, err = io.ReadFull(cn.c, b) 1022 if err != nil { 1023 return err 1024 } 1025 1026 if b[0] != 'S' { 1027 return ErrSSLNotSupported 1028 } 1029 1030 cn.c, err = upgrade(cn.c) 1031 return err 1032} 1033 1034// isDriverSetting returns true iff a setting is purely for configuring the 1035// driver's options and should not be sent to the server in the connection 1036// startup packet. 1037func isDriverSetting(key string) bool { 1038 switch key { 1039 case "host", "port": 1040 return true 1041 case "password": 1042 return true 1043 case "sslmode", "sslcert", "sslkey", "sslrootcert": 1044 return true 1045 case "fallback_application_name": 1046 return true 1047 case "connect_timeout": 1048 return true 1049 case "disable_prepared_binary_result": 1050 return true 1051 case "binary_parameters": 1052 return true 1053 1054 default: 1055 return false 1056 } 1057} 1058 1059func (cn *conn) startup(o values) { 1060 w := cn.writeBuf(0) 1061 w.int32(196608) 1062 // Send the backend the name of the database we want to connect to, and the 1063 // user we want to connect as. Additionally, we send over any run-time 1064 // parameters potentially included in the connection string. If the server 1065 // doesn't recognize any of them, it will reply with an error. 1066 for k, v := range o { 1067 if isDriverSetting(k) { 1068 // skip options which can't be run-time parameters 1069 continue 1070 } 1071 // The protocol requires us to supply the database name as "database" 1072 // instead of "dbname". 1073 if k == "dbname" { 1074 k = "database" 1075 } 1076 w.string(k) 1077 w.string(v) 1078 } 1079 w.string("") 1080 if err := cn.sendStartupPacket(w); err != nil { 1081 panic(err) 1082 } 1083 1084 for { 1085 t, r := cn.recv() 1086 switch t { 1087 case 'K': 1088 cn.processBackendKeyData(r) 1089 case 'S': 1090 cn.processParameterStatus(r) 1091 case 'R': 1092 cn.auth(r, o) 1093 case 'Z': 1094 cn.processReadyForQuery(r) 1095 return 1096 default: 1097 errorf("unknown response for startup: %q", t) 1098 } 1099 } 1100} 1101 1102func (cn *conn) auth(r *readBuf, o values) { 1103 switch code := r.int32(); code { 1104 case 0: 1105 // OK 1106 case 3: 1107 w := cn.writeBuf('p') 1108 w.string(o["password"]) 1109 cn.send(w) 1110 1111 t, r := cn.recv() 1112 if t != 'R' { 1113 errorf("unexpected password response: %q", t) 1114 } 1115 1116 if r.int32() != 0 { 1117 errorf("unexpected authentication response: %q", t) 1118 } 1119 case 5: 1120 s := string(r.next(4)) 1121 w := cn.writeBuf('p') 1122 w.string("md5" + md5s(md5s(o["password"]+o["user"])+s)) 1123 cn.send(w) 1124 1125 t, r := cn.recv() 1126 if t != 'R' { 1127 errorf("unexpected password response: %q", t) 1128 } 1129 1130 if r.int32() != 0 { 1131 errorf("unexpected authentication response: %q", t) 1132 } 1133 case 10: 1134 sc := scram.NewClient(sha256.New, o["user"], o["password"]) 1135 sc.Step(nil) 1136 if sc.Err() != nil { 1137 errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) 1138 } 1139 scOut := sc.Out() 1140 1141 w := cn.writeBuf('p') 1142 w.string("SCRAM-SHA-256") 1143 w.int32(len(scOut)) 1144 w.bytes(scOut) 1145 cn.send(w) 1146 1147 t, r := cn.recv() 1148 if t != 'R' { 1149 errorf("unexpected password response: %q", t) 1150 } 1151 1152 if r.int32() != 11 { 1153 errorf("unexpected authentication response: %q", t) 1154 } 1155 1156 nextStep := r.next(len(*r)) 1157 sc.Step(nextStep) 1158 if sc.Err() != nil { 1159 errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) 1160 } 1161 1162 scOut = sc.Out() 1163 w = cn.writeBuf('p') 1164 w.bytes(scOut) 1165 cn.send(w) 1166 1167 t, r = cn.recv() 1168 if t != 'R' { 1169 errorf("unexpected password response: %q", t) 1170 } 1171 1172 if r.int32() != 12 { 1173 errorf("unexpected authentication response: %q", t) 1174 } 1175 1176 nextStep = r.next(len(*r)) 1177 sc.Step(nextStep) 1178 if sc.Err() != nil { 1179 errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) 1180 } 1181 1182 default: 1183 errorf("unknown authentication response: %d", code) 1184 } 1185} 1186 1187type format int 1188 1189const formatText format = 0 1190const formatBinary format = 1 1191 1192// One result-column format code with the value 1 (i.e. all binary). 1193var colFmtDataAllBinary = []byte{0, 1, 0, 1} 1194 1195// No result-column format codes (i.e. all text). 1196var colFmtDataAllText = []byte{0, 0} 1197 1198type stmt struct { 1199 cn *conn 1200 name string 1201 rowsHeader 1202 colFmtData []byte 1203 paramTyps []oid.Oid 1204 closed bool 1205} 1206 1207func (st *stmt) Close() (err error) { 1208 if st.closed { 1209 return nil 1210 } 1211 if st.cn.bad { 1212 return driver.ErrBadConn 1213 } 1214 defer st.cn.errRecover(&err) 1215 1216 w := st.cn.writeBuf('C') 1217 w.byte('S') 1218 w.string(st.name) 1219 st.cn.send(w) 1220 1221 st.cn.send(st.cn.writeBuf('S')) 1222 1223 t, _ := st.cn.recv1() 1224 if t != '3' { 1225 st.cn.bad = true 1226 errorf("unexpected close response: %q", t) 1227 } 1228 st.closed = true 1229 1230 t, r := st.cn.recv1() 1231 if t != 'Z' { 1232 st.cn.bad = true 1233 errorf("expected ready for query, but got: %q", t) 1234 } 1235 st.cn.processReadyForQuery(r) 1236 1237 return nil 1238} 1239 1240func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { 1241 if st.cn.bad { 1242 return nil, driver.ErrBadConn 1243 } 1244 defer st.cn.errRecover(&err) 1245 1246 st.exec(v) 1247 return &rows{ 1248 cn: st.cn, 1249 rowsHeader: st.rowsHeader, 1250 }, nil 1251} 1252 1253func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) { 1254 if st.cn.bad { 1255 return nil, driver.ErrBadConn 1256 } 1257 defer st.cn.errRecover(&err) 1258 1259 st.exec(v) 1260 res, _, err = st.cn.readExecuteResponse("simple query") 1261 return res, err 1262} 1263 1264func (st *stmt) exec(v []driver.Value) { 1265 if len(v) >= 65536 { 1266 errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v)) 1267 } 1268 if len(v) != len(st.paramTyps) { 1269 errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps)) 1270 } 1271 1272 cn := st.cn 1273 w := cn.writeBuf('B') 1274 w.byte(0) // unnamed portal 1275 w.string(st.name) 1276 1277 if cn.binaryParameters { 1278 cn.sendBinaryParameters(w, v) 1279 } else { 1280 w.int16(0) 1281 w.int16(len(v)) 1282 for i, x := range v { 1283 if x == nil { 1284 w.int32(-1) 1285 } else { 1286 b := encode(&cn.parameterStatus, x, st.paramTyps[i]) 1287 w.int32(len(b)) 1288 w.bytes(b) 1289 } 1290 } 1291 } 1292 w.bytes(st.colFmtData) 1293 1294 w.next('E') 1295 w.byte(0) 1296 w.int32(0) 1297 1298 w.next('S') 1299 cn.send(w) 1300 1301 cn.readBindResponse() 1302 cn.postExecuteWorkaround() 1303 1304} 1305 1306func (st *stmt) NumInput() int { 1307 return len(st.paramTyps) 1308} 1309 1310// parseComplete parses the "command tag" from a CommandComplete message, and 1311// returns the number of rows affected (if applicable) and a string 1312// identifying only the command that was executed, e.g. "ALTER TABLE". If the 1313// command tag could not be parsed, parseComplete panics. 1314func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { 1315 commandsWithAffectedRows := []string{ 1316 "SELECT ", 1317 // INSERT is handled below 1318 "UPDATE ", 1319 "DELETE ", 1320 "FETCH ", 1321 "MOVE ", 1322 "COPY ", 1323 } 1324 1325 var affectedRows *string 1326 for _, tag := range commandsWithAffectedRows { 1327 if strings.HasPrefix(commandTag, tag) { 1328 t := commandTag[len(tag):] 1329 affectedRows = &t 1330 commandTag = tag[:len(tag)-1] 1331 break 1332 } 1333 } 1334 // INSERT also includes the oid of the inserted row in its command tag. 1335 // Oids in user tables are deprecated, and the oid is only returned when 1336 // exactly one row is inserted, so it's unlikely to be of value to any 1337 // real-world application and we can ignore it. 1338 if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") { 1339 parts := strings.Split(commandTag, " ") 1340 if len(parts) != 3 { 1341 cn.bad = true 1342 errorf("unexpected INSERT command tag %s", commandTag) 1343 } 1344 affectedRows = &parts[len(parts)-1] 1345 commandTag = "INSERT" 1346 } 1347 // There should be no affected rows attached to the tag, just return it 1348 if affectedRows == nil { 1349 return driver.RowsAffected(0), commandTag 1350 } 1351 n, err := strconv.ParseInt(*affectedRows, 10, 64) 1352 if err != nil { 1353 cn.bad = true 1354 errorf("could not parse commandTag: %s", err) 1355 } 1356 return driver.RowsAffected(n), commandTag 1357} 1358 1359type rowsHeader struct { 1360 colNames []string 1361 colTyps []fieldDesc 1362 colFmts []format 1363} 1364 1365type rows struct { 1366 cn *conn 1367 finish func() 1368 rowsHeader 1369 done bool 1370 rb readBuf 1371 result driver.Result 1372 tag string 1373 1374 next *rowsHeader 1375} 1376 1377func (rs *rows) Close() error { 1378 if finish := rs.finish; finish != nil { 1379 defer finish() 1380 } 1381 // no need to look at cn.bad as Next() will 1382 for { 1383 err := rs.Next(nil) 1384 switch err { 1385 case nil: 1386 case io.EOF: 1387 // rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row 1388 // description, used with HasNextResultSet). We need to fetch messages until 1389 // we hit a 'Z', which is done by waiting for done to be set. 1390 if rs.done { 1391 return nil 1392 } 1393 default: 1394 return err 1395 } 1396 } 1397} 1398 1399func (rs *rows) Columns() []string { 1400 return rs.colNames 1401} 1402 1403func (rs *rows) Result() driver.Result { 1404 if rs.result == nil { 1405 return emptyRows 1406 } 1407 return rs.result 1408} 1409 1410func (rs *rows) Tag() string { 1411 return rs.tag 1412} 1413 1414func (rs *rows) Next(dest []driver.Value) (err error) { 1415 if rs.done { 1416 return io.EOF 1417 } 1418 1419 conn := rs.cn 1420 if conn.bad { 1421 return driver.ErrBadConn 1422 } 1423 defer conn.errRecover(&err) 1424 1425 for { 1426 t := conn.recv1Buf(&rs.rb) 1427 switch t { 1428 case 'E': 1429 err = parseError(&rs.rb) 1430 case 'C', 'I': 1431 if t == 'C' { 1432 rs.result, rs.tag = conn.parseComplete(rs.rb.string()) 1433 } 1434 continue 1435 case 'Z': 1436 conn.processReadyForQuery(&rs.rb) 1437 rs.done = true 1438 if err != nil { 1439 return err 1440 } 1441 return io.EOF 1442 case 'D': 1443 n := rs.rb.int16() 1444 if err != nil { 1445 conn.bad = true 1446 errorf("unexpected DataRow after error %s", err) 1447 } 1448 if n < len(dest) { 1449 dest = dest[:n] 1450 } 1451 for i := range dest { 1452 l := rs.rb.int32() 1453 if l == -1 { 1454 dest[i] = nil 1455 continue 1456 } 1457 dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i]) 1458 } 1459 return 1460 case 'T': 1461 next := parsePortalRowDescribe(&rs.rb) 1462 rs.next = &next 1463 return io.EOF 1464 default: 1465 errorf("unexpected message after execute: %q", t) 1466 } 1467 } 1468} 1469 1470func (rs *rows) HasNextResultSet() bool { 1471 hasNext := rs.next != nil && !rs.done 1472 return hasNext 1473} 1474 1475func (rs *rows) NextResultSet() error { 1476 if rs.next == nil { 1477 return io.EOF 1478 } 1479 rs.rowsHeader = *rs.next 1480 rs.next = nil 1481 return nil 1482} 1483 1484// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be 1485// used as part of an SQL statement. For example: 1486// 1487// tblname := "my_table" 1488// data := "my_data" 1489// quoted := pq.QuoteIdentifier(tblname) 1490// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data) 1491// 1492// Any double quotes in name will be escaped. The quoted identifier will be 1493// case sensitive when used in a query. If the input string contains a zero 1494// byte, the result will be truncated immediately before it. 1495func QuoteIdentifier(name string) string { 1496 end := strings.IndexRune(name, 0) 1497 if end > -1 { 1498 name = name[:end] 1499 } 1500 return `"` + strings.Replace(name, `"`, `""`, -1) + `"` 1501} 1502 1503func md5s(s string) string { 1504 h := md5.New() 1505 h.Write([]byte(s)) 1506 return fmt.Sprintf("%x", h.Sum(nil)) 1507} 1508 1509func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) { 1510 // Do one pass over the parameters to see if we're going to send any of 1511 // them over in binary. If we are, create a paramFormats array at the 1512 // same time. 1513 var paramFormats []int 1514 for i, x := range args { 1515 _, ok := x.([]byte) 1516 if ok { 1517 if paramFormats == nil { 1518 paramFormats = make([]int, len(args)) 1519 } 1520 paramFormats[i] = 1 1521 } 1522 } 1523 if paramFormats == nil { 1524 b.int16(0) 1525 } else { 1526 b.int16(len(paramFormats)) 1527 for _, x := range paramFormats { 1528 b.int16(x) 1529 } 1530 } 1531 1532 b.int16(len(args)) 1533 for _, x := range args { 1534 if x == nil { 1535 b.int32(-1) 1536 } else { 1537 datum := binaryEncode(&cn.parameterStatus, x) 1538 b.int32(len(datum)) 1539 b.bytes(datum) 1540 } 1541 } 1542} 1543 1544func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) { 1545 if len(args) >= 65536 { 1546 errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args)) 1547 } 1548 1549 b := cn.writeBuf('P') 1550 b.byte(0) // unnamed statement 1551 b.string(query) 1552 b.int16(0) 1553 1554 b.next('B') 1555 b.int16(0) // unnamed portal and statement 1556 cn.sendBinaryParameters(b, args) 1557 b.bytes(colFmtDataAllText) 1558 1559 b.next('D') 1560 b.byte('P') 1561 b.byte(0) // unnamed portal 1562 1563 b.next('E') 1564 b.byte(0) 1565 b.int32(0) 1566 1567 b.next('S') 1568 cn.send(b) 1569} 1570 1571func (cn *conn) processParameterStatus(r *readBuf) { 1572 var err error 1573 1574 param := r.string() 1575 switch param { 1576 case "server_version": 1577 var major1 int 1578 var major2 int 1579 var minor int 1580 _, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor) 1581 if err == nil { 1582 cn.parameterStatus.serverVersion = major1*10000 + major2*100 + minor 1583 } 1584 1585 case "TimeZone": 1586 cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string()) 1587 if err != nil { 1588 cn.parameterStatus.currentLocation = nil 1589 } 1590 1591 default: 1592 // ignore 1593 } 1594} 1595 1596func (cn *conn) processReadyForQuery(r *readBuf) { 1597 cn.txnStatus = transactionStatus(r.byte()) 1598} 1599 1600func (cn *conn) readReadyForQuery() { 1601 t, r := cn.recv1() 1602 switch t { 1603 case 'Z': 1604 cn.processReadyForQuery(r) 1605 return 1606 default: 1607 cn.bad = true 1608 errorf("unexpected message %q; expected ReadyForQuery", t) 1609 } 1610} 1611 1612func (cn *conn) processBackendKeyData(r *readBuf) { 1613 cn.processID = r.int32() 1614 cn.secretKey = r.int32() 1615} 1616 1617func (cn *conn) readParseResponse() { 1618 t, r := cn.recv1() 1619 switch t { 1620 case '1': 1621 return 1622 case 'E': 1623 err := parseError(r) 1624 cn.readReadyForQuery() 1625 panic(err) 1626 default: 1627 cn.bad = true 1628 errorf("unexpected Parse response %q", t) 1629 } 1630} 1631 1632func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc) { 1633 for { 1634 t, r := cn.recv1() 1635 switch t { 1636 case 't': 1637 nparams := r.int16() 1638 paramTyps = make([]oid.Oid, nparams) 1639 for i := range paramTyps { 1640 paramTyps[i] = r.oid() 1641 } 1642 case 'n': 1643 return paramTyps, nil, nil 1644 case 'T': 1645 colNames, colTyps = parseStatementRowDescribe(r) 1646 return paramTyps, colNames, colTyps 1647 case 'E': 1648 err := parseError(r) 1649 cn.readReadyForQuery() 1650 panic(err) 1651 default: 1652 cn.bad = true 1653 errorf("unexpected Describe statement response %q", t) 1654 } 1655 } 1656} 1657 1658func (cn *conn) readPortalDescribeResponse() rowsHeader { 1659 t, r := cn.recv1() 1660 switch t { 1661 case 'T': 1662 return parsePortalRowDescribe(r) 1663 case 'n': 1664 return rowsHeader{} 1665 case 'E': 1666 err := parseError(r) 1667 cn.readReadyForQuery() 1668 panic(err) 1669 default: 1670 cn.bad = true 1671 errorf("unexpected Describe response %q", t) 1672 } 1673 panic("not reached") 1674} 1675 1676func (cn *conn) readBindResponse() { 1677 t, r := cn.recv1() 1678 switch t { 1679 case '2': 1680 return 1681 case 'E': 1682 err := parseError(r) 1683 cn.readReadyForQuery() 1684 panic(err) 1685 default: 1686 cn.bad = true 1687 errorf("unexpected Bind response %q", t) 1688 } 1689} 1690 1691func (cn *conn) postExecuteWorkaround() { 1692 // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores 1693 // any errors from rows.Next, which masks errors that happened during the 1694 // execution of the query. To avoid the problem in common cases, we wait 1695 // here for one more message from the database. If it's not an error the 1696 // query will likely succeed (or perhaps has already, if it's a 1697 // CommandComplete), so we push the message into the conn struct; recv1 1698 // will return it as the next message for rows.Next or rows.Close. 1699 // However, if it's an error, we wait until ReadyForQuery and then return 1700 // the error to our caller. 1701 for { 1702 t, r := cn.recv1() 1703 switch t { 1704 case 'E': 1705 err := parseError(r) 1706 cn.readReadyForQuery() 1707 panic(err) 1708 case 'C', 'D', 'I': 1709 // the query didn't fail, but we can't process this message 1710 cn.saveMessage(t, r) 1711 return 1712 default: 1713 cn.bad = true 1714 errorf("unexpected message during extended query execution: %q", t) 1715 } 1716 } 1717} 1718 1719// Only for Exec(), since we ignore the returned data 1720func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) { 1721 for { 1722 t, r := cn.recv1() 1723 switch t { 1724 case 'C': 1725 if err != nil { 1726 cn.bad = true 1727 errorf("unexpected CommandComplete after error %s", err) 1728 } 1729 res, commandTag = cn.parseComplete(r.string()) 1730 case 'Z': 1731 cn.processReadyForQuery(r) 1732 if res == nil && err == nil { 1733 err = errUnexpectedReady 1734 } 1735 return res, commandTag, err 1736 case 'E': 1737 err = parseError(r) 1738 case 'T', 'D', 'I': 1739 if err != nil { 1740 cn.bad = true 1741 errorf("unexpected %q after error %s", t, err) 1742 } 1743 if t == 'I' { 1744 res = emptyRows 1745 } 1746 // ignore any results 1747 default: 1748 cn.bad = true 1749 errorf("unknown %s response: %q", protocolState, t) 1750 } 1751 } 1752} 1753 1754func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) { 1755 n := r.int16() 1756 colNames = make([]string, n) 1757 colTyps = make([]fieldDesc, n) 1758 for i := range colNames { 1759 colNames[i] = r.string() 1760 r.next(6) 1761 colTyps[i].OID = r.oid() 1762 colTyps[i].Len = r.int16() 1763 colTyps[i].Mod = r.int32() 1764 // format code not known when describing a statement; always 0 1765 r.next(2) 1766 } 1767 return 1768} 1769 1770func parsePortalRowDescribe(r *readBuf) rowsHeader { 1771 n := r.int16() 1772 colNames := make([]string, n) 1773 colFmts := make([]format, n) 1774 colTyps := make([]fieldDesc, n) 1775 for i := range colNames { 1776 colNames[i] = r.string() 1777 r.next(6) 1778 colTyps[i].OID = r.oid() 1779 colTyps[i].Len = r.int16() 1780 colTyps[i].Mod = r.int32() 1781 colFmts[i] = format(r.int16()) 1782 } 1783 return rowsHeader{ 1784 colNames: colNames, 1785 colFmts: colFmts, 1786 colTyps: colTyps, 1787 } 1788} 1789 1790// parseEnviron tries to mimic some of libpq's environment handling 1791// 1792// To ease testing, it does not directly reference os.Environ, but is 1793// designed to accept its output. 1794// 1795// Environment-set connection information is intended to have a higher 1796// precedence than a library default but lower than any explicitly 1797// passed information (such as in the URL or connection string). 1798func parseEnviron(env []string) (out map[string]string) { 1799 out = make(map[string]string) 1800 1801 for _, v := range env { 1802 parts := strings.SplitN(v, "=", 2) 1803 1804 accrue := func(keyname string) { 1805 out[keyname] = parts[1] 1806 } 1807 unsupported := func() { 1808 panic(fmt.Sprintf("setting %v not supported", parts[0])) 1809 } 1810 1811 // The order of these is the same as is seen in the 1812 // PostgreSQL 9.1 manual. Unsupported but well-defined 1813 // keys cause a panic; these should be unset prior to 1814 // execution. Options which pq expects to be set to a 1815 // certain value are allowed, but must be set to that 1816 // value if present (they can, of course, be absent). 1817 switch parts[0] { 1818 case "PGHOST": 1819 accrue("host") 1820 case "PGHOSTADDR": 1821 unsupported() 1822 case "PGPORT": 1823 accrue("port") 1824 case "PGDATABASE": 1825 accrue("dbname") 1826 case "PGUSER": 1827 accrue("user") 1828 case "PGPASSWORD": 1829 accrue("password") 1830 case "PGSERVICE", "PGSERVICEFILE", "PGREALM": 1831 unsupported() 1832 case "PGOPTIONS": 1833 accrue("options") 1834 case "PGAPPNAME": 1835 accrue("application_name") 1836 case "PGSSLMODE": 1837 accrue("sslmode") 1838 case "PGSSLCERT": 1839 accrue("sslcert") 1840 case "PGSSLKEY": 1841 accrue("sslkey") 1842 case "PGSSLROOTCERT": 1843 accrue("sslrootcert") 1844 case "PGREQUIRESSL", "PGSSLCRL": 1845 unsupported() 1846 case "PGREQUIREPEER": 1847 unsupported() 1848 case "PGKRBSRVNAME", "PGGSSLIB": 1849 unsupported() 1850 case "PGCONNECT_TIMEOUT": 1851 accrue("connect_timeout") 1852 case "PGCLIENTENCODING": 1853 accrue("client_encoding") 1854 case "PGDATESTYLE": 1855 accrue("datestyle") 1856 case "PGTZ": 1857 accrue("timezone") 1858 case "PGGEQO": 1859 accrue("geqo") 1860 case "PGSYSCONFDIR", "PGLOCALEDIR": 1861 unsupported() 1862 } 1863 } 1864 1865 return out 1866} 1867 1868// isUTF8 returns whether name is a fuzzy variation of the string "UTF-8". 1869func isUTF8(name string) bool { 1870 // Recognize all sorts of silly things as "UTF-8", like Postgres does 1871 s := strings.Map(alnumLowerASCII, name) 1872 return s == "utf8" || s == "unicode" 1873} 1874 1875func alnumLowerASCII(ch rune) rune { 1876 if 'A' <= ch && ch <= 'Z' { 1877 return ch + ('a' - 'A') 1878 } 1879 if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' { 1880 return ch 1881 } 1882 return -1 // discard 1883} 1884