1 //===- LegalizeStandardForSPIRV.cpp - Legalize ops for SPIR-V lowering ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass legalizes operations before the conversion to SPIR-V
10 // dialect to handle ops that cannot be lowered directly.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "../PassDetail.h"
15 #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h"
16 #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/Dialect/Vector/VectorOps.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 
23 using namespace mlir;
24 
25 /// Helpers to access the memref operand for each op.
getMemRefOperand(LoadOp op)26 static Value getMemRefOperand(LoadOp op) { return op.memref(); }
27 
getMemRefOperand(vector::TransferReadOp op)28 static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); }
29 
getMemRefOperand(StoreOp op)30 static Value getMemRefOperand(StoreOp op) { return op.memref(); }
31 
getMemRefOperand(vector::TransferWriteOp op)32 static Value getMemRefOperand(vector::TransferWriteOp op) {
33   return op.source();
34 }
35 
36 namespace {
37 /// Merges subview operation with load/transferRead operation.
38 template <typename OpTy>
39 class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
40 public:
41   using OpRewritePattern<OpTy>::OpRewritePattern;
42 
43   LogicalResult matchAndRewrite(OpTy loadOp,
44                                 PatternRewriter &rewriter) const override;
45 
46 private:
47   void replaceOp(OpTy loadOp, SubViewOp subViewOp,
48                  ArrayRef<Value> sourceIndices,
49                  PatternRewriter &rewriter) const;
50 };
51 
52 /// Merges subview operation with store/transferWriteOp operation.
53 template <typename OpTy>
54 class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
55 public:
56   using OpRewritePattern<OpTy>::OpRewritePattern;
57 
58   LogicalResult matchAndRewrite(OpTy storeOp,
59                                 PatternRewriter &rewriter) const override;
60 
61 private:
62   void replaceOp(OpTy StoreOp, SubViewOp subViewOp,
63                  ArrayRef<Value> sourceIndices,
64                  PatternRewriter &rewriter) const;
65 };
66 
67 template <>
replaceOp(LoadOp loadOp,SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const68 void LoadOpOfSubViewFolder<LoadOp>::replaceOp(LoadOp loadOp,
69                                               SubViewOp subViewOp,
70                                               ArrayRef<Value> sourceIndices,
71                                               PatternRewriter &rewriter) const {
72   rewriter.replaceOpWithNewOp<LoadOp>(loadOp, subViewOp.source(),
73                                       sourceIndices);
74 }
75 
76 template <>
replaceOp(vector::TransferReadOp loadOp,SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const77 void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
78     vector::TransferReadOp loadOp, SubViewOp subViewOp,
79     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
80   rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
81       loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices,
82       loadOp.permutation_map(), loadOp.padding(), loadOp.maskedAttr());
83 }
84 
85 template <>
replaceOp(StoreOp storeOp,SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const86 void StoreOpOfSubViewFolder<StoreOp>::replaceOp(
87     StoreOp storeOp, SubViewOp subViewOp, ArrayRef<Value> sourceIndices,
88     PatternRewriter &rewriter) const {
89   rewriter.replaceOpWithNewOp<StoreOp>(storeOp, storeOp.value(),
90                                        subViewOp.source(), sourceIndices);
91 }
92 
93 template <>
replaceOp(vector::TransferWriteOp tranferWriteOp,SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const94 void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
95     vector::TransferWriteOp tranferWriteOp, SubViewOp subViewOp,
96     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
97   rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
98       tranferWriteOp, tranferWriteOp.vector(), subViewOp.source(),
99       sourceIndices, tranferWriteOp.permutation_map(),
100       tranferWriteOp.maskedAttr());
101 }
102 } // namespace
103 
104 //===----------------------------------------------------------------------===//
105 // Utility functions for op legalization.
106 //===----------------------------------------------------------------------===//
107 
108 /// Given the 'indices' of an load/store operation where the memref is a result
109 /// of a subview op, returns the indices w.r.t to the source memref of the
110 /// subview op. For example
111 ///
112 /// %0 = ... : memref<12x42xf32>
113 /// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
114 ///          memref<4x4xf32, offset=?, strides=[?, ?]>
115 /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
116 ///
117 /// could be folded into
118 ///
119 /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
120 ///          memref<12x42xf32>
121 static LogicalResult
resolveSourceIndices(Location loc,PatternRewriter & rewriter,SubViewOp subViewOp,ValueRange indices,SmallVectorImpl<Value> & sourceIndices)122 resolveSourceIndices(Location loc, PatternRewriter &rewriter,
123                      SubViewOp subViewOp, ValueRange indices,
124                      SmallVectorImpl<Value> &sourceIndices) {
125   // TODO: Aborting when the offsets are static. There might be a way to fold
126   // the subview op with load even if the offsets have been canonicalized
127   // away.
128   SmallVector<Range, 4> opRanges = subViewOp.getOrCreateRanges(rewriter, loc);
129   auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; });
130   auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; });
131   assert(opRanges.size() == indices.size() &&
132          "expected as many indices as rank of subview op result type");
133 
134   // New indices for the load are the current indices * subview_stride +
135   // subview_offset.
136   sourceIndices.resize(indices.size());
137   for (auto index : llvm::enumerate(indices)) {
138     auto offset = *(opOffsets.begin() + index.index());
139     auto stride = *(opStrides.begin() + index.index());
140     auto mul = rewriter.create<MulIOp>(loc, index.value(), stride);
141     sourceIndices[index.index()] =
142         rewriter.create<AddIOp>(loc, offset, mul).getResult();
143   }
144   return success();
145 }
146 
147 //===----------------------------------------------------------------------===//
148 // Folding SubViewOp and LoadOp/TransferReadOp.
149 //===----------------------------------------------------------------------===//
150 
151 template <typename OpTy>
152 LogicalResult
matchAndRewrite(OpTy loadOp,PatternRewriter & rewriter) const153 LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
154                                              PatternRewriter &rewriter) const {
155   auto subViewOp = getMemRefOperand(loadOp).template getDefiningOp<SubViewOp>();
156   if (!subViewOp) {
157     return failure();
158   }
159   SmallVector<Value, 4> sourceIndices;
160   if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
161                                   loadOp.indices(), sourceIndices)))
162     return failure();
163 
164   replaceOp(loadOp, subViewOp, sourceIndices, rewriter);
165   return success();
166 }
167 
168 //===----------------------------------------------------------------------===//
169 // Folding SubViewOp and StoreOp/TransferWriteOp.
170 //===----------------------------------------------------------------------===//
171 
172 template <typename OpTy>
173 LogicalResult
matchAndRewrite(OpTy storeOp,PatternRewriter & rewriter) const174 StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
175                                               PatternRewriter &rewriter) const {
176   auto subViewOp =
177       getMemRefOperand(storeOp).template getDefiningOp<SubViewOp>();
178   if (!subViewOp) {
179     return failure();
180   }
181   SmallVector<Value, 4> sourceIndices;
182   if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
183                                   storeOp.indices(), sourceIndices)))
184     return failure();
185 
186   replaceOp(storeOp, subViewOp, sourceIndices, rewriter);
187   return success();
188 }
189 
190 //===----------------------------------------------------------------------===//
191 // Hook for adding patterns.
192 //===----------------------------------------------------------------------===//
193 
populateStdLegalizationPatternsForSPIRVLowering(MLIRContext * context,OwningRewritePatternList & patterns)194 void mlir::populateStdLegalizationPatternsForSPIRVLowering(
195     MLIRContext *context, OwningRewritePatternList &patterns) {
196   patterns.insert<LoadOpOfSubViewFolder<LoadOp>,
197                   LoadOpOfSubViewFolder<vector::TransferReadOp>,
198                   StoreOpOfSubViewFolder<StoreOp>,
199                   StoreOpOfSubViewFolder<vector::TransferWriteOp>>(context);
200 }
201 
202 //===----------------------------------------------------------------------===//
203 // Pass for testing just the legalization patterns.
204 //===----------------------------------------------------------------------===//
205 
206 namespace {
207 struct SPIRVLegalization final
208     : public LegalizeStandardForSPIRVBase<SPIRVLegalization> {
209   void runOnOperation() override;
210 };
211 } // namespace
212 
runOnOperation()213 void SPIRVLegalization::runOnOperation() {
214   OwningRewritePatternList patterns;
215   auto *context = &getContext();
216   populateStdLegalizationPatternsForSPIRVLowering(context, patterns);
217   applyPatternsAndFoldGreedily(getOperation()->getRegions(),
218                                std::move(patterns));
219 }
220 
createLegalizeStdOpsForSPIRVLoweringPass()221 std::unique_ptr<Pass> mlir::createLegalizeStdOpsForSPIRVLoweringPass() {
222   return std::make_unique<SPIRVLegalization>();
223 }
224