1 //====- LowerToAffineLoops.cpp - Partial lowering from Toy to Affine+Std --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a partial lowering of Toy operations to a combination of
10 // affine loops and standard operations. This lowering expects that all calls
11 // have been inlined, and all shapes have been resolved.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "toy/Dialect.h"
16 #include "toy/Passes.h"
17
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "llvm/ADT/Sequence.h"
23
24 using namespace mlir;
25
26 //===----------------------------------------------------------------------===//
27 // ToyToAffine RewritePatterns
28 //===----------------------------------------------------------------------===//
29
30 /// Convert the given TensorType into the corresponding MemRefType.
convertTensorToMemRef(TensorType type)31 static MemRefType convertTensorToMemRef(TensorType type) {
32 assert(type.hasRank() && "expected only ranked shapes");
33 return MemRefType::get(type.getShape(), type.getElementType());
34 }
35
36 /// Insert an allocation and deallocation for the given MemRefType.
insertAllocAndDealloc(MemRefType type,Location loc,PatternRewriter & rewriter)37 static Value insertAllocAndDealloc(MemRefType type, Location loc,
38 PatternRewriter &rewriter) {
39 auto alloc = rewriter.create<AllocOp>(loc, type);
40
41 // Make sure to allocate at the beginning of the block.
42 auto *parentBlock = alloc->getBlock();
43 alloc->moveBefore(&parentBlock->front());
44
45 // Make sure to deallocate this alloc at the end of the block. This is fine
46 // as toy functions have no control flow.
47 auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
48 dealloc->moveBefore(&parentBlock->back());
49 return alloc;
50 }
51
52 /// This defines the function type used to process an iteration of a lowered
53 /// loop. It takes as input an OpBuilder, an range of memRefOperands
54 /// corresponding to the operands of the input operation, and the range of loop
55 /// induction variables for the iteration. It returns a value to store at the
56 /// current index of the iteration.
57 using LoopIterationFn = function_ref<Value(
58 OpBuilder &builder, ValueRange memRefOperands, ValueRange loopIvs)>;
59
lowerOpToLoops(Operation * op,ArrayRef<Value> operands,PatternRewriter & rewriter,LoopIterationFn processIteration)60 static void lowerOpToLoops(Operation *op, ArrayRef<Value> operands,
61 PatternRewriter &rewriter,
62 LoopIterationFn processIteration) {
63 auto tensorType = (*op->result_type_begin()).cast<TensorType>();
64 auto loc = op->getLoc();
65
66 // Insert an allocation and deallocation for the result of this operation.
67 auto memRefType = convertTensorToMemRef(tensorType);
68 auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
69
70 // Create a nest of affine loops, with one loop per dimension of the shape.
71 // The buildAffineLoopNest function takes a callback that is used to construct
72 // the body of the innermost loop given a builder, a location and a range of
73 // loop induction variables.
74 SmallVector<int64_t, 4> lowerBounds(tensorType.getRank(), /*Value=*/0);
75 SmallVector<int64_t, 4> steps(tensorType.getRank(), /*Value=*/1);
76 buildAffineLoopNest(
77 rewriter, loc, lowerBounds, tensorType.getShape(), steps,
78 [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) {
79 // Call the processing function with the rewriter, the memref operands,
80 // and the loop induction variables. This function will return the value
81 // to store at the current index.
82 Value valueToStore = processIteration(nestedBuilder, operands, ivs);
83 nestedBuilder.create<AffineStoreOp>(loc, valueToStore, alloc, ivs);
84 });
85
86 // Replace this operation with the generated alloc.
87 rewriter.replaceOp(op, alloc);
88 }
89
90 namespace {
91 //===----------------------------------------------------------------------===//
92 // ToyToAffine RewritePatterns: Binary operations
93 //===----------------------------------------------------------------------===//
94
95 template <typename BinaryOp, typename LoweredBinaryOp>
96 struct BinaryOpLowering : public ConversionPattern {
BinaryOpLowering__anon0f27571d0211::BinaryOpLowering97 BinaryOpLowering(MLIRContext *ctx)
98 : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
99
100 LogicalResult
matchAndRewrite__anon0f27571d0211::BinaryOpLowering101 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
102 ConversionPatternRewriter &rewriter) const final {
103 auto loc = op->getLoc();
104 lowerOpToLoops(
105 op, operands, rewriter,
106 [loc](OpBuilder &builder, ValueRange memRefOperands,
107 ValueRange loopIvs) {
108 // Generate an adaptor for the remapped operands of the BinaryOp. This
109 // allows for using the nice named accessors that are generated by the
110 // ODS.
111 typename BinaryOp::Adaptor binaryAdaptor(memRefOperands);
112
113 // Generate loads for the element of 'lhs' and 'rhs' at the inner
114 // loop.
115 auto loadedLhs =
116 builder.create<AffineLoadOp>(loc, binaryAdaptor.lhs(), loopIvs);
117 auto loadedRhs =
118 builder.create<AffineLoadOp>(loc, binaryAdaptor.rhs(), loopIvs);
119
120 // Create the binary operation performed on the loaded values.
121 return builder.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
122 });
123 return success();
124 }
125 };
126 using AddOpLowering = BinaryOpLowering<toy::AddOp, AddFOp>;
127 using MulOpLowering = BinaryOpLowering<toy::MulOp, MulFOp>;
128
129 //===----------------------------------------------------------------------===//
130 // ToyToAffine RewritePatterns: Constant operations
131 //===----------------------------------------------------------------------===//
132
133 struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
134 using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
135
matchAndRewrite__anon0f27571d0211::ConstantOpLowering136 LogicalResult matchAndRewrite(toy::ConstantOp op,
137 PatternRewriter &rewriter) const final {
138 DenseElementsAttr constantValue = op.value();
139 Location loc = op.getLoc();
140
141 // When lowering the constant operation, we allocate and assign the constant
142 // values to a corresponding memref allocation.
143 auto tensorType = op.getType().cast<TensorType>();
144 auto memRefType = convertTensorToMemRef(tensorType);
145 auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
146
147 // We will be generating constant indices up-to the largest dimension.
148 // Create these constants up-front to avoid large amounts of redundant
149 // operations.
150 auto valueShape = memRefType.getShape();
151 SmallVector<Value, 8> constantIndices;
152
153 if (!valueShape.empty()) {
154 for (auto i : llvm::seq<int64_t>(
155 0, *std::max_element(valueShape.begin(), valueShape.end())))
156 constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
157 } else {
158 // This is the case of a tensor of rank 0.
159 constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, 0));
160 }
161 // The constant operation represents a multi-dimensional constant, so we
162 // will need to generate a store for each of the elements. The following
163 // functor recursively walks the dimensions of the constant shape,
164 // generating a store when the recursion hits the base case.
165 SmallVector<Value, 2> indices;
166 auto valueIt = constantValue.getValues<FloatAttr>().begin();
167 std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) {
168 // The last dimension is the base case of the recursion, at this point
169 // we store the element at the given index.
170 if (dimension == valueShape.size()) {
171 rewriter.create<AffineStoreOp>(
172 loc, rewriter.create<ConstantOp>(loc, *valueIt++), alloc,
173 llvm::makeArrayRef(indices));
174 return;
175 }
176
177 // Otherwise, iterate over the current dimension and add the indices to
178 // the list.
179 for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) {
180 indices.push_back(constantIndices[i]);
181 storeElements(dimension + 1);
182 indices.pop_back();
183 }
184 };
185
186 // Start the element storing recursion from the first dimension.
187 storeElements(/*dimension=*/0);
188
189 // Replace this operation with the generated alloc.
190 rewriter.replaceOp(op, alloc);
191 return success();
192 }
193 };
194
195 //===----------------------------------------------------------------------===//
196 // ToyToAffine RewritePatterns: Return operations
197 //===----------------------------------------------------------------------===//
198
199 struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
200 using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
201
matchAndRewrite__anon0f27571d0211::ReturnOpLowering202 LogicalResult matchAndRewrite(toy::ReturnOp op,
203 PatternRewriter &rewriter) const final {
204 // During this lowering, we expect that all function calls have been
205 // inlined.
206 if (op.hasOperand())
207 return failure();
208
209 // We lower "toy.return" directly to "std.return".
210 rewriter.replaceOpWithNewOp<ReturnOp>(op);
211 return success();
212 }
213 };
214
215 //===----------------------------------------------------------------------===//
216 // ToyToAffine RewritePatterns: Transpose operations
217 //===----------------------------------------------------------------------===//
218
219 struct TransposeOpLowering : public ConversionPattern {
TransposeOpLowering__anon0f27571d0211::TransposeOpLowering220 TransposeOpLowering(MLIRContext *ctx)
221 : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
222
223 LogicalResult
matchAndRewrite__anon0f27571d0211::TransposeOpLowering224 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
225 ConversionPatternRewriter &rewriter) const final {
226 auto loc = op->getLoc();
227 lowerOpToLoops(op, operands, rewriter,
228 [loc](OpBuilder &builder, ValueRange memRefOperands,
229 ValueRange loopIvs) {
230 // Generate an adaptor for the remapped operands of the
231 // TransposeOp. This allows for using the nice named
232 // accessors that are generated by the ODS.
233 toy::TransposeOpAdaptor transposeAdaptor(memRefOperands);
234 Value input = transposeAdaptor.input();
235
236 // Transpose the elements by generating a load from the
237 // reverse indices.
238 SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
239 return builder.create<AffineLoadOp>(loc, input,
240 reverseIvs);
241 });
242 return success();
243 }
244 };
245
246 } // end anonymous namespace.
247
248 //===----------------------------------------------------------------------===//
249 // ToyToAffineLoweringPass
250 //===----------------------------------------------------------------------===//
251
252 /// This is a partial lowering to affine loops of the toy operations that are
253 /// computationally intensive (like matmul for example...) while keeping the
254 /// rest of the code in the Toy dialect.
255 namespace {
256 struct ToyToAffineLoweringPass
257 : public PassWrapper<ToyToAffineLoweringPass, FunctionPass> {
getDependentDialects__anon0f27571d0611::ToyToAffineLoweringPass258 void getDependentDialects(DialectRegistry ®istry) const override {
259 registry.insert<AffineDialect, StandardOpsDialect>();
260 }
261 void runOnFunction() final;
262 };
263 } // end anonymous namespace.
264
runOnFunction()265 void ToyToAffineLoweringPass::runOnFunction() {
266 auto function = getFunction();
267
268 // We only lower the main function as we expect that all other functions have
269 // been inlined.
270 if (function.getName() != "main")
271 return;
272
273 // Verify that the given main has no inputs and results.
274 if (function.getNumArguments() || function.getType().getNumResults()) {
275 function.emitError("expected 'main' to have 0 inputs and 0 results");
276 return signalPassFailure();
277 }
278
279 // The first thing to define is the conversion target. This will define the
280 // final target for this lowering.
281 ConversionTarget target(getContext());
282
283 // We define the specific operations, or dialects, that are legal targets for
284 // this lowering. In our case, we are lowering to a combination of the
285 // `Affine` and `Standard` dialects.
286 target.addLegalDialect<AffineDialect, StandardOpsDialect>();
287
288 // We also define the Toy dialect as Illegal so that the conversion will fail
289 // if any of these operations are *not* converted. Given that we actually want
290 // a partial lowering, we explicitly mark the Toy operations that don't want
291 // to lower, `toy.print`, as `legal`.
292 target.addIllegalDialect<toy::ToyDialect>();
293 target.addLegalOp<toy::PrintOp>();
294
295 // Now that the conversion target has been defined, we just need to provide
296 // the set of patterns that will lower the Toy operations.
297 OwningRewritePatternList patterns;
298 patterns.insert<AddOpLowering, ConstantOpLowering, MulOpLowering,
299 ReturnOpLowering, TransposeOpLowering>(&getContext());
300
301 // With the target and rewrite patterns defined, we can now attempt the
302 // conversion. The conversion will signal failure if any of our `illegal`
303 // operations were not converted successfully.
304 if (failed(
305 applyPartialConversion(getFunction(), target, std::move(patterns))))
306 signalPassFailure();
307 }
308
309 /// Create a pass for lowering operations in the `Affine` and `Std` dialects,
310 /// for a subset of the Toy IR (e.g. matmul).
createLowerToAffinePass()311 std::unique_ptr<Pass> mlir::toy::createLowerToAffinePass() {
312 return std::make_unique<ToyToAffineLoweringPass>();
313 }
314