1// Copyright 2013 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5//go:generate go run mkstdlib.go
6
7// Package imports implements a Go pretty-printer (like package "go/format")
8// that also adds or removes import statements as necessary.
9package imports
10
11import (
12	"bufio"
13	"bytes"
14	"fmt"
15	"go/ast"
16	"go/format"
17	"go/parser"
18	"go/printer"
19	"go/token"
20	"io"
21	"regexp"
22	"strconv"
23	"strings"
24
25	"golang.org/x/tools/go/ast/astutil"
26)
27
28// Options is golang.org/x/tools/imports.Options with extra internal-only options.
29type Options struct {
30	Env *ProcessEnv // The environment to use. Note: this contains the cached module and filesystem state.
31
32	// LocalPrefix is a comma-separated string of import path prefixes, which, if
33	// set, instructs Process to sort the import paths with the given prefixes
34	// into another group after 3rd-party packages.
35	LocalPrefix string
36
37	Fragment  bool // Accept fragment of a source file (no package statement)
38	AllErrors bool // Report all errors (not just the first 10 on different lines)
39
40	Comments  bool // Print comments (true if nil *Options provided)
41	TabIndent bool // Use tabs for indent (true if nil *Options provided)
42	TabWidth  int  // Tab width (8 if nil *Options provided)
43
44	FormatOnly bool // Disable the insertion and deletion of imports
45}
46
47// Process implements golang.org/x/tools/imports.Process with explicit context in opt.Env.
48func Process(filename string, src []byte, opt *Options) (formatted []byte, err error) {
49	fileSet := token.NewFileSet()
50	file, adjust, err := parse(fileSet, filename, src, opt)
51	if err != nil {
52		return nil, err
53	}
54
55	if !opt.FormatOnly {
56		if err := fixImports(fileSet, file, filename, opt.Env); err != nil {
57			return nil, err
58		}
59	}
60	return formatFile(fileSet, file, src, adjust, opt)
61}
62
63// FixImports returns a list of fixes to the imports that, when applied,
64// will leave the imports in the same state as Process. src and opt must
65// be specified.
66//
67// Note that filename's directory influences which imports can be chosen,
68// so it is important that filename be accurate.
69func FixImports(filename string, src []byte, opt *Options) (fixes []*ImportFix, err error) {
70	fileSet := token.NewFileSet()
71	file, _, err := parse(fileSet, filename, src, opt)
72	if err != nil {
73		return nil, err
74	}
75
76	return getFixes(fileSet, file, filename, opt.Env)
77}
78
79// ApplyFixes applies all of the fixes to the file and formats it. extraMode
80// is added in when parsing the file. src and opts must be specified, but no
81// env is needed.
82func ApplyFixes(fixes []*ImportFix, filename string, src []byte, opt *Options, extraMode parser.Mode) (formatted []byte, err error) {
83	// Don't use parse() -- we don't care about fragments or statement lists
84	// here, and we need to work with unparseable files.
85	fileSet := token.NewFileSet()
86	parserMode := parser.Mode(0)
87	if opt.Comments {
88		parserMode |= parser.ParseComments
89	}
90	if opt.AllErrors {
91		parserMode |= parser.AllErrors
92	}
93	parserMode |= extraMode
94
95	file, err := parser.ParseFile(fileSet, filename, src, parserMode)
96	if file == nil {
97		return nil, err
98	}
99
100	// Apply the fixes to the file.
101	apply(fileSet, file, fixes)
102
103	return formatFile(fileSet, file, src, nil, opt)
104}
105
106func formatFile(fileSet *token.FileSet, file *ast.File, src []byte, adjust func(orig []byte, src []byte) []byte, opt *Options) ([]byte, error) {
107	mergeImports(fileSet, file)
108	sortImports(opt.LocalPrefix, fileSet, file)
109	imps := astutil.Imports(fileSet, file)
110	var spacesBefore []string // import paths we need spaces before
111	for _, impSection := range imps {
112		// Within each block of contiguous imports, see if any
113		// import lines are in different group numbers. If so,
114		// we'll need to put a space between them so it's
115		// compatible with gofmt.
116		lastGroup := -1
117		for _, importSpec := range impSection {
118			importPath, _ := strconv.Unquote(importSpec.Path.Value)
119			groupNum := importGroup(opt.LocalPrefix, importPath)
120			if groupNum != lastGroup && lastGroup != -1 {
121				spacesBefore = append(spacesBefore, importPath)
122			}
123			lastGroup = groupNum
124		}
125
126	}
127
128	printerMode := printer.UseSpaces
129	if opt.TabIndent {
130		printerMode |= printer.TabIndent
131	}
132	printConfig := &printer.Config{Mode: printerMode, Tabwidth: opt.TabWidth}
133
134	var buf bytes.Buffer
135	err := printConfig.Fprint(&buf, fileSet, file)
136	if err != nil {
137		return nil, err
138	}
139	out := buf.Bytes()
140	if adjust != nil {
141		out = adjust(src, out)
142	}
143	if len(spacesBefore) > 0 {
144		out, err = addImportSpaces(bytes.NewReader(out), spacesBefore)
145		if err != nil {
146			return nil, err
147		}
148	}
149
150	out, err = format.Source(out)
151	if err != nil {
152		return nil, err
153	}
154	return out, nil
155}
156
157// parse parses src, which was read from filename,
158// as a Go source file or statement list.
159func parse(fset *token.FileSet, filename string, src []byte, opt *Options) (*ast.File, func(orig, src []byte) []byte, error) {
160	parserMode := parser.Mode(0)
161	if opt.Comments {
162		parserMode |= parser.ParseComments
163	}
164	if opt.AllErrors {
165		parserMode |= parser.AllErrors
166	}
167
168	// Try as whole source file.
169	file, err := parser.ParseFile(fset, filename, src, parserMode)
170	if err == nil {
171		return file, nil, nil
172	}
173	// If the error is that the source file didn't begin with a
174	// package line and we accept fragmented input, fall through to
175	// try as a source fragment.  Stop and return on any other error.
176	if !opt.Fragment || !strings.Contains(err.Error(), "expected 'package'") {
177		return nil, nil, err
178	}
179
180	// If this is a declaration list, make it a source file
181	// by inserting a package clause.
182	// Insert using a ;, not a newline, so that parse errors are on
183	// the correct line.
184	const prefix = "package main;"
185	psrc := append([]byte(prefix), src...)
186	file, err = parser.ParseFile(fset, filename, psrc, parserMode)
187	if err == nil {
188		// Gofmt will turn the ; into a \n.
189		// Do that ourselves now and update the file contents,
190		// so that positions and line numbers are correct going forward.
191		psrc[len(prefix)-1] = '\n'
192		fset.File(file.Package).SetLinesForContent(psrc)
193
194		// If a main function exists, we will assume this is a main
195		// package and leave the file.
196		if containsMainFunc(file) {
197			return file, nil, nil
198		}
199
200		adjust := func(orig, src []byte) []byte {
201			// Remove the package clause.
202			src = src[len(prefix):]
203			return matchSpace(orig, src)
204		}
205		return file, adjust, nil
206	}
207	// If the error is that the source file didn't begin with a
208	// declaration, fall through to try as a statement list.
209	// Stop and return on any other error.
210	if !strings.Contains(err.Error(), "expected declaration") {
211		return nil, nil, err
212	}
213
214	// If this is a statement list, make it a source file
215	// by inserting a package clause and turning the list
216	// into a function body.  This handles expressions too.
217	// Insert using a ;, not a newline, so that the line numbers
218	// in fsrc match the ones in src.
219	fsrc := append(append([]byte("package p; func _() {"), src...), '}')
220	file, err = parser.ParseFile(fset, filename, fsrc, parserMode)
221	if err == nil {
222		adjust := func(orig, src []byte) []byte {
223			// Remove the wrapping.
224			// Gofmt has turned the ; into a \n\n.
225			src = src[len("package p\n\nfunc _() {"):]
226			src = src[:len(src)-len("}\n")]
227			// Gofmt has also indented the function body one level.
228			// Remove that indent.
229			src = bytes.Replace(src, []byte("\n\t"), []byte("\n"), -1)
230			return matchSpace(orig, src)
231		}
232		return file, adjust, nil
233	}
234
235	// Failed, and out of options.
236	return nil, nil, err
237}
238
239// containsMainFunc checks if a file contains a function declaration with the
240// function signature 'func main()'
241func containsMainFunc(file *ast.File) bool {
242	for _, decl := range file.Decls {
243		if f, ok := decl.(*ast.FuncDecl); ok {
244			if f.Name.Name != "main" {
245				continue
246			}
247
248			if len(f.Type.Params.List) != 0 {
249				continue
250			}
251
252			if f.Type.Results != nil && len(f.Type.Results.List) != 0 {
253				continue
254			}
255
256			return true
257		}
258	}
259
260	return false
261}
262
263func cutSpace(b []byte) (before, middle, after []byte) {
264	i := 0
265	for i < len(b) && (b[i] == ' ' || b[i] == '\t' || b[i] == '\n') {
266		i++
267	}
268	j := len(b)
269	for j > 0 && (b[j-1] == ' ' || b[j-1] == '\t' || b[j-1] == '\n') {
270		j--
271	}
272	if i <= j {
273		return b[:i], b[i:j], b[j:]
274	}
275	return nil, nil, b[j:]
276}
277
278// matchSpace reformats src to use the same space context as orig.
279// 1) If orig begins with blank lines, matchSpace inserts them at the beginning of src.
280// 2) matchSpace copies the indentation of the first non-blank line in orig
281//    to every non-blank line in src.
282// 3) matchSpace copies the trailing space from orig and uses it in place
283//   of src's trailing space.
284func matchSpace(orig []byte, src []byte) []byte {
285	before, _, after := cutSpace(orig)
286	i := bytes.LastIndex(before, []byte{'\n'})
287	before, indent := before[:i+1], before[i+1:]
288
289	_, src, _ = cutSpace(src)
290
291	var b bytes.Buffer
292	b.Write(before)
293	for len(src) > 0 {
294		line := src
295		if i := bytes.IndexByte(line, '\n'); i >= 0 {
296			line, src = line[:i+1], line[i+1:]
297		} else {
298			src = nil
299		}
300		if len(line) > 0 && line[0] != '\n' { // not blank
301			b.Write(indent)
302		}
303		b.Write(line)
304	}
305	b.Write(after)
306	return b.Bytes()
307}
308
309var impLine = regexp.MustCompile(`^\s+(?:[\w\.]+\s+)?"(.+)"`)
310
311func addImportSpaces(r io.Reader, breaks []string) ([]byte, error) {
312	var out bytes.Buffer
313	in := bufio.NewReader(r)
314	inImports := false
315	done := false
316	for {
317		s, err := in.ReadString('\n')
318		if err == io.EOF {
319			break
320		} else if err != nil {
321			return nil, err
322		}
323
324		if !inImports && !done && strings.HasPrefix(s, "import") {
325			inImports = true
326		}
327		if inImports && (strings.HasPrefix(s, "var") ||
328			strings.HasPrefix(s, "func") ||
329			strings.HasPrefix(s, "const") ||
330			strings.HasPrefix(s, "type")) {
331			done = true
332			inImports = false
333		}
334		if inImports && len(breaks) > 0 {
335			if m := impLine.FindStringSubmatch(s); m != nil {
336				if m[1] == breaks[0] {
337					out.WriteByte('\n')
338					breaks = breaks[1:]
339				}
340			}
341		}
342
343		fmt.Fprint(&out, s)
344	}
345	return out.Bytes(), nil
346}
347