1// Unless explicitly stated otherwise all files in this repository are licensed
2// under the Apache License Version 2.0.
3// This product includes software developed at Datadog (https://www.datadoghq.com/).
4// Copyright 2016 Datadog, Inc.
5
6package gorm
7
8import (
9	"context"
10	"fmt"
11	"log"
12	"os"
13	"testing"
14
15	sqltrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql"
16	"gopkg.in/DataDog/dd-trace-go.v1/contrib/internal/sqltest"
17	"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext"
18	"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer"
19	"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
20	"gopkg.in/DataDog/dd-trace-go.v1/internal/globalconfig"
21
22	"github.com/go-sql-driver/mysql"
23	"github.com/jinzhu/gorm"
24	"github.com/lib/pq"
25	"github.com/stretchr/testify/assert"
26)
27
28// tableName holds the SQL table that these tests will be run against. It must be unique cross-repo.
29const tableName = "testgorm"
30
31func TestMain(m *testing.M) {
32	_, ok := os.LookupEnv("INTEGRATION")
33	if !ok {
34		fmt.Println("--- SKIP: to enable integration test, set the INTEGRATION environment variable")
35		os.Exit(0)
36	}
37	defer sqltest.Prepare(tableName)()
38	os.Exit(m.Run())
39}
40
41func TestMySQL(t *testing.T) {
42	sqltrace.Register("mysql", &mysql.MySQLDriver{}, sqltrace.WithServiceName("mysql-test"))
43	db, err := Open("mysql", "test:test@tcp(127.0.0.1:3306)/test")
44	if err != nil {
45		log.Fatal(err)
46	}
47	defer db.Close()
48
49	testConfig := &sqltest.Config{
50		DB:         db.DB(),
51		DriverName: "mysql",
52		TableName:  tableName,
53		ExpectName: "mysql.query",
54		ExpectTags: map[string]interface{}{
55			ext.ServiceName: "mysql-test",
56			ext.SpanType:    ext.SpanTypeSQL,
57			ext.TargetHost:  "127.0.0.1",
58			ext.TargetPort:  "3306",
59			"db.user":       "test",
60			"db.name":       "test",
61		},
62	}
63	sqltest.RunAll(t, testConfig)
64}
65
66func TestPostgres(t *testing.T) {
67	sqltrace.Register("postgres", &pq.Driver{})
68	db, err := Open("postgres", "postgres://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable")
69	if err != nil {
70		log.Fatal(err)
71	}
72	defer db.Close()
73
74	testConfig := &sqltest.Config{
75		DB:         db.DB(),
76		DriverName: "postgres",
77		TableName:  tableName,
78		ExpectName: "postgres.query",
79		ExpectTags: map[string]interface{}{
80			ext.ServiceName: "postgres.db",
81			ext.SpanType:    ext.SpanTypeSQL,
82			ext.TargetHost:  "127.0.0.1",
83			ext.TargetPort:  "5432",
84			"db.user":       "postgres",
85			"db.name":       "postgres",
86		},
87	}
88	sqltest.RunAll(t, testConfig)
89}
90
91type Product struct {
92	gorm.Model
93	Code  string
94	Price uint
95}
96
97func TestCallbacks(t *testing.T) {
98	assert := assert.New(t)
99	mt := mocktracer.Start()
100	defer mt.Stop()
101
102	sqltrace.Register("postgres", &pq.Driver{})
103	db, err := Open("postgres", "postgres://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable")
104	if err != nil {
105		log.Fatal(err)
106	}
107	defer db.Close()
108	db.AutoMigrate(&Product{})
109
110	t.Run("create", func(t *testing.T) {
111		parentSpan, ctx := tracer.StartSpanFromContext(context.Background(), "http.request",
112			tracer.ServiceName("fake-http-server"),
113			tracer.SpanType(ext.SpanTypeWeb),
114		)
115
116		db = WithContext(ctx, db)
117		db.Create(&Product{Code: "L1212", Price: 1000})
118
119		parentSpan.Finish()
120
121		spans := mt.FinishedSpans()
122		assert.True(len(spans) >= 3)
123
124		span := spans[len(spans)-3]
125		assert.Equal("gorm.create", span.OperationName())
126		assert.Equal(ext.SpanTypeSQL, span.Tag(ext.SpanType))
127		assert.Equal(
128			`INSERT INTO "products" ("created_at","updated_at","deleted_at","code","price") VALUES ($1,$2,$3,$4,$5) RETURNING "products"."id"`,
129			span.Tag(ext.ResourceName))
130	})
131
132	t.Run("query", func(t *testing.T) {
133		parentSpan, ctx := tracer.StartSpanFromContext(context.Background(), "http.request",
134			tracer.ServiceName("fake-http-server"),
135			tracer.SpanType(ext.SpanTypeWeb),
136		)
137
138		db = WithContext(ctx, db)
139		var product Product
140		db.First(&product, "code = ?", "L1212")
141
142		parentSpan.Finish()
143
144		spans := mt.FinishedSpans()
145		assert.True(len(spans) >= 2)
146
147		span := spans[len(spans)-2]
148		assert.Equal("gorm.query", span.OperationName())
149		assert.Equal(ext.SpanTypeSQL, span.Tag(ext.SpanType))
150		assert.Equal(
151			`SELECT * FROM "products"  WHERE "products"."deleted_at" IS NULL AND ((code = $1)) ORDER BY "products"."id" ASC LIMIT 1`,
152			span.Tag(ext.ResourceName))
153	})
154
155	t.Run("update", func(t *testing.T) {
156		parentSpan, ctx := tracer.StartSpanFromContext(context.Background(), "http.request",
157			tracer.ServiceName("fake-http-server"),
158			tracer.SpanType(ext.SpanTypeWeb),
159		)
160
161		db = WithContext(ctx, db)
162		var product Product
163		db.First(&product, "code = ?", "L1212")
164		db.Model(&product).Update("Price", 2000)
165
166		parentSpan.Finish()
167
168		spans := mt.FinishedSpans()
169		assert.True(len(spans) >= 3)
170
171		span := spans[len(spans)-3]
172		assert.Equal("gorm.update", span.OperationName())
173		assert.Equal(ext.SpanTypeSQL, span.Tag(ext.SpanType))
174		assert.Equal(
175			`UPDATE "products" SET "price" = $1, "updated_at" = $2  WHERE "products"."deleted_at" IS NULL AND "products"."id" = $3`,
176			span.Tag(ext.ResourceName))
177	})
178
179	t.Run("delete", func(t *testing.T) {
180		parentSpan, ctx := tracer.StartSpanFromContext(context.Background(), "http.request",
181			tracer.ServiceName("fake-http-server"),
182			tracer.SpanType(ext.SpanTypeWeb),
183		)
184
185		db = WithContext(ctx, db)
186		var product Product
187		db.First(&product, "code = ?", "L1212")
188		db.Delete(&product)
189
190		parentSpan.Finish()
191
192		spans := mt.FinishedSpans()
193		assert.True(len(spans) >= 3)
194
195		span := spans[len(spans)-3]
196		assert.Equal("gorm.delete", span.OperationName())
197		assert.Equal(ext.SpanTypeSQL, span.Tag(ext.SpanType))
198		assert.Equal(
199			`UPDATE "products" SET "deleted_at"=$1  WHERE "products"."deleted_at" IS NULL AND "products"."id" = $2`,
200			span.Tag(ext.ResourceName))
201	})
202}
203
204func TestAnalyticsSettings(t *testing.T) {
205	mt := mocktracer.Start()
206	defer mt.Stop()
207
208	sqltrace.Register("postgres", &pq.Driver{})
209	db, err := Open("postgres", "postgres://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable")
210	if err != nil {
211		log.Fatal(err)
212	}
213	defer db.Close()
214	db.AutoMigrate(&Product{})
215
216	assertRate := func(t *testing.T, mt mocktracer.Tracer, rate interface{}, opts ...Option) {
217		db, err := Open("postgres", "postgres://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable", opts...)
218		if err != nil {
219			log.Fatal(err)
220		}
221		defer db.Close()
222		db.AutoMigrate(&Product{})
223
224		parentSpan, ctx := tracer.StartSpanFromContext(context.Background(), "http.request",
225			tracer.ServiceName("fake-http-server"),
226			tracer.SpanType(ext.SpanTypeWeb),
227		)
228
229		db = WithContext(ctx, db)
230		db.Create(&Product{Code: "L1212", Price: 1000})
231
232		parentSpan.Finish()
233
234		spans := mt.FinishedSpans()
235		assert.True(t, len(spans) > 3)
236		s := spans[len(spans)-3]
237		assert.Equal(t, rate, s.Tag(ext.EventSampleRate))
238	}
239
240	t.Run("defaults", func(t *testing.T) {
241		mt := mocktracer.Start()
242		defer mt.Stop()
243
244		assertRate(t, mt, nil)
245	})
246
247	t.Run("global", func(t *testing.T) {
248		t.Skip("global flag disabled")
249		mt := mocktracer.Start()
250		defer mt.Stop()
251
252		rate := globalconfig.AnalyticsRate()
253		defer globalconfig.SetAnalyticsRate(rate)
254		globalconfig.SetAnalyticsRate(0.4)
255
256		assertRate(t, mt, 0.4)
257	})
258
259	t.Run("enabled", func(t *testing.T) {
260		mt := mocktracer.Start()
261		defer mt.Stop()
262
263		assertRate(t, mt, 1.0, WithAnalytics(true))
264	})
265
266	t.Run("disabled", func(t *testing.T) {
267		mt := mocktracer.Start()
268		defer mt.Stop()
269
270		assertRate(t, mt, nil, WithAnalytics(false))
271	})
272
273	t.Run("override", func(t *testing.T) {
274		mt := mocktracer.Start()
275		defer mt.Stop()
276
277		rate := globalconfig.AnalyticsRate()
278		defer globalconfig.SetAnalyticsRate(rate)
279		globalconfig.SetAnalyticsRate(0.4)
280
281		assertRate(t, mt, 0.23, WithAnalyticsRate(0.23))
282	})
283}
284
285func TestContext(t *testing.T) {
286	sqltrace.Register("postgres", &pq.Driver{})
287	db, err := Open("postgres", "postgres://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable")
288	if err != nil {
289		log.Fatal(err)
290	}
291	defer db.Close()
292
293	t.Run("with", func(t *testing.T) {
294		type key string
295		testCtx := context.WithValue(context.Background(), key("test context"), true)
296		db := WithContext(testCtx, db)
297		ctx := ContextFromDB(db)
298		assert.Equal(t, testCtx, ctx)
299	})
300
301	t.Run("without", func(t *testing.T) {
302		ctx := ContextFromDB(db)
303		assert.Equal(t, context.Background(), ctx)
304	})
305}
306
307func TestCustomTags(t *testing.T) {
308	assert := assert.New(t)
309	mt := mocktracer.Start()
310	defer mt.Stop()
311
312	sqltrace.Register("postgres", &pq.Driver{})
313	db, err := Open("postgres", "postgres://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable",
314		WithCustomTag("custom_tag", func(scope *gorm.Scope) interface{} {
315			return scope.SQLVars[3]
316		}),
317	)
318	if err != nil {
319		log.Fatal(err)
320	}
321	defer db.Close()
322	db.AutoMigrate(&Product{})
323
324	parentSpan, ctx := tracer.StartSpanFromContext(context.Background(), "http.request",
325		tracer.ServiceName("fake-http-server"),
326		tracer.SpanType(ext.SpanTypeWeb),
327	)
328
329	db = WithContext(ctx, db)
330	db.Create(&Product{Code: "L1212", Price: 1000})
331
332	parentSpan.Finish()
333
334	spans := mt.FinishedSpans()
335	assert.True(len(spans) >= 3)
336
337	// We deterministically expect the span to be the third last,
338	// followed by the underlying postgres DB trace and the above http.request span.
339	span := spans[len(spans)-3]
340	assert.Equal("gorm.create", span.OperationName())
341	assert.Equal(ext.SpanTypeSQL, span.Tag(ext.SpanType))
342	assert.Equal("L1212", span.Tag("custom_tag"))
343	assert.Equal(
344		`INSERT INTO "products" ("created_at","updated_at","deleted_at","code","price") VALUES ($1,$2,$3,$4,$5) RETURNING "products"."id"`,
345		span.Tag(ext.ResourceName))
346}
347
348func TestError(t *testing.T) {
349	mt := mocktracer.Start()
350	defer mt.Stop()
351
352	assertErrCheck := func(t *testing.T, mt mocktracer.Tracer, errExist bool, opts ...Option) {
353		sqltrace.Register("postgres", &pq.Driver{})
354		db, err := Open("postgres", "postgres://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable", opts...)
355		if err != nil {
356			log.Fatal(err)
357		}
358		defer db.Close()
359		db.AutoMigrate(&Product{})
360
361		db = WithContext(context.Background(), db)
362		db.Find(&Product{}, Product{Code: "L1210", Price: 2000})
363
364		spans := mt.FinishedSpans()
365		assert.True(t, len(spans) > 1)
366
367		// Get last span (gorm.db)
368		s := spans[len(spans)-1]
369		assert.Equal(t, errExist, s.Tag(ext.Error) != nil)
370	}
371
372	t.Run("defaults", func(t *testing.T) {
373		mt := mocktracer.Start()
374		defer mt.Stop()
375
376		assertErrCheck(t, mt, true)
377	})
378
379	t.Run("errcheck", func(t *testing.T) {
380		mt := mocktracer.Start()
381		defer mt.Stop()
382
383		errFn := func(err error) bool {
384			if err == gorm.ErrRecordNotFound {
385				return false
386			}
387			return true
388		}
389		assertErrCheck(t, mt, false, WithErrorCheck(errFn))
390	})
391}
392