1package checkers
2
3import (
4	"fmt"
5	"go/ast"
6	"go/token"
7	"strconv"
8
9	"github.com/go-critic/go-critic/checkers/internal/astwalk"
10	"github.com/go-critic/go-critic/checkers/internal/lintutil"
11	"github.com/go-critic/go-critic/framework/linter"
12	"github.com/go-toolsmith/astcast"
13	"github.com/go-toolsmith/astcopy"
14	"github.com/go-toolsmith/astequal"
15	"github.com/go-toolsmith/astp"
16	"github.com/go-toolsmith/typep"
17	"golang.org/x/tools/go/ast/astutil"
18)
19
20func init() {
21	var info linter.CheckerInfo
22	info.Name = "boolExprSimplify"
23	info.Tags = []string{"style", "experimental"}
24	info.Summary = "Detects bool expressions that can be simplified"
25	info.Before = `
26a := !(elapsed >= expectElapsedMin)
27b := !(x) == !(y)`
28	info.After = `
29a := elapsed < expectElapsedMin
30b := (x) == (y)`
31
32	collection.AddChecker(&info, func(ctx *linter.CheckerContext) (linter.FileWalker, error) {
33		return astwalk.WalkerForExpr(&boolExprSimplifyChecker{ctx: ctx}), nil
34	})
35}
36
37type boolExprSimplifyChecker struct {
38	astwalk.WalkHandler
39	ctx       *linter.CheckerContext
40	hasFloats bool
41}
42
43func (c *boolExprSimplifyChecker) VisitExpr(x ast.Expr) {
44	if !astp.IsBinaryExpr(x) && !astp.IsUnaryExpr(x) {
45		return
46	}
47
48	// Throw away non-bool expressions and avoid redundant
49	// AST copying below.
50	if typ := c.ctx.TypeOf(x); typ == nil || !typep.HasBoolKind(typ.Underlying()) {
51		return
52	}
53
54	// We'll loose all types info after a copy,
55	// this is why we record valuable info before doing it.
56	c.hasFloats = lintutil.ContainsNode(x, func(n ast.Node) bool {
57		if x, ok := n.(*ast.BinaryExpr); ok {
58			return typep.HasFloatProp(c.ctx.TypeOf(x.X).Underlying()) ||
59				typep.HasFloatProp(c.ctx.TypeOf(x.Y).Underlying())
60		}
61		return false
62	})
63
64	y := c.simplifyBool(astcopy.Expr(x))
65	if !astequal.Expr(x, y) {
66		c.warn(x, y)
67	}
68}
69
70func (c *boolExprSimplifyChecker) simplifyBool(x ast.Expr) ast.Expr {
71	return astutil.Apply(x, nil, func(cur *astutil.Cursor) bool {
72		return c.doubleNegation(cur) ||
73			c.negatedEquals(cur) ||
74			c.invertComparison(cur) ||
75			c.combineChecks(cur) ||
76			c.removeIncDec(cur) ||
77			c.foldRanges(cur) ||
78			true
79	}).(ast.Expr)
80}
81
82func (c *boolExprSimplifyChecker) doubleNegation(cur *astutil.Cursor) bool {
83	neg1 := astcast.ToUnaryExpr(cur.Node())
84	neg2 := astcast.ToUnaryExpr(astutil.Unparen(neg1.X))
85	if neg1.Op == token.NOT && neg2.Op == token.NOT {
86		cur.Replace(astutil.Unparen(neg2.X))
87		return true
88	}
89	return false
90}
91
92func (c *boolExprSimplifyChecker) negatedEquals(cur *astutil.Cursor) bool {
93	x, ok := cur.Node().(*ast.BinaryExpr)
94	if !ok || x.Op != token.EQL {
95		return false
96	}
97	neg1 := astcast.ToUnaryExpr(x.X)
98	neg2 := astcast.ToUnaryExpr(x.Y)
99	if neg1.Op == token.NOT && neg2.Op == token.NOT {
100		x.X = neg1.X
101		x.Y = neg2.X
102		return true
103	}
104	return false
105}
106
107func (c *boolExprSimplifyChecker) invertComparison(cur *astutil.Cursor) bool {
108	if c.hasFloats { // See #673
109		return false
110	}
111
112	neg := astcast.ToUnaryExpr(cur.Node())
113	cmp := astcast.ToBinaryExpr(astutil.Unparen(neg.X))
114	if neg.Op != token.NOT {
115		return false
116	}
117
118	// Replace operator to its negated form.
119	switch cmp.Op {
120	case token.EQL:
121		cmp.Op = token.NEQ
122	case token.NEQ:
123		cmp.Op = token.EQL
124	case token.LSS:
125		cmp.Op = token.GEQ
126	case token.GTR:
127		cmp.Op = token.LEQ
128	case token.LEQ:
129		cmp.Op = token.GTR
130	case token.GEQ:
131		cmp.Op = token.LSS
132
133	default:
134		return false
135	}
136	cur.Replace(cmp)
137	return true
138}
139
140func (c *boolExprSimplifyChecker) isSafe(x ast.Expr) bool {
141	return typep.SideEffectFree(c.ctx.TypesInfo, x)
142}
143
144func (c *boolExprSimplifyChecker) combineChecks(cur *astutil.Cursor) bool {
145	or, ok := cur.Node().(*ast.BinaryExpr)
146	if !ok || or.Op != token.LOR {
147		return false
148	}
149
150	lhs := astcast.ToBinaryExpr(astutil.Unparen(or.X))
151	rhs := astcast.ToBinaryExpr(astutil.Unparen(or.Y))
152
153	if !astequal.Expr(lhs.X, rhs.X) || !astequal.Expr(lhs.Y, rhs.Y) {
154		return false
155	}
156	if !c.isSafe(lhs.X) || !c.isSafe(lhs.Y) {
157		return false
158	}
159
160	combTable := [...]struct {
161		x      token.Token
162		y      token.Token
163		result token.Token
164	}{
165		{token.GTR, token.EQL, token.GEQ},
166		{token.EQL, token.GTR, token.GEQ},
167		{token.LSS, token.EQL, token.LEQ},
168		{token.EQL, token.LSS, token.LEQ},
169	}
170	for _, comb := range &combTable {
171		if comb.x == lhs.Op && comb.y == rhs.Op {
172			lhs.Op = comb.result
173			cur.Replace(lhs)
174			return true
175		}
176	}
177	return false
178}
179
180func (c *boolExprSimplifyChecker) removeIncDec(cur *astutil.Cursor) bool {
181	cmp := astcast.ToBinaryExpr(cur.Node())
182
183	matchOneWay := func(op token.Token, x, y *ast.BinaryExpr) bool {
184		if x.Op != op || astcast.ToBasicLit(x.Y).Value != "1" {
185			return false
186		}
187		if y.Op == op && astcast.ToBasicLit(y.Y).Value == "1" {
188			return false
189		}
190		return true
191	}
192	replace := func(lhsOp, rhsOp, replacement token.Token) bool {
193		lhs := astcast.ToBinaryExpr(cmp.X)
194		rhs := astcast.ToBinaryExpr(cmp.Y)
195		switch {
196		case matchOneWay(lhsOp, lhs, rhs):
197			cmp.X = lhs.X
198			cmp.Op = replacement
199			cur.Replace(cmp)
200			return true
201		case matchOneWay(rhsOp, rhs, lhs):
202			cmp.Y = rhs.X
203			cmp.Op = replacement
204			cur.Replace(cmp)
205			return true
206		default:
207			return false
208		}
209	}
210
211	switch cmp.Op {
212	case token.GTR:
213		// `x > y-1` => `x >= y`
214		// `x+1 > y` => `x >= y`
215		return replace(token.ADD, token.SUB, token.GEQ)
216
217	case token.GEQ:
218		// `x >= y+1` => `x > y`
219		// `x-1 >= y` => `x > y`
220		return replace(token.SUB, token.ADD, token.GTR)
221
222	case token.LSS:
223		// `x < y+1` => `x <= y`
224		// `x-1 < y` => `x <= y`
225		return replace(token.SUB, token.ADD, token.LEQ)
226
227	case token.LEQ:
228		// `x <= y-1` => `x < y`
229		// `x+1 <= y` => `x < y`
230		return replace(token.ADD, token.SUB, token.LSS)
231
232	default:
233		return false
234	}
235}
236
237func (c *boolExprSimplifyChecker) foldRanges(cur *astutil.Cursor) bool {
238	if c.hasFloats { // See #848
239		return false
240	}
241
242	e, ok := cur.Node().(*ast.BinaryExpr)
243	if !ok {
244		return false
245	}
246	lhs := astcast.ToBinaryExpr(e.X)
247	rhs := astcast.ToBinaryExpr(e.Y)
248	if !c.isSafe(lhs.X) || !c.isSafe(rhs.X) {
249		return false
250	}
251	if !astequal.Expr(lhs.X, rhs.X) {
252		return false
253	}
254
255	c1, ok := c.int64val(lhs.Y)
256	if !ok {
257		return false
258	}
259	c2, ok := c.int64val(rhs.Y)
260	if !ok {
261		return false
262	}
263
264	type combination struct {
265		lhsOp    token.Token
266		rhsOp    token.Token
267		rhsDiff  int64
268		resDelta int64
269	}
270	match := func(comb *combination) bool {
271		if lhs.Op != comb.lhsOp || rhs.Op != comb.rhsOp {
272			return false
273		}
274		if c2-c1 != comb.rhsDiff {
275			return false
276		}
277		return true
278	}
279
280	switch e.Op {
281	case token.LAND:
282		combTable := [...]combination{
283			// `x > c && x < c+2` => `x == c+1`
284			{token.GTR, token.LSS, 2, 1},
285			// `x >= c && x < c+1` => `x == c`
286			{token.GEQ, token.LSS, 1, 0},
287			// `x > c && x <= c+1` => `x == c+1`
288			{token.GTR, token.LEQ, 1, 1},
289			// `x >= c && x <= c` => `x == c`
290			{token.GEQ, token.LEQ, 0, 0},
291		}
292		for i := range combTable {
293			comb := combTable[i]
294			if match(&comb) {
295				lhs.Op = token.EQL
296				v := c1 + comb.resDelta
297				lhs.Y.(*ast.BasicLit).Value = fmt.Sprint(v)
298				cur.Replace(lhs)
299				return true
300			}
301		}
302
303	case token.LOR:
304		combTable := [...]combination{
305			// `x < c || x > c` => `x != c`
306			{token.LSS, token.GTR, 0, 0},
307			// `x <= c || x > c+1` => `x != c+1`
308			{token.LEQ, token.GTR, 1, 1},
309			// `x < c || x >= c+1` => `x != c`
310			{token.LSS, token.GEQ, 1, 0},
311			// `x <= c || x >= c+2` => `x != c+1`
312			{token.LEQ, token.GEQ, 2, 1},
313		}
314		for i := range combTable {
315			comb := combTable[i]
316			if match(&comb) {
317				lhs.Op = token.NEQ
318				v := c1 + comb.resDelta
319				lhs.Y.(*ast.BasicLit).Value = fmt.Sprint(v)
320				cur.Replace(lhs)
321				return true
322			}
323		}
324	}
325
326	return false
327}
328
329func (c *boolExprSimplifyChecker) int64val(x ast.Expr) (int64, bool) {
330	// TODO(quasilyte): if we had types info, we could use TypesInfo.Types[x].Value,
331	// but since copying erases leaves us without it, only basic literals are handled.
332	lit, ok := x.(*ast.BasicLit)
333	if !ok {
334		return 0, false
335	}
336	v, err := strconv.ParseInt(lit.Value, 10, 64)
337	if err != nil {
338		return 0, false
339	}
340	return v, true
341}
342
343func (c *boolExprSimplifyChecker) warn(cause, suggestion ast.Expr) {
344	c.SkipChilds = true
345	c.ctx.Warn(cause, "can simplify `%s` to `%s`", cause, suggestion)
346}
347