1// Copyright 2016 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package lostcancel defines an Analyzer that checks for failure to
6// call a context cancellation function.
7package lostcancel
8
9import (
10	"fmt"
11	"go/ast"
12	"go/types"
13
14	"golang.org/x/tools/go/analysis"
15	"golang.org/x/tools/go/analysis/passes/ctrlflow"
16	"golang.org/x/tools/go/analysis/passes/inspect"
17	"golang.org/x/tools/go/ast/inspector"
18	"golang.org/x/tools/go/cfg"
19)
20
21const Doc = `check cancel func returned by context.WithCancel is called
22
23The cancellation function returned by context.WithCancel, WithTimeout,
24and WithDeadline must be called or the new context will remain live
25until its parent context is cancelled.
26(The background context is never cancelled.)`
27
28var Analyzer = &analysis.Analyzer{
29	Name: "lostcancel",
30	Doc:  Doc,
31	Run:  run,
32	Requires: []*analysis.Analyzer{
33		inspect.Analyzer,
34		ctrlflow.Analyzer,
35	},
36}
37
38const debug = false
39
40var contextPackage = "context"
41
42// checkLostCancel reports a failure to the call the cancel function
43// returned by context.WithCancel, either because the variable was
44// assigned to the blank identifier, or because there exists a
45// control-flow path from the call to a return statement and that path
46// does not "use" the cancel function.  Any reference to the variable
47// counts as a use, even within a nested function literal.
48// If the variable's scope is larger than the function
49// containing the assignment, we assume that other uses exist.
50//
51// checkLostCancel analyzes a single named or literal function.
52func run(pass *analysis.Pass) (interface{}, error) {
53	// Fast path: bypass check if file doesn't use context.WithCancel.
54	if !hasImport(pass.Pkg, contextPackage) {
55		return nil, nil
56	}
57
58	// Call runFunc for each Func{Decl,Lit}.
59	inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
60	nodeTypes := []ast.Node{
61		(*ast.FuncLit)(nil),
62		(*ast.FuncDecl)(nil),
63	}
64	inspect.Preorder(nodeTypes, func(n ast.Node) {
65		runFunc(pass, n)
66	})
67	return nil, nil
68}
69
70func runFunc(pass *analysis.Pass, node ast.Node) {
71	// Find scope of function node
72	var funcScope *types.Scope
73	switch v := node.(type) {
74	case *ast.FuncLit:
75		funcScope = pass.TypesInfo.Scopes[v.Type]
76	case *ast.FuncDecl:
77		funcScope = pass.TypesInfo.Scopes[v.Type]
78	}
79
80	// Maps each cancel variable to its defining ValueSpec/AssignStmt.
81	cancelvars := make(map[*types.Var]ast.Node)
82
83	// TODO(adonovan): opt: refactor to make a single pass
84	// over the AST using inspect.WithStack and node types
85	// {FuncDecl,FuncLit,CallExpr,SelectorExpr}.
86
87	// Find the set of cancel vars to analyze.
88	stack := make([]ast.Node, 0, 32)
89	ast.Inspect(node, func(n ast.Node) bool {
90		switch n.(type) {
91		case *ast.FuncLit:
92			if len(stack) > 0 {
93				return false // don't stray into nested functions
94			}
95		case nil:
96			stack = stack[:len(stack)-1] // pop
97			return true
98		}
99		stack = append(stack, n) // push
100
101		// Look for [{AssignStmt,ValueSpec} CallExpr SelectorExpr]:
102		//
103		//   ctx, cancel    := context.WithCancel(...)
104		//   ctx, cancel     = context.WithCancel(...)
105		//   var ctx, cancel = context.WithCancel(...)
106		//
107		if !isContextWithCancel(pass.TypesInfo, n) || !isCall(stack[len(stack)-2]) {
108			return true
109		}
110		var id *ast.Ident // id of cancel var
111		stmt := stack[len(stack)-3]
112		switch stmt := stmt.(type) {
113		case *ast.ValueSpec:
114			if len(stmt.Names) > 1 {
115				id = stmt.Names[1]
116			}
117		case *ast.AssignStmt:
118			if len(stmt.Lhs) > 1 {
119				id, _ = stmt.Lhs[1].(*ast.Ident)
120			}
121		}
122		if id != nil {
123			if id.Name == "_" {
124				pass.ReportRangef(id,
125					"the cancel function returned by context.%s should be called, not discarded, to avoid a context leak",
126					n.(*ast.SelectorExpr).Sel.Name)
127			} else if v, ok := pass.TypesInfo.Uses[id].(*types.Var); ok {
128				// If the cancel variable is defined outside function scope,
129				// do not analyze it.
130				if funcScope.Contains(v.Pos()) {
131					cancelvars[v] = stmt
132				}
133			} else if v, ok := pass.TypesInfo.Defs[id].(*types.Var); ok {
134				cancelvars[v] = stmt
135			}
136		}
137		return true
138	})
139
140	if len(cancelvars) == 0 {
141		return // no need to inspect CFG
142	}
143
144	// Obtain the CFG.
145	cfgs := pass.ResultOf[ctrlflow.Analyzer].(*ctrlflow.CFGs)
146	var g *cfg.CFG
147	var sig *types.Signature
148	switch node := node.(type) {
149	case *ast.FuncDecl:
150		sig, _ = pass.TypesInfo.Defs[node.Name].Type().(*types.Signature)
151		if node.Name.Name == "main" && sig.Recv() == nil && pass.Pkg.Name() == "main" {
152			// Returning from main.main terminates the process,
153			// so there's no need to cancel contexts.
154			return
155		}
156		g = cfgs.FuncDecl(node)
157
158	case *ast.FuncLit:
159		sig, _ = pass.TypesInfo.Types[node.Type].Type.(*types.Signature)
160		g = cfgs.FuncLit(node)
161	}
162	if sig == nil {
163		return // missing type information
164	}
165
166	// Print CFG.
167	if debug {
168		fmt.Println(g.Format(pass.Fset))
169	}
170
171	// Examine the CFG for each variable in turn.
172	// (It would be more efficient to analyze all cancelvars in a
173	// single pass over the AST, but seldom is there more than one.)
174	for v, stmt := range cancelvars {
175		if ret := lostCancelPath(pass, g, v, stmt, sig); ret != nil {
176			lineno := pass.Fset.Position(stmt.Pos()).Line
177			pass.ReportRangef(stmt, "the %s function is not used on all paths (possible context leak)", v.Name())
178			pass.ReportRangef(ret, "this return statement may be reached without using the %s var defined on line %d", v.Name(), lineno)
179		}
180	}
181}
182
183func isCall(n ast.Node) bool { _, ok := n.(*ast.CallExpr); return ok }
184
185func hasImport(pkg *types.Package, path string) bool {
186	for _, imp := range pkg.Imports() {
187		if imp.Path() == path {
188			return true
189		}
190	}
191	return false
192}
193
194// isContextWithCancel reports whether n is one of the qualified identifiers
195// context.With{Cancel,Timeout,Deadline}.
196func isContextWithCancel(info *types.Info, n ast.Node) bool {
197	sel, ok := n.(*ast.SelectorExpr)
198	if !ok {
199		return false
200	}
201	switch sel.Sel.Name {
202	case "WithCancel", "WithTimeout", "WithDeadline":
203	default:
204		return false
205	}
206	if x, ok := sel.X.(*ast.Ident); ok {
207		if pkgname, ok := info.Uses[x].(*types.PkgName); ok {
208			return pkgname.Imported().Path() == contextPackage
209		}
210		// Import failed, so we can't check package path.
211		// Just check the local package name (heuristic).
212		return x.Name == "context"
213	}
214	return false
215}
216
217// lostCancelPath finds a path through the CFG, from stmt (which defines
218// the 'cancel' variable v) to a return statement, that doesn't "use" v.
219// If it finds one, it returns the return statement (which may be synthetic).
220// sig is the function's type, if known.
221func lostCancelPath(pass *analysis.Pass, g *cfg.CFG, v *types.Var, stmt ast.Node, sig *types.Signature) *ast.ReturnStmt {
222	vIsNamedResult := sig != nil && tupleContains(sig.Results(), v)
223
224	// uses reports whether stmts contain a "use" of variable v.
225	uses := func(pass *analysis.Pass, v *types.Var, stmts []ast.Node) bool {
226		found := false
227		for _, stmt := range stmts {
228			ast.Inspect(stmt, func(n ast.Node) bool {
229				switch n := n.(type) {
230				case *ast.Ident:
231					if pass.TypesInfo.Uses[n] == v {
232						found = true
233					}
234				case *ast.ReturnStmt:
235					// A naked return statement counts as a use
236					// of the named result variables.
237					if n.Results == nil && vIsNamedResult {
238						found = true
239					}
240				}
241				return !found
242			})
243		}
244		return found
245	}
246
247	// blockUses computes "uses" for each block, caching the result.
248	memo := make(map[*cfg.Block]bool)
249	blockUses := func(pass *analysis.Pass, v *types.Var, b *cfg.Block) bool {
250		res, ok := memo[b]
251		if !ok {
252			res = uses(pass, v, b.Nodes)
253			memo[b] = res
254		}
255		return res
256	}
257
258	// Find the var's defining block in the CFG,
259	// plus the rest of the statements of that block.
260	var defblock *cfg.Block
261	var rest []ast.Node
262outer:
263	for _, b := range g.Blocks {
264		for i, n := range b.Nodes {
265			if n == stmt {
266				defblock = b
267				rest = b.Nodes[i+1:]
268				break outer
269			}
270		}
271	}
272	if defblock == nil {
273		panic("internal error: can't find defining block for cancel var")
274	}
275
276	// Is v "used" in the remainder of its defining block?
277	if uses(pass, v, rest) {
278		return nil
279	}
280
281	// Does the defining block return without using v?
282	if ret := defblock.Return(); ret != nil {
283		return ret
284	}
285
286	// Search the CFG depth-first for a path, from defblock to a
287	// return block, in which v is never "used".
288	seen := make(map[*cfg.Block]bool)
289	var search func(blocks []*cfg.Block) *ast.ReturnStmt
290	search = func(blocks []*cfg.Block) *ast.ReturnStmt {
291		for _, b := range blocks {
292			if seen[b] {
293				continue
294			}
295			seen[b] = true
296
297			// Prune the search if the block uses v.
298			if blockUses(pass, v, b) {
299				continue
300			}
301
302			// Found path to return statement?
303			if ret := b.Return(); ret != nil {
304				if debug {
305					fmt.Printf("found path to return in block %s\n", b)
306				}
307				return ret // found
308			}
309
310			// Recur
311			if ret := search(b.Succs); ret != nil {
312				if debug {
313					fmt.Printf(" from block %s\n", b)
314				}
315				return ret
316			}
317		}
318		return nil
319	}
320	return search(defblock.Succs)
321}
322
323func tupleContains(tuple *types.Tuple, v *types.Var) bool {
324	for i := 0; i < tuple.Len(); i++ {
325		if tuple.At(i) == v {
326			return true
327		}
328	}
329	return false
330}
331