1// Copyright (C) 2020 Storj Labs, Inc. 2// See LICENSE for copying information. 3 4package tagsql 5 6import ( 7 "context" 8 "database/sql" 9 10 "github.com/zeebo/errs" 11 12 "storj.io/private/traces" 13) 14 15// Stmt is an interface for *sql.Stmt. 16type Stmt interface { 17 // Exec and other methods take a context for tracing 18 // purposes, but do not pass the context to the underlying database query. 19 Exec(ctx context.Context, args ...interface{}) (sql.Result, error) 20 Query(ctx context.Context, args ...interface{}) (Rows, error) 21 QueryRow(ctx context.Context, args ...interface{}) *sql.Row 22 23 // ExecContext and other Context methods take a context for tracing and also 24 // pass the context to the underlying database, if this tagsql instance is 25 // configured to do so. (By default, lib/pq does not ever, and 26 // mattn/go-sqlite3 does not for transactions). 27 ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) 28 QueryContext(ctx context.Context, args ...interface{}) (Rows, error) 29 QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row 30 31 Close() error 32} 33 34// sqlStmt implements Stmt, which optionally disables contexts. 35type sqlStmt struct { 36 query string 37 stmt *sql.Stmt 38 useContext bool 39 tracker *tracker 40} 41 42func (s *sqlStmt) Close() error { 43 return errs.Combine(s.tracker.close(), s.stmt.Close()) 44} 45 46func (s *sqlStmt) Exec(ctx context.Context, args ...interface{}) (_ sql.Result, err error) { 47 traces.Tag(ctx, traces.TagDB) 48 defer mon.Task()(&ctx, s.query, args)(&err) 49 50 return s.stmt.Exec(args...) 51} 52 53func (s *sqlStmt) ExecContext(ctx context.Context, args ...interface{}) (_ sql.Result, err error) { 54 traces.Tag(ctx, traces.TagDB) 55 defer mon.Task()(&ctx, s.query, args)(&err) 56 57 if !s.useContext { 58 return s.stmt.Exec(args...) 59 } 60 return s.stmt.ExecContext(ctx, args...) 61} 62 63func (s *sqlStmt) Query(ctx context.Context, args ...interface{}) (_ Rows, err error) { 64 traces.Tag(ctx, traces.TagDB) 65 defer mon.Task()(&ctx, s.query, args)(&err) 66 67 return s.tracker.wrapRows(s.stmt.Query(args...)) 68} 69 70func (s *sqlStmt) QueryContext(ctx context.Context, args ...interface{}) (_ Rows, err error) { 71 traces.Tag(ctx, traces.TagDB) 72 defer mon.Task()(&ctx, s.query, args)(&err) 73 74 if !s.useContext { 75 return s.tracker.wrapRows(s.stmt.Query(args...)) 76 } 77 return s.tracker.wrapRows(s.stmt.QueryContext(ctx, args...)) 78} 79 80func (s *sqlStmt) QueryRow(ctx context.Context, args ...interface{}) *sql.Row { 81 traces.Tag(ctx, traces.TagDB) 82 defer mon.Task()(&ctx, s.query, args)(nil) 83 84 return s.stmt.QueryRow(args...) 85} 86 87func (s *sqlStmt) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row { 88 traces.Tag(ctx, traces.TagDB) 89 defer mon.Task()(&ctx, s.query, args)(nil) 90 91 if !s.useContext { 92 return s.stmt.QueryRow(args...) 93 } 94 return s.stmt.QueryRowContext(ctx, args...) 95} 96