1/* 2Copyright 2016 The Kubernetes Authors. 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17package options 18 19import ( 20 "bytes" 21 cryptorand "crypto/rand" 22 "crypto/rsa" 23 "crypto/tls" 24 "crypto/x509" 25 "crypto/x509/pkix" 26 "encoding/base64" 27 "encoding/pem" 28 "fmt" 29 "io/ioutil" 30 "math/big" 31 "net" 32 "os" 33 "path/filepath" 34 "reflect" 35 "strconv" 36 "strings" 37 "testing" 38 "time" 39 40 "k8s.io/apimachinery/pkg/runtime" 41 "k8s.io/apimachinery/pkg/runtime/serializer" 42 "k8s.io/apimachinery/pkg/version" 43 "k8s.io/apiserver/pkg/server" 44 "k8s.io/client-go/discovery" 45 restclient "k8s.io/client-go/rest" 46 cliflag "k8s.io/component-base/cli/flag" 47) 48 49func setUp(t *testing.T) server.Config { 50 scheme := runtime.NewScheme() 51 codecs := serializer.NewCodecFactory(scheme) 52 53 config := server.NewConfig(codecs) 54 55 return *config 56} 57 58type TestCertSpec struct { 59 host string 60 names, ips []string // in certificate 61} 62 63type NamedTestCertSpec struct { 64 TestCertSpec 65 explicitNames []string // as --tls-sni-cert-key explicit names 66} 67 68func TestServerRunWithSNI(t *testing.T) { 69 tests := map[string]struct { 70 Cert TestCertSpec 71 SNICerts []NamedTestCertSpec 72 ExpectedCertIndex int 73 74 // passed in the client hello info, "localhost" if unset 75 ServerName string 76 77 // optional ip or hostname to pass to NewLoopbackClientConfig 78 LoopbackClientBindAddressOverride string 79 ExpectLoopbackClientError bool 80 }{ 81 "only one cert": { 82 Cert: TestCertSpec{ 83 host: "localhost", 84 ips: []string{"127.0.0.1"}, 85 }, 86 ExpectedCertIndex: -1, 87 }, 88 "cert with multiple alternate names": { 89 Cert: TestCertSpec{ 90 host: "localhost", 91 names: []string{"test.com"}, 92 ips: []string{"127.0.0.1"}, 93 }, 94 ExpectedCertIndex: -1, 95 ServerName: "test.com", 96 }, 97 "one SNI and the default cert with the same name": { 98 Cert: TestCertSpec{ 99 host: "localhost", 100 ips: []string{"127.0.0.1"}, 101 }, 102 SNICerts: []NamedTestCertSpec{ 103 { 104 TestCertSpec: TestCertSpec{ 105 host: "localhost", 106 }, 107 }, 108 }, 109 ExpectedCertIndex: 0, 110 }, 111 "matching SNI cert": { 112 Cert: TestCertSpec{ 113 host: "localhost", 114 ips: []string{"127.0.0.1"}, 115 }, 116 SNICerts: []NamedTestCertSpec{ 117 { 118 TestCertSpec: TestCertSpec{ 119 host: "test.com", 120 }, 121 }, 122 }, 123 ExpectedCertIndex: 0, 124 ServerName: "test.com", 125 }, 126 "matching IP in SNI cert and the server cert": { 127 // IPs must not be passed via SNI. Hence, the ServerName in the 128 // HELLO packet is empty and the server should select the non-SNI cert. 129 Cert: TestCertSpec{ 130 host: "localhost", 131 ips: []string{"10.0.0.1", "127.0.0.1"}, 132 }, 133 SNICerts: []NamedTestCertSpec{ 134 { 135 TestCertSpec: TestCertSpec{ 136 host: "test.com", 137 ips: []string{"10.0.0.1"}, 138 }, 139 }, 140 }, 141 ExpectedCertIndex: -1, 142 ServerName: "10.0.0.1", 143 }, 144 "wildcards": { 145 Cert: TestCertSpec{ 146 host: "localhost", 147 ips: []string{"127.0.0.1"}, 148 }, 149 SNICerts: []NamedTestCertSpec{ 150 { 151 TestCertSpec: TestCertSpec{ 152 host: "test.com", 153 names: []string{"*.test.com"}, 154 }, 155 }, 156 }, 157 ExpectedCertIndex: 0, 158 ServerName: "www.test.com", 159 }, 160 161 "loopback: LoopbackClientServerNameOverride not on any cert": { 162 Cert: TestCertSpec{ 163 host: "test.com", 164 }, 165 SNICerts: []NamedTestCertSpec{ 166 { 167 TestCertSpec: TestCertSpec{ 168 host: "localhost", 169 }, 170 }, 171 }, 172 ExpectedCertIndex: 0, 173 }, 174 "loopback: LoopbackClientServerNameOverride on server cert": { 175 Cert: TestCertSpec{ 176 host: server.LoopbackClientServerNameOverride, 177 }, 178 SNICerts: []NamedTestCertSpec{ 179 { 180 TestCertSpec: TestCertSpec{ 181 host: "localhost", 182 }, 183 }, 184 }, 185 ExpectedCertIndex: 0, 186 }, 187 "loopback: LoopbackClientServerNameOverride on SNI cert": { 188 Cert: TestCertSpec{ 189 host: "localhost", 190 }, 191 SNICerts: []NamedTestCertSpec{ 192 { 193 TestCertSpec: TestCertSpec{ 194 host: server.LoopbackClientServerNameOverride, 195 }, 196 }, 197 }, 198 ExpectedCertIndex: -1, 199 }, 200 "loopback: bind to 0.0.0.0 => loopback uses localhost": { 201 Cert: TestCertSpec{ 202 host: "localhost", 203 }, 204 ExpectedCertIndex: -1, 205 LoopbackClientBindAddressOverride: "0.0.0.0", 206 }, 207 } 208 209 specToName := func(spec TestCertSpec) string { 210 name := spec.host + "_" + strings.Join(spec.names, ",") + "_" + strings.Join(spec.ips, ",") 211 return strings.Replace(name, "*", "star", -1) 212 } 213 214 for title := range tests { 215 test := tests[title] 216 t.Run(title, func(t *testing.T) { 217 t.Parallel() 218 // create server cert 219 certDir := "testdata/" + specToName(test.Cert) 220 serverCertBundleFile := filepath.Join(certDir, "cert") 221 serverKeyFile := filepath.Join(certDir, "key") 222 err := getOrCreateTestCertFiles(serverCertBundleFile, serverKeyFile, test.Cert) 223 if err != nil { 224 t.Fatalf("failed to create server cert: %v", err) 225 } 226 ca, err := caCertFromBundle(serverCertBundleFile) 227 if err != nil { 228 t.Fatalf("failed to extract ca cert from server cert bundle: %v", err) 229 } 230 caCerts := []*x509.Certificate{ca} 231 232 // create SNI certs 233 var namedCertKeys []cliflag.NamedCertKey 234 serverSig, err := certFileSignature(serverCertBundleFile, serverKeyFile) 235 if err != nil { 236 t.Fatalf("failed to get server cert signature: %v", err) 237 } 238 signatures := map[string]int{ 239 serverSig: -1, 240 } 241 for j, c := range test.SNICerts { 242 sniDir := filepath.Join(certDir, specToName(c.TestCertSpec)) 243 certBundleFile := filepath.Join(sniDir, "cert") 244 keyFile := filepath.Join(sniDir, "key") 245 err := getOrCreateTestCertFiles(certBundleFile, keyFile, c.TestCertSpec) 246 if err != nil { 247 t.Fatalf("failed to create SNI cert %d: %v", j, err) 248 } 249 250 namedCertKeys = append(namedCertKeys, cliflag.NamedCertKey{ 251 KeyFile: keyFile, 252 CertFile: certBundleFile, 253 Names: c.explicitNames, 254 }) 255 256 ca, err := caCertFromBundle(certBundleFile) 257 if err != nil { 258 t.Fatalf("failed to extract ca cert from SNI cert %d: %v", j, err) 259 } 260 caCerts = append(caCerts, ca) 261 262 // store index in namedCertKeys with the signature as the key 263 sig, err := certFileSignature(certBundleFile, keyFile) 264 if err != nil { 265 t.Fatalf("failed get SNI cert %d signature: %v", j, err) 266 } 267 signatures[sig] = j 268 } 269 270 stopCh := make(chan struct{}) 271 defer close(stopCh) 272 273 // launch server 274 config := setUp(t) 275 276 v := fakeVersion() 277 config.Version = &v 278 279 config.EnableIndex = true 280 secureOptions := (&SecureServingOptions{ 281 BindAddress: net.ParseIP("127.0.0.1"), 282 BindPort: 6443, 283 ServerCert: GeneratableKeyCert{ 284 CertKey: CertKey{ 285 CertFile: serverCertBundleFile, 286 KeyFile: serverKeyFile, 287 }, 288 }, 289 SNICertKeys: namedCertKeys, 290 }).WithLoopback() 291 // use a random free port 292 ln, err := net.Listen("tcp", "127.0.0.1:0") 293 if err != nil { 294 t.Fatalf("failed to listen on 127.0.0.1:0") 295 } 296 297 secureOptions.Listener = ln 298 // get port 299 secureOptions.BindPort = ln.Addr().(*net.TCPAddr).Port 300 config.LoopbackClientConfig = &restclient.Config{} 301 if err := secureOptions.ApplyTo(&config.SecureServing, &config.LoopbackClientConfig); err != nil { 302 t.Fatalf("failed applying the SecureServingOptions: %v", err) 303 } 304 305 s, err := config.Complete(nil).New("test", server.NewEmptyDelegate()) 306 if err != nil { 307 t.Fatalf("failed creating the server: %v", err) 308 } 309 310 // add poststart hook to know when the server is up. 311 startedCh := make(chan struct{}) 312 s.AddPostStartHookOrDie("test-notifier", func(context server.PostStartHookContext) error { 313 close(startedCh) 314 return nil 315 }) 316 preparedServer := s.PrepareRun() 317 go func() { 318 if err := preparedServer.Run(stopCh); err != nil { 319 t.Fatal(err) 320 } 321 }() 322 323 // load ca certificates into a pool 324 roots := x509.NewCertPool() 325 for _, caCert := range caCerts { 326 roots.AddCert(caCert) 327 } 328 329 <-startedCh 330 331 // try to dial 332 addr := fmt.Sprintf("localhost:%d", secureOptions.BindPort) 333 t.Logf("Dialing %s as %q", addr, test.ServerName) 334 conn, err := tls.Dial("tcp", addr, &tls.Config{ 335 RootCAs: roots, 336 ServerName: test.ServerName, // used for SNI in the client HELLO packet 337 }) 338 if err != nil { 339 t.Fatalf("failed to connect: %v", err) 340 } 341 defer conn.Close() 342 343 // check returned server certificate 344 sig := x509CertSignature(conn.ConnectionState().PeerCertificates[0]) 345 gotCertIndex, found := signatures[sig] 346 if !found { 347 t.Errorf("unknown signature returned from server: %s", sig) 348 } 349 if gotCertIndex != test.ExpectedCertIndex { 350 t.Errorf("expected cert index %d, got cert index %d", test.ExpectedCertIndex, gotCertIndex) 351 } 352 353 // check that the loopback client can connect 354 host := "127.0.0.1" 355 if len(test.LoopbackClientBindAddressOverride) != 0 { 356 host = test.LoopbackClientBindAddressOverride 357 } 358 s.LoopbackClientConfig.Host = net.JoinHostPort(host, strconv.Itoa(secureOptions.BindPort)) 359 if test.ExpectLoopbackClientError { 360 if err == nil { 361 t.Fatalf("expected error creating loopback client config") 362 } 363 return 364 } 365 if err != nil { 366 t.Fatalf("failed creating loopback client config: %v", err) 367 } 368 client, err := discovery.NewDiscoveryClientForConfig(s.LoopbackClientConfig) 369 if err != nil { 370 t.Fatalf("failed to create loopback client: %v", err) 371 } 372 got, err := client.ServerVersion() 373 if err != nil { 374 t.Fatalf("failed to connect with loopback client: %v", err) 375 } 376 if expected := &v; !reflect.DeepEqual(got, expected) { 377 t.Errorf("loopback client didn't get correct version info: expected=%v got=%v", expected, got) 378 } 379 }) 380 } 381} 382 383func parseIPList(ips []string) []net.IP { 384 var netIPs []net.IP 385 for _, ip := range ips { 386 netIPs = append(netIPs, net.ParseIP(ip)) 387 } 388 return netIPs 389} 390 391func getOrCreateTestCertFiles(certFileName, keyFileName string, spec TestCertSpec) (err error) { 392 if _, err := os.Stat(certFileName); err == nil { 393 if _, err := os.Stat(keyFileName); err == nil { 394 return nil 395 } 396 } 397 398 certPem, keyPem, err := generateSelfSignedCertKey(spec.host, parseIPList(spec.ips), spec.names) 399 if err != nil { 400 return err 401 } 402 403 os.MkdirAll(filepath.Dir(certFileName), os.FileMode(0755)) 404 err = ioutil.WriteFile(certFileName, certPem, os.FileMode(0755)) 405 if err != nil { 406 return err 407 } 408 409 os.MkdirAll(filepath.Dir(keyFileName), os.FileMode(0755)) 410 err = ioutil.WriteFile(keyFileName, keyPem, os.FileMode(0755)) 411 if err != nil { 412 return err 413 } 414 415 return nil 416} 417 418func caCertFromBundle(bundlePath string) (*x509.Certificate, error) { 419 pemData, err := ioutil.ReadFile(bundlePath) 420 if err != nil { 421 return nil, err 422 } 423 424 // fetch last block 425 var block *pem.Block 426 for { 427 var nextBlock *pem.Block 428 nextBlock, pemData = pem.Decode(pemData) 429 if nextBlock == nil { 430 if block == nil { 431 return nil, fmt.Errorf("no certificate found in %q", bundlePath) 432 433 } 434 return x509.ParseCertificate(block.Bytes) 435 } 436 block = nextBlock 437 } 438} 439 440func x509CertSignature(cert *x509.Certificate) string { 441 return base64.StdEncoding.EncodeToString(cert.Signature) 442} 443 444func certFileSignature(certFile, keyFile string) (string, error) { 445 cert, err := tls.LoadX509KeyPair(certFile, keyFile) 446 if err != nil { 447 return "", err 448 } 449 return certSignature(cert) 450} 451 452func certSignature(cert tls.Certificate) (string, error) { 453 x509Certs, err := x509.ParseCertificates(cert.Certificate[0]) 454 if err != nil { 455 return "", err 456 } 457 return x509CertSignature(x509Certs[0]), nil 458} 459 460func fakeVersion() version.Info { 461 return version.Info{ 462 Major: "42", 463 Minor: "42", 464 GitVersion: "42", 465 GitCommit: "34973274ccef6ab4dfaaf86599792fa9c3fe4689", 466 GitTreeState: "Dirty", 467 } 468} 469 470// generateSelfSignedCertKey creates a self-signed certificate and key for the given host. 471// Host may be an IP or a DNS name 472// You may also specify additional subject alt names (either ip or dns names) for the certificate 473func generateSelfSignedCertKey(host string, alternateIPs []net.IP, alternateDNS []string) ([]byte, []byte, error) { 474 priv, err := rsa.GenerateKey(cryptorand.Reader, 2048) 475 if err != nil { 476 return nil, nil, err 477 } 478 479 template := x509.Certificate{ 480 SerialNumber: big.NewInt(1), 481 Subject: pkix.Name{ 482 CommonName: fmt.Sprintf("%s@%d", host, time.Now().Unix()), 483 }, 484 NotBefore: time.Unix(0, 0), 485 NotAfter: time.Now().Add(time.Hour * 24 * 365 * 100), 486 487 KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, 488 ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, 489 BasicConstraintsValid: true, 490 IsCA: true, 491 } 492 493 if ip := net.ParseIP(host); ip != nil { 494 template.IPAddresses = append(template.IPAddresses, ip) 495 } else { 496 template.DNSNames = append(template.DNSNames, host) 497 } 498 499 template.IPAddresses = append(template.IPAddresses, alternateIPs...) 500 template.DNSNames = append(template.DNSNames, alternateDNS...) 501 502 derBytes, err := x509.CreateCertificate(cryptorand.Reader, &template, &template, &priv.PublicKey, priv) 503 if err != nil { 504 return nil, nil, err 505 } 506 507 // Generate cert 508 certBuffer := bytes.Buffer{} 509 if err := pem.Encode(&certBuffer, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { 510 return nil, nil, err 511 } 512 513 // Generate key 514 keyBuffer := bytes.Buffer{} 515 if err := pem.Encode(&keyBuffer, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}); err != nil { 516 return nil, nil, err 517 } 518 519 return certBuffer.Bytes(), keyBuffer.Bytes(), nil 520} 521