1 //===- IndVarSimplify.cpp - Induction Variable Elimination ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation analyzes and transforms the induction variables (and
10 // computations derived from them) into simpler forms suitable for subsequent
11 // analysis and transformation.
12 //
13 // If the trip count of a loop is computable, this pass also makes the following
14 // changes:
15 //   1. The exit condition for the loop is canonicalized to compare the
16 //      induction value against the exit value.  This turns loops like:
17 //        'for (i = 7; i*i < 1000; ++i)' into 'for (i = 0; i != 25; ++i)'
18 //   2. Any use outside of the loop of an expression derived from the indvar
19 //      is changed to compute the derived value outside of the loop, eliminating
20 //      the dependence on the exit value of the induction variable.  If the only
21 //      purpose of the loop is to compute the exit value of some derived
22 //      expression, this transformation will make the loop dead.
23 //
24 //===----------------------------------------------------------------------===//
25 
26 #include "llvm/Transforms/Scalar/IndVarSimplify.h"
27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/APInt.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/None.h"
32 #include "llvm/ADT/Optional.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/SmallPtrSet.h"
35 #include "llvm/ADT/SmallSet.h"
36 #include "llvm/ADT/SmallVector.h"
37 #include "llvm/ADT/Statistic.h"
38 #include "llvm/ADT/iterator_range.h"
39 #include "llvm/Analysis/LoopInfo.h"
40 #include "llvm/Analysis/LoopPass.h"
41 #include "llvm/Analysis/MemorySSA.h"
42 #include "llvm/Analysis/MemorySSAUpdater.h"
43 #include "llvm/Analysis/ScalarEvolution.h"
44 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
45 #include "llvm/Analysis/TargetLibraryInfo.h"
46 #include "llvm/Analysis/TargetTransformInfo.h"
47 #include "llvm/Analysis/ValueTracking.h"
48 #include "llvm/IR/BasicBlock.h"
49 #include "llvm/IR/Constant.h"
50 #include "llvm/IR/ConstantRange.h"
51 #include "llvm/IR/Constants.h"
52 #include "llvm/IR/DataLayout.h"
53 #include "llvm/IR/DerivedTypes.h"
54 #include "llvm/IR/Dominators.h"
55 #include "llvm/IR/Function.h"
56 #include "llvm/IR/IRBuilder.h"
57 #include "llvm/IR/InstrTypes.h"
58 #include "llvm/IR/Instruction.h"
59 #include "llvm/IR/Instructions.h"
60 #include "llvm/IR/IntrinsicInst.h"
61 #include "llvm/IR/Intrinsics.h"
62 #include "llvm/IR/Module.h"
63 #include "llvm/IR/Operator.h"
64 #include "llvm/IR/PassManager.h"
65 #include "llvm/IR/PatternMatch.h"
66 #include "llvm/IR/Type.h"
67 #include "llvm/IR/Use.h"
68 #include "llvm/IR/User.h"
69 #include "llvm/IR/Value.h"
70 #include "llvm/IR/ValueHandle.h"
71 #include "llvm/InitializePasses.h"
72 #include "llvm/Pass.h"
73 #include "llvm/Support/Casting.h"
74 #include "llvm/Support/CommandLine.h"
75 #include "llvm/Support/Compiler.h"
76 #include "llvm/Support/Debug.h"
77 #include "llvm/Support/ErrorHandling.h"
78 #include "llvm/Support/MathExtras.h"
79 #include "llvm/Support/raw_ostream.h"
80 #include "llvm/Transforms/Scalar.h"
81 #include "llvm/Transforms/Scalar/LoopPassManager.h"
82 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
83 #include "llvm/Transforms/Utils/Local.h"
84 #include "llvm/Transforms/Utils/LoopUtils.h"
85 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
86 #include "llvm/Transforms/Utils/SimplifyIndVar.h"
87 #include <cassert>
88 #include <cstdint>
89 #include <utility>
90 
91 using namespace llvm;
92 
93 #define DEBUG_TYPE "indvars"
94 
95 STATISTIC(NumWidened     , "Number of indvars widened");
96 STATISTIC(NumReplaced    , "Number of exit values replaced");
97 STATISTIC(NumLFTR        , "Number of loop exit tests replaced");
98 STATISTIC(NumElimExt     , "Number of IV sign/zero extends eliminated");
99 STATISTIC(NumElimIV      , "Number of congruent IVs eliminated");
100 
101 // Trip count verification can be enabled by default under NDEBUG if we
102 // implement a strong expression equivalence checker in SCEV. Until then, we
103 // use the verify-indvars flag, which may assert in some cases.
104 static cl::opt<bool> VerifyIndvars(
105     "verify-indvars", cl::Hidden,
106     cl::desc("Verify the ScalarEvolution result after running indvars. Has no "
107              "effect in release builds. (Note: this adds additional SCEV "
108              "queries potentially changing the analysis result)"));
109 
110 static cl::opt<ReplaceExitVal> ReplaceExitValue(
111     "replexitval", cl::Hidden, cl::init(OnlyCheapRepl),
112     cl::desc("Choose the strategy to replace exit value in IndVarSimplify"),
113     cl::values(clEnumValN(NeverRepl, "never", "never replace exit value"),
114                clEnumValN(OnlyCheapRepl, "cheap",
115                           "only replace exit value when the cost is cheap"),
116                clEnumValN(NoHardUse, "noharduse",
117                           "only replace exit values when loop def likely dead"),
118                clEnumValN(AlwaysRepl, "always",
119                           "always replace exit value whenever possible")));
120 
121 static cl::opt<bool> UsePostIncrementRanges(
122   "indvars-post-increment-ranges", cl::Hidden,
123   cl::desc("Use post increment control-dependent ranges in IndVarSimplify"),
124   cl::init(true));
125 
126 static cl::opt<bool>
127 DisableLFTR("disable-lftr", cl::Hidden, cl::init(false),
128             cl::desc("Disable Linear Function Test Replace optimization"));
129 
130 static cl::opt<bool>
131 LoopPredication("indvars-predicate-loops", cl::Hidden, cl::init(true),
132                 cl::desc("Predicate conditions in read only loops"));
133 
134 static cl::opt<bool>
135 AllowIVWidening("indvars-widen-indvars", cl::Hidden, cl::init(true),
136                 cl::desc("Allow widening of indvars to eliminate s/zext"));
137 
138 namespace {
139 
140 struct RewritePhi;
141 
142 class IndVarSimplify {
143   LoopInfo *LI;
144   ScalarEvolution *SE;
145   DominatorTree *DT;
146   const DataLayout &DL;
147   TargetLibraryInfo *TLI;
148   const TargetTransformInfo *TTI;
149   std::unique_ptr<MemorySSAUpdater> MSSAU;
150 
151   SmallVector<WeakTrackingVH, 16> DeadInsts;
152   bool WidenIndVars;
153 
154   bool handleFloatingPointIV(Loop *L, PHINode *PH);
155   bool rewriteNonIntegerIVs(Loop *L);
156 
157   bool simplifyAndExtend(Loop *L, SCEVExpander &Rewriter, LoopInfo *LI);
158   /// Try to eliminate loop exits based on analyzeable exit counts
159   bool optimizeLoopExits(Loop *L, SCEVExpander &Rewriter);
160   /// Try to form loop invariant tests for loop exits by changing how many
161   /// iterations of the loop run when that is unobservable.
162   bool predicateLoopExits(Loop *L, SCEVExpander &Rewriter);
163 
164   bool rewriteFirstIterationLoopExitValues(Loop *L);
165 
166   bool linearFunctionTestReplace(Loop *L, BasicBlock *ExitingBB,
167                                  const SCEV *ExitCount,
168                                  PHINode *IndVar, SCEVExpander &Rewriter);
169 
170   bool sinkUnusedInvariants(Loop *L);
171 
172 public:
IndVarSimplify(LoopInfo * LI,ScalarEvolution * SE,DominatorTree * DT,const DataLayout & DL,TargetLibraryInfo * TLI,TargetTransformInfo * TTI,MemorySSA * MSSA,bool WidenIndVars)173   IndVarSimplify(LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT,
174                  const DataLayout &DL, TargetLibraryInfo *TLI,
175                  TargetTransformInfo *TTI, MemorySSA *MSSA, bool WidenIndVars)
176       : LI(LI), SE(SE), DT(DT), DL(DL), TLI(TLI), TTI(TTI),
177         WidenIndVars(WidenIndVars) {
178     if (MSSA)
179       MSSAU = std::make_unique<MemorySSAUpdater>(MSSA);
180   }
181 
182   bool run(Loop *L);
183 };
184 
185 } // end anonymous namespace
186 
187 //===----------------------------------------------------------------------===//
188 // rewriteNonIntegerIVs and helpers. Prefer integer IVs.
189 //===----------------------------------------------------------------------===//
190 
191 /// Convert APF to an integer, if possible.
ConvertToSInt(const APFloat & APF,int64_t & IntVal)192 static bool ConvertToSInt(const APFloat &APF, int64_t &IntVal) {
193   bool isExact = false;
194   // See if we can convert this to an int64_t
195   uint64_t UIntVal;
196   if (APF.convertToInteger(makeMutableArrayRef(UIntVal), 64, true,
197                            APFloat::rmTowardZero, &isExact) != APFloat::opOK ||
198       !isExact)
199     return false;
200   IntVal = UIntVal;
201   return true;
202 }
203 
204 /// If the loop has floating induction variable then insert corresponding
205 /// integer induction variable if possible.
206 /// For example,
207 /// for(double i = 0; i < 10000; ++i)
208 ///   bar(i)
209 /// is converted into
210 /// for(int i = 0; i < 10000; ++i)
211 ///   bar((double)i);
handleFloatingPointIV(Loop * L,PHINode * PN)212 bool IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) {
213   unsigned IncomingEdge = L->contains(PN->getIncomingBlock(0));
214   unsigned BackEdge     = IncomingEdge^1;
215 
216   // Check incoming value.
217   auto *InitValueVal = dyn_cast<ConstantFP>(PN->getIncomingValue(IncomingEdge));
218 
219   int64_t InitValue;
220   if (!InitValueVal || !ConvertToSInt(InitValueVal->getValueAPF(), InitValue))
221     return false;
222 
223   // Check IV increment. Reject this PN if increment operation is not
224   // an add or increment value can not be represented by an integer.
225   auto *Incr = dyn_cast<BinaryOperator>(PN->getIncomingValue(BackEdge));
226   if (Incr == nullptr || Incr->getOpcode() != Instruction::FAdd) return false;
227 
228   // If this is not an add of the PHI with a constantfp, or if the constant fp
229   // is not an integer, bail out.
230   ConstantFP *IncValueVal = dyn_cast<ConstantFP>(Incr->getOperand(1));
231   int64_t IncValue;
232   if (IncValueVal == nullptr || Incr->getOperand(0) != PN ||
233       !ConvertToSInt(IncValueVal->getValueAPF(), IncValue))
234     return false;
235 
236   // Check Incr uses. One user is PN and the other user is an exit condition
237   // used by the conditional terminator.
238   Value::user_iterator IncrUse = Incr->user_begin();
239   Instruction *U1 = cast<Instruction>(*IncrUse++);
240   if (IncrUse == Incr->user_end()) return false;
241   Instruction *U2 = cast<Instruction>(*IncrUse++);
242   if (IncrUse != Incr->user_end()) return false;
243 
244   // Find exit condition, which is an fcmp.  If it doesn't exist, or if it isn't
245   // only used by a branch, we can't transform it.
246   FCmpInst *Compare = dyn_cast<FCmpInst>(U1);
247   if (!Compare)
248     Compare = dyn_cast<FCmpInst>(U2);
249   if (!Compare || !Compare->hasOneUse() ||
250       !isa<BranchInst>(Compare->user_back()))
251     return false;
252 
253   BranchInst *TheBr = cast<BranchInst>(Compare->user_back());
254 
255   // We need to verify that the branch actually controls the iteration count
256   // of the loop.  If not, the new IV can overflow and no one will notice.
257   // The branch block must be in the loop and one of the successors must be out
258   // of the loop.
259   assert(TheBr->isConditional() && "Can't use fcmp if not conditional");
260   if (!L->contains(TheBr->getParent()) ||
261       (L->contains(TheBr->getSuccessor(0)) &&
262        L->contains(TheBr->getSuccessor(1))))
263     return false;
264 
265   // If it isn't a comparison with an integer-as-fp (the exit value), we can't
266   // transform it.
267   ConstantFP *ExitValueVal = dyn_cast<ConstantFP>(Compare->getOperand(1));
268   int64_t ExitValue;
269   if (ExitValueVal == nullptr ||
270       !ConvertToSInt(ExitValueVal->getValueAPF(), ExitValue))
271     return false;
272 
273   // Find new predicate for integer comparison.
274   CmpInst::Predicate NewPred = CmpInst::BAD_ICMP_PREDICATE;
275   switch (Compare->getPredicate()) {
276   default: return false;  // Unknown comparison.
277   case CmpInst::FCMP_OEQ:
278   case CmpInst::FCMP_UEQ: NewPred = CmpInst::ICMP_EQ; break;
279   case CmpInst::FCMP_ONE:
280   case CmpInst::FCMP_UNE: NewPred = CmpInst::ICMP_NE; break;
281   case CmpInst::FCMP_OGT:
282   case CmpInst::FCMP_UGT: NewPred = CmpInst::ICMP_SGT; break;
283   case CmpInst::FCMP_OGE:
284   case CmpInst::FCMP_UGE: NewPred = CmpInst::ICMP_SGE; break;
285   case CmpInst::FCMP_OLT:
286   case CmpInst::FCMP_ULT: NewPred = CmpInst::ICMP_SLT; break;
287   case CmpInst::FCMP_OLE:
288   case CmpInst::FCMP_ULE: NewPred = CmpInst::ICMP_SLE; break;
289   }
290 
291   // We convert the floating point induction variable to a signed i32 value if
292   // we can.  This is only safe if the comparison will not overflow in a way
293   // that won't be trapped by the integer equivalent operations.  Check for this
294   // now.
295   // TODO: We could use i64 if it is native and the range requires it.
296 
297   // The start/stride/exit values must all fit in signed i32.
298   if (!isInt<32>(InitValue) || !isInt<32>(IncValue) || !isInt<32>(ExitValue))
299     return false;
300 
301   // If not actually striding (add x, 0.0), avoid touching the code.
302   if (IncValue == 0)
303     return false;
304 
305   // Positive and negative strides have different safety conditions.
306   if (IncValue > 0) {
307     // If we have a positive stride, we require the init to be less than the
308     // exit value.
309     if (InitValue >= ExitValue)
310       return false;
311 
312     uint32_t Range = uint32_t(ExitValue-InitValue);
313     // Check for infinite loop, either:
314     // while (i <= Exit) or until (i > Exit)
315     if (NewPred == CmpInst::ICMP_SLE || NewPred == CmpInst::ICMP_SGT) {
316       if (++Range == 0) return false;  // Range overflows.
317     }
318 
319     unsigned Leftover = Range % uint32_t(IncValue);
320 
321     // If this is an equality comparison, we require that the strided value
322     // exactly land on the exit value, otherwise the IV condition will wrap
323     // around and do things the fp IV wouldn't.
324     if ((NewPred == CmpInst::ICMP_EQ || NewPred == CmpInst::ICMP_NE) &&
325         Leftover != 0)
326       return false;
327 
328     // If the stride would wrap around the i32 before exiting, we can't
329     // transform the IV.
330     if (Leftover != 0 && int32_t(ExitValue+IncValue) < ExitValue)
331       return false;
332   } else {
333     // If we have a negative stride, we require the init to be greater than the
334     // exit value.
335     if (InitValue <= ExitValue)
336       return false;
337 
338     uint32_t Range = uint32_t(InitValue-ExitValue);
339     // Check for infinite loop, either:
340     // while (i >= Exit) or until (i < Exit)
341     if (NewPred == CmpInst::ICMP_SGE || NewPred == CmpInst::ICMP_SLT) {
342       if (++Range == 0) return false;  // Range overflows.
343     }
344 
345     unsigned Leftover = Range % uint32_t(-IncValue);
346 
347     // If this is an equality comparison, we require that the strided value
348     // exactly land on the exit value, otherwise the IV condition will wrap
349     // around and do things the fp IV wouldn't.
350     if ((NewPred == CmpInst::ICMP_EQ || NewPred == CmpInst::ICMP_NE) &&
351         Leftover != 0)
352       return false;
353 
354     // If the stride would wrap around the i32 before exiting, we can't
355     // transform the IV.
356     if (Leftover != 0 && int32_t(ExitValue+IncValue) > ExitValue)
357       return false;
358   }
359 
360   IntegerType *Int32Ty = Type::getInt32Ty(PN->getContext());
361 
362   // Insert new integer induction variable.
363   PHINode *NewPHI = PHINode::Create(Int32Ty, 2, PN->getName()+".int", PN);
364   NewPHI->addIncoming(ConstantInt::get(Int32Ty, InitValue),
365                       PN->getIncomingBlock(IncomingEdge));
366 
367   Value *NewAdd =
368     BinaryOperator::CreateAdd(NewPHI, ConstantInt::get(Int32Ty, IncValue),
369                               Incr->getName()+".int", Incr);
370   NewPHI->addIncoming(NewAdd, PN->getIncomingBlock(BackEdge));
371 
372   ICmpInst *NewCompare = new ICmpInst(TheBr, NewPred, NewAdd,
373                                       ConstantInt::get(Int32Ty, ExitValue),
374                                       Compare->getName());
375 
376   // In the following deletions, PN may become dead and may be deleted.
377   // Use a WeakTrackingVH to observe whether this happens.
378   WeakTrackingVH WeakPH = PN;
379 
380   // Delete the old floating point exit comparison.  The branch starts using the
381   // new comparison.
382   NewCompare->takeName(Compare);
383   Compare->replaceAllUsesWith(NewCompare);
384   RecursivelyDeleteTriviallyDeadInstructions(Compare, TLI, MSSAU.get());
385 
386   // Delete the old floating point increment.
387   Incr->replaceAllUsesWith(UndefValue::get(Incr->getType()));
388   RecursivelyDeleteTriviallyDeadInstructions(Incr, TLI, MSSAU.get());
389 
390   // If the FP induction variable still has uses, this is because something else
391   // in the loop uses its value.  In order to canonicalize the induction
392   // variable, we chose to eliminate the IV and rewrite it in terms of an
393   // int->fp cast.
394   //
395   // We give preference to sitofp over uitofp because it is faster on most
396   // platforms.
397   if (WeakPH) {
398     Value *Conv = new SIToFPInst(NewPHI, PN->getType(), "indvar.conv",
399                                  &*PN->getParent()->getFirstInsertionPt());
400     PN->replaceAllUsesWith(Conv);
401     RecursivelyDeleteTriviallyDeadInstructions(PN, TLI, MSSAU.get());
402   }
403   return true;
404 }
405 
rewriteNonIntegerIVs(Loop * L)406 bool IndVarSimplify::rewriteNonIntegerIVs(Loop *L) {
407   // First step.  Check to see if there are any floating-point recurrences.
408   // If there are, change them into integer recurrences, permitting analysis by
409   // the SCEV routines.
410   BasicBlock *Header = L->getHeader();
411 
412   SmallVector<WeakTrackingVH, 8> PHIs;
413   for (PHINode &PN : Header->phis())
414     PHIs.push_back(&PN);
415 
416   bool Changed = false;
417   for (unsigned i = 0, e = PHIs.size(); i != e; ++i)
418     if (PHINode *PN = dyn_cast_or_null<PHINode>(&*PHIs[i]))
419       Changed |= handleFloatingPointIV(L, PN);
420 
421   // If the loop previously had floating-point IV, ScalarEvolution
422   // may not have been able to compute a trip count. Now that we've done some
423   // re-writing, the trip count may be computable.
424   if (Changed)
425     SE->forgetLoop(L);
426   return Changed;
427 }
428 
429 //===---------------------------------------------------------------------===//
430 // rewriteFirstIterationLoopExitValues: Rewrite loop exit values if we know
431 // they will exit at the first iteration.
432 //===---------------------------------------------------------------------===//
433 
434 /// Check to see if this loop has loop invariant conditions which lead to loop
435 /// exits. If so, we know that if the exit path is taken, it is at the first
436 /// loop iteration. This lets us predict exit values of PHI nodes that live in
437 /// loop header.
rewriteFirstIterationLoopExitValues(Loop * L)438 bool IndVarSimplify::rewriteFirstIterationLoopExitValues(Loop *L) {
439   // Verify the input to the pass is already in LCSSA form.
440   assert(L->isLCSSAForm(*DT));
441 
442   SmallVector<BasicBlock *, 8> ExitBlocks;
443   L->getUniqueExitBlocks(ExitBlocks);
444 
445   bool MadeAnyChanges = false;
446   for (auto *ExitBB : ExitBlocks) {
447     // If there are no more PHI nodes in this exit block, then no more
448     // values defined inside the loop are used on this path.
449     for (PHINode &PN : ExitBB->phis()) {
450       for (unsigned IncomingValIdx = 0, E = PN.getNumIncomingValues();
451            IncomingValIdx != E; ++IncomingValIdx) {
452         auto *IncomingBB = PN.getIncomingBlock(IncomingValIdx);
453 
454         // Can we prove that the exit must run on the first iteration if it
455         // runs at all?  (i.e. early exits are fine for our purposes, but
456         // traces which lead to this exit being taken on the 2nd iteration
457         // aren't.)  Note that this is about whether the exit branch is
458         // executed, not about whether it is taken.
459         if (!L->getLoopLatch() ||
460             !DT->dominates(IncomingBB, L->getLoopLatch()))
461           continue;
462 
463         // Get condition that leads to the exit path.
464         auto *TermInst = IncomingBB->getTerminator();
465 
466         Value *Cond = nullptr;
467         if (auto *BI = dyn_cast<BranchInst>(TermInst)) {
468           // Must be a conditional branch, otherwise the block
469           // should not be in the loop.
470           Cond = BI->getCondition();
471         } else if (auto *SI = dyn_cast<SwitchInst>(TermInst))
472           Cond = SI->getCondition();
473         else
474           continue;
475 
476         if (!L->isLoopInvariant(Cond))
477           continue;
478 
479         auto *ExitVal = dyn_cast<PHINode>(PN.getIncomingValue(IncomingValIdx));
480 
481         // Only deal with PHIs in the loop header.
482         if (!ExitVal || ExitVal->getParent() != L->getHeader())
483           continue;
484 
485         // If ExitVal is a PHI on the loop header, then we know its
486         // value along this exit because the exit can only be taken
487         // on the first iteration.
488         auto *LoopPreheader = L->getLoopPreheader();
489         assert(LoopPreheader && "Invalid loop");
490         int PreheaderIdx = ExitVal->getBasicBlockIndex(LoopPreheader);
491         if (PreheaderIdx != -1) {
492           assert(ExitVal->getParent() == L->getHeader() &&
493                  "ExitVal must be in loop header");
494           MadeAnyChanges = true;
495           PN.setIncomingValue(IncomingValIdx,
496                               ExitVal->getIncomingValue(PreheaderIdx));
497         }
498       }
499     }
500   }
501   return MadeAnyChanges;
502 }
503 
504 //===----------------------------------------------------------------------===//
505 //  IV Widening - Extend the width of an IV to cover its widest uses.
506 //===----------------------------------------------------------------------===//
507 
508 /// Update information about the induction variable that is extended by this
509 /// sign or zero extend operation. This is used to determine the final width of
510 /// the IV before actually widening it.
visitIVCast(CastInst * Cast,WideIVInfo & WI,ScalarEvolution * SE,const TargetTransformInfo * TTI)511 static void visitIVCast(CastInst *Cast, WideIVInfo &WI,
512                         ScalarEvolution *SE,
513                         const TargetTransformInfo *TTI) {
514   bool IsSigned = Cast->getOpcode() == Instruction::SExt;
515   if (!IsSigned && Cast->getOpcode() != Instruction::ZExt)
516     return;
517 
518   Type *Ty = Cast->getType();
519   uint64_t Width = SE->getTypeSizeInBits(Ty);
520   if (!Cast->getModule()->getDataLayout().isLegalInteger(Width))
521     return;
522 
523   // Check that `Cast` actually extends the induction variable (we rely on this
524   // later).  This takes care of cases where `Cast` is extending a truncation of
525   // the narrow induction variable, and thus can end up being narrower than the
526   // "narrow" induction variable.
527   uint64_t NarrowIVWidth = SE->getTypeSizeInBits(WI.NarrowIV->getType());
528   if (NarrowIVWidth >= Width)
529     return;
530 
531   // Cast is either an sext or zext up to this point.
532   // We should not widen an indvar if arithmetics on the wider indvar are more
533   // expensive than those on the narrower indvar. We check only the cost of ADD
534   // because at least an ADD is required to increment the induction variable. We
535   // could compute more comprehensively the cost of all instructions on the
536   // induction variable when necessary.
537   if (TTI &&
538       TTI->getArithmeticInstrCost(Instruction::Add, Ty) >
539           TTI->getArithmeticInstrCost(Instruction::Add,
540                                       Cast->getOperand(0)->getType())) {
541     return;
542   }
543 
544   if (!WI.WidestNativeType) {
545     WI.WidestNativeType = SE->getEffectiveSCEVType(Ty);
546     WI.IsSigned = IsSigned;
547     return;
548   }
549 
550   // We extend the IV to satisfy the sign of its first user, arbitrarily.
551   if (WI.IsSigned != IsSigned)
552     return;
553 
554   if (Width > SE->getTypeSizeInBits(WI.WidestNativeType))
555     WI.WidestNativeType = SE->getEffectiveSCEVType(Ty);
556 }
557 
558 //===----------------------------------------------------------------------===//
559 //  Live IV Reduction - Minimize IVs live across the loop.
560 //===----------------------------------------------------------------------===//
561 
562 //===----------------------------------------------------------------------===//
563 //  Simplification of IV users based on SCEV evaluation.
564 //===----------------------------------------------------------------------===//
565 
566 namespace {
567 
568 class IndVarSimplifyVisitor : public IVVisitor {
569   ScalarEvolution *SE;
570   const TargetTransformInfo *TTI;
571   PHINode *IVPhi;
572 
573 public:
574   WideIVInfo WI;
575 
IndVarSimplifyVisitor(PHINode * IV,ScalarEvolution * SCEV,const TargetTransformInfo * TTI,const DominatorTree * DTree)576   IndVarSimplifyVisitor(PHINode *IV, ScalarEvolution *SCEV,
577                         const TargetTransformInfo *TTI,
578                         const DominatorTree *DTree)
579     : SE(SCEV), TTI(TTI), IVPhi(IV) {
580     DT = DTree;
581     WI.NarrowIV = IVPhi;
582   }
583 
584   // Implement the interface used by simplifyUsersOfIV.
visitCast(CastInst * Cast)585   void visitCast(CastInst *Cast) override { visitIVCast(Cast, WI, SE, TTI); }
586 };
587 
588 } // end anonymous namespace
589 
590 /// Iteratively perform simplification on a worklist of IV users. Each
591 /// successive simplification may push more users which may themselves be
592 /// candidates for simplification.
593 ///
594 /// Sign/Zero extend elimination is interleaved with IV simplification.
simplifyAndExtend(Loop * L,SCEVExpander & Rewriter,LoopInfo * LI)595 bool IndVarSimplify::simplifyAndExtend(Loop *L,
596                                        SCEVExpander &Rewriter,
597                                        LoopInfo *LI) {
598   SmallVector<WideIVInfo, 8> WideIVs;
599 
600   auto *GuardDecl = L->getBlocks()[0]->getModule()->getFunction(
601           Intrinsic::getName(Intrinsic::experimental_guard));
602   bool HasGuards = GuardDecl && !GuardDecl->use_empty();
603 
604   SmallVector<PHINode*, 8> LoopPhis;
605   for (BasicBlock::iterator I = L->getHeader()->begin(); isa<PHINode>(I); ++I) {
606     LoopPhis.push_back(cast<PHINode>(I));
607   }
608   // Each round of simplification iterates through the SimplifyIVUsers worklist
609   // for all current phis, then determines whether any IVs can be
610   // widened. Widening adds new phis to LoopPhis, inducing another round of
611   // simplification on the wide IVs.
612   bool Changed = false;
613   while (!LoopPhis.empty()) {
614     // Evaluate as many IV expressions as possible before widening any IVs. This
615     // forces SCEV to set no-wrap flags before evaluating sign/zero
616     // extension. The first time SCEV attempts to normalize sign/zero extension,
617     // the result becomes final. So for the most predictable results, we delay
618     // evaluation of sign/zero extend evaluation until needed, and avoid running
619     // other SCEV based analysis prior to simplifyAndExtend.
620     do {
621       PHINode *CurrIV = LoopPhis.pop_back_val();
622 
623       // Information about sign/zero extensions of CurrIV.
624       IndVarSimplifyVisitor Visitor(CurrIV, SE, TTI, DT);
625 
626       Changed |= simplifyUsersOfIV(CurrIV, SE, DT, LI, TTI, DeadInsts, Rewriter,
627                                    &Visitor);
628 
629       if (Visitor.WI.WidestNativeType) {
630         WideIVs.push_back(Visitor.WI);
631       }
632     } while(!LoopPhis.empty());
633 
634     // Continue if we disallowed widening.
635     if (!WidenIndVars)
636       continue;
637 
638     for (; !WideIVs.empty(); WideIVs.pop_back()) {
639       unsigned ElimExt;
640       unsigned Widened;
641       if (PHINode *WidePhi = createWideIV(WideIVs.back(), LI, SE, Rewriter,
642                                           DT, DeadInsts, ElimExt, Widened,
643                                           HasGuards, UsePostIncrementRanges)) {
644         NumElimExt += ElimExt;
645         NumWidened += Widened;
646         Changed = true;
647         LoopPhis.push_back(WidePhi);
648       }
649     }
650   }
651   return Changed;
652 }
653 
654 //===----------------------------------------------------------------------===//
655 //  linearFunctionTestReplace and its kin. Rewrite the loop exit condition.
656 //===----------------------------------------------------------------------===//
657 
658 /// Given an Value which is hoped to be part of an add recurance in the given
659 /// loop, return the associated Phi node if so.  Otherwise, return null.  Note
660 /// that this is less general than SCEVs AddRec checking.
getLoopPhiForCounter(Value * IncV,Loop * L)661 static PHINode *getLoopPhiForCounter(Value *IncV, Loop *L) {
662   Instruction *IncI = dyn_cast<Instruction>(IncV);
663   if (!IncI)
664     return nullptr;
665 
666   switch (IncI->getOpcode()) {
667   case Instruction::Add:
668   case Instruction::Sub:
669     break;
670   case Instruction::GetElementPtr:
671     // An IV counter must preserve its type.
672     if (IncI->getNumOperands() == 2)
673       break;
674     LLVM_FALLTHROUGH;
675   default:
676     return nullptr;
677   }
678 
679   PHINode *Phi = dyn_cast<PHINode>(IncI->getOperand(0));
680   if (Phi && Phi->getParent() == L->getHeader()) {
681     if (L->isLoopInvariant(IncI->getOperand(1)))
682       return Phi;
683     return nullptr;
684   }
685   if (IncI->getOpcode() == Instruction::GetElementPtr)
686     return nullptr;
687 
688   // Allow add/sub to be commuted.
689   Phi = dyn_cast<PHINode>(IncI->getOperand(1));
690   if (Phi && Phi->getParent() == L->getHeader()) {
691     if (L->isLoopInvariant(IncI->getOperand(0)))
692       return Phi;
693   }
694   return nullptr;
695 }
696 
697 /// Whether the current loop exit test is based on this value.  Currently this
698 /// is limited to a direct use in the loop condition.
isLoopExitTestBasedOn(Value * V,BasicBlock * ExitingBB)699 static bool isLoopExitTestBasedOn(Value *V, BasicBlock *ExitingBB) {
700   BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator());
701   ICmpInst *ICmp = dyn_cast<ICmpInst>(BI->getCondition());
702   // TODO: Allow non-icmp loop test.
703   if (!ICmp)
704     return false;
705 
706   // TODO: Allow indirect use.
707   return ICmp->getOperand(0) == V || ICmp->getOperand(1) == V;
708 }
709 
710 /// linearFunctionTestReplace policy. Return true unless we can show that the
711 /// current exit test is already sufficiently canonical.
needsLFTR(Loop * L,BasicBlock * ExitingBB)712 static bool needsLFTR(Loop *L, BasicBlock *ExitingBB) {
713   assert(L->getLoopLatch() && "Must be in simplified form");
714 
715   // Avoid converting a constant or loop invariant test back to a runtime
716   // test.  This is critical for when SCEV's cached ExitCount is less precise
717   // than the current IR (such as after we've proven a particular exit is
718   // actually dead and thus the BE count never reaches our ExitCount.)
719   BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator());
720   if (L->isLoopInvariant(BI->getCondition()))
721     return false;
722 
723   // Do LFTR to simplify the exit condition to an ICMP.
724   ICmpInst *Cond = dyn_cast<ICmpInst>(BI->getCondition());
725   if (!Cond)
726     return true;
727 
728   // Do LFTR to simplify the exit ICMP to EQ/NE
729   ICmpInst::Predicate Pred = Cond->getPredicate();
730   if (Pred != ICmpInst::ICMP_NE && Pred != ICmpInst::ICMP_EQ)
731     return true;
732 
733   // Look for a loop invariant RHS
734   Value *LHS = Cond->getOperand(0);
735   Value *RHS = Cond->getOperand(1);
736   if (!L->isLoopInvariant(RHS)) {
737     if (!L->isLoopInvariant(LHS))
738       return true;
739     std::swap(LHS, RHS);
740   }
741   // Look for a simple IV counter LHS
742   PHINode *Phi = dyn_cast<PHINode>(LHS);
743   if (!Phi)
744     Phi = getLoopPhiForCounter(LHS, L);
745 
746   if (!Phi)
747     return true;
748 
749   // Do LFTR if PHI node is defined in the loop, but is *not* a counter.
750   int Idx = Phi->getBasicBlockIndex(L->getLoopLatch());
751   if (Idx < 0)
752     return true;
753 
754   // Do LFTR if the exit condition's IV is *not* a simple counter.
755   Value *IncV = Phi->getIncomingValue(Idx);
756   return Phi != getLoopPhiForCounter(IncV, L);
757 }
758 
759 /// Return true if undefined behavior would provable be executed on the path to
760 /// OnPathTo if Root produced a posion result.  Note that this doesn't say
761 /// anything about whether OnPathTo is actually executed or whether Root is
762 /// actually poison.  This can be used to assess whether a new use of Root can
763 /// be added at a location which is control equivalent with OnPathTo (such as
764 /// immediately before it) without introducing UB which didn't previously
765 /// exist.  Note that a false result conveys no information.
mustExecuteUBIfPoisonOnPathTo(Instruction * Root,Instruction * OnPathTo,DominatorTree * DT)766 static bool mustExecuteUBIfPoisonOnPathTo(Instruction *Root,
767                                           Instruction *OnPathTo,
768                                           DominatorTree *DT) {
769   // Basic approach is to assume Root is poison, propagate poison forward
770   // through all users we can easily track, and then check whether any of those
771   // users are provable UB and must execute before out exiting block might
772   // exit.
773 
774   // The set of all recursive users we've visited (which are assumed to all be
775   // poison because of said visit)
776   SmallSet<const Value *, 16> KnownPoison;
777   SmallVector<const Instruction*, 16> Worklist;
778   Worklist.push_back(Root);
779   while (!Worklist.empty()) {
780     const Instruction *I = Worklist.pop_back_val();
781 
782     // If we know this must trigger UB on a path leading our target.
783     if (mustTriggerUB(I, KnownPoison) && DT->dominates(I, OnPathTo))
784       return true;
785 
786     // If we can't analyze propagation through this instruction, just skip it
787     // and transitive users.  Safe as false is a conservative result.
788     if (!propagatesPoison(cast<Operator>(I)) && I != Root)
789       continue;
790 
791     if (KnownPoison.insert(I).second)
792       for (const User *User : I->users())
793         Worklist.push_back(cast<Instruction>(User));
794   }
795 
796   // Might be non-UB, or might have a path we couldn't prove must execute on
797   // way to exiting bb.
798   return false;
799 }
800 
801 /// Recursive helper for hasConcreteDef(). Unfortunately, this currently boils
802 /// down to checking that all operands are constant and listing instructions
803 /// that may hide undef.
hasConcreteDefImpl(Value * V,SmallPtrSetImpl<Value * > & Visited,unsigned Depth)804 static bool hasConcreteDefImpl(Value *V, SmallPtrSetImpl<Value*> &Visited,
805                                unsigned Depth) {
806   if (isa<Constant>(V))
807     return !isa<UndefValue>(V);
808 
809   if (Depth >= 6)
810     return false;
811 
812   // Conservatively handle non-constant non-instructions. For example, Arguments
813   // may be undef.
814   Instruction *I = dyn_cast<Instruction>(V);
815   if (!I)
816     return false;
817 
818   // Load and return values may be undef.
819   if(I->mayReadFromMemory() || isa<CallInst>(I) || isa<InvokeInst>(I))
820     return false;
821 
822   // Optimistically handle other instructions.
823   for (Value *Op : I->operands()) {
824     if (!Visited.insert(Op).second)
825       continue;
826     if (!hasConcreteDefImpl(Op, Visited, Depth+1))
827       return false;
828   }
829   return true;
830 }
831 
832 /// Return true if the given value is concrete. We must prove that undef can
833 /// never reach it.
834 ///
835 /// TODO: If we decide that this is a good approach to checking for undef, we
836 /// may factor it into a common location.
hasConcreteDef(Value * V)837 static bool hasConcreteDef(Value *V) {
838   SmallPtrSet<Value*, 8> Visited;
839   Visited.insert(V);
840   return hasConcreteDefImpl(V, Visited, 0);
841 }
842 
843 /// Return true if this IV has any uses other than the (soon to be rewritten)
844 /// loop exit test.
AlmostDeadIV(PHINode * Phi,BasicBlock * LatchBlock,Value * Cond)845 static bool AlmostDeadIV(PHINode *Phi, BasicBlock *LatchBlock, Value *Cond) {
846   int LatchIdx = Phi->getBasicBlockIndex(LatchBlock);
847   Value *IncV = Phi->getIncomingValue(LatchIdx);
848 
849   for (User *U : Phi->users())
850     if (U != Cond && U != IncV) return false;
851 
852   for (User *U : IncV->users())
853     if (U != Cond && U != Phi) return false;
854   return true;
855 }
856 
857 /// Return true if the given phi is a "counter" in L.  A counter is an
858 /// add recurance (of integer or pointer type) with an arbitrary start, and a
859 /// step of 1.  Note that L must have exactly one latch.
isLoopCounter(PHINode * Phi,Loop * L,ScalarEvolution * SE)860 static bool isLoopCounter(PHINode* Phi, Loop *L,
861                           ScalarEvolution *SE) {
862   assert(Phi->getParent() == L->getHeader());
863   assert(L->getLoopLatch());
864 
865   if (!SE->isSCEVable(Phi->getType()))
866     return false;
867 
868   const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(Phi));
869   if (!AR || AR->getLoop() != L || !AR->isAffine())
870     return false;
871 
872   const SCEV *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*SE));
873   if (!Step || !Step->isOne())
874     return false;
875 
876   int LatchIdx = Phi->getBasicBlockIndex(L->getLoopLatch());
877   Value *IncV = Phi->getIncomingValue(LatchIdx);
878   return (getLoopPhiForCounter(IncV, L) == Phi);
879 }
880 
881 /// Search the loop header for a loop counter (anadd rec w/step of one)
882 /// suitable for use by LFTR.  If multiple counters are available, select the
883 /// "best" one based profitable heuristics.
884 ///
885 /// BECount may be an i8* pointer type. The pointer difference is already
886 /// valid count without scaling the address stride, so it remains a pointer
887 /// expression as far as SCEV is concerned.
FindLoopCounter(Loop * L,BasicBlock * ExitingBB,const SCEV * BECount,ScalarEvolution * SE,DominatorTree * DT)888 static PHINode *FindLoopCounter(Loop *L, BasicBlock *ExitingBB,
889                                 const SCEV *BECount,
890                                 ScalarEvolution *SE, DominatorTree *DT) {
891   uint64_t BCWidth = SE->getTypeSizeInBits(BECount->getType());
892 
893   Value *Cond = cast<BranchInst>(ExitingBB->getTerminator())->getCondition();
894 
895   // Loop over all of the PHI nodes, looking for a simple counter.
896   PHINode *BestPhi = nullptr;
897   const SCEV *BestInit = nullptr;
898   BasicBlock *LatchBlock = L->getLoopLatch();
899   assert(LatchBlock && "Must be in simplified form");
900   const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
901 
902   for (BasicBlock::iterator I = L->getHeader()->begin(); isa<PHINode>(I); ++I) {
903     PHINode *Phi = cast<PHINode>(I);
904     if (!isLoopCounter(Phi, L, SE))
905       continue;
906 
907     // Avoid comparing an integer IV against a pointer Limit.
908     if (BECount->getType()->isPointerTy() && !Phi->getType()->isPointerTy())
909       continue;
910 
911     const auto *AR = cast<SCEVAddRecExpr>(SE->getSCEV(Phi));
912 
913     // AR may be a pointer type, while BECount is an integer type.
914     // AR may be wider than BECount. With eq/ne tests overflow is immaterial.
915     // AR may not be a narrower type, or we may never exit.
916     uint64_t PhiWidth = SE->getTypeSizeInBits(AR->getType());
917     if (PhiWidth < BCWidth || !DL.isLegalInteger(PhiWidth))
918       continue;
919 
920     // Avoid reusing a potentially undef value to compute other values that may
921     // have originally had a concrete definition.
922     if (!hasConcreteDef(Phi)) {
923       // We explicitly allow unknown phis as long as they are already used by
924       // the loop exit test.  This is legal since performing LFTR could not
925       // increase the number of undef users.
926       Value *IncPhi = Phi->getIncomingValueForBlock(LatchBlock);
927       if (!isLoopExitTestBasedOn(Phi, ExitingBB) &&
928           !isLoopExitTestBasedOn(IncPhi, ExitingBB))
929         continue;
930     }
931 
932     // Avoid introducing undefined behavior due to poison which didn't exist in
933     // the original program.  (Annoyingly, the rules for poison and undef
934     // propagation are distinct, so this does NOT cover the undef case above.)
935     // We have to ensure that we don't introduce UB by introducing a use on an
936     // iteration where said IV produces poison.  Our strategy here differs for
937     // pointers and integer IVs.  For integers, we strip and reinfer as needed,
938     // see code in linearFunctionTestReplace.  For pointers, we restrict
939     // transforms as there is no good way to reinfer inbounds once lost.
940     if (!Phi->getType()->isIntegerTy() &&
941         !mustExecuteUBIfPoisonOnPathTo(Phi, ExitingBB->getTerminator(), DT))
942       continue;
943 
944     const SCEV *Init = AR->getStart();
945 
946     if (BestPhi && !AlmostDeadIV(BestPhi, LatchBlock, Cond)) {
947       // Don't force a live loop counter if another IV can be used.
948       if (AlmostDeadIV(Phi, LatchBlock, Cond))
949         continue;
950 
951       // Prefer to count-from-zero. This is a more "canonical" counter form. It
952       // also prefers integer to pointer IVs.
953       if (BestInit->isZero() != Init->isZero()) {
954         if (BestInit->isZero())
955           continue;
956       }
957       // If two IVs both count from zero or both count from nonzero then the
958       // narrower is likely a dead phi that has been widened. Use the wider phi
959       // to allow the other to be eliminated.
960       else if (PhiWidth <= SE->getTypeSizeInBits(BestPhi->getType()))
961         continue;
962     }
963     BestPhi = Phi;
964     BestInit = Init;
965   }
966   return BestPhi;
967 }
968 
969 /// Insert an IR expression which computes the value held by the IV IndVar
970 /// (which must be an loop counter w/unit stride) after the backedge of loop L
971 /// is taken ExitCount times.
genLoopLimit(PHINode * IndVar,BasicBlock * ExitingBB,const SCEV * ExitCount,bool UsePostInc,Loop * L,SCEVExpander & Rewriter,ScalarEvolution * SE)972 static Value *genLoopLimit(PHINode *IndVar, BasicBlock *ExitingBB,
973                            const SCEV *ExitCount, bool UsePostInc, Loop *L,
974                            SCEVExpander &Rewriter, ScalarEvolution *SE) {
975   assert(isLoopCounter(IndVar, L, SE));
976   const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(SE->getSCEV(IndVar));
977   const SCEV *IVInit = AR->getStart();
978 
979   // IVInit may be a pointer while ExitCount is an integer when FindLoopCounter
980   // finds a valid pointer IV. Sign extend ExitCount in order to materialize a
981   // GEP. Avoid running SCEVExpander on a new pointer value, instead reusing
982   // the existing GEPs whenever possible.
983   if (IndVar->getType()->isPointerTy() &&
984       !ExitCount->getType()->isPointerTy()) {
985     // IVOffset will be the new GEP offset that is interpreted by GEP as a
986     // signed value. ExitCount on the other hand represents the loop trip count,
987     // which is an unsigned value. FindLoopCounter only allows induction
988     // variables that have a positive unit stride of one. This means we don't
989     // have to handle the case of negative offsets (yet) and just need to zero
990     // extend ExitCount.
991     Type *OfsTy = SE->getEffectiveSCEVType(IVInit->getType());
992     const SCEV *IVOffset = SE->getTruncateOrZeroExtend(ExitCount, OfsTy);
993     if (UsePostInc)
994       IVOffset = SE->getAddExpr(IVOffset, SE->getOne(OfsTy));
995 
996     // Expand the code for the iteration count.
997     assert(SE->isLoopInvariant(IVOffset, L) &&
998            "Computed iteration count is not loop invariant!");
999 
1000     // We could handle pointer IVs other than i8*, but we need to compensate for
1001     // gep index scaling.
1002     assert(SE->getSizeOfExpr(IntegerType::getInt64Ty(IndVar->getContext()),
1003                              cast<PointerType>(IndVar->getType())
1004                                  ->getElementType())->isOne() &&
1005            "unit stride pointer IV must be i8*");
1006 
1007     const SCEV *IVLimit = SE->getAddExpr(IVInit, IVOffset);
1008     BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator());
1009     return Rewriter.expandCodeFor(IVLimit, IndVar->getType(), BI);
1010   } else {
1011     // In any other case, convert both IVInit and ExitCount to integers before
1012     // comparing. This may result in SCEV expansion of pointers, but in practice
1013     // SCEV will fold the pointer arithmetic away as such:
1014     // BECount = (IVEnd - IVInit - 1) => IVLimit = IVInit (postinc).
1015     //
1016     // Valid Cases: (1) both integers is most common; (2) both may be pointers
1017     // for simple memset-style loops.
1018     //
1019     // IVInit integer and ExitCount pointer would only occur if a canonical IV
1020     // were generated on top of case #2, which is not expected.
1021 
1022     assert(AR->getStepRecurrence(*SE)->isOne() && "only handles unit stride");
1023     // For unit stride, IVCount = Start + ExitCount with 2's complement
1024     // overflow.
1025 
1026     // For integer IVs, truncate the IV before computing IVInit + BECount,
1027     // unless we know apriori that the limit must be a constant when evaluated
1028     // in the bitwidth of the IV.  We prefer (potentially) keeping a truncate
1029     // of the IV in the loop over a (potentially) expensive expansion of the
1030     // widened exit count add(zext(add)) expression.
1031     if (SE->getTypeSizeInBits(IVInit->getType())
1032         > SE->getTypeSizeInBits(ExitCount->getType())) {
1033       if (isa<SCEVConstant>(IVInit) && isa<SCEVConstant>(ExitCount))
1034         ExitCount = SE->getZeroExtendExpr(ExitCount, IVInit->getType());
1035       else
1036         IVInit = SE->getTruncateExpr(IVInit, ExitCount->getType());
1037     }
1038 
1039     const SCEV *IVLimit = SE->getAddExpr(IVInit, ExitCount);
1040 
1041     if (UsePostInc)
1042       IVLimit = SE->getAddExpr(IVLimit, SE->getOne(IVLimit->getType()));
1043 
1044     // Expand the code for the iteration count.
1045     assert(SE->isLoopInvariant(IVLimit, L) &&
1046            "Computed iteration count is not loop invariant!");
1047     // Ensure that we generate the same type as IndVar, or a smaller integer
1048     // type. In the presence of null pointer values, we have an integer type
1049     // SCEV expression (IVInit) for a pointer type IV value (IndVar).
1050     Type *LimitTy = ExitCount->getType()->isPointerTy() ?
1051       IndVar->getType() : ExitCount->getType();
1052     BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator());
1053     return Rewriter.expandCodeFor(IVLimit, LimitTy, BI);
1054   }
1055 }
1056 
1057 /// This method rewrites the exit condition of the loop to be a canonical !=
1058 /// comparison against the incremented loop induction variable.  This pass is
1059 /// able to rewrite the exit tests of any loop where the SCEV analysis can
1060 /// determine a loop-invariant trip count of the loop, which is actually a much
1061 /// broader range than just linear tests.
1062 bool IndVarSimplify::
linearFunctionTestReplace(Loop * L,BasicBlock * ExitingBB,const SCEV * ExitCount,PHINode * IndVar,SCEVExpander & Rewriter)1063 linearFunctionTestReplace(Loop *L, BasicBlock *ExitingBB,
1064                           const SCEV *ExitCount,
1065                           PHINode *IndVar, SCEVExpander &Rewriter) {
1066   assert(L->getLoopLatch() && "Loop no longer in simplified form?");
1067   assert(isLoopCounter(IndVar, L, SE));
1068   Instruction * const IncVar =
1069     cast<Instruction>(IndVar->getIncomingValueForBlock(L->getLoopLatch()));
1070 
1071   // Initialize CmpIndVar to the preincremented IV.
1072   Value *CmpIndVar = IndVar;
1073   bool UsePostInc = false;
1074 
1075   // If the exiting block is the same as the backedge block, we prefer to
1076   // compare against the post-incremented value, otherwise we must compare
1077   // against the preincremented value.
1078   if (ExitingBB == L->getLoopLatch()) {
1079     // For pointer IVs, we chose to not strip inbounds which requires us not
1080     // to add a potentially UB introducing use.  We need to either a) show
1081     // the loop test we're modifying is already in post-inc form, or b) show
1082     // that adding a use must not introduce UB.
1083     bool SafeToPostInc =
1084         IndVar->getType()->isIntegerTy() ||
1085         isLoopExitTestBasedOn(IncVar, ExitingBB) ||
1086         mustExecuteUBIfPoisonOnPathTo(IncVar, ExitingBB->getTerminator(), DT);
1087     if (SafeToPostInc) {
1088       UsePostInc = true;
1089       CmpIndVar = IncVar;
1090     }
1091   }
1092 
1093   // It may be necessary to drop nowrap flags on the incrementing instruction
1094   // if either LFTR moves from a pre-inc check to a post-inc check (in which
1095   // case the increment might have previously been poison on the last iteration
1096   // only) or if LFTR switches to a different IV that was previously dynamically
1097   // dead (and as such may be arbitrarily poison). We remove any nowrap flags
1098   // that SCEV didn't infer for the post-inc addrec (even if we use a pre-inc
1099   // check), because the pre-inc addrec flags may be adopted from the original
1100   // instruction, while SCEV has to explicitly prove the post-inc nowrap flags.
1101   // TODO: This handling is inaccurate for one case: If we switch to a
1102   // dynamically dead IV that wraps on the first loop iteration only, which is
1103   // not covered by the post-inc addrec. (If the new IV was not dynamically
1104   // dead, it could not be poison on the first iteration in the first place.)
1105   if (auto *BO = dyn_cast<BinaryOperator>(IncVar)) {
1106     const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(SE->getSCEV(IncVar));
1107     if (BO->hasNoUnsignedWrap())
1108       BO->setHasNoUnsignedWrap(AR->hasNoUnsignedWrap());
1109     if (BO->hasNoSignedWrap())
1110       BO->setHasNoSignedWrap(AR->hasNoSignedWrap());
1111   }
1112 
1113   Value *ExitCnt = genLoopLimit(
1114       IndVar, ExitingBB, ExitCount, UsePostInc, L, Rewriter, SE);
1115   assert(ExitCnt->getType()->isPointerTy() ==
1116              IndVar->getType()->isPointerTy() &&
1117          "genLoopLimit missed a cast");
1118 
1119   // Insert a new icmp_ne or icmp_eq instruction before the branch.
1120   BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator());
1121   ICmpInst::Predicate P;
1122   if (L->contains(BI->getSuccessor(0)))
1123     P = ICmpInst::ICMP_NE;
1124   else
1125     P = ICmpInst::ICMP_EQ;
1126 
1127   IRBuilder<> Builder(BI);
1128 
1129   // The new loop exit condition should reuse the debug location of the
1130   // original loop exit condition.
1131   if (auto *Cond = dyn_cast<Instruction>(BI->getCondition()))
1132     Builder.SetCurrentDebugLocation(Cond->getDebugLoc());
1133 
1134   // For integer IVs, if we evaluated the limit in the narrower bitwidth to
1135   // avoid the expensive expansion of the limit expression in the wider type,
1136   // emit a truncate to narrow the IV to the ExitCount type.  This is safe
1137   // since we know (from the exit count bitwidth), that we can't self-wrap in
1138   // the narrower type.
1139   unsigned CmpIndVarSize = SE->getTypeSizeInBits(CmpIndVar->getType());
1140   unsigned ExitCntSize = SE->getTypeSizeInBits(ExitCnt->getType());
1141   if (CmpIndVarSize > ExitCntSize) {
1142     assert(!CmpIndVar->getType()->isPointerTy() &&
1143            !ExitCnt->getType()->isPointerTy());
1144 
1145     // Before resorting to actually inserting the truncate, use the same
1146     // reasoning as from SimplifyIndvar::eliminateTrunc to see if we can extend
1147     // the other side of the comparison instead.  We still evaluate the limit
1148     // in the narrower bitwidth, we just prefer a zext/sext outside the loop to
1149     // a truncate within in.
1150     bool Extended = false;
1151     const SCEV *IV = SE->getSCEV(CmpIndVar);
1152     const SCEV *TruncatedIV = SE->getTruncateExpr(SE->getSCEV(CmpIndVar),
1153                                                   ExitCnt->getType());
1154     const SCEV *ZExtTrunc =
1155       SE->getZeroExtendExpr(TruncatedIV, CmpIndVar->getType());
1156 
1157     if (ZExtTrunc == IV) {
1158       Extended = true;
1159       ExitCnt = Builder.CreateZExt(ExitCnt, IndVar->getType(),
1160                                    "wide.trip.count");
1161     } else {
1162       const SCEV *SExtTrunc =
1163         SE->getSignExtendExpr(TruncatedIV, CmpIndVar->getType());
1164       if (SExtTrunc == IV) {
1165         Extended = true;
1166         ExitCnt = Builder.CreateSExt(ExitCnt, IndVar->getType(),
1167                                      "wide.trip.count");
1168       }
1169     }
1170 
1171     if (Extended) {
1172       bool Discard;
1173       L->makeLoopInvariant(ExitCnt, Discard);
1174     } else
1175       CmpIndVar = Builder.CreateTrunc(CmpIndVar, ExitCnt->getType(),
1176                                       "lftr.wideiv");
1177   }
1178   LLVM_DEBUG(dbgs() << "INDVARS: Rewriting loop exit condition to:\n"
1179                     << "      LHS:" << *CmpIndVar << '\n'
1180                     << "       op:\t" << (P == ICmpInst::ICMP_NE ? "!=" : "==")
1181                     << "\n"
1182                     << "      RHS:\t" << *ExitCnt << "\n"
1183                     << "ExitCount:\t" << *ExitCount << "\n"
1184                     << "  was: " << *BI->getCondition() << "\n");
1185 
1186   Value *Cond = Builder.CreateICmp(P, CmpIndVar, ExitCnt, "exitcond");
1187   Value *OrigCond = BI->getCondition();
1188   // It's tempting to use replaceAllUsesWith here to fully replace the old
1189   // comparison, but that's not immediately safe, since users of the old
1190   // comparison may not be dominated by the new comparison. Instead, just
1191   // update the branch to use the new comparison; in the common case this
1192   // will make old comparison dead.
1193   BI->setCondition(Cond);
1194   DeadInsts.emplace_back(OrigCond);
1195 
1196   ++NumLFTR;
1197   return true;
1198 }
1199 
1200 //===----------------------------------------------------------------------===//
1201 //  sinkUnusedInvariants. A late subpass to cleanup loop preheaders.
1202 //===----------------------------------------------------------------------===//
1203 
1204 /// If there's a single exit block, sink any loop-invariant values that
1205 /// were defined in the preheader but not used inside the loop into the
1206 /// exit block to reduce register pressure in the loop.
sinkUnusedInvariants(Loop * L)1207 bool IndVarSimplify::sinkUnusedInvariants(Loop *L) {
1208   BasicBlock *ExitBlock = L->getExitBlock();
1209   if (!ExitBlock) return false;
1210 
1211   BasicBlock *Preheader = L->getLoopPreheader();
1212   if (!Preheader) return false;
1213 
1214   bool MadeAnyChanges = false;
1215   BasicBlock::iterator InsertPt = ExitBlock->getFirstInsertionPt();
1216   BasicBlock::iterator I(Preheader->getTerminator());
1217   while (I != Preheader->begin()) {
1218     --I;
1219     // New instructions were inserted at the end of the preheader.
1220     if (isa<PHINode>(I))
1221       break;
1222 
1223     // Don't move instructions which might have side effects, since the side
1224     // effects need to complete before instructions inside the loop.  Also don't
1225     // move instructions which might read memory, since the loop may modify
1226     // memory. Note that it's okay if the instruction might have undefined
1227     // behavior: LoopSimplify guarantees that the preheader dominates the exit
1228     // block.
1229     if (I->mayHaveSideEffects() || I->mayReadFromMemory())
1230       continue;
1231 
1232     // Skip debug info intrinsics.
1233     if (isa<DbgInfoIntrinsic>(I))
1234       continue;
1235 
1236     // Skip eh pad instructions.
1237     if (I->isEHPad())
1238       continue;
1239 
1240     // Don't sink alloca: we never want to sink static alloca's out of the
1241     // entry block, and correctly sinking dynamic alloca's requires
1242     // checks for stacksave/stackrestore intrinsics.
1243     // FIXME: Refactor this check somehow?
1244     if (isa<AllocaInst>(I))
1245       continue;
1246 
1247     // Determine if there is a use in or before the loop (direct or
1248     // otherwise).
1249     bool UsedInLoop = false;
1250     for (Use &U : I->uses()) {
1251       Instruction *User = cast<Instruction>(U.getUser());
1252       BasicBlock *UseBB = User->getParent();
1253       if (PHINode *P = dyn_cast<PHINode>(User)) {
1254         unsigned i =
1255           PHINode::getIncomingValueNumForOperand(U.getOperandNo());
1256         UseBB = P->getIncomingBlock(i);
1257       }
1258       if (UseBB == Preheader || L->contains(UseBB)) {
1259         UsedInLoop = true;
1260         break;
1261       }
1262     }
1263 
1264     // If there is, the def must remain in the preheader.
1265     if (UsedInLoop)
1266       continue;
1267 
1268     // Otherwise, sink it to the exit block.
1269     Instruction *ToMove = &*I;
1270     bool Done = false;
1271 
1272     if (I != Preheader->begin()) {
1273       // Skip debug info intrinsics.
1274       do {
1275         --I;
1276       } while (isa<DbgInfoIntrinsic>(I) && I != Preheader->begin());
1277 
1278       if (isa<DbgInfoIntrinsic>(I) && I == Preheader->begin())
1279         Done = true;
1280     } else {
1281       Done = true;
1282     }
1283 
1284     MadeAnyChanges = true;
1285     ToMove->moveBefore(*ExitBlock, InsertPt);
1286     if (Done) break;
1287     InsertPt = ToMove->getIterator();
1288   }
1289 
1290   return MadeAnyChanges;
1291 }
1292 
replaceExitCond(BranchInst * BI,Value * NewCond,SmallVectorImpl<WeakTrackingVH> & DeadInsts)1293 static void replaceExitCond(BranchInst *BI, Value *NewCond,
1294                             SmallVectorImpl<WeakTrackingVH> &DeadInsts) {
1295   auto *OldCond = BI->getCondition();
1296   BI->setCondition(NewCond);
1297   if (OldCond->use_empty())
1298     DeadInsts.emplace_back(OldCond);
1299 }
1300 
foldExit(const Loop * L,BasicBlock * ExitingBB,bool IsTaken,SmallVectorImpl<WeakTrackingVH> & DeadInsts)1301 static void foldExit(const Loop *L, BasicBlock *ExitingBB, bool IsTaken,
1302                      SmallVectorImpl<WeakTrackingVH> &DeadInsts) {
1303   BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator());
1304   bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB));
1305   auto *OldCond = BI->getCondition();
1306   auto *NewCond =
1307       ConstantInt::get(OldCond->getType(), IsTaken ? ExitIfTrue : !ExitIfTrue);
1308   replaceExitCond(BI, NewCond, DeadInsts);
1309 }
1310 
replaceWithInvariantCond(const Loop * L,BasicBlock * ExitingBB,ICmpInst::Predicate InvariantPred,const SCEV * InvariantLHS,const SCEV * InvariantRHS,SCEVExpander & Rewriter,SmallVectorImpl<WeakTrackingVH> & DeadInsts)1311 static void replaceWithInvariantCond(
1312     const Loop *L, BasicBlock *ExitingBB, ICmpInst::Predicate InvariantPred,
1313     const SCEV *InvariantLHS, const SCEV *InvariantRHS, SCEVExpander &Rewriter,
1314     SmallVectorImpl<WeakTrackingVH> &DeadInsts) {
1315   BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator());
1316   Rewriter.setInsertPoint(BI);
1317   auto *LHSV = Rewriter.expandCodeFor(InvariantLHS);
1318   auto *RHSV = Rewriter.expandCodeFor(InvariantRHS);
1319   bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB));
1320   if (ExitIfTrue)
1321     InvariantPred = ICmpInst::getInversePredicate(InvariantPred);
1322   IRBuilder<> Builder(BI);
1323   auto *NewCond = Builder.CreateICmp(InvariantPred, LHSV, RHSV,
1324                                      BI->getCondition()->getName());
1325   replaceExitCond(BI, NewCond, DeadInsts);
1326 }
1327 
optimizeLoopExitWithUnknownExitCount(const Loop * L,BranchInst * BI,BasicBlock * ExitingBB,const SCEV * MaxIter,bool Inverted,bool SkipLastIter,ScalarEvolution * SE,SCEVExpander & Rewriter,SmallVectorImpl<WeakTrackingVH> & DeadInsts)1328 static bool optimizeLoopExitWithUnknownExitCount(
1329     const Loop *L, BranchInst *BI, BasicBlock *ExitingBB,
1330     const SCEV *MaxIter, bool Inverted, bool SkipLastIter,
1331     ScalarEvolution *SE, SCEVExpander &Rewriter,
1332     SmallVectorImpl<WeakTrackingVH> &DeadInsts) {
1333   ICmpInst::Predicate Pred;
1334   Value *LHS, *RHS;
1335   using namespace PatternMatch;
1336   BasicBlock *TrueSucc, *FalseSucc;
1337   if (!match(BI, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)),
1338                       m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc))))
1339     return false;
1340 
1341   assert((L->contains(TrueSucc) != L->contains(FalseSucc)) &&
1342          "Not a loop exit!");
1343 
1344   // 'LHS pred RHS' should now mean that we stay in loop.
1345   if (L->contains(FalseSucc))
1346     Pred = CmpInst::getInversePredicate(Pred);
1347 
1348   // If we are proving loop exit, invert the predicate.
1349   if (Inverted)
1350     Pred = CmpInst::getInversePredicate(Pred);
1351 
1352   const SCEV *LHSS = SE->getSCEVAtScope(LHS, L);
1353   const SCEV *RHSS = SE->getSCEVAtScope(RHS, L);
1354   // Can we prove it to be trivially true?
1355   if (SE->isKnownPredicateAt(Pred, LHSS, RHSS, BI)) {
1356     foldExit(L, ExitingBB, Inverted, DeadInsts);
1357     return true;
1358   }
1359   // Further logic works for non-inverted condition only.
1360   if (Inverted)
1361     return false;
1362 
1363   auto *ARTy = LHSS->getType();
1364   auto *MaxIterTy = MaxIter->getType();
1365   // If possible, adjust types.
1366   if (SE->getTypeSizeInBits(ARTy) > SE->getTypeSizeInBits(MaxIterTy))
1367     MaxIter = SE->getZeroExtendExpr(MaxIter, ARTy);
1368   else if (SE->getTypeSizeInBits(ARTy) < SE->getTypeSizeInBits(MaxIterTy)) {
1369     const SCEV *MinusOne = SE->getMinusOne(ARTy);
1370     auto *MaxAllowedIter = SE->getZeroExtendExpr(MinusOne, MaxIterTy);
1371     if (SE->isKnownPredicateAt(ICmpInst::ICMP_ULE, MaxIter, MaxAllowedIter, BI))
1372       MaxIter = SE->getTruncateExpr(MaxIter, ARTy);
1373   }
1374 
1375   if (SkipLastIter) {
1376     const SCEV *One = SE->getOne(MaxIter->getType());
1377     MaxIter = SE->getMinusSCEV(MaxIter, One);
1378   }
1379 
1380   // Check if there is a loop-invariant predicate equivalent to our check.
1381   auto LIP = SE->getLoopInvariantExitCondDuringFirstIterations(Pred, LHSS, RHSS,
1382                                                                L, BI, MaxIter);
1383   if (!LIP)
1384     return false;
1385 
1386   // Can we prove it to be trivially true?
1387   if (SE->isKnownPredicateAt(LIP->Pred, LIP->LHS, LIP->RHS, BI))
1388     foldExit(L, ExitingBB, Inverted, DeadInsts);
1389   else
1390     replaceWithInvariantCond(L, ExitingBB, LIP->Pred, LIP->LHS, LIP->RHS,
1391                              Rewriter, DeadInsts);
1392 
1393   return true;
1394 }
1395 
optimizeLoopExits(Loop * L,SCEVExpander & Rewriter)1396 bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) {
1397   SmallVector<BasicBlock*, 16> ExitingBlocks;
1398   L->getExitingBlocks(ExitingBlocks);
1399 
1400   // Remove all exits which aren't both rewriteable and execute on every
1401   // iteration.
1402   llvm::erase_if(ExitingBlocks, [&](BasicBlock *ExitingBB) {
1403     // If our exitting block exits multiple loops, we can only rewrite the
1404     // innermost one.  Otherwise, we're changing how many times the innermost
1405     // loop runs before it exits.
1406     if (LI->getLoopFor(ExitingBB) != L)
1407       return true;
1408 
1409     // Can't rewrite non-branch yet.
1410     BranchInst *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
1411     if (!BI)
1412       return true;
1413 
1414     // If already constant, nothing to do.
1415     if (isa<Constant>(BI->getCondition()))
1416       return true;
1417 
1418     // Likewise, the loop latch must be dominated by the exiting BB.
1419     if (!DT->dominates(ExitingBB, L->getLoopLatch()))
1420       return true;
1421 
1422     return false;
1423   });
1424 
1425   if (ExitingBlocks.empty())
1426     return false;
1427 
1428   // Get a symbolic upper bound on the loop backedge taken count.
1429   const SCEV *MaxExitCount = SE->getSymbolicMaxBackedgeTakenCount(L);
1430   if (isa<SCEVCouldNotCompute>(MaxExitCount))
1431     return false;
1432 
1433   // Visit our exit blocks in order of dominance. We know from the fact that
1434   // all exits must dominate the latch, so there is a total dominance order
1435   // between them.
1436   llvm::sort(ExitingBlocks, [&](BasicBlock *A, BasicBlock *B) {
1437                // std::sort sorts in ascending order, so we want the inverse of
1438                // the normal dominance relation.
1439                if (A == B) return false;
1440                if (DT->properlyDominates(A, B))
1441                  return true;
1442                else {
1443                  assert(DT->properlyDominates(B, A) &&
1444                         "expected total dominance order!");
1445                  return false;
1446                }
1447   });
1448 #ifdef ASSERT
1449   for (unsigned i = 1; i < ExitingBlocks.size(); i++) {
1450     assert(DT->dominates(ExitingBlocks[i-1], ExitingBlocks[i]));
1451   }
1452 #endif
1453 
1454   bool Changed = false;
1455   bool SkipLastIter = false;
1456   SmallSet<const SCEV*, 8> DominatingExitCounts;
1457   for (BasicBlock *ExitingBB : ExitingBlocks) {
1458     const SCEV *ExitCount = SE->getExitCount(L, ExitingBB);
1459     if (isa<SCEVCouldNotCompute>(ExitCount)) {
1460       // Okay, we do not know the exit count here. Can we at least prove that it
1461       // will remain the same within iteration space?
1462       auto *BI = cast<BranchInst>(ExitingBB->getTerminator());
1463       auto OptimizeCond = [&](bool Inverted, bool SkipLastIter) {
1464         return optimizeLoopExitWithUnknownExitCount(
1465             L, BI, ExitingBB, MaxExitCount, Inverted, SkipLastIter, SE,
1466             Rewriter, DeadInsts);
1467       };
1468 
1469       // TODO: We might have proved that we can skip the last iteration for
1470       // this check. In this case, we only want to check the condition on the
1471       // pre-last iteration (MaxExitCount - 1). However, there is a nasty
1472       // corner case:
1473       //
1474       //   for (i = len; i != 0; i--) { ... check (i ult X) ... }
1475       //
1476       // If we could not prove that len != 0, then we also could not prove that
1477       // (len - 1) is not a UINT_MAX. If we simply query (len - 1), then
1478       // OptimizeCond will likely not prove anything for it, even if it could
1479       // prove the same fact for len.
1480       //
1481       // As a temporary solution, we query both last and pre-last iterations in
1482       // hope that we will be able to prove triviality for at least one of
1483       // them. We can stop querying MaxExitCount for this case once SCEV
1484       // understands that (MaxExitCount - 1) will not overflow here.
1485       if (OptimizeCond(false, false) || OptimizeCond(true, false))
1486         Changed = true;
1487       else if (SkipLastIter)
1488         if (OptimizeCond(false, true) || OptimizeCond(true, true))
1489           Changed = true;
1490       continue;
1491     }
1492 
1493     if (MaxExitCount == ExitCount)
1494       // If the loop has more than 1 iteration, all further checks will be
1495       // executed 1 iteration less.
1496       SkipLastIter = true;
1497 
1498     // If we know we'd exit on the first iteration, rewrite the exit to
1499     // reflect this.  This does not imply the loop must exit through this
1500     // exit; there may be an earlier one taken on the first iteration.
1501     // TODO: Given we know the backedge can't be taken, we should go ahead
1502     // and break it.  Or at least, kill all the header phis and simplify.
1503     if (ExitCount->isZero()) {
1504       foldExit(L, ExitingBB, true, DeadInsts);
1505       Changed = true;
1506       continue;
1507     }
1508 
1509     // If we end up with a pointer exit count, bail.  Note that we can end up
1510     // with a pointer exit count for one exiting block, and not for another in
1511     // the same loop.
1512     if (!ExitCount->getType()->isIntegerTy() ||
1513         !MaxExitCount->getType()->isIntegerTy())
1514       continue;
1515 
1516     Type *WiderType =
1517       SE->getWiderType(MaxExitCount->getType(), ExitCount->getType());
1518     ExitCount = SE->getNoopOrZeroExtend(ExitCount, WiderType);
1519     MaxExitCount = SE->getNoopOrZeroExtend(MaxExitCount, WiderType);
1520     assert(MaxExitCount->getType() == ExitCount->getType());
1521 
1522     // Can we prove that some other exit must be taken strictly before this
1523     // one?
1524     if (SE->isLoopEntryGuardedByCond(L, CmpInst::ICMP_ULT,
1525                                      MaxExitCount, ExitCount)) {
1526       foldExit(L, ExitingBB, false, DeadInsts);
1527       Changed = true;
1528       continue;
1529     }
1530 
1531     // As we run, keep track of which exit counts we've encountered.  If we
1532     // find a duplicate, we've found an exit which would have exited on the
1533     // exiting iteration, but (from the visit order) strictly follows another
1534     // which does the same and is thus dead.
1535     if (!DominatingExitCounts.insert(ExitCount).second) {
1536       foldExit(L, ExitingBB, false, DeadInsts);
1537       Changed = true;
1538       continue;
1539     }
1540 
1541     // TODO: There might be another oppurtunity to leverage SCEV's reasoning
1542     // here.  If we kept track of the min of dominanting exits so far, we could
1543     // discharge exits with EC >= MDEC. This is less powerful than the existing
1544     // transform (since later exits aren't considered), but potentially more
1545     // powerful for any case where SCEV can prove a >=u b, but neither a == b
1546     // or a >u b.  Such a case is not currently known.
1547   }
1548   return Changed;
1549 }
1550 
predicateLoopExits(Loop * L,SCEVExpander & Rewriter)1551 bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
1552   SmallVector<BasicBlock*, 16> ExitingBlocks;
1553   L->getExitingBlocks(ExitingBlocks);
1554 
1555   // Finally, see if we can rewrite our exit conditions into a loop invariant
1556   // form. If we have a read-only loop, and we can tell that we must exit down
1557   // a path which does not need any of the values computed within the loop, we
1558   // can rewrite the loop to exit on the first iteration.  Note that this
1559   // doesn't either a) tell us the loop exits on the first iteration (unless
1560   // *all* exits are predicateable) or b) tell us *which* exit might be taken.
1561   // This transformation looks a lot like a restricted form of dead loop
1562   // elimination, but restricted to read-only loops and without neccesssarily
1563   // needing to kill the loop entirely.
1564   if (!LoopPredication)
1565     return false;
1566 
1567   if (!SE->hasLoopInvariantBackedgeTakenCount(L))
1568     return false;
1569 
1570   // Note: ExactBTC is the exact backedge taken count *iff* the loop exits
1571   // through *explicit* control flow.  We have to eliminate the possibility of
1572   // implicit exits (see below) before we know it's truly exact.
1573   const SCEV *ExactBTC = SE->getBackedgeTakenCount(L);
1574   if (isa<SCEVCouldNotCompute>(ExactBTC) ||
1575       !SE->isLoopInvariant(ExactBTC, L) ||
1576       !isSafeToExpand(ExactBTC, *SE))
1577     return false;
1578 
1579   // If we end up with a pointer exit count, bail.  It may be unsized.
1580   if (!ExactBTC->getType()->isIntegerTy())
1581     return false;
1582 
1583   auto BadExit = [&](BasicBlock *ExitingBB) {
1584     // If our exiting block exits multiple loops, we can only rewrite the
1585     // innermost one.  Otherwise, we're changing how many times the innermost
1586     // loop runs before it exits.
1587     if (LI->getLoopFor(ExitingBB) != L)
1588       return true;
1589 
1590     // Can't rewrite non-branch yet.
1591     BranchInst *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
1592     if (!BI)
1593       return true;
1594 
1595     // If already constant, nothing to do.
1596     if (isa<Constant>(BI->getCondition()))
1597       return true;
1598 
1599     // If the exit block has phis, we need to be able to compute the values
1600     // within the loop which contains them.  This assumes trivially lcssa phis
1601     // have already been removed; TODO: generalize
1602     BasicBlock *ExitBlock =
1603     BI->getSuccessor(L->contains(BI->getSuccessor(0)) ? 1 : 0);
1604     if (!ExitBlock->phis().empty())
1605       return true;
1606 
1607     const SCEV *ExitCount = SE->getExitCount(L, ExitingBB);
1608     assert(!isa<SCEVCouldNotCompute>(ExactBTC) && "implied by having exact trip count");
1609     if (!SE->isLoopInvariant(ExitCount, L) ||
1610         !isSafeToExpand(ExitCount, *SE))
1611       return true;
1612 
1613     // If we end up with a pointer exit count, bail.  It may be unsized.
1614     if (!ExitCount->getType()->isIntegerTy())
1615       return true;
1616 
1617     return false;
1618   };
1619 
1620   // If we have any exits which can't be predicated themselves, than we can't
1621   // predicate any exit which isn't guaranteed to execute before it.  Consider
1622   // two exits (a) and (b) which would both exit on the same iteration.  If we
1623   // can predicate (b), but not (a), and (a) preceeds (b) along some path, then
1624   // we could convert a loop from exiting through (a) to one exiting through
1625   // (b).  Note that this problem exists only for exits with the same exit
1626   // count, and we could be more aggressive when exit counts are known inequal.
1627   llvm::sort(ExitingBlocks,
1628             [&](BasicBlock *A, BasicBlock *B) {
1629               // std::sort sorts in ascending order, so we want the inverse of
1630               // the normal dominance relation, plus a tie breaker for blocks
1631               // unordered by dominance.
1632               if (DT->properlyDominates(A, B)) return true;
1633               if (DT->properlyDominates(B, A)) return false;
1634               return A->getName() < B->getName();
1635             });
1636   // Check to see if our exit blocks are a total order (i.e. a linear chain of
1637   // exits before the backedge).  If they aren't, reasoning about reachability
1638   // is complicated and we choose not to for now.
1639   for (unsigned i = 1; i < ExitingBlocks.size(); i++)
1640     if (!DT->dominates(ExitingBlocks[i-1], ExitingBlocks[i]))
1641       return false;
1642 
1643   // Given our sorted total order, we know that exit[j] must be evaluated
1644   // after all exit[i] such j > i.
1645   for (unsigned i = 0, e = ExitingBlocks.size(); i < e; i++)
1646     if (BadExit(ExitingBlocks[i])) {
1647       ExitingBlocks.resize(i);
1648       break;
1649     }
1650 
1651   if (ExitingBlocks.empty())
1652     return false;
1653 
1654   // We rely on not being able to reach an exiting block on a later iteration
1655   // then it's statically compute exit count.  The implementaton of
1656   // getExitCount currently has this invariant, but assert it here so that
1657   // breakage is obvious if this ever changes..
1658   assert(llvm::all_of(ExitingBlocks, [&](BasicBlock *ExitingBB) {
1659         return DT->dominates(ExitingBB, L->getLoopLatch());
1660       }));
1661 
1662   // At this point, ExitingBlocks consists of only those blocks which are
1663   // predicatable.  Given that, we know we have at least one exit we can
1664   // predicate if the loop is doesn't have side effects and doesn't have any
1665   // implicit exits (because then our exact BTC isn't actually exact).
1666   // @Reviewers - As structured, this is O(I^2) for loop nests.  Any
1667   // suggestions on how to improve this?  I can obviously bail out for outer
1668   // loops, but that seems less than ideal.  MemorySSA can find memory writes,
1669   // is that enough for *all* side effects?
1670   for (BasicBlock *BB : L->blocks())
1671     for (auto &I : *BB)
1672       // TODO:isGuaranteedToTransfer
1673       if (I.mayHaveSideEffects() || I.mayThrow())
1674         return false;
1675 
1676   bool Changed = false;
1677   // Finally, do the actual predication for all predicatable blocks.  A couple
1678   // of notes here:
1679   // 1) We don't bother to constant fold dominated exits with identical exit
1680   //    counts; that's simply a form of CSE/equality propagation and we leave
1681   //    it for dedicated passes.
1682   // 2) We insert the comparison at the branch.  Hoisting introduces additional
1683   //    legality constraints and we leave that to dedicated logic.  We want to
1684   //    predicate even if we can't insert a loop invariant expression as
1685   //    peeling or unrolling will likely reduce the cost of the otherwise loop
1686   //    varying check.
1687   Rewriter.setInsertPoint(L->getLoopPreheader()->getTerminator());
1688   IRBuilder<> B(L->getLoopPreheader()->getTerminator());
1689   Value *ExactBTCV = nullptr; // Lazily generated if needed.
1690   for (BasicBlock *ExitingBB : ExitingBlocks) {
1691     const SCEV *ExitCount = SE->getExitCount(L, ExitingBB);
1692 
1693     auto *BI = cast<BranchInst>(ExitingBB->getTerminator());
1694     Value *NewCond;
1695     if (ExitCount == ExactBTC) {
1696       NewCond = L->contains(BI->getSuccessor(0)) ?
1697         B.getFalse() : B.getTrue();
1698     } else {
1699       Value *ECV = Rewriter.expandCodeFor(ExitCount);
1700       if (!ExactBTCV)
1701         ExactBTCV = Rewriter.expandCodeFor(ExactBTC);
1702       Value *RHS = ExactBTCV;
1703       if (ECV->getType() != RHS->getType()) {
1704         Type *WiderTy = SE->getWiderType(ECV->getType(), RHS->getType());
1705         ECV = B.CreateZExt(ECV, WiderTy);
1706         RHS = B.CreateZExt(RHS, WiderTy);
1707       }
1708       auto Pred = L->contains(BI->getSuccessor(0)) ?
1709         ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ;
1710       NewCond = B.CreateICmp(Pred, ECV, RHS);
1711     }
1712     Value *OldCond = BI->getCondition();
1713     BI->setCondition(NewCond);
1714     if (OldCond->use_empty())
1715       DeadInsts.emplace_back(OldCond);
1716     Changed = true;
1717   }
1718 
1719   return Changed;
1720 }
1721 
1722 //===----------------------------------------------------------------------===//
1723 //  IndVarSimplify driver. Manage several subpasses of IV simplification.
1724 //===----------------------------------------------------------------------===//
1725 
run(Loop * L)1726 bool IndVarSimplify::run(Loop *L) {
1727   // We need (and expect!) the incoming loop to be in LCSSA.
1728   assert(L->isRecursivelyLCSSAForm(*DT, *LI) &&
1729          "LCSSA required to run indvars!");
1730 
1731   // If LoopSimplify form is not available, stay out of trouble. Some notes:
1732   //  - LSR currently only supports LoopSimplify-form loops. Indvars'
1733   //    canonicalization can be a pessimization without LSR to "clean up"
1734   //    afterwards.
1735   //  - We depend on having a preheader; in particular,
1736   //    Loop::getCanonicalInductionVariable only supports loops with preheaders,
1737   //    and we're in trouble if we can't find the induction variable even when
1738   //    we've manually inserted one.
1739   //  - LFTR relies on having a single backedge.
1740   if (!L->isLoopSimplifyForm())
1741     return false;
1742 
1743 #ifndef NDEBUG
1744   // Used below for a consistency check only
1745   // Note: Since the result returned by ScalarEvolution may depend on the order
1746   // in which previous results are added to its cache, the call to
1747   // getBackedgeTakenCount() may change following SCEV queries.
1748   const SCEV *BackedgeTakenCount;
1749   if (VerifyIndvars)
1750     BackedgeTakenCount = SE->getBackedgeTakenCount(L);
1751 #endif
1752 
1753   bool Changed = false;
1754   // If there are any floating-point recurrences, attempt to
1755   // transform them to use integer recurrences.
1756   Changed |= rewriteNonIntegerIVs(L);
1757 
1758   // Create a rewriter object which we'll use to transform the code with.
1759   SCEVExpander Rewriter(*SE, DL, "indvars");
1760 #ifndef NDEBUG
1761   Rewriter.setDebugType(DEBUG_TYPE);
1762 #endif
1763 
1764   // Eliminate redundant IV users.
1765   //
1766   // Simplification works best when run before other consumers of SCEV. We
1767   // attempt to avoid evaluating SCEVs for sign/zero extend operations until
1768   // other expressions involving loop IVs have been evaluated. This helps SCEV
1769   // set no-wrap flags before normalizing sign/zero extension.
1770   Rewriter.disableCanonicalMode();
1771   Changed |= simplifyAndExtend(L, Rewriter, LI);
1772 
1773   // Check to see if we can compute the final value of any expressions
1774   // that are recurrent in the loop, and substitute the exit values from the
1775   // loop into any instructions outside of the loop that use the final values
1776   // of the current expressions.
1777   if (ReplaceExitValue != NeverRepl) {
1778     if (int Rewrites = rewriteLoopExitValues(L, LI, TLI, SE, TTI, Rewriter, DT,
1779                                              ReplaceExitValue, DeadInsts)) {
1780       NumReplaced += Rewrites;
1781       Changed = true;
1782     }
1783   }
1784 
1785   // Eliminate redundant IV cycles.
1786   NumElimIV += Rewriter.replaceCongruentIVs(L, DT, DeadInsts);
1787 
1788   // Try to eliminate loop exits based on analyzeable exit counts
1789   if (optimizeLoopExits(L, Rewriter))  {
1790     Changed = true;
1791     // Given we've changed exit counts, notify SCEV
1792     // Some nested loops may share same folded exit basic block,
1793     // thus we need to notify top most loop.
1794     SE->forgetTopmostLoop(L);
1795   }
1796 
1797   // Try to form loop invariant tests for loop exits by changing how many
1798   // iterations of the loop run when that is unobservable.
1799   if (predicateLoopExits(L, Rewriter)) {
1800     Changed = true;
1801     // Given we've changed exit counts, notify SCEV
1802     SE->forgetLoop(L);
1803   }
1804 
1805   // If we have a trip count expression, rewrite the loop's exit condition
1806   // using it.
1807   if (!DisableLFTR) {
1808     BasicBlock *PreHeader = L->getLoopPreheader();
1809     BranchInst *PreHeaderBR = cast<BranchInst>(PreHeader->getTerminator());
1810 
1811     SmallVector<BasicBlock*, 16> ExitingBlocks;
1812     L->getExitingBlocks(ExitingBlocks);
1813     for (BasicBlock *ExitingBB : ExitingBlocks) {
1814       // Can't rewrite non-branch yet.
1815       if (!isa<BranchInst>(ExitingBB->getTerminator()))
1816         continue;
1817 
1818       // If our exitting block exits multiple loops, we can only rewrite the
1819       // innermost one.  Otherwise, we're changing how many times the innermost
1820       // loop runs before it exits.
1821       if (LI->getLoopFor(ExitingBB) != L)
1822         continue;
1823 
1824       if (!needsLFTR(L, ExitingBB))
1825         continue;
1826 
1827       const SCEV *ExitCount = SE->getExitCount(L, ExitingBB);
1828       if (isa<SCEVCouldNotCompute>(ExitCount))
1829         continue;
1830 
1831       // This was handled above, but as we form SCEVs, we can sometimes refine
1832       // existing ones; this allows exit counts to be folded to zero which
1833       // weren't when optimizeLoopExits saw them.  Arguably, we should iterate
1834       // until stable to handle cases like this better.
1835       if (ExitCount->isZero())
1836         continue;
1837 
1838       PHINode *IndVar = FindLoopCounter(L, ExitingBB, ExitCount, SE, DT);
1839       if (!IndVar)
1840         continue;
1841 
1842       // Avoid high cost expansions.  Note: This heuristic is questionable in
1843       // that our definition of "high cost" is not exactly principled.
1844       if (Rewriter.isHighCostExpansion(ExitCount, L, SCEVCheapExpansionBudget,
1845                                        TTI, PreHeaderBR))
1846         continue;
1847 
1848       // Check preconditions for proper SCEVExpander operation. SCEV does not
1849       // express SCEVExpander's dependencies, such as LoopSimplify. Instead
1850       // any pass that uses the SCEVExpander must do it. This does not work
1851       // well for loop passes because SCEVExpander makes assumptions about
1852       // all loops, while LoopPassManager only forces the current loop to be
1853       // simplified.
1854       //
1855       // FIXME: SCEV expansion has no way to bail out, so the caller must
1856       // explicitly check any assumptions made by SCEV. Brittle.
1857       const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ExitCount);
1858       if (!AR || AR->getLoop()->getLoopPreheader())
1859         Changed |= linearFunctionTestReplace(L, ExitingBB,
1860                                              ExitCount, IndVar,
1861                                              Rewriter);
1862     }
1863   }
1864   // Clear the rewriter cache, because values that are in the rewriter's cache
1865   // can be deleted in the loop below, causing the AssertingVH in the cache to
1866   // trigger.
1867   Rewriter.clear();
1868 
1869   // Now that we're done iterating through lists, clean up any instructions
1870   // which are now dead.
1871   while (!DeadInsts.empty()) {
1872     Value *V = DeadInsts.pop_back_val();
1873 
1874     if (PHINode *PHI = dyn_cast_or_null<PHINode>(V))
1875       Changed |= RecursivelyDeleteDeadPHINode(PHI, TLI, MSSAU.get());
1876     else if (Instruction *Inst = dyn_cast_or_null<Instruction>(V))
1877       Changed |=
1878           RecursivelyDeleteTriviallyDeadInstructions(Inst, TLI, MSSAU.get());
1879   }
1880 
1881   // The Rewriter may not be used from this point on.
1882 
1883   // Loop-invariant instructions in the preheader that aren't used in the
1884   // loop may be sunk below the loop to reduce register pressure.
1885   Changed |= sinkUnusedInvariants(L);
1886 
1887   // rewriteFirstIterationLoopExitValues does not rely on the computation of
1888   // trip count and therefore can further simplify exit values in addition to
1889   // rewriteLoopExitValues.
1890   Changed |= rewriteFirstIterationLoopExitValues(L);
1891 
1892   // Clean up dead instructions.
1893   Changed |= DeleteDeadPHIs(L->getHeader(), TLI, MSSAU.get());
1894 
1895   // Check a post-condition.
1896   assert(L->isRecursivelyLCSSAForm(*DT, *LI) &&
1897          "Indvars did not preserve LCSSA!");
1898 
1899   // Verify that LFTR, and any other change have not interfered with SCEV's
1900   // ability to compute trip count.  We may have *changed* the exit count, but
1901   // only by reducing it.
1902 #ifndef NDEBUG
1903   if (VerifyIndvars && !isa<SCEVCouldNotCompute>(BackedgeTakenCount)) {
1904     SE->forgetLoop(L);
1905     const SCEV *NewBECount = SE->getBackedgeTakenCount(L);
1906     if (SE->getTypeSizeInBits(BackedgeTakenCount->getType()) <
1907         SE->getTypeSizeInBits(NewBECount->getType()))
1908       NewBECount = SE->getTruncateOrNoop(NewBECount,
1909                                          BackedgeTakenCount->getType());
1910     else
1911       BackedgeTakenCount = SE->getTruncateOrNoop(BackedgeTakenCount,
1912                                                  NewBECount->getType());
1913     assert(!SE->isKnownPredicate(ICmpInst::ICMP_ULT, BackedgeTakenCount,
1914                                  NewBECount) && "indvars must preserve SCEV");
1915   }
1916   if (VerifyMemorySSA && MSSAU)
1917     MSSAU->getMemorySSA()->verifyMemorySSA();
1918 #endif
1919 
1920   return Changed;
1921 }
1922 
run(Loop & L,LoopAnalysisManager & AM,LoopStandardAnalysisResults & AR,LPMUpdater &)1923 PreservedAnalyses IndVarSimplifyPass::run(Loop &L, LoopAnalysisManager &AM,
1924                                           LoopStandardAnalysisResults &AR,
1925                                           LPMUpdater &) {
1926   Function *F = L.getHeader()->getParent();
1927   const DataLayout &DL = F->getParent()->getDataLayout();
1928 
1929   IndVarSimplify IVS(&AR.LI, &AR.SE, &AR.DT, DL, &AR.TLI, &AR.TTI, AR.MSSA,
1930                      WidenIndVars && AllowIVWidening);
1931   if (!IVS.run(&L))
1932     return PreservedAnalyses::all();
1933 
1934   auto PA = getLoopPassPreservedAnalyses();
1935   PA.preserveSet<CFGAnalyses>();
1936   if (AR.MSSA)
1937     PA.preserve<MemorySSAAnalysis>();
1938   return PA;
1939 }
1940 
1941 namespace {
1942 
1943 struct IndVarSimplifyLegacyPass : public LoopPass {
1944   static char ID; // Pass identification, replacement for typeid
1945 
IndVarSimplifyLegacyPass__anon0eacf89c0911::IndVarSimplifyLegacyPass1946   IndVarSimplifyLegacyPass() : LoopPass(ID) {
1947     initializeIndVarSimplifyLegacyPassPass(*PassRegistry::getPassRegistry());
1948   }
1949 
runOnLoop__anon0eacf89c0911::IndVarSimplifyLegacyPass1950   bool runOnLoop(Loop *L, LPPassManager &LPM) override {
1951     if (skipLoop(L))
1952       return false;
1953 
1954     auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1955     auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
1956     auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
1957     auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>();
1958     auto *TLI = TLIP ? &TLIP->getTLI(*L->getHeader()->getParent()) : nullptr;
1959     auto *TTIP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>();
1960     auto *TTI = TTIP ? &TTIP->getTTI(*L->getHeader()->getParent()) : nullptr;
1961     const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
1962     auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>();
1963     MemorySSA *MSSA = nullptr;
1964     if (MSSAAnalysis)
1965       MSSA = &MSSAAnalysis->getMSSA();
1966 
1967     IndVarSimplify IVS(LI, SE, DT, DL, TLI, TTI, MSSA, AllowIVWidening);
1968     return IVS.run(L);
1969   }
1970 
getAnalysisUsage__anon0eacf89c0911::IndVarSimplifyLegacyPass1971   void getAnalysisUsage(AnalysisUsage &AU) const override {
1972     AU.setPreservesCFG();
1973     AU.addPreserved<MemorySSAWrapperPass>();
1974     getLoopAnalysisUsage(AU);
1975   }
1976 };
1977 
1978 } // end anonymous namespace
1979 
1980 char IndVarSimplifyLegacyPass::ID = 0;
1981 
1982 INITIALIZE_PASS_BEGIN(IndVarSimplifyLegacyPass, "indvars",
1983                       "Induction Variable Simplification", false, false)
INITIALIZE_PASS_DEPENDENCY(LoopPass)1984 INITIALIZE_PASS_DEPENDENCY(LoopPass)
1985 INITIALIZE_PASS_END(IndVarSimplifyLegacyPass, "indvars",
1986                     "Induction Variable Simplification", false, false)
1987 
1988 Pass *llvm::createIndVarSimplifyPass() {
1989   return new IndVarSimplifyLegacyPass();
1990 }
1991