1// Copyright 2016 The OPA Authors.  All rights reserved.
2// Use of this source code is governed by an Apache2
3// license that can be found in the LICENSE file.
4
5// Package inmem implements an in-memory version of the policy engine's storage
6// layer.
7//
8// The in-memory store is used as the default storage layer implementation. The
9// in-memory store supports multi-reader/single-writer concurrency with
10// rollback.
11//
12// Callers should assume the in-memory store does not make copies of written
13// data. Once data is written to the in-memory store, it should not be modified
14// (outside of calling Store.Write). Furthermore, data read from the in-memory
15// store should be treated as read-only.
16package inmem
17
18import (
19	"context"
20	"fmt"
21	"io"
22	"sync"
23	"sync/atomic"
24
25	"github.com/open-policy-agent/opa/ast"
26	"github.com/open-policy-agent/opa/storage"
27	"github.com/open-policy-agent/opa/util"
28)
29
30// New returns an empty in-memory store.
31func New() storage.Store {
32	return &store{
33		data:     map[string]interface{}{},
34		triggers: map[*handle]storage.TriggerConfig{},
35		policies: map[string][]byte{},
36		indices:  newIndices(),
37	}
38}
39
40// NewFromObject returns a new in-memory store from the supplied data object.
41func NewFromObject(data map[string]interface{}) storage.Store {
42	db := New()
43	ctx := context.Background()
44	txn, err := db.NewTransaction(ctx, storage.WriteParams)
45	if err != nil {
46		panic(err)
47	}
48	if err := db.Write(ctx, txn, storage.AddOp, storage.Path{}, data); err != nil {
49		panic(err)
50	}
51	if err := db.Commit(ctx, txn); err != nil {
52		panic(err)
53	}
54	return db
55}
56
57// NewFromReader returns a new in-memory store from a reader that produces a
58// JSON serialized object. This function is for test purposes.
59func NewFromReader(r io.Reader) storage.Store {
60	d := util.NewJSONDecoder(r)
61	var data map[string]interface{}
62	if err := d.Decode(&data); err != nil {
63		panic(err)
64	}
65	return NewFromObject(data)
66}
67
68type store struct {
69	rmu      sync.RWMutex                      // reader-writer lock
70	wmu      sync.Mutex                        // writer lock
71	xid      uint64                            // last generated transaction id
72	data     map[string]interface{}            // raw data
73	policies map[string][]byte                 // raw policies
74	triggers map[*handle]storage.TriggerConfig // registered triggers
75	indices  *indices                          // data ref indices
76}
77
78type handle struct {
79	db *store
80}
81
82func (db *store) NewTransaction(ctx context.Context, params ...storage.TransactionParams) (storage.Transaction, error) {
83	var write bool
84	var context *storage.Context
85	if len(params) > 0 {
86		write = params[0].Write
87		context = params[0].Context
88	}
89	xid := atomic.AddUint64(&db.xid, uint64(1))
90	if write {
91		db.wmu.Lock()
92	} else {
93		db.rmu.RLock()
94	}
95	return newTransaction(xid, write, context, db), nil
96}
97
98func (db *store) Commit(ctx context.Context, txn storage.Transaction) error {
99	underlying, err := db.underlying(txn)
100	if err != nil {
101		return err
102	}
103	if underlying.write {
104		db.rmu.Lock()
105		event := underlying.Commit()
106		db.indices = newIndices()
107		db.runOnCommitTriggers(ctx, txn, event)
108		// Mark the transaction stale after executing triggers so they can
109		// perform store operations if needed.
110		underlying.stale = true
111		db.rmu.Unlock()
112		db.wmu.Unlock()
113	} else {
114		db.rmu.RUnlock()
115	}
116	return nil
117}
118
119func (db *store) Abort(ctx context.Context, txn storage.Transaction) {
120	underlying, err := db.underlying(txn)
121	if err != nil {
122		panic(err)
123	}
124	underlying.stale = true
125	if underlying.write {
126		db.wmu.Unlock()
127	} else {
128		db.rmu.RUnlock()
129	}
130}
131
132func (db *store) ListPolicies(_ context.Context, txn storage.Transaction) ([]string, error) {
133	underlying, err := db.underlying(txn)
134	if err != nil {
135		return nil, err
136	}
137	return underlying.ListPolicies(), nil
138}
139
140func (db *store) GetPolicy(_ context.Context, txn storage.Transaction, id string) ([]byte, error) {
141	underlying, err := db.underlying(txn)
142	if err != nil {
143		return nil, err
144	}
145	return underlying.GetPolicy(id)
146}
147
148func (db *store) UpsertPolicy(_ context.Context, txn storage.Transaction, id string, bs []byte) error {
149	underlying, err := db.underlying(txn)
150	if err != nil {
151		return err
152	}
153	return underlying.UpsertPolicy(id, bs)
154}
155
156func (db *store) DeletePolicy(_ context.Context, txn storage.Transaction, id string) error {
157	underlying, err := db.underlying(txn)
158	if err != nil {
159		return err
160	}
161	if _, err := underlying.GetPolicy(id); err != nil {
162		return err
163	}
164	return underlying.DeletePolicy(id)
165}
166
167func (db *store) Register(ctx context.Context, txn storage.Transaction, config storage.TriggerConfig) (storage.TriggerHandle, error) {
168	underlying, err := db.underlying(txn)
169	if err != nil {
170		return nil, err
171	}
172	if !underlying.write {
173		return nil, &storage.Error{
174			Code:    storage.InvalidTransactionErr,
175			Message: "triggers must be registered with a write transaction",
176		}
177	}
178	h := &handle{db}
179	db.triggers[h] = config
180	return h, nil
181}
182
183func (db *store) Read(ctx context.Context, txn storage.Transaction, path storage.Path) (interface{}, error) {
184	underlying, err := db.underlying(txn)
185	if err != nil {
186		return nil, err
187	}
188	return underlying.Read(path)
189}
190
191func (db *store) Write(ctx context.Context, txn storage.Transaction, op storage.PatchOp, path storage.Path, value interface{}) error {
192	underlying, err := db.underlying(txn)
193	if err != nil {
194		return err
195	}
196	val := util.Reference(value)
197	if err := util.RoundTrip(val); err != nil {
198		return err
199	}
200	return underlying.Write(op, path, *val)
201}
202
203func (db *store) Build(ctx context.Context, txn storage.Transaction, ref ast.Ref) (storage.Index, error) {
204	underlying, err := db.underlying(txn)
205	if err != nil {
206		return nil, err
207	}
208	if underlying.write {
209		return nil, &storage.Error{
210			Code:    storage.IndexingNotSupportedErr,
211			Message: "in-memory store does not support indexing on write transactions",
212		}
213	}
214	return db.indices.Build(ctx, db, txn, ref)
215}
216
217func (h *handle) Unregister(ctx context.Context, txn storage.Transaction) {
218	underlying, err := h.db.underlying(txn)
219	if err != nil {
220		panic(err)
221	}
222	if !underlying.write {
223		panic(&storage.Error{
224			Code:    storage.InvalidTransactionErr,
225			Message: "triggers must be unregistered with a write transaction",
226		})
227	}
228	delete(h.db.triggers, h)
229}
230
231func (db *store) runOnCommitTriggers(ctx context.Context, txn storage.Transaction, event storage.TriggerEvent) {
232	for _, t := range db.triggers {
233		t.OnCommit(ctx, txn, event)
234	}
235}
236
237func (db *store) underlying(txn storage.Transaction) (*transaction, error) {
238	underlying, ok := txn.(*transaction)
239	if !ok {
240		return nil, &storage.Error{
241			Code:    storage.InvalidTransactionErr,
242			Message: fmt.Sprintf("unexpected transaction type %T", txn),
243		}
244	}
245	if underlying.db != db {
246		return nil, &storage.Error{
247			Code:    storage.InvalidTransactionErr,
248			Message: "unknown transaction",
249		}
250	}
251	if underlying.stale {
252		return nil, &storage.Error{
253			Code:    storage.InvalidTransactionErr,
254			Message: "stale transaction",
255		}
256	}
257	return underlying, nil
258}
259
260var doesNotExistMsg = "document does not exist"
261var rootMustBeObjectMsg = "root must be object"
262var rootCannotBeRemovedMsg = "root cannot be removed"
263var outOfRangeMsg = "array index out of range"
264var arrayIndexTypeMsg = "array index must be integer"
265
266func invalidPatchError(f string, a ...interface{}) *storage.Error {
267	return &storage.Error{
268		Code:    storage.InvalidPatchErr,
269		Message: fmt.Sprintf(f, a...),
270	}
271}
272
273func notFoundError(path storage.Path) *storage.Error {
274	return notFoundErrorHint(path, doesNotExistMsg)
275}
276
277func notFoundErrorHint(path storage.Path, hint string) *storage.Error {
278	return notFoundErrorf("%v: %v", path.String(), hint)
279}
280
281func notFoundErrorf(f string, a ...interface{}) *storage.Error {
282	msg := fmt.Sprintf(f, a...)
283	return &storage.Error{
284		Code:    storage.NotFoundErr,
285		Message: msg,
286	}
287}
288