1// Unless explicitly stated otherwise all files in this repository are licensed
2// under the Apache License Version 2.0.
3// This product includes software developed at Datadog (https://www.datadoghq.com/).
4// Copyright 2016 Datadog, Inc.
5
6package sql // import "gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql"
7
8import (
9	"context"
10	"database/sql/driver"
11	"fmt"
12	"math"
13	"time"
14
15	"gopkg.in/DataDog/dd-trace-go.v1/ddtrace"
16	"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext"
17	"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
18)
19
20var _ driver.Conn = (*tracedConn)(nil)
21
22type queryType string
23
24const (
25	queryTypeQuery    queryType = "Query"
26	queryTypePing               = "Ping"
27	queryTypePrepare            = "Prepare"
28	queryTypeExec               = "Exec"
29	queryTypeBegin              = "Begin"
30	queryTypeClose              = "Close"
31	queryTypeCommit             = "Commit"
32	queryTypeRollback           = "Rollback"
33)
34
35type tracedConn struct {
36	driver.Conn
37	*traceParams
38}
39
40func (tc *tracedConn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx, err error) {
41	start := time.Now()
42	if connBeginTx, ok := tc.Conn.(driver.ConnBeginTx); ok {
43		tx, err = connBeginTx.BeginTx(ctx, opts)
44		tc.tryTrace(ctx, queryTypeBegin, "", start, err)
45		if err != nil {
46			return nil, err
47		}
48		return &tracedTx{tx, tc.traceParams, ctx}, nil
49	}
50	tx, err = tc.Conn.Begin()
51	tc.tryTrace(ctx, queryTypeBegin, "", start, err)
52	if err != nil {
53		return nil, err
54	}
55	return &tracedTx{tx, tc.traceParams, ctx}, nil
56}
57
58func (tc *tracedConn) PrepareContext(ctx context.Context, query string) (stmt driver.Stmt, err error) {
59	start := time.Now()
60	if connPrepareCtx, ok := tc.Conn.(driver.ConnPrepareContext); ok {
61		stmt, err := connPrepareCtx.PrepareContext(ctx, query)
62		tc.tryTrace(ctx, queryTypePrepare, query, start, err)
63		if err != nil {
64			return nil, err
65		}
66		return &tracedStmt{stmt, tc.traceParams, ctx, query}, nil
67	}
68	stmt, err = tc.Prepare(query)
69	tc.tryTrace(ctx, queryTypePrepare, query, start, err)
70	if err != nil {
71		return nil, err
72	}
73	return &tracedStmt{stmt, tc.traceParams, ctx, query}, nil
74}
75
76func (tc *tracedConn) Exec(query string, args []driver.Value) (driver.Result, error) {
77	if execer, ok := tc.Conn.(driver.Execer); ok {
78		return execer.Exec(query, args)
79	}
80	return nil, driver.ErrSkip
81}
82
83func (tc *tracedConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (r driver.Result, err error) {
84	start := time.Now()
85	if execContext, ok := tc.Conn.(driver.ExecerContext); ok {
86		r, err := execContext.ExecContext(ctx, query, args)
87		tc.tryTrace(ctx, queryTypeExec, query, start, err)
88		return r, err
89	}
90	dargs, err := namedValueToValue(args)
91	if err != nil {
92		return nil, err
93	}
94	select {
95	case <-ctx.Done():
96		return nil, ctx.Err()
97	default:
98	}
99	r, err = tc.Exec(query, dargs)
100	tc.tryTrace(ctx, queryTypeExec, query, start, err)
101	return r, err
102}
103
104// tracedConn has a Ping method in order to implement the pinger interface
105func (tc *tracedConn) Ping(ctx context.Context) (err error) {
106	start := time.Now()
107	if pinger, ok := tc.Conn.(driver.Pinger); ok {
108		err = pinger.Ping(ctx)
109	}
110	tc.tryTrace(ctx, queryTypePing, "", start, err)
111	return err
112}
113
114func (tc *tracedConn) Query(query string, args []driver.Value) (driver.Rows, error) {
115	if queryer, ok := tc.Conn.(driver.Queryer); ok {
116		return queryer.Query(query, args)
117	}
118	return nil, driver.ErrSkip
119}
120
121func (tc *tracedConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) {
122	start := time.Now()
123	if queryerContext, ok := tc.Conn.(driver.QueryerContext); ok {
124		rows, err := queryerContext.QueryContext(ctx, query, args)
125		tc.tryTrace(ctx, queryTypeQuery, query, start, err)
126		return rows, err
127	}
128	dargs, err := namedValueToValue(args)
129	if err != nil {
130		return nil, err
131	}
132	select {
133	case <-ctx.Done():
134		return nil, ctx.Err()
135	default:
136	}
137	rows, err = tc.Query(query, dargs)
138	tc.tryTrace(ctx, queryTypeQuery, query, start, err)
139	return rows, err
140}
141
142func (tc *tracedConn) CheckNamedValue(value *driver.NamedValue) error {
143	if checker, ok := tc.Conn.(driver.NamedValueChecker); ok {
144		return checker.CheckNamedValue(value)
145	}
146	return driver.ErrSkip
147}
148
149var _ driver.SessionResetter = (*tracedConn)(nil)
150
151// ResetSession implements driver.SessionResetter
152func (tc *tracedConn) ResetSession(ctx context.Context) error {
153	if resetter, ok := tc.Conn.(driver.SessionResetter); ok {
154		return resetter.ResetSession(ctx)
155	}
156	return driver.ErrSkip
157}
158
159// traceParams stores all information related to tracing the driver.Conn
160type traceParams struct {
161	cfg        *config
162	driverName string
163	meta       map[string]string
164}
165
166type contextKey int
167
168const spanTagsKey contextKey = 0 // map[string]string
169
170// WithSpanTags creates a new context containing the given set of tags. They will be added
171// to any query created with the returned context.
172func WithSpanTags(ctx context.Context, tags map[string]string) context.Context {
173	return context.WithValue(ctx, spanTagsKey, tags)
174}
175
176// tryTrace will create a span using the given arguments, but will act as a no-op when err is driver.ErrSkip.
177func (tp *traceParams) tryTrace(ctx context.Context, qtype queryType, query string, startTime time.Time, err error) {
178	if err == driver.ErrSkip {
179		// Not a user error: driver is telling sql package that an
180		// optional interface method is not implemented. There is
181		// nothing to trace here.
182		// See: https://github.com/DataDog/dd-trace-go/issues/270
183		return
184	}
185	name := fmt.Sprintf("%s.query", tp.driverName)
186	opts := []ddtrace.StartSpanOption{
187		tracer.ServiceName(tp.cfg.serviceName),
188		tracer.SpanType(ext.SpanTypeSQL),
189		tracer.StartTime(startTime),
190	}
191	if !math.IsNaN(tp.cfg.analyticsRate) {
192		opts = append(opts, tracer.Tag(ext.EventSampleRate, tp.cfg.analyticsRate))
193	}
194	span, _ := tracer.StartSpanFromContext(ctx, name, opts...)
195	resource := string(qtype)
196	if query != "" {
197		resource = query
198	}
199	span.SetTag("sql.query_type", string(qtype))
200	span.SetTag(ext.ResourceName, resource)
201	for k, v := range tp.meta {
202		span.SetTag(k, v)
203	}
204	if meta, ok := ctx.Value(spanTagsKey).(map[string]string); ok {
205		for k, v := range meta {
206			span.SetTag(k, v)
207		}
208	}
209	span.Finish(tracer.WithError(err))
210}
211