1package vault
2
3import (
4	"context"
5	"fmt"
6	"math/rand"
7	"path"
8	"time"
9
10	uuid "github.com/hashicorp/go-uuid"
11	"github.com/hashicorp/vault/helper/namespace"
12	"github.com/hashicorp/vault/sdk/logical"
13)
14
15type basicLeaseTestInfo struct {
16	id     string
17	mount  string
18	expire time.Time
19}
20
21// add an irrevocable lease for test purposes
22// returns the lease ID and expire time
23func (c *Core) AddIrrevocableLease(ctx context.Context, pathPrefix string) (*basicLeaseTestInfo, error) {
24	exp := c.expiration
25
26	uuid, err := uuid.GenerateUUID()
27	if err != nil {
28		return nil, fmt.Errorf("error generating uuid: %w", err)
29	}
30
31	ns, err := namespace.FromContext(ctx)
32	if err != nil {
33		return nil, fmt.Errorf("error getting namespace from context: %w", err)
34	}
35	if ns == nil {
36		ns = namespace.RootNamespace
37	}
38
39	leaseID := path.Join(pathPrefix, "lease"+uuid)
40
41	if ns != namespace.RootNamespace {
42		leaseID = fmt.Sprintf("%s.%s", leaseID, ns.ID)
43	}
44
45	randomTimeDelta := time.Duration(rand.Int31n(24))
46	le := &leaseEntry{
47		LeaseID:    leaseID,
48		Path:       pathPrefix,
49		namespace:  ns,
50		IssueTime:  time.Now(),
51		ExpireTime: time.Now().Add(randomTimeDelta * time.Hour),
52		RevokeErr:  "some error message",
53	}
54
55	exp.pendingLock.Lock()
56	defer exp.pendingLock.Unlock()
57
58	if err := exp.persistEntry(context.Background(), le); err != nil {
59		return nil, fmt.Errorf("error persisting irrevocable lease: %w", err)
60	}
61
62	exp.updatePendingInternal(le)
63
64	return &basicLeaseTestInfo{
65		id:     le.LeaseID,
66		expire: le.ExpireTime,
67	}, nil
68}
69
70// InjectIrrevocableLeases injects `count` irrevocable leases (currently to a
71// single mount).
72// It returns a map of the mount accessor to the number of leases stored there
73func (c *Core) InjectIrrevocableLeases(ctx context.Context, count int) (map[string]int, error) {
74	out := make(map[string]int)
75	for i := 0; i < count; i++ {
76		le, err := c.AddIrrevocableLease(ctx, "foo/")
77		if err != nil {
78			return nil, err
79		}
80
81		mountAccessor := c.expiration.getLeaseMountAccessor(ctx, le.id)
82		if _, ok := out[mountAccessor]; !ok {
83			out[mountAccessor] = 0
84		}
85
86		out[mountAccessor]++
87	}
88
89	return out, nil
90}
91
92type backend struct {
93	path string
94	ns   *namespace.Namespace
95}
96
97// set up multiple mounts, and return a mapping of the path to the mount accessor
98func mountNoopBackends(c *Core, backends []*backend) (map[string]string, error) {
99	// enable the noop backend
100	c.logicalBackends["noop"] = func(ctx context.Context, config *logical.BackendConfig) (logical.Backend, error) {
101		return &NoopBackend{}, nil
102	}
103
104	pathToMount := make(map[string]string)
105	for _, backend := range backends {
106		me := &MountEntry{
107			Table: mountTableType,
108			Path:  backend.path,
109			Type:  "noop",
110		}
111
112		nsCtx := namespace.ContextWithNamespace(context.Background(), backend.ns)
113		if err := c.mount(nsCtx, me); err != nil {
114			return nil, fmt.Errorf("error mounting backend %s: %w", backend.path, err)
115		}
116
117		mount := c.router.MatchingMountEntry(nsCtx, backend.path)
118		if mount == nil {
119			return nil, fmt.Errorf("couldn't find mount for path %s", backend.path)
120		}
121		pathToMount[backend.path] = mount.Accessor
122	}
123
124	return pathToMount, nil
125}
126
127func (c *Core) FetchLeaseCountToRevoke() int {
128	c.expiration.pendingLock.RLock()
129	defer c.expiration.pendingLock.RUnlock()
130	return c.expiration.leaseCount
131}
132