1package keysutil 2 3import ( 4 "context" 5 "crypto/rand" 6 "reflect" 7 "strconv" 8 "sync" 9 "testing" 10 "time" 11 12 "github.com/hashicorp/vault/sdk/helper/jsonutil" 13 "github.com/hashicorp/vault/sdk/logical" 14 "github.com/mitchellh/copystructure" 15) 16 17func TestPolicy_KeyEntryMapUpgrade(t *testing.T) { 18 now := time.Now() 19 old := map[int]KeyEntry{ 20 1: { 21 Key: []byte("samplekey"), 22 HMACKey: []byte("samplehmackey"), 23 CreationTime: now, 24 FormattedPublicKey: "sampleformattedpublickey", 25 }, 26 2: { 27 Key: []byte("samplekey2"), 28 HMACKey: []byte("samplehmackey2"), 29 CreationTime: now.Add(10 * time.Second), 30 FormattedPublicKey: "sampleformattedpublickey2", 31 }, 32 } 33 34 oldEncoded, err := jsonutil.EncodeJSON(old) 35 if err != nil { 36 t.Fatal(err) 37 } 38 39 var new keyEntryMap 40 err = jsonutil.DecodeJSON(oldEncoded, &new) 41 if err != nil { 42 t.Fatal(err) 43 } 44 45 newEncoded, err := jsonutil.EncodeJSON(&new) 46 if err != nil { 47 t.Fatal(err) 48 } 49 50 if string(oldEncoded) != string(newEncoded) { 51 t.Fatalf("failed to upgrade key entry map;\nold: %q\nnew: %q", string(oldEncoded), string(newEncoded)) 52 } 53} 54 55func Test_KeyUpgrade(t *testing.T) { 56 lockManagerWithCache, _ := NewLockManager(true, 0) 57 lockManagerWithoutCache, _ := NewLockManager(false, 0) 58 testKeyUpgradeCommon(t, lockManagerWithCache) 59 testKeyUpgradeCommon(t, lockManagerWithoutCache) 60} 61 62func testKeyUpgradeCommon(t *testing.T, lm *LockManager) { 63 ctx := context.Background() 64 65 storage := &logical.InmemStorage{} 66 p, upserted, err := lm.GetPolicy(ctx, PolicyRequest{ 67 Upsert: true, 68 Storage: storage, 69 KeyType: KeyType_AES256_GCM96, 70 Name: "test", 71 }, rand.Reader) 72 if err != nil { 73 t.Fatal(err) 74 } 75 if p == nil { 76 t.Fatal("nil policy") 77 } 78 if !upserted { 79 t.Fatal("expected an upsert") 80 } 81 if !lm.useCache { 82 p.Unlock() 83 } 84 85 testBytes := make([]byte, len(p.Keys["1"].Key)) 86 copy(testBytes, p.Keys["1"].Key) 87 88 p.Key = p.Keys["1"].Key 89 p.Keys = nil 90 p.MigrateKeyToKeysMap() 91 if p.Key != nil { 92 t.Fatal("policy.Key is not nil") 93 } 94 if len(p.Keys) != 1 { 95 t.Fatal("policy.Keys is the wrong size") 96 } 97 if !reflect.DeepEqual(testBytes, p.Keys["1"].Key) { 98 t.Fatal("key mismatch") 99 } 100} 101 102func Test_ArchivingUpgrade(t *testing.T) { 103 lockManagerWithCache, _ := NewLockManager(true, 0) 104 lockManagerWithoutCache, _ := NewLockManager(false, 0) 105 testArchivingUpgradeCommon(t, lockManagerWithCache) 106 testArchivingUpgradeCommon(t, lockManagerWithoutCache) 107} 108 109func testArchivingUpgradeCommon(t *testing.T, lm *LockManager) { 110 ctx := context.Background() 111 112 // First, we generate a policy and rotate it a number of times. Each time 113 // we'll ensure that we have the expected number of keys in the archive and 114 // the main keys object, which without changing the min version should be 115 // zero and latest, respectively 116 117 storage := &logical.InmemStorage{} 118 p, _, err := lm.GetPolicy(ctx, PolicyRequest{ 119 Upsert: true, 120 Storage: storage, 121 KeyType: KeyType_AES256_GCM96, 122 Name: "test", 123 }, rand.Reader) 124 if err != nil { 125 t.Fatal(err) 126 } 127 if p == nil { 128 t.Fatal("nil policy") 129 } 130 if !lm.useCache { 131 p.Unlock() 132 } 133 134 // Store the initial key in the archive 135 keysArchive := []KeyEntry{KeyEntry{}, p.Keys["1"]} 136 checkKeys(t, ctx, p, storage, keysArchive, "initial", 1, 1, 1) 137 138 for i := 2; i <= 10; i++ { 139 err = p.Rotate(ctx, storage, rand.Reader) 140 if err != nil { 141 t.Fatal(err) 142 } 143 keysArchive = append(keysArchive, p.Keys[strconv.Itoa(i)]) 144 checkKeys(t, ctx, p, storage, keysArchive, "rotate", i, i, i) 145 } 146 147 // Now, wipe the archive and set the archive version to zero 148 err = storage.Delete(ctx, "archive/test") 149 if err != nil { 150 t.Fatal(err) 151 } 152 p.ArchiveVersion = 0 153 154 // Store it, but without calling persist, so we don't trigger 155 // handleArchiving() 156 buf, err := p.Serialize() 157 if err != nil { 158 t.Fatal(err) 159 } 160 161 // Write the policy into storage 162 err = storage.Put(ctx, &logical.StorageEntry{ 163 Key: "policy/" + p.Name, 164 Value: buf, 165 }) 166 if err != nil { 167 t.Fatal(err) 168 } 169 170 // If we're caching, expire from the cache since we modified it 171 // under-the-hood 172 if lm.useCache { 173 lm.cache.Delete("test") 174 } 175 176 // Now get the policy again; the upgrade should happen automatically 177 p, _, err = lm.GetPolicy(ctx, PolicyRequest{ 178 Storage: storage, 179 Name: "test", 180 }, rand.Reader) 181 if err != nil { 182 t.Fatal(err) 183 } 184 if p == nil { 185 t.Fatal("nil policy") 186 } 187 if !lm.useCache { 188 p.Unlock() 189 } 190 191 checkKeys(t, ctx, p, storage, keysArchive, "upgrade", 10, 10, 10) 192 193 // Let's check some deletion logic while we're at it 194 195 // The policy should be in there 196 if lm.useCache { 197 _, ok := lm.cache.Load("test") 198 if !ok { 199 t.Fatal("nil policy in cache") 200 } 201 } 202 203 // First we'll do this wrong, by not setting the deletion flag 204 err = lm.DeletePolicy(ctx, storage, "test") 205 if err == nil { 206 t.Fatal("got nil error, but should not have been able to delete since we didn't set the deletion flag on the policy") 207 } 208 209 // The policy should still be in there 210 if lm.useCache { 211 _, ok := lm.cache.Load("test") 212 if !ok { 213 t.Fatal("nil policy in cache") 214 } 215 } 216 217 p, _, err = lm.GetPolicy(ctx, PolicyRequest{ 218 Storage: storage, 219 Name: "test", 220 }, rand.Reader) 221 if err != nil { 222 t.Fatal(err) 223 } 224 if p == nil { 225 t.Fatal("policy nil after bad delete") 226 } 227 if !lm.useCache { 228 p.Unlock() 229 } 230 231 // Now do it properly 232 p.DeletionAllowed = true 233 err = p.Persist(ctx, storage) 234 if err != nil { 235 t.Fatal(err) 236 } 237 err = lm.DeletePolicy(ctx, storage, "test") 238 if err != nil { 239 t.Fatal(err) 240 } 241 242 // The policy should *not* be in there 243 if lm.useCache { 244 _, ok := lm.cache.Load("test") 245 if ok { 246 t.Fatal("non-nil policy in cache") 247 } 248 } 249 250 p, _, err = lm.GetPolicy(ctx, PolicyRequest{ 251 Storage: storage, 252 Name: "test", 253 }, rand.Reader) 254 if err != nil { 255 t.Fatal(err) 256 } 257 if p != nil { 258 t.Fatal("policy not nil after delete") 259 } 260} 261 262func Test_Archiving(t *testing.T) { 263 lockManagerWithCache, _ := NewLockManager(true, 0) 264 lockManagerWithoutCache, _ := NewLockManager(false, 0) 265 testArchivingUpgradeCommon(t, lockManagerWithCache) 266 testArchivingUpgradeCommon(t, lockManagerWithoutCache) 267} 268 269func testArchivingCommon(t *testing.T, lm *LockManager) { 270 ctx := context.Background() 271 272 // First, we generate a policy and rotate it a number of times. Each time 273 // we'll ensure that we have the expected number of keys in the archive and 274 // the main keys object, which without changing the min version should be 275 // zero and latest, respectively 276 277 storage := &logical.InmemStorage{} 278 p, _, err := lm.GetPolicy(ctx, PolicyRequest{ 279 Upsert: true, 280 Storage: storage, 281 KeyType: KeyType_AES256_GCM96, 282 Name: "test", 283 }, rand.Reader) 284 if err != nil { 285 t.Fatal(err) 286 } 287 if p == nil { 288 t.Fatal("nil policy") 289 } 290 if !lm.useCache { 291 p.Unlock() 292 } 293 294 // Store the initial key in the archive 295 keysArchive := []KeyEntry{KeyEntry{}, p.Keys["1"]} 296 checkKeys(t, ctx, p, storage, keysArchive, "initial", 1, 1, 1) 297 298 for i := 2; i <= 10; i++ { 299 err = p.Rotate(ctx, storage, rand.Reader) 300 if err != nil { 301 t.Fatal(err) 302 } 303 keysArchive = append(keysArchive, p.Keys[strconv.Itoa(i)]) 304 checkKeys(t, ctx, p, storage, keysArchive, "rotate", i, i, i) 305 } 306 307 // Move the min decryption version up 308 for i := 1; i <= 10; i++ { 309 p.MinDecryptionVersion = i 310 311 err = p.Persist(ctx, storage) 312 if err != nil { 313 t.Fatal(err) 314 } 315 // We expect to find: 316 // * The keys in archive are the same as the latest version 317 // * The latest version is constant 318 // * The number of keys in the policy itself is from the min 319 // decryption version up to the latest version, so for e.g. 7 and 320 // 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min 321 // decryption version plus 1 (the min decryption version key 322 // itself) 323 checkKeys(t, ctx, p, storage, keysArchive, "minadd", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1) 324 } 325 326 // Move the min decryption version down 327 for i := 10; i >= 1; i-- { 328 p.MinDecryptionVersion = i 329 330 err = p.Persist(ctx, storage) 331 if err != nil { 332 t.Fatal(err) 333 } 334 // We expect to find: 335 // * The keys in archive are never removed so same as the latest version 336 // * The latest version is constant 337 // * The number of keys in the policy itself is from the min 338 // decryption version up to the latest version, so for e.g. 7 and 339 // 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min 340 // decryption version plus 1 (the min decryption version key 341 // itself) 342 checkKeys(t, ctx, p, storage, keysArchive, "minsub", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1) 343 } 344} 345 346func checkKeys(t *testing.T, 347 ctx context.Context, 348 p *Policy, 349 storage logical.Storage, 350 keysArchive []KeyEntry, 351 action string, 352 archiveVer, latestVer, keysSize int) { 353 354 // Sanity check 355 if len(keysArchive) != latestVer+1 { 356 t.Fatalf("latest expected key version is %d, expected test keys archive size is %d, "+ 357 "but keys archive is of size %d", latestVer, latestVer+1, len(keysArchive)) 358 } 359 360 archive, err := p.LoadArchive(ctx, storage) 361 if err != nil { 362 t.Fatal(err) 363 } 364 365 badArchiveVer := false 366 if archiveVer == 0 { 367 if len(archive.Keys) != 0 || p.ArchiveVersion != 0 { 368 badArchiveVer = true 369 } 370 } else { 371 // We need to subtract one because we have the indexes match key 372 // versions, which start at 1. So for an archive version of 1, we 373 // actually have two entries -- a blank 0 entry, and the key at spot 1 374 if archiveVer != len(archive.Keys)-1 || archiveVer != p.ArchiveVersion { 375 badArchiveVer = true 376 } 377 } 378 if badArchiveVer { 379 t.Fatalf( 380 "expected archive version %d, found length of archive keys %d and policy archive version %d", 381 archiveVer, len(archive.Keys), p.ArchiveVersion, 382 ) 383 } 384 385 if latestVer != p.LatestVersion { 386 t.Fatalf( 387 "expected latest version %d, found %d", 388 latestVer, p.LatestVersion, 389 ) 390 } 391 392 if keysSize != len(p.Keys) { 393 t.Fatalf( 394 "expected keys size %d, found %d, action is %s, policy is \n%#v\n", 395 keysSize, len(p.Keys), action, p, 396 ) 397 } 398 399 for i := p.MinDecryptionVersion; i <= p.LatestVersion; i++ { 400 if _, ok := p.Keys[strconv.Itoa(i)]; !ok { 401 t.Fatalf( 402 "expected key %d, did not find it in policy keys", i, 403 ) 404 } 405 } 406 407 for i := p.MinDecryptionVersion; i <= p.LatestVersion; i++ { 408 ver := strconv.Itoa(i) 409 if !p.Keys[ver].CreationTime.Equal(keysArchive[i].CreationTime) { 410 t.Fatalf("key %d not equivalent between policy keys and test keys archive; policy keys:\n%#v\ntest keys archive:\n%#v\n", i, p.Keys[ver], keysArchive[i]) 411 } 412 polKey := p.Keys[ver] 413 polKey.CreationTime = keysArchive[i].CreationTime 414 p.Keys[ver] = polKey 415 if !reflect.DeepEqual(p.Keys[ver], keysArchive[i]) { 416 t.Fatalf("key %d not equivalent between policy keys and test keys archive; policy keys:\n%#v\ntest keys archive:\n%#v\n", i, p.Keys[ver], keysArchive[i]) 417 } 418 } 419 420 for i := 1; i < len(archive.Keys); i++ { 421 if !reflect.DeepEqual(archive.Keys[i].Key, keysArchive[i].Key) { 422 t.Fatalf("key %d not equivalent between policy archive and test keys archive; policy archive:\n%#v\ntest keys archive:\n%#v\n", i, archive.Keys[i].Key, keysArchive[i].Key) 423 } 424 } 425} 426 427func Test_StorageErrorSafety(t *testing.T) { 428 ctx := context.Background() 429 lm, _ := NewLockManager(true, 0) 430 431 storage := &logical.InmemStorage{} 432 p, _, err := lm.GetPolicy(ctx, PolicyRequest{ 433 Upsert: true, 434 Storage: storage, 435 KeyType: KeyType_AES256_GCM96, 436 Name: "test", 437 }, rand.Reader) 438 if err != nil { 439 t.Fatal(err) 440 } 441 if p == nil { 442 t.Fatal("nil policy") 443 } 444 445 // Store the initial key in the archive 446 keysArchive := []KeyEntry{KeyEntry{}, p.Keys["1"]} 447 checkKeys(t, ctx, p, storage, keysArchive, "initial", 1, 1, 1) 448 449 // We use checkKeys here just for sanity; it doesn't really handle cases of 450 // errors below so we do more targeted testing later 451 for i := 2; i <= 5; i++ { 452 err = p.Rotate(ctx, storage, rand.Reader) 453 if err != nil { 454 t.Fatal(err) 455 } 456 keysArchive = append(keysArchive, p.Keys[strconv.Itoa(i)]) 457 checkKeys(t, ctx, p, storage, keysArchive, "rotate", i, i, i) 458 } 459 460 underlying := storage.Underlying() 461 underlying.FailPut(true) 462 463 priorLen := len(p.Keys) 464 465 err = p.Rotate(ctx, storage, rand.Reader) 466 if err == nil { 467 t.Fatal("expected error") 468 } 469 470 if len(p.Keys) != priorLen { 471 t.Fatal("length of keys should not have changed") 472 } 473} 474 475func Test_BadUpgrade(t *testing.T) { 476 ctx := context.Background() 477 lm, _ := NewLockManager(true, 0) 478 storage := &logical.InmemStorage{} 479 p, _, err := lm.GetPolicy(ctx, PolicyRequest{ 480 Upsert: true, 481 Storage: storage, 482 KeyType: KeyType_AES256_GCM96, 483 Name: "test", 484 }, rand.Reader) 485 if err != nil { 486 t.Fatal(err) 487 } 488 if p == nil { 489 t.Fatal("nil policy") 490 } 491 492 orig, err := copystructure.Copy(p) 493 if err != nil { 494 t.Fatal(err) 495 } 496 orig.(*Policy).l = p.l 497 498 p.Key = p.Keys["1"].Key 499 p.Keys = nil 500 p.MinDecryptionVersion = 0 501 502 if err := p.Upgrade(ctx, storage, rand.Reader); err != nil { 503 t.Fatal(err) 504 } 505 506 k := p.Keys["1"] 507 o := orig.(*Policy).Keys["1"] 508 k.CreationTime = o.CreationTime 509 k.HMACKey = o.HMACKey 510 p.Keys["1"] = k 511 p.versionPrefixCache = sync.Map{} 512 513 if !reflect.DeepEqual(orig, p) { 514 t.Fatalf("not equal:\n%#v\n%#v", orig, p) 515 } 516 517 // Do it again with a failing storage call 518 underlying := storage.Underlying() 519 underlying.FailPut(true) 520 521 p.Key = p.Keys["1"].Key 522 p.Keys = nil 523 p.MinDecryptionVersion = 0 524 525 if err := p.Upgrade(ctx, storage, rand.Reader); err == nil { 526 t.Fatal("expected error") 527 } 528 529 if p.MinDecryptionVersion == 1 { 530 t.Fatal("min decryption version was changed") 531 } 532 if p.Keys != nil { 533 t.Fatal("found upgraded keys") 534 } 535 if p.Key == nil { 536 t.Fatal("non-upgraded key not found") 537 } 538} 539 540func Test_BadArchive(t *testing.T) { 541 ctx := context.Background() 542 lm, _ := NewLockManager(true, 0) 543 storage := &logical.InmemStorage{} 544 p, _, err := lm.GetPolicy(ctx, PolicyRequest{ 545 Upsert: true, 546 Storage: storage, 547 KeyType: KeyType_AES256_GCM96, 548 Name: "test", 549 }, rand.Reader) 550 if err != nil { 551 t.Fatal(err) 552 } 553 if p == nil { 554 t.Fatal("nil policy") 555 } 556 557 for i := 2; i <= 10; i++ { 558 err = p.Rotate(ctx, storage, rand.Reader) 559 if err != nil { 560 t.Fatal(err) 561 } 562 } 563 564 p.MinDecryptionVersion = 5 565 if err := p.Persist(ctx, storage); err != nil { 566 t.Fatal(err) 567 } 568 if p.ArchiveVersion != 10 { 569 t.Fatalf("unexpected archive version %d", p.ArchiveVersion) 570 } 571 if len(p.Keys) != 6 { 572 t.Fatalf("unexpected key length %d", len(p.Keys)) 573 } 574 575 // Set back 576 p.MinDecryptionVersion = 1 577 if err := p.Persist(ctx, storage); err != nil { 578 t.Fatal(err) 579 } 580 if p.ArchiveVersion != 10 { 581 t.Fatalf("unexpected archive version %d", p.ArchiveVersion) 582 } 583 if len(p.Keys) != 10 { 584 t.Fatalf("unexpected key length %d", len(p.Keys)) 585 } 586 587 // Run it again but we'll turn off storage along the way 588 p.MinDecryptionVersion = 5 589 if err := p.Persist(ctx, storage); err != nil { 590 t.Fatal(err) 591 } 592 if p.ArchiveVersion != 10 { 593 t.Fatalf("unexpected archive version %d", p.ArchiveVersion) 594 } 595 if len(p.Keys) != 6 { 596 t.Fatalf("unexpected key length %d", len(p.Keys)) 597 } 598 599 underlying := storage.Underlying() 600 underlying.FailPut(true) 601 602 // Set back, which should cause p.Keys to be changed if the persist works, 603 // but it doesn't 604 p.MinDecryptionVersion = 1 605 if err := p.Persist(ctx, storage); err == nil { 606 t.Fatal("expected error during put") 607 } 608 if p.ArchiveVersion != 10 { 609 t.Fatalf("unexpected archive version %d", p.ArchiveVersion) 610 } 611 // Here's the expected change 612 if len(p.Keys) != 6 { 613 t.Fatalf("unexpected key length %d", len(p.Keys)) 614 } 615} 616