1// Unless explicitly stated otherwise all files in this repository are licensed 2// under the Apache License Version 2.0. 3// This product includes software developed at Datadog (https://www.datadoghq.com/). 4// Copyright 2016 Datadog, Inc. 5 6package vault 7 8import ( 9 "encoding/json" 10 "fmt" 11 "io/ioutil" 12 "net/http" 13 "net/http/httptest" 14 "testing" 15 16 "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" 17 "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer" 18 19 "github.com/hashicorp/vault/api" 20 "github.com/stretchr/testify/assert" 21) 22 23const secretMountPath = "/ns1/ns2/secret" 24 25func setupServer(t *testing.T) (*httptest.Server, func()) { 26 storage := make(map[string]string) 27 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 28 switch r.Method { 29 case http.MethodPut: 30 slurp, err := ioutil.ReadAll(r.Body) 31 if err != nil { 32 t.Fatal(err) 33 } 34 defer r.Body.Close() 35 storage[r.URL.Path] = string(slurp) 36 fmt.Fprintln(w, "{}") 37 case http.MethodGet: 38 val, ok := storage[r.URL.Path] 39 if !ok { 40 http.Error(w, "No data for key.", http.StatusNotFound) 41 return 42 } 43 secret := api.Secret{Data: make(map[string]interface{})} 44 json.Unmarshal([]byte(val), &secret.Data) 45 if err := json.NewEncoder(w).Encode(secret); err != nil { 46 t.Fatal(err) 47 } 48 } 49 })) 50 return ts, func() { 51 ts.Close() 52 } 53} 54 55func setupClient(ts *httptest.Server) (*api.Client, error) { 56 config := &api.Config{ 57 HttpClient: NewHTTPClient(), 58 Address: ts.URL, 59 } 60 client, err := api.NewClient(config) 61 if err != nil { 62 return nil, err 63 } 64 return client, nil 65} 66 67func TestNewHTTPClient(t *testing.T) { 68 ts, cleanup := setupServer(t) 69 defer cleanup() 70 71 client, err := setupClient(ts) 72 if err != nil { 73 t.Fatal(err) 74 } 75 testMountReadWrite(client, t) 76} 77 78func TestWrapHTTPClient(t *testing.T) { 79 ts, cleanup := setupServer(t) 80 defer cleanup() 81 82 httpClient := http.Client{} 83 config := &api.Config{ 84 HttpClient: WrapHTTPClient(&httpClient), 85 Address: ts.URL, 86 } 87 client, err := api.NewClient(config) 88 if err != nil { 89 t.Fatal(err) 90 } 91 client.SetToken("myroot") 92 93 testMountReadWrite(client, t) 94} 95 96// mountKV mounts the K/V engine on secretMountPath and returns a function to unmount it. 97// See: https://www.vaultproject.io/docs/secrets/ 98func mountKV(c *api.Client, t *testing.T) func() { 99 secretMount := api.MountInput{ 100 Type: "kv", 101 Description: "Test KV Store", 102 Local: true, 103 } 104 if err := c.Sys().Mount(secretMountPath, &secretMount); err != nil { 105 t.Fatal(err) 106 } 107 return func() { 108 c.Sys().Unmount(secretMountPath) 109 } 110} 111 112func testMountReadWrite(c *api.Client, t *testing.T) { 113 key := secretMountPath + "/test" 114 fullPath := "/v1" + key 115 data := map[string]interface{}{"Key1": "Val1", "Key2": "Val2"} 116 117 t.Run("mount", func(t *testing.T) { 118 assert := assert.New(t) 119 mt := mocktracer.Start() 120 defer mt.Stop() 121 defer mountKV(c, t)() 122 123 spans := mt.FinishedSpans() 124 assert.Len(spans, 1) 125 span := spans[0] 126 127 // Mount operation 128 assert.Equal("vault", span.Tag(ext.ServiceName)) 129 assert.Equal("/v1/sys/mounts/ns1/ns2/secret", span.Tag(ext.HTTPURL)) 130 assert.Equal(http.MethodPost, span.Tag(ext.HTTPMethod)) 131 assert.Equal(http.MethodPost+" /v1/sys/mounts/ns1/ns2/secret", span.Tag(ext.ResourceName)) 132 assert.Equal(ext.SpanTypeHTTP, span.Tag(ext.SpanType)) 133 assert.Equal(200, span.Tag(ext.HTTPCode)) 134 assert.Nil(span.Tag(ext.Error)) 135 assert.Nil(span.Tag(ext.ErrorMsg)) 136 assert.Nil(span.Tag("vault.namespace")) 137 }) 138 139 t.Run("write", func(t *testing.T) { 140 assert := assert.New(t) 141 mt := mocktracer.Start() 142 defer mt.Stop() 143 defer mountKV(c, t)() 144 145 // Write key 146 _, err := c.Logical().Write(key, data) 147 if err != nil { 148 t.Fatal(err) 149 } 150 spans := mt.FinishedSpans() 151 assert.Len(spans, 2) 152 span := spans[1] 153 154 assert.Equal("vault", span.Tag(ext.ServiceName)) 155 assert.Equal(fullPath, span.Tag(ext.HTTPURL)) 156 assert.Equal(http.MethodPut, span.Tag(ext.HTTPMethod)) 157 assert.Equal(http.MethodPut+" "+fullPath, span.Tag(ext.ResourceName)) 158 assert.Equal(ext.SpanTypeHTTP, span.Tag(ext.SpanType)) 159 assert.Equal(200, span.Tag(ext.HTTPCode)) 160 assert.Nil(span.Tag(ext.Error)) 161 assert.Nil(span.Tag(ext.ErrorMsg)) 162 assert.Nil(span.Tag("vault.namespace")) 163 }) 164 165 t.Run("read", func(t *testing.T) { 166 assert := assert.New(t) 167 mt := mocktracer.Start() 168 defer mt.Stop() 169 defer mountKV(c, t)() 170 171 // Write the key first 172 _, err := c.Logical().Write(key, data) 173 if err != nil { 174 t.Fatal(err) 175 } 176 // Read key 177 secret, err := c.Logical().Read(key) 178 if err != nil { 179 t.Fatal(err) 180 } 181 spans := mt.FinishedSpans() 182 assert.Len(spans, 3) 183 span := spans[2] 184 185 assert.Equal(secret.Data["Key1"], data["Key1"]) 186 assert.Equal(secret.Data["Key2"], data["Key2"]) 187 assert.Equal("vault", span.Tag(ext.ServiceName)) 188 assert.Equal(fullPath, span.Tag(ext.HTTPURL)) 189 assert.Equal(http.MethodGet, span.Tag(ext.HTTPMethod)) 190 assert.Equal(http.MethodGet+" "+fullPath, span.Tag(ext.ResourceName)) 191 assert.Equal(ext.SpanTypeHTTP, span.Tag(ext.SpanType)) 192 assert.Equal(200, span.Tag(ext.HTTPCode)) 193 assert.Nil(span.Tag(ext.Error)) 194 assert.Nil(span.Tag(ext.ErrorMsg)) 195 assert.Nil(span.Tag("vault.namespace")) 196 }) 197} 198 199func TestReadError(t *testing.T) { 200 assert := assert.New(t) 201 mt := mocktracer.Start() 202 defer mt.Stop() 203 204 ts, cleanup := setupServer(t) 205 defer cleanup() 206 client, err := setupClient(ts) 207 if err != nil { 208 t.Fatal(err) 209 } 210 defer mountKV(client, t)() 211 212 key := "/some/bad/key" 213 fullPath := "/v1" + key 214 secret, err := client.Logical().Read(key) 215 if err == nil { 216 t.Fatalf("Expected error when reading key from %s, but it returned: %#v", key, secret) 217 } 218 spans := mt.FinishedSpans() 219 assert.Len(spans, 2) 220 span := spans[1] 221 222 // Read key error 223 assert.Equal("vault", span.Tag(ext.ServiceName)) 224 assert.Equal(fullPath, span.Tag(ext.HTTPURL)) 225 assert.Equal(http.MethodGet, span.Tag(ext.HTTPMethod)) 226 assert.Equal(http.MethodGet+" "+fullPath, span.Tag(ext.ResourceName)) 227 assert.Equal(ext.SpanTypeHTTP, span.Tag(ext.SpanType)) 228 assert.Equal(404, span.Tag(ext.HTTPCode)) 229 assert.Equal(true, span.Tag(ext.Error)) 230 assert.NotNil(span.Tag(ext.ErrorMsg)) 231 assert.Nil(span.Tag("vault.namespace")) 232} 233 234func TestNamespace(t *testing.T) { 235 ts, cleanup := setupServer(t) 236 defer cleanup() 237 client, err := setupClient(ts) 238 if err != nil { 239 t.Fatal(err) 240 } 241 defer mountKV(client, t)() 242 243 namespace := "/some/namespace" 244 client.SetNamespace(namespace) 245 key := secretMountPath + "/testNamespace" 246 fullPath := "/v1" + key 247 248 t.Run("write", func(t *testing.T) { 249 assert := assert.New(t) 250 mt := mocktracer.Start() 251 defer mt.Stop() 252 253 // Write key with namespace 254 data := map[string]interface{}{"Key1": "Val1", "Key2": "Val2"} 255 _, err = client.Logical().Write(key, data) 256 if err != nil { 257 t.Fatal(err) 258 } 259 spans := mt.FinishedSpans() 260 assert.Len(spans, 1) 261 span := spans[0] 262 263 assert.Equal("vault", span.Tag(ext.ServiceName)) 264 assert.Equal(fullPath, span.Tag(ext.HTTPURL)) 265 assert.Equal(http.MethodPut, span.Tag(ext.HTTPMethod)) 266 assert.Equal(http.MethodPut+" "+fullPath, span.Tag(ext.ResourceName)) 267 assert.Equal(ext.SpanTypeHTTP, span.Tag(ext.SpanType)) 268 assert.Equal(200, span.Tag(ext.HTTPCode)) 269 assert.Nil(span.Tag(ext.Error)) 270 assert.Nil(span.Tag(ext.ErrorMsg)) 271 assert.Equal(namespace, span.Tag("vault.namespace")) 272 }) 273 274 t.Run("read", func(t *testing.T) { 275 assert := assert.New(t) 276 mt := mocktracer.Start() 277 defer mt.Stop() 278 279 // Write key with namespace first 280 data := map[string]interface{}{"Key1": "Val1", "Key2": "Val2"} 281 _, err = client.Logical().Write(key, data) 282 if err != nil { 283 t.Fatal(err) 284 } 285 // Read key with namespace 286 _, err = client.Logical().Read(key) 287 if err != nil { 288 t.Fatal(err) 289 } 290 spans := mt.FinishedSpans() 291 assert.Len(spans, 2) 292 span := spans[1] 293 294 assert.Equal("vault", span.Tag(ext.ServiceName)) 295 assert.Equal(fullPath, span.Tag(ext.HTTPURL)) 296 assert.Equal(http.MethodGet, span.Tag(ext.HTTPMethod)) 297 assert.Equal(http.MethodGet+" "+fullPath, span.Tag(ext.ResourceName)) 298 assert.Equal(ext.SpanTypeHTTP, span.Tag(ext.SpanType)) 299 assert.Equal(200, span.Tag(ext.HTTPCode)) 300 assert.Nil(span.Tag(ext.Error)) 301 assert.Nil(span.Tag(ext.ErrorMsg)) 302 assert.Equal(namespace, span.Tag("vault.namespace")) 303 }) 304} 305 306func TestOption(t *testing.T) { 307 ts, cleanup := setupServer(t) 308 defer cleanup() 309 310 for ttName, tt := range map[string]struct { 311 opts []Option 312 test func(assert *assert.Assertions, span mocktracer.Span) 313 }{ 314 "DefaultOptions": { 315 opts: []Option{}, 316 test: func(assert *assert.Assertions, span mocktracer.Span) { 317 assert.Equal(defaultServiceName, span.Tag(ext.ServiceName)) 318 assert.Nil(span.Tag(ext.EventSampleRate)) 319 }, 320 }, 321 "CustomServiceName": { 322 opts: []Option{WithServiceName("someServiceName")}, 323 test: func(assert *assert.Assertions, span mocktracer.Span) { 324 assert.Equal("someServiceName", span.Tag(ext.ServiceName)) 325 }, 326 }, 327 "WithAnalyticsTrue": { 328 opts: []Option{WithAnalytics(true)}, 329 test: func(assert *assert.Assertions, span mocktracer.Span) { 330 assert.Equal(1.0, span.Tag(ext.EventSampleRate)) 331 }, 332 }, 333 "WithAnalyticsFalse": { 334 opts: []Option{WithAnalytics(false)}, 335 test: func(assert *assert.Assertions, span mocktracer.Span) { 336 assert.Nil(span.Tag(ext.EventSampleRate)) 337 }, 338 }, 339 "WithAnalyticsLastOptionWins": { 340 opts: []Option{WithAnalyticsRate(0.7), WithAnalytics(true)}, 341 test: func(assert *assert.Assertions, span mocktracer.Span) { 342 assert.Equal(1.0, span.Tag(ext.EventSampleRate)) 343 }, 344 }, 345 "WithAnalyticsRateMax": { 346 opts: []Option{WithAnalyticsRate(1.0)}, 347 test: func(assert *assert.Assertions, span mocktracer.Span) { 348 assert.Equal(1.0, span.Tag(ext.EventSampleRate)) 349 }, 350 }, 351 "WithAnalyticsRateMin": { 352 opts: []Option{WithAnalyticsRate(0.0)}, 353 test: func(assert *assert.Assertions, span mocktracer.Span) { 354 assert.Equal(0.0, span.Tag(ext.EventSampleRate)) 355 }, 356 }, 357 "WithAnalyticsRateLastOptionWins": { 358 opts: []Option{WithAnalytics(true), WithAnalyticsRate(0.7)}, 359 test: func(assert *assert.Assertions, span mocktracer.Span) { 360 assert.Equal(0.7, span.Tag(ext.EventSampleRate)) 361 }, 362 }, 363 } { 364 t.Run(ttName, func(t *testing.T) { 365 assert := assert.New(t) 366 config := &api.Config{ 367 HttpClient: NewHTTPClient(tt.opts...), 368 Address: ts.URL, 369 } 370 client, err := api.NewClient(config) 371 if err != nil { 372 t.Fatal(err) 373 } 374 defer mountKV(client, t)() 375 376 mt := mocktracer.Start() 377 defer mt.Stop() 378 379 _, err = client.Logical().Write( 380 secretMountPath+"/key", 381 map[string]interface{}{"Key1": "Val1", "Key2": "Val2"}, 382 ) 383 if err != nil { 384 t.Fatal(err) 385 } 386 spans := mt.FinishedSpans() 387 assert.Len(spans, 1) 388 span := spans[0] 389 tt.test(assert, span) 390 }) 391 } 392} 393