1# Chapter 6: Lowering to LLVM and CodeGeneration
2
3[TOC]
4
5In the [previous chapter](Ch-5.md), we introduced the
6[dialect conversion](../../DialectConversion.md) framework and partially lowered
7many of the `Toy` operations to affine loop nests for optimization. In this
8chapter, we will finally lower to LLVM for code generation.
9
10## Lowering to LLVM
11
12For this lowering, we will again use the dialect conversion framework to perform
13the heavy lifting. However, this time, we will be performing a full conversion
14to the [LLVM dialect](../../Dialects/LLVM.md). Thankfully, we have already
15lowered all but one of the `toy` operations, with the last being `toy.print`.
16Before going over the conversion to LLVM, let's lower the `toy.print` operation.
17We will lower this operation to a non-affine loop nest that invokes `printf` for
18each element. Note that, because the dialect conversion framework supports
19[transitive lowering](../../../getting_started/Glossary.md#transitive-lowering), we don't need to
20directly emit operations in the LLVM dialect. By transitive lowering, we mean
21that the conversion framework may apply multiple patterns to fully legalize an
22operation. In this example, we are generating a structured loop nest instead of
23the branch-form in the LLVM dialect. As long as we then have a lowering from the
24loop operations to LLVM, the lowering will still succeed.
25
26During lowering we can get, or build, the declaration for printf as so:
27
28```c++
29/// Return a symbol reference to the printf function, inserting it into the
30/// module if necessary.
31static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
32                                           ModuleOp module,
33                                           LLVM::LLVMDialect *llvmDialect) {
34  auto *context = module.getContext();
35  if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
36    return SymbolRefAttr::get("printf", context);
37
38  // Create a function declaration for printf, the signature is:
39  //   * `i32 (i8*, ...)`
40  auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
41  auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
42  auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy,
43                                                  /*isVarArg=*/true);
44
45  // Insert the printf function into the body of the parent module.
46  PatternRewriter::InsertionGuard insertGuard(rewriter);
47  rewriter.setInsertionPointToStart(module.getBody());
48  rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
49  return SymbolRefAttr::get("printf", context);
50}
51```
52
53Now that the lowering for the printf operation has been defined, we can specify
54the components necessary for the lowering. These are largely the same as the
55components defined in the [previous chapter](Ch-5.md).
56
57### Conversion Target
58
59For this conversion, aside from the top-level module, we will be lowering
60everything to the LLVM dialect.
61
62```c++
63  mlir::ConversionTarget target(getContext());
64  target.addLegalDialect<mlir::LLVMDialect>();
65  target.addLegalOp<mlir::ModuleOp, mlir::ModuleTerminatorOp>();
66```
67
68### Type Converter
69
70This lowering will also transform the MemRef types which are currently being
71operated on into a representation in LLVM. To perform this conversion, we use a
72TypeConverter as part of the lowering. This converter specifies how one type
73maps to another. This is necessary now that we are performing more complicated
74lowerings involving block arguments. Given that we don't have any
75Toy-dialect-specific types that need to be lowered, the default converter is
76enough for our use case.
77
78```c++
79  LLVMTypeConverter typeConverter(&getContext());
80```
81
82### Conversion Patterns
83
84Now that the conversion target has been defined, we need to provide the patterns
85used for lowering. At this point in the compilation process, we have a
86combination of `toy`, `affine`, and `std` operations. Luckily, the `std` and
87`affine` dialects already provide the set of patterns needed to transform them
88into LLVM dialect. These patterns allow for lowering the IR in multiple stages
89by relying on [transitive lowering](../../../getting_started/Glossary.md#transitive-lowering).
90
91```c++
92  mlir::OwningRewritePatternList patterns;
93  mlir::populateAffineToStdConversionPatterns(patterns, &getContext());
94  mlir::populateLoopToStdConversionPatterns(patterns, &getContext());
95  mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
96
97  // The only remaining operation, to lower from the `toy` dialect, is the
98  // PrintOp.
99  patterns.insert<PrintOpLowering>(&getContext());
100```
101
102### Full Lowering
103
104We want to completely lower to LLVM, so we use a `FullConversion`. This ensures
105that only legal operations will remain after the conversion.
106
107```c++
108  mlir::ModuleOp module = getOperation();
109  if (mlir::failed(mlir::applyFullConversion(module, target, patterns)))
110    signalPassFailure();
111```
112
113Looking back at our current working example:
114
115```mlir
116func @main() {
117  %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
118  %2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
119  %3 = toy.mul %2, %2 : tensor<3x2xf64>
120  toy.print %3 : tensor<3x2xf64>
121  toy.return
122}
123```
124
125We can now lower down to the LLVM dialect, which produces the following code:
126
127```mlir
128llvm.func @free(!llvm<"i8*">)
129llvm.func @printf(!llvm<"i8*">, ...) -> !llvm.i32
130llvm.func @malloc(!llvm.i64) -> !llvm<"i8*">
131llvm.func @main() {
132  %0 = llvm.mlir.constant(1.000000e+00 : f64) : !llvm.double
133  %1 = llvm.mlir.constant(2.000000e+00 : f64) : !llvm.double
134
135  ...
136
137^bb16:
138  %221 = llvm.extractvalue %25[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
139  %222 = llvm.mlir.constant(0 : index) : !llvm.i64
140  %223 = llvm.mlir.constant(2 : index) : !llvm.i64
141  %224 = llvm.mul %214, %223 : !llvm.i64
142  %225 = llvm.add %222, %224 : !llvm.i64
143  %226 = llvm.mlir.constant(1 : index) : !llvm.i64
144  %227 = llvm.mul %219, %226 : !llvm.i64
145  %228 = llvm.add %225, %227 : !llvm.i64
146  %229 = llvm.getelementptr %221[%228] : (!llvm<"double*">, !llvm.i64) -> !llvm<"double*">
147  %230 = llvm.load %229 : !llvm<"double*">
148  %231 = llvm.call @printf(%207, %230) : (!llvm<"i8*">, !llvm.double) -> !llvm.i32
149  %232 = llvm.add %219, %218 : !llvm.i64
150  llvm.br ^bb15(%232 : !llvm.i64)
151
152  ...
153
154^bb18:
155  %235 = llvm.extractvalue %65[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
156  %236 = llvm.bitcast %235 : !llvm<"double*"> to !llvm<"i8*">
157  llvm.call @free(%236) : (!llvm<"i8*">) -> ()
158  %237 = llvm.extractvalue %45[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
159  %238 = llvm.bitcast %237 : !llvm<"double*"> to !llvm<"i8*">
160  llvm.call @free(%238) : (!llvm<"i8*">) -> ()
161  %239 = llvm.extractvalue %25[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
162  %240 = llvm.bitcast %239 : !llvm<"double*"> to !llvm<"i8*">
163  llvm.call @free(%240) : (!llvm<"i8*">) -> ()
164  llvm.return
165}
166```
167
168See [Conversion to the LLVM IR Dialect](../../ConversionToLLVMDialect.md) for
169more in-depth details on lowering to the LLVM dialect.
170
171## CodeGen: Getting Out of MLIR
172
173At this point we are right at the cusp of code generation. We can generate code
174in the LLVM dialect, so now we just need to export to LLVM IR and setup a JIT to
175run it.
176
177### Emitting LLVM IR
178
179Now that our module is comprised only of operations in the LLVM dialect, we can
180export to LLVM IR. To do this programmatically, we can invoke the following
181utility:
182
183```c++
184  std::unique_ptr<llvm::Module> llvmModule = mlir::translateModuleToLLVMIR(module);
185  if (!llvmModule)
186    /* ... an error was encountered ... */
187```
188
189Exporting our module to LLVM IR generates:
190
191```llvm
192define void @main() {
193  ...
194
195102:
196  %103 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %8, 0
197  %104 = mul i64 %96, 2
198  %105 = add i64 0, %104
199  %106 = mul i64 %100, 1
200  %107 = add i64 %105, %106
201  %108 = getelementptr double, double* %103, i64 %107
202  %109 = load double, double* %108
203  %110 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double %109)
204  %111 = add i64 %100, 1
205  br label %99
206
207  ...
208
209115:
210  %116 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %24, 0
211  %117 = bitcast double* %116 to i8*
212  call void @free(i8* %117)
213  %118 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %16, 0
214  %119 = bitcast double* %118 to i8*
215  call void @free(i8* %119)
216  %120 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %8, 0
217  %121 = bitcast double* %120 to i8*
218  call void @free(i8* %121)
219  ret void
220}
221```
222
223If we enable optimization on the generated LLVM IR, we can trim this down quite
224a bit:
225
226```llvm
227define void @main()
228  %0 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 1.000000e+00)
229  %1 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 1.600000e+01)
230  %putchar = tail call i32 @putchar(i32 10)
231  %2 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 4.000000e+00)
232  %3 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 2.500000e+01)
233  %putchar.1 = tail call i32 @putchar(i32 10)
234  %4 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 9.000000e+00)
235  %5 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 3.600000e+01)
236  %putchar.2 = tail call i32 @putchar(i32 10)
237  ret void
238}
239```
240
241The full code listing for dumping LLVM IR can be found in
242`examples/toy/Ch6/toy.cpp` in the `dumpLLVMIR()` function:
243
244```c++
245
246int dumpLLVMIR(mlir::ModuleOp module) {
247  // Translate the module, that contains the LLVM dialect, to LLVM IR. Use a
248  // fresh LLVM IR context. (Note that LLVM is not thread-safe and any
249  // concurrent use of a context requires external locking.)
250  llvm::LLVMContext llvmContext;
251  auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext);
252  if (!llvmModule) {
253    llvm::errs() << "Failed to emit LLVM IR\n";
254    return -1;
255  }
256
257  // Initialize LLVM targets.
258  llvm::InitializeNativeTarget();
259  llvm::InitializeNativeTargetAsmPrinter();
260  mlir::ExecutionEngine::setupTargetTriple(llvmModule.get());
261
262  /// Optionally run an optimization pipeline over the llvm module.
263  auto optPipeline = mlir::makeOptimizingTransformer(
264      /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0,
265      /*targetMachine=*/nullptr);
266  if (auto err = optPipeline(llvmModule.get())) {
267    llvm::errs() << "Failed to optimize LLVM IR " << err << "\n";
268    return -1;
269  }
270  llvm::errs() << *llvmModule << "\n";
271  return 0;
272}
273```
274
275### Setting up a JIT
276
277Setting up a JIT to run the module containing the LLVM dialect can be done using
278the `mlir::ExecutionEngine` infrastructure. This is a utility wrapper around
279LLVM's JIT that accepts `.mlir` as input. The full code listing for setting up
280the JIT can be found in `Ch6/toyc.cpp` in the `runJit()` function:
281
282```c++
283int runJit(mlir::ModuleOp module) {
284  // Initialize LLVM targets.
285  llvm::InitializeNativeTarget();
286  llvm::InitializeNativeTargetAsmPrinter();
287
288  // An optimization pipeline to use within the execution engine.
289  auto optPipeline = mlir::makeOptimizingTransformer(
290      /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0,
291      /*targetMachine=*/nullptr);
292
293  // Create an MLIR execution engine. The execution engine eagerly JIT-compiles
294  // the module.
295  auto maybeEngine = mlir::ExecutionEngine::create(module, optPipeline);
296  assert(maybeEngine && "failed to construct an execution engine");
297  auto &engine = maybeEngine.get();
298
299  // Invoke the JIT-compiled function.
300  auto invocationResult = engine->invoke("main");
301  if (invocationResult) {
302    llvm::errs() << "JIT invocation failed\n";
303    return -1;
304  }
305
306  return 0;
307}
308```
309
310You can play around with it from the build directory:
311
312```shell
313$ echo 'def main() { print([[1, 2], [3, 4]]); }' | ./bin/toyc-ch6 -emit=jit
3141.000000 2.000000
3153.000000 4.000000
316```
317
318You can also play with `-emit=mlir`, `-emit=mlir-affine`, `-emit=mlir-llvm`, and
319`-emit=llvm` to compare the various levels of IR involved. Also try options like
320[`--print-ir-after-all`](../../PassManagement.md#ir-printing) to track the
321evolution of the IR throughout the pipeline.
322
323The example code used throughout this section can be found in
324test/Examples/Toy/Ch6/llvm-lowering.mlir.
325
326So far, we have worked with primitive data types. In the
327[next chapter](Ch-7.md), we will add a composite `struct` type.
328