1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Instrumentation-based profile-guided optimization
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "CodeGenPGO.h"
14 #include "CodeGenFunction.h"
15 #include "CoverageMappingGen.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
24 #include <optional>
25
26 static llvm::cl::opt<bool>
27 EnableValueProfiling("enable-value-profiling",
28 llvm::cl::desc("Enable value profiling"),
29 llvm::cl::Hidden, llvm::cl::init(false));
30
31 using namespace clang;
32 using namespace CodeGen;
33
setFuncName(StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)34 void CodeGenPGO::setFuncName(StringRef Name,
35 llvm::GlobalValue::LinkageTypes Linkage) {
36 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
37 FuncName = llvm::getPGOFuncName(
38 Name, Linkage, CGM.getCodeGenOpts().MainFileName,
39 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
40
41 // If we're generating a profile, create a variable for the name.
42 if (CGM.getCodeGenOpts().hasProfileClangInstr())
43 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
44 }
45
setFuncName(llvm::Function * Fn)46 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
47 setFuncName(Fn->getName(), Fn->getLinkage());
48 // Create PGOFuncName meta data.
49 llvm::createPGOFuncNameMetadata(*Fn, FuncName);
50 }
51
52 /// The version of the PGO hash algorithm.
53 enum PGOHashVersion : unsigned {
54 PGO_HASH_V1,
55 PGO_HASH_V2,
56 PGO_HASH_V3,
57
58 // Keep this set to the latest hash version.
59 PGO_HASH_LATEST = PGO_HASH_V3
60 };
61
62 namespace {
63 /// Stable hasher for PGO region counters.
64 ///
65 /// PGOHash produces a stable hash of a given function's control flow.
66 ///
67 /// Changing the output of this hash will invalidate all previously generated
68 /// profiles -- i.e., don't do it.
69 ///
70 /// \note When this hash does eventually change (years?), we still need to
71 /// support old hashes. We'll need to pull in the version number from the
72 /// profile data format and use the matching hash function.
73 class PGOHash {
74 uint64_t Working;
75 unsigned Count;
76 PGOHashVersion HashVersion;
77 llvm::MD5 MD5;
78
79 static const int NumBitsPerType = 6;
80 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
81 static const unsigned TooBig = 1u << NumBitsPerType;
82
83 public:
84 /// Hash values for AST nodes.
85 ///
86 /// Distinct values for AST nodes that have region counters attached.
87 ///
88 /// These values must be stable. All new members must be added at the end,
89 /// and no members should be removed. Changing the enumeration value for an
90 /// AST node will affect the hash of every function that contains that node.
91 enum HashType : unsigned char {
92 None = 0,
93 LabelStmt = 1,
94 WhileStmt,
95 DoStmt,
96 ForStmt,
97 CXXForRangeStmt,
98 ObjCForCollectionStmt,
99 SwitchStmt,
100 CaseStmt,
101 DefaultStmt,
102 IfStmt,
103 CXXTryStmt,
104 CXXCatchStmt,
105 ConditionalOperator,
106 BinaryOperatorLAnd,
107 BinaryOperatorLOr,
108 BinaryConditionalOperator,
109 // The preceding values are available with PGO_HASH_V1.
110
111 EndOfScope,
112 IfThenBranch,
113 IfElseBranch,
114 GotoStmt,
115 IndirectGotoStmt,
116 BreakStmt,
117 ContinueStmt,
118 ReturnStmt,
119 ThrowExpr,
120 UnaryOperatorLNot,
121 BinaryOperatorLT,
122 BinaryOperatorGT,
123 BinaryOperatorLE,
124 BinaryOperatorGE,
125 BinaryOperatorEQ,
126 BinaryOperatorNE,
127 // The preceding values are available since PGO_HASH_V2.
128
129 // Keep this last. It's for the static assert that follows.
130 LastHashType
131 };
132 static_assert(LastHashType <= TooBig, "Too many types in HashType");
133
PGOHash(PGOHashVersion HashVersion)134 PGOHash(PGOHashVersion HashVersion)
135 : Working(0), Count(0), HashVersion(HashVersion) {}
136 void combine(HashType Type);
137 uint64_t finalize();
getHashVersion() const138 PGOHashVersion getHashVersion() const { return HashVersion; }
139 };
140 const int PGOHash::NumBitsPerType;
141 const unsigned PGOHash::NumTypesPerWord;
142 const unsigned PGOHash::TooBig;
143
144 /// Get the PGO hash version used in the given indexed profile.
getPGOHashVersion(llvm::IndexedInstrProfReader * PGOReader,CodeGenModule & CGM)145 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
146 CodeGenModule &CGM) {
147 if (PGOReader->getVersion() <= 4)
148 return PGO_HASH_V1;
149 if (PGOReader->getVersion() <= 5)
150 return PGO_HASH_V2;
151 return PGO_HASH_V3;
152 }
153
154 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
155 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
156 using Base = RecursiveASTVisitor<MapRegionCounters>;
157
158 /// The next counter value to assign.
159 unsigned NextCounter;
160 /// The function hash.
161 PGOHash Hash;
162 /// The map of statements to counters.
163 llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
164 /// The profile version.
165 uint64_t ProfileVersion;
166
MapRegionCounters__anon18fb87260111::MapRegionCounters167 MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion,
168 llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
169 : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap),
170 ProfileVersion(ProfileVersion) {}
171
172 // Blocks and lambdas are handled as separate functions, so we need not
173 // traverse them in the parent context.
TraverseBlockExpr__anon18fb87260111::MapRegionCounters174 bool TraverseBlockExpr(BlockExpr *BE) { return true; }
TraverseLambdaExpr__anon18fb87260111::MapRegionCounters175 bool TraverseLambdaExpr(LambdaExpr *LE) {
176 // Traverse the captures, but not the body.
177 for (auto C : zip(LE->captures(), LE->capture_inits()))
178 TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
179 return true;
180 }
TraverseCapturedStmt__anon18fb87260111::MapRegionCounters181 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
182
VisitDecl__anon18fb87260111::MapRegionCounters183 bool VisitDecl(const Decl *D) {
184 switch (D->getKind()) {
185 default:
186 break;
187 case Decl::Function:
188 case Decl::CXXMethod:
189 case Decl::CXXConstructor:
190 case Decl::CXXDestructor:
191 case Decl::CXXConversion:
192 case Decl::ObjCMethod:
193 case Decl::Block:
194 case Decl::Captured:
195 CounterMap[D->getBody()] = NextCounter++;
196 break;
197 }
198 return true;
199 }
200
201 /// If \p S gets a fresh counter, update the counter mappings. Return the
202 /// V1 hash of \p S.
updateCounterMappings__anon18fb87260111::MapRegionCounters203 PGOHash::HashType updateCounterMappings(Stmt *S) {
204 auto Type = getHashType(PGO_HASH_V1, S);
205 if (Type != PGOHash::None)
206 CounterMap[S] = NextCounter++;
207 return Type;
208 }
209
210 /// The RHS of all logical operators gets a fresh counter in order to count
211 /// how many times the RHS evaluates to true or false, depending on the
212 /// semantics of the operator. This is only valid for ">= v7" of the profile
213 /// version so that we facilitate backward compatibility.
VisitBinaryOperator__anon18fb87260111::MapRegionCounters214 bool VisitBinaryOperator(BinaryOperator *S) {
215 if (ProfileVersion >= llvm::IndexedInstrProf::Version7)
216 if (S->isLogicalOp() &&
217 CodeGenFunction::isInstrumentedCondition(S->getRHS()))
218 CounterMap[S->getRHS()] = NextCounter++;
219 return Base::VisitBinaryOperator(S);
220 }
221
222 /// Include \p S in the function hash.
VisitStmt__anon18fb87260111::MapRegionCounters223 bool VisitStmt(Stmt *S) {
224 auto Type = updateCounterMappings(S);
225 if (Hash.getHashVersion() != PGO_HASH_V1)
226 Type = getHashType(Hash.getHashVersion(), S);
227 if (Type != PGOHash::None)
228 Hash.combine(Type);
229 return true;
230 }
231
TraverseIfStmt__anon18fb87260111::MapRegionCounters232 bool TraverseIfStmt(IfStmt *If) {
233 // If we used the V1 hash, use the default traversal.
234 if (Hash.getHashVersion() == PGO_HASH_V1)
235 return Base::TraverseIfStmt(If);
236
237 // Otherwise, keep track of which branch we're in while traversing.
238 VisitStmt(If);
239 for (Stmt *CS : If->children()) {
240 if (!CS)
241 continue;
242 if (CS == If->getThen())
243 Hash.combine(PGOHash::IfThenBranch);
244 else if (CS == If->getElse())
245 Hash.combine(PGOHash::IfElseBranch);
246 TraverseStmt(CS);
247 }
248 Hash.combine(PGOHash::EndOfScope);
249 return true;
250 }
251
252 // If the statement type \p N is nestable, and its nesting impacts profile
253 // stability, define a custom traversal which tracks the end of the statement
254 // in the hash (provided we're not using the V1 hash).
255 #define DEFINE_NESTABLE_TRAVERSAL(N) \
256 bool Traverse##N(N *S) { \
257 Base::Traverse##N(S); \
258 if (Hash.getHashVersion() != PGO_HASH_V1) \
259 Hash.combine(PGOHash::EndOfScope); \
260 return true; \
261 }
262
263 DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
DEFINE_NESTABLE_TRAVERSAL__anon18fb87260111::MapRegionCounters264 DEFINE_NESTABLE_TRAVERSAL(DoStmt)
265 DEFINE_NESTABLE_TRAVERSAL(ForStmt)
266 DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
267 DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
268 DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
269 DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
270
271 /// Get version \p HashVersion of the PGO hash for \p S.
272 PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
273 switch (S->getStmtClass()) {
274 default:
275 break;
276 case Stmt::LabelStmtClass:
277 return PGOHash::LabelStmt;
278 case Stmt::WhileStmtClass:
279 return PGOHash::WhileStmt;
280 case Stmt::DoStmtClass:
281 return PGOHash::DoStmt;
282 case Stmt::ForStmtClass:
283 return PGOHash::ForStmt;
284 case Stmt::CXXForRangeStmtClass:
285 return PGOHash::CXXForRangeStmt;
286 case Stmt::ObjCForCollectionStmtClass:
287 return PGOHash::ObjCForCollectionStmt;
288 case Stmt::SwitchStmtClass:
289 return PGOHash::SwitchStmt;
290 case Stmt::CaseStmtClass:
291 return PGOHash::CaseStmt;
292 case Stmt::DefaultStmtClass:
293 return PGOHash::DefaultStmt;
294 case Stmt::IfStmtClass:
295 return PGOHash::IfStmt;
296 case Stmt::CXXTryStmtClass:
297 return PGOHash::CXXTryStmt;
298 case Stmt::CXXCatchStmtClass:
299 return PGOHash::CXXCatchStmt;
300 case Stmt::ConditionalOperatorClass:
301 return PGOHash::ConditionalOperator;
302 case Stmt::BinaryConditionalOperatorClass:
303 return PGOHash::BinaryConditionalOperator;
304 case Stmt::BinaryOperatorClass: {
305 const BinaryOperator *BO = cast<BinaryOperator>(S);
306 if (BO->getOpcode() == BO_LAnd)
307 return PGOHash::BinaryOperatorLAnd;
308 if (BO->getOpcode() == BO_LOr)
309 return PGOHash::BinaryOperatorLOr;
310 if (HashVersion >= PGO_HASH_V2) {
311 switch (BO->getOpcode()) {
312 default:
313 break;
314 case BO_LT:
315 return PGOHash::BinaryOperatorLT;
316 case BO_GT:
317 return PGOHash::BinaryOperatorGT;
318 case BO_LE:
319 return PGOHash::BinaryOperatorLE;
320 case BO_GE:
321 return PGOHash::BinaryOperatorGE;
322 case BO_EQ:
323 return PGOHash::BinaryOperatorEQ;
324 case BO_NE:
325 return PGOHash::BinaryOperatorNE;
326 }
327 }
328 break;
329 }
330 }
331
332 if (HashVersion >= PGO_HASH_V2) {
333 switch (S->getStmtClass()) {
334 default:
335 break;
336 case Stmt::GotoStmtClass:
337 return PGOHash::GotoStmt;
338 case Stmt::IndirectGotoStmtClass:
339 return PGOHash::IndirectGotoStmt;
340 case Stmt::BreakStmtClass:
341 return PGOHash::BreakStmt;
342 case Stmt::ContinueStmtClass:
343 return PGOHash::ContinueStmt;
344 case Stmt::ReturnStmtClass:
345 return PGOHash::ReturnStmt;
346 case Stmt::CXXThrowExprClass:
347 return PGOHash::ThrowExpr;
348 case Stmt::UnaryOperatorClass: {
349 const UnaryOperator *UO = cast<UnaryOperator>(S);
350 if (UO->getOpcode() == UO_LNot)
351 return PGOHash::UnaryOperatorLNot;
352 break;
353 }
354 }
355 }
356
357 return PGOHash::None;
358 }
359 };
360
361 /// A StmtVisitor that propagates the raw counts through the AST and
362 /// records the count at statements where the value may change.
363 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
364 /// PGO state.
365 CodeGenPGO &PGO;
366
367 /// A flag that is set when the current count should be recorded on the
368 /// next statement, such as at the exit of a loop.
369 bool RecordNextStmtCount;
370
371 /// The count at the current location in the traversal.
372 uint64_t CurrentCount;
373
374 /// The map of statements to count values.
375 llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
376
377 /// BreakContinueStack - Keep counts of breaks and continues inside loops.
378 struct BreakContinue {
379 uint64_t BreakCount;
380 uint64_t ContinueCount;
BreakContinue__anon18fb87260111::ComputeRegionCounts::BreakContinue381 BreakContinue() : BreakCount(0), ContinueCount(0) {}
382 };
383 SmallVector<BreakContinue, 8> BreakContinueStack;
384
ComputeRegionCounts__anon18fb87260111::ComputeRegionCounts385 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
386 CodeGenPGO &PGO)
387 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
388
RecordStmtCount__anon18fb87260111::ComputeRegionCounts389 void RecordStmtCount(const Stmt *S) {
390 if (RecordNextStmtCount) {
391 CountMap[S] = CurrentCount;
392 RecordNextStmtCount = false;
393 }
394 }
395
396 /// Set and return the current count.
setCount__anon18fb87260111::ComputeRegionCounts397 uint64_t setCount(uint64_t Count) {
398 CurrentCount = Count;
399 return Count;
400 }
401
VisitStmt__anon18fb87260111::ComputeRegionCounts402 void VisitStmt(const Stmt *S) {
403 RecordStmtCount(S);
404 for (const Stmt *Child : S->children())
405 if (Child)
406 this->Visit(Child);
407 }
408
VisitFunctionDecl__anon18fb87260111::ComputeRegionCounts409 void VisitFunctionDecl(const FunctionDecl *D) {
410 // Counter tracks entry to the function body.
411 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
412 CountMap[D->getBody()] = BodyCount;
413 Visit(D->getBody());
414 }
415
416 // Skip lambda expressions. We visit these as FunctionDecls when we're
417 // generating them and aren't interested in the body when generating a
418 // parent context.
VisitLambdaExpr__anon18fb87260111::ComputeRegionCounts419 void VisitLambdaExpr(const LambdaExpr *LE) {}
420
VisitCapturedDecl__anon18fb87260111::ComputeRegionCounts421 void VisitCapturedDecl(const CapturedDecl *D) {
422 // Counter tracks entry to the capture body.
423 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
424 CountMap[D->getBody()] = BodyCount;
425 Visit(D->getBody());
426 }
427
VisitObjCMethodDecl__anon18fb87260111::ComputeRegionCounts428 void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
429 // Counter tracks entry to the method body.
430 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
431 CountMap[D->getBody()] = BodyCount;
432 Visit(D->getBody());
433 }
434
VisitBlockDecl__anon18fb87260111::ComputeRegionCounts435 void VisitBlockDecl(const BlockDecl *D) {
436 // Counter tracks entry to the block body.
437 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
438 CountMap[D->getBody()] = BodyCount;
439 Visit(D->getBody());
440 }
441
VisitReturnStmt__anon18fb87260111::ComputeRegionCounts442 void VisitReturnStmt(const ReturnStmt *S) {
443 RecordStmtCount(S);
444 if (S->getRetValue())
445 Visit(S->getRetValue());
446 CurrentCount = 0;
447 RecordNextStmtCount = true;
448 }
449
VisitCXXThrowExpr__anon18fb87260111::ComputeRegionCounts450 void VisitCXXThrowExpr(const CXXThrowExpr *E) {
451 RecordStmtCount(E);
452 if (E->getSubExpr())
453 Visit(E->getSubExpr());
454 CurrentCount = 0;
455 RecordNextStmtCount = true;
456 }
457
VisitGotoStmt__anon18fb87260111::ComputeRegionCounts458 void VisitGotoStmt(const GotoStmt *S) {
459 RecordStmtCount(S);
460 CurrentCount = 0;
461 RecordNextStmtCount = true;
462 }
463
VisitLabelStmt__anon18fb87260111::ComputeRegionCounts464 void VisitLabelStmt(const LabelStmt *S) {
465 RecordNextStmtCount = false;
466 // Counter tracks the block following the label.
467 uint64_t BlockCount = setCount(PGO.getRegionCount(S));
468 CountMap[S] = BlockCount;
469 Visit(S->getSubStmt());
470 }
471
VisitBreakStmt__anon18fb87260111::ComputeRegionCounts472 void VisitBreakStmt(const BreakStmt *S) {
473 RecordStmtCount(S);
474 assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
475 BreakContinueStack.back().BreakCount += CurrentCount;
476 CurrentCount = 0;
477 RecordNextStmtCount = true;
478 }
479
VisitContinueStmt__anon18fb87260111::ComputeRegionCounts480 void VisitContinueStmt(const ContinueStmt *S) {
481 RecordStmtCount(S);
482 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
483 BreakContinueStack.back().ContinueCount += CurrentCount;
484 CurrentCount = 0;
485 RecordNextStmtCount = true;
486 }
487
VisitWhileStmt__anon18fb87260111::ComputeRegionCounts488 void VisitWhileStmt(const WhileStmt *S) {
489 RecordStmtCount(S);
490 uint64_t ParentCount = CurrentCount;
491
492 BreakContinueStack.push_back(BreakContinue());
493 // Visit the body region first so the break/continue adjustments can be
494 // included when visiting the condition.
495 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
496 CountMap[S->getBody()] = CurrentCount;
497 Visit(S->getBody());
498 uint64_t BackedgeCount = CurrentCount;
499
500 // ...then go back and propagate counts through the condition. The count
501 // at the start of the condition is the sum of the incoming edges,
502 // the backedge from the end of the loop body, and the edges from
503 // continue statements.
504 BreakContinue BC = BreakContinueStack.pop_back_val();
505 uint64_t CondCount =
506 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
507 CountMap[S->getCond()] = CondCount;
508 Visit(S->getCond());
509 setCount(BC.BreakCount + CondCount - BodyCount);
510 RecordNextStmtCount = true;
511 }
512
VisitDoStmt__anon18fb87260111::ComputeRegionCounts513 void VisitDoStmt(const DoStmt *S) {
514 RecordStmtCount(S);
515 uint64_t LoopCount = PGO.getRegionCount(S);
516
517 BreakContinueStack.push_back(BreakContinue());
518 // The count doesn't include the fallthrough from the parent scope. Add it.
519 uint64_t BodyCount = setCount(LoopCount + CurrentCount);
520 CountMap[S->getBody()] = BodyCount;
521 Visit(S->getBody());
522 uint64_t BackedgeCount = CurrentCount;
523
524 BreakContinue BC = BreakContinueStack.pop_back_val();
525 // The count at the start of the condition is equal to the count at the
526 // end of the body, plus any continues.
527 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
528 CountMap[S->getCond()] = CondCount;
529 Visit(S->getCond());
530 setCount(BC.BreakCount + CondCount - LoopCount);
531 RecordNextStmtCount = true;
532 }
533
VisitForStmt__anon18fb87260111::ComputeRegionCounts534 void VisitForStmt(const ForStmt *S) {
535 RecordStmtCount(S);
536 if (S->getInit())
537 Visit(S->getInit());
538
539 uint64_t ParentCount = CurrentCount;
540
541 BreakContinueStack.push_back(BreakContinue());
542 // Visit the body region first. (This is basically the same as a while
543 // loop; see further comments in VisitWhileStmt.)
544 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
545 CountMap[S->getBody()] = BodyCount;
546 Visit(S->getBody());
547 uint64_t BackedgeCount = CurrentCount;
548 BreakContinue BC = BreakContinueStack.pop_back_val();
549
550 // The increment is essentially part of the body but it needs to include
551 // the count for all the continue statements.
552 if (S->getInc()) {
553 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
554 CountMap[S->getInc()] = IncCount;
555 Visit(S->getInc());
556 }
557
558 // ...then go back and propagate counts through the condition.
559 uint64_t CondCount =
560 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
561 if (S->getCond()) {
562 CountMap[S->getCond()] = CondCount;
563 Visit(S->getCond());
564 }
565 setCount(BC.BreakCount + CondCount - BodyCount);
566 RecordNextStmtCount = true;
567 }
568
VisitCXXForRangeStmt__anon18fb87260111::ComputeRegionCounts569 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
570 RecordStmtCount(S);
571 if (S->getInit())
572 Visit(S->getInit());
573 Visit(S->getLoopVarStmt());
574 Visit(S->getRangeStmt());
575 Visit(S->getBeginStmt());
576 Visit(S->getEndStmt());
577
578 uint64_t ParentCount = CurrentCount;
579 BreakContinueStack.push_back(BreakContinue());
580 // Visit the body region first. (This is basically the same as a while
581 // loop; see further comments in VisitWhileStmt.)
582 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
583 CountMap[S->getBody()] = BodyCount;
584 Visit(S->getBody());
585 uint64_t BackedgeCount = CurrentCount;
586 BreakContinue BC = BreakContinueStack.pop_back_val();
587
588 // The increment is essentially part of the body but it needs to include
589 // the count for all the continue statements.
590 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
591 CountMap[S->getInc()] = IncCount;
592 Visit(S->getInc());
593
594 // ...then go back and propagate counts through the condition.
595 uint64_t CondCount =
596 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
597 CountMap[S->getCond()] = CondCount;
598 Visit(S->getCond());
599 setCount(BC.BreakCount + CondCount - BodyCount);
600 RecordNextStmtCount = true;
601 }
602
VisitObjCForCollectionStmt__anon18fb87260111::ComputeRegionCounts603 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
604 RecordStmtCount(S);
605 Visit(S->getElement());
606 uint64_t ParentCount = CurrentCount;
607 BreakContinueStack.push_back(BreakContinue());
608 // Counter tracks the body of the loop.
609 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
610 CountMap[S->getBody()] = BodyCount;
611 Visit(S->getBody());
612 uint64_t BackedgeCount = CurrentCount;
613 BreakContinue BC = BreakContinueStack.pop_back_val();
614
615 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
616 BodyCount);
617 RecordNextStmtCount = true;
618 }
619
VisitSwitchStmt__anon18fb87260111::ComputeRegionCounts620 void VisitSwitchStmt(const SwitchStmt *S) {
621 RecordStmtCount(S);
622 if (S->getInit())
623 Visit(S->getInit());
624 Visit(S->getCond());
625 CurrentCount = 0;
626 BreakContinueStack.push_back(BreakContinue());
627 Visit(S->getBody());
628 // If the switch is inside a loop, add the continue counts.
629 BreakContinue BC = BreakContinueStack.pop_back_val();
630 if (!BreakContinueStack.empty())
631 BreakContinueStack.back().ContinueCount += BC.ContinueCount;
632 // Counter tracks the exit block of the switch.
633 setCount(PGO.getRegionCount(S));
634 RecordNextStmtCount = true;
635 }
636
VisitSwitchCase__anon18fb87260111::ComputeRegionCounts637 void VisitSwitchCase(const SwitchCase *S) {
638 RecordNextStmtCount = false;
639 // Counter for this particular case. This counts only jumps from the
640 // switch header and does not include fallthrough from the case before
641 // this one.
642 uint64_t CaseCount = PGO.getRegionCount(S);
643 setCount(CurrentCount + CaseCount);
644 // We need the count without fallthrough in the mapping, so it's more useful
645 // for branch probabilities.
646 CountMap[S] = CaseCount;
647 RecordNextStmtCount = true;
648 Visit(S->getSubStmt());
649 }
650
VisitIfStmt__anon18fb87260111::ComputeRegionCounts651 void VisitIfStmt(const IfStmt *S) {
652 RecordStmtCount(S);
653
654 if (S->isConsteval()) {
655 const Stmt *Stm = S->isNegatedConsteval() ? S->getThen() : S->getElse();
656 if (Stm)
657 Visit(Stm);
658 return;
659 }
660
661 uint64_t ParentCount = CurrentCount;
662 if (S->getInit())
663 Visit(S->getInit());
664 Visit(S->getCond());
665
666 // Counter tracks the "then" part of an if statement. The count for
667 // the "else" part, if it exists, will be calculated from this counter.
668 uint64_t ThenCount = setCount(PGO.getRegionCount(S));
669 CountMap[S->getThen()] = ThenCount;
670 Visit(S->getThen());
671 uint64_t OutCount = CurrentCount;
672
673 uint64_t ElseCount = ParentCount - ThenCount;
674 if (S->getElse()) {
675 setCount(ElseCount);
676 CountMap[S->getElse()] = ElseCount;
677 Visit(S->getElse());
678 OutCount += CurrentCount;
679 } else
680 OutCount += ElseCount;
681 setCount(OutCount);
682 RecordNextStmtCount = true;
683 }
684
VisitCXXTryStmt__anon18fb87260111::ComputeRegionCounts685 void VisitCXXTryStmt(const CXXTryStmt *S) {
686 RecordStmtCount(S);
687 Visit(S->getTryBlock());
688 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
689 Visit(S->getHandler(I));
690 // Counter tracks the continuation block of the try statement.
691 setCount(PGO.getRegionCount(S));
692 RecordNextStmtCount = true;
693 }
694
VisitCXXCatchStmt__anon18fb87260111::ComputeRegionCounts695 void VisitCXXCatchStmt(const CXXCatchStmt *S) {
696 RecordNextStmtCount = false;
697 // Counter tracks the catch statement's handler block.
698 uint64_t CatchCount = setCount(PGO.getRegionCount(S));
699 CountMap[S] = CatchCount;
700 Visit(S->getHandlerBlock());
701 }
702
VisitAbstractConditionalOperator__anon18fb87260111::ComputeRegionCounts703 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
704 RecordStmtCount(E);
705 uint64_t ParentCount = CurrentCount;
706 Visit(E->getCond());
707
708 // Counter tracks the "true" part of a conditional operator. The
709 // count in the "false" part will be calculated from this counter.
710 uint64_t TrueCount = setCount(PGO.getRegionCount(E));
711 CountMap[E->getTrueExpr()] = TrueCount;
712 Visit(E->getTrueExpr());
713 uint64_t OutCount = CurrentCount;
714
715 uint64_t FalseCount = setCount(ParentCount - TrueCount);
716 CountMap[E->getFalseExpr()] = FalseCount;
717 Visit(E->getFalseExpr());
718 OutCount += CurrentCount;
719
720 setCount(OutCount);
721 RecordNextStmtCount = true;
722 }
723
VisitBinLAnd__anon18fb87260111::ComputeRegionCounts724 void VisitBinLAnd(const BinaryOperator *E) {
725 RecordStmtCount(E);
726 uint64_t ParentCount = CurrentCount;
727 Visit(E->getLHS());
728 // Counter tracks the right hand side of a logical and operator.
729 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
730 CountMap[E->getRHS()] = RHSCount;
731 Visit(E->getRHS());
732 setCount(ParentCount + RHSCount - CurrentCount);
733 RecordNextStmtCount = true;
734 }
735
VisitBinLOr__anon18fb87260111::ComputeRegionCounts736 void VisitBinLOr(const BinaryOperator *E) {
737 RecordStmtCount(E);
738 uint64_t ParentCount = CurrentCount;
739 Visit(E->getLHS());
740 // Counter tracks the right hand side of a logical or operator.
741 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
742 CountMap[E->getRHS()] = RHSCount;
743 Visit(E->getRHS());
744 setCount(ParentCount + RHSCount - CurrentCount);
745 RecordNextStmtCount = true;
746 }
747 };
748 } // end anonymous namespace
749
combine(HashType Type)750 void PGOHash::combine(HashType Type) {
751 // Check that we never combine 0 and only have six bits.
752 assert(Type && "Hash is invalid: unexpected type 0");
753 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
754
755 // Pass through MD5 if enough work has built up.
756 if (Count && Count % NumTypesPerWord == 0) {
757 using namespace llvm::support;
758 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
759 MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
760 Working = 0;
761 }
762
763 // Accumulate the current type.
764 ++Count;
765 Working = Working << NumBitsPerType | Type;
766 }
767
finalize()768 uint64_t PGOHash::finalize() {
769 // Use Working as the hash directly if we never used MD5.
770 if (Count <= NumTypesPerWord)
771 // No need to byte swap here, since none of the math was endian-dependent.
772 // This number will be byte-swapped as required on endianness transitions,
773 // so we will see the same value on the other side.
774 return Working;
775
776 // Check for remaining work in Working.
777 if (Working) {
778 // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
779 // is buggy because it converts a uint64_t into an array of uint8_t.
780 if (HashVersion < PGO_HASH_V3) {
781 MD5.update({(uint8_t)Working});
782 } else {
783 using namespace llvm::support;
784 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
785 MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
786 }
787 }
788
789 // Finalize the MD5 and return the hash.
790 llvm::MD5::MD5Result Result;
791 MD5.final(Result);
792 return Result.low();
793 }
794
assignRegionCounters(GlobalDecl GD,llvm::Function * Fn)795 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
796 const Decl *D = GD.getDecl();
797 if (!D->hasBody())
798 return;
799
800 // Skip CUDA/HIP kernel launch stub functions.
801 if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
802 D->hasAttr<CUDAGlobalAttr>())
803 return;
804
805 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
806 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
807 if (!InstrumentRegions && !PGOReader)
808 return;
809 if (D->isImplicit())
810 return;
811 // Constructors and destructors may be represented by several functions in IR.
812 // If so, instrument only base variant, others are implemented by delegation
813 // to the base one, it would be counted twice otherwise.
814 if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
815 if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
816 if (GD.getCtorType() != Ctor_Base &&
817 CodeGenFunction::IsConstructorDelegationValid(CCD))
818 return;
819 }
820 if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
821 return;
822
823 CGM.ClearUnusedCoverageMapping(D);
824 if (Fn->hasFnAttribute(llvm::Attribute::NoProfile))
825 return;
826 if (Fn->hasFnAttribute(llvm::Attribute::SkipProfile))
827 return;
828
829 setFuncName(Fn);
830
831 mapRegionCounters(D);
832 if (CGM.getCodeGenOpts().CoverageMapping)
833 emitCounterRegionMapping(D);
834 if (PGOReader) {
835 SourceManager &SM = CGM.getContext().getSourceManager();
836 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
837 computeRegionCounts(D);
838 applyFunctionAttributes(PGOReader, Fn);
839 }
840 }
841
mapRegionCounters(const Decl * D)842 void CodeGenPGO::mapRegionCounters(const Decl *D) {
843 // Use the latest hash version when inserting instrumentation, but use the
844 // version in the indexed profile if we're reading PGO data.
845 PGOHashVersion HashVersion = PGO_HASH_LATEST;
846 uint64_t ProfileVersion = llvm::IndexedInstrProf::Version;
847 if (auto *PGOReader = CGM.getPGOReader()) {
848 HashVersion = getPGOHashVersion(PGOReader, CGM);
849 ProfileVersion = PGOReader->getVersion();
850 }
851
852 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
853 MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap);
854 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
855 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
856 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
857 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
858 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
859 Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
860 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
861 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
862 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
863 NumRegionCounters = Walker.NextCounter;
864 FunctionHash = Walker.Hash.finalize();
865 }
866
skipRegionMappingForDecl(const Decl * D)867 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
868 if (!D->getBody())
869 return true;
870
871 // Skip host-only functions in the CUDA device compilation and device-only
872 // functions in the host compilation. Just roughly filter them out based on
873 // the function attributes. If there are effectively host-only or device-only
874 // ones, their coverage mapping may still be generated.
875 if (CGM.getLangOpts().CUDA &&
876 ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() &&
877 !D->hasAttr<CUDAGlobalAttr>()) ||
878 (!CGM.getLangOpts().CUDAIsDevice &&
879 (D->hasAttr<CUDAGlobalAttr>() ||
880 (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>())))))
881 return true;
882
883 // Don't map the functions in system headers.
884 const auto &SM = CGM.getContext().getSourceManager();
885 auto Loc = D->getBody()->getBeginLoc();
886 return SM.isInSystemHeader(Loc);
887 }
888
emitCounterRegionMapping(const Decl * D)889 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
890 if (skipRegionMappingForDecl(D))
891 return;
892
893 std::string CoverageMapping;
894 llvm::raw_string_ostream OS(CoverageMapping);
895 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
896 CGM.getContext().getSourceManager(),
897 CGM.getLangOpts(), RegionCounterMap.get());
898 MappingGen.emitCounterMapping(D, OS);
899 OS.flush();
900
901 if (CoverageMapping.empty())
902 return;
903
904 CGM.getCoverageMapping()->addFunctionMappingRecord(
905 FuncNameVar, FuncName, FunctionHash, CoverageMapping);
906 }
907
908 void
emitEmptyCounterMapping(const Decl * D,StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)909 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
910 llvm::GlobalValue::LinkageTypes Linkage) {
911 if (skipRegionMappingForDecl(D))
912 return;
913
914 std::string CoverageMapping;
915 llvm::raw_string_ostream OS(CoverageMapping);
916 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
917 CGM.getContext().getSourceManager(),
918 CGM.getLangOpts());
919 MappingGen.emitEmptyMapping(D, OS);
920 OS.flush();
921
922 if (CoverageMapping.empty())
923 return;
924
925 setFuncName(Name, Linkage);
926 CGM.getCoverageMapping()->addFunctionMappingRecord(
927 FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
928 }
929
computeRegionCounts(const Decl * D)930 void CodeGenPGO::computeRegionCounts(const Decl *D) {
931 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
932 ComputeRegionCounts Walker(*StmtCountMap, *this);
933 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
934 Walker.VisitFunctionDecl(FD);
935 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
936 Walker.VisitObjCMethodDecl(MD);
937 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
938 Walker.VisitBlockDecl(BD);
939 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
940 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
941 }
942
943 void
applyFunctionAttributes(llvm::IndexedInstrProfReader * PGOReader,llvm::Function * Fn)944 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
945 llvm::Function *Fn) {
946 if (!haveRegionCounts())
947 return;
948
949 uint64_t FunctionCount = getRegionCount(nullptr);
950 Fn->setEntryCount(FunctionCount);
951 }
952
emitCounterIncrement(CGBuilderTy & Builder,const Stmt * S,llvm::Value * StepV)953 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
954 llvm::Value *StepV) {
955 if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
956 return;
957 if (!Builder.GetInsertBlock())
958 return;
959
960 unsigned Counter = (*RegionCounterMap)[S];
961 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
962
963 llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
964 Builder.getInt64(FunctionHash),
965 Builder.getInt32(NumRegionCounters),
966 Builder.getInt32(Counter), StepV};
967 if (!StepV)
968 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
969 ArrayRef(Args, 4));
970 else
971 Builder.CreateCall(
972 CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
973 ArrayRef(Args));
974 }
975
setValueProfilingFlag(llvm::Module & M)976 void CodeGenPGO::setValueProfilingFlag(llvm::Module &M) {
977 if (CGM.getCodeGenOpts().hasProfileClangInstr())
978 M.addModuleFlag(llvm::Module::Warning, "EnableValueProfiling",
979 uint32_t(EnableValueProfiling));
980 }
981
982 // This method either inserts a call to the profile run-time during
983 // instrumentation or puts profile data into metadata for PGO use.
valueProfile(CGBuilderTy & Builder,uint32_t ValueKind,llvm::Instruction * ValueSite,llvm::Value * ValuePtr)984 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
985 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
986
987 if (!EnableValueProfiling)
988 return;
989
990 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
991 return;
992
993 if (isa<llvm::Constant>(ValuePtr))
994 return;
995
996 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
997 if (InstrumentValueSites && RegionCounterMap) {
998 auto BuilderInsertPoint = Builder.saveIP();
999 Builder.SetInsertPoint(ValueSite);
1000 llvm::Value *Args[5] = {
1001 llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
1002 Builder.getInt64(FunctionHash),
1003 Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
1004 Builder.getInt32(ValueKind),
1005 Builder.getInt32(NumValueSites[ValueKind]++)
1006 };
1007 Builder.CreateCall(
1008 CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
1009 Builder.restoreIP(BuilderInsertPoint);
1010 return;
1011 }
1012
1013 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1014 if (PGOReader && haveRegionCounts()) {
1015 // We record the top most called three functions at each call site.
1016 // Profile metadata contains "VP" string identifying this metadata
1017 // as value profiling data, then a uint32_t value for the value profiling
1018 // kind, a uint64_t value for the total number of times the call is
1019 // executed, followed by the function hash and execution count (uint64_t)
1020 // pairs for each function.
1021 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
1022 return;
1023
1024 llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
1025 (llvm::InstrProfValueKind)ValueKind,
1026 NumValueSites[ValueKind]);
1027
1028 NumValueSites[ValueKind]++;
1029 }
1030 }
1031
loadRegionCounts(llvm::IndexedInstrProfReader * PGOReader,bool IsInMainFile)1032 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
1033 bool IsInMainFile) {
1034 CGM.getPGOStats().addVisited(IsInMainFile);
1035 RegionCounts.clear();
1036 llvm::Expected<llvm::InstrProfRecord> RecordExpected =
1037 PGOReader->getInstrProfRecord(FuncName, FunctionHash);
1038 if (auto E = RecordExpected.takeError()) {
1039 auto IPE = llvm::InstrProfError::take(std::move(E));
1040 if (IPE == llvm::instrprof_error::unknown_function)
1041 CGM.getPGOStats().addMissing(IsInMainFile);
1042 else if (IPE == llvm::instrprof_error::hash_mismatch)
1043 CGM.getPGOStats().addMismatched(IsInMainFile);
1044 else if (IPE == llvm::instrprof_error::malformed)
1045 // TODO: Consider a more specific warning for this case.
1046 CGM.getPGOStats().addMismatched(IsInMainFile);
1047 return;
1048 }
1049 ProfRecord =
1050 std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
1051 RegionCounts = ProfRecord->Counts;
1052 }
1053
1054 /// Calculate what to divide by to scale weights.
1055 ///
1056 /// Given the maximum weight, calculate a divisor that will scale all the
1057 /// weights to strictly less than UINT32_MAX.
calculateWeightScale(uint64_t MaxWeight)1058 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1059 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1060 }
1061
1062 /// Scale an individual branch weight (and add 1).
1063 ///
1064 /// Scale a 64-bit weight down to 32-bits using \c Scale.
1065 ///
1066 /// According to Laplace's Rule of Succession, it is better to compute the
1067 /// weight based on the count plus 1, so universally add 1 to the value.
1068 ///
1069 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1070 /// greater than \c Weight.
scaleBranchWeight(uint64_t Weight,uint64_t Scale)1071 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1072 assert(Scale && "scale by 0?");
1073 uint64_t Scaled = Weight / Scale + 1;
1074 assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1075 return Scaled;
1076 }
1077
createProfileWeights(uint64_t TrueCount,uint64_t FalseCount) const1078 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1079 uint64_t FalseCount) const {
1080 // Check for empty weights.
1081 if (!TrueCount && !FalseCount)
1082 return nullptr;
1083
1084 // Calculate how to scale down to 32-bits.
1085 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1086
1087 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1088 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1089 scaleBranchWeight(FalseCount, Scale));
1090 }
1091
1092 llvm::MDNode *
createProfileWeights(ArrayRef<uint64_t> Weights) const1093 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const {
1094 // We need at least two elements to create meaningful weights.
1095 if (Weights.size() < 2)
1096 return nullptr;
1097
1098 // Check for empty weights.
1099 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1100 if (MaxWeight == 0)
1101 return nullptr;
1102
1103 // Calculate how to scale down to 32-bits.
1104 uint64_t Scale = calculateWeightScale(MaxWeight);
1105
1106 SmallVector<uint32_t, 16> ScaledWeights;
1107 ScaledWeights.reserve(Weights.size());
1108 for (uint64_t W : Weights)
1109 ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1110
1111 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1112 return MDHelper.createBranchWeights(ScaledWeights);
1113 }
1114
1115 llvm::MDNode *
createProfileWeightsForLoop(const Stmt * Cond,uint64_t LoopCount) const1116 CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1117 uint64_t LoopCount) const {
1118 if (!PGO.haveRegionCounts())
1119 return nullptr;
1120 std::optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1121 if (!CondCount || *CondCount == 0)
1122 return nullptr;
1123 return createProfileWeights(LoopCount,
1124 std::max(*CondCount, LoopCount) - LoopCount);
1125 }
1126