1package physical
2
3import (
4	"fmt"
5	"log"
6	"math/rand"
7	"os"
8	"reflect"
9	"testing"
10	"time"
11
12	"github.com/hashicorp/consul/api"
13)
14
15type consulConf map[string]string
16
17var (
18	addrCount int = 0
19)
20
21func testHostIP() string {
22	a := addrCount
23	addrCount++
24	return fmt.Sprintf("127.0.0.%d", a)
25}
26
27func testConsulBackend(t *testing.T) *ConsulBackend {
28	return testConsulBackendConfig(t, &consulConf{})
29}
30
31func testConsulBackendConfig(t *testing.T, conf *consulConf) *ConsulBackend {
32	logger := log.New(os.Stderr, "", log.LstdFlags)
33	be, err := newConsulBackend(*conf, logger)
34	if err != nil {
35		t.Fatalf("Expected Consul to initialize: %v", err)
36	}
37
38	c, ok := be.(*ConsulBackend)
39	if !ok {
40		t.Fatalf("Expected ConsulBackend")
41	}
42
43	return c
44}
45
46func testConsul_testConsulBackend(t *testing.T) {
47	c := testConsulBackend(t)
48	if c == nil {
49		t.Fatalf("bad")
50	}
51}
52
53func testActiveFunc(activePct float64) activeFunction {
54	return func() bool {
55		var active bool
56		standbyProb := rand.Float64()
57		if standbyProb > activePct {
58			active = true
59		}
60		return active
61	}
62}
63
64func testSealedFunc(sealedPct float64) sealedFunction {
65	return func() bool {
66		var sealed bool
67		unsealedProb := rand.Float64()
68		if unsealedProb > sealedPct {
69			sealed = true
70		}
71		return sealed
72	}
73}
74
75func TestConsul_newConsulBackend(t *testing.T) {
76	tests := []struct {
77		name          string
78		consulConfig  map[string]string
79		fail          bool
80		advertiseAddr string
81		checkTimeout  time.Duration
82		path          string
83		service       string
84		address       string
85		scheme        string
86		token         string
87		max_parallel  int
88		disableReg    bool
89	}{
90		{
91			name:          "Valid default config",
92			consulConfig:  map[string]string{},
93			checkTimeout:  5 * time.Second,
94			advertiseAddr: "http://127.0.0.1:8200",
95			path:          "vault/",
96			service:       "vault",
97			address:       "127.0.0.1:8500",
98			scheme:        "http",
99			token:         "",
100			max_parallel:  4,
101			disableReg:    false,
102		},
103		{
104			name: "Valid modified config",
105			consulConfig: map[string]string{
106				"path":                 "seaTech/",
107				"service":              "astronomy",
108				"advertiseAddr":        "http://127.0.0.2:8200",
109				"check_timeout":        "6s",
110				"address":              "127.0.0.2",
111				"scheme":               "https",
112				"token":                "deadbeef-cafeefac-deadc0de-feedface",
113				"max_parallel":         "4",
114				"disable_registration": "false",
115			},
116			checkTimeout:  6 * time.Second,
117			path:          "seaTech/",
118			service:       "astronomy",
119			advertiseAddr: "http://127.0.0.2:8200",
120			address:       "127.0.0.2",
121			scheme:        "https",
122			token:         "deadbeef-cafeefac-deadc0de-feedface",
123			max_parallel:  4,
124		},
125		{
126			name: "check timeout too short",
127			fail: true,
128			consulConfig: map[string]string{
129				"check_timeout": "99ms",
130			},
131		},
132	}
133
134	for _, test := range tests {
135		logger := log.New(os.Stderr, "", log.LstdFlags)
136		be, err := newConsulBackend(test.consulConfig, logger)
137		if test.fail {
138			if err == nil {
139				t.Fatalf(`Expected config "%s" to fail`, test.name)
140			} else {
141				continue
142			}
143		} else if !test.fail && err != nil {
144			t.Fatalf("Expected config %s to not fail: %v", test.name, err)
145		}
146
147		c, ok := be.(*ConsulBackend)
148		if !ok {
149			t.Fatalf("Expected ConsulBackend: %s", test.name)
150		}
151		c.disableRegistration = true
152
153		if c.disableRegistration == false {
154			addr := os.Getenv("CONSUL_HTTP_ADDR")
155			if addr == "" {
156				continue
157			}
158		}
159
160		var shutdownCh ShutdownChannel
161		if err := c.RunServiceDiscovery(shutdownCh, test.advertiseAddr, testActiveFunc(0.5), testSealedFunc(0.5)); err != nil {
162			t.Fatalf("bad: %v", err)
163		}
164
165		if test.checkTimeout != c.checkTimeout {
166			t.Errorf("bad: %v != %v", test.checkTimeout, c.checkTimeout)
167		}
168
169		if test.path != c.path {
170			t.Errorf("bad: %s %v != %v", test.name, test.path, c.path)
171		}
172
173		if test.service != c.serviceName {
174			t.Errorf("bad: %v != %v", test.service, c.serviceName)
175		}
176
177		// FIXME(sean@): Unable to test max_parallel
178		// if test.max_parallel != cap(c.permitPool) {
179		// 	t.Errorf("bad: %v != %v", test.max_parallel, cap(c.permitPool))
180		// }
181	}
182}
183
184func TestConsul_serviceTags(t *testing.T) {
185	tests := []struct {
186		active bool
187		tags   []string
188	}{
189		{
190			active: true,
191			tags:   []string{"active"},
192		},
193		{
194			active: false,
195			tags:   []string{"standby"},
196		},
197	}
198
199	for _, test := range tests {
200		tags := serviceTags(test.active)
201		if !reflect.DeepEqual(tags[:], test.tags[:]) {
202			t.Errorf("Bad %v: %v %v", test.active, tags, test.tags)
203		}
204	}
205}
206
207func TestConsul_setAdvertiseAddr(t *testing.T) {
208	tests := []struct {
209		addr string
210		host string
211		port int64
212		pass bool
213	}{
214		{
215			addr: "http://127.0.0.1:8200/",
216			host: "127.0.0.1",
217			port: 8200,
218			pass: true,
219		},
220		{
221			addr: "http://127.0.0.1:8200",
222			host: "127.0.0.1",
223			port: 8200,
224			pass: true,
225		},
226		{
227			addr: "https://127.0.0.1:8200",
228			host: "127.0.0.1",
229			port: 8200,
230			pass: true,
231		},
232		{
233			addr: "unix:///tmp/.vault.addr.sock",
234			host: "/tmp/.vault.addr.sock",
235			port: -1,
236			pass: true,
237		},
238		{
239			addr: "127.0.0.1:8200",
240			pass: false,
241		},
242		{
243			addr: "127.0.0.1",
244			pass: false,
245		},
246	}
247	for _, test := range tests {
248		c := testConsulBackend(t)
249		err := c.setAdvertiseAddr(test.addr)
250		if test.pass {
251			if err != nil {
252				t.Fatalf("bad: %v", err)
253			}
254		} else {
255			if err == nil {
256				t.Fatalf("bad, expected fail")
257			} else {
258				continue
259			}
260		}
261
262		if c.advertiseHost != test.host {
263			t.Fatalf("bad: %v != %v", c.advertiseHost, test.host)
264		}
265
266		if c.advertisePort != test.port {
267			t.Fatalf("bad: %v != %v", c.advertisePort, test.port)
268		}
269	}
270}
271
272func TestConsul_NotifyActiveStateChange(t *testing.T) {
273	addr := os.Getenv("CONSUL_HTTP_ADDR")
274	if addr == "" {
275		t.Skipf("No consul process running, skipping test")
276	}
277
278	c := testConsulBackend(t)
279
280	if err := c.NotifyActiveStateChange(); err != nil {
281		t.Fatalf("bad: %v", err)
282	}
283}
284
285func TestConsul_NotifySealedStateChange(t *testing.T) {
286	addr := os.Getenv("CONSUL_HTTP_ADDR")
287	if addr == "" {
288		t.Skipf("No consul process running, skipping test")
289	}
290
291	c := testConsulBackend(t)
292
293	if err := c.NotifySealedStateChange(); err != nil {
294		t.Fatalf("bad: %v", err)
295	}
296}
297
298func TestConsul_checkID(t *testing.T) {
299	c := testConsulBackend(t)
300	if c.checkID() != "vault-sealed-check" {
301		t.Errorf("bad")
302	}
303}
304
305func TestConsul_serviceID(t *testing.T) {
306	passingTests := []struct {
307		name          string
308		advertiseAddr string
309		serviceName   string
310		expected      string
311	}{
312		{
313			name:          "valid host w/o slash",
314			advertiseAddr: "http://127.0.0.1:8200",
315			serviceName:   "sea-tech-astronomy",
316			expected:      "sea-tech-astronomy:127.0.0.1:8200",
317		},
318		{
319			name:          "valid host w/ slash",
320			advertiseAddr: "http://127.0.0.1:8200/",
321			serviceName:   "sea-tech-astronomy",
322			expected:      "sea-tech-astronomy:127.0.0.1:8200",
323		},
324		{
325			name:          "valid https host w/ slash",
326			advertiseAddr: "https://127.0.0.1:8200/",
327			serviceName:   "sea-tech-astronomy",
328			expected:      "sea-tech-astronomy:127.0.0.1:8200",
329		},
330	}
331
332	for _, test := range passingTests {
333		c := testConsulBackendConfig(t, &consulConf{
334			"service": test.serviceName,
335		})
336
337		if err := c.setAdvertiseAddr(test.advertiseAddr); err != nil {
338			t.Fatalf("bad: %s %v", test.name, err)
339		}
340
341		serviceID := c.serviceID()
342		if serviceID != test.expected {
343			t.Fatalf("bad: %v != %v", serviceID, test.expected)
344		}
345	}
346}
347
348func TestConsulBackend(t *testing.T) {
349	addr := os.Getenv("CONSUL_HTTP_ADDR")
350	if addr == "" {
351		t.Skipf("No consul process running, skipping test")
352	}
353
354	conf := api.DefaultConfig()
355	conf.Address = addr
356	client, err := api.NewClient(conf)
357	if err != nil {
358		t.Fatalf("err: %v", err)
359	}
360
361	randPath := fmt.Sprintf("vault-%d/", time.Now().Unix())
362	defer func() {
363		client.KV().DeleteTree(randPath, nil)
364	}()
365
366	logger := log.New(os.Stderr, "", log.LstdFlags)
367	b, err := NewBackend("consul", logger, map[string]string{
368		"address":      addr,
369		"path":         randPath,
370		"max_parallel": "256",
371	})
372	if err != nil {
373		t.Fatalf("err: %s", err)
374	}
375
376	testBackend(t, b)
377	testBackend_ListPrefix(t, b)
378}
379
380func TestConsulHABackend(t *testing.T) {
381	addr := os.Getenv("CONSUL_HTTP_ADDR")
382	if addr == "" {
383		t.Skipf("No consul process running, skipping test")
384	}
385
386	conf := api.DefaultConfig()
387	conf.Address = addr
388	client, err := api.NewClient(conf)
389	if err != nil {
390		t.Fatalf("err: %v", err)
391	}
392
393	randPath := fmt.Sprintf("vault-%d/", time.Now().Unix())
394	defer func() {
395		client.KV().DeleteTree(randPath, nil)
396	}()
397
398	logger := log.New(os.Stderr, "", log.LstdFlags)
399	b, err := NewBackend("consul", logger, map[string]string{
400		"address":      addr,
401		"path":         randPath,
402		"max_parallel": "-1",
403	})
404	if err != nil {
405		t.Fatalf("err: %s", err)
406	}
407
408	ha, ok := b.(HABackend)
409	if !ok {
410		t.Fatalf("consul does not implement HABackend")
411	}
412	testHABackend(t, ha, ha)
413
414	detect, ok := b.(AdvertiseDetect)
415	if !ok {
416		t.Fatalf("consul does not implement AdvertiseDetect")
417	}
418	host, err := detect.DetectHostAddr()
419	if err != nil {
420		t.Fatalf("err: %s", err)
421	}
422	if host == "" {
423		t.Fatalf("bad addr: %v", host)
424	}
425}
426