1package keysutil 2 3import ( 4 "context" 5 "encoding/base64" 6 "errors" 7 "fmt" 8 "sync" 9 "sync/atomic" 10 "time" 11 12 "github.com/hashicorp/errwrap" 13 "github.com/hashicorp/vault/sdk/helper/jsonutil" 14 "github.com/hashicorp/vault/sdk/helper/locksutil" 15 "github.com/hashicorp/vault/sdk/logical" 16) 17 18const ( 19 shared = false 20 exclusive = true 21 currentConvergentVersion = 3 22) 23 24var ( 25 errNeedExclusiveLock = errors.New("an exclusive lock is needed for this operation") 26) 27 28// PolicyRequest holds values used when requesting a policy. Most values are 29// only used during an upsert. 30type PolicyRequest struct { 31 // The storage to use 32 Storage logical.Storage 33 34 // The name of the policy 35 Name string 36 37 // The key type 38 KeyType KeyType 39 40 // Whether it should be derived 41 Derived bool 42 43 // Whether to enable convergent encryption 44 Convergent bool 45 46 // Whether to allow export 47 Exportable bool 48 49 // Whether to upsert 50 Upsert bool 51 52 // Whether to allow plaintext backup 53 AllowPlaintextBackup bool 54} 55 56type LockManager struct { 57 useCache bool 58 cache Cache 59 keyLocks []*locksutil.LockEntry 60} 61 62func NewLockManager(useCache bool, cacheSize int) (*LockManager, error) { 63 // determine the type of cache to create 64 var cache Cache 65 switch { 66 case !useCache: 67 case cacheSize < 0: 68 return nil, errors.New("cache size must be greater or equal to zero") 69 case cacheSize == 0: 70 cache = NewTransitSyncMap() 71 case cacheSize > 0: 72 newLRUCache, err := NewTransitLRU(cacheSize) 73 if err != nil { 74 return nil, errwrap.Wrapf("failed to create cache: {{err}}", err) 75 } 76 cache = newLRUCache 77 } 78 79 lm := &LockManager{ 80 useCache: useCache, 81 cache: cache, 82 keyLocks: locksutil.CreateLocks(), 83 } 84 85 return lm, nil 86} 87 88func (lm *LockManager) GetCacheSize() int { 89 if !lm.useCache { 90 return 0 91 } 92 return lm.cache.Size() 93} 94 95func (lm *LockManager) GetUseCache() bool { 96 return lm.useCache 97} 98 99func (lm *LockManager) InvalidatePolicy(name string) { 100 if lm.useCache { 101 lm.cache.Delete(name) 102 } 103} 104 105// RestorePolicy acquires an exclusive lock on the policy name and restores the 106// given policy along with the archive. 107func (lm *LockManager) RestorePolicy(ctx context.Context, storage logical.Storage, name, backup string, force bool) error { 108 backupBytes, err := base64.StdEncoding.DecodeString(backup) 109 if err != nil { 110 return err 111 } 112 113 var keyData KeyData 114 err = jsonutil.DecodeJSON(backupBytes, &keyData) 115 if err != nil { 116 return err 117 } 118 119 // Set a different name if desired 120 if name != "" { 121 keyData.Policy.Name = name 122 } 123 124 name = keyData.Policy.Name 125 126 // Grab the exclusive lock as we'll be modifying disk 127 lock := locksutil.LockForKey(lm.keyLocks, name) 128 lock.Lock() 129 defer lock.Unlock() 130 131 var ok bool 132 var pRaw interface{} 133 134 // If the policy is in cache and 'force' is not specified, error out. Anywhere 135 // that would put it in the cache will also be protected by the mutex above, 136 // so we don't need to re-check the cache later. 137 if lm.useCache { 138 pRaw, ok = lm.cache.Load(name) 139 if ok && !force { 140 return fmt.Errorf("key %q already exists", name) 141 } 142 } 143 144 // Conditionally look up the policy from storage, depending on the use of 145 // 'force' and if the policy was found in cache. 146 // 147 // - If was not found in cache and we are not using 'force', look for it in 148 // storage. If found, error out. 149 // 150 // - If it was found in cache and we are using 'force', pRaw will not be nil 151 // and we do not look the policy up from storage 152 // 153 // - If it was found in cache and we are not using 'force', we should have 154 // returned above with error 155 var p *Policy 156 if pRaw == nil { 157 p, err = lm.getPolicyFromStorage(ctx, storage, name) 158 if err != nil { 159 return err 160 } 161 if p != nil && !force { 162 return fmt.Errorf("key %q already exists", name) 163 } 164 } 165 166 // If both pRaw and p above are nil and 'force' is specified, we don't need to 167 // grab policy locks as we have ensured it doesn't already exist, so there 168 // will be no races as nothing else has this pointer. If 'force' was not used, 169 // an error would have been returned by now if the policy already existed 170 if pRaw != nil { 171 p = pRaw.(*Policy) 172 } 173 if p != nil { 174 p.l.Lock() 175 defer p.l.Unlock() 176 } 177 178 // Restore the archived keys 179 if keyData.ArchivedKeys != nil { 180 err = keyData.Policy.storeArchive(ctx, storage, keyData.ArchivedKeys) 181 if err != nil { 182 return errwrap.Wrapf(fmt.Sprintf("failed to restore archived keys for key %q: {{err}}", name), err) 183 } 184 } 185 186 // Mark that policy as a restored key 187 keyData.Policy.RestoreInfo = &RestoreInfo{ 188 Time: time.Now(), 189 Version: keyData.Policy.LatestVersion, 190 } 191 192 // Restore the policy. This will also attempt to adjust the archive. 193 err = keyData.Policy.Persist(ctx, storage) 194 if err != nil { 195 return errwrap.Wrapf(fmt.Sprintf("failed to restore the policy %q: {{err}}", name), err) 196 } 197 198 keyData.Policy.l = new(sync.RWMutex) 199 200 // Update the cache to contain the restored policy 201 if lm.useCache { 202 lm.cache.Store(name, keyData.Policy) 203 } 204 return nil 205} 206 207func (lm *LockManager) BackupPolicy(ctx context.Context, storage logical.Storage, name string) (string, error) { 208 var p *Policy 209 var err error 210 211 // Backup writes information about when the backup took place, so we get an 212 // exclusive lock here 213 lock := locksutil.LockForKey(lm.keyLocks, name) 214 lock.Lock() 215 defer lock.Unlock() 216 217 var ok bool 218 var pRaw interface{} 219 220 if lm.useCache { 221 pRaw, ok = lm.cache.Load(name) 222 } 223 if ok { 224 p = pRaw.(*Policy) 225 p.l.Lock() 226 defer p.l.Unlock() 227 } else { 228 // If the policy doesn't exit in storage, error out 229 p, err = lm.getPolicyFromStorage(ctx, storage, name) 230 if err != nil { 231 return "", err 232 } 233 if p == nil { 234 return "", fmt.Errorf(fmt.Sprintf("key %q not found", name)) 235 } 236 } 237 238 if atomic.LoadUint32(&p.deleted) == 1 { 239 return "", fmt.Errorf(fmt.Sprintf("key %q not found", name)) 240 } 241 242 backup, err := p.Backup(ctx, storage) 243 if err != nil { 244 return "", err 245 } 246 247 return backup, nil 248} 249 250// When the function returns, if caching was disabled, the Policy's lock must 251// be unlocked when the caller is done (and it should not be re-locked). 252func (lm *LockManager) GetPolicy(ctx context.Context, req PolicyRequest) (retP *Policy, retUpserted bool, retErr error) { 253 var p *Policy 254 var err error 255 var ok bool 256 var pRaw interface{} 257 258 // Check if it's in our cache. If so, return right away. 259 if lm.useCache { 260 pRaw, ok = lm.cache.Load(req.Name) 261 } 262 if ok { 263 p = pRaw.(*Policy) 264 if atomic.LoadUint32(&p.deleted) == 1 { 265 return nil, false, nil 266 } 267 return p, false, nil 268 } 269 270 // We're not using the cache, or it wasn't found; get an exclusive lock. 271 // This ensures that any other process writing the actual storage will be 272 // finished before we load from storage. 273 lock := locksutil.LockForKey(lm.keyLocks, req.Name) 274 lock.Lock() 275 276 // If we are using the cache, defer the lock unlock; otherwise we will 277 // return from here with the lock still held. 278 cleanup := func() { 279 switch { 280 // If using the cache we always unlock, the caller locks the policy 281 // themselves 282 case lm.useCache: 283 lock.Unlock() 284 285 // If not using the cache, if we aren't returning a policy the caller 286 // doesn't have a lock, so we must unlock 287 case retP == nil: 288 lock.Unlock() 289 } 290 } 291 292 // Check the cache again 293 if lm.useCache { 294 pRaw, ok = lm.cache.Load(req.Name) 295 } 296 if ok { 297 p = pRaw.(*Policy) 298 if atomic.LoadUint32(&p.deleted) == 1 { 299 cleanup() 300 return nil, false, nil 301 } 302 retP = p 303 cleanup() 304 return 305 } 306 307 // Load it from storage 308 p, err = lm.getPolicyFromStorage(ctx, req.Storage, req.Name) 309 if err != nil { 310 cleanup() 311 return nil, false, err 312 } 313 // We don't need to lock the policy as there would be no other holders of 314 // the pointer 315 316 if p == nil { 317 // This is the only place we upsert a new policy, so if upsert is not 318 // specified, or the lock type is wrong, unlock before returning 319 if !req.Upsert { 320 cleanup() 321 return nil, false, nil 322 } 323 324 // We create the policy here, then at the end we do a LoadOrStore. If 325 // it's been loaded since we last checked the cache, we return an error 326 // to the user to let them know that their request can't be satisfied 327 // because we don't know if the parameters match. 328 329 switch req.KeyType { 330 case KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305: 331 if req.Convergent && !req.Derived { 332 cleanup() 333 return nil, false, fmt.Errorf("convergent encryption requires derivation to be enabled") 334 } 335 336 case KeyType_ECDSA_P256: 337 if req.Derived || req.Convergent { 338 cleanup() 339 return nil, false, fmt.Errorf("key derivation and convergent encryption not supported for keys of type %v", req.KeyType) 340 } 341 342 case KeyType_ED25519: 343 if req.Convergent { 344 cleanup() 345 return nil, false, fmt.Errorf("convergent encryption not supported for keys of type %v", req.KeyType) 346 } 347 348 case KeyType_RSA2048, KeyType_RSA4096: 349 if req.Derived || req.Convergent { 350 cleanup() 351 return nil, false, fmt.Errorf("key derivation and convergent encryption not supported for keys of type %v", req.KeyType) 352 } 353 354 default: 355 cleanup() 356 return nil, false, fmt.Errorf("unsupported key type %v", req.KeyType) 357 } 358 359 p = &Policy{ 360 l: new(sync.RWMutex), 361 Name: req.Name, 362 Type: req.KeyType, 363 Derived: req.Derived, 364 Exportable: req.Exportable, 365 AllowPlaintextBackup: req.AllowPlaintextBackup, 366 } 367 368 if req.Derived { 369 p.KDF = Kdf_hkdf_sha256 370 if req.Convergent { 371 p.ConvergentEncryption = true 372 // As of version 3 we store the version within each key, so we 373 // set to -1 to indicate that the value in the policy has no 374 // meaning. We still, for backwards compatibility, fall back to 375 // this value if the key doesn't have one, which means it will 376 // only be -1 in the case where every key version is >= 3 377 p.ConvergentVersion = -1 378 } 379 } 380 381 // Performs the actual persist and does setup 382 err = p.Rotate(ctx, req.Storage) 383 if err != nil { 384 cleanup() 385 return nil, false, err 386 } 387 388 if lm.useCache { 389 lm.cache.Store(req.Name, p) 390 } else { 391 p.l = &lock.RWMutex 392 p.writeLocked = true 393 } 394 395 // We don't need to worry about upgrading since it will be a new policy 396 retP = p 397 retUpserted = true 398 cleanup() 399 return 400 } 401 402 if p.NeedsUpgrade() { 403 if err := p.Upgrade(ctx, req.Storage); err != nil { 404 cleanup() 405 return nil, false, err 406 } 407 } 408 409 if lm.useCache { 410 lm.cache.Store(req.Name, p) 411 } else { 412 p.l = &lock.RWMutex 413 p.writeLocked = true 414 } 415 416 retP = p 417 cleanup() 418 return 419} 420 421func (lm *LockManager) DeletePolicy(ctx context.Context, storage logical.Storage, name string) error { 422 var p *Policy 423 var err error 424 var ok bool 425 var pRaw interface{} 426 427 // We may be writing to disk, so grab an exclusive lock. This prevents bad 428 // behavior when the cache is turned off. We also lock the shared policy 429 // object to make sure no requests are in flight. 430 lock := locksutil.LockForKey(lm.keyLocks, name) 431 lock.Lock() 432 defer lock.Unlock() 433 434 if lm.useCache { 435 pRaw, ok = lm.cache.Load(name) 436 } 437 if ok { 438 p = pRaw.(*Policy) 439 p.l.Lock() 440 defer p.l.Unlock() 441 } 442 443 if p == nil { 444 p, err = lm.getPolicyFromStorage(ctx, storage, name) 445 if err != nil { 446 return err 447 } 448 if p == nil { 449 return fmt.Errorf("could not delete key; not found") 450 } 451 } 452 453 if !p.DeletionAllowed { 454 return fmt.Errorf("deletion is not allowed for this key") 455 } 456 457 atomic.StoreUint32(&p.deleted, 1) 458 459 if lm.useCache { 460 lm.cache.Delete(name) 461 } 462 463 err = storage.Delete(ctx, "policy/"+name) 464 if err != nil { 465 return errwrap.Wrapf(fmt.Sprintf("error deleting key %q: {{err}}", name), err) 466 } 467 468 err = storage.Delete(ctx, "archive/"+name) 469 if err != nil { 470 return errwrap.Wrapf(fmt.Sprintf("error deleting key %q archive: {{err}}", name), err) 471 } 472 473 return nil 474} 475 476func (lm *LockManager) getPolicyFromStorage(ctx context.Context, storage logical.Storage, name string) (*Policy, error) { 477 return LoadPolicy(ctx, storage, "policy/"+name) 478} 479