1 //====- LowerToAffineLoops.cpp - Partial lowering from Toy to Affine+Std --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a partial lowering of Toy operations to a combination of
10 // affine loops and standard operations. This lowering expects that all calls
11 // have been inlined, and all shapes have been resolved.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "toy/Dialect.h"
16 #include "toy/Passes.h"
17 
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "llvm/ADT/Sequence.h"
23 
24 using namespace mlir;
25 
26 //===----------------------------------------------------------------------===//
27 // ToyToAffine RewritePatterns
28 //===----------------------------------------------------------------------===//
29 
30 /// Convert the given TensorType into the corresponding MemRefType.
convertTensorToMemRef(TensorType type)31 static MemRefType convertTensorToMemRef(TensorType type) {
32   assert(type.hasRank() && "expected only ranked shapes");
33   return MemRefType::get(type.getShape(), type.getElementType());
34 }
35 
36 /// Insert an allocation and deallocation for the given MemRefType.
insertAllocAndDealloc(MemRefType type,Location loc,PatternRewriter & rewriter)37 static Value insertAllocAndDealloc(MemRefType type, Location loc,
38                                    PatternRewriter &rewriter) {
39   auto alloc = rewriter.create<AllocOp>(loc, type);
40 
41   // Make sure to allocate at the beginning of the block.
42   auto *parentBlock = alloc->getBlock();
43   alloc->moveBefore(&parentBlock->front());
44 
45   // Make sure to deallocate this alloc at the end of the block. This is fine
46   // as toy functions have no control flow.
47   auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
48   dealloc->moveBefore(&parentBlock->back());
49   return alloc;
50 }
51 
52 /// This defines the function type used to process an iteration of a lowered
53 /// loop. It takes as input an OpBuilder, an range of memRefOperands
54 /// corresponding to the operands of the input operation, and the range of loop
55 /// induction variables for the iteration. It returns a value to store at the
56 /// current index of the iteration.
57 using LoopIterationFn = function_ref<Value(
58     OpBuilder &builder, ValueRange memRefOperands, ValueRange loopIvs)>;
59 
lowerOpToLoops(Operation * op,ArrayRef<Value> operands,PatternRewriter & rewriter,LoopIterationFn processIteration)60 static void lowerOpToLoops(Operation *op, ArrayRef<Value> operands,
61                            PatternRewriter &rewriter,
62                            LoopIterationFn processIteration) {
63   auto tensorType = (*op->result_type_begin()).cast<TensorType>();
64   auto loc = op->getLoc();
65 
66   // Insert an allocation and deallocation for the result of this operation.
67   auto memRefType = convertTensorToMemRef(tensorType);
68   auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
69 
70   // Create a nest of affine loops, with one loop per dimension of the shape.
71   // The buildAffineLoopNest function takes a callback that is used to construct
72   // the body of the innermost loop given a builder, a location and a range of
73   // loop induction variables.
74   SmallVector<int64_t, 4> lowerBounds(tensorType.getRank(), /*Value=*/0);
75   SmallVector<int64_t, 4> steps(tensorType.getRank(), /*Value=*/1);
76   buildAffineLoopNest(
77       rewriter, loc, lowerBounds, tensorType.getShape(), steps,
78       [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) {
79         // Call the processing function with the rewriter, the memref operands,
80         // and the loop induction variables. This function will return the value
81         // to store at the current index.
82         Value valueToStore = processIteration(nestedBuilder, operands, ivs);
83         nestedBuilder.create<AffineStoreOp>(loc, valueToStore, alloc, ivs);
84       });
85 
86   // Replace this operation with the generated alloc.
87   rewriter.replaceOp(op, alloc);
88 }
89 
90 namespace {
91 //===----------------------------------------------------------------------===//
92 // ToyToAffine RewritePatterns: Binary operations
93 //===----------------------------------------------------------------------===//
94 
95 template <typename BinaryOp, typename LoweredBinaryOp>
96 struct BinaryOpLowering : public ConversionPattern {
BinaryOpLowering__anon0f27571d0211::BinaryOpLowering97   BinaryOpLowering(MLIRContext *ctx)
98       : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
99 
100   LogicalResult
matchAndRewrite__anon0f27571d0211::BinaryOpLowering101   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
102                   ConversionPatternRewriter &rewriter) const final {
103     auto loc = op->getLoc();
104     lowerOpToLoops(
105         op, operands, rewriter,
106         [loc](OpBuilder &builder, ValueRange memRefOperands,
107               ValueRange loopIvs) {
108           // Generate an adaptor for the remapped operands of the BinaryOp. This
109           // allows for using the nice named accessors that are generated by the
110           // ODS.
111           typename BinaryOp::Adaptor binaryAdaptor(memRefOperands);
112 
113           // Generate loads for the element of 'lhs' and 'rhs' at the inner
114           // loop.
115           auto loadedLhs =
116               builder.create<AffineLoadOp>(loc, binaryAdaptor.lhs(), loopIvs);
117           auto loadedRhs =
118               builder.create<AffineLoadOp>(loc, binaryAdaptor.rhs(), loopIvs);
119 
120           // Create the binary operation performed on the loaded values.
121           return builder.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
122         });
123     return success();
124   }
125 };
126 using AddOpLowering = BinaryOpLowering<toy::AddOp, AddFOp>;
127 using MulOpLowering = BinaryOpLowering<toy::MulOp, MulFOp>;
128 
129 //===----------------------------------------------------------------------===//
130 // ToyToAffine RewritePatterns: Constant operations
131 //===----------------------------------------------------------------------===//
132 
133 struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
134   using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
135 
matchAndRewrite__anon0f27571d0211::ConstantOpLowering136   LogicalResult matchAndRewrite(toy::ConstantOp op,
137                                 PatternRewriter &rewriter) const final {
138     DenseElementsAttr constantValue = op.value();
139     Location loc = op.getLoc();
140 
141     // When lowering the constant operation, we allocate and assign the constant
142     // values to a corresponding memref allocation.
143     auto tensorType = op.getType().cast<TensorType>();
144     auto memRefType = convertTensorToMemRef(tensorType);
145     auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
146 
147     // We will be generating constant indices up-to the largest dimension.
148     // Create these constants up-front to avoid large amounts of redundant
149     // operations.
150     auto valueShape = memRefType.getShape();
151     SmallVector<Value, 8> constantIndices;
152 
153     if (!valueShape.empty()) {
154       for (auto i : llvm::seq<int64_t>(
155               0, *std::max_element(valueShape.begin(), valueShape.end())))
156        constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
157     } else {
158       // This is the case of a tensor of rank 0.
159       constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, 0));
160     }
161     // The constant operation represents a multi-dimensional constant, so we
162     // will need to generate a store for each of the elements. The following
163     // functor recursively walks the dimensions of the constant shape,
164     // generating a store when the recursion hits the base case.
165     SmallVector<Value, 2> indices;
166     auto valueIt = constantValue.getValues<FloatAttr>().begin();
167     std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) {
168       // The last dimension is the base case of the recursion, at this point
169       // we store the element at the given index.
170       if (dimension == valueShape.size()) {
171         rewriter.create<AffineStoreOp>(
172             loc, rewriter.create<ConstantOp>(loc, *valueIt++), alloc,
173             llvm::makeArrayRef(indices));
174         return;
175       }
176 
177       // Otherwise, iterate over the current dimension and add the indices to
178       // the list.
179       for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) {
180         indices.push_back(constantIndices[i]);
181         storeElements(dimension + 1);
182         indices.pop_back();
183       }
184     };
185 
186     // Start the element storing recursion from the first dimension.
187     storeElements(/*dimension=*/0);
188 
189     // Replace this operation with the generated alloc.
190     rewriter.replaceOp(op, alloc);
191     return success();
192   }
193 };
194 
195 //===----------------------------------------------------------------------===//
196 // ToyToAffine RewritePatterns: Return operations
197 //===----------------------------------------------------------------------===//
198 
199 struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
200   using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
201 
matchAndRewrite__anon0f27571d0211::ReturnOpLowering202   LogicalResult matchAndRewrite(toy::ReturnOp op,
203                                 PatternRewriter &rewriter) const final {
204     // During this lowering, we expect that all function calls have been
205     // inlined.
206     if (op.hasOperand())
207       return failure();
208 
209     // We lower "toy.return" directly to "std.return".
210     rewriter.replaceOpWithNewOp<ReturnOp>(op);
211     return success();
212   }
213 };
214 
215 //===----------------------------------------------------------------------===//
216 // ToyToAffine RewritePatterns: Transpose operations
217 //===----------------------------------------------------------------------===//
218 
219 struct TransposeOpLowering : public ConversionPattern {
TransposeOpLowering__anon0f27571d0211::TransposeOpLowering220   TransposeOpLowering(MLIRContext *ctx)
221       : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
222 
223   LogicalResult
matchAndRewrite__anon0f27571d0211::TransposeOpLowering224   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
225                   ConversionPatternRewriter &rewriter) const final {
226     auto loc = op->getLoc();
227     lowerOpToLoops(op, operands, rewriter,
228                    [loc](OpBuilder &builder, ValueRange memRefOperands,
229                          ValueRange loopIvs) {
230                      // Generate an adaptor for the remapped operands of the
231                      // TransposeOp. This allows for using the nice named
232                      // accessors that are generated by the ODS.
233                      toy::TransposeOpAdaptor transposeAdaptor(memRefOperands);
234                      Value input = transposeAdaptor.input();
235 
236                      // Transpose the elements by generating a load from the
237                      // reverse indices.
238                      SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
239                      return builder.create<AffineLoadOp>(loc, input,
240                                                          reverseIvs);
241                    });
242     return success();
243   }
244 };
245 
246 } // end anonymous namespace.
247 
248 //===----------------------------------------------------------------------===//
249 // ToyToAffineLoweringPass
250 //===----------------------------------------------------------------------===//
251 
252 /// This is a partial lowering to affine loops of the toy operations that are
253 /// computationally intensive (like matmul for example...) while keeping the
254 /// rest of the code in the Toy dialect.
255 namespace {
256 struct ToyToAffineLoweringPass
257     : public PassWrapper<ToyToAffineLoweringPass, FunctionPass> {
getDependentDialects__anon0f27571d0611::ToyToAffineLoweringPass258   void getDependentDialects(DialectRegistry &registry) const override {
259     registry.insert<AffineDialect, StandardOpsDialect>();
260   }
261   void runOnFunction() final;
262 };
263 } // end anonymous namespace.
264 
runOnFunction()265 void ToyToAffineLoweringPass::runOnFunction() {
266   auto function = getFunction();
267 
268   // We only lower the main function as we expect that all other functions have
269   // been inlined.
270   if (function.getName() != "main")
271     return;
272 
273   // Verify that the given main has no inputs and results.
274   if (function.getNumArguments() || function.getType().getNumResults()) {
275     function.emitError("expected 'main' to have 0 inputs and 0 results");
276     return signalPassFailure();
277   }
278 
279   // The first thing to define is the conversion target. This will define the
280   // final target for this lowering.
281   ConversionTarget target(getContext());
282 
283   // We define the specific operations, or dialects, that are legal targets for
284   // this lowering. In our case, we are lowering to a combination of the
285   // `Affine` and `Standard` dialects.
286   target.addLegalDialect<AffineDialect, StandardOpsDialect>();
287 
288   // We also define the Toy dialect as Illegal so that the conversion will fail
289   // if any of these operations are *not* converted. Given that we actually want
290   // a partial lowering, we explicitly mark the Toy operations that don't want
291   // to lower, `toy.print`, as `legal`.
292   target.addIllegalDialect<toy::ToyDialect>();
293   target.addLegalOp<toy::PrintOp>();
294 
295   // Now that the conversion target has been defined, we just need to provide
296   // the set of patterns that will lower the Toy operations.
297   OwningRewritePatternList patterns;
298   patterns.insert<AddOpLowering, ConstantOpLowering, MulOpLowering,
299                   ReturnOpLowering, TransposeOpLowering>(&getContext());
300 
301   // With the target and rewrite patterns defined, we can now attempt the
302   // conversion. The conversion will signal failure if any of our `illegal`
303   // operations were not converted successfully.
304   if (failed(
305           applyPartialConversion(getFunction(), target, std::move(patterns))))
306     signalPassFailure();
307 }
308 
309 /// Create a pass for lowering operations in the `Affine` and `Std` dialects,
310 /// for a subset of the Toy IR (e.g. matmul).
createLowerToAffinePass()311 std::unique_ptr<Pass> mlir::toy::createLowerToAffinePass() {
312   return std::make_unique<ToyToAffineLoweringPass>();
313 }
314