1package pgx
2
3import (
4	"context"
5	"crypto/md5"
6	"crypto/tls"
7	"crypto/x509"
8	"encoding/binary"
9	"encoding/hex"
10	"fmt"
11	"io"
12	"io/ioutil"
13	"net"
14	"net/url"
15	"os"
16	"os/user"
17	"path/filepath"
18	"reflect"
19	"regexp"
20	"strconv"
21	"strings"
22	"sync"
23	"time"
24
25	"github.com/pkg/errors"
26
27	"github.com/jackc/pgx/pgio"
28	"github.com/jackc/pgx/pgproto3"
29	"github.com/jackc/pgx/pgtype"
30)
31
32const (
33	connStatusUninitialized = iota
34	connStatusClosed
35	connStatusIdle
36	connStatusBusy
37)
38
39// minimalConnInfo has just enough static type information to establish the
40// connection and retrieve the type data.
41var minimalConnInfo *pgtype.ConnInfo
42
43func init() {
44	minimalConnInfo = pgtype.NewConnInfo()
45	minimalConnInfo.InitializeDataTypes(map[string]pgtype.OID{
46		"int4":    pgtype.Int4OID,
47		"name":    pgtype.NameOID,
48		"oid":     pgtype.OIDOID,
49		"text":    pgtype.TextOID,
50		"varchar": pgtype.VarcharOID,
51	})
52}
53
54// NoticeHandler is a function that can handle notices received from the
55// PostgreSQL server. Notices can be received at any time, usually during
56// handling of a query response. The *Conn is provided so the handler is aware
57// of the origin of the notice, but it must not invoke any query method. Be
58// aware that this is distinct from LISTEN/NOTIFY notification.
59type NoticeHandler func(*Conn, *Notice)
60
61// DialFunc is a function that can be used to connect to a PostgreSQL server
62type DialFunc func(network, addr string) (net.Conn, error)
63
64// ConnConfig contains all the options used to establish a connection.
65type ConnConfig struct {
66	Host              string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
67	Port              uint16 // default: 5432
68	Database          string
69	User              string // default: OS user name
70	Password          string
71	TLSConfig         *tls.Config // config for TLS connection -- nil disables TLS
72	UseFallbackTLS    bool        // Try FallbackTLSConfig if connecting with TLSConfig fails. Used for preferring TLS, but allowing unencrypted, or vice-versa
73	FallbackTLSConfig *tls.Config // config for fallback TLS connection (only used if UseFallBackTLS is true)-- nil disables TLS
74	Logger            Logger
75	LogLevel          int
76	Dial              DialFunc
77	RuntimeParams     map[string]string                     // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
78	OnNotice          NoticeHandler                         // Callback function called when a notice response is received.
79	CustomConnInfo    func(*Conn) (*pgtype.ConnInfo, error) // Callback function to implement connection strategies for different backends. crate, pgbouncer, pgpool, etc.
80	CustomCancel      func(*Conn) error                     // Callback function used to override cancellation behavior
81
82	// PreferSimpleProtocol disables implicit prepared statement usage. By default
83	// pgx automatically uses the unnamed prepared statement for Query and
84	// QueryRow. It also uses a prepared statement when Exec has arguments. This
85	// can improve performance due to being able to use the binary format. It also
86	// does not rely on client side parameter sanitization. However, it does incur
87	// two round-trips per query and may be incompatible proxies such as
88	// PGBouncer. Setting PreferSimpleProtocol causes the simple protocol to be
89	// used by default. The same functionality can be controlled on a per query
90	// basis by setting QueryExOptions.SimpleProtocol.
91	PreferSimpleProtocol bool
92}
93
94func (cc *ConnConfig) networkAddress() (network, address string) {
95	network = "tcp"
96	address = fmt.Sprintf("%s:%d", cc.Host, cc.Port)
97	// See if host is a valid path, if yes connect with a socket
98	if _, err := os.Stat(cc.Host); err == nil {
99		// For backward compatibility accept socket file paths -- but directories are now preferred
100		network = "unix"
101		address = cc.Host
102		if !strings.Contains(address, "/.s.PGSQL.") {
103			address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10)
104		}
105	}
106
107	return network, address
108}
109
110// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
111// Use ConnPool to manage access to multiple database connections from multiple
112// goroutines.
113type Conn struct {
114	conn               net.Conn  // the underlying TCP or unix domain socket connection
115	lastActivityTime   time.Time // the last time the connection was used
116	wbuf               []byte
117	pid                uint32            // backend pid
118	secretKey          uint32            // key to use to send a cancel query message to the server
119	RuntimeParams      map[string]string // parameters that have been reported by the server
120	config             ConnConfig        // config used when establishing this connection
121	txStatus           byte
122	preparedStatements map[string]*PreparedStatement
123	channels           map[string]struct{}
124	notifications      []*Notification
125	logger             Logger
126	logLevel           int
127	fp                 *fastpath
128	poolResetCount     int
129	preallocatedRows   []Rows
130	onNotice           NoticeHandler
131
132	mux          sync.Mutex
133	status       byte // One of connStatus* constants
134	causeOfDeath error
135
136	pendingReadyForQueryCount int // number of ReadyForQuery messages expected
137	cancelQueryCompleted      chan struct{}
138	lastStmtSent              bool
139
140	// context support
141	ctxInProgress bool
142	doneChan      chan struct{}
143	closedChan    chan error
144
145	ConnInfo *pgtype.ConnInfo
146
147	frontend *pgproto3.Frontend
148}
149
150// PreparedStatement is a description of a prepared statement
151type PreparedStatement struct {
152	Name              string
153	SQL               string
154	FieldDescriptions []FieldDescription
155	ParameterOIDs     []pgtype.OID
156}
157
158// PrepareExOptions is an option struct that can be passed to PrepareEx
159type PrepareExOptions struct {
160	ParameterOIDs []pgtype.OID
161}
162
163// Notification is a message received from the PostgreSQL LISTEN/NOTIFY system
164type Notification struct {
165	PID     uint32 // backend pid that sent the notification
166	Channel string // channel from which notification was received
167	Payload string
168}
169
170// CommandTag is the result of an Exec function
171type CommandTag string
172
173// RowsAffected returns the number of rows affected. If the CommandTag was not
174// for a row affecting command (such as "CREATE TABLE") then it returns 0
175func (ct CommandTag) RowsAffected() int64 {
176	s := string(ct)
177	index := strings.LastIndex(s, " ")
178	if index == -1 {
179		return 0
180	}
181	n, _ := strconv.ParseInt(s[index+1:], 10, 64)
182	return n
183}
184
185// Identifier a PostgreSQL identifier or name. Identifiers can be composed of
186// multiple parts such as ["schema", "table"] or ["table", "column"].
187type Identifier []string
188
189// Sanitize returns a sanitized string safe for SQL interpolation.
190func (ident Identifier) Sanitize() string {
191	parts := make([]string, len(ident))
192	for i := range ident {
193		parts[i] = `"` + strings.Replace(ident[i], `"`, `""`, -1) + `"`
194	}
195	return strings.Join(parts, ".")
196}
197
198// ErrNoRows occurs when rows are expected but none are returned.
199var ErrNoRows = errors.New("no rows in result set")
200
201// ErrDeadConn occurs on an attempt to use a dead connection
202var ErrDeadConn = errors.New("conn is dead")
203
204// ErrTLSRefused occurs when the connection attempt requires TLS and the
205// PostgreSQL server refuses to use TLS
206var ErrTLSRefused = errors.New("server refused TLS connection")
207
208// ErrConnBusy occurs when the connection is busy (for example, in the middle of
209// reading query results) and another action is attempted.
210var ErrConnBusy = errors.New("conn is busy")
211
212// ErrInvalidLogLevel occurs on attempt to set an invalid log level.
213var ErrInvalidLogLevel = errors.New("invalid log level")
214
215// ProtocolError occurs when unexpected data is received from PostgreSQL
216type ProtocolError string
217
218func (e ProtocolError) Error() string {
219	return string(e)
220}
221
222// Connect establishes a connection with a PostgreSQL server using config.
223// config.Host must be specified. config.User will default to the OS user name.
224// Other config fields are optional.
225func Connect(config ConnConfig) (c *Conn, err error) {
226	return connect(config, minimalConnInfo)
227}
228
229func defaultDialer() *net.Dialer {
230	return &net.Dialer{KeepAlive: 5 * time.Minute}
231}
232
233func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) {
234	c = new(Conn)
235
236	c.config = config
237	c.ConnInfo = connInfo
238
239	if c.config.LogLevel != 0 {
240		c.logLevel = c.config.LogLevel
241	} else {
242		// Preserve pre-LogLevel behavior by defaulting to LogLevelDebug
243		c.logLevel = LogLevelDebug
244	}
245	c.logger = c.config.Logger
246
247	if c.config.User == "" {
248		user, err := user.Current()
249		if err != nil {
250			return nil, err
251		}
252		c.config.User = user.Username
253		if c.shouldLog(LogLevelDebug) {
254			c.log(LogLevelDebug, "Using default connection config", map[string]interface{}{"User": c.config.User})
255		}
256	}
257
258	if c.config.Port == 0 {
259		c.config.Port = 5432
260		if c.shouldLog(LogLevelDebug) {
261			c.log(LogLevelDebug, "Using default connection config", map[string]interface{}{"Port": c.config.Port})
262		}
263	}
264
265	c.onNotice = config.OnNotice
266
267	network, address := c.config.networkAddress()
268	if c.config.Dial == nil {
269		d := defaultDialer()
270		c.config.Dial = d.Dial
271	}
272
273	if c.shouldLog(LogLevelInfo) {
274		c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"network": network, "address": address})
275	}
276	err = c.connect(config, network, address, config.TLSConfig)
277	if err != nil && config.UseFallbackTLS {
278		if c.shouldLog(LogLevelInfo) {
279			c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err})
280		}
281		err = c.connect(config, network, address, config.FallbackTLSConfig)
282	}
283
284	if err != nil {
285		if c.shouldLog(LogLevelError) {
286			c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err})
287		}
288		return nil, err
289	}
290
291	return c, nil
292}
293
294func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) {
295	c.conn, err = c.config.Dial(network, address)
296	if err != nil {
297		return err
298	}
299	defer func() {
300		if c != nil && err != nil {
301			c.conn.Close()
302			c.mux.Lock()
303			c.status = connStatusClosed
304			c.mux.Unlock()
305		}
306	}()
307
308	c.RuntimeParams = make(map[string]string)
309	c.preparedStatements = make(map[string]*PreparedStatement)
310	c.channels = make(map[string]struct{})
311	c.lastActivityTime = time.Now()
312	c.cancelQueryCompleted = make(chan struct{})
313	close(c.cancelQueryCompleted)
314	c.doneChan = make(chan struct{})
315	c.closedChan = make(chan error)
316	c.wbuf = make([]byte, 0, 1024)
317
318	c.mux.Lock()
319	c.status = connStatusIdle
320	c.mux.Unlock()
321
322	if tlsConfig != nil {
323		if c.shouldLog(LogLevelDebug) {
324			c.log(LogLevelDebug, "starting TLS handshake", nil)
325		}
326		if err := c.startTLS(tlsConfig); err != nil {
327			return err
328		}
329	}
330
331	c.frontend, err = pgproto3.NewFrontend(c.conn, c.conn)
332	if err != nil {
333		return err
334	}
335
336	startupMsg := pgproto3.StartupMessage{
337		ProtocolVersion: pgproto3.ProtocolVersionNumber,
338		Parameters:      make(map[string]string),
339	}
340
341	// Copy default run-time params
342	for k, v := range config.RuntimeParams {
343		startupMsg.Parameters[k] = v
344	}
345
346	startupMsg.Parameters["user"] = c.config.User
347	if c.config.Database != "" {
348		startupMsg.Parameters["database"] = c.config.Database
349	}
350
351	if _, err := c.conn.Write(startupMsg.Encode(nil)); err != nil {
352		return err
353	}
354
355	c.pendingReadyForQueryCount = 1
356
357	for {
358		msg, err := c.rxMsg()
359		if err != nil {
360			return err
361		}
362
363		switch msg := msg.(type) {
364		case *pgproto3.BackendKeyData:
365			c.rxBackendKeyData(msg)
366		case *pgproto3.Authentication:
367			if err = c.rxAuthenticationX(msg); err != nil {
368				return err
369			}
370		case *pgproto3.ReadyForQuery:
371			c.rxReadyForQuery(msg)
372			if c.shouldLog(LogLevelInfo) {
373				c.log(LogLevelInfo, "connection established", nil)
374			}
375
376			// Replication connections can't execute the queries to
377			// populate the c.PgTypes and c.pgsqlAfInet
378			if _, ok := config.RuntimeParams["replication"]; ok {
379				return nil
380			}
381
382			if c.ConnInfo == minimalConnInfo {
383				err = c.initConnInfo()
384				if err != nil {
385					return err
386				}
387			}
388
389			return nil
390		default:
391			if err = c.processContextFreeMsg(msg); err != nil {
392				return err
393			}
394		}
395	}
396}
397
398func initPostgresql(c *Conn) (*pgtype.ConnInfo, error) {
399	const (
400		namedOIDQuery = `select t.oid,
401	case when nsp.nspname in ('pg_catalog', 'public') then t.typname
402		else nsp.nspname||'.'||t.typname
403	end
404from pg_type t
405left join pg_type base_type on t.typelem=base_type.oid
406left join pg_namespace nsp on t.typnamespace=nsp.oid
407where (
408	  t.typtype in('b', 'p', 'r', 'e')
409	  and (base_type.oid is null or base_type.typtype in('b', 'p', 'r'))
410	)`
411	)
412
413	nameOIDs, err := connInfoFromRows(c.Query(namedOIDQuery))
414	if err != nil {
415		return nil, err
416	}
417
418	cinfo := pgtype.NewConnInfo()
419	cinfo.InitializeDataTypes(nameOIDs)
420
421	if err = c.initConnInfoEnumArray(cinfo); err != nil {
422		return nil, err
423	}
424
425	if err = c.initConnInfoDomains(cinfo); err != nil {
426		return nil, err
427	}
428
429	return cinfo, nil
430}
431
432func (c *Conn) initConnInfo() (err error) {
433	var (
434		connInfo *pgtype.ConnInfo
435	)
436
437	if c.config.CustomConnInfo != nil {
438		if c.ConnInfo, err = c.config.CustomConnInfo(c); err != nil {
439			return err
440		}
441
442		return nil
443	}
444
445	if connInfo, err = initPostgresql(c); err == nil {
446		c.ConnInfo = connInfo
447		return err
448	}
449
450	// Check if CrateDB specific approach might still allow us to connect.
451	if connInfo, err = c.crateDBTypesQuery(err); err == nil {
452		c.ConnInfo = connInfo
453	}
454
455	return err
456}
457
458// initConnInfoEnumArray introspects for arrays of enums and registers a data type for them.
459func (c *Conn) initConnInfoEnumArray(cinfo *pgtype.ConnInfo) error {
460	nameOIDs := make(map[string]pgtype.OID, 16)
461	rows, err := c.Query(`select t.oid, t.typname
462from pg_type t
463  join pg_type base_type on t.typelem=base_type.oid
464where t.typtype = 'b'
465  and base_type.typtype = 'e'`)
466	if err != nil {
467		return err
468	}
469
470	for rows.Next() {
471		var oid pgtype.OID
472		var name pgtype.Text
473		if err := rows.Scan(&oid, &name); err != nil {
474			return err
475		}
476
477		nameOIDs[name.String] = oid
478	}
479
480	if rows.Err() != nil {
481		return rows.Err()
482	}
483
484	for name, oid := range nameOIDs {
485		cinfo.RegisterDataType(pgtype.DataType{
486			Value: &pgtype.EnumArray{},
487			Name:  name,
488			OID:   oid,
489		})
490	}
491
492	return nil
493}
494
495// initConnInfoDomains introspects for domains and registers a data type for them.
496func (c *Conn) initConnInfoDomains(cinfo *pgtype.ConnInfo) error {
497	type domain struct {
498		oid     pgtype.OID
499		name    pgtype.Text
500		baseOID pgtype.OID
501	}
502
503	domains := make([]*domain, 0, 16)
504
505	rows, err := c.Query(`select t.oid, t.typname, t.typbasetype
506from pg_type t
507  join pg_type base_type on t.typbasetype=base_type.oid
508where t.typtype = 'd'
509  and base_type.typtype = 'b'`)
510	if err != nil {
511		return err
512	}
513
514	for rows.Next() {
515		var d domain
516		if err := rows.Scan(&d.oid, &d.name, &d.baseOID); err != nil {
517			return err
518		}
519
520		domains = append(domains, &d)
521	}
522
523	if rows.Err() != nil {
524		return rows.Err()
525	}
526
527	for _, d := range domains {
528		baseDataType, ok := cinfo.DataTypeForOID(d.baseOID)
529		if ok {
530			cinfo.RegisterDataType(pgtype.DataType{
531				Value: reflect.New(reflect.ValueOf(baseDataType.Value).Elem().Type()).Interface().(pgtype.Value),
532				Name:  d.name.String,
533				OID:   d.oid,
534			})
535		}
536	}
537
538	return nil
539}
540
541// crateDBTypesQuery checks if the given err is likely to be the result of
542// CrateDB not implementing the pg_types table correctly. If yes, a CrateDB
543// specific query against pg_types is executed and its results are returned. If
544// not, the original error is returned.
545func (c *Conn) crateDBTypesQuery(err error) (*pgtype.ConnInfo, error) {
546	// CrateDB 2.1.6 is a database that implements the PostgreSQL wire protocol,
547	// but not perfectly. In particular, the pg_catalog schema containing the
548	// pg_type table is not visible by default and the pg_type.typtype column is
549	// not implemented. Therefor the query above currently returns the following
550	// error:
551	//
552	//   pgx.PgError{Severity:"ERROR", Code:"XX000",
553	//   Message:"TableUnknownException: Table 'test.pg_type' unknown",
554	//   Detail:"", Hint:"", Position:0, InternalPosition:0, InternalQuery:"",
555	//   Where:"", SchemaName:"", TableName:"", ColumnName:"", DataTypeName:"",
556	//   ConstraintName:"", File:"Schemas.java", Line:99, Routine:"getTableInfo"}
557	//
558	// If CrateDB was to fix the pg_type table visbility in the future, we'd
559	// still get this error until typtype column is implemented:
560	//
561	//   pgx.PgError{Severity:"ERROR", Code:"XX000",
562	//   Message:"ColumnUnknownException: Column typtype unknown", Detail:"",
563	//   Hint:"", Position:0, InternalPosition:0, InternalQuery:"", Where:"",
564	//   SchemaName:"", TableName:"", ColumnName:"", DataTypeName:"",
565	//   ConstraintName:"", File:"FullQualifiedNameFieldProvider.java", Line:132,
566	//
567	// Additionally CrateDB doesn't implement Postgres error codes [2], and
568	// instead always returns "XX000" (internal_error). The code below uses all
569	// of this knowledge as a heuristic to detect CrateDB. If CrateDB is
570	// detected, a CrateDB specific pg_type query is executed instead.
571	//
572	// The heuristic is designed to still work even if CrateDB fixes [2] or
573	// renames its internal exception names. If both are changed but pg_types
574	// isn't fixed, this code will need to be changed.
575	//
576	// There is also a small chance the heuristic will yield a false positive for
577	// non-CrateDB databases (e.g. if a real Postgres instance returns a XX000
578	// error), but hopefully there will be no harm in attempting the alternative
579	// query in this case.
580	//
581	// CrateDB also uses the type varchar for the typname column which required
582	// adding varchar to the minimalConnInfo init code.
583	//
584	// Also see the discussion here [3].
585	//
586	// [1] https://crate.io/
587	// [2] https://github.com/crate/crate/issues/5027
588	// [3] https://github.com/jackc/pgx/issues/320
589
590	if pgErr, ok := err.(PgError); ok &&
591		(pgErr.Code == "XX000" ||
592			strings.Contains(pgErr.Message, "TableUnknownException") ||
593			strings.Contains(pgErr.Message, "ColumnUnknownException")) {
594		var (
595			nameOIDs map[string]pgtype.OID
596		)
597
598		if nameOIDs, err = connInfoFromRows(c.Query(`select oid, typname from pg_catalog.pg_type`)); err != nil {
599			return nil, err
600		}
601
602		cinfo := pgtype.NewConnInfo()
603		cinfo.InitializeDataTypes(nameOIDs)
604
605		return cinfo, err
606	}
607
608	return nil, err
609}
610
611// PID returns the backend PID for this connection.
612func (c *Conn) PID() uint32 {
613	return c.pid
614}
615
616// LocalAddr returns the underlying connection's local address
617func (c *Conn) LocalAddr() (net.Addr, error) {
618	if !c.IsAlive() {
619		return nil, errors.New("connection not ready")
620	}
621	return c.conn.LocalAddr(), nil
622}
623
624// Close closes a connection. It is safe to call Close on a already closed
625// connection.
626func (c *Conn) Close() (err error) {
627	c.mux.Lock()
628	defer c.mux.Unlock()
629
630	if c.status < connStatusIdle {
631		return nil
632	}
633	c.status = connStatusClosed
634
635	defer func() {
636		c.conn.Close()
637		c.causeOfDeath = errors.New("Closed")
638		if c.shouldLog(LogLevelInfo) {
639			c.log(LogLevelInfo, "closed connection", nil)
640		}
641	}()
642
643	err = c.conn.SetDeadline(time.Time{})
644	if err != nil && c.shouldLog(LogLevelWarn) {
645		c.log(LogLevelWarn, "failed to clear deadlines to send close message", map[string]interface{}{"err": err})
646		return err
647	}
648
649	_, err = c.conn.Write([]byte{'X', 0, 0, 0, 4})
650	if err != nil && c.shouldLog(LogLevelWarn) {
651		c.log(LogLevelWarn, "failed to send terminate message", map[string]interface{}{"err": err})
652		return err
653	}
654
655	err = c.conn.SetReadDeadline(time.Now().Add(5 * time.Second))
656	if err != nil && c.shouldLog(LogLevelWarn) {
657		c.log(LogLevelWarn, "failed to set read deadline to finish closing", map[string]interface{}{"err": err})
658		return err
659	}
660
661	_, err = c.conn.Read(make([]byte, 1))
662	if err != io.EOF {
663		return err
664	}
665
666	return nil
667}
668
669// Merge returns a new ConnConfig with the attributes of old and other
670// combined. When an attribute is set on both, other takes precedence.
671//
672// As a security precaution, if the other TLSConfig is nil, all old TLS
673// attributes will be preserved.
674func (old ConnConfig) Merge(other ConnConfig) ConnConfig {
675	cc := old
676
677	if other.Host != "" {
678		cc.Host = other.Host
679	}
680	if other.Port != 0 {
681		cc.Port = other.Port
682	}
683	if other.Database != "" {
684		cc.Database = other.Database
685	}
686	if other.User != "" {
687		cc.User = other.User
688	}
689	if other.Password != "" {
690		cc.Password = other.Password
691	}
692
693	if other.TLSConfig != nil {
694		cc.TLSConfig = other.TLSConfig
695		cc.UseFallbackTLS = other.UseFallbackTLS
696		cc.FallbackTLSConfig = other.FallbackTLSConfig
697	}
698
699	if other.Logger != nil {
700		cc.Logger = other.Logger
701	}
702	if other.LogLevel != 0 {
703		cc.LogLevel = other.LogLevel
704	}
705
706	if other.Dial != nil {
707		cc.Dial = other.Dial
708	}
709
710	cc.PreferSimpleProtocol = other.PreferSimpleProtocol
711
712	cc.RuntimeParams = make(map[string]string)
713	for k, v := range old.RuntimeParams {
714		cc.RuntimeParams[k] = v
715	}
716	for k, v := range other.RuntimeParams {
717		cc.RuntimeParams[k] = v
718	}
719
720	return cc
721}
722
723// ParseURI parses a database URI into ConnConfig
724//
725// Query parameters not used by the connection process are parsed into ConnConfig.RuntimeParams.
726func ParseURI(uri string) (ConnConfig, error) {
727	var cp ConnConfig
728
729	url, err := url.Parse(uri)
730	if err != nil {
731		return cp, err
732	}
733
734	if url.User != nil {
735		cp.User = url.User.Username()
736		cp.Password, _ = url.User.Password()
737	}
738
739	parts := strings.SplitN(url.Host, ":", 2)
740	cp.Host = parts[0]
741	if len(parts) == 2 {
742		p, err := strconv.ParseUint(parts[1], 10, 16)
743		if err != nil {
744			return cp, err
745		}
746		cp.Port = uint16(p)
747	}
748	cp.Database = strings.TrimLeft(url.Path, "/")
749
750	if pgtimeout := url.Query().Get("connect_timeout"); pgtimeout != "" {
751		timeout, err := strconv.ParseInt(pgtimeout, 10, 64)
752		if err != nil {
753			return cp, err
754		}
755		d := defaultDialer()
756		d.Timeout = time.Duration(timeout) * time.Second
757		cp.Dial = d.Dial
758	}
759
760	tlsArgs := configTLSArgs{
761		sslCert:     url.Query().Get("sslcert"),
762		sslKey:      url.Query().Get("sslkey"),
763		sslMode:     url.Query().Get("sslmode"),
764		sslRootCert: url.Query().Get("sslrootcert"),
765	}
766	err = configTLS(tlsArgs, &cp)
767	if err != nil {
768		return cp, err
769	}
770
771	ignoreKeys := map[string]struct{}{
772		"connect_timeout": {},
773		"sslcert":         {},
774		"sslkey":          {},
775		"sslmode":         {},
776		"sslrootcert":     {},
777	}
778
779	cp.RuntimeParams = make(map[string]string)
780
781	for k, v := range url.Query() {
782		if _, ok := ignoreKeys[k]; ok {
783			continue
784		}
785
786		if k == "host" {
787			cp.Host = v[0]
788			continue
789		}
790
791		cp.RuntimeParams[k] = v[0]
792	}
793	if cp.Password == "" {
794		pgpass(&cp)
795	}
796	return cp, nil
797}
798
799var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`)
800
801// ParseDSN parses a database DSN (data source name) into a ConnConfig
802//
803// e.g. ParseDSN("user=username password=password host=1.2.3.4 port=5432 dbname=mydb sslmode=disable")
804//
805// Any options not used by the connection process are parsed into ConnConfig.RuntimeParams.
806//
807// e.g. ParseDSN("application_name=pgxtest search_path=admin user=username password=password host=1.2.3.4 dbname=mydb")
808//
809// ParseDSN tries to match libpq behavior with regard to sslmode. See comments
810// for ParseEnvLibpq for more information on the security implications of
811// sslmode options.
812func ParseDSN(s string) (ConnConfig, error) {
813	var cp ConnConfig
814
815	m := dsnRegexp.FindAllStringSubmatch(s, -1)
816
817	tlsArgs := configTLSArgs{}
818
819	cp.RuntimeParams = make(map[string]string)
820
821	for _, b := range m {
822		switch b[1] {
823		case "user":
824			cp.User = b[2]
825		case "password":
826			cp.Password = b[2]
827		case "host":
828			cp.Host = b[2]
829		case "port":
830			p, err := strconv.ParseUint(b[2], 10, 16)
831			if err != nil {
832				return cp, err
833			}
834			cp.Port = uint16(p)
835		case "dbname":
836			cp.Database = b[2]
837		case "sslmode":
838			tlsArgs.sslMode = b[2]
839		case "sslrootcert":
840			tlsArgs.sslRootCert = b[2]
841		case "sslcert":
842			tlsArgs.sslCert = b[2]
843		case "sslkey":
844			tlsArgs.sslKey = b[2]
845		case "connect_timeout":
846			timeout, err := strconv.ParseInt(b[2], 10, 64)
847			if err != nil {
848				return cp, err
849			}
850			d := defaultDialer()
851			d.Timeout = time.Duration(timeout) * time.Second
852			cp.Dial = d.Dial
853		default:
854			cp.RuntimeParams[b[1]] = b[2]
855		}
856	}
857
858	err := configTLS(tlsArgs, &cp)
859	if err != nil {
860		return cp, err
861	}
862	if cp.Password == "" {
863		pgpass(&cp)
864	}
865	return cp, nil
866}
867
868// ParseConnectionString parses either a URI or a DSN connection string.
869// see ParseURI and ParseDSN for details.
870func ParseConnectionString(s string) (ConnConfig, error) {
871	if u, err := url.Parse(s); err == nil && u.Scheme != "" {
872		return ParseURI(s)
873	}
874	return ParseDSN(s)
875}
876
877// ParseEnvLibpq parses the environment like libpq does into a ConnConfig
878//
879// See http://www.postgresql.org/docs/9.4/static/libpq-envars.html for details
880// on the meaning of environment variables.
881//
882// ParseEnvLibpq currently recognizes the following environment variables:
883// PGHOST
884// PGPORT
885// PGDATABASE
886// PGUSER
887// PGPASSWORD
888// PGSSLMODE
889// PGSSLCERT
890// PGSSLKEY
891// PGSSLROOTCERT
892// PGAPPNAME
893// PGCONNECT_TIMEOUT
894//
895// Important TLS Security Notes:
896// ParseEnvLibpq tries to match libpq behavior with regard to PGSSLMODE. This
897// includes defaulting to "prefer" behavior if no environment variable is set.
898//
899// See http://www.postgresql.org/docs/9.4/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION
900// for details on what level of security each sslmode provides.
901//
902// "verify-ca" mode currently is treated as "verify-full". e.g. It has stronger
903// security guarantees than it would with libpq. Do not rely on this behavior as it
904// may be possible to match libpq in the future. If you need full security use
905// "verify-full".
906//
907// Several of the PGSSLMODE options (including the default behavior of "prefer")
908// will set UseFallbackTLS to true and FallbackTLSConfig to a disabled or
909// weakened TLS mode. This means that if ParseEnvLibpq is used, but TLSConfig is
910// later set from a different source that UseFallbackTLS MUST be set false to
911// avoid the possibility of falling back to weaker or disabled security.
912func ParseEnvLibpq() (ConnConfig, error) {
913	var cc ConnConfig
914
915	cc.Host = os.Getenv("PGHOST")
916
917	if pgport := os.Getenv("PGPORT"); pgport != "" {
918		if port, err := strconv.ParseUint(pgport, 10, 16); err == nil {
919			cc.Port = uint16(port)
920		} else {
921			return cc, err
922		}
923	}
924
925	cc.Database = os.Getenv("PGDATABASE")
926	cc.User = os.Getenv("PGUSER")
927	cc.Password = os.Getenv("PGPASSWORD")
928
929	if pgtimeout := os.Getenv("PGCONNECT_TIMEOUT"); pgtimeout != "" {
930		if timeout, err := strconv.ParseInt(pgtimeout, 10, 64); err == nil {
931			d := defaultDialer()
932			d.Timeout = time.Duration(timeout) * time.Second
933			cc.Dial = d.Dial
934		} else {
935			return cc, err
936		}
937	}
938
939	tlsArgs := configTLSArgs{
940		sslMode:     os.Getenv("PGSSLMODE"),
941		sslKey:      os.Getenv("PGSSLKEY"),
942		sslCert:     os.Getenv("PGSSLCERT"),
943		sslRootCert: os.Getenv("PGSSLROOTCERT"),
944	}
945
946	err := configTLS(tlsArgs, &cc)
947	if err != nil {
948		return cc, err
949	}
950
951	cc.RuntimeParams = make(map[string]string)
952	if appname := os.Getenv("PGAPPNAME"); appname != "" {
953		cc.RuntimeParams["application_name"] = appname
954	}
955	if cc.Password == "" {
956		pgpass(&cc)
957	}
958	return cc, nil
959}
960
961type configTLSArgs struct {
962	sslMode     string
963	sslRootCert string
964	sslCert     string
965	sslKey      string
966}
967
968// configTLS uses lib/pq's TLS parameters to reconstruct a coherent tls.Config.
969// Inputs are parsed out and provided by ParseDSN() or ParseURI().
970func configTLS(args configTLSArgs, cc *ConnConfig) error {
971	// Match libpq default behavior
972	if args.sslMode == "" {
973		args.sslMode = "prefer"
974	}
975
976	switch args.sslMode {
977	case "disable":
978		cc.UseFallbackTLS = false
979		cc.TLSConfig = nil
980		cc.FallbackTLSConfig = nil
981		return nil
982	case "allow":
983		cc.UseFallbackTLS = true
984		cc.FallbackTLSConfig = &tls.Config{InsecureSkipVerify: true}
985	case "prefer":
986		cc.TLSConfig = &tls.Config{InsecureSkipVerify: true}
987		cc.UseFallbackTLS = true
988		cc.FallbackTLSConfig = nil
989	case "require":
990		cc.TLSConfig = &tls.Config{InsecureSkipVerify: true}
991	case "verify-ca", "verify-full":
992		cc.TLSConfig = &tls.Config{
993			ServerName: cc.Host,
994		}
995	default:
996		return errors.New("sslmode is invalid")
997	}
998
999	if args.sslRootCert != "" {
1000		caCertPool := x509.NewCertPool()
1001
1002		caPath := args.sslRootCert
1003		caCert, err := ioutil.ReadFile(caPath)
1004		if err != nil {
1005			return errors.Wrapf(err, "unable to read CA file %q", caPath)
1006		}
1007
1008		if !caCertPool.AppendCertsFromPEM(caCert) {
1009			return errors.Wrap(err, "unable to add CA to cert pool")
1010		}
1011
1012		cc.TLSConfig.RootCAs = caCertPool
1013		cc.TLSConfig.ClientCAs = caCertPool
1014	}
1015
1016	sslcert := args.sslCert
1017	sslkey := args.sslKey
1018
1019	if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
1020		return fmt.Errorf(`both "sslcert" and "sslkey" are required`)
1021	}
1022
1023	if sslcert != "" && sslkey != "" {
1024		cert, err := tls.LoadX509KeyPair(sslcert, sslkey)
1025		if err != nil {
1026			return errors.Wrap(err, "unable to read cert")
1027		}
1028
1029		cc.TLSConfig.Certificates = []tls.Certificate{cert}
1030	}
1031
1032	return nil
1033}
1034
1035// Prepare creates a prepared statement with name and sql. sql can contain placeholders
1036// for bound parameters. These placeholders are referenced positional as $1, $2, etc.
1037//
1038// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same
1039// name and sql arguments. This allows a code path to Prepare and Query/Exec without
1040// concern for if the statement has already been prepared.
1041func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
1042	return c.PrepareEx(context.Background(), name, sql, nil)
1043}
1044
1045// PrepareEx creates a prepared statement with name and sql. sql can contain placeholders
1046// for bound parameters. These placeholders are referenced positional as $1, $2, etc.
1047// It differs from Prepare as it allows additional options (such as parameter OIDs) to be passed via struct
1048//
1049// PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same
1050// name and sql arguments. This allows a code path to PrepareEx and Query/Exec without
1051// concern for if the statement has already been prepared.
1052func (c *Conn) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
1053	err = c.waitForPreviousCancelQuery(ctx)
1054	if err != nil {
1055		return nil, err
1056	}
1057
1058	err = c.initContext(ctx)
1059	if err != nil {
1060		return nil, err
1061	}
1062
1063	ps, err = c.prepareEx(name, sql, opts)
1064	err = c.termContext(err)
1065	return ps, err
1066}
1067
1068func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
1069	if name != "" {
1070		if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql {
1071			return ps, nil
1072		}
1073	}
1074
1075	if err := c.ensureConnectionReadyForQuery(); err != nil {
1076		return nil, err
1077	}
1078
1079	if c.shouldLog(LogLevelError) {
1080		defer func() {
1081			if err != nil {
1082				c.log(LogLevelError, "prepareEx failed", map[string]interface{}{"err": err, "name": name, "sql": sql})
1083			}
1084		}()
1085	}
1086
1087	if opts == nil {
1088		opts = &PrepareExOptions{}
1089	}
1090
1091	if len(opts.ParameterOIDs) > 65535 {
1092		return nil, errors.Errorf("Number of PrepareExOptions ParameterOIDs must be between 0 and 65535, received %d", len(opts.ParameterOIDs))
1093	}
1094
1095	buf := appendParse(c.wbuf, name, sql, opts.ParameterOIDs)
1096	buf = appendDescribe(buf, 'S', name)
1097	buf = appendSync(buf)
1098
1099	n, err := c.conn.Write(buf)
1100	if err != nil {
1101		if fatalWriteErr(n, err) {
1102			c.die(err)
1103		}
1104		return nil, err
1105	}
1106	c.pendingReadyForQueryCount++
1107
1108	ps = &PreparedStatement{Name: name, SQL: sql}
1109
1110	var softErr error
1111
1112	for {
1113		msg, err := c.rxMsg()
1114		if err != nil {
1115			return nil, err
1116		}
1117
1118		switch msg := msg.(type) {
1119		case *pgproto3.ParameterDescription:
1120			ps.ParameterOIDs = c.rxParameterDescription(msg)
1121
1122			if len(ps.ParameterOIDs) > 65535 && softErr == nil {
1123				softErr = errors.Errorf("PostgreSQL supports maximum of 65535 parameters, received %d", len(ps.ParameterOIDs))
1124			}
1125		case *pgproto3.RowDescription:
1126			ps.FieldDescriptions = c.rxRowDescription(msg)
1127			for i := range ps.FieldDescriptions {
1128				if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok {
1129					ps.FieldDescriptions[i].DataTypeName = dt.Name
1130					if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
1131						ps.FieldDescriptions[i].FormatCode = BinaryFormatCode
1132					} else {
1133						ps.FieldDescriptions[i].FormatCode = TextFormatCode
1134					}
1135				} else {
1136					return nil, errors.Errorf("unknown oid: %d", ps.FieldDescriptions[i].DataType)
1137				}
1138			}
1139		case *pgproto3.ReadyForQuery:
1140			c.rxReadyForQuery(msg)
1141
1142			if softErr == nil {
1143				c.preparedStatements[name] = ps
1144			}
1145
1146			return ps, softErr
1147		default:
1148			if e := c.processContextFreeMsg(msg); e != nil && softErr == nil {
1149				softErr = e
1150			}
1151		}
1152	}
1153}
1154
1155// Deallocate released a prepared statement
1156func (c *Conn) Deallocate(name string) error {
1157	return c.deallocateContext(context.Background(), name)
1158}
1159
1160// TODO - consider making this public
1161func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) {
1162	err = c.waitForPreviousCancelQuery(ctx)
1163	if err != nil {
1164		return err
1165	}
1166
1167	err = c.initContext(ctx)
1168	if err != nil {
1169		return err
1170	}
1171	defer func() {
1172		err = c.termContext(err)
1173	}()
1174
1175	if err := c.ensureConnectionReadyForQuery(); err != nil {
1176		return err
1177	}
1178
1179	delete(c.preparedStatements, name)
1180
1181	// close
1182	buf := c.wbuf
1183	buf = append(buf, 'C')
1184	sp := len(buf)
1185	buf = pgio.AppendInt32(buf, -1)
1186	buf = append(buf, 'S')
1187	buf = append(buf, name...)
1188	buf = append(buf, 0)
1189	pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
1190
1191	// flush
1192	buf = append(buf, 'H')
1193	buf = pgio.AppendInt32(buf, 4)
1194
1195	_, err = c.conn.Write(buf)
1196	if err != nil {
1197		c.die(err)
1198		return err
1199	}
1200
1201	for {
1202		msg, err := c.rxMsg()
1203		if err != nil {
1204			return err
1205		}
1206
1207		switch msg.(type) {
1208		case *pgproto3.CloseComplete:
1209			return nil
1210		default:
1211			err = c.processContextFreeMsg(msg)
1212			if err != nil {
1213				return err
1214			}
1215		}
1216	}
1217}
1218
1219// Listen establishes a PostgreSQL listen/notify to channel
1220func (c *Conn) Listen(channel string) error {
1221	_, err := c.Exec("listen " + quoteIdentifier(channel))
1222	if err != nil {
1223		return err
1224	}
1225
1226	c.channels[channel] = struct{}{}
1227
1228	return nil
1229}
1230
1231// Unlisten unsubscribes from a listen channel
1232func (c *Conn) Unlisten(channel string) error {
1233	_, err := c.Exec("unlisten " + quoteIdentifier(channel))
1234	if err != nil {
1235		return err
1236	}
1237
1238	delete(c.channels, channel)
1239	return nil
1240}
1241
1242// WaitForNotification waits for a PostgreSQL notification.
1243func (c *Conn) WaitForNotification(ctx context.Context) (notification *Notification, err error) {
1244	// Return already received notification immediately
1245	if len(c.notifications) > 0 {
1246		notification := c.notifications[0]
1247		c.notifications = c.notifications[1:]
1248		return notification, nil
1249	}
1250
1251	err = c.waitForPreviousCancelQuery(ctx)
1252	if err != nil {
1253		return nil, err
1254	}
1255
1256	err = c.initContext(ctx)
1257	if err != nil {
1258		return nil, err
1259	}
1260	defer func() {
1261		err = c.termContext(err)
1262	}()
1263
1264	if err = c.lock(); err != nil {
1265		return nil, err
1266	}
1267	defer func() {
1268		if unlockErr := c.unlock(); unlockErr != nil && err == nil {
1269			err = unlockErr
1270		}
1271	}()
1272
1273	if err := c.ensureConnectionReadyForQuery(); err != nil {
1274		return nil, err
1275	}
1276
1277	for {
1278		msg, err := c.rxMsg()
1279		if err != nil {
1280			return nil, err
1281		}
1282
1283		err = c.processContextFreeMsg(msg)
1284		if err != nil {
1285			return nil, err
1286		}
1287
1288		if len(c.notifications) > 0 {
1289			notification := c.notifications[0]
1290			c.notifications = c.notifications[1:]
1291			return notification, nil
1292		}
1293	}
1294}
1295
1296func (c *Conn) IsAlive() bool {
1297	c.mux.Lock()
1298	defer c.mux.Unlock()
1299	return c.status >= connStatusIdle
1300}
1301
1302func (c *Conn) CauseOfDeath() error {
1303	c.mux.Lock()
1304	defer c.mux.Unlock()
1305	return c.causeOfDeath
1306}
1307
1308func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) {
1309	if ps, present := c.preparedStatements[sql]; present {
1310		return c.sendPreparedQuery(ps, arguments...)
1311	}
1312	return c.sendSimpleQuery(sql, arguments...)
1313}
1314
1315func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error {
1316	if err := c.ensureConnectionReadyForQuery(); err != nil {
1317		return err
1318	}
1319
1320	if len(args) == 0 {
1321		buf := appendQuery(c.wbuf, sql)
1322
1323		_, err := c.conn.Write(buf)
1324		if err != nil {
1325			c.die(err)
1326			return err
1327		}
1328		c.pendingReadyForQueryCount++
1329
1330		return nil
1331	}
1332
1333	ps, err := c.Prepare("", sql)
1334	if err != nil {
1335		return err
1336	}
1337
1338	return c.sendPreparedQuery(ps, args...)
1339}
1340
1341func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}) (err error) {
1342	if len(ps.ParameterOIDs) != len(arguments) {
1343		return errors.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOIDs), len(arguments))
1344	}
1345
1346	if err := c.ensureConnectionReadyForQuery(); err != nil {
1347		return err
1348	}
1349
1350	resultFormatCodes := make([]int16, len(ps.FieldDescriptions))
1351	for i, fd := range ps.FieldDescriptions {
1352		resultFormatCodes[i] = fd.FormatCode
1353	}
1354	buf, err := appendBind(c.wbuf, "", ps.Name, c.ConnInfo, ps.ParameterOIDs, arguments, resultFormatCodes)
1355	if err != nil {
1356		return err
1357	}
1358
1359	buf = appendExecute(buf, "", 0)
1360	buf = appendSync(buf)
1361
1362	n, err := c.conn.Write(buf)
1363	if err != nil {
1364		if fatalWriteErr(n, err) {
1365			c.die(err)
1366		}
1367		return err
1368	}
1369	c.pendingReadyForQueryCount++
1370
1371	return nil
1372}
1373
1374// fatalWriteError takes the response of a net.Conn.Write and determines if it is fatal
1375func fatalWriteErr(bytesWritten int, err error) bool {
1376	// Partial writes break the connection
1377	if bytesWritten > 0 {
1378		return true
1379	}
1380
1381	netErr, is := err.(net.Error)
1382	return !(is && netErr.Timeout())
1383}
1384
1385// Exec executes sql. sql can be either a prepared statement name or an SQL string.
1386// arguments should be referenced positionally from the sql string as $1, $2, etc.
1387func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
1388	return c.ExecEx(context.Background(), sql, nil, arguments...)
1389}
1390
1391// Processes messages that are not exclusive to one context such as
1392// authentication or query response. The response to these messages is the same
1393// regardless of when they occur. It also ignores messages that are only
1394// meaningful in a given context. These messages can occur due to a context
1395// deadline interrupting message processing. For example, an interrupted query
1396// may have left DataRow messages on the wire.
1397func (c *Conn) processContextFreeMsg(msg pgproto3.BackendMessage) (err error) {
1398	switch msg := msg.(type) {
1399	case *pgproto3.ErrorResponse:
1400		return c.rxErrorResponse(msg)
1401	case *pgproto3.NoticeResponse:
1402		c.rxNoticeResponse(msg)
1403	case *pgproto3.NotificationResponse:
1404		c.rxNotificationResponse(msg)
1405	case *pgproto3.ReadyForQuery:
1406		c.rxReadyForQuery(msg)
1407	case *pgproto3.ParameterStatus:
1408		c.rxParameterStatus(msg)
1409	}
1410
1411	return nil
1412}
1413
1414func (c *Conn) rxMsg() (pgproto3.BackendMessage, error) {
1415	if !c.IsAlive() {
1416		return nil, ErrDeadConn
1417	}
1418
1419	msg, err := c.frontend.Receive()
1420	if err != nil {
1421		if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
1422			c.die(err)
1423		}
1424		return nil, err
1425	}
1426
1427	c.lastActivityTime = time.Now()
1428
1429	// fmt.Printf("rxMsg: %#v\n", msg)
1430
1431	return msg, nil
1432}
1433
1434func (c *Conn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) {
1435	switch msg.Type {
1436	case pgproto3.AuthTypeOk:
1437	case pgproto3.AuthTypeCleartextPassword:
1438		err = c.txPasswordMessage(c.config.Password)
1439	case pgproto3.AuthTypeMD5Password:
1440		digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+string(msg.Salt[:]))
1441		err = c.txPasswordMessage(digestedPassword)
1442	default:
1443		err = errors.New("Received unknown authentication message")
1444	}
1445
1446	return
1447}
1448
1449func hexMD5(s string) string {
1450	hash := md5.New()
1451	io.WriteString(hash, s)
1452	return hex.EncodeToString(hash.Sum(nil))
1453}
1454
1455func (c *Conn) rxParameterStatus(msg *pgproto3.ParameterStatus) {
1456	c.RuntimeParams[msg.Name] = msg.Value
1457}
1458
1459func (c *Conn) rxErrorResponse(msg *pgproto3.ErrorResponse) PgError {
1460	err := PgError{
1461		Severity:         msg.Severity,
1462		Code:             msg.Code,
1463		Message:          msg.Message,
1464		Detail:           msg.Detail,
1465		Hint:             msg.Hint,
1466		Position:         msg.Position,
1467		InternalPosition: msg.InternalPosition,
1468		InternalQuery:    msg.InternalQuery,
1469		Where:            msg.Where,
1470		SchemaName:       msg.SchemaName,
1471		TableName:        msg.TableName,
1472		ColumnName:       msg.ColumnName,
1473		DataTypeName:     msg.DataTypeName,
1474		ConstraintName:   msg.ConstraintName,
1475		File:             msg.File,
1476		Line:             msg.Line,
1477		Routine:          msg.Routine,
1478	}
1479
1480	if err.Severity == "FATAL" {
1481		c.die(err)
1482	}
1483
1484	return err
1485}
1486
1487func (c *Conn) rxNoticeResponse(msg *pgproto3.NoticeResponse) {
1488	if c.onNotice == nil {
1489		return
1490	}
1491
1492	notice := &Notice{
1493		Severity:         msg.Severity,
1494		Code:             msg.Code,
1495		Message:          msg.Message,
1496		Detail:           msg.Detail,
1497		Hint:             msg.Hint,
1498		Position:         msg.Position,
1499		InternalPosition: msg.InternalPosition,
1500		InternalQuery:    msg.InternalQuery,
1501		Where:            msg.Where,
1502		SchemaName:       msg.SchemaName,
1503		TableName:        msg.TableName,
1504		ColumnName:       msg.ColumnName,
1505		DataTypeName:     msg.DataTypeName,
1506		ConstraintName:   msg.ConstraintName,
1507		File:             msg.File,
1508		Line:             msg.Line,
1509		Routine:          msg.Routine,
1510	}
1511
1512	c.onNotice(c, notice)
1513}
1514
1515func (c *Conn) rxBackendKeyData(msg *pgproto3.BackendKeyData) {
1516	c.pid = msg.ProcessID
1517	c.secretKey = msg.SecretKey
1518}
1519
1520func (c *Conn) rxReadyForQuery(msg *pgproto3.ReadyForQuery) {
1521	c.pendingReadyForQueryCount--
1522	c.txStatus = msg.TxStatus
1523}
1524
1525func (c *Conn) rxRowDescription(msg *pgproto3.RowDescription) []FieldDescription {
1526	fields := make([]FieldDescription, len(msg.Fields))
1527	for i := 0; i < len(fields); i++ {
1528		fields[i].Name = msg.Fields[i].Name
1529		fields[i].Table = pgtype.OID(msg.Fields[i].TableOID)
1530		fields[i].AttributeNumber = msg.Fields[i].TableAttributeNumber
1531		fields[i].DataType = pgtype.OID(msg.Fields[i].DataTypeOID)
1532		fields[i].DataTypeSize = msg.Fields[i].DataTypeSize
1533		fields[i].Modifier = msg.Fields[i].TypeModifier
1534		fields[i].FormatCode = msg.Fields[i].Format
1535	}
1536	return fields
1537}
1538
1539func (c *Conn) rxParameterDescription(msg *pgproto3.ParameterDescription) []pgtype.OID {
1540	parameters := make([]pgtype.OID, len(msg.ParameterOIDs))
1541	for i := 0; i < len(parameters); i++ {
1542		parameters[i] = pgtype.OID(msg.ParameterOIDs[i])
1543	}
1544	return parameters
1545}
1546
1547func (c *Conn) rxNotificationResponse(msg *pgproto3.NotificationResponse) {
1548	n := new(Notification)
1549	n.PID = msg.PID
1550	n.Channel = msg.Channel
1551	n.Payload = msg.Payload
1552	c.notifications = append(c.notifications, n)
1553}
1554
1555func (c *Conn) startTLS(tlsConfig *tls.Config) (err error) {
1556	err = binary.Write(c.conn, binary.BigEndian, []int32{8, 80877103})
1557	if err != nil {
1558		return
1559	}
1560
1561	response := make([]byte, 1)
1562	if _, err = io.ReadFull(c.conn, response); err != nil {
1563		return
1564	}
1565
1566	if response[0] != 'S' {
1567		return ErrTLSRefused
1568	}
1569
1570	c.conn = tls.Client(c.conn, tlsConfig)
1571
1572	return nil
1573}
1574
1575func (c *Conn) txPasswordMessage(password string) (err error) {
1576	buf := c.wbuf
1577	buf = append(buf, 'p')
1578	sp := len(buf)
1579	buf = pgio.AppendInt32(buf, -1)
1580	buf = append(buf, password...)
1581	buf = append(buf, 0)
1582	pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
1583
1584	_, err = c.conn.Write(buf)
1585
1586	return err
1587}
1588
1589func (c *Conn) die(err error) {
1590	c.mux.Lock()
1591	defer c.mux.Unlock()
1592
1593	if c.status == connStatusClosed {
1594		return
1595	}
1596
1597	c.status = connStatusClosed
1598	c.causeOfDeath = err
1599	c.conn.Close()
1600}
1601
1602func (c *Conn) lock() error {
1603	c.mux.Lock()
1604	defer c.mux.Unlock()
1605
1606	if c.status != connStatusIdle {
1607		return ErrConnBusy
1608	}
1609
1610	c.status = connStatusBusy
1611	return nil
1612}
1613
1614func (c *Conn) unlock() error {
1615	c.mux.Lock()
1616	defer c.mux.Unlock()
1617
1618	if c.status != connStatusBusy {
1619		return errors.New("unlock conn that is not busy")
1620	}
1621
1622	c.status = connStatusIdle
1623	return nil
1624}
1625
1626func (c *Conn) shouldLog(lvl int) bool {
1627	return c.logger != nil && c.logLevel >= lvl
1628}
1629
1630func (c *Conn) log(lvl LogLevel, msg string, data map[string]interface{}) {
1631	if data == nil {
1632		data = map[string]interface{}{}
1633	}
1634	if c.pid != 0 {
1635		data["pid"] = c.pid
1636	}
1637
1638	c.logger.Log(lvl, msg, data)
1639}
1640
1641// SetLogger replaces the current logger and returns the previous logger.
1642func (c *Conn) SetLogger(logger Logger) Logger {
1643	oldLogger := c.logger
1644	c.logger = logger
1645	return oldLogger
1646}
1647
1648// SetLogLevel replaces the current log level and returns the previous log
1649// level.
1650func (c *Conn) SetLogLevel(lvl int) (int, error) {
1651	oldLvl := c.logLevel
1652
1653	if lvl < LogLevelNone || lvl > LogLevelTrace {
1654		return oldLvl, ErrInvalidLogLevel
1655	}
1656
1657	c.logLevel = lvl
1658	return lvl, nil
1659}
1660
1661func quoteIdentifier(s string) string {
1662	return `"` + strings.Replace(s, `"`, `""`, -1) + `"`
1663}
1664
1665func doCancel(c *Conn) error {
1666	network, address := c.config.networkAddress()
1667	cancelConn, err := c.config.Dial(network, address)
1668	if err != nil {
1669		return err
1670	}
1671	defer cancelConn.Close()
1672
1673	// If server doesn't process cancellation request in bounded time then abort.
1674	now := time.Now()
1675	err = cancelConn.SetDeadline(now.Add(15 * time.Second))
1676	if err != nil {
1677		return err
1678	}
1679
1680	buf := make([]byte, 16)
1681	binary.BigEndian.PutUint32(buf[0:4], 16)
1682	binary.BigEndian.PutUint32(buf[4:8], 80877102)
1683	binary.BigEndian.PutUint32(buf[8:12], uint32(c.pid))
1684	binary.BigEndian.PutUint32(buf[12:16], uint32(c.secretKey))
1685	_, err = cancelConn.Write(buf)
1686	if err != nil {
1687		return err
1688	}
1689
1690	_, err = cancelConn.Read(buf)
1691	if err != io.EOF {
1692		return errors.Errorf("Server failed to close connection after cancel query request: %v %v", err, buf)
1693	}
1694
1695	return nil
1696}
1697
1698// cancelQuery sends a cancel request to the PostgreSQL server. It returns an
1699// error if unable to deliver the cancel request, but lack of an error does not
1700// ensure that the query was canceled. As specified in the documentation, there
1701// is no way to be sure a query was canceled. See
1702// https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861
1703func (c *Conn) cancelQuery() {
1704	if err := c.conn.SetDeadline(time.Now()); err != nil {
1705		c.Close() // Close connection if unable to set deadline
1706		return
1707	}
1708
1709	var cancelFn func(*Conn) error
1710	completeCh := make(chan struct{})
1711	c.mux.Lock()
1712	c.cancelQueryCompleted = completeCh
1713	c.mux.Unlock()
1714	if c.config.CustomCancel != nil {
1715		cancelFn = c.config.CustomCancel
1716	} else {
1717		cancelFn = doCancel
1718	}
1719
1720	go func() {
1721		defer close(completeCh)
1722		err := cancelFn(c)
1723		if err != nil {
1724			c.Close() // Something is very wrong. Terminate the connection.
1725		}
1726	}()
1727}
1728
1729func (c *Conn) Ping(ctx context.Context) error {
1730	_, err := c.ExecEx(ctx, ";", nil)
1731	return err
1732}
1733
1734func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (CommandTag, error) {
1735	c.lastStmtSent = false
1736	err := c.waitForPreviousCancelQuery(ctx)
1737	if err != nil {
1738		return "", err
1739	}
1740
1741	if err := c.lock(); err != nil {
1742		return "", err
1743	}
1744	defer c.unlock()
1745
1746	startTime := time.Now()
1747	c.lastActivityTime = startTime
1748
1749	commandTag, err := c.execEx(ctx, sql, options, arguments...)
1750	if err != nil {
1751		if c.shouldLog(LogLevelError) {
1752			c.log(LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err})
1753		}
1754		return commandTag, err
1755	}
1756
1757	if c.shouldLog(LogLevelInfo) {
1758		endTime := time.Now()
1759		c.log(LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag})
1760	}
1761
1762	return commandTag, err
1763}
1764
1765func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) {
1766	err = c.initContext(ctx)
1767	if err != nil {
1768		return "", err
1769	}
1770	defer func() {
1771		err = c.termContext(err)
1772	}()
1773
1774	if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) {
1775		c.lastStmtSent = true
1776		err = c.sanitizeAndSendSimpleQuery(sql, arguments...)
1777		if err != nil {
1778			return "", err
1779		}
1780	} else if options != nil && len(options.ParameterOIDs) > 0 {
1781		if err := c.ensureConnectionReadyForQuery(); err != nil {
1782			return "", err
1783		}
1784
1785		buf, err := c.buildOneRoundTripExec(c.wbuf, sql, options, arguments)
1786		if err != nil {
1787			return "", err
1788		}
1789
1790		buf = appendSync(buf)
1791
1792		c.lastStmtSent = true
1793		n, err := c.conn.Write(buf)
1794		if err != nil && fatalWriteErr(n, err) {
1795			c.die(err)
1796			return "", err
1797		}
1798		c.pendingReadyForQueryCount++
1799	} else {
1800		if len(arguments) > 0 {
1801			ps, ok := c.preparedStatements[sql]
1802			if !ok {
1803				var err error
1804				ps, err = c.prepareEx("", sql, nil)
1805				if err != nil {
1806					return "", err
1807				}
1808			}
1809
1810			c.lastStmtSent = true
1811			err = c.sendPreparedQuery(ps, arguments...)
1812			if err != nil {
1813				return "", err
1814			}
1815		} else {
1816			c.lastStmtSent = true
1817			if err = c.sendQuery(sql, arguments...); err != nil {
1818				return
1819			}
1820		}
1821	}
1822
1823	var softErr error
1824
1825	for {
1826		msg, err := c.rxMsg()
1827		if err != nil {
1828			return commandTag, err
1829		}
1830
1831		switch msg := msg.(type) {
1832		case *pgproto3.ReadyForQuery:
1833			c.rxReadyForQuery(msg)
1834			return commandTag, softErr
1835		case *pgproto3.CommandComplete:
1836			commandTag = CommandTag(msg.CommandTag)
1837		default:
1838			if e := c.processContextFreeMsg(msg); e != nil && softErr == nil {
1839				softErr = e
1840			}
1841		}
1842	}
1843}
1844
1845func (c *Conn) buildOneRoundTripExec(buf []byte, sql string, options *QueryExOptions, arguments []interface{}) ([]byte, error) {
1846	if len(arguments) != len(options.ParameterOIDs) {
1847		return nil, errors.Errorf("mismatched number of arguments (%d) and options.ParameterOIDs (%d)", len(arguments), len(options.ParameterOIDs))
1848	}
1849
1850	if len(options.ParameterOIDs) > 65535 {
1851		return nil, errors.Errorf("Number of QueryExOptions ParameterOIDs must be between 0 and 65535, received %d", len(options.ParameterOIDs))
1852	}
1853
1854	buf = appendParse(buf, "", sql, options.ParameterOIDs)
1855	buf, err := appendBind(buf, "", "", c.ConnInfo, options.ParameterOIDs, arguments, nil)
1856	if err != nil {
1857		return nil, err
1858	}
1859	buf = appendExecute(buf, "", 0)
1860
1861	return buf, nil
1862}
1863
1864func (c *Conn) initContext(ctx context.Context) error {
1865	if c.ctxInProgress {
1866		return errors.New("ctx already in progress")
1867	}
1868
1869	if ctx.Done() == nil {
1870		return nil
1871	}
1872
1873	select {
1874	case <-ctx.Done():
1875		return ctx.Err()
1876	default:
1877	}
1878
1879	c.ctxInProgress = true
1880
1881	go c.contextHandler(ctx)
1882
1883	return nil
1884}
1885
1886func (c *Conn) termContext(opErr error) error {
1887	if !c.ctxInProgress {
1888		return opErr
1889	}
1890
1891	var err error
1892
1893	select {
1894	case err = <-c.closedChan:
1895		if opErr == nil {
1896			err = nil
1897		}
1898	case c.doneChan <- struct{}{}:
1899		err = opErr
1900	}
1901
1902	c.ctxInProgress = false
1903	return err
1904}
1905
1906func (c *Conn) contextHandler(ctx context.Context) {
1907	select {
1908	case <-ctx.Done():
1909		c.cancelQuery()
1910		c.closedChan <- ctx.Err()
1911	case <-c.doneChan:
1912	}
1913}
1914
1915// WaitUntilReady will return when the connection is ready for another query
1916func (c *Conn) WaitUntilReady(ctx context.Context) error {
1917	err := c.waitForPreviousCancelQuery(ctx)
1918	if err != nil {
1919		return err
1920	}
1921	return c.ensureConnectionReadyForQuery()
1922}
1923
1924func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error {
1925	c.mux.Lock()
1926	completeCh := c.cancelQueryCompleted
1927	c.mux.Unlock()
1928	select {
1929	case <-completeCh:
1930		if err := c.conn.SetDeadline(time.Time{}); err != nil {
1931			c.Close() // Close connection if unable to disable deadline
1932			return err
1933		}
1934		return nil
1935	case <-ctx.Done():
1936		return ctx.Err()
1937	}
1938}
1939
1940func (c *Conn) ensureConnectionReadyForQuery() error {
1941	for c.pendingReadyForQueryCount > 0 {
1942		msg, err := c.rxMsg()
1943		if err != nil {
1944			return err
1945		}
1946
1947		switch msg := msg.(type) {
1948		case *pgproto3.ErrorResponse:
1949			pgErr := c.rxErrorResponse(msg)
1950			if pgErr.Severity == "FATAL" {
1951				return pgErr
1952			}
1953		default:
1954			err = c.processContextFreeMsg(msg)
1955			if err != nil {
1956				return err
1957			}
1958		}
1959	}
1960
1961	return nil
1962}
1963
1964func connInfoFromRows(rows *Rows, err error) (map[string]pgtype.OID, error) {
1965	if err != nil {
1966		return nil, err
1967	}
1968	defer rows.Close()
1969
1970	nameOIDs := make(map[string]pgtype.OID, 256)
1971	for rows.Next() {
1972		var oid pgtype.OID
1973		var name pgtype.Text
1974		if err = rows.Scan(&oid, &name); err != nil {
1975			return nil, err
1976		}
1977
1978		nameOIDs[name.String] = oid
1979	}
1980
1981	if err = rows.Err(); err != nil {
1982		return nil, err
1983	}
1984
1985	return nameOIDs, err
1986}
1987
1988// LastStmtSent returns true if the last call to Query(Ex)/Exec(Ex) attempted to
1989// send the statement over the wire. Each call to a Query(Ex)/Exec(Ex) resets
1990// the value to false initially until the statement has been sent. This does
1991// NOT mean that the statement was successful or even received, it just means
1992// that a write was attempted and therefore it could have been executed. Calls
1993// to prepare a statement are ignored, only when the prepared statement is
1994// attempted to be executed will this return true.
1995func (c *Conn) LastStmtSent() bool {
1996	return c.lastStmtSent
1997}
1998