1# Chapter 6: Lowering to LLVM and CodeGeneration
2
3[TOC]
4
5In the [previous chapter](Ch-5.md), we introduced the
6[dialect conversion](../../DialectConversion.md) framework and partially lowered
7many of the `Toy` operations to affine loop nests for optimization. In this
8chapter, we will finally lower to LLVM for code generation.
9
10## Lowering to LLVM
11
12For this lowering, we will again use the dialect conversion framework to perform
13the heavy lifting. However, this time, we will be performing a full conversion
14to the [LLVM dialect](../../Dialects/LLVM.md). Thankfully, we have already
15lowered all but one of the `toy` operations, with the last being `toy.print`.
16Before going over the conversion to LLVM, let's lower the `toy.print` operation.
17We will lower this operation to a non-affine loop nest that invokes `printf` for
18each element. Note that, because the dialect conversion framework supports
19[transitive lowering](../../../getting_started/Glossary.md#transitive-lowering), we don't need to
20directly emit operations in the LLVM dialect. By transitive lowering, we mean
21that the conversion framework may apply multiple patterns to fully legalize an
22operation. In this example, we are generating a structured loop nest instead of
23the branch-form in the LLVM dialect. As long as we then have a lowering from the
24loop operations to LLVM, the lowering will still succeed.
25
26During lowering we can get, or build, the declaration for printf as so:
27
28```c++
29/// Return a symbol reference to the printf function, inserting it into the
30/// module if necessary.
31static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
32                                           ModuleOp module,
33                                           LLVM::LLVMDialect *llvmDialect) {
34  auto *context = module.getContext();
35  if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
36    return SymbolRefAttr::get("printf", context);
37
38  // Create a function declaration for printf, the signature is:
39  //   * `i32 (i8*, ...)`
40  auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
41  auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
42  auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy,
43                                                  /*isVarArg=*/true);
44
45  // Insert the printf function into the body of the parent module.
46  PatternRewriter::InsertionGuard insertGuard(rewriter);
47  rewriter.setInsertionPointToStart(module.getBody());
48  rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
49  return SymbolRefAttr::get("printf", context);
50}
51```
52
53Now that the lowering for the printf operation has been defined, we can specify
54the components necessary for the lowering. These are largely the same as the
55components defined in the [previous chapter](Ch-5.md).
56
57### Conversion Target
58
59For this conversion, aside from the top-level module, we will be lowering
60everything to the LLVM dialect.
61
62```c++
63  mlir::ConversionTarget target(getContext());
64  target.addLegalDialect<mlir::LLVMDialect>();
65  target.addLegalOp<mlir::ModuleOp, mlir::ModuleTerminatorOp>();
66```
67
68### Type Converter
69
70This lowering will also transform the MemRef types which are currently being
71operated on into a representation in LLVM. To perform this conversion, we use a
72TypeConverter as part of the lowering. This converter specifies how one type
73maps to another. This is necessary now that we are performing more complicated
74lowerings involving block arguments. Given that we don't have any
75Toy-dialect-specific types that need to be lowered, the default converter is
76enough for our use case.
77
78```c++
79  LLVMTypeConverter typeConverter(&getContext());
80```
81
82### Conversion Patterns
83
84Now that the conversion target has been defined, we need to provide the patterns
85used for lowering. At this point in the compilation process, we have a
86combination of `toy`, `affine`, and `std` operations. Luckily, the `std` and
87`affine` dialects already provide the set of patterns needed to transform them
88into LLVM dialect. These patterns allow for lowering the IR in multiple stages
89by relying on [transitive lowering](../../../getting_started/Glossary.md#transitive-lowering).
90
91```c++
92  mlir::OwningRewritePatternList patterns;
93  mlir::populateAffineToStdConversionPatterns(patterns, &getContext());
94  mlir::populateLoopToStdConversionPatterns(patterns, &getContext());
95  mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
96
97  // The only remaining operation, to lower from the `toy` dialect, is the
98  // PrintOp.
99  patterns.insert<PrintOpLowering>(&getContext());
100```
101
102### Full Lowering
103
104We want to completely lower to LLVM, so we use a `FullConversion`. This ensures
105that only legal operations will remain after the conversion.
106
107```c++
108  mlir::ModuleOp module = getOperation();
109  if (mlir::failed(mlir::applyFullConversion(module, target, patterns)))
110    signalPassFailure();
111```
112
113Looking back at our current working example:
114
115```mlir
116func @main() {
117  %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
118  %2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
119  %3 = toy.mul %2, %2 : tensor<3x2xf64>
120  toy.print %3 : tensor<3x2xf64>
121  toy.return
122}
123```
124
125We can now lower down to the LLVM dialect, which produces the following code:
126
127```mlir
128llvm.func @free(!llvm<"i8*">)
129llvm.func @printf(!llvm<"i8*">, ...) -> !llvm.i32
130llvm.func @malloc(!llvm.i64) -> !llvm<"i8*">
131llvm.func @main() {
132  %0 = llvm.mlir.constant(1.000000e+00 : f64) : !llvm.double
133  %1 = llvm.mlir.constant(2.000000e+00 : f64) : !llvm.double
134
135  ...
136
137^bb16:
138  %221 = llvm.extractvalue %25[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
139  %222 = llvm.mlir.constant(0 : index) : !llvm.i64
140  %223 = llvm.mlir.constant(2 : index) : !llvm.i64
141  %224 = llvm.mul %214, %223 : !llvm.i64
142  %225 = llvm.add %222, %224 : !llvm.i64
143  %226 = llvm.mlir.constant(1 : index) : !llvm.i64
144  %227 = llvm.mul %219, %226 : !llvm.i64
145  %228 = llvm.add %225, %227 : !llvm.i64
146  %229 = llvm.getelementptr %221[%228] : (!llvm<"double*">, !llvm.i64) -> !llvm<"double*">
147  %230 = llvm.load %229 : !llvm<"double*">
148  %231 = llvm.call @printf(%207, %230) : (!llvm<"i8*">, !llvm.double) -> !llvm.i32
149  %232 = llvm.add %219, %218 : !llvm.i64
150  llvm.br ^bb15(%232 : !llvm.i64)
151
152  ...
153
154^bb18:
155  %235 = llvm.extractvalue %65[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
156  %236 = llvm.bitcast %235 : !llvm<"double*"> to !llvm<"i8*">
157  llvm.call @free(%236) : (!llvm<"i8*">) -> ()
158  %237 = llvm.extractvalue %45[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
159  %238 = llvm.bitcast %237 : !llvm<"double*"> to !llvm<"i8*">
160  llvm.call @free(%238) : (!llvm<"i8*">) -> ()
161  %239 = llvm.extractvalue %25[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
162  %240 = llvm.bitcast %239 : !llvm<"double*"> to !llvm<"i8*">
163  llvm.call @free(%240) : (!llvm<"i8*">) -> ()
164  llvm.return
165}
166```
167
168See [Conversion to the LLVM IR Dialect](../../ConversionToLLVMDialect.md) for
169more in-depth details on lowering to the LLVM dialect.
170
171## CodeGen: Getting Out of MLIR
172
173At this point we are right at the cusp of code generation. We can generate code
174in the LLVM dialect, so now we just need to export to LLVM IR and setup a JIT to
175run it.
176
177### Emitting LLVM IR
178
179Now that our module is comprised only of operations in the LLVM dialect, we can
180export to LLVM IR. To do this programmatically, we can invoke the following
181utility:
182
183```c++
184  std::unique_ptr<llvm::Module> llvmModule = mlir::translateModuleToLLVMIR(module);
185  if (!llvmModule)
186    /* ... an error was encountered ... */
187```
188
189Exporting our module to LLVM IR generates:
190
191```llvm
192define void @main() {
193  ...
194
195102:
196  %103 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %8, 0
197  %104 = mul i64 %96, 2
198  %105 = add i64 0, %104
199  %106 = mul i64 %100, 1
200  %107 = add i64 %105, %106
201  %108 = getelementptr double, double* %103, i64 %107
202  %109 = load double, double* %108
203  %110 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double %109)
204  %111 = add i64 %100, 1
205  br label %99
206
207  ...
208
209115:
210  %116 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %24, 0
211  %117 = bitcast double* %116 to i8*
212  call void @free(i8* %117)
213  %118 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %16, 0
214  %119 = bitcast double* %118 to i8*
215  call void @free(i8* %119)
216  %120 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %8, 0
217  %121 = bitcast double* %120 to i8*
218  call void @free(i8* %121)
219  ret void
220}
221```
222
223If we enable optimization on the generated LLVM IR, we can trim this down quite
224a bit:
225
226```llvm
227define void @main()
228  %0 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 1.000000e+00)
229  %1 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 1.600000e+01)
230  %putchar = tail call i32 @putchar(i32 10)
231  %2 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 4.000000e+00)
232  %3 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 2.500000e+01)
233  %putchar.1 = tail call i32 @putchar(i32 10)
234  %4 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 9.000000e+00)
235  %5 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 3.600000e+01)
236  %putchar.2 = tail call i32 @putchar(i32 10)
237  ret void
238}
239```
240
241The full code listing for dumping LLVM IR can be found in
242`examples/toy/Ch6/toy.cpp` in the `dumpLLVMIR()` function:
243
244```c++
245
246int dumpLLVMIR(mlir::ModuleOp module) {
247  // Translate the module, that contains the LLVM dialect, to LLVM IR.
248  auto llvmModule = mlir::translateModuleToLLVMIR(module);
249  if (!llvmModule) {
250    llvm::errs() << "Failed to emit LLVM IR\n";
251    return -1;
252  }
253
254  // Initialize LLVM targets.
255  llvm::InitializeNativeTarget();
256  llvm::InitializeNativeTargetAsmPrinter();
257  mlir::ExecutionEngine::setupTargetTriple(llvmModule.get());
258
259  /// Optionally run an optimization pipeline over the llvm module.
260  auto optPipeline = mlir::makeOptimizingTransformer(
261      /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0,
262      /*targetMachine=*/nullptr);
263  if (auto err = optPipeline(llvmModule.get())) {
264    llvm::errs() << "Failed to optimize LLVM IR " << err << "\n";
265    return -1;
266  }
267  llvm::errs() << *llvmModule << "\n";
268  return 0;
269}
270```
271
272### Setting up a JIT
273
274Setting up a JIT to run the module containing the LLVM dialect can be done using
275the `mlir::ExecutionEngine` infrastructure. This is a utility wrapper around
276LLVM's JIT that accepts `.mlir` as input. The full code listing for setting up
277the JIT can be found in `Ch6/toyc.cpp` in the `runJit()` function:
278
279```c++
280int runJit(mlir::ModuleOp module) {
281  // Initialize LLVM targets.
282  llvm::InitializeNativeTarget();
283  llvm::InitializeNativeTargetAsmPrinter();
284
285  // An optimization pipeline to use within the execution engine.
286  auto optPipeline = mlir::makeOptimizingTransformer(
287      /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0,
288      /*targetMachine=*/nullptr);
289
290  // Create an MLIR execution engine. The execution engine eagerly JIT-compiles
291  // the module.
292  auto maybeEngine = mlir::ExecutionEngine::create(module, optPipeline);
293  assert(maybeEngine && "failed to construct an execution engine");
294  auto &engine = maybeEngine.get();
295
296  // Invoke the JIT-compiled function.
297  auto invocationResult = engine->invoke("main");
298  if (invocationResult) {
299    llvm::errs() << "JIT invocation failed\n";
300    return -1;
301  }
302
303  return 0;
304}
305```
306
307You can play around with it from the build directory:
308
309```shell
310$ echo 'def main() { print([[1, 2], [3, 4]]); }' | ./bin/toyc-ch6 -emit=jit
3111.000000 2.000000
3123.000000 4.000000
313```
314
315You can also play with `-emit=mlir`, `-emit=mlir-affine`, `-emit=mlir-llvm`, and
316`-emit=llvm` to compare the various levels of IR involved. Also try options like
317[`--print-ir-after-all`](../../PassManagement.md#ir-printing) to track the
318evolution of the IR throughout the pipeline.
319
320The example code used throughout this section can be found in
321test/Examples/Toy/Ch6/llvm-lowering.mlir.
322
323So far, we have worked with primitive data types. In the
324[next chapter](Ch-7.md), we will add a composite `struct` type.
325