1package checkers 2 3import ( 4 "fmt" 5 "go/ast" 6 "go/token" 7 "strconv" 8 9 "github.com/go-critic/go-critic/checkers/internal/astwalk" 10 "github.com/go-critic/go-critic/checkers/internal/lintutil" 11 "github.com/go-critic/go-critic/framework/linter" 12 "github.com/go-toolsmith/astcast" 13 "github.com/go-toolsmith/astcopy" 14 "github.com/go-toolsmith/astequal" 15 "github.com/go-toolsmith/astp" 16 "github.com/go-toolsmith/typep" 17 "golang.org/x/tools/go/ast/astutil" 18) 19 20func init() { 21 var info linter.CheckerInfo 22 info.Name = "boolExprSimplify" 23 info.Tags = []string{"style", "experimental"} 24 info.Summary = "Detects bool expressions that can be simplified" 25 info.Before = ` 26a := !(elapsed >= expectElapsedMin) 27b := !(x) == !(y)` 28 info.After = ` 29a := elapsed < expectElapsedMin 30b := (x) == (y)` 31 32 collection.AddChecker(&info, func(ctx *linter.CheckerContext) (linter.FileWalker, error) { 33 return astwalk.WalkerForExpr(&boolExprSimplifyChecker{ctx: ctx}), nil 34 }) 35} 36 37type boolExprSimplifyChecker struct { 38 astwalk.WalkHandler 39 ctx *linter.CheckerContext 40 hasFloats bool 41} 42 43func (c *boolExprSimplifyChecker) VisitExpr(x ast.Expr) { 44 if !astp.IsBinaryExpr(x) && !astp.IsUnaryExpr(x) { 45 return 46 } 47 48 // Throw away non-bool expressions and avoid redundant 49 // AST copying below. 50 if typ := c.ctx.TypeOf(x); typ == nil || !typep.HasBoolKind(typ.Underlying()) { 51 return 52 } 53 54 // We'll loose all types info after a copy, 55 // this is why we record valuable info before doing it. 56 c.hasFloats = lintutil.ContainsNode(x, func(n ast.Node) bool { 57 if x, ok := n.(*ast.BinaryExpr); ok { 58 return typep.HasFloatProp(c.ctx.TypeOf(x.X).Underlying()) || 59 typep.HasFloatProp(c.ctx.TypeOf(x.Y).Underlying()) 60 } 61 return false 62 }) 63 64 y := c.simplifyBool(astcopy.Expr(x)) 65 if !astequal.Expr(x, y) { 66 c.warn(x, y) 67 } 68} 69 70func (c *boolExprSimplifyChecker) simplifyBool(x ast.Expr) ast.Expr { 71 return astutil.Apply(x, nil, func(cur *astutil.Cursor) bool { 72 return c.doubleNegation(cur) || 73 c.negatedEquals(cur) || 74 c.invertComparison(cur) || 75 c.combineChecks(cur) || 76 c.removeIncDec(cur) || 77 c.foldRanges(cur) || 78 true 79 }).(ast.Expr) 80} 81 82func (c *boolExprSimplifyChecker) doubleNegation(cur *astutil.Cursor) bool { 83 neg1 := astcast.ToUnaryExpr(cur.Node()) 84 neg2 := astcast.ToUnaryExpr(astutil.Unparen(neg1.X)) 85 if neg1.Op == token.NOT && neg2.Op == token.NOT { 86 cur.Replace(astutil.Unparen(neg2.X)) 87 return true 88 } 89 return false 90} 91 92func (c *boolExprSimplifyChecker) negatedEquals(cur *astutil.Cursor) bool { 93 x, ok := cur.Node().(*ast.BinaryExpr) 94 if !ok || x.Op != token.EQL { 95 return false 96 } 97 neg1 := astcast.ToUnaryExpr(x.X) 98 neg2 := astcast.ToUnaryExpr(x.Y) 99 if neg1.Op == token.NOT && neg2.Op == token.NOT { 100 x.X = neg1.X 101 x.Y = neg2.X 102 return true 103 } 104 return false 105} 106 107func (c *boolExprSimplifyChecker) invertComparison(cur *astutil.Cursor) bool { 108 if c.hasFloats { // See #673 109 return false 110 } 111 112 neg := astcast.ToUnaryExpr(cur.Node()) 113 cmp := astcast.ToBinaryExpr(astutil.Unparen(neg.X)) 114 if neg.Op != token.NOT { 115 return false 116 } 117 118 // Replace operator to its negated form. 119 switch cmp.Op { 120 case token.EQL: 121 cmp.Op = token.NEQ 122 case token.NEQ: 123 cmp.Op = token.EQL 124 case token.LSS: 125 cmp.Op = token.GEQ 126 case token.GTR: 127 cmp.Op = token.LEQ 128 case token.LEQ: 129 cmp.Op = token.GTR 130 case token.GEQ: 131 cmp.Op = token.LSS 132 133 default: 134 return false 135 } 136 cur.Replace(cmp) 137 return true 138} 139 140func (c *boolExprSimplifyChecker) isSafe(x ast.Expr) bool { 141 return typep.SideEffectFree(c.ctx.TypesInfo, x) 142} 143 144func (c *boolExprSimplifyChecker) combineChecks(cur *astutil.Cursor) bool { 145 or, ok := cur.Node().(*ast.BinaryExpr) 146 if !ok || or.Op != token.LOR { 147 return false 148 } 149 150 lhs := astcast.ToBinaryExpr(astutil.Unparen(or.X)) 151 rhs := astcast.ToBinaryExpr(astutil.Unparen(or.Y)) 152 153 if !astequal.Expr(lhs.X, rhs.X) || !astequal.Expr(lhs.Y, rhs.Y) { 154 return false 155 } 156 if !c.isSafe(lhs.X) || !c.isSafe(lhs.Y) { 157 return false 158 } 159 160 combTable := [...]struct { 161 x token.Token 162 y token.Token 163 result token.Token 164 }{ 165 {token.GTR, token.EQL, token.GEQ}, 166 {token.EQL, token.GTR, token.GEQ}, 167 {token.LSS, token.EQL, token.LEQ}, 168 {token.EQL, token.LSS, token.LEQ}, 169 } 170 for _, comb := range &combTable { 171 if comb.x == lhs.Op && comb.y == rhs.Op { 172 lhs.Op = comb.result 173 cur.Replace(lhs) 174 return true 175 } 176 } 177 return false 178} 179 180func (c *boolExprSimplifyChecker) removeIncDec(cur *astutil.Cursor) bool { 181 cmp := astcast.ToBinaryExpr(cur.Node()) 182 183 matchOneWay := func(op token.Token, x, y *ast.BinaryExpr) bool { 184 if x.Op != op || astcast.ToBasicLit(x.Y).Value != "1" { 185 return false 186 } 187 if y.Op == op && astcast.ToBasicLit(y.Y).Value == "1" { 188 return false 189 } 190 return true 191 } 192 replace := func(lhsOp, rhsOp, replacement token.Token) bool { 193 lhs := astcast.ToBinaryExpr(cmp.X) 194 rhs := astcast.ToBinaryExpr(cmp.Y) 195 switch { 196 case matchOneWay(lhsOp, lhs, rhs): 197 cmp.X = lhs.X 198 cmp.Op = replacement 199 cur.Replace(cmp) 200 return true 201 case matchOneWay(rhsOp, rhs, lhs): 202 cmp.Y = rhs.X 203 cmp.Op = replacement 204 cur.Replace(cmp) 205 return true 206 default: 207 return false 208 } 209 } 210 211 switch cmp.Op { 212 case token.GTR: 213 // `x > y-1` => `x >= y` 214 // `x+1 > y` => `x >= y` 215 return replace(token.ADD, token.SUB, token.GEQ) 216 217 case token.GEQ: 218 // `x >= y+1` => `x > y` 219 // `x-1 >= y` => `x > y` 220 return replace(token.SUB, token.ADD, token.GTR) 221 222 case token.LSS: 223 // `x < y+1` => `x <= y` 224 // `x-1 < y` => `x <= y` 225 return replace(token.SUB, token.ADD, token.LEQ) 226 227 case token.LEQ: 228 // `x <= y-1` => `x < y` 229 // `x+1 <= y` => `x < y` 230 return replace(token.ADD, token.SUB, token.LSS) 231 232 default: 233 return false 234 } 235} 236 237func (c *boolExprSimplifyChecker) foldRanges(cur *astutil.Cursor) bool { 238 if c.hasFloats { // See #848 239 return false 240 } 241 242 e, ok := cur.Node().(*ast.BinaryExpr) 243 if !ok { 244 return false 245 } 246 lhs := astcast.ToBinaryExpr(e.X) 247 rhs := astcast.ToBinaryExpr(e.Y) 248 if !c.isSafe(lhs.X) || !c.isSafe(rhs.X) { 249 return false 250 } 251 if !astequal.Expr(lhs.X, rhs.X) { 252 return false 253 } 254 255 c1, ok := c.int64val(lhs.Y) 256 if !ok { 257 return false 258 } 259 c2, ok := c.int64val(rhs.Y) 260 if !ok { 261 return false 262 } 263 264 type combination struct { 265 lhsOp token.Token 266 rhsOp token.Token 267 rhsDiff int64 268 resDelta int64 269 } 270 match := func(comb *combination) bool { 271 if lhs.Op != comb.lhsOp || rhs.Op != comb.rhsOp { 272 return false 273 } 274 if c2-c1 != comb.rhsDiff { 275 return false 276 } 277 return true 278 } 279 280 switch e.Op { 281 case token.LAND: 282 combTable := [...]combination{ 283 // `x > c && x < c+2` => `x == c+1` 284 {token.GTR, token.LSS, 2, 1}, 285 // `x >= c && x < c+1` => `x == c` 286 {token.GEQ, token.LSS, 1, 0}, 287 // `x > c && x <= c+1` => `x == c+1` 288 {token.GTR, token.LEQ, 1, 1}, 289 // `x >= c && x <= c` => `x == c` 290 {token.GEQ, token.LEQ, 0, 0}, 291 } 292 for i := range combTable { 293 comb := combTable[i] 294 if match(&comb) { 295 lhs.Op = token.EQL 296 v := c1 + comb.resDelta 297 lhs.Y.(*ast.BasicLit).Value = fmt.Sprint(v) 298 cur.Replace(lhs) 299 return true 300 } 301 } 302 303 case token.LOR: 304 combTable := [...]combination{ 305 // `x < c || x > c` => `x != c` 306 {token.LSS, token.GTR, 0, 0}, 307 // `x <= c || x > c+1` => `x != c+1` 308 {token.LEQ, token.GTR, 1, 1}, 309 // `x < c || x >= c+1` => `x != c` 310 {token.LSS, token.GEQ, 1, 0}, 311 // `x <= c || x >= c+2` => `x != c+1` 312 {token.LEQ, token.GEQ, 2, 1}, 313 } 314 for i := range combTable { 315 comb := combTable[i] 316 if match(&comb) { 317 lhs.Op = token.NEQ 318 v := c1 + comb.resDelta 319 lhs.Y.(*ast.BasicLit).Value = fmt.Sprint(v) 320 cur.Replace(lhs) 321 return true 322 } 323 } 324 } 325 326 return false 327} 328 329func (c *boolExprSimplifyChecker) int64val(x ast.Expr) (int64, bool) { 330 // TODO(quasilyte): if we had types info, we could use TypesInfo.Types[x].Value, 331 // but since copying erases leaves us without it, only basic literals are handled. 332 lit, ok := x.(*ast.BasicLit) 333 if !ok { 334 return 0, false 335 } 336 v, err := strconv.ParseInt(lit.Value, 10, 64) 337 if err != nil { 338 return 0, false 339 } 340 return v, true 341} 342 343func (c *boolExprSimplifyChecker) warn(cause, suggestion ast.Expr) { 344 c.SkipChilds = true 345 c.ctx.Warn(cause, "can simplify `%s` to `%s`", cause, suggestion) 346} 347