1package keysutil
2
3import (
4	"context"
5	"crypto/rand"
6	"reflect"
7	"strconv"
8	"sync"
9	"testing"
10	"time"
11
12	"github.com/hashicorp/vault/sdk/helper/jsonutil"
13	"github.com/hashicorp/vault/sdk/logical"
14	"github.com/mitchellh/copystructure"
15)
16
17func TestPolicy_KeyEntryMapUpgrade(t *testing.T) {
18	now := time.Now()
19	old := map[int]KeyEntry{
20		1: {
21			Key:                []byte("samplekey"),
22			HMACKey:            []byte("samplehmackey"),
23			CreationTime:       now,
24			FormattedPublicKey: "sampleformattedpublickey",
25		},
26		2: {
27			Key:                []byte("samplekey2"),
28			HMACKey:            []byte("samplehmackey2"),
29			CreationTime:       now.Add(10 * time.Second),
30			FormattedPublicKey: "sampleformattedpublickey2",
31		},
32	}
33
34	oldEncoded, err := jsonutil.EncodeJSON(old)
35	if err != nil {
36		t.Fatal(err)
37	}
38
39	var new keyEntryMap
40	err = jsonutil.DecodeJSON(oldEncoded, &new)
41	if err != nil {
42		t.Fatal(err)
43	}
44
45	newEncoded, err := jsonutil.EncodeJSON(&new)
46	if err != nil {
47		t.Fatal(err)
48	}
49
50	if string(oldEncoded) != string(newEncoded) {
51		t.Fatalf("failed to upgrade key entry map;\nold: %q\nnew: %q", string(oldEncoded), string(newEncoded))
52	}
53}
54
55func Test_KeyUpgrade(t *testing.T) {
56	lockManagerWithCache, _ := NewLockManager(true, 0)
57	lockManagerWithoutCache, _ := NewLockManager(false, 0)
58	testKeyUpgradeCommon(t, lockManagerWithCache)
59	testKeyUpgradeCommon(t, lockManagerWithoutCache)
60}
61
62func testKeyUpgradeCommon(t *testing.T, lm *LockManager) {
63	ctx := context.Background()
64
65	storage := &logical.InmemStorage{}
66	p, upserted, err := lm.GetPolicy(ctx, PolicyRequest{
67		Upsert:  true,
68		Storage: storage,
69		KeyType: KeyType_AES256_GCM96,
70		Name:    "test",
71	}, rand.Reader)
72	if err != nil {
73		t.Fatal(err)
74	}
75	if p == nil {
76		t.Fatal("nil policy")
77	}
78	if !upserted {
79		t.Fatal("expected an upsert")
80	}
81	if !lm.useCache {
82		p.Unlock()
83	}
84
85	testBytes := make([]byte, len(p.Keys["1"].Key))
86	copy(testBytes, p.Keys["1"].Key)
87
88	p.Key = p.Keys["1"].Key
89	p.Keys = nil
90	p.MigrateKeyToKeysMap()
91	if p.Key != nil {
92		t.Fatal("policy.Key is not nil")
93	}
94	if len(p.Keys) != 1 {
95		t.Fatal("policy.Keys is the wrong size")
96	}
97	if !reflect.DeepEqual(testBytes, p.Keys["1"].Key) {
98		t.Fatal("key mismatch")
99	}
100}
101
102func Test_ArchivingUpgrade(t *testing.T) {
103	lockManagerWithCache, _ := NewLockManager(true, 0)
104	lockManagerWithoutCache, _ := NewLockManager(false, 0)
105	testArchivingUpgradeCommon(t, lockManagerWithCache)
106	testArchivingUpgradeCommon(t, lockManagerWithoutCache)
107}
108
109func testArchivingUpgradeCommon(t *testing.T, lm *LockManager) {
110	ctx := context.Background()
111
112	// First, we generate a policy and rotate it a number of times. Each time
113	// we'll ensure that we have the expected number of keys in the archive and
114	// the main keys object, which without changing the min version should be
115	// zero and latest, respectively
116
117	storage := &logical.InmemStorage{}
118	p, _, err := lm.GetPolicy(ctx, PolicyRequest{
119		Upsert:  true,
120		Storage: storage,
121		KeyType: KeyType_AES256_GCM96,
122		Name:    "test",
123	}, rand.Reader)
124	if err != nil {
125		t.Fatal(err)
126	}
127	if p == nil {
128		t.Fatal("nil policy")
129	}
130	if !lm.useCache {
131		p.Unlock()
132	}
133
134	// Store the initial key in the archive
135	keysArchive := []KeyEntry{KeyEntry{}, p.Keys["1"]}
136	checkKeys(t, ctx, p, storage, keysArchive, "initial", 1, 1, 1)
137
138	for i := 2; i <= 10; i++ {
139		err = p.Rotate(ctx, storage, rand.Reader)
140		if err != nil {
141			t.Fatal(err)
142		}
143		keysArchive = append(keysArchive, p.Keys[strconv.Itoa(i)])
144		checkKeys(t, ctx, p, storage, keysArchive, "rotate", i, i, i)
145	}
146
147	// Now, wipe the archive and set the archive version to zero
148	err = storage.Delete(ctx, "archive/test")
149	if err != nil {
150		t.Fatal(err)
151	}
152	p.ArchiveVersion = 0
153
154	// Store it, but without calling persist, so we don't trigger
155	// handleArchiving()
156	buf, err := p.Serialize()
157	if err != nil {
158		t.Fatal(err)
159	}
160
161	// Write the policy into storage
162	err = storage.Put(ctx, &logical.StorageEntry{
163		Key:   "policy/" + p.Name,
164		Value: buf,
165	})
166	if err != nil {
167		t.Fatal(err)
168	}
169
170	// If we're caching, expire from the cache since we modified it
171	// under-the-hood
172	if lm.useCache {
173		lm.cache.Delete("test")
174	}
175
176	// Now get the policy again; the upgrade should happen automatically
177	p, _, err = lm.GetPolicy(ctx, PolicyRequest{
178		Storage: storage,
179		Name:    "test",
180	}, rand.Reader)
181	if err != nil {
182		t.Fatal(err)
183	}
184	if p == nil {
185		t.Fatal("nil policy")
186	}
187	if !lm.useCache {
188		p.Unlock()
189	}
190
191	checkKeys(t, ctx, p, storage, keysArchive, "upgrade", 10, 10, 10)
192
193	// Let's check some deletion logic while we're at it
194
195	// The policy should be in there
196	if lm.useCache {
197		_, ok := lm.cache.Load("test")
198		if !ok {
199			t.Fatal("nil policy in cache")
200		}
201	}
202
203	// First we'll do this wrong, by not setting the deletion flag
204	err = lm.DeletePolicy(ctx, storage, "test")
205	if err == nil {
206		t.Fatal("got nil error, but should not have been able to delete since we didn't set the deletion flag on the policy")
207	}
208
209	// The policy should still be in there
210	if lm.useCache {
211		_, ok := lm.cache.Load("test")
212		if !ok {
213			t.Fatal("nil policy in cache")
214		}
215	}
216
217	p, _, err = lm.GetPolicy(ctx, PolicyRequest{
218		Storage: storage,
219		Name:    "test",
220	}, rand.Reader)
221	if err != nil {
222		t.Fatal(err)
223	}
224	if p == nil {
225		t.Fatal("policy nil after bad delete")
226	}
227	if !lm.useCache {
228		p.Unlock()
229	}
230
231	// Now do it properly
232	p.DeletionAllowed = true
233	err = p.Persist(ctx, storage)
234	if err != nil {
235		t.Fatal(err)
236	}
237	err = lm.DeletePolicy(ctx, storage, "test")
238	if err != nil {
239		t.Fatal(err)
240	}
241
242	// The policy should *not* be in there
243	if lm.useCache {
244		_, ok := lm.cache.Load("test")
245		if ok {
246			t.Fatal("non-nil policy in cache")
247		}
248	}
249
250	p, _, err = lm.GetPolicy(ctx, PolicyRequest{
251		Storage: storage,
252		Name:    "test",
253	}, rand.Reader)
254	if err != nil {
255		t.Fatal(err)
256	}
257	if p != nil {
258		t.Fatal("policy not nil after delete")
259	}
260}
261
262func Test_Archiving(t *testing.T) {
263	lockManagerWithCache, _ := NewLockManager(true, 0)
264	lockManagerWithoutCache, _ := NewLockManager(false, 0)
265	testArchivingUpgradeCommon(t, lockManagerWithCache)
266	testArchivingUpgradeCommon(t, lockManagerWithoutCache)
267}
268
269func testArchivingCommon(t *testing.T, lm *LockManager) {
270	ctx := context.Background()
271
272	// First, we generate a policy and rotate it a number of times. Each time
273	// we'll ensure that we have the expected number of keys in the archive and
274	// the main keys object, which without changing the min version should be
275	// zero and latest, respectively
276
277	storage := &logical.InmemStorage{}
278	p, _, err := lm.GetPolicy(ctx, PolicyRequest{
279		Upsert:  true,
280		Storage: storage,
281		KeyType: KeyType_AES256_GCM96,
282		Name:    "test",
283	}, rand.Reader)
284	if err != nil {
285		t.Fatal(err)
286	}
287	if p == nil {
288		t.Fatal("nil policy")
289	}
290	if !lm.useCache {
291		p.Unlock()
292	}
293
294	// Store the initial key in the archive
295	keysArchive := []KeyEntry{KeyEntry{}, p.Keys["1"]}
296	checkKeys(t, ctx, p, storage, keysArchive, "initial", 1, 1, 1)
297
298	for i := 2; i <= 10; i++ {
299		err = p.Rotate(ctx, storage, rand.Reader)
300		if err != nil {
301			t.Fatal(err)
302		}
303		keysArchive = append(keysArchive, p.Keys[strconv.Itoa(i)])
304		checkKeys(t, ctx, p, storage, keysArchive, "rotate", i, i, i)
305	}
306
307	// Move the min decryption version up
308	for i := 1; i <= 10; i++ {
309		p.MinDecryptionVersion = i
310
311		err = p.Persist(ctx, storage)
312		if err != nil {
313			t.Fatal(err)
314		}
315		// We expect to find:
316		// * The keys in archive are the same as the latest version
317		// * The latest version is constant
318		// * The number of keys in the policy itself is from the min
319		// decryption version up to the latest version, so for e.g. 7 and
320		// 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min
321		// decryption version plus 1 (the min decryption version key
322		// itself)
323		checkKeys(t, ctx, p, storage, keysArchive, "minadd", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1)
324	}
325
326	// Move the min decryption version down
327	for i := 10; i >= 1; i-- {
328		p.MinDecryptionVersion = i
329
330		err = p.Persist(ctx, storage)
331		if err != nil {
332			t.Fatal(err)
333		}
334		// We expect to find:
335		// * The keys in archive are never removed so same as the latest version
336		// * The latest version is constant
337		// * The number of keys in the policy itself is from the min
338		// decryption version up to the latest version, so for e.g. 7 and
339		// 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min
340		// decryption version plus 1 (the min decryption version key
341		// itself)
342		checkKeys(t, ctx, p, storage, keysArchive, "minsub", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1)
343	}
344}
345
346func checkKeys(t *testing.T,
347	ctx context.Context,
348	p *Policy,
349	storage logical.Storage,
350	keysArchive []KeyEntry,
351	action string,
352	archiveVer, latestVer, keysSize int) {
353
354	// Sanity check
355	if len(keysArchive) != latestVer+1 {
356		t.Fatalf("latest expected key version is %d, expected test keys archive size is %d, "+
357			"but keys archive is of size %d", latestVer, latestVer+1, len(keysArchive))
358	}
359
360	archive, err := p.LoadArchive(ctx, storage)
361	if err != nil {
362		t.Fatal(err)
363	}
364
365	badArchiveVer := false
366	if archiveVer == 0 {
367		if len(archive.Keys) != 0 || p.ArchiveVersion != 0 {
368			badArchiveVer = true
369		}
370	} else {
371		// We need to subtract one because we have the indexes match key
372		// versions, which start at 1. So for an archive version of 1, we
373		// actually have two entries -- a blank 0 entry, and the key at spot 1
374		if archiveVer != len(archive.Keys)-1 || archiveVer != p.ArchiveVersion {
375			badArchiveVer = true
376		}
377	}
378	if badArchiveVer {
379		t.Fatalf(
380			"expected archive version %d, found length of archive keys %d and policy archive version %d",
381			archiveVer, len(archive.Keys), p.ArchiveVersion,
382		)
383	}
384
385	if latestVer != p.LatestVersion {
386		t.Fatalf(
387			"expected latest version %d, found %d",
388			latestVer, p.LatestVersion,
389		)
390	}
391
392	if keysSize != len(p.Keys) {
393		t.Fatalf(
394			"expected keys size %d, found %d, action is %s, policy is \n%#v\n",
395			keysSize, len(p.Keys), action, p,
396		)
397	}
398
399	for i := p.MinDecryptionVersion; i <= p.LatestVersion; i++ {
400		if _, ok := p.Keys[strconv.Itoa(i)]; !ok {
401			t.Fatalf(
402				"expected key %d, did not find it in policy keys", i,
403			)
404		}
405	}
406
407	for i := p.MinDecryptionVersion; i <= p.LatestVersion; i++ {
408		ver := strconv.Itoa(i)
409		if !p.Keys[ver].CreationTime.Equal(keysArchive[i].CreationTime) {
410			t.Fatalf("key %d not equivalent between policy keys and test keys archive; policy keys:\n%#v\ntest keys archive:\n%#v\n", i, p.Keys[ver], keysArchive[i])
411		}
412		polKey := p.Keys[ver]
413		polKey.CreationTime = keysArchive[i].CreationTime
414		p.Keys[ver] = polKey
415		if !reflect.DeepEqual(p.Keys[ver], keysArchive[i]) {
416			t.Fatalf("key %d not equivalent between policy keys and test keys archive; policy keys:\n%#v\ntest keys archive:\n%#v\n", i, p.Keys[ver], keysArchive[i])
417		}
418	}
419
420	for i := 1; i < len(archive.Keys); i++ {
421		if !reflect.DeepEqual(archive.Keys[i].Key, keysArchive[i].Key) {
422			t.Fatalf("key %d not equivalent between policy archive and test keys archive; policy archive:\n%#v\ntest keys archive:\n%#v\n", i, archive.Keys[i].Key, keysArchive[i].Key)
423		}
424	}
425}
426
427func Test_StorageErrorSafety(t *testing.T) {
428	ctx := context.Background()
429	lm, _ := NewLockManager(true, 0)
430
431	storage := &logical.InmemStorage{}
432	p, _, err := lm.GetPolicy(ctx, PolicyRequest{
433		Upsert:  true,
434		Storage: storage,
435		KeyType: KeyType_AES256_GCM96,
436		Name:    "test",
437	}, rand.Reader)
438	if err != nil {
439		t.Fatal(err)
440	}
441	if p == nil {
442		t.Fatal("nil policy")
443	}
444
445	// Store the initial key in the archive
446	keysArchive := []KeyEntry{KeyEntry{}, p.Keys["1"]}
447	checkKeys(t, ctx, p, storage, keysArchive, "initial", 1, 1, 1)
448
449	// We use checkKeys here just for sanity; it doesn't really handle cases of
450	// errors below so we do more targeted testing later
451	for i := 2; i <= 5; i++ {
452		err = p.Rotate(ctx, storage, rand.Reader)
453		if err != nil {
454			t.Fatal(err)
455		}
456		keysArchive = append(keysArchive, p.Keys[strconv.Itoa(i)])
457		checkKeys(t, ctx, p, storage, keysArchive, "rotate", i, i, i)
458	}
459
460	underlying := storage.Underlying()
461	underlying.FailPut(true)
462
463	priorLen := len(p.Keys)
464
465	err = p.Rotate(ctx, storage, rand.Reader)
466	if err == nil {
467		t.Fatal("expected error")
468	}
469
470	if len(p.Keys) != priorLen {
471		t.Fatal("length of keys should not have changed")
472	}
473}
474
475func Test_BadUpgrade(t *testing.T) {
476	ctx := context.Background()
477	lm, _ := NewLockManager(true, 0)
478	storage := &logical.InmemStorage{}
479	p, _, err := lm.GetPolicy(ctx, PolicyRequest{
480		Upsert:  true,
481		Storage: storage,
482		KeyType: KeyType_AES256_GCM96,
483		Name:    "test",
484	}, rand.Reader)
485	if err != nil {
486		t.Fatal(err)
487	}
488	if p == nil {
489		t.Fatal("nil policy")
490	}
491
492	orig, err := copystructure.Copy(p)
493	if err != nil {
494		t.Fatal(err)
495	}
496	orig.(*Policy).l = p.l
497
498	p.Key = p.Keys["1"].Key
499	p.Keys = nil
500	p.MinDecryptionVersion = 0
501
502	if err := p.Upgrade(ctx, storage, rand.Reader); err != nil {
503		t.Fatal(err)
504	}
505
506	k := p.Keys["1"]
507	o := orig.(*Policy).Keys["1"]
508	k.CreationTime = o.CreationTime
509	k.HMACKey = o.HMACKey
510	p.Keys["1"] = k
511	p.versionPrefixCache = sync.Map{}
512
513	if !reflect.DeepEqual(orig, p) {
514		t.Fatalf("not equal:\n%#v\n%#v", orig, p)
515	}
516
517	// Do it again with a failing storage call
518	underlying := storage.Underlying()
519	underlying.FailPut(true)
520
521	p.Key = p.Keys["1"].Key
522	p.Keys = nil
523	p.MinDecryptionVersion = 0
524
525	if err := p.Upgrade(ctx, storage, rand.Reader); err == nil {
526		t.Fatal("expected error")
527	}
528
529	if p.MinDecryptionVersion == 1 {
530		t.Fatal("min decryption version was changed")
531	}
532	if p.Keys != nil {
533		t.Fatal("found upgraded keys")
534	}
535	if p.Key == nil {
536		t.Fatal("non-upgraded key not found")
537	}
538}
539
540func Test_BadArchive(t *testing.T) {
541	ctx := context.Background()
542	lm, _ := NewLockManager(true, 0)
543	storage := &logical.InmemStorage{}
544	p, _, err := lm.GetPolicy(ctx, PolicyRequest{
545		Upsert:  true,
546		Storage: storage,
547		KeyType: KeyType_AES256_GCM96,
548		Name:    "test",
549	}, rand.Reader)
550	if err != nil {
551		t.Fatal(err)
552	}
553	if p == nil {
554		t.Fatal("nil policy")
555	}
556
557	for i := 2; i <= 10; i++ {
558		err = p.Rotate(ctx, storage, rand.Reader)
559		if err != nil {
560			t.Fatal(err)
561		}
562	}
563
564	p.MinDecryptionVersion = 5
565	if err := p.Persist(ctx, storage); err != nil {
566		t.Fatal(err)
567	}
568	if p.ArchiveVersion != 10 {
569		t.Fatalf("unexpected archive version %d", p.ArchiveVersion)
570	}
571	if len(p.Keys) != 6 {
572		t.Fatalf("unexpected key length %d", len(p.Keys))
573	}
574
575	// Set back
576	p.MinDecryptionVersion = 1
577	if err := p.Persist(ctx, storage); err != nil {
578		t.Fatal(err)
579	}
580	if p.ArchiveVersion != 10 {
581		t.Fatalf("unexpected archive version %d", p.ArchiveVersion)
582	}
583	if len(p.Keys) != 10 {
584		t.Fatalf("unexpected key length %d", len(p.Keys))
585	}
586
587	// Run it again but we'll turn off storage along the way
588	p.MinDecryptionVersion = 5
589	if err := p.Persist(ctx, storage); err != nil {
590		t.Fatal(err)
591	}
592	if p.ArchiveVersion != 10 {
593		t.Fatalf("unexpected archive version %d", p.ArchiveVersion)
594	}
595	if len(p.Keys) != 6 {
596		t.Fatalf("unexpected key length %d", len(p.Keys))
597	}
598
599	underlying := storage.Underlying()
600	underlying.FailPut(true)
601
602	// Set back, which should cause p.Keys to be changed if the persist works,
603	// but it doesn't
604	p.MinDecryptionVersion = 1
605	if err := p.Persist(ctx, storage); err == nil {
606		t.Fatal("expected error during put")
607	}
608	if p.ArchiveVersion != 10 {
609		t.Fatalf("unexpected archive version %d", p.ArchiveVersion)
610	}
611	// Here's the expected change
612	if len(p.Keys) != 6 {
613		t.Fatalf("unexpected key length %d", len(p.Keys))
614	}
615}
616