1 //====- LowerToLLVM.cpp - Lowering from Toy+Affine+Std to LLVM ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements full lowering of Toy operations to LLVM MLIR dialect.
10 // 'toy.print' is lowered to a loop nest that calls `printf` on each element of
11 // the input array. The file also sets up the ToyToLLVMLoweringPass. This pass
12 // lowers the combination of Affine + SCF + Standard dialects to the LLVM one:
13 //
14 //                         Affine --
15 //                                  |
16 //                                  v
17 //                                  Standard --> LLVM (Dialect)
18 //                                  ^
19 //                                  |
20 //     'toy.print' --> Loop (SCF) --
21 //
22 //===----------------------------------------------------------------------===//
23 
24 #include "toy/Dialect.h"
25 #include "toy/Passes.h"
26 
27 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
28 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
29 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
30 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
31 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
32 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
33 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
34 #include "mlir/Dialect/Affine/IR/AffineOps.h"
35 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
36 #include "mlir/Dialect/MemRef/IR/MemRef.h"
37 #include "mlir/Dialect/SCF/SCF.h"
38 #include "mlir/Dialect/StandardOps/IR/Ops.h"
39 #include "mlir/Pass/Pass.h"
40 #include "mlir/Transforms/DialectConversion.h"
41 #include "llvm/ADT/Sequence.h"
42 
43 using namespace mlir;
44 
45 //===----------------------------------------------------------------------===//
46 // ToyToLLVM RewritePatterns
47 //===----------------------------------------------------------------------===//
48 
49 namespace {
50 /// Lowers `toy.print` to a loop nest calling `printf` on each of the individual
51 /// elements of the array.
52 class PrintOpLowering : public ConversionPattern {
53 public:
PrintOpLowering(MLIRContext * context)54   explicit PrintOpLowering(MLIRContext *context)
55       : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {}
56 
57   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const58   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
59                   ConversionPatternRewriter &rewriter) const override {
60     auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
61     auto memRefShape = memRefType.getShape();
62     auto loc = op->getLoc();
63 
64     ModuleOp parentModule = op->getParentOfType<ModuleOp>();
65 
66     // Get a symbol reference to the printf function, inserting it if necessary.
67     auto printfRef = getOrInsertPrintf(rewriter, parentModule);
68     Value formatSpecifierCst = getOrCreateGlobalString(
69         loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule);
70     Value newLineCst = getOrCreateGlobalString(
71         loc, rewriter, "nl", StringRef("\n\0", 2), parentModule);
72 
73     // Create a loop for each of the dimensions within the shape.
74     SmallVector<Value, 4> loopIvs;
75     for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) {
76       auto lowerBound = rewriter.create<ConstantIndexOp>(loc, 0);
77       auto upperBound = rewriter.create<ConstantIndexOp>(loc, memRefShape[i]);
78       auto step = rewriter.create<ConstantIndexOp>(loc, 1);
79       auto loop =
80           rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
81       for (Operation &nested : *loop.getBody())
82         rewriter.eraseOp(&nested);
83       loopIvs.push_back(loop.getInductionVar());
84 
85       // Terminate the loop body.
86       rewriter.setInsertionPointToEnd(loop.getBody());
87 
88       // Insert a newline after each of the inner dimensions of the shape.
89       if (i != e - 1)
90         rewriter.create<CallOp>(loc, printfRef, rewriter.getIntegerType(32),
91                                 newLineCst);
92       rewriter.create<scf::YieldOp>(loc);
93       rewriter.setInsertionPointToStart(loop.getBody());
94     }
95 
96     // Generate a call to printf for the current element of the loop.
97     auto printOp = cast<toy::PrintOp>(op);
98     auto elementLoad =
99         rewriter.create<memref::LoadOp>(loc, printOp.input(), loopIvs);
100     rewriter.create<CallOp>(loc, printfRef, rewriter.getIntegerType(32),
101                             ArrayRef<Value>({formatSpecifierCst, elementLoad}));
102 
103     // Notify the rewriter that this operation has been removed.
104     rewriter.eraseOp(op);
105     return success();
106   }
107 
108 private:
109   /// Return a symbol reference to the printf function, inserting it into the
110   /// module if necessary.
getOrInsertPrintf(PatternRewriter & rewriter,ModuleOp module)111   static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
112                                              ModuleOp module) {
113     auto *context = module.getContext();
114     if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
115       return SymbolRefAttr::get(context, "printf");
116 
117     // Create a function declaration for printf, the signature is:
118     //   * `i32 (i8*, ...)`
119     auto llvmI32Ty = IntegerType::get(context, 32);
120     auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
121     auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
122                                                   /*isVarArg=*/true);
123 
124     // Insert the printf function into the body of the parent module.
125     PatternRewriter::InsertionGuard insertGuard(rewriter);
126     rewriter.setInsertionPointToStart(module.getBody());
127     rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
128     return SymbolRefAttr::get(context, "printf");
129   }
130 
131   /// Return a value representing an access into a global string with the given
132   /// name, creating the string if necessary.
getOrCreateGlobalString(Location loc,OpBuilder & builder,StringRef name,StringRef value,ModuleOp module)133   static Value getOrCreateGlobalString(Location loc, OpBuilder &builder,
134                                        StringRef name, StringRef value,
135                                        ModuleOp module) {
136     // Create the global at the entry of the module.
137     LLVM::GlobalOp global;
138     if (!(global = module.lookupSymbol<LLVM::GlobalOp>(name))) {
139       OpBuilder::InsertionGuard insertGuard(builder);
140       builder.setInsertionPointToStart(module.getBody());
141       auto type = LLVM::LLVMArrayType::get(
142           IntegerType::get(builder.getContext(), 8), value.size());
143       global = builder.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
144                                               LLVM::Linkage::Internal, name,
145                                               builder.getStringAttr(value),
146                                               /*alignment=*/0);
147     }
148 
149     // Get the pointer to the first character in the global string.
150     Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
151     Value cst0 = builder.create<LLVM::ConstantOp>(
152         loc, IntegerType::get(builder.getContext(), 64),
153         builder.getIntegerAttr(builder.getIndexType(), 0));
154     return builder.create<LLVM::GEPOp>(
155         loc,
156         LLVM::LLVMPointerType::get(IntegerType::get(builder.getContext(), 8)),
157         globalPtr, ArrayRef<Value>({cst0, cst0}));
158   }
159 };
160 } // end anonymous namespace
161 
162 //===----------------------------------------------------------------------===//
163 // ToyToLLVMLoweringPass
164 //===----------------------------------------------------------------------===//
165 
166 namespace {
167 struct ToyToLLVMLoweringPass
168     : public PassWrapper<ToyToLLVMLoweringPass, OperationPass<ModuleOp>> {
getDependentDialects__anon7a11ba050211::ToyToLLVMLoweringPass169   void getDependentDialects(DialectRegistry &registry) const override {
170     registry.insert<LLVM::LLVMDialect, scf::SCFDialect>();
171   }
172   void runOnOperation() final;
173 };
174 } // end anonymous namespace
175 
runOnOperation()176 void ToyToLLVMLoweringPass::runOnOperation() {
177   // The first thing to define is the conversion target. This will define the
178   // final target for this lowering. For this lowering, we are only targeting
179   // the LLVM dialect.
180   LLVMConversionTarget target(getContext());
181   target.addLegalOp<ModuleOp>();
182 
183   // During this lowering, we will also be lowering the MemRef types, that are
184   // currently being operated on, to a representation in LLVM. To perform this
185   // conversion we use a TypeConverter as part of the lowering. This converter
186   // details how one type maps to another. This is necessary now that we will be
187   // doing more complicated lowerings, involving loop region arguments.
188   LLVMTypeConverter typeConverter(&getContext());
189 
190   // Now that the conversion target has been defined, we need to provide the
191   // patterns used for lowering. At this point of the compilation process, we
192   // have a combination of `toy`, `affine`, and `std` operations. Luckily, there
193   // are already exists a set of patterns to transform `affine` and `std`
194   // dialects. These patterns lowering in multiple stages, relying on transitive
195   // lowerings. Transitive lowering, or A->B->C lowering, is when multiple
196   // patterns must be applied to fully transform an illegal operation into a
197   // set of legal ones.
198   RewritePatternSet patterns(&getContext());
199   populateAffineToStdConversionPatterns(patterns);
200   populateLoopToStdConversionPatterns(patterns);
201   populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
202   populateStdToLLVMConversionPatterns(typeConverter, patterns);
203 
204   // The only remaining operation to lower from the `toy` dialect, is the
205   // PrintOp.
206   patterns.add<PrintOpLowering>(&getContext());
207 
208   // We want to completely lower to LLVM, so we use a `FullConversion`. This
209   // ensures that only legal operations will remain after the conversion.
210   auto module = getOperation();
211   if (failed(applyFullConversion(module, target, std::move(patterns))))
212     signalPassFailure();
213 }
214 
215 /// Create a pass for lowering operations the remaining `Toy` operations, as
216 /// well as `Affine` and `Std`, to the LLVM dialect for codegen.
createLowerToLLVMPass()217 std::unique_ptr<mlir::Pass> mlir::toy::createLowerToLLVMPass() {
218   return std::make_unique<ToyToLLVMLoweringPass>();
219 }
220