1 //===- LegalizeStandardForSPIRV.cpp - Legalize ops for SPIR-V lowering ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass legalizes operations before the conversion to SPIR-V
10 // dialect to handle ops that cannot be lowered directly.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "../PassDetail.h"
15 #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h"
16 #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/Dialect/Vector/VectorOps.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22
23 using namespace mlir;
24
25 /// Helpers to access the memref operand for each op.
getMemRefOperand(LoadOp op)26 static Value getMemRefOperand(LoadOp op) { return op.memref(); }
27
getMemRefOperand(vector::TransferReadOp op)28 static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); }
29
getMemRefOperand(StoreOp op)30 static Value getMemRefOperand(StoreOp op) { return op.memref(); }
31
getMemRefOperand(vector::TransferWriteOp op)32 static Value getMemRefOperand(vector::TransferWriteOp op) {
33 return op.source();
34 }
35
36 namespace {
37 /// Merges subview operation with load/transferRead operation.
38 template <typename OpTy>
39 class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
40 public:
41 using OpRewritePattern<OpTy>::OpRewritePattern;
42
43 LogicalResult matchAndRewrite(OpTy loadOp,
44 PatternRewriter &rewriter) const override;
45
46 private:
47 void replaceOp(OpTy loadOp, SubViewOp subViewOp,
48 ArrayRef<Value> sourceIndices,
49 PatternRewriter &rewriter) const;
50 };
51
52 /// Merges subview operation with store/transferWriteOp operation.
53 template <typename OpTy>
54 class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
55 public:
56 using OpRewritePattern<OpTy>::OpRewritePattern;
57
58 LogicalResult matchAndRewrite(OpTy storeOp,
59 PatternRewriter &rewriter) const override;
60
61 private:
62 void replaceOp(OpTy StoreOp, SubViewOp subViewOp,
63 ArrayRef<Value> sourceIndices,
64 PatternRewriter &rewriter) const;
65 };
66
67 template <>
replaceOp(LoadOp loadOp,SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const68 void LoadOpOfSubViewFolder<LoadOp>::replaceOp(LoadOp loadOp,
69 SubViewOp subViewOp,
70 ArrayRef<Value> sourceIndices,
71 PatternRewriter &rewriter) const {
72 rewriter.replaceOpWithNewOp<LoadOp>(loadOp, subViewOp.source(),
73 sourceIndices);
74 }
75
76 template <>
replaceOp(vector::TransferReadOp loadOp,SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const77 void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
78 vector::TransferReadOp loadOp, SubViewOp subViewOp,
79 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
80 rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
81 loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices,
82 loadOp.permutation_map(), loadOp.padding(), loadOp.maskedAttr());
83 }
84
85 template <>
replaceOp(StoreOp storeOp,SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const86 void StoreOpOfSubViewFolder<StoreOp>::replaceOp(
87 StoreOp storeOp, SubViewOp subViewOp, ArrayRef<Value> sourceIndices,
88 PatternRewriter &rewriter) const {
89 rewriter.replaceOpWithNewOp<StoreOp>(storeOp, storeOp.value(),
90 subViewOp.source(), sourceIndices);
91 }
92
93 template <>
replaceOp(vector::TransferWriteOp tranferWriteOp,SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const94 void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
95 vector::TransferWriteOp tranferWriteOp, SubViewOp subViewOp,
96 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
97 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
98 tranferWriteOp, tranferWriteOp.vector(), subViewOp.source(),
99 sourceIndices, tranferWriteOp.permutation_map(),
100 tranferWriteOp.maskedAttr());
101 }
102 } // namespace
103
104 //===----------------------------------------------------------------------===//
105 // Utility functions for op legalization.
106 //===----------------------------------------------------------------------===//
107
108 /// Given the 'indices' of an load/store operation where the memref is a result
109 /// of a subview op, returns the indices w.r.t to the source memref of the
110 /// subview op. For example
111 ///
112 /// %0 = ... : memref<12x42xf32>
113 /// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
114 /// memref<4x4xf32, offset=?, strides=[?, ?]>
115 /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
116 ///
117 /// could be folded into
118 ///
119 /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
120 /// memref<12x42xf32>
121 static LogicalResult
resolveSourceIndices(Location loc,PatternRewriter & rewriter,SubViewOp subViewOp,ValueRange indices,SmallVectorImpl<Value> & sourceIndices)122 resolveSourceIndices(Location loc, PatternRewriter &rewriter,
123 SubViewOp subViewOp, ValueRange indices,
124 SmallVectorImpl<Value> &sourceIndices) {
125 // TODO: Aborting when the offsets are static. There might be a way to fold
126 // the subview op with load even if the offsets have been canonicalized
127 // away.
128 SmallVector<Range, 4> opRanges = subViewOp.getOrCreateRanges(rewriter, loc);
129 auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; });
130 auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; });
131 assert(opRanges.size() == indices.size() &&
132 "expected as many indices as rank of subview op result type");
133
134 // New indices for the load are the current indices * subview_stride +
135 // subview_offset.
136 sourceIndices.resize(indices.size());
137 for (auto index : llvm::enumerate(indices)) {
138 auto offset = *(opOffsets.begin() + index.index());
139 auto stride = *(opStrides.begin() + index.index());
140 auto mul = rewriter.create<MulIOp>(loc, index.value(), stride);
141 sourceIndices[index.index()] =
142 rewriter.create<AddIOp>(loc, offset, mul).getResult();
143 }
144 return success();
145 }
146
147 //===----------------------------------------------------------------------===//
148 // Folding SubViewOp and LoadOp/TransferReadOp.
149 //===----------------------------------------------------------------------===//
150
151 template <typename OpTy>
152 LogicalResult
matchAndRewrite(OpTy loadOp,PatternRewriter & rewriter) const153 LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
154 PatternRewriter &rewriter) const {
155 auto subViewOp = getMemRefOperand(loadOp).template getDefiningOp<SubViewOp>();
156 if (!subViewOp) {
157 return failure();
158 }
159 SmallVector<Value, 4> sourceIndices;
160 if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
161 loadOp.indices(), sourceIndices)))
162 return failure();
163
164 replaceOp(loadOp, subViewOp, sourceIndices, rewriter);
165 return success();
166 }
167
168 //===----------------------------------------------------------------------===//
169 // Folding SubViewOp and StoreOp/TransferWriteOp.
170 //===----------------------------------------------------------------------===//
171
172 template <typename OpTy>
173 LogicalResult
matchAndRewrite(OpTy storeOp,PatternRewriter & rewriter) const174 StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
175 PatternRewriter &rewriter) const {
176 auto subViewOp =
177 getMemRefOperand(storeOp).template getDefiningOp<SubViewOp>();
178 if (!subViewOp) {
179 return failure();
180 }
181 SmallVector<Value, 4> sourceIndices;
182 if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
183 storeOp.indices(), sourceIndices)))
184 return failure();
185
186 replaceOp(storeOp, subViewOp, sourceIndices, rewriter);
187 return success();
188 }
189
190 //===----------------------------------------------------------------------===//
191 // Hook for adding patterns.
192 //===----------------------------------------------------------------------===//
193
populateStdLegalizationPatternsForSPIRVLowering(MLIRContext * context,OwningRewritePatternList & patterns)194 void mlir::populateStdLegalizationPatternsForSPIRVLowering(
195 MLIRContext *context, OwningRewritePatternList &patterns) {
196 patterns.insert<LoadOpOfSubViewFolder<LoadOp>,
197 LoadOpOfSubViewFolder<vector::TransferReadOp>,
198 StoreOpOfSubViewFolder<StoreOp>,
199 StoreOpOfSubViewFolder<vector::TransferWriteOp>>(context);
200 }
201
202 //===----------------------------------------------------------------------===//
203 // Pass for testing just the legalization patterns.
204 //===----------------------------------------------------------------------===//
205
206 namespace {
207 struct SPIRVLegalization final
208 : public LegalizeStandardForSPIRVBase<SPIRVLegalization> {
209 void runOnOperation() override;
210 };
211 } // namespace
212
runOnOperation()213 void SPIRVLegalization::runOnOperation() {
214 OwningRewritePatternList patterns;
215 auto *context = &getContext();
216 populateStdLegalizationPatternsForSPIRVLowering(context, patterns);
217 applyPatternsAndFoldGreedily(getOperation()->getRegions(),
218 std::move(patterns));
219 }
220
createLegalizeStdOpsForSPIRVLoweringPass()221 std::unique_ptr<Pass> mlir::createLegalizeStdOpsForSPIRVLoweringPass() {
222 return std::make_unique<SPIRVLegalization>();
223 }
224