1package pq 2 3import ( 4 "bufio" 5 "context" 6 "crypto/md5" 7 "crypto/sha256" 8 "database/sql" 9 "database/sql/driver" 10 "encoding/binary" 11 "errors" 12 "fmt" 13 "io" 14 "net" 15 "os" 16 "os/user" 17 "path" 18 "path/filepath" 19 "strconv" 20 "strings" 21 "sync" 22 "time" 23 "unicode" 24 25 "github.com/lib/pq/oid" 26 "github.com/lib/pq/scram" 27) 28 29// Common error types 30var ( 31 ErrNotSupported = errors.New("pq: Unsupported command") 32 ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction") 33 ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server") 34 ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less") 35 ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly") 36 37 errUnexpectedReady = errors.New("unexpected ReadyForQuery") 38 errNoRowsAffected = errors.New("no RowsAffected available after the empty statement") 39 errNoLastInsertID = errors.New("no LastInsertId available after the empty statement") 40) 41 42// Compile time validation that our types implement the expected interfaces 43var ( 44 _ driver.Driver = Driver{} 45) 46 47// Driver is the Postgres database driver. 48type Driver struct{} 49 50// Open opens a new connection to the database. name is a connection string. 51// Most users should only use it through database/sql package from the standard 52// library. 53func (d Driver) Open(name string) (driver.Conn, error) { 54 return Open(name) 55} 56 57func init() { 58 sql.Register("postgres", &Driver{}) 59} 60 61type parameterStatus struct { 62 // server version in the same format as server_version_num, or 0 if 63 // unavailable 64 serverVersion int 65 66 // the current location based on the TimeZone value of the session, if 67 // available 68 currentLocation *time.Location 69} 70 71type transactionStatus byte 72 73const ( 74 txnStatusIdle transactionStatus = 'I' 75 txnStatusIdleInTransaction transactionStatus = 'T' 76 txnStatusInFailedTransaction transactionStatus = 'E' 77) 78 79func (s transactionStatus) String() string { 80 switch s { 81 case txnStatusIdle: 82 return "idle" 83 case txnStatusIdleInTransaction: 84 return "idle in transaction" 85 case txnStatusInFailedTransaction: 86 return "in a failed transaction" 87 default: 88 errorf("unknown transactionStatus %d", s) 89 } 90 91 panic("not reached") 92} 93 94// Dialer is the dialer interface. It can be used to obtain more control over 95// how pq creates network connections. 96type Dialer interface { 97 Dial(network, address string) (net.Conn, error) 98 DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) 99} 100 101// DialerContext is the context-aware dialer interface. 102type DialerContext interface { 103 DialContext(ctx context.Context, network, address string) (net.Conn, error) 104} 105 106type defaultDialer struct { 107 d net.Dialer 108} 109 110func (d defaultDialer) Dial(network, address string) (net.Conn, error) { 111 return d.d.Dial(network, address) 112} 113func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { 114 ctx, cancel := context.WithTimeout(context.Background(), timeout) 115 defer cancel() 116 return d.DialContext(ctx, network, address) 117} 118func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { 119 return d.d.DialContext(ctx, network, address) 120} 121 122type conn struct { 123 c net.Conn 124 buf *bufio.Reader 125 namei int 126 scratch [512]byte 127 txnStatus transactionStatus 128 txnFinish func() 129 130 // Save connection arguments to use during CancelRequest. 131 dialer Dialer 132 opts values 133 134 // Cancellation key data for use with CancelRequest messages. 135 processID int 136 secretKey int 137 138 parameterStatus parameterStatus 139 140 saveMessageType byte 141 saveMessageBuffer []byte 142 143 // If an error is set, this connection is bad and all public-facing 144 // functions should return the appropriate error by calling get() 145 // (ErrBadConn) or getForNext(). 146 err syncErr 147 148 // If set, this connection should never use the binary format when 149 // receiving query results from prepared statements. Only provided for 150 // debugging. 151 disablePreparedBinaryResult bool 152 153 // Whether to always send []byte parameters over as binary. Enables single 154 // round-trip mode for non-prepared Query calls. 155 binaryParameters bool 156 157 // If true this connection is in the middle of a COPY 158 inCopy bool 159 160 // If not nil, notices will be synchronously sent here 161 noticeHandler func(*Error) 162 163 // If not nil, notifications will be synchronously sent here 164 notificationHandler func(*Notification) 165 166 // GSSAPI context 167 gss GSS 168} 169 170type syncErr struct { 171 err error 172 sync.Mutex 173} 174 175// Return ErrBadConn if connection is bad. 176func (e *syncErr) get() error { 177 e.Lock() 178 defer e.Unlock() 179 if e.err != nil { 180 return driver.ErrBadConn 181 } 182 return nil 183} 184 185// Return the error set on the connection. Currently only used by rows.Next. 186func (e *syncErr) getForNext() error { 187 e.Lock() 188 defer e.Unlock() 189 return e.err 190} 191 192// Set error, only if it isn't set yet. 193func (e *syncErr) set(err error) { 194 if err == nil { 195 panic("attempt to set nil err") 196 } 197 e.Lock() 198 defer e.Unlock() 199 if e.err == nil { 200 e.err = err 201 } 202} 203 204// Handle driver-side settings in parsed connection string. 205func (cn *conn) handleDriverSettings(o values) (err error) { 206 boolSetting := func(key string, val *bool) error { 207 if value, ok := o[key]; ok { 208 if value == "yes" { 209 *val = true 210 } else if value == "no" { 211 *val = false 212 } else { 213 return fmt.Errorf("unrecognized value %q for %s", value, key) 214 } 215 } 216 return nil 217 } 218 219 err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult) 220 if err != nil { 221 return err 222 } 223 return boolSetting("binary_parameters", &cn.binaryParameters) 224} 225 226func (cn *conn) handlePgpass(o values) { 227 // if a password was supplied, do not process .pgpass 228 if _, ok := o["password"]; ok { 229 return 230 } 231 filename := os.Getenv("PGPASSFILE") 232 if filename == "" { 233 // XXX this code doesn't work on Windows where the default filename is 234 // XXX %APPDATA%\postgresql\pgpass.conf 235 // Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470 236 userHome := os.Getenv("HOME") 237 if userHome == "" { 238 user, err := user.Current() 239 if err != nil { 240 return 241 } 242 userHome = user.HomeDir 243 } 244 filename = filepath.Join(userHome, ".pgpass") 245 } 246 fileinfo, err := os.Stat(filename) 247 if err != nil { 248 return 249 } 250 mode := fileinfo.Mode() 251 if mode&(0x77) != 0 { 252 // XXX should warn about incorrect .pgpass permissions as psql does 253 return 254 } 255 file, err := os.Open(filename) 256 if err != nil { 257 return 258 } 259 defer file.Close() 260 scanner := bufio.NewScanner(io.Reader(file)) 261 hostname := o["host"] 262 ntw, _ := network(o) 263 port := o["port"] 264 db := o["dbname"] 265 username := o["user"] 266 // From: https://github.com/tg/pgpass/blob/master/reader.go 267 getFields := func(s string) []string { 268 fs := make([]string, 0, 5) 269 f := make([]rune, 0, len(s)) 270 271 var esc bool 272 for _, c := range s { 273 switch { 274 case esc: 275 f = append(f, c) 276 esc = false 277 case c == '\\': 278 esc = true 279 case c == ':': 280 fs = append(fs, string(f)) 281 f = f[:0] 282 default: 283 f = append(f, c) 284 } 285 } 286 return append(fs, string(f)) 287 } 288 for scanner.Scan() { 289 line := scanner.Text() 290 if len(line) == 0 || line[0] == '#' { 291 continue 292 } 293 split := getFields(line) 294 if len(split) != 5 { 295 continue 296 } 297 if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) { 298 o["password"] = split[4] 299 return 300 } 301 } 302} 303 304func (cn *conn) writeBuf(b byte) *writeBuf { 305 cn.scratch[0] = b 306 return &writeBuf{ 307 buf: cn.scratch[:5], 308 pos: 1, 309 } 310} 311 312// Open opens a new connection to the database. dsn is a connection string. 313// Most users should only use it through database/sql package from the standard 314// library. 315func Open(dsn string) (_ driver.Conn, err error) { 316 return DialOpen(defaultDialer{}, dsn) 317} 318 319// DialOpen opens a new connection to the database using a dialer. 320func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) { 321 c, err := NewConnector(dsn) 322 if err != nil { 323 return nil, err 324 } 325 c.dialer = d 326 return c.open(context.Background()) 327} 328 329func (c *Connector) open(ctx context.Context) (cn *conn, err error) { 330 // Handle any panics during connection initialization. Note that we 331 // specifically do *not* want to use errRecover(), as that would turn any 332 // connection errors into ErrBadConns, hiding the real error message from 333 // the user. 334 defer errRecoverNoErrBadConn(&err) 335 336 // Create a new values map (copy). This makes it so maps in different 337 // connections do not reference the same underlying data structure, so it 338 // is safe for multiple connections to concurrently write to their opts. 339 o := make(values) 340 for k, v := range c.opts { 341 o[k] = v 342 } 343 344 cn = &conn{ 345 opts: o, 346 dialer: c.dialer, 347 } 348 err = cn.handleDriverSettings(o) 349 if err != nil { 350 return nil, err 351 } 352 cn.handlePgpass(o) 353 354 cn.c, err = dial(ctx, c.dialer, o) 355 if err != nil { 356 return nil, err 357 } 358 359 err = cn.ssl(o) 360 if err != nil { 361 if cn.c != nil { 362 cn.c.Close() 363 } 364 return nil, err 365 } 366 367 // cn.startup panics on error. Make sure we don't leak cn.c. 368 panicking := true 369 defer func() { 370 if panicking { 371 cn.c.Close() 372 } 373 }() 374 375 cn.buf = bufio.NewReader(cn.c) 376 cn.startup(o) 377 378 // reset the deadline, in case one was set (see dial) 379 if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { 380 err = cn.c.SetDeadline(time.Time{}) 381 } 382 panicking = false 383 return cn, err 384} 385 386func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) { 387 network, address := network(o) 388 389 // Zero or not specified means wait indefinitely. 390 if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { 391 seconds, err := strconv.ParseInt(timeout, 10, 0) 392 if err != nil { 393 return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err) 394 } 395 duration := time.Duration(seconds) * time.Second 396 397 // connect_timeout should apply to the entire connection establishment 398 // procedure, so we both use a timeout for the TCP connection 399 // establishment and set a deadline for doing the initial handshake. 400 // The deadline is then reset after startup() is done. 401 deadline := time.Now().Add(duration) 402 var conn net.Conn 403 if dctx, ok := d.(DialerContext); ok { 404 ctx, cancel := context.WithTimeout(ctx, duration) 405 defer cancel() 406 conn, err = dctx.DialContext(ctx, network, address) 407 } else { 408 conn, err = d.DialTimeout(network, address, duration) 409 } 410 if err != nil { 411 return nil, err 412 } 413 err = conn.SetDeadline(deadline) 414 return conn, err 415 } 416 if dctx, ok := d.(DialerContext); ok { 417 return dctx.DialContext(ctx, network, address) 418 } 419 return d.Dial(network, address) 420} 421 422func network(o values) (string, string) { 423 host := o["host"] 424 425 if strings.HasPrefix(host, "/") { 426 sockPath := path.Join(host, ".s.PGSQL."+o["port"]) 427 return "unix", sockPath 428 } 429 430 return "tcp", net.JoinHostPort(host, o["port"]) 431} 432 433type values map[string]string 434 435// scanner implements a tokenizer for libpq-style option strings. 436type scanner struct { 437 s []rune 438 i int 439} 440 441// newScanner returns a new scanner initialized with the option string s. 442func newScanner(s string) *scanner { 443 return &scanner{[]rune(s), 0} 444} 445 446// Next returns the next rune. 447// It returns 0, false if the end of the text has been reached. 448func (s *scanner) Next() (rune, bool) { 449 if s.i >= len(s.s) { 450 return 0, false 451 } 452 r := s.s[s.i] 453 s.i++ 454 return r, true 455} 456 457// SkipSpaces returns the next non-whitespace rune. 458// It returns 0, false if the end of the text has been reached. 459func (s *scanner) SkipSpaces() (rune, bool) { 460 r, ok := s.Next() 461 for unicode.IsSpace(r) && ok { 462 r, ok = s.Next() 463 } 464 return r, ok 465} 466 467// parseOpts parses the options from name and adds them to the values. 468// 469// The parsing code is based on conninfo_parse from libpq's fe-connect.c 470func parseOpts(name string, o values) error { 471 s := newScanner(name) 472 473 for { 474 var ( 475 keyRunes, valRunes []rune 476 r rune 477 ok bool 478 ) 479 480 if r, ok = s.SkipSpaces(); !ok { 481 break 482 } 483 484 // Scan the key 485 for !unicode.IsSpace(r) && r != '=' { 486 keyRunes = append(keyRunes, r) 487 if r, ok = s.Next(); !ok { 488 break 489 } 490 } 491 492 // Skip any whitespace if we're not at the = yet 493 if r != '=' { 494 r, ok = s.SkipSpaces() 495 } 496 497 // The current character should be = 498 if r != '=' || !ok { 499 return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes)) 500 } 501 502 // Skip any whitespace after the = 503 if r, ok = s.SkipSpaces(); !ok { 504 // If we reach the end here, the last value is just an empty string as per libpq. 505 o[string(keyRunes)] = "" 506 break 507 } 508 509 if r != '\'' { 510 for !unicode.IsSpace(r) { 511 if r == '\\' { 512 if r, ok = s.Next(); !ok { 513 return fmt.Errorf(`missing character after backslash`) 514 } 515 } 516 valRunes = append(valRunes, r) 517 518 if r, ok = s.Next(); !ok { 519 break 520 } 521 } 522 } else { 523 quote: 524 for { 525 if r, ok = s.Next(); !ok { 526 return fmt.Errorf(`unterminated quoted string literal in connection string`) 527 } 528 switch r { 529 case '\'': 530 break quote 531 case '\\': 532 r, _ = s.Next() 533 fallthrough 534 default: 535 valRunes = append(valRunes, r) 536 } 537 } 538 } 539 540 o[string(keyRunes)] = string(valRunes) 541 } 542 543 return nil 544} 545 546func (cn *conn) isInTransaction() bool { 547 return cn.txnStatus == txnStatusIdleInTransaction || 548 cn.txnStatus == txnStatusInFailedTransaction 549} 550 551func (cn *conn) checkIsInTransaction(intxn bool) { 552 if cn.isInTransaction() != intxn { 553 cn.err.set(driver.ErrBadConn) 554 errorf("unexpected transaction status %v", cn.txnStatus) 555 } 556} 557 558func (cn *conn) Begin() (_ driver.Tx, err error) { 559 return cn.begin("") 560} 561 562func (cn *conn) begin(mode string) (_ driver.Tx, err error) { 563 if err := cn.err.get(); err != nil { 564 return nil, err 565 } 566 defer cn.errRecover(&err) 567 568 cn.checkIsInTransaction(false) 569 _, commandTag, err := cn.simpleExec("BEGIN" + mode) 570 if err != nil { 571 return nil, err 572 } 573 if commandTag != "BEGIN" { 574 cn.err.set(driver.ErrBadConn) 575 return nil, fmt.Errorf("unexpected command tag %s", commandTag) 576 } 577 if cn.txnStatus != txnStatusIdleInTransaction { 578 cn.err.set(driver.ErrBadConn) 579 return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus) 580 } 581 return cn, nil 582} 583 584func (cn *conn) closeTxn() { 585 if finish := cn.txnFinish; finish != nil { 586 finish() 587 } 588} 589 590func (cn *conn) Commit() (err error) { 591 defer cn.closeTxn() 592 if err := cn.err.get(); err != nil { 593 return err 594 } 595 defer cn.errRecover(&err) 596 597 cn.checkIsInTransaction(true) 598 // We don't want the client to think that everything is okay if it tries 599 // to commit a failed transaction. However, no matter what we return, 600 // database/sql will release this connection back into the free connection 601 // pool so we have to abort the current transaction here. Note that you 602 // would get the same behaviour if you issued a COMMIT in a failed 603 // transaction, so it's also the least surprising thing to do here. 604 if cn.txnStatus == txnStatusInFailedTransaction { 605 if err := cn.rollback(); err != nil { 606 return err 607 } 608 return ErrInFailedTransaction 609 } 610 611 _, commandTag, err := cn.simpleExec("COMMIT") 612 if err != nil { 613 if cn.isInTransaction() { 614 cn.err.set(driver.ErrBadConn) 615 } 616 return err 617 } 618 if commandTag != "COMMIT" { 619 cn.err.set(driver.ErrBadConn) 620 return fmt.Errorf("unexpected command tag %s", commandTag) 621 } 622 cn.checkIsInTransaction(false) 623 return nil 624} 625 626func (cn *conn) Rollback() (err error) { 627 defer cn.closeTxn() 628 if err := cn.err.get(); err != nil { 629 return err 630 } 631 defer cn.errRecover(&err) 632 return cn.rollback() 633} 634 635func (cn *conn) rollback() (err error) { 636 cn.checkIsInTransaction(true) 637 _, commandTag, err := cn.simpleExec("ROLLBACK") 638 if err != nil { 639 if cn.isInTransaction() { 640 cn.err.set(driver.ErrBadConn) 641 } 642 return err 643 } 644 if commandTag != "ROLLBACK" { 645 return fmt.Errorf("unexpected command tag %s", commandTag) 646 } 647 cn.checkIsInTransaction(false) 648 return nil 649} 650 651func (cn *conn) gname() string { 652 cn.namei++ 653 return strconv.FormatInt(int64(cn.namei), 10) 654} 655 656func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) { 657 b := cn.writeBuf('Q') 658 b.string(q) 659 cn.send(b) 660 661 for { 662 t, r := cn.recv1() 663 switch t { 664 case 'C': 665 res, commandTag = cn.parseComplete(r.string()) 666 case 'Z': 667 cn.processReadyForQuery(r) 668 if res == nil && err == nil { 669 err = errUnexpectedReady 670 } 671 // done 672 return 673 case 'E': 674 err = parseError(r) 675 case 'I': 676 res = emptyRows 677 case 'T', 'D': 678 // ignore any results 679 default: 680 cn.err.set(driver.ErrBadConn) 681 errorf("unknown response for simple query: %q", t) 682 } 683 } 684} 685 686func (cn *conn) simpleQuery(q string) (res *rows, err error) { 687 defer cn.errRecover(&err) 688 689 b := cn.writeBuf('Q') 690 b.string(q) 691 cn.send(b) 692 693 for { 694 t, r := cn.recv1() 695 switch t { 696 case 'C', 'I': 697 // We allow queries which don't return any results through Query as 698 // well as Exec. We still have to give database/sql a rows object 699 // the user can close, though, to avoid connections from being 700 // leaked. A "rows" with done=true works fine for that purpose. 701 if err != nil { 702 cn.err.set(driver.ErrBadConn) 703 errorf("unexpected message %q in simple query execution", t) 704 } 705 if res == nil { 706 res = &rows{ 707 cn: cn, 708 } 709 } 710 // Set the result and tag to the last command complete if there wasn't a 711 // query already run. Although queries usually return from here and cede 712 // control to Next, a query with zero results does not. 713 if t == 'C' { 714 res.result, res.tag = cn.parseComplete(r.string()) 715 if res.colNames != nil { 716 return 717 } 718 } 719 res.done = true 720 case 'Z': 721 cn.processReadyForQuery(r) 722 // done 723 return 724 case 'E': 725 res = nil 726 err = parseError(r) 727 case 'D': 728 if res == nil { 729 cn.err.set(driver.ErrBadConn) 730 errorf("unexpected DataRow in simple query execution") 731 } 732 // the query didn't fail; kick off to Next 733 cn.saveMessage(t, r) 734 return 735 case 'T': 736 // res might be non-nil here if we received a previous 737 // CommandComplete, but that's fine; just overwrite it 738 res = &rows{cn: cn} 739 res.rowsHeader = parsePortalRowDescribe(r) 740 741 // To work around a bug in QueryRow in Go 1.2 and earlier, wait 742 // until the first DataRow has been received. 743 default: 744 cn.err.set(driver.ErrBadConn) 745 errorf("unknown response for simple query: %q", t) 746 } 747 } 748} 749 750type noRows struct{} 751 752var emptyRows noRows 753 754var _ driver.Result = noRows{} 755 756func (noRows) LastInsertId() (int64, error) { 757 return 0, errNoLastInsertID 758} 759 760func (noRows) RowsAffected() (int64, error) { 761 return 0, errNoRowsAffected 762} 763 764// Decides which column formats to use for a prepared statement. The input is 765// an array of type oids, one element per result column. 766func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte) { 767 if len(colTyps) == 0 { 768 return nil, colFmtDataAllText 769 } 770 771 colFmts = make([]format, len(colTyps)) 772 if forceText { 773 return colFmts, colFmtDataAllText 774 } 775 776 allBinary := true 777 allText := true 778 for i, t := range colTyps { 779 switch t.OID { 780 // This is the list of types to use binary mode for when receiving them 781 // through a prepared statement. If a type appears in this list, it 782 // must also be implemented in binaryDecode in encode.go. 783 case oid.T_bytea: 784 fallthrough 785 case oid.T_int8: 786 fallthrough 787 case oid.T_int4: 788 fallthrough 789 case oid.T_int2: 790 fallthrough 791 case oid.T_uuid: 792 colFmts[i] = formatBinary 793 allText = false 794 795 default: 796 allBinary = false 797 } 798 } 799 800 if allBinary { 801 return colFmts, colFmtDataAllBinary 802 } else if allText { 803 return colFmts, colFmtDataAllText 804 } else { 805 colFmtData = make([]byte, 2+len(colFmts)*2) 806 binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts))) 807 for i, v := range colFmts { 808 binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v)) 809 } 810 return colFmts, colFmtData 811 } 812} 813 814func (cn *conn) prepareTo(q, stmtName string) *stmt { 815 st := &stmt{cn: cn, name: stmtName} 816 817 b := cn.writeBuf('P') 818 b.string(st.name) 819 b.string(q) 820 b.int16(0) 821 822 b.next('D') 823 b.byte('S') 824 b.string(st.name) 825 826 b.next('S') 827 cn.send(b) 828 829 cn.readParseResponse() 830 st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse() 831 st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult) 832 cn.readReadyForQuery() 833 return st 834} 835 836func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) { 837 if err := cn.err.get(); err != nil { 838 return nil, err 839 } 840 defer cn.errRecover(&err) 841 842 if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") { 843 s, err := cn.prepareCopyIn(q) 844 if err == nil { 845 cn.inCopy = true 846 } 847 return s, err 848 } 849 return cn.prepareTo(q, cn.gname()), nil 850} 851 852func (cn *conn) Close() (err error) { 853 // Skip cn.bad return here because we always want to close a connection. 854 defer cn.errRecover(&err) 855 856 // Ensure that cn.c.Close is always run. Since error handling is done with 857 // panics and cn.errRecover, the Close must be in a defer. 858 defer func() { 859 cerr := cn.c.Close() 860 if err == nil { 861 err = cerr 862 } 863 }() 864 865 // Don't go through send(); ListenerConn relies on us not scribbling on the 866 // scratch buffer of this connection. 867 return cn.sendSimpleMessage('X') 868} 869 870// Implement the "Queryer" interface 871func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) { 872 return cn.query(query, args) 873} 874 875func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { 876 if err := cn.err.get(); err != nil { 877 return nil, err 878 } 879 if cn.inCopy { 880 return nil, errCopyInProgress 881 } 882 defer cn.errRecover(&err) 883 884 // Check to see if we can use the "simpleQuery" interface, which is 885 // *much* faster than going through prepare/exec 886 if len(args) == 0 { 887 return cn.simpleQuery(query) 888 } 889 890 if cn.binaryParameters { 891 cn.sendBinaryModeQuery(query, args) 892 893 cn.readParseResponse() 894 cn.readBindResponse() 895 rows := &rows{cn: cn} 896 rows.rowsHeader = cn.readPortalDescribeResponse() 897 cn.postExecuteWorkaround() 898 return rows, nil 899 } 900 st := cn.prepareTo(query, "") 901 st.exec(args) 902 return &rows{ 903 cn: cn, 904 rowsHeader: st.rowsHeader, 905 }, nil 906} 907 908// Implement the optional "Execer" interface for one-shot queries 909func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) { 910 if err := cn.err.get(); err != nil { 911 return nil, err 912 } 913 defer cn.errRecover(&err) 914 915 // Check to see if we can use the "simpleExec" interface, which is 916 // *much* faster than going through prepare/exec 917 if len(args) == 0 { 918 // ignore commandTag, our caller doesn't care 919 r, _, err := cn.simpleExec(query) 920 return r, err 921 } 922 923 if cn.binaryParameters { 924 cn.sendBinaryModeQuery(query, args) 925 926 cn.readParseResponse() 927 cn.readBindResponse() 928 cn.readPortalDescribeResponse() 929 cn.postExecuteWorkaround() 930 res, _, err = cn.readExecuteResponse("Execute") 931 return res, err 932 } 933 // Use the unnamed statement to defer planning until bind 934 // time, or else value-based selectivity estimates cannot be 935 // used. 936 st := cn.prepareTo(query, "") 937 r, err := st.Exec(args) 938 if err != nil { 939 panic(err) 940 } 941 return r, err 942} 943 944type safeRetryError struct { 945 Err error 946} 947 948func (se *safeRetryError) Error() string { 949 return se.Err.Error() 950} 951 952func (cn *conn) send(m *writeBuf) { 953 n, err := cn.c.Write(m.wrap()) 954 if err != nil { 955 if n == 0 { 956 err = &safeRetryError{Err: err} 957 } 958 panic(err) 959 } 960} 961 962func (cn *conn) sendStartupPacket(m *writeBuf) error { 963 _, err := cn.c.Write((m.wrap())[1:]) 964 return err 965} 966 967// Send a message of type typ to the server on the other end of cn. The 968// message should have no payload. This method does not use the scratch 969// buffer. 970func (cn *conn) sendSimpleMessage(typ byte) (err error) { 971 _, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'}) 972 return err 973} 974 975// saveMessage memorizes a message and its buffer in the conn struct. 976// recvMessage will then return these values on the next call to it. This 977// method is useful in cases where you have to see what the next message is 978// going to be (e.g. to see whether it's an error or not) but you can't handle 979// the message yourself. 980func (cn *conn) saveMessage(typ byte, buf *readBuf) { 981 if cn.saveMessageType != 0 { 982 cn.err.set(driver.ErrBadConn) 983 errorf("unexpected saveMessageType %d", cn.saveMessageType) 984 } 985 cn.saveMessageType = typ 986 cn.saveMessageBuffer = *buf 987} 988 989// recvMessage receives any message from the backend, or returns an error if 990// a problem occurred while reading the message. 991func (cn *conn) recvMessage(r *readBuf) (byte, error) { 992 // workaround for a QueryRow bug, see exec 993 if cn.saveMessageType != 0 { 994 t := cn.saveMessageType 995 *r = cn.saveMessageBuffer 996 cn.saveMessageType = 0 997 cn.saveMessageBuffer = nil 998 return t, nil 999 } 1000 1001 x := cn.scratch[:5] 1002 _, err := io.ReadFull(cn.buf, x) 1003 if err != nil { 1004 return 0, err 1005 } 1006 1007 // read the type and length of the message that follows 1008 t := x[0] 1009 n := int(binary.BigEndian.Uint32(x[1:])) - 4 1010 var y []byte 1011 if n <= len(cn.scratch) { 1012 y = cn.scratch[:n] 1013 } else { 1014 y = make([]byte, n) 1015 } 1016 _, err = io.ReadFull(cn.buf, y) 1017 if err != nil { 1018 return 0, err 1019 } 1020 *r = y 1021 return t, nil 1022} 1023 1024// recv receives a message from the backend, but if an error happened while 1025// reading the message or the received message was an ErrorResponse, it panics. 1026// NoticeResponses are ignored. This function should generally be used only 1027// during the startup sequence. 1028func (cn *conn) recv() (t byte, r *readBuf) { 1029 for { 1030 var err error 1031 r = &readBuf{} 1032 t, err = cn.recvMessage(r) 1033 if err != nil { 1034 panic(err) 1035 } 1036 switch t { 1037 case 'E': 1038 panic(parseError(r)) 1039 case 'N': 1040 if n := cn.noticeHandler; n != nil { 1041 n(parseError(r)) 1042 } 1043 case 'A': 1044 if n := cn.notificationHandler; n != nil { 1045 n(recvNotification(r)) 1046 } 1047 default: 1048 return 1049 } 1050 } 1051} 1052 1053// recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by 1054// the caller to avoid an allocation. 1055func (cn *conn) recv1Buf(r *readBuf) byte { 1056 for { 1057 t, err := cn.recvMessage(r) 1058 if err != nil { 1059 panic(err) 1060 } 1061 1062 switch t { 1063 case 'A': 1064 if n := cn.notificationHandler; n != nil { 1065 n(recvNotification(r)) 1066 } 1067 case 'N': 1068 if n := cn.noticeHandler; n != nil { 1069 n(parseError(r)) 1070 } 1071 case 'S': 1072 cn.processParameterStatus(r) 1073 default: 1074 return t 1075 } 1076 } 1077} 1078 1079// recv1 receives a message from the backend, panicking if an error occurs 1080// while attempting to read it. All asynchronous messages are ignored, with 1081// the exception of ErrorResponse. 1082func (cn *conn) recv1() (t byte, r *readBuf) { 1083 r = &readBuf{} 1084 t = cn.recv1Buf(r) 1085 return t, r 1086} 1087 1088func (cn *conn) ssl(o values) error { 1089 upgrade, err := ssl(o) 1090 if err != nil { 1091 return err 1092 } 1093 1094 if upgrade == nil { 1095 // Nothing to do 1096 return nil 1097 } 1098 1099 w := cn.writeBuf(0) 1100 w.int32(80877103) 1101 if err = cn.sendStartupPacket(w); err != nil { 1102 return err 1103 } 1104 1105 b := cn.scratch[:1] 1106 _, err = io.ReadFull(cn.c, b) 1107 if err != nil { 1108 return err 1109 } 1110 1111 if b[0] != 'S' { 1112 return ErrSSLNotSupported 1113 } 1114 1115 cn.c, err = upgrade(cn.c) 1116 return err 1117} 1118 1119// isDriverSetting returns true iff a setting is purely for configuring the 1120// driver's options and should not be sent to the server in the connection 1121// startup packet. 1122func isDriverSetting(key string) bool { 1123 switch key { 1124 case "host", "port": 1125 return true 1126 case "password": 1127 return true 1128 case "sslmode", "sslcert", "sslkey", "sslrootcert", "sslinline": 1129 return true 1130 case "fallback_application_name": 1131 return true 1132 case "connect_timeout": 1133 return true 1134 case "disable_prepared_binary_result": 1135 return true 1136 case "binary_parameters": 1137 return true 1138 case "krbsrvname": 1139 return true 1140 case "krbspn": 1141 return true 1142 default: 1143 return false 1144 } 1145} 1146 1147func (cn *conn) startup(o values) { 1148 w := cn.writeBuf(0) 1149 w.int32(196608) 1150 // Send the backend the name of the database we want to connect to, and the 1151 // user we want to connect as. Additionally, we send over any run-time 1152 // parameters potentially included in the connection string. If the server 1153 // doesn't recognize any of them, it will reply with an error. 1154 for k, v := range o { 1155 if isDriverSetting(k) { 1156 // skip options which can't be run-time parameters 1157 continue 1158 } 1159 // The protocol requires us to supply the database name as "database" 1160 // instead of "dbname". 1161 if k == "dbname" { 1162 k = "database" 1163 } 1164 w.string(k) 1165 w.string(v) 1166 } 1167 w.string("") 1168 if err := cn.sendStartupPacket(w); err != nil { 1169 panic(err) 1170 } 1171 1172 for { 1173 t, r := cn.recv() 1174 switch t { 1175 case 'K': 1176 cn.processBackendKeyData(r) 1177 case 'S': 1178 cn.processParameterStatus(r) 1179 case 'R': 1180 cn.auth(r, o) 1181 case 'Z': 1182 cn.processReadyForQuery(r) 1183 return 1184 default: 1185 errorf("unknown response for startup: %q", t) 1186 } 1187 } 1188} 1189 1190func (cn *conn) auth(r *readBuf, o values) { 1191 switch code := r.int32(); code { 1192 case 0: 1193 // OK 1194 case 3: 1195 w := cn.writeBuf('p') 1196 w.string(o["password"]) 1197 cn.send(w) 1198 1199 t, r := cn.recv() 1200 if t != 'R' { 1201 errorf("unexpected password response: %q", t) 1202 } 1203 1204 if r.int32() != 0 { 1205 errorf("unexpected authentication response: %q", t) 1206 } 1207 case 5: 1208 s := string(r.next(4)) 1209 w := cn.writeBuf('p') 1210 w.string("md5" + md5s(md5s(o["password"]+o["user"])+s)) 1211 cn.send(w) 1212 1213 t, r := cn.recv() 1214 if t != 'R' { 1215 errorf("unexpected password response: %q", t) 1216 } 1217 1218 if r.int32() != 0 { 1219 errorf("unexpected authentication response: %q", t) 1220 } 1221 case 7: // GSSAPI, startup 1222 if newGss == nil { 1223 errorf("kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos if you need Kerberos support)") 1224 } 1225 cli, err := newGss() 1226 if err != nil { 1227 errorf("kerberos error: %s", err.Error()) 1228 } 1229 1230 var token []byte 1231 1232 if spn, ok := o["krbspn"]; ok { 1233 // Use the supplied SPN if provided.. 1234 token, err = cli.GetInitTokenFromSpn(spn) 1235 } else { 1236 // Allow the kerberos service name to be overridden 1237 service := "postgres" 1238 if val, ok := o["krbsrvname"]; ok { 1239 service = val 1240 } 1241 1242 token, err = cli.GetInitToken(o["host"], service) 1243 } 1244 1245 if err != nil { 1246 errorf("failed to get Kerberos ticket: %q", err) 1247 } 1248 1249 w := cn.writeBuf('p') 1250 w.bytes(token) 1251 cn.send(w) 1252 1253 // Store for GSSAPI continue message 1254 cn.gss = cli 1255 1256 case 8: // GSSAPI continue 1257 1258 if cn.gss == nil { 1259 errorf("GSSAPI protocol error") 1260 } 1261 1262 b := []byte(*r) 1263 1264 done, tokOut, err := cn.gss.Continue(b) 1265 if err == nil && !done { 1266 w := cn.writeBuf('p') 1267 w.bytes(tokOut) 1268 cn.send(w) 1269 } 1270 1271 // Errors fall through and read the more detailed message 1272 // from the server.. 1273 1274 case 10: 1275 sc := scram.NewClient(sha256.New, o["user"], o["password"]) 1276 sc.Step(nil) 1277 if sc.Err() != nil { 1278 errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) 1279 } 1280 scOut := sc.Out() 1281 1282 w := cn.writeBuf('p') 1283 w.string("SCRAM-SHA-256") 1284 w.int32(len(scOut)) 1285 w.bytes(scOut) 1286 cn.send(w) 1287 1288 t, r := cn.recv() 1289 if t != 'R' { 1290 errorf("unexpected password response: %q", t) 1291 } 1292 1293 if r.int32() != 11 { 1294 errorf("unexpected authentication response: %q", t) 1295 } 1296 1297 nextStep := r.next(len(*r)) 1298 sc.Step(nextStep) 1299 if sc.Err() != nil { 1300 errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) 1301 } 1302 1303 scOut = sc.Out() 1304 w = cn.writeBuf('p') 1305 w.bytes(scOut) 1306 cn.send(w) 1307 1308 t, r = cn.recv() 1309 if t != 'R' { 1310 errorf("unexpected password response: %q", t) 1311 } 1312 1313 if r.int32() != 12 { 1314 errorf("unexpected authentication response: %q", t) 1315 } 1316 1317 nextStep = r.next(len(*r)) 1318 sc.Step(nextStep) 1319 if sc.Err() != nil { 1320 errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) 1321 } 1322 1323 default: 1324 errorf("unknown authentication response: %d", code) 1325 } 1326} 1327 1328type format int 1329 1330const formatText format = 0 1331const formatBinary format = 1 1332 1333// One result-column format code with the value 1 (i.e. all binary). 1334var colFmtDataAllBinary = []byte{0, 1, 0, 1} 1335 1336// No result-column format codes (i.e. all text). 1337var colFmtDataAllText = []byte{0, 0} 1338 1339type stmt struct { 1340 cn *conn 1341 name string 1342 rowsHeader 1343 colFmtData []byte 1344 paramTyps []oid.Oid 1345 closed bool 1346} 1347 1348func (st *stmt) Close() (err error) { 1349 if st.closed { 1350 return nil 1351 } 1352 if err := st.cn.err.get(); err != nil { 1353 return err 1354 } 1355 defer st.cn.errRecover(&err) 1356 1357 w := st.cn.writeBuf('C') 1358 w.byte('S') 1359 w.string(st.name) 1360 st.cn.send(w) 1361 1362 st.cn.send(st.cn.writeBuf('S')) 1363 1364 t, _ := st.cn.recv1() 1365 if t != '3' { 1366 st.cn.err.set(driver.ErrBadConn) 1367 errorf("unexpected close response: %q", t) 1368 } 1369 st.closed = true 1370 1371 t, r := st.cn.recv1() 1372 if t != 'Z' { 1373 st.cn.err.set(driver.ErrBadConn) 1374 errorf("expected ready for query, but got: %q", t) 1375 } 1376 st.cn.processReadyForQuery(r) 1377 1378 return nil 1379} 1380 1381func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { 1382 return st.query(v) 1383} 1384 1385func (st *stmt) query(v []driver.Value) (r *rows, err error) { 1386 if err := st.cn.err.get(); err != nil { 1387 return nil, err 1388 } 1389 defer st.cn.errRecover(&err) 1390 1391 st.exec(v) 1392 return &rows{ 1393 cn: st.cn, 1394 rowsHeader: st.rowsHeader, 1395 }, nil 1396} 1397 1398func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) { 1399 if err := st.cn.err.get(); err != nil { 1400 return nil, err 1401 } 1402 defer st.cn.errRecover(&err) 1403 1404 st.exec(v) 1405 res, _, err = st.cn.readExecuteResponse("simple query") 1406 return res, err 1407} 1408 1409func (st *stmt) exec(v []driver.Value) { 1410 if len(v) >= 65536 { 1411 errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v)) 1412 } 1413 if len(v) != len(st.paramTyps) { 1414 errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps)) 1415 } 1416 1417 cn := st.cn 1418 w := cn.writeBuf('B') 1419 w.byte(0) // unnamed portal 1420 w.string(st.name) 1421 1422 if cn.binaryParameters { 1423 cn.sendBinaryParameters(w, v) 1424 } else { 1425 w.int16(0) 1426 w.int16(len(v)) 1427 for i, x := range v { 1428 if x == nil { 1429 w.int32(-1) 1430 } else { 1431 b := encode(&cn.parameterStatus, x, st.paramTyps[i]) 1432 w.int32(len(b)) 1433 w.bytes(b) 1434 } 1435 } 1436 } 1437 w.bytes(st.colFmtData) 1438 1439 w.next('E') 1440 w.byte(0) 1441 w.int32(0) 1442 1443 w.next('S') 1444 cn.send(w) 1445 1446 cn.readBindResponse() 1447 cn.postExecuteWorkaround() 1448 1449} 1450 1451func (st *stmt) NumInput() int { 1452 return len(st.paramTyps) 1453} 1454 1455// parseComplete parses the "command tag" from a CommandComplete message, and 1456// returns the number of rows affected (if applicable) and a string 1457// identifying only the command that was executed, e.g. "ALTER TABLE". If the 1458// command tag could not be parsed, parseComplete panics. 1459func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { 1460 commandsWithAffectedRows := []string{ 1461 "SELECT ", 1462 // INSERT is handled below 1463 "UPDATE ", 1464 "DELETE ", 1465 "FETCH ", 1466 "MOVE ", 1467 "COPY ", 1468 } 1469 1470 var affectedRows *string 1471 for _, tag := range commandsWithAffectedRows { 1472 if strings.HasPrefix(commandTag, tag) { 1473 t := commandTag[len(tag):] 1474 affectedRows = &t 1475 commandTag = tag[:len(tag)-1] 1476 break 1477 } 1478 } 1479 // INSERT also includes the oid of the inserted row in its command tag. 1480 // Oids in user tables are deprecated, and the oid is only returned when 1481 // exactly one row is inserted, so it's unlikely to be of value to any 1482 // real-world application and we can ignore it. 1483 if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") { 1484 parts := strings.Split(commandTag, " ") 1485 if len(parts) != 3 { 1486 cn.err.set(driver.ErrBadConn) 1487 errorf("unexpected INSERT command tag %s", commandTag) 1488 } 1489 affectedRows = &parts[len(parts)-1] 1490 commandTag = "INSERT" 1491 } 1492 // There should be no affected rows attached to the tag, just return it 1493 if affectedRows == nil { 1494 return driver.RowsAffected(0), commandTag 1495 } 1496 n, err := strconv.ParseInt(*affectedRows, 10, 64) 1497 if err != nil { 1498 cn.err.set(driver.ErrBadConn) 1499 errorf("could not parse commandTag: %s", err) 1500 } 1501 return driver.RowsAffected(n), commandTag 1502} 1503 1504type rowsHeader struct { 1505 colNames []string 1506 colTyps []fieldDesc 1507 colFmts []format 1508} 1509 1510type rows struct { 1511 cn *conn 1512 finish func() 1513 rowsHeader 1514 done bool 1515 rb readBuf 1516 result driver.Result 1517 tag string 1518 1519 next *rowsHeader 1520} 1521 1522func (rs *rows) Close() error { 1523 if finish := rs.finish; finish != nil { 1524 defer finish() 1525 } 1526 // no need to look at cn.bad as Next() will 1527 for { 1528 err := rs.Next(nil) 1529 switch err { 1530 case nil: 1531 case io.EOF: 1532 // rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row 1533 // description, used with HasNextResultSet). We need to fetch messages until 1534 // we hit a 'Z', which is done by waiting for done to be set. 1535 if rs.done { 1536 return nil 1537 } 1538 default: 1539 return err 1540 } 1541 } 1542} 1543 1544func (rs *rows) Columns() []string { 1545 return rs.colNames 1546} 1547 1548func (rs *rows) Result() driver.Result { 1549 if rs.result == nil { 1550 return emptyRows 1551 } 1552 return rs.result 1553} 1554 1555func (rs *rows) Tag() string { 1556 return rs.tag 1557} 1558 1559func (rs *rows) Next(dest []driver.Value) (err error) { 1560 if rs.done { 1561 return io.EOF 1562 } 1563 1564 conn := rs.cn 1565 if err := conn.err.getForNext(); err != nil { 1566 return err 1567 } 1568 defer conn.errRecover(&err) 1569 1570 for { 1571 t := conn.recv1Buf(&rs.rb) 1572 switch t { 1573 case 'E': 1574 err = parseError(&rs.rb) 1575 case 'C', 'I': 1576 if t == 'C' { 1577 rs.result, rs.tag = conn.parseComplete(rs.rb.string()) 1578 } 1579 continue 1580 case 'Z': 1581 conn.processReadyForQuery(&rs.rb) 1582 rs.done = true 1583 if err != nil { 1584 return err 1585 } 1586 return io.EOF 1587 case 'D': 1588 n := rs.rb.int16() 1589 if err != nil { 1590 conn.err.set(driver.ErrBadConn) 1591 errorf("unexpected DataRow after error %s", err) 1592 } 1593 if n < len(dest) { 1594 dest = dest[:n] 1595 } 1596 for i := range dest { 1597 l := rs.rb.int32() 1598 if l == -1 { 1599 dest[i] = nil 1600 continue 1601 } 1602 dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i]) 1603 } 1604 return 1605 case 'T': 1606 next := parsePortalRowDescribe(&rs.rb) 1607 rs.next = &next 1608 return io.EOF 1609 default: 1610 errorf("unexpected message after execute: %q", t) 1611 } 1612 } 1613} 1614 1615func (rs *rows) HasNextResultSet() bool { 1616 hasNext := rs.next != nil && !rs.done 1617 return hasNext 1618} 1619 1620func (rs *rows) NextResultSet() error { 1621 if rs.next == nil { 1622 return io.EOF 1623 } 1624 rs.rowsHeader = *rs.next 1625 rs.next = nil 1626 return nil 1627} 1628 1629// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be 1630// used as part of an SQL statement. For example: 1631// 1632// tblname := "my_table" 1633// data := "my_data" 1634// quoted := pq.QuoteIdentifier(tblname) 1635// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data) 1636// 1637// Any double quotes in name will be escaped. The quoted identifier will be 1638// case sensitive when used in a query. If the input string contains a zero 1639// byte, the result will be truncated immediately before it. 1640func QuoteIdentifier(name string) string { 1641 end := strings.IndexRune(name, 0) 1642 if end > -1 { 1643 name = name[:end] 1644 } 1645 return `"` + strings.Replace(name, `"`, `""`, -1) + `"` 1646} 1647 1648// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal 1649// to DDL and other statements that do not accept parameters) to be used as part 1650// of an SQL statement. For example: 1651// 1652// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z") 1653// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date)) 1654// 1655// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be 1656// replaced by two backslashes (i.e. "\\") and the C-style escape identifier 1657// that PostgreSQL provides ('E') will be prepended to the string. 1658func QuoteLiteral(literal string) string { 1659 // This follows the PostgreSQL internal algorithm for handling quoted literals 1660 // from libpq, which can be found in the "PQEscapeStringInternal" function, 1661 // which is found in the libpq/fe-exec.c source file: 1662 // https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c 1663 // 1664 // substitute any single-quotes (') with two single-quotes ('') 1665 literal = strings.Replace(literal, `'`, `''`, -1) 1666 // determine if the string has any backslashes (\) in it. 1667 // if it does, replace any backslashes (\) with two backslashes (\\) 1668 // then, we need to wrap the entire string with a PostgreSQL 1669 // C-style escape. Per how "PQEscapeStringInternal" handles this case, we 1670 // also add a space before the "E" 1671 if strings.Contains(literal, `\`) { 1672 literal = strings.Replace(literal, `\`, `\\`, -1) 1673 literal = ` E'` + literal + `'` 1674 } else { 1675 // otherwise, we can just wrap the literal with a pair of single quotes 1676 literal = `'` + literal + `'` 1677 } 1678 return literal 1679} 1680 1681func md5s(s string) string { 1682 h := md5.New() 1683 h.Write([]byte(s)) 1684 return fmt.Sprintf("%x", h.Sum(nil)) 1685} 1686 1687func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) { 1688 // Do one pass over the parameters to see if we're going to send any of 1689 // them over in binary. If we are, create a paramFormats array at the 1690 // same time. 1691 var paramFormats []int 1692 for i, x := range args { 1693 _, ok := x.([]byte) 1694 if ok { 1695 if paramFormats == nil { 1696 paramFormats = make([]int, len(args)) 1697 } 1698 paramFormats[i] = 1 1699 } 1700 } 1701 if paramFormats == nil { 1702 b.int16(0) 1703 } else { 1704 b.int16(len(paramFormats)) 1705 for _, x := range paramFormats { 1706 b.int16(x) 1707 } 1708 } 1709 1710 b.int16(len(args)) 1711 for _, x := range args { 1712 if x == nil { 1713 b.int32(-1) 1714 } else { 1715 datum := binaryEncode(&cn.parameterStatus, x) 1716 b.int32(len(datum)) 1717 b.bytes(datum) 1718 } 1719 } 1720} 1721 1722func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) { 1723 if len(args) >= 65536 { 1724 errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args)) 1725 } 1726 1727 b := cn.writeBuf('P') 1728 b.byte(0) // unnamed statement 1729 b.string(query) 1730 b.int16(0) 1731 1732 b.next('B') 1733 b.int16(0) // unnamed portal and statement 1734 cn.sendBinaryParameters(b, args) 1735 b.bytes(colFmtDataAllText) 1736 1737 b.next('D') 1738 b.byte('P') 1739 b.byte(0) // unnamed portal 1740 1741 b.next('E') 1742 b.byte(0) 1743 b.int32(0) 1744 1745 b.next('S') 1746 cn.send(b) 1747} 1748 1749func (cn *conn) processParameterStatus(r *readBuf) { 1750 var err error 1751 1752 param := r.string() 1753 switch param { 1754 case "server_version": 1755 var major1 int 1756 var major2 int 1757 _, err = fmt.Sscanf(r.string(), "%d.%d", &major1, &major2) 1758 if err == nil { 1759 cn.parameterStatus.serverVersion = major1*10000 + major2*100 1760 } 1761 1762 case "TimeZone": 1763 cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string()) 1764 if err != nil { 1765 cn.parameterStatus.currentLocation = nil 1766 } 1767 1768 default: 1769 // ignore 1770 } 1771} 1772 1773func (cn *conn) processReadyForQuery(r *readBuf) { 1774 cn.txnStatus = transactionStatus(r.byte()) 1775} 1776 1777func (cn *conn) readReadyForQuery() { 1778 t, r := cn.recv1() 1779 switch t { 1780 case 'Z': 1781 cn.processReadyForQuery(r) 1782 return 1783 default: 1784 cn.err.set(driver.ErrBadConn) 1785 errorf("unexpected message %q; expected ReadyForQuery", t) 1786 } 1787} 1788 1789func (cn *conn) processBackendKeyData(r *readBuf) { 1790 cn.processID = r.int32() 1791 cn.secretKey = r.int32() 1792} 1793 1794func (cn *conn) readParseResponse() { 1795 t, r := cn.recv1() 1796 switch t { 1797 case '1': 1798 return 1799 case 'E': 1800 err := parseError(r) 1801 cn.readReadyForQuery() 1802 panic(err) 1803 default: 1804 cn.err.set(driver.ErrBadConn) 1805 errorf("unexpected Parse response %q", t) 1806 } 1807} 1808 1809func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc) { 1810 for { 1811 t, r := cn.recv1() 1812 switch t { 1813 case 't': 1814 nparams := r.int16() 1815 paramTyps = make([]oid.Oid, nparams) 1816 for i := range paramTyps { 1817 paramTyps[i] = r.oid() 1818 } 1819 case 'n': 1820 return paramTyps, nil, nil 1821 case 'T': 1822 colNames, colTyps = parseStatementRowDescribe(r) 1823 return paramTyps, colNames, colTyps 1824 case 'E': 1825 err := parseError(r) 1826 cn.readReadyForQuery() 1827 panic(err) 1828 default: 1829 cn.err.set(driver.ErrBadConn) 1830 errorf("unexpected Describe statement response %q", t) 1831 } 1832 } 1833} 1834 1835func (cn *conn) readPortalDescribeResponse() rowsHeader { 1836 t, r := cn.recv1() 1837 switch t { 1838 case 'T': 1839 return parsePortalRowDescribe(r) 1840 case 'n': 1841 return rowsHeader{} 1842 case 'E': 1843 err := parseError(r) 1844 cn.readReadyForQuery() 1845 panic(err) 1846 default: 1847 cn.err.set(driver.ErrBadConn) 1848 errorf("unexpected Describe response %q", t) 1849 } 1850 panic("not reached") 1851} 1852 1853func (cn *conn) readBindResponse() { 1854 t, r := cn.recv1() 1855 switch t { 1856 case '2': 1857 return 1858 case 'E': 1859 err := parseError(r) 1860 cn.readReadyForQuery() 1861 panic(err) 1862 default: 1863 cn.err.set(driver.ErrBadConn) 1864 errorf("unexpected Bind response %q", t) 1865 } 1866} 1867 1868func (cn *conn) postExecuteWorkaround() { 1869 // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores 1870 // any errors from rows.Next, which masks errors that happened during the 1871 // execution of the query. To avoid the problem in common cases, we wait 1872 // here for one more message from the database. If it's not an error the 1873 // query will likely succeed (or perhaps has already, if it's a 1874 // CommandComplete), so we push the message into the conn struct; recv1 1875 // will return it as the next message for rows.Next or rows.Close. 1876 // However, if it's an error, we wait until ReadyForQuery and then return 1877 // the error to our caller. 1878 for { 1879 t, r := cn.recv1() 1880 switch t { 1881 case 'E': 1882 err := parseError(r) 1883 cn.readReadyForQuery() 1884 panic(err) 1885 case 'C', 'D', 'I': 1886 // the query didn't fail, but we can't process this message 1887 cn.saveMessage(t, r) 1888 return 1889 default: 1890 cn.err.set(driver.ErrBadConn) 1891 errorf("unexpected message during extended query execution: %q", t) 1892 } 1893 } 1894} 1895 1896// Only for Exec(), since we ignore the returned data 1897func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) { 1898 for { 1899 t, r := cn.recv1() 1900 switch t { 1901 case 'C': 1902 if err != nil { 1903 cn.err.set(driver.ErrBadConn) 1904 errorf("unexpected CommandComplete after error %s", err) 1905 } 1906 res, commandTag = cn.parseComplete(r.string()) 1907 case 'Z': 1908 cn.processReadyForQuery(r) 1909 if res == nil && err == nil { 1910 err = errUnexpectedReady 1911 } 1912 return res, commandTag, err 1913 case 'E': 1914 err = parseError(r) 1915 case 'T', 'D', 'I': 1916 if err != nil { 1917 cn.err.set(driver.ErrBadConn) 1918 errorf("unexpected %q after error %s", t, err) 1919 } 1920 if t == 'I' { 1921 res = emptyRows 1922 } 1923 // ignore any results 1924 default: 1925 cn.err.set(driver.ErrBadConn) 1926 errorf("unknown %s response: %q", protocolState, t) 1927 } 1928 } 1929} 1930 1931func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) { 1932 n := r.int16() 1933 colNames = make([]string, n) 1934 colTyps = make([]fieldDesc, n) 1935 for i := range colNames { 1936 colNames[i] = r.string() 1937 r.next(6) 1938 colTyps[i].OID = r.oid() 1939 colTyps[i].Len = r.int16() 1940 colTyps[i].Mod = r.int32() 1941 // format code not known when describing a statement; always 0 1942 r.next(2) 1943 } 1944 return 1945} 1946 1947func parsePortalRowDescribe(r *readBuf) rowsHeader { 1948 n := r.int16() 1949 colNames := make([]string, n) 1950 colFmts := make([]format, n) 1951 colTyps := make([]fieldDesc, n) 1952 for i := range colNames { 1953 colNames[i] = r.string() 1954 r.next(6) 1955 colTyps[i].OID = r.oid() 1956 colTyps[i].Len = r.int16() 1957 colTyps[i].Mod = r.int32() 1958 colFmts[i] = format(r.int16()) 1959 } 1960 return rowsHeader{ 1961 colNames: colNames, 1962 colFmts: colFmts, 1963 colTyps: colTyps, 1964 } 1965} 1966 1967// parseEnviron tries to mimic some of libpq's environment handling 1968// 1969// To ease testing, it does not directly reference os.Environ, but is 1970// designed to accept its output. 1971// 1972// Environment-set connection information is intended to have a higher 1973// precedence than a library default but lower than any explicitly 1974// passed information (such as in the URL or connection string). 1975func parseEnviron(env []string) (out map[string]string) { 1976 out = make(map[string]string) 1977 1978 for _, v := range env { 1979 parts := strings.SplitN(v, "=", 2) 1980 1981 accrue := func(keyname string) { 1982 out[keyname] = parts[1] 1983 } 1984 unsupported := func() { 1985 panic(fmt.Sprintf("setting %v not supported", parts[0])) 1986 } 1987 1988 // The order of these is the same as is seen in the 1989 // PostgreSQL 9.1 manual. Unsupported but well-defined 1990 // keys cause a panic; these should be unset prior to 1991 // execution. Options which pq expects to be set to a 1992 // certain value are allowed, but must be set to that 1993 // value if present (they can, of course, be absent). 1994 switch parts[0] { 1995 case "PGHOST": 1996 accrue("host") 1997 case "PGHOSTADDR": 1998 unsupported() 1999 case "PGPORT": 2000 accrue("port") 2001 case "PGDATABASE": 2002 accrue("dbname") 2003 case "PGUSER": 2004 accrue("user") 2005 case "PGPASSWORD": 2006 accrue("password") 2007 case "PGSERVICE", "PGSERVICEFILE", "PGREALM": 2008 unsupported() 2009 case "PGOPTIONS": 2010 accrue("options") 2011 case "PGAPPNAME": 2012 accrue("application_name") 2013 case "PGSSLMODE": 2014 accrue("sslmode") 2015 case "PGSSLCERT": 2016 accrue("sslcert") 2017 case "PGSSLKEY": 2018 accrue("sslkey") 2019 case "PGSSLROOTCERT": 2020 accrue("sslrootcert") 2021 case "PGREQUIRESSL", "PGSSLCRL": 2022 unsupported() 2023 case "PGREQUIREPEER": 2024 unsupported() 2025 case "PGKRBSRVNAME", "PGGSSLIB": 2026 unsupported() 2027 case "PGCONNECT_TIMEOUT": 2028 accrue("connect_timeout") 2029 case "PGCLIENTENCODING": 2030 accrue("client_encoding") 2031 case "PGDATESTYLE": 2032 accrue("datestyle") 2033 case "PGTZ": 2034 accrue("timezone") 2035 case "PGGEQO": 2036 accrue("geqo") 2037 case "PGSYSCONFDIR", "PGLOCALEDIR": 2038 unsupported() 2039 } 2040 } 2041 2042 return out 2043} 2044 2045// isUTF8 returns whether name is a fuzzy variation of the string "UTF-8". 2046func isUTF8(name string) bool { 2047 // Recognize all sorts of silly things as "UTF-8", like Postgres does 2048 s := strings.Map(alnumLowerASCII, name) 2049 return s == "utf8" || s == "unicode" 2050} 2051 2052func alnumLowerASCII(ch rune) rune { 2053 if 'A' <= ch && ch <= 'Z' { 2054 return ch + ('a' - 'A') 2055 } 2056 if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' { 2057 return ch 2058 } 2059 return -1 // discard 2060} 2061