1package vault
2
3import (
4	"bytes"
5	"context"
6	"crypto/ecdsa"
7	"crypto/elliptic"
8	"crypto/rand"
9	"crypto/sha256"
10	"crypto/tls"
11	"crypto/x509"
12	"crypto/x509/pkix"
13	"encoding/base64"
14	"encoding/pem"
15	"errors"
16	"fmt"
17	"io"
18	"io/ioutil"
19	"math/big"
20	mathrand "math/rand"
21	"net"
22	"net/http"
23	"os"
24	"os/exec"
25	"path/filepath"
26	"sync"
27	"sync/atomic"
28	"time"
29
30	"github.com/armon/go-metrics"
31	hclog "github.com/hashicorp/go-hclog"
32	log "github.com/hashicorp/go-hclog"
33	"github.com/hashicorp/vault/helper/metricsutil"
34	"github.com/mitchellh/copystructure"
35
36	"golang.org/x/crypto/ed25519"
37	"golang.org/x/crypto/ssh"
38	"golang.org/x/net/http2"
39
40	cleanhttp "github.com/hashicorp/go-cleanhttp"
41	"github.com/hashicorp/vault/api"
42	"github.com/hashicorp/vault/audit"
43	"github.com/hashicorp/vault/command/server"
44	"github.com/hashicorp/vault/helper/namespace"
45	"github.com/hashicorp/vault/helper/reload"
46	dbMysql "github.com/hashicorp/vault/plugins/database/mysql"
47	dbPostgres "github.com/hashicorp/vault/plugins/database/postgresql"
48	"github.com/hashicorp/vault/sdk/framework"
49	"github.com/hashicorp/vault/sdk/helper/consts"
50	"github.com/hashicorp/vault/sdk/helper/logging"
51	"github.com/hashicorp/vault/sdk/helper/salt"
52	"github.com/hashicorp/vault/sdk/logical"
53	"github.com/hashicorp/vault/sdk/physical"
54	testing "github.com/mitchellh/go-testing-interface"
55
56	physInmem "github.com/hashicorp/vault/sdk/physical/inmem"
57)
58
59// This file contains a number of methods that are useful for unit
60// tests within other packages.
61
62const (
63	testSharedPublicKey = `
64ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQC9i+hFxZHGo6KblVme4zrAcJstR6I0PTJozW286X4WyvPnkMYDQ5mnhEYC7UWCvjoTWbPEXPX7NjhRtwQTGD67bV+lrxgfyzK1JZbUXK4PwgKJvQD+XyyWYMzDgGSQY61KUSqCxymSm/9NZkPU3ElaQ9xQuTzPpztM4ROfb8f2Yv6/ZESZsTo0MTAkp8Pcy+WkioI/uJ1H7zqs0EA4OMY4aDJRu0UtP4rTVeYNEAuRXdX+eH4aW3KMvhzpFTjMbaJHJXlEeUm2SaX5TNQyTOvghCeQILfYIL/Ca2ij8iwCmulwdV6eQGfd4VDu40PvSnmfoaE38o6HaPnX0kUcnKiT
65`
66	testSharedPrivateKey = `
67-----BEGIN RSA PRIVATE KEY-----
68MIIEogIBAAKCAQEAvYvoRcWRxqOim5VZnuM6wHCbLUeiND0yaM1tvOl+Fsrz55DG
69A0OZp4RGAu1Fgr46E1mzxFz1+zY4UbcEExg+u21fpa8YH8sytSWW1FyuD8ICib0A
70/l8slmDMw4BkkGOtSlEqgscpkpv/TWZD1NxJWkPcULk8z6c7TOETn2/H9mL+v2RE
71mbE6NDEwJKfD3MvlpIqCP7idR+86rNBAODjGOGgyUbtFLT+K01XmDRALkV3V/nh+
72GltyjL4c6RU4zG2iRyV5RHlJtkml+UzUMkzr4IQnkCC32CC/wmtoo/IsAprpcHVe
73nkBn3eFQ7uND70p5n6GhN/KOh2j519JFHJyokwIDAQABAoIBAHX7VOvBC3kCN9/x
74+aPdup84OE7Z7MvpX6w+WlUhXVugnmsAAVDczhKoUc/WktLLx2huCGhsmKvyVuH+
75MioUiE+vx75gm3qGx5xbtmOfALVMRLopjCnJYf6EaFA0ZeQ+NwowNW7Lu0PHmAU8
76Z3JiX8IwxTz14DU82buDyewO7v+cEr97AnERe3PUcSTDoUXNaoNxjNpEJkKREY6h
774hAY676RT/GsRcQ8tqe/rnCqPHNd7JGqL+207FK4tJw7daoBjQyijWuB7K5chSal
78oPInylM6b13ASXuOAOT/2uSUBWmFVCZPDCmnZxy2SdnJGbsJAMl7Ma3MUlaGvVI+
79Tfh1aQkCgYEA4JlNOabTb3z42wz6mz+Nz3JRwbawD+PJXOk5JsSnV7DtPtfgkK9y
806FTQdhnozGWShAvJvc+C4QAihs9AlHXoaBY5bEU7R/8UK/pSqwzam+MmxmhVDV7G
81IMQPV0FteoXTaJSikhZ88mETTegI2mik+zleBpVxvfdhE5TR+lq8Br0CgYEA2AwJ
82CUD5CYUSj09PluR0HHqamWOrJkKPFPwa+5eiTTCzfBBxImYZh7nXnWuoviXC0sg2
83AuvCW+uZ48ygv/D8gcz3j1JfbErKZJuV+TotK9rRtNIF5Ub7qysP7UjyI7zCssVM
84kuDd9LfRXaB/qGAHNkcDA8NxmHW3gpln4CFdSY8CgYANs4xwfercHEWaJ1qKagAe
85rZyrMpffAEhicJ/Z65lB0jtG4CiE6w8ZeUMWUVJQVcnwYD+4YpZbX4S7sJ0B8Ydy
86AhkSr86D/92dKTIt2STk6aCN7gNyQ1vW198PtaAWH1/cO2UHgHOy3ZUt5X/Uwxl9
87cex4flln+1Viumts2GgsCQKBgCJH7psgSyPekK5auFdKEr5+Gc/jB8I/Z3K9+g4X
885nH3G1PBTCJYLw7hRzw8W/8oALzvddqKzEFHphiGXK94Lqjt/A4q1OdbCrhiE68D
89My21P/dAKB1UYRSs9Y8CNyHCjuZM9jSMJ8vv6vG/SOJPsnVDWVAckAbQDvlTHC9t
90O98zAoGAcbW6uFDkrv0XMCpB9Su3KaNXOR0wzag+WIFQRXCcoTvxVi9iYfUReQPi
91oOyBJU/HMVvBfv4g+OVFLVgSwwm6owwsouZ0+D/LasbuHqYyqYqdyPJQYzWA2Y+F
92+B6f4RoPdSXj24JHPg/ioRxjaj094UXJxua2yfkcecGNEuBQHSs=
93-----END RSA PRIVATE KEY-----
94`
95)
96
97// TestCore returns a pure in-memory, uninitialized core for testing.
98func TestCore(t testing.T) *Core {
99	return TestCoreWithSeal(t, nil, false)
100}
101
102// TestCoreRaw returns a pure in-memory, uninitialized core for testing. The raw
103// storage endpoints are enabled with this core.
104func TestCoreRaw(t testing.T) *Core {
105	return TestCoreWithSeal(t, nil, true)
106}
107
108// TestCoreNewSeal returns a pure in-memory, uninitialized core with
109// the new seal configuration.
110func TestCoreNewSeal(t testing.T) *Core {
111	seal := NewTestSeal(t, nil)
112	return TestCoreWithSeal(t, seal, false)
113}
114
115// TestCoreWithConfig returns a pure in-memory, uninitialized core with the
116// specified core configurations overridden for testing.
117func TestCoreWithConfig(t testing.T, conf *CoreConfig) *Core {
118	return TestCoreWithSealAndUI(t, conf)
119}
120
121// TestCoreWithSeal returns a pure in-memory, uninitialized core with the
122// specified seal for testing.
123func TestCoreWithSeal(t testing.T, testSeal Seal, enableRaw bool) *Core {
124	conf := &CoreConfig{
125		Seal:            testSeal,
126		EnableUI:        false,
127		EnableRaw:       enableRaw,
128		BuiltinRegistry: NewMockBuiltinRegistry(),
129	}
130	return TestCoreWithSealAndUI(t, conf)
131}
132
133func TestCoreUI(t testing.T, enableUI bool) *Core {
134	conf := &CoreConfig{
135		EnableUI:        enableUI,
136		EnableRaw:       true,
137		BuiltinRegistry: NewMockBuiltinRegistry(),
138	}
139	return TestCoreWithSealAndUI(t, conf)
140}
141
142func TestCoreWithSealAndUI(t testing.T, opts *CoreConfig) *Core {
143	logger := logging.NewVaultLogger(log.Trace)
144	physicalBackend, err := physInmem.NewInmem(nil, logger)
145	if err != nil {
146		t.Fatal(err)
147	}
148
149	errInjector := physical.NewErrorInjector(physicalBackend, 0, logger)
150
151	// Start off with base test core config
152	conf := testCoreConfig(t, errInjector, logger)
153
154	// Override config values with ones that gets passed in
155	conf.EnableUI = opts.EnableUI
156	conf.EnableRaw = opts.EnableRaw
157	conf.Seal = opts.Seal
158	conf.LicensingConfig = opts.LicensingConfig
159	conf.DisableKeyEncodingChecks = opts.DisableKeyEncodingChecks
160	conf.MetricsHelper = opts.MetricsHelper
161
162	if opts.Logger != nil {
163		conf.Logger = opts.Logger
164	}
165
166	for k, v := range opts.LogicalBackends {
167		conf.LogicalBackends[k] = v
168	}
169	for k, v := range opts.CredentialBackends {
170		conf.CredentialBackends[k] = v
171	}
172
173	for k, v := range opts.AuditBackends {
174		conf.AuditBackends[k] = v
175	}
176
177	c, err := NewCore(conf)
178	if err != nil {
179		t.Fatalf("err: %s", err)
180	}
181
182	return c
183}
184
185func testCoreConfig(t testing.T, physicalBackend physical.Backend, logger log.Logger) *CoreConfig {
186	t.Helper()
187	noopAudits := map[string]audit.Factory{
188		"noop": func(_ context.Context, config *audit.BackendConfig) (audit.Backend, error) {
189			view := &logical.InmemStorage{}
190			view.Put(context.Background(), &logical.StorageEntry{
191				Key:   "salt",
192				Value: []byte("foo"),
193			})
194			config.SaltConfig = &salt.Config{
195				HMAC:     sha256.New,
196				HMACType: "hmac-sha256",
197			}
198			config.SaltView = view
199
200			n := &noopAudit{
201				Config: config,
202			}
203			n.formatter.AuditFormatWriter = &audit.JSONFormatWriter{
204				SaltFunc: n.Salt,
205			}
206			return n, nil
207		},
208	}
209
210	noopBackends := make(map[string]logical.Factory)
211	noopBackends["noop"] = func(ctx context.Context, config *logical.BackendConfig) (logical.Backend, error) {
212		b := new(framework.Backend)
213		b.Setup(ctx, config)
214		b.BackendType = logical.TypeCredential
215		return b, nil
216	}
217	noopBackends["http"] = func(ctx context.Context, config *logical.BackendConfig) (logical.Backend, error) {
218		return new(rawHTTP), nil
219	}
220
221	credentialBackends := make(map[string]logical.Factory)
222	for backendName, backendFactory := range noopBackends {
223		credentialBackends[backendName] = backendFactory
224	}
225	for backendName, backendFactory := range testCredentialBackends {
226		credentialBackends[backendName] = backendFactory
227	}
228
229	logicalBackends := make(map[string]logical.Factory)
230	for backendName, backendFactory := range noopBackends {
231		logicalBackends[backendName] = backendFactory
232	}
233
234	logicalBackends["kv"] = LeasedPassthroughBackendFactory
235	for backendName, backendFactory := range testLogicalBackends {
236		logicalBackends[backendName] = backendFactory
237	}
238
239	conf := &CoreConfig{
240		Physical:           physicalBackend,
241		AuditBackends:      noopAudits,
242		LogicalBackends:    logicalBackends,
243		CredentialBackends: credentialBackends,
244		DisableMlock:       true,
245		Logger:             logger,
246		BuiltinRegistry:    NewMockBuiltinRegistry(),
247	}
248
249	return conf
250}
251
252// TestCoreInit initializes the core with a single key, and returns
253// the key that must be used to unseal the core and a root token.
254func TestCoreInit(t testing.T, core *Core) ([][]byte, string) {
255	t.Helper()
256	secretShares, _, root := TestCoreInitClusterWrapperSetup(t, core, nil)
257	return secretShares, root
258}
259
260func TestCoreInitClusterWrapperSetup(t testing.T, core *Core, handler http.Handler) ([][]byte, [][]byte, string) {
261	t.Helper()
262	core.SetClusterHandler(handler)
263
264	barrierConfig := &SealConfig{
265		SecretShares:    3,
266		SecretThreshold: 3,
267	}
268
269	switch core.seal.StoredKeysSupported() {
270	case StoredKeysNotSupported:
271		barrierConfig.StoredShares = 0
272	default:
273		barrierConfig.StoredShares = 1
274	}
275
276	recoveryConfig := &SealConfig{
277		SecretShares:    3,
278		SecretThreshold: 3,
279	}
280
281	initParams := &InitParams{
282		BarrierConfig:  barrierConfig,
283		RecoveryConfig: recoveryConfig,
284	}
285	if core.seal.StoredKeysSupported() == StoredKeysNotSupported {
286		initParams.LegacyShamirSeal = true
287	}
288	result, err := core.Initialize(context.Background(), initParams)
289	if err != nil {
290		t.Fatalf("err: %s", err)
291	}
292	return result.SecretShares, result.RecoveryShares, result.RootToken
293}
294
295func TestCoreUnseal(core *Core, key []byte) (bool, error) {
296	return core.Unseal(key)
297}
298
299func TestCoreUnsealWithRecoveryKeys(core *Core, key []byte) (bool, error) {
300	return core.UnsealWithRecoveryKeys(key)
301}
302
303// TestCoreUnsealed returns a pure in-memory core that is already
304// initialized and unsealed.
305func TestCoreUnsealed(t testing.T) (*Core, [][]byte, string) {
306	t.Helper()
307	core := TestCore(t)
308	return testCoreUnsealed(t, core)
309}
310
311// TestCoreUnsealedRaw returns a pure in-memory core that is already
312// initialized, unsealed, and with raw endpoints enabled.
313func TestCoreUnsealedRaw(t testing.T) (*Core, [][]byte, string) {
314	t.Helper()
315	core := TestCoreRaw(t)
316	return testCoreUnsealed(t, core)
317}
318
319// TestCoreUnsealedWithConfig returns a pure in-memory core that is already
320// initialized, unsealed, with the any provided core config values overridden.
321func TestCoreUnsealedWithConfig(t testing.T, conf *CoreConfig) (*Core, [][]byte, string) {
322	t.Helper()
323	core := TestCoreWithConfig(t, conf)
324	return testCoreUnsealed(t, core)
325}
326
327func testCoreUnsealed(t testing.T, core *Core) (*Core, [][]byte, string) {
328	t.Helper()
329	keys, token := TestCoreInit(t, core)
330	for _, key := range keys {
331		if _, err := TestCoreUnseal(core, TestKeyCopy(key)); err != nil {
332			t.Fatalf("unseal err: %s", err)
333		}
334	}
335
336	if core.Sealed() {
337		t.Fatal("should not be sealed")
338	}
339
340	testCoreAddSecretMount(t, core, token)
341
342	return core, keys, token
343}
344
345func testCoreAddSecretMount(t testing.T, core *Core, token string) {
346	kvReq := &logical.Request{
347		Operation:   logical.UpdateOperation,
348		ClientToken: token,
349		Path:        "sys/mounts/secret",
350		Data: map[string]interface{}{
351			"type":        "kv",
352			"path":        "secret/",
353			"description": "key/value secret storage",
354			"options": map[string]string{
355				"version": "1",
356			},
357		},
358	}
359	resp, err := core.HandleRequest(namespace.RootContext(nil), kvReq)
360	if err != nil {
361		t.Fatal(err)
362	}
363	if resp.IsError() {
364		t.Fatal(err)
365	}
366
367}
368
369func TestCoreUnsealedBackend(t testing.T, backend physical.Backend) (*Core, [][]byte, string) {
370	t.Helper()
371	logger := logging.NewVaultLogger(log.Trace)
372	conf := testCoreConfig(t, backend, logger)
373	conf.Seal = NewTestSeal(t, nil)
374
375	core, err := NewCore(conf)
376	if err != nil {
377		t.Fatalf("err: %s", err)
378	}
379
380	keys, token := TestCoreInit(t, core)
381	for _, key := range keys {
382		if _, err := TestCoreUnseal(core, TestKeyCopy(key)); err != nil {
383			t.Fatalf("unseal err: %s", err)
384		}
385	}
386
387	if err := core.UnsealWithStoredKeys(context.Background()); err != nil {
388		t.Fatal(err)
389	}
390
391	if core.Sealed() {
392		t.Fatal("should not be sealed")
393	}
394
395	return core, keys, token
396}
397
398// TestKeyCopy is a silly little function to just copy the key so that
399// it can be used with Unseal easily.
400func TestKeyCopy(key []byte) []byte {
401	result := make([]byte, len(key))
402	copy(result, key)
403	return result
404}
405
406func TestDynamicSystemView(c *Core) *dynamicSystemView {
407	me := &MountEntry{
408		Config: MountConfig{
409			DefaultLeaseTTL: 24 * time.Hour,
410			MaxLeaseTTL:     2 * 24 * time.Hour,
411		},
412	}
413
414	return &dynamicSystemView{c, me}
415}
416
417// TestAddTestPlugin registers the testFunc as part of the plugin command to the
418// plugin catalog. If provided, uses tmpDir as the plugin directory.
419func TestAddTestPlugin(t testing.T, c *Core, name string, pluginType consts.PluginType, testFunc string, env []string, tempDir string) {
420	file, err := os.Open(os.Args[0])
421	if err != nil {
422		t.Fatal(err)
423	}
424	defer file.Close()
425
426	dirPath := filepath.Dir(os.Args[0])
427	fileName := filepath.Base(os.Args[0])
428
429	if tempDir != "" {
430		fi, err := file.Stat()
431		if err != nil {
432			t.Fatal(err)
433		}
434
435		// Copy over the file to the temp dir
436		dst := filepath.Join(tempDir, fileName)
437		out, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode())
438		if err != nil {
439			t.Fatal(err)
440		}
441		defer out.Close()
442
443		if _, err = io.Copy(out, file); err != nil {
444			t.Fatal(err)
445		}
446		err = out.Sync()
447		if err != nil {
448			t.Fatal(err)
449		}
450
451		dirPath = tempDir
452	}
453
454	// Determine plugin directory full path, evaluating potential symlink path
455	fullPath, err := filepath.EvalSymlinks(dirPath)
456	if err != nil {
457		t.Fatal(err)
458	}
459
460	reader, err := os.Open(filepath.Join(fullPath, fileName))
461	if err != nil {
462		t.Fatal(err)
463	}
464	defer reader.Close()
465
466	// Find out the sha256
467	hash := sha256.New()
468
469	_, err = io.Copy(hash, reader)
470	if err != nil {
471		t.Fatal(err)
472	}
473
474	sum := hash.Sum(nil)
475
476	// Set core's plugin directory and plugin catalog directory
477	c.pluginDirectory = fullPath
478	c.pluginCatalog.directory = fullPath
479
480	args := []string{fmt.Sprintf("--test.run=%s", testFunc)}
481	err = c.pluginCatalog.Set(context.Background(), name, pluginType, fileName, args, env, sum)
482	if err != nil {
483		t.Fatal(err)
484	}
485}
486
487var testLogicalBackends = map[string]logical.Factory{}
488var testCredentialBackends = map[string]logical.Factory{}
489
490// StartSSHHostTestServer starts the test server which responds to SSH
491// authentication. Used to test the SSH secret backend.
492func StartSSHHostTestServer() (string, error) {
493	pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(testSharedPublicKey))
494	if err != nil {
495		return "", fmt.Errorf("error parsing public key")
496	}
497	serverConfig := &ssh.ServerConfig{
498		PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
499			if bytes.Compare(pubKey.Marshal(), key.Marshal()) == 0 {
500				return &ssh.Permissions{}, nil
501			} else {
502				return nil, fmt.Errorf("key does not match")
503			}
504		},
505	}
506	signer, err := ssh.ParsePrivateKey([]byte(testSharedPrivateKey))
507	if err != nil {
508		panic("Error parsing private key")
509	}
510	serverConfig.AddHostKey(signer)
511
512	soc, err := net.Listen("tcp", "127.0.0.1:0")
513	if err != nil {
514		return "", fmt.Errorf("error listening to connection")
515	}
516
517	go func() {
518		for {
519			conn, err := soc.Accept()
520			if err != nil {
521				panic(fmt.Sprintf("Error accepting incoming connection: %s", err))
522			}
523			defer conn.Close()
524			sshConn, chanReqs, _, err := ssh.NewServerConn(conn, serverConfig)
525			if err != nil {
526				panic(fmt.Sprintf("Handshaking error: %v", err))
527			}
528
529			go func() {
530				for chanReq := range chanReqs {
531					go func(chanReq ssh.NewChannel) {
532						if chanReq.ChannelType() != "session" {
533							chanReq.Reject(ssh.UnknownChannelType, "unknown channel type")
534							return
535						}
536
537						ch, requests, err := chanReq.Accept()
538						if err != nil {
539							panic(fmt.Sprintf("Error accepting channel: %s", err))
540						}
541
542						go func(ch ssh.Channel, in <-chan *ssh.Request) {
543							for req := range in {
544								executeServerCommand(ch, req)
545							}
546						}(ch, requests)
547					}(chanReq)
548				}
549				sshConn.Close()
550			}()
551		}
552	}()
553	return soc.Addr().String(), nil
554}
555
556// This executes the commands requested to be run on the server.
557// Used to test the SSH secret backend.
558func executeServerCommand(ch ssh.Channel, req *ssh.Request) {
559	command := string(req.Payload[4:])
560	cmd := exec.Command("/bin/bash", []string{"-c", command}...)
561	req.Reply(true, nil)
562
563	cmd.Stdout = ch
564	cmd.Stderr = ch
565	cmd.Stdin = ch
566
567	err := cmd.Start()
568	if err != nil {
569		panic(fmt.Sprintf("Error starting the command: '%s'", err))
570	}
571
572	go func() {
573		_, err := cmd.Process.Wait()
574		if err != nil {
575			panic(fmt.Sprintf("Error while waiting for command to finish:'%s'", err))
576		}
577		ch.Close()
578	}()
579}
580
581// This adds a credential backend for the test core. This needs to be
582// invoked before the test core is created.
583func AddTestCredentialBackend(name string, factory logical.Factory) error {
584	if name == "" {
585		return fmt.Errorf("missing backend name")
586	}
587	if factory == nil {
588		return fmt.Errorf("missing backend factory function")
589	}
590	testCredentialBackends[name] = factory
591	return nil
592}
593
594// This adds a logical backend for the test core. This needs to be
595// invoked before the test core is created.
596func AddTestLogicalBackend(name string, factory logical.Factory) error {
597	if name == "" {
598		return fmt.Errorf("missing backend name")
599	}
600	if factory == nil {
601		return fmt.Errorf("missing backend factory function")
602	}
603	testLogicalBackends[name] = factory
604	return nil
605}
606
607type noopAudit struct {
608	Config    *audit.BackendConfig
609	salt      *salt.Salt
610	saltMutex sync.RWMutex
611	formatter audit.AuditFormatter
612	records   [][]byte
613	l         sync.RWMutex
614}
615
616func (n *noopAudit) GetHash(ctx context.Context, data string) (string, error) {
617	salt, err := n.Salt(ctx)
618	if err != nil {
619		return "", err
620	}
621	return salt.GetIdentifiedHMAC(data), nil
622}
623
624func (n *noopAudit) LogRequest(ctx context.Context, in *logical.LogInput) error {
625	n.l.Lock()
626	defer n.l.Unlock()
627	var w bytes.Buffer
628	err := n.formatter.FormatRequest(ctx, &w, audit.FormatterConfig{}, in)
629	if err != nil {
630		return err
631	}
632	n.records = append(n.records, w.Bytes())
633	return nil
634}
635
636func (n *noopAudit) LogResponse(ctx context.Context, in *logical.LogInput) error {
637	n.l.Lock()
638	defer n.l.Unlock()
639	var w bytes.Buffer
640	err := n.formatter.FormatResponse(ctx, &w, audit.FormatterConfig{}, in)
641	if err != nil {
642		return err
643	}
644	n.records = append(n.records, w.Bytes())
645	return nil
646}
647
648func (n *noopAudit) Reload(_ context.Context) error {
649	return nil
650}
651
652func (n *noopAudit) Invalidate(_ context.Context) {
653	n.saltMutex.Lock()
654	defer n.saltMutex.Unlock()
655	n.salt = nil
656}
657
658func (n *noopAudit) Salt(ctx context.Context) (*salt.Salt, error) {
659	n.saltMutex.RLock()
660	if n.salt != nil {
661		defer n.saltMutex.RUnlock()
662		return n.salt, nil
663	}
664	n.saltMutex.RUnlock()
665	n.saltMutex.Lock()
666	defer n.saltMutex.Unlock()
667	if n.salt != nil {
668		return n.salt, nil
669	}
670	salt, err := salt.NewSalt(ctx, n.Config.SaltView, n.Config.SaltConfig)
671	if err != nil {
672		return nil, err
673	}
674	n.salt = salt
675	return salt, nil
676}
677
678func AddNoopAudit(conf *CoreConfig) {
679	conf.AuditBackends = map[string]audit.Factory{
680		"noop": func(_ context.Context, config *audit.BackendConfig) (audit.Backend, error) {
681			view := &logical.InmemStorage{}
682			view.Put(context.Background(), &logical.StorageEntry{
683				Key:   "salt",
684				Value: []byte("foo"),
685			})
686			n := &noopAudit{
687				Config: config,
688			}
689			n.formatter.AuditFormatWriter = &audit.JSONFormatWriter{
690				SaltFunc: n.Salt,
691			}
692			return n, nil
693		},
694	}
695}
696
697type rawHTTP struct{}
698
699func (n *rawHTTP) HandleRequest(ctx context.Context, req *logical.Request) (*logical.Response, error) {
700	return &logical.Response{
701		Data: map[string]interface{}{
702			logical.HTTPStatusCode:  200,
703			logical.HTTPContentType: "plain/text",
704			logical.HTTPRawBody:     []byte("hello world"),
705		},
706	}, nil
707}
708
709func (n *rawHTTP) HandleExistenceCheck(ctx context.Context, req *logical.Request) (bool, bool, error) {
710	return false, false, nil
711}
712
713func (n *rawHTTP) SpecialPaths() *logical.Paths {
714	return &logical.Paths{Unauthenticated: []string{"*"}}
715}
716
717func (n *rawHTTP) System() logical.SystemView {
718	return logical.StaticSystemView{
719		DefaultLeaseTTLVal: time.Hour * 24,
720		MaxLeaseTTLVal:     time.Hour * 24 * 32,
721	}
722}
723
724func (n *rawHTTP) Logger() log.Logger {
725	return logging.NewVaultLogger(log.Trace)
726}
727
728func (n *rawHTTP) Cleanup(ctx context.Context) {
729	// noop
730}
731
732func (n *rawHTTP) Initialize(ctx context.Context, req *logical.InitializationRequest) error {
733	return nil
734}
735
736func (n *rawHTTP) InvalidateKey(context.Context, string) {
737	// noop
738}
739
740func (n *rawHTTP) Setup(ctx context.Context, config *logical.BackendConfig) error {
741	// noop
742	return nil
743}
744
745func (n *rawHTTP) Type() logical.BackendType {
746	return logical.TypeLogical
747}
748
749func GenerateRandBytes(length int) ([]byte, error) {
750	if length < 0 {
751		return nil, fmt.Errorf("length must be >= 0")
752	}
753
754	buf := make([]byte, length)
755	if length == 0 {
756		return buf, nil
757	}
758
759	n, err := rand.Read(buf)
760	if err != nil {
761		return nil, err
762	}
763	if n != length {
764		return nil, fmt.Errorf("unable to read %d bytes; only read %d", length, n)
765	}
766
767	return buf, nil
768}
769
770func TestWaitActive(t testing.T, core *Core) {
771	t.Helper()
772	if err := TestWaitActiveWithError(core); err != nil {
773		t.Fatal(err)
774	}
775}
776
777func TestWaitActiveWithError(core *Core) error {
778	start := time.Now()
779	var standby bool
780	var err error
781	for time.Now().Sub(start) < 30*time.Second {
782		standby, err = core.Standby()
783		if err != nil {
784			return err
785		}
786		if !standby {
787			break
788		}
789	}
790	if standby {
791		return errors.New("should not be in standby mode")
792	}
793	return nil
794}
795
796type TestCluster struct {
797	BarrierKeys        [][]byte
798	RecoveryKeys       [][]byte
799	CACert             *x509.Certificate
800	CACertBytes        []byte
801	CACertPEM          []byte
802	CACertPEMFile      string
803	CAKey              *ecdsa.PrivateKey
804	CAKeyPEM           []byte
805	Cores              []*TestClusterCore
806	ID                 string
807	RootToken          string
808	RootCAs            *x509.CertPool
809	TempDir            string
810	ClientAuthRequired bool
811	Logger             log.Logger
812	CleanupFunc        func()
813	SetupFunc          func()
814}
815
816func (c *TestCluster) Start() {
817	for _, core := range c.Cores {
818		if core.Server != nil {
819			for _, ln := range core.Listeners {
820				go core.Server.Serve(ln)
821			}
822		}
823	}
824	if c.SetupFunc != nil {
825		c.SetupFunc()
826	}
827}
828
829// UnsealCores uses the cluster barrier keys to unseal the test cluster cores
830func (c *TestCluster) UnsealCores(t testing.T) {
831	t.Helper()
832	if err := c.UnsealCoresWithError(false); err != nil {
833		t.Fatal(err)
834	}
835}
836
837func (c *TestCluster) UnsealCoresWithError(useStoredKeys bool) error {
838	unseal := func(core *Core) error {
839		for _, key := range c.BarrierKeys {
840			if _, err := core.Unseal(TestKeyCopy(key)); err != nil {
841				return err
842			}
843		}
844		return nil
845	}
846	if useStoredKeys {
847		unseal = func(core *Core) error {
848			return core.UnsealWithStoredKeys(context.Background())
849		}
850	}
851
852	// Unseal first core
853	if err := unseal(c.Cores[0].Core); err != nil {
854		return fmt.Errorf("unseal core %d err: %s", 0, err)
855	}
856
857	// Verify unsealed
858	if c.Cores[0].Sealed() {
859		return fmt.Errorf("should not be sealed")
860	}
861
862	if err := TestWaitActiveWithError(c.Cores[0].Core); err != nil {
863		return err
864	}
865
866	// Unseal other cores
867	for i := 1; i < len(c.Cores); i++ {
868		if err := unseal(c.Cores[i].Core); err != nil {
869			return fmt.Errorf("unseal core %d err: %s", i, err)
870		}
871	}
872
873	// Let them come fully up to standby
874	time.Sleep(2 * time.Second)
875
876	// Ensure cluster connection info is populated.
877	// Other cores should not come up as leaders.
878	for i := 1; i < len(c.Cores); i++ {
879		isLeader, _, _, err := c.Cores[i].Leader()
880		if err != nil {
881			return err
882		}
883		if isLeader {
884			return fmt.Errorf("core[%d] should not be leader", i)
885		}
886	}
887
888	return nil
889}
890
891func (c *TestCluster) UnsealCore(t testing.T, core *TestClusterCore) {
892	for _, key := range c.BarrierKeys {
893		if _, err := core.Core.Unseal(TestKeyCopy(key)); err != nil {
894			t.Fatalf("unseal err: %s", err)
895		}
896	}
897}
898
899func (c *TestCluster) EnsureCoresSealed(t testing.T) {
900	t.Helper()
901	if err := c.ensureCoresSealed(); err != nil {
902		t.Fatal(err)
903	}
904}
905
906func (c *TestClusterCore) Seal(t testing.T) {
907	t.Helper()
908	if err := c.Core.sealInternal(); err != nil {
909		t.Fatal(err)
910	}
911}
912
913func CleanupClusters(clusters []*TestCluster) {
914	wg := &sync.WaitGroup{}
915	for _, cluster := range clusters {
916		wg.Add(1)
917		lc := cluster
918		go func() {
919			defer wg.Done()
920			lc.Cleanup()
921		}()
922	}
923	wg.Wait()
924}
925
926func (c *TestCluster) Cleanup() {
927	c.Logger.Info("cleaning up vault cluster")
928	for _, core := range c.Cores {
929		core.CoreConfig.Logger.SetLevel(log.Error)
930	}
931
932	// Close listeners
933	wg := &sync.WaitGroup{}
934	for _, core := range c.Cores {
935		wg.Add(1)
936		lc := core
937
938		go func() {
939			defer wg.Done()
940			if lc.Listeners != nil {
941				for _, ln := range lc.Listeners {
942					ln.Close()
943				}
944			}
945			if lc.licensingStopCh != nil {
946				close(lc.licensingStopCh)
947				lc.licensingStopCh = nil
948			}
949
950			if err := lc.Shutdown(); err != nil {
951				lc.Logger().Error("error during shutdown; abandoning sealing", "error", err)
952			} else {
953				timeout := time.Now().Add(60 * time.Second)
954				for {
955					if time.Now().After(timeout) {
956						lc.Logger().Error("timeout waiting for core to seal")
957					}
958					if lc.Sealed() {
959						break
960					}
961					time.Sleep(250 * time.Millisecond)
962				}
963			}
964		}()
965	}
966
967	wg.Wait()
968
969	// Remove any temp dir that exists
970	if c.TempDir != "" {
971		os.RemoveAll(c.TempDir)
972	}
973
974	// Give time to actually shut down/clean up before the next test
975	time.Sleep(time.Second)
976	if c.CleanupFunc != nil {
977		c.CleanupFunc()
978	}
979}
980
981func (c *TestCluster) ensureCoresSealed() error {
982	for _, core := range c.Cores {
983		if err := core.Shutdown(); err != nil {
984			return err
985		}
986		timeout := time.Now().Add(60 * time.Second)
987		for {
988			if time.Now().After(timeout) {
989				return fmt.Errorf("timeout waiting for core to seal")
990			}
991			if core.Sealed() {
992				break
993			}
994			time.Sleep(250 * time.Millisecond)
995		}
996	}
997	return nil
998}
999
1000func SetReplicationFailureMode(core *TestClusterCore, mode uint32) {
1001	atomic.StoreUint32(core.Core.replicationFailure, mode)
1002}
1003
1004type TestListener struct {
1005	net.Listener
1006	Address *net.TCPAddr
1007}
1008
1009type TestClusterCore struct {
1010	*Core
1011	CoreConfig           *CoreConfig
1012	Client               *api.Client
1013	Handler              http.Handler
1014	Listeners            []*TestListener
1015	ReloadFuncs          *map[string][]reload.ReloadFunc
1016	ReloadFuncsLock      *sync.RWMutex
1017	Server               *http.Server
1018	ServerCert           *x509.Certificate
1019	ServerCertBytes      []byte
1020	ServerCertPEM        []byte
1021	ServerKey            *ecdsa.PrivateKey
1022	ServerKeyPEM         []byte
1023	TLSConfig            *tls.Config
1024	UnderlyingStorage    physical.Backend
1025	UnderlyingRawStorage physical.Backend
1026	Barrier              SecurityBarrier
1027	NodeID               string
1028}
1029
1030type PhysicalBackendBundle struct {
1031	Backend   physical.Backend
1032	HABackend physical.HABackend
1033	Cleanup   func()
1034}
1035
1036type TestClusterOptions struct {
1037	KeepStandbysSealed       bool
1038	SkipInit                 bool
1039	HandlerFunc              func(*HandlerProperties) http.Handler
1040	DefaultHandlerProperties HandlerProperties
1041	BaseListenAddress        string
1042	NumCores                 int
1043	SealFunc                 func() Seal
1044	Logger                   log.Logger
1045	TempDir                  string
1046	CACert                   []byte
1047	CAKey                    *ecdsa.PrivateKey
1048	// PhysicalFactory is used to create backends.
1049	// The int argument is the index of the core within the cluster, i.e. first
1050	// core in cluster will have 0, second 1, etc.
1051	// If the backend is shared across the cluster (i.e. is not Raft) then it
1052	// should return nil when coreIdx != 0.
1053	PhysicalFactory func(t testing.T, coreIdx int, logger hclog.Logger) *PhysicalBackendBundle
1054	// FirstCoreNumber is used to assign a unique number to each core within
1055	// a multi-cluster setup.
1056	FirstCoreNumber   int
1057	RequireClientAuth bool
1058	// SetupFunc is called after the cluster is started.
1059	SetupFunc func(t testing.T, c *TestCluster)
1060}
1061
1062var DefaultNumCores = 3
1063
1064type certInfo struct {
1065	cert      *x509.Certificate
1066	certPEM   []byte
1067	certBytes []byte
1068	key       *ecdsa.PrivateKey
1069	keyPEM    []byte
1070}
1071
1072// NewTestCluster creates a new test cluster based on the provided core config
1073// and test cluster options.
1074//
1075// N.B. Even though a single base CoreConfig is provided, NewTestCluster will instantiate a
1076// core config for each core it creates. If separate seal per core is desired, opts.SealFunc
1077// can be provided to generate a seal for each one. Otherwise, the provided base.Seal will be
1078// shared among cores. NewCore's default behavior is to generate a new DefaultSeal if the
1079// provided Seal in coreConfig (i.e. base.Seal) is nil.
1080//
1081// If opts.Logger is provided, it takes precedence and will be used as the cluster
1082// logger and will be the basis for each core's logger.  If no opts.Logger is
1083// given, one will be generated based on t.Name() for the cluster logger, and if
1084// no base.Logger is given will also be used as the basis for each core's logger.
1085func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *TestCluster {
1086	var err error
1087
1088	var numCores int
1089	if opts == nil || opts.NumCores == 0 {
1090		numCores = DefaultNumCores
1091	} else {
1092		numCores = opts.NumCores
1093	}
1094
1095	var firstCoreNumber int
1096	if opts != nil {
1097		firstCoreNumber = opts.FirstCoreNumber
1098	}
1099
1100	certIPs := []net.IP{
1101		net.IPv6loopback,
1102		net.ParseIP("127.0.0.1"),
1103	}
1104	var baseAddr *net.TCPAddr
1105	if opts != nil && opts.BaseListenAddress != "" {
1106		baseAddr, err = net.ResolveTCPAddr("tcp", opts.BaseListenAddress)
1107		if err != nil {
1108			t.Fatal("could not parse given base IP")
1109		}
1110		certIPs = append(certIPs, baseAddr.IP)
1111	}
1112
1113	var testCluster TestCluster
1114
1115	if opts != nil && opts.Logger != nil {
1116		testCluster.Logger = opts.Logger
1117	} else {
1118		testCluster.Logger = logging.NewVaultLogger(log.Trace).Named(t.Name())
1119	}
1120
1121	if opts != nil && opts.TempDir != "" {
1122		if _, err := os.Stat(opts.TempDir); os.IsNotExist(err) {
1123			if err := os.MkdirAll(opts.TempDir, 0700); err != nil {
1124				t.Fatal(err)
1125			}
1126		}
1127		testCluster.TempDir = opts.TempDir
1128	} else {
1129		tempDir, err := ioutil.TempDir("", "vault-test-cluster-")
1130		if err != nil {
1131			t.Fatal(err)
1132		}
1133		testCluster.TempDir = tempDir
1134	}
1135
1136	var caKey *ecdsa.PrivateKey
1137	if opts != nil && opts.CAKey != nil {
1138		caKey = opts.CAKey
1139	} else {
1140		caKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
1141		if err != nil {
1142			t.Fatal(err)
1143		}
1144	}
1145	testCluster.CAKey = caKey
1146	var caBytes []byte
1147	if opts != nil && len(opts.CACert) > 0 {
1148		caBytes = opts.CACert
1149	} else {
1150		caCertTemplate := &x509.Certificate{
1151			Subject: pkix.Name{
1152				CommonName: "localhost",
1153			},
1154			DNSNames:              []string{"localhost"},
1155			IPAddresses:           certIPs,
1156			KeyUsage:              x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign),
1157			SerialNumber:          big.NewInt(mathrand.Int63()),
1158			NotBefore:             time.Now().Add(-30 * time.Second),
1159			NotAfter:              time.Now().Add(262980 * time.Hour),
1160			BasicConstraintsValid: true,
1161			IsCA:                  true,
1162		}
1163		caBytes, err = x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, caKey.Public(), caKey)
1164		if err != nil {
1165			t.Fatal(err)
1166		}
1167	}
1168	caCert, err := x509.ParseCertificate(caBytes)
1169	if err != nil {
1170		t.Fatal(err)
1171	}
1172	testCluster.CACert = caCert
1173	testCluster.CACertBytes = caBytes
1174	testCluster.RootCAs = x509.NewCertPool()
1175	testCluster.RootCAs.AddCert(caCert)
1176	caCertPEMBlock := &pem.Block{
1177		Type:  "CERTIFICATE",
1178		Bytes: caBytes,
1179	}
1180	testCluster.CACertPEM = pem.EncodeToMemory(caCertPEMBlock)
1181	testCluster.CACertPEMFile = filepath.Join(testCluster.TempDir, "ca_cert.pem")
1182	err = ioutil.WriteFile(testCluster.CACertPEMFile, testCluster.CACertPEM, 0755)
1183	if err != nil {
1184		t.Fatal(err)
1185	}
1186	marshaledCAKey, err := x509.MarshalECPrivateKey(caKey)
1187	if err != nil {
1188		t.Fatal(err)
1189	}
1190	caKeyPEMBlock := &pem.Block{
1191		Type:  "EC PRIVATE KEY",
1192		Bytes: marshaledCAKey,
1193	}
1194	testCluster.CAKeyPEM = pem.EncodeToMemory(caKeyPEMBlock)
1195	err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "ca_key.pem"), testCluster.CAKeyPEM, 0755)
1196	if err != nil {
1197		t.Fatal(err)
1198	}
1199
1200	var certInfoSlice []*certInfo
1201
1202	//
1203	// Certs generation
1204	//
1205	for i := 0; i < numCores; i++ {
1206		key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
1207		if err != nil {
1208			t.Fatal(err)
1209		}
1210		certTemplate := &x509.Certificate{
1211			Subject: pkix.Name{
1212				CommonName: "localhost",
1213			},
1214			DNSNames:    []string{"localhost"},
1215			IPAddresses: certIPs,
1216			ExtKeyUsage: []x509.ExtKeyUsage{
1217				x509.ExtKeyUsageServerAuth,
1218				x509.ExtKeyUsageClientAuth,
1219			},
1220			KeyUsage:     x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
1221			SerialNumber: big.NewInt(mathrand.Int63()),
1222			NotBefore:    time.Now().Add(-30 * time.Second),
1223			NotAfter:     time.Now().Add(262980 * time.Hour),
1224		}
1225		certBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, caCert, key.Public(), caKey)
1226		if err != nil {
1227			t.Fatal(err)
1228		}
1229		cert, err := x509.ParseCertificate(certBytes)
1230		if err != nil {
1231			t.Fatal(err)
1232		}
1233		certPEMBlock := &pem.Block{
1234			Type:  "CERTIFICATE",
1235			Bytes: certBytes,
1236		}
1237		certPEM := pem.EncodeToMemory(certPEMBlock)
1238		marshaledKey, err := x509.MarshalECPrivateKey(key)
1239		if err != nil {
1240			t.Fatal(err)
1241		}
1242		keyPEMBlock := &pem.Block{
1243			Type:  "EC PRIVATE KEY",
1244			Bytes: marshaledKey,
1245		}
1246		keyPEM := pem.EncodeToMemory(keyPEMBlock)
1247
1248		certInfoSlice = append(certInfoSlice, &certInfo{
1249			cert:      cert,
1250			certPEM:   certPEM,
1251			certBytes: certBytes,
1252			key:       key,
1253			keyPEM:    keyPEM,
1254		})
1255	}
1256
1257	//
1258	// Listener setup
1259	//
1260	ports := make([]int, numCores)
1261	if baseAddr != nil {
1262		for i := 0; i < numCores; i++ {
1263			ports[i] = baseAddr.Port + i
1264		}
1265	} else {
1266		baseAddr = &net.TCPAddr{
1267			IP:   net.ParseIP("127.0.0.1"),
1268			Port: 0,
1269		}
1270	}
1271
1272	listeners := [][]*TestListener{}
1273	servers := []*http.Server{}
1274	handlers := []http.Handler{}
1275	tlsConfigs := []*tls.Config{}
1276	certGetters := []*reload.CertificateGetter{}
1277	for i := 0; i < numCores; i++ {
1278		baseAddr.Port = ports[i]
1279		ln, err := net.ListenTCP("tcp", baseAddr)
1280		if err != nil {
1281			t.Fatal(err)
1282		}
1283		certFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_cert.pem", i+1, ln.Addr().(*net.TCPAddr).Port))
1284		keyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_key.pem", i+1, ln.Addr().(*net.TCPAddr).Port))
1285		err = ioutil.WriteFile(certFile, certInfoSlice[i].certPEM, 0755)
1286		if err != nil {
1287			t.Fatal(err)
1288		}
1289		err = ioutil.WriteFile(keyFile, certInfoSlice[i].keyPEM, 0755)
1290		if err != nil {
1291			t.Fatal(err)
1292		}
1293		tlsCert, err := tls.X509KeyPair(certInfoSlice[i].certPEM, certInfoSlice[i].keyPEM)
1294		if err != nil {
1295			t.Fatal(err)
1296		}
1297		certGetter := reload.NewCertificateGetter(certFile, keyFile, "")
1298		certGetters = append(certGetters, certGetter)
1299		tlsConfig := &tls.Config{
1300			Certificates:   []tls.Certificate{tlsCert},
1301			RootCAs:        testCluster.RootCAs,
1302			ClientCAs:      testCluster.RootCAs,
1303			ClientAuth:     tls.RequestClientCert,
1304			NextProtos:     []string{"h2", "http/1.1"},
1305			GetCertificate: certGetter.GetCertificate,
1306		}
1307		if opts != nil && opts.RequireClientAuth {
1308			tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
1309			testCluster.ClientAuthRequired = true
1310		}
1311		tlsConfig.BuildNameToCertificate()
1312		tlsConfigs = append(tlsConfigs, tlsConfig)
1313		lns := []*TestListener{&TestListener{
1314			Listener: tls.NewListener(ln, tlsConfig),
1315			Address:  ln.Addr().(*net.TCPAddr),
1316		},
1317		}
1318		listeners = append(listeners, lns)
1319		var handler http.Handler = http.NewServeMux()
1320		handlers = append(handlers, handler)
1321		server := &http.Server{
1322			Handler:  handler,
1323			ErrorLog: testCluster.Logger.StandardLogger(nil),
1324		}
1325		servers = append(servers, server)
1326	}
1327
1328	// Create three cores with the same physical and different redirect/cluster
1329	// addrs.
1330	// N.B.: On OSX, instead of random ports, it assigns new ports to new
1331	// listeners sequentially. Aside from being a bad idea in a security sense,
1332	// it also broke tests that assumed it was OK to just use the port above
1333	// the redirect addr. This has now been changed to 105 ports above, but if
1334	// we ever do more than three nodes in a cluster it may need to be bumped.
1335	// Note: it's 105 so that we don't conflict with a running Consul by
1336	// default.
1337	coreConfig := &CoreConfig{
1338		LogicalBackends:    make(map[string]logical.Factory),
1339		CredentialBackends: make(map[string]logical.Factory),
1340		AuditBackends:      make(map[string]audit.Factory),
1341		RedirectAddr:       fmt.Sprintf("https://127.0.0.1:%d", listeners[0][0].Address.Port),
1342		ClusterAddr:        "https://127.0.0.1:0",
1343		DisableMlock:       true,
1344		EnableUI:           true,
1345		EnableRaw:          true,
1346		BuiltinRegistry:    NewMockBuiltinRegistry(),
1347	}
1348
1349	if base != nil {
1350		coreConfig.RawConfig = base.RawConfig
1351		coreConfig.DisableCache = base.DisableCache
1352		coreConfig.EnableUI = base.EnableUI
1353		coreConfig.DefaultLeaseTTL = base.DefaultLeaseTTL
1354		coreConfig.MaxLeaseTTL = base.MaxLeaseTTL
1355		coreConfig.CacheSize = base.CacheSize
1356		coreConfig.PluginDirectory = base.PluginDirectory
1357		coreConfig.Seal = base.Seal
1358		coreConfig.DevToken = base.DevToken
1359		coreConfig.EnableRaw = base.EnableRaw
1360		coreConfig.DisableSealWrap = base.DisableSealWrap
1361		coreConfig.DevLicenseDuration = base.DevLicenseDuration
1362		coreConfig.DisableCache = base.DisableCache
1363		coreConfig.LicensingConfig = base.LicensingConfig
1364		coreConfig.DisablePerformanceStandby = base.DisablePerformanceStandby
1365		coreConfig.MetricsHelper = base.MetricsHelper
1366		coreConfig.SecureRandomReader = base.SecureRandomReader
1367		if base.BuiltinRegistry != nil {
1368			coreConfig.BuiltinRegistry = base.BuiltinRegistry
1369		}
1370
1371		if !coreConfig.DisableMlock {
1372			base.DisableMlock = false
1373		}
1374
1375		if base.Physical != nil {
1376			coreConfig.Physical = base.Physical
1377		}
1378
1379		if base.HAPhysical != nil {
1380			coreConfig.HAPhysical = base.HAPhysical
1381		}
1382
1383		// Used to set something non-working to test fallback
1384		switch base.ClusterAddr {
1385		case "empty":
1386			coreConfig.ClusterAddr = ""
1387		case "":
1388		default:
1389			coreConfig.ClusterAddr = base.ClusterAddr
1390		}
1391
1392		if base.LogicalBackends != nil {
1393			for k, v := range base.LogicalBackends {
1394				coreConfig.LogicalBackends[k] = v
1395			}
1396		}
1397		if base.CredentialBackends != nil {
1398			for k, v := range base.CredentialBackends {
1399				coreConfig.CredentialBackends[k] = v
1400			}
1401		}
1402		if base.AuditBackends != nil {
1403			for k, v := range base.AuditBackends {
1404				coreConfig.AuditBackends[k] = v
1405			}
1406		}
1407		if base.Logger != nil {
1408			coreConfig.Logger = base.Logger
1409		}
1410
1411		coreConfig.ClusterCipherSuites = base.ClusterCipherSuites
1412
1413		coreConfig.DisableCache = base.DisableCache
1414
1415		coreConfig.DevToken = base.DevToken
1416		coreConfig.CounterSyncInterval = base.CounterSyncInterval
1417		coreConfig.RecoveryMode = base.RecoveryMode
1418	}
1419
1420	if coreConfig.RawConfig == nil {
1421		coreConfig.RawConfig = new(server.Config)
1422	}
1423
1424	addAuditBackend := len(coreConfig.AuditBackends) == 0
1425	if addAuditBackend {
1426		AddNoopAudit(coreConfig)
1427	}
1428
1429	if coreConfig.Physical == nil && (opts == nil || opts.PhysicalFactory == nil) {
1430		coreConfig.Physical, err = physInmem.NewInmem(nil, testCluster.Logger)
1431		if err != nil {
1432			t.Fatal(err)
1433		}
1434	}
1435	if coreConfig.HAPhysical == nil && (opts == nil || opts.PhysicalFactory == nil) {
1436		haPhys, err := physInmem.NewInmemHA(nil, testCluster.Logger)
1437		if err != nil {
1438			t.Fatal(err)
1439		}
1440		coreConfig.HAPhysical = haPhys.(physical.HABackend)
1441	}
1442
1443	pubKey, priKey, err := testGenerateCoreKeys()
1444	if err != nil {
1445		t.Fatalf("err: %v", err)
1446	}
1447
1448	cleanupFuncs := []func(){}
1449	cores := []*Core{}
1450	coreConfigs := []*CoreConfig{}
1451	for i := 0; i < numCores; i++ {
1452		localConfig := *coreConfig
1453		localConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", listeners[i][0].Address.Port)
1454
1455		// if opts.SealFunc is provided, use that to generate a seal for the config instead
1456		if opts != nil && opts.SealFunc != nil {
1457			localConfig.Seal = opts.SealFunc()
1458		}
1459
1460		if coreConfig.Logger == nil || (opts != nil && opts.Logger != nil) {
1461			localConfig.Logger = testCluster.Logger.Named(fmt.Sprintf("core%d", i))
1462		}
1463		if opts != nil && opts.PhysicalFactory != nil {
1464			physBundle := opts.PhysicalFactory(t, i, localConfig.Logger)
1465			switch {
1466			case physBundle == nil && coreConfig.Physical != nil:
1467			case physBundle == nil && coreConfig.Physical == nil:
1468				t.Fatal("PhysicalFactory produced no physical and none in CoreConfig")
1469			case physBundle != nil:
1470				testCluster.Logger.Info("created physical backend", "instance", i)
1471				coreConfig.Physical = physBundle.Backend
1472				localConfig.Physical = physBundle.Backend
1473				base.Physical = physBundle.Backend
1474				haBackend := physBundle.HABackend
1475				if haBackend == nil {
1476					if ha, ok := physBundle.Backend.(physical.HABackend); ok {
1477						haBackend = ha
1478					}
1479				}
1480				coreConfig.HAPhysical = haBackend
1481				localConfig.HAPhysical = haBackend
1482				if physBundle.Cleanup != nil {
1483					cleanupFuncs = append(cleanupFuncs, physBundle.Cleanup)
1484				}
1485			}
1486		}
1487
1488		switch {
1489		case localConfig.LicensingConfig != nil:
1490			if pubKey != nil {
1491				localConfig.LicensingConfig.AdditionalPublicKeys = append(localConfig.LicensingConfig.AdditionalPublicKeys, pubKey.(ed25519.PublicKey))
1492			}
1493		default:
1494			localConfig.LicensingConfig = testGetLicensingConfig(pubKey)
1495		}
1496
1497		if localConfig.MetricsHelper == nil {
1498			inm := metrics.NewInmemSink(10*time.Second, time.Minute)
1499			metrics.DefaultInmemSignal(inm)
1500			localConfig.MetricsHelper = metricsutil.NewMetricsHelper(inm, false)
1501		}
1502
1503		c, err := NewCore(&localConfig)
1504		if err != nil {
1505			t.Fatalf("err: %v", err)
1506		}
1507		c.coreNumber = firstCoreNumber + i
1508		cores = append(cores, c)
1509		coreConfigs = append(coreConfigs, &localConfig)
1510		if opts != nil && opts.HandlerFunc != nil {
1511			props := opts.DefaultHandlerProperties
1512			props.Core = c
1513			if props.MaxRequestDuration == 0 {
1514				props.MaxRequestDuration = DefaultMaxRequestDuration
1515			}
1516			handlers[i] = opts.HandlerFunc(&props)
1517			servers[i].Handler = handlers[i]
1518		}
1519
1520		// Set this in case the Seal was manually set before the core was
1521		// created
1522		if localConfig.Seal != nil {
1523			localConfig.Seal.SetCore(c)
1524		}
1525	}
1526
1527	//
1528	// Clustering setup
1529	//
1530	clusterAddrGen := func(lns []*TestListener) []*net.TCPAddr {
1531		ret := make([]*net.TCPAddr, len(lns))
1532		for i, ln := range lns {
1533			ret[i] = &net.TCPAddr{
1534				IP:   ln.Address.IP,
1535				Port: 0,
1536			}
1537		}
1538		return ret
1539	}
1540
1541	for i := 0; i < numCores; i++ {
1542		if coreConfigs[i].ClusterAddr != "" {
1543			cores[i].SetClusterListenerAddrs(clusterAddrGen(listeners[i]))
1544			cores[i].SetClusterHandler(handlers[i])
1545		}
1546	}
1547
1548	if opts == nil || !opts.SkipInit {
1549		bKeys, rKeys, root := TestCoreInitClusterWrapperSetup(t, cores[0], handlers[0])
1550		barrierKeys, _ := copystructure.Copy(bKeys)
1551		testCluster.BarrierKeys = barrierKeys.([][]byte)
1552		recoveryKeys, _ := copystructure.Copy(rKeys)
1553		testCluster.RecoveryKeys = recoveryKeys.([][]byte)
1554		testCluster.RootToken = root
1555
1556		// Write root token and barrier keys
1557		err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "root_token"), []byte(root), 0755)
1558		if err != nil {
1559			t.Fatal(err)
1560		}
1561		var buf bytes.Buffer
1562		for i, key := range testCluster.BarrierKeys {
1563			buf.Write([]byte(base64.StdEncoding.EncodeToString(key)))
1564			if i < len(testCluster.BarrierKeys)-1 {
1565				buf.WriteRune('\n')
1566			}
1567		}
1568		err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "barrier_keys"), buf.Bytes(), 0755)
1569		if err != nil {
1570			t.Fatal(err)
1571		}
1572		for i, key := range testCluster.RecoveryKeys {
1573			buf.Write([]byte(base64.StdEncoding.EncodeToString(key)))
1574			if i < len(testCluster.RecoveryKeys)-1 {
1575				buf.WriteRune('\n')
1576			}
1577		}
1578		err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "recovery_keys"), buf.Bytes(), 0755)
1579		if err != nil {
1580			t.Fatal(err)
1581		}
1582
1583		// Unseal first core
1584		for _, key := range bKeys {
1585			if _, err := cores[0].Unseal(TestKeyCopy(key)); err != nil {
1586				t.Fatalf("unseal err: %s", err)
1587			}
1588		}
1589
1590		ctx := context.Background()
1591
1592		// If stored keys is supported, the above will no no-op, so trigger auto-unseal
1593		// using stored keys to try to unseal
1594		if err := cores[0].UnsealWithStoredKeys(ctx); err != nil {
1595			t.Fatal(err)
1596		}
1597
1598		// Verify unsealed
1599		if cores[0].Sealed() {
1600			t.Fatal("should not be sealed")
1601		}
1602
1603		TestWaitActive(t, cores[0])
1604
1605		// Existing tests rely on this; we can make a toggle to disable it
1606		// later if we want
1607		kvReq := &logical.Request{
1608			Operation:   logical.UpdateOperation,
1609			ClientToken: testCluster.RootToken,
1610			Path:        "sys/mounts/secret",
1611			Data: map[string]interface{}{
1612				"type":        "kv",
1613				"path":        "secret/",
1614				"description": "key/value secret storage",
1615				"options": map[string]string{
1616					"version": "1",
1617				},
1618			},
1619		}
1620		resp, err := cores[0].HandleRequest(namespace.RootContext(ctx), kvReq)
1621		if err != nil {
1622			t.Fatal(err)
1623		}
1624		if resp.IsError() {
1625			t.Fatal(err)
1626		}
1627
1628		cfg, err := cores[0].seal.BarrierConfig(ctx)
1629		if err != nil {
1630			t.Fatal(err)
1631		}
1632
1633		// Unseal other cores unless otherwise specified
1634		if (opts == nil || !opts.KeepStandbysSealed) && numCores > 1 {
1635			for i := 1; i < numCores; i++ {
1636				cores[i].seal.SetCachedBarrierConfig(cfg)
1637				for _, key := range bKeys {
1638					if _, err := cores[i].Unseal(TestKeyCopy(key)); err != nil {
1639						t.Fatalf("unseal err: %s", err)
1640					}
1641				}
1642
1643				// If stored keys is supported, the above will no no-op, so trigger auto-unseal
1644				// using stored keys
1645				if err := cores[i].UnsealWithStoredKeys(ctx); err != nil {
1646					t.Fatal(err)
1647				}
1648			}
1649
1650			// Let them come fully up to standby
1651			time.Sleep(2 * time.Second)
1652
1653			// Ensure cluster connection info is populated.
1654			// Other cores should not come up as leaders.
1655			for i := 1; i < numCores; i++ {
1656				isLeader, _, _, err := cores[i].Leader()
1657				if err != nil {
1658					t.Fatal(err)
1659				}
1660				if isLeader {
1661					t.Fatalf("core[%d] should not be leader", i)
1662				}
1663			}
1664		}
1665
1666		//
1667		// Set test cluster core(s) and test cluster
1668		//
1669		cluster, err := cores[0].Cluster(context.Background())
1670		if err != nil {
1671			t.Fatal(err)
1672		}
1673		testCluster.ID = cluster.ID
1674
1675		if addAuditBackend {
1676			// Enable auditing.
1677			auditReq := &logical.Request{
1678				Operation:   logical.UpdateOperation,
1679				ClientToken: testCluster.RootToken,
1680				Path:        "sys/audit/noop",
1681				Data: map[string]interface{}{
1682					"type": "noop",
1683				},
1684			}
1685			resp, err = cores[0].HandleRequest(namespace.RootContext(ctx), auditReq)
1686			if err != nil {
1687				t.Fatal(err)
1688			}
1689
1690			if resp.IsError() {
1691				t.Fatal(err)
1692			}
1693		}
1694	}
1695
1696	getAPIClient := func(port int, tlsConfig *tls.Config) *api.Client {
1697		transport := cleanhttp.DefaultPooledTransport()
1698		transport.TLSClientConfig = tlsConfig.Clone()
1699		if err := http2.ConfigureTransport(transport); err != nil {
1700			t.Fatal(err)
1701		}
1702		client := &http.Client{
1703			Transport: transport,
1704			CheckRedirect: func(*http.Request, []*http.Request) error {
1705				// This can of course be overridden per-test by using its own client
1706				return fmt.Errorf("redirects not allowed in these tests")
1707			},
1708		}
1709		config := api.DefaultConfig()
1710		if config.Error != nil {
1711			t.Fatal(config.Error)
1712		}
1713		config.Address = fmt.Sprintf("https://127.0.0.1:%d", port)
1714		config.HttpClient = client
1715		config.MaxRetries = 0
1716		apiClient, err := api.NewClient(config)
1717		if err != nil {
1718			t.Fatal(err)
1719		}
1720		if opts == nil || !opts.SkipInit {
1721			apiClient.SetToken(testCluster.RootToken)
1722		}
1723		return apiClient
1724	}
1725
1726	var ret []*TestClusterCore
1727	for i := 0; i < numCores; i++ {
1728		tcc := &TestClusterCore{
1729			Core:                 cores[i],
1730			CoreConfig:           coreConfigs[i],
1731			ServerKey:            certInfoSlice[i].key,
1732			ServerKeyPEM:         certInfoSlice[i].keyPEM,
1733			ServerCert:           certInfoSlice[i].cert,
1734			ServerCertBytes:      certInfoSlice[i].certBytes,
1735			ServerCertPEM:        certInfoSlice[i].certPEM,
1736			Listeners:            listeners[i],
1737			Handler:              handlers[i],
1738			Server:               servers[i],
1739			TLSConfig:            tlsConfigs[i],
1740			Client:               getAPIClient(listeners[i][0].Address.Port, tlsConfigs[i]),
1741			Barrier:              cores[i].barrier,
1742			NodeID:               fmt.Sprintf("core-%d", i),
1743			UnderlyingRawStorage: coreConfigs[i].Physical,
1744		}
1745		tcc.ReloadFuncs = &cores[i].reloadFuncs
1746		tcc.ReloadFuncsLock = &cores[i].reloadFuncsLock
1747		tcc.ReloadFuncsLock.Lock()
1748		(*tcc.ReloadFuncs)["listener|tcp"] = []reload.ReloadFunc{certGetters[i].Reload}
1749		tcc.ReloadFuncsLock.Unlock()
1750
1751		testAdjustTestCore(base, tcc)
1752
1753		ret = append(ret, tcc)
1754	}
1755
1756	testCluster.Cores = ret
1757
1758	testExtraClusterCoresTestSetup(t, priKey, testCluster.Cores)
1759
1760	testCluster.CleanupFunc = func() {
1761		for _, c := range cleanupFuncs {
1762			c()
1763		}
1764	}
1765	if opts != nil {
1766		if opts.SetupFunc != nil {
1767			testCluster.SetupFunc = func() {
1768				opts.SetupFunc(t, &testCluster)
1769			}
1770		}
1771	}
1772
1773	return &testCluster
1774}
1775
1776func NewMockBuiltinRegistry() *mockBuiltinRegistry {
1777	return &mockBuiltinRegistry{
1778		forTesting: map[string]consts.PluginType{
1779			"mysql-database-plugin":      consts.PluginTypeDatabase,
1780			"postgresql-database-plugin": consts.PluginTypeDatabase,
1781		},
1782	}
1783}
1784
1785type mockBuiltinRegistry struct {
1786	forTesting map[string]consts.PluginType
1787}
1788
1789func (m *mockBuiltinRegistry) Get(name string, pluginType consts.PluginType) (func() (interface{}, error), bool) {
1790	testPluginType, ok := m.forTesting[name]
1791	if !ok {
1792		return nil, false
1793	}
1794	if pluginType != testPluginType {
1795		return nil, false
1796	}
1797	if name == "postgresql-database-plugin" {
1798		return dbPostgres.New, true
1799	}
1800	return dbMysql.New(dbMysql.MetadataLen, dbMysql.MetadataLen, dbMysql.UsernameLen), true
1801}
1802
1803// Keys only supports getting a realistic list of the keys for database plugins.
1804func (m *mockBuiltinRegistry) Keys(pluginType consts.PluginType) []string {
1805	if pluginType != consts.PluginTypeDatabase {
1806		return []string{}
1807	}
1808	/*
1809		This is a hard-coded reproduction of the db plugin keys in helper/builtinplugins/registry.go.
1810		The registry isn't directly used because it causes import cycles.
1811	*/
1812	return []string{
1813		"mysql-database-plugin",
1814		"mysql-aurora-database-plugin",
1815		"mysql-rds-database-plugin",
1816		"mysql-legacy-database-plugin",
1817		"postgresql-database-plugin",
1818		"elasticsearch-database-plugin",
1819		"mssql-database-plugin",
1820		"cassandra-database-plugin",
1821		"mongodb-database-plugin",
1822		"hana-database-plugin",
1823		"influxdb-database-plugin",
1824	}
1825}
1826
1827func (m *mockBuiltinRegistry) Contains(name string, pluginType consts.PluginType) bool {
1828	return false
1829}
1830
1831type NoopAudit struct {
1832	Config         *audit.BackendConfig
1833	ReqErr         error
1834	ReqAuth        []*logical.Auth
1835	Req            []*logical.Request
1836	ReqHeaders     []map[string][]string
1837	ReqNonHMACKeys []string
1838	ReqErrs        []error
1839
1840	RespErr            error
1841	RespAuth           []*logical.Auth
1842	RespReq            []*logical.Request
1843	Resp               []*logical.Response
1844	RespNonHMACKeys    []string
1845	RespReqNonHMACKeys []string
1846	RespErrs           []error
1847
1848	salt      *salt.Salt
1849	saltMutex sync.RWMutex
1850}
1851
1852func (n *NoopAudit) LogRequest(ctx context.Context, in *logical.LogInput) error {
1853	n.ReqAuth = append(n.ReqAuth, in.Auth)
1854	n.Req = append(n.Req, in.Request)
1855	n.ReqHeaders = append(n.ReqHeaders, in.Request.Headers)
1856	n.ReqNonHMACKeys = in.NonHMACReqDataKeys
1857	n.ReqErrs = append(n.ReqErrs, in.OuterErr)
1858	return n.ReqErr
1859}
1860
1861func (n *NoopAudit) LogResponse(ctx context.Context, in *logical.LogInput) error {
1862	n.RespAuth = append(n.RespAuth, in.Auth)
1863	n.RespReq = append(n.RespReq, in.Request)
1864	n.Resp = append(n.Resp, in.Response)
1865	n.RespErrs = append(n.RespErrs, in.OuterErr)
1866
1867	if in.Response != nil {
1868		n.RespNonHMACKeys = in.NonHMACRespDataKeys
1869		n.RespReqNonHMACKeys = in.NonHMACReqDataKeys
1870	}
1871
1872	return n.RespErr
1873}
1874
1875func (n *NoopAudit) Salt(ctx context.Context) (*salt.Salt, error) {
1876	n.saltMutex.RLock()
1877	if n.salt != nil {
1878		defer n.saltMutex.RUnlock()
1879		return n.salt, nil
1880	}
1881	n.saltMutex.RUnlock()
1882	n.saltMutex.Lock()
1883	defer n.saltMutex.Unlock()
1884	if n.salt != nil {
1885		return n.salt, nil
1886	}
1887	salt, err := salt.NewSalt(ctx, n.Config.SaltView, n.Config.SaltConfig)
1888	if err != nil {
1889		return nil, err
1890	}
1891	n.salt = salt
1892	return salt, nil
1893}
1894
1895func (n *NoopAudit) GetHash(ctx context.Context, data string) (string, error) {
1896	salt, err := n.Salt(ctx)
1897	if err != nil {
1898		return "", err
1899	}
1900	return salt.GetIdentifiedHMAC(data), nil
1901}
1902
1903func (n *NoopAudit) Reload(ctx context.Context) error {
1904	return nil
1905}
1906
1907func (n *NoopAudit) Invalidate(ctx context.Context) {
1908	n.saltMutex.Lock()
1909	defer n.saltMutex.Unlock()
1910	n.salt = nil
1911}
1912