1 //===-------- LoopDataPrefetch.cpp - Loop Data Prefetching Pass -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a Loop Data Prefetching Pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/Transforms/Scalar/LoopDataPrefetch.h"
14 #include "llvm/InitializePasses.h"
15 
16 #define DEBUG_TYPE "loop-data-prefetch"
17 #include "llvm/ADT/DepthFirstIterator.h"
18 #include "llvm/ADT/Statistic.h"
19 #include "llvm/Analysis/AssumptionCache.h"
20 #include "llvm/Analysis/CodeMetrics.h"
21 #include "llvm/Analysis/LoopInfo.h"
22 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
23 #include "llvm/Analysis/ScalarEvolution.h"
24 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
25 #include "llvm/Analysis/TargetTransformInfo.h"
26 #include "llvm/IR/CFG.h"
27 #include "llvm/IR/Dominators.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/Module.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Transforms/Scalar.h"
33 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
34 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
35 #include "llvm/Transforms/Utils/ValueMapper.h"
36 using namespace llvm;
37 
38 // By default, we limit this to creating 16 PHIs (which is a little over half
39 // of the allocatable register set).
40 static cl::opt<bool>
41 PrefetchWrites("loop-prefetch-writes", cl::Hidden, cl::init(false),
42                cl::desc("Prefetch write addresses"));
43 
44 static cl::opt<unsigned>
45     PrefetchDistance("prefetch-distance",
46                      cl::desc("Number of instructions to prefetch ahead"),
47                      cl::Hidden);
48 
49 static cl::opt<unsigned>
50     MinPrefetchStride("min-prefetch-stride",
51                       cl::desc("Min stride to add prefetches"), cl::Hidden);
52 
53 static cl::opt<unsigned> MaxPrefetchIterationsAhead(
54     "max-prefetch-iters-ahead",
55     cl::desc("Max number of iterations to prefetch ahead"), cl::Hidden);
56 
57 STATISTIC(NumPrefetches, "Number of prefetches inserted");
58 
59 namespace {
60 
61 /// Loop prefetch implementation class.
62 class LoopDataPrefetch {
63 public:
LoopDataPrefetch(AssumptionCache * AC,DominatorTree * DT,LoopInfo * LI,ScalarEvolution * SE,const TargetTransformInfo * TTI,OptimizationRemarkEmitter * ORE)64   LoopDataPrefetch(AssumptionCache *AC, DominatorTree *DT, LoopInfo *LI,
65                    ScalarEvolution *SE, const TargetTransformInfo *TTI,
66                    OptimizationRemarkEmitter *ORE)
67       : AC(AC), DT(DT), LI(LI), SE(SE), TTI(TTI), ORE(ORE) {}
68 
69   bool run();
70 
71 private:
72   bool runOnLoop(Loop *L);
73 
74   /// Check if the stride of the accesses is large enough to
75   /// warrant a prefetch.
76   bool isStrideLargeEnough(const SCEVAddRecExpr *AR, unsigned TargetMinStride);
77 
getMinPrefetchStride(unsigned NumMemAccesses,unsigned NumStridedMemAccesses,unsigned NumPrefetches,bool HasCall)78   unsigned getMinPrefetchStride(unsigned NumMemAccesses,
79                                 unsigned NumStridedMemAccesses,
80                                 unsigned NumPrefetches,
81                                 bool HasCall) {
82     if (MinPrefetchStride.getNumOccurrences() > 0)
83       return MinPrefetchStride;
84     return TTI->getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses,
85                                      NumPrefetches, HasCall);
86   }
87 
getPrefetchDistance()88   unsigned getPrefetchDistance() {
89     if (PrefetchDistance.getNumOccurrences() > 0)
90       return PrefetchDistance;
91     return TTI->getPrefetchDistance();
92   }
93 
getMaxPrefetchIterationsAhead()94   unsigned getMaxPrefetchIterationsAhead() {
95     if (MaxPrefetchIterationsAhead.getNumOccurrences() > 0)
96       return MaxPrefetchIterationsAhead;
97     return TTI->getMaxPrefetchIterationsAhead();
98   }
99 
doPrefetchWrites()100   bool doPrefetchWrites() {
101     if (PrefetchWrites.getNumOccurrences() > 0)
102       return PrefetchWrites;
103     return TTI->enableWritePrefetching();
104   }
105 
106   AssumptionCache *AC;
107   DominatorTree *DT;
108   LoopInfo *LI;
109   ScalarEvolution *SE;
110   const TargetTransformInfo *TTI;
111   OptimizationRemarkEmitter *ORE;
112 };
113 
114 /// Legacy class for inserting loop data prefetches.
115 class LoopDataPrefetchLegacyPass : public FunctionPass {
116 public:
117   static char ID; // Pass ID, replacement for typeid
LoopDataPrefetchLegacyPass()118   LoopDataPrefetchLegacyPass() : FunctionPass(ID) {
119     initializeLoopDataPrefetchLegacyPassPass(*PassRegistry::getPassRegistry());
120   }
121 
getAnalysisUsage(AnalysisUsage & AU) const122   void getAnalysisUsage(AnalysisUsage &AU) const override {
123     AU.addRequired<AssumptionCacheTracker>();
124     AU.addRequired<DominatorTreeWrapperPass>();
125     AU.addPreserved<DominatorTreeWrapperPass>();
126     AU.addRequired<LoopInfoWrapperPass>();
127     AU.addPreserved<LoopInfoWrapperPass>();
128     AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
129     AU.addRequired<ScalarEvolutionWrapperPass>();
130     AU.addPreserved<ScalarEvolutionWrapperPass>();
131     AU.addRequired<TargetTransformInfoWrapperPass>();
132   }
133 
134   bool runOnFunction(Function &F) override;
135   };
136 }
137 
138 char LoopDataPrefetchLegacyPass::ID = 0;
139 INITIALIZE_PASS_BEGIN(LoopDataPrefetchLegacyPass, "loop-data-prefetch",
140                       "Loop Data Prefetch", false, false)
INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)141 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
142 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
143 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
144 INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
145 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
146 INITIALIZE_PASS_END(LoopDataPrefetchLegacyPass, "loop-data-prefetch",
147                     "Loop Data Prefetch", false, false)
148 
149 FunctionPass *llvm::createLoopDataPrefetchPass() {
150   return new LoopDataPrefetchLegacyPass();
151 }
152 
isStrideLargeEnough(const SCEVAddRecExpr * AR,unsigned TargetMinStride)153 bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR,
154                                            unsigned TargetMinStride) {
155   // No need to check if any stride goes.
156   if (TargetMinStride <= 1)
157     return true;
158 
159   const auto *ConstStride = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*SE));
160   // If MinStride is set, don't prefetch unless we can ensure that stride is
161   // larger.
162   if (!ConstStride)
163     return false;
164 
165   unsigned AbsStride = std::abs(ConstStride->getAPInt().getSExtValue());
166   return TargetMinStride <= AbsStride;
167 }
168 
run(Function & F,FunctionAnalysisManager & AM)169 PreservedAnalyses LoopDataPrefetchPass::run(Function &F,
170                                             FunctionAnalysisManager &AM) {
171   DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F);
172   LoopInfo *LI = &AM.getResult<LoopAnalysis>(F);
173   ScalarEvolution *SE = &AM.getResult<ScalarEvolutionAnalysis>(F);
174   AssumptionCache *AC = &AM.getResult<AssumptionAnalysis>(F);
175   OptimizationRemarkEmitter *ORE =
176       &AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
177   const TargetTransformInfo *TTI = &AM.getResult<TargetIRAnalysis>(F);
178 
179   LoopDataPrefetch LDP(AC, DT, LI, SE, TTI, ORE);
180   bool Changed = LDP.run();
181 
182   if (Changed) {
183     PreservedAnalyses PA;
184     PA.preserve<DominatorTreeAnalysis>();
185     PA.preserve<LoopAnalysis>();
186     return PA;
187   }
188 
189   return PreservedAnalyses::all();
190 }
191 
runOnFunction(Function & F)192 bool LoopDataPrefetchLegacyPass::runOnFunction(Function &F) {
193   if (skipFunction(F))
194     return false;
195 
196   DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
197   LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
198   ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
199   AssumptionCache *AC =
200       &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
201   OptimizationRemarkEmitter *ORE =
202       &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
203   const TargetTransformInfo *TTI =
204       &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
205 
206   LoopDataPrefetch LDP(AC, DT, LI, SE, TTI, ORE);
207   return LDP.run();
208 }
209 
run()210 bool LoopDataPrefetch::run() {
211   // If PrefetchDistance is not set, don't run the pass.  This gives an
212   // opportunity for targets to run this pass for selected subtargets only
213   // (whose TTI sets PrefetchDistance).
214   if (getPrefetchDistance() == 0)
215     return false;
216   assert(TTI->getCacheLineSize() && "Cache line size is not set for target");
217 
218   bool MadeChange = false;
219 
220   for (Loop *I : *LI)
221     for (auto L = df_begin(I), LE = df_end(I); L != LE; ++L)
222       MadeChange |= runOnLoop(*L);
223 
224   return MadeChange;
225 }
226 
227 /// A record for a potential prefetch made during the initial scan of the
228 /// loop. This is used to let a single prefetch target multiple memory accesses.
229 struct Prefetch {
230   /// The address formula for this prefetch as returned by ScalarEvolution.
231   const SCEVAddRecExpr *LSCEVAddRec;
232   /// The point of insertion for the prefetch instruction.
233   Instruction *InsertPt;
234   /// True if targeting a write memory access.
235   bool Writes;
236   /// The (first seen) prefetched instruction.
237   Instruction *MemI;
238 
239   /// Constructor to create a new Prefetch for \p I.
PrefetchPrefetch240   Prefetch(const SCEVAddRecExpr *L, Instruction *I)
241       : LSCEVAddRec(L), InsertPt(nullptr), Writes(false), MemI(nullptr) {
242     addInstruction(I);
243   };
244 
245   /// Add the instruction \param I to this prefetch. If it's not the first
246   /// one, 'InsertPt' and 'Writes' will be updated as required.
247   /// \param PtrDiff the known constant address difference to the first added
248   /// instruction.
addInstructionPrefetch249   void addInstruction(Instruction *I, DominatorTree *DT = nullptr,
250                       int64_t PtrDiff = 0) {
251     if (!InsertPt) {
252       MemI = I;
253       InsertPt = I;
254       Writes = isa<StoreInst>(I);
255     } else {
256       BasicBlock *PrefBB = InsertPt->getParent();
257       BasicBlock *InsBB = I->getParent();
258       if (PrefBB != InsBB) {
259         BasicBlock *DomBB = DT->findNearestCommonDominator(PrefBB, InsBB);
260         if (DomBB != PrefBB)
261           InsertPt = DomBB->getTerminator();
262       }
263 
264       if (isa<StoreInst>(I) && PtrDiff == 0)
265         Writes = true;
266     }
267   }
268 };
269 
runOnLoop(Loop * L)270 bool LoopDataPrefetch::runOnLoop(Loop *L) {
271   bool MadeChange = false;
272 
273   // Only prefetch in the inner-most loop
274   if (!L->empty())
275     return MadeChange;
276 
277   SmallPtrSet<const Value *, 32> EphValues;
278   CodeMetrics::collectEphemeralValues(L, AC, EphValues);
279 
280   // Calculate the number of iterations ahead to prefetch
281   CodeMetrics Metrics;
282   bool HasCall = false;
283   for (const auto BB : L->blocks()) {
284     // If the loop already has prefetches, then assume that the user knows
285     // what they are doing and don't add any more.
286     for (auto &I : *BB) {
287       if (isa<CallInst>(&I) || isa<InvokeInst>(&I)) {
288         if (const Function *F = cast<CallBase>(I).getCalledFunction()) {
289           if (F->getIntrinsicID() == Intrinsic::prefetch)
290             return MadeChange;
291           if (TTI->isLoweredToCall(F))
292             HasCall = true;
293         } else { // indirect call.
294           HasCall = true;
295         }
296       }
297     }
298     Metrics.analyzeBasicBlock(BB, *TTI, EphValues);
299   }
300   unsigned LoopSize = Metrics.NumInsts;
301   if (!LoopSize)
302     LoopSize = 1;
303 
304   unsigned ItersAhead = getPrefetchDistance() / LoopSize;
305   if (!ItersAhead)
306     ItersAhead = 1;
307 
308   if (ItersAhead > getMaxPrefetchIterationsAhead())
309     return MadeChange;
310 
311   unsigned ConstantMaxTripCount = SE->getSmallConstantMaxTripCount(L);
312   if (ConstantMaxTripCount && ConstantMaxTripCount < ItersAhead + 1)
313     return MadeChange;
314 
315   unsigned NumMemAccesses = 0;
316   unsigned NumStridedMemAccesses = 0;
317   SmallVector<Prefetch, 16> Prefetches;
318   for (const auto BB : L->blocks())
319     for (auto &I : *BB) {
320       Value *PtrValue;
321       Instruction *MemI;
322 
323       if (LoadInst *LMemI = dyn_cast<LoadInst>(&I)) {
324         MemI = LMemI;
325         PtrValue = LMemI->getPointerOperand();
326       } else if (StoreInst *SMemI = dyn_cast<StoreInst>(&I)) {
327         if (!doPrefetchWrites()) continue;
328         MemI = SMemI;
329         PtrValue = SMemI->getPointerOperand();
330       } else continue;
331 
332       unsigned PtrAddrSpace = PtrValue->getType()->getPointerAddressSpace();
333       if (PtrAddrSpace)
334         continue;
335       NumMemAccesses++;
336       if (L->isLoopInvariant(PtrValue))
337         continue;
338 
339       const SCEV *LSCEV = SE->getSCEV(PtrValue);
340       const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
341       if (!LSCEVAddRec)
342         continue;
343       NumStridedMemAccesses++;
344 
345       // We don't want to double prefetch individual cache lines. If this
346       // access is known to be within one cache line of some other one that
347       // has already been prefetched, then don't prefetch this one as well.
348       bool DupPref = false;
349       for (auto &Pref : Prefetches) {
350         const SCEV *PtrDiff = SE->getMinusSCEV(LSCEVAddRec, Pref.LSCEVAddRec);
351         if (const SCEVConstant *ConstPtrDiff =
352             dyn_cast<SCEVConstant>(PtrDiff)) {
353           int64_t PD = std::abs(ConstPtrDiff->getValue()->getSExtValue());
354           if (PD < (int64_t) TTI->getCacheLineSize()) {
355             Pref.addInstruction(MemI, DT, PD);
356             DupPref = true;
357             break;
358           }
359         }
360       }
361       if (!DupPref)
362         Prefetches.push_back(Prefetch(LSCEVAddRec, MemI));
363     }
364 
365   unsigned TargetMinStride =
366     getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses,
367                          Prefetches.size(), HasCall);
368 
369   LLVM_DEBUG(dbgs() << "Prefetching " << ItersAhead
370              << " iterations ahead (loop size: " << LoopSize << ") in "
371              << L->getHeader()->getParent()->getName() << ": " << *L);
372   LLVM_DEBUG(dbgs() << "Loop has: "
373              << NumMemAccesses << " memory accesses, "
374              << NumStridedMemAccesses << " strided memory accesses, "
375              << Prefetches.size() << " potential prefetch(es), "
376              << "a minimum stride of " << TargetMinStride << ", "
377              << (HasCall ? "calls" : "no calls") << ".\n");
378 
379   for (auto &P : Prefetches) {
380     // Check if the stride of the accesses is large enough to warrant a
381     // prefetch.
382     if (!isStrideLargeEnough(P.LSCEVAddRec, TargetMinStride))
383       continue;
384 
385     const SCEV *NextLSCEV = SE->getAddExpr(P.LSCEVAddRec, SE->getMulExpr(
386       SE->getConstant(P.LSCEVAddRec->getType(), ItersAhead),
387       P.LSCEVAddRec->getStepRecurrence(*SE)));
388     if (!isSafeToExpand(NextLSCEV, *SE))
389       continue;
390 
391     BasicBlock *BB = P.InsertPt->getParent();
392     Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), 0/*PtrAddrSpace*/);
393     SCEVExpander SCEVE(*SE, BB->getModule()->getDataLayout(), "prefaddr");
394     Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, P.InsertPt);
395 
396     IRBuilder<> Builder(P.InsertPt);
397     Module *M = BB->getParent()->getParent();
398     Type *I32 = Type::getInt32Ty(BB->getContext());
399     Function *PrefetchFunc = Intrinsic::getDeclaration(
400         M, Intrinsic::prefetch, PrefPtrValue->getType());
401     Builder.CreateCall(
402         PrefetchFunc,
403         {PrefPtrValue,
404          ConstantInt::get(I32, P.Writes),
405          ConstantInt::get(I32, 3), ConstantInt::get(I32, 1)});
406     ++NumPrefetches;
407     LLVM_DEBUG(dbgs() << "  Access: "
408                << *P.MemI->getOperand(isa<LoadInst>(P.MemI) ? 0 : 1)
409                << ", SCEV: " << *P.LSCEVAddRec << "\n");
410     ORE->emit([&]() {
411         return OptimizationRemark(DEBUG_TYPE, "Prefetched", P.MemI)
412           << "prefetched memory access";
413       });
414 
415     MadeChange = true;
416   }
417 
418   return MadeChange;
419 }
420