1package inmem
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"os"
8	"strconv"
9	"strings"
10	"sync"
11	"sync/atomic"
12
13	log "github.com/hashicorp/go-hclog"
14	"github.com/hashicorp/vault/sdk/physical"
15
16	radix "github.com/armon/go-radix"
17)
18
19// Verify interfaces are satisfied
20var _ physical.Backend = (*InmemBackend)(nil)
21var _ physical.HABackend = (*InmemHABackend)(nil)
22var _ physical.HABackend = (*TransactionalInmemHABackend)(nil)
23var _ physical.Lock = (*InmemLock)(nil)
24var _ physical.Transactional = (*TransactionalInmemBackend)(nil)
25var _ physical.Transactional = (*TransactionalInmemHABackend)(nil)
26
27var (
28	PutDisabledError    = errors.New("put operations disabled in inmem backend")
29	GetDisabledError    = errors.New("get operations disabled in inmem backend")
30	DeleteDisabledError = errors.New("delete operations disabled in inmem backend")
31	ListDisabledError   = errors.New("list operations disabled in inmem backend")
32)
33
34// InmemBackend is an in-memory only physical backend. It is useful
35// for testing and development situations where the data is not
36// expected to be durable.
37type InmemBackend struct {
38	sync.RWMutex
39	root         *radix.Tree
40	permitPool   *physical.PermitPool
41	logger       log.Logger
42	failGet      *uint32
43	failPut      *uint32
44	failDelete   *uint32
45	failList     *uint32
46	logOps       bool
47	maxValueSize int
48}
49
50type TransactionalInmemBackend struct {
51	InmemBackend
52}
53
54// NewInmem constructs a new in-memory backend
55func NewInmem(conf map[string]string, logger log.Logger) (physical.Backend, error) {
56	maxValueSize := 0
57	maxValueSizeStr, ok := conf["max_value_size"]
58	if ok {
59		var err error
60		maxValueSize, err = strconv.Atoi(maxValueSizeStr)
61		if err != nil {
62			return nil, err
63		}
64	}
65
66	return &InmemBackend{
67		root:         radix.New(),
68		permitPool:   physical.NewPermitPool(physical.DefaultParallelOperations),
69		logger:       logger,
70		failGet:      new(uint32),
71		failPut:      new(uint32),
72		failDelete:   new(uint32),
73		failList:     new(uint32),
74		logOps:       os.Getenv("VAULT_INMEM_LOG_ALL_OPS") != "",
75		maxValueSize: maxValueSize,
76	}, nil
77}
78
79// Basically for now just creates a permit pool of size 1 so only one operation
80// can run at a time
81func NewTransactionalInmem(conf map[string]string, logger log.Logger) (physical.Backend, error) {
82	maxValueSize := 0
83	maxValueSizeStr, ok := conf["max_value_size"]
84	if ok {
85		var err error
86		maxValueSize, err = strconv.Atoi(maxValueSizeStr)
87		if err != nil {
88			return nil, err
89		}
90	}
91
92	return &TransactionalInmemBackend{
93		InmemBackend: InmemBackend{
94			root:         radix.New(),
95			permitPool:   physical.NewPermitPool(1),
96			logger:       logger,
97			failGet:      new(uint32),
98			failPut:      new(uint32),
99			failDelete:   new(uint32),
100			failList:     new(uint32),
101			logOps:       os.Getenv("VAULT_INMEM_LOG_ALL_OPS") != "",
102			maxValueSize: maxValueSize,
103		},
104	}, nil
105}
106
107// Put is used to insert or update an entry
108func (i *InmemBackend) Put(ctx context.Context, entry *physical.Entry) error {
109	i.permitPool.Acquire()
110	defer i.permitPool.Release()
111
112	i.Lock()
113	defer i.Unlock()
114
115	return i.PutInternal(ctx, entry)
116}
117
118func (i *InmemBackend) PutInternal(ctx context.Context, entry *physical.Entry) error {
119	if i.logOps {
120		i.logger.Trace("put", "key", entry.Key)
121	}
122	if atomic.LoadUint32(i.failPut) != 0 {
123		return PutDisabledError
124	}
125
126	select {
127	case <-ctx.Done():
128		return ctx.Err()
129	default:
130	}
131
132	if i.maxValueSize > 0 && len(entry.Value) > i.maxValueSize {
133		return fmt.Errorf("%s", physical.ErrValueTooLarge)
134	}
135
136	i.root.Insert(entry.Key, entry.Value)
137	return nil
138}
139
140func (i *InmemBackend) FailPut(fail bool) {
141	var val uint32
142	if fail {
143		val = 1
144	}
145	atomic.StoreUint32(i.failPut, val)
146}
147
148// Get is used to fetch an entry
149func (i *InmemBackend) Get(ctx context.Context, key string) (*physical.Entry, error) {
150	i.permitPool.Acquire()
151	defer i.permitPool.Release()
152
153	i.RLock()
154	defer i.RUnlock()
155
156	return i.GetInternal(ctx, key)
157}
158
159func (i *InmemBackend) GetInternal(ctx context.Context, key string) (*physical.Entry, error) {
160	if i.logOps {
161		i.logger.Trace("get", "key", key)
162	}
163	if atomic.LoadUint32(i.failGet) != 0 {
164		return nil, GetDisabledError
165	}
166
167	select {
168	case <-ctx.Done():
169		return nil, ctx.Err()
170	default:
171	}
172
173	if raw, ok := i.root.Get(key); ok {
174		return &physical.Entry{
175			Key:   key,
176			Value: raw.([]byte),
177		}, nil
178	}
179	return nil, nil
180}
181
182func (i *InmemBackend) FailGet(fail bool) {
183	var val uint32
184	if fail {
185		val = 1
186	}
187	atomic.StoreUint32(i.failGet, val)
188}
189
190// Delete is used to permanently delete an entry
191func (i *InmemBackend) Delete(ctx context.Context, key string) error {
192	i.permitPool.Acquire()
193	defer i.permitPool.Release()
194
195	i.Lock()
196	defer i.Unlock()
197
198	return i.DeleteInternal(ctx, key)
199}
200
201func (i *InmemBackend) DeleteInternal(ctx context.Context, key string) error {
202	if i.logOps {
203		i.logger.Trace("delete", "key", key)
204	}
205	if atomic.LoadUint32(i.failDelete) != 0 {
206		return DeleteDisabledError
207	}
208	select {
209	case <-ctx.Done():
210		return ctx.Err()
211	default:
212	}
213
214	i.root.Delete(key)
215	return nil
216}
217
218func (i *InmemBackend) FailDelete(fail bool) {
219	var val uint32
220	if fail {
221		val = 1
222	}
223	atomic.StoreUint32(i.failDelete, val)
224}
225
226// List is used to list all the keys under a given
227// prefix, up to the next prefix.
228func (i *InmemBackend) List(ctx context.Context, prefix string) ([]string, error) {
229	i.permitPool.Acquire()
230	defer i.permitPool.Release()
231
232	i.RLock()
233	defer i.RUnlock()
234
235	return i.ListInternal(ctx, prefix)
236}
237
238func (i *InmemBackend) ListInternal(ctx context.Context, prefix string) ([]string, error) {
239	if i.logOps {
240		i.logger.Trace("list", "prefix", prefix)
241	}
242	if atomic.LoadUint32(i.failList) != 0 {
243		return nil, ListDisabledError
244	}
245
246	var out []string
247	seen := make(map[string]interface{})
248	walkFn := func(s string, v interface{}) bool {
249		trimmed := strings.TrimPrefix(s, prefix)
250		sep := strings.Index(trimmed, "/")
251		if sep == -1 {
252			out = append(out, trimmed)
253		} else {
254			trimmed = trimmed[:sep+1]
255			if _, ok := seen[trimmed]; !ok {
256				out = append(out, trimmed)
257				seen[trimmed] = struct{}{}
258			}
259		}
260		return false
261	}
262	i.root.WalkPrefix(prefix, walkFn)
263
264	select {
265	case <-ctx.Done():
266		return nil, ctx.Err()
267	default:
268	}
269
270	return out, nil
271}
272
273func (i *InmemBackend) FailList(fail bool) {
274	var val uint32
275	if fail {
276		val = 1
277	}
278	atomic.StoreUint32(i.failList, val)
279}
280
281// Implements the transaction interface
282func (t *TransactionalInmemBackend) Transaction(ctx context.Context, txns []*physical.TxnEntry) error {
283	t.permitPool.Acquire()
284	defer t.permitPool.Release()
285
286	t.Lock()
287	defer t.Unlock()
288
289	return physical.GenericTransactionHandler(ctx, t, txns)
290}
291