1package checkers
2
3import (
4	"go/ast"
5	"go/token"
6	"go/types"
7
8	"github.com/go-critic/go-critic/checkers/internal/astwalk"
9	"github.com/go-critic/go-critic/checkers/internal/lintutil"
10	"github.com/go-critic/go-critic/framework/linter"
11	"github.com/go-toolsmith/astcast"
12	"github.com/go-toolsmith/astequal"
13	"github.com/go-toolsmith/typep"
14)
15
16func init() {
17	var info linter.CheckerInfo
18	info.Name = "unlambda"
19	info.Tags = []string{"style"}
20	info.Summary = "Detects function literals that can be simplified"
21	info.Before = `func(x int) int { return fn(x) }`
22	info.After = `fn`
23
24	collection.AddChecker(&info, func(ctx *linter.CheckerContext) (linter.FileWalker, error) {
25		return astwalk.WalkerForExpr(&unlambdaChecker{ctx: ctx}), nil
26	})
27}
28
29type unlambdaChecker struct {
30	astwalk.WalkHandler
31	ctx *linter.CheckerContext
32}
33
34func (c *unlambdaChecker) VisitExpr(x ast.Expr) {
35	fn, ok := x.(*ast.FuncLit)
36	if !ok || len(fn.Body.List) != 1 {
37		return
38	}
39
40	ret, ok := fn.Body.List[0].(*ast.ReturnStmt)
41	if !ok || len(ret.Results) != 1 {
42		return
43	}
44
45	result := astcast.ToCallExpr(ret.Results[0])
46	callable := qualifiedName(result.Fun)
47	if callable == "" {
48		return // Skip tricky cases; only handle simple calls
49	}
50	if isBuiltin(callable) {
51		return // See #762
52	}
53	hasVars := lintutil.ContainsNode(result.Fun, func(n ast.Node) bool {
54		id, ok := n.(*ast.Ident)
55		if !ok {
56			return false
57		}
58		obj, ok := c.ctx.TypesInfo.ObjectOf(id).(*types.Var)
59		if !ok {
60			return false
61		}
62		// Permit only non-pointer struct method values.
63		return !typep.IsStruct(obj.Type().Underlying())
64	})
65	if hasVars {
66		return // See #888 #1007
67	}
68
69	fnType := c.ctx.TypeOf(fn)
70	resultType := c.ctx.TypeOf(result.Fun)
71	if !types.Identical(fnType, resultType) {
72		return
73	}
74	// Now check that all arguments match the parameters.
75	n := 0
76	for _, params := range fn.Type.Params.List {
77		if _, ok := params.Type.(*ast.Ellipsis); ok {
78			if result.Ellipsis == token.NoPos {
79				return
80			}
81			n++
82			continue
83		}
84
85		for _, id := range params.Names {
86			if !astequal.Expr(id, result.Args[n]) {
87				return
88			}
89			n++
90		}
91	}
92
93	if len(result.Args) == n {
94		c.warn(fn, callable)
95	}
96}
97
98func (c *unlambdaChecker) warn(cause ast.Node, suggestion string) {
99	c.ctx.Warn(cause, "replace `%s` with `%s`", cause, suggestion)
100}
101