1// Copyright 2016 The OPA Authors.  All rights reserved.
2// Use of this source code is governed by an Apache2
3// license that can be found in the LICENSE file.
4
5package ast
6
7import (
8	"fmt"
9	"sort"
10	"strconv"
11	"strings"
12
13	"github.com/open-policy-agent/opa/metrics"
14	"github.com/open-policy-agent/opa/util"
15)
16
17// CompileErrorLimitDefault is the default number errors a compiler will allow before
18// exiting.
19const CompileErrorLimitDefault = 10
20
21var errLimitReached = NewError(CompileErr, nil, "error limit reached")
22
23// Compiler contains the state of a compilation process.
24type Compiler struct {
25
26	// Errors contains errors that occurred during the compilation process.
27	// If there are one or more errors, the compilation process is considered
28	// "failed".
29	Errors Errors
30
31	// Modules contains the compiled modules. The compiled modules are the
32	// output of the compilation process. If the compilation process failed,
33	// there is no guarantee about the state of the modules.
34	Modules map[string]*Module
35
36	// ModuleTree organizes the modules into a tree where each node is keyed by
37	// an element in the module's package path. E.g., given modules containing
38	// the following package directives: "a", "a.b", "a.c", and "a.b", the
39	// resulting module tree would be:
40	//
41	//  root
42	//    |
43	//    +--- data (no modules)
44	//           |
45	//           +--- a (1 module)
46	//                |
47	//                +--- b (2 modules)
48	//                |
49	//                +--- c (1 module)
50	//
51	ModuleTree *ModuleTreeNode
52
53	// RuleTree organizes rules into a tree where each node is keyed by an
54	// element in the rule's path. The rule path is the concatenation of the
55	// containing package and the stringified rule name. E.g., given the
56	// following module:
57	//
58	//  package ex
59	//  p[1] { true }
60	//  p[2] { true }
61	//  q = true
62	//
63	//  root
64	//    |
65	//    +--- data (no rules)
66	//           |
67	//           +--- ex (no rules)
68	//                |
69	//                +--- p (2 rules)
70	//                |
71	//                +--- q (1 rule)
72	RuleTree *TreeNode
73
74	// Graph contains dependencies between rules. An edge (u,v) is added to the
75	// graph if rule 'u' refers to the virtual document defined by 'v'.
76	Graph *Graph
77
78	// TypeEnv holds type information for values inferred by the compiler.
79	TypeEnv *TypeEnv
80
81	// RewrittenVars is a mapping of variables that have been rewritten
82	// with the key being the generated name and value being the original.
83	RewrittenVars map[Var]Var
84
85	localvargen  *localVarGenerator
86	moduleLoader ModuleLoader
87	ruleIndices  *util.HashMap
88	stages       []struct {
89		name       string
90		metricName string
91		f          func()
92	}
93	maxErrs           int
94	sorted            []string // list of sorted module names
95	pathExists        func([]string) (bool, error)
96	after             map[string][]CompilerStageDefinition
97	metrics           metrics.Metrics
98	builtins          map[string]*Builtin
99	unsafeBuiltinsMap map[string]struct{}
100}
101
102// CompilerStage defines the interface for stages in the compiler.
103type CompilerStage func(*Compiler) *Error
104
105// CompilerStageDefinition defines a compiler stage
106type CompilerStageDefinition struct {
107	Name       string
108	MetricName string
109	Stage      CompilerStage
110}
111
112// QueryContext contains contextual information for running an ad-hoc query.
113//
114// Ad-hoc queries can be run in the context of a package and imports may be
115// included to provide concise access to data.
116type QueryContext struct {
117	Package *Package
118	Imports []*Import
119}
120
121// NewQueryContext returns a new QueryContext object.
122func NewQueryContext() *QueryContext {
123	return &QueryContext{}
124}
125
126// WithPackage sets the pkg on qc.
127func (qc *QueryContext) WithPackage(pkg *Package) *QueryContext {
128	if qc == nil {
129		qc = NewQueryContext()
130	}
131	qc.Package = pkg
132	return qc
133}
134
135// WithImports sets the imports on qc.
136func (qc *QueryContext) WithImports(imports []*Import) *QueryContext {
137	if qc == nil {
138		qc = NewQueryContext()
139	}
140	qc.Imports = imports
141	return qc
142}
143
144// Copy returns a deep copy of qc.
145func (qc *QueryContext) Copy() *QueryContext {
146	if qc == nil {
147		return nil
148	}
149	cpy := *qc
150	if cpy.Package != nil {
151		cpy.Package = qc.Package.Copy()
152	}
153	cpy.Imports = make([]*Import, len(qc.Imports))
154	for i := range qc.Imports {
155		cpy.Imports[i] = qc.Imports[i].Copy()
156	}
157	return &cpy
158}
159
160// QueryCompiler defines the interface for compiling ad-hoc queries.
161type QueryCompiler interface {
162
163	// Compile should be called to compile ad-hoc queries. The return value is
164	// the compiled version of the query.
165	Compile(q Body) (Body, error)
166
167	// TypeEnv returns the type environment built after running type checking
168	// on the query.
169	TypeEnv() *TypeEnv
170
171	// WithContext sets the QueryContext on the QueryCompiler. Subsequent calls
172	// to Compile will take the QueryContext into account.
173	WithContext(qctx *QueryContext) QueryCompiler
174
175	// WithUnsafeBuiltins sets the built-in functions to treat as unsafe and not
176	// allow inside of queries. By default the query compiler inherits the
177	// compiler's unsafe built-in functions. This function allows callers to
178	// override that set. If an empty (non-nil) map is provided, all built-ins
179	// are allowed.
180	WithUnsafeBuiltins(unsafe map[string]struct{}) QueryCompiler
181
182	// WithStageAfter registers a stage to run during query compilation after
183	// the named stage.
184	WithStageAfter(after string, stage QueryCompilerStageDefinition) QueryCompiler
185
186	// RewrittenVars maps generated vars in the compiled query to vars from the
187	// parsed query. For example, given the query "input := 1" the rewritten
188	// query would be "__local0__ = 1". The mapping would then be {__local0__: input}.
189	RewrittenVars() map[Var]Var
190}
191
192// QueryCompilerStage defines the interface for stages in the query compiler.
193type QueryCompilerStage func(QueryCompiler, Body) (Body, error)
194
195// QueryCompilerStageDefinition defines a QueryCompiler stage
196type QueryCompilerStageDefinition struct {
197	Name       string
198	MetricName string
199	Stage      QueryCompilerStage
200}
201
202const compileStageMetricPrefex = "ast_compile_stage_"
203
204// NewCompiler returns a new empty compiler.
205func NewCompiler() *Compiler {
206
207	c := &Compiler{
208		Modules:       map[string]*Module{},
209		TypeEnv:       NewTypeEnv(),
210		RewrittenVars: map[Var]Var{},
211		ruleIndices: util.NewHashMap(func(a, b util.T) bool {
212			r1, r2 := a.(Ref), b.(Ref)
213			return r1.Equal(r2)
214		}, func(x util.T) int {
215			return x.(Ref).Hash()
216		}),
217		maxErrs:           CompileErrorLimitDefault,
218		after:             map[string][]CompilerStageDefinition{},
219		unsafeBuiltinsMap: map[string]struct{}{},
220	}
221
222	c.ModuleTree = NewModuleTree(nil)
223	c.RuleTree = NewRuleTree(c.ModuleTree)
224
225	// Initialize the compiler with the statically compiled built-in functions.
226	// If the caller customizes the compiler, a copy will be made.
227	c.builtins = BuiltinMap
228	checker := newTypeChecker()
229	c.TypeEnv = checker.checkLanguageBuiltins(nil, c.builtins)
230
231	c.stages = []struct {
232		name       string
233		metricName string
234		f          func()
235	}{
236		// Reference resolution should run first as it may be used to lazily
237		// load additional modules. If any stages run before resolution, they
238		// need to be re-run after resolution.
239		{"ResolveRefs", "compile_stage_resolve_refs", c.resolveAllRefs},
240
241		// The local variable generator must be initialized after references are
242		// resolved and the dynamic module loader has run but before subsequent
243		// stages that need to generate variables.
244		{"InitLocalVarGen", "compile_stage_init_local_var_gen", c.initLocalVarGen},
245
246		{"RewriteLocalVars", "compile_stage_rewrite_local_vars", c.rewriteLocalVars},
247		{"RewriteExprTerms", "compile_stage_rewrite_expr_terms", c.rewriteExprTerms},
248		{"SetModuleTree", "compile_stage_set_module_tree", c.setModuleTree},
249		{"SetRuleTree", "compile_stage_set_rule_tree", c.setRuleTree},
250		{"SetGraph", "compile_stage_set_graph", c.setGraph},
251		{"RewriteComprehensionTerms", "compile_stage_rewrite_comprehension_terms", c.rewriteComprehensionTerms},
252		{"RewriteRefsInHead", "compile_stage_rewrite_refs_in_head", c.rewriteRefsInHead},
253		{"RewriteWithValues", "compile_stage_rewrite_with_values", c.rewriteWithModifiers},
254		{"CheckRuleConflicts", "compile_stage_check_rule_conflicts", c.checkRuleConflicts},
255		{"CheckUndefinedFuncs", "compile_stage_check_undefined_funcs", c.checkUndefinedFuncs},
256		{"CheckSafetyRuleHeads", "compile_stage_check_safety_rule_heads", c.checkSafetyRuleHeads},
257		{"CheckSafetyRuleBodies", "compile_stage_check_safety_rule_bodies", c.checkSafetyRuleBodies},
258		{"RewriteEquals", "compile_stage_rewrite_equals", c.rewriteEquals},
259		{"RewriteDynamicTerms", "compile_stage_rewrite_dynamic_terms", c.rewriteDynamicTerms},
260		{"CheckRecursion", "compile_stage_check_recursion", c.checkRecursion},
261		{"CheckTypes", "compile_stage_check_types", c.checkTypes},
262		{"CheckUnsafeBuiltins", "compile_state_check_unsafe_builtins", c.checkUnsafeBuiltins},
263		{"BuildRuleIndices", "compile_stage_rebuild_indices", c.buildRuleIndices},
264	}
265
266	return c
267}
268
269// SetErrorLimit sets the number of errors the compiler can encounter before it
270// quits. Zero or a negative number indicates no limit.
271func (c *Compiler) SetErrorLimit(limit int) *Compiler {
272	c.maxErrs = limit
273	return c
274}
275
276// WithPathConflictsCheck enables base-virtual document conflict
277// detection. The compiler will check that rules don't overlap with
278// paths that exist as determined by the provided callable.
279func (c *Compiler) WithPathConflictsCheck(fn func([]string) (bool, error)) *Compiler {
280	c.pathExists = fn
281	return c
282}
283
284// WithStageAfter registers a stage to run during compilation after
285// the named stage.
286func (c *Compiler) WithStageAfter(after string, stage CompilerStageDefinition) *Compiler {
287	c.after[after] = append(c.after[after], stage)
288	return c
289}
290
291// WithMetrics will set a metrics.Metrics and be used for profiling
292// the Compiler instance.
293func (c *Compiler) WithMetrics(metrics metrics.Metrics) *Compiler {
294	c.metrics = metrics
295	return c
296}
297
298// WithBuiltins adds a set of custom built-in functions to the compiler.
299func (c *Compiler) WithBuiltins(builtins map[string]*Builtin) *Compiler {
300	if len(builtins) == 0 {
301		return c
302	}
303	cpy := make(map[string]*Builtin, len(c.builtins)+len(builtins))
304	for k, v := range c.builtins {
305		cpy[k] = v
306	}
307	for k, v := range builtins {
308		cpy[k] = v
309	}
310	c.builtins = cpy
311	// Build type env for custom functions and wrap existing one.
312	checker := newTypeChecker()
313	c.TypeEnv = checker.checkLanguageBuiltins(c.TypeEnv, builtins)
314	return c
315}
316
317// WithUnsafeBuiltins will add all built-ins in the map to the "blacklist".
318func (c *Compiler) WithUnsafeBuiltins(unsafeBuiltins map[string]struct{}) *Compiler {
319	for name := range unsafeBuiltins {
320		c.unsafeBuiltinsMap[name] = struct{}{}
321	}
322	return c
323}
324
325// QueryCompiler returns a new QueryCompiler object.
326func (c *Compiler) QueryCompiler() QueryCompiler {
327	return newQueryCompiler(c)
328}
329
330// Compile runs the compilation process on the input modules. The compiled
331// version of the modules and associated data structures are stored on the
332// compiler. If the compilation process fails for any reason, the compiler will
333// contain a slice of errors.
334func (c *Compiler) Compile(modules map[string]*Module) {
335
336	c.Modules = make(map[string]*Module, len(modules))
337
338	for k, v := range modules {
339		c.Modules[k] = v.Copy()
340		c.sorted = append(c.sorted, k)
341	}
342
343	sort.Strings(c.sorted)
344
345	c.compile()
346}
347
348// Failed returns true if a compilation error has been encountered.
349func (c *Compiler) Failed() bool {
350	return len(c.Errors) > 0
351}
352
353// GetArity returns the number of args a function referred to by ref takes. If
354// ref refers to built-in function, the built-in declaration is consulted,
355// otherwise, the ref is used to perform a ruleset lookup.
356func (c *Compiler) GetArity(ref Ref) int {
357	if bi := c.builtins[ref.String()]; bi != nil {
358		return len(bi.Decl.Args())
359	}
360	rules := c.GetRulesExact(ref)
361	if len(rules) == 0 {
362		return -1
363	}
364	return len(rules[0].Head.Args)
365}
366
367// GetRulesExact returns a slice of rules referred to by the reference.
368//
369// E.g., given the following module:
370//
371//	package a.b.c
372//
373//	p[k] = v { ... }    # rule1
374//  p[k1] = v1 { ... }  # rule2
375//
376// The following calls yield the rules on the right.
377//
378//  GetRulesExact("data.a.b.c.p")   => [rule1, rule2]
379//  GetRulesExact("data.a.b.c.p.x") => nil
380//  GetRulesExact("data.a.b.c")     => nil
381func (c *Compiler) GetRulesExact(ref Ref) (rules []*Rule) {
382	node := c.RuleTree
383
384	for _, x := range ref {
385		if node = node.Child(x.Value); node == nil {
386			return nil
387		}
388	}
389
390	return extractRules(node.Values)
391}
392
393// GetRulesForVirtualDocument returns a slice of rules that produce the virtual
394// document referred to by the reference.
395//
396// E.g., given the following module:
397//
398//	package a.b.c
399//
400//	p[k] = v { ... }    # rule1
401//  p[k1] = v1 { ... }  # rule2
402//
403// The following calls yield the rules on the right.
404//
405//  GetRulesForVirtualDocument("data.a.b.c.p")   => [rule1, rule2]
406//  GetRulesForVirtualDocument("data.a.b.c.p.x") => [rule1, rule2]
407//  GetRulesForVirtualDocument("data.a.b.c")     => nil
408func (c *Compiler) GetRulesForVirtualDocument(ref Ref) (rules []*Rule) {
409
410	node := c.RuleTree
411
412	for _, x := range ref {
413		if node = node.Child(x.Value); node == nil {
414			return nil
415		}
416		if len(node.Values) > 0 {
417			return extractRules(node.Values)
418		}
419	}
420
421	return extractRules(node.Values)
422}
423
424// GetRulesWithPrefix returns a slice of rules that share the prefix ref.
425//
426// E.g., given the following module:
427//
428//  package a.b.c
429//
430//  p[x] = y { ... }  # rule1
431//  p[k] = v { ... }  # rule2
432//  q { ... }         # rule3
433//
434// The following calls yield the rules on the right.
435//
436//  GetRulesWithPrefix("data.a.b.c.p")   => [rule1, rule2]
437//  GetRulesWithPrefix("data.a.b.c.p.a") => nil
438//  GetRulesWithPrefix("data.a.b.c")     => [rule1, rule2, rule3]
439func (c *Compiler) GetRulesWithPrefix(ref Ref) (rules []*Rule) {
440
441	node := c.RuleTree
442
443	for _, x := range ref {
444		if node = node.Child(x.Value); node == nil {
445			return nil
446		}
447	}
448
449	var acc func(node *TreeNode)
450
451	acc = func(node *TreeNode) {
452		rules = append(rules, extractRules(node.Values)...)
453		for _, child := range node.Children {
454			if child.Hide {
455				continue
456			}
457			acc(child)
458		}
459	}
460
461	acc(node)
462
463	return rules
464}
465
466func extractRules(s []util.T) (rules []*Rule) {
467	for _, r := range s {
468		rules = append(rules, r.(*Rule))
469	}
470	return rules
471}
472
473// GetRules returns a slice of rules that are referred to by ref.
474//
475// E.g., given the following module:
476//
477//  package a.b.c
478//
479//  p[x] = y { q[x] = y; ... } # rule1
480//  q[x] = y { ... }           # rule2
481//
482// The following calls yield the rules on the right.
483//
484//  GetRules("data.a.b.c.p")	=> [rule1]
485//  GetRules("data.a.b.c.p.x")	=> [rule1]
486//  GetRules("data.a.b.c.q")	=> [rule2]
487//  GetRules("data.a.b.c")		=> [rule1, rule2]
488//  GetRules("data.a.b.d")		=> nil
489func (c *Compiler) GetRules(ref Ref) (rules []*Rule) {
490
491	set := map[*Rule]struct{}{}
492
493	for _, rule := range c.GetRulesForVirtualDocument(ref) {
494		set[rule] = struct{}{}
495	}
496
497	for _, rule := range c.GetRulesWithPrefix(ref) {
498		set[rule] = struct{}{}
499	}
500
501	for rule := range set {
502		rules = append(rules, rule)
503	}
504
505	return rules
506}
507
508// GetRulesDynamic returns a slice of rules that could be referred to by a ref.
509// When parts of the ref are statically known, we use that information to narrow
510// down which rules the ref could refer to, but in the most general case this
511// will be an over-approximation.
512//
513// E.g., given the following modules:
514//
515//  package a.b.c
516//
517//  r1 = 1  # rule1
518//
519// and:
520//
521//  package a.d.c
522//
523//  r2 = 2  # rule2
524//
525// The following calls yield the rules on the right.
526//
527//  GetRulesDynamic("data.a[x].c[y]") => [rule1, rule2]
528//  GetRulesDynamic("data.a[x].c.r2") => [rule2]
529//  GetRulesDynamic("data.a.b[x][y]") => [rule1]
530func (c *Compiler) GetRulesDynamic(ref Ref) (rules []*Rule) {
531	node := c.RuleTree
532
533	set := map[*Rule]struct{}{}
534	var walk func(node *TreeNode, i int)
535	walk = func(node *TreeNode, i int) {
536		if i >= len(ref) {
537			// We've reached the end of the reference and want to collect everything
538			// under this "prefix".
539			node.DepthFirst(func(descendant *TreeNode) bool {
540				insertRules(set, descendant.Values)
541				return descendant.Hide
542			})
543		} else if i == 0 || IsConstant(ref[i].Value) {
544			// The head of the ref is always grounded.  In case another part of the
545			// ref is also grounded, we can lookup the exact child.  If it's not found
546			// we can immediately return...
547			if child := node.Child(ref[i].Value); child == nil {
548				return
549			} else if len(child.Values) > 0 {
550				// If there are any rules at this position, it's what the ref would
551				// refer to.  We can just append those and stop here.
552				insertRules(set, child.Values)
553			} else {
554				// Otherwise, we continue using the child node.
555				walk(child, i+1)
556			}
557		} else {
558			// This part of the ref is a dynamic term.  We can't know what it refers
559			// to and will just need to try all of the children.
560			for _, child := range node.Children {
561				if child.Hide {
562					continue
563				}
564				insertRules(set, child.Values)
565				walk(child, i+1)
566			}
567		}
568	}
569
570	walk(node, 0)
571	for rule := range set {
572		rules = append(rules, rule)
573	}
574	return rules
575}
576
577// Utility: add all rule values to the set.
578func insertRules(set map[*Rule]struct{}, rules []util.T) {
579	for _, rule := range rules {
580		set[rule.(*Rule)] = struct{}{}
581	}
582}
583
584// RuleIndex returns a RuleIndex built for the rule set referred to by path.
585// The path must refer to the rule set exactly, i.e., given a rule set at path
586// data.a.b.c.p, refs data.a.b.c.p.x and data.a.b.c would not return a
587// RuleIndex built for the rule.
588func (c *Compiler) RuleIndex(path Ref) RuleIndex {
589	r, ok := c.ruleIndices.Get(path)
590	if !ok {
591		return nil
592	}
593	return r.(RuleIndex)
594}
595
596// ModuleLoader defines the interface that callers can implement to enable lazy
597// loading of modules during compilation.
598type ModuleLoader func(resolved map[string]*Module) (parsed map[string]*Module, err error)
599
600// WithModuleLoader sets f as the ModuleLoader on the compiler.
601//
602// The compiler will invoke the ModuleLoader after resolving all references in
603// the current set of input modules. The ModuleLoader can return a new
604// collection of parsed modules that are to be included in the compilation
605// process. This process will repeat until the ModuleLoader returns an empty
606// collection or an error. If an error is returned, compilation will stop
607// immediately.
608func (c *Compiler) WithModuleLoader(f ModuleLoader) *Compiler {
609	c.moduleLoader = f
610	return c
611}
612
613// buildRuleIndices constructs indices for rules.
614func (c *Compiler) buildRuleIndices() {
615
616	c.RuleTree.DepthFirst(func(node *TreeNode) bool {
617		if len(node.Values) == 0 {
618			return false
619		}
620		index := newBaseDocEqIndex(func(ref Ref) bool {
621			return isVirtual(c.RuleTree, ref.GroundPrefix())
622		})
623		if rules := extractRules(node.Values); index.Build(rules) {
624			c.ruleIndices.Put(rules[0].Path(), index)
625		}
626		return false
627	})
628
629}
630
631// checkRecursion ensures that there are no recursive definitions, i.e., there are
632// no cycles in the Graph.
633func (c *Compiler) checkRecursion() {
634	eq := func(a, b util.T) bool {
635		return a.(*Rule) == b.(*Rule)
636	}
637
638	c.RuleTree.DepthFirst(func(node *TreeNode) bool {
639		for _, rule := range node.Values {
640			for node := rule.(*Rule); node != nil; node = node.Else {
641				c.checkSelfPath(node.Loc(), eq, node, node)
642			}
643		}
644		return false
645	})
646}
647
648func (c *Compiler) checkSelfPath(loc *Location, eq func(a, b util.T) bool, a, b util.T) {
649	tr := NewGraphTraversal(c.Graph)
650	if p := util.DFSPath(tr, eq, a, b); len(p) > 0 {
651		n := []string{}
652		for _, x := range p {
653			n = append(n, astNodeToString(x))
654		}
655		c.err(NewError(RecursionErr, loc, "rule %v is recursive: %v", astNodeToString(a), strings.Join(n, " -> ")))
656	}
657}
658
659func astNodeToString(x interface{}) string {
660	switch x := x.(type) {
661	case *Rule:
662		return string(x.Head.Name)
663	default:
664		panic("not reached")
665	}
666}
667
668// checkRuleConflicts ensures that rules definitions are not in conflict.
669func (c *Compiler) checkRuleConflicts() {
670	c.RuleTree.DepthFirst(func(node *TreeNode) bool {
671		if len(node.Values) == 0 {
672			return false
673		}
674
675		kinds := map[DocKind]struct{}{}
676		defaultRules := 0
677		arities := map[int]struct{}{}
678		declared := false
679
680		for _, rule := range node.Values {
681			r := rule.(*Rule)
682			kinds[r.Head.DocKind()] = struct{}{}
683			arities[len(r.Head.Args)] = struct{}{}
684			if r.Head.Assign {
685				declared = true
686			}
687			if r.Default {
688				defaultRules++
689			}
690		}
691
692		name := Var(node.Key.(String))
693
694		if declared && len(node.Values) > 1 {
695			c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "rule named %v redeclared at %v", name, node.Values[1].(*Rule).Loc()))
696		} else if len(kinds) > 1 || len(arities) > 1 {
697			c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "conflicting rules named %v found", name))
698		} else if defaultRules > 1 {
699			c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "multiple default rules named %s found", name))
700		}
701
702		return false
703	})
704
705	if c.pathExists != nil {
706		for _, err := range CheckPathConflicts(c, c.pathExists) {
707			c.err(err)
708		}
709	}
710
711	c.ModuleTree.DepthFirst(func(node *ModuleTreeNode) bool {
712		for _, mod := range node.Modules {
713			for _, rule := range mod.Rules {
714				if childNode, ok := node.Children[String(rule.Head.Name)]; ok {
715					for _, childMod := range childNode.Modules {
716						msg := fmt.Sprintf("%v conflicts with rule defined at %v", childMod.Package, rule.Loc())
717						c.err(NewError(TypeErr, mod.Package.Loc(), msg))
718					}
719				}
720			}
721		}
722		return false
723	})
724}
725
726func (c *Compiler) checkUndefinedFuncs() {
727	for _, name := range c.sorted {
728		m := c.Modules[name]
729		for _, err := range checkUndefinedFuncs(m, c.GetArity) {
730			c.err(err)
731		}
732	}
733}
734
735func checkUndefinedFuncs(x interface{}, arity func(Ref) int) Errors {
736
737	var errs Errors
738
739	WalkExprs(x, func(expr *Expr) bool {
740		if !expr.IsCall() {
741			return false
742		}
743		ref := expr.Operator()
744		if arity(ref) >= 0 {
745			return false
746		}
747		errs = append(errs, NewError(TypeErr, expr.Loc(), "undefined function %v", ref))
748		return true
749	})
750
751	return errs
752}
753
754// checkSafetyRuleBodies ensures that variables appearing in negated expressions or non-target
755// positions of built-in expressions will be bound when evaluating the rule from left
756// to right, re-ordering as necessary.
757func (c *Compiler) checkSafetyRuleBodies() {
758	for _, name := range c.sorted {
759		m := c.Modules[name]
760		WalkRules(m, func(r *Rule) bool {
761			safe := ReservedVars.Copy()
762			safe.Update(r.Head.Args.Vars())
763			r.Body = c.checkBodySafety(safe, m, r.Body)
764			return false
765		})
766	}
767}
768
769func (c *Compiler) checkBodySafety(safe VarSet, m *Module, b Body) Body {
770	reordered, unsafe := reorderBodyForSafety(c.builtins, c.GetArity, safe, b)
771	if errs := safetyErrorSlice(unsafe); len(errs) > 0 {
772		for _, err := range errs {
773			c.err(err)
774		}
775		return b
776	}
777	return reordered
778}
779
780var safetyCheckVarVisitorParams = VarVisitorParams{
781	SkipRefCallHead: true,
782	SkipClosures:    true,
783}
784
785// checkSafetyRuleHeads ensures that variables appearing in the head of a
786// rule also appear in the body.
787func (c *Compiler) checkSafetyRuleHeads() {
788
789	for _, name := range c.sorted {
790		m := c.Modules[name]
791		WalkRules(m, func(r *Rule) bool {
792			safe := r.Body.Vars(safetyCheckVarVisitorParams)
793			safe.Update(r.Head.Args.Vars())
794			unsafe := r.Head.Vars().Diff(safe)
795			for v := range unsafe {
796				if !v.IsGenerated() {
797					c.err(NewError(UnsafeVarErr, r.Loc(), "var %v is unsafe", v))
798				}
799			}
800			return false
801		})
802	}
803}
804
805// checkTypes runs the type checker on all rules. The type checker builds a
806// TypeEnv that is stored on the compiler.
807func (c *Compiler) checkTypes() {
808	// Recursion is caught in earlier step, so this cannot fail.
809	sorted, _ := c.Graph.Sort()
810	checker := newTypeChecker().WithVarRewriter(rewriteVarsInRef(c.RewrittenVars))
811	env, errs := checker.CheckTypes(c.TypeEnv, sorted)
812	for _, err := range errs {
813		c.err(err)
814	}
815	c.TypeEnv = env
816}
817
818func (c *Compiler) checkUnsafeBuiltins() {
819	for _, name := range c.sorted {
820		errs := checkUnsafeBuiltins(c.unsafeBuiltinsMap, c.Modules[name])
821		for _, err := range errs {
822			c.err(err)
823		}
824	}
825}
826
827func (c *Compiler) runStage(metricName string, f func()) {
828	if c.metrics != nil {
829		c.metrics.Timer(metricName).Start()
830		defer c.metrics.Timer(metricName).Stop()
831	}
832	f()
833}
834
835func (c *Compiler) runStageAfter(metricName string, s CompilerStage) *Error {
836	if c.metrics != nil {
837		c.metrics.Timer(metricName).Start()
838		defer c.metrics.Timer(metricName).Stop()
839	}
840	return s(c)
841}
842
843func (c *Compiler) compile() {
844	defer func() {
845		if r := recover(); r != nil && r != errLimitReached {
846			panic(r)
847		}
848	}()
849
850	for _, s := range c.stages {
851		c.runStage(s.metricName, s.f)
852		if c.Failed() {
853			return
854		}
855		for _, s := range c.after[s.name] {
856			err := c.runStageAfter(s.MetricName, s.Stage)
857			if err != nil {
858				c.err(err)
859			}
860		}
861	}
862}
863
864func (c *Compiler) err(err *Error) {
865	if c.maxErrs > 0 && len(c.Errors) >= c.maxErrs {
866		c.Errors = append(c.Errors, errLimitReached)
867		panic(errLimitReached)
868	}
869	c.Errors = append(c.Errors, err)
870}
871
872func (c *Compiler) getExports() *util.HashMap {
873
874	rules := util.NewHashMap(func(a, b util.T) bool {
875		r1 := a.(Ref)
876		r2 := a.(Ref)
877		return r1.Equal(r2)
878	}, func(v util.T) int {
879		return v.(Ref).Hash()
880	})
881
882	for _, name := range c.sorted {
883		mod := c.Modules[name]
884		rv, ok := rules.Get(mod.Package.Path)
885		if !ok {
886			rv = []Var{}
887		}
888		rvs := rv.([]Var)
889
890		for _, rule := range mod.Rules {
891			rvs = append(rvs, rule.Head.Name)
892		}
893		rules.Put(mod.Package.Path, rvs)
894	}
895
896	return rules
897}
898
899// resolveAllRefs resolves references in expressions to their fully qualified values.
900//
901// For instance, given the following module:
902//
903// package a.b
904// import data.foo.bar
905// p[x] { bar[_] = x }
906//
907// The reference "bar[_]" would be resolved to "data.foo.bar[_]".
908func (c *Compiler) resolveAllRefs() {
909
910	rules := c.getExports()
911
912	for _, name := range c.sorted {
913		mod := c.Modules[name]
914
915		var ruleExports []Var
916		if x, ok := rules.Get(mod.Package.Path); ok {
917			ruleExports = x.([]Var)
918		}
919
920		globals := getGlobals(mod.Package, ruleExports, mod.Imports)
921
922		WalkRules(mod, func(rule *Rule) bool {
923			err := resolveRefsInRule(globals, rule)
924			if err != nil {
925				c.err(NewError(CompileErr, rule.Location, err.Error()))
926			}
927			return false
928		})
929
930		// Once imports have been resolved, they are no longer needed.
931		mod.Imports = nil
932	}
933
934	if c.moduleLoader != nil {
935
936		parsed, err := c.moduleLoader(c.Modules)
937		if err != nil {
938			c.err(NewError(CompileErr, nil, err.Error()))
939			return
940		}
941
942		if len(parsed) == 0 {
943			return
944		}
945
946		for id, module := range parsed {
947			c.Modules[id] = module.Copy()
948			c.sorted = append(c.sorted, id)
949		}
950
951		sort.Strings(c.sorted)
952		c.resolveAllRefs()
953	}
954}
955
956func (c *Compiler) initLocalVarGen() {
957	c.localvargen = newLocalVarGeneratorForModuleSet(c.sorted, c.Modules)
958}
959
960func (c *Compiler) rewriteComprehensionTerms() {
961	f := newEqualityFactory(c.localvargen)
962	for _, name := range c.sorted {
963		mod := c.Modules[name]
964		rewriteComprehensionTerms(f, mod)
965	}
966}
967
968func (c *Compiler) rewriteExprTerms() {
969	for _, name := range c.sorted {
970		mod := c.Modules[name]
971		WalkRules(mod, func(rule *Rule) bool {
972			rewriteExprTermsInHead(c.localvargen, rule)
973			rule.Body = rewriteExprTermsInBody(c.localvargen, rule.Body)
974			return false
975		})
976	}
977}
978
979// rewriteTermsInHead will rewrite rules so that the head does not contain any
980// terms that require evaluation (e.g., refs or comprehensions). If the key or
981// value contains or more of these terms, the key or value will be moved into
982// the body and assigned to a new variable. The new variable will replace the
983// key or value in the head.
984//
985// For instance, given the following rule:
986//
987// p[{"foo": data.foo[i]}] { i < 100 }
988//
989// The rule would be re-written as:
990//
991// p[__local0__] { i < 100; __local0__ = {"foo": data.foo[i]} }
992func (c *Compiler) rewriteRefsInHead() {
993	f := newEqualityFactory(c.localvargen)
994	for _, name := range c.sorted {
995		mod := c.Modules[name]
996		WalkRules(mod, func(rule *Rule) bool {
997			if requiresEval(rule.Head.Key) {
998				expr := f.Generate(rule.Head.Key)
999				rule.Head.Key = expr.Operand(0)
1000				rule.Body.Append(expr)
1001			}
1002			if requiresEval(rule.Head.Value) {
1003				expr := f.Generate(rule.Head.Value)
1004				rule.Head.Value = expr.Operand(0)
1005				rule.Body.Append(expr)
1006			}
1007			for i := 0; i < len(rule.Head.Args); i++ {
1008				if requiresEval(rule.Head.Args[i]) {
1009					expr := f.Generate(rule.Head.Args[i])
1010					rule.Head.Args[i] = expr.Operand(0)
1011					rule.Body.Append(expr)
1012				}
1013			}
1014			return false
1015		})
1016	}
1017}
1018
1019func (c *Compiler) rewriteEquals() {
1020	for _, name := range c.sorted {
1021		mod := c.Modules[name]
1022		rewriteEquals(mod)
1023	}
1024}
1025
1026func (c *Compiler) rewriteDynamicTerms() {
1027	f := newEqualityFactory(c.localvargen)
1028	for _, name := range c.sorted {
1029		mod := c.Modules[name]
1030		WalkRules(mod, func(rule *Rule) bool {
1031			rule.Body = rewriteDynamics(f, rule.Body)
1032			return false
1033		})
1034	}
1035}
1036
1037func (c *Compiler) rewriteLocalVars() {
1038
1039	for _, name := range c.sorted {
1040		mod := c.Modules[name]
1041		gen := c.localvargen
1042
1043		WalkRules(mod, func(rule *Rule) bool {
1044
1045			var errs Errors
1046
1047			// Rewrite assignments contained in head of rule. Assignments can
1048			// occur in rule head if they're inside a comprehension. Note,
1049			// assigned vars in comprehensions in the head will be rewritten
1050			// first to preserve scoping rules. For example:
1051			//
1052			// p = [x | x := 1] { x := 2 } becomes p = [__local0__ | __local0__ = 1] { __local1__ = 2 }
1053			//
1054			// This behaviour is consistent scoping inside the body. For example:
1055			//
1056			// p = xs { x := 2; xs = [x | x := 1] } becomes p = xs { __local0__ = 2; xs = [__local1__ | __local1__ = 1] }
1057			WalkTerms(rule.Head, func(term *Term) bool {
1058				stop := false
1059				stack := newLocalDeclaredVars()
1060				switch v := term.Value.(type) {
1061				case *ArrayComprehension:
1062					errs = rewriteDeclaredVarsInArrayComprehension(gen, stack, v, errs)
1063					stop = true
1064				case *SetComprehension:
1065					errs = rewriteDeclaredVarsInSetComprehension(gen, stack, v, errs)
1066					stop = true
1067				case *ObjectComprehension:
1068					errs = rewriteDeclaredVarsInObjectComprehension(gen, stack, v, errs)
1069					stop = true
1070				}
1071
1072				for k, v := range stack.rewritten {
1073					c.RewrittenVars[k] = v
1074				}
1075
1076				return stop
1077			})
1078
1079			for _, err := range errs {
1080				c.err(err)
1081			}
1082
1083			// Rewrite assignments in body.
1084			used := NewVarSet()
1085
1086			if rule.Head.Key != nil {
1087				used.Update(rule.Head.Key.Vars())
1088			}
1089
1090			if rule.Head.Value != nil {
1091				used.Update(rule.Head.Value.Vars())
1092			}
1093
1094			stack := newLocalDeclaredVars()
1095
1096			c.rewriteLocalArgVars(gen, stack, rule)
1097
1098			body, declared, errs := rewriteLocalVars(gen, stack, used, rule.Body)
1099			for _, err := range errs {
1100				c.err(err)
1101			}
1102
1103			// For rewritten vars use the collection of all variables that
1104			// were in the stack at some point in time.
1105			for k, v := range stack.rewritten {
1106				c.RewrittenVars[k] = v
1107			}
1108
1109			rule.Body = body
1110
1111			// Rewrite vars in head that refer to locally declared vars in the body.
1112			vis := NewGenericVisitor(func(x interface{}) bool {
1113
1114				term, ok := x.(*Term)
1115				if !ok {
1116					return false
1117				}
1118
1119				switch v := term.Value.(type) {
1120				case Object:
1121					// Make a copy of the object because the keys may be mutated.
1122					cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) {
1123						if vark, ok := k.Value.(Var); ok {
1124							if gv, ok := declared[vark]; ok {
1125								k = k.Copy()
1126								k.Value = gv
1127							}
1128						}
1129						return k, v, nil
1130					})
1131					term.Value = cpy
1132				case Var:
1133					if gv, ok := declared[v]; ok {
1134						term.Value = gv
1135						return true
1136					}
1137				}
1138
1139				return false
1140			})
1141
1142			vis.Walk(rule.Head.Args)
1143
1144			if rule.Head.Key != nil {
1145				vis.Walk(rule.Head.Key)
1146			}
1147
1148			if rule.Head.Value != nil {
1149				vis.Walk(rule.Head.Value)
1150			}
1151
1152			return false
1153		})
1154	}
1155}
1156
1157func (c *Compiler) rewriteLocalArgVars(gen *localVarGenerator, stack *localDeclaredVars, rule *Rule) {
1158
1159	vis := &ruleArgLocalRewriter{
1160		stack: stack,
1161		gen:   gen,
1162	}
1163
1164	for i := range rule.Head.Args {
1165		Walk(vis, rule.Head.Args[i])
1166	}
1167
1168	for i := range vis.errs {
1169		c.err(vis.errs[i])
1170	}
1171}
1172
1173type ruleArgLocalRewriter struct {
1174	stack *localDeclaredVars
1175	gen   *localVarGenerator
1176	errs  []*Error
1177}
1178
1179func (vis *ruleArgLocalRewriter) Visit(x interface{}) Visitor {
1180
1181	t, ok := x.(*Term)
1182	if !ok {
1183		return vis
1184	}
1185
1186	switch v := t.Value.(type) {
1187	case Var:
1188		gv, ok := vis.stack.Declared(v)
1189		if !ok {
1190			gv = vis.gen.Generate()
1191			vis.stack.Insert(v, gv, argVar)
1192		}
1193		t.Value = gv
1194		return nil
1195	case Object:
1196		if cpy, err := v.Map(func(k, v *Term) (*Term, *Term, error) {
1197			vcpy := v.Copy()
1198			Walk(vis, vcpy)
1199			return k, vcpy, nil
1200		}); err != nil {
1201			vis.errs = append(vis.errs, NewError(CompileErr, t.Location, err.Error()))
1202		} else {
1203			t.Value = cpy
1204		}
1205		return nil
1206	case Null, Boolean, Number, String, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Set:
1207		// Scalars are no-ops. Comprehensions are handled above. Sets must not
1208		// contain variables.
1209		return nil
1210	case Call:
1211		vis.errs = append(vis.errs, NewError(CompileErr, t.Location, "rule arguments cannot contain calls"))
1212		return nil
1213	default:
1214		// Recurse on refs and arrays. Any embedded
1215		// variables can be rewritten.
1216		return vis
1217	}
1218}
1219
1220func (c *Compiler) rewriteWithModifiers() {
1221	f := newEqualityFactory(c.localvargen)
1222	for _, name := range c.sorted {
1223		mod := c.Modules[name]
1224		t := NewGenericTransformer(func(x interface{}) (interface{}, error) {
1225			body, ok := x.(Body)
1226			if !ok {
1227				return x, nil
1228			}
1229			body, err := rewriteWithModifiersInBody(c, f, body)
1230			if err != nil {
1231				c.err(err)
1232			}
1233
1234			return body, nil
1235		})
1236		Transform(t, mod)
1237	}
1238}
1239
1240func (c *Compiler) setModuleTree() {
1241	c.ModuleTree = NewModuleTree(c.Modules)
1242}
1243
1244func (c *Compiler) setRuleTree() {
1245	c.RuleTree = NewRuleTree(c.ModuleTree)
1246}
1247
1248func (c *Compiler) setGraph() {
1249	c.Graph = NewGraph(c.Modules, c.GetRulesDynamic)
1250}
1251
1252type queryCompiler struct {
1253	compiler       *Compiler
1254	qctx           *QueryContext
1255	typeEnv        *TypeEnv
1256	rewritten      map[Var]Var
1257	after          map[string][]QueryCompilerStageDefinition
1258	unsafeBuiltins map[string]struct{}
1259}
1260
1261func newQueryCompiler(compiler *Compiler) QueryCompiler {
1262	qc := &queryCompiler{
1263		compiler: compiler,
1264		qctx:     nil,
1265		after:    map[string][]QueryCompilerStageDefinition{},
1266	}
1267	return qc
1268}
1269
1270func (qc *queryCompiler) WithContext(qctx *QueryContext) QueryCompiler {
1271	qc.qctx = qctx
1272	return qc
1273}
1274
1275func (qc *queryCompiler) WithStageAfter(after string, stage QueryCompilerStageDefinition) QueryCompiler {
1276	qc.after[after] = append(qc.after[after], stage)
1277	return qc
1278}
1279
1280func (qc *queryCompiler) WithUnsafeBuiltins(unsafe map[string]struct{}) QueryCompiler {
1281	qc.unsafeBuiltins = unsafe
1282	return qc
1283}
1284
1285func (qc *queryCompiler) RewrittenVars() map[Var]Var {
1286	return qc.rewritten
1287}
1288
1289func (qc *queryCompiler) runStage(metricName string, qctx *QueryContext, query Body, s func(*QueryContext, Body) (Body, error)) (Body, error) {
1290	if qc.compiler.metrics != nil {
1291		qc.compiler.metrics.Timer(metricName).Start()
1292		defer qc.compiler.metrics.Timer(metricName).Stop()
1293	}
1294	return s(qctx, query)
1295}
1296
1297func (qc *queryCompiler) runStageAfter(metricName string, query Body, s QueryCompilerStage) (Body, error) {
1298	if qc.compiler.metrics != nil {
1299		qc.compiler.metrics.Timer(metricName).Start()
1300		defer qc.compiler.metrics.Timer(metricName).Stop()
1301	}
1302	return s(qc, query)
1303}
1304
1305func (qc *queryCompiler) Compile(query Body) (Body, error) {
1306
1307	query = query.Copy()
1308
1309	stages := []struct {
1310		name       string
1311		metricName string
1312		f          func(*QueryContext, Body) (Body, error)
1313	}{
1314		{"ResolveRefs", "query_compile_stage_resolve_refs", qc.resolveRefs},
1315		{"RewriteLocalVars", "query_compile_stage_rewrite_local_vars", qc.rewriteLocalVars},
1316		{"RewriteExprTerms", "query_compile_stage_rewrite_expr_terms", qc.rewriteExprTerms},
1317		{"RewriteComprehensionTerms", "query_compile_stage_rewrite_comprehension_terms", qc.rewriteComprehensionTerms},
1318		{"RewriteWithValues", "query_compile_stage_rewrite_with_values", qc.rewriteWithModifiers},
1319		{"CheckUndefinedFuncs", "query_compile_stage_check_undefined_funcs", qc.checkUndefinedFuncs},
1320		{"CheckSafety", "query_compile_stage_check_safety", qc.checkSafety},
1321		{"RewriteDynamicTerms", "query_compile_stage_rewrite_dynamic_terms", qc.rewriteDynamicTerms},
1322		{"CheckTypes", "query_compile_stage_check_types", qc.checkTypes},
1323		{"CheckUnsafeBuiltins", "query_compile_stage_check_unsafe_builtins", qc.checkUnsafeBuiltins},
1324	}
1325
1326	qctx := qc.qctx.Copy()
1327
1328	for _, s := range stages {
1329		var err error
1330		query, err = qc.runStage(s.metricName, qctx, query, s.f)
1331		if err != nil {
1332			return nil, qc.applyErrorLimit(err)
1333		}
1334		for _, s := range qc.after[s.name] {
1335			query, err = qc.runStageAfter(s.MetricName, query, s.Stage)
1336			if err != nil {
1337				return nil, qc.applyErrorLimit(err)
1338			}
1339		}
1340	}
1341
1342	return query, nil
1343}
1344
1345func (qc *queryCompiler) TypeEnv() *TypeEnv {
1346	return qc.typeEnv
1347}
1348
1349func (qc *queryCompiler) applyErrorLimit(err error) error {
1350	if errs, ok := err.(Errors); ok {
1351		if qc.compiler.maxErrs > 0 && len(errs) > qc.compiler.maxErrs {
1352			err = append(errs[:qc.compiler.maxErrs], errLimitReached)
1353		}
1354	}
1355	return err
1356}
1357
1358func (qc *queryCompiler) resolveRefs(qctx *QueryContext, body Body) (Body, error) {
1359
1360	var globals map[Var]Ref
1361
1362	if qctx != nil && qctx.Package != nil {
1363		var ruleExports []Var
1364		rules := qc.compiler.getExports()
1365		if exist, ok := rules.Get(qctx.Package.Path); ok {
1366			ruleExports = exist.([]Var)
1367		}
1368
1369		globals = getGlobals(qctx.Package, ruleExports, qc.qctx.Imports)
1370		qctx.Imports = nil
1371	}
1372
1373	ignore := &declaredVarStack{declaredVars(body)}
1374
1375	return resolveRefsInBody(globals, ignore, body), nil
1376}
1377
1378func (qc *queryCompiler) rewriteComprehensionTerms(_ *QueryContext, body Body) (Body, error) {
1379	gen := newLocalVarGenerator("q", body)
1380	f := newEqualityFactory(gen)
1381	node, err := rewriteComprehensionTerms(f, body)
1382	if err != nil {
1383		return nil, err
1384	}
1385	return node.(Body), nil
1386}
1387
1388func (qc *queryCompiler) rewriteDynamicTerms(_ *QueryContext, body Body) (Body, error) {
1389	gen := newLocalVarGenerator("q", body)
1390	f := newEqualityFactory(gen)
1391	return rewriteDynamics(f, body), nil
1392}
1393
1394func (qc *queryCompiler) rewriteExprTerms(_ *QueryContext, body Body) (Body, error) {
1395	gen := newLocalVarGenerator("q", body)
1396	return rewriteExprTermsInBody(gen, body), nil
1397}
1398
1399func (qc *queryCompiler) rewriteLocalVars(_ *QueryContext, body Body) (Body, error) {
1400	gen := newLocalVarGenerator("q", body)
1401	stack := newLocalDeclaredVars()
1402	body, _, err := rewriteLocalVars(gen, stack, nil, body)
1403	if len(err) != 0 {
1404		return nil, err
1405	}
1406	qc.rewritten = make(map[Var]Var, len(stack.rewritten))
1407	for k, v := range stack.rewritten {
1408		// The vars returned during the rewrite will include all seen vars,
1409		// even if they're not declared with an assignment operation. We don't
1410		// want to include these inside the rewritten set though.
1411		qc.rewritten[k] = v
1412	}
1413	return body, nil
1414}
1415
1416func (qc *queryCompiler) checkUndefinedFuncs(_ *QueryContext, body Body) (Body, error) {
1417	if errs := checkUndefinedFuncs(body, qc.compiler.GetArity); len(errs) > 0 {
1418		return nil, errs
1419	}
1420	return body, nil
1421}
1422
1423func (qc *queryCompiler) checkSafety(_ *QueryContext, body Body) (Body, error) {
1424	safe := ReservedVars.Copy()
1425	reordered, unsafe := reorderBodyForSafety(qc.compiler.builtins, qc.compiler.GetArity, safe, body)
1426	if errs := safetyErrorSlice(unsafe); len(errs) > 0 {
1427		return nil, errs
1428	}
1429	return reordered, nil
1430}
1431
1432func (qc *queryCompiler) checkTypes(qctx *QueryContext, body Body) (Body, error) {
1433	var errs Errors
1434	checker := newTypeChecker().WithVarRewriter(rewriteVarsInRef(qc.rewritten, qc.compiler.RewrittenVars))
1435	qc.typeEnv, errs = checker.CheckBody(qc.compiler.TypeEnv, body)
1436	if len(errs) > 0 {
1437		return nil, errs
1438	}
1439	return body, nil
1440}
1441
1442func (qc *queryCompiler) checkUnsafeBuiltins(qctx *QueryContext, body Body) (Body, error) {
1443	var unsafe map[string]struct{}
1444	if qc.unsafeBuiltins != nil {
1445		unsafe = qc.unsafeBuiltins
1446	} else {
1447		unsafe = qc.compiler.unsafeBuiltinsMap
1448	}
1449	errs := checkUnsafeBuiltins(unsafe, body)
1450	if len(errs) > 0 {
1451		return nil, errs
1452	}
1453	return body, nil
1454}
1455
1456func (qc *queryCompiler) rewriteWithModifiers(qctx *QueryContext, body Body) (Body, error) {
1457	f := newEqualityFactory(newLocalVarGenerator("q", body))
1458	body, err := rewriteWithModifiersInBody(qc.compiler, f, body)
1459	if err != nil {
1460		return nil, Errors{err}
1461	}
1462	return body, nil
1463}
1464
1465// ModuleTreeNode represents a node in the module tree. The module
1466// tree is keyed by the package path.
1467type ModuleTreeNode struct {
1468	Key      Value
1469	Modules  []*Module
1470	Children map[Value]*ModuleTreeNode
1471	Hide     bool
1472}
1473
1474// NewModuleTree returns a new ModuleTreeNode that represents the root
1475// of the module tree populated with the given modules.
1476func NewModuleTree(mods map[string]*Module) *ModuleTreeNode {
1477	root := &ModuleTreeNode{
1478		Children: map[Value]*ModuleTreeNode{},
1479	}
1480	for _, m := range mods {
1481		node := root
1482		for i, x := range m.Package.Path {
1483			c, ok := node.Children[x.Value]
1484			if !ok {
1485				var hide bool
1486				if i == 1 && x.Value.Compare(SystemDocumentKey) == 0 {
1487					hide = true
1488				}
1489				c = &ModuleTreeNode{
1490					Key:      x.Value,
1491					Children: map[Value]*ModuleTreeNode{},
1492					Hide:     hide,
1493				}
1494				node.Children[x.Value] = c
1495			}
1496			node = c
1497		}
1498		node.Modules = append(node.Modules, m)
1499	}
1500	return root
1501}
1502
1503// Size returns the number of modules in the tree.
1504func (n *ModuleTreeNode) Size() int {
1505	s := len(n.Modules)
1506	for _, c := range n.Children {
1507		s += c.Size()
1508	}
1509	return s
1510}
1511
1512// DepthFirst performs a depth-first traversal of the module tree rooted at n.
1513// If f returns true, traversal will not continue to the children of n.
1514func (n *ModuleTreeNode) DepthFirst(f func(node *ModuleTreeNode) bool) {
1515	if !f(n) {
1516		for _, node := range n.Children {
1517			node.DepthFirst(f)
1518		}
1519	}
1520}
1521
1522// TreeNode represents a node in the rule tree. The rule tree is keyed by
1523// rule path.
1524type TreeNode struct {
1525	Key      Value
1526	Values   []util.T
1527	Children map[Value]*TreeNode
1528	Hide     bool
1529}
1530
1531// NewRuleTree returns a new TreeNode that represents the root
1532// of the rule tree populated with the given rules.
1533func NewRuleTree(mtree *ModuleTreeNode) *TreeNode {
1534
1535	ruleSets := map[String][]util.T{}
1536
1537	// Build rule sets for this package.
1538	for _, mod := range mtree.Modules {
1539		for _, rule := range mod.Rules {
1540			key := String(rule.Head.Name)
1541			ruleSets[key] = append(ruleSets[key], rule)
1542		}
1543	}
1544
1545	// Each rule set becomes a leaf node.
1546	children := map[Value]*TreeNode{}
1547
1548	for key, rules := range ruleSets {
1549		children[key] = &TreeNode{
1550			Key:      key,
1551			Children: nil,
1552			Values:   rules,
1553		}
1554	}
1555
1556	// Each module in subpackage becomes child node.
1557	for _, child := range mtree.Children {
1558		children[child.Key] = NewRuleTree(child)
1559	}
1560
1561	return &TreeNode{
1562		Key:      mtree.Key,
1563		Values:   nil,
1564		Children: children,
1565		Hide:     mtree.Hide,
1566	}
1567}
1568
1569// Size returns the number of rules in the tree.
1570func (n *TreeNode) Size() int {
1571	s := len(n.Values)
1572	for _, c := range n.Children {
1573		s += c.Size()
1574	}
1575	return s
1576}
1577
1578// Child returns n's child with key k.
1579func (n *TreeNode) Child(k Value) *TreeNode {
1580	switch k.(type) {
1581	case String, Var:
1582		return n.Children[k]
1583	}
1584	return nil
1585}
1586
1587// DepthFirst performs a depth-first traversal of the rule tree rooted at n. If
1588// f returns true, traversal will not continue to the children of n.
1589func (n *TreeNode) DepthFirst(f func(node *TreeNode) bool) {
1590	if !f(n) {
1591		for _, node := range n.Children {
1592			node.DepthFirst(f)
1593		}
1594	}
1595}
1596
1597// Graph represents the graph of dependencies between rules.
1598type Graph struct {
1599	adj    map[util.T]map[util.T]struct{}
1600	nodes  map[util.T]struct{}
1601	sorted []util.T
1602}
1603
1604// NewGraph returns a new Graph based on modules. The list function must return
1605// the rules referred to directly by the ref.
1606func NewGraph(modules map[string]*Module, list func(Ref) []*Rule) *Graph {
1607
1608	graph := &Graph{
1609		adj:    map[util.T]map[util.T]struct{}{},
1610		nodes:  map[util.T]struct{}{},
1611		sorted: nil,
1612	}
1613
1614	// Create visitor to walk a rule AST and add edges to the rule graph for
1615	// each dependency.
1616	vis := func(a *Rule) *GenericVisitor {
1617		stop := false
1618		return NewGenericVisitor(func(x interface{}) bool {
1619			switch x := x.(type) {
1620			case Ref:
1621				for _, b := range list(x) {
1622					for node := b; node != nil; node = node.Else {
1623						graph.addDependency(a, node)
1624					}
1625				}
1626			case *Rule:
1627				if stop {
1628					// Do not recurse into else clauses (which will be handled
1629					// by the outer visitor.)
1630					return true
1631				}
1632				stop = true
1633			}
1634			return false
1635		})
1636	}
1637
1638	// Walk over all rules, add them to graph, and build adjencency lists.
1639	for _, module := range modules {
1640		WalkRules(module, func(a *Rule) bool {
1641			graph.addNode(a)
1642			vis(a).Walk(a)
1643			return false
1644		})
1645	}
1646
1647	return graph
1648}
1649
1650// Dependencies returns the set of rules that x depends on.
1651func (g *Graph) Dependencies(x util.T) map[util.T]struct{} {
1652	return g.adj[x]
1653}
1654
1655// Sort returns a slice of rules sorted by dependencies. If a cycle is found,
1656// ok is set to false.
1657func (g *Graph) Sort() (sorted []util.T, ok bool) {
1658	if g.sorted != nil {
1659		return g.sorted, true
1660	}
1661
1662	sort := &graphSort{
1663		sorted: make([]util.T, 0, len(g.nodes)),
1664		deps:   g.Dependencies,
1665		marked: map[util.T]struct{}{},
1666		temp:   map[util.T]struct{}{},
1667	}
1668
1669	for node := range g.nodes {
1670		if !sort.Visit(node) {
1671			return nil, false
1672		}
1673	}
1674
1675	g.sorted = sort.sorted
1676	return g.sorted, true
1677}
1678
1679func (g *Graph) addDependency(u util.T, v util.T) {
1680
1681	if _, ok := g.nodes[u]; !ok {
1682		g.addNode(u)
1683	}
1684
1685	if _, ok := g.nodes[v]; !ok {
1686		g.addNode(v)
1687	}
1688
1689	edges, ok := g.adj[u]
1690	if !ok {
1691		edges = map[util.T]struct{}{}
1692		g.adj[u] = edges
1693	}
1694
1695	edges[v] = struct{}{}
1696}
1697
1698func (g *Graph) addNode(n util.T) {
1699	g.nodes[n] = struct{}{}
1700}
1701
1702type graphSort struct {
1703	sorted []util.T
1704	deps   func(util.T) map[util.T]struct{}
1705	marked map[util.T]struct{}
1706	temp   map[util.T]struct{}
1707}
1708
1709func (sort *graphSort) Marked(node util.T) bool {
1710	_, marked := sort.marked[node]
1711	return marked
1712}
1713
1714func (sort *graphSort) Visit(node util.T) (ok bool) {
1715	if _, ok := sort.temp[node]; ok {
1716		return false
1717	}
1718	if sort.Marked(node) {
1719		return true
1720	}
1721	sort.temp[node] = struct{}{}
1722	for other := range sort.deps(node) {
1723		if !sort.Visit(other) {
1724			return false
1725		}
1726	}
1727	sort.marked[node] = struct{}{}
1728	delete(sort.temp, node)
1729	sort.sorted = append(sort.sorted, node)
1730	return true
1731}
1732
1733// GraphTraversal is a Traversal that understands the dependency graph
1734type GraphTraversal struct {
1735	graph   *Graph
1736	visited map[util.T]struct{}
1737}
1738
1739// NewGraphTraversal returns a Traversal for the dependency graph
1740func NewGraphTraversal(graph *Graph) *GraphTraversal {
1741	return &GraphTraversal{
1742		graph:   graph,
1743		visited: map[util.T]struct{}{},
1744	}
1745}
1746
1747// Edges lists all dependency connections for a given node
1748func (g *GraphTraversal) Edges(x util.T) []util.T {
1749	r := []util.T{}
1750	for v := range g.graph.Dependencies(x) {
1751		r = append(r, v)
1752	}
1753	return r
1754}
1755
1756// Visited returns whether a node has been visited, setting a node to visited if not
1757func (g *GraphTraversal) Visited(u util.T) bool {
1758	_, ok := g.visited[u]
1759	g.visited[u] = struct{}{}
1760	return ok
1761}
1762
1763type unsafePair struct {
1764	Expr *Expr
1765	Vars VarSet
1766}
1767
1768type unsafeVarLoc struct {
1769	Var Var
1770	Loc *Location
1771}
1772
1773type unsafeVars map[*Expr]VarSet
1774
1775func (vs unsafeVars) Add(e *Expr, v Var) {
1776	if u, ok := vs[e]; ok {
1777		u[v] = struct{}{}
1778	} else {
1779		vs[e] = VarSet{v: struct{}{}}
1780	}
1781}
1782
1783func (vs unsafeVars) Set(e *Expr, s VarSet) {
1784	vs[e] = s
1785}
1786
1787func (vs unsafeVars) Update(o unsafeVars) {
1788	for k, v := range o {
1789		if _, ok := vs[k]; !ok {
1790			vs[k] = VarSet{}
1791		}
1792		vs[k].Update(v)
1793	}
1794}
1795
1796func (vs unsafeVars) Vars() (result []unsafeVarLoc) {
1797
1798	locs := map[Var]*Location{}
1799
1800	// If var appears in multiple sets then pick first by location.
1801	for expr, vars := range vs {
1802		for v := range vars {
1803			if locs[v].Compare(expr.Location) > 0 {
1804				locs[v] = expr.Location
1805			}
1806		}
1807	}
1808
1809	for v, loc := range locs {
1810		result = append(result, unsafeVarLoc{
1811			Var: v,
1812			Loc: loc,
1813		})
1814	}
1815
1816	sort.Slice(result, func(i, j int) bool {
1817		return result[i].Loc.Compare(result[j].Loc) < 0
1818	})
1819
1820	return result
1821}
1822
1823func (vs unsafeVars) Slice() (result []unsafePair) {
1824	for expr, vs := range vs {
1825		result = append(result, unsafePair{
1826			Expr: expr,
1827			Vars: vs,
1828		})
1829	}
1830	return
1831}
1832
1833// reorderBodyForSafety returns a copy of the body ordered such that
1834// left to right evaluation of the body will not encounter unbound variables
1835// in input positions or negated expressions.
1836//
1837// Expressions are added to the re-ordered body as soon as they are considered
1838// safe. If multiple expressions become safe in the same pass, they are added
1839// in their original order. This results in minimal re-ordering of the body.
1840//
1841// If the body cannot be reordered to ensure safety, the second return value
1842// contains a mapping of expressions to unsafe variables in those expressions.
1843func reorderBodyForSafety(builtins map[string]*Builtin, arity func(Ref) int, globals VarSet, body Body) (Body, unsafeVars) {
1844
1845	body, unsafe := reorderBodyForClosures(builtins, arity, globals, body)
1846	if len(unsafe) != 0 {
1847		return nil, unsafe
1848	}
1849
1850	reordered := Body{}
1851	safe := VarSet{}
1852
1853	for _, e := range body {
1854		for v := range e.Vars(safetyCheckVarVisitorParams) {
1855			if globals.Contains(v) {
1856				safe.Add(v)
1857			} else {
1858				unsafe.Add(e, v)
1859			}
1860		}
1861	}
1862
1863	for {
1864		n := len(reordered)
1865
1866		for _, e := range body {
1867			if reordered.Contains(e) {
1868				continue
1869			}
1870
1871			safe.Update(outputVarsForExpr(e, builtins, arity, safe))
1872
1873			for v := range unsafe[e] {
1874				if safe.Contains(v) {
1875					delete(unsafe[e], v)
1876				}
1877			}
1878
1879			if len(unsafe[e]) == 0 {
1880				delete(unsafe, e)
1881				reordered = append(reordered, e)
1882			}
1883		}
1884
1885		if len(reordered) == n {
1886			break
1887		}
1888	}
1889
1890	// Recursively visit closures and perform the safety checks on them.
1891	// Update the globals at each expression to include the variables that could
1892	// be closed over.
1893	g := globals.Copy()
1894	for i, e := range reordered {
1895		if i > 0 {
1896			g.Update(reordered[i-1].Vars(safetyCheckVarVisitorParams))
1897		}
1898		vis := &bodySafetyVisitor{
1899			builtins: builtins,
1900			arity:    arity,
1901			current:  e,
1902			globals:  g,
1903			unsafe:   unsafe,
1904		}
1905		NewGenericVisitor(vis.Visit).Walk(e)
1906	}
1907
1908	// Need to reset expression indices as re-ordering may have
1909	// changed them.
1910	setExprIndices(reordered)
1911
1912	return reordered, unsafe
1913}
1914
1915type bodySafetyVisitor struct {
1916	builtins map[string]*Builtin
1917	arity    func(Ref) int
1918	current  *Expr
1919	globals  VarSet
1920	unsafe   unsafeVars
1921}
1922
1923func (vis *bodySafetyVisitor) Visit(x interface{}) bool {
1924	switch x := x.(type) {
1925	case *Expr:
1926		cpy := *vis
1927		cpy.current = x
1928
1929		switch ts := x.Terms.(type) {
1930		case *SomeDecl:
1931			NewGenericVisitor(cpy.Visit).Walk(ts)
1932		case []*Term:
1933			for _, t := range ts {
1934				NewGenericVisitor(cpy.Visit).Walk(t)
1935			}
1936		case *Term:
1937			NewGenericVisitor(cpy.Visit).Walk(ts)
1938		}
1939		for i := range x.With {
1940			NewGenericVisitor(cpy.Visit).Walk(x.With[i])
1941		}
1942		return true
1943	case *ArrayComprehension:
1944		vis.checkArrayComprehensionSafety(x)
1945		return true
1946	case *ObjectComprehension:
1947		vis.checkObjectComprehensionSafety(x)
1948		return true
1949	case *SetComprehension:
1950		vis.checkSetComprehensionSafety(x)
1951		return true
1952	}
1953	return false
1954}
1955
1956// Check term for safety. This is analogous to the rule head safety check.
1957func (vis *bodySafetyVisitor) checkComprehensionSafety(tv VarSet, body Body) Body {
1958	bv := body.Vars(safetyCheckVarVisitorParams)
1959	bv.Update(vis.globals)
1960	uv := tv.Diff(bv)
1961	for v := range uv {
1962		vis.unsafe.Add(vis.current, v)
1963	}
1964
1965	// Check body for safety, reordering as necessary.
1966	r, u := reorderBodyForSafety(vis.builtins, vis.arity, vis.globals, body)
1967	if len(u) == 0 {
1968		return r
1969	}
1970
1971	vis.unsafe.Update(u)
1972	return body
1973}
1974
1975func (vis *bodySafetyVisitor) checkArrayComprehensionSafety(ac *ArrayComprehension) {
1976	ac.Body = vis.checkComprehensionSafety(ac.Term.Vars(), ac.Body)
1977}
1978
1979func (vis *bodySafetyVisitor) checkObjectComprehensionSafety(oc *ObjectComprehension) {
1980	tv := oc.Key.Vars()
1981	tv.Update(oc.Value.Vars())
1982	oc.Body = vis.checkComprehensionSafety(tv, oc.Body)
1983}
1984
1985func (vis *bodySafetyVisitor) checkSetComprehensionSafety(sc *SetComprehension) {
1986	sc.Body = vis.checkComprehensionSafety(sc.Term.Vars(), sc.Body)
1987}
1988
1989// reorderBodyForClosures returns a copy of the body ordered such that
1990// expressions (such as array comprehensions) that close over variables are ordered
1991// after other expressions that contain the same variable in an output position.
1992func reorderBodyForClosures(builtins map[string]*Builtin, arity func(Ref) int, globals VarSet, body Body) (Body, unsafeVars) {
1993
1994	reordered := Body{}
1995	unsafe := unsafeVars{}
1996
1997	for {
1998		n := len(reordered)
1999
2000		for _, e := range body {
2001			if reordered.Contains(e) {
2002				continue
2003			}
2004
2005			// Collect vars that are contained in closures within this
2006			// expression.
2007			vs := VarSet{}
2008			WalkClosures(e, func(x interface{}) bool {
2009				vis := &VarVisitor{vars: vs}
2010				vis.Walk(x)
2011				return true
2012			})
2013
2014			// Compute vars that are closed over from the body but not yet
2015			// contained in the output position of an expression in the reordered
2016			// body. These vars are considered unsafe.
2017			cv := vs.Intersect(body.Vars(safetyCheckVarVisitorParams)).Diff(globals)
2018			uv := cv.Diff(outputVarsForBody(reordered, builtins, arity, globals))
2019
2020			if len(uv) == 0 {
2021				reordered = append(reordered, e)
2022				delete(unsafe, e)
2023			} else {
2024				unsafe.Set(e, uv)
2025			}
2026		}
2027
2028		if len(reordered) == n {
2029			break
2030		}
2031	}
2032
2033	return reordered, unsafe
2034}
2035
2036func outputVarsForBody(body Body, builtins map[string]*Builtin, arity func(Ref) int, safe VarSet) VarSet {
2037	o := safe.Copy()
2038	for _, e := range body {
2039		o.Update(outputVarsForExpr(e, builtins, arity, o))
2040	}
2041	return o.Diff(safe)
2042}
2043
2044func outputVarsForExpr(expr *Expr, builtins map[string]*Builtin, arity func(Ref) int, safe VarSet) VarSet {
2045
2046	// Negated expressions must be safe.
2047	if expr.Negated {
2048		return VarSet{}
2049	}
2050
2051	// With modifier inputs must be safe.
2052	for _, with := range expr.With {
2053		unsafe := false
2054		WalkVars(with, func(v Var) bool {
2055			if !safe.Contains(v) {
2056				unsafe = true
2057				return true
2058			}
2059			return false
2060		})
2061		if unsafe {
2062			return VarSet{}
2063		}
2064	}
2065
2066	if !expr.IsCall() {
2067		return outputVarsForExprRefs(expr, safe)
2068	}
2069
2070	terms := expr.Terms.([]*Term)
2071	name := terms[0].String()
2072
2073	if b := builtins[name]; b != nil {
2074		if b.Name == Equality.Name {
2075			return outputVarsForExprEq(expr, safe)
2076		}
2077		return outputVarsForExprBuiltin(expr, b, safe)
2078	}
2079
2080	return outputVarsForExprCall(expr, arity, safe, terms)
2081}
2082
2083func outputVarsForExprBuiltin(expr *Expr, b *Builtin, safe VarSet) VarSet {
2084
2085	output := outputVarsForExprRefs(expr, safe)
2086	terms := expr.Terms.([]*Term)
2087
2088	// Check that all input terms are safe.
2089	for i, t := range terms[1:] {
2090		if b.IsTargetPos(i) {
2091			continue
2092		}
2093		vis := NewVarVisitor().WithParams(VarVisitorParams{
2094			SkipClosures:   true,
2095			SkipSets:       true,
2096			SkipObjectKeys: true,
2097			SkipRefHead:    true,
2098		})
2099		vis.Walk(t)
2100		unsafe := vis.Vars().Diff(output).Diff(safe)
2101		if len(unsafe) > 0 {
2102			return VarSet{}
2103		}
2104	}
2105
2106	// Add vars in target positions to result.
2107	for i, t := range terms[1:] {
2108		if b.IsTargetPos(i) {
2109			vis := NewVarVisitor().WithParams(VarVisitorParams{
2110				SkipRefHead:    true,
2111				SkipSets:       true,
2112				SkipObjectKeys: true,
2113				SkipClosures:   true,
2114			})
2115			vis.Walk(t)
2116			output.Update(vis.vars)
2117		}
2118	}
2119
2120	return output
2121}
2122
2123func outputVarsForExprEq(expr *Expr, safe VarSet) VarSet {
2124	if !validEqAssignArgCount(expr) {
2125		return safe
2126	}
2127	output := outputVarsForExprRefs(expr, safe)
2128	output.Update(safe)
2129	output.Update(Unify(output, expr.Operand(0), expr.Operand(1)))
2130	return output.Diff(safe)
2131}
2132
2133func outputVarsForExprCall(expr *Expr, arity func(Ref) int, safe VarSet, terms []*Term) VarSet {
2134
2135	output := outputVarsForExprRefs(expr, safe)
2136
2137	ref, ok := terms[0].Value.(Ref)
2138	if !ok {
2139		return VarSet{}
2140	}
2141
2142	numArgs := arity(ref)
2143	if numArgs == -1 {
2144		return VarSet{}
2145	}
2146
2147	numInputTerms := numArgs + 1
2148
2149	if numInputTerms >= len(terms) {
2150		return output
2151	}
2152
2153	vis := NewVarVisitor().WithParams(VarVisitorParams{
2154		SkipClosures:   true,
2155		SkipSets:       true,
2156		SkipObjectKeys: true,
2157		SkipRefHead:    true,
2158	})
2159
2160	vis.Walk(Args(terms[:numInputTerms]))
2161	unsafe := vis.Vars().Diff(output).Diff(safe)
2162
2163	if len(unsafe) > 0 {
2164		return VarSet{}
2165	}
2166
2167	vis = NewVarVisitor().WithParams(VarVisitorParams{
2168		SkipRefHead:    true,
2169		SkipSets:       true,
2170		SkipObjectKeys: true,
2171		SkipClosures:   true,
2172	})
2173
2174	vis.Walk(Args(terms[numInputTerms:]))
2175	output.Update(vis.vars)
2176	return output
2177}
2178
2179func outputVarsForExprRefs(expr *Expr, safe VarSet) VarSet {
2180	output := VarSet{}
2181	WalkRefs(expr, func(r Ref) bool {
2182		if safe.Contains(r[0].Value.(Var)) {
2183			output.Update(r.OutputVars())
2184			return false
2185		}
2186		return true
2187	})
2188	return output
2189}
2190
2191type equalityFactory struct {
2192	gen *localVarGenerator
2193}
2194
2195func newEqualityFactory(gen *localVarGenerator) *equalityFactory {
2196	return &equalityFactory{gen}
2197}
2198
2199func (f *equalityFactory) Generate(other *Term) *Expr {
2200	term := NewTerm(f.gen.Generate()).SetLocation(other.Location)
2201	expr := Equality.Expr(term, other)
2202	expr.Generated = true
2203	expr.Location = other.Location
2204	return expr
2205}
2206
2207type localVarGenerator struct {
2208	exclude VarSet
2209	suffix  string
2210	next    int
2211}
2212
2213func newLocalVarGeneratorForModuleSet(sorted []string, modules map[string]*Module) *localVarGenerator {
2214	exclude := NewVarSet()
2215	vis := &VarVisitor{vars: exclude}
2216	for _, key := range sorted {
2217		vis.Walk(modules[key])
2218	}
2219	return &localVarGenerator{exclude: exclude, next: 0}
2220}
2221
2222func newLocalVarGenerator(suffix string, node interface{}) *localVarGenerator {
2223	exclude := NewVarSet()
2224	vis := &VarVisitor{vars: exclude}
2225	vis.Walk(node)
2226	return &localVarGenerator{exclude: exclude, suffix: suffix, next: 0}
2227}
2228
2229func (l *localVarGenerator) Generate() Var {
2230	for {
2231		result := Var("__local" + l.suffix + strconv.Itoa(l.next) + "__")
2232		l.next++
2233		if !l.exclude.Contains(result) {
2234			return result
2235		}
2236	}
2237}
2238
2239func getGlobals(pkg *Package, rules []Var, imports []*Import) map[Var]Ref {
2240
2241	globals := map[Var]Ref{}
2242
2243	// Populate globals with exports within the package.
2244	for _, v := range rules {
2245		global := append(Ref{}, pkg.Path...)
2246		global = append(global, &Term{Value: String(v)})
2247		globals[v] = global
2248	}
2249
2250	// Populate globals with imports.
2251	for _, i := range imports {
2252		if len(i.Alias) > 0 {
2253			path := i.Path.Value.(Ref)
2254			globals[i.Alias] = path
2255		} else {
2256			path := i.Path.Value.(Ref)
2257			if len(path) == 1 {
2258				globals[path[0].Value.(Var)] = path
2259			} else {
2260				v := path[len(path)-1].Value.(String)
2261				globals[Var(v)] = path
2262			}
2263		}
2264	}
2265
2266	return globals
2267}
2268
2269func requiresEval(x *Term) bool {
2270	if x == nil {
2271		return false
2272	}
2273	return ContainsRefs(x) || ContainsComprehensions(x)
2274}
2275
2276func resolveRef(globals map[Var]Ref, ignore *declaredVarStack, ref Ref) Ref {
2277
2278	r := Ref{}
2279	for i, x := range ref {
2280		switch v := x.Value.(type) {
2281		case Var:
2282			if g, ok := globals[v]; ok && !ignore.Contains(v) {
2283				cpy := g.Copy()
2284				for i := range cpy {
2285					cpy[i].SetLocation(x.Location)
2286				}
2287				if i == 0 {
2288					r = cpy
2289				} else {
2290					r = append(r, NewTerm(cpy).SetLocation(x.Location))
2291				}
2292			} else {
2293				r = append(r, x)
2294			}
2295		case Ref, Array, Object, Set, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Call:
2296			r = append(r, resolveRefsInTerm(globals, ignore, x))
2297		default:
2298			r = append(r, x)
2299		}
2300	}
2301
2302	return r
2303}
2304
2305func resolveRefsInRule(globals map[Var]Ref, rule *Rule) error {
2306	ignore := &declaredVarStack{}
2307
2308	vars := NewVarSet()
2309	var vis *GenericVisitor
2310	var err error
2311
2312	// Walk args to collect vars and transform body so that callers can shadow
2313	// root documents.
2314	vis = NewGenericVisitor(func(x interface{}) bool {
2315		if err != nil {
2316			return true
2317		}
2318		switch x := x.(type) {
2319		case Var:
2320			vars.Add(x)
2321
2322		// Object keys cannot be pattern matched so only walk values.
2323		case Object:
2324			for _, k := range x.Keys() {
2325				vis.Walk(x.Get(k))
2326			}
2327
2328		// Skip terms that could contain vars that cannot be pattern matched.
2329		case Set, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Call:
2330			return true
2331
2332		case *Term:
2333			if _, ok := x.Value.(Ref); ok {
2334				if RootDocumentRefs.Contains(x) {
2335					// We could support args named input, data, etc. however
2336					// this would require rewriting terms in the head and body.
2337					// Preventing root document shadowing is simpler, and
2338					// arguably, will prevent confusing names from being used.
2339					err = fmt.Errorf("args must not shadow %v (use a different variable name)", x)
2340					return true
2341				}
2342			}
2343		}
2344		return false
2345	})
2346
2347	vis.Walk(rule.Head.Args)
2348
2349	if err != nil {
2350		return err
2351	}
2352
2353	ignore.Push(vars)
2354	ignore.Push(declaredVars(rule.Body))
2355
2356	if rule.Head.Key != nil {
2357		rule.Head.Key = resolveRefsInTerm(globals, ignore, rule.Head.Key)
2358	}
2359
2360	if rule.Head.Value != nil {
2361		rule.Head.Value = resolveRefsInTerm(globals, ignore, rule.Head.Value)
2362	}
2363
2364	rule.Body = resolveRefsInBody(globals, ignore, rule.Body)
2365	return nil
2366}
2367
2368func resolveRefsInBody(globals map[Var]Ref, ignore *declaredVarStack, body Body) Body {
2369	r := Body{}
2370	for _, expr := range body {
2371		r = append(r, resolveRefsInExpr(globals, ignore, expr))
2372	}
2373	return r
2374}
2375
2376func resolveRefsInExpr(globals map[Var]Ref, ignore *declaredVarStack, expr *Expr) *Expr {
2377	cpy := *expr
2378	switch ts := expr.Terms.(type) {
2379	case *Term:
2380		cpy.Terms = resolveRefsInTerm(globals, ignore, ts)
2381	case []*Term:
2382		buf := make([]*Term, len(ts))
2383		for i := 0; i < len(ts); i++ {
2384			buf[i] = resolveRefsInTerm(globals, ignore, ts[i])
2385		}
2386		cpy.Terms = buf
2387	}
2388	for _, w := range cpy.With {
2389		w.Target = resolveRefsInTerm(globals, ignore, w.Target)
2390		w.Value = resolveRefsInTerm(globals, ignore, w.Value)
2391	}
2392	return &cpy
2393}
2394
2395func resolveRefsInTerm(globals map[Var]Ref, ignore *declaredVarStack, term *Term) *Term {
2396	switch v := term.Value.(type) {
2397	case Var:
2398		if g, ok := globals[v]; ok && !ignore.Contains(v) {
2399			cpy := g.Copy()
2400			for i := range cpy {
2401				cpy[i].SetLocation(term.Location)
2402			}
2403			return NewTerm(cpy).SetLocation(term.Location)
2404		}
2405		return term
2406	case Ref:
2407		fqn := resolveRef(globals, ignore, v)
2408		cpy := *term
2409		cpy.Value = fqn
2410		return &cpy
2411	case Object:
2412		cpy := *term
2413		cpy.Value, _ = v.Map(func(k, v *Term) (*Term, *Term, error) {
2414			k = resolveRefsInTerm(globals, ignore, k)
2415			v = resolveRefsInTerm(globals, ignore, v)
2416			return k, v, nil
2417		})
2418		return &cpy
2419	case Array:
2420		cpy := *term
2421		cpy.Value = Array(resolveRefsInTermSlice(globals, ignore, v))
2422		return &cpy
2423	case Call:
2424		cpy := *term
2425		cpy.Value = Call(resolveRefsInTermSlice(globals, ignore, v))
2426		return &cpy
2427	case Set:
2428		s, _ := v.Map(func(e *Term) (*Term, error) {
2429			return resolveRefsInTerm(globals, ignore, e), nil
2430		})
2431		cpy := *term
2432		cpy.Value = s
2433		return &cpy
2434	case *ArrayComprehension:
2435		ac := &ArrayComprehension{}
2436		ignore.Push(declaredVars(v.Body))
2437		ac.Term = resolveRefsInTerm(globals, ignore, v.Term)
2438		ac.Body = resolveRefsInBody(globals, ignore, v.Body)
2439		cpy := *term
2440		cpy.Value = ac
2441		ignore.Pop()
2442		return &cpy
2443	case *ObjectComprehension:
2444		oc := &ObjectComprehension{}
2445		ignore.Push(declaredVars(v.Body))
2446		oc.Key = resolveRefsInTerm(globals, ignore, v.Key)
2447		oc.Value = resolveRefsInTerm(globals, ignore, v.Value)
2448		oc.Body = resolveRefsInBody(globals, ignore, v.Body)
2449		cpy := *term
2450		cpy.Value = oc
2451		ignore.Pop()
2452		return &cpy
2453	case *SetComprehension:
2454		sc := &SetComprehension{}
2455		ignore.Push(declaredVars(v.Body))
2456		sc.Term = resolveRefsInTerm(globals, ignore, v.Term)
2457		sc.Body = resolveRefsInBody(globals, ignore, v.Body)
2458		cpy := *term
2459		cpy.Value = sc
2460		ignore.Pop()
2461		return &cpy
2462	default:
2463		return term
2464	}
2465}
2466
2467func resolveRefsInTermSlice(globals map[Var]Ref, ignore *declaredVarStack, terms []*Term) []*Term {
2468	cpy := make([]*Term, len(terms))
2469	for i := 0; i < len(terms); i++ {
2470		cpy[i] = resolveRefsInTerm(globals, ignore, terms[i])
2471	}
2472	return cpy
2473}
2474
2475type declaredVarStack []VarSet
2476
2477func (s declaredVarStack) Contains(v Var) bool {
2478	for i := len(s) - 1; i >= 0; i-- {
2479		if _, ok := s[i][v]; ok {
2480			return ok
2481		}
2482	}
2483	return false
2484}
2485
2486func (s declaredVarStack) Add(v Var) {
2487	s[len(s)-1].Add(v)
2488}
2489
2490func (s *declaredVarStack) Push(vs VarSet) {
2491	*s = append(*s, vs)
2492}
2493
2494func (s *declaredVarStack) Pop() {
2495	curr := *s
2496	*s = curr[:len(curr)-1]
2497}
2498
2499func declaredVars(x interface{}) VarSet {
2500	vars := NewVarSet()
2501	vis := NewGenericVisitor(func(x interface{}) bool {
2502		switch x := x.(type) {
2503		case *Expr:
2504			if x.IsAssignment() && validEqAssignArgCount(x) {
2505				WalkVars(x.Operand(0), func(v Var) bool {
2506					vars.Add(v)
2507					return false
2508				})
2509			} else if decl, ok := x.Terms.(*SomeDecl); ok {
2510				for i := range decl.Symbols {
2511					vars.Add(decl.Symbols[i].Value.(Var))
2512				}
2513			}
2514		case *ArrayComprehension, *SetComprehension, *ObjectComprehension:
2515			return true
2516		}
2517		return false
2518	})
2519	vis.Walk(x)
2520	return vars
2521}
2522
2523// rewriteComprehensionTerms will rewrite comprehensions so that the term part
2524// is bound to a variable in the body. This allows any type of term to be used
2525// in the term part (even if the term requires evaluation.)
2526//
2527// For instance, given the following comprehension:
2528//
2529// [x[0] | x = y[_]; y = [1,2,3]]
2530//
2531// The comprehension would be rewritten as:
2532//
2533// [__local0__ | x = y[_]; y = [1,2,3]; __local0__ = x[0]]
2534func rewriteComprehensionTerms(f *equalityFactory, node interface{}) (interface{}, error) {
2535	return TransformComprehensions(node, func(x interface{}) (Value, error) {
2536		switch x := x.(type) {
2537		case *ArrayComprehension:
2538			if requiresEval(x.Term) {
2539				expr := f.Generate(x.Term)
2540				x.Term = expr.Operand(0)
2541				x.Body.Append(expr)
2542			}
2543			return x, nil
2544		case *SetComprehension:
2545			if requiresEval(x.Term) {
2546				expr := f.Generate(x.Term)
2547				x.Term = expr.Operand(0)
2548				x.Body.Append(expr)
2549			}
2550			return x, nil
2551		case *ObjectComprehension:
2552			if requiresEval(x.Key) {
2553				expr := f.Generate(x.Key)
2554				x.Key = expr.Operand(0)
2555				x.Body.Append(expr)
2556			}
2557			if requiresEval(x.Value) {
2558				expr := f.Generate(x.Value)
2559				x.Value = expr.Operand(0)
2560				x.Body.Append(expr)
2561			}
2562			return x, nil
2563		}
2564		panic("illegal type")
2565	})
2566}
2567
2568// rewriteEquals will rewrite exprs under x as unification calls instead of ==
2569// calls. For example:
2570//
2571// data.foo == data.bar is rewritten as data.foo = data.bar
2572//
2573// This stage should only run the safety check (since == is a built-in with no
2574// outputs, so the inputs must not be marked as safe.)
2575//
2576// This stage is not executed by the query compiler by default because when
2577// callers specify == instead of = they expect to receive a true/false/undefined
2578// result back whereas with = the result is only ever true/undefined. For
2579// partial evaluation cases we do want to rewrite == to = to simplify the
2580// result.
2581func rewriteEquals(x interface{}) {
2582	doubleEq := Equal.Ref()
2583	unifyOp := Equality.Ref()
2584	WalkExprs(x, func(x *Expr) bool {
2585		if x.IsCall() {
2586			operator := x.Operator()
2587			if operator.Equal(doubleEq) && len(x.Operands()) == 2 {
2588				x.SetOperator(NewTerm(unifyOp))
2589			}
2590		}
2591		return false
2592	})
2593}
2594
2595// rewriteDynamics will rewrite the body so that dynamic terms (i.e., refs and
2596// comprehensions) are bound to vars earlier in the query. This translation
2597// results in eager evaluation.
2598//
2599// For instance, given the following query:
2600//
2601// foo(data.bar) = 1
2602//
2603// The rewritten version will be:
2604//
2605// __local0__ = data.bar; foo(__local0__) = 1
2606func rewriteDynamics(f *equalityFactory, body Body) Body {
2607	result := make(Body, 0, len(body))
2608	for _, expr := range body {
2609		if expr.IsEquality() {
2610			result = rewriteDynamicsEqExpr(f, expr, result)
2611		} else if expr.IsCall() {
2612			result = rewriteDynamicsCallExpr(f, expr, result)
2613		} else {
2614			result = rewriteDynamicsTermExpr(f, expr, result)
2615		}
2616	}
2617	return result
2618}
2619
2620func appendExpr(body Body, expr *Expr) Body {
2621	body.Append(expr)
2622	return body
2623}
2624
2625func rewriteDynamicsEqExpr(f *equalityFactory, expr *Expr, result Body) Body {
2626	if !validEqAssignArgCount(expr) {
2627		return appendExpr(result, expr)
2628	}
2629	terms := expr.Terms.([]*Term)
2630	result, terms[1] = rewriteDynamicsInTerm(expr, f, terms[1], result)
2631	result, terms[2] = rewriteDynamicsInTerm(expr, f, terms[2], result)
2632	return appendExpr(result, expr)
2633}
2634
2635func rewriteDynamicsCallExpr(f *equalityFactory, expr *Expr, result Body) Body {
2636	terms := expr.Terms.([]*Term)
2637	for i := 1; i < len(terms); i++ {
2638		result, terms[i] = rewriteDynamicsOne(expr, f, terms[i], result)
2639	}
2640	return appendExpr(result, expr)
2641}
2642
2643func rewriteDynamicsTermExpr(f *equalityFactory, expr *Expr, result Body) Body {
2644	term := expr.Terms.(*Term)
2645	result, expr.Terms = rewriteDynamicsInTerm(expr, f, term, result)
2646	return appendExpr(result, expr)
2647}
2648
2649func rewriteDynamicsInTerm(original *Expr, f *equalityFactory, term *Term, result Body) (Body, *Term) {
2650	switch v := term.Value.(type) {
2651	case Ref:
2652		for i := 1; i < len(v); i++ {
2653			result, v[i] = rewriteDynamicsOne(original, f, v[i], result)
2654		}
2655	case *ArrayComprehension:
2656		v.Body = rewriteDynamics(f, v.Body)
2657	case *SetComprehension:
2658		v.Body = rewriteDynamics(f, v.Body)
2659	case *ObjectComprehension:
2660		v.Body = rewriteDynamics(f, v.Body)
2661	default:
2662		result, term = rewriteDynamicsOne(original, f, term, result)
2663	}
2664	return result, term
2665}
2666
2667func rewriteDynamicsOne(original *Expr, f *equalityFactory, term *Term, result Body) (Body, *Term) {
2668	switch v := term.Value.(type) {
2669	case Ref:
2670		for i := 1; i < len(v); i++ {
2671			result, v[i] = rewriteDynamicsOne(original, f, v[i], result)
2672		}
2673		generated := f.Generate(term)
2674		generated.With = original.With
2675		result.Append(generated)
2676		return result, result[len(result)-1].Operand(0)
2677	case Array:
2678		for i := 0; i < len(v); i++ {
2679			result, v[i] = rewriteDynamicsOne(original, f, v[i], result)
2680		}
2681		return result, term
2682	case Object:
2683		cpy := NewObject()
2684		for _, key := range v.Keys() {
2685			value := v.Get(key)
2686			result, key = rewriteDynamicsOne(original, f, key, result)
2687			result, value = rewriteDynamicsOne(original, f, value, result)
2688			cpy.Insert(key, value)
2689		}
2690		return result, NewTerm(cpy).SetLocation(term.Location)
2691	case Set:
2692		cpy := NewSet()
2693		for _, term := range v.Slice() {
2694			var rw *Term
2695			result, rw = rewriteDynamicsOne(original, f, term, result)
2696			cpy.Add(rw)
2697		}
2698		return result, NewTerm(cpy).SetLocation(term.Location)
2699	case *ArrayComprehension:
2700		var extra *Expr
2701		v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
2702		result.Append(extra)
2703		return result, result[len(result)-1].Operand(0)
2704	case *SetComprehension:
2705		var extra *Expr
2706		v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
2707		result.Append(extra)
2708		return result, result[len(result)-1].Operand(0)
2709	case *ObjectComprehension:
2710		var extra *Expr
2711		v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
2712		result.Append(extra)
2713		return result, result[len(result)-1].Operand(0)
2714	}
2715	return result, term
2716}
2717
2718func rewriteDynamicsComprehensionBody(original *Expr, f *equalityFactory, body Body, term *Term) (Body, *Expr) {
2719	body = rewriteDynamics(f, body)
2720	generated := f.Generate(term)
2721	generated.With = original.With
2722	return body, generated
2723}
2724
2725func rewriteExprTermsInHead(gen *localVarGenerator, rule *Rule) {
2726	if rule.Head.Key != nil {
2727		support, output := expandExprTerm(gen, rule.Head.Key)
2728		for i := range support {
2729			rule.Body.Append(support[i])
2730		}
2731		rule.Head.Key = output
2732	}
2733	if rule.Head.Value != nil {
2734		support, output := expandExprTerm(gen, rule.Head.Value)
2735		for i := range support {
2736			rule.Body.Append(support[i])
2737		}
2738		rule.Head.Value = output
2739	}
2740}
2741
2742func rewriteExprTermsInBody(gen *localVarGenerator, body Body) Body {
2743	cpy := make(Body, 0, len(body))
2744	for i := 0; i < len(body); i++ {
2745		for _, expr := range expandExpr(gen, body[i]) {
2746			cpy.Append(expr)
2747		}
2748	}
2749	return cpy
2750}
2751
2752func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) {
2753	for i := range expr.With {
2754		extras, value := expandExprTerm(gen, expr.With[i].Value)
2755		expr.With[i].Value = value
2756		result = append(result, extras...)
2757	}
2758	switch terms := expr.Terms.(type) {
2759	case *Term:
2760		extras, term := expandExprTerm(gen, terms)
2761		if len(expr.With) > 0 {
2762			for i := range extras {
2763				extras[i].With = expr.With
2764			}
2765		}
2766		result = append(result, extras...)
2767		expr.Terms = term
2768		result = append(result, expr)
2769	case []*Term:
2770		for i := 1; i < len(terms); i++ {
2771			var extras []*Expr
2772			extras, terms[i] = expandExprTerm(gen, terms[i])
2773			if len(expr.With) > 0 {
2774				for i := range extras {
2775					extras[i].With = expr.With
2776				}
2777			}
2778			result = append(result, extras...)
2779		}
2780		result = append(result, expr)
2781	}
2782	return
2783}
2784
2785func expandExprTerm(gen *localVarGenerator, term *Term) (support []*Expr, output *Term) {
2786	output = term
2787	switch v := term.Value.(type) {
2788	case Call:
2789		for i := 1; i < len(v); i++ {
2790			var extras []*Expr
2791			extras, v[i] = expandExprTerm(gen, v[i])
2792			support = append(support, extras...)
2793		}
2794		output = NewTerm(gen.Generate()).SetLocation(term.Location)
2795		expr := v.MakeExpr(output).SetLocation(term.Location)
2796		expr.Generated = true
2797		support = append(support, expr)
2798	case Ref:
2799		support = expandExprRef(gen, v)
2800	case Array:
2801		support = expandExprTermSlice(gen, v)
2802	case Object:
2803		cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) {
2804			extras1, expandedKey := expandExprTerm(gen, k)
2805			extras2, expandedValue := expandExprTerm(gen, v)
2806			support = append(support, extras1...)
2807			support = append(support, extras2...)
2808			return expandedKey, expandedValue, nil
2809		})
2810		output = NewTerm(cpy).SetLocation(term.Location)
2811	case Set:
2812		cpy, _ := v.Map(func(x *Term) (*Term, error) {
2813			extras, expanded := expandExprTerm(gen, x)
2814			support = append(support, extras...)
2815			return expanded, nil
2816		})
2817		output = NewTerm(cpy).SetLocation(term.Location)
2818	case *ArrayComprehension:
2819		support, term := expandExprTerm(gen, v.Term)
2820		for i := range support {
2821			v.Body.Append(support[i])
2822		}
2823		v.Term = term
2824		v.Body = rewriteExprTermsInBody(gen, v.Body)
2825	case *SetComprehension:
2826		support, term := expandExprTerm(gen, v.Term)
2827		for i := range support {
2828			v.Body.Append(support[i])
2829		}
2830		v.Term = term
2831		v.Body = rewriteExprTermsInBody(gen, v.Body)
2832	case *ObjectComprehension:
2833		support, key := expandExprTerm(gen, v.Key)
2834		for i := range support {
2835			v.Body.Append(support[i])
2836		}
2837		v.Key = key
2838		support, value := expandExprTerm(gen, v.Value)
2839		for i := range support {
2840			v.Body.Append(support[i])
2841		}
2842		v.Value = value
2843		v.Body = rewriteExprTermsInBody(gen, v.Body)
2844	}
2845	return
2846}
2847
2848func expandExprRef(gen *localVarGenerator, v []*Term) (support []*Expr) {
2849	// Start by calling a normal expandExprTerm on all terms.
2850	support = expandExprTermSlice(gen, v)
2851
2852	// Rewrite references in order to support indirect references.  We rewrite
2853	// e.g.
2854	//
2855	//     [1, 2, 3][i]
2856	//
2857	// to
2858	//
2859	//     __local_var = [1, 2, 3]
2860	//     __local_var[i]
2861	//
2862	// to support these.  This only impacts the reference subject, i.e. the
2863	// first item in the slice.
2864	var subject = v[0]
2865	switch subject.Value.(type) {
2866	case Array, Object, Set, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Call:
2867		f := newEqualityFactory(gen)
2868		assignToLocal := f.Generate(subject)
2869		support = append(support, assignToLocal)
2870		v[0] = assignToLocal.Operand(0)
2871	}
2872	return
2873}
2874
2875func expandExprTermSlice(gen *localVarGenerator, v []*Term) (support []*Expr) {
2876	for i := 0; i < len(v); i++ {
2877		var extras []*Expr
2878		extras, v[i] = expandExprTerm(gen, v[i])
2879		support = append(support, extras...)
2880	}
2881	return
2882}
2883
2884type localDeclaredVars struct {
2885	vars []*declaredVarSet
2886
2887	// rewritten contains a mapping of *all* user-defined variables
2888	// that have been rewritten whereas vars contains the state
2889	// from the current query (not not any nested queries, and all
2890	// vars seen).
2891	rewritten map[Var]Var
2892}
2893
2894type varOccurrence int
2895
2896const (
2897	newVar varOccurrence = iota
2898	argVar
2899	seenVar
2900	assignedVar
2901	declaredVar
2902)
2903
2904type declaredVarSet struct {
2905	vs         map[Var]Var
2906	reverse    map[Var]Var
2907	occurrence map[Var]varOccurrence
2908}
2909
2910func newDeclaredVarSet() *declaredVarSet {
2911	return &declaredVarSet{
2912		vs:         map[Var]Var{},
2913		reverse:    map[Var]Var{},
2914		occurrence: map[Var]varOccurrence{},
2915	}
2916}
2917
2918func newLocalDeclaredVars() *localDeclaredVars {
2919	return &localDeclaredVars{
2920		vars:      []*declaredVarSet{newDeclaredVarSet()},
2921		rewritten: map[Var]Var{},
2922	}
2923}
2924
2925func (s *localDeclaredVars) Push() {
2926	s.vars = append(s.vars, newDeclaredVarSet())
2927}
2928
2929func (s *localDeclaredVars) Pop() *declaredVarSet {
2930	sl := s.vars
2931	curr := sl[len(sl)-1]
2932	s.vars = sl[:len(sl)-1]
2933	return curr
2934}
2935
2936func (s localDeclaredVars) Peek() *declaredVarSet {
2937	return s.vars[len(s.vars)-1]
2938}
2939
2940func (s localDeclaredVars) Insert(x, y Var, occurrence varOccurrence) {
2941	elem := s.vars[len(s.vars)-1]
2942	elem.vs[x] = y
2943	elem.reverse[y] = x
2944	elem.occurrence[x] = occurrence
2945
2946	// If the variable has been rewritten (where x != y, with y being
2947	// the generated value), store it in the map of rewritten vars.
2948	// Assume that the generated values are unique for the compilation.
2949	if !x.Equal(y) {
2950		s.rewritten[y] = x
2951	}
2952}
2953
2954func (s localDeclaredVars) Declared(x Var) (y Var, ok bool) {
2955	for i := len(s.vars) - 1; i >= 0; i-- {
2956		if y, ok = s.vars[i].vs[x]; ok {
2957			return
2958		}
2959	}
2960	return
2961}
2962
2963// Occurrence returns a flag that indicates whether x has occurred in the
2964// current scope.
2965func (s localDeclaredVars) Occurrence(x Var) varOccurrence {
2966	return s.vars[len(s.vars)-1].occurrence[x]
2967}
2968
2969// GlobalOccurrence returns a flag that indicates whether x has occurred in the
2970// global scope.
2971func (s localDeclaredVars) GlobalOccurrence(x Var) (varOccurrence, bool) {
2972	for i := len(s.vars) - 1; i >= 0; i-- {
2973		if occ, ok := s.vars[i].occurrence[x]; ok {
2974			return occ, true
2975		}
2976	}
2977	return newVar, false
2978}
2979
2980// rewriteLocalVars rewrites bodies to remove assignment/declaration
2981// expressions. For example:
2982//
2983// a := 1; p[a]
2984//
2985// Is rewritten to:
2986//
2987// __local0__ = 1; p[__local0__]
2988//
2989// During rewriting, assignees are validated to prevent use before declaration.
2990func rewriteLocalVars(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body) (Body, map[Var]Var, Errors) {
2991	var errs Errors
2992	body, errs = rewriteDeclaredVarsInBody(g, stack, used, body, errs)
2993	return body, stack.Pop().vs, errs
2994}
2995
2996func rewriteDeclaredVarsInBody(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body, errs Errors) (Body, Errors) {
2997
2998	var cpy Body
2999
3000	for i := range body {
3001		var expr *Expr
3002		if body[i].IsAssignment() {
3003			expr, errs = rewriteDeclaredAssignment(g, stack, body[i], errs)
3004		} else if decl, ok := body[i].Terms.(*SomeDecl); ok {
3005			errs = rewriteSomeDeclStatement(g, stack, decl, errs)
3006		} else {
3007			expr, errs = rewriteDeclaredVarsInExpr(g, stack, body[i], errs)
3008		}
3009		if expr != nil {
3010			cpy.Append(expr)
3011		}
3012	}
3013
3014	// If the body only contained a var statement it will be empty at this
3015	// point. Append true to the body to ensure that it's non-empty (zero length
3016	// bodies are not supported.)
3017	if len(cpy) == 0 {
3018		cpy.Append(NewExpr(BooleanTerm(true)))
3019	}
3020
3021	return cpy, checkUnusedDeclaredVars(body[0].Loc(), stack, used, cpy, errs)
3022}
3023
3024func checkUnusedDeclaredVars(loc *Location, stack *localDeclaredVars, used VarSet, cpy Body, errs Errors) Errors {
3025
3026	// NOTE(tsandall): Do not generate more errors if there are existing
3027	// declaration errors.
3028	if len(errs) > 0 {
3029		return errs
3030	}
3031
3032	dvs := stack.Peek()
3033	declared := NewVarSet()
3034
3035	for v, occ := range dvs.occurrence {
3036		if occ == declaredVar {
3037			declared.Add(dvs.vs[v])
3038		}
3039	}
3040
3041	bodyvars := cpy.Vars(VarVisitorParams{})
3042
3043	for v := range used {
3044		if gv, ok := stack.Declared(v); ok {
3045			bodyvars.Add(gv)
3046		} else {
3047			bodyvars.Add(v)
3048		}
3049	}
3050
3051	unused := declared.Diff(bodyvars).Diff(used)
3052
3053	for _, gv := range unused.Sorted() {
3054		errs = append(errs, NewError(CompileErr, loc, "declared var %v unused", dvs.reverse[gv]))
3055	}
3056
3057	return errs
3058}
3059
3060func rewriteSomeDeclStatement(g *localVarGenerator, stack *localDeclaredVars, decl *SomeDecl, errs Errors) Errors {
3061	for i := range decl.Symbols {
3062		v := decl.Symbols[i].Value.(Var)
3063		if _, err := rewriteDeclaredVar(g, stack, v, declaredVar); err != nil {
3064			errs = append(errs, NewError(CompileErr, decl.Loc(), err.Error()))
3065		}
3066	}
3067	return errs
3068}
3069
3070func rewriteDeclaredVarsInExpr(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors) (*Expr, Errors) {
3071	vis := NewGenericVisitor(func(x interface{}) bool {
3072		var stop bool
3073		switch x := x.(type) {
3074		case *Term:
3075			stop, errs = rewriteDeclaredVarsInTerm(g, stack, x, errs)
3076		case *With:
3077			_, errs = rewriteDeclaredVarsInTerm(g, stack, x.Value, errs)
3078			stop = true
3079		}
3080		return stop
3081	})
3082	vis.Walk(expr)
3083	return expr, errs
3084}
3085
3086func rewriteDeclaredAssignment(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors) (*Expr, Errors) {
3087
3088	if expr.Negated {
3089		errs = append(errs, NewError(CompileErr, expr.Location, "cannot assign vars inside negated expression"))
3090		return expr, errs
3091	}
3092
3093	numErrsBefore := len(errs)
3094
3095	if !validEqAssignArgCount(expr) {
3096		return expr, errs
3097	}
3098
3099	// Rewrite terms on right hand side capture seen vars and recursively
3100	// process comprehensions before left hand side is processed. Also
3101	// rewrite with modifier.
3102	errs = rewriteDeclaredVarsInTermRecursive(g, stack, expr.Operand(1), errs)
3103
3104	for _, w := range expr.With {
3105		errs = rewriteDeclaredVarsInTermRecursive(g, stack, w.Value, errs)
3106	}
3107
3108	// Rewrite vars on left hand side with unique names. Catch redeclaration
3109	// and invalid term types here.
3110	var vis func(t *Term) bool
3111
3112	vis = func(t *Term) bool {
3113		switch v := t.Value.(type) {
3114		case Var:
3115			if gv, err := rewriteDeclaredVar(g, stack, v, assignedVar); err != nil {
3116				errs = append(errs, NewError(CompileErr, t.Location, err.Error()))
3117			} else {
3118				t.Value = gv
3119			}
3120			return true
3121		case Array:
3122			return false
3123		case Object:
3124			v.Foreach(func(_, v *Term) {
3125				WalkTerms(v, vis)
3126			})
3127			return true
3128		case Ref:
3129			if RootDocumentRefs.Contains(t) {
3130				if gv, err := rewriteDeclaredVar(g, stack, v[0].Value.(Var), assignedVar); err != nil {
3131					errs = append(errs, NewError(CompileErr, t.Location, err.Error()))
3132				} else {
3133					t.Value = gv
3134				}
3135				return true
3136			}
3137		}
3138		errs = append(errs, NewError(CompileErr, t.Location, "cannot assign to %v", TypeName(t.Value)))
3139		return true
3140	}
3141
3142	WalkTerms(expr.Operand(0), vis)
3143
3144	if len(errs) == numErrsBefore {
3145		loc := expr.Operator()[0].Location
3146		expr.SetOperator(RefTerm(VarTerm(Equality.Name).SetLocation(loc)).SetLocation(loc))
3147	}
3148
3149	return expr, errs
3150}
3151
3152func rewriteDeclaredVarsInTerm(g *localVarGenerator, stack *localDeclaredVars, term *Term, errs Errors) (bool, Errors) {
3153	switch v := term.Value.(type) {
3154	case Var:
3155		if gv, ok := stack.Declared(v); ok {
3156			term.Value = gv
3157		} else if stack.Occurrence(v) == newVar {
3158			stack.Insert(v, v, seenVar)
3159		}
3160	case Ref:
3161		if RootDocumentRefs.Contains(term) {
3162			x := v[0].Value.(Var)
3163			if occ, ok := stack.GlobalOccurrence(x); ok && occ != seenVar {
3164				gv, _ := stack.Declared(x)
3165				term.Value = gv
3166			}
3167
3168			return true, errs
3169		}
3170		return false, errs
3171	case Object:
3172		cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) {
3173			kcpy := k.Copy()
3174			errs = rewriteDeclaredVarsInTermRecursive(g, stack, kcpy, errs)
3175			errs = rewriteDeclaredVarsInTermRecursive(g, stack, v, errs)
3176			return kcpy, v, nil
3177		})
3178		term.Value = cpy
3179	case Set:
3180		cpy, _ := v.Map(func(elem *Term) (*Term, error) {
3181			elemcpy := elem.Copy()
3182			errs = rewriteDeclaredVarsInTermRecursive(g, stack, elemcpy, errs)
3183			return elemcpy, nil
3184		})
3185		term.Value = cpy
3186	case *ArrayComprehension:
3187		errs = rewriteDeclaredVarsInArrayComprehension(g, stack, v, errs)
3188	case *SetComprehension:
3189		errs = rewriteDeclaredVarsInSetComprehension(g, stack, v, errs)
3190	case *ObjectComprehension:
3191		errs = rewriteDeclaredVarsInObjectComprehension(g, stack, v, errs)
3192	default:
3193		return false, errs
3194	}
3195	return true, errs
3196}
3197
3198func rewriteDeclaredVarsInTermRecursive(g *localVarGenerator, stack *localDeclaredVars, term *Term, errs Errors) Errors {
3199	WalkNodes(term, func(n Node) bool {
3200		var stop bool
3201		switch n := n.(type) {
3202		case *With:
3203			_, errs = rewriteDeclaredVarsInTerm(g, stack, n.Value, errs)
3204			stop = true
3205		case *Term:
3206			stop, errs = rewriteDeclaredVarsInTerm(g, stack, n, errs)
3207		}
3208		return stop
3209	})
3210	return errs
3211}
3212
3213func rewriteDeclaredVarsInArrayComprehension(g *localVarGenerator, stack *localDeclaredVars, v *ArrayComprehension, errs Errors) Errors {
3214	stack.Push()
3215	v.Body, errs = rewriteDeclaredVarsInBody(g, stack, nil, v.Body, errs)
3216	errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Term, errs)
3217	stack.Pop()
3218	return errs
3219}
3220
3221func rewriteDeclaredVarsInSetComprehension(g *localVarGenerator, stack *localDeclaredVars, v *SetComprehension, errs Errors) Errors {
3222	stack.Push()
3223	v.Body, errs = rewriteDeclaredVarsInBody(g, stack, nil, v.Body, errs)
3224	errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Term, errs)
3225	stack.Pop()
3226	return errs
3227}
3228
3229func rewriteDeclaredVarsInObjectComprehension(g *localVarGenerator, stack *localDeclaredVars, v *ObjectComprehension, errs Errors) Errors {
3230	stack.Push()
3231	v.Body, errs = rewriteDeclaredVarsInBody(g, stack, nil, v.Body, errs)
3232	errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Key, errs)
3233	errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Value, errs)
3234	stack.Pop()
3235	return errs
3236}
3237
3238func rewriteDeclaredVar(g *localVarGenerator, stack *localDeclaredVars, v Var, occ varOccurrence) (gv Var, err error) {
3239	switch stack.Occurrence(v) {
3240	case seenVar:
3241		return gv, fmt.Errorf("var %v referenced above", v)
3242	case assignedVar:
3243		return gv, fmt.Errorf("var %v assigned above", v)
3244	case declaredVar:
3245		return gv, fmt.Errorf("var %v declared above", v)
3246	case argVar:
3247		return gv, fmt.Errorf("arg %v redeclared", v)
3248	}
3249	gv = g.Generate()
3250	stack.Insert(v, gv, occ)
3251	return
3252}
3253
3254// rewriteWithModifiersInBody will rewrite the body so that with modifiers do
3255// not contain terms that require evaluation as values. If this function
3256// encounters an invalid with modifier target then it will raise an error.
3257func rewriteWithModifiersInBody(c *Compiler, f *equalityFactory, body Body) (Body, *Error) {
3258	var result Body
3259	for i := range body {
3260		exprs, err := rewriteWithModifier(c, f, body[i])
3261		if err != nil {
3262			return nil, err
3263		}
3264		if len(exprs) > 0 {
3265			for _, expr := range exprs {
3266				result.Append(expr)
3267			}
3268		} else {
3269			result.Append(body[i])
3270		}
3271	}
3272	return result, nil
3273}
3274
3275func rewriteWithModifier(c *Compiler, f *equalityFactory, expr *Expr) ([]*Expr, *Error) {
3276
3277	var result []*Expr
3278	for i := range expr.With {
3279		err := validateTarget(c, expr.With[i].Target)
3280		if err != nil {
3281			return nil, err
3282		}
3283
3284		if requiresEval(expr.With[i].Value) {
3285			eq := f.Generate(expr.With[i].Value)
3286			result = append(result, eq)
3287			expr.With[i].Value = eq.Operand(0)
3288		}
3289	}
3290
3291	// If any of the with modifiers in this expression were rewritten then result
3292	// will be non-empty. In this case, the expression will have been modified and
3293	// it should also be added to the result.
3294	if len(result) > 0 {
3295		result = append(result, expr)
3296	}
3297	return result, nil
3298}
3299
3300func validateTarget(c *Compiler, term *Term) *Error {
3301	if !isInputRef(term) && !isDataRef(term) {
3302		return NewError(TypeErr, term.Location, "with keyword target must start with %v or %v", InputRootDocument, DefaultRootDocument)
3303	}
3304
3305	if isDataRef(term) {
3306		ref := term.Value.(Ref)
3307		node := c.RuleTree
3308		for i := 0; i < len(ref)-1; i++ {
3309			child := node.Child(ref[i].Value)
3310			if child == nil {
3311				break
3312			} else if len(child.Values) > 0 {
3313				return NewError(CompileErr, term.Loc(), "with keyword cannot partially replace virtual document(s)")
3314			}
3315			node = child
3316		}
3317
3318		if node != nil {
3319			if child := node.Child(ref[len(ref)-1].Value); child != nil {
3320				for _, value := range child.Values {
3321					if len(value.(*Rule).Head.Args) > 0 {
3322						return NewError(CompileErr, term.Loc(), "with keyword cannot replace functions")
3323					}
3324				}
3325			}
3326		}
3327
3328	}
3329	return nil
3330}
3331
3332func isInputRef(term *Term) bool {
3333	if ref, ok := term.Value.(Ref); ok {
3334		if ref.HasPrefix(InputRootRef) {
3335			return true
3336		}
3337	}
3338	return false
3339}
3340
3341func isDataRef(term *Term) bool {
3342	if ref, ok := term.Value.(Ref); ok {
3343		if ref.HasPrefix(DefaultRootRef) {
3344			return true
3345		}
3346	}
3347	return false
3348}
3349
3350func isVirtual(node *TreeNode, ref Ref) bool {
3351	for i := 0; i < len(ref); i++ {
3352		child := node.Child(ref[i].Value)
3353		if child == nil {
3354			return false
3355		} else if len(child.Values) > 0 {
3356			return true
3357		}
3358		node = child
3359	}
3360	return true
3361}
3362
3363func safetyErrorSlice(unsafe unsafeVars) (result Errors) {
3364
3365	if len(unsafe) == 0 {
3366		return
3367	}
3368
3369	for _, pair := range unsafe.Vars() {
3370		if !pair.Var.IsGenerated() {
3371			result = append(result, NewError(UnsafeVarErr, pair.Loc, "var %v is unsafe", pair.Var))
3372		}
3373	}
3374
3375	if len(result) > 0 {
3376		return
3377	}
3378
3379	// If the expression contains unsafe generated variables, report which
3380	// expressions are unsafe instead of the variables that are unsafe (since
3381	// the latter are not meaningful to the user.)
3382	pairs := unsafe.Slice()
3383
3384	sort.Slice(pairs, func(i, j int) bool {
3385		return pairs[i].Expr.Location.Compare(pairs[j].Expr.Location) < 0
3386	})
3387
3388	// Report at most one error per generated variable.
3389	seen := NewVarSet()
3390
3391	for _, expr := range pairs {
3392		before := len(seen)
3393		for v := range expr.Vars {
3394			if v.IsGenerated() {
3395				seen.Add(v)
3396			}
3397		}
3398		if len(seen) > before {
3399			result = append(result, NewError(UnsafeVarErr, expr.Expr.Location, "expression is unsafe"))
3400		}
3401	}
3402
3403	return
3404}
3405
3406func checkUnsafeBuiltins(unsafeBuiltinsMap map[string]struct{}, node interface{}) Errors {
3407	errs := make(Errors, 0)
3408	WalkExprs(node, func(x *Expr) bool {
3409		if x.IsCall() {
3410			operator := x.Operator().String()
3411			if _, ok := unsafeBuiltinsMap[operator]; ok {
3412				errs = append(errs, NewError(TypeErr, x.Loc(), "unsafe built-in function calls in expression: %v", operator))
3413			}
3414		}
3415		return false
3416	})
3417	return errs
3418}
3419
3420func rewriteVarsInRef(vars ...map[Var]Var) func(Ref) Ref {
3421	return func(node Ref) Ref {
3422		i, _ := TransformVars(node, func(v Var) (Value, error) {
3423			for _, m := range vars {
3424				if u, ok := m[v]; ok {
3425					return u, nil
3426				}
3427			}
3428			return v, nil
3429		})
3430		return i.(Ref)
3431	}
3432}
3433
3434func rewriteVarsNop(node Ref) Ref {
3435	return node
3436}
3437