1package pgx 2 3import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "reflect" 8 "time" 9 10 "github.com/pkg/errors" 11 12 "github.com/jackc/pgx/internal/sanitize" 13 "github.com/jackc/pgx/pgproto3" 14 "github.com/jackc/pgx/pgtype" 15) 16 17// Row is a convenience wrapper over Rows that is returned by QueryRow. 18type Row Rows 19 20// Scan works the same as (*Rows Scan) with the following exceptions. If no 21// rows were found it returns ErrNoRows. If multiple rows are returned it 22// ignores all but the first. 23func (r *Row) Scan(dest ...interface{}) (err error) { 24 rows := (*Rows)(r) 25 26 if rows.Err() != nil { 27 return rows.Err() 28 } 29 30 if !rows.Next() { 31 if rows.Err() == nil { 32 return ErrNoRows 33 } 34 return rows.Err() 35 } 36 37 rows.Scan(dest...) 38 rows.Close() 39 return rows.Err() 40} 41 42// Rows is the result set returned from *Conn.Query. Rows must be closed before 43// the *Conn can be used again. Rows are closed by explicitly calling Close(), 44// calling Next() until it returns false, or when a fatal error occurs. 45type Rows struct { 46 conn *Conn 47 connPool *ConnPool 48 batch *Batch 49 values [][]byte 50 fields []FieldDescription 51 rowCount int 52 columnIdx int 53 err error 54 startTime time.Time 55 sql string 56 args []interface{} 57 unlockConn bool 58 closed bool 59} 60 61func (rows *Rows) FieldDescriptions() []FieldDescription { 62 return rows.fields 63} 64 65// Close closes the rows, making the connection ready for use again. It is safe 66// to call Close after rows is already closed. 67func (rows *Rows) Close() { 68 if rows.closed { 69 return 70 } 71 72 if rows.unlockConn { 73 rows.conn.unlock() 74 rows.unlockConn = false 75 } 76 77 rows.closed = true 78 79 rows.err = rows.conn.termContext(rows.err) 80 81 if rows.err == nil { 82 if rows.conn.shouldLog(LogLevelInfo) { 83 endTime := time.Now() 84 rows.conn.log(LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount}) 85 } 86 } else if rows.conn.shouldLog(LogLevelError) { 87 rows.conn.log(LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)}) 88 } 89 90 if rows.batch != nil && rows.err != nil { 91 rows.batch.die(rows.err) 92 } 93 94 if rows.connPool != nil { 95 rows.connPool.Release(rows.conn) 96 } 97} 98 99func (rows *Rows) Err() error { 100 return rows.err 101} 102 103// fatal signals an error occurred after the query was sent to the server. It 104// closes the rows automatically. 105func (rows *Rows) fatal(err error) { 106 if rows.err != nil { 107 return 108 } 109 110 rows.err = err 111 rows.Close() 112} 113 114// Next prepares the next row for reading. It returns true if there is another 115// row and false if no more rows are available. It automatically closes rows 116// when all rows are read. 117func (rows *Rows) Next() bool { 118 if rows.closed { 119 return false 120 } 121 122 rows.rowCount++ 123 rows.columnIdx = 0 124 125 for { 126 msg, err := rows.conn.rxMsg() 127 if err != nil { 128 rows.fatal(err) 129 return false 130 } 131 132 switch msg := msg.(type) { 133 case *pgproto3.RowDescription: 134 rows.fields = rows.conn.rxRowDescription(msg) 135 for i := range rows.fields { 136 if dt, ok := rows.conn.ConnInfo.DataTypeForOID(rows.fields[i].DataType); ok { 137 rows.fields[i].DataTypeName = dt.Name 138 rows.fields[i].FormatCode = TextFormatCode 139 } else { 140 rows.fatal(errors.Errorf("unknown oid: %d", rows.fields[i].DataType)) 141 return false 142 } 143 } 144 case *pgproto3.DataRow: 145 if len(msg.Values) != len(rows.fields) { 146 rows.fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), len(msg.Values)))) 147 return false 148 } 149 150 rows.values = msg.Values 151 return true 152 case *pgproto3.CommandComplete: 153 if rows.batch != nil { 154 rows.batch.pendingCommandComplete = false 155 } 156 rows.Close() 157 return false 158 159 default: 160 err = rows.conn.processContextFreeMsg(msg) 161 if err != nil { 162 rows.fatal(err) 163 return false 164 } 165 } 166 } 167} 168 169func (rows *Rows) nextColumn() ([]byte, *FieldDescription, bool) { 170 if rows.closed { 171 return nil, nil, false 172 } 173 if len(rows.fields) <= rows.columnIdx { 174 rows.fatal(ProtocolError("No next column available")) 175 return nil, nil, false 176 } 177 178 buf := rows.values[rows.columnIdx] 179 fd := &rows.fields[rows.columnIdx] 180 rows.columnIdx++ 181 return buf, fd, true 182} 183 184type scanArgError struct { 185 col int 186 err error 187} 188 189func (e scanArgError) Error() string { 190 return fmt.Sprintf("can't scan into dest[%d]: %v", e.col, e.err) 191} 192 193// Scan reads the values from the current row into dest values positionally. 194// dest can include pointers to core types, values implementing the Scanner 195// interface, []byte, and nil. []byte will skip the decoding process and directly 196// copy the raw bytes received from PostgreSQL. nil will skip the value entirely. 197func (rows *Rows) Scan(dest ...interface{}) (err error) { 198 if len(rows.fields) != len(dest) { 199 err = errors.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.fields)) 200 rows.fatal(err) 201 return err 202 } 203 204 for i, d := range dest { 205 buf, fd, _ := rows.nextColumn() 206 207 if d == nil { 208 continue 209 } 210 211 if s, ok := d.(pgtype.BinaryDecoder); ok && fd.FormatCode == BinaryFormatCode { 212 err = s.DecodeBinary(rows.conn.ConnInfo, buf) 213 if err != nil { 214 rows.fatal(scanArgError{col: i, err: err}) 215 } 216 } else if s, ok := d.(pgtype.TextDecoder); ok && fd.FormatCode == TextFormatCode { 217 err = s.DecodeText(rows.conn.ConnInfo, buf) 218 if err != nil { 219 rows.fatal(scanArgError{col: i, err: err}) 220 } 221 } else { 222 if dt, ok := rows.conn.ConnInfo.DataTypeForOID(fd.DataType); ok { 223 value := dt.Value 224 switch fd.FormatCode { 225 case TextFormatCode: 226 if textDecoder, ok := value.(pgtype.TextDecoder); ok { 227 err = textDecoder.DecodeText(rows.conn.ConnInfo, buf) 228 if err != nil { 229 rows.fatal(scanArgError{col: i, err: err}) 230 } 231 } else { 232 rows.fatal(scanArgError{col: i, err: errors.Errorf("%T is not a pgtype.TextDecoder", value)}) 233 } 234 case BinaryFormatCode: 235 if binaryDecoder, ok := value.(pgtype.BinaryDecoder); ok { 236 err = binaryDecoder.DecodeBinary(rows.conn.ConnInfo, buf) 237 if err != nil { 238 rows.fatal(scanArgError{col: i, err: err}) 239 } 240 } else { 241 rows.fatal(scanArgError{col: i, err: errors.Errorf("%T is not a pgtype.BinaryDecoder", value)}) 242 } 243 default: 244 rows.fatal(scanArgError{col: i, err: errors.Errorf("unknown format code: %v", fd.FormatCode)}) 245 } 246 247 if rows.Err() == nil { 248 if scanner, ok := d.(sql.Scanner); ok { 249 sqlSrc, err := pgtype.DatabaseSQLValue(rows.conn.ConnInfo, value) 250 if err != nil { 251 rows.fatal(err) 252 } 253 err = scanner.Scan(sqlSrc) 254 if err != nil { 255 rows.fatal(scanArgError{col: i, err: err}) 256 } 257 } else if err := value.AssignTo(d); err != nil { 258 rows.fatal(scanArgError{col: i, err: err}) 259 } 260 } 261 } else { 262 rows.fatal(scanArgError{col: i, err: errors.Errorf("unknown oid: %v", fd.DataType)}) 263 } 264 } 265 266 if rows.Err() != nil { 267 return rows.Err() 268 } 269 } 270 271 return nil 272} 273 274// Values returns an array of the row values 275func (rows *Rows) Values() ([]interface{}, error) { 276 if rows.closed { 277 return nil, errors.New("rows is closed") 278 } 279 280 values := make([]interface{}, 0, len(rows.fields)) 281 282 for range rows.fields { 283 buf, fd, _ := rows.nextColumn() 284 285 if buf == nil { 286 values = append(values, nil) 287 continue 288 } 289 290 if dt, ok := rows.conn.ConnInfo.DataTypeForOID(fd.DataType); ok { 291 value := reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(pgtype.Value) 292 293 switch fd.FormatCode { 294 case TextFormatCode: 295 decoder := value.(pgtype.TextDecoder) 296 if decoder == nil { 297 decoder = &pgtype.GenericText{} 298 } 299 err := decoder.DecodeText(rows.conn.ConnInfo, buf) 300 if err != nil { 301 rows.fatal(err) 302 } 303 values = append(values, decoder.(pgtype.Value).Get()) 304 case BinaryFormatCode: 305 decoder := value.(pgtype.BinaryDecoder) 306 if decoder == nil { 307 decoder = &pgtype.GenericBinary{} 308 } 309 err := decoder.DecodeBinary(rows.conn.ConnInfo, buf) 310 if err != nil { 311 rows.fatal(err) 312 } 313 values = append(values, value.Get()) 314 default: 315 rows.fatal(errors.New("Unknown format code")) 316 } 317 } else { 318 rows.fatal(errors.New("Unknown type")) 319 } 320 321 if rows.Err() != nil { 322 return nil, rows.Err() 323 } 324 } 325 326 return values, rows.Err() 327} 328 329// Query executes sql with args. If there is an error the returned *Rows will 330// be returned in an error state. So it is allowed to ignore the error returned 331// from Query and handle it in *Rows. 332func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) { 333 return c.QueryEx(context.Background(), sql, nil, args...) 334} 335 336func (c *Conn) getRows(sql string, args []interface{}) *Rows { 337 if len(c.preallocatedRows) == 0 { 338 c.preallocatedRows = make([]Rows, 64) 339 } 340 341 r := &c.preallocatedRows[len(c.preallocatedRows)-1] 342 c.preallocatedRows = c.preallocatedRows[0 : len(c.preallocatedRows)-1] 343 344 r.conn = c 345 r.startTime = c.lastActivityTime 346 r.sql = sql 347 r.args = args 348 349 return r 350} 351 352// QueryRow is a convenience wrapper over Query. Any error that occurs while 353// querying is deferred until calling Scan on the returned *Row. That *Row will 354// error with ErrNoRows if no rows are returned. 355func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { 356 rows, _ := c.Query(sql, args...) 357 return (*Row)(rows) 358} 359 360type QueryExOptions struct { 361 // When ParameterOIDs are present and the query is not a prepared statement, 362 // then ParameterOIDs and ResultFormatCodes will be used to avoid an extra 363 // network round-trip. 364 ParameterOIDs []pgtype.OID 365 ResultFormatCodes []int16 366 367 SimpleProtocol bool 368} 369 370func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (rows *Rows, err error) { 371 c.lastStmtSent = false 372 c.lastActivityTime = time.Now() 373 rows = c.getRows(sql, args) 374 375 err = c.waitForPreviousCancelQuery(ctx) 376 if err != nil { 377 rows.fatal(err) 378 return rows, err 379 } 380 381 if err := c.ensureConnectionReadyForQuery(); err != nil { 382 rows.fatal(err) 383 return rows, err 384 } 385 386 if err := c.lock(); err != nil { 387 rows.fatal(err) 388 return rows, err 389 } 390 rows.unlockConn = true 391 392 err = c.initContext(ctx) 393 if err != nil { 394 rows.fatal(err) 395 return rows, rows.err 396 } 397 398 if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) { 399 c.lastStmtSent = true 400 err = c.sanitizeAndSendSimpleQuery(sql, args...) 401 if err != nil { 402 rows.fatal(err) 403 return rows, err 404 } 405 406 return rows, nil 407 } 408 409 if options != nil && len(options.ParameterOIDs) > 0 { 410 411 buf, err := c.buildOneRoundTripQueryEx(c.wbuf, sql, options, args) 412 if err != nil { 413 rows.fatal(err) 414 return rows, err 415 } 416 417 buf = appendSync(buf) 418 419 c.lastStmtSent = true 420 n, err := c.conn.Write(buf) 421 if err != nil && fatalWriteErr(n, err) { 422 rows.fatal(err) 423 c.die(err) 424 return rows, err 425 } 426 c.pendingReadyForQueryCount++ 427 428 fieldDescriptions, err := c.readUntilRowDescription() 429 if err != nil { 430 rows.fatal(err) 431 return rows, err 432 } 433 434 if len(options.ResultFormatCodes) == 0 { 435 for i := range fieldDescriptions { 436 fieldDescriptions[i].FormatCode = TextFormatCode 437 } 438 } else if len(options.ResultFormatCodes) == 1 { 439 fc := options.ResultFormatCodes[0] 440 for i := range fieldDescriptions { 441 fieldDescriptions[i].FormatCode = fc 442 } 443 } else { 444 for i := range options.ResultFormatCodes { 445 fieldDescriptions[i].FormatCode = options.ResultFormatCodes[i] 446 } 447 } 448 449 rows.sql = sql 450 rows.fields = fieldDescriptions 451 return rows, nil 452 } 453 454 ps, ok := c.preparedStatements[sql] 455 if !ok { 456 var err error 457 ps, err = c.prepareEx("", sql, nil) 458 if err != nil { 459 rows.fatal(err) 460 return rows, rows.err 461 } 462 } 463 rows.sql = ps.SQL 464 rows.fields = ps.FieldDescriptions 465 466 c.lastStmtSent = true 467 err = c.sendPreparedQuery(ps, args...) 468 if err != nil { 469 rows.fatal(err) 470 } 471 472 return rows, rows.err 473} 474 475func (c *Conn) buildOneRoundTripQueryEx(buf []byte, sql string, options *QueryExOptions, arguments []interface{}) ([]byte, error) { 476 if len(arguments) != len(options.ParameterOIDs) { 477 return nil, errors.Errorf("mismatched number of arguments (%d) and options.ParameterOIDs (%d)", len(arguments), len(options.ParameterOIDs)) 478 } 479 480 if len(options.ParameterOIDs) > 65535 { 481 return nil, errors.Errorf("Number of QueryExOptions ParameterOIDs must be between 0 and 65535, received %d", len(options.ParameterOIDs)) 482 } 483 484 buf = appendParse(buf, "", sql, options.ParameterOIDs) 485 buf = appendDescribe(buf, 'S', "") 486 buf, err := appendBind(buf, "", "", c.ConnInfo, options.ParameterOIDs, arguments, options.ResultFormatCodes) 487 if err != nil { 488 return nil, err 489 } 490 buf = appendExecute(buf, "", 0) 491 492 return buf, nil 493} 494 495func (c *Conn) readUntilRowDescription() ([]FieldDescription, error) { 496 for { 497 msg, err := c.rxMsg() 498 if err != nil { 499 return nil, err 500 } 501 502 switch msg := msg.(type) { 503 case *pgproto3.ParameterDescription: 504 case *pgproto3.RowDescription: 505 fieldDescriptions := c.rxRowDescription(msg) 506 for i := range fieldDescriptions { 507 if dt, ok := c.ConnInfo.DataTypeForOID(fieldDescriptions[i].DataType); ok { 508 fieldDescriptions[i].DataTypeName = dt.Name 509 } else { 510 return nil, errors.Errorf("unknown oid: %d", fieldDescriptions[i].DataType) 511 } 512 } 513 return fieldDescriptions, nil 514 default: 515 if err := c.processContextFreeMsg(msg); err != nil { 516 return nil, err 517 } 518 } 519 } 520} 521 522func (c *Conn) sanitizeAndSendSimpleQuery(sql string, args ...interface{}) (err error) { 523 if c.RuntimeParams["standard_conforming_strings"] != "on" { 524 return errors.New("simple protocol queries must be run with standard_conforming_strings=on") 525 } 526 527 if c.RuntimeParams["client_encoding"] != "UTF8" { 528 return errors.New("simple protocol queries must be run with client_encoding=UTF8") 529 } 530 531 valueArgs := make([]interface{}, len(args)) 532 for i, a := range args { 533 valueArgs[i], err = convertSimpleArgument(c.ConnInfo, a) 534 if err != nil { 535 return err 536 } 537 } 538 539 sql, err = sanitize.SanitizeSQL(sql, valueArgs...) 540 if err != nil { 541 return err 542 } 543 544 return c.sendSimpleQuery(sql) 545} 546 547func (c *Conn) QueryRowEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) *Row { 548 rows, _ := c.QueryEx(ctx, sql, options, args...) 549 return (*Row)(rows) 550} 551