1// Copyright (C) 2020 Storj Labs, Inc.
2// See LICENSE for copying information.
3
4package tagsql
5
6import (
7	"context"
8	"database/sql"
9
10	"github.com/zeebo/errs"
11
12	"storj.io/private/traces"
13)
14
15// Stmt is an interface for *sql.Stmt.
16type Stmt interface {
17	// Exec and other methods take a context for tracing
18	// purposes, but do not pass the context to the underlying database query.
19	Exec(ctx context.Context, args ...interface{}) (sql.Result, error)
20	Query(ctx context.Context, args ...interface{}) (Rows, error)
21	QueryRow(ctx context.Context, args ...interface{}) *sql.Row
22
23	// ExecContext and other Context methods take a context for tracing and also
24	// pass the context to the underlying database, if this tagsql instance is
25	// configured to do so. (By default, lib/pq does not ever, and
26	// mattn/go-sqlite3 does not for transactions).
27	ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error)
28	QueryContext(ctx context.Context, args ...interface{}) (Rows, error)
29	QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row
30
31	Close() error
32}
33
34// sqlStmt implements Stmt, which optionally disables contexts.
35type sqlStmt struct {
36	query      string
37	stmt       *sql.Stmt
38	useContext bool
39	tracker    *tracker
40}
41
42func (s *sqlStmt) Close() error {
43	return errs.Combine(s.tracker.close(), s.stmt.Close())
44}
45
46func (s *sqlStmt) Exec(ctx context.Context, args ...interface{}) (_ sql.Result, err error) {
47	traces.Tag(ctx, traces.TagDB)
48	defer mon.Task()(&ctx, s.query, args)(&err)
49
50	return s.stmt.Exec(args...)
51}
52
53func (s *sqlStmt) ExecContext(ctx context.Context, args ...interface{}) (_ sql.Result, err error) {
54	traces.Tag(ctx, traces.TagDB)
55	defer mon.Task()(&ctx, s.query, args)(&err)
56
57	if !s.useContext {
58		return s.stmt.Exec(args...)
59	}
60	return s.stmt.ExecContext(ctx, args...)
61}
62
63func (s *sqlStmt) Query(ctx context.Context, args ...interface{}) (_ Rows, err error) {
64	traces.Tag(ctx, traces.TagDB)
65	defer mon.Task()(&ctx, s.query, args)(&err)
66
67	return s.tracker.wrapRows(s.stmt.Query(args...))
68}
69
70func (s *sqlStmt) QueryContext(ctx context.Context, args ...interface{}) (_ Rows, err error) {
71	traces.Tag(ctx, traces.TagDB)
72	defer mon.Task()(&ctx, s.query, args)(&err)
73
74	if !s.useContext {
75		return s.tracker.wrapRows(s.stmt.Query(args...))
76	}
77	return s.tracker.wrapRows(s.stmt.QueryContext(ctx, args...))
78}
79
80func (s *sqlStmt) QueryRow(ctx context.Context, args ...interface{}) *sql.Row {
81	traces.Tag(ctx, traces.TagDB)
82	defer mon.Task()(&ctx, s.query, args)(nil)
83
84	return s.stmt.QueryRow(args...)
85}
86
87func (s *sqlStmt) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row {
88	traces.Tag(ctx, traces.TagDB)
89	defer mon.Task()(&ctx, s.query, args)(nil)
90
91	if !s.useContext {
92		return s.stmt.QueryRow(args...)
93	}
94	return s.stmt.QueryRowContext(ctx, args...)
95}
96