1 //===--- LoopUnrolling.cpp - Unroll loops -----------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// This file contains functions which are used to decide if a loop worth to be 10 /// unrolled. Moreover, these functions manages the stack of loop which is 11 /// tracked by the ProgramState. 12 /// 13 //===----------------------------------------------------------------------===// 14 15 #include "clang/ASTMatchers/ASTMatchers.h" 16 #include "clang/ASTMatchers/ASTMatchFinder.h" 17 #include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h" 18 #include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h" 19 #include "clang/StaticAnalyzer/Core/PathSensitive/LoopUnrolling.h" 20 21 using namespace clang; 22 using namespace ento; 23 using namespace clang::ast_matchers; 24 25 static const int MAXIMUM_STEP_UNROLLED = 128; 26 27 struct LoopState { 28 private: 29 enum Kind { Normal, Unrolled } K; 30 const Stmt *LoopStmt; 31 const LocationContext *LCtx; 32 unsigned maxStep; 33 LoopState(Kind InK, const Stmt *S, const LocationContext *L, unsigned N) 34 : K(InK), LoopStmt(S), LCtx(L), maxStep(N) {} 35 36 public: 37 static LoopState getNormal(const Stmt *S, const LocationContext *L, 38 unsigned N) { 39 return LoopState(Normal, S, L, N); 40 } 41 static LoopState getUnrolled(const Stmt *S, const LocationContext *L, 42 unsigned N) { 43 return LoopState(Unrolled, S, L, N); 44 } 45 bool isUnrolled() const { return K == Unrolled; } 46 unsigned getMaxStep() const { return maxStep; } 47 const Stmt *getLoopStmt() const { return LoopStmt; } 48 const LocationContext *getLocationContext() const { return LCtx; } 49 bool operator==(const LoopState &X) const { 50 return K == X.K && LoopStmt == X.LoopStmt; 51 } 52 void Profile(llvm::FoldingSetNodeID &ID) const { 53 ID.AddInteger(K); 54 ID.AddPointer(LoopStmt); 55 ID.AddPointer(LCtx); 56 ID.AddInteger(maxStep); 57 } 58 }; 59 60 // The tracked stack of loops. The stack indicates that which loops the 61 // simulated element contained by. The loops are marked depending if we decided 62 // to unroll them. 63 // TODO: The loop stack should not need to be in the program state since it is 64 // lexical in nature. Instead, the stack of loops should be tracked in the 65 // LocationContext. 66 REGISTER_LIST_WITH_PROGRAMSTATE(LoopStack, LoopState) 67 68 namespace clang { 69 namespace ento { 70 71 static bool isLoopStmt(const Stmt *S) { 72 return isa_and_nonnull<ForStmt, WhileStmt, DoStmt>(S); 73 } 74 75 ProgramStateRef processLoopEnd(const Stmt *LoopStmt, ProgramStateRef State) { 76 auto LS = State->get<LoopStack>(); 77 if (!LS.isEmpty() && LS.getHead().getLoopStmt() == LoopStmt) 78 State = State->set<LoopStack>(LS.getTail()); 79 return State; 80 } 81 82 static internal::Matcher<Stmt> simpleCondition(StringRef BindName, 83 StringRef RefName) { 84 return binaryOperator( 85 anyOf(hasOperatorName("<"), hasOperatorName(">"), 86 hasOperatorName("<="), hasOperatorName(">="), 87 hasOperatorName("!=")), 88 hasEitherOperand(ignoringParenImpCasts( 89 declRefExpr(to(varDecl(hasType(isInteger())).bind(BindName))) 90 .bind(RefName))), 91 hasEitherOperand( 92 ignoringParenImpCasts(integerLiteral().bind("boundNum")))) 93 .bind("conditionOperator"); 94 } 95 96 static internal::Matcher<Stmt> 97 changeIntBoundNode(internal::Matcher<Decl> VarNodeMatcher) { 98 return anyOf( 99 unaryOperator(anyOf(hasOperatorName("--"), hasOperatorName("++")), 100 hasUnaryOperand(ignoringParenImpCasts( 101 declRefExpr(to(varDecl(VarNodeMatcher)))))), 102 binaryOperator(isAssignmentOperator(), 103 hasLHS(ignoringParenImpCasts( 104 declRefExpr(to(varDecl(VarNodeMatcher))))))); 105 } 106 107 static internal::Matcher<Stmt> 108 callByRef(internal::Matcher<Decl> VarNodeMatcher) { 109 return callExpr(forEachArgumentWithParam( 110 declRefExpr(to(varDecl(VarNodeMatcher))), 111 parmVarDecl(hasType(references(qualType(unless(isConstQualified()))))))); 112 } 113 114 static internal::Matcher<Stmt> 115 assignedToRef(internal::Matcher<Decl> VarNodeMatcher) { 116 return declStmt(hasDescendant(varDecl( 117 allOf(hasType(referenceType()), 118 hasInitializer(anyOf( 119 initListExpr(has(declRefExpr(to(varDecl(VarNodeMatcher))))), 120 declRefExpr(to(varDecl(VarNodeMatcher))))))))); 121 } 122 123 static internal::Matcher<Stmt> 124 getAddrTo(internal::Matcher<Decl> VarNodeMatcher) { 125 return unaryOperator( 126 hasOperatorName("&"), 127 hasUnaryOperand(declRefExpr(hasDeclaration(VarNodeMatcher)))); 128 } 129 130 static internal::Matcher<Stmt> hasSuspiciousStmt(StringRef NodeName) { 131 return hasDescendant(stmt( 132 anyOf(gotoStmt(), switchStmt(), returnStmt(), 133 // Escaping and not known mutation of the loop counter is handled 134 // by exclusion of assigning and address-of operators and 135 // pass-by-ref function calls on the loop counter from the body. 136 changeIntBoundNode(equalsBoundNode(std::string(NodeName))), 137 callByRef(equalsBoundNode(std::string(NodeName))), 138 getAddrTo(equalsBoundNode(std::string(NodeName))), 139 assignedToRef(equalsBoundNode(std::string(NodeName)))))); 140 } 141 142 static internal::Matcher<Stmt> forLoopMatcher() { 143 return forStmt( 144 hasCondition(simpleCondition("initVarName", "initVarRef")), 145 // Initialization should match the form: 'int i = 6' or 'i = 42'. 146 hasLoopInit( 147 anyOf(declStmt(hasSingleDecl( 148 varDecl(allOf(hasInitializer(ignoringParenImpCasts( 149 integerLiteral().bind("initNum"))), 150 equalsBoundNode("initVarName"))))), 151 binaryOperator(hasLHS(declRefExpr(to(varDecl( 152 equalsBoundNode("initVarName"))))), 153 hasRHS(ignoringParenImpCasts( 154 integerLiteral().bind("initNum")))))), 155 // Incrementation should be a simple increment or decrement 156 // operator call. 157 hasIncrement(unaryOperator( 158 anyOf(hasOperatorName("++"), hasOperatorName("--")), 159 hasUnaryOperand(declRefExpr( 160 to(varDecl(allOf(equalsBoundNode("initVarName"), 161 hasType(isInteger())))))))), 162 unless(hasBody(hasSuspiciousStmt("initVarName")))) 163 .bind("forLoop"); 164 } 165 166 static bool isCapturedByReference(ExplodedNode *N, const DeclRefExpr *DR) { 167 168 // Get the lambda CXXRecordDecl 169 assert(DR->refersToEnclosingVariableOrCapture()); 170 const LocationContext *LocCtxt = N->getLocationContext(); 171 const Decl *D = LocCtxt->getDecl(); 172 const auto *MD = cast<CXXMethodDecl>(D); 173 assert(MD && MD->getParent()->isLambda() && 174 "Captured variable should only be seen while evaluating a lambda"); 175 const CXXRecordDecl *LambdaCXXRec = MD->getParent(); 176 177 // Lookup the fields of the lambda 178 llvm::DenseMap<const VarDecl *, FieldDecl *> LambdaCaptureFields; 179 FieldDecl *LambdaThisCaptureField; 180 LambdaCXXRec->getCaptureFields(LambdaCaptureFields, LambdaThisCaptureField); 181 182 // Check if the counter is captured by reference 183 const VarDecl *VD = cast<VarDecl>(DR->getDecl()->getCanonicalDecl()); 184 assert(VD); 185 const FieldDecl *FD = LambdaCaptureFields[VD]; 186 assert(FD && "Captured variable without a corresponding field"); 187 return FD->getType()->isReferenceType(); 188 } 189 190 // A loop counter is considered escaped if: 191 // case 1: It is a global variable. 192 // case 2: It is a reference parameter or a reference capture. 193 // case 3: It is assigned to a non-const reference variable or parameter. 194 // case 4: Has its address taken. 195 static bool isPossiblyEscaped(ExplodedNode *N, const DeclRefExpr *DR) { 196 const VarDecl *VD = cast<VarDecl>(DR->getDecl()->getCanonicalDecl()); 197 assert(VD); 198 // Case 1: 199 if (VD->hasGlobalStorage()) 200 return true; 201 202 const bool IsRefParamOrCapture = 203 isa<ParmVarDecl>(VD) || DR->refersToEnclosingVariableOrCapture(); 204 // Case 2: 205 if ((DR->refersToEnclosingVariableOrCapture() && 206 isCapturedByReference(N, DR)) || 207 (IsRefParamOrCapture && VD->getType()->isReferenceType())) 208 return true; 209 210 while (!N->pred_empty()) { 211 // FIXME: getStmtForDiagnostics() does nasty things in order to provide 212 // a valid statement for body farms, do we need this behavior here? 213 const Stmt *S = N->getStmtForDiagnostics(); 214 if (!S) { 215 N = N->getFirstPred(); 216 continue; 217 } 218 219 if (const DeclStmt *DS = dyn_cast<DeclStmt>(S)) { 220 for (const Decl *D : DS->decls()) { 221 // Once we reach the declaration of the VD we can return. 222 if (D->getCanonicalDecl() == VD) 223 return false; 224 } 225 } 226 // Check the usage of the pass-by-ref function calls and adress-of operator 227 // on VD and reference initialized by VD. 228 ASTContext &ASTCtx = 229 N->getLocationContext()->getAnalysisDeclContext()->getASTContext(); 230 // Case 3 and 4: 231 auto Match = 232 match(stmt(anyOf(callByRef(equalsNode(VD)), getAddrTo(equalsNode(VD)), 233 assignedToRef(equalsNode(VD)))), 234 *S, ASTCtx); 235 if (!Match.empty()) 236 return true; 237 238 N = N->getFirstPred(); 239 } 240 241 // Reference parameter and reference capture will not be found. 242 if (IsRefParamOrCapture) 243 return false; 244 245 llvm_unreachable("Reached root without finding the declaration of VD"); 246 } 247 248 bool shouldCompletelyUnroll(const Stmt *LoopStmt, ASTContext &ASTCtx, 249 ExplodedNode *Pred, unsigned &maxStep) { 250 251 if (!isLoopStmt(LoopStmt)) 252 return false; 253 254 // TODO: Match the cases where the bound is not a concrete literal but an 255 // integer with known value 256 auto Matches = match(forLoopMatcher(), *LoopStmt, ASTCtx); 257 if (Matches.empty()) 258 return false; 259 260 const auto *CounterVarRef = Matches[0].getNodeAs<DeclRefExpr>("initVarRef"); 261 llvm::APInt BoundNum = 262 Matches[0].getNodeAs<IntegerLiteral>("boundNum")->getValue(); 263 llvm::APInt InitNum = 264 Matches[0].getNodeAs<IntegerLiteral>("initNum")->getValue(); 265 auto CondOp = Matches[0].getNodeAs<BinaryOperator>("conditionOperator"); 266 if (InitNum.getBitWidth() != BoundNum.getBitWidth()) { 267 InitNum = InitNum.zext(BoundNum.getBitWidth()); 268 BoundNum = BoundNum.zext(InitNum.getBitWidth()); 269 } 270 271 if (CondOp->getOpcode() == BO_GE || CondOp->getOpcode() == BO_LE) 272 maxStep = (BoundNum - InitNum + 1).abs().getZExtValue(); 273 else 274 maxStep = (BoundNum - InitNum).abs().getZExtValue(); 275 276 // Check if the counter of the loop is not escaped before. 277 return !isPossiblyEscaped(Pred, CounterVarRef); 278 } 279 280 bool madeNewBranch(ExplodedNode *N, const Stmt *LoopStmt) { 281 const Stmt *S = nullptr; 282 while (!N->pred_empty()) { 283 if (N->succ_size() > 1) 284 return true; 285 286 ProgramPoint P = N->getLocation(); 287 if (Optional<BlockEntrance> BE = P.getAs<BlockEntrance>()) 288 S = BE->getBlock()->getTerminatorStmt(); 289 290 if (S == LoopStmt) 291 return false; 292 293 N = N->getFirstPred(); 294 } 295 296 llvm_unreachable("Reached root without encountering the previous step"); 297 } 298 299 // updateLoopStack is called on every basic block, therefore it needs to be fast 300 ProgramStateRef updateLoopStack(const Stmt *LoopStmt, ASTContext &ASTCtx, 301 ExplodedNode *Pred, unsigned maxVisitOnPath) { 302 auto State = Pred->getState(); 303 auto LCtx = Pred->getLocationContext(); 304 305 if (!isLoopStmt(LoopStmt)) 306 return State; 307 308 auto LS = State->get<LoopStack>(); 309 if (!LS.isEmpty() && LoopStmt == LS.getHead().getLoopStmt() && 310 LCtx == LS.getHead().getLocationContext()) { 311 if (LS.getHead().isUnrolled() && madeNewBranch(Pred, LoopStmt)) { 312 State = State->set<LoopStack>(LS.getTail()); 313 State = State->add<LoopStack>( 314 LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath)); 315 } 316 return State; 317 } 318 unsigned maxStep; 319 if (!shouldCompletelyUnroll(LoopStmt, ASTCtx, Pred, maxStep)) { 320 State = State->add<LoopStack>( 321 LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath)); 322 return State; 323 } 324 325 unsigned outerStep = (LS.isEmpty() ? 1 : LS.getHead().getMaxStep()); 326 327 unsigned innerMaxStep = maxStep * outerStep; 328 if (innerMaxStep > MAXIMUM_STEP_UNROLLED) 329 State = State->add<LoopStack>( 330 LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath)); 331 else 332 State = State->add<LoopStack>( 333 LoopState::getUnrolled(LoopStmt, LCtx, innerMaxStep)); 334 return State; 335 } 336 337 bool isUnrolledState(ProgramStateRef State) { 338 auto LS = State->get<LoopStack>(); 339 if (LS.isEmpty() || !LS.getHead().isUnrolled()) 340 return false; 341 return true; 342 } 343 } 344 } 345