1package checkers
2
3import (
4	"go/ast"
5
6	"github.com/go-critic/go-critic/checkers/internal/astwalk"
7	"github.com/go-critic/go-critic/framework/linter"
8	"github.com/go-toolsmith/astfmt"
9)
10
11func init() {
12	var info linter.CheckerInfo
13	info.Name = "unnecessaryDefer"
14	info.Tags = []string{"diagnostic", "experimental"}
15	info.Summary = "Detects redundantly deferred calls"
16	info.Before = `
17func() {
18	defer os.Remove(filename)
19}`
20	info.After = `
21func() {
22	os.Remove(filename)
23}`
24
25	collection.AddChecker(&info, func(ctx *linter.CheckerContext) (linter.FileWalker, error) {
26		return astwalk.WalkerForFuncDecl(&unnecessaryDeferChecker{ctx: ctx}), nil
27	})
28}
29
30type unnecessaryDeferChecker struct {
31	astwalk.WalkHandler
32	ctx    *linter.CheckerContext
33	isFunc bool
34}
35
36// Visit implements the ast.Visitor. This visitor keeps track of the block
37// statement belongs to a function or any other block. If the block is not a
38// function and ends with a defer statement that should be OK since it's
39// defering the outer function.
40func (c *unnecessaryDeferChecker) Visit(node ast.Node) ast.Visitor {
41	switch n := node.(type) {
42	case *ast.FuncDecl, *ast.FuncLit:
43		c.isFunc = true
44	case *ast.BlockStmt:
45		c.checkDeferBeforeReturn(n)
46	default:
47		c.isFunc = false
48	}
49
50	return c
51}
52
53func (c *unnecessaryDeferChecker) VisitFuncDecl(funcDecl *ast.FuncDecl) {
54	// We always start as a function (*ast.FuncDecl.Body passed)
55	c.isFunc = true
56
57	ast.Walk(c, funcDecl.Body)
58}
59
60func (c *unnecessaryDeferChecker) checkDeferBeforeReturn(funcDecl *ast.BlockStmt) {
61	// Check if we have an explicit return or if it's just the end of the scope.
62	explicitReturn := false
63	retIndex := len(funcDecl.List)
64	for i, stmt := range funcDecl.List {
65		retStmt, ok := stmt.(*ast.ReturnStmt)
66		if !ok {
67			continue
68		}
69		explicitReturn = true
70		if !c.isTrivialReturn(retStmt) {
71			continue
72		}
73		retIndex = i
74		break
75	}
76	if retIndex == 0 {
77		return
78	}
79
80	if deferStmt, ok := funcDecl.List[retIndex-1].(*ast.DeferStmt); ok {
81		// If the block is a function and ending with return or if we have an
82		// explicit return in any other block we should warn about
83		// unnecessary defer.
84		if c.isFunc || explicitReturn {
85			c.warn(deferStmt)
86		}
87	}
88}
89
90func (c *unnecessaryDeferChecker) isTrivialReturn(ret *ast.ReturnStmt) bool {
91	for _, e := range ret.Results {
92		if !c.isConstExpr(e) {
93			return false
94		}
95	}
96	return true
97}
98
99func (c *unnecessaryDeferChecker) isConstExpr(e ast.Expr) bool {
100	return c.ctx.TypesInfo.Types[e].Value != nil
101}
102
103func (c *unnecessaryDeferChecker) warn(deferStmt *ast.DeferStmt) {
104	s := astfmt.Sprint(deferStmt)
105	if fnlit, ok := deferStmt.Call.Fun.(*ast.FuncLit); ok {
106		// To avoid long and multi-line warning messages,
107		// collapse the function literals.
108		s = "defer " + astfmt.Sprint(fnlit.Type) + "{...}(...)"
109	}
110	c.ctx.Warn(deferStmt, "%s is placed just before return", s)
111}
112