1 //====- LowerToAffineLoops.cpp - Partial lowering from Toy to Affine+Std --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a partial lowering of Toy operations to a combination of
10 // affine loops, memref operations and standard operations. This lowering
11 // expects that all calls have been inlined, and all shapes have been resolved.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "toy/Dialect.h"
16 #include "toy/Passes.h"
17 
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "llvm/ADT/Sequence.h"
24 
25 using namespace mlir;
26 
27 //===----------------------------------------------------------------------===//
28 // ToyToAffine RewritePatterns
29 //===----------------------------------------------------------------------===//
30 
31 /// Convert the given TensorType into the corresponding MemRefType.
convertTensorToMemRef(TensorType type)32 static MemRefType convertTensorToMemRef(TensorType type) {
33   assert(type.hasRank() && "expected only ranked shapes");
34   return MemRefType::get(type.getShape(), type.getElementType());
35 }
36 
37 /// Insert an allocation and deallocation for the given MemRefType.
insertAllocAndDealloc(MemRefType type,Location loc,PatternRewriter & rewriter)38 static Value insertAllocAndDealloc(MemRefType type, Location loc,
39                                    PatternRewriter &rewriter) {
40   auto alloc = rewriter.create<memref::AllocOp>(loc, type);
41 
42   // Make sure to allocate at the beginning of the block.
43   auto *parentBlock = alloc->getBlock();
44   alloc->moveBefore(&parentBlock->front());
45 
46   // Make sure to deallocate this alloc at the end of the block. This is fine
47   // as toy functions have no control flow.
48   auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
49   dealloc->moveBefore(&parentBlock->back());
50   return alloc;
51 }
52 
53 /// This defines the function type used to process an iteration of a lowered
54 /// loop. It takes as input an OpBuilder, an range of memRefOperands
55 /// corresponding to the operands of the input operation, and the range of loop
56 /// induction variables for the iteration. It returns a value to store at the
57 /// current index of the iteration.
58 using LoopIterationFn = function_ref<Value(
59     OpBuilder &rewriter, ValueRange memRefOperands, ValueRange loopIvs)>;
60 
lowerOpToLoops(Operation * op,ValueRange operands,PatternRewriter & rewriter,LoopIterationFn processIteration)61 static void lowerOpToLoops(Operation *op, ValueRange operands,
62                            PatternRewriter &rewriter,
63                            LoopIterationFn processIteration) {
64   auto tensorType = (*op->result_type_begin()).cast<TensorType>();
65   auto loc = op->getLoc();
66 
67   // Insert an allocation and deallocation for the result of this operation.
68   auto memRefType = convertTensorToMemRef(tensorType);
69   auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
70 
71   // Create a nest of affine loops, with one loop per dimension of the shape.
72   // The buildAffineLoopNest function takes a callback that is used to construct
73   // the body of the innermost loop given a builder, a location and a range of
74   // loop induction variables.
75   SmallVector<int64_t, 4> lowerBounds(tensorType.getRank(), /*Value=*/0);
76   SmallVector<int64_t, 4> steps(tensorType.getRank(), /*Value=*/1);
77   buildAffineLoopNest(
78       rewriter, loc, lowerBounds, tensorType.getShape(), steps,
79       [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) {
80         // Call the processing function with the rewriter, the memref operands,
81         // and the loop induction variables. This function will return the value
82         // to store at the current index.
83         Value valueToStore = processIteration(nestedBuilder, operands, ivs);
84         nestedBuilder.create<AffineStoreOp>(loc, valueToStore, alloc, ivs);
85       });
86 
87   // Replace this operation with the generated alloc.
88   rewriter.replaceOp(op, alloc);
89 }
90 
91 namespace {
92 //===----------------------------------------------------------------------===//
93 // ToyToAffine RewritePatterns: Binary operations
94 //===----------------------------------------------------------------------===//
95 
96 template <typename BinaryOp, typename LoweredBinaryOp>
97 struct BinaryOpLowering : public ConversionPattern {
BinaryOpLowering__anone3a5770b0211::BinaryOpLowering98   BinaryOpLowering(MLIRContext *ctx)
99       : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
100 
101   LogicalResult
matchAndRewrite__anone3a5770b0211::BinaryOpLowering102   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
103                   ConversionPatternRewriter &rewriter) const final {
104     auto loc = op->getLoc();
105     lowerOpToLoops(
106         op, operands, rewriter,
107         [loc](OpBuilder &builder, ValueRange memRefOperands,
108               ValueRange loopIvs) {
109           // Generate an adaptor for the remapped operands of the BinaryOp. This
110           // allows for using the nice named accessors that are generated by the
111           // ODS.
112           typename BinaryOp::Adaptor binaryAdaptor(memRefOperands);
113 
114           // Generate loads for the element of 'lhs' and 'rhs' at the inner
115           // loop.
116           auto loadedLhs =
117               builder.create<AffineLoadOp>(loc, binaryAdaptor.lhs(), loopIvs);
118           auto loadedRhs =
119               builder.create<AffineLoadOp>(loc, binaryAdaptor.rhs(), loopIvs);
120 
121           // Create the binary operation performed on the loaded values.
122           return builder.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
123         });
124     return success();
125   }
126 };
127 using AddOpLowering = BinaryOpLowering<toy::AddOp, AddFOp>;
128 using MulOpLowering = BinaryOpLowering<toy::MulOp, MulFOp>;
129 
130 //===----------------------------------------------------------------------===//
131 // ToyToAffine RewritePatterns: Constant operations
132 //===----------------------------------------------------------------------===//
133 
134 struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
135   using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
136 
matchAndRewrite__anone3a5770b0211::ConstantOpLowering137   LogicalResult matchAndRewrite(toy::ConstantOp op,
138                                 PatternRewriter &rewriter) const final {
139     DenseElementsAttr constantValue = op.value();
140     Location loc = op.getLoc();
141 
142     // When lowering the constant operation, we allocate and assign the constant
143     // values to a corresponding memref allocation.
144     auto tensorType = op.getType().cast<TensorType>();
145     auto memRefType = convertTensorToMemRef(tensorType);
146     auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
147 
148     // We will be generating constant indices up-to the largest dimension.
149     // Create these constants up-front to avoid large amounts of redundant
150     // operations.
151     auto valueShape = memRefType.getShape();
152     SmallVector<Value, 8> constantIndices;
153 
154     if (!valueShape.empty()) {
155       for (auto i : llvm::seq<int64_t>(
156                0, *std::max_element(valueShape.begin(), valueShape.end())))
157         constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
158     } else {
159       // This is the case of a tensor of rank 0.
160       constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, 0));
161     }
162 
163     // The constant operation represents a multi-dimensional constant, so we
164     // will need to generate a store for each of the elements. The following
165     // functor recursively walks the dimensions of the constant shape,
166     // generating a store when the recursion hits the base case.
167     SmallVector<Value, 2> indices;
168     auto valueIt = constantValue.value_begin<FloatAttr>();
169     std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) {
170       // The last dimension is the base case of the recursion, at this point
171       // we store the element at the given index.
172       if (dimension == valueShape.size()) {
173         rewriter.create<AffineStoreOp>(
174             loc, rewriter.create<ConstantOp>(loc, *valueIt++), alloc,
175             llvm::makeArrayRef(indices));
176         return;
177       }
178 
179       // Otherwise, iterate over the current dimension and add the indices to
180       // the list.
181       for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) {
182         indices.push_back(constantIndices[i]);
183         storeElements(dimension + 1);
184         indices.pop_back();
185       }
186     };
187 
188     // Start the element storing recursion from the first dimension.
189     storeElements(/*dimension=*/0);
190 
191     // Replace this operation with the generated alloc.
192     rewriter.replaceOp(op, alloc);
193     return success();
194   }
195 };
196 
197 //===----------------------------------------------------------------------===//
198 // ToyToAffine RewritePatterns: Return operations
199 //===----------------------------------------------------------------------===//
200 
201 struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
202   using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
203 
matchAndRewrite__anone3a5770b0211::ReturnOpLowering204   LogicalResult matchAndRewrite(toy::ReturnOp op,
205                                 PatternRewriter &rewriter) const final {
206     // During this lowering, we expect that all function calls have been
207     // inlined.
208     if (op.hasOperand())
209       return failure();
210 
211     // We lower "toy.return" directly to "std.return".
212     rewriter.replaceOpWithNewOp<ReturnOp>(op);
213     return success();
214   }
215 };
216 
217 //===----------------------------------------------------------------------===//
218 // ToyToAffine RewritePatterns: Transpose operations
219 //===----------------------------------------------------------------------===//
220 
221 struct TransposeOpLowering : public ConversionPattern {
TransposeOpLowering__anone3a5770b0211::TransposeOpLowering222   TransposeOpLowering(MLIRContext *ctx)
223       : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
224 
225   LogicalResult
matchAndRewrite__anone3a5770b0211::TransposeOpLowering226   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
227                   ConversionPatternRewriter &rewriter) const final {
228     auto loc = op->getLoc();
229     lowerOpToLoops(op, operands, rewriter,
230                    [loc](OpBuilder &builder, ValueRange memRefOperands,
231                          ValueRange loopIvs) {
232                      // Generate an adaptor for the remapped operands of the
233                      // TransposeOp. This allows for using the nice named
234                      // accessors that are generated by the ODS.
235                      toy::TransposeOpAdaptor transposeAdaptor(memRefOperands);
236                      Value input = transposeAdaptor.input();
237 
238                      // Transpose the elements by generating a load from the
239                      // reverse indices.
240                      SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
241                      return builder.create<AffineLoadOp>(loc, input,
242                                                          reverseIvs);
243                    });
244     return success();
245   }
246 };
247 
248 } // end anonymous namespace.
249 
250 //===----------------------------------------------------------------------===//
251 // ToyToAffineLoweringPass
252 //===----------------------------------------------------------------------===//
253 
254 /// This is a partial lowering to affine loops of the toy operations that are
255 /// computationally intensive (like matmul for example...) while keeping the
256 /// rest of the code in the Toy dialect.
257 namespace {
258 struct ToyToAffineLoweringPass
259     : public PassWrapper<ToyToAffineLoweringPass, FunctionPass> {
getDependentDialects__anone3a5770b0611::ToyToAffineLoweringPass260   void getDependentDialects(DialectRegistry &registry) const override {
261     registry.insert<AffineDialect, memref::MemRefDialect, StandardOpsDialect>();
262   }
263   void runOnFunction() final;
264 };
265 } // end anonymous namespace.
266 
runOnFunction()267 void ToyToAffineLoweringPass::runOnFunction() {
268   auto function = getFunction();
269 
270   // We only lower the main function as we expect that all other functions have
271   // been inlined.
272   if (function.getName() != "main")
273     return;
274 
275   // Verify that the given main has no inputs and results.
276   if (function.getNumArguments() || function.getType().getNumResults()) {
277     function.emitError("expected 'main' to have 0 inputs and 0 results");
278     return signalPassFailure();
279   }
280 
281   // The first thing to define is the conversion target. This will define the
282   // final target for this lowering.
283   ConversionTarget target(getContext());
284 
285   // We define the specific operations, or dialects, that are legal targets for
286   // this lowering. In our case, we are lowering to a combination of the
287   // `Affine`, `MemRef` and `Standard` dialects.
288   target.addLegalDialect<AffineDialect, memref::MemRefDialect,
289                          StandardOpsDialect>();
290 
291   // We also define the Toy dialect as Illegal so that the conversion will fail
292   // if any of these operations are *not* converted. Given that we actually want
293   // a partial lowering, we explicitly mark the Toy operations that don't want
294   // to lower, `toy.print`, as `legal`.
295   target.addIllegalDialect<toy::ToyDialect>();
296   target.addLegalOp<toy::PrintOp>();
297 
298   // Now that the conversion target has been defined, we just need to provide
299   // the set of patterns that will lower the Toy operations.
300   RewritePatternSet patterns(&getContext());
301   patterns.add<AddOpLowering, ConstantOpLowering, MulOpLowering,
302                ReturnOpLowering, TransposeOpLowering>(&getContext());
303 
304   // With the target and rewrite patterns defined, we can now attempt the
305   // conversion. The conversion will signal failure if any of our `illegal`
306   // operations were not converted successfully.
307   if (failed(
308           applyPartialConversion(getFunction(), target, std::move(patterns))))
309     signalPassFailure();
310 }
311 
312 /// Create a pass for lowering operations in the `Affine` and `Std` dialects,
313 /// for a subset of the Toy IR (e.g. matmul).
createLowerToAffinePass()314 std::unique_ptr<Pass> mlir::toy::createLowerToAffinePass() {
315   return std::make_unique<ToyToAffineLoweringPass>();
316 }
317