1 //===--- LoopUnrolling.cpp - Unroll loops -----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// This file contains functions which are used to decide if a loop worth to be
10 /// unrolled. Moreover, these functions manages the stack of loop which is
11 /// tracked by the ProgramState.
12 ///
13 //===----------------------------------------------------------------------===//
14
15 #include "clang/ASTMatchers/ASTMatchers.h"
16 #include "clang/ASTMatchers/ASTMatchFinder.h"
17 #include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
18 #include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
19 #include "clang/StaticAnalyzer/Core/PathSensitive/LoopUnrolling.h"
20
21 using namespace clang;
22 using namespace ento;
23 using namespace clang::ast_matchers;
24
25 static const int MAXIMUM_STEP_UNROLLED = 128;
26
27 struct LoopState {
28 private:
29 enum Kind { Normal, Unrolled } K;
30 const Stmt *LoopStmt;
31 const LocationContext *LCtx;
32 unsigned maxStep;
LoopStateLoopState33 LoopState(Kind InK, const Stmt *S, const LocationContext *L, unsigned N)
34 : K(InK), LoopStmt(S), LCtx(L), maxStep(N) {}
35
36 public:
getNormalLoopState37 static LoopState getNormal(const Stmt *S, const LocationContext *L,
38 unsigned N) {
39 return LoopState(Normal, S, L, N);
40 }
getUnrolledLoopState41 static LoopState getUnrolled(const Stmt *S, const LocationContext *L,
42 unsigned N) {
43 return LoopState(Unrolled, S, L, N);
44 }
isUnrolledLoopState45 bool isUnrolled() const { return K == Unrolled; }
getMaxStepLoopState46 unsigned getMaxStep() const { return maxStep; }
getLoopStmtLoopState47 const Stmt *getLoopStmt() const { return LoopStmt; }
getLocationContextLoopState48 const LocationContext *getLocationContext() const { return LCtx; }
operator ==LoopState49 bool operator==(const LoopState &X) const {
50 return K == X.K && LoopStmt == X.LoopStmt;
51 }
ProfileLoopState52 void Profile(llvm::FoldingSetNodeID &ID) const {
53 ID.AddInteger(K);
54 ID.AddPointer(LoopStmt);
55 ID.AddPointer(LCtx);
56 ID.AddInteger(maxStep);
57 }
58 };
59
60 // The tracked stack of loops. The stack indicates that which loops the
61 // simulated element contained by. The loops are marked depending if we decided
62 // to unroll them.
63 // TODO: The loop stack should not need to be in the program state since it is
64 // lexical in nature. Instead, the stack of loops should be tracked in the
65 // LocationContext.
66 REGISTER_LIST_WITH_PROGRAMSTATE(LoopStack, LoopState)
67
68 namespace clang {
69 namespace ento {
70
isLoopStmt(const Stmt * S)71 static bool isLoopStmt(const Stmt *S) {
72 return S && (isa<ForStmt>(S) || isa<WhileStmt>(S) || isa<DoStmt>(S));
73 }
74
processLoopEnd(const Stmt * LoopStmt,ProgramStateRef State)75 ProgramStateRef processLoopEnd(const Stmt *LoopStmt, ProgramStateRef State) {
76 auto LS = State->get<LoopStack>();
77 if (!LS.isEmpty() && LS.getHead().getLoopStmt() == LoopStmt)
78 State = State->set<LoopStack>(LS.getTail());
79 return State;
80 }
81
simpleCondition(StringRef BindName)82 static internal::Matcher<Stmt> simpleCondition(StringRef BindName) {
83 return binaryOperator(anyOf(hasOperatorName("<"), hasOperatorName(">"),
84 hasOperatorName("<="), hasOperatorName(">="),
85 hasOperatorName("!=")),
86 hasEitherOperand(ignoringParenImpCasts(declRefExpr(
87 to(varDecl(hasType(isInteger())).bind(BindName))))),
88 hasEitherOperand(ignoringParenImpCasts(
89 integerLiteral().bind("boundNum"))))
90 .bind("conditionOperator");
91 }
92
93 static internal::Matcher<Stmt>
changeIntBoundNode(internal::Matcher<Decl> VarNodeMatcher)94 changeIntBoundNode(internal::Matcher<Decl> VarNodeMatcher) {
95 return anyOf(
96 unaryOperator(anyOf(hasOperatorName("--"), hasOperatorName("++")),
97 hasUnaryOperand(ignoringParenImpCasts(
98 declRefExpr(to(varDecl(VarNodeMatcher)))))),
99 binaryOperator(isAssignmentOperator(),
100 hasLHS(ignoringParenImpCasts(
101 declRefExpr(to(varDecl(VarNodeMatcher)))))));
102 }
103
104 static internal::Matcher<Stmt>
callByRef(internal::Matcher<Decl> VarNodeMatcher)105 callByRef(internal::Matcher<Decl> VarNodeMatcher) {
106 return callExpr(forEachArgumentWithParam(
107 declRefExpr(to(varDecl(VarNodeMatcher))),
108 parmVarDecl(hasType(references(qualType(unless(isConstQualified())))))));
109 }
110
111 static internal::Matcher<Stmt>
assignedToRef(internal::Matcher<Decl> VarNodeMatcher)112 assignedToRef(internal::Matcher<Decl> VarNodeMatcher) {
113 return declStmt(hasDescendant(varDecl(
114 allOf(hasType(referenceType()),
115 hasInitializer(anyOf(
116 initListExpr(has(declRefExpr(to(varDecl(VarNodeMatcher))))),
117 declRefExpr(to(varDecl(VarNodeMatcher)))))))));
118 }
119
120 static internal::Matcher<Stmt>
getAddrTo(internal::Matcher<Decl> VarNodeMatcher)121 getAddrTo(internal::Matcher<Decl> VarNodeMatcher) {
122 return unaryOperator(
123 hasOperatorName("&"),
124 hasUnaryOperand(declRefExpr(hasDeclaration(VarNodeMatcher))));
125 }
126
hasSuspiciousStmt(StringRef NodeName)127 static internal::Matcher<Stmt> hasSuspiciousStmt(StringRef NodeName) {
128 return hasDescendant(stmt(
129 anyOf(gotoStmt(), switchStmt(), returnStmt(),
130 // Escaping and not known mutation of the loop counter is handled
131 // by exclusion of assigning and address-of operators and
132 // pass-by-ref function calls on the loop counter from the body.
133 changeIntBoundNode(equalsBoundNode(std::string(NodeName))),
134 callByRef(equalsBoundNode(std::string(NodeName))),
135 getAddrTo(equalsBoundNode(std::string(NodeName))),
136 assignedToRef(equalsBoundNode(std::string(NodeName))))));
137 }
138
forLoopMatcher()139 static internal::Matcher<Stmt> forLoopMatcher() {
140 return forStmt(
141 hasCondition(simpleCondition("initVarName")),
142 // Initialization should match the form: 'int i = 6' or 'i = 42'.
143 hasLoopInit(
144 anyOf(declStmt(hasSingleDecl(
145 varDecl(allOf(hasInitializer(ignoringParenImpCasts(
146 integerLiteral().bind("initNum"))),
147 equalsBoundNode("initVarName"))))),
148 binaryOperator(hasLHS(declRefExpr(to(varDecl(
149 equalsBoundNode("initVarName"))))),
150 hasRHS(ignoringParenImpCasts(
151 integerLiteral().bind("initNum")))))),
152 // Incrementation should be a simple increment or decrement
153 // operator call.
154 hasIncrement(unaryOperator(
155 anyOf(hasOperatorName("++"), hasOperatorName("--")),
156 hasUnaryOperand(declRefExpr(
157 to(varDecl(allOf(equalsBoundNode("initVarName"),
158 hasType(isInteger())))))))),
159 unless(hasBody(hasSuspiciousStmt("initVarName")))).bind("forLoop");
160 }
161
isPossiblyEscaped(const VarDecl * VD,ExplodedNode * N)162 static bool isPossiblyEscaped(const VarDecl *VD, ExplodedNode *N) {
163 // Global variables assumed as escaped variables.
164 if (VD->hasGlobalStorage())
165 return true;
166
167 const bool isParm = isa<ParmVarDecl>(VD);
168 // Reference parameters are assumed as escaped variables.
169 if (isParm && VD->getType()->isReferenceType())
170 return true;
171
172 while (!N->pred_empty()) {
173 // FIXME: getStmtForDiagnostics() does nasty things in order to provide
174 // a valid statement for body farms, do we need this behavior here?
175 const Stmt *S = N->getStmtForDiagnostics();
176 if (!S) {
177 N = N->getFirstPred();
178 continue;
179 }
180
181 if (const DeclStmt *DS = dyn_cast<DeclStmt>(S)) {
182 for (const Decl *D : DS->decls()) {
183 // Once we reach the declaration of the VD we can return.
184 if (D->getCanonicalDecl() == VD)
185 return false;
186 }
187 }
188 // Check the usage of the pass-by-ref function calls and adress-of operator
189 // on VD and reference initialized by VD.
190 ASTContext &ASTCtx =
191 N->getLocationContext()->getAnalysisDeclContext()->getASTContext();
192 auto Match =
193 match(stmt(anyOf(callByRef(equalsNode(VD)), getAddrTo(equalsNode(VD)),
194 assignedToRef(equalsNode(VD)))),
195 *S, ASTCtx);
196 if (!Match.empty())
197 return true;
198
199 N = N->getFirstPred();
200 }
201
202 // Parameter declaration will not be found.
203 if (isParm)
204 return false;
205
206 llvm_unreachable("Reached root without finding the declaration of VD");
207 }
208
shouldCompletelyUnroll(const Stmt * LoopStmt,ASTContext & ASTCtx,ExplodedNode * Pred,unsigned & maxStep)209 bool shouldCompletelyUnroll(const Stmt *LoopStmt, ASTContext &ASTCtx,
210 ExplodedNode *Pred, unsigned &maxStep) {
211
212 if (!isLoopStmt(LoopStmt))
213 return false;
214
215 // TODO: Match the cases where the bound is not a concrete literal but an
216 // integer with known value
217 auto Matches = match(forLoopMatcher(), *LoopStmt, ASTCtx);
218 if (Matches.empty())
219 return false;
220
221 auto CounterVar = Matches[0].getNodeAs<VarDecl>("initVarName");
222 llvm::APInt BoundNum =
223 Matches[0].getNodeAs<IntegerLiteral>("boundNum")->getValue();
224 llvm::APInt InitNum =
225 Matches[0].getNodeAs<IntegerLiteral>("initNum")->getValue();
226 auto CondOp = Matches[0].getNodeAs<BinaryOperator>("conditionOperator");
227 if (InitNum.getBitWidth() != BoundNum.getBitWidth()) {
228 InitNum = InitNum.zextOrSelf(BoundNum.getBitWidth());
229 BoundNum = BoundNum.zextOrSelf(InitNum.getBitWidth());
230 }
231
232 if (CondOp->getOpcode() == BO_GE || CondOp->getOpcode() == BO_LE)
233 maxStep = (BoundNum - InitNum + 1).abs().getZExtValue();
234 else
235 maxStep = (BoundNum - InitNum).abs().getZExtValue();
236
237 // Check if the counter of the loop is not escaped before.
238 return !isPossiblyEscaped(CounterVar->getCanonicalDecl(), Pred);
239 }
240
madeNewBranch(ExplodedNode * N,const Stmt * LoopStmt)241 bool madeNewBranch(ExplodedNode *N, const Stmt *LoopStmt) {
242 const Stmt *S = nullptr;
243 while (!N->pred_empty()) {
244 if (N->succ_size() > 1)
245 return true;
246
247 ProgramPoint P = N->getLocation();
248 if (Optional<BlockEntrance> BE = P.getAs<BlockEntrance>())
249 S = BE->getBlock()->getTerminatorStmt();
250
251 if (S == LoopStmt)
252 return false;
253
254 N = N->getFirstPred();
255 }
256
257 llvm_unreachable("Reached root without encountering the previous step");
258 }
259
260 // updateLoopStack is called on every basic block, therefore it needs to be fast
updateLoopStack(const Stmt * LoopStmt,ASTContext & ASTCtx,ExplodedNode * Pred,unsigned maxVisitOnPath)261 ProgramStateRef updateLoopStack(const Stmt *LoopStmt, ASTContext &ASTCtx,
262 ExplodedNode *Pred, unsigned maxVisitOnPath) {
263 auto State = Pred->getState();
264 auto LCtx = Pred->getLocationContext();
265
266 if (!isLoopStmt(LoopStmt))
267 return State;
268
269 auto LS = State->get<LoopStack>();
270 if (!LS.isEmpty() && LoopStmt == LS.getHead().getLoopStmt() &&
271 LCtx == LS.getHead().getLocationContext()) {
272 if (LS.getHead().isUnrolled() && madeNewBranch(Pred, LoopStmt)) {
273 State = State->set<LoopStack>(LS.getTail());
274 State = State->add<LoopStack>(
275 LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
276 }
277 return State;
278 }
279 unsigned maxStep;
280 if (!shouldCompletelyUnroll(LoopStmt, ASTCtx, Pred, maxStep)) {
281 State = State->add<LoopStack>(
282 LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
283 return State;
284 }
285
286 unsigned outerStep = (LS.isEmpty() ? 1 : LS.getHead().getMaxStep());
287
288 unsigned innerMaxStep = maxStep * outerStep;
289 if (innerMaxStep > MAXIMUM_STEP_UNROLLED)
290 State = State->add<LoopStack>(
291 LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
292 else
293 State = State->add<LoopStack>(
294 LoopState::getUnrolled(LoopStmt, LCtx, innerMaxStep));
295 return State;
296 }
297
isUnrolledState(ProgramStateRef State)298 bool isUnrolledState(ProgramStateRef State) {
299 auto LS = State->get<LoopStack>();
300 if (LS.isEmpty() || !LS.getHead().isUnrolled())
301 return false;
302 return true;
303 }
304 }
305 }
306