1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Instrumentation-based profile-guided optimization
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CodeGenPGO.h"
14 #include "CodeGenFunction.h"
15 #include "CoverageMappingGen.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
24 
25 static llvm::cl::opt<bool>
26     EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore,
27                          llvm::cl::desc("Enable value profiling"),
28                          llvm::cl::Hidden, llvm::cl::init(false));
29 
30 using namespace clang;
31 using namespace CodeGen;
32 
setFuncName(StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)33 void CodeGenPGO::setFuncName(StringRef Name,
34                              llvm::GlobalValue::LinkageTypes Linkage) {
35   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
36   FuncName = llvm::getPGOFuncName(
37       Name, Linkage, CGM.getCodeGenOpts().MainFileName,
38       PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
39 
40   // If we're generating a profile, create a variable for the name.
41   if (CGM.getCodeGenOpts().hasProfileClangInstr())
42     FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
43 }
44 
setFuncName(llvm::Function * Fn)45 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
46   setFuncName(Fn->getName(), Fn->getLinkage());
47   // Create PGOFuncName meta data.
48   llvm::createPGOFuncNameMetadata(*Fn, FuncName);
49 }
50 
51 /// The version of the PGO hash algorithm.
52 enum PGOHashVersion : unsigned {
53   PGO_HASH_V1,
54   PGO_HASH_V2,
55   PGO_HASH_V3,
56 
57   // Keep this set to the latest hash version.
58   PGO_HASH_LATEST = PGO_HASH_V3
59 };
60 
61 namespace {
62 /// Stable hasher for PGO region counters.
63 ///
64 /// PGOHash produces a stable hash of a given function's control flow.
65 ///
66 /// Changing the output of this hash will invalidate all previously generated
67 /// profiles -- i.e., don't do it.
68 ///
69 /// \note  When this hash does eventually change (years?), we still need to
70 /// support old hashes.  We'll need to pull in the version number from the
71 /// profile data format and use the matching hash function.
72 class PGOHash {
73   uint64_t Working;
74   unsigned Count;
75   PGOHashVersion HashVersion;
76   llvm::MD5 MD5;
77 
78   static const int NumBitsPerType = 6;
79   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
80   static const unsigned TooBig = 1u << NumBitsPerType;
81 
82 public:
83   /// Hash values for AST nodes.
84   ///
85   /// Distinct values for AST nodes that have region counters attached.
86   ///
87   /// These values must be stable.  All new members must be added at the end,
88   /// and no members should be removed.  Changing the enumeration value for an
89   /// AST node will affect the hash of every function that contains that node.
90   enum HashType : unsigned char {
91     None = 0,
92     LabelStmt = 1,
93     WhileStmt,
94     DoStmt,
95     ForStmt,
96     CXXForRangeStmt,
97     ObjCForCollectionStmt,
98     SwitchStmt,
99     CaseStmt,
100     DefaultStmt,
101     IfStmt,
102     CXXTryStmt,
103     CXXCatchStmt,
104     ConditionalOperator,
105     BinaryOperatorLAnd,
106     BinaryOperatorLOr,
107     BinaryConditionalOperator,
108     // The preceding values are available with PGO_HASH_V1.
109 
110     EndOfScope,
111     IfThenBranch,
112     IfElseBranch,
113     GotoStmt,
114     IndirectGotoStmt,
115     BreakStmt,
116     ContinueStmt,
117     ReturnStmt,
118     ThrowExpr,
119     UnaryOperatorLNot,
120     BinaryOperatorLT,
121     BinaryOperatorGT,
122     BinaryOperatorLE,
123     BinaryOperatorGE,
124     BinaryOperatorEQ,
125     BinaryOperatorNE,
126     // The preceding values are available since PGO_HASH_V2.
127 
128     // Keep this last.  It's for the static assert that follows.
129     LastHashType
130   };
131   static_assert(LastHashType <= TooBig, "Too many types in HashType");
132 
PGOHash(PGOHashVersion HashVersion)133   PGOHash(PGOHashVersion HashVersion)
134       : Working(0), Count(0), HashVersion(HashVersion), MD5() {}
135   void combine(HashType Type);
136   uint64_t finalize();
getHashVersion() const137   PGOHashVersion getHashVersion() const { return HashVersion; }
138 };
139 const int PGOHash::NumBitsPerType;
140 const unsigned PGOHash::NumTypesPerWord;
141 const unsigned PGOHash::TooBig;
142 
143 /// Get the PGO hash version used in the given indexed profile.
getPGOHashVersion(llvm::IndexedInstrProfReader * PGOReader,CodeGenModule & CGM)144 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
145                                         CodeGenModule &CGM) {
146   if (PGOReader->getVersion() <= 4)
147     return PGO_HASH_V1;
148   if (PGOReader->getVersion() <= 5)
149     return PGO_HASH_V2;
150   return PGO_HASH_V3;
151 }
152 
153 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
154 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
155   using Base = RecursiveASTVisitor<MapRegionCounters>;
156 
157   /// The next counter value to assign.
158   unsigned NextCounter;
159   /// The function hash.
160   PGOHash Hash;
161   /// The map of statements to counters.
162   llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
163   /// The profile version.
164   uint64_t ProfileVersion;
165 
MapRegionCounters__anon9f6b1bbf0111::MapRegionCounters166   MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion,
167                     llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
168       : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap),
169         ProfileVersion(ProfileVersion) {}
170 
171   // Blocks and lambdas are handled as separate functions, so we need not
172   // traverse them in the parent context.
TraverseBlockExpr__anon9f6b1bbf0111::MapRegionCounters173   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
TraverseLambdaExpr__anon9f6b1bbf0111::MapRegionCounters174   bool TraverseLambdaExpr(LambdaExpr *LE) {
175     // Traverse the captures, but not the body.
176     for (auto C : zip(LE->captures(), LE->capture_inits()))
177       TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
178     return true;
179   }
TraverseCapturedStmt__anon9f6b1bbf0111::MapRegionCounters180   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
181 
VisitDecl__anon9f6b1bbf0111::MapRegionCounters182   bool VisitDecl(const Decl *D) {
183     switch (D->getKind()) {
184     default:
185       break;
186     case Decl::Function:
187     case Decl::CXXMethod:
188     case Decl::CXXConstructor:
189     case Decl::CXXDestructor:
190     case Decl::CXXConversion:
191     case Decl::ObjCMethod:
192     case Decl::Block:
193     case Decl::Captured:
194       CounterMap[D->getBody()] = NextCounter++;
195       break;
196     }
197     return true;
198   }
199 
200   /// If \p S gets a fresh counter, update the counter mappings. Return the
201   /// V1 hash of \p S.
updateCounterMappings__anon9f6b1bbf0111::MapRegionCounters202   PGOHash::HashType updateCounterMappings(Stmt *S) {
203     auto Type = getHashType(PGO_HASH_V1, S);
204     if (Type != PGOHash::None)
205       CounterMap[S] = NextCounter++;
206     return Type;
207   }
208 
209   /// The RHS of all logical operators gets a fresh counter in order to count
210   /// how many times the RHS evaluates to true or false, depending on the
211   /// semantics of the operator. This is only valid for ">= v7" of the profile
212   /// version so that we facilitate backward compatibility.
VisitBinaryOperator__anon9f6b1bbf0111::MapRegionCounters213   bool VisitBinaryOperator(BinaryOperator *S) {
214     if (ProfileVersion >= llvm::IndexedInstrProf::Version7)
215       if (S->isLogicalOp() &&
216           CodeGenFunction::isInstrumentedCondition(S->getRHS()))
217         CounterMap[S->getRHS()] = NextCounter++;
218     return Base::VisitBinaryOperator(S);
219   }
220 
221   /// Include \p S in the function hash.
VisitStmt__anon9f6b1bbf0111::MapRegionCounters222   bool VisitStmt(Stmt *S) {
223     auto Type = updateCounterMappings(S);
224     if (Hash.getHashVersion() != PGO_HASH_V1)
225       Type = getHashType(Hash.getHashVersion(), S);
226     if (Type != PGOHash::None)
227       Hash.combine(Type);
228     return true;
229   }
230 
TraverseIfStmt__anon9f6b1bbf0111::MapRegionCounters231   bool TraverseIfStmt(IfStmt *If) {
232     // If we used the V1 hash, use the default traversal.
233     if (Hash.getHashVersion() == PGO_HASH_V1)
234       return Base::TraverseIfStmt(If);
235 
236     // Otherwise, keep track of which branch we're in while traversing.
237     VisitStmt(If);
238     for (Stmt *CS : If->children()) {
239       if (!CS)
240         continue;
241       if (CS == If->getThen())
242         Hash.combine(PGOHash::IfThenBranch);
243       else if (CS == If->getElse())
244         Hash.combine(PGOHash::IfElseBranch);
245       TraverseStmt(CS);
246     }
247     Hash.combine(PGOHash::EndOfScope);
248     return true;
249   }
250 
251 // If the statement type \p N is nestable, and its nesting impacts profile
252 // stability, define a custom traversal which tracks the end of the statement
253 // in the hash (provided we're not using the V1 hash).
254 #define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
255   bool Traverse##N(N *S) {                                                     \
256     Base::Traverse##N(S);                                                      \
257     if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
258       Hash.combine(PGOHash::EndOfScope);                                       \
259     return true;                                                               \
260   }
261 
262   DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
DEFINE_NESTABLE_TRAVERSAL__anon9f6b1bbf0111::MapRegionCounters263   DEFINE_NESTABLE_TRAVERSAL(DoStmt)
264   DEFINE_NESTABLE_TRAVERSAL(ForStmt)
265   DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
266   DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
267   DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
268   DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
269 
270   /// Get version \p HashVersion of the PGO hash for \p S.
271   PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
272     switch (S->getStmtClass()) {
273     default:
274       break;
275     case Stmt::LabelStmtClass:
276       return PGOHash::LabelStmt;
277     case Stmt::WhileStmtClass:
278       return PGOHash::WhileStmt;
279     case Stmt::DoStmtClass:
280       return PGOHash::DoStmt;
281     case Stmt::ForStmtClass:
282       return PGOHash::ForStmt;
283     case Stmt::CXXForRangeStmtClass:
284       return PGOHash::CXXForRangeStmt;
285     case Stmt::ObjCForCollectionStmtClass:
286       return PGOHash::ObjCForCollectionStmt;
287     case Stmt::SwitchStmtClass:
288       return PGOHash::SwitchStmt;
289     case Stmt::CaseStmtClass:
290       return PGOHash::CaseStmt;
291     case Stmt::DefaultStmtClass:
292       return PGOHash::DefaultStmt;
293     case Stmt::IfStmtClass:
294       return PGOHash::IfStmt;
295     case Stmt::CXXTryStmtClass:
296       return PGOHash::CXXTryStmt;
297     case Stmt::CXXCatchStmtClass:
298       return PGOHash::CXXCatchStmt;
299     case Stmt::ConditionalOperatorClass:
300       return PGOHash::ConditionalOperator;
301     case Stmt::BinaryConditionalOperatorClass:
302       return PGOHash::BinaryConditionalOperator;
303     case Stmt::BinaryOperatorClass: {
304       const BinaryOperator *BO = cast<BinaryOperator>(S);
305       if (BO->getOpcode() == BO_LAnd)
306         return PGOHash::BinaryOperatorLAnd;
307       if (BO->getOpcode() == BO_LOr)
308         return PGOHash::BinaryOperatorLOr;
309       if (HashVersion >= PGO_HASH_V2) {
310         switch (BO->getOpcode()) {
311         default:
312           break;
313         case BO_LT:
314           return PGOHash::BinaryOperatorLT;
315         case BO_GT:
316           return PGOHash::BinaryOperatorGT;
317         case BO_LE:
318           return PGOHash::BinaryOperatorLE;
319         case BO_GE:
320           return PGOHash::BinaryOperatorGE;
321         case BO_EQ:
322           return PGOHash::BinaryOperatorEQ;
323         case BO_NE:
324           return PGOHash::BinaryOperatorNE;
325         }
326       }
327       break;
328     }
329     }
330 
331     if (HashVersion >= PGO_HASH_V2) {
332       switch (S->getStmtClass()) {
333       default:
334         break;
335       case Stmt::GotoStmtClass:
336         return PGOHash::GotoStmt;
337       case Stmt::IndirectGotoStmtClass:
338         return PGOHash::IndirectGotoStmt;
339       case Stmt::BreakStmtClass:
340         return PGOHash::BreakStmt;
341       case Stmt::ContinueStmtClass:
342         return PGOHash::ContinueStmt;
343       case Stmt::ReturnStmtClass:
344         return PGOHash::ReturnStmt;
345       case Stmt::CXXThrowExprClass:
346         return PGOHash::ThrowExpr;
347       case Stmt::UnaryOperatorClass: {
348         const UnaryOperator *UO = cast<UnaryOperator>(S);
349         if (UO->getOpcode() == UO_LNot)
350           return PGOHash::UnaryOperatorLNot;
351         break;
352       }
353       }
354     }
355 
356     return PGOHash::None;
357   }
358 };
359 
360 /// A StmtVisitor that propagates the raw counts through the AST and
361 /// records the count at statements where the value may change.
362 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
363   /// PGO state.
364   CodeGenPGO &PGO;
365 
366   /// A flag that is set when the current count should be recorded on the
367   /// next statement, such as at the exit of a loop.
368   bool RecordNextStmtCount;
369 
370   /// The count at the current location in the traversal.
371   uint64_t CurrentCount;
372 
373   /// The map of statements to count values.
374   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
375 
376   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
377   struct BreakContinue {
378     uint64_t BreakCount;
379     uint64_t ContinueCount;
BreakContinue__anon9f6b1bbf0111::ComputeRegionCounts::BreakContinue380     BreakContinue() : BreakCount(0), ContinueCount(0) {}
381   };
382   SmallVector<BreakContinue, 8> BreakContinueStack;
383 
ComputeRegionCounts__anon9f6b1bbf0111::ComputeRegionCounts384   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
385                       CodeGenPGO &PGO)
386       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
387 
RecordStmtCount__anon9f6b1bbf0111::ComputeRegionCounts388   void RecordStmtCount(const Stmt *S) {
389     if (RecordNextStmtCount) {
390       CountMap[S] = CurrentCount;
391       RecordNextStmtCount = false;
392     }
393   }
394 
395   /// Set and return the current count.
setCount__anon9f6b1bbf0111::ComputeRegionCounts396   uint64_t setCount(uint64_t Count) {
397     CurrentCount = Count;
398     return Count;
399   }
400 
VisitStmt__anon9f6b1bbf0111::ComputeRegionCounts401   void VisitStmt(const Stmt *S) {
402     RecordStmtCount(S);
403     for (const Stmt *Child : S->children())
404       if (Child)
405         this->Visit(Child);
406   }
407 
VisitFunctionDecl__anon9f6b1bbf0111::ComputeRegionCounts408   void VisitFunctionDecl(const FunctionDecl *D) {
409     // Counter tracks entry to the function body.
410     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
411     CountMap[D->getBody()] = BodyCount;
412     Visit(D->getBody());
413   }
414 
415   // Skip lambda expressions. We visit these as FunctionDecls when we're
416   // generating them and aren't interested in the body when generating a
417   // parent context.
VisitLambdaExpr__anon9f6b1bbf0111::ComputeRegionCounts418   void VisitLambdaExpr(const LambdaExpr *LE) {}
419 
VisitCapturedDecl__anon9f6b1bbf0111::ComputeRegionCounts420   void VisitCapturedDecl(const CapturedDecl *D) {
421     // Counter tracks entry to the capture body.
422     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
423     CountMap[D->getBody()] = BodyCount;
424     Visit(D->getBody());
425   }
426 
VisitObjCMethodDecl__anon9f6b1bbf0111::ComputeRegionCounts427   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
428     // Counter tracks entry to the method body.
429     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
430     CountMap[D->getBody()] = BodyCount;
431     Visit(D->getBody());
432   }
433 
VisitBlockDecl__anon9f6b1bbf0111::ComputeRegionCounts434   void VisitBlockDecl(const BlockDecl *D) {
435     // Counter tracks entry to the block body.
436     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
437     CountMap[D->getBody()] = BodyCount;
438     Visit(D->getBody());
439   }
440 
VisitReturnStmt__anon9f6b1bbf0111::ComputeRegionCounts441   void VisitReturnStmt(const ReturnStmt *S) {
442     RecordStmtCount(S);
443     if (S->getRetValue())
444       Visit(S->getRetValue());
445     CurrentCount = 0;
446     RecordNextStmtCount = true;
447   }
448 
VisitCXXThrowExpr__anon9f6b1bbf0111::ComputeRegionCounts449   void VisitCXXThrowExpr(const CXXThrowExpr *E) {
450     RecordStmtCount(E);
451     if (E->getSubExpr())
452       Visit(E->getSubExpr());
453     CurrentCount = 0;
454     RecordNextStmtCount = true;
455   }
456 
VisitGotoStmt__anon9f6b1bbf0111::ComputeRegionCounts457   void VisitGotoStmt(const GotoStmt *S) {
458     RecordStmtCount(S);
459     CurrentCount = 0;
460     RecordNextStmtCount = true;
461   }
462 
VisitLabelStmt__anon9f6b1bbf0111::ComputeRegionCounts463   void VisitLabelStmt(const LabelStmt *S) {
464     RecordNextStmtCount = false;
465     // Counter tracks the block following the label.
466     uint64_t BlockCount = setCount(PGO.getRegionCount(S));
467     CountMap[S] = BlockCount;
468     Visit(S->getSubStmt());
469   }
470 
VisitBreakStmt__anon9f6b1bbf0111::ComputeRegionCounts471   void VisitBreakStmt(const BreakStmt *S) {
472     RecordStmtCount(S);
473     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
474     BreakContinueStack.back().BreakCount += CurrentCount;
475     CurrentCount = 0;
476     RecordNextStmtCount = true;
477   }
478 
VisitContinueStmt__anon9f6b1bbf0111::ComputeRegionCounts479   void VisitContinueStmt(const ContinueStmt *S) {
480     RecordStmtCount(S);
481     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
482     BreakContinueStack.back().ContinueCount += CurrentCount;
483     CurrentCount = 0;
484     RecordNextStmtCount = true;
485   }
486 
VisitWhileStmt__anon9f6b1bbf0111::ComputeRegionCounts487   void VisitWhileStmt(const WhileStmt *S) {
488     RecordStmtCount(S);
489     uint64_t ParentCount = CurrentCount;
490 
491     BreakContinueStack.push_back(BreakContinue());
492     // Visit the body region first so the break/continue adjustments can be
493     // included when visiting the condition.
494     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
495     CountMap[S->getBody()] = CurrentCount;
496     Visit(S->getBody());
497     uint64_t BackedgeCount = CurrentCount;
498 
499     // ...then go back and propagate counts through the condition. The count
500     // at the start of the condition is the sum of the incoming edges,
501     // the backedge from the end of the loop body, and the edges from
502     // continue statements.
503     BreakContinue BC = BreakContinueStack.pop_back_val();
504     uint64_t CondCount =
505         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
506     CountMap[S->getCond()] = CondCount;
507     Visit(S->getCond());
508     setCount(BC.BreakCount + CondCount - BodyCount);
509     RecordNextStmtCount = true;
510   }
511 
VisitDoStmt__anon9f6b1bbf0111::ComputeRegionCounts512   void VisitDoStmt(const DoStmt *S) {
513     RecordStmtCount(S);
514     uint64_t LoopCount = PGO.getRegionCount(S);
515 
516     BreakContinueStack.push_back(BreakContinue());
517     // The count doesn't include the fallthrough from the parent scope. Add it.
518     uint64_t BodyCount = setCount(LoopCount + CurrentCount);
519     CountMap[S->getBody()] = BodyCount;
520     Visit(S->getBody());
521     uint64_t BackedgeCount = CurrentCount;
522 
523     BreakContinue BC = BreakContinueStack.pop_back_val();
524     // The count at the start of the condition is equal to the count at the
525     // end of the body, plus any continues.
526     uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
527     CountMap[S->getCond()] = CondCount;
528     Visit(S->getCond());
529     setCount(BC.BreakCount + CondCount - LoopCount);
530     RecordNextStmtCount = true;
531   }
532 
VisitForStmt__anon9f6b1bbf0111::ComputeRegionCounts533   void VisitForStmt(const ForStmt *S) {
534     RecordStmtCount(S);
535     if (S->getInit())
536       Visit(S->getInit());
537 
538     uint64_t ParentCount = CurrentCount;
539 
540     BreakContinueStack.push_back(BreakContinue());
541     // Visit the body region first. (This is basically the same as a while
542     // loop; see further comments in VisitWhileStmt.)
543     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
544     CountMap[S->getBody()] = BodyCount;
545     Visit(S->getBody());
546     uint64_t BackedgeCount = CurrentCount;
547     BreakContinue BC = BreakContinueStack.pop_back_val();
548 
549     // The increment is essentially part of the body but it needs to include
550     // the count for all the continue statements.
551     if (S->getInc()) {
552       uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
553       CountMap[S->getInc()] = IncCount;
554       Visit(S->getInc());
555     }
556 
557     // ...then go back and propagate counts through the condition.
558     uint64_t CondCount =
559         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
560     if (S->getCond()) {
561       CountMap[S->getCond()] = CondCount;
562       Visit(S->getCond());
563     }
564     setCount(BC.BreakCount + CondCount - BodyCount);
565     RecordNextStmtCount = true;
566   }
567 
VisitCXXForRangeStmt__anon9f6b1bbf0111::ComputeRegionCounts568   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
569     RecordStmtCount(S);
570     if (S->getInit())
571       Visit(S->getInit());
572     Visit(S->getLoopVarStmt());
573     Visit(S->getRangeStmt());
574     Visit(S->getBeginStmt());
575     Visit(S->getEndStmt());
576 
577     uint64_t ParentCount = CurrentCount;
578     BreakContinueStack.push_back(BreakContinue());
579     // Visit the body region first. (This is basically the same as a while
580     // loop; see further comments in VisitWhileStmt.)
581     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
582     CountMap[S->getBody()] = BodyCount;
583     Visit(S->getBody());
584     uint64_t BackedgeCount = CurrentCount;
585     BreakContinue BC = BreakContinueStack.pop_back_val();
586 
587     // The increment is essentially part of the body but it needs to include
588     // the count for all the continue statements.
589     uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
590     CountMap[S->getInc()] = IncCount;
591     Visit(S->getInc());
592 
593     // ...then go back and propagate counts through the condition.
594     uint64_t CondCount =
595         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
596     CountMap[S->getCond()] = CondCount;
597     Visit(S->getCond());
598     setCount(BC.BreakCount + CondCount - BodyCount);
599     RecordNextStmtCount = true;
600   }
601 
VisitObjCForCollectionStmt__anon9f6b1bbf0111::ComputeRegionCounts602   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
603     RecordStmtCount(S);
604     Visit(S->getElement());
605     uint64_t ParentCount = CurrentCount;
606     BreakContinueStack.push_back(BreakContinue());
607     // Counter tracks the body of the loop.
608     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
609     CountMap[S->getBody()] = BodyCount;
610     Visit(S->getBody());
611     uint64_t BackedgeCount = CurrentCount;
612     BreakContinue BC = BreakContinueStack.pop_back_val();
613 
614     setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
615              BodyCount);
616     RecordNextStmtCount = true;
617   }
618 
VisitSwitchStmt__anon9f6b1bbf0111::ComputeRegionCounts619   void VisitSwitchStmt(const SwitchStmt *S) {
620     RecordStmtCount(S);
621     if (S->getInit())
622       Visit(S->getInit());
623     Visit(S->getCond());
624     CurrentCount = 0;
625     BreakContinueStack.push_back(BreakContinue());
626     Visit(S->getBody());
627     // If the switch is inside a loop, add the continue counts.
628     BreakContinue BC = BreakContinueStack.pop_back_val();
629     if (!BreakContinueStack.empty())
630       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
631     // Counter tracks the exit block of the switch.
632     setCount(PGO.getRegionCount(S));
633     RecordNextStmtCount = true;
634   }
635 
VisitSwitchCase__anon9f6b1bbf0111::ComputeRegionCounts636   void VisitSwitchCase(const SwitchCase *S) {
637     RecordNextStmtCount = false;
638     // Counter for this particular case. This counts only jumps from the
639     // switch header and does not include fallthrough from the case before
640     // this one.
641     uint64_t CaseCount = PGO.getRegionCount(S);
642     setCount(CurrentCount + CaseCount);
643     // We need the count without fallthrough in the mapping, so it's more useful
644     // for branch probabilities.
645     CountMap[S] = CaseCount;
646     RecordNextStmtCount = true;
647     Visit(S->getSubStmt());
648   }
649 
VisitIfStmt__anon9f6b1bbf0111::ComputeRegionCounts650   void VisitIfStmt(const IfStmt *S) {
651     RecordStmtCount(S);
652     uint64_t ParentCount = CurrentCount;
653     if (S->getInit())
654       Visit(S->getInit());
655     Visit(S->getCond());
656 
657     // Counter tracks the "then" part of an if statement. The count for
658     // the "else" part, if it exists, will be calculated from this counter.
659     uint64_t ThenCount = setCount(PGO.getRegionCount(S));
660     CountMap[S->getThen()] = ThenCount;
661     Visit(S->getThen());
662     uint64_t OutCount = CurrentCount;
663 
664     uint64_t ElseCount = ParentCount - ThenCount;
665     if (S->getElse()) {
666       setCount(ElseCount);
667       CountMap[S->getElse()] = ElseCount;
668       Visit(S->getElse());
669       OutCount += CurrentCount;
670     } else
671       OutCount += ElseCount;
672     setCount(OutCount);
673     RecordNextStmtCount = true;
674   }
675 
VisitCXXTryStmt__anon9f6b1bbf0111::ComputeRegionCounts676   void VisitCXXTryStmt(const CXXTryStmt *S) {
677     RecordStmtCount(S);
678     Visit(S->getTryBlock());
679     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
680       Visit(S->getHandler(I));
681     // Counter tracks the continuation block of the try statement.
682     setCount(PGO.getRegionCount(S));
683     RecordNextStmtCount = true;
684   }
685 
VisitCXXCatchStmt__anon9f6b1bbf0111::ComputeRegionCounts686   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
687     RecordNextStmtCount = false;
688     // Counter tracks the catch statement's handler block.
689     uint64_t CatchCount = setCount(PGO.getRegionCount(S));
690     CountMap[S] = CatchCount;
691     Visit(S->getHandlerBlock());
692   }
693 
VisitAbstractConditionalOperator__anon9f6b1bbf0111::ComputeRegionCounts694   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
695     RecordStmtCount(E);
696     uint64_t ParentCount = CurrentCount;
697     Visit(E->getCond());
698 
699     // Counter tracks the "true" part of a conditional operator. The
700     // count in the "false" part will be calculated from this counter.
701     uint64_t TrueCount = setCount(PGO.getRegionCount(E));
702     CountMap[E->getTrueExpr()] = TrueCount;
703     Visit(E->getTrueExpr());
704     uint64_t OutCount = CurrentCount;
705 
706     uint64_t FalseCount = setCount(ParentCount - TrueCount);
707     CountMap[E->getFalseExpr()] = FalseCount;
708     Visit(E->getFalseExpr());
709     OutCount += CurrentCount;
710 
711     setCount(OutCount);
712     RecordNextStmtCount = true;
713   }
714 
VisitBinLAnd__anon9f6b1bbf0111::ComputeRegionCounts715   void VisitBinLAnd(const BinaryOperator *E) {
716     RecordStmtCount(E);
717     uint64_t ParentCount = CurrentCount;
718     Visit(E->getLHS());
719     // Counter tracks the right hand side of a logical and operator.
720     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
721     CountMap[E->getRHS()] = RHSCount;
722     Visit(E->getRHS());
723     setCount(ParentCount + RHSCount - CurrentCount);
724     RecordNextStmtCount = true;
725   }
726 
VisitBinLOr__anon9f6b1bbf0111::ComputeRegionCounts727   void VisitBinLOr(const BinaryOperator *E) {
728     RecordStmtCount(E);
729     uint64_t ParentCount = CurrentCount;
730     Visit(E->getLHS());
731     // Counter tracks the right hand side of a logical or operator.
732     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
733     CountMap[E->getRHS()] = RHSCount;
734     Visit(E->getRHS());
735     setCount(ParentCount + RHSCount - CurrentCount);
736     RecordNextStmtCount = true;
737   }
738 };
739 } // end anonymous namespace
740 
combine(HashType Type)741 void PGOHash::combine(HashType Type) {
742   // Check that we never combine 0 and only have six bits.
743   assert(Type && "Hash is invalid: unexpected type 0");
744   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
745 
746   // Pass through MD5 if enough work has built up.
747   if (Count && Count % NumTypesPerWord == 0) {
748     using namespace llvm::support;
749     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
750     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
751     Working = 0;
752   }
753 
754   // Accumulate the current type.
755   ++Count;
756   Working = Working << NumBitsPerType | Type;
757 }
758 
finalize()759 uint64_t PGOHash::finalize() {
760   // Use Working as the hash directly if we never used MD5.
761   if (Count <= NumTypesPerWord)
762     // No need to byte swap here, since none of the math was endian-dependent.
763     // This number will be byte-swapped as required on endianness transitions,
764     // so we will see the same value on the other side.
765     return Working;
766 
767   // Check for remaining work in Working.
768   if (Working) {
769     // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
770     // is buggy because it converts a uint64_t into an array of uint8_t.
771     if (HashVersion < PGO_HASH_V3) {
772       MD5.update({(uint8_t)Working});
773     } else {
774       using namespace llvm::support;
775       uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
776       MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
777     }
778   }
779 
780   // Finalize the MD5 and return the hash.
781   llvm::MD5::MD5Result Result;
782   MD5.final(Result);
783   return Result.low();
784 }
785 
assignRegionCounters(GlobalDecl GD,llvm::Function * Fn)786 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
787   const Decl *D = GD.getDecl();
788   if (!D->hasBody())
789     return;
790 
791   // Skip CUDA/HIP kernel launch stub functions.
792   if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
793       D->hasAttr<CUDAGlobalAttr>())
794     return;
795 
796   bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
797   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
798   if (!InstrumentRegions && !PGOReader)
799     return;
800   if (D->isImplicit())
801     return;
802   // Constructors and destructors may be represented by several functions in IR.
803   // If so, instrument only base variant, others are implemented by delegation
804   // to the base one, it would be counted twice otherwise.
805   if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
806     if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
807       if (GD.getCtorType() != Ctor_Base &&
808           CodeGenFunction::IsConstructorDelegationValid(CCD))
809         return;
810   }
811   if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
812     return;
813 
814   CGM.ClearUnusedCoverageMapping(D);
815   if (Fn->hasFnAttribute(llvm::Attribute::NoProfile))
816     return;
817 
818   setFuncName(Fn);
819 
820   mapRegionCounters(D);
821   if (CGM.getCodeGenOpts().CoverageMapping)
822     emitCounterRegionMapping(D);
823   if (PGOReader) {
824     SourceManager &SM = CGM.getContext().getSourceManager();
825     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
826     computeRegionCounts(D);
827     applyFunctionAttributes(PGOReader, Fn);
828   }
829 }
830 
mapRegionCounters(const Decl * D)831 void CodeGenPGO::mapRegionCounters(const Decl *D) {
832   // Use the latest hash version when inserting instrumentation, but use the
833   // version in the indexed profile if we're reading PGO data.
834   PGOHashVersion HashVersion = PGO_HASH_LATEST;
835   uint64_t ProfileVersion = llvm::IndexedInstrProf::Version;
836   if (auto *PGOReader = CGM.getPGOReader()) {
837     HashVersion = getPGOHashVersion(PGOReader, CGM);
838     ProfileVersion = PGOReader->getVersion();
839   }
840 
841   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
842   MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap);
843   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
844     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
845   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
846     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
847   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
848     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
849   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
850     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
851   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
852   NumRegionCounters = Walker.NextCounter;
853   FunctionHash = Walker.Hash.finalize();
854 }
855 
skipRegionMappingForDecl(const Decl * D)856 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
857   if (!D->getBody())
858     return true;
859 
860   // Skip host-only functions in the CUDA device compilation and device-only
861   // functions in the host compilation. Just roughly filter them out based on
862   // the function attributes. If there are effectively host-only or device-only
863   // ones, their coverage mapping may still be generated.
864   if (CGM.getLangOpts().CUDA &&
865       ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() &&
866         !D->hasAttr<CUDAGlobalAttr>()) ||
867        (!CGM.getLangOpts().CUDAIsDevice &&
868         (D->hasAttr<CUDAGlobalAttr>() ||
869          (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>())))))
870     return true;
871 
872   // Don't map the functions in system headers.
873   const auto &SM = CGM.getContext().getSourceManager();
874   auto Loc = D->getBody()->getBeginLoc();
875   return SM.isInSystemHeader(Loc);
876 }
877 
emitCounterRegionMapping(const Decl * D)878 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
879   if (skipRegionMappingForDecl(D))
880     return;
881 
882   std::string CoverageMapping;
883   llvm::raw_string_ostream OS(CoverageMapping);
884   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
885                                 CGM.getContext().getSourceManager(),
886                                 CGM.getLangOpts(), RegionCounterMap.get());
887   MappingGen.emitCounterMapping(D, OS);
888   OS.flush();
889 
890   if (CoverageMapping.empty())
891     return;
892 
893   CGM.getCoverageMapping()->addFunctionMappingRecord(
894       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
895 }
896 
897 void
emitEmptyCounterMapping(const Decl * D,StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)898 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
899                                     llvm::GlobalValue::LinkageTypes Linkage) {
900   if (skipRegionMappingForDecl(D))
901     return;
902 
903   std::string CoverageMapping;
904   llvm::raw_string_ostream OS(CoverageMapping);
905   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
906                                 CGM.getContext().getSourceManager(),
907                                 CGM.getLangOpts());
908   MappingGen.emitEmptyMapping(D, OS);
909   OS.flush();
910 
911   if (CoverageMapping.empty())
912     return;
913 
914   setFuncName(Name, Linkage);
915   CGM.getCoverageMapping()->addFunctionMappingRecord(
916       FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
917 }
918 
computeRegionCounts(const Decl * D)919 void CodeGenPGO::computeRegionCounts(const Decl *D) {
920   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
921   ComputeRegionCounts Walker(*StmtCountMap, *this);
922   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
923     Walker.VisitFunctionDecl(FD);
924   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
925     Walker.VisitObjCMethodDecl(MD);
926   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
927     Walker.VisitBlockDecl(BD);
928   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
929     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
930 }
931 
932 void
applyFunctionAttributes(llvm::IndexedInstrProfReader * PGOReader,llvm::Function * Fn)933 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
934                                     llvm::Function *Fn) {
935   if (!haveRegionCounts())
936     return;
937 
938   uint64_t FunctionCount = getRegionCount(nullptr);
939   Fn->setEntryCount(FunctionCount);
940 }
941 
emitCounterIncrement(CGBuilderTy & Builder,const Stmt * S,llvm::Value * StepV)942 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
943                                       llvm::Value *StepV) {
944   if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
945     return;
946   if (!Builder.GetInsertBlock())
947     return;
948 
949   unsigned Counter = (*RegionCounterMap)[S];
950   auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
951 
952   llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
953                          Builder.getInt64(FunctionHash),
954                          Builder.getInt32(NumRegionCounters),
955                          Builder.getInt32(Counter), StepV};
956   if (!StepV)
957     Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
958                        makeArrayRef(Args, 4));
959   else
960     Builder.CreateCall(
961         CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
962         makeArrayRef(Args));
963 }
964 
setValueProfilingFlag(llvm::Module & M)965 void CodeGenPGO::setValueProfilingFlag(llvm::Module &M) {
966   if (CGM.getCodeGenOpts().hasProfileClangInstr())
967     M.addModuleFlag(llvm::Module::Warning, "EnableValueProfiling",
968                     uint32_t(EnableValueProfiling));
969 }
970 
971 // This method either inserts a call to the profile run-time during
972 // instrumentation or puts profile data into metadata for PGO use.
valueProfile(CGBuilderTy & Builder,uint32_t ValueKind,llvm::Instruction * ValueSite,llvm::Value * ValuePtr)973 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
974     llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
975 
976   if (!EnableValueProfiling)
977     return;
978 
979   if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
980     return;
981 
982   if (isa<llvm::Constant>(ValuePtr))
983     return;
984 
985   bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
986   if (InstrumentValueSites && RegionCounterMap) {
987     auto BuilderInsertPoint = Builder.saveIP();
988     Builder.SetInsertPoint(ValueSite);
989     llvm::Value *Args[5] = {
990         llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
991         Builder.getInt64(FunctionHash),
992         Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
993         Builder.getInt32(ValueKind),
994         Builder.getInt32(NumValueSites[ValueKind]++)
995     };
996     Builder.CreateCall(
997         CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
998     Builder.restoreIP(BuilderInsertPoint);
999     return;
1000   }
1001 
1002   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1003   if (PGOReader && haveRegionCounts()) {
1004     // We record the top most called three functions at each call site.
1005     // Profile metadata contains "VP" string identifying this metadata
1006     // as value profiling data, then a uint32_t value for the value profiling
1007     // kind, a uint64_t value for the total number of times the call is
1008     // executed, followed by the function hash and execution count (uint64_t)
1009     // pairs for each function.
1010     if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
1011       return;
1012 
1013     llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
1014                             (llvm::InstrProfValueKind)ValueKind,
1015                             NumValueSites[ValueKind]);
1016 
1017     NumValueSites[ValueKind]++;
1018   }
1019 }
1020 
loadRegionCounts(llvm::IndexedInstrProfReader * PGOReader,bool IsInMainFile)1021 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
1022                                   bool IsInMainFile) {
1023   CGM.getPGOStats().addVisited(IsInMainFile);
1024   RegionCounts.clear();
1025   llvm::Expected<llvm::InstrProfRecord> RecordExpected =
1026       PGOReader->getInstrProfRecord(FuncName, FunctionHash);
1027   if (auto E = RecordExpected.takeError()) {
1028     auto IPE = llvm::InstrProfError::take(std::move(E));
1029     if (IPE == llvm::instrprof_error::unknown_function)
1030       CGM.getPGOStats().addMissing(IsInMainFile);
1031     else if (IPE == llvm::instrprof_error::hash_mismatch)
1032       CGM.getPGOStats().addMismatched(IsInMainFile);
1033     else if (IPE == llvm::instrprof_error::malformed)
1034       // TODO: Consider a more specific warning for this case.
1035       CGM.getPGOStats().addMismatched(IsInMainFile);
1036     return;
1037   }
1038   ProfRecord =
1039       std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
1040   RegionCounts = ProfRecord->Counts;
1041 }
1042 
1043 /// Calculate what to divide by to scale weights.
1044 ///
1045 /// Given the maximum weight, calculate a divisor that will scale all the
1046 /// weights to strictly less than UINT32_MAX.
calculateWeightScale(uint64_t MaxWeight)1047 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1048   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1049 }
1050 
1051 /// Scale an individual branch weight (and add 1).
1052 ///
1053 /// Scale a 64-bit weight down to 32-bits using \c Scale.
1054 ///
1055 /// According to Laplace's Rule of Succession, it is better to compute the
1056 /// weight based on the count plus 1, so universally add 1 to the value.
1057 ///
1058 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1059 /// greater than \c Weight.
scaleBranchWeight(uint64_t Weight,uint64_t Scale)1060 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1061   assert(Scale && "scale by 0?");
1062   uint64_t Scaled = Weight / Scale + 1;
1063   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1064   return Scaled;
1065 }
1066 
createProfileWeights(uint64_t TrueCount,uint64_t FalseCount) const1067 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1068                                                     uint64_t FalseCount) const {
1069   // Check for empty weights.
1070   if (!TrueCount && !FalseCount)
1071     return nullptr;
1072 
1073   // Calculate how to scale down to 32-bits.
1074   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1075 
1076   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1077   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1078                                       scaleBranchWeight(FalseCount, Scale));
1079 }
1080 
1081 llvm::MDNode *
createProfileWeights(ArrayRef<uint64_t> Weights) const1082 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const {
1083   // We need at least two elements to create meaningful weights.
1084   if (Weights.size() < 2)
1085     return nullptr;
1086 
1087   // Check for empty weights.
1088   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1089   if (MaxWeight == 0)
1090     return nullptr;
1091 
1092   // Calculate how to scale down to 32-bits.
1093   uint64_t Scale = calculateWeightScale(MaxWeight);
1094 
1095   SmallVector<uint32_t, 16> ScaledWeights;
1096   ScaledWeights.reserve(Weights.size());
1097   for (uint64_t W : Weights)
1098     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1099 
1100   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1101   return MDHelper.createBranchWeights(ScaledWeights);
1102 }
1103 
1104 llvm::MDNode *
createProfileWeightsForLoop(const Stmt * Cond,uint64_t LoopCount) const1105 CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1106                                              uint64_t LoopCount) const {
1107   if (!PGO.haveRegionCounts())
1108     return nullptr;
1109   Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1110   if (!CondCount || *CondCount == 0)
1111     return nullptr;
1112   return createProfileWeights(LoopCount,
1113                               std::max(*CondCount, LoopCount) - LoopCount);
1114 }
1115