1 //===- PassRegistry.cpp - Pass Registration Utilities ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Pass/PassRegistry.h"
10 #include "mlir/Pass/Pass.h"
11 #include "mlir/Pass/PassManager.h"
12 #include "llvm/ADT/DenseMap.h"
13 #include "llvm/Support/Format.h"
14 #include "llvm/Support/ManagedStatic.h"
15 #include "llvm/Support/MemoryBuffer.h"
16 #include "llvm/Support/SourceMgr.h"
17
18 using namespace mlir;
19 using namespace detail;
20
21 /// Static mapping of all of the registered passes.
22 static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry;
23
24 /// A mapping of the above pass registry entries to the corresponding TypeID
25 /// of the pass that they generate.
26 static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs;
27
28 /// Static mapping of all of the registered pass pipelines.
29 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
30 passPipelineRegistry;
31
32 /// Utility to create a default registry function from a pass instance.
33 static PassRegistryFunction
buildDefaultRegistryFn(const PassAllocatorFunction & allocator)34 buildDefaultRegistryFn(const PassAllocatorFunction &allocator) {
35 return [=](OpPassManager &pm, StringRef options,
36 function_ref<LogicalResult(const Twine &)> errorHandler) {
37 std::unique_ptr<Pass> pass = allocator();
38 LogicalResult result = pass->initializeOptions(options);
39 if ((pm.getNesting() == OpPassManager::Nesting::Explicit) &&
40 pass->getOpName() && *pass->getOpName() != pm.getOpName())
41 return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() +
42 "' restricted to '" + *pass->getOpName() +
43 "' on a PassManager intended to run on '" +
44 pm.getOpName() + "', did you intend to nest?");
45 pm.addPass(std::move(pass));
46 return result;
47 };
48 }
49
50 /// Utility to print the help string for a specific option.
printOptionHelp(StringRef arg,StringRef desc,size_t indent,size_t descIndent,bool isTopLevel)51 static void printOptionHelp(StringRef arg, StringRef desc, size_t indent,
52 size_t descIndent, bool isTopLevel) {
53 size_t numSpaces = descIndent - indent - 4;
54 llvm::outs().indent(indent)
55 << "--" << llvm::left_justify(arg, numSpaces) << "- " << desc << '\n';
56 }
57
58 //===----------------------------------------------------------------------===//
59 // PassRegistry
60 //===----------------------------------------------------------------------===//
61
62 /// Print the help information for this pass. This includes the argument,
63 /// description, and any pass options. `descIndent` is the indent that the
64 /// descriptions should be aligned.
printHelpStr(size_t indent,size_t descIndent) const65 void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const {
66 printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent,
67 /*isTopLevel=*/true);
68 // If this entry has options, print the help for those as well.
69 optHandler([=](const PassOptions &options) {
70 options.printHelp(indent, descIndent);
71 });
72 }
73
74 /// Return the maximum width required when printing the options of this
75 /// entry.
getOptionWidth() const76 size_t PassRegistryEntry::getOptionWidth() const {
77 size_t maxLen = 0;
78 optHandler([&](const PassOptions &options) mutable {
79 maxLen = options.getOptionWidth() + 2;
80 });
81 return maxLen;
82 }
83
84 //===----------------------------------------------------------------------===//
85 // PassPipelineInfo
86 //===----------------------------------------------------------------------===//
87
registerPassPipeline(StringRef arg,StringRef description,const PassRegistryFunction & function,std::function<void (function_ref<void (const PassOptions &)>)> optHandler)88 void mlir::registerPassPipeline(
89 StringRef arg, StringRef description, const PassRegistryFunction &function,
90 std::function<void(function_ref<void(const PassOptions &)>)> optHandler) {
91 PassPipelineInfo pipelineInfo(arg, description, function, optHandler);
92 bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second;
93 assert(inserted && "Pass pipeline registered multiple times");
94 (void)inserted;
95 }
96
97 //===----------------------------------------------------------------------===//
98 // PassInfo
99 //===----------------------------------------------------------------------===//
100
PassInfo(StringRef arg,StringRef description,const PassAllocatorFunction & allocator)101 PassInfo::PassInfo(StringRef arg, StringRef description,
102 const PassAllocatorFunction &allocator)
103 : PassRegistryEntry(
104 arg, description, buildDefaultRegistryFn(allocator),
105 // Use a temporary pass to provide an options instance.
106 [=](function_ref<void(const PassOptions &)> optHandler) {
107 optHandler(allocator()->passOptions);
108 }) {}
109
registerPass(StringRef arg,StringRef description,const PassAllocatorFunction & function)110 void mlir::registerPass(StringRef arg, StringRef description,
111 const PassAllocatorFunction &function) {
112 PassInfo passInfo(arg, description, function);
113 passRegistry->try_emplace(arg, passInfo);
114
115 // Verify that the registered pass has the same ID as any registered to this
116 // arg before it.
117 TypeID entryTypeID = function()->getTypeID();
118 auto it = passRegistryTypeIDs->try_emplace(arg, entryTypeID).first;
119 if (it->second != entryTypeID)
120 llvm::report_fatal_error(
121 "pass allocator creates a different pass than previously "
122 "registered for pass " +
123 arg);
124 }
125
registerPass(const PassAllocatorFunction & function)126 void mlir::registerPass(const PassAllocatorFunction &function) {
127 std::unique_ptr<Pass> pass = function();
128 StringRef arg = pass->getArgument();
129 if (arg.empty())
130 llvm::report_fatal_error(
131 "Trying to register a pass that does not override `getArgument()`: " +
132 pass->getName());
133 registerPass(arg, pass->getDescription(), function);
134 }
135
136 /// Returns the pass info for the specified pass argument or null if unknown.
lookupPassInfo(StringRef passArg)137 const PassInfo *mlir::Pass::lookupPassInfo(StringRef passArg) {
138 auto it = passRegistry->find(passArg);
139 return it == passRegistry->end() ? nullptr : &it->second;
140 }
141
142 //===----------------------------------------------------------------------===//
143 // PassOptions
144 //===----------------------------------------------------------------------===//
145
146 /// Out of line virtual function to provide home for the class.
anchor()147 void detail::PassOptions::OptionBase::anchor() {}
148
149 /// Copy the option values from 'other'.
copyOptionValuesFrom(const PassOptions & other)150 void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) {
151 assert(options.size() == other.options.size());
152 if (options.empty())
153 return;
154 for (auto optionsIt : llvm::zip(options, other.options))
155 std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
156 }
157
parseFromString(StringRef options)158 LogicalResult detail::PassOptions::parseFromString(StringRef options) {
159 // TODO: Handle escaping strings.
160 // NOTE: `options` is modified in place to always refer to the unprocessed
161 // part of the string.
162 while (!options.empty()) {
163 size_t spacePos = options.find(' ');
164 StringRef arg = options;
165 if (spacePos != StringRef::npos) {
166 arg = options.substr(0, spacePos);
167 options = options.substr(spacePos + 1);
168 } else {
169 options = StringRef();
170 }
171 if (arg.empty())
172 continue;
173
174 // At this point, arg refers to everything that is non-space in options
175 // upto the next space, and options refers to the rest of the string after
176 // that point.
177
178 // Split the individual option on '=' to form key and value. If there is no
179 // '=', then value is `StringRef()`.
180 size_t equalPos = arg.find('=');
181 StringRef key = arg;
182 StringRef value;
183 if (equalPos != StringRef::npos) {
184 key = arg.substr(0, equalPos);
185 value = arg.substr(equalPos + 1);
186 }
187 auto it = OptionsMap.find(key);
188 if (it == OptionsMap.end()) {
189 llvm::errs() << "<Pass-Options-Parser>: no such option " << key << "\n";
190 return failure();
191 }
192 if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
193 return failure();
194 }
195
196 return success();
197 }
198
199 /// Print the options held by this struct in a form that can be parsed via
200 /// 'parseFromString'.
print(raw_ostream & os)201 void detail::PassOptions::print(raw_ostream &os) {
202 // If there are no options, there is nothing left to do.
203 if (OptionsMap.empty())
204 return;
205
206 // Sort the options to make the ordering deterministic.
207 SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
208 auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
209 return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
210 };
211 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
212
213 // Interleave the options with ' '.
214 os << '{';
215 llvm::interleave(
216 orderedOps, os, [&](OptionBase *option) { option->print(os); }, " ");
217 os << '}';
218 }
219
220 /// Print the help string for the options held by this struct. `descIndent` is
221 /// the indent within the stream that the descriptions should be aligned.
printHelp(size_t indent,size_t descIndent) const222 void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const {
223 // Sort the options to make the ordering deterministic.
224 SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
225 auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
226 return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
227 };
228 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
229 for (OptionBase *option : orderedOps) {
230 // TODO: printOptionInfo assumes a specific indent and will
231 // print options with values with incorrect indentation. We should add
232 // support to llvm::cl::Option for passing in a base indent to use when
233 // printing.
234 llvm::outs().indent(indent);
235 option->getOption()->printOptionInfo(descIndent - indent);
236 }
237 }
238
239 /// Return the maximum width required when printing the help string.
getOptionWidth() const240 size_t detail::PassOptions::getOptionWidth() const {
241 size_t max = 0;
242 for (auto *option : options)
243 max = std::max(max, option->getOption()->getOptionWidth());
244 return max;
245 }
246
247 //===----------------------------------------------------------------------===//
248 // TextualPassPipeline Parser
249 //===----------------------------------------------------------------------===//
250
251 namespace {
252 /// This class represents a textual description of a pass pipeline.
253 class TextualPipeline {
254 public:
255 /// Try to initialize this pipeline with the given pipeline text.
256 /// `errorStream` is the output stream to emit errors to.
257 LogicalResult initialize(StringRef text, raw_ostream &errorStream);
258
259 /// Add the internal pipeline elements to the provided pass manager.
260 LogicalResult
261 addToPipeline(OpPassManager &pm,
262 function_ref<LogicalResult(const Twine &)> errorHandler) const;
263
264 private:
265 /// A functor used to emit errors found during pipeline handling. The first
266 /// parameter corresponds to the raw location within the pipeline string. This
267 /// should always return failure.
268 using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>;
269
270 /// A struct to capture parsed pass pipeline names.
271 ///
272 /// A pipeline is defined as a series of names, each of which may in itself
273 /// recursively contain a nested pipeline. A name is either the name of a pass
274 /// (e.g. "cse") or the name of an operation type (e.g. "func"). If the name
275 /// is the name of a pass, the InnerPipeline is empty, since passes cannot
276 /// contain inner pipelines.
277 struct PipelineElement {
PipelineElement__anone08a6f2d0811::TextualPipeline::PipelineElement278 PipelineElement(StringRef name) : name(name), registryEntry(nullptr) {}
279
280 StringRef name;
281 StringRef options;
282 const PassRegistryEntry *registryEntry;
283 std::vector<PipelineElement> innerPipeline;
284 };
285
286 /// Parse the given pipeline text into the internal pipeline vector. This
287 /// function only parses the structure of the pipeline, and does not resolve
288 /// its elements.
289 LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);
290
291 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
292 /// the corresponding registry entry.
293 LogicalResult
294 resolvePipelineElements(MutableArrayRef<PipelineElement> elements,
295 ErrorHandlerT errorHandler);
296
297 /// Resolve a single element of the pipeline.
298 LogicalResult resolvePipelineElement(PipelineElement &element,
299 ErrorHandlerT errorHandler);
300
301 /// Add the given pipeline elements to the provided pass manager.
302 LogicalResult
303 addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm,
304 function_ref<LogicalResult(const Twine &)> errorHandler) const;
305
306 std::vector<PipelineElement> pipeline;
307 };
308
309 } // end anonymous namespace
310
311 /// Try to initialize this pipeline with the given pipeline text. An option is
312 /// given to enable accurate error reporting.
initialize(StringRef text,raw_ostream & errorStream)313 LogicalResult TextualPipeline::initialize(StringRef text,
314 raw_ostream &errorStream) {
315 if (text.empty())
316 return success();
317
318 // Build a source manager to use for error reporting.
319 llvm::SourceMgr pipelineMgr;
320 pipelineMgr.AddNewSourceBuffer(
321 llvm::MemoryBuffer::getMemBuffer(text, "MLIR Textual PassPipeline Parser",
322 /*RequiresNullTerminator=*/false),
323 llvm::SMLoc());
324 auto errorHandler = [&](const char *rawLoc, Twine msg) {
325 pipelineMgr.PrintMessage(errorStream, llvm::SMLoc::getFromPointer(rawLoc),
326 llvm::SourceMgr::DK_Error, msg);
327 return failure();
328 };
329
330 // Parse the provided pipeline string.
331 if (failed(parsePipelineText(text, errorHandler)))
332 return failure();
333 return resolvePipelineElements(pipeline, errorHandler);
334 }
335
336 /// Add the internal pipeline elements to the provided pass manager.
addToPipeline(OpPassManager & pm,function_ref<LogicalResult (const Twine &)> errorHandler) const337 LogicalResult TextualPipeline::addToPipeline(
338 OpPassManager &pm,
339 function_ref<LogicalResult(const Twine &)> errorHandler) const {
340 return addToPipeline(pipeline, pm, errorHandler);
341 }
342
343 /// Parse the given pipeline text into the internal pipeline vector. This
344 /// function only parses the structure of the pipeline, and does not resolve
345 /// its elements.
parsePipelineText(StringRef text,ErrorHandlerT errorHandler)346 LogicalResult TextualPipeline::parsePipelineText(StringRef text,
347 ErrorHandlerT errorHandler) {
348 SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline};
349 for (;;) {
350 std::vector<PipelineElement> &pipeline = *pipelineStack.back();
351 size_t pos = text.find_first_of(",(){");
352 pipeline.emplace_back(/*name=*/text.substr(0, pos).trim());
353
354 // If we have a single terminating name, we're done.
355 if (pos == StringRef::npos)
356 break;
357
358 text = text.substr(pos);
359 char sep = text[0];
360
361 // Handle pulling ... from 'pass{...}' out as PipelineElement.options.
362 if (sep == '{') {
363 text = text.substr(1);
364
365 // Skip over everything until the closing '}' and store as options.
366 size_t close = StringRef::npos;
367 for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {
368 if (text[i] == '{') {
369 ++braceCount;
370 continue;
371 }
372 if (text[i] == '}' && --braceCount == 0) {
373 close = i;
374 break;
375 }
376 }
377
378 // Check to see if a closing options brace was found.
379 if (close == StringRef::npos) {
380 return errorHandler(
381 /*rawLoc=*/text.data() - 1,
382 "missing closing '}' while processing pass options");
383 }
384 pipeline.back().options = text.substr(0, close);
385 text = text.substr(close + 1);
386
387 // Skip checking for '(' because nested pipelines cannot have options.
388 } else if (sep == '(') {
389 text = text.substr(1);
390
391 // Push the inner pipeline onto the stack to continue processing.
392 pipelineStack.push_back(&pipeline.back().innerPipeline);
393 continue;
394 }
395
396 // When handling the close parenthesis, we greedily consume them to avoid
397 // empty strings in the pipeline.
398 while (text.consume_front(")")) {
399 // If we try to pop the outer pipeline we have unbalanced parentheses.
400 if (pipelineStack.size() == 1)
401 return errorHandler(/*rawLoc=*/text.data() - 1,
402 "encountered extra closing ')' creating unbalanced "
403 "parentheses while parsing pipeline");
404
405 pipelineStack.pop_back();
406 }
407
408 // Check if we've finished parsing.
409 if (text.empty())
410 break;
411
412 // Otherwise, the end of an inner pipeline always has to be followed by
413 // a comma, and then we can continue.
414 if (!text.consume_front(","))
415 return errorHandler(text.data(), "expected ',' after parsing pipeline");
416 }
417
418 // Check for unbalanced parentheses.
419 if (pipelineStack.size() > 1)
420 return errorHandler(
421 text.data(),
422 "encountered unbalanced parentheses while parsing pipeline");
423
424 assert(pipelineStack.back() == &pipeline &&
425 "wrong pipeline at the bottom of the stack");
426 return success();
427 }
428
429 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
430 /// the corresponding registry entry.
resolvePipelineElements(MutableArrayRef<PipelineElement> elements,ErrorHandlerT errorHandler)431 LogicalResult TextualPipeline::resolvePipelineElements(
432 MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) {
433 for (auto &elt : elements)
434 if (failed(resolvePipelineElement(elt, errorHandler)))
435 return failure();
436 return success();
437 }
438
439 /// Resolve a single element of the pipeline.
440 LogicalResult
resolvePipelineElement(PipelineElement & element,ErrorHandlerT errorHandler)441 TextualPipeline::resolvePipelineElement(PipelineElement &element,
442 ErrorHandlerT errorHandler) {
443 // If the inner pipeline of this element is not empty, this is an operation
444 // pipeline.
445 if (!element.innerPipeline.empty())
446 return resolvePipelineElements(element.innerPipeline, errorHandler);
447 // Otherwise, this must be a pass or pass pipeline.
448 // Check to see if a pipeline was registered with this name.
449 auto pipelineRegistryIt = passPipelineRegistry->find(element.name);
450 if (pipelineRegistryIt != passPipelineRegistry->end()) {
451 element.registryEntry = &pipelineRegistryIt->second;
452 return success();
453 }
454
455 // If not, then this must be a specific pass name.
456 if ((element.registryEntry = Pass::lookupPassInfo(element.name)))
457 return success();
458
459 // Emit an error for the unknown pass.
460 auto *rawLoc = element.name.data();
461 return errorHandler(rawLoc, "'" + element.name +
462 "' does not refer to a "
463 "registered pass or pass pipeline");
464 }
465
466 /// Add the given pipeline elements to the provided pass manager.
addToPipeline(ArrayRef<PipelineElement> elements,OpPassManager & pm,function_ref<LogicalResult (const Twine &)> errorHandler) const467 LogicalResult TextualPipeline::addToPipeline(
468 ArrayRef<PipelineElement> elements, OpPassManager &pm,
469 function_ref<LogicalResult(const Twine &)> errorHandler) const {
470 for (auto &elt : elements) {
471 if (elt.registryEntry) {
472 if (failed(elt.registryEntry->addToPipeline(pm, elt.options,
473 errorHandler))) {
474 return errorHandler("failed to add `" + elt.name + "` with options `" +
475 elt.options + "`");
476 }
477 } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name),
478 errorHandler))) {
479 return errorHandler("failed to add `" + elt.name + "` with options `" +
480 elt.options + "` to inner pipeline");
481 }
482 }
483 return success();
484 }
485
486 /// This function parses the textual representation of a pass pipeline, and adds
487 /// the result to 'pm' on success. This function returns failure if the given
488 /// pipeline was invalid. 'errorStream' is an optional parameter that, if
489 /// non-null, will be used to emit errors found during parsing.
parsePassPipeline(StringRef pipeline,OpPassManager & pm,raw_ostream & errorStream)490 LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm,
491 raw_ostream &errorStream) {
492 TextualPipeline pipelineParser;
493 if (failed(pipelineParser.initialize(pipeline, errorStream)))
494 return failure();
495 auto errorHandler = [&](Twine msg) {
496 errorStream << msg << "\n";
497 return failure();
498 };
499 if (failed(pipelineParser.addToPipeline(pm, errorHandler)))
500 return failure();
501 return success();
502 }
503
504 //===----------------------------------------------------------------------===//
505 // PassNameParser
506 //===----------------------------------------------------------------------===//
507
508 namespace {
509 /// This struct represents the possible data entries in a parsed pass pipeline
510 /// list.
511 struct PassArgData {
PassArgData__anone08a6f2d0b11::PassArgData512 PassArgData() : registryEntry(nullptr) {}
PassArgData__anone08a6f2d0b11::PassArgData513 PassArgData(const PassRegistryEntry *registryEntry)
514 : registryEntry(registryEntry) {}
515
516 /// This field is used when the parsed option corresponds to a registered pass
517 /// or pass pipeline.
518 const PassRegistryEntry *registryEntry;
519
520 /// This field is set when instance specific pass options have been provided
521 /// on the command line.
522 StringRef options;
523
524 /// This field is used when the parsed option corresponds to an explicit
525 /// pipeline.
526 TextualPipeline pipeline;
527 };
528 } // end anonymous namespace
529
530 namespace llvm {
531 namespace cl {
532 /// Define a valid OptionValue for the command line pass argument.
533 template <>
534 struct OptionValue<PassArgData> final
535 : OptionValueBase<PassArgData, /*isClass=*/true> {
OptionValuellvm::cl::OptionValue536 OptionValue(const PassArgData &value) { this->setValue(value); }
537 OptionValue() = default;
anchorllvm::cl::OptionValue538 void anchor() override {}
539
hasValuellvm::cl::OptionValue540 bool hasValue() const { return true; }
getValuellvm::cl::OptionValue541 const PassArgData &getValue() const { return value; }
setValuellvm::cl::OptionValue542 void setValue(const PassArgData &value) { this->value = value; }
543
544 PassArgData value;
545 };
546 } // end namespace cl
547 } // end namespace llvm
548
549 namespace {
550
551 /// The name for the command line option used for parsing the textual pass
552 /// pipeline.
553 static constexpr StringLiteral passPipelineArg = "pass-pipeline";
554
555 /// Adds command line option for each registered pass or pass pipeline, as well
556 /// as textual pass pipelines.
557 struct PassNameParser : public llvm::cl::parser<PassArgData> {
PassNameParser__anone08a6f2d0c11::PassNameParser558 PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {}
559
560 void initialize();
561 void printOptionInfo(const llvm::cl::Option &opt,
562 size_t globalWidth) const override;
563 size_t getOptionWidth(const llvm::cl::Option &opt) const override;
564 bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
565 PassArgData &value);
566
567 /// If true, this parser only parses entries that correspond to a concrete
568 /// pass registry entry, and does not add a `pass-pipeline` argument, does not
569 /// include the options for pass entries, and does not include pass pipelines
570 /// entries.
571 bool passNamesOnly = false;
572 };
573 } // namespace
574
initialize()575 void PassNameParser::initialize() {
576 llvm::cl::parser<PassArgData>::initialize();
577
578 /// Add an entry for the textual pass pipeline option.
579 if (!passNamesOnly) {
580 addLiteralOption(passPipelineArg, PassArgData(),
581 "A textual description of a pass pipeline to run");
582 }
583
584 /// Add the pass entries.
585 for (const auto &kv : *passRegistry) {
586 addLiteralOption(kv.second.getPassArgument(), &kv.second,
587 kv.second.getPassDescription());
588 }
589 /// Add the pass pipeline entries.
590 if (!passNamesOnly) {
591 for (const auto &kv : *passPipelineRegistry) {
592 addLiteralOption(kv.second.getPassArgument(), &kv.second,
593 kv.second.getPassDescription());
594 }
595 }
596 }
597
printOptionInfo(const llvm::cl::Option & opt,size_t globalWidth) const598 void PassNameParser::printOptionInfo(const llvm::cl::Option &opt,
599 size_t globalWidth) const {
600 // If this parser is just parsing pass names, print a simplified option
601 // string.
602 if (passNamesOnly) {
603 llvm::outs() << " --" << opt.ArgStr << "=<pass-arg>";
604 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18);
605 return;
606 }
607
608 // Print the information for the top-level option.
609 if (opt.hasArgStr()) {
610 llvm::outs() << " --" << opt.ArgStr;
611 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7);
612 } else {
613 llvm::outs() << " " << opt.HelpStr << '\n';
614 }
615
616 // Print the top-level pipeline argument.
617 printOptionHelp(passPipelineArg,
618 "A textual description of a pass pipeline to run",
619 /*indent=*/4, globalWidth, /*isTopLevel=*/!opt.hasArgStr());
620
621 // Functor used to print the ordered entries of a registration map.
622 auto printOrderedEntries = [&](StringRef header, auto &map) {
623 llvm::SmallVector<PassRegistryEntry *, 32> orderedEntries;
624 for (auto &kv : map)
625 orderedEntries.push_back(&kv.second);
626 llvm::array_pod_sort(
627 orderedEntries.begin(), orderedEntries.end(),
628 [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
629 return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
630 });
631
632 llvm::outs().indent(4) << header << ":\n";
633 for (PassRegistryEntry *entry : orderedEntries)
634 entry->printHelpStr(/*indent=*/6, globalWidth);
635 };
636
637 // Print the available passes.
638 printOrderedEntries("Passes", *passRegistry);
639
640 // Print the available pass pipelines.
641 if (!passPipelineRegistry->empty())
642 printOrderedEntries("Pass Pipelines", *passPipelineRegistry);
643 }
644
getOptionWidth(const llvm::cl::Option & opt) const645 size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const {
646 size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(opt) + 2;
647
648 // Check for any wider pass or pipeline options.
649 for (auto &entry : *passRegistry)
650 maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
651 for (auto &entry : *passPipelineRegistry)
652 maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
653 return maxWidth;
654 }
655
parse(llvm::cl::Option & opt,StringRef argName,StringRef arg,PassArgData & value)656 bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName,
657 StringRef arg, PassArgData &value) {
658 // Handle the pipeline option explicitly.
659 if (argName == passPipelineArg)
660 return failed(value.pipeline.initialize(arg, llvm::errs()));
661
662 // Otherwise, default to the base for handling.
663 if (llvm::cl::parser<PassArgData>::parse(opt, argName, arg, value))
664 return true;
665 value.options = arg;
666 return false;
667 }
668
669 //===----------------------------------------------------------------------===//
670 // PassPipelineCLParser
671 //===----------------------------------------------------------------------===//
672
673 namespace mlir {
674 namespace detail {
675 struct PassPipelineCLParserImpl {
PassPipelineCLParserImplmlir::detail::PassPipelineCLParserImpl676 PassPipelineCLParserImpl(StringRef arg, StringRef description,
677 bool passNamesOnly)
678 : passList(arg, llvm::cl::desc(description)) {
679 passList.getParser().passNamesOnly = passNamesOnly;
680 passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
681 }
682
683 /// Returns true if the given pass registry entry was registered at the
684 /// top-level of the parser, i.e. not within an explicit textual pipeline.
containsmlir::detail::PassPipelineCLParserImpl685 bool contains(const PassRegistryEntry *entry) const {
686 return llvm::any_of(passList, [&](const PassArgData &data) {
687 return data.registryEntry == entry;
688 });
689 }
690
691 /// The set of passes and pass pipelines to run.
692 llvm::cl::list<PassArgData, bool, PassNameParser> passList;
693 };
694 } // end namespace detail
695 } // end namespace mlir
696
697 /// Construct a pass pipeline parser with the given command line description.
PassPipelineCLParser(StringRef arg,StringRef description)698 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description)
699 : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
700 arg, description, /*passNamesOnly=*/false)) {}
~PassPipelineCLParser()701 PassPipelineCLParser::~PassPipelineCLParser() {}
702
703 /// Returns true if this parser contains any valid options to add.
hasAnyOccurrences() const704 bool PassPipelineCLParser::hasAnyOccurrences() const {
705 return impl->passList.getNumOccurrences() != 0;
706 }
707
708 /// Returns true if the given pass registry entry was registered at the
709 /// top-level of the parser, i.e. not within an explicit textual pipeline.
contains(const PassRegistryEntry * entry) const710 bool PassPipelineCLParser::contains(const PassRegistryEntry *entry) const {
711 return impl->contains(entry);
712 }
713
714 /// Adds the passes defined by this parser entry to the given pass manager.
addToPipeline(OpPassManager & pm,function_ref<LogicalResult (const Twine &)> errorHandler) const715 LogicalResult PassPipelineCLParser::addToPipeline(
716 OpPassManager &pm,
717 function_ref<LogicalResult(const Twine &)> errorHandler) const {
718 for (auto &passIt : impl->passList) {
719 if (passIt.registryEntry) {
720 if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options,
721 errorHandler)))
722 return failure();
723 } else {
724 OpPassManager::Nesting nesting = pm.getNesting();
725 pm.setNesting(OpPassManager::Nesting::Explicit);
726 LogicalResult status = passIt.pipeline.addToPipeline(pm, errorHandler);
727 pm.setNesting(nesting);
728 if (failed(status))
729 return failure();
730 }
731 }
732 return success();
733 }
734
735 //===----------------------------------------------------------------------===//
736 // PassNameCLParser
737
738 /// Construct a pass pipeline parser with the given command line description.
PassNameCLParser(StringRef arg,StringRef description)739 PassNameCLParser::PassNameCLParser(StringRef arg, StringRef description)
740 : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
741 arg, description, /*passNamesOnly=*/true)) {
742 impl->passList.setMiscFlag(llvm::cl::CommaSeparated);
743 }
~PassNameCLParser()744 PassNameCLParser::~PassNameCLParser() {}
745
746 /// Returns true if this parser contains any valid options to add.
hasAnyOccurrences() const747 bool PassNameCLParser::hasAnyOccurrences() const {
748 return impl->passList.getNumOccurrences() != 0;
749 }
750
751 /// Returns true if the given pass registry entry was registered at the
752 /// top-level of the parser, i.e. not within an explicit textual pipeline.
contains(const PassRegistryEntry * entry) const753 bool PassNameCLParser::contains(const PassRegistryEntry *entry) const {
754 return impl->contains(entry);
755 }
756