1/*
2Copyright 2016 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package options
18
19import (
20	"bytes"
21	cryptorand "crypto/rand"
22	"crypto/rsa"
23	"crypto/tls"
24	"crypto/x509"
25	"crypto/x509/pkix"
26	"encoding/base64"
27	"encoding/pem"
28	"fmt"
29	"io/ioutil"
30	"math/big"
31	"net"
32	"os"
33	"path/filepath"
34	"reflect"
35	"strconv"
36	"strings"
37	"testing"
38	"time"
39
40	"k8s.io/apimachinery/pkg/runtime"
41	"k8s.io/apimachinery/pkg/runtime/serializer"
42	"k8s.io/apimachinery/pkg/version"
43	"k8s.io/apiserver/pkg/server"
44	"k8s.io/client-go/discovery"
45	restclient "k8s.io/client-go/rest"
46	cliflag "k8s.io/component-base/cli/flag"
47)
48
49func setUp(t *testing.T) server.Config {
50	scheme := runtime.NewScheme()
51	codecs := serializer.NewCodecFactory(scheme)
52
53	config := server.NewConfig(codecs)
54
55	return *config
56}
57
58type TestCertSpec struct {
59	host       string
60	names, ips []string // in certificate
61}
62
63type NamedTestCertSpec struct {
64	TestCertSpec
65	explicitNames []string // as --tls-sni-cert-key explicit names
66}
67
68func TestServerRunWithSNI(t *testing.T) {
69	tests := map[string]struct {
70		Cert              TestCertSpec
71		SNICerts          []NamedTestCertSpec
72		ExpectedCertIndex int
73
74		// passed in the client hello info, "localhost" if unset
75		ServerName string
76
77		// optional ip or hostname to pass to NewLoopbackClientConfig
78		LoopbackClientBindAddressOverride string
79		ExpectLoopbackClientError         bool
80	}{
81		"only one cert": {
82			Cert: TestCertSpec{
83				host: "localhost",
84				ips:  []string{"127.0.0.1"},
85			},
86			ExpectedCertIndex: -1,
87		},
88		"cert with multiple alternate names": {
89			Cert: TestCertSpec{
90				host:  "localhost",
91				names: []string{"test.com"},
92				ips:   []string{"127.0.0.1"},
93			},
94			ExpectedCertIndex: -1,
95			ServerName:        "test.com",
96		},
97		"one SNI and the default cert with the same name": {
98			Cert: TestCertSpec{
99				host: "localhost",
100				ips:  []string{"127.0.0.1"},
101			},
102			SNICerts: []NamedTestCertSpec{
103				{
104					TestCertSpec: TestCertSpec{
105						host: "localhost",
106					},
107				},
108			},
109			ExpectedCertIndex: 0,
110		},
111		"matching SNI cert": {
112			Cert: TestCertSpec{
113				host: "localhost",
114				ips:  []string{"127.0.0.1"},
115			},
116			SNICerts: []NamedTestCertSpec{
117				{
118					TestCertSpec: TestCertSpec{
119						host: "test.com",
120					},
121				},
122			},
123			ExpectedCertIndex: 0,
124			ServerName:        "test.com",
125		},
126		"matching IP in SNI cert and the server cert": {
127			// IPs must not be passed via SNI. Hence, the ServerName in the
128			// HELLO packet is empty and the server should select the non-SNI cert.
129			Cert: TestCertSpec{
130				host: "localhost",
131				ips:  []string{"10.0.0.1", "127.0.0.1"},
132			},
133			SNICerts: []NamedTestCertSpec{
134				{
135					TestCertSpec: TestCertSpec{
136						host: "test.com",
137						ips:  []string{"10.0.0.1"},
138					},
139				},
140			},
141			ExpectedCertIndex: -1,
142			ServerName:        "10.0.0.1",
143		},
144		"wildcards": {
145			Cert: TestCertSpec{
146				host: "localhost",
147				ips:  []string{"127.0.0.1"},
148			},
149			SNICerts: []NamedTestCertSpec{
150				{
151					TestCertSpec: TestCertSpec{
152						host:  "test.com",
153						names: []string{"*.test.com"},
154					},
155				},
156			},
157			ExpectedCertIndex: 0,
158			ServerName:        "www.test.com",
159		},
160
161		"loopback: LoopbackClientServerNameOverride not on any cert": {
162			Cert: TestCertSpec{
163				host: "test.com",
164			},
165			SNICerts: []NamedTestCertSpec{
166				{
167					TestCertSpec: TestCertSpec{
168						host: "localhost",
169					},
170				},
171			},
172			ExpectedCertIndex: 0,
173		},
174		"loopback: LoopbackClientServerNameOverride on server cert": {
175			Cert: TestCertSpec{
176				host: server.LoopbackClientServerNameOverride,
177			},
178			SNICerts: []NamedTestCertSpec{
179				{
180					TestCertSpec: TestCertSpec{
181						host: "localhost",
182					},
183				},
184			},
185			ExpectedCertIndex: 0,
186		},
187		"loopback: LoopbackClientServerNameOverride on SNI cert": {
188			Cert: TestCertSpec{
189				host: "localhost",
190			},
191			SNICerts: []NamedTestCertSpec{
192				{
193					TestCertSpec: TestCertSpec{
194						host: server.LoopbackClientServerNameOverride,
195					},
196				},
197			},
198			ExpectedCertIndex: -1,
199		},
200		"loopback: bind to 0.0.0.0 => loopback uses localhost": {
201			Cert: TestCertSpec{
202				host: "localhost",
203			},
204			ExpectedCertIndex:                 -1,
205			LoopbackClientBindAddressOverride: "0.0.0.0",
206		},
207	}
208
209	specToName := func(spec TestCertSpec) string {
210		name := spec.host + "_" + strings.Join(spec.names, ",") + "_" + strings.Join(spec.ips, ",")
211		return strings.Replace(name, "*", "star", -1)
212	}
213
214	for title := range tests {
215		test := tests[title]
216		t.Run(title, func(t *testing.T) {
217			t.Parallel()
218			// create server cert
219			certDir := "testdata/" + specToName(test.Cert)
220			serverCertBundleFile := filepath.Join(certDir, "cert")
221			serverKeyFile := filepath.Join(certDir, "key")
222			err := getOrCreateTestCertFiles(serverCertBundleFile, serverKeyFile, test.Cert)
223			if err != nil {
224				t.Fatalf("failed to create server cert: %v", err)
225			}
226			ca, err := caCertFromBundle(serverCertBundleFile)
227			if err != nil {
228				t.Fatalf("failed to extract ca cert from server cert bundle: %v", err)
229			}
230			caCerts := []*x509.Certificate{ca}
231
232			// create SNI certs
233			var namedCertKeys []cliflag.NamedCertKey
234			serverSig, err := certFileSignature(serverCertBundleFile, serverKeyFile)
235			if err != nil {
236				t.Fatalf("failed to get server cert signature: %v", err)
237			}
238			signatures := map[string]int{
239				serverSig: -1,
240			}
241			for j, c := range test.SNICerts {
242				sniDir := filepath.Join(certDir, specToName(c.TestCertSpec))
243				certBundleFile := filepath.Join(sniDir, "cert")
244				keyFile := filepath.Join(sniDir, "key")
245				err := getOrCreateTestCertFiles(certBundleFile, keyFile, c.TestCertSpec)
246				if err != nil {
247					t.Fatalf("failed to create SNI cert %d: %v", j, err)
248				}
249
250				namedCertKeys = append(namedCertKeys, cliflag.NamedCertKey{
251					KeyFile:  keyFile,
252					CertFile: certBundleFile,
253					Names:    c.explicitNames,
254				})
255
256				ca, err := caCertFromBundle(certBundleFile)
257				if err != nil {
258					t.Fatalf("failed to extract ca cert from SNI cert %d: %v", j, err)
259				}
260				caCerts = append(caCerts, ca)
261
262				// store index in namedCertKeys with the signature as the key
263				sig, err := certFileSignature(certBundleFile, keyFile)
264				if err != nil {
265					t.Fatalf("failed get SNI cert %d signature: %v", j, err)
266				}
267				signatures[sig] = j
268			}
269
270			stopCh := make(chan struct{})
271			defer close(stopCh)
272
273			// launch server
274			config := setUp(t)
275
276			v := fakeVersion()
277			config.Version = &v
278
279			config.EnableIndex = true
280			secureOptions := (&SecureServingOptions{
281				BindAddress: net.ParseIP("127.0.0.1"),
282				BindPort:    6443,
283				ServerCert: GeneratableKeyCert{
284					CertKey: CertKey{
285						CertFile: serverCertBundleFile,
286						KeyFile:  serverKeyFile,
287					},
288				},
289				SNICertKeys: namedCertKeys,
290			}).WithLoopback()
291			// use a random free port
292			ln, err := net.Listen("tcp", "127.0.0.1:0")
293			if err != nil {
294				t.Fatalf("failed to listen on 127.0.0.1:0")
295			}
296
297			secureOptions.Listener = ln
298			// get port
299			secureOptions.BindPort = ln.Addr().(*net.TCPAddr).Port
300			config.LoopbackClientConfig = &restclient.Config{}
301			if err := secureOptions.ApplyTo(&config.SecureServing, &config.LoopbackClientConfig); err != nil {
302				t.Fatalf("failed applying the SecureServingOptions: %v", err)
303			}
304
305			s, err := config.Complete(nil).New("test", server.NewEmptyDelegate())
306			if err != nil {
307				t.Fatalf("failed creating the server: %v", err)
308			}
309
310			// add poststart hook to know when the server is up.
311			startedCh := make(chan struct{})
312			s.AddPostStartHookOrDie("test-notifier", func(context server.PostStartHookContext) error {
313				close(startedCh)
314				return nil
315			})
316			preparedServer := s.PrepareRun()
317			go func() {
318				if err := preparedServer.Run(stopCh); err != nil {
319					t.Fatal(err)
320				}
321			}()
322
323			// load ca certificates into a pool
324			roots := x509.NewCertPool()
325			for _, caCert := range caCerts {
326				roots.AddCert(caCert)
327			}
328
329			<-startedCh
330
331			// try to dial
332			addr := fmt.Sprintf("localhost:%d", secureOptions.BindPort)
333			t.Logf("Dialing %s as %q", addr, test.ServerName)
334			conn, err := tls.Dial("tcp", addr, &tls.Config{
335				RootCAs:    roots,
336				ServerName: test.ServerName, // used for SNI in the client HELLO packet
337			})
338			if err != nil {
339				t.Fatalf("failed to connect: %v", err)
340			}
341			defer conn.Close()
342
343			// check returned server certificate
344			sig := x509CertSignature(conn.ConnectionState().PeerCertificates[0])
345			gotCertIndex, found := signatures[sig]
346			if !found {
347				t.Errorf("unknown signature returned from server: %s", sig)
348			}
349			if gotCertIndex != test.ExpectedCertIndex {
350				t.Errorf("expected cert index %d, got cert index %d", test.ExpectedCertIndex, gotCertIndex)
351			}
352
353			// check that the loopback client can connect
354			host := "127.0.0.1"
355			if len(test.LoopbackClientBindAddressOverride) != 0 {
356				host = test.LoopbackClientBindAddressOverride
357			}
358			s.LoopbackClientConfig.Host = net.JoinHostPort(host, strconv.Itoa(secureOptions.BindPort))
359			if test.ExpectLoopbackClientError {
360				if err == nil {
361					t.Fatalf("expected error creating loopback client config")
362				}
363				return
364			}
365			if err != nil {
366				t.Fatalf("failed creating loopback client config: %v", err)
367			}
368			client, err := discovery.NewDiscoveryClientForConfig(s.LoopbackClientConfig)
369			if err != nil {
370				t.Fatalf("failed to create loopback client: %v", err)
371			}
372			got, err := client.ServerVersion()
373			if err != nil {
374				t.Fatalf("failed to connect with loopback client: %v", err)
375			}
376			if expected := &v; !reflect.DeepEqual(got, expected) {
377				t.Errorf("loopback client didn't get correct version info: expected=%v got=%v", expected, got)
378			}
379		})
380	}
381}
382
383func parseIPList(ips []string) []net.IP {
384	var netIPs []net.IP
385	for _, ip := range ips {
386		netIPs = append(netIPs, net.ParseIP(ip))
387	}
388	return netIPs
389}
390
391func getOrCreateTestCertFiles(certFileName, keyFileName string, spec TestCertSpec) (err error) {
392	if _, err := os.Stat(certFileName); err == nil {
393		if _, err := os.Stat(keyFileName); err == nil {
394			return nil
395		}
396	}
397
398	certPem, keyPem, err := generateSelfSignedCertKey(spec.host, parseIPList(spec.ips), spec.names)
399	if err != nil {
400		return err
401	}
402
403	os.MkdirAll(filepath.Dir(certFileName), os.FileMode(0755))
404	err = ioutil.WriteFile(certFileName, certPem, os.FileMode(0755))
405	if err != nil {
406		return err
407	}
408
409	os.MkdirAll(filepath.Dir(keyFileName), os.FileMode(0755))
410	err = ioutil.WriteFile(keyFileName, keyPem, os.FileMode(0755))
411	if err != nil {
412		return err
413	}
414
415	return nil
416}
417
418func caCertFromBundle(bundlePath string) (*x509.Certificate, error) {
419	pemData, err := ioutil.ReadFile(bundlePath)
420	if err != nil {
421		return nil, err
422	}
423
424	// fetch last block
425	var block *pem.Block
426	for {
427		var nextBlock *pem.Block
428		nextBlock, pemData = pem.Decode(pemData)
429		if nextBlock == nil {
430			if block == nil {
431				return nil, fmt.Errorf("no certificate found in %q", bundlePath)
432
433			}
434			return x509.ParseCertificate(block.Bytes)
435		}
436		block = nextBlock
437	}
438}
439
440func x509CertSignature(cert *x509.Certificate) string {
441	return base64.StdEncoding.EncodeToString(cert.Signature)
442}
443
444func certFileSignature(certFile, keyFile string) (string, error) {
445	cert, err := tls.LoadX509KeyPair(certFile, keyFile)
446	if err != nil {
447		return "", err
448	}
449	return certSignature(cert)
450}
451
452func certSignature(cert tls.Certificate) (string, error) {
453	x509Certs, err := x509.ParseCertificates(cert.Certificate[0])
454	if err != nil {
455		return "", err
456	}
457	return x509CertSignature(x509Certs[0]), nil
458}
459
460func fakeVersion() version.Info {
461	return version.Info{
462		Major:        "42",
463		Minor:        "42",
464		GitVersion:   "42",
465		GitCommit:    "34973274ccef6ab4dfaaf86599792fa9c3fe4689",
466		GitTreeState: "Dirty",
467	}
468}
469
470// generateSelfSignedCertKey creates a self-signed certificate and key for the given host.
471// Host may be an IP or a DNS name
472// You may also specify additional subject alt names (either ip or dns names) for the certificate
473func generateSelfSignedCertKey(host string, alternateIPs []net.IP, alternateDNS []string) ([]byte, []byte, error) {
474	priv, err := rsa.GenerateKey(cryptorand.Reader, 2048)
475	if err != nil {
476		return nil, nil, err
477	}
478
479	template := x509.Certificate{
480		SerialNumber: big.NewInt(1),
481		Subject: pkix.Name{
482			CommonName: fmt.Sprintf("%s@%d", host, time.Now().Unix()),
483		},
484		NotBefore: time.Unix(0, 0),
485		NotAfter:  time.Now().Add(time.Hour * 24 * 365 * 100),
486
487		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
488		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
489		BasicConstraintsValid: true,
490		IsCA:                  true,
491	}
492
493	if ip := net.ParseIP(host); ip != nil {
494		template.IPAddresses = append(template.IPAddresses, ip)
495	} else {
496		template.DNSNames = append(template.DNSNames, host)
497	}
498
499	template.IPAddresses = append(template.IPAddresses, alternateIPs...)
500	template.DNSNames = append(template.DNSNames, alternateDNS...)
501
502	derBytes, err := x509.CreateCertificate(cryptorand.Reader, &template, &template, &priv.PublicKey, priv)
503	if err != nil {
504		return nil, nil, err
505	}
506
507	// Generate cert
508	certBuffer := bytes.Buffer{}
509	if err := pem.Encode(&certBuffer, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
510		return nil, nil, err
511	}
512
513	// Generate key
514	keyBuffer := bytes.Buffer{}
515	if err := pem.Encode(&keyBuffer, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}); err != nil {
516		return nil, nil, err
517	}
518
519	return certBuffer.Bytes(), keyBuffer.Bytes(), nil
520}
521