1 //===-- LoopUnrollAndJam.cpp - Loop unrolling utilities -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop unroll and jam as a routine, much like
10 // LoopUnroll.cpp implements loop unroll.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/ArrayRef.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/Optional.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/Statistic.h"
21 #include "llvm/ADT/StringRef.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/ADT/iterator_range.h"
24 #include "llvm/Analysis/AssumptionCache.h"
25 #include "llvm/Analysis/DependenceAnalysis.h"
26 #include "llvm/Analysis/DomTreeUpdater.h"
27 #include "llvm/Analysis/LoopInfo.h"
28 #include "llvm/Analysis/LoopIterator.h"
29 #include "llvm/Analysis/MustExecute.h"
30 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
31 #include "llvm/Analysis/ScalarEvolution.h"
32 #include "llvm/IR/BasicBlock.h"
33 #include "llvm/IR/DebugInfoMetadata.h"
34 #include "llvm/IR/DebugLoc.h"
35 #include "llvm/IR/DiagnosticInfo.h"
36 #include "llvm/IR/Dominators.h"
37 #include "llvm/IR/Function.h"
38 #include "llvm/IR/Instruction.h"
39 #include "llvm/IR/Instructions.h"
40 #include "llvm/IR/IntrinsicInst.h"
41 #include "llvm/IR/User.h"
42 #include "llvm/IR/Value.h"
43 #include "llvm/IR/ValueHandle.h"
44 #include "llvm/IR/ValueMap.h"
45 #include "llvm/Support/Casting.h"
46 #include "llvm/Support/Debug.h"
47 #include "llvm/Support/ErrorHandling.h"
48 #include "llvm/Support/GenericDomTree.h"
49 #include "llvm/Support/raw_ostream.h"
50 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
51 #include "llvm/Transforms/Utils/Cloning.h"
52 #include "llvm/Transforms/Utils/LoopUtils.h"
53 #include "llvm/Transforms/Utils/UnrollLoop.h"
54 #include "llvm/Transforms/Utils/ValueMapper.h"
55 #include <assert.h>
56 #include <memory>
57 #include <type_traits>
58 #include <vector>
59 
60 using namespace llvm;
61 
62 #define DEBUG_TYPE "loop-unroll-and-jam"
63 
64 STATISTIC(NumUnrolledAndJammed, "Number of loops unroll and jammed");
65 STATISTIC(NumCompletelyUnrolledAndJammed, "Number of loops unroll and jammed");
66 
67 typedef SmallPtrSet<BasicBlock *, 4> BasicBlockSet;
68 
69 // Partition blocks in an outer/inner loop pair into blocks before and after
70 // the loop
71 static bool partitionLoopBlocks(Loop &L, BasicBlockSet &ForeBlocks,
72                                 BasicBlockSet &AftBlocks, DominatorTree &DT) {
73   Loop *SubLoop = L.getSubLoops()[0];
74   BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
75 
76   for (BasicBlock *BB : L.blocks()) {
77     if (!SubLoop->contains(BB)) {
78       if (DT.dominates(SubLoopLatch, BB))
79         AftBlocks.insert(BB);
80       else
81         ForeBlocks.insert(BB);
82     }
83   }
84 
85   // Check that all blocks in ForeBlocks together dominate the subloop
86   // TODO: This might ideally be done better with a dominator/postdominators.
87   BasicBlock *SubLoopPreHeader = SubLoop->getLoopPreheader();
88   for (BasicBlock *BB : ForeBlocks) {
89     if (BB == SubLoopPreHeader)
90       continue;
91     Instruction *TI = BB->getTerminator();
92     for (BasicBlock *Succ : successors(TI))
93       if (!ForeBlocks.count(Succ))
94         return false;
95   }
96 
97   return true;
98 }
99 
100 /// Partition blocks in a loop nest into blocks before and after each inner
101 /// loop.
102 static bool partitionOuterLoopBlocks(
103     Loop &Root, Loop &JamLoop, BasicBlockSet &JamLoopBlocks,
104     DenseMap<Loop *, BasicBlockSet> &ForeBlocksMap,
105     DenseMap<Loop *, BasicBlockSet> &AftBlocksMap, DominatorTree &DT) {
106   JamLoopBlocks.insert(JamLoop.block_begin(), JamLoop.block_end());
107 
108   for (Loop *L : Root.getLoopsInPreorder()) {
109     if (L == &JamLoop)
110       break;
111 
112     if (!partitionLoopBlocks(*L, ForeBlocksMap[L], AftBlocksMap[L], DT))
113       return false;
114   }
115 
116   return true;
117 }
118 
119 // TODO Remove when UnrollAndJamLoop changed to support unroll and jamming more
120 // than 2 levels loop.
121 static bool partitionOuterLoopBlocks(Loop *L, Loop *SubLoop,
122                                      BasicBlockSet &ForeBlocks,
123                                      BasicBlockSet &SubLoopBlocks,
124                                      BasicBlockSet &AftBlocks,
125                                      DominatorTree *DT) {
126   SubLoopBlocks.insert(SubLoop->block_begin(), SubLoop->block_end());
127   return partitionLoopBlocks(*L, ForeBlocks, AftBlocks, *DT);
128 }
129 
130 // Looks at the phi nodes in Header for values coming from Latch. For these
131 // instructions and all their operands calls Visit on them, keeping going for
132 // all the operands in AftBlocks. Returns false if Visit returns false,
133 // otherwise returns true. This is used to process the instructions in the
134 // Aft blocks that need to be moved before the subloop. It is used in two
135 // places. One to check that the required set of instructions can be moved
136 // before the loop. Then to collect the instructions to actually move in
137 // moveHeaderPhiOperandsToForeBlocks.
138 template <typename T>
139 static bool processHeaderPhiOperands(BasicBlock *Header, BasicBlock *Latch,
140                                      BasicBlockSet &AftBlocks, T Visit) {
141   SmallVector<Instruction *, 8> Worklist;
142   SmallPtrSet<Instruction *, 8> VisitedInstr;
143   for (auto &Phi : Header->phis()) {
144     Value *V = Phi.getIncomingValueForBlock(Latch);
145     if (Instruction *I = dyn_cast<Instruction>(V))
146       Worklist.push_back(I);
147   }
148 
149   while (!Worklist.empty()) {
150     Instruction *I = Worklist.pop_back_val();
151     if (!Visit(I))
152       return false;
153     VisitedInstr.insert(I);
154 
155     if (AftBlocks.count(I->getParent()))
156       for (auto &U : I->operands())
157         if (Instruction *II = dyn_cast<Instruction>(U))
158           if (!VisitedInstr.count(II))
159             Worklist.push_back(II);
160   }
161 
162   return true;
163 }
164 
165 // Move the phi operands of Header from Latch out of AftBlocks to InsertLoc.
166 static void moveHeaderPhiOperandsToForeBlocks(BasicBlock *Header,
167                                               BasicBlock *Latch,
168                                               Instruction *InsertLoc,
169                                               BasicBlockSet &AftBlocks) {
170   // We need to ensure we move the instructions in the correct order,
171   // starting with the earliest required instruction and moving forward.
172   std::vector<Instruction *> Visited;
173   processHeaderPhiOperands(Header, Latch, AftBlocks,
174                            [&Visited, &AftBlocks](Instruction *I) {
175                              if (AftBlocks.count(I->getParent()))
176                                Visited.push_back(I);
177                              return true;
178                            });
179 
180   // Move all instructions in program order to before the InsertLoc
181   BasicBlock *InsertLocBB = InsertLoc->getParent();
182   for (Instruction *I : reverse(Visited)) {
183     if (I->getParent() != InsertLocBB)
184       I->moveBefore(InsertLoc);
185   }
186 }
187 
188 /*
189   This method performs Unroll and Jam. For a simple loop like:
190   for (i = ..)
191     Fore(i)
192     for (j = ..)
193       SubLoop(i, j)
194     Aft(i)
195 
196   Instead of doing normal inner or outer unrolling, we do:
197   for (i = .., i+=2)
198     Fore(i)
199     Fore(i+1)
200     for (j = ..)
201       SubLoop(i, j)
202       SubLoop(i+1, j)
203     Aft(i)
204     Aft(i+1)
205 
206   So the outer loop is essetially unrolled and then the inner loops are fused
207   ("jammed") together into a single loop. This can increase speed when there
208   are loads in SubLoop that are invariant to i, as they become shared between
209   the now jammed inner loops.
210 
211   We do this by spliting the blocks in the loop into Fore, Subloop and Aft.
212   Fore blocks are those before the inner loop, Aft are those after. Normal
213   Unroll code is used to copy each of these sets of blocks and the results are
214   combined together into the final form above.
215 
216   isSafeToUnrollAndJam should be used prior to calling this to make sure the
217   unrolling will be valid. Checking profitablility is also advisable.
218 
219   If EpilogueLoop is non-null, it receives the epilogue loop (if it was
220   necessary to create one and not fully unrolled).
221 */
222 LoopUnrollResult
223 llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount,
224                        unsigned TripMultiple, bool UnrollRemainder,
225                        LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT,
226                        AssumptionCache *AC, const TargetTransformInfo *TTI,
227                        OptimizationRemarkEmitter *ORE, Loop **EpilogueLoop) {
228 
229   // When we enter here we should have already checked that it is safe
230   BasicBlock *Header = L->getHeader();
231   assert(Header && "No header.");
232   assert(L->getSubLoops().size() == 1);
233   Loop *SubLoop = *L->begin();
234 
235   // Don't enter the unroll code if there is nothing to do.
236   if (TripCount == 0 && Count < 2) {
237     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; almost nothing to do\n");
238     return LoopUnrollResult::Unmodified;
239   }
240 
241   assert(Count > 0);
242   assert(TripMultiple > 0);
243   assert(TripCount == 0 || TripCount % TripMultiple == 0);
244 
245   // Are we eliminating the loop control altogether?
246   bool CompletelyUnroll = (Count == TripCount);
247 
248   // We use the runtime remainder in cases where we don't know trip multiple
249   if (TripMultiple % Count != 0) {
250     if (!UnrollRuntimeLoopRemainder(L, Count, /*AllowExpensiveTripCount*/ false,
251                                     /*UseEpilogRemainder*/ true,
252                                     UnrollRemainder, /*ForgetAllSCEV*/ false,
253                                     LI, SE, DT, AC, TTI, true, EpilogueLoop)) {
254       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; remainder loop could not be "
255                            "generated when assuming runtime trip count\n");
256       return LoopUnrollResult::Unmodified;
257     }
258   }
259 
260   // Notify ScalarEvolution that the loop will be substantially changed,
261   // if not outright eliminated.
262   if (SE) {
263     SE->forgetLoop(L);
264     SE->forgetLoop(SubLoop);
265   }
266 
267   using namespace ore;
268   // Report the unrolling decision.
269   if (CompletelyUnroll) {
270     LLVM_DEBUG(dbgs() << "COMPLETELY UNROLL AND JAMMING loop %"
271                       << Header->getName() << " with trip count " << TripCount
272                       << "!\n");
273     ORE->emit(OptimizationRemark(DEBUG_TYPE, "FullyUnrolled", L->getStartLoc(),
274                                  L->getHeader())
275               << "completely unroll and jammed loop with "
276               << NV("UnrollCount", TripCount) << " iterations");
277   } else {
278     auto DiagBuilder = [&]() {
279       OptimizationRemark Diag(DEBUG_TYPE, "PartialUnrolled", L->getStartLoc(),
280                               L->getHeader());
281       return Diag << "unroll and jammed loop by a factor of "
282                   << NV("UnrollCount", Count);
283     };
284 
285     LLVM_DEBUG(dbgs() << "UNROLL AND JAMMING loop %" << Header->getName()
286                       << " by " << Count);
287     if (TripMultiple != 1) {
288       LLVM_DEBUG(dbgs() << " with " << TripMultiple << " trips per branch");
289       ORE->emit([&]() {
290         return DiagBuilder() << " with " << NV("TripMultiple", TripMultiple)
291                              << " trips per branch";
292       });
293     } else {
294       LLVM_DEBUG(dbgs() << " with run-time trip count");
295       ORE->emit([&]() { return DiagBuilder() << " with run-time trip count"; });
296     }
297     LLVM_DEBUG(dbgs() << "!\n");
298   }
299 
300   BasicBlock *Preheader = L->getLoopPreheader();
301   BasicBlock *LatchBlock = L->getLoopLatch();
302   assert(Preheader && "No preheader");
303   assert(LatchBlock && "No latch block");
304   BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator());
305   assert(BI && !BI->isUnconditional());
306   bool ContinueOnTrue = L->contains(BI->getSuccessor(0));
307   BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue);
308   bool SubLoopContinueOnTrue = SubLoop->contains(
309       SubLoop->getLoopLatch()->getTerminator()->getSuccessor(0));
310 
311   // Partition blocks in an outer/inner loop pair into blocks before and after
312   // the loop
313   BasicBlockSet SubLoopBlocks;
314   BasicBlockSet ForeBlocks;
315   BasicBlockSet AftBlocks;
316   partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks, AftBlocks,
317                            DT);
318 
319   // We keep track of the entering/first and exiting/last block of each of
320   // Fore/SubLoop/Aft in each iteration. This helps make the stapling up of
321   // blocks easier.
322   std::vector<BasicBlock *> ForeBlocksFirst;
323   std::vector<BasicBlock *> ForeBlocksLast;
324   std::vector<BasicBlock *> SubLoopBlocksFirst;
325   std::vector<BasicBlock *> SubLoopBlocksLast;
326   std::vector<BasicBlock *> AftBlocksFirst;
327   std::vector<BasicBlock *> AftBlocksLast;
328   ForeBlocksFirst.push_back(Header);
329   ForeBlocksLast.push_back(SubLoop->getLoopPreheader());
330   SubLoopBlocksFirst.push_back(SubLoop->getHeader());
331   SubLoopBlocksLast.push_back(SubLoop->getExitingBlock());
332   AftBlocksFirst.push_back(SubLoop->getExitBlock());
333   AftBlocksLast.push_back(L->getExitingBlock());
334   // Maps Blocks[0] -> Blocks[It]
335   ValueToValueMapTy LastValueMap;
336 
337   // Move any instructions from fore phi operands from AftBlocks into Fore.
338   moveHeaderPhiOperandsToForeBlocks(
339       Header, LatchBlock, ForeBlocksLast[0]->getTerminator(), AftBlocks);
340 
341   // The current on-the-fly SSA update requires blocks to be processed in
342   // reverse postorder so that LastValueMap contains the correct value at each
343   // exit.
344   LoopBlocksDFS DFS(L);
345   DFS.perform(LI);
346   // Stash the DFS iterators before adding blocks to the loop.
347   LoopBlocksDFS::RPOIterator BlockBegin = DFS.beginRPO();
348   LoopBlocksDFS::RPOIterator BlockEnd = DFS.endRPO();
349 
350   // When a FSDiscriminator is enabled, we don't need to add the multiply
351   // factors to the discriminators.
352   if (Header->getParent()->isDebugInfoForProfiling() && !EnableFSDiscriminator)
353     for (BasicBlock *BB : L->getBlocks())
354       for (Instruction &I : *BB)
355         if (!isa<DbgInfoIntrinsic>(&I))
356           if (const DILocation *DIL = I.getDebugLoc()) {
357             auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(Count);
358             if (NewDIL)
359               I.setDebugLoc(*NewDIL);
360             else
361               LLVM_DEBUG(dbgs()
362                          << "Failed to create new discriminator: "
363                          << DIL->getFilename() << " Line: " << DIL->getLine());
364           }
365 
366   // Copy all blocks
367   for (unsigned It = 1; It != Count; ++It) {
368     SmallVector<BasicBlock *, 8> NewBlocks;
369     // Maps Blocks[It] -> Blocks[It-1]
370     DenseMap<Value *, Value *> PrevItValueMap;
371     SmallDenseMap<const Loop *, Loop *, 4> NewLoops;
372     NewLoops[L] = L;
373     NewLoops[SubLoop] = SubLoop;
374 
375     for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) {
376       ValueToValueMapTy VMap;
377       BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It));
378       Header->getParent()->getBasicBlockList().push_back(New);
379 
380       // Tell LI about New.
381       addClonedBlockToLoopInfo(*BB, New, LI, NewLoops);
382 
383       if (ForeBlocks.count(*BB)) {
384         if (*BB == ForeBlocksFirst[0])
385           ForeBlocksFirst.push_back(New);
386         if (*BB == ForeBlocksLast[0])
387           ForeBlocksLast.push_back(New);
388       } else if (SubLoopBlocks.count(*BB)) {
389         if (*BB == SubLoopBlocksFirst[0])
390           SubLoopBlocksFirst.push_back(New);
391         if (*BB == SubLoopBlocksLast[0])
392           SubLoopBlocksLast.push_back(New);
393       } else if (AftBlocks.count(*BB)) {
394         if (*BB == AftBlocksFirst[0])
395           AftBlocksFirst.push_back(New);
396         if (*BB == AftBlocksLast[0])
397           AftBlocksLast.push_back(New);
398       } else {
399         llvm_unreachable("BB being cloned should be in Fore/Sub/Aft");
400       }
401 
402       // Update our running maps of newest clones
403       PrevItValueMap[New] = (It == 1 ? *BB : LastValueMap[*BB]);
404       LastValueMap[*BB] = New;
405       for (ValueToValueMapTy::iterator VI = VMap.begin(), VE = VMap.end();
406            VI != VE; ++VI) {
407         PrevItValueMap[VI->second] =
408             const_cast<Value *>(It == 1 ? VI->first : LastValueMap[VI->first]);
409         LastValueMap[VI->first] = VI->second;
410       }
411 
412       NewBlocks.push_back(New);
413 
414       // Update DomTree:
415       if (*BB == ForeBlocksFirst[0])
416         DT->addNewBlock(New, ForeBlocksLast[It - 1]);
417       else if (*BB == SubLoopBlocksFirst[0])
418         DT->addNewBlock(New, SubLoopBlocksLast[It - 1]);
419       else if (*BB == AftBlocksFirst[0])
420         DT->addNewBlock(New, AftBlocksLast[It - 1]);
421       else {
422         // Each set of blocks (Fore/Sub/Aft) will have the same internal domtree
423         // structure.
424         auto BBDomNode = DT->getNode(*BB);
425         auto BBIDom = BBDomNode->getIDom();
426         BasicBlock *OriginalBBIDom = BBIDom->getBlock();
427         assert(OriginalBBIDom);
428         assert(LastValueMap[cast<Value>(OriginalBBIDom)]);
429         DT->addNewBlock(
430             New, cast<BasicBlock>(LastValueMap[cast<Value>(OriginalBBIDom)]));
431       }
432     }
433 
434     // Remap all instructions in the most recent iteration
435     remapInstructionsInBlocks(NewBlocks, LastValueMap);
436     for (BasicBlock *NewBlock : NewBlocks) {
437       for (Instruction &I : *NewBlock) {
438         if (auto *II = dyn_cast<AssumeInst>(&I))
439           AC->registerAssumption(II);
440       }
441     }
442 
443     // Alter the ForeBlocks phi's, pointing them at the latest version of the
444     // value from the previous iteration's phis
445     for (PHINode &Phi : ForeBlocksFirst[It]->phis()) {
446       Value *OldValue = Phi.getIncomingValueForBlock(AftBlocksLast[It]);
447       assert(OldValue && "should have incoming edge from Aft[It]");
448       Value *NewValue = OldValue;
449       if (Value *PrevValue = PrevItValueMap[OldValue])
450         NewValue = PrevValue;
451 
452       assert(Phi.getNumOperands() == 2);
453       Phi.setIncomingBlock(0, ForeBlocksLast[It - 1]);
454       Phi.setIncomingValue(0, NewValue);
455       Phi.removeIncomingValue(1);
456     }
457   }
458 
459   // Now that all the basic blocks for the unrolled iterations are in place,
460   // finish up connecting the blocks and phi nodes. At this point LastValueMap
461   // is the last unrolled iterations values.
462 
463   // Update Phis in BB from OldBB to point to NewBB and use the latest value
464   // from LastValueMap
465   auto updatePHIBlocksAndValues = [](BasicBlock *BB, BasicBlock *OldBB,
466                                      BasicBlock *NewBB,
467                                      ValueToValueMapTy &LastValueMap) {
468     for (PHINode &Phi : BB->phis()) {
469       for (unsigned b = 0; b < Phi.getNumIncomingValues(); ++b) {
470         if (Phi.getIncomingBlock(b) == OldBB) {
471           Value *OldValue = Phi.getIncomingValue(b);
472           if (Value *LastValue = LastValueMap[OldValue])
473             Phi.setIncomingValue(b, LastValue);
474           Phi.setIncomingBlock(b, NewBB);
475           break;
476         }
477       }
478     }
479   };
480   // Move all the phis from Src into Dest
481   auto movePHIs = [](BasicBlock *Src, BasicBlock *Dest) {
482     Instruction *insertPoint = Dest->getFirstNonPHI();
483     while (PHINode *Phi = dyn_cast<PHINode>(Src->begin()))
484       Phi->moveBefore(insertPoint);
485   };
486 
487   // Update the PHI values outside the loop to point to the last block
488   updatePHIBlocksAndValues(LoopExit, AftBlocksLast[0], AftBlocksLast.back(),
489                            LastValueMap);
490 
491   // Update ForeBlocks successors and phi nodes
492   BranchInst *ForeTerm =
493       cast<BranchInst>(ForeBlocksLast.back()->getTerminator());
494   assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
495   ForeTerm->setSuccessor(0, SubLoopBlocksFirst[0]);
496 
497   if (CompletelyUnroll) {
498     while (PHINode *Phi = dyn_cast<PHINode>(ForeBlocksFirst[0]->begin())) {
499       Phi->replaceAllUsesWith(Phi->getIncomingValueForBlock(Preheader));
500       Phi->getParent()->getInstList().erase(Phi);
501     }
502   } else {
503     // Update the PHI values to point to the last aft block
504     updatePHIBlocksAndValues(ForeBlocksFirst[0], AftBlocksLast[0],
505                              AftBlocksLast.back(), LastValueMap);
506   }
507 
508   for (unsigned It = 1; It != Count; It++) {
509     // Remap ForeBlock successors from previous iteration to this
510     BranchInst *ForeTerm =
511         cast<BranchInst>(ForeBlocksLast[It - 1]->getTerminator());
512     assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
513     ForeTerm->setSuccessor(0, ForeBlocksFirst[It]);
514   }
515 
516   // Subloop successors and phis
517   BranchInst *SubTerm =
518       cast<BranchInst>(SubLoopBlocksLast.back()->getTerminator());
519   SubTerm->setSuccessor(!SubLoopContinueOnTrue, SubLoopBlocksFirst[0]);
520   SubTerm->setSuccessor(SubLoopContinueOnTrue, AftBlocksFirst[0]);
521   SubLoopBlocksFirst[0]->replacePhiUsesWith(ForeBlocksLast[0],
522                                             ForeBlocksLast.back());
523   SubLoopBlocksFirst[0]->replacePhiUsesWith(SubLoopBlocksLast[0],
524                                             SubLoopBlocksLast.back());
525 
526   for (unsigned It = 1; It != Count; It++) {
527     // Replace the conditional branch of the previous iteration subloop with an
528     // unconditional one to this one
529     BranchInst *SubTerm =
530         cast<BranchInst>(SubLoopBlocksLast[It - 1]->getTerminator());
531     BranchInst::Create(SubLoopBlocksFirst[It], SubTerm);
532     SubTerm->eraseFromParent();
533 
534     SubLoopBlocksFirst[It]->replacePhiUsesWith(ForeBlocksLast[It],
535                                                ForeBlocksLast.back());
536     SubLoopBlocksFirst[It]->replacePhiUsesWith(SubLoopBlocksLast[It],
537                                                SubLoopBlocksLast.back());
538     movePHIs(SubLoopBlocksFirst[It], SubLoopBlocksFirst[0]);
539   }
540 
541   // Aft blocks successors and phis
542   BranchInst *AftTerm = cast<BranchInst>(AftBlocksLast.back()->getTerminator());
543   if (CompletelyUnroll) {
544     BranchInst::Create(LoopExit, AftTerm);
545     AftTerm->eraseFromParent();
546   } else {
547     AftTerm->setSuccessor(!ContinueOnTrue, ForeBlocksFirst[0]);
548     assert(AftTerm->getSuccessor(ContinueOnTrue) == LoopExit &&
549            "Expecting the ContinueOnTrue successor of AftTerm to be LoopExit");
550   }
551   AftBlocksFirst[0]->replacePhiUsesWith(SubLoopBlocksLast[0],
552                                         SubLoopBlocksLast.back());
553 
554   for (unsigned It = 1; It != Count; It++) {
555     // Replace the conditional branch of the previous iteration subloop with an
556     // unconditional one to this one
557     BranchInst *AftTerm =
558         cast<BranchInst>(AftBlocksLast[It - 1]->getTerminator());
559     BranchInst::Create(AftBlocksFirst[It], AftTerm);
560     AftTerm->eraseFromParent();
561 
562     AftBlocksFirst[It]->replacePhiUsesWith(SubLoopBlocksLast[It],
563                                            SubLoopBlocksLast.back());
564     movePHIs(AftBlocksFirst[It], AftBlocksFirst[0]);
565   }
566 
567   DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
568   // Dominator Tree. Remove the old links between Fore, Sub and Aft, adding the
569   // new ones required.
570   if (Count != 1) {
571     SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
572     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete, ForeBlocksLast[0],
573                            SubLoopBlocksFirst[0]);
574     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete,
575                            SubLoopBlocksLast[0], AftBlocksFirst[0]);
576 
577     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
578                            ForeBlocksLast.back(), SubLoopBlocksFirst[0]);
579     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
580                            SubLoopBlocksLast.back(), AftBlocksFirst[0]);
581     DTU.applyUpdatesPermissive(DTUpdates);
582   }
583 
584   // Merge adjacent basic blocks, if possible.
585   SmallPtrSet<BasicBlock *, 16> MergeBlocks;
586   MergeBlocks.insert(ForeBlocksLast.begin(), ForeBlocksLast.end());
587   MergeBlocks.insert(SubLoopBlocksLast.begin(), SubLoopBlocksLast.end());
588   MergeBlocks.insert(AftBlocksLast.begin(), AftBlocksLast.end());
589 
590   MergeBlockSuccessorsIntoGivenBlocks(MergeBlocks, L, &DTU, LI);
591 
592   // Apply updates to the DomTree.
593   DT = &DTU.getDomTree();
594 
595   // At this point, the code is well formed.  We now do a quick sweep over the
596   // inserted code, doing constant propagation and dead code elimination as we
597   // go.
598   simplifyLoopAfterUnroll(SubLoop, true, LI, SE, DT, AC, TTI);
599   simplifyLoopAfterUnroll(L, !CompletelyUnroll && Count > 1, LI, SE, DT, AC,
600                           TTI);
601 
602   NumCompletelyUnrolledAndJammed += CompletelyUnroll;
603   ++NumUnrolledAndJammed;
604 
605   // Update LoopInfo if the loop is completely removed.
606   if (CompletelyUnroll)
607     LI->erase(L);
608 
609 #ifndef NDEBUG
610   // We shouldn't have done anything to break loop simplify form or LCSSA.
611   Loop *OutestLoop = SubLoop->getParentLoop()
612                          ? SubLoop->getParentLoop()->getParentLoop()
613                                ? SubLoop->getParentLoop()->getParentLoop()
614                                : SubLoop->getParentLoop()
615                          : SubLoop;
616   assert(DT->verify());
617   LI->verify(*DT);
618   assert(OutestLoop->isRecursivelyLCSSAForm(*DT, *LI));
619   if (!CompletelyUnroll)
620     assert(L->isLoopSimplifyForm());
621   assert(SubLoop->isLoopSimplifyForm());
622   SE->verify();
623 #endif
624 
625   return CompletelyUnroll ? LoopUnrollResult::FullyUnrolled
626                           : LoopUnrollResult::PartiallyUnrolled;
627 }
628 
629 static bool getLoadsAndStores(BasicBlockSet &Blocks,
630                               SmallVector<Instruction *, 4> &MemInstr) {
631   // Scan the BBs and collect legal loads and stores.
632   // Returns false if non-simple loads/stores are found.
633   for (BasicBlock *BB : Blocks) {
634     for (Instruction &I : *BB) {
635       if (auto *Ld = dyn_cast<LoadInst>(&I)) {
636         if (!Ld->isSimple())
637           return false;
638         MemInstr.push_back(&I);
639       } else if (auto *St = dyn_cast<StoreInst>(&I)) {
640         if (!St->isSimple())
641           return false;
642         MemInstr.push_back(&I);
643       } else if (I.mayReadOrWriteMemory()) {
644         return false;
645       }
646     }
647   }
648   return true;
649 }
650 
651 static bool preservesForwardDependence(Instruction *Src, Instruction *Dst,
652                                        unsigned UnrollLevel, unsigned JamLevel,
653                                        bool Sequentialized, Dependence *D) {
654   // UnrollLevel might carry the dependency Src --> Dst
655   // Does a different loop after unrolling?
656   for (unsigned CurLoopDepth = UnrollLevel + 1; CurLoopDepth <= JamLevel;
657        ++CurLoopDepth) {
658     auto JammedDir = D->getDirection(CurLoopDepth);
659     if (JammedDir == Dependence::DVEntry::LT)
660       return true;
661 
662     if (JammedDir & Dependence::DVEntry::GT)
663       return false;
664   }
665 
666   return true;
667 }
668 
669 static bool preservesBackwardDependence(Instruction *Src, Instruction *Dst,
670                                         unsigned UnrollLevel, unsigned JamLevel,
671                                         bool Sequentialized, Dependence *D) {
672   // UnrollLevel might carry the dependency Dst --> Src
673   for (unsigned CurLoopDepth = UnrollLevel + 1; CurLoopDepth <= JamLevel;
674        ++CurLoopDepth) {
675     auto JammedDir = D->getDirection(CurLoopDepth);
676     if (JammedDir == Dependence::DVEntry::GT)
677       return true;
678 
679     if (JammedDir & Dependence::DVEntry::LT)
680       return false;
681   }
682 
683   // Backward dependencies are only preserved if not interleaved.
684   return Sequentialized;
685 }
686 
687 // Check whether it is semantically safe Src and Dst considering any potential
688 // dependency between them.
689 //
690 // @param UnrollLevel The level of the loop being unrolled
691 // @param JamLevel    The level of the loop being jammed; if Src and Dst are on
692 // different levels, the outermost common loop counts as jammed level
693 //
694 // @return true if is safe and false if there is a dependency violation.
695 static bool checkDependency(Instruction *Src, Instruction *Dst,
696                             unsigned UnrollLevel, unsigned JamLevel,
697                             bool Sequentialized, DependenceInfo &DI) {
698   assert(UnrollLevel <= JamLevel &&
699          "Expecting JamLevel to be at least UnrollLevel");
700 
701   if (Src == Dst)
702     return true;
703   // Ignore Input dependencies.
704   if (isa<LoadInst>(Src) && isa<LoadInst>(Dst))
705     return true;
706 
707   // Check whether unroll-and-jam may violate a dependency.
708   // By construction, every dependency will be lexicographically non-negative
709   // (if it was, it would violate the current execution order), such as
710   //   (0,0,>,*,*)
711   // Unroll-and-jam changes the GT execution of two executions to the same
712   // iteration of the chosen unroll level. That is, a GT dependence becomes a GE
713   // dependence (or EQ, if we fully unrolled the loop) at the loop's position:
714   //   (0,0,>=,*,*)
715   // Now, the dependency is not necessarily non-negative anymore, i.e.
716   // unroll-and-jam may violate correctness.
717   std::unique_ptr<Dependence> D = DI.depends(Src, Dst, true);
718   if (!D)
719     return true;
720   assert(D->isOrdered() && "Expected an output, flow or anti dep.");
721 
722   if (D->isConfused()) {
723     LLVM_DEBUG(dbgs() << "  Confused dependency between:\n"
724                       << "  " << *Src << "\n"
725                       << "  " << *Dst << "\n");
726     return false;
727   }
728 
729   // If outer levels (levels enclosing the loop being unroll-and-jammed) have a
730   // non-equal direction, then the locations accessed in the inner levels cannot
731   // overlap in memory. We assumes the indexes never overlap into neighboring
732   // dimensions.
733   for (unsigned CurLoopDepth = 1; CurLoopDepth < UnrollLevel; ++CurLoopDepth)
734     if (!(D->getDirection(CurLoopDepth) & Dependence::DVEntry::EQ))
735       return true;
736 
737   auto UnrollDirection = D->getDirection(UnrollLevel);
738 
739   // If the distance carried by the unrolled loop is 0, then after unrolling
740   // that distance will become non-zero resulting in non-overlapping accesses in
741   // the inner loops.
742   if (UnrollDirection == Dependence::DVEntry::EQ)
743     return true;
744 
745   if (UnrollDirection & Dependence::DVEntry::LT &&
746       !preservesForwardDependence(Src, Dst, UnrollLevel, JamLevel,
747                                   Sequentialized, D.get()))
748     return false;
749 
750   if (UnrollDirection & Dependence::DVEntry::GT &&
751       !preservesBackwardDependence(Src, Dst, UnrollLevel, JamLevel,
752                                    Sequentialized, D.get()))
753     return false;
754 
755   return true;
756 }
757 
758 static bool
759 checkDependencies(Loop &Root, const BasicBlockSet &SubLoopBlocks,
760                   const DenseMap<Loop *, BasicBlockSet> &ForeBlocksMap,
761                   const DenseMap<Loop *, BasicBlockSet> &AftBlocksMap,
762                   DependenceInfo &DI, LoopInfo &LI) {
763   SmallVector<BasicBlockSet, 8> AllBlocks;
764   for (Loop *L : Root.getLoopsInPreorder())
765     if (ForeBlocksMap.find(L) != ForeBlocksMap.end())
766       AllBlocks.push_back(ForeBlocksMap.lookup(L));
767   AllBlocks.push_back(SubLoopBlocks);
768   for (Loop *L : Root.getLoopsInPreorder())
769     if (AftBlocksMap.find(L) != AftBlocksMap.end())
770       AllBlocks.push_back(AftBlocksMap.lookup(L));
771 
772   unsigned LoopDepth = Root.getLoopDepth();
773   SmallVector<Instruction *, 4> EarlierLoadsAndStores;
774   SmallVector<Instruction *, 4> CurrentLoadsAndStores;
775   for (BasicBlockSet &Blocks : AllBlocks) {
776     CurrentLoadsAndStores.clear();
777     if (!getLoadsAndStores(Blocks, CurrentLoadsAndStores))
778       return false;
779 
780     Loop *CurLoop = LI.getLoopFor((*Blocks.begin())->front().getParent());
781     unsigned CurLoopDepth = CurLoop->getLoopDepth();
782 
783     for (auto *Earlier : EarlierLoadsAndStores) {
784       Loop *EarlierLoop = LI.getLoopFor(Earlier->getParent());
785       unsigned EarlierDepth = EarlierLoop->getLoopDepth();
786       unsigned CommonLoopDepth = std::min(EarlierDepth, CurLoopDepth);
787       for (auto *Later : CurrentLoadsAndStores) {
788         if (!checkDependency(Earlier, Later, LoopDepth, CommonLoopDepth, false,
789                              DI))
790           return false;
791       }
792     }
793 
794     size_t NumInsts = CurrentLoadsAndStores.size();
795     for (size_t I = 0; I < NumInsts; ++I) {
796       for (size_t J = I; J < NumInsts; ++J) {
797         if (!checkDependency(CurrentLoadsAndStores[I], CurrentLoadsAndStores[J],
798                              LoopDepth, CurLoopDepth, true, DI))
799           return false;
800       }
801     }
802 
803     EarlierLoadsAndStores.append(CurrentLoadsAndStores.begin(),
804                                  CurrentLoadsAndStores.end());
805   }
806   return true;
807 }
808 
809 static bool isEligibleLoopForm(const Loop &Root) {
810   // Root must have a child.
811   if (Root.getSubLoops().size() != 1)
812     return false;
813 
814   const Loop *L = &Root;
815   do {
816     // All loops in Root need to be in simplify and rotated form.
817     if (!L->isLoopSimplifyForm())
818       return false;
819 
820     if (!L->isRotatedForm())
821       return false;
822 
823     if (L->getHeader()->hasAddressTaken()) {
824       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Address taken\n");
825       return false;
826     }
827 
828     unsigned SubLoopsSize = L->getSubLoops().size();
829     if (SubLoopsSize == 0)
830       return true;
831 
832     // Only one child is allowed.
833     if (SubLoopsSize != 1)
834       return false;
835 
836     // Only loops with a single exit block can be unrolled and jammed.
837     // The function getExitBlock() is used for this check, rather than
838     // getUniqueExitBlock() to ensure loops with mulitple exit edges are
839     // disallowed.
840     if (!L->getExitBlock()) {
841       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; only loops with single exit "
842                            "blocks can be unrolled and jammed.\n");
843       return false;
844     }
845 
846     // Only loops with a single exiting block can be unrolled and jammed.
847     if (!L->getExitingBlock()) {
848       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; only loops with single "
849                            "exiting blocks can be unrolled and jammed.\n");
850       return false;
851     }
852 
853     L = L->getSubLoops()[0];
854   } while (L);
855 
856   return true;
857 }
858 
859 static Loop *getInnerMostLoop(Loop *L) {
860   while (!L->getSubLoops().empty())
861     L = L->getSubLoops()[0];
862   return L;
863 }
864 
865 bool llvm::isSafeToUnrollAndJam(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
866                                 DependenceInfo &DI, LoopInfo &LI) {
867   if (!isEligibleLoopForm(*L)) {
868     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Ineligible loop form\n");
869     return false;
870   }
871 
872   /* We currently handle outer loops like this:
873         |
874     ForeFirst    <------\   }
875      Blocks             |   } ForeBlocks of L
876     ForeLast            |   }
877         |               |
878        ...              |
879         |               |
880     ForeFirst    <----\ |   }
881      Blocks           | |   } ForeBlocks of a inner loop of L
882     ForeLast          | |   }
883         |             | |
884     JamLoopFirst  <\  | |   }
885      Blocks        |  | |   } JamLoopBlocks of the innermost loop
886     JamLoopLast   -/  | |   }
887         |             | |
888     AftFirst          | |   }
889      Blocks           | |   } AftBlocks of a inner loop of L
890     AftLast     ------/ |   }
891         |               |
892        ...              |
893         |               |
894     AftFirst            |   }
895      Blocks             |   } AftBlocks of L
896     AftLast     --------/   }
897         |
898 
899     There are (theoretically) any number of blocks in ForeBlocks, SubLoopBlocks
900     and AftBlocks, providing that there is one edge from Fores to SubLoops,
901     one edge from SubLoops to Afts and a single outer loop exit (from Afts).
902     In practice we currently limit Aft blocks to a single block, and limit
903     things further in the profitablility checks of the unroll and jam pass.
904 
905     Because of the way we rearrange basic blocks, we also require that
906     the Fore blocks of L on all unrolled iterations are safe to move before the
907     blocks of the direct child of L of all iterations. So we require that the
908     phi node looping operands of ForeHeader can be moved to at least the end of
909     ForeEnd, so that we can arrange cloned Fore Blocks before the subloop and
910     match up Phi's correctly.
911 
912     i.e. The old order of blocks used to be
913            (F1)1 (F2)1 J1_1 J1_2 (A2)1 (A1)1 (F1)2 (F2)2 J2_1 J2_2 (A2)2 (A1)2.
914          It needs to be safe to transform this to
915            (F1)1 (F1)2 (F2)1 (F2)2 J1_1 J1_2 J2_1 J2_2 (A2)1 (A2)2 (A1)1 (A1)2.
916 
917     There are then a number of checks along the lines of no calls, no
918     exceptions, inner loop IV is consistent, etc. Note that for loops requiring
919     runtime unrolling, UnrollRuntimeLoopRemainder can also fail in
920     UnrollAndJamLoop if the trip count cannot be easily calculated.
921   */
922 
923   // Split blocks into Fore/SubLoop/Aft based on dominators
924   Loop *JamLoop = getInnerMostLoop(L);
925   BasicBlockSet SubLoopBlocks;
926   DenseMap<Loop *, BasicBlockSet> ForeBlocksMap;
927   DenseMap<Loop *, BasicBlockSet> AftBlocksMap;
928   if (!partitionOuterLoopBlocks(*L, *JamLoop, SubLoopBlocks, ForeBlocksMap,
929                                 AftBlocksMap, DT)) {
930     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Incompatible loop layout\n");
931     return false;
932   }
933 
934   // Aft blocks may need to move instructions to fore blocks, which becomes more
935   // difficult if there are multiple (potentially conditionally executed)
936   // blocks. For now we just exclude loops with multiple aft blocks.
937   if (AftBlocksMap[L].size() != 1) {
938     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Can't currently handle "
939                          "multiple blocks after the loop\n");
940     return false;
941   }
942 
943   // Check inner loop backedge count is consistent on all iterations of the
944   // outer loop
945   if (any_of(L->getLoopsInPreorder(), [&SE](Loop *SubLoop) {
946         return !hasIterationCountInvariantInParent(SubLoop, SE);
947       })) {
948     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Inner loop iteration count is "
949                          "not consistent on each iteration\n");
950     return false;
951   }
952 
953   // Check the loop safety info for exceptions.
954   SimpleLoopSafetyInfo LSI;
955   LSI.computeLoopSafetyInfo(L);
956   if (LSI.anyBlockMayThrow()) {
957     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Something may throw\n");
958     return false;
959   }
960 
961   // We've ruled out the easy stuff and now need to check that there are no
962   // interdependencies which may prevent us from moving the:
963   //  ForeBlocks before Subloop and AftBlocks.
964   //  Subloop before AftBlocks.
965   //  ForeBlock phi operands before the subloop
966 
967   // Make sure we can move all instructions we need to before the subloop
968   BasicBlock *Header = L->getHeader();
969   BasicBlock *Latch = L->getLoopLatch();
970   BasicBlockSet AftBlocks = AftBlocksMap[L];
971   Loop *SubLoop = L->getSubLoops()[0];
972   if (!processHeaderPhiOperands(
973           Header, Latch, AftBlocks, [&AftBlocks, &SubLoop](Instruction *I) {
974             if (SubLoop->contains(I->getParent()))
975               return false;
976             if (AftBlocks.count(I->getParent())) {
977               // If we hit a phi node in afts we know we are done (probably
978               // LCSSA)
979               if (isa<PHINode>(I))
980                 return false;
981               // Can't move instructions with side effects or memory
982               // reads/writes
983               if (I->mayHaveSideEffects() || I->mayReadOrWriteMemory())
984                 return false;
985             }
986             // Keep going
987             return true;
988           })) {
989     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; can't move required "
990                          "instructions after subloop to before it\n");
991     return false;
992   }
993 
994   // Check for memory dependencies which prohibit the unrolling we are doing.
995   // Because of the way we are unrolling Fore/Sub/Aft blocks, we need to check
996   // there are no dependencies between Fore-Sub, Fore-Aft, Sub-Aft and Sub-Sub.
997   if (!checkDependencies(*L, SubLoopBlocks, ForeBlocksMap, AftBlocksMap, DI,
998                          LI)) {
999     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; failed dependency check\n");
1000     return false;
1001   }
1002 
1003   return true;
1004 }
1005