1package keysutil
2
3import (
4	"context"
5	"encoding/base64"
6	"errors"
7	"fmt"
8	"io"
9	"sync"
10	"sync/atomic"
11	"time"
12
13	"github.com/hashicorp/errwrap"
14	"github.com/hashicorp/vault/sdk/helper/jsonutil"
15	"github.com/hashicorp/vault/sdk/helper/locksutil"
16	"github.com/hashicorp/vault/sdk/logical"
17)
18
19const (
20	shared                   = false
21	exclusive                = true
22	currentConvergentVersion = 3
23)
24
25var (
26	errNeedExclusiveLock = errors.New("an exclusive lock is needed for this operation")
27)
28
29// PolicyRequest holds values used when requesting a policy. Most values are
30// only used during an upsert.
31type PolicyRequest struct {
32	// The storage to use
33	Storage logical.Storage
34
35	// The name of the policy
36	Name string
37
38	// The key type
39	KeyType KeyType
40
41	// Whether it should be derived
42	Derived bool
43
44	// Whether to enable convergent encryption
45	Convergent bool
46
47	// Whether to allow export
48	Exportable bool
49
50	// Whether to upsert
51	Upsert bool
52
53	// Whether to allow plaintext backup
54	AllowPlaintextBackup bool
55}
56
57type LockManager struct {
58	useCache bool
59	cache    Cache
60	keyLocks []*locksutil.LockEntry
61}
62
63func NewLockManager(useCache bool, cacheSize int) (*LockManager, error) {
64	// determine the type of cache to create
65	var cache Cache
66	switch {
67	case !useCache:
68	case cacheSize < 0:
69		return nil, errors.New("cache size must be greater or equal to zero")
70	case cacheSize == 0:
71		cache = NewTransitSyncMap()
72	case cacheSize > 0:
73		newLRUCache, err := NewTransitLRU(cacheSize)
74		if err != nil {
75			return nil, errwrap.Wrapf("failed to create cache: {{err}}", err)
76		}
77		cache = newLRUCache
78	}
79
80	lm := &LockManager{
81		useCache: useCache,
82		cache:    cache,
83		keyLocks: locksutil.CreateLocks(),
84	}
85
86	return lm, nil
87}
88
89func (lm *LockManager) GetCacheSize() int {
90	if !lm.useCache {
91		return 0
92	}
93	return lm.cache.Size()
94}
95
96func (lm *LockManager) GetUseCache() bool {
97	return lm.useCache
98}
99
100func (lm *LockManager) InvalidatePolicy(name string) {
101	if lm.useCache {
102		lm.cache.Delete(name)
103	}
104}
105
106// RestorePolicy acquires an exclusive lock on the policy name and restores the
107// given policy along with the archive.
108func (lm *LockManager) RestorePolicy(ctx context.Context, storage logical.Storage, name, backup string, force bool) error {
109	backupBytes, err := base64.StdEncoding.DecodeString(backup)
110	if err != nil {
111		return err
112	}
113
114	var keyData KeyData
115	err = jsonutil.DecodeJSON(backupBytes, &keyData)
116	if err != nil {
117		return err
118	}
119
120	// Set a different name if desired
121	if name != "" {
122		keyData.Policy.Name = name
123	}
124
125	name = keyData.Policy.Name
126
127	// Grab the exclusive lock as we'll be modifying disk
128	lock := locksutil.LockForKey(lm.keyLocks, name)
129	lock.Lock()
130	defer lock.Unlock()
131
132	var ok bool
133	var pRaw interface{}
134
135	// If the policy is in cache and 'force' is not specified, error out. Anywhere
136	// that would put it in the cache will also be protected by the mutex above,
137	// so we don't need to re-check the cache later.
138	if lm.useCache {
139		pRaw, ok = lm.cache.Load(name)
140		if ok && !force {
141			return fmt.Errorf("key %q already exists", name)
142		}
143	}
144
145	// Conditionally look up the policy from storage, depending on the use of
146	// 'force' and if the policy was found in cache.
147	//
148	// - If was not found in cache and we are not using 'force', look for it in
149	// storage. If found, error out.
150	//
151	// - If it was found in cache and we are using 'force', pRaw will not be nil
152	// and we do not look the policy up from storage
153	//
154	// - If it was found in cache and we are not using 'force', we should have
155	// returned above with error
156	var p *Policy
157	if pRaw == nil {
158		p, err = lm.getPolicyFromStorage(ctx, storage, name)
159		if err != nil {
160			return err
161		}
162		if p != nil && !force {
163			return fmt.Errorf("key %q already exists", name)
164		}
165	}
166
167	// If both pRaw and p above are nil and 'force' is specified, we don't need to
168	// grab policy locks as we have ensured it doesn't already exist, so there
169	// will be no races as nothing else has this pointer. If 'force' was not used,
170	// an error would have been returned by now if the policy already existed
171	if pRaw != nil {
172		p = pRaw.(*Policy)
173	}
174	if p != nil {
175		p.l.Lock()
176		defer p.l.Unlock()
177	}
178
179	// Restore the archived keys
180	if keyData.ArchivedKeys != nil {
181		err = keyData.Policy.storeArchive(ctx, storage, keyData.ArchivedKeys)
182		if err != nil {
183			return errwrap.Wrapf(fmt.Sprintf("failed to restore archived keys for key %q: {{err}}", name), err)
184		}
185	}
186
187	// Mark that policy as a restored key
188	keyData.Policy.RestoreInfo = &RestoreInfo{
189		Time:    time.Now(),
190		Version: keyData.Policy.LatestVersion,
191	}
192
193	// Restore the policy. This will also attempt to adjust the archive.
194	err = keyData.Policy.Persist(ctx, storage)
195	if err != nil {
196		return errwrap.Wrapf(fmt.Sprintf("failed to restore the policy %q: {{err}}", name), err)
197	}
198
199	keyData.Policy.l = new(sync.RWMutex)
200
201	// Update the cache to contain the restored policy
202	if lm.useCache {
203		lm.cache.Store(name, keyData.Policy)
204	}
205	return nil
206}
207
208func (lm *LockManager) BackupPolicy(ctx context.Context, storage logical.Storage, name string) (string, error) {
209	var p *Policy
210	var err error
211
212	// Backup writes information about when the backup took place, so we get an
213	// exclusive lock here
214	lock := locksutil.LockForKey(lm.keyLocks, name)
215	lock.Lock()
216	defer lock.Unlock()
217
218	var ok bool
219	var pRaw interface{}
220
221	if lm.useCache {
222		pRaw, ok = lm.cache.Load(name)
223	}
224	if ok {
225		p = pRaw.(*Policy)
226		p.l.Lock()
227		defer p.l.Unlock()
228	} else {
229		// If the policy doesn't exit in storage, error out
230		p, err = lm.getPolicyFromStorage(ctx, storage, name)
231		if err != nil {
232			return "", err
233		}
234		if p == nil {
235			return "", fmt.Errorf(fmt.Sprintf("key %q not found", name))
236		}
237	}
238
239	if atomic.LoadUint32(&p.deleted) == 1 {
240		return "", fmt.Errorf(fmt.Sprintf("key %q not found", name))
241	}
242
243	backup, err := p.Backup(ctx, storage)
244	if err != nil {
245		return "", err
246	}
247
248	return backup, nil
249}
250
251// When the function returns, if caching was disabled, the Policy's lock must
252// be unlocked when the caller is done (and it should not be re-locked).
253func (lm *LockManager) GetPolicy(ctx context.Context, req PolicyRequest, rand io.Reader) (retP *Policy, retUpserted bool, retErr error) {
254	var p *Policy
255	var err error
256	var ok bool
257	var pRaw interface{}
258
259	// Check if it's in our cache. If so, return right away.
260	if lm.useCache {
261		pRaw, ok = lm.cache.Load(req.Name)
262	}
263	if ok {
264		p = pRaw.(*Policy)
265		if atomic.LoadUint32(&p.deleted) == 1 {
266			return nil, false, nil
267		}
268		return p, false, nil
269	}
270
271	// We're not using the cache, or it wasn't found; get an exclusive lock.
272	// This ensures that any other process writing the actual storage will be
273	// finished before we load from storage.
274	lock := locksutil.LockForKey(lm.keyLocks, req.Name)
275	lock.Lock()
276
277	// If we are using the cache, defer the lock unlock; otherwise we will
278	// return from here with the lock still held.
279	cleanup := func() {
280		switch {
281		// If using the cache we always unlock, the caller locks the policy
282		// themselves
283		case lm.useCache:
284			lock.Unlock()
285
286		// If not using the cache, if we aren't returning a policy the caller
287		// doesn't have a lock, so we must unlock
288		case retP == nil:
289			lock.Unlock()
290		}
291	}
292
293	// Check the cache again
294	if lm.useCache {
295		pRaw, ok = lm.cache.Load(req.Name)
296	}
297	if ok {
298		p = pRaw.(*Policy)
299		if atomic.LoadUint32(&p.deleted) == 1 {
300			cleanup()
301			return nil, false, nil
302		}
303		retP = p
304		cleanup()
305		return
306	}
307
308	// Load it from storage
309	p, err = lm.getPolicyFromStorage(ctx, req.Storage, req.Name)
310	if err != nil {
311		cleanup()
312		return nil, false, err
313	}
314	// We don't need to lock the policy as there would be no other holders of
315	// the pointer
316
317	if p == nil {
318		// This is the only place we upsert a new policy, so if upsert is not
319		// specified, or the lock type is wrong, unlock before returning
320		if !req.Upsert {
321			cleanup()
322			return nil, false, nil
323		}
324
325		// We create the policy here, then at the end we do a LoadOrStore. If
326		// it's been loaded since we last checked the cache, we return an error
327		// to the user to let them know that their request can't be satisfied
328		// because we don't know if the parameters match.
329
330		switch req.KeyType {
331		case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305:
332			if req.Convergent && !req.Derived {
333				cleanup()
334				return nil, false, fmt.Errorf("convergent encryption requires derivation to be enabled")
335			}
336
337		case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521:
338			if req.Derived || req.Convergent {
339				cleanup()
340				return nil, false, fmt.Errorf("key derivation and convergent encryption not supported for keys of type %v", req.KeyType)
341			}
342
343		case KeyType_ED25519:
344			if req.Convergent {
345				cleanup()
346				return nil, false, fmt.Errorf("convergent encryption not supported for keys of type %v", req.KeyType)
347			}
348
349		case KeyType_RSA2048, KeyType_RSA4096:
350			if req.Derived || req.Convergent {
351				cleanup()
352				return nil, false, fmt.Errorf("key derivation and convergent encryption not supported for keys of type %v", req.KeyType)
353			}
354
355		default:
356			cleanup()
357			return nil, false, fmt.Errorf("unsupported key type %v", req.KeyType)
358		}
359
360		p = &Policy{
361			l:                    new(sync.RWMutex),
362			Name:                 req.Name,
363			Type:                 req.KeyType,
364			Derived:              req.Derived,
365			Exportable:           req.Exportable,
366			AllowPlaintextBackup: req.AllowPlaintextBackup,
367		}
368
369		if req.Derived {
370			p.KDF = Kdf_hkdf_sha256
371			if req.Convergent {
372				p.ConvergentEncryption = true
373				// As of version 3 we store the version within each key, so we
374				// set to -1 to indicate that the value in the policy has no
375				// meaning. We still, for backwards compatibility, fall back to
376				// this value if the key doesn't have one, which means it will
377				// only be -1 in the case where every key version is >= 3
378				p.ConvergentVersion = -1
379			}
380		}
381
382		// Performs the actual persist and does setup
383		err = p.Rotate(ctx, req.Storage, rand)
384		if err != nil {
385			cleanup()
386			return nil, false, err
387		}
388
389		if lm.useCache {
390			lm.cache.Store(req.Name, p)
391		} else {
392			p.l = &lock.RWMutex
393			p.writeLocked = true
394		}
395
396		// We don't need to worry about upgrading since it will be a new policy
397		retP = p
398		retUpserted = true
399		cleanup()
400		return
401	}
402
403	if p.NeedsUpgrade() {
404		if err := p.Upgrade(ctx, req.Storage, rand); err != nil {
405			cleanup()
406			return nil, false, err
407		}
408	}
409
410	if lm.useCache {
411		lm.cache.Store(req.Name, p)
412	} else {
413		p.l = &lock.RWMutex
414		p.writeLocked = true
415	}
416
417	retP = p
418	cleanup()
419	return
420}
421
422func (lm *LockManager) DeletePolicy(ctx context.Context, storage logical.Storage, name string) error {
423	var p *Policy
424	var err error
425	var ok bool
426	var pRaw interface{}
427
428	// We may be writing to disk, so grab an exclusive lock. This prevents bad
429	// behavior when the cache is turned off. We also lock the shared policy
430	// object to make sure no requests are in flight.
431	lock := locksutil.LockForKey(lm.keyLocks, name)
432	lock.Lock()
433	defer lock.Unlock()
434
435	if lm.useCache {
436		pRaw, ok = lm.cache.Load(name)
437	}
438	if ok {
439		p = pRaw.(*Policy)
440		p.l.Lock()
441		defer p.l.Unlock()
442	}
443
444	if p == nil {
445		p, err = lm.getPolicyFromStorage(ctx, storage, name)
446		if err != nil {
447			return err
448		}
449		if p == nil {
450			return fmt.Errorf("could not delete key; not found")
451		}
452	}
453
454	if !p.DeletionAllowed {
455		return fmt.Errorf("deletion is not allowed for this key")
456	}
457
458	atomic.StoreUint32(&p.deleted, 1)
459
460	if lm.useCache {
461		lm.cache.Delete(name)
462	}
463
464	err = storage.Delete(ctx, "policy/"+name)
465	if err != nil {
466		return errwrap.Wrapf(fmt.Sprintf("error deleting key %q: {{err}}", name), err)
467	}
468
469	err = storage.Delete(ctx, "archive/"+name)
470	if err != nil {
471		return errwrap.Wrapf(fmt.Sprintf("error deleting key %q archive: {{err}}", name), err)
472	}
473
474	return nil
475}
476
477func (lm *LockManager) getPolicyFromStorage(ctx context.Context, storage logical.Storage, name string) (*Policy, error) {
478	return LoadPolicy(ctx, storage, "policy/"+name)
479}
480