1package pq 2 3import ( 4 "bufio" 5 "context" 6 "crypto/md5" 7 "crypto/sha256" 8 "database/sql" 9 "database/sql/driver" 10 "encoding/binary" 11 "errors" 12 "fmt" 13 "io" 14 "net" 15 "os" 16 "os/user" 17 "path" 18 "path/filepath" 19 "strconv" 20 "strings" 21 "time" 22 "unicode" 23 24 "github.com/lib/pq/oid" 25 "github.com/lib/pq/scram" 26) 27 28// Common error types 29var ( 30 ErrNotSupported = errors.New("pq: Unsupported command") 31 ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction") 32 ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server") 33 ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less") 34 ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly") 35 36 errUnexpectedReady = errors.New("unexpected ReadyForQuery") 37 errNoRowsAffected = errors.New("no RowsAffected available after the empty statement") 38 errNoLastInsertID = errors.New("no LastInsertId available after the empty statement") 39) 40 41// Driver is the Postgres database driver. 42type Driver struct{} 43 44// Open opens a new connection to the database. name is a connection string. 45// Most users should only use it through database/sql package from the standard 46// library. 47func (d *Driver) Open(name string) (driver.Conn, error) { 48 return Open(name) 49} 50 51func init() { 52 sql.Register("postgres", &Driver{}) 53} 54 55type parameterStatus struct { 56 // server version in the same format as server_version_num, or 0 if 57 // unavailable 58 serverVersion int 59 60 // the current location based on the TimeZone value of the session, if 61 // available 62 currentLocation *time.Location 63} 64 65type transactionStatus byte 66 67const ( 68 txnStatusIdle transactionStatus = 'I' 69 txnStatusIdleInTransaction transactionStatus = 'T' 70 txnStatusInFailedTransaction transactionStatus = 'E' 71) 72 73func (s transactionStatus) String() string { 74 switch s { 75 case txnStatusIdle: 76 return "idle" 77 case txnStatusIdleInTransaction: 78 return "idle in transaction" 79 case txnStatusInFailedTransaction: 80 return "in a failed transaction" 81 default: 82 errorf("unknown transactionStatus %d", s) 83 } 84 85 panic("not reached") 86} 87 88// Dialer is the dialer interface. It can be used to obtain more control over 89// how pq creates network connections. 90type Dialer interface { 91 Dial(network, address string) (net.Conn, error) 92 DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) 93} 94 95// DialerContext is the context-aware dialer interface. 96type DialerContext interface { 97 DialContext(ctx context.Context, network, address string) (net.Conn, error) 98} 99 100type defaultDialer struct { 101 d net.Dialer 102} 103 104func (d defaultDialer) Dial(network, address string) (net.Conn, error) { 105 return d.d.Dial(network, address) 106} 107func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { 108 ctx, cancel := context.WithTimeout(context.Background(), timeout) 109 defer cancel() 110 return d.DialContext(ctx, network, address) 111} 112func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { 113 return d.d.DialContext(ctx, network, address) 114} 115 116type conn struct { 117 c net.Conn 118 buf *bufio.Reader 119 namei int 120 scratch [512]byte 121 txnStatus transactionStatus 122 txnFinish func() 123 124 // Save connection arguments to use during CancelRequest. 125 dialer Dialer 126 opts values 127 128 // Cancellation key data for use with CancelRequest messages. 129 processID int 130 secretKey int 131 132 parameterStatus parameterStatus 133 134 saveMessageType byte 135 saveMessageBuffer []byte 136 137 // If true, this connection is bad and all public-facing functions should 138 // return ErrBadConn. 139 bad bool 140 141 // If set, this connection should never use the binary format when 142 // receiving query results from prepared statements. Only provided for 143 // debugging. 144 disablePreparedBinaryResult bool 145 146 // Whether to always send []byte parameters over as binary. Enables single 147 // round-trip mode for non-prepared Query calls. 148 binaryParameters bool 149 150 // If true this connection is in the middle of a COPY 151 inCopy bool 152} 153 154// Handle driver-side settings in parsed connection string. 155func (cn *conn) handleDriverSettings(o values) (err error) { 156 boolSetting := func(key string, val *bool) error { 157 if value, ok := o[key]; ok { 158 if value == "yes" { 159 *val = true 160 } else if value == "no" { 161 *val = false 162 } else { 163 return fmt.Errorf("unrecognized value %q for %s", value, key) 164 } 165 } 166 return nil 167 } 168 169 err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult) 170 if err != nil { 171 return err 172 } 173 return boolSetting("binary_parameters", &cn.binaryParameters) 174} 175 176func (cn *conn) handlePgpass(o values) { 177 // if a password was supplied, do not process .pgpass 178 if _, ok := o["password"]; ok { 179 return 180 } 181 filename := os.Getenv("PGPASSFILE") 182 if filename == "" { 183 // XXX this code doesn't work on Windows where the default filename is 184 // XXX %APPDATA%\postgresql\pgpass.conf 185 // Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470 186 userHome := os.Getenv("HOME") 187 if userHome == "" { 188 user, err := user.Current() 189 if err != nil { 190 return 191 } 192 userHome = user.HomeDir 193 } 194 filename = filepath.Join(userHome, ".pgpass") 195 } 196 fileinfo, err := os.Stat(filename) 197 if err != nil { 198 return 199 } 200 mode := fileinfo.Mode() 201 if mode&(0x77) != 0 { 202 // XXX should warn about incorrect .pgpass permissions as psql does 203 return 204 } 205 file, err := os.Open(filename) 206 if err != nil { 207 return 208 } 209 defer file.Close() 210 scanner := bufio.NewScanner(io.Reader(file)) 211 hostname := o["host"] 212 ntw, _ := network(o) 213 port := o["port"] 214 db := o["dbname"] 215 username := o["user"] 216 // From: https://github.com/tg/pgpass/blob/master/reader.go 217 getFields := func(s string) []string { 218 fs := make([]string, 0, 5) 219 f := make([]rune, 0, len(s)) 220 221 var esc bool 222 for _, c := range s { 223 switch { 224 case esc: 225 f = append(f, c) 226 esc = false 227 case c == '\\': 228 esc = true 229 case c == ':': 230 fs = append(fs, string(f)) 231 f = f[:0] 232 default: 233 f = append(f, c) 234 } 235 } 236 return append(fs, string(f)) 237 } 238 for scanner.Scan() { 239 line := scanner.Text() 240 if len(line) == 0 || line[0] == '#' { 241 continue 242 } 243 split := getFields(line) 244 if len(split) != 5 { 245 continue 246 } 247 if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) { 248 o["password"] = split[4] 249 return 250 } 251 } 252} 253 254func (cn *conn) writeBuf(b byte) *writeBuf { 255 cn.scratch[0] = b 256 return &writeBuf{ 257 buf: cn.scratch[:5], 258 pos: 1, 259 } 260} 261 262// Open opens a new connection to the database. dsn is a connection string. 263// Most users should only use it through database/sql package from the standard 264// library. 265func Open(dsn string) (_ driver.Conn, err error) { 266 return DialOpen(defaultDialer{}, dsn) 267} 268 269// DialOpen opens a new connection to the database using a dialer. 270func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) { 271 c, err := NewConnector(dsn) 272 if err != nil { 273 return nil, err 274 } 275 c.dialer = d 276 return c.open(context.Background()) 277} 278 279func (c *Connector) open(ctx context.Context) (cn *conn, err error) { 280 // Handle any panics during connection initialization. Note that we 281 // specifically do *not* want to use errRecover(), as that would turn any 282 // connection errors into ErrBadConns, hiding the real error message from 283 // the user. 284 defer errRecoverNoErrBadConn(&err) 285 286 o := c.opts 287 288 cn = &conn{ 289 opts: o, 290 dialer: c.dialer, 291 } 292 err = cn.handleDriverSettings(o) 293 if err != nil { 294 return nil, err 295 } 296 cn.handlePgpass(o) 297 298 cn.c, err = dial(ctx, c.dialer, o) 299 if err != nil { 300 return nil, err 301 } 302 303 err = cn.ssl(o) 304 if err != nil { 305 if cn.c != nil { 306 cn.c.Close() 307 } 308 return nil, err 309 } 310 311 // cn.startup panics on error. Make sure we don't leak cn.c. 312 panicking := true 313 defer func() { 314 if panicking { 315 cn.c.Close() 316 } 317 }() 318 319 cn.buf = bufio.NewReader(cn.c) 320 cn.startup(o) 321 322 // reset the deadline, in case one was set (see dial) 323 if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { 324 err = cn.c.SetDeadline(time.Time{}) 325 } 326 panicking = false 327 return cn, err 328} 329 330func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) { 331 network, address := network(o) 332 // SSL is not necessary or supported over UNIX domain sockets 333 if network == "unix" { 334 o["sslmode"] = "disable" 335 } 336 337 // Zero or not specified means wait indefinitely. 338 if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { 339 seconds, err := strconv.ParseInt(timeout, 10, 0) 340 if err != nil { 341 return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err) 342 } 343 duration := time.Duration(seconds) * time.Second 344 345 // connect_timeout should apply to the entire connection establishment 346 // procedure, so we both use a timeout for the TCP connection 347 // establishment and set a deadline for doing the initial handshake. 348 // The deadline is then reset after startup() is done. 349 deadline := time.Now().Add(duration) 350 var conn net.Conn 351 if dctx, ok := d.(DialerContext); ok { 352 ctx, cancel := context.WithTimeout(ctx, duration) 353 defer cancel() 354 conn, err = dctx.DialContext(ctx, network, address) 355 } else { 356 conn, err = d.DialTimeout(network, address, duration) 357 } 358 if err != nil { 359 return nil, err 360 } 361 err = conn.SetDeadline(deadline) 362 return conn, err 363 } 364 if dctx, ok := d.(DialerContext); ok { 365 return dctx.DialContext(ctx, network, address) 366 } 367 return d.Dial(network, address) 368} 369 370func network(o values) (string, string) { 371 host := o["host"] 372 373 if strings.HasPrefix(host, "/") { 374 sockPath := path.Join(host, ".s.PGSQL."+o["port"]) 375 return "unix", sockPath 376 } 377 378 return "tcp", net.JoinHostPort(host, o["port"]) 379} 380 381type values map[string]string 382 383// scanner implements a tokenizer for libpq-style option strings. 384type scanner struct { 385 s []rune 386 i int 387} 388 389// newScanner returns a new scanner initialized with the option string s. 390func newScanner(s string) *scanner { 391 return &scanner{[]rune(s), 0} 392} 393 394// Next returns the next rune. 395// It returns 0, false if the end of the text has been reached. 396func (s *scanner) Next() (rune, bool) { 397 if s.i >= len(s.s) { 398 return 0, false 399 } 400 r := s.s[s.i] 401 s.i++ 402 return r, true 403} 404 405// SkipSpaces returns the next non-whitespace rune. 406// It returns 0, false if the end of the text has been reached. 407func (s *scanner) SkipSpaces() (rune, bool) { 408 r, ok := s.Next() 409 for unicode.IsSpace(r) && ok { 410 r, ok = s.Next() 411 } 412 return r, ok 413} 414 415// parseOpts parses the options from name and adds them to the values. 416// 417// The parsing code is based on conninfo_parse from libpq's fe-connect.c 418func parseOpts(name string, o values) error { 419 s := newScanner(name) 420 421 for { 422 var ( 423 keyRunes, valRunes []rune 424 r rune 425 ok bool 426 ) 427 428 if r, ok = s.SkipSpaces(); !ok { 429 break 430 } 431 432 // Scan the key 433 for !unicode.IsSpace(r) && r != '=' { 434 keyRunes = append(keyRunes, r) 435 if r, ok = s.Next(); !ok { 436 break 437 } 438 } 439 440 // Skip any whitespace if we're not at the = yet 441 if r != '=' { 442 r, ok = s.SkipSpaces() 443 } 444 445 // The current character should be = 446 if r != '=' || !ok { 447 return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes)) 448 } 449 450 // Skip any whitespace after the = 451 if r, ok = s.SkipSpaces(); !ok { 452 // If we reach the end here, the last value is just an empty string as per libpq. 453 o[string(keyRunes)] = "" 454 break 455 } 456 457 if r != '\'' { 458 for !unicode.IsSpace(r) { 459 if r == '\\' { 460 if r, ok = s.Next(); !ok { 461 return fmt.Errorf(`missing character after backslash`) 462 } 463 } 464 valRunes = append(valRunes, r) 465 466 if r, ok = s.Next(); !ok { 467 break 468 } 469 } 470 } else { 471 quote: 472 for { 473 if r, ok = s.Next(); !ok { 474 return fmt.Errorf(`unterminated quoted string literal in connection string`) 475 } 476 switch r { 477 case '\'': 478 break quote 479 case '\\': 480 r, _ = s.Next() 481 fallthrough 482 default: 483 valRunes = append(valRunes, r) 484 } 485 } 486 } 487 488 o[string(keyRunes)] = string(valRunes) 489 } 490 491 return nil 492} 493 494func (cn *conn) isInTransaction() bool { 495 return cn.txnStatus == txnStatusIdleInTransaction || 496 cn.txnStatus == txnStatusInFailedTransaction 497} 498 499func (cn *conn) checkIsInTransaction(intxn bool) { 500 if cn.isInTransaction() != intxn { 501 cn.bad = true 502 errorf("unexpected transaction status %v", cn.txnStatus) 503 } 504} 505 506func (cn *conn) Begin() (_ driver.Tx, err error) { 507 return cn.begin("") 508} 509 510func (cn *conn) begin(mode string) (_ driver.Tx, err error) { 511 if cn.bad { 512 return nil, driver.ErrBadConn 513 } 514 defer cn.errRecover(&err) 515 516 cn.checkIsInTransaction(false) 517 _, commandTag, err := cn.simpleExec("BEGIN" + mode) 518 if err != nil { 519 return nil, err 520 } 521 if commandTag != "BEGIN" { 522 cn.bad = true 523 return nil, fmt.Errorf("unexpected command tag %s", commandTag) 524 } 525 if cn.txnStatus != txnStatusIdleInTransaction { 526 cn.bad = true 527 return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus) 528 } 529 return cn, nil 530} 531 532func (cn *conn) closeTxn() { 533 if finish := cn.txnFinish; finish != nil { 534 finish() 535 } 536} 537 538func (cn *conn) Commit() (err error) { 539 defer cn.closeTxn() 540 if cn.bad { 541 return driver.ErrBadConn 542 } 543 defer cn.errRecover(&err) 544 545 cn.checkIsInTransaction(true) 546 // We don't want the client to think that everything is okay if it tries 547 // to commit a failed transaction. However, no matter what we return, 548 // database/sql will release this connection back into the free connection 549 // pool so we have to abort the current transaction here. Note that you 550 // would get the same behaviour if you issued a COMMIT in a failed 551 // transaction, so it's also the least surprising thing to do here. 552 if cn.txnStatus == txnStatusInFailedTransaction { 553 if err := cn.rollback(); err != nil { 554 return err 555 } 556 return ErrInFailedTransaction 557 } 558 559 _, commandTag, err := cn.simpleExec("COMMIT") 560 if err != nil { 561 if cn.isInTransaction() { 562 cn.bad = true 563 } 564 return err 565 } 566 if commandTag != "COMMIT" { 567 cn.bad = true 568 return fmt.Errorf("unexpected command tag %s", commandTag) 569 } 570 cn.checkIsInTransaction(false) 571 return nil 572} 573 574func (cn *conn) Rollback() (err error) { 575 defer cn.closeTxn() 576 if cn.bad { 577 return driver.ErrBadConn 578 } 579 defer cn.errRecover(&err) 580 return cn.rollback() 581} 582 583func (cn *conn) rollback() (err error) { 584 cn.checkIsInTransaction(true) 585 _, commandTag, err := cn.simpleExec("ROLLBACK") 586 if err != nil { 587 if cn.isInTransaction() { 588 cn.bad = true 589 } 590 return err 591 } 592 if commandTag != "ROLLBACK" { 593 return fmt.Errorf("unexpected command tag %s", commandTag) 594 } 595 cn.checkIsInTransaction(false) 596 return nil 597} 598 599func (cn *conn) gname() string { 600 cn.namei++ 601 return strconv.FormatInt(int64(cn.namei), 10) 602} 603 604func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) { 605 b := cn.writeBuf('Q') 606 b.string(q) 607 cn.send(b) 608 609 for { 610 t, r := cn.recv1() 611 switch t { 612 case 'C': 613 res, commandTag = cn.parseComplete(r.string()) 614 case 'Z': 615 cn.processReadyForQuery(r) 616 if res == nil && err == nil { 617 err = errUnexpectedReady 618 } 619 // done 620 return 621 case 'E': 622 err = parseError(r) 623 case 'I': 624 res = emptyRows 625 case 'T', 'D': 626 // ignore any results 627 default: 628 cn.bad = true 629 errorf("unknown response for simple query: %q", t) 630 } 631 } 632} 633 634func (cn *conn) simpleQuery(q string) (res *rows, err error) { 635 defer cn.errRecover(&err) 636 637 b := cn.writeBuf('Q') 638 b.string(q) 639 cn.send(b) 640 641 for { 642 t, r := cn.recv1() 643 switch t { 644 case 'C', 'I': 645 // We allow queries which don't return any results through Query as 646 // well as Exec. We still have to give database/sql a rows object 647 // the user can close, though, to avoid connections from being 648 // leaked. A "rows" with done=true works fine for that purpose. 649 if err != nil { 650 cn.bad = true 651 errorf("unexpected message %q in simple query execution", t) 652 } 653 if res == nil { 654 res = &rows{ 655 cn: cn, 656 } 657 } 658 // Set the result and tag to the last command complete if there wasn't a 659 // query already run. Although queries usually return from here and cede 660 // control to Next, a query with zero results does not. 661 if t == 'C' && res.colNames == nil { 662 res.result, res.tag = cn.parseComplete(r.string()) 663 } 664 res.done = true 665 case 'Z': 666 cn.processReadyForQuery(r) 667 // done 668 return 669 case 'E': 670 res = nil 671 err = parseError(r) 672 case 'D': 673 if res == nil { 674 cn.bad = true 675 errorf("unexpected DataRow in simple query execution") 676 } 677 // the query didn't fail; kick off to Next 678 cn.saveMessage(t, r) 679 return 680 case 'T': 681 // res might be non-nil here if we received a previous 682 // CommandComplete, but that's fine; just overwrite it 683 res = &rows{cn: cn} 684 res.rowsHeader = parsePortalRowDescribe(r) 685 686 // To work around a bug in QueryRow in Go 1.2 and earlier, wait 687 // until the first DataRow has been received. 688 default: 689 cn.bad = true 690 errorf("unknown response for simple query: %q", t) 691 } 692 } 693} 694 695type noRows struct{} 696 697var emptyRows noRows 698 699var _ driver.Result = noRows{} 700 701func (noRows) LastInsertId() (int64, error) { 702 return 0, errNoLastInsertID 703} 704 705func (noRows) RowsAffected() (int64, error) { 706 return 0, errNoRowsAffected 707} 708 709// Decides which column formats to use for a prepared statement. The input is 710// an array of type oids, one element per result column. 711func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte) { 712 if len(colTyps) == 0 { 713 return nil, colFmtDataAllText 714 } 715 716 colFmts = make([]format, len(colTyps)) 717 if forceText { 718 return colFmts, colFmtDataAllText 719 } 720 721 allBinary := true 722 allText := true 723 for i, t := range colTyps { 724 switch t.OID { 725 // This is the list of types to use binary mode for when receiving them 726 // through a prepared statement. If a type appears in this list, it 727 // must also be implemented in binaryDecode in encode.go. 728 case oid.T_bytea: 729 fallthrough 730 case oid.T_int8: 731 fallthrough 732 case oid.T_int4: 733 fallthrough 734 case oid.T_int2: 735 fallthrough 736 case oid.T_uuid: 737 colFmts[i] = formatBinary 738 allText = false 739 740 default: 741 allBinary = false 742 } 743 } 744 745 if allBinary { 746 return colFmts, colFmtDataAllBinary 747 } else if allText { 748 return colFmts, colFmtDataAllText 749 } else { 750 colFmtData = make([]byte, 2+len(colFmts)*2) 751 binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts))) 752 for i, v := range colFmts { 753 binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v)) 754 } 755 return colFmts, colFmtData 756 } 757} 758 759func (cn *conn) prepareTo(q, stmtName string) *stmt { 760 st := &stmt{cn: cn, name: stmtName} 761 762 b := cn.writeBuf('P') 763 b.string(st.name) 764 b.string(q) 765 b.int16(0) 766 767 b.next('D') 768 b.byte('S') 769 b.string(st.name) 770 771 b.next('S') 772 cn.send(b) 773 774 cn.readParseResponse() 775 st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse() 776 st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult) 777 cn.readReadyForQuery() 778 return st 779} 780 781func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) { 782 if cn.bad { 783 return nil, driver.ErrBadConn 784 } 785 defer cn.errRecover(&err) 786 787 if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") { 788 s, err := cn.prepareCopyIn(q) 789 if err == nil { 790 cn.inCopy = true 791 } 792 return s, err 793 } 794 return cn.prepareTo(q, cn.gname()), nil 795} 796 797func (cn *conn) Close() (err error) { 798 // Skip cn.bad return here because we always want to close a connection. 799 defer cn.errRecover(&err) 800 801 // Ensure that cn.c.Close is always run. Since error handling is done with 802 // panics and cn.errRecover, the Close must be in a defer. 803 defer func() { 804 cerr := cn.c.Close() 805 if err == nil { 806 err = cerr 807 } 808 }() 809 810 // Don't go through send(); ListenerConn relies on us not scribbling on the 811 // scratch buffer of this connection. 812 return cn.sendSimpleMessage('X') 813} 814 815// Implement the "Queryer" interface 816func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) { 817 return cn.query(query, args) 818} 819 820func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { 821 if cn.bad { 822 return nil, driver.ErrBadConn 823 } 824 if cn.inCopy { 825 return nil, errCopyInProgress 826 } 827 defer cn.errRecover(&err) 828 829 // Check to see if we can use the "simpleQuery" interface, which is 830 // *much* faster than going through prepare/exec 831 if len(args) == 0 { 832 return cn.simpleQuery(query) 833 } 834 835 if cn.binaryParameters { 836 cn.sendBinaryModeQuery(query, args) 837 838 cn.readParseResponse() 839 cn.readBindResponse() 840 rows := &rows{cn: cn} 841 rows.rowsHeader = cn.readPortalDescribeResponse() 842 cn.postExecuteWorkaround() 843 return rows, nil 844 } 845 st := cn.prepareTo(query, "") 846 st.exec(args) 847 return &rows{ 848 cn: cn, 849 rowsHeader: st.rowsHeader, 850 }, nil 851} 852 853// Implement the optional "Execer" interface for one-shot queries 854func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) { 855 if cn.bad { 856 return nil, driver.ErrBadConn 857 } 858 defer cn.errRecover(&err) 859 860 // Check to see if we can use the "simpleExec" interface, which is 861 // *much* faster than going through prepare/exec 862 if len(args) == 0 { 863 // ignore commandTag, our caller doesn't care 864 r, _, err := cn.simpleExec(query) 865 return r, err 866 } 867 868 if cn.binaryParameters { 869 cn.sendBinaryModeQuery(query, args) 870 871 cn.readParseResponse() 872 cn.readBindResponse() 873 cn.readPortalDescribeResponse() 874 cn.postExecuteWorkaround() 875 res, _, err = cn.readExecuteResponse("Execute") 876 return res, err 877 } 878 // Use the unnamed statement to defer planning until bind 879 // time, or else value-based selectivity estimates cannot be 880 // used. 881 st := cn.prepareTo(query, "") 882 r, err := st.Exec(args) 883 if err != nil { 884 panic(err) 885 } 886 return r, err 887} 888 889func (cn *conn) send(m *writeBuf) { 890 _, err := cn.c.Write(m.wrap()) 891 if err != nil { 892 panic(err) 893 } 894} 895 896func (cn *conn) sendStartupPacket(m *writeBuf) error { 897 _, err := cn.c.Write((m.wrap())[1:]) 898 return err 899} 900 901// Send a message of type typ to the server on the other end of cn. The 902// message should have no payload. This method does not use the scratch 903// buffer. 904func (cn *conn) sendSimpleMessage(typ byte) (err error) { 905 _, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'}) 906 return err 907} 908 909// saveMessage memorizes a message and its buffer in the conn struct. 910// recvMessage will then return these values on the next call to it. This 911// method is useful in cases where you have to see what the next message is 912// going to be (e.g. to see whether it's an error or not) but you can't handle 913// the message yourself. 914func (cn *conn) saveMessage(typ byte, buf *readBuf) { 915 if cn.saveMessageType != 0 { 916 cn.bad = true 917 errorf("unexpected saveMessageType %d", cn.saveMessageType) 918 } 919 cn.saveMessageType = typ 920 cn.saveMessageBuffer = *buf 921} 922 923// recvMessage receives any message from the backend, or returns an error if 924// a problem occurred while reading the message. 925func (cn *conn) recvMessage(r *readBuf) (byte, error) { 926 // workaround for a QueryRow bug, see exec 927 if cn.saveMessageType != 0 { 928 t := cn.saveMessageType 929 *r = cn.saveMessageBuffer 930 cn.saveMessageType = 0 931 cn.saveMessageBuffer = nil 932 return t, nil 933 } 934 935 x := cn.scratch[:5] 936 _, err := io.ReadFull(cn.buf, x) 937 if err != nil { 938 return 0, err 939 } 940 941 // read the type and length of the message that follows 942 t := x[0] 943 n := int(binary.BigEndian.Uint32(x[1:])) - 4 944 var y []byte 945 if n <= len(cn.scratch) { 946 y = cn.scratch[:n] 947 } else { 948 y = make([]byte, n) 949 } 950 _, err = io.ReadFull(cn.buf, y) 951 if err != nil { 952 return 0, err 953 } 954 *r = y 955 return t, nil 956} 957 958// recv receives a message from the backend, but if an error happened while 959// reading the message or the received message was an ErrorResponse, it panics. 960// NoticeResponses are ignored. This function should generally be used only 961// during the startup sequence. 962func (cn *conn) recv() (t byte, r *readBuf) { 963 for { 964 var err error 965 r = &readBuf{} 966 t, err = cn.recvMessage(r) 967 if err != nil { 968 panic(err) 969 } 970 switch t { 971 case 'E': 972 panic(parseError(r)) 973 case 'N': 974 // ignore 975 default: 976 return 977 } 978 } 979} 980 981// recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by 982// the caller to avoid an allocation. 983func (cn *conn) recv1Buf(r *readBuf) byte { 984 for { 985 t, err := cn.recvMessage(r) 986 if err != nil { 987 panic(err) 988 } 989 990 switch t { 991 case 'A', 'N': 992 // ignore 993 case 'S': 994 cn.processParameterStatus(r) 995 default: 996 return t 997 } 998 } 999} 1000 1001// recv1 receives a message from the backend, panicking if an error occurs 1002// while attempting to read it. All asynchronous messages are ignored, with 1003// the exception of ErrorResponse. 1004func (cn *conn) recv1() (t byte, r *readBuf) { 1005 r = &readBuf{} 1006 t = cn.recv1Buf(r) 1007 return t, r 1008} 1009 1010func (cn *conn) ssl(o values) error { 1011 upgrade, err := ssl(o) 1012 if err != nil { 1013 return err 1014 } 1015 1016 if upgrade == nil { 1017 // Nothing to do 1018 return nil 1019 } 1020 1021 w := cn.writeBuf(0) 1022 w.int32(80877103) 1023 if err = cn.sendStartupPacket(w); err != nil { 1024 return err 1025 } 1026 1027 b := cn.scratch[:1] 1028 _, err = io.ReadFull(cn.c, b) 1029 if err != nil { 1030 return err 1031 } 1032 1033 if b[0] != 'S' { 1034 return ErrSSLNotSupported 1035 } 1036 1037 cn.c, err = upgrade(cn.c) 1038 return err 1039} 1040 1041// isDriverSetting returns true iff a setting is purely for configuring the 1042// driver's options and should not be sent to the server in the connection 1043// startup packet. 1044func isDriverSetting(key string) bool { 1045 switch key { 1046 case "host", "port": 1047 return true 1048 case "password": 1049 return true 1050 case "sslmode", "sslcert", "sslkey", "sslrootcert": 1051 return true 1052 case "fallback_application_name": 1053 return true 1054 case "connect_timeout": 1055 return true 1056 case "disable_prepared_binary_result": 1057 return true 1058 case "binary_parameters": 1059 return true 1060 1061 default: 1062 return false 1063 } 1064} 1065 1066func (cn *conn) startup(o values) { 1067 w := cn.writeBuf(0) 1068 w.int32(196608) 1069 // Send the backend the name of the database we want to connect to, and the 1070 // user we want to connect as. Additionally, we send over any run-time 1071 // parameters potentially included in the connection string. If the server 1072 // doesn't recognize any of them, it will reply with an error. 1073 for k, v := range o { 1074 if isDriverSetting(k) { 1075 // skip options which can't be run-time parameters 1076 continue 1077 } 1078 // The protocol requires us to supply the database name as "database" 1079 // instead of "dbname". 1080 if k == "dbname" { 1081 k = "database" 1082 } 1083 w.string(k) 1084 w.string(v) 1085 } 1086 w.string("") 1087 if err := cn.sendStartupPacket(w); err != nil { 1088 panic(err) 1089 } 1090 1091 for { 1092 t, r := cn.recv() 1093 switch t { 1094 case 'K': 1095 cn.processBackendKeyData(r) 1096 case 'S': 1097 cn.processParameterStatus(r) 1098 case 'R': 1099 cn.auth(r, o) 1100 case 'Z': 1101 cn.processReadyForQuery(r) 1102 return 1103 default: 1104 errorf("unknown response for startup: %q", t) 1105 } 1106 } 1107} 1108 1109func (cn *conn) auth(r *readBuf, o values) { 1110 switch code := r.int32(); code { 1111 case 0: 1112 // OK 1113 case 3: 1114 w := cn.writeBuf('p') 1115 w.string(o["password"]) 1116 cn.send(w) 1117 1118 t, r := cn.recv() 1119 if t != 'R' { 1120 errorf("unexpected password response: %q", t) 1121 } 1122 1123 if r.int32() != 0 { 1124 errorf("unexpected authentication response: %q", t) 1125 } 1126 case 5: 1127 s := string(r.next(4)) 1128 w := cn.writeBuf('p') 1129 w.string("md5" + md5s(md5s(o["password"]+o["user"])+s)) 1130 cn.send(w) 1131 1132 t, r := cn.recv() 1133 if t != 'R' { 1134 errorf("unexpected password response: %q", t) 1135 } 1136 1137 if r.int32() != 0 { 1138 errorf("unexpected authentication response: %q", t) 1139 } 1140 case 10: 1141 sc := scram.NewClient(sha256.New, o["user"], o["password"]) 1142 sc.Step(nil) 1143 if sc.Err() != nil { 1144 errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) 1145 } 1146 scOut := sc.Out() 1147 1148 w := cn.writeBuf('p') 1149 w.string("SCRAM-SHA-256") 1150 w.int32(len(scOut)) 1151 w.bytes(scOut) 1152 cn.send(w) 1153 1154 t, r := cn.recv() 1155 if t != 'R' { 1156 errorf("unexpected password response: %q", t) 1157 } 1158 1159 if r.int32() != 11 { 1160 errorf("unexpected authentication response: %q", t) 1161 } 1162 1163 nextStep := r.next(len(*r)) 1164 sc.Step(nextStep) 1165 if sc.Err() != nil { 1166 errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) 1167 } 1168 1169 scOut = sc.Out() 1170 w = cn.writeBuf('p') 1171 w.bytes(scOut) 1172 cn.send(w) 1173 1174 t, r = cn.recv() 1175 if t != 'R' { 1176 errorf("unexpected password response: %q", t) 1177 } 1178 1179 if r.int32() != 12 { 1180 errorf("unexpected authentication response: %q", t) 1181 } 1182 1183 nextStep = r.next(len(*r)) 1184 sc.Step(nextStep) 1185 if sc.Err() != nil { 1186 errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) 1187 } 1188 1189 default: 1190 errorf("unknown authentication response: %d", code) 1191 } 1192} 1193 1194type format int 1195 1196const formatText format = 0 1197const formatBinary format = 1 1198 1199// One result-column format code with the value 1 (i.e. all binary). 1200var colFmtDataAllBinary = []byte{0, 1, 0, 1} 1201 1202// No result-column format codes (i.e. all text). 1203var colFmtDataAllText = []byte{0, 0} 1204 1205type stmt struct { 1206 cn *conn 1207 name string 1208 rowsHeader 1209 colFmtData []byte 1210 paramTyps []oid.Oid 1211 closed bool 1212} 1213 1214func (st *stmt) Close() (err error) { 1215 if st.closed { 1216 return nil 1217 } 1218 if st.cn.bad { 1219 return driver.ErrBadConn 1220 } 1221 defer st.cn.errRecover(&err) 1222 1223 w := st.cn.writeBuf('C') 1224 w.byte('S') 1225 w.string(st.name) 1226 st.cn.send(w) 1227 1228 st.cn.send(st.cn.writeBuf('S')) 1229 1230 t, _ := st.cn.recv1() 1231 if t != '3' { 1232 st.cn.bad = true 1233 errorf("unexpected close response: %q", t) 1234 } 1235 st.closed = true 1236 1237 t, r := st.cn.recv1() 1238 if t != 'Z' { 1239 st.cn.bad = true 1240 errorf("expected ready for query, but got: %q", t) 1241 } 1242 st.cn.processReadyForQuery(r) 1243 1244 return nil 1245} 1246 1247func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { 1248 if st.cn.bad { 1249 return nil, driver.ErrBadConn 1250 } 1251 defer st.cn.errRecover(&err) 1252 1253 st.exec(v) 1254 return &rows{ 1255 cn: st.cn, 1256 rowsHeader: st.rowsHeader, 1257 }, nil 1258} 1259 1260func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) { 1261 if st.cn.bad { 1262 return nil, driver.ErrBadConn 1263 } 1264 defer st.cn.errRecover(&err) 1265 1266 st.exec(v) 1267 res, _, err = st.cn.readExecuteResponse("simple query") 1268 return res, err 1269} 1270 1271func (st *stmt) exec(v []driver.Value) { 1272 if len(v) >= 65536 { 1273 errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v)) 1274 } 1275 if len(v) != len(st.paramTyps) { 1276 errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps)) 1277 } 1278 1279 cn := st.cn 1280 w := cn.writeBuf('B') 1281 w.byte(0) // unnamed portal 1282 w.string(st.name) 1283 1284 if cn.binaryParameters { 1285 cn.sendBinaryParameters(w, v) 1286 } else { 1287 w.int16(0) 1288 w.int16(len(v)) 1289 for i, x := range v { 1290 if x == nil { 1291 w.int32(-1) 1292 } else { 1293 b := encode(&cn.parameterStatus, x, st.paramTyps[i]) 1294 w.int32(len(b)) 1295 w.bytes(b) 1296 } 1297 } 1298 } 1299 w.bytes(st.colFmtData) 1300 1301 w.next('E') 1302 w.byte(0) 1303 w.int32(0) 1304 1305 w.next('S') 1306 cn.send(w) 1307 1308 cn.readBindResponse() 1309 cn.postExecuteWorkaround() 1310 1311} 1312 1313func (st *stmt) NumInput() int { 1314 return len(st.paramTyps) 1315} 1316 1317// parseComplete parses the "command tag" from a CommandComplete message, and 1318// returns the number of rows affected (if applicable) and a string 1319// identifying only the command that was executed, e.g. "ALTER TABLE". If the 1320// command tag could not be parsed, parseComplete panics. 1321func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { 1322 commandsWithAffectedRows := []string{ 1323 "SELECT ", 1324 // INSERT is handled below 1325 "UPDATE ", 1326 "DELETE ", 1327 "FETCH ", 1328 "MOVE ", 1329 "COPY ", 1330 } 1331 1332 var affectedRows *string 1333 for _, tag := range commandsWithAffectedRows { 1334 if strings.HasPrefix(commandTag, tag) { 1335 t := commandTag[len(tag):] 1336 affectedRows = &t 1337 commandTag = tag[:len(tag)-1] 1338 break 1339 } 1340 } 1341 // INSERT also includes the oid of the inserted row in its command tag. 1342 // Oids in user tables are deprecated, and the oid is only returned when 1343 // exactly one row is inserted, so it's unlikely to be of value to any 1344 // real-world application and we can ignore it. 1345 if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") { 1346 parts := strings.Split(commandTag, " ") 1347 if len(parts) != 3 { 1348 cn.bad = true 1349 errorf("unexpected INSERT command tag %s", commandTag) 1350 } 1351 affectedRows = &parts[len(parts)-1] 1352 commandTag = "INSERT" 1353 } 1354 // There should be no affected rows attached to the tag, just return it 1355 if affectedRows == nil { 1356 return driver.RowsAffected(0), commandTag 1357 } 1358 n, err := strconv.ParseInt(*affectedRows, 10, 64) 1359 if err != nil { 1360 cn.bad = true 1361 errorf("could not parse commandTag: %s", err) 1362 } 1363 return driver.RowsAffected(n), commandTag 1364} 1365 1366type rowsHeader struct { 1367 colNames []string 1368 colTyps []fieldDesc 1369 colFmts []format 1370} 1371 1372type rows struct { 1373 cn *conn 1374 finish func() 1375 rowsHeader 1376 done bool 1377 rb readBuf 1378 result driver.Result 1379 tag string 1380 1381 next *rowsHeader 1382} 1383 1384func (rs *rows) Close() error { 1385 if finish := rs.finish; finish != nil { 1386 defer finish() 1387 } 1388 // no need to look at cn.bad as Next() will 1389 for { 1390 err := rs.Next(nil) 1391 switch err { 1392 case nil: 1393 case io.EOF: 1394 // rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row 1395 // description, used with HasNextResultSet). We need to fetch messages until 1396 // we hit a 'Z', which is done by waiting for done to be set. 1397 if rs.done { 1398 return nil 1399 } 1400 default: 1401 return err 1402 } 1403 } 1404} 1405 1406func (rs *rows) Columns() []string { 1407 return rs.colNames 1408} 1409 1410func (rs *rows) Result() driver.Result { 1411 if rs.result == nil { 1412 return emptyRows 1413 } 1414 return rs.result 1415} 1416 1417func (rs *rows) Tag() string { 1418 return rs.tag 1419} 1420 1421func (rs *rows) Next(dest []driver.Value) (err error) { 1422 if rs.done { 1423 return io.EOF 1424 } 1425 1426 conn := rs.cn 1427 if conn.bad { 1428 return driver.ErrBadConn 1429 } 1430 defer conn.errRecover(&err) 1431 1432 for { 1433 t := conn.recv1Buf(&rs.rb) 1434 switch t { 1435 case 'E': 1436 err = parseError(&rs.rb) 1437 case 'C', 'I': 1438 if t == 'C' { 1439 rs.result, rs.tag = conn.parseComplete(rs.rb.string()) 1440 } 1441 continue 1442 case 'Z': 1443 conn.processReadyForQuery(&rs.rb) 1444 rs.done = true 1445 if err != nil { 1446 return err 1447 } 1448 return io.EOF 1449 case 'D': 1450 n := rs.rb.int16() 1451 if err != nil { 1452 conn.bad = true 1453 errorf("unexpected DataRow after error %s", err) 1454 } 1455 if n < len(dest) { 1456 dest = dest[:n] 1457 } 1458 for i := range dest { 1459 l := rs.rb.int32() 1460 if l == -1 { 1461 dest[i] = nil 1462 continue 1463 } 1464 dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i]) 1465 } 1466 return 1467 case 'T': 1468 next := parsePortalRowDescribe(&rs.rb) 1469 rs.next = &next 1470 return io.EOF 1471 default: 1472 errorf("unexpected message after execute: %q", t) 1473 } 1474 } 1475} 1476 1477func (rs *rows) HasNextResultSet() bool { 1478 hasNext := rs.next != nil && !rs.done 1479 return hasNext 1480} 1481 1482func (rs *rows) NextResultSet() error { 1483 if rs.next == nil { 1484 return io.EOF 1485 } 1486 rs.rowsHeader = *rs.next 1487 rs.next = nil 1488 return nil 1489} 1490 1491// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be 1492// used as part of an SQL statement. For example: 1493// 1494// tblname := "my_table" 1495// data := "my_data" 1496// quoted := pq.QuoteIdentifier(tblname) 1497// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data) 1498// 1499// Any double quotes in name will be escaped. The quoted identifier will be 1500// case sensitive when used in a query. If the input string contains a zero 1501// byte, the result will be truncated immediately before it. 1502func QuoteIdentifier(name string) string { 1503 end := strings.IndexRune(name, 0) 1504 if end > -1 { 1505 name = name[:end] 1506 } 1507 return `"` + strings.Replace(name, `"`, `""`, -1) + `"` 1508} 1509 1510// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal 1511// to DDL and other statements that do not accept parameters) to be used as part 1512// of an SQL statement. For example: 1513// 1514// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z") 1515// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date)) 1516// 1517// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be 1518// replaced by two backslashes (i.e. "\\") and the C-style escape identifier 1519// that PostgreSQL provides ('E') will be prepended to the string. 1520func QuoteLiteral(literal string) string { 1521 // This follows the PostgreSQL internal algorithm for handling quoted literals 1522 // from libpq, which can be found in the "PQEscapeStringInternal" function, 1523 // which is found in the libpq/fe-exec.c source file: 1524 // https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c 1525 // 1526 // substitute any single-quotes (') with two single-quotes ('') 1527 literal = strings.Replace(literal, `'`, `''`, -1) 1528 // determine if the string has any backslashes (\) in it. 1529 // if it does, replace any backslashes (\) with two backslashes (\\) 1530 // then, we need to wrap the entire string with a PostgreSQL 1531 // C-style escape. Per how "PQEscapeStringInternal" handles this case, we 1532 // also add a space before the "E" 1533 if strings.Contains(literal, `\`) { 1534 literal = strings.Replace(literal, `\`, `\\`, -1) 1535 literal = ` E'` + literal + `'` 1536 } else { 1537 // otherwise, we can just wrap the literal with a pair of single quotes 1538 literal = `'` + literal + `'` 1539 } 1540 return literal 1541} 1542 1543func md5s(s string) string { 1544 h := md5.New() 1545 h.Write([]byte(s)) 1546 return fmt.Sprintf("%x", h.Sum(nil)) 1547} 1548 1549func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) { 1550 // Do one pass over the parameters to see if we're going to send any of 1551 // them over in binary. If we are, create a paramFormats array at the 1552 // same time. 1553 var paramFormats []int 1554 for i, x := range args { 1555 _, ok := x.([]byte) 1556 if ok { 1557 if paramFormats == nil { 1558 paramFormats = make([]int, len(args)) 1559 } 1560 paramFormats[i] = 1 1561 } 1562 } 1563 if paramFormats == nil { 1564 b.int16(0) 1565 } else { 1566 b.int16(len(paramFormats)) 1567 for _, x := range paramFormats { 1568 b.int16(x) 1569 } 1570 } 1571 1572 b.int16(len(args)) 1573 for _, x := range args { 1574 if x == nil { 1575 b.int32(-1) 1576 } else { 1577 datum := binaryEncode(&cn.parameterStatus, x) 1578 b.int32(len(datum)) 1579 b.bytes(datum) 1580 } 1581 } 1582} 1583 1584func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) { 1585 if len(args) >= 65536 { 1586 errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args)) 1587 } 1588 1589 b := cn.writeBuf('P') 1590 b.byte(0) // unnamed statement 1591 b.string(query) 1592 b.int16(0) 1593 1594 b.next('B') 1595 b.int16(0) // unnamed portal and statement 1596 cn.sendBinaryParameters(b, args) 1597 b.bytes(colFmtDataAllText) 1598 1599 b.next('D') 1600 b.byte('P') 1601 b.byte(0) // unnamed portal 1602 1603 b.next('E') 1604 b.byte(0) 1605 b.int32(0) 1606 1607 b.next('S') 1608 cn.send(b) 1609} 1610 1611func (cn *conn) processParameterStatus(r *readBuf) { 1612 var err error 1613 1614 param := r.string() 1615 switch param { 1616 case "server_version": 1617 var major1 int 1618 var major2 int 1619 var minor int 1620 _, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor) 1621 if err == nil { 1622 cn.parameterStatus.serverVersion = major1*10000 + major2*100 + minor 1623 } 1624 1625 case "TimeZone": 1626 cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string()) 1627 if err != nil { 1628 cn.parameterStatus.currentLocation = nil 1629 } 1630 1631 default: 1632 // ignore 1633 } 1634} 1635 1636func (cn *conn) processReadyForQuery(r *readBuf) { 1637 cn.txnStatus = transactionStatus(r.byte()) 1638} 1639 1640func (cn *conn) readReadyForQuery() { 1641 t, r := cn.recv1() 1642 switch t { 1643 case 'Z': 1644 cn.processReadyForQuery(r) 1645 return 1646 default: 1647 cn.bad = true 1648 errorf("unexpected message %q; expected ReadyForQuery", t) 1649 } 1650} 1651 1652func (cn *conn) processBackendKeyData(r *readBuf) { 1653 cn.processID = r.int32() 1654 cn.secretKey = r.int32() 1655} 1656 1657func (cn *conn) readParseResponse() { 1658 t, r := cn.recv1() 1659 switch t { 1660 case '1': 1661 return 1662 case 'E': 1663 err := parseError(r) 1664 cn.readReadyForQuery() 1665 panic(err) 1666 default: 1667 cn.bad = true 1668 errorf("unexpected Parse response %q", t) 1669 } 1670} 1671 1672func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc) { 1673 for { 1674 t, r := cn.recv1() 1675 switch t { 1676 case 't': 1677 nparams := r.int16() 1678 paramTyps = make([]oid.Oid, nparams) 1679 for i := range paramTyps { 1680 paramTyps[i] = r.oid() 1681 } 1682 case 'n': 1683 return paramTyps, nil, nil 1684 case 'T': 1685 colNames, colTyps = parseStatementRowDescribe(r) 1686 return paramTyps, colNames, colTyps 1687 case 'E': 1688 err := parseError(r) 1689 cn.readReadyForQuery() 1690 panic(err) 1691 default: 1692 cn.bad = true 1693 errorf("unexpected Describe statement response %q", t) 1694 } 1695 } 1696} 1697 1698func (cn *conn) readPortalDescribeResponse() rowsHeader { 1699 t, r := cn.recv1() 1700 switch t { 1701 case 'T': 1702 return parsePortalRowDescribe(r) 1703 case 'n': 1704 return rowsHeader{} 1705 case 'E': 1706 err := parseError(r) 1707 cn.readReadyForQuery() 1708 panic(err) 1709 default: 1710 cn.bad = true 1711 errorf("unexpected Describe response %q", t) 1712 } 1713 panic("not reached") 1714} 1715 1716func (cn *conn) readBindResponse() { 1717 t, r := cn.recv1() 1718 switch t { 1719 case '2': 1720 return 1721 case 'E': 1722 err := parseError(r) 1723 cn.readReadyForQuery() 1724 panic(err) 1725 default: 1726 cn.bad = true 1727 errorf("unexpected Bind response %q", t) 1728 } 1729} 1730 1731func (cn *conn) postExecuteWorkaround() { 1732 // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores 1733 // any errors from rows.Next, which masks errors that happened during the 1734 // execution of the query. To avoid the problem in common cases, we wait 1735 // here for one more message from the database. If it's not an error the 1736 // query will likely succeed (or perhaps has already, if it's a 1737 // CommandComplete), so we push the message into the conn struct; recv1 1738 // will return it as the next message for rows.Next or rows.Close. 1739 // However, if it's an error, we wait until ReadyForQuery and then return 1740 // the error to our caller. 1741 for { 1742 t, r := cn.recv1() 1743 switch t { 1744 case 'E': 1745 err := parseError(r) 1746 cn.readReadyForQuery() 1747 panic(err) 1748 case 'C', 'D', 'I': 1749 // the query didn't fail, but we can't process this message 1750 cn.saveMessage(t, r) 1751 return 1752 default: 1753 cn.bad = true 1754 errorf("unexpected message during extended query execution: %q", t) 1755 } 1756 } 1757} 1758 1759// Only for Exec(), since we ignore the returned data 1760func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) { 1761 for { 1762 t, r := cn.recv1() 1763 switch t { 1764 case 'C': 1765 if err != nil { 1766 cn.bad = true 1767 errorf("unexpected CommandComplete after error %s", err) 1768 } 1769 res, commandTag = cn.parseComplete(r.string()) 1770 case 'Z': 1771 cn.processReadyForQuery(r) 1772 if res == nil && err == nil { 1773 err = errUnexpectedReady 1774 } 1775 return res, commandTag, err 1776 case 'E': 1777 err = parseError(r) 1778 case 'T', 'D', 'I': 1779 if err != nil { 1780 cn.bad = true 1781 errorf("unexpected %q after error %s", t, err) 1782 } 1783 if t == 'I' { 1784 res = emptyRows 1785 } 1786 // ignore any results 1787 default: 1788 cn.bad = true 1789 errorf("unknown %s response: %q", protocolState, t) 1790 } 1791 } 1792} 1793 1794func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) { 1795 n := r.int16() 1796 colNames = make([]string, n) 1797 colTyps = make([]fieldDesc, n) 1798 for i := range colNames { 1799 colNames[i] = r.string() 1800 r.next(6) 1801 colTyps[i].OID = r.oid() 1802 colTyps[i].Len = r.int16() 1803 colTyps[i].Mod = r.int32() 1804 // format code not known when describing a statement; always 0 1805 r.next(2) 1806 } 1807 return 1808} 1809 1810func parsePortalRowDescribe(r *readBuf) rowsHeader { 1811 n := r.int16() 1812 colNames := make([]string, n) 1813 colFmts := make([]format, n) 1814 colTyps := make([]fieldDesc, n) 1815 for i := range colNames { 1816 colNames[i] = r.string() 1817 r.next(6) 1818 colTyps[i].OID = r.oid() 1819 colTyps[i].Len = r.int16() 1820 colTyps[i].Mod = r.int32() 1821 colFmts[i] = format(r.int16()) 1822 } 1823 return rowsHeader{ 1824 colNames: colNames, 1825 colFmts: colFmts, 1826 colTyps: colTyps, 1827 } 1828} 1829 1830// parseEnviron tries to mimic some of libpq's environment handling 1831// 1832// To ease testing, it does not directly reference os.Environ, but is 1833// designed to accept its output. 1834// 1835// Environment-set connection information is intended to have a higher 1836// precedence than a library default but lower than any explicitly 1837// passed information (such as in the URL or connection string). 1838func parseEnviron(env []string) (out map[string]string) { 1839 out = make(map[string]string) 1840 1841 for _, v := range env { 1842 parts := strings.SplitN(v, "=", 2) 1843 1844 accrue := func(keyname string) { 1845 out[keyname] = parts[1] 1846 } 1847 unsupported := func() { 1848 panic(fmt.Sprintf("setting %v not supported", parts[0])) 1849 } 1850 1851 // The order of these is the same as is seen in the 1852 // PostgreSQL 9.1 manual. Unsupported but well-defined 1853 // keys cause a panic; these should be unset prior to 1854 // execution. Options which pq expects to be set to a 1855 // certain value are allowed, but must be set to that 1856 // value if present (they can, of course, be absent). 1857 switch parts[0] { 1858 case "PGHOST": 1859 accrue("host") 1860 case "PGHOSTADDR": 1861 unsupported() 1862 case "PGPORT": 1863 accrue("port") 1864 case "PGDATABASE": 1865 accrue("dbname") 1866 case "PGUSER": 1867 accrue("user") 1868 case "PGPASSWORD": 1869 accrue("password") 1870 case "PGSERVICE", "PGSERVICEFILE", "PGREALM": 1871 unsupported() 1872 case "PGOPTIONS": 1873 accrue("options") 1874 case "PGAPPNAME": 1875 accrue("application_name") 1876 case "PGSSLMODE": 1877 accrue("sslmode") 1878 case "PGSSLCERT": 1879 accrue("sslcert") 1880 case "PGSSLKEY": 1881 accrue("sslkey") 1882 case "PGSSLROOTCERT": 1883 accrue("sslrootcert") 1884 case "PGREQUIRESSL", "PGSSLCRL": 1885 unsupported() 1886 case "PGREQUIREPEER": 1887 unsupported() 1888 case "PGKRBSRVNAME", "PGGSSLIB": 1889 unsupported() 1890 case "PGCONNECT_TIMEOUT": 1891 accrue("connect_timeout") 1892 case "PGCLIENTENCODING": 1893 accrue("client_encoding") 1894 case "PGDATESTYLE": 1895 accrue("datestyle") 1896 case "PGTZ": 1897 accrue("timezone") 1898 case "PGGEQO": 1899 accrue("geqo") 1900 case "PGSYSCONFDIR", "PGLOCALEDIR": 1901 unsupported() 1902 } 1903 } 1904 1905 return out 1906} 1907 1908// isUTF8 returns whether name is a fuzzy variation of the string "UTF-8". 1909func isUTF8(name string) bool { 1910 // Recognize all sorts of silly things as "UTF-8", like Postgres does 1911 s := strings.Map(alnumLowerASCII, name) 1912 return s == "utf8" || s == "unicode" 1913} 1914 1915func alnumLowerASCII(ch rune) rune { 1916 if 'A' <= ch && ch <= 'Z' { 1917 return ch + ('a' - 'A') 1918 } 1919 if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' { 1920 return ch 1921 } 1922 return -1 // discard 1923} 1924