1// Copyright (c) 2017, Daniel Martí <mvdan@mvdan.cc>
2// See LICENSE for licensing information
3
4package gogrep
5
6import (
7	"bytes"
8	"fmt"
9	"go/ast"
10	"go/parser"
11	"go/scanner"
12	"go/token"
13	"strings"
14	"text/template"
15)
16
17func transformSource(expr string) (string, []posOffset, error) {
18	toks, err := tokenize([]byte(expr))
19	if err != nil {
20		return "", nil, fmt.Errorf("cannot tokenize expr: %v", err)
21	}
22	var offs []posOffset
23	lbuf := lineColBuffer{line: 1, col: 1}
24	lastLit := false
25	for _, t := range toks {
26		if lbuf.offs >= t.pos.Offset && lastLit && t.lit != "" {
27			_, _ = lbuf.WriteString(" ")
28		}
29		for lbuf.offs < t.pos.Offset {
30			_, _ = lbuf.WriteString(" ")
31		}
32		if t.lit == "" {
33			_, _ = lbuf.WriteString(t.tok.String())
34			lastLit = false
35			continue
36		}
37		_, _ = lbuf.WriteString(t.lit)
38		lastLit = strings.TrimSpace(t.lit) != ""
39	}
40	// trailing newlines can cause issues with commas
41	return strings.TrimSpace(lbuf.String()), offs, nil
42}
43
44func parseExpr(fset *token.FileSet, expr string) (ast.Node, error) {
45	exprStr, offs, err := transformSource(expr)
46	if err != nil {
47		return nil, err
48	}
49	node, _, err := parseDetectingNode(fset, exprStr)
50	if err != nil {
51		err = subPosOffsets(err, offs...)
52		return nil, fmt.Errorf("cannot parse expr: %v", err)
53	}
54	return node, nil
55}
56
57type lineColBuffer struct {
58	bytes.Buffer
59	line, col, offs int
60}
61
62func (l *lineColBuffer) WriteString(s string) (n int, err error) {
63	for _, r := range s {
64		if r == '\n' {
65			l.line++
66			l.col = 1
67		} else {
68			l.col++
69		}
70		l.offs++
71	}
72	return l.Buffer.WriteString(s)
73}
74
75var tmplDecl = template.Must(template.New("").Parse(`` +
76	`package p; {{ . }}`))
77
78var tmplBlock = template.Must(template.New("").Parse(`` +
79	`package p; func _() { if true {{ . }} else {} }`))
80
81var tmplExprs = template.Must(template.New("").Parse(`` +
82	`package p; var _ = []interface{}{ {{ . }}, }`))
83
84var tmplStmts = template.Must(template.New("").Parse(`` +
85	`package p; func _() { {{ . }} }`))
86
87var tmplType = template.Must(template.New("").Parse(`` +
88	`package p; var _ {{ . }}`))
89
90var tmplValSpec = template.Must(template.New("").Parse(`` +
91	`package p; var {{ . }}`))
92
93func execTmpl(tmpl *template.Template, src string) string {
94	var buf bytes.Buffer
95	if err := tmpl.Execute(&buf, src); err != nil {
96		panic(err)
97	}
98	return buf.String()
99}
100
101func noBadNodes(node ast.Node) bool {
102	any := false
103	ast.Inspect(node, func(n ast.Node) bool {
104		if any {
105			return false
106		}
107		switch n.(type) {
108		case *ast.BadExpr, *ast.BadDecl:
109			any = true
110		}
111		return true
112	})
113	return !any
114}
115
116func parseType(fset *token.FileSet, src string) (ast.Expr, *ast.File, error) {
117	asType := execTmpl(tmplType, src)
118	f, err := parser.ParseFile(fset, "", asType, 0)
119	if err != nil {
120		err = subPosOffsets(err, posOffset{1, 1, 17})
121		return nil, nil, err
122	}
123	vs := f.Decls[0].(*ast.GenDecl).Specs[0].(*ast.ValueSpec)
124	return vs.Type, f, nil
125}
126
127// parseDetectingNode tries its best to parse the ast.Node contained in src, as
128// one of: *ast.File, ast.Decl, ast.Expr, ast.Stmt, *ast.ValueSpec.
129// It also returns the *ast.File used for the parsing, so that the returned node
130// can be easily type-checked.
131func parseDetectingNode(fset *token.FileSet, src string) (ast.Node, *ast.File, error) {
132	file := fset.AddFile("", fset.Base(), len(src))
133	scan := scanner.Scanner{}
134	scan.Init(file, []byte(src), nil, 0)
135	if _, tok, _ := scan.Scan(); tok == token.EOF {
136		return nil, nil, fmt.Errorf("empty source code")
137	}
138	var mainErr error
139
140	// first try as a whole file
141	if f, err := parser.ParseFile(fset, "", src, 0); err == nil && noBadNodes(f) {
142		return f, f, nil
143	}
144
145	// then as a single declaration, or many
146	asDecl := execTmpl(tmplDecl, src)
147	if f, err := parser.ParseFile(fset, "", asDecl, 0); err == nil && noBadNodes(f) {
148		if len(f.Decls) == 1 {
149			return f.Decls[0], f, nil
150		}
151		return f, f, nil
152	}
153
154	// then as a block; otherwise blocks might be mistaken for composite
155	// literals further below
156	asBlock := execTmpl(tmplBlock, src)
157	if f, err := parser.ParseFile(fset, "", asBlock, 0); err == nil && noBadNodes(f) {
158		bl := f.Decls[0].(*ast.FuncDecl).Body
159		if len(bl.List) == 1 {
160			ifs := bl.List[0].(*ast.IfStmt)
161			return ifs.Body, f, nil
162		}
163	}
164
165	// then as value expressions
166	asExprs := execTmpl(tmplExprs, src)
167	if f, err := parser.ParseFile(fset, "", asExprs, 0); err == nil && noBadNodes(f) {
168		vs := f.Decls[0].(*ast.GenDecl).Specs[0].(*ast.ValueSpec)
169		cl := vs.Values[0].(*ast.CompositeLit)
170		if len(cl.Elts) == 1 {
171			return cl.Elts[0], f, nil
172		}
173		return exprSlice(cl.Elts), f, nil
174	}
175
176	// then try as statements
177	asStmts := execTmpl(tmplStmts, src)
178	f, err := parser.ParseFile(fset, "", asStmts, 0)
179	if err == nil && noBadNodes(f) {
180		bl := f.Decls[0].(*ast.FuncDecl).Body
181		if len(bl.List) == 1 {
182			return bl.List[0], f, nil
183		}
184		return stmtSlice(bl.List), f, nil
185	}
186	// Statements is what covers most cases, so it will give
187	// the best overall error message. Show positions
188	// relative to where the user's code is put in the
189	// template.
190	mainErr = subPosOffsets(err, posOffset{1, 1, 22})
191
192	// type expressions not yet picked up, for e.g. chans and interfaces
193	if typ, f, err := parseType(fset, src); err == nil && noBadNodes(f) {
194		return typ, f, nil
195	}
196
197	// value specs
198	asValSpec := execTmpl(tmplValSpec, src)
199	if f, err := parser.ParseFile(fset, "", asValSpec, 0); err == nil && noBadNodes(f) {
200		vs := f.Decls[0].(*ast.GenDecl).Specs[0].(*ast.ValueSpec)
201		return vs, f, nil
202	}
203	return nil, nil, mainErr
204}
205
206type posOffset struct {
207	atLine, atCol int
208	offset        int
209}
210
211func subPosOffsets(err error, offs ...posOffset) error {
212	list, ok := err.(scanner.ErrorList)
213	if !ok {
214		return err
215	}
216	for i, err := range list {
217		for _, off := range offs {
218			if err.Pos.Line != off.atLine {
219				continue
220			}
221			if err.Pos.Column < off.atCol {
222				continue
223			}
224			err.Pos.Column -= off.offset
225		}
226		list[i] = err
227	}
228	return list
229}
230
231type fullToken struct {
232	pos token.Position
233	tok token.Token
234	lit string
235}
236
237type caseStatus uint
238
239const (
240	caseNone caseStatus = iota
241	caseNeedBlock
242	caseHere
243)
244
245func tokenize(src []byte) ([]fullToken, error) {
246	var s scanner.Scanner
247	fset := token.NewFileSet()
248	file := fset.AddFile("", fset.Base(), len(src))
249
250	var err error
251	onError := func(pos token.Position, msg string) {
252		switch msg { // allow certain extra chars
253		case `illegal character U+0024 '$'`:
254		case `illegal character U+007E '~'`:
255		default:
256			err = fmt.Errorf("%v: %s", pos, msg)
257		}
258	}
259
260	// we will modify the input source under the scanner's nose to
261	// enable some features such as regexes.
262	s.Init(file, src, onError, scanner.ScanComments)
263
264	next := func() fullToken {
265		pos, tok, lit := s.Scan()
266		return fullToken{fset.Position(pos), tok, lit}
267	}
268
269	caseStat := caseNone
270
271	var toks []fullToken
272	for t := next(); t.tok != token.EOF; t = next() {
273		switch t.lit {
274		case "$": // continues below
275		case "switch", "select", "case":
276			if t.lit == "case" {
277				caseStat = caseNone
278			} else {
279				caseStat = caseNeedBlock
280			}
281			fallthrough
282		default: // regular Go code
283			if t.tok == token.LBRACE && caseStat == caseNeedBlock {
284				caseStat = caseHere
285			}
286			toks = append(toks, t)
287			continue
288		}
289		wt, err := tokenizeWildcard(t.pos, next)
290		if err != nil {
291			return nil, err
292		}
293		if caseStat == caseHere {
294			toks = append(toks, fullToken{wt.pos, token.IDENT, "case"})
295		}
296		toks = append(toks, wt)
297		if caseStat == caseHere {
298			toks = append(toks, fullToken{wt.pos, token.COLON, ""})
299			toks = append(toks, fullToken{wt.pos, token.IDENT, "gogrep_body"})
300		}
301	}
302	return toks, err
303}
304
305type varInfo struct {
306	Name string
307	Seq  bool
308}
309
310func tokenizeWildcard(pos token.Position, next func() fullToken) (fullToken, error) {
311	t := next()
312	any := false
313	if t.tok == token.MUL {
314		t = next()
315		any = true
316	}
317	wildName := encodeWildName(t.lit, any)
318	wt := fullToken{pos, token.IDENT, wildName}
319	if t.tok != token.IDENT {
320		return wt, fmt.Errorf("%v: $ must be followed by ident, got %v",
321			t.pos, t.tok)
322	}
323	return wt, nil
324}
325
326const wildSeparator = "ᐸᐳ"
327
328func isWildName(s string) bool {
329	return strings.HasPrefix(s, wildSeparator)
330}
331
332func encodeWildName(name string, any bool) string {
333	suffix := "v"
334	if any {
335		suffix = "a"
336	}
337	return wildSeparator + name + wildSeparator + suffix
338}
339
340func decodeWildName(s string) varInfo {
341	s = s[len(wildSeparator):]
342	nameEnd := strings.Index(s, wildSeparator)
343	name := s[:nameEnd]
344	s = s[nameEnd:]
345	s = s[len(wildSeparator):]
346	kind := s
347	return varInfo{Name: name, Seq: kind == "a"}
348}
349
350func decodeWildNode(n ast.Node) varInfo {
351	switch n := n.(type) {
352	case *ast.ExprStmt:
353		return decodeWildNode(n.X)
354	case *ast.Ident:
355		if isWildName(n.Name) {
356			return decodeWildName(n.Name)
357		}
358	}
359	return varInfo{}
360}
361