1// Copyright 2017 The OPA Authors.  All rights reserved.
2// Use of this source code is governed by an Apache2
3// license that can be found in the LICENSE file.
4
5package ast
6
7import (
8	"encoding/json"
9	"fmt"
10	"reflect"
11	"strings"
12	"testing"
13
14	"github.com/open-policy-agent/opa/types"
15	"github.com/open-policy-agent/opa/util"
16	"github.com/open-policy-agent/opa/util/test"
17)
18
19func TestCheckInference(t *testing.T) {
20
21	// fake_builtin_1([str1,str2])
22	RegisterBuiltin(&Builtin{
23		Name: "fake_builtin_1",
24		Decl: types.NewFunction(
25			nil,
26			types.NewArray(
27				[]types.Type{types.S, types.S}, nil,
28			),
29		),
30	})
31
32	// fake_builtin_2({"a":str1,"b":str2})
33	RegisterBuiltin(&Builtin{
34		Name: "fake_builtin_2",
35		Decl: types.NewFunction(
36			nil,
37			types.NewObject(
38				[]*types.StaticProperty{
39					{Key: "a", Value: types.S},
40					{Key: "b", Value: types.S},
41				}, nil,
42			),
43		),
44	})
45
46	// fake_builtin_3({str1,str2,...})
47	RegisterBuiltin(&Builtin{
48		Name: "fake_builtin_3",
49		Decl: types.NewFunction(
50			nil,
51			types.NewSet(types.S),
52		),
53	})
54
55	tests := []struct {
56		note     string
57		query    string
58		expected map[Var]types.Type
59	}{
60		{"trivial", `x = 1`, map[Var]types.Type{
61			Var("x"): types.N,
62		}},
63		{"one-level", "y = 1; x = y", map[Var]types.Type{
64			Var("x"): types.N,
65			Var("y"): types.N,
66		}},
67		{"two-level", "z = 1; y = z; x = y", map[Var]types.Type{
68			Var("x"): types.N,
69			Var("y"): types.N,
70			Var("z"): types.N,
71		}},
72		{"array-nested", "[x, 1] = [true, y]", map[Var]types.Type{
73			Var("x"): types.B,
74			Var("y"): types.N,
75		}},
76		{"array-transitive", "y = [[2], 1]; [[x], 1] = y", map[Var]types.Type{
77			Var("x"): types.N,
78			Var("y"): types.NewArray(
79				[]types.Type{
80					types.NewArray([]types.Type{types.N}, nil),
81					types.N,
82				}, nil),
83		}},
84		{"array-embedded", `[1, "2", x] = data.foo`, map[Var]types.Type{
85			Var("x"): types.A,
86		}},
87		{"object-nested", `{"a": "foo", "b": {"c": x}} = {"a": y, "b": {"c": 2}}`, map[Var]types.Type{
88			Var("x"): types.N,
89			Var("y"): types.S,
90		}},
91		{"object-transitive", `y = {"a": "foo", "b": 2}; {"a": z, "b": x} = y`, map[Var]types.Type{
92			Var("x"): types.N,
93			Var("z"): types.S,
94		}},
95		{"object-embedded", `{"1": "2", "2": x} = data.foo`, map[Var]types.Type{
96			Var("x"): types.A,
97		}},
98		{"object-numeric-key", `x = {1: 2}; y = 1; x[y]`, map[Var]types.Type{
99			Var("x"): types.NewObject([]*types.StaticProperty{{Key: json.Number("1"), Value: types.N}}, nil),
100			Var("y"): types.N,
101		}},
102		{"object-object-key", `x = {{{}: 1}: 1}`, map[Var]types.Type{
103			Var("x"): types.NewObject(
104				nil,
105				types.NewDynamicProperty(
106					types.NewObject(
107						[]*types.StaticProperty{types.NewStaticProperty(map[string]interface{}{}, types.N)},
108						nil,
109					),
110					types.N,
111				),
112			),
113		}},
114		{"sets", `x = {1, 2}; y = {{"foo", 1}, x}`, map[Var]types.Type{
115			Var("x"): types.NewSet(types.N),
116			Var("y"): types.NewSet(
117				types.NewAny(
118					types.NewSet(
119						types.NewAny(types.N, types.S),
120					),
121					types.NewSet(
122						types.N,
123					),
124				),
125			),
126		}},
127		{"sets-nested", `{"a", 1, 2} = {1,2,3}`, nil},
128		{"sets-composite-ref-operand", `s = {[1, 2], [3, 4]}; s[[x, y]]`, map[Var]types.Type{
129			Var("x"): types.N,
130			Var("y"): types.N,
131			Var("s"): types.NewSet(types.NewArray([]types.Type{types.N, types.N}, nil)),
132		}},
133		{"empty-composites", `
134				obj = {};
135				arr = [];
136				set = set();
137				obj[i] = v1;
138				arr[j] = v2;
139				set[v3];
140				obj = {"foo": "bar"};
141				arr = [1];
142				set = {1,2,3}
143				`, map[Var]types.Type{
144			Var("obj"): types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)),
145			Var("i"):   types.A,
146			Var("v1"):  types.A,
147			Var("arr"): types.NewArray(nil, types.A),
148			Var("j"):   types.N,
149			Var("v2"):  types.A,
150			Var("set"): types.NewSet(types.A),
151			Var("v3"):  types.A,
152		}},
153		{"empty-composite-property", `
154			obj = {};
155			obj.foo = x;
156			obj[i].foo = y
157		`, map[Var]types.Type{
158			Var("x"): types.A,
159			Var("y"): types.A,
160		}},
161		{"local-reference", `
162			a = [
163				1,
164				{
165					"foo": [
166						{"bar": null},
167						-1,
168						{"bar": true}
169					]
170				},
171				3];
172
173			x = a[1].foo[_].bar`, map[Var]types.Type{
174			Var("x"): types.NewAny(types.NewNull(), types.B),
175		}},
176		{"local-reference-var", `
177
178			a = [
179				{
180					"a": null,
181					"b": {
182						"foo": {
183							"c": {1,},
184						},
185						"bar": {
186							"c": {"hello",},
187						},
188					},
189				},
190				{
191					"a": null,
192					"b": {
193						"foo": {
194							"c": {1,},
195						},
196						"bar": {
197							"c": {true,},
198						},
199					},
200				},
201			];
202			x = a[i].b[j].c[k]
203			`, map[Var]types.Type{
204			Var("i"): types.N,
205			Var("j"): types.S,
206			Var("k"): types.NewAny(types.S, types.N, types.B),
207			Var("x"): types.NewAny(types.S, types.N, types.B),
208		}},
209		{"local-reference-var-any", `
210			a = [[], {}];
211			a[_][i]
212		`, map[Var]types.Type{
213			Var("i"): types.A,
214		}},
215		{"local-reference-nested", `
216			a = [["foo"], 0, {"bar": "baz"}, 2];
217			b = [0,1,2,3];
218			a[b[_]][k] = v
219			`, map[Var]types.Type{
220			Var("k"): types.NewAny(types.S, types.N),
221		}},
222		{"simple-built-in", "plus(1,2,x)", map[Var]types.Type{
223			Var("x"): types.N,
224		}},
225		{"simple-built-in-exists", "plus(1,2,x); plus(x,2,y)", map[Var]types.Type{
226			Var("x"): types.N,
227			Var("y"): types.N,
228		}},
229		{"array-builtin", `fake_builtin_1([x,"foo"])`, map[Var]types.Type{
230			Var("x"): types.S,
231		}},
232		{"object-builtin", `fake_builtin_2({"a": "foo", "b": x})`, map[Var]types.Type{
233			Var("x"): types.S,
234		}},
235		{"set-builtin", `fake_builtin_3({"foo", x})`, map[Var]types.Type{
236			Var("x"): types.S,
237		}},
238		{"array-comprehension-ref-closure", `a = [1,"foo",3]; x = [ i | a[_] = i ]`, map[Var]types.Type{
239			Var("x"): types.NewArray(nil, types.NewAny(types.N, types.S)),
240		}},
241		{"array-comprehension-var-closure", `x = 1; y = [ i | x = i ]`, map[Var]types.Type{
242			Var("y"): types.NewArray(nil, types.N),
243		}},
244		{"dynamic-object-value", `q = {"a": "b", "c": "d"}; {k: [v]} = {k: [q[k]]}`, map[Var]types.Type{
245			Var("k"): types.S,
246			Var("v"): types.A,
247		}},
248		{
249			note:  "type unioning: arrays",
250			query: `x = [[1], ["foo"]]; x[_] = [y]`,
251			expected: map[Var]types.Type{
252				Var("y"): types.NewAny(
253					types.N, types.S,
254				),
255			},
256		},
257		{
258			note:  "type unioning: sets",
259			query: `x = {[1], ["foo"]}; x[[y]]`,
260			expected: map[Var]types.Type{
261				Var("y"): types.NewAny(
262					types.N, types.S,
263				),
264			},
265		},
266		{
267			note:  "type unioning: object values",
268			query: `x = {"a": [1], "b": ["foo"]}; x[_] = [y]`,
269			expected: map[Var]types.Type{
270				Var("y"): types.NewAny(
271					types.N, types.S,
272				),
273			},
274		},
275	}
276
277	for _, tc := range tests {
278		test.Subtest(t, tc.note, func(t *testing.T) {
279			body := MustParseBody(tc.query)
280			checker := newTypeChecker()
281			env := checker.checkLanguageBuiltins(nil, BuiltinMap)
282			env, err := checker.CheckBody(env, body)
283			if len(err) != 0 {
284				t.Fatalf("Unexpected error: %v", err)
285			}
286			for k, tpe := range tc.expected {
287				result := env.Get(k)
288				if tpe == nil {
289					if result != nil {
290						t.Errorf("Expected %v type to be unset but got: %v", k, result)
291					}
292				} else {
293					if result == nil {
294						t.Errorf("Expected to infer %v => %v but got nil", k, tpe)
295					} else if types.Compare(tpe, result) != 0 {
296						t.Errorf("Expected to infer %v => %v but got %v", k, tpe, result)
297					}
298				}
299			}
300		})
301	}
302}
303
304func TestCheckInferenceRules(t *testing.T) {
305
306	// Rules must have refs resolved, safe ordering, etc. Each pair is a
307	// (package path, rule) tuple. The test constructs the Rule objects to
308	// run the inference on from these inputs.
309	ruleset1 := [][2]string{
310		{`a`, `trivial = true { true }`},
311		{`a`, `complete = [{"foo": x}] { x = 1 }`},
312		{`a`, `partialset[{"foo": x}] { y = "bar"; x = y }`},
313		{`a`, `partialobj[x] = {"foo": y} { y = "bar"; x = y }`},
314		{`b`, `trivial_ref = x { x = data.a.trivial }`},
315		{`b`, `transitive_ref = [x] { y = data.b.trivial_ref; x = y }`},
316		{`c`, `else_kw = null { false } else = 100 { true } else = "foo" { true }`},
317		{`iteration`, `arr = [[1], ["two"], {"x": true}, ["four"]] { true }`},
318		{`iteration`, `values[x] { data.iteration.arr[_][_] = x } `},
319		{`iteration`, `keys[i] { data.iteration.arr[_][i] = _ } `},
320		{`disjunction`, `partialset[1] { true }`},
321		{`disjunction`, `partialset[x] { x = "foo" }`},
322		{`disjunction`, `partialset[3] { true }`},
323		{`disjunction`, `partialobj[x] = y { y = "bar"; x = "foo" }`},
324		{`disjunction`, `partialobj[x] = y { y = 100; x = "foo" }`},
325		{`disjunction`, `complete = 1 { true }`},
326		{`disjunction`, `complete = x { x = "foo" }`},
327		{`prefix.a.b.c`, `d = true { true }`},
328		{`prefix.i.j.k`, `p = 1 { true }`},
329		{`prefix.i.j.k`, `p = "foo" { true }`},
330		{`default_rule`, `default x = 1`},
331		{`default_rule`, `x = "foo" { true }`},
332		{`unknown_type`, `p = [x] { x = data.deadbeef }`},
333		{`nested_ref`, `inner = {"a": 0, "b": "1"} { true }`},
334		{`nested_ref`, `middle = [[1, true], ["foo", false]] { true }`},
335		{`nested_ref`, `p = x { data.nested_ref.middle[data.nested_ref.inner.a][0] = x }`},
336		{`number_key`, `q[x] = y { a = ["a", "b"]; y = a[x] }`},
337		{`non_leaf`, `p[x] { data.prefix.i[x][_] }`},
338	}
339
340	tests := []struct {
341		note     string
342		rules    [][2]string
343		ref      string
344		expected types.Type
345	}{
346		{"trivial", ruleset1, `data.a.trivial`, types.B},
347
348		{"complete-doc", ruleset1, `data.a.complete`, types.NewArray(
349			[]types.Type{types.NewObject(
350				[]*types.StaticProperty{{
351					Key: "foo", Value: types.N,
352				}},
353				nil,
354			)},
355			nil,
356		)},
357
358		{"complete-doc-suffix", ruleset1, `data.a.complete[0].foo`, types.N},
359
360		{"else-kw", ruleset1, "data.c.else_kw", types.NewAny(types.NewNull(), types.N, types.S)},
361
362		{"partial-set-doc", ruleset1, `data.a.partialset`, types.NewSet(
363			types.NewObject(
364				[]*types.StaticProperty{{
365					Key: "foo", Value: types.S,
366				}},
367				nil,
368			),
369		)},
370
371		{"partial-object-doc", ruleset1, "data.a.partialobj", types.NewObject(
372			nil,
373			types.NewDynamicProperty(types.S, types.NewObject(
374				[]*types.StaticProperty{{
375					Key: "foo", Value: types.S,
376				}},
377				nil,
378			)),
379		)},
380
381		{"partial-object-doc-suffix", ruleset1, `data.a.partialobj.somekey.foo`, types.S},
382
383		{"partial-object-doc-number-suffix", ruleset1, "data.number_key.q[1]", types.S},
384
385		{"iteration", ruleset1, "data.iteration.values", types.NewSet(
386			types.NewAny(
387				types.S,
388				types.N,
389				types.B),
390		)},
391
392		{"iteration-keys", ruleset1, "data.iteration.keys", types.NewSet(
393			types.NewAny(
394				types.S,
395				types.N,
396			),
397		)},
398
399		{"disj-complete-doc", ruleset1, "data.disjunction.complete", types.NewAny(
400			types.S,
401			types.N,
402		)},
403
404		{"disj-partial-set-doc", ruleset1, "data.disjunction.partialset", types.NewSet(
405			types.NewAny(
406				types.S,
407				types.N),
408		)},
409
410		{"disj-partial-obj-doc", ruleset1, "data.disjunction.partialobj", types.NewObject(
411			nil,
412			types.NewDynamicProperty(types.S, types.NewAny(types.S, types.N)),
413		)},
414
415		{"ref", ruleset1, "data.b.trivial_ref", types.B},
416
417		{"ref-transitive", ruleset1, "data.b.transitive_ref", types.NewArray(
418			[]types.Type{
419				types.B,
420			},
421			nil,
422		)},
423
424		{"prefix", ruleset1, `data.prefix.a.b`, types.NewObject(
425			[]*types.StaticProperty{{
426				Key: "c", Value: types.NewObject(
427					[]*types.StaticProperty{{Key: "d", Value: types.B}},
428					types.NewDynamicProperty(types.S, types.A),
429				),
430			}},
431			types.NewDynamicProperty(types.S, types.A),
432		)},
433
434		// Check that prefixes that iterate fallback to any.
435		{"prefix-iter", ruleset1, `data.prefix.i.j[k]`, types.A},
436
437		// Check that iteration targeting a rule (but nonetheless prefixed) falls back to any.
438		{"prefix-iter-2", ruleset1, `data.prefix.i.j[k].p`, types.A},
439
440		{"default-rule", ruleset1, "data.default_rule.x", types.NewAny(
441			types.S,
442			types.N,
443		)},
444
445		{"unknown-type", ruleset1, "data.unknown_type.p", types.NewArray(
446			[]types.Type{
447				types.A,
448			},
449			nil,
450		)},
451
452		{"nested-ref", ruleset1, "data.nested_ref.p", types.NewAny(
453			types.S,
454			types.N,
455		)},
456
457		{"non-leaf", ruleset1, "data.non_leaf.p", types.NewSet(
458			types.S,
459		)},
460	}
461
462	for _, tc := range tests {
463		test.Subtest(t, tc.note, func(t *testing.T) {
464			var elems []util.T
465
466			// Convert test rules into rule slice for call.
467			for i := range tc.rules {
468				pkg := MustParsePackage(`package ` + tc.rules[i][0])
469				rule := MustParseRule(tc.rules[i][1])
470				module := &Module{
471					Package: pkg,
472					Rules:   []*Rule{rule},
473				}
474				rule.Module = module
475				elems = append(elems, rule)
476				for next := rule.Else; next != nil; next = next.Else {
477					next.Module = module
478					elems = append(elems, next)
479				}
480			}
481
482			ref := MustParseRef(tc.ref)
483			checker := newTypeChecker()
484			env, err := checker.CheckTypes(nil, elems)
485
486			if err != nil {
487				t.Fatalf("Unexpected error %v:", err)
488			}
489
490			result := env.Get(ref)
491			if tc.expected == nil {
492				if result != nil {
493					t.Errorf("Expected %v type to be unset but got: %v", ref, result)
494				}
495			} else {
496				if result == nil {
497					t.Errorf("Expected to infer %v => %v but got nil", ref, tc.expected)
498				} else if types.Compare(tc.expected, result) != 0 {
499					t.Errorf("Expected to infer %v => %v but got %v", ref, tc.expected, result)
500				}
501			}
502		})
503	}
504
505}
506
507func TestCheckErrorSuppression(t *testing.T) {
508
509	query := `arr = [1,2,3]; arr[0].deadbeef = 1`
510
511	_, errs := newTypeChecker().CheckBody(nil, MustParseBody(query))
512	if len(errs) != 1 {
513		t.Fatalf("Expected exactly one error but got: %v", errs)
514	}
515
516	_, ok := errs[0].Details.(*RefErrUnsupportedDetail)
517	if !ok {
518		t.Fatalf("Expected ref error but got: %v", errs)
519	}
520
521	query = `_ = [true | count(1)]`
522
523	_, errs = newTypeChecker().CheckBody(newTypeChecker().checkLanguageBuiltins(nil, BuiltinMap), MustParseBody(query))
524	if len(errs) != 1 {
525		t.Fatalf("Expected exactly one error but got: %v", errs)
526	}
527
528	_, ok = errs[0].Details.(*ArgErrDetail)
529	if !ok {
530		t.Fatalf("Expected arg error but got: %v", errs)
531	}
532
533}
534
535func TestCheckBadCardinality(t *testing.T) {
536	tests := []struct {
537		body string
538		exp  []types.Type
539	}{
540		{
541			body: "plus(1)",
542			exp:  []types.Type{types.N},
543		},
544		{
545			body: "plus(1, 2, 3, 4)",
546			exp:  []types.Type{types.N, types.N, types.N, types.N},
547		},
548	}
549	for _, test := range tests {
550		body := MustParseBody(test.body)
551		tc := newTypeChecker()
552		env := tc.checkLanguageBuiltins(nil, BuiltinMap)
553		_, err := tc.CheckBody(env, body)
554		if len(err) != 1 || err[0].Code != TypeErr {
555			t.Fatalf("Expected 1 type error from %v but got: %v", body, err)
556		}
557		detail, ok := err[0].Details.(*ArgErrDetail)
558		if !ok {
559			t.Fatalf("Expected argument error details but got: %v", err)
560		}
561		if len(test.exp) != len(detail.Have) {
562			t.Fatalf("Expected arg types %v but got: %v", test.exp, detail.Have)
563		}
564		for i := range test.exp {
565			if types.Compare(test.exp[i], detail.Have[i]) != 0 {
566				t.Fatalf("Expected types for %v to be %v but got: %v", body[0], test.exp, detail.Have)
567			}
568		}
569	}
570}
571
572func TestCheckMatchErrors(t *testing.T) {
573	tests := []struct {
574		note  string
575		query string
576	}{
577		{"null", "null = true"},
578		{"boolean", "true = null"},
579		{"number", "1 = null"},
580		{"string", `"hello" = null`},
581		{"array", "[1,2,3] = null"},
582		{"array-nested", `[1,2,3] = [1,2,"3"]`},
583		{"array-nested-2", `[1,2] = [1,2,3]`},
584		{"array-dynamic", `[ true | true ] = [x | a = [1, "foo"]; x = a[_]]`},
585		{"object", `{"a": 1, "b": 2} = null`},
586		{"object-nested", `{"a": 1, "b": "2"} = {"a": 1, "b": 2}`},
587		{"object-nested-2", `{"a": 1} = {"a": 1, "b": "2"}`},
588		{"set", "{1,2,3} = null"},
589		{"any", `x = ["str", 1]; x[_] = null`},
590	}
591	for _, tc := range tests {
592		test.Subtest(t, tc.note, func(t *testing.T) {
593			body := MustParseBody(tc.query)
594			checker := newTypeChecker()
595			_, err := checker.CheckBody(nil, body)
596			if len(err) != 1 {
597				t.Fatalf("Expected exactly one error from %v, but got:\n%v", body, err)
598			}
599		})
600	}
601}
602
603func TestCheckBuiltinErrors(t *testing.T) {
604
605	RegisterBuiltin(&Builtin{
606		Name: "fake_builtin_2",
607		Decl: types.NewFunction(
608			types.Args(
609				types.NewAny(types.NewObject(
610					[]*types.StaticProperty{
611						{Key: "a", Value: types.S},
612						{Key: "b", Value: types.S},
613					}, nil),
614				),
615			),
616			types.NewObject(
617				[]*types.StaticProperty{
618					{Key: "b", Value: types.S},
619					{Key: "c", Value: types.S},
620				}, nil,
621			),
622		),
623	})
624
625	tests := []struct {
626		note  string
627		query string
628	}{
629		{"trivial", "plus(true, 1, x)"},
630		{"refs", "x = [null]; plus(x[0], 1, y)"},
631		{"array comprehensions", `sum([null | true], x)`},
632		{"arrays-any", `sum([1,2,"3",4], x)`},
633		{"arrays-bad-input", `contains([1,2,3], "x")`},
634		{"objects-any", `fake_builtin_2({"a": a, "c": c})`},
635		{"objects-bad-input", `sum({"a": 1, "b": 2}, x)`},
636		{"sets-any", `sum({1,2,"3",4}, x)`},
637		{"virtual-ref", `plus(data.test.p, data.deabeef, 0)`},
638	}
639
640	env := newTestEnv([]string{
641		`p = "foo" { true }`,
642		`f(x) = x { true }`,
643	})
644
645	for _, tc := range tests {
646		test.Subtest(t, tc.note, func(t *testing.T) {
647			body := MustParseBody(tc.query)
648			checker := newTypeChecker()
649			_, err := checker.CheckBody(env, body)
650			if len(err) != 1 {
651				t.Fatalf("Expected exactly one error from %v but got:\n%v", body, err)
652			}
653		})
654	}
655}
656
657func TestVoidBuiltins(t *testing.T) {
658
659	// Void builtins are used in test cases.
660	RegisterBuiltin(&Builtin{
661		Name: "fake_void_builtin",
662		Decl: types.NewFunction(
663			types.Args(types.N),
664			nil,
665		),
666	})
667
668	tests := []struct {
669		query   string
670		wantErr bool
671	}{
672		{"fake_void_builtin(1)", false},
673		{"fake_void_builtin()", true},
674		{"fake_void_builtin(1,2)", true},
675		{"fake_void_builtin(true)", true},
676	}
677
678	for _, tc := range tests {
679		body := MustParseBody(tc.query)
680		checker := newTypeChecker()
681		_, errs := checker.CheckBody(newTestEnv(nil), body)
682		if len(errs) != 0 && !tc.wantErr {
683			t.Fatal(errs)
684		} else if len(errs) == 0 && tc.wantErr {
685			t.Fatal("Expected error")
686		}
687	}
688}
689
690func TestCheckRefErrUnsupported(t *testing.T) {
691
692	query := `arr = [[1,2],[3,4]]; arr[1][0].deadbeef`
693
694	_, errs := newTypeChecker().CheckBody(nil, MustParseBody(query))
695	if len(errs) != 1 {
696		t.Fatalf("Expected exactly one error but got: %v", errs)
697	}
698
699	details, ok := errs[0].Details.(*RefErrUnsupportedDetail)
700	if !ok {
701		t.Fatalf("Expected ref err unsupported but got: %v", errs)
702	}
703
704	wantRef := MustParseRef(`arr[1][0].deadbeef`)
705	wantPos := 2
706	wantHave := types.N
707
708	if !wantRef.Equal(details.Ref) ||
709		wantPos != details.Pos ||
710		types.Compare(wantHave, details.Have) != 0 {
711		t.Fatalf("Expected (%v, %v, %v) but got: (%v, %v, %v)", wantRef, wantPos, wantHave, details.Ref, details.Pos, details.Have)
712	}
713
714}
715
716func TestCheckRefErrInvalid(t *testing.T) {
717
718	env := newTestEnv([]string{
719		`p { true }`,
720		`q = {"foo": 1, "bar": 2} { true }`,
721	})
722
723	tests := []struct {
724		note  string
725		query string
726		ref   string
727		pos   int
728		have  types.Type
729		want  types.Type
730		oneOf []Value
731	}{
732		{
733			note:  "bad non-leaf var",
734			query: `x = 1; data.test[x]`,
735			ref:   `data.test[x]`,
736			pos:   2,
737			have:  types.N,
738			want:  types.S,
739			oneOf: []Value{String("p"), String("q")},
740		},
741		{
742			note:  "bad non-leaf ref",
743			query: `arr = [1]; data.test[arr[0]]`,
744			ref:   `data.test[arr[0]]`,
745			pos:   2,
746			have:  types.N,
747			want:  types.S,
748			oneOf: []Value{String("p"), String("q")},
749		},
750		{
751			note:  "bad leaf ref",
752			query: `arr = [1]; data.test.q[arr[0]]`,
753			ref:   `data.test.q[arr[0]]`,
754			pos:   3,
755			have:  types.N,
756			want:  types.S,
757			oneOf: []Value{String("bar"), String("foo")},
758		},
759		{
760			note:  "bad leaf var",
761			query: `x = 1; data.test.q[x]`,
762			ref:   `data.test.q[x]`,
763			pos:   3,
764			have:  types.N,
765			want:  types.S,
766			oneOf: []Value{String("bar"), String("foo")},
767		},
768		{
769			note:  "bad array index value",
770			query: "arr = [[1,2],[3],[4]]; arr[0].dead.beef = x",
771			ref:   "arr[0].dead.beef",
772			pos:   2,
773			want:  types.N,
774		},
775		{
776			note:  "bad set element value",
777			query: `s = {{1,2},{3,4}}; x = {1,2}; s[x].deadbeef`,
778			ref:   "s[x].deadbeef",
779			pos:   2,
780			want:  types.N,
781		},
782		{
783			note:  "bad object key value",
784			query: `arr = [{"a": 1, "c": 3}, {"b": 2}]; arr[0].b`,
785			ref:   "arr[0].b",
786			pos:   2,
787			want:  types.S,
788			oneOf: []Value{String("a"), String("c")},
789		},
790		{
791			note:  "bad non-leaf value",
792			query: `data.test[1]`,
793			ref:   "data.test[1]",
794			pos:   2,
795			want:  types.S,
796			oneOf: []Value{String("p"), String("q")},
797		},
798		{
799			note:  "composite ref into non-set",
800			query: `data.test.q[[1, 2]]`,
801			ref:   "data.test.q[[1, 2]]",
802			pos:   3,
803			have:  types.NewObject([]*types.StaticProperty{types.NewStaticProperty("bar", types.N), types.NewStaticProperty("foo", types.N)}, nil),
804			want:  types.NewSet(types.A),
805		},
806		{
807			note:  "composite ref type error 1",
808			query: `a = {[1], [2], [3]}; a[["foo"]]`,
809			ref:   `a[["foo"]]`,
810			pos:   1,
811			have:  types.NewArray([]types.Type{types.S}, nil),
812			want:  types.NewArray([]types.Type{types.N}, nil),
813		},
814		{
815			note:  "composite ref type error 2",
816			query: `a = {{"a": 2}}; a[{"a": "foo"}]`,
817			ref:   `a[{"a": "foo"}]`,
818			pos:   1,
819			have:  types.NewObject([]*types.StaticProperty{types.NewStaticProperty("a", types.S)}, nil),
820			want:  types.NewObject([]*types.StaticProperty{types.NewStaticProperty("a", types.N)}, nil),
821		},
822	}
823
824	for _, tc := range tests {
825		test.Subtest(t, tc.note, func(t *testing.T) {
826
827			_, errs := newTypeChecker().CheckBody(env, MustParseBody(tc.query))
828			if len(errs) != 1 {
829				t.Fatalf("Expected exactly one error but got: %v", errs)
830			}
831
832			details, ok := errs[0].Details.(*RefErrInvalidDetail)
833			if !ok {
834				t.Fatalf("Expected ref error invalid but got: %v", errs)
835			}
836
837			wantRef := MustParseRef(tc.ref)
838
839			if details.Pos != tc.pos ||
840				!details.Ref.Equal(wantRef) ||
841				types.Compare(details.Want, tc.want) != 0 ||
842				types.Compare(details.Have, tc.have) != 0 ||
843				!reflect.DeepEqual(details.OneOf, tc.oneOf) {
844				t.Fatalf("Expected (%v, %v, %v, %v, %v) but got: (%v, %v, %v, %v, %v)", wantRef, tc.pos, tc.have, tc.want, tc.oneOf, details.Ref, details.Pos, details.Have, details.Want, details.OneOf)
845			}
846		})
847	}
848}
849
850func TestFunctionsTypeInference(t *testing.T) {
851	functions := []string{
852		`foo([a, b]) = y { split(a, b, y) }`,
853		`bar(x) = y { count(x, y) }`,
854		`baz([x, y]) = z { sprintf("%s%s", [x, y], z) }`,
855		`qux({"bar": x, "foo": y}) = {a: b} { upper(y, a); json.unmarshal(x, b) }`,
856		`corge(x) = y { qux({"bar": x, "foo": x}, a); baz([a["{5: true}"], "BUZ"], y) }`,
857	}
858	body := strings.Join(functions, "\n")
859	base := fmt.Sprintf("package base\n%s", body)
860
861	c := NewCompiler()
862	if c.Compile(map[string]*Module{"base": MustParseModule(base)}); c.Failed() {
863		t.Fatalf("Failed to compile base module: %v", c.Errors)
864	}
865
866	tests := []struct {
867		body    string
868		wantErr bool
869	}{
870		{
871			`fn(_) = y { data.base.foo(["hello", 5], y) }`,
872			true,
873		},
874		{
875			`fn(_) = y { data.base.foo(["hello", "ll"], y) }`,
876			false,
877		},
878		{
879			`fn(_) = y { data.base.baz(["hello", "ll"], y) }`,
880			false,
881		},
882		{
883			`fn(_) = y { data.base.baz([5, ["foo", "bar", true]], y) }`,
884			false,
885		},
886		{
887			`fn(_) = y { data.base.baz(["hello", {"a": "b", "c": 3}], y) }`,
888			false,
889		},
890		{
891			`fn(_) = y { data.base.corge("this is not json", y) }`,
892			false,
893		},
894		{
895			`fn(x) = y { data.non_existent(x, a); y = a[0] }`,
896			true,
897		},
898		{
899			`fn(x) = y { y = [x] }`,
900			false,
901		},
902		{
903			`f(x) = y { [x] = y }`,
904			false,
905		},
906		{
907			`fn(x) = y { y = {"k": x} }`,
908			false,
909		},
910		{
911			`f(x) = y { {"k": x} = y }`,
912			false,
913		},
914		{
915			`p { [data.base.foo] }`,
916			true,
917		},
918		{
919			`p { x = data.base.foo }`,
920			true,
921		},
922		{
923			`p { data.base.foo(data.base.bar) }`,
924			true,
925		},
926	}
927
928	for n, test := range tests {
929		t.Run(fmt.Sprintf("Test Case %d", n), func(t *testing.T) {
930			mod := MustParseModule(fmt.Sprintf("package test\n%s", test.body))
931			c := NewCompiler()
932			c.Compile(map[string]*Module{"base": MustParseModule(base), "mod": mod})
933			if test.wantErr && !c.Failed() {
934				t.Errorf("Expected error but got success")
935			} else if !test.wantErr && c.Failed() {
936				t.Errorf("Expected success but got error: %v", c.Errors)
937			}
938		})
939	}
940}
941
942func TestFunctionTypeInferenceUnappliedWithObjectVarKey(t *testing.T) {
943
944	// Run type inference on a function that constructs an object with a key
945	// from args in the head.
946	module := MustParseModule(`
947		package test
948
949		f(x) = y { y = {x: 1} }
950	`)
951
952	env, err := newTypeChecker().CheckTypes(newTypeChecker().checkLanguageBuiltins(nil, BuiltinMap), []util.T{
953		module.Rules[0],
954	})
955
956	if len(err) > 0 {
957		t.Fatal(err)
958	}
959
960	// Check inferred type for reference to function.
961	tpe := env.Get(MustParseRef("data.test.f"))
962	exp := types.NewFunction([]types.Type{types.A}, types.NewObject(nil, types.NewDynamicProperty(types.A, types.N)))
963
964	if types.Compare(tpe, exp) != 0 {
965		t.Fatalf("Expected %v but got %v", exp, tpe)
966	}
967}
968
969func TestCheckValidErrors(t *testing.T) {
970
971	module := MustParseModule(`
972		package test
973
974		p {
975			concat("", 1)  # type error
976		}
977
978		q {
979			r(1)
980		}
981
982		r(x) = x`)
983
984	module2 := MustParseModule(`
985		package test
986
987		b {
988			a(1)		# call erroneous function
989		}
990
991		a(x) {
992			max("foo")  # max requires an array
993		}
994
995		m {
996			1 / "foo"	# type error
997		}
998
999		n {
1000			m			# call erroneous rule
1001		}`)
1002
1003	module3 := MustParseModule(`
1004		package test
1005
1006		x := {"a" : 1}
1007
1008		y {
1009			z
1010		}
1011
1012		z {
1013			x[1] == 1	# undefined reference error
1014		}`)
1015
1016	tests := map[string]struct {
1017		module *Module
1018		numErr int
1019		query  []string
1020	}{
1021		"single_type_error":         {module: module, numErr: 1, query: []string{`data.test.p`}},
1022		"multiple_type_error":       {module: module2, numErr: 2, query: []string{`data.test.a`, `data.test.m`}},
1023		"undefined_reference_error": {module: module3, numErr: 1, query: []string{`data.test.z`}},
1024	}
1025
1026	for name, tc := range tests {
1027		t.Run(name, func(t *testing.T) {
1028			c := NewCompiler()
1029			c.Compile(map[string]*Module{"test": tc.module})
1030
1031			if !c.Failed() {
1032				t.Errorf("Expected error but got success")
1033			}
1034
1035			if len(c.Errors) != tc.numErr {
1036				t.Fatalf("Expected %v error(s) but got: %v", tc.numErr, c.Errors)
1037			}
1038
1039			// check type of the rule/function that contains an error
1040			for _, q := range tc.query {
1041				tpe := c.TypeEnv.Get(MustParseRef(q))
1042
1043				if types.Compare(tpe, types.NewAny()) != 0 {
1044					t.Fatalf("Expected Any type but got %v", tpe)
1045				}
1046			}
1047		})
1048	}
1049}
1050
1051func TestCheckErrorDetails(t *testing.T) {
1052
1053	tests := []struct {
1054		detail   ErrorDetails
1055		expected []string
1056	}{
1057		{
1058			detail: &RefErrUnsupportedDetail{
1059				Ref:  MustParseRef("data.foo[x]"),
1060				Pos:  1,
1061				Have: types.N,
1062			},
1063			expected: []string{
1064				"data.foo[x]",
1065				"^^^^^^^^",
1066				"have: number",
1067			},
1068		},
1069		{
1070			detail: &RefErrInvalidDetail{
1071				Ref:  MustParseRef("data.foo[x]"),
1072				Pos:  2,
1073				Have: types.N,
1074				Want: types.S,
1075			},
1076			expected: []string{
1077				"data.foo[x]",
1078				"         ^",
1079				"         have (type): number",
1080				"         want (type): string",
1081			},
1082		},
1083		{
1084			detail: &RefErrInvalidDetail{
1085				Ref:  MustParseRef("data.foo[100]"),
1086				Pos:  2,
1087				Want: types.S,
1088				OneOf: []Value{
1089					String("a"),
1090					String("b"),
1091				},
1092			},
1093			expected: []string{
1094				"data.foo[100]",
1095				"         ^",
1096				"         have: 100",
1097				`         want (one of): ["a" "b"]`,
1098			},
1099		},
1100		{
1101			detail: &ArgErrDetail{
1102				Have: []types.Type{
1103					types.N,
1104					types.S,
1105				},
1106				Want: []types.Type{
1107					types.S,
1108					types.S,
1109				},
1110			},
1111			expected: []string{
1112				"have: (number, string)",
1113				"want: (string, string)",
1114			},
1115		},
1116	}
1117
1118	for _, tc := range tests {
1119		if !reflect.DeepEqual(tc.detail.Lines(), tc.expected) {
1120			t.Errorf("Expected %v for %v but got: %v", tc.expected, tc.detail, tc.detail.Lines())
1121		}
1122	}
1123}
1124
1125func TestCheckErrorOrdering(t *testing.T) {
1126
1127	mod := MustParseModule(`
1128		package test
1129
1130		q = true
1131
1132		p { data.test.q = 1 }  # type error: bool = number
1133		p { data.test.q = 2 }  # type error: bool = number
1134	`)
1135
1136	input := make([]util.T, len(mod.Rules))
1137	inputReversed := make([]util.T, len(mod.Rules))
1138
1139	for i := range input {
1140		input[i] = mod.Rules[i]
1141		inputReversed[i] = mod.Rules[i]
1142	}
1143
1144	tmp := inputReversed[1]
1145	inputReversed[1] = inputReversed[2]
1146	inputReversed[2] = tmp
1147
1148	_, errs1 := newTypeChecker().CheckTypes(nil, input)
1149	_, errs2 := newTypeChecker().CheckTypes(nil, inputReversed)
1150
1151	if errs1.Error() != errs2.Error() {
1152		t.Fatalf("Expected error slices to be equal. errs1:\n\n%v\n\nerrs2:\n\n%v\n\n", errs1, errs2)
1153	}
1154}
1155
1156func TestRewrittenVarsInErrors(t *testing.T) {
1157
1158	_, errs := newTypeChecker().WithVarRewriter(rewriteVarsInRef(map[Var]Var{
1159		"__local0__": "foo",
1160		"__local1__": "bar",
1161	})).CheckBody(nil, MustParseBody(`__local0__ = [[1]]; __local1__ = "bar"; __local0__[0][__local1__]`))
1162
1163	if len(errs) != 1 {
1164		t.Fatal("expected exactly one error but got:", len(errs))
1165	}
1166
1167	detail, ok := errs[0].Details.(*RefErrInvalidDetail)
1168	if !ok {
1169		t.Fatal("expected invalid ref detail but got:", errs[0].Details)
1170	}
1171
1172	if !detail.Ref.Equal(MustParseRef("foo[0][bar]")) {
1173		t.Fatal("expected ref to be foo[0][bar] but got:", detail.Ref)
1174	}
1175
1176}
1177
1178func newTestEnv(rs []string) *TypeEnv {
1179	module := MustParseModule(`
1180		package test
1181	`)
1182
1183	var elems []util.T
1184
1185	for i := range rs {
1186		rule := MustParseRule(rs[i])
1187		rule.Module = module
1188		elems = append(elems, rule)
1189		for next := rule.Else; next != nil; next = next.Else {
1190			next.Module = module
1191			elems = append(elems, next)
1192		}
1193	}
1194
1195	env, err := newTypeChecker().CheckTypes(newTypeChecker().checkLanguageBuiltins(nil, BuiltinMap), elems)
1196	if len(err) > 0 {
1197		panic(err)
1198	}
1199
1200	return env
1201}
1202