1// Copyright 2020 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package source
6
7import (
8	"bytes"
9	"fmt"
10	"go/ast"
11	"go/format"
12	"go/parser"
13	"go/token"
14	"go/types"
15	"strings"
16	"unicode"
17
18	"golang.org/x/tools/go/analysis"
19	"golang.org/x/tools/go/ast/astutil"
20	"golang.org/x/tools/internal/analysisinternal"
21	"golang.org/x/tools/internal/span"
22)
23
24func extractVariable(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, _ *types.Package, info *types.Info) (*analysis.SuggestedFix, error) {
25	expr, path, ok, err := canExtractVariable(rng, file)
26	if !ok {
27		return nil, fmt.Errorf("extractVariable: cannot extract %s: %v", fset.Position(rng.Start), err)
28	}
29
30	// Create new AST node for extracted code.
31	var lhsNames []string
32	switch expr := expr.(type) {
33	// TODO: stricter rules for selectorExpr.
34	case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.SliceExpr,
35		*ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr:
36		lhsNames = append(lhsNames, generateAvailableIdentifier(expr.Pos(), file, path, info, "x", 0))
37	case *ast.CallExpr:
38		tup, ok := info.TypeOf(expr).(*types.Tuple)
39		if !ok {
40			// If the call expression only has one return value, we can treat it the
41			// same as our standard extract variable case.
42			lhsNames = append(lhsNames,
43				generateAvailableIdentifier(expr.Pos(), file, path, info, "x", 0))
44			break
45		}
46		for i := 0; i < tup.Len(); i++ {
47			// Generate a unique variable for each return value.
48			lhsNames = append(lhsNames,
49				generateAvailableIdentifier(expr.Pos(), file, path, info, "x", i))
50		}
51	default:
52		return nil, fmt.Errorf("cannot extract %T", expr)
53	}
54
55	insertBeforeStmt := analysisinternal.StmtToInsertVarBefore(path)
56	if insertBeforeStmt == nil {
57		return nil, fmt.Errorf("cannot find location to insert extraction")
58	}
59	tok := fset.File(expr.Pos())
60	if tok == nil {
61		return nil, fmt.Errorf("no file for pos %v", fset.Position(file.Pos()))
62	}
63	newLineIndent := "\n" + calculateIndentation(src, tok, insertBeforeStmt)
64
65	lhs := strings.Join(lhsNames, ", ")
66	assignStmt := &ast.AssignStmt{
67		Lhs: []ast.Expr{ast.NewIdent(lhs)},
68		Tok: token.DEFINE,
69		Rhs: []ast.Expr{expr},
70	}
71	var buf bytes.Buffer
72	if err := format.Node(&buf, fset, assignStmt); err != nil {
73		return nil, err
74	}
75	assignment := strings.ReplaceAll(buf.String(), "\n", newLineIndent) + newLineIndent
76
77	return &analysis.SuggestedFix{
78		TextEdits: []analysis.TextEdit{
79			{
80				Pos:     rng.Start,
81				End:     rng.End,
82				NewText: []byte(lhs),
83			},
84			{
85				Pos:     insertBeforeStmt.Pos(),
86				End:     insertBeforeStmt.Pos(),
87				NewText: []byte(assignment),
88			},
89		},
90	}, nil
91}
92
93// canExtractVariable reports whether the code in the given range can be
94// extracted to a variable.
95func canExtractVariable(rng span.Range, file *ast.File) (ast.Expr, []ast.Node, bool, error) {
96	if rng.Start == rng.End {
97		return nil, nil, false, fmt.Errorf("start and end are equal")
98	}
99	path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End)
100	if len(path) == 0 {
101		return nil, nil, false, fmt.Errorf("no path enclosing interval")
102	}
103	for _, n := range path {
104		if _, ok := n.(*ast.ImportSpec); ok {
105			return nil, nil, false, fmt.Errorf("cannot extract variable in an import block")
106		}
107	}
108	node := path[0]
109	if rng.Start != node.Pos() || rng.End != node.End() {
110		return nil, nil, false, fmt.Errorf("range does not map to an AST node")
111	}
112	expr, ok := node.(ast.Expr)
113	if !ok {
114		return nil, nil, false, fmt.Errorf("node is not an expression")
115	}
116	switch expr.(type) {
117	case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.CallExpr,
118		*ast.SliceExpr, *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr:
119		return expr, path, true, nil
120	}
121	return nil, nil, false, fmt.Errorf("cannot extract an %T to a variable", expr)
122}
123
124// Calculate indentation for insertion.
125// When inserting lines of code, we must ensure that the lines have consistent
126// formatting (i.e. the proper indentation). To do so, we observe the indentation on the
127// line of code on which the insertion occurs.
128func calculateIndentation(content []byte, tok *token.File, insertBeforeStmt ast.Node) string {
129	line := tok.Line(insertBeforeStmt.Pos())
130	lineOffset := tok.Offset(tok.LineStart(line))
131	stmtOffset := tok.Offset(insertBeforeStmt.Pos())
132	return string(content[lineOffset:stmtOffset])
133}
134
135// generateAvailableIdentifier adjusts the new function name until there are no collisons in scope.
136// Possible collisions include other function and variable names.
137func generateAvailableIdentifier(pos token.Pos, file *ast.File, path []ast.Node, info *types.Info, prefix string, idx int) string {
138	scopes := CollectScopes(info, path, pos)
139	name := prefix + fmt.Sprintf("%d", idx)
140	for file.Scope.Lookup(name) != nil || !isValidName(name, scopes) {
141		idx++
142		name = fmt.Sprintf("%v%d", prefix, idx)
143	}
144	return name
145}
146
147// isValidName checks for variable collision in scope.
148func isValidName(name string, scopes []*types.Scope) bool {
149	for _, scope := range scopes {
150		if scope == nil {
151			continue
152		}
153		if scope.Lookup(name) != nil {
154			return false
155		}
156	}
157	return true
158}
159
160// returnVariable keeps track of the information we need to properly introduce a new variable
161// that we will return in the extracted function.
162type returnVariable struct {
163	// name is the identifier that is used on the left-hand side of the call to
164	// the extracted function.
165	name ast.Expr
166	// decl is the declaration of the variable. It is used in the type signature of the
167	// extracted function and for variable declarations.
168	decl *ast.Field
169	// zeroVal is the "zero value" of the type of the variable. It is used in a return
170	// statement in the extracted function.
171	zeroVal ast.Expr
172}
173
174// extractFunction refactors the selected block of code into a new function.
175// It also replaces the selected block of code with a call to the extracted
176// function. First, we manually adjust the selection range. We remove trailing
177// and leading whitespace characters to ensure the range is precisely bounded
178// by AST nodes. Next, we determine the variables that will be the paramters
179// and return values of the extracted function. Lastly, we construct the call
180// of the function and insert this call as well as the extracted function into
181// their proper locations.
182func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) {
183	p, ok, err := canExtractFunction(fset, rng, src, file, info)
184	if !ok {
185		return nil, fmt.Errorf("extractFunction: cannot extract %s: %v",
186			fset.Position(rng.Start), err)
187	}
188	tok, path, rng, outer, start := p.tok, p.path, p.rng, p.outer, p.start
189	fileScope := info.Scopes[file]
190	if fileScope == nil {
191		return nil, fmt.Errorf("extractFunction: file scope is empty")
192	}
193	pkgScope := fileScope.Parent()
194	if pkgScope == nil {
195		return nil, fmt.Errorf("extractFunction: package scope is empty")
196	}
197
198	// TODO: Support non-nested return statements.
199	// A return statement is non-nested if its parent node is equal to the parent node
200	// of the first node in the selection. These cases must be handled seperately because
201	// non-nested return statements are guaranteed to execute. Our control flow does not
202	// properly consider these situations yet.
203	var retStmts []*ast.ReturnStmt
204	var hasNonNestedReturn bool
205	startParent := findParent(outer, start)
206	ast.Inspect(outer, func(n ast.Node) bool {
207		if n == nil {
208			return false
209		}
210		if n.Pos() < rng.Start || n.End() > rng.End {
211			return n.Pos() <= rng.End
212		}
213		ret, ok := n.(*ast.ReturnStmt)
214		if !ok {
215			return true
216		}
217		if findParent(outer, n) == startParent {
218			hasNonNestedReturn = true
219			return false
220		}
221		retStmts = append(retStmts, ret)
222		return false
223	})
224	if hasNonNestedReturn {
225		return nil, fmt.Errorf("extractFunction: selected block contains non-nested return")
226	}
227	containsReturnStatement := len(retStmts) > 0
228
229	// Now that we have determined the correct range for the selection block,
230	// we must determine the signature of the extracted function. We will then replace
231	// the block with an assignment statement that calls the extracted function with
232	// the appropriate parameters and return values.
233	variables, err := collectFreeVars(info, file, fileScope, pkgScope, rng, path[0])
234	if err != nil {
235		return nil, err
236	}
237
238	var (
239		params, returns         []ast.Expr     // used when calling the extracted function
240		paramTypes, returnTypes []*ast.Field   // used in the signature of the extracted function
241		uninitialized           []types.Object // vars we will need to initialize before the call
242	)
243
244	// Avoid duplicates while traversing vars and uninitialzed.
245	seenVars := make(map[types.Object]ast.Expr)
246	seenUninitialized := make(map[types.Object]struct{})
247
248	// Some variables on the left-hand side of our assignment statement may be free. If our
249	// selection begins in the same scope in which the free variable is defined, we can
250	// redefine it in our assignment statement. See the following example, where 'b' and
251	// 'err' (both free variables) can be redefined in the second funcCall() while maintaing
252	// correctness.
253	//
254	//
255	// Not Redefined:
256	//
257	// a, err := funcCall()
258	// var b int
259	// b, err = funcCall()
260	//
261	// Redefined:
262	//
263	// a, err := funcCall()
264	// b, err := funcCall()
265	//
266	// We track the number of free variables that can be redefined to maintain our preference
267	// of using "x, y, z := fn()" style assignment statements.
268	var canRedefineCount int
269
270	// Each identifier in the selected block must become (1) a parameter to the
271	// extracted function, (2) a return value of the extracted function, or (3) a local
272	// variable in the extracted function. Determine the outcome(s) for each variable
273	// based on whether it is free, altered within the selected block, and used outside
274	// of the selected block.
275	for _, v := range variables {
276		if _, ok := seenVars[v.obj]; ok {
277			continue
278		}
279		typ := analysisinternal.TypeExpr(fset, file, pkg, v.obj.Type())
280		if typ == nil {
281			return nil, fmt.Errorf("nil AST expression for type: %v", v.obj.Name())
282		}
283		seenVars[v.obj] = typ
284		identifier := ast.NewIdent(v.obj.Name())
285		// An identifier must meet three conditions to become a return value of the
286		// extracted function. (1) its value must be defined or reassigned within
287		// the selection (isAssigned), (2) it must be used at least once after the
288		// selection (isUsed), and (3) its first use after the selection
289		// cannot be its own reassignment or redefinition (objOverriden).
290		if v.obj.Parent() == nil {
291			return nil, fmt.Errorf("parent nil")
292		}
293		isUsed, firstUseAfter := objUsed(info, span.NewRange(fset, rng.End, v.obj.Parent().End()), v.obj)
294		if v.assigned && isUsed && !varOverridden(info, firstUseAfter, v.obj, v.free, outer) {
295			returnTypes = append(returnTypes, &ast.Field{Type: typ})
296			returns = append(returns, identifier)
297			if !v.free {
298				uninitialized = append(uninitialized, v.obj)
299			} else if v.obj.Parent().Pos() == startParent.Pos() {
300				canRedefineCount++
301			}
302		}
303		// An identifier must meet two conditions to become a parameter of the
304		// extracted function. (1) it must be free (isFree), and (2) its first
305		// use within the selection cannot be its own definition (isDefined).
306		if v.free && !v.defined {
307			params = append(params, identifier)
308			paramTypes = append(paramTypes, &ast.Field{
309				Names: []*ast.Ident{identifier},
310				Type:  typ,
311			})
312		}
313	}
314
315	// Find the function literal that encloses the selection. The enclosing function literal
316	// may not be the enclosing function declaration (i.e. 'outer'). For example, in the
317	// following block:
318	//
319	// func main() {
320	//     ast.Inspect(node, func(n ast.Node) bool {
321	//         v := 1 // this line extracted
322	//         return true
323	//     })
324	// }
325	//
326	// 'outer' is main(). However, the extracted selection most directly belongs to
327	// the anonymous function literal, the second argument of ast.Inspect(). We use the
328	// enclosing function literal to determine the proper return types for return statements
329	// within the selection. We still need the enclosing function declaration because this is
330	// the top-level declaration. We inspect the top-level declaration to look for variables
331	// as well as for code replacement.
332	enclosing := outer.Type
333	for _, p := range path {
334		if p == enclosing {
335			break
336		}
337		if fl, ok := p.(*ast.FuncLit); ok {
338			enclosing = fl.Type
339			break
340		}
341	}
342
343	// We put the selection in a constructed file. We can then traverse and edit
344	// the extracted selection without modifying the original AST.
345	startOffset := tok.Offset(rng.Start)
346	endOffset := tok.Offset(rng.End)
347	selection := src[startOffset:endOffset]
348	extractedBlock, err := parseBlockStmt(fset, selection)
349	if err != nil {
350		return nil, err
351	}
352
353	// We need to account for return statements in the selected block, as they will complicate
354	// the logical flow of the extracted function. See the following example, where ** denotes
355	// the range to be extracted.
356	//
357	// Before:
358	//
359	// func _() int {
360	//     a := 1
361	//     b := 2
362	//     **if a == b {
363	//         return a
364	//     }**
365	//     ...
366	// }
367	//
368	// After:
369	//
370	// func _() int {
371	//     a := 1
372	//     b := 2
373	//     cond0, ret0 := x0(a, b)
374	//     if cond0 {
375	//         return ret0
376	//     }
377	//     ...
378	// }
379	//
380	// func x0(a int, b int) (bool, int) {
381	//     if a == b {
382	//         return true, a
383	//     }
384	//     return false, 0
385	// }
386	//
387	// We handle returns by adding an additional boolean return value to the extracted function.
388	// This bool reports whether the original function would have returned. Because the
389	// extracted selection contains a return statement, we must also add the types in the
390	// return signature of the enclosing function to the return signature of the
391	// extracted function. We then add an extra if statement checking this boolean value
392	// in the original function. If the condition is met, the original function should
393	// return a value, mimicking the functionality of the original return statement(s)
394	// in the selection.
395
396	var retVars []*returnVariable
397	var ifReturn *ast.IfStmt
398	if containsReturnStatement {
399		// The selected block contained return statements, so we have to modify the
400		// signature of the extracted function as described above. Adjust all of
401		// the return statements in the extracted function to reflect this change in
402		// signature.
403		if err := adjustReturnStatements(returnTypes, seenVars, fset, file,
404			pkg, extractedBlock); err != nil {
405			return nil, err
406		}
407		// Collect the additional return values and types needed to accomodate return
408		// statements in the selection. Update the type signature of the extracted
409		// function and construct the if statement that will be inserted in the enclosing
410		// function.
411		retVars, ifReturn, err = generateReturnInfo(enclosing, pkg, path, file, info, fset, rng.Start)
412		if err != nil {
413			return nil, err
414		}
415	}
416
417	// Add a return statement to the end of the new function. This return statement must include
418	// the values for the types of the original extracted function signature and (if a return
419	// statement is present in the selection) enclosing function signature.
420	hasReturnValues := len(returns)+len(retVars) > 0
421	if hasReturnValues {
422		extractedBlock.List = append(extractedBlock.List, &ast.ReturnStmt{
423			Results: append(returns, getZeroVals(retVars)...),
424		})
425	}
426
427	// Construct the appropriate call to the extracted function.
428	// We must meet two conditions to use ":=" instead of '='. (1) there must be at least
429	// one variable on the lhs that is uninitailized (non-free) prior to the assignment.
430	// (2) all of the initialized (free) variables on the lhs must be able to be redefined.
431	sym := token.ASSIGN
432	canDefineCount := len(uninitialized) + canRedefineCount
433	canDefine := len(uninitialized)+len(retVars) > 0 && canDefineCount == len(returns)
434	if canDefine {
435		sym = token.DEFINE
436	}
437	funName := generateAvailableIdentifier(rng.Start, file, path, info, "fn", 0)
438	extractedFunCall := generateFuncCall(hasReturnValues, params,
439		append(returns, getNames(retVars)...), funName, sym)
440
441	// Build the extracted function.
442	newFunc := &ast.FuncDecl{
443		Name: ast.NewIdent(funName),
444		Type: &ast.FuncType{
445			Params:  &ast.FieldList{List: paramTypes},
446			Results: &ast.FieldList{List: append(returnTypes, getDecls(retVars)...)},
447		},
448		Body: extractedBlock,
449	}
450
451	// Create variable declarations for any identifiers that need to be initialized prior to
452	// calling the extracted function. We do not manually initialize variables if every return
453	// value is unitialized. We can use := to initialize the variables in this situation.
454	var declarations []ast.Stmt
455	if canDefineCount != len(returns) {
456		declarations = initializeVars(uninitialized, retVars, seenUninitialized, seenVars)
457	}
458
459	var declBuf, replaceBuf, newFuncBuf, ifBuf bytes.Buffer
460	if err := format.Node(&declBuf, fset, declarations); err != nil {
461		return nil, err
462	}
463	if err := format.Node(&replaceBuf, fset, extractedFunCall); err != nil {
464		return nil, err
465	}
466	if ifReturn != nil {
467		if err := format.Node(&ifBuf, fset, ifReturn); err != nil {
468			return nil, err
469		}
470	}
471	if err := format.Node(&newFuncBuf, fset, newFunc); err != nil {
472		return nil, err
473	}
474
475	// We're going to replace the whole enclosing function,
476	// so preserve the text before and after the selected block.
477	outerStart := tok.Offset(outer.Pos())
478	outerEnd := tok.Offset(outer.End())
479	before := src[outerStart:startOffset]
480	after := src[endOffset:outerEnd]
481	newLineIndent := "\n" + calculateIndentation(src, tok, start)
482
483	var fullReplacement strings.Builder
484	fullReplacement.Write(before)
485	if declBuf.Len() > 0 { // add any initializations, if needed
486		initializations := strings.ReplaceAll(declBuf.String(), "\n", newLineIndent) +
487			newLineIndent
488		fullReplacement.WriteString(initializations)
489	}
490	fullReplacement.Write(replaceBuf.Bytes()) // call the extracted function
491	if ifBuf.Len() > 0 {                      // add the if statement below the function call, if needed
492		ifstatement := newLineIndent +
493			strings.ReplaceAll(ifBuf.String(), "\n", newLineIndent)
494		fullReplacement.WriteString(ifstatement)
495	}
496	fullReplacement.Write(after)
497	fullReplacement.WriteString("\n\n")       // add newlines after the enclosing function
498	fullReplacement.Write(newFuncBuf.Bytes()) // insert the extracted function
499
500	return &analysis.SuggestedFix{
501		TextEdits: []analysis.TextEdit{{
502			Pos:     outer.Pos(),
503			End:     outer.End(),
504			NewText: []byte(fullReplacement.String()),
505		}},
506	}, nil
507}
508
509// adjustRangeForWhitespace adjusts the given range to exclude unnecessary leading or
510// trailing whitespace characters from selection. In the following example, each line
511// of the if statement is indented once. There are also two extra spaces after the
512// closing bracket before the line break.
513//
514// \tif (true) {
515// \t    _ = 1
516// \t}  \n
517//
518// By default, a valid range begins at 'if' and ends at the first whitespace character
519// after the '}'. But, users are likely to highlight full lines rather than adjusting
520// their cursors for whitespace. To support this use case, we must manually adjust the
521// ranges to match the correct AST node. In this particular example, we would adjust
522// rng.Start forward by one byte, and rng.End backwards by two bytes.
523func adjustRangeForWhitespace(rng span.Range, tok *token.File, content []byte) span.Range {
524	offset := tok.Offset(rng.Start)
525	for offset < len(content) {
526		if !unicode.IsSpace(rune(content[offset])) {
527			break
528		}
529		// Move forwards one byte to find a non-whitespace character.
530		offset += 1
531	}
532	rng.Start = tok.Pos(offset)
533
534	// Move backwards to find a non-whitespace character.
535	offset = tok.Offset(rng.End)
536	for o := offset - 1; 0 <= o && o < len(content); o-- {
537		if !unicode.IsSpace(rune(content[o])) {
538			break
539		}
540		offset = o
541	}
542	rng.End = tok.Pos(offset)
543	return rng
544}
545
546// findParent finds the parent AST node of the given target node, if the target is a
547// descendant of the starting node.
548func findParent(start ast.Node, target ast.Node) ast.Node {
549	var parent ast.Node
550	analysisinternal.WalkASTWithParent(start, func(n, p ast.Node) bool {
551		if n == target {
552			parent = p
553			return false
554		}
555		return true
556	})
557	return parent
558}
559
560// variable describes the status of a variable within a selection.
561type variable struct {
562	obj types.Object
563
564	// free reports whether the variable is a free variable, meaning it should
565	// be a parameter to the extracted function.
566	free bool
567
568	// assigned reports whether the variable is assigned to in the selection.
569	assigned bool
570
571	// defined reports whether the variable is defined in the selection.
572	defined bool
573}
574
575// collectFreeVars maps each identifier in the given range to whether it is "free."
576// Given a range, a variable in that range is defined as "free" if it is declared
577// outside of the range and neither at the file scope nor package scope. These free
578// variables will be used as arguments in the extracted function. It also returns a
579// list of identifiers that may need to be returned by the extracted function.
580// Some of the code in this function has been adapted from tools/cmd/guru/freevars.go.
581func collectFreeVars(info *types.Info, file *ast.File, fileScope, pkgScope *types.Scope, rng span.Range, node ast.Node) ([]*variable, error) {
582	// id returns non-nil if n denotes an object that is referenced by the span
583	// and defined either within the span or in the lexical environment. The bool
584	// return value acts as an indicator for where it was defined.
585	id := func(n *ast.Ident) (types.Object, bool) {
586		obj := info.Uses[n]
587		if obj == nil {
588			return info.Defs[n], false
589		}
590		if obj.Name() == "_" {
591			return nil, false // exclude objects denoting '_'
592		}
593		if _, ok := obj.(*types.PkgName); ok {
594			return nil, false // imported package
595		}
596		if !(file.Pos() <= obj.Pos() && obj.Pos() <= file.End()) {
597			return nil, false // not defined in this file
598		}
599		scope := obj.Parent()
600		if scope == nil {
601			return nil, false // e.g. interface method, struct field
602		}
603		if scope == fileScope || scope == pkgScope {
604			return nil, false // defined at file or package scope
605		}
606		if rng.Start <= obj.Pos() && obj.Pos() <= rng.End {
607			return obj, false // defined within selection => not free
608		}
609		return obj, true
610	}
611	// sel returns non-nil if n denotes a selection o.x.y that is referenced by the
612	// span and defined either within the span or in the lexical environment. The bool
613	// return value acts as an indicator for where it was defined.
614	var sel func(n *ast.SelectorExpr) (types.Object, bool)
615	sel = func(n *ast.SelectorExpr) (types.Object, bool) {
616		switch x := astutil.Unparen(n.X).(type) {
617		case *ast.SelectorExpr:
618			return sel(x)
619		case *ast.Ident:
620			return id(x)
621		}
622		return nil, false
623	}
624	seen := make(map[types.Object]*variable)
625	firstUseIn := make(map[types.Object]token.Pos)
626	var vars []types.Object
627	ast.Inspect(node, func(n ast.Node) bool {
628		if n == nil {
629			return false
630		}
631		if rng.Start <= n.Pos() && n.End() <= rng.End {
632			var obj types.Object
633			var isFree, prune bool
634			switch n := n.(type) {
635			case *ast.Ident:
636				obj, isFree = id(n)
637			case *ast.SelectorExpr:
638				obj, isFree = sel(n)
639				prune = true
640			}
641			if obj != nil {
642				seen[obj] = &variable{
643					obj:  obj,
644					free: isFree,
645				}
646				vars = append(vars, obj)
647				// Find the first time that the object is used in the selection.
648				first, ok := firstUseIn[obj]
649				if !ok || n.Pos() < first {
650					firstUseIn[obj] = n.Pos()
651				}
652				if prune {
653					return false
654				}
655			}
656		}
657		return n.Pos() <= rng.End
658	})
659
660	// Find identifiers that are initialized or whose values are altered at some
661	// point in the selected block. For example, in a selected block from lines 2-4,
662	// variables x, y, and z are included in assigned. However, in a selected block
663	// from lines 3-4, only variables y and z are included in assigned.
664	//
665	// 1: var a int
666	// 2: var x int
667	// 3: y := 3
668	// 4: z := x + a
669	//
670	ast.Inspect(node, func(n ast.Node) bool {
671		if n == nil {
672			return false
673		}
674		if n.Pos() < rng.Start || n.End() > rng.End {
675			return n.Pos() <= rng.End
676		}
677		switch n := n.(type) {
678		case *ast.AssignStmt:
679			for _, assignment := range n.Lhs {
680				lhs, ok := assignment.(*ast.Ident)
681				if !ok {
682					continue
683				}
684				obj, _ := id(lhs)
685				if obj == nil {
686					continue
687				}
688				if _, ok := seen[obj]; !ok {
689					continue
690				}
691				seen[obj].assigned = true
692				if n.Tok != token.DEFINE {
693					continue
694				}
695				// Find identifiers that are defined prior to being used
696				// elsewhere in the selection.
697				// TODO: Include identifiers that are assigned prior to being
698				// used elsewhere in the selection. Then, change the assignment
699				// to a definition in the extracted function.
700				if firstUseIn[obj] != lhs.Pos() {
701					continue
702				}
703				// Ensure that the object is not used in its own re-definition.
704				// For example:
705				// var f float64
706				// f, e := math.Frexp(f)
707				for _, expr := range n.Rhs {
708					if referencesObj(info, expr, obj) {
709						continue
710					}
711					if _, ok := seen[obj]; !ok {
712						continue
713					}
714					seen[obj].defined = true
715					break
716				}
717			}
718			return false
719		case *ast.DeclStmt:
720			gen, ok := n.Decl.(*ast.GenDecl)
721			if !ok {
722				return false
723			}
724			for _, spec := range gen.Specs {
725				vSpecs, ok := spec.(*ast.ValueSpec)
726				if !ok {
727					continue
728				}
729				for _, vSpec := range vSpecs.Names {
730					obj, _ := id(vSpec)
731					if obj == nil {
732						continue
733					}
734					if _, ok := seen[obj]; !ok {
735						continue
736					}
737					seen[obj].assigned = true
738				}
739			}
740			return false
741		case *ast.IncDecStmt:
742			if ident, ok := n.X.(*ast.Ident); !ok {
743				return false
744			} else if obj, _ := id(ident); obj == nil {
745				return false
746			} else {
747				if _, ok := seen[obj]; !ok {
748					return false
749				}
750				seen[obj].assigned = true
751			}
752		}
753		return true
754	})
755	var variables []*variable
756	for _, obj := range vars {
757		v, ok := seen[obj]
758		if !ok {
759			return nil, fmt.Errorf("no seen types.Object for %v", obj)
760		}
761		variables = append(variables, v)
762	}
763	return variables, nil
764}
765
766// referencesObj checks whether the given object appears in the given expression.
767func referencesObj(info *types.Info, expr ast.Expr, obj types.Object) bool {
768	var hasObj bool
769	ast.Inspect(expr, func(n ast.Node) bool {
770		if n == nil {
771			return false
772		}
773		ident, ok := n.(*ast.Ident)
774		if !ok {
775			return true
776		}
777		objUse := info.Uses[ident]
778		if obj == objUse {
779			hasObj = true
780			return false
781		}
782		return false
783	})
784	return hasObj
785}
786
787type fnExtractParams struct {
788	tok   *token.File
789	path  []ast.Node
790	rng   span.Range
791	outer *ast.FuncDecl
792	start ast.Node
793}
794
795// canExtractFunction reports whether the code in the given range can be
796// extracted to a function.
797func canExtractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, _ *types.Info) (*fnExtractParams, bool, error) {
798	if rng.Start == rng.End {
799		return nil, false, fmt.Errorf("start and end are equal")
800	}
801	tok := fset.File(file.Pos())
802	if tok == nil {
803		return nil, false, fmt.Errorf("no file for pos %v", fset.Position(file.Pos()))
804	}
805	rng = adjustRangeForWhitespace(rng, tok, src)
806	path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End)
807	if len(path) == 0 {
808		return nil, false, fmt.Errorf("no path enclosing interval")
809	}
810	// Node that encloses the selection must be a statement.
811	// TODO: Support function extraction for an expression.
812	_, ok := path[0].(ast.Stmt)
813	if !ok {
814		return nil, false, fmt.Errorf("node is not a statement")
815	}
816
817	// Find the function declaration that encloses the selection.
818	var outer *ast.FuncDecl
819	for _, p := range path {
820		if p, ok := p.(*ast.FuncDecl); ok {
821			outer = p
822			break
823		}
824	}
825	if outer == nil {
826		return nil, false, fmt.Errorf("no enclosing function")
827	}
828
829	// Find the nodes at the start and end of the selection.
830	var start, end ast.Node
831	ast.Inspect(outer, func(n ast.Node) bool {
832		if n == nil {
833			return false
834		}
835		// Do not override 'start' with a node that begins at the same location
836		// but is nested further from 'outer'.
837		if start == nil && n.Pos() == rng.Start && n.End() <= rng.End {
838			start = n
839		}
840		if end == nil && n.End() == rng.End && n.Pos() >= rng.Start {
841			end = n
842		}
843		return n.Pos() <= rng.End
844	})
845	if start == nil || end == nil {
846		return nil, false, fmt.Errorf("range does not map to AST nodes")
847	}
848	return &fnExtractParams{
849		tok:   tok,
850		path:  path,
851		rng:   rng,
852		outer: outer,
853		start: start,
854	}, true, nil
855}
856
857// objUsed checks if the object is used within the range. It returns the first occurence of
858// the object in the range, if it exists.
859func objUsed(info *types.Info, rng span.Range, obj types.Object) (bool, *ast.Ident) {
860	var firstUse *ast.Ident
861	for id, objUse := range info.Uses {
862		if obj != objUse {
863			continue
864		}
865		if id.Pos() < rng.Start || id.End() > rng.End {
866			continue
867		}
868		if firstUse == nil || id.Pos() < firstUse.Pos() {
869			firstUse = id
870		}
871	}
872	return firstUse != nil, firstUse
873}
874
875// varOverridden traverses the given AST node until we find the given identifier. Then, we
876// examine the occurrence of the given identifier and check for (1) whether the identifier
877// is being redefined. If the identifier is free, we also check for (2) whether the identifier
878// is being reassigned. We will not include an identifier in the return statement of the
879// extracted function if it meets one of the above conditions.
880func varOverridden(info *types.Info, firstUse *ast.Ident, obj types.Object, isFree bool, node ast.Node) bool {
881	var isOverriden bool
882	ast.Inspect(node, func(n ast.Node) bool {
883		if n == nil {
884			return false
885		}
886		assignment, ok := n.(*ast.AssignStmt)
887		if !ok {
888			return true
889		}
890		// A free variable is initialized prior to the selection. We can always reassign
891		// this variable after the selection because it has already been defined.
892		// Conversely, a non-free variable is initialized within the selection. Thus, we
893		// cannot reassign this variable after the selection unless it is initialized and
894		// returned by the extracted function.
895		if !isFree && assignment.Tok == token.ASSIGN {
896			return false
897		}
898		for _, assigned := range assignment.Lhs {
899			ident, ok := assigned.(*ast.Ident)
900			// Check if we found the first use of the identifier.
901			if !ok || ident != firstUse {
902				continue
903			}
904			objUse := info.Uses[ident]
905			if objUse == nil || objUse != obj {
906				continue
907			}
908			// Ensure that the object is not used in its own definition.
909			// For example:
910			// var f float64
911			// f, e := math.Frexp(f)
912			for _, expr := range assignment.Rhs {
913				if referencesObj(info, expr, obj) {
914					return false
915				}
916			}
917			isOverriden = true
918			return false
919		}
920		return false
921	})
922	return isOverriden
923}
924
925// parseExtraction generates an AST file from the given text. We then return the portion of the
926// file that represents the text.
927func parseBlockStmt(fset *token.FileSet, src []byte) (*ast.BlockStmt, error) {
928	text := "package main\nfunc _() { " + string(src) + " }"
929	extract, err := parser.ParseFile(fset, "", text, 0)
930	if err != nil {
931		return nil, err
932	}
933	if len(extract.Decls) == 0 {
934		return nil, fmt.Errorf("parsed file does not contain any declarations")
935	}
936	decl, ok := extract.Decls[0].(*ast.FuncDecl)
937	if !ok {
938		return nil, fmt.Errorf("parsed file does not contain expected function declaration")
939	}
940	if decl.Body == nil {
941		return nil, fmt.Errorf("extracted function has no body")
942	}
943	return decl.Body, nil
944}
945
946// generateReturnInfo generates the information we need to adjust the return statements and
947// signature of the extracted function. We prepare names, signatures, and "zero values" that
948// represent the new variables. We also use this information to construct the if statement that
949// is inserted below the call to the extracted function.
950func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast.Node, file *ast.File, info *types.Info, fset *token.FileSet, pos token.Pos) ([]*returnVariable, *ast.IfStmt, error) {
951	// Generate information for the added bool value.
952	cond := &ast.Ident{Name: generateAvailableIdentifier(pos, file, path, info, "cond", 0)}
953	retVars := []*returnVariable{
954		{
955			name:    cond,
956			decl:    &ast.Field{Type: ast.NewIdent("bool")},
957			zeroVal: ast.NewIdent("false"),
958		},
959	}
960	// Generate information for the values in the return signature of the enclosing function.
961	if enclosing.Results != nil {
962		for i, field := range enclosing.Results.List {
963			typ := info.TypeOf(field.Type)
964			if typ == nil {
965				return nil, nil, fmt.Errorf(
966					"failed type conversion, AST expression: %T", field.Type)
967			}
968			expr := analysisinternal.TypeExpr(fset, file, pkg, typ)
969			if expr == nil {
970				return nil, nil, fmt.Errorf("nil AST expression")
971			}
972			retVars = append(retVars, &returnVariable{
973				name: ast.NewIdent(generateAvailableIdentifier(pos, file,
974					path, info, "ret", i)),
975				decl: &ast.Field{Type: expr},
976				zeroVal: analysisinternal.ZeroValue(
977					fset, file, pkg, typ),
978			})
979		}
980	}
981	// Create the return statement for the enclosing function. We must exclude the variable
982	// for the condition of the if statement (cond) from the return statement.
983	ifReturn := &ast.IfStmt{
984		Cond: cond,
985		Body: &ast.BlockStmt{
986			List: []ast.Stmt{&ast.ReturnStmt{Results: getNames(retVars)[1:]}},
987		},
988	}
989	return retVars, ifReturn, nil
990}
991
992// adjustReturnStatements adds "zero values" of the given types to each return statement
993// in the given AST node.
994func adjustReturnStatements(returnTypes []*ast.Field, seenVars map[types.Object]ast.Expr, fset *token.FileSet, file *ast.File, pkg *types.Package, extractedBlock *ast.BlockStmt) error {
995	var zeroVals []ast.Expr
996	// Create "zero values" for each type.
997	for _, returnType := range returnTypes {
998		var val ast.Expr
999		for obj, typ := range seenVars {
1000			if typ != returnType.Type {
1001				continue
1002			}
1003			val = analysisinternal.ZeroValue(fset, file, pkg, obj.Type())
1004			break
1005		}
1006		if val == nil {
1007			return fmt.Errorf(
1008				"could not find matching AST expression for %T", returnType.Type)
1009		}
1010		zeroVals = append(zeroVals, val)
1011	}
1012	// Add "zero values" to each return statement.
1013	// The bool reports whether the enclosing function should return after calling the
1014	// extracted function. We set the bool to 'true' because, if these return statements
1015	// execute, the extracted function terminates early, and the enclosing function must
1016	// return as well.
1017	zeroVals = append(zeroVals, ast.NewIdent("true"))
1018	ast.Inspect(extractedBlock, func(n ast.Node) bool {
1019		if n == nil {
1020			return false
1021		}
1022		if n, ok := n.(*ast.ReturnStmt); ok {
1023			n.Results = append(zeroVals, n.Results...)
1024			return false
1025		}
1026		return true
1027	})
1028	return nil
1029}
1030
1031// generateFuncCall constructs a call expression for the extracted function, described by the
1032// given parameters and return variables.
1033func generateFuncCall(hasReturnVals bool, params, returns []ast.Expr, name string, token token.Token) ast.Node {
1034	var replace ast.Node
1035	if hasReturnVals {
1036		callExpr := &ast.CallExpr{
1037			Fun:  ast.NewIdent(name),
1038			Args: params,
1039		}
1040		replace = &ast.AssignStmt{
1041			Lhs: returns,
1042			Tok: token,
1043			Rhs: []ast.Expr{callExpr},
1044		}
1045	} else {
1046		replace = &ast.CallExpr{
1047			Fun:  ast.NewIdent(name),
1048			Args: params,
1049		}
1050	}
1051	return replace
1052}
1053
1054// initializeVars creates variable declarations, if needed.
1055// Our preference is to replace the selected block with an "x, y, z := fn()" style
1056// assignment statement. We can use this style when all of the variables in the
1057// extracted function's return statement are either not defined prior to the extracted block
1058// or can be safely redefined. However, for example, if z is already defined
1059// in a different scope, we replace the selected block with:
1060//
1061// var x int
1062// var y string
1063// x, y, z = fn()
1064func initializeVars(uninitialized []types.Object, retVars []*returnVariable, seenUninitialized map[types.Object]struct{}, seenVars map[types.Object]ast.Expr) []ast.Stmt {
1065	var declarations []ast.Stmt
1066	for _, obj := range uninitialized {
1067		if _, ok := seenUninitialized[obj]; ok {
1068			continue
1069		}
1070		seenUninitialized[obj] = struct{}{}
1071		valSpec := &ast.ValueSpec{
1072			Names: []*ast.Ident{ast.NewIdent(obj.Name())},
1073			Type:  seenVars[obj],
1074		}
1075		genDecl := &ast.GenDecl{
1076			Tok:   token.VAR,
1077			Specs: []ast.Spec{valSpec},
1078		}
1079		declarations = append(declarations, &ast.DeclStmt{Decl: genDecl})
1080	}
1081	// Each variable added from a return statement in the selection
1082	// must be initialized.
1083	for i, retVar := range retVars {
1084		n := retVar.name.(*ast.Ident)
1085		valSpec := &ast.ValueSpec{
1086			Names: []*ast.Ident{n},
1087			Type:  retVars[i].decl.Type,
1088		}
1089		genDecl := &ast.GenDecl{
1090			Tok:   token.VAR,
1091			Specs: []ast.Spec{valSpec},
1092		}
1093		declarations = append(declarations, &ast.DeclStmt{Decl: genDecl})
1094	}
1095	return declarations
1096}
1097
1098// getNames returns the names from the given list of returnVariable.
1099func getNames(retVars []*returnVariable) []ast.Expr {
1100	var names []ast.Expr
1101	for _, retVar := range retVars {
1102		names = append(names, retVar.name)
1103	}
1104	return names
1105}
1106
1107// getZeroVals returns the "zero values" from the given list of returnVariable.
1108func getZeroVals(retVars []*returnVariable) []ast.Expr {
1109	var zvs []ast.Expr
1110	for _, retVar := range retVars {
1111		zvs = append(zvs, retVar.zeroVal)
1112	}
1113	return zvs
1114}
1115
1116// getDecls returns the declarations from the given list of returnVariable.
1117func getDecls(retVars []*returnVariable) []*ast.Field {
1118	var decls []*ast.Field
1119	for _, retVar := range retVars {
1120		decls = append(decls, retVar.decl)
1121	}
1122	return decls
1123}
1124