1// Copyright (c) 2019, Daniel Martí <mvdan@mvdan.cc> 2// See LICENSE for licensing information 3 4// Package format exposes gofumpt's formatting in an API similar to go/format. 5// In general, the APIs are only guaranteed to work well when the input source 6// is in canonical gofmt format. 7package format 8 9import ( 10 "bytes" 11 "fmt" 12 "go/ast" 13 "go/format" 14 "go/parser" 15 "go/token" 16 "reflect" 17 "regexp" 18 "sort" 19 "strings" 20 "unicode" 21 "unicode/utf8" 22 23 "github.com/google/go-cmp/cmp" 24 "golang.org/x/mod/semver" 25 "golang.org/x/tools/go/ast/astutil" 26) 27 28type Options struct { 29 // LangVersion corresponds to the Go language version a piece of code is 30 // written in. The version is used to decide whether to apply formatting 31 // rules which require new language features. When inside a Go module, 32 // LangVersion should generally be specified as the result of: 33 // 34 // go list -m -f {{.GoVersion}} 35 // 36 // LangVersion is treated as a semantic version, which might start with 37 // a "v" prefix. Like Go versions, it might also be incomplete; "1.14" 38 // is equivalent to "1.14.0". When empty, it is equivalent to "v1", to 39 // not use language features which could break programs. 40 LangVersion string 41 42 ExtraRules bool 43} 44 45// Source formats src in gofumpt's format, assuming that src holds a valid Go 46// source file. 47func Source(src []byte, opts Options) ([]byte, error) { 48 fset := token.NewFileSet() 49 file, err := parser.ParseFile(fset, "", src, parser.ParseComments) 50 if err != nil { 51 return nil, err 52 } 53 54 File(fset, file, opts) 55 56 var buf bytes.Buffer 57 if err := format.Node(&buf, fset, file); err != nil { 58 return nil, err 59 } 60 return buf.Bytes(), nil 61} 62 63// File modifies a file and fset in place to follow gofumpt's format. The 64// changes might include manipulating adding or removing newlines in fset, 65// modifying the position of nodes, or modifying literal values. 66func File(fset *token.FileSet, file *ast.File, opts Options) { 67 if opts.LangVersion == "" { 68 opts.LangVersion = "v1" 69 } else if opts.LangVersion[0] != 'v' { 70 opts.LangVersion = "v" + opts.LangVersion 71 } 72 if !semver.IsValid(opts.LangVersion) { 73 panic(fmt.Sprintf("invalid semver string: %q", opts.LangVersion)) 74 } 75 f := &fumpter{ 76 File: fset.File(file.Pos()), 77 fset: fset, 78 astFile: file, 79 Options: opts, 80 } 81 pre := func(c *astutil.Cursor) bool { 82 f.applyPre(c) 83 if _, ok := c.Node().(*ast.BlockStmt); ok { 84 f.blockLevel++ 85 } 86 return true 87 } 88 post := func(c *astutil.Cursor) bool { 89 if _, ok := c.Node().(*ast.BlockStmt); ok { 90 f.blockLevel-- 91 } 92 return true 93 } 94 astutil.Apply(file, pre, post) 95} 96 97// Multiline nodes which could fit on a single line under this many 98// bytes may be collapsed onto a single line. 99const shortLineLimit = 60 100 101var rxOctalInteger = regexp.MustCompile(`\A0[0-7_]+\z`) 102 103type fumpter struct { 104 Options 105 106 *token.File 107 fset *token.FileSet 108 109 astFile *ast.File 110 111 blockLevel int 112} 113 114func (f *fumpter) commentsBetween(p1, p2 token.Pos) []*ast.CommentGroup { 115 comments := f.astFile.Comments 116 i1 := sort.Search(len(comments), func(i int) bool { 117 return comments[i].Pos() >= p1 118 }) 119 comments = comments[i1:] 120 i2 := sort.Search(len(comments), func(i int) bool { 121 return comments[i].Pos() >= p2 122 }) 123 comments = comments[:i2] 124 return comments 125} 126 127func (f *fumpter) inlineComment(pos token.Pos) *ast.Comment { 128 comments := f.astFile.Comments 129 i := sort.Search(len(comments), func(i int) bool { 130 return comments[i].Pos() >= pos 131 }) 132 if i >= len(comments) { 133 return nil 134 } 135 line := f.Line(pos) 136 for _, comment := range comments[i].List { 137 if f.Line(comment.Pos()) == line { 138 return comment 139 } 140 } 141 return nil 142} 143 144// addNewline is a hack to let us force a newline at a certain position. 145func (f *fumpter) addNewline(at token.Pos) { 146 offset := f.Offset(at) 147 148 field := reflect.ValueOf(f.File).Elem().FieldByName("lines") 149 n := field.Len() 150 lines := make([]int, 0, n+1) 151 for i := 0; i < n; i++ { 152 cur := int(field.Index(i).Int()) 153 if offset == cur { 154 // This newline already exists; do nothing. Duplicate 155 // newlines can't exist. 156 return 157 } 158 if offset >= 0 && offset < cur { 159 lines = append(lines, offset) 160 offset = -1 161 } 162 lines = append(lines, cur) 163 } 164 if offset >= 0 { 165 lines = append(lines, offset) 166 } 167 if !f.SetLines(lines) { 168 panic(fmt.Sprintf("could not set lines to %v", lines)) 169 } 170} 171 172// removeNewlines removes all newlines between two positions, so that they end 173// up on the same line. 174func (f *fumpter) removeLines(fromLine, toLine int) { 175 for fromLine < toLine { 176 f.MergeLine(fromLine) 177 toLine-- 178 } 179} 180 181// removeLinesBetween is like removeLines, but it leaves one newline between the 182// two positions. 183func (f *fumpter) removeLinesBetween(from, to token.Pos) { 184 f.removeLines(f.Line(from)+1, f.Line(to)) 185} 186 187type byteCounter int 188 189func (b *byteCounter) Write(p []byte) (n int, err error) { 190 *b += byteCounter(len(p)) 191 return len(p), nil 192} 193 194func (f *fumpter) printLength(node ast.Node) int { 195 var count byteCounter 196 if err := format.Node(&count, f.fset, node); err != nil { 197 panic(fmt.Sprintf("unexpected print error: %v", err)) 198 } 199 200 // Add the space taken by an inline comment. 201 if c := f.inlineComment(node.End()); c != nil { 202 fmt.Fprintf(&count, " %s", c.Text) 203 } 204 205 // Add an approximation of the indentation level. We can't know the 206 // number of tabs go/printer will add ahead of time. Trying to print the 207 // entire top-level declaration would tell us that, but then it's near 208 // impossible to reliably find our node again. 209 return int(count) + (f.blockLevel * 8) 210} 211 212// rxCommentDirective covers all common Go comment directives: 213// 214// //go: | standard Go directives, like go:noinline 215// //someword: | similar to the syntax above, like lint:ignore 216// //line | inserted line information for cmd/compile 217// //export | to mark cgo funcs for exporting 218// //extern | C function declarations for gccgo 219// //sys(nb)? | syscall function wrapper prototypes 220// //nolint | nolint directive for golangci 221var rxCommentDirective = regexp.MustCompile(`^([a-z]+:|line\b|export\b|extern\b|sys(nb)?\b|nolint\b)`) 222 223// visit takes either an ast.Node or a []ast.Stmt. 224func (f *fumpter) applyPre(c *astutil.Cursor) { 225 switch node := c.Node().(type) { 226 case *ast.File: 227 var lastMulti bool 228 var lastEnd token.Pos 229 for _, decl := range node.Decls { 230 pos := decl.Pos() 231 comments := f.commentsBetween(lastEnd, pos) 232 if len(comments) > 0 { 233 pos = comments[0].Pos() 234 } 235 236 // multiline top-level declarations should be separated 237 multi := f.Line(pos) < f.Line(decl.End()) 238 if multi && lastMulti && f.Line(lastEnd)+1 == f.Line(pos) { 239 f.addNewline(lastEnd) 240 } 241 242 lastMulti = multi 243 lastEnd = decl.End() 244 } 245 246 // Join contiguous lone var/const/import lines; abort if there 247 // are empty lines or comments in between. 248 newDecls := make([]ast.Decl, 0, len(node.Decls)) 249 for i := 0; i < len(node.Decls); { 250 newDecls = append(newDecls, node.Decls[i]) 251 start, ok := node.Decls[i].(*ast.GenDecl) 252 if !ok { 253 i++ 254 continue 255 } 256 lastPos := start.Pos() 257 for i++; i < len(node.Decls); { 258 cont, ok := node.Decls[i].(*ast.GenDecl) 259 if !ok || cont.Tok != start.Tok || cont.Lparen != token.NoPos || 260 f.Line(lastPos) < f.Line(cont.Pos())-1 { 261 break 262 } 263 start.Specs = append(start.Specs, cont.Specs...) 264 if c := f.inlineComment(cont.End()); c != nil { 265 // don't move an inline comment outside 266 start.Rparen = c.End() 267 } 268 lastPos = cont.Pos() 269 i++ 270 } 271 } 272 node.Decls = newDecls 273 274 // Comments aren't nodes, so they're not walked by default. 275 groupLoop: 276 for _, group := range node.Comments { 277 for _, comment := range group.List { 278 body := strings.TrimPrefix(comment.Text, "//") 279 if body == comment.Text { 280 // /*-style comment 281 continue groupLoop 282 } 283 if rxCommentDirective.MatchString(body) { 284 // this line is a directive 285 continue groupLoop 286 } 287 r, _ := utf8.DecodeRuneInString(body) 288 if !unicode.IsLetter(r) && !unicode.IsNumber(r) && !unicode.IsSpace(r) { 289 // this line could be code like "//{" 290 continue groupLoop 291 } 292 } 293 // If none of the comment group's lines look like a 294 // directive or code, add spaces, if needed. 295 for _, comment := range group.List { 296 body := strings.TrimPrefix(comment.Text, "//") 297 r, _ := utf8.DecodeRuneInString(body) 298 if !unicode.IsSpace(r) { 299 comment.Text = "// " + strings.TrimPrefix(comment.Text, "//") 300 } 301 } 302 } 303 304 case *ast.DeclStmt: 305 decl, ok := node.Decl.(*ast.GenDecl) 306 if !ok || decl.Tok != token.VAR || len(decl.Specs) != 1 { 307 break // e.g. const name = "value" 308 } 309 spec := decl.Specs[0].(*ast.ValueSpec) 310 if spec.Type != nil { 311 break // e.g. var name Type 312 } 313 tok := token.ASSIGN 314 names := make([]ast.Expr, len(spec.Names)) 315 for i, name := range spec.Names { 316 names[i] = name 317 if name.Name != "_" { 318 tok = token.DEFINE 319 } 320 } 321 c.Replace(&ast.AssignStmt{ 322 Lhs: names, 323 Tok: tok, 324 Rhs: spec.Values, 325 }) 326 327 case *ast.GenDecl: 328 if node.Tok == token.IMPORT && node.Lparen.IsValid() { 329 f.joinStdImports(node) 330 } 331 332 // Single var declarations shouldn't use parentheses, unless 333 // there's a comment on the grouped declaration. 334 if node.Tok == token.VAR && len(node.Specs) == 1 && 335 node.Lparen.IsValid() && node.Doc == nil { 336 specPos := node.Specs[0].Pos() 337 specEnd := node.Specs[0].End() 338 339 if len(f.commentsBetween(node.TokPos, specPos)) > 0 { 340 // If the single spec has any comment, it must 341 // go before the entire declaration now. 342 node.TokPos = specPos 343 } else { 344 f.removeLines(f.Line(node.TokPos), f.Line(specPos)) 345 } 346 f.removeLines(f.Line(specEnd), f.Line(node.Rparen)) 347 348 // Remove the parentheses. go/printer will automatically 349 // get rid of the newlines. 350 node.Lparen = token.NoPos 351 node.Rparen = token.NoPos 352 } 353 354 case *ast.BlockStmt: 355 f.stmts(node.List) 356 comments := f.commentsBetween(node.Lbrace, node.Rbrace) 357 if len(node.List) == 0 && len(comments) == 0 { 358 f.removeLinesBetween(node.Lbrace, node.Rbrace) 359 break 360 } 361 362 isFuncBody := false 363 switch c.Parent().(type) { 364 case *ast.FuncDecl: 365 isFuncBody = true 366 case *ast.FuncLit: 367 isFuncBody = true 368 } 369 370 if len(node.List) > 1 && !isFuncBody { 371 // only if we have a single statement, or if 372 // it's a func body. 373 break 374 } 375 var bodyPos, bodyEnd token.Pos 376 377 if len(node.List) > 0 { 378 bodyPos = node.List[0].Pos() 379 bodyEnd = node.List[len(node.List)-1].End() 380 } 381 if len(comments) > 0 { 382 if pos := comments[0].Pos(); !bodyPos.IsValid() || pos < bodyPos { 383 bodyPos = pos 384 } 385 if pos := comments[len(comments)-1].End(); !bodyPos.IsValid() || pos > bodyEnd { 386 bodyEnd = pos 387 } 388 } 389 390 f.removeLinesBetween(node.Lbrace, bodyPos) 391 f.removeLinesBetween(bodyEnd, node.Rbrace) 392 393 case *ast.CompositeLit: 394 if len(node.Elts) == 0 { 395 // doesn't have elements 396 break 397 } 398 openLine := f.Line(node.Lbrace) 399 closeLine := f.Line(node.Rbrace) 400 if openLine == closeLine { 401 // all in a single line 402 break 403 } 404 405 newlineAroundElems := false 406 newlineBetweenElems := false 407 lastLine := openLine 408 for i, elem := range node.Elts { 409 if f.Line(elem.Pos()) > lastLine { 410 if i == 0 { 411 newlineAroundElems = true 412 } else { 413 newlineBetweenElems = true 414 } 415 } 416 lastLine = f.Line(elem.End()) 417 } 418 if closeLine > lastLine { 419 newlineAroundElems = true 420 } 421 422 if newlineBetweenElems || newlineAroundElems { 423 first := node.Elts[0] 424 if openLine == f.Line(first.Pos()) { 425 // We want the newline right after the brace. 426 f.addNewline(node.Lbrace + 1) 427 closeLine = f.Line(node.Rbrace) 428 } 429 last := node.Elts[len(node.Elts)-1] 430 if closeLine == f.Line(last.End()) { 431 // We want the newline right before the brace. 432 f.addNewline(node.Rbrace) 433 } 434 } 435 436 // If there's a newline between any consecutive elements, there 437 // must be a newline between all composite literal elements. 438 if !newlineBetweenElems { 439 break 440 } 441 for i1, elem1 := range node.Elts { 442 i2 := i1 + 1 443 if i2 >= len(node.Elts) { 444 break 445 } 446 elem2 := node.Elts[i2] 447 // TODO: do we care about &{}? 448 _, ok1 := elem1.(*ast.CompositeLit) 449 _, ok2 := elem2.(*ast.CompositeLit) 450 if !ok1 && !ok2 { 451 continue 452 } 453 if f.Line(elem1.End()) == f.Line(elem2.Pos()) { 454 f.addNewline(elem1.End()) 455 } 456 } 457 458 case *ast.CaseClause: 459 f.stmts(node.Body) 460 openLine := f.Line(node.Case) 461 closeLine := f.Line(node.Colon) 462 if openLine == closeLine { 463 // nothing to do 464 break 465 } 466 if len(f.commentsBetween(node.Case, node.Colon)) > 0 { 467 // don't move comments 468 break 469 } 470 if f.printLength(node) > shortLineLimit { 471 // too long to collapse 472 break 473 } 474 f.removeLines(openLine, closeLine) 475 476 case *ast.CommClause: 477 f.stmts(node.Body) 478 479 case *ast.FieldList: 480 // Merging adjacent fields (e.g. parameters) is disabled by default. 481 if !f.ExtraRules { 482 break 483 } 484 switch c.Parent().(type) { 485 case *ast.FuncDecl, *ast.FuncType, *ast.InterfaceType: 486 node.List = f.mergeAdjacentFields(node.List) 487 c.Replace(node) 488 case *ast.StructType: 489 // Do not merge adjacent fields in structs. 490 } 491 492 case *ast.BasicLit: 493 // Octal number literals were introduced in 1.13. 494 if semver.Compare(f.LangVersion, "v1.13") >= 0 { 495 if node.Kind == token.INT && rxOctalInteger.MatchString(node.Value) { 496 node.Value = "0o" + node.Value[1:] 497 c.Replace(node) 498 } 499 } 500 } 501} 502 503func (f *fumpter) stmts(list []ast.Stmt) { 504 for i, stmt := range list { 505 ifs, ok := stmt.(*ast.IfStmt) 506 if !ok || i < 1 { 507 continue // not an if following another statement 508 } 509 as, ok := list[i-1].(*ast.AssignStmt) 510 if !ok || as.Tok != token.DEFINE || 511 !identEqual(as.Lhs[len(as.Lhs)-1], "err") { 512 continue // not "..., err := ..." 513 } 514 be, ok := ifs.Cond.(*ast.BinaryExpr) 515 if !ok || ifs.Init != nil || ifs.Else != nil { 516 continue // complex if 517 } 518 if be.Op != token.NEQ || !identEqual(be.X, "err") || 519 !identEqual(be.Y, "nil") { 520 continue // not "err != nil" 521 } 522 f.removeLinesBetween(as.End(), ifs.Pos()) 523 } 524} 525 526func identEqual(expr ast.Expr, name string) bool { 527 id, ok := expr.(*ast.Ident) 528 return ok && id.Name == name 529} 530 531// joinStdImports ensures that all standard library imports are together and at 532// the top of the imports list. 533func (f *fumpter) joinStdImports(d *ast.GenDecl) { 534 var std, other []ast.Spec 535 firstGroup := true 536 lastEnd := d.Pos() 537 needsSort := false 538 for i, spec := range d.Specs { 539 spec := spec.(*ast.ImportSpec) 540 if coms := f.commentsBetween(lastEnd, spec.Pos()); len(coms) > 0 { 541 lastEnd = coms[len(coms)-1].End() 542 } 543 if i > 0 && firstGroup && f.Line(spec.Pos()) > f.Line(lastEnd)+1 { 544 firstGroup = false 545 } else { 546 // We're still in the first group, update lastEnd. 547 lastEnd = spec.End() 548 } 549 550 // First, separate the non-std imports. 551 if strings.Contains(spec.Path.Value, ".") { 552 other = append(other, spec) 553 continue 554 } 555 // To be conservative, if an import has a name or an inline 556 // comment, and isn't part of the top group, treat it as non-std. 557 if !firstGroup && (spec.Name != nil || spec.Comment != nil) { 558 other = append(other, spec) 559 continue 560 } 561 562 // If we're moving this std import further up, reset its 563 // position, to avoid breaking comments. 564 if !firstGroup || len(other) > 0 { 565 setPos(reflect.ValueOf(spec), d.Pos()) 566 needsSort = true 567 } 568 std = append(std, spec) 569 } 570 // Ensure there is an empty line between std imports and other imports. 571 if len(std) > 0 && len(other) > 0 && f.Line(std[len(std)-1].End())+1 >= f.Line(other[0].Pos()) { 572 // We add two newlines, as that's necessary in some edge cases. 573 // For example, if the std and non-std imports were together and 574 // without indentation, adding one newline isn't enough. Two 575 // empty lines will be printed as one by go/printer, anyway. 576 f.addNewline(other[0].Pos() - 1) 577 f.addNewline(other[0].Pos()) 578 } 579 // Finally, join the imports, keeping std at the top. 580 d.Specs = append(std, other...) 581 582 // If we moved any std imports to the first group, we need to sort them 583 // again. 584 if needsSort { 585 ast.SortImports(f.fset, f.astFile) 586 } 587} 588 589// mergeAdjacentFields returns fields with adjacent fields merged if possible. 590func (f *fumpter) mergeAdjacentFields(fields []*ast.Field) []*ast.Field { 591 // If there are less than two fields then there is nothing to merge. 592 if len(fields) < 2 { 593 return fields 594 } 595 596 // Otherwise, iterate over adjacent pairs of fields, merging if possible, 597 // and mutating fields. Elements of fields may be mutated (if merged with 598 // following fields), discarded (if merged with a preceeding field), or left 599 // unchanged. 600 i := 0 601 for j := 1; j < len(fields); j++ { 602 if f.shouldMergeAdjacentFields(fields[i], fields[j]) { 603 fields[i].Names = append(fields[i].Names, fields[j].Names...) 604 } else { 605 i++ 606 fields[i] = fields[j] 607 } 608 } 609 return fields[:i+1] 610} 611 612func (f *fumpter) shouldMergeAdjacentFields(f1, f2 *ast.Field) bool { 613 if len(f1.Names) == 0 || len(f2.Names) == 0 { 614 // Both must have names for the merge to work. 615 return false 616 } 617 if f.Line(f1.Pos()) != f.Line(f2.Pos()) { 618 // Trust the user if they used separate lines. 619 return false 620 } 621 622 // Only merge if the types are equal. 623 opt := cmp.Comparer(func(x, y token.Pos) bool { return true }) 624 return cmp.Equal(f1.Type, f2.Type, opt) 625} 626 627var posType = reflect.TypeOf(token.NoPos) 628 629// setPos recursively sets all position fields in the node v to pos. 630func setPos(v reflect.Value, pos token.Pos) { 631 if v.Kind() == reflect.Ptr { 632 v = v.Elem() 633 } 634 if !v.IsValid() { 635 return 636 } 637 if v.Type() == posType { 638 v.Set(reflect.ValueOf(pos)) 639 } 640 if v.Kind() == reflect.Struct { 641 for i := 0; i < v.NumField(); i++ { 642 setPos(v.Field(i), pos) 643 } 644 } 645} 646