1 //===-- ProfiledCallGraph.h - Profiled Call Graph ----------------- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_TOOLS_LLVM_PROFGEN_PROFILEDCALLGRAPH_H
10 #define LLVM_TOOLS_LLVM_PROFGEN_PROFILEDCALLGRAPH_H
11 
12 #include "llvm/ADT/GraphTraits.h"
13 #include "llvm/ADT/StringMap.h"
14 #include "llvm/ADT/StringRef.h"
15 #include "llvm/ProfileData/SampleProf.h"
16 #include "llvm/ProfileData/SampleProfReader.h"
17 #include "llvm/Transforms/IPO/SampleContextTracker.h"
18 #include <queue>
19 #include <set>
20 
21 using namespace llvm;
22 using namespace sampleprof;
23 
24 namespace llvm {
25 namespace sampleprof {
26 
27 struct ProfiledCallGraphNode {
NameProfiledCallGraphNode28   ProfiledCallGraphNode(StringRef FName = StringRef()) : Name(FName) {}
29   StringRef Name;
30 
31   struct ProfiledCallGraphNodeComparer {
operatorProfiledCallGraphNode::ProfiledCallGraphNodeComparer32     bool operator()(const ProfiledCallGraphNode *L,
33                     const ProfiledCallGraphNode *R) const {
34       return L->Name < R->Name;
35     }
36   };
37   std::set<ProfiledCallGraphNode *, ProfiledCallGraphNodeComparer> Callees;
38 };
39 
40 class ProfiledCallGraph {
41 public:
42   using iterator = std::set<ProfiledCallGraphNode *>::iterator;
43 
44   // Constructor for non-CS profile.
ProfiledCallGraph(StringMap<FunctionSamples> & ProfileMap)45   ProfiledCallGraph(StringMap<FunctionSamples> &ProfileMap) {
46     assert(!FunctionSamples::ProfileIsCS && "CS profile is not handled here");
47     for (const auto &Samples : ProfileMap) {
48       addProfiledCalls(Samples.second);
49     }
50   }
51 
52   // Constructor for CS profile.
ProfiledCallGraph(SampleContextTracker & ContextTracker)53   ProfiledCallGraph(SampleContextTracker &ContextTracker) {
54     // BFS traverse the context profile trie to add call edges for calls shown
55     // in context.
56     std::queue<ContextTrieNode *> Queue;
57     for (auto &Child : ContextTracker.getRootContext().getAllChildContext()) {
58       ContextTrieNode *Callee = &Child.second;
59       addProfiledFunction(Callee->getFuncName());
60       Queue.push(Callee);
61     }
62 
63     while (!Queue.empty()) {
64       ContextTrieNode *Caller = Queue.front();
65       Queue.pop();
66       // Add calls for context. When AddNodeWithSamplesOnly is true, both caller
67       // and callee need to have context profile.
68       // Note that callsite target samples are completely ignored since they can
69       // conflict with the context edges, which are formed by context
70       // compression during profile generation, for cyclic SCCs. This may
71       // further result in an SCC order incompatible with the purely
72       // context-based one, which may in turn block context-based inlining.
73       for (auto &Child : Caller->getAllChildContext()) {
74         ContextTrieNode *Callee = &Child.second;
75         addProfiledFunction(Callee->getFuncName());
76         Queue.push(Callee);
77         addProfiledCall(Caller->getFuncName(), Callee->getFuncName());
78       }
79     }
80   }
81 
begin()82   iterator begin() { return Root.Callees.begin(); }
end()83   iterator end() { return Root.Callees.end(); }
getEntryNode()84   ProfiledCallGraphNode *getEntryNode() { return &Root; }
addProfiledFunction(StringRef Name)85   void addProfiledFunction(StringRef Name) {
86     if (!ProfiledFunctions.count(Name)) {
87       // Link to synthetic root to make sure every node is reachable
88       // from root. This does not affect SCC order.
89       ProfiledFunctions[Name] = ProfiledCallGraphNode(Name);
90       Root.Callees.insert(&ProfiledFunctions[Name]);
91     }
92   }
93 
addProfiledCall(StringRef CallerName,StringRef CalleeName)94   void addProfiledCall(StringRef CallerName, StringRef CalleeName) {
95     assert(ProfiledFunctions.count(CallerName));
96     auto CalleeIt = ProfiledFunctions.find(CalleeName);
97     if (CalleeIt == ProfiledFunctions.end()) {
98       return;
99     }
100     ProfiledFunctions[CallerName].Callees.insert(&CalleeIt->second);
101   }
102 
addProfiledCalls(const FunctionSamples & Samples)103   void addProfiledCalls(const FunctionSamples &Samples) {
104     addProfiledFunction(Samples.getFuncName());
105 
106     for (const auto &Sample : Samples.getBodySamples()) {
107       for (const auto &Target : Sample.second.getCallTargets()) {
108         addProfiledFunction(Target.first());
109         addProfiledCall(Samples.getFuncName(), Target.first());
110       }
111     }
112 
113     for (const auto &CallsiteSamples : Samples.getCallsiteSamples()) {
114       for (const auto &InlinedSamples : CallsiteSamples.second) {
115         addProfiledFunction(InlinedSamples.first);
116         addProfiledCall(Samples.getFuncName(), InlinedSamples.first);
117         addProfiledCalls(InlinedSamples.second);
118       }
119     }
120   }
121 
122 private:
123   ProfiledCallGraphNode Root;
124   StringMap<ProfiledCallGraphNode> ProfiledFunctions;
125 };
126 
127 } // end namespace sampleprof
128 
129 template <> struct GraphTraits<ProfiledCallGraphNode *> {
130   using NodeRef = ProfiledCallGraphNode *;
131   using ChildIteratorType = std::set<ProfiledCallGraphNode *>::iterator;
132 
133   static NodeRef getEntryNode(NodeRef PCGN) { return PCGN; }
134   static ChildIteratorType child_begin(NodeRef N) { return N->Callees.begin(); }
135   static ChildIteratorType child_end(NodeRef N) { return N->Callees.end(); }
136 };
137 
138 template <>
139 struct GraphTraits<ProfiledCallGraph *>
140     : public GraphTraits<ProfiledCallGraphNode *> {
141   static NodeRef getEntryNode(ProfiledCallGraph *PCG) {
142     return PCG->getEntryNode();
143   }
144 
145   static ChildIteratorType nodes_begin(ProfiledCallGraph *PCG) {
146     return PCG->begin();
147   }
148 
149   static ChildIteratorType nodes_end(ProfiledCallGraph *PCG) {
150     return PCG->end();
151   }
152 };
153 
154 } // end namespace llvm
155 
156 #endif
157