1# Chapter 6: Lowering to LLVM and CodeGeneration 2 3[TOC] 4 5In the [previous chapter](Ch-5.md), we introduced the 6[dialect conversion](../../DialectConversion.md) framework and partially lowered 7many of the `Toy` operations to affine loop nests for optimization. In this 8chapter, we will finally lower to LLVM for code generation. 9 10## Lowering to LLVM 11 12For this lowering, we will again use the dialect conversion framework to perform 13the heavy lifting. However, this time, we will be performing a full conversion 14to the [LLVM dialect](../../Dialects/LLVM.md). Thankfully, we have already 15lowered all but one of the `toy` operations, with the last being `toy.print`. 16Before going over the conversion to LLVM, let's lower the `toy.print` operation. 17We will lower this operation to a non-affine loop nest that invokes `printf` for 18each element. Note that, because the dialect conversion framework supports 19[transitive lowering](../../../getting_started/Glossary.md#transitive-lowering), we don't need to 20directly emit operations in the LLVM dialect. By transitive lowering, we mean 21that the conversion framework may apply multiple patterns to fully legalize an 22operation. In this example, we are generating a structured loop nest instead of 23the branch-form in the LLVM dialect. As long as we then have a lowering from the 24loop operations to LLVM, the lowering will still succeed. 25 26During lowering we can get, or build, the declaration for printf as so: 27 28```c++ 29/// Return a symbol reference to the printf function, inserting it into the 30/// module if necessary. 31static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter, 32 ModuleOp module, 33 LLVM::LLVMDialect *llvmDialect) { 34 auto *context = module.getContext(); 35 if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf")) 36 return SymbolRefAttr::get("printf", context); 37 38 // Create a function declaration for printf, the signature is: 39 // * `i32 (i8*, ...)` 40 auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect); 41 auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); 42 auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy, 43 /*isVarArg=*/true); 44 45 // Insert the printf function into the body of the parent module. 46 PatternRewriter::InsertionGuard insertGuard(rewriter); 47 rewriter.setInsertionPointToStart(module.getBody()); 48 rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType); 49 return SymbolRefAttr::get("printf", context); 50} 51``` 52 53Now that the lowering for the printf operation has been defined, we can specify 54the components necessary for the lowering. These are largely the same as the 55components defined in the [previous chapter](Ch-5.md). 56 57### Conversion Target 58 59For this conversion, aside from the top-level module, we will be lowering 60everything to the LLVM dialect. 61 62```c++ 63 mlir::ConversionTarget target(getContext()); 64 target.addLegalDialect<mlir::LLVMDialect>(); 65 target.addLegalOp<mlir::ModuleOp, mlir::ModuleTerminatorOp>(); 66``` 67 68### Type Converter 69 70This lowering will also transform the MemRef types which are currently being 71operated on into a representation in LLVM. To perform this conversion, we use a 72TypeConverter as part of the lowering. This converter specifies how one type 73maps to another. This is necessary now that we are performing more complicated 74lowerings involving block arguments. Given that we don't have any 75Toy-dialect-specific types that need to be lowered, the default converter is 76enough for our use case. 77 78```c++ 79 LLVMTypeConverter typeConverter(&getContext()); 80``` 81 82### Conversion Patterns 83 84Now that the conversion target has been defined, we need to provide the patterns 85used for lowering. At this point in the compilation process, we have a 86combination of `toy`, `affine`, and `std` operations. Luckily, the `std` and 87`affine` dialects already provide the set of patterns needed to transform them 88into LLVM dialect. These patterns allow for lowering the IR in multiple stages 89by relying on [transitive lowering](../../../getting_started/Glossary.md#transitive-lowering). 90 91```c++ 92 mlir::OwningRewritePatternList patterns; 93 mlir::populateAffineToStdConversionPatterns(patterns, &getContext()); 94 mlir::populateLoopToStdConversionPatterns(patterns, &getContext()); 95 mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns); 96 97 // The only remaining operation, to lower from the `toy` dialect, is the 98 // PrintOp. 99 patterns.insert<PrintOpLowering>(&getContext()); 100``` 101 102### Full Lowering 103 104We want to completely lower to LLVM, so we use a `FullConversion`. This ensures 105that only legal operations will remain after the conversion. 106 107```c++ 108 mlir::ModuleOp module = getOperation(); 109 if (mlir::failed(mlir::applyFullConversion(module, target, patterns))) 110 signalPassFailure(); 111``` 112 113Looking back at our current working example: 114 115```mlir 116func @main() { 117 %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> 118 %2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64> 119 %3 = toy.mul %2, %2 : tensor<3x2xf64> 120 toy.print %3 : tensor<3x2xf64> 121 toy.return 122} 123``` 124 125We can now lower down to the LLVM dialect, which produces the following code: 126 127```mlir 128llvm.func @free(!llvm<"i8*">) 129llvm.func @printf(!llvm<"i8*">, ...) -> !llvm.i32 130llvm.func @malloc(!llvm.i64) -> !llvm<"i8*"> 131llvm.func @main() { 132 %0 = llvm.mlir.constant(1.000000e+00 : f64) : !llvm.double 133 %1 = llvm.mlir.constant(2.000000e+00 : f64) : !llvm.double 134 135 ... 136 137^bb16: 138 %221 = llvm.extractvalue %25[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }"> 139 %222 = llvm.mlir.constant(0 : index) : !llvm.i64 140 %223 = llvm.mlir.constant(2 : index) : !llvm.i64 141 %224 = llvm.mul %214, %223 : !llvm.i64 142 %225 = llvm.add %222, %224 : !llvm.i64 143 %226 = llvm.mlir.constant(1 : index) : !llvm.i64 144 %227 = llvm.mul %219, %226 : !llvm.i64 145 %228 = llvm.add %225, %227 : !llvm.i64 146 %229 = llvm.getelementptr %221[%228] : (!llvm<"double*">, !llvm.i64) -> !llvm<"double*"> 147 %230 = llvm.load %229 : !llvm<"double*"> 148 %231 = llvm.call @printf(%207, %230) : (!llvm<"i8*">, !llvm.double) -> !llvm.i32 149 %232 = llvm.add %219, %218 : !llvm.i64 150 llvm.br ^bb15(%232 : !llvm.i64) 151 152 ... 153 154^bb18: 155 %235 = llvm.extractvalue %65[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }"> 156 %236 = llvm.bitcast %235 : !llvm<"double*"> to !llvm<"i8*"> 157 llvm.call @free(%236) : (!llvm<"i8*">) -> () 158 %237 = llvm.extractvalue %45[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }"> 159 %238 = llvm.bitcast %237 : !llvm<"double*"> to !llvm<"i8*"> 160 llvm.call @free(%238) : (!llvm<"i8*">) -> () 161 %239 = llvm.extractvalue %25[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }"> 162 %240 = llvm.bitcast %239 : !llvm<"double*"> to !llvm<"i8*"> 163 llvm.call @free(%240) : (!llvm<"i8*">) -> () 164 llvm.return 165} 166``` 167 168See [Conversion to the LLVM IR Dialect](../../ConversionToLLVMDialect.md) for 169more in-depth details on lowering to the LLVM dialect. 170 171## CodeGen: Getting Out of MLIR 172 173At this point we are right at the cusp of code generation. We can generate code 174in the LLVM dialect, so now we just need to export to LLVM IR and setup a JIT to 175run it. 176 177### Emitting LLVM IR 178 179Now that our module is comprised only of operations in the LLVM dialect, we can 180export to LLVM IR. To do this programmatically, we can invoke the following 181utility: 182 183```c++ 184 std::unique_ptr<llvm::Module> llvmModule = mlir::translateModuleToLLVMIR(module); 185 if (!llvmModule) 186 /* ... an error was encountered ... */ 187``` 188 189Exporting our module to LLVM IR generates: 190 191```llvm 192define void @main() { 193 ... 194 195102: 196 %103 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %8, 0 197 %104 = mul i64 %96, 2 198 %105 = add i64 0, %104 199 %106 = mul i64 %100, 1 200 %107 = add i64 %105, %106 201 %108 = getelementptr double, double* %103, i64 %107 202 %109 = load double, double* %108 203 %110 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double %109) 204 %111 = add i64 %100, 1 205 br label %99 206 207 ... 208 209115: 210 %116 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %24, 0 211 %117 = bitcast double* %116 to i8* 212 call void @free(i8* %117) 213 %118 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %16, 0 214 %119 = bitcast double* %118 to i8* 215 call void @free(i8* %119) 216 %120 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %8, 0 217 %121 = bitcast double* %120 to i8* 218 call void @free(i8* %121) 219 ret void 220} 221``` 222 223If we enable optimization on the generated LLVM IR, we can trim this down quite 224a bit: 225 226```llvm 227define void @main() 228 %0 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 1.000000e+00) 229 %1 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 1.600000e+01) 230 %putchar = tail call i32 @putchar(i32 10) 231 %2 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 4.000000e+00) 232 %3 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 2.500000e+01) 233 %putchar.1 = tail call i32 @putchar(i32 10) 234 %4 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 9.000000e+00) 235 %5 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 3.600000e+01) 236 %putchar.2 = tail call i32 @putchar(i32 10) 237 ret void 238} 239``` 240 241The full code listing for dumping LLVM IR can be found in 242`examples/toy/Ch6/toy.cpp` in the `dumpLLVMIR()` function: 243 244```c++ 245 246int dumpLLVMIR(mlir::ModuleOp module) { 247 // Translate the module, that contains the LLVM dialect, to LLVM IR. Use a 248 // fresh LLVM IR context. (Note that LLVM is not thread-safe and any 249 // concurrent use of a context requires external locking.) 250 llvm::LLVMContext llvmContext; 251 auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext); 252 if (!llvmModule) { 253 llvm::errs() << "Failed to emit LLVM IR\n"; 254 return -1; 255 } 256 257 // Initialize LLVM targets. 258 llvm::InitializeNativeTarget(); 259 llvm::InitializeNativeTargetAsmPrinter(); 260 mlir::ExecutionEngine::setupTargetTriple(llvmModule.get()); 261 262 /// Optionally run an optimization pipeline over the llvm module. 263 auto optPipeline = mlir::makeOptimizingTransformer( 264 /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0, 265 /*targetMachine=*/nullptr); 266 if (auto err = optPipeline(llvmModule.get())) { 267 llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; 268 return -1; 269 } 270 llvm::errs() << *llvmModule << "\n"; 271 return 0; 272} 273``` 274 275### Setting up a JIT 276 277Setting up a JIT to run the module containing the LLVM dialect can be done using 278the `mlir::ExecutionEngine` infrastructure. This is a utility wrapper around 279LLVM's JIT that accepts `.mlir` as input. The full code listing for setting up 280the JIT can be found in `Ch6/toyc.cpp` in the `runJit()` function: 281 282```c++ 283int runJit(mlir::ModuleOp module) { 284 // Initialize LLVM targets. 285 llvm::InitializeNativeTarget(); 286 llvm::InitializeNativeTargetAsmPrinter(); 287 288 // An optimization pipeline to use within the execution engine. 289 auto optPipeline = mlir::makeOptimizingTransformer( 290 /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0, 291 /*targetMachine=*/nullptr); 292 293 // Create an MLIR execution engine. The execution engine eagerly JIT-compiles 294 // the module. 295 auto maybeEngine = mlir::ExecutionEngine::create(module, optPipeline); 296 assert(maybeEngine && "failed to construct an execution engine"); 297 auto &engine = maybeEngine.get(); 298 299 // Invoke the JIT-compiled function. 300 auto invocationResult = engine->invoke("main"); 301 if (invocationResult) { 302 llvm::errs() << "JIT invocation failed\n"; 303 return -1; 304 } 305 306 return 0; 307} 308``` 309 310You can play around with it from the build directory: 311 312```shell 313$ echo 'def main() { print([[1, 2], [3, 4]]); }' | ./bin/toyc-ch6 -emit=jit 3141.000000 2.000000 3153.000000 4.000000 316``` 317 318You can also play with `-emit=mlir`, `-emit=mlir-affine`, `-emit=mlir-llvm`, and 319`-emit=llvm` to compare the various levels of IR involved. Also try options like 320[`--print-ir-after-all`](../../PassManagement.md#ir-printing) to track the 321evolution of the IR throughout the pipeline. 322 323The example code used throughout this section can be found in 324test/Examples/Toy/Ch6/llvm-lowering.mlir. 325 326So far, we have worked with primitive data types. In the 327[next chapter](Ch-7.md), we will add a composite `struct` type. 328