1package rule
2
3import (
4	"fmt"
5	"go/ast"
6	"go/token"
7
8	"github.com/mgechev/revive/lint"
9)
10
11// IncrementDecrementRule lints given else constructs.
12type IncrementDecrementRule struct{}
13
14// Apply applies the rule to given file.
15func (r *IncrementDecrementRule) Apply(file *lint.File, _ lint.Arguments) []lint.Failure {
16	var failures []lint.Failure
17
18	fileAst := file.AST
19	walker := lintIncrementDecrement{
20		file: file,
21		onFailure: func(failure lint.Failure) {
22			failures = append(failures, failure)
23		},
24	}
25
26	ast.Walk(walker, fileAst)
27
28	return failures
29}
30
31// Name returns the rule name.
32func (r *IncrementDecrementRule) Name() string {
33	return "increment-decrement"
34}
35
36type lintIncrementDecrement struct {
37	file      *lint.File
38	fileAst   *ast.File
39	onFailure func(lint.Failure)
40}
41
42func (w lintIncrementDecrement) Visit(n ast.Node) ast.Visitor {
43	as, ok := n.(*ast.AssignStmt)
44	if !ok {
45		return w
46	}
47	if len(as.Lhs) != 1 {
48		return w
49	}
50	if !isOne(as.Rhs[0]) {
51		return w
52	}
53	var suffix string
54	switch as.Tok {
55	case token.ADD_ASSIGN:
56		suffix = "++"
57	case token.SUB_ASSIGN:
58		suffix = "--"
59	default:
60		return w
61	}
62	w.onFailure(lint.Failure{
63		Confidence: 0.8,
64		Node:       as,
65		Category:   "unary-op",
66		Failure:    fmt.Sprintf("should replace %s with %s%s", w.file.Render(as), w.file.Render(as.Lhs[0]), suffix),
67	})
68	return w
69}
70
71func isOne(expr ast.Expr) bool {
72	lit, ok := expr.(*ast.BasicLit)
73	return ok && lit.Kind == token.INT && lit.Value == "1"
74}
75