1 //====- LowerToLLVM.cpp - Lowering from Toy+Affine+Std to LLVM ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements full lowering of Toy operations to LLVM MLIR dialect.
10 // 'toy.print' is lowered to a loop nest that calls `printf` on each element of
11 // the input array. The file also sets up the ToyToLLVMLoweringPass. This pass
12 // lowers the combination of Affine + SCF + Standard dialects to the LLVM one:
13 //
14 // Affine --
15 // |
16 // v
17 // Standard --> LLVM (Dialect)
18 // ^
19 // |
20 // 'toy.print' --> Loop (SCF) --
21 //
22 //===----------------------------------------------------------------------===//
23
24 #include "toy/Dialect.h"
25 #include "toy/Passes.h"
26
27 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
28 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
29 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
30 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
31 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
32 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
33 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
34 #include "mlir/Dialect/Affine/IR/AffineOps.h"
35 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
36 #include "mlir/Dialect/MemRef/IR/MemRef.h"
37 #include "mlir/Dialect/SCF/SCF.h"
38 #include "mlir/Dialect/StandardOps/IR/Ops.h"
39 #include "mlir/Pass/Pass.h"
40 #include "mlir/Transforms/DialectConversion.h"
41 #include "llvm/ADT/Sequence.h"
42
43 using namespace mlir;
44
45 //===----------------------------------------------------------------------===//
46 // ToyToLLVM RewritePatterns
47 //===----------------------------------------------------------------------===//
48
49 namespace {
50 /// Lowers `toy.print` to a loop nest calling `printf` on each of the individual
51 /// elements of the array.
52 class PrintOpLowering : public ConversionPattern {
53 public:
PrintOpLowering(MLIRContext * context)54 explicit PrintOpLowering(MLIRContext *context)
55 : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {}
56
57 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const58 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
59 ConversionPatternRewriter &rewriter) const override {
60 auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
61 auto memRefShape = memRefType.getShape();
62 auto loc = op->getLoc();
63
64 ModuleOp parentModule = op->getParentOfType<ModuleOp>();
65
66 // Get a symbol reference to the printf function, inserting it if necessary.
67 auto printfRef = getOrInsertPrintf(rewriter, parentModule);
68 Value formatSpecifierCst = getOrCreateGlobalString(
69 loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule);
70 Value newLineCst = getOrCreateGlobalString(
71 loc, rewriter, "nl", StringRef("\n\0", 2), parentModule);
72
73 // Create a loop for each of the dimensions within the shape.
74 SmallVector<Value, 4> loopIvs;
75 for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) {
76 auto lowerBound = rewriter.create<ConstantIndexOp>(loc, 0);
77 auto upperBound = rewriter.create<ConstantIndexOp>(loc, memRefShape[i]);
78 auto step = rewriter.create<ConstantIndexOp>(loc, 1);
79 auto loop =
80 rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
81 for (Operation &nested : *loop.getBody())
82 rewriter.eraseOp(&nested);
83 loopIvs.push_back(loop.getInductionVar());
84
85 // Terminate the loop body.
86 rewriter.setInsertionPointToEnd(loop.getBody());
87
88 // Insert a newline after each of the inner dimensions of the shape.
89 if (i != e - 1)
90 rewriter.create<CallOp>(loc, printfRef, rewriter.getIntegerType(32),
91 newLineCst);
92 rewriter.create<scf::YieldOp>(loc);
93 rewriter.setInsertionPointToStart(loop.getBody());
94 }
95
96 // Generate a call to printf for the current element of the loop.
97 auto printOp = cast<toy::PrintOp>(op);
98 auto elementLoad =
99 rewriter.create<memref::LoadOp>(loc, printOp.input(), loopIvs);
100 rewriter.create<CallOp>(loc, printfRef, rewriter.getIntegerType(32),
101 ArrayRef<Value>({formatSpecifierCst, elementLoad}));
102
103 // Notify the rewriter that this operation has been removed.
104 rewriter.eraseOp(op);
105 return success();
106 }
107
108 private:
109 /// Return a symbol reference to the printf function, inserting it into the
110 /// module if necessary.
getOrInsertPrintf(PatternRewriter & rewriter,ModuleOp module)111 static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
112 ModuleOp module) {
113 auto *context = module.getContext();
114 if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
115 return SymbolRefAttr::get(context, "printf");
116
117 // Create a function declaration for printf, the signature is:
118 // * `i32 (i8*, ...)`
119 auto llvmI32Ty = IntegerType::get(context, 32);
120 auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
121 auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
122 /*isVarArg=*/true);
123
124 // Insert the printf function into the body of the parent module.
125 PatternRewriter::InsertionGuard insertGuard(rewriter);
126 rewriter.setInsertionPointToStart(module.getBody());
127 rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
128 return SymbolRefAttr::get(context, "printf");
129 }
130
131 /// Return a value representing an access into a global string with the given
132 /// name, creating the string if necessary.
getOrCreateGlobalString(Location loc,OpBuilder & builder,StringRef name,StringRef value,ModuleOp module)133 static Value getOrCreateGlobalString(Location loc, OpBuilder &builder,
134 StringRef name, StringRef value,
135 ModuleOp module) {
136 // Create the global at the entry of the module.
137 LLVM::GlobalOp global;
138 if (!(global = module.lookupSymbol<LLVM::GlobalOp>(name))) {
139 OpBuilder::InsertionGuard insertGuard(builder);
140 builder.setInsertionPointToStart(module.getBody());
141 auto type = LLVM::LLVMArrayType::get(
142 IntegerType::get(builder.getContext(), 8), value.size());
143 global = builder.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
144 LLVM::Linkage::Internal, name,
145 builder.getStringAttr(value),
146 /*alignment=*/0);
147 }
148
149 // Get the pointer to the first character in the global string.
150 Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
151 Value cst0 = builder.create<LLVM::ConstantOp>(
152 loc, IntegerType::get(builder.getContext(), 64),
153 builder.getIntegerAttr(builder.getIndexType(), 0));
154 return builder.create<LLVM::GEPOp>(
155 loc,
156 LLVM::LLVMPointerType::get(IntegerType::get(builder.getContext(), 8)),
157 globalPtr, ArrayRef<Value>({cst0, cst0}));
158 }
159 };
160 } // end anonymous namespace
161
162 //===----------------------------------------------------------------------===//
163 // ToyToLLVMLoweringPass
164 //===----------------------------------------------------------------------===//
165
166 namespace {
167 struct ToyToLLVMLoweringPass
168 : public PassWrapper<ToyToLLVMLoweringPass, OperationPass<ModuleOp>> {
getDependentDialects__anon7a11ba050211::ToyToLLVMLoweringPass169 void getDependentDialects(DialectRegistry ®istry) const override {
170 registry.insert<LLVM::LLVMDialect, scf::SCFDialect>();
171 }
172 void runOnOperation() final;
173 };
174 } // end anonymous namespace
175
runOnOperation()176 void ToyToLLVMLoweringPass::runOnOperation() {
177 // The first thing to define is the conversion target. This will define the
178 // final target for this lowering. For this lowering, we are only targeting
179 // the LLVM dialect.
180 LLVMConversionTarget target(getContext());
181 target.addLegalOp<ModuleOp>();
182
183 // During this lowering, we will also be lowering the MemRef types, that are
184 // currently being operated on, to a representation in LLVM. To perform this
185 // conversion we use a TypeConverter as part of the lowering. This converter
186 // details how one type maps to another. This is necessary now that we will be
187 // doing more complicated lowerings, involving loop region arguments.
188 LLVMTypeConverter typeConverter(&getContext());
189
190 // Now that the conversion target has been defined, we need to provide the
191 // patterns used for lowering. At this point of the compilation process, we
192 // have a combination of `toy`, `affine`, and `std` operations. Luckily, there
193 // are already exists a set of patterns to transform `affine` and `std`
194 // dialects. These patterns lowering in multiple stages, relying on transitive
195 // lowerings. Transitive lowering, or A->B->C lowering, is when multiple
196 // patterns must be applied to fully transform an illegal operation into a
197 // set of legal ones.
198 RewritePatternSet patterns(&getContext());
199 populateAffineToStdConversionPatterns(patterns);
200 populateLoopToStdConversionPatterns(patterns);
201 populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
202 populateStdToLLVMConversionPatterns(typeConverter, patterns);
203
204 // The only remaining operation to lower from the `toy` dialect, is the
205 // PrintOp.
206 patterns.add<PrintOpLowering>(&getContext());
207
208 // We want to completely lower to LLVM, so we use a `FullConversion`. This
209 // ensures that only legal operations will remain after the conversion.
210 auto module = getOperation();
211 if (failed(applyFullConversion(module, target, std::move(patterns))))
212 signalPassFailure();
213 }
214
215 /// Create a pass for lowering operations the remaining `Toy` operations, as
216 /// well as `Affine` and `Std`, to the LLVM dialect for codegen.
createLowerToLLVMPass()217 std::unique_ptr<mlir::Pass> mlir::toy::createLowerToLLVMPass() {
218 return std::make_unique<ToyToLLVMLoweringPass>();
219 }
220