1// Copyright (c) 2016, Daniel Martí <mvdan@mvdan.cc>
2// See LICENSE for licensing information
3
4package syntax
5
6import (
7	"fmt"
8	"io"
9	"reflect"
10)
11
12func walkStmts(stmts []*Stmt, last []Comment, f func(Node) bool) {
13	for _, s := range stmts {
14		Walk(s, f)
15	}
16	for _, c := range last {
17		Walk(&c, f)
18	}
19}
20
21func walkWords(words []*Word, f func(Node) bool) {
22	for _, w := range words {
23		Walk(w, f)
24	}
25}
26
27// Walk traverses a syntax tree in depth-first order: It starts by calling
28// f(node); node must not be nil. If f returns true, Walk invokes f
29// recursively for each of the non-nil children of node, followed by
30// f(nil).
31func Walk(node Node, f func(Node) bool) {
32	if !f(node) {
33		return
34	}
35
36	switch x := node.(type) {
37	case *File:
38		walkStmts(x.Stmts, x.Last, f)
39	case *Comment:
40	case *Stmt:
41		for _, c := range x.Comments {
42			if !x.End().After(c.Pos()) {
43				defer Walk(&c, f)
44				break
45			}
46			Walk(&c, f)
47		}
48		if x.Cmd != nil {
49			Walk(x.Cmd, f)
50		}
51		for _, r := range x.Redirs {
52			Walk(r, f)
53		}
54	case *Assign:
55		if x.Name != nil {
56			Walk(x.Name, f)
57		}
58		if x.Value != nil {
59			Walk(x.Value, f)
60		}
61		if x.Index != nil {
62			Walk(x.Index, f)
63		}
64		if x.Array != nil {
65			Walk(x.Array, f)
66		}
67	case *Redirect:
68		if x.N != nil {
69			Walk(x.N, f)
70		}
71		Walk(x.Word, f)
72		if x.Hdoc != nil {
73			Walk(x.Hdoc, f)
74		}
75	case *CallExpr:
76		for _, a := range x.Assigns {
77			Walk(a, f)
78		}
79		walkWords(x.Args, f)
80	case *Subshell:
81		walkStmts(x.Stmts, x.Last, f)
82	case *Block:
83		walkStmts(x.Stmts, x.Last, f)
84	case *IfClause:
85		walkStmts(x.Cond, x.CondLast, f)
86		walkStmts(x.Then, x.ThenLast, f)
87		if x.Else != nil {
88			Walk(x.Else, f)
89		}
90	case *WhileClause:
91		walkStmts(x.Cond, x.CondLast, f)
92		walkStmts(x.Do, x.DoLast, f)
93	case *ForClause:
94		Walk(x.Loop, f)
95		walkStmts(x.Do, x.DoLast, f)
96	case *WordIter:
97		Walk(x.Name, f)
98		walkWords(x.Items, f)
99	case *CStyleLoop:
100		if x.Init != nil {
101			Walk(x.Init, f)
102		}
103		if x.Cond != nil {
104			Walk(x.Cond, f)
105		}
106		if x.Post != nil {
107			Walk(x.Post, f)
108		}
109	case *BinaryCmd:
110		Walk(x.X, f)
111		Walk(x.Y, f)
112	case *FuncDecl:
113		Walk(x.Name, f)
114		Walk(x.Body, f)
115	case *Word:
116		for _, wp := range x.Parts {
117			Walk(wp, f)
118		}
119	case *Lit:
120	case *SglQuoted:
121	case *DblQuoted:
122		for _, wp := range x.Parts {
123			Walk(wp, f)
124		}
125	case *CmdSubst:
126		walkStmts(x.Stmts, x.Last, f)
127	case *ParamExp:
128		Walk(x.Param, f)
129		if x.Index != nil {
130			Walk(x.Index, f)
131		}
132		if x.Repl != nil {
133			if x.Repl.Orig != nil {
134				Walk(x.Repl.Orig, f)
135			}
136			if x.Repl.With != nil {
137				Walk(x.Repl.With, f)
138			}
139		}
140		if x.Exp != nil && x.Exp.Word != nil {
141			Walk(x.Exp.Word, f)
142		}
143	case *ArithmExp:
144		Walk(x.X, f)
145	case *ArithmCmd:
146		Walk(x.X, f)
147	case *BinaryArithm:
148		Walk(x.X, f)
149		Walk(x.Y, f)
150	case *BinaryTest:
151		Walk(x.X, f)
152		Walk(x.Y, f)
153	case *UnaryArithm:
154		Walk(x.X, f)
155	case *UnaryTest:
156		Walk(x.X, f)
157	case *ParenArithm:
158		Walk(x.X, f)
159	case *ParenTest:
160		Walk(x.X, f)
161	case *CaseClause:
162		Walk(x.Word, f)
163		for _, ci := range x.Items {
164			Walk(ci, f)
165		}
166		for _, c := range x.Last {
167			Walk(&c, f)
168		}
169	case *CaseItem:
170		for _, c := range x.Comments {
171			if c.Pos().After(x.Pos()) {
172				defer Walk(&c, f)
173				break
174			}
175			Walk(&c, f)
176		}
177		walkWords(x.Patterns, f)
178		walkStmts(x.Stmts, x.Last, f)
179	case *TestClause:
180		Walk(x.X, f)
181	case *DeclClause:
182		for _, a := range x.Args {
183			Walk(a, f)
184		}
185	case *ArrayExpr:
186		for _, el := range x.Elems {
187			Walk(el, f)
188		}
189		for _, c := range x.Last {
190			Walk(&c, f)
191		}
192	case *ArrayElem:
193		for _, c := range x.Comments {
194			if c.Pos().After(x.Pos()) {
195				defer Walk(&c, f)
196				break
197			}
198			Walk(&c, f)
199		}
200		if x.Index != nil {
201			Walk(x.Index, f)
202		}
203		if x.Value != nil {
204			Walk(x.Value, f)
205		}
206	case *ExtGlob:
207		Walk(x.Pattern, f)
208	case *ProcSubst:
209		walkStmts(x.Stmts, x.Last, f)
210	case *TimeClause:
211		if x.Stmt != nil {
212			Walk(x.Stmt, f)
213		}
214	case *CoprocClause:
215		if x.Name != nil {
216			Walk(x.Name, f)
217		}
218		Walk(x.Stmt, f)
219	case *LetClause:
220		for _, expr := range x.Exprs {
221			Walk(expr, f)
222		}
223	default:
224		panic(fmt.Sprintf("syntax.Walk: unexpected node type %T", x))
225	}
226
227	f(nil)
228}
229
230// DebugPrint prints the provided syntax tree, spanning multiple lines and with
231// indentation. Can be useful to investigate the content of a syntax tree.
232func DebugPrint(w io.Writer, node Node) error {
233	p := debugPrinter{out: w}
234	p.print(reflect.ValueOf(node))
235	return p.err
236}
237
238type debugPrinter struct {
239	out   io.Writer
240	level int
241	err   error
242}
243
244func (p *debugPrinter) printf(format string, args ...interface{}) {
245	_, err := fmt.Fprintf(p.out, format, args...)
246	if err != nil && p.err == nil {
247		p.err = err
248	}
249}
250
251func (p *debugPrinter) newline() {
252	p.printf("\n")
253	for i := 0; i < p.level; i++ {
254		p.printf(".  ")
255	}
256}
257
258func (p *debugPrinter) print(x reflect.Value) {
259	switch x.Kind() {
260	case reflect.Interface:
261		if x.IsNil() {
262			p.printf("nil")
263			return
264		}
265		p.print(x.Elem())
266	case reflect.Ptr:
267		if x.IsNil() {
268			p.printf("nil")
269			return
270		}
271		p.printf("*")
272		p.print(x.Elem())
273	case reflect.Slice:
274		p.printf("%s (len = %d) {", x.Type(), x.Len())
275		if x.Len() > 0 {
276			p.level++
277			p.newline()
278			for i := 0; i < x.Len(); i++ {
279				p.printf("%d: ", i)
280				p.print(x.Index(i))
281				if i == x.Len()-1 {
282					p.level--
283				}
284				p.newline()
285			}
286		}
287		p.printf("}")
288
289	case reflect.Struct:
290		if v, ok := x.Interface().(Pos); ok {
291			p.printf("%v:%v", v.Line(), v.Col())
292			return
293		}
294		t := x.Type()
295		p.printf("%s {", t)
296		p.level++
297		p.newline()
298		for i := 0; i < t.NumField(); i++ {
299			p.printf("%s: ", t.Field(i).Name)
300			p.print(x.Field(i))
301			if i == x.NumField()-1 {
302				p.level--
303			}
304			p.newline()
305		}
306		p.printf("}")
307	default:
308		p.printf("%#v", x.Interface())
309	}
310}
311