1// Copyright (c) 2019, Daniel Martí <mvdan@mvdan.cc>
2// See LICENSE for licensing information
3
4// Package format exposes gofumpt's formatting in an API similar to go/format.
5// In general, the APIs are only guaranteed to work well when the input source
6// is in canonical gofmt format.
7package format
8
9import (
10	"bytes"
11	"fmt"
12	"go/ast"
13	"go/format"
14	"go/parser"
15	"go/token"
16	"reflect"
17	"regexp"
18	"sort"
19	"strings"
20	"unicode"
21	"unicode/utf8"
22
23	"github.com/google/go-cmp/cmp"
24	"golang.org/x/mod/semver"
25	"golang.org/x/tools/go/ast/astutil"
26)
27
28type Options struct {
29	// LangVersion corresponds to the Go language version a piece of code is
30	// written in. The version is used to decide whether to apply formatting
31	// rules which require new language features. When inside a Go module,
32	// LangVersion should generally be specified as the result of:
33	//
34	//     go list -m -f {{.GoVersion}}
35	//
36	// LangVersion is treated as a semantic version, which might start with
37	// a "v" prefix. Like Go versions, it might also be incomplete; "1.14"
38	// is equivalent to "1.14.0". When empty, it is equivalent to "v1", to
39	// not use language features which could break programs.
40	LangVersion string
41
42	ExtraRules bool
43}
44
45// Source formats src in gofumpt's format, assuming that src holds a valid Go
46// source file.
47func Source(src []byte, opts Options) ([]byte, error) {
48	fset := token.NewFileSet()
49	file, err := parser.ParseFile(fset, "", src, parser.ParseComments)
50	if err != nil {
51		return nil, err
52	}
53
54	File(fset, file, opts)
55
56	var buf bytes.Buffer
57	if err := format.Node(&buf, fset, file); err != nil {
58		return nil, err
59	}
60	return buf.Bytes(), nil
61}
62
63// File modifies a file and fset in place to follow gofumpt's format. The
64// changes might include manipulating adding or removing newlines in fset,
65// modifying the position of nodes, or modifying literal values.
66func File(fset *token.FileSet, file *ast.File, opts Options) {
67	if opts.LangVersion == "" {
68		opts.LangVersion = "v1"
69	} else if opts.LangVersion[0] != 'v' {
70		opts.LangVersion = "v" + opts.LangVersion
71	}
72	if !semver.IsValid(opts.LangVersion) {
73		panic(fmt.Sprintf("invalid semver string: %q", opts.LangVersion))
74	}
75	f := &fumpter{
76		File:    fset.File(file.Pos()),
77		fset:    fset,
78		astFile: file,
79		Options: opts,
80	}
81	pre := func(c *astutil.Cursor) bool {
82		f.applyPre(c)
83		if _, ok := c.Node().(*ast.BlockStmt); ok {
84			f.blockLevel++
85		}
86		return true
87	}
88	post := func(c *astutil.Cursor) bool {
89		if _, ok := c.Node().(*ast.BlockStmt); ok {
90			f.blockLevel--
91		}
92		return true
93	}
94	astutil.Apply(file, pre, post)
95}
96
97// Multiline nodes which could fit on a single line under this many
98// bytes may be collapsed onto a single line.
99const shortLineLimit = 60
100
101var rxOctalInteger = regexp.MustCompile(`\A0[0-7_]+\z`)
102
103type fumpter struct {
104	Options
105
106	*token.File
107	fset *token.FileSet
108
109	astFile *ast.File
110
111	blockLevel int
112}
113
114func (f *fumpter) commentsBetween(p1, p2 token.Pos) []*ast.CommentGroup {
115	comments := f.astFile.Comments
116	i1 := sort.Search(len(comments), func(i int) bool {
117		return comments[i].Pos() >= p1
118	})
119	comments = comments[i1:]
120	i2 := sort.Search(len(comments), func(i int) bool {
121		return comments[i].Pos() >= p2
122	})
123	comments = comments[:i2]
124	return comments
125}
126
127func (f *fumpter) inlineComment(pos token.Pos) *ast.Comment {
128	comments := f.astFile.Comments
129	i := sort.Search(len(comments), func(i int) bool {
130		return comments[i].Pos() >= pos
131	})
132	if i >= len(comments) {
133		return nil
134	}
135	line := f.Line(pos)
136	for _, comment := range comments[i].List {
137		if f.Line(comment.Pos()) == line {
138			return comment
139		}
140	}
141	return nil
142}
143
144// addNewline is a hack to let us force a newline at a certain position.
145func (f *fumpter) addNewline(at token.Pos) {
146	offset := f.Offset(at)
147
148	field := reflect.ValueOf(f.File).Elem().FieldByName("lines")
149	n := field.Len()
150	lines := make([]int, 0, n+1)
151	for i := 0; i < n; i++ {
152		cur := int(field.Index(i).Int())
153		if offset == cur {
154			// This newline already exists; do nothing. Duplicate
155			// newlines can't exist.
156			return
157		}
158		if offset >= 0 && offset < cur {
159			lines = append(lines, offset)
160			offset = -1
161		}
162		lines = append(lines, cur)
163	}
164	if offset >= 0 {
165		lines = append(lines, offset)
166	}
167	if !f.SetLines(lines) {
168		panic(fmt.Sprintf("could not set lines to %v", lines))
169	}
170}
171
172// removeNewlines removes all newlines between two positions, so that they end
173// up on the same line.
174func (f *fumpter) removeLines(fromLine, toLine int) {
175	for fromLine < toLine {
176		f.MergeLine(fromLine)
177		toLine--
178	}
179}
180
181// removeLinesBetween is like removeLines, but it leaves one newline between the
182// two positions.
183func (f *fumpter) removeLinesBetween(from, to token.Pos) {
184	f.removeLines(f.Line(from)+1, f.Line(to))
185}
186
187type byteCounter int
188
189func (b *byteCounter) Write(p []byte) (n int, err error) {
190	*b += byteCounter(len(p))
191	return len(p), nil
192}
193
194func (f *fumpter) printLength(node ast.Node) int {
195	var count byteCounter
196	if err := format.Node(&count, f.fset, node); err != nil {
197		panic(fmt.Sprintf("unexpected print error: %v", err))
198	}
199
200	// Add the space taken by an inline comment.
201	if c := f.inlineComment(node.End()); c != nil {
202		fmt.Fprintf(&count, " %s", c.Text)
203	}
204
205	// Add an approximation of the indentation level. We can't know the
206	// number of tabs go/printer will add ahead of time. Trying to print the
207	// entire top-level declaration would tell us that, but then it's near
208	// impossible to reliably find our node again.
209	return int(count) + (f.blockLevel * 8)
210}
211
212// rxCommentDirective covers all common Go comment directives:
213//
214//   //go:        | standard Go directives, like go:noinline
215//   //someword:  | similar to the syntax above, like lint:ignore
216//   //line       | inserted line information for cmd/compile
217//   //export     | to mark cgo funcs for exporting
218//   //extern     | C function declarations for gccgo
219//   //sys(nb)?   | syscall function wrapper prototypes
220//   //nolint     | nolint directive for golangci
221var rxCommentDirective = regexp.MustCompile(`^([a-z]+:|line\b|export\b|extern\b|sys(nb)?\b|nolint\b)`)
222
223// visit takes either an ast.Node or a []ast.Stmt.
224func (f *fumpter) applyPre(c *astutil.Cursor) {
225	switch node := c.Node().(type) {
226	case *ast.File:
227		var lastMulti bool
228		var lastEnd token.Pos
229		for _, decl := range node.Decls {
230			pos := decl.Pos()
231			comments := f.commentsBetween(lastEnd, pos)
232			if len(comments) > 0 {
233				pos = comments[0].Pos()
234			}
235
236			// multiline top-level declarations should be separated
237			multi := f.Line(pos) < f.Line(decl.End())
238			if multi && lastMulti && f.Line(lastEnd)+1 == f.Line(pos) {
239				f.addNewline(lastEnd)
240			}
241
242			lastMulti = multi
243			lastEnd = decl.End()
244		}
245
246		// Join contiguous lone var/const/import lines; abort if there
247		// are empty lines or comments in between.
248		newDecls := make([]ast.Decl, 0, len(node.Decls))
249		for i := 0; i < len(node.Decls); {
250			newDecls = append(newDecls, node.Decls[i])
251			start, ok := node.Decls[i].(*ast.GenDecl)
252			if !ok {
253				i++
254				continue
255			}
256			lastPos := start.Pos()
257			for i++; i < len(node.Decls); {
258				cont, ok := node.Decls[i].(*ast.GenDecl)
259				if !ok || cont.Tok != start.Tok || cont.Lparen != token.NoPos ||
260					f.Line(lastPos) < f.Line(cont.Pos())-1 {
261					break
262				}
263				start.Specs = append(start.Specs, cont.Specs...)
264				if c := f.inlineComment(cont.End()); c != nil {
265					// don't move an inline comment outside
266					start.Rparen = c.End()
267				}
268				lastPos = cont.Pos()
269				i++
270			}
271		}
272		node.Decls = newDecls
273
274		// Comments aren't nodes, so they're not walked by default.
275	groupLoop:
276		for _, group := range node.Comments {
277			for _, comment := range group.List {
278				body := strings.TrimPrefix(comment.Text, "//")
279				if body == comment.Text {
280					// /*-style comment
281					continue groupLoop
282				}
283				if rxCommentDirective.MatchString(body) {
284					// this line is a directive
285					continue groupLoop
286				}
287				r, _ := utf8.DecodeRuneInString(body)
288				if !unicode.IsLetter(r) && !unicode.IsNumber(r) && !unicode.IsSpace(r) {
289					// this line could be code like "//{"
290					continue groupLoop
291				}
292			}
293			// If none of the comment group's lines look like a
294			// directive or code, add spaces, if needed.
295			for _, comment := range group.List {
296				body := strings.TrimPrefix(comment.Text, "//")
297				r, _ := utf8.DecodeRuneInString(body)
298				if !unicode.IsSpace(r) {
299					comment.Text = "// " + strings.TrimPrefix(comment.Text, "//")
300				}
301			}
302		}
303
304	case *ast.DeclStmt:
305		decl, ok := node.Decl.(*ast.GenDecl)
306		if !ok || decl.Tok != token.VAR || len(decl.Specs) != 1 {
307			break // e.g. const name = "value"
308		}
309		spec := decl.Specs[0].(*ast.ValueSpec)
310		if spec.Type != nil {
311			break // e.g. var name Type
312		}
313		tok := token.ASSIGN
314		names := make([]ast.Expr, len(spec.Names))
315		for i, name := range spec.Names {
316			names[i] = name
317			if name.Name != "_" {
318				tok = token.DEFINE
319			}
320		}
321		c.Replace(&ast.AssignStmt{
322			Lhs: names,
323			Tok: tok,
324			Rhs: spec.Values,
325		})
326
327	case *ast.GenDecl:
328		if node.Tok == token.IMPORT && node.Lparen.IsValid() {
329			f.joinStdImports(node)
330		}
331
332		// Single var declarations shouldn't use parentheses, unless
333		// there's a comment on the grouped declaration.
334		if node.Tok == token.VAR && len(node.Specs) == 1 &&
335			node.Lparen.IsValid() && node.Doc == nil {
336			specPos := node.Specs[0].Pos()
337			specEnd := node.Specs[0].End()
338
339			if len(f.commentsBetween(node.TokPos, specPos)) > 0 {
340				// If the single spec has any comment, it must
341				// go before the entire declaration now.
342				node.TokPos = specPos
343			} else {
344				f.removeLines(f.Line(node.TokPos), f.Line(specPos))
345			}
346			f.removeLines(f.Line(specEnd), f.Line(node.Rparen))
347
348			// Remove the parentheses. go/printer will automatically
349			// get rid of the newlines.
350			node.Lparen = token.NoPos
351			node.Rparen = token.NoPos
352		}
353
354	case *ast.BlockStmt:
355		f.stmts(node.List)
356		comments := f.commentsBetween(node.Lbrace, node.Rbrace)
357		if len(node.List) == 0 && len(comments) == 0 {
358			f.removeLinesBetween(node.Lbrace, node.Rbrace)
359			break
360		}
361
362		isFuncBody := false
363		switch c.Parent().(type) {
364		case *ast.FuncDecl:
365			isFuncBody = true
366		case *ast.FuncLit:
367			isFuncBody = true
368		}
369
370		if len(node.List) > 1 && !isFuncBody {
371			// only if we have a single statement, or if
372			// it's a func body.
373			break
374		}
375		var bodyPos, bodyEnd token.Pos
376
377		if len(node.List) > 0 {
378			bodyPos = node.List[0].Pos()
379			bodyEnd = node.List[len(node.List)-1].End()
380		}
381		if len(comments) > 0 {
382			if pos := comments[0].Pos(); !bodyPos.IsValid() || pos < bodyPos {
383				bodyPos = pos
384			}
385			if pos := comments[len(comments)-1].End(); !bodyPos.IsValid() || pos > bodyEnd {
386				bodyEnd = pos
387			}
388		}
389
390		f.removeLinesBetween(node.Lbrace, bodyPos)
391		f.removeLinesBetween(bodyEnd, node.Rbrace)
392
393	case *ast.CompositeLit:
394		if len(node.Elts) == 0 {
395			// doesn't have elements
396			break
397		}
398		openLine := f.Line(node.Lbrace)
399		closeLine := f.Line(node.Rbrace)
400		if openLine == closeLine {
401			// all in a single line
402			break
403		}
404
405		newlineAroundElems := false
406		newlineBetweenElems := false
407		lastLine := openLine
408		for i, elem := range node.Elts {
409			if f.Line(elem.Pos()) > lastLine {
410				if i == 0 {
411					newlineAroundElems = true
412				} else {
413					newlineBetweenElems = true
414				}
415			}
416			lastLine = f.Line(elem.End())
417		}
418		if closeLine > lastLine {
419			newlineAroundElems = true
420		}
421
422		if newlineBetweenElems || newlineAroundElems {
423			first := node.Elts[0]
424			if openLine == f.Line(first.Pos()) {
425				// We want the newline right after the brace.
426				f.addNewline(node.Lbrace + 1)
427				closeLine = f.Line(node.Rbrace)
428			}
429			last := node.Elts[len(node.Elts)-1]
430			if closeLine == f.Line(last.End()) {
431				// We want the newline right before the brace.
432				f.addNewline(node.Rbrace)
433			}
434		}
435
436		// If there's a newline between any consecutive elements, there
437		// must be a newline between all composite literal elements.
438		if !newlineBetweenElems {
439			break
440		}
441		for i1, elem1 := range node.Elts {
442			i2 := i1 + 1
443			if i2 >= len(node.Elts) {
444				break
445			}
446			elem2 := node.Elts[i2]
447			// TODO: do we care about &{}?
448			_, ok1 := elem1.(*ast.CompositeLit)
449			_, ok2 := elem2.(*ast.CompositeLit)
450			if !ok1 && !ok2 {
451				continue
452			}
453			if f.Line(elem1.End()) == f.Line(elem2.Pos()) {
454				f.addNewline(elem1.End())
455			}
456		}
457
458	case *ast.CaseClause:
459		f.stmts(node.Body)
460		openLine := f.Line(node.Case)
461		closeLine := f.Line(node.Colon)
462		if openLine == closeLine {
463			// nothing to do
464			break
465		}
466		if len(f.commentsBetween(node.Case, node.Colon)) > 0 {
467			// don't move comments
468			break
469		}
470		if f.printLength(node) > shortLineLimit {
471			// too long to collapse
472			break
473		}
474		f.removeLines(openLine, closeLine)
475
476	case *ast.CommClause:
477		f.stmts(node.Body)
478
479	case *ast.FieldList:
480		// Merging adjacent fields (e.g. parameters) is disabled by default.
481		if !f.ExtraRules {
482			break
483		}
484		switch c.Parent().(type) {
485		case *ast.FuncDecl, *ast.FuncType, *ast.InterfaceType:
486			node.List = f.mergeAdjacentFields(node.List)
487			c.Replace(node)
488		case *ast.StructType:
489			// Do not merge adjacent fields in structs.
490		}
491
492	case *ast.BasicLit:
493		// Octal number literals were introduced in 1.13.
494		if semver.Compare(f.LangVersion, "v1.13") >= 0 {
495			if node.Kind == token.INT && rxOctalInteger.MatchString(node.Value) {
496				node.Value = "0o" + node.Value[1:]
497				c.Replace(node)
498			}
499		}
500	}
501}
502
503func (f *fumpter) stmts(list []ast.Stmt) {
504	for i, stmt := range list {
505		ifs, ok := stmt.(*ast.IfStmt)
506		if !ok || i < 1 {
507			continue // not an if following another statement
508		}
509		as, ok := list[i-1].(*ast.AssignStmt)
510		if !ok || as.Tok != token.DEFINE ||
511			!identEqual(as.Lhs[len(as.Lhs)-1], "err") {
512			continue // not "..., err := ..."
513		}
514		be, ok := ifs.Cond.(*ast.BinaryExpr)
515		if !ok || ifs.Init != nil || ifs.Else != nil {
516			continue // complex if
517		}
518		if be.Op != token.NEQ || !identEqual(be.X, "err") ||
519			!identEqual(be.Y, "nil") {
520			continue // not "err != nil"
521		}
522		f.removeLinesBetween(as.End(), ifs.Pos())
523	}
524}
525
526func identEqual(expr ast.Expr, name string) bool {
527	id, ok := expr.(*ast.Ident)
528	return ok && id.Name == name
529}
530
531// joinStdImports ensures that all standard library imports are together and at
532// the top of the imports list.
533func (f *fumpter) joinStdImports(d *ast.GenDecl) {
534	var std, other []ast.Spec
535	firstGroup := true
536	lastEnd := d.Pos()
537	needsSort := false
538	for i, spec := range d.Specs {
539		spec := spec.(*ast.ImportSpec)
540		if coms := f.commentsBetween(lastEnd, spec.Pos()); len(coms) > 0 {
541			lastEnd = coms[len(coms)-1].End()
542		}
543		if i > 0 && firstGroup && f.Line(spec.Pos()) > f.Line(lastEnd)+1 {
544			firstGroup = false
545		} else {
546			// We're still in the first group, update lastEnd.
547			lastEnd = spec.End()
548		}
549
550		// First, separate the non-std imports.
551		if strings.Contains(spec.Path.Value, ".") {
552			other = append(other, spec)
553			continue
554		}
555		// To be conservative, if an import has a name or an inline
556		// comment, and isn't part of the top group, treat it as non-std.
557		if !firstGroup && (spec.Name != nil || spec.Comment != nil) {
558			other = append(other, spec)
559			continue
560		}
561
562		// If we're moving this std import further up, reset its
563		// position, to avoid breaking comments.
564		if !firstGroup || len(other) > 0 {
565			setPos(reflect.ValueOf(spec), d.Pos())
566			needsSort = true
567		}
568		std = append(std, spec)
569	}
570	// Ensure there is an empty line between std imports and other imports.
571	if len(std) > 0 && len(other) > 0 && f.Line(std[len(std)-1].End())+1 >= f.Line(other[0].Pos()) {
572		// We add two newlines, as that's necessary in some edge cases.
573		// For example, if the std and non-std imports were together and
574		// without indentation, adding one newline isn't enough. Two
575		// empty lines will be printed as one by go/printer, anyway.
576		f.addNewline(other[0].Pos() - 1)
577		f.addNewline(other[0].Pos())
578	}
579	// Finally, join the imports, keeping std at the top.
580	d.Specs = append(std, other...)
581
582	// If we moved any std imports to the first group, we need to sort them
583	// again.
584	if needsSort {
585		ast.SortImports(f.fset, f.astFile)
586	}
587}
588
589// mergeAdjacentFields returns fields with adjacent fields merged if possible.
590func (f *fumpter) mergeAdjacentFields(fields []*ast.Field) []*ast.Field {
591	// If there are less than two fields then there is nothing to merge.
592	if len(fields) < 2 {
593		return fields
594	}
595
596	// Otherwise, iterate over adjacent pairs of fields, merging if possible,
597	// and mutating fields. Elements of fields may be mutated (if merged with
598	// following fields), discarded (if merged with a preceeding field), or left
599	// unchanged.
600	i := 0
601	for j := 1; j < len(fields); j++ {
602		if f.shouldMergeAdjacentFields(fields[i], fields[j]) {
603			fields[i].Names = append(fields[i].Names, fields[j].Names...)
604		} else {
605			i++
606			fields[i] = fields[j]
607		}
608	}
609	return fields[:i+1]
610}
611
612func (f *fumpter) shouldMergeAdjacentFields(f1, f2 *ast.Field) bool {
613	if len(f1.Names) == 0 || len(f2.Names) == 0 {
614		// Both must have names for the merge to work.
615		return false
616	}
617	if f.Line(f1.Pos()) != f.Line(f2.Pos()) {
618		// Trust the user if they used separate lines.
619		return false
620	}
621
622	// Only merge if the types are equal.
623	opt := cmp.Comparer(func(x, y token.Pos) bool { return true })
624	return cmp.Equal(f1.Type, f2.Type, opt)
625}
626
627var posType = reflect.TypeOf(token.NoPos)
628
629// setPos recursively sets all position fields in the node v to pos.
630func setPos(v reflect.Value, pos token.Pos) {
631	if v.Kind() == reflect.Ptr {
632		v = v.Elem()
633	}
634	if !v.IsValid() {
635		return
636	}
637	if v.Type() == posType {
638		v.Set(reflect.ValueOf(pos))
639	}
640	if v.Kind() == reflect.Struct {
641		for i := 0; i < v.NumField(); i++ {
642			setPos(v.Field(i), pos)
643		}
644	}
645}
646