1package keysutil
2
3import (
4	"context"
5	"encoding/base64"
6	"errors"
7	"fmt"
8	"io"
9	"sync"
10	"sync/atomic"
11	"time"
12
13	"github.com/hashicorp/errwrap"
14	"github.com/hashicorp/vault/sdk/helper/jsonutil"
15	"github.com/hashicorp/vault/sdk/helper/locksutil"
16	"github.com/hashicorp/vault/sdk/logical"
17)
18
19const (
20	shared                   = false
21	exclusive                = true
22	currentConvergentVersion = 3
23)
24
25var errNeedExclusiveLock = errors.New("an exclusive lock is needed for this operation")
26
27// PolicyRequest holds values used when requesting a policy. Most values are
28// only used during an upsert.
29type PolicyRequest struct {
30	// The storage to use
31	Storage logical.Storage
32
33	// The name of the policy
34	Name string
35
36	// The key type
37	KeyType KeyType
38
39	// Whether it should be derived
40	Derived bool
41
42	// Whether to enable convergent encryption
43	Convergent bool
44
45	// Whether to allow export
46	Exportable bool
47
48	// Whether to upsert
49	Upsert bool
50
51	// Whether to allow plaintext backup
52	AllowPlaintextBackup bool
53}
54
55type LockManager struct {
56	useCache bool
57	cache    Cache
58	keyLocks []*locksutil.LockEntry
59}
60
61func NewLockManager(useCache bool, cacheSize int) (*LockManager, error) {
62	// determine the type of cache to create
63	var cache Cache
64	switch {
65	case !useCache:
66	case cacheSize < 0:
67		return nil, errors.New("cache size must be greater or equal to zero")
68	case cacheSize == 0:
69		cache = NewTransitSyncMap()
70	case cacheSize > 0:
71		newLRUCache, err := NewTransitLRU(cacheSize)
72		if err != nil {
73			return nil, errwrap.Wrapf("failed to create cache: {{err}}", err)
74		}
75		cache = newLRUCache
76	}
77
78	lm := &LockManager{
79		useCache: useCache,
80		cache:    cache,
81		keyLocks: locksutil.CreateLocks(),
82	}
83
84	return lm, nil
85}
86
87func (lm *LockManager) GetCacheSize() int {
88	if !lm.useCache {
89		return 0
90	}
91	return lm.cache.Size()
92}
93
94func (lm *LockManager) GetUseCache() bool {
95	return lm.useCache
96}
97
98func (lm *LockManager) InvalidatePolicy(name string) {
99	if lm.useCache {
100		lm.cache.Delete(name)
101	}
102}
103
104// RestorePolicy acquires an exclusive lock on the policy name and restores the
105// given policy along with the archive.
106func (lm *LockManager) RestorePolicy(ctx context.Context, storage logical.Storage, name, backup string, force bool) error {
107	backupBytes, err := base64.StdEncoding.DecodeString(backup)
108	if err != nil {
109		return err
110	}
111
112	var keyData KeyData
113	err = jsonutil.DecodeJSON(backupBytes, &keyData)
114	if err != nil {
115		return err
116	}
117
118	// Set a different name if desired
119	if name != "" {
120		keyData.Policy.Name = name
121	}
122
123	name = keyData.Policy.Name
124
125	// Grab the exclusive lock as we'll be modifying disk
126	lock := locksutil.LockForKey(lm.keyLocks, name)
127	lock.Lock()
128	defer lock.Unlock()
129
130	var ok bool
131	var pRaw interface{}
132
133	// If the policy is in cache and 'force' is not specified, error out. Anywhere
134	// that would put it in the cache will also be protected by the mutex above,
135	// so we don't need to re-check the cache later.
136	if lm.useCache {
137		pRaw, ok = lm.cache.Load(name)
138		if ok && !force {
139			return fmt.Errorf("key %q already exists", name)
140		}
141	}
142
143	// Conditionally look up the policy from storage, depending on the use of
144	// 'force' and if the policy was found in cache.
145	//
146	// - If was not found in cache and we are not using 'force', look for it in
147	// storage. If found, error out.
148	//
149	// - If it was found in cache and we are using 'force', pRaw will not be nil
150	// and we do not look the policy up from storage
151	//
152	// - If it was found in cache and we are not using 'force', we should have
153	// returned above with error
154	var p *Policy
155	if pRaw == nil {
156		p, err = lm.getPolicyFromStorage(ctx, storage, name)
157		if err != nil {
158			return err
159		}
160		if p != nil && !force {
161			return fmt.Errorf("key %q already exists", name)
162		}
163	}
164
165	// If both pRaw and p above are nil and 'force' is specified, we don't need to
166	// grab policy locks as we have ensured it doesn't already exist, so there
167	// will be no races as nothing else has this pointer. If 'force' was not used,
168	// an error would have been returned by now if the policy already existed
169	if pRaw != nil {
170		p = pRaw.(*Policy)
171	}
172	if p != nil {
173		p.l.Lock()
174		defer p.l.Unlock()
175	}
176
177	// Restore the archived keys
178	if keyData.ArchivedKeys != nil {
179		err = keyData.Policy.storeArchive(ctx, storage, keyData.ArchivedKeys)
180		if err != nil {
181			return errwrap.Wrapf(fmt.Sprintf("failed to restore archived keys for key %q: {{err}}", name), err)
182		}
183	}
184
185	// Mark that policy as a restored key
186	keyData.Policy.RestoreInfo = &RestoreInfo{
187		Time:    time.Now(),
188		Version: keyData.Policy.LatestVersion,
189	}
190
191	// Restore the policy. This will also attempt to adjust the archive.
192	err = keyData.Policy.Persist(ctx, storage)
193	if err != nil {
194		return errwrap.Wrapf(fmt.Sprintf("failed to restore the policy %q: {{err}}", name), err)
195	}
196
197	keyData.Policy.l = new(sync.RWMutex)
198
199	// Update the cache to contain the restored policy
200	if lm.useCache {
201		lm.cache.Store(name, keyData.Policy)
202	}
203	return nil
204}
205
206func (lm *LockManager) BackupPolicy(ctx context.Context, storage logical.Storage, name string) (string, error) {
207	var p *Policy
208	var err error
209
210	// Backup writes information about when the backup took place, so we get an
211	// exclusive lock here
212	lock := locksutil.LockForKey(lm.keyLocks, name)
213	lock.Lock()
214	defer lock.Unlock()
215
216	var ok bool
217	var pRaw interface{}
218
219	if lm.useCache {
220		pRaw, ok = lm.cache.Load(name)
221	}
222	if ok {
223		p = pRaw.(*Policy)
224		p.l.Lock()
225		defer p.l.Unlock()
226	} else {
227		// If the policy doesn't exit in storage, error out
228		p, err = lm.getPolicyFromStorage(ctx, storage, name)
229		if err != nil {
230			return "", err
231		}
232		if p == nil {
233			return "", fmt.Errorf(fmt.Sprintf("key %q not found", name))
234		}
235	}
236
237	if atomic.LoadUint32(&p.deleted) == 1 {
238		return "", fmt.Errorf(fmt.Sprintf("key %q not found", name))
239	}
240
241	backup, err := p.Backup(ctx, storage)
242	if err != nil {
243		return "", err
244	}
245
246	return backup, nil
247}
248
249// When the function returns, if caching was disabled, the Policy's lock must
250// be unlocked when the caller is done (and it should not be re-locked).
251func (lm *LockManager) GetPolicy(ctx context.Context, req PolicyRequest, rand io.Reader) (retP *Policy, retUpserted bool, retErr error) {
252	var p *Policy
253	var err error
254	var ok bool
255	var pRaw interface{}
256
257	// Check if it's in our cache. If so, return right away.
258	if lm.useCache {
259		pRaw, ok = lm.cache.Load(req.Name)
260	}
261	if ok {
262		p = pRaw.(*Policy)
263		if atomic.LoadUint32(&p.deleted) == 1 {
264			return nil, false, nil
265		}
266		return p, false, nil
267	}
268
269	// We're not using the cache, or it wasn't found; get an exclusive lock.
270	// This ensures that any other process writing the actual storage will be
271	// finished before we load from storage.
272	lock := locksutil.LockForKey(lm.keyLocks, req.Name)
273	lock.Lock()
274
275	// If we are using the cache, defer the lock unlock; otherwise we will
276	// return from here with the lock still held.
277	cleanup := func() {
278		switch {
279		// If using the cache we always unlock, the caller locks the policy
280		// themselves
281		case lm.useCache:
282			lock.Unlock()
283
284		// If not using the cache, if we aren't returning a policy the caller
285		// doesn't have a lock, so we must unlock
286		case retP == nil:
287			lock.Unlock()
288		}
289	}
290
291	// Check the cache again
292	if lm.useCache {
293		pRaw, ok = lm.cache.Load(req.Name)
294	}
295	if ok {
296		p = pRaw.(*Policy)
297		if atomic.LoadUint32(&p.deleted) == 1 {
298			cleanup()
299			return nil, false, nil
300		}
301		retP = p
302		cleanup()
303		return
304	}
305
306	// Load it from storage
307	p, err = lm.getPolicyFromStorage(ctx, req.Storage, req.Name)
308	if err != nil {
309		cleanup()
310		return nil, false, err
311	}
312	// We don't need to lock the policy as there would be no other holders of
313	// the pointer
314
315	if p == nil {
316		// This is the only place we upsert a new policy, so if upsert is not
317		// specified, or the lock type is wrong, unlock before returning
318		if !req.Upsert {
319			cleanup()
320			return nil, false, nil
321		}
322
323		// We create the policy here, then at the end we do a LoadOrStore. If
324		// it's been loaded since we last checked the cache, we return an error
325		// to the user to let them know that their request can't be satisfied
326		// because we don't know if the parameters match.
327
328		switch req.KeyType {
329		case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305:
330			if req.Convergent && !req.Derived {
331				cleanup()
332				return nil, false, fmt.Errorf("convergent encryption requires derivation to be enabled")
333			}
334
335		case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521:
336			if req.Derived || req.Convergent {
337				cleanup()
338				return nil, false, fmt.Errorf("key derivation and convergent encryption not supported for keys of type %v", req.KeyType)
339			}
340
341		case KeyType_ED25519:
342			if req.Convergent {
343				cleanup()
344				return nil, false, fmt.Errorf("convergent encryption not supported for keys of type %v", req.KeyType)
345			}
346
347		case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
348			if req.Derived || req.Convergent {
349				cleanup()
350				return nil, false, fmt.Errorf("key derivation and convergent encryption not supported for keys of type %v", req.KeyType)
351			}
352
353		default:
354			cleanup()
355			return nil, false, fmt.Errorf("unsupported key type %v", req.KeyType)
356		}
357
358		p = &Policy{
359			l:                    new(sync.RWMutex),
360			Name:                 req.Name,
361			Type:                 req.KeyType,
362			Derived:              req.Derived,
363			Exportable:           req.Exportable,
364			AllowPlaintextBackup: req.AllowPlaintextBackup,
365		}
366
367		if req.Derived {
368			p.KDF = Kdf_hkdf_sha256
369			if req.Convergent {
370				p.ConvergentEncryption = true
371				// As of version 3 we store the version within each key, so we
372				// set to -1 to indicate that the value in the policy has no
373				// meaning. We still, for backwards compatibility, fall back to
374				// this value if the key doesn't have one, which means it will
375				// only be -1 in the case where every key version is >= 3
376				p.ConvergentVersion = -1
377			}
378		}
379
380		// Performs the actual persist and does setup
381		err = p.Rotate(ctx, req.Storage, rand)
382		if err != nil {
383			cleanup()
384			return nil, false, err
385		}
386
387		if lm.useCache {
388			lm.cache.Store(req.Name, p)
389		} else {
390			p.l = &lock.RWMutex
391			p.writeLocked = true
392		}
393
394		// We don't need to worry about upgrading since it will be a new policy
395		retP = p
396		retUpserted = true
397		cleanup()
398		return
399	}
400
401	if p.NeedsUpgrade() {
402		if err := p.Upgrade(ctx, req.Storage, rand); err != nil {
403			cleanup()
404			return nil, false, err
405		}
406	}
407
408	if lm.useCache {
409		lm.cache.Store(req.Name, p)
410	} else {
411		p.l = &lock.RWMutex
412		p.writeLocked = true
413	}
414
415	retP = p
416	cleanup()
417	return
418}
419
420func (lm *LockManager) DeletePolicy(ctx context.Context, storage logical.Storage, name string) error {
421	var p *Policy
422	var err error
423	var ok bool
424	var pRaw interface{}
425
426	// We may be writing to disk, so grab an exclusive lock. This prevents bad
427	// behavior when the cache is turned off. We also lock the shared policy
428	// object to make sure no requests are in flight.
429	lock := locksutil.LockForKey(lm.keyLocks, name)
430	lock.Lock()
431	defer lock.Unlock()
432
433	if lm.useCache {
434		pRaw, ok = lm.cache.Load(name)
435	}
436	if ok {
437		p = pRaw.(*Policy)
438		p.l.Lock()
439		defer p.l.Unlock()
440	}
441
442	if p == nil {
443		p, err = lm.getPolicyFromStorage(ctx, storage, name)
444		if err != nil {
445			return err
446		}
447		if p == nil {
448			return fmt.Errorf("could not delete key; not found")
449		}
450	}
451
452	if !p.DeletionAllowed {
453		return fmt.Errorf("deletion is not allowed for this key")
454	}
455
456	atomic.StoreUint32(&p.deleted, 1)
457
458	if lm.useCache {
459		lm.cache.Delete(name)
460	}
461
462	err = storage.Delete(ctx, "policy/"+name)
463	if err != nil {
464		return errwrap.Wrapf(fmt.Sprintf("error deleting key %q: {{err}}", name), err)
465	}
466
467	err = storage.Delete(ctx, "archive/"+name)
468	if err != nil {
469		return errwrap.Wrapf(fmt.Sprintf("error deleting key %q archive: {{err}}", name), err)
470	}
471
472	return nil
473}
474
475func (lm *LockManager) getPolicyFromStorage(ctx context.Context, storage logical.Storage, name string) (*Policy, error) {
476	return LoadPolicy(ctx, storage, "policy/"+name)
477}
478