1 //===--- ElseAfterReturnCheck.cpp - clang-tidy-----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "ElseAfterReturnCheck.h"
10 #include "clang/AST/ASTContext.h"
11 #include "clang/ASTMatchers/ASTMatchFinder.h"
12 #include "clang/Lex/Lexer.h"
13 #include "clang/Lex/Preprocessor.h"
14 #include "clang/Tooling/FixIt.h"
15 #include "llvm/ADT/SmallVector.h"
16 
17 using namespace clang::ast_matchers;
18 
19 namespace clang {
20 namespace tidy {
21 namespace readability {
22 
23 namespace {
24 
25 class PPConditionalCollector : public PPCallbacks {
26 public:
PPConditionalCollector(ElseAfterReturnCheck::ConditionalBranchMap & Collections,const SourceManager & SM)27   PPConditionalCollector(
28       ElseAfterReturnCheck::ConditionalBranchMap &Collections,
29       const SourceManager &SM)
30       : Collections(Collections), SM(SM) {}
Endif(SourceLocation Loc,SourceLocation IfLoc)31   void Endif(SourceLocation Loc, SourceLocation IfLoc) override {
32     if (!SM.isWrittenInSameFile(Loc, IfLoc))
33       return;
34     SmallVectorImpl<SourceRange> &Collection = Collections[SM.getFileID(Loc)];
35     assert(Collection.empty() || Collection.back().getEnd() < Loc);
36     Collection.emplace_back(IfLoc, Loc);
37   }
38 
39 private:
40   ElseAfterReturnCheck::ConditionalBranchMap &Collections;
41   const SourceManager &SM;
42 };
43 
44 } // namespace
45 
46 static const char InterruptingStr[] = "interrupting";
47 static const char WarningMessage[] = "do not use 'else' after '%0'";
48 static const char WarnOnUnfixableStr[] = "WarnOnUnfixable";
49 static const char WarnOnConditionVariablesStr[] = "WarnOnConditionVariables";
50 
findUsage(const Stmt * Node,int64_t DeclIdentifier)51 static const DeclRefExpr *findUsage(const Stmt *Node, int64_t DeclIdentifier) {
52   if (!Node)
53     return nullptr;
54   if (const auto *DeclRef = dyn_cast<DeclRefExpr>(Node)) {
55     if (DeclRef->getDecl()->getID() == DeclIdentifier)
56       return DeclRef;
57   } else {
58     for (const Stmt *ChildNode : Node->children()) {
59       if (const DeclRefExpr *Result = findUsage(ChildNode, DeclIdentifier))
60         return Result;
61     }
62   }
63   return nullptr;
64 }
65 
66 static const DeclRefExpr *
findUsageRange(const Stmt * Node,const llvm::ArrayRef<int64_t> & DeclIdentifiers)67 findUsageRange(const Stmt *Node,
68                const llvm::ArrayRef<int64_t> &DeclIdentifiers) {
69   if (!Node)
70     return nullptr;
71   if (const auto *DeclRef = dyn_cast<DeclRefExpr>(Node)) {
72     if (llvm::is_contained(DeclIdentifiers, DeclRef->getDecl()->getID()))
73       return DeclRef;
74   } else {
75     for (const Stmt *ChildNode : Node->children()) {
76       if (const DeclRefExpr *Result =
77               findUsageRange(ChildNode, DeclIdentifiers))
78         return Result;
79     }
80   }
81   return nullptr;
82 }
83 
checkInitDeclUsageInElse(const IfStmt * If)84 static const DeclRefExpr *checkInitDeclUsageInElse(const IfStmt *If) {
85   const auto *InitDeclStmt = dyn_cast_or_null<DeclStmt>(If->getInit());
86   if (!InitDeclStmt)
87     return nullptr;
88   if (InitDeclStmt->isSingleDecl()) {
89     const Decl *InitDecl = InitDeclStmt->getSingleDecl();
90     assert(isa<VarDecl>(InitDecl) && "SingleDecl must be a VarDecl");
91     return findUsage(If->getElse(), InitDecl->getID());
92   }
93   llvm::SmallVector<int64_t, 4> DeclIdentifiers;
94   for (const Decl *ChildDecl : InitDeclStmt->decls()) {
95     assert(isa<VarDecl>(ChildDecl) && "Init Decls must be a VarDecl");
96     DeclIdentifiers.push_back(ChildDecl->getID());
97   }
98   return findUsageRange(If->getElse(), DeclIdentifiers);
99 }
100 
checkConditionVarUsageInElse(const IfStmt * If)101 static const DeclRefExpr *checkConditionVarUsageInElse(const IfStmt *If) {
102   if (const VarDecl *CondVar = If->getConditionVariable())
103     return findUsage(If->getElse(), CondVar->getID());
104   return nullptr;
105 }
106 
containsDeclInScope(const Stmt * Node)107 static bool containsDeclInScope(const Stmt *Node) {
108   if (isa<DeclStmt>(Node))
109     return true;
110   if (const auto *Compound = dyn_cast<CompoundStmt>(Node))
111     return llvm::any_of(Compound->body(), [](const Stmt *SubNode) {
112       return isa<DeclStmt>(SubNode);
113     });
114   return false;
115 }
116 
removeElseAndBrackets(DiagnosticBuilder & Diag,ASTContext & Context,const Stmt * Else,SourceLocation ElseLoc)117 static void removeElseAndBrackets(DiagnosticBuilder &Diag, ASTContext &Context,
118                            const Stmt *Else, SourceLocation ElseLoc) {
119   auto Remap = [&](SourceLocation Loc) {
120     return Context.getSourceManager().getExpansionLoc(Loc);
121   };
122   auto TokLen = [&](SourceLocation Loc) {
123     return Lexer::MeasureTokenLength(Loc, Context.getSourceManager(),
124                                      Context.getLangOpts());
125   };
126 
127   if (const auto *CS = dyn_cast<CompoundStmt>(Else)) {
128     Diag << tooling::fixit::createRemoval(ElseLoc);
129     SourceLocation LBrace = CS->getLBracLoc();
130     SourceLocation RBrace = CS->getRBracLoc();
131     SourceLocation RangeStart =
132         Remap(LBrace).getLocWithOffset(TokLen(LBrace) + 1);
133     SourceLocation RangeEnd = Remap(RBrace).getLocWithOffset(-1);
134 
135     llvm::StringRef Repl = Lexer::getSourceText(
136         CharSourceRange::getTokenRange(RangeStart, RangeEnd),
137         Context.getSourceManager(), Context.getLangOpts());
138     Diag << tooling::fixit::createReplacement(CS->getSourceRange(), Repl);
139   } else {
140     SourceLocation ElseExpandedLoc = Remap(ElseLoc);
141     SourceLocation EndLoc = Remap(Else->getEndLoc());
142 
143     llvm::StringRef Repl = Lexer::getSourceText(
144         CharSourceRange::getTokenRange(
145             ElseExpandedLoc.getLocWithOffset(TokLen(ElseLoc) + 1), EndLoc),
146         Context.getSourceManager(), Context.getLangOpts());
147     Diag << tooling::fixit::createReplacement(
148         SourceRange(ElseExpandedLoc, EndLoc), Repl);
149   }
150 }
151 
ElseAfterReturnCheck(StringRef Name,ClangTidyContext * Context)152 ElseAfterReturnCheck::ElseAfterReturnCheck(StringRef Name,
153                                            ClangTidyContext *Context)
154     : ClangTidyCheck(Name, Context),
155       WarnOnUnfixable(Options.get(WarnOnUnfixableStr, true)),
156       WarnOnConditionVariables(Options.get(WarnOnConditionVariablesStr, true)) {
157 }
158 
storeOptions(ClangTidyOptions::OptionMap & Opts)159 void ElseAfterReturnCheck::storeOptions(ClangTidyOptions::OptionMap &Opts) {
160   Options.store(Opts, WarnOnUnfixableStr, WarnOnUnfixable);
161   Options.store(Opts, WarnOnConditionVariablesStr, WarnOnConditionVariables);
162 }
163 
registerPPCallbacks(const SourceManager & SM,Preprocessor * PP,Preprocessor * ModuleExpanderPP)164 void ElseAfterReturnCheck::registerPPCallbacks(const SourceManager &SM,
165                                                Preprocessor *PP,
166                                                Preprocessor *ModuleExpanderPP) {
167   PP->addPPCallbacks(
168       std::make_unique<PPConditionalCollector>(this->PPConditionals, SM));
169 }
170 
registerMatchers(MatchFinder * Finder)171 void ElseAfterReturnCheck::registerMatchers(MatchFinder *Finder) {
172   const auto InterruptsControlFlow = stmt(anyOf(
173       returnStmt().bind(InterruptingStr), continueStmt().bind(InterruptingStr),
174       breakStmt().bind(InterruptingStr), cxxThrowExpr().bind(InterruptingStr)));
175   Finder->addMatcher(
176       compoundStmt(
177           forEach(ifStmt(unless(isConstexpr()),
178                          hasThen(stmt(
179                              anyOf(InterruptsControlFlow,
180                                    compoundStmt(has(InterruptsControlFlow))))),
181                          hasElse(stmt().bind("else")))
182                       .bind("if")))
183           .bind("cs"),
184       this);
185 }
186 
hasPreprocessorBranchEndBetweenLocations(const ElseAfterReturnCheck::ConditionalBranchMap & ConditionalBranchMap,const SourceManager & SM,SourceLocation StartLoc,SourceLocation EndLoc)187 static bool hasPreprocessorBranchEndBetweenLocations(
188     const ElseAfterReturnCheck::ConditionalBranchMap &ConditionalBranchMap,
189     const SourceManager &SM, SourceLocation StartLoc, SourceLocation EndLoc) {
190 
191   SourceLocation ExpandedStartLoc = SM.getExpansionLoc(StartLoc);
192   SourceLocation ExpandedEndLoc = SM.getExpansionLoc(EndLoc);
193   if (!SM.isWrittenInSameFile(ExpandedStartLoc, ExpandedEndLoc))
194     return false;
195 
196   // StartLoc and EndLoc expand to the same macro.
197   if (ExpandedStartLoc == ExpandedEndLoc)
198     return false;
199 
200   assert(ExpandedStartLoc < ExpandedEndLoc);
201 
202   auto Iter = ConditionalBranchMap.find(SM.getFileID(ExpandedEndLoc));
203 
204   if (Iter == ConditionalBranchMap.end() || Iter->getSecond().empty())
205     return false;
206 
207   const SmallVectorImpl<SourceRange> &ConditionalBranches = Iter->getSecond();
208 
209   assert(llvm::is_sorted(ConditionalBranches,
210                          [](const SourceRange &LHS, const SourceRange &RHS) {
211                            return LHS.getEnd() < RHS.getEnd();
212                          }));
213 
214   // First conditional block that ends after ExpandedStartLoc.
215   const auto *Begin =
216       llvm::lower_bound(ConditionalBranches, ExpandedStartLoc,
217                         [](const SourceRange &LHS, const SourceLocation &RHS) {
218                           return LHS.getEnd() < RHS;
219                         });
220   const auto *End = ConditionalBranches.end();
221   for (; Begin != End && Begin->getEnd() < ExpandedEndLoc; ++Begin)
222     if (Begin->getBegin() < ExpandedStartLoc)
223       return true;
224   return false;
225 }
226 
getControlFlowString(const Stmt & Stmt)227 static StringRef getControlFlowString(const Stmt &Stmt) {
228   if (isa<ReturnStmt>(Stmt))
229     return "return";
230   if (isa<ContinueStmt>(Stmt))
231     return "continue";
232   if (isa<BreakStmt>(Stmt))
233     return "break";
234   if (isa<CXXThrowExpr>(Stmt))
235     return "throw";
236   llvm_unreachable("Unknown control flow interruptor");
237 }
238 
check(const MatchFinder::MatchResult & Result)239 void ElseAfterReturnCheck::check(const MatchFinder::MatchResult &Result) {
240   const auto *If = Result.Nodes.getNodeAs<IfStmt>("if");
241   const auto *Else = Result.Nodes.getNodeAs<Stmt>("else");
242   const auto *OuterScope = Result.Nodes.getNodeAs<CompoundStmt>("cs");
243   const auto *Interrupt = Result.Nodes.getNodeAs<Stmt>(InterruptingStr);
244   SourceLocation ElseLoc = If->getElseLoc();
245 
246   if (hasPreprocessorBranchEndBetweenLocations(
247           PPConditionals, *Result.SourceManager, Interrupt->getBeginLoc(),
248           ElseLoc))
249     return;
250 
251   bool IsLastInScope = OuterScope->body_back() == If;
252   StringRef ControlFlowInterruptor = getControlFlowString(*Interrupt);
253 
254   if (!IsLastInScope && containsDeclInScope(Else)) {
255     if (WarnOnUnfixable) {
256       // Warn, but don't attempt an autofix.
257       diag(ElseLoc, WarningMessage) << ControlFlowInterruptor;
258     }
259     return;
260   }
261 
262   if (checkConditionVarUsageInElse(If) != nullptr) {
263     if (!WarnOnConditionVariables)
264       return;
265     if (IsLastInScope) {
266       // If the if statement is the last statement its enclosing statements
267       // scope, we can pull the decl out of the if statement.
268       DiagnosticBuilder Diag = diag(ElseLoc, WarningMessage)
269                                << ControlFlowInterruptor
270                                << SourceRange(ElseLoc);
271       if (checkInitDeclUsageInElse(If) != nullptr) {
272         Diag << tooling::fixit::createReplacement(
273                     SourceRange(If->getIfLoc()),
274                     (tooling::fixit::getText(*If->getInit(), *Result.Context) +
275                      llvm::StringRef("\n"))
276                         .str())
277              << tooling::fixit::createRemoval(If->getInit()->getSourceRange());
278       }
279       const DeclStmt *VDeclStmt = If->getConditionVariableDeclStmt();
280       const VarDecl *VDecl = If->getConditionVariable();
281       std::string Repl =
282           (tooling::fixit::getText(*VDeclStmt, *Result.Context) +
283            llvm::StringRef(";\n") +
284            tooling::fixit::getText(If->getIfLoc(), *Result.Context))
285               .str();
286       Diag << tooling::fixit::createReplacement(SourceRange(If->getIfLoc()),
287                                                 Repl)
288            << tooling::fixit::createReplacement(VDeclStmt->getSourceRange(),
289                                                 VDecl->getName());
290       removeElseAndBrackets(Diag, *Result.Context, Else, ElseLoc);
291     } else if (WarnOnUnfixable) {
292       // Warn, but don't attempt an autofix.
293       diag(ElseLoc, WarningMessage) << ControlFlowInterruptor;
294     }
295     return;
296   }
297 
298   if (checkInitDeclUsageInElse(If) != nullptr) {
299     if (!WarnOnConditionVariables)
300       return;
301     if (IsLastInScope) {
302       // If the if statement is the last statement its enclosing statements
303       // scope, we can pull the decl out of the if statement.
304       DiagnosticBuilder Diag = diag(ElseLoc, WarningMessage)
305                                << ControlFlowInterruptor
306                                << SourceRange(ElseLoc);
307       Diag << tooling::fixit::createReplacement(
308                   SourceRange(If->getIfLoc()),
309                   (tooling::fixit::getText(*If->getInit(), *Result.Context) +
310                    "\n" +
311                    tooling::fixit::getText(If->getIfLoc(), *Result.Context))
312                       .str())
313            << tooling::fixit::createRemoval(If->getInit()->getSourceRange());
314       removeElseAndBrackets(Diag, *Result.Context, Else, ElseLoc);
315     } else if (WarnOnUnfixable) {
316       // Warn, but don't attempt an autofix.
317       diag(ElseLoc, WarningMessage) << ControlFlowInterruptor;
318     }
319     return;
320   }
321 
322   DiagnosticBuilder Diag = diag(ElseLoc, WarningMessage)
323                            << ControlFlowInterruptor << SourceRange(ElseLoc);
324   removeElseAndBrackets(Diag, *Result.Context, Else, ElseLoc);
325 }
326 
327 } // namespace readability
328 } // namespace tidy
329 } // namespace clang
330