1package keysutil
2
3import (
4	"context"
5	"encoding/base64"
6	"errors"
7	"fmt"
8	"sync"
9	"sync/atomic"
10	"time"
11
12	"github.com/hashicorp/errwrap"
13	"github.com/hashicorp/vault/sdk/helper/jsonutil"
14	"github.com/hashicorp/vault/sdk/helper/locksutil"
15	"github.com/hashicorp/vault/sdk/logical"
16)
17
18const (
19	shared                   = false
20	exclusive                = true
21	currentConvergentVersion = 3
22)
23
24var (
25	errNeedExclusiveLock = errors.New("an exclusive lock is needed for this operation")
26)
27
28// PolicyRequest holds values used when requesting a policy. Most values are
29// only used during an upsert.
30type PolicyRequest struct {
31	// The storage to use
32	Storage logical.Storage
33
34	// The name of the policy
35	Name string
36
37	// The key type
38	KeyType KeyType
39
40	// Whether it should be derived
41	Derived bool
42
43	// Whether to enable convergent encryption
44	Convergent bool
45
46	// Whether to allow export
47	Exportable bool
48
49	// Whether to upsert
50	Upsert bool
51
52	// Whether to allow plaintext backup
53	AllowPlaintextBackup bool
54}
55
56type LockManager struct {
57	useCache bool
58	cache    Cache
59	keyLocks []*locksutil.LockEntry
60}
61
62func NewLockManager(useCache bool, cacheSize int) (*LockManager, error) {
63	// determine the type of cache to create
64	var cache Cache
65	switch {
66	case !useCache:
67	case cacheSize < 0:
68		return nil, errors.New("cache size must be greater or equal to zero")
69	case cacheSize == 0:
70		cache = NewTransitSyncMap()
71	case cacheSize > 0:
72		newLRUCache, err := NewTransitLRU(cacheSize)
73		if err != nil {
74			return nil, errwrap.Wrapf("failed to create cache: {{err}}", err)
75		}
76		cache = newLRUCache
77	}
78
79	lm := &LockManager{
80		useCache: useCache,
81		cache:    cache,
82		keyLocks: locksutil.CreateLocks(),
83	}
84
85	return lm, nil
86}
87
88func (lm *LockManager) GetCacheSize() int {
89	if !lm.useCache {
90		return 0
91	}
92	return lm.cache.Size()
93}
94
95func (lm *LockManager) GetUseCache() bool {
96	return lm.useCache
97}
98
99func (lm *LockManager) InvalidatePolicy(name string) {
100	if lm.useCache {
101		lm.cache.Delete(name)
102	}
103}
104
105// RestorePolicy acquires an exclusive lock on the policy name and restores the
106// given policy along with the archive.
107func (lm *LockManager) RestorePolicy(ctx context.Context, storage logical.Storage, name, backup string, force bool) error {
108	backupBytes, err := base64.StdEncoding.DecodeString(backup)
109	if err != nil {
110		return err
111	}
112
113	var keyData KeyData
114	err = jsonutil.DecodeJSON(backupBytes, &keyData)
115	if err != nil {
116		return err
117	}
118
119	// Set a different name if desired
120	if name != "" {
121		keyData.Policy.Name = name
122	}
123
124	name = keyData.Policy.Name
125
126	// Grab the exclusive lock as we'll be modifying disk
127	lock := locksutil.LockForKey(lm.keyLocks, name)
128	lock.Lock()
129	defer lock.Unlock()
130
131	var ok bool
132	var pRaw interface{}
133
134	// If the policy is in cache and 'force' is not specified, error out. Anywhere
135	// that would put it in the cache will also be protected by the mutex above,
136	// so we don't need to re-check the cache later.
137	if lm.useCache {
138		pRaw, ok = lm.cache.Load(name)
139		if ok && !force {
140			return fmt.Errorf("key %q already exists", name)
141		}
142	}
143
144	// Conditionally look up the policy from storage, depending on the use of
145	// 'force' and if the policy was found in cache.
146	//
147	// - If was not found in cache and we are not using 'force', look for it in
148	// storage. If found, error out.
149	//
150	// - If it was found in cache and we are using 'force', pRaw will not be nil
151	// and we do not look the policy up from storage
152	//
153	// - If it was found in cache and we are not using 'force', we should have
154	// returned above with error
155	var p *Policy
156	if pRaw == nil {
157		p, err = lm.getPolicyFromStorage(ctx, storage, name)
158		if err != nil {
159			return err
160		}
161		if p != nil && !force {
162			return fmt.Errorf("key %q already exists", name)
163		}
164	}
165
166	// If both pRaw and p above are nil and 'force' is specified, we don't need to
167	// grab policy locks as we have ensured it doesn't already exist, so there
168	// will be no races as nothing else has this pointer. If 'force' was not used,
169	// an error would have been returned by now if the policy already existed
170	if pRaw != nil {
171		p = pRaw.(*Policy)
172	}
173	if p != nil {
174		p.l.Lock()
175		defer p.l.Unlock()
176	}
177
178	// Restore the archived keys
179	if keyData.ArchivedKeys != nil {
180		err = keyData.Policy.storeArchive(ctx, storage, keyData.ArchivedKeys)
181		if err != nil {
182			return errwrap.Wrapf(fmt.Sprintf("failed to restore archived keys for key %q: {{err}}", name), err)
183		}
184	}
185
186	// Mark that policy as a restored key
187	keyData.Policy.RestoreInfo = &RestoreInfo{
188		Time:    time.Now(),
189		Version: keyData.Policy.LatestVersion,
190	}
191
192	// Restore the policy. This will also attempt to adjust the archive.
193	err = keyData.Policy.Persist(ctx, storage)
194	if err != nil {
195		return errwrap.Wrapf(fmt.Sprintf("failed to restore the policy %q: {{err}}", name), err)
196	}
197
198	keyData.Policy.l = new(sync.RWMutex)
199
200	// Update the cache to contain the restored policy
201	if lm.useCache {
202		lm.cache.Store(name, keyData.Policy)
203	}
204	return nil
205}
206
207func (lm *LockManager) BackupPolicy(ctx context.Context, storage logical.Storage, name string) (string, error) {
208	var p *Policy
209	var err error
210
211	// Backup writes information about when the backup took place, so we get an
212	// exclusive lock here
213	lock := locksutil.LockForKey(lm.keyLocks, name)
214	lock.Lock()
215	defer lock.Unlock()
216
217	var ok bool
218	var pRaw interface{}
219
220	if lm.useCache {
221		pRaw, ok = lm.cache.Load(name)
222	}
223	if ok {
224		p = pRaw.(*Policy)
225		p.l.Lock()
226		defer p.l.Unlock()
227	} else {
228		// If the policy doesn't exit in storage, error out
229		p, err = lm.getPolicyFromStorage(ctx, storage, name)
230		if err != nil {
231			return "", err
232		}
233		if p == nil {
234			return "", fmt.Errorf(fmt.Sprintf("key %q not found", name))
235		}
236	}
237
238	if atomic.LoadUint32(&p.deleted) == 1 {
239		return "", fmt.Errorf(fmt.Sprintf("key %q not found", name))
240	}
241
242	backup, err := p.Backup(ctx, storage)
243	if err != nil {
244		return "", err
245	}
246
247	return backup, nil
248}
249
250// When the function returns, if caching was disabled, the Policy's lock must
251// be unlocked when the caller is done (and it should not be re-locked).
252func (lm *LockManager) GetPolicy(ctx context.Context, req PolicyRequest) (retP *Policy, retUpserted bool, retErr error) {
253	var p *Policy
254	var err error
255	var ok bool
256	var pRaw interface{}
257
258	// Check if it's in our cache. If so, return right away.
259	if lm.useCache {
260		pRaw, ok = lm.cache.Load(req.Name)
261	}
262	if ok {
263		p = pRaw.(*Policy)
264		if atomic.LoadUint32(&p.deleted) == 1 {
265			return nil, false, nil
266		}
267		return p, false, nil
268	}
269
270	// We're not using the cache, or it wasn't found; get an exclusive lock.
271	// This ensures that any other process writing the actual storage will be
272	// finished before we load from storage.
273	lock := locksutil.LockForKey(lm.keyLocks, req.Name)
274	lock.Lock()
275
276	// If we are using the cache, defer the lock unlock; otherwise we will
277	// return from here with the lock still held.
278	cleanup := func() {
279		switch {
280		// If using the cache we always unlock, the caller locks the policy
281		// themselves
282		case lm.useCache:
283			lock.Unlock()
284
285		// If not using the cache, if we aren't returning a policy the caller
286		// doesn't have a lock, so we must unlock
287		case retP == nil:
288			lock.Unlock()
289		}
290	}
291
292	// Check the cache again
293	if lm.useCache {
294		pRaw, ok = lm.cache.Load(req.Name)
295	}
296	if ok {
297		p = pRaw.(*Policy)
298		if atomic.LoadUint32(&p.deleted) == 1 {
299			cleanup()
300			return nil, false, nil
301		}
302		retP = p
303		cleanup()
304		return
305	}
306
307	// Load it from storage
308	p, err = lm.getPolicyFromStorage(ctx, req.Storage, req.Name)
309	if err != nil {
310		cleanup()
311		return nil, false, err
312	}
313	// We don't need to lock the policy as there would be no other holders of
314	// the pointer
315
316	if p == nil {
317		// This is the only place we upsert a new policy, so if upsert is not
318		// specified, or the lock type is wrong, unlock before returning
319		if !req.Upsert {
320			cleanup()
321			return nil, false, nil
322		}
323
324		// We create the policy here, then at the end we do a LoadOrStore. If
325		// it's been loaded since we last checked the cache, we return an error
326		// to the user to let them know that their request can't be satisfied
327		// because we don't know if the parameters match.
328
329		switch req.KeyType {
330		case KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305:
331			if req.Convergent && !req.Derived {
332				cleanup()
333				return nil, false, fmt.Errorf("convergent encryption requires derivation to be enabled")
334			}
335
336		case KeyType_ECDSA_P256:
337			if req.Derived || req.Convergent {
338				cleanup()
339				return nil, false, fmt.Errorf("key derivation and convergent encryption not supported for keys of type %v", req.KeyType)
340			}
341
342		case KeyType_ED25519:
343			if req.Convergent {
344				cleanup()
345				return nil, false, fmt.Errorf("convergent encryption not supported for keys of type %v", req.KeyType)
346			}
347
348		case KeyType_RSA2048, KeyType_RSA4096:
349			if req.Derived || req.Convergent {
350				cleanup()
351				return nil, false, fmt.Errorf("key derivation and convergent encryption not supported for keys of type %v", req.KeyType)
352			}
353
354		default:
355			cleanup()
356			return nil, false, fmt.Errorf("unsupported key type %v", req.KeyType)
357		}
358
359		p = &Policy{
360			l:                    new(sync.RWMutex),
361			Name:                 req.Name,
362			Type:                 req.KeyType,
363			Derived:              req.Derived,
364			Exportable:           req.Exportable,
365			AllowPlaintextBackup: req.AllowPlaintextBackup,
366		}
367
368		if req.Derived {
369			p.KDF = Kdf_hkdf_sha256
370			if req.Convergent {
371				p.ConvergentEncryption = true
372				// As of version 3 we store the version within each key, so we
373				// set to -1 to indicate that the value in the policy has no
374				// meaning. We still, for backwards compatibility, fall back to
375				// this value if the key doesn't have one, which means it will
376				// only be -1 in the case where every key version is >= 3
377				p.ConvergentVersion = -1
378			}
379		}
380
381		// Performs the actual persist and does setup
382		err = p.Rotate(ctx, req.Storage)
383		if err != nil {
384			cleanup()
385			return nil, false, err
386		}
387
388		if lm.useCache {
389			lm.cache.Store(req.Name, p)
390		} else {
391			p.l = &lock.RWMutex
392			p.writeLocked = true
393		}
394
395		// We don't need to worry about upgrading since it will be a new policy
396		retP = p
397		retUpserted = true
398		cleanup()
399		return
400	}
401
402	if p.NeedsUpgrade() {
403		if err := p.Upgrade(ctx, req.Storage); err != nil {
404			cleanup()
405			return nil, false, err
406		}
407	}
408
409	if lm.useCache {
410		lm.cache.Store(req.Name, p)
411	} else {
412		p.l = &lock.RWMutex
413		p.writeLocked = true
414	}
415
416	retP = p
417	cleanup()
418	return
419}
420
421func (lm *LockManager) DeletePolicy(ctx context.Context, storage logical.Storage, name string) error {
422	var p *Policy
423	var err error
424	var ok bool
425	var pRaw interface{}
426
427	// We may be writing to disk, so grab an exclusive lock. This prevents bad
428	// behavior when the cache is turned off. We also lock the shared policy
429	// object to make sure no requests are in flight.
430	lock := locksutil.LockForKey(lm.keyLocks, name)
431	lock.Lock()
432	defer lock.Unlock()
433
434	if lm.useCache {
435		pRaw, ok = lm.cache.Load(name)
436	}
437	if ok {
438		p = pRaw.(*Policy)
439		p.l.Lock()
440		defer p.l.Unlock()
441	}
442
443	if p == nil {
444		p, err = lm.getPolicyFromStorage(ctx, storage, name)
445		if err != nil {
446			return err
447		}
448		if p == nil {
449			return fmt.Errorf("could not delete key; not found")
450		}
451	}
452
453	if !p.DeletionAllowed {
454		return fmt.Errorf("deletion is not allowed for this key")
455	}
456
457	atomic.StoreUint32(&p.deleted, 1)
458
459	if lm.useCache {
460		lm.cache.Delete(name)
461	}
462
463	err = storage.Delete(ctx, "policy/"+name)
464	if err != nil {
465		return errwrap.Wrapf(fmt.Sprintf("error deleting key %q: {{err}}", name), err)
466	}
467
468	err = storage.Delete(ctx, "archive/"+name)
469	if err != nil {
470		return errwrap.Wrapf(fmt.Sprintf("error deleting key %q archive: {{err}}", name), err)
471	}
472
473	return nil
474}
475
476func (lm *LockManager) getPolicyFromStorage(ctx context.Context, storage logical.Storage, name string) (*Policy, error) {
477	return LoadPolicy(ctx, storage, "policy/"+name)
478}
479