1// Copyright 2019 The Prometheus Authors
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14package https
15
16import (
17	"crypto/tls"
18	"crypto/x509"
19	"errors"
20	"fmt"
21	"io/ioutil"
22	"net"
23	"net/http"
24	"regexp"
25	"sync"
26	"testing"
27	"time"
28)
29
30var (
31	port = getPort()
32
33	ErrorMap = map[string]*regexp.Regexp{
34		"HTTP Response to HTTPS":       regexp.MustCompile(`server gave HTTP response to HTTPS client`),
35		"No such file":                 regexp.MustCompile(`no such file`),
36		"Invalid argument":             regexp.MustCompile(`invalid argument`),
37		"YAML error":                   regexp.MustCompile(`yaml`),
38		"Invalid ClientAuth":           regexp.MustCompile(`invalid ClientAuth`),
39		"TLS handshake":                regexp.MustCompile(`tls`),
40		"HTTP Request to HTTPS server": regexp.MustCompile(`HTTP`),
41		"Invalid CertPath":             regexp.MustCompile(`missing TLSCertPath`),
42		"Invalid KeyPath":              regexp.MustCompile(`missing TLSKeyPath`),
43		"ClientCA set without policy":  regexp.MustCompile(`Client CA's have been configured without a Client Auth Policy`),
44	}
45)
46
47func getPort() string {
48	listener, err := net.Listen("tcp", ":0")
49	if err != nil {
50		panic(err)
51	}
52	defer listener.Close()
53	p := listener.Addr().(*net.TCPAddr).Port
54	return fmt.Sprintf(":%v", p)
55}
56
57type TestInputs struct {
58	Name           string
59	Server         func() *http.Server
60	UseNilServer   bool
61	YAMLConfigPath string
62	ExpectedError  *regexp.Regexp
63	UseTLSClient   bool
64}
65
66func TestYAMLFiles(t *testing.T) {
67	testTables := []*TestInputs{
68		{
69			Name:           `path to config yml invalid`,
70			YAMLConfigPath: "somefile",
71			ExpectedError:  ErrorMap["No such file"],
72		},
73		{
74			Name:           `empty config yml`,
75			YAMLConfigPath: "testdata/tls_config_empty.yml",
76			ExpectedError:  ErrorMap["Invalid CertPath"],
77		},
78		{
79			Name:           `invalid config yml (invalid structure)`,
80			YAMLConfigPath: "testdata/tls_config_junk.yml",
81			ExpectedError:  ErrorMap["YAML error"],
82		},
83		{
84			Name:           `invalid config yml (cert path empty)`,
85			YAMLConfigPath: "testdata/tls_config_noAuth_certPath_empty.bad.yml",
86			ExpectedError:  ErrorMap["Invalid CertPath"],
87		},
88		{
89			Name:           `invalid config yml (key path empty)`,
90			YAMLConfigPath: "testdata/tls_config_noAuth_keyPath_empty.bad.yml",
91			ExpectedError:  ErrorMap["Invalid KeyPath"],
92		},
93		{
94			Name:           `invalid config yml (cert path and key path empty)`,
95			YAMLConfigPath: "testdata/tls_config_noAuth_certPath_keyPath_empty.bad.yml",
96			ExpectedError:  ErrorMap["Invalid CertPath"],
97		},
98		{
99			Name:           `invalid config yml (cert path invalid)`,
100			YAMLConfigPath: "testdata/tls_config_noAuth_certPath_invalid.bad.yml",
101			ExpectedError:  ErrorMap["No such file"],
102		},
103		{
104			Name:           `invalid config yml (key path invalid)`,
105			YAMLConfigPath: "testdata/tls_config_noAuth_keyPath_invalid.bad.yml",
106			ExpectedError:  ErrorMap["No such file"],
107		},
108		{
109			Name:           `invalid config yml (cert path and key path invalid)`,
110			YAMLConfigPath: "testdata/tls_config_noAuth_certPath_keyPath_invalid.bad.yml",
111			ExpectedError:  ErrorMap["No such file"],
112		},
113		{
114			Name:           `invalid config yml (invalid ClientAuth)`,
115			YAMLConfigPath: "testdata/tls_config_noAuth.bad.yml",
116			ExpectedError:  ErrorMap["ClientCA set without policy"],
117		},
118		{
119			Name:           `invalid config yml (invalid ClientCAs filepath)`,
120			YAMLConfigPath: "testdata/tls_config_auth_clientCAs_invalid.bad.yml",
121			ExpectedError:  ErrorMap["No such file"],
122		},
123	}
124	for _, testInputs := range testTables {
125		t.Run(testInputs.Name, testInputs.Test)
126	}
127}
128
129func TestServerBehaviour(t *testing.T) {
130	testTables := []*TestInputs{
131		{
132			Name:           `empty string YAMLConfigPath and default client`,
133			YAMLConfigPath: "",
134			ExpectedError:  nil,
135		},
136		{
137			Name:           `empty string YAMLConfigPath and TLS client`,
138			YAMLConfigPath: "",
139			UseTLSClient:   true,
140			ExpectedError:  ErrorMap["HTTP Response to HTTPS"],
141		},
142		{
143			Name:           `valid tls config yml and default client`,
144			YAMLConfigPath: "testdata/tls_config_noAuth.good.yml",
145			ExpectedError:  ErrorMap["HTTP Request to HTTPS server"],
146		},
147		{
148			Name:           `valid tls config yml and tls client`,
149			YAMLConfigPath: "testdata/tls_config_noAuth.good.yml",
150			UseTLSClient:   true,
151			ExpectedError:  nil,
152		},
153	}
154	for _, testInputs := range testTables {
155		t.Run(testInputs.Name, testInputs.Test)
156	}
157}
158
159func TestConfigReloading(t *testing.T) {
160	errorChannel := make(chan error, 1)
161	var once sync.Once
162	recordConnectionError := func(err error) {
163		once.Do(func() {
164			errorChannel <- err
165		})
166	}
167	defer func() {
168		if recover() != nil {
169			recordConnectionError(errors.New("Panic in test function"))
170		}
171	}()
172
173	goodYAMLPath := "testdata/tls_config_noAuth.good.yml"
174	badYAMLPath := "testdata/tls_config_noAuth.good.blocking.yml"
175
176	server := &http.Server{
177		Addr: port,
178		Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
179			w.Write([]byte("Hello World!"))
180		}),
181	}
182	defer func() {
183		server.Close()
184	}()
185
186	go func() {
187		defer func() {
188			if recover() != nil {
189				recordConnectionError(errors.New("Panic starting server"))
190			}
191		}()
192		err := Listen(server, badYAMLPath)
193		recordConnectionError(err)
194	}()
195
196	client := getTLSClient()
197
198	TestClientConnection := func() error {
199		time.Sleep(250 * time.Millisecond)
200		r, err := client.Get("https://localhost" + port)
201		if err != nil {
202			return (err)
203		}
204		body, err := ioutil.ReadAll(r.Body)
205		if err != nil {
206			return (err)
207		}
208		if string(body) != "Hello World!" {
209			return (errors.New(string(body)))
210		}
211		return (nil)
212	}
213
214	err := TestClientConnection()
215	if err == nil {
216		recordConnectionError(errors.New("connection accepted but should have failed"))
217	} else {
218		swapFileContents(goodYAMLPath, badYAMLPath)
219		defer swapFileContents(goodYAMLPath, badYAMLPath)
220		err = TestClientConnection()
221		if err != nil {
222			recordConnectionError(errors.New("connection failed but should have been accepted"))
223		} else {
224
225			recordConnectionError(nil)
226		}
227	}
228
229	err = <-errorChannel
230	if err != nil {
231		t.Errorf(" *** Failed test: %s *** Returned error: %v", "TestConfigReloading", err)
232	}
233}
234
235func (test *TestInputs) Test(t *testing.T) {
236	errorChannel := make(chan error, 1)
237	var once sync.Once
238	recordConnectionError := func(err error) {
239		once.Do(func() {
240			errorChannel <- err
241		})
242	}
243	defer func() {
244		if recover() != nil {
245			recordConnectionError(errors.New("Panic in test function"))
246		}
247	}()
248
249	var server *http.Server
250	if test.UseNilServer {
251		server = nil
252	} else {
253		server = &http.Server{
254			Addr: port,
255			Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
256				w.Write([]byte("Hello World!"))
257			}),
258		}
259		defer func() {
260			server.Close()
261		}()
262	}
263	go func() {
264		defer func() {
265			if recover() != nil {
266				recordConnectionError(errors.New("Panic starting server"))
267			}
268		}()
269		err := Listen(server, test.YAMLConfigPath)
270		recordConnectionError(err)
271	}()
272
273	var ClientConnection func() (*http.Response, error)
274	if test.UseTLSClient {
275		ClientConnection = func() (*http.Response, error) {
276			client := getTLSClient()
277			return client.Get("https://localhost" + port)
278		}
279	} else {
280		ClientConnection = func() (*http.Response, error) {
281			client := http.DefaultClient
282			return client.Get("http://localhost" + port)
283		}
284	}
285	go func() {
286		time.Sleep(250 * time.Millisecond)
287		r, err := ClientConnection()
288		if err != nil {
289			recordConnectionError(err)
290			return
291		}
292		body, err := ioutil.ReadAll(r.Body)
293		if err != nil {
294			recordConnectionError(err)
295			return
296		}
297		if string(body) != "Hello World!" {
298			recordConnectionError(errors.New(string(body)))
299			return
300		}
301		recordConnectionError(nil)
302	}()
303	err := <-errorChannel
304	if test.isCorrectError(err) == false {
305		if test.ExpectedError == nil {
306			t.Logf("Expected no error, got error: %v", err)
307		} else {
308			t.Logf("Expected error matching regular expression: %v", test.ExpectedError)
309			t.Logf("Got: %v", err)
310		}
311		t.Fail()
312	}
313}
314
315func (test *TestInputs) isCorrectError(returnedError error) bool {
316	switch {
317	case returnedError == nil && test.ExpectedError == nil:
318	case returnedError != nil && test.ExpectedError != nil && test.ExpectedError.MatchString(returnedError.Error()):
319	default:
320		return false
321	}
322	return true
323}
324
325func getTLSClient() *http.Client {
326	cert, err := ioutil.ReadFile("testdata/tls-ca-chain.pem")
327	if err != nil {
328		panic("Unable to start TLS client. Check cert path")
329	}
330	client := &http.Client{
331		Transport: &http.Transport{
332			TLSClientConfig: &tls.Config{
333				RootCAs: func() *x509.CertPool {
334					caCertPool := x509.NewCertPool()
335					caCertPool.AppendCertsFromPEM(cert)
336					return caCertPool
337				}(),
338			},
339		},
340	}
341	return client
342}
343
344func swapFileContents(file1, file2 string) error {
345	content1, err := ioutil.ReadFile(file1)
346	if err != nil {
347		return err
348	}
349	content2, err := ioutil.ReadFile(file2)
350	if err != nil {
351		return err
352	}
353	err = ioutil.WriteFile(file1, content2, 0644)
354	if err != nil {
355		return err
356	}
357	err = ioutil.WriteFile(file2, content1, 0644)
358	if err != nil {
359		return err
360	}
361	return nil
362}
363