1 //===--- ElseAfterReturnCheck.cpp - clang-tidy-----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "ElseAfterReturnCheck.h"
10 #include "clang/AST/ASTContext.h"
11 #include "clang/ASTMatchers/ASTMatchFinder.h"
12 #include "clang/Lex/Lexer.h"
13 #include "clang/Tooling/FixIt.h"
14 #include "llvm/ADT/SmallVector.h"
15 
16 using namespace clang::ast_matchers;
17 
18 namespace clang {
19 namespace tidy {
20 namespace readability {
21 
22 static const char ReturnStr[] = "return";
23 static const char ContinueStr[] = "continue";
24 static const char BreakStr[] = "break";
25 static const char ThrowStr[] = "throw";
26 static const char WarningMessage[] = "do not use 'else' after '%0'";
27 static const char WarnOnUnfixableStr[] = "WarnOnUnfixable";
28 static const char WarnOnConditionVariablesStr[] = "WarnOnConditionVariables";
29 
findUsage(const Stmt * Node,int64_t DeclIdentifier)30 static const DeclRefExpr *findUsage(const Stmt *Node, int64_t DeclIdentifier) {
31   if (!Node)
32     return nullptr;
33   if (const auto *DeclRef = dyn_cast<DeclRefExpr>(Node)) {
34     if (DeclRef->getDecl()->getID() == DeclIdentifier)
35       return DeclRef;
36   } else {
37     for (const Stmt *ChildNode : Node->children()) {
38       if (const DeclRefExpr *Result = findUsage(ChildNode, DeclIdentifier))
39         return Result;
40     }
41   }
42   return nullptr;
43 }
44 
45 static const DeclRefExpr *
findUsageRange(const Stmt * Node,const llvm::ArrayRef<int64_t> & DeclIdentifiers)46 findUsageRange(const Stmt *Node,
47                const llvm::ArrayRef<int64_t> &DeclIdentifiers) {
48   if (!Node)
49     return nullptr;
50   if (const auto *DeclRef = dyn_cast<DeclRefExpr>(Node)) {
51     if (llvm::is_contained(DeclIdentifiers, DeclRef->getDecl()->getID()))
52       return DeclRef;
53   } else {
54     for (const Stmt *ChildNode : Node->children()) {
55       if (const DeclRefExpr *Result =
56               findUsageRange(ChildNode, DeclIdentifiers))
57         return Result;
58     }
59   }
60   return nullptr;
61 }
62 
checkInitDeclUsageInElse(const IfStmt * If)63 static const DeclRefExpr *checkInitDeclUsageInElse(const IfStmt *If) {
64   const auto *InitDeclStmt = dyn_cast_or_null<DeclStmt>(If->getInit());
65   if (!InitDeclStmt)
66     return nullptr;
67   if (InitDeclStmt->isSingleDecl()) {
68     const Decl *InitDecl = InitDeclStmt->getSingleDecl();
69     assert(isa<VarDecl>(InitDecl) && "SingleDecl must be a VarDecl");
70     return findUsage(If->getElse(), InitDecl->getID());
71   }
72   llvm::SmallVector<int64_t, 4> DeclIdentifiers;
73   for (const Decl *ChildDecl : InitDeclStmt->decls()) {
74     assert(isa<VarDecl>(ChildDecl) && "Init Decls must be a VarDecl");
75     DeclIdentifiers.push_back(ChildDecl->getID());
76   }
77   return findUsageRange(If->getElse(), DeclIdentifiers);
78 }
79 
checkConditionVarUsageInElse(const IfStmt * If)80 static const DeclRefExpr *checkConditionVarUsageInElse(const IfStmt *If) {
81   if (const VarDecl *CondVar = If->getConditionVariable())
82     return findUsage(If->getElse(), CondVar->getID());
83   return nullptr;
84 }
85 
containsDeclInScope(const Stmt * Node)86 static bool containsDeclInScope(const Stmt *Node) {
87   if (isa<DeclStmt>(Node))
88     return true;
89   if (const auto *Compound = dyn_cast<CompoundStmt>(Node))
90     return llvm::any_of(Compound->body(), [](const Stmt *SubNode) {
91       return isa<DeclStmt>(SubNode);
92     });
93   return false;
94 }
95 
removeElseAndBrackets(DiagnosticBuilder & Diag,ASTContext & Context,const Stmt * Else,SourceLocation ElseLoc)96 static void removeElseAndBrackets(DiagnosticBuilder &Diag, ASTContext &Context,
97                            const Stmt *Else, SourceLocation ElseLoc) {
98   auto Remap = [&](SourceLocation Loc) {
99     return Context.getSourceManager().getExpansionLoc(Loc);
100   };
101   auto TokLen = [&](SourceLocation Loc) {
102     return Lexer::MeasureTokenLength(Loc, Context.getSourceManager(),
103                                      Context.getLangOpts());
104   };
105 
106   if (const auto *CS = dyn_cast<CompoundStmt>(Else)) {
107     Diag << tooling::fixit::createRemoval(ElseLoc);
108     SourceLocation LBrace = CS->getLBracLoc();
109     SourceLocation RBrace = CS->getRBracLoc();
110     SourceLocation RangeStart =
111         Remap(LBrace).getLocWithOffset(TokLen(LBrace) + 1);
112     SourceLocation RangeEnd = Remap(RBrace).getLocWithOffset(-1);
113 
114     llvm::StringRef Repl = Lexer::getSourceText(
115         CharSourceRange::getTokenRange(RangeStart, RangeEnd),
116         Context.getSourceManager(), Context.getLangOpts());
117     Diag << tooling::fixit::createReplacement(CS->getSourceRange(), Repl);
118   } else {
119     SourceLocation ElseExpandedLoc = Remap(ElseLoc);
120     SourceLocation EndLoc = Remap(Else->getEndLoc());
121 
122     llvm::StringRef Repl = Lexer::getSourceText(
123         CharSourceRange::getTokenRange(
124             ElseExpandedLoc.getLocWithOffset(TokLen(ElseLoc) + 1), EndLoc),
125         Context.getSourceManager(), Context.getLangOpts());
126     Diag << tooling::fixit::createReplacement(
127         SourceRange(ElseExpandedLoc, EndLoc), Repl);
128   }
129 }
130 
ElseAfterReturnCheck(StringRef Name,ClangTidyContext * Context)131 ElseAfterReturnCheck::ElseAfterReturnCheck(StringRef Name,
132                                            ClangTidyContext *Context)
133     : ClangTidyCheck(Name, Context),
134       WarnOnUnfixable(Options.get(WarnOnUnfixableStr, true)),
135       WarnOnConditionVariables(Options.get(WarnOnConditionVariablesStr, true)) {
136 }
137 
storeOptions(ClangTidyOptions::OptionMap & Opts)138 void ElseAfterReturnCheck::storeOptions(ClangTidyOptions::OptionMap &Opts) {
139   Options.store(Opts, WarnOnUnfixableStr, WarnOnUnfixable);
140   Options.store(Opts, WarnOnConditionVariablesStr, WarnOnConditionVariables);
141 }
142 
registerMatchers(MatchFinder * Finder)143 void ElseAfterReturnCheck::registerMatchers(MatchFinder *Finder) {
144   const auto InterruptsControlFlow =
145       stmt(anyOf(returnStmt().bind(ReturnStr), continueStmt().bind(ContinueStr),
146                  breakStmt().bind(BreakStr),
147                  expr(ignoringImplicit(cxxThrowExpr().bind(ThrowStr)))));
148   Finder->addMatcher(
149       compoundStmt(
150           forEach(ifStmt(unless(isConstexpr()),
151                          hasThen(stmt(
152                              anyOf(InterruptsControlFlow,
153                                    compoundStmt(has(InterruptsControlFlow))))),
154                          hasElse(stmt().bind("else")))
155                       .bind("if")))
156           .bind("cs"),
157       this);
158 }
159 
check(const MatchFinder::MatchResult & Result)160 void ElseAfterReturnCheck::check(const MatchFinder::MatchResult &Result) {
161   const auto *If = Result.Nodes.getNodeAs<IfStmt>("if");
162   const auto *Else = Result.Nodes.getNodeAs<Stmt>("else");
163   const auto *OuterScope = Result.Nodes.getNodeAs<CompoundStmt>("cs");
164 
165   bool IsLastInScope = OuterScope->body_back() == If;
166   SourceLocation ElseLoc = If->getElseLoc();
167 
168   auto ControlFlowInterruptor = [&]() -> llvm::StringRef {
169     for (llvm::StringRef BindingName :
170          {ReturnStr, ContinueStr, BreakStr, ThrowStr})
171       if (Result.Nodes.getNodeAs<Stmt>(BindingName))
172         return BindingName;
173     return {};
174   }();
175 
176   if (!IsLastInScope && containsDeclInScope(Else)) {
177     if (WarnOnUnfixable) {
178       // Warn, but don't attempt an autofix.
179       diag(ElseLoc, WarningMessage) << ControlFlowInterruptor;
180     }
181     return;
182   }
183 
184   if (checkConditionVarUsageInElse(If) != nullptr) {
185     if (!WarnOnConditionVariables)
186       return;
187     if (IsLastInScope) {
188       // If the if statement is the last statement its enclosing statements
189       // scope, we can pull the decl out of the if statement.
190       DiagnosticBuilder Diag = diag(ElseLoc, WarningMessage)
191                                << ControlFlowInterruptor;
192       if (checkInitDeclUsageInElse(If) != nullptr) {
193         Diag << tooling::fixit::createReplacement(
194                     SourceRange(If->getIfLoc()),
195                     (tooling::fixit::getText(*If->getInit(), *Result.Context) +
196                      llvm::StringRef("\n"))
197                         .str())
198              << tooling::fixit::createRemoval(If->getInit()->getSourceRange());
199       }
200       const DeclStmt *VDeclStmt = If->getConditionVariableDeclStmt();
201       const VarDecl *VDecl = If->getConditionVariable();
202       std::string Repl =
203           (tooling::fixit::getText(*VDeclStmt, *Result.Context) +
204            llvm::StringRef(";\n") +
205            tooling::fixit::getText(If->getIfLoc(), *Result.Context))
206               .str();
207       Diag << tooling::fixit::createReplacement(SourceRange(If->getIfLoc()),
208                                                 Repl)
209            << tooling::fixit::createReplacement(VDeclStmt->getSourceRange(),
210                                                 VDecl->getName());
211       removeElseAndBrackets(Diag, *Result.Context, Else, ElseLoc);
212     } else if (WarnOnUnfixable) {
213       // Warn, but don't attempt an autofix.
214       diag(ElseLoc, WarningMessage) << ControlFlowInterruptor;
215     }
216     return;
217   }
218 
219   if (checkInitDeclUsageInElse(If) != nullptr) {
220     if (!WarnOnConditionVariables)
221       return;
222     if (IsLastInScope) {
223       // If the if statement is the last statement its enclosing statements
224       // scope, we can pull the decl out of the if statement.
225       DiagnosticBuilder Diag = diag(ElseLoc, WarningMessage)
226                                << ControlFlowInterruptor;
227       Diag << tooling::fixit::createReplacement(
228                   SourceRange(If->getIfLoc()),
229                   (tooling::fixit::getText(*If->getInit(), *Result.Context) +
230                    "\n" +
231                    tooling::fixit::getText(If->getIfLoc(), *Result.Context))
232                       .str())
233            << tooling::fixit::createRemoval(If->getInit()->getSourceRange());
234       removeElseAndBrackets(Diag, *Result.Context, Else, ElseLoc);
235     } else if (WarnOnUnfixable) {
236       // Warn, but don't attempt an autofix.
237       diag(ElseLoc, WarningMessage) << ControlFlowInterruptor;
238     }
239     return;
240   }
241 
242   DiagnosticBuilder Diag = diag(ElseLoc, WarningMessage)
243                            << ControlFlowInterruptor;
244   removeElseAndBrackets(Diag, *Result.Context, Else, ElseLoc);
245 }
246 
247 } // namespace readability
248 } // namespace tidy
249 } // namespace clang
250