1// Copyright 2016 The OPA Authors. All rights reserved. 2// Use of this source code is governed by an Apache2 3// license that can be found in the LICENSE file. 4 5package inmem 6 7import ( 8 "context" 9 "encoding/json" 10 "fmt" 11 "hash/fnv" 12 "strings" 13 "sync" 14 15 "github.com/open-policy-agent/opa/ast" 16 "github.com/open-policy-agent/opa/storage" 17 "github.com/open-policy-agent/opa/util" 18) 19 20// indices contains a mapping of non-ground references to values to sets of bindings. 21// 22// +------+------------------------------------+ 23// | ref1 | val1 | bindings-1, bindings-2, ... | 24// | +------+-----------------------------+ 25// | | val2 | bindings-m, bindings-m, ... | 26// | +------+-----------------------------+ 27// | | .... | ... | 28// +------+------+-----------------------------+ 29// | ref2 | .... | ... | 30// +------+------+-----------------------------+ 31// | ... | 32// +-------------------------------------------+ 33// 34// The "value" is the data value stored at the location referred to by the ground 35// reference obtained by plugging bindings into the non-ground reference that is the 36// index key. 37// 38type indices struct { 39 mu sync.Mutex 40 table map[int]*indicesNode 41} 42 43type indicesNode struct { 44 key ast.Ref 45 val *bindingIndex 46 next *indicesNode 47} 48 49func newIndices() *indices { 50 return &indices{ 51 table: map[int]*indicesNode{}, 52 } 53} 54 55func (ind *indices) Build(ctx context.Context, store storage.Store, txn storage.Transaction, ref ast.Ref) (*bindingIndex, error) { 56 57 ind.mu.Lock() 58 defer ind.mu.Unlock() 59 60 if exist := ind.get(ref); exist != nil { 61 return exist, nil 62 } 63 64 index := newBindingIndex() 65 66 if err := iterStorage(ctx, store, txn, ref, ast.EmptyRef(), ast.NewValueMap(), index.Add); err != nil { 67 return nil, err 68 } 69 70 hashCode := ref.Hash() 71 head := ind.table[hashCode] 72 entry := &indicesNode{ 73 key: ref, 74 val: index, 75 next: head, 76 } 77 78 ind.table[hashCode] = entry 79 80 return index, nil 81} 82 83func (ind *indices) get(ref ast.Ref) *bindingIndex { 84 node := ind.getNode(ref) 85 if node != nil { 86 return node.val 87 } 88 return nil 89} 90 91func (ind *indices) iter(iter func(ast.Ref, *bindingIndex) error) error { 92 for _, head := range ind.table { 93 for entry := head; entry != nil; entry = entry.next { 94 if err := iter(entry.key, entry.val); err != nil { 95 return err 96 } 97 } 98 } 99 return nil 100} 101 102func (ind *indices) getNode(ref ast.Ref) *indicesNode { 103 hashCode := ref.Hash() 104 for entry := ind.table[hashCode]; entry != nil; entry = entry.next { 105 if entry.key.Equal(ref) { 106 return entry 107 } 108 } 109 return nil 110} 111 112func (ind *indices) String() string { 113 buf := []string{} 114 for _, head := range ind.table { 115 for entry := head; entry != nil; entry = entry.next { 116 str := fmt.Sprintf("%v: %v", entry.key, entry.val) 117 buf = append(buf, str) 118 } 119 } 120 return "{" + strings.Join(buf, ", ") + "}" 121} 122 123const ( 124 triggerID = "org.openpolicyagent/index-maintenance" 125) 126 127// bindingIndex contains a mapping of values to bindings. 128type bindingIndex struct { 129 table map[int]*indexNode 130} 131 132type indexNode struct { 133 key interface{} 134 val *bindingSet 135 next *indexNode 136} 137 138func newBindingIndex() *bindingIndex { 139 return &bindingIndex{ 140 table: map[int]*indexNode{}, 141 } 142} 143 144func (ind *bindingIndex) Add(val interface{}, bindings *ast.ValueMap) { 145 146 node := ind.getNode(val) 147 if node != nil { 148 node.val.Add(bindings) 149 return 150 } 151 152 hashCode := hash(val) 153 bindingsSet := newBindingSet() 154 bindingsSet.Add(bindings) 155 156 entry := &indexNode{ 157 key: val, 158 val: bindingsSet, 159 next: ind.table[hashCode], 160 } 161 162 ind.table[hashCode] = entry 163} 164 165func (ind *bindingIndex) Lookup(_ context.Context, _ storage.Transaction, val interface{}, iter storage.IndexIterator) error { 166 node := ind.getNode(val) 167 if node == nil { 168 return nil 169 } 170 return node.val.Iter(iter) 171} 172 173func (ind *bindingIndex) getNode(val interface{}) *indexNode { 174 hashCode := hash(val) 175 head := ind.table[hashCode] 176 for entry := head; entry != nil; entry = entry.next { 177 if util.Compare(entry.key, val) == 0 { 178 return entry 179 } 180 } 181 return nil 182} 183 184func (ind *bindingIndex) String() string { 185 186 buf := []string{} 187 188 for _, head := range ind.table { 189 for entry := head; entry != nil; entry = entry.next { 190 str := fmt.Sprintf("%v: %v", entry.key, entry.val) 191 buf = append(buf, str) 192 } 193 } 194 195 return "{" + strings.Join(buf, ", ") + "}" 196} 197 198type bindingSetNode struct { 199 val *ast.ValueMap 200 next *bindingSetNode 201} 202 203type bindingSet struct { 204 table map[int]*bindingSetNode 205} 206 207func newBindingSet() *bindingSet { 208 return &bindingSet{ 209 table: map[int]*bindingSetNode{}, 210 } 211} 212 213func (set *bindingSet) Add(val *ast.ValueMap) { 214 node := set.getNode(val) 215 if node != nil { 216 return 217 } 218 hashCode := val.Hash() 219 head := set.table[hashCode] 220 set.table[hashCode] = &bindingSetNode{val, head} 221} 222 223func (set *bindingSet) Iter(iter func(*ast.ValueMap) error) error { 224 for _, head := range set.table { 225 for entry := head; entry != nil; entry = entry.next { 226 if err := iter(entry.val); err != nil { 227 return err 228 } 229 } 230 } 231 return nil 232} 233 234func (set *bindingSet) String() string { 235 buf := []string{} 236 set.Iter(func(bindings *ast.ValueMap) error { 237 buf = append(buf, bindings.String()) 238 return nil 239 }) 240 return "{" + strings.Join(buf, ", ") + "}" 241} 242 243func (set *bindingSet) getNode(val *ast.ValueMap) *bindingSetNode { 244 hashCode := val.Hash() 245 for entry := set.table[hashCode]; entry != nil; entry = entry.next { 246 if entry.val.Equal(val) { 247 return entry 248 } 249 } 250 return nil 251} 252 253func hash(v interface{}) int { 254 switch v := v.(type) { 255 case []interface{}: 256 var h int 257 for _, e := range v { 258 h += hash(e) 259 } 260 return h 261 case map[string]interface{}: 262 var h int 263 for k, v := range v { 264 h += hash(k) + hash(v) 265 } 266 return h 267 case string: 268 h := fnv.New64a() 269 h.Write([]byte(v)) 270 return int(h.Sum64()) 271 case bool: 272 if v { 273 return 1 274 } 275 return 0 276 case nil: 277 return 0 278 case json.Number: 279 h := fnv.New64a() 280 h.Write([]byte(v)) 281 return int(h.Sum64()) 282 } 283 panic(fmt.Sprintf("illegal argument: %v (%T)", v, v)) 284} 285 286func iterStorage(ctx context.Context, store storage.Store, txn storage.Transaction, nonGround, ground ast.Ref, bindings *ast.ValueMap, iter func(interface{}, *ast.ValueMap)) error { 287 288 if len(nonGround) == 0 { 289 path, err := storage.NewPathForRef(ground) 290 if err != nil { 291 return err 292 } 293 node, err := store.Read(ctx, txn, path) 294 if err != nil { 295 if storage.IsNotFound(err) { 296 return nil 297 } 298 return err 299 } 300 iter(node, bindings) 301 return nil 302 } 303 304 head := nonGround[0] 305 tail := nonGround[1:] 306 307 headVar, isVar := head.Value.(ast.Var) 308 309 if !isVar || len(ground) == 0 { 310 ground = append(ground, head) 311 return iterStorage(ctx, store, txn, tail, ground, bindings, iter) 312 } 313 314 path, err := storage.NewPathForRef(ground) 315 if err != nil { 316 return err 317 } 318 319 node, err := store.Read(ctx, txn, path) 320 if err != nil { 321 if storage.IsNotFound(err) { 322 return nil 323 } 324 return err 325 } 326 327 switch node := node.(type) { 328 case map[string]interface{}: 329 for key := range node { 330 ground = append(ground, ast.StringTerm(key)) 331 cpy := bindings.Copy() 332 cpy.Put(headVar, ast.String(key)) 333 err := iterStorage(ctx, store, txn, tail, ground, cpy, iter) 334 if err != nil { 335 return err 336 } 337 ground = ground[:len(ground)-1] 338 } 339 case []interface{}: 340 for i := range node { 341 idx := ast.IntNumberTerm(i) 342 ground = append(ground, idx) 343 cpy := bindings.Copy() 344 cpy.Put(headVar, idx.Value) 345 err := iterStorage(ctx, store, txn, tail, ground, cpy, iter) 346 if err != nil { 347 return err 348 } 349 ground = ground[:len(ground)-1] 350 } 351 } 352 353 return nil 354} 355