1 //===--- LoopUnrolling.cpp - Unroll loops -----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// This file contains functions which are used to decide if a loop worth to be
10 /// unrolled. Moreover, these functions manages the stack of loop which is
11 /// tracked by the ProgramState.
12 ///
13 //===----------------------------------------------------------------------===//
14
15 #include "clang/ASTMatchers/ASTMatchers.h"
16 #include "clang/ASTMatchers/ASTMatchFinder.h"
17 #include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
18 #include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
19 #include "clang/StaticAnalyzer/Core/PathSensitive/LoopUnrolling.h"
20
21 using namespace clang;
22 using namespace ento;
23 using namespace clang::ast_matchers;
24
25 static const int MAXIMUM_STEP_UNROLLED = 128;
26
27 struct LoopState {
28 private:
29 enum Kind { Normal, Unrolled } K;
30 const Stmt *LoopStmt;
31 const LocationContext *LCtx;
32 unsigned maxStep;
LoopStateLoopState33 LoopState(Kind InK, const Stmt *S, const LocationContext *L, unsigned N)
34 : K(InK), LoopStmt(S), LCtx(L), maxStep(N) {}
35
36 public:
getNormalLoopState37 static LoopState getNormal(const Stmt *S, const LocationContext *L,
38 unsigned N) {
39 return LoopState(Normal, S, L, N);
40 }
getUnrolledLoopState41 static LoopState getUnrolled(const Stmt *S, const LocationContext *L,
42 unsigned N) {
43 return LoopState(Unrolled, S, L, N);
44 }
isUnrolledLoopState45 bool isUnrolled() const { return K == Unrolled; }
getMaxStepLoopState46 unsigned getMaxStep() const { return maxStep; }
getLoopStmtLoopState47 const Stmt *getLoopStmt() const { return LoopStmt; }
getLocationContextLoopState48 const LocationContext *getLocationContext() const { return LCtx; }
operator ==LoopState49 bool operator==(const LoopState &X) const {
50 return K == X.K && LoopStmt == X.LoopStmt;
51 }
ProfileLoopState52 void Profile(llvm::FoldingSetNodeID &ID) const {
53 ID.AddInteger(K);
54 ID.AddPointer(LoopStmt);
55 ID.AddPointer(LCtx);
56 ID.AddInteger(maxStep);
57 }
58 };
59
60 // The tracked stack of loops. The stack indicates that which loops the
61 // simulated element contained by. The loops are marked depending if we decided
62 // to unroll them.
63 // TODO: The loop stack should not need to be in the program state since it is
64 // lexical in nature. Instead, the stack of loops should be tracked in the
65 // LocationContext.
66 REGISTER_LIST_WITH_PROGRAMSTATE(LoopStack, LoopState)
67
68 namespace clang {
69 namespace ento {
70
isLoopStmt(const Stmt * S)71 static bool isLoopStmt(const Stmt *S) {
72 return S && (isa<ForStmt>(S) || isa<WhileStmt>(S) || isa<DoStmt>(S));
73 }
74
processLoopEnd(const Stmt * LoopStmt,ProgramStateRef State)75 ProgramStateRef processLoopEnd(const Stmt *LoopStmt, ProgramStateRef State) {
76 auto LS = State->get<LoopStack>();
77 if (!LS.isEmpty() && LS.getHead().getLoopStmt() == LoopStmt)
78 State = State->set<LoopStack>(LS.getTail());
79 return State;
80 }
81
simpleCondition(StringRef BindName,StringRef RefName)82 static internal::Matcher<Stmt> simpleCondition(StringRef BindName,
83 StringRef RefName) {
84 return binaryOperator(
85 anyOf(hasOperatorName("<"), hasOperatorName(">"),
86 hasOperatorName("<="), hasOperatorName(">="),
87 hasOperatorName("!=")),
88 hasEitherOperand(ignoringParenImpCasts(
89 declRefExpr(to(varDecl(hasType(isInteger())).bind(BindName)))
90 .bind(RefName))),
91 hasEitherOperand(
92 ignoringParenImpCasts(integerLiteral().bind("boundNum"))))
93 .bind("conditionOperator");
94 }
95
96 static internal::Matcher<Stmt>
changeIntBoundNode(internal::Matcher<Decl> VarNodeMatcher)97 changeIntBoundNode(internal::Matcher<Decl> VarNodeMatcher) {
98 return anyOf(
99 unaryOperator(anyOf(hasOperatorName("--"), hasOperatorName("++")),
100 hasUnaryOperand(ignoringParenImpCasts(
101 declRefExpr(to(varDecl(VarNodeMatcher)))))),
102 binaryOperator(isAssignmentOperator(),
103 hasLHS(ignoringParenImpCasts(
104 declRefExpr(to(varDecl(VarNodeMatcher)))))));
105 }
106
107 static internal::Matcher<Stmt>
callByRef(internal::Matcher<Decl> VarNodeMatcher)108 callByRef(internal::Matcher<Decl> VarNodeMatcher) {
109 return callExpr(forEachArgumentWithParam(
110 declRefExpr(to(varDecl(VarNodeMatcher))),
111 parmVarDecl(hasType(references(qualType(unless(isConstQualified())))))));
112 }
113
114 static internal::Matcher<Stmt>
assignedToRef(internal::Matcher<Decl> VarNodeMatcher)115 assignedToRef(internal::Matcher<Decl> VarNodeMatcher) {
116 return declStmt(hasDescendant(varDecl(
117 allOf(hasType(referenceType()),
118 hasInitializer(anyOf(
119 initListExpr(has(declRefExpr(to(varDecl(VarNodeMatcher))))),
120 declRefExpr(to(varDecl(VarNodeMatcher)))))))));
121 }
122
123 static internal::Matcher<Stmt>
getAddrTo(internal::Matcher<Decl> VarNodeMatcher)124 getAddrTo(internal::Matcher<Decl> VarNodeMatcher) {
125 return unaryOperator(
126 hasOperatorName("&"),
127 hasUnaryOperand(declRefExpr(hasDeclaration(VarNodeMatcher))));
128 }
129
hasSuspiciousStmt(StringRef NodeName)130 static internal::Matcher<Stmt> hasSuspiciousStmt(StringRef NodeName) {
131 return hasDescendant(stmt(
132 anyOf(gotoStmt(), switchStmt(), returnStmt(),
133 // Escaping and not known mutation of the loop counter is handled
134 // by exclusion of assigning and address-of operators and
135 // pass-by-ref function calls on the loop counter from the body.
136 changeIntBoundNode(equalsBoundNode(std::string(NodeName))),
137 callByRef(equalsBoundNode(std::string(NodeName))),
138 getAddrTo(equalsBoundNode(std::string(NodeName))),
139 assignedToRef(equalsBoundNode(std::string(NodeName))))));
140 }
141
forLoopMatcher()142 static internal::Matcher<Stmt> forLoopMatcher() {
143 return forStmt(
144 hasCondition(simpleCondition("initVarName", "initVarRef")),
145 // Initialization should match the form: 'int i = 6' or 'i = 42'.
146 hasLoopInit(
147 anyOf(declStmt(hasSingleDecl(
148 varDecl(allOf(hasInitializer(ignoringParenImpCasts(
149 integerLiteral().bind("initNum"))),
150 equalsBoundNode("initVarName"))))),
151 binaryOperator(hasLHS(declRefExpr(to(varDecl(
152 equalsBoundNode("initVarName"))))),
153 hasRHS(ignoringParenImpCasts(
154 integerLiteral().bind("initNum")))))),
155 // Incrementation should be a simple increment or decrement
156 // operator call.
157 hasIncrement(unaryOperator(
158 anyOf(hasOperatorName("++"), hasOperatorName("--")),
159 hasUnaryOperand(declRefExpr(
160 to(varDecl(allOf(equalsBoundNode("initVarName"),
161 hasType(isInteger())))))))),
162 unless(hasBody(hasSuspiciousStmt("initVarName"))))
163 .bind("forLoop");
164 }
165
isCapturedByReference(ExplodedNode * N,const DeclRefExpr * DR)166 static bool isCapturedByReference(ExplodedNode *N, const DeclRefExpr *DR) {
167
168 // Get the lambda CXXRecordDecl
169 assert(DR->refersToEnclosingVariableOrCapture());
170 const LocationContext *LocCtxt = N->getLocationContext();
171 const Decl *D = LocCtxt->getDecl();
172 const auto *MD = cast<CXXMethodDecl>(D);
173 assert(MD && MD->getParent()->isLambda() &&
174 "Captured variable should only be seen while evaluating a lambda");
175 const CXXRecordDecl *LambdaCXXRec = MD->getParent();
176
177 // Lookup the fields of the lambda
178 llvm::DenseMap<const VarDecl *, FieldDecl *> LambdaCaptureFields;
179 FieldDecl *LambdaThisCaptureField;
180 LambdaCXXRec->getCaptureFields(LambdaCaptureFields, LambdaThisCaptureField);
181
182 // Check if the counter is captured by reference
183 const VarDecl *VD = cast<VarDecl>(DR->getDecl()->getCanonicalDecl());
184 assert(VD);
185 const FieldDecl *FD = LambdaCaptureFields[VD];
186 assert(FD && "Captured variable without a corresponding field");
187 return FD->getType()->isReferenceType();
188 }
189
190 // A loop counter is considered escaped if:
191 // case 1: It is a global variable.
192 // case 2: It is a reference parameter or a reference capture.
193 // case 3: It is assigned to a non-const reference variable or parameter.
194 // case 4: Has its address taken.
isPossiblyEscaped(ExplodedNode * N,const DeclRefExpr * DR)195 static bool isPossiblyEscaped(ExplodedNode *N, const DeclRefExpr *DR) {
196 const VarDecl *VD = cast<VarDecl>(DR->getDecl()->getCanonicalDecl());
197 assert(VD);
198 // Case 1:
199 if (VD->hasGlobalStorage())
200 return true;
201
202 const bool IsRefParamOrCapture =
203 isa<ParmVarDecl>(VD) || DR->refersToEnclosingVariableOrCapture();
204 // Case 2:
205 if ((DR->refersToEnclosingVariableOrCapture() &&
206 isCapturedByReference(N, DR)) ||
207 (IsRefParamOrCapture && VD->getType()->isReferenceType()))
208 return true;
209
210 while (!N->pred_empty()) {
211 // FIXME: getStmtForDiagnostics() does nasty things in order to provide
212 // a valid statement for body farms, do we need this behavior here?
213 const Stmt *S = N->getStmtForDiagnostics();
214 if (!S) {
215 N = N->getFirstPred();
216 continue;
217 }
218
219 if (const DeclStmt *DS = dyn_cast<DeclStmt>(S)) {
220 for (const Decl *D : DS->decls()) {
221 // Once we reach the declaration of the VD we can return.
222 if (D->getCanonicalDecl() == VD)
223 return false;
224 }
225 }
226 // Check the usage of the pass-by-ref function calls and adress-of operator
227 // on VD and reference initialized by VD.
228 ASTContext &ASTCtx =
229 N->getLocationContext()->getAnalysisDeclContext()->getASTContext();
230 // Case 3 and 4:
231 auto Match =
232 match(stmt(anyOf(callByRef(equalsNode(VD)), getAddrTo(equalsNode(VD)),
233 assignedToRef(equalsNode(VD)))),
234 *S, ASTCtx);
235 if (!Match.empty())
236 return true;
237
238 N = N->getFirstPred();
239 }
240
241 // Reference parameter and reference capture will not be found.
242 if (IsRefParamOrCapture)
243 return false;
244
245 llvm_unreachable("Reached root without finding the declaration of VD");
246 }
247
shouldCompletelyUnroll(const Stmt * LoopStmt,ASTContext & ASTCtx,ExplodedNode * Pred,unsigned & maxStep)248 bool shouldCompletelyUnroll(const Stmt *LoopStmt, ASTContext &ASTCtx,
249 ExplodedNode *Pred, unsigned &maxStep) {
250
251 if (!isLoopStmt(LoopStmt))
252 return false;
253
254 // TODO: Match the cases where the bound is not a concrete literal but an
255 // integer with known value
256 auto Matches = match(forLoopMatcher(), *LoopStmt, ASTCtx);
257 if (Matches.empty())
258 return false;
259
260 const auto *CounterVarRef = Matches[0].getNodeAs<DeclRefExpr>("initVarRef");
261 llvm::APInt BoundNum =
262 Matches[0].getNodeAs<IntegerLiteral>("boundNum")->getValue();
263 llvm::APInt InitNum =
264 Matches[0].getNodeAs<IntegerLiteral>("initNum")->getValue();
265 auto CondOp = Matches[0].getNodeAs<BinaryOperator>("conditionOperator");
266 if (InitNum.getBitWidth() != BoundNum.getBitWidth()) {
267 InitNum = InitNum.zextOrSelf(BoundNum.getBitWidth());
268 BoundNum = BoundNum.zextOrSelf(InitNum.getBitWidth());
269 }
270
271 if (CondOp->getOpcode() == BO_GE || CondOp->getOpcode() == BO_LE)
272 maxStep = (BoundNum - InitNum + 1).abs().getZExtValue();
273 else
274 maxStep = (BoundNum - InitNum).abs().getZExtValue();
275
276 // Check if the counter of the loop is not escaped before.
277 return !isPossiblyEscaped(Pred, CounterVarRef);
278 }
279
madeNewBranch(ExplodedNode * N,const Stmt * LoopStmt)280 bool madeNewBranch(ExplodedNode *N, const Stmt *LoopStmt) {
281 const Stmt *S = nullptr;
282 while (!N->pred_empty()) {
283 if (N->succ_size() > 1)
284 return true;
285
286 ProgramPoint P = N->getLocation();
287 if (Optional<BlockEntrance> BE = P.getAs<BlockEntrance>())
288 S = BE->getBlock()->getTerminatorStmt();
289
290 if (S == LoopStmt)
291 return false;
292
293 N = N->getFirstPred();
294 }
295
296 llvm_unreachable("Reached root without encountering the previous step");
297 }
298
299 // updateLoopStack is called on every basic block, therefore it needs to be fast
updateLoopStack(const Stmt * LoopStmt,ASTContext & ASTCtx,ExplodedNode * Pred,unsigned maxVisitOnPath)300 ProgramStateRef updateLoopStack(const Stmt *LoopStmt, ASTContext &ASTCtx,
301 ExplodedNode *Pred, unsigned maxVisitOnPath) {
302 auto State = Pred->getState();
303 auto LCtx = Pred->getLocationContext();
304
305 if (!isLoopStmt(LoopStmt))
306 return State;
307
308 auto LS = State->get<LoopStack>();
309 if (!LS.isEmpty() && LoopStmt == LS.getHead().getLoopStmt() &&
310 LCtx == LS.getHead().getLocationContext()) {
311 if (LS.getHead().isUnrolled() && madeNewBranch(Pred, LoopStmt)) {
312 State = State->set<LoopStack>(LS.getTail());
313 State = State->add<LoopStack>(
314 LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
315 }
316 return State;
317 }
318 unsigned maxStep;
319 if (!shouldCompletelyUnroll(LoopStmt, ASTCtx, Pred, maxStep)) {
320 State = State->add<LoopStack>(
321 LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
322 return State;
323 }
324
325 unsigned outerStep = (LS.isEmpty() ? 1 : LS.getHead().getMaxStep());
326
327 unsigned innerMaxStep = maxStep * outerStep;
328 if (innerMaxStep > MAXIMUM_STEP_UNROLLED)
329 State = State->add<LoopStack>(
330 LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
331 else
332 State = State->add<LoopStack>(
333 LoopState::getUnrolled(LoopStmt, LCtx, innerMaxStep));
334 return State;
335 }
336
isUnrolledState(ProgramStateRef State)337 bool isUnrolledState(ProgramStateRef State) {
338 auto LS = State->get<LoopStack>();
339 if (LS.isEmpty() || !LS.getHead().isUnrolled())
340 return false;
341 return true;
342 }
343 }
344 }
345