1package pq
2
3import (
4	"bufio"
5	"context"
6	"crypto/md5"
7	"crypto/sha256"
8	"database/sql"
9	"database/sql/driver"
10	"encoding/binary"
11	"errors"
12	"fmt"
13	"io"
14	"net"
15	"os"
16	"os/user"
17	"path"
18	"path/filepath"
19	"strconv"
20	"strings"
21	"sync"
22	"time"
23	"unicode"
24
25	"github.com/lib/pq/oid"
26	"github.com/lib/pq/scram"
27)
28
29// Common error types
30var (
31	ErrNotSupported              = errors.New("pq: Unsupported command")
32	ErrInFailedTransaction       = errors.New("pq: Could not complete operation in a failed transaction")
33	ErrSSLNotSupported           = errors.New("pq: SSL is not enabled on the server")
34	ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less")
35	ErrCouldNotDetectUsername    = errors.New("pq: Could not detect default username. Please provide one explicitly")
36
37	errUnexpectedReady = errors.New("unexpected ReadyForQuery")
38	errNoRowsAffected  = errors.New("no RowsAffected available after the empty statement")
39	errNoLastInsertID  = errors.New("no LastInsertId available after the empty statement")
40)
41
42// Compile time validation that our types implement the expected interfaces
43var (
44	_ driver.Driver = Driver{}
45)
46
47// Driver is the Postgres database driver.
48type Driver struct{}
49
50// Open opens a new connection to the database. name is a connection string.
51// Most users should only use it through database/sql package from the standard
52// library.
53func (d Driver) Open(name string) (driver.Conn, error) {
54	return Open(name)
55}
56
57func init() {
58	sql.Register("postgres", &Driver{})
59}
60
61type parameterStatus struct {
62	// server version in the same format as server_version_num, or 0 if
63	// unavailable
64	serverVersion int
65
66	// the current location based on the TimeZone value of the session, if
67	// available
68	currentLocation *time.Location
69}
70
71type transactionStatus byte
72
73const (
74	txnStatusIdle                transactionStatus = 'I'
75	txnStatusIdleInTransaction   transactionStatus = 'T'
76	txnStatusInFailedTransaction transactionStatus = 'E'
77)
78
79func (s transactionStatus) String() string {
80	switch s {
81	case txnStatusIdle:
82		return "idle"
83	case txnStatusIdleInTransaction:
84		return "idle in transaction"
85	case txnStatusInFailedTransaction:
86		return "in a failed transaction"
87	default:
88		errorf("unknown transactionStatus %d", s)
89	}
90
91	panic("not reached")
92}
93
94// Dialer is the dialer interface. It can be used to obtain more control over
95// how pq creates network connections.
96type Dialer interface {
97	Dial(network, address string) (net.Conn, error)
98	DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
99}
100
101// DialerContext is the context-aware dialer interface.
102type DialerContext interface {
103	DialContext(ctx context.Context, network, address string) (net.Conn, error)
104}
105
106type defaultDialer struct {
107	d net.Dialer
108}
109
110func (d defaultDialer) Dial(network, address string) (net.Conn, error) {
111	return d.d.Dial(network, address)
112}
113func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
114	ctx, cancel := context.WithTimeout(context.Background(), timeout)
115	defer cancel()
116	return d.DialContext(ctx, network, address)
117}
118func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
119	return d.d.DialContext(ctx, network, address)
120}
121
122type conn struct {
123	c         net.Conn
124	buf       *bufio.Reader
125	namei     int
126	scratch   [512]byte
127	txnStatus transactionStatus
128	txnFinish func()
129
130	// Save connection arguments to use during CancelRequest.
131	dialer Dialer
132	opts   values
133
134	// Cancellation key data for use with CancelRequest messages.
135	processID int
136	secretKey int
137
138	parameterStatus parameterStatus
139
140	saveMessageType   byte
141	saveMessageBuffer []byte
142
143	// If an error is set, this connection is bad and all public-facing
144	// functions should return the appropriate error by calling get()
145	// (ErrBadConn) or getForNext().
146	err syncErr
147
148	// If set, this connection should never use the binary format when
149	// receiving query results from prepared statements.  Only provided for
150	// debugging.
151	disablePreparedBinaryResult bool
152
153	// Whether to always send []byte parameters over as binary.  Enables single
154	// round-trip mode for non-prepared Query calls.
155	binaryParameters bool
156
157	// If true this connection is in the middle of a COPY
158	inCopy bool
159
160	// If not nil, notices will be synchronously sent here
161	noticeHandler func(*Error)
162
163	// If not nil, notifications will be synchronously sent here
164	notificationHandler func(*Notification)
165
166	// GSSAPI context
167	gss GSS
168}
169
170type syncErr struct {
171	err error
172	sync.Mutex
173}
174
175// Return ErrBadConn if connection is bad.
176func (e *syncErr) get() error {
177	e.Lock()
178	defer e.Unlock()
179	if e.err != nil {
180		return driver.ErrBadConn
181	}
182	return nil
183}
184
185// Return the error set on the connection. Currently only used by rows.Next.
186func (e *syncErr) getForNext() error {
187	e.Lock()
188	defer e.Unlock()
189	return e.err
190}
191
192// Set error, only if it isn't set yet.
193func (e *syncErr) set(err error) {
194	if err == nil {
195		panic("attempt to set nil err")
196	}
197	e.Lock()
198	defer e.Unlock()
199	if e.err == nil {
200		e.err = err
201	}
202}
203
204// Handle driver-side settings in parsed connection string.
205func (cn *conn) handleDriverSettings(o values) (err error) {
206	boolSetting := func(key string, val *bool) error {
207		if value, ok := o[key]; ok {
208			if value == "yes" {
209				*val = true
210			} else if value == "no" {
211				*val = false
212			} else {
213				return fmt.Errorf("unrecognized value %q for %s", value, key)
214			}
215		}
216		return nil
217	}
218
219	err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult)
220	if err != nil {
221		return err
222	}
223	return boolSetting("binary_parameters", &cn.binaryParameters)
224}
225
226func (cn *conn) handlePgpass(o values) {
227	// if a password was supplied, do not process .pgpass
228	if _, ok := o["password"]; ok {
229		return
230	}
231	filename := os.Getenv("PGPASSFILE")
232	if filename == "" {
233		// XXX this code doesn't work on Windows where the default filename is
234		// XXX %APPDATA%\postgresql\pgpass.conf
235		// Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470
236		userHome := os.Getenv("HOME")
237		if userHome == "" {
238			user, err := user.Current()
239			if err != nil {
240				return
241			}
242			userHome = user.HomeDir
243		}
244		filename = filepath.Join(userHome, ".pgpass")
245	}
246	fileinfo, err := os.Stat(filename)
247	if err != nil {
248		return
249	}
250	mode := fileinfo.Mode()
251	if mode&(0x77) != 0 {
252		// XXX should warn about incorrect .pgpass permissions as psql does
253		return
254	}
255	file, err := os.Open(filename)
256	if err != nil {
257		return
258	}
259	defer file.Close()
260	scanner := bufio.NewScanner(io.Reader(file))
261	hostname := o["host"]
262	ntw, _ := network(o)
263	port := o["port"]
264	db := o["dbname"]
265	username := o["user"]
266	// From: https://github.com/tg/pgpass/blob/master/reader.go
267	getFields := func(s string) []string {
268		fs := make([]string, 0, 5)
269		f := make([]rune, 0, len(s))
270
271		var esc bool
272		for _, c := range s {
273			switch {
274			case esc:
275				f = append(f, c)
276				esc = false
277			case c == '\\':
278				esc = true
279			case c == ':':
280				fs = append(fs, string(f))
281				f = f[:0]
282			default:
283				f = append(f, c)
284			}
285		}
286		return append(fs, string(f))
287	}
288	for scanner.Scan() {
289		line := scanner.Text()
290		if len(line) == 0 || line[0] == '#' {
291			continue
292		}
293		split := getFields(line)
294		if len(split) != 5 {
295			continue
296		}
297		if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) {
298			o["password"] = split[4]
299			return
300		}
301	}
302}
303
304func (cn *conn) writeBuf(b byte) *writeBuf {
305	cn.scratch[0] = b
306	return &writeBuf{
307		buf: cn.scratch[:5],
308		pos: 1,
309	}
310}
311
312// Open opens a new connection to the database. dsn is a connection string.
313// Most users should only use it through database/sql package from the standard
314// library.
315func Open(dsn string) (_ driver.Conn, err error) {
316	return DialOpen(defaultDialer{}, dsn)
317}
318
319// DialOpen opens a new connection to the database using a dialer.
320func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) {
321	c, err := NewConnector(dsn)
322	if err != nil {
323		return nil, err
324	}
325	c.dialer = d
326	return c.open(context.Background())
327}
328
329func (c *Connector) open(ctx context.Context) (cn *conn, err error) {
330	// Handle any panics during connection initialization.  Note that we
331	// specifically do *not* want to use errRecover(), as that would turn any
332	// connection errors into ErrBadConns, hiding the real error message from
333	// the user.
334	defer errRecoverNoErrBadConn(&err)
335
336	// Create a new values map (copy). This makes it so maps in different
337	// connections do not reference the same underlying data structure, so it
338	// is safe for multiple connections to concurrently write to their opts.
339	o := make(values)
340	for k, v := range c.opts {
341		o[k] = v
342	}
343
344	cn = &conn{
345		opts:   o,
346		dialer: c.dialer,
347	}
348	err = cn.handleDriverSettings(o)
349	if err != nil {
350		return nil, err
351	}
352	cn.handlePgpass(o)
353
354	cn.c, err = dial(ctx, c.dialer, o)
355	if err != nil {
356		return nil, err
357	}
358
359	err = cn.ssl(o)
360	if err != nil {
361		if cn.c != nil {
362			cn.c.Close()
363		}
364		return nil, err
365	}
366
367	// cn.startup panics on error. Make sure we don't leak cn.c.
368	panicking := true
369	defer func() {
370		if panicking {
371			cn.c.Close()
372		}
373	}()
374
375	cn.buf = bufio.NewReader(cn.c)
376	cn.startup(o)
377
378	// reset the deadline, in case one was set (see dial)
379	if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
380		err = cn.c.SetDeadline(time.Time{})
381	}
382	panicking = false
383	return cn, err
384}
385
386func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) {
387	network, address := network(o)
388
389	// Zero or not specified means wait indefinitely.
390	if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
391		seconds, err := strconv.ParseInt(timeout, 10, 0)
392		if err != nil {
393			return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err)
394		}
395		duration := time.Duration(seconds) * time.Second
396
397		// connect_timeout should apply to the entire connection establishment
398		// procedure, so we both use a timeout for the TCP connection
399		// establishment and set a deadline for doing the initial handshake.
400		// The deadline is then reset after startup() is done.
401		deadline := time.Now().Add(duration)
402		var conn net.Conn
403		if dctx, ok := d.(DialerContext); ok {
404			ctx, cancel := context.WithTimeout(ctx, duration)
405			defer cancel()
406			conn, err = dctx.DialContext(ctx, network, address)
407		} else {
408			conn, err = d.DialTimeout(network, address, duration)
409		}
410		if err != nil {
411			return nil, err
412		}
413		err = conn.SetDeadline(deadline)
414		return conn, err
415	}
416	if dctx, ok := d.(DialerContext); ok {
417		return dctx.DialContext(ctx, network, address)
418	}
419	return d.Dial(network, address)
420}
421
422func network(o values) (string, string) {
423	host := o["host"]
424
425	if strings.HasPrefix(host, "/") {
426		sockPath := path.Join(host, ".s.PGSQL."+o["port"])
427		return "unix", sockPath
428	}
429
430	return "tcp", net.JoinHostPort(host, o["port"])
431}
432
433type values map[string]string
434
435// scanner implements a tokenizer for libpq-style option strings.
436type scanner struct {
437	s []rune
438	i int
439}
440
441// newScanner returns a new scanner initialized with the option string s.
442func newScanner(s string) *scanner {
443	return &scanner{[]rune(s), 0}
444}
445
446// Next returns the next rune.
447// It returns 0, false if the end of the text has been reached.
448func (s *scanner) Next() (rune, bool) {
449	if s.i >= len(s.s) {
450		return 0, false
451	}
452	r := s.s[s.i]
453	s.i++
454	return r, true
455}
456
457// SkipSpaces returns the next non-whitespace rune.
458// It returns 0, false if the end of the text has been reached.
459func (s *scanner) SkipSpaces() (rune, bool) {
460	r, ok := s.Next()
461	for unicode.IsSpace(r) && ok {
462		r, ok = s.Next()
463	}
464	return r, ok
465}
466
467// parseOpts parses the options from name and adds them to the values.
468//
469// The parsing code is based on conninfo_parse from libpq's fe-connect.c
470func parseOpts(name string, o values) error {
471	s := newScanner(name)
472
473	for {
474		var (
475			keyRunes, valRunes []rune
476			r                  rune
477			ok                 bool
478		)
479
480		if r, ok = s.SkipSpaces(); !ok {
481			break
482		}
483
484		// Scan the key
485		for !unicode.IsSpace(r) && r != '=' {
486			keyRunes = append(keyRunes, r)
487			if r, ok = s.Next(); !ok {
488				break
489			}
490		}
491
492		// Skip any whitespace if we're not at the = yet
493		if r != '=' {
494			r, ok = s.SkipSpaces()
495		}
496
497		// The current character should be =
498		if r != '=' || !ok {
499			return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
500		}
501
502		// Skip any whitespace after the =
503		if r, ok = s.SkipSpaces(); !ok {
504			// If we reach the end here, the last value is just an empty string as per libpq.
505			o[string(keyRunes)] = ""
506			break
507		}
508
509		if r != '\'' {
510			for !unicode.IsSpace(r) {
511				if r == '\\' {
512					if r, ok = s.Next(); !ok {
513						return fmt.Errorf(`missing character after backslash`)
514					}
515				}
516				valRunes = append(valRunes, r)
517
518				if r, ok = s.Next(); !ok {
519					break
520				}
521			}
522		} else {
523		quote:
524			for {
525				if r, ok = s.Next(); !ok {
526					return fmt.Errorf(`unterminated quoted string literal in connection string`)
527				}
528				switch r {
529				case '\'':
530					break quote
531				case '\\':
532					r, _ = s.Next()
533					fallthrough
534				default:
535					valRunes = append(valRunes, r)
536				}
537			}
538		}
539
540		o[string(keyRunes)] = string(valRunes)
541	}
542
543	return nil
544}
545
546func (cn *conn) isInTransaction() bool {
547	return cn.txnStatus == txnStatusIdleInTransaction ||
548		cn.txnStatus == txnStatusInFailedTransaction
549}
550
551func (cn *conn) checkIsInTransaction(intxn bool) {
552	if cn.isInTransaction() != intxn {
553		cn.err.set(driver.ErrBadConn)
554		errorf("unexpected transaction status %v", cn.txnStatus)
555	}
556}
557
558func (cn *conn) Begin() (_ driver.Tx, err error) {
559	return cn.begin("")
560}
561
562func (cn *conn) begin(mode string) (_ driver.Tx, err error) {
563	if err := cn.err.get(); err != nil {
564		return nil, err
565	}
566	defer cn.errRecover(&err)
567
568	cn.checkIsInTransaction(false)
569	_, commandTag, err := cn.simpleExec("BEGIN" + mode)
570	if err != nil {
571		return nil, err
572	}
573	if commandTag != "BEGIN" {
574		cn.err.set(driver.ErrBadConn)
575		return nil, fmt.Errorf("unexpected command tag %s", commandTag)
576	}
577	if cn.txnStatus != txnStatusIdleInTransaction {
578		cn.err.set(driver.ErrBadConn)
579		return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus)
580	}
581	return cn, nil
582}
583
584func (cn *conn) closeTxn() {
585	if finish := cn.txnFinish; finish != nil {
586		finish()
587	}
588}
589
590func (cn *conn) Commit() (err error) {
591	defer cn.closeTxn()
592	if err := cn.err.get(); err != nil {
593		return err
594	}
595	defer cn.errRecover(&err)
596
597	cn.checkIsInTransaction(true)
598	// We don't want the client to think that everything is okay if it tries
599	// to commit a failed transaction.  However, no matter what we return,
600	// database/sql will release this connection back into the free connection
601	// pool so we have to abort the current transaction here.  Note that you
602	// would get the same behaviour if you issued a COMMIT in a failed
603	// transaction, so it's also the least surprising thing to do here.
604	if cn.txnStatus == txnStatusInFailedTransaction {
605		if err := cn.rollback(); err != nil {
606			return err
607		}
608		return ErrInFailedTransaction
609	}
610
611	_, commandTag, err := cn.simpleExec("COMMIT")
612	if err != nil {
613		if cn.isInTransaction() {
614			cn.err.set(driver.ErrBadConn)
615		}
616		return err
617	}
618	if commandTag != "COMMIT" {
619		cn.err.set(driver.ErrBadConn)
620		return fmt.Errorf("unexpected command tag %s", commandTag)
621	}
622	cn.checkIsInTransaction(false)
623	return nil
624}
625
626func (cn *conn) Rollback() (err error) {
627	defer cn.closeTxn()
628	if err := cn.err.get(); err != nil {
629		return err
630	}
631	defer cn.errRecover(&err)
632	return cn.rollback()
633}
634
635func (cn *conn) rollback() (err error) {
636	cn.checkIsInTransaction(true)
637	_, commandTag, err := cn.simpleExec("ROLLBACK")
638	if err != nil {
639		if cn.isInTransaction() {
640			cn.err.set(driver.ErrBadConn)
641		}
642		return err
643	}
644	if commandTag != "ROLLBACK" {
645		return fmt.Errorf("unexpected command tag %s", commandTag)
646	}
647	cn.checkIsInTransaction(false)
648	return nil
649}
650
651func (cn *conn) gname() string {
652	cn.namei++
653	return strconv.FormatInt(int64(cn.namei), 10)
654}
655
656func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) {
657	b := cn.writeBuf('Q')
658	b.string(q)
659	cn.send(b)
660
661	for {
662		t, r := cn.recv1()
663		switch t {
664		case 'C':
665			res, commandTag = cn.parseComplete(r.string())
666		case 'Z':
667			cn.processReadyForQuery(r)
668			if res == nil && err == nil {
669				err = errUnexpectedReady
670			}
671			// done
672			return
673		case 'E':
674			err = parseError(r)
675		case 'I':
676			res = emptyRows
677		case 'T', 'D':
678			// ignore any results
679		default:
680			cn.err.set(driver.ErrBadConn)
681			errorf("unknown response for simple query: %q", t)
682		}
683	}
684}
685
686func (cn *conn) simpleQuery(q string) (res *rows, err error) {
687	defer cn.errRecover(&err)
688
689	b := cn.writeBuf('Q')
690	b.string(q)
691	cn.send(b)
692
693	for {
694		t, r := cn.recv1()
695		switch t {
696		case 'C', 'I':
697			// We allow queries which don't return any results through Query as
698			// well as Exec.  We still have to give database/sql a rows object
699			// the user can close, though, to avoid connections from being
700			// leaked.  A "rows" with done=true works fine for that purpose.
701			if err != nil {
702				cn.err.set(driver.ErrBadConn)
703				errorf("unexpected message %q in simple query execution", t)
704			}
705			if res == nil {
706				res = &rows{
707					cn: cn,
708				}
709			}
710			// Set the result and tag to the last command complete if there wasn't a
711			// query already run. Although queries usually return from here and cede
712			// control to Next, a query with zero results does not.
713			if t == 'C' {
714				res.result, res.tag = cn.parseComplete(r.string())
715				if res.colNames != nil {
716					return
717				}
718			}
719			res.done = true
720		case 'Z':
721			cn.processReadyForQuery(r)
722			// done
723			return
724		case 'E':
725			res = nil
726			err = parseError(r)
727		case 'D':
728			if res == nil {
729				cn.err.set(driver.ErrBadConn)
730				errorf("unexpected DataRow in simple query execution")
731			}
732			// the query didn't fail; kick off to Next
733			cn.saveMessage(t, r)
734			return
735		case 'T':
736			// res might be non-nil here if we received a previous
737			// CommandComplete, but that's fine; just overwrite it
738			res = &rows{cn: cn}
739			res.rowsHeader = parsePortalRowDescribe(r)
740
741			// To work around a bug in QueryRow in Go 1.2 and earlier, wait
742			// until the first DataRow has been received.
743		default:
744			cn.err.set(driver.ErrBadConn)
745			errorf("unknown response for simple query: %q", t)
746		}
747	}
748}
749
750type noRows struct{}
751
752var emptyRows noRows
753
754var _ driver.Result = noRows{}
755
756func (noRows) LastInsertId() (int64, error) {
757	return 0, errNoLastInsertID
758}
759
760func (noRows) RowsAffected() (int64, error) {
761	return 0, errNoRowsAffected
762}
763
764// Decides which column formats to use for a prepared statement.  The input is
765// an array of type oids, one element per result column.
766func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte) {
767	if len(colTyps) == 0 {
768		return nil, colFmtDataAllText
769	}
770
771	colFmts = make([]format, len(colTyps))
772	if forceText {
773		return colFmts, colFmtDataAllText
774	}
775
776	allBinary := true
777	allText := true
778	for i, t := range colTyps {
779		switch t.OID {
780		// This is the list of types to use binary mode for when receiving them
781		// through a prepared statement.  If a type appears in this list, it
782		// must also be implemented in binaryDecode in encode.go.
783		case oid.T_bytea:
784			fallthrough
785		case oid.T_int8:
786			fallthrough
787		case oid.T_int4:
788			fallthrough
789		case oid.T_int2:
790			fallthrough
791		case oid.T_uuid:
792			colFmts[i] = formatBinary
793			allText = false
794
795		default:
796			allBinary = false
797		}
798	}
799
800	if allBinary {
801		return colFmts, colFmtDataAllBinary
802	} else if allText {
803		return colFmts, colFmtDataAllText
804	} else {
805		colFmtData = make([]byte, 2+len(colFmts)*2)
806		binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts)))
807		for i, v := range colFmts {
808			binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v))
809		}
810		return colFmts, colFmtData
811	}
812}
813
814func (cn *conn) prepareTo(q, stmtName string) *stmt {
815	st := &stmt{cn: cn, name: stmtName}
816
817	b := cn.writeBuf('P')
818	b.string(st.name)
819	b.string(q)
820	b.int16(0)
821
822	b.next('D')
823	b.byte('S')
824	b.string(st.name)
825
826	b.next('S')
827	cn.send(b)
828
829	cn.readParseResponse()
830	st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse()
831	st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult)
832	cn.readReadyForQuery()
833	return st
834}
835
836func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
837	if err := cn.err.get(); err != nil {
838		return nil, err
839	}
840	defer cn.errRecover(&err)
841
842	if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") {
843		s, err := cn.prepareCopyIn(q)
844		if err == nil {
845			cn.inCopy = true
846		}
847		return s, err
848	}
849	return cn.prepareTo(q, cn.gname()), nil
850}
851
852func (cn *conn) Close() (err error) {
853	// Skip cn.bad return here because we always want to close a connection.
854	defer cn.errRecover(&err)
855
856	// Ensure that cn.c.Close is always run. Since error handling is done with
857	// panics and cn.errRecover, the Close must be in a defer.
858	defer func() {
859		cerr := cn.c.Close()
860		if err == nil {
861			err = cerr
862		}
863	}()
864
865	// Don't go through send(); ListenerConn relies on us not scribbling on the
866	// scratch buffer of this connection.
867	return cn.sendSimpleMessage('X')
868}
869
870// Implement the "Queryer" interface
871func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
872	return cn.query(query, args)
873}
874
875func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
876	if err := cn.err.get(); err != nil {
877		return nil, err
878	}
879	if cn.inCopy {
880		return nil, errCopyInProgress
881	}
882	defer cn.errRecover(&err)
883
884	// Check to see if we can use the "simpleQuery" interface, which is
885	// *much* faster than going through prepare/exec
886	if len(args) == 0 {
887		return cn.simpleQuery(query)
888	}
889
890	if cn.binaryParameters {
891		cn.sendBinaryModeQuery(query, args)
892
893		cn.readParseResponse()
894		cn.readBindResponse()
895		rows := &rows{cn: cn}
896		rows.rowsHeader = cn.readPortalDescribeResponse()
897		cn.postExecuteWorkaround()
898		return rows, nil
899	}
900	st := cn.prepareTo(query, "")
901	st.exec(args)
902	return &rows{
903		cn:         cn,
904		rowsHeader: st.rowsHeader,
905	}, nil
906}
907
908// Implement the optional "Execer" interface for one-shot queries
909func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) {
910	if err := cn.err.get(); err != nil {
911		return nil, err
912	}
913	defer cn.errRecover(&err)
914
915	// Check to see if we can use the "simpleExec" interface, which is
916	// *much* faster than going through prepare/exec
917	if len(args) == 0 {
918		// ignore commandTag, our caller doesn't care
919		r, _, err := cn.simpleExec(query)
920		return r, err
921	}
922
923	if cn.binaryParameters {
924		cn.sendBinaryModeQuery(query, args)
925
926		cn.readParseResponse()
927		cn.readBindResponse()
928		cn.readPortalDescribeResponse()
929		cn.postExecuteWorkaround()
930		res, _, err = cn.readExecuteResponse("Execute")
931		return res, err
932	}
933	// Use the unnamed statement to defer planning until bind
934	// time, or else value-based selectivity estimates cannot be
935	// used.
936	st := cn.prepareTo(query, "")
937	r, err := st.Exec(args)
938	if err != nil {
939		panic(err)
940	}
941	return r, err
942}
943
944type safeRetryError struct {
945	Err error
946}
947
948func (se *safeRetryError) Error() string {
949	return se.Err.Error()
950}
951
952func (cn *conn) send(m *writeBuf) {
953	n, err := cn.c.Write(m.wrap())
954	if err != nil {
955		if n == 0 {
956			err = &safeRetryError{Err: err}
957		}
958		panic(err)
959	}
960}
961
962func (cn *conn) sendStartupPacket(m *writeBuf) error {
963	_, err := cn.c.Write((m.wrap())[1:])
964	return err
965}
966
967// Send a message of type typ to the server on the other end of cn.  The
968// message should have no payload.  This method does not use the scratch
969// buffer.
970func (cn *conn) sendSimpleMessage(typ byte) (err error) {
971	_, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'})
972	return err
973}
974
975// saveMessage memorizes a message and its buffer in the conn struct.
976// recvMessage will then return these values on the next call to it.  This
977// method is useful in cases where you have to see what the next message is
978// going to be (e.g. to see whether it's an error or not) but you can't handle
979// the message yourself.
980func (cn *conn) saveMessage(typ byte, buf *readBuf) {
981	if cn.saveMessageType != 0 {
982		cn.err.set(driver.ErrBadConn)
983		errorf("unexpected saveMessageType %d", cn.saveMessageType)
984	}
985	cn.saveMessageType = typ
986	cn.saveMessageBuffer = *buf
987}
988
989// recvMessage receives any message from the backend, or returns an error if
990// a problem occurred while reading the message.
991func (cn *conn) recvMessage(r *readBuf) (byte, error) {
992	// workaround for a QueryRow bug, see exec
993	if cn.saveMessageType != 0 {
994		t := cn.saveMessageType
995		*r = cn.saveMessageBuffer
996		cn.saveMessageType = 0
997		cn.saveMessageBuffer = nil
998		return t, nil
999	}
1000
1001	x := cn.scratch[:5]
1002	_, err := io.ReadFull(cn.buf, x)
1003	if err != nil {
1004		return 0, err
1005	}
1006
1007	// read the type and length of the message that follows
1008	t := x[0]
1009	n := int(binary.BigEndian.Uint32(x[1:])) - 4
1010	var y []byte
1011	if n <= len(cn.scratch) {
1012		y = cn.scratch[:n]
1013	} else {
1014		y = make([]byte, n)
1015	}
1016	_, err = io.ReadFull(cn.buf, y)
1017	if err != nil {
1018		return 0, err
1019	}
1020	*r = y
1021	return t, nil
1022}
1023
1024// recv receives a message from the backend, but if an error happened while
1025// reading the message or the received message was an ErrorResponse, it panics.
1026// NoticeResponses are ignored.  This function should generally be used only
1027// during the startup sequence.
1028func (cn *conn) recv() (t byte, r *readBuf) {
1029	for {
1030		var err error
1031		r = &readBuf{}
1032		t, err = cn.recvMessage(r)
1033		if err != nil {
1034			panic(err)
1035		}
1036		switch t {
1037		case 'E':
1038			panic(parseError(r))
1039		case 'N':
1040			if n := cn.noticeHandler; n != nil {
1041				n(parseError(r))
1042			}
1043		case 'A':
1044			if n := cn.notificationHandler; n != nil {
1045				n(recvNotification(r))
1046			}
1047		default:
1048			return
1049		}
1050	}
1051}
1052
1053// recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by
1054// the caller to avoid an allocation.
1055func (cn *conn) recv1Buf(r *readBuf) byte {
1056	for {
1057		t, err := cn.recvMessage(r)
1058		if err != nil {
1059			panic(err)
1060		}
1061
1062		switch t {
1063		case 'A':
1064			if n := cn.notificationHandler; n != nil {
1065				n(recvNotification(r))
1066			}
1067		case 'N':
1068			if n := cn.noticeHandler; n != nil {
1069				n(parseError(r))
1070			}
1071		case 'S':
1072			cn.processParameterStatus(r)
1073		default:
1074			return t
1075		}
1076	}
1077}
1078
1079// recv1 receives a message from the backend, panicking if an error occurs
1080// while attempting to read it.  All asynchronous messages are ignored, with
1081// the exception of ErrorResponse.
1082func (cn *conn) recv1() (t byte, r *readBuf) {
1083	r = &readBuf{}
1084	t = cn.recv1Buf(r)
1085	return t, r
1086}
1087
1088func (cn *conn) ssl(o values) error {
1089	upgrade, err := ssl(o)
1090	if err != nil {
1091		return err
1092	}
1093
1094	if upgrade == nil {
1095		// Nothing to do
1096		return nil
1097	}
1098
1099	w := cn.writeBuf(0)
1100	w.int32(80877103)
1101	if err = cn.sendStartupPacket(w); err != nil {
1102		return err
1103	}
1104
1105	b := cn.scratch[:1]
1106	_, err = io.ReadFull(cn.c, b)
1107	if err != nil {
1108		return err
1109	}
1110
1111	if b[0] != 'S' {
1112		return ErrSSLNotSupported
1113	}
1114
1115	cn.c, err = upgrade(cn.c)
1116	return err
1117}
1118
1119// isDriverSetting returns true iff a setting is purely for configuring the
1120// driver's options and should not be sent to the server in the connection
1121// startup packet.
1122func isDriverSetting(key string) bool {
1123	switch key {
1124	case "host", "port":
1125		return true
1126	case "password":
1127		return true
1128	case "sslmode", "sslcert", "sslkey", "sslrootcert", "sslinline":
1129		return true
1130	case "fallback_application_name":
1131		return true
1132	case "connect_timeout":
1133		return true
1134	case "disable_prepared_binary_result":
1135		return true
1136	case "binary_parameters":
1137		return true
1138	case "krbsrvname":
1139		return true
1140	case "krbspn":
1141		return true
1142	default:
1143		return false
1144	}
1145}
1146
1147func (cn *conn) startup(o values) {
1148	w := cn.writeBuf(0)
1149	w.int32(196608)
1150	// Send the backend the name of the database we want to connect to, and the
1151	// user we want to connect as.  Additionally, we send over any run-time
1152	// parameters potentially included in the connection string.  If the server
1153	// doesn't recognize any of them, it will reply with an error.
1154	for k, v := range o {
1155		if isDriverSetting(k) {
1156			// skip options which can't be run-time parameters
1157			continue
1158		}
1159		// The protocol requires us to supply the database name as "database"
1160		// instead of "dbname".
1161		if k == "dbname" {
1162			k = "database"
1163		}
1164		w.string(k)
1165		w.string(v)
1166	}
1167	w.string("")
1168	if err := cn.sendStartupPacket(w); err != nil {
1169		panic(err)
1170	}
1171
1172	for {
1173		t, r := cn.recv()
1174		switch t {
1175		case 'K':
1176			cn.processBackendKeyData(r)
1177		case 'S':
1178			cn.processParameterStatus(r)
1179		case 'R':
1180			cn.auth(r, o)
1181		case 'Z':
1182			cn.processReadyForQuery(r)
1183			return
1184		default:
1185			errorf("unknown response for startup: %q", t)
1186		}
1187	}
1188}
1189
1190func (cn *conn) auth(r *readBuf, o values) {
1191	switch code := r.int32(); code {
1192	case 0:
1193		// OK
1194	case 3:
1195		w := cn.writeBuf('p')
1196		w.string(o["password"])
1197		cn.send(w)
1198
1199		t, r := cn.recv()
1200		if t != 'R' {
1201			errorf("unexpected password response: %q", t)
1202		}
1203
1204		if r.int32() != 0 {
1205			errorf("unexpected authentication response: %q", t)
1206		}
1207	case 5:
1208		s := string(r.next(4))
1209		w := cn.writeBuf('p')
1210		w.string("md5" + md5s(md5s(o["password"]+o["user"])+s))
1211		cn.send(w)
1212
1213		t, r := cn.recv()
1214		if t != 'R' {
1215			errorf("unexpected password response: %q", t)
1216		}
1217
1218		if r.int32() != 0 {
1219			errorf("unexpected authentication response: %q", t)
1220		}
1221	case 7: // GSSAPI, startup
1222		if newGss == nil {
1223			errorf("kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos if you need Kerberos support)")
1224		}
1225		cli, err := newGss()
1226		if err != nil {
1227			errorf("kerberos error: %s", err.Error())
1228		}
1229
1230		var token []byte
1231
1232		if spn, ok := o["krbspn"]; ok {
1233			// Use the supplied SPN if provided..
1234			token, err = cli.GetInitTokenFromSpn(spn)
1235		} else {
1236			// Allow the kerberos service name to be overridden
1237			service := "postgres"
1238			if val, ok := o["krbsrvname"]; ok {
1239				service = val
1240			}
1241
1242			token, err = cli.GetInitToken(o["host"], service)
1243		}
1244
1245		if err != nil {
1246			errorf("failed to get Kerberos ticket: %q", err)
1247		}
1248
1249		w := cn.writeBuf('p')
1250		w.bytes(token)
1251		cn.send(w)
1252
1253		// Store for GSSAPI continue message
1254		cn.gss = cli
1255
1256	case 8: // GSSAPI continue
1257
1258		if cn.gss == nil {
1259			errorf("GSSAPI protocol error")
1260		}
1261
1262		b := []byte(*r)
1263
1264		done, tokOut, err := cn.gss.Continue(b)
1265		if err == nil && !done {
1266			w := cn.writeBuf('p')
1267			w.bytes(tokOut)
1268			cn.send(w)
1269		}
1270
1271		// Errors fall through and read the more detailed message
1272		// from the server..
1273
1274	case 10:
1275		sc := scram.NewClient(sha256.New, o["user"], o["password"])
1276		sc.Step(nil)
1277		if sc.Err() != nil {
1278			errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
1279		}
1280		scOut := sc.Out()
1281
1282		w := cn.writeBuf('p')
1283		w.string("SCRAM-SHA-256")
1284		w.int32(len(scOut))
1285		w.bytes(scOut)
1286		cn.send(w)
1287
1288		t, r := cn.recv()
1289		if t != 'R' {
1290			errorf("unexpected password response: %q", t)
1291		}
1292
1293		if r.int32() != 11 {
1294			errorf("unexpected authentication response: %q", t)
1295		}
1296
1297		nextStep := r.next(len(*r))
1298		sc.Step(nextStep)
1299		if sc.Err() != nil {
1300			errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
1301		}
1302
1303		scOut = sc.Out()
1304		w = cn.writeBuf('p')
1305		w.bytes(scOut)
1306		cn.send(w)
1307
1308		t, r = cn.recv()
1309		if t != 'R' {
1310			errorf("unexpected password response: %q", t)
1311		}
1312
1313		if r.int32() != 12 {
1314			errorf("unexpected authentication response: %q", t)
1315		}
1316
1317		nextStep = r.next(len(*r))
1318		sc.Step(nextStep)
1319		if sc.Err() != nil {
1320			errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
1321		}
1322
1323	default:
1324		errorf("unknown authentication response: %d", code)
1325	}
1326}
1327
1328type format int
1329
1330const formatText format = 0
1331const formatBinary format = 1
1332
1333// One result-column format code with the value 1 (i.e. all binary).
1334var colFmtDataAllBinary = []byte{0, 1, 0, 1}
1335
1336// No result-column format codes (i.e. all text).
1337var colFmtDataAllText = []byte{0, 0}
1338
1339type stmt struct {
1340	cn   *conn
1341	name string
1342	rowsHeader
1343	colFmtData []byte
1344	paramTyps  []oid.Oid
1345	closed     bool
1346}
1347
1348func (st *stmt) Close() (err error) {
1349	if st.closed {
1350		return nil
1351	}
1352	if err := st.cn.err.get(); err != nil {
1353		return err
1354	}
1355	defer st.cn.errRecover(&err)
1356
1357	w := st.cn.writeBuf('C')
1358	w.byte('S')
1359	w.string(st.name)
1360	st.cn.send(w)
1361
1362	st.cn.send(st.cn.writeBuf('S'))
1363
1364	t, _ := st.cn.recv1()
1365	if t != '3' {
1366		st.cn.err.set(driver.ErrBadConn)
1367		errorf("unexpected close response: %q", t)
1368	}
1369	st.closed = true
1370
1371	t, r := st.cn.recv1()
1372	if t != 'Z' {
1373		st.cn.err.set(driver.ErrBadConn)
1374		errorf("expected ready for query, but got: %q", t)
1375	}
1376	st.cn.processReadyForQuery(r)
1377
1378	return nil
1379}
1380
1381func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
1382	return st.query(v)
1383}
1384
1385func (st *stmt) query(v []driver.Value) (r *rows, err error) {
1386	if err := st.cn.err.get(); err != nil {
1387		return nil, err
1388	}
1389	defer st.cn.errRecover(&err)
1390
1391	st.exec(v)
1392	return &rows{
1393		cn:         st.cn,
1394		rowsHeader: st.rowsHeader,
1395	}, nil
1396}
1397
1398func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
1399	if err := st.cn.err.get(); err != nil {
1400		return nil, err
1401	}
1402	defer st.cn.errRecover(&err)
1403
1404	st.exec(v)
1405	res, _, err = st.cn.readExecuteResponse("simple query")
1406	return res, err
1407}
1408
1409func (st *stmt) exec(v []driver.Value) {
1410	if len(v) >= 65536 {
1411		errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v))
1412	}
1413	if len(v) != len(st.paramTyps) {
1414		errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps))
1415	}
1416
1417	cn := st.cn
1418	w := cn.writeBuf('B')
1419	w.byte(0) // unnamed portal
1420	w.string(st.name)
1421
1422	if cn.binaryParameters {
1423		cn.sendBinaryParameters(w, v)
1424	} else {
1425		w.int16(0)
1426		w.int16(len(v))
1427		for i, x := range v {
1428			if x == nil {
1429				w.int32(-1)
1430			} else {
1431				b := encode(&cn.parameterStatus, x, st.paramTyps[i])
1432				w.int32(len(b))
1433				w.bytes(b)
1434			}
1435		}
1436	}
1437	w.bytes(st.colFmtData)
1438
1439	w.next('E')
1440	w.byte(0)
1441	w.int32(0)
1442
1443	w.next('S')
1444	cn.send(w)
1445
1446	cn.readBindResponse()
1447	cn.postExecuteWorkaround()
1448
1449}
1450
1451func (st *stmt) NumInput() int {
1452	return len(st.paramTyps)
1453}
1454
1455// parseComplete parses the "command tag" from a CommandComplete message, and
1456// returns the number of rows affected (if applicable) and a string
1457// identifying only the command that was executed, e.g. "ALTER TABLE".  If the
1458// command tag could not be parsed, parseComplete panics.
1459func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
1460	commandsWithAffectedRows := []string{
1461		"SELECT ",
1462		// INSERT is handled below
1463		"UPDATE ",
1464		"DELETE ",
1465		"FETCH ",
1466		"MOVE ",
1467		"COPY ",
1468	}
1469
1470	var affectedRows *string
1471	for _, tag := range commandsWithAffectedRows {
1472		if strings.HasPrefix(commandTag, tag) {
1473			t := commandTag[len(tag):]
1474			affectedRows = &t
1475			commandTag = tag[:len(tag)-1]
1476			break
1477		}
1478	}
1479	// INSERT also includes the oid of the inserted row in its command tag.
1480	// Oids in user tables are deprecated, and the oid is only returned when
1481	// exactly one row is inserted, so it's unlikely to be of value to any
1482	// real-world application and we can ignore it.
1483	if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") {
1484		parts := strings.Split(commandTag, " ")
1485		if len(parts) != 3 {
1486			cn.err.set(driver.ErrBadConn)
1487			errorf("unexpected INSERT command tag %s", commandTag)
1488		}
1489		affectedRows = &parts[len(parts)-1]
1490		commandTag = "INSERT"
1491	}
1492	// There should be no affected rows attached to the tag, just return it
1493	if affectedRows == nil {
1494		return driver.RowsAffected(0), commandTag
1495	}
1496	n, err := strconv.ParseInt(*affectedRows, 10, 64)
1497	if err != nil {
1498		cn.err.set(driver.ErrBadConn)
1499		errorf("could not parse commandTag: %s", err)
1500	}
1501	return driver.RowsAffected(n), commandTag
1502}
1503
1504type rowsHeader struct {
1505	colNames []string
1506	colTyps  []fieldDesc
1507	colFmts  []format
1508}
1509
1510type rows struct {
1511	cn     *conn
1512	finish func()
1513	rowsHeader
1514	done   bool
1515	rb     readBuf
1516	result driver.Result
1517	tag    string
1518
1519	next *rowsHeader
1520}
1521
1522func (rs *rows) Close() error {
1523	if finish := rs.finish; finish != nil {
1524		defer finish()
1525	}
1526	// no need to look at cn.bad as Next() will
1527	for {
1528		err := rs.Next(nil)
1529		switch err {
1530		case nil:
1531		case io.EOF:
1532			// rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row
1533			// description, used with HasNextResultSet). We need to fetch messages until
1534			// we hit a 'Z', which is done by waiting for done to be set.
1535			if rs.done {
1536				return nil
1537			}
1538		default:
1539			return err
1540		}
1541	}
1542}
1543
1544func (rs *rows) Columns() []string {
1545	return rs.colNames
1546}
1547
1548func (rs *rows) Result() driver.Result {
1549	if rs.result == nil {
1550		return emptyRows
1551	}
1552	return rs.result
1553}
1554
1555func (rs *rows) Tag() string {
1556	return rs.tag
1557}
1558
1559func (rs *rows) Next(dest []driver.Value) (err error) {
1560	if rs.done {
1561		return io.EOF
1562	}
1563
1564	conn := rs.cn
1565	if err := conn.err.getForNext(); err != nil {
1566		return err
1567	}
1568	defer conn.errRecover(&err)
1569
1570	for {
1571		t := conn.recv1Buf(&rs.rb)
1572		switch t {
1573		case 'E':
1574			err = parseError(&rs.rb)
1575		case 'C', 'I':
1576			if t == 'C' {
1577				rs.result, rs.tag = conn.parseComplete(rs.rb.string())
1578			}
1579			continue
1580		case 'Z':
1581			conn.processReadyForQuery(&rs.rb)
1582			rs.done = true
1583			if err != nil {
1584				return err
1585			}
1586			return io.EOF
1587		case 'D':
1588			n := rs.rb.int16()
1589			if err != nil {
1590				conn.err.set(driver.ErrBadConn)
1591				errorf("unexpected DataRow after error %s", err)
1592			}
1593			if n < len(dest) {
1594				dest = dest[:n]
1595			}
1596			for i := range dest {
1597				l := rs.rb.int32()
1598				if l == -1 {
1599					dest[i] = nil
1600					continue
1601				}
1602				dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i])
1603			}
1604			return
1605		case 'T':
1606			next := parsePortalRowDescribe(&rs.rb)
1607			rs.next = &next
1608			return io.EOF
1609		default:
1610			errorf("unexpected message after execute: %q", t)
1611		}
1612	}
1613}
1614
1615func (rs *rows) HasNextResultSet() bool {
1616	hasNext := rs.next != nil && !rs.done
1617	return hasNext
1618}
1619
1620func (rs *rows) NextResultSet() error {
1621	if rs.next == nil {
1622		return io.EOF
1623	}
1624	rs.rowsHeader = *rs.next
1625	rs.next = nil
1626	return nil
1627}
1628
1629// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
1630// used as part of an SQL statement.  For example:
1631//
1632//    tblname := "my_table"
1633//    data := "my_data"
1634//    quoted := pq.QuoteIdentifier(tblname)
1635//    err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data)
1636//
1637// Any double quotes in name will be escaped.  The quoted identifier will be
1638// case sensitive when used in a query.  If the input string contains a zero
1639// byte, the result will be truncated immediately before it.
1640func QuoteIdentifier(name string) string {
1641	end := strings.IndexRune(name, 0)
1642	if end > -1 {
1643		name = name[:end]
1644	}
1645	return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
1646}
1647
1648// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal
1649// to DDL and other statements that do not accept parameters) to be used as part
1650// of an SQL statement.  For example:
1651//
1652//    exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z")
1653//    err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date))
1654//
1655// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be
1656// replaced by two backslashes (i.e. "\\") and the C-style escape identifier
1657// that PostgreSQL provides ('E') will be prepended to the string.
1658func QuoteLiteral(literal string) string {
1659	// This follows the PostgreSQL internal algorithm for handling quoted literals
1660	// from libpq, which can be found in the "PQEscapeStringInternal" function,
1661	// which is found in the libpq/fe-exec.c source file:
1662	// https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c
1663	//
1664	// substitute any single-quotes (') with two single-quotes ('')
1665	literal = strings.Replace(literal, `'`, `''`, -1)
1666	// determine if the string has any backslashes (\) in it.
1667	// if it does, replace any backslashes (\) with two backslashes (\\)
1668	// then, we need to wrap the entire string with a PostgreSQL
1669	// C-style escape. Per how "PQEscapeStringInternal" handles this case, we
1670	// also add a space before the "E"
1671	if strings.Contains(literal, `\`) {
1672		literal = strings.Replace(literal, `\`, `\\`, -1)
1673		literal = ` E'` + literal + `'`
1674	} else {
1675		// otherwise, we can just wrap the literal with a pair of single quotes
1676		literal = `'` + literal + `'`
1677	}
1678	return literal
1679}
1680
1681func md5s(s string) string {
1682	h := md5.New()
1683	h.Write([]byte(s))
1684	return fmt.Sprintf("%x", h.Sum(nil))
1685}
1686
1687func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) {
1688	// Do one pass over the parameters to see if we're going to send any of
1689	// them over in binary.  If we are, create a paramFormats array at the
1690	// same time.
1691	var paramFormats []int
1692	for i, x := range args {
1693		_, ok := x.([]byte)
1694		if ok {
1695			if paramFormats == nil {
1696				paramFormats = make([]int, len(args))
1697			}
1698			paramFormats[i] = 1
1699		}
1700	}
1701	if paramFormats == nil {
1702		b.int16(0)
1703	} else {
1704		b.int16(len(paramFormats))
1705		for _, x := range paramFormats {
1706			b.int16(x)
1707		}
1708	}
1709
1710	b.int16(len(args))
1711	for _, x := range args {
1712		if x == nil {
1713			b.int32(-1)
1714		} else {
1715			datum := binaryEncode(&cn.parameterStatus, x)
1716			b.int32(len(datum))
1717			b.bytes(datum)
1718		}
1719	}
1720}
1721
1722func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
1723	if len(args) >= 65536 {
1724		errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args))
1725	}
1726
1727	b := cn.writeBuf('P')
1728	b.byte(0) // unnamed statement
1729	b.string(query)
1730	b.int16(0)
1731
1732	b.next('B')
1733	b.int16(0) // unnamed portal and statement
1734	cn.sendBinaryParameters(b, args)
1735	b.bytes(colFmtDataAllText)
1736
1737	b.next('D')
1738	b.byte('P')
1739	b.byte(0) // unnamed portal
1740
1741	b.next('E')
1742	b.byte(0)
1743	b.int32(0)
1744
1745	b.next('S')
1746	cn.send(b)
1747}
1748
1749func (cn *conn) processParameterStatus(r *readBuf) {
1750	var err error
1751
1752	param := r.string()
1753	switch param {
1754	case "server_version":
1755		var major1 int
1756		var major2 int
1757		_, err = fmt.Sscanf(r.string(), "%d.%d", &major1, &major2)
1758		if err == nil {
1759			cn.parameterStatus.serverVersion = major1*10000 + major2*100
1760		}
1761
1762	case "TimeZone":
1763		cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string())
1764		if err != nil {
1765			cn.parameterStatus.currentLocation = nil
1766		}
1767
1768	default:
1769		// ignore
1770	}
1771}
1772
1773func (cn *conn) processReadyForQuery(r *readBuf) {
1774	cn.txnStatus = transactionStatus(r.byte())
1775}
1776
1777func (cn *conn) readReadyForQuery() {
1778	t, r := cn.recv1()
1779	switch t {
1780	case 'Z':
1781		cn.processReadyForQuery(r)
1782		return
1783	default:
1784		cn.err.set(driver.ErrBadConn)
1785		errorf("unexpected message %q; expected ReadyForQuery", t)
1786	}
1787}
1788
1789func (cn *conn) processBackendKeyData(r *readBuf) {
1790	cn.processID = r.int32()
1791	cn.secretKey = r.int32()
1792}
1793
1794func (cn *conn) readParseResponse() {
1795	t, r := cn.recv1()
1796	switch t {
1797	case '1':
1798		return
1799	case 'E':
1800		err := parseError(r)
1801		cn.readReadyForQuery()
1802		panic(err)
1803	default:
1804		cn.err.set(driver.ErrBadConn)
1805		errorf("unexpected Parse response %q", t)
1806	}
1807}
1808
1809func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc) {
1810	for {
1811		t, r := cn.recv1()
1812		switch t {
1813		case 't':
1814			nparams := r.int16()
1815			paramTyps = make([]oid.Oid, nparams)
1816			for i := range paramTyps {
1817				paramTyps[i] = r.oid()
1818			}
1819		case 'n':
1820			return paramTyps, nil, nil
1821		case 'T':
1822			colNames, colTyps = parseStatementRowDescribe(r)
1823			return paramTyps, colNames, colTyps
1824		case 'E':
1825			err := parseError(r)
1826			cn.readReadyForQuery()
1827			panic(err)
1828		default:
1829			cn.err.set(driver.ErrBadConn)
1830			errorf("unexpected Describe statement response %q", t)
1831		}
1832	}
1833}
1834
1835func (cn *conn) readPortalDescribeResponse() rowsHeader {
1836	t, r := cn.recv1()
1837	switch t {
1838	case 'T':
1839		return parsePortalRowDescribe(r)
1840	case 'n':
1841		return rowsHeader{}
1842	case 'E':
1843		err := parseError(r)
1844		cn.readReadyForQuery()
1845		panic(err)
1846	default:
1847		cn.err.set(driver.ErrBadConn)
1848		errorf("unexpected Describe response %q", t)
1849	}
1850	panic("not reached")
1851}
1852
1853func (cn *conn) readBindResponse() {
1854	t, r := cn.recv1()
1855	switch t {
1856	case '2':
1857		return
1858	case 'E':
1859		err := parseError(r)
1860		cn.readReadyForQuery()
1861		panic(err)
1862	default:
1863		cn.err.set(driver.ErrBadConn)
1864		errorf("unexpected Bind response %q", t)
1865	}
1866}
1867
1868func (cn *conn) postExecuteWorkaround() {
1869	// Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores
1870	// any errors from rows.Next, which masks errors that happened during the
1871	// execution of the query.  To avoid the problem in common cases, we wait
1872	// here for one more message from the database.  If it's not an error the
1873	// query will likely succeed (or perhaps has already, if it's a
1874	// CommandComplete), so we push the message into the conn struct; recv1
1875	// will return it as the next message for rows.Next or rows.Close.
1876	// However, if it's an error, we wait until ReadyForQuery and then return
1877	// the error to our caller.
1878	for {
1879		t, r := cn.recv1()
1880		switch t {
1881		case 'E':
1882			err := parseError(r)
1883			cn.readReadyForQuery()
1884			panic(err)
1885		case 'C', 'D', 'I':
1886			// the query didn't fail, but we can't process this message
1887			cn.saveMessage(t, r)
1888			return
1889		default:
1890			cn.err.set(driver.ErrBadConn)
1891			errorf("unexpected message during extended query execution: %q", t)
1892		}
1893	}
1894}
1895
1896// Only for Exec(), since we ignore the returned data
1897func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) {
1898	for {
1899		t, r := cn.recv1()
1900		switch t {
1901		case 'C':
1902			if err != nil {
1903				cn.err.set(driver.ErrBadConn)
1904				errorf("unexpected CommandComplete after error %s", err)
1905			}
1906			res, commandTag = cn.parseComplete(r.string())
1907		case 'Z':
1908			cn.processReadyForQuery(r)
1909			if res == nil && err == nil {
1910				err = errUnexpectedReady
1911			}
1912			return res, commandTag, err
1913		case 'E':
1914			err = parseError(r)
1915		case 'T', 'D', 'I':
1916			if err != nil {
1917				cn.err.set(driver.ErrBadConn)
1918				errorf("unexpected %q after error %s", t, err)
1919			}
1920			if t == 'I' {
1921				res = emptyRows
1922			}
1923			// ignore any results
1924		default:
1925			cn.err.set(driver.ErrBadConn)
1926			errorf("unknown %s response: %q", protocolState, t)
1927		}
1928	}
1929}
1930
1931func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) {
1932	n := r.int16()
1933	colNames = make([]string, n)
1934	colTyps = make([]fieldDesc, n)
1935	for i := range colNames {
1936		colNames[i] = r.string()
1937		r.next(6)
1938		colTyps[i].OID = r.oid()
1939		colTyps[i].Len = r.int16()
1940		colTyps[i].Mod = r.int32()
1941		// format code not known when describing a statement; always 0
1942		r.next(2)
1943	}
1944	return
1945}
1946
1947func parsePortalRowDescribe(r *readBuf) rowsHeader {
1948	n := r.int16()
1949	colNames := make([]string, n)
1950	colFmts := make([]format, n)
1951	colTyps := make([]fieldDesc, n)
1952	for i := range colNames {
1953		colNames[i] = r.string()
1954		r.next(6)
1955		colTyps[i].OID = r.oid()
1956		colTyps[i].Len = r.int16()
1957		colTyps[i].Mod = r.int32()
1958		colFmts[i] = format(r.int16())
1959	}
1960	return rowsHeader{
1961		colNames: colNames,
1962		colFmts:  colFmts,
1963		colTyps:  colTyps,
1964	}
1965}
1966
1967// parseEnviron tries to mimic some of libpq's environment handling
1968//
1969// To ease testing, it does not directly reference os.Environ, but is
1970// designed to accept its output.
1971//
1972// Environment-set connection information is intended to have a higher
1973// precedence than a library default but lower than any explicitly
1974// passed information (such as in the URL or connection string).
1975func parseEnviron(env []string) (out map[string]string) {
1976	out = make(map[string]string)
1977
1978	for _, v := range env {
1979		parts := strings.SplitN(v, "=", 2)
1980
1981		accrue := func(keyname string) {
1982			out[keyname] = parts[1]
1983		}
1984		unsupported := func() {
1985			panic(fmt.Sprintf("setting %v not supported", parts[0]))
1986		}
1987
1988		// The order of these is the same as is seen in the
1989		// PostgreSQL 9.1 manual. Unsupported but well-defined
1990		// keys cause a panic; these should be unset prior to
1991		// execution. Options which pq expects to be set to a
1992		// certain value are allowed, but must be set to that
1993		// value if present (they can, of course, be absent).
1994		switch parts[0] {
1995		case "PGHOST":
1996			accrue("host")
1997		case "PGHOSTADDR":
1998			unsupported()
1999		case "PGPORT":
2000			accrue("port")
2001		case "PGDATABASE":
2002			accrue("dbname")
2003		case "PGUSER":
2004			accrue("user")
2005		case "PGPASSWORD":
2006			accrue("password")
2007		case "PGSERVICE", "PGSERVICEFILE", "PGREALM":
2008			unsupported()
2009		case "PGOPTIONS":
2010			accrue("options")
2011		case "PGAPPNAME":
2012			accrue("application_name")
2013		case "PGSSLMODE":
2014			accrue("sslmode")
2015		case "PGSSLCERT":
2016			accrue("sslcert")
2017		case "PGSSLKEY":
2018			accrue("sslkey")
2019		case "PGSSLROOTCERT":
2020			accrue("sslrootcert")
2021		case "PGREQUIRESSL", "PGSSLCRL":
2022			unsupported()
2023		case "PGREQUIREPEER":
2024			unsupported()
2025		case "PGKRBSRVNAME", "PGGSSLIB":
2026			unsupported()
2027		case "PGCONNECT_TIMEOUT":
2028			accrue("connect_timeout")
2029		case "PGCLIENTENCODING":
2030			accrue("client_encoding")
2031		case "PGDATESTYLE":
2032			accrue("datestyle")
2033		case "PGTZ":
2034			accrue("timezone")
2035		case "PGGEQO":
2036			accrue("geqo")
2037		case "PGSYSCONFDIR", "PGLOCALEDIR":
2038			unsupported()
2039		}
2040	}
2041
2042	return out
2043}
2044
2045// isUTF8 returns whether name is a fuzzy variation of the string "UTF-8".
2046func isUTF8(name string) bool {
2047	// Recognize all sorts of silly things as "UTF-8", like Postgres does
2048	s := strings.Map(alnumLowerASCII, name)
2049	return s == "utf8" || s == "unicode"
2050}
2051
2052func alnumLowerASCII(ch rune) rune {
2053	if 'A' <= ch && ch <= 'Z' {
2054		return ch + ('a' - 'A')
2055	}
2056	if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' {
2057		return ch
2058	}
2059	return -1 // discard
2060}
2061