1package vault 2 3import ( 4 "bytes" 5 "context" 6 "crypto/ecdsa" 7 "crypto/elliptic" 8 "crypto/rand" 9 "crypto/sha256" 10 "crypto/tls" 11 "crypto/x509" 12 "crypto/x509/pkix" 13 "encoding/base64" 14 "encoding/pem" 15 "errors" 16 "fmt" 17 "io" 18 "io/ioutil" 19 "math/big" 20 mathrand "math/rand" 21 "net" 22 "net/http" 23 "os" 24 "os/exec" 25 "path/filepath" 26 "sync" 27 "sync/atomic" 28 "time" 29 30 hclog "github.com/hashicorp/go-hclog" 31 log "github.com/hashicorp/go-hclog" 32 "github.com/mitchellh/copystructure" 33 34 "golang.org/x/crypto/ed25519" 35 "golang.org/x/crypto/ssh" 36 "golang.org/x/net/http2" 37 38 cleanhttp "github.com/hashicorp/go-cleanhttp" 39 "github.com/hashicorp/vault/api" 40 "github.com/hashicorp/vault/audit" 41 "github.com/hashicorp/vault/helper/namespace" 42 "github.com/hashicorp/vault/helper/reload" 43 dbMysql "github.com/hashicorp/vault/plugins/database/mysql" 44 dbPostgres "github.com/hashicorp/vault/plugins/database/postgresql" 45 "github.com/hashicorp/vault/sdk/framework" 46 "github.com/hashicorp/vault/sdk/helper/consts" 47 "github.com/hashicorp/vault/sdk/helper/logging" 48 "github.com/hashicorp/vault/sdk/helper/salt" 49 "github.com/hashicorp/vault/sdk/logical" 50 "github.com/hashicorp/vault/sdk/physical" 51 testing "github.com/mitchellh/go-testing-interface" 52 53 physInmem "github.com/hashicorp/vault/sdk/physical/inmem" 54) 55 56// This file contains a number of methods that are useful for unit 57// tests within other packages. 58 59const ( 60 testSharedPublicKey = ` 61ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQC9i+hFxZHGo6KblVme4zrAcJstR6I0PTJozW286X4WyvPnkMYDQ5mnhEYC7UWCvjoTWbPEXPX7NjhRtwQTGD67bV+lrxgfyzK1JZbUXK4PwgKJvQD+XyyWYMzDgGSQY61KUSqCxymSm/9NZkPU3ElaQ9xQuTzPpztM4ROfb8f2Yv6/ZESZsTo0MTAkp8Pcy+WkioI/uJ1H7zqs0EA4OMY4aDJRu0UtP4rTVeYNEAuRXdX+eH4aW3KMvhzpFTjMbaJHJXlEeUm2SaX5TNQyTOvghCeQILfYIL/Ca2ij8iwCmulwdV6eQGfd4VDu40PvSnmfoaE38o6HaPnX0kUcnKiT 62` 63 testSharedPrivateKey = ` 64-----BEGIN RSA PRIVATE KEY----- 65MIIEogIBAAKCAQEAvYvoRcWRxqOim5VZnuM6wHCbLUeiND0yaM1tvOl+Fsrz55DG 66A0OZp4RGAu1Fgr46E1mzxFz1+zY4UbcEExg+u21fpa8YH8sytSWW1FyuD8ICib0A 67/l8slmDMw4BkkGOtSlEqgscpkpv/TWZD1NxJWkPcULk8z6c7TOETn2/H9mL+v2RE 68mbE6NDEwJKfD3MvlpIqCP7idR+86rNBAODjGOGgyUbtFLT+K01XmDRALkV3V/nh+ 69GltyjL4c6RU4zG2iRyV5RHlJtkml+UzUMkzr4IQnkCC32CC/wmtoo/IsAprpcHVe 70nkBn3eFQ7uND70p5n6GhN/KOh2j519JFHJyokwIDAQABAoIBAHX7VOvBC3kCN9/x 71+aPdup84OE7Z7MvpX6w+WlUhXVugnmsAAVDczhKoUc/WktLLx2huCGhsmKvyVuH+ 72MioUiE+vx75gm3qGx5xbtmOfALVMRLopjCnJYf6EaFA0ZeQ+NwowNW7Lu0PHmAU8 73Z3JiX8IwxTz14DU82buDyewO7v+cEr97AnERe3PUcSTDoUXNaoNxjNpEJkKREY6h 744hAY676RT/GsRcQ8tqe/rnCqPHNd7JGqL+207FK4tJw7daoBjQyijWuB7K5chSal 75oPInylM6b13ASXuOAOT/2uSUBWmFVCZPDCmnZxy2SdnJGbsJAMl7Ma3MUlaGvVI+ 76Tfh1aQkCgYEA4JlNOabTb3z42wz6mz+Nz3JRwbawD+PJXOk5JsSnV7DtPtfgkK9y 776FTQdhnozGWShAvJvc+C4QAihs9AlHXoaBY5bEU7R/8UK/pSqwzam+MmxmhVDV7G 78IMQPV0FteoXTaJSikhZ88mETTegI2mik+zleBpVxvfdhE5TR+lq8Br0CgYEA2AwJ 79CUD5CYUSj09PluR0HHqamWOrJkKPFPwa+5eiTTCzfBBxImYZh7nXnWuoviXC0sg2 80AuvCW+uZ48ygv/D8gcz3j1JfbErKZJuV+TotK9rRtNIF5Ub7qysP7UjyI7zCssVM 81kuDd9LfRXaB/qGAHNkcDA8NxmHW3gpln4CFdSY8CgYANs4xwfercHEWaJ1qKagAe 82rZyrMpffAEhicJ/Z65lB0jtG4CiE6w8ZeUMWUVJQVcnwYD+4YpZbX4S7sJ0B8Ydy 83AhkSr86D/92dKTIt2STk6aCN7gNyQ1vW198PtaAWH1/cO2UHgHOy3ZUt5X/Uwxl9 84cex4flln+1Viumts2GgsCQKBgCJH7psgSyPekK5auFdKEr5+Gc/jB8I/Z3K9+g4X 855nH3G1PBTCJYLw7hRzw8W/8oALzvddqKzEFHphiGXK94Lqjt/A4q1OdbCrhiE68D 86My21P/dAKB1UYRSs9Y8CNyHCjuZM9jSMJ8vv6vG/SOJPsnVDWVAckAbQDvlTHC9t 87O98zAoGAcbW6uFDkrv0XMCpB9Su3KaNXOR0wzag+WIFQRXCcoTvxVi9iYfUReQPi 88oOyBJU/HMVvBfv4g+OVFLVgSwwm6owwsouZ0+D/LasbuHqYyqYqdyPJQYzWA2Y+F 89+B6f4RoPdSXj24JHPg/ioRxjaj094UXJxua2yfkcecGNEuBQHSs= 90-----END RSA PRIVATE KEY----- 91` 92) 93 94// TestCore returns a pure in-memory, uninitialized core for testing. 95func TestCore(t testing.T) *Core { 96 return TestCoreWithSeal(t, nil, false) 97} 98 99// TestCoreRaw returns a pure in-memory, uninitialized core for testing. The raw 100// storage endpoints are enabled with this core. 101func TestCoreRaw(t testing.T) *Core { 102 return TestCoreWithSeal(t, nil, true) 103} 104 105// TestCoreNewSeal returns a pure in-memory, uninitialized core with 106// the new seal configuration. 107func TestCoreNewSeal(t testing.T) *Core { 108 seal := NewTestSeal(t, nil) 109 return TestCoreWithSeal(t, seal, false) 110} 111 112// TestCoreWithConfig returns a pure in-memory, uninitialized core with the 113// specified core configurations overridden for testing. 114func TestCoreWithConfig(t testing.T, conf *CoreConfig) *Core { 115 return TestCoreWithSealAndUI(t, conf) 116} 117 118// TestCoreWithSeal returns a pure in-memory, uninitialized core with the 119// specified seal for testing. 120func TestCoreWithSeal(t testing.T, testSeal Seal, enableRaw bool) *Core { 121 conf := &CoreConfig{ 122 Seal: testSeal, 123 EnableUI: false, 124 EnableRaw: enableRaw, 125 BuiltinRegistry: NewMockBuiltinRegistry(), 126 } 127 return TestCoreWithSealAndUI(t, conf) 128} 129 130func TestCoreUI(t testing.T, enableUI bool) *Core { 131 conf := &CoreConfig{ 132 EnableUI: enableUI, 133 EnableRaw: true, 134 BuiltinRegistry: NewMockBuiltinRegistry(), 135 } 136 return TestCoreWithSealAndUI(t, conf) 137} 138 139func TestCoreWithSealAndUI(t testing.T, opts *CoreConfig) *Core { 140 logger := logging.NewVaultLogger(log.Trace) 141 physicalBackend, err := physInmem.NewInmem(nil, logger) 142 if err != nil { 143 t.Fatal(err) 144 } 145 146 // Start off with base test core config 147 conf := testCoreConfig(t, physicalBackend, logger) 148 149 // Override config values with ones that gets passed in 150 conf.EnableUI = opts.EnableUI 151 conf.EnableRaw = opts.EnableRaw 152 conf.Seal = opts.Seal 153 conf.LicensingConfig = opts.LicensingConfig 154 conf.DisableKeyEncodingChecks = opts.DisableKeyEncodingChecks 155 156 if opts.Logger != nil { 157 conf.Logger = opts.Logger 158 } 159 160 for k, v := range opts.LogicalBackends { 161 conf.LogicalBackends[k] = v 162 } 163 for k, v := range opts.CredentialBackends { 164 conf.CredentialBackends[k] = v 165 } 166 167 for k, v := range opts.AuditBackends { 168 conf.AuditBackends[k] = v 169 } 170 171 c, err := NewCore(conf) 172 if err != nil { 173 t.Fatalf("err: %s", err) 174 } 175 176 return c 177} 178 179func testCoreConfig(t testing.T, physicalBackend physical.Backend, logger log.Logger) *CoreConfig { 180 t.Helper() 181 noopAudits := map[string]audit.Factory{ 182 "noop": func(_ context.Context, config *audit.BackendConfig) (audit.Backend, error) { 183 view := &logical.InmemStorage{} 184 view.Put(context.Background(), &logical.StorageEntry{ 185 Key: "salt", 186 Value: []byte("foo"), 187 }) 188 config.SaltConfig = &salt.Config{ 189 HMAC: sha256.New, 190 HMACType: "hmac-sha256", 191 } 192 config.SaltView = view 193 194 n := &noopAudit{ 195 Config: config, 196 } 197 n.formatter.AuditFormatWriter = &audit.JSONFormatWriter{ 198 SaltFunc: n.Salt, 199 } 200 return n, nil 201 }, 202 } 203 204 noopBackends := make(map[string]logical.Factory) 205 noopBackends["noop"] = func(ctx context.Context, config *logical.BackendConfig) (logical.Backend, error) { 206 b := new(framework.Backend) 207 b.Setup(ctx, config) 208 b.BackendType = logical.TypeCredential 209 return b, nil 210 } 211 noopBackends["http"] = func(ctx context.Context, config *logical.BackendConfig) (logical.Backend, error) { 212 return new(rawHTTP), nil 213 } 214 215 credentialBackends := make(map[string]logical.Factory) 216 for backendName, backendFactory := range noopBackends { 217 credentialBackends[backendName] = backendFactory 218 } 219 for backendName, backendFactory := range testCredentialBackends { 220 credentialBackends[backendName] = backendFactory 221 } 222 223 logicalBackends := make(map[string]logical.Factory) 224 for backendName, backendFactory := range noopBackends { 225 logicalBackends[backendName] = backendFactory 226 } 227 228 logicalBackends["kv"] = LeasedPassthroughBackendFactory 229 for backendName, backendFactory := range testLogicalBackends { 230 logicalBackends[backendName] = backendFactory 231 } 232 233 conf := &CoreConfig{ 234 Physical: physicalBackend, 235 AuditBackends: noopAudits, 236 LogicalBackends: logicalBackends, 237 CredentialBackends: credentialBackends, 238 DisableMlock: true, 239 Logger: logger, 240 BuiltinRegistry: NewMockBuiltinRegistry(), 241 } 242 243 return conf 244} 245 246// TestCoreInit initializes the core with a single key, and returns 247// the key that must be used to unseal the core and a root token. 248func TestCoreInit(t testing.T, core *Core) ([][]byte, string) { 249 t.Helper() 250 secretShares, _, root := TestCoreInitClusterWrapperSetup(t, core, nil) 251 return secretShares, root 252} 253 254func TestCoreInitClusterWrapperSetup(t testing.T, core *Core, handler http.Handler) ([][]byte, [][]byte, string) { 255 t.Helper() 256 core.SetClusterHandler(handler) 257 258 barrierConfig := &SealConfig{ 259 SecretShares: 3, 260 SecretThreshold: 3, 261 } 262 263 // If we support storing barrier keys, then set that to equal the min threshold to unseal 264 if core.seal.StoredKeysSupported() { 265 barrierConfig.StoredShares = barrierConfig.SecretThreshold 266 } 267 268 recoveryConfig := &SealConfig{ 269 SecretShares: 3, 270 SecretThreshold: 3, 271 } 272 273 result, err := core.Initialize(context.Background(), &InitParams{ 274 BarrierConfig: barrierConfig, 275 RecoveryConfig: recoveryConfig, 276 }) 277 if err != nil { 278 t.Fatalf("err: %s", err) 279 } 280 return result.SecretShares, result.RecoveryShares, result.RootToken 281} 282 283func TestCoreUnseal(core *Core, key []byte) (bool, error) { 284 return core.Unseal(key) 285} 286 287func TestCoreUnsealWithRecoveryKeys(core *Core, key []byte) (bool, error) { 288 return core.UnsealWithRecoveryKeys(key) 289} 290 291// TestCoreUnsealed returns a pure in-memory core that is already 292// initialized and unsealed. 293func TestCoreUnsealed(t testing.T) (*Core, [][]byte, string) { 294 t.Helper() 295 core := TestCore(t) 296 return testCoreUnsealed(t, core) 297} 298 299// TestCoreUnsealedRaw returns a pure in-memory core that is already 300// initialized, unsealed, and with raw endpoints enabled. 301func TestCoreUnsealedRaw(t testing.T) (*Core, [][]byte, string) { 302 t.Helper() 303 core := TestCoreRaw(t) 304 return testCoreUnsealed(t, core) 305} 306 307// TestCoreUnsealedWithConfig returns a pure in-memory core that is already 308// initialized, unsealed, with the any provided core config values overridden. 309func TestCoreUnsealedWithConfig(t testing.T, conf *CoreConfig) (*Core, [][]byte, string) { 310 t.Helper() 311 core := TestCoreWithConfig(t, conf) 312 return testCoreUnsealed(t, core) 313} 314 315func testCoreUnsealed(t testing.T, core *Core) (*Core, [][]byte, string) { 316 t.Helper() 317 keys, token := TestCoreInit(t, core) 318 for _, key := range keys { 319 if _, err := TestCoreUnseal(core, TestKeyCopy(key)); err != nil { 320 t.Fatalf("unseal err: %s", err) 321 } 322 } 323 324 if core.Sealed() { 325 t.Fatal("should not be sealed") 326 } 327 328 testCoreAddSecretMount(t, core, token) 329 330 return core, keys, token 331} 332 333func testCoreAddSecretMount(t testing.T, core *Core, token string) { 334 kvReq := &logical.Request{ 335 Operation: logical.UpdateOperation, 336 ClientToken: token, 337 Path: "sys/mounts/secret", 338 Data: map[string]interface{}{ 339 "type": "kv", 340 "path": "secret/", 341 "description": "key/value secret storage", 342 "options": map[string]string{ 343 "version": "1", 344 }, 345 }, 346 } 347 resp, err := core.HandleRequest(namespace.RootContext(nil), kvReq) 348 if err != nil { 349 t.Fatal(err) 350 } 351 if resp.IsError() { 352 t.Fatal(err) 353 } 354 355} 356 357func TestCoreUnsealedBackend(t testing.T, backend physical.Backend) (*Core, [][]byte, string) { 358 t.Helper() 359 logger := logging.NewVaultLogger(log.Trace) 360 conf := testCoreConfig(t, backend, logger) 361 conf.Seal = NewTestSeal(t, nil) 362 363 core, err := NewCore(conf) 364 if err != nil { 365 t.Fatalf("err: %s", err) 366 } 367 368 keys, token := TestCoreInit(t, core) 369 for _, key := range keys { 370 if _, err := TestCoreUnseal(core, TestKeyCopy(key)); err != nil { 371 t.Fatalf("unseal err: %s", err) 372 } 373 } 374 375 if err := core.UnsealWithStoredKeys(context.Background()); err != nil { 376 t.Fatal(err) 377 } 378 379 if core.Sealed() { 380 t.Fatal("should not be sealed") 381 } 382 383 return core, keys, token 384} 385 386// TestKeyCopy is a silly little function to just copy the key so that 387// it can be used with Unseal easily. 388func TestKeyCopy(key []byte) []byte { 389 result := make([]byte, len(key)) 390 copy(result, key) 391 return result 392} 393 394func TestDynamicSystemView(c *Core) *dynamicSystemView { 395 me := &MountEntry{ 396 Config: MountConfig{ 397 DefaultLeaseTTL: 24 * time.Hour, 398 MaxLeaseTTL: 2 * 24 * time.Hour, 399 }, 400 } 401 402 return &dynamicSystemView{c, me} 403} 404 405// TestAddTestPlugin registers the testFunc as part of the plugin command to the 406// plugin catalog. If provided, uses tmpDir as the plugin directory. 407func TestAddTestPlugin(t testing.T, c *Core, name string, pluginType consts.PluginType, testFunc string, env []string, tempDir string) { 408 file, err := os.Open(os.Args[0]) 409 if err != nil { 410 t.Fatal(err) 411 } 412 defer file.Close() 413 414 dirPath := filepath.Dir(os.Args[0]) 415 fileName := filepath.Base(os.Args[0]) 416 417 if tempDir != "" { 418 fi, err := file.Stat() 419 if err != nil { 420 t.Fatal(err) 421 } 422 423 // Copy over the file to the temp dir 424 dst := filepath.Join(tempDir, fileName) 425 out, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode()) 426 if err != nil { 427 t.Fatal(err) 428 } 429 defer out.Close() 430 431 if _, err = io.Copy(out, file); err != nil { 432 t.Fatal(err) 433 } 434 err = out.Sync() 435 if err != nil { 436 t.Fatal(err) 437 } 438 439 dirPath = tempDir 440 } 441 442 // Determine plugin directory full path, evaluating potential symlink path 443 fullPath, err := filepath.EvalSymlinks(dirPath) 444 if err != nil { 445 t.Fatal(err) 446 } 447 448 reader, err := os.Open(filepath.Join(fullPath, fileName)) 449 if err != nil { 450 t.Fatal(err) 451 } 452 defer reader.Close() 453 454 // Find out the sha256 455 hash := sha256.New() 456 457 _, err = io.Copy(hash, reader) 458 if err != nil { 459 t.Fatal(err) 460 } 461 462 sum := hash.Sum(nil) 463 464 // Set core's plugin directory and plugin catalog directory 465 c.pluginDirectory = fullPath 466 c.pluginCatalog.directory = fullPath 467 468 args := []string{fmt.Sprintf("--test.run=%s", testFunc)} 469 err = c.pluginCatalog.Set(context.Background(), name, pluginType, fileName, args, env, sum) 470 if err != nil { 471 t.Fatal(err) 472 } 473} 474 475var testLogicalBackends = map[string]logical.Factory{} 476var testCredentialBackends = map[string]logical.Factory{} 477 478// StartSSHHostTestServer starts the test server which responds to SSH 479// authentication. Used to test the SSH secret backend. 480func StartSSHHostTestServer() (string, error) { 481 pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(testSharedPublicKey)) 482 if err != nil { 483 return "", fmt.Errorf("error parsing public key") 484 } 485 serverConfig := &ssh.ServerConfig{ 486 PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { 487 if bytes.Compare(pubKey.Marshal(), key.Marshal()) == 0 { 488 return &ssh.Permissions{}, nil 489 } else { 490 return nil, fmt.Errorf("key does not match") 491 } 492 }, 493 } 494 signer, err := ssh.ParsePrivateKey([]byte(testSharedPrivateKey)) 495 if err != nil { 496 panic("Error parsing private key") 497 } 498 serverConfig.AddHostKey(signer) 499 500 soc, err := net.Listen("tcp", "127.0.0.1:0") 501 if err != nil { 502 return "", fmt.Errorf("error listening to connection") 503 } 504 505 go func() { 506 for { 507 conn, err := soc.Accept() 508 if err != nil { 509 panic(fmt.Sprintf("Error accepting incoming connection: %s", err)) 510 } 511 defer conn.Close() 512 sshConn, chanReqs, _, err := ssh.NewServerConn(conn, serverConfig) 513 if err != nil { 514 panic(fmt.Sprintf("Handshaking error: %v", err)) 515 } 516 517 go func() { 518 for chanReq := range chanReqs { 519 go func(chanReq ssh.NewChannel) { 520 if chanReq.ChannelType() != "session" { 521 chanReq.Reject(ssh.UnknownChannelType, "unknown channel type") 522 return 523 } 524 525 ch, requests, err := chanReq.Accept() 526 if err != nil { 527 panic(fmt.Sprintf("Error accepting channel: %s", err)) 528 } 529 530 go func(ch ssh.Channel, in <-chan *ssh.Request) { 531 for req := range in { 532 executeServerCommand(ch, req) 533 } 534 }(ch, requests) 535 }(chanReq) 536 } 537 sshConn.Close() 538 }() 539 } 540 }() 541 return soc.Addr().String(), nil 542} 543 544// This executes the commands requested to be run on the server. 545// Used to test the SSH secret backend. 546func executeServerCommand(ch ssh.Channel, req *ssh.Request) { 547 command := string(req.Payload[4:]) 548 cmd := exec.Command("/bin/bash", []string{"-c", command}...) 549 req.Reply(true, nil) 550 551 cmd.Stdout = ch 552 cmd.Stderr = ch 553 cmd.Stdin = ch 554 555 err := cmd.Start() 556 if err != nil { 557 panic(fmt.Sprintf("Error starting the command: '%s'", err)) 558 } 559 560 go func() { 561 _, err := cmd.Process.Wait() 562 if err != nil { 563 panic(fmt.Sprintf("Error while waiting for command to finish:'%s'", err)) 564 } 565 ch.Close() 566 }() 567} 568 569// This adds a credential backend for the test core. This needs to be 570// invoked before the test core is created. 571func AddTestCredentialBackend(name string, factory logical.Factory) error { 572 if name == "" { 573 return fmt.Errorf("missing backend name") 574 } 575 if factory == nil { 576 return fmt.Errorf("missing backend factory function") 577 } 578 testCredentialBackends[name] = factory 579 return nil 580} 581 582// This adds a logical backend for the test core. This needs to be 583// invoked before the test core is created. 584func AddTestLogicalBackend(name string, factory logical.Factory) error { 585 if name == "" { 586 return fmt.Errorf("missing backend name") 587 } 588 if factory == nil { 589 return fmt.Errorf("missing backend factory function") 590 } 591 testLogicalBackends[name] = factory 592 return nil 593} 594 595type noopAudit struct { 596 Config *audit.BackendConfig 597 salt *salt.Salt 598 saltMutex sync.RWMutex 599 formatter audit.AuditFormatter 600 records [][]byte 601 l sync.RWMutex 602} 603 604func (n *noopAudit) GetHash(ctx context.Context, data string) (string, error) { 605 salt, err := n.Salt(ctx) 606 if err != nil { 607 return "", err 608 } 609 return salt.GetIdentifiedHMAC(data), nil 610} 611 612func (n *noopAudit) LogRequest(ctx context.Context, in *logical.LogInput) error { 613 n.l.Lock() 614 defer n.l.Unlock() 615 var w bytes.Buffer 616 err := n.formatter.FormatRequest(ctx, &w, audit.FormatterConfig{}, in) 617 if err != nil { 618 return err 619 } 620 n.records = append(n.records, w.Bytes()) 621 return nil 622} 623 624func (n *noopAudit) LogResponse(ctx context.Context, in *logical.LogInput) error { 625 n.l.Lock() 626 defer n.l.Unlock() 627 var w bytes.Buffer 628 err := n.formatter.FormatResponse(ctx, &w, audit.FormatterConfig{}, in) 629 if err != nil { 630 return err 631 } 632 n.records = append(n.records, w.Bytes()) 633 return nil 634} 635 636func (n *noopAudit) Reload(_ context.Context) error { 637 return nil 638} 639 640func (n *noopAudit) Invalidate(_ context.Context) { 641 n.saltMutex.Lock() 642 defer n.saltMutex.Unlock() 643 n.salt = nil 644} 645 646func (n *noopAudit) Salt(ctx context.Context) (*salt.Salt, error) { 647 n.saltMutex.RLock() 648 if n.salt != nil { 649 defer n.saltMutex.RUnlock() 650 return n.salt, nil 651 } 652 n.saltMutex.RUnlock() 653 n.saltMutex.Lock() 654 defer n.saltMutex.Unlock() 655 if n.salt != nil { 656 return n.salt, nil 657 } 658 salt, err := salt.NewSalt(ctx, n.Config.SaltView, n.Config.SaltConfig) 659 if err != nil { 660 return nil, err 661 } 662 n.salt = salt 663 return salt, nil 664} 665 666func AddNoopAudit(conf *CoreConfig) { 667 conf.AuditBackends = map[string]audit.Factory{ 668 "noop": func(_ context.Context, config *audit.BackendConfig) (audit.Backend, error) { 669 view := &logical.InmemStorage{} 670 view.Put(context.Background(), &logical.StorageEntry{ 671 Key: "salt", 672 Value: []byte("foo"), 673 }) 674 n := &noopAudit{ 675 Config: config, 676 } 677 n.formatter.AuditFormatWriter = &audit.JSONFormatWriter{ 678 SaltFunc: n.Salt, 679 } 680 return n, nil 681 }, 682 } 683} 684 685type rawHTTP struct{} 686 687func (n *rawHTTP) HandleRequest(ctx context.Context, req *logical.Request) (*logical.Response, error) { 688 return &logical.Response{ 689 Data: map[string]interface{}{ 690 logical.HTTPStatusCode: 200, 691 logical.HTTPContentType: "plain/text", 692 logical.HTTPRawBody: []byte("hello world"), 693 }, 694 }, nil 695} 696 697func (n *rawHTTP) HandleExistenceCheck(ctx context.Context, req *logical.Request) (bool, bool, error) { 698 return false, false, nil 699} 700 701func (n *rawHTTP) SpecialPaths() *logical.Paths { 702 return &logical.Paths{Unauthenticated: []string{"*"}} 703} 704 705func (n *rawHTTP) System() logical.SystemView { 706 return logical.StaticSystemView{ 707 DefaultLeaseTTLVal: time.Hour * 24, 708 MaxLeaseTTLVal: time.Hour * 24 * 32, 709 } 710} 711 712func (n *rawHTTP) Logger() log.Logger { 713 return logging.NewVaultLogger(log.Trace) 714} 715 716func (n *rawHTTP) Cleanup(ctx context.Context) { 717 // noop 718} 719 720func (n *rawHTTP) Initialize(ctx context.Context, req *logical.InitializationRequest) error { 721 return nil 722} 723 724func (n *rawHTTP) InvalidateKey(context.Context, string) { 725 // noop 726} 727 728func (n *rawHTTP) Setup(ctx context.Context, config *logical.BackendConfig) error { 729 // noop 730 return nil 731} 732 733func (n *rawHTTP) Type() logical.BackendType { 734 return logical.TypeLogical 735} 736 737func GenerateRandBytes(length int) ([]byte, error) { 738 if length < 0 { 739 return nil, fmt.Errorf("length must be >= 0") 740 } 741 742 buf := make([]byte, length) 743 if length == 0 { 744 return buf, nil 745 } 746 747 n, err := rand.Read(buf) 748 if err != nil { 749 return nil, err 750 } 751 if n != length { 752 return nil, fmt.Errorf("unable to read %d bytes; only read %d", length, n) 753 } 754 755 return buf, nil 756} 757 758func TestWaitActive(t testing.T, core *Core) { 759 t.Helper() 760 if err := TestWaitActiveWithError(core); err != nil { 761 t.Fatal(err) 762 } 763} 764 765func TestWaitActiveWithError(core *Core) error { 766 start := time.Now() 767 var standby bool 768 var err error 769 for time.Now().Sub(start) < 30*time.Second { 770 standby, err = core.Standby() 771 if err != nil { 772 return err 773 } 774 if !standby { 775 break 776 } 777 } 778 if standby { 779 return errors.New("should not be in standby mode") 780 } 781 return nil 782} 783 784type TestCluster struct { 785 BarrierKeys [][]byte 786 RecoveryKeys [][]byte 787 CACert *x509.Certificate 788 CACertBytes []byte 789 CACertPEM []byte 790 CACertPEMFile string 791 CAKey *ecdsa.PrivateKey 792 CAKeyPEM []byte 793 Cores []*TestClusterCore 794 ID string 795 RootToken string 796 RootCAs *x509.CertPool 797 TempDir string 798 ClientAuthRequired bool 799 Logger log.Logger 800} 801 802func (c *TestCluster) Start() { 803 for _, core := range c.Cores { 804 if core.Server != nil { 805 for _, ln := range core.Listeners { 806 go core.Server.Serve(ln) 807 } 808 } 809 } 810} 811 812// UnsealCores uses the cluster barrier keys to unseal the test cluster cores 813func (c *TestCluster) UnsealCores(t testing.T) { 814 if err := c.UnsealCoresWithError(); err != nil { 815 t.Fatal(err) 816 } 817} 818 819func (c *TestCluster) UnsealCoresWithError() error { 820 numCores := len(c.Cores) 821 822 // Unseal first core 823 for _, key := range c.BarrierKeys { 824 if _, err := c.Cores[0].Unseal(TestKeyCopy(key)); err != nil { 825 return fmt.Errorf("unseal err: %s", err) 826 } 827 } 828 829 // Verify unsealed 830 if c.Cores[0].Sealed() { 831 return fmt.Errorf("should not be sealed") 832 } 833 834 if err := TestWaitActiveWithError(c.Cores[0].Core); err != nil { 835 return err 836 } 837 838 // Unseal other cores 839 for i := 1; i < numCores; i++ { 840 for _, key := range c.BarrierKeys { 841 if _, err := c.Cores[i].Core.Unseal(TestKeyCopy(key)); err != nil { 842 return fmt.Errorf("unseal err: %s", err) 843 } 844 } 845 } 846 847 // Let them come fully up to standby 848 time.Sleep(2 * time.Second) 849 850 // Ensure cluster connection info is populated. 851 // Other cores should not come up as leaders. 852 for i := 1; i < numCores; i++ { 853 isLeader, _, _, err := c.Cores[i].Leader() 854 if err != nil { 855 return err 856 } 857 if isLeader { 858 return fmt.Errorf("core[%d] should not be leader", i) 859 } 860 } 861 862 return nil 863} 864 865func (c *TestCluster) UnsealCore(t testing.T, core *TestClusterCore) { 866 for _, key := range c.BarrierKeys { 867 if _, err := core.Core.Unseal(TestKeyCopy(key)); err != nil { 868 t.Fatalf("unseal err: %s", err) 869 } 870 } 871} 872 873func (c *TestCluster) EnsureCoresSealed(t testing.T) { 874 t.Helper() 875 if err := c.ensureCoresSealed(); err != nil { 876 t.Fatal(err) 877 } 878} 879 880func (c *TestClusterCore) Seal(t testing.T) { 881 t.Helper() 882 if err := c.Core.sealInternal(); err != nil { 883 t.Fatal(err) 884 } 885} 886 887func CleanupClusters(clusters []*TestCluster) { 888 wg := &sync.WaitGroup{} 889 for _, cluster := range clusters { 890 wg.Add(1) 891 lc := cluster 892 go func() { 893 defer wg.Done() 894 lc.Cleanup() 895 }() 896 } 897 wg.Wait() 898} 899 900func (c *TestCluster) Cleanup() { 901 c.Logger.Info("cleaning up vault cluster") 902 for _, core := range c.Cores { 903 core.CoreConfig.Logger.SetLevel(log.Error) 904 } 905 906 // Close listeners 907 wg := &sync.WaitGroup{} 908 for _, core := range c.Cores { 909 wg.Add(1) 910 lc := core 911 912 go func() { 913 defer wg.Done() 914 if lc.Listeners != nil { 915 for _, ln := range lc.Listeners { 916 ln.Close() 917 } 918 } 919 if lc.licensingStopCh != nil { 920 close(lc.licensingStopCh) 921 lc.licensingStopCh = nil 922 } 923 924 if err := lc.Shutdown(); err != nil { 925 lc.Logger().Error("error during shutdown; abandoning sealing", "error", err) 926 } else { 927 timeout := time.Now().Add(60 * time.Second) 928 for { 929 if time.Now().After(timeout) { 930 lc.Logger().Error("timeout waiting for core to seal") 931 } 932 if lc.Sealed() { 933 break 934 } 935 time.Sleep(250 * time.Millisecond) 936 } 937 } 938 }() 939 } 940 941 wg.Wait() 942 943 // Remove any temp dir that exists 944 if c.TempDir != "" { 945 os.RemoveAll(c.TempDir) 946 } 947 948 // Give time to actually shut down/clean up before the next test 949 time.Sleep(time.Second) 950} 951 952func (c *TestCluster) ensureCoresSealed() error { 953 for _, core := range c.Cores { 954 if err := core.Shutdown(); err != nil { 955 return err 956 } 957 timeout := time.Now().Add(60 * time.Second) 958 for { 959 if time.Now().After(timeout) { 960 return fmt.Errorf("timeout waiting for core to seal") 961 } 962 if core.Sealed() { 963 break 964 } 965 time.Sleep(250 * time.Millisecond) 966 } 967 } 968 return nil 969} 970 971// UnsealWithStoredKeys uses stored keys to unseal the test cluster cores 972func (c *TestCluster) UnsealWithStoredKeys(t testing.T) error { 973 for _, core := range c.Cores { 974 if err := core.UnsealWithStoredKeys(context.Background()); err != nil { 975 return err 976 } 977 timeout := time.Now().Add(60 * time.Second) 978 for { 979 if time.Now().After(timeout) { 980 return fmt.Errorf("timeout waiting for core to unseal") 981 } 982 if !core.Sealed() { 983 break 984 } 985 time.Sleep(250 * time.Millisecond) 986 } 987 } 988 return nil 989} 990 991func SetReplicationFailureMode(core *TestClusterCore, mode uint32) { 992 atomic.StoreUint32(core.Core.replicationFailure, mode) 993} 994 995type TestListener struct { 996 net.Listener 997 Address *net.TCPAddr 998} 999 1000type TestClusterCore struct { 1001 *Core 1002 CoreConfig *CoreConfig 1003 Client *api.Client 1004 Handler http.Handler 1005 Listeners []*TestListener 1006 ReloadFuncs *map[string][]reload.ReloadFunc 1007 ReloadFuncsLock *sync.RWMutex 1008 Server *http.Server 1009 ServerCert *x509.Certificate 1010 ServerCertBytes []byte 1011 ServerCertPEM []byte 1012 ServerKey *ecdsa.PrivateKey 1013 ServerKeyPEM []byte 1014 TLSConfig *tls.Config 1015 UnderlyingStorage physical.Backend 1016 UnderlyingRawStorage physical.Backend 1017 Barrier SecurityBarrier 1018 NodeID string 1019} 1020 1021type TestClusterOptions struct { 1022 KeepStandbysSealed bool 1023 SkipInit bool 1024 HandlerFunc func(*HandlerProperties) http.Handler 1025 BaseListenAddress string 1026 NumCores int 1027 SealFunc func() Seal 1028 Logger log.Logger 1029 TempDir string 1030 CACert []byte 1031 CAKey *ecdsa.PrivateKey 1032 PhysicalFactory func(hclog.Logger) (physical.Backend, error) 1033 FirstCoreNumber int 1034 RequireClientAuth bool 1035} 1036 1037var DefaultNumCores = 3 1038 1039type certInfo struct { 1040 cert *x509.Certificate 1041 certPEM []byte 1042 certBytes []byte 1043 key *ecdsa.PrivateKey 1044 keyPEM []byte 1045} 1046 1047// NewTestCluster creates a new test cluster based on the provided core config 1048// and test cluster options. 1049// 1050// N.B. Even though a single base CoreConfig is provided, NewTestCluster will instantiate a 1051// core config for each core it creates. If separate seal per core is desired, opts.SealFunc 1052// can be provided to generate a seal for each one. Otherwise, the provided base.Seal will be 1053// shared among cores. NewCore's default behavior is to generate a new DefaultSeal if the 1054// provided Seal in coreConfig (i.e. base.Seal) is nil. 1055// 1056// If opts.Logger is provided, it takes precedence and will be used as the cluster 1057// logger and will be the basis for each core's logger. If no opts.Logger is 1058// given, one will be generated based on t.Name() for the cluster logger, and if 1059// no base.Logger is given will also be used as the basis for each core's logger. 1060func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *TestCluster { 1061 var err error 1062 1063 var numCores int 1064 if opts == nil || opts.NumCores == 0 { 1065 numCores = DefaultNumCores 1066 } else { 1067 numCores = opts.NumCores 1068 } 1069 1070 var firstCoreNumber int 1071 if opts != nil { 1072 firstCoreNumber = opts.FirstCoreNumber 1073 } 1074 1075 certIPs := []net.IP{ 1076 net.IPv6loopback, 1077 net.ParseIP("127.0.0.1"), 1078 } 1079 var baseAddr *net.TCPAddr 1080 if opts != nil && opts.BaseListenAddress != "" { 1081 baseAddr, err = net.ResolveTCPAddr("tcp", opts.BaseListenAddress) 1082 if err != nil { 1083 t.Fatal("could not parse given base IP") 1084 } 1085 certIPs = append(certIPs, baseAddr.IP) 1086 } 1087 1088 var testCluster TestCluster 1089 1090 if opts != nil && opts.Logger != nil { 1091 testCluster.Logger = opts.Logger 1092 } else { 1093 testCluster.Logger = logging.NewVaultLogger(log.Trace).Named(t.Name()) 1094 } 1095 1096 if opts != nil && opts.TempDir != "" { 1097 if _, err := os.Stat(opts.TempDir); os.IsNotExist(err) { 1098 if err := os.MkdirAll(opts.TempDir, 0700); err != nil { 1099 t.Fatal(err) 1100 } 1101 } 1102 testCluster.TempDir = opts.TempDir 1103 } else { 1104 tempDir, err := ioutil.TempDir("", "vault-test-cluster-") 1105 if err != nil { 1106 t.Fatal(err) 1107 } 1108 testCluster.TempDir = tempDir 1109 } 1110 1111 var caKey *ecdsa.PrivateKey 1112 if opts != nil && opts.CAKey != nil { 1113 caKey = opts.CAKey 1114 } else { 1115 caKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 1116 if err != nil { 1117 t.Fatal(err) 1118 } 1119 } 1120 testCluster.CAKey = caKey 1121 var caBytes []byte 1122 if opts != nil && len(opts.CACert) > 0 { 1123 caBytes = opts.CACert 1124 } else { 1125 caCertTemplate := &x509.Certificate{ 1126 Subject: pkix.Name{ 1127 CommonName: "localhost", 1128 }, 1129 DNSNames: []string{"localhost"}, 1130 IPAddresses: certIPs, 1131 KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign), 1132 SerialNumber: big.NewInt(mathrand.Int63()), 1133 NotBefore: time.Now().Add(-30 * time.Second), 1134 NotAfter: time.Now().Add(262980 * time.Hour), 1135 BasicConstraintsValid: true, 1136 IsCA: true, 1137 } 1138 caBytes, err = x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, caKey.Public(), caKey) 1139 if err != nil { 1140 t.Fatal(err) 1141 } 1142 } 1143 caCert, err := x509.ParseCertificate(caBytes) 1144 if err != nil { 1145 t.Fatal(err) 1146 } 1147 testCluster.CACert = caCert 1148 testCluster.CACertBytes = caBytes 1149 testCluster.RootCAs = x509.NewCertPool() 1150 testCluster.RootCAs.AddCert(caCert) 1151 caCertPEMBlock := &pem.Block{ 1152 Type: "CERTIFICATE", 1153 Bytes: caBytes, 1154 } 1155 testCluster.CACertPEM = pem.EncodeToMemory(caCertPEMBlock) 1156 testCluster.CACertPEMFile = filepath.Join(testCluster.TempDir, "ca_cert.pem") 1157 err = ioutil.WriteFile(testCluster.CACertPEMFile, testCluster.CACertPEM, 0755) 1158 if err != nil { 1159 t.Fatal(err) 1160 } 1161 marshaledCAKey, err := x509.MarshalECPrivateKey(caKey) 1162 if err != nil { 1163 t.Fatal(err) 1164 } 1165 caKeyPEMBlock := &pem.Block{ 1166 Type: "EC PRIVATE KEY", 1167 Bytes: marshaledCAKey, 1168 } 1169 testCluster.CAKeyPEM = pem.EncodeToMemory(caKeyPEMBlock) 1170 err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "ca_key.pem"), testCluster.CAKeyPEM, 0755) 1171 if err != nil { 1172 t.Fatal(err) 1173 } 1174 1175 var certInfoSlice []*certInfo 1176 1177 // 1178 // Certs generation 1179 // 1180 for i := 0; i < numCores; i++ { 1181 key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 1182 if err != nil { 1183 t.Fatal(err) 1184 } 1185 certTemplate := &x509.Certificate{ 1186 Subject: pkix.Name{ 1187 CommonName: "localhost", 1188 }, 1189 DNSNames: []string{"localhost"}, 1190 IPAddresses: certIPs, 1191 ExtKeyUsage: []x509.ExtKeyUsage{ 1192 x509.ExtKeyUsageServerAuth, 1193 x509.ExtKeyUsageClientAuth, 1194 }, 1195 KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, 1196 SerialNumber: big.NewInt(mathrand.Int63()), 1197 NotBefore: time.Now().Add(-30 * time.Second), 1198 NotAfter: time.Now().Add(262980 * time.Hour), 1199 } 1200 certBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, caCert, key.Public(), caKey) 1201 if err != nil { 1202 t.Fatal(err) 1203 } 1204 cert, err := x509.ParseCertificate(certBytes) 1205 if err != nil { 1206 t.Fatal(err) 1207 } 1208 certPEMBlock := &pem.Block{ 1209 Type: "CERTIFICATE", 1210 Bytes: certBytes, 1211 } 1212 certPEM := pem.EncodeToMemory(certPEMBlock) 1213 marshaledKey, err := x509.MarshalECPrivateKey(key) 1214 if err != nil { 1215 t.Fatal(err) 1216 } 1217 keyPEMBlock := &pem.Block{ 1218 Type: "EC PRIVATE KEY", 1219 Bytes: marshaledKey, 1220 } 1221 keyPEM := pem.EncodeToMemory(keyPEMBlock) 1222 1223 certInfoSlice = append(certInfoSlice, &certInfo{ 1224 cert: cert, 1225 certPEM: certPEM, 1226 certBytes: certBytes, 1227 key: key, 1228 keyPEM: keyPEM, 1229 }) 1230 } 1231 1232 // 1233 // Listener setup 1234 // 1235 ports := make([]int, numCores) 1236 if baseAddr != nil { 1237 for i := 0; i < numCores; i++ { 1238 ports[i] = baseAddr.Port + i 1239 } 1240 } else { 1241 baseAddr = &net.TCPAddr{ 1242 IP: net.ParseIP("127.0.0.1"), 1243 Port: 0, 1244 } 1245 } 1246 1247 listeners := [][]*TestListener{} 1248 servers := []*http.Server{} 1249 handlers := []http.Handler{} 1250 tlsConfigs := []*tls.Config{} 1251 certGetters := []*reload.CertificateGetter{} 1252 for i := 0; i < numCores; i++ { 1253 baseAddr.Port = ports[i] 1254 ln, err := net.ListenTCP("tcp", baseAddr) 1255 if err != nil { 1256 t.Fatal(err) 1257 } 1258 certFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_cert.pem", i+1, ln.Addr().(*net.TCPAddr).Port)) 1259 keyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_key.pem", i+1, ln.Addr().(*net.TCPAddr).Port)) 1260 err = ioutil.WriteFile(certFile, certInfoSlice[i].certPEM, 0755) 1261 if err != nil { 1262 t.Fatal(err) 1263 } 1264 err = ioutil.WriteFile(keyFile, certInfoSlice[i].keyPEM, 0755) 1265 if err != nil { 1266 t.Fatal(err) 1267 } 1268 tlsCert, err := tls.X509KeyPair(certInfoSlice[i].certPEM, certInfoSlice[i].keyPEM) 1269 if err != nil { 1270 t.Fatal(err) 1271 } 1272 certGetter := reload.NewCertificateGetter(certFile, keyFile, "") 1273 certGetters = append(certGetters, certGetter) 1274 tlsConfig := &tls.Config{ 1275 Certificates: []tls.Certificate{tlsCert}, 1276 RootCAs: testCluster.RootCAs, 1277 ClientCAs: testCluster.RootCAs, 1278 ClientAuth: tls.RequestClientCert, 1279 NextProtos: []string{"h2", "http/1.1"}, 1280 GetCertificate: certGetter.GetCertificate, 1281 } 1282 if opts != nil && opts.RequireClientAuth { 1283 tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert 1284 testCluster.ClientAuthRequired = true 1285 } 1286 tlsConfig.BuildNameToCertificate() 1287 tlsConfigs = append(tlsConfigs, tlsConfig) 1288 lns := []*TestListener{&TestListener{ 1289 Listener: tls.NewListener(ln, tlsConfig), 1290 Address: ln.Addr().(*net.TCPAddr), 1291 }, 1292 } 1293 listeners = append(listeners, lns) 1294 var handler http.Handler = http.NewServeMux() 1295 handlers = append(handlers, handler) 1296 server := &http.Server{ 1297 Handler: handler, 1298 ErrorLog: testCluster.Logger.StandardLogger(nil), 1299 } 1300 servers = append(servers, server) 1301 } 1302 1303 // Create three cores with the same physical and different redirect/cluster 1304 // addrs. 1305 // N.B.: On OSX, instead of random ports, it assigns new ports to new 1306 // listeners sequentially. Aside from being a bad idea in a security sense, 1307 // it also broke tests that assumed it was OK to just use the port above 1308 // the redirect addr. This has now been changed to 105 ports above, but if 1309 // we ever do more than three nodes in a cluster it may need to be bumped. 1310 // Note: it's 105 so that we don't conflict with a running Consul by 1311 // default. 1312 coreConfig := &CoreConfig{ 1313 LogicalBackends: make(map[string]logical.Factory), 1314 CredentialBackends: make(map[string]logical.Factory), 1315 AuditBackends: make(map[string]audit.Factory), 1316 RedirectAddr: fmt.Sprintf("https://127.0.0.1:%d", listeners[0][0].Address.Port), 1317 ClusterAddr: "https://127.0.0.1:0", 1318 DisableMlock: true, 1319 EnableUI: true, 1320 EnableRaw: true, 1321 BuiltinRegistry: NewMockBuiltinRegistry(), 1322 } 1323 1324 if base != nil { 1325 coreConfig.DisableCache = base.DisableCache 1326 coreConfig.EnableUI = base.EnableUI 1327 coreConfig.DefaultLeaseTTL = base.DefaultLeaseTTL 1328 coreConfig.MaxLeaseTTL = base.MaxLeaseTTL 1329 coreConfig.CacheSize = base.CacheSize 1330 coreConfig.PluginDirectory = base.PluginDirectory 1331 coreConfig.Seal = base.Seal 1332 coreConfig.DevToken = base.DevToken 1333 coreConfig.EnableRaw = base.EnableRaw 1334 coreConfig.DisableSealWrap = base.DisableSealWrap 1335 coreConfig.DevLicenseDuration = base.DevLicenseDuration 1336 coreConfig.DisableCache = base.DisableCache 1337 coreConfig.LicensingConfig = base.LicensingConfig 1338 coreConfig.DisablePerformanceStandby = base.DisablePerformanceStandby 1339 coreConfig.MetricsHelper = base.MetricsHelper 1340 if base.BuiltinRegistry != nil { 1341 coreConfig.BuiltinRegistry = base.BuiltinRegistry 1342 } 1343 1344 if !coreConfig.DisableMlock { 1345 base.DisableMlock = false 1346 } 1347 1348 if base.Physical != nil { 1349 coreConfig.Physical = base.Physical 1350 } 1351 1352 if base.HAPhysical != nil { 1353 coreConfig.HAPhysical = base.HAPhysical 1354 } 1355 1356 // Used to set something non-working to test fallback 1357 switch base.ClusterAddr { 1358 case "empty": 1359 coreConfig.ClusterAddr = "" 1360 case "": 1361 default: 1362 coreConfig.ClusterAddr = base.ClusterAddr 1363 } 1364 1365 if base.LogicalBackends != nil { 1366 for k, v := range base.LogicalBackends { 1367 coreConfig.LogicalBackends[k] = v 1368 } 1369 } 1370 if base.CredentialBackends != nil { 1371 for k, v := range base.CredentialBackends { 1372 coreConfig.CredentialBackends[k] = v 1373 } 1374 } 1375 if base.AuditBackends != nil { 1376 for k, v := range base.AuditBackends { 1377 coreConfig.AuditBackends[k] = v 1378 } 1379 } 1380 if base.Logger != nil { 1381 coreConfig.Logger = base.Logger 1382 } 1383 1384 coreConfig.ClusterCipherSuites = base.ClusterCipherSuites 1385 1386 coreConfig.DisableCache = base.DisableCache 1387 1388 coreConfig.DevToken = base.DevToken 1389 coreConfig.CounterSyncInterval = base.CounterSyncInterval 1390 1391 } 1392 1393 addAuditBackend := len(coreConfig.AuditBackends) == 0 1394 if addAuditBackend { 1395 AddNoopAudit(coreConfig) 1396 } 1397 1398 if coreConfig.Physical == nil && (opts == nil || opts.PhysicalFactory == nil) { 1399 coreConfig.Physical, err = physInmem.NewInmem(nil, testCluster.Logger) 1400 if err != nil { 1401 t.Fatal(err) 1402 } 1403 } 1404 if coreConfig.HAPhysical == nil && (opts == nil || opts.PhysicalFactory == nil) { 1405 haPhys, err := physInmem.NewInmemHA(nil, testCluster.Logger) 1406 if err != nil { 1407 t.Fatal(err) 1408 } 1409 coreConfig.HAPhysical = haPhys.(physical.HABackend) 1410 } 1411 1412 pubKey, priKey, err := testGenerateCoreKeys() 1413 if err != nil { 1414 t.Fatalf("err: %v", err) 1415 } 1416 1417 cores := []*Core{} 1418 coreConfigs := []*CoreConfig{} 1419 for i := 0; i < numCores; i++ { 1420 localConfig := *coreConfig 1421 localConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", listeners[i][0].Address.Port) 1422 1423 // if opts.SealFunc is provided, use that to generate a seal for the config instead 1424 if opts != nil && opts.SealFunc != nil { 1425 localConfig.Seal = opts.SealFunc() 1426 } 1427 1428 if coreConfig.Logger == nil || (opts != nil && opts.Logger != nil) { 1429 localConfig.Logger = testCluster.Logger.Named(fmt.Sprintf("core%d", i)) 1430 } 1431 1432 if opts != nil && opts.PhysicalFactory != nil { 1433 localConfig.Physical, err = opts.PhysicalFactory(localConfig.Logger) 1434 if err != nil { 1435 t.Fatalf("err: %v", err) 1436 } 1437 1438 if haPhysical, ok := localConfig.Physical.(physical.HABackend); ok { 1439 localConfig.HAPhysical = haPhysical 1440 } 1441 } 1442 1443 switch { 1444 case localConfig.LicensingConfig != nil: 1445 if pubKey != nil { 1446 localConfig.LicensingConfig.AdditionalPublicKeys = append(localConfig.LicensingConfig.AdditionalPublicKeys, pubKey.(ed25519.PublicKey)) 1447 } 1448 default: 1449 localConfig.LicensingConfig = testGetLicensingConfig(pubKey) 1450 } 1451 1452 c, err := NewCore(&localConfig) 1453 if err != nil { 1454 t.Fatalf("err: %v", err) 1455 } 1456 c.coreNumber = firstCoreNumber + i 1457 cores = append(cores, c) 1458 coreConfigs = append(coreConfigs, &localConfig) 1459 if opts != nil && opts.HandlerFunc != nil { 1460 handlers[i] = opts.HandlerFunc(&HandlerProperties{ 1461 Core: c, 1462 MaxRequestDuration: DefaultMaxRequestDuration, 1463 }) 1464 servers[i].Handler = handlers[i] 1465 } 1466 1467 // Set this in case the Seal was manually set before the core was 1468 // created 1469 if localConfig.Seal != nil { 1470 localConfig.Seal.SetCore(c) 1471 } 1472 } 1473 1474 // 1475 // Clustering setup 1476 // 1477 clusterAddrGen := func(lns []*TestListener) []*net.TCPAddr { 1478 ret := make([]*net.TCPAddr, len(lns)) 1479 for i, ln := range lns { 1480 ret[i] = &net.TCPAddr{ 1481 IP: ln.Address.IP, 1482 Port: 0, 1483 } 1484 } 1485 return ret 1486 } 1487 1488 for i := 0; i < numCores; i++ { 1489 if coreConfigs[i].ClusterAddr != "" { 1490 cores[i].SetClusterListenerAddrs(clusterAddrGen(listeners[i])) 1491 cores[i].SetClusterHandler(handlers[i]) 1492 } 1493 } 1494 1495 if opts == nil || !opts.SkipInit { 1496 bKeys, rKeys, root := TestCoreInitClusterWrapperSetup(t, cores[0], handlers[0]) 1497 barrierKeys, _ := copystructure.Copy(bKeys) 1498 testCluster.BarrierKeys = barrierKeys.([][]byte) 1499 recoveryKeys, _ := copystructure.Copy(rKeys) 1500 testCluster.RecoveryKeys = recoveryKeys.([][]byte) 1501 testCluster.RootToken = root 1502 1503 // Write root token and barrier keys 1504 err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "root_token"), []byte(root), 0755) 1505 if err != nil { 1506 t.Fatal(err) 1507 } 1508 var buf bytes.Buffer 1509 for i, key := range testCluster.BarrierKeys { 1510 buf.Write([]byte(base64.StdEncoding.EncodeToString(key))) 1511 if i < len(testCluster.BarrierKeys)-1 { 1512 buf.WriteRune('\n') 1513 } 1514 } 1515 err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "barrier_keys"), buf.Bytes(), 0755) 1516 if err != nil { 1517 t.Fatal(err) 1518 } 1519 for i, key := range testCluster.RecoveryKeys { 1520 buf.Write([]byte(base64.StdEncoding.EncodeToString(key))) 1521 if i < len(testCluster.RecoveryKeys)-1 { 1522 buf.WriteRune('\n') 1523 } 1524 } 1525 err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "recovery_keys"), buf.Bytes(), 0755) 1526 if err != nil { 1527 t.Fatal(err) 1528 } 1529 1530 // Unseal first core 1531 for _, key := range bKeys { 1532 if _, err := cores[0].Unseal(TestKeyCopy(key)); err != nil { 1533 t.Fatalf("unseal err: %s", err) 1534 } 1535 } 1536 1537 ctx := context.Background() 1538 1539 // If stored keys is supported, the above will no no-op, so trigger auto-unseal 1540 // using stored keys to try to unseal 1541 if err := cores[0].UnsealWithStoredKeys(ctx); err != nil { 1542 t.Fatal(err) 1543 } 1544 1545 // Verify unsealed 1546 if cores[0].Sealed() { 1547 t.Fatal("should not be sealed") 1548 } 1549 1550 TestWaitActive(t, cores[0]) 1551 1552 // Existing tests rely on this; we can make a toggle to disable it 1553 // later if we want 1554 kvReq := &logical.Request{ 1555 Operation: logical.UpdateOperation, 1556 ClientToken: testCluster.RootToken, 1557 Path: "sys/mounts/secret", 1558 Data: map[string]interface{}{ 1559 "type": "kv", 1560 "path": "secret/", 1561 "description": "key/value secret storage", 1562 "options": map[string]string{ 1563 "version": "1", 1564 }, 1565 }, 1566 } 1567 resp, err := cores[0].HandleRequest(namespace.RootContext(ctx), kvReq) 1568 if err != nil { 1569 t.Fatal(err) 1570 } 1571 if resp.IsError() { 1572 t.Fatal(err) 1573 } 1574 1575 // Unseal other cores unless otherwise specified 1576 if (opts == nil || !opts.KeepStandbysSealed) && numCores > 1 { 1577 for i := 1; i < numCores; i++ { 1578 for _, key := range bKeys { 1579 if _, err := cores[i].Unseal(TestKeyCopy(key)); err != nil { 1580 t.Fatalf("unseal err: %s", err) 1581 } 1582 } 1583 1584 // If stored keys is supported, the above will no no-op, so trigger auto-unseal 1585 // using stored keys 1586 if err := cores[i].UnsealWithStoredKeys(ctx); err != nil { 1587 t.Fatal(err) 1588 } 1589 } 1590 1591 // Let them come fully up to standby 1592 time.Sleep(2 * time.Second) 1593 1594 // Ensure cluster connection info is populated. 1595 // Other cores should not come up as leaders. 1596 for i := 1; i < numCores; i++ { 1597 isLeader, _, _, err := cores[i].Leader() 1598 if err != nil { 1599 t.Fatal(err) 1600 } 1601 if isLeader { 1602 t.Fatalf("core[%d] should not be leader", i) 1603 } 1604 } 1605 } 1606 1607 // 1608 // Set test cluster core(s) and test cluster 1609 // 1610 cluster, err := cores[0].Cluster(context.Background()) 1611 if err != nil { 1612 t.Fatal(err) 1613 } 1614 testCluster.ID = cluster.ID 1615 1616 if addAuditBackend { 1617 // Enable auditing. 1618 auditReq := &logical.Request{ 1619 Operation: logical.UpdateOperation, 1620 ClientToken: testCluster.RootToken, 1621 Path: "sys/audit/noop", 1622 Data: map[string]interface{}{ 1623 "type": "noop", 1624 }, 1625 } 1626 resp, err = cores[0].HandleRequest(namespace.RootContext(ctx), auditReq) 1627 if err != nil { 1628 t.Fatal(err) 1629 } 1630 1631 if resp.IsError() { 1632 t.Fatal(err) 1633 } 1634 } 1635 } 1636 1637 getAPIClient := func(port int, tlsConfig *tls.Config) *api.Client { 1638 transport := cleanhttp.DefaultPooledTransport() 1639 transport.TLSClientConfig = tlsConfig.Clone() 1640 if err := http2.ConfigureTransport(transport); err != nil { 1641 t.Fatal(err) 1642 } 1643 client := &http.Client{ 1644 Transport: transport, 1645 CheckRedirect: func(*http.Request, []*http.Request) error { 1646 // This can of course be overridden per-test by using its own client 1647 return fmt.Errorf("redirects not allowed in these tests") 1648 }, 1649 } 1650 config := api.DefaultConfig() 1651 if config.Error != nil { 1652 t.Fatal(config.Error) 1653 } 1654 config.Address = fmt.Sprintf("https://127.0.0.1:%d", port) 1655 config.HttpClient = client 1656 config.MaxRetries = 0 1657 apiClient, err := api.NewClient(config) 1658 if err != nil { 1659 t.Fatal(err) 1660 } 1661 if opts == nil || !opts.SkipInit { 1662 apiClient.SetToken(testCluster.RootToken) 1663 } 1664 return apiClient 1665 } 1666 1667 var ret []*TestClusterCore 1668 for i := 0; i < numCores; i++ { 1669 tcc := &TestClusterCore{ 1670 Core: cores[i], 1671 CoreConfig: coreConfigs[i], 1672 ServerKey: certInfoSlice[i].key, 1673 ServerKeyPEM: certInfoSlice[i].keyPEM, 1674 ServerCert: certInfoSlice[i].cert, 1675 ServerCertBytes: certInfoSlice[i].certBytes, 1676 ServerCertPEM: certInfoSlice[i].certPEM, 1677 Listeners: listeners[i], 1678 Handler: handlers[i], 1679 Server: servers[i], 1680 TLSConfig: tlsConfigs[i], 1681 Client: getAPIClient(listeners[i][0].Address.Port, tlsConfigs[i]), 1682 Barrier: cores[i].barrier, 1683 NodeID: fmt.Sprintf("core-%d", i), 1684 UnderlyingRawStorage: coreConfigs[i].Physical, 1685 } 1686 tcc.ReloadFuncs = &cores[i].reloadFuncs 1687 tcc.ReloadFuncsLock = &cores[i].reloadFuncsLock 1688 tcc.ReloadFuncsLock.Lock() 1689 (*tcc.ReloadFuncs)["listener|tcp"] = []reload.ReloadFunc{certGetters[i].Reload} 1690 tcc.ReloadFuncsLock.Unlock() 1691 1692 testAdjustTestCore(base, tcc) 1693 1694 ret = append(ret, tcc) 1695 } 1696 1697 testCluster.Cores = ret 1698 1699 testExtraClusterCoresTestSetup(t, priKey, testCluster.Cores) 1700 1701 return &testCluster 1702} 1703 1704func NewMockBuiltinRegistry() *mockBuiltinRegistry { 1705 return &mockBuiltinRegistry{ 1706 forTesting: map[string]consts.PluginType{ 1707 "mysql-database-plugin": consts.PluginTypeDatabase, 1708 "postgresql-database-plugin": consts.PluginTypeDatabase, 1709 }, 1710 } 1711} 1712 1713type mockBuiltinRegistry struct { 1714 forTesting map[string]consts.PluginType 1715} 1716 1717func (m *mockBuiltinRegistry) Get(name string, pluginType consts.PluginType) (func() (interface{}, error), bool) { 1718 testPluginType, ok := m.forTesting[name] 1719 if !ok { 1720 return nil, false 1721 } 1722 if pluginType != testPluginType { 1723 return nil, false 1724 } 1725 if name == "postgresql-database-plugin" { 1726 return dbPostgres.New, true 1727 } 1728 return dbMysql.New(dbMysql.MetadataLen, dbMysql.MetadataLen, dbMysql.UsernameLen), true 1729} 1730 1731// Keys only supports getting a realistic list of the keys for database plugins. 1732func (m *mockBuiltinRegistry) Keys(pluginType consts.PluginType) []string { 1733 if pluginType != consts.PluginTypeDatabase { 1734 return []string{} 1735 } 1736 /* 1737 This is a hard-coded reproduction of the db plugin keys in helper/builtinplugins/registry.go. 1738 The registry isn't directly used because it causes import cycles. 1739 */ 1740 return []string{ 1741 "mysql-database-plugin", 1742 "mysql-aurora-database-plugin", 1743 "mysql-rds-database-plugin", 1744 "mysql-legacy-database-plugin", 1745 "postgresql-database-plugin", 1746 "elasticsearch-database-plugin", 1747 "mssql-database-plugin", 1748 "cassandra-database-plugin", 1749 "mongodb-database-plugin", 1750 "hana-database-plugin", 1751 "influxdb-database-plugin", 1752 } 1753} 1754 1755func (m *mockBuiltinRegistry) Contains(name string, pluginType consts.PluginType) bool { 1756 return false 1757} 1758 1759type NoopAudit struct { 1760 Config *audit.BackendConfig 1761 ReqErr error 1762 ReqAuth []*logical.Auth 1763 Req []*logical.Request 1764 ReqHeaders []map[string][]string 1765 ReqNonHMACKeys []string 1766 ReqErrs []error 1767 1768 RespErr error 1769 RespAuth []*logical.Auth 1770 RespReq []*logical.Request 1771 Resp []*logical.Response 1772 RespNonHMACKeys []string 1773 RespReqNonHMACKeys []string 1774 RespErrs []error 1775 1776 salt *salt.Salt 1777 saltMutex sync.RWMutex 1778} 1779 1780func (n *NoopAudit) LogRequest(ctx context.Context, in *logical.LogInput) error { 1781 n.ReqAuth = append(n.ReqAuth, in.Auth) 1782 n.Req = append(n.Req, in.Request) 1783 n.ReqHeaders = append(n.ReqHeaders, in.Request.Headers) 1784 n.ReqNonHMACKeys = in.NonHMACReqDataKeys 1785 n.ReqErrs = append(n.ReqErrs, in.OuterErr) 1786 return n.ReqErr 1787} 1788 1789func (n *NoopAudit) LogResponse(ctx context.Context, in *logical.LogInput) error { 1790 n.RespAuth = append(n.RespAuth, in.Auth) 1791 n.RespReq = append(n.RespReq, in.Request) 1792 n.Resp = append(n.Resp, in.Response) 1793 n.RespErrs = append(n.RespErrs, in.OuterErr) 1794 1795 if in.Response != nil { 1796 n.RespNonHMACKeys = in.NonHMACRespDataKeys 1797 n.RespReqNonHMACKeys = in.NonHMACReqDataKeys 1798 } 1799 1800 return n.RespErr 1801} 1802 1803func (n *NoopAudit) Salt(ctx context.Context) (*salt.Salt, error) { 1804 n.saltMutex.RLock() 1805 if n.salt != nil { 1806 defer n.saltMutex.RUnlock() 1807 return n.salt, nil 1808 } 1809 n.saltMutex.RUnlock() 1810 n.saltMutex.Lock() 1811 defer n.saltMutex.Unlock() 1812 if n.salt != nil { 1813 return n.salt, nil 1814 } 1815 salt, err := salt.NewSalt(ctx, n.Config.SaltView, n.Config.SaltConfig) 1816 if err != nil { 1817 return nil, err 1818 } 1819 n.salt = salt 1820 return salt, nil 1821} 1822 1823func (n *NoopAudit) GetHash(ctx context.Context, data string) (string, error) { 1824 salt, err := n.Salt(ctx) 1825 if err != nil { 1826 return "", err 1827 } 1828 return salt.GetIdentifiedHMAC(data), nil 1829} 1830 1831func (n *NoopAudit) Reload(ctx context.Context) error { 1832 return nil 1833} 1834 1835func (n *NoopAudit) Invalidate(ctx context.Context) { 1836 n.saltMutex.Lock() 1837 defer n.saltMutex.Unlock() 1838 n.salt = nil 1839} 1840