1// Copyright 2016 The OPA Authors.  All rights reserved.
2// Use of this source code is governed by an Apache2
3// license that can be found in the LICENSE file.
4
5package ast
6
7// Visitor defines the interface for iterating AST elements. The Visit function
8// can return a Visitor w which will be used to visit the children of the AST
9// element v. If the Visit function returns nil, the children will not be
10// visited. This is deprecated.
11type Visitor interface {
12	Visit(v interface{}) (w Visitor)
13}
14
15// BeforeAndAfterVisitor wraps Visitor to provide hooks for being called before
16// and after the AST has been visited. This is deprecated.
17type BeforeAndAfterVisitor interface {
18	Visitor
19	Before(x interface{})
20	After(x interface{})
21}
22
23// Walk iterates the AST by calling the Visit function on the Visitor
24// v for x before recursing. This is deprecated.
25func Walk(v Visitor, x interface{}) {
26	if bav, ok := v.(BeforeAndAfterVisitor); !ok {
27		walk(v, x)
28	} else {
29		bav.Before(x)
30		defer bav.After(x)
31		walk(bav, x)
32	}
33}
34
35// WalkBeforeAndAfter iterates the AST by calling the Visit function on the
36// Visitor v for x before recursing. This is deprecated.
37func WalkBeforeAndAfter(v BeforeAndAfterVisitor, x interface{}) {
38	Walk(v, x)
39}
40
41func walk(v Visitor, x interface{}) {
42	w := v.Visit(x)
43	if w == nil {
44		return
45	}
46	switch x := x.(type) {
47	case *Module:
48		Walk(w, x.Package)
49		for _, i := range x.Imports {
50			Walk(w, i)
51		}
52		for _, r := range x.Rules {
53			Walk(w, r)
54		}
55		for _, c := range x.Comments {
56			Walk(w, c)
57		}
58	case *Package:
59		Walk(w, x.Path)
60	case *Import:
61		Walk(w, x.Path)
62		Walk(w, x.Alias)
63	case *Rule:
64		Walk(w, x.Head)
65		Walk(w, x.Body)
66		if x.Else != nil {
67			Walk(w, x.Else)
68		}
69	case *Head:
70		Walk(w, x.Name)
71		Walk(w, x.Args)
72		if x.Key != nil {
73			Walk(w, x.Key)
74		}
75		if x.Value != nil {
76			Walk(w, x.Value)
77		}
78	case Body:
79		for _, e := range x {
80			Walk(w, e)
81		}
82	case Args:
83		for _, t := range x {
84			Walk(w, t)
85		}
86	case *Expr:
87		switch ts := x.Terms.(type) {
88		case *SomeDecl:
89			Walk(w, ts)
90		case []*Term:
91			for _, t := range ts {
92				Walk(w, t)
93			}
94		case *Term:
95			Walk(w, ts)
96		}
97		for i := range x.With {
98			Walk(w, x.With[i])
99		}
100	case *With:
101		Walk(w, x.Target)
102		Walk(w, x.Value)
103	case *Term:
104		Walk(w, x.Value)
105	case Ref:
106		for _, t := range x {
107			Walk(w, t)
108		}
109	case Object:
110		x.Foreach(func(k, vv *Term) {
111			Walk(w, k)
112			Walk(w, vv)
113		})
114	case Array:
115		for _, t := range x {
116			Walk(w, t)
117		}
118	case Set:
119		x.Foreach(func(t *Term) {
120			Walk(w, t)
121		})
122	case *ArrayComprehension:
123		Walk(w, x.Term)
124		Walk(w, x.Body)
125	case *ObjectComprehension:
126		Walk(w, x.Key)
127		Walk(w, x.Value)
128		Walk(w, x.Body)
129	case *SetComprehension:
130		Walk(w, x.Term)
131		Walk(w, x.Body)
132	case Call:
133		for _, t := range x {
134			Walk(w, t)
135		}
136	}
137}
138
139// WalkVars calls the function f on all vars under x. If the function f
140// returns true, AST nodes under the last node will not be visited.
141func WalkVars(x interface{}, f func(Var) bool) {
142	vis := &GenericVisitor{func(x interface{}) bool {
143		if v, ok := x.(Var); ok {
144			return f(v)
145		}
146		return false
147	}}
148	vis.Walk(x)
149}
150
151// WalkClosures calls the function f on all closures under x. If the function f
152// returns true, AST nodes under the last node will not be visited.
153func WalkClosures(x interface{}, f func(interface{}) bool) {
154	vis := &GenericVisitor{func(x interface{}) bool {
155		switch x.(type) {
156		case *ArrayComprehension, *ObjectComprehension, *SetComprehension:
157			return f(x)
158		}
159		return false
160	}}
161	vis.Walk(x)
162}
163
164// WalkRefs calls the function f on all references under x. If the function f
165// returns true, AST nodes under the last node will not be visited.
166func WalkRefs(x interface{}, f func(Ref) bool) {
167	vis := &GenericVisitor{func(x interface{}) bool {
168		if r, ok := x.(Ref); ok {
169			return f(r)
170		}
171		return false
172	}}
173	vis.Walk(x)
174}
175
176// WalkTerms calls the function f on all terms under x. If the function f
177// returns true, AST nodes under the last node will not be visited.
178func WalkTerms(x interface{}, f func(*Term) bool) {
179	vis := &GenericVisitor{func(x interface{}) bool {
180		if term, ok := x.(*Term); ok {
181			return f(term)
182		}
183		return false
184	}}
185	vis.Walk(x)
186}
187
188// WalkWiths calls the function f on all with modifiers under x. If the function f
189// returns true, AST nodes under the last node will not be visited.
190func WalkWiths(x interface{}, f func(*With) bool) {
191	vis := &GenericVisitor{func(x interface{}) bool {
192		if w, ok := x.(*With); ok {
193			return f(w)
194		}
195		return false
196	}}
197	vis.Walk(x)
198}
199
200// WalkExprs calls the function f on all expressions under x. If the function f
201// returns true, AST nodes under the last node will not be visited.
202func WalkExprs(x interface{}, f func(*Expr) bool) {
203	vis := &GenericVisitor{func(x interface{}) bool {
204		if r, ok := x.(*Expr); ok {
205			return f(r)
206		}
207		return false
208	}}
209	vis.Walk(x)
210}
211
212// WalkBodies calls the function f on all bodies under x. If the function f
213// returns true, AST nodes under the last node will not be visited.
214func WalkBodies(x interface{}, f func(Body) bool) {
215	vis := &GenericVisitor{func(x interface{}) bool {
216		if b, ok := x.(Body); ok {
217			return f(b)
218		}
219		return false
220	}}
221	vis.Walk(x)
222}
223
224// WalkRules calls the function f on all rules under x. If the function f
225// returns true, AST nodes under the last node will not be visited.
226func WalkRules(x interface{}, f func(*Rule) bool) {
227	vis := &GenericVisitor{func(x interface{}) bool {
228		if r, ok := x.(*Rule); ok {
229			stop := f(r)
230			// NOTE(tsandall): since rules cannot be embedded inside of queries
231			// we can stop early if there is no else block.
232			if stop || r.Else == nil {
233				return true
234			}
235		}
236		return false
237	}}
238	vis.Walk(x)
239}
240
241// WalkNodes calls the function f on all nodes under x. If the function f
242// returns true, AST nodes under the last node will not be visited.
243func WalkNodes(x interface{}, f func(Node) bool) {
244	vis := &GenericVisitor{func(x interface{}) bool {
245		if n, ok := x.(Node); ok {
246			return f(n)
247		}
248		return false
249	}}
250	vis.Walk(x)
251}
252
253// GenericVisitor provides a utility to walk over AST nodes using a
254// closure. If the closure returns true, the visitor will not walk
255// over AST nodes under x.
256type GenericVisitor struct {
257	f func(x interface{}) bool
258}
259
260// NewGenericVisitor returns a new GenericVisitor that will invoke the function
261// f on AST nodes.
262func NewGenericVisitor(f func(x interface{}) bool) *GenericVisitor {
263	return &GenericVisitor{f}
264}
265
266// Walk iterates the AST by calling the function f on the
267// GenericVisitor before recursing. Contrary to the generic Walk, this
268// does not require allocating the visitor from heap.
269func (vis *GenericVisitor) Walk(x interface{}) {
270	if vis.f(x) {
271		return
272	}
273
274	switch x := x.(type) {
275	case *Module:
276		vis.Walk(x.Package)
277		for _, i := range x.Imports {
278			vis.Walk(i)
279		}
280		for _, r := range x.Rules {
281			vis.Walk(r)
282		}
283		for _, c := range x.Comments {
284			vis.Walk(c)
285		}
286	case *Package:
287		vis.Walk(x.Path)
288	case *Import:
289		vis.Walk(x.Path)
290		vis.Walk(x.Alias)
291	case *Rule:
292		vis.Walk(x.Head)
293		vis.Walk(x.Body)
294		if x.Else != nil {
295			vis.Walk(x.Else)
296		}
297	case *Head:
298		vis.Walk(x.Name)
299		vis.Walk(x.Args)
300		if x.Key != nil {
301			vis.Walk(x.Key)
302		}
303		if x.Value != nil {
304			vis.Walk(x.Value)
305		}
306	case Body:
307		for _, e := range x {
308			vis.Walk(e)
309		}
310	case Args:
311		for _, t := range x {
312			vis.Walk(t)
313		}
314	case *Expr:
315		switch ts := x.Terms.(type) {
316		case *SomeDecl:
317			vis.Walk(ts)
318		case []*Term:
319			for _, t := range ts {
320				vis.Walk(t)
321			}
322		case *Term:
323			vis.Walk(ts)
324		}
325		for i := range x.With {
326			vis.Walk(x.With[i])
327		}
328	case *With:
329		vis.Walk(x.Target)
330		vis.Walk(x.Value)
331	case *Term:
332		vis.Walk(x.Value)
333	case Ref:
334		for _, t := range x {
335			vis.Walk(t)
336		}
337	case Object:
338		for _, k := range x.Keys() {
339			vis.Walk(k)
340			vis.Walk(x.Get(k))
341		}
342	case Array:
343		for _, t := range x {
344			vis.Walk(t)
345		}
346	case Set:
347		for _, t := range x.Slice() {
348			vis.Walk(t)
349		}
350	case *ArrayComprehension:
351		vis.Walk(x.Term)
352		vis.Walk(x.Body)
353	case *ObjectComprehension:
354		vis.Walk(x.Key)
355		vis.Walk(x.Value)
356		vis.Walk(x.Body)
357	case *SetComprehension:
358		vis.Walk(x.Term)
359		vis.Walk(x.Body)
360	case Call:
361		for _, t := range x {
362			vis.Walk(t)
363		}
364	}
365}
366
367// BeforeAfterVisitor provides a utility to walk over AST nodes using
368// closures. If the before closure returns true, the visitor will not
369// walk over AST nodes under x. The after closure is invoked always
370// after visiting a node.
371type BeforeAfterVisitor struct {
372	before func(x interface{}) bool
373	after  func(x interface{})
374}
375
376// NewBeforeAfterVisitor returns a new BeforeAndAfterVisitor that
377// will invoke the functions before and after AST nodes.
378func NewBeforeAfterVisitor(before func(x interface{}) bool, after func(x interface{})) *BeforeAfterVisitor {
379	return &BeforeAfterVisitor{before, after}
380}
381
382// Walk iterates the AST by calling the functions on the
383// BeforeAndAfterVisitor before and after recursing. Contrary to the
384// generic Walk, this does not require allocating the visitor from
385// heap.
386func (vis *BeforeAfterVisitor) Walk(x interface{}) {
387	defer vis.after(x)
388	if vis.before(x) {
389		return
390	}
391
392	switch x := x.(type) {
393	case *Module:
394		vis.Walk(x.Package)
395		for _, i := range x.Imports {
396			vis.Walk(i)
397		}
398		for _, r := range x.Rules {
399			vis.Walk(r)
400		}
401		for _, c := range x.Comments {
402			vis.Walk(c)
403		}
404	case *Package:
405		vis.Walk(x.Path)
406	case *Import:
407		vis.Walk(x.Path)
408		vis.Walk(x.Alias)
409	case *Rule:
410		vis.Walk(x.Head)
411		vis.Walk(x.Body)
412		if x.Else != nil {
413			vis.Walk(x.Else)
414		}
415	case *Head:
416		vis.Walk(x.Name)
417		vis.Walk(x.Args)
418		if x.Key != nil {
419			vis.Walk(x.Key)
420		}
421		if x.Value != nil {
422			vis.Walk(x.Value)
423		}
424	case Body:
425		for _, e := range x {
426			vis.Walk(e)
427		}
428	case Args:
429		for _, t := range x {
430			vis.Walk(t)
431		}
432	case *Expr:
433		switch ts := x.Terms.(type) {
434		case *SomeDecl:
435			vis.Walk(ts)
436		case []*Term:
437			for _, t := range ts {
438				vis.Walk(t)
439			}
440		case *Term:
441			vis.Walk(ts)
442		}
443		for i := range x.With {
444			vis.Walk(x.With[i])
445		}
446	case *With:
447		vis.Walk(x.Target)
448		vis.Walk(x.Value)
449	case *Term:
450		vis.Walk(x.Value)
451	case Ref:
452		for _, t := range x {
453			vis.Walk(t)
454		}
455	case Object:
456		for _, k := range x.Keys() {
457			vis.Walk(k)
458			vis.Walk(x.Get(k))
459		}
460	case Array:
461		for _, t := range x {
462			vis.Walk(t)
463		}
464	case Set:
465		for _, t := range x.Slice() {
466			vis.Walk(t)
467		}
468	case *ArrayComprehension:
469		vis.Walk(x.Term)
470		vis.Walk(x.Body)
471	case *ObjectComprehension:
472		vis.Walk(x.Key)
473		vis.Walk(x.Value)
474		vis.Walk(x.Body)
475	case *SetComprehension:
476		vis.Walk(x.Term)
477		vis.Walk(x.Body)
478	case Call:
479		for _, t := range x {
480			vis.Walk(t)
481		}
482	}
483}
484
485// VarVisitor walks AST nodes under a given node and collects all encountered
486// variables. The collected variables can be controlled by specifying
487// VarVisitorParams when creating the visitor.
488type VarVisitor struct {
489	params VarVisitorParams
490	vars   VarSet
491}
492
493// VarVisitorParams contains settings for a VarVisitor.
494type VarVisitorParams struct {
495	SkipRefHead     bool
496	SkipRefCallHead bool
497	SkipObjectKeys  bool
498	SkipClosures    bool
499	SkipWithTarget  bool
500	SkipSets        bool
501}
502
503// NewVarVisitor returns a new VarVisitor object.
504func NewVarVisitor() *VarVisitor {
505	return &VarVisitor{
506		vars: NewVarSet(),
507	}
508}
509
510// WithParams sets the parameters in params on vis.
511func (vis *VarVisitor) WithParams(params VarVisitorParams) *VarVisitor {
512	vis.params = params
513	return vis
514}
515
516// Vars returns a VarSet that contains collected vars.
517func (vis *VarVisitor) Vars() VarSet {
518	return vis.vars
519}
520
521func (vis *VarVisitor) visit(v interface{}) bool {
522	if vis.params.SkipObjectKeys {
523		if o, ok := v.(Object); ok {
524			for _, k := range o.Keys() {
525				vis.Walk(o.Get(k))
526			}
527			return true
528		}
529	}
530	if vis.params.SkipRefHead {
531		if r, ok := v.(Ref); ok {
532			for _, t := range r[1:] {
533				vis.Walk(t)
534			}
535			return true
536		}
537	}
538	if vis.params.SkipClosures {
539		switch v.(type) {
540		case *ArrayComprehension, *ObjectComprehension, *SetComprehension:
541			return true
542		}
543	}
544	if vis.params.SkipWithTarget {
545		if v, ok := v.(*With); ok {
546			vis.Walk(v.Value)
547			return true
548		}
549	}
550	if vis.params.SkipSets {
551		if _, ok := v.(Set); ok {
552			return true
553		}
554	}
555	if vis.params.SkipRefCallHead {
556		switch v := v.(type) {
557		case *Expr:
558			if terms, ok := v.Terms.([]*Term); ok {
559				for _, t := range terms[0].Value.(Ref)[1:] {
560					vis.Walk(t)
561				}
562				for i := 1; i < len(terms); i++ {
563					vis.Walk(terms[i])
564				}
565				for _, w := range v.With {
566					vis.Walk(w)
567				}
568				return true
569			}
570		case Call:
571			operator := v[0].Value.(Ref)
572			for i := 1; i < len(operator); i++ {
573				vis.Walk(operator[i])
574			}
575			for i := 1; i < len(v); i++ {
576				vis.Walk(v[i])
577			}
578			return true
579		}
580	}
581	if v, ok := v.(Var); ok {
582		vis.vars.Add(v)
583	}
584	return false
585}
586
587// Walk iterates the AST by calling the function f on the
588// GenericVisitor before recursing. Contrary to the generic Walk, this
589// does not require allocating the visitor from heap.
590func (vis *VarVisitor) Walk(x interface{}) {
591	if vis.visit(x) {
592		return
593	}
594
595	switch x := x.(type) {
596	case *Module:
597		vis.Walk(x.Package)
598		for _, i := range x.Imports {
599			vis.Walk(i)
600		}
601		for _, r := range x.Rules {
602			vis.Walk(r)
603		}
604		for _, c := range x.Comments {
605			vis.Walk(c)
606		}
607	case *Package:
608		vis.Walk(x.Path)
609	case *Import:
610		vis.Walk(x.Path)
611		vis.Walk(x.Alias)
612	case *Rule:
613		vis.Walk(x.Head)
614		vis.Walk(x.Body)
615		if x.Else != nil {
616			vis.Walk(x.Else)
617		}
618	case *Head:
619		vis.Walk(x.Name)
620		vis.Walk(x.Args)
621		if x.Key != nil {
622			vis.Walk(x.Key)
623		}
624		if x.Value != nil {
625			vis.Walk(x.Value)
626		}
627	case Body:
628		for _, e := range x {
629			vis.Walk(e)
630		}
631	case Args:
632		for _, t := range x {
633			vis.Walk(t)
634		}
635	case *Expr:
636		switch ts := x.Terms.(type) {
637		case *SomeDecl:
638			vis.Walk(ts)
639		case []*Term:
640			for _, t := range ts {
641				vis.Walk(t)
642			}
643		case *Term:
644			vis.Walk(ts)
645		}
646		for i := range x.With {
647			vis.Walk(x.With[i])
648		}
649	case *With:
650		vis.Walk(x.Target)
651		vis.Walk(x.Value)
652	case *Term:
653		vis.Walk(x.Value)
654	case Ref:
655		for _, t := range x {
656			vis.Walk(t)
657		}
658	case Object:
659		for _, k := range x.Keys() {
660			vis.Walk(k)
661			vis.Walk(x.Get(k))
662		}
663	case Array:
664		for _, t := range x {
665			vis.Walk(t)
666		}
667	case Set:
668		for _, t := range x.Slice() {
669			vis.Walk(t)
670		}
671	case *ArrayComprehension:
672		vis.Walk(x.Term)
673		vis.Walk(x.Body)
674	case *ObjectComprehension:
675		vis.Walk(x.Key)
676		vis.Walk(x.Value)
677		vis.Walk(x.Body)
678	case *SetComprehension:
679		vis.Walk(x.Term)
680		vis.Walk(x.Body)
681	case Call:
682		for _, t := range x {
683			vis.Walk(t)
684		}
685	}
686}
687