1package vault 2 3import ( 4 "bytes" 5 "context" 6 "crypto/ecdsa" 7 "crypto/elliptic" 8 "crypto/rand" 9 "crypto/sha256" 10 "crypto/tls" 11 "crypto/x509" 12 "crypto/x509/pkix" 13 "encoding/base64" 14 "encoding/pem" 15 "errors" 16 "fmt" 17 "io" 18 "io/ioutil" 19 "math/big" 20 mathrand "math/rand" 21 "net" 22 "net/http" 23 "os" 24 "os/exec" 25 "path/filepath" 26 "sync" 27 "sync/atomic" 28 "time" 29 30 "github.com/armon/go-metrics" 31 hclog "github.com/hashicorp/go-hclog" 32 log "github.com/hashicorp/go-hclog" 33 "github.com/hashicorp/vault/helper/metricsutil" 34 "github.com/hashicorp/vault/vault/cluster" 35 "github.com/hashicorp/vault/vault/seal" 36 "github.com/mitchellh/copystructure" 37 38 "golang.org/x/crypto/ed25519" 39 "golang.org/x/crypto/ssh" 40 "golang.org/x/net/http2" 41 42 cleanhttp "github.com/hashicorp/go-cleanhttp" 43 "github.com/hashicorp/vault/api" 44 "github.com/hashicorp/vault/audit" 45 "github.com/hashicorp/vault/command/server" 46 "github.com/hashicorp/vault/helper/namespace" 47 "github.com/hashicorp/vault/internalshared/reloadutil" 48 dbMysql "github.com/hashicorp/vault/plugins/database/mysql" 49 dbPostgres "github.com/hashicorp/vault/plugins/database/postgresql" 50 "github.com/hashicorp/vault/sdk/framework" 51 "github.com/hashicorp/vault/sdk/helper/consts" 52 "github.com/hashicorp/vault/sdk/helper/logging" 53 "github.com/hashicorp/vault/sdk/helper/salt" 54 "github.com/hashicorp/vault/sdk/logical" 55 "github.com/hashicorp/vault/sdk/physical" 56 testing "github.com/mitchellh/go-testing-interface" 57 58 physInmem "github.com/hashicorp/vault/sdk/physical/inmem" 59) 60 61// This file contains a number of methods that are useful for unit 62// tests within other packages. 63 64const ( 65 testSharedPublicKey = ` 66ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQC9i+hFxZHGo6KblVme4zrAcJstR6I0PTJozW286X4WyvPnkMYDQ5mnhEYC7UWCvjoTWbPEXPX7NjhRtwQTGD67bV+lrxgfyzK1JZbUXK4PwgKJvQD+XyyWYMzDgGSQY61KUSqCxymSm/9NZkPU3ElaQ9xQuTzPpztM4ROfb8f2Yv6/ZESZsTo0MTAkp8Pcy+WkioI/uJ1H7zqs0EA4OMY4aDJRu0UtP4rTVeYNEAuRXdX+eH4aW3KMvhzpFTjMbaJHJXlEeUm2SaX5TNQyTOvghCeQILfYIL/Ca2ij8iwCmulwdV6eQGfd4VDu40PvSnmfoaE38o6HaPnX0kUcnKiT 67` 68 testSharedPrivateKey = ` 69-----BEGIN RSA PRIVATE KEY----- 70MIIEogIBAAKCAQEAvYvoRcWRxqOim5VZnuM6wHCbLUeiND0yaM1tvOl+Fsrz55DG 71A0OZp4RGAu1Fgr46E1mzxFz1+zY4UbcEExg+u21fpa8YH8sytSWW1FyuD8ICib0A 72/l8slmDMw4BkkGOtSlEqgscpkpv/TWZD1NxJWkPcULk8z6c7TOETn2/H9mL+v2RE 73mbE6NDEwJKfD3MvlpIqCP7idR+86rNBAODjGOGgyUbtFLT+K01XmDRALkV3V/nh+ 74GltyjL4c6RU4zG2iRyV5RHlJtkml+UzUMkzr4IQnkCC32CC/wmtoo/IsAprpcHVe 75nkBn3eFQ7uND70p5n6GhN/KOh2j519JFHJyokwIDAQABAoIBAHX7VOvBC3kCN9/x 76+aPdup84OE7Z7MvpX6w+WlUhXVugnmsAAVDczhKoUc/WktLLx2huCGhsmKvyVuH+ 77MioUiE+vx75gm3qGx5xbtmOfALVMRLopjCnJYf6EaFA0ZeQ+NwowNW7Lu0PHmAU8 78Z3JiX8IwxTz14DU82buDyewO7v+cEr97AnERe3PUcSTDoUXNaoNxjNpEJkKREY6h 794hAY676RT/GsRcQ8tqe/rnCqPHNd7JGqL+207FK4tJw7daoBjQyijWuB7K5chSal 80oPInylM6b13ASXuOAOT/2uSUBWmFVCZPDCmnZxy2SdnJGbsJAMl7Ma3MUlaGvVI+ 81Tfh1aQkCgYEA4JlNOabTb3z42wz6mz+Nz3JRwbawD+PJXOk5JsSnV7DtPtfgkK9y 826FTQdhnozGWShAvJvc+C4QAihs9AlHXoaBY5bEU7R/8UK/pSqwzam+MmxmhVDV7G 83IMQPV0FteoXTaJSikhZ88mETTegI2mik+zleBpVxvfdhE5TR+lq8Br0CgYEA2AwJ 84CUD5CYUSj09PluR0HHqamWOrJkKPFPwa+5eiTTCzfBBxImYZh7nXnWuoviXC0sg2 85AuvCW+uZ48ygv/D8gcz3j1JfbErKZJuV+TotK9rRtNIF5Ub7qysP7UjyI7zCssVM 86kuDd9LfRXaB/qGAHNkcDA8NxmHW3gpln4CFdSY8CgYANs4xwfercHEWaJ1qKagAe 87rZyrMpffAEhicJ/Z65lB0jtG4CiE6w8ZeUMWUVJQVcnwYD+4YpZbX4S7sJ0B8Ydy 88AhkSr86D/92dKTIt2STk6aCN7gNyQ1vW198PtaAWH1/cO2UHgHOy3ZUt5X/Uwxl9 89cex4flln+1Viumts2GgsCQKBgCJH7psgSyPekK5auFdKEr5+Gc/jB8I/Z3K9+g4X 905nH3G1PBTCJYLw7hRzw8W/8oALzvddqKzEFHphiGXK94Lqjt/A4q1OdbCrhiE68D 91My21P/dAKB1UYRSs9Y8CNyHCjuZM9jSMJ8vv6vG/SOJPsnVDWVAckAbQDvlTHC9t 92O98zAoGAcbW6uFDkrv0XMCpB9Su3KaNXOR0wzag+WIFQRXCcoTvxVi9iYfUReQPi 93oOyBJU/HMVvBfv4g+OVFLVgSwwm6owwsouZ0+D/LasbuHqYyqYqdyPJQYzWA2Y+F 94+B6f4RoPdSXj24JHPg/ioRxjaj094UXJxua2yfkcecGNEuBQHSs= 95-----END RSA PRIVATE KEY----- 96` 97) 98 99// TestCore returns a pure in-memory, uninitialized core for testing. 100func TestCore(t testing.T) *Core { 101 return TestCoreWithSeal(t, nil, false) 102} 103 104// TestCoreRaw returns a pure in-memory, uninitialized core for testing. The raw 105// storage endpoints are enabled with this core. 106func TestCoreRaw(t testing.T) *Core { 107 return TestCoreWithSeal(t, nil, true) 108} 109 110// TestCoreNewSeal returns a pure in-memory, uninitialized core with 111// the new seal configuration. 112func TestCoreNewSeal(t testing.T) *Core { 113 seal := NewTestSeal(t, nil) 114 return TestCoreWithSeal(t, seal, false) 115} 116 117// TestCoreWithConfig returns a pure in-memory, uninitialized core with the 118// specified core configurations overridden for testing. 119func TestCoreWithConfig(t testing.T, conf *CoreConfig) *Core { 120 return TestCoreWithSealAndUI(t, conf) 121} 122 123// TestCoreWithSeal returns a pure in-memory, uninitialized core with the 124// specified seal for testing. 125func TestCoreWithSeal(t testing.T, testSeal Seal, enableRaw bool) *Core { 126 conf := &CoreConfig{ 127 Seal: testSeal, 128 EnableUI: false, 129 EnableRaw: enableRaw, 130 BuiltinRegistry: NewMockBuiltinRegistry(), 131 } 132 return TestCoreWithSealAndUI(t, conf) 133} 134 135func TestCoreUI(t testing.T, enableUI bool) *Core { 136 conf := &CoreConfig{ 137 EnableUI: enableUI, 138 EnableRaw: true, 139 BuiltinRegistry: NewMockBuiltinRegistry(), 140 } 141 return TestCoreWithSealAndUI(t, conf) 142} 143 144func TestCoreWithSealAndUI(t testing.T, opts *CoreConfig) *Core { 145 logger := logging.NewVaultLogger(log.Trace) 146 physicalBackend, err := physInmem.NewInmem(nil, logger) 147 if err != nil { 148 t.Fatal(err) 149 } 150 151 errInjector := physical.NewErrorInjector(physicalBackend, 0, logger) 152 153 // Start off with base test core config 154 conf := testCoreConfig(t, errInjector, logger) 155 156 // Override config values with ones that gets passed in 157 conf.EnableUI = opts.EnableUI 158 conf.EnableRaw = opts.EnableRaw 159 conf.Seal = opts.Seal 160 conf.LicensingConfig = opts.LicensingConfig 161 conf.DisableKeyEncodingChecks = opts.DisableKeyEncodingChecks 162 conf.MetricsHelper = opts.MetricsHelper 163 164 if opts.Logger != nil { 165 conf.Logger = opts.Logger 166 } 167 168 for k, v := range opts.LogicalBackends { 169 conf.LogicalBackends[k] = v 170 } 171 for k, v := range opts.CredentialBackends { 172 conf.CredentialBackends[k] = v 173 } 174 175 for k, v := range opts.AuditBackends { 176 conf.AuditBackends[k] = v 177 } 178 179 c, err := NewCore(conf) 180 if err != nil { 181 t.Fatalf("err: %s", err) 182 } 183 184 return c 185} 186 187func testCoreConfig(t testing.T, physicalBackend physical.Backend, logger log.Logger) *CoreConfig { 188 t.Helper() 189 noopAudits := map[string]audit.Factory{ 190 "noop": func(_ context.Context, config *audit.BackendConfig) (audit.Backend, error) { 191 view := &logical.InmemStorage{} 192 view.Put(context.Background(), &logical.StorageEntry{ 193 Key: "salt", 194 Value: []byte("foo"), 195 }) 196 config.SaltConfig = &salt.Config{ 197 HMAC: sha256.New, 198 HMACType: "hmac-sha256", 199 } 200 config.SaltView = view 201 202 n := &noopAudit{ 203 Config: config, 204 } 205 n.formatter.AuditFormatWriter = &audit.JSONFormatWriter{ 206 SaltFunc: n.Salt, 207 } 208 return n, nil 209 }, 210 } 211 212 noopBackends := make(map[string]logical.Factory) 213 noopBackends["noop"] = func(ctx context.Context, config *logical.BackendConfig) (logical.Backend, error) { 214 b := new(framework.Backend) 215 b.Setup(ctx, config) 216 b.BackendType = logical.TypeCredential 217 return b, nil 218 } 219 noopBackends["http"] = func(ctx context.Context, config *logical.BackendConfig) (logical.Backend, error) { 220 return new(rawHTTP), nil 221 } 222 223 credentialBackends := make(map[string]logical.Factory) 224 for backendName, backendFactory := range noopBackends { 225 credentialBackends[backendName] = backendFactory 226 } 227 for backendName, backendFactory := range testCredentialBackends { 228 credentialBackends[backendName] = backendFactory 229 } 230 231 logicalBackends := make(map[string]logical.Factory) 232 for backendName, backendFactory := range noopBackends { 233 logicalBackends[backendName] = backendFactory 234 } 235 236 logicalBackends["kv"] = LeasedPassthroughBackendFactory 237 for backendName, backendFactory := range testLogicalBackends { 238 logicalBackends[backendName] = backendFactory 239 } 240 241 conf := &CoreConfig{ 242 Physical: physicalBackend, 243 AuditBackends: noopAudits, 244 LogicalBackends: logicalBackends, 245 CredentialBackends: credentialBackends, 246 DisableMlock: true, 247 Logger: logger, 248 BuiltinRegistry: NewMockBuiltinRegistry(), 249 } 250 251 return conf 252} 253 254// TestCoreInit initializes the core with a single key, and returns 255// the key that must be used to unseal the core and a root token. 256func TestCoreInit(t testing.T, core *Core) ([][]byte, string) { 257 t.Helper() 258 secretShares, _, root := TestCoreInitClusterWrapperSetup(t, core, nil) 259 return secretShares, root 260} 261 262func TestCoreInitClusterWrapperSetup(t testing.T, core *Core, handler http.Handler) ([][]byte, [][]byte, string) { 263 t.Helper() 264 core.SetClusterHandler(handler) 265 266 barrierConfig := &SealConfig{ 267 SecretShares: 3, 268 SecretThreshold: 3, 269 } 270 271 switch core.seal.StoredKeysSupported() { 272 case seal.StoredKeysNotSupported: 273 barrierConfig.StoredShares = 0 274 default: 275 barrierConfig.StoredShares = 1 276 } 277 278 recoveryConfig := &SealConfig{ 279 SecretShares: 3, 280 SecretThreshold: 3, 281 } 282 283 initParams := &InitParams{ 284 BarrierConfig: barrierConfig, 285 RecoveryConfig: recoveryConfig, 286 } 287 if core.seal.StoredKeysSupported() == seal.StoredKeysNotSupported { 288 initParams.LegacyShamirSeal = true 289 } 290 result, err := core.Initialize(context.Background(), initParams) 291 if err != nil { 292 t.Fatalf("err: %s", err) 293 } 294 return result.SecretShares, result.RecoveryShares, result.RootToken 295} 296 297func TestCoreUnseal(core *Core, key []byte) (bool, error) { 298 return core.Unseal(key) 299} 300 301func TestCoreUnsealWithRecoveryKeys(core *Core, key []byte) (bool, error) { 302 return core.UnsealWithRecoveryKeys(key) 303} 304 305// TestCoreUnsealed returns a pure in-memory core that is already 306// initialized and unsealed. 307func TestCoreUnsealed(t testing.T) (*Core, [][]byte, string) { 308 t.Helper() 309 core := TestCore(t) 310 return testCoreUnsealed(t, core) 311} 312 313// TestCoreUnsealedRaw returns a pure in-memory core that is already 314// initialized, unsealed, and with raw endpoints enabled. 315func TestCoreUnsealedRaw(t testing.T) (*Core, [][]byte, string) { 316 t.Helper() 317 core := TestCoreRaw(t) 318 return testCoreUnsealed(t, core) 319} 320 321// TestCoreUnsealedWithConfig returns a pure in-memory core that is already 322// initialized, unsealed, with the any provided core config values overridden. 323func TestCoreUnsealedWithConfig(t testing.T, conf *CoreConfig) (*Core, [][]byte, string) { 324 t.Helper() 325 core := TestCoreWithConfig(t, conf) 326 return testCoreUnsealed(t, core) 327} 328 329func testCoreUnsealed(t testing.T, core *Core) (*Core, [][]byte, string) { 330 t.Helper() 331 keys, token := TestCoreInit(t, core) 332 for _, key := range keys { 333 if _, err := TestCoreUnseal(core, TestKeyCopy(key)); err != nil { 334 t.Fatalf("unseal err: %s", err) 335 } 336 } 337 338 if core.Sealed() { 339 t.Fatal("should not be sealed") 340 } 341 342 testCoreAddSecretMount(t, core, token) 343 344 return core, keys, token 345} 346 347func testCoreAddSecretMount(t testing.T, core *Core, token string) { 348 kvReq := &logical.Request{ 349 Operation: logical.UpdateOperation, 350 ClientToken: token, 351 Path: "sys/mounts/secret", 352 Data: map[string]interface{}{ 353 "type": "kv", 354 "path": "secret/", 355 "description": "key/value secret storage", 356 "options": map[string]string{ 357 "version": "1", 358 }, 359 }, 360 } 361 resp, err := core.HandleRequest(namespace.RootContext(nil), kvReq) 362 if err != nil { 363 t.Fatal(err) 364 } 365 if resp.IsError() { 366 t.Fatal(err) 367 } 368 369} 370 371func TestCoreUnsealedBackend(t testing.T, backend physical.Backend) (*Core, [][]byte, string) { 372 t.Helper() 373 logger := logging.NewVaultLogger(log.Trace) 374 conf := testCoreConfig(t, backend, logger) 375 conf.Seal = NewTestSeal(t, nil) 376 377 core, err := NewCore(conf) 378 if err != nil { 379 t.Fatalf("err: %s", err) 380 } 381 382 keys, token := TestCoreInit(t, core) 383 for _, key := range keys { 384 if _, err := TestCoreUnseal(core, TestKeyCopy(key)); err != nil { 385 t.Fatalf("unseal err: %s", err) 386 } 387 } 388 389 if err := core.UnsealWithStoredKeys(context.Background()); err != nil { 390 t.Fatal(err) 391 } 392 393 if core.Sealed() { 394 t.Fatal("should not be sealed") 395 } 396 397 return core, keys, token 398} 399 400// TestKeyCopy is a silly little function to just copy the key so that 401// it can be used with Unseal easily. 402func TestKeyCopy(key []byte) []byte { 403 result := make([]byte, len(key)) 404 copy(result, key) 405 return result 406} 407 408func TestDynamicSystemView(c *Core) *dynamicSystemView { 409 me := &MountEntry{ 410 Config: MountConfig{ 411 DefaultLeaseTTL: 24 * time.Hour, 412 MaxLeaseTTL: 2 * 24 * time.Hour, 413 }, 414 } 415 416 return &dynamicSystemView{c, me} 417} 418 419// TestAddTestPlugin registers the testFunc as part of the plugin command to the 420// plugin catalog. If provided, uses tmpDir as the plugin directory. 421func TestAddTestPlugin(t testing.T, c *Core, name string, pluginType consts.PluginType, testFunc string, env []string, tempDir string) { 422 file, err := os.Open(os.Args[0]) 423 if err != nil { 424 t.Fatal(err) 425 } 426 defer file.Close() 427 428 dirPath := filepath.Dir(os.Args[0]) 429 fileName := filepath.Base(os.Args[0]) 430 431 if tempDir != "" { 432 fi, err := file.Stat() 433 if err != nil { 434 t.Fatal(err) 435 } 436 437 // Copy over the file to the temp dir 438 dst := filepath.Join(tempDir, fileName) 439 out, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode()) 440 if err != nil { 441 t.Fatal(err) 442 } 443 defer out.Close() 444 445 if _, err = io.Copy(out, file); err != nil { 446 t.Fatal(err) 447 } 448 err = out.Sync() 449 if err != nil { 450 t.Fatal(err) 451 } 452 453 dirPath = tempDir 454 } 455 456 // Determine plugin directory full path, evaluating potential symlink path 457 fullPath, err := filepath.EvalSymlinks(dirPath) 458 if err != nil { 459 t.Fatal(err) 460 } 461 462 reader, err := os.Open(filepath.Join(fullPath, fileName)) 463 if err != nil { 464 t.Fatal(err) 465 } 466 defer reader.Close() 467 468 // Find out the sha256 469 hash := sha256.New() 470 471 _, err = io.Copy(hash, reader) 472 if err != nil { 473 t.Fatal(err) 474 } 475 476 sum := hash.Sum(nil) 477 478 // Set core's plugin directory and plugin catalog directory 479 c.pluginDirectory = fullPath 480 c.pluginCatalog.directory = fullPath 481 482 args := []string{fmt.Sprintf("--test.run=%s", testFunc)} 483 err = c.pluginCatalog.Set(context.Background(), name, pluginType, fileName, args, env, sum) 484 if err != nil { 485 t.Fatal(err) 486 } 487} 488 489var testLogicalBackends = map[string]logical.Factory{} 490var testCredentialBackends = map[string]logical.Factory{} 491 492// StartSSHHostTestServer starts the test server which responds to SSH 493// authentication. Used to test the SSH secret backend. 494func StartSSHHostTestServer() (string, error) { 495 pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(testSharedPublicKey)) 496 if err != nil { 497 return "", fmt.Errorf("error parsing public key") 498 } 499 serverConfig := &ssh.ServerConfig{ 500 PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { 501 if bytes.Compare(pubKey.Marshal(), key.Marshal()) == 0 { 502 return &ssh.Permissions{}, nil 503 } else { 504 return nil, fmt.Errorf("key does not match") 505 } 506 }, 507 } 508 signer, err := ssh.ParsePrivateKey([]byte(testSharedPrivateKey)) 509 if err != nil { 510 panic("Error parsing private key") 511 } 512 serverConfig.AddHostKey(signer) 513 514 soc, err := net.Listen("tcp", "127.0.0.1:0") 515 if err != nil { 516 return "", fmt.Errorf("error listening to connection") 517 } 518 519 go func() { 520 for { 521 conn, err := soc.Accept() 522 if err != nil { 523 panic(fmt.Sprintf("Error accepting incoming connection: %s", err)) 524 } 525 defer conn.Close() 526 sshConn, chanReqs, _, err := ssh.NewServerConn(conn, serverConfig) 527 if err != nil { 528 panic(fmt.Sprintf("Handshaking error: %v", err)) 529 } 530 531 go func() { 532 for chanReq := range chanReqs { 533 go func(chanReq ssh.NewChannel) { 534 if chanReq.ChannelType() != "session" { 535 chanReq.Reject(ssh.UnknownChannelType, "unknown channel type") 536 return 537 } 538 539 ch, requests, err := chanReq.Accept() 540 if err != nil { 541 panic(fmt.Sprintf("Error accepting channel: %s", err)) 542 } 543 544 go func(ch ssh.Channel, in <-chan *ssh.Request) { 545 for req := range in { 546 executeServerCommand(ch, req) 547 } 548 }(ch, requests) 549 }(chanReq) 550 } 551 sshConn.Close() 552 }() 553 } 554 }() 555 return soc.Addr().String(), nil 556} 557 558// This executes the commands requested to be run on the server. 559// Used to test the SSH secret backend. 560func executeServerCommand(ch ssh.Channel, req *ssh.Request) { 561 command := string(req.Payload[4:]) 562 cmd := exec.Command("/bin/bash", []string{"-c", command}...) 563 req.Reply(true, nil) 564 565 cmd.Stdout = ch 566 cmd.Stderr = ch 567 cmd.Stdin = ch 568 569 err := cmd.Start() 570 if err != nil { 571 panic(fmt.Sprintf("Error starting the command: '%s'", err)) 572 } 573 574 go func() { 575 _, err := cmd.Process.Wait() 576 if err != nil { 577 panic(fmt.Sprintf("Error while waiting for command to finish:'%s'", err)) 578 } 579 ch.Close() 580 }() 581} 582 583// This adds a credential backend for the test core. This needs to be 584// invoked before the test core is created. 585func AddTestCredentialBackend(name string, factory logical.Factory) error { 586 if name == "" { 587 return fmt.Errorf("missing backend name") 588 } 589 if factory == nil { 590 return fmt.Errorf("missing backend factory function") 591 } 592 testCredentialBackends[name] = factory 593 return nil 594} 595 596// This adds a logical backend for the test core. This needs to be 597// invoked before the test core is created. 598func AddTestLogicalBackend(name string, factory logical.Factory) error { 599 if name == "" { 600 return fmt.Errorf("missing backend name") 601 } 602 if factory == nil { 603 return fmt.Errorf("missing backend factory function") 604 } 605 testLogicalBackends[name] = factory 606 return nil 607} 608 609type noopAudit struct { 610 Config *audit.BackendConfig 611 salt *salt.Salt 612 saltMutex sync.RWMutex 613 formatter audit.AuditFormatter 614 records [][]byte 615 l sync.RWMutex 616} 617 618func (n *noopAudit) GetHash(ctx context.Context, data string) (string, error) { 619 salt, err := n.Salt(ctx) 620 if err != nil { 621 return "", err 622 } 623 return salt.GetIdentifiedHMAC(data), nil 624} 625 626func (n *noopAudit) LogRequest(ctx context.Context, in *logical.LogInput) error { 627 n.l.Lock() 628 defer n.l.Unlock() 629 var w bytes.Buffer 630 err := n.formatter.FormatRequest(ctx, &w, audit.FormatterConfig{}, in) 631 if err != nil { 632 return err 633 } 634 n.records = append(n.records, w.Bytes()) 635 return nil 636} 637 638func (n *noopAudit) LogResponse(ctx context.Context, in *logical.LogInput) error { 639 n.l.Lock() 640 defer n.l.Unlock() 641 var w bytes.Buffer 642 err := n.formatter.FormatResponse(ctx, &w, audit.FormatterConfig{}, in) 643 if err != nil { 644 return err 645 } 646 n.records = append(n.records, w.Bytes()) 647 return nil 648} 649 650func (n *noopAudit) Reload(_ context.Context) error { 651 return nil 652} 653 654func (n *noopAudit) Invalidate(_ context.Context) { 655 n.saltMutex.Lock() 656 defer n.saltMutex.Unlock() 657 n.salt = nil 658} 659 660func (n *noopAudit) Salt(ctx context.Context) (*salt.Salt, error) { 661 n.saltMutex.RLock() 662 if n.salt != nil { 663 defer n.saltMutex.RUnlock() 664 return n.salt, nil 665 } 666 n.saltMutex.RUnlock() 667 n.saltMutex.Lock() 668 defer n.saltMutex.Unlock() 669 if n.salt != nil { 670 return n.salt, nil 671 } 672 salt, err := salt.NewSalt(ctx, n.Config.SaltView, n.Config.SaltConfig) 673 if err != nil { 674 return nil, err 675 } 676 n.salt = salt 677 return salt, nil 678} 679 680func AddNoopAudit(conf *CoreConfig, records **[][]byte) { 681 conf.AuditBackends = map[string]audit.Factory{ 682 "noop": func(_ context.Context, config *audit.BackendConfig) (audit.Backend, error) { 683 view := &logical.InmemStorage{} 684 view.Put(context.Background(), &logical.StorageEntry{ 685 Key: "salt", 686 Value: []byte("foo"), 687 }) 688 n := &noopAudit{ 689 Config: config, 690 } 691 n.formatter.AuditFormatWriter = &audit.JSONFormatWriter{ 692 SaltFunc: n.Salt, 693 } 694 if records != nil { 695 *records = &n.records 696 } 697 return n, nil 698 }, 699 } 700} 701 702type rawHTTP struct{} 703 704func (n *rawHTTP) HandleRequest(ctx context.Context, req *logical.Request) (*logical.Response, error) { 705 return &logical.Response{ 706 Data: map[string]interface{}{ 707 logical.HTTPStatusCode: 200, 708 logical.HTTPContentType: "plain/text", 709 logical.HTTPRawBody: []byte("hello world"), 710 }, 711 }, nil 712} 713 714func (n *rawHTTP) HandleExistenceCheck(ctx context.Context, req *logical.Request) (bool, bool, error) { 715 return false, false, nil 716} 717 718func (n *rawHTTP) SpecialPaths() *logical.Paths { 719 return &logical.Paths{Unauthenticated: []string{"*"}} 720} 721 722func (n *rawHTTP) System() logical.SystemView { 723 return logical.StaticSystemView{ 724 DefaultLeaseTTLVal: time.Hour * 24, 725 MaxLeaseTTLVal: time.Hour * 24 * 32, 726 } 727} 728 729func (n *rawHTTP) Logger() log.Logger { 730 return logging.NewVaultLogger(log.Trace) 731} 732 733func (n *rawHTTP) Cleanup(ctx context.Context) { 734 // noop 735} 736 737func (n *rawHTTP) Initialize(ctx context.Context, req *logical.InitializationRequest) error { 738 return nil 739} 740 741func (n *rawHTTP) InvalidateKey(context.Context, string) { 742 // noop 743} 744 745func (n *rawHTTP) Setup(ctx context.Context, config *logical.BackendConfig) error { 746 // noop 747 return nil 748} 749 750func (n *rawHTTP) Type() logical.BackendType { 751 return logical.TypeLogical 752} 753 754func GenerateRandBytes(length int) ([]byte, error) { 755 if length < 0 { 756 return nil, fmt.Errorf("length must be >= 0") 757 } 758 759 buf := make([]byte, length) 760 if length == 0 { 761 return buf, nil 762 } 763 764 n, err := rand.Read(buf) 765 if err != nil { 766 return nil, err 767 } 768 if n != length { 769 return nil, fmt.Errorf("unable to read %d bytes; only read %d", length, n) 770 } 771 772 return buf, nil 773} 774 775func TestWaitActive(t testing.T, core *Core) { 776 t.Helper() 777 if err := TestWaitActiveWithError(core); err != nil { 778 t.Fatal(err) 779 } 780} 781 782func TestWaitActiveWithError(core *Core) error { 783 start := time.Now() 784 var standby bool 785 var err error 786 for time.Now().Sub(start) < 30*time.Second { 787 standby, err = core.Standby() 788 if err != nil { 789 return err 790 } 791 if !standby { 792 break 793 } 794 } 795 if standby { 796 return errors.New("should not be in standby mode") 797 } 798 return nil 799} 800 801type TestCluster struct { 802 BarrierKeys [][]byte 803 RecoveryKeys [][]byte 804 CACert *x509.Certificate 805 CACertBytes []byte 806 CACertPEM []byte 807 CACertPEMFile string 808 CAKey *ecdsa.PrivateKey 809 CAKeyPEM []byte 810 Cores []*TestClusterCore 811 ID string 812 RootToken string 813 RootCAs *x509.CertPool 814 TempDir string 815 ClientAuthRequired bool 816 Logger log.Logger 817 CleanupFunc func() 818 SetupFunc func() 819} 820 821func (c *TestCluster) Start() { 822 for _, core := range c.Cores { 823 if core.Server != nil { 824 for _, ln := range core.Listeners { 825 go core.Server.Serve(ln) 826 } 827 } 828 } 829 if c.SetupFunc != nil { 830 c.SetupFunc() 831 } 832} 833 834// UnsealCores uses the cluster barrier keys to unseal the test cluster cores 835func (c *TestCluster) UnsealCores(t testing.T) { 836 t.Helper() 837 if err := c.UnsealCoresWithError(false); err != nil { 838 t.Fatal(err) 839 } 840} 841 842func (c *TestCluster) UnsealCoresWithError(useStoredKeys bool) error { 843 unseal := func(core *Core) error { 844 for _, key := range c.BarrierKeys { 845 if _, err := core.Unseal(TestKeyCopy(key)); err != nil { 846 return err 847 } 848 } 849 return nil 850 } 851 if useStoredKeys { 852 unseal = func(core *Core) error { 853 return core.UnsealWithStoredKeys(context.Background()) 854 } 855 } 856 857 // Unseal first core 858 if err := unseal(c.Cores[0].Core); err != nil { 859 return fmt.Errorf("unseal core %d err: %s", 0, err) 860 } 861 862 // Verify unsealed 863 if c.Cores[0].Sealed() { 864 return fmt.Errorf("should not be sealed") 865 } 866 867 if err := TestWaitActiveWithError(c.Cores[0].Core); err != nil { 868 return err 869 } 870 871 // Unseal other cores 872 for i := 1; i < len(c.Cores); i++ { 873 if err := unseal(c.Cores[i].Core); err != nil { 874 return fmt.Errorf("unseal core %d err: %s", i, err) 875 } 876 } 877 878 // Let them come fully up to standby 879 time.Sleep(2 * time.Second) 880 881 // Ensure cluster connection info is populated. 882 // Other cores should not come up as leaders. 883 for i := 1; i < len(c.Cores); i++ { 884 isLeader, _, _, err := c.Cores[i].Leader() 885 if err != nil { 886 return err 887 } 888 if isLeader { 889 return fmt.Errorf("core[%d] should not be leader", i) 890 } 891 } 892 893 return nil 894} 895 896func (c *TestCluster) UnsealCore(t testing.T, core *TestClusterCore) { 897 var keys [][]byte 898 if core.seal.RecoveryKeySupported() { 899 keys = c.RecoveryKeys 900 } else { 901 keys = c.BarrierKeys 902 } 903 for _, key := range keys { 904 if _, err := core.Core.Unseal(TestKeyCopy(key)); err != nil { 905 t.Fatalf("unseal err: %s", err) 906 } 907 } 908} 909 910func (c *TestCluster) EnsureCoresSealed(t testing.T) { 911 t.Helper() 912 if err := c.ensureCoresSealed(); err != nil { 913 t.Fatal(err) 914 } 915} 916 917func (c *TestClusterCore) Seal(t testing.T) { 918 t.Helper() 919 if err := c.Core.sealInternal(); err != nil { 920 t.Fatal(err) 921 } 922} 923 924func CleanupClusters(clusters []*TestCluster) { 925 wg := &sync.WaitGroup{} 926 for _, cluster := range clusters { 927 wg.Add(1) 928 lc := cluster 929 go func() { 930 defer wg.Done() 931 lc.Cleanup() 932 }() 933 } 934 wg.Wait() 935} 936 937func (c *TestCluster) Cleanup() { 938 c.Logger.Info("cleaning up vault cluster") 939 for _, core := range c.Cores { 940 core.CoreConfig.Logger.SetLevel(log.Error) 941 } 942 943 // Close listeners 944 wg := &sync.WaitGroup{} 945 for _, core := range c.Cores { 946 wg.Add(1) 947 lc := core 948 949 go func() { 950 defer wg.Done() 951 if lc.Listeners != nil { 952 for _, ln := range lc.Listeners { 953 ln.Close() 954 } 955 } 956 if lc.licensingStopCh != nil { 957 close(lc.licensingStopCh) 958 lc.licensingStopCh = nil 959 } 960 961 if err := lc.Shutdown(); err != nil { 962 lc.Logger().Error("error during shutdown; abandoning sealing", "error", err) 963 } else { 964 timeout := time.Now().Add(60 * time.Second) 965 for { 966 if time.Now().After(timeout) { 967 lc.Logger().Error("timeout waiting for core to seal") 968 } 969 if lc.Sealed() { 970 break 971 } 972 time.Sleep(250 * time.Millisecond) 973 } 974 } 975 }() 976 } 977 978 wg.Wait() 979 980 // Remove any temp dir that exists 981 if c.TempDir != "" { 982 os.RemoveAll(c.TempDir) 983 } 984 985 // Give time to actually shut down/clean up before the next test 986 time.Sleep(time.Second) 987 if c.CleanupFunc != nil { 988 c.CleanupFunc() 989 } 990} 991 992func (c *TestCluster) ensureCoresSealed() error { 993 for _, core := range c.Cores { 994 if err := core.Shutdown(); err != nil { 995 return err 996 } 997 timeout := time.Now().Add(60 * time.Second) 998 for { 999 if time.Now().After(timeout) { 1000 return fmt.Errorf("timeout waiting for core to seal") 1001 } 1002 if core.Sealed() { 1003 break 1004 } 1005 time.Sleep(250 * time.Millisecond) 1006 } 1007 } 1008 return nil 1009} 1010 1011func SetReplicationFailureMode(core *TestClusterCore, mode uint32) { 1012 atomic.StoreUint32(core.Core.replicationFailure, mode) 1013} 1014 1015type TestListener struct { 1016 net.Listener 1017 Address *net.TCPAddr 1018} 1019 1020type TestClusterCore struct { 1021 *Core 1022 CoreConfig *CoreConfig 1023 Client *api.Client 1024 Handler http.Handler 1025 Listeners []*TestListener 1026 ReloadFuncs *map[string][]reloadutil.ReloadFunc 1027 ReloadFuncsLock *sync.RWMutex 1028 Server *http.Server 1029 ServerCert *x509.Certificate 1030 ServerCertBytes []byte 1031 ServerCertPEM []byte 1032 ServerKey *ecdsa.PrivateKey 1033 ServerKeyPEM []byte 1034 TLSConfig *tls.Config 1035 UnderlyingStorage physical.Backend 1036 UnderlyingRawStorage physical.Backend 1037 Barrier SecurityBarrier 1038 NodeID string 1039} 1040 1041type PhysicalBackendBundle struct { 1042 Backend physical.Backend 1043 HABackend physical.HABackend 1044 Cleanup func() 1045} 1046 1047type TestClusterOptions struct { 1048 KeepStandbysSealed bool 1049 SkipInit bool 1050 HandlerFunc func(*HandlerProperties) http.Handler 1051 DefaultHandlerProperties HandlerProperties 1052 BaseListenAddress string 1053 NumCores int 1054 SealFunc func() Seal 1055 Logger log.Logger 1056 TempDir string 1057 CACert []byte 1058 CAKey *ecdsa.PrivateKey 1059 // PhysicalFactory is used to create backends. 1060 // The int argument is the index of the core within the cluster, i.e. first 1061 // core in cluster will have 0, second 1, etc. 1062 // If the backend is shared across the cluster (i.e. is not Raft) then it 1063 // should return nil when coreIdx != 0. 1064 PhysicalFactory func(t testing.T, coreIdx int, logger hclog.Logger) *PhysicalBackendBundle 1065 // FirstCoreNumber is used to assign a unique number to each core within 1066 // a multi-cluster setup. 1067 FirstCoreNumber int 1068 RequireClientAuth bool 1069 // SetupFunc is called after the cluster is started. 1070 SetupFunc func(t testing.T, c *TestCluster) 1071 PR1103Disabled bool 1072 1073 // ClusterLayers are used to override the default cluster connection layer 1074 ClusterLayers cluster.NetworkLayerSet 1075} 1076 1077var DefaultNumCores = 3 1078 1079type certInfo struct { 1080 cert *x509.Certificate 1081 certPEM []byte 1082 certBytes []byte 1083 key *ecdsa.PrivateKey 1084 keyPEM []byte 1085} 1086 1087// NewTestCluster creates a new test cluster based on the provided core config 1088// and test cluster options. 1089// 1090// N.B. Even though a single base CoreConfig is provided, NewTestCluster will instantiate a 1091// core config for each core it creates. If separate seal per core is desired, opts.SealFunc 1092// can be provided to generate a seal for each one. Otherwise, the provided base.Seal will be 1093// shared among cores. NewCore's default behavior is to generate a new DefaultSeal if the 1094// provided Seal in coreConfig (i.e. base.Seal) is nil. 1095// 1096// If opts.Logger is provided, it takes precedence and will be used as the cluster 1097// logger and will be the basis for each core's logger. If no opts.Logger is 1098// given, one will be generated based on t.Name() for the cluster logger, and if 1099// no base.Logger is given will also be used as the basis for each core's logger. 1100func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *TestCluster { 1101 var err error 1102 1103 var numCores int 1104 if opts == nil || opts.NumCores == 0 { 1105 numCores = DefaultNumCores 1106 } else { 1107 numCores = opts.NumCores 1108 } 1109 1110 var disablePR1103 bool 1111 if opts != nil && opts.PR1103Disabled { 1112 disablePR1103 = true 1113 } 1114 1115 var firstCoreNumber int 1116 if opts != nil { 1117 firstCoreNumber = opts.FirstCoreNumber 1118 } 1119 1120 certIPs := []net.IP{ 1121 net.IPv6loopback, 1122 net.ParseIP("127.0.0.1"), 1123 } 1124 var baseAddr *net.TCPAddr 1125 if opts != nil && opts.BaseListenAddress != "" { 1126 baseAddr, err = net.ResolveTCPAddr("tcp", opts.BaseListenAddress) 1127 if err != nil { 1128 t.Fatal("could not parse given base IP") 1129 } 1130 certIPs = append(certIPs, baseAddr.IP) 1131 } 1132 1133 var testCluster TestCluster 1134 1135 if opts != nil && opts.Logger != nil { 1136 testCluster.Logger = opts.Logger 1137 } else { 1138 testCluster.Logger = logging.NewVaultLogger(log.Trace).Named(t.Name()) 1139 } 1140 1141 if opts != nil && opts.TempDir != "" { 1142 if _, err := os.Stat(opts.TempDir); os.IsNotExist(err) { 1143 if err := os.MkdirAll(opts.TempDir, 0700); err != nil { 1144 t.Fatal(err) 1145 } 1146 } 1147 testCluster.TempDir = opts.TempDir 1148 } else { 1149 tempDir, err := ioutil.TempDir("", "vault-test-cluster-") 1150 if err != nil { 1151 t.Fatal(err) 1152 } 1153 testCluster.TempDir = tempDir 1154 } 1155 1156 var caKey *ecdsa.PrivateKey 1157 if opts != nil && opts.CAKey != nil { 1158 caKey = opts.CAKey 1159 } else { 1160 caKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 1161 if err != nil { 1162 t.Fatal(err) 1163 } 1164 } 1165 testCluster.CAKey = caKey 1166 var caBytes []byte 1167 if opts != nil && len(opts.CACert) > 0 { 1168 caBytes = opts.CACert 1169 } else { 1170 caCertTemplate := &x509.Certificate{ 1171 Subject: pkix.Name{ 1172 CommonName: "localhost", 1173 }, 1174 DNSNames: []string{"localhost"}, 1175 IPAddresses: certIPs, 1176 KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign), 1177 SerialNumber: big.NewInt(mathrand.Int63()), 1178 NotBefore: time.Now().Add(-30 * time.Second), 1179 NotAfter: time.Now().Add(262980 * time.Hour), 1180 BasicConstraintsValid: true, 1181 IsCA: true, 1182 } 1183 caBytes, err = x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, caKey.Public(), caKey) 1184 if err != nil { 1185 t.Fatal(err) 1186 } 1187 } 1188 caCert, err := x509.ParseCertificate(caBytes) 1189 if err != nil { 1190 t.Fatal(err) 1191 } 1192 testCluster.CACert = caCert 1193 testCluster.CACertBytes = caBytes 1194 testCluster.RootCAs = x509.NewCertPool() 1195 testCluster.RootCAs.AddCert(caCert) 1196 caCertPEMBlock := &pem.Block{ 1197 Type: "CERTIFICATE", 1198 Bytes: caBytes, 1199 } 1200 testCluster.CACertPEM = pem.EncodeToMemory(caCertPEMBlock) 1201 testCluster.CACertPEMFile = filepath.Join(testCluster.TempDir, "ca_cert.pem") 1202 err = ioutil.WriteFile(testCluster.CACertPEMFile, testCluster.CACertPEM, 0755) 1203 if err != nil { 1204 t.Fatal(err) 1205 } 1206 marshaledCAKey, err := x509.MarshalECPrivateKey(caKey) 1207 if err != nil { 1208 t.Fatal(err) 1209 } 1210 caKeyPEMBlock := &pem.Block{ 1211 Type: "EC PRIVATE KEY", 1212 Bytes: marshaledCAKey, 1213 } 1214 testCluster.CAKeyPEM = pem.EncodeToMemory(caKeyPEMBlock) 1215 err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "ca_key.pem"), testCluster.CAKeyPEM, 0755) 1216 if err != nil { 1217 t.Fatal(err) 1218 } 1219 1220 var certInfoSlice []*certInfo 1221 1222 // 1223 // Certs generation 1224 // 1225 for i := 0; i < numCores; i++ { 1226 key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 1227 if err != nil { 1228 t.Fatal(err) 1229 } 1230 certTemplate := &x509.Certificate{ 1231 Subject: pkix.Name{ 1232 CommonName: "localhost", 1233 }, 1234 // Include host.docker.internal for the sake of benchmark-vault running on MacOS/Windows. 1235 // This allows Prometheus running in docker to scrape the cluster for metrics. 1236 DNSNames: []string{"localhost", "host.docker.internal"}, 1237 IPAddresses: certIPs, 1238 ExtKeyUsage: []x509.ExtKeyUsage{ 1239 x509.ExtKeyUsageServerAuth, 1240 x509.ExtKeyUsageClientAuth, 1241 }, 1242 KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, 1243 SerialNumber: big.NewInt(mathrand.Int63()), 1244 NotBefore: time.Now().Add(-30 * time.Second), 1245 NotAfter: time.Now().Add(262980 * time.Hour), 1246 } 1247 certBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, caCert, key.Public(), caKey) 1248 if err != nil { 1249 t.Fatal(err) 1250 } 1251 cert, err := x509.ParseCertificate(certBytes) 1252 if err != nil { 1253 t.Fatal(err) 1254 } 1255 certPEMBlock := &pem.Block{ 1256 Type: "CERTIFICATE", 1257 Bytes: certBytes, 1258 } 1259 certPEM := pem.EncodeToMemory(certPEMBlock) 1260 marshaledKey, err := x509.MarshalECPrivateKey(key) 1261 if err != nil { 1262 t.Fatal(err) 1263 } 1264 keyPEMBlock := &pem.Block{ 1265 Type: "EC PRIVATE KEY", 1266 Bytes: marshaledKey, 1267 } 1268 keyPEM := pem.EncodeToMemory(keyPEMBlock) 1269 1270 certInfoSlice = append(certInfoSlice, &certInfo{ 1271 cert: cert, 1272 certPEM: certPEM, 1273 certBytes: certBytes, 1274 key: key, 1275 keyPEM: keyPEM, 1276 }) 1277 } 1278 1279 // 1280 // Listener setup 1281 // 1282 ports := make([]int, numCores) 1283 if baseAddr != nil { 1284 for i := 0; i < numCores; i++ { 1285 ports[i] = baseAddr.Port + i 1286 } 1287 } else { 1288 baseAddr = &net.TCPAddr{ 1289 IP: net.ParseIP("127.0.0.1"), 1290 Port: 0, 1291 } 1292 } 1293 1294 listeners := [][]*TestListener{} 1295 servers := []*http.Server{} 1296 handlers := []http.Handler{} 1297 tlsConfigs := []*tls.Config{} 1298 certGetters := []*reloadutil.CertificateGetter{} 1299 for i := 0; i < numCores; i++ { 1300 baseAddr.Port = ports[i] 1301 ln, err := net.ListenTCP("tcp", baseAddr) 1302 if err != nil { 1303 t.Fatal(err) 1304 } 1305 certFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_cert.pem", i+1, ln.Addr().(*net.TCPAddr).Port)) 1306 keyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_key.pem", i+1, ln.Addr().(*net.TCPAddr).Port)) 1307 err = ioutil.WriteFile(certFile, certInfoSlice[i].certPEM, 0755) 1308 if err != nil { 1309 t.Fatal(err) 1310 } 1311 err = ioutil.WriteFile(keyFile, certInfoSlice[i].keyPEM, 0755) 1312 if err != nil { 1313 t.Fatal(err) 1314 } 1315 tlsCert, err := tls.X509KeyPair(certInfoSlice[i].certPEM, certInfoSlice[i].keyPEM) 1316 if err != nil { 1317 t.Fatal(err) 1318 } 1319 certGetter := reloadutil.NewCertificateGetter(certFile, keyFile, "") 1320 certGetters = append(certGetters, certGetter) 1321 certGetter.Reload(nil) 1322 tlsConfig := &tls.Config{ 1323 Certificates: []tls.Certificate{tlsCert}, 1324 RootCAs: testCluster.RootCAs, 1325 ClientCAs: testCluster.RootCAs, 1326 ClientAuth: tls.RequestClientCert, 1327 NextProtos: []string{"h2", "http/1.1"}, 1328 GetCertificate: certGetter.GetCertificate, 1329 } 1330 if opts != nil && opts.RequireClientAuth { 1331 tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert 1332 testCluster.ClientAuthRequired = true 1333 } 1334 tlsConfig.BuildNameToCertificate() 1335 tlsConfigs = append(tlsConfigs, tlsConfig) 1336 lns := []*TestListener{&TestListener{ 1337 Listener: tls.NewListener(ln, tlsConfig), 1338 Address: ln.Addr().(*net.TCPAddr), 1339 }, 1340 } 1341 listeners = append(listeners, lns) 1342 var handler http.Handler = http.NewServeMux() 1343 handlers = append(handlers, handler) 1344 server := &http.Server{ 1345 Handler: handler, 1346 ErrorLog: testCluster.Logger.StandardLogger(nil), 1347 } 1348 servers = append(servers, server) 1349 } 1350 1351 // Create three cores with the same physical and different redirect/cluster 1352 // addrs. 1353 // N.B.: On OSX, instead of random ports, it assigns new ports to new 1354 // listeners sequentially. Aside from being a bad idea in a security sense, 1355 // it also broke tests that assumed it was OK to just use the port above 1356 // the redirect addr. This has now been changed to 105 ports above, but if 1357 // we ever do more than three nodes in a cluster it may need to be bumped. 1358 // Note: it's 105 so that we don't conflict with a running Consul by 1359 // default. 1360 coreConfig := &CoreConfig{ 1361 LogicalBackends: make(map[string]logical.Factory), 1362 CredentialBackends: make(map[string]logical.Factory), 1363 AuditBackends: make(map[string]audit.Factory), 1364 RedirectAddr: fmt.Sprintf("https://127.0.0.1:%d", listeners[0][0].Address.Port), 1365 ClusterAddr: "https://127.0.0.1:0", 1366 DisableMlock: true, 1367 EnableUI: true, 1368 EnableRaw: true, 1369 BuiltinRegistry: NewMockBuiltinRegistry(), 1370 } 1371 1372 if base != nil { 1373 coreConfig.RawConfig = base.RawConfig 1374 coreConfig.DisableCache = base.DisableCache 1375 coreConfig.EnableUI = base.EnableUI 1376 coreConfig.DefaultLeaseTTL = base.DefaultLeaseTTL 1377 coreConfig.MaxLeaseTTL = base.MaxLeaseTTL 1378 coreConfig.CacheSize = base.CacheSize 1379 coreConfig.PluginDirectory = base.PluginDirectory 1380 coreConfig.Seal = base.Seal 1381 coreConfig.DevToken = base.DevToken 1382 coreConfig.EnableRaw = base.EnableRaw 1383 coreConfig.DisableSealWrap = base.DisableSealWrap 1384 coreConfig.DevLicenseDuration = base.DevLicenseDuration 1385 coreConfig.DisableCache = base.DisableCache 1386 coreConfig.LicensingConfig = base.LicensingConfig 1387 coreConfig.DisablePerformanceStandby = base.DisablePerformanceStandby 1388 coreConfig.MetricsHelper = base.MetricsHelper 1389 coreConfig.SecureRandomReader = base.SecureRandomReader 1390 if base.BuiltinRegistry != nil { 1391 coreConfig.BuiltinRegistry = base.BuiltinRegistry 1392 } 1393 1394 if !coreConfig.DisableMlock { 1395 base.DisableMlock = false 1396 } 1397 1398 if base.Physical != nil { 1399 coreConfig.Physical = base.Physical 1400 } 1401 1402 if base.HAPhysical != nil { 1403 coreConfig.HAPhysical = base.HAPhysical 1404 } 1405 1406 // Used to set something non-working to test fallback 1407 switch base.ClusterAddr { 1408 case "empty": 1409 coreConfig.ClusterAddr = "" 1410 case "": 1411 default: 1412 coreConfig.ClusterAddr = base.ClusterAddr 1413 } 1414 1415 if base.LogicalBackends != nil { 1416 for k, v := range base.LogicalBackends { 1417 coreConfig.LogicalBackends[k] = v 1418 } 1419 } 1420 if base.CredentialBackends != nil { 1421 for k, v := range base.CredentialBackends { 1422 coreConfig.CredentialBackends[k] = v 1423 } 1424 } 1425 if base.AuditBackends != nil { 1426 for k, v := range base.AuditBackends { 1427 coreConfig.AuditBackends[k] = v 1428 } 1429 } 1430 if base.Logger != nil { 1431 coreConfig.Logger = base.Logger 1432 } 1433 1434 coreConfig.ClusterCipherSuites = base.ClusterCipherSuites 1435 1436 coreConfig.DisableCache = base.DisableCache 1437 1438 coreConfig.DevToken = base.DevToken 1439 coreConfig.CounterSyncInterval = base.CounterSyncInterval 1440 coreConfig.RecoveryMode = base.RecoveryMode 1441 } 1442 1443 if coreConfig.RawConfig == nil { 1444 coreConfig.RawConfig = new(server.Config) 1445 } 1446 1447 addAuditBackend := len(coreConfig.AuditBackends) == 0 1448 if addAuditBackend { 1449 AddNoopAudit(coreConfig, nil) 1450 } 1451 1452 if coreConfig.Physical == nil && (opts == nil || opts.PhysicalFactory == nil) { 1453 coreConfig.Physical, err = physInmem.NewInmem(nil, testCluster.Logger) 1454 if err != nil { 1455 t.Fatal(err) 1456 } 1457 } 1458 if coreConfig.HAPhysical == nil && (opts == nil || opts.PhysicalFactory == nil) { 1459 haPhys, err := physInmem.NewInmemHA(nil, testCluster.Logger) 1460 if err != nil { 1461 t.Fatal(err) 1462 } 1463 coreConfig.HAPhysical = haPhys.(physical.HABackend) 1464 } 1465 1466 pubKey, priKey, err := testGenerateCoreKeys() 1467 if err != nil { 1468 t.Fatalf("err: %v", err) 1469 } 1470 1471 cleanupFuncs := []func(){} 1472 cores := []*Core{} 1473 coreConfigs := []*CoreConfig{} 1474 for i := 0; i < numCores; i++ { 1475 localConfig := *coreConfig 1476 localConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", listeners[i][0].Address.Port) 1477 1478 // if opts.SealFunc is provided, use that to generate a seal for the config instead 1479 if opts != nil && opts.SealFunc != nil { 1480 localConfig.Seal = opts.SealFunc() 1481 } 1482 1483 if coreConfig.Logger == nil || (opts != nil && opts.Logger != nil) { 1484 localConfig.Logger = testCluster.Logger.Named(fmt.Sprintf("core%d", i)) 1485 } 1486 if opts != nil && opts.PhysicalFactory != nil { 1487 physBundle := opts.PhysicalFactory(t, i, localConfig.Logger) 1488 switch { 1489 case physBundle == nil && coreConfig.Physical != nil: 1490 case physBundle == nil && coreConfig.Physical == nil: 1491 t.Fatal("PhysicalFactory produced no physical and none in CoreConfig") 1492 case physBundle != nil: 1493 testCluster.Logger.Info("created physical backend", "instance", i) 1494 coreConfig.Physical = physBundle.Backend 1495 localConfig.Physical = physBundle.Backend 1496 base.Physical = physBundle.Backend 1497 haBackend := physBundle.HABackend 1498 if haBackend == nil { 1499 if ha, ok := physBundle.Backend.(physical.HABackend); ok { 1500 haBackend = ha 1501 } 1502 } 1503 coreConfig.HAPhysical = haBackend 1504 localConfig.HAPhysical = haBackend 1505 if physBundle.Cleanup != nil { 1506 cleanupFuncs = append(cleanupFuncs, physBundle.Cleanup) 1507 } 1508 } 1509 } 1510 1511 if opts != nil && opts.ClusterLayers != nil { 1512 localConfig.ClusterNetworkLayer = opts.ClusterLayers.Layers()[i] 1513 } 1514 1515 switch { 1516 case localConfig.LicensingConfig != nil: 1517 if pubKey != nil { 1518 localConfig.LicensingConfig.AdditionalPublicKeys = append(localConfig.LicensingConfig.AdditionalPublicKeys, pubKey.(ed25519.PublicKey)) 1519 } 1520 default: 1521 localConfig.LicensingConfig = testGetLicensingConfig(pubKey) 1522 } 1523 1524 if localConfig.MetricsHelper == nil { 1525 inm := metrics.NewInmemSink(10*time.Second, time.Minute) 1526 metrics.DefaultInmemSignal(inm) 1527 localConfig.MetricsHelper = metricsutil.NewMetricsHelper(inm, false) 1528 } 1529 1530 c, err := NewCore(&localConfig) 1531 if err != nil { 1532 t.Fatalf("err: %v", err) 1533 } 1534 c.coreNumber = firstCoreNumber + i 1535 c.PR1103disabled = disablePR1103 1536 cores = append(cores, c) 1537 coreConfigs = append(coreConfigs, &localConfig) 1538 if opts != nil && opts.HandlerFunc != nil { 1539 props := opts.DefaultHandlerProperties 1540 props.Core = c 1541 if props.MaxRequestDuration == 0 { 1542 props.MaxRequestDuration = DefaultMaxRequestDuration 1543 } 1544 handlers[i] = opts.HandlerFunc(&props) 1545 servers[i].Handler = handlers[i] 1546 } 1547 1548 // Set this in case the Seal was manually set before the core was 1549 // created 1550 if localConfig.Seal != nil { 1551 localConfig.Seal.SetCore(c) 1552 } 1553 } 1554 1555 // 1556 // Clustering setup 1557 // 1558 clusterAddrGen := func(lns []*TestListener) []*net.TCPAddr { 1559 ret := make([]*net.TCPAddr, len(lns)) 1560 for i, ln := range lns { 1561 ret[i] = &net.TCPAddr{ 1562 IP: ln.Address.IP, 1563 Port: 0, 1564 } 1565 } 1566 return ret 1567 } 1568 1569 for i := 0; i < numCores; i++ { 1570 if coreConfigs[i].ClusterAddr != "" { 1571 cores[i].SetClusterListenerAddrs(clusterAddrGen(listeners[i])) 1572 cores[i].SetClusterHandler(handlers[i]) 1573 } 1574 } 1575 1576 if opts == nil || !opts.SkipInit { 1577 bKeys, rKeys, root := TestCoreInitClusterWrapperSetup(t, cores[0], handlers[0]) 1578 barrierKeys, _ := copystructure.Copy(bKeys) 1579 testCluster.BarrierKeys = barrierKeys.([][]byte) 1580 recoveryKeys, _ := copystructure.Copy(rKeys) 1581 testCluster.RecoveryKeys = recoveryKeys.([][]byte) 1582 testCluster.RootToken = root 1583 1584 // Write root token and barrier keys 1585 err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "root_token"), []byte(root), 0755) 1586 if err != nil { 1587 t.Fatal(err) 1588 } 1589 var buf bytes.Buffer 1590 for i, key := range testCluster.BarrierKeys { 1591 buf.Write([]byte(base64.StdEncoding.EncodeToString(key))) 1592 if i < len(testCluster.BarrierKeys)-1 { 1593 buf.WriteRune('\n') 1594 } 1595 } 1596 err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "barrier_keys"), buf.Bytes(), 0755) 1597 if err != nil { 1598 t.Fatal(err) 1599 } 1600 for i, key := range testCluster.RecoveryKeys { 1601 buf.Write([]byte(base64.StdEncoding.EncodeToString(key))) 1602 if i < len(testCluster.RecoveryKeys)-1 { 1603 buf.WriteRune('\n') 1604 } 1605 } 1606 err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "recovery_keys"), buf.Bytes(), 0755) 1607 if err != nil { 1608 t.Fatal(err) 1609 } 1610 1611 // Unseal first core 1612 for _, key := range bKeys { 1613 if _, err := cores[0].Unseal(TestKeyCopy(key)); err != nil { 1614 t.Fatalf("unseal err: %s", err) 1615 } 1616 } 1617 1618 ctx := context.Background() 1619 1620 // If stored keys is supported, the above will no no-op, so trigger auto-unseal 1621 // using stored keys to try to unseal 1622 if err := cores[0].UnsealWithStoredKeys(ctx); err != nil { 1623 t.Fatal(err) 1624 } 1625 1626 // Verify unsealed 1627 if cores[0].Sealed() { 1628 t.Fatal("should not be sealed") 1629 } 1630 1631 TestWaitActive(t, cores[0]) 1632 1633 // Existing tests rely on this; we can make a toggle to disable it 1634 // later if we want 1635 kvReq := &logical.Request{ 1636 Operation: logical.UpdateOperation, 1637 ClientToken: testCluster.RootToken, 1638 Path: "sys/mounts/secret", 1639 Data: map[string]interface{}{ 1640 "type": "kv", 1641 "path": "secret/", 1642 "description": "key/value secret storage", 1643 "options": map[string]string{ 1644 "version": "1", 1645 }, 1646 }, 1647 } 1648 resp, err := cores[0].HandleRequest(namespace.RootContext(ctx), kvReq) 1649 if err != nil { 1650 t.Fatal(err) 1651 } 1652 if resp.IsError() { 1653 t.Fatal(err) 1654 } 1655 1656 cfg, err := cores[0].seal.BarrierConfig(ctx) 1657 if err != nil { 1658 t.Fatal(err) 1659 } 1660 1661 // Unseal other cores unless otherwise specified 1662 if (opts == nil || !opts.KeepStandbysSealed) && numCores > 1 { 1663 for i := 1; i < numCores; i++ { 1664 cores[i].seal.SetCachedBarrierConfig(cfg) 1665 for _, key := range bKeys { 1666 if _, err := cores[i].Unseal(TestKeyCopy(key)); err != nil { 1667 t.Fatalf("unseal err: %s", err) 1668 } 1669 } 1670 1671 // If stored keys is supported, the above will no no-op, so trigger auto-unseal 1672 // using stored keys 1673 if err := cores[i].UnsealWithStoredKeys(ctx); err != nil { 1674 t.Fatal(err) 1675 } 1676 } 1677 1678 // Let them come fully up to standby 1679 time.Sleep(2 * time.Second) 1680 1681 // Ensure cluster connection info is populated. 1682 // Other cores should not come up as leaders. 1683 for i := 1; i < numCores; i++ { 1684 isLeader, _, _, err := cores[i].Leader() 1685 if err != nil { 1686 t.Fatal(err) 1687 } 1688 if isLeader { 1689 t.Fatalf("core[%d] should not be leader", i) 1690 } 1691 } 1692 } 1693 1694 // 1695 // Set test cluster core(s) and test cluster 1696 // 1697 cluster, err := cores[0].Cluster(context.Background()) 1698 if err != nil { 1699 t.Fatal(err) 1700 } 1701 testCluster.ID = cluster.ID 1702 1703 if addAuditBackend { 1704 // Enable auditing. 1705 auditReq := &logical.Request{ 1706 Operation: logical.UpdateOperation, 1707 ClientToken: testCluster.RootToken, 1708 Path: "sys/audit/noop", 1709 Data: map[string]interface{}{ 1710 "type": "noop", 1711 }, 1712 } 1713 resp, err = cores[0].HandleRequest(namespace.RootContext(ctx), auditReq) 1714 if err != nil { 1715 t.Fatal(err) 1716 } 1717 1718 if resp.IsError() { 1719 t.Fatal(err) 1720 } 1721 } 1722 } 1723 1724 getAPIClient := func(port int, tlsConfig *tls.Config) *api.Client { 1725 transport := cleanhttp.DefaultPooledTransport() 1726 transport.TLSClientConfig = tlsConfig.Clone() 1727 if err := http2.ConfigureTransport(transport); err != nil { 1728 t.Fatal(err) 1729 } 1730 client := &http.Client{ 1731 Transport: transport, 1732 CheckRedirect: func(*http.Request, []*http.Request) error { 1733 // This can of course be overridden per-test by using its own client 1734 return fmt.Errorf("redirects not allowed in these tests") 1735 }, 1736 } 1737 config := api.DefaultConfig() 1738 if config.Error != nil { 1739 t.Fatal(config.Error) 1740 } 1741 config.Address = fmt.Sprintf("https://127.0.0.1:%d", port) 1742 config.HttpClient = client 1743 config.MaxRetries = 0 1744 apiClient, err := api.NewClient(config) 1745 if err != nil { 1746 t.Fatal(err) 1747 } 1748 if opts == nil || !opts.SkipInit { 1749 apiClient.SetToken(testCluster.RootToken) 1750 } 1751 return apiClient 1752 } 1753 1754 var ret []*TestClusterCore 1755 for i := 0; i < numCores; i++ { 1756 tcc := &TestClusterCore{ 1757 Core: cores[i], 1758 CoreConfig: coreConfigs[i], 1759 ServerKey: certInfoSlice[i].key, 1760 ServerKeyPEM: certInfoSlice[i].keyPEM, 1761 ServerCert: certInfoSlice[i].cert, 1762 ServerCertBytes: certInfoSlice[i].certBytes, 1763 ServerCertPEM: certInfoSlice[i].certPEM, 1764 Listeners: listeners[i], 1765 Handler: handlers[i], 1766 Server: servers[i], 1767 TLSConfig: tlsConfigs[i], 1768 Client: getAPIClient(listeners[i][0].Address.Port, tlsConfigs[i]), 1769 Barrier: cores[i].barrier, 1770 NodeID: fmt.Sprintf("core-%d", i), 1771 UnderlyingRawStorage: coreConfigs[i].Physical, 1772 } 1773 tcc.ReloadFuncs = &cores[i].reloadFuncs 1774 tcc.ReloadFuncsLock = &cores[i].reloadFuncsLock 1775 tcc.ReloadFuncsLock.Lock() 1776 (*tcc.ReloadFuncs)["listener|tcp"] = []reloadutil.ReloadFunc{certGetters[i].Reload} 1777 tcc.ReloadFuncsLock.Unlock() 1778 1779 testAdjustTestCore(base, tcc) 1780 1781 ret = append(ret, tcc) 1782 } 1783 1784 testCluster.Cores = ret 1785 1786 testExtraClusterCoresTestSetup(t, priKey, testCluster.Cores) 1787 1788 testCluster.CleanupFunc = func() { 1789 for _, c := range cleanupFuncs { 1790 c() 1791 } 1792 } 1793 if opts != nil { 1794 if opts.SetupFunc != nil { 1795 testCluster.SetupFunc = func() { 1796 opts.SetupFunc(t, &testCluster) 1797 } 1798 } 1799 } 1800 1801 return &testCluster 1802} 1803 1804func NewMockBuiltinRegistry() *mockBuiltinRegistry { 1805 return &mockBuiltinRegistry{ 1806 forTesting: map[string]consts.PluginType{ 1807 "mysql-database-plugin": consts.PluginTypeDatabase, 1808 "postgresql-database-plugin": consts.PluginTypeDatabase, 1809 }, 1810 } 1811} 1812 1813type mockBuiltinRegistry struct { 1814 forTesting map[string]consts.PluginType 1815} 1816 1817func (m *mockBuiltinRegistry) Get(name string, pluginType consts.PluginType) (func() (interface{}, error), bool) { 1818 testPluginType, ok := m.forTesting[name] 1819 if !ok { 1820 return nil, false 1821 } 1822 if pluginType != testPluginType { 1823 return nil, false 1824 } 1825 if name == "postgresql-database-plugin" { 1826 return dbPostgres.New, true 1827 } 1828 return dbMysql.New(dbMysql.MetadataLen, dbMysql.MetadataLen, dbMysql.UsernameLen), true 1829} 1830 1831// Keys only supports getting a realistic list of the keys for database plugins. 1832func (m *mockBuiltinRegistry) Keys(pluginType consts.PluginType) []string { 1833 if pluginType != consts.PluginTypeDatabase { 1834 return []string{} 1835 } 1836 /* 1837 This is a hard-coded reproduction of the db plugin keys in helper/builtinplugins/registry.go. 1838 The registry isn't directly used because it causes import cycles. 1839 */ 1840 return []string{ 1841 "mysql-database-plugin", 1842 "mysql-aurora-database-plugin", 1843 "mysql-rds-database-plugin", 1844 "mysql-legacy-database-plugin", 1845 "postgresql-database-plugin", 1846 "elasticsearch-database-plugin", 1847 "mssql-database-plugin", 1848 "cassandra-database-plugin", 1849 "mongodb-database-plugin", 1850 "mongodbatlas-database-plugin", 1851 "hana-database-plugin", 1852 "influxdb-database-plugin", 1853 "redshift-database-plugin", 1854 } 1855} 1856 1857func (m *mockBuiltinRegistry) Contains(name string, pluginType consts.PluginType) bool { 1858 return false 1859} 1860 1861type NoopAudit struct { 1862 Config *audit.BackendConfig 1863 ReqErr error 1864 ReqAuth []*logical.Auth 1865 Req []*logical.Request 1866 ReqHeaders []map[string][]string 1867 ReqNonHMACKeys []string 1868 ReqErrs []error 1869 1870 RespErr error 1871 RespAuth []*logical.Auth 1872 RespReq []*logical.Request 1873 Resp []*logical.Response 1874 RespNonHMACKeys []string 1875 RespReqNonHMACKeys []string 1876 RespErrs []error 1877 1878 salt *salt.Salt 1879 saltMutex sync.RWMutex 1880} 1881 1882func (n *NoopAudit) LogRequest(ctx context.Context, in *logical.LogInput) error { 1883 n.ReqAuth = append(n.ReqAuth, in.Auth) 1884 n.Req = append(n.Req, in.Request) 1885 n.ReqHeaders = append(n.ReqHeaders, in.Request.Headers) 1886 n.ReqNonHMACKeys = in.NonHMACReqDataKeys 1887 n.ReqErrs = append(n.ReqErrs, in.OuterErr) 1888 return n.ReqErr 1889} 1890 1891func (n *NoopAudit) LogResponse(ctx context.Context, in *logical.LogInput) error { 1892 n.RespAuth = append(n.RespAuth, in.Auth) 1893 n.RespReq = append(n.RespReq, in.Request) 1894 n.Resp = append(n.Resp, in.Response) 1895 n.RespErrs = append(n.RespErrs, in.OuterErr) 1896 1897 if in.Response != nil { 1898 n.RespNonHMACKeys = in.NonHMACRespDataKeys 1899 n.RespReqNonHMACKeys = in.NonHMACReqDataKeys 1900 } 1901 1902 return n.RespErr 1903} 1904 1905func (n *NoopAudit) Salt(ctx context.Context) (*salt.Salt, error) { 1906 n.saltMutex.RLock() 1907 if n.salt != nil { 1908 defer n.saltMutex.RUnlock() 1909 return n.salt, nil 1910 } 1911 n.saltMutex.RUnlock() 1912 n.saltMutex.Lock() 1913 defer n.saltMutex.Unlock() 1914 if n.salt != nil { 1915 return n.salt, nil 1916 } 1917 salt, err := salt.NewSalt(ctx, n.Config.SaltView, n.Config.SaltConfig) 1918 if err != nil { 1919 return nil, err 1920 } 1921 n.salt = salt 1922 return salt, nil 1923} 1924 1925func (n *NoopAudit) GetHash(ctx context.Context, data string) (string, error) { 1926 salt, err := n.Salt(ctx) 1927 if err != nil { 1928 return "", err 1929 } 1930 return salt.GetIdentifiedHMAC(data), nil 1931} 1932 1933func (n *NoopAudit) Reload(ctx context.Context) error { 1934 return nil 1935} 1936 1937func (n *NoopAudit) Invalidate(ctx context.Context) { 1938 n.saltMutex.Lock() 1939 defer n.saltMutex.Unlock() 1940 n.salt = nil 1941} 1942