1 //===- Transforms/IPO/SampleContextTracker.h --------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// This file provides the interface for context-sensitive profile tracker used
11 /// by CSSPGO.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef LLVM_TRANSFORMS_IPO_SAMPLECONTEXTTRACKER_H
16 #define LLVM_TRANSFORMS_IPO_SAMPLECONTEXTTRACKER_H
17 
18 #include "llvm/ADT/StringMap.h"
19 #include "llvm/ADT/StringRef.h"
20 #include "llvm/ProfileData/SampleProf.h"
21 #include <map>
22 #include <queue>
23 #include <vector>
24 
25 namespace llvm {
26 class CallBase;
27 class DILocation;
28 class Function;
29 class Instruction;
30 
31 // Internal trie tree representation used for tracking context tree and sample
32 // profiles. The path from root node to a given node represents the context of
33 // that nodes' profile.
34 class ContextTrieNode {
35 public:
36   ContextTrieNode(ContextTrieNode *Parent = nullptr,
37                   StringRef FName = StringRef(),
38                   FunctionSamples *FSamples = nullptr,
39                   LineLocation CallLoc = {0, 0})
40       : ParentContext(Parent), FuncName(FName), FuncSamples(FSamples),
41         CallSiteLoc(CallLoc){};
42   ContextTrieNode *getChildContext(const LineLocation &CallSite,
43                                    StringRef ChildName);
44   ContextTrieNode *getHottestChildContext(const LineLocation &CallSite);
45   ContextTrieNode *getOrCreateChildContext(const LineLocation &CallSite,
46                                            StringRef ChildName,
47                                            bool AllowCreate = true);
48   void removeChildContext(const LineLocation &CallSite, StringRef ChildName);
49   std::map<uint64_t, ContextTrieNode> &getAllChildContext();
50   StringRef getFuncName() const;
51   FunctionSamples *getFunctionSamples() const;
52   void setFunctionSamples(FunctionSamples *FSamples);
53   Optional<uint32_t> getFunctionSize() const;
54   void addFunctionSize(uint32_t FSize);
55   LineLocation getCallSiteLoc() const;
56   ContextTrieNode *getParentContext() const;
57   void setParentContext(ContextTrieNode *Parent);
58   void setCallSiteLoc(const LineLocation &Loc);
59   void dumpNode();
60   void dumpTree();
61 
62 private:
63   // Map line+discriminator location to child context
64   std::map<uint64_t, ContextTrieNode> AllChildContext;
65 
66   // Link to parent context node
67   ContextTrieNode *ParentContext;
68 
69   // Function name for current context
70   StringRef FuncName;
71 
72   // Function Samples for current context
73   FunctionSamples *FuncSamples;
74 
75   // Function size for current context
76   Optional<uint32_t> FuncSize;
77 
78   // Callsite location in parent context
79   LineLocation CallSiteLoc;
80 };
81 
82 // Profile tracker that manages profiles and its associated context. It
83 // provides interfaces used by sample profile loader to query context profile or
84 // base profile for given function or location; it also manages context tree
85 // manipulation that is needed to accommodate inline decisions so we have
86 // accurate post-inline profile for functions. Internally context profiles
87 // are organized in a trie, with each node representing profile for specific
88 // calling context and the context is identified by path from root to the node.
89 class SampleContextTracker {
90 public:
91   using ContextSamplesTy = std::vector<FunctionSamples *>;
92 
93   SampleContextTracker() = default;
94   SampleContextTracker(SampleProfileMap &Profiles,
95                        const DenseMap<uint64_t, StringRef> *GUIDToFuncNameMap);
96   // Populate the FuncToCtxtProfiles map after the trie is built.
97   void populateFuncToCtxtMap();
98   // Query context profile for a specific callee with given name at a given
99   // call-site. The full context is identified by location of call instruction.
100   FunctionSamples *getCalleeContextSamplesFor(const CallBase &Inst,
101                                               StringRef CalleeName);
102   // Get samples for indirect call targets for call site at given location.
103   std::vector<const FunctionSamples *>
104   getIndirectCalleeContextSamplesFor(const DILocation *DIL);
105   // Query context profile for a given location. The full context
106   // is identified by input DILocation.
107   FunctionSamples *getContextSamplesFor(const DILocation *DIL);
108   // Query context profile for a given sample contxt of a function.
109   FunctionSamples *getContextSamplesFor(const SampleContext &Context);
110   // Get all context profile for given function.
111   ContextSamplesTy &getAllContextSamplesFor(const Function &Func);
112   ContextSamplesTy &getAllContextSamplesFor(StringRef Name);
113   ContextTrieNode *getOrCreateContextPath(const SampleContext &Context,
114                                           bool AllowCreate);
115   // Query base profile for a given function. A base profile is a merged view
116   // of all context profiles for contexts that are not inlined.
117   FunctionSamples *getBaseSamplesFor(const Function &Func,
118                                      bool MergeContext = true);
119   // Query base profile for a given function by name.
120   FunctionSamples *getBaseSamplesFor(StringRef Name, bool MergeContext = true);
121   // Retrieve the context trie node for given profile context
122   ContextTrieNode *getContextFor(const SampleContext &Context);
123   // Get real function name for a given trie node.
124   StringRef getFuncNameFor(ContextTrieNode *Node) const;
125   // Mark a context profile as inlined when function is inlined.
126   // This makes sure that inlined context profile will be excluded in
127   // function's base profile.
128   void markContextSamplesInlined(const FunctionSamples *InlinedSamples);
129   ContextTrieNode &getRootContext();
130   void promoteMergeContextSamplesTree(const Instruction &Inst,
131                                       StringRef CalleeName);
132 
133   // Create a merged conext-less profile map.
134   void createContextLessProfileMap(SampleProfileMap &ContextLessProfiles);
135   ContextTrieNode *
136   getContextNodeForProfile(const FunctionSamples *FSamples) const {
137     auto I = ProfileToNodeMap.find(FSamples);
138     if (I == ProfileToNodeMap.end())
139       return nullptr;
140     return I->second;
141   }
142   StringMap<ContextSamplesTy> &getFuncToCtxtProfiles() {
143     return FuncToCtxtProfiles;
144   }
145 
146   class Iterator : public std::iterator<std::forward_iterator_tag,
147                                         const ContextTrieNode *> {
148     std::queue<ContextTrieNode *> NodeQueue;
149 
150   public:
151     explicit Iterator() = default;
152     explicit Iterator(ContextTrieNode *Node) { NodeQueue.push(Node); }
153     Iterator &operator++() {
154       assert(!NodeQueue.empty() && "Iterator already at the end");
155       ContextTrieNode *Node = NodeQueue.front();
156       NodeQueue.pop();
157       for (auto &It : Node->getAllChildContext())
158         NodeQueue.push(&It.second);
159       return *this;
160     }
161 
162     Iterator operator++(int) {
163       assert(!NodeQueue.empty() && "Iterator already at the end");
164       Iterator Ret = *this;
165       ++(*this);
166       return Ret;
167     }
168     bool operator==(const Iterator &Other) const {
169       if (NodeQueue.empty() && Other.NodeQueue.empty())
170         return true;
171       if (NodeQueue.empty() || Other.NodeQueue.empty())
172         return false;
173       return NodeQueue.front() == Other.NodeQueue.front();
174     }
175     bool operator!=(const Iterator &Other) const { return !(*this == Other); }
176     ContextTrieNode *operator*() const {
177       assert(!NodeQueue.empty() && "Invalid access to end iterator");
178       return NodeQueue.front();
179     }
180   };
181 
182   Iterator begin() { return Iterator(&RootContext); }
183   Iterator end() { return Iterator(); }
184 
185 #ifndef NDEBUG
186   // Get a context string from root to current node.
187   std::string getContextString(const FunctionSamples &FSamples) const;
188   std::string getContextString(ContextTrieNode *Node) const;
189 #endif
190   // Dump the internal context profile trie.
191   void dump();
192 
193 private:
194   ContextTrieNode *getContextFor(const DILocation *DIL);
195   ContextTrieNode *getCalleeContextFor(const DILocation *DIL,
196                                        StringRef CalleeName);
197   ContextTrieNode *getTopLevelContextNode(StringRef FName);
198   ContextTrieNode &addTopLevelContextNode(StringRef FName);
199   ContextTrieNode &promoteMergeContextSamplesTree(ContextTrieNode &NodeToPromo);
200   void mergeContextNode(ContextTrieNode &FromNode, ContextTrieNode &ToNode);
201   ContextTrieNode &
202   promoteMergeContextSamplesTree(ContextTrieNode &FromNode,
203                                  ContextTrieNode &ToNodeParent);
204   ContextTrieNode &moveContextSamples(ContextTrieNode &ToNodeParent,
205                                       const LineLocation &CallSite,
206                                       ContextTrieNode &&NodeToMove);
207   void setContextNode(const FunctionSamples *FSample, ContextTrieNode *Node) {
208     ProfileToNodeMap[FSample] = Node;
209   }
210   // Map from function name to context profiles (excluding base profile)
211   StringMap<ContextSamplesTy> FuncToCtxtProfiles;
212 
213   // Map from current FunctionSample to the belonged context trie.
214   std::unordered_map<const FunctionSamples *, ContextTrieNode *>
215       ProfileToNodeMap;
216 
217   // Map from function guid to real function names. Only used in md5 mode.
218   const DenseMap<uint64_t, StringRef> *GUIDToFuncNameMap;
219 
220   // Root node for context trie tree
221   ContextTrieNode RootContext;
222 };
223 
224 } // end namespace llvm
225 #endif // LLVM_TRANSFORMS_IPO_SAMPLECONTEXTTRACKER_H
226