1package pki 2 3import ( 4 "context" 5 "crypto/ecdsa" 6 "crypto/elliptic" 7 "crypto/rand" 8 "crypto/rsa" 9 "crypto/x509" 10 "crypto/x509/pkix" 11 "encoding/json" 12 "encoding/pem" 13 "math/big" 14 mathrand "math/rand" 15 "strings" 16 "testing" 17 "time" 18 19 "github.com/go-test/deep" 20 "github.com/hashicorp/vault/api" 21 vaulthttp "github.com/hashicorp/vault/http" 22 "github.com/hashicorp/vault/sdk/helper/certutil" 23 "github.com/hashicorp/vault/sdk/logical" 24 "github.com/hashicorp/vault/vault" 25) 26 27func TestBackend_CA_Steps(t *testing.T) { 28 var b *backend 29 30 factory := func(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { 31 be, err := Factory(ctx, conf) 32 if err == nil { 33 b = be.(*backend) 34 } 35 return be, err 36 } 37 38 coreConfig := &vault.CoreConfig{ 39 LogicalBackends: map[string]logical.Factory{ 40 "pki": factory, 41 }, 42 } 43 cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ 44 HandlerFunc: vaulthttp.Handler, 45 }) 46 cluster.Start() 47 defer cluster.Cleanup() 48 49 client := cluster.Cores[0].Client 50 51 // Set RSA/EC CA certificates 52 var rsaCAKey, rsaCACert, ecCAKey, ecCACert string 53 { 54 cak, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 55 if err != nil { 56 panic(err) 57 } 58 marshaledKey, err := x509.MarshalECPrivateKey(cak) 59 if err != nil { 60 panic(err) 61 } 62 keyPEMBlock := &pem.Block{ 63 Type: "EC PRIVATE KEY", 64 Bytes: marshaledKey, 65 } 66 ecCAKey = strings.TrimSpace(string(pem.EncodeToMemory(keyPEMBlock))) 67 if err != nil { 68 panic(err) 69 } 70 subjKeyID, err := certutil.GetSubjKeyID(cak) 71 if err != nil { 72 panic(err) 73 } 74 caCertTemplate := &x509.Certificate{ 75 Subject: pkix.Name{ 76 CommonName: "root.localhost", 77 }, 78 SubjectKeyId: subjKeyID, 79 DNSNames: []string{"root.localhost"}, 80 KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign), 81 SerialNumber: big.NewInt(mathrand.Int63()), 82 NotAfter: time.Now().Add(262980 * time.Hour), 83 BasicConstraintsValid: true, 84 IsCA: true, 85 } 86 caBytes, err := x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, cak.Public(), cak) 87 if err != nil { 88 panic(err) 89 } 90 caCertPEMBlock := &pem.Block{ 91 Type: "CERTIFICATE", 92 Bytes: caBytes, 93 } 94 ecCACert = strings.TrimSpace(string(pem.EncodeToMemory(caCertPEMBlock))) 95 96 rak, err := rsa.GenerateKey(rand.Reader, 2048) 97 if err != nil { 98 panic(err) 99 } 100 marshaledKey = x509.MarshalPKCS1PrivateKey(rak) 101 keyPEMBlock = &pem.Block{ 102 Type: "RSA PRIVATE KEY", 103 Bytes: marshaledKey, 104 } 105 rsaCAKey = strings.TrimSpace(string(pem.EncodeToMemory(keyPEMBlock))) 106 if err != nil { 107 panic(err) 108 } 109 _, err = certutil.GetSubjKeyID(rak) 110 if err != nil { 111 panic(err) 112 } 113 caBytes, err = x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, rak.Public(), rak) 114 if err != nil { 115 panic(err) 116 } 117 caCertPEMBlock = &pem.Block{ 118 Type: "CERTIFICATE", 119 Bytes: caBytes, 120 } 121 rsaCACert = strings.TrimSpace(string(pem.EncodeToMemory(caCertPEMBlock))) 122 } 123 124 // Setup backends 125 var rsaRoot, rsaInt, ecRoot, ecInt *backend 126 { 127 if err := client.Sys().Mount("rsaroot", &api.MountInput{ 128 Type: "pki", 129 Config: api.MountConfigInput{ 130 DefaultLeaseTTL: "16h", 131 MaxLeaseTTL: "60h", 132 }, 133 }); err != nil { 134 t.Fatal(err) 135 } 136 rsaRoot = b 137 138 if err := client.Sys().Mount("rsaint", &api.MountInput{ 139 Type: "pki", 140 Config: api.MountConfigInput{ 141 DefaultLeaseTTL: "16h", 142 MaxLeaseTTL: "60h", 143 }, 144 }); err != nil { 145 t.Fatal(err) 146 } 147 rsaInt = b 148 149 if err := client.Sys().Mount("ecroot", &api.MountInput{ 150 Type: "pki", 151 Config: api.MountConfigInput{ 152 DefaultLeaseTTL: "16h", 153 MaxLeaseTTL: "60h", 154 }, 155 }); err != nil { 156 t.Fatal(err) 157 } 158 ecRoot = b 159 160 if err := client.Sys().Mount("ecint", &api.MountInput{ 161 Type: "pki", 162 Config: api.MountConfigInput{ 163 DefaultLeaseTTL: "16h", 164 MaxLeaseTTL: "60h", 165 }, 166 }); err != nil { 167 t.Fatal(err) 168 } 169 ecInt = b 170 } 171 172 t.Run("teststeps", func(t *testing.T) { 173 t.Run("rsa", func(t *testing.T) { 174 t.Parallel() 175 subClient, err := client.Clone() 176 if err != nil { 177 t.Fatal(err) 178 } 179 subClient.SetToken(client.Token()) 180 runSteps(t, rsaRoot, rsaInt, subClient, "rsaroot/", "rsaint/", rsaCACert, rsaCAKey) 181 }) 182 t.Run("ec", func(t *testing.T) { 183 t.Parallel() 184 subClient, err := client.Clone() 185 if err != nil { 186 t.Fatal(err) 187 } 188 subClient.SetToken(client.Token()) 189 runSteps(t, ecRoot, ecInt, subClient, "ecroot/", "ecint/", ecCACert, ecCAKey) 190 }) 191 }) 192} 193 194func runSteps(t *testing.T, rootB, intB *backend, client *api.Client, rootName, intName, caCert, caKey string) { 195 // Load CA cert/key in and ensure we can fetch it back in various formats, 196 // unauthenticated 197 { 198 // Attempt import but only provide one the cert 199 { 200 _, err := client.Logical().Write(rootName+"config/ca", map[string]interface{}{ 201 "pem_bundle": caCert, 202 }) 203 if err == nil { 204 t.Fatal("expected error") 205 } 206 } 207 208 // Same but with only the key 209 { 210 _, err := client.Logical().Write(rootName+"config/ca", map[string]interface{}{ 211 "pem_bundle": caKey, 212 }) 213 if err == nil { 214 t.Fatal("expected error") 215 } 216 } 217 218 // Import CA bundle 219 { 220 _, err := client.Logical().Write(rootName+"config/ca", map[string]interface{}{ 221 "pem_bundle": strings.Join([]string{caKey, caCert}, "\n"), 222 }) 223 if err != nil { 224 t.Fatal(err) 225 } 226 } 227 228 prevToken := client.Token() 229 client.SetToken("") 230 231 // cert/ca path 232 { 233 resp, err := client.Logical().Read(rootName + "cert/ca") 234 if err != nil { 235 t.Fatal(err) 236 } 237 if resp == nil { 238 t.Fatal("nil response") 239 } 240 if diff := deep.Equal(resp.Data["certificate"].(string), caCert); diff != nil { 241 t.Fatal(diff) 242 } 243 } 244 // ca/pem path (raw string) 245 { 246 req := &logical.Request{ 247 Path: "ca/pem", 248 Operation: logical.ReadOperation, 249 Storage: rootB.storage, 250 } 251 resp, err := rootB.HandleRequest(context.Background(), req) 252 if err != nil { 253 t.Fatal(err) 254 } 255 if resp == nil { 256 t.Fatal("nil response") 257 } 258 if diff := deep.Equal(resp.Data["http_raw_body"].([]byte), []byte(caCert)); diff != nil { 259 t.Fatal(diff) 260 } 261 if resp.Data["http_content_type"].(string) != "application/pkix-cert" { 262 t.Fatal("wrong content type") 263 } 264 } 265 266 // ca (raw DER bytes) 267 { 268 req := &logical.Request{ 269 Path: "ca", 270 Operation: logical.ReadOperation, 271 Storage: rootB.storage, 272 } 273 resp, err := rootB.HandleRequest(context.Background(), req) 274 if err != nil { 275 t.Fatal(err) 276 } 277 if resp == nil { 278 t.Fatal("nil response") 279 } 280 rawBytes := resp.Data["http_raw_body"].([]byte) 281 pemBytes := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{ 282 Type: "CERTIFICATE", 283 Bytes: rawBytes, 284 }))) 285 if diff := deep.Equal(pemBytes, caCert); diff != nil { 286 t.Fatal(diff) 287 } 288 if resp.Data["http_content_type"].(string) != "application/pkix-cert" { 289 t.Fatal("wrong content type") 290 } 291 } 292 293 client.SetToken(prevToken) 294 } 295 296 // Configure an expiry on the CRL and verify what comes back 297 { 298 // Set CRL config 299 { 300 _, err := client.Logical().Write(rootName+"config/crl", map[string]interface{}{ 301 "expiry": "16h", 302 }) 303 if err != nil { 304 t.Fatal(err) 305 } 306 } 307 308 // Verify it 309 { 310 resp, err := client.Logical().Read(rootName + "config/crl") 311 if err != nil { 312 t.Fatal(err) 313 } 314 if resp == nil { 315 t.Fatal("nil response") 316 } 317 if resp.Data["expiry"].(string) != "16h" { 318 t.Fatal("expected a 16 hour expiry") 319 } 320 } 321 } 322 323 // Test generating a root, an intermediate, signing it, setting signed, and 324 // revoking it 325 326 // We'll need this later 327 var intSerialNumber string 328 { 329 // First, delete the existing CA info 330 { 331 _, err := client.Logical().Delete(rootName + "root") 332 if err != nil { 333 t.Fatal(err) 334 } 335 } 336 337 var rootPEM, rootKey, rootPEMBundle string 338 // Test exported root generation 339 { 340 resp, err := client.Logical().Write(rootName+"root/generate/exported", map[string]interface{}{ 341 "common_name": "Root Cert", 342 "ttl": "180h", 343 }) 344 if err != nil { 345 t.Fatal(err) 346 } 347 if resp == nil { 348 t.Fatal("nil response") 349 } 350 rootPEM = resp.Data["certificate"].(string) 351 rootKey = resp.Data["private_key"].(string) 352 rootPEMBundle = strings.Join([]string{rootPEM, rootKey}, "\n") 353 // This is really here to keep the use checker happy 354 if rootPEMBundle == "" { 355 t.Fatal("bad root pem bundle") 356 } 357 } 358 359 var intPEM, intCSR, intKey string 360 // Test exported intermediate CSR generation 361 { 362 resp, err := client.Logical().Write(intName+"intermediate/generate/exported", map[string]interface{}{ 363 "common_name": "intermediate.cert.com", 364 "ttl": "180h", 365 }) 366 if err != nil { 367 t.Fatal(err) 368 } 369 if resp == nil { 370 t.Fatal("nil response") 371 } 372 intCSR = resp.Data["csr"].(string) 373 intKey = resp.Data["private_key"].(string) 374 // This is really here to keep the use checker happy 375 if intCSR == "" || intKey == "" { 376 t.Fatal("int csr or key empty") 377 } 378 } 379 380 // Test signing 381 { 382 resp, err := client.Logical().Write(rootName+"root/sign-intermediate", map[string]interface{}{ 383 "common_name": "intermediate.cert.com", 384 "ttl": "10s", 385 "csr": intCSR, 386 }) 387 if err != nil { 388 t.Fatal(err) 389 } 390 if resp == nil { 391 t.Fatal("nil response") 392 } 393 intPEM = resp.Data["certificate"].(string) 394 intSerialNumber = resp.Data["serial_number"].(string) 395 } 396 397 // Test setting signed 398 { 399 resp, err := client.Logical().Write(intName+"intermediate/set-signed", map[string]interface{}{ 400 "certificate": intPEM, 401 }) 402 if err != nil { 403 t.Fatal(err) 404 } 405 if resp != nil { 406 t.Fatal("expected nil response") 407 } 408 } 409 410 // Verify we can find it via the root 411 { 412 resp, err := client.Logical().Read(rootName + "cert/" + intSerialNumber) 413 if err != nil { 414 t.Fatal(err) 415 } 416 if resp == nil { 417 t.Fatal("nil response") 418 } 419 if resp.Data["revocation_time"].(json.Number).String() != "0" { 420 t.Fatal("expected a zero revocation time") 421 } 422 } 423 424 // Revoke the intermediate 425 { 426 resp, err := client.Logical().Write(rootName+"revoke", map[string]interface{}{ 427 "serial_number": intSerialNumber, 428 }) 429 if err != nil { 430 t.Fatal(err) 431 } 432 if resp == nil { 433 t.Fatal("nil response") 434 } 435 } 436 } 437 438 verifyRevocation := func(t *testing.T, serial string, shouldFind bool) { 439 t.Helper() 440 // Verify it is now revoked 441 { 442 resp, err := client.Logical().Read(rootName + "cert/" + intSerialNumber) 443 if err != nil { 444 t.Fatal(err) 445 } 446 switch shouldFind { 447 case true: 448 if resp == nil { 449 t.Fatal("nil response") 450 } 451 if resp.Data["revocation_time"].(json.Number).String() == "0" { 452 t.Fatal("expected a non-zero revocation time") 453 } 454 default: 455 if resp != nil { 456 t.Fatalf("expected nil response, got %#v", *resp) 457 } 458 } 459 } 460 461 // Fetch the CRL and make sure it shows up 462 { 463 req := &logical.Request{ 464 Path: "crl", 465 Operation: logical.ReadOperation, 466 Storage: rootB.storage, 467 } 468 resp, err := rootB.HandleRequest(context.Background(), req) 469 if err != nil { 470 t.Fatal(err) 471 } 472 if resp == nil { 473 t.Fatal("nil response") 474 } 475 crlBytes := resp.Data["http_raw_body"].([]byte) 476 certList, err := x509.ParseCRL(crlBytes) 477 if err != nil { 478 t.Fatal(err) 479 } 480 switch shouldFind { 481 case true: 482 revokedList := certList.TBSCertList.RevokedCertificates 483 if len(revokedList) != 1 { 484 t.Fatalf("bad length of revoked list: %d", len(revokedList)) 485 } 486 revokedString := certutil.GetHexFormatted(revokedList[0].SerialNumber.Bytes(), ":") 487 if revokedString != intSerialNumber { 488 t.Fatalf("bad revoked serial: %s", revokedString) 489 } 490 default: 491 revokedList := certList.TBSCertList.RevokedCertificates 492 if len(revokedList) != 0 { 493 t.Fatalf("bad length of revoked list: %d", len(revokedList)) 494 } 495 } 496 } 497 } 498 499 // Validate current state of revoked certificates 500 verifyRevocation(t, intSerialNumber, true) 501 502 // Give time for the safety buffer to pass before tidying 503 time.Sleep(10 * time.Second) 504 505 // Test tidying 506 { 507 // Run with a high safety buffer, nothing should happen 508 { 509 resp, err := client.Logical().Write(rootName+"tidy", map[string]interface{}{ 510 "safety_buffer": "3h", 511 "tidy_cert_store": true, 512 "tidy_revoked_certs": true, 513 }) 514 if err != nil { 515 t.Fatal(err) 516 } 517 if resp == nil { 518 t.Fatal("expected warnings") 519 } 520 521 // Wait a few seconds as it runs in a goroutine 522 time.Sleep(5 * time.Second) 523 524 // Check to make sure we still find the cert and see it on the CRL 525 verifyRevocation(t, intSerialNumber, true) 526 } 527 528 // Run with both values set false, nothing should happen 529 { 530 resp, err := client.Logical().Write(rootName+"tidy", map[string]interface{}{ 531 "safety_buffer": "1s", 532 "tidy_cert_store": false, 533 "tidy_revoked_certs": false, 534 }) 535 if err != nil { 536 t.Fatal(err) 537 } 538 if resp == nil { 539 t.Fatal("expected warnings") 540 } 541 542 // Wait a few seconds as it runs in a goroutine 543 time.Sleep(5 * time.Second) 544 545 // Check to make sure we still find the cert and see it on the CRL 546 verifyRevocation(t, intSerialNumber, true) 547 } 548 549 // Run with a short safety buffer and both set to true, both should be cleared 550 { 551 resp, err := client.Logical().Write(rootName+"tidy", map[string]interface{}{ 552 "safety_buffer": "1s", 553 "tidy_cert_store": true, 554 "tidy_revoked_certs": true, 555 }) 556 if err != nil { 557 t.Fatal(err) 558 } 559 if resp == nil { 560 t.Fatal("expected warnings") 561 } 562 563 // Wait a few seconds as it runs in a goroutine 564 time.Sleep(5 * time.Second) 565 566 // Check to make sure we still find the cert and see it on the CRL 567 verifyRevocation(t, intSerialNumber, false) 568 } 569 } 570} 571