1 //===- Standard pass instrumentations handling ----------------*- C++ -*--===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 ///
10 /// This file defines IR-printing pass instrumentation callbacks as well as
11 /// StandardInstrumentations class that manages standard pass instrumentations.
12 ///
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Passes/StandardInstrumentations.h"
16 #include "llvm/ADT/Optional.h"
17 #include "llvm/Analysis/CallGraphSCCPass.h"
18 #include "llvm/Analysis/LazyCallGraph.h"
19 #include "llvm/Analysis/LoopInfo.h"
20 #include "llvm/IR/Function.h"
21 #include "llvm/IR/IRPrintingPasses.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/IR/PassInstrumentation.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/Support/raw_ostream.h"
27 
28 using namespace llvm;
29 
30 namespace {
31 
32 /// Extracting Module out of \p IR unit. Also fills a textual description
33 /// of \p IR for use in header when printing.
34 Optional<std::pair<const Module *, std::string>> unwrapModule(Any IR) {
35   if (any_isa<const Module *>(IR))
36     return std::make_pair(any_cast<const Module *>(IR), std::string());
37 
38   if (any_isa<const Function *>(IR)) {
39     const Function *F = any_cast<const Function *>(IR);
40     if (!llvm::isFunctionInPrintList(F->getName()))
41       return None;
42     const Module *M = F->getParent();
43     return std::make_pair(M, formatv(" (function: {0})", F->getName()).str());
44   }
45 
46   if (any_isa<const LazyCallGraph::SCC *>(IR)) {
47     const LazyCallGraph::SCC *C = any_cast<const LazyCallGraph::SCC *>(IR);
48     for (const LazyCallGraph::Node &N : *C) {
49       const Function &F = N.getFunction();
50       if (!F.isDeclaration() && isFunctionInPrintList(F.getName())) {
51         const Module *M = F.getParent();
52         return std::make_pair(M, formatv(" (scc: {0})", C->getName()).str());
53       }
54     }
55     return None;
56   }
57 
58   if (any_isa<const Loop *>(IR)) {
59     const Loop *L = any_cast<const Loop *>(IR);
60     const Function *F = L->getHeader()->getParent();
61     if (!isFunctionInPrintList(F->getName()))
62       return None;
63     const Module *M = F->getParent();
64     std::string LoopName;
65     raw_string_ostream ss(LoopName);
66     L->getHeader()->printAsOperand(ss, false);
67     return std::make_pair(M, formatv(" (loop: {0})", ss.str()).str());
68   }
69 
70   llvm_unreachable("Unknown IR unit");
71 }
72 
73 void printIR(const Module *M, StringRef Banner, StringRef Extra = StringRef()) {
74   dbgs() << Banner << Extra << "\n";
75   M->print(dbgs(), nullptr, false);
76 }
77 void printIR(const Function *F, StringRef Banner,
78              StringRef Extra = StringRef()) {
79   if (!llvm::isFunctionInPrintList(F->getName()))
80     return;
81   dbgs() << Banner << Extra << "\n" << static_cast<const Value &>(*F);
82 }
83 void printIR(const LazyCallGraph::SCC *C, StringRef Banner,
84              StringRef Extra = StringRef()) {
85   bool BannerPrinted = false;
86   for (const LazyCallGraph::Node &N : *C) {
87     const Function &F = N.getFunction();
88     if (!F.isDeclaration() && llvm::isFunctionInPrintList(F.getName())) {
89       if (!BannerPrinted) {
90         dbgs() << Banner << Extra << "\n";
91         BannerPrinted = true;
92       }
93       F.print(dbgs());
94     }
95   }
96 }
97 void printIR(const Loop *L, StringRef Banner) {
98   const Function *F = L->getHeader()->getParent();
99   if (!llvm::isFunctionInPrintList(F->getName()))
100     return;
101   llvm::printLoop(const_cast<Loop &>(*L), dbgs(), Banner);
102 }
103 
104 /// Generic IR-printing helper that unpacks a pointer to IRUnit wrapped into
105 /// llvm::Any and does actual print job.
106 void unwrapAndPrint(Any IR, StringRef Banner, bool ForceModule = false) {
107   if (ForceModule) {
108     if (auto UnwrappedModule = unwrapModule(IR))
109       printIR(UnwrappedModule->first, Banner, UnwrappedModule->second);
110     return;
111   }
112 
113   if (any_isa<const Module *>(IR)) {
114     const Module *M = any_cast<const Module *>(IR);
115     assert(M && "module should be valid for printing");
116     printIR(M, Banner);
117     return;
118   }
119 
120   if (any_isa<const Function *>(IR)) {
121     const Function *F = any_cast<const Function *>(IR);
122     assert(F && "function should be valid for printing");
123     printIR(F, Banner);
124     return;
125   }
126 
127   if (any_isa<const LazyCallGraph::SCC *>(IR)) {
128     const LazyCallGraph::SCC *C = any_cast<const LazyCallGraph::SCC *>(IR);
129     assert(C && "scc should be valid for printing");
130     std::string Extra = formatv(" (scc: {0})", C->getName());
131     printIR(C, Banner, Extra);
132     return;
133   }
134 
135   if (any_isa<const Loop *>(IR)) {
136     const Loop *L = any_cast<const Loop *>(IR);
137     assert(L && "Loop should be valid for printing");
138     printIR(L, Banner);
139     return;
140   }
141   llvm_unreachable("Unknown wrapped IR type");
142 }
143 
144 } // namespace
145 
146 PrintIRInstrumentation::~PrintIRInstrumentation() {
147   assert(ModuleDescStack.empty() && "ModuleDescStack is not empty at exit");
148 }
149 
150 void PrintIRInstrumentation::pushModuleDesc(StringRef PassID, Any IR) {
151   assert(StoreModuleDesc);
152   const Module *M = nullptr;
153   std::string Extra;
154   if (auto UnwrappedModule = unwrapModule(IR))
155     std::tie(M, Extra) = UnwrappedModule.getValue();
156   ModuleDescStack.emplace_back(M, Extra, PassID);
157 }
158 
159 PrintIRInstrumentation::PrintModuleDesc
160 PrintIRInstrumentation::popModuleDesc(StringRef PassID) {
161   assert(!ModuleDescStack.empty() && "empty ModuleDescStack");
162   PrintModuleDesc ModuleDesc = ModuleDescStack.pop_back_val();
163   assert(std::get<2>(ModuleDesc).equals(PassID) && "malformed ModuleDescStack");
164   return ModuleDesc;
165 }
166 
167 bool PrintIRInstrumentation::printBeforePass(StringRef PassID, Any IR) {
168   if (PassID.startswith("PassManager<") || PassID.contains("PassAdaptor<"))
169     return true;
170 
171   // Saving Module for AfterPassInvalidated operations.
172   // Note: here we rely on a fact that we do not change modules while
173   // traversing the pipeline, so the latest captured module is good
174   // for all print operations that has not happen yet.
175   if (StoreModuleDesc && llvm::shouldPrintAfterPass(PassID))
176     pushModuleDesc(PassID, IR);
177 
178   if (!llvm::shouldPrintBeforePass(PassID))
179     return true;
180 
181   SmallString<20> Banner = formatv("*** IR Dump Before {0} ***", PassID);
182   unwrapAndPrint(IR, Banner, llvm::forcePrintModuleIR());
183   return true;
184 }
185 
186 void PrintIRInstrumentation::printAfterPass(StringRef PassID, Any IR) {
187   if (PassID.startswith("PassManager<") || PassID.contains("PassAdaptor<"))
188     return;
189 
190   if (!llvm::shouldPrintAfterPass(PassID))
191     return;
192 
193   if (StoreModuleDesc)
194     popModuleDesc(PassID);
195 
196   SmallString<20> Banner = formatv("*** IR Dump After {0} ***", PassID);
197   unwrapAndPrint(IR, Banner, llvm::forcePrintModuleIR());
198 }
199 
200 void PrintIRInstrumentation::printAfterPassInvalidated(StringRef PassID) {
201   if (!StoreModuleDesc || !llvm::shouldPrintAfterPass(PassID))
202     return;
203 
204   if (PassID.startswith("PassManager<") || PassID.contains("PassAdaptor<"))
205     return;
206 
207   const Module *M;
208   std::string Extra;
209   StringRef StoredPassID;
210   std::tie(M, Extra, StoredPassID) = popModuleDesc(PassID);
211   // Additional filtering (e.g. -filter-print-func) can lead to module
212   // printing being skipped.
213   if (!M)
214     return;
215 
216   SmallString<20> Banner =
217       formatv("*** IR Dump After {0} *** invalidated: ", PassID);
218   printIR(M, Banner, Extra);
219 }
220 
221 void PrintIRInstrumentation::registerCallbacks(
222     PassInstrumentationCallbacks &PIC) {
223   // BeforePass callback is not just for printing, it also saves a Module
224   // for later use in AfterPassInvalidated.
225   StoreModuleDesc = llvm::forcePrintModuleIR() && llvm::shouldPrintAfterPass();
226   if (llvm::shouldPrintBeforePass() || StoreModuleDesc)
227     PIC.registerBeforePassCallback(
228         [this](StringRef P, Any IR) { return this->printBeforePass(P, IR); });
229 
230   if (llvm::shouldPrintAfterPass()) {
231     PIC.registerAfterPassCallback(
232         [this](StringRef P, Any IR) { this->printAfterPass(P, IR); });
233     PIC.registerAfterPassInvalidatedCallback(
234         [this](StringRef P) { this->printAfterPassInvalidated(P); });
235   }
236 }
237 
238 void StandardInstrumentations::registerCallbacks(
239     PassInstrumentationCallbacks &PIC) {
240   PrintIR.registerCallbacks(PIC);
241   TimePasses.registerCallbacks(PIC);
242 }
243