1package server
2
3import (
4	"crypto/tls"
5	"crypto/x509"
6	"fmt"
7	"io/ioutil"
8	"math/rand"
9	"net"
10	"os"
11	"testing"
12	"time"
13
14	"github.com/mitchellh/cli"
15)
16
17func TestTCPListener(t *testing.T) {
18	ln, _, _, err := tcpListenerFactory(map[string]interface{}{
19		"address":     "127.0.0.1:0",
20		"tls_disable": "1",
21	}, nil, cli.NewMockUi())
22	if err != nil {
23		t.Fatalf("err: %s", err)
24	}
25
26	connFn := func(lnReal net.Listener) (net.Conn, error) {
27		return net.Dial("tcp", ln.Addr().String())
28	}
29
30	testListenerImpl(t, ln, connFn, "")
31}
32
33// TestTCPListener_tls tests both TLS generally and also the reload capability
34// of core, system backend, and the listener logic
35func TestTCPListener_tls(t *testing.T) {
36	wd, _ := os.Getwd()
37	wd += "/test-fixtures/reload/"
38
39	td, err := ioutil.TempDir("", fmt.Sprintf("vault-test-%d", rand.New(rand.NewSource(time.Now().Unix())).Int63()))
40	if err != nil {
41		t.Fatal(err)
42	}
43	defer os.RemoveAll(td)
44
45	// Setup initial certs
46	inBytes, _ := ioutil.ReadFile(wd + "reload_ca.pem")
47	certPool := x509.NewCertPool()
48	ok := certPool.AppendCertsFromPEM(inBytes)
49	if !ok {
50		t.Fatal("not ok when appending CA cert")
51	}
52
53	ln, _, _, err := tcpListenerFactory(map[string]interface{}{
54		"address":                            "127.0.0.1:0",
55		"tls_cert_file":                      wd + "reload_foo.pem",
56		"tls_key_file":                       wd + "reload_foo.key",
57		"tls_require_and_verify_client_cert": "true",
58		"tls_client_ca_file":                 wd + "reload_ca.pem",
59	}, nil, cli.NewMockUi())
60	if err != nil {
61		t.Fatalf("err: %s", err)
62	}
63	cwd, _ := os.Getwd()
64
65	clientCert, _ := tls.LoadX509KeyPair(
66		cwd+"/test-fixtures/reload/reload_foo.pem",
67		cwd+"/test-fixtures/reload/reload_foo.key")
68
69	connFn := func(clientCerts bool) func(net.Listener) (net.Conn, error) {
70		return func(lnReal net.Listener) (net.Conn, error) {
71			conf := &tls.Config{
72				RootCAs: certPool,
73			}
74			if clientCerts {
75				conf.Certificates = []tls.Certificate{clientCert}
76			}
77			conn, err := tls.Dial("tcp", ln.Addr().String(), conf)
78
79			if err != nil {
80				return nil, err
81			}
82			if err = conn.Handshake(); err != nil {
83				return nil, err
84			}
85			return conn, nil
86		}
87	}
88
89	testListenerImpl(t, ln, connFn(true), "foo.example.com")
90
91	ln, _, _, err = tcpListenerFactory(map[string]interface{}{
92		"address":                            "127.0.0.1:0",
93		"tls_cert_file":                      wd + "reload_foo.pem",
94		"tls_key_file":                       wd + "reload_foo.key",
95		"tls_require_and_verify_client_cert": "true",
96		"tls_disable_client_certs":           "true",
97		"tls_client_ca_file":                 wd + "reload_ca.pem",
98	}, nil, cli.NewMockUi())
99	if err == nil {
100		t.Fatal("expected error due to mutually exclusive client cert options")
101	}
102
103	ln, _, _, err = tcpListenerFactory(map[string]interface{}{
104		"address":                  "127.0.0.1:0",
105		"tls_cert_file":            wd + "reload_foo.pem",
106		"tls_key_file":             wd + "reload_foo.key",
107		"tls_disable_client_certs": "true",
108		"tls_client_ca_file":       wd + "reload_ca.pem",
109	}, nil, cli.NewMockUi())
110	if err != nil {
111		t.Fatalf("err: %s", err)
112	}
113
114	testListenerImpl(t, ln, connFn(false), "foo.example.com")
115}
116