1package keysutil 2 3import ( 4 "context" 5 "reflect" 6 "strconv" 7 "sync" 8 "testing" 9 "time" 10 11 "github.com/hashicorp/vault/sdk/helper/jsonutil" 12 "github.com/hashicorp/vault/sdk/logical" 13 "github.com/mitchellh/copystructure" 14) 15 16func TestPolicy_KeyEntryMapUpgrade(t *testing.T) { 17 now := time.Now() 18 old := map[int]KeyEntry{ 19 1: { 20 Key: []byte("samplekey"), 21 HMACKey: []byte("samplehmackey"), 22 CreationTime: now, 23 FormattedPublicKey: "sampleformattedpublickey", 24 }, 25 2: { 26 Key: []byte("samplekey2"), 27 HMACKey: []byte("samplehmackey2"), 28 CreationTime: now.Add(10 * time.Second), 29 FormattedPublicKey: "sampleformattedpublickey2", 30 }, 31 } 32 33 oldEncoded, err := jsonutil.EncodeJSON(old) 34 if err != nil { 35 t.Fatal(err) 36 } 37 38 var new keyEntryMap 39 err = jsonutil.DecodeJSON(oldEncoded, &new) 40 if err != nil { 41 t.Fatal(err) 42 } 43 44 newEncoded, err := jsonutil.EncodeJSON(&new) 45 if err != nil { 46 t.Fatal(err) 47 } 48 49 if string(oldEncoded) != string(newEncoded) { 50 t.Fatalf("failed to upgrade key entry map;\nold: %q\nnew: %q", string(oldEncoded), string(newEncoded)) 51 } 52} 53 54func Test_KeyUpgrade(t *testing.T) { 55 lockManagerWithCache, _ := NewLockManager(true, 0) 56 lockManagerWithoutCache, _ := NewLockManager(false, 0) 57 testKeyUpgradeCommon(t, lockManagerWithCache) 58 testKeyUpgradeCommon(t, lockManagerWithoutCache) 59} 60 61func testKeyUpgradeCommon(t *testing.T, lm *LockManager) { 62 ctx := context.Background() 63 64 storage := &logical.InmemStorage{} 65 p, upserted, err := lm.GetPolicy(ctx, PolicyRequest{ 66 Upsert: true, 67 Storage: storage, 68 KeyType: KeyType_AES256_GCM96, 69 Name: "test", 70 }) 71 if err != nil { 72 t.Fatal(err) 73 } 74 if p == nil { 75 t.Fatal("nil policy") 76 } 77 if !upserted { 78 t.Fatal("expected an upsert") 79 } 80 if !lm.useCache { 81 p.Unlock() 82 } 83 84 testBytes := make([]byte, len(p.Keys["1"].Key)) 85 copy(testBytes, p.Keys["1"].Key) 86 87 p.Key = p.Keys["1"].Key 88 p.Keys = nil 89 p.MigrateKeyToKeysMap() 90 if p.Key != nil { 91 t.Fatal("policy.Key is not nil") 92 } 93 if len(p.Keys) != 1 { 94 t.Fatal("policy.Keys is the wrong size") 95 } 96 if !reflect.DeepEqual(testBytes, p.Keys["1"].Key) { 97 t.Fatal("key mismatch") 98 } 99} 100 101func Test_ArchivingUpgrade(t *testing.T) { 102 lockManagerWithCache, _ := NewLockManager(true, 0) 103 lockManagerWithoutCache, _ := NewLockManager(false, 0) 104 testArchivingUpgradeCommon(t, lockManagerWithCache) 105 testArchivingUpgradeCommon(t, lockManagerWithoutCache) 106} 107 108func testArchivingUpgradeCommon(t *testing.T, lm *LockManager) { 109 ctx := context.Background() 110 111 // First, we generate a policy and rotate it a number of times. Each time 112 // we'll ensure that we have the expected number of keys in the archive and 113 // the main keys object, which without changing the min version should be 114 // zero and latest, respectively 115 116 storage := &logical.InmemStorage{} 117 p, _, err := lm.GetPolicy(ctx, PolicyRequest{ 118 Upsert: true, 119 Storage: storage, 120 KeyType: KeyType_AES256_GCM96, 121 Name: "test", 122 }) 123 if err != nil { 124 t.Fatal(err) 125 } 126 if p == nil { 127 t.Fatal("nil policy") 128 } 129 if !lm.useCache { 130 p.Unlock() 131 } 132 133 // Store the initial key in the archive 134 keysArchive := []KeyEntry{KeyEntry{}, p.Keys["1"]} 135 checkKeys(t, ctx, p, storage, keysArchive, "initial", 1, 1, 1) 136 137 for i := 2; i <= 10; i++ { 138 err = p.Rotate(ctx, storage) 139 if err != nil { 140 t.Fatal(err) 141 } 142 keysArchive = append(keysArchive, p.Keys[strconv.Itoa(i)]) 143 checkKeys(t, ctx, p, storage, keysArchive, "rotate", i, i, i) 144 } 145 146 // Now, wipe the archive and set the archive version to zero 147 err = storage.Delete(ctx, "archive/test") 148 if err != nil { 149 t.Fatal(err) 150 } 151 p.ArchiveVersion = 0 152 153 // Store it, but without calling persist, so we don't trigger 154 // handleArchiving() 155 buf, err := p.Serialize() 156 if err != nil { 157 t.Fatal(err) 158 } 159 160 // Write the policy into storage 161 err = storage.Put(ctx, &logical.StorageEntry{ 162 Key: "policy/" + p.Name, 163 Value: buf, 164 }) 165 if err != nil { 166 t.Fatal(err) 167 } 168 169 // If we're caching, expire from the cache since we modified it 170 // under-the-hood 171 if lm.useCache { 172 lm.cache.Delete("test") 173 } 174 175 // Now get the policy again; the upgrade should happen automatically 176 p, _, err = lm.GetPolicy(ctx, PolicyRequest{ 177 Storage: storage, 178 Name: "test", 179 }) 180 if err != nil { 181 t.Fatal(err) 182 } 183 if p == nil { 184 t.Fatal("nil policy") 185 } 186 if !lm.useCache { 187 p.Unlock() 188 } 189 190 checkKeys(t, ctx, p, storage, keysArchive, "upgrade", 10, 10, 10) 191 192 // Let's check some deletion logic while we're at it 193 194 // The policy should be in there 195 if lm.useCache { 196 _, ok := lm.cache.Load("test") 197 if !ok { 198 t.Fatal("nil policy in cache") 199 } 200 } 201 202 // First we'll do this wrong, by not setting the deletion flag 203 err = lm.DeletePolicy(ctx, storage, "test") 204 if err == nil { 205 t.Fatal("got nil error, but should not have been able to delete since we didn't set the deletion flag on the policy") 206 } 207 208 // The policy should still be in there 209 if lm.useCache { 210 _, ok := lm.cache.Load("test") 211 if !ok { 212 t.Fatal("nil policy in cache") 213 } 214 } 215 216 p, _, err = lm.GetPolicy(ctx, PolicyRequest{ 217 Storage: storage, 218 Name: "test", 219 }) 220 if err != nil { 221 t.Fatal(err) 222 } 223 if p == nil { 224 t.Fatal("policy nil after bad delete") 225 } 226 if !lm.useCache { 227 p.Unlock() 228 } 229 230 // Now do it properly 231 p.DeletionAllowed = true 232 err = p.Persist(ctx, storage) 233 if err != nil { 234 t.Fatal(err) 235 } 236 err = lm.DeletePolicy(ctx, storage, "test") 237 if err != nil { 238 t.Fatal(err) 239 } 240 241 // The policy should *not* be in there 242 if lm.useCache { 243 _, ok := lm.cache.Load("test") 244 if ok { 245 t.Fatal("non-nil policy in cache") 246 } 247 } 248 249 p, _, err = lm.GetPolicy(ctx, PolicyRequest{ 250 Storage: storage, 251 Name: "test", 252 }) 253 if err != nil { 254 t.Fatal(err) 255 } 256 if p != nil { 257 t.Fatal("policy not nil after delete") 258 } 259} 260 261func Test_Archiving(t *testing.T) { 262 lockManagerWithCache, _ := NewLockManager(true, 0) 263 lockManagerWithoutCache, _ := NewLockManager(false, 0) 264 testArchivingUpgradeCommon(t, lockManagerWithCache) 265 testArchivingUpgradeCommon(t, lockManagerWithoutCache) 266} 267 268func testArchivingCommon(t *testing.T, lm *LockManager) { 269 ctx := context.Background() 270 271 // First, we generate a policy and rotate it a number of times. Each time 272 // we'll ensure that we have the expected number of keys in the archive and 273 // the main keys object, which without changing the min version should be 274 // zero and latest, respectively 275 276 storage := &logical.InmemStorage{} 277 p, _, err := lm.GetPolicy(ctx, PolicyRequest{ 278 Upsert: true, 279 Storage: storage, 280 KeyType: KeyType_AES256_GCM96, 281 Name: "test", 282 }) 283 if err != nil { 284 t.Fatal(err) 285 } 286 if p == nil { 287 t.Fatal("nil policy") 288 } 289 if !lm.useCache { 290 p.Unlock() 291 } 292 293 // Store the initial key in the archive 294 keysArchive := []KeyEntry{KeyEntry{}, p.Keys["1"]} 295 checkKeys(t, ctx, p, storage, keysArchive, "initial", 1, 1, 1) 296 297 for i := 2; i <= 10; i++ { 298 err = p.Rotate(ctx, storage) 299 if err != nil { 300 t.Fatal(err) 301 } 302 keysArchive = append(keysArchive, p.Keys[strconv.Itoa(i)]) 303 checkKeys(t, ctx, p, storage, keysArchive, "rotate", i, i, i) 304 } 305 306 // Move the min decryption version up 307 for i := 1; i <= 10; i++ { 308 p.MinDecryptionVersion = i 309 310 err = p.Persist(ctx, storage) 311 if err != nil { 312 t.Fatal(err) 313 } 314 // We expect to find: 315 // * The keys in archive are the same as the latest version 316 // * The latest version is constant 317 // * The number of keys in the policy itself is from the min 318 // decryption version up to the latest version, so for e.g. 7 and 319 // 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min 320 // decryption version plus 1 (the min decryption version key 321 // itself) 322 checkKeys(t, ctx, p, storage, keysArchive, "minadd", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1) 323 } 324 325 // Move the min decryption version down 326 for i := 10; i >= 1; i-- { 327 p.MinDecryptionVersion = i 328 329 err = p.Persist(ctx, storage) 330 if err != nil { 331 t.Fatal(err) 332 } 333 // We expect to find: 334 // * The keys in archive are never removed so same as the latest version 335 // * The latest version is constant 336 // * The number of keys in the policy itself is from the min 337 // decryption version up to the latest version, so for e.g. 7 and 338 // 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min 339 // decryption version plus 1 (the min decryption version key 340 // itself) 341 checkKeys(t, ctx, p, storage, keysArchive, "minsub", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1) 342 } 343} 344 345func checkKeys(t *testing.T, 346 ctx context.Context, 347 p *Policy, 348 storage logical.Storage, 349 keysArchive []KeyEntry, 350 action string, 351 archiveVer, latestVer, keysSize int) { 352 353 // Sanity check 354 if len(keysArchive) != latestVer+1 { 355 t.Fatalf("latest expected key version is %d, expected test keys archive size is %d, "+ 356 "but keys archive is of size %d", latestVer, latestVer+1, len(keysArchive)) 357 } 358 359 archive, err := p.LoadArchive(ctx, storage) 360 if err != nil { 361 t.Fatal(err) 362 } 363 364 badArchiveVer := false 365 if archiveVer == 0 { 366 if len(archive.Keys) != 0 || p.ArchiveVersion != 0 { 367 badArchiveVer = true 368 } 369 } else { 370 // We need to subtract one because we have the indexes match key 371 // versions, which start at 1. So for an archive version of 1, we 372 // actually have two entries -- a blank 0 entry, and the key at spot 1 373 if archiveVer != len(archive.Keys)-1 || archiveVer != p.ArchiveVersion { 374 badArchiveVer = true 375 } 376 } 377 if badArchiveVer { 378 t.Fatalf( 379 "expected archive version %d, found length of archive keys %d and policy archive version %d", 380 archiveVer, len(archive.Keys), p.ArchiveVersion, 381 ) 382 } 383 384 if latestVer != p.LatestVersion { 385 t.Fatalf( 386 "expected latest version %d, found %d", 387 latestVer, p.LatestVersion, 388 ) 389 } 390 391 if keysSize != len(p.Keys) { 392 t.Fatalf( 393 "expected keys size %d, found %d, action is %s, policy is \n%#v\n", 394 keysSize, len(p.Keys), action, p, 395 ) 396 } 397 398 for i := p.MinDecryptionVersion; i <= p.LatestVersion; i++ { 399 if _, ok := p.Keys[strconv.Itoa(i)]; !ok { 400 t.Fatalf( 401 "expected key %d, did not find it in policy keys", i, 402 ) 403 } 404 } 405 406 for i := p.MinDecryptionVersion; i <= p.LatestVersion; i++ { 407 ver := strconv.Itoa(i) 408 if !p.Keys[ver].CreationTime.Equal(keysArchive[i].CreationTime) { 409 t.Fatalf("key %d not equivalent between policy keys and test keys archive; policy keys:\n%#v\ntest keys archive:\n%#v\n", i, p.Keys[ver], keysArchive[i]) 410 } 411 polKey := p.Keys[ver] 412 polKey.CreationTime = keysArchive[i].CreationTime 413 p.Keys[ver] = polKey 414 if !reflect.DeepEqual(p.Keys[ver], keysArchive[i]) { 415 t.Fatalf("key %d not equivalent between policy keys and test keys archive; policy keys:\n%#v\ntest keys archive:\n%#v\n", i, p.Keys[ver], keysArchive[i]) 416 } 417 } 418 419 for i := 1; i < len(archive.Keys); i++ { 420 if !reflect.DeepEqual(archive.Keys[i].Key, keysArchive[i].Key) { 421 t.Fatalf("key %d not equivalent between policy archive and test keys archive; policy archive:\n%#v\ntest keys archive:\n%#v\n", i, archive.Keys[i].Key, keysArchive[i].Key) 422 } 423 } 424} 425 426func Test_StorageErrorSafety(t *testing.T) { 427 ctx := context.Background() 428 lm, _ := NewLockManager(true, 0) 429 430 storage := &logical.InmemStorage{} 431 p, _, err := lm.GetPolicy(ctx, PolicyRequest{ 432 Upsert: true, 433 Storage: storage, 434 KeyType: KeyType_AES256_GCM96, 435 Name: "test", 436 }) 437 if err != nil { 438 t.Fatal(err) 439 } 440 if p == nil { 441 t.Fatal("nil policy") 442 } 443 444 // Store the initial key in the archive 445 keysArchive := []KeyEntry{KeyEntry{}, p.Keys["1"]} 446 checkKeys(t, ctx, p, storage, keysArchive, "initial", 1, 1, 1) 447 448 // We use checkKeys here just for sanity; it doesn't really handle cases of 449 // errors below so we do more targeted testing later 450 for i := 2; i <= 5; i++ { 451 err = p.Rotate(ctx, storage) 452 if err != nil { 453 t.Fatal(err) 454 } 455 keysArchive = append(keysArchive, p.Keys[strconv.Itoa(i)]) 456 checkKeys(t, ctx, p, storage, keysArchive, "rotate", i, i, i) 457 } 458 459 underlying := storage.Underlying() 460 underlying.FailPut(true) 461 462 priorLen := len(p.Keys) 463 464 err = p.Rotate(ctx, storage) 465 if err == nil { 466 t.Fatal("expected error") 467 } 468 469 if len(p.Keys) != priorLen { 470 t.Fatal("length of keys should not have changed") 471 } 472} 473 474func Test_BadUpgrade(t *testing.T) { 475 ctx := context.Background() 476 lm, _ := NewLockManager(true, 0) 477 storage := &logical.InmemStorage{} 478 p, _, err := lm.GetPolicy(ctx, PolicyRequest{ 479 Upsert: true, 480 Storage: storage, 481 KeyType: KeyType_AES256_GCM96, 482 Name: "test", 483 }) 484 if err != nil { 485 t.Fatal(err) 486 } 487 if p == nil { 488 t.Fatal("nil policy") 489 } 490 491 orig, err := copystructure.Copy(p) 492 if err != nil { 493 t.Fatal(err) 494 } 495 orig.(*Policy).l = p.l 496 497 p.Key = p.Keys["1"].Key 498 p.Keys = nil 499 p.MinDecryptionVersion = 0 500 501 if err := p.Upgrade(ctx, storage); err != nil { 502 t.Fatal(err) 503 } 504 505 k := p.Keys["1"] 506 o := orig.(*Policy).Keys["1"] 507 k.CreationTime = o.CreationTime 508 k.HMACKey = o.HMACKey 509 p.Keys["1"] = k 510 p.versionPrefixCache = sync.Map{} 511 512 if !reflect.DeepEqual(orig, p) { 513 t.Fatalf("not equal:\n%#v\n%#v", orig, p) 514 } 515 516 // Do it again with a failing storage call 517 underlying := storage.Underlying() 518 underlying.FailPut(true) 519 520 p.Key = p.Keys["1"].Key 521 p.Keys = nil 522 p.MinDecryptionVersion = 0 523 524 if err := p.Upgrade(ctx, storage); err == nil { 525 t.Fatal("expected error") 526 } 527 528 if p.MinDecryptionVersion == 1 { 529 t.Fatal("min decryption version was changed") 530 } 531 if p.Keys != nil { 532 t.Fatal("found upgraded keys") 533 } 534 if p.Key == nil { 535 t.Fatal("non-upgraded key not found") 536 } 537} 538 539func Test_BadArchive(t *testing.T) { 540 ctx := context.Background() 541 lm, _ := NewLockManager(true, 0) 542 storage := &logical.InmemStorage{} 543 p, _, err := lm.GetPolicy(ctx, PolicyRequest{ 544 Upsert: true, 545 Storage: storage, 546 KeyType: KeyType_AES256_GCM96, 547 Name: "test", 548 }) 549 if err != nil { 550 t.Fatal(err) 551 } 552 if p == nil { 553 t.Fatal("nil policy") 554 } 555 556 for i := 2; i <= 10; i++ { 557 err = p.Rotate(ctx, storage) 558 if err != nil { 559 t.Fatal(err) 560 } 561 } 562 563 p.MinDecryptionVersion = 5 564 if err := p.Persist(ctx, storage); err != nil { 565 t.Fatal(err) 566 } 567 if p.ArchiveVersion != 10 { 568 t.Fatalf("unexpected archive version %d", p.ArchiveVersion) 569 } 570 if len(p.Keys) != 6 { 571 t.Fatalf("unexpected key length %d", len(p.Keys)) 572 } 573 574 // Set back 575 p.MinDecryptionVersion = 1 576 if err := p.Persist(ctx, storage); err != nil { 577 t.Fatal(err) 578 } 579 if p.ArchiveVersion != 10 { 580 t.Fatalf("unexpected archive version %d", p.ArchiveVersion) 581 } 582 if len(p.Keys) != 10 { 583 t.Fatalf("unexpected key length %d", len(p.Keys)) 584 } 585 586 // Run it again but we'll turn off storage along the way 587 p.MinDecryptionVersion = 5 588 if err := p.Persist(ctx, storage); err != nil { 589 t.Fatal(err) 590 } 591 if p.ArchiveVersion != 10 { 592 t.Fatalf("unexpected archive version %d", p.ArchiveVersion) 593 } 594 if len(p.Keys) != 6 { 595 t.Fatalf("unexpected key length %d", len(p.Keys)) 596 } 597 598 underlying := storage.Underlying() 599 underlying.FailPut(true) 600 601 // Set back, which should cause p.Keys to be changed if the persist works, 602 // but it doesn't 603 p.MinDecryptionVersion = 1 604 if err := p.Persist(ctx, storage); err == nil { 605 t.Fatal("expected error during put") 606 } 607 if p.ArchiveVersion != 10 { 608 t.Fatalf("unexpected archive version %d", p.ArchiveVersion) 609 } 610 // Here's the expected change 611 if len(p.Keys) != 6 { 612 t.Fatalf("unexpected key length %d", len(p.Keys)) 613 } 614} 615