1 //===- CallPrinter.cpp - DOT printer for call graph -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines '-dot-callgraph', which emit a callgraph.<fnname>.dot
10 // containing the call graph of a module.
11 //
12 // There is also a pass available to directly call dotty ('-view-callgraph').
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/Analysis/CallPrinter.h"
17 #include "llvm/ADT/DenseMap.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/Analysis/BlockFrequencyInfo.h"
20 #include "llvm/Analysis/CallGraph.h"
21 #include "llvm/Analysis/HeatUtils.h"
22 #include "llvm/IR/Instructions.h"
23 #include "llvm/InitializePasses.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/DOTGraphTraits.h"
26 #include "llvm/Support/GraphWriter.h"
27 
28 using namespace llvm;
29 
30 namespace llvm {
31 template <class GraphType> struct GraphTraits;
32 }
33 
34 // This option shows static (relative) call counts.
35 // FIXME:
36 // Need to show real counts when profile data is available
37 static cl::opt<bool> ShowHeatColors("callgraph-heat-colors", cl::init(false),
38                                     cl::Hidden,
39                                     cl::desc("Show heat colors in call-graph"));
40 
41 static cl::opt<bool>
42     ShowEdgeWeight("callgraph-show-weights", cl::init(false), cl::Hidden,
43                        cl::desc("Show edges labeled with weights"));
44 
45 static cl::opt<bool>
46     CallMultiGraph("callgraph-multigraph", cl::init(false), cl::Hidden,
47             cl::desc("Show call-multigraph (do not remove parallel edges)"));
48 
49 static cl::opt<std::string> CallGraphDotFilenamePrefix(
50     "callgraph-dot-filename-prefix", cl::Hidden,
51     cl::desc("The prefix used for the CallGraph dot file names."));
52 
53 namespace llvm {
54 
55 class CallGraphDOTInfo {
56 private:
57   Module *M;
58   CallGraph *CG;
59   DenseMap<const Function *, uint64_t> Freq;
60   uint64_t MaxFreq;
61 
62 public:
63   std::function<BlockFrequencyInfo *(Function &)> LookupBFI;
64 
65   CallGraphDOTInfo(Module *M, CallGraph *CG,
66                    function_ref<BlockFrequencyInfo *(Function &)> LookupBFI)
67       : M(M), CG(CG), LookupBFI(LookupBFI) {
68     MaxFreq = 0;
69 
70     for (Function &F : M->getFunctionList()) {
71       uint64_t localSumFreq = 0;
72       SmallSet<Function *, 16> Callers;
73       for (User *U : F.users())
74         if (isa<CallInst>(U))
75           Callers.insert(cast<Instruction>(U)->getFunction());
76       for (Function *Caller : Callers)
77         localSumFreq += getNumOfCalls(*Caller, F);
78       if (localSumFreq >= MaxFreq)
79         MaxFreq = localSumFreq;
80       Freq[&F] = localSumFreq;
81     }
82     if (!CallMultiGraph)
83       removeParallelEdges();
84   }
85 
86   Module *getModule() const { return M; }
87 
88   CallGraph *getCallGraph() const { return CG; }
89 
90   uint64_t getFreq(const Function *F) { return Freq[F]; }
91 
92   uint64_t getMaxFreq() { return MaxFreq; }
93 
94 private:
95   void removeParallelEdges() {
96     for (auto &I : (*CG)) {
97       CallGraphNode *Node = I.second.get();
98 
99       bool FoundParallelEdge = true;
100       while (FoundParallelEdge) {
101         SmallSet<Function *, 16> Visited;
102         FoundParallelEdge = false;
103         for (auto CI = Node->begin(), CE = Node->end(); CI != CE; CI++) {
104           if (!(Visited.insert(CI->second->getFunction())).second) {
105             FoundParallelEdge = true;
106             Node->removeCallEdge(CI);
107             break;
108           }
109         }
110       }
111     }
112   }
113 };
114 
115 template <>
116 struct GraphTraits<CallGraphDOTInfo *>
117     : public GraphTraits<const CallGraphNode *> {
118   static NodeRef getEntryNode(CallGraphDOTInfo *CGInfo) {
119     // Start at the external node!
120     return CGInfo->getCallGraph()->getExternalCallingNode();
121   }
122 
123   typedef std::pair<const Function *const, std::unique_ptr<CallGraphNode>>
124       PairTy;
125   static const CallGraphNode *CGGetValuePtr(const PairTy &P) {
126     return P.second.get();
127   }
128 
129   // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
130   typedef mapped_iterator<CallGraph::const_iterator, decltype(&CGGetValuePtr)>
131       nodes_iterator;
132 
133   static nodes_iterator nodes_begin(CallGraphDOTInfo *CGInfo) {
134     return nodes_iterator(CGInfo->getCallGraph()->begin(), &CGGetValuePtr);
135   }
136   static nodes_iterator nodes_end(CallGraphDOTInfo *CGInfo) {
137     return nodes_iterator(CGInfo->getCallGraph()->end(), &CGGetValuePtr);
138   }
139 };
140 
141 template <>
142 struct DOTGraphTraits<CallGraphDOTInfo *> : public DefaultDOTGraphTraits {
143 
144   DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {}
145 
146   static std::string getGraphName(CallGraphDOTInfo *CGInfo) {
147     return "Call graph: " +
148            std::string(CGInfo->getModule()->getModuleIdentifier());
149   }
150 
151   static bool isNodeHidden(const CallGraphNode *Node,
152                            const CallGraphDOTInfo *CGInfo) {
153     if (CallMultiGraph || Node->getFunction())
154       return false;
155     return true;
156   }
157 
158   std::string getNodeLabel(const CallGraphNode *Node,
159                            CallGraphDOTInfo *CGInfo) {
160     if (Node == CGInfo->getCallGraph()->getExternalCallingNode())
161       return "external caller";
162     if (Node == CGInfo->getCallGraph()->getCallsExternalNode())
163       return "external callee";
164 
165     if (Function *Func = Node->getFunction())
166       return std::string(Func->getName());
167     return "external node";
168   }
169   static const CallGraphNode *CGGetValuePtr(CallGraphNode::CallRecord P) {
170     return P.second;
171   }
172 
173   // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
174   typedef mapped_iterator<CallGraphNode::const_iterator,
175                           decltype(&CGGetValuePtr)>
176       nodes_iterator;
177 
178   std::string getEdgeAttributes(const CallGraphNode *Node, nodes_iterator I,
179                                 CallGraphDOTInfo *CGInfo) {
180     if (!ShowEdgeWeight)
181       return "";
182 
183     Function *Caller = Node->getFunction();
184     if (Caller == nullptr || Caller->isDeclaration())
185       return "";
186 
187     Function *Callee = (*I)->getFunction();
188     if (Callee == nullptr)
189       return "";
190 
191     uint64_t Counter = getNumOfCalls(*Caller, *Callee);
192     double Width =
193         1 + 2 * (double(Counter) / CGInfo->getMaxFreq());
194     std::string Attrs = "label=\"" + std::to_string(Counter) +
195                         "\" penwidth=" + std::to_string(Width);
196     return Attrs;
197   }
198 
199   std::string getNodeAttributes(const CallGraphNode *Node,
200                                 CallGraphDOTInfo *CGInfo) {
201     Function *F = Node->getFunction();
202     if (F == nullptr)
203       return "";
204     std::string attrs;
205     if (ShowHeatColors) {
206       uint64_t freq = CGInfo->getFreq(F);
207       std::string color = getHeatColor(freq, CGInfo->getMaxFreq());
208       std::string edgeColor = (freq <= (CGInfo->getMaxFreq() / 2))
209                                   ? getHeatColor(0)
210                                   : getHeatColor(1);
211       attrs = "color=\"" + edgeColor + "ff\", style=filled, fillcolor=\"" +
212               color + "80\"";
213     }
214     return attrs;
215   }
216 };
217 
218 } // end llvm namespace
219 
220 namespace {
221 void doCallGraphDOTPrinting(
222     Module &M, function_ref<BlockFrequencyInfo *(Function &)> LookupBFI) {
223   std::string Filename;
224   if (!CallGraphDotFilenamePrefix.empty())
225     Filename = (CallGraphDotFilenamePrefix + ".callgraph.dot");
226   else
227     Filename = (std::string(M.getModuleIdentifier()) + ".callgraph.dot");
228   errs() << "Writing '" << Filename << "'...";
229 
230   std::error_code EC;
231   raw_fd_ostream File(Filename, EC, sys::fs::OF_Text);
232 
233   CallGraph CG(M);
234   CallGraphDOTInfo CFGInfo(&M, &CG, LookupBFI);
235 
236   if (!EC)
237     WriteGraph(File, &CFGInfo);
238   else
239     errs() << "  error opening file for writing!";
240   errs() << "\n";
241 }
242 
243 void viewCallGraph(Module &M,
244                    function_ref<BlockFrequencyInfo *(Function &)> LookupBFI) {
245   CallGraph CG(M);
246   CallGraphDOTInfo CFGInfo(&M, &CG, LookupBFI);
247 
248   std::string Title =
249       DOTGraphTraits<CallGraphDOTInfo *>::getGraphName(&CFGInfo);
250   ViewGraph(&CFGInfo, "callgraph", true, Title);
251 }
252 } // namespace
253 
254 namespace llvm {
255 PreservedAnalyses CallGraphDOTPrinterPass::run(Module &M,
256                                                ModuleAnalysisManager &AM) {
257   FunctionAnalysisManager &FAM =
258       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
259 
260   auto LookupBFI = [&FAM](Function &F) {
261     return &FAM.getResult<BlockFrequencyAnalysis>(F);
262   };
263 
264   doCallGraphDOTPrinting(M, LookupBFI);
265 
266   return PreservedAnalyses::all();
267 }
268 
269 PreservedAnalyses CallGraphViewerPass::run(Module &M,
270                                            ModuleAnalysisManager &AM) {
271 
272   FunctionAnalysisManager &FAM =
273       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
274 
275   auto LookupBFI = [&FAM](Function &F) {
276     return &FAM.getResult<BlockFrequencyAnalysis>(F);
277   };
278 
279   viewCallGraph(M, LookupBFI);
280 
281   return PreservedAnalyses::all();
282 }
283 } // namespace llvm
284 
285 namespace {
286 // Viewer
287 class CallGraphViewer : public ModulePass {
288 public:
289   static char ID;
290   CallGraphViewer() : ModulePass(ID) {}
291 
292   void getAnalysisUsage(AnalysisUsage &AU) const override;
293   bool runOnModule(Module &M) override;
294 };
295 
296 void CallGraphViewer::getAnalysisUsage(AnalysisUsage &AU) const {
297   ModulePass::getAnalysisUsage(AU);
298   AU.addRequired<BlockFrequencyInfoWrapperPass>();
299   AU.setPreservesAll();
300 }
301 
302 bool CallGraphViewer::runOnModule(Module &M) {
303   auto LookupBFI = [this](Function &F) {
304     return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
305   };
306 
307   viewCallGraph(M, LookupBFI);
308 
309   return false;
310 }
311 
312 // DOT Printer
313 
314 class CallGraphDOTPrinter : public ModulePass {
315 public:
316   static char ID;
317   CallGraphDOTPrinter() : ModulePass(ID) {}
318 
319   void getAnalysisUsage(AnalysisUsage &AU) const override;
320   bool runOnModule(Module &M) override;
321 };
322 
323 void CallGraphDOTPrinter::getAnalysisUsage(AnalysisUsage &AU) const {
324   ModulePass::getAnalysisUsage(AU);
325   AU.addRequired<BlockFrequencyInfoWrapperPass>();
326   AU.setPreservesAll();
327 }
328 
329 bool CallGraphDOTPrinter::runOnModule(Module &M) {
330   auto LookupBFI = [this](Function &F) {
331     return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
332   };
333 
334   doCallGraphDOTPrinting(M, LookupBFI);
335 
336   return false;
337 }
338 
339 } // end anonymous namespace
340 
341 char CallGraphViewer::ID = 0;
342 INITIALIZE_PASS(CallGraphViewer, "view-callgraph", "View call graph", false,
343                 false)
344 
345 char CallGraphDOTPrinter::ID = 0;
346 INITIALIZE_PASS(CallGraphDOTPrinter, "dot-callgraph",
347                 "Print call graph to 'dot' file", false, false)
348 
349 // Create methods available outside of this file, to use them
350 // "include/llvm/LinkAllPasses.h". Otherwise the pass would be deleted by
351 // the link time optimization.
352 
353 ModulePass *llvm::createCallGraphViewerPass() { return new CallGraphViewer(); }
354 
355 ModulePass *llvm::createCallGraphDOTPrinterPass() {
356   return new CallGraphDOTPrinter();
357 }
358