1/*
2Copyright 2019 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package dynamiccertificates
18
19import (
20	"bytes"
21	"crypto/rand"
22	"crypto/rsa"
23	"crypto/tls"
24	"crypto/x509"
25	"crypto/x509/pkix"
26	"encoding/base64"
27	"encoding/pem"
28	"fmt"
29	"math/big"
30	"net"
31	"testing"
32	"time"
33
34	"github.com/stretchr/testify/assert"
35)
36
37type testCertSpec struct {
38	host       string
39	names, ips []string // in certificate
40}
41
42type namedtestCertSpec struct {
43	testCertSpec
44	explicitNames []string // as --tls-sni-cert-key explicit names
45}
46
47func TestBuiltNamedCertificates(t *testing.T) {
48	tests := []struct {
49		certs         []namedtestCertSpec
50		explicitNames []string
51		expected      map[string]int // name to certs[*] index
52		errorString   string
53	}{
54		{
55			// empty certs
56			expected: map[string]int{},
57		},
58		{
59			// only one cert
60			certs: []namedtestCertSpec{
61				{
62					testCertSpec: testCertSpec{
63						host: "test.com",
64					},
65				},
66			},
67			expected: map[string]int{
68				"test.com": 0,
69			},
70		},
71		{
72			// ip as cns are ignored
73			certs: []namedtestCertSpec{
74				{
75					testCertSpec: testCertSpec{
76						host:  "1.2.3.4",
77						names: []string{"test.com"},
78					},
79				},
80			},
81			expected: map[string]int{
82				"test.com": 0,
83			},
84		},
85		{
86			// ips are ignored
87			certs: []namedtestCertSpec{
88				{
89					testCertSpec: testCertSpec{
90						host: "test.com",
91						ips:  []string{"1.2.3.4"},
92					},
93				},
94			},
95			expected: map[string]int{
96				"test.com": 0,
97			},
98		},
99		{
100			// two certs with the same name
101			certs: []namedtestCertSpec{
102				{
103					testCertSpec: testCertSpec{
104						host: "test.com",
105					},
106				},
107				{
108					testCertSpec: testCertSpec{
109						host: "test.com",
110					},
111				},
112			},
113			expected: map[string]int{
114				"test.com": 0,
115			},
116		},
117		{
118			// two certs with different names
119			certs: []namedtestCertSpec{
120				{
121					testCertSpec: testCertSpec{
122						host: "test2.com",
123					},
124				},
125				{
126					testCertSpec: testCertSpec{
127						host: "test1.com",
128					},
129				},
130			},
131			expected: map[string]int{
132				"test1.com": 1,
133				"test2.com": 0,
134			},
135		},
136		{
137			// two certs with the same name, explicit trumps
138			certs: []namedtestCertSpec{
139				{
140					testCertSpec: testCertSpec{
141						host: "test.com",
142					},
143				},
144				{
145					testCertSpec: testCertSpec{
146						host: "test.com",
147					},
148					explicitNames: []string{"test.com"},
149				},
150			},
151			expected: map[string]int{
152				"test.com": 1,
153			},
154		},
155		{
156			// certs with partial overlap; ips are ignored
157			certs: []namedtestCertSpec{
158				{
159					testCertSpec: testCertSpec{
160						host:  "a",
161						names: []string{"a.test.com", "test.com"},
162					},
163				},
164				{
165					testCertSpec: testCertSpec{
166						host:  "b",
167						names: []string{"b.test.com", "test.com"},
168					},
169				},
170			},
171			expected: map[string]int{
172				"a": 0, "b": 1,
173				"a.test.com": 0, "b.test.com": 1,
174				"test.com": 0,
175			},
176		},
177		{
178			// wildcards
179			certs: []namedtestCertSpec{
180				{
181					testCertSpec: testCertSpec{
182						host:  "a",
183						names: []string{"a.test.com", "test.com"},
184					},
185					explicitNames: []string{"*.test.com", "test.com"},
186				},
187				{
188					testCertSpec: testCertSpec{
189						host:  "b",
190						names: []string{"b.test.com", "test.com"},
191					},
192					explicitNames: []string{"dev.test.com", "test.com"},
193				}},
194			expected: map[string]int{
195				"test.com":     0,
196				"*.test.com":   0,
197				"dev.test.com": 1,
198			},
199		},
200	}
201
202NextTest:
203	for i, test := range tests {
204		var sniCerts []SNICertKeyContentProvider
205		bySignature := map[string]int{} // index in test.certs by cert signature
206		for j, c := range test.certs {
207			certProvider, err := createTestTLSCerts(c.testCertSpec, c.explicitNames)
208			if err != nil {
209				t.Errorf("%d - failed to create cert %d: %v", i, j, err)
210				continue NextTest
211			}
212
213			sniCerts = append(sniCerts, certProvider)
214
215			sig, err := certSignature(certProvider)
216			if err != nil {
217				t.Errorf("%d - failed to get signature for %d: %v", i, j, err)
218				continue NextTest
219			}
220			bySignature[sig] = j
221		}
222
223		c := DynamicServingCertificateController{sniCerts: sniCerts}
224		content, err := c.newTLSContent()
225		assert.NoError(t, err)
226
227		certMap, err := c.BuildNamedCertificates(content.sniCerts)
228		if err == nil && len(test.errorString) != 0 {
229			t.Errorf("%d - expected no error, got: %v", i, err)
230		} else if err != nil && err.Error() != test.errorString {
231			t.Errorf("%d - expected error %q, got: %v", i, test.errorString, err)
232		} else {
233			got := map[string]int{}
234			for name, cert := range certMap {
235				x509Certs, err := x509.ParseCertificates(cert.Certificate[0])
236				assert.NoError(t, err, "%d - invalid certificate for %q", i, name)
237				assert.True(t, len(x509Certs) > 0, "%d - expected at least one x509 cert in tls cert for %q", i, name)
238				got[name] = bySignature[x509CertSignature(x509Certs[0])]
239			}
240
241			assert.EqualValues(t, test.expected, got, "%d - wrong certificate map", i)
242		}
243	}
244}
245
246func parseIPList(ips []string) []net.IP {
247	var netIPs []net.IP
248	for _, ip := range ips {
249		netIPs = append(netIPs, net.ParseIP(ip))
250	}
251	return netIPs
252}
253
254func createTestTLSCerts(spec testCertSpec, names []string) (certProvider SNICertKeyContentProvider, err error) {
255	certPem, keyPem, err := generateSelfSignedCertKey(spec.host, parseIPList(spec.ips), spec.names)
256	if err != nil {
257		return nil, err
258	}
259
260	return NewStaticSNICertKeyContent("test-cert", certPem, keyPem, names...)
261}
262
263func x509CertSignature(cert *x509.Certificate) string {
264	return base64.StdEncoding.EncodeToString(cert.Signature)
265}
266
267func certSignature(certProvider CertKeyContentProvider) (string, error) {
268	currentCert, currentKey := certProvider.CurrentCertKeyContent()
269
270	tlsCert, err := tls.X509KeyPair(currentCert, currentKey)
271	if err != nil {
272		return "", err
273	}
274
275	x509Certs, err := x509.ParseCertificates(tlsCert.Certificate[0])
276	if err != nil {
277		return "", err
278	}
279	return x509CertSignature(x509Certs[0]), nil
280}
281
282// generateSelfSignedCertKey creates a self-signed certificate and key for the given host.
283// Host may be an IP or a DNS name
284// You may also specify additional subject alt names (either ip or dns names) for the certificate
285func generateSelfSignedCertKey(host string, alternateIPs []net.IP, alternateDNS []string) ([]byte, []byte, error) {
286	priv, err := rsa.GenerateKey(rand.Reader, 2048)
287	if err != nil {
288		return nil, nil, err
289	}
290
291	template := x509.Certificate{
292		SerialNumber: big.NewInt(1),
293		Subject: pkix.Name{
294			CommonName: fmt.Sprintf("%s@%d", host, time.Now().Unix()),
295		},
296		NotBefore: time.Unix(0, 0),
297		NotAfter:  time.Now().Add(time.Hour * 24 * 365 * 100),
298
299		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
300		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
301		BasicConstraintsValid: true,
302		IsCA:                  true,
303	}
304
305	if ip := net.ParseIP(host); ip != nil {
306		template.IPAddresses = append(template.IPAddresses, ip)
307	} else {
308		template.DNSNames = append(template.DNSNames, host)
309	}
310
311	template.IPAddresses = append(template.IPAddresses, alternateIPs...)
312	template.DNSNames = append(template.DNSNames, alternateDNS...)
313
314	derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
315	if err != nil {
316		return nil, nil, err
317	}
318
319	// Generate cert
320	certBuffer := bytes.Buffer{}
321	if err := pem.Encode(&certBuffer, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
322		return nil, nil, err
323	}
324
325	// Generate key
326	keyBuffer := bytes.Buffer{}
327	if err := pem.Encode(&keyBuffer, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}); err != nil {
328		return nil, nil, err
329	}
330
331	return certBuffer.Bytes(), keyBuffer.Bytes(), nil
332}
333