1 //===------- LoopBoundSplit.cpp - Split Loop Bound --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "llvm/Transforms/Scalar/LoopBoundSplit.h"
10 #include "llvm/Analysis/LoopAccessAnalysis.h"
11 #include "llvm/Analysis/LoopAnalysisManager.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/LoopIterator.h"
14 #include "llvm/Analysis/LoopPass.h"
15 #include "llvm/Analysis/MemorySSA.h"
16 #include "llvm/Analysis/MemorySSAUpdater.h"
17 #include "llvm/Analysis/ScalarEvolution.h"
18 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
19 #include "llvm/IR/PatternMatch.h"
20 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
21 #include "llvm/Transforms/Utils/Cloning.h"
22 #include "llvm/Transforms/Utils/LoopSimplify.h"
23 #include "llvm/Transforms/Utils/LoopUtils.h"
24 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
25
26 #define DEBUG_TYPE "loop-bound-split"
27
28 namespace llvm {
29
30 using namespace PatternMatch;
31
32 namespace {
33 struct ConditionInfo {
34 /// Branch instruction with this condition
35 BranchInst *BI;
36 /// ICmp instruction with this condition
37 ICmpInst *ICmp;
38 /// Preciate info
39 ICmpInst::Predicate Pred;
40 /// AddRec llvm value
41 Value *AddRecValue;
42 /// Bound llvm value
43 Value *BoundValue;
44 /// AddRec SCEV
45 const SCEV *AddRecSCEV;
46 /// Bound SCEV
47 const SCEV *BoundSCEV;
48
ConditionInfollvm::__anonb8dd4c0c0111::ConditionInfo49 ConditionInfo()
50 : BI(nullptr), ICmp(nullptr), Pred(ICmpInst::BAD_ICMP_PREDICATE),
51 AddRecValue(nullptr), BoundValue(nullptr), AddRecSCEV(nullptr),
52 BoundSCEV(nullptr) {}
53 };
54 } // namespace
55
analyzeICmp(ScalarEvolution & SE,ICmpInst * ICmp,ConditionInfo & Cond)56 static void analyzeICmp(ScalarEvolution &SE, ICmpInst *ICmp,
57 ConditionInfo &Cond) {
58 Cond.ICmp = ICmp;
59 if (match(ICmp, m_ICmp(Cond.Pred, m_Value(Cond.AddRecValue),
60 m_Value(Cond.BoundValue)))) {
61 Cond.AddRecSCEV = SE.getSCEV(Cond.AddRecValue);
62 Cond.BoundSCEV = SE.getSCEV(Cond.BoundValue);
63 // Locate AddRec in LHSSCEV and Bound in RHSSCEV.
64 if (isa<SCEVAddRecExpr>(Cond.BoundSCEV) &&
65 !isa<SCEVAddRecExpr>(Cond.AddRecSCEV)) {
66 std::swap(Cond.AddRecValue, Cond.BoundValue);
67 std::swap(Cond.AddRecSCEV, Cond.BoundSCEV);
68 Cond.Pred = ICmpInst::getSwappedPredicate(Cond.Pred);
69 }
70 }
71 }
72
calculateUpperBound(const Loop & L,ScalarEvolution & SE,ConditionInfo & Cond,bool IsExitCond)73 static bool calculateUpperBound(const Loop &L, ScalarEvolution &SE,
74 ConditionInfo &Cond, bool IsExitCond) {
75 if (IsExitCond) {
76 const SCEV *ExitCount = SE.getExitCount(&L, Cond.ICmp->getParent());
77 if (isa<SCEVCouldNotCompute>(ExitCount))
78 return false;
79
80 Cond.BoundSCEV = ExitCount;
81 return true;
82 }
83
84 // For non-exit condtion, if pred is LT, keep existing bound.
85 if (Cond.Pred == ICmpInst::ICMP_SLT || Cond.Pred == ICmpInst::ICMP_ULT)
86 return true;
87
88 // For non-exit condition, if pre is LE, try to convert it to LT.
89 // Range Range
90 // AddRec <= Bound --> AddRec < Bound + 1
91 if (Cond.Pred != ICmpInst::ICMP_ULE && Cond.Pred != ICmpInst::ICMP_SLE)
92 return false;
93
94 if (IntegerType *BoundSCEVIntType =
95 dyn_cast<IntegerType>(Cond.BoundSCEV->getType())) {
96 unsigned BitWidth = BoundSCEVIntType->getBitWidth();
97 APInt Max = ICmpInst::isSigned(Cond.Pred)
98 ? APInt::getSignedMaxValue(BitWidth)
99 : APInt::getMaxValue(BitWidth);
100 const SCEV *MaxSCEV = SE.getConstant(Max);
101 // Check Bound < INT_MAX
102 ICmpInst::Predicate Pred =
103 ICmpInst::isSigned(Cond.Pred) ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
104 if (SE.isKnownPredicate(Pred, Cond.BoundSCEV, MaxSCEV)) {
105 const SCEV *BoundPlusOneSCEV =
106 SE.getAddExpr(Cond.BoundSCEV, SE.getOne(BoundSCEVIntType));
107 Cond.BoundSCEV = BoundPlusOneSCEV;
108 Cond.Pred = Pred;
109 return true;
110 }
111 }
112
113 // ToDo: Support ICMP_NE/EQ.
114
115 return false;
116 }
117
hasProcessableCondition(const Loop & L,ScalarEvolution & SE,ICmpInst * ICmp,ConditionInfo & Cond,bool IsExitCond)118 static bool hasProcessableCondition(const Loop &L, ScalarEvolution &SE,
119 ICmpInst *ICmp, ConditionInfo &Cond,
120 bool IsExitCond) {
121 analyzeICmp(SE, ICmp, Cond);
122
123 // The BoundSCEV should be evaluated at loop entry.
124 if (!SE.isAvailableAtLoopEntry(Cond.BoundSCEV, &L))
125 return false;
126
127 const SCEVAddRecExpr *AddRecSCEV = dyn_cast<SCEVAddRecExpr>(Cond.AddRecSCEV);
128 // Allowed AddRec as induction variable.
129 if (!AddRecSCEV)
130 return false;
131
132 if (!AddRecSCEV->isAffine())
133 return false;
134
135 const SCEV *StepRecSCEV = AddRecSCEV->getStepRecurrence(SE);
136 // Allowed constant step.
137 if (!isa<SCEVConstant>(StepRecSCEV))
138 return false;
139
140 ConstantInt *StepCI = cast<SCEVConstant>(StepRecSCEV)->getValue();
141 // Allowed positive step for now.
142 // TODO: Support negative step.
143 if (StepCI->isNegative() || StepCI->isZero())
144 return false;
145
146 // Calculate upper bound.
147 if (!calculateUpperBound(L, SE, Cond, IsExitCond))
148 return false;
149
150 return true;
151 }
152
isProcessableCondBI(const ScalarEvolution & SE,const BranchInst * BI)153 static bool isProcessableCondBI(const ScalarEvolution &SE,
154 const BranchInst *BI) {
155 BasicBlock *TrueSucc = nullptr;
156 BasicBlock *FalseSucc = nullptr;
157 ICmpInst::Predicate Pred;
158 Value *LHS, *RHS;
159 if (!match(BI, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)),
160 m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc))))
161 return false;
162
163 if (!SE.isSCEVable(LHS->getType()))
164 return false;
165 assert(SE.isSCEVable(RHS->getType()) && "Expected RHS's type is SCEVable");
166
167 if (TrueSucc == FalseSucc)
168 return false;
169
170 return true;
171 }
172
canSplitLoopBound(const Loop & L,const DominatorTree & DT,ScalarEvolution & SE,ConditionInfo & Cond)173 static bool canSplitLoopBound(const Loop &L, const DominatorTree &DT,
174 ScalarEvolution &SE, ConditionInfo &Cond) {
175 // Skip function with optsize.
176 if (L.getHeader()->getParent()->hasOptSize())
177 return false;
178
179 // Split only innermost loop.
180 if (!L.isInnermost())
181 return false;
182
183 // Check loop is in simplified form.
184 if (!L.isLoopSimplifyForm())
185 return false;
186
187 // Check loop is in LCSSA form.
188 if (!L.isLCSSAForm(DT))
189 return false;
190
191 // Skip loop that cannot be cloned.
192 if (!L.isSafeToClone())
193 return false;
194
195 BasicBlock *ExitingBB = L.getExitingBlock();
196 // Assumed only one exiting block.
197 if (!ExitingBB)
198 return false;
199
200 BranchInst *ExitingBI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
201 if (!ExitingBI)
202 return false;
203
204 // Allowed only conditional branch with ICmp.
205 if (!isProcessableCondBI(SE, ExitingBI))
206 return false;
207
208 // Check the condition is processable.
209 ICmpInst *ICmp = cast<ICmpInst>(ExitingBI->getCondition());
210 if (!hasProcessableCondition(L, SE, ICmp, Cond, /*IsExitCond*/ true))
211 return false;
212
213 Cond.BI = ExitingBI;
214 return true;
215 }
216
isProfitableToTransform(const Loop & L,const BranchInst * BI)217 static bool isProfitableToTransform(const Loop &L, const BranchInst *BI) {
218 // If the conditional branch splits a loop into two halves, we could
219 // generally say it is profitable.
220 //
221 // ToDo: Add more profitable cases here.
222
223 // Check this branch causes diamond CFG.
224 BasicBlock *Succ0 = BI->getSuccessor(0);
225 BasicBlock *Succ1 = BI->getSuccessor(1);
226
227 BasicBlock *Succ0Succ = Succ0->getSingleSuccessor();
228 BasicBlock *Succ1Succ = Succ1->getSingleSuccessor();
229 if (!Succ0Succ || !Succ1Succ || Succ0Succ != Succ1Succ)
230 return false;
231
232 // ToDo: Calculate each successor's instruction cost.
233
234 return true;
235 }
236
findSplitCandidate(const Loop & L,ScalarEvolution & SE,ConditionInfo & ExitingCond,ConditionInfo & SplitCandidateCond)237 static BranchInst *findSplitCandidate(const Loop &L, ScalarEvolution &SE,
238 ConditionInfo &ExitingCond,
239 ConditionInfo &SplitCandidateCond) {
240 for (auto *BB : L.blocks()) {
241 // Skip condition of backedge.
242 if (L.getLoopLatch() == BB)
243 continue;
244
245 auto *BI = dyn_cast<BranchInst>(BB->getTerminator());
246 if (!BI)
247 continue;
248
249 // Check conditional branch with ICmp.
250 if (!isProcessableCondBI(SE, BI))
251 continue;
252
253 // Skip loop invariant condition.
254 if (L.isLoopInvariant(BI->getCondition()))
255 continue;
256
257 // Check the condition is processable.
258 ICmpInst *ICmp = cast<ICmpInst>(BI->getCondition());
259 if (!hasProcessableCondition(L, SE, ICmp, SplitCandidateCond,
260 /*IsExitCond*/ false))
261 continue;
262
263 if (ExitingCond.BoundSCEV->getType() !=
264 SplitCandidateCond.BoundSCEV->getType())
265 continue;
266
267 SplitCandidateCond.BI = BI;
268 return BI;
269 }
270
271 return nullptr;
272 }
273
splitLoopBound(Loop & L,DominatorTree & DT,LoopInfo & LI,ScalarEvolution & SE,LPMUpdater & U)274 static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI,
275 ScalarEvolution &SE, LPMUpdater &U) {
276 ConditionInfo SplitCandidateCond;
277 ConditionInfo ExitingCond;
278
279 // Check we can split this loop's bound.
280 if (!canSplitLoopBound(L, DT, SE, ExitingCond))
281 return false;
282
283 if (!findSplitCandidate(L, SE, ExitingCond, SplitCandidateCond))
284 return false;
285
286 if (!isProfitableToTransform(L, SplitCandidateCond.BI))
287 return false;
288
289 // Now, we have a split candidate. Let's build a form as below.
290 // +--------------------+
291 // | preheader |
292 // | set up newbound |
293 // +--------------------+
294 // | /----------------\
295 // +--------v----v------+ |
296 // | header |---\ |
297 // | with true condition| | |
298 // +--------------------+ | |
299 // | | |
300 // +--------v-----------+ | |
301 // | if.then.BB | | |
302 // +--------------------+ | |
303 // | | |
304 // +--------v-----------<---/ |
305 // | latch >----------/
306 // | with newbound |
307 // +--------------------+
308 // |
309 // +--------v-----------+
310 // | preheader2 |--------------\
311 // | if (AddRec i != | |
312 // | org bound) | |
313 // +--------------------+ |
314 // | /----------------\ |
315 // +--------v----v------+ | |
316 // | header2 |---\ | |
317 // | conditional branch | | | |
318 // |with false condition| | | |
319 // +--------------------+ | | |
320 // | | | |
321 // +--------v-----------+ | | |
322 // | if.then.BB2 | | | |
323 // +--------------------+ | | |
324 // | | | |
325 // +--------v-----------<---/ | |
326 // | latch2 >----------/ |
327 // | with org bound | |
328 // +--------v-----------+ |
329 // | |
330 // | +---------------+ |
331 // +--> exit <-------/
332 // +---------------+
333
334 // Let's create post loop.
335 SmallVector<BasicBlock *, 8> PostLoopBlocks;
336 Loop *PostLoop;
337 ValueToValueMapTy VMap;
338 BasicBlock *PreHeader = L.getLoopPreheader();
339 BasicBlock *SplitLoopPH = SplitEdge(PreHeader, L.getHeader(), &DT, &LI);
340 PostLoop = cloneLoopWithPreheader(L.getExitBlock(), SplitLoopPH, &L, VMap,
341 ".split", &LI, &DT, PostLoopBlocks);
342 remapInstructionsInBlocks(PostLoopBlocks, VMap);
343
344 // Add conditional branch to check we can skip post-loop in its preheader.
345 BasicBlock *PostLoopPreHeader = PostLoop->getLoopPreheader();
346 IRBuilder<> Builder(PostLoopPreHeader);
347 Instruction *OrigBI = PostLoopPreHeader->getTerminator();
348 ICmpInst::Predicate Pred = ICmpInst::ICMP_NE;
349 Value *Cond =
350 Builder.CreateICmp(Pred, ExitingCond.AddRecValue, ExitingCond.BoundValue);
351 Builder.CreateCondBr(Cond, PostLoop->getHeader(), PostLoop->getExitBlock());
352 OrigBI->eraseFromParent();
353
354 // Create new loop bound and add it into preheader of pre-loop.
355 const SCEV *NewBoundSCEV = ExitingCond.BoundSCEV;
356 const SCEV *SplitBoundSCEV = SplitCandidateCond.BoundSCEV;
357 NewBoundSCEV = ICmpInst::isSigned(ExitingCond.Pred)
358 ? SE.getSMinExpr(NewBoundSCEV, SplitBoundSCEV)
359 : SE.getUMinExpr(NewBoundSCEV, SplitBoundSCEV);
360
361 SCEVExpander Expander(
362 SE, L.getHeader()->getParent()->getParent()->getDataLayout(), "split");
363 Instruction *InsertPt = SplitLoopPH->getTerminator();
364 Value *NewBoundValue =
365 Expander.expandCodeFor(NewBoundSCEV, NewBoundSCEV->getType(), InsertPt);
366 NewBoundValue->setName("new.bound");
367
368 // Replace exiting bound value of pre-loop NewBound.
369 ExitingCond.ICmp->setOperand(1, NewBoundValue);
370
371 // Replace IV's start value of post-loop by NewBound.
372 for (PHINode &PN : L.getHeader()->phis()) {
373 // Find PHI with exiting condition from pre-loop.
374 if (SE.isSCEVable(PN.getType()) && isa<SCEVAddRecExpr>(SE.getSCEV(&PN))) {
375 for (Value *Op : PN.incoming_values()) {
376 if (Op == ExitingCond.AddRecValue) {
377 // Find cloned PHI for post-loop.
378 PHINode *PostLoopPN = cast<PHINode>(VMap[&PN]);
379 PostLoopPN->setIncomingValueForBlock(PostLoopPreHeader,
380 NewBoundValue);
381 }
382 }
383 }
384 }
385
386 // Replace SplitCandidateCond.BI's condition of pre-loop by True.
387 LLVMContext &Context = PreHeader->getContext();
388 SplitCandidateCond.BI->setCondition(ConstantInt::getTrue(Context));
389
390 // Replace cloned SplitCandidateCond.BI's condition in post-loop by False.
391 BranchInst *ClonedSplitCandidateBI =
392 cast<BranchInst>(VMap[SplitCandidateCond.BI]);
393 ClonedSplitCandidateBI->setCondition(ConstantInt::getFalse(Context));
394
395 // Replace exit branch target of pre-loop by post-loop's preheader.
396 if (L.getExitBlock() == ExitingCond.BI->getSuccessor(0))
397 ExitingCond.BI->setSuccessor(0, PostLoopPreHeader);
398 else
399 ExitingCond.BI->setSuccessor(1, PostLoopPreHeader);
400
401 // Update dominator tree.
402 DT.changeImmediateDominator(PostLoopPreHeader, L.getExitingBlock());
403 DT.changeImmediateDominator(PostLoop->getExitBlock(), PostLoopPreHeader);
404
405 // Invalidate cached SE information.
406 SE.forgetLoop(&L);
407
408 // Canonicalize loops.
409 // TODO: Try to update LCSSA information according to above change.
410 formLCSSA(L, DT, &LI, &SE);
411 simplifyLoop(&L, &DT, &LI, &SE, nullptr, nullptr, true);
412 formLCSSA(*PostLoop, DT, &LI, &SE);
413 simplifyLoop(PostLoop, &DT, &LI, &SE, nullptr, nullptr, true);
414
415 // Add new post-loop to loop pass manager.
416 U.addSiblingLoops(PostLoop);
417
418 return true;
419 }
420
run(Loop & L,LoopAnalysisManager & AM,LoopStandardAnalysisResults & AR,LPMUpdater & U)421 PreservedAnalyses LoopBoundSplitPass::run(Loop &L, LoopAnalysisManager &AM,
422 LoopStandardAnalysisResults &AR,
423 LPMUpdater &U) {
424 Function &F = *L.getHeader()->getParent();
425 (void)F;
426
427 LLVM_DEBUG(dbgs() << "Spliting bound of loop in " << F.getName() << ": " << L
428 << "\n");
429
430 if (!splitLoopBound(L, AR.DT, AR.LI, AR.SE, U))
431 return PreservedAnalyses::all();
432
433 assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast));
434 AR.LI.verify(AR.DT);
435
436 return getLoopPassPreservedAnalyses();
437 }
438
439 } // end namespace llvm
440