1 //====- LowerToAffineLoops.cpp - Partial lowering from Toy to Affine+Std --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a partial lowering of Toy operations to a combination of
10 // affine loops, memref operations and standard operations. This lowering
11 // expects that all calls have been inlined, and all shapes have been resolved.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "toy/Dialect.h"
16 #include "toy/Passes.h"
17
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "llvm/ADT/Sequence.h"
24
25 using namespace mlir;
26
27 //===----------------------------------------------------------------------===//
28 // ToyToAffine RewritePatterns
29 //===----------------------------------------------------------------------===//
30
31 /// Convert the given TensorType into the corresponding MemRefType.
convertTensorToMemRef(TensorType type)32 static MemRefType convertTensorToMemRef(TensorType type) {
33 assert(type.hasRank() && "expected only ranked shapes");
34 return MemRefType::get(type.getShape(), type.getElementType());
35 }
36
37 /// Insert an allocation and deallocation for the given MemRefType.
insertAllocAndDealloc(MemRefType type,Location loc,PatternRewriter & rewriter)38 static Value insertAllocAndDealloc(MemRefType type, Location loc,
39 PatternRewriter &rewriter) {
40 auto alloc = rewriter.create<memref::AllocOp>(loc, type);
41
42 // Make sure to allocate at the beginning of the block.
43 auto *parentBlock = alloc->getBlock();
44 alloc->moveBefore(&parentBlock->front());
45
46 // Make sure to deallocate this alloc at the end of the block. This is fine
47 // as toy functions have no control flow.
48 auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
49 dealloc->moveBefore(&parentBlock->back());
50 return alloc;
51 }
52
53 /// This defines the function type used to process an iteration of a lowered
54 /// loop. It takes as input an OpBuilder, an range of memRefOperands
55 /// corresponding to the operands of the input operation, and the range of loop
56 /// induction variables for the iteration. It returns a value to store at the
57 /// current index of the iteration.
58 using LoopIterationFn = function_ref<Value(
59 OpBuilder &rewriter, ValueRange memRefOperands, ValueRange loopIvs)>;
60
lowerOpToLoops(Operation * op,ValueRange operands,PatternRewriter & rewriter,LoopIterationFn processIteration)61 static void lowerOpToLoops(Operation *op, ValueRange operands,
62 PatternRewriter &rewriter,
63 LoopIterationFn processIteration) {
64 auto tensorType = (*op->result_type_begin()).cast<TensorType>();
65 auto loc = op->getLoc();
66
67 // Insert an allocation and deallocation for the result of this operation.
68 auto memRefType = convertTensorToMemRef(tensorType);
69 auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
70
71 // Create a nest of affine loops, with one loop per dimension of the shape.
72 // The buildAffineLoopNest function takes a callback that is used to construct
73 // the body of the innermost loop given a builder, a location and a range of
74 // loop induction variables.
75 SmallVector<int64_t, 4> lowerBounds(tensorType.getRank(), /*Value=*/0);
76 SmallVector<int64_t, 4> steps(tensorType.getRank(), /*Value=*/1);
77 buildAffineLoopNest(
78 rewriter, loc, lowerBounds, tensorType.getShape(), steps,
79 [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) {
80 // Call the processing function with the rewriter, the memref operands,
81 // and the loop induction variables. This function will return the value
82 // to store at the current index.
83 Value valueToStore = processIteration(nestedBuilder, operands, ivs);
84 nestedBuilder.create<AffineStoreOp>(loc, valueToStore, alloc, ivs);
85 });
86
87 // Replace this operation with the generated alloc.
88 rewriter.replaceOp(op, alloc);
89 }
90
91 namespace {
92 //===----------------------------------------------------------------------===//
93 // ToyToAffine RewritePatterns: Binary operations
94 //===----------------------------------------------------------------------===//
95
96 template <typename BinaryOp, typename LoweredBinaryOp>
97 struct BinaryOpLowering : public ConversionPattern {
BinaryOpLowering__anone3a5770b0211::BinaryOpLowering98 BinaryOpLowering(MLIRContext *ctx)
99 : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
100
101 LogicalResult
matchAndRewrite__anone3a5770b0211::BinaryOpLowering102 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
103 ConversionPatternRewriter &rewriter) const final {
104 auto loc = op->getLoc();
105 lowerOpToLoops(
106 op, operands, rewriter,
107 [loc](OpBuilder &builder, ValueRange memRefOperands,
108 ValueRange loopIvs) {
109 // Generate an adaptor for the remapped operands of the BinaryOp. This
110 // allows for using the nice named accessors that are generated by the
111 // ODS.
112 typename BinaryOp::Adaptor binaryAdaptor(memRefOperands);
113
114 // Generate loads for the element of 'lhs' and 'rhs' at the inner
115 // loop.
116 auto loadedLhs =
117 builder.create<AffineLoadOp>(loc, binaryAdaptor.lhs(), loopIvs);
118 auto loadedRhs =
119 builder.create<AffineLoadOp>(loc, binaryAdaptor.rhs(), loopIvs);
120
121 // Create the binary operation performed on the loaded values.
122 return builder.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
123 });
124 return success();
125 }
126 };
127 using AddOpLowering = BinaryOpLowering<toy::AddOp, AddFOp>;
128 using MulOpLowering = BinaryOpLowering<toy::MulOp, MulFOp>;
129
130 //===----------------------------------------------------------------------===//
131 // ToyToAffine RewritePatterns: Constant operations
132 //===----------------------------------------------------------------------===//
133
134 struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
135 using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
136
matchAndRewrite__anone3a5770b0211::ConstantOpLowering137 LogicalResult matchAndRewrite(toy::ConstantOp op,
138 PatternRewriter &rewriter) const final {
139 DenseElementsAttr constantValue = op.value();
140 Location loc = op.getLoc();
141
142 // When lowering the constant operation, we allocate and assign the constant
143 // values to a corresponding memref allocation.
144 auto tensorType = op.getType().cast<TensorType>();
145 auto memRefType = convertTensorToMemRef(tensorType);
146 auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
147
148 // We will be generating constant indices up-to the largest dimension.
149 // Create these constants up-front to avoid large amounts of redundant
150 // operations.
151 auto valueShape = memRefType.getShape();
152 SmallVector<Value, 8> constantIndices;
153
154 if (!valueShape.empty()) {
155 for (auto i : llvm::seq<int64_t>(
156 0, *std::max_element(valueShape.begin(), valueShape.end())))
157 constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
158 } else {
159 // This is the case of a tensor of rank 0.
160 constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, 0));
161 }
162
163 // The constant operation represents a multi-dimensional constant, so we
164 // will need to generate a store for each of the elements. The following
165 // functor recursively walks the dimensions of the constant shape,
166 // generating a store when the recursion hits the base case.
167 SmallVector<Value, 2> indices;
168 auto valueIt = constantValue.value_begin<FloatAttr>();
169 std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) {
170 // The last dimension is the base case of the recursion, at this point
171 // we store the element at the given index.
172 if (dimension == valueShape.size()) {
173 rewriter.create<AffineStoreOp>(
174 loc, rewriter.create<ConstantOp>(loc, *valueIt++), alloc,
175 llvm::makeArrayRef(indices));
176 return;
177 }
178
179 // Otherwise, iterate over the current dimension and add the indices to
180 // the list.
181 for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) {
182 indices.push_back(constantIndices[i]);
183 storeElements(dimension + 1);
184 indices.pop_back();
185 }
186 };
187
188 // Start the element storing recursion from the first dimension.
189 storeElements(/*dimension=*/0);
190
191 // Replace this operation with the generated alloc.
192 rewriter.replaceOp(op, alloc);
193 return success();
194 }
195 };
196
197 //===----------------------------------------------------------------------===//
198 // ToyToAffine RewritePatterns: Return operations
199 //===----------------------------------------------------------------------===//
200
201 struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
202 using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
203
matchAndRewrite__anone3a5770b0211::ReturnOpLowering204 LogicalResult matchAndRewrite(toy::ReturnOp op,
205 PatternRewriter &rewriter) const final {
206 // During this lowering, we expect that all function calls have been
207 // inlined.
208 if (op.hasOperand())
209 return failure();
210
211 // We lower "toy.return" directly to "std.return".
212 rewriter.replaceOpWithNewOp<ReturnOp>(op);
213 return success();
214 }
215 };
216
217 //===----------------------------------------------------------------------===//
218 // ToyToAffine RewritePatterns: Transpose operations
219 //===----------------------------------------------------------------------===//
220
221 struct TransposeOpLowering : public ConversionPattern {
TransposeOpLowering__anone3a5770b0211::TransposeOpLowering222 TransposeOpLowering(MLIRContext *ctx)
223 : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
224
225 LogicalResult
matchAndRewrite__anone3a5770b0211::TransposeOpLowering226 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
227 ConversionPatternRewriter &rewriter) const final {
228 auto loc = op->getLoc();
229 lowerOpToLoops(op, operands, rewriter,
230 [loc](OpBuilder &builder, ValueRange memRefOperands,
231 ValueRange loopIvs) {
232 // Generate an adaptor for the remapped operands of the
233 // TransposeOp. This allows for using the nice named
234 // accessors that are generated by the ODS.
235 toy::TransposeOpAdaptor transposeAdaptor(memRefOperands);
236 Value input = transposeAdaptor.input();
237
238 // Transpose the elements by generating a load from the
239 // reverse indices.
240 SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
241 return builder.create<AffineLoadOp>(loc, input,
242 reverseIvs);
243 });
244 return success();
245 }
246 };
247
248 } // end anonymous namespace.
249
250 //===----------------------------------------------------------------------===//
251 // ToyToAffineLoweringPass
252 //===----------------------------------------------------------------------===//
253
254 /// This is a partial lowering to affine loops of the toy operations that are
255 /// computationally intensive (like matmul for example...) while keeping the
256 /// rest of the code in the Toy dialect.
257 namespace {
258 struct ToyToAffineLoweringPass
259 : public PassWrapper<ToyToAffineLoweringPass, FunctionPass> {
getDependentDialects__anone3a5770b0611::ToyToAffineLoweringPass260 void getDependentDialects(DialectRegistry ®istry) const override {
261 registry.insert<AffineDialect, memref::MemRefDialect, StandardOpsDialect>();
262 }
263 void runOnFunction() final;
264 };
265 } // end anonymous namespace.
266
runOnFunction()267 void ToyToAffineLoweringPass::runOnFunction() {
268 auto function = getFunction();
269
270 // We only lower the main function as we expect that all other functions have
271 // been inlined.
272 if (function.getName() != "main")
273 return;
274
275 // Verify that the given main has no inputs and results.
276 if (function.getNumArguments() || function.getType().getNumResults()) {
277 function.emitError("expected 'main' to have 0 inputs and 0 results");
278 return signalPassFailure();
279 }
280
281 // The first thing to define is the conversion target. This will define the
282 // final target for this lowering.
283 ConversionTarget target(getContext());
284
285 // We define the specific operations, or dialects, that are legal targets for
286 // this lowering. In our case, we are lowering to a combination of the
287 // `Affine`, `MemRef` and `Standard` dialects.
288 target.addLegalDialect<AffineDialect, memref::MemRefDialect,
289 StandardOpsDialect>();
290
291 // We also define the Toy dialect as Illegal so that the conversion will fail
292 // if any of these operations are *not* converted. Given that we actually want
293 // a partial lowering, we explicitly mark the Toy operations that don't want
294 // to lower, `toy.print`, as `legal`.
295 target.addIllegalDialect<toy::ToyDialect>();
296 target.addLegalOp<toy::PrintOp>();
297
298 // Now that the conversion target has been defined, we just need to provide
299 // the set of patterns that will lower the Toy operations.
300 RewritePatternSet patterns(&getContext());
301 patterns.add<AddOpLowering, ConstantOpLowering, MulOpLowering,
302 ReturnOpLowering, TransposeOpLowering>(&getContext());
303
304 // With the target and rewrite patterns defined, we can now attempt the
305 // conversion. The conversion will signal failure if any of our `illegal`
306 // operations were not converted successfully.
307 if (failed(
308 applyPartialConversion(getFunction(), target, std::move(patterns))))
309 signalPassFailure();
310 }
311
312 /// Create a pass for lowering operations in the `Affine` and `Std` dialects,
313 /// for a subset of the Toy IR (e.g. matmul).
createLowerToAffinePass()314 std::unique_ptr<Pass> mlir::toy::createLowerToAffinePass() {
315 return std::make_unique<ToyToAffineLoweringPass>();
316 }
317