1package keysutil
2
3import (
4	"context"
5	"reflect"
6	"strconv"
7	"sync"
8	"testing"
9	"time"
10
11	"github.com/hashicorp/vault/sdk/helper/jsonutil"
12	"github.com/hashicorp/vault/sdk/logical"
13	"github.com/mitchellh/copystructure"
14)
15
16func TestPolicy_KeyEntryMapUpgrade(t *testing.T) {
17	now := time.Now()
18	old := map[int]KeyEntry{
19		1: {
20			Key:                []byte("samplekey"),
21			HMACKey:            []byte("samplehmackey"),
22			CreationTime:       now,
23			FormattedPublicKey: "sampleformattedpublickey",
24		},
25		2: {
26			Key:                []byte("samplekey2"),
27			HMACKey:            []byte("samplehmackey2"),
28			CreationTime:       now.Add(10 * time.Second),
29			FormattedPublicKey: "sampleformattedpublickey2",
30		},
31	}
32
33	oldEncoded, err := jsonutil.EncodeJSON(old)
34	if err != nil {
35		t.Fatal(err)
36	}
37
38	var new keyEntryMap
39	err = jsonutil.DecodeJSON(oldEncoded, &new)
40	if err != nil {
41		t.Fatal(err)
42	}
43
44	newEncoded, err := jsonutil.EncodeJSON(&new)
45	if err != nil {
46		t.Fatal(err)
47	}
48
49	if string(oldEncoded) != string(newEncoded) {
50		t.Fatalf("failed to upgrade key entry map;\nold: %q\nnew: %q", string(oldEncoded), string(newEncoded))
51	}
52}
53
54func Test_KeyUpgrade(t *testing.T) {
55	lockManagerWithCache, _ := NewLockManager(true, 0)
56	lockManagerWithoutCache, _ := NewLockManager(false, 0)
57	testKeyUpgradeCommon(t, lockManagerWithCache)
58	testKeyUpgradeCommon(t, lockManagerWithoutCache)
59}
60
61func testKeyUpgradeCommon(t *testing.T, lm *LockManager) {
62	ctx := context.Background()
63
64	storage := &logical.InmemStorage{}
65	p, upserted, err := lm.GetPolicy(ctx, PolicyRequest{
66		Upsert:  true,
67		Storage: storage,
68		KeyType: KeyType_AES256_GCM96,
69		Name:    "test",
70	})
71	if err != nil {
72		t.Fatal(err)
73	}
74	if p == nil {
75		t.Fatal("nil policy")
76	}
77	if !upserted {
78		t.Fatal("expected an upsert")
79	}
80	if !lm.useCache {
81		p.Unlock()
82	}
83
84	testBytes := make([]byte, len(p.Keys["1"].Key))
85	copy(testBytes, p.Keys["1"].Key)
86
87	p.Key = p.Keys["1"].Key
88	p.Keys = nil
89	p.MigrateKeyToKeysMap()
90	if p.Key != nil {
91		t.Fatal("policy.Key is not nil")
92	}
93	if len(p.Keys) != 1 {
94		t.Fatal("policy.Keys is the wrong size")
95	}
96	if !reflect.DeepEqual(testBytes, p.Keys["1"].Key) {
97		t.Fatal("key mismatch")
98	}
99}
100
101func Test_ArchivingUpgrade(t *testing.T) {
102	lockManagerWithCache, _ := NewLockManager(true, 0)
103	lockManagerWithoutCache, _ := NewLockManager(false, 0)
104	testArchivingUpgradeCommon(t, lockManagerWithCache)
105	testArchivingUpgradeCommon(t, lockManagerWithoutCache)
106}
107
108func testArchivingUpgradeCommon(t *testing.T, lm *LockManager) {
109	ctx := context.Background()
110
111	// First, we generate a policy and rotate it a number of times. Each time
112	// we'll ensure that we have the expected number of keys in the archive and
113	// the main keys object, which without changing the min version should be
114	// zero and latest, respectively
115
116	storage := &logical.InmemStorage{}
117	p, _, err := lm.GetPolicy(ctx, PolicyRequest{
118		Upsert:  true,
119		Storage: storage,
120		KeyType: KeyType_AES256_GCM96,
121		Name:    "test",
122	})
123	if err != nil {
124		t.Fatal(err)
125	}
126	if p == nil {
127		t.Fatal("nil policy")
128	}
129	if !lm.useCache {
130		p.Unlock()
131	}
132
133	// Store the initial key in the archive
134	keysArchive := []KeyEntry{KeyEntry{}, p.Keys["1"]}
135	checkKeys(t, ctx, p, storage, keysArchive, "initial", 1, 1, 1)
136
137	for i := 2; i <= 10; i++ {
138		err = p.Rotate(ctx, storage)
139		if err != nil {
140			t.Fatal(err)
141		}
142		keysArchive = append(keysArchive, p.Keys[strconv.Itoa(i)])
143		checkKeys(t, ctx, p, storage, keysArchive, "rotate", i, i, i)
144	}
145
146	// Now, wipe the archive and set the archive version to zero
147	err = storage.Delete(ctx, "archive/test")
148	if err != nil {
149		t.Fatal(err)
150	}
151	p.ArchiveVersion = 0
152
153	// Store it, but without calling persist, so we don't trigger
154	// handleArchiving()
155	buf, err := p.Serialize()
156	if err != nil {
157		t.Fatal(err)
158	}
159
160	// Write the policy into storage
161	err = storage.Put(ctx, &logical.StorageEntry{
162		Key:   "policy/" + p.Name,
163		Value: buf,
164	})
165	if err != nil {
166		t.Fatal(err)
167	}
168
169	// If we're caching, expire from the cache since we modified it
170	// under-the-hood
171	if lm.useCache {
172		lm.cache.Delete("test")
173	}
174
175	// Now get the policy again; the upgrade should happen automatically
176	p, _, err = lm.GetPolicy(ctx, PolicyRequest{
177		Storage: storage,
178		Name:    "test",
179	})
180	if err != nil {
181		t.Fatal(err)
182	}
183	if p == nil {
184		t.Fatal("nil policy")
185	}
186	if !lm.useCache {
187		p.Unlock()
188	}
189
190	checkKeys(t, ctx, p, storage, keysArchive, "upgrade", 10, 10, 10)
191
192	// Let's check some deletion logic while we're at it
193
194	// The policy should be in there
195	if lm.useCache {
196		_, ok := lm.cache.Load("test")
197		if !ok {
198			t.Fatal("nil policy in cache")
199		}
200	}
201
202	// First we'll do this wrong, by not setting the deletion flag
203	err = lm.DeletePolicy(ctx, storage, "test")
204	if err == nil {
205		t.Fatal("got nil error, but should not have been able to delete since we didn't set the deletion flag on the policy")
206	}
207
208	// The policy should still be in there
209	if lm.useCache {
210		_, ok := lm.cache.Load("test")
211		if !ok {
212			t.Fatal("nil policy in cache")
213		}
214	}
215
216	p, _, err = lm.GetPolicy(ctx, PolicyRequest{
217		Storage: storage,
218		Name:    "test",
219	})
220	if err != nil {
221		t.Fatal(err)
222	}
223	if p == nil {
224		t.Fatal("policy nil after bad delete")
225	}
226	if !lm.useCache {
227		p.Unlock()
228	}
229
230	// Now do it properly
231	p.DeletionAllowed = true
232	err = p.Persist(ctx, storage)
233	if err != nil {
234		t.Fatal(err)
235	}
236	err = lm.DeletePolicy(ctx, storage, "test")
237	if err != nil {
238		t.Fatal(err)
239	}
240
241	// The policy should *not* be in there
242	if lm.useCache {
243		_, ok := lm.cache.Load("test")
244		if ok {
245			t.Fatal("non-nil policy in cache")
246		}
247	}
248
249	p, _, err = lm.GetPolicy(ctx, PolicyRequest{
250		Storage: storage,
251		Name:    "test",
252	})
253	if err != nil {
254		t.Fatal(err)
255	}
256	if p != nil {
257		t.Fatal("policy not nil after delete")
258	}
259}
260
261func Test_Archiving(t *testing.T) {
262	lockManagerWithCache, _ := NewLockManager(true, 0)
263	lockManagerWithoutCache, _ := NewLockManager(false, 0)
264	testArchivingUpgradeCommon(t, lockManagerWithCache)
265	testArchivingUpgradeCommon(t, lockManagerWithoutCache)
266}
267
268func testArchivingCommon(t *testing.T, lm *LockManager) {
269	ctx := context.Background()
270
271	// First, we generate a policy and rotate it a number of times. Each time
272	// we'll ensure that we have the expected number of keys in the archive and
273	// the main keys object, which without changing the min version should be
274	// zero and latest, respectively
275
276	storage := &logical.InmemStorage{}
277	p, _, err := lm.GetPolicy(ctx, PolicyRequest{
278		Upsert:  true,
279		Storage: storage,
280		KeyType: KeyType_AES256_GCM96,
281		Name:    "test",
282	})
283	if err != nil {
284		t.Fatal(err)
285	}
286	if p == nil {
287		t.Fatal("nil policy")
288	}
289	if !lm.useCache {
290		p.Unlock()
291	}
292
293	// Store the initial key in the archive
294	keysArchive := []KeyEntry{KeyEntry{}, p.Keys["1"]}
295	checkKeys(t, ctx, p, storage, keysArchive, "initial", 1, 1, 1)
296
297	for i := 2; i <= 10; i++ {
298		err = p.Rotate(ctx, storage)
299		if err != nil {
300			t.Fatal(err)
301		}
302		keysArchive = append(keysArchive, p.Keys[strconv.Itoa(i)])
303		checkKeys(t, ctx, p, storage, keysArchive, "rotate", i, i, i)
304	}
305
306	// Move the min decryption version up
307	for i := 1; i <= 10; i++ {
308		p.MinDecryptionVersion = i
309
310		err = p.Persist(ctx, storage)
311		if err != nil {
312			t.Fatal(err)
313		}
314		// We expect to find:
315		// * The keys in archive are the same as the latest version
316		// * The latest version is constant
317		// * The number of keys in the policy itself is from the min
318		// decryption version up to the latest version, so for e.g. 7 and
319		// 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min
320		// decryption version plus 1 (the min decryption version key
321		// itself)
322		checkKeys(t, ctx, p, storage, keysArchive, "minadd", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1)
323	}
324
325	// Move the min decryption version down
326	for i := 10; i >= 1; i-- {
327		p.MinDecryptionVersion = i
328
329		err = p.Persist(ctx, storage)
330		if err != nil {
331			t.Fatal(err)
332		}
333		// We expect to find:
334		// * The keys in archive are never removed so same as the latest version
335		// * The latest version is constant
336		// * The number of keys in the policy itself is from the min
337		// decryption version up to the latest version, so for e.g. 7 and
338		// 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min
339		// decryption version plus 1 (the min decryption version key
340		// itself)
341		checkKeys(t, ctx, p, storage, keysArchive, "minsub", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1)
342	}
343}
344
345func checkKeys(t *testing.T,
346	ctx context.Context,
347	p *Policy,
348	storage logical.Storage,
349	keysArchive []KeyEntry,
350	action string,
351	archiveVer, latestVer, keysSize int) {
352
353	// Sanity check
354	if len(keysArchive) != latestVer+1 {
355		t.Fatalf("latest expected key version is %d, expected test keys archive size is %d, "+
356			"but keys archive is of size %d", latestVer, latestVer+1, len(keysArchive))
357	}
358
359	archive, err := p.LoadArchive(ctx, storage)
360	if err != nil {
361		t.Fatal(err)
362	}
363
364	badArchiveVer := false
365	if archiveVer == 0 {
366		if len(archive.Keys) != 0 || p.ArchiveVersion != 0 {
367			badArchiveVer = true
368		}
369	} else {
370		// We need to subtract one because we have the indexes match key
371		// versions, which start at 1. So for an archive version of 1, we
372		// actually have two entries -- a blank 0 entry, and the key at spot 1
373		if archiveVer != len(archive.Keys)-1 || archiveVer != p.ArchiveVersion {
374			badArchiveVer = true
375		}
376	}
377	if badArchiveVer {
378		t.Fatalf(
379			"expected archive version %d, found length of archive keys %d and policy archive version %d",
380			archiveVer, len(archive.Keys), p.ArchiveVersion,
381		)
382	}
383
384	if latestVer != p.LatestVersion {
385		t.Fatalf(
386			"expected latest version %d, found %d",
387			latestVer, p.LatestVersion,
388		)
389	}
390
391	if keysSize != len(p.Keys) {
392		t.Fatalf(
393			"expected keys size %d, found %d, action is %s, policy is \n%#v\n",
394			keysSize, len(p.Keys), action, p,
395		)
396	}
397
398	for i := p.MinDecryptionVersion; i <= p.LatestVersion; i++ {
399		if _, ok := p.Keys[strconv.Itoa(i)]; !ok {
400			t.Fatalf(
401				"expected key %d, did not find it in policy keys", i,
402			)
403		}
404	}
405
406	for i := p.MinDecryptionVersion; i <= p.LatestVersion; i++ {
407		ver := strconv.Itoa(i)
408		if !p.Keys[ver].CreationTime.Equal(keysArchive[i].CreationTime) {
409			t.Fatalf("key %d not equivalent between policy keys and test keys archive; policy keys:\n%#v\ntest keys archive:\n%#v\n", i, p.Keys[ver], keysArchive[i])
410		}
411		polKey := p.Keys[ver]
412		polKey.CreationTime = keysArchive[i].CreationTime
413		p.Keys[ver] = polKey
414		if !reflect.DeepEqual(p.Keys[ver], keysArchive[i]) {
415			t.Fatalf("key %d not equivalent between policy keys and test keys archive; policy keys:\n%#v\ntest keys archive:\n%#v\n", i, p.Keys[ver], keysArchive[i])
416		}
417	}
418
419	for i := 1; i < len(archive.Keys); i++ {
420		if !reflect.DeepEqual(archive.Keys[i].Key, keysArchive[i].Key) {
421			t.Fatalf("key %d not equivalent between policy archive and test keys archive; policy archive:\n%#v\ntest keys archive:\n%#v\n", i, archive.Keys[i].Key, keysArchive[i].Key)
422		}
423	}
424}
425
426func Test_StorageErrorSafety(t *testing.T) {
427	ctx := context.Background()
428	lm, _ := NewLockManager(true, 0)
429
430	storage := &logical.InmemStorage{}
431	p, _, err := lm.GetPolicy(ctx, PolicyRequest{
432		Upsert:  true,
433		Storage: storage,
434		KeyType: KeyType_AES256_GCM96,
435		Name:    "test",
436	})
437	if err != nil {
438		t.Fatal(err)
439	}
440	if p == nil {
441		t.Fatal("nil policy")
442	}
443
444	// Store the initial key in the archive
445	keysArchive := []KeyEntry{KeyEntry{}, p.Keys["1"]}
446	checkKeys(t, ctx, p, storage, keysArchive, "initial", 1, 1, 1)
447
448	// We use checkKeys here just for sanity; it doesn't really handle cases of
449	// errors below so we do more targeted testing later
450	for i := 2; i <= 5; i++ {
451		err = p.Rotate(ctx, storage)
452		if err != nil {
453			t.Fatal(err)
454		}
455		keysArchive = append(keysArchive, p.Keys[strconv.Itoa(i)])
456		checkKeys(t, ctx, p, storage, keysArchive, "rotate", i, i, i)
457	}
458
459	underlying := storage.Underlying()
460	underlying.FailPut(true)
461
462	priorLen := len(p.Keys)
463
464	err = p.Rotate(ctx, storage)
465	if err == nil {
466		t.Fatal("expected error")
467	}
468
469	if len(p.Keys) != priorLen {
470		t.Fatal("length of keys should not have changed")
471	}
472}
473
474func Test_BadUpgrade(t *testing.T) {
475	ctx := context.Background()
476	lm, _ := NewLockManager(true, 0)
477	storage := &logical.InmemStorage{}
478	p, _, err := lm.GetPolicy(ctx, PolicyRequest{
479		Upsert:  true,
480		Storage: storage,
481		KeyType: KeyType_AES256_GCM96,
482		Name:    "test",
483	})
484	if err != nil {
485		t.Fatal(err)
486	}
487	if p == nil {
488		t.Fatal("nil policy")
489	}
490
491	orig, err := copystructure.Copy(p)
492	if err != nil {
493		t.Fatal(err)
494	}
495	orig.(*Policy).l = p.l
496
497	p.Key = p.Keys["1"].Key
498	p.Keys = nil
499	p.MinDecryptionVersion = 0
500
501	if err := p.Upgrade(ctx, storage); err != nil {
502		t.Fatal(err)
503	}
504
505	k := p.Keys["1"]
506	o := orig.(*Policy).Keys["1"]
507	k.CreationTime = o.CreationTime
508	k.HMACKey = o.HMACKey
509	p.Keys["1"] = k
510	p.versionPrefixCache = sync.Map{}
511
512	if !reflect.DeepEqual(orig, p) {
513		t.Fatalf("not equal:\n%#v\n%#v", orig, p)
514	}
515
516	// Do it again with a failing storage call
517	underlying := storage.Underlying()
518	underlying.FailPut(true)
519
520	p.Key = p.Keys["1"].Key
521	p.Keys = nil
522	p.MinDecryptionVersion = 0
523
524	if err := p.Upgrade(ctx, storage); err == nil {
525		t.Fatal("expected error")
526	}
527
528	if p.MinDecryptionVersion == 1 {
529		t.Fatal("min decryption version was changed")
530	}
531	if p.Keys != nil {
532		t.Fatal("found upgraded keys")
533	}
534	if p.Key == nil {
535		t.Fatal("non-upgraded key not found")
536	}
537}
538
539func Test_BadArchive(t *testing.T) {
540	ctx := context.Background()
541	lm, _ := NewLockManager(true, 0)
542	storage := &logical.InmemStorage{}
543	p, _, err := lm.GetPolicy(ctx, PolicyRequest{
544		Upsert:  true,
545		Storage: storage,
546		KeyType: KeyType_AES256_GCM96,
547		Name:    "test",
548	})
549	if err != nil {
550		t.Fatal(err)
551	}
552	if p == nil {
553		t.Fatal("nil policy")
554	}
555
556	for i := 2; i <= 10; i++ {
557		err = p.Rotate(ctx, storage)
558		if err != nil {
559			t.Fatal(err)
560		}
561	}
562
563	p.MinDecryptionVersion = 5
564	if err := p.Persist(ctx, storage); err != nil {
565		t.Fatal(err)
566	}
567	if p.ArchiveVersion != 10 {
568		t.Fatalf("unexpected archive version %d", p.ArchiveVersion)
569	}
570	if len(p.Keys) != 6 {
571		t.Fatalf("unexpected key length %d", len(p.Keys))
572	}
573
574	// Set back
575	p.MinDecryptionVersion = 1
576	if err := p.Persist(ctx, storage); err != nil {
577		t.Fatal(err)
578	}
579	if p.ArchiveVersion != 10 {
580		t.Fatalf("unexpected archive version %d", p.ArchiveVersion)
581	}
582	if len(p.Keys) != 10 {
583		t.Fatalf("unexpected key length %d", len(p.Keys))
584	}
585
586	// Run it again but we'll turn off storage along the way
587	p.MinDecryptionVersion = 5
588	if err := p.Persist(ctx, storage); err != nil {
589		t.Fatal(err)
590	}
591	if p.ArchiveVersion != 10 {
592		t.Fatalf("unexpected archive version %d", p.ArchiveVersion)
593	}
594	if len(p.Keys) != 6 {
595		t.Fatalf("unexpected key length %d", len(p.Keys))
596	}
597
598	underlying := storage.Underlying()
599	underlying.FailPut(true)
600
601	// Set back, which should cause p.Keys to be changed if the persist works,
602	// but it doesn't
603	p.MinDecryptionVersion = 1
604	if err := p.Persist(ctx, storage); err == nil {
605		t.Fatal("expected error during put")
606	}
607	if p.ArchiveVersion != 10 {
608		t.Fatalf("unexpected archive version %d", p.ArchiveVersion)
609	}
610	// Here's the expected change
611	if len(p.Keys) != 6 {
612		t.Fatalf("unexpected key length %d", len(p.Keys))
613	}
614}
615