1package storagepacker
2
3import (
4	"context"
5	"crypto/md5"
6	"fmt"
7	"strconv"
8	"strings"
9
10	"github.com/golang/protobuf/proto"
11	"github.com/hashicorp/errwrap"
12	"github.com/hashicorp/go-hclog"
13	log "github.com/hashicorp/go-hclog"
14	"github.com/hashicorp/vault/sdk/helper/compressutil"
15	"github.com/hashicorp/vault/sdk/helper/locksutil"
16	"github.com/hashicorp/vault/sdk/logical"
17)
18
19const (
20	bucketCount = 256
21	// StoragePackerBucketsPrefix is the default storage key prefix under which
22	// bucket data will be stored.
23	StoragePackerBucketsPrefix = "packer/buckets/"
24)
25
26// StoragePacker packs items into a specific number of buckets by hashing
27// its identifier and indexing on it. Currently this supports only 256 bucket entries and
28// hence relies on the first byte of the hash value for indexing.
29type StoragePacker struct {
30	view         logical.Storage
31	logger       log.Logger
32	storageLocks []*locksutil.LockEntry
33	viewPrefix   string
34}
35
36// View returns the storage view configured to be used by the packer
37func (s *StoragePacker) View() logical.Storage {
38	return s.view
39}
40
41// GetBucket returns a bucket for a given key
42func (s *StoragePacker) GetBucket(key string) (*Bucket, error) {
43	if key == "" {
44		return nil, fmt.Errorf("missing bucket key")
45	}
46
47	lock := locksutil.LockForKey(s.storageLocks, key)
48	lock.RLock()
49	defer lock.RUnlock()
50
51	// Read from storage
52	storageEntry, err := s.view.Get(context.Background(), key)
53	if err != nil {
54		return nil, errwrap.Wrapf("failed to read packed storage entry: {{err}}", err)
55	}
56	if storageEntry == nil {
57		return nil, nil
58	}
59
60	uncompressedData, notCompressed, err := compressutil.Decompress(storageEntry.Value)
61	if err != nil {
62		return nil, errwrap.Wrapf("failed to decompress packed storage entry: {{err}}", err)
63	}
64	if notCompressed {
65		uncompressedData = storageEntry.Value
66	}
67
68	var bucket Bucket
69	err = proto.Unmarshal(uncompressedData, &bucket)
70	if err != nil {
71		return nil, errwrap.Wrapf("failed to decode packed storage entry: {{err}}", err)
72	}
73
74	return &bucket, nil
75}
76
77// upsert either inserts a new item into the bucket or updates an existing one
78// if an item with a matching key is already present.
79func (s *Bucket) upsert(item *Item) error {
80	if s == nil {
81		return fmt.Errorf("nil storage bucket")
82	}
83
84	if item == nil {
85		return fmt.Errorf("nil item")
86	}
87
88	if item.ID == "" {
89		return fmt.Errorf("missing item ID")
90	}
91
92	// Look for an item with matching key and don't modify the collection while
93	// iterating
94	foundIdx := -1
95	for itemIdx, bucketItems := range s.Items {
96		if bucketItems.ID == item.ID {
97			foundIdx = itemIdx
98			break
99		}
100	}
101
102	// If there is no match, append the item, otherwise update it
103	if foundIdx == -1 {
104		s.Items = append(s.Items, item)
105	} else {
106		s.Items[foundIdx] = item
107	}
108
109	return nil
110}
111
112// BucketKey returns the storage key of the bucket where the given item will be
113// stored.
114func (s *StoragePacker) BucketKey(itemID string) string {
115	hf := md5.New()
116	input := []byte(itemID)
117	n, err := hf.Write(input)
118	// Make linter happy
119	if err != nil || n != len(input) {
120		return ""
121	}
122	index := uint8(hf.Sum(nil)[0])
123	return s.viewPrefix + strconv.Itoa(int(index))
124}
125
126// DeleteItem removes the item from the respective bucket
127func (s *StoragePacker) DeleteItem(_ context.Context, itemID string) error {
128	return s.DeleteMultipleItems(context.Background(), nil, itemID)
129}
130
131func (s *StoragePacker) DeleteMultipleItems(ctx context.Context, logger hclog.Logger, itemIDs ...string) error {
132	var err error
133	switch len(itemIDs) {
134	case 0:
135		// Nothing
136		return nil
137
138	case 1:
139		logger = hclog.NewNullLogger()
140		fallthrough
141
142	default:
143		lockIndexes := make(map[string]struct{}, len(s.storageLocks))
144		for _, itemID := range itemIDs {
145			bucketKey := s.BucketKey(itemID)
146			if _, ok := lockIndexes[bucketKey]; !ok {
147				lockIndexes[bucketKey] = struct{}{}
148			}
149		}
150
151		lockKeys := make([]string, 0, len(lockIndexes))
152		for k := range lockIndexes {
153			lockKeys = append(lockKeys, k)
154		}
155
156		locks := locksutil.LocksForKeys(s.storageLocks, lockKeys)
157		for _, lock := range locks {
158			lock.Lock()
159			defer lock.Unlock()
160		}
161	}
162
163	if logger == nil {
164		logger = hclog.NewNullLogger()
165	}
166
167	bucketCache := make(map[string]*Bucket, len(s.storageLocks))
168
169	logger.Debug("deleting multiple items from storagepacker; caching and deleting from buckets", "total_items", len(itemIDs))
170
171	var pctDone int
172	for idx, itemID := range itemIDs {
173		bucketKey := s.BucketKey(itemID)
174
175		bucket, bucketFound := bucketCache[bucketKey]
176		if !bucketFound {
177			// Read from storage
178			storageEntry, err := s.view.Get(context.Background(), bucketKey)
179			if err != nil {
180				return errwrap.Wrapf("failed to read packed storage value: {{err}}", err)
181			}
182			if storageEntry == nil {
183				return nil
184			}
185
186			uncompressedData, notCompressed, err := compressutil.Decompress(storageEntry.Value)
187			if err != nil {
188				return errwrap.Wrapf("failed to decompress packed storage value: {{err}}", err)
189			}
190			if notCompressed {
191				uncompressedData = storageEntry.Value
192			}
193
194			bucket = new(Bucket)
195			err = proto.Unmarshal(uncompressedData, bucket)
196			if err != nil {
197				return errwrap.Wrapf("failed decoding packed storage entry: {{err}}", err)
198			}
199		}
200
201		// Look for a matching storage entry
202		foundIdx := -1
203		for itemIdx, item := range bucket.Items {
204			if item.ID == itemID {
205				foundIdx = itemIdx
206				break
207			}
208		}
209
210		// If there is a match, remove it from the collection and persist the
211		// resulting collection
212		if foundIdx != -1 {
213			bucket.Items[foundIdx] = bucket.Items[len(bucket.Items)-1]
214			bucket.Items = bucket.Items[:len(bucket.Items)-1]
215			if !bucketFound {
216				bucketCache[bucketKey] = bucket
217			}
218		}
219
220		newPctDone := idx * 100.0 / len(itemIDs)
221		if int(newPctDone) > pctDone {
222			pctDone = int(newPctDone)
223			logger.Trace("bucket item removal progress", "percent", pctDone, "items_removed", idx)
224		}
225	}
226
227	logger.Debug("persisting buckets", "total_buckets", len(bucketCache))
228
229	// Persist all buckets in the cache; these will be the ones that had
230	// deletions
231	pctDone = 0
232	idx := 0
233	for _, bucket := range bucketCache {
234		// Fail if the context is canceled, the storage calls will fail anyways
235		if ctx.Err() != nil {
236			return ctx.Err()
237		}
238
239		err = s.putBucket(ctx, bucket)
240		if err != nil {
241			return err
242		}
243
244		newPctDone := idx * 100.0 / len(bucketCache)
245		if int(newPctDone) > pctDone {
246			pctDone = int(newPctDone)
247			logger.Trace("bucket persistence progress", "percent", pctDone, "buckets_persisted", idx)
248		}
249
250		idx++
251	}
252
253	return nil
254}
255
256func (s *StoragePacker) putBucket(ctx context.Context, bucket *Bucket) error {
257	if bucket == nil {
258		return fmt.Errorf("nil bucket entry")
259	}
260
261	if bucket.Key == "" {
262		return fmt.Errorf("missing key")
263	}
264
265	if !strings.HasPrefix(bucket.Key, s.viewPrefix) {
266		return fmt.Errorf("incorrect prefix; bucket entry key should have %q prefix", s.viewPrefix)
267	}
268
269	marshaledBucket, err := proto.Marshal(bucket)
270	if err != nil {
271		return errwrap.Wrapf("failed to marshal bucket: {{err}}", err)
272	}
273
274	compressedBucket, err := compressutil.Compress(marshaledBucket, &compressutil.CompressionConfig{
275		Type: compressutil.CompressionTypeSnappy,
276	})
277	if err != nil {
278		return errwrap.Wrapf("failed to compress packed bucket: {{err}}", err)
279	}
280
281	// Store the compressed value
282	err = s.view.Put(ctx, &logical.StorageEntry{
283		Key:   bucket.Key,
284		Value: compressedBucket,
285	})
286	if err != nil {
287		return errwrap.Wrapf("failed to persist packed storage entry: {{err}}", err)
288	}
289
290	return nil
291}
292
293// GetItem fetches the storage entry for a given key from its corresponding
294// bucket.
295func (s *StoragePacker) GetItem(itemID string) (*Item, error) {
296	if itemID == "" {
297		return nil, fmt.Errorf("empty item ID")
298	}
299
300	bucketKey := s.BucketKey(itemID)
301
302	// Fetch the bucket entry
303	bucket, err := s.GetBucket(bucketKey)
304	if err != nil {
305		return nil, errwrap.Wrapf("failed to read packed storage item: {{err}}", err)
306	}
307	if bucket == nil {
308		return nil, nil
309	}
310
311	// Look for a matching storage entry in the bucket items
312	for _, item := range bucket.Items {
313		if item.ID == itemID {
314			return item, nil
315		}
316	}
317
318	return nil, nil
319}
320
321// PutItem stores the given item in its respective bucket
322func (s *StoragePacker) PutItem(_ context.Context, item *Item) error {
323	if item == nil {
324		return fmt.Errorf("nil item")
325	}
326
327	if item.ID == "" {
328		return fmt.Errorf("missing ID in item")
329	}
330
331	var err error
332	bucketKey := s.BucketKey(item.ID)
333
334	bucket := &Bucket{
335		Key: bucketKey,
336	}
337
338	// In this case, we persist the storage entry regardless of the read
339	// storageEntry below is nil or not. Hence, directly acquire write lock
340	// even to read the entry.
341	lock := locksutil.LockForKey(s.storageLocks, bucketKey)
342	lock.Lock()
343	defer lock.Unlock()
344
345	// Check if there is an existing bucket for a given key
346	storageEntry, err := s.view.Get(context.Background(), bucketKey)
347	if err != nil {
348		return errwrap.Wrapf("failed to read packed storage bucket entry: {{err}}", err)
349	}
350
351	if storageEntry == nil {
352		// If the bucket entry does not exist, this will be the only item the
353		// bucket that is going to be persisted.
354		bucket.Items = []*Item{
355			item,
356		}
357	} else {
358		uncompressedData, notCompressed, err := compressutil.Decompress(storageEntry.Value)
359		if err != nil {
360			return errwrap.Wrapf("failed to decompress packed storage entry: {{err}}", err)
361		}
362		if notCompressed {
363			uncompressedData = storageEntry.Value
364		}
365
366		err = proto.Unmarshal(uncompressedData, bucket)
367		if err != nil {
368			return errwrap.Wrapf("failed to decode packed storage entry: {{err}}", err)
369		}
370
371		err = bucket.upsert(item)
372		if err != nil {
373			return errwrap.Wrapf("failed to update entry in packed storage entry: {{err}}", err)
374		}
375	}
376
377	return s.putBucket(context.Background(), bucket)
378}
379
380// NewStoragePacker creates a new storage packer for a given view
381func NewStoragePacker(view logical.Storage, logger log.Logger, viewPrefix string) (*StoragePacker, error) {
382	if view == nil {
383		return nil, fmt.Errorf("nil view")
384	}
385
386	if viewPrefix == "" {
387		viewPrefix = StoragePackerBucketsPrefix
388	}
389
390	if !strings.HasSuffix(viewPrefix, "/") {
391		viewPrefix = viewPrefix + "/"
392	}
393
394	// Create a new packer object for the given view
395	packer := &StoragePacker{
396		view:         view,
397		viewPrefix:   viewPrefix,
398		logger:       logger,
399		storageLocks: locksutil.CreateLocks(),
400	}
401
402	return packer, nil
403}
404