1 //===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include <type_traits>
10
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/Vector/VectorOps.h"
18 #include "mlir/Dialect/Vector/VectorTransforms.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21
22 using namespace mlir;
23 using namespace mlir::vector;
24 namespace {
25
26 struct TestVectorToVectorConversion
27 : public PassWrapper<TestVectorToVectorConversion, FunctionPass> {
28 TestVectorToVectorConversion() = default;
TestVectorToVectorConversion__anon6622089b0111::TestVectorToVectorConversion29 TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {}
getArgument__anon6622089b0111::TestVectorToVectorConversion30 StringRef getArgument() const final {
31 return "test-vector-to-vector-conversion";
32 }
getDescription__anon6622089b0111::TestVectorToVectorConversion33 StringRef getDescription() const final {
34 return "Test conversion patterns between ops in the vector dialect";
35 }
36
getDependentDialects__anon6622089b0111::TestVectorToVectorConversion37 void getDependentDialects(DialectRegistry ®istry) const override {
38 registry.insert<AffineDialect>();
39 }
40
41 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
42 llvm::cl::init(false)};
43
runOnFunction__anon6622089b0111::TestVectorToVectorConversion44 void runOnFunction() override {
45 auto *ctx = &getContext();
46 RewritePatternSet patterns(ctx);
47 if (unroll) {
48 populateVectorUnrollPatterns(
49 patterns,
50 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
51 filter));
52 }
53 populateVectorToVectorCanonicalizationPatterns(patterns);
54 populateBubbleVectorBitCastOpPatterns(patterns);
55 populateCastAwayVectorLeadingOneDimPatterns(patterns);
56 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
57 }
58
59 private:
60 // Return the target shape based on op type.
getShape__anon6622089b0111::TestVectorToVectorConversion61 static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
62 if (isa<AddFOp, SelectOp, CmpFOp>(op))
63 return SmallVector<int64_t, 4>(2, 2);
64 if (isa<vector::ContractionOp>(op))
65 return SmallVector<int64_t, 4>(3, 2);
66 // For transfer ops, just propagate the shape coming from
67 // InsertStridedSlices/ExtractStridedSlices.
68 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
69 VectorType dstVec;
70 for (Operation *users : readOp->getUsers()) {
71 auto extract = dyn_cast<ExtractStridedSliceOp>(users);
72 if (!extract)
73 return llvm::None;
74 auto vecType = extract.getResult().getType().cast<VectorType>();
75 if (dstVec && dstVec != vecType)
76 return llvm::None;
77 dstVec = vecType;
78 }
79 return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
80 dstVec.getShape().end());
81 }
82 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
83 auto insert = writeOp.vector().getDefiningOp<InsertStridedSliceOp>();
84 if (!insert)
85 return llvm::None;
86 ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
87 return SmallVector<int64_t, 4>(shape.begin(), shape.end());
88 }
89 return llvm::None;
90 }
91
filter__anon6622089b0111::TestVectorToVectorConversion92 static LogicalResult filter(Operation *op) {
93 return success(isa<AddFOp, SelectOp, CmpFOp, ContractionOp, TransferReadOp,
94 TransferWriteOp>(op));
95 }
96 };
97
98 struct TestVectorContractionConversion
99 : public PassWrapper<TestVectorContractionConversion, FunctionPass> {
getArgument__anon6622089b0111::TestVectorContractionConversion100 StringRef getArgument() const final {
101 return "test-vector-contraction-conversion";
102 }
getDescription__anon6622089b0111::TestVectorContractionConversion103 StringRef getDescription() const final {
104 return "Test conversion patterns that lower contract ops in the vector "
105 "dialect";
106 }
107 TestVectorContractionConversion() = default;
TestVectorContractionConversion__anon6622089b0111::TestVectorContractionConversion108 TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
109 }
110
111 Option<bool> lowerToFlatMatrix{
112 *this, "vector-lower-matrix-intrinsics",
113 llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
114 llvm::cl::init(false)};
115 Option<bool> lowerToFlatTranspose{
116 *this, "vector-flat-transpose",
117 llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
118 llvm::cl::init(false)};
119 Option<bool> lowerToOuterProduct{
120 *this, "vector-outerproduct",
121 llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
122 llvm::cl::init(false)};
123 Option<bool> lowerToFilterOuterProduct{
124 *this, "vector-filter-outerproduct",
125 llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
126 "vectors of size 4."),
127 llvm::cl::init(false)};
128
runOnFunction__anon6622089b0111::TestVectorContractionConversion129 void runOnFunction() override {
130 RewritePatternSet patterns(&getContext());
131
132 // Test on one pattern in isolation.
133 if (lowerToOuterProduct) {
134 VectorContractLowering lowering = VectorContractLowering::OuterProduct;
135 VectorTransformsOptions options{lowering};
136 patterns.add<ContractionOpToOuterProductOpLowering>(options,
137 &getContext());
138 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
139 return;
140 }
141
142 // Test on one pattern in isolation.
143 if (lowerToFilterOuterProduct) {
144 VectorContractLowering lowering = VectorContractLowering::OuterProduct;
145 VectorTransformsOptions options{lowering};
146 patterns.add<ContractionOpToOuterProductOpLowering>(
147 options, &getContext(), [](vector::ContractionOp op) {
148 // Only lowers vector.contract where the lhs as a type vector<MxNx?>
149 // where M is not 4.
150 if (op.getRhsType().getShape()[0] == 4)
151 return failure();
152 return success();
153 });
154 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
155 return;
156 }
157
158 // Test on all contract lowering patterns.
159 VectorContractLowering contractLowering = VectorContractLowering::Dot;
160 if (lowerToFlatMatrix)
161 contractLowering = VectorContractLowering::Matmul;
162 VectorTransposeLowering transposeLowering =
163 VectorTransposeLowering::EltWise;
164 if (lowerToFlatTranspose)
165 transposeLowering = VectorTransposeLowering::Flat;
166 VectorTransformsOptions options{contractLowering, transposeLowering};
167 populateVectorContractLoweringPatterns(patterns, options);
168 populateVectorTransposeLoweringPatterns(patterns, options);
169 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
170 }
171 };
172
173 struct TestVectorUnrollingPatterns
174 : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
getArgument__anon6622089b0111::TestVectorUnrollingPatterns175 StringRef getArgument() const final {
176 return "test-vector-unrolling-patterns";
177 }
getDescription__anon6622089b0111::TestVectorUnrollingPatterns178 StringRef getDescription() const final {
179 return "Test conversion patterns to unroll contract ops in the vector "
180 "dialect";
181 }
182 TestVectorUnrollingPatterns() = default;
TestVectorUnrollingPatterns__anon6622089b0111::TestVectorUnrollingPatterns183 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
runOnFunction__anon6622089b0111::TestVectorUnrollingPatterns184 void runOnFunction() override {
185 MLIRContext *ctx = &getContext();
186 RewritePatternSet patterns(ctx);
187 populateVectorUnrollPatterns(
188 patterns, UnrollVectorOptions()
189 .setNativeShape(ArrayRef<int64_t>{2, 2})
190 .setFilterConstraint([](Operation *op) {
191 return success(isa<AddFOp, vector::FMAOp>(op));
192 }));
193
194 if (unrollBasedOnType) {
195 UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
196 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
197 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
198 SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
199 if (auto floatType = contractOp.getLhsType()
200 .getElementType()
201 .dyn_cast<FloatType>()) {
202 if (floatType.getWidth() == 16) {
203 nativeShape[2] = 4;
204 }
205 }
206 return nativeShape;
207 };
208 populateVectorUnrollPatterns(patterns,
209 UnrollVectorOptions()
210 .setNativeShapeFn(nativeShapeFn)
211 .setFilterConstraint([](Operation *op) {
212 return success(isa<ContractionOp>(op));
213 }));
214 } else {
215 populateVectorUnrollPatterns(
216 patterns, UnrollVectorOptions()
217 .setNativeShape(ArrayRef<int64_t>{2, 2, 2})
218 .setFilterConstraint([](Operation *op) {
219 return success(isa<ContractionOp>(op));
220 }));
221 }
222 populateVectorToVectorCanonicalizationPatterns(patterns);
223 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
224 }
225
226 Option<bool> unrollBasedOnType{
227 *this, "unroll-based-on-type",
228 llvm::cl::desc("Set the unroll factor based on type of the operation"),
229 llvm::cl::init(false)};
230 };
231
232 struct TestVectorDistributePatterns
233 : public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
getArgument__anon6622089b0111::TestVectorDistributePatterns234 StringRef getArgument() const final {
235 return "test-vector-distribute-patterns";
236 }
getDescription__anon6622089b0111::TestVectorDistributePatterns237 StringRef getDescription() const final {
238 return "Test conversion patterns to distribute vector ops in the vector "
239 "dialect";
240 }
241 TestVectorDistributePatterns() = default;
TestVectorDistributePatterns__anon6622089b0111::TestVectorDistributePatterns242 TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {}
getDependentDialects__anon6622089b0111::TestVectorDistributePatterns243 void getDependentDialects(DialectRegistry ®istry) const override {
244 registry.insert<VectorDialect>();
245 registry.insert<AffineDialect>();
246 }
247 ListOption<int32_t> multiplicity{
248 *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated,
249 llvm::cl::desc("Set the multiplicity used for distributing vector")};
250
runOnFunction__anon6622089b0111::TestVectorDistributePatterns251 void runOnFunction() override {
252 MLIRContext *ctx = &getContext();
253 RewritePatternSet patterns(ctx);
254 FuncOp func = getFunction();
255 func.walk([&](AddFOp op) {
256 OpBuilder builder(op);
257 if (auto vecType = op.getType().dyn_cast<VectorType>()) {
258 SmallVector<int64_t, 2> mul;
259 SmallVector<AffineExpr, 2> perm;
260 SmallVector<Value, 2> ids;
261 unsigned count = 0;
262 // Remove the multiplicity of 1 and calculate the affine map based on
263 // the multiplicity.
264 SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
265 for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
266 if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
267 mul.push_back(m[i]);
268 ids.push_back(func.getArgument(count++));
269 perm.push_back(getAffineDimExpr(i, ctx));
270 }
271 }
272 auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
273 perm, ctx);
274 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
275 builder, op.getOperation(), ids, mul, map);
276 if (ops.hasValue()) {
277 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
278 op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
279 extractOp);
280 }
281 }
282 });
283 populatePropagateVectorDistributionPatterns(patterns);
284 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
285 }
286 };
287
288 struct TestVectorToLoopPatterns
289 : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> {
getArgument__anon6622089b0111::TestVectorToLoopPatterns290 StringRef getArgument() const final { return "test-vector-to-forloop"; }
getDescription__anon6622089b0111::TestVectorToLoopPatterns291 StringRef getDescription() const final {
292 return "Test conversion patterns to break up a vector op into a for loop";
293 }
294 TestVectorToLoopPatterns() = default;
TestVectorToLoopPatterns__anon6622089b0111::TestVectorToLoopPatterns295 TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {}
getDependentDialects__anon6622089b0111::TestVectorToLoopPatterns296 void getDependentDialects(DialectRegistry ®istry) const override {
297 registry.insert<VectorDialect>();
298 registry.insert<AffineDialect>();
299 }
300 Option<int32_t> multiplicity{
301 *this, "distribution-multiplicity",
302 llvm::cl::desc("Set the multiplicity used for distributing vector"),
303 llvm::cl::init(32)};
runOnFunction__anon6622089b0111::TestVectorToLoopPatterns304 void runOnFunction() override {
305 MLIRContext *ctx = &getContext();
306 RewritePatternSet patterns(ctx);
307 FuncOp func = getFunction();
308 func.walk([&](AddFOp op) {
309 // Check that the operation type can be broken down into a loop.
310 VectorType type = op.getType().dyn_cast<VectorType>();
311 if (!type || type.getRank() != 1 ||
312 type.getNumElements() % multiplicity != 0)
313 return mlir::WalkResult::advance();
314 auto filterAlloc = [](Operation *op) {
315 if (isa<ConstantOp, memref::AllocOp, CallOp>(op))
316 return false;
317 return true;
318 };
319 auto dependentOps = getSlice(op, filterAlloc);
320 // Create a loop and move instructions from the Op slice into the loop.
321 OpBuilder builder(op);
322 auto zero = builder.create<ConstantOp>(
323 op.getLoc(), builder.getIndexType(),
324 builder.getIntegerAttr(builder.getIndexType(), 0));
325 auto one = builder.create<ConstantOp>(
326 op.getLoc(), builder.getIndexType(),
327 builder.getIntegerAttr(builder.getIndexType(), 1));
328 auto numIter = builder.create<ConstantOp>(
329 op.getLoc(), builder.getIndexType(),
330 builder.getIntegerAttr(builder.getIndexType(), multiplicity));
331 auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
332 for (Operation *it : dependentOps) {
333 it->moveBefore(forOp.getBody()->getTerminator());
334 }
335 auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
336 // break up the original op and let the patterns propagate.
337 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
338 builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
339 map);
340 if (ops.hasValue()) {
341 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
342 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
343 }
344 return mlir::WalkResult::interrupt();
345 });
346 populatePropagateVectorDistributionPatterns(patterns);
347 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
348 }
349 };
350
351 struct TestVectorTransferUnrollingPatterns
352 : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> {
getDependentDialects__anon6622089b0111::TestVectorTransferUnrollingPatterns353 void getDependentDialects(DialectRegistry ®istry) const override {
354 registry.insert<AffineDialect>();
355 }
getArgument__anon6622089b0111::TestVectorTransferUnrollingPatterns356 StringRef getArgument() const final {
357 return "test-vector-transfer-unrolling-patterns";
358 }
getDescription__anon6622089b0111::TestVectorTransferUnrollingPatterns359 StringRef getDescription() const final {
360 return "Test conversion patterns to unroll transfer ops in the vector "
361 "dialect";
362 }
runOnFunction__anon6622089b0111::TestVectorTransferUnrollingPatterns363 void runOnFunction() override {
364 MLIRContext *ctx = &getContext();
365 RewritePatternSet patterns(ctx);
366 populateVectorUnrollPatterns(
367 patterns,
368 UnrollVectorOptions()
369 .setNativeShape(ArrayRef<int64_t>{2, 2})
370 .setFilterConstraint([](Operation *op) {
371 return success(
372 isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
373 }));
374 populateVectorToVectorCanonicalizationPatterns(patterns);
375 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
376 }
377 };
378
379 struct TestVectorTransferFullPartialSplitPatterns
380 : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
381 FunctionPass> {
getArgument__anon6622089b0111::TestVectorTransferFullPartialSplitPatterns382 StringRef getArgument() const final {
383 return "test-vector-transfer-full-partial-split";
384 }
getDescription__anon6622089b0111::TestVectorTransferFullPartialSplitPatterns385 StringRef getDescription() const final {
386 return "Test conversion patterns to split "
387 "transfer ops via scf.if + linalg ops";
388 }
389 TestVectorTransferFullPartialSplitPatterns() = default;
TestVectorTransferFullPartialSplitPatterns__anon6622089b0111::TestVectorTransferFullPartialSplitPatterns390 TestVectorTransferFullPartialSplitPatterns(
391 const TestVectorTransferFullPartialSplitPatterns &pass) {}
392
getDependentDialects__anon6622089b0111::TestVectorTransferFullPartialSplitPatterns393 void getDependentDialects(DialectRegistry ®istry) const override {
394 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
395 scf::SCFDialect>();
396 }
397
398 Option<bool> useLinalgOps{
399 *this, "use-linalg-copy",
400 llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
401 "linalg.copy operations."),
402 llvm::cl::init(false)};
runOnFunction__anon6622089b0111::TestVectorTransferFullPartialSplitPatterns403 void runOnFunction() override {
404 MLIRContext *ctx = &getContext();
405 RewritePatternSet patterns(ctx);
406 VectorTransformsOptions options;
407 if (useLinalgOps)
408 options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
409 else
410 options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
411 patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
412 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
413 }
414 };
415
416 struct TestVectorTransferOpt
417 : public PassWrapper<TestVectorTransferOpt, FunctionPass> {
getArgument__anon6622089b0111::TestVectorTransferOpt418 StringRef getArgument() const final { return "test-vector-transferop-opt"; }
getDescription__anon6622089b0111::TestVectorTransferOpt419 StringRef getDescription() const final {
420 return "Test optimization transformations for transfer ops";
421 }
runOnFunction__anon6622089b0111::TestVectorTransferOpt422 void runOnFunction() override { transferOpflowOpt(getFunction()); }
423 };
424
425 struct TestVectorTransferLoweringPatterns
426 : public PassWrapper<TestVectorTransferLoweringPatterns, FunctionPass> {
getDependentDialects__anon6622089b0111::TestVectorTransferLoweringPatterns427 void getDependentDialects(DialectRegistry ®istry) const override {
428 registry.insert<memref::MemRefDialect>();
429 }
getArgument__anon6622089b0111::TestVectorTransferLoweringPatterns430 StringRef getArgument() const final {
431 return "test-vector-transfer-lowering-patterns";
432 }
getDescription__anon6622089b0111::TestVectorTransferLoweringPatterns433 StringRef getDescription() const final {
434 return "Test conversion patterns to lower transfer ops to other vector ops";
435 }
runOnFunction__anon6622089b0111::TestVectorTransferLoweringPatterns436 void runOnFunction() override {
437 RewritePatternSet patterns(&getContext());
438 populateVectorTransferLoweringPatterns(patterns);
439 populateVectorTransferPermutationMapLoweringPatterns(patterns);
440 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
441 }
442 };
443
444 struct TestVectorMultiReductionLoweringPatterns
445 : public PassWrapper<TestVectorMultiReductionLoweringPatterns,
446 FunctionPass> {
getDependentDialects__anon6622089b0111::TestVectorMultiReductionLoweringPatterns447 void getDependentDialects(DialectRegistry ®istry) const override {
448 registry.insert<memref::MemRefDialect>();
449 }
getArgument__anon6622089b0111::TestVectorMultiReductionLoweringPatterns450 StringRef getArgument() const final {
451 return "test-vector-multi-reduction-lowering-patterns";
452 }
getDescription__anon6622089b0111::TestVectorMultiReductionLoweringPatterns453 StringRef getDescription() const final {
454 return "Test conversion patterns to lower vector.multi_reduction to other "
455 "vector ops";
456 }
runOnFunction__anon6622089b0111::TestVectorMultiReductionLoweringPatterns457 void runOnFunction() override {
458 RewritePatternSet patterns(&getContext());
459 populateVectorMultiReductionLoweringPatterns(patterns);
460 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
461 }
462 };
463
464 } // end anonymous namespace
465
466 namespace mlir {
467 namespace test {
registerTestVectorConversions()468 void registerTestVectorConversions() {
469 PassRegistration<TestVectorToVectorConversion>();
470
471 PassRegistration<TestVectorContractionConversion>();
472
473 PassRegistration<TestVectorUnrollingPatterns>();
474
475 PassRegistration<TestVectorTransferUnrollingPatterns>();
476
477 PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
478
479 PassRegistration<TestVectorDistributePatterns>();
480
481 PassRegistration<TestVectorToLoopPatterns>();
482
483 PassRegistration<TestVectorTransferOpt>();
484
485 PassRegistration<TestVectorTransferLoweringPatterns>();
486
487 PassRegistration<TestVectorMultiReductionLoweringPatterns>();
488 }
489 } // namespace test
490 } // namespace mlir
491