1package transit 2 3import ( 4 "reflect" 5 "testing" 6 7 "github.com/hashicorp/vault/logical" 8) 9 10var ( 11 keysArchive []KeyEntry 12) 13 14func resetKeysArchive() { 15 keysArchive = []KeyEntry{KeyEntry{}} 16} 17 18func Test_KeyUpgrade(t *testing.T) { 19 testKeyUpgradeCommon(t, newLockManager(false)) 20 testKeyUpgradeCommon(t, newLockManager(true)) 21} 22 23func testKeyUpgradeCommon(t *testing.T, lm *lockManager) { 24 storage := &logical.InmemStorage{} 25 p, lock, upserted, err := lm.GetPolicyUpsert(storage, "test", false) 26 if lock != nil { 27 defer lock.RUnlock() 28 } 29 if err != nil { 30 t.Fatal(err) 31 } 32 if p == nil { 33 t.Fatal("nil policy") 34 } 35 if !upserted { 36 t.Fatal("expected an upsert") 37 } 38 39 testBytes := make([]byte, len(p.Keys[1].Key)) 40 copy(testBytes, p.Keys[1].Key) 41 42 p.Key = p.Keys[1].Key 43 p.Keys = nil 44 p.migrateKeyToKeysMap() 45 if p.Key != nil { 46 t.Fatal("policy.Key is not nil") 47 } 48 if len(p.Keys) != 1 { 49 t.Fatal("policy.Keys is the wrong size") 50 } 51 if !reflect.DeepEqual(testBytes, p.Keys[1].Key) { 52 t.Fatal("key mismatch") 53 } 54} 55 56func Test_ArchivingUpgrade(t *testing.T) { 57 testArchivingUpgradeCommon(t, newLockManager(false)) 58 testArchivingUpgradeCommon(t, newLockManager(true)) 59} 60 61func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) { 62 resetKeysArchive() 63 64 // First, we generate a policy and rotate it a number of times. Each time 65 // we'll ensure that we have the expected number of keys in the archive and 66 // the main keys object, which without changing the min version should be 67 // zero and latest, respectively 68 69 storage := &logical.InmemStorage{} 70 71 p, lock, _, err := lm.GetPolicyUpsert(storage, "test", false) 72 if err != nil { 73 t.Fatal(err) 74 } 75 if p == nil || lock == nil { 76 t.Fatal("nil policy or lock") 77 } 78 lock.RUnlock() 79 80 // Store the initial key in the archive 81 keysArchive = append(keysArchive, p.Keys[1]) 82 checkKeys(t, p, storage, "initial", 1, 1, 1) 83 84 for i := 2; i <= 10; i++ { 85 err = p.rotate(storage) 86 if err != nil { 87 t.Fatal(err) 88 } 89 keysArchive = append(keysArchive, p.Keys[i]) 90 checkKeys(t, p, storage, "rotate", i, i, i) 91 } 92 93 // Now, wipe the archive and set the archive version to zero 94 err = storage.Delete("archive/test") 95 if err != nil { 96 t.Fatal(err) 97 } 98 p.ArchiveVersion = 0 99 100 // Store it, but without calling persist, so we don't trigger 101 // handleArchiving() 102 buf, err := p.Serialize() 103 if err != nil { 104 t.Fatal(err) 105 } 106 107 // Write the policy into storage 108 err = storage.Put(&logical.StorageEntry{ 109 Key: "policy/" + p.Name, 110 Value: buf, 111 }) 112 if err != nil { 113 t.Fatal(err) 114 } 115 116 // If we're caching, expire from the cache since we modified it 117 // under-the-hood 118 if lm.CacheActive() { 119 delete(lm.cache, "test") 120 } 121 122 // Now get the policy again; the upgrade should happen automatically 123 p, lock, err = lm.GetPolicyShared(storage, "test") 124 if err != nil { 125 t.Fatal(err) 126 } 127 if p == nil || lock == nil { 128 t.Fatal("nil policy or lock") 129 } 130 lock.RUnlock() 131 132 checkKeys(t, p, storage, "upgrade", 10, 10, 10) 133 134 // Let's check some deletion logic while we're at it 135 136 // The policy should be in there 137 if lm.CacheActive() && lm.cache["test"] == nil { 138 t.Fatal("nil policy in cache") 139 } 140 141 // First we'll do this wrong, by not setting the deletion flag 142 err = lm.DeletePolicy(storage, "test") 143 if err == nil { 144 t.Fatal("got nil error, but should not have been able to delete since we didn't set the deletion flag on the policy") 145 } 146 147 // The policy should still be in there 148 if lm.CacheActive() && lm.cache["test"] == nil { 149 t.Fatal("nil policy in cache") 150 } 151 152 p, lock, err = lm.GetPolicyShared(storage, "test") 153 if err != nil { 154 t.Fatal(err) 155 } 156 if p == nil || lock == nil { 157 t.Fatal("policy or lock nil after bad delete") 158 } 159 lock.RUnlock() 160 161 // Now do it properly 162 p.DeletionAllowed = true 163 err = p.Persist(storage) 164 if err != nil { 165 t.Fatal(err) 166 } 167 err = lm.DeletePolicy(storage, "test") 168 if err != nil { 169 t.Fatal(err) 170 } 171 172 // The policy should *not* be in there 173 if lm.CacheActive() && lm.cache["test"] != nil { 174 t.Fatal("non-nil policy in cache") 175 } 176 177 p, lock, err = lm.GetPolicyShared(storage, "test") 178 if err != nil { 179 t.Fatal(err) 180 } 181 if p != nil || lock != nil { 182 t.Fatal("policy or lock not nil after delete") 183 } 184} 185 186func Test_Archiving(t *testing.T) { 187 testArchivingCommon(t, newLockManager(false)) 188 testArchivingCommon(t, newLockManager(true)) 189} 190 191func testArchivingCommon(t *testing.T, lm *lockManager) { 192 resetKeysArchive() 193 194 // First, we generate a policy and rotate it a number of times. Each time 195 // we'll ensure that we have the expected number of keys in the archive and 196 // the main keys object, which without changing the min version should be 197 // zero and latest, respectively 198 199 storage := &logical.InmemStorage{} 200 201 p, lock, _, err := lm.GetPolicyUpsert(storage, "test", false) 202 if lock != nil { 203 defer lock.RUnlock() 204 } 205 if err != nil { 206 t.Fatal(err) 207 } 208 if p == nil { 209 t.Fatal("nil policy") 210 } 211 212 // Store the initial key in the archive 213 keysArchive = append(keysArchive, p.Keys[1]) 214 checkKeys(t, p, storage, "initial", 1, 1, 1) 215 216 for i := 2; i <= 10; i++ { 217 err = p.rotate(storage) 218 if err != nil { 219 t.Fatal(err) 220 } 221 keysArchive = append(keysArchive, p.Keys[i]) 222 checkKeys(t, p, storage, "rotate", i, i, i) 223 } 224 225 // Move the min decryption version up 226 for i := 1; i <= 10; i++ { 227 p.MinDecryptionVersion = i 228 229 err = p.Persist(storage) 230 if err != nil { 231 t.Fatal(err) 232 } 233 // We expect to find: 234 // * The keys in archive are the same as the latest version 235 // * The latest version is constant 236 // * The number of keys in the policy itself is from the min 237 // decryption version up to the latest version, so for e.g. 7 and 238 // 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min 239 // decryption version plus 1 (the min decryption version key 240 // itself) 241 checkKeys(t, p, storage, "minadd", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1) 242 } 243 244 // Move the min decryption version down 245 for i := 10; i >= 1; i-- { 246 p.MinDecryptionVersion = i 247 248 err = p.Persist(storage) 249 if err != nil { 250 t.Fatal(err) 251 } 252 // We expect to find: 253 // * The keys in archive are never removed so same as the latest version 254 // * The latest version is constant 255 // * The number of keys in the policy itself is from the min 256 // decryption version up to the latest version, so for e.g. 7 and 257 // 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min 258 // decryption version plus 1 (the min decryption version key 259 // itself) 260 checkKeys(t, p, storage, "minsub", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1) 261 } 262} 263 264func checkKeys(t *testing.T, 265 policy *Policy, 266 storage logical.Storage, 267 action string, 268 archiveVer, latestVer, keysSize int) { 269 270 // Sanity check 271 if len(keysArchive) != latestVer+1 { 272 t.Fatalf("latest expected key version is %d, expected test keys archive size is %d, "+ 273 "but keys archive is of size %d", latestVer, latestVer+1, len(keysArchive)) 274 } 275 276 archive, err := policy.loadArchive(storage) 277 if err != nil { 278 t.Fatal(err) 279 } 280 281 badArchiveVer := false 282 if archiveVer == 0 { 283 if len(archive.Keys) != 0 || policy.ArchiveVersion != 0 { 284 badArchiveVer = true 285 } 286 } else { 287 // We need to subtract one because we have the indexes match key 288 // versions, which start at 1. So for an archive version of 1, we 289 // actually have two entries -- a blank 0 entry, and the key at spot 1 290 if archiveVer != len(archive.Keys)-1 || archiveVer != policy.ArchiveVersion { 291 badArchiveVer = true 292 } 293 } 294 if badArchiveVer { 295 t.Fatalf( 296 "expected archive version %d, found length of archive keys %d and policy archive version %d", 297 archiveVer, len(archive.Keys), policy.ArchiveVersion, 298 ) 299 } 300 301 if latestVer != policy.LatestVersion { 302 t.Fatalf( 303 "expected latest version %d, found %d", 304 latestVer, policy.LatestVersion, 305 ) 306 } 307 308 if keysSize != len(policy.Keys) { 309 t.Fatalf( 310 "expected keys size %d, found %d, action is %s, policy is \n%#v\n", 311 keysSize, len(policy.Keys), action, policy, 312 ) 313 } 314 315 for i := policy.MinDecryptionVersion; i <= policy.LatestVersion; i++ { 316 if _, ok := policy.Keys[i]; !ok { 317 t.Fatalf( 318 "expected key %d, did not find it in policy keys", i, 319 ) 320 } 321 } 322 323 for i := policy.MinDecryptionVersion; i <= policy.LatestVersion; i++ { 324 if !reflect.DeepEqual(policy.Keys[i], keysArchive[i]) { 325 t.Fatalf("key %d not equivalent between policy keys and test keys archive", i) 326 } 327 } 328 329 for i := 1; i < len(archive.Keys); i++ { 330 if !reflect.DeepEqual(archive.Keys[i].Key, keysArchive[i].Key) { 331 t.Fatalf("key %d not equivalent between policy archive and test keys archive", i) 332 } 333 } 334} 335