1package pgx 2 3import ( 4 "context" 5 "crypto/md5" 6 "crypto/tls" 7 "crypto/x509" 8 "encoding/binary" 9 "encoding/hex" 10 "fmt" 11 "io" 12 "io/ioutil" 13 "net" 14 "net/url" 15 "os" 16 "os/user" 17 "path/filepath" 18 "reflect" 19 "regexp" 20 "strconv" 21 "strings" 22 "sync" 23 "time" 24 25 "github.com/pkg/errors" 26 27 "github.com/jackc/pgx/pgio" 28 "github.com/jackc/pgx/pgproto3" 29 "github.com/jackc/pgx/pgtype" 30) 31 32const ( 33 connStatusUninitialized = iota 34 connStatusClosed 35 connStatusIdle 36 connStatusBusy 37) 38 39// minimalConnInfo has just enough static type information to establish the 40// connection and retrieve the type data. 41var minimalConnInfo *pgtype.ConnInfo 42 43func init() { 44 minimalConnInfo = pgtype.NewConnInfo() 45 minimalConnInfo.InitializeDataTypes(map[string]pgtype.OID{ 46 "int4": pgtype.Int4OID, 47 "name": pgtype.NameOID, 48 "oid": pgtype.OIDOID, 49 "text": pgtype.TextOID, 50 "varchar": pgtype.VarcharOID, 51 }) 52} 53 54// NoticeHandler is a function that can handle notices received from the 55// PostgreSQL server. Notices can be received at any time, usually during 56// handling of a query response. The *Conn is provided so the handler is aware 57// of the origin of the notice, but it must not invoke any query method. Be 58// aware that this is distinct from LISTEN/NOTIFY notification. 59type NoticeHandler func(*Conn, *Notice) 60 61// DialFunc is a function that can be used to connect to a PostgreSQL server 62type DialFunc func(network, addr string) (net.Conn, error) 63 64// ConnConfig contains all the options used to establish a connection. 65type ConnConfig struct { 66 Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) 67 Port uint16 // default: 5432 68 Database string 69 User string // default: OS user name 70 Password string 71 TLSConfig *tls.Config // config for TLS connection -- nil disables TLS 72 UseFallbackTLS bool // Try FallbackTLSConfig if connecting with TLSConfig fails. Used for preferring TLS, but allowing unencrypted, or vice-versa 73 FallbackTLSConfig *tls.Config // config for fallback TLS connection (only used if UseFallBackTLS is true)-- nil disables TLS 74 Logger Logger 75 LogLevel int 76 Dial DialFunc 77 RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) 78 OnNotice NoticeHandler // Callback function called when a notice response is received. 79 CustomConnInfo func(*Conn) (*pgtype.ConnInfo, error) // Callback function to implement connection strategies for different backends. crate, pgbouncer, pgpool, etc. 80 CustomCancel func(*Conn) error // Callback function used to override cancellation behavior 81 82 // PreferSimpleProtocol disables implicit prepared statement usage. By default 83 // pgx automatically uses the unnamed prepared statement for Query and 84 // QueryRow. It also uses a prepared statement when Exec has arguments. This 85 // can improve performance due to being able to use the binary format. It also 86 // does not rely on client side parameter sanitization. However, it does incur 87 // two round-trips per query and may be incompatible proxies such as 88 // PGBouncer. Setting PreferSimpleProtocol causes the simple protocol to be 89 // used by default. The same functionality can be controlled on a per query 90 // basis by setting QueryExOptions.SimpleProtocol. 91 PreferSimpleProtocol bool 92} 93 94func (cc *ConnConfig) networkAddress() (network, address string) { 95 network = "tcp" 96 address = fmt.Sprintf("%s:%d", cc.Host, cc.Port) 97 // See if host is a valid path, if yes connect with a socket 98 if _, err := os.Stat(cc.Host); err == nil { 99 // For backward compatibility accept socket file paths -- but directories are now preferred 100 network = "unix" 101 address = cc.Host 102 if !strings.Contains(address, "/.s.PGSQL.") { 103 address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10) 104 } 105 } 106 107 return network, address 108} 109 110// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. 111// Use ConnPool to manage access to multiple database connections from multiple 112// goroutines. 113type Conn struct { 114 conn net.Conn // the underlying TCP or unix domain socket connection 115 lastActivityTime time.Time // the last time the connection was used 116 wbuf []byte 117 pid uint32 // backend pid 118 secretKey uint32 // key to use to send a cancel query message to the server 119 RuntimeParams map[string]string // parameters that have been reported by the server 120 config ConnConfig // config used when establishing this connection 121 txStatus byte 122 preparedStatements map[string]*PreparedStatement 123 channels map[string]struct{} 124 notifications []*Notification 125 logger Logger 126 logLevel int 127 fp *fastpath 128 poolResetCount int 129 preallocatedRows []Rows 130 onNotice NoticeHandler 131 132 mux sync.Mutex 133 status byte // One of connStatus* constants 134 causeOfDeath error 135 136 pendingReadyForQueryCount int // number of ReadyForQuery messages expected 137 cancelQueryCompleted chan struct{} 138 lastStmtSent bool 139 140 // context support 141 ctxInProgress bool 142 doneChan chan struct{} 143 closedChan chan error 144 145 ConnInfo *pgtype.ConnInfo 146 147 frontend *pgproto3.Frontend 148} 149 150// PreparedStatement is a description of a prepared statement 151type PreparedStatement struct { 152 Name string 153 SQL string 154 FieldDescriptions []FieldDescription 155 ParameterOIDs []pgtype.OID 156} 157 158// PrepareExOptions is an option struct that can be passed to PrepareEx 159type PrepareExOptions struct { 160 ParameterOIDs []pgtype.OID 161} 162 163// Notification is a message received from the PostgreSQL LISTEN/NOTIFY system 164type Notification struct { 165 PID uint32 // backend pid that sent the notification 166 Channel string // channel from which notification was received 167 Payload string 168} 169 170// CommandTag is the result of an Exec function 171type CommandTag string 172 173// RowsAffected returns the number of rows affected. If the CommandTag was not 174// for a row affecting command (such as "CREATE TABLE") then it returns 0 175func (ct CommandTag) RowsAffected() int64 { 176 s := string(ct) 177 index := strings.LastIndex(s, " ") 178 if index == -1 { 179 return 0 180 } 181 n, _ := strconv.ParseInt(s[index+1:], 10, 64) 182 return n 183} 184 185// Identifier a PostgreSQL identifier or name. Identifiers can be composed of 186// multiple parts such as ["schema", "table"] or ["table", "column"]. 187type Identifier []string 188 189// Sanitize returns a sanitized string safe for SQL interpolation. 190func (ident Identifier) Sanitize() string { 191 parts := make([]string, len(ident)) 192 for i := range ident { 193 parts[i] = `"` + strings.Replace(ident[i], `"`, `""`, -1) + `"` 194 } 195 return strings.Join(parts, ".") 196} 197 198// ErrNoRows occurs when rows are expected but none are returned. 199var ErrNoRows = errors.New("no rows in result set") 200 201// ErrDeadConn occurs on an attempt to use a dead connection 202var ErrDeadConn = errors.New("conn is dead") 203 204// ErrTLSRefused occurs when the connection attempt requires TLS and the 205// PostgreSQL server refuses to use TLS 206var ErrTLSRefused = errors.New("server refused TLS connection") 207 208// ErrConnBusy occurs when the connection is busy (for example, in the middle of 209// reading query results) and another action is attempted. 210var ErrConnBusy = errors.New("conn is busy") 211 212// ErrInvalidLogLevel occurs on attempt to set an invalid log level. 213var ErrInvalidLogLevel = errors.New("invalid log level") 214 215// ProtocolError occurs when unexpected data is received from PostgreSQL 216type ProtocolError string 217 218func (e ProtocolError) Error() string { 219 return string(e) 220} 221 222// Connect establishes a connection with a PostgreSQL server using config. 223// config.Host must be specified. config.User will default to the OS user name. 224// Other config fields are optional. 225func Connect(config ConnConfig) (c *Conn, err error) { 226 return connect(config, minimalConnInfo) 227} 228 229func defaultDialer() *net.Dialer { 230 return &net.Dialer{KeepAlive: 5 * time.Minute} 231} 232 233func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) { 234 c = new(Conn) 235 236 c.config = config 237 c.ConnInfo = connInfo 238 239 if c.config.LogLevel != 0 { 240 c.logLevel = c.config.LogLevel 241 } else { 242 // Preserve pre-LogLevel behavior by defaulting to LogLevelDebug 243 c.logLevel = LogLevelDebug 244 } 245 c.logger = c.config.Logger 246 247 if c.config.User == "" { 248 user, err := user.Current() 249 if err != nil { 250 return nil, err 251 } 252 c.config.User = user.Username 253 if c.shouldLog(LogLevelDebug) { 254 c.log(LogLevelDebug, "Using default connection config", map[string]interface{}{"User": c.config.User}) 255 } 256 } 257 258 if c.config.Port == 0 { 259 c.config.Port = 5432 260 if c.shouldLog(LogLevelDebug) { 261 c.log(LogLevelDebug, "Using default connection config", map[string]interface{}{"Port": c.config.Port}) 262 } 263 } 264 265 c.onNotice = config.OnNotice 266 267 network, address := c.config.networkAddress() 268 if c.config.Dial == nil { 269 d := defaultDialer() 270 c.config.Dial = d.Dial 271 } 272 273 if c.shouldLog(LogLevelInfo) { 274 c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"network": network, "address": address}) 275 } 276 err = c.connect(config, network, address, config.TLSConfig) 277 if err != nil && config.UseFallbackTLS { 278 if c.shouldLog(LogLevelInfo) { 279 c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err}) 280 } 281 err = c.connect(config, network, address, config.FallbackTLSConfig) 282 } 283 284 if err != nil { 285 if c.shouldLog(LogLevelError) { 286 c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err}) 287 } 288 return nil, err 289 } 290 291 return c, nil 292} 293 294func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) { 295 c.conn, err = c.config.Dial(network, address) 296 if err != nil { 297 return err 298 } 299 defer func() { 300 if c != nil && err != nil { 301 c.conn.Close() 302 c.mux.Lock() 303 c.status = connStatusClosed 304 c.mux.Unlock() 305 } 306 }() 307 308 c.RuntimeParams = make(map[string]string) 309 c.preparedStatements = make(map[string]*PreparedStatement) 310 c.channels = make(map[string]struct{}) 311 c.lastActivityTime = time.Now() 312 c.cancelQueryCompleted = make(chan struct{}) 313 close(c.cancelQueryCompleted) 314 c.doneChan = make(chan struct{}) 315 c.closedChan = make(chan error) 316 c.wbuf = make([]byte, 0, 1024) 317 318 c.mux.Lock() 319 c.status = connStatusIdle 320 c.mux.Unlock() 321 322 if tlsConfig != nil { 323 if c.shouldLog(LogLevelDebug) { 324 c.log(LogLevelDebug, "starting TLS handshake", nil) 325 } 326 if err := c.startTLS(tlsConfig); err != nil { 327 return err 328 } 329 } 330 331 c.frontend, err = pgproto3.NewFrontend(c.conn, c.conn) 332 if err != nil { 333 return err 334 } 335 336 startupMsg := pgproto3.StartupMessage{ 337 ProtocolVersion: pgproto3.ProtocolVersionNumber, 338 Parameters: make(map[string]string), 339 } 340 341 // Copy default run-time params 342 for k, v := range config.RuntimeParams { 343 startupMsg.Parameters[k] = v 344 } 345 346 startupMsg.Parameters["user"] = c.config.User 347 if c.config.Database != "" { 348 startupMsg.Parameters["database"] = c.config.Database 349 } 350 351 if _, err := c.conn.Write(startupMsg.Encode(nil)); err != nil { 352 return err 353 } 354 355 c.pendingReadyForQueryCount = 1 356 357 for { 358 msg, err := c.rxMsg() 359 if err != nil { 360 return err 361 } 362 363 switch msg := msg.(type) { 364 case *pgproto3.BackendKeyData: 365 c.rxBackendKeyData(msg) 366 case *pgproto3.Authentication: 367 if err = c.rxAuthenticationX(msg); err != nil { 368 return err 369 } 370 case *pgproto3.ReadyForQuery: 371 c.rxReadyForQuery(msg) 372 if c.shouldLog(LogLevelInfo) { 373 c.log(LogLevelInfo, "connection established", nil) 374 } 375 376 // Replication connections can't execute the queries to 377 // populate the c.PgTypes and c.pgsqlAfInet 378 if _, ok := config.RuntimeParams["replication"]; ok { 379 return nil 380 } 381 382 if c.ConnInfo == minimalConnInfo { 383 err = c.initConnInfo() 384 if err != nil { 385 return err 386 } 387 } 388 389 return nil 390 default: 391 if err = c.processContextFreeMsg(msg); err != nil { 392 return err 393 } 394 } 395 } 396} 397 398func initPostgresql(c *Conn) (*pgtype.ConnInfo, error) { 399 const ( 400 namedOIDQuery = `select t.oid, 401 case when nsp.nspname in ('pg_catalog', 'public') then t.typname 402 else nsp.nspname||'.'||t.typname 403 end 404from pg_type t 405left join pg_type base_type on t.typelem=base_type.oid 406left join pg_namespace nsp on t.typnamespace=nsp.oid 407where ( 408 t.typtype in('b', 'p', 'r', 'e') 409 and (base_type.oid is null or base_type.typtype in('b', 'p', 'r')) 410 )` 411 ) 412 413 nameOIDs, err := connInfoFromRows(c.Query(namedOIDQuery)) 414 if err != nil { 415 return nil, err 416 } 417 418 cinfo := pgtype.NewConnInfo() 419 cinfo.InitializeDataTypes(nameOIDs) 420 421 if err = c.initConnInfoEnumArray(cinfo); err != nil { 422 return nil, err 423 } 424 425 if err = c.initConnInfoDomains(cinfo); err != nil { 426 return nil, err 427 } 428 429 return cinfo, nil 430} 431 432func (c *Conn) initConnInfo() (err error) { 433 var ( 434 connInfo *pgtype.ConnInfo 435 ) 436 437 if c.config.CustomConnInfo != nil { 438 if c.ConnInfo, err = c.config.CustomConnInfo(c); err != nil { 439 return err 440 } 441 442 return nil 443 } 444 445 if connInfo, err = initPostgresql(c); err == nil { 446 c.ConnInfo = connInfo 447 return err 448 } 449 450 // Check if CrateDB specific approach might still allow us to connect. 451 if connInfo, err = c.crateDBTypesQuery(err); err == nil { 452 c.ConnInfo = connInfo 453 } 454 455 return err 456} 457 458// initConnInfoEnumArray introspects for arrays of enums and registers a data type for them. 459func (c *Conn) initConnInfoEnumArray(cinfo *pgtype.ConnInfo) error { 460 nameOIDs := make(map[string]pgtype.OID, 16) 461 rows, err := c.Query(`select t.oid, t.typname 462from pg_type t 463 join pg_type base_type on t.typelem=base_type.oid 464where t.typtype = 'b' 465 and base_type.typtype = 'e'`) 466 if err != nil { 467 return err 468 } 469 470 for rows.Next() { 471 var oid pgtype.OID 472 var name pgtype.Text 473 if err := rows.Scan(&oid, &name); err != nil { 474 return err 475 } 476 477 nameOIDs[name.String] = oid 478 } 479 480 if rows.Err() != nil { 481 return rows.Err() 482 } 483 484 for name, oid := range nameOIDs { 485 cinfo.RegisterDataType(pgtype.DataType{ 486 Value: &pgtype.EnumArray{}, 487 Name: name, 488 OID: oid, 489 }) 490 } 491 492 return nil 493} 494 495// initConnInfoDomains introspects for domains and registers a data type for them. 496func (c *Conn) initConnInfoDomains(cinfo *pgtype.ConnInfo) error { 497 type domain struct { 498 oid pgtype.OID 499 name pgtype.Text 500 baseOID pgtype.OID 501 } 502 503 domains := make([]*domain, 0, 16) 504 505 rows, err := c.Query(`select t.oid, t.typname, t.typbasetype 506from pg_type t 507 join pg_type base_type on t.typbasetype=base_type.oid 508where t.typtype = 'd' 509 and base_type.typtype = 'b'`) 510 if err != nil { 511 return err 512 } 513 514 for rows.Next() { 515 var d domain 516 if err := rows.Scan(&d.oid, &d.name, &d.baseOID); err != nil { 517 return err 518 } 519 520 domains = append(domains, &d) 521 } 522 523 if rows.Err() != nil { 524 return rows.Err() 525 } 526 527 for _, d := range domains { 528 baseDataType, ok := cinfo.DataTypeForOID(d.baseOID) 529 if ok { 530 cinfo.RegisterDataType(pgtype.DataType{ 531 Value: reflect.New(reflect.ValueOf(baseDataType.Value).Elem().Type()).Interface().(pgtype.Value), 532 Name: d.name.String, 533 OID: d.oid, 534 }) 535 } 536 } 537 538 return nil 539} 540 541// crateDBTypesQuery checks if the given err is likely to be the result of 542// CrateDB not implementing the pg_types table correctly. If yes, a CrateDB 543// specific query against pg_types is executed and its results are returned. If 544// not, the original error is returned. 545func (c *Conn) crateDBTypesQuery(err error) (*pgtype.ConnInfo, error) { 546 // CrateDB 2.1.6 is a database that implements the PostgreSQL wire protocol, 547 // but not perfectly. In particular, the pg_catalog schema containing the 548 // pg_type table is not visible by default and the pg_type.typtype column is 549 // not implemented. Therefor the query above currently returns the following 550 // error: 551 // 552 // pgx.PgError{Severity:"ERROR", Code:"XX000", 553 // Message:"TableUnknownException: Table 'test.pg_type' unknown", 554 // Detail:"", Hint:"", Position:0, InternalPosition:0, InternalQuery:"", 555 // Where:"", SchemaName:"", TableName:"", ColumnName:"", DataTypeName:"", 556 // ConstraintName:"", File:"Schemas.java", Line:99, Routine:"getTableInfo"} 557 // 558 // If CrateDB was to fix the pg_type table visbility in the future, we'd 559 // still get this error until typtype column is implemented: 560 // 561 // pgx.PgError{Severity:"ERROR", Code:"XX000", 562 // Message:"ColumnUnknownException: Column typtype unknown", Detail:"", 563 // Hint:"", Position:0, InternalPosition:0, InternalQuery:"", Where:"", 564 // SchemaName:"", TableName:"", ColumnName:"", DataTypeName:"", 565 // ConstraintName:"", File:"FullQualifiedNameFieldProvider.java", Line:132, 566 // 567 // Additionally CrateDB doesn't implement Postgres error codes [2], and 568 // instead always returns "XX000" (internal_error). The code below uses all 569 // of this knowledge as a heuristic to detect CrateDB. If CrateDB is 570 // detected, a CrateDB specific pg_type query is executed instead. 571 // 572 // The heuristic is designed to still work even if CrateDB fixes [2] or 573 // renames its internal exception names. If both are changed but pg_types 574 // isn't fixed, this code will need to be changed. 575 // 576 // There is also a small chance the heuristic will yield a false positive for 577 // non-CrateDB databases (e.g. if a real Postgres instance returns a XX000 578 // error), but hopefully there will be no harm in attempting the alternative 579 // query in this case. 580 // 581 // CrateDB also uses the type varchar for the typname column which required 582 // adding varchar to the minimalConnInfo init code. 583 // 584 // Also see the discussion here [3]. 585 // 586 // [1] https://crate.io/ 587 // [2] https://github.com/crate/crate/issues/5027 588 // [3] https://github.com/jackc/pgx/issues/320 589 590 if pgErr, ok := err.(PgError); ok && 591 (pgErr.Code == "XX000" || 592 strings.Contains(pgErr.Message, "TableUnknownException") || 593 strings.Contains(pgErr.Message, "ColumnUnknownException")) { 594 var ( 595 nameOIDs map[string]pgtype.OID 596 ) 597 598 if nameOIDs, err = connInfoFromRows(c.Query(`select oid, typname from pg_catalog.pg_type`)); err != nil { 599 return nil, err 600 } 601 602 cinfo := pgtype.NewConnInfo() 603 cinfo.InitializeDataTypes(nameOIDs) 604 605 return cinfo, err 606 } 607 608 return nil, err 609} 610 611// PID returns the backend PID for this connection. 612func (c *Conn) PID() uint32 { 613 return c.pid 614} 615 616// LocalAddr returns the underlying connection's local address 617func (c *Conn) LocalAddr() (net.Addr, error) { 618 if !c.IsAlive() { 619 return nil, errors.New("connection not ready") 620 } 621 return c.conn.LocalAddr(), nil 622} 623 624// Close closes a connection. It is safe to call Close on a already closed 625// connection. 626func (c *Conn) Close() (err error) { 627 c.mux.Lock() 628 defer c.mux.Unlock() 629 630 if c.status < connStatusIdle { 631 return nil 632 } 633 c.status = connStatusClosed 634 635 defer func() { 636 c.conn.Close() 637 c.causeOfDeath = errors.New("Closed") 638 if c.shouldLog(LogLevelInfo) { 639 c.log(LogLevelInfo, "closed connection", nil) 640 } 641 }() 642 643 err = c.conn.SetDeadline(time.Time{}) 644 if err != nil && c.shouldLog(LogLevelWarn) { 645 c.log(LogLevelWarn, "failed to clear deadlines to send close message", map[string]interface{}{"err": err}) 646 return err 647 } 648 649 _, err = c.conn.Write([]byte{'X', 0, 0, 0, 4}) 650 if err != nil && c.shouldLog(LogLevelWarn) { 651 c.log(LogLevelWarn, "failed to send terminate message", map[string]interface{}{"err": err}) 652 return err 653 } 654 655 err = c.conn.SetReadDeadline(time.Now().Add(5 * time.Second)) 656 if err != nil && c.shouldLog(LogLevelWarn) { 657 c.log(LogLevelWarn, "failed to set read deadline to finish closing", map[string]interface{}{"err": err}) 658 return err 659 } 660 661 _, err = c.conn.Read(make([]byte, 1)) 662 if err != io.EOF { 663 return err 664 } 665 666 return nil 667} 668 669// Merge returns a new ConnConfig with the attributes of old and other 670// combined. When an attribute is set on both, other takes precedence. 671// 672// As a security precaution, if the other TLSConfig is nil, all old TLS 673// attributes will be preserved. 674func (old ConnConfig) Merge(other ConnConfig) ConnConfig { 675 cc := old 676 677 if other.Host != "" { 678 cc.Host = other.Host 679 } 680 if other.Port != 0 { 681 cc.Port = other.Port 682 } 683 if other.Database != "" { 684 cc.Database = other.Database 685 } 686 if other.User != "" { 687 cc.User = other.User 688 } 689 if other.Password != "" { 690 cc.Password = other.Password 691 } 692 693 if other.TLSConfig != nil { 694 cc.TLSConfig = other.TLSConfig 695 cc.UseFallbackTLS = other.UseFallbackTLS 696 cc.FallbackTLSConfig = other.FallbackTLSConfig 697 } 698 699 if other.Logger != nil { 700 cc.Logger = other.Logger 701 } 702 if other.LogLevel != 0 { 703 cc.LogLevel = other.LogLevel 704 } 705 706 if other.Dial != nil { 707 cc.Dial = other.Dial 708 } 709 710 cc.PreferSimpleProtocol = other.PreferSimpleProtocol 711 712 cc.RuntimeParams = make(map[string]string) 713 for k, v := range old.RuntimeParams { 714 cc.RuntimeParams[k] = v 715 } 716 for k, v := range other.RuntimeParams { 717 cc.RuntimeParams[k] = v 718 } 719 720 return cc 721} 722 723// ParseURI parses a database URI into ConnConfig 724// 725// Query parameters not used by the connection process are parsed into ConnConfig.RuntimeParams. 726func ParseURI(uri string) (ConnConfig, error) { 727 var cp ConnConfig 728 729 url, err := url.Parse(uri) 730 if err != nil { 731 return cp, err 732 } 733 734 if url.User != nil { 735 cp.User = url.User.Username() 736 cp.Password, _ = url.User.Password() 737 } 738 739 parts := strings.SplitN(url.Host, ":", 2) 740 cp.Host = parts[0] 741 if len(parts) == 2 { 742 p, err := strconv.ParseUint(parts[1], 10, 16) 743 if err != nil { 744 return cp, err 745 } 746 cp.Port = uint16(p) 747 } 748 cp.Database = strings.TrimLeft(url.Path, "/") 749 750 if pgtimeout := url.Query().Get("connect_timeout"); pgtimeout != "" { 751 timeout, err := strconv.ParseInt(pgtimeout, 10, 64) 752 if err != nil { 753 return cp, err 754 } 755 d := defaultDialer() 756 d.Timeout = time.Duration(timeout) * time.Second 757 cp.Dial = d.Dial 758 } 759 760 tlsArgs := configTLSArgs{ 761 sslCert: url.Query().Get("sslcert"), 762 sslKey: url.Query().Get("sslkey"), 763 sslMode: url.Query().Get("sslmode"), 764 sslRootCert: url.Query().Get("sslrootcert"), 765 } 766 err = configTLS(tlsArgs, &cp) 767 if err != nil { 768 return cp, err 769 } 770 771 ignoreKeys := map[string]struct{}{ 772 "connect_timeout": {}, 773 "sslcert": {}, 774 "sslkey": {}, 775 "sslmode": {}, 776 "sslrootcert": {}, 777 } 778 779 cp.RuntimeParams = make(map[string]string) 780 781 for k, v := range url.Query() { 782 if _, ok := ignoreKeys[k]; ok { 783 continue 784 } 785 786 if k == "host" { 787 cp.Host = v[0] 788 continue 789 } 790 791 cp.RuntimeParams[k] = v[0] 792 } 793 if cp.Password == "" { 794 pgpass(&cp) 795 } 796 return cp, nil 797} 798 799var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`) 800 801// ParseDSN parses a database DSN (data source name) into a ConnConfig 802// 803// e.g. ParseDSN("user=username password=password host=1.2.3.4 port=5432 dbname=mydb sslmode=disable") 804// 805// Any options not used by the connection process are parsed into ConnConfig.RuntimeParams. 806// 807// e.g. ParseDSN("application_name=pgxtest search_path=admin user=username password=password host=1.2.3.4 dbname=mydb") 808// 809// ParseDSN tries to match libpq behavior with regard to sslmode. See comments 810// for ParseEnvLibpq for more information on the security implications of 811// sslmode options. 812func ParseDSN(s string) (ConnConfig, error) { 813 var cp ConnConfig 814 815 m := dsnRegexp.FindAllStringSubmatch(s, -1) 816 817 tlsArgs := configTLSArgs{} 818 819 cp.RuntimeParams = make(map[string]string) 820 821 for _, b := range m { 822 switch b[1] { 823 case "user": 824 cp.User = b[2] 825 case "password": 826 cp.Password = b[2] 827 case "host": 828 cp.Host = b[2] 829 case "port": 830 p, err := strconv.ParseUint(b[2], 10, 16) 831 if err != nil { 832 return cp, err 833 } 834 cp.Port = uint16(p) 835 case "dbname": 836 cp.Database = b[2] 837 case "sslmode": 838 tlsArgs.sslMode = b[2] 839 case "sslrootcert": 840 tlsArgs.sslRootCert = b[2] 841 case "sslcert": 842 tlsArgs.sslCert = b[2] 843 case "sslkey": 844 tlsArgs.sslKey = b[2] 845 case "connect_timeout": 846 timeout, err := strconv.ParseInt(b[2], 10, 64) 847 if err != nil { 848 return cp, err 849 } 850 d := defaultDialer() 851 d.Timeout = time.Duration(timeout) * time.Second 852 cp.Dial = d.Dial 853 default: 854 cp.RuntimeParams[b[1]] = b[2] 855 } 856 } 857 858 err := configTLS(tlsArgs, &cp) 859 if err != nil { 860 return cp, err 861 } 862 if cp.Password == "" { 863 pgpass(&cp) 864 } 865 return cp, nil 866} 867 868// ParseConnectionString parses either a URI or a DSN connection string. 869// see ParseURI and ParseDSN for details. 870func ParseConnectionString(s string) (ConnConfig, error) { 871 if u, err := url.Parse(s); err == nil && u.Scheme != "" { 872 return ParseURI(s) 873 } 874 return ParseDSN(s) 875} 876 877// ParseEnvLibpq parses the environment like libpq does into a ConnConfig 878// 879// See http://www.postgresql.org/docs/9.4/static/libpq-envars.html for details 880// on the meaning of environment variables. 881// 882// ParseEnvLibpq currently recognizes the following environment variables: 883// PGHOST 884// PGPORT 885// PGDATABASE 886// PGUSER 887// PGPASSWORD 888// PGSSLMODE 889// PGSSLCERT 890// PGSSLKEY 891// PGSSLROOTCERT 892// PGAPPNAME 893// PGCONNECT_TIMEOUT 894// 895// Important TLS Security Notes: 896// ParseEnvLibpq tries to match libpq behavior with regard to PGSSLMODE. This 897// includes defaulting to "prefer" behavior if no environment variable is set. 898// 899// See http://www.postgresql.org/docs/9.4/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION 900// for details on what level of security each sslmode provides. 901// 902// "verify-ca" mode currently is treated as "verify-full". e.g. It has stronger 903// security guarantees than it would with libpq. Do not rely on this behavior as it 904// may be possible to match libpq in the future. If you need full security use 905// "verify-full". 906// 907// Several of the PGSSLMODE options (including the default behavior of "prefer") 908// will set UseFallbackTLS to true and FallbackTLSConfig to a disabled or 909// weakened TLS mode. This means that if ParseEnvLibpq is used, but TLSConfig is 910// later set from a different source that UseFallbackTLS MUST be set false to 911// avoid the possibility of falling back to weaker or disabled security. 912func ParseEnvLibpq() (ConnConfig, error) { 913 var cc ConnConfig 914 915 cc.Host = os.Getenv("PGHOST") 916 917 if pgport := os.Getenv("PGPORT"); pgport != "" { 918 if port, err := strconv.ParseUint(pgport, 10, 16); err == nil { 919 cc.Port = uint16(port) 920 } else { 921 return cc, err 922 } 923 } 924 925 cc.Database = os.Getenv("PGDATABASE") 926 cc.User = os.Getenv("PGUSER") 927 cc.Password = os.Getenv("PGPASSWORD") 928 929 if pgtimeout := os.Getenv("PGCONNECT_TIMEOUT"); pgtimeout != "" { 930 if timeout, err := strconv.ParseInt(pgtimeout, 10, 64); err == nil { 931 d := defaultDialer() 932 d.Timeout = time.Duration(timeout) * time.Second 933 cc.Dial = d.Dial 934 } else { 935 return cc, err 936 } 937 } 938 939 tlsArgs := configTLSArgs{ 940 sslMode: os.Getenv("PGSSLMODE"), 941 sslKey: os.Getenv("PGSSLKEY"), 942 sslCert: os.Getenv("PGSSLCERT"), 943 sslRootCert: os.Getenv("PGSSLROOTCERT"), 944 } 945 946 err := configTLS(tlsArgs, &cc) 947 if err != nil { 948 return cc, err 949 } 950 951 cc.RuntimeParams = make(map[string]string) 952 if appname := os.Getenv("PGAPPNAME"); appname != "" { 953 cc.RuntimeParams["application_name"] = appname 954 } 955 if cc.Password == "" { 956 pgpass(&cc) 957 } 958 return cc, nil 959} 960 961type configTLSArgs struct { 962 sslMode string 963 sslRootCert string 964 sslCert string 965 sslKey string 966} 967 968// configTLS uses lib/pq's TLS parameters to reconstruct a coherent tls.Config. 969// Inputs are parsed out and provided by ParseDSN() or ParseURI(). 970func configTLS(args configTLSArgs, cc *ConnConfig) error { 971 // Match libpq default behavior 972 if args.sslMode == "" { 973 args.sslMode = "prefer" 974 } 975 976 switch args.sslMode { 977 case "disable": 978 cc.UseFallbackTLS = false 979 cc.TLSConfig = nil 980 cc.FallbackTLSConfig = nil 981 return nil 982 case "allow": 983 cc.UseFallbackTLS = true 984 cc.FallbackTLSConfig = &tls.Config{InsecureSkipVerify: true} 985 case "prefer": 986 cc.TLSConfig = &tls.Config{InsecureSkipVerify: true} 987 cc.UseFallbackTLS = true 988 cc.FallbackTLSConfig = nil 989 case "require": 990 cc.TLSConfig = &tls.Config{InsecureSkipVerify: true} 991 case "verify-ca", "verify-full": 992 cc.TLSConfig = &tls.Config{ 993 ServerName: cc.Host, 994 } 995 default: 996 return errors.New("sslmode is invalid") 997 } 998 999 if args.sslRootCert != "" { 1000 caCertPool := x509.NewCertPool() 1001 1002 caPath := args.sslRootCert 1003 caCert, err := ioutil.ReadFile(caPath) 1004 if err != nil { 1005 return errors.Wrapf(err, "unable to read CA file %q", caPath) 1006 } 1007 1008 if !caCertPool.AppendCertsFromPEM(caCert) { 1009 return errors.Wrap(err, "unable to add CA to cert pool") 1010 } 1011 1012 cc.TLSConfig.RootCAs = caCertPool 1013 cc.TLSConfig.ClientCAs = caCertPool 1014 } 1015 1016 sslcert := args.sslCert 1017 sslkey := args.sslKey 1018 1019 if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { 1020 return fmt.Errorf(`both "sslcert" and "sslkey" are required`) 1021 } 1022 1023 if sslcert != "" && sslkey != "" { 1024 cert, err := tls.LoadX509KeyPair(sslcert, sslkey) 1025 if err != nil { 1026 return errors.Wrap(err, "unable to read cert") 1027 } 1028 1029 cc.TLSConfig.Certificates = []tls.Certificate{cert} 1030 } 1031 1032 return nil 1033} 1034 1035// Prepare creates a prepared statement with name and sql. sql can contain placeholders 1036// for bound parameters. These placeholders are referenced positional as $1, $2, etc. 1037// 1038// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same 1039// name and sql arguments. This allows a code path to Prepare and Query/Exec without 1040// concern for if the statement has already been prepared. 1041func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { 1042 return c.PrepareEx(context.Background(), name, sql, nil) 1043} 1044 1045// PrepareEx creates a prepared statement with name and sql. sql can contain placeholders 1046// for bound parameters. These placeholders are referenced positional as $1, $2, etc. 1047// It differs from Prepare as it allows additional options (such as parameter OIDs) to be passed via struct 1048// 1049// PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same 1050// name and sql arguments. This allows a code path to PrepareEx and Query/Exec without 1051// concern for if the statement has already been prepared. 1052func (c *Conn) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { 1053 err = c.waitForPreviousCancelQuery(ctx) 1054 if err != nil { 1055 return nil, err 1056 } 1057 1058 err = c.initContext(ctx) 1059 if err != nil { 1060 return nil, err 1061 } 1062 1063 ps, err = c.prepareEx(name, sql, opts) 1064 err = c.termContext(err) 1065 return ps, err 1066} 1067 1068func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { 1069 if name != "" { 1070 if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql { 1071 return ps, nil 1072 } 1073 } 1074 1075 if err := c.ensureConnectionReadyForQuery(); err != nil { 1076 return nil, err 1077 } 1078 1079 if c.shouldLog(LogLevelError) { 1080 defer func() { 1081 if err != nil { 1082 c.log(LogLevelError, "prepareEx failed", map[string]interface{}{"err": err, "name": name, "sql": sql}) 1083 } 1084 }() 1085 } 1086 1087 if opts == nil { 1088 opts = &PrepareExOptions{} 1089 } 1090 1091 if len(opts.ParameterOIDs) > 65535 { 1092 return nil, errors.Errorf("Number of PrepareExOptions ParameterOIDs must be between 0 and 65535, received %d", len(opts.ParameterOIDs)) 1093 } 1094 1095 buf := appendParse(c.wbuf, name, sql, opts.ParameterOIDs) 1096 buf = appendDescribe(buf, 'S', name) 1097 buf = appendSync(buf) 1098 1099 n, err := c.conn.Write(buf) 1100 if err != nil { 1101 if fatalWriteErr(n, err) { 1102 c.die(err) 1103 } 1104 return nil, err 1105 } 1106 c.pendingReadyForQueryCount++ 1107 1108 ps = &PreparedStatement{Name: name, SQL: sql} 1109 1110 var softErr error 1111 1112 for { 1113 msg, err := c.rxMsg() 1114 if err != nil { 1115 return nil, err 1116 } 1117 1118 switch msg := msg.(type) { 1119 case *pgproto3.ParameterDescription: 1120 ps.ParameterOIDs = c.rxParameterDescription(msg) 1121 1122 if len(ps.ParameterOIDs) > 65535 && softErr == nil { 1123 softErr = errors.Errorf("PostgreSQL supports maximum of 65535 parameters, received %d", len(ps.ParameterOIDs)) 1124 } 1125 case *pgproto3.RowDescription: 1126 ps.FieldDescriptions = c.rxRowDescription(msg) 1127 for i := range ps.FieldDescriptions { 1128 if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok { 1129 ps.FieldDescriptions[i].DataTypeName = dt.Name 1130 if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { 1131 ps.FieldDescriptions[i].FormatCode = BinaryFormatCode 1132 } else { 1133 ps.FieldDescriptions[i].FormatCode = TextFormatCode 1134 } 1135 } else { 1136 return nil, errors.Errorf("unknown oid: %d", ps.FieldDescriptions[i].DataType) 1137 } 1138 } 1139 case *pgproto3.ReadyForQuery: 1140 c.rxReadyForQuery(msg) 1141 1142 if softErr == nil { 1143 c.preparedStatements[name] = ps 1144 } 1145 1146 return ps, softErr 1147 default: 1148 if e := c.processContextFreeMsg(msg); e != nil && softErr == nil { 1149 softErr = e 1150 } 1151 } 1152 } 1153} 1154 1155// Deallocate released a prepared statement 1156func (c *Conn) Deallocate(name string) error { 1157 return c.deallocateContext(context.Background(), name) 1158} 1159 1160// TODO - consider making this public 1161func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) { 1162 err = c.waitForPreviousCancelQuery(ctx) 1163 if err != nil { 1164 return err 1165 } 1166 1167 err = c.initContext(ctx) 1168 if err != nil { 1169 return err 1170 } 1171 defer func() { 1172 err = c.termContext(err) 1173 }() 1174 1175 if err := c.ensureConnectionReadyForQuery(); err != nil { 1176 return err 1177 } 1178 1179 delete(c.preparedStatements, name) 1180 1181 // close 1182 buf := c.wbuf 1183 buf = append(buf, 'C') 1184 sp := len(buf) 1185 buf = pgio.AppendInt32(buf, -1) 1186 buf = append(buf, 'S') 1187 buf = append(buf, name...) 1188 buf = append(buf, 0) 1189 pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) 1190 1191 // flush 1192 buf = append(buf, 'H') 1193 buf = pgio.AppendInt32(buf, 4) 1194 1195 _, err = c.conn.Write(buf) 1196 if err != nil { 1197 c.die(err) 1198 return err 1199 } 1200 1201 for { 1202 msg, err := c.rxMsg() 1203 if err != nil { 1204 return err 1205 } 1206 1207 switch msg.(type) { 1208 case *pgproto3.CloseComplete: 1209 return nil 1210 default: 1211 err = c.processContextFreeMsg(msg) 1212 if err != nil { 1213 return err 1214 } 1215 } 1216 } 1217} 1218 1219// Listen establishes a PostgreSQL listen/notify to channel 1220func (c *Conn) Listen(channel string) error { 1221 _, err := c.Exec("listen " + quoteIdentifier(channel)) 1222 if err != nil { 1223 return err 1224 } 1225 1226 c.channels[channel] = struct{}{} 1227 1228 return nil 1229} 1230 1231// Unlisten unsubscribes from a listen channel 1232func (c *Conn) Unlisten(channel string) error { 1233 _, err := c.Exec("unlisten " + quoteIdentifier(channel)) 1234 if err != nil { 1235 return err 1236 } 1237 1238 delete(c.channels, channel) 1239 return nil 1240} 1241 1242// WaitForNotification waits for a PostgreSQL notification. 1243func (c *Conn) WaitForNotification(ctx context.Context) (notification *Notification, err error) { 1244 // Return already received notification immediately 1245 if len(c.notifications) > 0 { 1246 notification := c.notifications[0] 1247 c.notifications = c.notifications[1:] 1248 return notification, nil 1249 } 1250 1251 err = c.waitForPreviousCancelQuery(ctx) 1252 if err != nil { 1253 return nil, err 1254 } 1255 1256 err = c.initContext(ctx) 1257 if err != nil { 1258 return nil, err 1259 } 1260 defer func() { 1261 err = c.termContext(err) 1262 }() 1263 1264 if err = c.lock(); err != nil { 1265 return nil, err 1266 } 1267 defer func() { 1268 if unlockErr := c.unlock(); unlockErr != nil && err == nil { 1269 err = unlockErr 1270 } 1271 }() 1272 1273 if err := c.ensureConnectionReadyForQuery(); err != nil { 1274 return nil, err 1275 } 1276 1277 for { 1278 msg, err := c.rxMsg() 1279 if err != nil { 1280 return nil, err 1281 } 1282 1283 err = c.processContextFreeMsg(msg) 1284 if err != nil { 1285 return nil, err 1286 } 1287 1288 if len(c.notifications) > 0 { 1289 notification := c.notifications[0] 1290 c.notifications = c.notifications[1:] 1291 return notification, nil 1292 } 1293 } 1294} 1295 1296func (c *Conn) IsAlive() bool { 1297 c.mux.Lock() 1298 defer c.mux.Unlock() 1299 return c.status >= connStatusIdle 1300} 1301 1302func (c *Conn) CauseOfDeath() error { 1303 c.mux.Lock() 1304 defer c.mux.Unlock() 1305 return c.causeOfDeath 1306} 1307 1308func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) { 1309 if ps, present := c.preparedStatements[sql]; present { 1310 return c.sendPreparedQuery(ps, arguments...) 1311 } 1312 return c.sendSimpleQuery(sql, arguments...) 1313} 1314 1315func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { 1316 if err := c.ensureConnectionReadyForQuery(); err != nil { 1317 return err 1318 } 1319 1320 if len(args) == 0 { 1321 buf := appendQuery(c.wbuf, sql) 1322 1323 _, err := c.conn.Write(buf) 1324 if err != nil { 1325 c.die(err) 1326 return err 1327 } 1328 c.pendingReadyForQueryCount++ 1329 1330 return nil 1331 } 1332 1333 ps, err := c.Prepare("", sql) 1334 if err != nil { 1335 return err 1336 } 1337 1338 return c.sendPreparedQuery(ps, args...) 1339} 1340 1341func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}) (err error) { 1342 if len(ps.ParameterOIDs) != len(arguments) { 1343 return errors.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOIDs), len(arguments)) 1344 } 1345 1346 if err := c.ensureConnectionReadyForQuery(); err != nil { 1347 return err 1348 } 1349 1350 resultFormatCodes := make([]int16, len(ps.FieldDescriptions)) 1351 for i, fd := range ps.FieldDescriptions { 1352 resultFormatCodes[i] = fd.FormatCode 1353 } 1354 buf, err := appendBind(c.wbuf, "", ps.Name, c.ConnInfo, ps.ParameterOIDs, arguments, resultFormatCodes) 1355 if err != nil { 1356 return err 1357 } 1358 1359 buf = appendExecute(buf, "", 0) 1360 buf = appendSync(buf) 1361 1362 n, err := c.conn.Write(buf) 1363 if err != nil { 1364 if fatalWriteErr(n, err) { 1365 c.die(err) 1366 } 1367 return err 1368 } 1369 c.pendingReadyForQueryCount++ 1370 1371 return nil 1372} 1373 1374// fatalWriteError takes the response of a net.Conn.Write and determines if it is fatal 1375func fatalWriteErr(bytesWritten int, err error) bool { 1376 // Partial writes break the connection 1377 if bytesWritten > 0 { 1378 return true 1379 } 1380 1381 netErr, is := err.(net.Error) 1382 return !(is && netErr.Timeout()) 1383} 1384 1385// Exec executes sql. sql can be either a prepared statement name or an SQL string. 1386// arguments should be referenced positionally from the sql string as $1, $2, etc. 1387func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) { 1388 return c.ExecEx(context.Background(), sql, nil, arguments...) 1389} 1390 1391// Processes messages that are not exclusive to one context such as 1392// authentication or query response. The response to these messages is the same 1393// regardless of when they occur. It also ignores messages that are only 1394// meaningful in a given context. These messages can occur due to a context 1395// deadline interrupting message processing. For example, an interrupted query 1396// may have left DataRow messages on the wire. 1397func (c *Conn) processContextFreeMsg(msg pgproto3.BackendMessage) (err error) { 1398 switch msg := msg.(type) { 1399 case *pgproto3.ErrorResponse: 1400 return c.rxErrorResponse(msg) 1401 case *pgproto3.NoticeResponse: 1402 c.rxNoticeResponse(msg) 1403 case *pgproto3.NotificationResponse: 1404 c.rxNotificationResponse(msg) 1405 case *pgproto3.ReadyForQuery: 1406 c.rxReadyForQuery(msg) 1407 case *pgproto3.ParameterStatus: 1408 c.rxParameterStatus(msg) 1409 } 1410 1411 return nil 1412} 1413 1414func (c *Conn) rxMsg() (pgproto3.BackendMessage, error) { 1415 if !c.IsAlive() { 1416 return nil, ErrDeadConn 1417 } 1418 1419 msg, err := c.frontend.Receive() 1420 if err != nil { 1421 if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { 1422 c.die(err) 1423 } 1424 return nil, err 1425 } 1426 1427 c.lastActivityTime = time.Now() 1428 1429 // fmt.Printf("rxMsg: %#v\n", msg) 1430 1431 return msg, nil 1432} 1433 1434func (c *Conn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { 1435 switch msg.Type { 1436 case pgproto3.AuthTypeOk: 1437 case pgproto3.AuthTypeCleartextPassword: 1438 err = c.txPasswordMessage(c.config.Password) 1439 case pgproto3.AuthTypeMD5Password: 1440 digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+string(msg.Salt[:])) 1441 err = c.txPasswordMessage(digestedPassword) 1442 default: 1443 err = errors.New("Received unknown authentication message") 1444 } 1445 1446 return 1447} 1448 1449func hexMD5(s string) string { 1450 hash := md5.New() 1451 io.WriteString(hash, s) 1452 return hex.EncodeToString(hash.Sum(nil)) 1453} 1454 1455func (c *Conn) rxParameterStatus(msg *pgproto3.ParameterStatus) { 1456 c.RuntimeParams[msg.Name] = msg.Value 1457} 1458 1459func (c *Conn) rxErrorResponse(msg *pgproto3.ErrorResponse) PgError { 1460 err := PgError{ 1461 Severity: msg.Severity, 1462 Code: msg.Code, 1463 Message: msg.Message, 1464 Detail: msg.Detail, 1465 Hint: msg.Hint, 1466 Position: msg.Position, 1467 InternalPosition: msg.InternalPosition, 1468 InternalQuery: msg.InternalQuery, 1469 Where: msg.Where, 1470 SchemaName: msg.SchemaName, 1471 TableName: msg.TableName, 1472 ColumnName: msg.ColumnName, 1473 DataTypeName: msg.DataTypeName, 1474 ConstraintName: msg.ConstraintName, 1475 File: msg.File, 1476 Line: msg.Line, 1477 Routine: msg.Routine, 1478 } 1479 1480 if err.Severity == "FATAL" { 1481 c.die(err) 1482 } 1483 1484 return err 1485} 1486 1487func (c *Conn) rxNoticeResponse(msg *pgproto3.NoticeResponse) { 1488 if c.onNotice == nil { 1489 return 1490 } 1491 1492 notice := &Notice{ 1493 Severity: msg.Severity, 1494 Code: msg.Code, 1495 Message: msg.Message, 1496 Detail: msg.Detail, 1497 Hint: msg.Hint, 1498 Position: msg.Position, 1499 InternalPosition: msg.InternalPosition, 1500 InternalQuery: msg.InternalQuery, 1501 Where: msg.Where, 1502 SchemaName: msg.SchemaName, 1503 TableName: msg.TableName, 1504 ColumnName: msg.ColumnName, 1505 DataTypeName: msg.DataTypeName, 1506 ConstraintName: msg.ConstraintName, 1507 File: msg.File, 1508 Line: msg.Line, 1509 Routine: msg.Routine, 1510 } 1511 1512 c.onNotice(c, notice) 1513} 1514 1515func (c *Conn) rxBackendKeyData(msg *pgproto3.BackendKeyData) { 1516 c.pid = msg.ProcessID 1517 c.secretKey = msg.SecretKey 1518} 1519 1520func (c *Conn) rxReadyForQuery(msg *pgproto3.ReadyForQuery) { 1521 c.pendingReadyForQueryCount-- 1522 c.txStatus = msg.TxStatus 1523} 1524 1525func (c *Conn) rxRowDescription(msg *pgproto3.RowDescription) []FieldDescription { 1526 fields := make([]FieldDescription, len(msg.Fields)) 1527 for i := 0; i < len(fields); i++ { 1528 fields[i].Name = msg.Fields[i].Name 1529 fields[i].Table = pgtype.OID(msg.Fields[i].TableOID) 1530 fields[i].AttributeNumber = msg.Fields[i].TableAttributeNumber 1531 fields[i].DataType = pgtype.OID(msg.Fields[i].DataTypeOID) 1532 fields[i].DataTypeSize = msg.Fields[i].DataTypeSize 1533 fields[i].Modifier = msg.Fields[i].TypeModifier 1534 fields[i].FormatCode = msg.Fields[i].Format 1535 } 1536 return fields 1537} 1538 1539func (c *Conn) rxParameterDescription(msg *pgproto3.ParameterDescription) []pgtype.OID { 1540 parameters := make([]pgtype.OID, len(msg.ParameterOIDs)) 1541 for i := 0; i < len(parameters); i++ { 1542 parameters[i] = pgtype.OID(msg.ParameterOIDs[i]) 1543 } 1544 return parameters 1545} 1546 1547func (c *Conn) rxNotificationResponse(msg *pgproto3.NotificationResponse) { 1548 n := new(Notification) 1549 n.PID = msg.PID 1550 n.Channel = msg.Channel 1551 n.Payload = msg.Payload 1552 c.notifications = append(c.notifications, n) 1553} 1554 1555func (c *Conn) startTLS(tlsConfig *tls.Config) (err error) { 1556 err = binary.Write(c.conn, binary.BigEndian, []int32{8, 80877103}) 1557 if err != nil { 1558 return 1559 } 1560 1561 response := make([]byte, 1) 1562 if _, err = io.ReadFull(c.conn, response); err != nil { 1563 return 1564 } 1565 1566 if response[0] != 'S' { 1567 return ErrTLSRefused 1568 } 1569 1570 c.conn = tls.Client(c.conn, tlsConfig) 1571 1572 return nil 1573} 1574 1575func (c *Conn) txPasswordMessage(password string) (err error) { 1576 buf := c.wbuf 1577 buf = append(buf, 'p') 1578 sp := len(buf) 1579 buf = pgio.AppendInt32(buf, -1) 1580 buf = append(buf, password...) 1581 buf = append(buf, 0) 1582 pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) 1583 1584 _, err = c.conn.Write(buf) 1585 1586 return err 1587} 1588 1589func (c *Conn) die(err error) { 1590 c.mux.Lock() 1591 defer c.mux.Unlock() 1592 1593 if c.status == connStatusClosed { 1594 return 1595 } 1596 1597 c.status = connStatusClosed 1598 c.causeOfDeath = err 1599 c.conn.Close() 1600} 1601 1602func (c *Conn) lock() error { 1603 c.mux.Lock() 1604 defer c.mux.Unlock() 1605 1606 if c.status != connStatusIdle { 1607 return ErrConnBusy 1608 } 1609 1610 c.status = connStatusBusy 1611 return nil 1612} 1613 1614func (c *Conn) unlock() error { 1615 c.mux.Lock() 1616 defer c.mux.Unlock() 1617 1618 if c.status != connStatusBusy { 1619 return errors.New("unlock conn that is not busy") 1620 } 1621 1622 c.status = connStatusIdle 1623 return nil 1624} 1625 1626func (c *Conn) shouldLog(lvl int) bool { 1627 return c.logger != nil && c.logLevel >= lvl 1628} 1629 1630func (c *Conn) log(lvl LogLevel, msg string, data map[string]interface{}) { 1631 if data == nil { 1632 data = map[string]interface{}{} 1633 } 1634 if c.pid != 0 { 1635 data["pid"] = c.pid 1636 } 1637 1638 c.logger.Log(lvl, msg, data) 1639} 1640 1641// SetLogger replaces the current logger and returns the previous logger. 1642func (c *Conn) SetLogger(logger Logger) Logger { 1643 oldLogger := c.logger 1644 c.logger = logger 1645 return oldLogger 1646} 1647 1648// SetLogLevel replaces the current log level and returns the previous log 1649// level. 1650func (c *Conn) SetLogLevel(lvl int) (int, error) { 1651 oldLvl := c.logLevel 1652 1653 if lvl < LogLevelNone || lvl > LogLevelTrace { 1654 return oldLvl, ErrInvalidLogLevel 1655 } 1656 1657 c.logLevel = lvl 1658 return lvl, nil 1659} 1660 1661func quoteIdentifier(s string) string { 1662 return `"` + strings.Replace(s, `"`, `""`, -1) + `"` 1663} 1664 1665func doCancel(c *Conn) error { 1666 network, address := c.config.networkAddress() 1667 cancelConn, err := c.config.Dial(network, address) 1668 if err != nil { 1669 return err 1670 } 1671 defer cancelConn.Close() 1672 1673 // If server doesn't process cancellation request in bounded time then abort. 1674 now := time.Now() 1675 err = cancelConn.SetDeadline(now.Add(15 * time.Second)) 1676 if err != nil { 1677 return err 1678 } 1679 1680 buf := make([]byte, 16) 1681 binary.BigEndian.PutUint32(buf[0:4], 16) 1682 binary.BigEndian.PutUint32(buf[4:8], 80877102) 1683 binary.BigEndian.PutUint32(buf[8:12], uint32(c.pid)) 1684 binary.BigEndian.PutUint32(buf[12:16], uint32(c.secretKey)) 1685 _, err = cancelConn.Write(buf) 1686 if err != nil { 1687 return err 1688 } 1689 1690 _, err = cancelConn.Read(buf) 1691 if err != io.EOF { 1692 return errors.Errorf("Server failed to close connection after cancel query request: %v %v", err, buf) 1693 } 1694 1695 return nil 1696} 1697 1698// cancelQuery sends a cancel request to the PostgreSQL server. It returns an 1699// error if unable to deliver the cancel request, but lack of an error does not 1700// ensure that the query was canceled. As specified in the documentation, there 1701// is no way to be sure a query was canceled. See 1702// https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861 1703func (c *Conn) cancelQuery() { 1704 if err := c.conn.SetDeadline(time.Now()); err != nil { 1705 c.Close() // Close connection if unable to set deadline 1706 return 1707 } 1708 1709 var cancelFn func(*Conn) error 1710 completeCh := make(chan struct{}) 1711 c.mux.Lock() 1712 c.cancelQueryCompleted = completeCh 1713 c.mux.Unlock() 1714 if c.config.CustomCancel != nil { 1715 cancelFn = c.config.CustomCancel 1716 } else { 1717 cancelFn = doCancel 1718 } 1719 1720 go func() { 1721 defer close(completeCh) 1722 err := cancelFn(c) 1723 if err != nil { 1724 c.Close() // Something is very wrong. Terminate the connection. 1725 } 1726 }() 1727} 1728 1729func (c *Conn) Ping(ctx context.Context) error { 1730 _, err := c.ExecEx(ctx, ";", nil) 1731 return err 1732} 1733 1734func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (CommandTag, error) { 1735 c.lastStmtSent = false 1736 err := c.waitForPreviousCancelQuery(ctx) 1737 if err != nil { 1738 return "", err 1739 } 1740 1741 if err := c.lock(); err != nil { 1742 return "", err 1743 } 1744 defer c.unlock() 1745 1746 startTime := time.Now() 1747 c.lastActivityTime = startTime 1748 1749 commandTag, err := c.execEx(ctx, sql, options, arguments...) 1750 if err != nil { 1751 if c.shouldLog(LogLevelError) { 1752 c.log(LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err}) 1753 } 1754 return commandTag, err 1755 } 1756 1757 if c.shouldLog(LogLevelInfo) { 1758 endTime := time.Now() 1759 c.log(LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag}) 1760 } 1761 1762 return commandTag, err 1763} 1764 1765func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) { 1766 err = c.initContext(ctx) 1767 if err != nil { 1768 return "", err 1769 } 1770 defer func() { 1771 err = c.termContext(err) 1772 }() 1773 1774 if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) { 1775 c.lastStmtSent = true 1776 err = c.sanitizeAndSendSimpleQuery(sql, arguments...) 1777 if err != nil { 1778 return "", err 1779 } 1780 } else if options != nil && len(options.ParameterOIDs) > 0 { 1781 if err := c.ensureConnectionReadyForQuery(); err != nil { 1782 return "", err 1783 } 1784 1785 buf, err := c.buildOneRoundTripExec(c.wbuf, sql, options, arguments) 1786 if err != nil { 1787 return "", err 1788 } 1789 1790 buf = appendSync(buf) 1791 1792 c.lastStmtSent = true 1793 n, err := c.conn.Write(buf) 1794 if err != nil && fatalWriteErr(n, err) { 1795 c.die(err) 1796 return "", err 1797 } 1798 c.pendingReadyForQueryCount++ 1799 } else { 1800 if len(arguments) > 0 { 1801 ps, ok := c.preparedStatements[sql] 1802 if !ok { 1803 var err error 1804 ps, err = c.prepareEx("", sql, nil) 1805 if err != nil { 1806 return "", err 1807 } 1808 } 1809 1810 c.lastStmtSent = true 1811 err = c.sendPreparedQuery(ps, arguments...) 1812 if err != nil { 1813 return "", err 1814 } 1815 } else { 1816 c.lastStmtSent = true 1817 if err = c.sendQuery(sql, arguments...); err != nil { 1818 return 1819 } 1820 } 1821 } 1822 1823 var softErr error 1824 1825 for { 1826 msg, err := c.rxMsg() 1827 if err != nil { 1828 return commandTag, err 1829 } 1830 1831 switch msg := msg.(type) { 1832 case *pgproto3.ReadyForQuery: 1833 c.rxReadyForQuery(msg) 1834 return commandTag, softErr 1835 case *pgproto3.CommandComplete: 1836 commandTag = CommandTag(msg.CommandTag) 1837 default: 1838 if e := c.processContextFreeMsg(msg); e != nil && softErr == nil { 1839 softErr = e 1840 } 1841 } 1842 } 1843} 1844 1845func (c *Conn) buildOneRoundTripExec(buf []byte, sql string, options *QueryExOptions, arguments []interface{}) ([]byte, error) { 1846 if len(arguments) != len(options.ParameterOIDs) { 1847 return nil, errors.Errorf("mismatched number of arguments (%d) and options.ParameterOIDs (%d)", len(arguments), len(options.ParameterOIDs)) 1848 } 1849 1850 if len(options.ParameterOIDs) > 65535 { 1851 return nil, errors.Errorf("Number of QueryExOptions ParameterOIDs must be between 0 and 65535, received %d", len(options.ParameterOIDs)) 1852 } 1853 1854 buf = appendParse(buf, "", sql, options.ParameterOIDs) 1855 buf, err := appendBind(buf, "", "", c.ConnInfo, options.ParameterOIDs, arguments, nil) 1856 if err != nil { 1857 return nil, err 1858 } 1859 buf = appendExecute(buf, "", 0) 1860 1861 return buf, nil 1862} 1863 1864func (c *Conn) initContext(ctx context.Context) error { 1865 if c.ctxInProgress { 1866 return errors.New("ctx already in progress") 1867 } 1868 1869 if ctx.Done() == nil { 1870 return nil 1871 } 1872 1873 select { 1874 case <-ctx.Done(): 1875 return ctx.Err() 1876 default: 1877 } 1878 1879 c.ctxInProgress = true 1880 1881 go c.contextHandler(ctx) 1882 1883 return nil 1884} 1885 1886func (c *Conn) termContext(opErr error) error { 1887 if !c.ctxInProgress { 1888 return opErr 1889 } 1890 1891 var err error 1892 1893 select { 1894 case err = <-c.closedChan: 1895 if opErr == nil { 1896 err = nil 1897 } 1898 case c.doneChan <- struct{}{}: 1899 err = opErr 1900 } 1901 1902 c.ctxInProgress = false 1903 return err 1904} 1905 1906func (c *Conn) contextHandler(ctx context.Context) { 1907 select { 1908 case <-ctx.Done(): 1909 c.cancelQuery() 1910 c.closedChan <- ctx.Err() 1911 case <-c.doneChan: 1912 } 1913} 1914 1915// WaitUntilReady will return when the connection is ready for another query 1916func (c *Conn) WaitUntilReady(ctx context.Context) error { 1917 err := c.waitForPreviousCancelQuery(ctx) 1918 if err != nil { 1919 return err 1920 } 1921 return c.ensureConnectionReadyForQuery() 1922} 1923 1924func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error { 1925 c.mux.Lock() 1926 completeCh := c.cancelQueryCompleted 1927 c.mux.Unlock() 1928 select { 1929 case <-completeCh: 1930 if err := c.conn.SetDeadline(time.Time{}); err != nil { 1931 c.Close() // Close connection if unable to disable deadline 1932 return err 1933 } 1934 return nil 1935 case <-ctx.Done(): 1936 return ctx.Err() 1937 } 1938} 1939 1940func (c *Conn) ensureConnectionReadyForQuery() error { 1941 for c.pendingReadyForQueryCount > 0 { 1942 msg, err := c.rxMsg() 1943 if err != nil { 1944 return err 1945 } 1946 1947 switch msg := msg.(type) { 1948 case *pgproto3.ErrorResponse: 1949 pgErr := c.rxErrorResponse(msg) 1950 if pgErr.Severity == "FATAL" { 1951 return pgErr 1952 } 1953 default: 1954 err = c.processContextFreeMsg(msg) 1955 if err != nil { 1956 return err 1957 } 1958 } 1959 } 1960 1961 return nil 1962} 1963 1964func connInfoFromRows(rows *Rows, err error) (map[string]pgtype.OID, error) { 1965 if err != nil { 1966 return nil, err 1967 } 1968 defer rows.Close() 1969 1970 nameOIDs := make(map[string]pgtype.OID, 256) 1971 for rows.Next() { 1972 var oid pgtype.OID 1973 var name pgtype.Text 1974 if err = rows.Scan(&oid, &name); err != nil { 1975 return nil, err 1976 } 1977 1978 nameOIDs[name.String] = oid 1979 } 1980 1981 if err = rows.Err(); err != nil { 1982 return nil, err 1983 } 1984 1985 return nameOIDs, err 1986} 1987 1988// LastStmtSent returns true if the last call to Query(Ex)/Exec(Ex) attempted to 1989// send the statement over the wire. Each call to a Query(Ex)/Exec(Ex) resets 1990// the value to false initially until the statement has been sent. This does 1991// NOT mean that the statement was successful or even received, it just means 1992// that a write was attempted and therefore it could have been executed. Calls 1993// to prepare a statement are ignored, only when the prepared statement is 1994// attempted to be executed will this return true. 1995func (c *Conn) LastStmtSent() bool { 1996 return c.lastStmtSent 1997} 1998