1package pgx
2
3import (
4	"context"
5	"database/sql"
6	"fmt"
7	"reflect"
8	"time"
9
10	"github.com/pkg/errors"
11
12	"github.com/jackc/pgx/internal/sanitize"
13	"github.com/jackc/pgx/pgproto3"
14	"github.com/jackc/pgx/pgtype"
15)
16
17// Row is a convenience wrapper over Rows that is returned by QueryRow.
18type Row Rows
19
20// Scan works the same as (*Rows Scan) with the following exceptions. If no
21// rows were found it returns ErrNoRows. If multiple rows are returned it
22// ignores all but the first.
23func (r *Row) Scan(dest ...interface{}) (err error) {
24	rows := (*Rows)(r)
25
26	if rows.Err() != nil {
27		return rows.Err()
28	}
29
30	if !rows.Next() {
31		if rows.Err() == nil {
32			return ErrNoRows
33		}
34		return rows.Err()
35	}
36
37	rows.Scan(dest...)
38	rows.Close()
39	return rows.Err()
40}
41
42// Rows is the result set returned from *Conn.Query. Rows must be closed before
43// the *Conn can be used again. Rows are closed by explicitly calling Close(),
44// calling Next() until it returns false, or when a fatal error occurs.
45type Rows struct {
46	conn       *Conn
47	connPool   *ConnPool
48	batch      *Batch
49	values     [][]byte
50	fields     []FieldDescription
51	rowCount   int
52	columnIdx  int
53	err        error
54	startTime  time.Time
55	sql        string
56	args       []interface{}
57	unlockConn bool
58	closed     bool
59}
60
61func (rows *Rows) FieldDescriptions() []FieldDescription {
62	return rows.fields
63}
64
65// Close closes the rows, making the connection ready for use again. It is safe
66// to call Close after rows is already closed.
67func (rows *Rows) Close() {
68	if rows.closed {
69		return
70	}
71
72	if rows.unlockConn {
73		rows.conn.unlock()
74		rows.unlockConn = false
75	}
76
77	rows.closed = true
78
79	rows.err = rows.conn.termContext(rows.err)
80
81	if rows.err == nil {
82		if rows.conn.shouldLog(LogLevelInfo) {
83			endTime := time.Now()
84			rows.conn.log(LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount})
85		}
86	} else if rows.conn.shouldLog(LogLevelError) {
87		rows.conn.log(LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)})
88	}
89
90	if rows.batch != nil && rows.err != nil {
91		rows.batch.die(rows.err)
92	}
93
94	if rows.connPool != nil {
95		rows.connPool.Release(rows.conn)
96	}
97}
98
99func (rows *Rows) Err() error {
100	return rows.err
101}
102
103// fatal signals an error occurred after the query was sent to the server. It
104// closes the rows automatically.
105func (rows *Rows) fatal(err error) {
106	if rows.err != nil {
107		return
108	}
109
110	rows.err = err
111	rows.Close()
112}
113
114// Next prepares the next row for reading. It returns true if there is another
115// row and false if no more rows are available. It automatically closes rows
116// when all rows are read.
117func (rows *Rows) Next() bool {
118	if rows.closed {
119		return false
120	}
121
122	rows.rowCount++
123	rows.columnIdx = 0
124
125	for {
126		msg, err := rows.conn.rxMsg()
127		if err != nil {
128			rows.fatal(err)
129			return false
130		}
131
132		switch msg := msg.(type) {
133		case *pgproto3.RowDescription:
134			rows.fields = rows.conn.rxRowDescription(msg)
135			for i := range rows.fields {
136				if dt, ok := rows.conn.ConnInfo.DataTypeForOID(rows.fields[i].DataType); ok {
137					rows.fields[i].DataTypeName = dt.Name
138					rows.fields[i].FormatCode = TextFormatCode
139				} else {
140					rows.fatal(errors.Errorf("unknown oid: %d", rows.fields[i].DataType))
141					return false
142				}
143			}
144		case *pgproto3.DataRow:
145			if len(msg.Values) != len(rows.fields) {
146				rows.fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), len(msg.Values))))
147				return false
148			}
149
150			rows.values = msg.Values
151			return true
152		case *pgproto3.CommandComplete:
153			if rows.batch != nil {
154				rows.batch.pendingCommandComplete = false
155			}
156			rows.Close()
157			return false
158
159		default:
160			err = rows.conn.processContextFreeMsg(msg)
161			if err != nil {
162				rows.fatal(err)
163				return false
164			}
165		}
166	}
167}
168
169func (rows *Rows) nextColumn() ([]byte, *FieldDescription, bool) {
170	if rows.closed {
171		return nil, nil, false
172	}
173	if len(rows.fields) <= rows.columnIdx {
174		rows.fatal(ProtocolError("No next column available"))
175		return nil, nil, false
176	}
177
178	buf := rows.values[rows.columnIdx]
179	fd := &rows.fields[rows.columnIdx]
180	rows.columnIdx++
181	return buf, fd, true
182}
183
184type scanArgError struct {
185	col int
186	err error
187}
188
189func (e scanArgError) Error() string {
190	return fmt.Sprintf("can't scan into dest[%d]: %v", e.col, e.err)
191}
192
193// Scan reads the values from the current row into dest values positionally.
194// dest can include pointers to core types, values implementing the Scanner
195// interface, []byte, and nil. []byte will skip the decoding process and directly
196// copy the raw bytes received from PostgreSQL. nil will skip the value entirely.
197func (rows *Rows) Scan(dest ...interface{}) (err error) {
198	if len(rows.fields) != len(dest) {
199		err = errors.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.fields))
200		rows.fatal(err)
201		return err
202	}
203
204	for i, d := range dest {
205		buf, fd, _ := rows.nextColumn()
206
207		if d == nil {
208			continue
209		}
210
211		if s, ok := d.(pgtype.BinaryDecoder); ok && fd.FormatCode == BinaryFormatCode {
212			err = s.DecodeBinary(rows.conn.ConnInfo, buf)
213			if err != nil {
214				rows.fatal(scanArgError{col: i, err: err})
215			}
216		} else if s, ok := d.(pgtype.TextDecoder); ok && fd.FormatCode == TextFormatCode {
217			err = s.DecodeText(rows.conn.ConnInfo, buf)
218			if err != nil {
219				rows.fatal(scanArgError{col: i, err: err})
220			}
221		} else {
222			if dt, ok := rows.conn.ConnInfo.DataTypeForOID(fd.DataType); ok {
223				value := dt.Value
224				switch fd.FormatCode {
225				case TextFormatCode:
226					if textDecoder, ok := value.(pgtype.TextDecoder); ok {
227						err = textDecoder.DecodeText(rows.conn.ConnInfo, buf)
228						if err != nil {
229							rows.fatal(scanArgError{col: i, err: err})
230						}
231					} else {
232						rows.fatal(scanArgError{col: i, err: errors.Errorf("%T is not a pgtype.TextDecoder", value)})
233					}
234				case BinaryFormatCode:
235					if binaryDecoder, ok := value.(pgtype.BinaryDecoder); ok {
236						err = binaryDecoder.DecodeBinary(rows.conn.ConnInfo, buf)
237						if err != nil {
238							rows.fatal(scanArgError{col: i, err: err})
239						}
240					} else {
241						rows.fatal(scanArgError{col: i, err: errors.Errorf("%T is not a pgtype.BinaryDecoder", value)})
242					}
243				default:
244					rows.fatal(scanArgError{col: i, err: errors.Errorf("unknown format code: %v", fd.FormatCode)})
245				}
246
247				if rows.Err() == nil {
248					if scanner, ok := d.(sql.Scanner); ok {
249						sqlSrc, err := pgtype.DatabaseSQLValue(rows.conn.ConnInfo, value)
250						if err != nil {
251							rows.fatal(err)
252						}
253						err = scanner.Scan(sqlSrc)
254						if err != nil {
255							rows.fatal(scanArgError{col: i, err: err})
256						}
257					} else if err := value.AssignTo(d); err != nil {
258						rows.fatal(scanArgError{col: i, err: err})
259					}
260				}
261			} else {
262				rows.fatal(scanArgError{col: i, err: errors.Errorf("unknown oid: %v", fd.DataType)})
263			}
264		}
265
266		if rows.Err() != nil {
267			return rows.Err()
268		}
269	}
270
271	return nil
272}
273
274// Values returns an array of the row values
275func (rows *Rows) Values() ([]interface{}, error) {
276	if rows.closed {
277		return nil, errors.New("rows is closed")
278	}
279
280	values := make([]interface{}, 0, len(rows.fields))
281
282	for range rows.fields {
283		buf, fd, _ := rows.nextColumn()
284
285		if buf == nil {
286			values = append(values, nil)
287			continue
288		}
289
290		if dt, ok := rows.conn.ConnInfo.DataTypeForOID(fd.DataType); ok {
291			value := reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(pgtype.Value)
292
293			switch fd.FormatCode {
294			case TextFormatCode:
295				decoder := value.(pgtype.TextDecoder)
296				if decoder == nil {
297					decoder = &pgtype.GenericText{}
298				}
299				err := decoder.DecodeText(rows.conn.ConnInfo, buf)
300				if err != nil {
301					rows.fatal(err)
302				}
303				values = append(values, decoder.(pgtype.Value).Get())
304			case BinaryFormatCode:
305				decoder := value.(pgtype.BinaryDecoder)
306				if decoder == nil {
307					decoder = &pgtype.GenericBinary{}
308				}
309				err := decoder.DecodeBinary(rows.conn.ConnInfo, buf)
310				if err != nil {
311					rows.fatal(err)
312				}
313				values = append(values, value.Get())
314			default:
315				rows.fatal(errors.New("Unknown format code"))
316			}
317		} else {
318			rows.fatal(errors.New("Unknown type"))
319		}
320
321		if rows.Err() != nil {
322			return nil, rows.Err()
323		}
324	}
325
326	return values, rows.Err()
327}
328
329// Query executes sql with args. If there is an error the returned *Rows will
330// be returned in an error state. So it is allowed to ignore the error returned
331// from Query and handle it in *Rows.
332func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) {
333	return c.QueryEx(context.Background(), sql, nil, args...)
334}
335
336func (c *Conn) getRows(sql string, args []interface{}) *Rows {
337	if len(c.preallocatedRows) == 0 {
338		c.preallocatedRows = make([]Rows, 64)
339	}
340
341	r := &c.preallocatedRows[len(c.preallocatedRows)-1]
342	c.preallocatedRows = c.preallocatedRows[0 : len(c.preallocatedRows)-1]
343
344	r.conn = c
345	r.startTime = c.lastActivityTime
346	r.sql = sql
347	r.args = args
348
349	return r
350}
351
352// QueryRow is a convenience wrapper over Query. Any error that occurs while
353// querying is deferred until calling Scan on the returned *Row. That *Row will
354// error with ErrNoRows if no rows are returned.
355func (c *Conn) QueryRow(sql string, args ...interface{}) *Row {
356	rows, _ := c.Query(sql, args...)
357	return (*Row)(rows)
358}
359
360type QueryExOptions struct {
361	// When ParameterOIDs are present and the query is not a prepared statement,
362	// then ParameterOIDs and ResultFormatCodes will be used to avoid an extra
363	// network round-trip.
364	ParameterOIDs     []pgtype.OID
365	ResultFormatCodes []int16
366
367	SimpleProtocol bool
368}
369
370func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (rows *Rows, err error) {
371	c.lastStmtSent = false
372	c.lastActivityTime = time.Now()
373	rows = c.getRows(sql, args)
374
375	err = c.waitForPreviousCancelQuery(ctx)
376	if err != nil {
377		rows.fatal(err)
378		return rows, err
379	}
380
381	if err := c.ensureConnectionReadyForQuery(); err != nil {
382		rows.fatal(err)
383		return rows, err
384	}
385
386	if err := c.lock(); err != nil {
387		rows.fatal(err)
388		return rows, err
389	}
390	rows.unlockConn = true
391
392	err = c.initContext(ctx)
393	if err != nil {
394		rows.fatal(err)
395		return rows, rows.err
396	}
397
398	if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) {
399		c.lastStmtSent = true
400		err = c.sanitizeAndSendSimpleQuery(sql, args...)
401		if err != nil {
402			rows.fatal(err)
403			return rows, err
404		}
405
406		return rows, nil
407	}
408
409	if options != nil && len(options.ParameterOIDs) > 0 {
410
411		buf, err := c.buildOneRoundTripQueryEx(c.wbuf, sql, options, args)
412		if err != nil {
413			rows.fatal(err)
414			return rows, err
415		}
416
417		buf = appendSync(buf)
418
419		c.lastStmtSent = true
420		n, err := c.conn.Write(buf)
421		if err != nil && fatalWriteErr(n, err) {
422			rows.fatal(err)
423			c.die(err)
424			return rows, err
425		}
426		c.pendingReadyForQueryCount++
427
428		fieldDescriptions, err := c.readUntilRowDescription()
429		if err != nil {
430			rows.fatal(err)
431			return rows, err
432		}
433
434		if len(options.ResultFormatCodes) == 0 {
435			for i := range fieldDescriptions {
436				fieldDescriptions[i].FormatCode = TextFormatCode
437			}
438		} else if len(options.ResultFormatCodes) == 1 {
439			fc := options.ResultFormatCodes[0]
440			for i := range fieldDescriptions {
441				fieldDescriptions[i].FormatCode = fc
442			}
443		} else {
444			for i := range options.ResultFormatCodes {
445				fieldDescriptions[i].FormatCode = options.ResultFormatCodes[i]
446			}
447		}
448
449		rows.sql = sql
450		rows.fields = fieldDescriptions
451		return rows, nil
452	}
453
454	ps, ok := c.preparedStatements[sql]
455	if !ok {
456		var err error
457		ps, err = c.prepareEx("", sql, nil)
458		if err != nil {
459			rows.fatal(err)
460			return rows, rows.err
461		}
462	}
463	rows.sql = ps.SQL
464	rows.fields = ps.FieldDescriptions
465
466	c.lastStmtSent = true
467	err = c.sendPreparedQuery(ps, args...)
468	if err != nil {
469		rows.fatal(err)
470	}
471
472	return rows, rows.err
473}
474
475func (c *Conn) buildOneRoundTripQueryEx(buf []byte, sql string, options *QueryExOptions, arguments []interface{}) ([]byte, error) {
476	if len(arguments) != len(options.ParameterOIDs) {
477		return nil, errors.Errorf("mismatched number of arguments (%d) and options.ParameterOIDs (%d)", len(arguments), len(options.ParameterOIDs))
478	}
479
480	if len(options.ParameterOIDs) > 65535 {
481		return nil, errors.Errorf("Number of QueryExOptions ParameterOIDs must be between 0 and 65535, received %d", len(options.ParameterOIDs))
482	}
483
484	buf = appendParse(buf, "", sql, options.ParameterOIDs)
485	buf = appendDescribe(buf, 'S', "")
486	buf, err := appendBind(buf, "", "", c.ConnInfo, options.ParameterOIDs, arguments, options.ResultFormatCodes)
487	if err != nil {
488		return nil, err
489	}
490	buf = appendExecute(buf, "", 0)
491
492	return buf, nil
493}
494
495func (c *Conn) readUntilRowDescription() ([]FieldDescription, error) {
496	for {
497		msg, err := c.rxMsg()
498		if err != nil {
499			return nil, err
500		}
501
502		switch msg := msg.(type) {
503		case *pgproto3.ParameterDescription:
504		case *pgproto3.RowDescription:
505			fieldDescriptions := c.rxRowDescription(msg)
506			for i := range fieldDescriptions {
507				if dt, ok := c.ConnInfo.DataTypeForOID(fieldDescriptions[i].DataType); ok {
508					fieldDescriptions[i].DataTypeName = dt.Name
509				} else {
510					return nil, errors.Errorf("unknown oid: %d", fieldDescriptions[i].DataType)
511				}
512			}
513			return fieldDescriptions, nil
514		default:
515			if err := c.processContextFreeMsg(msg); err != nil {
516				return nil, err
517			}
518		}
519	}
520}
521
522func (c *Conn) sanitizeAndSendSimpleQuery(sql string, args ...interface{}) (err error) {
523	if c.RuntimeParams["standard_conforming_strings"] != "on" {
524		return errors.New("simple protocol queries must be run with standard_conforming_strings=on")
525	}
526
527	if c.RuntimeParams["client_encoding"] != "UTF8" {
528		return errors.New("simple protocol queries must be run with client_encoding=UTF8")
529	}
530
531	valueArgs := make([]interface{}, len(args))
532	for i, a := range args {
533		valueArgs[i], err = convertSimpleArgument(c.ConnInfo, a)
534		if err != nil {
535			return err
536		}
537	}
538
539	sql, err = sanitize.SanitizeSQL(sql, valueArgs...)
540	if err != nil {
541		return err
542	}
543
544	return c.sendSimpleQuery(sql)
545}
546
547func (c *Conn) QueryRowEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) *Row {
548	rows, _ := c.QueryEx(ctx, sql, options, args...)
549	return (*Row)(rows)
550}
551