1# Chapter 6: Lowering to LLVM and CodeGeneration
2
3[TOC]
4
5In the [previous chapter](Ch-5.md), we introduced the
6[dialect conversion](../../DialectConversion.md) framework and partially lowered
7many of the `Toy` operations to affine loop nests for optimization. In this
8chapter, we will finally lower to LLVM for code generation.
9
10## Lowering to LLVM
11
12For this lowering, we will again use the dialect conversion framework to perform
13the heavy lifting. However, this time, we will be performing a full conversion
14to the [LLVM dialect](../../Dialects/LLVM.md). Thankfully, we have already
15lowered all but one of the `toy` operations, with the last being `toy.print`.
16Before going over the conversion to LLVM, let's lower the `toy.print` operation.
17We will lower this operation to a non-affine loop nest that invokes `printf` for
18each element. Note that, because the dialect conversion framework supports
19[transitive lowering](../../../getting_started/Glossary.md/#transitive-lowering), we don't need to
20directly emit operations in the LLVM dialect. By transitive lowering, we mean
21that the conversion framework may apply multiple patterns to fully legalize an
22operation. In this example, we are generating a structured loop nest instead of
23the branch-form in the LLVM dialect. As long as we then have a lowering from the
24loop operations to LLVM, the lowering will still succeed.
25
26During lowering we can get, or build, the declaration for printf as so:
27
28```c++
29/// Return a symbol reference to the printf function, inserting it into the
30/// module if necessary.
31static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
32                                           ModuleOp module,
33                                           LLVM::LLVMDialect *llvmDialect) {
34  auto *context = module.getContext();
35  if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
36    return SymbolRefAttr::get("printf", context);
37
38  // Create a function declaration for printf, the signature is:
39  //   * `i32 (i8*, ...)`
40  auto llvmI32Ty = IntegerType::get(context, 32);
41  auto llvmI8PtrTy =
42      LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
43  auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
44                                                /*isVarArg=*/true);
45
46  // Insert the printf function into the body of the parent module.
47  PatternRewriter::InsertionGuard insertGuard(rewriter);
48  rewriter.setInsertionPointToStart(module.getBody());
49  rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
50  return SymbolRefAttr::get("printf", context);
51}
52```
53
54Now that the lowering for the printf operation has been defined, we can specify
55the components necessary for the lowering. These are largely the same as the
56components defined in the [previous chapter](Ch-5.md).
57
58### Conversion Target
59
60For this conversion, aside from the top-level module, we will be lowering
61everything to the LLVM dialect.
62
63```c++
64  mlir::ConversionTarget target(getContext());
65  target.addLegalDialect<mlir::LLVMDialect>();
66  target.addLegalOp<mlir::ModuleOp>();
67```
68
69### Type Converter
70
71This lowering will also transform the MemRef types which are currently being
72operated on into a representation in LLVM. To perform this conversion, we use a
73TypeConverter as part of the lowering. This converter specifies how one type
74maps to another. This is necessary now that we are performing more complicated
75lowerings involving block arguments. Given that we don't have any
76Toy-dialect-specific types that need to be lowered, the default converter is
77enough for our use case.
78
79```c++
80  LLVMTypeConverter typeConverter(&getContext());
81```
82
83### Conversion Patterns
84
85Now that the conversion target has been defined, we need to provide the patterns
86used for lowering. At this point in the compilation process, we have a
87combination of `toy`, `affine`, and `std` operations. Luckily, the `std` and
88`affine` dialects already provide the set of patterns needed to transform them
89into LLVM dialect. These patterns allow for lowering the IR in multiple stages
90by relying on [transitive lowering](../../../getting_started/Glossary.md/#transitive-lowering).
91
92```c++
93  mlir::RewritePatternSet patterns(&getContext());
94  mlir::populateAffineToStdConversionPatterns(patterns, &getContext());
95  mlir::populateLoopToStdConversionPatterns(patterns, &getContext());
96  mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
97
98  // The only remaining operation, to lower from the `toy` dialect, is the
99  // PrintOp.
100  patterns.add<PrintOpLowering>(&getContext());
101```
102
103### Full Lowering
104
105We want to completely lower to LLVM, so we use a `FullConversion`. This ensures
106that only legal operations will remain after the conversion.
107
108```c++
109  mlir::ModuleOp module = getOperation();
110  if (mlir::failed(mlir::applyFullConversion(module, target, patterns)))
111    signalPassFailure();
112```
113
114Looking back at our current working example:
115
116```mlir
117func @main() {
118  %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
119  %2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
120  %3 = toy.mul %2, %2 : tensor<3x2xf64>
121  toy.print %3 : tensor<3x2xf64>
122  toy.return
123}
124```
125
126We can now lower down to the LLVM dialect, which produces the following code:
127
128```mlir
129llvm.func @free(!llvm<"i8*">)
130llvm.func @printf(!llvm<"i8*">, ...) -> i32
131llvm.func @malloc(i64) -> !llvm<"i8*">
132llvm.func @main() {
133  %0 = llvm.mlir.constant(1.000000e+00 : f64) : f64
134  %1 = llvm.mlir.constant(2.000000e+00 : f64) : f64
135
136  ...
137
138^bb16:
139  %221 = llvm.extractvalue %25[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
140  %222 = llvm.mlir.constant(0 : index) : i64
141  %223 = llvm.mlir.constant(2 : index) : i64
142  %224 = llvm.mul %214, %223 : i64
143  %225 = llvm.add %222, %224 : i64
144  %226 = llvm.mlir.constant(1 : index) : i64
145  %227 = llvm.mul %219, %226 : i64
146  %228 = llvm.add %225, %227 : i64
147  %229 = llvm.getelementptr %221[%228] : (!llvm."double*">, i64) -> !llvm<"f64*">
148  %230 = llvm.load %229 : !llvm<"double*">
149  %231 = llvm.call @printf(%207, %230) : (!llvm<"i8*">, f64) -> i32
150  %232 = llvm.add %219, %218 : i64
151  llvm.br ^bb15(%232 : i64)
152
153  ...
154
155^bb18:
156  %235 = llvm.extractvalue %65[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
157  %236 = llvm.bitcast %235 : !llvm<"double*"> to !llvm<"i8*">
158  llvm.call @free(%236) : (!llvm<"i8*">) -> ()
159  %237 = llvm.extractvalue %45[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
160  %238 = llvm.bitcast %237 : !llvm<"double*"> to !llvm<"i8*">
161  llvm.call @free(%238) : (!llvm<"i8*">) -> ()
162  %239 = llvm.extractvalue %25[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
163  %240 = llvm.bitcast %239 : !llvm<"double*"> to !llvm<"i8*">
164  llvm.call @free(%240) : (!llvm<"i8*">) -> ()
165  llvm.return
166}
167```
168
169See [Conversion to the LLVM IR Dialect](../../ConversionToLLVMDialect.md) for
170more in-depth details on lowering to the LLVM dialect.
171
172## CodeGen: Getting Out of MLIR
173
174At this point we are right at the cusp of code generation. We can generate code
175in the LLVM dialect, so now we just need to export to LLVM IR and setup a JIT to
176run it.
177
178### Emitting LLVM IR
179
180Now that our module is comprised only of operations in the LLVM dialect, we can
181export to LLVM IR. To do this programmatically, we can invoke the following
182utility:
183
184```c++
185  std::unique_ptr<llvm::Module> llvmModule = mlir::translateModuleToLLVMIR(module);
186  if (!llvmModule)
187    /* ... an error was encountered ... */
188```
189
190Exporting our module to LLVM IR generates:
191
192```llvm
193define void @main() {
194  ...
195
196102:
197  %103 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %8, 0
198  %104 = mul i64 %96, 2
199  %105 = add i64 0, %104
200  %106 = mul i64 %100, 1
201  %107 = add i64 %105, %106
202  %108 = getelementptr double, double* %103, i64 %107
203  %109 = load double, double* %108
204  %110 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double %109)
205  %111 = add i64 %100, 1
206  br label %99
207
208  ...
209
210115:
211  %116 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %24, 0
212  %117 = bitcast double* %116 to i8*
213  call void @free(i8* %117)
214  %118 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %16, 0
215  %119 = bitcast double* %118 to i8*
216  call void @free(i8* %119)
217  %120 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %8, 0
218  %121 = bitcast double* %120 to i8*
219  call void @free(i8* %121)
220  ret void
221}
222```
223
224If we enable optimization on the generated LLVM IR, we can trim this down quite
225a bit:
226
227```llvm
228define void @main()
229  %0 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 1.000000e+00)
230  %1 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 1.600000e+01)
231  %putchar = tail call i32 @putchar(i32 10)
232  %2 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 4.000000e+00)
233  %3 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 2.500000e+01)
234  %putchar.1 = tail call i32 @putchar(i32 10)
235  %4 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 9.000000e+00)
236  %5 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 3.600000e+01)
237  %putchar.2 = tail call i32 @putchar(i32 10)
238  ret void
239}
240```
241
242The full code listing for dumping LLVM IR can be found in
243`examples/toy/Ch6/toy.cpp` in the `dumpLLVMIR()` function:
244
245```c++
246
247int dumpLLVMIR(mlir::ModuleOp module) {
248  // Translate the module, that contains the LLVM dialect, to LLVM IR. Use a
249  // fresh LLVM IR context. (Note that LLVM is not thread-safe and any
250  // concurrent use of a context requires external locking.)
251  llvm::LLVMContext llvmContext;
252  auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext);
253  if (!llvmModule) {
254    llvm::errs() << "Failed to emit LLVM IR\n";
255    return -1;
256  }
257
258  // Initialize LLVM targets.
259  llvm::InitializeNativeTarget();
260  llvm::InitializeNativeTargetAsmPrinter();
261  mlir::ExecutionEngine::setupTargetTriple(llvmModule.get());
262
263  /// Optionally run an optimization pipeline over the llvm module.
264  auto optPipeline = mlir::makeOptimizingTransformer(
265      /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0,
266      /*targetMachine=*/nullptr);
267  if (auto err = optPipeline(llvmModule.get())) {
268    llvm::errs() << "Failed to optimize LLVM IR " << err << "\n";
269    return -1;
270  }
271  llvm::errs() << *llvmModule << "\n";
272  return 0;
273}
274```
275
276### Setting up a JIT
277
278Setting up a JIT to run the module containing the LLVM dialect can be done using
279the `mlir::ExecutionEngine` infrastructure. This is a utility wrapper around
280LLVM's JIT that accepts `.mlir` as input. The full code listing for setting up
281the JIT can be found in `Ch6/toyc.cpp` in the `runJit()` function:
282
283```c++
284int runJit(mlir::ModuleOp module) {
285  // Initialize LLVM targets.
286  llvm::InitializeNativeTarget();
287  llvm::InitializeNativeTargetAsmPrinter();
288
289  // An optimization pipeline to use within the execution engine.
290  auto optPipeline = mlir::makeOptimizingTransformer(
291      /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0,
292      /*targetMachine=*/nullptr);
293
294  // Create an MLIR execution engine. The execution engine eagerly JIT-compiles
295  // the module.
296  auto maybeEngine = mlir::ExecutionEngine::create(module,
297      /*llvmModuleBuilder=*/nullptr, optPipeline);
298  assert(maybeEngine && "failed to construct an execution engine");
299  auto &engine = maybeEngine.get();
300
301  // Invoke the JIT-compiled function.
302  auto invocationResult = engine->invoke("main");
303  if (invocationResult) {
304    llvm::errs() << "JIT invocation failed\n";
305    return -1;
306  }
307
308  return 0;
309}
310```
311
312You can play around with it from the build directory:
313
314```shell
315$ echo 'def main() { print([[1, 2], [3, 4]]); }' | ./bin/toyc-ch6 -emit=jit
3161.000000 2.000000
3173.000000 4.000000
318```
319
320You can also play with `-emit=mlir`, `-emit=mlir-affine`, `-emit=mlir-llvm`, and
321`-emit=llvm` to compare the various levels of IR involved. Also try options like
322[`--print-ir-after-all`](../../PassManagement.md/#ir-printing) to track the
323evolution of the IR throughout the pipeline.
324
325The example code used throughout this section can be found in
326test/Examples/Toy/Ch6/llvm-lowering.mlir.
327
328So far, we have worked with primitive data types. In the
329[next chapter](Ch-7.md), we will add a composite `struct` type.
330