1 //===- Pass.cpp - Pass infrastructure implementation ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements common pass infrastructure.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Pass/Pass.h"
14 #include "PassDetail.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Verifier.h"
18 #include "mlir/Support/FileUtilities.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/ScopeExit.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/Support/CommandLine.h"
23 #include "llvm/Support/CrashRecoveryContext.h"
24 #include "llvm/Support/Mutex.h"
25 #include "llvm/Support/Parallel.h"
26 #include "llvm/Support/Signals.h"
27 #include "llvm/Support/Threading.h"
28 #include "llvm/Support/ToolOutputFile.h"
29
30 using namespace mlir;
31 using namespace mlir::detail;
32
33 //===----------------------------------------------------------------------===//
34 // Pass
35 //===----------------------------------------------------------------------===//
36
37 /// Out of line virtual method to ensure vtables and metadata are emitted to a
38 /// single .o file.
anchor()39 void Pass::anchor() {}
40
41 /// Attempt to initialize the options of this pass from the given string.
initializeOptions(StringRef options)42 LogicalResult Pass::initializeOptions(StringRef options) {
43 return passOptions.parseFromString(options);
44 }
45
46 /// Copy the option values from 'other', which is another instance of this
47 /// pass.
copyOptionValuesFrom(const Pass * other)48 void Pass::copyOptionValuesFrom(const Pass *other) {
49 passOptions.copyOptionValuesFrom(other->passOptions);
50 }
51
52 /// Prints out the pass in the textual representation of pipelines. If this is
53 /// an adaptor pass, print with the op_name(sub_pass,...) format.
printAsTextualPipeline(raw_ostream & os)54 void Pass::printAsTextualPipeline(raw_ostream &os) {
55 // Special case for adaptors to use the 'op_name(sub_passes)' format.
56 if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(this)) {
57 llvm::interleaveComma(adaptor->getPassManagers(), os,
58 [&](OpPassManager &pm) {
59 os << pm.getOpName() << "(";
60 pm.printAsTextualPipeline(os);
61 os << ")";
62 });
63 return;
64 }
65 // Otherwise, print the pass argument followed by its options. If the pass
66 // doesn't have an argument, print the name of the pass to give some indicator
67 // of what pass was run.
68 StringRef argument = getArgument();
69 if (!argument.empty())
70 os << argument;
71 else
72 os << "unknown<" << getName() << ">";
73 passOptions.print(os);
74 }
75
76 //===----------------------------------------------------------------------===//
77 // OpPassManagerImpl
78 //===----------------------------------------------------------------------===//
79
80 namespace mlir {
81 namespace detail {
82 struct OpPassManagerImpl {
OpPassManagerImplmlir::detail::OpPassManagerImpl83 OpPassManagerImpl(Identifier identifier, OpPassManager::Nesting nesting)
84 : name(identifier.str()), identifier(identifier),
85 initializationGeneration(0), nesting(nesting) {}
OpPassManagerImplmlir::detail::OpPassManagerImpl86 OpPassManagerImpl(StringRef name, OpPassManager::Nesting nesting)
87 : name(name), initializationGeneration(0), nesting(nesting) {}
88
89 /// Merge the passes of this pass manager into the one provided.
90 void mergeInto(OpPassManagerImpl &rhs);
91
92 /// Nest a new operation pass manager for the given operation kind under this
93 /// pass manager.
94 OpPassManager &nest(Identifier nestedName);
95 OpPassManager &nest(StringRef nestedName);
96
97 /// Add the given pass to this pass manager. If this pass has a concrete
98 /// operation type, it must be the same type as this pass manager.
99 void addPass(std::unique_ptr<Pass> pass);
100
101 /// Coalesce adjacent AdaptorPasses into one large adaptor. This runs
102 /// recursively through the pipeline graph.
103 void coalesceAdjacentAdaptorPasses();
104
105 /// Split all of AdaptorPasses such that each adaptor only contains one leaf
106 /// pass.
107 void splitAdaptorPasses();
108
109 /// Return the operation name of this pass manager as an identifier.
getOpNamemlir::detail::OpPassManagerImpl110 Identifier getOpName(MLIRContext &context) {
111 if (!identifier)
112 identifier = Identifier::get(name, &context);
113 return *identifier;
114 }
115
116 /// The name of the operation that passes of this pass manager operate on.
117 std::string name;
118
119 /// The cached identifier (internalized in the context) for the name of the
120 /// operation that passes of this pass manager operate on.
121 Optional<Identifier> identifier;
122
123 /// The set of passes to run as part of this pass manager.
124 std::vector<std::unique_ptr<Pass>> passes;
125
126 /// The current initialization generation of this pass manager. This is used
127 /// to indicate when a pass manager should be reinitialized.
128 unsigned initializationGeneration;
129
130 /// Control the implicit nesting of passes that mismatch the name set for this
131 /// OpPassManager.
132 OpPassManager::Nesting nesting;
133 };
134 } // end namespace detail
135 } // end namespace mlir
136
mergeInto(OpPassManagerImpl & rhs)137 void OpPassManagerImpl::mergeInto(OpPassManagerImpl &rhs) {
138 assert(name == rhs.name && "merging unrelated pass managers");
139 for (auto &pass : passes)
140 rhs.passes.push_back(std::move(pass));
141 passes.clear();
142 }
143
nest(Identifier nestedName)144 OpPassManager &OpPassManagerImpl::nest(Identifier nestedName) {
145 OpPassManager nested(nestedName, nesting);
146 auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
147 addPass(std::unique_ptr<Pass>(adaptor));
148 return adaptor->getPassManagers().front();
149 }
150
nest(StringRef nestedName)151 OpPassManager &OpPassManagerImpl::nest(StringRef nestedName) {
152 OpPassManager nested(nestedName, nesting);
153 auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
154 addPass(std::unique_ptr<Pass>(adaptor));
155 return adaptor->getPassManagers().front();
156 }
157
addPass(std::unique_ptr<Pass> pass)158 void OpPassManagerImpl::addPass(std::unique_ptr<Pass> pass) {
159 // If this pass runs on a different operation than this pass manager, then
160 // implicitly nest a pass manager for this operation if enabled.
161 auto passOpName = pass->getOpName();
162 if (passOpName && passOpName->str() != name) {
163 if (nesting == OpPassManager::Nesting::Implicit)
164 return nest(*passOpName).addPass(std::move(pass));
165 llvm::report_fatal_error(llvm::Twine("Can't add pass '") + pass->getName() +
166 "' restricted to '" + *passOpName +
167 "' on a PassManager intended to run on '" + name +
168 "', did you intend to nest?");
169 }
170
171 passes.emplace_back(std::move(pass));
172 }
173
coalesceAdjacentAdaptorPasses()174 void OpPassManagerImpl::coalesceAdjacentAdaptorPasses() {
175 // Bail out early if there are no adaptor passes.
176 if (llvm::none_of(passes, [](std::unique_ptr<Pass> &pass) {
177 return isa<OpToOpPassAdaptor>(pass.get());
178 }))
179 return;
180
181 // Walk the pass list and merge adjacent adaptors.
182 OpToOpPassAdaptor *lastAdaptor = nullptr;
183 for (auto it = passes.begin(), e = passes.end(); it != e; ++it) {
184 // Check to see if this pass is an adaptor.
185 if (auto *currentAdaptor = dyn_cast<OpToOpPassAdaptor>(it->get())) {
186 // If it is the first adaptor in a possible chain, remember it and
187 // continue.
188 if (!lastAdaptor) {
189 lastAdaptor = currentAdaptor;
190 continue;
191 }
192
193 // Otherwise, merge into the existing adaptor and delete the current one.
194 currentAdaptor->mergeInto(*lastAdaptor);
195 it->reset();
196 } else if (lastAdaptor) {
197 // If this pass is not an adaptor, then coalesce and forget any existing
198 // adaptor.
199 for (auto &pm : lastAdaptor->getPassManagers())
200 pm.getImpl().coalesceAdjacentAdaptorPasses();
201 lastAdaptor = nullptr;
202 }
203 }
204
205 // If there was an adaptor at the end of the manager, coalesce it as well.
206 if (lastAdaptor) {
207 for (auto &pm : lastAdaptor->getPassManagers())
208 pm.getImpl().coalesceAdjacentAdaptorPasses();
209 }
210
211 // Now that the adaptors have been merged, erase the empty slot corresponding
212 // to the merged adaptors that were nulled-out in the loop above.
213 llvm::erase_if(passes, std::logical_not<std::unique_ptr<Pass>>());
214 }
215
splitAdaptorPasses()216 void OpPassManagerImpl::splitAdaptorPasses() {
217 std::vector<std::unique_ptr<Pass>> oldPasses;
218 std::swap(passes, oldPasses);
219
220 for (std::unique_ptr<Pass> &pass : oldPasses) {
221 // If this pass isn't an adaptor, move it directly to the new pass list.
222 auto *currentAdaptor = dyn_cast<OpToOpPassAdaptor>(pass.get());
223 if (!currentAdaptor) {
224 addPass(std::move(pass));
225 continue;
226 }
227
228 // Otherwise, split the adaptors of each manager within the adaptor.
229 for (OpPassManager &adaptorPM : currentAdaptor->getPassManagers()) {
230 adaptorPM.getImpl().splitAdaptorPasses();
231 for (std::unique_ptr<Pass> &nestedPass : adaptorPM.getImpl().passes)
232 nest(adaptorPM.getOpName()).addPass(std::move(nestedPass));
233 }
234 }
235 }
236
237 //===----------------------------------------------------------------------===//
238 // OpPassManager
239 //===----------------------------------------------------------------------===//
240
OpPassManager(Identifier name,Nesting nesting)241 OpPassManager::OpPassManager(Identifier name, Nesting nesting)
242 : impl(new OpPassManagerImpl(name, nesting)) {}
OpPassManager(StringRef name,Nesting nesting)243 OpPassManager::OpPassManager(StringRef name, Nesting nesting)
244 : impl(new OpPassManagerImpl(name, nesting)) {}
OpPassManager(OpPassManager && rhs)245 OpPassManager::OpPassManager(OpPassManager &&rhs) : impl(std::move(rhs.impl)) {}
OpPassManager(const OpPassManager & rhs)246 OpPassManager::OpPassManager(const OpPassManager &rhs) { *this = rhs; }
operator =(const OpPassManager & rhs)247 OpPassManager &OpPassManager::operator=(const OpPassManager &rhs) {
248 impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->nesting));
249 impl->initializationGeneration = rhs.impl->initializationGeneration;
250 for (auto &pass : rhs.impl->passes)
251 impl->passes.emplace_back(pass->clone());
252 return *this;
253 }
254
~OpPassManager()255 OpPassManager::~OpPassManager() {}
256
begin()257 OpPassManager::pass_iterator OpPassManager::begin() {
258 return MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.begin();
259 }
end()260 OpPassManager::pass_iterator OpPassManager::end() {
261 return MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.end();
262 }
263
begin() const264 OpPassManager::const_pass_iterator OpPassManager::begin() const {
265 return ArrayRef<std::unique_ptr<Pass>>{impl->passes}.begin();
266 }
end() const267 OpPassManager::const_pass_iterator OpPassManager::end() const {
268 return ArrayRef<std::unique_ptr<Pass>>{impl->passes}.end();
269 }
270
271 /// Nest a new operation pass manager for the given operation kind under this
272 /// pass manager.
nest(Identifier nestedName)273 OpPassManager &OpPassManager::nest(Identifier nestedName) {
274 return impl->nest(nestedName);
275 }
nest(StringRef nestedName)276 OpPassManager &OpPassManager::nest(StringRef nestedName) {
277 return impl->nest(nestedName);
278 }
279
280 /// Add the given pass to this pass manager. If this pass has a concrete
281 /// operation type, it must be the same type as this pass manager.
addPass(std::unique_ptr<Pass> pass)282 void OpPassManager::addPass(std::unique_ptr<Pass> pass) {
283 impl->addPass(std::move(pass));
284 }
285
286 /// Returns the number of passes held by this manager.
size() const287 size_t OpPassManager::size() const { return impl->passes.size(); }
288
289 /// Returns the internal implementation instance.
getImpl()290 OpPassManagerImpl &OpPassManager::getImpl() { return *impl; }
291
292 /// Return the operation name that this pass manager operates on.
getOpName() const293 StringRef OpPassManager::getOpName() const { return impl->name; }
294
295 /// Return the operation name that this pass manager operates on.
getOpName(MLIRContext & context) const296 Identifier OpPassManager::getOpName(MLIRContext &context) const {
297 return impl->getOpName(context);
298 }
299
300 /// Prints out the given passes as the textual representation of a pipeline.
printAsTextualPipeline(ArrayRef<std::unique_ptr<Pass>> passes,raw_ostream & os)301 static void printAsTextualPipeline(ArrayRef<std::unique_ptr<Pass>> passes,
302 raw_ostream &os) {
303 llvm::interleaveComma(passes, os, [&](const std::unique_ptr<Pass> &pass) {
304 pass->printAsTextualPipeline(os);
305 });
306 }
307
308 /// Prints out the passes of the pass manager as the textual representation
309 /// of pipelines.
printAsTextualPipeline(raw_ostream & os)310 void OpPassManager::printAsTextualPipeline(raw_ostream &os) {
311 ::printAsTextualPipeline(impl->passes, os);
312 }
313
dump()314 void OpPassManager::dump() {
315 llvm::errs() << "Pass Manager with " << impl->passes.size() << " passes: ";
316 ::printAsTextualPipeline(impl->passes, llvm::errs());
317 llvm::errs() << "\n";
318 }
319
registerDialectsForPipeline(const OpPassManager & pm,DialectRegistry & dialects)320 static void registerDialectsForPipeline(const OpPassManager &pm,
321 DialectRegistry &dialects) {
322 for (const Pass &pass : pm.getPasses())
323 pass.getDependentDialects(dialects);
324 }
325
getDependentDialects(DialectRegistry & dialects) const326 void OpPassManager::getDependentDialects(DialectRegistry &dialects) const {
327 registerDialectsForPipeline(*this, dialects);
328 }
329
setNesting(Nesting nesting)330 void OpPassManager::setNesting(Nesting nesting) { impl->nesting = nesting; }
331
getNesting()332 OpPassManager::Nesting OpPassManager::getNesting() { return impl->nesting; }
333
initialize(MLIRContext * context,unsigned newInitGeneration)334 void OpPassManager::initialize(MLIRContext *context,
335 unsigned newInitGeneration) {
336 if (impl->initializationGeneration == newInitGeneration)
337 return;
338 impl->initializationGeneration = newInitGeneration;
339 for (Pass &pass : getPasses()) {
340 // If this pass isn't an adaptor, directly initialize it.
341 auto *adaptor = dyn_cast<OpToOpPassAdaptor>(&pass);
342 if (!adaptor) {
343 pass.initialize(context);
344 continue;
345 }
346
347 // Otherwise, initialize each of the adaptors pass managers.
348 for (OpPassManager &adaptorPM : adaptor->getPassManagers())
349 adaptorPM.initialize(context, newInitGeneration);
350 }
351 }
352
353 //===----------------------------------------------------------------------===//
354 // OpToOpPassAdaptor
355 //===----------------------------------------------------------------------===//
356
run(Pass * pass,Operation * op,AnalysisManager am,bool verifyPasses,unsigned parentInitGeneration)357 LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
358 AnalysisManager am, bool verifyPasses,
359 unsigned parentInitGeneration) {
360 if (!op->getName().getAbstractOperation())
361 return op->emitOpError()
362 << "trying to schedule a pass on an unregistered operation";
363 if (!op->getName().getAbstractOperation()->hasProperty(
364 OperationProperty::IsolatedFromAbove))
365 return op->emitOpError() << "trying to schedule a pass on an operation not "
366 "marked as 'IsolatedFromAbove'";
367
368 // Initialize the pass state with a callback for the pass to dynamically
369 // execute a pipeline on the currently visited operation.
370 PassInstrumentor *pi = am.getPassInstrumentor();
371 PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
372 pass};
373 auto dynamic_pipeline_callback = [&](OpPassManager &pipeline,
374 Operation *root) -> LogicalResult {
375 if (!op->isAncestor(root))
376 return root->emitOpError()
377 << "Trying to schedule a dynamic pipeline on an "
378 "operation that isn't "
379 "nested under the current operation the pass is processing";
380 assert(pipeline.getOpName() == root->getName().getStringRef());
381
382 // Initialize the user provided pipeline and execute the pipeline.
383 pipeline.initialize(root->getContext(), parentInitGeneration);
384 AnalysisManager nestedAm = root == op ? am : am.nest(root);
385 return OpToOpPassAdaptor::runPipeline(pipeline.getPasses(), root, nestedAm,
386 verifyPasses, parentInitGeneration,
387 pi, &parentInfo);
388 };
389 pass->passState.emplace(op, am, dynamic_pipeline_callback);
390
391 // Instrument before the pass has run.
392 if (pi)
393 pi->runBeforePass(pass, op);
394
395 // Invoke the virtual runOnOperation method.
396 if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
397 adaptor->runOnOperation(verifyPasses);
398 else
399 pass->runOnOperation();
400 bool passFailed = pass->passState->irAndPassFailed.getInt();
401
402 // Invalidate any non preserved analyses.
403 am.invalidate(pass->passState->preservedAnalyses);
404
405 // Run the verifier if this pass didn't fail already.
406 if (!passFailed && verifyPasses)
407 passFailed = failed(verify(op));
408
409 // Instrument after the pass has run.
410 if (pi) {
411 if (passFailed)
412 pi->runAfterPassFailed(pass, op);
413 else
414 pi->runAfterPass(pass, op);
415 }
416
417 // Return if the pass signaled a failure.
418 return failure(passFailed);
419 }
420
421 /// Run the given operation and analysis manager on a provided op pass manager.
runPipeline(iterator_range<OpPassManager::pass_iterator> passes,Operation * op,AnalysisManager am,bool verifyPasses,unsigned parentInitGeneration,PassInstrumentor * instrumentor,const PassInstrumentation::PipelineParentInfo * parentInfo)422 LogicalResult OpToOpPassAdaptor::runPipeline(
423 iterator_range<OpPassManager::pass_iterator> passes, Operation *op,
424 AnalysisManager am, bool verifyPasses, unsigned parentInitGeneration,
425 PassInstrumentor *instrumentor,
426 const PassInstrumentation::PipelineParentInfo *parentInfo) {
427 assert((!instrumentor || parentInfo) &&
428 "expected parent info if instrumentor is provided");
429 auto scope_exit = llvm::make_scope_exit([&] {
430 // Clear out any computed operation analyses. These analyses won't be used
431 // any more in this pipeline, and this helps reduce the current working set
432 // of memory. If preserving these analyses becomes important in the future
433 // we can re-evaluate this.
434 am.clear();
435 });
436
437 // Run the pipeline over the provided operation.
438 if (instrumentor)
439 instrumentor->runBeforePipeline(op->getName().getIdentifier(), *parentInfo);
440 for (Pass &pass : passes)
441 if (failed(run(&pass, op, am, verifyPasses, parentInitGeneration)))
442 return failure();
443 if (instrumentor)
444 instrumentor->runAfterPipeline(op->getName().getIdentifier(), *parentInfo);
445 return success();
446 }
447
448 /// Find an operation pass manager that can operate on an operation of the given
449 /// type, or nullptr if one does not exist.
findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,StringRef name)450 static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
451 StringRef name) {
452 auto it = llvm::find_if(
453 mgrs, [&](OpPassManager &mgr) { return mgr.getOpName() == name; });
454 return it == mgrs.end() ? nullptr : &*it;
455 }
456
457 /// Find an operation pass manager that can operate on an operation of the given
458 /// type, or nullptr if one does not exist.
findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,Identifier name,MLIRContext & context)459 static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
460 Identifier name,
461 MLIRContext &context) {
462 auto it = llvm::find_if(
463 mgrs, [&](OpPassManager &mgr) { return mgr.getOpName(context) == name; });
464 return it == mgrs.end() ? nullptr : &*it;
465 }
466
OpToOpPassAdaptor(OpPassManager && mgr)467 OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr) {
468 mgrs.emplace_back(std::move(mgr));
469 }
470
getDependentDialects(DialectRegistry & dialects) const471 void OpToOpPassAdaptor::getDependentDialects(DialectRegistry &dialects) const {
472 for (auto &pm : mgrs)
473 pm.getDependentDialects(dialects);
474 }
475
476 /// Merge the current pass adaptor into given 'rhs'.
mergeInto(OpToOpPassAdaptor & rhs)477 void OpToOpPassAdaptor::mergeInto(OpToOpPassAdaptor &rhs) {
478 for (auto &pm : mgrs) {
479 // If an existing pass manager exists, then merge the given pass manager
480 // into it.
481 if (auto *existingPM = findPassManagerFor(rhs.mgrs, pm.getOpName())) {
482 pm.getImpl().mergeInto(existingPM->getImpl());
483 } else {
484 // Otherwise, add the given pass manager to the list.
485 rhs.mgrs.emplace_back(std::move(pm));
486 }
487 }
488 mgrs.clear();
489
490 // After coalescing, sort the pass managers within rhs by name.
491 llvm::array_pod_sort(rhs.mgrs.begin(), rhs.mgrs.end(),
492 [](const OpPassManager *lhs, const OpPassManager *rhs) {
493 return lhs->getOpName().compare(rhs->getOpName());
494 });
495 }
496
497 /// Returns the adaptor pass name.
getAdaptorName()498 std::string OpToOpPassAdaptor::getAdaptorName() {
499 std::string name = "Pipeline Collection : [";
500 llvm::raw_string_ostream os(name);
501 llvm::interleaveComma(getPassManagers(), os, [&](OpPassManager &pm) {
502 os << '\'' << pm.getOpName() << '\'';
503 });
504 os << ']';
505 return os.str();
506 }
507
runOnOperation()508 void OpToOpPassAdaptor::runOnOperation() {
509 llvm_unreachable(
510 "Unexpected call to Pass::runOnOperation() on OpToOpPassAdaptor");
511 }
512
513 /// Run the held pipeline over all nested operations.
runOnOperation(bool verifyPasses)514 void OpToOpPassAdaptor::runOnOperation(bool verifyPasses) {
515 if (getContext().isMultithreadingEnabled())
516 runOnOperationAsyncImpl(verifyPasses);
517 else
518 runOnOperationImpl(verifyPasses);
519 }
520
521 /// Run this pass adaptor synchronously.
runOnOperationImpl(bool verifyPasses)522 void OpToOpPassAdaptor::runOnOperationImpl(bool verifyPasses) {
523 auto am = getAnalysisManager();
524 PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
525 this};
526 auto *instrumentor = am.getPassInstrumentor();
527 for (auto ®ion : getOperation()->getRegions()) {
528 for (auto &block : region) {
529 for (auto &op : block) {
530 auto *mgr = findPassManagerFor(mgrs, op.getName().getIdentifier(),
531 *op.getContext());
532 if (!mgr)
533 continue;
534
535 // Run the held pipeline over the current operation.
536 unsigned initGeneration = mgr->impl->initializationGeneration;
537 if (failed(runPipeline(mgr->getPasses(), &op, am.nest(&op),
538 verifyPasses, initGeneration, instrumentor,
539 &parentInfo)))
540 return signalPassFailure();
541 }
542 }
543 }
544 }
545
546 /// Utility functor that checks if the two ranges of pass managers have a size
547 /// mismatch.
hasSizeMismatch(ArrayRef<OpPassManager> lhs,ArrayRef<OpPassManager> rhs)548 static bool hasSizeMismatch(ArrayRef<OpPassManager> lhs,
549 ArrayRef<OpPassManager> rhs) {
550 return lhs.size() != rhs.size() ||
551 llvm::any_of(llvm::seq<size_t>(0, lhs.size()),
552 [&](size_t i) { return lhs[i].size() != rhs[i].size(); });
553 }
554
555 /// Run this pass adaptor synchronously.
runOnOperationAsyncImpl(bool verifyPasses)556 void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
557 AnalysisManager am = getAnalysisManager();
558
559 // Create the async executors if they haven't been created, or if the main
560 // pipeline has changed.
561 if (asyncExecutors.empty() || hasSizeMismatch(asyncExecutors.front(), mgrs))
562 asyncExecutors.assign(llvm::hardware_concurrency().compute_thread_count(),
563 mgrs);
564
565 // Run a prepass over the operation to collect the nested operations to
566 // execute over. This ensures that an analysis manager exists for each
567 // operation, as well as providing a queue of operations to execute over.
568 std::vector<std::pair<Operation *, AnalysisManager>> opAMPairs;
569 for (auto ®ion : getOperation()->getRegions()) {
570 for (auto &block : region) {
571 for (auto &op : block) {
572 // Add this operation iff the name matches any of the pass managers.
573 if (findPassManagerFor(mgrs, op.getName().getIdentifier(),
574 getContext()))
575 opAMPairs.emplace_back(&op, am.nest(&op));
576 }
577 }
578 }
579
580 // A parallel diagnostic handler that provides deterministic diagnostic
581 // ordering.
582 ParallelDiagnosticHandler diagHandler(&getContext());
583
584 // An index for the current operation/analysis manager pair.
585 std::atomic<unsigned> opIt(0);
586
587 // Get the current thread for this adaptor.
588 PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
589 this};
590 auto *instrumentor = am.getPassInstrumentor();
591
592 // An atomic failure variable for the async executors.
593 std::atomic<bool> passFailed(false);
594 llvm::parallelForEach(
595 asyncExecutors.begin(),
596 std::next(asyncExecutors.begin(),
597 std::min(asyncExecutors.size(), opAMPairs.size())),
598 [&](MutableArrayRef<OpPassManager> pms) {
599 for (auto e = opAMPairs.size(); !passFailed && opIt < e;) {
600 // Get the next available operation index.
601 unsigned nextID = opIt++;
602 if (nextID >= e)
603 break;
604
605 // Set the order id for this thread in the diagnostic handler.
606 diagHandler.setOrderIDForThread(nextID);
607
608 // Get the pass manager for this operation and execute it.
609 auto &it = opAMPairs[nextID];
610 auto *pm = findPassManagerFor(
611 pms, it.first->getName().getIdentifier(), getContext());
612 assert(pm && "expected valid pass manager for operation");
613
614 unsigned initGeneration = pm->impl->initializationGeneration;
615 LogicalResult pipelineResult =
616 runPipeline(pm->getPasses(), it.first, it.second, verifyPasses,
617 initGeneration, instrumentor, &parentInfo);
618
619 // Drop this thread from being tracked by the diagnostic handler.
620 // After this task has finished, the thread may be used outside of
621 // this pass manager context meaning that we don't want to track
622 // diagnostics from it anymore.
623 diagHandler.eraseOrderIDForThread();
624
625 // Handle a failed pipeline result.
626 if (failed(pipelineResult)) {
627 passFailed = true;
628 break;
629 }
630 }
631 });
632
633 // Signal a failure if any of the executors failed.
634 if (passFailed)
635 signalPassFailure();
636 }
637
638 //===----------------------------------------------------------------------===//
639 // PassCrashReproducer
640 //===----------------------------------------------------------------------===//
641
642 namespace {
643 /// This class contains all of the context for generating a recovery reproducer.
644 /// Each recovery context is registered globally to allow for generating
645 /// reproducers when a signal is raised, such as a segfault.
646 struct RecoveryReproducerContext {
647 RecoveryReproducerContext(MutableArrayRef<std::unique_ptr<Pass>> passes,
648 Operation *op,
649 PassManager::ReproducerStreamFactory &crashStream,
650 bool disableThreads, bool verifyPasses);
651 ~RecoveryReproducerContext();
652
653 /// Generate a reproducer with the current context.
654 LogicalResult generate(std::string &error);
655
656 private:
657 /// This function is invoked in the event of a crash.
658 static void crashHandler(void *);
659
660 /// Register a signal handler to run in the event of a crash.
661 static void registerSignalHandler();
662
663 /// The textual description of the currently executing pipeline.
664 std::string pipeline;
665
666 /// The MLIR operation representing the IR before the crash.
667 Operation *preCrashOperation;
668
669 /// The factory for the reproducer output stream to use when generating the
670 /// reproducer.
671 PassManager::ReproducerStreamFactory &crashStreamFactory;
672
673 /// Various pass manager and context flags.
674 bool disableThreads;
675 bool verifyPasses;
676
677 /// The current set of active reproducer contexts. This is used in the event
678 /// of a crash. This is not thread_local as the pass manager may produce any
679 /// number of child threads. This uses a set to allow for multiple MLIR pass
680 /// managers to be running at the same time.
681 static llvm::ManagedStatic<llvm::sys::SmartMutex<true>> reproducerMutex;
682 static llvm::ManagedStatic<
683 llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
684 reproducerSet;
685 };
686
687 /// Instance of ReproducerStream backed by file.
688 struct FileReproducerStream : public PassManager::ReproducerStream {
FileReproducerStream__anon4c2169be0c11::FileReproducerStream689 FileReproducerStream(std::unique_ptr<llvm::ToolOutputFile> outputFile)
690 : outputFile(std::move(outputFile)) {}
691 ~FileReproducerStream() override;
692
693 /// Description of the reproducer stream.
694 StringRef description() override;
695
696 /// Stream on which to output reprooducer.
697 raw_ostream &os() override;
698
699 private:
700 /// ToolOutputFile corresponding to opened `filename`.
701 std::unique_ptr<llvm::ToolOutputFile> outputFile = nullptr;
702 };
703
704 } // end anonymous namespace
705
706 llvm::ManagedStatic<llvm::sys::SmartMutex<true>>
707 RecoveryReproducerContext::reproducerMutex;
708 llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
709 RecoveryReproducerContext::reproducerSet;
710
RecoveryReproducerContext(MutableArrayRef<std::unique_ptr<Pass>> passes,Operation * op,PassManager::ReproducerStreamFactory & crashStreamFactory,bool disableThreads,bool verifyPasses)711 RecoveryReproducerContext::RecoveryReproducerContext(
712 MutableArrayRef<std::unique_ptr<Pass>> passes, Operation *op,
713 PassManager::ReproducerStreamFactory &crashStreamFactory,
714 bool disableThreads, bool verifyPasses)
715 : preCrashOperation(op->clone()), crashStreamFactory(crashStreamFactory),
716 disableThreads(disableThreads), verifyPasses(verifyPasses) {
717 // Grab the textual pipeline being executed..
718 {
719 llvm::raw_string_ostream pipelineOS(pipeline);
720 ::printAsTextualPipeline(passes, pipelineOS);
721 }
722
723 // Make sure that the handler is registered, and update the current context.
724 llvm::sys::SmartScopedLock<true> producerLock(*reproducerMutex);
725 if (reproducerSet->empty())
726 llvm::CrashRecoveryContext::Enable();
727 registerSignalHandler();
728 reproducerSet->insert(this);
729 }
730
~RecoveryReproducerContext()731 RecoveryReproducerContext::~RecoveryReproducerContext() {
732 // Erase the cloned preCrash IR that we cached.
733 preCrashOperation->erase();
734
735 llvm::sys::SmartScopedLock<true> producerLock(*reproducerMutex);
736 reproducerSet->remove(this);
737 if (reproducerSet->empty())
738 llvm::CrashRecoveryContext::Disable();
739 }
740
741 /// Description of the reproducer stream.
description()742 StringRef FileReproducerStream::description() {
743 return outputFile->getFilename();
744 }
745
746 /// Stream on which to output reproducer.
os()747 raw_ostream &FileReproducerStream::os() { return outputFile->os(); }
748
~FileReproducerStream()749 FileReproducerStream::~FileReproducerStream() { outputFile->keep(); }
750
generate(std::string & error)751 LogicalResult RecoveryReproducerContext::generate(std::string &error) {
752 std::unique_ptr<PassManager::ReproducerStream> crashStream =
753 crashStreamFactory(error);
754 if (!crashStream)
755 return failure();
756
757 // Output the current pass manager configuration.
758 auto &os = crashStream->os();
759 os << "// configuration: -pass-pipeline='" << pipeline << "'";
760 if (disableThreads)
761 os << " -mlir-disable-threading";
762 if (verifyPasses)
763 os << " -verify-each";
764 os << '\n';
765
766 // Output the .mlir module.
767 preCrashOperation->print(os);
768
769 bool shouldPrintOnOp =
770 preCrashOperation->getContext()->shouldPrintOpOnDiagnostic();
771 preCrashOperation->getContext()->printOpOnDiagnostic(false);
772 preCrashOperation->emitError()
773 << "A failure has been detected while processing the MLIR module, a "
774 "reproducer has been generated in '"
775 << crashStream->description() << "'";
776 preCrashOperation->getContext()->printOpOnDiagnostic(shouldPrintOnOp);
777 return success();
778 }
779
crashHandler(void *)780 void RecoveryReproducerContext::crashHandler(void *) {
781 // Walk the current stack of contexts and generate a reproducer for each one.
782 // We can't know for certain which one was the cause, so we need to generate
783 // a reproducer for all of them.
784 std::string ignored;
785 for (RecoveryReproducerContext *context : *reproducerSet)
786 context->generate(ignored);
787 }
788
registerSignalHandler()789 void RecoveryReproducerContext::registerSignalHandler() {
790 // Ensure that the handler is only registered once.
791 static bool registered =
792 (llvm::sys::AddSignalHandler(crashHandler, nullptr), false);
793 (void)registered;
794 }
795
796 /// Run the pass manager with crash recover enabled.
runWithCrashRecovery(Operation * op,AnalysisManager am)797 LogicalResult PassManager::runWithCrashRecovery(Operation *op,
798 AnalysisManager am) {
799 // If this isn't a local producer, run all of the passes in recovery mode.
800 if (!localReproducer)
801 return runWithCrashRecovery(impl->passes, op, am);
802
803 // Split the passes within adaptors to ensure that each pass can be run in
804 // isolation.
805 impl->splitAdaptorPasses();
806
807 // If this is a local producer, run each of the passes individually.
808 MutableArrayRef<std::unique_ptr<Pass>> passes = impl->passes;
809 for (std::unique_ptr<Pass> &pass : passes)
810 if (failed(runWithCrashRecovery(pass, op, am)))
811 return failure();
812 return success();
813 }
814
815 /// Run the given passes with crash recover enabled.
816 LogicalResult
runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,Operation * op,AnalysisManager am)817 PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
818 Operation *op, AnalysisManager am) {
819 RecoveryReproducerContext context(passes, op, crashReproducerStreamFactory,
820 !getContext()->isMultithreadingEnabled(),
821 verifyPasses);
822
823 // Safely invoke the passes within a recovery context.
824 LogicalResult passManagerResult = failure();
825 llvm::CrashRecoveryContext recoveryContext;
826 recoveryContext.RunSafelyOnThread([&] {
827 for (std::unique_ptr<Pass> &pass : passes)
828 if (failed(OpToOpPassAdaptor::run(pass.get(), op, am, verifyPasses,
829 impl->initializationGeneration)))
830 return;
831 passManagerResult = success();
832 });
833 if (succeeded(passManagerResult))
834 return success();
835
836 std::string error;
837 if (failed(context.generate(error)))
838 return op->emitError("<MLIR-PassManager-Crash-Reproducer>: ") << error;
839 return failure();
840 }
841
842 //===----------------------------------------------------------------------===//
843 // PassManager
844 //===----------------------------------------------------------------------===//
845
PassManager(MLIRContext * ctx,Nesting nesting,StringRef operationName)846 PassManager::PassManager(MLIRContext *ctx, Nesting nesting,
847 StringRef operationName)
848 : OpPassManager(Identifier::get(operationName, ctx), nesting), context(ctx),
849 passTiming(false), localReproducer(false), verifyPasses(true) {}
850
~PassManager()851 PassManager::~PassManager() {}
852
enableVerifier(bool enabled)853 void PassManager::enableVerifier(bool enabled) { verifyPasses = enabled; }
854
855 /// Run the passes within this manager on the provided operation.
run(Operation * op)856 LogicalResult PassManager::run(Operation *op) {
857 MLIRContext *context = getContext();
858 assert(op->getName().getIdentifier() == getOpName(*context) &&
859 "operation has a different name than the PassManager");
860
861 // Before running, make sure to coalesce any adjacent pass adaptors in the
862 // pipeline.
863 getImpl().coalesceAdjacentAdaptorPasses();
864
865 // Register all dialects for the current pipeline.
866 DialectRegistry dependentDialects;
867 getDependentDialects(dependentDialects);
868 dependentDialects.loadAll(context);
869
870 // Initialize all of the passes within the pass manager with a new generation.
871 initialize(context, impl->initializationGeneration + 1);
872
873 // Construct a top level analysis manager for the pipeline.
874 ModuleAnalysisManager am(op, instrumentor.get());
875
876 // Notify the context that we start running a pipeline for book keeping.
877 context->enterMultiThreadedExecution();
878
879 // If reproducer generation is enabled, run the pass manager with crash
880 // handling enabled.
881 LogicalResult result =
882 crashReproducerStreamFactory
883 ? runWithCrashRecovery(op, am)
884 : OpToOpPassAdaptor::runPipeline(getPasses(), op, am, verifyPasses,
885 impl->initializationGeneration);
886
887 // Notify the context that the run is done.
888 context->exitMultiThreadedExecution();
889
890 // Dump all of the pass statistics if necessary.
891 if (passStatisticsMode)
892 dumpStatistics();
893 return result;
894 }
895
896 /// Enable support for the pass manager to generate a reproducer on the event
897 /// of a crash or a pass failure. `outputFile` is a .mlir filename used to write
898 /// the generated reproducer. If `genLocalReproducer` is true, the pass manager
899 /// will attempt to generate a local reproducer that contains the smallest
900 /// pipeline.
enableCrashReproducerGeneration(StringRef outputFile,bool genLocalReproducer)901 void PassManager::enableCrashReproducerGeneration(StringRef outputFile,
902 bool genLocalReproducer) {
903 // Capture the filename by value in case outputFile is out of scope when
904 // invoked.
905 std::string filename = outputFile.str();
906 enableCrashReproducerGeneration(
907 [filename](std::string &error) -> std::unique_ptr<ReproducerStream> {
908 std::unique_ptr<llvm::ToolOutputFile> outputFile =
909 mlir::openOutputFile(filename, &error);
910 if (!outputFile) {
911 error = "Failed to create reproducer stream: " + error;
912 return nullptr;
913 }
914 return std::make_unique<FileReproducerStream>(std::move(outputFile));
915 },
916 genLocalReproducer);
917 }
918
919 /// Enable support for the pass manager to generate a reproducer on the event
920 /// of a crash or a pass failure. `factory` is used to construct the streams
921 /// to write the generated reproducer to. If `genLocalReproducer` is true, the
922 /// pass manager will attempt to generate a local reproducer that contains the
923 /// smallest pipeline.
enableCrashReproducerGeneration(ReproducerStreamFactory factory,bool genLocalReproducer)924 void PassManager::enableCrashReproducerGeneration(
925 ReproducerStreamFactory factory, bool genLocalReproducer) {
926 crashReproducerStreamFactory = factory;
927 localReproducer = genLocalReproducer;
928 }
929
930 /// Add the provided instrumentation to the pass manager.
addInstrumentation(std::unique_ptr<PassInstrumentation> pi)931 void PassManager::addInstrumentation(std::unique_ptr<PassInstrumentation> pi) {
932 if (!instrumentor)
933 instrumentor = std::make_unique<PassInstrumentor>();
934
935 instrumentor->addInstrumentation(std::move(pi));
936 }
937
938 //===----------------------------------------------------------------------===//
939 // AnalysisManager
940 //===----------------------------------------------------------------------===//
941
942 /// Get an analysis manager for the given operation, which must be a proper
943 /// descendant of the current operation represented by this analysis manager.
nest(Operation * op)944 AnalysisManager AnalysisManager::nest(Operation *op) {
945 Operation *currentOp = impl->getOperation();
946 assert(currentOp->isProperAncestor(op) &&
947 "expected valid descendant operation");
948
949 // Check for the base case where the provided operation is immediately nested.
950 if (currentOp == op->getParentOp())
951 return nestImmediate(op);
952
953 // Otherwise, we need to collect all ancestors up to the current operation.
954 SmallVector<Operation *, 4> opAncestors;
955 do {
956 opAncestors.push_back(op);
957 op = op->getParentOp();
958 } while (op != currentOp);
959
960 AnalysisManager result = *this;
961 for (Operation *op : llvm::reverse(opAncestors))
962 result = result.nestImmediate(op);
963 return result;
964 }
965
966 /// Get an analysis manager for the given immediately nested child operation.
nestImmediate(Operation * op)967 AnalysisManager AnalysisManager::nestImmediate(Operation *op) {
968 assert(impl->getOperation() == op->getParentOp() &&
969 "expected immediate child operation");
970
971 auto it = impl->childAnalyses.find(op);
972 if (it == impl->childAnalyses.end())
973 it = impl->childAnalyses
974 .try_emplace(op, std::make_unique<NestedAnalysisMap>(op, impl))
975 .first;
976 return {it->second.get()};
977 }
978
979 /// Invalidate any non preserved analyses.
invalidate(const detail::PreservedAnalyses & pa)980 void detail::NestedAnalysisMap::invalidate(
981 const detail::PreservedAnalyses &pa) {
982 // If all analyses were preserved, then there is nothing to do here.
983 if (pa.isAll())
984 return;
985
986 // Invalidate the analyses for the current operation directly.
987 analyses.invalidate(pa);
988
989 // If no analyses were preserved, then just simply clear out the child
990 // analysis results.
991 if (pa.isNone()) {
992 childAnalyses.clear();
993 return;
994 }
995
996 // Otherwise, invalidate each child analysis map.
997 SmallVector<NestedAnalysisMap *, 8> mapsToInvalidate(1, this);
998 while (!mapsToInvalidate.empty()) {
999 auto *map = mapsToInvalidate.pop_back_val();
1000 for (auto &analysisPair : map->childAnalyses) {
1001 analysisPair.second->invalidate(pa);
1002 if (!analysisPair.second->childAnalyses.empty())
1003 mapsToInvalidate.push_back(analysisPair.second.get());
1004 }
1005 }
1006 }
1007
1008 //===----------------------------------------------------------------------===//
1009 // PassInstrumentation
1010 //===----------------------------------------------------------------------===//
1011
~PassInstrumentation()1012 PassInstrumentation::~PassInstrumentation() {}
1013
1014 //===----------------------------------------------------------------------===//
1015 // PassInstrumentor
1016 //===----------------------------------------------------------------------===//
1017
1018 namespace mlir {
1019 namespace detail {
1020 struct PassInstrumentorImpl {
1021 /// Mutex to keep instrumentation access thread-safe.
1022 llvm::sys::SmartMutex<true> mutex;
1023
1024 /// Set of registered instrumentations.
1025 std::vector<std::unique_ptr<PassInstrumentation>> instrumentations;
1026 };
1027 } // end namespace detail
1028 } // end namespace mlir
1029
PassInstrumentor()1030 PassInstrumentor::PassInstrumentor() : impl(new PassInstrumentorImpl()) {}
~PassInstrumentor()1031 PassInstrumentor::~PassInstrumentor() {}
1032
1033 /// See PassInstrumentation::runBeforePipeline for details.
runBeforePipeline(Identifier name,const PassInstrumentation::PipelineParentInfo & parentInfo)1034 void PassInstrumentor::runBeforePipeline(
1035 Identifier name,
1036 const PassInstrumentation::PipelineParentInfo &parentInfo) {
1037 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1038 for (auto &instr : impl->instrumentations)
1039 instr->runBeforePipeline(name, parentInfo);
1040 }
1041
1042 /// See PassInstrumentation::runAfterPipeline for details.
runAfterPipeline(Identifier name,const PassInstrumentation::PipelineParentInfo & parentInfo)1043 void PassInstrumentor::runAfterPipeline(
1044 Identifier name,
1045 const PassInstrumentation::PipelineParentInfo &parentInfo) {
1046 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1047 for (auto &instr : llvm::reverse(impl->instrumentations))
1048 instr->runAfterPipeline(name, parentInfo);
1049 }
1050
1051 /// See PassInstrumentation::runBeforePass for details.
runBeforePass(Pass * pass,Operation * op)1052 void PassInstrumentor::runBeforePass(Pass *pass, Operation *op) {
1053 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1054 for (auto &instr : impl->instrumentations)
1055 instr->runBeforePass(pass, op);
1056 }
1057
1058 /// See PassInstrumentation::runAfterPass for details.
runAfterPass(Pass * pass,Operation * op)1059 void PassInstrumentor::runAfterPass(Pass *pass, Operation *op) {
1060 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1061 for (auto &instr : llvm::reverse(impl->instrumentations))
1062 instr->runAfterPass(pass, op);
1063 }
1064
1065 /// See PassInstrumentation::runAfterPassFailed for details.
runAfterPassFailed(Pass * pass,Operation * op)1066 void PassInstrumentor::runAfterPassFailed(Pass *pass, Operation *op) {
1067 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1068 for (auto &instr : llvm::reverse(impl->instrumentations))
1069 instr->runAfterPassFailed(pass, op);
1070 }
1071
1072 /// See PassInstrumentation::runBeforeAnalysis for details.
runBeforeAnalysis(StringRef name,TypeID id,Operation * op)1073 void PassInstrumentor::runBeforeAnalysis(StringRef name, TypeID id,
1074 Operation *op) {
1075 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1076 for (auto &instr : impl->instrumentations)
1077 instr->runBeforeAnalysis(name, id, op);
1078 }
1079
1080 /// See PassInstrumentation::runAfterAnalysis for details.
runAfterAnalysis(StringRef name,TypeID id,Operation * op)1081 void PassInstrumentor::runAfterAnalysis(StringRef name, TypeID id,
1082 Operation *op) {
1083 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1084 for (auto &instr : llvm::reverse(impl->instrumentations))
1085 instr->runAfterAnalysis(name, id, op);
1086 }
1087
1088 /// Add the given instrumentation to the collection.
addInstrumentation(std::unique_ptr<PassInstrumentation> pi)1089 void PassInstrumentor::addInstrumentation(
1090 std::unique_ptr<PassInstrumentation> pi) {
1091 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1092 impl->instrumentations.emplace_back(std::move(pi));
1093 }
1094