1package vault 2 3import ( 4 "bytes" 5 "context" 6 "crypto/ecdsa" 7 "crypto/elliptic" 8 "crypto/rand" 9 "crypto/sha256" 10 "crypto/tls" 11 "crypto/x509" 12 "crypto/x509/pkix" 13 "encoding/base64" 14 "encoding/pem" 15 "errors" 16 "fmt" 17 "io" 18 "io/ioutil" 19 "math/big" 20 mathrand "math/rand" 21 "net" 22 "net/http" 23 "os" 24 "os/exec" 25 "path/filepath" 26 "sync" 27 "sync/atomic" 28 "time" 29 30 "github.com/armon/go-metrics" 31 hclog "github.com/hashicorp/go-hclog" 32 log "github.com/hashicorp/go-hclog" 33 "github.com/hashicorp/vault/helper/metricsutil" 34 "github.com/mitchellh/copystructure" 35 36 "golang.org/x/crypto/ed25519" 37 "golang.org/x/crypto/ssh" 38 "golang.org/x/net/http2" 39 40 cleanhttp "github.com/hashicorp/go-cleanhttp" 41 "github.com/hashicorp/vault/api" 42 "github.com/hashicorp/vault/audit" 43 "github.com/hashicorp/vault/command/server" 44 "github.com/hashicorp/vault/helper/namespace" 45 "github.com/hashicorp/vault/helper/reload" 46 dbMysql "github.com/hashicorp/vault/plugins/database/mysql" 47 dbPostgres "github.com/hashicorp/vault/plugins/database/postgresql" 48 "github.com/hashicorp/vault/sdk/framework" 49 "github.com/hashicorp/vault/sdk/helper/consts" 50 "github.com/hashicorp/vault/sdk/helper/logging" 51 "github.com/hashicorp/vault/sdk/helper/salt" 52 "github.com/hashicorp/vault/sdk/logical" 53 "github.com/hashicorp/vault/sdk/physical" 54 testing "github.com/mitchellh/go-testing-interface" 55 56 physInmem "github.com/hashicorp/vault/sdk/physical/inmem" 57) 58 59// This file contains a number of methods that are useful for unit 60// tests within other packages. 61 62const ( 63 testSharedPublicKey = ` 64ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQC9i+hFxZHGo6KblVme4zrAcJstR6I0PTJozW286X4WyvPnkMYDQ5mnhEYC7UWCvjoTWbPEXPX7NjhRtwQTGD67bV+lrxgfyzK1JZbUXK4PwgKJvQD+XyyWYMzDgGSQY61KUSqCxymSm/9NZkPU3ElaQ9xQuTzPpztM4ROfb8f2Yv6/ZESZsTo0MTAkp8Pcy+WkioI/uJ1H7zqs0EA4OMY4aDJRu0UtP4rTVeYNEAuRXdX+eH4aW3KMvhzpFTjMbaJHJXlEeUm2SaX5TNQyTOvghCeQILfYIL/Ca2ij8iwCmulwdV6eQGfd4VDu40PvSnmfoaE38o6HaPnX0kUcnKiT 65` 66 testSharedPrivateKey = ` 67-----BEGIN RSA PRIVATE KEY----- 68MIIEogIBAAKCAQEAvYvoRcWRxqOim5VZnuM6wHCbLUeiND0yaM1tvOl+Fsrz55DG 69A0OZp4RGAu1Fgr46E1mzxFz1+zY4UbcEExg+u21fpa8YH8sytSWW1FyuD8ICib0A 70/l8slmDMw4BkkGOtSlEqgscpkpv/TWZD1NxJWkPcULk8z6c7TOETn2/H9mL+v2RE 71mbE6NDEwJKfD3MvlpIqCP7idR+86rNBAODjGOGgyUbtFLT+K01XmDRALkV3V/nh+ 72GltyjL4c6RU4zG2iRyV5RHlJtkml+UzUMkzr4IQnkCC32CC/wmtoo/IsAprpcHVe 73nkBn3eFQ7uND70p5n6GhN/KOh2j519JFHJyokwIDAQABAoIBAHX7VOvBC3kCN9/x 74+aPdup84OE7Z7MvpX6w+WlUhXVugnmsAAVDczhKoUc/WktLLx2huCGhsmKvyVuH+ 75MioUiE+vx75gm3qGx5xbtmOfALVMRLopjCnJYf6EaFA0ZeQ+NwowNW7Lu0PHmAU8 76Z3JiX8IwxTz14DU82buDyewO7v+cEr97AnERe3PUcSTDoUXNaoNxjNpEJkKREY6h 774hAY676RT/GsRcQ8tqe/rnCqPHNd7JGqL+207FK4tJw7daoBjQyijWuB7K5chSal 78oPInylM6b13ASXuOAOT/2uSUBWmFVCZPDCmnZxy2SdnJGbsJAMl7Ma3MUlaGvVI+ 79Tfh1aQkCgYEA4JlNOabTb3z42wz6mz+Nz3JRwbawD+PJXOk5JsSnV7DtPtfgkK9y 806FTQdhnozGWShAvJvc+C4QAihs9AlHXoaBY5bEU7R/8UK/pSqwzam+MmxmhVDV7G 81IMQPV0FteoXTaJSikhZ88mETTegI2mik+zleBpVxvfdhE5TR+lq8Br0CgYEA2AwJ 82CUD5CYUSj09PluR0HHqamWOrJkKPFPwa+5eiTTCzfBBxImYZh7nXnWuoviXC0sg2 83AuvCW+uZ48ygv/D8gcz3j1JfbErKZJuV+TotK9rRtNIF5Ub7qysP7UjyI7zCssVM 84kuDd9LfRXaB/qGAHNkcDA8NxmHW3gpln4CFdSY8CgYANs4xwfercHEWaJ1qKagAe 85rZyrMpffAEhicJ/Z65lB0jtG4CiE6w8ZeUMWUVJQVcnwYD+4YpZbX4S7sJ0B8Ydy 86AhkSr86D/92dKTIt2STk6aCN7gNyQ1vW198PtaAWH1/cO2UHgHOy3ZUt5X/Uwxl9 87cex4flln+1Viumts2GgsCQKBgCJH7psgSyPekK5auFdKEr5+Gc/jB8I/Z3K9+g4X 885nH3G1PBTCJYLw7hRzw8W/8oALzvddqKzEFHphiGXK94Lqjt/A4q1OdbCrhiE68D 89My21P/dAKB1UYRSs9Y8CNyHCjuZM9jSMJ8vv6vG/SOJPsnVDWVAckAbQDvlTHC9t 90O98zAoGAcbW6uFDkrv0XMCpB9Su3KaNXOR0wzag+WIFQRXCcoTvxVi9iYfUReQPi 91oOyBJU/HMVvBfv4g+OVFLVgSwwm6owwsouZ0+D/LasbuHqYyqYqdyPJQYzWA2Y+F 92+B6f4RoPdSXj24JHPg/ioRxjaj094UXJxua2yfkcecGNEuBQHSs= 93-----END RSA PRIVATE KEY----- 94` 95) 96 97// TestCore returns a pure in-memory, uninitialized core for testing. 98func TestCore(t testing.T) *Core { 99 return TestCoreWithSeal(t, nil, false) 100} 101 102// TestCoreRaw returns a pure in-memory, uninitialized core for testing. The raw 103// storage endpoints are enabled with this core. 104func TestCoreRaw(t testing.T) *Core { 105 return TestCoreWithSeal(t, nil, true) 106} 107 108// TestCoreNewSeal returns a pure in-memory, uninitialized core with 109// the new seal configuration. 110func TestCoreNewSeal(t testing.T) *Core { 111 seal := NewTestSeal(t, nil) 112 return TestCoreWithSeal(t, seal, false) 113} 114 115// TestCoreWithConfig returns a pure in-memory, uninitialized core with the 116// specified core configurations overridden for testing. 117func TestCoreWithConfig(t testing.T, conf *CoreConfig) *Core { 118 return TestCoreWithSealAndUI(t, conf) 119} 120 121// TestCoreWithSeal returns a pure in-memory, uninitialized core with the 122// specified seal for testing. 123func TestCoreWithSeal(t testing.T, testSeal Seal, enableRaw bool) *Core { 124 conf := &CoreConfig{ 125 Seal: testSeal, 126 EnableUI: false, 127 EnableRaw: enableRaw, 128 BuiltinRegistry: NewMockBuiltinRegistry(), 129 } 130 return TestCoreWithSealAndUI(t, conf) 131} 132 133func TestCoreUI(t testing.T, enableUI bool) *Core { 134 conf := &CoreConfig{ 135 EnableUI: enableUI, 136 EnableRaw: true, 137 BuiltinRegistry: NewMockBuiltinRegistry(), 138 } 139 return TestCoreWithSealAndUI(t, conf) 140} 141 142func TestCoreWithSealAndUI(t testing.T, opts *CoreConfig) *Core { 143 logger := logging.NewVaultLogger(log.Trace) 144 physicalBackend, err := physInmem.NewInmem(nil, logger) 145 if err != nil { 146 t.Fatal(err) 147 } 148 149 errInjector := physical.NewErrorInjector(physicalBackend, 0, logger) 150 151 // Start off with base test core config 152 conf := testCoreConfig(t, errInjector, logger) 153 154 // Override config values with ones that gets passed in 155 conf.EnableUI = opts.EnableUI 156 conf.EnableRaw = opts.EnableRaw 157 conf.Seal = opts.Seal 158 conf.LicensingConfig = opts.LicensingConfig 159 conf.DisableKeyEncodingChecks = opts.DisableKeyEncodingChecks 160 conf.MetricsHelper = opts.MetricsHelper 161 162 if opts.Logger != nil { 163 conf.Logger = opts.Logger 164 } 165 166 for k, v := range opts.LogicalBackends { 167 conf.LogicalBackends[k] = v 168 } 169 for k, v := range opts.CredentialBackends { 170 conf.CredentialBackends[k] = v 171 } 172 173 for k, v := range opts.AuditBackends { 174 conf.AuditBackends[k] = v 175 } 176 177 c, err := NewCore(conf) 178 if err != nil { 179 t.Fatalf("err: %s", err) 180 } 181 182 return c 183} 184 185func testCoreConfig(t testing.T, physicalBackend physical.Backend, logger log.Logger) *CoreConfig { 186 t.Helper() 187 noopAudits := map[string]audit.Factory{ 188 "noop": func(_ context.Context, config *audit.BackendConfig) (audit.Backend, error) { 189 view := &logical.InmemStorage{} 190 view.Put(context.Background(), &logical.StorageEntry{ 191 Key: "salt", 192 Value: []byte("foo"), 193 }) 194 config.SaltConfig = &salt.Config{ 195 HMAC: sha256.New, 196 HMACType: "hmac-sha256", 197 } 198 config.SaltView = view 199 200 n := &noopAudit{ 201 Config: config, 202 } 203 n.formatter.AuditFormatWriter = &audit.JSONFormatWriter{ 204 SaltFunc: n.Salt, 205 } 206 return n, nil 207 }, 208 } 209 210 noopBackends := make(map[string]logical.Factory) 211 noopBackends["noop"] = func(ctx context.Context, config *logical.BackendConfig) (logical.Backend, error) { 212 b := new(framework.Backend) 213 b.Setup(ctx, config) 214 b.BackendType = logical.TypeCredential 215 return b, nil 216 } 217 noopBackends["http"] = func(ctx context.Context, config *logical.BackendConfig) (logical.Backend, error) { 218 return new(rawHTTP), nil 219 } 220 221 credentialBackends := make(map[string]logical.Factory) 222 for backendName, backendFactory := range noopBackends { 223 credentialBackends[backendName] = backendFactory 224 } 225 for backendName, backendFactory := range testCredentialBackends { 226 credentialBackends[backendName] = backendFactory 227 } 228 229 logicalBackends := make(map[string]logical.Factory) 230 for backendName, backendFactory := range noopBackends { 231 logicalBackends[backendName] = backendFactory 232 } 233 234 logicalBackends["kv"] = LeasedPassthroughBackendFactory 235 for backendName, backendFactory := range testLogicalBackends { 236 logicalBackends[backendName] = backendFactory 237 } 238 239 conf := &CoreConfig{ 240 Physical: physicalBackend, 241 AuditBackends: noopAudits, 242 LogicalBackends: logicalBackends, 243 CredentialBackends: credentialBackends, 244 DisableMlock: true, 245 Logger: logger, 246 BuiltinRegistry: NewMockBuiltinRegistry(), 247 } 248 249 return conf 250} 251 252// TestCoreInit initializes the core with a single key, and returns 253// the key that must be used to unseal the core and a root token. 254func TestCoreInit(t testing.T, core *Core) ([][]byte, string) { 255 t.Helper() 256 secretShares, _, root := TestCoreInitClusterWrapperSetup(t, core, nil) 257 return secretShares, root 258} 259 260func TestCoreInitClusterWrapperSetup(t testing.T, core *Core, handler http.Handler) ([][]byte, [][]byte, string) { 261 t.Helper() 262 core.SetClusterHandler(handler) 263 264 barrierConfig := &SealConfig{ 265 SecretShares: 3, 266 SecretThreshold: 3, 267 } 268 269 switch core.seal.StoredKeysSupported() { 270 case StoredKeysNotSupported: 271 barrierConfig.StoredShares = 0 272 default: 273 barrierConfig.StoredShares = 1 274 } 275 276 recoveryConfig := &SealConfig{ 277 SecretShares: 3, 278 SecretThreshold: 3, 279 } 280 281 initParams := &InitParams{ 282 BarrierConfig: barrierConfig, 283 RecoveryConfig: recoveryConfig, 284 } 285 if core.seal.StoredKeysSupported() == StoredKeysNotSupported { 286 initParams.LegacyShamirSeal = true 287 } 288 result, err := core.Initialize(context.Background(), initParams) 289 if err != nil { 290 t.Fatalf("err: %s", err) 291 } 292 return result.SecretShares, result.RecoveryShares, result.RootToken 293} 294 295func TestCoreUnseal(core *Core, key []byte) (bool, error) { 296 return core.Unseal(key) 297} 298 299func TestCoreUnsealWithRecoveryKeys(core *Core, key []byte) (bool, error) { 300 return core.UnsealWithRecoveryKeys(key) 301} 302 303// TestCoreUnsealed returns a pure in-memory core that is already 304// initialized and unsealed. 305func TestCoreUnsealed(t testing.T) (*Core, [][]byte, string) { 306 t.Helper() 307 core := TestCore(t) 308 return testCoreUnsealed(t, core) 309} 310 311// TestCoreUnsealedRaw returns a pure in-memory core that is already 312// initialized, unsealed, and with raw endpoints enabled. 313func TestCoreUnsealedRaw(t testing.T) (*Core, [][]byte, string) { 314 t.Helper() 315 core := TestCoreRaw(t) 316 return testCoreUnsealed(t, core) 317} 318 319// TestCoreUnsealedWithConfig returns a pure in-memory core that is already 320// initialized, unsealed, with the any provided core config values overridden. 321func TestCoreUnsealedWithConfig(t testing.T, conf *CoreConfig) (*Core, [][]byte, string) { 322 t.Helper() 323 core := TestCoreWithConfig(t, conf) 324 return testCoreUnsealed(t, core) 325} 326 327func testCoreUnsealed(t testing.T, core *Core) (*Core, [][]byte, string) { 328 t.Helper() 329 keys, token := TestCoreInit(t, core) 330 for _, key := range keys { 331 if _, err := TestCoreUnseal(core, TestKeyCopy(key)); err != nil { 332 t.Fatalf("unseal err: %s", err) 333 } 334 } 335 336 if core.Sealed() { 337 t.Fatal("should not be sealed") 338 } 339 340 testCoreAddSecretMount(t, core, token) 341 342 return core, keys, token 343} 344 345func testCoreAddSecretMount(t testing.T, core *Core, token string) { 346 kvReq := &logical.Request{ 347 Operation: logical.UpdateOperation, 348 ClientToken: token, 349 Path: "sys/mounts/secret", 350 Data: map[string]interface{}{ 351 "type": "kv", 352 "path": "secret/", 353 "description": "key/value secret storage", 354 "options": map[string]string{ 355 "version": "1", 356 }, 357 }, 358 } 359 resp, err := core.HandleRequest(namespace.RootContext(nil), kvReq) 360 if err != nil { 361 t.Fatal(err) 362 } 363 if resp.IsError() { 364 t.Fatal(err) 365 } 366 367} 368 369func TestCoreUnsealedBackend(t testing.T, backend physical.Backend) (*Core, [][]byte, string) { 370 t.Helper() 371 logger := logging.NewVaultLogger(log.Trace) 372 conf := testCoreConfig(t, backend, logger) 373 conf.Seal = NewTestSeal(t, nil) 374 375 core, err := NewCore(conf) 376 if err != nil { 377 t.Fatalf("err: %s", err) 378 } 379 380 keys, token := TestCoreInit(t, core) 381 for _, key := range keys { 382 if _, err := TestCoreUnseal(core, TestKeyCopy(key)); err != nil { 383 t.Fatalf("unseal err: %s", err) 384 } 385 } 386 387 if err := core.UnsealWithStoredKeys(context.Background()); err != nil { 388 t.Fatal(err) 389 } 390 391 if core.Sealed() { 392 t.Fatal("should not be sealed") 393 } 394 395 return core, keys, token 396} 397 398// TestKeyCopy is a silly little function to just copy the key so that 399// it can be used with Unseal easily. 400func TestKeyCopy(key []byte) []byte { 401 result := make([]byte, len(key)) 402 copy(result, key) 403 return result 404} 405 406func TestDynamicSystemView(c *Core) *dynamicSystemView { 407 me := &MountEntry{ 408 Config: MountConfig{ 409 DefaultLeaseTTL: 24 * time.Hour, 410 MaxLeaseTTL: 2 * 24 * time.Hour, 411 }, 412 } 413 414 return &dynamicSystemView{c, me} 415} 416 417// TestAddTestPlugin registers the testFunc as part of the plugin command to the 418// plugin catalog. If provided, uses tmpDir as the plugin directory. 419func TestAddTestPlugin(t testing.T, c *Core, name string, pluginType consts.PluginType, testFunc string, env []string, tempDir string) { 420 file, err := os.Open(os.Args[0]) 421 if err != nil { 422 t.Fatal(err) 423 } 424 defer file.Close() 425 426 dirPath := filepath.Dir(os.Args[0]) 427 fileName := filepath.Base(os.Args[0]) 428 429 if tempDir != "" { 430 fi, err := file.Stat() 431 if err != nil { 432 t.Fatal(err) 433 } 434 435 // Copy over the file to the temp dir 436 dst := filepath.Join(tempDir, fileName) 437 out, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode()) 438 if err != nil { 439 t.Fatal(err) 440 } 441 defer out.Close() 442 443 if _, err = io.Copy(out, file); err != nil { 444 t.Fatal(err) 445 } 446 err = out.Sync() 447 if err != nil { 448 t.Fatal(err) 449 } 450 451 dirPath = tempDir 452 } 453 454 // Determine plugin directory full path, evaluating potential symlink path 455 fullPath, err := filepath.EvalSymlinks(dirPath) 456 if err != nil { 457 t.Fatal(err) 458 } 459 460 reader, err := os.Open(filepath.Join(fullPath, fileName)) 461 if err != nil { 462 t.Fatal(err) 463 } 464 defer reader.Close() 465 466 // Find out the sha256 467 hash := sha256.New() 468 469 _, err = io.Copy(hash, reader) 470 if err != nil { 471 t.Fatal(err) 472 } 473 474 sum := hash.Sum(nil) 475 476 // Set core's plugin directory and plugin catalog directory 477 c.pluginDirectory = fullPath 478 c.pluginCatalog.directory = fullPath 479 480 args := []string{fmt.Sprintf("--test.run=%s", testFunc)} 481 err = c.pluginCatalog.Set(context.Background(), name, pluginType, fileName, args, env, sum) 482 if err != nil { 483 t.Fatal(err) 484 } 485} 486 487var testLogicalBackends = map[string]logical.Factory{} 488var testCredentialBackends = map[string]logical.Factory{} 489 490// StartSSHHostTestServer starts the test server which responds to SSH 491// authentication. Used to test the SSH secret backend. 492func StartSSHHostTestServer() (string, error) { 493 pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(testSharedPublicKey)) 494 if err != nil { 495 return "", fmt.Errorf("error parsing public key") 496 } 497 serverConfig := &ssh.ServerConfig{ 498 PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { 499 if bytes.Compare(pubKey.Marshal(), key.Marshal()) == 0 { 500 return &ssh.Permissions{}, nil 501 } else { 502 return nil, fmt.Errorf("key does not match") 503 } 504 }, 505 } 506 signer, err := ssh.ParsePrivateKey([]byte(testSharedPrivateKey)) 507 if err != nil { 508 panic("Error parsing private key") 509 } 510 serverConfig.AddHostKey(signer) 511 512 soc, err := net.Listen("tcp", "127.0.0.1:0") 513 if err != nil { 514 return "", fmt.Errorf("error listening to connection") 515 } 516 517 go func() { 518 for { 519 conn, err := soc.Accept() 520 if err != nil { 521 panic(fmt.Sprintf("Error accepting incoming connection: %s", err)) 522 } 523 defer conn.Close() 524 sshConn, chanReqs, _, err := ssh.NewServerConn(conn, serverConfig) 525 if err != nil { 526 panic(fmt.Sprintf("Handshaking error: %v", err)) 527 } 528 529 go func() { 530 for chanReq := range chanReqs { 531 go func(chanReq ssh.NewChannel) { 532 if chanReq.ChannelType() != "session" { 533 chanReq.Reject(ssh.UnknownChannelType, "unknown channel type") 534 return 535 } 536 537 ch, requests, err := chanReq.Accept() 538 if err != nil { 539 panic(fmt.Sprintf("Error accepting channel: %s", err)) 540 } 541 542 go func(ch ssh.Channel, in <-chan *ssh.Request) { 543 for req := range in { 544 executeServerCommand(ch, req) 545 } 546 }(ch, requests) 547 }(chanReq) 548 } 549 sshConn.Close() 550 }() 551 } 552 }() 553 return soc.Addr().String(), nil 554} 555 556// This executes the commands requested to be run on the server. 557// Used to test the SSH secret backend. 558func executeServerCommand(ch ssh.Channel, req *ssh.Request) { 559 command := string(req.Payload[4:]) 560 cmd := exec.Command("/bin/bash", []string{"-c", command}...) 561 req.Reply(true, nil) 562 563 cmd.Stdout = ch 564 cmd.Stderr = ch 565 cmd.Stdin = ch 566 567 err := cmd.Start() 568 if err != nil { 569 panic(fmt.Sprintf("Error starting the command: '%s'", err)) 570 } 571 572 go func() { 573 _, err := cmd.Process.Wait() 574 if err != nil { 575 panic(fmt.Sprintf("Error while waiting for command to finish:'%s'", err)) 576 } 577 ch.Close() 578 }() 579} 580 581// This adds a credential backend for the test core. This needs to be 582// invoked before the test core is created. 583func AddTestCredentialBackend(name string, factory logical.Factory) error { 584 if name == "" { 585 return fmt.Errorf("missing backend name") 586 } 587 if factory == nil { 588 return fmt.Errorf("missing backend factory function") 589 } 590 testCredentialBackends[name] = factory 591 return nil 592} 593 594// This adds a logical backend for the test core. This needs to be 595// invoked before the test core is created. 596func AddTestLogicalBackend(name string, factory logical.Factory) error { 597 if name == "" { 598 return fmt.Errorf("missing backend name") 599 } 600 if factory == nil { 601 return fmt.Errorf("missing backend factory function") 602 } 603 testLogicalBackends[name] = factory 604 return nil 605} 606 607type noopAudit struct { 608 Config *audit.BackendConfig 609 salt *salt.Salt 610 saltMutex sync.RWMutex 611 formatter audit.AuditFormatter 612 records [][]byte 613 l sync.RWMutex 614} 615 616func (n *noopAudit) GetHash(ctx context.Context, data string) (string, error) { 617 salt, err := n.Salt(ctx) 618 if err != nil { 619 return "", err 620 } 621 return salt.GetIdentifiedHMAC(data), nil 622} 623 624func (n *noopAudit) LogRequest(ctx context.Context, in *logical.LogInput) error { 625 n.l.Lock() 626 defer n.l.Unlock() 627 var w bytes.Buffer 628 err := n.formatter.FormatRequest(ctx, &w, audit.FormatterConfig{}, in) 629 if err != nil { 630 return err 631 } 632 n.records = append(n.records, w.Bytes()) 633 return nil 634} 635 636func (n *noopAudit) LogResponse(ctx context.Context, in *logical.LogInput) error { 637 n.l.Lock() 638 defer n.l.Unlock() 639 var w bytes.Buffer 640 err := n.formatter.FormatResponse(ctx, &w, audit.FormatterConfig{}, in) 641 if err != nil { 642 return err 643 } 644 n.records = append(n.records, w.Bytes()) 645 return nil 646} 647 648func (n *noopAudit) Reload(_ context.Context) error { 649 return nil 650} 651 652func (n *noopAudit) Invalidate(_ context.Context) { 653 n.saltMutex.Lock() 654 defer n.saltMutex.Unlock() 655 n.salt = nil 656} 657 658func (n *noopAudit) Salt(ctx context.Context) (*salt.Salt, error) { 659 n.saltMutex.RLock() 660 if n.salt != nil { 661 defer n.saltMutex.RUnlock() 662 return n.salt, nil 663 } 664 n.saltMutex.RUnlock() 665 n.saltMutex.Lock() 666 defer n.saltMutex.Unlock() 667 if n.salt != nil { 668 return n.salt, nil 669 } 670 salt, err := salt.NewSalt(ctx, n.Config.SaltView, n.Config.SaltConfig) 671 if err != nil { 672 return nil, err 673 } 674 n.salt = salt 675 return salt, nil 676} 677 678func AddNoopAudit(conf *CoreConfig) { 679 conf.AuditBackends = map[string]audit.Factory{ 680 "noop": func(_ context.Context, config *audit.BackendConfig) (audit.Backend, error) { 681 view := &logical.InmemStorage{} 682 view.Put(context.Background(), &logical.StorageEntry{ 683 Key: "salt", 684 Value: []byte("foo"), 685 }) 686 n := &noopAudit{ 687 Config: config, 688 } 689 n.formatter.AuditFormatWriter = &audit.JSONFormatWriter{ 690 SaltFunc: n.Salt, 691 } 692 return n, nil 693 }, 694 } 695} 696 697type rawHTTP struct{} 698 699func (n *rawHTTP) HandleRequest(ctx context.Context, req *logical.Request) (*logical.Response, error) { 700 return &logical.Response{ 701 Data: map[string]interface{}{ 702 logical.HTTPStatusCode: 200, 703 logical.HTTPContentType: "plain/text", 704 logical.HTTPRawBody: []byte("hello world"), 705 }, 706 }, nil 707} 708 709func (n *rawHTTP) HandleExistenceCheck(ctx context.Context, req *logical.Request) (bool, bool, error) { 710 return false, false, nil 711} 712 713func (n *rawHTTP) SpecialPaths() *logical.Paths { 714 return &logical.Paths{Unauthenticated: []string{"*"}} 715} 716 717func (n *rawHTTP) System() logical.SystemView { 718 return logical.StaticSystemView{ 719 DefaultLeaseTTLVal: time.Hour * 24, 720 MaxLeaseTTLVal: time.Hour * 24 * 32, 721 } 722} 723 724func (n *rawHTTP) Logger() log.Logger { 725 return logging.NewVaultLogger(log.Trace) 726} 727 728func (n *rawHTTP) Cleanup(ctx context.Context) { 729 // noop 730} 731 732func (n *rawHTTP) Initialize(ctx context.Context, req *logical.InitializationRequest) error { 733 return nil 734} 735 736func (n *rawHTTP) InvalidateKey(context.Context, string) { 737 // noop 738} 739 740func (n *rawHTTP) Setup(ctx context.Context, config *logical.BackendConfig) error { 741 // noop 742 return nil 743} 744 745func (n *rawHTTP) Type() logical.BackendType { 746 return logical.TypeLogical 747} 748 749func GenerateRandBytes(length int) ([]byte, error) { 750 if length < 0 { 751 return nil, fmt.Errorf("length must be >= 0") 752 } 753 754 buf := make([]byte, length) 755 if length == 0 { 756 return buf, nil 757 } 758 759 n, err := rand.Read(buf) 760 if err != nil { 761 return nil, err 762 } 763 if n != length { 764 return nil, fmt.Errorf("unable to read %d bytes; only read %d", length, n) 765 } 766 767 return buf, nil 768} 769 770func TestWaitActive(t testing.T, core *Core) { 771 t.Helper() 772 if err := TestWaitActiveWithError(core); err != nil { 773 t.Fatal(err) 774 } 775} 776 777func TestWaitActiveWithError(core *Core) error { 778 start := time.Now() 779 var standby bool 780 var err error 781 for time.Now().Sub(start) < 30*time.Second { 782 standby, err = core.Standby() 783 if err != nil { 784 return err 785 } 786 if !standby { 787 break 788 } 789 } 790 if standby { 791 return errors.New("should not be in standby mode") 792 } 793 return nil 794} 795 796type TestCluster struct { 797 BarrierKeys [][]byte 798 RecoveryKeys [][]byte 799 CACert *x509.Certificate 800 CACertBytes []byte 801 CACertPEM []byte 802 CACertPEMFile string 803 CAKey *ecdsa.PrivateKey 804 CAKeyPEM []byte 805 Cores []*TestClusterCore 806 ID string 807 RootToken string 808 RootCAs *x509.CertPool 809 TempDir string 810 ClientAuthRequired bool 811 Logger log.Logger 812 CleanupFunc func() 813 SetupFunc func() 814} 815 816func (c *TestCluster) Start() { 817 for _, core := range c.Cores { 818 if core.Server != nil { 819 for _, ln := range core.Listeners { 820 go core.Server.Serve(ln) 821 } 822 } 823 } 824 if c.SetupFunc != nil { 825 c.SetupFunc() 826 } 827} 828 829// UnsealCores uses the cluster barrier keys to unseal the test cluster cores 830func (c *TestCluster) UnsealCores(t testing.T) { 831 t.Helper() 832 if err := c.UnsealCoresWithError(false); err != nil { 833 t.Fatal(err) 834 } 835} 836 837func (c *TestCluster) UnsealCoresWithError(useStoredKeys bool) error { 838 unseal := func(core *Core) error { 839 for _, key := range c.BarrierKeys { 840 if _, err := core.Unseal(TestKeyCopy(key)); err != nil { 841 return err 842 } 843 } 844 return nil 845 } 846 if useStoredKeys { 847 unseal = func(core *Core) error { 848 return core.UnsealWithStoredKeys(context.Background()) 849 } 850 } 851 852 // Unseal first core 853 if err := unseal(c.Cores[0].Core); err != nil { 854 return fmt.Errorf("unseal core %d err: %s", 0, err) 855 } 856 857 // Verify unsealed 858 if c.Cores[0].Sealed() { 859 return fmt.Errorf("should not be sealed") 860 } 861 862 if err := TestWaitActiveWithError(c.Cores[0].Core); err != nil { 863 return err 864 } 865 866 // Unseal other cores 867 for i := 1; i < len(c.Cores); i++ { 868 if err := unseal(c.Cores[i].Core); err != nil { 869 return fmt.Errorf("unseal core %d err: %s", i, err) 870 } 871 } 872 873 // Let them come fully up to standby 874 time.Sleep(2 * time.Second) 875 876 // Ensure cluster connection info is populated. 877 // Other cores should not come up as leaders. 878 for i := 1; i < len(c.Cores); i++ { 879 isLeader, _, _, err := c.Cores[i].Leader() 880 if err != nil { 881 return err 882 } 883 if isLeader { 884 return fmt.Errorf("core[%d] should not be leader", i) 885 } 886 } 887 888 return nil 889} 890 891func (c *TestCluster) UnsealCore(t testing.T, core *TestClusterCore) { 892 for _, key := range c.BarrierKeys { 893 if _, err := core.Core.Unseal(TestKeyCopy(key)); err != nil { 894 t.Fatalf("unseal err: %s", err) 895 } 896 } 897} 898 899func (c *TestCluster) EnsureCoresSealed(t testing.T) { 900 t.Helper() 901 if err := c.ensureCoresSealed(); err != nil { 902 t.Fatal(err) 903 } 904} 905 906func (c *TestClusterCore) Seal(t testing.T) { 907 t.Helper() 908 if err := c.Core.sealInternal(); err != nil { 909 t.Fatal(err) 910 } 911} 912 913func CleanupClusters(clusters []*TestCluster) { 914 wg := &sync.WaitGroup{} 915 for _, cluster := range clusters { 916 wg.Add(1) 917 lc := cluster 918 go func() { 919 defer wg.Done() 920 lc.Cleanup() 921 }() 922 } 923 wg.Wait() 924} 925 926func (c *TestCluster) Cleanup() { 927 c.Logger.Info("cleaning up vault cluster") 928 for _, core := range c.Cores { 929 core.CoreConfig.Logger.SetLevel(log.Error) 930 } 931 932 // Close listeners 933 wg := &sync.WaitGroup{} 934 for _, core := range c.Cores { 935 wg.Add(1) 936 lc := core 937 938 go func() { 939 defer wg.Done() 940 if lc.Listeners != nil { 941 for _, ln := range lc.Listeners { 942 ln.Close() 943 } 944 } 945 if lc.licensingStopCh != nil { 946 close(lc.licensingStopCh) 947 lc.licensingStopCh = nil 948 } 949 950 if err := lc.Shutdown(); err != nil { 951 lc.Logger().Error("error during shutdown; abandoning sealing", "error", err) 952 } else { 953 timeout := time.Now().Add(60 * time.Second) 954 for { 955 if time.Now().After(timeout) { 956 lc.Logger().Error("timeout waiting for core to seal") 957 } 958 if lc.Sealed() { 959 break 960 } 961 time.Sleep(250 * time.Millisecond) 962 } 963 } 964 }() 965 } 966 967 wg.Wait() 968 969 // Remove any temp dir that exists 970 if c.TempDir != "" { 971 os.RemoveAll(c.TempDir) 972 } 973 974 // Give time to actually shut down/clean up before the next test 975 time.Sleep(time.Second) 976 if c.CleanupFunc != nil { 977 c.CleanupFunc() 978 } 979} 980 981func (c *TestCluster) ensureCoresSealed() error { 982 for _, core := range c.Cores { 983 if err := core.Shutdown(); err != nil { 984 return err 985 } 986 timeout := time.Now().Add(60 * time.Second) 987 for { 988 if time.Now().After(timeout) { 989 return fmt.Errorf("timeout waiting for core to seal") 990 } 991 if core.Sealed() { 992 break 993 } 994 time.Sleep(250 * time.Millisecond) 995 } 996 } 997 return nil 998} 999 1000func SetReplicationFailureMode(core *TestClusterCore, mode uint32) { 1001 atomic.StoreUint32(core.Core.replicationFailure, mode) 1002} 1003 1004type TestListener struct { 1005 net.Listener 1006 Address *net.TCPAddr 1007} 1008 1009type TestClusterCore struct { 1010 *Core 1011 CoreConfig *CoreConfig 1012 Client *api.Client 1013 Handler http.Handler 1014 Listeners []*TestListener 1015 ReloadFuncs *map[string][]reload.ReloadFunc 1016 ReloadFuncsLock *sync.RWMutex 1017 Server *http.Server 1018 ServerCert *x509.Certificate 1019 ServerCertBytes []byte 1020 ServerCertPEM []byte 1021 ServerKey *ecdsa.PrivateKey 1022 ServerKeyPEM []byte 1023 TLSConfig *tls.Config 1024 UnderlyingStorage physical.Backend 1025 UnderlyingRawStorage physical.Backend 1026 Barrier SecurityBarrier 1027 NodeID string 1028} 1029 1030type PhysicalBackendBundle struct { 1031 Backend physical.Backend 1032 HABackend physical.HABackend 1033 Cleanup func() 1034} 1035 1036type TestClusterOptions struct { 1037 KeepStandbysSealed bool 1038 SkipInit bool 1039 HandlerFunc func(*HandlerProperties) http.Handler 1040 DefaultHandlerProperties HandlerProperties 1041 BaseListenAddress string 1042 NumCores int 1043 SealFunc func() Seal 1044 Logger log.Logger 1045 TempDir string 1046 CACert []byte 1047 CAKey *ecdsa.PrivateKey 1048 // PhysicalFactory is used to create backends. 1049 // The int argument is the index of the core within the cluster, i.e. first 1050 // core in cluster will have 0, second 1, etc. 1051 // If the backend is shared across the cluster (i.e. is not Raft) then it 1052 // should return nil when coreIdx != 0. 1053 PhysicalFactory func(t testing.T, coreIdx int, logger hclog.Logger) *PhysicalBackendBundle 1054 // FirstCoreNumber is used to assign a unique number to each core within 1055 // a multi-cluster setup. 1056 FirstCoreNumber int 1057 RequireClientAuth bool 1058 // SetupFunc is called after the cluster is started. 1059 SetupFunc func(t testing.T, c *TestCluster) 1060} 1061 1062var DefaultNumCores = 3 1063 1064type certInfo struct { 1065 cert *x509.Certificate 1066 certPEM []byte 1067 certBytes []byte 1068 key *ecdsa.PrivateKey 1069 keyPEM []byte 1070} 1071 1072// NewTestCluster creates a new test cluster based on the provided core config 1073// and test cluster options. 1074// 1075// N.B. Even though a single base CoreConfig is provided, NewTestCluster will instantiate a 1076// core config for each core it creates. If separate seal per core is desired, opts.SealFunc 1077// can be provided to generate a seal for each one. Otherwise, the provided base.Seal will be 1078// shared among cores. NewCore's default behavior is to generate a new DefaultSeal if the 1079// provided Seal in coreConfig (i.e. base.Seal) is nil. 1080// 1081// If opts.Logger is provided, it takes precedence and will be used as the cluster 1082// logger and will be the basis for each core's logger. If no opts.Logger is 1083// given, one will be generated based on t.Name() for the cluster logger, and if 1084// no base.Logger is given will also be used as the basis for each core's logger. 1085func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *TestCluster { 1086 var err error 1087 1088 var numCores int 1089 if opts == nil || opts.NumCores == 0 { 1090 numCores = DefaultNumCores 1091 } else { 1092 numCores = opts.NumCores 1093 } 1094 1095 var firstCoreNumber int 1096 if opts != nil { 1097 firstCoreNumber = opts.FirstCoreNumber 1098 } 1099 1100 certIPs := []net.IP{ 1101 net.IPv6loopback, 1102 net.ParseIP("127.0.0.1"), 1103 } 1104 var baseAddr *net.TCPAddr 1105 if opts != nil && opts.BaseListenAddress != "" { 1106 baseAddr, err = net.ResolveTCPAddr("tcp", opts.BaseListenAddress) 1107 if err != nil { 1108 t.Fatal("could not parse given base IP") 1109 } 1110 certIPs = append(certIPs, baseAddr.IP) 1111 } 1112 1113 var testCluster TestCluster 1114 1115 if opts != nil && opts.Logger != nil { 1116 testCluster.Logger = opts.Logger 1117 } else { 1118 testCluster.Logger = logging.NewVaultLogger(log.Trace).Named(t.Name()) 1119 } 1120 1121 if opts != nil && opts.TempDir != "" { 1122 if _, err := os.Stat(opts.TempDir); os.IsNotExist(err) { 1123 if err := os.MkdirAll(opts.TempDir, 0700); err != nil { 1124 t.Fatal(err) 1125 } 1126 } 1127 testCluster.TempDir = opts.TempDir 1128 } else { 1129 tempDir, err := ioutil.TempDir("", "vault-test-cluster-") 1130 if err != nil { 1131 t.Fatal(err) 1132 } 1133 testCluster.TempDir = tempDir 1134 } 1135 1136 var caKey *ecdsa.PrivateKey 1137 if opts != nil && opts.CAKey != nil { 1138 caKey = opts.CAKey 1139 } else { 1140 caKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 1141 if err != nil { 1142 t.Fatal(err) 1143 } 1144 } 1145 testCluster.CAKey = caKey 1146 var caBytes []byte 1147 if opts != nil && len(opts.CACert) > 0 { 1148 caBytes = opts.CACert 1149 } else { 1150 caCertTemplate := &x509.Certificate{ 1151 Subject: pkix.Name{ 1152 CommonName: "localhost", 1153 }, 1154 DNSNames: []string{"localhost"}, 1155 IPAddresses: certIPs, 1156 KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign), 1157 SerialNumber: big.NewInt(mathrand.Int63()), 1158 NotBefore: time.Now().Add(-30 * time.Second), 1159 NotAfter: time.Now().Add(262980 * time.Hour), 1160 BasicConstraintsValid: true, 1161 IsCA: true, 1162 } 1163 caBytes, err = x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, caKey.Public(), caKey) 1164 if err != nil { 1165 t.Fatal(err) 1166 } 1167 } 1168 caCert, err := x509.ParseCertificate(caBytes) 1169 if err != nil { 1170 t.Fatal(err) 1171 } 1172 testCluster.CACert = caCert 1173 testCluster.CACertBytes = caBytes 1174 testCluster.RootCAs = x509.NewCertPool() 1175 testCluster.RootCAs.AddCert(caCert) 1176 caCertPEMBlock := &pem.Block{ 1177 Type: "CERTIFICATE", 1178 Bytes: caBytes, 1179 } 1180 testCluster.CACertPEM = pem.EncodeToMemory(caCertPEMBlock) 1181 testCluster.CACertPEMFile = filepath.Join(testCluster.TempDir, "ca_cert.pem") 1182 err = ioutil.WriteFile(testCluster.CACertPEMFile, testCluster.CACertPEM, 0755) 1183 if err != nil { 1184 t.Fatal(err) 1185 } 1186 marshaledCAKey, err := x509.MarshalECPrivateKey(caKey) 1187 if err != nil { 1188 t.Fatal(err) 1189 } 1190 caKeyPEMBlock := &pem.Block{ 1191 Type: "EC PRIVATE KEY", 1192 Bytes: marshaledCAKey, 1193 } 1194 testCluster.CAKeyPEM = pem.EncodeToMemory(caKeyPEMBlock) 1195 err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "ca_key.pem"), testCluster.CAKeyPEM, 0755) 1196 if err != nil { 1197 t.Fatal(err) 1198 } 1199 1200 var certInfoSlice []*certInfo 1201 1202 // 1203 // Certs generation 1204 // 1205 for i := 0; i < numCores; i++ { 1206 key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 1207 if err != nil { 1208 t.Fatal(err) 1209 } 1210 certTemplate := &x509.Certificate{ 1211 Subject: pkix.Name{ 1212 CommonName: "localhost", 1213 }, 1214 DNSNames: []string{"localhost"}, 1215 IPAddresses: certIPs, 1216 ExtKeyUsage: []x509.ExtKeyUsage{ 1217 x509.ExtKeyUsageServerAuth, 1218 x509.ExtKeyUsageClientAuth, 1219 }, 1220 KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, 1221 SerialNumber: big.NewInt(mathrand.Int63()), 1222 NotBefore: time.Now().Add(-30 * time.Second), 1223 NotAfter: time.Now().Add(262980 * time.Hour), 1224 } 1225 certBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, caCert, key.Public(), caKey) 1226 if err != nil { 1227 t.Fatal(err) 1228 } 1229 cert, err := x509.ParseCertificate(certBytes) 1230 if err != nil { 1231 t.Fatal(err) 1232 } 1233 certPEMBlock := &pem.Block{ 1234 Type: "CERTIFICATE", 1235 Bytes: certBytes, 1236 } 1237 certPEM := pem.EncodeToMemory(certPEMBlock) 1238 marshaledKey, err := x509.MarshalECPrivateKey(key) 1239 if err != nil { 1240 t.Fatal(err) 1241 } 1242 keyPEMBlock := &pem.Block{ 1243 Type: "EC PRIVATE KEY", 1244 Bytes: marshaledKey, 1245 } 1246 keyPEM := pem.EncodeToMemory(keyPEMBlock) 1247 1248 certInfoSlice = append(certInfoSlice, &certInfo{ 1249 cert: cert, 1250 certPEM: certPEM, 1251 certBytes: certBytes, 1252 key: key, 1253 keyPEM: keyPEM, 1254 }) 1255 } 1256 1257 // 1258 // Listener setup 1259 // 1260 ports := make([]int, numCores) 1261 if baseAddr != nil { 1262 for i := 0; i < numCores; i++ { 1263 ports[i] = baseAddr.Port + i 1264 } 1265 } else { 1266 baseAddr = &net.TCPAddr{ 1267 IP: net.ParseIP("127.0.0.1"), 1268 Port: 0, 1269 } 1270 } 1271 1272 listeners := [][]*TestListener{} 1273 servers := []*http.Server{} 1274 handlers := []http.Handler{} 1275 tlsConfigs := []*tls.Config{} 1276 certGetters := []*reload.CertificateGetter{} 1277 for i := 0; i < numCores; i++ { 1278 baseAddr.Port = ports[i] 1279 ln, err := net.ListenTCP("tcp", baseAddr) 1280 if err != nil { 1281 t.Fatal(err) 1282 } 1283 certFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_cert.pem", i+1, ln.Addr().(*net.TCPAddr).Port)) 1284 keyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_key.pem", i+1, ln.Addr().(*net.TCPAddr).Port)) 1285 err = ioutil.WriteFile(certFile, certInfoSlice[i].certPEM, 0755) 1286 if err != nil { 1287 t.Fatal(err) 1288 } 1289 err = ioutil.WriteFile(keyFile, certInfoSlice[i].keyPEM, 0755) 1290 if err != nil { 1291 t.Fatal(err) 1292 } 1293 tlsCert, err := tls.X509KeyPair(certInfoSlice[i].certPEM, certInfoSlice[i].keyPEM) 1294 if err != nil { 1295 t.Fatal(err) 1296 } 1297 certGetter := reload.NewCertificateGetter(certFile, keyFile, "") 1298 certGetters = append(certGetters, certGetter) 1299 tlsConfig := &tls.Config{ 1300 Certificates: []tls.Certificate{tlsCert}, 1301 RootCAs: testCluster.RootCAs, 1302 ClientCAs: testCluster.RootCAs, 1303 ClientAuth: tls.RequestClientCert, 1304 NextProtos: []string{"h2", "http/1.1"}, 1305 GetCertificate: certGetter.GetCertificate, 1306 } 1307 if opts != nil && opts.RequireClientAuth { 1308 tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert 1309 testCluster.ClientAuthRequired = true 1310 } 1311 tlsConfig.BuildNameToCertificate() 1312 tlsConfigs = append(tlsConfigs, tlsConfig) 1313 lns := []*TestListener{&TestListener{ 1314 Listener: tls.NewListener(ln, tlsConfig), 1315 Address: ln.Addr().(*net.TCPAddr), 1316 }, 1317 } 1318 listeners = append(listeners, lns) 1319 var handler http.Handler = http.NewServeMux() 1320 handlers = append(handlers, handler) 1321 server := &http.Server{ 1322 Handler: handler, 1323 ErrorLog: testCluster.Logger.StandardLogger(nil), 1324 } 1325 servers = append(servers, server) 1326 } 1327 1328 // Create three cores with the same physical and different redirect/cluster 1329 // addrs. 1330 // N.B.: On OSX, instead of random ports, it assigns new ports to new 1331 // listeners sequentially. Aside from being a bad idea in a security sense, 1332 // it also broke tests that assumed it was OK to just use the port above 1333 // the redirect addr. This has now been changed to 105 ports above, but if 1334 // we ever do more than three nodes in a cluster it may need to be bumped. 1335 // Note: it's 105 so that we don't conflict with a running Consul by 1336 // default. 1337 coreConfig := &CoreConfig{ 1338 LogicalBackends: make(map[string]logical.Factory), 1339 CredentialBackends: make(map[string]logical.Factory), 1340 AuditBackends: make(map[string]audit.Factory), 1341 RedirectAddr: fmt.Sprintf("https://127.0.0.1:%d", listeners[0][0].Address.Port), 1342 ClusterAddr: "https://127.0.0.1:0", 1343 DisableMlock: true, 1344 EnableUI: true, 1345 EnableRaw: true, 1346 BuiltinRegistry: NewMockBuiltinRegistry(), 1347 } 1348 1349 if base != nil { 1350 coreConfig.RawConfig = base.RawConfig 1351 coreConfig.DisableCache = base.DisableCache 1352 coreConfig.EnableUI = base.EnableUI 1353 coreConfig.DefaultLeaseTTL = base.DefaultLeaseTTL 1354 coreConfig.MaxLeaseTTL = base.MaxLeaseTTL 1355 coreConfig.CacheSize = base.CacheSize 1356 coreConfig.PluginDirectory = base.PluginDirectory 1357 coreConfig.Seal = base.Seal 1358 coreConfig.DevToken = base.DevToken 1359 coreConfig.EnableRaw = base.EnableRaw 1360 coreConfig.DisableSealWrap = base.DisableSealWrap 1361 coreConfig.DevLicenseDuration = base.DevLicenseDuration 1362 coreConfig.DisableCache = base.DisableCache 1363 coreConfig.LicensingConfig = base.LicensingConfig 1364 coreConfig.DisablePerformanceStandby = base.DisablePerformanceStandby 1365 coreConfig.MetricsHelper = base.MetricsHelper 1366 coreConfig.SecureRandomReader = base.SecureRandomReader 1367 if base.BuiltinRegistry != nil { 1368 coreConfig.BuiltinRegistry = base.BuiltinRegistry 1369 } 1370 1371 if !coreConfig.DisableMlock { 1372 base.DisableMlock = false 1373 } 1374 1375 if base.Physical != nil { 1376 coreConfig.Physical = base.Physical 1377 } 1378 1379 if base.HAPhysical != nil { 1380 coreConfig.HAPhysical = base.HAPhysical 1381 } 1382 1383 // Used to set something non-working to test fallback 1384 switch base.ClusterAddr { 1385 case "empty": 1386 coreConfig.ClusterAddr = "" 1387 case "": 1388 default: 1389 coreConfig.ClusterAddr = base.ClusterAddr 1390 } 1391 1392 if base.LogicalBackends != nil { 1393 for k, v := range base.LogicalBackends { 1394 coreConfig.LogicalBackends[k] = v 1395 } 1396 } 1397 if base.CredentialBackends != nil { 1398 for k, v := range base.CredentialBackends { 1399 coreConfig.CredentialBackends[k] = v 1400 } 1401 } 1402 if base.AuditBackends != nil { 1403 for k, v := range base.AuditBackends { 1404 coreConfig.AuditBackends[k] = v 1405 } 1406 } 1407 if base.Logger != nil { 1408 coreConfig.Logger = base.Logger 1409 } 1410 1411 coreConfig.ClusterCipherSuites = base.ClusterCipherSuites 1412 1413 coreConfig.DisableCache = base.DisableCache 1414 1415 coreConfig.DevToken = base.DevToken 1416 coreConfig.CounterSyncInterval = base.CounterSyncInterval 1417 coreConfig.RecoveryMode = base.RecoveryMode 1418 } 1419 1420 if coreConfig.RawConfig == nil { 1421 coreConfig.RawConfig = new(server.Config) 1422 } 1423 1424 addAuditBackend := len(coreConfig.AuditBackends) == 0 1425 if addAuditBackend { 1426 AddNoopAudit(coreConfig) 1427 } 1428 1429 if coreConfig.Physical == nil && (opts == nil || opts.PhysicalFactory == nil) { 1430 coreConfig.Physical, err = physInmem.NewInmem(nil, testCluster.Logger) 1431 if err != nil { 1432 t.Fatal(err) 1433 } 1434 } 1435 if coreConfig.HAPhysical == nil && (opts == nil || opts.PhysicalFactory == nil) { 1436 haPhys, err := physInmem.NewInmemHA(nil, testCluster.Logger) 1437 if err != nil { 1438 t.Fatal(err) 1439 } 1440 coreConfig.HAPhysical = haPhys.(physical.HABackend) 1441 } 1442 1443 pubKey, priKey, err := testGenerateCoreKeys() 1444 if err != nil { 1445 t.Fatalf("err: %v", err) 1446 } 1447 1448 cleanupFuncs := []func(){} 1449 cores := []*Core{} 1450 coreConfigs := []*CoreConfig{} 1451 for i := 0; i < numCores; i++ { 1452 localConfig := *coreConfig 1453 localConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", listeners[i][0].Address.Port) 1454 1455 // if opts.SealFunc is provided, use that to generate a seal for the config instead 1456 if opts != nil && opts.SealFunc != nil { 1457 localConfig.Seal = opts.SealFunc() 1458 } 1459 1460 if coreConfig.Logger == nil || (opts != nil && opts.Logger != nil) { 1461 localConfig.Logger = testCluster.Logger.Named(fmt.Sprintf("core%d", i)) 1462 } 1463 if opts != nil && opts.PhysicalFactory != nil { 1464 physBundle := opts.PhysicalFactory(t, i, localConfig.Logger) 1465 switch { 1466 case physBundle == nil && coreConfig.Physical != nil: 1467 case physBundle == nil && coreConfig.Physical == nil: 1468 t.Fatal("PhysicalFactory produced no physical and none in CoreConfig") 1469 case physBundle != nil: 1470 testCluster.Logger.Info("created physical backend", "instance", i) 1471 coreConfig.Physical = physBundle.Backend 1472 localConfig.Physical = physBundle.Backend 1473 base.Physical = physBundle.Backend 1474 haBackend := physBundle.HABackend 1475 if haBackend == nil { 1476 if ha, ok := physBundle.Backend.(physical.HABackend); ok { 1477 haBackend = ha 1478 } 1479 } 1480 coreConfig.HAPhysical = haBackend 1481 localConfig.HAPhysical = haBackend 1482 if physBundle.Cleanup != nil { 1483 cleanupFuncs = append(cleanupFuncs, physBundle.Cleanup) 1484 } 1485 } 1486 } 1487 1488 switch { 1489 case localConfig.LicensingConfig != nil: 1490 if pubKey != nil { 1491 localConfig.LicensingConfig.AdditionalPublicKeys = append(localConfig.LicensingConfig.AdditionalPublicKeys, pubKey.(ed25519.PublicKey)) 1492 } 1493 default: 1494 localConfig.LicensingConfig = testGetLicensingConfig(pubKey) 1495 } 1496 1497 if localConfig.MetricsHelper == nil { 1498 inm := metrics.NewInmemSink(10*time.Second, time.Minute) 1499 metrics.DefaultInmemSignal(inm) 1500 localConfig.MetricsHelper = metricsutil.NewMetricsHelper(inm, false) 1501 } 1502 1503 c, err := NewCore(&localConfig) 1504 if err != nil { 1505 t.Fatalf("err: %v", err) 1506 } 1507 c.coreNumber = firstCoreNumber + i 1508 cores = append(cores, c) 1509 coreConfigs = append(coreConfigs, &localConfig) 1510 if opts != nil && opts.HandlerFunc != nil { 1511 props := opts.DefaultHandlerProperties 1512 props.Core = c 1513 if props.MaxRequestDuration == 0 { 1514 props.MaxRequestDuration = DefaultMaxRequestDuration 1515 } 1516 handlers[i] = opts.HandlerFunc(&props) 1517 servers[i].Handler = handlers[i] 1518 } 1519 1520 // Set this in case the Seal was manually set before the core was 1521 // created 1522 if localConfig.Seal != nil { 1523 localConfig.Seal.SetCore(c) 1524 } 1525 } 1526 1527 // 1528 // Clustering setup 1529 // 1530 clusterAddrGen := func(lns []*TestListener) []*net.TCPAddr { 1531 ret := make([]*net.TCPAddr, len(lns)) 1532 for i, ln := range lns { 1533 ret[i] = &net.TCPAddr{ 1534 IP: ln.Address.IP, 1535 Port: 0, 1536 } 1537 } 1538 return ret 1539 } 1540 1541 for i := 0; i < numCores; i++ { 1542 if coreConfigs[i].ClusterAddr != "" { 1543 cores[i].SetClusterListenerAddrs(clusterAddrGen(listeners[i])) 1544 cores[i].SetClusterHandler(handlers[i]) 1545 } 1546 } 1547 1548 if opts == nil || !opts.SkipInit { 1549 bKeys, rKeys, root := TestCoreInitClusterWrapperSetup(t, cores[0], handlers[0]) 1550 barrierKeys, _ := copystructure.Copy(bKeys) 1551 testCluster.BarrierKeys = barrierKeys.([][]byte) 1552 recoveryKeys, _ := copystructure.Copy(rKeys) 1553 testCluster.RecoveryKeys = recoveryKeys.([][]byte) 1554 testCluster.RootToken = root 1555 1556 // Write root token and barrier keys 1557 err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "root_token"), []byte(root), 0755) 1558 if err != nil { 1559 t.Fatal(err) 1560 } 1561 var buf bytes.Buffer 1562 for i, key := range testCluster.BarrierKeys { 1563 buf.Write([]byte(base64.StdEncoding.EncodeToString(key))) 1564 if i < len(testCluster.BarrierKeys)-1 { 1565 buf.WriteRune('\n') 1566 } 1567 } 1568 err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "barrier_keys"), buf.Bytes(), 0755) 1569 if err != nil { 1570 t.Fatal(err) 1571 } 1572 for i, key := range testCluster.RecoveryKeys { 1573 buf.Write([]byte(base64.StdEncoding.EncodeToString(key))) 1574 if i < len(testCluster.RecoveryKeys)-1 { 1575 buf.WriteRune('\n') 1576 } 1577 } 1578 err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "recovery_keys"), buf.Bytes(), 0755) 1579 if err != nil { 1580 t.Fatal(err) 1581 } 1582 1583 // Unseal first core 1584 for _, key := range bKeys { 1585 if _, err := cores[0].Unseal(TestKeyCopy(key)); err != nil { 1586 t.Fatalf("unseal err: %s", err) 1587 } 1588 } 1589 1590 ctx := context.Background() 1591 1592 // If stored keys is supported, the above will no no-op, so trigger auto-unseal 1593 // using stored keys to try to unseal 1594 if err := cores[0].UnsealWithStoredKeys(ctx); err != nil { 1595 t.Fatal(err) 1596 } 1597 1598 // Verify unsealed 1599 if cores[0].Sealed() { 1600 t.Fatal("should not be sealed") 1601 } 1602 1603 TestWaitActive(t, cores[0]) 1604 1605 // Existing tests rely on this; we can make a toggle to disable it 1606 // later if we want 1607 kvReq := &logical.Request{ 1608 Operation: logical.UpdateOperation, 1609 ClientToken: testCluster.RootToken, 1610 Path: "sys/mounts/secret", 1611 Data: map[string]interface{}{ 1612 "type": "kv", 1613 "path": "secret/", 1614 "description": "key/value secret storage", 1615 "options": map[string]string{ 1616 "version": "1", 1617 }, 1618 }, 1619 } 1620 resp, err := cores[0].HandleRequest(namespace.RootContext(ctx), kvReq) 1621 if err != nil { 1622 t.Fatal(err) 1623 } 1624 if resp.IsError() { 1625 t.Fatal(err) 1626 } 1627 1628 cfg, err := cores[0].seal.BarrierConfig(ctx) 1629 if err != nil { 1630 t.Fatal(err) 1631 } 1632 1633 // Unseal other cores unless otherwise specified 1634 if (opts == nil || !opts.KeepStandbysSealed) && numCores > 1 { 1635 for i := 1; i < numCores; i++ { 1636 cores[i].seal.SetCachedBarrierConfig(cfg) 1637 for _, key := range bKeys { 1638 if _, err := cores[i].Unseal(TestKeyCopy(key)); err != nil { 1639 t.Fatalf("unseal err: %s", err) 1640 } 1641 } 1642 1643 // If stored keys is supported, the above will no no-op, so trigger auto-unseal 1644 // using stored keys 1645 if err := cores[i].UnsealWithStoredKeys(ctx); err != nil { 1646 t.Fatal(err) 1647 } 1648 } 1649 1650 // Let them come fully up to standby 1651 time.Sleep(2 * time.Second) 1652 1653 // Ensure cluster connection info is populated. 1654 // Other cores should not come up as leaders. 1655 for i := 1; i < numCores; i++ { 1656 isLeader, _, _, err := cores[i].Leader() 1657 if err != nil { 1658 t.Fatal(err) 1659 } 1660 if isLeader { 1661 t.Fatalf("core[%d] should not be leader", i) 1662 } 1663 } 1664 } 1665 1666 // 1667 // Set test cluster core(s) and test cluster 1668 // 1669 cluster, err := cores[0].Cluster(context.Background()) 1670 if err != nil { 1671 t.Fatal(err) 1672 } 1673 testCluster.ID = cluster.ID 1674 1675 if addAuditBackend { 1676 // Enable auditing. 1677 auditReq := &logical.Request{ 1678 Operation: logical.UpdateOperation, 1679 ClientToken: testCluster.RootToken, 1680 Path: "sys/audit/noop", 1681 Data: map[string]interface{}{ 1682 "type": "noop", 1683 }, 1684 } 1685 resp, err = cores[0].HandleRequest(namespace.RootContext(ctx), auditReq) 1686 if err != nil { 1687 t.Fatal(err) 1688 } 1689 1690 if resp.IsError() { 1691 t.Fatal(err) 1692 } 1693 } 1694 } 1695 1696 getAPIClient := func(port int, tlsConfig *tls.Config) *api.Client { 1697 transport := cleanhttp.DefaultPooledTransport() 1698 transport.TLSClientConfig = tlsConfig.Clone() 1699 if err := http2.ConfigureTransport(transport); err != nil { 1700 t.Fatal(err) 1701 } 1702 client := &http.Client{ 1703 Transport: transport, 1704 CheckRedirect: func(*http.Request, []*http.Request) error { 1705 // This can of course be overridden per-test by using its own client 1706 return fmt.Errorf("redirects not allowed in these tests") 1707 }, 1708 } 1709 config := api.DefaultConfig() 1710 if config.Error != nil { 1711 t.Fatal(config.Error) 1712 } 1713 config.Address = fmt.Sprintf("https://127.0.0.1:%d", port) 1714 config.HttpClient = client 1715 config.MaxRetries = 0 1716 apiClient, err := api.NewClient(config) 1717 if err != nil { 1718 t.Fatal(err) 1719 } 1720 if opts == nil || !opts.SkipInit { 1721 apiClient.SetToken(testCluster.RootToken) 1722 } 1723 return apiClient 1724 } 1725 1726 var ret []*TestClusterCore 1727 for i := 0; i < numCores; i++ { 1728 tcc := &TestClusterCore{ 1729 Core: cores[i], 1730 CoreConfig: coreConfigs[i], 1731 ServerKey: certInfoSlice[i].key, 1732 ServerKeyPEM: certInfoSlice[i].keyPEM, 1733 ServerCert: certInfoSlice[i].cert, 1734 ServerCertBytes: certInfoSlice[i].certBytes, 1735 ServerCertPEM: certInfoSlice[i].certPEM, 1736 Listeners: listeners[i], 1737 Handler: handlers[i], 1738 Server: servers[i], 1739 TLSConfig: tlsConfigs[i], 1740 Client: getAPIClient(listeners[i][0].Address.Port, tlsConfigs[i]), 1741 Barrier: cores[i].barrier, 1742 NodeID: fmt.Sprintf("core-%d", i), 1743 UnderlyingRawStorage: coreConfigs[i].Physical, 1744 } 1745 tcc.ReloadFuncs = &cores[i].reloadFuncs 1746 tcc.ReloadFuncsLock = &cores[i].reloadFuncsLock 1747 tcc.ReloadFuncsLock.Lock() 1748 (*tcc.ReloadFuncs)["listener|tcp"] = []reload.ReloadFunc{certGetters[i].Reload} 1749 tcc.ReloadFuncsLock.Unlock() 1750 1751 testAdjustTestCore(base, tcc) 1752 1753 ret = append(ret, tcc) 1754 } 1755 1756 testCluster.Cores = ret 1757 1758 testExtraClusterCoresTestSetup(t, priKey, testCluster.Cores) 1759 1760 testCluster.CleanupFunc = func() { 1761 for _, c := range cleanupFuncs { 1762 c() 1763 } 1764 } 1765 if opts != nil { 1766 if opts.SetupFunc != nil { 1767 testCluster.SetupFunc = func() { 1768 opts.SetupFunc(t, &testCluster) 1769 } 1770 } 1771 } 1772 1773 return &testCluster 1774} 1775 1776func NewMockBuiltinRegistry() *mockBuiltinRegistry { 1777 return &mockBuiltinRegistry{ 1778 forTesting: map[string]consts.PluginType{ 1779 "mysql-database-plugin": consts.PluginTypeDatabase, 1780 "postgresql-database-plugin": consts.PluginTypeDatabase, 1781 }, 1782 } 1783} 1784 1785type mockBuiltinRegistry struct { 1786 forTesting map[string]consts.PluginType 1787} 1788 1789func (m *mockBuiltinRegistry) Get(name string, pluginType consts.PluginType) (func() (interface{}, error), bool) { 1790 testPluginType, ok := m.forTesting[name] 1791 if !ok { 1792 return nil, false 1793 } 1794 if pluginType != testPluginType { 1795 return nil, false 1796 } 1797 if name == "postgresql-database-plugin" { 1798 return dbPostgres.New, true 1799 } 1800 return dbMysql.New(dbMysql.MetadataLen, dbMysql.MetadataLen, dbMysql.UsernameLen), true 1801} 1802 1803// Keys only supports getting a realistic list of the keys for database plugins. 1804func (m *mockBuiltinRegistry) Keys(pluginType consts.PluginType) []string { 1805 if pluginType != consts.PluginTypeDatabase { 1806 return []string{} 1807 } 1808 /* 1809 This is a hard-coded reproduction of the db plugin keys in helper/builtinplugins/registry.go. 1810 The registry isn't directly used because it causes import cycles. 1811 */ 1812 return []string{ 1813 "mysql-database-plugin", 1814 "mysql-aurora-database-plugin", 1815 "mysql-rds-database-plugin", 1816 "mysql-legacy-database-plugin", 1817 "postgresql-database-plugin", 1818 "elasticsearch-database-plugin", 1819 "mssql-database-plugin", 1820 "cassandra-database-plugin", 1821 "mongodb-database-plugin", 1822 "hana-database-plugin", 1823 "influxdb-database-plugin", 1824 } 1825} 1826 1827func (m *mockBuiltinRegistry) Contains(name string, pluginType consts.PluginType) bool { 1828 return false 1829} 1830 1831type NoopAudit struct { 1832 Config *audit.BackendConfig 1833 ReqErr error 1834 ReqAuth []*logical.Auth 1835 Req []*logical.Request 1836 ReqHeaders []map[string][]string 1837 ReqNonHMACKeys []string 1838 ReqErrs []error 1839 1840 RespErr error 1841 RespAuth []*logical.Auth 1842 RespReq []*logical.Request 1843 Resp []*logical.Response 1844 RespNonHMACKeys []string 1845 RespReqNonHMACKeys []string 1846 RespErrs []error 1847 1848 salt *salt.Salt 1849 saltMutex sync.RWMutex 1850} 1851 1852func (n *NoopAudit) LogRequest(ctx context.Context, in *logical.LogInput) error { 1853 n.ReqAuth = append(n.ReqAuth, in.Auth) 1854 n.Req = append(n.Req, in.Request) 1855 n.ReqHeaders = append(n.ReqHeaders, in.Request.Headers) 1856 n.ReqNonHMACKeys = in.NonHMACReqDataKeys 1857 n.ReqErrs = append(n.ReqErrs, in.OuterErr) 1858 return n.ReqErr 1859} 1860 1861func (n *NoopAudit) LogResponse(ctx context.Context, in *logical.LogInput) error { 1862 n.RespAuth = append(n.RespAuth, in.Auth) 1863 n.RespReq = append(n.RespReq, in.Request) 1864 n.Resp = append(n.Resp, in.Response) 1865 n.RespErrs = append(n.RespErrs, in.OuterErr) 1866 1867 if in.Response != nil { 1868 n.RespNonHMACKeys = in.NonHMACRespDataKeys 1869 n.RespReqNonHMACKeys = in.NonHMACReqDataKeys 1870 } 1871 1872 return n.RespErr 1873} 1874 1875func (n *NoopAudit) Salt(ctx context.Context) (*salt.Salt, error) { 1876 n.saltMutex.RLock() 1877 if n.salt != nil { 1878 defer n.saltMutex.RUnlock() 1879 return n.salt, nil 1880 } 1881 n.saltMutex.RUnlock() 1882 n.saltMutex.Lock() 1883 defer n.saltMutex.Unlock() 1884 if n.salt != nil { 1885 return n.salt, nil 1886 } 1887 salt, err := salt.NewSalt(ctx, n.Config.SaltView, n.Config.SaltConfig) 1888 if err != nil { 1889 return nil, err 1890 } 1891 n.salt = salt 1892 return salt, nil 1893} 1894 1895func (n *NoopAudit) GetHash(ctx context.Context, data string) (string, error) { 1896 salt, err := n.Salt(ctx) 1897 if err != nil { 1898 return "", err 1899 } 1900 return salt.GetIdentifiedHMAC(data), nil 1901} 1902 1903func (n *NoopAudit) Reload(ctx context.Context) error { 1904 return nil 1905} 1906 1907func (n *NoopAudit) Invalidate(ctx context.Context) { 1908 n.saltMutex.Lock() 1909 defer n.saltMutex.Unlock() 1910 n.salt = nil 1911} 1912