1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Instrumentation-based profile-guided optimization
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "CodeGenPGO.h"
14 #include "CodeGenFunction.h"
15 #include "CoverageMappingGen.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
24
25 static llvm::cl::opt<bool>
26 EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore,
27 llvm::cl::desc("Enable value profiling"),
28 llvm::cl::Hidden, llvm::cl::init(false));
29
30 using namespace clang;
31 using namespace CodeGen;
32
setFuncName(StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)33 void CodeGenPGO::setFuncName(StringRef Name,
34 llvm::GlobalValue::LinkageTypes Linkage) {
35 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
36 FuncName = llvm::getPGOFuncName(
37 Name, Linkage, CGM.getCodeGenOpts().MainFileName,
38 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
39
40 // If we're generating a profile, create a variable for the name.
41 if (CGM.getCodeGenOpts().hasProfileClangInstr())
42 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
43 }
44
setFuncName(llvm::Function * Fn)45 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
46 setFuncName(Fn->getName(), Fn->getLinkage());
47 // Create PGOFuncName meta data.
48 llvm::createPGOFuncNameMetadata(*Fn, FuncName);
49 }
50
51 /// The version of the PGO hash algorithm.
52 enum PGOHashVersion : unsigned {
53 PGO_HASH_V1,
54 PGO_HASH_V2,
55 PGO_HASH_V3,
56
57 // Keep this set to the latest hash version.
58 PGO_HASH_LATEST = PGO_HASH_V3
59 };
60
61 namespace {
62 /// Stable hasher for PGO region counters.
63 ///
64 /// PGOHash produces a stable hash of a given function's control flow.
65 ///
66 /// Changing the output of this hash will invalidate all previously generated
67 /// profiles -- i.e., don't do it.
68 ///
69 /// \note When this hash does eventually change (years?), we still need to
70 /// support old hashes. We'll need to pull in the version number from the
71 /// profile data format and use the matching hash function.
72 class PGOHash {
73 uint64_t Working;
74 unsigned Count;
75 PGOHashVersion HashVersion;
76 llvm::MD5 MD5;
77
78 static const int NumBitsPerType = 6;
79 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
80 static const unsigned TooBig = 1u << NumBitsPerType;
81
82 public:
83 /// Hash values for AST nodes.
84 ///
85 /// Distinct values for AST nodes that have region counters attached.
86 ///
87 /// These values must be stable. All new members must be added at the end,
88 /// and no members should be removed. Changing the enumeration value for an
89 /// AST node will affect the hash of every function that contains that node.
90 enum HashType : unsigned char {
91 None = 0,
92 LabelStmt = 1,
93 WhileStmt,
94 DoStmt,
95 ForStmt,
96 CXXForRangeStmt,
97 ObjCForCollectionStmt,
98 SwitchStmt,
99 CaseStmt,
100 DefaultStmt,
101 IfStmt,
102 CXXTryStmt,
103 CXXCatchStmt,
104 ConditionalOperator,
105 BinaryOperatorLAnd,
106 BinaryOperatorLOr,
107 BinaryConditionalOperator,
108 // The preceding values are available with PGO_HASH_V1.
109
110 EndOfScope,
111 IfThenBranch,
112 IfElseBranch,
113 GotoStmt,
114 IndirectGotoStmt,
115 BreakStmt,
116 ContinueStmt,
117 ReturnStmt,
118 ThrowExpr,
119 UnaryOperatorLNot,
120 BinaryOperatorLT,
121 BinaryOperatorGT,
122 BinaryOperatorLE,
123 BinaryOperatorGE,
124 BinaryOperatorEQ,
125 BinaryOperatorNE,
126 // The preceding values are available since PGO_HASH_V2.
127
128 // Keep this last. It's for the static assert that follows.
129 LastHashType
130 };
131 static_assert(LastHashType <= TooBig, "Too many types in HashType");
132
PGOHash(PGOHashVersion HashVersion)133 PGOHash(PGOHashVersion HashVersion)
134 : Working(0), Count(0), HashVersion(HashVersion), MD5() {}
135 void combine(HashType Type);
136 uint64_t finalize();
getHashVersion() const137 PGOHashVersion getHashVersion() const { return HashVersion; }
138 };
139 const int PGOHash::NumBitsPerType;
140 const unsigned PGOHash::NumTypesPerWord;
141 const unsigned PGOHash::TooBig;
142
143 /// Get the PGO hash version used in the given indexed profile.
getPGOHashVersion(llvm::IndexedInstrProfReader * PGOReader,CodeGenModule & CGM)144 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
145 CodeGenModule &CGM) {
146 if (PGOReader->getVersion() <= 4)
147 return PGO_HASH_V1;
148 if (PGOReader->getVersion() <= 5)
149 return PGO_HASH_V2;
150 return PGO_HASH_V3;
151 }
152
153 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
154 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
155 using Base = RecursiveASTVisitor<MapRegionCounters>;
156
157 /// The next counter value to assign.
158 unsigned NextCounter;
159 /// The function hash.
160 PGOHash Hash;
161 /// The map of statements to counters.
162 llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
163
MapRegionCounters__anon20f2e76b0111::MapRegionCounters164 MapRegionCounters(PGOHashVersion HashVersion,
165 llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
166 : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap) {}
167
168 // Blocks and lambdas are handled as separate functions, so we need not
169 // traverse them in the parent context.
TraverseBlockExpr__anon20f2e76b0111::MapRegionCounters170 bool TraverseBlockExpr(BlockExpr *BE) { return true; }
TraverseLambdaExpr__anon20f2e76b0111::MapRegionCounters171 bool TraverseLambdaExpr(LambdaExpr *LE) {
172 // Traverse the captures, but not the body.
173 for (auto C : zip(LE->captures(), LE->capture_inits()))
174 TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
175 return true;
176 }
TraverseCapturedStmt__anon20f2e76b0111::MapRegionCounters177 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
178
VisitDecl__anon20f2e76b0111::MapRegionCounters179 bool VisitDecl(const Decl *D) {
180 switch (D->getKind()) {
181 default:
182 break;
183 case Decl::Function:
184 case Decl::CXXMethod:
185 case Decl::CXXConstructor:
186 case Decl::CXXDestructor:
187 case Decl::CXXConversion:
188 case Decl::ObjCMethod:
189 case Decl::Block:
190 case Decl::Captured:
191 CounterMap[D->getBody()] = NextCounter++;
192 break;
193 }
194 return true;
195 }
196
197 /// If \p S gets a fresh counter, update the counter mappings. Return the
198 /// V1 hash of \p S.
updateCounterMappings__anon20f2e76b0111::MapRegionCounters199 PGOHash::HashType updateCounterMappings(Stmt *S) {
200 auto Type = getHashType(PGO_HASH_V1, S);
201 if (Type != PGOHash::None)
202 CounterMap[S] = NextCounter++;
203 return Type;
204 }
205
206 /// Include \p S in the function hash.
VisitStmt__anon20f2e76b0111::MapRegionCounters207 bool VisitStmt(Stmt *S) {
208 auto Type = updateCounterMappings(S);
209 if (Hash.getHashVersion() != PGO_HASH_V1)
210 Type = getHashType(Hash.getHashVersion(), S);
211 if (Type != PGOHash::None)
212 Hash.combine(Type);
213 return true;
214 }
215
TraverseIfStmt__anon20f2e76b0111::MapRegionCounters216 bool TraverseIfStmt(IfStmt *If) {
217 // If we used the V1 hash, use the default traversal.
218 if (Hash.getHashVersion() == PGO_HASH_V1)
219 return Base::TraverseIfStmt(If);
220
221 // Otherwise, keep track of which branch we're in while traversing.
222 VisitStmt(If);
223 for (Stmt *CS : If->children()) {
224 if (!CS)
225 continue;
226 if (CS == If->getThen())
227 Hash.combine(PGOHash::IfThenBranch);
228 else if (CS == If->getElse())
229 Hash.combine(PGOHash::IfElseBranch);
230 TraverseStmt(CS);
231 }
232 Hash.combine(PGOHash::EndOfScope);
233 return true;
234 }
235
236 // If the statement type \p N is nestable, and its nesting impacts profile
237 // stability, define a custom traversal which tracks the end of the statement
238 // in the hash (provided we're not using the V1 hash).
239 #define DEFINE_NESTABLE_TRAVERSAL(N) \
240 bool Traverse##N(N *S) { \
241 Base::Traverse##N(S); \
242 if (Hash.getHashVersion() != PGO_HASH_V1) \
243 Hash.combine(PGOHash::EndOfScope); \
244 return true; \
245 }
246
247 DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
DEFINE_NESTABLE_TRAVERSAL__anon20f2e76b0111::MapRegionCounters248 DEFINE_NESTABLE_TRAVERSAL(DoStmt)
249 DEFINE_NESTABLE_TRAVERSAL(ForStmt)
250 DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
251 DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
252 DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
253 DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
254
255 /// Get version \p HashVersion of the PGO hash for \p S.
256 PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
257 switch (S->getStmtClass()) {
258 default:
259 break;
260 case Stmt::LabelStmtClass:
261 return PGOHash::LabelStmt;
262 case Stmt::WhileStmtClass:
263 return PGOHash::WhileStmt;
264 case Stmt::DoStmtClass:
265 return PGOHash::DoStmt;
266 case Stmt::ForStmtClass:
267 return PGOHash::ForStmt;
268 case Stmt::CXXForRangeStmtClass:
269 return PGOHash::CXXForRangeStmt;
270 case Stmt::ObjCForCollectionStmtClass:
271 return PGOHash::ObjCForCollectionStmt;
272 case Stmt::SwitchStmtClass:
273 return PGOHash::SwitchStmt;
274 case Stmt::CaseStmtClass:
275 return PGOHash::CaseStmt;
276 case Stmt::DefaultStmtClass:
277 return PGOHash::DefaultStmt;
278 case Stmt::IfStmtClass:
279 return PGOHash::IfStmt;
280 case Stmt::CXXTryStmtClass:
281 return PGOHash::CXXTryStmt;
282 case Stmt::CXXCatchStmtClass:
283 return PGOHash::CXXCatchStmt;
284 case Stmt::ConditionalOperatorClass:
285 return PGOHash::ConditionalOperator;
286 case Stmt::BinaryConditionalOperatorClass:
287 return PGOHash::BinaryConditionalOperator;
288 case Stmt::BinaryOperatorClass: {
289 const BinaryOperator *BO = cast<BinaryOperator>(S);
290 if (BO->getOpcode() == BO_LAnd)
291 return PGOHash::BinaryOperatorLAnd;
292 if (BO->getOpcode() == BO_LOr)
293 return PGOHash::BinaryOperatorLOr;
294 if (HashVersion >= PGO_HASH_V2) {
295 switch (BO->getOpcode()) {
296 default:
297 break;
298 case BO_LT:
299 return PGOHash::BinaryOperatorLT;
300 case BO_GT:
301 return PGOHash::BinaryOperatorGT;
302 case BO_LE:
303 return PGOHash::BinaryOperatorLE;
304 case BO_GE:
305 return PGOHash::BinaryOperatorGE;
306 case BO_EQ:
307 return PGOHash::BinaryOperatorEQ;
308 case BO_NE:
309 return PGOHash::BinaryOperatorNE;
310 }
311 }
312 break;
313 }
314 }
315
316 if (HashVersion >= PGO_HASH_V2) {
317 switch (S->getStmtClass()) {
318 default:
319 break;
320 case Stmt::GotoStmtClass:
321 return PGOHash::GotoStmt;
322 case Stmt::IndirectGotoStmtClass:
323 return PGOHash::IndirectGotoStmt;
324 case Stmt::BreakStmtClass:
325 return PGOHash::BreakStmt;
326 case Stmt::ContinueStmtClass:
327 return PGOHash::ContinueStmt;
328 case Stmt::ReturnStmtClass:
329 return PGOHash::ReturnStmt;
330 case Stmt::CXXThrowExprClass:
331 return PGOHash::ThrowExpr;
332 case Stmt::UnaryOperatorClass: {
333 const UnaryOperator *UO = cast<UnaryOperator>(S);
334 if (UO->getOpcode() == UO_LNot)
335 return PGOHash::UnaryOperatorLNot;
336 break;
337 }
338 }
339 }
340
341 return PGOHash::None;
342 }
343 };
344
345 /// A StmtVisitor that propagates the raw counts through the AST and
346 /// records the count at statements where the value may change.
347 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
348 /// PGO state.
349 CodeGenPGO &PGO;
350
351 /// A flag that is set when the current count should be recorded on the
352 /// next statement, such as at the exit of a loop.
353 bool RecordNextStmtCount;
354
355 /// The count at the current location in the traversal.
356 uint64_t CurrentCount;
357
358 /// The map of statements to count values.
359 llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
360
361 /// BreakContinueStack - Keep counts of breaks and continues inside loops.
362 struct BreakContinue {
363 uint64_t BreakCount;
364 uint64_t ContinueCount;
BreakContinue__anon20f2e76b0111::ComputeRegionCounts::BreakContinue365 BreakContinue() : BreakCount(0), ContinueCount(0) {}
366 };
367 SmallVector<BreakContinue, 8> BreakContinueStack;
368
ComputeRegionCounts__anon20f2e76b0111::ComputeRegionCounts369 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
370 CodeGenPGO &PGO)
371 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
372
RecordStmtCount__anon20f2e76b0111::ComputeRegionCounts373 void RecordStmtCount(const Stmt *S) {
374 if (RecordNextStmtCount) {
375 CountMap[S] = CurrentCount;
376 RecordNextStmtCount = false;
377 }
378 }
379
380 /// Set and return the current count.
setCount__anon20f2e76b0111::ComputeRegionCounts381 uint64_t setCount(uint64_t Count) {
382 CurrentCount = Count;
383 return Count;
384 }
385
VisitStmt__anon20f2e76b0111::ComputeRegionCounts386 void VisitStmt(const Stmt *S) {
387 RecordStmtCount(S);
388 for (const Stmt *Child : S->children())
389 if (Child)
390 this->Visit(Child);
391 }
392
VisitFunctionDecl__anon20f2e76b0111::ComputeRegionCounts393 void VisitFunctionDecl(const FunctionDecl *D) {
394 // Counter tracks entry to the function body.
395 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
396 CountMap[D->getBody()] = BodyCount;
397 Visit(D->getBody());
398 }
399
400 // Skip lambda expressions. We visit these as FunctionDecls when we're
401 // generating them and aren't interested in the body when generating a
402 // parent context.
VisitLambdaExpr__anon20f2e76b0111::ComputeRegionCounts403 void VisitLambdaExpr(const LambdaExpr *LE) {}
404
VisitCapturedDecl__anon20f2e76b0111::ComputeRegionCounts405 void VisitCapturedDecl(const CapturedDecl *D) {
406 // Counter tracks entry to the capture body.
407 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
408 CountMap[D->getBody()] = BodyCount;
409 Visit(D->getBody());
410 }
411
VisitObjCMethodDecl__anon20f2e76b0111::ComputeRegionCounts412 void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
413 // Counter tracks entry to the method body.
414 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
415 CountMap[D->getBody()] = BodyCount;
416 Visit(D->getBody());
417 }
418
VisitBlockDecl__anon20f2e76b0111::ComputeRegionCounts419 void VisitBlockDecl(const BlockDecl *D) {
420 // Counter tracks entry to the block body.
421 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
422 CountMap[D->getBody()] = BodyCount;
423 Visit(D->getBody());
424 }
425
VisitReturnStmt__anon20f2e76b0111::ComputeRegionCounts426 void VisitReturnStmt(const ReturnStmt *S) {
427 RecordStmtCount(S);
428 if (S->getRetValue())
429 Visit(S->getRetValue());
430 CurrentCount = 0;
431 RecordNextStmtCount = true;
432 }
433
VisitCXXThrowExpr__anon20f2e76b0111::ComputeRegionCounts434 void VisitCXXThrowExpr(const CXXThrowExpr *E) {
435 RecordStmtCount(E);
436 if (E->getSubExpr())
437 Visit(E->getSubExpr());
438 CurrentCount = 0;
439 RecordNextStmtCount = true;
440 }
441
VisitGotoStmt__anon20f2e76b0111::ComputeRegionCounts442 void VisitGotoStmt(const GotoStmt *S) {
443 RecordStmtCount(S);
444 CurrentCount = 0;
445 RecordNextStmtCount = true;
446 }
447
VisitLabelStmt__anon20f2e76b0111::ComputeRegionCounts448 void VisitLabelStmt(const LabelStmt *S) {
449 RecordNextStmtCount = false;
450 // Counter tracks the block following the label.
451 uint64_t BlockCount = setCount(PGO.getRegionCount(S));
452 CountMap[S] = BlockCount;
453 Visit(S->getSubStmt());
454 }
455
VisitBreakStmt__anon20f2e76b0111::ComputeRegionCounts456 void VisitBreakStmt(const BreakStmt *S) {
457 RecordStmtCount(S);
458 assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
459 BreakContinueStack.back().BreakCount += CurrentCount;
460 CurrentCount = 0;
461 RecordNextStmtCount = true;
462 }
463
VisitContinueStmt__anon20f2e76b0111::ComputeRegionCounts464 void VisitContinueStmt(const ContinueStmt *S) {
465 RecordStmtCount(S);
466 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
467 BreakContinueStack.back().ContinueCount += CurrentCount;
468 CurrentCount = 0;
469 RecordNextStmtCount = true;
470 }
471
VisitWhileStmt__anon20f2e76b0111::ComputeRegionCounts472 void VisitWhileStmt(const WhileStmt *S) {
473 RecordStmtCount(S);
474 uint64_t ParentCount = CurrentCount;
475
476 BreakContinueStack.push_back(BreakContinue());
477 // Visit the body region first so the break/continue adjustments can be
478 // included when visiting the condition.
479 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
480 CountMap[S->getBody()] = CurrentCount;
481 Visit(S->getBody());
482 uint64_t BackedgeCount = CurrentCount;
483
484 // ...then go back and propagate counts through the condition. The count
485 // at the start of the condition is the sum of the incoming edges,
486 // the backedge from the end of the loop body, and the edges from
487 // continue statements.
488 BreakContinue BC = BreakContinueStack.pop_back_val();
489 uint64_t CondCount =
490 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
491 CountMap[S->getCond()] = CondCount;
492 Visit(S->getCond());
493 setCount(BC.BreakCount + CondCount - BodyCount);
494 RecordNextStmtCount = true;
495 }
496
VisitDoStmt__anon20f2e76b0111::ComputeRegionCounts497 void VisitDoStmt(const DoStmt *S) {
498 RecordStmtCount(S);
499 uint64_t LoopCount = PGO.getRegionCount(S);
500
501 BreakContinueStack.push_back(BreakContinue());
502 // The count doesn't include the fallthrough from the parent scope. Add it.
503 uint64_t BodyCount = setCount(LoopCount + CurrentCount);
504 CountMap[S->getBody()] = BodyCount;
505 Visit(S->getBody());
506 uint64_t BackedgeCount = CurrentCount;
507
508 BreakContinue BC = BreakContinueStack.pop_back_val();
509 // The count at the start of the condition is equal to the count at the
510 // end of the body, plus any continues.
511 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
512 CountMap[S->getCond()] = CondCount;
513 Visit(S->getCond());
514 setCount(BC.BreakCount + CondCount - LoopCount);
515 RecordNextStmtCount = true;
516 }
517
VisitForStmt__anon20f2e76b0111::ComputeRegionCounts518 void VisitForStmt(const ForStmt *S) {
519 RecordStmtCount(S);
520 if (S->getInit())
521 Visit(S->getInit());
522
523 uint64_t ParentCount = CurrentCount;
524
525 BreakContinueStack.push_back(BreakContinue());
526 // Visit the body region first. (This is basically the same as a while
527 // loop; see further comments in VisitWhileStmt.)
528 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
529 CountMap[S->getBody()] = BodyCount;
530 Visit(S->getBody());
531 uint64_t BackedgeCount = CurrentCount;
532 BreakContinue BC = BreakContinueStack.pop_back_val();
533
534 // The increment is essentially part of the body but it needs to include
535 // the count for all the continue statements.
536 if (S->getInc()) {
537 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
538 CountMap[S->getInc()] = IncCount;
539 Visit(S->getInc());
540 }
541
542 // ...then go back and propagate counts through the condition.
543 uint64_t CondCount =
544 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
545 if (S->getCond()) {
546 CountMap[S->getCond()] = CondCount;
547 Visit(S->getCond());
548 }
549 setCount(BC.BreakCount + CondCount - BodyCount);
550 RecordNextStmtCount = true;
551 }
552
VisitCXXForRangeStmt__anon20f2e76b0111::ComputeRegionCounts553 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
554 RecordStmtCount(S);
555 if (S->getInit())
556 Visit(S->getInit());
557 Visit(S->getLoopVarStmt());
558 Visit(S->getRangeStmt());
559 Visit(S->getBeginStmt());
560 Visit(S->getEndStmt());
561
562 uint64_t ParentCount = CurrentCount;
563 BreakContinueStack.push_back(BreakContinue());
564 // Visit the body region first. (This is basically the same as a while
565 // loop; see further comments in VisitWhileStmt.)
566 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
567 CountMap[S->getBody()] = BodyCount;
568 Visit(S->getBody());
569 uint64_t BackedgeCount = CurrentCount;
570 BreakContinue BC = BreakContinueStack.pop_back_val();
571
572 // The increment is essentially part of the body but it needs to include
573 // the count for all the continue statements.
574 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
575 CountMap[S->getInc()] = IncCount;
576 Visit(S->getInc());
577
578 // ...then go back and propagate counts through the condition.
579 uint64_t CondCount =
580 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
581 CountMap[S->getCond()] = CondCount;
582 Visit(S->getCond());
583 setCount(BC.BreakCount + CondCount - BodyCount);
584 RecordNextStmtCount = true;
585 }
586
VisitObjCForCollectionStmt__anon20f2e76b0111::ComputeRegionCounts587 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
588 RecordStmtCount(S);
589 Visit(S->getElement());
590 uint64_t ParentCount = CurrentCount;
591 BreakContinueStack.push_back(BreakContinue());
592 // Counter tracks the body of the loop.
593 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
594 CountMap[S->getBody()] = BodyCount;
595 Visit(S->getBody());
596 uint64_t BackedgeCount = CurrentCount;
597 BreakContinue BC = BreakContinueStack.pop_back_val();
598
599 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
600 BodyCount);
601 RecordNextStmtCount = true;
602 }
603
VisitSwitchStmt__anon20f2e76b0111::ComputeRegionCounts604 void VisitSwitchStmt(const SwitchStmt *S) {
605 RecordStmtCount(S);
606 if (S->getInit())
607 Visit(S->getInit());
608 Visit(S->getCond());
609 CurrentCount = 0;
610 BreakContinueStack.push_back(BreakContinue());
611 Visit(S->getBody());
612 // If the switch is inside a loop, add the continue counts.
613 BreakContinue BC = BreakContinueStack.pop_back_val();
614 if (!BreakContinueStack.empty())
615 BreakContinueStack.back().ContinueCount += BC.ContinueCount;
616 // Counter tracks the exit block of the switch.
617 setCount(PGO.getRegionCount(S));
618 RecordNextStmtCount = true;
619 }
620
VisitSwitchCase__anon20f2e76b0111::ComputeRegionCounts621 void VisitSwitchCase(const SwitchCase *S) {
622 RecordNextStmtCount = false;
623 // Counter for this particular case. This counts only jumps from the
624 // switch header and does not include fallthrough from the case before
625 // this one.
626 uint64_t CaseCount = PGO.getRegionCount(S);
627 setCount(CurrentCount + CaseCount);
628 // We need the count without fallthrough in the mapping, so it's more useful
629 // for branch probabilities.
630 CountMap[S] = CaseCount;
631 RecordNextStmtCount = true;
632 Visit(S->getSubStmt());
633 }
634
VisitIfStmt__anon20f2e76b0111::ComputeRegionCounts635 void VisitIfStmt(const IfStmt *S) {
636 RecordStmtCount(S);
637 uint64_t ParentCount = CurrentCount;
638 if (S->getInit())
639 Visit(S->getInit());
640 Visit(S->getCond());
641
642 // Counter tracks the "then" part of an if statement. The count for
643 // the "else" part, if it exists, will be calculated from this counter.
644 uint64_t ThenCount = setCount(PGO.getRegionCount(S));
645 CountMap[S->getThen()] = ThenCount;
646 Visit(S->getThen());
647 uint64_t OutCount = CurrentCount;
648
649 uint64_t ElseCount = ParentCount - ThenCount;
650 if (S->getElse()) {
651 setCount(ElseCount);
652 CountMap[S->getElse()] = ElseCount;
653 Visit(S->getElse());
654 OutCount += CurrentCount;
655 } else
656 OutCount += ElseCount;
657 setCount(OutCount);
658 RecordNextStmtCount = true;
659 }
660
VisitCXXTryStmt__anon20f2e76b0111::ComputeRegionCounts661 void VisitCXXTryStmt(const CXXTryStmt *S) {
662 RecordStmtCount(S);
663 Visit(S->getTryBlock());
664 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
665 Visit(S->getHandler(I));
666 // Counter tracks the continuation block of the try statement.
667 setCount(PGO.getRegionCount(S));
668 RecordNextStmtCount = true;
669 }
670
VisitCXXCatchStmt__anon20f2e76b0111::ComputeRegionCounts671 void VisitCXXCatchStmt(const CXXCatchStmt *S) {
672 RecordNextStmtCount = false;
673 // Counter tracks the catch statement's handler block.
674 uint64_t CatchCount = setCount(PGO.getRegionCount(S));
675 CountMap[S] = CatchCount;
676 Visit(S->getHandlerBlock());
677 }
678
VisitAbstractConditionalOperator__anon20f2e76b0111::ComputeRegionCounts679 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
680 RecordStmtCount(E);
681 uint64_t ParentCount = CurrentCount;
682 Visit(E->getCond());
683
684 // Counter tracks the "true" part of a conditional operator. The
685 // count in the "false" part will be calculated from this counter.
686 uint64_t TrueCount = setCount(PGO.getRegionCount(E));
687 CountMap[E->getTrueExpr()] = TrueCount;
688 Visit(E->getTrueExpr());
689 uint64_t OutCount = CurrentCount;
690
691 uint64_t FalseCount = setCount(ParentCount - TrueCount);
692 CountMap[E->getFalseExpr()] = FalseCount;
693 Visit(E->getFalseExpr());
694 OutCount += CurrentCount;
695
696 setCount(OutCount);
697 RecordNextStmtCount = true;
698 }
699
VisitBinLAnd__anon20f2e76b0111::ComputeRegionCounts700 void VisitBinLAnd(const BinaryOperator *E) {
701 RecordStmtCount(E);
702 uint64_t ParentCount = CurrentCount;
703 Visit(E->getLHS());
704 // Counter tracks the right hand side of a logical and operator.
705 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
706 CountMap[E->getRHS()] = RHSCount;
707 Visit(E->getRHS());
708 setCount(ParentCount + RHSCount - CurrentCount);
709 RecordNextStmtCount = true;
710 }
711
VisitBinLOr__anon20f2e76b0111::ComputeRegionCounts712 void VisitBinLOr(const BinaryOperator *E) {
713 RecordStmtCount(E);
714 uint64_t ParentCount = CurrentCount;
715 Visit(E->getLHS());
716 // Counter tracks the right hand side of a logical or operator.
717 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
718 CountMap[E->getRHS()] = RHSCount;
719 Visit(E->getRHS());
720 setCount(ParentCount + RHSCount - CurrentCount);
721 RecordNextStmtCount = true;
722 }
723 };
724 } // end anonymous namespace
725
combine(HashType Type)726 void PGOHash::combine(HashType Type) {
727 // Check that we never combine 0 and only have six bits.
728 assert(Type && "Hash is invalid: unexpected type 0");
729 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
730
731 // Pass through MD5 if enough work has built up.
732 if (Count && Count % NumTypesPerWord == 0) {
733 using namespace llvm::support;
734 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
735 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
736 Working = 0;
737 }
738
739 // Accumulate the current type.
740 ++Count;
741 Working = Working << NumBitsPerType | Type;
742 }
743
finalize()744 uint64_t PGOHash::finalize() {
745 // Use Working as the hash directly if we never used MD5.
746 if (Count <= NumTypesPerWord)
747 // No need to byte swap here, since none of the math was endian-dependent.
748 // This number will be byte-swapped as required on endianness transitions,
749 // so we will see the same value on the other side.
750 return Working;
751
752 // Check for remaining work in Working.
753 if (Working) {
754 // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
755 // is buggy because it converts a uint64_t into an array of uint8_t.
756 if (HashVersion < PGO_HASH_V3) {
757 MD5.update({(uint8_t)Working});
758 } else {
759 using namespace llvm::support;
760 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
761 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
762 }
763 }
764
765 // Finalize the MD5 and return the hash.
766 llvm::MD5::MD5Result Result;
767 MD5.final(Result);
768 return Result.low();
769 }
770
assignRegionCounters(GlobalDecl GD,llvm::Function * Fn)771 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
772 const Decl *D = GD.getDecl();
773 if (!D->hasBody())
774 return;
775
776 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
777 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
778 if (!InstrumentRegions && !PGOReader)
779 return;
780 if (D->isImplicit())
781 return;
782 // Constructors and destructors may be represented by several functions in IR.
783 // If so, instrument only base variant, others are implemented by delegation
784 // to the base one, it would be counted twice otherwise.
785 if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
786 if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
787 if (GD.getCtorType() != Ctor_Base &&
788 CodeGenFunction::IsConstructorDelegationValid(CCD))
789 return;
790 }
791 if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
792 return;
793
794 CGM.ClearUnusedCoverageMapping(D);
795 setFuncName(Fn);
796
797 mapRegionCounters(D);
798 if (CGM.getCodeGenOpts().CoverageMapping)
799 emitCounterRegionMapping(D);
800 if (PGOReader) {
801 SourceManager &SM = CGM.getContext().getSourceManager();
802 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
803 computeRegionCounts(D);
804 applyFunctionAttributes(PGOReader, Fn);
805 }
806 }
807
mapRegionCounters(const Decl * D)808 void CodeGenPGO::mapRegionCounters(const Decl *D) {
809 // Use the latest hash version when inserting instrumentation, but use the
810 // version in the indexed profile if we're reading PGO data.
811 PGOHashVersion HashVersion = PGO_HASH_LATEST;
812 if (auto *PGOReader = CGM.getPGOReader())
813 HashVersion = getPGOHashVersion(PGOReader, CGM);
814
815 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
816 MapRegionCounters Walker(HashVersion, *RegionCounterMap);
817 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
818 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
819 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
820 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
821 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
822 Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
823 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
824 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
825 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
826 NumRegionCounters = Walker.NextCounter;
827 FunctionHash = Walker.Hash.finalize();
828 }
829
skipRegionMappingForDecl(const Decl * D)830 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
831 if (!D->getBody())
832 return true;
833
834 // Don't map the functions in system headers.
835 const auto &SM = CGM.getContext().getSourceManager();
836 auto Loc = D->getBody()->getBeginLoc();
837 return SM.isInSystemHeader(Loc);
838 }
839
emitCounterRegionMapping(const Decl * D)840 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
841 if (skipRegionMappingForDecl(D))
842 return;
843
844 std::string CoverageMapping;
845 llvm::raw_string_ostream OS(CoverageMapping);
846 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
847 CGM.getContext().getSourceManager(),
848 CGM.getLangOpts(), RegionCounterMap.get());
849 MappingGen.emitCounterMapping(D, OS);
850 OS.flush();
851
852 if (CoverageMapping.empty())
853 return;
854
855 CGM.getCoverageMapping()->addFunctionMappingRecord(
856 FuncNameVar, FuncName, FunctionHash, CoverageMapping);
857 }
858
859 void
emitEmptyCounterMapping(const Decl * D,StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)860 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
861 llvm::GlobalValue::LinkageTypes Linkage) {
862 if (skipRegionMappingForDecl(D))
863 return;
864
865 std::string CoverageMapping;
866 llvm::raw_string_ostream OS(CoverageMapping);
867 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
868 CGM.getContext().getSourceManager(),
869 CGM.getLangOpts());
870 MappingGen.emitEmptyMapping(D, OS);
871 OS.flush();
872
873 if (CoverageMapping.empty())
874 return;
875
876 setFuncName(Name, Linkage);
877 CGM.getCoverageMapping()->addFunctionMappingRecord(
878 FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
879 }
880
computeRegionCounts(const Decl * D)881 void CodeGenPGO::computeRegionCounts(const Decl *D) {
882 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
883 ComputeRegionCounts Walker(*StmtCountMap, *this);
884 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
885 Walker.VisitFunctionDecl(FD);
886 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
887 Walker.VisitObjCMethodDecl(MD);
888 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
889 Walker.VisitBlockDecl(BD);
890 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
891 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
892 }
893
894 void
applyFunctionAttributes(llvm::IndexedInstrProfReader * PGOReader,llvm::Function * Fn)895 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
896 llvm::Function *Fn) {
897 if (!haveRegionCounts())
898 return;
899
900 uint64_t FunctionCount = getRegionCount(nullptr);
901 Fn->setEntryCount(FunctionCount);
902 }
903
emitCounterIncrement(CGBuilderTy & Builder,const Stmt * S,llvm::Value * StepV)904 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
905 llvm::Value *StepV) {
906 if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
907 return;
908 if (!Builder.GetInsertBlock())
909 return;
910
911 unsigned Counter = (*RegionCounterMap)[S];
912 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
913
914 llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
915 Builder.getInt64(FunctionHash),
916 Builder.getInt32(NumRegionCounters),
917 Builder.getInt32(Counter), StepV};
918 if (!StepV)
919 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
920 makeArrayRef(Args, 4));
921 else
922 Builder.CreateCall(
923 CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
924 makeArrayRef(Args));
925 }
926
927 // This method either inserts a call to the profile run-time during
928 // instrumentation or puts profile data into metadata for PGO use.
valueProfile(CGBuilderTy & Builder,uint32_t ValueKind,llvm::Instruction * ValueSite,llvm::Value * ValuePtr)929 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
930 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
931
932 if (!EnableValueProfiling)
933 return;
934
935 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
936 return;
937
938 if (isa<llvm::Constant>(ValuePtr))
939 return;
940
941 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
942 if (InstrumentValueSites && RegionCounterMap) {
943 auto BuilderInsertPoint = Builder.saveIP();
944 Builder.SetInsertPoint(ValueSite);
945 llvm::Value *Args[5] = {
946 llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
947 Builder.getInt64(FunctionHash),
948 Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
949 Builder.getInt32(ValueKind),
950 Builder.getInt32(NumValueSites[ValueKind]++)
951 };
952 Builder.CreateCall(
953 CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
954 Builder.restoreIP(BuilderInsertPoint);
955 return;
956 }
957
958 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
959 if (PGOReader && haveRegionCounts()) {
960 // We record the top most called three functions at each call site.
961 // Profile metadata contains "VP" string identifying this metadata
962 // as value profiling data, then a uint32_t value for the value profiling
963 // kind, a uint64_t value for the total number of times the call is
964 // executed, followed by the function hash and execution count (uint64_t)
965 // pairs for each function.
966 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
967 return;
968
969 llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
970 (llvm::InstrProfValueKind)ValueKind,
971 NumValueSites[ValueKind]);
972
973 NumValueSites[ValueKind]++;
974 }
975 }
976
loadRegionCounts(llvm::IndexedInstrProfReader * PGOReader,bool IsInMainFile)977 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
978 bool IsInMainFile) {
979 CGM.getPGOStats().addVisited(IsInMainFile);
980 RegionCounts.clear();
981 llvm::Expected<llvm::InstrProfRecord> RecordExpected =
982 PGOReader->getInstrProfRecord(FuncName, FunctionHash);
983 if (auto E = RecordExpected.takeError()) {
984 auto IPE = llvm::InstrProfError::take(std::move(E));
985 if (IPE == llvm::instrprof_error::unknown_function)
986 CGM.getPGOStats().addMissing(IsInMainFile);
987 else if (IPE == llvm::instrprof_error::hash_mismatch)
988 CGM.getPGOStats().addMismatched(IsInMainFile);
989 else if (IPE == llvm::instrprof_error::malformed)
990 // TODO: Consider a more specific warning for this case.
991 CGM.getPGOStats().addMismatched(IsInMainFile);
992 return;
993 }
994 ProfRecord =
995 std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
996 RegionCounts = ProfRecord->Counts;
997 }
998
999 /// Calculate what to divide by to scale weights.
1000 ///
1001 /// Given the maximum weight, calculate a divisor that will scale all the
1002 /// weights to strictly less than UINT32_MAX.
calculateWeightScale(uint64_t MaxWeight)1003 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1004 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1005 }
1006
1007 /// Scale an individual branch weight (and add 1).
1008 ///
1009 /// Scale a 64-bit weight down to 32-bits using \c Scale.
1010 ///
1011 /// According to Laplace's Rule of Succession, it is better to compute the
1012 /// weight based on the count plus 1, so universally add 1 to the value.
1013 ///
1014 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1015 /// greater than \c Weight.
scaleBranchWeight(uint64_t Weight,uint64_t Scale)1016 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1017 assert(Scale && "scale by 0?");
1018 uint64_t Scaled = Weight / Scale + 1;
1019 assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1020 return Scaled;
1021 }
1022
createProfileWeights(uint64_t TrueCount,uint64_t FalseCount)1023 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1024 uint64_t FalseCount) {
1025 // Check for empty weights.
1026 if (!TrueCount && !FalseCount)
1027 return nullptr;
1028
1029 // Calculate how to scale down to 32-bits.
1030 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1031
1032 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1033 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1034 scaleBranchWeight(FalseCount, Scale));
1035 }
1036
1037 llvm::MDNode *
createProfileWeights(ArrayRef<uint64_t> Weights)1038 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
1039 // We need at least two elements to create meaningful weights.
1040 if (Weights.size() < 2)
1041 return nullptr;
1042
1043 // Check for empty weights.
1044 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1045 if (MaxWeight == 0)
1046 return nullptr;
1047
1048 // Calculate how to scale down to 32-bits.
1049 uint64_t Scale = calculateWeightScale(MaxWeight);
1050
1051 SmallVector<uint32_t, 16> ScaledWeights;
1052 ScaledWeights.reserve(Weights.size());
1053 for (uint64_t W : Weights)
1054 ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1055
1056 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1057 return MDHelper.createBranchWeights(ScaledWeights);
1058 }
1059
createProfileWeightsForLoop(const Stmt * Cond,uint64_t LoopCount)1060 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1061 uint64_t LoopCount) {
1062 if (!PGO.haveRegionCounts())
1063 return nullptr;
1064 Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1065 if (!CondCount || *CondCount == 0)
1066 return nullptr;
1067 return createProfileWeights(LoopCount,
1068 std::max(*CondCount, LoopCount) - LoopCount);
1069 }
1070