1 //===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include <type_traits>
10
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
14 #include "mlir/Dialect/SCF/SCF.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/Dialect/Vector/VectorTransforms.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20
21 using namespace mlir;
22 using namespace mlir::vector;
23 namespace {
24
25 struct TestVectorToVectorConversion
26 : public PassWrapper<TestVectorToVectorConversion, FunctionPass> {
27 TestVectorToVectorConversion() = default;
TestVectorToVectorConversion__anonafb160640111::TestVectorToVectorConversion28 TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {}
29
getDependentDialects__anonafb160640111::TestVectorToVectorConversion30 void getDependentDialects(DialectRegistry ®istry) const override {
31 registry.insert<AffineDialect>();
32 }
33
34 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
35 llvm::cl::init(false)};
36
runOnFunction__anonafb160640111::TestVectorToVectorConversion37 void runOnFunction() override {
38 OwningRewritePatternList patterns;
39 auto *ctx = &getContext();
40 if (unroll) {
41 patterns.insert<UnrollVectorPattern>(
42 ctx,
43 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
44 filter));
45 }
46 populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
47 populateVectorToVectorTransformationPatterns(patterns, ctx);
48 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
49 }
50
51 private:
52 // Return the target shape based on op type.
getShape__anonafb160640111::TestVectorToVectorConversion53 static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
54 if (isa<AddFOp, SelectOp, CmpFOp>(op))
55 return SmallVector<int64_t, 4>(2, 2);
56 if (isa<vector::ContractionOp>(op))
57 return SmallVector<int64_t, 4>(3, 2);
58 return llvm::None;
59 }
60
filter__anonafb160640111::TestVectorToVectorConversion61 static LogicalResult filter(Operation *op) {
62 return success(isa<AddFOp, SelectOp, CmpFOp, ContractionOp>(op));
63 }
64 };
65
66 struct TestVectorSlicesConversion
67 : public PassWrapper<TestVectorSlicesConversion, FunctionPass> {
runOnFunction__anonafb160640111::TestVectorSlicesConversion68 void runOnFunction() override {
69 OwningRewritePatternList patterns;
70 populateVectorSlicesLoweringPatterns(patterns, &getContext());
71 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
72 }
73 };
74
75 struct TestVectorContractionConversion
76 : public PassWrapper<TestVectorContractionConversion, FunctionPass> {
77 TestVectorContractionConversion() = default;
TestVectorContractionConversion__anonafb160640111::TestVectorContractionConversion78 TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
79 }
80
81 Option<bool> lowerToFlatMatrix{
82 *this, "vector-lower-matrix-intrinsics",
83 llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
84 llvm::cl::init(false)};
85 Option<bool> lowerToFlatTranspose{
86 *this, "vector-flat-transpose",
87 llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
88 llvm::cl::init(false)};
89 Option<bool> lowerToOuterProduct{
90 *this, "vector-outerproduct",
91 llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
92 llvm::cl::init(false)};
93 Option<bool> lowerToFilterOuterProduct{
94 *this, "vector-filter-outerproduct",
95 llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
96 "vectors of size 4."),
97 llvm::cl::init(false)};
98
runOnFunction__anonafb160640111::TestVectorContractionConversion99 void runOnFunction() override {
100 OwningRewritePatternList patterns;
101
102 // Test on one pattern in isolation.
103 if (lowerToOuterProduct) {
104 VectorContractLowering lowering = VectorContractLowering::OuterProduct;
105 VectorTransformsOptions options{lowering};
106 patterns.insert<ContractionOpToOuterProductOpLowering>(options,
107 &getContext());
108 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
109 return;
110 }
111
112 // Test on one pattern in isolation.
113 if (lowerToFilterOuterProduct) {
114 VectorContractLowering lowering = VectorContractLowering::OuterProduct;
115 VectorTransformsOptions options{lowering};
116 patterns.insert<ContractionOpToOuterProductOpLowering>(
117 options, &getContext(), [](vector::ContractionOp op) {
118 // Only lowers vector.contract where the lhs as a type vector<MxNx?>
119 // where M is not 4.
120 if (op.getRhsType().getShape()[0] == 4)
121 return failure();
122 return success();
123 });
124 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
125 return;
126 }
127
128 // Test on all contract lowering patterns.
129 VectorContractLowering contractLowering = VectorContractLowering::Dot;
130 if (lowerToFlatMatrix)
131 contractLowering = VectorContractLowering::Matmul;
132 VectorTransposeLowering transposeLowering =
133 VectorTransposeLowering::EltWise;
134 if (lowerToFlatTranspose)
135 transposeLowering = VectorTransposeLowering::Flat;
136 VectorTransformsOptions options{contractLowering, transposeLowering};
137 populateVectorContractLoweringPatterns(patterns, &getContext(), options);
138 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
139 }
140 };
141
142 struct TestVectorUnrollingPatterns
143 : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
144 TestVectorUnrollingPatterns() = default;
TestVectorUnrollingPatterns__anonafb160640111::TestVectorUnrollingPatterns145 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
runOnFunction__anonafb160640111::TestVectorUnrollingPatterns146 void runOnFunction() override {
147 MLIRContext *ctx = &getContext();
148 OwningRewritePatternList patterns;
149 patterns.insert<UnrollVectorPattern>(
150 ctx, UnrollVectorOptions()
151 .setNativeShape(ArrayRef<int64_t>{2, 2})
152 .setFilterConstraint(
153 [](Operation *op) { return success(isa<AddFOp>(op)); }));
154
155 if (unrollBasedOnType) {
156 UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
157 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
158 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
159 SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
160 if (auto floatType = contractOp.getLhsType()
161 .getElementType()
162 .dyn_cast<FloatType>()) {
163 if (floatType.getWidth() == 16) {
164 nativeShape[2] = 4;
165 }
166 }
167 return nativeShape;
168 };
169 patterns.insert<UnrollVectorPattern>(
170 ctx, UnrollVectorOptions()
171 .setNativeShapeFn(nativeShapeFn)
172 .setFilterConstraint([](Operation *op) {
173 return success(isa<ContractionOp>(op));
174 }));
175 } else {
176 patterns.insert<UnrollVectorPattern>(
177 ctx, UnrollVectorOptions()
178 .setNativeShape(ArrayRef<int64_t>{2, 2, 2})
179 .setFilterConstraint([](Operation *op) {
180 return success(isa<ContractionOp>(op));
181 }));
182 }
183 populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
184 populateVectorToVectorTransformationPatterns(patterns, ctx);
185 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
186 }
187
188 Option<bool> unrollBasedOnType{
189 *this, "unroll-based-on-type",
190 llvm::cl::desc("Set the unroll factor based on type of the operation"),
191 llvm::cl::init(false)};
192 };
193
194 struct TestVectorDistributePatterns
195 : public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
196 TestVectorDistributePatterns() = default;
TestVectorDistributePatterns__anonafb160640111::TestVectorDistributePatterns197 TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {}
getDependentDialects__anonafb160640111::TestVectorDistributePatterns198 void getDependentDialects(DialectRegistry ®istry) const override {
199 registry.insert<VectorDialect>();
200 registry.insert<AffineDialect>();
201 }
202 ListOption<int32_t> multiplicity{
203 *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated,
204 llvm::cl::desc("Set the multiplicity used for distributing vector")};
205
runOnFunction__anonafb160640111::TestVectorDistributePatterns206 void runOnFunction() override {
207 MLIRContext *ctx = &getContext();
208 OwningRewritePatternList patterns;
209 FuncOp func = getFunction();
210 func.walk([&](AddFOp op) {
211 OpBuilder builder(op);
212 if (auto vecType = op.getType().dyn_cast<VectorType>()) {
213 SmallVector<int64_t, 2> mul;
214 SmallVector<AffineExpr, 2> perm;
215 SmallVector<Value, 2> ids;
216 unsigned count = 0;
217 // Remove the multiplicity of 1 and calculate the affine map based on
218 // the multiplicity.
219 SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
220 for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
221 if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
222 mul.push_back(m[i]);
223 ids.push_back(func.getArgument(count++));
224 perm.push_back(getAffineDimExpr(i, ctx));
225 }
226 }
227 auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
228 perm, ctx);
229 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
230 builder, op.getOperation(), ids, mul, map);
231 if (ops.hasValue()) {
232 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
233 op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
234 extractOp);
235 }
236 }
237 });
238 patterns.insert<PointwiseExtractPattern>(ctx);
239 populateVectorToVectorTransformationPatterns(patterns, ctx);
240 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
241 }
242 };
243
244 struct TestVectorToLoopPatterns
245 : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> {
246 TestVectorToLoopPatterns() = default;
TestVectorToLoopPatterns__anonafb160640111::TestVectorToLoopPatterns247 TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {}
getDependentDialects__anonafb160640111::TestVectorToLoopPatterns248 void getDependentDialects(DialectRegistry ®istry) const override {
249 registry.insert<VectorDialect>();
250 registry.insert<AffineDialect>();
251 }
252 Option<int32_t> multiplicity{
253 *this, "distribution-multiplicity",
254 llvm::cl::desc("Set the multiplicity used for distributing vector"),
255 llvm::cl::init(32)};
runOnFunction__anonafb160640111::TestVectorToLoopPatterns256 void runOnFunction() override {
257 MLIRContext *ctx = &getContext();
258 OwningRewritePatternList patterns;
259 FuncOp func = getFunction();
260 func.walk([&](AddFOp op) {
261 // Check that the operation type can be broken down into a loop.
262 VectorType type = op.getType().dyn_cast<VectorType>();
263 if (!type || type.getRank() != 1 ||
264 type.getNumElements() % multiplicity != 0)
265 return mlir::WalkResult::advance();
266 auto filterAlloc = [](Operation *op) {
267 if (isa<ConstantOp, AllocOp, CallOp>(op))
268 return false;
269 return true;
270 };
271 auto dependentOps = getSlice(op, filterAlloc);
272 // Create a loop and move instructions from the Op slice into the loop.
273 OpBuilder builder(op);
274 auto zero = builder.create<ConstantOp>(
275 op.getLoc(), builder.getIndexType(),
276 builder.getIntegerAttr(builder.getIndexType(), 0));
277 auto one = builder.create<ConstantOp>(
278 op.getLoc(), builder.getIndexType(),
279 builder.getIntegerAttr(builder.getIndexType(), 1));
280 auto numIter = builder.create<ConstantOp>(
281 op.getLoc(), builder.getIndexType(),
282 builder.getIntegerAttr(builder.getIndexType(), multiplicity));
283 auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
284 for (Operation *it : dependentOps) {
285 it->moveBefore(forOp.getBody()->getTerminator());
286 }
287 auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
288 // break up the original op and let the patterns propagate.
289 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
290 builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
291 map);
292 if (ops.hasValue()) {
293 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
294 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
295 }
296 return mlir::WalkResult::interrupt();
297 });
298 patterns.insert<PointwiseExtractPattern>(ctx);
299 populateVectorToVectorTransformationPatterns(patterns, ctx);
300 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
301 }
302 };
303
304 struct TestVectorTransferUnrollingPatterns
305 : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> {
getDependentDialects__anonafb160640111::TestVectorTransferUnrollingPatterns306 void getDependentDialects(DialectRegistry ®istry) const override {
307 registry.insert<AffineDialect>();
308 }
runOnFunction__anonafb160640111::TestVectorTransferUnrollingPatterns309 void runOnFunction() override {
310 MLIRContext *ctx = &getContext();
311 OwningRewritePatternList patterns;
312 patterns.insert<UnrollVectorPattern>(
313 ctx,
314 UnrollVectorOptions()
315 .setNativeShape(ArrayRef<int64_t>{2, 2})
316 .setFilterConstraint([](Operation *op) {
317 return success(
318 isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
319 }));
320 populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
321 populateVectorToVectorTransformationPatterns(patterns, ctx);
322 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
323 }
324 };
325
326 struct TestVectorTransferFullPartialSplitPatterns
327 : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
328 FunctionPass> {
329 TestVectorTransferFullPartialSplitPatterns() = default;
TestVectorTransferFullPartialSplitPatterns__anonafb160640111::TestVectorTransferFullPartialSplitPatterns330 TestVectorTransferFullPartialSplitPatterns(
331 const TestVectorTransferFullPartialSplitPatterns &pass) {}
332
getDependentDialects__anonafb160640111::TestVectorTransferFullPartialSplitPatterns333 void getDependentDialects(DialectRegistry ®istry) const override {
334 registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
335 }
336
337 Option<bool> useLinalgOps{
338 *this, "use-linalg-copy",
339 llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
340 "linalg.copy operations."),
341 llvm::cl::init(false)};
runOnFunction__anonafb160640111::TestVectorTransferFullPartialSplitPatterns342 void runOnFunction() override {
343 MLIRContext *ctx = &getContext();
344 OwningRewritePatternList patterns;
345 VectorTransformsOptions options;
346 if (useLinalgOps)
347 options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
348 else
349 options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
350 patterns.insert<VectorTransferFullPartialRewriter>(ctx, options);
351 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
352 }
353 };
354
355 struct TestVectorTransferOpt
356 : public PassWrapper<TestVectorTransferOpt, FunctionPass> {
runOnFunction__anonafb160640111::TestVectorTransferOpt357 void runOnFunction() override { transferOpflowOpt(getFunction()); }
358 };
359
360 } // end anonymous namespace
361
362 namespace mlir {
363 namespace test {
registerTestVectorConversions()364 void registerTestVectorConversions() {
365 PassRegistration<TestVectorToVectorConversion> vectorToVectorPass(
366 "test-vector-to-vector-conversion",
367 "Test conversion patterns between ops in the vector dialect");
368
369 PassRegistration<TestVectorSlicesConversion> slicesPass(
370 "test-vector-slices-conversion",
371 "Test conversion patterns that lower slices ops in the vector dialect");
372
373 PassRegistration<TestVectorContractionConversion> contractionPass(
374 "test-vector-contraction-conversion",
375 "Test conversion patterns that lower contract ops in the vector dialect");
376
377 PassRegistration<TestVectorUnrollingPatterns> contractionUnrollingPass(
378 "test-vector-unrolling-patterns",
379 "Test conversion patterns to unroll contract ops in the vector dialect");
380
381 PassRegistration<TestVectorTransferUnrollingPatterns> transferOpUnrollingPass(
382 "test-vector-transfer-unrolling-patterns",
383 "Test conversion patterns to unroll transfer ops in the vector dialect");
384
385 PassRegistration<TestVectorTransferFullPartialSplitPatterns>
386 vectorTransformFullPartialPass("test-vector-transfer-full-partial-split",
387 "Test conversion patterns to split "
388 "transfer ops via scf.if + linalg ops");
389 PassRegistration<TestVectorDistributePatterns> distributePass(
390 "test-vector-distribute-patterns",
391 "Test conversion patterns to distribute vector ops in the vector "
392 "dialect");
393 PassRegistration<TestVectorToLoopPatterns> vectorToForLoop(
394 "test-vector-to-forloop",
395 "Test conversion patterns to break up a vector op into a for loop");
396 PassRegistration<TestVectorTransferOpt> transferOpOpt(
397 "test-vector-transferop-opt",
398 "Test optimization transformations for transfer ops");
399 }
400 } // namespace test
401 } // namespace mlir
402