1// Unless explicitly stated otherwise all files in this repository are licensed 2// under the Apache License Version 2.0. 3// This product includes software developed at Datadog (https://www.datadoghq.com/). 4// Copyright 2016 Datadog, Inc. 5 6package gorm 7 8import ( 9 "context" 10 "fmt" 11 "log" 12 "os" 13 "testing" 14 15 sqltrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql" 16 "gopkg.in/DataDog/dd-trace-go.v1/contrib/internal/sqltest" 17 "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" 18 "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer" 19 "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" 20 "gopkg.in/DataDog/dd-trace-go.v1/internal/globalconfig" 21 22 "github.com/go-sql-driver/mysql" 23 "github.com/jinzhu/gorm" 24 "github.com/lib/pq" 25 "github.com/stretchr/testify/assert" 26) 27 28// tableName holds the SQL table that these tests will be run against. It must be unique cross-repo. 29const tableName = "testgorm" 30 31func TestMain(m *testing.M) { 32 _, ok := os.LookupEnv("INTEGRATION") 33 if !ok { 34 fmt.Println("--- SKIP: to enable integration test, set the INTEGRATION environment variable") 35 os.Exit(0) 36 } 37 defer sqltest.Prepare(tableName)() 38 os.Exit(m.Run()) 39} 40 41func TestMySQL(t *testing.T) { 42 sqltrace.Register("mysql", &mysql.MySQLDriver{}, sqltrace.WithServiceName("mysql-test")) 43 db, err := Open("mysql", "test:test@tcp(127.0.0.1:3306)/test") 44 if err != nil { 45 log.Fatal(err) 46 } 47 defer db.Close() 48 49 testConfig := &sqltest.Config{ 50 DB: db.DB(), 51 DriverName: "mysql", 52 TableName: tableName, 53 ExpectName: "mysql.query", 54 ExpectTags: map[string]interface{}{ 55 ext.ServiceName: "mysql-test", 56 ext.SpanType: ext.SpanTypeSQL, 57 ext.TargetHost: "127.0.0.1", 58 ext.TargetPort: "3306", 59 "db.user": "test", 60 "db.name": "test", 61 }, 62 } 63 sqltest.RunAll(t, testConfig) 64} 65 66func TestPostgres(t *testing.T) { 67 sqltrace.Register("postgres", &pq.Driver{}) 68 db, err := Open("postgres", "postgres://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable") 69 if err != nil { 70 log.Fatal(err) 71 } 72 defer db.Close() 73 74 testConfig := &sqltest.Config{ 75 DB: db.DB(), 76 DriverName: "postgres", 77 TableName: tableName, 78 ExpectName: "postgres.query", 79 ExpectTags: map[string]interface{}{ 80 ext.ServiceName: "postgres.db", 81 ext.SpanType: ext.SpanTypeSQL, 82 ext.TargetHost: "127.0.0.1", 83 ext.TargetPort: "5432", 84 "db.user": "postgres", 85 "db.name": "postgres", 86 }, 87 } 88 sqltest.RunAll(t, testConfig) 89} 90 91type Product struct { 92 gorm.Model 93 Code string 94 Price uint 95} 96 97func TestCallbacks(t *testing.T) { 98 assert := assert.New(t) 99 mt := mocktracer.Start() 100 defer mt.Stop() 101 102 sqltrace.Register("postgres", &pq.Driver{}) 103 db, err := Open("postgres", "postgres://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable") 104 if err != nil { 105 log.Fatal(err) 106 } 107 defer db.Close() 108 db.AutoMigrate(&Product{}) 109 110 t.Run("create", func(t *testing.T) { 111 parentSpan, ctx := tracer.StartSpanFromContext(context.Background(), "http.request", 112 tracer.ServiceName("fake-http-server"), 113 tracer.SpanType(ext.SpanTypeWeb), 114 ) 115 116 db = WithContext(ctx, db) 117 db.Create(&Product{Code: "L1212", Price: 1000}) 118 119 parentSpan.Finish() 120 121 spans := mt.FinishedSpans() 122 assert.True(len(spans) >= 3) 123 124 span := spans[len(spans)-3] 125 assert.Equal("gorm.create", span.OperationName()) 126 assert.Equal(ext.SpanTypeSQL, span.Tag(ext.SpanType)) 127 assert.Equal( 128 `INSERT INTO "products" ("created_at","updated_at","deleted_at","code","price") VALUES ($1,$2,$3,$4,$5) RETURNING "products"."id"`, 129 span.Tag(ext.ResourceName)) 130 }) 131 132 t.Run("query", func(t *testing.T) { 133 parentSpan, ctx := tracer.StartSpanFromContext(context.Background(), "http.request", 134 tracer.ServiceName("fake-http-server"), 135 tracer.SpanType(ext.SpanTypeWeb), 136 ) 137 138 db = WithContext(ctx, db) 139 var product Product 140 db.First(&product, "code = ?", "L1212") 141 142 parentSpan.Finish() 143 144 spans := mt.FinishedSpans() 145 assert.True(len(spans) >= 2) 146 147 span := spans[len(spans)-2] 148 assert.Equal("gorm.query", span.OperationName()) 149 assert.Equal(ext.SpanTypeSQL, span.Tag(ext.SpanType)) 150 assert.Equal( 151 `SELECT * FROM "products" WHERE "products"."deleted_at" IS NULL AND ((code = $1)) ORDER BY "products"."id" ASC LIMIT 1`, 152 span.Tag(ext.ResourceName)) 153 }) 154 155 t.Run("update", func(t *testing.T) { 156 parentSpan, ctx := tracer.StartSpanFromContext(context.Background(), "http.request", 157 tracer.ServiceName("fake-http-server"), 158 tracer.SpanType(ext.SpanTypeWeb), 159 ) 160 161 db = WithContext(ctx, db) 162 var product Product 163 db.First(&product, "code = ?", "L1212") 164 db.Model(&product).Update("Price", 2000) 165 166 parentSpan.Finish() 167 168 spans := mt.FinishedSpans() 169 assert.True(len(spans) >= 3) 170 171 span := spans[len(spans)-3] 172 assert.Equal("gorm.update", span.OperationName()) 173 assert.Equal(ext.SpanTypeSQL, span.Tag(ext.SpanType)) 174 assert.Equal( 175 `UPDATE "products" SET "price" = $1, "updated_at" = $2 WHERE "products"."deleted_at" IS NULL AND "products"."id" = $3`, 176 span.Tag(ext.ResourceName)) 177 }) 178 179 t.Run("delete", func(t *testing.T) { 180 parentSpan, ctx := tracer.StartSpanFromContext(context.Background(), "http.request", 181 tracer.ServiceName("fake-http-server"), 182 tracer.SpanType(ext.SpanTypeWeb), 183 ) 184 185 db = WithContext(ctx, db) 186 var product Product 187 db.First(&product, "code = ?", "L1212") 188 db.Delete(&product) 189 190 parentSpan.Finish() 191 192 spans := mt.FinishedSpans() 193 assert.True(len(spans) >= 3) 194 195 span := spans[len(spans)-3] 196 assert.Equal("gorm.delete", span.OperationName()) 197 assert.Equal(ext.SpanTypeSQL, span.Tag(ext.SpanType)) 198 assert.Equal( 199 `UPDATE "products" SET "deleted_at"=$1 WHERE "products"."deleted_at" IS NULL AND "products"."id" = $2`, 200 span.Tag(ext.ResourceName)) 201 }) 202} 203 204func TestAnalyticsSettings(t *testing.T) { 205 mt := mocktracer.Start() 206 defer mt.Stop() 207 208 sqltrace.Register("postgres", &pq.Driver{}) 209 db, err := Open("postgres", "postgres://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable") 210 if err != nil { 211 log.Fatal(err) 212 } 213 defer db.Close() 214 db.AutoMigrate(&Product{}) 215 216 assertRate := func(t *testing.T, mt mocktracer.Tracer, rate interface{}, opts ...Option) { 217 db, err := Open("postgres", "postgres://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable", opts...) 218 if err != nil { 219 log.Fatal(err) 220 } 221 defer db.Close() 222 db.AutoMigrate(&Product{}) 223 224 parentSpan, ctx := tracer.StartSpanFromContext(context.Background(), "http.request", 225 tracer.ServiceName("fake-http-server"), 226 tracer.SpanType(ext.SpanTypeWeb), 227 ) 228 229 db = WithContext(ctx, db) 230 db.Create(&Product{Code: "L1212", Price: 1000}) 231 232 parentSpan.Finish() 233 234 spans := mt.FinishedSpans() 235 assert.True(t, len(spans) > 3) 236 s := spans[len(spans)-3] 237 assert.Equal(t, rate, s.Tag(ext.EventSampleRate)) 238 } 239 240 t.Run("defaults", func(t *testing.T) { 241 mt := mocktracer.Start() 242 defer mt.Stop() 243 244 assertRate(t, mt, nil) 245 }) 246 247 t.Run("global", func(t *testing.T) { 248 t.Skip("global flag disabled") 249 mt := mocktracer.Start() 250 defer mt.Stop() 251 252 rate := globalconfig.AnalyticsRate() 253 defer globalconfig.SetAnalyticsRate(rate) 254 globalconfig.SetAnalyticsRate(0.4) 255 256 assertRate(t, mt, 0.4) 257 }) 258 259 t.Run("enabled", func(t *testing.T) { 260 mt := mocktracer.Start() 261 defer mt.Stop() 262 263 assertRate(t, mt, 1.0, WithAnalytics(true)) 264 }) 265 266 t.Run("disabled", func(t *testing.T) { 267 mt := mocktracer.Start() 268 defer mt.Stop() 269 270 assertRate(t, mt, nil, WithAnalytics(false)) 271 }) 272 273 t.Run("override", func(t *testing.T) { 274 mt := mocktracer.Start() 275 defer mt.Stop() 276 277 rate := globalconfig.AnalyticsRate() 278 defer globalconfig.SetAnalyticsRate(rate) 279 globalconfig.SetAnalyticsRate(0.4) 280 281 assertRate(t, mt, 0.23, WithAnalyticsRate(0.23)) 282 }) 283} 284 285func TestContext(t *testing.T) { 286 sqltrace.Register("postgres", &pq.Driver{}) 287 db, err := Open("postgres", "postgres://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable") 288 if err != nil { 289 log.Fatal(err) 290 } 291 defer db.Close() 292 293 t.Run("with", func(t *testing.T) { 294 type key string 295 testCtx := context.WithValue(context.Background(), key("test context"), true) 296 db := WithContext(testCtx, db) 297 ctx := ContextFromDB(db) 298 assert.Equal(t, testCtx, ctx) 299 }) 300 301 t.Run("without", func(t *testing.T) { 302 ctx := ContextFromDB(db) 303 assert.Equal(t, context.Background(), ctx) 304 }) 305} 306 307func TestCustomTags(t *testing.T) { 308 assert := assert.New(t) 309 mt := mocktracer.Start() 310 defer mt.Stop() 311 312 sqltrace.Register("postgres", &pq.Driver{}) 313 db, err := Open("postgres", "postgres://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable", 314 WithCustomTag("custom_tag", func(scope *gorm.Scope) interface{} { 315 return scope.SQLVars[3] 316 }), 317 ) 318 if err != nil { 319 log.Fatal(err) 320 } 321 defer db.Close() 322 db.AutoMigrate(&Product{}) 323 324 parentSpan, ctx := tracer.StartSpanFromContext(context.Background(), "http.request", 325 tracer.ServiceName("fake-http-server"), 326 tracer.SpanType(ext.SpanTypeWeb), 327 ) 328 329 db = WithContext(ctx, db) 330 db.Create(&Product{Code: "L1212", Price: 1000}) 331 332 parentSpan.Finish() 333 334 spans := mt.FinishedSpans() 335 assert.True(len(spans) >= 3) 336 337 // We deterministically expect the span to be the third last, 338 // followed by the underlying postgres DB trace and the above http.request span. 339 span := spans[len(spans)-3] 340 assert.Equal("gorm.create", span.OperationName()) 341 assert.Equal(ext.SpanTypeSQL, span.Tag(ext.SpanType)) 342 assert.Equal("L1212", span.Tag("custom_tag")) 343 assert.Equal( 344 `INSERT INTO "products" ("created_at","updated_at","deleted_at","code","price") VALUES ($1,$2,$3,$4,$5) RETURNING "products"."id"`, 345 span.Tag(ext.ResourceName)) 346} 347 348func TestError(t *testing.T) { 349 mt := mocktracer.Start() 350 defer mt.Stop() 351 352 assertErrCheck := func(t *testing.T, mt mocktracer.Tracer, errExist bool, opts ...Option) { 353 sqltrace.Register("postgres", &pq.Driver{}) 354 db, err := Open("postgres", "postgres://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable", opts...) 355 if err != nil { 356 log.Fatal(err) 357 } 358 defer db.Close() 359 db.AutoMigrate(&Product{}) 360 361 db = WithContext(context.Background(), db) 362 db.Find(&Product{}, Product{Code: "L1210", Price: 2000}) 363 364 spans := mt.FinishedSpans() 365 assert.True(t, len(spans) > 1) 366 367 // Get last span (gorm.db) 368 s := spans[len(spans)-1] 369 assert.Equal(t, errExist, s.Tag(ext.Error) != nil) 370 } 371 372 t.Run("defaults", func(t *testing.T) { 373 mt := mocktracer.Start() 374 defer mt.Stop() 375 376 assertErrCheck(t, mt, true) 377 }) 378 379 t.Run("errcheck", func(t *testing.T) { 380 mt := mocktracer.Start() 381 defer mt.Stop() 382 383 errFn := func(err error) bool { 384 if err == gorm.ErrRecordNotFound { 385 return false 386 } 387 return true 388 } 389 assertErrCheck(t, mt, false, WithErrorCheck(errFn)) 390 }) 391} 392