1// Copyright 2016 The OPA Authors. All rights reserved. 2// Use of this source code is governed by an Apache2 3// license that can be found in the LICENSE file. 4 5package ast 6 7import ( 8 "fmt" 9 "sort" 10 "strconv" 11 "strings" 12 13 "github.com/open-policy-agent/opa/metrics" 14 "github.com/open-policy-agent/opa/util" 15) 16 17// CompileErrorLimitDefault is the default number errors a compiler will allow before 18// exiting. 19const CompileErrorLimitDefault = 10 20 21var errLimitReached = NewError(CompileErr, nil, "error limit reached") 22 23// Compiler contains the state of a compilation process. 24type Compiler struct { 25 26 // Errors contains errors that occurred during the compilation process. 27 // If there are one or more errors, the compilation process is considered 28 // "failed". 29 Errors Errors 30 31 // Modules contains the compiled modules. The compiled modules are the 32 // output of the compilation process. If the compilation process failed, 33 // there is no guarantee about the state of the modules. 34 Modules map[string]*Module 35 36 // ModuleTree organizes the modules into a tree where each node is keyed by 37 // an element in the module's package path. E.g., given modules containing 38 // the following package directives: "a", "a.b", "a.c", and "a.b", the 39 // resulting module tree would be: 40 // 41 // root 42 // | 43 // +--- data (no modules) 44 // | 45 // +--- a (1 module) 46 // | 47 // +--- b (2 modules) 48 // | 49 // +--- c (1 module) 50 // 51 ModuleTree *ModuleTreeNode 52 53 // RuleTree organizes rules into a tree where each node is keyed by an 54 // element in the rule's path. The rule path is the concatenation of the 55 // containing package and the stringified rule name. E.g., given the 56 // following module: 57 // 58 // package ex 59 // p[1] { true } 60 // p[2] { true } 61 // q = true 62 // 63 // root 64 // | 65 // +--- data (no rules) 66 // | 67 // +--- ex (no rules) 68 // | 69 // +--- p (2 rules) 70 // | 71 // +--- q (1 rule) 72 RuleTree *TreeNode 73 74 // Graph contains dependencies between rules. An edge (u,v) is added to the 75 // graph if rule 'u' refers to the virtual document defined by 'v'. 76 Graph *Graph 77 78 // TypeEnv holds type information for values inferred by the compiler. 79 TypeEnv *TypeEnv 80 81 // RewrittenVars is a mapping of variables that have been rewritten 82 // with the key being the generated name and value being the original. 83 RewrittenVars map[Var]Var 84 85 localvargen *localVarGenerator 86 moduleLoader ModuleLoader 87 ruleIndices *util.HashMap 88 stages []struct { 89 name string 90 metricName string 91 f func() 92 } 93 maxErrs int 94 sorted []string // list of sorted module names 95 pathExists func([]string) (bool, error) 96 after map[string][]CompilerStageDefinition 97 metrics metrics.Metrics 98 builtins map[string]*Builtin 99 unsafeBuiltinsMap map[string]struct{} 100} 101 102// CompilerStage defines the interface for stages in the compiler. 103type CompilerStage func(*Compiler) *Error 104 105// CompilerStageDefinition defines a compiler stage 106type CompilerStageDefinition struct { 107 Name string 108 MetricName string 109 Stage CompilerStage 110} 111 112// QueryContext contains contextual information for running an ad-hoc query. 113// 114// Ad-hoc queries can be run in the context of a package and imports may be 115// included to provide concise access to data. 116type QueryContext struct { 117 Package *Package 118 Imports []*Import 119} 120 121// NewQueryContext returns a new QueryContext object. 122func NewQueryContext() *QueryContext { 123 return &QueryContext{} 124} 125 126// WithPackage sets the pkg on qc. 127func (qc *QueryContext) WithPackage(pkg *Package) *QueryContext { 128 if qc == nil { 129 qc = NewQueryContext() 130 } 131 qc.Package = pkg 132 return qc 133} 134 135// WithImports sets the imports on qc. 136func (qc *QueryContext) WithImports(imports []*Import) *QueryContext { 137 if qc == nil { 138 qc = NewQueryContext() 139 } 140 qc.Imports = imports 141 return qc 142} 143 144// Copy returns a deep copy of qc. 145func (qc *QueryContext) Copy() *QueryContext { 146 if qc == nil { 147 return nil 148 } 149 cpy := *qc 150 if cpy.Package != nil { 151 cpy.Package = qc.Package.Copy() 152 } 153 cpy.Imports = make([]*Import, len(qc.Imports)) 154 for i := range qc.Imports { 155 cpy.Imports[i] = qc.Imports[i].Copy() 156 } 157 return &cpy 158} 159 160// QueryCompiler defines the interface for compiling ad-hoc queries. 161type QueryCompiler interface { 162 163 // Compile should be called to compile ad-hoc queries. The return value is 164 // the compiled version of the query. 165 Compile(q Body) (Body, error) 166 167 // TypeEnv returns the type environment built after running type checking 168 // on the query. 169 TypeEnv() *TypeEnv 170 171 // WithContext sets the QueryContext on the QueryCompiler. Subsequent calls 172 // to Compile will take the QueryContext into account. 173 WithContext(qctx *QueryContext) QueryCompiler 174 175 // WithUnsafeBuiltins sets the built-in functions to treat as unsafe and not 176 // allow inside of queries. By default the query compiler inherits the 177 // compiler's unsafe built-in functions. This function allows callers to 178 // override that set. If an empty (non-nil) map is provided, all built-ins 179 // are allowed. 180 WithUnsafeBuiltins(unsafe map[string]struct{}) QueryCompiler 181 182 // WithStageAfter registers a stage to run during query compilation after 183 // the named stage. 184 WithStageAfter(after string, stage QueryCompilerStageDefinition) QueryCompiler 185 186 // RewrittenVars maps generated vars in the compiled query to vars from the 187 // parsed query. For example, given the query "input := 1" the rewritten 188 // query would be "__local0__ = 1". The mapping would then be {__local0__: input}. 189 RewrittenVars() map[Var]Var 190} 191 192// QueryCompilerStage defines the interface for stages in the query compiler. 193type QueryCompilerStage func(QueryCompiler, Body) (Body, error) 194 195// QueryCompilerStageDefinition defines a QueryCompiler stage 196type QueryCompilerStageDefinition struct { 197 Name string 198 MetricName string 199 Stage QueryCompilerStage 200} 201 202const compileStageMetricPrefex = "ast_compile_stage_" 203 204// NewCompiler returns a new empty compiler. 205func NewCompiler() *Compiler { 206 207 c := &Compiler{ 208 Modules: map[string]*Module{}, 209 TypeEnv: NewTypeEnv(), 210 RewrittenVars: map[Var]Var{}, 211 ruleIndices: util.NewHashMap(func(a, b util.T) bool { 212 r1, r2 := a.(Ref), b.(Ref) 213 return r1.Equal(r2) 214 }, func(x util.T) int { 215 return x.(Ref).Hash() 216 }), 217 maxErrs: CompileErrorLimitDefault, 218 after: map[string][]CompilerStageDefinition{}, 219 unsafeBuiltinsMap: map[string]struct{}{}, 220 } 221 222 c.ModuleTree = NewModuleTree(nil) 223 c.RuleTree = NewRuleTree(c.ModuleTree) 224 225 // Initialize the compiler with the statically compiled built-in functions. 226 // If the caller customizes the compiler, a copy will be made. 227 c.builtins = BuiltinMap 228 checker := newTypeChecker() 229 c.TypeEnv = checker.checkLanguageBuiltins(nil, c.builtins) 230 231 c.stages = []struct { 232 name string 233 metricName string 234 f func() 235 }{ 236 // Reference resolution should run first as it may be used to lazily 237 // load additional modules. If any stages run before resolution, they 238 // need to be re-run after resolution. 239 {"ResolveRefs", "compile_stage_resolve_refs", c.resolveAllRefs}, 240 241 // The local variable generator must be initialized after references are 242 // resolved and the dynamic module loader has run but before subsequent 243 // stages that need to generate variables. 244 {"InitLocalVarGen", "compile_stage_init_local_var_gen", c.initLocalVarGen}, 245 246 {"RewriteLocalVars", "compile_stage_rewrite_local_vars", c.rewriteLocalVars}, 247 {"RewriteExprTerms", "compile_stage_rewrite_expr_terms", c.rewriteExprTerms}, 248 {"SetModuleTree", "compile_stage_set_module_tree", c.setModuleTree}, 249 {"SetRuleTree", "compile_stage_set_rule_tree", c.setRuleTree}, 250 {"SetGraph", "compile_stage_set_graph", c.setGraph}, 251 {"RewriteComprehensionTerms", "compile_stage_rewrite_comprehension_terms", c.rewriteComprehensionTerms}, 252 {"RewriteRefsInHead", "compile_stage_rewrite_refs_in_head", c.rewriteRefsInHead}, 253 {"RewriteWithValues", "compile_stage_rewrite_with_values", c.rewriteWithModifiers}, 254 {"CheckRuleConflicts", "compile_stage_check_rule_conflicts", c.checkRuleConflicts}, 255 {"CheckUndefinedFuncs", "compile_stage_check_undefined_funcs", c.checkUndefinedFuncs}, 256 {"CheckSafetyRuleHeads", "compile_stage_check_safety_rule_heads", c.checkSafetyRuleHeads}, 257 {"CheckSafetyRuleBodies", "compile_stage_check_safety_rule_bodies", c.checkSafetyRuleBodies}, 258 {"RewriteEquals", "compile_stage_rewrite_equals", c.rewriteEquals}, 259 {"RewriteDynamicTerms", "compile_stage_rewrite_dynamic_terms", c.rewriteDynamicTerms}, 260 {"CheckRecursion", "compile_stage_check_recursion", c.checkRecursion}, 261 {"CheckTypes", "compile_stage_check_types", c.checkTypes}, 262 {"CheckUnsafeBuiltins", "compile_state_check_unsafe_builtins", c.checkUnsafeBuiltins}, 263 {"BuildRuleIndices", "compile_stage_rebuild_indices", c.buildRuleIndices}, 264 } 265 266 return c 267} 268 269// SetErrorLimit sets the number of errors the compiler can encounter before it 270// quits. Zero or a negative number indicates no limit. 271func (c *Compiler) SetErrorLimit(limit int) *Compiler { 272 c.maxErrs = limit 273 return c 274} 275 276// WithPathConflictsCheck enables base-virtual document conflict 277// detection. The compiler will check that rules don't overlap with 278// paths that exist as determined by the provided callable. 279func (c *Compiler) WithPathConflictsCheck(fn func([]string) (bool, error)) *Compiler { 280 c.pathExists = fn 281 return c 282} 283 284// WithStageAfter registers a stage to run during compilation after 285// the named stage. 286func (c *Compiler) WithStageAfter(after string, stage CompilerStageDefinition) *Compiler { 287 c.after[after] = append(c.after[after], stage) 288 return c 289} 290 291// WithMetrics will set a metrics.Metrics and be used for profiling 292// the Compiler instance. 293func (c *Compiler) WithMetrics(metrics metrics.Metrics) *Compiler { 294 c.metrics = metrics 295 return c 296} 297 298// WithBuiltins adds a set of custom built-in functions to the compiler. 299func (c *Compiler) WithBuiltins(builtins map[string]*Builtin) *Compiler { 300 if len(builtins) == 0 { 301 return c 302 } 303 cpy := make(map[string]*Builtin, len(c.builtins)+len(builtins)) 304 for k, v := range c.builtins { 305 cpy[k] = v 306 } 307 for k, v := range builtins { 308 cpy[k] = v 309 } 310 c.builtins = cpy 311 // Build type env for custom functions and wrap existing one. 312 checker := newTypeChecker() 313 c.TypeEnv = checker.checkLanguageBuiltins(c.TypeEnv, builtins) 314 return c 315} 316 317// WithUnsafeBuiltins will add all built-ins in the map to the "blacklist". 318func (c *Compiler) WithUnsafeBuiltins(unsafeBuiltins map[string]struct{}) *Compiler { 319 for name := range unsafeBuiltins { 320 c.unsafeBuiltinsMap[name] = struct{}{} 321 } 322 return c 323} 324 325// QueryCompiler returns a new QueryCompiler object. 326func (c *Compiler) QueryCompiler() QueryCompiler { 327 return newQueryCompiler(c) 328} 329 330// Compile runs the compilation process on the input modules. The compiled 331// version of the modules and associated data structures are stored on the 332// compiler. If the compilation process fails for any reason, the compiler will 333// contain a slice of errors. 334func (c *Compiler) Compile(modules map[string]*Module) { 335 336 c.Modules = make(map[string]*Module, len(modules)) 337 338 for k, v := range modules { 339 c.Modules[k] = v.Copy() 340 c.sorted = append(c.sorted, k) 341 } 342 343 sort.Strings(c.sorted) 344 345 c.compile() 346} 347 348// Failed returns true if a compilation error has been encountered. 349func (c *Compiler) Failed() bool { 350 return len(c.Errors) > 0 351} 352 353// GetArity returns the number of args a function referred to by ref takes. If 354// ref refers to built-in function, the built-in declaration is consulted, 355// otherwise, the ref is used to perform a ruleset lookup. 356func (c *Compiler) GetArity(ref Ref) int { 357 if bi := c.builtins[ref.String()]; bi != nil { 358 return len(bi.Decl.Args()) 359 } 360 rules := c.GetRulesExact(ref) 361 if len(rules) == 0 { 362 return -1 363 } 364 return len(rules[0].Head.Args) 365} 366 367// GetRulesExact returns a slice of rules referred to by the reference. 368// 369// E.g., given the following module: 370// 371// package a.b.c 372// 373// p[k] = v { ... } # rule1 374// p[k1] = v1 { ... } # rule2 375// 376// The following calls yield the rules on the right. 377// 378// GetRulesExact("data.a.b.c.p") => [rule1, rule2] 379// GetRulesExact("data.a.b.c.p.x") => nil 380// GetRulesExact("data.a.b.c") => nil 381func (c *Compiler) GetRulesExact(ref Ref) (rules []*Rule) { 382 node := c.RuleTree 383 384 for _, x := range ref { 385 if node = node.Child(x.Value); node == nil { 386 return nil 387 } 388 } 389 390 return extractRules(node.Values) 391} 392 393// GetRulesForVirtualDocument returns a slice of rules that produce the virtual 394// document referred to by the reference. 395// 396// E.g., given the following module: 397// 398// package a.b.c 399// 400// p[k] = v { ... } # rule1 401// p[k1] = v1 { ... } # rule2 402// 403// The following calls yield the rules on the right. 404// 405// GetRulesForVirtualDocument("data.a.b.c.p") => [rule1, rule2] 406// GetRulesForVirtualDocument("data.a.b.c.p.x") => [rule1, rule2] 407// GetRulesForVirtualDocument("data.a.b.c") => nil 408func (c *Compiler) GetRulesForVirtualDocument(ref Ref) (rules []*Rule) { 409 410 node := c.RuleTree 411 412 for _, x := range ref { 413 if node = node.Child(x.Value); node == nil { 414 return nil 415 } 416 if len(node.Values) > 0 { 417 return extractRules(node.Values) 418 } 419 } 420 421 return extractRules(node.Values) 422} 423 424// GetRulesWithPrefix returns a slice of rules that share the prefix ref. 425// 426// E.g., given the following module: 427// 428// package a.b.c 429// 430// p[x] = y { ... } # rule1 431// p[k] = v { ... } # rule2 432// q { ... } # rule3 433// 434// The following calls yield the rules on the right. 435// 436// GetRulesWithPrefix("data.a.b.c.p") => [rule1, rule2] 437// GetRulesWithPrefix("data.a.b.c.p.a") => nil 438// GetRulesWithPrefix("data.a.b.c") => [rule1, rule2, rule3] 439func (c *Compiler) GetRulesWithPrefix(ref Ref) (rules []*Rule) { 440 441 node := c.RuleTree 442 443 for _, x := range ref { 444 if node = node.Child(x.Value); node == nil { 445 return nil 446 } 447 } 448 449 var acc func(node *TreeNode) 450 451 acc = func(node *TreeNode) { 452 rules = append(rules, extractRules(node.Values)...) 453 for _, child := range node.Children { 454 if child.Hide { 455 continue 456 } 457 acc(child) 458 } 459 } 460 461 acc(node) 462 463 return rules 464} 465 466func extractRules(s []util.T) (rules []*Rule) { 467 for _, r := range s { 468 rules = append(rules, r.(*Rule)) 469 } 470 return rules 471} 472 473// GetRules returns a slice of rules that are referred to by ref. 474// 475// E.g., given the following module: 476// 477// package a.b.c 478// 479// p[x] = y { q[x] = y; ... } # rule1 480// q[x] = y { ... } # rule2 481// 482// The following calls yield the rules on the right. 483// 484// GetRules("data.a.b.c.p") => [rule1] 485// GetRules("data.a.b.c.p.x") => [rule1] 486// GetRules("data.a.b.c.q") => [rule2] 487// GetRules("data.a.b.c") => [rule1, rule2] 488// GetRules("data.a.b.d") => nil 489func (c *Compiler) GetRules(ref Ref) (rules []*Rule) { 490 491 set := map[*Rule]struct{}{} 492 493 for _, rule := range c.GetRulesForVirtualDocument(ref) { 494 set[rule] = struct{}{} 495 } 496 497 for _, rule := range c.GetRulesWithPrefix(ref) { 498 set[rule] = struct{}{} 499 } 500 501 for rule := range set { 502 rules = append(rules, rule) 503 } 504 505 return rules 506} 507 508// GetRulesDynamic returns a slice of rules that could be referred to by a ref. 509// When parts of the ref are statically known, we use that information to narrow 510// down which rules the ref could refer to, but in the most general case this 511// will be an over-approximation. 512// 513// E.g., given the following modules: 514// 515// package a.b.c 516// 517// r1 = 1 # rule1 518// 519// and: 520// 521// package a.d.c 522// 523// r2 = 2 # rule2 524// 525// The following calls yield the rules on the right. 526// 527// GetRulesDynamic("data.a[x].c[y]") => [rule1, rule2] 528// GetRulesDynamic("data.a[x].c.r2") => [rule2] 529// GetRulesDynamic("data.a.b[x][y]") => [rule1] 530func (c *Compiler) GetRulesDynamic(ref Ref) (rules []*Rule) { 531 node := c.RuleTree 532 533 set := map[*Rule]struct{}{} 534 var walk func(node *TreeNode, i int) 535 walk = func(node *TreeNode, i int) { 536 if i >= len(ref) { 537 // We've reached the end of the reference and want to collect everything 538 // under this "prefix". 539 node.DepthFirst(func(descendant *TreeNode) bool { 540 insertRules(set, descendant.Values) 541 return descendant.Hide 542 }) 543 } else if i == 0 || IsConstant(ref[i].Value) { 544 // The head of the ref is always grounded. In case another part of the 545 // ref is also grounded, we can lookup the exact child. If it's not found 546 // we can immediately return... 547 if child := node.Child(ref[i].Value); child == nil { 548 return 549 } else if len(child.Values) > 0 { 550 // If there are any rules at this position, it's what the ref would 551 // refer to. We can just append those and stop here. 552 insertRules(set, child.Values) 553 } else { 554 // Otherwise, we continue using the child node. 555 walk(child, i+1) 556 } 557 } else { 558 // This part of the ref is a dynamic term. We can't know what it refers 559 // to and will just need to try all of the children. 560 for _, child := range node.Children { 561 if child.Hide { 562 continue 563 } 564 insertRules(set, child.Values) 565 walk(child, i+1) 566 } 567 } 568 } 569 570 walk(node, 0) 571 for rule := range set { 572 rules = append(rules, rule) 573 } 574 return rules 575} 576 577// Utility: add all rule values to the set. 578func insertRules(set map[*Rule]struct{}, rules []util.T) { 579 for _, rule := range rules { 580 set[rule.(*Rule)] = struct{}{} 581 } 582} 583 584// RuleIndex returns a RuleIndex built for the rule set referred to by path. 585// The path must refer to the rule set exactly, i.e., given a rule set at path 586// data.a.b.c.p, refs data.a.b.c.p.x and data.a.b.c would not return a 587// RuleIndex built for the rule. 588func (c *Compiler) RuleIndex(path Ref) RuleIndex { 589 r, ok := c.ruleIndices.Get(path) 590 if !ok { 591 return nil 592 } 593 return r.(RuleIndex) 594} 595 596// ModuleLoader defines the interface that callers can implement to enable lazy 597// loading of modules during compilation. 598type ModuleLoader func(resolved map[string]*Module) (parsed map[string]*Module, err error) 599 600// WithModuleLoader sets f as the ModuleLoader on the compiler. 601// 602// The compiler will invoke the ModuleLoader after resolving all references in 603// the current set of input modules. The ModuleLoader can return a new 604// collection of parsed modules that are to be included in the compilation 605// process. This process will repeat until the ModuleLoader returns an empty 606// collection or an error. If an error is returned, compilation will stop 607// immediately. 608func (c *Compiler) WithModuleLoader(f ModuleLoader) *Compiler { 609 c.moduleLoader = f 610 return c 611} 612 613// buildRuleIndices constructs indices for rules. 614func (c *Compiler) buildRuleIndices() { 615 616 c.RuleTree.DepthFirst(func(node *TreeNode) bool { 617 if len(node.Values) == 0 { 618 return false 619 } 620 index := newBaseDocEqIndex(func(ref Ref) bool { 621 return isVirtual(c.RuleTree, ref.GroundPrefix()) 622 }) 623 if rules := extractRules(node.Values); index.Build(rules) { 624 c.ruleIndices.Put(rules[0].Path(), index) 625 } 626 return false 627 }) 628 629} 630 631// checkRecursion ensures that there are no recursive definitions, i.e., there are 632// no cycles in the Graph. 633func (c *Compiler) checkRecursion() { 634 eq := func(a, b util.T) bool { 635 return a.(*Rule) == b.(*Rule) 636 } 637 638 c.RuleTree.DepthFirst(func(node *TreeNode) bool { 639 for _, rule := range node.Values { 640 for node := rule.(*Rule); node != nil; node = node.Else { 641 c.checkSelfPath(node.Loc(), eq, node, node) 642 } 643 } 644 return false 645 }) 646} 647 648func (c *Compiler) checkSelfPath(loc *Location, eq func(a, b util.T) bool, a, b util.T) { 649 tr := NewGraphTraversal(c.Graph) 650 if p := util.DFSPath(tr, eq, a, b); len(p) > 0 { 651 n := []string{} 652 for _, x := range p { 653 n = append(n, astNodeToString(x)) 654 } 655 c.err(NewError(RecursionErr, loc, "rule %v is recursive: %v", astNodeToString(a), strings.Join(n, " -> "))) 656 } 657} 658 659func astNodeToString(x interface{}) string { 660 switch x := x.(type) { 661 case *Rule: 662 return string(x.Head.Name) 663 default: 664 panic("not reached") 665 } 666} 667 668// checkRuleConflicts ensures that rules definitions are not in conflict. 669func (c *Compiler) checkRuleConflicts() { 670 c.RuleTree.DepthFirst(func(node *TreeNode) bool { 671 if len(node.Values) == 0 { 672 return false 673 } 674 675 kinds := map[DocKind]struct{}{} 676 defaultRules := 0 677 arities := map[int]struct{}{} 678 declared := false 679 680 for _, rule := range node.Values { 681 r := rule.(*Rule) 682 kinds[r.Head.DocKind()] = struct{}{} 683 arities[len(r.Head.Args)] = struct{}{} 684 if r.Head.Assign { 685 declared = true 686 } 687 if r.Default { 688 defaultRules++ 689 } 690 } 691 692 name := Var(node.Key.(String)) 693 694 if declared && len(node.Values) > 1 { 695 c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "rule named %v redeclared at %v", name, node.Values[1].(*Rule).Loc())) 696 } else if len(kinds) > 1 || len(arities) > 1 { 697 c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "conflicting rules named %v found", name)) 698 } else if defaultRules > 1 { 699 c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "multiple default rules named %s found", name)) 700 } 701 702 return false 703 }) 704 705 if c.pathExists != nil { 706 for _, err := range CheckPathConflicts(c, c.pathExists) { 707 c.err(err) 708 } 709 } 710 711 c.ModuleTree.DepthFirst(func(node *ModuleTreeNode) bool { 712 for _, mod := range node.Modules { 713 for _, rule := range mod.Rules { 714 if childNode, ok := node.Children[String(rule.Head.Name)]; ok { 715 for _, childMod := range childNode.Modules { 716 msg := fmt.Sprintf("%v conflicts with rule defined at %v", childMod.Package, rule.Loc()) 717 c.err(NewError(TypeErr, mod.Package.Loc(), msg)) 718 } 719 } 720 } 721 } 722 return false 723 }) 724} 725 726func (c *Compiler) checkUndefinedFuncs() { 727 for _, name := range c.sorted { 728 m := c.Modules[name] 729 for _, err := range checkUndefinedFuncs(m, c.GetArity) { 730 c.err(err) 731 } 732 } 733} 734 735func checkUndefinedFuncs(x interface{}, arity func(Ref) int) Errors { 736 737 var errs Errors 738 739 WalkExprs(x, func(expr *Expr) bool { 740 if !expr.IsCall() { 741 return false 742 } 743 ref := expr.Operator() 744 if arity(ref) >= 0 { 745 return false 746 } 747 errs = append(errs, NewError(TypeErr, expr.Loc(), "undefined function %v", ref)) 748 return true 749 }) 750 751 return errs 752} 753 754// checkSafetyRuleBodies ensures that variables appearing in negated expressions or non-target 755// positions of built-in expressions will be bound when evaluating the rule from left 756// to right, re-ordering as necessary. 757func (c *Compiler) checkSafetyRuleBodies() { 758 for _, name := range c.sorted { 759 m := c.Modules[name] 760 WalkRules(m, func(r *Rule) bool { 761 safe := ReservedVars.Copy() 762 safe.Update(r.Head.Args.Vars()) 763 r.Body = c.checkBodySafety(safe, m, r.Body) 764 return false 765 }) 766 } 767} 768 769func (c *Compiler) checkBodySafety(safe VarSet, m *Module, b Body) Body { 770 reordered, unsafe := reorderBodyForSafety(c.builtins, c.GetArity, safe, b) 771 if errs := safetyErrorSlice(unsafe); len(errs) > 0 { 772 for _, err := range errs { 773 c.err(err) 774 } 775 return b 776 } 777 return reordered 778} 779 780var safetyCheckVarVisitorParams = VarVisitorParams{ 781 SkipRefCallHead: true, 782 SkipClosures: true, 783} 784 785// checkSafetyRuleHeads ensures that variables appearing in the head of a 786// rule also appear in the body. 787func (c *Compiler) checkSafetyRuleHeads() { 788 789 for _, name := range c.sorted { 790 m := c.Modules[name] 791 WalkRules(m, func(r *Rule) bool { 792 safe := r.Body.Vars(safetyCheckVarVisitorParams) 793 safe.Update(r.Head.Args.Vars()) 794 unsafe := r.Head.Vars().Diff(safe) 795 for v := range unsafe { 796 if !v.IsGenerated() { 797 c.err(NewError(UnsafeVarErr, r.Loc(), "var %v is unsafe", v)) 798 } 799 } 800 return false 801 }) 802 } 803} 804 805// checkTypes runs the type checker on all rules. The type checker builds a 806// TypeEnv that is stored on the compiler. 807func (c *Compiler) checkTypes() { 808 // Recursion is caught in earlier step, so this cannot fail. 809 sorted, _ := c.Graph.Sort() 810 checker := newTypeChecker().WithVarRewriter(rewriteVarsInRef(c.RewrittenVars)) 811 env, errs := checker.CheckTypes(c.TypeEnv, sorted) 812 for _, err := range errs { 813 c.err(err) 814 } 815 c.TypeEnv = env 816} 817 818func (c *Compiler) checkUnsafeBuiltins() { 819 for _, name := range c.sorted { 820 errs := checkUnsafeBuiltins(c.unsafeBuiltinsMap, c.Modules[name]) 821 for _, err := range errs { 822 c.err(err) 823 } 824 } 825} 826 827func (c *Compiler) runStage(metricName string, f func()) { 828 if c.metrics != nil { 829 c.metrics.Timer(metricName).Start() 830 defer c.metrics.Timer(metricName).Stop() 831 } 832 f() 833} 834 835func (c *Compiler) runStageAfter(metricName string, s CompilerStage) *Error { 836 if c.metrics != nil { 837 c.metrics.Timer(metricName).Start() 838 defer c.metrics.Timer(metricName).Stop() 839 } 840 return s(c) 841} 842 843func (c *Compiler) compile() { 844 defer func() { 845 if r := recover(); r != nil && r != errLimitReached { 846 panic(r) 847 } 848 }() 849 850 for _, s := range c.stages { 851 c.runStage(s.metricName, s.f) 852 if c.Failed() { 853 return 854 } 855 for _, s := range c.after[s.name] { 856 err := c.runStageAfter(s.MetricName, s.Stage) 857 if err != nil { 858 c.err(err) 859 } 860 } 861 } 862} 863 864func (c *Compiler) err(err *Error) { 865 if c.maxErrs > 0 && len(c.Errors) >= c.maxErrs { 866 c.Errors = append(c.Errors, errLimitReached) 867 panic(errLimitReached) 868 } 869 c.Errors = append(c.Errors, err) 870} 871 872func (c *Compiler) getExports() *util.HashMap { 873 874 rules := util.NewHashMap(func(a, b util.T) bool { 875 r1 := a.(Ref) 876 r2 := a.(Ref) 877 return r1.Equal(r2) 878 }, func(v util.T) int { 879 return v.(Ref).Hash() 880 }) 881 882 for _, name := range c.sorted { 883 mod := c.Modules[name] 884 rv, ok := rules.Get(mod.Package.Path) 885 if !ok { 886 rv = []Var{} 887 } 888 rvs := rv.([]Var) 889 890 for _, rule := range mod.Rules { 891 rvs = append(rvs, rule.Head.Name) 892 } 893 rules.Put(mod.Package.Path, rvs) 894 } 895 896 return rules 897} 898 899// resolveAllRefs resolves references in expressions to their fully qualified values. 900// 901// For instance, given the following module: 902// 903// package a.b 904// import data.foo.bar 905// p[x] { bar[_] = x } 906// 907// The reference "bar[_]" would be resolved to "data.foo.bar[_]". 908func (c *Compiler) resolveAllRefs() { 909 910 rules := c.getExports() 911 912 for _, name := range c.sorted { 913 mod := c.Modules[name] 914 915 var ruleExports []Var 916 if x, ok := rules.Get(mod.Package.Path); ok { 917 ruleExports = x.([]Var) 918 } 919 920 globals := getGlobals(mod.Package, ruleExports, mod.Imports) 921 922 WalkRules(mod, func(rule *Rule) bool { 923 err := resolveRefsInRule(globals, rule) 924 if err != nil { 925 c.err(NewError(CompileErr, rule.Location, err.Error())) 926 } 927 return false 928 }) 929 930 // Once imports have been resolved, they are no longer needed. 931 mod.Imports = nil 932 } 933 934 if c.moduleLoader != nil { 935 936 parsed, err := c.moduleLoader(c.Modules) 937 if err != nil { 938 c.err(NewError(CompileErr, nil, err.Error())) 939 return 940 } 941 942 if len(parsed) == 0 { 943 return 944 } 945 946 for id, module := range parsed { 947 c.Modules[id] = module.Copy() 948 c.sorted = append(c.sorted, id) 949 } 950 951 sort.Strings(c.sorted) 952 c.resolveAllRefs() 953 } 954} 955 956func (c *Compiler) initLocalVarGen() { 957 c.localvargen = newLocalVarGeneratorForModuleSet(c.sorted, c.Modules) 958} 959 960func (c *Compiler) rewriteComprehensionTerms() { 961 f := newEqualityFactory(c.localvargen) 962 for _, name := range c.sorted { 963 mod := c.Modules[name] 964 rewriteComprehensionTerms(f, mod) 965 } 966} 967 968func (c *Compiler) rewriteExprTerms() { 969 for _, name := range c.sorted { 970 mod := c.Modules[name] 971 WalkRules(mod, func(rule *Rule) bool { 972 rewriteExprTermsInHead(c.localvargen, rule) 973 rule.Body = rewriteExprTermsInBody(c.localvargen, rule.Body) 974 return false 975 }) 976 } 977} 978 979// rewriteTermsInHead will rewrite rules so that the head does not contain any 980// terms that require evaluation (e.g., refs or comprehensions). If the key or 981// value contains or more of these terms, the key or value will be moved into 982// the body and assigned to a new variable. The new variable will replace the 983// key or value in the head. 984// 985// For instance, given the following rule: 986// 987// p[{"foo": data.foo[i]}] { i < 100 } 988// 989// The rule would be re-written as: 990// 991// p[__local0__] { i < 100; __local0__ = {"foo": data.foo[i]} } 992func (c *Compiler) rewriteRefsInHead() { 993 f := newEqualityFactory(c.localvargen) 994 for _, name := range c.sorted { 995 mod := c.Modules[name] 996 WalkRules(mod, func(rule *Rule) bool { 997 if requiresEval(rule.Head.Key) { 998 expr := f.Generate(rule.Head.Key) 999 rule.Head.Key = expr.Operand(0) 1000 rule.Body.Append(expr) 1001 } 1002 if requiresEval(rule.Head.Value) { 1003 expr := f.Generate(rule.Head.Value) 1004 rule.Head.Value = expr.Operand(0) 1005 rule.Body.Append(expr) 1006 } 1007 for i := 0; i < len(rule.Head.Args); i++ { 1008 if requiresEval(rule.Head.Args[i]) { 1009 expr := f.Generate(rule.Head.Args[i]) 1010 rule.Head.Args[i] = expr.Operand(0) 1011 rule.Body.Append(expr) 1012 } 1013 } 1014 return false 1015 }) 1016 } 1017} 1018 1019func (c *Compiler) rewriteEquals() { 1020 for _, name := range c.sorted { 1021 mod := c.Modules[name] 1022 rewriteEquals(mod) 1023 } 1024} 1025 1026func (c *Compiler) rewriteDynamicTerms() { 1027 f := newEqualityFactory(c.localvargen) 1028 for _, name := range c.sorted { 1029 mod := c.Modules[name] 1030 WalkRules(mod, func(rule *Rule) bool { 1031 rule.Body = rewriteDynamics(f, rule.Body) 1032 return false 1033 }) 1034 } 1035} 1036 1037func (c *Compiler) rewriteLocalVars() { 1038 1039 for _, name := range c.sorted { 1040 mod := c.Modules[name] 1041 gen := c.localvargen 1042 1043 WalkRules(mod, func(rule *Rule) bool { 1044 1045 var errs Errors 1046 1047 // Rewrite assignments contained in head of rule. Assignments can 1048 // occur in rule head if they're inside a comprehension. Note, 1049 // assigned vars in comprehensions in the head will be rewritten 1050 // first to preserve scoping rules. For example: 1051 // 1052 // p = [x | x := 1] { x := 2 } becomes p = [__local0__ | __local0__ = 1] { __local1__ = 2 } 1053 // 1054 // This behaviour is consistent scoping inside the body. For example: 1055 // 1056 // p = xs { x := 2; xs = [x | x := 1] } becomes p = xs { __local0__ = 2; xs = [__local1__ | __local1__ = 1] } 1057 WalkTerms(rule.Head, func(term *Term) bool { 1058 stop := false 1059 stack := newLocalDeclaredVars() 1060 switch v := term.Value.(type) { 1061 case *ArrayComprehension: 1062 errs = rewriteDeclaredVarsInArrayComprehension(gen, stack, v, errs) 1063 stop = true 1064 case *SetComprehension: 1065 errs = rewriteDeclaredVarsInSetComprehension(gen, stack, v, errs) 1066 stop = true 1067 case *ObjectComprehension: 1068 errs = rewriteDeclaredVarsInObjectComprehension(gen, stack, v, errs) 1069 stop = true 1070 } 1071 1072 for k, v := range stack.rewritten { 1073 c.RewrittenVars[k] = v 1074 } 1075 1076 return stop 1077 }) 1078 1079 for _, err := range errs { 1080 c.err(err) 1081 } 1082 1083 // Rewrite assignments in body. 1084 used := NewVarSet() 1085 1086 if rule.Head.Key != nil { 1087 used.Update(rule.Head.Key.Vars()) 1088 } 1089 1090 if rule.Head.Value != nil { 1091 used.Update(rule.Head.Value.Vars()) 1092 } 1093 1094 stack := newLocalDeclaredVars() 1095 1096 c.rewriteLocalArgVars(gen, stack, rule) 1097 1098 body, declared, errs := rewriteLocalVars(gen, stack, used, rule.Body) 1099 for _, err := range errs { 1100 c.err(err) 1101 } 1102 1103 // For rewritten vars use the collection of all variables that 1104 // were in the stack at some point in time. 1105 for k, v := range stack.rewritten { 1106 c.RewrittenVars[k] = v 1107 } 1108 1109 rule.Body = body 1110 1111 // Rewrite vars in head that refer to locally declared vars in the body. 1112 vis := NewGenericVisitor(func(x interface{}) bool { 1113 1114 term, ok := x.(*Term) 1115 if !ok { 1116 return false 1117 } 1118 1119 switch v := term.Value.(type) { 1120 case Object: 1121 // Make a copy of the object because the keys may be mutated. 1122 cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) { 1123 if vark, ok := k.Value.(Var); ok { 1124 if gv, ok := declared[vark]; ok { 1125 k = k.Copy() 1126 k.Value = gv 1127 } 1128 } 1129 return k, v, nil 1130 }) 1131 term.Value = cpy 1132 case Var: 1133 if gv, ok := declared[v]; ok { 1134 term.Value = gv 1135 return true 1136 } 1137 } 1138 1139 return false 1140 }) 1141 1142 vis.Walk(rule.Head.Args) 1143 1144 if rule.Head.Key != nil { 1145 vis.Walk(rule.Head.Key) 1146 } 1147 1148 if rule.Head.Value != nil { 1149 vis.Walk(rule.Head.Value) 1150 } 1151 1152 return false 1153 }) 1154 } 1155} 1156 1157func (c *Compiler) rewriteLocalArgVars(gen *localVarGenerator, stack *localDeclaredVars, rule *Rule) { 1158 1159 vis := &ruleArgLocalRewriter{ 1160 stack: stack, 1161 gen: gen, 1162 } 1163 1164 for i := range rule.Head.Args { 1165 Walk(vis, rule.Head.Args[i]) 1166 } 1167 1168 for i := range vis.errs { 1169 c.err(vis.errs[i]) 1170 } 1171} 1172 1173type ruleArgLocalRewriter struct { 1174 stack *localDeclaredVars 1175 gen *localVarGenerator 1176 errs []*Error 1177} 1178 1179func (vis *ruleArgLocalRewriter) Visit(x interface{}) Visitor { 1180 1181 t, ok := x.(*Term) 1182 if !ok { 1183 return vis 1184 } 1185 1186 switch v := t.Value.(type) { 1187 case Var: 1188 gv, ok := vis.stack.Declared(v) 1189 if !ok { 1190 gv = vis.gen.Generate() 1191 vis.stack.Insert(v, gv, argVar) 1192 } 1193 t.Value = gv 1194 return nil 1195 case Object: 1196 if cpy, err := v.Map(func(k, v *Term) (*Term, *Term, error) { 1197 vcpy := v.Copy() 1198 Walk(vis, vcpy) 1199 return k, vcpy, nil 1200 }); err != nil { 1201 vis.errs = append(vis.errs, NewError(CompileErr, t.Location, err.Error())) 1202 } else { 1203 t.Value = cpy 1204 } 1205 return nil 1206 case Null, Boolean, Number, String, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Set: 1207 // Scalars are no-ops. Comprehensions are handled above. Sets must not 1208 // contain variables. 1209 return nil 1210 case Call: 1211 vis.errs = append(vis.errs, NewError(CompileErr, t.Location, "rule arguments cannot contain calls")) 1212 return nil 1213 default: 1214 // Recurse on refs and arrays. Any embedded 1215 // variables can be rewritten. 1216 return vis 1217 } 1218} 1219 1220func (c *Compiler) rewriteWithModifiers() { 1221 f := newEqualityFactory(c.localvargen) 1222 for _, name := range c.sorted { 1223 mod := c.Modules[name] 1224 t := NewGenericTransformer(func(x interface{}) (interface{}, error) { 1225 body, ok := x.(Body) 1226 if !ok { 1227 return x, nil 1228 } 1229 body, err := rewriteWithModifiersInBody(c, f, body) 1230 if err != nil { 1231 c.err(err) 1232 } 1233 1234 return body, nil 1235 }) 1236 Transform(t, mod) 1237 } 1238} 1239 1240func (c *Compiler) setModuleTree() { 1241 c.ModuleTree = NewModuleTree(c.Modules) 1242} 1243 1244func (c *Compiler) setRuleTree() { 1245 c.RuleTree = NewRuleTree(c.ModuleTree) 1246} 1247 1248func (c *Compiler) setGraph() { 1249 c.Graph = NewGraph(c.Modules, c.GetRulesDynamic) 1250} 1251 1252type queryCompiler struct { 1253 compiler *Compiler 1254 qctx *QueryContext 1255 typeEnv *TypeEnv 1256 rewritten map[Var]Var 1257 after map[string][]QueryCompilerStageDefinition 1258 unsafeBuiltins map[string]struct{} 1259} 1260 1261func newQueryCompiler(compiler *Compiler) QueryCompiler { 1262 qc := &queryCompiler{ 1263 compiler: compiler, 1264 qctx: nil, 1265 after: map[string][]QueryCompilerStageDefinition{}, 1266 } 1267 return qc 1268} 1269 1270func (qc *queryCompiler) WithContext(qctx *QueryContext) QueryCompiler { 1271 qc.qctx = qctx 1272 return qc 1273} 1274 1275func (qc *queryCompiler) WithStageAfter(after string, stage QueryCompilerStageDefinition) QueryCompiler { 1276 qc.after[after] = append(qc.after[after], stage) 1277 return qc 1278} 1279 1280func (qc *queryCompiler) WithUnsafeBuiltins(unsafe map[string]struct{}) QueryCompiler { 1281 qc.unsafeBuiltins = unsafe 1282 return qc 1283} 1284 1285func (qc *queryCompiler) RewrittenVars() map[Var]Var { 1286 return qc.rewritten 1287} 1288 1289func (qc *queryCompiler) runStage(metricName string, qctx *QueryContext, query Body, s func(*QueryContext, Body) (Body, error)) (Body, error) { 1290 if qc.compiler.metrics != nil { 1291 qc.compiler.metrics.Timer(metricName).Start() 1292 defer qc.compiler.metrics.Timer(metricName).Stop() 1293 } 1294 return s(qctx, query) 1295} 1296 1297func (qc *queryCompiler) runStageAfter(metricName string, query Body, s QueryCompilerStage) (Body, error) { 1298 if qc.compiler.metrics != nil { 1299 qc.compiler.metrics.Timer(metricName).Start() 1300 defer qc.compiler.metrics.Timer(metricName).Stop() 1301 } 1302 return s(qc, query) 1303} 1304 1305func (qc *queryCompiler) Compile(query Body) (Body, error) { 1306 1307 query = query.Copy() 1308 1309 stages := []struct { 1310 name string 1311 metricName string 1312 f func(*QueryContext, Body) (Body, error) 1313 }{ 1314 {"ResolveRefs", "query_compile_stage_resolve_refs", qc.resolveRefs}, 1315 {"RewriteLocalVars", "query_compile_stage_rewrite_local_vars", qc.rewriteLocalVars}, 1316 {"RewriteExprTerms", "query_compile_stage_rewrite_expr_terms", qc.rewriteExprTerms}, 1317 {"RewriteComprehensionTerms", "query_compile_stage_rewrite_comprehension_terms", qc.rewriteComprehensionTerms}, 1318 {"RewriteWithValues", "query_compile_stage_rewrite_with_values", qc.rewriteWithModifiers}, 1319 {"CheckUndefinedFuncs", "query_compile_stage_check_undefined_funcs", qc.checkUndefinedFuncs}, 1320 {"CheckSafety", "query_compile_stage_check_safety", qc.checkSafety}, 1321 {"RewriteDynamicTerms", "query_compile_stage_rewrite_dynamic_terms", qc.rewriteDynamicTerms}, 1322 {"CheckTypes", "query_compile_stage_check_types", qc.checkTypes}, 1323 {"CheckUnsafeBuiltins", "query_compile_stage_check_unsafe_builtins", qc.checkUnsafeBuiltins}, 1324 } 1325 1326 qctx := qc.qctx.Copy() 1327 1328 for _, s := range stages { 1329 var err error 1330 query, err = qc.runStage(s.metricName, qctx, query, s.f) 1331 if err != nil { 1332 return nil, qc.applyErrorLimit(err) 1333 } 1334 for _, s := range qc.after[s.name] { 1335 query, err = qc.runStageAfter(s.MetricName, query, s.Stage) 1336 if err != nil { 1337 return nil, qc.applyErrorLimit(err) 1338 } 1339 } 1340 } 1341 1342 return query, nil 1343} 1344 1345func (qc *queryCompiler) TypeEnv() *TypeEnv { 1346 return qc.typeEnv 1347} 1348 1349func (qc *queryCompiler) applyErrorLimit(err error) error { 1350 if errs, ok := err.(Errors); ok { 1351 if qc.compiler.maxErrs > 0 && len(errs) > qc.compiler.maxErrs { 1352 err = append(errs[:qc.compiler.maxErrs], errLimitReached) 1353 } 1354 } 1355 return err 1356} 1357 1358func (qc *queryCompiler) resolveRefs(qctx *QueryContext, body Body) (Body, error) { 1359 1360 var globals map[Var]Ref 1361 1362 if qctx != nil && qctx.Package != nil { 1363 var ruleExports []Var 1364 rules := qc.compiler.getExports() 1365 if exist, ok := rules.Get(qctx.Package.Path); ok { 1366 ruleExports = exist.([]Var) 1367 } 1368 1369 globals = getGlobals(qctx.Package, ruleExports, qc.qctx.Imports) 1370 qctx.Imports = nil 1371 } 1372 1373 ignore := &declaredVarStack{declaredVars(body)} 1374 1375 return resolveRefsInBody(globals, ignore, body), nil 1376} 1377 1378func (qc *queryCompiler) rewriteComprehensionTerms(_ *QueryContext, body Body) (Body, error) { 1379 gen := newLocalVarGenerator("q", body) 1380 f := newEqualityFactory(gen) 1381 node, err := rewriteComprehensionTerms(f, body) 1382 if err != nil { 1383 return nil, err 1384 } 1385 return node.(Body), nil 1386} 1387 1388func (qc *queryCompiler) rewriteDynamicTerms(_ *QueryContext, body Body) (Body, error) { 1389 gen := newLocalVarGenerator("q", body) 1390 f := newEqualityFactory(gen) 1391 return rewriteDynamics(f, body), nil 1392} 1393 1394func (qc *queryCompiler) rewriteExprTerms(_ *QueryContext, body Body) (Body, error) { 1395 gen := newLocalVarGenerator("q", body) 1396 return rewriteExprTermsInBody(gen, body), nil 1397} 1398 1399func (qc *queryCompiler) rewriteLocalVars(_ *QueryContext, body Body) (Body, error) { 1400 gen := newLocalVarGenerator("q", body) 1401 stack := newLocalDeclaredVars() 1402 body, _, err := rewriteLocalVars(gen, stack, nil, body) 1403 if len(err) != 0 { 1404 return nil, err 1405 } 1406 qc.rewritten = make(map[Var]Var, len(stack.rewritten)) 1407 for k, v := range stack.rewritten { 1408 // The vars returned during the rewrite will include all seen vars, 1409 // even if they're not declared with an assignment operation. We don't 1410 // want to include these inside the rewritten set though. 1411 qc.rewritten[k] = v 1412 } 1413 return body, nil 1414} 1415 1416func (qc *queryCompiler) checkUndefinedFuncs(_ *QueryContext, body Body) (Body, error) { 1417 if errs := checkUndefinedFuncs(body, qc.compiler.GetArity); len(errs) > 0 { 1418 return nil, errs 1419 } 1420 return body, nil 1421} 1422 1423func (qc *queryCompiler) checkSafety(_ *QueryContext, body Body) (Body, error) { 1424 safe := ReservedVars.Copy() 1425 reordered, unsafe := reorderBodyForSafety(qc.compiler.builtins, qc.compiler.GetArity, safe, body) 1426 if errs := safetyErrorSlice(unsafe); len(errs) > 0 { 1427 return nil, errs 1428 } 1429 return reordered, nil 1430} 1431 1432func (qc *queryCompiler) checkTypes(qctx *QueryContext, body Body) (Body, error) { 1433 var errs Errors 1434 checker := newTypeChecker().WithVarRewriter(rewriteVarsInRef(qc.rewritten, qc.compiler.RewrittenVars)) 1435 qc.typeEnv, errs = checker.CheckBody(qc.compiler.TypeEnv, body) 1436 if len(errs) > 0 { 1437 return nil, errs 1438 } 1439 return body, nil 1440} 1441 1442func (qc *queryCompiler) checkUnsafeBuiltins(qctx *QueryContext, body Body) (Body, error) { 1443 var unsafe map[string]struct{} 1444 if qc.unsafeBuiltins != nil { 1445 unsafe = qc.unsafeBuiltins 1446 } else { 1447 unsafe = qc.compiler.unsafeBuiltinsMap 1448 } 1449 errs := checkUnsafeBuiltins(unsafe, body) 1450 if len(errs) > 0 { 1451 return nil, errs 1452 } 1453 return body, nil 1454} 1455 1456func (qc *queryCompiler) rewriteWithModifiers(qctx *QueryContext, body Body) (Body, error) { 1457 f := newEqualityFactory(newLocalVarGenerator("q", body)) 1458 body, err := rewriteWithModifiersInBody(qc.compiler, f, body) 1459 if err != nil { 1460 return nil, Errors{err} 1461 } 1462 return body, nil 1463} 1464 1465// ModuleTreeNode represents a node in the module tree. The module 1466// tree is keyed by the package path. 1467type ModuleTreeNode struct { 1468 Key Value 1469 Modules []*Module 1470 Children map[Value]*ModuleTreeNode 1471 Hide bool 1472} 1473 1474// NewModuleTree returns a new ModuleTreeNode that represents the root 1475// of the module tree populated with the given modules. 1476func NewModuleTree(mods map[string]*Module) *ModuleTreeNode { 1477 root := &ModuleTreeNode{ 1478 Children: map[Value]*ModuleTreeNode{}, 1479 } 1480 for _, m := range mods { 1481 node := root 1482 for i, x := range m.Package.Path { 1483 c, ok := node.Children[x.Value] 1484 if !ok { 1485 var hide bool 1486 if i == 1 && x.Value.Compare(SystemDocumentKey) == 0 { 1487 hide = true 1488 } 1489 c = &ModuleTreeNode{ 1490 Key: x.Value, 1491 Children: map[Value]*ModuleTreeNode{}, 1492 Hide: hide, 1493 } 1494 node.Children[x.Value] = c 1495 } 1496 node = c 1497 } 1498 node.Modules = append(node.Modules, m) 1499 } 1500 return root 1501} 1502 1503// Size returns the number of modules in the tree. 1504func (n *ModuleTreeNode) Size() int { 1505 s := len(n.Modules) 1506 for _, c := range n.Children { 1507 s += c.Size() 1508 } 1509 return s 1510} 1511 1512// DepthFirst performs a depth-first traversal of the module tree rooted at n. 1513// If f returns true, traversal will not continue to the children of n. 1514func (n *ModuleTreeNode) DepthFirst(f func(node *ModuleTreeNode) bool) { 1515 if !f(n) { 1516 for _, node := range n.Children { 1517 node.DepthFirst(f) 1518 } 1519 } 1520} 1521 1522// TreeNode represents a node in the rule tree. The rule tree is keyed by 1523// rule path. 1524type TreeNode struct { 1525 Key Value 1526 Values []util.T 1527 Children map[Value]*TreeNode 1528 Hide bool 1529} 1530 1531// NewRuleTree returns a new TreeNode that represents the root 1532// of the rule tree populated with the given rules. 1533func NewRuleTree(mtree *ModuleTreeNode) *TreeNode { 1534 1535 ruleSets := map[String][]util.T{} 1536 1537 // Build rule sets for this package. 1538 for _, mod := range mtree.Modules { 1539 for _, rule := range mod.Rules { 1540 key := String(rule.Head.Name) 1541 ruleSets[key] = append(ruleSets[key], rule) 1542 } 1543 } 1544 1545 // Each rule set becomes a leaf node. 1546 children := map[Value]*TreeNode{} 1547 1548 for key, rules := range ruleSets { 1549 children[key] = &TreeNode{ 1550 Key: key, 1551 Children: nil, 1552 Values: rules, 1553 } 1554 } 1555 1556 // Each module in subpackage becomes child node. 1557 for _, child := range mtree.Children { 1558 children[child.Key] = NewRuleTree(child) 1559 } 1560 1561 return &TreeNode{ 1562 Key: mtree.Key, 1563 Values: nil, 1564 Children: children, 1565 Hide: mtree.Hide, 1566 } 1567} 1568 1569// Size returns the number of rules in the tree. 1570func (n *TreeNode) Size() int { 1571 s := len(n.Values) 1572 for _, c := range n.Children { 1573 s += c.Size() 1574 } 1575 return s 1576} 1577 1578// Child returns n's child with key k. 1579func (n *TreeNode) Child(k Value) *TreeNode { 1580 switch k.(type) { 1581 case String, Var: 1582 return n.Children[k] 1583 } 1584 return nil 1585} 1586 1587// DepthFirst performs a depth-first traversal of the rule tree rooted at n. If 1588// f returns true, traversal will not continue to the children of n. 1589func (n *TreeNode) DepthFirst(f func(node *TreeNode) bool) { 1590 if !f(n) { 1591 for _, node := range n.Children { 1592 node.DepthFirst(f) 1593 } 1594 } 1595} 1596 1597// Graph represents the graph of dependencies between rules. 1598type Graph struct { 1599 adj map[util.T]map[util.T]struct{} 1600 nodes map[util.T]struct{} 1601 sorted []util.T 1602} 1603 1604// NewGraph returns a new Graph based on modules. The list function must return 1605// the rules referred to directly by the ref. 1606func NewGraph(modules map[string]*Module, list func(Ref) []*Rule) *Graph { 1607 1608 graph := &Graph{ 1609 adj: map[util.T]map[util.T]struct{}{}, 1610 nodes: map[util.T]struct{}{}, 1611 sorted: nil, 1612 } 1613 1614 // Create visitor to walk a rule AST and add edges to the rule graph for 1615 // each dependency. 1616 vis := func(a *Rule) *GenericVisitor { 1617 stop := false 1618 return NewGenericVisitor(func(x interface{}) bool { 1619 switch x := x.(type) { 1620 case Ref: 1621 for _, b := range list(x) { 1622 for node := b; node != nil; node = node.Else { 1623 graph.addDependency(a, node) 1624 } 1625 } 1626 case *Rule: 1627 if stop { 1628 // Do not recurse into else clauses (which will be handled 1629 // by the outer visitor.) 1630 return true 1631 } 1632 stop = true 1633 } 1634 return false 1635 }) 1636 } 1637 1638 // Walk over all rules, add them to graph, and build adjencency lists. 1639 for _, module := range modules { 1640 WalkRules(module, func(a *Rule) bool { 1641 graph.addNode(a) 1642 vis(a).Walk(a) 1643 return false 1644 }) 1645 } 1646 1647 return graph 1648} 1649 1650// Dependencies returns the set of rules that x depends on. 1651func (g *Graph) Dependencies(x util.T) map[util.T]struct{} { 1652 return g.adj[x] 1653} 1654 1655// Sort returns a slice of rules sorted by dependencies. If a cycle is found, 1656// ok is set to false. 1657func (g *Graph) Sort() (sorted []util.T, ok bool) { 1658 if g.sorted != nil { 1659 return g.sorted, true 1660 } 1661 1662 sort := &graphSort{ 1663 sorted: make([]util.T, 0, len(g.nodes)), 1664 deps: g.Dependencies, 1665 marked: map[util.T]struct{}{}, 1666 temp: map[util.T]struct{}{}, 1667 } 1668 1669 for node := range g.nodes { 1670 if !sort.Visit(node) { 1671 return nil, false 1672 } 1673 } 1674 1675 g.sorted = sort.sorted 1676 return g.sorted, true 1677} 1678 1679func (g *Graph) addDependency(u util.T, v util.T) { 1680 1681 if _, ok := g.nodes[u]; !ok { 1682 g.addNode(u) 1683 } 1684 1685 if _, ok := g.nodes[v]; !ok { 1686 g.addNode(v) 1687 } 1688 1689 edges, ok := g.adj[u] 1690 if !ok { 1691 edges = map[util.T]struct{}{} 1692 g.adj[u] = edges 1693 } 1694 1695 edges[v] = struct{}{} 1696} 1697 1698func (g *Graph) addNode(n util.T) { 1699 g.nodes[n] = struct{}{} 1700} 1701 1702type graphSort struct { 1703 sorted []util.T 1704 deps func(util.T) map[util.T]struct{} 1705 marked map[util.T]struct{} 1706 temp map[util.T]struct{} 1707} 1708 1709func (sort *graphSort) Marked(node util.T) bool { 1710 _, marked := sort.marked[node] 1711 return marked 1712} 1713 1714func (sort *graphSort) Visit(node util.T) (ok bool) { 1715 if _, ok := sort.temp[node]; ok { 1716 return false 1717 } 1718 if sort.Marked(node) { 1719 return true 1720 } 1721 sort.temp[node] = struct{}{} 1722 for other := range sort.deps(node) { 1723 if !sort.Visit(other) { 1724 return false 1725 } 1726 } 1727 sort.marked[node] = struct{}{} 1728 delete(sort.temp, node) 1729 sort.sorted = append(sort.sorted, node) 1730 return true 1731} 1732 1733// GraphTraversal is a Traversal that understands the dependency graph 1734type GraphTraversal struct { 1735 graph *Graph 1736 visited map[util.T]struct{} 1737} 1738 1739// NewGraphTraversal returns a Traversal for the dependency graph 1740func NewGraphTraversal(graph *Graph) *GraphTraversal { 1741 return &GraphTraversal{ 1742 graph: graph, 1743 visited: map[util.T]struct{}{}, 1744 } 1745} 1746 1747// Edges lists all dependency connections for a given node 1748func (g *GraphTraversal) Edges(x util.T) []util.T { 1749 r := []util.T{} 1750 for v := range g.graph.Dependencies(x) { 1751 r = append(r, v) 1752 } 1753 return r 1754} 1755 1756// Visited returns whether a node has been visited, setting a node to visited if not 1757func (g *GraphTraversal) Visited(u util.T) bool { 1758 _, ok := g.visited[u] 1759 g.visited[u] = struct{}{} 1760 return ok 1761} 1762 1763type unsafePair struct { 1764 Expr *Expr 1765 Vars VarSet 1766} 1767 1768type unsafeVarLoc struct { 1769 Var Var 1770 Loc *Location 1771} 1772 1773type unsafeVars map[*Expr]VarSet 1774 1775func (vs unsafeVars) Add(e *Expr, v Var) { 1776 if u, ok := vs[e]; ok { 1777 u[v] = struct{}{} 1778 } else { 1779 vs[e] = VarSet{v: struct{}{}} 1780 } 1781} 1782 1783func (vs unsafeVars) Set(e *Expr, s VarSet) { 1784 vs[e] = s 1785} 1786 1787func (vs unsafeVars) Update(o unsafeVars) { 1788 for k, v := range o { 1789 if _, ok := vs[k]; !ok { 1790 vs[k] = VarSet{} 1791 } 1792 vs[k].Update(v) 1793 } 1794} 1795 1796func (vs unsafeVars) Vars() (result []unsafeVarLoc) { 1797 1798 locs := map[Var]*Location{} 1799 1800 // If var appears in multiple sets then pick first by location. 1801 for expr, vars := range vs { 1802 for v := range vars { 1803 if locs[v].Compare(expr.Location) > 0 { 1804 locs[v] = expr.Location 1805 } 1806 } 1807 } 1808 1809 for v, loc := range locs { 1810 result = append(result, unsafeVarLoc{ 1811 Var: v, 1812 Loc: loc, 1813 }) 1814 } 1815 1816 sort.Slice(result, func(i, j int) bool { 1817 return result[i].Loc.Compare(result[j].Loc) < 0 1818 }) 1819 1820 return result 1821} 1822 1823func (vs unsafeVars) Slice() (result []unsafePair) { 1824 for expr, vs := range vs { 1825 result = append(result, unsafePair{ 1826 Expr: expr, 1827 Vars: vs, 1828 }) 1829 } 1830 return 1831} 1832 1833// reorderBodyForSafety returns a copy of the body ordered such that 1834// left to right evaluation of the body will not encounter unbound variables 1835// in input positions or negated expressions. 1836// 1837// Expressions are added to the re-ordered body as soon as they are considered 1838// safe. If multiple expressions become safe in the same pass, they are added 1839// in their original order. This results in minimal re-ordering of the body. 1840// 1841// If the body cannot be reordered to ensure safety, the second return value 1842// contains a mapping of expressions to unsafe variables in those expressions. 1843func reorderBodyForSafety(builtins map[string]*Builtin, arity func(Ref) int, globals VarSet, body Body) (Body, unsafeVars) { 1844 1845 body, unsafe := reorderBodyForClosures(builtins, arity, globals, body) 1846 if len(unsafe) != 0 { 1847 return nil, unsafe 1848 } 1849 1850 reordered := Body{} 1851 safe := VarSet{} 1852 1853 for _, e := range body { 1854 for v := range e.Vars(safetyCheckVarVisitorParams) { 1855 if globals.Contains(v) { 1856 safe.Add(v) 1857 } else { 1858 unsafe.Add(e, v) 1859 } 1860 } 1861 } 1862 1863 for { 1864 n := len(reordered) 1865 1866 for _, e := range body { 1867 if reordered.Contains(e) { 1868 continue 1869 } 1870 1871 safe.Update(outputVarsForExpr(e, builtins, arity, safe)) 1872 1873 for v := range unsafe[e] { 1874 if safe.Contains(v) { 1875 delete(unsafe[e], v) 1876 } 1877 } 1878 1879 if len(unsafe[e]) == 0 { 1880 delete(unsafe, e) 1881 reordered = append(reordered, e) 1882 } 1883 } 1884 1885 if len(reordered) == n { 1886 break 1887 } 1888 } 1889 1890 // Recursively visit closures and perform the safety checks on them. 1891 // Update the globals at each expression to include the variables that could 1892 // be closed over. 1893 g := globals.Copy() 1894 for i, e := range reordered { 1895 if i > 0 { 1896 g.Update(reordered[i-1].Vars(safetyCheckVarVisitorParams)) 1897 } 1898 vis := &bodySafetyVisitor{ 1899 builtins: builtins, 1900 arity: arity, 1901 current: e, 1902 globals: g, 1903 unsafe: unsafe, 1904 } 1905 NewGenericVisitor(vis.Visit).Walk(e) 1906 } 1907 1908 // Need to reset expression indices as re-ordering may have 1909 // changed them. 1910 setExprIndices(reordered) 1911 1912 return reordered, unsafe 1913} 1914 1915type bodySafetyVisitor struct { 1916 builtins map[string]*Builtin 1917 arity func(Ref) int 1918 current *Expr 1919 globals VarSet 1920 unsafe unsafeVars 1921} 1922 1923func (vis *bodySafetyVisitor) Visit(x interface{}) bool { 1924 switch x := x.(type) { 1925 case *Expr: 1926 cpy := *vis 1927 cpy.current = x 1928 1929 switch ts := x.Terms.(type) { 1930 case *SomeDecl: 1931 NewGenericVisitor(cpy.Visit).Walk(ts) 1932 case []*Term: 1933 for _, t := range ts { 1934 NewGenericVisitor(cpy.Visit).Walk(t) 1935 } 1936 case *Term: 1937 NewGenericVisitor(cpy.Visit).Walk(ts) 1938 } 1939 for i := range x.With { 1940 NewGenericVisitor(cpy.Visit).Walk(x.With[i]) 1941 } 1942 return true 1943 case *ArrayComprehension: 1944 vis.checkArrayComprehensionSafety(x) 1945 return true 1946 case *ObjectComprehension: 1947 vis.checkObjectComprehensionSafety(x) 1948 return true 1949 case *SetComprehension: 1950 vis.checkSetComprehensionSafety(x) 1951 return true 1952 } 1953 return false 1954} 1955 1956// Check term for safety. This is analogous to the rule head safety check. 1957func (vis *bodySafetyVisitor) checkComprehensionSafety(tv VarSet, body Body) Body { 1958 bv := body.Vars(safetyCheckVarVisitorParams) 1959 bv.Update(vis.globals) 1960 uv := tv.Diff(bv) 1961 for v := range uv { 1962 vis.unsafe.Add(vis.current, v) 1963 } 1964 1965 // Check body for safety, reordering as necessary. 1966 r, u := reorderBodyForSafety(vis.builtins, vis.arity, vis.globals, body) 1967 if len(u) == 0 { 1968 return r 1969 } 1970 1971 vis.unsafe.Update(u) 1972 return body 1973} 1974 1975func (vis *bodySafetyVisitor) checkArrayComprehensionSafety(ac *ArrayComprehension) { 1976 ac.Body = vis.checkComprehensionSafety(ac.Term.Vars(), ac.Body) 1977} 1978 1979func (vis *bodySafetyVisitor) checkObjectComprehensionSafety(oc *ObjectComprehension) { 1980 tv := oc.Key.Vars() 1981 tv.Update(oc.Value.Vars()) 1982 oc.Body = vis.checkComprehensionSafety(tv, oc.Body) 1983} 1984 1985func (vis *bodySafetyVisitor) checkSetComprehensionSafety(sc *SetComprehension) { 1986 sc.Body = vis.checkComprehensionSafety(sc.Term.Vars(), sc.Body) 1987} 1988 1989// reorderBodyForClosures returns a copy of the body ordered such that 1990// expressions (such as array comprehensions) that close over variables are ordered 1991// after other expressions that contain the same variable in an output position. 1992func reorderBodyForClosures(builtins map[string]*Builtin, arity func(Ref) int, globals VarSet, body Body) (Body, unsafeVars) { 1993 1994 reordered := Body{} 1995 unsafe := unsafeVars{} 1996 1997 for { 1998 n := len(reordered) 1999 2000 for _, e := range body { 2001 if reordered.Contains(e) { 2002 continue 2003 } 2004 2005 // Collect vars that are contained in closures within this 2006 // expression. 2007 vs := VarSet{} 2008 WalkClosures(e, func(x interface{}) bool { 2009 vis := &VarVisitor{vars: vs} 2010 vis.Walk(x) 2011 return true 2012 }) 2013 2014 // Compute vars that are closed over from the body but not yet 2015 // contained in the output position of an expression in the reordered 2016 // body. These vars are considered unsafe. 2017 cv := vs.Intersect(body.Vars(safetyCheckVarVisitorParams)).Diff(globals) 2018 uv := cv.Diff(outputVarsForBody(reordered, builtins, arity, globals)) 2019 2020 if len(uv) == 0 { 2021 reordered = append(reordered, e) 2022 delete(unsafe, e) 2023 } else { 2024 unsafe.Set(e, uv) 2025 } 2026 } 2027 2028 if len(reordered) == n { 2029 break 2030 } 2031 } 2032 2033 return reordered, unsafe 2034} 2035 2036func outputVarsForBody(body Body, builtins map[string]*Builtin, arity func(Ref) int, safe VarSet) VarSet { 2037 o := safe.Copy() 2038 for _, e := range body { 2039 o.Update(outputVarsForExpr(e, builtins, arity, o)) 2040 } 2041 return o.Diff(safe) 2042} 2043 2044func outputVarsForExpr(expr *Expr, builtins map[string]*Builtin, arity func(Ref) int, safe VarSet) VarSet { 2045 2046 // Negated expressions must be safe. 2047 if expr.Negated { 2048 return VarSet{} 2049 } 2050 2051 // With modifier inputs must be safe. 2052 for _, with := range expr.With { 2053 unsafe := false 2054 WalkVars(with, func(v Var) bool { 2055 if !safe.Contains(v) { 2056 unsafe = true 2057 return true 2058 } 2059 return false 2060 }) 2061 if unsafe { 2062 return VarSet{} 2063 } 2064 } 2065 2066 if !expr.IsCall() { 2067 return outputVarsForExprRefs(expr, safe) 2068 } 2069 2070 terms := expr.Terms.([]*Term) 2071 name := terms[0].String() 2072 2073 if b := builtins[name]; b != nil { 2074 if b.Name == Equality.Name { 2075 return outputVarsForExprEq(expr, safe) 2076 } 2077 return outputVarsForExprBuiltin(expr, b, safe) 2078 } 2079 2080 return outputVarsForExprCall(expr, arity, safe, terms) 2081} 2082 2083func outputVarsForExprBuiltin(expr *Expr, b *Builtin, safe VarSet) VarSet { 2084 2085 output := outputVarsForExprRefs(expr, safe) 2086 terms := expr.Terms.([]*Term) 2087 2088 // Check that all input terms are safe. 2089 for i, t := range terms[1:] { 2090 if b.IsTargetPos(i) { 2091 continue 2092 } 2093 vis := NewVarVisitor().WithParams(VarVisitorParams{ 2094 SkipClosures: true, 2095 SkipSets: true, 2096 SkipObjectKeys: true, 2097 SkipRefHead: true, 2098 }) 2099 vis.Walk(t) 2100 unsafe := vis.Vars().Diff(output).Diff(safe) 2101 if len(unsafe) > 0 { 2102 return VarSet{} 2103 } 2104 } 2105 2106 // Add vars in target positions to result. 2107 for i, t := range terms[1:] { 2108 if b.IsTargetPos(i) { 2109 vis := NewVarVisitor().WithParams(VarVisitorParams{ 2110 SkipRefHead: true, 2111 SkipSets: true, 2112 SkipObjectKeys: true, 2113 SkipClosures: true, 2114 }) 2115 vis.Walk(t) 2116 output.Update(vis.vars) 2117 } 2118 } 2119 2120 return output 2121} 2122 2123func outputVarsForExprEq(expr *Expr, safe VarSet) VarSet { 2124 if !validEqAssignArgCount(expr) { 2125 return safe 2126 } 2127 output := outputVarsForExprRefs(expr, safe) 2128 output.Update(safe) 2129 output.Update(Unify(output, expr.Operand(0), expr.Operand(1))) 2130 return output.Diff(safe) 2131} 2132 2133func outputVarsForExprCall(expr *Expr, arity func(Ref) int, safe VarSet, terms []*Term) VarSet { 2134 2135 output := outputVarsForExprRefs(expr, safe) 2136 2137 ref, ok := terms[0].Value.(Ref) 2138 if !ok { 2139 return VarSet{} 2140 } 2141 2142 numArgs := arity(ref) 2143 if numArgs == -1 { 2144 return VarSet{} 2145 } 2146 2147 numInputTerms := numArgs + 1 2148 2149 if numInputTerms >= len(terms) { 2150 return output 2151 } 2152 2153 vis := NewVarVisitor().WithParams(VarVisitorParams{ 2154 SkipClosures: true, 2155 SkipSets: true, 2156 SkipObjectKeys: true, 2157 SkipRefHead: true, 2158 }) 2159 2160 vis.Walk(Args(terms[:numInputTerms])) 2161 unsafe := vis.Vars().Diff(output).Diff(safe) 2162 2163 if len(unsafe) > 0 { 2164 return VarSet{} 2165 } 2166 2167 vis = NewVarVisitor().WithParams(VarVisitorParams{ 2168 SkipRefHead: true, 2169 SkipSets: true, 2170 SkipObjectKeys: true, 2171 SkipClosures: true, 2172 }) 2173 2174 vis.Walk(Args(terms[numInputTerms:])) 2175 output.Update(vis.vars) 2176 return output 2177} 2178 2179func outputVarsForExprRefs(expr *Expr, safe VarSet) VarSet { 2180 output := VarSet{} 2181 WalkRefs(expr, func(r Ref) bool { 2182 if safe.Contains(r[0].Value.(Var)) { 2183 output.Update(r.OutputVars()) 2184 return false 2185 } 2186 return true 2187 }) 2188 return output 2189} 2190 2191type equalityFactory struct { 2192 gen *localVarGenerator 2193} 2194 2195func newEqualityFactory(gen *localVarGenerator) *equalityFactory { 2196 return &equalityFactory{gen} 2197} 2198 2199func (f *equalityFactory) Generate(other *Term) *Expr { 2200 term := NewTerm(f.gen.Generate()).SetLocation(other.Location) 2201 expr := Equality.Expr(term, other) 2202 expr.Generated = true 2203 expr.Location = other.Location 2204 return expr 2205} 2206 2207type localVarGenerator struct { 2208 exclude VarSet 2209 suffix string 2210 next int 2211} 2212 2213func newLocalVarGeneratorForModuleSet(sorted []string, modules map[string]*Module) *localVarGenerator { 2214 exclude := NewVarSet() 2215 vis := &VarVisitor{vars: exclude} 2216 for _, key := range sorted { 2217 vis.Walk(modules[key]) 2218 } 2219 return &localVarGenerator{exclude: exclude, next: 0} 2220} 2221 2222func newLocalVarGenerator(suffix string, node interface{}) *localVarGenerator { 2223 exclude := NewVarSet() 2224 vis := &VarVisitor{vars: exclude} 2225 vis.Walk(node) 2226 return &localVarGenerator{exclude: exclude, suffix: suffix, next: 0} 2227} 2228 2229func (l *localVarGenerator) Generate() Var { 2230 for { 2231 result := Var("__local" + l.suffix + strconv.Itoa(l.next) + "__") 2232 l.next++ 2233 if !l.exclude.Contains(result) { 2234 return result 2235 } 2236 } 2237} 2238 2239func getGlobals(pkg *Package, rules []Var, imports []*Import) map[Var]Ref { 2240 2241 globals := map[Var]Ref{} 2242 2243 // Populate globals with exports within the package. 2244 for _, v := range rules { 2245 global := append(Ref{}, pkg.Path...) 2246 global = append(global, &Term{Value: String(v)}) 2247 globals[v] = global 2248 } 2249 2250 // Populate globals with imports. 2251 for _, i := range imports { 2252 if len(i.Alias) > 0 { 2253 path := i.Path.Value.(Ref) 2254 globals[i.Alias] = path 2255 } else { 2256 path := i.Path.Value.(Ref) 2257 if len(path) == 1 { 2258 globals[path[0].Value.(Var)] = path 2259 } else { 2260 v := path[len(path)-1].Value.(String) 2261 globals[Var(v)] = path 2262 } 2263 } 2264 } 2265 2266 return globals 2267} 2268 2269func requiresEval(x *Term) bool { 2270 if x == nil { 2271 return false 2272 } 2273 return ContainsRefs(x) || ContainsComprehensions(x) 2274} 2275 2276func resolveRef(globals map[Var]Ref, ignore *declaredVarStack, ref Ref) Ref { 2277 2278 r := Ref{} 2279 for i, x := range ref { 2280 switch v := x.Value.(type) { 2281 case Var: 2282 if g, ok := globals[v]; ok && !ignore.Contains(v) { 2283 cpy := g.Copy() 2284 for i := range cpy { 2285 cpy[i].SetLocation(x.Location) 2286 } 2287 if i == 0 { 2288 r = cpy 2289 } else { 2290 r = append(r, NewTerm(cpy).SetLocation(x.Location)) 2291 } 2292 } else { 2293 r = append(r, x) 2294 } 2295 case Ref, Array, Object, Set, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Call: 2296 r = append(r, resolveRefsInTerm(globals, ignore, x)) 2297 default: 2298 r = append(r, x) 2299 } 2300 } 2301 2302 return r 2303} 2304 2305func resolveRefsInRule(globals map[Var]Ref, rule *Rule) error { 2306 ignore := &declaredVarStack{} 2307 2308 vars := NewVarSet() 2309 var vis *GenericVisitor 2310 var err error 2311 2312 // Walk args to collect vars and transform body so that callers can shadow 2313 // root documents. 2314 vis = NewGenericVisitor(func(x interface{}) bool { 2315 if err != nil { 2316 return true 2317 } 2318 switch x := x.(type) { 2319 case Var: 2320 vars.Add(x) 2321 2322 // Object keys cannot be pattern matched so only walk values. 2323 case Object: 2324 for _, k := range x.Keys() { 2325 vis.Walk(x.Get(k)) 2326 } 2327 2328 // Skip terms that could contain vars that cannot be pattern matched. 2329 case Set, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Call: 2330 return true 2331 2332 case *Term: 2333 if _, ok := x.Value.(Ref); ok { 2334 if RootDocumentRefs.Contains(x) { 2335 // We could support args named input, data, etc. however 2336 // this would require rewriting terms in the head and body. 2337 // Preventing root document shadowing is simpler, and 2338 // arguably, will prevent confusing names from being used. 2339 err = fmt.Errorf("args must not shadow %v (use a different variable name)", x) 2340 return true 2341 } 2342 } 2343 } 2344 return false 2345 }) 2346 2347 vis.Walk(rule.Head.Args) 2348 2349 if err != nil { 2350 return err 2351 } 2352 2353 ignore.Push(vars) 2354 ignore.Push(declaredVars(rule.Body)) 2355 2356 if rule.Head.Key != nil { 2357 rule.Head.Key = resolveRefsInTerm(globals, ignore, rule.Head.Key) 2358 } 2359 2360 if rule.Head.Value != nil { 2361 rule.Head.Value = resolveRefsInTerm(globals, ignore, rule.Head.Value) 2362 } 2363 2364 rule.Body = resolveRefsInBody(globals, ignore, rule.Body) 2365 return nil 2366} 2367 2368func resolveRefsInBody(globals map[Var]Ref, ignore *declaredVarStack, body Body) Body { 2369 r := Body{} 2370 for _, expr := range body { 2371 r = append(r, resolveRefsInExpr(globals, ignore, expr)) 2372 } 2373 return r 2374} 2375 2376func resolveRefsInExpr(globals map[Var]Ref, ignore *declaredVarStack, expr *Expr) *Expr { 2377 cpy := *expr 2378 switch ts := expr.Terms.(type) { 2379 case *Term: 2380 cpy.Terms = resolveRefsInTerm(globals, ignore, ts) 2381 case []*Term: 2382 buf := make([]*Term, len(ts)) 2383 for i := 0; i < len(ts); i++ { 2384 buf[i] = resolveRefsInTerm(globals, ignore, ts[i]) 2385 } 2386 cpy.Terms = buf 2387 } 2388 for _, w := range cpy.With { 2389 w.Target = resolveRefsInTerm(globals, ignore, w.Target) 2390 w.Value = resolveRefsInTerm(globals, ignore, w.Value) 2391 } 2392 return &cpy 2393} 2394 2395func resolveRefsInTerm(globals map[Var]Ref, ignore *declaredVarStack, term *Term) *Term { 2396 switch v := term.Value.(type) { 2397 case Var: 2398 if g, ok := globals[v]; ok && !ignore.Contains(v) { 2399 cpy := g.Copy() 2400 for i := range cpy { 2401 cpy[i].SetLocation(term.Location) 2402 } 2403 return NewTerm(cpy).SetLocation(term.Location) 2404 } 2405 return term 2406 case Ref: 2407 fqn := resolveRef(globals, ignore, v) 2408 cpy := *term 2409 cpy.Value = fqn 2410 return &cpy 2411 case Object: 2412 cpy := *term 2413 cpy.Value, _ = v.Map(func(k, v *Term) (*Term, *Term, error) { 2414 k = resolveRefsInTerm(globals, ignore, k) 2415 v = resolveRefsInTerm(globals, ignore, v) 2416 return k, v, nil 2417 }) 2418 return &cpy 2419 case Array: 2420 cpy := *term 2421 cpy.Value = Array(resolveRefsInTermSlice(globals, ignore, v)) 2422 return &cpy 2423 case Call: 2424 cpy := *term 2425 cpy.Value = Call(resolveRefsInTermSlice(globals, ignore, v)) 2426 return &cpy 2427 case Set: 2428 s, _ := v.Map(func(e *Term) (*Term, error) { 2429 return resolveRefsInTerm(globals, ignore, e), nil 2430 }) 2431 cpy := *term 2432 cpy.Value = s 2433 return &cpy 2434 case *ArrayComprehension: 2435 ac := &ArrayComprehension{} 2436 ignore.Push(declaredVars(v.Body)) 2437 ac.Term = resolveRefsInTerm(globals, ignore, v.Term) 2438 ac.Body = resolveRefsInBody(globals, ignore, v.Body) 2439 cpy := *term 2440 cpy.Value = ac 2441 ignore.Pop() 2442 return &cpy 2443 case *ObjectComprehension: 2444 oc := &ObjectComprehension{} 2445 ignore.Push(declaredVars(v.Body)) 2446 oc.Key = resolveRefsInTerm(globals, ignore, v.Key) 2447 oc.Value = resolveRefsInTerm(globals, ignore, v.Value) 2448 oc.Body = resolveRefsInBody(globals, ignore, v.Body) 2449 cpy := *term 2450 cpy.Value = oc 2451 ignore.Pop() 2452 return &cpy 2453 case *SetComprehension: 2454 sc := &SetComprehension{} 2455 ignore.Push(declaredVars(v.Body)) 2456 sc.Term = resolveRefsInTerm(globals, ignore, v.Term) 2457 sc.Body = resolveRefsInBody(globals, ignore, v.Body) 2458 cpy := *term 2459 cpy.Value = sc 2460 ignore.Pop() 2461 return &cpy 2462 default: 2463 return term 2464 } 2465} 2466 2467func resolveRefsInTermSlice(globals map[Var]Ref, ignore *declaredVarStack, terms []*Term) []*Term { 2468 cpy := make([]*Term, len(terms)) 2469 for i := 0; i < len(terms); i++ { 2470 cpy[i] = resolveRefsInTerm(globals, ignore, terms[i]) 2471 } 2472 return cpy 2473} 2474 2475type declaredVarStack []VarSet 2476 2477func (s declaredVarStack) Contains(v Var) bool { 2478 for i := len(s) - 1; i >= 0; i-- { 2479 if _, ok := s[i][v]; ok { 2480 return ok 2481 } 2482 } 2483 return false 2484} 2485 2486func (s declaredVarStack) Add(v Var) { 2487 s[len(s)-1].Add(v) 2488} 2489 2490func (s *declaredVarStack) Push(vs VarSet) { 2491 *s = append(*s, vs) 2492} 2493 2494func (s *declaredVarStack) Pop() { 2495 curr := *s 2496 *s = curr[:len(curr)-1] 2497} 2498 2499func declaredVars(x interface{}) VarSet { 2500 vars := NewVarSet() 2501 vis := NewGenericVisitor(func(x interface{}) bool { 2502 switch x := x.(type) { 2503 case *Expr: 2504 if x.IsAssignment() && validEqAssignArgCount(x) { 2505 WalkVars(x.Operand(0), func(v Var) bool { 2506 vars.Add(v) 2507 return false 2508 }) 2509 } else if decl, ok := x.Terms.(*SomeDecl); ok { 2510 for i := range decl.Symbols { 2511 vars.Add(decl.Symbols[i].Value.(Var)) 2512 } 2513 } 2514 case *ArrayComprehension, *SetComprehension, *ObjectComprehension: 2515 return true 2516 } 2517 return false 2518 }) 2519 vis.Walk(x) 2520 return vars 2521} 2522 2523// rewriteComprehensionTerms will rewrite comprehensions so that the term part 2524// is bound to a variable in the body. This allows any type of term to be used 2525// in the term part (even if the term requires evaluation.) 2526// 2527// For instance, given the following comprehension: 2528// 2529// [x[0] | x = y[_]; y = [1,2,3]] 2530// 2531// The comprehension would be rewritten as: 2532// 2533// [__local0__ | x = y[_]; y = [1,2,3]; __local0__ = x[0]] 2534func rewriteComprehensionTerms(f *equalityFactory, node interface{}) (interface{}, error) { 2535 return TransformComprehensions(node, func(x interface{}) (Value, error) { 2536 switch x := x.(type) { 2537 case *ArrayComprehension: 2538 if requiresEval(x.Term) { 2539 expr := f.Generate(x.Term) 2540 x.Term = expr.Operand(0) 2541 x.Body.Append(expr) 2542 } 2543 return x, nil 2544 case *SetComprehension: 2545 if requiresEval(x.Term) { 2546 expr := f.Generate(x.Term) 2547 x.Term = expr.Operand(0) 2548 x.Body.Append(expr) 2549 } 2550 return x, nil 2551 case *ObjectComprehension: 2552 if requiresEval(x.Key) { 2553 expr := f.Generate(x.Key) 2554 x.Key = expr.Operand(0) 2555 x.Body.Append(expr) 2556 } 2557 if requiresEval(x.Value) { 2558 expr := f.Generate(x.Value) 2559 x.Value = expr.Operand(0) 2560 x.Body.Append(expr) 2561 } 2562 return x, nil 2563 } 2564 panic("illegal type") 2565 }) 2566} 2567 2568// rewriteEquals will rewrite exprs under x as unification calls instead of == 2569// calls. For example: 2570// 2571// data.foo == data.bar is rewritten as data.foo = data.bar 2572// 2573// This stage should only run the safety check (since == is a built-in with no 2574// outputs, so the inputs must not be marked as safe.) 2575// 2576// This stage is not executed by the query compiler by default because when 2577// callers specify == instead of = they expect to receive a true/false/undefined 2578// result back whereas with = the result is only ever true/undefined. For 2579// partial evaluation cases we do want to rewrite == to = to simplify the 2580// result. 2581func rewriteEquals(x interface{}) { 2582 doubleEq := Equal.Ref() 2583 unifyOp := Equality.Ref() 2584 WalkExprs(x, func(x *Expr) bool { 2585 if x.IsCall() { 2586 operator := x.Operator() 2587 if operator.Equal(doubleEq) && len(x.Operands()) == 2 { 2588 x.SetOperator(NewTerm(unifyOp)) 2589 } 2590 } 2591 return false 2592 }) 2593} 2594 2595// rewriteDynamics will rewrite the body so that dynamic terms (i.e., refs and 2596// comprehensions) are bound to vars earlier in the query. This translation 2597// results in eager evaluation. 2598// 2599// For instance, given the following query: 2600// 2601// foo(data.bar) = 1 2602// 2603// The rewritten version will be: 2604// 2605// __local0__ = data.bar; foo(__local0__) = 1 2606func rewriteDynamics(f *equalityFactory, body Body) Body { 2607 result := make(Body, 0, len(body)) 2608 for _, expr := range body { 2609 if expr.IsEquality() { 2610 result = rewriteDynamicsEqExpr(f, expr, result) 2611 } else if expr.IsCall() { 2612 result = rewriteDynamicsCallExpr(f, expr, result) 2613 } else { 2614 result = rewriteDynamicsTermExpr(f, expr, result) 2615 } 2616 } 2617 return result 2618} 2619 2620func appendExpr(body Body, expr *Expr) Body { 2621 body.Append(expr) 2622 return body 2623} 2624 2625func rewriteDynamicsEqExpr(f *equalityFactory, expr *Expr, result Body) Body { 2626 if !validEqAssignArgCount(expr) { 2627 return appendExpr(result, expr) 2628 } 2629 terms := expr.Terms.([]*Term) 2630 result, terms[1] = rewriteDynamicsInTerm(expr, f, terms[1], result) 2631 result, terms[2] = rewriteDynamicsInTerm(expr, f, terms[2], result) 2632 return appendExpr(result, expr) 2633} 2634 2635func rewriteDynamicsCallExpr(f *equalityFactory, expr *Expr, result Body) Body { 2636 terms := expr.Terms.([]*Term) 2637 for i := 1; i < len(terms); i++ { 2638 result, terms[i] = rewriteDynamicsOne(expr, f, terms[i], result) 2639 } 2640 return appendExpr(result, expr) 2641} 2642 2643func rewriteDynamicsTermExpr(f *equalityFactory, expr *Expr, result Body) Body { 2644 term := expr.Terms.(*Term) 2645 result, expr.Terms = rewriteDynamicsInTerm(expr, f, term, result) 2646 return appendExpr(result, expr) 2647} 2648 2649func rewriteDynamicsInTerm(original *Expr, f *equalityFactory, term *Term, result Body) (Body, *Term) { 2650 switch v := term.Value.(type) { 2651 case Ref: 2652 for i := 1; i < len(v); i++ { 2653 result, v[i] = rewriteDynamicsOne(original, f, v[i], result) 2654 } 2655 case *ArrayComprehension: 2656 v.Body = rewriteDynamics(f, v.Body) 2657 case *SetComprehension: 2658 v.Body = rewriteDynamics(f, v.Body) 2659 case *ObjectComprehension: 2660 v.Body = rewriteDynamics(f, v.Body) 2661 default: 2662 result, term = rewriteDynamicsOne(original, f, term, result) 2663 } 2664 return result, term 2665} 2666 2667func rewriteDynamicsOne(original *Expr, f *equalityFactory, term *Term, result Body) (Body, *Term) { 2668 switch v := term.Value.(type) { 2669 case Ref: 2670 for i := 1; i < len(v); i++ { 2671 result, v[i] = rewriteDynamicsOne(original, f, v[i], result) 2672 } 2673 generated := f.Generate(term) 2674 generated.With = original.With 2675 result.Append(generated) 2676 return result, result[len(result)-1].Operand(0) 2677 case Array: 2678 for i := 0; i < len(v); i++ { 2679 result, v[i] = rewriteDynamicsOne(original, f, v[i], result) 2680 } 2681 return result, term 2682 case Object: 2683 cpy := NewObject() 2684 for _, key := range v.Keys() { 2685 value := v.Get(key) 2686 result, key = rewriteDynamicsOne(original, f, key, result) 2687 result, value = rewriteDynamicsOne(original, f, value, result) 2688 cpy.Insert(key, value) 2689 } 2690 return result, NewTerm(cpy).SetLocation(term.Location) 2691 case Set: 2692 cpy := NewSet() 2693 for _, term := range v.Slice() { 2694 var rw *Term 2695 result, rw = rewriteDynamicsOne(original, f, term, result) 2696 cpy.Add(rw) 2697 } 2698 return result, NewTerm(cpy).SetLocation(term.Location) 2699 case *ArrayComprehension: 2700 var extra *Expr 2701 v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term) 2702 result.Append(extra) 2703 return result, result[len(result)-1].Operand(0) 2704 case *SetComprehension: 2705 var extra *Expr 2706 v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term) 2707 result.Append(extra) 2708 return result, result[len(result)-1].Operand(0) 2709 case *ObjectComprehension: 2710 var extra *Expr 2711 v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term) 2712 result.Append(extra) 2713 return result, result[len(result)-1].Operand(0) 2714 } 2715 return result, term 2716} 2717 2718func rewriteDynamicsComprehensionBody(original *Expr, f *equalityFactory, body Body, term *Term) (Body, *Expr) { 2719 body = rewriteDynamics(f, body) 2720 generated := f.Generate(term) 2721 generated.With = original.With 2722 return body, generated 2723} 2724 2725func rewriteExprTermsInHead(gen *localVarGenerator, rule *Rule) { 2726 if rule.Head.Key != nil { 2727 support, output := expandExprTerm(gen, rule.Head.Key) 2728 for i := range support { 2729 rule.Body.Append(support[i]) 2730 } 2731 rule.Head.Key = output 2732 } 2733 if rule.Head.Value != nil { 2734 support, output := expandExprTerm(gen, rule.Head.Value) 2735 for i := range support { 2736 rule.Body.Append(support[i]) 2737 } 2738 rule.Head.Value = output 2739 } 2740} 2741 2742func rewriteExprTermsInBody(gen *localVarGenerator, body Body) Body { 2743 cpy := make(Body, 0, len(body)) 2744 for i := 0; i < len(body); i++ { 2745 for _, expr := range expandExpr(gen, body[i]) { 2746 cpy.Append(expr) 2747 } 2748 } 2749 return cpy 2750} 2751 2752func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) { 2753 for i := range expr.With { 2754 extras, value := expandExprTerm(gen, expr.With[i].Value) 2755 expr.With[i].Value = value 2756 result = append(result, extras...) 2757 } 2758 switch terms := expr.Terms.(type) { 2759 case *Term: 2760 extras, term := expandExprTerm(gen, terms) 2761 if len(expr.With) > 0 { 2762 for i := range extras { 2763 extras[i].With = expr.With 2764 } 2765 } 2766 result = append(result, extras...) 2767 expr.Terms = term 2768 result = append(result, expr) 2769 case []*Term: 2770 for i := 1; i < len(terms); i++ { 2771 var extras []*Expr 2772 extras, terms[i] = expandExprTerm(gen, terms[i]) 2773 if len(expr.With) > 0 { 2774 for i := range extras { 2775 extras[i].With = expr.With 2776 } 2777 } 2778 result = append(result, extras...) 2779 } 2780 result = append(result, expr) 2781 } 2782 return 2783} 2784 2785func expandExprTerm(gen *localVarGenerator, term *Term) (support []*Expr, output *Term) { 2786 output = term 2787 switch v := term.Value.(type) { 2788 case Call: 2789 for i := 1; i < len(v); i++ { 2790 var extras []*Expr 2791 extras, v[i] = expandExprTerm(gen, v[i]) 2792 support = append(support, extras...) 2793 } 2794 output = NewTerm(gen.Generate()).SetLocation(term.Location) 2795 expr := v.MakeExpr(output).SetLocation(term.Location) 2796 expr.Generated = true 2797 support = append(support, expr) 2798 case Ref: 2799 support = expandExprRef(gen, v) 2800 case Array: 2801 support = expandExprTermSlice(gen, v) 2802 case Object: 2803 cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) { 2804 extras1, expandedKey := expandExprTerm(gen, k) 2805 extras2, expandedValue := expandExprTerm(gen, v) 2806 support = append(support, extras1...) 2807 support = append(support, extras2...) 2808 return expandedKey, expandedValue, nil 2809 }) 2810 output = NewTerm(cpy).SetLocation(term.Location) 2811 case Set: 2812 cpy, _ := v.Map(func(x *Term) (*Term, error) { 2813 extras, expanded := expandExprTerm(gen, x) 2814 support = append(support, extras...) 2815 return expanded, nil 2816 }) 2817 output = NewTerm(cpy).SetLocation(term.Location) 2818 case *ArrayComprehension: 2819 support, term := expandExprTerm(gen, v.Term) 2820 for i := range support { 2821 v.Body.Append(support[i]) 2822 } 2823 v.Term = term 2824 v.Body = rewriteExprTermsInBody(gen, v.Body) 2825 case *SetComprehension: 2826 support, term := expandExprTerm(gen, v.Term) 2827 for i := range support { 2828 v.Body.Append(support[i]) 2829 } 2830 v.Term = term 2831 v.Body = rewriteExprTermsInBody(gen, v.Body) 2832 case *ObjectComprehension: 2833 support, key := expandExprTerm(gen, v.Key) 2834 for i := range support { 2835 v.Body.Append(support[i]) 2836 } 2837 v.Key = key 2838 support, value := expandExprTerm(gen, v.Value) 2839 for i := range support { 2840 v.Body.Append(support[i]) 2841 } 2842 v.Value = value 2843 v.Body = rewriteExprTermsInBody(gen, v.Body) 2844 } 2845 return 2846} 2847 2848func expandExprRef(gen *localVarGenerator, v []*Term) (support []*Expr) { 2849 // Start by calling a normal expandExprTerm on all terms. 2850 support = expandExprTermSlice(gen, v) 2851 2852 // Rewrite references in order to support indirect references. We rewrite 2853 // e.g. 2854 // 2855 // [1, 2, 3][i] 2856 // 2857 // to 2858 // 2859 // __local_var = [1, 2, 3] 2860 // __local_var[i] 2861 // 2862 // to support these. This only impacts the reference subject, i.e. the 2863 // first item in the slice. 2864 var subject = v[0] 2865 switch subject.Value.(type) { 2866 case Array, Object, Set, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Call: 2867 f := newEqualityFactory(gen) 2868 assignToLocal := f.Generate(subject) 2869 support = append(support, assignToLocal) 2870 v[0] = assignToLocal.Operand(0) 2871 } 2872 return 2873} 2874 2875func expandExprTermSlice(gen *localVarGenerator, v []*Term) (support []*Expr) { 2876 for i := 0; i < len(v); i++ { 2877 var extras []*Expr 2878 extras, v[i] = expandExprTerm(gen, v[i]) 2879 support = append(support, extras...) 2880 } 2881 return 2882} 2883 2884type localDeclaredVars struct { 2885 vars []*declaredVarSet 2886 2887 // rewritten contains a mapping of *all* user-defined variables 2888 // that have been rewritten whereas vars contains the state 2889 // from the current query (not not any nested queries, and all 2890 // vars seen). 2891 rewritten map[Var]Var 2892} 2893 2894type varOccurrence int 2895 2896const ( 2897 newVar varOccurrence = iota 2898 argVar 2899 seenVar 2900 assignedVar 2901 declaredVar 2902) 2903 2904type declaredVarSet struct { 2905 vs map[Var]Var 2906 reverse map[Var]Var 2907 occurrence map[Var]varOccurrence 2908} 2909 2910func newDeclaredVarSet() *declaredVarSet { 2911 return &declaredVarSet{ 2912 vs: map[Var]Var{}, 2913 reverse: map[Var]Var{}, 2914 occurrence: map[Var]varOccurrence{}, 2915 } 2916} 2917 2918func newLocalDeclaredVars() *localDeclaredVars { 2919 return &localDeclaredVars{ 2920 vars: []*declaredVarSet{newDeclaredVarSet()}, 2921 rewritten: map[Var]Var{}, 2922 } 2923} 2924 2925func (s *localDeclaredVars) Push() { 2926 s.vars = append(s.vars, newDeclaredVarSet()) 2927} 2928 2929func (s *localDeclaredVars) Pop() *declaredVarSet { 2930 sl := s.vars 2931 curr := sl[len(sl)-1] 2932 s.vars = sl[:len(sl)-1] 2933 return curr 2934} 2935 2936func (s localDeclaredVars) Peek() *declaredVarSet { 2937 return s.vars[len(s.vars)-1] 2938} 2939 2940func (s localDeclaredVars) Insert(x, y Var, occurrence varOccurrence) { 2941 elem := s.vars[len(s.vars)-1] 2942 elem.vs[x] = y 2943 elem.reverse[y] = x 2944 elem.occurrence[x] = occurrence 2945 2946 // If the variable has been rewritten (where x != y, with y being 2947 // the generated value), store it in the map of rewritten vars. 2948 // Assume that the generated values are unique for the compilation. 2949 if !x.Equal(y) { 2950 s.rewritten[y] = x 2951 } 2952} 2953 2954func (s localDeclaredVars) Declared(x Var) (y Var, ok bool) { 2955 for i := len(s.vars) - 1; i >= 0; i-- { 2956 if y, ok = s.vars[i].vs[x]; ok { 2957 return 2958 } 2959 } 2960 return 2961} 2962 2963// Occurrence returns a flag that indicates whether x has occurred in the 2964// current scope. 2965func (s localDeclaredVars) Occurrence(x Var) varOccurrence { 2966 return s.vars[len(s.vars)-1].occurrence[x] 2967} 2968 2969// GlobalOccurrence returns a flag that indicates whether x has occurred in the 2970// global scope. 2971func (s localDeclaredVars) GlobalOccurrence(x Var) (varOccurrence, bool) { 2972 for i := len(s.vars) - 1; i >= 0; i-- { 2973 if occ, ok := s.vars[i].occurrence[x]; ok { 2974 return occ, true 2975 } 2976 } 2977 return newVar, false 2978} 2979 2980// rewriteLocalVars rewrites bodies to remove assignment/declaration 2981// expressions. For example: 2982// 2983// a := 1; p[a] 2984// 2985// Is rewritten to: 2986// 2987// __local0__ = 1; p[__local0__] 2988// 2989// During rewriting, assignees are validated to prevent use before declaration. 2990func rewriteLocalVars(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body) (Body, map[Var]Var, Errors) { 2991 var errs Errors 2992 body, errs = rewriteDeclaredVarsInBody(g, stack, used, body, errs) 2993 return body, stack.Pop().vs, errs 2994} 2995 2996func rewriteDeclaredVarsInBody(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body, errs Errors) (Body, Errors) { 2997 2998 var cpy Body 2999 3000 for i := range body { 3001 var expr *Expr 3002 if body[i].IsAssignment() { 3003 expr, errs = rewriteDeclaredAssignment(g, stack, body[i], errs) 3004 } else if decl, ok := body[i].Terms.(*SomeDecl); ok { 3005 errs = rewriteSomeDeclStatement(g, stack, decl, errs) 3006 } else { 3007 expr, errs = rewriteDeclaredVarsInExpr(g, stack, body[i], errs) 3008 } 3009 if expr != nil { 3010 cpy.Append(expr) 3011 } 3012 } 3013 3014 // If the body only contained a var statement it will be empty at this 3015 // point. Append true to the body to ensure that it's non-empty (zero length 3016 // bodies are not supported.) 3017 if len(cpy) == 0 { 3018 cpy.Append(NewExpr(BooleanTerm(true))) 3019 } 3020 3021 return cpy, checkUnusedDeclaredVars(body[0].Loc(), stack, used, cpy, errs) 3022} 3023 3024func checkUnusedDeclaredVars(loc *Location, stack *localDeclaredVars, used VarSet, cpy Body, errs Errors) Errors { 3025 3026 // NOTE(tsandall): Do not generate more errors if there are existing 3027 // declaration errors. 3028 if len(errs) > 0 { 3029 return errs 3030 } 3031 3032 dvs := stack.Peek() 3033 declared := NewVarSet() 3034 3035 for v, occ := range dvs.occurrence { 3036 if occ == declaredVar { 3037 declared.Add(dvs.vs[v]) 3038 } 3039 } 3040 3041 bodyvars := cpy.Vars(VarVisitorParams{}) 3042 3043 for v := range used { 3044 if gv, ok := stack.Declared(v); ok { 3045 bodyvars.Add(gv) 3046 } else { 3047 bodyvars.Add(v) 3048 } 3049 } 3050 3051 unused := declared.Diff(bodyvars).Diff(used) 3052 3053 for _, gv := range unused.Sorted() { 3054 errs = append(errs, NewError(CompileErr, loc, "declared var %v unused", dvs.reverse[gv])) 3055 } 3056 3057 return errs 3058} 3059 3060func rewriteSomeDeclStatement(g *localVarGenerator, stack *localDeclaredVars, decl *SomeDecl, errs Errors) Errors { 3061 for i := range decl.Symbols { 3062 v := decl.Symbols[i].Value.(Var) 3063 if _, err := rewriteDeclaredVar(g, stack, v, declaredVar); err != nil { 3064 errs = append(errs, NewError(CompileErr, decl.Loc(), err.Error())) 3065 } 3066 } 3067 return errs 3068} 3069 3070func rewriteDeclaredVarsInExpr(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors) (*Expr, Errors) { 3071 vis := NewGenericVisitor(func(x interface{}) bool { 3072 var stop bool 3073 switch x := x.(type) { 3074 case *Term: 3075 stop, errs = rewriteDeclaredVarsInTerm(g, stack, x, errs) 3076 case *With: 3077 _, errs = rewriteDeclaredVarsInTerm(g, stack, x.Value, errs) 3078 stop = true 3079 } 3080 return stop 3081 }) 3082 vis.Walk(expr) 3083 return expr, errs 3084} 3085 3086func rewriteDeclaredAssignment(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors) (*Expr, Errors) { 3087 3088 if expr.Negated { 3089 errs = append(errs, NewError(CompileErr, expr.Location, "cannot assign vars inside negated expression")) 3090 return expr, errs 3091 } 3092 3093 numErrsBefore := len(errs) 3094 3095 if !validEqAssignArgCount(expr) { 3096 return expr, errs 3097 } 3098 3099 // Rewrite terms on right hand side capture seen vars and recursively 3100 // process comprehensions before left hand side is processed. Also 3101 // rewrite with modifier. 3102 errs = rewriteDeclaredVarsInTermRecursive(g, stack, expr.Operand(1), errs) 3103 3104 for _, w := range expr.With { 3105 errs = rewriteDeclaredVarsInTermRecursive(g, stack, w.Value, errs) 3106 } 3107 3108 // Rewrite vars on left hand side with unique names. Catch redeclaration 3109 // and invalid term types here. 3110 var vis func(t *Term) bool 3111 3112 vis = func(t *Term) bool { 3113 switch v := t.Value.(type) { 3114 case Var: 3115 if gv, err := rewriteDeclaredVar(g, stack, v, assignedVar); err != nil { 3116 errs = append(errs, NewError(CompileErr, t.Location, err.Error())) 3117 } else { 3118 t.Value = gv 3119 } 3120 return true 3121 case Array: 3122 return false 3123 case Object: 3124 v.Foreach(func(_, v *Term) { 3125 WalkTerms(v, vis) 3126 }) 3127 return true 3128 case Ref: 3129 if RootDocumentRefs.Contains(t) { 3130 if gv, err := rewriteDeclaredVar(g, stack, v[0].Value.(Var), assignedVar); err != nil { 3131 errs = append(errs, NewError(CompileErr, t.Location, err.Error())) 3132 } else { 3133 t.Value = gv 3134 } 3135 return true 3136 } 3137 } 3138 errs = append(errs, NewError(CompileErr, t.Location, "cannot assign to %v", TypeName(t.Value))) 3139 return true 3140 } 3141 3142 WalkTerms(expr.Operand(0), vis) 3143 3144 if len(errs) == numErrsBefore { 3145 loc := expr.Operator()[0].Location 3146 expr.SetOperator(RefTerm(VarTerm(Equality.Name).SetLocation(loc)).SetLocation(loc)) 3147 } 3148 3149 return expr, errs 3150} 3151 3152func rewriteDeclaredVarsInTerm(g *localVarGenerator, stack *localDeclaredVars, term *Term, errs Errors) (bool, Errors) { 3153 switch v := term.Value.(type) { 3154 case Var: 3155 if gv, ok := stack.Declared(v); ok { 3156 term.Value = gv 3157 } else if stack.Occurrence(v) == newVar { 3158 stack.Insert(v, v, seenVar) 3159 } 3160 case Ref: 3161 if RootDocumentRefs.Contains(term) { 3162 x := v[0].Value.(Var) 3163 if occ, ok := stack.GlobalOccurrence(x); ok && occ != seenVar { 3164 gv, _ := stack.Declared(x) 3165 term.Value = gv 3166 } 3167 3168 return true, errs 3169 } 3170 return false, errs 3171 case Object: 3172 cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) { 3173 kcpy := k.Copy() 3174 errs = rewriteDeclaredVarsInTermRecursive(g, stack, kcpy, errs) 3175 errs = rewriteDeclaredVarsInTermRecursive(g, stack, v, errs) 3176 return kcpy, v, nil 3177 }) 3178 term.Value = cpy 3179 case Set: 3180 cpy, _ := v.Map(func(elem *Term) (*Term, error) { 3181 elemcpy := elem.Copy() 3182 errs = rewriteDeclaredVarsInTermRecursive(g, stack, elemcpy, errs) 3183 return elemcpy, nil 3184 }) 3185 term.Value = cpy 3186 case *ArrayComprehension: 3187 errs = rewriteDeclaredVarsInArrayComprehension(g, stack, v, errs) 3188 case *SetComprehension: 3189 errs = rewriteDeclaredVarsInSetComprehension(g, stack, v, errs) 3190 case *ObjectComprehension: 3191 errs = rewriteDeclaredVarsInObjectComprehension(g, stack, v, errs) 3192 default: 3193 return false, errs 3194 } 3195 return true, errs 3196} 3197 3198func rewriteDeclaredVarsInTermRecursive(g *localVarGenerator, stack *localDeclaredVars, term *Term, errs Errors) Errors { 3199 WalkNodes(term, func(n Node) bool { 3200 var stop bool 3201 switch n := n.(type) { 3202 case *With: 3203 _, errs = rewriteDeclaredVarsInTerm(g, stack, n.Value, errs) 3204 stop = true 3205 case *Term: 3206 stop, errs = rewriteDeclaredVarsInTerm(g, stack, n, errs) 3207 } 3208 return stop 3209 }) 3210 return errs 3211} 3212 3213func rewriteDeclaredVarsInArrayComprehension(g *localVarGenerator, stack *localDeclaredVars, v *ArrayComprehension, errs Errors) Errors { 3214 stack.Push() 3215 v.Body, errs = rewriteDeclaredVarsInBody(g, stack, nil, v.Body, errs) 3216 errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Term, errs) 3217 stack.Pop() 3218 return errs 3219} 3220 3221func rewriteDeclaredVarsInSetComprehension(g *localVarGenerator, stack *localDeclaredVars, v *SetComprehension, errs Errors) Errors { 3222 stack.Push() 3223 v.Body, errs = rewriteDeclaredVarsInBody(g, stack, nil, v.Body, errs) 3224 errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Term, errs) 3225 stack.Pop() 3226 return errs 3227} 3228 3229func rewriteDeclaredVarsInObjectComprehension(g *localVarGenerator, stack *localDeclaredVars, v *ObjectComprehension, errs Errors) Errors { 3230 stack.Push() 3231 v.Body, errs = rewriteDeclaredVarsInBody(g, stack, nil, v.Body, errs) 3232 errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Key, errs) 3233 errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Value, errs) 3234 stack.Pop() 3235 return errs 3236} 3237 3238func rewriteDeclaredVar(g *localVarGenerator, stack *localDeclaredVars, v Var, occ varOccurrence) (gv Var, err error) { 3239 switch stack.Occurrence(v) { 3240 case seenVar: 3241 return gv, fmt.Errorf("var %v referenced above", v) 3242 case assignedVar: 3243 return gv, fmt.Errorf("var %v assigned above", v) 3244 case declaredVar: 3245 return gv, fmt.Errorf("var %v declared above", v) 3246 case argVar: 3247 return gv, fmt.Errorf("arg %v redeclared", v) 3248 } 3249 gv = g.Generate() 3250 stack.Insert(v, gv, occ) 3251 return 3252} 3253 3254// rewriteWithModifiersInBody will rewrite the body so that with modifiers do 3255// not contain terms that require evaluation as values. If this function 3256// encounters an invalid with modifier target then it will raise an error. 3257func rewriteWithModifiersInBody(c *Compiler, f *equalityFactory, body Body) (Body, *Error) { 3258 var result Body 3259 for i := range body { 3260 exprs, err := rewriteWithModifier(c, f, body[i]) 3261 if err != nil { 3262 return nil, err 3263 } 3264 if len(exprs) > 0 { 3265 for _, expr := range exprs { 3266 result.Append(expr) 3267 } 3268 } else { 3269 result.Append(body[i]) 3270 } 3271 } 3272 return result, nil 3273} 3274 3275func rewriteWithModifier(c *Compiler, f *equalityFactory, expr *Expr) ([]*Expr, *Error) { 3276 3277 var result []*Expr 3278 for i := range expr.With { 3279 err := validateTarget(c, expr.With[i].Target) 3280 if err != nil { 3281 return nil, err 3282 } 3283 3284 if requiresEval(expr.With[i].Value) { 3285 eq := f.Generate(expr.With[i].Value) 3286 result = append(result, eq) 3287 expr.With[i].Value = eq.Operand(0) 3288 } 3289 } 3290 3291 // If any of the with modifiers in this expression were rewritten then result 3292 // will be non-empty. In this case, the expression will have been modified and 3293 // it should also be added to the result. 3294 if len(result) > 0 { 3295 result = append(result, expr) 3296 } 3297 return result, nil 3298} 3299 3300func validateTarget(c *Compiler, term *Term) *Error { 3301 if !isInputRef(term) && !isDataRef(term) { 3302 return NewError(TypeErr, term.Location, "with keyword target must start with %v or %v", InputRootDocument, DefaultRootDocument) 3303 } 3304 3305 if isDataRef(term) { 3306 ref := term.Value.(Ref) 3307 node := c.RuleTree 3308 for i := 0; i < len(ref)-1; i++ { 3309 child := node.Child(ref[i].Value) 3310 if child == nil { 3311 break 3312 } else if len(child.Values) > 0 { 3313 return NewError(CompileErr, term.Loc(), "with keyword cannot partially replace virtual document(s)") 3314 } 3315 node = child 3316 } 3317 3318 if node != nil { 3319 if child := node.Child(ref[len(ref)-1].Value); child != nil { 3320 for _, value := range child.Values { 3321 if len(value.(*Rule).Head.Args) > 0 { 3322 return NewError(CompileErr, term.Loc(), "with keyword cannot replace functions") 3323 } 3324 } 3325 } 3326 } 3327 3328 } 3329 return nil 3330} 3331 3332func isInputRef(term *Term) bool { 3333 if ref, ok := term.Value.(Ref); ok { 3334 if ref.HasPrefix(InputRootRef) { 3335 return true 3336 } 3337 } 3338 return false 3339} 3340 3341func isDataRef(term *Term) bool { 3342 if ref, ok := term.Value.(Ref); ok { 3343 if ref.HasPrefix(DefaultRootRef) { 3344 return true 3345 } 3346 } 3347 return false 3348} 3349 3350func isVirtual(node *TreeNode, ref Ref) bool { 3351 for i := 0; i < len(ref); i++ { 3352 child := node.Child(ref[i].Value) 3353 if child == nil { 3354 return false 3355 } else if len(child.Values) > 0 { 3356 return true 3357 } 3358 node = child 3359 } 3360 return true 3361} 3362 3363func safetyErrorSlice(unsafe unsafeVars) (result Errors) { 3364 3365 if len(unsafe) == 0 { 3366 return 3367 } 3368 3369 for _, pair := range unsafe.Vars() { 3370 if !pair.Var.IsGenerated() { 3371 result = append(result, NewError(UnsafeVarErr, pair.Loc, "var %v is unsafe", pair.Var)) 3372 } 3373 } 3374 3375 if len(result) > 0 { 3376 return 3377 } 3378 3379 // If the expression contains unsafe generated variables, report which 3380 // expressions are unsafe instead of the variables that are unsafe (since 3381 // the latter are not meaningful to the user.) 3382 pairs := unsafe.Slice() 3383 3384 sort.Slice(pairs, func(i, j int) bool { 3385 return pairs[i].Expr.Location.Compare(pairs[j].Expr.Location) < 0 3386 }) 3387 3388 // Report at most one error per generated variable. 3389 seen := NewVarSet() 3390 3391 for _, expr := range pairs { 3392 before := len(seen) 3393 for v := range expr.Vars { 3394 if v.IsGenerated() { 3395 seen.Add(v) 3396 } 3397 } 3398 if len(seen) > before { 3399 result = append(result, NewError(UnsafeVarErr, expr.Expr.Location, "expression is unsafe")) 3400 } 3401 } 3402 3403 return 3404} 3405 3406func checkUnsafeBuiltins(unsafeBuiltinsMap map[string]struct{}, node interface{}) Errors { 3407 errs := make(Errors, 0) 3408 WalkExprs(node, func(x *Expr) bool { 3409 if x.IsCall() { 3410 operator := x.Operator().String() 3411 if _, ok := unsafeBuiltinsMap[operator]; ok { 3412 errs = append(errs, NewError(TypeErr, x.Loc(), "unsafe built-in function calls in expression: %v", operator)) 3413 } 3414 } 3415 return false 3416 }) 3417 return errs 3418} 3419 3420func rewriteVarsInRef(vars ...map[Var]Var) func(Ref) Ref { 3421 return func(node Ref) Ref { 3422 i, _ := TransformVars(node, func(v Var) (Value, error) { 3423 for _, m := range vars { 3424 if u, ok := m[v]; ok { 3425 return u, nil 3426 } 3427 } 3428 return v, nil 3429 }) 3430 return i.(Ref) 3431 } 3432} 3433 3434func rewriteVarsNop(node Ref) Ref { 3435 return node 3436} 3437