1 //===-------- MIRSampleProfile.cpp: MIRSampleFDO (For FSAFDO) -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides the implementation of the MIRSampleProfile loader, mainly
10 // for flow sensitive SampleFDO.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/CodeGen/MIRSampleProfile.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/DenseSet.h"
17 #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
18 #include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
19 #include "llvm/CodeGen/MachineBranchProbabilityInfo.h"
20 #include "llvm/CodeGen/MachineDominators.h"
21 #include "llvm/CodeGen/MachineLoopInfo.h"
22 #include "llvm/CodeGen/MachineOptimizationRemarkEmitter.h"
23 #include "llvm/CodeGen/MachinePostDominators.h"
24 #include "llvm/CodeGen/Passes.h"
25 #include "llvm/IR/Function.h"
26 #include "llvm/InitializePasses.h"
27 #include "llvm/Support/CommandLine.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h"
31 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h"
32 
33 using namespace llvm;
34 using namespace sampleprof;
35 using namespace llvm::sampleprofutil;
36 using ProfileCount = Function::ProfileCount;
37 
38 #define DEBUG_TYPE "fs-profile-loader"
39 
40 static cl::opt<bool> ShowFSBranchProb(
41     "show-fs-branchprob", cl::Hidden, cl::init(false),
42     cl::desc("Print setting flow sensitive branch probabilities"));
43 static cl::opt<unsigned> FSProfileDebugProbDiffThreshold(
44     "fs-profile-debug-prob-diff-threshold", cl::init(10),
45     cl::desc("Only show debug message if the branch probility is greater than "
46              "this value (in percentage)."));
47 
48 static cl::opt<unsigned> FSProfileDebugBWThreshold(
49     "fs-profile-debug-bw-threshold", cl::init(10000),
50     cl::desc("Only show debug message if the source branch weight is greater "
51              " than this value."));
52 
53 static cl::opt<bool> ViewBFIBefore("fs-viewbfi-before", cl::Hidden,
54                                    cl::init(false),
55                                    cl::desc("View BFI before MIR loader"));
56 static cl::opt<bool> ViewBFIAfter("fs-viewbfi-after", cl::Hidden,
57                                   cl::init(false),
58                                   cl::desc("View BFI after MIR loader"));
59 
60 char MIRProfileLoaderPass::ID = 0;
61 
62 INITIALIZE_PASS_BEGIN(MIRProfileLoaderPass, DEBUG_TYPE,
63                       "Load MIR Sample Profile",
64                       /* cfg = */ false, /* is_analysis = */ false)
65 INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo)
66 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree)
67 INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTree)
68 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
69 INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass)
70 INITIALIZE_PASS_END(MIRProfileLoaderPass, DEBUG_TYPE, "Load MIR Sample Profile",
71                     /* cfg = */ false, /* is_analysis = */ false)
72 
73 char &llvm::MIRProfileLoaderPassID = MIRProfileLoaderPass::ID;
74 
75 FunctionPass *llvm::createMIRProfileLoaderPass(std::string File,
76                                                std::string RemappingFile,
77                                                FSDiscriminatorPass P) {
78   return new MIRProfileLoaderPass(File, RemappingFile, P);
79 }
80 
81 namespace llvm {
82 
83 // Internal option used to control BFI display only after MBP pass.
84 // Defined in CodeGen/MachineBlockFrequencyInfo.cpp:
85 // -view-block-layout-with-bfi={none | fraction | integer | count}
86 extern cl::opt<GVDAGType> ViewBlockLayoutWithBFI;
87 
88 // Command line option to specify the name of the function for CFG dump
89 // Defined in Analysis/BlockFrequencyInfo.cpp:  -view-bfi-func-name=
90 extern cl::opt<std::string> ViewBlockFreqFuncName;
91 
92 namespace afdo_detail {
93 template <> struct IRTraits<MachineBasicBlock> {
94   using InstructionT = MachineInstr;
95   using BasicBlockT = MachineBasicBlock;
96   using FunctionT = MachineFunction;
97   using BlockFrequencyInfoT = MachineBlockFrequencyInfo;
98   using LoopT = MachineLoop;
99   using LoopInfoPtrT = MachineLoopInfo *;
100   using DominatorTreePtrT = MachineDominatorTree *;
101   using PostDominatorTreePtrT = MachinePostDominatorTree *;
102   using PostDominatorTreeT = MachinePostDominatorTree;
103   using OptRemarkEmitterT = MachineOptimizationRemarkEmitter;
104   using OptRemarkAnalysisT = MachineOptimizationRemarkAnalysis;
105   using PredRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>;
106   using SuccRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>;
107   static Function &getFunction(MachineFunction &F) { return F.getFunction(); }
108   static const MachineBasicBlock *getEntryBB(const MachineFunction *F) {
109     return GraphTraits<const MachineFunction *>::getEntryNode(F);
110   }
111   static PredRangeT getPredecessors(MachineBasicBlock *BB) {
112     return BB->predecessors();
113   }
114   static SuccRangeT getSuccessors(MachineBasicBlock *BB) {
115     return BB->successors();
116   }
117 };
118 } // namespace afdo_detail
119 
120 class MIRProfileLoader final
121     : public SampleProfileLoaderBaseImpl<MachineBasicBlock> {
122 public:
123   void setInitVals(MachineDominatorTree *MDT, MachinePostDominatorTree *MPDT,
124                    MachineLoopInfo *MLI, MachineBlockFrequencyInfo *MBFI,
125                    MachineOptimizationRemarkEmitter *MORE) {
126     DT = MDT;
127     PDT = MPDT;
128     LI = MLI;
129     BFI = MBFI;
130     ORE = MORE;
131   }
132   void setFSPass(FSDiscriminatorPass Pass) {
133     P = Pass;
134     LowBit = getFSPassBitBegin(P);
135     HighBit = getFSPassBitEnd(P);
136     assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
137   }
138 
139   MIRProfileLoader(StringRef Name, StringRef RemapName)
140       : SampleProfileLoaderBaseImpl(std::string(Name), std::string(RemapName)) {
141   }
142 
143   void setBranchProbs(MachineFunction &F);
144   bool runOnFunction(MachineFunction &F);
145   bool doInitialization(Module &M);
146   bool isValid() const { return ProfileIsValid; }
147 
148 protected:
149   friend class SampleCoverageTracker;
150 
151   /// Hold the information of the basic block frequency.
152   MachineBlockFrequencyInfo *BFI;
153 
154   /// PassNum is the sequence number this pass is called, start from 1.
155   FSDiscriminatorPass P;
156 
157   // LowBit in the FS discriminator used by this instance. Note the number is
158   // 0-based. Base discrimnator use bit 0 to bit 11.
159   unsigned LowBit;
160   // HighwBit in the FS discriminator used by this instance. Note the number
161   // is 0-based.
162   unsigned HighBit;
163 
164   bool ProfileIsValid = true;
165 };
166 
167 template <>
168 void SampleProfileLoaderBaseImpl<
169     MachineBasicBlock>::computeDominanceAndLoopInfo(MachineFunction &F) {}
170 
171 void MIRProfileLoader::setBranchProbs(MachineFunction &F) {
172   LLVM_DEBUG(dbgs() << "\nPropagation complete. Setting branch probs\n");
173   for (auto &BI : F) {
174     MachineBasicBlock *BB = &BI;
175     if (BB->succ_size() < 2)
176       continue;
177     const MachineBasicBlock *EC = EquivalenceClass[BB];
178     uint64_t BBWeight = BlockWeights[EC];
179     uint64_t SumEdgeWeight = 0;
180     for (MachineBasicBlock *Succ : BB->successors()) {
181       Edge E = std::make_pair(BB, Succ);
182       SumEdgeWeight += EdgeWeights[E];
183     }
184 
185     if (BBWeight != SumEdgeWeight) {
186       LLVM_DEBUG(dbgs() << "BBweight is not equal to SumEdgeWeight: BBWWeight="
187                         << BBWeight << " SumEdgeWeight= " << SumEdgeWeight
188                         << "\n");
189       BBWeight = SumEdgeWeight;
190     }
191     if (BBWeight == 0) {
192       LLVM_DEBUG(dbgs() << "SKIPPED. All branch weights are zero.\n");
193       continue;
194     }
195 
196 #ifndef NDEBUG
197     uint64_t BBWeightOrig = BBWeight;
198 #endif
199     uint32_t MaxWeight = std::numeric_limits<uint32_t>::max();
200     uint32_t Factor = 1;
201     if (BBWeight > MaxWeight) {
202       Factor = BBWeight / MaxWeight + 1;
203       BBWeight /= Factor;
204       LLVM_DEBUG(dbgs() << "Scaling weights by " << Factor << "\n");
205     }
206 
207     for (MachineBasicBlock::succ_iterator SI = BB->succ_begin(),
208                                           SE = BB->succ_end();
209          SI != SE; ++SI) {
210       MachineBasicBlock *Succ = *SI;
211       Edge E = std::make_pair(BB, Succ);
212       uint64_t EdgeWeight = EdgeWeights[E];
213       EdgeWeight /= Factor;
214 
215       assert(BBWeight >= EdgeWeight &&
216              "BBweight is larger than EdgeWeight -- should not happen.\n");
217 
218       BranchProbability OldProb = BFI->getMBPI()->getEdgeProbability(BB, SI);
219       BranchProbability NewProb(EdgeWeight, BBWeight);
220       if (OldProb == NewProb)
221         continue;
222       BB->setSuccProbability(SI, NewProb);
223 #ifndef NDEBUG
224       if (!ShowFSBranchProb)
225         continue;
226       bool Show = false;
227       BranchProbability Diff;
228       if (OldProb > NewProb)
229         Diff = OldProb - NewProb;
230       else
231         Diff = NewProb - OldProb;
232       Show = (Diff >= BranchProbability(FSProfileDebugProbDiffThreshold, 100));
233       Show &= (BBWeightOrig >= FSProfileDebugBWThreshold);
234 
235       auto DIL = BB->findBranchDebugLoc();
236       auto SuccDIL = Succ->findBranchDebugLoc();
237       if (Show) {
238         dbgs() << "Set branch fs prob: MBB (" << BB->getNumber() << " -> "
239                << Succ->getNumber() << "): ";
240         if (DIL)
241           dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":"
242                  << DIL->getColumn();
243         if (SuccDIL)
244           dbgs() << "-->" << SuccDIL->getFilename() << ":" << SuccDIL->getLine()
245                  << ":" << SuccDIL->getColumn();
246         dbgs() << " W=" << BBWeightOrig << "  " << OldProb << " --> " << NewProb
247                << "\n";
248       }
249 #endif
250     }
251   }
252 }
253 
254 bool MIRProfileLoader::doInitialization(Module &M) {
255   auto &Ctx = M.getContext();
256 
257   auto ReaderOrErr = sampleprof::SampleProfileReader::create(Filename, Ctx, P,
258                                                              RemappingFilename);
259   if (std::error_code EC = ReaderOrErr.getError()) {
260     std::string Msg = "Could not open profile: " + EC.message();
261     Ctx.diagnose(DiagnosticInfoSampleProfile(Filename, Msg));
262     return false;
263   }
264 
265   Reader = std::move(ReaderOrErr.get());
266   Reader->setModule(&M);
267   ProfileIsValid = (Reader->read() == sampleprof_error::success);
268   Reader->getSummary();
269 
270   return true;
271 }
272 
273 bool MIRProfileLoader::runOnFunction(MachineFunction &MF) {
274   Function &Func = MF.getFunction();
275   clearFunctionData(false);
276   Samples = Reader->getSamplesFor(Func);
277   if (!Samples || Samples->empty())
278     return false;
279 
280   if (getFunctionLoc(MF) == 0)
281     return false;
282 
283   DenseSet<GlobalValue::GUID> InlinedGUIDs;
284   bool Changed = computeAndPropagateWeights(MF, InlinedGUIDs);
285 
286   // Set the new BPI, BFI.
287   setBranchProbs(MF);
288 
289   return Changed;
290 }
291 
292 } // namespace llvm
293 
294 MIRProfileLoaderPass::MIRProfileLoaderPass(std::string FileName,
295                                            std::string RemappingFileName,
296                                            FSDiscriminatorPass P)
297     : MachineFunctionPass(ID), ProfileFileName(FileName), P(P),
298       MIRSampleLoader(
299           std::make_unique<MIRProfileLoader>(FileName, RemappingFileName)) {
300   LowBit = getFSPassBitBegin(P);
301   HighBit = getFSPassBitEnd(P);
302   assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
303 }
304 
305 bool MIRProfileLoaderPass::runOnMachineFunction(MachineFunction &MF) {
306   if (!MIRSampleLoader->isValid())
307     return false;
308 
309   LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Func: "
310                     << MF.getFunction().getName() << "\n");
311   MBFI = &getAnalysis<MachineBlockFrequencyInfo>();
312   MIRSampleLoader->setInitVals(
313       &getAnalysis<MachineDominatorTree>(),
314       &getAnalysis<MachinePostDominatorTree>(), &getAnalysis<MachineLoopInfo>(),
315       MBFI, &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE());
316 
317   MF.RenumberBlocks();
318   if (ViewBFIBefore && ViewBlockLayoutWithBFI != GVDT_None &&
319       (ViewBlockFreqFuncName.empty() ||
320        MF.getFunction().getName().equals(ViewBlockFreqFuncName))) {
321     MBFI->view("MIR_Prof_loader_b." + MF.getName(), false);
322   }
323 
324   bool Changed = MIRSampleLoader->runOnFunction(MF);
325   if (Changed)
326     MBFI->calculate(MF, *MBFI->getMBPI(), *&getAnalysis<MachineLoopInfo>());
327 
328   if (ViewBFIAfter && ViewBlockLayoutWithBFI != GVDT_None &&
329       (ViewBlockFreqFuncName.empty() ||
330        MF.getFunction().getName().equals(ViewBlockFreqFuncName))) {
331     MBFI->view("MIR_prof_loader_a." + MF.getName(), false);
332   }
333 
334   return Changed;
335 }
336 
337 bool MIRProfileLoaderPass::doInitialization(Module &M) {
338   LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Module " << M.getName()
339                     << "\n");
340 
341   MIRSampleLoader->setFSPass(P);
342   return MIRSampleLoader->doInitialization(M);
343 }
344 
345 void MIRProfileLoaderPass::getAnalysisUsage(AnalysisUsage &AU) const {
346   AU.setPreservesAll();
347   AU.addRequired<MachineBlockFrequencyInfo>();
348   AU.addRequired<MachineDominatorTree>();
349   AU.addRequired<MachinePostDominatorTree>();
350   AU.addRequiredTransitive<MachineLoopInfo>();
351   AU.addRequired<MachineOptimizationRemarkEmitterPass>();
352   MachineFunctionPass::getAnalysisUsage(AU);
353 }
354