1// Licensed to Elasticsearch B.V. under one or more contributor
2// license agreements. See the NOTICE file distributed with
3// this work for additional information regarding copyright
4// ownership. Elasticsearch B.V. licenses this file to you under
5// the Apache License, Version 2.0 (the "License"); you may
6// not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9//     http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18package apmsql
19
20import (
21	"context"
22	"database/sql/driver"
23	"errors"
24
25	"go.elastic.co/apm"
26)
27
28func newConn(in driver.Conn, d *tracingDriver, dsnInfo DSNInfo) driver.Conn {
29	conn := &conn{Conn: in, driver: d}
30	conn.dsnInfo = dsnInfo
31	conn.namedValueChecker, _ = in.(namedValueChecker)
32	conn.pinger, _ = in.(driver.Pinger)
33	conn.queryer, _ = in.(driver.Queryer)
34	conn.queryerContext, _ = in.(driver.QueryerContext)
35	conn.connPrepareContext, _ = in.(driver.ConnPrepareContext)
36	conn.execer, _ = in.(driver.Execer)
37	conn.execerContext, _ = in.(driver.ExecerContext)
38	conn.connBeginTx, _ = in.(driver.ConnBeginTx)
39	conn.connGo110.init(in)
40	if in, ok := in.(driver.ConnBeginTx); ok {
41		return &connBeginTx{conn, in}
42	}
43	return conn
44}
45
46type conn struct {
47	driver.Conn
48	connGo110
49	driver  *tracingDriver
50	dsnInfo DSNInfo
51
52	namedValueChecker  namedValueChecker
53	pinger             driver.Pinger
54	queryer            driver.Queryer
55	queryerContext     driver.QueryerContext
56	connPrepareContext driver.ConnPrepareContext
57	execer             driver.Execer
58	execerContext      driver.ExecerContext
59	connBeginTx        driver.ConnBeginTx
60}
61
62func (c *conn) startStmtSpan(ctx context.Context, stmt, spanType string) (*apm.Span, context.Context) {
63	return c.startSpan(ctx, c.driver.querySignature(stmt), spanType, stmt)
64}
65
66func (c *conn) startSpan(ctx context.Context, name, spanType, stmt string) (*apm.Span, context.Context) {
67	span, ctx := apm.StartSpan(ctx, name, spanType)
68	if !span.Dropped() {
69		span.Context.SetDatabase(apm.DatabaseSpanContext{
70			Instance:  c.dsnInfo.Database,
71			Statement: stmt,
72			Type:      "sql",
73			User:      c.dsnInfo.User,
74		})
75	}
76	return span, ctx
77}
78
79func (c *conn) finishSpan(ctx context.Context, span *apm.Span, resultError *error) {
80	if *resultError == driver.ErrSkip {
81		// TODO(axw) mark span as abandoned,
82		// so it's not sent and not counted
83		// in the span limit. Ideally remove
84		// from the slice so memory is kept
85		// in check.
86		return
87	}
88	switch *resultError {
89	case nil, driver.ErrBadConn, context.Canceled:
90		// ErrBadConn is used by the connection pooling
91		// logic in database/sql, and so is expected and
92		// should not be reported.
93		//
94		// context.Canceled means the callers canceled
95		// the operation, so this is also expected.
96	default:
97		if e := apm.CaptureError(ctx, *resultError); e != nil {
98			e.Send()
99		}
100	}
101	span.End()
102}
103
104func (c *conn) Ping(ctx context.Context) (resultError error) {
105	if c.pinger == nil {
106		return nil
107	}
108	span, ctx := c.startSpan(ctx, "ping", c.driver.pingSpanType, "")
109	defer c.finishSpan(ctx, span, &resultError)
110	return c.pinger.Ping(ctx)
111}
112
113func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (_ driver.Rows, resultError error) {
114	if c.queryerContext == nil && c.queryer == nil {
115		return nil, driver.ErrSkip
116	}
117	span, ctx := c.startStmtSpan(ctx, query, c.driver.querySpanType)
118	defer c.finishSpan(ctx, span, &resultError)
119
120	if c.queryerContext != nil {
121		return c.queryerContext.QueryContext(ctx, query, args)
122	}
123	dargs, err := namedValueToValue(args)
124	if err != nil {
125		return nil, err
126	}
127	select {
128	default:
129	case <-ctx.Done():
130		return nil, ctx.Err()
131	}
132	return c.queryer.Query(query, dargs)
133}
134
135func (*conn) Query(query string, args []driver.Value) (driver.Rows, error) {
136	return nil, errors.New("Query should never be called")
137}
138
139func (c *conn) PrepareContext(ctx context.Context, query string) (_ driver.Stmt, resultError error) {
140	span, ctx := c.startStmtSpan(ctx, query, c.driver.prepareSpanType)
141	defer c.finishSpan(ctx, span, &resultError)
142	var stmt driver.Stmt
143	var err error
144	if c.connPrepareContext != nil {
145		stmt, err = c.connPrepareContext.PrepareContext(ctx, query)
146	} else {
147		stmt, err = c.Prepare(query)
148		if err == nil {
149			select {
150			default:
151			case <-ctx.Done():
152				stmt.Close()
153				return nil, ctx.Err()
154			}
155		}
156	}
157	if stmt != nil {
158		stmt = newStmt(stmt, c, query)
159	}
160	return stmt, err
161}
162
163func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (_ driver.Result, resultError error) {
164	if c.execerContext == nil && c.execer == nil {
165		return nil, driver.ErrSkip
166	}
167	span, ctx := c.startStmtSpan(ctx, query, c.driver.execSpanType)
168	defer c.finishSpan(ctx, span, &resultError)
169
170	if c.execerContext != nil {
171		return c.execerContext.ExecContext(ctx, query, args)
172	}
173	dargs, err := namedValueToValue(args)
174	if err != nil {
175		return nil, err
176	}
177	select {
178	default:
179	case <-ctx.Done():
180		return nil, ctx.Err()
181	}
182	return c.execer.Exec(query, dargs)
183}
184
185func (*conn) Exec(query string, args []driver.Value) (driver.Result, error) {
186	return nil, errors.New("Exec should never be called")
187}
188
189func (c *conn) CheckNamedValue(nv *driver.NamedValue) error {
190	return checkNamedValue(nv, c.namedValueChecker)
191}
192
193type connBeginTx struct {
194	*conn
195	connBeginTx driver.ConnBeginTx
196}
197
198func (c *connBeginTx) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
199	// TODO(axw) instrument commit/rollback?
200	return c.connBeginTx.BeginTx(ctx, opts)
201}
202