1 //===-- LoopUnrollAndJam.cpp - Loop unrolling utilities -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop unroll and jam as a routine, much like
10 // LoopUnroll.cpp implements loop unroll.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/ArrayRef.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/Optional.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SmallPtrSet.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/Statistic.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/ADT/Twine.h"
24 #include "llvm/ADT/iterator_range.h"
25 #include "llvm/Analysis/AssumptionCache.h"
26 #include "llvm/Analysis/DependenceAnalysis.h"
27 #include "llvm/Analysis/DomTreeUpdater.h"
28 #include "llvm/Analysis/LoopInfo.h"
29 #include "llvm/Analysis/LoopIterator.h"
30 #include "llvm/Analysis/MustExecute.h"
31 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
32 #include "llvm/Analysis/ScalarEvolution.h"
33 #include "llvm/IR/BasicBlock.h"
34 #include "llvm/IR/DebugInfoMetadata.h"
35 #include "llvm/IR/DebugLoc.h"
36 #include "llvm/IR/DiagnosticInfo.h"
37 #include "llvm/IR/Dominators.h"
38 #include "llvm/IR/Function.h"
39 #include "llvm/IR/Instruction.h"
40 #include "llvm/IR/Instructions.h"
41 #include "llvm/IR/IntrinsicInst.h"
42 #include "llvm/IR/Use.h"
43 #include "llvm/IR/User.h"
44 #include "llvm/IR/Value.h"
45 #include "llvm/IR/ValueHandle.h"
46 #include "llvm/IR/ValueMap.h"
47 #include "llvm/Support/Casting.h"
48 #include "llvm/Support/Debug.h"
49 #include "llvm/Support/ErrorHandling.h"
50 #include "llvm/Support/GenericDomTree.h"
51 #include "llvm/Support/raw_ostream.h"
52 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
53 #include "llvm/Transforms/Utils/Cloning.h"
54 #include "llvm/Transforms/Utils/LoopUtils.h"
55 #include "llvm/Transforms/Utils/UnrollLoop.h"
56 #include "llvm/Transforms/Utils/ValueMapper.h"
57 #include <assert.h>
58 #include <memory>
59 #include <type_traits>
60 #include <vector>
61 
62 using namespace llvm;
63 
64 #define DEBUG_TYPE "loop-unroll-and-jam"
65 
66 STATISTIC(NumUnrolledAndJammed, "Number of loops unroll and jammed");
67 STATISTIC(NumCompletelyUnrolledAndJammed, "Number of loops unroll and jammed");
68 
69 typedef SmallPtrSet<BasicBlock *, 4> BasicBlockSet;
70 
71 // Partition blocks in an outer/inner loop pair into blocks before and after
72 // the loop
73 static bool partitionLoopBlocks(Loop &L, BasicBlockSet &ForeBlocks,
74                                 BasicBlockSet &AftBlocks, DominatorTree &DT) {
75   Loop *SubLoop = L.getSubLoops()[0];
76   BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
77 
78   for (BasicBlock *BB : L.blocks()) {
79     if (!SubLoop->contains(BB)) {
80       if (DT.dominates(SubLoopLatch, BB))
81         AftBlocks.insert(BB);
82       else
83         ForeBlocks.insert(BB);
84     }
85   }
86 
87   // Check that all blocks in ForeBlocks together dominate the subloop
88   // TODO: This might ideally be done better with a dominator/postdominators.
89   BasicBlock *SubLoopPreHeader = SubLoop->getLoopPreheader();
90   for (BasicBlock *BB : ForeBlocks) {
91     if (BB == SubLoopPreHeader)
92       continue;
93     Instruction *TI = BB->getTerminator();
94     for (BasicBlock *Succ : successors(TI))
95       if (!ForeBlocks.count(Succ))
96         return false;
97   }
98 
99   return true;
100 }
101 
102 /// Partition blocks in a loop nest into blocks before and after each inner
103 /// loop.
104 static bool partitionOuterLoopBlocks(
105     Loop &Root, Loop &JamLoop, BasicBlockSet &JamLoopBlocks,
106     DenseMap<Loop *, BasicBlockSet> &ForeBlocksMap,
107     DenseMap<Loop *, BasicBlockSet> &AftBlocksMap, DominatorTree &DT) {
108   JamLoopBlocks.insert(JamLoop.block_begin(), JamLoop.block_end());
109 
110   for (Loop *L : Root.getLoopsInPreorder()) {
111     if (L == &JamLoop)
112       break;
113 
114     if (!partitionLoopBlocks(*L, ForeBlocksMap[L], AftBlocksMap[L], DT))
115       return false;
116   }
117 
118   return true;
119 }
120 
121 // TODO Remove when UnrollAndJamLoop changed to support unroll and jamming more
122 // than 2 levels loop.
123 static bool partitionOuterLoopBlocks(Loop *L, Loop *SubLoop,
124                                      BasicBlockSet &ForeBlocks,
125                                      BasicBlockSet &SubLoopBlocks,
126                                      BasicBlockSet &AftBlocks,
127                                      DominatorTree *DT) {
128   SubLoopBlocks.insert(SubLoop->block_begin(), SubLoop->block_end());
129   return partitionLoopBlocks(*L, ForeBlocks, AftBlocks, *DT);
130 }
131 
132 // Looks at the phi nodes in Header for values coming from Latch. For these
133 // instructions and all their operands calls Visit on them, keeping going for
134 // all the operands in AftBlocks. Returns false if Visit returns false,
135 // otherwise returns true. This is used to process the instructions in the
136 // Aft blocks that need to be moved before the subloop. It is used in two
137 // places. One to check that the required set of instructions can be moved
138 // before the loop. Then to collect the instructions to actually move in
139 // moveHeaderPhiOperandsToForeBlocks.
140 template <typename T>
141 static bool processHeaderPhiOperands(BasicBlock *Header, BasicBlock *Latch,
142                                      BasicBlockSet &AftBlocks, T Visit) {
143   SmallVector<Instruction *, 8> Worklist;
144   SmallPtrSet<Instruction *, 8> VisitedInstr;
145   for (auto &Phi : Header->phis()) {
146     Value *V = Phi.getIncomingValueForBlock(Latch);
147     if (Instruction *I = dyn_cast<Instruction>(V))
148       Worklist.push_back(I);
149   }
150 
151   while (!Worklist.empty()) {
152     Instruction *I = Worklist.pop_back_val();
153     if (!Visit(I))
154       return false;
155     VisitedInstr.insert(I);
156 
157     if (AftBlocks.count(I->getParent()))
158       for (auto &U : I->operands())
159         if (Instruction *II = dyn_cast<Instruction>(U))
160           if (!VisitedInstr.count(II))
161             Worklist.push_back(II);
162   }
163 
164   return true;
165 }
166 
167 // Move the phi operands of Header from Latch out of AftBlocks to InsertLoc.
168 static void moveHeaderPhiOperandsToForeBlocks(BasicBlock *Header,
169                                               BasicBlock *Latch,
170                                               Instruction *InsertLoc,
171                                               BasicBlockSet &AftBlocks) {
172   // We need to ensure we move the instructions in the correct order,
173   // starting with the earliest required instruction and moving forward.
174   std::vector<Instruction *> Visited;
175   processHeaderPhiOperands(Header, Latch, AftBlocks,
176                            [&Visited, &AftBlocks](Instruction *I) {
177                              if (AftBlocks.count(I->getParent()))
178                                Visited.push_back(I);
179                              return true;
180                            });
181 
182   // Move all instructions in program order to before the InsertLoc
183   BasicBlock *InsertLocBB = InsertLoc->getParent();
184   for (Instruction *I : reverse(Visited)) {
185     if (I->getParent() != InsertLocBB)
186       I->moveBefore(InsertLoc);
187   }
188 }
189 
190 /*
191   This method performs Unroll and Jam. For a simple loop like:
192   for (i = ..)
193     Fore(i)
194     for (j = ..)
195       SubLoop(i, j)
196     Aft(i)
197 
198   Instead of doing normal inner or outer unrolling, we do:
199   for (i = .., i+=2)
200     Fore(i)
201     Fore(i+1)
202     for (j = ..)
203       SubLoop(i, j)
204       SubLoop(i+1, j)
205     Aft(i)
206     Aft(i+1)
207 
208   So the outer loop is essetially unrolled and then the inner loops are fused
209   ("jammed") together into a single loop. This can increase speed when there
210   are loads in SubLoop that are invariant to i, as they become shared between
211   the now jammed inner loops.
212 
213   We do this by spliting the blocks in the loop into Fore, Subloop and Aft.
214   Fore blocks are those before the inner loop, Aft are those after. Normal
215   Unroll code is used to copy each of these sets of blocks and the results are
216   combined together into the final form above.
217 
218   isSafeToUnrollAndJam should be used prior to calling this to make sure the
219   unrolling will be valid. Checking profitablility is also advisable.
220 
221   If EpilogueLoop is non-null, it receives the epilogue loop (if it was
222   necessary to create one and not fully unrolled).
223 */
224 LoopUnrollResult
225 llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount,
226                        unsigned TripMultiple, bool UnrollRemainder,
227                        LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT,
228                        AssumptionCache *AC, const TargetTransformInfo *TTI,
229                        OptimizationRemarkEmitter *ORE, Loop **EpilogueLoop) {
230 
231   // When we enter here we should have already checked that it is safe
232   BasicBlock *Header = L->getHeader();
233   assert(Header && "No header.");
234   assert(L->getSubLoops().size() == 1);
235   Loop *SubLoop = *L->begin();
236 
237   // Don't enter the unroll code if there is nothing to do.
238   if (TripCount == 0 && Count < 2) {
239     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; almost nothing to do\n");
240     return LoopUnrollResult::Unmodified;
241   }
242 
243   assert(Count > 0);
244   assert(TripMultiple > 0);
245   assert(TripCount == 0 || TripCount % TripMultiple == 0);
246 
247   // Are we eliminating the loop control altogether?
248   bool CompletelyUnroll = (Count == TripCount);
249 
250   // We use the runtime remainder in cases where we don't know trip multiple
251   if (TripMultiple % Count != 0) {
252     if (!UnrollRuntimeLoopRemainder(L, Count, /*AllowExpensiveTripCount*/ false,
253                                     /*UseEpilogRemainder*/ true,
254                                     UnrollRemainder, /*ForgetAllSCEV*/ false,
255                                     LI, SE, DT, AC, TTI, true, EpilogueLoop)) {
256       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; remainder loop could not be "
257                            "generated when assuming runtime trip count\n");
258       return LoopUnrollResult::Unmodified;
259     }
260   }
261 
262   // Notify ScalarEvolution that the loop will be substantially changed,
263   // if not outright eliminated.
264   if (SE) {
265     SE->forgetLoop(L);
266     SE->forgetLoop(SubLoop);
267   }
268 
269   using namespace ore;
270   // Report the unrolling decision.
271   if (CompletelyUnroll) {
272     LLVM_DEBUG(dbgs() << "COMPLETELY UNROLL AND JAMMING loop %"
273                       << Header->getName() << " with trip count " << TripCount
274                       << "!\n");
275     ORE->emit(OptimizationRemark(DEBUG_TYPE, "FullyUnrolled", L->getStartLoc(),
276                                  L->getHeader())
277               << "completely unroll and jammed loop with "
278               << NV("UnrollCount", TripCount) << " iterations");
279   } else {
280     auto DiagBuilder = [&]() {
281       OptimizationRemark Diag(DEBUG_TYPE, "PartialUnrolled", L->getStartLoc(),
282                               L->getHeader());
283       return Diag << "unroll and jammed loop by a factor of "
284                   << NV("UnrollCount", Count);
285     };
286 
287     LLVM_DEBUG(dbgs() << "UNROLL AND JAMMING loop %" << Header->getName()
288                       << " by " << Count);
289     if (TripMultiple != 1) {
290       LLVM_DEBUG(dbgs() << " with " << TripMultiple << " trips per branch");
291       ORE->emit([&]() {
292         return DiagBuilder() << " with " << NV("TripMultiple", TripMultiple)
293                              << " trips per branch";
294       });
295     } else {
296       LLVM_DEBUG(dbgs() << " with run-time trip count");
297       ORE->emit([&]() { return DiagBuilder() << " with run-time trip count"; });
298     }
299     LLVM_DEBUG(dbgs() << "!\n");
300   }
301 
302   BasicBlock *Preheader = L->getLoopPreheader();
303   BasicBlock *LatchBlock = L->getLoopLatch();
304   assert(Preheader && "No preheader");
305   assert(LatchBlock && "No latch block");
306   BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator());
307   assert(BI && !BI->isUnconditional());
308   bool ContinueOnTrue = L->contains(BI->getSuccessor(0));
309   BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue);
310   bool SubLoopContinueOnTrue = SubLoop->contains(
311       SubLoop->getLoopLatch()->getTerminator()->getSuccessor(0));
312 
313   // Partition blocks in an outer/inner loop pair into blocks before and after
314   // the loop
315   BasicBlockSet SubLoopBlocks;
316   BasicBlockSet ForeBlocks;
317   BasicBlockSet AftBlocks;
318   partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks, AftBlocks,
319                            DT);
320 
321   // We keep track of the entering/first and exiting/last block of each of
322   // Fore/SubLoop/Aft in each iteration. This helps make the stapling up of
323   // blocks easier.
324   std::vector<BasicBlock *> ForeBlocksFirst;
325   std::vector<BasicBlock *> ForeBlocksLast;
326   std::vector<BasicBlock *> SubLoopBlocksFirst;
327   std::vector<BasicBlock *> SubLoopBlocksLast;
328   std::vector<BasicBlock *> AftBlocksFirst;
329   std::vector<BasicBlock *> AftBlocksLast;
330   ForeBlocksFirst.push_back(Header);
331   ForeBlocksLast.push_back(SubLoop->getLoopPreheader());
332   SubLoopBlocksFirst.push_back(SubLoop->getHeader());
333   SubLoopBlocksLast.push_back(SubLoop->getExitingBlock());
334   AftBlocksFirst.push_back(SubLoop->getExitBlock());
335   AftBlocksLast.push_back(L->getExitingBlock());
336   // Maps Blocks[0] -> Blocks[It]
337   ValueToValueMapTy LastValueMap;
338 
339   // Move any instructions from fore phi operands from AftBlocks into Fore.
340   moveHeaderPhiOperandsToForeBlocks(
341       Header, LatchBlock, ForeBlocksLast[0]->getTerminator(), AftBlocks);
342 
343   // The current on-the-fly SSA update requires blocks to be processed in
344   // reverse postorder so that LastValueMap contains the correct value at each
345   // exit.
346   LoopBlocksDFS DFS(L);
347   DFS.perform(LI);
348   // Stash the DFS iterators before adding blocks to the loop.
349   LoopBlocksDFS::RPOIterator BlockBegin = DFS.beginRPO();
350   LoopBlocksDFS::RPOIterator BlockEnd = DFS.endRPO();
351 
352   // When a FSDiscriminator is enabled, we don't need to add the multiply
353   // factors to the discriminators.
354   if (Header->getParent()->isDebugInfoForProfiling() && !EnableFSDiscriminator)
355     for (BasicBlock *BB : L->getBlocks())
356       for (Instruction &I : *BB)
357         if (!isa<DbgInfoIntrinsic>(&I))
358           if (const DILocation *DIL = I.getDebugLoc()) {
359             auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(Count);
360             if (NewDIL)
361               I.setDebugLoc(NewDIL.getValue());
362             else
363               LLVM_DEBUG(dbgs()
364                          << "Failed to create new discriminator: "
365                          << DIL->getFilename() << " Line: " << DIL->getLine());
366           }
367 
368   // Copy all blocks
369   for (unsigned It = 1; It != Count; ++It) {
370     SmallVector<BasicBlock *, 8> NewBlocks;
371     // Maps Blocks[It] -> Blocks[It-1]
372     DenseMap<Value *, Value *> PrevItValueMap;
373     SmallDenseMap<const Loop *, Loop *, 4> NewLoops;
374     NewLoops[L] = L;
375     NewLoops[SubLoop] = SubLoop;
376 
377     for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) {
378       ValueToValueMapTy VMap;
379       BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It));
380       Header->getParent()->getBasicBlockList().push_back(New);
381 
382       // Tell LI about New.
383       addClonedBlockToLoopInfo(*BB, New, LI, NewLoops);
384 
385       if (ForeBlocks.count(*BB)) {
386         if (*BB == ForeBlocksFirst[0])
387           ForeBlocksFirst.push_back(New);
388         if (*BB == ForeBlocksLast[0])
389           ForeBlocksLast.push_back(New);
390       } else if (SubLoopBlocks.count(*BB)) {
391         if (*BB == SubLoopBlocksFirst[0])
392           SubLoopBlocksFirst.push_back(New);
393         if (*BB == SubLoopBlocksLast[0])
394           SubLoopBlocksLast.push_back(New);
395       } else if (AftBlocks.count(*BB)) {
396         if (*BB == AftBlocksFirst[0])
397           AftBlocksFirst.push_back(New);
398         if (*BB == AftBlocksLast[0])
399           AftBlocksLast.push_back(New);
400       } else {
401         llvm_unreachable("BB being cloned should be in Fore/Sub/Aft");
402       }
403 
404       // Update our running maps of newest clones
405       PrevItValueMap[New] = (It == 1 ? *BB : LastValueMap[*BB]);
406       LastValueMap[*BB] = New;
407       for (ValueToValueMapTy::iterator VI = VMap.begin(), VE = VMap.end();
408            VI != VE; ++VI) {
409         PrevItValueMap[VI->second] =
410             const_cast<Value *>(It == 1 ? VI->first : LastValueMap[VI->first]);
411         LastValueMap[VI->first] = VI->second;
412       }
413 
414       NewBlocks.push_back(New);
415 
416       // Update DomTree:
417       if (*BB == ForeBlocksFirst[0])
418         DT->addNewBlock(New, ForeBlocksLast[It - 1]);
419       else if (*BB == SubLoopBlocksFirst[0])
420         DT->addNewBlock(New, SubLoopBlocksLast[It - 1]);
421       else if (*BB == AftBlocksFirst[0])
422         DT->addNewBlock(New, AftBlocksLast[It - 1]);
423       else {
424         // Each set of blocks (Fore/Sub/Aft) will have the same internal domtree
425         // structure.
426         auto BBDomNode = DT->getNode(*BB);
427         auto BBIDom = BBDomNode->getIDom();
428         BasicBlock *OriginalBBIDom = BBIDom->getBlock();
429         assert(OriginalBBIDom);
430         assert(LastValueMap[cast<Value>(OriginalBBIDom)]);
431         DT->addNewBlock(
432             New, cast<BasicBlock>(LastValueMap[cast<Value>(OriginalBBIDom)]));
433       }
434     }
435 
436     // Remap all instructions in the most recent iteration
437     remapInstructionsInBlocks(NewBlocks, LastValueMap);
438     for (BasicBlock *NewBlock : NewBlocks) {
439       for (Instruction &I : *NewBlock) {
440         if (auto *II = dyn_cast<AssumeInst>(&I))
441           AC->registerAssumption(II);
442       }
443     }
444 
445     // Alter the ForeBlocks phi's, pointing them at the latest version of the
446     // value from the previous iteration's phis
447     for (PHINode &Phi : ForeBlocksFirst[It]->phis()) {
448       Value *OldValue = Phi.getIncomingValueForBlock(AftBlocksLast[It]);
449       assert(OldValue && "should have incoming edge from Aft[It]");
450       Value *NewValue = OldValue;
451       if (Value *PrevValue = PrevItValueMap[OldValue])
452         NewValue = PrevValue;
453 
454       assert(Phi.getNumOperands() == 2);
455       Phi.setIncomingBlock(0, ForeBlocksLast[It - 1]);
456       Phi.setIncomingValue(0, NewValue);
457       Phi.removeIncomingValue(1);
458     }
459   }
460 
461   // Now that all the basic blocks for the unrolled iterations are in place,
462   // finish up connecting the blocks and phi nodes. At this point LastValueMap
463   // is the last unrolled iterations values.
464 
465   // Update Phis in BB from OldBB to point to NewBB and use the latest value
466   // from LastValueMap
467   auto updatePHIBlocksAndValues = [](BasicBlock *BB, BasicBlock *OldBB,
468                                      BasicBlock *NewBB,
469                                      ValueToValueMapTy &LastValueMap) {
470     for (PHINode &Phi : BB->phis()) {
471       for (unsigned b = 0; b < Phi.getNumIncomingValues(); ++b) {
472         if (Phi.getIncomingBlock(b) == OldBB) {
473           Value *OldValue = Phi.getIncomingValue(b);
474           if (Value *LastValue = LastValueMap[OldValue])
475             Phi.setIncomingValue(b, LastValue);
476           Phi.setIncomingBlock(b, NewBB);
477           break;
478         }
479       }
480     }
481   };
482   // Move all the phis from Src into Dest
483   auto movePHIs = [](BasicBlock *Src, BasicBlock *Dest) {
484     Instruction *insertPoint = Dest->getFirstNonPHI();
485     while (PHINode *Phi = dyn_cast<PHINode>(Src->begin()))
486       Phi->moveBefore(insertPoint);
487   };
488 
489   // Update the PHI values outside the loop to point to the last block
490   updatePHIBlocksAndValues(LoopExit, AftBlocksLast[0], AftBlocksLast.back(),
491                            LastValueMap);
492 
493   // Update ForeBlocks successors and phi nodes
494   BranchInst *ForeTerm =
495       cast<BranchInst>(ForeBlocksLast.back()->getTerminator());
496   assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
497   ForeTerm->setSuccessor(0, SubLoopBlocksFirst[0]);
498 
499   if (CompletelyUnroll) {
500     while (PHINode *Phi = dyn_cast<PHINode>(ForeBlocksFirst[0]->begin())) {
501       Phi->replaceAllUsesWith(Phi->getIncomingValueForBlock(Preheader));
502       Phi->getParent()->getInstList().erase(Phi);
503     }
504   } else {
505     // Update the PHI values to point to the last aft block
506     updatePHIBlocksAndValues(ForeBlocksFirst[0], AftBlocksLast[0],
507                              AftBlocksLast.back(), LastValueMap);
508   }
509 
510   for (unsigned It = 1; It != Count; It++) {
511     // Remap ForeBlock successors from previous iteration to this
512     BranchInst *ForeTerm =
513         cast<BranchInst>(ForeBlocksLast[It - 1]->getTerminator());
514     assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
515     ForeTerm->setSuccessor(0, ForeBlocksFirst[It]);
516   }
517 
518   // Subloop successors and phis
519   BranchInst *SubTerm =
520       cast<BranchInst>(SubLoopBlocksLast.back()->getTerminator());
521   SubTerm->setSuccessor(!SubLoopContinueOnTrue, SubLoopBlocksFirst[0]);
522   SubTerm->setSuccessor(SubLoopContinueOnTrue, AftBlocksFirst[0]);
523   SubLoopBlocksFirst[0]->replacePhiUsesWith(ForeBlocksLast[0],
524                                             ForeBlocksLast.back());
525   SubLoopBlocksFirst[0]->replacePhiUsesWith(SubLoopBlocksLast[0],
526                                             SubLoopBlocksLast.back());
527 
528   for (unsigned It = 1; It != Count; It++) {
529     // Replace the conditional branch of the previous iteration subloop with an
530     // unconditional one to this one
531     BranchInst *SubTerm =
532         cast<BranchInst>(SubLoopBlocksLast[It - 1]->getTerminator());
533     BranchInst::Create(SubLoopBlocksFirst[It], SubTerm);
534     SubTerm->eraseFromParent();
535 
536     SubLoopBlocksFirst[It]->replacePhiUsesWith(ForeBlocksLast[It],
537                                                ForeBlocksLast.back());
538     SubLoopBlocksFirst[It]->replacePhiUsesWith(SubLoopBlocksLast[It],
539                                                SubLoopBlocksLast.back());
540     movePHIs(SubLoopBlocksFirst[It], SubLoopBlocksFirst[0]);
541   }
542 
543   // Aft blocks successors and phis
544   BranchInst *AftTerm = cast<BranchInst>(AftBlocksLast.back()->getTerminator());
545   if (CompletelyUnroll) {
546     BranchInst::Create(LoopExit, AftTerm);
547     AftTerm->eraseFromParent();
548   } else {
549     AftTerm->setSuccessor(!ContinueOnTrue, ForeBlocksFirst[0]);
550     assert(AftTerm->getSuccessor(ContinueOnTrue) == LoopExit &&
551            "Expecting the ContinueOnTrue successor of AftTerm to be LoopExit");
552   }
553   AftBlocksFirst[0]->replacePhiUsesWith(SubLoopBlocksLast[0],
554                                         SubLoopBlocksLast.back());
555 
556   for (unsigned It = 1; It != Count; It++) {
557     // Replace the conditional branch of the previous iteration subloop with an
558     // unconditional one to this one
559     BranchInst *AftTerm =
560         cast<BranchInst>(AftBlocksLast[It - 1]->getTerminator());
561     BranchInst::Create(AftBlocksFirst[It], AftTerm);
562     AftTerm->eraseFromParent();
563 
564     AftBlocksFirst[It]->replacePhiUsesWith(SubLoopBlocksLast[It],
565                                            SubLoopBlocksLast.back());
566     movePHIs(AftBlocksFirst[It], AftBlocksFirst[0]);
567   }
568 
569   DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
570   // Dominator Tree. Remove the old links between Fore, Sub and Aft, adding the
571   // new ones required.
572   if (Count != 1) {
573     SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
574     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete, ForeBlocksLast[0],
575                            SubLoopBlocksFirst[0]);
576     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete,
577                            SubLoopBlocksLast[0], AftBlocksFirst[0]);
578 
579     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
580                            ForeBlocksLast.back(), SubLoopBlocksFirst[0]);
581     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
582                            SubLoopBlocksLast.back(), AftBlocksFirst[0]);
583     DTU.applyUpdatesPermissive(DTUpdates);
584   }
585 
586   // Merge adjacent basic blocks, if possible.
587   SmallPtrSet<BasicBlock *, 16> MergeBlocks;
588   MergeBlocks.insert(ForeBlocksLast.begin(), ForeBlocksLast.end());
589   MergeBlocks.insert(SubLoopBlocksLast.begin(), SubLoopBlocksLast.end());
590   MergeBlocks.insert(AftBlocksLast.begin(), AftBlocksLast.end());
591 
592   MergeBlockSuccessorsIntoGivenBlocks(MergeBlocks, L, &DTU, LI);
593 
594   // Apply updates to the DomTree.
595   DT = &DTU.getDomTree();
596 
597   // At this point, the code is well formed.  We now do a quick sweep over the
598   // inserted code, doing constant propagation and dead code elimination as we
599   // go.
600   simplifyLoopAfterUnroll(SubLoop, true, LI, SE, DT, AC, TTI);
601   simplifyLoopAfterUnroll(L, !CompletelyUnroll && Count > 1, LI, SE, DT, AC,
602                           TTI);
603 
604   NumCompletelyUnrolledAndJammed += CompletelyUnroll;
605   ++NumUnrolledAndJammed;
606 
607   // Update LoopInfo if the loop is completely removed.
608   if (CompletelyUnroll)
609     LI->erase(L);
610 
611 #ifndef NDEBUG
612   // We shouldn't have done anything to break loop simplify form or LCSSA.
613   Loop *OutestLoop = SubLoop->getParentLoop()
614                          ? SubLoop->getParentLoop()->getParentLoop()
615                                ? SubLoop->getParentLoop()->getParentLoop()
616                                : SubLoop->getParentLoop()
617                          : SubLoop;
618   assert(DT->verify());
619   LI->verify(*DT);
620   assert(OutestLoop->isRecursivelyLCSSAForm(*DT, *LI));
621   if (!CompletelyUnroll)
622     assert(L->isLoopSimplifyForm());
623   assert(SubLoop->isLoopSimplifyForm());
624   SE->verify();
625 #endif
626 
627   return CompletelyUnroll ? LoopUnrollResult::FullyUnrolled
628                           : LoopUnrollResult::PartiallyUnrolled;
629 }
630 
631 static bool getLoadsAndStores(BasicBlockSet &Blocks,
632                               SmallVector<Instruction *, 4> &MemInstr) {
633   // Scan the BBs and collect legal loads and stores.
634   // Returns false if non-simple loads/stores are found.
635   for (BasicBlock *BB : Blocks) {
636     for (Instruction &I : *BB) {
637       if (auto *Ld = dyn_cast<LoadInst>(&I)) {
638         if (!Ld->isSimple())
639           return false;
640         MemInstr.push_back(&I);
641       } else if (auto *St = dyn_cast<StoreInst>(&I)) {
642         if (!St->isSimple())
643           return false;
644         MemInstr.push_back(&I);
645       } else if (I.mayReadOrWriteMemory()) {
646         return false;
647       }
648     }
649   }
650   return true;
651 }
652 
653 static bool preservesForwardDependence(Instruction *Src, Instruction *Dst,
654                                        unsigned UnrollLevel, unsigned JamLevel,
655                                        bool Sequentialized, Dependence *D) {
656   // UnrollLevel might carry the dependency Src --> Dst
657   // Does a different loop after unrolling?
658   for (unsigned CurLoopDepth = UnrollLevel + 1; CurLoopDepth <= JamLevel;
659        ++CurLoopDepth) {
660     auto JammedDir = D->getDirection(CurLoopDepth);
661     if (JammedDir == Dependence::DVEntry::LT)
662       return true;
663 
664     if (JammedDir & Dependence::DVEntry::GT)
665       return false;
666   }
667 
668   return true;
669 }
670 
671 static bool preservesBackwardDependence(Instruction *Src, Instruction *Dst,
672                                         unsigned UnrollLevel, unsigned JamLevel,
673                                         bool Sequentialized, Dependence *D) {
674   // UnrollLevel might carry the dependency Dst --> Src
675   for (unsigned CurLoopDepth = UnrollLevel + 1; CurLoopDepth <= JamLevel;
676        ++CurLoopDepth) {
677     auto JammedDir = D->getDirection(CurLoopDepth);
678     if (JammedDir == Dependence::DVEntry::GT)
679       return true;
680 
681     if (JammedDir & Dependence::DVEntry::LT)
682       return false;
683   }
684 
685   // Backward dependencies are only preserved if not interleaved.
686   return Sequentialized;
687 }
688 
689 // Check whether it is semantically safe Src and Dst considering any potential
690 // dependency between them.
691 //
692 // @param UnrollLevel The level of the loop being unrolled
693 // @param JamLevel    The level of the loop being jammed; if Src and Dst are on
694 // different levels, the outermost common loop counts as jammed level
695 //
696 // @return true if is safe and false if there is a dependency violation.
697 static bool checkDependency(Instruction *Src, Instruction *Dst,
698                             unsigned UnrollLevel, unsigned JamLevel,
699                             bool Sequentialized, DependenceInfo &DI) {
700   assert(UnrollLevel <= JamLevel &&
701          "Expecting JamLevel to be at least UnrollLevel");
702 
703   if (Src == Dst)
704     return true;
705   // Ignore Input dependencies.
706   if (isa<LoadInst>(Src) && isa<LoadInst>(Dst))
707     return true;
708 
709   // Check whether unroll-and-jam may violate a dependency.
710   // By construction, every dependency will be lexicographically non-negative
711   // (if it was, it would violate the current execution order), such as
712   //   (0,0,>,*,*)
713   // Unroll-and-jam changes the GT execution of two executions to the same
714   // iteration of the chosen unroll level. That is, a GT dependence becomes a GE
715   // dependence (or EQ, if we fully unrolled the loop) at the loop's position:
716   //   (0,0,>=,*,*)
717   // Now, the dependency is not necessarily non-negative anymore, i.e.
718   // unroll-and-jam may violate correctness.
719   std::unique_ptr<Dependence> D = DI.depends(Src, Dst, true);
720   if (!D)
721     return true;
722   assert(D->isOrdered() && "Expected an output, flow or anti dep.");
723 
724   if (D->isConfused()) {
725     LLVM_DEBUG(dbgs() << "  Confused dependency between:\n"
726                       << "  " << *Src << "\n"
727                       << "  " << *Dst << "\n");
728     return false;
729   }
730 
731   // If outer levels (levels enclosing the loop being unroll-and-jammed) have a
732   // non-equal direction, then the locations accessed in the inner levels cannot
733   // overlap in memory. We assumes the indexes never overlap into neighboring
734   // dimensions.
735   for (unsigned CurLoopDepth = 1; CurLoopDepth < UnrollLevel; ++CurLoopDepth)
736     if (!(D->getDirection(CurLoopDepth) & Dependence::DVEntry::EQ))
737       return true;
738 
739   auto UnrollDirection = D->getDirection(UnrollLevel);
740 
741   // If the distance carried by the unrolled loop is 0, then after unrolling
742   // that distance will become non-zero resulting in non-overlapping accesses in
743   // the inner loops.
744   if (UnrollDirection == Dependence::DVEntry::EQ)
745     return true;
746 
747   if (UnrollDirection & Dependence::DVEntry::LT &&
748       !preservesForwardDependence(Src, Dst, UnrollLevel, JamLevel,
749                                   Sequentialized, D.get()))
750     return false;
751 
752   if (UnrollDirection & Dependence::DVEntry::GT &&
753       !preservesBackwardDependence(Src, Dst, UnrollLevel, JamLevel,
754                                    Sequentialized, D.get()))
755     return false;
756 
757   return true;
758 }
759 
760 static bool
761 checkDependencies(Loop &Root, const BasicBlockSet &SubLoopBlocks,
762                   const DenseMap<Loop *, BasicBlockSet> &ForeBlocksMap,
763                   const DenseMap<Loop *, BasicBlockSet> &AftBlocksMap,
764                   DependenceInfo &DI, LoopInfo &LI) {
765   SmallVector<BasicBlockSet, 8> AllBlocks;
766   for (Loop *L : Root.getLoopsInPreorder())
767     if (ForeBlocksMap.find(L) != ForeBlocksMap.end())
768       AllBlocks.push_back(ForeBlocksMap.lookup(L));
769   AllBlocks.push_back(SubLoopBlocks);
770   for (Loop *L : Root.getLoopsInPreorder())
771     if (AftBlocksMap.find(L) != AftBlocksMap.end())
772       AllBlocks.push_back(AftBlocksMap.lookup(L));
773 
774   unsigned LoopDepth = Root.getLoopDepth();
775   SmallVector<Instruction *, 4> EarlierLoadsAndStores;
776   SmallVector<Instruction *, 4> CurrentLoadsAndStores;
777   for (BasicBlockSet &Blocks : AllBlocks) {
778     CurrentLoadsAndStores.clear();
779     if (!getLoadsAndStores(Blocks, CurrentLoadsAndStores))
780       return false;
781 
782     Loop *CurLoop = LI.getLoopFor((*Blocks.begin())->front().getParent());
783     unsigned CurLoopDepth = CurLoop->getLoopDepth();
784 
785     for (auto *Earlier : EarlierLoadsAndStores) {
786       Loop *EarlierLoop = LI.getLoopFor(Earlier->getParent());
787       unsigned EarlierDepth = EarlierLoop->getLoopDepth();
788       unsigned CommonLoopDepth = std::min(EarlierDepth, CurLoopDepth);
789       for (auto *Later : CurrentLoadsAndStores) {
790         if (!checkDependency(Earlier, Later, LoopDepth, CommonLoopDepth, false,
791                              DI))
792           return false;
793       }
794     }
795 
796     size_t NumInsts = CurrentLoadsAndStores.size();
797     for (size_t I = 0; I < NumInsts; ++I) {
798       for (size_t J = I; J < NumInsts; ++J) {
799         if (!checkDependency(CurrentLoadsAndStores[I], CurrentLoadsAndStores[J],
800                              LoopDepth, CurLoopDepth, true, DI))
801           return false;
802       }
803     }
804 
805     EarlierLoadsAndStores.append(CurrentLoadsAndStores.begin(),
806                                  CurrentLoadsAndStores.end());
807   }
808   return true;
809 }
810 
811 static bool isEligibleLoopForm(const Loop &Root) {
812   // Root must have a child.
813   if (Root.getSubLoops().size() != 1)
814     return false;
815 
816   const Loop *L = &Root;
817   do {
818     // All loops in Root need to be in simplify and rotated form.
819     if (!L->isLoopSimplifyForm())
820       return false;
821 
822     if (!L->isRotatedForm())
823       return false;
824 
825     if (L->getHeader()->hasAddressTaken()) {
826       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Address taken\n");
827       return false;
828     }
829 
830     unsigned SubLoopsSize = L->getSubLoops().size();
831     if (SubLoopsSize == 0)
832       return true;
833 
834     // Only one child is allowed.
835     if (SubLoopsSize != 1)
836       return false;
837 
838     // Only loops with a single exit block can be unrolled and jammed.
839     // The function getExitBlock() is used for this check, rather than
840     // getUniqueExitBlock() to ensure loops with mulitple exit edges are
841     // disallowed.
842     if (!L->getExitBlock()) {
843       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; only loops with single exit "
844                            "blocks can be unrolled and jammed.\n");
845       return false;
846     }
847 
848     // Only loops with a single exiting block can be unrolled and jammed.
849     if (!L->getExitingBlock()) {
850       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; only loops with single "
851                            "exiting blocks can be unrolled and jammed.\n");
852       return false;
853     }
854 
855     L = L->getSubLoops()[0];
856   } while (L);
857 
858   return true;
859 }
860 
861 static Loop *getInnerMostLoop(Loop *L) {
862   while (!L->getSubLoops().empty())
863     L = L->getSubLoops()[0];
864   return L;
865 }
866 
867 bool llvm::isSafeToUnrollAndJam(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
868                                 DependenceInfo &DI, LoopInfo &LI) {
869   if (!isEligibleLoopForm(*L)) {
870     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Ineligible loop form\n");
871     return false;
872   }
873 
874   /* We currently handle outer loops like this:
875         |
876     ForeFirst    <------\   }
877      Blocks             |   } ForeBlocks of L
878     ForeLast            |   }
879         |               |
880        ...              |
881         |               |
882     ForeFirst    <----\ |   }
883      Blocks           | |   } ForeBlocks of a inner loop of L
884     ForeLast          | |   }
885         |             | |
886     JamLoopFirst  <\  | |   }
887      Blocks        |  | |   } JamLoopBlocks of the innermost loop
888     JamLoopLast   -/  | |   }
889         |             | |
890     AftFirst          | |   }
891      Blocks           | |   } AftBlocks of a inner loop of L
892     AftLast     ------/ |   }
893         |               |
894        ...              |
895         |               |
896     AftFirst            |   }
897      Blocks             |   } AftBlocks of L
898     AftLast     --------/   }
899         |
900 
901     There are (theoretically) any number of blocks in ForeBlocks, SubLoopBlocks
902     and AftBlocks, providing that there is one edge from Fores to SubLoops,
903     one edge from SubLoops to Afts and a single outer loop exit (from Afts).
904     In practice we currently limit Aft blocks to a single block, and limit
905     things further in the profitablility checks of the unroll and jam pass.
906 
907     Because of the way we rearrange basic blocks, we also require that
908     the Fore blocks of L on all unrolled iterations are safe to move before the
909     blocks of the direct child of L of all iterations. So we require that the
910     phi node looping operands of ForeHeader can be moved to at least the end of
911     ForeEnd, so that we can arrange cloned Fore Blocks before the subloop and
912     match up Phi's correctly.
913 
914     i.e. The old order of blocks used to be
915            (F1)1 (F2)1 J1_1 J1_2 (A2)1 (A1)1 (F1)2 (F2)2 J2_1 J2_2 (A2)2 (A1)2.
916          It needs to be safe to transform this to
917            (F1)1 (F1)2 (F2)1 (F2)2 J1_1 J1_2 J2_1 J2_2 (A2)1 (A2)2 (A1)1 (A1)2.
918 
919     There are then a number of checks along the lines of no calls, no
920     exceptions, inner loop IV is consistent, etc. Note that for loops requiring
921     runtime unrolling, UnrollRuntimeLoopRemainder can also fail in
922     UnrollAndJamLoop if the trip count cannot be easily calculated.
923   */
924 
925   // Split blocks into Fore/SubLoop/Aft based on dominators
926   Loop *JamLoop = getInnerMostLoop(L);
927   BasicBlockSet SubLoopBlocks;
928   DenseMap<Loop *, BasicBlockSet> ForeBlocksMap;
929   DenseMap<Loop *, BasicBlockSet> AftBlocksMap;
930   if (!partitionOuterLoopBlocks(*L, *JamLoop, SubLoopBlocks, ForeBlocksMap,
931                                 AftBlocksMap, DT)) {
932     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Incompatible loop layout\n");
933     return false;
934   }
935 
936   // Aft blocks may need to move instructions to fore blocks, which becomes more
937   // difficult if there are multiple (potentially conditionally executed)
938   // blocks. For now we just exclude loops with multiple aft blocks.
939   if (AftBlocksMap[L].size() != 1) {
940     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Can't currently handle "
941                          "multiple blocks after the loop\n");
942     return false;
943   }
944 
945   // Check inner loop backedge count is consistent on all iterations of the
946   // outer loop
947   if (any_of(L->getLoopsInPreorder(), [&SE](Loop *SubLoop) {
948         return !hasIterationCountInvariantInParent(SubLoop, SE);
949       })) {
950     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Inner loop iteration count is "
951                          "not consistent on each iteration\n");
952     return false;
953   }
954 
955   // Check the loop safety info for exceptions.
956   SimpleLoopSafetyInfo LSI;
957   LSI.computeLoopSafetyInfo(L);
958   if (LSI.anyBlockMayThrow()) {
959     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Something may throw\n");
960     return false;
961   }
962 
963   // We've ruled out the easy stuff and now need to check that there are no
964   // interdependencies which may prevent us from moving the:
965   //  ForeBlocks before Subloop and AftBlocks.
966   //  Subloop before AftBlocks.
967   //  ForeBlock phi operands before the subloop
968 
969   // Make sure we can move all instructions we need to before the subloop
970   BasicBlock *Header = L->getHeader();
971   BasicBlock *Latch = L->getLoopLatch();
972   BasicBlockSet AftBlocks = AftBlocksMap[L];
973   Loop *SubLoop = L->getSubLoops()[0];
974   if (!processHeaderPhiOperands(
975           Header, Latch, AftBlocks, [&AftBlocks, &SubLoop](Instruction *I) {
976             if (SubLoop->contains(I->getParent()))
977               return false;
978             if (AftBlocks.count(I->getParent())) {
979               // If we hit a phi node in afts we know we are done (probably
980               // LCSSA)
981               if (isa<PHINode>(I))
982                 return false;
983               // Can't move instructions with side effects or memory
984               // reads/writes
985               if (I->mayHaveSideEffects() || I->mayReadOrWriteMemory())
986                 return false;
987             }
988             // Keep going
989             return true;
990           })) {
991     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; can't move required "
992                          "instructions after subloop to before it\n");
993     return false;
994   }
995 
996   // Check for memory dependencies which prohibit the unrolling we are doing.
997   // Because of the way we are unrolling Fore/Sub/Aft blocks, we need to check
998   // there are no dependencies between Fore-Sub, Fore-Aft, Sub-Aft and Sub-Sub.
999   if (!checkDependencies(*L, SubLoopBlocks, ForeBlocksMap, AftBlocksMap, DI,
1000                          LI)) {
1001     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; failed dependency check\n");
1002     return false;
1003   }
1004 
1005   return true;
1006 }
1007