1package pgx
2
3import (
4	"context"
5	"fmt"
6	"time"
7
8	errors "golang.org/x/xerrors"
9
10	"github.com/jackc/pgconn"
11	"github.com/jackc/pgproto3/v2"
12	"github.com/jackc/pgtype"
13)
14
15// Rows is the result set returned from *Conn.Query. Rows must be closed before
16// the *Conn can be used again. Rows are closed by explicitly calling Close(),
17// calling Next() until it returns false, or when a fatal error occurs.
18//
19// Once a Rows is closed the only methods that may be called are Close(), Err(), and CommandTag().
20//
21// Rows is an interface instead of a struct to allow tests to mock Query. However,
22// adding a method to an interface is technically a breaking change. Because of this
23// the Rows interface is partially excluded from semantic version requirements.
24// Methods will not be removed or changed, but new methods may be added.
25type Rows interface {
26	// Close closes the rows, making the connection ready for use again. It is safe
27	// to call Close after rows is already closed.
28	Close()
29
30	// Err returns any error that occurred while reading.
31	Err() error
32
33	// CommandTag returns the command tag from this query. It is only available after Rows is closed.
34	CommandTag() pgconn.CommandTag
35
36	FieldDescriptions() []pgproto3.FieldDescription
37
38	// Next prepares the next row for reading. It returns true if there is another
39	// row and false if no more rows are available. It automatically closes rows
40	// when all rows are read.
41	Next() bool
42
43	// Scan reads the values from the current row into dest values positionally.
44	// dest can include pointers to core types, values implementing the Scanner
45	// interface, []byte, and nil. []byte will skip the decoding process and directly
46	// copy the raw bytes received from PostgreSQL. nil will skip the value entirely.
47	Scan(dest ...interface{}) error
48
49	// Values returns the decoded row values.
50	Values() ([]interface{}, error)
51
52	// RawValues returns the unparsed bytes of the row values. The returned [][]byte is only valid until the next Next
53	// call or the Rows is closed. However, the underlying byte data is safe to retain a reference to and mutate.
54	RawValues() [][]byte
55}
56
57// Row is a convenience wrapper over Rows that is returned by QueryRow.
58//
59// Row is an interface instead of a struct to allow tests to mock QueryRow. However,
60// adding a method to an interface is technically a breaking change. Because of this
61// the Row interface is partially excluded from semantic version requirements.
62// Methods will not be removed or changed, but new methods may be added.
63type Row interface {
64	// Scan works the same as Rows. with the following exceptions. If no
65	// rows were found it returns ErrNoRows. If multiple rows are returned it
66	// ignores all but the first.
67	Scan(dest ...interface{}) error
68}
69
70// connRow implements the Row interface for Conn.QueryRow.
71type connRow connRows
72
73func (r *connRow) Scan(dest ...interface{}) (err error) {
74	rows := (*connRows)(r)
75
76	if rows.Err() != nil {
77		return rows.Err()
78	}
79
80	if !rows.Next() {
81		if rows.Err() == nil {
82			return ErrNoRows
83		}
84		return rows.Err()
85	}
86
87	rows.Scan(dest...)
88	rows.Close()
89	return rows.Err()
90}
91
92type rowLog interface {
93	shouldLog(lvl LogLevel) bool
94	log(ctx context.Context, lvl LogLevel, msg string, data map[string]interface{})
95}
96
97// connRows implements the Rows interface for Conn.Query.
98type connRows struct {
99	ctx        context.Context
100	logger     rowLog
101	connInfo   *pgtype.ConnInfo
102	values     [][]byte
103	rowCount   int
104	err        error
105	commandTag pgconn.CommandTag
106	startTime  time.Time
107	sql        string
108	args       []interface{}
109	closed     bool
110
111	resultReader      *pgconn.ResultReader
112	multiResultReader *pgconn.MultiResultReader
113
114	scanPlans []pgtype.ScanPlan
115}
116
117func (rows *connRows) FieldDescriptions() []pgproto3.FieldDescription {
118	return rows.resultReader.FieldDescriptions()
119}
120
121func (rows *connRows) Close() {
122	if rows.closed {
123		return
124	}
125
126	rows.closed = true
127
128	if rows.resultReader != nil {
129		var closeErr error
130		rows.commandTag, closeErr = rows.resultReader.Close()
131		if rows.err == nil {
132			rows.err = closeErr
133		}
134	}
135
136	if rows.multiResultReader != nil {
137		closeErr := rows.multiResultReader.Close()
138		if rows.err == nil {
139			rows.err = closeErr
140		}
141	}
142
143	if rows.logger != nil {
144		if rows.err == nil {
145			if rows.logger.shouldLog(LogLevelInfo) {
146				endTime := time.Now()
147				rows.logger.log(rows.ctx, LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount})
148			}
149		} else if rows.logger.shouldLog(LogLevelError) {
150			rows.logger.log(rows.ctx, LogLevelError, "Query", map[string]interface{}{"err": rows.err, "sql": rows.sql, "args": logQueryArgs(rows.args)})
151		}
152	}
153}
154
155func (rows *connRows) CommandTag() pgconn.CommandTag {
156	return rows.commandTag
157}
158
159func (rows *connRows) Err() error {
160	return rows.err
161}
162
163// fatal signals an error occurred after the query was sent to the server. It
164// closes the rows automatically.
165func (rows *connRows) fatal(err error) {
166	if rows.err != nil {
167		return
168	}
169
170	rows.err = err
171	rows.Close()
172}
173
174func (rows *connRows) Next() bool {
175	if rows.closed {
176		return false
177	}
178
179	if rows.resultReader.NextRow() {
180		rows.rowCount++
181		rows.values = rows.resultReader.Values()
182		return true
183	} else {
184		rows.Close()
185		return false
186	}
187}
188
189func (rows *connRows) Scan(dest ...interface{}) error {
190	ci := rows.connInfo
191	fieldDescriptions := rows.FieldDescriptions()
192	values := rows.values
193
194	if len(fieldDescriptions) != len(values) {
195		err := errors.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values))
196		rows.fatal(err)
197		return err
198	}
199	if len(fieldDescriptions) != len(dest) {
200		err := errors.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest))
201		rows.fatal(err)
202		return err
203	}
204
205	if rows.scanPlans == nil {
206		rows.scanPlans = make([]pgtype.ScanPlan, len(values))
207		for i, dst := range dest {
208			if dst == nil {
209				continue
210			}
211			rows.scanPlans[i] = ci.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i])
212		}
213	}
214
215	for i, dst := range dest {
216		if dst == nil {
217			continue
218		}
219
220		err := rows.scanPlans[i].Scan(ci, fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], dst)
221		if err != nil {
222			err = scanArgError{col: i, err: err}
223			rows.fatal(err)
224			return err
225		}
226	}
227
228	return nil
229}
230
231func (rows *connRows) Values() ([]interface{}, error) {
232	if rows.closed {
233		return nil, errors.New("rows is closed")
234	}
235
236	values := make([]interface{}, 0, len(rows.FieldDescriptions()))
237
238	for i := range rows.FieldDescriptions() {
239		buf := rows.values[i]
240		fd := &rows.FieldDescriptions()[i]
241
242		if buf == nil {
243			values = append(values, nil)
244			continue
245		}
246
247		if dt, ok := rows.connInfo.DataTypeForOID(fd.DataTypeOID); ok {
248			value := dt.Value
249
250			switch fd.Format {
251			case TextFormatCode:
252				decoder, ok := value.(pgtype.TextDecoder)
253				if !ok {
254					decoder = &pgtype.GenericText{}
255				}
256				err := decoder.DecodeText(rows.connInfo, buf)
257				if err != nil {
258					rows.fatal(err)
259				}
260				values = append(values, decoder.(pgtype.Value).Get())
261			case BinaryFormatCode:
262				decoder, ok := value.(pgtype.BinaryDecoder)
263				if !ok {
264					decoder = &pgtype.GenericBinary{}
265				}
266				err := decoder.DecodeBinary(rows.connInfo, buf)
267				if err != nil {
268					rows.fatal(err)
269				}
270				values = append(values, value.Get())
271			default:
272				rows.fatal(errors.New("Unknown format code"))
273			}
274		} else {
275			switch fd.Format {
276			case TextFormatCode:
277				decoder := &pgtype.GenericText{}
278				err := decoder.DecodeText(rows.connInfo, buf)
279				if err != nil {
280					rows.fatal(err)
281				}
282				values = append(values, decoder.Get())
283			case BinaryFormatCode:
284				decoder := &pgtype.GenericBinary{}
285				err := decoder.DecodeBinary(rows.connInfo, buf)
286				if err != nil {
287					rows.fatal(err)
288				}
289				values = append(values, decoder.Get())
290			default:
291				rows.fatal(errors.New("Unknown format code"))
292			}
293		}
294
295		if rows.Err() != nil {
296			return nil, rows.Err()
297		}
298	}
299
300	return values, rows.Err()
301}
302
303func (rows *connRows) RawValues() [][]byte {
304	return rows.values
305}
306
307type scanArgError struct {
308	col int
309	err error
310}
311
312func (e scanArgError) Error() string {
313	return fmt.Sprintf("can't scan into dest[%d]: %v", e.col, e.err)
314}
315
316// ScanRow decodes raw row data into dest. It can be used to scan rows read from the lower level pgconn interface.
317//
318// connInfo - OID to Go type mapping.
319// fieldDescriptions - OID and format of values
320// values - the raw data as returned from the PostgreSQL server
321// dest - the destination that values will be decoded into
322func ScanRow(connInfo *pgtype.ConnInfo, fieldDescriptions []pgproto3.FieldDescription, values [][]byte, dest ...interface{}) error {
323	if len(fieldDescriptions) != len(values) {
324		return errors.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values))
325	}
326	if len(fieldDescriptions) != len(dest) {
327		return errors.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest))
328	}
329
330	for i, d := range dest {
331		if d == nil {
332			continue
333		}
334
335		err := connInfo.Scan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], d)
336		if err != nil {
337			return scanArgError{col: i, err: err}
338		}
339	}
340
341	return nil
342}
343