1 //===- PassCrashRecovery.cpp - Pass Crash Recovery Implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/IR/Diagnostics.h"
11 #include "mlir/IR/Dialect.h"
12 #include "mlir/IR/Verifier.h"
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Support/FileUtilities.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/ScopeExit.h"
17 #include "llvm/ADT/SetVector.h"
18 #include "llvm/Support/CommandLine.h"
19 #include "llvm/Support/CrashRecoveryContext.h"
20 #include "llvm/Support/Mutex.h"
21 #include "llvm/Support/Signals.h"
22 #include "llvm/Support/Threading.h"
23 #include "llvm/Support/ToolOutputFile.h"
24 
25 using namespace mlir;
26 using namespace mlir::detail;
27 
28 //===----------------------------------------------------------------------===//
29 // RecoveryReproducerContext
30 //===----------------------------------------------------------------------===//
31 
32 namespace mlir {
33 namespace detail {
34 /// This class contains all of the context for generating a recovery reproducer.
35 /// Each recovery context is registered globally to allow for generating
36 /// reproducers when a signal is raised, such as a segfault.
37 struct RecoveryReproducerContext {
38   RecoveryReproducerContext(std::string passPipelineStr, Operation *op,
39                             PassManager::ReproducerStreamFactory &streamFactory,
40                             bool verifyPasses);
41   ~RecoveryReproducerContext();
42 
43   /// Generate a reproducer with the current context.
44   void generate(std::string &description);
45 
46   /// Disable this reproducer context. This prevents the context from generating
47   /// a reproducer in the result of a crash.
48   void disable();
49 
50   /// Enable a previously disabled reproducer context.
51   void enable();
52 
53 private:
54   /// This function is invoked in the event of a crash.
55   static void crashHandler(void *);
56 
57   /// Register a signal handler to run in the event of a crash.
58   static void registerSignalHandler();
59 
60   /// The textual description of the currently executing pipeline.
61   std::string pipeline;
62 
63   /// The MLIR operation representing the IR before the crash.
64   Operation *preCrashOperation;
65 
66   /// The factory for the reproducer output stream to use when generating the
67   /// reproducer.
68   PassManager::ReproducerStreamFactory &streamFactory;
69 
70   /// Various pass manager and context flags.
71   bool disableThreads;
72   bool verifyPasses;
73 
74   /// The current set of active reproducer contexts. This is used in the event
75   /// of a crash. This is not thread_local as the pass manager may produce any
76   /// number of child threads. This uses a set to allow for multiple MLIR pass
77   /// managers to be running at the same time.
78   static llvm::ManagedStatic<llvm::sys::SmartMutex<true>> reproducerMutex;
79   static llvm::ManagedStatic<
80       llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
81       reproducerSet;
82 };
83 } // namespace detail
84 } // namespace mlir
85 
86 llvm::ManagedStatic<llvm::sys::SmartMutex<true>>
87     RecoveryReproducerContext::reproducerMutex;
88 llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
89     RecoveryReproducerContext::reproducerSet;
90 
RecoveryReproducerContext(std::string passPipelineStr,Operation * op,PassManager::ReproducerStreamFactory & streamFactory,bool verifyPasses)91 RecoveryReproducerContext::RecoveryReproducerContext(
92     std::string passPipelineStr, Operation *op,
93     PassManager::ReproducerStreamFactory &streamFactory, bool verifyPasses)
94     : pipeline(std::move(passPipelineStr)), preCrashOperation(op->clone()),
95       streamFactory(streamFactory),
96       disableThreads(!op->getContext()->isMultithreadingEnabled()),
97       verifyPasses(verifyPasses) {
98   enable();
99 }
100 
~RecoveryReproducerContext()101 RecoveryReproducerContext::~RecoveryReproducerContext() {
102   // Erase the cloned preCrash IR that we cached.
103   preCrashOperation->erase();
104   disable();
105 }
106 
generate(std::string & description)107 void RecoveryReproducerContext::generate(std::string &description) {
108   llvm::raw_string_ostream descOS(description);
109 
110   // Try to create a new output stream for this crash reproducer.
111   std::string error;
112   std::unique_ptr<PassManager::ReproducerStream> stream = streamFactory(error);
113   if (!stream) {
114     descOS << "failed to create output stream: " << error;
115     return;
116   }
117   descOS << "reproducer generated at `" << stream->description() << "`";
118 
119   // Output the current pass manager configuration to the crash stream.
120   auto &os = stream->os();
121   os << "// configuration: -pass-pipeline='" << pipeline << "'";
122   if (disableThreads)
123     os << " -mlir-disable-threading";
124   if (verifyPasses)
125     os << " -verify-each";
126   os << '\n';
127 
128   // Output the .mlir module.
129   preCrashOperation->print(os);
130 }
131 
disable()132 void RecoveryReproducerContext::disable() {
133   llvm::sys::SmartScopedLock<true> lock(*reproducerMutex);
134   reproducerSet->remove(this);
135   if (reproducerSet->empty())
136     llvm::CrashRecoveryContext::Disable();
137 }
138 
enable()139 void RecoveryReproducerContext::enable() {
140   llvm::sys::SmartScopedLock<true> lock(*reproducerMutex);
141   if (reproducerSet->empty())
142     llvm::CrashRecoveryContext::Enable();
143   registerSignalHandler();
144   reproducerSet->insert(this);
145 }
146 
crashHandler(void *)147 void RecoveryReproducerContext::crashHandler(void *) {
148   // Walk the current stack of contexts and generate a reproducer for each one.
149   // We can't know for certain which one was the cause, so we need to generate
150   // a reproducer for all of them.
151   for (RecoveryReproducerContext *context : *reproducerSet) {
152     std::string description;
153     context->generate(description);
154 
155     // Emit an error using information only available within the context.
156     context->preCrashOperation->getContext()->printOpOnDiagnostic(false);
157     context->preCrashOperation->emitError()
158         << "A failure has been detected while processing the MLIR module:"
159         << description;
160   }
161 }
162 
registerSignalHandler()163 void RecoveryReproducerContext::registerSignalHandler() {
164   // Ensure that the handler is only registered once.
165   static bool registered =
166       (llvm::sys::AddSignalHandler(crashHandler, nullptr), false);
167   (void)registered;
168 }
169 
170 //===----------------------------------------------------------------------===//
171 // PassCrashReproducerGenerator
172 //===----------------------------------------------------------------------===//
173 
174 struct PassCrashReproducerGenerator::Impl {
ImplPassCrashReproducerGenerator::Impl175   Impl(PassManager::ReproducerStreamFactory &streamFactory,
176        bool localReproducer)
177       : streamFactory(streamFactory), localReproducer(localReproducer) {}
178 
179   /// The factory to use when generating a crash reproducer.
180   PassManager::ReproducerStreamFactory streamFactory;
181 
182   /// Flag indicating if reproducer generation should be localized to the
183   /// failing pass.
184   bool localReproducer;
185 
186   /// A record of all of the currently active reproducer contexts.
187   SmallVector<std::unique_ptr<RecoveryReproducerContext>> activeContexts;
188 
189   /// The set of all currently running passes. Note: This is not populated when
190   /// `localReproducer` is true, as each pass will get its own recovery context.
191   SetVector<std::pair<Pass *, Operation *>> runningPasses;
192 
193   /// Various pass manager flags that get emitted when generating a reproducer.
194   bool pmFlagVerifyPasses;
195 };
196 
PassCrashReproducerGenerator(PassManager::ReproducerStreamFactory & streamFactory,bool localReproducer)197 PassCrashReproducerGenerator::PassCrashReproducerGenerator(
198     PassManager::ReproducerStreamFactory &streamFactory, bool localReproducer)
199     : impl(std::make_unique<Impl>(streamFactory, localReproducer)) {}
~PassCrashReproducerGenerator()200 PassCrashReproducerGenerator::~PassCrashReproducerGenerator() {}
201 
initialize(iterator_range<PassManager::pass_iterator> passes,Operation * op,bool pmFlagVerifyPasses)202 void PassCrashReproducerGenerator::initialize(
203     iterator_range<PassManager::pass_iterator> passes, Operation *op,
204     bool pmFlagVerifyPasses) {
205   assert((!impl->localReproducer ||
206           !op->getContext()->isMultithreadingEnabled()) &&
207          "expected multi-threading to be disabled when generating a local "
208          "reproducer");
209 
210   llvm::CrashRecoveryContext::Enable();
211   impl->pmFlagVerifyPasses = pmFlagVerifyPasses;
212 
213   // If we aren't generating a local reproducer, prepare a reproducer for the
214   // given top-level operation.
215   if (!impl->localReproducer)
216     prepareReproducerFor(passes, op);
217 }
218 
219 static void
formatPassOpReproducerMessage(Diagnostic & os,std::pair<Pass *,Operation * > passOpPair)220 formatPassOpReproducerMessage(Diagnostic &os,
221                               std::pair<Pass *, Operation *> passOpPair) {
222   os << "`" << passOpPair.first->getName() << "` on "
223      << "'" << passOpPair.second->getName() << "' operation";
224   if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(passOpPair.second))
225     os << ": @" << symbol.getName();
226 }
227 
finalize(Operation * rootOp,LogicalResult executionResult)228 void PassCrashReproducerGenerator::finalize(Operation *rootOp,
229                                             LogicalResult executionResult) {
230   // Don't generate a reproducer if we have no active contexts.
231   if (impl->activeContexts.empty())
232     return;
233 
234   // If the pass manager execution succeeded, we don't generate any reproducers.
235   if (succeeded(executionResult))
236     return impl->activeContexts.clear();
237 
238   MLIRContext *context = rootOp->getContext();
239   bool shouldPrintOnOp = context->shouldPrintOpOnDiagnostic();
240   context->printOpOnDiagnostic(false);
241   InFlightDiagnostic diag = rootOp->emitError()
242                             << "Failures have been detected while "
243                                "processing an MLIR pass pipeline";
244   context->printOpOnDiagnostic(shouldPrintOnOp);
245 
246   // If we are generating a global reproducer, we include all of the running
247   // passes in the error message for the only active context.
248   if (!impl->localReproducer) {
249     assert(impl->activeContexts.size() == 1 && "expected one active context");
250 
251     // Generate the reproducer.
252     std::string description;
253     impl->activeContexts.front()->generate(description);
254 
255     // Emit an error to the user.
256     Diagnostic &note = diag.attachNote() << "Pipeline failed while executing [";
257     llvm::interleaveComma(impl->runningPasses, note,
258                           [&](const std::pair<Pass *, Operation *> &value) {
259                             formatPassOpReproducerMessage(note, value);
260                           });
261     note << "]: " << description;
262     return;
263   }
264 
265   // If we were generating a local reproducer, we generate a reproducer for the
266   // most recently executing pass using the matching entry from  `runningPasses`
267   // to generate a localized diagnostic message.
268   assert(impl->activeContexts.size() == impl->runningPasses.size() &&
269          "expected running passes to match active contexts");
270 
271   // Generate the reproducer.
272   RecoveryReproducerContext &reproducerContext = *impl->activeContexts.back();
273   std::string description;
274   reproducerContext.generate(description);
275 
276   // Emit an error to the user.
277   Diagnostic &note = diag.attachNote() << "Pipeline failed while executing ";
278   formatPassOpReproducerMessage(note, impl->runningPasses.back());
279   note << ": " << description;
280 
281   impl->activeContexts.clear();
282 }
283 
prepareReproducerFor(Pass * pass,Operation * op)284 void PassCrashReproducerGenerator::prepareReproducerFor(Pass *pass,
285                                                         Operation *op) {
286   // If not tracking local reproducers, we simply remember that this pass is
287   // running.
288   impl->runningPasses.insert(std::make_pair(pass, op));
289   if (!impl->localReproducer)
290     return;
291 
292   // Disable the current pass recovery context, if there is one. This may happen
293   // in the case of dynamic pass pipelines.
294   if (!impl->activeContexts.empty())
295     impl->activeContexts.back()->disable();
296 
297   // Collect all of the parent scopes of this operation.
298   SmallVector<OperationName> scopes;
299   while (Operation *parentOp = op->getParentOp()) {
300     scopes.push_back(op->getName());
301     op = parentOp;
302   }
303 
304   // Emit a pass pipeline string for the current pass running on the current
305   // operation type.
306   std::string passStr;
307   llvm::raw_string_ostream passOS(passStr);
308   for (OperationName scope : llvm::reverse(scopes))
309     passOS << scope << "(";
310   pass->printAsTextualPipeline(passOS);
311   for (unsigned i = 0, e = scopes.size(); i < e; ++i)
312     passOS << ")";
313 
314   impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>(
315       passOS.str(), op, impl->streamFactory, impl->pmFlagVerifyPasses));
316 }
prepareReproducerFor(iterator_range<PassManager::pass_iterator> passes,Operation * op)317 void PassCrashReproducerGenerator::prepareReproducerFor(
318     iterator_range<PassManager::pass_iterator> passes, Operation *op) {
319   std::string passStr;
320   llvm::raw_string_ostream passOS(passStr);
321   llvm::interleaveComma(
322       passes, passOS, [&](Pass &pass) { pass.printAsTextualPipeline(passOS); });
323 
324   impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>(
325       passOS.str(), op, impl->streamFactory, impl->pmFlagVerifyPasses));
326 }
327 
removeLastReproducerFor(Pass * pass,Operation * op)328 void PassCrashReproducerGenerator::removeLastReproducerFor(Pass *pass,
329                                                            Operation *op) {
330   // We only pop the active context if we are tracking local reproducers.
331   impl->runningPasses.remove(std::make_pair(pass, op));
332   if (impl->localReproducer) {
333     impl->activeContexts.pop_back();
334 
335     // Re-enable the previous pass recovery context, if there was one. This may
336     // happen in the case of dynamic pass pipelines.
337     if (!impl->activeContexts.empty())
338       impl->activeContexts.back()->enable();
339   }
340 }
341 
342 //===----------------------------------------------------------------------===//
343 // CrashReproducerInstrumentation
344 //===----------------------------------------------------------------------===//
345 
346 namespace {
347 struct CrashReproducerInstrumentation : public PassInstrumentation {
CrashReproducerInstrumentation__anon20bf54df0311::CrashReproducerInstrumentation348   CrashReproducerInstrumentation(PassCrashReproducerGenerator &generator)
349       : generator(generator) {}
350   ~CrashReproducerInstrumentation() override = default;
351 
runBeforePass__anon20bf54df0311::CrashReproducerInstrumentation352   void runBeforePass(Pass *pass, Operation *op) override {
353     if (!isa<OpToOpPassAdaptor>(pass))
354       generator.prepareReproducerFor(pass, op);
355   }
356 
runAfterPass__anon20bf54df0311::CrashReproducerInstrumentation357   void runAfterPass(Pass *pass, Operation *op) override {
358     if (!isa<OpToOpPassAdaptor>(pass))
359       generator.removeLastReproducerFor(pass, op);
360   }
361 
runAfterPassFailed__anon20bf54df0311::CrashReproducerInstrumentation362   void runAfterPassFailed(Pass *pass, Operation *op) override {
363     generator.finalize(op, /*executionResult=*/failure());
364   }
365 
366 private:
367   /// The generator used to create crash reproducers.
368   PassCrashReproducerGenerator &generator;
369 };
370 } // end anonymous namespace
371 
372 //===----------------------------------------------------------------------===//
373 // FileReproducerStream
374 //===----------------------------------------------------------------------===//
375 
376 namespace {
377 /// This class represents a default instance of PassManager::ReproducerStream
378 /// that is backed by a file.
379 struct FileReproducerStream : public PassManager::ReproducerStream {
FileReproducerStream__anon20bf54df0411::FileReproducerStream380   FileReproducerStream(std::unique_ptr<llvm::ToolOutputFile> outputFile)
381       : outputFile(std::move(outputFile)) {}
~FileReproducerStream__anon20bf54df0411::FileReproducerStream382   ~FileReproducerStream() override { outputFile->keep(); }
383 
384   /// Returns a description of the reproducer stream.
description__anon20bf54df0411::FileReproducerStream385   StringRef description() override { return outputFile->getFilename(); }
386 
387   /// Returns the stream on which to output the reproducer.
os__anon20bf54df0411::FileReproducerStream388   raw_ostream &os() override { return outputFile->os(); }
389 
390 private:
391   /// ToolOutputFile corresponding to opened `filename`.
392   std::unique_ptr<llvm::ToolOutputFile> outputFile = nullptr;
393 };
394 } // end anonymous namespace
395 
396 //===----------------------------------------------------------------------===//
397 // PassManager
398 //===----------------------------------------------------------------------===//
399 
runWithCrashRecovery(Operation * op,AnalysisManager am)400 LogicalResult PassManager::runWithCrashRecovery(Operation *op,
401                                                 AnalysisManager am) {
402   crashReproGenerator->initialize(getPasses(), op, verifyPasses);
403 
404   // Safely invoke the passes within a recovery context.
405   LogicalResult passManagerResult = failure();
406   llvm::CrashRecoveryContext recoveryContext;
407   recoveryContext.RunSafelyOnThread(
408       [&] { passManagerResult = runPasses(op, am); });
409   crashReproGenerator->finalize(op, passManagerResult);
410   return passManagerResult;
411 }
412 
enableCrashReproducerGeneration(StringRef outputFile,bool genLocalReproducer)413 void PassManager::enableCrashReproducerGeneration(StringRef outputFile,
414                                                   bool genLocalReproducer) {
415   // Capture the filename by value in case outputFile is out of scope when
416   // invoked.
417   std::string filename = outputFile.str();
418   enableCrashReproducerGeneration(
419       [filename](std::string &error) -> std::unique_ptr<ReproducerStream> {
420         std::unique_ptr<llvm::ToolOutputFile> outputFile =
421             mlir::openOutputFile(filename, &error);
422         if (!outputFile) {
423           error = "Failed to create reproducer stream: " + error;
424           return nullptr;
425         }
426         return std::make_unique<FileReproducerStream>(std::move(outputFile));
427       },
428       genLocalReproducer);
429 }
430 
enableCrashReproducerGeneration(ReproducerStreamFactory factory,bool genLocalReproducer)431 void PassManager::enableCrashReproducerGeneration(
432     ReproducerStreamFactory factory, bool genLocalReproducer) {
433   assert(!crashReproGenerator &&
434          "crash reproducer has already been initialized");
435   if (genLocalReproducer && getContext()->isMultithreadingEnabled())
436     llvm::report_fatal_error(
437         "Local crash reproduction can't be setup on a "
438         "pass-manager without disabling multi-threading first.");
439 
440   crashReproGenerator = std::make_unique<PassCrashReproducerGenerator>(
441       factory, genLocalReproducer);
442   addInstrumentation(
443       std::make_unique<CrashReproducerInstrumentation>(*crashReproGenerator));
444 }
445