1 //===- SampleProf.h - Sampling profiling format support ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains common definitions used in the reading and writing of
10 // sample profile data.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_PROFILEDATA_SAMPLEPROF_H
15 #define LLVM_PROFILEDATA_SAMPLEPROF_H
16 
17 #include "llvm/ADT/DenseSet.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/StringMap.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/ADT/StringSet.h"
22 #include "llvm/IR/Function.h"
23 #include "llvm/IR/GlobalValue.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/Support/Allocator.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/ErrorOr.h"
28 #include "llvm/Support/MathExtras.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include <algorithm>
31 #include <cstdint>
32 #include <map>
33 #include <set>
34 #include <string>
35 #include <system_error>
36 #include <utility>
37 
38 namespace llvm {
39 
40 const std::error_category &sampleprof_category();
41 
42 enum class sampleprof_error {
43   success = 0,
44   bad_magic,
45   unsupported_version,
46   too_large,
47   truncated,
48   malformed,
49   unrecognized_format,
50   unsupported_writing_format,
51   truncated_name_table,
52   not_implemented,
53   counter_overflow,
54   ostream_seek_unsupported,
55   compress_failed,
56   uncompress_failed,
57   zlib_unavailable,
58   hash_mismatch
59 };
60 
make_error_code(sampleprof_error E)61 inline std::error_code make_error_code(sampleprof_error E) {
62   return std::error_code(static_cast<int>(E), sampleprof_category());
63 }
64 
MergeResult(sampleprof_error & Accumulator,sampleprof_error Result)65 inline sampleprof_error MergeResult(sampleprof_error &Accumulator,
66                                     sampleprof_error Result) {
67   // Prefer first error encountered as later errors may be secondary effects of
68   // the initial problem.
69   if (Accumulator == sampleprof_error::success &&
70       Result != sampleprof_error::success)
71     Accumulator = Result;
72   return Accumulator;
73 }
74 
75 } // end namespace llvm
76 
77 namespace std {
78 
79 template <>
80 struct is_error_code_enum<llvm::sampleprof_error> : std::true_type {};
81 
82 } // end namespace std
83 
84 namespace llvm {
85 namespace sampleprof {
86 
87 enum SampleProfileFormat {
88   SPF_None = 0,
89   SPF_Text = 0x1,
90   SPF_Compact_Binary = 0x2,
91   SPF_GCC = 0x3,
92   SPF_Ext_Binary = 0x4,
93   SPF_Binary = 0xff
94 };
95 
96 static inline uint64_t SPMagic(SampleProfileFormat Format = SPF_Binary) {
97   return uint64_t('S') << (64 - 8) | uint64_t('P') << (64 - 16) |
98          uint64_t('R') << (64 - 24) | uint64_t('O') << (64 - 32) |
99          uint64_t('F') << (64 - 40) | uint64_t('4') << (64 - 48) |
100          uint64_t('2') << (64 - 56) | uint64_t(Format);
101 }
102 
103 /// Get the proper representation of a string according to whether the
104 /// current Format uses MD5 to represent the string.
105 static inline StringRef getRepInFormat(StringRef Name, bool UseMD5,
106                                        std::string &GUIDBuf) {
107   if (Name.empty())
108     return Name;
109   GUIDBuf = std::to_string(Function::getGUID(Name));
110   return UseMD5 ? StringRef(GUIDBuf) : Name;
111 }
112 
113 static inline uint64_t SPVersion() { return 103; }
114 
115 // Section Type used by SampleProfileExtBinaryBaseReader and
116 // SampleProfileExtBinaryBaseWriter. Never change the existing
117 // value of enum. Only append new ones.
118 enum SecType {
119   SecInValid = 0,
120   SecProfSummary = 1,
121   SecNameTable = 2,
122   SecProfileSymbolList = 3,
123   SecFuncOffsetTable = 4,
124   SecFuncMetadata = 5,
125   // marker for the first type of profile.
126   SecFuncProfileFirst = 32,
127   SecLBRProfile = SecFuncProfileFirst
128 };
129 
130 static inline std::string getSecName(SecType Type) {
131   switch (Type) {
132   case SecInValid:
133     return "InvalidSection";
134   case SecProfSummary:
135     return "ProfileSummarySection";
136   case SecNameTable:
137     return "NameTableSection";
138   case SecProfileSymbolList:
139     return "ProfileSymbolListSection";
140   case SecFuncOffsetTable:
141     return "FuncOffsetTableSection";
142   case SecFuncMetadata:
143     return "FunctionMetadata";
144   case SecLBRProfile:
145     return "LBRProfileSection";
146   }
147   llvm_unreachable("A SecType has no name for output");
148 }
149 
150 // Entry type of section header table used by SampleProfileExtBinaryBaseReader
151 // and SampleProfileExtBinaryBaseWriter.
152 struct SecHdrTableEntry {
153   SecType Type;
154   uint64_t Flags;
155   uint64_t Offset;
156   uint64_t Size;
157   // The index indicating the location of the current entry in
158   // SectionHdrLayout table.
159   uint32_t LayoutIndex;
160 };
161 
162 // Flags common for all sections are defined here. In SecHdrTableEntry::Flags,
163 // common flags will be saved in the lower 32bits and section specific flags
164 // will be saved in the higher 32 bits.
165 enum class SecCommonFlags : uint32_t {
166   SecFlagInValid = 0,
167   SecFlagCompress = (1 << 0),
168   // Indicate the section contains only profile without context.
169   SecFlagFlat = (1 << 1)
170 };
171 
172 // Section specific flags are defined here.
173 // !!!Note: Everytime a new enum class is created here, please add
174 // a new check in verifySecFlag.
175 enum class SecNameTableFlags : uint32_t {
176   SecFlagInValid = 0,
177   SecFlagMD5Name = (1 << 0),
178   // Store MD5 in fixed length instead of ULEB128 so NameTable can be
179   // accessed like an array.
180   SecFlagFixedLengthMD5 = (1 << 1)
181 };
182 enum class SecProfSummaryFlags : uint32_t {
183   SecFlagInValid = 0,
184   /// SecFlagPartial means the profile is for common/shared code.
185   /// The common profile is usually merged from profiles collected
186   /// from running other targets.
187   SecFlagPartial = (1 << 0)
188 };
189 
190 enum class SecFuncMetadataFlags : uint32_t {
191   SecFlagInvalid = 0,
192   SecFlagIsProbeBased = (1 << 0),
193 };
194 
195 // Verify section specific flag is used for the correct section.
196 template <class SecFlagType>
197 static inline void verifySecFlag(SecType Type, SecFlagType Flag) {
198   // No verification is needed for common flags.
199   if (std::is_same<SecCommonFlags, SecFlagType>())
200     return;
201 
202   // Verification starts here for section specific flag.
203   bool IsFlagLegal = false;
204   switch (Type) {
205   case SecNameTable:
206     IsFlagLegal = std::is_same<SecNameTableFlags, SecFlagType>();
207     break;
208   case SecProfSummary:
209     IsFlagLegal = std::is_same<SecProfSummaryFlags, SecFlagType>();
210     break;
211   case SecFuncMetadata:
212     IsFlagLegal = std::is_same<SecFuncMetadataFlags, SecFlagType>();
213     break;
214   default:
215     break;
216   }
217   if (!IsFlagLegal)
218     llvm_unreachable("Misuse of a flag in an incompatible section");
219 }
220 
221 template <class SecFlagType>
222 static inline void addSecFlag(SecHdrTableEntry &Entry, SecFlagType Flag) {
223   verifySecFlag(Entry.Type, Flag);
224   auto FVal = static_cast<uint64_t>(Flag);
225   bool IsCommon = std::is_same<SecCommonFlags, SecFlagType>();
226   Entry.Flags |= IsCommon ? FVal : (FVal << 32);
227 }
228 
229 template <class SecFlagType>
230 static inline void removeSecFlag(SecHdrTableEntry &Entry, SecFlagType Flag) {
231   verifySecFlag(Entry.Type, Flag);
232   auto FVal = static_cast<uint64_t>(Flag);
233   bool IsCommon = std::is_same<SecCommonFlags, SecFlagType>();
234   Entry.Flags &= ~(IsCommon ? FVal : (FVal << 32));
235 }
236 
237 template <class SecFlagType>
238 static inline bool hasSecFlag(const SecHdrTableEntry &Entry, SecFlagType Flag) {
239   verifySecFlag(Entry.Type, Flag);
240   auto FVal = static_cast<uint64_t>(Flag);
241   bool IsCommon = std::is_same<SecCommonFlags, SecFlagType>();
242   return Entry.Flags & (IsCommon ? FVal : (FVal << 32));
243 }
244 
245 /// Represents the relative location of an instruction.
246 ///
247 /// Instruction locations are specified by the line offset from the
248 /// beginning of the function (marked by the line where the function
249 /// header is) and the discriminator value within that line.
250 ///
251 /// The discriminator value is useful to distinguish instructions
252 /// that are on the same line but belong to different basic blocks
253 /// (e.g., the two post-increment instructions in "if (p) x++; else y++;").
254 struct LineLocation {
255   LineLocation(uint32_t L, uint32_t D) : LineOffset(L), Discriminator(D) {}
256 
257   void print(raw_ostream &OS) const;
258   void dump() const;
259 
260   bool operator<(const LineLocation &O) const {
261     return LineOffset < O.LineOffset ||
262            (LineOffset == O.LineOffset && Discriminator < O.Discriminator);
263   }
264 
265   bool operator==(const LineLocation &O) const {
266     return LineOffset == O.LineOffset && Discriminator == O.Discriminator;
267   }
268 
269   bool operator!=(const LineLocation &O) const {
270     return LineOffset != O.LineOffset || Discriminator != O.Discriminator;
271   }
272 
273   uint32_t LineOffset;
274   uint32_t Discriminator;
275 };
276 
277 raw_ostream &operator<<(raw_ostream &OS, const LineLocation &Loc);
278 
279 /// Representation of a single sample record.
280 ///
281 /// A sample record is represented by a positive integer value, which
282 /// indicates how frequently was the associated line location executed.
283 ///
284 /// Additionally, if the associated location contains a function call,
285 /// the record will hold a list of all the possible called targets. For
286 /// direct calls, this will be the exact function being invoked. For
287 /// indirect calls (function pointers, virtual table dispatch), this
288 /// will be a list of one or more functions.
289 class SampleRecord {
290 public:
291   using CallTarget = std::pair<StringRef, uint64_t>;
292   struct CallTargetComparator {
293     bool operator()(const CallTarget &LHS, const CallTarget &RHS) const {
294       if (LHS.second != RHS.second)
295         return LHS.second > RHS.second;
296 
297       return LHS.first < RHS.first;
298     }
299   };
300 
301   using SortedCallTargetSet = std::set<CallTarget, CallTargetComparator>;
302   using CallTargetMap = StringMap<uint64_t>;
303   SampleRecord() = default;
304 
305   /// Increment the number of samples for this record by \p S.
306   /// Optionally scale sample count \p S by \p Weight.
307   ///
308   /// Sample counts accumulate using saturating arithmetic, to avoid wrapping
309   /// around unsigned integers.
310   sampleprof_error addSamples(uint64_t S, uint64_t Weight = 1) {
311     bool Overflowed;
312     NumSamples = SaturatingMultiplyAdd(S, Weight, NumSamples, &Overflowed);
313     return Overflowed ? sampleprof_error::counter_overflow
314                       : sampleprof_error::success;
315   }
316 
317   /// Add called function \p F with samples \p S.
318   /// Optionally scale sample count \p S by \p Weight.
319   ///
320   /// Sample counts accumulate using saturating arithmetic, to avoid wrapping
321   /// around unsigned integers.
322   sampleprof_error addCalledTarget(StringRef F, uint64_t S,
323                                    uint64_t Weight = 1) {
324     uint64_t &TargetSamples = CallTargets[F];
325     bool Overflowed;
326     TargetSamples =
327         SaturatingMultiplyAdd(S, Weight, TargetSamples, &Overflowed);
328     return Overflowed ? sampleprof_error::counter_overflow
329                       : sampleprof_error::success;
330   }
331 
332   /// Return true if this sample record contains function calls.
333   bool hasCalls() const { return !CallTargets.empty(); }
334 
335   uint64_t getSamples() const { return NumSamples; }
336   const CallTargetMap &getCallTargets() const { return CallTargets; }
337   const SortedCallTargetSet getSortedCallTargets() const {
338     return SortCallTargets(CallTargets);
339   }
340 
341   /// Sort call targets in descending order of call frequency.
342   static const SortedCallTargetSet SortCallTargets(const CallTargetMap &Targets) {
343     SortedCallTargetSet SortedTargets;
344     for (const auto &I : Targets) {
345       SortedTargets.emplace(I.first(), I.second);
346     }
347     return SortedTargets;
348   }
349 
350   /// Prorate call targets by a distribution factor.
351   static const CallTargetMap adjustCallTargets(const CallTargetMap &Targets,
352                                                float DistributionFactor) {
353     CallTargetMap AdjustedTargets;
354     for (const auto &I : Targets) {
355       AdjustedTargets[I.first()] = I.second * DistributionFactor;
356     }
357     return AdjustedTargets;
358   }
359 
360   /// Merge the samples in \p Other into this record.
361   /// Optionally scale sample counts by \p Weight.
362   sampleprof_error merge(const SampleRecord &Other, uint64_t Weight = 1) {
363     sampleprof_error Result = addSamples(Other.getSamples(), Weight);
364     for (const auto &I : Other.getCallTargets()) {
365       MergeResult(Result, addCalledTarget(I.first(), I.second, Weight));
366     }
367     return Result;
368   }
369 
370   void print(raw_ostream &OS, unsigned Indent) const;
371   void dump() const;
372 
373 private:
374   uint64_t NumSamples = 0;
375   CallTargetMap CallTargets;
376 };
377 
378 raw_ostream &operator<<(raw_ostream &OS, const SampleRecord &Sample);
379 
380 // State of context associated with FunctionSamples
381 enum ContextStateMask {
382   UnknownContext = 0x0,   // Profile without context
383   RawContext = 0x1,       // Full context profile from input profile
384   SyntheticContext = 0x2, // Synthetic context created for context promotion
385   InlinedContext = 0x4,   // Profile for context that is inlined into caller
386   MergedContext = 0x8     // Profile for context merged into base profile
387 };
388 
389 // Sample context for FunctionSamples. It consists of the calling context,
390 // the function name and context state. Internally sample context is represented
391 // using StringRef, which is also the input for constructing a `SampleContext`.
392 // It can accept and represent both full context string as well as context-less
393 // function name.
394 // Example of full context string (note the wrapping `[]`):
395 //    `[main:3 @ _Z5funcAi:1 @ _Z8funcLeafi]`
396 // Example of context-less function name (same as AutoFDO):
397 //    `_Z8funcLeafi`
398 class SampleContext {
399 public:
400   SampleContext() : State(UnknownContext) {}
401   SampleContext(StringRef ContextStr,
402                 ContextStateMask CState = UnknownContext) {
403     setContext(ContextStr, CState);
404   }
405 
406   // Promote context by removing top frames (represented by `ContextStrToRemove`).
407   // Note that with string representation of context, the promotion is effectively
408   // a substr operation with `ContextStrToRemove` removed from left.
409   void promoteOnPath(StringRef ContextStrToRemove) {
410     assert(FullContext.startswith(ContextStrToRemove));
411 
412     // Remove leading context and frame separator " @ ".
413     FullContext = FullContext.substr(ContextStrToRemove.size() + 3);
414     CallingContext = CallingContext.substr(ContextStrToRemove.size() + 3);
415   }
416 
417   // Split the top context frame (left-most substr) from context.
418   static std::pair<StringRef, StringRef>
419   splitContextString(StringRef ContextStr) {
420     return ContextStr.split(" @ ");
421   }
422 
423   // Decode context string for a frame to get function name and location.
424   // `ContextStr` is in the form of `FuncName:StartLine.Discriminator`.
425   static void decodeContextString(StringRef ContextStr, StringRef &FName,
426                                   LineLocation &LineLoc) {
427     // Get function name
428     auto EntrySplit = ContextStr.split(':');
429     FName = EntrySplit.first;
430 
431     LineLoc = {0, 0};
432     if (!EntrySplit.second.empty()) {
433       // Get line offset, use signed int for getAsInteger so string will
434       // be parsed as signed.
435       int LineOffset = 0;
436       auto LocSplit = EntrySplit.second.split('.');
437       LocSplit.first.getAsInteger(10, LineOffset);
438       LineLoc.LineOffset = LineOffset;
439 
440       // Get discriminator
441       if (!LocSplit.second.empty())
442         LocSplit.second.getAsInteger(10, LineLoc.Discriminator);
443     }
444   }
445 
446   operator StringRef() const { return FullContext; }
447   bool hasState(ContextStateMask S) { return State & (uint32_t)S; }
448   void setState(ContextStateMask S) { State |= (uint32_t)S; }
449   void clearState(ContextStateMask S) { State &= (uint32_t)~S; }
450   bool hasContext() const { return State != UnknownContext; }
451   bool isBaseContext() const { return CallingContext.empty(); }
452   StringRef getNameWithoutContext() const { return Name; }
453   StringRef getCallingContext() const { return CallingContext; }
454   StringRef getNameWithContext(bool WithBracket = false) const {
455     return WithBracket ? InputContext : FullContext;
456   }
457 
458 private:
459   // Give a context string, decode and populate internal states like
460   // Function name, Calling context and context state. Example of input
461   // `ContextStr`: `[main:3 @ _Z5funcAi:1 @ _Z8funcLeafi]`
462   void setContext(StringRef ContextStr, ContextStateMask CState) {
463     assert(!ContextStr.empty());
464     InputContext = ContextStr;
465     // Note that `[]` wrapped input indicates a full context string, otherwise
466     // it's treated as context-less function name only.
467     bool HasContext = ContextStr.startswith("[");
468     if (!HasContext && CState == UnknownContext) {
469       State = UnknownContext;
470       Name = FullContext = ContextStr;
471     } else {
472       // Assume raw context profile if unspecified
473       if (CState == UnknownContext)
474         State = RawContext;
475       else
476         State = CState;
477 
478       // Remove encapsulating '[' and ']' if any
479       if (HasContext)
480         FullContext = ContextStr.substr(1, ContextStr.size() - 2);
481       else
482         FullContext = ContextStr;
483 
484       // Caller is to the left of callee in context string
485       auto NameContext = FullContext.rsplit(" @ ");
486       if (NameContext.second.empty()) {
487         Name = NameContext.first;
488         CallingContext = NameContext.second;
489       } else {
490         Name = NameContext.second;
491         CallingContext = NameContext.first;
492       }
493     }
494   }
495 
496   // Input context string including bracketed calling context and leaf function
497   // name
498   StringRef InputContext;
499   // Full context string including calling context and leaf function name
500   StringRef FullContext;
501   // Function name for the associated sample profile
502   StringRef Name;
503   // Calling context (leaf function excluded) for the associated sample profile
504   StringRef CallingContext;
505   // State of the associated sample profile
506   uint32_t State;
507 };
508 
509 class FunctionSamples;
510 class SampleProfileReaderItaniumRemapper;
511 
512 using BodySampleMap = std::map<LineLocation, SampleRecord>;
513 // NOTE: Using a StringMap here makes parsed profiles consume around 17% more
514 // memory, which is *very* significant for large profiles.
515 using FunctionSamplesMap = std::map<std::string, FunctionSamples, std::less<>>;
516 using CallsiteSampleMap = std::map<LineLocation, FunctionSamplesMap>;
517 
518 /// Representation of the samples collected for a function.
519 ///
520 /// This data structure contains all the collected samples for the body
521 /// of a function. Each sample corresponds to a LineLocation instance
522 /// within the body of the function.
523 class FunctionSamples {
524 public:
525   FunctionSamples() = default;
526 
527   void print(raw_ostream &OS = dbgs(), unsigned Indent = 0) const;
528   void dump() const;
529 
530   sampleprof_error addTotalSamples(uint64_t Num, uint64_t Weight = 1) {
531     bool Overflowed;
532     TotalSamples =
533         SaturatingMultiplyAdd(Num, Weight, TotalSamples, &Overflowed);
534     return Overflowed ? sampleprof_error::counter_overflow
535                       : sampleprof_error::success;
536   }
537 
538   void setTotalSamples(uint64_t Num) { TotalSamples = Num; }
539 
540   sampleprof_error addHeadSamples(uint64_t Num, uint64_t Weight = 1) {
541     bool Overflowed;
542     TotalHeadSamples =
543         SaturatingMultiplyAdd(Num, Weight, TotalHeadSamples, &Overflowed);
544     return Overflowed ? sampleprof_error::counter_overflow
545                       : sampleprof_error::success;
546   }
547 
548   sampleprof_error addBodySamples(uint32_t LineOffset, uint32_t Discriminator,
549                                   uint64_t Num, uint64_t Weight = 1) {
550     return BodySamples[LineLocation(LineOffset, Discriminator)].addSamples(
551         Num, Weight);
552   }
553 
554   sampleprof_error addCalledTargetSamples(uint32_t LineOffset,
555                                           uint32_t Discriminator,
556                                           StringRef FName, uint64_t Num,
557                                           uint64_t Weight = 1) {
558     return BodySamples[LineLocation(LineOffset, Discriminator)].addCalledTarget(
559         FName, Num, Weight);
560   }
561 
562   /// Return the number of samples collected at the given location.
563   /// Each location is specified by \p LineOffset and \p Discriminator.
564   /// If the location is not found in profile, return error.
565   ErrorOr<uint64_t> findSamplesAt(uint32_t LineOffset,
566                                   uint32_t Discriminator) const {
567     const auto &ret = BodySamples.find(LineLocation(LineOffset, Discriminator));
568     if (ret == BodySamples.end()) {
569       // For CSSPGO, in order to conserve profile size, we no longer write out
570       // locations profile for those not hit during training, so we need to
571       // treat them as zero instead of error here.
572       if (ProfileIsCS)
573         return 0;
574       return std::error_code();
575       // A missing counter for a probe likely means the probe was not executed.
576       // Treat it as a zero count instead of an unknown count to help edge
577       // weight inference.
578       if (FunctionSamples::ProfileIsProbeBased)
579         return 0;
580       return std::error_code();
581     } else {
582       return ret->second.getSamples();
583     }
584   }
585 
586   /// Returns the call target map collected at a given location.
587   /// Each location is specified by \p LineOffset and \p Discriminator.
588   /// If the location is not found in profile, return error.
589   ErrorOr<SampleRecord::CallTargetMap>
590   findCallTargetMapAt(uint32_t LineOffset, uint32_t Discriminator) const {
591     const auto &ret = BodySamples.find(LineLocation(LineOffset, Discriminator));
592     if (ret == BodySamples.end())
593       return std::error_code();
594     return ret->second.getCallTargets();
595   }
596 
597   /// Returns the call target map collected at a given location specified by \p
598   /// CallSite. If the location is not found in profile, return error.
599   ErrorOr<SampleRecord::CallTargetMap>
600   findCallTargetMapAt(const LineLocation &CallSite) const {
601     const auto &Ret = BodySamples.find(CallSite);
602     if (Ret == BodySamples.end())
603       return std::error_code();
604     return Ret->second.getCallTargets();
605   }
606 
607   /// Return the function samples at the given callsite location.
608   FunctionSamplesMap &functionSamplesAt(const LineLocation &Loc) {
609     return CallsiteSamples[Loc];
610   }
611 
612   /// Returns the FunctionSamplesMap at the given \p Loc.
613   const FunctionSamplesMap *
614   findFunctionSamplesMapAt(const LineLocation &Loc) const {
615     auto iter = CallsiteSamples.find(Loc);
616     if (iter == CallsiteSamples.end())
617       return nullptr;
618     return &iter->second;
619   }
620 
621   /// Returns a pointer to FunctionSamples at the given callsite location
622   /// \p Loc with callee \p CalleeName. If no callsite can be found, relax
623   /// the restriction to return the FunctionSamples at callsite location
624   /// \p Loc with the maximum total sample count. If \p Remapper is not
625   /// nullptr, use \p Remapper to find FunctionSamples with equivalent name
626   /// as \p CalleeName.
627   const FunctionSamples *
628   findFunctionSamplesAt(const LineLocation &Loc, StringRef CalleeName,
629                         SampleProfileReaderItaniumRemapper *Remapper) const;
630 
631   bool empty() const { return TotalSamples == 0; }
632 
633   /// Return the total number of samples collected inside the function.
634   uint64_t getTotalSamples() const { return TotalSamples; }
635 
636   /// Return the total number of branch samples that have the function as the
637   /// branch target. This should be equivalent to the sample of the first
638   /// instruction of the symbol. But as we directly get this info for raw
639   /// profile without referring to potentially inaccurate debug info, this
640   /// gives more accurate profile data and is preferred for standalone symbols.
641   uint64_t getHeadSamples() const { return TotalHeadSamples; }
642 
643   /// Return the sample count of the first instruction of the function.
644   /// The function can be either a standalone symbol or an inlined function.
645   uint64_t getEntrySamples() const {
646     if (FunctionSamples::ProfileIsCS && getHeadSamples()) {
647       // For CS profile, if we already have more accurate head samples
648       // counted by branch sample from caller, use them as entry samples.
649       return getHeadSamples();
650     }
651     uint64_t Count = 0;
652     // Use either BodySamples or CallsiteSamples which ever has the smaller
653     // lineno.
654     if (!BodySamples.empty() &&
655         (CallsiteSamples.empty() ||
656          BodySamples.begin()->first < CallsiteSamples.begin()->first))
657       Count = BodySamples.begin()->second.getSamples();
658     else if (!CallsiteSamples.empty()) {
659       // An indirect callsite may be promoted to several inlined direct calls.
660       // We need to get the sum of them.
661       for (const auto &N_FS : CallsiteSamples.begin()->second)
662         Count += N_FS.second.getEntrySamples();
663     }
664     // Return at least 1 if total sample is not 0.
665     return Count ? Count : TotalSamples > 0;
666   }
667 
668   /// Return all the samples collected in the body of the function.
669   const BodySampleMap &getBodySamples() const { return BodySamples; }
670 
671   /// Return all the callsite samples collected in the body of the function.
672   const CallsiteSampleMap &getCallsiteSamples() const {
673     return CallsiteSamples;
674   }
675 
676   /// Return the maximum of sample counts in a function body including functions
677   /// inlined in it.
678   uint64_t getMaxCountInside() const {
679     uint64_t MaxCount = 0;
680     for (const auto &L : getBodySamples())
681       MaxCount = std::max(MaxCount, L.second.getSamples());
682     for (const auto &C : getCallsiteSamples())
683       for (const FunctionSamplesMap::value_type &F : C.second)
684         MaxCount = std::max(MaxCount, F.second.getMaxCountInside());
685     return MaxCount;
686   }
687 
688   /// Merge the samples in \p Other into this one.
689   /// Optionally scale samples by \p Weight.
690   sampleprof_error merge(const FunctionSamples &Other, uint64_t Weight = 1) {
691     sampleprof_error Result = sampleprof_error::success;
692     Name = Other.getName();
693     if (!GUIDToFuncNameMap)
694       GUIDToFuncNameMap = Other.GUIDToFuncNameMap;
695     if (Context.getNameWithContext(true).empty())
696       Context = Other.getContext();
697     if (FunctionHash == 0) {
698       // Set the function hash code for the target profile.
699       FunctionHash = Other.getFunctionHash();
700     } else if (FunctionHash != Other.getFunctionHash()) {
701       // The two profiles coming with different valid hash codes indicates
702       // either:
703       // 1. They are same-named static functions from different compilation
704       // units (without using -unique-internal-linkage-names), or
705       // 2. They are really the same function but from different compilations.
706       // Let's bail out in either case for now, which means one profile is
707       // dropped.
708       return sampleprof_error::hash_mismatch;
709     }
710 
711     MergeResult(Result, addTotalSamples(Other.getTotalSamples(), Weight));
712     MergeResult(Result, addHeadSamples(Other.getHeadSamples(), Weight));
713     for (const auto &I : Other.getBodySamples()) {
714       const LineLocation &Loc = I.first;
715       const SampleRecord &Rec = I.second;
716       MergeResult(Result, BodySamples[Loc].merge(Rec, Weight));
717     }
718     for (const auto &I : Other.getCallsiteSamples()) {
719       const LineLocation &Loc = I.first;
720       FunctionSamplesMap &FSMap = functionSamplesAt(Loc);
721       for (const auto &Rec : I.second)
722         MergeResult(Result, FSMap[Rec.first].merge(Rec.second, Weight));
723     }
724     return Result;
725   }
726 
727   /// Recursively traverses all children, if the total sample count of the
728   /// corresponding function is no less than \p Threshold, add its corresponding
729   /// GUID to \p S. Also traverse the BodySamples to add hot CallTarget's GUID
730   /// to \p S.
731   void findInlinedFunctions(DenseSet<GlobalValue::GUID> &S, const Module *M,
732                             uint64_t Threshold) const {
733     if (TotalSamples <= Threshold)
734       return;
735     auto isDeclaration = [](const Function *F) {
736       return !F || F->isDeclaration();
737     };
738     if (isDeclaration(M->getFunction(getFuncName()))) {
739       // Add to the import list only when it's defined out of module.
740       S.insert(getGUID(Name));
741     }
742     // Import hot CallTargets, which may not be available in IR because full
743     // profile annotation cannot be done until backend compilation in ThinLTO.
744     for (const auto &BS : BodySamples)
745       for (const auto &TS : BS.second.getCallTargets())
746         if (TS.getValue() > Threshold) {
747           const Function *Callee = M->getFunction(getFuncName(TS.getKey()));
748           if (isDeclaration(Callee))
749             S.insert(getGUID(TS.getKey()));
750         }
751     for (const auto &CS : CallsiteSamples)
752       for (const auto &NameFS : CS.second)
753         NameFS.second.findInlinedFunctions(S, M, Threshold);
754   }
755 
756   /// Set the name of the function.
757   void setName(StringRef FunctionName) { Name = FunctionName; }
758 
759   /// Return the function name.
760   StringRef getName() const { return Name; }
761 
762   /// Return function name with context.
763   StringRef getNameWithContext(bool WithBracket = false) const {
764     return FunctionSamples::ProfileIsCS
765                ? Context.getNameWithContext(WithBracket)
766                : Name;
767   }
768 
769   /// Return the original function name.
770   StringRef getFuncName() const { return getFuncName(Name); }
771 
772   void setFunctionHash(uint64_t Hash) { FunctionHash = Hash; }
773 
774   uint64_t getFunctionHash() const { return FunctionHash; }
775 
776   /// Return the canonical name for a function, taking into account
777   /// suffix elision policy attributes.
778   static StringRef getCanonicalFnName(const Function &F) {
779     auto AttrName = "sample-profile-suffix-elision-policy";
780     auto Attr = F.getFnAttribute(AttrName).getValueAsString();
781     return getCanonicalFnName(F.getName(), Attr);
782   }
783 
784   static StringRef getCanonicalFnName(StringRef FnName, StringRef Attr = "") {
785     static const char *knownSuffixes[] = { ".llvm.", ".part." };
786     if (Attr == "" || Attr == "all") {
787       return FnName.split('.').first;
788     } else if (Attr == "selected") {
789       StringRef Cand(FnName);
790       for (const auto &Suf : knownSuffixes) {
791         StringRef Suffix(Suf);
792         auto It = Cand.rfind(Suffix);
793         if (It == StringRef::npos)
794           return Cand;
795         auto Dit = Cand.rfind('.');
796         if (Dit == It + Suffix.size() - 1)
797           Cand = Cand.substr(0, It);
798       }
799       return Cand;
800     } else if (Attr == "none") {
801       return FnName;
802     } else {
803       assert(false && "internal error: unknown suffix elision policy");
804     }
805     return FnName;
806   }
807 
808   /// Translate \p Name into its original name.
809   /// When profile doesn't use MD5, \p Name needs no translation.
810   /// When profile uses MD5, \p Name in current FunctionSamples
811   /// is actually GUID of the original function name. getFuncName will
812   /// translate \p Name in current FunctionSamples into its original name
813   /// by looking up in the function map GUIDToFuncNameMap.
814   /// If the original name doesn't exist in the map, return empty StringRef.
815   StringRef getFuncName(StringRef Name) const {
816     if (!UseMD5)
817       return Name;
818 
819     assert(GUIDToFuncNameMap && "GUIDToFuncNameMap needs to be popluated first");
820     return GUIDToFuncNameMap->lookup(std::stoull(Name.data()));
821   }
822 
823   /// Returns the line offset to the start line of the subprogram.
824   /// We assume that a single function will not exceed 65535 LOC.
825   static unsigned getOffset(const DILocation *DIL);
826 
827   /// Returns a unique call site identifier for a given debug location of a call
828   /// instruction. This is wrapper of two scenarios, the probe-based profile and
829   /// regular profile, to hide implementation details from the sample loader and
830   /// the context tracker.
831   static LineLocation getCallSiteIdentifier(const DILocation *DIL);
832 
833   /// Get the FunctionSamples of the inline instance where DIL originates
834   /// from.
835   ///
836   /// The FunctionSamples of the instruction (Machine or IR) associated to
837   /// \p DIL is the inlined instance in which that instruction is coming from.
838   /// We traverse the inline stack of that instruction, and match it with the
839   /// tree nodes in the profile.
840   ///
841   /// \returns the FunctionSamples pointer to the inlined instance.
842   /// If \p Remapper is not nullptr, it will be used to find matching
843   /// FunctionSamples with not exactly the same but equivalent name.
844   const FunctionSamples *findFunctionSamples(
845       const DILocation *DIL,
846       SampleProfileReaderItaniumRemapper *Remapper = nullptr) const;
847 
848   static bool ProfileIsProbeBased;
849 
850   static bool ProfileIsCS;
851 
852   SampleContext &getContext() const { return Context; }
853 
854   void setContext(const SampleContext &FContext) { Context = FContext; }
855 
856   static SampleProfileFormat Format;
857 
858   /// Whether the profile uses MD5 to represent string.
859   static bool UseMD5;
860 
861   /// GUIDToFuncNameMap saves the mapping from GUID to the symbol name, for
862   /// all the function symbols defined or declared in current module.
863   DenseMap<uint64_t, StringRef> *GUIDToFuncNameMap = nullptr;
864 
865   // Assume the input \p Name is a name coming from FunctionSamples itself.
866   // If UseMD5 is true, the name is already a GUID and we
867   // don't want to return the GUID of GUID.
868   static uint64_t getGUID(StringRef Name) {
869     return UseMD5 ? std::stoull(Name.data()) : Function::getGUID(Name);
870   }
871 
872   // Find all the names in the current FunctionSamples including names in
873   // all the inline instances and names of call targets.
874   void findAllNames(DenseSet<StringRef> &NameSet) const;
875 
876 private:
877   /// Mangled name of the function.
878   StringRef Name;
879 
880   /// CFG hash value for the function.
881   uint64_t FunctionHash = 0;
882 
883   /// Calling context for function profile
884   mutable SampleContext Context;
885 
886   /// Total number of samples collected inside this function.
887   ///
888   /// Samples are cumulative, they include all the samples collected
889   /// inside this function and all its inlined callees.
890   uint64_t TotalSamples = 0;
891 
892   /// Total number of samples collected at the head of the function.
893   /// This is an approximation of the number of calls made to this function
894   /// at runtime.
895   uint64_t TotalHeadSamples = 0;
896 
897   /// Map instruction locations to collected samples.
898   ///
899   /// Each entry in this map contains the number of samples
900   /// collected at the corresponding line offset. All line locations
901   /// are an offset from the start of the function.
902   BodySampleMap BodySamples;
903 
904   /// Map call sites to collected samples for the called function.
905   ///
906   /// Each entry in this map corresponds to all the samples
907   /// collected for the inlined function call at the given
908   /// location. For example, given:
909   ///
910   ///     void foo() {
911   ///  1    bar();
912   ///  ...
913   ///  8    baz();
914   ///     }
915   ///
916   /// If the bar() and baz() calls were inlined inside foo(), this
917   /// map will contain two entries.  One for all the samples collected
918   /// in the call to bar() at line offset 1, the other for all the samples
919   /// collected in the call to baz() at line offset 8.
920   CallsiteSampleMap CallsiteSamples;
921 };
922 
923 raw_ostream &operator<<(raw_ostream &OS, const FunctionSamples &FS);
924 
925 /// Sort a LocationT->SampleT map by LocationT.
926 ///
927 /// It produces a sorted list of <LocationT, SampleT> records by ascending
928 /// order of LocationT.
929 template <class LocationT, class SampleT> class SampleSorter {
930 public:
931   using SamplesWithLoc = std::pair<const LocationT, SampleT>;
932   using SamplesWithLocList = SmallVector<const SamplesWithLoc *, 20>;
933 
934   SampleSorter(const std::map<LocationT, SampleT> &Samples) {
935     for (const auto &I : Samples)
936       V.push_back(&I);
937     llvm::stable_sort(V, [](const SamplesWithLoc *A, const SamplesWithLoc *B) {
938       return A->first < B->first;
939     });
940   }
941 
942   const SamplesWithLocList &get() const { return V; }
943 
944 private:
945   SamplesWithLocList V;
946 };
947 
948 /// ProfileSymbolList records the list of function symbols shown up
949 /// in the binary used to generate the profile. It is useful to
950 /// to discriminate a function being so cold as not to shown up
951 /// in the profile and a function newly added.
952 class ProfileSymbolList {
953 public:
954   /// copy indicates whether we need to copy the underlying memory
955   /// for the input Name.
956   void add(StringRef Name, bool copy = false) {
957     if (!copy) {
958       Syms.insert(Name);
959       return;
960     }
961     Syms.insert(Name.copy(Allocator));
962   }
963 
964   bool contains(StringRef Name) { return Syms.count(Name); }
965 
966   void merge(const ProfileSymbolList &List) {
967     for (auto Sym : List.Syms)
968       add(Sym, true);
969   }
970 
971   unsigned size() { return Syms.size(); }
972 
973   void setToCompress(bool TC) { ToCompress = TC; }
974   bool toCompress() { return ToCompress; }
975 
976   std::error_code read(const uint8_t *Data, uint64_t ListSize);
977   std::error_code write(raw_ostream &OS);
978   void dump(raw_ostream &OS = dbgs()) const;
979 
980 private:
981   // Determine whether or not to compress the symbol list when
982   // writing it into profile. The variable is unused when the symbol
983   // list is read from an existing profile.
984   bool ToCompress = false;
985   DenseSet<StringRef> Syms;
986   BumpPtrAllocator Allocator;
987 };
988 
989 } // end namespace sampleprof
990 } // end namespace llvm
991 
992 #endif // LLVM_PROFILEDATA_SAMPLEPROF_H
993