1 //===- Transforms/IPO/SampleContextTracker.h --------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// This file provides the interface for context-sensitive profile tracker used
11 /// by CSSPGO.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef LLVM_TRANSFORMS_IPO_SAMPLECONTEXTTRACKER_H
16 #define LLVM_TRANSFORMS_IPO_SAMPLECONTEXTTRACKER_H
17 
18 #include "llvm/ADT/StringMap.h"
19 #include "llvm/ADT/StringRef.h"
20 #include "llvm/ADT/iterator.h"
21 #include "llvm/ProfileData/SampleProf.h"
22 #include <map>
23 #include <queue>
24 #include <vector>
25 
26 namespace llvm {
27 class CallBase;
28 class DILocation;
29 class Function;
30 class Instruction;
31 
32 // Internal trie tree representation used for tracking context tree and sample
33 // profiles. The path from root node to a given node represents the context of
34 // that nodes' profile.
35 class ContextTrieNode {
36 public:
37   ContextTrieNode(ContextTrieNode *Parent = nullptr,
38                   StringRef FName = StringRef(),
39                   FunctionSamples *FSamples = nullptr,
40                   LineLocation CallLoc = {0, 0})
41       : ParentContext(Parent), FuncName(FName), FuncSamples(FSamples),
42         CallSiteLoc(CallLoc){};
43   ContextTrieNode *getChildContext(const LineLocation &CallSite,
44                                    StringRef ChildName);
45   ContextTrieNode *getHottestChildContext(const LineLocation &CallSite);
46   ContextTrieNode *getOrCreateChildContext(const LineLocation &CallSite,
47                                            StringRef ChildName,
48                                            bool AllowCreate = true);
49   void removeChildContext(const LineLocation &CallSite, StringRef ChildName);
50   std::map<uint64_t, ContextTrieNode> &getAllChildContext();
51   StringRef getFuncName() const;
52   FunctionSamples *getFunctionSamples() const;
53   void setFunctionSamples(FunctionSamples *FSamples);
54   std::optional<uint32_t> getFunctionSize() const;
55   void addFunctionSize(uint32_t FSize);
56   LineLocation getCallSiteLoc() const;
57   ContextTrieNode *getParentContext() const;
58   void setParentContext(ContextTrieNode *Parent);
59   void setCallSiteLoc(const LineLocation &Loc);
60   void dumpNode();
61   void dumpTree();
62 
63 private:
64   // Map line+discriminator location to child context
65   std::map<uint64_t, ContextTrieNode> AllChildContext;
66 
67   // Link to parent context node
68   ContextTrieNode *ParentContext;
69 
70   // Function name for current context
71   StringRef FuncName;
72 
73   // Function Samples for current context
74   FunctionSamples *FuncSamples;
75 
76   // Function size for current context
77   std::optional<uint32_t> FuncSize;
78 
79   // Callsite location in parent context
80   LineLocation CallSiteLoc;
81 };
82 
83 // Profile tracker that manages profiles and its associated context. It
84 // provides interfaces used by sample profile loader to query context profile or
85 // base profile for given function or location; it also manages context tree
86 // manipulation that is needed to accommodate inline decisions so we have
87 // accurate post-inline profile for functions. Internally context profiles
88 // are organized in a trie, with each node representing profile for specific
89 // calling context and the context is identified by path from root to the node.
90 class SampleContextTracker {
91 public:
92   using ContextSamplesTy = std::vector<FunctionSamples *>;
93 
94   SampleContextTracker() = default;
95   SampleContextTracker(SampleProfileMap &Profiles,
96                        const DenseMap<uint64_t, StringRef> *GUIDToFuncNameMap);
97   // Populate the FuncToCtxtProfiles map after the trie is built.
98   void populateFuncToCtxtMap();
99   // Query context profile for a specific callee with given name at a given
100   // call-site. The full context is identified by location of call instruction.
101   FunctionSamples *getCalleeContextSamplesFor(const CallBase &Inst,
102                                               StringRef CalleeName);
103   // Get samples for indirect call targets for call site at given location.
104   std::vector<const FunctionSamples *>
105   getIndirectCalleeContextSamplesFor(const DILocation *DIL);
106   // Query context profile for a given location. The full context
107   // is identified by input DILocation.
108   FunctionSamples *getContextSamplesFor(const DILocation *DIL);
109   // Query context profile for a given sample contxt of a function.
110   FunctionSamples *getContextSamplesFor(const SampleContext &Context);
111   // Get all context profile for given function.
112   ContextSamplesTy &getAllContextSamplesFor(const Function &Func);
113   ContextSamplesTy &getAllContextSamplesFor(StringRef Name);
114   ContextTrieNode *getOrCreateContextPath(const SampleContext &Context,
115                                           bool AllowCreate);
116   // Query base profile for a given function. A base profile is a merged view
117   // of all context profiles for contexts that are not inlined.
118   FunctionSamples *getBaseSamplesFor(const Function &Func,
119                                      bool MergeContext = true);
120   // Query base profile for a given function by name.
121   FunctionSamples *getBaseSamplesFor(StringRef Name, bool MergeContext = true);
122   // Retrieve the context trie node for given profile context
123   ContextTrieNode *getContextFor(const SampleContext &Context);
124   // Get real function name for a given trie node.
125   StringRef getFuncNameFor(ContextTrieNode *Node) const;
126   // Mark a context profile as inlined when function is inlined.
127   // This makes sure that inlined context profile will be excluded in
128   // function's base profile.
129   void markContextSamplesInlined(const FunctionSamples *InlinedSamples);
130   ContextTrieNode &getRootContext();
131   void promoteMergeContextSamplesTree(const Instruction &Inst,
132                                       StringRef CalleeName);
133 
134   // Create a merged conext-less profile map.
135   void createContextLessProfileMap(SampleProfileMap &ContextLessProfiles);
136   ContextTrieNode *
137   getContextNodeForProfile(const FunctionSamples *FSamples) const {
138     auto I = ProfileToNodeMap.find(FSamples);
139     if (I == ProfileToNodeMap.end())
140       return nullptr;
141     return I->second;
142   }
143   StringMap<ContextSamplesTy> &getFuncToCtxtProfiles() {
144     return FuncToCtxtProfiles;
145   }
146 
147   class Iterator : public llvm::iterator_facade_base<
148                        Iterator, std::forward_iterator_tag, ContextTrieNode *,
149                        std::ptrdiff_t, ContextTrieNode **, ContextTrieNode *> {
150     std::queue<ContextTrieNode *> NodeQueue;
151 
152   public:
153     explicit Iterator() = default;
154     explicit Iterator(ContextTrieNode *Node) { NodeQueue.push(Node); }
155     Iterator &operator++() {
156       assert(!NodeQueue.empty() && "Iterator already at the end");
157       ContextTrieNode *Node = NodeQueue.front();
158       NodeQueue.pop();
159       for (auto &It : Node->getAllChildContext())
160         NodeQueue.push(&It.second);
161       return *this;
162     }
163 
164     bool operator==(const Iterator &Other) const {
165       if (NodeQueue.empty() && Other.NodeQueue.empty())
166         return true;
167       if (NodeQueue.empty() || Other.NodeQueue.empty())
168         return false;
169       return NodeQueue.front() == Other.NodeQueue.front();
170     }
171 
172     ContextTrieNode *operator*() const {
173       assert(!NodeQueue.empty() && "Invalid access to end iterator");
174       return NodeQueue.front();
175     }
176   };
177 
178   Iterator begin() { return Iterator(&RootContext); }
179   Iterator end() { return Iterator(); }
180 
181 #ifndef NDEBUG
182   // Get a context string from root to current node.
183   std::string getContextString(const FunctionSamples &FSamples) const;
184   std::string getContextString(ContextTrieNode *Node) const;
185 #endif
186   // Dump the internal context profile trie.
187   void dump();
188 
189 private:
190   ContextTrieNode *getContextFor(const DILocation *DIL);
191   ContextTrieNode *getCalleeContextFor(const DILocation *DIL,
192                                        StringRef CalleeName);
193   ContextTrieNode *getTopLevelContextNode(StringRef FName);
194   ContextTrieNode &addTopLevelContextNode(StringRef FName);
195   ContextTrieNode &promoteMergeContextSamplesTree(ContextTrieNode &NodeToPromo);
196   void mergeContextNode(ContextTrieNode &FromNode, ContextTrieNode &ToNode);
197   ContextTrieNode &
198   promoteMergeContextSamplesTree(ContextTrieNode &FromNode,
199                                  ContextTrieNode &ToNodeParent);
200   ContextTrieNode &moveContextSamples(ContextTrieNode &ToNodeParent,
201                                       const LineLocation &CallSite,
202                                       ContextTrieNode &&NodeToMove);
203   void setContextNode(const FunctionSamples *FSample, ContextTrieNode *Node) {
204     ProfileToNodeMap[FSample] = Node;
205   }
206   // Map from function name to context profiles (excluding base profile)
207   StringMap<ContextSamplesTy> FuncToCtxtProfiles;
208 
209   // Map from current FunctionSample to the belonged context trie.
210   std::unordered_map<const FunctionSamples *, ContextTrieNode *>
211       ProfileToNodeMap;
212 
213   // Map from function guid to real function names. Only used in md5 mode.
214   const DenseMap<uint64_t, StringRef> *GUIDToFuncNameMap;
215 
216   // Root node for context trie tree
217   ContextTrieNode RootContext;
218 };
219 
220 } // end namespace llvm
221 #endif // LLVM_TRANSFORMS_IPO_SAMPLECONTEXTTRACKER_H
222