1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Instrumentation-based profile-guided optimization
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "CodeGenPGO.h"
14 #include "CodeGenFunction.h"
15 #include "CoverageMappingGen.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
24
25 static llvm::cl::opt<bool>
26 EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore,
27 llvm::cl::desc("Enable value profiling"),
28 llvm::cl::Hidden, llvm::cl::init(false));
29
30 using namespace clang;
31 using namespace CodeGen;
32
setFuncName(StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)33 void CodeGenPGO::setFuncName(StringRef Name,
34 llvm::GlobalValue::LinkageTypes Linkage) {
35 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
36 FuncName = llvm::getPGOFuncName(
37 Name, Linkage, CGM.getCodeGenOpts().MainFileName,
38 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
39
40 // If we're generating a profile, create a variable for the name.
41 if (CGM.getCodeGenOpts().hasProfileClangInstr())
42 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
43 }
44
setFuncName(llvm::Function * Fn)45 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
46 setFuncName(Fn->getName(), Fn->getLinkage());
47 // Create PGOFuncName meta data.
48 llvm::createPGOFuncNameMetadata(*Fn, FuncName);
49 }
50
51 /// The version of the PGO hash algorithm.
52 enum PGOHashVersion : unsigned {
53 PGO_HASH_V1,
54 PGO_HASH_V2,
55 PGO_HASH_V3,
56
57 // Keep this set to the latest hash version.
58 PGO_HASH_LATEST = PGO_HASH_V3
59 };
60
61 namespace {
62 /// Stable hasher for PGO region counters.
63 ///
64 /// PGOHash produces a stable hash of a given function's control flow.
65 ///
66 /// Changing the output of this hash will invalidate all previously generated
67 /// profiles -- i.e., don't do it.
68 ///
69 /// \note When this hash does eventually change (years?), we still need to
70 /// support old hashes. We'll need to pull in the version number from the
71 /// profile data format and use the matching hash function.
72 class PGOHash {
73 uint64_t Working;
74 unsigned Count;
75 PGOHashVersion HashVersion;
76 llvm::MD5 MD5;
77
78 static const int NumBitsPerType = 6;
79 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
80 static const unsigned TooBig = 1u << NumBitsPerType;
81
82 public:
83 /// Hash values for AST nodes.
84 ///
85 /// Distinct values for AST nodes that have region counters attached.
86 ///
87 /// These values must be stable. All new members must be added at the end,
88 /// and no members should be removed. Changing the enumeration value for an
89 /// AST node will affect the hash of every function that contains that node.
90 enum HashType : unsigned char {
91 None = 0,
92 LabelStmt = 1,
93 WhileStmt,
94 DoStmt,
95 ForStmt,
96 CXXForRangeStmt,
97 ObjCForCollectionStmt,
98 SwitchStmt,
99 CaseStmt,
100 DefaultStmt,
101 IfStmt,
102 CXXTryStmt,
103 CXXCatchStmt,
104 ConditionalOperator,
105 BinaryOperatorLAnd,
106 BinaryOperatorLOr,
107 BinaryConditionalOperator,
108 // The preceding values are available with PGO_HASH_V1.
109
110 EndOfScope,
111 IfThenBranch,
112 IfElseBranch,
113 GotoStmt,
114 IndirectGotoStmt,
115 BreakStmt,
116 ContinueStmt,
117 ReturnStmt,
118 ThrowExpr,
119 UnaryOperatorLNot,
120 BinaryOperatorLT,
121 BinaryOperatorGT,
122 BinaryOperatorLE,
123 BinaryOperatorGE,
124 BinaryOperatorEQ,
125 BinaryOperatorNE,
126 // The preceding values are available since PGO_HASH_V2.
127
128 // Keep this last. It's for the static assert that follows.
129 LastHashType
130 };
131 static_assert(LastHashType <= TooBig, "Too many types in HashType");
132
PGOHash(PGOHashVersion HashVersion)133 PGOHash(PGOHashVersion HashVersion)
134 : Working(0), Count(0), HashVersion(HashVersion), MD5() {}
135 void combine(HashType Type);
136 uint64_t finalize();
getHashVersion() const137 PGOHashVersion getHashVersion() const { return HashVersion; }
138 };
139 const int PGOHash::NumBitsPerType;
140 const unsigned PGOHash::NumTypesPerWord;
141 const unsigned PGOHash::TooBig;
142
143 /// Get the PGO hash version used in the given indexed profile.
getPGOHashVersion(llvm::IndexedInstrProfReader * PGOReader,CodeGenModule & CGM)144 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
145 CodeGenModule &CGM) {
146 if (PGOReader->getVersion() <= 4)
147 return PGO_HASH_V1;
148 if (PGOReader->getVersion() <= 5)
149 return PGO_HASH_V2;
150 return PGO_HASH_V3;
151 }
152
153 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
154 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
155 using Base = RecursiveASTVisitor<MapRegionCounters>;
156
157 /// The next counter value to assign.
158 unsigned NextCounter;
159 /// The function hash.
160 PGOHash Hash;
161 /// The map of statements to counters.
162 llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
163 /// The profile version.
164 uint64_t ProfileVersion;
165
MapRegionCounters__anon2b2d3e140111::MapRegionCounters166 MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion,
167 llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
168 : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap),
169 ProfileVersion(ProfileVersion) {}
170
171 // Blocks and lambdas are handled as separate functions, so we need not
172 // traverse them in the parent context.
TraverseBlockExpr__anon2b2d3e140111::MapRegionCounters173 bool TraverseBlockExpr(BlockExpr *BE) { return true; }
TraverseLambdaExpr__anon2b2d3e140111::MapRegionCounters174 bool TraverseLambdaExpr(LambdaExpr *LE) {
175 // Traverse the captures, but not the body.
176 for (auto C : zip(LE->captures(), LE->capture_inits()))
177 TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
178 return true;
179 }
TraverseCapturedStmt__anon2b2d3e140111::MapRegionCounters180 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
181
VisitDecl__anon2b2d3e140111::MapRegionCounters182 bool VisitDecl(const Decl *D) {
183 switch (D->getKind()) {
184 default:
185 break;
186 case Decl::Function:
187 case Decl::CXXMethod:
188 case Decl::CXXConstructor:
189 case Decl::CXXDestructor:
190 case Decl::CXXConversion:
191 case Decl::ObjCMethod:
192 case Decl::Block:
193 case Decl::Captured:
194 CounterMap[D->getBody()] = NextCounter++;
195 break;
196 }
197 return true;
198 }
199
200 /// If \p S gets a fresh counter, update the counter mappings. Return the
201 /// V1 hash of \p S.
updateCounterMappings__anon2b2d3e140111::MapRegionCounters202 PGOHash::HashType updateCounterMappings(Stmt *S) {
203 auto Type = getHashType(PGO_HASH_V1, S);
204 if (Type != PGOHash::None)
205 CounterMap[S] = NextCounter++;
206 return Type;
207 }
208
209 /// The RHS of all logical operators gets a fresh counter in order to count
210 /// how many times the RHS evaluates to true or false, depending on the
211 /// semantics of the operator. This is only valid for ">= v7" of the profile
212 /// version so that we facilitate backward compatibility.
VisitBinaryOperator__anon2b2d3e140111::MapRegionCounters213 bool VisitBinaryOperator(BinaryOperator *S) {
214 if (ProfileVersion >= llvm::IndexedInstrProf::Version7)
215 if (S->isLogicalOp() &&
216 CodeGenFunction::isInstrumentedCondition(S->getRHS()))
217 CounterMap[S->getRHS()] = NextCounter++;
218 return Base::VisitBinaryOperator(S);
219 }
220
221 /// Include \p S in the function hash.
VisitStmt__anon2b2d3e140111::MapRegionCounters222 bool VisitStmt(Stmt *S) {
223 auto Type = updateCounterMappings(S);
224 if (Hash.getHashVersion() != PGO_HASH_V1)
225 Type = getHashType(Hash.getHashVersion(), S);
226 if (Type != PGOHash::None)
227 Hash.combine(Type);
228 return true;
229 }
230
TraverseIfStmt__anon2b2d3e140111::MapRegionCounters231 bool TraverseIfStmt(IfStmt *If) {
232 // If we used the V1 hash, use the default traversal.
233 if (Hash.getHashVersion() == PGO_HASH_V1)
234 return Base::TraverseIfStmt(If);
235
236 // Otherwise, keep track of which branch we're in while traversing.
237 VisitStmt(If);
238 for (Stmt *CS : If->children()) {
239 if (!CS)
240 continue;
241 if (CS == If->getThen())
242 Hash.combine(PGOHash::IfThenBranch);
243 else if (CS == If->getElse())
244 Hash.combine(PGOHash::IfElseBranch);
245 TraverseStmt(CS);
246 }
247 Hash.combine(PGOHash::EndOfScope);
248 return true;
249 }
250
251 // If the statement type \p N is nestable, and its nesting impacts profile
252 // stability, define a custom traversal which tracks the end of the statement
253 // in the hash (provided we're not using the V1 hash).
254 #define DEFINE_NESTABLE_TRAVERSAL(N) \
255 bool Traverse##N(N *S) { \
256 Base::Traverse##N(S); \
257 if (Hash.getHashVersion() != PGO_HASH_V1) \
258 Hash.combine(PGOHash::EndOfScope); \
259 return true; \
260 }
261
262 DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
DEFINE_NESTABLE_TRAVERSAL__anon2b2d3e140111::MapRegionCounters263 DEFINE_NESTABLE_TRAVERSAL(DoStmt)
264 DEFINE_NESTABLE_TRAVERSAL(ForStmt)
265 DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
266 DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
267 DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
268 DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
269
270 /// Get version \p HashVersion of the PGO hash for \p S.
271 PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
272 switch (S->getStmtClass()) {
273 default:
274 break;
275 case Stmt::LabelStmtClass:
276 return PGOHash::LabelStmt;
277 case Stmt::WhileStmtClass:
278 return PGOHash::WhileStmt;
279 case Stmt::DoStmtClass:
280 return PGOHash::DoStmt;
281 case Stmt::ForStmtClass:
282 return PGOHash::ForStmt;
283 case Stmt::CXXForRangeStmtClass:
284 return PGOHash::CXXForRangeStmt;
285 case Stmt::ObjCForCollectionStmtClass:
286 return PGOHash::ObjCForCollectionStmt;
287 case Stmt::SwitchStmtClass:
288 return PGOHash::SwitchStmt;
289 case Stmt::CaseStmtClass:
290 return PGOHash::CaseStmt;
291 case Stmt::DefaultStmtClass:
292 return PGOHash::DefaultStmt;
293 case Stmt::IfStmtClass:
294 return PGOHash::IfStmt;
295 case Stmt::CXXTryStmtClass:
296 return PGOHash::CXXTryStmt;
297 case Stmt::CXXCatchStmtClass:
298 return PGOHash::CXXCatchStmt;
299 case Stmt::ConditionalOperatorClass:
300 return PGOHash::ConditionalOperator;
301 case Stmt::BinaryConditionalOperatorClass:
302 return PGOHash::BinaryConditionalOperator;
303 case Stmt::BinaryOperatorClass: {
304 const BinaryOperator *BO = cast<BinaryOperator>(S);
305 if (BO->getOpcode() == BO_LAnd)
306 return PGOHash::BinaryOperatorLAnd;
307 if (BO->getOpcode() == BO_LOr)
308 return PGOHash::BinaryOperatorLOr;
309 if (HashVersion >= PGO_HASH_V2) {
310 switch (BO->getOpcode()) {
311 default:
312 break;
313 case BO_LT:
314 return PGOHash::BinaryOperatorLT;
315 case BO_GT:
316 return PGOHash::BinaryOperatorGT;
317 case BO_LE:
318 return PGOHash::BinaryOperatorLE;
319 case BO_GE:
320 return PGOHash::BinaryOperatorGE;
321 case BO_EQ:
322 return PGOHash::BinaryOperatorEQ;
323 case BO_NE:
324 return PGOHash::BinaryOperatorNE;
325 }
326 }
327 break;
328 }
329 }
330
331 if (HashVersion >= PGO_HASH_V2) {
332 switch (S->getStmtClass()) {
333 default:
334 break;
335 case Stmt::GotoStmtClass:
336 return PGOHash::GotoStmt;
337 case Stmt::IndirectGotoStmtClass:
338 return PGOHash::IndirectGotoStmt;
339 case Stmt::BreakStmtClass:
340 return PGOHash::BreakStmt;
341 case Stmt::ContinueStmtClass:
342 return PGOHash::ContinueStmt;
343 case Stmt::ReturnStmtClass:
344 return PGOHash::ReturnStmt;
345 case Stmt::CXXThrowExprClass:
346 return PGOHash::ThrowExpr;
347 case Stmt::UnaryOperatorClass: {
348 const UnaryOperator *UO = cast<UnaryOperator>(S);
349 if (UO->getOpcode() == UO_LNot)
350 return PGOHash::UnaryOperatorLNot;
351 break;
352 }
353 }
354 }
355
356 return PGOHash::None;
357 }
358 };
359
360 /// A StmtVisitor that propagates the raw counts through the AST and
361 /// records the count at statements where the value may change.
362 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
363 /// PGO state.
364 CodeGenPGO &PGO;
365
366 /// A flag that is set when the current count should be recorded on the
367 /// next statement, such as at the exit of a loop.
368 bool RecordNextStmtCount;
369
370 /// The count at the current location in the traversal.
371 uint64_t CurrentCount;
372
373 /// The map of statements to count values.
374 llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
375
376 /// BreakContinueStack - Keep counts of breaks and continues inside loops.
377 struct BreakContinue {
378 uint64_t BreakCount;
379 uint64_t ContinueCount;
BreakContinue__anon2b2d3e140111::ComputeRegionCounts::BreakContinue380 BreakContinue() : BreakCount(0), ContinueCount(0) {}
381 };
382 SmallVector<BreakContinue, 8> BreakContinueStack;
383
ComputeRegionCounts__anon2b2d3e140111::ComputeRegionCounts384 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
385 CodeGenPGO &PGO)
386 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
387
RecordStmtCount__anon2b2d3e140111::ComputeRegionCounts388 void RecordStmtCount(const Stmt *S) {
389 if (RecordNextStmtCount) {
390 CountMap[S] = CurrentCount;
391 RecordNextStmtCount = false;
392 }
393 }
394
395 /// Set and return the current count.
setCount__anon2b2d3e140111::ComputeRegionCounts396 uint64_t setCount(uint64_t Count) {
397 CurrentCount = Count;
398 return Count;
399 }
400
VisitStmt__anon2b2d3e140111::ComputeRegionCounts401 void VisitStmt(const Stmt *S) {
402 RecordStmtCount(S);
403 for (const Stmt *Child : S->children())
404 if (Child)
405 this->Visit(Child);
406 }
407
VisitFunctionDecl__anon2b2d3e140111::ComputeRegionCounts408 void VisitFunctionDecl(const FunctionDecl *D) {
409 // Counter tracks entry to the function body.
410 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
411 CountMap[D->getBody()] = BodyCount;
412 Visit(D->getBody());
413 }
414
415 // Skip lambda expressions. We visit these as FunctionDecls when we're
416 // generating them and aren't interested in the body when generating a
417 // parent context.
VisitLambdaExpr__anon2b2d3e140111::ComputeRegionCounts418 void VisitLambdaExpr(const LambdaExpr *LE) {}
419
VisitCapturedDecl__anon2b2d3e140111::ComputeRegionCounts420 void VisitCapturedDecl(const CapturedDecl *D) {
421 // Counter tracks entry to the capture body.
422 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
423 CountMap[D->getBody()] = BodyCount;
424 Visit(D->getBody());
425 }
426
VisitObjCMethodDecl__anon2b2d3e140111::ComputeRegionCounts427 void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
428 // Counter tracks entry to the method body.
429 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
430 CountMap[D->getBody()] = BodyCount;
431 Visit(D->getBody());
432 }
433
VisitBlockDecl__anon2b2d3e140111::ComputeRegionCounts434 void VisitBlockDecl(const BlockDecl *D) {
435 // Counter tracks entry to the block body.
436 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
437 CountMap[D->getBody()] = BodyCount;
438 Visit(D->getBody());
439 }
440
VisitReturnStmt__anon2b2d3e140111::ComputeRegionCounts441 void VisitReturnStmt(const ReturnStmt *S) {
442 RecordStmtCount(S);
443 if (S->getRetValue())
444 Visit(S->getRetValue());
445 CurrentCount = 0;
446 RecordNextStmtCount = true;
447 }
448
VisitCXXThrowExpr__anon2b2d3e140111::ComputeRegionCounts449 void VisitCXXThrowExpr(const CXXThrowExpr *E) {
450 RecordStmtCount(E);
451 if (E->getSubExpr())
452 Visit(E->getSubExpr());
453 CurrentCount = 0;
454 RecordNextStmtCount = true;
455 }
456
VisitGotoStmt__anon2b2d3e140111::ComputeRegionCounts457 void VisitGotoStmt(const GotoStmt *S) {
458 RecordStmtCount(S);
459 CurrentCount = 0;
460 RecordNextStmtCount = true;
461 }
462
VisitLabelStmt__anon2b2d3e140111::ComputeRegionCounts463 void VisitLabelStmt(const LabelStmt *S) {
464 RecordNextStmtCount = false;
465 // Counter tracks the block following the label.
466 uint64_t BlockCount = setCount(PGO.getRegionCount(S));
467 CountMap[S] = BlockCount;
468 Visit(S->getSubStmt());
469 }
470
VisitBreakStmt__anon2b2d3e140111::ComputeRegionCounts471 void VisitBreakStmt(const BreakStmt *S) {
472 RecordStmtCount(S);
473 assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
474 BreakContinueStack.back().BreakCount += CurrentCount;
475 CurrentCount = 0;
476 RecordNextStmtCount = true;
477 }
478
VisitContinueStmt__anon2b2d3e140111::ComputeRegionCounts479 void VisitContinueStmt(const ContinueStmt *S) {
480 RecordStmtCount(S);
481 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
482 BreakContinueStack.back().ContinueCount += CurrentCount;
483 CurrentCount = 0;
484 RecordNextStmtCount = true;
485 }
486
VisitWhileStmt__anon2b2d3e140111::ComputeRegionCounts487 void VisitWhileStmt(const WhileStmt *S) {
488 RecordStmtCount(S);
489 uint64_t ParentCount = CurrentCount;
490
491 BreakContinueStack.push_back(BreakContinue());
492 // Visit the body region first so the break/continue adjustments can be
493 // included when visiting the condition.
494 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
495 CountMap[S->getBody()] = CurrentCount;
496 Visit(S->getBody());
497 uint64_t BackedgeCount = CurrentCount;
498
499 // ...then go back and propagate counts through the condition. The count
500 // at the start of the condition is the sum of the incoming edges,
501 // the backedge from the end of the loop body, and the edges from
502 // continue statements.
503 BreakContinue BC = BreakContinueStack.pop_back_val();
504 uint64_t CondCount =
505 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
506 CountMap[S->getCond()] = CondCount;
507 Visit(S->getCond());
508 setCount(BC.BreakCount + CondCount - BodyCount);
509 RecordNextStmtCount = true;
510 }
511
VisitDoStmt__anon2b2d3e140111::ComputeRegionCounts512 void VisitDoStmt(const DoStmt *S) {
513 RecordStmtCount(S);
514 uint64_t LoopCount = PGO.getRegionCount(S);
515
516 BreakContinueStack.push_back(BreakContinue());
517 // The count doesn't include the fallthrough from the parent scope. Add it.
518 uint64_t BodyCount = setCount(LoopCount + CurrentCount);
519 CountMap[S->getBody()] = BodyCount;
520 Visit(S->getBody());
521 uint64_t BackedgeCount = CurrentCount;
522
523 BreakContinue BC = BreakContinueStack.pop_back_val();
524 // The count at the start of the condition is equal to the count at the
525 // end of the body, plus any continues.
526 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
527 CountMap[S->getCond()] = CondCount;
528 Visit(S->getCond());
529 setCount(BC.BreakCount + CondCount - LoopCount);
530 RecordNextStmtCount = true;
531 }
532
VisitForStmt__anon2b2d3e140111::ComputeRegionCounts533 void VisitForStmt(const ForStmt *S) {
534 RecordStmtCount(S);
535 if (S->getInit())
536 Visit(S->getInit());
537
538 uint64_t ParentCount = CurrentCount;
539
540 BreakContinueStack.push_back(BreakContinue());
541 // Visit the body region first. (This is basically the same as a while
542 // loop; see further comments in VisitWhileStmt.)
543 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
544 CountMap[S->getBody()] = BodyCount;
545 Visit(S->getBody());
546 uint64_t BackedgeCount = CurrentCount;
547 BreakContinue BC = BreakContinueStack.pop_back_val();
548
549 // The increment is essentially part of the body but it needs to include
550 // the count for all the continue statements.
551 if (S->getInc()) {
552 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
553 CountMap[S->getInc()] = IncCount;
554 Visit(S->getInc());
555 }
556
557 // ...then go back and propagate counts through the condition.
558 uint64_t CondCount =
559 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
560 if (S->getCond()) {
561 CountMap[S->getCond()] = CondCount;
562 Visit(S->getCond());
563 }
564 setCount(BC.BreakCount + CondCount - BodyCount);
565 RecordNextStmtCount = true;
566 }
567
VisitCXXForRangeStmt__anon2b2d3e140111::ComputeRegionCounts568 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
569 RecordStmtCount(S);
570 if (S->getInit())
571 Visit(S->getInit());
572 Visit(S->getLoopVarStmt());
573 Visit(S->getRangeStmt());
574 Visit(S->getBeginStmt());
575 Visit(S->getEndStmt());
576
577 uint64_t ParentCount = CurrentCount;
578 BreakContinueStack.push_back(BreakContinue());
579 // Visit the body region first. (This is basically the same as a while
580 // loop; see further comments in VisitWhileStmt.)
581 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
582 CountMap[S->getBody()] = BodyCount;
583 Visit(S->getBody());
584 uint64_t BackedgeCount = CurrentCount;
585 BreakContinue BC = BreakContinueStack.pop_back_val();
586
587 // The increment is essentially part of the body but it needs to include
588 // the count for all the continue statements.
589 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
590 CountMap[S->getInc()] = IncCount;
591 Visit(S->getInc());
592
593 // ...then go back and propagate counts through the condition.
594 uint64_t CondCount =
595 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
596 CountMap[S->getCond()] = CondCount;
597 Visit(S->getCond());
598 setCount(BC.BreakCount + CondCount - BodyCount);
599 RecordNextStmtCount = true;
600 }
601
VisitObjCForCollectionStmt__anon2b2d3e140111::ComputeRegionCounts602 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
603 RecordStmtCount(S);
604 Visit(S->getElement());
605 uint64_t ParentCount = CurrentCount;
606 BreakContinueStack.push_back(BreakContinue());
607 // Counter tracks the body of the loop.
608 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
609 CountMap[S->getBody()] = BodyCount;
610 Visit(S->getBody());
611 uint64_t BackedgeCount = CurrentCount;
612 BreakContinue BC = BreakContinueStack.pop_back_val();
613
614 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
615 BodyCount);
616 RecordNextStmtCount = true;
617 }
618
VisitSwitchStmt__anon2b2d3e140111::ComputeRegionCounts619 void VisitSwitchStmt(const SwitchStmt *S) {
620 RecordStmtCount(S);
621 if (S->getInit())
622 Visit(S->getInit());
623 Visit(S->getCond());
624 CurrentCount = 0;
625 BreakContinueStack.push_back(BreakContinue());
626 Visit(S->getBody());
627 // If the switch is inside a loop, add the continue counts.
628 BreakContinue BC = BreakContinueStack.pop_back_val();
629 if (!BreakContinueStack.empty())
630 BreakContinueStack.back().ContinueCount += BC.ContinueCount;
631 // Counter tracks the exit block of the switch.
632 setCount(PGO.getRegionCount(S));
633 RecordNextStmtCount = true;
634 }
635
VisitSwitchCase__anon2b2d3e140111::ComputeRegionCounts636 void VisitSwitchCase(const SwitchCase *S) {
637 RecordNextStmtCount = false;
638 // Counter for this particular case. This counts only jumps from the
639 // switch header and does not include fallthrough from the case before
640 // this one.
641 uint64_t CaseCount = PGO.getRegionCount(S);
642 setCount(CurrentCount + CaseCount);
643 // We need the count without fallthrough in the mapping, so it's more useful
644 // for branch probabilities.
645 CountMap[S] = CaseCount;
646 RecordNextStmtCount = true;
647 Visit(S->getSubStmt());
648 }
649
VisitIfStmt__anon2b2d3e140111::ComputeRegionCounts650 void VisitIfStmt(const IfStmt *S) {
651 RecordStmtCount(S);
652 uint64_t ParentCount = CurrentCount;
653 if (S->getInit())
654 Visit(S->getInit());
655 Visit(S->getCond());
656
657 // Counter tracks the "then" part of an if statement. The count for
658 // the "else" part, if it exists, will be calculated from this counter.
659 uint64_t ThenCount = setCount(PGO.getRegionCount(S));
660 CountMap[S->getThen()] = ThenCount;
661 Visit(S->getThen());
662 uint64_t OutCount = CurrentCount;
663
664 uint64_t ElseCount = ParentCount - ThenCount;
665 if (S->getElse()) {
666 setCount(ElseCount);
667 CountMap[S->getElse()] = ElseCount;
668 Visit(S->getElse());
669 OutCount += CurrentCount;
670 } else
671 OutCount += ElseCount;
672 setCount(OutCount);
673 RecordNextStmtCount = true;
674 }
675
VisitCXXTryStmt__anon2b2d3e140111::ComputeRegionCounts676 void VisitCXXTryStmt(const CXXTryStmt *S) {
677 RecordStmtCount(S);
678 Visit(S->getTryBlock());
679 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
680 Visit(S->getHandler(I));
681 // Counter tracks the continuation block of the try statement.
682 setCount(PGO.getRegionCount(S));
683 RecordNextStmtCount = true;
684 }
685
VisitCXXCatchStmt__anon2b2d3e140111::ComputeRegionCounts686 void VisitCXXCatchStmt(const CXXCatchStmt *S) {
687 RecordNextStmtCount = false;
688 // Counter tracks the catch statement's handler block.
689 uint64_t CatchCount = setCount(PGO.getRegionCount(S));
690 CountMap[S] = CatchCount;
691 Visit(S->getHandlerBlock());
692 }
693
VisitAbstractConditionalOperator__anon2b2d3e140111::ComputeRegionCounts694 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
695 RecordStmtCount(E);
696 uint64_t ParentCount = CurrentCount;
697 Visit(E->getCond());
698
699 // Counter tracks the "true" part of a conditional operator. The
700 // count in the "false" part will be calculated from this counter.
701 uint64_t TrueCount = setCount(PGO.getRegionCount(E));
702 CountMap[E->getTrueExpr()] = TrueCount;
703 Visit(E->getTrueExpr());
704 uint64_t OutCount = CurrentCount;
705
706 uint64_t FalseCount = setCount(ParentCount - TrueCount);
707 CountMap[E->getFalseExpr()] = FalseCount;
708 Visit(E->getFalseExpr());
709 OutCount += CurrentCount;
710
711 setCount(OutCount);
712 RecordNextStmtCount = true;
713 }
714
VisitBinLAnd__anon2b2d3e140111::ComputeRegionCounts715 void VisitBinLAnd(const BinaryOperator *E) {
716 RecordStmtCount(E);
717 uint64_t ParentCount = CurrentCount;
718 Visit(E->getLHS());
719 // Counter tracks the right hand side of a logical and operator.
720 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
721 CountMap[E->getRHS()] = RHSCount;
722 Visit(E->getRHS());
723 setCount(ParentCount + RHSCount - CurrentCount);
724 RecordNextStmtCount = true;
725 }
726
VisitBinLOr__anon2b2d3e140111::ComputeRegionCounts727 void VisitBinLOr(const BinaryOperator *E) {
728 RecordStmtCount(E);
729 uint64_t ParentCount = CurrentCount;
730 Visit(E->getLHS());
731 // Counter tracks the right hand side of a logical or operator.
732 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
733 CountMap[E->getRHS()] = RHSCount;
734 Visit(E->getRHS());
735 setCount(ParentCount + RHSCount - CurrentCount);
736 RecordNextStmtCount = true;
737 }
738 };
739 } // end anonymous namespace
740
combine(HashType Type)741 void PGOHash::combine(HashType Type) {
742 // Check that we never combine 0 and only have six bits.
743 assert(Type && "Hash is invalid: unexpected type 0");
744 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
745
746 // Pass through MD5 if enough work has built up.
747 if (Count && Count % NumTypesPerWord == 0) {
748 using namespace llvm::support;
749 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
750 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
751 Working = 0;
752 }
753
754 // Accumulate the current type.
755 ++Count;
756 Working = Working << NumBitsPerType | Type;
757 }
758
finalize()759 uint64_t PGOHash::finalize() {
760 // Use Working as the hash directly if we never used MD5.
761 if (Count <= NumTypesPerWord)
762 // No need to byte swap here, since none of the math was endian-dependent.
763 // This number will be byte-swapped as required on endianness transitions,
764 // so we will see the same value on the other side.
765 return Working;
766
767 // Check for remaining work in Working.
768 if (Working) {
769 // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
770 // is buggy because it converts a uint64_t into an array of uint8_t.
771 if (HashVersion < PGO_HASH_V3) {
772 MD5.update({(uint8_t)Working});
773 } else {
774 using namespace llvm::support;
775 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
776 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
777 }
778 }
779
780 // Finalize the MD5 and return the hash.
781 llvm::MD5::MD5Result Result;
782 MD5.final(Result);
783 return Result.low();
784 }
785
assignRegionCounters(GlobalDecl GD,llvm::Function * Fn)786 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
787 const Decl *D = GD.getDecl();
788 if (!D->hasBody())
789 return;
790
791 // Skip CUDA/HIP kernel launch stub functions.
792 if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
793 D->hasAttr<CUDAGlobalAttr>())
794 return;
795
796 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
797 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
798 if (!InstrumentRegions && !PGOReader)
799 return;
800 if (D->isImplicit())
801 return;
802 // Constructors and destructors may be represented by several functions in IR.
803 // If so, instrument only base variant, others are implemented by delegation
804 // to the base one, it would be counted twice otherwise.
805 if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
806 if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
807 if (GD.getCtorType() != Ctor_Base &&
808 CodeGenFunction::IsConstructorDelegationValid(CCD))
809 return;
810 }
811 if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
812 return;
813
814 CGM.ClearUnusedCoverageMapping(D);
815 if (Fn->hasFnAttribute(llvm::Attribute::NoProfile))
816 return;
817
818 setFuncName(Fn);
819
820 mapRegionCounters(D);
821 if (CGM.getCodeGenOpts().CoverageMapping)
822 emitCounterRegionMapping(D);
823 if (PGOReader) {
824 SourceManager &SM = CGM.getContext().getSourceManager();
825 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
826 computeRegionCounts(D);
827 applyFunctionAttributes(PGOReader, Fn);
828 }
829 }
830
mapRegionCounters(const Decl * D)831 void CodeGenPGO::mapRegionCounters(const Decl *D) {
832 // Use the latest hash version when inserting instrumentation, but use the
833 // version in the indexed profile if we're reading PGO data.
834 PGOHashVersion HashVersion = PGO_HASH_LATEST;
835 uint64_t ProfileVersion = llvm::IndexedInstrProf::Version;
836 if (auto *PGOReader = CGM.getPGOReader()) {
837 HashVersion = getPGOHashVersion(PGOReader, CGM);
838 ProfileVersion = PGOReader->getVersion();
839 }
840
841 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
842 MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap);
843 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
844 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
845 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
846 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
847 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
848 Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
849 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
850 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
851 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
852 NumRegionCounters = Walker.NextCounter;
853 FunctionHash = Walker.Hash.finalize();
854 }
855
skipRegionMappingForDecl(const Decl * D)856 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
857 if (!D->getBody())
858 return true;
859
860 // Skip host-only functions in the CUDA device compilation and device-only
861 // functions in the host compilation. Just roughly filter them out based on
862 // the function attributes. If there are effectively host-only or device-only
863 // ones, their coverage mapping may still be generated.
864 if (CGM.getLangOpts().CUDA &&
865 ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() &&
866 !D->hasAttr<CUDAGlobalAttr>()) ||
867 (!CGM.getLangOpts().CUDAIsDevice &&
868 (D->hasAttr<CUDAGlobalAttr>() ||
869 (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>())))))
870 return true;
871
872 // Don't map the functions in system headers.
873 const auto &SM = CGM.getContext().getSourceManager();
874 auto Loc = D->getBody()->getBeginLoc();
875 return SM.isInSystemHeader(Loc);
876 }
877
emitCounterRegionMapping(const Decl * D)878 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
879 if (skipRegionMappingForDecl(D))
880 return;
881
882 std::string CoverageMapping;
883 llvm::raw_string_ostream OS(CoverageMapping);
884 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
885 CGM.getContext().getSourceManager(),
886 CGM.getLangOpts(), RegionCounterMap.get());
887 MappingGen.emitCounterMapping(D, OS);
888 OS.flush();
889
890 if (CoverageMapping.empty())
891 return;
892
893 CGM.getCoverageMapping()->addFunctionMappingRecord(
894 FuncNameVar, FuncName, FunctionHash, CoverageMapping);
895 }
896
897 void
emitEmptyCounterMapping(const Decl * D,StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)898 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
899 llvm::GlobalValue::LinkageTypes Linkage) {
900 if (skipRegionMappingForDecl(D))
901 return;
902
903 std::string CoverageMapping;
904 llvm::raw_string_ostream OS(CoverageMapping);
905 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
906 CGM.getContext().getSourceManager(),
907 CGM.getLangOpts());
908 MappingGen.emitEmptyMapping(D, OS);
909 OS.flush();
910
911 if (CoverageMapping.empty())
912 return;
913
914 setFuncName(Name, Linkage);
915 CGM.getCoverageMapping()->addFunctionMappingRecord(
916 FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
917 }
918
computeRegionCounts(const Decl * D)919 void CodeGenPGO::computeRegionCounts(const Decl *D) {
920 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
921 ComputeRegionCounts Walker(*StmtCountMap, *this);
922 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
923 Walker.VisitFunctionDecl(FD);
924 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
925 Walker.VisitObjCMethodDecl(MD);
926 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
927 Walker.VisitBlockDecl(BD);
928 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
929 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
930 }
931
932 void
applyFunctionAttributes(llvm::IndexedInstrProfReader * PGOReader,llvm::Function * Fn)933 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
934 llvm::Function *Fn) {
935 if (!haveRegionCounts())
936 return;
937
938 uint64_t FunctionCount = getRegionCount(nullptr);
939 Fn->setEntryCount(FunctionCount);
940 }
941
emitCounterIncrement(CGBuilderTy & Builder,const Stmt * S,llvm::Value * StepV)942 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
943 llvm::Value *StepV) {
944 if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
945 return;
946 if (!Builder.GetInsertBlock())
947 return;
948
949 unsigned Counter = (*RegionCounterMap)[S];
950 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
951
952 llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
953 Builder.getInt64(FunctionHash),
954 Builder.getInt32(NumRegionCounters),
955 Builder.getInt32(Counter), StepV};
956 if (!StepV)
957 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
958 makeArrayRef(Args, 4));
959 else
960 Builder.CreateCall(
961 CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
962 makeArrayRef(Args));
963 }
964
setValueProfilingFlag(llvm::Module & M)965 void CodeGenPGO::setValueProfilingFlag(llvm::Module &M) {
966 if (CGM.getCodeGenOpts().hasProfileClangInstr())
967 M.addModuleFlag(llvm::Module::Warning, "EnableValueProfiling",
968 uint32_t(EnableValueProfiling));
969 }
970
971 // This method either inserts a call to the profile run-time during
972 // instrumentation or puts profile data into metadata for PGO use.
valueProfile(CGBuilderTy & Builder,uint32_t ValueKind,llvm::Instruction * ValueSite,llvm::Value * ValuePtr)973 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
974 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
975
976 if (!EnableValueProfiling)
977 return;
978
979 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
980 return;
981
982 if (isa<llvm::Constant>(ValuePtr))
983 return;
984
985 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
986 if (InstrumentValueSites && RegionCounterMap) {
987 auto BuilderInsertPoint = Builder.saveIP();
988 Builder.SetInsertPoint(ValueSite);
989 llvm::Value *Args[5] = {
990 llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
991 Builder.getInt64(FunctionHash),
992 Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
993 Builder.getInt32(ValueKind),
994 Builder.getInt32(NumValueSites[ValueKind]++)
995 };
996 Builder.CreateCall(
997 CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
998 Builder.restoreIP(BuilderInsertPoint);
999 return;
1000 }
1001
1002 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1003 if (PGOReader && haveRegionCounts()) {
1004 // We record the top most called three functions at each call site.
1005 // Profile metadata contains "VP" string identifying this metadata
1006 // as value profiling data, then a uint32_t value for the value profiling
1007 // kind, a uint64_t value for the total number of times the call is
1008 // executed, followed by the function hash and execution count (uint64_t)
1009 // pairs for each function.
1010 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
1011 return;
1012
1013 llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
1014 (llvm::InstrProfValueKind)ValueKind,
1015 NumValueSites[ValueKind]);
1016
1017 NumValueSites[ValueKind]++;
1018 }
1019 }
1020
loadRegionCounts(llvm::IndexedInstrProfReader * PGOReader,bool IsInMainFile)1021 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
1022 bool IsInMainFile) {
1023 CGM.getPGOStats().addVisited(IsInMainFile);
1024 RegionCounts.clear();
1025 llvm::Expected<llvm::InstrProfRecord> RecordExpected =
1026 PGOReader->getInstrProfRecord(FuncName, FunctionHash);
1027 if (auto E = RecordExpected.takeError()) {
1028 auto IPE = llvm::InstrProfError::take(std::move(E));
1029 if (IPE == llvm::instrprof_error::unknown_function)
1030 CGM.getPGOStats().addMissing(IsInMainFile);
1031 else if (IPE == llvm::instrprof_error::hash_mismatch)
1032 CGM.getPGOStats().addMismatched(IsInMainFile);
1033 else if (IPE == llvm::instrprof_error::malformed)
1034 // TODO: Consider a more specific warning for this case.
1035 CGM.getPGOStats().addMismatched(IsInMainFile);
1036 return;
1037 }
1038 ProfRecord =
1039 std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
1040 RegionCounts = ProfRecord->Counts;
1041 }
1042
1043 /// Calculate what to divide by to scale weights.
1044 ///
1045 /// Given the maximum weight, calculate a divisor that will scale all the
1046 /// weights to strictly less than UINT32_MAX.
calculateWeightScale(uint64_t MaxWeight)1047 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1048 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1049 }
1050
1051 /// Scale an individual branch weight (and add 1).
1052 ///
1053 /// Scale a 64-bit weight down to 32-bits using \c Scale.
1054 ///
1055 /// According to Laplace's Rule of Succession, it is better to compute the
1056 /// weight based on the count plus 1, so universally add 1 to the value.
1057 ///
1058 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1059 /// greater than \c Weight.
scaleBranchWeight(uint64_t Weight,uint64_t Scale)1060 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1061 assert(Scale && "scale by 0?");
1062 uint64_t Scaled = Weight / Scale + 1;
1063 assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1064 return Scaled;
1065 }
1066
createProfileWeights(uint64_t TrueCount,uint64_t FalseCount) const1067 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1068 uint64_t FalseCount) const {
1069 // Check for empty weights.
1070 if (!TrueCount && !FalseCount)
1071 return nullptr;
1072
1073 // Calculate how to scale down to 32-bits.
1074 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1075
1076 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1077 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1078 scaleBranchWeight(FalseCount, Scale));
1079 }
1080
1081 llvm::MDNode *
createProfileWeights(ArrayRef<uint64_t> Weights) const1082 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const {
1083 // We need at least two elements to create meaningful weights.
1084 if (Weights.size() < 2)
1085 return nullptr;
1086
1087 // Check for empty weights.
1088 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1089 if (MaxWeight == 0)
1090 return nullptr;
1091
1092 // Calculate how to scale down to 32-bits.
1093 uint64_t Scale = calculateWeightScale(MaxWeight);
1094
1095 SmallVector<uint32_t, 16> ScaledWeights;
1096 ScaledWeights.reserve(Weights.size());
1097 for (uint64_t W : Weights)
1098 ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1099
1100 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1101 return MDHelper.createBranchWeights(ScaledWeights);
1102 }
1103
1104 llvm::MDNode *
createProfileWeightsForLoop(const Stmt * Cond,uint64_t LoopCount) const1105 CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1106 uint64_t LoopCount) const {
1107 if (!PGO.haveRegionCounts())
1108 return nullptr;
1109 Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1110 if (!CondCount || *CondCount == 0)
1111 return nullptr;
1112 return createProfileWeights(LoopCount,
1113 std::max(*CondCount, LoopCount) - LoopCount);
1114 }
1115