1 //===- BranchProbability.h - Branch Probability Wrapper ---------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Definition of BranchProbability shared by IR and Machine Instructions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef LLVM_SUPPORT_BRANCHPROBABILITY_H 14 #define LLVM_SUPPORT_BRANCHPROBABILITY_H 15 16 #include "llvm/Support/DataTypes.h" 17 #include <algorithm> 18 #include <cassert> 19 #include <climits> 20 #include <numeric> 21 22 namespace llvm { 23 24 class raw_ostream; 25 26 // This class represents Branch Probability as a non-negative fraction that is 27 // no greater than 1. It uses a fixed-point-like implementation, in which the 28 // denominator is always a constant value (here we use 1<<31 for maximum 29 // precision). 30 class BranchProbability { 31 // Numerator 32 uint32_t N; 33 34 // Denominator, which is a constant value. 35 static constexpr uint32_t D = 1u << 31; 36 static constexpr uint32_t UnknownN = UINT32_MAX; 37 38 // Construct a BranchProbability with only numerator assuming the denominator 39 // is 1<<31. For internal use only. 40 explicit BranchProbability(uint32_t n) : N(n) {} 41 42 public: 43 BranchProbability() : N(UnknownN) {} 44 BranchProbability(uint32_t Numerator, uint32_t Denominator); 45 46 bool isZero() const { return N == 0; } 47 bool isUnknown() const { return N == UnknownN; } 48 49 static BranchProbability getZero() { return BranchProbability(0); } 50 static BranchProbability getOne() { return BranchProbability(D); } 51 static BranchProbability getUnknown() { return BranchProbability(UnknownN); } 52 // Create a BranchProbability object with the given numerator and 1<<31 53 // as denominator. 54 static BranchProbability getRaw(uint32_t N) { return BranchProbability(N); } 55 // Create a BranchProbability object from 64-bit integers. 56 static BranchProbability getBranchProbability(uint64_t Numerator, 57 uint64_t Denominator); 58 59 // Normalize given probabilties so that the sum of them becomes approximate 60 // one. 61 template <class ProbabilityIter> 62 static void normalizeProbabilities(ProbabilityIter Begin, 63 ProbabilityIter End); 64 65 uint32_t getNumerator() const { return N; } 66 static uint32_t getDenominator() { return D; } 67 68 // Return (1 - Probability). 69 BranchProbability getCompl() const { return BranchProbability(D - N); } 70 71 raw_ostream &print(raw_ostream &OS) const; 72 73 void dump() const; 74 75 /// Scale a large integer. 76 /// 77 /// Scales \c Num. Guarantees full precision. Returns the floor of the 78 /// result. 79 /// 80 /// \return \c Num times \c this. 81 uint64_t scale(uint64_t Num) const; 82 83 /// Scale a large integer by the inverse. 84 /// 85 /// Scales \c Num by the inverse of \c this. Guarantees full precision. 86 /// Returns the floor of the result. 87 /// 88 /// \return \c Num divided by \c this. 89 uint64_t scaleByInverse(uint64_t Num) const; 90 91 BranchProbability &operator+=(BranchProbability RHS) { 92 assert(N != UnknownN && RHS.N != UnknownN && 93 "Unknown probability cannot participate in arithmetics."); 94 // Saturate the result in case of overflow. 95 N = (uint64_t(N) + RHS.N > D) ? D : N + RHS.N; 96 return *this; 97 } 98 99 BranchProbability &operator-=(BranchProbability RHS) { 100 assert(N != UnknownN && RHS.N != UnknownN && 101 "Unknown probability cannot participate in arithmetics."); 102 // Saturate the result in case of underflow. 103 N = N < RHS.N ? 0 : N - RHS.N; 104 return *this; 105 } 106 107 BranchProbability &operator*=(BranchProbability RHS) { 108 assert(N != UnknownN && RHS.N != UnknownN && 109 "Unknown probability cannot participate in arithmetics."); 110 N = (static_cast<uint64_t>(N) * RHS.N + D / 2) / D; 111 return *this; 112 } 113 114 BranchProbability &operator*=(uint32_t RHS) { 115 assert(N != UnknownN && 116 "Unknown probability cannot participate in arithmetics."); 117 N = (uint64_t(N) * RHS > D) ? D : N * RHS; 118 return *this; 119 } 120 121 BranchProbability &operator/=(BranchProbability RHS) { 122 assert(N != UnknownN && RHS.N != UnknownN && 123 "Unknown probability cannot participate in arithmetics."); 124 N = (static_cast<uint64_t>(N) * D + RHS.N / 2) / RHS.N; 125 return *this; 126 } 127 128 BranchProbability &operator/=(uint32_t RHS) { 129 assert(N != UnknownN && 130 "Unknown probability cannot participate in arithmetics."); 131 assert(RHS > 0 && "The divider cannot be zero."); 132 N /= RHS; 133 return *this; 134 } 135 136 BranchProbability operator+(BranchProbability RHS) const { 137 BranchProbability Prob(*this); 138 Prob += RHS; 139 return Prob; 140 } 141 142 BranchProbability operator-(BranchProbability RHS) const { 143 BranchProbability Prob(*this); 144 Prob -= RHS; 145 return Prob; 146 } 147 148 BranchProbability operator*(BranchProbability RHS) const { 149 BranchProbability Prob(*this); 150 Prob *= RHS; 151 return Prob; 152 } 153 154 BranchProbability operator*(uint32_t RHS) const { 155 BranchProbability Prob(*this); 156 Prob *= RHS; 157 return Prob; 158 } 159 160 BranchProbability operator/(BranchProbability RHS) const { 161 BranchProbability Prob(*this); 162 Prob /= RHS; 163 return Prob; 164 } 165 166 BranchProbability operator/(uint32_t RHS) const { 167 BranchProbability Prob(*this); 168 Prob /= RHS; 169 return Prob; 170 } 171 172 bool operator==(BranchProbability RHS) const { return N == RHS.N; } 173 bool operator!=(BranchProbability RHS) const { return !(*this == RHS); } 174 175 bool operator<(BranchProbability RHS) const { 176 assert(N != UnknownN && RHS.N != UnknownN && 177 "Unknown probability cannot participate in comparisons."); 178 return N < RHS.N; 179 } 180 181 bool operator>(BranchProbability RHS) const { 182 assert(N != UnknownN && RHS.N != UnknownN && 183 "Unknown probability cannot participate in comparisons."); 184 return RHS < *this; 185 } 186 187 bool operator<=(BranchProbability RHS) const { 188 assert(N != UnknownN && RHS.N != UnknownN && 189 "Unknown probability cannot participate in comparisons."); 190 return !(RHS < *this); 191 } 192 193 bool operator>=(BranchProbability RHS) const { 194 assert(N != UnknownN && RHS.N != UnknownN && 195 "Unknown probability cannot participate in comparisons."); 196 return !(*this < RHS); 197 } 198 }; 199 200 inline raw_ostream &operator<<(raw_ostream &OS, BranchProbability Prob) { 201 return Prob.print(OS); 202 } 203 204 template <class ProbabilityIter> 205 void BranchProbability::normalizeProbabilities(ProbabilityIter Begin, 206 ProbabilityIter End) { 207 if (Begin == End) 208 return; 209 210 unsigned UnknownProbCount = 0; 211 uint64_t Sum = std::accumulate(Begin, End, uint64_t(0), 212 [&](uint64_t S, const BranchProbability &BP) { 213 if (!BP.isUnknown()) 214 return S + BP.N; 215 UnknownProbCount++; 216 return S; 217 }); 218 219 if (UnknownProbCount > 0) { 220 BranchProbability ProbForUnknown = BranchProbability::getZero(); 221 // If the sum of all known probabilities is less than one, evenly distribute 222 // the complement of sum to unknown probabilities. Otherwise, set unknown 223 // probabilities to zeros and continue to normalize known probabilities. 224 if (Sum < BranchProbability::getDenominator()) 225 ProbForUnknown = BranchProbability::getRaw( 226 (BranchProbability::getDenominator() - Sum) / UnknownProbCount); 227 228 std::replace_if(Begin, End, 229 [](const BranchProbability &BP) { return BP.isUnknown(); }, 230 ProbForUnknown); 231 232 if (Sum <= BranchProbability::getDenominator()) 233 return; 234 } 235 236 if (Sum == 0) { 237 BranchProbability BP(1, std::distance(Begin, End)); 238 std::fill(Begin, End, BP); 239 return; 240 } 241 242 for (auto I = Begin; I != End; ++I) 243 I->N = (I->N * uint64_t(D) + Sum / 2) / Sum; 244 } 245 246 } 247 248 #endif 249