1// Copyright (c) 2019, Daniel Martí <mvdan@mvdan.cc>
2// See LICENSE for licensing information
3
4// Package format exposes gofumpt's formatting in an API similar to go/format.
5// In general, the APIs are only guaranteed to work well when the input source
6// is in canonical gofmt format.
7package format
8
9import (
10	"bytes"
11	"fmt"
12	"go/ast"
13	"go/format"
14	"go/parser"
15	"go/token"
16	"reflect"
17	"regexp"
18	"sort"
19	"strconv"
20	"strings"
21	"unicode"
22	"unicode/utf8"
23
24	"github.com/google/go-cmp/cmp"
25	"golang.org/x/mod/semver"
26	"golang.org/x/tools/go/ast/astutil"
27)
28
29type Options struct {
30	// LangVersion corresponds to the Go language version a piece of code is
31	// written in. The version is used to decide whether to apply formatting
32	// rules which require new language features. When inside a Go module,
33	// LangVersion should generally be specified as the result of:
34	//
35	//     go list -m -f {{.GoVersion}}
36	//
37	// LangVersion is treated as a semantic version, which might start with
38	// a "v" prefix. Like Go versions, it might also be incomplete; "1.14"
39	// is equivalent to "1.14.0". When empty, it is equivalent to "v1", to
40	// not use language features which could break programs.
41	LangVersion string
42
43	ExtraRules bool
44}
45
46// Source formats src in gofumpt's format, assuming that src holds a valid Go
47// source file.
48func Source(src []byte, opts Options) ([]byte, error) {
49	fset := token.NewFileSet()
50	file, err := parser.ParseFile(fset, "", src, parser.ParseComments)
51	if err != nil {
52		return nil, err
53	}
54
55	File(fset, file, opts)
56
57	var buf bytes.Buffer
58	if err := format.Node(&buf, fset, file); err != nil {
59		return nil, err
60	}
61	return buf.Bytes(), nil
62}
63
64// File modifies a file and fset in place to follow gofumpt's format. The
65// changes might include manipulating adding or removing newlines in fset,
66// modifying the position of nodes, or modifying literal values.
67func File(fset *token.FileSet, file *ast.File, opts Options) {
68	if opts.LangVersion == "" {
69		opts.LangVersion = "v1"
70	} else if opts.LangVersion[0] != 'v' {
71		opts.LangVersion = "v" + opts.LangVersion
72	}
73	if !semver.IsValid(opts.LangVersion) {
74		panic(fmt.Sprintf("invalid semver string: %q", opts.LangVersion))
75	}
76	f := &fumpter{
77		File:    fset.File(file.Pos()),
78		fset:    fset,
79		astFile: file,
80		Options: opts,
81	}
82	pre := func(c *astutil.Cursor) bool {
83		f.applyPre(c)
84		if _, ok := c.Node().(*ast.BlockStmt); ok {
85			f.blockLevel++
86		}
87		return true
88	}
89	post := func(c *astutil.Cursor) bool {
90		if _, ok := c.Node().(*ast.BlockStmt); ok {
91			f.blockLevel--
92		}
93		return true
94	}
95	astutil.Apply(file, pre, post)
96}
97
98// Multiline nodes which could fit on a single line under this many
99// bytes may be collapsed onto a single line.
100const shortLineLimit = 60
101
102var rxOctalInteger = regexp.MustCompile(`\A0[0-7_]+\z`)
103
104type fumpter struct {
105	Options
106
107	*token.File
108	fset *token.FileSet
109
110	astFile *ast.File
111
112	blockLevel int
113}
114
115func (f *fumpter) commentsBetween(p1, p2 token.Pos) []*ast.CommentGroup {
116	comments := f.astFile.Comments
117	i1 := sort.Search(len(comments), func(i int) bool {
118		return comments[i].Pos() >= p1
119	})
120	comments = comments[i1:]
121	i2 := sort.Search(len(comments), func(i int) bool {
122		return comments[i].Pos() >= p2
123	})
124	comments = comments[:i2]
125	return comments
126}
127
128func (f *fumpter) inlineComment(pos token.Pos) *ast.Comment {
129	comments := f.astFile.Comments
130	i := sort.Search(len(comments), func(i int) bool {
131		return comments[i].Pos() >= pos
132	})
133	if i >= len(comments) {
134		return nil
135	}
136	line := f.Line(pos)
137	for _, comment := range comments[i].List {
138		if f.Line(comment.Pos()) == line {
139			return comment
140		}
141	}
142	return nil
143}
144
145// addNewline is a hack to let us force a newline at a certain position.
146func (f *fumpter) addNewline(at token.Pos) {
147	offset := f.Offset(at)
148
149	field := reflect.ValueOf(f.File).Elem().FieldByName("lines")
150	n := field.Len()
151	lines := make([]int, 0, n+1)
152	for i := 0; i < n; i++ {
153		cur := int(field.Index(i).Int())
154		if offset == cur {
155			// This newline already exists; do nothing. Duplicate
156			// newlines can't exist.
157			return
158		}
159		if offset >= 0 && offset < cur {
160			lines = append(lines, offset)
161			offset = -1
162		}
163		lines = append(lines, cur)
164	}
165	if offset >= 0 {
166		lines = append(lines, offset)
167	}
168	if !f.SetLines(lines) {
169		panic(fmt.Sprintf("could not set lines to %v", lines))
170	}
171}
172
173// removeLines removes all newlines between two positions, so that they end
174// up on the same line.
175func (f *fumpter) removeLines(fromLine, toLine int) {
176	for fromLine < toLine {
177		f.MergeLine(fromLine)
178		toLine--
179	}
180}
181
182// removeLinesBetween is like removeLines, but it leaves one newline between the
183// two positions.
184func (f *fumpter) removeLinesBetween(from, to token.Pos) {
185	f.removeLines(f.Line(from)+1, f.Line(to))
186}
187
188type byteCounter int
189
190func (b *byteCounter) Write(p []byte) (n int, err error) {
191	*b += byteCounter(len(p))
192	return len(p), nil
193}
194
195func (f *fumpter) printLength(node ast.Node) int {
196	var count byteCounter
197	if err := format.Node(&count, f.fset, node); err != nil {
198		panic(fmt.Sprintf("unexpected print error: %v", err))
199	}
200
201	// Add the space taken by an inline comment.
202	if c := f.inlineComment(node.End()); c != nil {
203		fmt.Fprintf(&count, " %s", c.Text)
204	}
205
206	// Add an approximation of the indentation level. We can't know the
207	// number of tabs go/printer will add ahead of time. Trying to print the
208	// entire top-level declaration would tell us that, but then it's near
209	// impossible to reliably find our node again.
210	return int(count) + (f.blockLevel * 8)
211}
212
213// rxCommentDirective covers all common Go comment directives:
214//
215//   //go:         | standard Go directives, like go:noinline
216//   //some-words: | similar to the syntax above, like lint:ignore or go-sumtype:decl
217//   //line        | inserted line information for cmd/compile
218//   //export      | to mark cgo funcs for exporting
219//   //extern      | C function declarations for gccgo
220//   //sys(nb)?    | syscall function wrapper prototypes
221//   //nolint      | nolint directive for golangci
222//
223// Note that the "some-words:" matching expects a letter afterward, such as
224// "go:generate", to prevent matching false positives like "https://site".
225var rxCommentDirective = regexp.MustCompile(`^([a-z-]+:[a-z]+|line\b|export\b|extern\b|sys(nb)?\b|nolint\b)`)
226
227// visit takes either an ast.Node or a []ast.Stmt.
228func (f *fumpter) applyPre(c *astutil.Cursor) {
229	switch node := c.Node().(type) {
230	case *ast.File:
231		var lastMulti bool
232		var lastEnd token.Pos
233		for _, decl := range node.Decls {
234			pos := decl.Pos()
235			comments := f.commentsBetween(lastEnd, pos)
236			if len(comments) > 0 {
237				pos = comments[0].Pos()
238			}
239
240			// multiline top-level declarations should be separated
241			multi := f.Line(pos) < f.Line(decl.End())
242			if multi && lastMulti && f.Line(lastEnd)+1 == f.Line(pos) {
243				f.addNewline(lastEnd)
244			}
245
246			lastMulti = multi
247			lastEnd = decl.End()
248		}
249
250		// Join contiguous lone var/const/import lines; abort if there
251		// are empty lines or comments in between.
252		newDecls := make([]ast.Decl, 0, len(node.Decls))
253		for i := 0; i < len(node.Decls); {
254			newDecls = append(newDecls, node.Decls[i])
255			start, ok := node.Decls[i].(*ast.GenDecl)
256			if !ok || isCgoImport(start) {
257				i++
258				continue
259			}
260			lastPos := start.Pos()
261			for i++; i < len(node.Decls); {
262				cont, ok := node.Decls[i].(*ast.GenDecl)
263				if !ok || cont.Tok != start.Tok || cont.Lparen != token.NoPos ||
264					f.Line(lastPos) < f.Line(cont.Pos())-1 || isCgoImport(cont) {
265					break
266				}
267				start.Specs = append(start.Specs, cont.Specs...)
268				if c := f.inlineComment(cont.End()); c != nil {
269					// don't move an inline comment outside
270					start.Rparen = c.End()
271				}
272				lastPos = cont.Pos()
273				i++
274			}
275		}
276		node.Decls = newDecls
277
278		// Comments aren't nodes, so they're not walked by default.
279	groupLoop:
280		for _, group := range node.Comments {
281			for _, comment := range group.List {
282				body := strings.TrimPrefix(comment.Text, "//")
283				if body == comment.Text {
284					// /*-style comment
285					continue groupLoop
286				}
287				if rxCommentDirective.MatchString(body) {
288					// this line is a directive
289					continue groupLoop
290				}
291				r, _ := utf8.DecodeRuneInString(body)
292				if !unicode.IsLetter(r) && !unicode.IsNumber(r) && !unicode.IsSpace(r) {
293					// this line could be code like "//{"
294					continue groupLoop
295				}
296			}
297			// If none of the comment group's lines look like a
298			// directive or code, add spaces, if needed.
299			for _, comment := range group.List {
300				body := strings.TrimPrefix(comment.Text, "//")
301				r, _ := utf8.DecodeRuneInString(body)
302				if !unicode.IsSpace(r) {
303					comment.Text = "// " + strings.TrimPrefix(comment.Text, "//")
304				}
305			}
306		}
307
308	case *ast.DeclStmt:
309		decl, ok := node.Decl.(*ast.GenDecl)
310		if !ok || decl.Tok != token.VAR || len(decl.Specs) != 1 {
311			break // e.g. const name = "value"
312		}
313		spec := decl.Specs[0].(*ast.ValueSpec)
314		if spec.Type != nil {
315			break // e.g. var name Type
316		}
317		tok := token.ASSIGN
318		names := make([]ast.Expr, len(spec.Names))
319		for i, name := range spec.Names {
320			names[i] = name
321			if name.Name != "_" {
322				tok = token.DEFINE
323			}
324		}
325		c.Replace(&ast.AssignStmt{
326			Lhs: names,
327			Tok: tok,
328			Rhs: spec.Values,
329		})
330
331	case *ast.GenDecl:
332		if node.Tok == token.IMPORT && node.Lparen.IsValid() {
333			f.joinStdImports(node)
334		}
335
336		// Single var declarations shouldn't use parentheses, unless
337		// there's a comment on the grouped declaration.
338		if node.Tok == token.VAR && len(node.Specs) == 1 &&
339			node.Lparen.IsValid() && node.Doc == nil {
340			specPos := node.Specs[0].Pos()
341			specEnd := node.Specs[0].End()
342
343			if len(f.commentsBetween(node.TokPos, specPos)) > 0 {
344				// If the single spec has any comment, it must
345				// go before the entire declaration now.
346				node.TokPos = specPos
347			} else {
348				f.removeLines(f.Line(node.TokPos), f.Line(specPos))
349			}
350			f.removeLines(f.Line(specEnd), f.Line(node.Rparen))
351
352			// Remove the parentheses. go/printer will automatically
353			// get rid of the newlines.
354			node.Lparen = token.NoPos
355			node.Rparen = token.NoPos
356		}
357
358	case *ast.BlockStmt:
359		f.stmts(node.List)
360		comments := f.commentsBetween(node.Lbrace, node.Rbrace)
361		if len(node.List) == 0 && len(comments) == 0 {
362			f.removeLinesBetween(node.Lbrace, node.Rbrace)
363			break
364		}
365
366		var sign *ast.FuncType
367		var cond ast.Expr
368		switch parent := c.Parent().(type) {
369		case *ast.FuncDecl:
370			sign = parent.Type
371		case *ast.FuncLit:
372			sign = parent.Type
373		case *ast.IfStmt:
374			cond = parent.Cond
375		case *ast.ForStmt:
376			cond = parent.Cond
377		}
378
379		if len(node.List) > 1 && sign == nil {
380			// only if we have a single statement, or if
381			// it's a func body.
382			break
383		}
384		var bodyPos, bodyEnd token.Pos
385
386		if len(node.List) > 0 {
387			bodyPos = node.List[0].Pos()
388			bodyEnd = node.List[len(node.List)-1].End()
389		}
390		if len(comments) > 0 {
391			if pos := comments[0].Pos(); !bodyPos.IsValid() || pos < bodyPos {
392				bodyPos = pos
393			}
394			if pos := comments[len(comments)-1].End(); !bodyPos.IsValid() || pos > bodyEnd {
395				bodyEnd = pos
396			}
397		}
398
399		f.removeLinesBetween(bodyEnd, node.Rbrace)
400
401		if cond != nil && f.Line(cond.Pos()) != f.Line(cond.End()) {
402			// The body is preceded by a multi-line condition, so an
403			// empty line can help readability.
404			return
405		}
406		if sign != nil {
407			var lastParam *ast.Field
408			if l := sign.Results; l != nil && len(l.List) > 0 {
409				lastParam = l.List[len(l.List)-1]
410			} else if l := sign.Params; l != nil && len(l.List) > 0 {
411				lastParam = l.List[len(l.List)-1]
412			}
413			endLine := f.Line(sign.End())
414			if lastParam != nil && f.Line(sign.Pos()) != endLine && f.Line(lastParam.Pos()) == endLine {
415				// The body is preceded by a multi-line function
416				// signature, and the empty line helps readability.
417				return
418			}
419		}
420
421		f.removeLinesBetween(node.Lbrace, bodyPos)
422
423	case *ast.CompositeLit:
424		if len(node.Elts) == 0 {
425			// doesn't have elements
426			break
427		}
428		openLine := f.Line(node.Lbrace)
429		closeLine := f.Line(node.Rbrace)
430		if openLine == closeLine {
431			// all in a single line
432			break
433		}
434
435		newlineAroundElems := false
436		newlineBetweenElems := false
437		lastLine := openLine
438		for i, elem := range node.Elts {
439			if f.Line(elem.Pos()) > lastLine {
440				if i == 0 {
441					newlineAroundElems = true
442				} else {
443					newlineBetweenElems = true
444				}
445			}
446			lastLine = f.Line(elem.End())
447		}
448		if closeLine > lastLine {
449			newlineAroundElems = true
450		}
451
452		if newlineBetweenElems || newlineAroundElems {
453			first := node.Elts[0]
454			if openLine == f.Line(first.Pos()) {
455				// We want the newline right after the brace.
456				f.addNewline(node.Lbrace + 1)
457				closeLine = f.Line(node.Rbrace)
458			}
459			last := node.Elts[len(node.Elts)-1]
460			if closeLine == f.Line(last.End()) {
461				// We want the newline right before the brace.
462				f.addNewline(node.Rbrace)
463			}
464		}
465
466		// If there's a newline between any consecutive elements, there
467		// must be a newline between all composite literal elements.
468		if !newlineBetweenElems {
469			break
470		}
471		for i1, elem1 := range node.Elts {
472			i2 := i1 + 1
473			if i2 >= len(node.Elts) {
474				break
475			}
476			elem2 := node.Elts[i2]
477			// TODO: do we care about &{}?
478			_, ok1 := elem1.(*ast.CompositeLit)
479			_, ok2 := elem2.(*ast.CompositeLit)
480			if !ok1 && !ok2 {
481				continue
482			}
483			if f.Line(elem1.End()) == f.Line(elem2.Pos()) {
484				f.addNewline(elem1.End())
485			}
486		}
487
488	case *ast.CaseClause:
489		f.stmts(node.Body)
490		openLine := f.Line(node.Case)
491		closeLine := f.Line(node.Colon)
492		if openLine == closeLine {
493			// nothing to do
494			break
495		}
496		if len(f.commentsBetween(node.Case, node.Colon)) > 0 {
497			// don't move comments
498			break
499		}
500		if f.printLength(node) > shortLineLimit {
501			// too long to collapse
502			break
503		}
504		f.removeLines(openLine, closeLine)
505
506	case *ast.CommClause:
507		f.stmts(node.Body)
508
509	case *ast.FieldList:
510		if node.NumFields() == 0 && f.inlineComment(node.Pos()) == nil {
511			// Empty field lists should not contain a newline.
512			// Do not join the two lines if the first has an inline
513			// comment, as that can result in broken formatting.
514			openLine := f.Line(node.Pos())
515			closeLine := f.Line(node.End())
516			f.removeLines(openLine, closeLine)
517		}
518
519		// Merging adjacent fields (e.g. parameters) is disabled by default.
520		if !f.ExtraRules {
521			break
522		}
523		switch c.Parent().(type) {
524		case *ast.FuncDecl, *ast.FuncType, *ast.InterfaceType:
525			node.List = f.mergeAdjacentFields(node.List)
526			c.Replace(node)
527		case *ast.StructType:
528			// Do not merge adjacent fields in structs.
529		}
530
531	case *ast.BasicLit:
532		// Octal number literals were introduced in 1.13.
533		if semver.Compare(f.LangVersion, "v1.13") >= 0 {
534			if node.Kind == token.INT && rxOctalInteger.MatchString(node.Value) {
535				node.Value = "0o" + node.Value[1:]
536				c.Replace(node)
537			}
538		}
539	}
540}
541
542func (f *fumpter) stmts(list []ast.Stmt) {
543	for i, stmt := range list {
544		ifs, ok := stmt.(*ast.IfStmt)
545		if !ok || i < 1 {
546			continue // not an if following another statement
547		}
548		as, ok := list[i-1].(*ast.AssignStmt)
549		if !ok || as.Tok != token.DEFINE ||
550			!identEqual(as.Lhs[len(as.Lhs)-1], "err") {
551			continue // not "..., err := ..."
552		}
553		be, ok := ifs.Cond.(*ast.BinaryExpr)
554		if !ok || ifs.Init != nil || ifs.Else != nil {
555			continue // complex if
556		}
557		if be.Op != token.NEQ || !identEqual(be.X, "err") ||
558			!identEqual(be.Y, "nil") {
559			continue // not "err != nil"
560		}
561		f.removeLinesBetween(as.End(), ifs.Pos())
562	}
563}
564
565func identEqual(expr ast.Expr, name string) bool {
566	id, ok := expr.(*ast.Ident)
567	return ok && id.Name == name
568}
569
570// isCgoImport returns true if the declaration is simply:
571//
572//   import "C"
573//
574// Note that parentheses do not affect the result.
575func isCgoImport(decl *ast.GenDecl) bool {
576	if decl.Tok != token.IMPORT || len(decl.Specs) != 1 {
577		return false
578	}
579	spec := decl.Specs[0].(*ast.ImportSpec)
580	return spec.Path.Value == `"C"`
581}
582
583// joinStdImports ensures that all standard library imports are together and at
584// the top of the imports list.
585func (f *fumpter) joinStdImports(d *ast.GenDecl) {
586	var std, other []ast.Spec
587	firstGroup := true
588	lastEnd := d.Pos()
589	needsSort := false
590	for i, spec := range d.Specs {
591		spec := spec.(*ast.ImportSpec)
592		if coms := f.commentsBetween(lastEnd, spec.Pos()); len(coms) > 0 {
593			lastEnd = coms[len(coms)-1].End()
594		}
595		if i > 0 && firstGroup && f.Line(spec.Pos()) > f.Line(lastEnd)+1 {
596			firstGroup = false
597		} else {
598			// We're still in the first group, update lastEnd.
599			lastEnd = spec.End()
600		}
601
602		path, _ := strconv.Unquote(spec.Path.Value)
603		switch {
604		// Imports with a period are definitely third party.
605		case strings.Contains(path, "."):
606			fallthrough
607		// "test" and "example" are reserved as per golang.org/issue/37641.
608		// "internal" is unreachable.
609		case strings.HasPrefix(path, "test/") ||
610			strings.HasPrefix(path, "example/") ||
611			strings.HasPrefix(path, "internal/"):
612			fallthrough
613		// To be conservative, if an import has a name or an inline
614		// comment, and isn't part of the top group, treat it as non-std.
615		case !firstGroup && (spec.Name != nil || spec.Comment != nil):
616			other = append(other, spec)
617			continue
618		}
619
620		// If we're moving this std import further up, reset its
621		// position, to avoid breaking comments.
622		if !firstGroup || len(other) > 0 {
623			setPos(reflect.ValueOf(spec), d.Pos())
624			needsSort = true
625		}
626		std = append(std, spec)
627	}
628	// Ensure there is an empty line between std imports and other imports.
629	if len(std) > 0 && len(other) > 0 && f.Line(std[len(std)-1].End())+1 >= f.Line(other[0].Pos()) {
630		// We add two newlines, as that's necessary in some edge cases.
631		// For example, if the std and non-std imports were together and
632		// without indentation, adding one newline isn't enough. Two
633		// empty lines will be printed as one by go/printer, anyway.
634		f.addNewline(other[0].Pos() - 1)
635		f.addNewline(other[0].Pos())
636	}
637	// Finally, join the imports, keeping std at the top.
638	d.Specs = append(std, other...)
639
640	// If we moved any std imports to the first group, we need to sort them
641	// again.
642	if needsSort {
643		ast.SortImports(f.fset, f.astFile)
644	}
645}
646
647// mergeAdjacentFields returns fields with adjacent fields merged if possible.
648func (f *fumpter) mergeAdjacentFields(fields []*ast.Field) []*ast.Field {
649	// If there are less than two fields then there is nothing to merge.
650	if len(fields) < 2 {
651		return fields
652	}
653
654	// Otherwise, iterate over adjacent pairs of fields, merging if possible,
655	// and mutating fields. Elements of fields may be mutated (if merged with
656	// following fields), discarded (if merged with a preceeding field), or left
657	// unchanged.
658	i := 0
659	for j := 1; j < len(fields); j++ {
660		if f.shouldMergeAdjacentFields(fields[i], fields[j]) {
661			fields[i].Names = append(fields[i].Names, fields[j].Names...)
662		} else {
663			i++
664			fields[i] = fields[j]
665		}
666	}
667	return fields[:i+1]
668}
669
670func (f *fumpter) shouldMergeAdjacentFields(f1, f2 *ast.Field) bool {
671	if len(f1.Names) == 0 || len(f2.Names) == 0 {
672		// Both must have names for the merge to work.
673		return false
674	}
675	if f.Line(f1.Pos()) != f.Line(f2.Pos()) {
676		// Trust the user if they used separate lines.
677		return false
678	}
679
680	// Only merge if the types are equal.
681	opt := cmp.Comparer(func(x, y token.Pos) bool { return true })
682	return cmp.Equal(f1.Type, f2.Type, opt)
683}
684
685var posType = reflect.TypeOf(token.NoPos)
686
687// setPos recursively sets all position fields in the node v to pos.
688func setPos(v reflect.Value, pos token.Pos) {
689	if v.Kind() == reflect.Ptr {
690		v = v.Elem()
691	}
692	if !v.IsValid() {
693		return
694	}
695	if v.Type() == posType {
696		v.Set(reflect.ValueOf(pos))
697	}
698	if v.Kind() == reflect.Struct {
699		for i := 0; i < v.NumField(); i++ {
700			setPos(v.Field(i), pos)
701		}
702	}
703}
704