1package transit
2
3import (
4	"reflect"
5	"testing"
6
7	"github.com/hashicorp/vault/logical"
8)
9
10var (
11	keysArchive []KeyEntry
12)
13
14func resetKeysArchive() {
15	keysArchive = []KeyEntry{KeyEntry{}}
16}
17
18func Test_KeyUpgrade(t *testing.T) {
19	testKeyUpgradeCommon(t, newLockManager(false))
20	testKeyUpgradeCommon(t, newLockManager(true))
21}
22
23func testKeyUpgradeCommon(t *testing.T, lm *lockManager) {
24	storage := &logical.InmemStorage{}
25	p, lock, upserted, err := lm.GetPolicyUpsert(storage, "test", false)
26	if lock != nil {
27		defer lock.RUnlock()
28	}
29	if err != nil {
30		t.Fatal(err)
31	}
32	if p == nil {
33		t.Fatal("nil policy")
34	}
35	if !upserted {
36		t.Fatal("expected an upsert")
37	}
38
39	testBytes := make([]byte, len(p.Keys[1].Key))
40	copy(testBytes, p.Keys[1].Key)
41
42	p.Key = p.Keys[1].Key
43	p.Keys = nil
44	p.migrateKeyToKeysMap()
45	if p.Key != nil {
46		t.Fatal("policy.Key is not nil")
47	}
48	if len(p.Keys) != 1 {
49		t.Fatal("policy.Keys is the wrong size")
50	}
51	if !reflect.DeepEqual(testBytes, p.Keys[1].Key) {
52		t.Fatal("key mismatch")
53	}
54}
55
56func Test_ArchivingUpgrade(t *testing.T) {
57	testArchivingUpgradeCommon(t, newLockManager(false))
58	testArchivingUpgradeCommon(t, newLockManager(true))
59}
60
61func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
62	resetKeysArchive()
63
64	// First, we generate a policy and rotate it a number of times. Each time
65	// we'll ensure that we have the expected number of keys in the archive and
66	// the main keys object, which without changing the min version should be
67	// zero and latest, respectively
68
69	storage := &logical.InmemStorage{}
70
71	p, lock, _, err := lm.GetPolicyUpsert(storage, "test", false)
72	if err != nil {
73		t.Fatal(err)
74	}
75	if p == nil || lock == nil {
76		t.Fatal("nil policy or lock")
77	}
78	lock.RUnlock()
79
80	// Store the initial key in the archive
81	keysArchive = append(keysArchive, p.Keys[1])
82	checkKeys(t, p, storage, "initial", 1, 1, 1)
83
84	for i := 2; i <= 10; i++ {
85		err = p.rotate(storage)
86		if err != nil {
87			t.Fatal(err)
88		}
89		keysArchive = append(keysArchive, p.Keys[i])
90		checkKeys(t, p, storage, "rotate", i, i, i)
91	}
92
93	// Now, wipe the archive and set the archive version to zero
94	err = storage.Delete("archive/test")
95	if err != nil {
96		t.Fatal(err)
97	}
98	p.ArchiveVersion = 0
99
100	// Store it, but without calling persist, so we don't trigger
101	// handleArchiving()
102	buf, err := p.Serialize()
103	if err != nil {
104		t.Fatal(err)
105	}
106
107	// Write the policy into storage
108	err = storage.Put(&logical.StorageEntry{
109		Key:   "policy/" + p.Name,
110		Value: buf,
111	})
112	if err != nil {
113		t.Fatal(err)
114	}
115
116	// If we're caching, expire from the cache since we modified it
117	// under-the-hood
118	if lm.CacheActive() {
119		delete(lm.cache, "test")
120	}
121
122	// Now get the policy again; the upgrade should happen automatically
123	p, lock, err = lm.GetPolicyShared(storage, "test")
124	if err != nil {
125		t.Fatal(err)
126	}
127	if p == nil || lock == nil {
128		t.Fatal("nil policy or lock")
129	}
130	lock.RUnlock()
131
132	checkKeys(t, p, storage, "upgrade", 10, 10, 10)
133
134	// Let's check some deletion logic while we're at it
135
136	// The policy should be in there
137	if lm.CacheActive() && lm.cache["test"] == nil {
138		t.Fatal("nil policy in cache")
139	}
140
141	// First we'll do this wrong, by not setting the deletion flag
142	err = lm.DeletePolicy(storage, "test")
143	if err == nil {
144		t.Fatal("got nil error, but should not have been able to delete since we didn't set the deletion flag on the policy")
145	}
146
147	// The policy should still be in there
148	if lm.CacheActive() && lm.cache["test"] == nil {
149		t.Fatal("nil policy in cache")
150	}
151
152	p, lock, err = lm.GetPolicyShared(storage, "test")
153	if err != nil {
154		t.Fatal(err)
155	}
156	if p == nil || lock == nil {
157		t.Fatal("policy or lock nil after bad delete")
158	}
159	lock.RUnlock()
160
161	// Now do it properly
162	p.DeletionAllowed = true
163	err = p.Persist(storage)
164	if err != nil {
165		t.Fatal(err)
166	}
167	err = lm.DeletePolicy(storage, "test")
168	if err != nil {
169		t.Fatal(err)
170	}
171
172	// The policy should *not* be in there
173	if lm.CacheActive() && lm.cache["test"] != nil {
174		t.Fatal("non-nil policy in cache")
175	}
176
177	p, lock, err = lm.GetPolicyShared(storage, "test")
178	if err != nil {
179		t.Fatal(err)
180	}
181	if p != nil || lock != nil {
182		t.Fatal("policy or lock not nil after delete")
183	}
184}
185
186func Test_Archiving(t *testing.T) {
187	testArchivingCommon(t, newLockManager(false))
188	testArchivingCommon(t, newLockManager(true))
189}
190
191func testArchivingCommon(t *testing.T, lm *lockManager) {
192	resetKeysArchive()
193
194	// First, we generate a policy and rotate it a number of times. Each time
195	// we'll ensure that we have the expected number of keys in the archive and
196	// the main keys object, which without changing the min version should be
197	// zero and latest, respectively
198
199	storage := &logical.InmemStorage{}
200
201	p, lock, _, err := lm.GetPolicyUpsert(storage, "test", false)
202	if lock != nil {
203		defer lock.RUnlock()
204	}
205	if err != nil {
206		t.Fatal(err)
207	}
208	if p == nil {
209		t.Fatal("nil policy")
210	}
211
212	// Store the initial key in the archive
213	keysArchive = append(keysArchive, p.Keys[1])
214	checkKeys(t, p, storage, "initial", 1, 1, 1)
215
216	for i := 2; i <= 10; i++ {
217		err = p.rotate(storage)
218		if err != nil {
219			t.Fatal(err)
220		}
221		keysArchive = append(keysArchive, p.Keys[i])
222		checkKeys(t, p, storage, "rotate", i, i, i)
223	}
224
225	// Move the min decryption version up
226	for i := 1; i <= 10; i++ {
227		p.MinDecryptionVersion = i
228
229		err = p.Persist(storage)
230		if err != nil {
231			t.Fatal(err)
232		}
233		// We expect to find:
234		// * The keys in archive are the same as the latest version
235		// * The latest version is constant
236		// * The number of keys in the policy itself is from the min
237		// decryption version up to the latest version, so for e.g. 7 and
238		// 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min
239		// decryption version plus 1 (the min decryption version key
240		// itself)
241		checkKeys(t, p, storage, "minadd", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1)
242	}
243
244	// Move the min decryption version down
245	for i := 10; i >= 1; i-- {
246		p.MinDecryptionVersion = i
247
248		err = p.Persist(storage)
249		if err != nil {
250			t.Fatal(err)
251		}
252		// We expect to find:
253		// * The keys in archive are never removed so same as the latest version
254		// * The latest version is constant
255		// * The number of keys in the policy itself is from the min
256		// decryption version up to the latest version, so for e.g. 7 and
257		// 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min
258		// decryption version plus 1 (the min decryption version key
259		// itself)
260		checkKeys(t, p, storage, "minsub", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1)
261	}
262}
263
264func checkKeys(t *testing.T,
265	policy *Policy,
266	storage logical.Storage,
267	action string,
268	archiveVer, latestVer, keysSize int) {
269
270	// Sanity check
271	if len(keysArchive) != latestVer+1 {
272		t.Fatalf("latest expected key version is %d, expected test keys archive size is %d, "+
273			"but keys archive is of size %d", latestVer, latestVer+1, len(keysArchive))
274	}
275
276	archive, err := policy.loadArchive(storage)
277	if err != nil {
278		t.Fatal(err)
279	}
280
281	badArchiveVer := false
282	if archiveVer == 0 {
283		if len(archive.Keys) != 0 || policy.ArchiveVersion != 0 {
284			badArchiveVer = true
285		}
286	} else {
287		// We need to subtract one because we have the indexes match key
288		// versions, which start at 1. So for an archive version of 1, we
289		// actually have two entries -- a blank 0 entry, and the key at spot 1
290		if archiveVer != len(archive.Keys)-1 || archiveVer != policy.ArchiveVersion {
291			badArchiveVer = true
292		}
293	}
294	if badArchiveVer {
295		t.Fatalf(
296			"expected archive version %d, found length of archive keys %d and policy archive version %d",
297			archiveVer, len(archive.Keys), policy.ArchiveVersion,
298		)
299	}
300
301	if latestVer != policy.LatestVersion {
302		t.Fatalf(
303			"expected latest version %d, found %d",
304			latestVer, policy.LatestVersion,
305		)
306	}
307
308	if keysSize != len(policy.Keys) {
309		t.Fatalf(
310			"expected keys size %d, found %d, action is %s, policy is \n%#v\n",
311			keysSize, len(policy.Keys), action, policy,
312		)
313	}
314
315	for i := policy.MinDecryptionVersion; i <= policy.LatestVersion; i++ {
316		if _, ok := policy.Keys[i]; !ok {
317			t.Fatalf(
318				"expected key %d, did not find it in policy keys", i,
319			)
320		}
321	}
322
323	for i := policy.MinDecryptionVersion; i <= policy.LatestVersion; i++ {
324		if !reflect.DeepEqual(policy.Keys[i], keysArchive[i]) {
325			t.Fatalf("key %d not equivalent between policy keys and test keys archive", i)
326		}
327	}
328
329	for i := 1; i < len(archive.Keys); i++ {
330		if !reflect.DeepEqual(archive.Keys[i].Key, keysArchive[i].Key) {
331			t.Fatalf("key %d not equivalent between policy archive and test keys archive", i)
332		}
333	}
334}
335