1package check
2
3import (
4	"bytes"
5	"go/ast"
6	"go/parser"
7	"go/printer"
8	"go/token"
9	"os"
10)
11
12func indent(s, with string) (r string) {
13	eol := true
14	for i := 0; i != len(s); i++ {
15		c := s[i]
16		switch {
17		case eol && c == '\n' || c == '\r':
18		case c == '\n' || c == '\r':
19			eol = true
20		case eol:
21			eol = false
22			s = s[:i] + with + s[i:]
23			i += len(with)
24		}
25	}
26	return s
27}
28
29func printLine(filename string, line int) (string, error) {
30	fset := token.NewFileSet()
31	file, err := os.Open(filename)
32	if err != nil {
33		return "", err
34	}
35	fnode, err := parser.ParseFile(fset, filename, file, parser.ParseComments)
36	if err != nil {
37		return "", err
38	}
39	config := &printer.Config{Mode: printer.UseSpaces, Tabwidth: 4}
40	lp := &linePrinter{fset: fset, fnode: fnode, line: line, config: config}
41	ast.Walk(lp, fnode)
42	result := lp.output.Bytes()
43	// Comments leave \n at the end.
44	n := len(result)
45	for n > 0 && result[n-1] == '\n' {
46		n--
47	}
48	return string(result[:n]), nil
49}
50
51type linePrinter struct {
52	config *printer.Config
53	fset   *token.FileSet
54	fnode  *ast.File
55	line   int
56	output bytes.Buffer
57	stmt   ast.Stmt
58}
59
60func (lp *linePrinter) emit() bool {
61	if lp.stmt != nil {
62		lp.trim(lp.stmt)
63		lp.printWithComments(lp.stmt)
64		lp.stmt = nil
65		return true
66	}
67	return false
68}
69
70func (lp *linePrinter) printWithComments(n ast.Node) {
71	nfirst := lp.fset.Position(n.Pos()).Line
72	nlast := lp.fset.Position(n.End()).Line
73	for _, g := range lp.fnode.Comments {
74		cfirst := lp.fset.Position(g.Pos()).Line
75		clast := lp.fset.Position(g.End()).Line
76		if clast == nfirst-1 && lp.fset.Position(n.Pos()).Column == lp.fset.Position(g.Pos()).Column {
77			for _, c := range g.List {
78				lp.output.WriteString(c.Text)
79				lp.output.WriteByte('\n')
80			}
81		}
82		if cfirst >= nfirst && cfirst <= nlast && n.End() <= g.List[0].Slash {
83			// The printer will not include the comment if it starts past
84			// the node itself. Trick it into printing by overlapping the
85			// slash with the end of the statement.
86			g.List[0].Slash = n.End() - 1
87		}
88	}
89	node := &printer.CommentedNode{n, lp.fnode.Comments}
90	lp.config.Fprint(&lp.output, lp.fset, node)
91}
92
93func (lp *linePrinter) Visit(n ast.Node) (w ast.Visitor) {
94	if n == nil {
95		if lp.output.Len() == 0 {
96			lp.emit()
97		}
98		return nil
99	}
100	first := lp.fset.Position(n.Pos()).Line
101	last := lp.fset.Position(n.End()).Line
102	if first <= lp.line && last >= lp.line {
103		// Print the innermost statement containing the line.
104		if stmt, ok := n.(ast.Stmt); ok {
105			if _, ok := n.(*ast.BlockStmt); !ok {
106				lp.stmt = stmt
107			}
108		}
109		if first == lp.line && lp.emit() {
110			return nil
111		}
112		return lp
113	}
114	return nil
115}
116
117func (lp *linePrinter) trim(n ast.Node) bool {
118	stmt, ok := n.(ast.Stmt)
119	if !ok {
120		return true
121	}
122	line := lp.fset.Position(n.Pos()).Line
123	if line != lp.line {
124		return false
125	}
126	switch stmt := stmt.(type) {
127	case *ast.IfStmt:
128		stmt.Body = lp.trimBlock(stmt.Body)
129	case *ast.SwitchStmt:
130		stmt.Body = lp.trimBlock(stmt.Body)
131	case *ast.TypeSwitchStmt:
132		stmt.Body = lp.trimBlock(stmt.Body)
133	case *ast.CaseClause:
134		stmt.Body = lp.trimList(stmt.Body)
135	case *ast.CommClause:
136		stmt.Body = lp.trimList(stmt.Body)
137	case *ast.BlockStmt:
138		stmt.List = lp.trimList(stmt.List)
139	}
140	return true
141}
142
143func (lp *linePrinter) trimBlock(stmt *ast.BlockStmt) *ast.BlockStmt {
144	if !lp.trim(stmt) {
145		return lp.emptyBlock(stmt)
146	}
147	stmt.Rbrace = stmt.Lbrace
148	return stmt
149}
150
151func (lp *linePrinter) trimList(stmts []ast.Stmt) []ast.Stmt {
152	for i := 0; i != len(stmts); i++ {
153		if !lp.trim(stmts[i]) {
154			stmts[i] = lp.emptyStmt(stmts[i])
155			break
156		}
157	}
158	return stmts
159}
160
161func (lp *linePrinter) emptyStmt(n ast.Node) *ast.ExprStmt {
162	return &ast.ExprStmt{&ast.Ellipsis{n.Pos(), nil}}
163}
164
165func (lp *linePrinter) emptyBlock(n ast.Node) *ast.BlockStmt {
166	p := n.Pos()
167	return &ast.BlockStmt{p, []ast.Stmt{lp.emptyStmt(n)}, p}
168}
169