1// Copyright 2016 The OPA Authors. All rights reserved. 2// Use of this source code is governed by an Apache2 3// license that can be found in the LICENSE file. 4 5package ast 6 7// Visitor defines the interface for iterating AST elements. The Visit function 8// can return a Visitor w which will be used to visit the children of the AST 9// element v. If the Visit function returns nil, the children will not be 10// visited. This is deprecated. 11type Visitor interface { 12 Visit(v interface{}) (w Visitor) 13} 14 15// BeforeAndAfterVisitor wraps Visitor to provide hooks for being called before 16// and after the AST has been visited. This is deprecated. 17type BeforeAndAfterVisitor interface { 18 Visitor 19 Before(x interface{}) 20 After(x interface{}) 21} 22 23// Walk iterates the AST by calling the Visit function on the Visitor 24// v for x before recursing. This is deprecated. 25func Walk(v Visitor, x interface{}) { 26 if bav, ok := v.(BeforeAndAfterVisitor); !ok { 27 walk(v, x) 28 } else { 29 bav.Before(x) 30 defer bav.After(x) 31 walk(bav, x) 32 } 33} 34 35// WalkBeforeAndAfter iterates the AST by calling the Visit function on the 36// Visitor v for x before recursing. This is deprecated. 37func WalkBeforeAndAfter(v BeforeAndAfterVisitor, x interface{}) { 38 Walk(v, x) 39} 40 41func walk(v Visitor, x interface{}) { 42 w := v.Visit(x) 43 if w == nil { 44 return 45 } 46 switch x := x.(type) { 47 case *Module: 48 Walk(w, x.Package) 49 for _, i := range x.Imports { 50 Walk(w, i) 51 } 52 for _, r := range x.Rules { 53 Walk(w, r) 54 } 55 for _, c := range x.Comments { 56 Walk(w, c) 57 } 58 case *Package: 59 Walk(w, x.Path) 60 case *Import: 61 Walk(w, x.Path) 62 Walk(w, x.Alias) 63 case *Rule: 64 Walk(w, x.Head) 65 Walk(w, x.Body) 66 if x.Else != nil { 67 Walk(w, x.Else) 68 } 69 case *Head: 70 Walk(w, x.Name) 71 Walk(w, x.Args) 72 if x.Key != nil { 73 Walk(w, x.Key) 74 } 75 if x.Value != nil { 76 Walk(w, x.Value) 77 } 78 case Body: 79 for _, e := range x { 80 Walk(w, e) 81 } 82 case Args: 83 for _, t := range x { 84 Walk(w, t) 85 } 86 case *Expr: 87 switch ts := x.Terms.(type) { 88 case *SomeDecl: 89 Walk(w, ts) 90 case []*Term: 91 for _, t := range ts { 92 Walk(w, t) 93 } 94 case *Term: 95 Walk(w, ts) 96 } 97 for i := range x.With { 98 Walk(w, x.With[i]) 99 } 100 case *With: 101 Walk(w, x.Target) 102 Walk(w, x.Value) 103 case *Term: 104 Walk(w, x.Value) 105 case Ref: 106 for _, t := range x { 107 Walk(w, t) 108 } 109 case Object: 110 x.Foreach(func(k, vv *Term) { 111 Walk(w, k) 112 Walk(w, vv) 113 }) 114 case Array: 115 for _, t := range x { 116 Walk(w, t) 117 } 118 case Set: 119 x.Foreach(func(t *Term) { 120 Walk(w, t) 121 }) 122 case *ArrayComprehension: 123 Walk(w, x.Term) 124 Walk(w, x.Body) 125 case *ObjectComprehension: 126 Walk(w, x.Key) 127 Walk(w, x.Value) 128 Walk(w, x.Body) 129 case *SetComprehension: 130 Walk(w, x.Term) 131 Walk(w, x.Body) 132 case Call: 133 for _, t := range x { 134 Walk(w, t) 135 } 136 } 137} 138 139// WalkVars calls the function f on all vars under x. If the function f 140// returns true, AST nodes under the last node will not be visited. 141func WalkVars(x interface{}, f func(Var) bool) { 142 vis := &GenericVisitor{func(x interface{}) bool { 143 if v, ok := x.(Var); ok { 144 return f(v) 145 } 146 return false 147 }} 148 vis.Walk(x) 149} 150 151// WalkClosures calls the function f on all closures under x. If the function f 152// returns true, AST nodes under the last node will not be visited. 153func WalkClosures(x interface{}, f func(interface{}) bool) { 154 vis := &GenericVisitor{func(x interface{}) bool { 155 switch x.(type) { 156 case *ArrayComprehension, *ObjectComprehension, *SetComprehension: 157 return f(x) 158 } 159 return false 160 }} 161 vis.Walk(x) 162} 163 164// WalkRefs calls the function f on all references under x. If the function f 165// returns true, AST nodes under the last node will not be visited. 166func WalkRefs(x interface{}, f func(Ref) bool) { 167 vis := &GenericVisitor{func(x interface{}) bool { 168 if r, ok := x.(Ref); ok { 169 return f(r) 170 } 171 return false 172 }} 173 vis.Walk(x) 174} 175 176// WalkTerms calls the function f on all terms under x. If the function f 177// returns true, AST nodes under the last node will not be visited. 178func WalkTerms(x interface{}, f func(*Term) bool) { 179 vis := &GenericVisitor{func(x interface{}) bool { 180 if term, ok := x.(*Term); ok { 181 return f(term) 182 } 183 return false 184 }} 185 vis.Walk(x) 186} 187 188// WalkWiths calls the function f on all with modifiers under x. If the function f 189// returns true, AST nodes under the last node will not be visited. 190func WalkWiths(x interface{}, f func(*With) bool) { 191 vis := &GenericVisitor{func(x interface{}) bool { 192 if w, ok := x.(*With); ok { 193 return f(w) 194 } 195 return false 196 }} 197 vis.Walk(x) 198} 199 200// WalkExprs calls the function f on all expressions under x. If the function f 201// returns true, AST nodes under the last node will not be visited. 202func WalkExprs(x interface{}, f func(*Expr) bool) { 203 vis := &GenericVisitor{func(x interface{}) bool { 204 if r, ok := x.(*Expr); ok { 205 return f(r) 206 } 207 return false 208 }} 209 vis.Walk(x) 210} 211 212// WalkBodies calls the function f on all bodies under x. If the function f 213// returns true, AST nodes under the last node will not be visited. 214func WalkBodies(x interface{}, f func(Body) bool) { 215 vis := &GenericVisitor{func(x interface{}) bool { 216 if b, ok := x.(Body); ok { 217 return f(b) 218 } 219 return false 220 }} 221 vis.Walk(x) 222} 223 224// WalkRules calls the function f on all rules under x. If the function f 225// returns true, AST nodes under the last node will not be visited. 226func WalkRules(x interface{}, f func(*Rule) bool) { 227 vis := &GenericVisitor{func(x interface{}) bool { 228 if r, ok := x.(*Rule); ok { 229 stop := f(r) 230 // NOTE(tsandall): since rules cannot be embedded inside of queries 231 // we can stop early if there is no else block. 232 if stop || r.Else == nil { 233 return true 234 } 235 } 236 return false 237 }} 238 vis.Walk(x) 239} 240 241// WalkNodes calls the function f on all nodes under x. If the function f 242// returns true, AST nodes under the last node will not be visited. 243func WalkNodes(x interface{}, f func(Node) bool) { 244 vis := &GenericVisitor{func(x interface{}) bool { 245 if n, ok := x.(Node); ok { 246 return f(n) 247 } 248 return false 249 }} 250 vis.Walk(x) 251} 252 253// GenericVisitor provides a utility to walk over AST nodes using a 254// closure. If the closure returns true, the visitor will not walk 255// over AST nodes under x. 256type GenericVisitor struct { 257 f func(x interface{}) bool 258} 259 260// NewGenericVisitor returns a new GenericVisitor that will invoke the function 261// f on AST nodes. 262func NewGenericVisitor(f func(x interface{}) bool) *GenericVisitor { 263 return &GenericVisitor{f} 264} 265 266// Walk iterates the AST by calling the function f on the 267// GenericVisitor before recursing. Contrary to the generic Walk, this 268// does not require allocating the visitor from heap. 269func (vis *GenericVisitor) Walk(x interface{}) { 270 if vis.f(x) { 271 return 272 } 273 274 switch x := x.(type) { 275 case *Module: 276 vis.Walk(x.Package) 277 for _, i := range x.Imports { 278 vis.Walk(i) 279 } 280 for _, r := range x.Rules { 281 vis.Walk(r) 282 } 283 for _, c := range x.Comments { 284 vis.Walk(c) 285 } 286 case *Package: 287 vis.Walk(x.Path) 288 case *Import: 289 vis.Walk(x.Path) 290 vis.Walk(x.Alias) 291 case *Rule: 292 vis.Walk(x.Head) 293 vis.Walk(x.Body) 294 if x.Else != nil { 295 vis.Walk(x.Else) 296 } 297 case *Head: 298 vis.Walk(x.Name) 299 vis.Walk(x.Args) 300 if x.Key != nil { 301 vis.Walk(x.Key) 302 } 303 if x.Value != nil { 304 vis.Walk(x.Value) 305 } 306 case Body: 307 for _, e := range x { 308 vis.Walk(e) 309 } 310 case Args: 311 for _, t := range x { 312 vis.Walk(t) 313 } 314 case *Expr: 315 switch ts := x.Terms.(type) { 316 case *SomeDecl: 317 vis.Walk(ts) 318 case []*Term: 319 for _, t := range ts { 320 vis.Walk(t) 321 } 322 case *Term: 323 vis.Walk(ts) 324 } 325 for i := range x.With { 326 vis.Walk(x.With[i]) 327 } 328 case *With: 329 vis.Walk(x.Target) 330 vis.Walk(x.Value) 331 case *Term: 332 vis.Walk(x.Value) 333 case Ref: 334 for _, t := range x { 335 vis.Walk(t) 336 } 337 case Object: 338 for _, k := range x.Keys() { 339 vis.Walk(k) 340 vis.Walk(x.Get(k)) 341 } 342 case Array: 343 for _, t := range x { 344 vis.Walk(t) 345 } 346 case Set: 347 for _, t := range x.Slice() { 348 vis.Walk(t) 349 } 350 case *ArrayComprehension: 351 vis.Walk(x.Term) 352 vis.Walk(x.Body) 353 case *ObjectComprehension: 354 vis.Walk(x.Key) 355 vis.Walk(x.Value) 356 vis.Walk(x.Body) 357 case *SetComprehension: 358 vis.Walk(x.Term) 359 vis.Walk(x.Body) 360 case Call: 361 for _, t := range x { 362 vis.Walk(t) 363 } 364 } 365} 366 367// BeforeAfterVisitor provides a utility to walk over AST nodes using 368// closures. If the before closure returns true, the visitor will not 369// walk over AST nodes under x. The after closure is invoked always 370// after visiting a node. 371type BeforeAfterVisitor struct { 372 before func(x interface{}) bool 373 after func(x interface{}) 374} 375 376// NewBeforeAfterVisitor returns a new BeforeAndAfterVisitor that 377// will invoke the functions before and after AST nodes. 378func NewBeforeAfterVisitor(before func(x interface{}) bool, after func(x interface{})) *BeforeAfterVisitor { 379 return &BeforeAfterVisitor{before, after} 380} 381 382// Walk iterates the AST by calling the functions on the 383// BeforeAndAfterVisitor before and after recursing. Contrary to the 384// generic Walk, this does not require allocating the visitor from 385// heap. 386func (vis *BeforeAfterVisitor) Walk(x interface{}) { 387 defer vis.after(x) 388 if vis.before(x) { 389 return 390 } 391 392 switch x := x.(type) { 393 case *Module: 394 vis.Walk(x.Package) 395 for _, i := range x.Imports { 396 vis.Walk(i) 397 } 398 for _, r := range x.Rules { 399 vis.Walk(r) 400 } 401 for _, c := range x.Comments { 402 vis.Walk(c) 403 } 404 case *Package: 405 vis.Walk(x.Path) 406 case *Import: 407 vis.Walk(x.Path) 408 vis.Walk(x.Alias) 409 case *Rule: 410 vis.Walk(x.Head) 411 vis.Walk(x.Body) 412 if x.Else != nil { 413 vis.Walk(x.Else) 414 } 415 case *Head: 416 vis.Walk(x.Name) 417 vis.Walk(x.Args) 418 if x.Key != nil { 419 vis.Walk(x.Key) 420 } 421 if x.Value != nil { 422 vis.Walk(x.Value) 423 } 424 case Body: 425 for _, e := range x { 426 vis.Walk(e) 427 } 428 case Args: 429 for _, t := range x { 430 vis.Walk(t) 431 } 432 case *Expr: 433 switch ts := x.Terms.(type) { 434 case *SomeDecl: 435 vis.Walk(ts) 436 case []*Term: 437 for _, t := range ts { 438 vis.Walk(t) 439 } 440 case *Term: 441 vis.Walk(ts) 442 } 443 for i := range x.With { 444 vis.Walk(x.With[i]) 445 } 446 case *With: 447 vis.Walk(x.Target) 448 vis.Walk(x.Value) 449 case *Term: 450 vis.Walk(x.Value) 451 case Ref: 452 for _, t := range x { 453 vis.Walk(t) 454 } 455 case Object: 456 for _, k := range x.Keys() { 457 vis.Walk(k) 458 vis.Walk(x.Get(k)) 459 } 460 case Array: 461 for _, t := range x { 462 vis.Walk(t) 463 } 464 case Set: 465 for _, t := range x.Slice() { 466 vis.Walk(t) 467 } 468 case *ArrayComprehension: 469 vis.Walk(x.Term) 470 vis.Walk(x.Body) 471 case *ObjectComprehension: 472 vis.Walk(x.Key) 473 vis.Walk(x.Value) 474 vis.Walk(x.Body) 475 case *SetComprehension: 476 vis.Walk(x.Term) 477 vis.Walk(x.Body) 478 case Call: 479 for _, t := range x { 480 vis.Walk(t) 481 } 482 } 483} 484 485// VarVisitor walks AST nodes under a given node and collects all encountered 486// variables. The collected variables can be controlled by specifying 487// VarVisitorParams when creating the visitor. 488type VarVisitor struct { 489 params VarVisitorParams 490 vars VarSet 491} 492 493// VarVisitorParams contains settings for a VarVisitor. 494type VarVisitorParams struct { 495 SkipRefHead bool 496 SkipRefCallHead bool 497 SkipObjectKeys bool 498 SkipClosures bool 499 SkipWithTarget bool 500 SkipSets bool 501} 502 503// NewVarVisitor returns a new VarVisitor object. 504func NewVarVisitor() *VarVisitor { 505 return &VarVisitor{ 506 vars: NewVarSet(), 507 } 508} 509 510// WithParams sets the parameters in params on vis. 511func (vis *VarVisitor) WithParams(params VarVisitorParams) *VarVisitor { 512 vis.params = params 513 return vis 514} 515 516// Vars returns a VarSet that contains collected vars. 517func (vis *VarVisitor) Vars() VarSet { 518 return vis.vars 519} 520 521func (vis *VarVisitor) visit(v interface{}) bool { 522 if vis.params.SkipObjectKeys { 523 if o, ok := v.(Object); ok { 524 for _, k := range o.Keys() { 525 vis.Walk(o.Get(k)) 526 } 527 return true 528 } 529 } 530 if vis.params.SkipRefHead { 531 if r, ok := v.(Ref); ok { 532 for _, t := range r[1:] { 533 vis.Walk(t) 534 } 535 return true 536 } 537 } 538 if vis.params.SkipClosures { 539 switch v.(type) { 540 case *ArrayComprehension, *ObjectComprehension, *SetComprehension: 541 return true 542 } 543 } 544 if vis.params.SkipWithTarget { 545 if v, ok := v.(*With); ok { 546 vis.Walk(v.Value) 547 return true 548 } 549 } 550 if vis.params.SkipSets { 551 if _, ok := v.(Set); ok { 552 return true 553 } 554 } 555 if vis.params.SkipRefCallHead { 556 switch v := v.(type) { 557 case *Expr: 558 if terms, ok := v.Terms.([]*Term); ok { 559 for _, t := range terms[0].Value.(Ref)[1:] { 560 vis.Walk(t) 561 } 562 for i := 1; i < len(terms); i++ { 563 vis.Walk(terms[i]) 564 } 565 for _, w := range v.With { 566 vis.Walk(w) 567 } 568 return true 569 } 570 case Call: 571 operator := v[0].Value.(Ref) 572 for i := 1; i < len(operator); i++ { 573 vis.Walk(operator[i]) 574 } 575 for i := 1; i < len(v); i++ { 576 vis.Walk(v[i]) 577 } 578 return true 579 } 580 } 581 if v, ok := v.(Var); ok { 582 vis.vars.Add(v) 583 } 584 return false 585} 586 587// Walk iterates the AST by calling the function f on the 588// GenericVisitor before recursing. Contrary to the generic Walk, this 589// does not require allocating the visitor from heap. 590func (vis *VarVisitor) Walk(x interface{}) { 591 if vis.visit(x) { 592 return 593 } 594 595 switch x := x.(type) { 596 case *Module: 597 vis.Walk(x.Package) 598 for _, i := range x.Imports { 599 vis.Walk(i) 600 } 601 for _, r := range x.Rules { 602 vis.Walk(r) 603 } 604 for _, c := range x.Comments { 605 vis.Walk(c) 606 } 607 case *Package: 608 vis.Walk(x.Path) 609 case *Import: 610 vis.Walk(x.Path) 611 vis.Walk(x.Alias) 612 case *Rule: 613 vis.Walk(x.Head) 614 vis.Walk(x.Body) 615 if x.Else != nil { 616 vis.Walk(x.Else) 617 } 618 case *Head: 619 vis.Walk(x.Name) 620 vis.Walk(x.Args) 621 if x.Key != nil { 622 vis.Walk(x.Key) 623 } 624 if x.Value != nil { 625 vis.Walk(x.Value) 626 } 627 case Body: 628 for _, e := range x { 629 vis.Walk(e) 630 } 631 case Args: 632 for _, t := range x { 633 vis.Walk(t) 634 } 635 case *Expr: 636 switch ts := x.Terms.(type) { 637 case *SomeDecl: 638 vis.Walk(ts) 639 case []*Term: 640 for _, t := range ts { 641 vis.Walk(t) 642 } 643 case *Term: 644 vis.Walk(ts) 645 } 646 for i := range x.With { 647 vis.Walk(x.With[i]) 648 } 649 case *With: 650 vis.Walk(x.Target) 651 vis.Walk(x.Value) 652 case *Term: 653 vis.Walk(x.Value) 654 case Ref: 655 for _, t := range x { 656 vis.Walk(t) 657 } 658 case Object: 659 for _, k := range x.Keys() { 660 vis.Walk(k) 661 vis.Walk(x.Get(k)) 662 } 663 case Array: 664 for _, t := range x { 665 vis.Walk(t) 666 } 667 case Set: 668 for _, t := range x.Slice() { 669 vis.Walk(t) 670 } 671 case *ArrayComprehension: 672 vis.Walk(x.Term) 673 vis.Walk(x.Body) 674 case *ObjectComprehension: 675 vis.Walk(x.Key) 676 vis.Walk(x.Value) 677 vis.Walk(x.Body) 678 case *SetComprehension: 679 vis.Walk(x.Term) 680 vis.Walk(x.Body) 681 case Call: 682 for _, t := range x { 683 vis.Walk(t) 684 } 685 } 686} 687