1 //===- IRPrinting.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "PassDetail.h"
10 #include "mlir/Pass/PassManager.h"
11 #include "llvm/Support/Format.h"
12 #include "llvm/Support/FormatVariadic.h"
13 #include "llvm/Support/SHA1.h"
14
15 using namespace mlir;
16 using namespace mlir::detail;
17
18 namespace {
19 //===----------------------------------------------------------------------===//
20 // OperationFingerPrint
21 //===----------------------------------------------------------------------===//
22
23 /// A unique fingerprint for a specific operation, and all of it's internal
24 /// operations.
25 class OperationFingerPrint {
26 public:
OperationFingerPrint(Operation * topOp)27 OperationFingerPrint(Operation *topOp) {
28 llvm::SHA1 hasher;
29
30 // Hash each of the operations based upon their mutable bits:
31 topOp->walk([&](Operation *op) {
32 // - Operation pointer
33 addDataToHash(hasher, op);
34 // - Attributes
35 addDataToHash(hasher, op->getAttrDictionary());
36 // - Blocks in Regions
37 for (Region ®ion : op->getRegions()) {
38 for (Block &block : region) {
39 addDataToHash(hasher, &block);
40 for (BlockArgument arg : block.getArguments())
41 addDataToHash(hasher, arg);
42 }
43 }
44 // - Location
45 addDataToHash(hasher, op->getLoc().getAsOpaquePointer());
46 // - Operands
47 for (Value operand : op->getOperands())
48 addDataToHash(hasher, operand);
49 // - Successors
50 for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i)
51 addDataToHash(hasher, op->getSuccessor(i));
52 });
53 hash = hasher.result();
54 }
55
operator ==(const OperationFingerPrint & other) const56 bool operator==(const OperationFingerPrint &other) const {
57 return hash == other.hash;
58 }
operator !=(const OperationFingerPrint & other) const59 bool operator!=(const OperationFingerPrint &other) const {
60 return !(*this == other);
61 }
62
63 private:
addDataToHash(llvm::SHA1 & hasher,const T & data)64 template <typename T> void addDataToHash(llvm::SHA1 &hasher, const T &data) {
65 hasher.update(
66 ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
67 }
68
69 SmallString<20> hash;
70 };
71
72 //===----------------------------------------------------------------------===//
73 // IRPrinter
74 //===----------------------------------------------------------------------===//
75
76 class IRPrinterInstrumentation : public PassInstrumentation {
77 public:
IRPrinterInstrumentation(std::unique_ptr<PassManager::IRPrinterConfig> config)78 IRPrinterInstrumentation(std::unique_ptr<PassManager::IRPrinterConfig> config)
79 : config(std::move(config)) {}
80
81 private:
82 /// Instrumentation hooks.
83 void runBeforePass(Pass *pass, Operation *op) override;
84 void runAfterPass(Pass *pass, Operation *op) override;
85 void runAfterPassFailed(Pass *pass, Operation *op) override;
86
87 /// Configuration to use.
88 std::unique_ptr<PassManager::IRPrinterConfig> config;
89
90 /// The following is a set of fingerprints for operations that are currently
91 /// being operated on in a pass. This field is only used when the
92 /// configuration asked for change detection.
93 DenseMap<Pass *, OperationFingerPrint> beforePassFingerPrints;
94 };
95 } // end anonymous namespace
96
printIR(Operation * op,bool printModuleScope,raw_ostream & out,OpPrintingFlags flags)97 static void printIR(Operation *op, bool printModuleScope, raw_ostream &out,
98 OpPrintingFlags flags) {
99 // Otherwise, check to see if we are not printing at module scope.
100 if (!printModuleScope)
101 return op->print(out << " //----- //\n",
102 op->getBlock() ? flags.useLocalScope() : flags);
103
104 // Otherwise, we are printing at module scope.
105 out << " ('" << op->getName() << "' operation";
106 if (auto symbolName =
107 op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()))
108 out << ": @" << symbolName.getValue();
109 out << ") //----- //\n";
110
111 // Find the top-level operation.
112 auto *topLevelOp = op;
113 while (auto *parentOp = topLevelOp->getParentOp())
114 topLevelOp = parentOp;
115 topLevelOp->print(out, flags);
116 }
117
118 /// Instrumentation hooks.
runBeforePass(Pass * pass,Operation * op)119 void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) {
120 if (isa<OpToOpPassAdaptor>(pass))
121 return;
122 // If the config asked to detect changes, record the current fingerprint.
123 if (config->shouldPrintAfterOnlyOnChange())
124 beforePassFingerPrints.try_emplace(pass, op);
125
126 config->printBeforeIfEnabled(pass, op, [&](raw_ostream &out) {
127 out << "// -----// IR Dump Before " << pass->getName();
128 printIR(op, config->shouldPrintAtModuleScope(), out,
129 config->getOpPrintingFlags());
130 out << "\n\n";
131 });
132 }
133
runAfterPass(Pass * pass,Operation * op)134 void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) {
135 if (isa<OpToOpPassAdaptor>(pass))
136 return;
137
138 // Check to see if we are only printing on failure.
139 if (config->shouldPrintAfterOnlyOnFailure())
140 return;
141
142 // If the config asked to detect changes, compare the current fingerprint with
143 // the previous.
144 if (config->shouldPrintAfterOnlyOnChange()) {
145 auto fingerPrintIt = beforePassFingerPrints.find(pass);
146 assert(fingerPrintIt != beforePassFingerPrints.end() &&
147 "expected valid fingerprint");
148 // If the fingerprints are the same, we don't print the IR.
149 if (fingerPrintIt->second == OperationFingerPrint(op)) {
150 beforePassFingerPrints.erase(fingerPrintIt);
151 return;
152 }
153 beforePassFingerPrints.erase(fingerPrintIt);
154 }
155
156 config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) {
157 out << "// -----// IR Dump After " << pass->getName();
158 printIR(op, config->shouldPrintAtModuleScope(), out,
159 config->getOpPrintingFlags());
160 out << "\n\n";
161 });
162 }
163
runAfterPassFailed(Pass * pass,Operation * op)164 void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) {
165 if (isa<OpToOpPassAdaptor>(pass))
166 return;
167 if (config->shouldPrintAfterOnlyOnChange())
168 beforePassFingerPrints.erase(pass);
169
170 config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) {
171 out << formatv("// -----// IR Dump After {0} Failed", pass->getName());
172 printIR(op, config->shouldPrintAtModuleScope(), out,
173 OpPrintingFlags().printGenericOpForm());
174 out << "\n\n";
175 });
176 }
177
178 //===----------------------------------------------------------------------===//
179 // IRPrinterConfig
180 //===----------------------------------------------------------------------===//
181
182 /// Initialize the configuration.
IRPrinterConfig(bool printModuleScope,bool printAfterOnlyOnChange,bool printAfterOnlyOnFailure,OpPrintingFlags opPrintingFlags)183 PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope,
184 bool printAfterOnlyOnChange,
185 bool printAfterOnlyOnFailure,
186 OpPrintingFlags opPrintingFlags)
187 : printModuleScope(printModuleScope),
188 printAfterOnlyOnChange(printAfterOnlyOnChange),
189 printAfterOnlyOnFailure(printAfterOnlyOnFailure),
190 opPrintingFlags(opPrintingFlags) {}
~IRPrinterConfig()191 PassManager::IRPrinterConfig::~IRPrinterConfig() {}
192
193 /// A hook that may be overridden by a derived config that checks if the IR
194 /// of 'operation' should be dumped *before* the pass 'pass' has been
195 /// executed. If the IR should be dumped, 'printCallback' should be invoked
196 /// with the stream to dump into.
printBeforeIfEnabled(Pass * pass,Operation * operation,PrintCallbackFn printCallback)197 void PassManager::IRPrinterConfig::printBeforeIfEnabled(
198 Pass *pass, Operation *operation, PrintCallbackFn printCallback) {
199 // By default, never print.
200 }
201
202 /// A hook that may be overridden by a derived config that checks if the IR
203 /// of 'operation' should be dumped *after* the pass 'pass' has been
204 /// executed. If the IR should be dumped, 'printCallback' should be invoked
205 /// with the stream to dump into.
printAfterIfEnabled(Pass * pass,Operation * operation,PrintCallbackFn printCallback)206 void PassManager::IRPrinterConfig::printAfterIfEnabled(
207 Pass *pass, Operation *operation, PrintCallbackFn printCallback) {
208 // By default, never print.
209 }
210
211 //===----------------------------------------------------------------------===//
212 // PassManager
213 //===----------------------------------------------------------------------===//
214
215 namespace {
216 /// Simple wrapper config that allows for the simpler interface defined above.
217 struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig {
BasicIRPrinterConfig__anon87e2fe8e0611::BasicIRPrinterConfig218 BasicIRPrinterConfig(
219 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
220 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
221 bool printModuleScope, bool printAfterOnlyOnChange,
222 bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags,
223 raw_ostream &out)
224 : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange,
225 printAfterOnlyOnFailure, opPrintingFlags),
226 shouldPrintBeforePass(shouldPrintBeforePass),
227 shouldPrintAfterPass(shouldPrintAfterPass), out(out) {
228 assert((shouldPrintBeforePass || shouldPrintAfterPass) &&
229 "expected at least one valid filter function");
230 }
231
printBeforeIfEnabled__anon87e2fe8e0611::BasicIRPrinterConfig232 void printBeforeIfEnabled(Pass *pass, Operation *operation,
233 PrintCallbackFn printCallback) final {
234 if (shouldPrintBeforePass && shouldPrintBeforePass(pass, operation))
235 printCallback(out);
236 }
237
printAfterIfEnabled__anon87e2fe8e0611::BasicIRPrinterConfig238 void printAfterIfEnabled(Pass *pass, Operation *operation,
239 PrintCallbackFn printCallback) final {
240 if (shouldPrintAfterPass && shouldPrintAfterPass(pass, operation))
241 printCallback(out);
242 }
243
244 /// Filter functions for before and after pass execution.
245 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass;
246 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass;
247
248 /// The stream to output to.
249 raw_ostream &out;
250 };
251 } // end anonymous namespace
252
253 /// Add an instrumentation to print the IR before and after pass execution,
254 /// using the provided configuration.
enableIRPrinting(std::unique_ptr<IRPrinterConfig> config)255 void PassManager::enableIRPrinting(std::unique_ptr<IRPrinterConfig> config) {
256 if (config->shouldPrintAtModuleScope() &&
257 getContext()->isMultithreadingEnabled())
258 llvm::report_fatal_error("IR printing can't be setup on a pass-manager "
259 "without disabling multi-threading first.");
260 addInstrumentation(
261 std::make_unique<IRPrinterInstrumentation>(std::move(config)));
262 }
263
264 /// Add an instrumentation to print the IR before and after pass execution.
enableIRPrinting(std::function<bool (Pass *,Operation *)> shouldPrintBeforePass,std::function<bool (Pass *,Operation *)> shouldPrintAfterPass,bool printModuleScope,bool printAfterOnlyOnChange,bool printAfterOnlyOnFailure,raw_ostream & out,OpPrintingFlags opPrintingFlags)265 void PassManager::enableIRPrinting(
266 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
267 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
268 bool printModuleScope, bool printAfterOnlyOnChange,
269 bool printAfterOnlyOnFailure, raw_ostream &out,
270 OpPrintingFlags opPrintingFlags) {
271 enableIRPrinting(std::make_unique<BasicIRPrinterConfig>(
272 std::move(shouldPrintBeforePass), std::move(shouldPrintAfterPass),
273 printModuleScope, printAfterOnlyOnChange, printAfterOnlyOnFailure,
274 opPrintingFlags, out));
275 }
276