1 //===- Translation.cpp - Translation registry -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Definitions of the translation registry.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Translation.h"
14 #include "mlir/IR/AsmState.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Verifier.h"
18 #include "mlir/Parser.h"
19 #include "mlir/Support/FileUtilities.h"
20 #include "mlir/Support/ToolUtilities.h"
21 #include "llvm/Support/InitLLVM.h"
22 #include "llvm/Support/SourceMgr.h"
23 #include "llvm/Support/ToolOutputFile.h"
24 
25 using namespace mlir;
26 
27 //===----------------------------------------------------------------------===//
28 // Translation Registry
29 //===----------------------------------------------------------------------===//
30 
31 /// Get the mutable static map between registered file-to-file MLIR translations
32 /// and the TranslateFunctions that perform those translations.
getTranslationRegistry()33 static llvm::StringMap<TranslateFunction> &getTranslationRegistry() {
34   static llvm::StringMap<TranslateFunction> translationRegistry;
35   return translationRegistry;
36 }
37 
38 /// Register the given translation.
registerTranslation(StringRef name,const TranslateFunction & function)39 static void registerTranslation(StringRef name,
40                                 const TranslateFunction &function) {
41   auto &translationRegistry = getTranslationRegistry();
42   if (translationRegistry.find(name) != translationRegistry.end())
43     llvm::report_fatal_error(
44         "Attempting to overwrite an existing <file-to-file> function");
45   assert(function &&
46          "Attempting to register an empty translate <file-to-file> function");
47   translationRegistry[name] = function;
48 }
49 
TranslateRegistration(StringRef name,const TranslateFunction & function)50 TranslateRegistration::TranslateRegistration(
51     StringRef name, const TranslateFunction &function) {
52   registerTranslation(name, function);
53 }
54 
55 //===----------------------------------------------------------------------===//
56 // Translation to MLIR
57 //===----------------------------------------------------------------------===//
58 
59 // Puts `function` into the to-MLIR translation registry unless there is already
60 // a function registered for the same name.
registerTranslateToMLIRFunction(StringRef name,const TranslateSourceMgrToMLIRFunction & function)61 static void registerTranslateToMLIRFunction(
62     StringRef name, const TranslateSourceMgrToMLIRFunction &function) {
63   auto wrappedFn = [function](llvm::SourceMgr &sourceMgr, raw_ostream &output,
64                               MLIRContext *context) {
65     OwningModuleRef module = function(sourceMgr, context);
66     if (!module || failed(verify(*module)))
67       return failure();
68     module->print(output);
69     return success();
70   };
71   registerTranslation(name, wrappedFn);
72 }
73 
TranslateToMLIRRegistration(StringRef name,const TranslateSourceMgrToMLIRFunction & function)74 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
75     StringRef name, const TranslateSourceMgrToMLIRFunction &function) {
76   registerTranslateToMLIRFunction(name, function);
77 }
78 
79 /// Wraps `function` with a lambda that extracts a StringRef from a source
80 /// manager and registers the wrapper lambda as a to-MLIR conversion.
TranslateToMLIRRegistration(StringRef name,const TranslateStringRefToMLIRFunction & function)81 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
82     StringRef name, const TranslateStringRefToMLIRFunction &function) {
83   registerTranslateToMLIRFunction(
84       name, [function](llvm::SourceMgr &sourceMgr, MLIRContext *ctx) {
85         const llvm::MemoryBuffer *buffer =
86             sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
87         return function(buffer->getBuffer(), ctx);
88       });
89 }
90 
91 //===----------------------------------------------------------------------===//
92 // Translation from MLIR
93 //===----------------------------------------------------------------------===//
94 
TranslateFromMLIRRegistration(StringRef name,const TranslateFromMLIRFunction & function,std::function<void (DialectRegistry &)> dialectRegistration)95 TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
96     StringRef name, const TranslateFromMLIRFunction &function,
97     std::function<void(DialectRegistry &)> dialectRegistration) {
98   registerTranslation(name, [function, dialectRegistration](
99                                 llvm::SourceMgr &sourceMgr, raw_ostream &output,
100                                 MLIRContext *context) {
101     DialectRegistry registry;
102     dialectRegistration(registry);
103     context->appendDialectRegistry(registry);
104     auto module = OwningModuleRef(parseSourceFile(sourceMgr, context));
105     if (!module)
106       return failure();
107     return function(module.get(), output);
108   });
109 }
110 
111 //===----------------------------------------------------------------------===//
112 // Translation Parser
113 //===----------------------------------------------------------------------===//
114 
TranslationParser(llvm::cl::Option & opt)115 TranslationParser::TranslationParser(llvm::cl::Option &opt)
116     : llvm::cl::parser<const TranslateFunction *>(opt) {
117   for (const auto &kv : getTranslationRegistry())
118     addLiteralOption(kv.first(), &kv.second, kv.first());
119 }
120 
printOptionInfo(const llvm::cl::Option & o,size_t globalWidth) const121 void TranslationParser::printOptionInfo(const llvm::cl::Option &o,
122                                         size_t globalWidth) const {
123   TranslationParser *tp = const_cast<TranslationParser *>(this);
124   llvm::array_pod_sort(tp->Values.begin(), tp->Values.end(),
125                        [](const TranslationParser::OptionInfo *lhs,
126                           const TranslationParser::OptionInfo *rhs) {
127                          return lhs->Name.compare(rhs->Name);
128                        });
129   llvm::cl::parser<const TranslateFunction *>::printOptionInfo(o, globalWidth);
130 }
131 
mlirTranslateMain(int argc,char ** argv,llvm::StringRef toolName)132 LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
133                                       llvm::StringRef toolName) {
134 
135   static llvm::cl::opt<std::string> inputFilename(
136       llvm::cl::Positional, llvm::cl::desc("<input file>"),
137       llvm::cl::init("-"));
138 
139   static llvm::cl::opt<std::string> outputFilename(
140       "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
141       llvm::cl::init("-"));
142 
143   static llvm::cl::opt<bool> splitInputFile(
144       "split-input-file",
145       llvm::cl::desc("Split the input file into pieces and "
146                      "process each chunk independently"),
147       llvm::cl::init(false));
148 
149   static llvm::cl::opt<bool> verifyDiagnostics(
150       "verify-diagnostics",
151       llvm::cl::desc("Check that emitted diagnostics match "
152                      "expected-* lines on the corresponding line"),
153       llvm::cl::init(false));
154 
155   llvm::InitLLVM y(argc, argv);
156 
157   // Add flags for all the registered translations.
158   llvm::cl::opt<const TranslateFunction *, false, TranslationParser>
159       translationRequested("", llvm::cl::desc("Translation to perform"),
160                            llvm::cl::Required);
161   registerAsmPrinterCLOptions();
162   registerMLIRContextCLOptions();
163   llvm::cl::ParseCommandLineOptions(argc, argv, toolName);
164 
165   std::string errorMessage;
166   auto input = openInputFile(inputFilename, &errorMessage);
167   if (!input) {
168     llvm::errs() << errorMessage << "\n";
169     return failure();
170   }
171 
172   auto output = openOutputFile(outputFilename, &errorMessage);
173   if (!output) {
174     llvm::errs() << errorMessage << "\n";
175     return failure();
176   }
177 
178   // Processes the memory buffer with a new MLIRContext.
179   auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
180                            raw_ostream &os) {
181     MLIRContext context;
182     context.printOpOnDiagnostic(!verifyDiagnostics);
183     llvm::SourceMgr sourceMgr;
184     sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());
185 
186     if (!verifyDiagnostics) {
187       SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
188       return (*translationRequested)(sourceMgr, os, &context);
189     }
190 
191     // In the diagnostic verification flow, we ignore whether the translation
192     // failed (in most cases, it is expected to fail). Instead, we check if the
193     // diagnostics were produced as expected.
194     SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context);
195     (void)(*translationRequested)(sourceMgr, os, &context);
196     return sourceMgrHandler.verify();
197   };
198 
199   if (splitInputFile) {
200     if (failed(splitAndProcessBuffer(std::move(input), processBuffer,
201                                      output->os())))
202       return failure();
203   } else if (failed(processBuffer(std::move(input), output->os()))) {
204     return failure();
205   }
206 
207   output->keep();
208   return success();
209 }
210