1package pki
2
3import (
4	"context"
5	"crypto/ecdsa"
6	"crypto/elliptic"
7	"crypto/rand"
8	"crypto/rsa"
9	"crypto/x509"
10	"crypto/x509/pkix"
11	"encoding/json"
12	"encoding/pem"
13	"math/big"
14	mathrand "math/rand"
15	"strings"
16	"testing"
17	"time"
18
19	"github.com/go-test/deep"
20	"github.com/hashicorp/vault/api"
21	vaulthttp "github.com/hashicorp/vault/http"
22	"github.com/hashicorp/vault/sdk/helper/certutil"
23	"github.com/hashicorp/vault/sdk/logical"
24	"github.com/hashicorp/vault/vault"
25)
26
27func TestBackend_CA_Steps(t *testing.T) {
28	var b *backend
29
30	factory := func(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
31		be, err := Factory(ctx, conf)
32		if err == nil {
33			b = be.(*backend)
34		}
35		return be, err
36	}
37
38	coreConfig := &vault.CoreConfig{
39		LogicalBackends: map[string]logical.Factory{
40			"pki": factory,
41		},
42	}
43	cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
44		HandlerFunc: vaulthttp.Handler,
45	})
46	cluster.Start()
47	defer cluster.Cleanup()
48
49	client := cluster.Cores[0].Client
50
51	// Set RSA/EC CA certificates
52	var rsaCAKey, rsaCACert, ecCAKey, ecCACert string
53	{
54		cak, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
55		if err != nil {
56			panic(err)
57		}
58		marshaledKey, err := x509.MarshalECPrivateKey(cak)
59		if err != nil {
60			panic(err)
61		}
62		keyPEMBlock := &pem.Block{
63			Type:  "EC PRIVATE KEY",
64			Bytes: marshaledKey,
65		}
66		ecCAKey = strings.TrimSpace(string(pem.EncodeToMemory(keyPEMBlock)))
67		if err != nil {
68			panic(err)
69		}
70		subjKeyID, err := certutil.GetSubjKeyID(cak)
71		if err != nil {
72			panic(err)
73		}
74		caCertTemplate := &x509.Certificate{
75			Subject: pkix.Name{
76				CommonName: "root.localhost",
77			},
78			SubjectKeyId:          subjKeyID,
79			DNSNames:              []string{"root.localhost"},
80			KeyUsage:              x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign),
81			SerialNumber:          big.NewInt(mathrand.Int63()),
82			NotAfter:              time.Now().Add(262980 * time.Hour),
83			BasicConstraintsValid: true,
84			IsCA:                  true,
85		}
86		caBytes, err := x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, cak.Public(), cak)
87		if err != nil {
88			panic(err)
89		}
90		caCertPEMBlock := &pem.Block{
91			Type:  "CERTIFICATE",
92			Bytes: caBytes,
93		}
94		ecCACert = strings.TrimSpace(string(pem.EncodeToMemory(caCertPEMBlock)))
95
96		rak, err := rsa.GenerateKey(rand.Reader, 2048)
97		if err != nil {
98			panic(err)
99		}
100		marshaledKey = x509.MarshalPKCS1PrivateKey(rak)
101		keyPEMBlock = &pem.Block{
102			Type:  "RSA PRIVATE KEY",
103			Bytes: marshaledKey,
104		}
105		rsaCAKey = strings.TrimSpace(string(pem.EncodeToMemory(keyPEMBlock)))
106		if err != nil {
107			panic(err)
108		}
109		_, err = certutil.GetSubjKeyID(rak)
110		if err != nil {
111			panic(err)
112		}
113		caBytes, err = x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, rak.Public(), rak)
114		if err != nil {
115			panic(err)
116		}
117		caCertPEMBlock = &pem.Block{
118			Type:  "CERTIFICATE",
119			Bytes: caBytes,
120		}
121		rsaCACert = strings.TrimSpace(string(pem.EncodeToMemory(caCertPEMBlock)))
122	}
123
124	// Setup backends
125	var rsaRoot, rsaInt, ecRoot, ecInt *backend
126	{
127		if err := client.Sys().Mount("rsaroot", &api.MountInput{
128			Type: "pki",
129			Config: api.MountConfigInput{
130				DefaultLeaseTTL: "16h",
131				MaxLeaseTTL:     "60h",
132			},
133		}); err != nil {
134			t.Fatal(err)
135		}
136		rsaRoot = b
137
138		if err := client.Sys().Mount("rsaint", &api.MountInput{
139			Type: "pki",
140			Config: api.MountConfigInput{
141				DefaultLeaseTTL: "16h",
142				MaxLeaseTTL:     "60h",
143			},
144		}); err != nil {
145			t.Fatal(err)
146		}
147		rsaInt = b
148
149		if err := client.Sys().Mount("ecroot", &api.MountInput{
150			Type: "pki",
151			Config: api.MountConfigInput{
152				DefaultLeaseTTL: "16h",
153				MaxLeaseTTL:     "60h",
154			},
155		}); err != nil {
156			t.Fatal(err)
157		}
158		ecRoot = b
159
160		if err := client.Sys().Mount("ecint", &api.MountInput{
161			Type: "pki",
162			Config: api.MountConfigInput{
163				DefaultLeaseTTL: "16h",
164				MaxLeaseTTL:     "60h",
165			},
166		}); err != nil {
167			t.Fatal(err)
168		}
169		ecInt = b
170	}
171
172	t.Run("teststeps", func(t *testing.T) {
173		t.Run("rsa", func(t *testing.T) {
174			t.Parallel()
175			subClient, err := client.Clone()
176			if err != nil {
177				t.Fatal(err)
178			}
179			subClient.SetToken(client.Token())
180			runSteps(t, rsaRoot, rsaInt, subClient, "rsaroot/", "rsaint/", rsaCACert, rsaCAKey)
181		})
182		t.Run("ec", func(t *testing.T) {
183			t.Parallel()
184			subClient, err := client.Clone()
185			if err != nil {
186				t.Fatal(err)
187			}
188			subClient.SetToken(client.Token())
189			runSteps(t, ecRoot, ecInt, subClient, "ecroot/", "ecint/", ecCACert, ecCAKey)
190		})
191	})
192}
193
194func runSteps(t *testing.T, rootB, intB *backend, client *api.Client, rootName, intName, caCert, caKey string) {
195	//  Load CA cert/key in and ensure we can fetch it back in various formats,
196	//  unauthenticated
197	{
198		// Attempt import but only provide one the cert
199		{
200			_, err := client.Logical().Write(rootName+"config/ca", map[string]interface{}{
201				"pem_bundle": caCert,
202			})
203			if err == nil {
204				t.Fatal("expected error")
205			}
206		}
207
208		// Same but with only the key
209		{
210			_, err := client.Logical().Write(rootName+"config/ca", map[string]interface{}{
211				"pem_bundle": caKey,
212			})
213			if err == nil {
214				t.Fatal("expected error")
215			}
216		}
217
218		// Import CA bundle
219		{
220			_, err := client.Logical().Write(rootName+"config/ca", map[string]interface{}{
221				"pem_bundle": strings.Join([]string{caKey, caCert}, "\n"),
222			})
223			if err != nil {
224				t.Fatal(err)
225			}
226		}
227
228		prevToken := client.Token()
229		client.SetToken("")
230
231		// cert/ca path
232		{
233			resp, err := client.Logical().Read(rootName + "cert/ca")
234			if err != nil {
235				t.Fatal(err)
236			}
237			if resp == nil {
238				t.Fatal("nil response")
239			}
240			if diff := deep.Equal(resp.Data["certificate"].(string), caCert); diff != nil {
241				t.Fatal(diff)
242			}
243		}
244		// ca/pem path (raw string)
245		{
246			req := &logical.Request{
247				Path:      "ca/pem",
248				Operation: logical.ReadOperation,
249				Storage:   rootB.storage,
250			}
251			resp, err := rootB.HandleRequest(context.Background(), req)
252			if err != nil {
253				t.Fatal(err)
254			}
255			if resp == nil {
256				t.Fatal("nil response")
257			}
258			if diff := deep.Equal(resp.Data["http_raw_body"].([]byte), []byte(caCert)); diff != nil {
259				t.Fatal(diff)
260			}
261			if resp.Data["http_content_type"].(string) != "application/pkix-cert" {
262				t.Fatal("wrong content type")
263			}
264		}
265
266		// ca (raw DER bytes)
267		{
268			req := &logical.Request{
269				Path:      "ca",
270				Operation: logical.ReadOperation,
271				Storage:   rootB.storage,
272			}
273			resp, err := rootB.HandleRequest(context.Background(), req)
274			if err != nil {
275				t.Fatal(err)
276			}
277			if resp == nil {
278				t.Fatal("nil response")
279			}
280			rawBytes := resp.Data["http_raw_body"].([]byte)
281			pemBytes := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{
282				Type:  "CERTIFICATE",
283				Bytes: rawBytes,
284			})))
285			if diff := deep.Equal(pemBytes, caCert); diff != nil {
286				t.Fatal(diff)
287			}
288			if resp.Data["http_content_type"].(string) != "application/pkix-cert" {
289				t.Fatal("wrong content type")
290			}
291		}
292
293		client.SetToken(prevToken)
294	}
295
296	// Configure an expiry on the CRL and verify what comes back
297	{
298		// Set CRL config
299		{
300			_, err := client.Logical().Write(rootName+"config/crl", map[string]interface{}{
301				"expiry": "16h",
302			})
303			if err != nil {
304				t.Fatal(err)
305			}
306		}
307
308		// Verify it
309		{
310			resp, err := client.Logical().Read(rootName + "config/crl")
311			if err != nil {
312				t.Fatal(err)
313			}
314			if resp == nil {
315				t.Fatal("nil response")
316			}
317			if resp.Data["expiry"].(string) != "16h" {
318				t.Fatal("expected a 16 hour expiry")
319			}
320		}
321	}
322
323	// Test generating a root, an intermediate, signing it, setting signed, and
324	// revoking it
325
326	// We'll need this later
327	var intSerialNumber string
328	{
329		// First, delete the existing CA info
330		{
331			_, err := client.Logical().Delete(rootName + "root")
332			if err != nil {
333				t.Fatal(err)
334			}
335		}
336
337		var rootPEM, rootKey, rootPEMBundle string
338		// Test exported root generation
339		{
340			resp, err := client.Logical().Write(rootName+"root/generate/exported", map[string]interface{}{
341				"common_name": "Root Cert",
342				"ttl":         "180h",
343			})
344			if err != nil {
345				t.Fatal(err)
346			}
347			if resp == nil {
348				t.Fatal("nil response")
349			}
350			rootPEM = resp.Data["certificate"].(string)
351			rootKey = resp.Data["private_key"].(string)
352			rootPEMBundle = strings.Join([]string{rootPEM, rootKey}, "\n")
353			// This is really here to keep the use checker happy
354			if rootPEMBundle == "" {
355				t.Fatal("bad root pem bundle")
356			}
357		}
358
359		var intPEM, intCSR, intKey string
360		// Test exported intermediate CSR generation
361		{
362			resp, err := client.Logical().Write(intName+"intermediate/generate/exported", map[string]interface{}{
363				"common_name": "intermediate.cert.com",
364				"ttl":         "180h",
365			})
366			if err != nil {
367				t.Fatal(err)
368			}
369			if resp == nil {
370				t.Fatal("nil response")
371			}
372			intCSR = resp.Data["csr"].(string)
373			intKey = resp.Data["private_key"].(string)
374			// This is really here to keep the use checker happy
375			if intCSR == "" || intKey == "" {
376				t.Fatal("int csr or key empty")
377			}
378		}
379
380		// Test signing
381		{
382			resp, err := client.Logical().Write(rootName+"root/sign-intermediate", map[string]interface{}{
383				"common_name": "intermediate.cert.com",
384				"ttl":         "10s",
385				"csr":         intCSR,
386			})
387			if err != nil {
388				t.Fatal(err)
389			}
390			if resp == nil {
391				t.Fatal("nil response")
392			}
393			intPEM = resp.Data["certificate"].(string)
394			intSerialNumber = resp.Data["serial_number"].(string)
395		}
396
397		// Test setting signed
398		{
399			resp, err := client.Logical().Write(intName+"intermediate/set-signed", map[string]interface{}{
400				"certificate": intPEM,
401			})
402			if err != nil {
403				t.Fatal(err)
404			}
405			if resp != nil {
406				t.Fatal("expected nil response")
407			}
408		}
409
410		// Verify we can find it via the root
411		{
412			resp, err := client.Logical().Read(rootName + "cert/" + intSerialNumber)
413			if err != nil {
414				t.Fatal(err)
415			}
416			if resp == nil {
417				t.Fatal("nil response")
418			}
419			if resp.Data["revocation_time"].(json.Number).String() != "0" {
420				t.Fatal("expected a zero revocation time")
421			}
422		}
423
424		// Revoke the intermediate
425		{
426			resp, err := client.Logical().Write(rootName+"revoke", map[string]interface{}{
427				"serial_number": intSerialNumber,
428			})
429			if err != nil {
430				t.Fatal(err)
431			}
432			if resp == nil {
433				t.Fatal("nil response")
434			}
435		}
436	}
437
438	verifyRevocation := func(t *testing.T, serial string, shouldFind bool) {
439		t.Helper()
440		// Verify it is now revoked
441		{
442			resp, err := client.Logical().Read(rootName + "cert/" + intSerialNumber)
443			if err != nil {
444				t.Fatal(err)
445			}
446			switch shouldFind {
447			case true:
448				if resp == nil {
449					t.Fatal("nil response")
450				}
451				if resp.Data["revocation_time"].(json.Number).String() == "0" {
452					t.Fatal("expected a non-zero revocation time")
453				}
454			default:
455				if resp != nil {
456					t.Fatalf("expected nil response, got %#v", *resp)
457				}
458			}
459		}
460
461		// Fetch the CRL and make sure it shows up
462		{
463			req := &logical.Request{
464				Path:      "crl",
465				Operation: logical.ReadOperation,
466				Storage:   rootB.storage,
467			}
468			resp, err := rootB.HandleRequest(context.Background(), req)
469			if err != nil {
470				t.Fatal(err)
471			}
472			if resp == nil {
473				t.Fatal("nil response")
474			}
475			crlBytes := resp.Data["http_raw_body"].([]byte)
476			certList, err := x509.ParseCRL(crlBytes)
477			if err != nil {
478				t.Fatal(err)
479			}
480			switch shouldFind {
481			case true:
482				revokedList := certList.TBSCertList.RevokedCertificates
483				if len(revokedList) != 1 {
484					t.Fatalf("bad length of revoked list: %d", len(revokedList))
485				}
486				revokedString := certutil.GetHexFormatted(revokedList[0].SerialNumber.Bytes(), ":")
487				if revokedString != intSerialNumber {
488					t.Fatalf("bad revoked serial: %s", revokedString)
489				}
490			default:
491				revokedList := certList.TBSCertList.RevokedCertificates
492				if len(revokedList) != 0 {
493					t.Fatalf("bad length of revoked list: %d", len(revokedList))
494				}
495			}
496		}
497	}
498
499	// Validate current state of revoked certificates
500	verifyRevocation(t, intSerialNumber, true)
501
502	// Give time for the safety buffer to pass before tidying
503	time.Sleep(10 * time.Second)
504
505	// Test tidying
506	{
507		// Run with a high safety buffer, nothing should happen
508		{
509			resp, err := client.Logical().Write(rootName+"tidy", map[string]interface{}{
510				"safety_buffer":      "3h",
511				"tidy_cert_store":    true,
512				"tidy_revoked_certs": true,
513			})
514			if err != nil {
515				t.Fatal(err)
516			}
517			if resp == nil {
518				t.Fatal("expected warnings")
519			}
520
521			// Wait a few seconds as it runs in a goroutine
522			time.Sleep(5 * time.Second)
523
524			// Check to make sure we still find the cert and see it on the CRL
525			verifyRevocation(t, intSerialNumber, true)
526		}
527
528		// Run with both values set false, nothing should happen
529		{
530			resp, err := client.Logical().Write(rootName+"tidy", map[string]interface{}{
531				"safety_buffer":      "1s",
532				"tidy_cert_store":    false,
533				"tidy_revoked_certs": false,
534			})
535			if err != nil {
536				t.Fatal(err)
537			}
538			if resp == nil {
539				t.Fatal("expected warnings")
540			}
541
542			// Wait a few seconds as it runs in a goroutine
543			time.Sleep(5 * time.Second)
544
545			// Check to make sure we still find the cert and see it on the CRL
546			verifyRevocation(t, intSerialNumber, true)
547		}
548
549		// Run with a short safety buffer and both set to true, both should be cleared
550		{
551			resp, err := client.Logical().Write(rootName+"tidy", map[string]interface{}{
552				"safety_buffer":      "1s",
553				"tidy_cert_store":    true,
554				"tidy_revoked_certs": true,
555			})
556			if err != nil {
557				t.Fatal(err)
558			}
559			if resp == nil {
560				t.Fatal("expected warnings")
561			}
562
563			// Wait a few seconds as it runs in a goroutine
564			time.Sleep(5 * time.Second)
565
566			// Check to make sure we still find the cert and see it on the CRL
567			verifyRevocation(t, intSerialNumber, false)
568		}
569	}
570}
571