1// Unless explicitly stated otherwise all files in this repository are licensed
2// under the Apache License Version 2.0.
3// This product includes software developed at Datadog (https://www.datadoghq.com/).
4// Copyright 2016 Datadog, Inc.
5
6package sql
7
8import (
9	"context"
10	"database/sql/driver"
11	"errors"
12	"time"
13)
14
15var _ driver.Stmt = (*tracedStmt)(nil)
16
17// tracedStmt is traced version of sql.Stmt
18type tracedStmt struct {
19	driver.Stmt
20	*traceParams
21	ctx   context.Context
22	query string
23}
24
25// Close sends a span before closing a statement
26func (s *tracedStmt) Close() (err error) {
27	start := time.Now()
28	err = s.Stmt.Close()
29	s.tryTrace(s.ctx, queryTypeClose, "", start, err)
30	return err
31}
32
33// ExecContext is needed to implement the driver.StmtExecContext interface
34func (s *tracedStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (res driver.Result, err error) {
35	start := time.Now()
36	if stmtExecContext, ok := s.Stmt.(driver.StmtExecContext); ok {
37		res, err := stmtExecContext.ExecContext(ctx, args)
38		s.tryTrace(ctx, queryTypeExec, s.query, start, err)
39		return res, err
40	}
41	dargs, err := namedValueToValue(args)
42	if err != nil {
43		return nil, err
44	}
45	select {
46	case <-ctx.Done():
47		return nil, ctx.Err()
48	default:
49	}
50	res, err = s.Exec(dargs)
51	s.tryTrace(ctx, queryTypeExec, s.query, start, err)
52	return res, err
53}
54
55// QueryContext is needed to implement the driver.StmtQueryContext interface
56func (s *tracedStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows driver.Rows, err error) {
57	start := time.Now()
58	if stmtQueryContext, ok := s.Stmt.(driver.StmtQueryContext); ok {
59		rows, err := stmtQueryContext.QueryContext(ctx, args)
60		s.tryTrace(ctx, queryTypeQuery, s.query, start, err)
61		return rows, err
62	}
63	dargs, err := namedValueToValue(args)
64	if err != nil {
65		return nil, err
66	}
67	select {
68	case <-ctx.Done():
69		return nil, ctx.Err()
70	default:
71	}
72	rows, err = s.Query(dargs)
73	s.tryTrace(ctx, queryTypeQuery, s.query, start, err)
74	return rows, err
75}
76
77// copied from stdlib database/sql package: src/database/sql/ctxutil.go
78func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
79	dargs := make([]driver.Value, len(named))
80	for n, param := range named {
81		if len(param.Name) > 0 {
82			return nil, errors.New("sql: driver does not support the use of Named Parameters")
83		}
84		dargs[n] = param.Value
85	}
86	return dargs, nil
87}
88