1// Copyright (c) 2019, Daniel Martí <mvdan@mvdan.cc> 2// See LICENSE for licensing information 3 4// Package format exposes gofumpt's formatting in an API similar to go/format. 5// In general, the APIs are only guaranteed to work well when the input source 6// is in canonical gofmt format. 7package format 8 9import ( 10 "bytes" 11 "fmt" 12 "go/ast" 13 "go/format" 14 "go/parser" 15 "go/token" 16 "reflect" 17 "regexp" 18 "sort" 19 "strconv" 20 "strings" 21 "unicode" 22 "unicode/utf8" 23 24 "github.com/google/go-cmp/cmp" 25 "golang.org/x/mod/semver" 26 "golang.org/x/tools/go/ast/astutil" 27) 28 29type Options struct { 30 // LangVersion corresponds to the Go language version a piece of code is 31 // written in. The version is used to decide whether to apply formatting 32 // rules which require new language features. When inside a Go module, 33 // LangVersion should generally be specified as the result of: 34 // 35 // go list -m -f {{.GoVersion}} 36 // 37 // LangVersion is treated as a semantic version, which might start with 38 // a "v" prefix. Like Go versions, it might also be incomplete; "1.14" 39 // is equivalent to "1.14.0". When empty, it is equivalent to "v1", to 40 // not use language features which could break programs. 41 LangVersion string 42 43 ExtraRules bool 44} 45 46// Source formats src in gofumpt's format, assuming that src holds a valid Go 47// source file. 48func Source(src []byte, opts Options) ([]byte, error) { 49 fset := token.NewFileSet() 50 file, err := parser.ParseFile(fset, "", src, parser.ParseComments) 51 if err != nil { 52 return nil, err 53 } 54 55 File(fset, file, opts) 56 57 var buf bytes.Buffer 58 if err := format.Node(&buf, fset, file); err != nil { 59 return nil, err 60 } 61 return buf.Bytes(), nil 62} 63 64// File modifies a file and fset in place to follow gofumpt's format. The 65// changes might include manipulating adding or removing newlines in fset, 66// modifying the position of nodes, or modifying literal values. 67func File(fset *token.FileSet, file *ast.File, opts Options) { 68 if opts.LangVersion == "" { 69 opts.LangVersion = "v1" 70 } else if opts.LangVersion[0] != 'v' { 71 opts.LangVersion = "v" + opts.LangVersion 72 } 73 if !semver.IsValid(opts.LangVersion) { 74 panic(fmt.Sprintf("invalid semver string: %q", opts.LangVersion)) 75 } 76 f := &fumpter{ 77 File: fset.File(file.Pos()), 78 fset: fset, 79 astFile: file, 80 Options: opts, 81 } 82 pre := func(c *astutil.Cursor) bool { 83 f.applyPre(c) 84 if _, ok := c.Node().(*ast.BlockStmt); ok { 85 f.blockLevel++ 86 } 87 return true 88 } 89 post := func(c *astutil.Cursor) bool { 90 if _, ok := c.Node().(*ast.BlockStmt); ok { 91 f.blockLevel-- 92 } 93 return true 94 } 95 astutil.Apply(file, pre, post) 96} 97 98// Multiline nodes which could fit on a single line under this many 99// bytes may be collapsed onto a single line. 100const shortLineLimit = 60 101 102var rxOctalInteger = regexp.MustCompile(`\A0[0-7_]+\z`) 103 104type fumpter struct { 105 Options 106 107 *token.File 108 fset *token.FileSet 109 110 astFile *ast.File 111 112 blockLevel int 113} 114 115func (f *fumpter) commentsBetween(p1, p2 token.Pos) []*ast.CommentGroup { 116 comments := f.astFile.Comments 117 i1 := sort.Search(len(comments), func(i int) bool { 118 return comments[i].Pos() >= p1 119 }) 120 comments = comments[i1:] 121 i2 := sort.Search(len(comments), func(i int) bool { 122 return comments[i].Pos() >= p2 123 }) 124 comments = comments[:i2] 125 return comments 126} 127 128func (f *fumpter) inlineComment(pos token.Pos) *ast.Comment { 129 comments := f.astFile.Comments 130 i := sort.Search(len(comments), func(i int) bool { 131 return comments[i].Pos() >= pos 132 }) 133 if i >= len(comments) { 134 return nil 135 } 136 line := f.Line(pos) 137 for _, comment := range comments[i].List { 138 if f.Line(comment.Pos()) == line { 139 return comment 140 } 141 } 142 return nil 143} 144 145// addNewline is a hack to let us force a newline at a certain position. 146func (f *fumpter) addNewline(at token.Pos) { 147 offset := f.Offset(at) 148 149 field := reflect.ValueOf(f.File).Elem().FieldByName("lines") 150 n := field.Len() 151 lines := make([]int, 0, n+1) 152 for i := 0; i < n; i++ { 153 cur := int(field.Index(i).Int()) 154 if offset == cur { 155 // This newline already exists; do nothing. Duplicate 156 // newlines can't exist. 157 return 158 } 159 if offset >= 0 && offset < cur { 160 lines = append(lines, offset) 161 offset = -1 162 } 163 lines = append(lines, cur) 164 } 165 if offset >= 0 { 166 lines = append(lines, offset) 167 } 168 if !f.SetLines(lines) { 169 panic(fmt.Sprintf("could not set lines to %v", lines)) 170 } 171} 172 173// removeLines removes all newlines between two positions, so that they end 174// up on the same line. 175func (f *fumpter) removeLines(fromLine, toLine int) { 176 for fromLine < toLine { 177 f.MergeLine(fromLine) 178 toLine-- 179 } 180} 181 182// removeLinesBetween is like removeLines, but it leaves one newline between the 183// two positions. 184func (f *fumpter) removeLinesBetween(from, to token.Pos) { 185 f.removeLines(f.Line(from)+1, f.Line(to)) 186} 187 188type byteCounter int 189 190func (b *byteCounter) Write(p []byte) (n int, err error) { 191 *b += byteCounter(len(p)) 192 return len(p), nil 193} 194 195func (f *fumpter) printLength(node ast.Node) int { 196 var count byteCounter 197 if err := format.Node(&count, f.fset, node); err != nil { 198 panic(fmt.Sprintf("unexpected print error: %v", err)) 199 } 200 201 // Add the space taken by an inline comment. 202 if c := f.inlineComment(node.End()); c != nil { 203 fmt.Fprintf(&count, " %s", c.Text) 204 } 205 206 // Add an approximation of the indentation level. We can't know the 207 // number of tabs go/printer will add ahead of time. Trying to print the 208 // entire top-level declaration would tell us that, but then it's near 209 // impossible to reliably find our node again. 210 return int(count) + (f.blockLevel * 8) 211} 212 213// rxCommentDirective covers all common Go comment directives: 214// 215// //go: | standard Go directives, like go:noinline 216// //some-words: | similar to the syntax above, like lint:ignore or go-sumtype:decl 217// //line | inserted line information for cmd/compile 218// //export | to mark cgo funcs for exporting 219// //extern | C function declarations for gccgo 220// //sys(nb)? | syscall function wrapper prototypes 221// //nolint | nolint directive for golangci 222// 223// Note that the "some-words:" matching expects a letter afterward, such as 224// "go:generate", to prevent matching false positives like "https://site". 225var rxCommentDirective = regexp.MustCompile(`^([a-z-]+:[a-z]+|line\b|export\b|extern\b|sys(nb)?\b|nolint\b)`) 226 227// visit takes either an ast.Node or a []ast.Stmt. 228func (f *fumpter) applyPre(c *astutil.Cursor) { 229 switch node := c.Node().(type) { 230 case *ast.File: 231 var lastMulti bool 232 var lastEnd token.Pos 233 for _, decl := range node.Decls { 234 pos := decl.Pos() 235 comments := f.commentsBetween(lastEnd, pos) 236 if len(comments) > 0 { 237 pos = comments[0].Pos() 238 } 239 240 // multiline top-level declarations should be separated 241 multi := f.Line(pos) < f.Line(decl.End()) 242 if multi && lastMulti && f.Line(lastEnd)+1 == f.Line(pos) { 243 f.addNewline(lastEnd) 244 } 245 246 lastMulti = multi 247 lastEnd = decl.End() 248 } 249 250 // Join contiguous lone var/const/import lines; abort if there 251 // are empty lines or comments in between. 252 newDecls := make([]ast.Decl, 0, len(node.Decls)) 253 for i := 0; i < len(node.Decls); { 254 newDecls = append(newDecls, node.Decls[i]) 255 start, ok := node.Decls[i].(*ast.GenDecl) 256 if !ok || isCgoImport(start) { 257 i++ 258 continue 259 } 260 lastPos := start.Pos() 261 for i++; i < len(node.Decls); { 262 cont, ok := node.Decls[i].(*ast.GenDecl) 263 if !ok || cont.Tok != start.Tok || cont.Lparen != token.NoPos || 264 f.Line(lastPos) < f.Line(cont.Pos())-1 || isCgoImport(cont) { 265 break 266 } 267 start.Specs = append(start.Specs, cont.Specs...) 268 if c := f.inlineComment(cont.End()); c != nil { 269 // don't move an inline comment outside 270 start.Rparen = c.End() 271 } 272 lastPos = cont.Pos() 273 i++ 274 } 275 } 276 node.Decls = newDecls 277 278 // Comments aren't nodes, so they're not walked by default. 279 groupLoop: 280 for _, group := range node.Comments { 281 for _, comment := range group.List { 282 body := strings.TrimPrefix(comment.Text, "//") 283 if body == comment.Text { 284 // /*-style comment 285 continue groupLoop 286 } 287 if rxCommentDirective.MatchString(body) { 288 // this line is a directive 289 continue groupLoop 290 } 291 r, _ := utf8.DecodeRuneInString(body) 292 if !unicode.IsLetter(r) && !unicode.IsNumber(r) && !unicode.IsSpace(r) { 293 // this line could be code like "//{" 294 continue groupLoop 295 } 296 } 297 // If none of the comment group's lines look like a 298 // directive or code, add spaces, if needed. 299 for _, comment := range group.List { 300 body := strings.TrimPrefix(comment.Text, "//") 301 r, _ := utf8.DecodeRuneInString(body) 302 if !unicode.IsSpace(r) { 303 comment.Text = "// " + strings.TrimPrefix(comment.Text, "//") 304 } 305 } 306 } 307 308 case *ast.DeclStmt: 309 decl, ok := node.Decl.(*ast.GenDecl) 310 if !ok || decl.Tok != token.VAR || len(decl.Specs) != 1 { 311 break // e.g. const name = "value" 312 } 313 spec := decl.Specs[0].(*ast.ValueSpec) 314 if spec.Type != nil { 315 break // e.g. var name Type 316 } 317 tok := token.ASSIGN 318 names := make([]ast.Expr, len(spec.Names)) 319 for i, name := range spec.Names { 320 names[i] = name 321 if name.Name != "_" { 322 tok = token.DEFINE 323 } 324 } 325 c.Replace(&ast.AssignStmt{ 326 Lhs: names, 327 Tok: tok, 328 Rhs: spec.Values, 329 }) 330 331 case *ast.GenDecl: 332 if node.Tok == token.IMPORT && node.Lparen.IsValid() { 333 f.joinStdImports(node) 334 } 335 336 // Single var declarations shouldn't use parentheses, unless 337 // there's a comment on the grouped declaration. 338 if node.Tok == token.VAR && len(node.Specs) == 1 && 339 node.Lparen.IsValid() && node.Doc == nil { 340 specPos := node.Specs[0].Pos() 341 specEnd := node.Specs[0].End() 342 343 if len(f.commentsBetween(node.TokPos, specPos)) > 0 { 344 // If the single spec has any comment, it must 345 // go before the entire declaration now. 346 node.TokPos = specPos 347 } else { 348 f.removeLines(f.Line(node.TokPos), f.Line(specPos)) 349 } 350 f.removeLines(f.Line(specEnd), f.Line(node.Rparen)) 351 352 // Remove the parentheses. go/printer will automatically 353 // get rid of the newlines. 354 node.Lparen = token.NoPos 355 node.Rparen = token.NoPos 356 } 357 358 case *ast.BlockStmt: 359 f.stmts(node.List) 360 comments := f.commentsBetween(node.Lbrace, node.Rbrace) 361 if len(node.List) == 0 && len(comments) == 0 { 362 f.removeLinesBetween(node.Lbrace, node.Rbrace) 363 break 364 } 365 366 var sign *ast.FuncType 367 var cond ast.Expr 368 switch parent := c.Parent().(type) { 369 case *ast.FuncDecl: 370 sign = parent.Type 371 case *ast.FuncLit: 372 sign = parent.Type 373 case *ast.IfStmt: 374 cond = parent.Cond 375 case *ast.ForStmt: 376 cond = parent.Cond 377 } 378 379 if len(node.List) > 1 && sign == nil { 380 // only if we have a single statement, or if 381 // it's a func body. 382 break 383 } 384 var bodyPos, bodyEnd token.Pos 385 386 if len(node.List) > 0 { 387 bodyPos = node.List[0].Pos() 388 bodyEnd = node.List[len(node.List)-1].End() 389 } 390 if len(comments) > 0 { 391 if pos := comments[0].Pos(); !bodyPos.IsValid() || pos < bodyPos { 392 bodyPos = pos 393 } 394 if pos := comments[len(comments)-1].End(); !bodyPos.IsValid() || pos > bodyEnd { 395 bodyEnd = pos 396 } 397 } 398 399 f.removeLinesBetween(bodyEnd, node.Rbrace) 400 401 if cond != nil && f.Line(cond.Pos()) != f.Line(cond.End()) { 402 // The body is preceded by a multi-line condition, so an 403 // empty line can help readability. 404 return 405 } 406 if sign != nil { 407 var lastParam *ast.Field 408 if l := sign.Results; l != nil && len(l.List) > 0 { 409 lastParam = l.List[len(l.List)-1] 410 } else if l := sign.Params; l != nil && len(l.List) > 0 { 411 lastParam = l.List[len(l.List)-1] 412 } 413 endLine := f.Line(sign.End()) 414 if lastParam != nil && f.Line(sign.Pos()) != endLine && f.Line(lastParam.Pos()) == endLine { 415 // The body is preceded by a multi-line function 416 // signature, and the empty line helps readability. 417 return 418 } 419 } 420 421 f.removeLinesBetween(node.Lbrace, bodyPos) 422 423 case *ast.CompositeLit: 424 if len(node.Elts) == 0 { 425 // doesn't have elements 426 break 427 } 428 openLine := f.Line(node.Lbrace) 429 closeLine := f.Line(node.Rbrace) 430 if openLine == closeLine { 431 // all in a single line 432 break 433 } 434 435 newlineAroundElems := false 436 newlineBetweenElems := false 437 lastLine := openLine 438 for i, elem := range node.Elts { 439 if f.Line(elem.Pos()) > lastLine { 440 if i == 0 { 441 newlineAroundElems = true 442 } else { 443 newlineBetweenElems = true 444 } 445 } 446 lastLine = f.Line(elem.End()) 447 } 448 if closeLine > lastLine { 449 newlineAroundElems = true 450 } 451 452 if newlineBetweenElems || newlineAroundElems { 453 first := node.Elts[0] 454 if openLine == f.Line(first.Pos()) { 455 // We want the newline right after the brace. 456 f.addNewline(node.Lbrace + 1) 457 closeLine = f.Line(node.Rbrace) 458 } 459 last := node.Elts[len(node.Elts)-1] 460 if closeLine == f.Line(last.End()) { 461 // We want the newline right before the brace. 462 f.addNewline(node.Rbrace) 463 } 464 } 465 466 // If there's a newline between any consecutive elements, there 467 // must be a newline between all composite literal elements. 468 if !newlineBetweenElems { 469 break 470 } 471 for i1, elem1 := range node.Elts { 472 i2 := i1 + 1 473 if i2 >= len(node.Elts) { 474 break 475 } 476 elem2 := node.Elts[i2] 477 // TODO: do we care about &{}? 478 _, ok1 := elem1.(*ast.CompositeLit) 479 _, ok2 := elem2.(*ast.CompositeLit) 480 if !ok1 && !ok2 { 481 continue 482 } 483 if f.Line(elem1.End()) == f.Line(elem2.Pos()) { 484 f.addNewline(elem1.End()) 485 } 486 } 487 488 case *ast.CaseClause: 489 f.stmts(node.Body) 490 openLine := f.Line(node.Case) 491 closeLine := f.Line(node.Colon) 492 if openLine == closeLine { 493 // nothing to do 494 break 495 } 496 if len(f.commentsBetween(node.Case, node.Colon)) > 0 { 497 // don't move comments 498 break 499 } 500 if f.printLength(node) > shortLineLimit { 501 // too long to collapse 502 break 503 } 504 f.removeLines(openLine, closeLine) 505 506 case *ast.CommClause: 507 f.stmts(node.Body) 508 509 case *ast.FieldList: 510 if node.NumFields() == 0 && f.inlineComment(node.Pos()) == nil { 511 // Empty field lists should not contain a newline. 512 // Do not join the two lines if the first has an inline 513 // comment, as that can result in broken formatting. 514 openLine := f.Line(node.Pos()) 515 closeLine := f.Line(node.End()) 516 f.removeLines(openLine, closeLine) 517 } 518 519 // Merging adjacent fields (e.g. parameters) is disabled by default. 520 if !f.ExtraRules { 521 break 522 } 523 switch c.Parent().(type) { 524 case *ast.FuncDecl, *ast.FuncType, *ast.InterfaceType: 525 node.List = f.mergeAdjacentFields(node.List) 526 c.Replace(node) 527 case *ast.StructType: 528 // Do not merge adjacent fields in structs. 529 } 530 531 case *ast.BasicLit: 532 // Octal number literals were introduced in 1.13. 533 if semver.Compare(f.LangVersion, "v1.13") >= 0 { 534 if node.Kind == token.INT && rxOctalInteger.MatchString(node.Value) { 535 node.Value = "0o" + node.Value[1:] 536 c.Replace(node) 537 } 538 } 539 } 540} 541 542func (f *fumpter) stmts(list []ast.Stmt) { 543 for i, stmt := range list { 544 ifs, ok := stmt.(*ast.IfStmt) 545 if !ok || i < 1 { 546 continue // not an if following another statement 547 } 548 as, ok := list[i-1].(*ast.AssignStmt) 549 if !ok || as.Tok != token.DEFINE || 550 !identEqual(as.Lhs[len(as.Lhs)-1], "err") { 551 continue // not "..., err := ..." 552 } 553 be, ok := ifs.Cond.(*ast.BinaryExpr) 554 if !ok || ifs.Init != nil || ifs.Else != nil { 555 continue // complex if 556 } 557 if be.Op != token.NEQ || !identEqual(be.X, "err") || 558 !identEqual(be.Y, "nil") { 559 continue // not "err != nil" 560 } 561 f.removeLinesBetween(as.End(), ifs.Pos()) 562 } 563} 564 565func identEqual(expr ast.Expr, name string) bool { 566 id, ok := expr.(*ast.Ident) 567 return ok && id.Name == name 568} 569 570// isCgoImport returns true if the declaration is simply: 571// 572// import "C" 573// 574// Note that parentheses do not affect the result. 575func isCgoImport(decl *ast.GenDecl) bool { 576 if decl.Tok != token.IMPORT || len(decl.Specs) != 1 { 577 return false 578 } 579 spec := decl.Specs[0].(*ast.ImportSpec) 580 return spec.Path.Value == `"C"` 581} 582 583// joinStdImports ensures that all standard library imports are together and at 584// the top of the imports list. 585func (f *fumpter) joinStdImports(d *ast.GenDecl) { 586 var std, other []ast.Spec 587 firstGroup := true 588 lastEnd := d.Pos() 589 needsSort := false 590 for i, spec := range d.Specs { 591 spec := spec.(*ast.ImportSpec) 592 if coms := f.commentsBetween(lastEnd, spec.Pos()); len(coms) > 0 { 593 lastEnd = coms[len(coms)-1].End() 594 } 595 if i > 0 && firstGroup && f.Line(spec.Pos()) > f.Line(lastEnd)+1 { 596 firstGroup = false 597 } else { 598 // We're still in the first group, update lastEnd. 599 lastEnd = spec.End() 600 } 601 602 path, _ := strconv.Unquote(spec.Path.Value) 603 switch { 604 // Imports with a period are definitely third party. 605 case strings.Contains(path, "."): 606 fallthrough 607 // "test" and "example" are reserved as per golang.org/issue/37641. 608 // "internal" is unreachable. 609 case strings.HasPrefix(path, "test/") || 610 strings.HasPrefix(path, "example/") || 611 strings.HasPrefix(path, "internal/"): 612 fallthrough 613 // To be conservative, if an import has a name or an inline 614 // comment, and isn't part of the top group, treat it as non-std. 615 case !firstGroup && (spec.Name != nil || spec.Comment != nil): 616 other = append(other, spec) 617 continue 618 } 619 620 // If we're moving this std import further up, reset its 621 // position, to avoid breaking comments. 622 if !firstGroup || len(other) > 0 { 623 setPos(reflect.ValueOf(spec), d.Pos()) 624 needsSort = true 625 } 626 std = append(std, spec) 627 } 628 // Ensure there is an empty line between std imports and other imports. 629 if len(std) > 0 && len(other) > 0 && f.Line(std[len(std)-1].End())+1 >= f.Line(other[0].Pos()) { 630 // We add two newlines, as that's necessary in some edge cases. 631 // For example, if the std and non-std imports were together and 632 // without indentation, adding one newline isn't enough. Two 633 // empty lines will be printed as one by go/printer, anyway. 634 f.addNewline(other[0].Pos() - 1) 635 f.addNewline(other[0].Pos()) 636 } 637 // Finally, join the imports, keeping std at the top. 638 d.Specs = append(std, other...) 639 640 // If we moved any std imports to the first group, we need to sort them 641 // again. 642 if needsSort { 643 ast.SortImports(f.fset, f.astFile) 644 } 645} 646 647// mergeAdjacentFields returns fields with adjacent fields merged if possible. 648func (f *fumpter) mergeAdjacentFields(fields []*ast.Field) []*ast.Field { 649 // If there are less than two fields then there is nothing to merge. 650 if len(fields) < 2 { 651 return fields 652 } 653 654 // Otherwise, iterate over adjacent pairs of fields, merging if possible, 655 // and mutating fields. Elements of fields may be mutated (if merged with 656 // following fields), discarded (if merged with a preceeding field), or left 657 // unchanged. 658 i := 0 659 for j := 1; j < len(fields); j++ { 660 if f.shouldMergeAdjacentFields(fields[i], fields[j]) { 661 fields[i].Names = append(fields[i].Names, fields[j].Names...) 662 } else { 663 i++ 664 fields[i] = fields[j] 665 } 666 } 667 return fields[:i+1] 668} 669 670func (f *fumpter) shouldMergeAdjacentFields(f1, f2 *ast.Field) bool { 671 if len(f1.Names) == 0 || len(f2.Names) == 0 { 672 // Both must have names for the merge to work. 673 return false 674 } 675 if f.Line(f1.Pos()) != f.Line(f2.Pos()) { 676 // Trust the user if they used separate lines. 677 return false 678 } 679 680 // Only merge if the types are equal. 681 opt := cmp.Comparer(func(x, y token.Pos) bool { return true }) 682 return cmp.Equal(f1.Type, f2.Type, opt) 683} 684 685var posType = reflect.TypeOf(token.NoPos) 686 687// setPos recursively sets all position fields in the node v to pos. 688func setPos(v reflect.Value, pos token.Pos) { 689 if v.Kind() == reflect.Ptr { 690 v = v.Elem() 691 } 692 if !v.IsValid() { 693 return 694 } 695 if v.Type() == posType { 696 v.Set(reflect.ValueOf(pos)) 697 } 698 if v.Kind() == reflect.Struct { 699 for i := 0; i < v.NumField(); i++ { 700 setPos(v.Field(i), pos) 701 } 702 } 703} 704