1// Copyright 2017 The OPA Authors.  All rights reserved.
2// Use of this source code is governed by an Apache2
3// license that can be found in the LICENSE file.
4
5package topdown
6
7import (
8	"bytes"
9	"context"
10	"fmt"
11	"sync"
12	"testing"
13	"text/template"
14
15	"github.com/open-policy-agent/opa/ast"
16	"github.com/open-policy-agent/opa/storage"
17	"github.com/open-policy-agent/opa/storage/inmem"
18	"github.com/open-policy-agent/opa/util"
19)
20
21func BenchmarkLargeJSON(b *testing.B) {
22	data := generateLargeJSONBenchmarkData()
23	ctx := context.Background()
24	store := inmem.NewFromObject(data)
25	compiler := ast.NewCompiler()
26
27	if compiler.Compile(nil); compiler.Failed() {
28		b.Fatal(compiler.Errors)
29	}
30
31	b.ResetTimer()
32
33	// Read data.values N times inside query.
34	query := ast.MustParseBody("data.keys[_] = x; data.values = y")
35
36	for i := 0; i < b.N; i++ {
37
38		err := storage.Txn(ctx, store, storage.TransactionParams{}, func(txn storage.Transaction) error {
39
40			q := NewQuery(query).
41				WithCompiler(compiler).
42				WithStore(store).
43				WithTransaction(txn)
44
45			_, err := q.Run(ctx)
46			if err != nil {
47				return err
48			}
49
50			return nil
51		})
52
53		if err != nil {
54			b.Fatal(err)
55		}
56
57	}
58}
59
60func generateLargeJSONBenchmarkData() map[string]interface{} {
61
62	// create array of null values that can be iterated over
63	keys := make([]interface{}, 100)
64	for i := range keys {
65		keys[i] = nil
66	}
67
68	// create large JSON object value (100,000 entries is about 2MB on disk)
69	values := map[string]interface{}{}
70	for i := 0; i < 100*1000; i++ {
71		values[fmt.Sprintf("key%d", i)] = fmt.Sprintf("value%d", i)
72	}
73
74	return map[string]interface{}{
75		"keys":   keys,
76		"values": values,
77	}
78}
79
80func BenchmarkConcurrency1(b *testing.B) {
81	benchmarkConcurrency(b, getParams(1, 0))
82}
83
84func BenchmarkConcurrency2(b *testing.B) {
85	benchmarkConcurrency(b, getParams(2, 0))
86}
87
88func BenchmarkConcurrency4(b *testing.B) {
89	benchmarkConcurrency(b, getParams(4, 0))
90}
91
92func BenchmarkConcurrency8(b *testing.B) {
93	benchmarkConcurrency(b, getParams(8, 0))
94}
95
96func BenchmarkConcurrency4Readers1Writer(b *testing.B) {
97	benchmarkConcurrency(b, getParams(4, 1))
98}
99
100func BenchmarkConcurrency8Writers(b *testing.B) {
101	benchmarkConcurrency(b, getParams(0, 8))
102}
103
104func benchmarkConcurrency(b *testing.B, params []storage.TransactionParams) {
105
106	mod, data := generateConcurrencyBenchmarkData()
107	ctx := context.Background()
108	store := inmem.NewFromObject(data)
109	mods := map[string]*ast.Module{"module": mod}
110	compiler := ast.NewCompiler()
111
112	if compiler.Compile(mods); compiler.Failed() {
113		b.Fatalf("Unexpected compiler error: %v", compiler.Errors)
114	}
115
116	b.ResetTimer()
117
118	for i := 0; i < b.N; i++ {
119		wg := new(sync.WaitGroup)
120		queriesPerCore := 1000 / len(params)
121		for j := 0; j < len(params); j++ {
122			param := params[j] // capture j'th params before goroutine
123			wg.Add(1)
124			go func() {
125				defer wg.Done()
126				for k := 0; k < queriesPerCore; k++ {
127					txn := storage.NewTransactionOrDie(ctx, store, param)
128					query := NewQuery(ast.MustParseBody("data.test.p = x")).
129						WithCompiler(compiler).
130						WithStore(store).
131						WithTransaction(txn)
132					rs, err := query.Run(ctx)
133					if err != nil {
134						b.Fatalf("Unexpected topdown query error: %v", err)
135					}
136					if len(rs) != 1 || !rs[0][ast.Var("x")].Equal(ast.BooleanTerm(true)) {
137						b.Fatalf("Unexpected undefined/extra/bad result: %v", rs)
138					}
139					store.Abort(ctx, txn)
140				}
141			}()
142		}
143
144		wg.Wait()
145	}
146}
147
148func getParams(nReaders, nWriters int) (sl []storage.TransactionParams) {
149	for i := 0; i < nReaders; i++ {
150		sl = append(sl, storage.TransactionParams{})
151	}
152	for i := 0; i < nWriters; i++ {
153		sl = append(sl, storage.WriteParams)
154	}
155	return sl
156}
157
158func generateConcurrencyBenchmarkData() (*ast.Module, map[string]interface{}) {
159	obj := util.MustUnmarshalJSON([]byte(`
160		{
161			"objs": [
162				{
163					"attr1": "get",
164					"path": "/foo/bar",
165					"user": "bob"
166				},
167				{
168					"attr1": "set",
169					"path": "/foo/bar/baz",
170					"user": "alice"
171				},
172				{
173					"attr1": "get",
174					"path": "/foo",
175					"groups": [
176						"admin",
177						"eng"
178					]
179				},
180				{
181					"path": "/foo/bar",
182					"user": "alice"
183				}
184			]
185		}
186		`))
187
188	mod := `package test
189
190	import data.objs
191
192	p {
193		objs[i].attr1 = "get"
194		objs[i].groups[j] = "eng"
195	}
196
197	p {
198		objs[i].user = "alice"
199	}
200	`
201
202	return ast.MustParseModule(mod), obj.(map[string]interface{})
203}
204
205func BenchmarkVirtualDocs1x1(b *testing.B) {
206	runVirtualDocsBenchmark(b, 1, 1)
207}
208
209func BenchmarkVirtualDocs10x1(b *testing.B) {
210	runVirtualDocsBenchmark(b, 10, 1)
211}
212
213func BenchmarkVirtualDocs100x1(b *testing.B) {
214	runVirtualDocsBenchmark(b, 100, 1)
215}
216
217func BenchmarkVirtualDocs1000x1(b *testing.B) {
218	runVirtualDocsBenchmark(b, 1000, 1)
219}
220
221func BenchmarkVirtualDocs10x10(b *testing.B) {
222	runVirtualDocsBenchmark(b, 10, 10)
223}
224
225func BenchmarkVirtualDocs100x10(b *testing.B) {
226	runVirtualDocsBenchmark(b, 100, 10)
227}
228
229func BenchmarkVirtualDocs1000x10(b *testing.B) {
230	runVirtualDocsBenchmark(b, 1000, 10)
231}
232
233func runVirtualDocsBenchmark(b *testing.B, numTotalRules, numHitRules int) {
234
235	mod, input := generateVirtualDocsBenchmarkData(numTotalRules, numHitRules)
236	ctx := context.Background()
237	compiler := ast.NewCompiler()
238	mods := map[string]*ast.Module{"module": mod}
239	store := inmem.New()
240	txn := storage.NewTransactionOrDie(ctx, store)
241	if compiler.Compile(mods); compiler.Failed() {
242		b.Fatalf("Unexpected compiler error: %v", compiler.Errors)
243	}
244
245	query := NewQuery(ast.MustParseBody("data.a.b.c.allow = x")).
246		WithCompiler(compiler).
247		WithStore(store).
248		WithTransaction(txn).
249		WithInput(input)
250
251	b.ResetTimer()
252
253	for i := 0; i < b.N; i++ {
254		func() {
255			rs, err := query.Run(ctx)
256			if err != nil {
257				b.Fatalf("Unexpected topdown query error: %v", err)
258			}
259			if len(rs) != 1 || !rs[0][ast.Var("x")].Equal(ast.BooleanTerm(true)) {
260				b.Fatalf("Unexpecfted undefined/extra/bad result: %v", rs)
261			}
262		}()
263
264	}
265}
266
267func generateVirtualDocsBenchmarkData(numTotalRules, numHitRules int) (*ast.Module, *ast.Term) {
268
269	hitRule := `
270	allow {
271		input.method = "POST"
272		input.path = ["accounts", account_id]
273		input.user_id = account_id
274	}
275	`
276
277	missRule := `
278	allow {
279		input.method = "GET"
280		input.path = ["salaries", account_id]
281		input.user_id = account_id
282	}
283	`
284
285	testModuleTmpl := `
286	package a.b.c
287
288	{{range .MissRules }}
289		{{ . }}
290	{{end}}
291
292	{{range .HitRules }}
293		{{ . }}
294	{{end}}
295	`
296
297	tmpl, err := template.New("Test").Parse(testModuleTmpl)
298	if err != nil {
299		panic(err)
300	}
301
302	var buf bytes.Buffer
303
304	var missRules []string
305
306	if numTotalRules > numHitRules {
307		missRules = make([]string, numTotalRules-numHitRules)
308		for i := range missRules {
309			missRules[i] = missRule
310		}
311	}
312
313	hitRules := make([]string, numHitRules)
314	for i := range hitRules {
315		hitRules[i] = hitRule
316	}
317
318	params := struct {
319		MissRules []string
320		HitRules  []string
321	}{
322		MissRules: missRules,
323		HitRules:  hitRules,
324	}
325
326	err = tmpl.Execute(&buf, params)
327	if err != nil {
328		panic(err)
329	}
330
331	input := ast.MustParseTerm(`{
332			"path": ["accounts", "alice"],
333			"method": "POST",
334			"user_id": "alice"
335		}`)
336
337	return ast.MustParseModule(buf.String()), input
338}
339
340func BenchmarkPartialEval(b *testing.B) {
341	sizes := []int{1, 10, 100, 1000}
342	for _, n := range sizes {
343		b.Run(fmt.Sprint(n), func(b *testing.B) {
344			runPartialEvalBenchmark(b, n)
345		})
346	}
347}
348
349func BenchmarkPartialEvalCompile(b *testing.B) {
350	sizes := []int{1, 10, 100, 1000}
351	for _, n := range sizes {
352		b.Run(fmt.Sprint(n), func(b *testing.B) {
353			runPartialEvalCompileBenchmark(b, n)
354		})
355	}
356}
357
358func runPartialEvalBenchmark(b *testing.B, numRoles int) {
359
360	ctx := context.Background()
361	compiler := ast.NewCompiler()
362	compiler.Compile(map[string]*ast.Module{
363		"authz": ast.MustParseModule(partialEvalBenchmarkPolicy),
364	})
365
366	if compiler.Failed() {
367		b.Fatal(compiler.Errors)
368	}
369
370	var partials []ast.Body
371	var support []*ast.Module
372	data := generatePartialEvalBenchmarkData(numRoles)
373	store := inmem.NewFromObject(data)
374
375	err := storage.Txn(ctx, store, storage.TransactionParams{}, func(txn storage.Transaction) error {
376		query := NewQuery(ast.MustParseBody("data.authz.allow = true")).
377			WithUnknowns([]*ast.Term{ast.MustParseTerm("input")}).
378			WithCompiler(compiler).
379			WithStore(store).
380			WithTransaction(txn)
381		var err error
382		partials, support, err = query.PartialRun(ctx)
383		return err
384	})
385	if err != nil {
386		b.Fatal(err)
387	}
388
389	if len(partials) != numRoles {
390		b.Fatal("Expected exactly one partial query result but got:", partials)
391	} else if len(support) != 0 {
392		b.Fatal("Expected no partial support results but got:", support)
393	}
394
395	module := ast.MustParseModule(`package partial.authz`)
396
397	for _, query := range partials {
398		rule := &ast.Rule{
399			Head:   ast.NewHead(ast.Var("allow"), nil, ast.BooleanTerm(true)),
400			Body:   query,
401			Module: module,
402		}
403		module.Rules = append(module.Rules, rule)
404	}
405
406	compiler = ast.NewCompiler()
407	compiler.Compile(map[string]*ast.Module{
408		"partial": module,
409	})
410	if compiler.Failed() {
411		b.Fatal(compiler.Errors)
412	}
413
414	input := generatePartialEvalBenchmarkInput(numRoles)
415
416	err = storage.Txn(ctx, store, storage.TransactionParams{}, func(txn storage.Transaction) error {
417		query := NewQuery(ast.MustParseBody("data.partial.authz.allow = true")).
418			WithCompiler(compiler).
419			WithStore(store).
420			WithTransaction(txn).
421			WithInput(input)
422		b.ResetTimer()
423		for i := 0; i < b.N; i++ {
424			qrs, err := query.Run(ctx)
425			if len(qrs) != 1 || err != nil {
426				b.Fatal("Unexpected query result:", qrs, "err:", err)
427			}
428		}
429		return nil
430	})
431	if err != nil {
432		b.Fatal(err)
433	}
434}
435
436func runPartialEvalCompileBenchmark(b *testing.B, numRoles int) {
437
438	ctx := context.Background()
439	data := generatePartialEvalBenchmarkData(numRoles)
440	store := inmem.NewFromObject(data)
441
442	err := storage.Txn(ctx, store, storage.TransactionParams{}, func(txn storage.Transaction) error {
443
444		b.ResetTimer()
445
446		for i := 0; i < b.N; i++ {
447			// compile original policy
448			compiler := ast.NewCompiler()
449			compiler.Compile(map[string]*ast.Module{
450				"authz": ast.MustParseModule(partialEvalBenchmarkPolicy),
451			})
452			if compiler.Failed() {
453				return compiler.Errors
454			}
455
456			// run partial evaluation
457			var partials []ast.Body
458			var support []*ast.Module
459			query := NewQuery(ast.MustParseBody("data.authz.allow = true")).
460				WithUnknowns([]*ast.Term{ast.MustParseTerm("input")}).
461				WithCompiler(compiler).
462				WithStore(store).
463				WithTransaction(txn)
464			var err error
465			partials, support, err = query.PartialRun(ctx)
466			if err != nil {
467				return err
468			}
469
470			if len(partials) != numRoles {
471				b.Fatal("Expected exactly one partial query result but got:", partials)
472			} else if len(support) != 0 {
473				b.Fatal("Expected no partial support results but got:", support)
474			}
475
476			// recompile output
477			module := ast.MustParseModule(`package partial.authz`)
478
479			for _, query := range partials {
480				rule := &ast.Rule{
481					Head:   ast.NewHead(ast.Var("allow"), nil, ast.BooleanTerm(true)),
482					Body:   query,
483					Module: module,
484				}
485				module.Rules = append(module.Rules, rule)
486			}
487
488			compiler = ast.NewCompiler()
489			compiler.Compile(map[string]*ast.Module{
490				"test": module,
491			})
492
493			if compiler.Failed() {
494				b.Fatal(compiler.Errors)
495			}
496		}
497
498		return nil
499	})
500
501	if err != nil {
502		b.Fatal(err)
503	}
504}
505
506const partialEvalBenchmarkPolicy = `package authz
507
508	default allow = false
509
510	allow {
511		user_has_role[role_name]
512		role_has_permission[role_name]
513	}
514
515	user_has_role[role_name] {
516		data.bindings[_] = binding
517		binding.iss = input.iss
518		binding.group = input.group
519		role_name = binding.role
520	}
521
522	role_has_permission[role_name] {
523		data.roles[_] = role
524		role.name = role_name
525		role.operation = input.operation
526		role.resource = input.resource
527	}
528	`
529
530func generatePartialEvalBenchmarkData(numRoles int) map[string]interface{} {
531	roles := make([]interface{}, numRoles)
532	bindings := make([]interface{}, numRoles)
533	for i := 0; i < numRoles; i++ {
534		role := map[string]interface{}{
535			"name":      fmt.Sprintf("role-%d", i),
536			"operation": fmt.Sprintf("operation-%d", i),
537			"resource":  fmt.Sprintf("resource-%d", i),
538		}
539		roles[i] = role
540		binding := map[string]interface{}{
541			"name":  fmt.Sprintf("binding-%d", i),
542			"iss":   fmt.Sprintf("iss-%d", i),
543			"group": fmt.Sprintf("group-%d", i),
544			"role":  role["name"],
545		}
546		bindings[i] = binding
547	}
548	return map[string]interface{}{
549		"roles":    roles,
550		"bindings": bindings,
551	}
552}
553
554func generatePartialEvalBenchmarkInput(numRoles int) *ast.Term {
555
556	tmpl, err := template.New("Test").Parse(`{
557		"operation": "operation-{{ . }}",
558		"resource": "resource-{{ . }}",
559		"iss": "iss-{{ . }}",
560		"group": "group-{{ . }}"
561	}`)
562	if err != nil {
563		panic(err)
564	}
565
566	var buf bytes.Buffer
567
568	err = tmpl.Execute(&buf, numRoles-1)
569	if err != nil {
570		panic(err)
571	}
572
573	return ast.MustParseTerm(buf.String())
574}
575
576func BenchmarkWalk(b *testing.B) {
577
578	ctx := context.Background()
579	sizes := []int{100, 1000, 2000, 3000}
580
581	for _, n := range sizes {
582		b.Run(fmt.Sprint(n), func(b *testing.B) {
583			data := genWalkBenchmarkData(n)
584			store := inmem.NewFromObject(data)
585			compiler := ast.NewCompiler()
586			b.ResetTimer()
587			for i := 0; i < b.N; i++ {
588				err := storage.Txn(ctx, store, storage.TransactionParams{}, func(txn storage.Transaction) error {
589					query := ast.MustParseBody(fmt.Sprintf(`walk(data, [["arr", %v], x])`, n-1))
590					compiledQuery, err := compiler.QueryCompiler().Compile(query)
591					if err != nil {
592						b.Fatal(err)
593					}
594					q := NewQuery(compiledQuery).
595						WithStore(store).
596						WithCompiler(compiler).
597						WithTransaction(txn)
598					rs, err := q.Run(ctx)
599					if err != nil || len(rs) != 1 || !rs[0][ast.Var("x")].Equal(ast.IntNumberTerm(n-1)) {
600						b.Fatal("Unexpected result:", rs, "err:", err)
601					}
602					return nil
603				})
604				if err != nil {
605					b.Fatal(err)
606				}
607			}
608		})
609	}
610
611}
612
613func genWalkBenchmarkData(n int) map[string]interface{} {
614	sl := make([]interface{}, n)
615	for i := 0; i < n; i++ {
616		sl[i] = i
617	}
618	return map[string]interface{}{
619		"arr": sl,
620	}
621}
622
623func BenchmarkComprehensionIndexing(b *testing.B) {
624	ctx := context.Background()
625	sizes := []int{10, 100, 1000}
626	for _, n := range sizes {
627		b.Run(fmt.Sprint(n), func(b *testing.B) {
628			data := genComprehensionIndexingData(n)
629			store := inmem.NewFromObject(data)
630			compiler := ast.MustCompileModules(map[string]string{
631				"test.rego": `
632				package test
633
634				p {
635					v := data.items[_]
636					ks := [k | some k; v == data.items[k]]
637				}
638			`,
639			})
640			query, err := compiler.QueryCompiler().Compile(ast.MustParseBody(`data.test.p = true`))
641			if err != nil {
642				b.Fatal(err)
643			}
644			b.ResetTimer()
645			for i := 0; i < b.N; i++ {
646				err = storage.Txn(ctx, store, storage.TransactionParams{}, func(txn storage.Transaction) error {
647					q := NewQuery(query).WithStore(store).WithCompiler(compiler).WithTransaction(txn)
648					rs, err := q.Run(ctx)
649					if err != nil || len(rs) != 1 {
650						b.Fatal("Unexpected result:", rs, "err:", err)
651					}
652					return nil
653				})
654				if err != nil {
655					b.Fatal(err)
656				}
657
658			}
659		})
660	}
661}
662
663func genComprehensionIndexingData(n int) map[string]interface{} {
664	items := map[string]interface{}{}
665	for i := 0; i < n; i++ {
666		items[fmt.Sprint(i)] = fmt.Sprint(i)
667	}
668	return map[string]interface{}{"items": items}
669}
670