1// Unless explicitly stated otherwise all files in this repository are licensed
2// under the Apache License Version 2.0.
3// This product includes software developed at Datadog (https://www.datadoghq.com/).
4// Copyright 2016 Datadog, Inc.
5
6// Package sql provides functions to trace the database/sql package (https://golang.org/pkg/database/sql).
7// It will automatically augment operations such as connections, statements and transactions with tracing.
8//
9// We start by telling the package which driver we will be using. For example, if we are using "github.com/lib/pq",
10// we would do as follows:
11//
12// 	sqltrace.Register("pq", pq.Driver{})
13//	db, err := sqltrace.Open("pq", "postgres://pqgotest:password@localhost...")
14//
15// The rest of our application would continue as usual, but with tracing enabled.
16//
17package sql
18
19import (
20	"context"
21	"database/sql"
22	"database/sql/driver"
23	"errors"
24	"math"
25	"reflect"
26
27	"gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql/internal"
28	"gopkg.in/DataDog/dd-trace-go.v1/internal/log"
29)
30
31// registeredDrivers holds a registry of all drivers registered via the sqltrace package.
32var registeredDrivers = &driverRegistry{
33	keys:    make(map[reflect.Type]string),
34	drivers: make(map[string]driver.Driver),
35	configs: make(map[string]*config),
36}
37
38type driverRegistry struct {
39	// keys maps driver types to their registered names.
40	keys map[reflect.Type]string
41	// drivers maps keys to their registered driver.
42	drivers map[string]driver.Driver
43	// configs maps keys to their registered configuration.
44	configs map[string]*config
45}
46
47// isRegistered reports whether the name matches an existing entry
48// in the driver registry.
49func (d *driverRegistry) isRegistered(name string) bool {
50	_, ok := d.configs[name]
51	return ok
52}
53
54// add adds the driver with the given name and config to the registry.
55func (d *driverRegistry) add(name string, driver driver.Driver, cfg *config) {
56	if d.isRegistered(name) {
57		return
58	}
59	d.keys[reflect.TypeOf(driver)] = name
60	d.drivers[name] = driver
61	d.configs[name] = cfg
62}
63
64// name returns the name of the driver stored in the registry.
65func (d *driverRegistry) name(driver driver.Driver) (string, bool) {
66	name, ok := d.keys[reflect.TypeOf(driver)]
67	return name, ok
68}
69
70// driver returns the driver stored in the registry with the provided name.
71func (d *driverRegistry) driver(name string) (driver.Driver, bool) {
72	driver, ok := d.drivers[name]
73	return driver, ok
74}
75
76// config returns the config stored in the registry with the provided name.
77func (d *driverRegistry) config(name string) (*config, bool) {
78	config, ok := d.configs[name]
79	return config, ok
80}
81
82// Register tells the sql integration package about the driver that we will be tracing. It must
83// be called before Open, if that connection is to be traced. It uses the driverName suffixed
84// with ".db" as the default service name.
85func Register(driverName string, driver driver.Driver, opts ...RegisterOption) {
86	if driver == nil {
87		panic("sqltrace: Register driver is nil")
88	}
89	if registeredDrivers.isRegistered(driverName) {
90		// already registered, don't change things
91		return
92	}
93
94	cfg := new(config)
95	defaults(cfg)
96	for _, fn := range opts {
97		fn(cfg)
98	}
99	if cfg.serviceName == "" {
100		cfg.serviceName = driverName + ".db"
101	}
102	log.Debug("contrib/database/sql: Registering driver: %s %#v", driverName, cfg)
103	registeredDrivers.add(driverName, driver, cfg)
104}
105
106// errNotRegistered is returned when there is an attempt to open a database connection towards a driver
107// that has not previously been registered using this package.
108var errNotRegistered = errors.New("sqltrace: Register must be called before Open")
109
110type tracedConnector struct {
111	connector  driver.Connector
112	driverName string
113	cfg        *config
114}
115
116func (t *tracedConnector) Connect(c context.Context) (driver.Conn, error) {
117	conn, err := t.connector.Connect(c)
118	if err != nil {
119		return nil, err
120	}
121	tp := &traceParams{
122		driverName: t.driverName,
123		cfg:        t.cfg,
124	}
125	if dc, ok := t.connector.(*dsnConnector); ok {
126		tp.meta, _ = internal.ParseDSN(t.driverName, dc.dsn)
127	} else if t.cfg.dsn != "" {
128		tp.meta, _ = internal.ParseDSN(t.driverName, t.cfg.dsn)
129	}
130	return &tracedConn{conn, tp}, err
131}
132
133func (t *tracedConnector) Driver() driver.Driver {
134	return t.connector.Driver()
135}
136
137// from Go stdlib implementation of sql.Open
138type dsnConnector struct {
139	dsn    string
140	driver driver.Driver
141}
142
143func (t dsnConnector) Connect(_ context.Context) (driver.Conn, error) {
144	return t.driver.Open(t.dsn)
145}
146
147func (t dsnConnector) Driver() driver.Driver {
148	return t.driver
149}
150
151// OpenDB returns connection to a DB using a the traced version of the given driver. In order for OpenDB
152// to work, the driver must first be registered using Register. If this did not occur, OpenDB will panic.
153func OpenDB(c driver.Connector, opts ...Option) *sql.DB {
154	name, ok := registeredDrivers.name(c.Driver())
155	if !ok {
156		panic("sqltrace.OpenDB: driver is not registered via sqltrace.Register")
157	}
158	rc, _ := registeredDrivers.config(name)
159	cfg := new(config)
160	defaults(cfg)
161	for _, fn := range opts {
162		fn(cfg)
163	}
164	// use registered config for unset options
165	if cfg.serviceName == "" {
166		cfg.serviceName = rc.serviceName
167	}
168	if math.IsNaN(cfg.analyticsRate) {
169		cfg.analyticsRate = rc.analyticsRate
170	}
171	tc := &tracedConnector{
172		connector:  c,
173		driverName: name,
174		cfg:        cfg,
175	}
176	return sql.OpenDB(tc)
177}
178
179// Open returns connection to a DB using a the traced version of the given driver. In order for Open
180// to work, the driver must first be registered using Register. If this did not occur, Open will
181// return an error.
182func Open(driverName, dataSourceName string, opts ...Option) (*sql.DB, error) {
183	if !registeredDrivers.isRegistered(driverName) {
184		return nil, errNotRegistered
185	}
186	d, _ := registeredDrivers.driver(driverName)
187	return OpenDB(&dsnConnector{dsn: dataSourceName, driver: d}, opts...), nil
188}
189