1 //===- Translation.cpp - Translation registry -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Definitions of the translation registry.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Translation.h"
14 #include "mlir/IR/AsmState.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/Verifier.h"
17 #include "mlir/Parser.h"
18 #include "mlir/Support/FileUtilities.h"
19 #include "mlir/Support/ToolUtilities.h"
20 #include "llvm/Support/InitLLVM.h"
21 #include "llvm/Support/SourceMgr.h"
22 #include "llvm/Support/ToolOutputFile.h"
23 
24 using namespace mlir;
25 
26 //===----------------------------------------------------------------------===//
27 // Translation Registry
28 //===----------------------------------------------------------------------===//
29 
30 /// Get the mutable static map between registered file-to-file MLIR translations
31 /// and the TranslateFunctions that perform those translations.
getTranslationRegistry()32 static llvm::StringMap<TranslateFunction> &getTranslationRegistry() {
33   static llvm::StringMap<TranslateFunction> translationRegistry;
34   return translationRegistry;
35 }
36 
37 /// Register the given translation.
registerTranslation(StringRef name,const TranslateFunction & function)38 static void registerTranslation(StringRef name,
39                                 const TranslateFunction &function) {
40   auto &translationRegistry = getTranslationRegistry();
41   if (translationRegistry.find(name) != translationRegistry.end())
42     llvm::report_fatal_error(
43         "Attempting to overwrite an existing <file-to-file> function");
44   assert(function &&
45          "Attempting to register an empty translate <file-to-file> function");
46   translationRegistry[name] = function;
47 }
48 
TranslateRegistration(StringRef name,const TranslateFunction & function)49 TranslateRegistration::TranslateRegistration(
50     StringRef name, const TranslateFunction &function) {
51   registerTranslation(name, function);
52 }
53 
54 //===----------------------------------------------------------------------===//
55 // Translation to MLIR
56 //===----------------------------------------------------------------------===//
57 
58 // Puts `function` into the to-MLIR translation registry unless there is already
59 // a function registered for the same name.
registerTranslateToMLIRFunction(StringRef name,const TranslateSourceMgrToMLIRFunction & function)60 static void registerTranslateToMLIRFunction(
61     StringRef name, const TranslateSourceMgrToMLIRFunction &function) {
62   auto wrappedFn = [function](llvm::SourceMgr &sourceMgr, raw_ostream &output,
63                               MLIRContext *context) {
64     OwningModuleRef module = function(sourceMgr, context);
65     if (!module || failed(verify(*module)))
66       return failure();
67     module->print(output);
68     return success();
69   };
70   registerTranslation(name, wrappedFn);
71 }
72 
TranslateToMLIRRegistration(StringRef name,const TranslateSourceMgrToMLIRFunction & function)73 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
74     StringRef name, const TranslateSourceMgrToMLIRFunction &function) {
75   registerTranslateToMLIRFunction(name, function);
76 }
77 
78 /// Wraps `function` with a lambda that extracts a StringRef from a source
79 /// manager and registers the wrapper lambda as a to-MLIR conversion.
TranslateToMLIRRegistration(StringRef name,const TranslateStringRefToMLIRFunction & function)80 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
81     StringRef name, const TranslateStringRefToMLIRFunction &function) {
82   registerTranslateToMLIRFunction(
83       name, [function](llvm::SourceMgr &sourceMgr, MLIRContext *ctx) {
84         const llvm::MemoryBuffer *buffer =
85             sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
86         return function(buffer->getBuffer(), ctx);
87       });
88 }
89 
90 //===----------------------------------------------------------------------===//
91 // Translation from MLIR
92 //===----------------------------------------------------------------------===//
93 
TranslateFromMLIRRegistration(StringRef name,const TranslateFromMLIRFunction & function,std::function<void (DialectRegistry &)> dialectRegistration)94 TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
95     StringRef name, const TranslateFromMLIRFunction &function,
96     std::function<void(DialectRegistry &)> dialectRegistration) {
97   registerTranslation(name, [function, dialectRegistration](
98                                 llvm::SourceMgr &sourceMgr, raw_ostream &output,
99                                 MLIRContext *context) {
100     dialectRegistration(context->getDialectRegistry());
101     auto module = OwningModuleRef(parseSourceFile(sourceMgr, context));
102     if (!module)
103       return failure();
104     return function(module.get(), output);
105   });
106 }
107 
108 //===----------------------------------------------------------------------===//
109 // Translation Parser
110 //===----------------------------------------------------------------------===//
111 
TranslationParser(llvm::cl::Option & opt)112 TranslationParser::TranslationParser(llvm::cl::Option &opt)
113     : llvm::cl::parser<const TranslateFunction *>(opt) {
114   for (const auto &kv : getTranslationRegistry())
115     addLiteralOption(kv.first(), &kv.second, kv.first());
116 }
117 
printOptionInfo(const llvm::cl::Option & o,size_t globalWidth) const118 void TranslationParser::printOptionInfo(const llvm::cl::Option &o,
119                                         size_t globalWidth) const {
120   TranslationParser *tp = const_cast<TranslationParser *>(this);
121   llvm::array_pod_sort(tp->Values.begin(), tp->Values.end(),
122                        [](const TranslationParser::OptionInfo *lhs,
123                           const TranslationParser::OptionInfo *rhs) {
124                          return lhs->Name.compare(rhs->Name);
125                        });
126   llvm::cl::parser<const TranslateFunction *>::printOptionInfo(o, globalWidth);
127 }
128 
mlirTranslateMain(int argc,char ** argv,llvm::StringRef toolName)129 LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
130                                       llvm::StringRef toolName) {
131 
132   static llvm::cl::opt<std::string> inputFilename(
133       llvm::cl::Positional, llvm::cl::desc("<input file>"),
134       llvm::cl::init("-"));
135 
136   static llvm::cl::opt<std::string> outputFilename(
137       "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
138       llvm::cl::init("-"));
139 
140   static llvm::cl::opt<bool> splitInputFile(
141       "split-input-file",
142       llvm::cl::desc("Split the input file into pieces and "
143                      "process each chunk independently"),
144       llvm::cl::init(false));
145 
146   static llvm::cl::opt<bool> verifyDiagnostics(
147       "verify-diagnostics",
148       llvm::cl::desc("Check that emitted diagnostics match "
149                      "expected-* lines on the corresponding line"),
150       llvm::cl::init(false));
151 
152   llvm::InitLLVM y(argc, argv);
153 
154   // Add flags for all the registered translations.
155   llvm::cl::opt<const TranslateFunction *, false, TranslationParser>
156       translationRequested("", llvm::cl::desc("Translation to perform"),
157                            llvm::cl::Required);
158   registerAsmPrinterCLOptions();
159   registerMLIRContextCLOptions();
160   llvm::cl::ParseCommandLineOptions(argc, argv, toolName);
161 
162   std::string errorMessage;
163   auto input = openInputFile(inputFilename, &errorMessage);
164   if (!input) {
165     llvm::errs() << errorMessage << "\n";
166     return failure();
167   }
168 
169   auto output = openOutputFile(outputFilename, &errorMessage);
170   if (!output) {
171     llvm::errs() << errorMessage << "\n";
172     return failure();
173   }
174 
175   // Processes the memory buffer with a new MLIRContext.
176   auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
177                            raw_ostream &os) {
178     MLIRContext context;
179     context.printOpOnDiagnostic(!verifyDiagnostics);
180     llvm::SourceMgr sourceMgr;
181     sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());
182 
183     if (!verifyDiagnostics) {
184       SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
185       return (*translationRequested)(sourceMgr, os, &context);
186     }
187 
188     // In the diagnostic verification flow, we ignore whether the translation
189     // failed (in most cases, it is expected to fail). Instead, we check if the
190     // diagnostics were produced as expected.
191     SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context);
192     (*translationRequested)(sourceMgr, os, &context);
193     return sourceMgrHandler.verify();
194   };
195 
196   if (splitInputFile) {
197     if (failed(splitAndProcessBuffer(std::move(input), processBuffer,
198                                      output->os())))
199       return failure();
200   } else if (failed(processBuffer(std::move(input), output->os()))) {
201     return failure();
202   }
203 
204   output->keep();
205   return success();
206 }
207