1// Copyright 2016 The OPA Authors.  All rights reserved.
2// Use of this source code is governed by an Apache2
3// license that can be found in the LICENSE file.
4
5package inmem
6
7import (
8	"context"
9	"encoding/json"
10	"fmt"
11	"hash/fnv"
12	"strings"
13	"sync"
14
15	"github.com/open-policy-agent/opa/ast"
16	"github.com/open-policy-agent/opa/storage"
17	"github.com/open-policy-agent/opa/util"
18)
19
20// indices contains a mapping of non-ground references to values to sets of bindings.
21//
22//  +------+------------------------------------+
23//  | ref1 | val1 | bindings-1, bindings-2, ... |
24//  |      +------+-----------------------------+
25//  |      | val2 | bindings-m, bindings-m, ... |
26//  |      +------+-----------------------------+
27//  |      | .... | ...                         |
28//  +------+------+-----------------------------+
29//  | ref2 | .... | ...                         |
30//  +------+------+-----------------------------+
31//  | ...                                       |
32//  +-------------------------------------------+
33//
34// The "value" is the data value stored at the location referred to by the ground
35// reference obtained by plugging bindings into the non-ground reference that is the
36// index key.
37//
38type indices struct {
39	mu    sync.Mutex
40	table map[int]*indicesNode
41}
42
43type indicesNode struct {
44	key  ast.Ref
45	val  *bindingIndex
46	next *indicesNode
47}
48
49func newIndices() *indices {
50	return &indices{
51		table: map[int]*indicesNode{},
52	}
53}
54
55func (ind *indices) Build(ctx context.Context, store storage.Store, txn storage.Transaction, ref ast.Ref) (*bindingIndex, error) {
56
57	ind.mu.Lock()
58	defer ind.mu.Unlock()
59
60	if exist := ind.get(ref); exist != nil {
61		return exist, nil
62	}
63
64	index := newBindingIndex()
65
66	if err := iterStorage(ctx, store, txn, ref, ast.EmptyRef(), ast.NewValueMap(), index.Add); err != nil {
67		return nil, err
68	}
69
70	hashCode := ref.Hash()
71	head := ind.table[hashCode]
72	entry := &indicesNode{
73		key:  ref,
74		val:  index,
75		next: head,
76	}
77
78	ind.table[hashCode] = entry
79
80	return index, nil
81}
82
83func (ind *indices) get(ref ast.Ref) *bindingIndex {
84	node := ind.getNode(ref)
85	if node != nil {
86		return node.val
87	}
88	return nil
89}
90
91func (ind *indices) iter(iter func(ast.Ref, *bindingIndex) error) error {
92	for _, head := range ind.table {
93		for entry := head; entry != nil; entry = entry.next {
94			if err := iter(entry.key, entry.val); err != nil {
95				return err
96			}
97		}
98	}
99	return nil
100}
101
102func (ind *indices) getNode(ref ast.Ref) *indicesNode {
103	hashCode := ref.Hash()
104	for entry := ind.table[hashCode]; entry != nil; entry = entry.next {
105		if entry.key.Equal(ref) {
106			return entry
107		}
108	}
109	return nil
110}
111
112func (ind *indices) String() string {
113	buf := []string{}
114	for _, head := range ind.table {
115		for entry := head; entry != nil; entry = entry.next {
116			str := fmt.Sprintf("%v: %v", entry.key, entry.val)
117			buf = append(buf, str)
118		}
119	}
120	return "{" + strings.Join(buf, ", ") + "}"
121}
122
123const (
124	triggerID = "org.openpolicyagent/index-maintenance"
125)
126
127// bindingIndex contains a mapping of values to bindings.
128type bindingIndex struct {
129	table map[int]*indexNode
130}
131
132type indexNode struct {
133	key  interface{}
134	val  *bindingSet
135	next *indexNode
136}
137
138func newBindingIndex() *bindingIndex {
139	return &bindingIndex{
140		table: map[int]*indexNode{},
141	}
142}
143
144func (ind *bindingIndex) Add(val interface{}, bindings *ast.ValueMap) {
145
146	node := ind.getNode(val)
147	if node != nil {
148		node.val.Add(bindings)
149		return
150	}
151
152	hashCode := hash(val)
153	bindingsSet := newBindingSet()
154	bindingsSet.Add(bindings)
155
156	entry := &indexNode{
157		key:  val,
158		val:  bindingsSet,
159		next: ind.table[hashCode],
160	}
161
162	ind.table[hashCode] = entry
163}
164
165func (ind *bindingIndex) Lookup(_ context.Context, _ storage.Transaction, val interface{}, iter storage.IndexIterator) error {
166	node := ind.getNode(val)
167	if node == nil {
168		return nil
169	}
170	return node.val.Iter(iter)
171}
172
173func (ind *bindingIndex) getNode(val interface{}) *indexNode {
174	hashCode := hash(val)
175	head := ind.table[hashCode]
176	for entry := head; entry != nil; entry = entry.next {
177		if util.Compare(entry.key, val) == 0 {
178			return entry
179		}
180	}
181	return nil
182}
183
184func (ind *bindingIndex) String() string {
185
186	buf := []string{}
187
188	for _, head := range ind.table {
189		for entry := head; entry != nil; entry = entry.next {
190			str := fmt.Sprintf("%v: %v", entry.key, entry.val)
191			buf = append(buf, str)
192		}
193	}
194
195	return "{" + strings.Join(buf, ", ") + "}"
196}
197
198type bindingSetNode struct {
199	val  *ast.ValueMap
200	next *bindingSetNode
201}
202
203type bindingSet struct {
204	table map[int]*bindingSetNode
205}
206
207func newBindingSet() *bindingSet {
208	return &bindingSet{
209		table: map[int]*bindingSetNode{},
210	}
211}
212
213func (set *bindingSet) Add(val *ast.ValueMap) {
214	node := set.getNode(val)
215	if node != nil {
216		return
217	}
218	hashCode := val.Hash()
219	head := set.table[hashCode]
220	set.table[hashCode] = &bindingSetNode{val, head}
221}
222
223func (set *bindingSet) Iter(iter func(*ast.ValueMap) error) error {
224	for _, head := range set.table {
225		for entry := head; entry != nil; entry = entry.next {
226			if err := iter(entry.val); err != nil {
227				return err
228			}
229		}
230	}
231	return nil
232}
233
234func (set *bindingSet) String() string {
235	buf := []string{}
236	set.Iter(func(bindings *ast.ValueMap) error {
237		buf = append(buf, bindings.String())
238		return nil
239	})
240	return "{" + strings.Join(buf, ", ") + "}"
241}
242
243func (set *bindingSet) getNode(val *ast.ValueMap) *bindingSetNode {
244	hashCode := val.Hash()
245	for entry := set.table[hashCode]; entry != nil; entry = entry.next {
246		if entry.val.Equal(val) {
247			return entry
248		}
249	}
250	return nil
251}
252
253func hash(v interface{}) int {
254	switch v := v.(type) {
255	case []interface{}:
256		var h int
257		for _, e := range v {
258			h += hash(e)
259		}
260		return h
261	case map[string]interface{}:
262		var h int
263		for k, v := range v {
264			h += hash(k) + hash(v)
265		}
266		return h
267	case string:
268		h := fnv.New64a()
269		h.Write([]byte(v))
270		return int(h.Sum64())
271	case bool:
272		if v {
273			return 1
274		}
275		return 0
276	case nil:
277		return 0
278	case json.Number:
279		h := fnv.New64a()
280		h.Write([]byte(v))
281		return int(h.Sum64())
282	}
283	panic(fmt.Sprintf("illegal argument: %v (%T)", v, v))
284}
285
286func iterStorage(ctx context.Context, store storage.Store, txn storage.Transaction, nonGround, ground ast.Ref, bindings *ast.ValueMap, iter func(interface{}, *ast.ValueMap)) error {
287
288	if len(nonGround) == 0 {
289		path, err := storage.NewPathForRef(ground)
290		if err != nil {
291			return err
292		}
293		node, err := store.Read(ctx, txn, path)
294		if err != nil {
295			if storage.IsNotFound(err) {
296				return nil
297			}
298			return err
299		}
300		iter(node, bindings)
301		return nil
302	}
303
304	head := nonGround[0]
305	tail := nonGround[1:]
306
307	headVar, isVar := head.Value.(ast.Var)
308
309	if !isVar || len(ground) == 0 {
310		ground = append(ground, head)
311		return iterStorage(ctx, store, txn, tail, ground, bindings, iter)
312	}
313
314	path, err := storage.NewPathForRef(ground)
315	if err != nil {
316		return err
317	}
318
319	node, err := store.Read(ctx, txn, path)
320	if err != nil {
321		if storage.IsNotFound(err) {
322			return nil
323		}
324		return err
325	}
326
327	switch node := node.(type) {
328	case map[string]interface{}:
329		for key := range node {
330			ground = append(ground, ast.StringTerm(key))
331			cpy := bindings.Copy()
332			cpy.Put(headVar, ast.String(key))
333			err := iterStorage(ctx, store, txn, tail, ground, cpy, iter)
334			if err != nil {
335				return err
336			}
337			ground = ground[:len(ground)-1]
338		}
339	case []interface{}:
340		for i := range node {
341			idx := ast.IntNumberTerm(i)
342			ground = append(ground, idx)
343			cpy := bindings.Copy()
344			cpy.Put(headVar, idx.Value)
345			err := iterStorage(ctx, store, txn, tail, ground, cpy, iter)
346			if err != nil {
347				return err
348			}
349			ground = ground[:len(ground)-1]
350		}
351	}
352
353	return nil
354}
355