1 //===-------- MIRSampleProfile.cpp: MIRSampleFDO (For FSAFDO) -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides the implementation of the MIRSampleProfile loader, mainly
10 // for flow sensitive SampleFDO.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/CodeGen/MIRSampleProfile.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/DenseSet.h"
17 #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
18 #include "llvm/IR/Function.h"
19 #include "llvm/Support/CommandLine.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/raw_ostream.h"
22 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h"
23 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h"
24 
25 using namespace llvm;
26 using namespace sampleprof;
27 using namespace llvm::sampleprofutil;
28 using ProfileCount = Function::ProfileCount;
29 
30 #define DEBUG_TYPE "fs-profile-loader"
31 
32 static cl::opt<bool> ShowFSBranchProb(
33     "show-fs-branchprob", cl::Hidden, cl::init(false),
34     cl::desc("Print setting flow sensitive branch probabilities"));
35 static cl::opt<unsigned> FSProfileDebugProbDiffThreshold(
36     "fs-profile-debug-prob-diff-threshold", cl::init(10),
37     cl::desc("Only show debug message if the branch probility is greater than "
38              "this value (in percentage)."));
39 
40 static cl::opt<unsigned> FSProfileDebugBWThreshold(
41     "fs-profile-debug-bw-threshold", cl::init(10000),
42     cl::desc("Only show debug message if the source branch weight is greater "
43              " than this value."));
44 
45 static cl::opt<bool> ViewBFIBefore("fs-viewbfi-before", cl::Hidden,
46                                    cl::init(false),
47                                    cl::desc("View BFI before MIR loader"));
48 static cl::opt<bool> ViewBFIAfter("fs-viewbfi-after", cl::Hidden,
49                                   cl::init(false),
50                                   cl::desc("View BFI after MIR loader"));
51 
52 char MIRProfileLoaderPass::ID = 0;
53 
54 INITIALIZE_PASS_BEGIN(MIRProfileLoaderPass, DEBUG_TYPE,
55                       "Load MIR Sample Profile",
56                       /* cfg = */ false, /* is_analysis = */ false)
57 INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo)
58 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree)
59 INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTree)
60 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
61 INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass)
62 INITIALIZE_PASS_END(MIRProfileLoaderPass, DEBUG_TYPE, "Load MIR Sample Profile",
63                     /* cfg = */ false, /* is_analysis = */ false)
64 
65 char &llvm::MIRProfileLoaderPassID = MIRProfileLoaderPass::ID;
66 
createMIRProfileLoaderPass(std::string File,std::string RemappingFile,FSDiscriminatorPass P)67 FunctionPass *llvm::createMIRProfileLoaderPass(std::string File,
68                                                std::string RemappingFile,
69                                                FSDiscriminatorPass P) {
70   return new MIRProfileLoaderPass(File, RemappingFile, P);
71 }
72 
73 namespace llvm {
74 
75 // Internal option used to control BFI display only after MBP pass.
76 // Defined in CodeGen/MachineBlockFrequencyInfo.cpp:
77 // -view-block-layout-with-bfi={none | fraction | integer | count}
78 extern cl::opt<GVDAGType> ViewBlockLayoutWithBFI;
79 
80 // Command line option to specify the name of the function for CFG dump
81 // Defined in Analysis/BlockFrequencyInfo.cpp:  -view-bfi-func-name=
82 extern cl::opt<std::string> ViewBlockFreqFuncName;
83 
84 namespace afdo_detail {
85 template <> struct IRTraits<MachineBasicBlock> {
86   using InstructionT = MachineInstr;
87   using BasicBlockT = MachineBasicBlock;
88   using FunctionT = MachineFunction;
89   using BlockFrequencyInfoT = MachineBlockFrequencyInfo;
90   using LoopT = MachineLoop;
91   using LoopInfoPtrT = MachineLoopInfo *;
92   using DominatorTreePtrT = MachineDominatorTree *;
93   using PostDominatorTreePtrT = MachinePostDominatorTree *;
94   using PostDominatorTreeT = MachinePostDominatorTree;
95   using OptRemarkEmitterT = MachineOptimizationRemarkEmitter;
96   using OptRemarkAnalysisT = MachineOptimizationRemarkAnalysis;
97   using PredRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>;
98   using SuccRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>;
getFunctionllvm::afdo_detail::IRTraits99   static Function &getFunction(MachineFunction &F) { return F.getFunction(); }
getEntryBBllvm::afdo_detail::IRTraits100   static const MachineBasicBlock *getEntryBB(const MachineFunction *F) {
101     return GraphTraits<const MachineFunction *>::getEntryNode(F);
102   }
getPredecessorsllvm::afdo_detail::IRTraits103   static PredRangeT getPredecessors(MachineBasicBlock *BB) {
104     return BB->predecessors();
105   }
getSuccessorsllvm::afdo_detail::IRTraits106   static SuccRangeT getSuccessors(MachineBasicBlock *BB) {
107     return BB->successors();
108   }
109 };
110 } // namespace afdo_detail
111 
112 class MIRProfileLoader final
113     : public SampleProfileLoaderBaseImpl<MachineBasicBlock> {
114 public:
setInitVals(MachineDominatorTree * MDT,MachinePostDominatorTree * MPDT,MachineLoopInfo * MLI,MachineBlockFrequencyInfo * MBFI,MachineOptimizationRemarkEmitter * MORE)115   void setInitVals(MachineDominatorTree *MDT, MachinePostDominatorTree *MPDT,
116                    MachineLoopInfo *MLI, MachineBlockFrequencyInfo *MBFI,
117                    MachineOptimizationRemarkEmitter *MORE) {
118     DT = MDT;
119     PDT = MPDT;
120     LI = MLI;
121     BFI = MBFI;
122     ORE = MORE;
123   }
setFSPass(FSDiscriminatorPass Pass)124   void setFSPass(FSDiscriminatorPass Pass) {
125     P = Pass;
126     LowBit = getFSPassBitBegin(P);
127     HighBit = getFSPassBitEnd(P);
128     assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
129   }
130 
MIRProfileLoader(StringRef Name,StringRef RemapName)131   MIRProfileLoader(StringRef Name, StringRef RemapName)
132       : SampleProfileLoaderBaseImpl(std::string(Name), std::string(RemapName)) {
133   }
134 
135   void setBranchProbs(MachineFunction &F);
136   bool runOnFunction(MachineFunction &F);
137   bool doInitialization(Module &M);
isValid() const138   bool isValid() const { return ProfileIsValid; }
139 
140 protected:
141   friend class SampleCoverageTracker;
142 
143   /// Hold the information of the basic block frequency.
144   MachineBlockFrequencyInfo *BFI;
145 
146   /// PassNum is the sequence number this pass is called, start from 1.
147   FSDiscriminatorPass P;
148 
149   // LowBit in the FS discriminator used by this instance. Note the number is
150   // 0-based. Base discrimnator use bit 0 to bit 11.
151   unsigned LowBit;
152   // HighwBit in the FS discriminator used by this instance. Note the number
153   // is 0-based.
154   unsigned HighBit;
155 
156   bool ProfileIsValid = true;
157 };
158 
159 template <>
160 void SampleProfileLoaderBaseImpl<
computeDominanceAndLoopInfo(MachineFunction & F)161     MachineBasicBlock>::computeDominanceAndLoopInfo(MachineFunction &F) {}
162 
setBranchProbs(MachineFunction & F)163 void MIRProfileLoader::setBranchProbs(MachineFunction &F) {
164   LLVM_DEBUG(dbgs() << "\nPropagation complete. Setting branch probs\n");
165   for (auto &BI : F) {
166     MachineBasicBlock *BB = &BI;
167     if (BB->succ_size() < 2)
168       continue;
169     const MachineBasicBlock *EC = EquivalenceClass[BB];
170     uint64_t BBWeight = BlockWeights[EC];
171     uint64_t SumEdgeWeight = 0;
172     for (MachineBasicBlock::succ_iterator SI = BB->succ_begin(),
173                                           SE = BB->succ_end();
174          SI != SE; ++SI) {
175       MachineBasicBlock *Succ = *SI;
176       Edge E = std::make_pair(BB, Succ);
177       SumEdgeWeight += EdgeWeights[E];
178     }
179 
180     if (BBWeight != SumEdgeWeight) {
181       LLVM_DEBUG(dbgs() << "BBweight is not equal to SumEdgeWeight: BBWWeight="
182                         << BBWeight << " SumEdgeWeight= " << SumEdgeWeight
183                         << "\n");
184       BBWeight = SumEdgeWeight;
185     }
186     if (BBWeight == 0) {
187       LLVM_DEBUG(dbgs() << "SKIPPED. All branch weights are zero.\n");
188       continue;
189     }
190 
191 #ifndef NDEBUG
192     uint64_t BBWeightOrig = BBWeight;
193 #endif
194     uint32_t MaxWeight = std::numeric_limits<uint32_t>::max();
195     uint32_t Factor = 1;
196     if (BBWeight > MaxWeight) {
197       Factor = BBWeight / MaxWeight + 1;
198       BBWeight /= Factor;
199       LLVM_DEBUG(dbgs() << "Scaling weights by " << Factor << "\n");
200     }
201 
202     for (MachineBasicBlock::succ_iterator SI = BB->succ_begin(),
203                                           SE = BB->succ_end();
204          SI != SE; ++SI) {
205       MachineBasicBlock *Succ = *SI;
206       Edge E = std::make_pair(BB, Succ);
207       uint64_t EdgeWeight = EdgeWeights[E];
208       EdgeWeight /= Factor;
209 
210       assert(BBWeight >= EdgeWeight &&
211              "BBweight is larger than EdgeWeight -- should not happen.\n");
212 
213       BranchProbability OldProb = BFI->getMBPI()->getEdgeProbability(BB, SI);
214       BranchProbability NewProb(EdgeWeight, BBWeight);
215       if (OldProb == NewProb)
216         continue;
217       BB->setSuccProbability(SI, NewProb);
218 #ifndef NDEBUG
219       if (!ShowFSBranchProb)
220         continue;
221       bool Show = false;
222       BranchProbability Diff;
223       if (OldProb > NewProb)
224         Diff = OldProb - NewProb;
225       else
226         Diff = NewProb - OldProb;
227       Show = (Diff >= BranchProbability(FSProfileDebugProbDiffThreshold, 100));
228       Show &= (BBWeightOrig >= FSProfileDebugBWThreshold);
229 
230       auto DIL = BB->findBranchDebugLoc();
231       auto SuccDIL = Succ->findBranchDebugLoc();
232       if (Show) {
233         dbgs() << "Set branch fs prob: MBB (" << BB->getNumber() << " -> "
234                << Succ->getNumber() << "): ";
235         if (DIL)
236           dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":"
237                  << DIL->getColumn();
238         if (SuccDIL)
239           dbgs() << "-->" << SuccDIL->getFilename() << ":" << SuccDIL->getLine()
240                  << ":" << SuccDIL->getColumn();
241         dbgs() << " W=" << BBWeightOrig << "  " << OldProb << " --> " << NewProb
242                << "\n";
243       }
244 #endif
245     }
246   }
247 }
248 
doInitialization(Module & M)249 bool MIRProfileLoader::doInitialization(Module &M) {
250   auto &Ctx = M.getContext();
251 
252   auto ReaderOrErr = sampleprof::SampleProfileReader::create(Filename, Ctx, P,
253                                                              RemappingFilename);
254   if (std::error_code EC = ReaderOrErr.getError()) {
255     std::string Msg = "Could not open profile: " + EC.message();
256     Ctx.diagnose(DiagnosticInfoSampleProfile(Filename, Msg));
257     return false;
258   }
259 
260   Reader = std::move(ReaderOrErr.get());
261   Reader->setModule(&M);
262   ProfileIsValid = (Reader->read() == sampleprof_error::success);
263   Reader->getSummary();
264 
265   return true;
266 }
267 
runOnFunction(MachineFunction & MF)268 bool MIRProfileLoader::runOnFunction(MachineFunction &MF) {
269   Function &Func = MF.getFunction();
270   clearFunctionData(false);
271   Samples = Reader->getSamplesFor(Func);
272   if (!Samples || Samples->empty())
273     return false;
274 
275   if (getFunctionLoc(MF) == 0)
276     return false;
277 
278   DenseSet<GlobalValue::GUID> InlinedGUIDs;
279   bool Changed = computeAndPropagateWeights(MF, InlinedGUIDs);
280 
281   // Set the new BPI, BFI.
282   setBranchProbs(MF);
283 
284   return Changed;
285 }
286 
287 } // namespace llvm
288 
MIRProfileLoaderPass(std::string FileName,std::string RemappingFileName,FSDiscriminatorPass P)289 MIRProfileLoaderPass::MIRProfileLoaderPass(std::string FileName,
290                                            std::string RemappingFileName,
291                                            FSDiscriminatorPass P)
292     : MachineFunctionPass(ID), ProfileFileName(FileName), P(P),
293       MIRSampleLoader(
294           std::make_unique<MIRProfileLoader>(FileName, RemappingFileName)) {
295   LowBit = getFSPassBitBegin(P);
296   HighBit = getFSPassBitEnd(P);
297   assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
298 }
299 
runOnMachineFunction(MachineFunction & MF)300 bool MIRProfileLoaderPass::runOnMachineFunction(MachineFunction &MF) {
301   if (!MIRSampleLoader->isValid())
302     return false;
303 
304   LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Func: "
305                     << MF.getFunction().getName() << "\n");
306   MBFI = &getAnalysis<MachineBlockFrequencyInfo>();
307   MIRSampleLoader->setInitVals(
308       &getAnalysis<MachineDominatorTree>(),
309       &getAnalysis<MachinePostDominatorTree>(), &getAnalysis<MachineLoopInfo>(),
310       MBFI, &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE());
311 
312   MF.RenumberBlocks();
313   if (ViewBFIBefore && ViewBlockLayoutWithBFI != GVDT_None &&
314       (ViewBlockFreqFuncName.empty() ||
315        MF.getFunction().getName().equals(ViewBlockFreqFuncName))) {
316     MBFI->view("MIR_Prof_loader_b." + MF.getName(), false);
317   }
318 
319   bool Changed = MIRSampleLoader->runOnFunction(MF);
320 
321   if (ViewBFIAfter && ViewBlockLayoutWithBFI != GVDT_None &&
322       (ViewBlockFreqFuncName.empty() ||
323        MF.getFunction().getName().equals(ViewBlockFreqFuncName))) {
324     MBFI->view("MIR_prof_loader_a." + MF.getName(), false);
325   }
326 
327   return Changed;
328 }
329 
doInitialization(Module & M)330 bool MIRProfileLoaderPass::doInitialization(Module &M) {
331   LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Module " << M.getName()
332                     << "\n");
333 
334   MIRSampleLoader->setFSPass(P);
335   return MIRSampleLoader->doInitialization(M);
336 }
337 
getAnalysisUsage(AnalysisUsage & AU) const338 void MIRProfileLoaderPass::getAnalysisUsage(AnalysisUsage &AU) const {
339   AU.setPreservesAll();
340   AU.addRequired<MachineBlockFrequencyInfo>();
341   AU.addRequired<MachineDominatorTree>();
342   AU.addRequired<MachinePostDominatorTree>();
343   AU.addRequiredTransitive<MachineLoopInfo>();
344   AU.addRequired<MachineOptimizationRemarkEmitterPass>();
345   MachineFunctionPass::getAnalysisUsage(AU);
346 }
347