1// Copyright 2019 The Prometheus Authors 2// Licensed under the Apache License, Version 2.0 (the "License"); 3// you may not use this file except in compliance with the License. 4// You may obtain a copy of the License at 5// 6// http://www.apache.org/licenses/LICENSE-2.0 7// 8// Unless required by applicable law or agreed to in writing, software 9// distributed under the License is distributed on an "AS IS" BASIS, 10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11// See the License for the specific language governing permissions and 12// limitations under the License. 13 14package https 15 16import ( 17 "crypto/tls" 18 "crypto/x509" 19 "errors" 20 "fmt" 21 "io/ioutil" 22 "net" 23 "net/http" 24 "regexp" 25 "sync" 26 "testing" 27 "time" 28) 29 30var ( 31 port = getPort() 32 33 ErrorMap = map[string]*regexp.Regexp{ 34 "HTTP Response to HTTPS": regexp.MustCompile(`server gave HTTP response to HTTPS client`), 35 "No such file": regexp.MustCompile(`no such file`), 36 "Invalid argument": regexp.MustCompile(`invalid argument`), 37 "YAML error": regexp.MustCompile(`yaml`), 38 "Invalid ClientAuth": regexp.MustCompile(`invalid ClientAuth`), 39 "TLS handshake": regexp.MustCompile(`tls`), 40 "HTTP Request to HTTPS server": regexp.MustCompile(`HTTP`), 41 "Invalid CertPath": regexp.MustCompile(`missing TLSCertPath`), 42 "Invalid KeyPath": regexp.MustCompile(`missing TLSKeyPath`), 43 "ClientCA set without policy": regexp.MustCompile(`Client CA's have been configured without a Client Auth Policy`), 44 } 45) 46 47func getPort() string { 48 listener, err := net.Listen("tcp", ":0") 49 if err != nil { 50 panic(err) 51 } 52 defer listener.Close() 53 p := listener.Addr().(*net.TCPAddr).Port 54 return fmt.Sprintf(":%v", p) 55} 56 57type TestInputs struct { 58 Name string 59 Server func() *http.Server 60 UseNilServer bool 61 YAMLConfigPath string 62 ExpectedError *regexp.Regexp 63 UseTLSClient bool 64} 65 66func TestYAMLFiles(t *testing.T) { 67 testTables := []*TestInputs{ 68 { 69 Name: `path to config yml invalid`, 70 YAMLConfigPath: "somefile", 71 ExpectedError: ErrorMap["No such file"], 72 }, 73 { 74 Name: `empty config yml`, 75 YAMLConfigPath: "testdata/tls_config_empty.yml", 76 ExpectedError: ErrorMap["Invalid CertPath"], 77 }, 78 { 79 Name: `invalid config yml (invalid structure)`, 80 YAMLConfigPath: "testdata/tls_config_junk.yml", 81 ExpectedError: ErrorMap["YAML error"], 82 }, 83 { 84 Name: `invalid config yml (cert path empty)`, 85 YAMLConfigPath: "testdata/tls_config_noAuth_certPath_empty.bad.yml", 86 ExpectedError: ErrorMap["Invalid CertPath"], 87 }, 88 { 89 Name: `invalid config yml (key path empty)`, 90 YAMLConfigPath: "testdata/tls_config_noAuth_keyPath_empty.bad.yml", 91 ExpectedError: ErrorMap["Invalid KeyPath"], 92 }, 93 { 94 Name: `invalid config yml (cert path and key path empty)`, 95 YAMLConfigPath: "testdata/tls_config_noAuth_certPath_keyPath_empty.bad.yml", 96 ExpectedError: ErrorMap["Invalid CertPath"], 97 }, 98 { 99 Name: `invalid config yml (cert path invalid)`, 100 YAMLConfigPath: "testdata/tls_config_noAuth_certPath_invalid.bad.yml", 101 ExpectedError: ErrorMap["No such file"], 102 }, 103 { 104 Name: `invalid config yml (key path invalid)`, 105 YAMLConfigPath: "testdata/tls_config_noAuth_keyPath_invalid.bad.yml", 106 ExpectedError: ErrorMap["No such file"], 107 }, 108 { 109 Name: `invalid config yml (cert path and key path invalid)`, 110 YAMLConfigPath: "testdata/tls_config_noAuth_certPath_keyPath_invalid.bad.yml", 111 ExpectedError: ErrorMap["No such file"], 112 }, 113 { 114 Name: `invalid config yml (invalid ClientAuth)`, 115 YAMLConfigPath: "testdata/tls_config_noAuth.bad.yml", 116 ExpectedError: ErrorMap["ClientCA set without policy"], 117 }, 118 { 119 Name: `invalid config yml (invalid ClientCAs filepath)`, 120 YAMLConfigPath: "testdata/tls_config_auth_clientCAs_invalid.bad.yml", 121 ExpectedError: ErrorMap["No such file"], 122 }, 123 } 124 for _, testInputs := range testTables { 125 t.Run(testInputs.Name, testInputs.Test) 126 } 127} 128 129func TestServerBehaviour(t *testing.T) { 130 testTables := []*TestInputs{ 131 { 132 Name: `empty string YAMLConfigPath and default client`, 133 YAMLConfigPath: "", 134 ExpectedError: nil, 135 }, 136 { 137 Name: `empty string YAMLConfigPath and TLS client`, 138 YAMLConfigPath: "", 139 UseTLSClient: true, 140 ExpectedError: ErrorMap["HTTP Response to HTTPS"], 141 }, 142 { 143 Name: `valid tls config yml and default client`, 144 YAMLConfigPath: "testdata/tls_config_noAuth.good.yml", 145 ExpectedError: ErrorMap["HTTP Request to HTTPS server"], 146 }, 147 { 148 Name: `valid tls config yml and tls client`, 149 YAMLConfigPath: "testdata/tls_config_noAuth.good.yml", 150 UseTLSClient: true, 151 ExpectedError: nil, 152 }, 153 } 154 for _, testInputs := range testTables { 155 t.Run(testInputs.Name, testInputs.Test) 156 } 157} 158 159func TestConfigReloading(t *testing.T) { 160 errorChannel := make(chan error, 1) 161 var once sync.Once 162 recordConnectionError := func(err error) { 163 once.Do(func() { 164 errorChannel <- err 165 }) 166 } 167 defer func() { 168 if recover() != nil { 169 recordConnectionError(errors.New("Panic in test function")) 170 } 171 }() 172 173 goodYAMLPath := "testdata/tls_config_noAuth.good.yml" 174 badYAMLPath := "testdata/tls_config_noAuth.good.blocking.yml" 175 176 server := &http.Server{ 177 Addr: port, 178 Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 179 w.Write([]byte("Hello World!")) 180 }), 181 } 182 defer func() { 183 server.Close() 184 }() 185 186 go func() { 187 defer func() { 188 if recover() != nil { 189 recordConnectionError(errors.New("Panic starting server")) 190 } 191 }() 192 err := Listen(server, badYAMLPath) 193 recordConnectionError(err) 194 }() 195 196 client := getTLSClient() 197 198 TestClientConnection := func() error { 199 time.Sleep(250 * time.Millisecond) 200 r, err := client.Get("https://localhost" + port) 201 if err != nil { 202 return (err) 203 } 204 body, err := ioutil.ReadAll(r.Body) 205 if err != nil { 206 return (err) 207 } 208 if string(body) != "Hello World!" { 209 return (errors.New(string(body))) 210 } 211 return (nil) 212 } 213 214 err := TestClientConnection() 215 if err == nil { 216 recordConnectionError(errors.New("connection accepted but should have failed")) 217 } else { 218 swapFileContents(goodYAMLPath, badYAMLPath) 219 defer swapFileContents(goodYAMLPath, badYAMLPath) 220 err = TestClientConnection() 221 if err != nil { 222 recordConnectionError(errors.New("connection failed but should have been accepted")) 223 } else { 224 225 recordConnectionError(nil) 226 } 227 } 228 229 err = <-errorChannel 230 if err != nil { 231 t.Errorf(" *** Failed test: %s *** Returned error: %v", "TestConfigReloading", err) 232 } 233} 234 235func (test *TestInputs) Test(t *testing.T) { 236 errorChannel := make(chan error, 1) 237 var once sync.Once 238 recordConnectionError := func(err error) { 239 once.Do(func() { 240 errorChannel <- err 241 }) 242 } 243 defer func() { 244 if recover() != nil { 245 recordConnectionError(errors.New("Panic in test function")) 246 } 247 }() 248 249 var server *http.Server 250 if test.UseNilServer { 251 server = nil 252 } else { 253 server = &http.Server{ 254 Addr: port, 255 Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 256 w.Write([]byte("Hello World!")) 257 }), 258 } 259 defer func() { 260 server.Close() 261 }() 262 } 263 go func() { 264 defer func() { 265 if recover() != nil { 266 recordConnectionError(errors.New("Panic starting server")) 267 } 268 }() 269 err := Listen(server, test.YAMLConfigPath) 270 recordConnectionError(err) 271 }() 272 273 var ClientConnection func() (*http.Response, error) 274 if test.UseTLSClient { 275 ClientConnection = func() (*http.Response, error) { 276 client := getTLSClient() 277 return client.Get("https://localhost" + port) 278 } 279 } else { 280 ClientConnection = func() (*http.Response, error) { 281 client := http.DefaultClient 282 return client.Get("http://localhost" + port) 283 } 284 } 285 go func() { 286 time.Sleep(250 * time.Millisecond) 287 r, err := ClientConnection() 288 if err != nil { 289 recordConnectionError(err) 290 return 291 } 292 body, err := ioutil.ReadAll(r.Body) 293 if err != nil { 294 recordConnectionError(err) 295 return 296 } 297 if string(body) != "Hello World!" { 298 recordConnectionError(errors.New(string(body))) 299 return 300 } 301 recordConnectionError(nil) 302 }() 303 err := <-errorChannel 304 if test.isCorrectError(err) == false { 305 if test.ExpectedError == nil { 306 t.Logf("Expected no error, got error: %v", err) 307 } else { 308 t.Logf("Expected error matching regular expression: %v", test.ExpectedError) 309 t.Logf("Got: %v", err) 310 } 311 t.Fail() 312 } 313} 314 315func (test *TestInputs) isCorrectError(returnedError error) bool { 316 switch { 317 case returnedError == nil && test.ExpectedError == nil: 318 case returnedError != nil && test.ExpectedError != nil && test.ExpectedError.MatchString(returnedError.Error()): 319 default: 320 return false 321 } 322 return true 323} 324 325func getTLSClient() *http.Client { 326 cert, err := ioutil.ReadFile("testdata/tls-ca-chain.pem") 327 if err != nil { 328 panic("Unable to start TLS client. Check cert path") 329 } 330 client := &http.Client{ 331 Transport: &http.Transport{ 332 TLSClientConfig: &tls.Config{ 333 RootCAs: func() *x509.CertPool { 334 caCertPool := x509.NewCertPool() 335 caCertPool.AppendCertsFromPEM(cert) 336 return caCertPool 337 }(), 338 }, 339 }, 340 } 341 return client 342} 343 344func swapFileContents(file1, file2 string) error { 345 content1, err := ioutil.ReadFile(file1) 346 if err != nil { 347 return err 348 } 349 content2, err := ioutil.ReadFile(file2) 350 if err != nil { 351 return err 352 } 353 err = ioutil.WriteFile(file1, content2, 0644) 354 if err != nil { 355 return err 356 } 357 err = ioutil.WriteFile(file2, content1, 0644) 358 if err != nil { 359 return err 360 } 361 return nil 362} 363