1 //===- BranchProbability.h - Branch Probability Wrapper ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Definition of BranchProbability shared by IR and Machine Instructions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_SUPPORT_BRANCHPROBABILITY_H
14 #define LLVM_SUPPORT_BRANCHPROBABILITY_H
15 
16 #include "llvm/Support/DataTypes.h"
17 #include <algorithm>
18 #include <cassert>
19 #include <iterator>
20 #include <numeric>
21 
22 namespace llvm {
23 
24 class raw_ostream;
25 
26 // This class represents Branch Probability as a non-negative fraction that is
27 // no greater than 1. It uses a fixed-point-like implementation, in which the
28 // denominator is always a constant value (here we use 1<<31 for maximum
29 // precision).
30 class BranchProbability {
31   // Numerator
32   uint32_t N;
33 
34   // Denominator, which is a constant value.
35   static constexpr uint32_t D = 1u << 31;
36   static constexpr uint32_t UnknownN = UINT32_MAX;
37 
38   // Construct a BranchProbability with only numerator assuming the denominator
39   // is 1<<31. For internal use only.
BranchProbability(uint32_t n)40   explicit BranchProbability(uint32_t n) : N(n) {}
41 
42 public:
BranchProbability()43   BranchProbability() : N(UnknownN) {}
44   BranchProbability(uint32_t Numerator, uint32_t Denominator);
45 
isZero()46   bool isZero() const { return N == 0; }
isUnknown()47   bool isUnknown() const { return N == UnknownN; }
48 
getZero()49   static BranchProbability getZero() { return BranchProbability(0); }
getOne()50   static BranchProbability getOne() { return BranchProbability(D); }
getUnknown()51   static BranchProbability getUnknown() { return BranchProbability(UnknownN); }
52   // Create a BranchProbability object with the given numerator and 1<<31
53   // as denominator.
getRaw(uint32_t N)54   static BranchProbability getRaw(uint32_t N) { return BranchProbability(N); }
55   // Create a BranchProbability object from 64-bit integers.
56   static BranchProbability getBranchProbability(uint64_t Numerator,
57                                                 uint64_t Denominator);
58 
59   // Normalize given probabilties so that the sum of them becomes approximate
60   // one.
61   template <class ProbabilityIter>
62   static void normalizeProbabilities(ProbabilityIter Begin,
63                                      ProbabilityIter End);
64 
getNumerator()65   uint32_t getNumerator() const { return N; }
getDenominator()66   static uint32_t getDenominator() { return D; }
67 
68   // Return (1 - Probability).
getCompl()69   BranchProbability getCompl() const { return BranchProbability(D - N); }
70 
71   raw_ostream &print(raw_ostream &OS) const;
72 
73   void dump() const;
74 
75   /// Scale a large integer.
76   ///
77   /// Scales \c Num.  Guarantees full precision.  Returns the floor of the
78   /// result.
79   ///
80   /// \return \c Num times \c this.
81   uint64_t scale(uint64_t Num) const;
82 
83   /// Scale a large integer by the inverse.
84   ///
85   /// Scales \c Num by the inverse of \c this.  Guarantees full precision.
86   /// Returns the floor of the result.
87   ///
88   /// \return \c Num divided by \c this.
89   uint64_t scaleByInverse(uint64_t Num) const;
90 
91   BranchProbability &operator+=(BranchProbability RHS) {
92     assert(N != UnknownN && RHS.N != UnknownN &&
93            "Unknown probability cannot participate in arithmetics.");
94     // Saturate the result in case of overflow.
95     N = (uint64_t(N) + RHS.N > D) ? D : N + RHS.N;
96     return *this;
97   }
98 
99   BranchProbability &operator-=(BranchProbability RHS) {
100     assert(N != UnknownN && RHS.N != UnknownN &&
101            "Unknown probability cannot participate in arithmetics.");
102     // Saturate the result in case of underflow.
103     N = N < RHS.N ? 0 : N - RHS.N;
104     return *this;
105   }
106 
107   BranchProbability &operator*=(BranchProbability RHS) {
108     assert(N != UnknownN && RHS.N != UnknownN &&
109            "Unknown probability cannot participate in arithmetics.");
110     N = (static_cast<uint64_t>(N) * RHS.N + D / 2) / D;
111     return *this;
112   }
113 
114   BranchProbability &operator*=(uint32_t RHS) {
115     assert(N != UnknownN &&
116            "Unknown probability cannot participate in arithmetics.");
117     N = (uint64_t(N) * RHS > D) ? D : N * RHS;
118     return *this;
119   }
120 
121   BranchProbability &operator/=(BranchProbability RHS) {
122     assert(N != UnknownN && RHS.N != UnknownN &&
123            "Unknown probability cannot participate in arithmetics.");
124     N = (static_cast<uint64_t>(N) * D + RHS.N / 2) / RHS.N;
125     return *this;
126   }
127 
128   BranchProbability &operator/=(uint32_t RHS) {
129     assert(N != UnknownN &&
130            "Unknown probability cannot participate in arithmetics.");
131     assert(RHS > 0 && "The divider cannot be zero.");
132     N /= RHS;
133     return *this;
134   }
135 
136   BranchProbability operator+(BranchProbability RHS) const {
137     BranchProbability Prob(*this);
138     Prob += RHS;
139     return Prob;
140   }
141 
142   BranchProbability operator-(BranchProbability RHS) const {
143     BranchProbability Prob(*this);
144     Prob -= RHS;
145     return Prob;
146   }
147 
148   BranchProbability operator*(BranchProbability RHS) const {
149     BranchProbability Prob(*this);
150     Prob *= RHS;
151     return Prob;
152   }
153 
154   BranchProbability operator*(uint32_t RHS) const {
155     BranchProbability Prob(*this);
156     Prob *= RHS;
157     return Prob;
158   }
159 
160   BranchProbability operator/(BranchProbability RHS) const {
161     BranchProbability Prob(*this);
162     Prob /= RHS;
163     return Prob;
164   }
165 
166   BranchProbability operator/(uint32_t RHS) const {
167     BranchProbability Prob(*this);
168     Prob /= RHS;
169     return Prob;
170   }
171 
172   bool operator==(BranchProbability RHS) const { return N == RHS.N; }
173   bool operator!=(BranchProbability RHS) const { return !(*this == RHS); }
174 
175   bool operator<(BranchProbability RHS) const {
176     assert(N != UnknownN && RHS.N != UnknownN &&
177            "Unknown probability cannot participate in comparisons.");
178     return N < RHS.N;
179   }
180 
181   bool operator>(BranchProbability RHS) const {
182     assert(N != UnknownN && RHS.N != UnknownN &&
183            "Unknown probability cannot participate in comparisons.");
184     return RHS < *this;
185   }
186 
187   bool operator<=(BranchProbability RHS) const {
188     assert(N != UnknownN && RHS.N != UnknownN &&
189            "Unknown probability cannot participate in comparisons.");
190     return !(RHS < *this);
191   }
192 
193   bool operator>=(BranchProbability RHS) const {
194     assert(N != UnknownN && RHS.N != UnknownN &&
195            "Unknown probability cannot participate in comparisons.");
196     return !(*this < RHS);
197   }
198 };
199 
200 inline raw_ostream &operator<<(raw_ostream &OS, BranchProbability Prob) {
201   return Prob.print(OS);
202 }
203 
204 template <class ProbabilityIter>
normalizeProbabilities(ProbabilityIter Begin,ProbabilityIter End)205 void BranchProbability::normalizeProbabilities(ProbabilityIter Begin,
206                                                ProbabilityIter End) {
207   if (Begin == End)
208     return;
209 
210   unsigned UnknownProbCount = 0;
211   uint64_t Sum = std::accumulate(Begin, End, uint64_t(0),
212                                  [&](uint64_t S, const BranchProbability &BP) {
213                                    if (!BP.isUnknown())
214                                      return S + BP.N;
215                                    UnknownProbCount++;
216                                    return S;
217                                  });
218 
219   if (UnknownProbCount > 0) {
220     BranchProbability ProbForUnknown = BranchProbability::getZero();
221     // If the sum of all known probabilities is less than one, evenly distribute
222     // the complement of sum to unknown probabilities. Otherwise, set unknown
223     // probabilities to zeros and continue to normalize known probabilities.
224     if (Sum < BranchProbability::getDenominator())
225       ProbForUnknown = BranchProbability::getRaw(
226           (BranchProbability::getDenominator() - Sum) / UnknownProbCount);
227 
228     std::replace_if(Begin, End,
229                     [](const BranchProbability &BP) { return BP.isUnknown(); },
230                     ProbForUnknown);
231 
232     if (Sum <= BranchProbability::getDenominator())
233       return;
234   }
235 
236   if (Sum == 0) {
237     BranchProbability BP(1, std::distance(Begin, End));
238     std::fill(Begin, End, BP);
239     return;
240   }
241 
242   for (auto I = Begin; I != End; ++I)
243     I->N = (I->N * uint64_t(D) + Sum / 2) / Sum;
244 }
245 
246 }
247 
248 #endif
249