1package storagepacker 2 3import ( 4 "context" 5 "crypto/md5" 6 "fmt" 7 "strconv" 8 "strings" 9 10 "github.com/golang/protobuf/proto" 11 "github.com/hashicorp/errwrap" 12 "github.com/hashicorp/go-hclog" 13 log "github.com/hashicorp/go-hclog" 14 "github.com/hashicorp/vault/sdk/helper/compressutil" 15 "github.com/hashicorp/vault/sdk/helper/locksutil" 16 "github.com/hashicorp/vault/sdk/logical" 17) 18 19const ( 20 bucketCount = 256 21 // StoragePackerBucketsPrefix is the default storage key prefix under which 22 // bucket data will be stored. 23 StoragePackerBucketsPrefix = "packer/buckets/" 24) 25 26// StoragePacker packs items into a specific number of buckets by hashing 27// its identifier and indexing on it. Currently this supports only 256 bucket entries and 28// hence relies on the first byte of the hash value for indexing. 29type StoragePacker struct { 30 view logical.Storage 31 logger log.Logger 32 storageLocks []*locksutil.LockEntry 33 viewPrefix string 34} 35 36// View returns the storage view configured to be used by the packer 37func (s *StoragePacker) View() logical.Storage { 38 return s.view 39} 40 41// GetBucket returns a bucket for a given key 42func (s *StoragePacker) GetBucket(key string) (*Bucket, error) { 43 if key == "" { 44 return nil, fmt.Errorf("missing bucket key") 45 } 46 47 lock := locksutil.LockForKey(s.storageLocks, key) 48 lock.RLock() 49 defer lock.RUnlock() 50 51 // Read from storage 52 storageEntry, err := s.view.Get(context.Background(), key) 53 if err != nil { 54 return nil, errwrap.Wrapf("failed to read packed storage entry: {{err}}", err) 55 } 56 if storageEntry == nil { 57 return nil, nil 58 } 59 60 uncompressedData, notCompressed, err := compressutil.Decompress(storageEntry.Value) 61 if err != nil { 62 return nil, errwrap.Wrapf("failed to decompress packed storage entry: {{err}}", err) 63 } 64 if notCompressed { 65 uncompressedData = storageEntry.Value 66 } 67 68 var bucket Bucket 69 err = proto.Unmarshal(uncompressedData, &bucket) 70 if err != nil { 71 return nil, errwrap.Wrapf("failed to decode packed storage entry: {{err}}", err) 72 } 73 74 return &bucket, nil 75} 76 77// upsert either inserts a new item into the bucket or updates an existing one 78// if an item with a matching key is already present. 79func (s *Bucket) upsert(item *Item) error { 80 if s == nil { 81 return fmt.Errorf("nil storage bucket") 82 } 83 84 if item == nil { 85 return fmt.Errorf("nil item") 86 } 87 88 if item.ID == "" { 89 return fmt.Errorf("missing item ID") 90 } 91 92 // Look for an item with matching key and don't modify the collection while 93 // iterating 94 foundIdx := -1 95 for itemIdx, bucketItems := range s.Items { 96 if bucketItems.ID == item.ID { 97 foundIdx = itemIdx 98 break 99 } 100 } 101 102 // If there is no match, append the item, otherwise update it 103 if foundIdx == -1 { 104 s.Items = append(s.Items, item) 105 } else { 106 s.Items[foundIdx] = item 107 } 108 109 return nil 110} 111 112// BucketKey returns the storage key of the bucket where the given item will be 113// stored. 114func (s *StoragePacker) BucketKey(itemID string) string { 115 hf := md5.New() 116 input := []byte(itemID) 117 n, err := hf.Write(input) 118 // Make linter happy 119 if err != nil || n != len(input) { 120 return "" 121 } 122 index := uint8(hf.Sum(nil)[0]) 123 return s.viewPrefix + strconv.Itoa(int(index)) 124} 125 126// DeleteItem removes the item from the respective bucket 127func (s *StoragePacker) DeleteItem(_ context.Context, itemID string) error { 128 return s.DeleteMultipleItems(context.Background(), nil, itemID) 129} 130 131func (s *StoragePacker) DeleteMultipleItems(ctx context.Context, logger hclog.Logger, itemIDs ...string) error { 132 var err error 133 switch len(itemIDs) { 134 case 0: 135 // Nothing 136 return nil 137 138 case 1: 139 logger = hclog.NewNullLogger() 140 fallthrough 141 142 default: 143 lockIndexes := make(map[string]struct{}, len(s.storageLocks)) 144 for _, itemID := range itemIDs { 145 bucketKey := s.BucketKey(itemID) 146 if _, ok := lockIndexes[bucketKey]; !ok { 147 lockIndexes[bucketKey] = struct{}{} 148 } 149 } 150 151 lockKeys := make([]string, 0, len(lockIndexes)) 152 for k := range lockIndexes { 153 lockKeys = append(lockKeys, k) 154 } 155 156 locks := locksutil.LocksForKeys(s.storageLocks, lockKeys) 157 for _, lock := range locks { 158 lock.Lock() 159 defer lock.Unlock() 160 } 161 } 162 163 if logger == nil { 164 logger = hclog.NewNullLogger() 165 } 166 167 bucketCache := make(map[string]*Bucket, len(s.storageLocks)) 168 169 logger.Debug("deleting multiple items from storagepacker; caching and deleting from buckets", "total_items", len(itemIDs)) 170 171 var pctDone int 172 for idx, itemID := range itemIDs { 173 bucketKey := s.BucketKey(itemID) 174 175 bucket, bucketFound := bucketCache[bucketKey] 176 if !bucketFound { 177 // Read from storage 178 storageEntry, err := s.view.Get(context.Background(), bucketKey) 179 if err != nil { 180 return errwrap.Wrapf("failed to read packed storage value: {{err}}", err) 181 } 182 if storageEntry == nil { 183 return nil 184 } 185 186 uncompressedData, notCompressed, err := compressutil.Decompress(storageEntry.Value) 187 if err != nil { 188 return errwrap.Wrapf("failed to decompress packed storage value: {{err}}", err) 189 } 190 if notCompressed { 191 uncompressedData = storageEntry.Value 192 } 193 194 bucket = new(Bucket) 195 err = proto.Unmarshal(uncompressedData, bucket) 196 if err != nil { 197 return errwrap.Wrapf("failed decoding packed storage entry: {{err}}", err) 198 } 199 } 200 201 // Look for a matching storage entry 202 foundIdx := -1 203 for itemIdx, item := range bucket.Items { 204 if item.ID == itemID { 205 foundIdx = itemIdx 206 break 207 } 208 } 209 210 // If there is a match, remove it from the collection and persist the 211 // resulting collection 212 if foundIdx != -1 { 213 bucket.Items[foundIdx] = bucket.Items[len(bucket.Items)-1] 214 bucket.Items = bucket.Items[:len(bucket.Items)-1] 215 if !bucketFound { 216 bucketCache[bucketKey] = bucket 217 } 218 } 219 220 newPctDone := idx * 100.0 / len(itemIDs) 221 if int(newPctDone) > pctDone { 222 pctDone = int(newPctDone) 223 logger.Trace("bucket item removal progress", "percent", pctDone, "items_removed", idx) 224 } 225 } 226 227 logger.Debug("persisting buckets", "total_buckets", len(bucketCache)) 228 229 // Persist all buckets in the cache; these will be the ones that had 230 // deletions 231 pctDone = 0 232 idx := 0 233 for _, bucket := range bucketCache { 234 // Fail if the context is canceled, the storage calls will fail anyways 235 if ctx.Err() != nil { 236 return ctx.Err() 237 } 238 239 err = s.putBucket(ctx, bucket) 240 if err != nil { 241 return err 242 } 243 244 newPctDone := idx * 100.0 / len(bucketCache) 245 if int(newPctDone) > pctDone { 246 pctDone = int(newPctDone) 247 logger.Trace("bucket persistence progress", "percent", pctDone, "buckets_persisted", idx) 248 } 249 250 idx++ 251 } 252 253 return nil 254} 255 256func (s *StoragePacker) putBucket(ctx context.Context, bucket *Bucket) error { 257 if bucket == nil { 258 return fmt.Errorf("nil bucket entry") 259 } 260 261 if bucket.Key == "" { 262 return fmt.Errorf("missing key") 263 } 264 265 if !strings.HasPrefix(bucket.Key, s.viewPrefix) { 266 return fmt.Errorf("incorrect prefix; bucket entry key should have %q prefix", s.viewPrefix) 267 } 268 269 marshaledBucket, err := proto.Marshal(bucket) 270 if err != nil { 271 return errwrap.Wrapf("failed to marshal bucket: {{err}}", err) 272 } 273 274 compressedBucket, err := compressutil.Compress(marshaledBucket, &compressutil.CompressionConfig{ 275 Type: compressutil.CompressionTypeSnappy, 276 }) 277 if err != nil { 278 return errwrap.Wrapf("failed to compress packed bucket: {{err}}", err) 279 } 280 281 // Store the compressed value 282 err = s.view.Put(ctx, &logical.StorageEntry{ 283 Key: bucket.Key, 284 Value: compressedBucket, 285 }) 286 if err != nil { 287 return errwrap.Wrapf("failed to persist packed storage entry: {{err}}", err) 288 } 289 290 return nil 291} 292 293// GetItem fetches the storage entry for a given key from its corresponding 294// bucket. 295func (s *StoragePacker) GetItem(itemID string) (*Item, error) { 296 if itemID == "" { 297 return nil, fmt.Errorf("empty item ID") 298 } 299 300 bucketKey := s.BucketKey(itemID) 301 302 // Fetch the bucket entry 303 bucket, err := s.GetBucket(bucketKey) 304 if err != nil { 305 return nil, errwrap.Wrapf("failed to read packed storage item: {{err}}", err) 306 } 307 if bucket == nil { 308 return nil, nil 309 } 310 311 // Look for a matching storage entry in the bucket items 312 for _, item := range bucket.Items { 313 if item.ID == itemID { 314 return item, nil 315 } 316 } 317 318 return nil, nil 319} 320 321// PutItem stores the given item in its respective bucket 322func (s *StoragePacker) PutItem(_ context.Context, item *Item) error { 323 if item == nil { 324 return fmt.Errorf("nil item") 325 } 326 327 if item.ID == "" { 328 return fmt.Errorf("missing ID in item") 329 } 330 331 var err error 332 bucketKey := s.BucketKey(item.ID) 333 334 bucket := &Bucket{ 335 Key: bucketKey, 336 } 337 338 // In this case, we persist the storage entry regardless of the read 339 // storageEntry below is nil or not. Hence, directly acquire write lock 340 // even to read the entry. 341 lock := locksutil.LockForKey(s.storageLocks, bucketKey) 342 lock.Lock() 343 defer lock.Unlock() 344 345 // Check if there is an existing bucket for a given key 346 storageEntry, err := s.view.Get(context.Background(), bucketKey) 347 if err != nil { 348 return errwrap.Wrapf("failed to read packed storage bucket entry: {{err}}", err) 349 } 350 351 if storageEntry == nil { 352 // If the bucket entry does not exist, this will be the only item the 353 // bucket that is going to be persisted. 354 bucket.Items = []*Item{ 355 item, 356 } 357 } else { 358 uncompressedData, notCompressed, err := compressutil.Decompress(storageEntry.Value) 359 if err != nil { 360 return errwrap.Wrapf("failed to decompress packed storage entry: {{err}}", err) 361 } 362 if notCompressed { 363 uncompressedData = storageEntry.Value 364 } 365 366 err = proto.Unmarshal(uncompressedData, bucket) 367 if err != nil { 368 return errwrap.Wrapf("failed to decode packed storage entry: {{err}}", err) 369 } 370 371 err = bucket.upsert(item) 372 if err != nil { 373 return errwrap.Wrapf("failed to update entry in packed storage entry: {{err}}", err) 374 } 375 } 376 377 return s.putBucket(context.Background(), bucket) 378} 379 380// NewStoragePacker creates a new storage packer for a given view 381func NewStoragePacker(view logical.Storage, logger log.Logger, viewPrefix string) (*StoragePacker, error) { 382 if view == nil { 383 return nil, fmt.Errorf("nil view") 384 } 385 386 if viewPrefix == "" { 387 viewPrefix = StoragePackerBucketsPrefix 388 } 389 390 if !strings.HasSuffix(viewPrefix, "/") { 391 viewPrefix = viewPrefix + "/" 392 } 393 394 // Create a new packer object for the given view 395 packer := &StoragePacker{ 396 view: view, 397 viewPrefix: viewPrefix, 398 logger: logger, 399 storageLocks: locksutil.CreateLocks(), 400 } 401 402 return packer, nil 403} 404