1package pq 2 3import ( 4 "bufio" 5 "context" 6 "crypto/md5" 7 "crypto/sha256" 8 "database/sql" 9 "database/sql/driver" 10 "encoding/binary" 11 "errors" 12 "fmt" 13 "io" 14 "net" 15 "os" 16 "os/user" 17 "path" 18 "path/filepath" 19 "strconv" 20 "strings" 21 "sync/atomic" 22 "time" 23 "unicode" 24 25 "github.com/lib/pq/oid" 26 "github.com/lib/pq/scram" 27) 28 29// Common error types 30var ( 31 ErrNotSupported = errors.New("pq: Unsupported command") 32 ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction") 33 ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server") 34 ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less") 35 ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly") 36 37 errUnexpectedReady = errors.New("unexpected ReadyForQuery") 38 errNoRowsAffected = errors.New("no RowsAffected available after the empty statement") 39 errNoLastInsertID = errors.New("no LastInsertId available after the empty statement") 40) 41 42// Compile time validation that our types implement the expected interfaces 43var ( 44 _ driver.Driver = Driver{} 45) 46 47// Driver is the Postgres database driver. 48type Driver struct{} 49 50// Open opens a new connection to the database. name is a connection string. 51// Most users should only use it through database/sql package from the standard 52// library. 53func (d Driver) Open(name string) (driver.Conn, error) { 54 return Open(name) 55} 56 57func init() { 58 sql.Register("postgres", &Driver{}) 59} 60 61type parameterStatus struct { 62 // server version in the same format as server_version_num, or 0 if 63 // unavailable 64 serverVersion int 65 66 // the current location based on the TimeZone value of the session, if 67 // available 68 currentLocation *time.Location 69} 70 71type transactionStatus byte 72 73const ( 74 txnStatusIdle transactionStatus = 'I' 75 txnStatusIdleInTransaction transactionStatus = 'T' 76 txnStatusInFailedTransaction transactionStatus = 'E' 77) 78 79func (s transactionStatus) String() string { 80 switch s { 81 case txnStatusIdle: 82 return "idle" 83 case txnStatusIdleInTransaction: 84 return "idle in transaction" 85 case txnStatusInFailedTransaction: 86 return "in a failed transaction" 87 default: 88 errorf("unknown transactionStatus %d", s) 89 } 90 91 panic("not reached") 92} 93 94// Dialer is the dialer interface. It can be used to obtain more control over 95// how pq creates network connections. 96type Dialer interface { 97 Dial(network, address string) (net.Conn, error) 98 DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) 99} 100 101// DialerContext is the context-aware dialer interface. 102type DialerContext interface { 103 DialContext(ctx context.Context, network, address string) (net.Conn, error) 104} 105 106type defaultDialer struct { 107 d net.Dialer 108} 109 110func (d defaultDialer) Dial(network, address string) (net.Conn, error) { 111 return d.d.Dial(network, address) 112} 113func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { 114 ctx, cancel := context.WithTimeout(context.Background(), timeout) 115 defer cancel() 116 return d.DialContext(ctx, network, address) 117} 118func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { 119 return d.d.DialContext(ctx, network, address) 120} 121 122type conn struct { 123 c net.Conn 124 buf *bufio.Reader 125 namei int 126 scratch [512]byte 127 txnStatus transactionStatus 128 txnFinish func() 129 130 // Save connection arguments to use during CancelRequest. 131 dialer Dialer 132 opts values 133 134 // Cancellation key data for use with CancelRequest messages. 135 processID int 136 secretKey int 137 138 parameterStatus parameterStatus 139 140 saveMessageType byte 141 saveMessageBuffer []byte 142 143 // If true, this connection is bad and all public-facing functions should 144 // return ErrBadConn. 145 bad *atomic.Value 146 147 // If set, this connection should never use the binary format when 148 // receiving query results from prepared statements. Only provided for 149 // debugging. 150 disablePreparedBinaryResult bool 151 152 // Whether to always send []byte parameters over as binary. Enables single 153 // round-trip mode for non-prepared Query calls. 154 binaryParameters bool 155 156 // If true this connection is in the middle of a COPY 157 inCopy bool 158 159 // If not nil, notices will be synchronously sent here 160 noticeHandler func(*Error) 161 162 // If not nil, notifications will be synchronously sent here 163 notificationHandler func(*Notification) 164 165 // GSSAPI context 166 gss GSS 167} 168 169// Handle driver-side settings in parsed connection string. 170func (cn *conn) handleDriverSettings(o values) (err error) { 171 boolSetting := func(key string, val *bool) error { 172 if value, ok := o[key]; ok { 173 if value == "yes" { 174 *val = true 175 } else if value == "no" { 176 *val = false 177 } else { 178 return fmt.Errorf("unrecognized value %q for %s", value, key) 179 } 180 } 181 return nil 182 } 183 184 err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult) 185 if err != nil { 186 return err 187 } 188 return boolSetting("binary_parameters", &cn.binaryParameters) 189} 190 191func (cn *conn) handlePgpass(o values) { 192 // if a password was supplied, do not process .pgpass 193 if _, ok := o["password"]; ok { 194 return 195 } 196 filename := os.Getenv("PGPASSFILE") 197 if filename == "" { 198 // XXX this code doesn't work on Windows where the default filename is 199 // XXX %APPDATA%\postgresql\pgpass.conf 200 // Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470 201 userHome := os.Getenv("HOME") 202 if userHome == "" { 203 user, err := user.Current() 204 if err != nil { 205 return 206 } 207 userHome = user.HomeDir 208 } 209 filename = filepath.Join(userHome, ".pgpass") 210 } 211 fileinfo, err := os.Stat(filename) 212 if err != nil { 213 return 214 } 215 mode := fileinfo.Mode() 216 if mode&(0x77) != 0 { 217 // XXX should warn about incorrect .pgpass permissions as psql does 218 return 219 } 220 file, err := os.Open(filename) 221 if err != nil { 222 return 223 } 224 defer file.Close() 225 scanner := bufio.NewScanner(io.Reader(file)) 226 hostname := o["host"] 227 ntw, _ := network(o) 228 port := o["port"] 229 db := o["dbname"] 230 username := o["user"] 231 // From: https://github.com/tg/pgpass/blob/master/reader.go 232 getFields := func(s string) []string { 233 fs := make([]string, 0, 5) 234 f := make([]rune, 0, len(s)) 235 236 var esc bool 237 for _, c := range s { 238 switch { 239 case esc: 240 f = append(f, c) 241 esc = false 242 case c == '\\': 243 esc = true 244 case c == ':': 245 fs = append(fs, string(f)) 246 f = f[:0] 247 default: 248 f = append(f, c) 249 } 250 } 251 return append(fs, string(f)) 252 } 253 for scanner.Scan() { 254 line := scanner.Text() 255 if len(line) == 0 || line[0] == '#' { 256 continue 257 } 258 split := getFields(line) 259 if len(split) != 5 { 260 continue 261 } 262 if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) { 263 o["password"] = split[4] 264 return 265 } 266 } 267} 268 269func (cn *conn) writeBuf(b byte) *writeBuf { 270 cn.scratch[0] = b 271 return &writeBuf{ 272 buf: cn.scratch[:5], 273 pos: 1, 274 } 275} 276 277// Open opens a new connection to the database. dsn is a connection string. 278// Most users should only use it through database/sql package from the standard 279// library. 280func Open(dsn string) (_ driver.Conn, err error) { 281 return DialOpen(defaultDialer{}, dsn) 282} 283 284// DialOpen opens a new connection to the database using a dialer. 285func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) { 286 c, err := NewConnector(dsn) 287 if err != nil { 288 return nil, err 289 } 290 c.dialer = d 291 return c.open(context.Background()) 292} 293 294func (c *Connector) open(ctx context.Context) (cn *conn, err error) { 295 // Handle any panics during connection initialization. Note that we 296 // specifically do *not* want to use errRecover(), as that would turn any 297 // connection errors into ErrBadConns, hiding the real error message from 298 // the user. 299 defer errRecoverNoErrBadConn(&err) 300 301 // Create a new values map (copy). This makes it so maps in different 302 // connections do not reference the same underlying data structure, so it 303 // is safe for multiple connections to concurrently write to their opts. 304 o := make(values) 305 for k, v := range c.opts { 306 o[k] = v 307 } 308 309 bad := &atomic.Value{} 310 bad.Store(false) 311 cn = &conn{ 312 opts: o, 313 dialer: c.dialer, 314 bad: bad, 315 } 316 err = cn.handleDriverSettings(o) 317 if err != nil { 318 return nil, err 319 } 320 cn.handlePgpass(o) 321 322 cn.c, err = dial(ctx, c.dialer, o) 323 if err != nil { 324 return nil, err 325 } 326 327 err = cn.ssl(o) 328 if err != nil { 329 if cn.c != nil { 330 cn.c.Close() 331 } 332 return nil, err 333 } 334 335 // cn.startup panics on error. Make sure we don't leak cn.c. 336 panicking := true 337 defer func() { 338 if panicking { 339 cn.c.Close() 340 } 341 }() 342 343 cn.buf = bufio.NewReader(cn.c) 344 cn.startup(o) 345 346 // reset the deadline, in case one was set (see dial) 347 if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { 348 err = cn.c.SetDeadline(time.Time{}) 349 } 350 panicking = false 351 return cn, err 352} 353 354func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) { 355 network, address := network(o) 356 357 // Zero or not specified means wait indefinitely. 358 if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { 359 seconds, err := strconv.ParseInt(timeout, 10, 0) 360 if err != nil { 361 return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err) 362 } 363 duration := time.Duration(seconds) * time.Second 364 365 // connect_timeout should apply to the entire connection establishment 366 // procedure, so we both use a timeout for the TCP connection 367 // establishment and set a deadline for doing the initial handshake. 368 // The deadline is then reset after startup() is done. 369 deadline := time.Now().Add(duration) 370 var conn net.Conn 371 if dctx, ok := d.(DialerContext); ok { 372 ctx, cancel := context.WithTimeout(ctx, duration) 373 defer cancel() 374 conn, err = dctx.DialContext(ctx, network, address) 375 } else { 376 conn, err = d.DialTimeout(network, address, duration) 377 } 378 if err != nil { 379 return nil, err 380 } 381 err = conn.SetDeadline(deadline) 382 return conn, err 383 } 384 if dctx, ok := d.(DialerContext); ok { 385 return dctx.DialContext(ctx, network, address) 386 } 387 return d.Dial(network, address) 388} 389 390func network(o values) (string, string) { 391 host := o["host"] 392 393 if strings.HasPrefix(host, "/") { 394 sockPath := path.Join(host, ".s.PGSQL."+o["port"]) 395 return "unix", sockPath 396 } 397 398 return "tcp", net.JoinHostPort(host, o["port"]) 399} 400 401type values map[string]string 402 403// scanner implements a tokenizer for libpq-style option strings. 404type scanner struct { 405 s []rune 406 i int 407} 408 409// newScanner returns a new scanner initialized with the option string s. 410func newScanner(s string) *scanner { 411 return &scanner{[]rune(s), 0} 412} 413 414// Next returns the next rune. 415// It returns 0, false if the end of the text has been reached. 416func (s *scanner) Next() (rune, bool) { 417 if s.i >= len(s.s) { 418 return 0, false 419 } 420 r := s.s[s.i] 421 s.i++ 422 return r, true 423} 424 425// SkipSpaces returns the next non-whitespace rune. 426// It returns 0, false if the end of the text has been reached. 427func (s *scanner) SkipSpaces() (rune, bool) { 428 r, ok := s.Next() 429 for unicode.IsSpace(r) && ok { 430 r, ok = s.Next() 431 } 432 return r, ok 433} 434 435// parseOpts parses the options from name and adds them to the values. 436// 437// The parsing code is based on conninfo_parse from libpq's fe-connect.c 438func parseOpts(name string, o values) error { 439 s := newScanner(name) 440 441 for { 442 var ( 443 keyRunes, valRunes []rune 444 r rune 445 ok bool 446 ) 447 448 if r, ok = s.SkipSpaces(); !ok { 449 break 450 } 451 452 // Scan the key 453 for !unicode.IsSpace(r) && r != '=' { 454 keyRunes = append(keyRunes, r) 455 if r, ok = s.Next(); !ok { 456 break 457 } 458 } 459 460 // Skip any whitespace if we're not at the = yet 461 if r != '=' { 462 r, ok = s.SkipSpaces() 463 } 464 465 // The current character should be = 466 if r != '=' || !ok { 467 return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes)) 468 } 469 470 // Skip any whitespace after the = 471 if r, ok = s.SkipSpaces(); !ok { 472 // If we reach the end here, the last value is just an empty string as per libpq. 473 o[string(keyRunes)] = "" 474 break 475 } 476 477 if r != '\'' { 478 for !unicode.IsSpace(r) { 479 if r == '\\' { 480 if r, ok = s.Next(); !ok { 481 return fmt.Errorf(`missing character after backslash`) 482 } 483 } 484 valRunes = append(valRunes, r) 485 486 if r, ok = s.Next(); !ok { 487 break 488 } 489 } 490 } else { 491 quote: 492 for { 493 if r, ok = s.Next(); !ok { 494 return fmt.Errorf(`unterminated quoted string literal in connection string`) 495 } 496 switch r { 497 case '\'': 498 break quote 499 case '\\': 500 r, _ = s.Next() 501 fallthrough 502 default: 503 valRunes = append(valRunes, r) 504 } 505 } 506 } 507 508 o[string(keyRunes)] = string(valRunes) 509 } 510 511 return nil 512} 513 514func (cn *conn) isInTransaction() bool { 515 return cn.txnStatus == txnStatusIdleInTransaction || 516 cn.txnStatus == txnStatusInFailedTransaction 517} 518 519func (cn *conn) setBad() { 520 if cn.bad != nil { 521 cn.bad.Store(true) 522 } 523} 524 525func (cn *conn) getBad() bool { 526 if cn.bad != nil { 527 return cn.bad.Load().(bool) 528 } 529 return false 530} 531 532func (cn *conn) checkIsInTransaction(intxn bool) { 533 if cn.isInTransaction() != intxn { 534 cn.setBad() 535 errorf("unexpected transaction status %v", cn.txnStatus) 536 } 537} 538 539func (cn *conn) Begin() (_ driver.Tx, err error) { 540 return cn.begin("") 541} 542 543func (cn *conn) begin(mode string) (_ driver.Tx, err error) { 544 if cn.getBad() { 545 return nil, driver.ErrBadConn 546 } 547 defer cn.errRecover(&err) 548 549 cn.checkIsInTransaction(false) 550 _, commandTag, err := cn.simpleExec("BEGIN" + mode) 551 if err != nil { 552 return nil, err 553 } 554 if commandTag != "BEGIN" { 555 cn.setBad() 556 return nil, fmt.Errorf("unexpected command tag %s", commandTag) 557 } 558 if cn.txnStatus != txnStatusIdleInTransaction { 559 cn.setBad() 560 return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus) 561 } 562 return cn, nil 563} 564 565func (cn *conn) closeTxn() { 566 if finish := cn.txnFinish; finish != nil { 567 finish() 568 } 569} 570 571func (cn *conn) Commit() (err error) { 572 defer cn.closeTxn() 573 if cn.getBad() { 574 return driver.ErrBadConn 575 } 576 defer cn.errRecover(&err) 577 578 cn.checkIsInTransaction(true) 579 // We don't want the client to think that everything is okay if it tries 580 // to commit a failed transaction. However, no matter what we return, 581 // database/sql will release this connection back into the free connection 582 // pool so we have to abort the current transaction here. Note that you 583 // would get the same behaviour if you issued a COMMIT in a failed 584 // transaction, so it's also the least surprising thing to do here. 585 if cn.txnStatus == txnStatusInFailedTransaction { 586 if err := cn.rollback(); err != nil { 587 return err 588 } 589 return ErrInFailedTransaction 590 } 591 592 _, commandTag, err := cn.simpleExec("COMMIT") 593 if err != nil { 594 if cn.isInTransaction() { 595 cn.setBad() 596 } 597 return err 598 } 599 if commandTag != "COMMIT" { 600 cn.setBad() 601 return fmt.Errorf("unexpected command tag %s", commandTag) 602 } 603 cn.checkIsInTransaction(false) 604 return nil 605} 606 607func (cn *conn) Rollback() (err error) { 608 defer cn.closeTxn() 609 if cn.getBad() { 610 return driver.ErrBadConn 611 } 612 defer cn.errRecover(&err) 613 return cn.rollback() 614} 615 616func (cn *conn) rollback() (err error) { 617 cn.checkIsInTransaction(true) 618 _, commandTag, err := cn.simpleExec("ROLLBACK") 619 if err != nil { 620 if cn.isInTransaction() { 621 cn.setBad() 622 } 623 return err 624 } 625 if commandTag != "ROLLBACK" { 626 return fmt.Errorf("unexpected command tag %s", commandTag) 627 } 628 cn.checkIsInTransaction(false) 629 return nil 630} 631 632func (cn *conn) gname() string { 633 cn.namei++ 634 return strconv.FormatInt(int64(cn.namei), 10) 635} 636 637func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) { 638 b := cn.writeBuf('Q') 639 b.string(q) 640 cn.send(b) 641 642 for { 643 t, r := cn.recv1() 644 switch t { 645 case 'C': 646 res, commandTag = cn.parseComplete(r.string()) 647 case 'Z': 648 cn.processReadyForQuery(r) 649 if res == nil && err == nil { 650 err = errUnexpectedReady 651 } 652 // done 653 return 654 case 'E': 655 err = parseError(r) 656 case 'I': 657 res = emptyRows 658 case 'T', 'D': 659 // ignore any results 660 default: 661 cn.setBad() 662 errorf("unknown response for simple query: %q", t) 663 } 664 } 665} 666 667func (cn *conn) simpleQuery(q string) (res *rows, err error) { 668 defer cn.errRecover(&err) 669 670 b := cn.writeBuf('Q') 671 b.string(q) 672 cn.send(b) 673 674 for { 675 t, r := cn.recv1() 676 switch t { 677 case 'C', 'I': 678 // We allow queries which don't return any results through Query as 679 // well as Exec. We still have to give database/sql a rows object 680 // the user can close, though, to avoid connections from being 681 // leaked. A "rows" with done=true works fine for that purpose. 682 if err != nil { 683 cn.setBad() 684 errorf("unexpected message %q in simple query execution", t) 685 } 686 if res == nil { 687 res = &rows{ 688 cn: cn, 689 } 690 } 691 // Set the result and tag to the last command complete if there wasn't a 692 // query already run. Although queries usually return from here and cede 693 // control to Next, a query with zero results does not. 694 if t == 'C' { 695 res.result, res.tag = cn.parseComplete(r.string()) 696 if res.colNames != nil { 697 return 698 } 699 } 700 res.done = true 701 case 'Z': 702 cn.processReadyForQuery(r) 703 // done 704 return 705 case 'E': 706 res = nil 707 err = parseError(r) 708 case 'D': 709 if res == nil { 710 cn.setBad() 711 errorf("unexpected DataRow in simple query execution") 712 } 713 // the query didn't fail; kick off to Next 714 cn.saveMessage(t, r) 715 return 716 case 'T': 717 // res might be non-nil here if we received a previous 718 // CommandComplete, but that's fine; just overwrite it 719 res = &rows{cn: cn} 720 res.rowsHeader = parsePortalRowDescribe(r) 721 722 // To work around a bug in QueryRow in Go 1.2 and earlier, wait 723 // until the first DataRow has been received. 724 default: 725 cn.setBad() 726 errorf("unknown response for simple query: %q", t) 727 } 728 } 729} 730 731type noRows struct{} 732 733var emptyRows noRows 734 735var _ driver.Result = noRows{} 736 737func (noRows) LastInsertId() (int64, error) { 738 return 0, errNoLastInsertID 739} 740 741func (noRows) RowsAffected() (int64, error) { 742 return 0, errNoRowsAffected 743} 744 745// Decides which column formats to use for a prepared statement. The input is 746// an array of type oids, one element per result column. 747func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte) { 748 if len(colTyps) == 0 { 749 return nil, colFmtDataAllText 750 } 751 752 colFmts = make([]format, len(colTyps)) 753 if forceText { 754 return colFmts, colFmtDataAllText 755 } 756 757 allBinary := true 758 allText := true 759 for i, t := range colTyps { 760 switch t.OID { 761 // This is the list of types to use binary mode for when receiving them 762 // through a prepared statement. If a type appears in this list, it 763 // must also be implemented in binaryDecode in encode.go. 764 case oid.T_bytea: 765 fallthrough 766 case oid.T_int8: 767 fallthrough 768 case oid.T_int4: 769 fallthrough 770 case oid.T_int2: 771 fallthrough 772 case oid.T_uuid: 773 colFmts[i] = formatBinary 774 allText = false 775 776 default: 777 allBinary = false 778 } 779 } 780 781 if allBinary { 782 return colFmts, colFmtDataAllBinary 783 } else if allText { 784 return colFmts, colFmtDataAllText 785 } else { 786 colFmtData = make([]byte, 2+len(colFmts)*2) 787 binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts))) 788 for i, v := range colFmts { 789 binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v)) 790 } 791 return colFmts, colFmtData 792 } 793} 794 795func (cn *conn) prepareTo(q, stmtName string) *stmt { 796 st := &stmt{cn: cn, name: stmtName} 797 798 b := cn.writeBuf('P') 799 b.string(st.name) 800 b.string(q) 801 b.int16(0) 802 803 b.next('D') 804 b.byte('S') 805 b.string(st.name) 806 807 b.next('S') 808 cn.send(b) 809 810 cn.readParseResponse() 811 st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse() 812 st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult) 813 cn.readReadyForQuery() 814 return st 815} 816 817func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) { 818 if cn.getBad() { 819 return nil, driver.ErrBadConn 820 } 821 defer cn.errRecover(&err) 822 823 if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") { 824 s, err := cn.prepareCopyIn(q) 825 if err == nil { 826 cn.inCopy = true 827 } 828 return s, err 829 } 830 return cn.prepareTo(q, cn.gname()), nil 831} 832 833func (cn *conn) Close() (err error) { 834 // Skip cn.bad return here because we always want to close a connection. 835 defer cn.errRecover(&err) 836 837 // Ensure that cn.c.Close is always run. Since error handling is done with 838 // panics and cn.errRecover, the Close must be in a defer. 839 defer func() { 840 cerr := cn.c.Close() 841 if err == nil { 842 err = cerr 843 } 844 }() 845 846 // Don't go through send(); ListenerConn relies on us not scribbling on the 847 // scratch buffer of this connection. 848 return cn.sendSimpleMessage('X') 849} 850 851// Implement the "Queryer" interface 852func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) { 853 return cn.query(query, args) 854} 855 856func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { 857 if cn.getBad() { 858 return nil, driver.ErrBadConn 859 } 860 if cn.inCopy { 861 return nil, errCopyInProgress 862 } 863 defer cn.errRecover(&err) 864 865 // Check to see if we can use the "simpleQuery" interface, which is 866 // *much* faster than going through prepare/exec 867 if len(args) == 0 { 868 return cn.simpleQuery(query) 869 } 870 871 if cn.binaryParameters { 872 cn.sendBinaryModeQuery(query, args) 873 874 cn.readParseResponse() 875 cn.readBindResponse() 876 rows := &rows{cn: cn} 877 rows.rowsHeader = cn.readPortalDescribeResponse() 878 cn.postExecuteWorkaround() 879 return rows, nil 880 } 881 st := cn.prepareTo(query, "") 882 st.exec(args) 883 return &rows{ 884 cn: cn, 885 rowsHeader: st.rowsHeader, 886 }, nil 887} 888 889// Implement the optional "Execer" interface for one-shot queries 890func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) { 891 if cn.getBad() { 892 return nil, driver.ErrBadConn 893 } 894 defer cn.errRecover(&err) 895 896 // Check to see if we can use the "simpleExec" interface, which is 897 // *much* faster than going through prepare/exec 898 if len(args) == 0 { 899 // ignore commandTag, our caller doesn't care 900 r, _, err := cn.simpleExec(query) 901 return r, err 902 } 903 904 if cn.binaryParameters { 905 cn.sendBinaryModeQuery(query, args) 906 907 cn.readParseResponse() 908 cn.readBindResponse() 909 cn.readPortalDescribeResponse() 910 cn.postExecuteWorkaround() 911 res, _, err = cn.readExecuteResponse("Execute") 912 return res, err 913 } 914 // Use the unnamed statement to defer planning until bind 915 // time, or else value-based selectivity estimates cannot be 916 // used. 917 st := cn.prepareTo(query, "") 918 r, err := st.Exec(args) 919 if err != nil { 920 panic(err) 921 } 922 return r, err 923} 924 925type safeRetryError struct { 926 Err error 927} 928 929func (se *safeRetryError) Error() string { 930 return se.Err.Error() 931} 932 933func (cn *conn) send(m *writeBuf) { 934 n, err := cn.c.Write(m.wrap()) 935 if err != nil { 936 if n == 0 { 937 err = &safeRetryError{Err: err} 938 } 939 panic(err) 940 } 941} 942 943func (cn *conn) sendStartupPacket(m *writeBuf) error { 944 _, err := cn.c.Write((m.wrap())[1:]) 945 return err 946} 947 948// Send a message of type typ to the server on the other end of cn. The 949// message should have no payload. This method does not use the scratch 950// buffer. 951func (cn *conn) sendSimpleMessage(typ byte) (err error) { 952 _, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'}) 953 return err 954} 955 956// saveMessage memorizes a message and its buffer in the conn struct. 957// recvMessage will then return these values on the next call to it. This 958// method is useful in cases where you have to see what the next message is 959// going to be (e.g. to see whether it's an error or not) but you can't handle 960// the message yourself. 961func (cn *conn) saveMessage(typ byte, buf *readBuf) { 962 if cn.saveMessageType != 0 { 963 cn.setBad() 964 errorf("unexpected saveMessageType %d", cn.saveMessageType) 965 } 966 cn.saveMessageType = typ 967 cn.saveMessageBuffer = *buf 968} 969 970// recvMessage receives any message from the backend, or returns an error if 971// a problem occurred while reading the message. 972func (cn *conn) recvMessage(r *readBuf) (byte, error) { 973 // workaround for a QueryRow bug, see exec 974 if cn.saveMessageType != 0 { 975 t := cn.saveMessageType 976 *r = cn.saveMessageBuffer 977 cn.saveMessageType = 0 978 cn.saveMessageBuffer = nil 979 return t, nil 980 } 981 982 x := cn.scratch[:5] 983 _, err := io.ReadFull(cn.buf, x) 984 if err != nil { 985 return 0, err 986 } 987 988 // read the type and length of the message that follows 989 t := x[0] 990 n := int(binary.BigEndian.Uint32(x[1:])) - 4 991 var y []byte 992 if n <= len(cn.scratch) { 993 y = cn.scratch[:n] 994 } else { 995 y = make([]byte, n) 996 } 997 _, err = io.ReadFull(cn.buf, y) 998 if err != nil { 999 return 0, err 1000 } 1001 *r = y 1002 return t, nil 1003} 1004 1005// recv receives a message from the backend, but if an error happened while 1006// reading the message or the received message was an ErrorResponse, it panics. 1007// NoticeResponses are ignored. This function should generally be used only 1008// during the startup sequence. 1009func (cn *conn) recv() (t byte, r *readBuf) { 1010 for { 1011 var err error 1012 r = &readBuf{} 1013 t, err = cn.recvMessage(r) 1014 if err != nil { 1015 panic(err) 1016 } 1017 switch t { 1018 case 'E': 1019 panic(parseError(r)) 1020 case 'N': 1021 if n := cn.noticeHandler; n != nil { 1022 n(parseError(r)) 1023 } 1024 case 'A': 1025 if n := cn.notificationHandler; n != nil { 1026 n(recvNotification(r)) 1027 } 1028 default: 1029 return 1030 } 1031 } 1032} 1033 1034// recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by 1035// the caller to avoid an allocation. 1036func (cn *conn) recv1Buf(r *readBuf) byte { 1037 for { 1038 t, err := cn.recvMessage(r) 1039 if err != nil { 1040 panic(err) 1041 } 1042 1043 switch t { 1044 case 'A': 1045 if n := cn.notificationHandler; n != nil { 1046 n(recvNotification(r)) 1047 } 1048 case 'N': 1049 if n := cn.noticeHandler; n != nil { 1050 n(parseError(r)) 1051 } 1052 case 'S': 1053 cn.processParameterStatus(r) 1054 default: 1055 return t 1056 } 1057 } 1058} 1059 1060// recv1 receives a message from the backend, panicking if an error occurs 1061// while attempting to read it. All asynchronous messages are ignored, with 1062// the exception of ErrorResponse. 1063func (cn *conn) recv1() (t byte, r *readBuf) { 1064 r = &readBuf{} 1065 t = cn.recv1Buf(r) 1066 return t, r 1067} 1068 1069func (cn *conn) ssl(o values) error { 1070 upgrade, err := ssl(o) 1071 if err != nil { 1072 return err 1073 } 1074 1075 if upgrade == nil { 1076 // Nothing to do 1077 return nil 1078 } 1079 1080 w := cn.writeBuf(0) 1081 w.int32(80877103) 1082 if err = cn.sendStartupPacket(w); err != nil { 1083 return err 1084 } 1085 1086 b := cn.scratch[:1] 1087 _, err = io.ReadFull(cn.c, b) 1088 if err != nil { 1089 return err 1090 } 1091 1092 if b[0] != 'S' { 1093 return ErrSSLNotSupported 1094 } 1095 1096 cn.c, err = upgrade(cn.c) 1097 return err 1098} 1099 1100// isDriverSetting returns true iff a setting is purely for configuring the 1101// driver's options and should not be sent to the server in the connection 1102// startup packet. 1103func isDriverSetting(key string) bool { 1104 switch key { 1105 case "host", "port": 1106 return true 1107 case "password": 1108 return true 1109 case "sslmode", "sslcert", "sslkey", "sslrootcert", "sslinline": 1110 return true 1111 case "fallback_application_name": 1112 return true 1113 case "connect_timeout": 1114 return true 1115 case "disable_prepared_binary_result": 1116 return true 1117 case "binary_parameters": 1118 return true 1119 case "krbsrvname": 1120 return true 1121 case "krbspn": 1122 return true 1123 default: 1124 return false 1125 } 1126} 1127 1128func (cn *conn) startup(o values) { 1129 w := cn.writeBuf(0) 1130 w.int32(196608) 1131 // Send the backend the name of the database we want to connect to, and the 1132 // user we want to connect as. Additionally, we send over any run-time 1133 // parameters potentially included in the connection string. If the server 1134 // doesn't recognize any of them, it will reply with an error. 1135 for k, v := range o { 1136 if isDriverSetting(k) { 1137 // skip options which can't be run-time parameters 1138 continue 1139 } 1140 // The protocol requires us to supply the database name as "database" 1141 // instead of "dbname". 1142 if k == "dbname" { 1143 k = "database" 1144 } 1145 w.string(k) 1146 w.string(v) 1147 } 1148 w.string("") 1149 if err := cn.sendStartupPacket(w); err != nil { 1150 panic(err) 1151 } 1152 1153 for { 1154 t, r := cn.recv() 1155 switch t { 1156 case 'K': 1157 cn.processBackendKeyData(r) 1158 case 'S': 1159 cn.processParameterStatus(r) 1160 case 'R': 1161 cn.auth(r, o) 1162 case 'Z': 1163 cn.processReadyForQuery(r) 1164 return 1165 default: 1166 errorf("unknown response for startup: %q", t) 1167 } 1168 } 1169} 1170 1171func (cn *conn) auth(r *readBuf, o values) { 1172 switch code := r.int32(); code { 1173 case 0: 1174 // OK 1175 case 3: 1176 w := cn.writeBuf('p') 1177 w.string(o["password"]) 1178 cn.send(w) 1179 1180 t, r := cn.recv() 1181 if t != 'R' { 1182 errorf("unexpected password response: %q", t) 1183 } 1184 1185 if r.int32() != 0 { 1186 errorf("unexpected authentication response: %q", t) 1187 } 1188 case 5: 1189 s := string(r.next(4)) 1190 w := cn.writeBuf('p') 1191 w.string("md5" + md5s(md5s(o["password"]+o["user"])+s)) 1192 cn.send(w) 1193 1194 t, r := cn.recv() 1195 if t != 'R' { 1196 errorf("unexpected password response: %q", t) 1197 } 1198 1199 if r.int32() != 0 { 1200 errorf("unexpected authentication response: %q", t) 1201 } 1202 case 7: // GSSAPI, startup 1203 if newGss == nil { 1204 errorf("kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos if you need Kerberos support)") 1205 } 1206 cli, err := newGss() 1207 if err != nil { 1208 errorf("kerberos error: %s", err.Error()) 1209 } 1210 1211 var token []byte 1212 1213 if spn, ok := o["krbspn"]; ok { 1214 // Use the supplied SPN if provided.. 1215 token, err = cli.GetInitTokenFromSpn(spn) 1216 } else { 1217 // Allow the kerberos service name to be overridden 1218 service := "postgres" 1219 if val, ok := o["krbsrvname"]; ok { 1220 service = val 1221 } 1222 1223 token, err = cli.GetInitToken(o["host"], service) 1224 } 1225 1226 if err != nil { 1227 errorf("failed to get Kerberos ticket: %q", err) 1228 } 1229 1230 w := cn.writeBuf('p') 1231 w.bytes(token) 1232 cn.send(w) 1233 1234 // Store for GSSAPI continue message 1235 cn.gss = cli 1236 1237 case 8: // GSSAPI continue 1238 1239 if cn.gss == nil { 1240 errorf("GSSAPI protocol error") 1241 } 1242 1243 b := []byte(*r) 1244 1245 done, tokOut, err := cn.gss.Continue(b) 1246 if err == nil && !done { 1247 w := cn.writeBuf('p') 1248 w.bytes(tokOut) 1249 cn.send(w) 1250 } 1251 1252 // Errors fall through and read the more detailed message 1253 // from the server.. 1254 1255 case 10: 1256 sc := scram.NewClient(sha256.New, o["user"], o["password"]) 1257 sc.Step(nil) 1258 if sc.Err() != nil { 1259 errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) 1260 } 1261 scOut := sc.Out() 1262 1263 w := cn.writeBuf('p') 1264 w.string("SCRAM-SHA-256") 1265 w.int32(len(scOut)) 1266 w.bytes(scOut) 1267 cn.send(w) 1268 1269 t, r := cn.recv() 1270 if t != 'R' { 1271 errorf("unexpected password response: %q", t) 1272 } 1273 1274 if r.int32() != 11 { 1275 errorf("unexpected authentication response: %q", t) 1276 } 1277 1278 nextStep := r.next(len(*r)) 1279 sc.Step(nextStep) 1280 if sc.Err() != nil { 1281 errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) 1282 } 1283 1284 scOut = sc.Out() 1285 w = cn.writeBuf('p') 1286 w.bytes(scOut) 1287 cn.send(w) 1288 1289 t, r = cn.recv() 1290 if t != 'R' { 1291 errorf("unexpected password response: %q", t) 1292 } 1293 1294 if r.int32() != 12 { 1295 errorf("unexpected authentication response: %q", t) 1296 } 1297 1298 nextStep = r.next(len(*r)) 1299 sc.Step(nextStep) 1300 if sc.Err() != nil { 1301 errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) 1302 } 1303 1304 default: 1305 errorf("unknown authentication response: %d", code) 1306 } 1307} 1308 1309type format int 1310 1311const formatText format = 0 1312const formatBinary format = 1 1313 1314// One result-column format code with the value 1 (i.e. all binary). 1315var colFmtDataAllBinary = []byte{0, 1, 0, 1} 1316 1317// No result-column format codes (i.e. all text). 1318var colFmtDataAllText = []byte{0, 0} 1319 1320type stmt struct { 1321 cn *conn 1322 name string 1323 rowsHeader 1324 colFmtData []byte 1325 paramTyps []oid.Oid 1326 closed bool 1327} 1328 1329func (st *stmt) Close() (err error) { 1330 if st.closed { 1331 return nil 1332 } 1333 if st.cn.getBad() { 1334 return driver.ErrBadConn 1335 } 1336 defer st.cn.errRecover(&err) 1337 1338 w := st.cn.writeBuf('C') 1339 w.byte('S') 1340 w.string(st.name) 1341 st.cn.send(w) 1342 1343 st.cn.send(st.cn.writeBuf('S')) 1344 1345 t, _ := st.cn.recv1() 1346 if t != '3' { 1347 st.cn.setBad() 1348 errorf("unexpected close response: %q", t) 1349 } 1350 st.closed = true 1351 1352 t, r := st.cn.recv1() 1353 if t != 'Z' { 1354 st.cn.setBad() 1355 errorf("expected ready for query, but got: %q", t) 1356 } 1357 st.cn.processReadyForQuery(r) 1358 1359 return nil 1360} 1361 1362func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { 1363 if st.cn.getBad() { 1364 return nil, driver.ErrBadConn 1365 } 1366 defer st.cn.errRecover(&err) 1367 1368 st.exec(v) 1369 return &rows{ 1370 cn: st.cn, 1371 rowsHeader: st.rowsHeader, 1372 }, nil 1373} 1374 1375func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) { 1376 if st.cn.getBad() { 1377 return nil, driver.ErrBadConn 1378 } 1379 defer st.cn.errRecover(&err) 1380 1381 st.exec(v) 1382 res, _, err = st.cn.readExecuteResponse("simple query") 1383 return res, err 1384} 1385 1386func (st *stmt) exec(v []driver.Value) { 1387 if len(v) >= 65536 { 1388 errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v)) 1389 } 1390 if len(v) != len(st.paramTyps) { 1391 errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps)) 1392 } 1393 1394 cn := st.cn 1395 w := cn.writeBuf('B') 1396 w.byte(0) // unnamed portal 1397 w.string(st.name) 1398 1399 if cn.binaryParameters { 1400 cn.sendBinaryParameters(w, v) 1401 } else { 1402 w.int16(0) 1403 w.int16(len(v)) 1404 for i, x := range v { 1405 if x == nil { 1406 w.int32(-1) 1407 } else { 1408 b := encode(&cn.parameterStatus, x, st.paramTyps[i]) 1409 w.int32(len(b)) 1410 w.bytes(b) 1411 } 1412 } 1413 } 1414 w.bytes(st.colFmtData) 1415 1416 w.next('E') 1417 w.byte(0) 1418 w.int32(0) 1419 1420 w.next('S') 1421 cn.send(w) 1422 1423 cn.readBindResponse() 1424 cn.postExecuteWorkaround() 1425 1426} 1427 1428func (st *stmt) NumInput() int { 1429 return len(st.paramTyps) 1430} 1431 1432// parseComplete parses the "command tag" from a CommandComplete message, and 1433// returns the number of rows affected (if applicable) and a string 1434// identifying only the command that was executed, e.g. "ALTER TABLE". If the 1435// command tag could not be parsed, parseComplete panics. 1436func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { 1437 commandsWithAffectedRows := []string{ 1438 "SELECT ", 1439 // INSERT is handled below 1440 "UPDATE ", 1441 "DELETE ", 1442 "FETCH ", 1443 "MOVE ", 1444 "COPY ", 1445 } 1446 1447 var affectedRows *string 1448 for _, tag := range commandsWithAffectedRows { 1449 if strings.HasPrefix(commandTag, tag) { 1450 t := commandTag[len(tag):] 1451 affectedRows = &t 1452 commandTag = tag[:len(tag)-1] 1453 break 1454 } 1455 } 1456 // INSERT also includes the oid of the inserted row in its command tag. 1457 // Oids in user tables are deprecated, and the oid is only returned when 1458 // exactly one row is inserted, so it's unlikely to be of value to any 1459 // real-world application and we can ignore it. 1460 if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") { 1461 parts := strings.Split(commandTag, " ") 1462 if len(parts) != 3 { 1463 cn.setBad() 1464 errorf("unexpected INSERT command tag %s", commandTag) 1465 } 1466 affectedRows = &parts[len(parts)-1] 1467 commandTag = "INSERT" 1468 } 1469 // There should be no affected rows attached to the tag, just return it 1470 if affectedRows == nil { 1471 return driver.RowsAffected(0), commandTag 1472 } 1473 n, err := strconv.ParseInt(*affectedRows, 10, 64) 1474 if err != nil { 1475 cn.setBad() 1476 errorf("could not parse commandTag: %s", err) 1477 } 1478 return driver.RowsAffected(n), commandTag 1479} 1480 1481type rowsHeader struct { 1482 colNames []string 1483 colTyps []fieldDesc 1484 colFmts []format 1485} 1486 1487type rows struct { 1488 cn *conn 1489 finish func() 1490 rowsHeader 1491 done bool 1492 rb readBuf 1493 result driver.Result 1494 tag string 1495 1496 next *rowsHeader 1497} 1498 1499func (rs *rows) Close() error { 1500 if finish := rs.finish; finish != nil { 1501 defer finish() 1502 } 1503 // no need to look at cn.bad as Next() will 1504 for { 1505 err := rs.Next(nil) 1506 switch err { 1507 case nil: 1508 case io.EOF: 1509 // rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row 1510 // description, used with HasNextResultSet). We need to fetch messages until 1511 // we hit a 'Z', which is done by waiting for done to be set. 1512 if rs.done { 1513 return nil 1514 } 1515 default: 1516 return err 1517 } 1518 } 1519} 1520 1521func (rs *rows) Columns() []string { 1522 return rs.colNames 1523} 1524 1525func (rs *rows) Result() driver.Result { 1526 if rs.result == nil { 1527 return emptyRows 1528 } 1529 return rs.result 1530} 1531 1532func (rs *rows) Tag() string { 1533 return rs.tag 1534} 1535 1536func (rs *rows) Next(dest []driver.Value) (err error) { 1537 if rs.done { 1538 return io.EOF 1539 } 1540 1541 conn := rs.cn 1542 if conn.getBad() { 1543 return driver.ErrBadConn 1544 } 1545 defer conn.errRecover(&err) 1546 1547 for { 1548 t := conn.recv1Buf(&rs.rb) 1549 switch t { 1550 case 'E': 1551 err = parseError(&rs.rb) 1552 case 'C', 'I': 1553 if t == 'C' { 1554 rs.result, rs.tag = conn.parseComplete(rs.rb.string()) 1555 } 1556 continue 1557 case 'Z': 1558 conn.processReadyForQuery(&rs.rb) 1559 rs.done = true 1560 if err != nil { 1561 return err 1562 } 1563 return io.EOF 1564 case 'D': 1565 n := rs.rb.int16() 1566 if err != nil { 1567 conn.setBad() 1568 errorf("unexpected DataRow after error %s", err) 1569 } 1570 if n < len(dest) { 1571 dest = dest[:n] 1572 } 1573 for i := range dest { 1574 l := rs.rb.int32() 1575 if l == -1 { 1576 dest[i] = nil 1577 continue 1578 } 1579 dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i]) 1580 } 1581 return 1582 case 'T': 1583 next := parsePortalRowDescribe(&rs.rb) 1584 rs.next = &next 1585 return io.EOF 1586 default: 1587 errorf("unexpected message after execute: %q", t) 1588 } 1589 } 1590} 1591 1592func (rs *rows) HasNextResultSet() bool { 1593 hasNext := rs.next != nil && !rs.done 1594 return hasNext 1595} 1596 1597func (rs *rows) NextResultSet() error { 1598 if rs.next == nil { 1599 return io.EOF 1600 } 1601 rs.rowsHeader = *rs.next 1602 rs.next = nil 1603 return nil 1604} 1605 1606// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be 1607// used as part of an SQL statement. For example: 1608// 1609// tblname := "my_table" 1610// data := "my_data" 1611// quoted := pq.QuoteIdentifier(tblname) 1612// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data) 1613// 1614// Any double quotes in name will be escaped. The quoted identifier will be 1615// case sensitive when used in a query. If the input string contains a zero 1616// byte, the result will be truncated immediately before it. 1617func QuoteIdentifier(name string) string { 1618 end := strings.IndexRune(name, 0) 1619 if end > -1 { 1620 name = name[:end] 1621 } 1622 return `"` + strings.Replace(name, `"`, `""`, -1) + `"` 1623} 1624 1625// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal 1626// to DDL and other statements that do not accept parameters) to be used as part 1627// of an SQL statement. For example: 1628// 1629// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z") 1630// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date)) 1631// 1632// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be 1633// replaced by two backslashes (i.e. "\\") and the C-style escape identifier 1634// that PostgreSQL provides ('E') will be prepended to the string. 1635func QuoteLiteral(literal string) string { 1636 // This follows the PostgreSQL internal algorithm for handling quoted literals 1637 // from libpq, which can be found in the "PQEscapeStringInternal" function, 1638 // which is found in the libpq/fe-exec.c source file: 1639 // https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c 1640 // 1641 // substitute any single-quotes (') with two single-quotes ('') 1642 literal = strings.Replace(literal, `'`, `''`, -1) 1643 // determine if the string has any backslashes (\) in it. 1644 // if it does, replace any backslashes (\) with two backslashes (\\) 1645 // then, we need to wrap the entire string with a PostgreSQL 1646 // C-style escape. Per how "PQEscapeStringInternal" handles this case, we 1647 // also add a space before the "E" 1648 if strings.Contains(literal, `\`) { 1649 literal = strings.Replace(literal, `\`, `\\`, -1) 1650 literal = ` E'` + literal + `'` 1651 } else { 1652 // otherwise, we can just wrap the literal with a pair of single quotes 1653 literal = `'` + literal + `'` 1654 } 1655 return literal 1656} 1657 1658func md5s(s string) string { 1659 h := md5.New() 1660 h.Write([]byte(s)) 1661 return fmt.Sprintf("%x", h.Sum(nil)) 1662} 1663 1664func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) { 1665 // Do one pass over the parameters to see if we're going to send any of 1666 // them over in binary. If we are, create a paramFormats array at the 1667 // same time. 1668 var paramFormats []int 1669 for i, x := range args { 1670 _, ok := x.([]byte) 1671 if ok { 1672 if paramFormats == nil { 1673 paramFormats = make([]int, len(args)) 1674 } 1675 paramFormats[i] = 1 1676 } 1677 } 1678 if paramFormats == nil { 1679 b.int16(0) 1680 } else { 1681 b.int16(len(paramFormats)) 1682 for _, x := range paramFormats { 1683 b.int16(x) 1684 } 1685 } 1686 1687 b.int16(len(args)) 1688 for _, x := range args { 1689 if x == nil { 1690 b.int32(-1) 1691 } else { 1692 datum := binaryEncode(&cn.parameterStatus, x) 1693 b.int32(len(datum)) 1694 b.bytes(datum) 1695 } 1696 } 1697} 1698 1699func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) { 1700 if len(args) >= 65536 { 1701 errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args)) 1702 } 1703 1704 b := cn.writeBuf('P') 1705 b.byte(0) // unnamed statement 1706 b.string(query) 1707 b.int16(0) 1708 1709 b.next('B') 1710 b.int16(0) // unnamed portal and statement 1711 cn.sendBinaryParameters(b, args) 1712 b.bytes(colFmtDataAllText) 1713 1714 b.next('D') 1715 b.byte('P') 1716 b.byte(0) // unnamed portal 1717 1718 b.next('E') 1719 b.byte(0) 1720 b.int32(0) 1721 1722 b.next('S') 1723 cn.send(b) 1724} 1725 1726func (cn *conn) processParameterStatus(r *readBuf) { 1727 var err error 1728 1729 param := r.string() 1730 switch param { 1731 case "server_version": 1732 var major1 int 1733 var major2 int 1734 _, err = fmt.Sscanf(r.string(), "%d.%d", &major1, &major2) 1735 if err == nil { 1736 cn.parameterStatus.serverVersion = major1*10000 + major2*100 1737 } 1738 1739 case "TimeZone": 1740 cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string()) 1741 if err != nil { 1742 cn.parameterStatus.currentLocation = nil 1743 } 1744 1745 default: 1746 // ignore 1747 } 1748} 1749 1750func (cn *conn) processReadyForQuery(r *readBuf) { 1751 cn.txnStatus = transactionStatus(r.byte()) 1752} 1753 1754func (cn *conn) readReadyForQuery() { 1755 t, r := cn.recv1() 1756 switch t { 1757 case 'Z': 1758 cn.processReadyForQuery(r) 1759 return 1760 default: 1761 cn.setBad() 1762 errorf("unexpected message %q; expected ReadyForQuery", t) 1763 } 1764} 1765 1766func (cn *conn) processBackendKeyData(r *readBuf) { 1767 cn.processID = r.int32() 1768 cn.secretKey = r.int32() 1769} 1770 1771func (cn *conn) readParseResponse() { 1772 t, r := cn.recv1() 1773 switch t { 1774 case '1': 1775 return 1776 case 'E': 1777 err := parseError(r) 1778 cn.readReadyForQuery() 1779 panic(err) 1780 default: 1781 cn.setBad() 1782 errorf("unexpected Parse response %q", t) 1783 } 1784} 1785 1786func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc) { 1787 for { 1788 t, r := cn.recv1() 1789 switch t { 1790 case 't': 1791 nparams := r.int16() 1792 paramTyps = make([]oid.Oid, nparams) 1793 for i := range paramTyps { 1794 paramTyps[i] = r.oid() 1795 } 1796 case 'n': 1797 return paramTyps, nil, nil 1798 case 'T': 1799 colNames, colTyps = parseStatementRowDescribe(r) 1800 return paramTyps, colNames, colTyps 1801 case 'E': 1802 err := parseError(r) 1803 cn.readReadyForQuery() 1804 panic(err) 1805 default: 1806 cn.setBad() 1807 errorf("unexpected Describe statement response %q", t) 1808 } 1809 } 1810} 1811 1812func (cn *conn) readPortalDescribeResponse() rowsHeader { 1813 t, r := cn.recv1() 1814 switch t { 1815 case 'T': 1816 return parsePortalRowDescribe(r) 1817 case 'n': 1818 return rowsHeader{} 1819 case 'E': 1820 err := parseError(r) 1821 cn.readReadyForQuery() 1822 panic(err) 1823 default: 1824 cn.setBad() 1825 errorf("unexpected Describe response %q", t) 1826 } 1827 panic("not reached") 1828} 1829 1830func (cn *conn) readBindResponse() { 1831 t, r := cn.recv1() 1832 switch t { 1833 case '2': 1834 return 1835 case 'E': 1836 err := parseError(r) 1837 cn.readReadyForQuery() 1838 panic(err) 1839 default: 1840 cn.setBad() 1841 errorf("unexpected Bind response %q", t) 1842 } 1843} 1844 1845func (cn *conn) postExecuteWorkaround() { 1846 // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores 1847 // any errors from rows.Next, which masks errors that happened during the 1848 // execution of the query. To avoid the problem in common cases, we wait 1849 // here for one more message from the database. If it's not an error the 1850 // query will likely succeed (or perhaps has already, if it's a 1851 // CommandComplete), so we push the message into the conn struct; recv1 1852 // will return it as the next message for rows.Next or rows.Close. 1853 // However, if it's an error, we wait until ReadyForQuery and then return 1854 // the error to our caller. 1855 for { 1856 t, r := cn.recv1() 1857 switch t { 1858 case 'E': 1859 err := parseError(r) 1860 cn.readReadyForQuery() 1861 panic(err) 1862 case 'C', 'D', 'I': 1863 // the query didn't fail, but we can't process this message 1864 cn.saveMessage(t, r) 1865 return 1866 default: 1867 cn.setBad() 1868 errorf("unexpected message during extended query execution: %q", t) 1869 } 1870 } 1871} 1872 1873// Only for Exec(), since we ignore the returned data 1874func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) { 1875 for { 1876 t, r := cn.recv1() 1877 switch t { 1878 case 'C': 1879 if err != nil { 1880 cn.setBad() 1881 errorf("unexpected CommandComplete after error %s", err) 1882 } 1883 res, commandTag = cn.parseComplete(r.string()) 1884 case 'Z': 1885 cn.processReadyForQuery(r) 1886 if res == nil && err == nil { 1887 err = errUnexpectedReady 1888 } 1889 return res, commandTag, err 1890 case 'E': 1891 err = parseError(r) 1892 case 'T', 'D', 'I': 1893 if err != nil { 1894 cn.setBad() 1895 errorf("unexpected %q after error %s", t, err) 1896 } 1897 if t == 'I' { 1898 res = emptyRows 1899 } 1900 // ignore any results 1901 default: 1902 cn.setBad() 1903 errorf("unknown %s response: %q", protocolState, t) 1904 } 1905 } 1906} 1907 1908func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) { 1909 n := r.int16() 1910 colNames = make([]string, n) 1911 colTyps = make([]fieldDesc, n) 1912 for i := range colNames { 1913 colNames[i] = r.string() 1914 r.next(6) 1915 colTyps[i].OID = r.oid() 1916 colTyps[i].Len = r.int16() 1917 colTyps[i].Mod = r.int32() 1918 // format code not known when describing a statement; always 0 1919 r.next(2) 1920 } 1921 return 1922} 1923 1924func parsePortalRowDescribe(r *readBuf) rowsHeader { 1925 n := r.int16() 1926 colNames := make([]string, n) 1927 colFmts := make([]format, n) 1928 colTyps := make([]fieldDesc, n) 1929 for i := range colNames { 1930 colNames[i] = r.string() 1931 r.next(6) 1932 colTyps[i].OID = r.oid() 1933 colTyps[i].Len = r.int16() 1934 colTyps[i].Mod = r.int32() 1935 colFmts[i] = format(r.int16()) 1936 } 1937 return rowsHeader{ 1938 colNames: colNames, 1939 colFmts: colFmts, 1940 colTyps: colTyps, 1941 } 1942} 1943 1944// parseEnviron tries to mimic some of libpq's environment handling 1945// 1946// To ease testing, it does not directly reference os.Environ, but is 1947// designed to accept its output. 1948// 1949// Environment-set connection information is intended to have a higher 1950// precedence than a library default but lower than any explicitly 1951// passed information (such as in the URL or connection string). 1952func parseEnviron(env []string) (out map[string]string) { 1953 out = make(map[string]string) 1954 1955 for _, v := range env { 1956 parts := strings.SplitN(v, "=", 2) 1957 1958 accrue := func(keyname string) { 1959 out[keyname] = parts[1] 1960 } 1961 unsupported := func() { 1962 panic(fmt.Sprintf("setting %v not supported", parts[0])) 1963 } 1964 1965 // The order of these is the same as is seen in the 1966 // PostgreSQL 9.1 manual. Unsupported but well-defined 1967 // keys cause a panic; these should be unset prior to 1968 // execution. Options which pq expects to be set to a 1969 // certain value are allowed, but must be set to that 1970 // value if present (they can, of course, be absent). 1971 switch parts[0] { 1972 case "PGHOST": 1973 accrue("host") 1974 case "PGHOSTADDR": 1975 unsupported() 1976 case "PGPORT": 1977 accrue("port") 1978 case "PGDATABASE": 1979 accrue("dbname") 1980 case "PGUSER": 1981 accrue("user") 1982 case "PGPASSWORD": 1983 accrue("password") 1984 case "PGSERVICE", "PGSERVICEFILE", "PGREALM": 1985 unsupported() 1986 case "PGOPTIONS": 1987 accrue("options") 1988 case "PGAPPNAME": 1989 accrue("application_name") 1990 case "PGSSLMODE": 1991 accrue("sslmode") 1992 case "PGSSLCERT": 1993 accrue("sslcert") 1994 case "PGSSLKEY": 1995 accrue("sslkey") 1996 case "PGSSLROOTCERT": 1997 accrue("sslrootcert") 1998 case "PGREQUIRESSL", "PGSSLCRL": 1999 unsupported() 2000 case "PGREQUIREPEER": 2001 unsupported() 2002 case "PGKRBSRVNAME", "PGGSSLIB": 2003 unsupported() 2004 case "PGCONNECT_TIMEOUT": 2005 accrue("connect_timeout") 2006 case "PGCLIENTENCODING": 2007 accrue("client_encoding") 2008 case "PGDATESTYLE": 2009 accrue("datestyle") 2010 case "PGTZ": 2011 accrue("timezone") 2012 case "PGGEQO": 2013 accrue("geqo") 2014 case "PGSYSCONFDIR", "PGLOCALEDIR": 2015 unsupported() 2016 } 2017 } 2018 2019 return out 2020} 2021 2022// isUTF8 returns whether name is a fuzzy variation of the string "UTF-8". 2023func isUTF8(name string) bool { 2024 // Recognize all sorts of silly things as "UTF-8", like Postgres does 2025 s := strings.Map(alnumLowerASCII, name) 2026 return s == "utf8" || s == "unicode" 2027} 2028 2029func alnumLowerASCII(ch rune) rune { 2030 if 'A' <= ch && ch <= 'Z' { 2031 return ch + ('a' - 'A') 2032 } 2033 if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' { 2034 return ch 2035 } 2036 return -1 // discard 2037} 2038