1// Copyright 2020 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package source 6 7import ( 8 "bytes" 9 "fmt" 10 "go/ast" 11 "go/format" 12 "go/parser" 13 "go/token" 14 "go/types" 15 "strings" 16 "unicode" 17 18 "golang.org/x/tools/go/analysis" 19 "golang.org/x/tools/go/ast/astutil" 20 "golang.org/x/tools/internal/analysisinternal" 21 "golang.org/x/tools/internal/span" 22) 23 24func extractVariable(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, _ *types.Package, info *types.Info) (*analysis.SuggestedFix, error) { 25 expr, path, ok, err := canExtractVariable(rng, file) 26 if !ok { 27 return nil, fmt.Errorf("extractVariable: cannot extract %s: %v", fset.Position(rng.Start), err) 28 } 29 30 // Create new AST node for extracted code. 31 var lhsNames []string 32 switch expr := expr.(type) { 33 // TODO: stricter rules for selectorExpr. 34 case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.SliceExpr, 35 *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr: 36 lhsNames = append(lhsNames, generateAvailableIdentifier(expr.Pos(), file, path, info, "x", 0)) 37 case *ast.CallExpr: 38 tup, ok := info.TypeOf(expr).(*types.Tuple) 39 if !ok { 40 // If the call expression only has one return value, we can treat it the 41 // same as our standard extract variable case. 42 lhsNames = append(lhsNames, 43 generateAvailableIdentifier(expr.Pos(), file, path, info, "x", 0)) 44 break 45 } 46 for i := 0; i < tup.Len(); i++ { 47 // Generate a unique variable for each return value. 48 lhsNames = append(lhsNames, 49 generateAvailableIdentifier(expr.Pos(), file, path, info, "x", i)) 50 } 51 default: 52 return nil, fmt.Errorf("cannot extract %T", expr) 53 } 54 55 insertBeforeStmt := analysisinternal.StmtToInsertVarBefore(path) 56 if insertBeforeStmt == nil { 57 return nil, fmt.Errorf("cannot find location to insert extraction") 58 } 59 tok := fset.File(expr.Pos()) 60 if tok == nil { 61 return nil, fmt.Errorf("no file for pos %v", fset.Position(file.Pos())) 62 } 63 newLineIndent := "\n" + calculateIndentation(src, tok, insertBeforeStmt) 64 65 lhs := strings.Join(lhsNames, ", ") 66 assignStmt := &ast.AssignStmt{ 67 Lhs: []ast.Expr{ast.NewIdent(lhs)}, 68 Tok: token.DEFINE, 69 Rhs: []ast.Expr{expr}, 70 } 71 var buf bytes.Buffer 72 if err := format.Node(&buf, fset, assignStmt); err != nil { 73 return nil, err 74 } 75 assignment := strings.ReplaceAll(buf.String(), "\n", newLineIndent) + newLineIndent 76 77 return &analysis.SuggestedFix{ 78 TextEdits: []analysis.TextEdit{ 79 { 80 Pos: rng.Start, 81 End: rng.End, 82 NewText: []byte(lhs), 83 }, 84 { 85 Pos: insertBeforeStmt.Pos(), 86 End: insertBeforeStmt.Pos(), 87 NewText: []byte(assignment), 88 }, 89 }, 90 }, nil 91} 92 93// canExtractVariable reports whether the code in the given range can be 94// extracted to a variable. 95func canExtractVariable(rng span.Range, file *ast.File) (ast.Expr, []ast.Node, bool, error) { 96 if rng.Start == rng.End { 97 return nil, nil, false, fmt.Errorf("start and end are equal") 98 } 99 path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End) 100 if len(path) == 0 { 101 return nil, nil, false, fmt.Errorf("no path enclosing interval") 102 } 103 for _, n := range path { 104 if _, ok := n.(*ast.ImportSpec); ok { 105 return nil, nil, false, fmt.Errorf("cannot extract variable in an import block") 106 } 107 } 108 node := path[0] 109 if rng.Start != node.Pos() || rng.End != node.End() { 110 return nil, nil, false, fmt.Errorf("range does not map to an AST node") 111 } 112 expr, ok := node.(ast.Expr) 113 if !ok { 114 return nil, nil, false, fmt.Errorf("node is not an expression") 115 } 116 switch expr.(type) { 117 case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.CallExpr, 118 *ast.SliceExpr, *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr: 119 return expr, path, true, nil 120 } 121 return nil, nil, false, fmt.Errorf("cannot extract an %T to a variable", expr) 122} 123 124// Calculate indentation for insertion. 125// When inserting lines of code, we must ensure that the lines have consistent 126// formatting (i.e. the proper indentation). To do so, we observe the indentation on the 127// line of code on which the insertion occurs. 128func calculateIndentation(content []byte, tok *token.File, insertBeforeStmt ast.Node) string { 129 line := tok.Line(insertBeforeStmt.Pos()) 130 lineOffset := tok.Offset(tok.LineStart(line)) 131 stmtOffset := tok.Offset(insertBeforeStmt.Pos()) 132 return string(content[lineOffset:stmtOffset]) 133} 134 135// generateAvailableIdentifier adjusts the new function name until there are no collisons in scope. 136// Possible collisions include other function and variable names. 137func generateAvailableIdentifier(pos token.Pos, file *ast.File, path []ast.Node, info *types.Info, prefix string, idx int) string { 138 scopes := CollectScopes(info, path, pos) 139 name := prefix + fmt.Sprintf("%d", idx) 140 for file.Scope.Lookup(name) != nil || !isValidName(name, scopes) { 141 idx++ 142 name = fmt.Sprintf("%v%d", prefix, idx) 143 } 144 return name 145} 146 147// isValidName checks for variable collision in scope. 148func isValidName(name string, scopes []*types.Scope) bool { 149 for _, scope := range scopes { 150 if scope == nil { 151 continue 152 } 153 if scope.Lookup(name) != nil { 154 return false 155 } 156 } 157 return true 158} 159 160// returnVariable keeps track of the information we need to properly introduce a new variable 161// that we will return in the extracted function. 162type returnVariable struct { 163 // name is the identifier that is used on the left-hand side of the call to 164 // the extracted function. 165 name ast.Expr 166 // decl is the declaration of the variable. It is used in the type signature of the 167 // extracted function and for variable declarations. 168 decl *ast.Field 169 // zeroVal is the "zero value" of the type of the variable. It is used in a return 170 // statement in the extracted function. 171 zeroVal ast.Expr 172} 173 174// extractFunction refactors the selected block of code into a new function. 175// It also replaces the selected block of code with a call to the extracted 176// function. First, we manually adjust the selection range. We remove trailing 177// and leading whitespace characters to ensure the range is precisely bounded 178// by AST nodes. Next, we determine the variables that will be the paramters 179// and return values of the extracted function. Lastly, we construct the call 180// of the function and insert this call as well as the extracted function into 181// their proper locations. 182func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) { 183 p, ok, err := canExtractFunction(fset, rng, src, file, info) 184 if !ok { 185 return nil, fmt.Errorf("extractFunction: cannot extract %s: %v", 186 fset.Position(rng.Start), err) 187 } 188 tok, path, rng, outer, start := p.tok, p.path, p.rng, p.outer, p.start 189 fileScope := info.Scopes[file] 190 if fileScope == nil { 191 return nil, fmt.Errorf("extractFunction: file scope is empty") 192 } 193 pkgScope := fileScope.Parent() 194 if pkgScope == nil { 195 return nil, fmt.Errorf("extractFunction: package scope is empty") 196 } 197 198 // TODO: Support non-nested return statements. 199 // A return statement is non-nested if its parent node is equal to the parent node 200 // of the first node in the selection. These cases must be handled seperately because 201 // non-nested return statements are guaranteed to execute. Our control flow does not 202 // properly consider these situations yet. 203 var retStmts []*ast.ReturnStmt 204 var hasNonNestedReturn bool 205 startParent := findParent(outer, start) 206 ast.Inspect(outer, func(n ast.Node) bool { 207 if n == nil { 208 return false 209 } 210 if n.Pos() < rng.Start || n.End() > rng.End { 211 return n.Pos() <= rng.End 212 } 213 ret, ok := n.(*ast.ReturnStmt) 214 if !ok { 215 return true 216 } 217 if findParent(outer, n) == startParent { 218 hasNonNestedReturn = true 219 return false 220 } 221 retStmts = append(retStmts, ret) 222 return false 223 }) 224 if hasNonNestedReturn { 225 return nil, fmt.Errorf("extractFunction: selected block contains non-nested return") 226 } 227 containsReturnStatement := len(retStmts) > 0 228 229 // Now that we have determined the correct range for the selection block, 230 // we must determine the signature of the extracted function. We will then replace 231 // the block with an assignment statement that calls the extracted function with 232 // the appropriate parameters and return values. 233 variables, err := collectFreeVars(info, file, fileScope, pkgScope, rng, path[0]) 234 if err != nil { 235 return nil, err 236 } 237 238 var ( 239 params, returns []ast.Expr // used when calling the extracted function 240 paramTypes, returnTypes []*ast.Field // used in the signature of the extracted function 241 uninitialized []types.Object // vars we will need to initialize before the call 242 ) 243 244 // Avoid duplicates while traversing vars and uninitialzed. 245 seenVars := make(map[types.Object]ast.Expr) 246 seenUninitialized := make(map[types.Object]struct{}) 247 248 // Some variables on the left-hand side of our assignment statement may be free. If our 249 // selection begins in the same scope in which the free variable is defined, we can 250 // redefine it in our assignment statement. See the following example, where 'b' and 251 // 'err' (both free variables) can be redefined in the second funcCall() while maintaing 252 // correctness. 253 // 254 // 255 // Not Redefined: 256 // 257 // a, err := funcCall() 258 // var b int 259 // b, err = funcCall() 260 // 261 // Redefined: 262 // 263 // a, err := funcCall() 264 // b, err := funcCall() 265 // 266 // We track the number of free variables that can be redefined to maintain our preference 267 // of using "x, y, z := fn()" style assignment statements. 268 var canRedefineCount int 269 270 // Each identifier in the selected block must become (1) a parameter to the 271 // extracted function, (2) a return value of the extracted function, or (3) a local 272 // variable in the extracted function. Determine the outcome(s) for each variable 273 // based on whether it is free, altered within the selected block, and used outside 274 // of the selected block. 275 for _, v := range variables { 276 if _, ok := seenVars[v.obj]; ok { 277 continue 278 } 279 typ := analysisinternal.TypeExpr(fset, file, pkg, v.obj.Type()) 280 if typ == nil { 281 return nil, fmt.Errorf("nil AST expression for type: %v", v.obj.Name()) 282 } 283 seenVars[v.obj] = typ 284 identifier := ast.NewIdent(v.obj.Name()) 285 // An identifier must meet three conditions to become a return value of the 286 // extracted function. (1) its value must be defined or reassigned within 287 // the selection (isAssigned), (2) it must be used at least once after the 288 // selection (isUsed), and (3) its first use after the selection 289 // cannot be its own reassignment or redefinition (objOverriden). 290 if v.obj.Parent() == nil { 291 return nil, fmt.Errorf("parent nil") 292 } 293 isUsed, firstUseAfter := objUsed(info, span.NewRange(fset, rng.End, v.obj.Parent().End()), v.obj) 294 if v.assigned && isUsed && !varOverridden(info, firstUseAfter, v.obj, v.free, outer) { 295 returnTypes = append(returnTypes, &ast.Field{Type: typ}) 296 returns = append(returns, identifier) 297 if !v.free { 298 uninitialized = append(uninitialized, v.obj) 299 } else if v.obj.Parent().Pos() == startParent.Pos() { 300 canRedefineCount++ 301 } 302 } 303 // An identifier must meet two conditions to become a parameter of the 304 // extracted function. (1) it must be free (isFree), and (2) its first 305 // use within the selection cannot be its own definition (isDefined). 306 if v.free && !v.defined { 307 params = append(params, identifier) 308 paramTypes = append(paramTypes, &ast.Field{ 309 Names: []*ast.Ident{identifier}, 310 Type: typ, 311 }) 312 } 313 } 314 315 // Find the function literal that encloses the selection. The enclosing function literal 316 // may not be the enclosing function declaration (i.e. 'outer'). For example, in the 317 // following block: 318 // 319 // func main() { 320 // ast.Inspect(node, func(n ast.Node) bool { 321 // v := 1 // this line extracted 322 // return true 323 // }) 324 // } 325 // 326 // 'outer' is main(). However, the extracted selection most directly belongs to 327 // the anonymous function literal, the second argument of ast.Inspect(). We use the 328 // enclosing function literal to determine the proper return types for return statements 329 // within the selection. We still need the enclosing function declaration because this is 330 // the top-level declaration. We inspect the top-level declaration to look for variables 331 // as well as for code replacement. 332 enclosing := outer.Type 333 for _, p := range path { 334 if p == enclosing { 335 break 336 } 337 if fl, ok := p.(*ast.FuncLit); ok { 338 enclosing = fl.Type 339 break 340 } 341 } 342 343 // We put the selection in a constructed file. We can then traverse and edit 344 // the extracted selection without modifying the original AST. 345 startOffset := tok.Offset(rng.Start) 346 endOffset := tok.Offset(rng.End) 347 selection := src[startOffset:endOffset] 348 extractedBlock, err := parseBlockStmt(fset, selection) 349 if err != nil { 350 return nil, err 351 } 352 353 // We need to account for return statements in the selected block, as they will complicate 354 // the logical flow of the extracted function. See the following example, where ** denotes 355 // the range to be extracted. 356 // 357 // Before: 358 // 359 // func _() int { 360 // a := 1 361 // b := 2 362 // **if a == b { 363 // return a 364 // }** 365 // ... 366 // } 367 // 368 // After: 369 // 370 // func _() int { 371 // a := 1 372 // b := 2 373 // cond0, ret0 := x0(a, b) 374 // if cond0 { 375 // return ret0 376 // } 377 // ... 378 // } 379 // 380 // func x0(a int, b int) (bool, int) { 381 // if a == b { 382 // return true, a 383 // } 384 // return false, 0 385 // } 386 // 387 // We handle returns by adding an additional boolean return value to the extracted function. 388 // This bool reports whether the original function would have returned. Because the 389 // extracted selection contains a return statement, we must also add the types in the 390 // return signature of the enclosing function to the return signature of the 391 // extracted function. We then add an extra if statement checking this boolean value 392 // in the original function. If the condition is met, the original function should 393 // return a value, mimicking the functionality of the original return statement(s) 394 // in the selection. 395 396 var retVars []*returnVariable 397 var ifReturn *ast.IfStmt 398 if containsReturnStatement { 399 // The selected block contained return statements, so we have to modify the 400 // signature of the extracted function as described above. Adjust all of 401 // the return statements in the extracted function to reflect this change in 402 // signature. 403 if err := adjustReturnStatements(returnTypes, seenVars, fset, file, 404 pkg, extractedBlock); err != nil { 405 return nil, err 406 } 407 // Collect the additional return values and types needed to accomodate return 408 // statements in the selection. Update the type signature of the extracted 409 // function and construct the if statement that will be inserted in the enclosing 410 // function. 411 retVars, ifReturn, err = generateReturnInfo(enclosing, pkg, path, file, info, fset, rng.Start) 412 if err != nil { 413 return nil, err 414 } 415 } 416 417 // Add a return statement to the end of the new function. This return statement must include 418 // the values for the types of the original extracted function signature and (if a return 419 // statement is present in the selection) enclosing function signature. 420 hasReturnValues := len(returns)+len(retVars) > 0 421 if hasReturnValues { 422 extractedBlock.List = append(extractedBlock.List, &ast.ReturnStmt{ 423 Results: append(returns, getZeroVals(retVars)...), 424 }) 425 } 426 427 // Construct the appropriate call to the extracted function. 428 // We must meet two conditions to use ":=" instead of '='. (1) there must be at least 429 // one variable on the lhs that is uninitailized (non-free) prior to the assignment. 430 // (2) all of the initialized (free) variables on the lhs must be able to be redefined. 431 sym := token.ASSIGN 432 canDefineCount := len(uninitialized) + canRedefineCount 433 canDefine := len(uninitialized)+len(retVars) > 0 && canDefineCount == len(returns) 434 if canDefine { 435 sym = token.DEFINE 436 } 437 funName := generateAvailableIdentifier(rng.Start, file, path, info, "fn", 0) 438 extractedFunCall := generateFuncCall(hasReturnValues, params, 439 append(returns, getNames(retVars)...), funName, sym) 440 441 // Build the extracted function. 442 newFunc := &ast.FuncDecl{ 443 Name: ast.NewIdent(funName), 444 Type: &ast.FuncType{ 445 Params: &ast.FieldList{List: paramTypes}, 446 Results: &ast.FieldList{List: append(returnTypes, getDecls(retVars)...)}, 447 }, 448 Body: extractedBlock, 449 } 450 451 // Create variable declarations for any identifiers that need to be initialized prior to 452 // calling the extracted function. We do not manually initialize variables if every return 453 // value is unitialized. We can use := to initialize the variables in this situation. 454 var declarations []ast.Stmt 455 if canDefineCount != len(returns) { 456 declarations = initializeVars(uninitialized, retVars, seenUninitialized, seenVars) 457 } 458 459 var declBuf, replaceBuf, newFuncBuf, ifBuf bytes.Buffer 460 if err := format.Node(&declBuf, fset, declarations); err != nil { 461 return nil, err 462 } 463 if err := format.Node(&replaceBuf, fset, extractedFunCall); err != nil { 464 return nil, err 465 } 466 if ifReturn != nil { 467 if err := format.Node(&ifBuf, fset, ifReturn); err != nil { 468 return nil, err 469 } 470 } 471 if err := format.Node(&newFuncBuf, fset, newFunc); err != nil { 472 return nil, err 473 } 474 475 // We're going to replace the whole enclosing function, 476 // so preserve the text before and after the selected block. 477 outerStart := tok.Offset(outer.Pos()) 478 outerEnd := tok.Offset(outer.End()) 479 before := src[outerStart:startOffset] 480 after := src[endOffset:outerEnd] 481 newLineIndent := "\n" + calculateIndentation(src, tok, start) 482 483 var fullReplacement strings.Builder 484 fullReplacement.Write(before) 485 if declBuf.Len() > 0 { // add any initializations, if needed 486 initializations := strings.ReplaceAll(declBuf.String(), "\n", newLineIndent) + 487 newLineIndent 488 fullReplacement.WriteString(initializations) 489 } 490 fullReplacement.Write(replaceBuf.Bytes()) // call the extracted function 491 if ifBuf.Len() > 0 { // add the if statement below the function call, if needed 492 ifstatement := newLineIndent + 493 strings.ReplaceAll(ifBuf.String(), "\n", newLineIndent) 494 fullReplacement.WriteString(ifstatement) 495 } 496 fullReplacement.Write(after) 497 fullReplacement.WriteString("\n\n") // add newlines after the enclosing function 498 fullReplacement.Write(newFuncBuf.Bytes()) // insert the extracted function 499 500 return &analysis.SuggestedFix{ 501 TextEdits: []analysis.TextEdit{{ 502 Pos: outer.Pos(), 503 End: outer.End(), 504 NewText: []byte(fullReplacement.String()), 505 }}, 506 }, nil 507} 508 509// adjustRangeForWhitespace adjusts the given range to exclude unnecessary leading or 510// trailing whitespace characters from selection. In the following example, each line 511// of the if statement is indented once. There are also two extra spaces after the 512// closing bracket before the line break. 513// 514// \tif (true) { 515// \t _ = 1 516// \t} \n 517// 518// By default, a valid range begins at 'if' and ends at the first whitespace character 519// after the '}'. But, users are likely to highlight full lines rather than adjusting 520// their cursors for whitespace. To support this use case, we must manually adjust the 521// ranges to match the correct AST node. In this particular example, we would adjust 522// rng.Start forward by one byte, and rng.End backwards by two bytes. 523func adjustRangeForWhitespace(rng span.Range, tok *token.File, content []byte) span.Range { 524 offset := tok.Offset(rng.Start) 525 for offset < len(content) { 526 if !unicode.IsSpace(rune(content[offset])) { 527 break 528 } 529 // Move forwards one byte to find a non-whitespace character. 530 offset += 1 531 } 532 rng.Start = tok.Pos(offset) 533 534 // Move backwards to find a non-whitespace character. 535 offset = tok.Offset(rng.End) 536 for o := offset - 1; 0 <= o && o < len(content); o-- { 537 if !unicode.IsSpace(rune(content[o])) { 538 break 539 } 540 offset = o 541 } 542 rng.End = tok.Pos(offset) 543 return rng 544} 545 546// findParent finds the parent AST node of the given target node, if the target is a 547// descendant of the starting node. 548func findParent(start ast.Node, target ast.Node) ast.Node { 549 var parent ast.Node 550 analysisinternal.WalkASTWithParent(start, func(n, p ast.Node) bool { 551 if n == target { 552 parent = p 553 return false 554 } 555 return true 556 }) 557 return parent 558} 559 560// variable describes the status of a variable within a selection. 561type variable struct { 562 obj types.Object 563 564 // free reports whether the variable is a free variable, meaning it should 565 // be a parameter to the extracted function. 566 free bool 567 568 // assigned reports whether the variable is assigned to in the selection. 569 assigned bool 570 571 // defined reports whether the variable is defined in the selection. 572 defined bool 573} 574 575// collectFreeVars maps each identifier in the given range to whether it is "free." 576// Given a range, a variable in that range is defined as "free" if it is declared 577// outside of the range and neither at the file scope nor package scope. These free 578// variables will be used as arguments in the extracted function. It also returns a 579// list of identifiers that may need to be returned by the extracted function. 580// Some of the code in this function has been adapted from tools/cmd/guru/freevars.go. 581func collectFreeVars(info *types.Info, file *ast.File, fileScope, pkgScope *types.Scope, rng span.Range, node ast.Node) ([]*variable, error) { 582 // id returns non-nil if n denotes an object that is referenced by the span 583 // and defined either within the span or in the lexical environment. The bool 584 // return value acts as an indicator for where it was defined. 585 id := func(n *ast.Ident) (types.Object, bool) { 586 obj := info.Uses[n] 587 if obj == nil { 588 return info.Defs[n], false 589 } 590 if obj.Name() == "_" { 591 return nil, false // exclude objects denoting '_' 592 } 593 if _, ok := obj.(*types.PkgName); ok { 594 return nil, false // imported package 595 } 596 if !(file.Pos() <= obj.Pos() && obj.Pos() <= file.End()) { 597 return nil, false // not defined in this file 598 } 599 scope := obj.Parent() 600 if scope == nil { 601 return nil, false // e.g. interface method, struct field 602 } 603 if scope == fileScope || scope == pkgScope { 604 return nil, false // defined at file or package scope 605 } 606 if rng.Start <= obj.Pos() && obj.Pos() <= rng.End { 607 return obj, false // defined within selection => not free 608 } 609 return obj, true 610 } 611 // sel returns non-nil if n denotes a selection o.x.y that is referenced by the 612 // span and defined either within the span or in the lexical environment. The bool 613 // return value acts as an indicator for where it was defined. 614 var sel func(n *ast.SelectorExpr) (types.Object, bool) 615 sel = func(n *ast.SelectorExpr) (types.Object, bool) { 616 switch x := astutil.Unparen(n.X).(type) { 617 case *ast.SelectorExpr: 618 return sel(x) 619 case *ast.Ident: 620 return id(x) 621 } 622 return nil, false 623 } 624 seen := make(map[types.Object]*variable) 625 firstUseIn := make(map[types.Object]token.Pos) 626 var vars []types.Object 627 ast.Inspect(node, func(n ast.Node) bool { 628 if n == nil { 629 return false 630 } 631 if rng.Start <= n.Pos() && n.End() <= rng.End { 632 var obj types.Object 633 var isFree, prune bool 634 switch n := n.(type) { 635 case *ast.Ident: 636 obj, isFree = id(n) 637 case *ast.SelectorExpr: 638 obj, isFree = sel(n) 639 prune = true 640 } 641 if obj != nil { 642 seen[obj] = &variable{ 643 obj: obj, 644 free: isFree, 645 } 646 vars = append(vars, obj) 647 // Find the first time that the object is used in the selection. 648 first, ok := firstUseIn[obj] 649 if !ok || n.Pos() < first { 650 firstUseIn[obj] = n.Pos() 651 } 652 if prune { 653 return false 654 } 655 } 656 } 657 return n.Pos() <= rng.End 658 }) 659 660 // Find identifiers that are initialized or whose values are altered at some 661 // point in the selected block. For example, in a selected block from lines 2-4, 662 // variables x, y, and z are included in assigned. However, in a selected block 663 // from lines 3-4, only variables y and z are included in assigned. 664 // 665 // 1: var a int 666 // 2: var x int 667 // 3: y := 3 668 // 4: z := x + a 669 // 670 ast.Inspect(node, func(n ast.Node) bool { 671 if n == nil { 672 return false 673 } 674 if n.Pos() < rng.Start || n.End() > rng.End { 675 return n.Pos() <= rng.End 676 } 677 switch n := n.(type) { 678 case *ast.AssignStmt: 679 for _, assignment := range n.Lhs { 680 lhs, ok := assignment.(*ast.Ident) 681 if !ok { 682 continue 683 } 684 obj, _ := id(lhs) 685 if obj == nil { 686 continue 687 } 688 if _, ok := seen[obj]; !ok { 689 continue 690 } 691 seen[obj].assigned = true 692 if n.Tok != token.DEFINE { 693 continue 694 } 695 // Find identifiers that are defined prior to being used 696 // elsewhere in the selection. 697 // TODO: Include identifiers that are assigned prior to being 698 // used elsewhere in the selection. Then, change the assignment 699 // to a definition in the extracted function. 700 if firstUseIn[obj] != lhs.Pos() { 701 continue 702 } 703 // Ensure that the object is not used in its own re-definition. 704 // For example: 705 // var f float64 706 // f, e := math.Frexp(f) 707 for _, expr := range n.Rhs { 708 if referencesObj(info, expr, obj) { 709 continue 710 } 711 if _, ok := seen[obj]; !ok { 712 continue 713 } 714 seen[obj].defined = true 715 break 716 } 717 } 718 return false 719 case *ast.DeclStmt: 720 gen, ok := n.Decl.(*ast.GenDecl) 721 if !ok { 722 return false 723 } 724 for _, spec := range gen.Specs { 725 vSpecs, ok := spec.(*ast.ValueSpec) 726 if !ok { 727 continue 728 } 729 for _, vSpec := range vSpecs.Names { 730 obj, _ := id(vSpec) 731 if obj == nil { 732 continue 733 } 734 if _, ok := seen[obj]; !ok { 735 continue 736 } 737 seen[obj].assigned = true 738 } 739 } 740 return false 741 case *ast.IncDecStmt: 742 if ident, ok := n.X.(*ast.Ident); !ok { 743 return false 744 } else if obj, _ := id(ident); obj == nil { 745 return false 746 } else { 747 if _, ok := seen[obj]; !ok { 748 return false 749 } 750 seen[obj].assigned = true 751 } 752 } 753 return true 754 }) 755 var variables []*variable 756 for _, obj := range vars { 757 v, ok := seen[obj] 758 if !ok { 759 return nil, fmt.Errorf("no seen types.Object for %v", obj) 760 } 761 variables = append(variables, v) 762 } 763 return variables, nil 764} 765 766// referencesObj checks whether the given object appears in the given expression. 767func referencesObj(info *types.Info, expr ast.Expr, obj types.Object) bool { 768 var hasObj bool 769 ast.Inspect(expr, func(n ast.Node) bool { 770 if n == nil { 771 return false 772 } 773 ident, ok := n.(*ast.Ident) 774 if !ok { 775 return true 776 } 777 objUse := info.Uses[ident] 778 if obj == objUse { 779 hasObj = true 780 return false 781 } 782 return false 783 }) 784 return hasObj 785} 786 787type fnExtractParams struct { 788 tok *token.File 789 path []ast.Node 790 rng span.Range 791 outer *ast.FuncDecl 792 start ast.Node 793} 794 795// canExtractFunction reports whether the code in the given range can be 796// extracted to a function. 797func canExtractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, _ *types.Info) (*fnExtractParams, bool, error) { 798 if rng.Start == rng.End { 799 return nil, false, fmt.Errorf("start and end are equal") 800 } 801 tok := fset.File(file.Pos()) 802 if tok == nil { 803 return nil, false, fmt.Errorf("no file for pos %v", fset.Position(file.Pos())) 804 } 805 rng = adjustRangeForWhitespace(rng, tok, src) 806 path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End) 807 if len(path) == 0 { 808 return nil, false, fmt.Errorf("no path enclosing interval") 809 } 810 // Node that encloses the selection must be a statement. 811 // TODO: Support function extraction for an expression. 812 _, ok := path[0].(ast.Stmt) 813 if !ok { 814 return nil, false, fmt.Errorf("node is not a statement") 815 } 816 817 // Find the function declaration that encloses the selection. 818 var outer *ast.FuncDecl 819 for _, p := range path { 820 if p, ok := p.(*ast.FuncDecl); ok { 821 outer = p 822 break 823 } 824 } 825 if outer == nil { 826 return nil, false, fmt.Errorf("no enclosing function") 827 } 828 829 // Find the nodes at the start and end of the selection. 830 var start, end ast.Node 831 ast.Inspect(outer, func(n ast.Node) bool { 832 if n == nil { 833 return false 834 } 835 // Do not override 'start' with a node that begins at the same location 836 // but is nested further from 'outer'. 837 if start == nil && n.Pos() == rng.Start && n.End() <= rng.End { 838 start = n 839 } 840 if end == nil && n.End() == rng.End && n.Pos() >= rng.Start { 841 end = n 842 } 843 return n.Pos() <= rng.End 844 }) 845 if start == nil || end == nil { 846 return nil, false, fmt.Errorf("range does not map to AST nodes") 847 } 848 return &fnExtractParams{ 849 tok: tok, 850 path: path, 851 rng: rng, 852 outer: outer, 853 start: start, 854 }, true, nil 855} 856 857// objUsed checks if the object is used within the range. It returns the first occurence of 858// the object in the range, if it exists. 859func objUsed(info *types.Info, rng span.Range, obj types.Object) (bool, *ast.Ident) { 860 var firstUse *ast.Ident 861 for id, objUse := range info.Uses { 862 if obj != objUse { 863 continue 864 } 865 if id.Pos() < rng.Start || id.End() > rng.End { 866 continue 867 } 868 if firstUse == nil || id.Pos() < firstUse.Pos() { 869 firstUse = id 870 } 871 } 872 return firstUse != nil, firstUse 873} 874 875// varOverridden traverses the given AST node until we find the given identifier. Then, we 876// examine the occurrence of the given identifier and check for (1) whether the identifier 877// is being redefined. If the identifier is free, we also check for (2) whether the identifier 878// is being reassigned. We will not include an identifier in the return statement of the 879// extracted function if it meets one of the above conditions. 880func varOverridden(info *types.Info, firstUse *ast.Ident, obj types.Object, isFree bool, node ast.Node) bool { 881 var isOverriden bool 882 ast.Inspect(node, func(n ast.Node) bool { 883 if n == nil { 884 return false 885 } 886 assignment, ok := n.(*ast.AssignStmt) 887 if !ok { 888 return true 889 } 890 // A free variable is initialized prior to the selection. We can always reassign 891 // this variable after the selection because it has already been defined. 892 // Conversely, a non-free variable is initialized within the selection. Thus, we 893 // cannot reassign this variable after the selection unless it is initialized and 894 // returned by the extracted function. 895 if !isFree && assignment.Tok == token.ASSIGN { 896 return false 897 } 898 for _, assigned := range assignment.Lhs { 899 ident, ok := assigned.(*ast.Ident) 900 // Check if we found the first use of the identifier. 901 if !ok || ident != firstUse { 902 continue 903 } 904 objUse := info.Uses[ident] 905 if objUse == nil || objUse != obj { 906 continue 907 } 908 // Ensure that the object is not used in its own definition. 909 // For example: 910 // var f float64 911 // f, e := math.Frexp(f) 912 for _, expr := range assignment.Rhs { 913 if referencesObj(info, expr, obj) { 914 return false 915 } 916 } 917 isOverriden = true 918 return false 919 } 920 return false 921 }) 922 return isOverriden 923} 924 925// parseExtraction generates an AST file from the given text. We then return the portion of the 926// file that represents the text. 927func parseBlockStmt(fset *token.FileSet, src []byte) (*ast.BlockStmt, error) { 928 text := "package main\nfunc _() { " + string(src) + " }" 929 extract, err := parser.ParseFile(fset, "", text, 0) 930 if err != nil { 931 return nil, err 932 } 933 if len(extract.Decls) == 0 { 934 return nil, fmt.Errorf("parsed file does not contain any declarations") 935 } 936 decl, ok := extract.Decls[0].(*ast.FuncDecl) 937 if !ok { 938 return nil, fmt.Errorf("parsed file does not contain expected function declaration") 939 } 940 if decl.Body == nil { 941 return nil, fmt.Errorf("extracted function has no body") 942 } 943 return decl.Body, nil 944} 945 946// generateReturnInfo generates the information we need to adjust the return statements and 947// signature of the extracted function. We prepare names, signatures, and "zero values" that 948// represent the new variables. We also use this information to construct the if statement that 949// is inserted below the call to the extracted function. 950func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast.Node, file *ast.File, info *types.Info, fset *token.FileSet, pos token.Pos) ([]*returnVariable, *ast.IfStmt, error) { 951 // Generate information for the added bool value. 952 cond := &ast.Ident{Name: generateAvailableIdentifier(pos, file, path, info, "cond", 0)} 953 retVars := []*returnVariable{ 954 { 955 name: cond, 956 decl: &ast.Field{Type: ast.NewIdent("bool")}, 957 zeroVal: ast.NewIdent("false"), 958 }, 959 } 960 // Generate information for the values in the return signature of the enclosing function. 961 if enclosing.Results != nil { 962 for i, field := range enclosing.Results.List { 963 typ := info.TypeOf(field.Type) 964 if typ == nil { 965 return nil, nil, fmt.Errorf( 966 "failed type conversion, AST expression: %T", field.Type) 967 } 968 expr := analysisinternal.TypeExpr(fset, file, pkg, typ) 969 if expr == nil { 970 return nil, nil, fmt.Errorf("nil AST expression") 971 } 972 retVars = append(retVars, &returnVariable{ 973 name: ast.NewIdent(generateAvailableIdentifier(pos, file, 974 path, info, "ret", i)), 975 decl: &ast.Field{Type: expr}, 976 zeroVal: analysisinternal.ZeroValue( 977 fset, file, pkg, typ), 978 }) 979 } 980 } 981 // Create the return statement for the enclosing function. We must exclude the variable 982 // for the condition of the if statement (cond) from the return statement. 983 ifReturn := &ast.IfStmt{ 984 Cond: cond, 985 Body: &ast.BlockStmt{ 986 List: []ast.Stmt{&ast.ReturnStmt{Results: getNames(retVars)[1:]}}, 987 }, 988 } 989 return retVars, ifReturn, nil 990} 991 992// adjustReturnStatements adds "zero values" of the given types to each return statement 993// in the given AST node. 994func adjustReturnStatements(returnTypes []*ast.Field, seenVars map[types.Object]ast.Expr, fset *token.FileSet, file *ast.File, pkg *types.Package, extractedBlock *ast.BlockStmt) error { 995 var zeroVals []ast.Expr 996 // Create "zero values" for each type. 997 for _, returnType := range returnTypes { 998 var val ast.Expr 999 for obj, typ := range seenVars { 1000 if typ != returnType.Type { 1001 continue 1002 } 1003 val = analysisinternal.ZeroValue(fset, file, pkg, obj.Type()) 1004 break 1005 } 1006 if val == nil { 1007 return fmt.Errorf( 1008 "could not find matching AST expression for %T", returnType.Type) 1009 } 1010 zeroVals = append(zeroVals, val) 1011 } 1012 // Add "zero values" to each return statement. 1013 // The bool reports whether the enclosing function should return after calling the 1014 // extracted function. We set the bool to 'true' because, if these return statements 1015 // execute, the extracted function terminates early, and the enclosing function must 1016 // return as well. 1017 zeroVals = append(zeroVals, ast.NewIdent("true")) 1018 ast.Inspect(extractedBlock, func(n ast.Node) bool { 1019 if n == nil { 1020 return false 1021 } 1022 if n, ok := n.(*ast.ReturnStmt); ok { 1023 n.Results = append(zeroVals, n.Results...) 1024 return false 1025 } 1026 return true 1027 }) 1028 return nil 1029} 1030 1031// generateFuncCall constructs a call expression for the extracted function, described by the 1032// given parameters and return variables. 1033func generateFuncCall(hasReturnVals bool, params, returns []ast.Expr, name string, token token.Token) ast.Node { 1034 var replace ast.Node 1035 if hasReturnVals { 1036 callExpr := &ast.CallExpr{ 1037 Fun: ast.NewIdent(name), 1038 Args: params, 1039 } 1040 replace = &ast.AssignStmt{ 1041 Lhs: returns, 1042 Tok: token, 1043 Rhs: []ast.Expr{callExpr}, 1044 } 1045 } else { 1046 replace = &ast.CallExpr{ 1047 Fun: ast.NewIdent(name), 1048 Args: params, 1049 } 1050 } 1051 return replace 1052} 1053 1054// initializeVars creates variable declarations, if needed. 1055// Our preference is to replace the selected block with an "x, y, z := fn()" style 1056// assignment statement. We can use this style when all of the variables in the 1057// extracted function's return statement are either not defined prior to the extracted block 1058// or can be safely redefined. However, for example, if z is already defined 1059// in a different scope, we replace the selected block with: 1060// 1061// var x int 1062// var y string 1063// x, y, z = fn() 1064func initializeVars(uninitialized []types.Object, retVars []*returnVariable, seenUninitialized map[types.Object]struct{}, seenVars map[types.Object]ast.Expr) []ast.Stmt { 1065 var declarations []ast.Stmt 1066 for _, obj := range uninitialized { 1067 if _, ok := seenUninitialized[obj]; ok { 1068 continue 1069 } 1070 seenUninitialized[obj] = struct{}{} 1071 valSpec := &ast.ValueSpec{ 1072 Names: []*ast.Ident{ast.NewIdent(obj.Name())}, 1073 Type: seenVars[obj], 1074 } 1075 genDecl := &ast.GenDecl{ 1076 Tok: token.VAR, 1077 Specs: []ast.Spec{valSpec}, 1078 } 1079 declarations = append(declarations, &ast.DeclStmt{Decl: genDecl}) 1080 } 1081 // Each variable added from a return statement in the selection 1082 // must be initialized. 1083 for i, retVar := range retVars { 1084 n := retVar.name.(*ast.Ident) 1085 valSpec := &ast.ValueSpec{ 1086 Names: []*ast.Ident{n}, 1087 Type: retVars[i].decl.Type, 1088 } 1089 genDecl := &ast.GenDecl{ 1090 Tok: token.VAR, 1091 Specs: []ast.Spec{valSpec}, 1092 } 1093 declarations = append(declarations, &ast.DeclStmt{Decl: genDecl}) 1094 } 1095 return declarations 1096} 1097 1098// getNames returns the names from the given list of returnVariable. 1099func getNames(retVars []*returnVariable) []ast.Expr { 1100 var names []ast.Expr 1101 for _, retVar := range retVars { 1102 names = append(names, retVar.name) 1103 } 1104 return names 1105} 1106 1107// getZeroVals returns the "zero values" from the given list of returnVariable. 1108func getZeroVals(retVars []*returnVariable) []ast.Expr { 1109 var zvs []ast.Expr 1110 for _, retVar := range retVars { 1111 zvs = append(zvs, retVar.zeroVal) 1112 } 1113 return zvs 1114} 1115 1116// getDecls returns the declarations from the given list of returnVariable. 1117func getDecls(retVars []*returnVariable) []*ast.Field { 1118 var decls []*ast.Field 1119 for _, retVar := range retVars { 1120 decls = append(decls, retVar.decl) 1121 } 1122 return decls 1123} 1124