1// Unless explicitly stated otherwise all files in this repository are licensed
2// under the Apache License Version 2.0.
3// This product includes software developed at Datadog (https://www.datadoghq.com/).
4// Copyright 2016 Datadog, Inc.
5
6package vault
7
8import (
9	"encoding/json"
10	"fmt"
11	"io/ioutil"
12	"net/http"
13	"net/http/httptest"
14	"testing"
15
16	"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext"
17	"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer"
18
19	"github.com/hashicorp/vault/api"
20	"github.com/stretchr/testify/assert"
21)
22
23const secretMountPath = "/ns1/ns2/secret"
24
25func setupServer(t *testing.T) (*httptest.Server, func()) {
26	storage := make(map[string]string)
27	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
28		switch r.Method {
29		case http.MethodPut:
30			slurp, err := ioutil.ReadAll(r.Body)
31			if err != nil {
32				t.Fatal(err)
33			}
34			defer r.Body.Close()
35			storage[r.URL.Path] = string(slurp)
36			fmt.Fprintln(w, "{}")
37		case http.MethodGet:
38			val, ok := storage[r.URL.Path]
39			if !ok {
40				http.Error(w, "No data for key.", http.StatusNotFound)
41				return
42			}
43			secret := api.Secret{Data: make(map[string]interface{})}
44			json.Unmarshal([]byte(val), &secret.Data)
45			if err := json.NewEncoder(w).Encode(secret); err != nil {
46				t.Fatal(err)
47			}
48		}
49	}))
50	return ts, func() {
51		ts.Close()
52	}
53}
54
55func setupClient(ts *httptest.Server) (*api.Client, error) {
56	config := &api.Config{
57		HttpClient: NewHTTPClient(),
58		Address:    ts.URL,
59	}
60	client, err := api.NewClient(config)
61	if err != nil {
62		return nil, err
63	}
64	return client, nil
65}
66
67func TestNewHTTPClient(t *testing.T) {
68	ts, cleanup := setupServer(t)
69	defer cleanup()
70
71	client, err := setupClient(ts)
72	if err != nil {
73		t.Fatal(err)
74	}
75	testMountReadWrite(client, t)
76}
77
78func TestWrapHTTPClient(t *testing.T) {
79	ts, cleanup := setupServer(t)
80	defer cleanup()
81
82	httpClient := http.Client{}
83	config := &api.Config{
84		HttpClient: WrapHTTPClient(&httpClient),
85		Address:    ts.URL,
86	}
87	client, err := api.NewClient(config)
88	if err != nil {
89		t.Fatal(err)
90	}
91	client.SetToken("myroot")
92
93	testMountReadWrite(client, t)
94}
95
96// mountKV mounts the K/V engine on secretMountPath and returns a function to unmount it.
97// See: https://www.vaultproject.io/docs/secrets/
98func mountKV(c *api.Client, t *testing.T) func() {
99	secretMount := api.MountInput{
100		Type:        "kv",
101		Description: "Test KV Store",
102		Local:       true,
103	}
104	if err := c.Sys().Mount(secretMountPath, &secretMount); err != nil {
105		t.Fatal(err)
106	}
107	return func() {
108		c.Sys().Unmount(secretMountPath)
109	}
110}
111
112func testMountReadWrite(c *api.Client, t *testing.T) {
113	key := secretMountPath + "/test"
114	fullPath := "/v1" + key
115	data := map[string]interface{}{"Key1": "Val1", "Key2": "Val2"}
116
117	t.Run("mount", func(t *testing.T) {
118		assert := assert.New(t)
119		mt := mocktracer.Start()
120		defer mt.Stop()
121		defer mountKV(c, t)()
122
123		spans := mt.FinishedSpans()
124		assert.Len(spans, 1)
125		span := spans[0]
126
127		// Mount operation
128		assert.Equal("vault", span.Tag(ext.ServiceName))
129		assert.Equal("/v1/sys/mounts/ns1/ns2/secret", span.Tag(ext.HTTPURL))
130		assert.Equal(http.MethodPost, span.Tag(ext.HTTPMethod))
131		assert.Equal(http.MethodPost+" /v1/sys/mounts/ns1/ns2/secret", span.Tag(ext.ResourceName))
132		assert.Equal(ext.SpanTypeHTTP, span.Tag(ext.SpanType))
133		assert.Equal(200, span.Tag(ext.HTTPCode))
134		assert.Nil(span.Tag(ext.Error))
135		assert.Nil(span.Tag(ext.ErrorMsg))
136		assert.Nil(span.Tag("vault.namespace"))
137	})
138
139	t.Run("write", func(t *testing.T) {
140		assert := assert.New(t)
141		mt := mocktracer.Start()
142		defer mt.Stop()
143		defer mountKV(c, t)()
144
145		// Write key
146		_, err := c.Logical().Write(key, data)
147		if err != nil {
148			t.Fatal(err)
149		}
150		spans := mt.FinishedSpans()
151		assert.Len(spans, 2)
152		span := spans[1]
153
154		assert.Equal("vault", span.Tag(ext.ServiceName))
155		assert.Equal(fullPath, span.Tag(ext.HTTPURL))
156		assert.Equal(http.MethodPut, span.Tag(ext.HTTPMethod))
157		assert.Equal(http.MethodPut+" "+fullPath, span.Tag(ext.ResourceName))
158		assert.Equal(ext.SpanTypeHTTP, span.Tag(ext.SpanType))
159		assert.Equal(200, span.Tag(ext.HTTPCode))
160		assert.Nil(span.Tag(ext.Error))
161		assert.Nil(span.Tag(ext.ErrorMsg))
162		assert.Nil(span.Tag("vault.namespace"))
163	})
164
165	t.Run("read", func(t *testing.T) {
166		assert := assert.New(t)
167		mt := mocktracer.Start()
168		defer mt.Stop()
169		defer mountKV(c, t)()
170
171		// Write the key first
172		_, err := c.Logical().Write(key, data)
173		if err != nil {
174			t.Fatal(err)
175		}
176		// Read key
177		secret, err := c.Logical().Read(key)
178		if err != nil {
179			t.Fatal(err)
180		}
181		spans := mt.FinishedSpans()
182		assert.Len(spans, 3)
183		span := spans[2]
184
185		assert.Equal(secret.Data["Key1"], data["Key1"])
186		assert.Equal(secret.Data["Key2"], data["Key2"])
187		assert.Equal("vault", span.Tag(ext.ServiceName))
188		assert.Equal(fullPath, span.Tag(ext.HTTPURL))
189		assert.Equal(http.MethodGet, span.Tag(ext.HTTPMethod))
190		assert.Equal(http.MethodGet+" "+fullPath, span.Tag(ext.ResourceName))
191		assert.Equal(ext.SpanTypeHTTP, span.Tag(ext.SpanType))
192		assert.Equal(200, span.Tag(ext.HTTPCode))
193		assert.Nil(span.Tag(ext.Error))
194		assert.Nil(span.Tag(ext.ErrorMsg))
195		assert.Nil(span.Tag("vault.namespace"))
196	})
197}
198
199func TestReadError(t *testing.T) {
200	assert := assert.New(t)
201	mt := mocktracer.Start()
202	defer mt.Stop()
203
204	ts, cleanup := setupServer(t)
205	defer cleanup()
206	client, err := setupClient(ts)
207	if err != nil {
208		t.Fatal(err)
209	}
210	defer mountKV(client, t)()
211
212	key := "/some/bad/key"
213	fullPath := "/v1" + key
214	secret, err := client.Logical().Read(key)
215	if err == nil {
216		t.Fatalf("Expected error when reading key from %s, but it returned: %#v", key, secret)
217	}
218	spans := mt.FinishedSpans()
219	assert.Len(spans, 2)
220	span := spans[1]
221
222	// Read key error
223	assert.Equal("vault", span.Tag(ext.ServiceName))
224	assert.Equal(fullPath, span.Tag(ext.HTTPURL))
225	assert.Equal(http.MethodGet, span.Tag(ext.HTTPMethod))
226	assert.Equal(http.MethodGet+" "+fullPath, span.Tag(ext.ResourceName))
227	assert.Equal(ext.SpanTypeHTTP, span.Tag(ext.SpanType))
228	assert.Equal(404, span.Tag(ext.HTTPCode))
229	assert.Equal(true, span.Tag(ext.Error))
230	assert.NotNil(span.Tag(ext.ErrorMsg))
231	assert.Nil(span.Tag("vault.namespace"))
232}
233
234func TestNamespace(t *testing.T) {
235	ts, cleanup := setupServer(t)
236	defer cleanup()
237	client, err := setupClient(ts)
238	if err != nil {
239		t.Fatal(err)
240	}
241	defer mountKV(client, t)()
242
243	namespace := "/some/namespace"
244	client.SetNamespace(namespace)
245	key := secretMountPath + "/testNamespace"
246	fullPath := "/v1" + key
247
248	t.Run("write", func(t *testing.T) {
249		assert := assert.New(t)
250		mt := mocktracer.Start()
251		defer mt.Stop()
252
253		// Write key with namespace
254		data := map[string]interface{}{"Key1": "Val1", "Key2": "Val2"}
255		_, err = client.Logical().Write(key, data)
256		if err != nil {
257			t.Fatal(err)
258		}
259		spans := mt.FinishedSpans()
260		assert.Len(spans, 1)
261		span := spans[0]
262
263		assert.Equal("vault", span.Tag(ext.ServiceName))
264		assert.Equal(fullPath, span.Tag(ext.HTTPURL))
265		assert.Equal(http.MethodPut, span.Tag(ext.HTTPMethod))
266		assert.Equal(http.MethodPut+" "+fullPath, span.Tag(ext.ResourceName))
267		assert.Equal(ext.SpanTypeHTTP, span.Tag(ext.SpanType))
268		assert.Equal(200, span.Tag(ext.HTTPCode))
269		assert.Nil(span.Tag(ext.Error))
270		assert.Nil(span.Tag(ext.ErrorMsg))
271		assert.Equal(namespace, span.Tag("vault.namespace"))
272	})
273
274	t.Run("read", func(t *testing.T) {
275		assert := assert.New(t)
276		mt := mocktracer.Start()
277		defer mt.Stop()
278
279		// Write key with namespace first
280		data := map[string]interface{}{"Key1": "Val1", "Key2": "Val2"}
281		_, err = client.Logical().Write(key, data)
282		if err != nil {
283			t.Fatal(err)
284		}
285		// Read key with namespace
286		_, err = client.Logical().Read(key)
287		if err != nil {
288			t.Fatal(err)
289		}
290		spans := mt.FinishedSpans()
291		assert.Len(spans, 2)
292		span := spans[1]
293
294		assert.Equal("vault", span.Tag(ext.ServiceName))
295		assert.Equal(fullPath, span.Tag(ext.HTTPURL))
296		assert.Equal(http.MethodGet, span.Tag(ext.HTTPMethod))
297		assert.Equal(http.MethodGet+" "+fullPath, span.Tag(ext.ResourceName))
298		assert.Equal(ext.SpanTypeHTTP, span.Tag(ext.SpanType))
299		assert.Equal(200, span.Tag(ext.HTTPCode))
300		assert.Nil(span.Tag(ext.Error))
301		assert.Nil(span.Tag(ext.ErrorMsg))
302		assert.Equal(namespace, span.Tag("vault.namespace"))
303	})
304}
305
306func TestOption(t *testing.T) {
307	ts, cleanup := setupServer(t)
308	defer cleanup()
309
310	for ttName, tt := range map[string]struct {
311		opts []Option
312		test func(assert *assert.Assertions, span mocktracer.Span)
313	}{
314		"DefaultOptions": {
315			opts: []Option{},
316			test: func(assert *assert.Assertions, span mocktracer.Span) {
317				assert.Equal(defaultServiceName, span.Tag(ext.ServiceName))
318				assert.Nil(span.Tag(ext.EventSampleRate))
319			},
320		},
321		"CustomServiceName": {
322			opts: []Option{WithServiceName("someServiceName")},
323			test: func(assert *assert.Assertions, span mocktracer.Span) {
324				assert.Equal("someServiceName", span.Tag(ext.ServiceName))
325			},
326		},
327		"WithAnalyticsTrue": {
328			opts: []Option{WithAnalytics(true)},
329			test: func(assert *assert.Assertions, span mocktracer.Span) {
330				assert.Equal(1.0, span.Tag(ext.EventSampleRate))
331			},
332		},
333		"WithAnalyticsFalse": {
334			opts: []Option{WithAnalytics(false)},
335			test: func(assert *assert.Assertions, span mocktracer.Span) {
336				assert.Nil(span.Tag(ext.EventSampleRate))
337			},
338		},
339		"WithAnalyticsLastOptionWins": {
340			opts: []Option{WithAnalyticsRate(0.7), WithAnalytics(true)},
341			test: func(assert *assert.Assertions, span mocktracer.Span) {
342				assert.Equal(1.0, span.Tag(ext.EventSampleRate))
343			},
344		},
345		"WithAnalyticsRateMax": {
346			opts: []Option{WithAnalyticsRate(1.0)},
347			test: func(assert *assert.Assertions, span mocktracer.Span) {
348				assert.Equal(1.0, span.Tag(ext.EventSampleRate))
349			},
350		},
351		"WithAnalyticsRateMin": {
352			opts: []Option{WithAnalyticsRate(0.0)},
353			test: func(assert *assert.Assertions, span mocktracer.Span) {
354				assert.Equal(0.0, span.Tag(ext.EventSampleRate))
355			},
356		},
357		"WithAnalyticsRateLastOptionWins": {
358			opts: []Option{WithAnalytics(true), WithAnalyticsRate(0.7)},
359			test: func(assert *assert.Assertions, span mocktracer.Span) {
360				assert.Equal(0.7, span.Tag(ext.EventSampleRate))
361			},
362		},
363	} {
364		t.Run(ttName, func(t *testing.T) {
365			assert := assert.New(t)
366			config := &api.Config{
367				HttpClient: NewHTTPClient(tt.opts...),
368				Address:    ts.URL,
369			}
370			client, err := api.NewClient(config)
371			if err != nil {
372				t.Fatal(err)
373			}
374			defer mountKV(client, t)()
375
376			mt := mocktracer.Start()
377			defer mt.Stop()
378
379			_, err = client.Logical().Write(
380				secretMountPath+"/key",
381				map[string]interface{}{"Key1": "Val1", "Key2": "Val2"},
382			)
383			if err != nil {
384				t.Fatal(err)
385			}
386			spans := mt.FinishedSpans()
387			assert.Len(spans, 1)
388			span := spans[0]
389			tt.test(assert, span)
390		})
391	}
392}
393