1package pgx 2 3import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "reflect" 8 "time" 9 10 "github.com/pkg/errors" 11 12 "github.com/jackc/pgx/internal/sanitize" 13 "github.com/jackc/pgx/pgproto3" 14 "github.com/jackc/pgx/pgtype" 15) 16 17// Row is a convenience wrapper over Rows that is returned by QueryRow. 18type Row Rows 19 20// Scan works the same as (*Rows Scan) with the following exceptions. If no 21// rows were found it returns ErrNoRows. If multiple rows are returned it 22// ignores all but the first. 23func (r *Row) Scan(dest ...interface{}) (err error) { 24 rows := (*Rows)(r) 25 26 if rows.Err() != nil { 27 return rows.Err() 28 } 29 30 if !rows.Next() { 31 if rows.Err() == nil { 32 return ErrNoRows 33 } 34 return rows.Err() 35 } 36 37 rows.Scan(dest...) 38 rows.Close() 39 return rows.Err() 40} 41 42// Rows is the result set returned from *Conn.Query. Rows must be closed before 43// the *Conn can be used again. Rows are closed by explicitly calling Close(), 44// calling Next() until it returns false, or when a fatal error occurs. 45type Rows struct { 46 conn *Conn 47 connPool *ConnPool 48 batch *Batch 49 values [][]byte 50 fields []FieldDescription 51 rowCount int 52 columnIdx int 53 err error 54 startTime time.Time 55 sql string 56 args []interface{} 57 unlockConn bool 58 closed bool 59} 60 61func (rows *Rows) FieldDescriptions() []FieldDescription { 62 return rows.fields 63} 64 65// Close closes the rows, making the connection ready for use again. It is safe 66// to call Close after rows is already closed. 67func (rows *Rows) Close() { 68 if rows.closed { 69 return 70 } 71 72 if rows.unlockConn { 73 rows.conn.unlock() 74 rows.unlockConn = false 75 } 76 77 rows.closed = true 78 79 rows.err = rows.conn.termContext(rows.err) 80 81 if rows.err == nil { 82 if rows.conn.shouldLog(LogLevelInfo) { 83 endTime := time.Now() 84 rows.conn.log(LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount}) 85 } 86 } else if rows.conn.shouldLog(LogLevelError) { 87 rows.conn.log(LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)}) 88 } 89 90 if rows.batch != nil && rows.err != nil { 91 rows.batch.die(rows.err) 92 } 93 94 if rows.connPool != nil { 95 rows.connPool.Release(rows.conn) 96 } 97} 98 99func (rows *Rows) Err() error { 100 return rows.err 101} 102 103// fatal signals an error occurred after the query was sent to the server. It 104// closes the rows automatically. 105func (rows *Rows) fatal(err error) { 106 if rows.err != nil { 107 return 108 } 109 110 rows.err = err 111 rows.Close() 112} 113 114// Next prepares the next row for reading. It returns true if there is another 115// row and false if no more rows are available. It automatically closes rows 116// when all rows are read. 117func (rows *Rows) Next() bool { 118 if rows.closed { 119 return false 120 } 121 122 rows.rowCount++ 123 rows.columnIdx = 0 124 125 for { 126 msg, err := rows.conn.rxMsg() 127 if err != nil { 128 rows.fatal(err) 129 return false 130 } 131 132 switch msg := msg.(type) { 133 case *pgproto3.RowDescription: 134 rows.fields = rows.conn.rxRowDescription(msg) 135 for i := range rows.fields { 136 if dt, ok := rows.conn.ConnInfo.DataTypeForOID(rows.fields[i].DataType); ok { 137 rows.fields[i].DataTypeName = dt.Name 138 rows.fields[i].FormatCode = TextFormatCode 139 } else { 140 rows.fatal(errors.Errorf("unknown oid: %d", rows.fields[i].DataType)) 141 return false 142 } 143 } 144 case *pgproto3.DataRow: 145 if len(msg.Values) != len(rows.fields) { 146 rows.fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), len(msg.Values)))) 147 return false 148 } 149 150 rows.values = msg.Values 151 return true 152 case *pgproto3.CommandComplete: 153 if rows.batch != nil { 154 rows.batch.pendingCommandComplete = false 155 } 156 rows.Close() 157 return false 158 159 default: 160 err = rows.conn.processContextFreeMsg(msg) 161 if err != nil { 162 rows.fatal(err) 163 return false 164 } 165 } 166 } 167} 168 169func (rows *Rows) nextColumn() ([]byte, *FieldDescription, bool) { 170 if rows.closed { 171 return nil, nil, false 172 } 173 if len(rows.fields) <= rows.columnIdx { 174 rows.fatal(ProtocolError("No next column available")) 175 return nil, nil, false 176 } 177 178 buf := rows.values[rows.columnIdx] 179 fd := &rows.fields[rows.columnIdx] 180 rows.columnIdx++ 181 return buf, fd, true 182} 183 184type scanArgError struct { 185 col int 186 err error 187} 188 189func (e scanArgError) Error() string { 190 return fmt.Sprintf("can't scan into dest[%d]: %v", e.col, e.err) 191} 192 193// Scan reads the values from the current row into dest values positionally. 194// dest can include pointers to core types, values implementing the Scanner 195// interface, []byte, and nil. []byte will skip the decoding process and directly 196// copy the raw bytes received from PostgreSQL. nil will skip the value entirely. 197func (rows *Rows) Scan(dest ...interface{}) (err error) { 198 if len(rows.fields) != len(dest) { 199 err = errors.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.fields)) 200 rows.fatal(err) 201 return err 202 } 203 204 for i, d := range dest { 205 buf, fd, _ := rows.nextColumn() 206 207 if d == nil { 208 continue 209 } 210 211 if s, ok := d.(pgtype.BinaryDecoder); ok && fd.FormatCode == BinaryFormatCode { 212 err = s.DecodeBinary(rows.conn.ConnInfo, buf) 213 if err != nil { 214 rows.fatal(scanArgError{col: i, err: err}) 215 } 216 } else if s, ok := d.(pgtype.TextDecoder); ok && fd.FormatCode == TextFormatCode { 217 err = s.DecodeText(rows.conn.ConnInfo, buf) 218 if err != nil { 219 rows.fatal(scanArgError{col: i, err: err}) 220 } 221 } else { 222 if dt, ok := rows.conn.ConnInfo.DataTypeForOID(fd.DataType); ok { 223 value := dt.Value 224 switch fd.FormatCode { 225 case TextFormatCode: 226 if textDecoder, ok := value.(pgtype.TextDecoder); ok { 227 err = textDecoder.DecodeText(rows.conn.ConnInfo, buf) 228 if err != nil { 229 rows.fatal(scanArgError{col: i, err: err}) 230 } 231 } else { 232 rows.fatal(scanArgError{col: i, err: errors.Errorf("%T is not a pgtype.TextDecoder", value)}) 233 } 234 case BinaryFormatCode: 235 if binaryDecoder, ok := value.(pgtype.BinaryDecoder); ok { 236 err = binaryDecoder.DecodeBinary(rows.conn.ConnInfo, buf) 237 if err != nil { 238 rows.fatal(scanArgError{col: i, err: err}) 239 } 240 } else { 241 rows.fatal(scanArgError{col: i, err: errors.Errorf("%T is not a pgtype.BinaryDecoder", value)}) 242 } 243 default: 244 rows.fatal(scanArgError{col: i, err: errors.Errorf("unknown format code: %v", fd.FormatCode)}) 245 } 246 247 if rows.Err() == nil { 248 if scanner, ok := d.(sql.Scanner); ok { 249 sqlSrc, err := pgtype.DatabaseSQLValue(rows.conn.ConnInfo, value) 250 if err != nil { 251 rows.fatal(err) 252 } 253 err = scanner.Scan(sqlSrc) 254 if err != nil { 255 rows.fatal(scanArgError{col: i, err: err}) 256 } 257 } else if err := value.AssignTo(d); err != nil { 258 rows.fatal(scanArgError{col: i, err: err}) 259 } 260 } 261 } else { 262 rows.fatal(scanArgError{col: i, err: errors.Errorf("unknown oid: %v", fd.DataType)}) 263 } 264 } 265 266 if rows.Err() != nil { 267 return rows.Err() 268 } 269 } 270 271 return nil 272} 273 274// Values returns an array of the row values 275func (rows *Rows) Values() ([]interface{}, error) { 276 if rows.closed { 277 return nil, errors.New("rows is closed") 278 } 279 280 values := make([]interface{}, 0, len(rows.fields)) 281 282 for range rows.fields { 283 buf, fd, _ := rows.nextColumn() 284 285 if buf == nil { 286 values = append(values, nil) 287 continue 288 } 289 290 if dt, ok := rows.conn.ConnInfo.DataTypeForOID(fd.DataType); ok { 291 value := reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(pgtype.Value) 292 293 switch fd.FormatCode { 294 case TextFormatCode: 295 decoder := value.(pgtype.TextDecoder) 296 if decoder == nil { 297 decoder = &pgtype.GenericText{} 298 } 299 err := decoder.DecodeText(rows.conn.ConnInfo, buf) 300 if err != nil { 301 rows.fatal(err) 302 } 303 values = append(values, decoder.(pgtype.Value).Get()) 304 case BinaryFormatCode: 305 decoder := value.(pgtype.BinaryDecoder) 306 if decoder == nil { 307 decoder = &pgtype.GenericBinary{} 308 } 309 err := decoder.DecodeBinary(rows.conn.ConnInfo, buf) 310 if err != nil { 311 rows.fatal(err) 312 } 313 values = append(values, value.Get()) 314 default: 315 rows.fatal(errors.New("Unknown format code")) 316 } 317 } else { 318 rows.fatal(errors.New("Unknown type")) 319 } 320 321 if rows.Err() != nil { 322 return nil, rows.Err() 323 } 324 } 325 326 return values, rows.Err() 327} 328 329// Query executes sql with args. If there is an error the returned *Rows will 330// be returned in an error state. So it is allowed to ignore the error returned 331// from Query and handle it in *Rows. 332func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) { 333 return c.QueryEx(context.Background(), sql, nil, args...) 334} 335 336func (c *Conn) getRows(sql string, args []interface{}) *Rows { 337 if len(c.preallocatedRows) == 0 { 338 c.preallocatedRows = make([]Rows, 64) 339 } 340 341 r := &c.preallocatedRows[len(c.preallocatedRows)-1] 342 c.preallocatedRows = c.preallocatedRows[0 : len(c.preallocatedRows)-1] 343 344 r.conn = c 345 r.startTime = c.lastActivityTime 346 r.sql = sql 347 r.args = args 348 349 return r 350} 351 352// QueryRow is a convenience wrapper over Query. Any error that occurs while 353// querying is deferred until calling Scan on the returned *Row. That *Row will 354// error with ErrNoRows if no rows are returned. 355func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { 356 rows, _ := c.Query(sql, args...) 357 return (*Row)(rows) 358} 359 360type QueryExOptions struct { 361 // When ParameterOIDs are present and the query is not a prepared statement, 362 // then ParameterOIDs and ResultFormatCodes will be used to avoid an extra 363 // network round-trip. 364 ParameterOIDs []pgtype.OID 365 ResultFormatCodes []int16 366 367 SimpleProtocol bool 368} 369 370func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (rows *Rows, err error) { 371 c.lastActivityTime = time.Now() 372 rows = c.getRows(sql, args) 373 374 err = c.waitForPreviousCancelQuery(ctx) 375 if err != nil { 376 rows.fatal(err) 377 return rows, err 378 } 379 380 if err := c.ensureConnectionReadyForQuery(); err != nil { 381 rows.fatal(err) 382 return rows, err 383 } 384 385 if err := c.lock(); err != nil { 386 rows.fatal(err) 387 return rows, err 388 } 389 rows.unlockConn = true 390 391 err = c.initContext(ctx) 392 if err != nil { 393 rows.fatal(err) 394 return rows, rows.err 395 } 396 397 if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) { 398 err = c.sanitizeAndSendSimpleQuery(sql, args...) 399 if err != nil { 400 rows.fatal(err) 401 return rows, err 402 } 403 404 return rows, nil 405 } 406 407 if options != nil && len(options.ParameterOIDs) > 0 { 408 409 buf, err := c.buildOneRoundTripQueryEx(c.wbuf, sql, options, args) 410 if err != nil { 411 rows.fatal(err) 412 return rows, err 413 } 414 415 buf = appendSync(buf) 416 417 n, err := c.conn.Write(buf) 418 if err != nil && fatalWriteErr(n, err) { 419 rows.fatal(err) 420 c.die(err) 421 return rows, err 422 } 423 c.pendingReadyForQueryCount++ 424 425 fieldDescriptions, err := c.readUntilRowDescription() 426 if err != nil { 427 rows.fatal(err) 428 return rows, err 429 } 430 431 if len(options.ResultFormatCodes) == 0 { 432 for i := range fieldDescriptions { 433 fieldDescriptions[i].FormatCode = TextFormatCode 434 } 435 } else if len(options.ResultFormatCodes) == 1 { 436 fc := options.ResultFormatCodes[0] 437 for i := range fieldDescriptions { 438 fieldDescriptions[i].FormatCode = fc 439 } 440 } else { 441 for i := range options.ResultFormatCodes { 442 fieldDescriptions[i].FormatCode = options.ResultFormatCodes[i] 443 } 444 } 445 446 rows.sql = sql 447 rows.fields = fieldDescriptions 448 return rows, nil 449 } 450 451 ps, ok := c.preparedStatements[sql] 452 if !ok { 453 var err error 454 ps, err = c.prepareEx("", sql, nil) 455 if err != nil { 456 rows.fatal(err) 457 return rows, rows.err 458 } 459 } 460 rows.sql = ps.SQL 461 rows.fields = ps.FieldDescriptions 462 463 err = c.sendPreparedQuery(ps, args...) 464 if err != nil { 465 rows.fatal(err) 466 } 467 468 return rows, rows.err 469} 470 471func (c *Conn) buildOneRoundTripQueryEx(buf []byte, sql string, options *QueryExOptions, arguments []interface{}) ([]byte, error) { 472 if len(arguments) != len(options.ParameterOIDs) { 473 return nil, errors.Errorf("mismatched number of arguments (%d) and options.ParameterOIDs (%d)", len(arguments), len(options.ParameterOIDs)) 474 } 475 476 if len(options.ParameterOIDs) > 65535 { 477 return nil, errors.Errorf("Number of QueryExOptions ParameterOIDs must be between 0 and 65535, received %d", len(options.ParameterOIDs)) 478 } 479 480 buf = appendParse(buf, "", sql, options.ParameterOIDs) 481 buf = appendDescribe(buf, 'S', "") 482 buf, err := appendBind(buf, "", "", c.ConnInfo, options.ParameterOIDs, arguments, options.ResultFormatCodes) 483 if err != nil { 484 return nil, err 485 } 486 buf = appendExecute(buf, "", 0) 487 488 return buf, nil 489} 490 491func (c *Conn) readUntilRowDescription() ([]FieldDescription, error) { 492 for { 493 msg, err := c.rxMsg() 494 if err != nil { 495 return nil, err 496 } 497 498 switch msg := msg.(type) { 499 case *pgproto3.ParameterDescription: 500 case *pgproto3.RowDescription: 501 fieldDescriptions := c.rxRowDescription(msg) 502 for i := range fieldDescriptions { 503 if dt, ok := c.ConnInfo.DataTypeForOID(fieldDescriptions[i].DataType); ok { 504 fieldDescriptions[i].DataTypeName = dt.Name 505 } else { 506 return nil, errors.Errorf("unknown oid: %d", fieldDescriptions[i].DataType) 507 } 508 } 509 return fieldDescriptions, nil 510 default: 511 if err := c.processContextFreeMsg(msg); err != nil { 512 return nil, err 513 } 514 } 515 } 516} 517 518func (c *Conn) sanitizeAndSendSimpleQuery(sql string, args ...interface{}) (err error) { 519 if c.RuntimeParams["standard_conforming_strings"] != "on" { 520 return errors.New("simple protocol queries must be run with standard_conforming_strings=on") 521 } 522 523 if c.RuntimeParams["client_encoding"] != "UTF8" { 524 return errors.New("simple protocol queries must be run with client_encoding=UTF8") 525 } 526 527 valueArgs := make([]interface{}, len(args)) 528 for i, a := range args { 529 valueArgs[i], err = convertSimpleArgument(c.ConnInfo, a) 530 if err != nil { 531 return err 532 } 533 } 534 535 sql, err = sanitize.SanitizeSQL(sql, valueArgs...) 536 if err != nil { 537 return err 538 } 539 540 return c.sendSimpleQuery(sql) 541} 542 543func (c *Conn) QueryRowEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) *Row { 544 rows, _ := c.QueryEx(ctx, sql, options, args...) 545 return (*Row)(rows) 546} 547