1 //===- LoopExtractor.cpp - Extract each loop into a new function ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass wrapper around the ExtractLoop() scalar transformation to extract each
10 // top-level loop into its own new function. If the loop is the ONLY loop in a
11 // given function, it is not touched. This is a pass most useful for debugging
12 // via bugpoint.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/Transforms/IPO/LoopExtractor.h"
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/Analysis/AssumptionCache.h"
19 #include "llvm/Analysis/LoopInfo.h"
20 #include "llvm/IR/Dominators.h"
21 #include "llvm/IR/Instructions.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/IR/PassManager.h"
24 #include "llvm/InitializePasses.h"
25 #include "llvm/Pass.h"
26 #include "llvm/Transforms/IPO.h"
27 #include "llvm/Transforms/Utils.h"
28 #include "llvm/Transforms/Utils/CodeExtractor.h"
29 using namespace llvm;
30 
31 #define DEBUG_TYPE "loop-extract"
32 
33 STATISTIC(NumExtracted, "Number of loops extracted");
34 
35 namespace {
36 struct LoopExtractorLegacyPass : public ModulePass {
37   static char ID; // Pass identification, replacement for typeid
38 
39   unsigned NumLoops;
40 
41   explicit LoopExtractorLegacyPass(unsigned NumLoops = ~0)
42       : ModulePass(ID), NumLoops(NumLoops) {
43     initializeLoopExtractorLegacyPassPass(*PassRegistry::getPassRegistry());
44   }
45 
46   bool runOnModule(Module &M) override;
47 
48   void getAnalysisUsage(AnalysisUsage &AU) const override {
49     AU.addRequiredID(BreakCriticalEdgesID);
50     AU.addRequired<DominatorTreeWrapperPass>();
51     AU.addRequired<LoopInfoWrapperPass>();
52     AU.addPreserved<LoopInfoWrapperPass>();
53     AU.addRequiredID(LoopSimplifyID);
54     AU.addUsedIfAvailable<AssumptionCacheTracker>();
55   }
56 };
57 
58 struct LoopExtractor {
59   explicit LoopExtractor(
60       unsigned NumLoops,
61       function_ref<DominatorTree &(Function &)> LookupDomTree,
62       function_ref<LoopInfo &(Function &)> LookupLoopInfo,
63       function_ref<AssumptionCache *(Function &)> LookupAssumptionCache)
64       : NumLoops(NumLoops), LookupDomTree(LookupDomTree),
65         LookupLoopInfo(LookupLoopInfo),
66         LookupAssumptionCache(LookupAssumptionCache) {}
67   bool runOnModule(Module &M);
68 
69 private:
70   // The number of natural loops to extract from the program into functions.
71   unsigned NumLoops;
72 
73   function_ref<DominatorTree &(Function &)> LookupDomTree;
74   function_ref<LoopInfo &(Function &)> LookupLoopInfo;
75   function_ref<AssumptionCache *(Function &)> LookupAssumptionCache;
76 
77   bool runOnFunction(Function &F);
78 
79   bool extractLoops(Loop::iterator From, Loop::iterator To, LoopInfo &LI,
80                     DominatorTree &DT);
81   bool extractLoop(Loop *L, LoopInfo &LI, DominatorTree &DT);
82 };
83 } // namespace
84 
85 char LoopExtractorLegacyPass::ID = 0;
86 INITIALIZE_PASS_BEGIN(LoopExtractorLegacyPass, "loop-extract",
87                       "Extract loops into new functions", false, false)
88 INITIALIZE_PASS_DEPENDENCY(BreakCriticalEdges)
89 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
90 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
91 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
92 INITIALIZE_PASS_END(LoopExtractorLegacyPass, "loop-extract",
93                     "Extract loops into new functions", false, false)
94 
95 namespace {
96   /// SingleLoopExtractor - For bugpoint.
97 struct SingleLoopExtractor : public LoopExtractorLegacyPass {
98   static char ID; // Pass identification, replacement for typeid
99   SingleLoopExtractor() : LoopExtractorLegacyPass(1) {}
100 };
101 } // End anonymous namespace
102 
103 char SingleLoopExtractor::ID = 0;
104 INITIALIZE_PASS(SingleLoopExtractor, "loop-extract-single",
105                 "Extract at most one loop into a new function", false, false)
106 
107 // createLoopExtractorPass - This pass extracts all natural loops from the
108 // program into a function if it can.
109 //
110 Pass *llvm::createLoopExtractorPass() { return new LoopExtractorLegacyPass(); }
111 
112 bool LoopExtractorLegacyPass::runOnModule(Module &M) {
113   if (skipModule(M))
114     return false;
115 
116   bool Changed = false;
117   auto LookupDomTree = [this](Function &F) -> DominatorTree & {
118     return this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree();
119   };
120   auto LookupLoopInfo = [this, &Changed](Function &F) -> LoopInfo & {
121     return this->getAnalysis<LoopInfoWrapperPass>(F, &Changed).getLoopInfo();
122   };
123   auto LookupACT = [this](Function &F) -> AssumptionCache * {
124     if (auto *ACT = this->getAnalysisIfAvailable<AssumptionCacheTracker>())
125       return ACT->lookupAssumptionCache(F);
126     return nullptr;
127   };
128   return LoopExtractor(NumLoops, LookupDomTree, LookupLoopInfo, LookupACT)
129              .runOnModule(M) ||
130          Changed;
131 }
132 
133 bool LoopExtractor::runOnModule(Module &M) {
134   if (M.empty())
135     return false;
136 
137   if (!NumLoops)
138     return false;
139 
140   bool Changed = false;
141 
142   // The end of the function list may change (new functions will be added at the
143   // end), so we run from the first to the current last.
144   auto I = M.begin(), E = --M.end();
145   while (true) {
146     Function &F = *I;
147 
148     Changed |= runOnFunction(F);
149     if (!NumLoops)
150       break;
151 
152     // If this is the last function.
153     if (I == E)
154       break;
155 
156     ++I;
157   }
158   return Changed;
159 }
160 
161 bool LoopExtractor::runOnFunction(Function &F) {
162   // Do not modify `optnone` functions.
163   if (F.hasOptNone())
164     return false;
165 
166   if (F.empty())
167     return false;
168 
169   bool Changed = false;
170   LoopInfo &LI = LookupLoopInfo(F);
171 
172   // If there are no loops in the function.
173   if (LI.empty())
174     return Changed;
175 
176   DominatorTree &DT = LookupDomTree(F);
177 
178   // If there is more than one top-level loop in this function, extract all of
179   // the loops.
180   if (std::next(LI.begin()) != LI.end())
181     return Changed | extractLoops(LI.begin(), LI.end(), LI, DT);
182 
183   // Otherwise there is exactly one top-level loop.
184   Loop *TLL = *LI.begin();
185 
186   // If the loop is in LoopSimplify form, then extract it only if this function
187   // is more than a minimal wrapper around the loop.
188   if (TLL->isLoopSimplifyForm()) {
189     bool ShouldExtractLoop = false;
190 
191     // Extract the loop if the entry block doesn't branch to the loop header.
192     Instruction *EntryTI = F.getEntryBlock().getTerminator();
193     if (!isa<BranchInst>(EntryTI) ||
194         !cast<BranchInst>(EntryTI)->isUnconditional() ||
195         EntryTI->getSuccessor(0) != TLL->getHeader()) {
196       ShouldExtractLoop = true;
197     } else {
198       // Check to see if any exits from the loop are more than just return
199       // blocks.
200       SmallVector<BasicBlock *, 8> ExitBlocks;
201       TLL->getExitBlocks(ExitBlocks);
202       for (auto *ExitBlock : ExitBlocks)
203         if (!isa<ReturnInst>(ExitBlock->getTerminator())) {
204           ShouldExtractLoop = true;
205           break;
206         }
207     }
208 
209     if (ShouldExtractLoop)
210       return Changed | extractLoop(TLL, LI, DT);
211   }
212 
213   // Okay, this function is a minimal container around the specified loop.
214   // If we extract the loop, we will continue to just keep extracting it
215   // infinitely... so don't extract it. However, if the loop contains any
216   // sub-loops, extract them.
217   return Changed | extractLoops(TLL->begin(), TLL->end(), LI, DT);
218 }
219 
220 bool LoopExtractor::extractLoops(Loop::iterator From, Loop::iterator To,
221                                  LoopInfo &LI, DominatorTree &DT) {
222   bool Changed = false;
223   SmallVector<Loop *, 8> Loops;
224 
225   // Save the list of loops, as it may change.
226   Loops.assign(From, To);
227   for (Loop *L : Loops) {
228     // If LoopSimplify form is not available, stay out of trouble.
229     if (!L->isLoopSimplifyForm())
230       continue;
231 
232     Changed |= extractLoop(L, LI, DT);
233     if (!NumLoops)
234       break;
235   }
236   return Changed;
237 }
238 
239 bool LoopExtractor::extractLoop(Loop *L, LoopInfo &LI, DominatorTree &DT) {
240   assert(NumLoops != 0);
241   Function &Func = *L->getHeader()->getParent();
242   AssumptionCache *AC = LookupAssumptionCache(Func);
243   CodeExtractorAnalysisCache CEAC(Func);
244   CodeExtractor Extractor(DT, *L, false, nullptr, nullptr, AC);
245   if (Extractor.extractCodeRegion(CEAC)) {
246     LI.erase(L);
247     --NumLoops;
248     ++NumExtracted;
249     return true;
250   }
251   return false;
252 }
253 
254 // createSingleLoopExtractorPass - This pass extracts one natural loop from the
255 // program into a function if it can.  This is used by bugpoint.
256 //
257 Pass *llvm::createSingleLoopExtractorPass() {
258   return new SingleLoopExtractor();
259 }
260 
261 PreservedAnalyses LoopExtractorPass::run(Module &M, ModuleAnalysisManager &AM) {
262   auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
263   auto LookupDomTree = [&FAM](Function &F) -> DominatorTree & {
264     return FAM.getResult<DominatorTreeAnalysis>(F);
265   };
266   auto LookupLoopInfo = [&FAM](Function &F) -> LoopInfo & {
267     return FAM.getResult<LoopAnalysis>(F);
268   };
269   auto LookupAssumptionCache = [&FAM](Function &F) -> AssumptionCache * {
270     return FAM.getCachedResult<AssumptionAnalysis>(F);
271   };
272   if (!LoopExtractor(NumLoops, LookupDomTree, LookupLoopInfo,
273                      LookupAssumptionCache)
274            .runOnModule(M))
275     return PreservedAnalyses::all();
276 
277   PreservedAnalyses PA;
278   PA.preserve<LoopAnalysis>();
279   return PA;
280 }
281 
282 void LoopExtractorPass::printPipeline(
283     raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
284   static_cast<PassInfoMixin<LoopExtractorPass> *>(this)->printPipeline(
285       OS, MapClassName2PassName);
286   OS << "<";
287   if (NumLoops == 1)
288     OS << "single";
289   OS << ">";
290 }
291