1 //===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include <type_traits>
10
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
14 #include "mlir/Dialect/SCF/SCF.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/Dialect/Vector/VectorTransforms.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20
21 using namespace mlir;
22 using namespace mlir::vector;
23 namespace {
24
25 struct TestVectorToVectorConversion
26 : public PassWrapper<TestVectorToVectorConversion, FunctionPass> {
runOnFunction__anonfe31e6f10111::TestVectorToVectorConversion27 void runOnFunction() override {
28 OwningRewritePatternList patterns;
29 auto *ctx = &getContext();
30 patterns.insert<UnrollVectorPattern<AddFOp>>(
31 ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
32 patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
33 ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2, 2}));
34 populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
35 populateVectorToVectorTransformationPatterns(patterns, ctx);
36 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
37 }
38 };
39
40 struct TestVectorSlicesConversion
41 : public PassWrapper<TestVectorSlicesConversion, FunctionPass> {
runOnFunction__anonfe31e6f10111::TestVectorSlicesConversion42 void runOnFunction() override {
43 OwningRewritePatternList patterns;
44 populateVectorSlicesLoweringPatterns(patterns, &getContext());
45 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
46 }
47 };
48
49 struct TestVectorContractionConversion
50 : public PassWrapper<TestVectorContractionConversion, FunctionPass> {
51 TestVectorContractionConversion() = default;
TestVectorContractionConversion__anonfe31e6f10111::TestVectorContractionConversion52 TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
53 }
54
55 Option<bool> lowerToFlatMatrix{
56 *this, "vector-lower-matrix-intrinsics",
57 llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
58 llvm::cl::init(false)};
59 Option<bool> lowerToFlatTranspose{
60 *this, "vector-flat-transpose",
61 llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
62 llvm::cl::init(false)};
63 Option<bool> lowerToOuterProduct{
64 *this, "vector-outerproduct",
65 llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
66 llvm::cl::init(false)};
67 Option<bool> lowerToFilterOuterProduct{
68 *this, "vector-filter-outerproduct",
69 llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
70 "vectors of size 4."),
71 llvm::cl::init(false)};
72
runOnFunction__anonfe31e6f10111::TestVectorContractionConversion73 void runOnFunction() override {
74 OwningRewritePatternList patterns;
75
76 // Test on one pattern in isolation.
77 if (lowerToOuterProduct) {
78 VectorContractLowering lowering = VectorContractLowering::OuterProduct;
79 VectorTransformsOptions options{lowering};
80 patterns.insert<ContractionOpToOuterProductOpLowering>(options,
81 &getContext());
82 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
83 return;
84 }
85
86 // Test on one pattern in isolation.
87 if (lowerToFilterOuterProduct) {
88 VectorContractLowering lowering = VectorContractLowering::OuterProduct;
89 VectorTransformsOptions options{lowering};
90 patterns.insert<ContractionOpToOuterProductOpLowering>(
91 options, &getContext(), [](vector::ContractionOp op) {
92 // Only lowers vector.contract where the lhs as a type vector<MxNx?>
93 // where M is not 4.
94 if (op.getRhsType().getShape()[0] == 4)
95 return failure();
96 return success();
97 });
98 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
99 return;
100 }
101
102 // Test on all contract lowering patterns.
103 VectorContractLowering contractLowering = VectorContractLowering::Dot;
104 if (lowerToFlatMatrix)
105 contractLowering = VectorContractLowering::Matmul;
106 VectorTransposeLowering transposeLowering =
107 VectorTransposeLowering::EltWise;
108 if (lowerToFlatTranspose)
109 transposeLowering = VectorTransposeLowering::Flat;
110 VectorTransformsOptions options{contractLowering, transposeLowering};
111 populateVectorContractLoweringPatterns(patterns, &getContext(), options);
112 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
113 }
114 };
115
116 struct TestVectorUnrollingPatterns
117 : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
118 TestVectorUnrollingPatterns() = default;
TestVectorUnrollingPatterns__anonfe31e6f10111::TestVectorUnrollingPatterns119 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
runOnFunction__anonfe31e6f10111::TestVectorUnrollingPatterns120 void runOnFunction() override {
121 MLIRContext *ctx = &getContext();
122 OwningRewritePatternList patterns;
123 patterns.insert<UnrollVectorPattern<AddFOp>>(
124 ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
125
126 if (unrollBasedOnType) {
127 UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
128 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
129 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
130 SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
131 if (auto floatType = contractOp.getLhsType()
132 .getElementType()
133 .dyn_cast<FloatType>()) {
134 if (floatType.getWidth() == 16) {
135 nativeShape[2] = 4;
136 }
137 }
138 return nativeShape;
139 };
140 patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
141 ctx, UnrollVectorOptions().setNativeShapeFn(nativeShapeFn));
142 } else {
143 patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
144 ctx,
145 UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2, 2}));
146 }
147 populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
148 populateVectorToVectorTransformationPatterns(patterns, ctx);
149 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
150 }
151
152 Option<bool> unrollBasedOnType{
153 *this, "unroll-based-on-type",
154 llvm::cl::desc("Set the unroll factor based on type of the operation"),
155 llvm::cl::init(false)};
156 };
157
158 struct TestVectorDistributePatterns
159 : public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
160 TestVectorDistributePatterns() = default;
TestVectorDistributePatterns__anonfe31e6f10111::TestVectorDistributePatterns161 TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {}
getDependentDialects__anonfe31e6f10111::TestVectorDistributePatterns162 void getDependentDialects(DialectRegistry ®istry) const override {
163 registry.insert<VectorDialect>();
164 registry.insert<AffineDialect>();
165 }
166 Option<int32_t> multiplicity{
167 *this, "distribution-multiplicity",
168 llvm::cl::desc("Set the multiplicity used for distributing vector"),
169 llvm::cl::init(32)};
runOnFunction__anonfe31e6f10111::TestVectorDistributePatterns170 void runOnFunction() override {
171 MLIRContext *ctx = &getContext();
172 OwningRewritePatternList patterns;
173 FuncOp func = getFunction();
174 func.walk([&](AddFOp op) {
175 OpBuilder builder(op);
176 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
177 builder, op.getOperation(), func.getArgument(0), multiplicity);
178 if (ops.hasValue()) {
179 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
180 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
181 }
182 });
183 patterns.insert<PointwiseExtractPattern>(ctx);
184 populateVectorToVectorTransformationPatterns(patterns, ctx);
185 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
186 }
187 };
188
189 struct TestVectorToLoopPatterns
190 : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> {
191 TestVectorToLoopPatterns() = default;
TestVectorToLoopPatterns__anonfe31e6f10111::TestVectorToLoopPatterns192 TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {}
getDependentDialects__anonfe31e6f10111::TestVectorToLoopPatterns193 void getDependentDialects(DialectRegistry ®istry) const override {
194 registry.insert<VectorDialect>();
195 registry.insert<AffineDialect>();
196 }
197 Option<int32_t> multiplicity{
198 *this, "distribution-multiplicity",
199 llvm::cl::desc("Set the multiplicity used for distributing vector"),
200 llvm::cl::init(32)};
runOnFunction__anonfe31e6f10111::TestVectorToLoopPatterns201 void runOnFunction() override {
202 MLIRContext *ctx = &getContext();
203 OwningRewritePatternList patterns;
204 FuncOp func = getFunction();
205 func.walk([&](AddFOp op) {
206 // Check that the operation type can be broken down into a loop.
207 VectorType type = op.getType().dyn_cast<VectorType>();
208 if (!type || type.getRank() != 1 ||
209 type.getNumElements() % multiplicity != 0)
210 return mlir::WalkResult::advance();
211 auto filterAlloc = [](Operation *op) {
212 if (isa<ConstantOp, AllocOp, CallOp>(op))
213 return false;
214 return true;
215 };
216 auto dependentOps = getSlice(op, filterAlloc);
217 // Create a loop and move instructions from the Op slice into the loop.
218 OpBuilder builder(op);
219 auto zero = builder.create<ConstantOp>(
220 op.getLoc(), builder.getIndexType(),
221 builder.getIntegerAttr(builder.getIndexType(), 0));
222 auto one = builder.create<ConstantOp>(
223 op.getLoc(), builder.getIndexType(),
224 builder.getIntegerAttr(builder.getIndexType(), 1));
225 auto numIter = builder.create<ConstantOp>(
226 op.getLoc(), builder.getIndexType(),
227 builder.getIntegerAttr(builder.getIndexType(), multiplicity));
228 auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
229 for (Operation *it : dependentOps) {
230 it->moveBefore(forOp.getBody()->getTerminator());
231 }
232 // break up the original op and let the patterns propagate.
233 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
234 builder, op.getOperation(), forOp.getInductionVar(), multiplicity);
235 if (ops.hasValue()) {
236 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
237 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
238 }
239 return mlir::WalkResult::interrupt();
240 });
241 patterns.insert<PointwiseExtractPattern>(ctx);
242 populateVectorToVectorTransformationPatterns(patterns, ctx);
243 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
244 }
245 };
246
247 struct TestVectorTransferUnrollingPatterns
248 : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> {
getDependentDialects__anonfe31e6f10111::TestVectorTransferUnrollingPatterns249 void getDependentDialects(DialectRegistry ®istry) const override {
250 registry.insert<AffineDialect>();
251 }
runOnFunction__anonfe31e6f10111::TestVectorTransferUnrollingPatterns252 void runOnFunction() override {
253 MLIRContext *ctx = &getContext();
254 OwningRewritePatternList patterns;
255 patterns.insert<UnrollVectorPattern<vector::TransferReadOp>>(
256 ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
257 patterns.insert<UnrollVectorPattern<vector::TransferWriteOp>>(
258 ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
259 populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
260 populateVectorToVectorTransformationPatterns(patterns, ctx);
261 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
262 }
263 };
264
265 struct TestVectorTransferFullPartialSplitPatterns
266 : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
267 FunctionPass> {
268 TestVectorTransferFullPartialSplitPatterns() = default;
TestVectorTransferFullPartialSplitPatterns__anonfe31e6f10111::TestVectorTransferFullPartialSplitPatterns269 TestVectorTransferFullPartialSplitPatterns(
270 const TestVectorTransferFullPartialSplitPatterns &pass) {}
271
getDependentDialects__anonfe31e6f10111::TestVectorTransferFullPartialSplitPatterns272 void getDependentDialects(DialectRegistry ®istry) const override {
273 registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
274 }
275
276 Option<bool> useLinalgOps{
277 *this, "use-linalg-copy",
278 llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
279 "linalg.copy operations."),
280 llvm::cl::init(false)};
runOnFunction__anonfe31e6f10111::TestVectorTransferFullPartialSplitPatterns281 void runOnFunction() override {
282 MLIRContext *ctx = &getContext();
283 OwningRewritePatternList patterns;
284 VectorTransformsOptions options;
285 if (useLinalgOps)
286 options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
287 else
288 options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
289 patterns.insert<VectorTransferFullPartialRewriter>(ctx, options);
290 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
291 }
292 };
293
294 } // end anonymous namespace
295
296 namespace mlir {
297 namespace test {
registerTestVectorConversions()298 void registerTestVectorConversions() {
299 PassRegistration<TestVectorToVectorConversion> vectorToVectorPass(
300 "test-vector-to-vector-conversion",
301 "Test conversion patterns between ops in the vector dialect");
302
303 PassRegistration<TestVectorSlicesConversion> slicesPass(
304 "test-vector-slices-conversion",
305 "Test conversion patterns that lower slices ops in the vector dialect");
306
307 PassRegistration<TestVectorContractionConversion> contractionPass(
308 "test-vector-contraction-conversion",
309 "Test conversion patterns that lower contract ops in the vector dialect");
310
311 PassRegistration<TestVectorUnrollingPatterns> contractionUnrollingPass(
312 "test-vector-unrolling-patterns",
313 "Test conversion patterns to unroll contract ops in the vector dialect");
314
315 PassRegistration<TestVectorTransferUnrollingPatterns> transferOpUnrollingPass(
316 "test-vector-transfer-unrolling-patterns",
317 "Test conversion patterns to unroll transfer ops in the vector dialect");
318
319 PassRegistration<TestVectorTransferFullPartialSplitPatterns>
320 vectorTransformFullPartialPass("test-vector-transfer-full-partial-split",
321 "Test conversion patterns to split "
322 "transfer ops via scf.if + linalg ops");
323 PassRegistration<TestVectorDistributePatterns> distributePass(
324 "test-vector-distribute-patterns",
325 "Test conversion patterns to distribute vector ops in the vector "
326 "dialect");
327 PassRegistration<TestVectorToLoopPatterns> vectorToForLoop(
328 "test-vector-to-forloop",
329 "Test conversion patterns to break up a vector op into a for loop");
330 }
331 } // namespace test
332 } // namespace mlir
333