1 //===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <type_traits>
10 
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/Vector/VectorOps.h"
18 #include "mlir/Dialect/Vector/VectorTransforms.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 using namespace mlir;
23 using namespace mlir::vector;
24 namespace {
25 
26 struct TestVectorToVectorConversion
27     : public PassWrapper<TestVectorToVectorConversion, FunctionPass> {
28   TestVectorToVectorConversion() = default;
TestVectorToVectorConversion__anon6622089b0111::TestVectorToVectorConversion29   TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {}
getArgument__anon6622089b0111::TestVectorToVectorConversion30   StringRef getArgument() const final {
31     return "test-vector-to-vector-conversion";
32   }
getDescription__anon6622089b0111::TestVectorToVectorConversion33   StringRef getDescription() const final {
34     return "Test conversion patterns between ops in the vector dialect";
35   }
36 
getDependentDialects__anon6622089b0111::TestVectorToVectorConversion37   void getDependentDialects(DialectRegistry &registry) const override {
38     registry.insert<AffineDialect>();
39   }
40 
41   Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
42                       llvm::cl::init(false)};
43 
runOnFunction__anon6622089b0111::TestVectorToVectorConversion44   void runOnFunction() override {
45     auto *ctx = &getContext();
46     RewritePatternSet patterns(ctx);
47     if (unroll) {
48       populateVectorUnrollPatterns(
49           patterns,
50           UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
51               filter));
52     }
53     populateVectorToVectorCanonicalizationPatterns(patterns);
54     populateBubbleVectorBitCastOpPatterns(patterns);
55     populateCastAwayVectorLeadingOneDimPatterns(patterns);
56     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
57   }
58 
59 private:
60   // Return the target shape based on op type.
getShape__anon6622089b0111::TestVectorToVectorConversion61   static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
62     if (isa<AddFOp, SelectOp, CmpFOp>(op))
63       return SmallVector<int64_t, 4>(2, 2);
64     if (isa<vector::ContractionOp>(op))
65       return SmallVector<int64_t, 4>(3, 2);
66     // For transfer ops, just propagate the shape coming from
67     // InsertStridedSlices/ExtractStridedSlices.
68     if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
69       VectorType dstVec;
70       for (Operation *users : readOp->getUsers()) {
71         auto extract = dyn_cast<ExtractStridedSliceOp>(users);
72         if (!extract)
73           return llvm::None;
74         auto vecType = extract.getResult().getType().cast<VectorType>();
75         if (dstVec && dstVec != vecType)
76           return llvm::None;
77         dstVec = vecType;
78       }
79       return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
80                                      dstVec.getShape().end());
81     }
82     if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
83       auto insert = writeOp.vector().getDefiningOp<InsertStridedSliceOp>();
84       if (!insert)
85         return llvm::None;
86       ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
87       return SmallVector<int64_t, 4>(shape.begin(), shape.end());
88     }
89     return llvm::None;
90   }
91 
filter__anon6622089b0111::TestVectorToVectorConversion92   static LogicalResult filter(Operation *op) {
93     return success(isa<AddFOp, SelectOp, CmpFOp, ContractionOp, TransferReadOp,
94                        TransferWriteOp>(op));
95   }
96 };
97 
98 struct TestVectorContractionConversion
99     : public PassWrapper<TestVectorContractionConversion, FunctionPass> {
getArgument__anon6622089b0111::TestVectorContractionConversion100   StringRef getArgument() const final {
101     return "test-vector-contraction-conversion";
102   }
getDescription__anon6622089b0111::TestVectorContractionConversion103   StringRef getDescription() const final {
104     return "Test conversion patterns that lower contract ops in the vector "
105            "dialect";
106   }
107   TestVectorContractionConversion() = default;
TestVectorContractionConversion__anon6622089b0111::TestVectorContractionConversion108   TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
109   }
110 
111   Option<bool> lowerToFlatMatrix{
112       *this, "vector-lower-matrix-intrinsics",
113       llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
114       llvm::cl::init(false)};
115   Option<bool> lowerToFlatTranspose{
116       *this, "vector-flat-transpose",
117       llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
118       llvm::cl::init(false)};
119   Option<bool> lowerToOuterProduct{
120       *this, "vector-outerproduct",
121       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
122       llvm::cl::init(false)};
123   Option<bool> lowerToFilterOuterProduct{
124       *this, "vector-filter-outerproduct",
125       llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
126                      "vectors of size 4."),
127       llvm::cl::init(false)};
128 
runOnFunction__anon6622089b0111::TestVectorContractionConversion129   void runOnFunction() override {
130     RewritePatternSet patterns(&getContext());
131 
132     // Test on one pattern in isolation.
133     if (lowerToOuterProduct) {
134       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
135       VectorTransformsOptions options{lowering};
136       patterns.add<ContractionOpToOuterProductOpLowering>(options,
137                                                           &getContext());
138       (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
139       return;
140     }
141 
142     // Test on one pattern in isolation.
143     if (lowerToFilterOuterProduct) {
144       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
145       VectorTransformsOptions options{lowering};
146       patterns.add<ContractionOpToOuterProductOpLowering>(
147           options, &getContext(), [](vector::ContractionOp op) {
148             // Only lowers vector.contract where the lhs as a type vector<MxNx?>
149             // where M is not 4.
150             if (op.getRhsType().getShape()[0] == 4)
151               return failure();
152             return success();
153           });
154       (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
155       return;
156     }
157 
158     // Test on all contract lowering patterns.
159     VectorContractLowering contractLowering = VectorContractLowering::Dot;
160     if (lowerToFlatMatrix)
161       contractLowering = VectorContractLowering::Matmul;
162     VectorTransposeLowering transposeLowering =
163         VectorTransposeLowering::EltWise;
164     if (lowerToFlatTranspose)
165       transposeLowering = VectorTransposeLowering::Flat;
166     VectorTransformsOptions options{contractLowering, transposeLowering};
167     populateVectorContractLoweringPatterns(patterns, options);
168     populateVectorTransposeLoweringPatterns(patterns, options);
169     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
170   }
171 };
172 
173 struct TestVectorUnrollingPatterns
174     : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
getArgument__anon6622089b0111::TestVectorUnrollingPatterns175   StringRef getArgument() const final {
176     return "test-vector-unrolling-patterns";
177   }
getDescription__anon6622089b0111::TestVectorUnrollingPatterns178   StringRef getDescription() const final {
179     return "Test conversion patterns to unroll contract ops in the vector "
180            "dialect";
181   }
182   TestVectorUnrollingPatterns() = default;
TestVectorUnrollingPatterns__anon6622089b0111::TestVectorUnrollingPatterns183   TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
runOnFunction__anon6622089b0111::TestVectorUnrollingPatterns184   void runOnFunction() override {
185     MLIRContext *ctx = &getContext();
186     RewritePatternSet patterns(ctx);
187     populateVectorUnrollPatterns(
188         patterns, UnrollVectorOptions()
189                       .setNativeShape(ArrayRef<int64_t>{2, 2})
190                       .setFilterConstraint([](Operation *op) {
191                         return success(isa<AddFOp, vector::FMAOp>(op));
192                       }));
193 
194     if (unrollBasedOnType) {
195       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
196           [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
197         vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
198         SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
199         if (auto floatType = contractOp.getLhsType()
200                                  .getElementType()
201                                  .dyn_cast<FloatType>()) {
202           if (floatType.getWidth() == 16) {
203             nativeShape[2] = 4;
204           }
205         }
206         return nativeShape;
207       };
208       populateVectorUnrollPatterns(patterns,
209                                    UnrollVectorOptions()
210                                        .setNativeShapeFn(nativeShapeFn)
211                                        .setFilterConstraint([](Operation *op) {
212                                          return success(isa<ContractionOp>(op));
213                                        }));
214     } else {
215       populateVectorUnrollPatterns(
216           patterns, UnrollVectorOptions()
217                         .setNativeShape(ArrayRef<int64_t>{2, 2, 2})
218                         .setFilterConstraint([](Operation *op) {
219                           return success(isa<ContractionOp>(op));
220                         }));
221     }
222     populateVectorToVectorCanonicalizationPatterns(patterns);
223     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
224   }
225 
226   Option<bool> unrollBasedOnType{
227       *this, "unroll-based-on-type",
228       llvm::cl::desc("Set the unroll factor based on type of the operation"),
229       llvm::cl::init(false)};
230 };
231 
232 struct TestVectorDistributePatterns
233     : public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
getArgument__anon6622089b0111::TestVectorDistributePatterns234   StringRef getArgument() const final {
235     return "test-vector-distribute-patterns";
236   }
getDescription__anon6622089b0111::TestVectorDistributePatterns237   StringRef getDescription() const final {
238     return "Test conversion patterns to distribute vector ops in the vector "
239            "dialect";
240   }
241   TestVectorDistributePatterns() = default;
TestVectorDistributePatterns__anon6622089b0111::TestVectorDistributePatterns242   TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {}
getDependentDialects__anon6622089b0111::TestVectorDistributePatterns243   void getDependentDialects(DialectRegistry &registry) const override {
244     registry.insert<VectorDialect>();
245     registry.insert<AffineDialect>();
246   }
247   ListOption<int32_t> multiplicity{
248       *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated,
249       llvm::cl::desc("Set the multiplicity used for distributing vector")};
250 
runOnFunction__anon6622089b0111::TestVectorDistributePatterns251   void runOnFunction() override {
252     MLIRContext *ctx = &getContext();
253     RewritePatternSet patterns(ctx);
254     FuncOp func = getFunction();
255     func.walk([&](AddFOp op) {
256       OpBuilder builder(op);
257       if (auto vecType = op.getType().dyn_cast<VectorType>()) {
258         SmallVector<int64_t, 2> mul;
259         SmallVector<AffineExpr, 2> perm;
260         SmallVector<Value, 2> ids;
261         unsigned count = 0;
262         // Remove the multiplicity of 1 and calculate the affine map based on
263         // the multiplicity.
264         SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
265         for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
266           if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
267             mul.push_back(m[i]);
268             ids.push_back(func.getArgument(count++));
269             perm.push_back(getAffineDimExpr(i, ctx));
270           }
271         }
272         auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
273                                   perm, ctx);
274         Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
275             builder, op.getOperation(), ids, mul, map);
276         if (ops.hasValue()) {
277           SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
278           op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
279                                               extractOp);
280         }
281       }
282     });
283     populatePropagateVectorDistributionPatterns(patterns);
284     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
285   }
286 };
287 
288 struct TestVectorToLoopPatterns
289     : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> {
getArgument__anon6622089b0111::TestVectorToLoopPatterns290   StringRef getArgument() const final { return "test-vector-to-forloop"; }
getDescription__anon6622089b0111::TestVectorToLoopPatterns291   StringRef getDescription() const final {
292     return "Test conversion patterns to break up a vector op into a for loop";
293   }
294   TestVectorToLoopPatterns() = default;
TestVectorToLoopPatterns__anon6622089b0111::TestVectorToLoopPatterns295   TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {}
getDependentDialects__anon6622089b0111::TestVectorToLoopPatterns296   void getDependentDialects(DialectRegistry &registry) const override {
297     registry.insert<VectorDialect>();
298     registry.insert<AffineDialect>();
299   }
300   Option<int32_t> multiplicity{
301       *this, "distribution-multiplicity",
302       llvm::cl::desc("Set the multiplicity used for distributing vector"),
303       llvm::cl::init(32)};
runOnFunction__anon6622089b0111::TestVectorToLoopPatterns304   void runOnFunction() override {
305     MLIRContext *ctx = &getContext();
306     RewritePatternSet patterns(ctx);
307     FuncOp func = getFunction();
308     func.walk([&](AddFOp op) {
309       // Check that the operation type can be broken down into a loop.
310       VectorType type = op.getType().dyn_cast<VectorType>();
311       if (!type || type.getRank() != 1 ||
312           type.getNumElements() % multiplicity != 0)
313         return mlir::WalkResult::advance();
314       auto filterAlloc = [](Operation *op) {
315         if (isa<ConstantOp, memref::AllocOp, CallOp>(op))
316           return false;
317         return true;
318       };
319       auto dependentOps = getSlice(op, filterAlloc);
320       // Create a loop and move instructions from the Op slice into the loop.
321       OpBuilder builder(op);
322       auto zero = builder.create<ConstantOp>(
323           op.getLoc(), builder.getIndexType(),
324           builder.getIntegerAttr(builder.getIndexType(), 0));
325       auto one = builder.create<ConstantOp>(
326           op.getLoc(), builder.getIndexType(),
327           builder.getIntegerAttr(builder.getIndexType(), 1));
328       auto numIter = builder.create<ConstantOp>(
329           op.getLoc(), builder.getIndexType(),
330           builder.getIntegerAttr(builder.getIndexType(), multiplicity));
331       auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
332       for (Operation *it : dependentOps) {
333         it->moveBefore(forOp.getBody()->getTerminator());
334       }
335       auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
336       // break up the original op and let the patterns propagate.
337       Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
338           builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
339           map);
340       if (ops.hasValue()) {
341         SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
342         op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
343       }
344       return mlir::WalkResult::interrupt();
345     });
346     populatePropagateVectorDistributionPatterns(patterns);
347     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
348   }
349 };
350 
351 struct TestVectorTransferUnrollingPatterns
352     : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> {
getDependentDialects__anon6622089b0111::TestVectorTransferUnrollingPatterns353   void getDependentDialects(DialectRegistry &registry) const override {
354     registry.insert<AffineDialect>();
355   }
getArgument__anon6622089b0111::TestVectorTransferUnrollingPatterns356   StringRef getArgument() const final {
357     return "test-vector-transfer-unrolling-patterns";
358   }
getDescription__anon6622089b0111::TestVectorTransferUnrollingPatterns359   StringRef getDescription() const final {
360     return "Test conversion patterns to unroll transfer ops in the vector "
361            "dialect";
362   }
runOnFunction__anon6622089b0111::TestVectorTransferUnrollingPatterns363   void runOnFunction() override {
364     MLIRContext *ctx = &getContext();
365     RewritePatternSet patterns(ctx);
366     populateVectorUnrollPatterns(
367         patterns,
368         UnrollVectorOptions()
369             .setNativeShape(ArrayRef<int64_t>{2, 2})
370             .setFilterConstraint([](Operation *op) {
371               return success(
372                   isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
373             }));
374     populateVectorToVectorCanonicalizationPatterns(patterns);
375     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
376   }
377 };
378 
379 struct TestVectorTransferFullPartialSplitPatterns
380     : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
381                          FunctionPass> {
getArgument__anon6622089b0111::TestVectorTransferFullPartialSplitPatterns382   StringRef getArgument() const final {
383     return "test-vector-transfer-full-partial-split";
384   }
getDescription__anon6622089b0111::TestVectorTransferFullPartialSplitPatterns385   StringRef getDescription() const final {
386     return "Test conversion patterns to split "
387            "transfer ops via scf.if + linalg ops";
388   }
389   TestVectorTransferFullPartialSplitPatterns() = default;
TestVectorTransferFullPartialSplitPatterns__anon6622089b0111::TestVectorTransferFullPartialSplitPatterns390   TestVectorTransferFullPartialSplitPatterns(
391       const TestVectorTransferFullPartialSplitPatterns &pass) {}
392 
getDependentDialects__anon6622089b0111::TestVectorTransferFullPartialSplitPatterns393   void getDependentDialects(DialectRegistry &registry) const override {
394     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
395                     scf::SCFDialect>();
396   }
397 
398   Option<bool> useLinalgOps{
399       *this, "use-linalg-copy",
400       llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
401                      "linalg.copy operations."),
402       llvm::cl::init(false)};
runOnFunction__anon6622089b0111::TestVectorTransferFullPartialSplitPatterns403   void runOnFunction() override {
404     MLIRContext *ctx = &getContext();
405     RewritePatternSet patterns(ctx);
406     VectorTransformsOptions options;
407     if (useLinalgOps)
408       options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
409     else
410       options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
411     patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
412     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
413   }
414 };
415 
416 struct TestVectorTransferOpt
417     : public PassWrapper<TestVectorTransferOpt, FunctionPass> {
getArgument__anon6622089b0111::TestVectorTransferOpt418   StringRef getArgument() const final { return "test-vector-transferop-opt"; }
getDescription__anon6622089b0111::TestVectorTransferOpt419   StringRef getDescription() const final {
420     return "Test optimization transformations for transfer ops";
421   }
runOnFunction__anon6622089b0111::TestVectorTransferOpt422   void runOnFunction() override { transferOpflowOpt(getFunction()); }
423 };
424 
425 struct TestVectorTransferLoweringPatterns
426     : public PassWrapper<TestVectorTransferLoweringPatterns, FunctionPass> {
getDependentDialects__anon6622089b0111::TestVectorTransferLoweringPatterns427   void getDependentDialects(DialectRegistry &registry) const override {
428     registry.insert<memref::MemRefDialect>();
429   }
getArgument__anon6622089b0111::TestVectorTransferLoweringPatterns430   StringRef getArgument() const final {
431     return "test-vector-transfer-lowering-patterns";
432   }
getDescription__anon6622089b0111::TestVectorTransferLoweringPatterns433   StringRef getDescription() const final {
434     return "Test conversion patterns to lower transfer ops to other vector ops";
435   }
runOnFunction__anon6622089b0111::TestVectorTransferLoweringPatterns436   void runOnFunction() override {
437     RewritePatternSet patterns(&getContext());
438     populateVectorTransferLoweringPatterns(patterns);
439     populateVectorTransferPermutationMapLoweringPatterns(patterns);
440     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
441   }
442 };
443 
444 struct TestVectorMultiReductionLoweringPatterns
445     : public PassWrapper<TestVectorMultiReductionLoweringPatterns,
446                          FunctionPass> {
getDependentDialects__anon6622089b0111::TestVectorMultiReductionLoweringPatterns447   void getDependentDialects(DialectRegistry &registry) const override {
448     registry.insert<memref::MemRefDialect>();
449   }
getArgument__anon6622089b0111::TestVectorMultiReductionLoweringPatterns450   StringRef getArgument() const final {
451     return "test-vector-multi-reduction-lowering-patterns";
452   }
getDescription__anon6622089b0111::TestVectorMultiReductionLoweringPatterns453   StringRef getDescription() const final {
454     return "Test conversion patterns to lower vector.multi_reduction to other "
455            "vector ops";
456   }
runOnFunction__anon6622089b0111::TestVectorMultiReductionLoweringPatterns457   void runOnFunction() override {
458     RewritePatternSet patterns(&getContext());
459     populateVectorMultiReductionLoweringPatterns(patterns);
460     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
461   }
462 };
463 
464 } // end anonymous namespace
465 
466 namespace mlir {
467 namespace test {
registerTestVectorConversions()468 void registerTestVectorConversions() {
469   PassRegistration<TestVectorToVectorConversion>();
470 
471   PassRegistration<TestVectorContractionConversion>();
472 
473   PassRegistration<TestVectorUnrollingPatterns>();
474 
475   PassRegistration<TestVectorTransferUnrollingPatterns>();
476 
477   PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
478 
479   PassRegistration<TestVectorDistributePatterns>();
480 
481   PassRegistration<TestVectorToLoopPatterns>();
482 
483   PassRegistration<TestVectorTransferOpt>();
484 
485   PassRegistration<TestVectorTransferLoweringPatterns>();
486 
487   PassRegistration<TestVectorMultiReductionLoweringPatterns>();
488 }
489 } // namespace test
490 } // namespace mlir
491