1package vault
2
3import (
4	"bytes"
5	"context"
6	"crypto/ecdsa"
7	"crypto/elliptic"
8	"crypto/rand"
9	"crypto/sha256"
10	"crypto/tls"
11	"crypto/x509"
12	"crypto/x509/pkix"
13	"encoding/base64"
14	"encoding/pem"
15	"errors"
16	"fmt"
17	"io"
18	"io/ioutil"
19	"math/big"
20	mathrand "math/rand"
21	"net"
22	"net/http"
23	"os"
24	"os/exec"
25	"path/filepath"
26	"sync"
27	"sync/atomic"
28	"time"
29
30	hclog "github.com/hashicorp/go-hclog"
31	log "github.com/hashicorp/go-hclog"
32	"github.com/mitchellh/copystructure"
33
34	"golang.org/x/crypto/ed25519"
35	"golang.org/x/crypto/ssh"
36	"golang.org/x/net/http2"
37
38	cleanhttp "github.com/hashicorp/go-cleanhttp"
39	"github.com/hashicorp/vault/api"
40	"github.com/hashicorp/vault/audit"
41	"github.com/hashicorp/vault/helper/namespace"
42	"github.com/hashicorp/vault/helper/reload"
43	dbMysql "github.com/hashicorp/vault/plugins/database/mysql"
44	dbPostgres "github.com/hashicorp/vault/plugins/database/postgresql"
45	"github.com/hashicorp/vault/sdk/framework"
46	"github.com/hashicorp/vault/sdk/helper/consts"
47	"github.com/hashicorp/vault/sdk/helper/logging"
48	"github.com/hashicorp/vault/sdk/helper/salt"
49	"github.com/hashicorp/vault/sdk/logical"
50	"github.com/hashicorp/vault/sdk/physical"
51	testing "github.com/mitchellh/go-testing-interface"
52
53	physInmem "github.com/hashicorp/vault/sdk/physical/inmem"
54)
55
56// This file contains a number of methods that are useful for unit
57// tests within other packages.
58
59const (
60	testSharedPublicKey = `
61ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQC9i+hFxZHGo6KblVme4zrAcJstR6I0PTJozW286X4WyvPnkMYDQ5mnhEYC7UWCvjoTWbPEXPX7NjhRtwQTGD67bV+lrxgfyzK1JZbUXK4PwgKJvQD+XyyWYMzDgGSQY61KUSqCxymSm/9NZkPU3ElaQ9xQuTzPpztM4ROfb8f2Yv6/ZESZsTo0MTAkp8Pcy+WkioI/uJ1H7zqs0EA4OMY4aDJRu0UtP4rTVeYNEAuRXdX+eH4aW3KMvhzpFTjMbaJHJXlEeUm2SaX5TNQyTOvghCeQILfYIL/Ca2ij8iwCmulwdV6eQGfd4VDu40PvSnmfoaE38o6HaPnX0kUcnKiT
62`
63	testSharedPrivateKey = `
64-----BEGIN RSA PRIVATE KEY-----
65MIIEogIBAAKCAQEAvYvoRcWRxqOim5VZnuM6wHCbLUeiND0yaM1tvOl+Fsrz55DG
66A0OZp4RGAu1Fgr46E1mzxFz1+zY4UbcEExg+u21fpa8YH8sytSWW1FyuD8ICib0A
67/l8slmDMw4BkkGOtSlEqgscpkpv/TWZD1NxJWkPcULk8z6c7TOETn2/H9mL+v2RE
68mbE6NDEwJKfD3MvlpIqCP7idR+86rNBAODjGOGgyUbtFLT+K01XmDRALkV3V/nh+
69GltyjL4c6RU4zG2iRyV5RHlJtkml+UzUMkzr4IQnkCC32CC/wmtoo/IsAprpcHVe
70nkBn3eFQ7uND70p5n6GhN/KOh2j519JFHJyokwIDAQABAoIBAHX7VOvBC3kCN9/x
71+aPdup84OE7Z7MvpX6w+WlUhXVugnmsAAVDczhKoUc/WktLLx2huCGhsmKvyVuH+
72MioUiE+vx75gm3qGx5xbtmOfALVMRLopjCnJYf6EaFA0ZeQ+NwowNW7Lu0PHmAU8
73Z3JiX8IwxTz14DU82buDyewO7v+cEr97AnERe3PUcSTDoUXNaoNxjNpEJkKREY6h
744hAY676RT/GsRcQ8tqe/rnCqPHNd7JGqL+207FK4tJw7daoBjQyijWuB7K5chSal
75oPInylM6b13ASXuOAOT/2uSUBWmFVCZPDCmnZxy2SdnJGbsJAMl7Ma3MUlaGvVI+
76Tfh1aQkCgYEA4JlNOabTb3z42wz6mz+Nz3JRwbawD+PJXOk5JsSnV7DtPtfgkK9y
776FTQdhnozGWShAvJvc+C4QAihs9AlHXoaBY5bEU7R/8UK/pSqwzam+MmxmhVDV7G
78IMQPV0FteoXTaJSikhZ88mETTegI2mik+zleBpVxvfdhE5TR+lq8Br0CgYEA2AwJ
79CUD5CYUSj09PluR0HHqamWOrJkKPFPwa+5eiTTCzfBBxImYZh7nXnWuoviXC0sg2
80AuvCW+uZ48ygv/D8gcz3j1JfbErKZJuV+TotK9rRtNIF5Ub7qysP7UjyI7zCssVM
81kuDd9LfRXaB/qGAHNkcDA8NxmHW3gpln4CFdSY8CgYANs4xwfercHEWaJ1qKagAe
82rZyrMpffAEhicJ/Z65lB0jtG4CiE6w8ZeUMWUVJQVcnwYD+4YpZbX4S7sJ0B8Ydy
83AhkSr86D/92dKTIt2STk6aCN7gNyQ1vW198PtaAWH1/cO2UHgHOy3ZUt5X/Uwxl9
84cex4flln+1Viumts2GgsCQKBgCJH7psgSyPekK5auFdKEr5+Gc/jB8I/Z3K9+g4X
855nH3G1PBTCJYLw7hRzw8W/8oALzvddqKzEFHphiGXK94Lqjt/A4q1OdbCrhiE68D
86My21P/dAKB1UYRSs9Y8CNyHCjuZM9jSMJ8vv6vG/SOJPsnVDWVAckAbQDvlTHC9t
87O98zAoGAcbW6uFDkrv0XMCpB9Su3KaNXOR0wzag+WIFQRXCcoTvxVi9iYfUReQPi
88oOyBJU/HMVvBfv4g+OVFLVgSwwm6owwsouZ0+D/LasbuHqYyqYqdyPJQYzWA2Y+F
89+B6f4RoPdSXj24JHPg/ioRxjaj094UXJxua2yfkcecGNEuBQHSs=
90-----END RSA PRIVATE KEY-----
91`
92)
93
94// TestCore returns a pure in-memory, uninitialized core for testing.
95func TestCore(t testing.T) *Core {
96	return TestCoreWithSeal(t, nil, false)
97}
98
99// TestCoreRaw returns a pure in-memory, uninitialized core for testing. The raw
100// storage endpoints are enabled with this core.
101func TestCoreRaw(t testing.T) *Core {
102	return TestCoreWithSeal(t, nil, true)
103}
104
105// TestCoreNewSeal returns a pure in-memory, uninitialized core with
106// the new seal configuration.
107func TestCoreNewSeal(t testing.T) *Core {
108	seal := NewTestSeal(t, nil)
109	return TestCoreWithSeal(t, seal, false)
110}
111
112// TestCoreWithConfig returns a pure in-memory, uninitialized core with the
113// specified core configurations overridden for testing.
114func TestCoreWithConfig(t testing.T, conf *CoreConfig) *Core {
115	return TestCoreWithSealAndUI(t, conf)
116}
117
118// TestCoreWithSeal returns a pure in-memory, uninitialized core with the
119// specified seal for testing.
120func TestCoreWithSeal(t testing.T, testSeal Seal, enableRaw bool) *Core {
121	conf := &CoreConfig{
122		Seal:            testSeal,
123		EnableUI:        false,
124		EnableRaw:       enableRaw,
125		BuiltinRegistry: NewMockBuiltinRegistry(),
126	}
127	return TestCoreWithSealAndUI(t, conf)
128}
129
130func TestCoreUI(t testing.T, enableUI bool) *Core {
131	conf := &CoreConfig{
132		EnableUI:        enableUI,
133		EnableRaw:       true,
134		BuiltinRegistry: NewMockBuiltinRegistry(),
135	}
136	return TestCoreWithSealAndUI(t, conf)
137}
138
139func TestCoreWithSealAndUI(t testing.T, opts *CoreConfig) *Core {
140	logger := logging.NewVaultLogger(log.Trace)
141	physicalBackend, err := physInmem.NewInmem(nil, logger)
142	if err != nil {
143		t.Fatal(err)
144	}
145
146	// Start off with base test core config
147	conf := testCoreConfig(t, physicalBackend, logger)
148
149	// Override config values with ones that gets passed in
150	conf.EnableUI = opts.EnableUI
151	conf.EnableRaw = opts.EnableRaw
152	conf.Seal = opts.Seal
153	conf.LicensingConfig = opts.LicensingConfig
154	conf.DisableKeyEncodingChecks = opts.DisableKeyEncodingChecks
155
156	if opts.Logger != nil {
157		conf.Logger = opts.Logger
158	}
159
160	for k, v := range opts.LogicalBackends {
161		conf.LogicalBackends[k] = v
162	}
163	for k, v := range opts.CredentialBackends {
164		conf.CredentialBackends[k] = v
165	}
166
167	for k, v := range opts.AuditBackends {
168		conf.AuditBackends[k] = v
169	}
170
171	c, err := NewCore(conf)
172	if err != nil {
173		t.Fatalf("err: %s", err)
174	}
175
176	return c
177}
178
179func testCoreConfig(t testing.T, physicalBackend physical.Backend, logger log.Logger) *CoreConfig {
180	t.Helper()
181	noopAudits := map[string]audit.Factory{
182		"noop": func(_ context.Context, config *audit.BackendConfig) (audit.Backend, error) {
183			view := &logical.InmemStorage{}
184			view.Put(context.Background(), &logical.StorageEntry{
185				Key:   "salt",
186				Value: []byte("foo"),
187			})
188			config.SaltConfig = &salt.Config{
189				HMAC:     sha256.New,
190				HMACType: "hmac-sha256",
191			}
192			config.SaltView = view
193
194			n := &noopAudit{
195				Config: config,
196			}
197			n.formatter.AuditFormatWriter = &audit.JSONFormatWriter{
198				SaltFunc: n.Salt,
199			}
200			return n, nil
201		},
202	}
203
204	noopBackends := make(map[string]logical.Factory)
205	noopBackends["noop"] = func(ctx context.Context, config *logical.BackendConfig) (logical.Backend, error) {
206		b := new(framework.Backend)
207		b.Setup(ctx, config)
208		b.BackendType = logical.TypeCredential
209		return b, nil
210	}
211	noopBackends["http"] = func(ctx context.Context, config *logical.BackendConfig) (logical.Backend, error) {
212		return new(rawHTTP), nil
213	}
214
215	credentialBackends := make(map[string]logical.Factory)
216	for backendName, backendFactory := range noopBackends {
217		credentialBackends[backendName] = backendFactory
218	}
219	for backendName, backendFactory := range testCredentialBackends {
220		credentialBackends[backendName] = backendFactory
221	}
222
223	logicalBackends := make(map[string]logical.Factory)
224	for backendName, backendFactory := range noopBackends {
225		logicalBackends[backendName] = backendFactory
226	}
227
228	logicalBackends["kv"] = LeasedPassthroughBackendFactory
229	for backendName, backendFactory := range testLogicalBackends {
230		logicalBackends[backendName] = backendFactory
231	}
232
233	conf := &CoreConfig{
234		Physical:           physicalBackend,
235		AuditBackends:      noopAudits,
236		LogicalBackends:    logicalBackends,
237		CredentialBackends: credentialBackends,
238		DisableMlock:       true,
239		Logger:             logger,
240		BuiltinRegistry:    NewMockBuiltinRegistry(),
241	}
242
243	return conf
244}
245
246// TestCoreInit initializes the core with a single key, and returns
247// the key that must be used to unseal the core and a root token.
248func TestCoreInit(t testing.T, core *Core) ([][]byte, string) {
249	t.Helper()
250	secretShares, _, root := TestCoreInitClusterWrapperSetup(t, core, nil)
251	return secretShares, root
252}
253
254func TestCoreInitClusterWrapperSetup(t testing.T, core *Core, handler http.Handler) ([][]byte, [][]byte, string) {
255	t.Helper()
256	core.SetClusterHandler(handler)
257
258	barrierConfig := &SealConfig{
259		SecretShares:    3,
260		SecretThreshold: 3,
261	}
262
263	// If we support storing barrier keys, then set that to equal the min threshold to unseal
264	if core.seal.StoredKeysSupported() {
265		barrierConfig.StoredShares = barrierConfig.SecretThreshold
266	}
267
268	recoveryConfig := &SealConfig{
269		SecretShares:    3,
270		SecretThreshold: 3,
271	}
272
273	result, err := core.Initialize(context.Background(), &InitParams{
274		BarrierConfig:  barrierConfig,
275		RecoveryConfig: recoveryConfig,
276	})
277	if err != nil {
278		t.Fatalf("err: %s", err)
279	}
280	return result.SecretShares, result.RecoveryShares, result.RootToken
281}
282
283func TestCoreUnseal(core *Core, key []byte) (bool, error) {
284	return core.Unseal(key)
285}
286
287func TestCoreUnsealWithRecoveryKeys(core *Core, key []byte) (bool, error) {
288	return core.UnsealWithRecoveryKeys(key)
289}
290
291// TestCoreUnsealed returns a pure in-memory core that is already
292// initialized and unsealed.
293func TestCoreUnsealed(t testing.T) (*Core, [][]byte, string) {
294	t.Helper()
295	core := TestCore(t)
296	return testCoreUnsealed(t, core)
297}
298
299// TestCoreUnsealedRaw returns a pure in-memory core that is already
300// initialized, unsealed, and with raw endpoints enabled.
301func TestCoreUnsealedRaw(t testing.T) (*Core, [][]byte, string) {
302	t.Helper()
303	core := TestCoreRaw(t)
304	return testCoreUnsealed(t, core)
305}
306
307// TestCoreUnsealedWithConfig returns a pure in-memory core that is already
308// initialized, unsealed, with the any provided core config values overridden.
309func TestCoreUnsealedWithConfig(t testing.T, conf *CoreConfig) (*Core, [][]byte, string) {
310	t.Helper()
311	core := TestCoreWithConfig(t, conf)
312	return testCoreUnsealed(t, core)
313}
314
315func testCoreUnsealed(t testing.T, core *Core) (*Core, [][]byte, string) {
316	t.Helper()
317	keys, token := TestCoreInit(t, core)
318	for _, key := range keys {
319		if _, err := TestCoreUnseal(core, TestKeyCopy(key)); err != nil {
320			t.Fatalf("unseal err: %s", err)
321		}
322	}
323
324	if core.Sealed() {
325		t.Fatal("should not be sealed")
326	}
327
328	testCoreAddSecretMount(t, core, token)
329
330	return core, keys, token
331}
332
333func testCoreAddSecretMount(t testing.T, core *Core, token string) {
334	kvReq := &logical.Request{
335		Operation:   logical.UpdateOperation,
336		ClientToken: token,
337		Path:        "sys/mounts/secret",
338		Data: map[string]interface{}{
339			"type":        "kv",
340			"path":        "secret/",
341			"description": "key/value secret storage",
342			"options": map[string]string{
343				"version": "1",
344			},
345		},
346	}
347	resp, err := core.HandleRequest(namespace.RootContext(nil), kvReq)
348	if err != nil {
349		t.Fatal(err)
350	}
351	if resp.IsError() {
352		t.Fatal(err)
353	}
354
355}
356
357func TestCoreUnsealedBackend(t testing.T, backend physical.Backend) (*Core, [][]byte, string) {
358	t.Helper()
359	logger := logging.NewVaultLogger(log.Trace)
360	conf := testCoreConfig(t, backend, logger)
361	conf.Seal = NewTestSeal(t, nil)
362
363	core, err := NewCore(conf)
364	if err != nil {
365		t.Fatalf("err: %s", err)
366	}
367
368	keys, token := TestCoreInit(t, core)
369	for _, key := range keys {
370		if _, err := TestCoreUnseal(core, TestKeyCopy(key)); err != nil {
371			t.Fatalf("unseal err: %s", err)
372		}
373	}
374
375	if err := core.UnsealWithStoredKeys(context.Background()); err != nil {
376		t.Fatal(err)
377	}
378
379	if core.Sealed() {
380		t.Fatal("should not be sealed")
381	}
382
383	return core, keys, token
384}
385
386// TestKeyCopy is a silly little function to just copy the key so that
387// it can be used with Unseal easily.
388func TestKeyCopy(key []byte) []byte {
389	result := make([]byte, len(key))
390	copy(result, key)
391	return result
392}
393
394func TestDynamicSystemView(c *Core) *dynamicSystemView {
395	me := &MountEntry{
396		Config: MountConfig{
397			DefaultLeaseTTL: 24 * time.Hour,
398			MaxLeaseTTL:     2 * 24 * time.Hour,
399		},
400	}
401
402	return &dynamicSystemView{c, me}
403}
404
405// TestAddTestPlugin registers the testFunc as part of the plugin command to the
406// plugin catalog. If provided, uses tmpDir as the plugin directory.
407func TestAddTestPlugin(t testing.T, c *Core, name string, pluginType consts.PluginType, testFunc string, env []string, tempDir string) {
408	file, err := os.Open(os.Args[0])
409	if err != nil {
410		t.Fatal(err)
411	}
412	defer file.Close()
413
414	dirPath := filepath.Dir(os.Args[0])
415	fileName := filepath.Base(os.Args[0])
416
417	if tempDir != "" {
418		fi, err := file.Stat()
419		if err != nil {
420			t.Fatal(err)
421		}
422
423		// Copy over the file to the temp dir
424		dst := filepath.Join(tempDir, fileName)
425		out, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode())
426		if err != nil {
427			t.Fatal(err)
428		}
429		defer out.Close()
430
431		if _, err = io.Copy(out, file); err != nil {
432			t.Fatal(err)
433		}
434		err = out.Sync()
435		if err != nil {
436			t.Fatal(err)
437		}
438
439		dirPath = tempDir
440	}
441
442	// Determine plugin directory full path, evaluating potential symlink path
443	fullPath, err := filepath.EvalSymlinks(dirPath)
444	if err != nil {
445		t.Fatal(err)
446	}
447
448	reader, err := os.Open(filepath.Join(fullPath, fileName))
449	if err != nil {
450		t.Fatal(err)
451	}
452	defer reader.Close()
453
454	// Find out the sha256
455	hash := sha256.New()
456
457	_, err = io.Copy(hash, reader)
458	if err != nil {
459		t.Fatal(err)
460	}
461
462	sum := hash.Sum(nil)
463
464	// Set core's plugin directory and plugin catalog directory
465	c.pluginDirectory = fullPath
466	c.pluginCatalog.directory = fullPath
467
468	args := []string{fmt.Sprintf("--test.run=%s", testFunc)}
469	err = c.pluginCatalog.Set(context.Background(), name, pluginType, fileName, args, env, sum)
470	if err != nil {
471		t.Fatal(err)
472	}
473}
474
475var testLogicalBackends = map[string]logical.Factory{}
476var testCredentialBackends = map[string]logical.Factory{}
477
478// StartSSHHostTestServer starts the test server which responds to SSH
479// authentication. Used to test the SSH secret backend.
480func StartSSHHostTestServer() (string, error) {
481	pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(testSharedPublicKey))
482	if err != nil {
483		return "", fmt.Errorf("error parsing public key")
484	}
485	serverConfig := &ssh.ServerConfig{
486		PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
487			if bytes.Compare(pubKey.Marshal(), key.Marshal()) == 0 {
488				return &ssh.Permissions{}, nil
489			} else {
490				return nil, fmt.Errorf("key does not match")
491			}
492		},
493	}
494	signer, err := ssh.ParsePrivateKey([]byte(testSharedPrivateKey))
495	if err != nil {
496		panic("Error parsing private key")
497	}
498	serverConfig.AddHostKey(signer)
499
500	soc, err := net.Listen("tcp", "127.0.0.1:0")
501	if err != nil {
502		return "", fmt.Errorf("error listening to connection")
503	}
504
505	go func() {
506		for {
507			conn, err := soc.Accept()
508			if err != nil {
509				panic(fmt.Sprintf("Error accepting incoming connection: %s", err))
510			}
511			defer conn.Close()
512			sshConn, chanReqs, _, err := ssh.NewServerConn(conn, serverConfig)
513			if err != nil {
514				panic(fmt.Sprintf("Handshaking error: %v", err))
515			}
516
517			go func() {
518				for chanReq := range chanReqs {
519					go func(chanReq ssh.NewChannel) {
520						if chanReq.ChannelType() != "session" {
521							chanReq.Reject(ssh.UnknownChannelType, "unknown channel type")
522							return
523						}
524
525						ch, requests, err := chanReq.Accept()
526						if err != nil {
527							panic(fmt.Sprintf("Error accepting channel: %s", err))
528						}
529
530						go func(ch ssh.Channel, in <-chan *ssh.Request) {
531							for req := range in {
532								executeServerCommand(ch, req)
533							}
534						}(ch, requests)
535					}(chanReq)
536				}
537				sshConn.Close()
538			}()
539		}
540	}()
541	return soc.Addr().String(), nil
542}
543
544// This executes the commands requested to be run on the server.
545// Used to test the SSH secret backend.
546func executeServerCommand(ch ssh.Channel, req *ssh.Request) {
547	command := string(req.Payload[4:])
548	cmd := exec.Command("/bin/bash", []string{"-c", command}...)
549	req.Reply(true, nil)
550
551	cmd.Stdout = ch
552	cmd.Stderr = ch
553	cmd.Stdin = ch
554
555	err := cmd.Start()
556	if err != nil {
557		panic(fmt.Sprintf("Error starting the command: '%s'", err))
558	}
559
560	go func() {
561		_, err := cmd.Process.Wait()
562		if err != nil {
563			panic(fmt.Sprintf("Error while waiting for command to finish:'%s'", err))
564		}
565		ch.Close()
566	}()
567}
568
569// This adds a credential backend for the test core. This needs to be
570// invoked before the test core is created.
571func AddTestCredentialBackend(name string, factory logical.Factory) error {
572	if name == "" {
573		return fmt.Errorf("missing backend name")
574	}
575	if factory == nil {
576		return fmt.Errorf("missing backend factory function")
577	}
578	testCredentialBackends[name] = factory
579	return nil
580}
581
582// This adds a logical backend for the test core. This needs to be
583// invoked before the test core is created.
584func AddTestLogicalBackend(name string, factory logical.Factory) error {
585	if name == "" {
586		return fmt.Errorf("missing backend name")
587	}
588	if factory == nil {
589		return fmt.Errorf("missing backend factory function")
590	}
591	testLogicalBackends[name] = factory
592	return nil
593}
594
595type noopAudit struct {
596	Config    *audit.BackendConfig
597	salt      *salt.Salt
598	saltMutex sync.RWMutex
599	formatter audit.AuditFormatter
600	records   [][]byte
601	l         sync.RWMutex
602}
603
604func (n *noopAudit) GetHash(ctx context.Context, data string) (string, error) {
605	salt, err := n.Salt(ctx)
606	if err != nil {
607		return "", err
608	}
609	return salt.GetIdentifiedHMAC(data), nil
610}
611
612func (n *noopAudit) LogRequest(ctx context.Context, in *logical.LogInput) error {
613	n.l.Lock()
614	defer n.l.Unlock()
615	var w bytes.Buffer
616	err := n.formatter.FormatRequest(ctx, &w, audit.FormatterConfig{}, in)
617	if err != nil {
618		return err
619	}
620	n.records = append(n.records, w.Bytes())
621	return nil
622}
623
624func (n *noopAudit) LogResponse(ctx context.Context, in *logical.LogInput) error {
625	n.l.Lock()
626	defer n.l.Unlock()
627	var w bytes.Buffer
628	err := n.formatter.FormatResponse(ctx, &w, audit.FormatterConfig{}, in)
629	if err != nil {
630		return err
631	}
632	n.records = append(n.records, w.Bytes())
633	return nil
634}
635
636func (n *noopAudit) Reload(_ context.Context) error {
637	return nil
638}
639
640func (n *noopAudit) Invalidate(_ context.Context) {
641	n.saltMutex.Lock()
642	defer n.saltMutex.Unlock()
643	n.salt = nil
644}
645
646func (n *noopAudit) Salt(ctx context.Context) (*salt.Salt, error) {
647	n.saltMutex.RLock()
648	if n.salt != nil {
649		defer n.saltMutex.RUnlock()
650		return n.salt, nil
651	}
652	n.saltMutex.RUnlock()
653	n.saltMutex.Lock()
654	defer n.saltMutex.Unlock()
655	if n.salt != nil {
656		return n.salt, nil
657	}
658	salt, err := salt.NewSalt(ctx, n.Config.SaltView, n.Config.SaltConfig)
659	if err != nil {
660		return nil, err
661	}
662	n.salt = salt
663	return salt, nil
664}
665
666func AddNoopAudit(conf *CoreConfig) {
667	conf.AuditBackends = map[string]audit.Factory{
668		"noop": func(_ context.Context, config *audit.BackendConfig) (audit.Backend, error) {
669			view := &logical.InmemStorage{}
670			view.Put(context.Background(), &logical.StorageEntry{
671				Key:   "salt",
672				Value: []byte("foo"),
673			})
674			n := &noopAudit{
675				Config: config,
676			}
677			n.formatter.AuditFormatWriter = &audit.JSONFormatWriter{
678				SaltFunc: n.Salt,
679			}
680			return n, nil
681		},
682	}
683}
684
685type rawHTTP struct{}
686
687func (n *rawHTTP) HandleRequest(ctx context.Context, req *logical.Request) (*logical.Response, error) {
688	return &logical.Response{
689		Data: map[string]interface{}{
690			logical.HTTPStatusCode:  200,
691			logical.HTTPContentType: "plain/text",
692			logical.HTTPRawBody:     []byte("hello world"),
693		},
694	}, nil
695}
696
697func (n *rawHTTP) HandleExistenceCheck(ctx context.Context, req *logical.Request) (bool, bool, error) {
698	return false, false, nil
699}
700
701func (n *rawHTTP) SpecialPaths() *logical.Paths {
702	return &logical.Paths{Unauthenticated: []string{"*"}}
703}
704
705func (n *rawHTTP) System() logical.SystemView {
706	return logical.StaticSystemView{
707		DefaultLeaseTTLVal: time.Hour * 24,
708		MaxLeaseTTLVal:     time.Hour * 24 * 32,
709	}
710}
711
712func (n *rawHTTP) Logger() log.Logger {
713	return logging.NewVaultLogger(log.Trace)
714}
715
716func (n *rawHTTP) Cleanup(ctx context.Context) {
717	// noop
718}
719
720func (n *rawHTTP) Initialize(ctx context.Context, req *logical.InitializationRequest) error {
721	return nil
722}
723
724func (n *rawHTTP) InvalidateKey(context.Context, string) {
725	// noop
726}
727
728func (n *rawHTTP) Setup(ctx context.Context, config *logical.BackendConfig) error {
729	// noop
730	return nil
731}
732
733func (n *rawHTTP) Type() logical.BackendType {
734	return logical.TypeLogical
735}
736
737func GenerateRandBytes(length int) ([]byte, error) {
738	if length < 0 {
739		return nil, fmt.Errorf("length must be >= 0")
740	}
741
742	buf := make([]byte, length)
743	if length == 0 {
744		return buf, nil
745	}
746
747	n, err := rand.Read(buf)
748	if err != nil {
749		return nil, err
750	}
751	if n != length {
752		return nil, fmt.Errorf("unable to read %d bytes; only read %d", length, n)
753	}
754
755	return buf, nil
756}
757
758func TestWaitActive(t testing.T, core *Core) {
759	t.Helper()
760	if err := TestWaitActiveWithError(core); err != nil {
761		t.Fatal(err)
762	}
763}
764
765func TestWaitActiveWithError(core *Core) error {
766	start := time.Now()
767	var standby bool
768	var err error
769	for time.Now().Sub(start) < 30*time.Second {
770		standby, err = core.Standby()
771		if err != nil {
772			return err
773		}
774		if !standby {
775			break
776		}
777	}
778	if standby {
779		return errors.New("should not be in standby mode")
780	}
781	return nil
782}
783
784type TestCluster struct {
785	BarrierKeys        [][]byte
786	RecoveryKeys       [][]byte
787	CACert             *x509.Certificate
788	CACertBytes        []byte
789	CACertPEM          []byte
790	CACertPEMFile      string
791	CAKey              *ecdsa.PrivateKey
792	CAKeyPEM           []byte
793	Cores              []*TestClusterCore
794	ID                 string
795	RootToken          string
796	RootCAs            *x509.CertPool
797	TempDir            string
798	ClientAuthRequired bool
799	Logger             log.Logger
800}
801
802func (c *TestCluster) Start() {
803	for _, core := range c.Cores {
804		if core.Server != nil {
805			for _, ln := range core.Listeners {
806				go core.Server.Serve(ln)
807			}
808		}
809	}
810}
811
812// UnsealCores uses the cluster barrier keys to unseal the test cluster cores
813func (c *TestCluster) UnsealCores(t testing.T) {
814	if err := c.UnsealCoresWithError(); err != nil {
815		t.Fatal(err)
816	}
817}
818
819func (c *TestCluster) UnsealCoresWithError() error {
820	numCores := len(c.Cores)
821
822	// Unseal first core
823	for _, key := range c.BarrierKeys {
824		if _, err := c.Cores[0].Unseal(TestKeyCopy(key)); err != nil {
825			return fmt.Errorf("unseal err: %s", err)
826		}
827	}
828
829	// Verify unsealed
830	if c.Cores[0].Sealed() {
831		return fmt.Errorf("should not be sealed")
832	}
833
834	if err := TestWaitActiveWithError(c.Cores[0].Core); err != nil {
835		return err
836	}
837
838	// Unseal other cores
839	for i := 1; i < numCores; i++ {
840		for _, key := range c.BarrierKeys {
841			if _, err := c.Cores[i].Core.Unseal(TestKeyCopy(key)); err != nil {
842				return fmt.Errorf("unseal err: %s", err)
843			}
844		}
845	}
846
847	// Let them come fully up to standby
848	time.Sleep(2 * time.Second)
849
850	// Ensure cluster connection info is populated.
851	// Other cores should not come up as leaders.
852	for i := 1; i < numCores; i++ {
853		isLeader, _, _, err := c.Cores[i].Leader()
854		if err != nil {
855			return err
856		}
857		if isLeader {
858			return fmt.Errorf("core[%d] should not be leader", i)
859		}
860	}
861
862	return nil
863}
864
865func (c *TestCluster) UnsealCore(t testing.T, core *TestClusterCore) {
866	for _, key := range c.BarrierKeys {
867		if _, err := core.Core.Unseal(TestKeyCopy(key)); err != nil {
868			t.Fatalf("unseal err: %s", err)
869		}
870	}
871}
872
873func (c *TestCluster) EnsureCoresSealed(t testing.T) {
874	t.Helper()
875	if err := c.ensureCoresSealed(); err != nil {
876		t.Fatal(err)
877	}
878}
879
880func (c *TestClusterCore) Seal(t testing.T) {
881	t.Helper()
882	if err := c.Core.sealInternal(); err != nil {
883		t.Fatal(err)
884	}
885}
886
887func CleanupClusters(clusters []*TestCluster) {
888	wg := &sync.WaitGroup{}
889	for _, cluster := range clusters {
890		wg.Add(1)
891		lc := cluster
892		go func() {
893			defer wg.Done()
894			lc.Cleanup()
895		}()
896	}
897	wg.Wait()
898}
899
900func (c *TestCluster) Cleanup() {
901	c.Logger.Info("cleaning up vault cluster")
902	for _, core := range c.Cores {
903		core.CoreConfig.Logger.SetLevel(log.Error)
904	}
905
906	// Close listeners
907	wg := &sync.WaitGroup{}
908	for _, core := range c.Cores {
909		wg.Add(1)
910		lc := core
911
912		go func() {
913			defer wg.Done()
914			if lc.Listeners != nil {
915				for _, ln := range lc.Listeners {
916					ln.Close()
917				}
918			}
919			if lc.licensingStopCh != nil {
920				close(lc.licensingStopCh)
921				lc.licensingStopCh = nil
922			}
923
924			if err := lc.Shutdown(); err != nil {
925				lc.Logger().Error("error during shutdown; abandoning sealing", "error", err)
926			} else {
927				timeout := time.Now().Add(60 * time.Second)
928				for {
929					if time.Now().After(timeout) {
930						lc.Logger().Error("timeout waiting for core to seal")
931					}
932					if lc.Sealed() {
933						break
934					}
935					time.Sleep(250 * time.Millisecond)
936				}
937			}
938		}()
939	}
940
941	wg.Wait()
942
943	// Remove any temp dir that exists
944	if c.TempDir != "" {
945		os.RemoveAll(c.TempDir)
946	}
947
948	// Give time to actually shut down/clean up before the next test
949	time.Sleep(time.Second)
950}
951
952func (c *TestCluster) ensureCoresSealed() error {
953	for _, core := range c.Cores {
954		if err := core.Shutdown(); err != nil {
955			return err
956		}
957		timeout := time.Now().Add(60 * time.Second)
958		for {
959			if time.Now().After(timeout) {
960				return fmt.Errorf("timeout waiting for core to seal")
961			}
962			if core.Sealed() {
963				break
964			}
965			time.Sleep(250 * time.Millisecond)
966		}
967	}
968	return nil
969}
970
971// UnsealWithStoredKeys uses stored keys to unseal the test cluster cores
972func (c *TestCluster) UnsealWithStoredKeys(t testing.T) error {
973	for _, core := range c.Cores {
974		if err := core.UnsealWithStoredKeys(context.Background()); err != nil {
975			return err
976		}
977		timeout := time.Now().Add(60 * time.Second)
978		for {
979			if time.Now().After(timeout) {
980				return fmt.Errorf("timeout waiting for core to unseal")
981			}
982			if !core.Sealed() {
983				break
984			}
985			time.Sleep(250 * time.Millisecond)
986		}
987	}
988	return nil
989}
990
991func SetReplicationFailureMode(core *TestClusterCore, mode uint32) {
992	atomic.StoreUint32(core.Core.replicationFailure, mode)
993}
994
995type TestListener struct {
996	net.Listener
997	Address *net.TCPAddr
998}
999
1000type TestClusterCore struct {
1001	*Core
1002	CoreConfig           *CoreConfig
1003	Client               *api.Client
1004	Handler              http.Handler
1005	Listeners            []*TestListener
1006	ReloadFuncs          *map[string][]reload.ReloadFunc
1007	ReloadFuncsLock      *sync.RWMutex
1008	Server               *http.Server
1009	ServerCert           *x509.Certificate
1010	ServerCertBytes      []byte
1011	ServerCertPEM        []byte
1012	ServerKey            *ecdsa.PrivateKey
1013	ServerKeyPEM         []byte
1014	TLSConfig            *tls.Config
1015	UnderlyingStorage    physical.Backend
1016	UnderlyingRawStorage physical.Backend
1017	Barrier              SecurityBarrier
1018	NodeID               string
1019}
1020
1021type TestClusterOptions struct {
1022	KeepStandbysSealed bool
1023	SkipInit           bool
1024	HandlerFunc        func(*HandlerProperties) http.Handler
1025	BaseListenAddress  string
1026	NumCores           int
1027	SealFunc           func() Seal
1028	Logger             log.Logger
1029	TempDir            string
1030	CACert             []byte
1031	CAKey              *ecdsa.PrivateKey
1032	PhysicalFactory    func(hclog.Logger) (physical.Backend, error)
1033	FirstCoreNumber    int
1034	RequireClientAuth  bool
1035}
1036
1037var DefaultNumCores = 3
1038
1039type certInfo struct {
1040	cert      *x509.Certificate
1041	certPEM   []byte
1042	certBytes []byte
1043	key       *ecdsa.PrivateKey
1044	keyPEM    []byte
1045}
1046
1047// NewTestCluster creates a new test cluster based on the provided core config
1048// and test cluster options.
1049//
1050// N.B. Even though a single base CoreConfig is provided, NewTestCluster will instantiate a
1051// core config for each core it creates. If separate seal per core is desired, opts.SealFunc
1052// can be provided to generate a seal for each one. Otherwise, the provided base.Seal will be
1053// shared among cores. NewCore's default behavior is to generate a new DefaultSeal if the
1054// provided Seal in coreConfig (i.e. base.Seal) is nil.
1055//
1056// If opts.Logger is provided, it takes precedence and will be used as the cluster
1057// logger and will be the basis for each core's logger.  If no opts.Logger is
1058// given, one will be generated based on t.Name() for the cluster logger, and if
1059// no base.Logger is given will also be used as the basis for each core's logger.
1060func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *TestCluster {
1061	var err error
1062
1063	var numCores int
1064	if opts == nil || opts.NumCores == 0 {
1065		numCores = DefaultNumCores
1066	} else {
1067		numCores = opts.NumCores
1068	}
1069
1070	var firstCoreNumber int
1071	if opts != nil {
1072		firstCoreNumber = opts.FirstCoreNumber
1073	}
1074
1075	certIPs := []net.IP{
1076		net.IPv6loopback,
1077		net.ParseIP("127.0.0.1"),
1078	}
1079	var baseAddr *net.TCPAddr
1080	if opts != nil && opts.BaseListenAddress != "" {
1081		baseAddr, err = net.ResolveTCPAddr("tcp", opts.BaseListenAddress)
1082		if err != nil {
1083			t.Fatal("could not parse given base IP")
1084		}
1085		certIPs = append(certIPs, baseAddr.IP)
1086	}
1087
1088	var testCluster TestCluster
1089
1090	if opts != nil && opts.Logger != nil {
1091		testCluster.Logger = opts.Logger
1092	} else {
1093		testCluster.Logger = logging.NewVaultLogger(log.Trace).Named(t.Name())
1094	}
1095
1096	if opts != nil && opts.TempDir != "" {
1097		if _, err := os.Stat(opts.TempDir); os.IsNotExist(err) {
1098			if err := os.MkdirAll(opts.TempDir, 0700); err != nil {
1099				t.Fatal(err)
1100			}
1101		}
1102		testCluster.TempDir = opts.TempDir
1103	} else {
1104		tempDir, err := ioutil.TempDir("", "vault-test-cluster-")
1105		if err != nil {
1106			t.Fatal(err)
1107		}
1108		testCluster.TempDir = tempDir
1109	}
1110
1111	var caKey *ecdsa.PrivateKey
1112	if opts != nil && opts.CAKey != nil {
1113		caKey = opts.CAKey
1114	} else {
1115		caKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
1116		if err != nil {
1117			t.Fatal(err)
1118		}
1119	}
1120	testCluster.CAKey = caKey
1121	var caBytes []byte
1122	if opts != nil && len(opts.CACert) > 0 {
1123		caBytes = opts.CACert
1124	} else {
1125		caCertTemplate := &x509.Certificate{
1126			Subject: pkix.Name{
1127				CommonName: "localhost",
1128			},
1129			DNSNames:              []string{"localhost"},
1130			IPAddresses:           certIPs,
1131			KeyUsage:              x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign),
1132			SerialNumber:          big.NewInt(mathrand.Int63()),
1133			NotBefore:             time.Now().Add(-30 * time.Second),
1134			NotAfter:              time.Now().Add(262980 * time.Hour),
1135			BasicConstraintsValid: true,
1136			IsCA:                  true,
1137		}
1138		caBytes, err = x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, caKey.Public(), caKey)
1139		if err != nil {
1140			t.Fatal(err)
1141		}
1142	}
1143	caCert, err := x509.ParseCertificate(caBytes)
1144	if err != nil {
1145		t.Fatal(err)
1146	}
1147	testCluster.CACert = caCert
1148	testCluster.CACertBytes = caBytes
1149	testCluster.RootCAs = x509.NewCertPool()
1150	testCluster.RootCAs.AddCert(caCert)
1151	caCertPEMBlock := &pem.Block{
1152		Type:  "CERTIFICATE",
1153		Bytes: caBytes,
1154	}
1155	testCluster.CACertPEM = pem.EncodeToMemory(caCertPEMBlock)
1156	testCluster.CACertPEMFile = filepath.Join(testCluster.TempDir, "ca_cert.pem")
1157	err = ioutil.WriteFile(testCluster.CACertPEMFile, testCluster.CACertPEM, 0755)
1158	if err != nil {
1159		t.Fatal(err)
1160	}
1161	marshaledCAKey, err := x509.MarshalECPrivateKey(caKey)
1162	if err != nil {
1163		t.Fatal(err)
1164	}
1165	caKeyPEMBlock := &pem.Block{
1166		Type:  "EC PRIVATE KEY",
1167		Bytes: marshaledCAKey,
1168	}
1169	testCluster.CAKeyPEM = pem.EncodeToMemory(caKeyPEMBlock)
1170	err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "ca_key.pem"), testCluster.CAKeyPEM, 0755)
1171	if err != nil {
1172		t.Fatal(err)
1173	}
1174
1175	var certInfoSlice []*certInfo
1176
1177	//
1178	// Certs generation
1179	//
1180	for i := 0; i < numCores; i++ {
1181		key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
1182		if err != nil {
1183			t.Fatal(err)
1184		}
1185		certTemplate := &x509.Certificate{
1186			Subject: pkix.Name{
1187				CommonName: "localhost",
1188			},
1189			DNSNames:    []string{"localhost"},
1190			IPAddresses: certIPs,
1191			ExtKeyUsage: []x509.ExtKeyUsage{
1192				x509.ExtKeyUsageServerAuth,
1193				x509.ExtKeyUsageClientAuth,
1194			},
1195			KeyUsage:     x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
1196			SerialNumber: big.NewInt(mathrand.Int63()),
1197			NotBefore:    time.Now().Add(-30 * time.Second),
1198			NotAfter:     time.Now().Add(262980 * time.Hour),
1199		}
1200		certBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, caCert, key.Public(), caKey)
1201		if err != nil {
1202			t.Fatal(err)
1203		}
1204		cert, err := x509.ParseCertificate(certBytes)
1205		if err != nil {
1206			t.Fatal(err)
1207		}
1208		certPEMBlock := &pem.Block{
1209			Type:  "CERTIFICATE",
1210			Bytes: certBytes,
1211		}
1212		certPEM := pem.EncodeToMemory(certPEMBlock)
1213		marshaledKey, err := x509.MarshalECPrivateKey(key)
1214		if err != nil {
1215			t.Fatal(err)
1216		}
1217		keyPEMBlock := &pem.Block{
1218			Type:  "EC PRIVATE KEY",
1219			Bytes: marshaledKey,
1220		}
1221		keyPEM := pem.EncodeToMemory(keyPEMBlock)
1222
1223		certInfoSlice = append(certInfoSlice, &certInfo{
1224			cert:      cert,
1225			certPEM:   certPEM,
1226			certBytes: certBytes,
1227			key:       key,
1228			keyPEM:    keyPEM,
1229		})
1230	}
1231
1232	//
1233	// Listener setup
1234	//
1235	ports := make([]int, numCores)
1236	if baseAddr != nil {
1237		for i := 0; i < numCores; i++ {
1238			ports[i] = baseAddr.Port + i
1239		}
1240	} else {
1241		baseAddr = &net.TCPAddr{
1242			IP:   net.ParseIP("127.0.0.1"),
1243			Port: 0,
1244		}
1245	}
1246
1247	listeners := [][]*TestListener{}
1248	servers := []*http.Server{}
1249	handlers := []http.Handler{}
1250	tlsConfigs := []*tls.Config{}
1251	certGetters := []*reload.CertificateGetter{}
1252	for i := 0; i < numCores; i++ {
1253		baseAddr.Port = ports[i]
1254		ln, err := net.ListenTCP("tcp", baseAddr)
1255		if err != nil {
1256			t.Fatal(err)
1257		}
1258		certFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_cert.pem", i+1, ln.Addr().(*net.TCPAddr).Port))
1259		keyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_key.pem", i+1, ln.Addr().(*net.TCPAddr).Port))
1260		err = ioutil.WriteFile(certFile, certInfoSlice[i].certPEM, 0755)
1261		if err != nil {
1262			t.Fatal(err)
1263		}
1264		err = ioutil.WriteFile(keyFile, certInfoSlice[i].keyPEM, 0755)
1265		if err != nil {
1266			t.Fatal(err)
1267		}
1268		tlsCert, err := tls.X509KeyPair(certInfoSlice[i].certPEM, certInfoSlice[i].keyPEM)
1269		if err != nil {
1270			t.Fatal(err)
1271		}
1272		certGetter := reload.NewCertificateGetter(certFile, keyFile, "")
1273		certGetters = append(certGetters, certGetter)
1274		tlsConfig := &tls.Config{
1275			Certificates:   []tls.Certificate{tlsCert},
1276			RootCAs:        testCluster.RootCAs,
1277			ClientCAs:      testCluster.RootCAs,
1278			ClientAuth:     tls.RequestClientCert,
1279			NextProtos:     []string{"h2", "http/1.1"},
1280			GetCertificate: certGetter.GetCertificate,
1281		}
1282		if opts != nil && opts.RequireClientAuth {
1283			tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
1284			testCluster.ClientAuthRequired = true
1285		}
1286		tlsConfig.BuildNameToCertificate()
1287		tlsConfigs = append(tlsConfigs, tlsConfig)
1288		lns := []*TestListener{&TestListener{
1289			Listener: tls.NewListener(ln, tlsConfig),
1290			Address:  ln.Addr().(*net.TCPAddr),
1291		},
1292		}
1293		listeners = append(listeners, lns)
1294		var handler http.Handler = http.NewServeMux()
1295		handlers = append(handlers, handler)
1296		server := &http.Server{
1297			Handler:  handler,
1298			ErrorLog: testCluster.Logger.StandardLogger(nil),
1299		}
1300		servers = append(servers, server)
1301	}
1302
1303	// Create three cores with the same physical and different redirect/cluster
1304	// addrs.
1305	// N.B.: On OSX, instead of random ports, it assigns new ports to new
1306	// listeners sequentially. Aside from being a bad idea in a security sense,
1307	// it also broke tests that assumed it was OK to just use the port above
1308	// the redirect addr. This has now been changed to 105 ports above, but if
1309	// we ever do more than three nodes in a cluster it may need to be bumped.
1310	// Note: it's 105 so that we don't conflict with a running Consul by
1311	// default.
1312	coreConfig := &CoreConfig{
1313		LogicalBackends:    make(map[string]logical.Factory),
1314		CredentialBackends: make(map[string]logical.Factory),
1315		AuditBackends:      make(map[string]audit.Factory),
1316		RedirectAddr:       fmt.Sprintf("https://127.0.0.1:%d", listeners[0][0].Address.Port),
1317		ClusterAddr:        "https://127.0.0.1:0",
1318		DisableMlock:       true,
1319		EnableUI:           true,
1320		EnableRaw:          true,
1321		BuiltinRegistry:    NewMockBuiltinRegistry(),
1322	}
1323
1324	if base != nil {
1325		coreConfig.DisableCache = base.DisableCache
1326		coreConfig.EnableUI = base.EnableUI
1327		coreConfig.DefaultLeaseTTL = base.DefaultLeaseTTL
1328		coreConfig.MaxLeaseTTL = base.MaxLeaseTTL
1329		coreConfig.CacheSize = base.CacheSize
1330		coreConfig.PluginDirectory = base.PluginDirectory
1331		coreConfig.Seal = base.Seal
1332		coreConfig.DevToken = base.DevToken
1333		coreConfig.EnableRaw = base.EnableRaw
1334		coreConfig.DisableSealWrap = base.DisableSealWrap
1335		coreConfig.DevLicenseDuration = base.DevLicenseDuration
1336		coreConfig.DisableCache = base.DisableCache
1337		coreConfig.LicensingConfig = base.LicensingConfig
1338		coreConfig.DisablePerformanceStandby = base.DisablePerformanceStandby
1339		coreConfig.MetricsHelper = base.MetricsHelper
1340		if base.BuiltinRegistry != nil {
1341			coreConfig.BuiltinRegistry = base.BuiltinRegistry
1342		}
1343
1344		if !coreConfig.DisableMlock {
1345			base.DisableMlock = false
1346		}
1347
1348		if base.Physical != nil {
1349			coreConfig.Physical = base.Physical
1350		}
1351
1352		if base.HAPhysical != nil {
1353			coreConfig.HAPhysical = base.HAPhysical
1354		}
1355
1356		// Used to set something non-working to test fallback
1357		switch base.ClusterAddr {
1358		case "empty":
1359			coreConfig.ClusterAddr = ""
1360		case "":
1361		default:
1362			coreConfig.ClusterAddr = base.ClusterAddr
1363		}
1364
1365		if base.LogicalBackends != nil {
1366			for k, v := range base.LogicalBackends {
1367				coreConfig.LogicalBackends[k] = v
1368			}
1369		}
1370		if base.CredentialBackends != nil {
1371			for k, v := range base.CredentialBackends {
1372				coreConfig.CredentialBackends[k] = v
1373			}
1374		}
1375		if base.AuditBackends != nil {
1376			for k, v := range base.AuditBackends {
1377				coreConfig.AuditBackends[k] = v
1378			}
1379		}
1380		if base.Logger != nil {
1381			coreConfig.Logger = base.Logger
1382		}
1383
1384		coreConfig.ClusterCipherSuites = base.ClusterCipherSuites
1385
1386		coreConfig.DisableCache = base.DisableCache
1387
1388		coreConfig.DevToken = base.DevToken
1389		coreConfig.CounterSyncInterval = base.CounterSyncInterval
1390
1391	}
1392
1393	addAuditBackend := len(coreConfig.AuditBackends) == 0
1394	if addAuditBackend {
1395		AddNoopAudit(coreConfig)
1396	}
1397
1398	if coreConfig.Physical == nil && (opts == nil || opts.PhysicalFactory == nil) {
1399		coreConfig.Physical, err = physInmem.NewInmem(nil, testCluster.Logger)
1400		if err != nil {
1401			t.Fatal(err)
1402		}
1403	}
1404	if coreConfig.HAPhysical == nil && (opts == nil || opts.PhysicalFactory == nil) {
1405		haPhys, err := physInmem.NewInmemHA(nil, testCluster.Logger)
1406		if err != nil {
1407			t.Fatal(err)
1408		}
1409		coreConfig.HAPhysical = haPhys.(physical.HABackend)
1410	}
1411
1412	pubKey, priKey, err := testGenerateCoreKeys()
1413	if err != nil {
1414		t.Fatalf("err: %v", err)
1415	}
1416
1417	cores := []*Core{}
1418	coreConfigs := []*CoreConfig{}
1419	for i := 0; i < numCores; i++ {
1420		localConfig := *coreConfig
1421		localConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", listeners[i][0].Address.Port)
1422
1423		// if opts.SealFunc is provided, use that to generate a seal for the config instead
1424		if opts != nil && opts.SealFunc != nil {
1425			localConfig.Seal = opts.SealFunc()
1426		}
1427
1428		if coreConfig.Logger == nil || (opts != nil && opts.Logger != nil) {
1429			localConfig.Logger = testCluster.Logger.Named(fmt.Sprintf("core%d", i))
1430		}
1431
1432		if opts != nil && opts.PhysicalFactory != nil {
1433			localConfig.Physical, err = opts.PhysicalFactory(localConfig.Logger)
1434			if err != nil {
1435				t.Fatalf("err: %v", err)
1436			}
1437
1438			if haPhysical, ok := localConfig.Physical.(physical.HABackend); ok {
1439				localConfig.HAPhysical = haPhysical
1440			}
1441		}
1442
1443		switch {
1444		case localConfig.LicensingConfig != nil:
1445			if pubKey != nil {
1446				localConfig.LicensingConfig.AdditionalPublicKeys = append(localConfig.LicensingConfig.AdditionalPublicKeys, pubKey.(ed25519.PublicKey))
1447			}
1448		default:
1449			localConfig.LicensingConfig = testGetLicensingConfig(pubKey)
1450		}
1451
1452		c, err := NewCore(&localConfig)
1453		if err != nil {
1454			t.Fatalf("err: %v", err)
1455		}
1456		c.coreNumber = firstCoreNumber + i
1457		cores = append(cores, c)
1458		coreConfigs = append(coreConfigs, &localConfig)
1459		if opts != nil && opts.HandlerFunc != nil {
1460			handlers[i] = opts.HandlerFunc(&HandlerProperties{
1461				Core:               c,
1462				MaxRequestDuration: DefaultMaxRequestDuration,
1463			})
1464			servers[i].Handler = handlers[i]
1465		}
1466
1467		// Set this in case the Seal was manually set before the core was
1468		// created
1469		if localConfig.Seal != nil {
1470			localConfig.Seal.SetCore(c)
1471		}
1472	}
1473
1474	//
1475	// Clustering setup
1476	//
1477	clusterAddrGen := func(lns []*TestListener) []*net.TCPAddr {
1478		ret := make([]*net.TCPAddr, len(lns))
1479		for i, ln := range lns {
1480			ret[i] = &net.TCPAddr{
1481				IP:   ln.Address.IP,
1482				Port: 0,
1483			}
1484		}
1485		return ret
1486	}
1487
1488	for i := 0; i < numCores; i++ {
1489		if coreConfigs[i].ClusterAddr != "" {
1490			cores[i].SetClusterListenerAddrs(clusterAddrGen(listeners[i]))
1491			cores[i].SetClusterHandler(handlers[i])
1492		}
1493	}
1494
1495	if opts == nil || !opts.SkipInit {
1496		bKeys, rKeys, root := TestCoreInitClusterWrapperSetup(t, cores[0], handlers[0])
1497		barrierKeys, _ := copystructure.Copy(bKeys)
1498		testCluster.BarrierKeys = barrierKeys.([][]byte)
1499		recoveryKeys, _ := copystructure.Copy(rKeys)
1500		testCluster.RecoveryKeys = recoveryKeys.([][]byte)
1501		testCluster.RootToken = root
1502
1503		// Write root token and barrier keys
1504		err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "root_token"), []byte(root), 0755)
1505		if err != nil {
1506			t.Fatal(err)
1507		}
1508		var buf bytes.Buffer
1509		for i, key := range testCluster.BarrierKeys {
1510			buf.Write([]byte(base64.StdEncoding.EncodeToString(key)))
1511			if i < len(testCluster.BarrierKeys)-1 {
1512				buf.WriteRune('\n')
1513			}
1514		}
1515		err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "barrier_keys"), buf.Bytes(), 0755)
1516		if err != nil {
1517			t.Fatal(err)
1518		}
1519		for i, key := range testCluster.RecoveryKeys {
1520			buf.Write([]byte(base64.StdEncoding.EncodeToString(key)))
1521			if i < len(testCluster.RecoveryKeys)-1 {
1522				buf.WriteRune('\n')
1523			}
1524		}
1525		err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "recovery_keys"), buf.Bytes(), 0755)
1526		if err != nil {
1527			t.Fatal(err)
1528		}
1529
1530		// Unseal first core
1531		for _, key := range bKeys {
1532			if _, err := cores[0].Unseal(TestKeyCopy(key)); err != nil {
1533				t.Fatalf("unseal err: %s", err)
1534			}
1535		}
1536
1537		ctx := context.Background()
1538
1539		// If stored keys is supported, the above will no no-op, so trigger auto-unseal
1540		// using stored keys to try to unseal
1541		if err := cores[0].UnsealWithStoredKeys(ctx); err != nil {
1542			t.Fatal(err)
1543		}
1544
1545		// Verify unsealed
1546		if cores[0].Sealed() {
1547			t.Fatal("should not be sealed")
1548		}
1549
1550		TestWaitActive(t, cores[0])
1551
1552		// Existing tests rely on this; we can make a toggle to disable it
1553		// later if we want
1554		kvReq := &logical.Request{
1555			Operation:   logical.UpdateOperation,
1556			ClientToken: testCluster.RootToken,
1557			Path:        "sys/mounts/secret",
1558			Data: map[string]interface{}{
1559				"type":        "kv",
1560				"path":        "secret/",
1561				"description": "key/value secret storage",
1562				"options": map[string]string{
1563					"version": "1",
1564				},
1565			},
1566		}
1567		resp, err := cores[0].HandleRequest(namespace.RootContext(ctx), kvReq)
1568		if err != nil {
1569			t.Fatal(err)
1570		}
1571		if resp.IsError() {
1572			t.Fatal(err)
1573		}
1574
1575		// Unseal other cores unless otherwise specified
1576		if (opts == nil || !opts.KeepStandbysSealed) && numCores > 1 {
1577			for i := 1; i < numCores; i++ {
1578				for _, key := range bKeys {
1579					if _, err := cores[i].Unseal(TestKeyCopy(key)); err != nil {
1580						t.Fatalf("unseal err: %s", err)
1581					}
1582				}
1583
1584				// If stored keys is supported, the above will no no-op, so trigger auto-unseal
1585				// using stored keys
1586				if err := cores[i].UnsealWithStoredKeys(ctx); err != nil {
1587					t.Fatal(err)
1588				}
1589			}
1590
1591			// Let them come fully up to standby
1592			time.Sleep(2 * time.Second)
1593
1594			// Ensure cluster connection info is populated.
1595			// Other cores should not come up as leaders.
1596			for i := 1; i < numCores; i++ {
1597				isLeader, _, _, err := cores[i].Leader()
1598				if err != nil {
1599					t.Fatal(err)
1600				}
1601				if isLeader {
1602					t.Fatalf("core[%d] should not be leader", i)
1603				}
1604			}
1605		}
1606
1607		//
1608		// Set test cluster core(s) and test cluster
1609		//
1610		cluster, err := cores[0].Cluster(context.Background())
1611		if err != nil {
1612			t.Fatal(err)
1613		}
1614		testCluster.ID = cluster.ID
1615
1616		if addAuditBackend {
1617			// Enable auditing.
1618			auditReq := &logical.Request{
1619				Operation:   logical.UpdateOperation,
1620				ClientToken: testCluster.RootToken,
1621				Path:        "sys/audit/noop",
1622				Data: map[string]interface{}{
1623					"type": "noop",
1624				},
1625			}
1626			resp, err = cores[0].HandleRequest(namespace.RootContext(ctx), auditReq)
1627			if err != nil {
1628				t.Fatal(err)
1629			}
1630
1631			if resp.IsError() {
1632				t.Fatal(err)
1633			}
1634		}
1635	}
1636
1637	getAPIClient := func(port int, tlsConfig *tls.Config) *api.Client {
1638		transport := cleanhttp.DefaultPooledTransport()
1639		transport.TLSClientConfig = tlsConfig.Clone()
1640		if err := http2.ConfigureTransport(transport); err != nil {
1641			t.Fatal(err)
1642		}
1643		client := &http.Client{
1644			Transport: transport,
1645			CheckRedirect: func(*http.Request, []*http.Request) error {
1646				// This can of course be overridden per-test by using its own client
1647				return fmt.Errorf("redirects not allowed in these tests")
1648			},
1649		}
1650		config := api.DefaultConfig()
1651		if config.Error != nil {
1652			t.Fatal(config.Error)
1653		}
1654		config.Address = fmt.Sprintf("https://127.0.0.1:%d", port)
1655		config.HttpClient = client
1656		config.MaxRetries = 0
1657		apiClient, err := api.NewClient(config)
1658		if err != nil {
1659			t.Fatal(err)
1660		}
1661		if opts == nil || !opts.SkipInit {
1662			apiClient.SetToken(testCluster.RootToken)
1663		}
1664		return apiClient
1665	}
1666
1667	var ret []*TestClusterCore
1668	for i := 0; i < numCores; i++ {
1669		tcc := &TestClusterCore{
1670			Core:                 cores[i],
1671			CoreConfig:           coreConfigs[i],
1672			ServerKey:            certInfoSlice[i].key,
1673			ServerKeyPEM:         certInfoSlice[i].keyPEM,
1674			ServerCert:           certInfoSlice[i].cert,
1675			ServerCertBytes:      certInfoSlice[i].certBytes,
1676			ServerCertPEM:        certInfoSlice[i].certPEM,
1677			Listeners:            listeners[i],
1678			Handler:              handlers[i],
1679			Server:               servers[i],
1680			TLSConfig:            tlsConfigs[i],
1681			Client:               getAPIClient(listeners[i][0].Address.Port, tlsConfigs[i]),
1682			Barrier:              cores[i].barrier,
1683			NodeID:               fmt.Sprintf("core-%d", i),
1684			UnderlyingRawStorage: coreConfigs[i].Physical,
1685		}
1686		tcc.ReloadFuncs = &cores[i].reloadFuncs
1687		tcc.ReloadFuncsLock = &cores[i].reloadFuncsLock
1688		tcc.ReloadFuncsLock.Lock()
1689		(*tcc.ReloadFuncs)["listener|tcp"] = []reload.ReloadFunc{certGetters[i].Reload}
1690		tcc.ReloadFuncsLock.Unlock()
1691
1692		testAdjustTestCore(base, tcc)
1693
1694		ret = append(ret, tcc)
1695	}
1696
1697	testCluster.Cores = ret
1698
1699	testExtraClusterCoresTestSetup(t, priKey, testCluster.Cores)
1700
1701	return &testCluster
1702}
1703
1704func NewMockBuiltinRegistry() *mockBuiltinRegistry {
1705	return &mockBuiltinRegistry{
1706		forTesting: map[string]consts.PluginType{
1707			"mysql-database-plugin":      consts.PluginTypeDatabase,
1708			"postgresql-database-plugin": consts.PluginTypeDatabase,
1709		},
1710	}
1711}
1712
1713type mockBuiltinRegistry struct {
1714	forTesting map[string]consts.PluginType
1715}
1716
1717func (m *mockBuiltinRegistry) Get(name string, pluginType consts.PluginType) (func() (interface{}, error), bool) {
1718	testPluginType, ok := m.forTesting[name]
1719	if !ok {
1720		return nil, false
1721	}
1722	if pluginType != testPluginType {
1723		return nil, false
1724	}
1725	if name == "postgresql-database-plugin" {
1726		return dbPostgres.New, true
1727	}
1728	return dbMysql.New(dbMysql.MetadataLen, dbMysql.MetadataLen, dbMysql.UsernameLen), true
1729}
1730
1731// Keys only supports getting a realistic list of the keys for database plugins.
1732func (m *mockBuiltinRegistry) Keys(pluginType consts.PluginType) []string {
1733	if pluginType != consts.PluginTypeDatabase {
1734		return []string{}
1735	}
1736	/*
1737		This is a hard-coded reproduction of the db plugin keys in helper/builtinplugins/registry.go.
1738		The registry isn't directly used because it causes import cycles.
1739	*/
1740	return []string{
1741		"mysql-database-plugin",
1742		"mysql-aurora-database-plugin",
1743		"mysql-rds-database-plugin",
1744		"mysql-legacy-database-plugin",
1745		"postgresql-database-plugin",
1746		"elasticsearch-database-plugin",
1747		"mssql-database-plugin",
1748		"cassandra-database-plugin",
1749		"mongodb-database-plugin",
1750		"hana-database-plugin",
1751		"influxdb-database-plugin",
1752	}
1753}
1754
1755func (m *mockBuiltinRegistry) Contains(name string, pluginType consts.PluginType) bool {
1756	return false
1757}
1758
1759type NoopAudit struct {
1760	Config         *audit.BackendConfig
1761	ReqErr         error
1762	ReqAuth        []*logical.Auth
1763	Req            []*logical.Request
1764	ReqHeaders     []map[string][]string
1765	ReqNonHMACKeys []string
1766	ReqErrs        []error
1767
1768	RespErr            error
1769	RespAuth           []*logical.Auth
1770	RespReq            []*logical.Request
1771	Resp               []*logical.Response
1772	RespNonHMACKeys    []string
1773	RespReqNonHMACKeys []string
1774	RespErrs           []error
1775
1776	salt      *salt.Salt
1777	saltMutex sync.RWMutex
1778}
1779
1780func (n *NoopAudit) LogRequest(ctx context.Context, in *logical.LogInput) error {
1781	n.ReqAuth = append(n.ReqAuth, in.Auth)
1782	n.Req = append(n.Req, in.Request)
1783	n.ReqHeaders = append(n.ReqHeaders, in.Request.Headers)
1784	n.ReqNonHMACKeys = in.NonHMACReqDataKeys
1785	n.ReqErrs = append(n.ReqErrs, in.OuterErr)
1786	return n.ReqErr
1787}
1788
1789func (n *NoopAudit) LogResponse(ctx context.Context, in *logical.LogInput) error {
1790	n.RespAuth = append(n.RespAuth, in.Auth)
1791	n.RespReq = append(n.RespReq, in.Request)
1792	n.Resp = append(n.Resp, in.Response)
1793	n.RespErrs = append(n.RespErrs, in.OuterErr)
1794
1795	if in.Response != nil {
1796		n.RespNonHMACKeys = in.NonHMACRespDataKeys
1797		n.RespReqNonHMACKeys = in.NonHMACReqDataKeys
1798	}
1799
1800	return n.RespErr
1801}
1802
1803func (n *NoopAudit) Salt(ctx context.Context) (*salt.Salt, error) {
1804	n.saltMutex.RLock()
1805	if n.salt != nil {
1806		defer n.saltMutex.RUnlock()
1807		return n.salt, nil
1808	}
1809	n.saltMutex.RUnlock()
1810	n.saltMutex.Lock()
1811	defer n.saltMutex.Unlock()
1812	if n.salt != nil {
1813		return n.salt, nil
1814	}
1815	salt, err := salt.NewSalt(ctx, n.Config.SaltView, n.Config.SaltConfig)
1816	if err != nil {
1817		return nil, err
1818	}
1819	n.salt = salt
1820	return salt, nil
1821}
1822
1823func (n *NoopAudit) GetHash(ctx context.Context, data string) (string, error) {
1824	salt, err := n.Salt(ctx)
1825	if err != nil {
1826		return "", err
1827	}
1828	return salt.GetIdentifiedHMAC(data), nil
1829}
1830
1831func (n *NoopAudit) Reload(ctx context.Context) error {
1832	return nil
1833}
1834
1835func (n *NoopAudit) Invalidate(ctx context.Context) {
1836	n.saltMutex.Lock()
1837	defer n.saltMutex.Unlock()
1838	n.salt = nil
1839}
1840