1package vault 2 3import ( 4 "bytes" 5 "context" 6 "crypto/ecdsa" 7 "crypto/elliptic" 8 "crypto/rand" 9 "crypto/sha256" 10 "crypto/tls" 11 "crypto/x509" 12 "crypto/x509/pkix" 13 "encoding/base64" 14 "encoding/pem" 15 "errors" 16 "fmt" 17 "io" 18 "io/ioutil" 19 "math/big" 20 mathrand "math/rand" 21 "net" 22 "net/http" 23 "os" 24 "os/exec" 25 "path/filepath" 26 "sync" 27 "sync/atomic" 28 "time" 29 30 log "github.com/hashicorp/go-hclog" 31 "github.com/mitchellh/copystructure" 32 33 "golang.org/x/crypto/ssh" 34 "golang.org/x/net/http2" 35 36 cleanhttp "github.com/hashicorp/go-cleanhttp" 37 "github.com/hashicorp/vault/api" 38 "github.com/hashicorp/vault/audit" 39 "github.com/hashicorp/vault/helper/consts" 40 "github.com/hashicorp/vault/helper/logging" 41 "github.com/hashicorp/vault/helper/reload" 42 "github.com/hashicorp/vault/helper/salt" 43 "github.com/hashicorp/vault/logical" 44 "github.com/hashicorp/vault/logical/framework" 45 "github.com/hashicorp/vault/physical" 46 dbMysql "github.com/hashicorp/vault/plugins/database/mysql" 47 dbPostgres "github.com/hashicorp/vault/plugins/database/postgresql" 48 testing "github.com/mitchellh/go-testing-interface" 49 50 physInmem "github.com/hashicorp/vault/physical/inmem" 51) 52 53// This file contains a number of methods that are useful for unit 54// tests within other packages. 55 56const ( 57 testSharedPublicKey = ` 58ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQC9i+hFxZHGo6KblVme4zrAcJstR6I0PTJozW286X4WyvPnkMYDQ5mnhEYC7UWCvjoTWbPEXPX7NjhRtwQTGD67bV+lrxgfyzK1JZbUXK4PwgKJvQD+XyyWYMzDgGSQY61KUSqCxymSm/9NZkPU3ElaQ9xQuTzPpztM4ROfb8f2Yv6/ZESZsTo0MTAkp8Pcy+WkioI/uJ1H7zqs0EA4OMY4aDJRu0UtP4rTVeYNEAuRXdX+eH4aW3KMvhzpFTjMbaJHJXlEeUm2SaX5TNQyTOvghCeQILfYIL/Ca2ij8iwCmulwdV6eQGfd4VDu40PvSnmfoaE38o6HaPnX0kUcnKiT 59` 60 testSharedPrivateKey = ` 61-----BEGIN RSA PRIVATE KEY----- 62MIIEogIBAAKCAQEAvYvoRcWRxqOim5VZnuM6wHCbLUeiND0yaM1tvOl+Fsrz55DG 63A0OZp4RGAu1Fgr46E1mzxFz1+zY4UbcEExg+u21fpa8YH8sytSWW1FyuD8ICib0A 64/l8slmDMw4BkkGOtSlEqgscpkpv/TWZD1NxJWkPcULk8z6c7TOETn2/H9mL+v2RE 65mbE6NDEwJKfD3MvlpIqCP7idR+86rNBAODjGOGgyUbtFLT+K01XmDRALkV3V/nh+ 66GltyjL4c6RU4zG2iRyV5RHlJtkml+UzUMkzr4IQnkCC32CC/wmtoo/IsAprpcHVe 67nkBn3eFQ7uND70p5n6GhN/KOh2j519JFHJyokwIDAQABAoIBAHX7VOvBC3kCN9/x 68+aPdup84OE7Z7MvpX6w+WlUhXVugnmsAAVDczhKoUc/WktLLx2huCGhsmKvyVuH+ 69MioUiE+vx75gm3qGx5xbtmOfALVMRLopjCnJYf6EaFA0ZeQ+NwowNW7Lu0PHmAU8 70Z3JiX8IwxTz14DU82buDyewO7v+cEr97AnERe3PUcSTDoUXNaoNxjNpEJkKREY6h 714hAY676RT/GsRcQ8tqe/rnCqPHNd7JGqL+207FK4tJw7daoBjQyijWuB7K5chSal 72oPInylM6b13ASXuOAOT/2uSUBWmFVCZPDCmnZxy2SdnJGbsJAMl7Ma3MUlaGvVI+ 73Tfh1aQkCgYEA4JlNOabTb3z42wz6mz+Nz3JRwbawD+PJXOk5JsSnV7DtPtfgkK9y 746FTQdhnozGWShAvJvc+C4QAihs9AlHXoaBY5bEU7R/8UK/pSqwzam+MmxmhVDV7G 75IMQPV0FteoXTaJSikhZ88mETTegI2mik+zleBpVxvfdhE5TR+lq8Br0CgYEA2AwJ 76CUD5CYUSj09PluR0HHqamWOrJkKPFPwa+5eiTTCzfBBxImYZh7nXnWuoviXC0sg2 77AuvCW+uZ48ygv/D8gcz3j1JfbErKZJuV+TotK9rRtNIF5Ub7qysP7UjyI7zCssVM 78kuDd9LfRXaB/qGAHNkcDA8NxmHW3gpln4CFdSY8CgYANs4xwfercHEWaJ1qKagAe 79rZyrMpffAEhicJ/Z65lB0jtG4CiE6w8ZeUMWUVJQVcnwYD+4YpZbX4S7sJ0B8Ydy 80AhkSr86D/92dKTIt2STk6aCN7gNyQ1vW198PtaAWH1/cO2UHgHOy3ZUt5X/Uwxl9 81cex4flln+1Viumts2GgsCQKBgCJH7psgSyPekK5auFdKEr5+Gc/jB8I/Z3K9+g4X 825nH3G1PBTCJYLw7hRzw8W/8oALzvddqKzEFHphiGXK94Lqjt/A4q1OdbCrhiE68D 83My21P/dAKB1UYRSs9Y8CNyHCjuZM9jSMJ8vv6vG/SOJPsnVDWVAckAbQDvlTHC9t 84O98zAoGAcbW6uFDkrv0XMCpB9Su3KaNXOR0wzag+WIFQRXCcoTvxVi9iYfUReQPi 85oOyBJU/HMVvBfv4g+OVFLVgSwwm6owwsouZ0+D/LasbuHqYyqYqdyPJQYzWA2Y+F 86+B6f4RoPdSXj24JHPg/ioRxjaj094UXJxua2yfkcecGNEuBQHSs= 87-----END RSA PRIVATE KEY----- 88` 89) 90 91// TestCore returns a pure in-memory, uninitialized core for testing. 92func TestCore(t testing.T) *Core { 93 return TestCoreWithSeal(t, nil, false) 94} 95 96// TestCoreRaw returns a pure in-memory, uninitialized core for testing. The raw 97// storage endpoints are enabled with this core. 98func TestCoreRaw(t testing.T) *Core { 99 return TestCoreWithSeal(t, nil, true) 100} 101 102// TestCoreNewSeal returns a pure in-memory, uninitialized core with 103// the new seal configuration. 104func TestCoreNewSeal(t testing.T) *Core { 105 seal := NewTestSeal(t, nil) 106 return TestCoreWithSeal(t, seal, false) 107} 108 109// TestCoreWithConfig returns a pure in-memory, uninitialized core with the 110// specified core configurations overridden for testing. 111func TestCoreWithConfig(t testing.T, conf *CoreConfig) *Core { 112 return TestCoreWithSealAndUI(t, conf) 113} 114 115// TestCoreWithSeal returns a pure in-memory, uninitialized core with the 116// specified seal for testing. 117func TestCoreWithSeal(t testing.T, testSeal Seal, enableRaw bool) *Core { 118 conf := &CoreConfig{ 119 Seal: testSeal, 120 EnableUI: false, 121 EnableRaw: enableRaw, 122 BuiltinRegistry: NewMockBuiltinRegistry(), 123 } 124 return TestCoreWithSealAndUI(t, conf) 125} 126 127func TestCoreUI(t testing.T, enableUI bool) *Core { 128 conf := &CoreConfig{ 129 EnableUI: enableUI, 130 EnableRaw: true, 131 BuiltinRegistry: NewMockBuiltinRegistry(), 132 } 133 return TestCoreWithSealAndUI(t, conf) 134} 135 136func TestCoreWithSealAndUI(t testing.T, opts *CoreConfig) *Core { 137 logger := logging.NewVaultLogger(log.Trace) 138 physicalBackend, err := physInmem.NewInmem(nil, logger) 139 if err != nil { 140 t.Fatal(err) 141 } 142 143 // Start off with base test core config 144 conf := testCoreConfig(t, physicalBackend, logger) 145 146 // Override config values with ones that gets passed in 147 conf.EnableUI = opts.EnableUI 148 conf.EnableRaw = opts.EnableRaw 149 conf.Seal = opts.Seal 150 conf.LicensingConfig = opts.LicensingConfig 151 conf.DisableKeyEncodingChecks = opts.DisableKeyEncodingChecks 152 153 for k, v := range opts.LogicalBackends { 154 conf.LogicalBackends[k] = v 155 } 156 for k, v := range opts.CredentialBackends { 157 conf.CredentialBackends[k] = v 158 } 159 160 c, err := NewCore(conf) 161 if err != nil { 162 t.Fatalf("err: %s", err) 163 } 164 165 return c 166} 167 168func testCoreConfig(t testing.T, physicalBackend physical.Backend, logger log.Logger) *CoreConfig { 169 t.Helper() 170 noopAudits := map[string]audit.Factory{ 171 "noop": func(_ context.Context, config *audit.BackendConfig) (audit.Backend, error) { 172 view := &logical.InmemStorage{} 173 view.Put(context.Background(), &logical.StorageEntry{ 174 Key: "salt", 175 Value: []byte("foo"), 176 }) 177 config.SaltConfig = &salt.Config{ 178 HMAC: sha256.New, 179 HMACType: "hmac-sha256", 180 } 181 config.SaltView = view 182 return &noopAudit{ 183 Config: config, 184 }, nil 185 }, 186 } 187 188 noopBackends := make(map[string]logical.Factory) 189 noopBackends["noop"] = func(ctx context.Context, config *logical.BackendConfig) (logical.Backend, error) { 190 b := new(framework.Backend) 191 b.Setup(ctx, config) 192 b.BackendType = logical.TypeCredential 193 return b, nil 194 } 195 noopBackends["http"] = func(ctx context.Context, config *logical.BackendConfig) (logical.Backend, error) { 196 return new(rawHTTP), nil 197 } 198 199 credentialBackends := make(map[string]logical.Factory) 200 for backendName, backendFactory := range noopBackends { 201 credentialBackends[backendName] = backendFactory 202 } 203 for backendName, backendFactory := range testCredentialBackends { 204 credentialBackends[backendName] = backendFactory 205 } 206 207 logicalBackends := make(map[string]logical.Factory) 208 for backendName, backendFactory := range noopBackends { 209 logicalBackends[backendName] = backendFactory 210 } 211 212 logicalBackends["kv"] = LeasedPassthroughBackendFactory 213 for backendName, backendFactory := range testLogicalBackends { 214 logicalBackends[backendName] = backendFactory 215 } 216 217 conf := &CoreConfig{ 218 Physical: physicalBackend, 219 AuditBackends: noopAudits, 220 LogicalBackends: logicalBackends, 221 CredentialBackends: credentialBackends, 222 DisableMlock: true, 223 Logger: logger, 224 BuiltinRegistry: NewMockBuiltinRegistry(), 225 } 226 227 return conf 228} 229 230// TestCoreInit initializes the core with a single key, and returns 231// the key that must be used to unseal the core and a root token. 232func TestCoreInit(t testing.T, core *Core) ([][]byte, string) { 233 t.Helper() 234 secretShares, _, root := TestCoreInitClusterWrapperSetup(t, core, nil, nil) 235 return secretShares, root 236} 237 238func TestCoreInitClusterWrapperSetup(t testing.T, core *Core, clusterAddrs []*net.TCPAddr, handler http.Handler) ([][]byte, [][]byte, string) { 239 t.Helper() 240 core.SetClusterListenerAddrs(clusterAddrs) 241 core.SetClusterHandler(handler) 242 243 barrierConfig := &SealConfig{ 244 SecretShares: 3, 245 SecretThreshold: 3, 246 } 247 248 // If we support storing barrier keys, then set that to equal the min threshold to unseal 249 if core.seal.StoredKeysSupported() { 250 barrierConfig.StoredShares = barrierConfig.SecretThreshold 251 } 252 253 recoveryConfig := &SealConfig{ 254 SecretShares: 3, 255 SecretThreshold: 3, 256 } 257 258 result, err := core.Initialize(context.Background(), &InitParams{ 259 BarrierConfig: barrierConfig, 260 RecoveryConfig: recoveryConfig, 261 }) 262 if err != nil { 263 t.Fatalf("err: %s", err) 264 } 265 return result.SecretShares, result.RecoveryShares, result.RootToken 266} 267 268func TestCoreUnseal(core *Core, key []byte) (bool, error) { 269 return core.Unseal(key) 270} 271 272func TestCoreUnsealWithRecoveryKeys(core *Core, key []byte) (bool, error) { 273 return core.UnsealWithRecoveryKeys(key) 274} 275 276// TestCoreUnsealed returns a pure in-memory core that is already 277// initialized and unsealed. 278func TestCoreUnsealed(t testing.T) (*Core, [][]byte, string) { 279 t.Helper() 280 core := TestCore(t) 281 return testCoreUnsealed(t, core) 282} 283 284// TestCoreUnsealedRaw returns a pure in-memory core that is already 285// initialized, unsealed, and with raw endpoints enabled. 286func TestCoreUnsealedRaw(t testing.T) (*Core, [][]byte, string) { 287 t.Helper() 288 core := TestCoreRaw(t) 289 return testCoreUnsealed(t, core) 290} 291 292// TestCoreUnsealedWithConfig returns a pure in-memory core that is already 293// initialized, unsealed, with the any provided core config values overridden. 294func TestCoreUnsealedWithConfig(t testing.T, conf *CoreConfig) (*Core, [][]byte, string) { 295 t.Helper() 296 core := TestCoreWithConfig(t, conf) 297 return testCoreUnsealed(t, core) 298} 299 300func testCoreUnsealed(t testing.T, core *Core) (*Core, [][]byte, string) { 301 t.Helper() 302 keys, token := TestCoreInit(t, core) 303 for _, key := range keys { 304 if _, err := TestCoreUnseal(core, TestKeyCopy(key)); err != nil { 305 t.Fatalf("unseal err: %s", err) 306 } 307 } 308 309 if core.Sealed() { 310 t.Fatal("should not be sealed") 311 } 312 313 return core, keys, token 314} 315 316func TestCoreUnsealedBackend(t testing.T, backend physical.Backend) (*Core, [][]byte, string) { 317 t.Helper() 318 logger := logging.NewVaultLogger(log.Trace) 319 conf := testCoreConfig(t, backend, logger) 320 conf.Seal = NewTestSeal(t, nil) 321 322 core, err := NewCore(conf) 323 if err != nil { 324 t.Fatalf("err: %s", err) 325 } 326 327 keys, token := TestCoreInit(t, core) 328 for _, key := range keys { 329 if _, err := TestCoreUnseal(core, TestKeyCopy(key)); err != nil { 330 t.Fatalf("unseal err: %s", err) 331 } 332 } 333 334 if err := core.UnsealWithStoredKeys(context.Background()); err != nil { 335 t.Fatal(err) 336 } 337 338 if core.Sealed() { 339 t.Fatal("should not be sealed") 340 } 341 342 return core, keys, token 343} 344 345// TestKeyCopy is a silly little function to just copy the key so that 346// it can be used with Unseal easily. 347func TestKeyCopy(key []byte) []byte { 348 result := make([]byte, len(key)) 349 copy(result, key) 350 return result 351} 352 353func TestDynamicSystemView(c *Core) *dynamicSystemView { 354 me := &MountEntry{ 355 Config: MountConfig{ 356 DefaultLeaseTTL: 24 * time.Hour, 357 MaxLeaseTTL: 2 * 24 * time.Hour, 358 }, 359 } 360 361 return &dynamicSystemView{c, me} 362} 363 364// TestAddTestPlugin registers the testFunc as part of the plugin command to the 365// plugin catalog. If provided, uses tmpDir as the plugin directory. 366func TestAddTestPlugin(t testing.T, c *Core, name string, pluginType consts.PluginType, testFunc string, env []string, tempDir string) { 367 file, err := os.Open(os.Args[0]) 368 if err != nil { 369 t.Fatal(err) 370 } 371 defer file.Close() 372 373 dirPath := filepath.Dir(os.Args[0]) 374 fileName := filepath.Base(os.Args[0]) 375 376 if tempDir != "" { 377 fi, err := file.Stat() 378 if err != nil { 379 t.Fatal(err) 380 } 381 382 // Copy over the file to the temp dir 383 dst := filepath.Join(tempDir, fileName) 384 out, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode()) 385 if err != nil { 386 t.Fatal(err) 387 } 388 defer out.Close() 389 390 if _, err = io.Copy(out, file); err != nil { 391 t.Fatal(err) 392 } 393 err = out.Sync() 394 if err != nil { 395 t.Fatal(err) 396 } 397 398 dirPath = tempDir 399 } 400 401 // Determine plugin directory full path, evaluating potential symlink path 402 fullPath, err := filepath.EvalSymlinks(dirPath) 403 if err != nil { 404 t.Fatal(err) 405 } 406 407 reader, err := os.Open(filepath.Join(fullPath, fileName)) 408 if err != nil { 409 t.Fatal(err) 410 } 411 defer reader.Close() 412 413 // Find out the sha256 414 hash := sha256.New() 415 416 _, err = io.Copy(hash, reader) 417 if err != nil { 418 t.Fatal(err) 419 } 420 421 sum := hash.Sum(nil) 422 423 // Set core's plugin directory and plugin catalog directory 424 c.pluginDirectory = fullPath 425 c.pluginCatalog.directory = fullPath 426 427 args := []string{fmt.Sprintf("--test.run=%s", testFunc)} 428 err = c.pluginCatalog.Set(context.Background(), name, pluginType, fileName, args, env, sum) 429 if err != nil { 430 t.Fatal(err) 431 } 432} 433 434var testLogicalBackends = map[string]logical.Factory{} 435var testCredentialBackends = map[string]logical.Factory{} 436 437// StartSSHHostTestServer starts the test server which responds to SSH 438// authentication. Used to test the SSH secret backend. 439func StartSSHHostTestServer() (string, error) { 440 pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(testSharedPublicKey)) 441 if err != nil { 442 return "", fmt.Errorf("error parsing public key") 443 } 444 serverConfig := &ssh.ServerConfig{ 445 PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { 446 if bytes.Compare(pubKey.Marshal(), key.Marshal()) == 0 { 447 return &ssh.Permissions{}, nil 448 } else { 449 return nil, fmt.Errorf("key does not match") 450 } 451 }, 452 } 453 signer, err := ssh.ParsePrivateKey([]byte(testSharedPrivateKey)) 454 if err != nil { 455 panic("Error parsing private key") 456 } 457 serverConfig.AddHostKey(signer) 458 459 soc, err := net.Listen("tcp", "127.0.0.1:0") 460 if err != nil { 461 return "", fmt.Errorf("error listening to connection") 462 } 463 464 go func() { 465 for { 466 conn, err := soc.Accept() 467 if err != nil { 468 panic(fmt.Sprintf("Error accepting incoming connection: %s", err)) 469 } 470 defer conn.Close() 471 sshConn, chanReqs, _, err := ssh.NewServerConn(conn, serverConfig) 472 if err != nil { 473 panic(fmt.Sprintf("Handshaking error: %v", err)) 474 } 475 476 go func() { 477 for chanReq := range chanReqs { 478 go func(chanReq ssh.NewChannel) { 479 if chanReq.ChannelType() != "session" { 480 chanReq.Reject(ssh.UnknownChannelType, "unknown channel type") 481 return 482 } 483 484 ch, requests, err := chanReq.Accept() 485 if err != nil { 486 panic(fmt.Sprintf("Error accepting channel: %s", err)) 487 } 488 489 go func(ch ssh.Channel, in <-chan *ssh.Request) { 490 for req := range in { 491 executeServerCommand(ch, req) 492 } 493 }(ch, requests) 494 }(chanReq) 495 } 496 sshConn.Close() 497 }() 498 } 499 }() 500 return soc.Addr().String(), nil 501} 502 503// This executes the commands requested to be run on the server. 504// Used to test the SSH secret backend. 505func executeServerCommand(ch ssh.Channel, req *ssh.Request) { 506 command := string(req.Payload[4:]) 507 cmd := exec.Command("/bin/bash", []string{"-c", command}...) 508 req.Reply(true, nil) 509 510 cmd.Stdout = ch 511 cmd.Stderr = ch 512 cmd.Stdin = ch 513 514 err := cmd.Start() 515 if err != nil { 516 panic(fmt.Sprintf("Error starting the command: '%s'", err)) 517 } 518 519 go func() { 520 _, err := cmd.Process.Wait() 521 if err != nil { 522 panic(fmt.Sprintf("Error while waiting for command to finish:'%s'", err)) 523 } 524 ch.Close() 525 }() 526} 527 528// This adds a credential backend for the test core. This needs to be 529// invoked before the test core is created. 530func AddTestCredentialBackend(name string, factory logical.Factory) error { 531 if name == "" { 532 return fmt.Errorf("missing backend name") 533 } 534 if factory == nil { 535 return fmt.Errorf("missing backend factory function") 536 } 537 testCredentialBackends[name] = factory 538 return nil 539} 540 541// This adds a logical backend for the test core. This needs to be 542// invoked before the test core is created. 543func AddTestLogicalBackend(name string, factory logical.Factory) error { 544 if name == "" { 545 return fmt.Errorf("missing backend name") 546 } 547 if factory == nil { 548 return fmt.Errorf("missing backend factory function") 549 } 550 testLogicalBackends[name] = factory 551 return nil 552} 553 554type noopAudit struct { 555 Config *audit.BackendConfig 556 salt *salt.Salt 557 saltMutex sync.RWMutex 558} 559 560func (n *noopAudit) GetHash(ctx context.Context, data string) (string, error) { 561 salt, err := n.Salt(ctx) 562 if err != nil { 563 return "", err 564 } 565 return salt.GetIdentifiedHMAC(data), nil 566} 567 568func (n *noopAudit) LogRequest(_ context.Context, _ *audit.LogInput) error { 569 return nil 570} 571 572func (n *noopAudit) LogResponse(_ context.Context, _ *audit.LogInput) error { 573 return nil 574} 575 576func (n *noopAudit) Reload(_ context.Context) error { 577 return nil 578} 579 580func (n *noopAudit) Invalidate(_ context.Context) { 581 n.saltMutex.Lock() 582 defer n.saltMutex.Unlock() 583 n.salt = nil 584} 585 586func (n *noopAudit) Salt(ctx context.Context) (*salt.Salt, error) { 587 n.saltMutex.RLock() 588 if n.salt != nil { 589 defer n.saltMutex.RUnlock() 590 return n.salt, nil 591 } 592 n.saltMutex.RUnlock() 593 n.saltMutex.Lock() 594 defer n.saltMutex.Unlock() 595 if n.salt != nil { 596 return n.salt, nil 597 } 598 salt, err := salt.NewSalt(ctx, n.Config.SaltView, n.Config.SaltConfig) 599 if err != nil { 600 return nil, err 601 } 602 n.salt = salt 603 return salt, nil 604} 605 606type rawHTTP struct{} 607 608func (n *rawHTTP) HandleRequest(ctx context.Context, req *logical.Request) (*logical.Response, error) { 609 return &logical.Response{ 610 Data: map[string]interface{}{ 611 logical.HTTPStatusCode: 200, 612 logical.HTTPContentType: "plain/text", 613 logical.HTTPRawBody: []byte("hello world"), 614 }, 615 }, nil 616} 617 618func (n *rawHTTP) HandleExistenceCheck(ctx context.Context, req *logical.Request) (bool, bool, error) { 619 return false, false, nil 620} 621 622func (n *rawHTTP) SpecialPaths() *logical.Paths { 623 return &logical.Paths{Unauthenticated: []string{"*"}} 624} 625 626func (n *rawHTTP) System() logical.SystemView { 627 return logical.StaticSystemView{ 628 DefaultLeaseTTLVal: time.Hour * 24, 629 MaxLeaseTTLVal: time.Hour * 24 * 32, 630 } 631} 632 633func (n *rawHTTP) Logger() log.Logger { 634 return logging.NewVaultLogger(log.Trace) 635} 636 637func (n *rawHTTP) Cleanup(ctx context.Context) { 638 // noop 639} 640 641func (n *rawHTTP) Initialize(ctx context.Context) error { 642 // noop 643 return nil 644} 645 646func (n *rawHTTP) InvalidateKey(context.Context, string) { 647 // noop 648} 649 650func (n *rawHTTP) Setup(ctx context.Context, config *logical.BackendConfig) error { 651 // noop 652 return nil 653} 654 655func (n *rawHTTP) Type() logical.BackendType { 656 return logical.TypeLogical 657} 658 659func GenerateRandBytes(length int) ([]byte, error) { 660 if length < 0 { 661 return nil, fmt.Errorf("length must be >= 0") 662 } 663 664 buf := make([]byte, length) 665 if length == 0 { 666 return buf, nil 667 } 668 669 n, err := rand.Read(buf) 670 if err != nil { 671 return nil, err 672 } 673 if n != length { 674 return nil, fmt.Errorf("unable to read %d bytes; only read %d", length, n) 675 } 676 677 return buf, nil 678} 679 680func TestWaitActive(t testing.T, core *Core) { 681 t.Helper() 682 if err := TestWaitActiveWithError(core); err != nil { 683 t.Fatal(err) 684 } 685} 686 687func TestWaitActiveWithError(core *Core) error { 688 start := time.Now() 689 var standby bool 690 var err error 691 for time.Now().Sub(start) < 30*time.Second { 692 standby, err = core.Standby() 693 if err != nil { 694 return err 695 } 696 if !standby { 697 break 698 } 699 } 700 if standby { 701 return errors.New("should not be in standby mode") 702 } 703 return nil 704} 705 706type TestCluster struct { 707 BarrierKeys [][]byte 708 RecoveryKeys [][]byte 709 CACert *x509.Certificate 710 CACertBytes []byte 711 CACertPEM []byte 712 CACertPEMFile string 713 CAKey *ecdsa.PrivateKey 714 CAKeyPEM []byte 715 Cores []*TestClusterCore 716 ID string 717 RootToken string 718 RootCAs *x509.CertPool 719 TempDir string 720} 721 722func (c *TestCluster) Start() { 723 for _, core := range c.Cores { 724 if core.Server != nil { 725 for _, ln := range core.Listeners { 726 go core.Server.Serve(ln) 727 } 728 } 729 } 730} 731 732// UnsealCores uses the cluster barrier keys to unseal the test cluster cores 733func (c *TestCluster) UnsealCores(t testing.T) { 734 if err := c.UnsealCoresWithError(); err != nil { 735 t.Fatal(err) 736 } 737} 738 739func (c *TestCluster) UnsealCoresWithError() error { 740 numCores := len(c.Cores) 741 742 // Unseal first core 743 for _, key := range c.BarrierKeys { 744 if _, err := c.Cores[0].Unseal(TestKeyCopy(key)); err != nil { 745 return fmt.Errorf("unseal err: %s", err) 746 } 747 } 748 749 // Verify unsealed 750 if c.Cores[0].Sealed() { 751 return fmt.Errorf("should not be sealed") 752 } 753 754 if err := TestWaitActiveWithError(c.Cores[0].Core); err != nil { 755 return err 756 } 757 758 // Unseal other cores 759 for i := 1; i < numCores; i++ { 760 for _, key := range c.BarrierKeys { 761 if _, err := c.Cores[i].Core.Unseal(TestKeyCopy(key)); err != nil { 762 return fmt.Errorf("unseal err: %s", err) 763 } 764 } 765 } 766 767 // Let them come fully up to standby 768 time.Sleep(2 * time.Second) 769 770 // Ensure cluster connection info is populated. 771 // Other cores should not come up as leaders. 772 for i := 1; i < numCores; i++ { 773 isLeader, _, _, err := c.Cores[i].Leader() 774 if err != nil { 775 return err 776 } 777 if isLeader { 778 return fmt.Errorf("core[%d] should not be leader", i) 779 } 780 } 781 782 return nil 783} 784 785func (c *TestCluster) EnsureCoresSealed(t testing.T) { 786 t.Helper() 787 if err := c.ensureCoresSealed(); err != nil { 788 t.Fatal(err) 789 } 790} 791 792func (c *TestClusterCore) Seal(t testing.T) { 793 t.Helper() 794 if err := c.Core.sealInternal(); err != nil { 795 t.Fatal(err) 796 } 797} 798 799func CleanupClusters(clusters []*TestCluster) { 800 wg := &sync.WaitGroup{} 801 for _, cluster := range clusters { 802 wg.Add(1) 803 lc := cluster 804 go func() { 805 defer wg.Done() 806 lc.Cleanup() 807 }() 808 } 809 wg.Wait() 810} 811 812func (c *TestCluster) Cleanup() { 813 // Close listeners 814 wg := &sync.WaitGroup{} 815 for _, core := range c.Cores { 816 wg.Add(1) 817 lc := core 818 819 go func() { 820 defer wg.Done() 821 if lc.Listeners != nil { 822 for _, ln := range lc.Listeners { 823 ln.Close() 824 } 825 } 826 if lc.licensingStopCh != nil { 827 close(lc.licensingStopCh) 828 lc.licensingStopCh = nil 829 } 830 831 if err := lc.Shutdown(); err != nil { 832 lc.Logger().Error("error during shutdown; abandoning sealing", "error", err) 833 } else { 834 timeout := time.Now().Add(60 * time.Second) 835 for { 836 if time.Now().After(timeout) { 837 lc.Logger().Error("timeout waiting for core to seal") 838 } 839 if lc.Sealed() { 840 break 841 } 842 time.Sleep(250 * time.Millisecond) 843 } 844 } 845 }() 846 } 847 848 wg.Wait() 849 850 // Remove any temp dir that exists 851 if c.TempDir != "" { 852 os.RemoveAll(c.TempDir) 853 } 854 855 // Give time to actually shut down/clean up before the next test 856 time.Sleep(time.Second) 857} 858 859func (c *TestCluster) ensureCoresSealed() error { 860 for _, core := range c.Cores { 861 if err := core.Shutdown(); err != nil { 862 return err 863 } 864 timeout := time.Now().Add(60 * time.Second) 865 for { 866 if time.Now().After(timeout) { 867 return fmt.Errorf("timeout waiting for core to seal") 868 } 869 if core.Sealed() { 870 break 871 } 872 time.Sleep(250 * time.Millisecond) 873 } 874 } 875 return nil 876} 877 878// UnsealWithStoredKeys uses stored keys to unseal the test cluster cores 879func (c *TestCluster) UnsealWithStoredKeys(t testing.T) error { 880 for _, core := range c.Cores { 881 if err := core.UnsealWithStoredKeys(context.Background()); err != nil { 882 return err 883 } 884 timeout := time.Now().Add(60 * time.Second) 885 for { 886 if time.Now().After(timeout) { 887 return fmt.Errorf("timeout waiting for core to unseal") 888 } 889 if !core.Sealed() { 890 break 891 } 892 time.Sleep(250 * time.Millisecond) 893 } 894 } 895 return nil 896} 897 898func SetReplicationFailureMode(core *TestClusterCore, mode uint32) { 899 atomic.StoreUint32(core.Core.replicationFailure, mode) 900} 901 902type TestListener struct { 903 net.Listener 904 Address *net.TCPAddr 905} 906 907type TestClusterCore struct { 908 *Core 909 CoreConfig *CoreConfig 910 Client *api.Client 911 Handler http.Handler 912 Listeners []*TestListener 913 ReloadFuncs *map[string][]reload.ReloadFunc 914 ReloadFuncsLock *sync.RWMutex 915 Server *http.Server 916 ServerCert *x509.Certificate 917 ServerCertBytes []byte 918 ServerCertPEM []byte 919 ServerKey *ecdsa.PrivateKey 920 ServerKeyPEM []byte 921 TLSConfig *tls.Config 922 UnderlyingStorage physical.Backend 923} 924 925type TestClusterOptions struct { 926 KeepStandbysSealed bool 927 SkipInit bool 928 HandlerFunc func(*HandlerProperties) http.Handler 929 BaseListenAddress string 930 NumCores int 931 SealFunc func() Seal 932 Logger log.Logger 933 TempDir string 934 CACert []byte 935 CAKey *ecdsa.PrivateKey 936} 937 938var DefaultNumCores = 3 939 940type certInfo struct { 941 cert *x509.Certificate 942 certPEM []byte 943 certBytes []byte 944 key *ecdsa.PrivateKey 945 keyPEM []byte 946} 947 948// NewTestCluster creates a new test cluster based on the provided core config 949// and test cluster options. 950// 951// N.B. Even though a single base CoreConfig is provided, NewTestCluster will instantiate a 952// core config for each core it creates. If separate seal per core is desired, opts.SealFunc 953// can be provided to generate a seal for each one. Otherwise, the provided base.Seal will be 954// shared among cores. NewCore's default behavior is to generate a new DefaultSeal if the 955// provided Seal in coreConfig (i.e. base.Seal) is nil. 956func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *TestCluster { 957 var err error 958 959 var numCores int 960 if opts == nil || opts.NumCores == 0 { 961 numCores = DefaultNumCores 962 } else { 963 numCores = opts.NumCores 964 } 965 966 certIPs := []net.IP{ 967 net.IPv6loopback, 968 net.ParseIP("127.0.0.1"), 969 } 970 var baseAddr *net.TCPAddr 971 if opts != nil && opts.BaseListenAddress != "" { 972 baseAddr, err = net.ResolveTCPAddr("tcp", opts.BaseListenAddress) 973 if err != nil { 974 t.Fatal("could not parse given base IP") 975 } 976 certIPs = append(certIPs, baseAddr.IP) 977 } 978 979 var testCluster TestCluster 980 if opts != nil && opts.TempDir != "" { 981 if _, err := os.Stat(opts.TempDir); os.IsNotExist(err) { 982 if err := os.MkdirAll(opts.TempDir, 0700); err != nil { 983 t.Fatal(err) 984 } 985 } 986 testCluster.TempDir = opts.TempDir 987 } else { 988 tempDir, err := ioutil.TempDir("", "vault-test-cluster-") 989 if err != nil { 990 t.Fatal(err) 991 } 992 testCluster.TempDir = tempDir 993 } 994 995 var caKey *ecdsa.PrivateKey 996 if opts != nil && opts.CAKey != nil { 997 caKey = opts.CAKey 998 } else { 999 caKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 1000 if err != nil { 1001 t.Fatal(err) 1002 } 1003 } 1004 testCluster.CAKey = caKey 1005 var caBytes []byte 1006 if opts != nil && len(opts.CACert) > 0 { 1007 caBytes = opts.CACert 1008 } else { 1009 caCertTemplate := &x509.Certificate{ 1010 Subject: pkix.Name{ 1011 CommonName: "localhost", 1012 }, 1013 DNSNames: []string{"localhost"}, 1014 IPAddresses: certIPs, 1015 KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign), 1016 SerialNumber: big.NewInt(mathrand.Int63()), 1017 NotBefore: time.Now().Add(-30 * time.Second), 1018 NotAfter: time.Now().Add(262980 * time.Hour), 1019 BasicConstraintsValid: true, 1020 IsCA: true, 1021 } 1022 caBytes, err = x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, caKey.Public(), caKey) 1023 if err != nil { 1024 t.Fatal(err) 1025 } 1026 } 1027 caCert, err := x509.ParseCertificate(caBytes) 1028 if err != nil { 1029 t.Fatal(err) 1030 } 1031 testCluster.CACert = caCert 1032 testCluster.CACertBytes = caBytes 1033 testCluster.RootCAs = x509.NewCertPool() 1034 testCluster.RootCAs.AddCert(caCert) 1035 caCertPEMBlock := &pem.Block{ 1036 Type: "CERTIFICATE", 1037 Bytes: caBytes, 1038 } 1039 testCluster.CACertPEM = pem.EncodeToMemory(caCertPEMBlock) 1040 testCluster.CACertPEMFile = filepath.Join(testCluster.TempDir, "ca_cert.pem") 1041 err = ioutil.WriteFile(testCluster.CACertPEMFile, testCluster.CACertPEM, 0755) 1042 if err != nil { 1043 t.Fatal(err) 1044 } 1045 marshaledCAKey, err := x509.MarshalECPrivateKey(caKey) 1046 if err != nil { 1047 t.Fatal(err) 1048 } 1049 caKeyPEMBlock := &pem.Block{ 1050 Type: "EC PRIVATE KEY", 1051 Bytes: marshaledCAKey, 1052 } 1053 testCluster.CAKeyPEM = pem.EncodeToMemory(caKeyPEMBlock) 1054 err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "ca_key.pem"), testCluster.CAKeyPEM, 0755) 1055 if err != nil { 1056 t.Fatal(err) 1057 } 1058 1059 var certInfoSlice []*certInfo 1060 1061 // 1062 // Certs generation 1063 // 1064 for i := 0; i < numCores; i++ { 1065 key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 1066 if err != nil { 1067 t.Fatal(err) 1068 } 1069 certTemplate := &x509.Certificate{ 1070 Subject: pkix.Name{ 1071 CommonName: "localhost", 1072 }, 1073 DNSNames: []string{"localhost"}, 1074 IPAddresses: certIPs, 1075 ExtKeyUsage: []x509.ExtKeyUsage{ 1076 x509.ExtKeyUsageServerAuth, 1077 x509.ExtKeyUsageClientAuth, 1078 }, 1079 KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, 1080 SerialNumber: big.NewInt(mathrand.Int63()), 1081 NotBefore: time.Now().Add(-30 * time.Second), 1082 NotAfter: time.Now().Add(262980 * time.Hour), 1083 } 1084 certBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, caCert, key.Public(), caKey) 1085 if err != nil { 1086 t.Fatal(err) 1087 } 1088 cert, err := x509.ParseCertificate(certBytes) 1089 if err != nil { 1090 t.Fatal(err) 1091 } 1092 certPEMBlock := &pem.Block{ 1093 Type: "CERTIFICATE", 1094 Bytes: certBytes, 1095 } 1096 certPEM := pem.EncodeToMemory(certPEMBlock) 1097 marshaledKey, err := x509.MarshalECPrivateKey(key) 1098 if err != nil { 1099 t.Fatal(err) 1100 } 1101 keyPEMBlock := &pem.Block{ 1102 Type: "EC PRIVATE KEY", 1103 Bytes: marshaledKey, 1104 } 1105 keyPEM := pem.EncodeToMemory(keyPEMBlock) 1106 1107 certInfoSlice = append(certInfoSlice, &certInfo{ 1108 cert: cert, 1109 certPEM: certPEM, 1110 certBytes: certBytes, 1111 key: key, 1112 keyPEM: keyPEM, 1113 }) 1114 } 1115 1116 // 1117 // Listener setup 1118 // 1119 logger := logging.NewVaultLogger(log.Trace) 1120 ports := make([]int, numCores) 1121 if baseAddr != nil { 1122 for i := 0; i < numCores; i++ { 1123 ports[i] = baseAddr.Port + i 1124 } 1125 } else { 1126 baseAddr = &net.TCPAddr{ 1127 IP: net.ParseIP("127.0.0.1"), 1128 Port: 0, 1129 } 1130 } 1131 1132 listeners := [][]*TestListener{} 1133 servers := []*http.Server{} 1134 handlers := []http.Handler{} 1135 tlsConfigs := []*tls.Config{} 1136 certGetters := []*reload.CertificateGetter{} 1137 for i := 0; i < numCores; i++ { 1138 baseAddr.Port = ports[i] 1139 ln, err := net.ListenTCP("tcp", baseAddr) 1140 if err != nil { 1141 t.Fatal(err) 1142 } 1143 certFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_cert.pem", i+1, ln.Addr().(*net.TCPAddr).Port)) 1144 keyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_key.pem", i+1, ln.Addr().(*net.TCPAddr).Port)) 1145 err = ioutil.WriteFile(certFile, certInfoSlice[i].certPEM, 0755) 1146 if err != nil { 1147 t.Fatal(err) 1148 } 1149 err = ioutil.WriteFile(keyFile, certInfoSlice[i].keyPEM, 0755) 1150 if err != nil { 1151 t.Fatal(err) 1152 } 1153 tlsCert, err := tls.X509KeyPair(certInfoSlice[i].certPEM, certInfoSlice[i].keyPEM) 1154 if err != nil { 1155 t.Fatal(err) 1156 } 1157 certGetter := reload.NewCertificateGetter(certFile, keyFile, "") 1158 certGetters = append(certGetters, certGetter) 1159 tlsConfig := &tls.Config{ 1160 Certificates: []tls.Certificate{tlsCert}, 1161 RootCAs: testCluster.RootCAs, 1162 ClientCAs: testCluster.RootCAs, 1163 ClientAuth: tls.RequestClientCert, 1164 NextProtos: []string{"h2", "http/1.1"}, 1165 GetCertificate: certGetter.GetCertificate, 1166 } 1167 tlsConfig.BuildNameToCertificate() 1168 tlsConfigs = append(tlsConfigs, tlsConfig) 1169 lns := []*TestListener{&TestListener{ 1170 Listener: tls.NewListener(ln, tlsConfig), 1171 Address: ln.Addr().(*net.TCPAddr), 1172 }, 1173 } 1174 listeners = append(listeners, lns) 1175 var handler http.Handler = http.NewServeMux() 1176 handlers = append(handlers, handler) 1177 server := &http.Server{ 1178 Handler: handler, 1179 ErrorLog: logger.StandardLogger(nil), 1180 } 1181 servers = append(servers, server) 1182 } 1183 1184 // Create three cores with the same physical and different redirect/cluster 1185 // addrs. 1186 // N.B.: On OSX, instead of random ports, it assigns new ports to new 1187 // listeners sequentially. Aside from being a bad idea in a security sense, 1188 // it also broke tests that assumed it was OK to just use the port above 1189 // the redirect addr. This has now been changed to 105 ports above, but if 1190 // we ever do more than three nodes in a cluster it may need to be bumped. 1191 // Note: it's 105 so that we don't conflict with a running Consul by 1192 // default. 1193 coreConfig := &CoreConfig{ 1194 LogicalBackends: make(map[string]logical.Factory), 1195 CredentialBackends: make(map[string]logical.Factory), 1196 AuditBackends: make(map[string]audit.Factory), 1197 RedirectAddr: fmt.Sprintf("https://127.0.0.1:%d", listeners[0][0].Address.Port), 1198 ClusterAddr: fmt.Sprintf("https://127.0.0.1:%d", listeners[0][0].Address.Port+105), 1199 DisableMlock: true, 1200 EnableUI: true, 1201 EnableRaw: true, 1202 BuiltinRegistry: NewMockBuiltinRegistry(), 1203 } 1204 1205 if base != nil { 1206 coreConfig.DisableCache = base.DisableCache 1207 coreConfig.EnableUI = base.EnableUI 1208 coreConfig.DefaultLeaseTTL = base.DefaultLeaseTTL 1209 coreConfig.MaxLeaseTTL = base.MaxLeaseTTL 1210 coreConfig.CacheSize = base.CacheSize 1211 coreConfig.PluginDirectory = base.PluginDirectory 1212 coreConfig.Seal = base.Seal 1213 coreConfig.DevToken = base.DevToken 1214 coreConfig.EnableRaw = base.EnableRaw 1215 coreConfig.DisableSealWrap = base.DisableSealWrap 1216 coreConfig.DevLicenseDuration = base.DevLicenseDuration 1217 coreConfig.DisableCache = base.DisableCache 1218 if base.BuiltinRegistry != nil { 1219 coreConfig.BuiltinRegistry = base.BuiltinRegistry 1220 } 1221 1222 if !coreConfig.DisableMlock { 1223 base.DisableMlock = false 1224 } 1225 1226 if base.Physical != nil { 1227 coreConfig.Physical = base.Physical 1228 } 1229 1230 if base.HAPhysical != nil { 1231 coreConfig.HAPhysical = base.HAPhysical 1232 } 1233 1234 // Used to set something non-working to test fallback 1235 switch base.ClusterAddr { 1236 case "empty": 1237 coreConfig.ClusterAddr = "" 1238 case "": 1239 default: 1240 coreConfig.ClusterAddr = base.ClusterAddr 1241 } 1242 1243 if base.LogicalBackends != nil { 1244 for k, v := range base.LogicalBackends { 1245 coreConfig.LogicalBackends[k] = v 1246 } 1247 } 1248 if base.CredentialBackends != nil { 1249 for k, v := range base.CredentialBackends { 1250 coreConfig.CredentialBackends[k] = v 1251 } 1252 } 1253 if base.AuditBackends != nil { 1254 for k, v := range base.AuditBackends { 1255 coreConfig.AuditBackends[k] = v 1256 } 1257 } 1258 if base.Logger != nil { 1259 coreConfig.Logger = base.Logger 1260 } 1261 1262 coreConfig.ClusterCipherSuites = base.ClusterCipherSuites 1263 1264 coreConfig.DisableCache = base.DisableCache 1265 1266 coreConfig.DevToken = base.DevToken 1267 } 1268 1269 if coreConfig.Physical == nil { 1270 coreConfig.Physical, err = physInmem.NewInmem(nil, logger) 1271 if err != nil { 1272 t.Fatal(err) 1273 } 1274 } 1275 if coreConfig.HAPhysical == nil { 1276 haPhys, err := physInmem.NewInmemHA(nil, logger) 1277 if err != nil { 1278 t.Fatal(err) 1279 } 1280 coreConfig.HAPhysical = haPhys.(physical.HABackend) 1281 } 1282 1283 pubKey, priKey, err := testGenerateCoreKeys() 1284 if err != nil { 1285 t.Fatalf("err: %v", err) 1286 } 1287 1288 cores := []*Core{} 1289 coreConfigs := []*CoreConfig{} 1290 for i := 0; i < numCores; i++ { 1291 localConfig := *coreConfig 1292 localConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", listeners[i][0].Address.Port) 1293 if localConfig.ClusterAddr != "" { 1294 localConfig.ClusterAddr = fmt.Sprintf("https://127.0.0.1:%d", listeners[i][0].Address.Port+105) 1295 } 1296 1297 // if opts.SealFunc is provided, use that to generate a seal for the config instead 1298 if opts != nil && opts.SealFunc != nil { 1299 localConfig.Seal = opts.SealFunc() 1300 } 1301 1302 if opts != nil && opts.Logger != nil { 1303 localConfig.Logger = opts.Logger.Named(fmt.Sprintf("core%d", i)) 1304 } 1305 1306 localConfig.LicensingConfig = testGetLicensingConfig(pubKey) 1307 1308 c, err := NewCore(&localConfig) 1309 if err != nil { 1310 t.Fatalf("err: %v", err) 1311 } 1312 cores = append(cores, c) 1313 coreConfigs = append(coreConfigs, &localConfig) 1314 if opts != nil && opts.HandlerFunc != nil { 1315 handlers[i] = opts.HandlerFunc(&HandlerProperties{ 1316 Core: c, 1317 MaxRequestDuration: DefaultMaxRequestDuration, 1318 }) 1319 servers[i].Handler = handlers[i] 1320 } 1321 1322 // Set this in case the Seal was manually set before the core was 1323 // created 1324 if localConfig.Seal != nil { 1325 localConfig.Seal.SetCore(c) 1326 } 1327 } 1328 1329 // 1330 // Clustering setup 1331 // 1332 clusterAddrGen := func(lns []*TestListener) []*net.TCPAddr { 1333 ret := make([]*net.TCPAddr, len(lns)) 1334 for i, ln := range lns { 1335 ret[i] = &net.TCPAddr{ 1336 IP: ln.Address.IP, 1337 Port: ln.Address.Port + 105, 1338 } 1339 } 1340 return ret 1341 } 1342 1343 for i := 0; i < numCores; i++ { 1344 if coreConfigs[i].ClusterAddr != "" { 1345 cores[i].SetClusterListenerAddrs(clusterAddrGen(listeners[i])) 1346 cores[i].SetClusterHandler(handlers[i]) 1347 } 1348 } 1349 1350 if opts == nil || !opts.SkipInit { 1351 bKeys, rKeys, root := TestCoreInitClusterWrapperSetup(t, cores[0], clusterAddrGen(listeners[0]), handlers[0]) 1352 barrierKeys, _ := copystructure.Copy(bKeys) 1353 testCluster.BarrierKeys = barrierKeys.([][]byte) 1354 recoveryKeys, _ := copystructure.Copy(rKeys) 1355 testCluster.RecoveryKeys = recoveryKeys.([][]byte) 1356 testCluster.RootToken = root 1357 1358 // Write root token and barrier keys 1359 err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "root_token"), []byte(root), 0755) 1360 if err != nil { 1361 t.Fatal(err) 1362 } 1363 var buf bytes.Buffer 1364 for i, key := range testCluster.BarrierKeys { 1365 buf.Write([]byte(base64.StdEncoding.EncodeToString(key))) 1366 if i < len(testCluster.BarrierKeys)-1 { 1367 buf.WriteRune('\n') 1368 } 1369 } 1370 err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "barrier_keys"), buf.Bytes(), 0755) 1371 if err != nil { 1372 t.Fatal(err) 1373 } 1374 for i, key := range testCluster.RecoveryKeys { 1375 buf.Write([]byte(base64.StdEncoding.EncodeToString(key))) 1376 if i < len(testCluster.RecoveryKeys)-1 { 1377 buf.WriteRune('\n') 1378 } 1379 } 1380 err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "recovery_keys"), buf.Bytes(), 0755) 1381 if err != nil { 1382 t.Fatal(err) 1383 } 1384 1385 // Unseal first core 1386 for _, key := range bKeys { 1387 if _, err := cores[0].Unseal(TestKeyCopy(key)); err != nil { 1388 t.Fatalf("unseal err: %s", err) 1389 } 1390 } 1391 1392 ctx := context.Background() 1393 1394 // If stored keys is supported, the above will no no-op, so trigger auto-unseal 1395 // using stored keys to try to unseal 1396 if err := cores[0].UnsealWithStoredKeys(ctx); err != nil { 1397 t.Fatal(err) 1398 } 1399 1400 // Verify unsealed 1401 if cores[0].Sealed() { 1402 t.Fatal("should not be sealed") 1403 } 1404 1405 TestWaitActive(t, cores[0]) 1406 1407 // Unseal other cores unless otherwise specified 1408 if (opts == nil || !opts.KeepStandbysSealed) && numCores > 1 { 1409 for i := 1; i < numCores; i++ { 1410 for _, key := range bKeys { 1411 if _, err := cores[i].Unseal(TestKeyCopy(key)); err != nil { 1412 t.Fatalf("unseal err: %s", err) 1413 } 1414 } 1415 1416 // If stored keys is supported, the above will no no-op, so trigger auto-unseal 1417 // using stored keys 1418 if err := cores[i].UnsealWithStoredKeys(ctx); err != nil { 1419 t.Fatal(err) 1420 } 1421 } 1422 1423 // Let them come fully up to standby 1424 time.Sleep(2 * time.Second) 1425 1426 // Ensure cluster connection info is populated. 1427 // Other cores should not come up as leaders. 1428 for i := 1; i < numCores; i++ { 1429 isLeader, _, _, err := cores[i].Leader() 1430 if err != nil { 1431 t.Fatal(err) 1432 } 1433 if isLeader { 1434 t.Fatalf("core[%d] should not be leader", i) 1435 } 1436 } 1437 } 1438 1439 // 1440 // Set test cluster core(s) and test cluster 1441 // 1442 cluster, err := cores[0].Cluster(context.Background()) 1443 if err != nil { 1444 t.Fatal(err) 1445 } 1446 testCluster.ID = cluster.ID 1447 } 1448 1449 getAPIClient := func(port int, tlsConfig *tls.Config) *api.Client { 1450 transport := cleanhttp.DefaultPooledTransport() 1451 transport.TLSClientConfig = tlsConfig.Clone() 1452 if err := http2.ConfigureTransport(transport); err != nil { 1453 t.Fatal(err) 1454 } 1455 client := &http.Client{ 1456 Transport: transport, 1457 CheckRedirect: func(*http.Request, []*http.Request) error { 1458 // This can of course be overridden per-test by using its own client 1459 return fmt.Errorf("redirects not allowed in these tests") 1460 }, 1461 } 1462 config := api.DefaultConfig() 1463 if config.Error != nil { 1464 t.Fatal(config.Error) 1465 } 1466 config.Address = fmt.Sprintf("https://127.0.0.1:%d", port) 1467 config.HttpClient = client 1468 config.MaxRetries = 0 1469 apiClient, err := api.NewClient(config) 1470 if err != nil { 1471 t.Fatal(err) 1472 } 1473 if opts == nil || !opts.SkipInit { 1474 apiClient.SetToken(testCluster.RootToken) 1475 } 1476 return apiClient 1477 } 1478 1479 var ret []*TestClusterCore 1480 for i := 0; i < numCores; i++ { 1481 tcc := &TestClusterCore{ 1482 Core: cores[i], 1483 CoreConfig: coreConfigs[i], 1484 ServerKey: certInfoSlice[i].key, 1485 ServerKeyPEM: certInfoSlice[i].keyPEM, 1486 ServerCert: certInfoSlice[i].cert, 1487 ServerCertBytes: certInfoSlice[i].certBytes, 1488 ServerCertPEM: certInfoSlice[i].certPEM, 1489 Listeners: listeners[i], 1490 Handler: handlers[i], 1491 Server: servers[i], 1492 TLSConfig: tlsConfigs[i], 1493 Client: getAPIClient(listeners[i][0].Address.Port, tlsConfigs[i]), 1494 } 1495 tcc.ReloadFuncs = &cores[i].reloadFuncs 1496 tcc.ReloadFuncsLock = &cores[i].reloadFuncsLock 1497 tcc.ReloadFuncsLock.Lock() 1498 (*tcc.ReloadFuncs)["listener|tcp"] = []reload.ReloadFunc{certGetters[i].Reload} 1499 tcc.ReloadFuncsLock.Unlock() 1500 1501 testAdjustTestCore(base, tcc) 1502 1503 ret = append(ret, tcc) 1504 } 1505 1506 testCluster.Cores = ret 1507 1508 testExtraClusterCoresTestSetup(t, priKey, testCluster.Cores) 1509 1510 return &testCluster 1511} 1512 1513func NewMockBuiltinRegistry() *mockBuiltinRegistry { 1514 return &mockBuiltinRegistry{ 1515 forTesting: map[string]consts.PluginType{ 1516 "mysql-database-plugin": consts.PluginTypeDatabase, 1517 "postgresql-database-plugin": consts.PluginTypeDatabase, 1518 }, 1519 } 1520} 1521 1522type mockBuiltinRegistry struct { 1523 forTesting map[string]consts.PluginType 1524} 1525 1526func (m *mockBuiltinRegistry) Get(name string, pluginType consts.PluginType) (func() (interface{}, error), bool) { 1527 testPluginType, ok := m.forTesting[name] 1528 if !ok { 1529 return nil, false 1530 } 1531 if pluginType != testPluginType { 1532 return nil, false 1533 } 1534 if name == "postgresql-database-plugin" { 1535 return dbPostgres.New, true 1536 } 1537 return dbMysql.New(dbMysql.MetadataLen, dbMysql.MetadataLen, dbMysql.UsernameLen), true 1538} 1539 1540// Keys only supports getting a realistic list of the keys for database plugins. 1541func (m *mockBuiltinRegistry) Keys(pluginType consts.PluginType) []string { 1542 if pluginType != consts.PluginTypeDatabase { 1543 return []string{} 1544 } 1545 /* 1546 This is a hard-coded reproduction of the db plugin keys in helper/builtinplugins/registry.go. 1547 The registry isn't directly used because it causes import cycles. 1548 */ 1549 return []string{ 1550 "mysql-database-plugin", 1551 "mysql-aurora-database-plugin", 1552 "mysql-rds-database-plugin", 1553 "mysql-legacy-database-plugin", 1554 "postgresql-database-plugin", 1555 "mssql-database-plugin", 1556 "cassandra-database-plugin", 1557 "mongodb-database-plugin", 1558 "hana-database-plugin", 1559 "influxdb-database-plugin", 1560 } 1561} 1562 1563func (m *mockBuiltinRegistry) Contains(name string, pluginType consts.PluginType) bool { 1564 return false 1565} 1566