1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Instrumentation-based profile-guided optimization
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "CoverageMappingGen.h"
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/AST/StmtVisitor.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/MDBuilder.h"
21 #include "llvm/ProfileData/InstrProfReader.h"
22 #include "llvm/Support/Endian.h"
23 #include "llvm/Support/FileSystem.h"
24 #include "llvm/Support/MD5.h"
25 
26 using namespace clang;
27 using namespace CodeGen;
28 
setFuncName(StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)29 void CodeGenPGO::setFuncName(StringRef Name,
30                              llvm::GlobalValue::LinkageTypes Linkage) {
31   StringRef RawFuncName = Name;
32 
33   // Function names may be prefixed with a binary '1' to indicate
34   // that the backend should not modify the symbols due to any platform
35   // naming convention. Do not include that '1' in the PGO profile name.
36   if (RawFuncName[0] == '\1')
37     RawFuncName = RawFuncName.substr(1);
38 
39   FuncName = RawFuncName;
40   if (llvm::GlobalValue::isLocalLinkage(Linkage)) {
41     // For local symbols, prepend the main file name to distinguish them.
42     // Do not include the full path in the file name since there's no guarantee
43     // that it will stay the same, e.g., if the files are checked out from
44     // version control in different locations.
45     if (CGM.getCodeGenOpts().MainFileName.empty())
46       FuncName = FuncName.insert(0, "<unknown>:");
47     else
48       FuncName = FuncName.insert(0, CGM.getCodeGenOpts().MainFileName + ":");
49   }
50 
51   // If we're generating a profile, create a variable for the name.
52   if (CGM.getCodeGenOpts().ProfileInstrGenerate)
53     createFuncNameVar(Linkage);
54 }
55 
setFuncName(llvm::Function * Fn)56 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
57   setFuncName(Fn->getName(), Fn->getLinkage());
58 }
59 
createFuncNameVar(llvm::GlobalValue::LinkageTypes Linkage)60 void CodeGenPGO::createFuncNameVar(llvm::GlobalValue::LinkageTypes Linkage) {
61   // Usually, we want to match the function's linkage, but
62   // available_externally and extern_weak both have the wrong semantics.
63   if (Linkage == llvm::GlobalValue::ExternalWeakLinkage)
64     Linkage = llvm::GlobalValue::LinkOnceAnyLinkage;
65   else if (Linkage == llvm::GlobalValue::AvailableExternallyLinkage)
66     Linkage = llvm::GlobalValue::LinkOnceODRLinkage;
67 
68   auto *Value =
69       llvm::ConstantDataArray::getString(CGM.getLLVMContext(), FuncName, false);
70   FuncNameVar =
71       new llvm::GlobalVariable(CGM.getModule(), Value->getType(), true, Linkage,
72                                Value, "__llvm_profile_name_" + FuncName);
73 
74   // Hide the symbol so that we correctly get a copy for each executable.
75   if (!llvm::GlobalValue::isLocalLinkage(FuncNameVar->getLinkage()))
76     FuncNameVar->setVisibility(llvm::GlobalValue::HiddenVisibility);
77 }
78 
79 namespace {
80 /// \brief Stable hasher for PGO region counters.
81 ///
82 /// PGOHash produces a stable hash of a given function's control flow.
83 ///
84 /// Changing the output of this hash will invalidate all previously generated
85 /// profiles -- i.e., don't do it.
86 ///
87 /// \note  When this hash does eventually change (years?), we still need to
88 /// support old hashes.  We'll need to pull in the version number from the
89 /// profile data format and use the matching hash function.
90 class PGOHash {
91   uint64_t Working;
92   unsigned Count;
93   llvm::MD5 MD5;
94 
95   static const int NumBitsPerType = 6;
96   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
97   static const unsigned TooBig = 1u << NumBitsPerType;
98 
99 public:
100   /// \brief Hash values for AST nodes.
101   ///
102   /// Distinct values for AST nodes that have region counters attached.
103   ///
104   /// These values must be stable.  All new members must be added at the end,
105   /// and no members should be removed.  Changing the enumeration value for an
106   /// AST node will affect the hash of every function that contains that node.
107   enum HashType : unsigned char {
108     None = 0,
109     LabelStmt = 1,
110     WhileStmt,
111     DoStmt,
112     ForStmt,
113     CXXForRangeStmt,
114     ObjCForCollectionStmt,
115     SwitchStmt,
116     CaseStmt,
117     DefaultStmt,
118     IfStmt,
119     CXXTryStmt,
120     CXXCatchStmt,
121     ConditionalOperator,
122     BinaryOperatorLAnd,
123     BinaryOperatorLOr,
124     BinaryConditionalOperator,
125 
126     // Keep this last.  It's for the static assert that follows.
127     LastHashType
128   };
129   static_assert(LastHashType <= TooBig, "Too many types in HashType");
130 
131   // TODO: When this format changes, take in a version number here, and use the
132   // old hash calculation for file formats that used the old hash.
PGOHash()133   PGOHash() : Working(0), Count(0) {}
134   void combine(HashType Type);
135   uint64_t finalize();
136 };
137 const int PGOHash::NumBitsPerType;
138 const unsigned PGOHash::NumTypesPerWord;
139 const unsigned PGOHash::TooBig;
140 
141   /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
142   struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
143     /// The next counter value to assign.
144     unsigned NextCounter;
145     /// The function hash.
146     PGOHash Hash;
147     /// The map of statements to counters.
148     llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
149 
MapRegionCounters__anon65fb66840111::MapRegionCounters150     MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
151         : NextCounter(0), CounterMap(CounterMap) {}
152 
153     // Blocks and lambdas are handled as separate functions, so we need not
154     // traverse them in the parent context.
TraverseBlockExpr__anon65fb66840111::MapRegionCounters155     bool TraverseBlockExpr(BlockExpr *BE) { return true; }
TraverseLambdaBody__anon65fb66840111::MapRegionCounters156     bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
TraverseCapturedStmt__anon65fb66840111::MapRegionCounters157     bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
158 
VisitDecl__anon65fb66840111::MapRegionCounters159     bool VisitDecl(const Decl *D) {
160       switch (D->getKind()) {
161       default:
162         break;
163       case Decl::Function:
164       case Decl::CXXMethod:
165       case Decl::CXXConstructor:
166       case Decl::CXXDestructor:
167       case Decl::CXXConversion:
168       case Decl::ObjCMethod:
169       case Decl::Block:
170       case Decl::Captured:
171         CounterMap[D->getBody()] = NextCounter++;
172         break;
173       }
174       return true;
175     }
176 
VisitStmt__anon65fb66840111::MapRegionCounters177     bool VisitStmt(const Stmt *S) {
178       auto Type = getHashType(S);
179       if (Type == PGOHash::None)
180         return true;
181 
182       CounterMap[S] = NextCounter++;
183       Hash.combine(Type);
184       return true;
185     }
getHashType__anon65fb66840111::MapRegionCounters186     PGOHash::HashType getHashType(const Stmt *S) {
187       switch (S->getStmtClass()) {
188       default:
189         break;
190       case Stmt::LabelStmtClass:
191         return PGOHash::LabelStmt;
192       case Stmt::WhileStmtClass:
193         return PGOHash::WhileStmt;
194       case Stmt::DoStmtClass:
195         return PGOHash::DoStmt;
196       case Stmt::ForStmtClass:
197         return PGOHash::ForStmt;
198       case Stmt::CXXForRangeStmtClass:
199         return PGOHash::CXXForRangeStmt;
200       case Stmt::ObjCForCollectionStmtClass:
201         return PGOHash::ObjCForCollectionStmt;
202       case Stmt::SwitchStmtClass:
203         return PGOHash::SwitchStmt;
204       case Stmt::CaseStmtClass:
205         return PGOHash::CaseStmt;
206       case Stmt::DefaultStmtClass:
207         return PGOHash::DefaultStmt;
208       case Stmt::IfStmtClass:
209         return PGOHash::IfStmt;
210       case Stmt::CXXTryStmtClass:
211         return PGOHash::CXXTryStmt;
212       case Stmt::CXXCatchStmtClass:
213         return PGOHash::CXXCatchStmt;
214       case Stmt::ConditionalOperatorClass:
215         return PGOHash::ConditionalOperator;
216       case Stmt::BinaryConditionalOperatorClass:
217         return PGOHash::BinaryConditionalOperator;
218       case Stmt::BinaryOperatorClass: {
219         const BinaryOperator *BO = cast<BinaryOperator>(S);
220         if (BO->getOpcode() == BO_LAnd)
221           return PGOHash::BinaryOperatorLAnd;
222         if (BO->getOpcode() == BO_LOr)
223           return PGOHash::BinaryOperatorLOr;
224         break;
225       }
226       }
227       return PGOHash::None;
228     }
229   };
230 
231   /// A StmtVisitor that propagates the raw counts through the AST and
232   /// records the count at statements where the value may change.
233   struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
234     /// PGO state.
235     CodeGenPGO &PGO;
236 
237     /// A flag that is set when the current count should be recorded on the
238     /// next statement, such as at the exit of a loop.
239     bool RecordNextStmtCount;
240 
241     /// The map of statements to count values.
242     llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
243 
244     /// BreakContinueStack - Keep counts of breaks and continues inside loops.
245     struct BreakContinue {
246       uint64_t BreakCount;
247       uint64_t ContinueCount;
BreakContinue__anon65fb66840111::ComputeRegionCounts::BreakContinue248       BreakContinue() : BreakCount(0), ContinueCount(0) {}
249     };
250     SmallVector<BreakContinue, 8> BreakContinueStack;
251 
ComputeRegionCounts__anon65fb66840111::ComputeRegionCounts252     ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
253                         CodeGenPGO &PGO)
254         : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
255 
RecordStmtCount__anon65fb66840111::ComputeRegionCounts256     void RecordStmtCount(const Stmt *S) {
257       if (RecordNextStmtCount) {
258         CountMap[S] = PGO.getCurrentRegionCount();
259         RecordNextStmtCount = false;
260       }
261     }
262 
VisitStmt__anon65fb66840111::ComputeRegionCounts263     void VisitStmt(const Stmt *S) {
264       RecordStmtCount(S);
265       for (Stmt::const_child_range I = S->children(); I; ++I) {
266         if (*I)
267          this->Visit(*I);
268       }
269     }
270 
VisitFunctionDecl__anon65fb66840111::ComputeRegionCounts271     void VisitFunctionDecl(const FunctionDecl *D) {
272       // Counter tracks entry to the function body.
273       RegionCounter Cnt(PGO, D->getBody());
274       Cnt.beginRegion();
275       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
276       Visit(D->getBody());
277     }
278 
279     // Skip lambda expressions. We visit these as FunctionDecls when we're
280     // generating them and aren't interested in the body when generating a
281     // parent context.
VisitLambdaExpr__anon65fb66840111::ComputeRegionCounts282     void VisitLambdaExpr(const LambdaExpr *LE) {}
283 
VisitCapturedDecl__anon65fb66840111::ComputeRegionCounts284     void VisitCapturedDecl(const CapturedDecl *D) {
285       // Counter tracks entry to the capture body.
286       RegionCounter Cnt(PGO, D->getBody());
287       Cnt.beginRegion();
288       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
289       Visit(D->getBody());
290     }
291 
VisitObjCMethodDecl__anon65fb66840111::ComputeRegionCounts292     void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
293       // Counter tracks entry to the method body.
294       RegionCounter Cnt(PGO, D->getBody());
295       Cnt.beginRegion();
296       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
297       Visit(D->getBody());
298     }
299 
VisitBlockDecl__anon65fb66840111::ComputeRegionCounts300     void VisitBlockDecl(const BlockDecl *D) {
301       // Counter tracks entry to the block body.
302       RegionCounter Cnt(PGO, D->getBody());
303       Cnt.beginRegion();
304       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
305       Visit(D->getBody());
306     }
307 
VisitReturnStmt__anon65fb66840111::ComputeRegionCounts308     void VisitReturnStmt(const ReturnStmt *S) {
309       RecordStmtCount(S);
310       if (S->getRetValue())
311         Visit(S->getRetValue());
312       PGO.setCurrentRegionUnreachable();
313       RecordNextStmtCount = true;
314     }
315 
VisitGotoStmt__anon65fb66840111::ComputeRegionCounts316     void VisitGotoStmt(const GotoStmt *S) {
317       RecordStmtCount(S);
318       PGO.setCurrentRegionUnreachable();
319       RecordNextStmtCount = true;
320     }
321 
VisitLabelStmt__anon65fb66840111::ComputeRegionCounts322     void VisitLabelStmt(const LabelStmt *S) {
323       RecordNextStmtCount = false;
324       // Counter tracks the block following the label.
325       RegionCounter Cnt(PGO, S);
326       Cnt.beginRegion();
327       CountMap[S] = PGO.getCurrentRegionCount();
328       Visit(S->getSubStmt());
329     }
330 
VisitBreakStmt__anon65fb66840111::ComputeRegionCounts331     void VisitBreakStmt(const BreakStmt *S) {
332       RecordStmtCount(S);
333       assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
334       BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount();
335       PGO.setCurrentRegionUnreachable();
336       RecordNextStmtCount = true;
337     }
338 
VisitContinueStmt__anon65fb66840111::ComputeRegionCounts339     void VisitContinueStmt(const ContinueStmt *S) {
340       RecordStmtCount(S);
341       assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
342       BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount();
343       PGO.setCurrentRegionUnreachable();
344       RecordNextStmtCount = true;
345     }
346 
VisitWhileStmt__anon65fb66840111::ComputeRegionCounts347     void VisitWhileStmt(const WhileStmt *S) {
348       RecordStmtCount(S);
349       // Counter tracks the body of the loop.
350       RegionCounter Cnt(PGO, S);
351       BreakContinueStack.push_back(BreakContinue());
352       // Visit the body region first so the break/continue adjustments can be
353       // included when visiting the condition.
354       Cnt.beginRegion();
355       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
356       Visit(S->getBody());
357       Cnt.adjustForControlFlow();
358 
359       // ...then go back and propagate counts through the condition. The count
360       // at the start of the condition is the sum of the incoming edges,
361       // the backedge from the end of the loop body, and the edges from
362       // continue statements.
363       BreakContinue BC = BreakContinueStack.pop_back_val();
364       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
365                                 Cnt.getAdjustedCount() + BC.ContinueCount);
366       CountMap[S->getCond()] = PGO.getCurrentRegionCount();
367       Visit(S->getCond());
368       Cnt.adjustForControlFlow();
369       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
370       RecordNextStmtCount = true;
371     }
372 
VisitDoStmt__anon65fb66840111::ComputeRegionCounts373     void VisitDoStmt(const DoStmt *S) {
374       RecordStmtCount(S);
375       // Counter tracks the body of the loop.
376       RegionCounter Cnt(PGO, S);
377       BreakContinueStack.push_back(BreakContinue());
378       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
379       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
380       Visit(S->getBody());
381       Cnt.adjustForControlFlow();
382 
383       BreakContinue BC = BreakContinueStack.pop_back_val();
384       // The count at the start of the condition is equal to the count at the
385       // end of the body. The adjusted count does not include either the
386       // fall-through count coming into the loop or the continue count, so add
387       // both of those separately. This is coincidentally the same equation as
388       // with while loops but for different reasons.
389       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
390                                 Cnt.getAdjustedCount() + BC.ContinueCount);
391       CountMap[S->getCond()] = PGO.getCurrentRegionCount();
392       Visit(S->getCond());
393       Cnt.adjustForControlFlow();
394       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
395       RecordNextStmtCount = true;
396     }
397 
VisitForStmt__anon65fb66840111::ComputeRegionCounts398     void VisitForStmt(const ForStmt *S) {
399       RecordStmtCount(S);
400       if (S->getInit())
401         Visit(S->getInit());
402       // Counter tracks the body of the loop.
403       RegionCounter Cnt(PGO, S);
404       BreakContinueStack.push_back(BreakContinue());
405       // Visit the body region first. (This is basically the same as a while
406       // loop; see further comments in VisitWhileStmt.)
407       Cnt.beginRegion();
408       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
409       Visit(S->getBody());
410       Cnt.adjustForControlFlow();
411 
412       // The increment is essentially part of the body but it needs to include
413       // the count for all the continue statements.
414       if (S->getInc()) {
415         Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
416                                   BreakContinueStack.back().ContinueCount);
417         CountMap[S->getInc()] = PGO.getCurrentRegionCount();
418         Visit(S->getInc());
419         Cnt.adjustForControlFlow();
420       }
421 
422       BreakContinue BC = BreakContinueStack.pop_back_val();
423 
424       // ...then go back and propagate counts through the condition.
425       if (S->getCond()) {
426         Cnt.setCurrentRegionCount(Cnt.getParentCount() +
427                                   Cnt.getAdjustedCount() +
428                                   BC.ContinueCount);
429         CountMap[S->getCond()] = PGO.getCurrentRegionCount();
430         Visit(S->getCond());
431         Cnt.adjustForControlFlow();
432       }
433       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
434       RecordNextStmtCount = true;
435     }
436 
VisitCXXForRangeStmt__anon65fb66840111::ComputeRegionCounts437     void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
438       RecordStmtCount(S);
439       Visit(S->getRangeStmt());
440       Visit(S->getBeginEndStmt());
441       // Counter tracks the body of the loop.
442       RegionCounter Cnt(PGO, S);
443       BreakContinueStack.push_back(BreakContinue());
444       // Visit the body region first. (This is basically the same as a while
445       // loop; see further comments in VisitWhileStmt.)
446       Cnt.beginRegion();
447       CountMap[S->getLoopVarStmt()] = PGO.getCurrentRegionCount();
448       Visit(S->getLoopVarStmt());
449       Visit(S->getBody());
450       Cnt.adjustForControlFlow();
451 
452       // The increment is essentially part of the body but it needs to include
453       // the count for all the continue statements.
454       Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
455                                 BreakContinueStack.back().ContinueCount);
456       CountMap[S->getInc()] = PGO.getCurrentRegionCount();
457       Visit(S->getInc());
458       Cnt.adjustForControlFlow();
459 
460       BreakContinue BC = BreakContinueStack.pop_back_val();
461 
462       // ...then go back and propagate counts through the condition.
463       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
464                                 Cnt.getAdjustedCount() +
465                                 BC.ContinueCount);
466       CountMap[S->getCond()] = PGO.getCurrentRegionCount();
467       Visit(S->getCond());
468       Cnt.adjustForControlFlow();
469       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
470       RecordNextStmtCount = true;
471     }
472 
VisitObjCForCollectionStmt__anon65fb66840111::ComputeRegionCounts473     void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
474       RecordStmtCount(S);
475       Visit(S->getElement());
476       // Counter tracks the body of the loop.
477       RegionCounter Cnt(PGO, S);
478       BreakContinueStack.push_back(BreakContinue());
479       Cnt.beginRegion();
480       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
481       Visit(S->getBody());
482       BreakContinue BC = BreakContinueStack.pop_back_val();
483       Cnt.adjustForControlFlow();
484       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
485       RecordNextStmtCount = true;
486     }
487 
VisitSwitchStmt__anon65fb66840111::ComputeRegionCounts488     void VisitSwitchStmt(const SwitchStmt *S) {
489       RecordStmtCount(S);
490       Visit(S->getCond());
491       PGO.setCurrentRegionUnreachable();
492       BreakContinueStack.push_back(BreakContinue());
493       Visit(S->getBody());
494       // If the switch is inside a loop, add the continue counts.
495       BreakContinue BC = BreakContinueStack.pop_back_val();
496       if (!BreakContinueStack.empty())
497         BreakContinueStack.back().ContinueCount += BC.ContinueCount;
498       // Counter tracks the exit block of the switch.
499       RegionCounter ExitCnt(PGO, S);
500       ExitCnt.beginRegion();
501       RecordNextStmtCount = true;
502     }
503 
VisitCaseStmt__anon65fb66840111::ComputeRegionCounts504     void VisitCaseStmt(const CaseStmt *S) {
505       RecordNextStmtCount = false;
506       // Counter for this particular case. This counts only jumps from the
507       // switch header and does not include fallthrough from the case before
508       // this one.
509       RegionCounter Cnt(PGO, S);
510       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
511       CountMap[S] = Cnt.getCount();
512       RecordNextStmtCount = true;
513       Visit(S->getSubStmt());
514     }
515 
VisitDefaultStmt__anon65fb66840111::ComputeRegionCounts516     void VisitDefaultStmt(const DefaultStmt *S) {
517       RecordNextStmtCount = false;
518       // Counter for this default case. This does not include fallthrough from
519       // the previous case.
520       RegionCounter Cnt(PGO, S);
521       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
522       CountMap[S] = Cnt.getCount();
523       RecordNextStmtCount = true;
524       Visit(S->getSubStmt());
525     }
526 
VisitIfStmt__anon65fb66840111::ComputeRegionCounts527     void VisitIfStmt(const IfStmt *S) {
528       RecordStmtCount(S);
529       // Counter tracks the "then" part of an if statement. The count for
530       // the "else" part, if it exists, will be calculated from this counter.
531       RegionCounter Cnt(PGO, S);
532       Visit(S->getCond());
533 
534       Cnt.beginRegion();
535       CountMap[S->getThen()] = PGO.getCurrentRegionCount();
536       Visit(S->getThen());
537       Cnt.adjustForControlFlow();
538 
539       if (S->getElse()) {
540         Cnt.beginElseRegion();
541         CountMap[S->getElse()] = PGO.getCurrentRegionCount();
542         Visit(S->getElse());
543         Cnt.adjustForControlFlow();
544       }
545       Cnt.applyAdjustmentsToRegion(0);
546       RecordNextStmtCount = true;
547     }
548 
VisitCXXTryStmt__anon65fb66840111::ComputeRegionCounts549     void VisitCXXTryStmt(const CXXTryStmt *S) {
550       RecordStmtCount(S);
551       Visit(S->getTryBlock());
552       for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
553         Visit(S->getHandler(I));
554       // Counter tracks the continuation block of the try statement.
555       RegionCounter Cnt(PGO, S);
556       Cnt.beginRegion();
557       RecordNextStmtCount = true;
558     }
559 
VisitCXXCatchStmt__anon65fb66840111::ComputeRegionCounts560     void VisitCXXCatchStmt(const CXXCatchStmt *S) {
561       RecordNextStmtCount = false;
562       // Counter tracks the catch statement's handler block.
563       RegionCounter Cnt(PGO, S);
564       Cnt.beginRegion();
565       CountMap[S] = PGO.getCurrentRegionCount();
566       Visit(S->getHandlerBlock());
567     }
568 
VisitAbstractConditionalOperator__anon65fb66840111::ComputeRegionCounts569     void VisitAbstractConditionalOperator(
570         const AbstractConditionalOperator *E) {
571       RecordStmtCount(E);
572       // Counter tracks the "true" part of a conditional operator. The
573       // count in the "false" part will be calculated from this counter.
574       RegionCounter Cnt(PGO, E);
575       Visit(E->getCond());
576 
577       Cnt.beginRegion();
578       CountMap[E->getTrueExpr()] = PGO.getCurrentRegionCount();
579       Visit(E->getTrueExpr());
580       Cnt.adjustForControlFlow();
581 
582       Cnt.beginElseRegion();
583       CountMap[E->getFalseExpr()] = PGO.getCurrentRegionCount();
584       Visit(E->getFalseExpr());
585       Cnt.adjustForControlFlow();
586 
587       Cnt.applyAdjustmentsToRegion(0);
588       RecordNextStmtCount = true;
589     }
590 
VisitBinLAnd__anon65fb66840111::ComputeRegionCounts591     void VisitBinLAnd(const BinaryOperator *E) {
592       RecordStmtCount(E);
593       // Counter tracks the right hand side of a logical and operator.
594       RegionCounter Cnt(PGO, E);
595       Visit(E->getLHS());
596       Cnt.beginRegion();
597       CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
598       Visit(E->getRHS());
599       Cnt.adjustForControlFlow();
600       Cnt.applyAdjustmentsToRegion(0);
601       RecordNextStmtCount = true;
602     }
603 
VisitBinLOr__anon65fb66840111::ComputeRegionCounts604     void VisitBinLOr(const BinaryOperator *E) {
605       RecordStmtCount(E);
606       // Counter tracks the right hand side of a logical or operator.
607       RegionCounter Cnt(PGO, E);
608       Visit(E->getLHS());
609       Cnt.beginRegion();
610       CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
611       Visit(E->getRHS());
612       Cnt.adjustForControlFlow();
613       Cnt.applyAdjustmentsToRegion(0);
614       RecordNextStmtCount = true;
615     }
616   };
617 }
618 
combine(HashType Type)619 void PGOHash::combine(HashType Type) {
620   // Check that we never combine 0 and only have six bits.
621   assert(Type && "Hash is invalid: unexpected type 0");
622   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
623 
624   // Pass through MD5 if enough work has built up.
625   if (Count && Count % NumTypesPerWord == 0) {
626     using namespace llvm::support;
627     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
628     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
629     Working = 0;
630   }
631 
632   // Accumulate the current type.
633   ++Count;
634   Working = Working << NumBitsPerType | Type;
635 }
636 
finalize()637 uint64_t PGOHash::finalize() {
638   // Use Working as the hash directly if we never used MD5.
639   if (Count <= NumTypesPerWord)
640     // No need to byte swap here, since none of the math was endian-dependent.
641     // This number will be byte-swapped as required on endianness transitions,
642     // so we will see the same value on the other side.
643     return Working;
644 
645   // Check for remaining work in Working.
646   if (Working)
647     MD5.update(Working);
648 
649   // Finalize the MD5 and return the hash.
650   llvm::MD5::MD5Result Result;
651   MD5.final(Result);
652   using namespace llvm::support;
653   return endian::read<uint64_t, little, unaligned>(Result);
654 }
655 
checkGlobalDecl(GlobalDecl GD)656 void CodeGenPGO::checkGlobalDecl(GlobalDecl GD) {
657   // Make sure we only emit coverage mapping for one constructor/destructor.
658   // Clang emits several functions for the constructor and the destructor of
659   // a class. Every function is instrumented, but we only want to provide
660   // coverage for one of them. Because of that we only emit the coverage mapping
661   // for the base constructor/destructor.
662   if ((isa<CXXConstructorDecl>(GD.getDecl()) &&
663        GD.getCtorType() != Ctor_Base) ||
664       (isa<CXXDestructorDecl>(GD.getDecl()) &&
665        GD.getDtorType() != Dtor_Base)) {
666     SkipCoverageMapping = true;
667   }
668 }
669 
assignRegionCounters(const Decl * D,llvm::Function * Fn)670 void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) {
671   bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
672   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
673   if (!InstrumentRegions && !PGOReader)
674     return;
675   if (D->isImplicit())
676     return;
677   CGM.ClearUnusedCoverageMapping(D);
678   setFuncName(Fn);
679 
680   mapRegionCounters(D);
681   if (CGM.getCodeGenOpts().CoverageMapping)
682     emitCounterRegionMapping(D);
683   if (PGOReader) {
684     SourceManager &SM = CGM.getContext().getSourceManager();
685     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
686     computeRegionCounts(D);
687     applyFunctionAttributes(PGOReader, Fn);
688   }
689 }
690 
mapRegionCounters(const Decl * D)691 void CodeGenPGO::mapRegionCounters(const Decl *D) {
692   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
693   MapRegionCounters Walker(*RegionCounterMap);
694   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
695     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
696   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
697     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
698   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
699     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
700   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
701     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
702   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
703   NumRegionCounters = Walker.NextCounter;
704   FunctionHash = Walker.Hash.finalize();
705 }
706 
emitCounterRegionMapping(const Decl * D)707 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
708   if (SkipCoverageMapping)
709     return;
710   // Don't map the functions inside the system headers
711   auto Loc = D->getBody()->getLocStart();
712   if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
713     return;
714 
715   std::string CoverageMapping;
716   llvm::raw_string_ostream OS(CoverageMapping);
717   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
718                                 CGM.getContext().getSourceManager(),
719                                 CGM.getLangOpts(), RegionCounterMap.get());
720   MappingGen.emitCounterMapping(D, OS);
721   OS.flush();
722 
723   if (CoverageMapping.empty())
724     return;
725 
726   CGM.getCoverageMapping()->addFunctionMappingRecord(
727       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
728 }
729 
730 void
emitEmptyCounterMapping(const Decl * D,StringRef FuncName,llvm::GlobalValue::LinkageTypes Linkage)731 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef FuncName,
732                                     llvm::GlobalValue::LinkageTypes Linkage) {
733   if (SkipCoverageMapping)
734     return;
735   setFuncName(FuncName, Linkage);
736 
737   // Don't map the functions inside the system headers
738   auto Loc = D->getBody()->getLocStart();
739   if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
740     return;
741 
742   std::string CoverageMapping;
743   llvm::raw_string_ostream OS(CoverageMapping);
744   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
745                                 CGM.getContext().getSourceManager(),
746                                 CGM.getLangOpts());
747   MappingGen.emitEmptyMapping(D, OS);
748   OS.flush();
749 
750   if (CoverageMapping.empty())
751     return;
752 
753   CGM.getCoverageMapping()->addFunctionMappingRecord(
754       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
755 }
756 
computeRegionCounts(const Decl * D)757 void CodeGenPGO::computeRegionCounts(const Decl *D) {
758   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
759   ComputeRegionCounts Walker(*StmtCountMap, *this);
760   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
761     Walker.VisitFunctionDecl(FD);
762   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
763     Walker.VisitObjCMethodDecl(MD);
764   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
765     Walker.VisitBlockDecl(BD);
766   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
767     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
768 }
769 
770 void
applyFunctionAttributes(llvm::IndexedInstrProfReader * PGOReader,llvm::Function * Fn)771 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
772                                     llvm::Function *Fn) {
773   if (!haveRegionCounts())
774     return;
775 
776   uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount();
777   uint64_t FunctionCount = getRegionCount(0);
778   if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount))
779     // Turn on InlineHint attribute for hot functions.
780     // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal.
781     Fn->addFnAttr(llvm::Attribute::InlineHint);
782   else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount))
783     // Turn on Cold attribute for cold functions.
784     // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal.
785     Fn->addFnAttr(llvm::Attribute::Cold);
786 }
787 
emitCounterIncrement(CGBuilderTy & Builder,unsigned Counter)788 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) {
789   if (!CGM.getCodeGenOpts().ProfileInstrGenerate || !RegionCounterMap)
790     return;
791   if (!Builder.GetInsertPoint())
792     return;
793   auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
794   Builder.CreateCall4(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
795                       llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
796                       Builder.getInt64(FunctionHash),
797                       Builder.getInt32(NumRegionCounters),
798                       Builder.getInt32(Counter));
799 }
800 
loadRegionCounts(llvm::IndexedInstrProfReader * PGOReader,bool IsInMainFile)801 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
802                                   bool IsInMainFile) {
803   CGM.getPGOStats().addVisited(IsInMainFile);
804   RegionCounts.clear();
805   if (std::error_code EC =
806           PGOReader->getFunctionCounts(FuncName, FunctionHash, RegionCounts)) {
807     if (EC == llvm::instrprof_error::unknown_function)
808       CGM.getPGOStats().addMissing(IsInMainFile);
809     else if (EC == llvm::instrprof_error::hash_mismatch)
810       CGM.getPGOStats().addMismatched(IsInMainFile);
811     else if (EC == llvm::instrprof_error::malformed)
812       // TODO: Consider a more specific warning for this case.
813       CGM.getPGOStats().addMismatched(IsInMainFile);
814     RegionCounts.clear();
815   }
816 }
817 
818 /// \brief Calculate what to divide by to scale weights.
819 ///
820 /// Given the maximum weight, calculate a divisor that will scale all the
821 /// weights to strictly less than UINT32_MAX.
calculateWeightScale(uint64_t MaxWeight)822 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
823   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
824 }
825 
826 /// \brief Scale an individual branch weight (and add 1).
827 ///
828 /// Scale a 64-bit weight down to 32-bits using \c Scale.
829 ///
830 /// According to Laplace's Rule of Succession, it is better to compute the
831 /// weight based on the count plus 1, so universally add 1 to the value.
832 ///
833 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
834 /// greater than \c Weight.
scaleBranchWeight(uint64_t Weight,uint64_t Scale)835 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
836   assert(Scale && "scale by 0?");
837   uint64_t Scaled = Weight / Scale + 1;
838   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
839   return Scaled;
840 }
841 
createBranchWeights(uint64_t TrueCount,uint64_t FalseCount)842 llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount,
843                                               uint64_t FalseCount) {
844   // Check for empty weights.
845   if (!TrueCount && !FalseCount)
846     return nullptr;
847 
848   // Calculate how to scale down to 32-bits.
849   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
850 
851   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
852   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
853                                       scaleBranchWeight(FalseCount, Scale));
854 }
855 
createBranchWeights(ArrayRef<uint64_t> Weights)856 llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) {
857   // We need at least two elements to create meaningful weights.
858   if (Weights.size() < 2)
859     return nullptr;
860 
861   // Check for empty weights.
862   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
863   if (MaxWeight == 0)
864     return nullptr;
865 
866   // Calculate how to scale down to 32-bits.
867   uint64_t Scale = calculateWeightScale(MaxWeight);
868 
869   SmallVector<uint32_t, 16> ScaledWeights;
870   ScaledWeights.reserve(Weights.size());
871   for (uint64_t W : Weights)
872     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
873 
874   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
875   return MDHelper.createBranchWeights(ScaledWeights);
876 }
877 
createLoopWeights(const Stmt * Cond,RegionCounter & Cnt)878 llvm::MDNode *CodeGenPGO::createLoopWeights(const Stmt *Cond,
879                                             RegionCounter &Cnt) {
880   if (!haveRegionCounts())
881     return nullptr;
882   uint64_t LoopCount = Cnt.getCount();
883   uint64_t CondCount = 0;
884   bool Found = getStmtCount(Cond, CondCount);
885   assert(Found && "missing expected loop condition count");
886   (void)Found;
887   if (CondCount == 0)
888     return nullptr;
889   return createBranchWeights(LoopCount,
890                              std::max(CondCount, LoopCount) - LoopCount);
891 }
892