1 //===--- UseAnyOfAllOfCheck.cpp - clang-tidy-------------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 
10 #include "UseAnyOfAllOfCheck.h"
11 #include "clang/AST/ASTContext.h"
12 #include "clang/ASTMatchers/ASTMatchFinder.h"
13 #include "clang/Analysis/Analyses/ExprMutationAnalyzer.h"
14 #include "clang/Frontend/CompilerInstance.h"
15 
16 using namespace clang::ast_matchers;
17 
18 namespace clang {
19 namespace {
20 /// Matches a Stmt whose parent is a CompoundStmt, and which is directly
21 /// followed by a Stmt matching the inner matcher.
AST_MATCHER_P(Stmt,nextStmt,ast_matchers::internal::Matcher<Stmt>,InnerMatcher)22 AST_MATCHER_P(Stmt, nextStmt, ast_matchers::internal::Matcher<Stmt>,
23               InnerMatcher) {
24   DynTypedNodeList Parents = Finder->getASTContext().getParents(Node);
25   if (Parents.size() != 1)
26     return false;
27 
28   auto *C = Parents[0].get<CompoundStmt>();
29   if (!C)
30     return false;
31 
32   const auto *I = llvm::find(C->body(), &Node);
33   assert(I != C->body_end() && "C is parent of Node");
34   if (++I == C->body_end())
35     return false; // Node is last statement.
36 
37   return InnerMatcher.matches(**I, Finder, Builder);
38 }
39 } // namespace
40 
41 namespace tidy {
42 namespace readability {
43 
registerMatchers(MatchFinder * Finder)44 void UseAnyOfAllOfCheck::registerMatchers(MatchFinder *Finder) {
45   auto Returns = [](bool V) {
46     return returnStmt(hasReturnValue(cxxBoolLiteral(equals(V))));
47   };
48 
49   auto ReturnsButNotTrue =
50       returnStmt(hasReturnValue(unless(cxxBoolLiteral(equals(true)))));
51   auto ReturnsButNotFalse =
52       returnStmt(hasReturnValue(unless(cxxBoolLiteral(equals(false)))));
53 
54   Finder->addMatcher(
55       cxxForRangeStmt(
56           nextStmt(Returns(false).bind("final_return")),
57           hasBody(allOf(hasDescendant(Returns(true)),
58                         unless(anyOf(hasDescendant(breakStmt()),
59                                      hasDescendant(gotoStmt()),
60                                      hasDescendant(ReturnsButNotTrue))))))
61           .bind("any_of_loop"),
62       this);
63 
64   Finder->addMatcher(
65       cxxForRangeStmt(
66           nextStmt(Returns(true).bind("final_return")),
67           hasBody(allOf(hasDescendant(Returns(false)),
68                         unless(anyOf(hasDescendant(breakStmt()),
69                                      hasDescendant(gotoStmt()),
70                                      hasDescendant(ReturnsButNotFalse))))))
71           .bind("all_of_loop"),
72       this);
73 }
74 
isViableLoop(const CXXForRangeStmt & S,ASTContext & Context)75 static bool isViableLoop(const CXXForRangeStmt &S, ASTContext &Context) {
76 
77   ExprMutationAnalyzer Mutations(*S.getBody(), Context);
78   if (Mutations.isMutated(S.getLoopVariable()))
79     return false;
80   const auto Matches =
81       match(findAll(declRefExpr().bind("decl_ref")), *S.getBody(), Context);
82 
83   return llvm::none_of(Matches, [&Mutations](auto &DeclRef) {
84     // TODO: allow modifications of loop-local variables
85     return Mutations.isMutated(
86         DeclRef.template getNodeAs<DeclRefExpr>("decl_ref")->getDecl());
87   });
88 }
89 
check(const MatchFinder::MatchResult & Result)90 void UseAnyOfAllOfCheck::check(const MatchFinder::MatchResult &Result) {
91 
92   if (const auto *S = Result.Nodes.getNodeAs<CXXForRangeStmt>("any_of_loop")) {
93     if (!isViableLoop(*S, *Result.Context))
94       return;
95 
96     diag(S->getForLoc(), "replace loop by 'std%select{|::ranges}0::any_of()'")
97         << getLangOpts().CPlusPlus20;
98   } else if (const auto *S =
99                  Result.Nodes.getNodeAs<CXXForRangeStmt>("all_of_loop")) {
100     if (!isViableLoop(*S, *Result.Context))
101       return;
102 
103     diag(S->getForLoc(), "replace loop by 'std%select{|::ranges}0::all_of()'")
104         << getLangOpts().CPlusPlus20;
105   }
106 }
107 
108 } // namespace readability
109 } // namespace tidy
110 } // namespace clang
111