1 //===- BranchProbabilityInfo.h - Branch Probability Analysis ----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass is used to evaluate branch probabilties. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef LLVM_ANALYSIS_BRANCHPROBABILITYINFO_H 14 #define LLVM_ANALYSIS_BRANCHPROBABILITYINFO_H 15 16 #include "llvm/ADT/DenseMap.h" 17 #include "llvm/ADT/DenseMapInfo.h" 18 #include "llvm/ADT/DenseSet.h" 19 #include "llvm/ADT/SmallPtrSet.h" 20 #include "llvm/IR/BasicBlock.h" 21 #include "llvm/IR/CFG.h" 22 #include "llvm/IR/PassManager.h" 23 #include "llvm/IR/ValueHandle.h" 24 #include "llvm/Pass.h" 25 #include "llvm/Support/BranchProbability.h" 26 #include "llvm/Support/Casting.h" 27 #include <algorithm> 28 #include <cassert> 29 #include <cstdint> 30 #include <utility> 31 32 namespace llvm { 33 34 class Function; 35 class LoopInfo; 36 class raw_ostream; 37 class PostDominatorTree; 38 class TargetLibraryInfo; 39 class Value; 40 41 /// Analysis providing branch probability information. 42 /// 43 /// This is a function analysis which provides information on the relative 44 /// probabilities of each "edge" in the function's CFG where such an edge is 45 /// defined by a pair (PredBlock and an index in the successors). The 46 /// probability of an edge from one block is always relative to the 47 /// probabilities of other edges from the block. The probabilites of all edges 48 /// from a block sum to exactly one (100%). 49 /// We use a pair (PredBlock and an index in the successors) to uniquely 50 /// identify an edge, since we can have multiple edges from Src to Dst. 51 /// As an example, we can have a switch which jumps to Dst with value 0 and 52 /// value 10. 53 class BranchProbabilityInfo { 54 public: 55 BranchProbabilityInfo() = default; 56 57 BranchProbabilityInfo(const Function &F, const LoopInfo &LI, 58 const TargetLibraryInfo *TLI = nullptr, 59 PostDominatorTree *PDT = nullptr) { 60 calculate(F, LI, TLI, PDT); 61 } 62 BranchProbabilityInfo(BranchProbabilityInfo && Arg)63 BranchProbabilityInfo(BranchProbabilityInfo &&Arg) 64 : Probs(std::move(Arg.Probs)), LastF(Arg.LastF), 65 PostDominatedByUnreachable(std::move(Arg.PostDominatedByUnreachable)), 66 PostDominatedByColdCall(std::move(Arg.PostDominatedByColdCall)) {} 67 68 BranchProbabilityInfo(const BranchProbabilityInfo &) = delete; 69 BranchProbabilityInfo &operator=(const BranchProbabilityInfo &) = delete; 70 71 BranchProbabilityInfo &operator=(BranchProbabilityInfo &&RHS) { 72 releaseMemory(); 73 Probs = std::move(RHS.Probs); 74 PostDominatedByColdCall = std::move(RHS.PostDominatedByColdCall); 75 PostDominatedByUnreachable = std::move(RHS.PostDominatedByUnreachable); 76 return *this; 77 } 78 79 bool invalidate(Function &, const PreservedAnalyses &PA, 80 FunctionAnalysisManager::Invalidator &); 81 82 void releaseMemory(); 83 84 void print(raw_ostream &OS) const; 85 86 /// Get an edge's probability, relative to other out-edges of the Src. 87 /// 88 /// This routine provides access to the fractional probability between zero 89 /// (0%) and one (100%) of this edge executing, relative to other edges 90 /// leaving the 'Src' block. The returned probability is never zero, and can 91 /// only be one if the source block has only one successor. 92 BranchProbability getEdgeProbability(const BasicBlock *Src, 93 unsigned IndexInSuccessors) const; 94 95 /// Get the probability of going from Src to Dst. 96 /// 97 /// It returns the sum of all probabilities for edges from Src to Dst. 98 BranchProbability getEdgeProbability(const BasicBlock *Src, 99 const BasicBlock *Dst) const; 100 101 BranchProbability getEdgeProbability(const BasicBlock *Src, 102 const_succ_iterator Dst) const; 103 104 /// Test if an edge is hot relative to other out-edges of the Src. 105 /// 106 /// Check whether this edge out of the source block is 'hot'. We define hot 107 /// as having a relative probability >= 80%. 108 bool isEdgeHot(const BasicBlock *Src, const BasicBlock *Dst) const; 109 110 /// Retrieve the hot successor of a block if one exists. 111 /// 112 /// Given a basic block, look through its successors and if one exists for 113 /// which \see isEdgeHot would return true, return that successor block. 114 const BasicBlock *getHotSucc(const BasicBlock *BB) const; 115 116 /// Print an edge's probability. 117 /// 118 /// Retrieves an edge's probability similarly to \see getEdgeProbability, but 119 /// then prints that probability to the provided stream. That stream is then 120 /// returned. 121 raw_ostream &printEdgeProbability(raw_ostream &OS, const BasicBlock *Src, 122 const BasicBlock *Dst) const; 123 124 protected: 125 /// Set the raw edge probability for the given edge. 126 /// 127 /// This allows a pass to explicitly set the edge probability for an edge. It 128 /// can be used when updating the CFG to update and preserve the branch 129 /// probability information. Read the implementation of how these edge 130 /// probabilities are calculated carefully before using! 131 void setEdgeProbability(const BasicBlock *Src, unsigned IndexInSuccessors, 132 BranchProbability Prob); 133 134 public: 135 /// Set the raw probabilities for all edges from the given block. 136 /// 137 /// This allows a pass to explicitly set edge probabilities for a block. It 138 /// can be used when updating the CFG to update the branch probability 139 /// information. 140 void setEdgeProbability(const BasicBlock *Src, 141 const SmallVectorImpl<BranchProbability> &Probs); 142 getBranchProbStackProtector(bool IsLikely)143 static BranchProbability getBranchProbStackProtector(bool IsLikely) { 144 static const BranchProbability LikelyProb((1u << 20) - 1, 1u << 20); 145 return IsLikely ? LikelyProb : LikelyProb.getCompl(); 146 } 147 148 void calculate(const Function &F, const LoopInfo &LI, 149 const TargetLibraryInfo *TLI, PostDominatorTree *PDT); 150 151 /// Forget analysis results for the given basic block. 152 void eraseBlock(const BasicBlock *BB); 153 154 // Use to track SCCs for handling irreducible loops. 155 using SccMap = DenseMap<const BasicBlock *, int>; 156 using SccHeaderMap = DenseMap<const BasicBlock *, bool>; 157 using SccHeaderMaps = std::vector<SccHeaderMap>; 158 struct SccInfo { 159 SccMap SccNums; 160 SccHeaderMaps SccHeaders; 161 }; 162 163 private: 164 // We need to store CallbackVH's in order to correctly handle basic block 165 // removal. 166 class BasicBlockCallbackVH final : public CallbackVH { 167 BranchProbabilityInfo *BPI; 168 deleted()169 void deleted() override { 170 assert(BPI != nullptr); 171 BPI->eraseBlock(cast<BasicBlock>(getValPtr())); 172 BPI->Handles.erase(*this); 173 } 174 175 public: 176 BasicBlockCallbackVH(const Value *V, BranchProbabilityInfo *BPI = nullptr) CallbackVH(const_cast<Value * > (V))177 : CallbackVH(const_cast<Value *>(V)), BPI(BPI) {} 178 }; 179 180 DenseSet<BasicBlockCallbackVH, DenseMapInfo<Value*>> Handles; 181 182 // Since we allow duplicate edges from one basic block to another, we use 183 // a pair (PredBlock and an index in the successors) to specify an edge. 184 using Edge = std::pair<const BasicBlock *, unsigned>; 185 186 // Default weight value. Used when we don't have information about the edge. 187 // TODO: DEFAULT_WEIGHT makes sense during static predication, when none of 188 // the successors have a weight yet. But it doesn't make sense when providing 189 // weight to an edge that may have siblings with non-zero weights. This can 190 // be handled various ways, but it's probably fine for an edge with unknown 191 // weight to just "inherit" the non-zero weight of an adjacent successor. 192 static const uint32_t DEFAULT_WEIGHT = 16; 193 194 DenseMap<Edge, BranchProbability> Probs; 195 196 /// Track the last function we run over for printing. 197 const Function *LastF = nullptr; 198 199 /// Track the set of blocks directly succeeded by a returning block. 200 SmallPtrSet<const BasicBlock *, 16> PostDominatedByUnreachable; 201 202 /// Track the set of blocks that always lead to a cold call. 203 SmallPtrSet<const BasicBlock *, 16> PostDominatedByColdCall; 204 205 void computePostDominatedByUnreachable(const Function &F, 206 PostDominatorTree *PDT); 207 void computePostDominatedByColdCall(const Function &F, 208 PostDominatorTree *PDT); 209 bool calcUnreachableHeuristics(const BasicBlock *BB); 210 bool calcMetadataWeights(const BasicBlock *BB); 211 bool calcColdCallHeuristics(const BasicBlock *BB); 212 bool calcPointerHeuristics(const BasicBlock *BB); 213 bool calcLoopBranchHeuristics(const BasicBlock *BB, const LoopInfo &LI, 214 SccInfo &SccI); 215 bool calcZeroHeuristics(const BasicBlock *BB, const TargetLibraryInfo *TLI); 216 bool calcFloatingPointHeuristics(const BasicBlock *BB); 217 bool calcInvokeHeuristics(const BasicBlock *BB); 218 }; 219 220 /// Analysis pass which computes \c BranchProbabilityInfo. 221 class BranchProbabilityAnalysis 222 : public AnalysisInfoMixin<BranchProbabilityAnalysis> { 223 friend AnalysisInfoMixin<BranchProbabilityAnalysis>; 224 225 static AnalysisKey Key; 226 227 public: 228 /// Provide the result type for this analysis pass. 229 using Result = BranchProbabilityInfo; 230 231 /// Run the analysis pass over a function and produce BPI. 232 BranchProbabilityInfo run(Function &F, FunctionAnalysisManager &AM); 233 }; 234 235 /// Printer pass for the \c BranchProbabilityAnalysis results. 236 class BranchProbabilityPrinterPass 237 : public PassInfoMixin<BranchProbabilityPrinterPass> { 238 raw_ostream &OS; 239 240 public: BranchProbabilityPrinterPass(raw_ostream & OS)241 explicit BranchProbabilityPrinterPass(raw_ostream &OS) : OS(OS) {} 242 243 PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); 244 }; 245 246 /// Legacy analysis pass which computes \c BranchProbabilityInfo. 247 class BranchProbabilityInfoWrapperPass : public FunctionPass { 248 BranchProbabilityInfo BPI; 249 250 public: 251 static char ID; 252 253 BranchProbabilityInfoWrapperPass(); 254 getBPI()255 BranchProbabilityInfo &getBPI() { return BPI; } getBPI()256 const BranchProbabilityInfo &getBPI() const { return BPI; } 257 258 void getAnalysisUsage(AnalysisUsage &AU) const override; 259 bool runOnFunction(Function &F) override; 260 void releaseMemory() override; 261 void print(raw_ostream &OS, const Module *M = nullptr) const override; 262 }; 263 264 } // end namespace llvm 265 266 #endif // LLVM_ANALYSIS_BRANCHPROBABILITYINFO_H 267