1 //===- PassRegistry.cpp - Pass Registration Utilities ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Pass/PassRegistry.h"
10 #include "mlir/Pass/Pass.h"
11 #include "mlir/Pass/PassManager.h"
12 #include "llvm/ADT/DenseMap.h"
13 #include "llvm/Support/Format.h"
14 #include "llvm/Support/ManagedStatic.h"
15 #include "llvm/Support/MemoryBuffer.h"
16 #include "llvm/Support/SourceMgr.h"
17 
18 using namespace mlir;
19 using namespace detail;
20 
21 /// Static mapping of all of the registered passes.
22 static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry;
23 
24 /// A mapping of the above pass registry entries to the corresponding TypeID
25 /// of the pass that they generate.
26 static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs;
27 
28 /// Static mapping of all of the registered pass pipelines.
29 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
30     passPipelineRegistry;
31 
32 /// Utility to create a default registry function from a pass instance.
33 static PassRegistryFunction
buildDefaultRegistryFn(const PassAllocatorFunction & allocator)34 buildDefaultRegistryFn(const PassAllocatorFunction &allocator) {
35   return [=](OpPassManager &pm, StringRef options,
36              function_ref<LogicalResult(const Twine &)> errorHandler) {
37     std::unique_ptr<Pass> pass = allocator();
38     LogicalResult result = pass->initializeOptions(options);
39     if ((pm.getNesting() == OpPassManager::Nesting::Explicit) &&
40         pass->getOpName() && *pass->getOpName() != pm.getOpName())
41       return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() +
42                           "' restricted to '" + *pass->getOpName() +
43                           "' on a PassManager intended to run on '" +
44                           pm.getOpName() + "', did you intend to nest?");
45     pm.addPass(std::move(pass));
46     return result;
47   };
48 }
49 
50 /// Utility to print the help string for a specific option.
printOptionHelp(StringRef arg,StringRef desc,size_t indent,size_t descIndent,bool isTopLevel)51 static void printOptionHelp(StringRef arg, StringRef desc, size_t indent,
52                             size_t descIndent, bool isTopLevel) {
53   size_t numSpaces = descIndent - indent - 4;
54   llvm::outs().indent(indent)
55       << "--" << llvm::left_justify(arg, numSpaces) << "-   " << desc << '\n';
56 }
57 
58 //===----------------------------------------------------------------------===//
59 // PassRegistry
60 //===----------------------------------------------------------------------===//
61 
62 /// Print the help information for this pass. This includes the argument,
63 /// description, and any pass options. `descIndent` is the indent that the
64 /// descriptions should be aligned.
printHelpStr(size_t indent,size_t descIndent) const65 void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const {
66   printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent,
67                   /*isTopLevel=*/true);
68   // If this entry has options, print the help for those as well.
69   optHandler([=](const PassOptions &options) {
70     options.printHelp(indent, descIndent);
71   });
72 }
73 
74 /// Return the maximum width required when printing the options of this
75 /// entry.
getOptionWidth() const76 size_t PassRegistryEntry::getOptionWidth() const {
77   size_t maxLen = 0;
78   optHandler([&](const PassOptions &options) mutable {
79     maxLen = options.getOptionWidth() + 2;
80   });
81   return maxLen;
82 }
83 
84 //===----------------------------------------------------------------------===//
85 // PassPipelineInfo
86 //===----------------------------------------------------------------------===//
87 
registerPassPipeline(StringRef arg,StringRef description,const PassRegistryFunction & function,std::function<void (function_ref<void (const PassOptions &)>)> optHandler)88 void mlir::registerPassPipeline(
89     StringRef arg, StringRef description, const PassRegistryFunction &function,
90     std::function<void(function_ref<void(const PassOptions &)>)> optHandler) {
91   PassPipelineInfo pipelineInfo(arg, description, function, optHandler);
92   bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second;
93   assert(inserted && "Pass pipeline registered multiple times");
94   (void)inserted;
95 }
96 
97 //===----------------------------------------------------------------------===//
98 // PassInfo
99 //===----------------------------------------------------------------------===//
100 
PassInfo(StringRef arg,StringRef description,const PassAllocatorFunction & allocator)101 PassInfo::PassInfo(StringRef arg, StringRef description,
102                    const PassAllocatorFunction &allocator)
103     : PassRegistryEntry(
104           arg, description, buildDefaultRegistryFn(allocator),
105           // Use a temporary pass to provide an options instance.
106           [=](function_ref<void(const PassOptions &)> optHandler) {
107             optHandler(allocator()->passOptions);
108           }) {}
109 
registerPass(StringRef arg,StringRef description,const PassAllocatorFunction & function)110 void mlir::registerPass(StringRef arg, StringRef description,
111                         const PassAllocatorFunction &function) {
112   PassInfo passInfo(arg, description, function);
113   passRegistry->try_emplace(arg, passInfo);
114 
115   // Verify that the registered pass has the same ID as any registered to this
116   // arg before it.
117   TypeID entryTypeID = function()->getTypeID();
118   auto it = passRegistryTypeIDs->try_emplace(arg, entryTypeID).first;
119   if (it->second != entryTypeID)
120     llvm::report_fatal_error(
121         "pass allocator creates a different pass than previously "
122         "registered for pass " +
123         arg);
124 }
125 
registerPass(const PassAllocatorFunction & function)126 void mlir::registerPass(const PassAllocatorFunction &function) {
127   std::unique_ptr<Pass> pass = function();
128   StringRef arg = pass->getArgument();
129   if (arg.empty())
130     llvm::report_fatal_error(
131         "Trying to register a pass that does not override `getArgument()`: " +
132         pass->getName());
133   registerPass(arg, pass->getDescription(), function);
134 }
135 
136 /// Returns the pass info for the specified pass argument or null if unknown.
lookupPassInfo(StringRef passArg)137 const PassInfo *mlir::Pass::lookupPassInfo(StringRef passArg) {
138   auto it = passRegistry->find(passArg);
139   return it == passRegistry->end() ? nullptr : &it->second;
140 }
141 
142 //===----------------------------------------------------------------------===//
143 // PassOptions
144 //===----------------------------------------------------------------------===//
145 
146 /// Out of line virtual function to provide home for the class.
anchor()147 void detail::PassOptions::OptionBase::anchor() {}
148 
149 /// Copy the option values from 'other'.
copyOptionValuesFrom(const PassOptions & other)150 void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) {
151   assert(options.size() == other.options.size());
152   if (options.empty())
153     return;
154   for (auto optionsIt : llvm::zip(options, other.options))
155     std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
156 }
157 
parseFromString(StringRef options)158 LogicalResult detail::PassOptions::parseFromString(StringRef options) {
159   // TODO: Handle escaping strings.
160   // NOTE: `options` is modified in place to always refer to the unprocessed
161   // part of the string.
162   while (!options.empty()) {
163     size_t spacePos = options.find(' ');
164     StringRef arg = options;
165     if (spacePos != StringRef::npos) {
166       arg = options.substr(0, spacePos);
167       options = options.substr(spacePos + 1);
168     } else {
169       options = StringRef();
170     }
171     if (arg.empty())
172       continue;
173 
174     // At this point, arg refers to everything that is non-space in options
175     // upto the next space, and options refers to the rest of the string after
176     // that point.
177 
178     // Split the individual option on '=' to form key and value. If there is no
179     // '=', then value is `StringRef()`.
180     size_t equalPos = arg.find('=');
181     StringRef key = arg;
182     StringRef value;
183     if (equalPos != StringRef::npos) {
184       key = arg.substr(0, equalPos);
185       value = arg.substr(equalPos + 1);
186     }
187     auto it = OptionsMap.find(key);
188     if (it == OptionsMap.end()) {
189       llvm::errs() << "<Pass-Options-Parser>: no such option " << key << "\n";
190       return failure();
191     }
192     if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
193       return failure();
194   }
195 
196   return success();
197 }
198 
199 /// Print the options held by this struct in a form that can be parsed via
200 /// 'parseFromString'.
print(raw_ostream & os)201 void detail::PassOptions::print(raw_ostream &os) {
202   // If there are no options, there is nothing left to do.
203   if (OptionsMap.empty())
204     return;
205 
206   // Sort the options to make the ordering deterministic.
207   SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
208   auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
209     return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
210   };
211   llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
212 
213   // Interleave the options with ' '.
214   os << '{';
215   llvm::interleave(
216       orderedOps, os, [&](OptionBase *option) { option->print(os); }, " ");
217   os << '}';
218 }
219 
220 /// Print the help string for the options held by this struct. `descIndent` is
221 /// the indent within the stream that the descriptions should be aligned.
printHelp(size_t indent,size_t descIndent) const222 void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const {
223   // Sort the options to make the ordering deterministic.
224   SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
225   auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
226     return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
227   };
228   llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
229   for (OptionBase *option : orderedOps) {
230     // TODO: printOptionInfo assumes a specific indent and will
231     // print options with values with incorrect indentation. We should add
232     // support to llvm::cl::Option for passing in a base indent to use when
233     // printing.
234     llvm::outs().indent(indent);
235     option->getOption()->printOptionInfo(descIndent - indent);
236   }
237 }
238 
239 /// Return the maximum width required when printing the help string.
getOptionWidth() const240 size_t detail::PassOptions::getOptionWidth() const {
241   size_t max = 0;
242   for (auto *option : options)
243     max = std::max(max, option->getOption()->getOptionWidth());
244   return max;
245 }
246 
247 //===----------------------------------------------------------------------===//
248 // TextualPassPipeline Parser
249 //===----------------------------------------------------------------------===//
250 
251 namespace {
252 /// This class represents a textual description of a pass pipeline.
253 class TextualPipeline {
254 public:
255   /// Try to initialize this pipeline with the given pipeline text.
256   /// `errorStream` is the output stream to emit errors to.
257   LogicalResult initialize(StringRef text, raw_ostream &errorStream);
258 
259   /// Add the internal pipeline elements to the provided pass manager.
260   LogicalResult
261   addToPipeline(OpPassManager &pm,
262                 function_ref<LogicalResult(const Twine &)> errorHandler) const;
263 
264 private:
265   /// A functor used to emit errors found during pipeline handling. The first
266   /// parameter corresponds to the raw location within the pipeline string. This
267   /// should always return failure.
268   using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>;
269 
270   /// A struct to capture parsed pass pipeline names.
271   ///
272   /// A pipeline is defined as a series of names, each of which may in itself
273   /// recursively contain a nested pipeline. A name is either the name of a pass
274   /// (e.g. "cse") or the name of an operation type (e.g. "func"). If the name
275   /// is the name of a pass, the InnerPipeline is empty, since passes cannot
276   /// contain inner pipelines.
277   struct PipelineElement {
PipelineElement__anone08a6f2d0811::TextualPipeline::PipelineElement278     PipelineElement(StringRef name) : name(name), registryEntry(nullptr) {}
279 
280     StringRef name;
281     StringRef options;
282     const PassRegistryEntry *registryEntry;
283     std::vector<PipelineElement> innerPipeline;
284   };
285 
286   /// Parse the given pipeline text into the internal pipeline vector. This
287   /// function only parses the structure of the pipeline, and does not resolve
288   /// its elements.
289   LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);
290 
291   /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
292   /// the corresponding registry entry.
293   LogicalResult
294   resolvePipelineElements(MutableArrayRef<PipelineElement> elements,
295                           ErrorHandlerT errorHandler);
296 
297   /// Resolve a single element of the pipeline.
298   LogicalResult resolvePipelineElement(PipelineElement &element,
299                                        ErrorHandlerT errorHandler);
300 
301   /// Add the given pipeline elements to the provided pass manager.
302   LogicalResult
303   addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm,
304                 function_ref<LogicalResult(const Twine &)> errorHandler) const;
305 
306   std::vector<PipelineElement> pipeline;
307 };
308 
309 } // end anonymous namespace
310 
311 /// Try to initialize this pipeline with the given pipeline text. An option is
312 /// given to enable accurate error reporting.
initialize(StringRef text,raw_ostream & errorStream)313 LogicalResult TextualPipeline::initialize(StringRef text,
314                                           raw_ostream &errorStream) {
315   if (text.empty())
316     return success();
317 
318   // Build a source manager to use for error reporting.
319   llvm::SourceMgr pipelineMgr;
320   pipelineMgr.AddNewSourceBuffer(
321       llvm::MemoryBuffer::getMemBuffer(text, "MLIR Textual PassPipeline Parser",
322                                        /*RequiresNullTerminator=*/false),
323       llvm::SMLoc());
324   auto errorHandler = [&](const char *rawLoc, Twine msg) {
325     pipelineMgr.PrintMessage(errorStream, llvm::SMLoc::getFromPointer(rawLoc),
326                              llvm::SourceMgr::DK_Error, msg);
327     return failure();
328   };
329 
330   // Parse the provided pipeline string.
331   if (failed(parsePipelineText(text, errorHandler)))
332     return failure();
333   return resolvePipelineElements(pipeline, errorHandler);
334 }
335 
336 /// Add the internal pipeline elements to the provided pass manager.
addToPipeline(OpPassManager & pm,function_ref<LogicalResult (const Twine &)> errorHandler) const337 LogicalResult TextualPipeline::addToPipeline(
338     OpPassManager &pm,
339     function_ref<LogicalResult(const Twine &)> errorHandler) const {
340   return addToPipeline(pipeline, pm, errorHandler);
341 }
342 
343 /// Parse the given pipeline text into the internal pipeline vector. This
344 /// function only parses the structure of the pipeline, and does not resolve
345 /// its elements.
parsePipelineText(StringRef text,ErrorHandlerT errorHandler)346 LogicalResult TextualPipeline::parsePipelineText(StringRef text,
347                                                  ErrorHandlerT errorHandler) {
348   SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline};
349   for (;;) {
350     std::vector<PipelineElement> &pipeline = *pipelineStack.back();
351     size_t pos = text.find_first_of(",(){");
352     pipeline.emplace_back(/*name=*/text.substr(0, pos).trim());
353 
354     // If we have a single terminating name, we're done.
355     if (pos == StringRef::npos)
356       break;
357 
358     text = text.substr(pos);
359     char sep = text[0];
360 
361     // Handle pulling ... from 'pass{...}' out as PipelineElement.options.
362     if (sep == '{') {
363       text = text.substr(1);
364 
365       // Skip over everything until the closing '}' and store as options.
366       size_t close = StringRef::npos;
367       for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {
368         if (text[i] == '{') {
369           ++braceCount;
370           continue;
371         }
372         if (text[i] == '}' && --braceCount == 0) {
373           close = i;
374           break;
375         }
376       }
377 
378       // Check to see if a closing options brace was found.
379       if (close == StringRef::npos) {
380         return errorHandler(
381             /*rawLoc=*/text.data() - 1,
382             "missing closing '}' while processing pass options");
383       }
384       pipeline.back().options = text.substr(0, close);
385       text = text.substr(close + 1);
386 
387       // Skip checking for '(' because nested pipelines cannot have options.
388     } else if (sep == '(') {
389       text = text.substr(1);
390 
391       // Push the inner pipeline onto the stack to continue processing.
392       pipelineStack.push_back(&pipeline.back().innerPipeline);
393       continue;
394     }
395 
396     // When handling the close parenthesis, we greedily consume them to avoid
397     // empty strings in the pipeline.
398     while (text.consume_front(")")) {
399       // If we try to pop the outer pipeline we have unbalanced parentheses.
400       if (pipelineStack.size() == 1)
401         return errorHandler(/*rawLoc=*/text.data() - 1,
402                             "encountered extra closing ')' creating unbalanced "
403                             "parentheses while parsing pipeline");
404 
405       pipelineStack.pop_back();
406     }
407 
408     // Check if we've finished parsing.
409     if (text.empty())
410       break;
411 
412     // Otherwise, the end of an inner pipeline always has to be followed by
413     // a comma, and then we can continue.
414     if (!text.consume_front(","))
415       return errorHandler(text.data(), "expected ',' after parsing pipeline");
416   }
417 
418   // Check for unbalanced parentheses.
419   if (pipelineStack.size() > 1)
420     return errorHandler(
421         text.data(),
422         "encountered unbalanced parentheses while parsing pipeline");
423 
424   assert(pipelineStack.back() == &pipeline &&
425          "wrong pipeline at the bottom of the stack");
426   return success();
427 }
428 
429 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
430 /// the corresponding registry entry.
resolvePipelineElements(MutableArrayRef<PipelineElement> elements,ErrorHandlerT errorHandler)431 LogicalResult TextualPipeline::resolvePipelineElements(
432     MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) {
433   for (auto &elt : elements)
434     if (failed(resolvePipelineElement(elt, errorHandler)))
435       return failure();
436   return success();
437 }
438 
439 /// Resolve a single element of the pipeline.
440 LogicalResult
resolvePipelineElement(PipelineElement & element,ErrorHandlerT errorHandler)441 TextualPipeline::resolvePipelineElement(PipelineElement &element,
442                                         ErrorHandlerT errorHandler) {
443   // If the inner pipeline of this element is not empty, this is an operation
444   // pipeline.
445   if (!element.innerPipeline.empty())
446     return resolvePipelineElements(element.innerPipeline, errorHandler);
447   // Otherwise, this must be a pass or pass pipeline.
448   // Check to see if a pipeline was registered with this name.
449   auto pipelineRegistryIt = passPipelineRegistry->find(element.name);
450   if (pipelineRegistryIt != passPipelineRegistry->end()) {
451     element.registryEntry = &pipelineRegistryIt->second;
452     return success();
453   }
454 
455   // If not, then this must be a specific pass name.
456   if ((element.registryEntry = Pass::lookupPassInfo(element.name)))
457     return success();
458 
459   // Emit an error for the unknown pass.
460   auto *rawLoc = element.name.data();
461   return errorHandler(rawLoc, "'" + element.name +
462                                   "' does not refer to a "
463                                   "registered pass or pass pipeline");
464 }
465 
466 /// Add the given pipeline elements to the provided pass manager.
addToPipeline(ArrayRef<PipelineElement> elements,OpPassManager & pm,function_ref<LogicalResult (const Twine &)> errorHandler) const467 LogicalResult TextualPipeline::addToPipeline(
468     ArrayRef<PipelineElement> elements, OpPassManager &pm,
469     function_ref<LogicalResult(const Twine &)> errorHandler) const {
470   for (auto &elt : elements) {
471     if (elt.registryEntry) {
472       if (failed(elt.registryEntry->addToPipeline(pm, elt.options,
473                                                   errorHandler))) {
474         return errorHandler("failed to add `" + elt.name + "` with options `" +
475                             elt.options + "`");
476       }
477     } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name),
478                                     errorHandler))) {
479       return errorHandler("failed to add `" + elt.name + "` with options `" +
480                           elt.options + "` to inner pipeline");
481     }
482   }
483   return success();
484 }
485 
486 /// This function parses the textual representation of a pass pipeline, and adds
487 /// the result to 'pm' on success. This function returns failure if the given
488 /// pipeline was invalid. 'errorStream' is an optional parameter that, if
489 /// non-null, will be used to emit errors found during parsing.
parsePassPipeline(StringRef pipeline,OpPassManager & pm,raw_ostream & errorStream)490 LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm,
491                                       raw_ostream &errorStream) {
492   TextualPipeline pipelineParser;
493   if (failed(pipelineParser.initialize(pipeline, errorStream)))
494     return failure();
495   auto errorHandler = [&](Twine msg) {
496     errorStream << msg << "\n";
497     return failure();
498   };
499   if (failed(pipelineParser.addToPipeline(pm, errorHandler)))
500     return failure();
501   return success();
502 }
503 
504 //===----------------------------------------------------------------------===//
505 // PassNameParser
506 //===----------------------------------------------------------------------===//
507 
508 namespace {
509 /// This struct represents the possible data entries in a parsed pass pipeline
510 /// list.
511 struct PassArgData {
PassArgData__anone08a6f2d0b11::PassArgData512   PassArgData() : registryEntry(nullptr) {}
PassArgData__anone08a6f2d0b11::PassArgData513   PassArgData(const PassRegistryEntry *registryEntry)
514       : registryEntry(registryEntry) {}
515 
516   /// This field is used when the parsed option corresponds to a registered pass
517   /// or pass pipeline.
518   const PassRegistryEntry *registryEntry;
519 
520   /// This field is set when instance specific pass options have been provided
521   /// on the command line.
522   StringRef options;
523 
524   /// This field is used when the parsed option corresponds to an explicit
525   /// pipeline.
526   TextualPipeline pipeline;
527 };
528 } // end anonymous namespace
529 
530 namespace llvm {
531 namespace cl {
532 /// Define a valid OptionValue for the command line pass argument.
533 template <>
534 struct OptionValue<PassArgData> final
535     : OptionValueBase<PassArgData, /*isClass=*/true> {
OptionValuellvm::cl::OptionValue536   OptionValue(const PassArgData &value) { this->setValue(value); }
537   OptionValue() = default;
anchorllvm::cl::OptionValue538   void anchor() override {}
539 
hasValuellvm::cl::OptionValue540   bool hasValue() const { return true; }
getValuellvm::cl::OptionValue541   const PassArgData &getValue() const { return value; }
setValuellvm::cl::OptionValue542   void setValue(const PassArgData &value) { this->value = value; }
543 
544   PassArgData value;
545 };
546 } // end namespace cl
547 } // end namespace llvm
548 
549 namespace {
550 
551 /// The name for the command line option used for parsing the textual pass
552 /// pipeline.
553 static constexpr StringLiteral passPipelineArg = "pass-pipeline";
554 
555 /// Adds command line option for each registered pass or pass pipeline, as well
556 /// as textual pass pipelines.
557 struct PassNameParser : public llvm::cl::parser<PassArgData> {
PassNameParser__anone08a6f2d0c11::PassNameParser558   PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {}
559 
560   void initialize();
561   void printOptionInfo(const llvm::cl::Option &opt,
562                        size_t globalWidth) const override;
563   size_t getOptionWidth(const llvm::cl::Option &opt) const override;
564   bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
565              PassArgData &value);
566 
567   /// If true, this parser only parses entries that correspond to a concrete
568   /// pass registry entry, and does not add a `pass-pipeline` argument, does not
569   /// include the options for pass entries, and does not include pass pipelines
570   /// entries.
571   bool passNamesOnly = false;
572 };
573 } // namespace
574 
initialize()575 void PassNameParser::initialize() {
576   llvm::cl::parser<PassArgData>::initialize();
577 
578   /// Add an entry for the textual pass pipeline option.
579   if (!passNamesOnly) {
580     addLiteralOption(passPipelineArg, PassArgData(),
581                      "A textual description of a pass pipeline to run");
582   }
583 
584   /// Add the pass entries.
585   for (const auto &kv : *passRegistry) {
586     addLiteralOption(kv.second.getPassArgument(), &kv.second,
587                      kv.second.getPassDescription());
588   }
589   /// Add the pass pipeline entries.
590   if (!passNamesOnly) {
591     for (const auto &kv : *passPipelineRegistry) {
592       addLiteralOption(kv.second.getPassArgument(), &kv.second,
593                        kv.second.getPassDescription());
594     }
595   }
596 }
597 
printOptionInfo(const llvm::cl::Option & opt,size_t globalWidth) const598 void PassNameParser::printOptionInfo(const llvm::cl::Option &opt,
599                                      size_t globalWidth) const {
600   // If this parser is just parsing pass names, print a simplified option
601   // string.
602   if (passNamesOnly) {
603     llvm::outs() << "  --" << opt.ArgStr << "=<pass-arg>";
604     opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18);
605     return;
606   }
607 
608   // Print the information for the top-level option.
609   if (opt.hasArgStr()) {
610     llvm::outs() << "  --" << opt.ArgStr;
611     opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7);
612   } else {
613     llvm::outs() << "  " << opt.HelpStr << '\n';
614   }
615 
616   // Print the top-level pipeline argument.
617   printOptionHelp(passPipelineArg,
618                   "A textual description of a pass pipeline to run",
619                   /*indent=*/4, globalWidth, /*isTopLevel=*/!opt.hasArgStr());
620 
621   // Functor used to print the ordered entries of a registration map.
622   auto printOrderedEntries = [&](StringRef header, auto &map) {
623     llvm::SmallVector<PassRegistryEntry *, 32> orderedEntries;
624     for (auto &kv : map)
625       orderedEntries.push_back(&kv.second);
626     llvm::array_pod_sort(
627         orderedEntries.begin(), orderedEntries.end(),
628         [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
629           return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
630         });
631 
632     llvm::outs().indent(4) << header << ":\n";
633     for (PassRegistryEntry *entry : orderedEntries)
634       entry->printHelpStr(/*indent=*/6, globalWidth);
635   };
636 
637   // Print the available passes.
638   printOrderedEntries("Passes", *passRegistry);
639 
640   // Print the available pass pipelines.
641   if (!passPipelineRegistry->empty())
642     printOrderedEntries("Pass Pipelines", *passPipelineRegistry);
643 }
644 
getOptionWidth(const llvm::cl::Option & opt) const645 size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const {
646   size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(opt) + 2;
647 
648   // Check for any wider pass or pipeline options.
649   for (auto &entry : *passRegistry)
650     maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
651   for (auto &entry : *passPipelineRegistry)
652     maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
653   return maxWidth;
654 }
655 
parse(llvm::cl::Option & opt,StringRef argName,StringRef arg,PassArgData & value)656 bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName,
657                            StringRef arg, PassArgData &value) {
658   // Handle the pipeline option explicitly.
659   if (argName == passPipelineArg)
660     return failed(value.pipeline.initialize(arg, llvm::errs()));
661 
662   // Otherwise, default to the base for handling.
663   if (llvm::cl::parser<PassArgData>::parse(opt, argName, arg, value))
664     return true;
665   value.options = arg;
666   return false;
667 }
668 
669 //===----------------------------------------------------------------------===//
670 // PassPipelineCLParser
671 //===----------------------------------------------------------------------===//
672 
673 namespace mlir {
674 namespace detail {
675 struct PassPipelineCLParserImpl {
PassPipelineCLParserImplmlir::detail::PassPipelineCLParserImpl676   PassPipelineCLParserImpl(StringRef arg, StringRef description,
677                            bool passNamesOnly)
678       : passList(arg, llvm::cl::desc(description)) {
679     passList.getParser().passNamesOnly = passNamesOnly;
680     passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
681   }
682 
683   /// Returns true if the given pass registry entry was registered at the
684   /// top-level of the parser, i.e. not within an explicit textual pipeline.
containsmlir::detail::PassPipelineCLParserImpl685   bool contains(const PassRegistryEntry *entry) const {
686     return llvm::any_of(passList, [&](const PassArgData &data) {
687       return data.registryEntry == entry;
688     });
689   }
690 
691   /// The set of passes and pass pipelines to run.
692   llvm::cl::list<PassArgData, bool, PassNameParser> passList;
693 };
694 } // end namespace detail
695 } // end namespace mlir
696 
697 /// Construct a pass pipeline parser with the given command line description.
PassPipelineCLParser(StringRef arg,StringRef description)698 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description)
699     : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
700           arg, description, /*passNamesOnly=*/false)) {}
~PassPipelineCLParser()701 PassPipelineCLParser::~PassPipelineCLParser() {}
702 
703 /// Returns true if this parser contains any valid options to add.
hasAnyOccurrences() const704 bool PassPipelineCLParser::hasAnyOccurrences() const {
705   return impl->passList.getNumOccurrences() != 0;
706 }
707 
708 /// Returns true if the given pass registry entry was registered at the
709 /// top-level of the parser, i.e. not within an explicit textual pipeline.
contains(const PassRegistryEntry * entry) const710 bool PassPipelineCLParser::contains(const PassRegistryEntry *entry) const {
711   return impl->contains(entry);
712 }
713 
714 /// Adds the passes defined by this parser entry to the given pass manager.
addToPipeline(OpPassManager & pm,function_ref<LogicalResult (const Twine &)> errorHandler) const715 LogicalResult PassPipelineCLParser::addToPipeline(
716     OpPassManager &pm,
717     function_ref<LogicalResult(const Twine &)> errorHandler) const {
718   for (auto &passIt : impl->passList) {
719     if (passIt.registryEntry) {
720       if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options,
721                                                      errorHandler)))
722         return failure();
723     } else {
724       OpPassManager::Nesting nesting = pm.getNesting();
725       pm.setNesting(OpPassManager::Nesting::Explicit);
726       LogicalResult status = passIt.pipeline.addToPipeline(pm, errorHandler);
727       pm.setNesting(nesting);
728       if (failed(status))
729         return failure();
730     }
731   }
732   return success();
733 }
734 
735 //===----------------------------------------------------------------------===//
736 // PassNameCLParser
737 
738 /// Construct a pass pipeline parser with the given command line description.
PassNameCLParser(StringRef arg,StringRef description)739 PassNameCLParser::PassNameCLParser(StringRef arg, StringRef description)
740     : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
741           arg, description, /*passNamesOnly=*/true)) {
742   impl->passList.setMiscFlag(llvm::cl::CommaSeparated);
743 }
~PassNameCLParser()744 PassNameCLParser::~PassNameCLParser() {}
745 
746 /// Returns true if this parser contains any valid options to add.
hasAnyOccurrences() const747 bool PassNameCLParser::hasAnyOccurrences() const {
748   return impl->passList.getNumOccurrences() != 0;
749 }
750 
751 /// Returns true if the given pass registry entry was registered at the
752 /// top-level of the parser, i.e. not within an explicit textual pipeline.
contains(const PassRegistryEntry * entry) const753 bool PassNameCLParser::contains(const PassRegistryEntry *entry) const {
754   return impl->contains(entry);
755 }
756