1package vault
2
3import (
4	"context"
5	"fmt"
6	"strings"
7
8	"github.com/hashicorp/vault/helper/namespace"
9
10	"github.com/hashicorp/errwrap"
11	multierror "github.com/hashicorp/go-multierror"
12	"github.com/hashicorp/vault/sdk/helper/strutil"
13	"github.com/hashicorp/vault/sdk/logical"
14)
15
16// reloadPluginMounts reloads provided mounts, regardless of
17// plugin name, as long as the backend type is plugin.
18func (c *Core) reloadMatchingPluginMounts(ctx context.Context, mounts []string) error {
19	c.mountsLock.RLock()
20	defer c.mountsLock.RUnlock()
21	c.authLock.RLock()
22	defer c.authLock.RUnlock()
23
24	ns, err := namespace.FromContext(ctx)
25	if err != nil {
26		return err
27	}
28
29	var errors error
30	for _, mount := range mounts {
31		entry := c.router.MatchingMountEntry(ctx, mount)
32		if entry == nil {
33			errors = multierror.Append(errors, fmt.Errorf("cannot fetch mount entry on %q", mount))
34			continue
35		}
36
37		var isAuth bool
38		fullPath := c.router.MatchingMount(ctx, mount)
39		if strings.HasPrefix(fullPath, credentialRoutePrefix) {
40			isAuth = true
41		}
42
43		// We dont reload mounts that are not in the same namespace
44		if ns.ID != entry.Namespace().ID {
45			continue
46		}
47
48		err := c.reloadBackendCommon(ctx, entry, isAuth)
49		if err != nil {
50			errors = multierror.Append(errors, errwrap.Wrapf(fmt.Sprintf("cannot reload plugin on %q: {{err}}", mount), err))
51			continue
52		}
53		c.logger.Info("successfully reloaded plugin", "plugin", entry.Accessor, "path", entry.Path)
54	}
55	return errors
56}
57
58// reloadPlugin reloads all mounted backends that are of
59// plugin pluginName (name of the plugin as registered in
60// the plugin catalog).
61func (c *Core) reloadMatchingPlugin(ctx context.Context, pluginName string) error {
62	c.mountsLock.RLock()
63	defer c.mountsLock.RUnlock()
64	c.authLock.RLock()
65	defer c.authLock.RUnlock()
66
67	ns, err := namespace.FromContext(ctx)
68	if err != nil {
69		return err
70	}
71
72	// Filter mount entries that only matches the plugin name
73	for _, entry := range c.mounts.Entries {
74		// We dont reload mounts that are not in the same namespace
75		if ns.ID != entry.Namespace().ID {
76			continue
77		}
78		if entry.Type == pluginName || (entry.Type == "plugin" && entry.Config.PluginName == pluginName) {
79			err := c.reloadBackendCommon(ctx, entry, false)
80			if err != nil {
81				return err
82			}
83			c.logger.Info("successfully reloaded plugin", "plugin", pluginName, "path", entry.Path)
84		}
85	}
86
87	// Filter auth mount entries that ony matches the plugin name
88	for _, entry := range c.auth.Entries {
89		// We dont reload mounts that are not in the same namespace
90		if ns.ID != entry.Namespace().ID {
91			continue
92		}
93
94		if entry.Type == pluginName || (entry.Type == "plugin" && entry.Config.PluginName == pluginName) {
95			err := c.reloadBackendCommon(ctx, entry, true)
96			if err != nil {
97				return err
98			}
99			c.logger.Info("successfully reloaded plugin", "plugin", entry.Accessor, "path", entry.Path)
100		}
101	}
102
103	return nil
104}
105
106// reloadBackendCommon is a generic method to reload a backend provided a
107// MountEntry.
108func (c *Core) reloadBackendCommon(ctx context.Context, entry *MountEntry, isAuth bool) error {
109	// Make sure our cache is up-to-date. Since some singleton mounts can be
110	// tuned, we do this before the below check.
111	entry.SyncCache()
112
113	// We don't want to reload the singleton mounts. They often have specific
114	// inmemory elements and we don't want to touch them here.
115	if strutil.StrListContains(singletonMounts, entry.Type) {
116		c.logger.Debug("skipping reload of singleton mount", "type", entry.Type)
117		return nil
118	}
119
120	path := entry.Path
121
122	if isAuth {
123		path = credentialRoutePrefix + path
124	}
125
126	// Fast-path out if the backend doesn't exist
127	raw, ok := c.router.root.Get(entry.Namespace().Path + path)
128	if !ok {
129		return nil
130	}
131
132	re := raw.(*routeEntry)
133
134	// Grab the lock, this allows requests to drain before we cleanup the
135	// client.
136	re.l.Lock()
137	defer re.l.Unlock()
138
139	// Only call Cleanup if backend is initialized
140	if re.backend != nil {
141		// Call backend's Cleanup routine
142		re.backend.Cleanup(ctx)
143	}
144
145	view := re.storageView
146	viewPath := entry.UUID + "/"
147	switch entry.Table {
148	case mountTableType:
149		viewPath = backendBarrierPrefix + viewPath
150	case credentialTableType:
151		viewPath = credentialBarrierPrefix + viewPath
152	}
153
154	removePathCheckers(c, entry, viewPath)
155
156	sysView := c.mountEntrySysView(entry)
157
158	nilMount, err := preprocessMount(c, entry, view.(*BarrierView))
159	if err != nil {
160		return err
161	}
162
163	var backend logical.Backend
164	if !isAuth {
165		// Dispense a new backend
166		backend, err = c.newLogicalBackend(ctx, entry, sysView, view)
167	} else {
168		backend, err = c.newCredentialBackend(ctx, entry, sysView, view)
169	}
170	if err != nil {
171		return err
172	}
173	if backend == nil {
174		return fmt.Errorf("nil backend of type %q returned from creation function", entry.Type)
175	}
176
177	addPathCheckers(c, entry, backend, viewPath)
178
179	if nilMount {
180		backend.Cleanup(ctx)
181		backend = nil
182	}
183
184	// Set the backend back
185	re.backend = backend
186
187	if backend != nil {
188		// Set paths as well
189		paths := backend.SpecialPaths()
190		if paths != nil {
191			re.rootPaths.Store(pathsToRadix(paths.Root))
192			re.loginPaths.Store(pathsToRadix(paths.Unauthenticated))
193		}
194	}
195
196	return nil
197}
198