1 //===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <type_traits>
10 
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
14 #include "mlir/Dialect/SCF/SCF.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/Dialect/Vector/VectorTransforms.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 
21 using namespace mlir;
22 using namespace mlir::vector;
23 namespace {
24 
25 struct TestVectorToVectorConversion
26     : public PassWrapper<TestVectorToVectorConversion, FunctionPass> {
runOnFunction__anonfe31e6f10111::TestVectorToVectorConversion27   void runOnFunction() override {
28     OwningRewritePatternList patterns;
29     auto *ctx = &getContext();
30     patterns.insert<UnrollVectorPattern<AddFOp>>(
31         ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
32     patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
33         ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2, 2}));
34     populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
35     populateVectorToVectorTransformationPatterns(patterns, ctx);
36     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
37   }
38 };
39 
40 struct TestVectorSlicesConversion
41     : public PassWrapper<TestVectorSlicesConversion, FunctionPass> {
runOnFunction__anonfe31e6f10111::TestVectorSlicesConversion42   void runOnFunction() override {
43     OwningRewritePatternList patterns;
44     populateVectorSlicesLoweringPatterns(patterns, &getContext());
45     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
46   }
47 };
48 
49 struct TestVectorContractionConversion
50     : public PassWrapper<TestVectorContractionConversion, FunctionPass> {
51   TestVectorContractionConversion() = default;
TestVectorContractionConversion__anonfe31e6f10111::TestVectorContractionConversion52   TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
53   }
54 
55   Option<bool> lowerToFlatMatrix{
56       *this, "vector-lower-matrix-intrinsics",
57       llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
58       llvm::cl::init(false)};
59   Option<bool> lowerToFlatTranspose{
60       *this, "vector-flat-transpose",
61       llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
62       llvm::cl::init(false)};
63   Option<bool> lowerToOuterProduct{
64       *this, "vector-outerproduct",
65       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
66       llvm::cl::init(false)};
67   Option<bool> lowerToFilterOuterProduct{
68       *this, "vector-filter-outerproduct",
69       llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
70                      "vectors of size 4."),
71       llvm::cl::init(false)};
72 
runOnFunction__anonfe31e6f10111::TestVectorContractionConversion73   void runOnFunction() override {
74     OwningRewritePatternList patterns;
75 
76     // Test on one pattern in isolation.
77     if (lowerToOuterProduct) {
78       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
79       VectorTransformsOptions options{lowering};
80       patterns.insert<ContractionOpToOuterProductOpLowering>(options,
81                                                              &getContext());
82       applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
83       return;
84     }
85 
86     // Test on one pattern in isolation.
87     if (lowerToFilterOuterProduct) {
88       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
89       VectorTransformsOptions options{lowering};
90       patterns.insert<ContractionOpToOuterProductOpLowering>(
91           options, &getContext(), [](vector::ContractionOp op) {
92             // Only lowers vector.contract where the lhs as a type vector<MxNx?>
93             // where M is not 4.
94             if (op.getRhsType().getShape()[0] == 4)
95               return failure();
96             return success();
97           });
98       applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
99       return;
100     }
101 
102     // Test on all contract lowering patterns.
103     VectorContractLowering contractLowering = VectorContractLowering::Dot;
104     if (lowerToFlatMatrix)
105       contractLowering = VectorContractLowering::Matmul;
106     VectorTransposeLowering transposeLowering =
107         VectorTransposeLowering::EltWise;
108     if (lowerToFlatTranspose)
109       transposeLowering = VectorTransposeLowering::Flat;
110     VectorTransformsOptions options{contractLowering, transposeLowering};
111     populateVectorContractLoweringPatterns(patterns, &getContext(), options);
112     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
113   }
114 };
115 
116 struct TestVectorUnrollingPatterns
117     : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
118   TestVectorUnrollingPatterns() = default;
TestVectorUnrollingPatterns__anonfe31e6f10111::TestVectorUnrollingPatterns119   TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
runOnFunction__anonfe31e6f10111::TestVectorUnrollingPatterns120   void runOnFunction() override {
121     MLIRContext *ctx = &getContext();
122     OwningRewritePatternList patterns;
123     patterns.insert<UnrollVectorPattern<AddFOp>>(
124         ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
125 
126     if (unrollBasedOnType) {
127       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
128           [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
129         vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
130         SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
131         if (auto floatType = contractOp.getLhsType()
132                                  .getElementType()
133                                  .dyn_cast<FloatType>()) {
134           if (floatType.getWidth() == 16) {
135             nativeShape[2] = 4;
136           }
137         }
138         return nativeShape;
139       };
140       patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
141           ctx, UnrollVectorOptions().setNativeShapeFn(nativeShapeFn));
142     } else {
143       patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
144           ctx,
145           UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2, 2}));
146     }
147     populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
148     populateVectorToVectorTransformationPatterns(patterns, ctx);
149     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
150   }
151 
152   Option<bool> unrollBasedOnType{
153       *this, "unroll-based-on-type",
154       llvm::cl::desc("Set the unroll factor based on type of the operation"),
155       llvm::cl::init(false)};
156 };
157 
158 struct TestVectorDistributePatterns
159     : public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
160   TestVectorDistributePatterns() = default;
TestVectorDistributePatterns__anonfe31e6f10111::TestVectorDistributePatterns161   TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {}
getDependentDialects__anonfe31e6f10111::TestVectorDistributePatterns162   void getDependentDialects(DialectRegistry &registry) const override {
163     registry.insert<VectorDialect>();
164     registry.insert<AffineDialect>();
165   }
166   Option<int32_t> multiplicity{
167       *this, "distribution-multiplicity",
168       llvm::cl::desc("Set the multiplicity used for distributing vector"),
169       llvm::cl::init(32)};
runOnFunction__anonfe31e6f10111::TestVectorDistributePatterns170   void runOnFunction() override {
171     MLIRContext *ctx = &getContext();
172     OwningRewritePatternList patterns;
173     FuncOp func = getFunction();
174     func.walk([&](AddFOp op) {
175       OpBuilder builder(op);
176       Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
177           builder, op.getOperation(), func.getArgument(0), multiplicity);
178       if (ops.hasValue()) {
179         SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
180         op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
181       }
182     });
183     patterns.insert<PointwiseExtractPattern>(ctx);
184     populateVectorToVectorTransformationPatterns(patterns, ctx);
185     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
186   }
187 };
188 
189 struct TestVectorToLoopPatterns
190     : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> {
191   TestVectorToLoopPatterns() = default;
TestVectorToLoopPatterns__anonfe31e6f10111::TestVectorToLoopPatterns192   TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {}
getDependentDialects__anonfe31e6f10111::TestVectorToLoopPatterns193   void getDependentDialects(DialectRegistry &registry) const override {
194     registry.insert<VectorDialect>();
195     registry.insert<AffineDialect>();
196   }
197   Option<int32_t> multiplicity{
198       *this, "distribution-multiplicity",
199       llvm::cl::desc("Set the multiplicity used for distributing vector"),
200       llvm::cl::init(32)};
runOnFunction__anonfe31e6f10111::TestVectorToLoopPatterns201   void runOnFunction() override {
202     MLIRContext *ctx = &getContext();
203     OwningRewritePatternList patterns;
204     FuncOp func = getFunction();
205     func.walk([&](AddFOp op) {
206       // Check that the operation type can be broken down into a loop.
207       VectorType type = op.getType().dyn_cast<VectorType>();
208       if (!type || type.getRank() != 1 ||
209           type.getNumElements() % multiplicity != 0)
210         return mlir::WalkResult::advance();
211       auto filterAlloc = [](Operation *op) {
212         if (isa<ConstantOp, AllocOp, CallOp>(op))
213           return false;
214         return true;
215       };
216       auto dependentOps = getSlice(op, filterAlloc);
217       // Create a loop and move instructions from the Op slice into the loop.
218       OpBuilder builder(op);
219       auto zero = builder.create<ConstantOp>(
220           op.getLoc(), builder.getIndexType(),
221           builder.getIntegerAttr(builder.getIndexType(), 0));
222       auto one = builder.create<ConstantOp>(
223           op.getLoc(), builder.getIndexType(),
224           builder.getIntegerAttr(builder.getIndexType(), 1));
225       auto numIter = builder.create<ConstantOp>(
226           op.getLoc(), builder.getIndexType(),
227           builder.getIntegerAttr(builder.getIndexType(), multiplicity));
228       auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
229       for (Operation *it : dependentOps) {
230         it->moveBefore(forOp.getBody()->getTerminator());
231       }
232       // break up the original op and let the patterns propagate.
233       Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
234           builder, op.getOperation(), forOp.getInductionVar(), multiplicity);
235       if (ops.hasValue()) {
236         SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
237         op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
238       }
239       return mlir::WalkResult::interrupt();
240     });
241     patterns.insert<PointwiseExtractPattern>(ctx);
242     populateVectorToVectorTransformationPatterns(patterns, ctx);
243     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
244   }
245 };
246 
247 struct TestVectorTransferUnrollingPatterns
248     : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> {
getDependentDialects__anonfe31e6f10111::TestVectorTransferUnrollingPatterns249   void getDependentDialects(DialectRegistry &registry) const override {
250     registry.insert<AffineDialect>();
251   }
runOnFunction__anonfe31e6f10111::TestVectorTransferUnrollingPatterns252   void runOnFunction() override {
253     MLIRContext *ctx = &getContext();
254     OwningRewritePatternList patterns;
255     patterns.insert<UnrollVectorPattern<vector::TransferReadOp>>(
256         ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
257     patterns.insert<UnrollVectorPattern<vector::TransferWriteOp>>(
258         ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
259     populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
260     populateVectorToVectorTransformationPatterns(patterns, ctx);
261     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
262   }
263 };
264 
265 struct TestVectorTransferFullPartialSplitPatterns
266     : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
267                          FunctionPass> {
268   TestVectorTransferFullPartialSplitPatterns() = default;
TestVectorTransferFullPartialSplitPatterns__anonfe31e6f10111::TestVectorTransferFullPartialSplitPatterns269   TestVectorTransferFullPartialSplitPatterns(
270       const TestVectorTransferFullPartialSplitPatterns &pass) {}
271 
getDependentDialects__anonfe31e6f10111::TestVectorTransferFullPartialSplitPatterns272   void getDependentDialects(DialectRegistry &registry) const override {
273     registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
274   }
275 
276   Option<bool> useLinalgOps{
277       *this, "use-linalg-copy",
278       llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
279                      "linalg.copy operations."),
280       llvm::cl::init(false)};
runOnFunction__anonfe31e6f10111::TestVectorTransferFullPartialSplitPatterns281   void runOnFunction() override {
282     MLIRContext *ctx = &getContext();
283     OwningRewritePatternList patterns;
284     VectorTransformsOptions options;
285     if (useLinalgOps)
286       options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
287     else
288       options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
289     patterns.insert<VectorTransferFullPartialRewriter>(ctx, options);
290     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
291   }
292 };
293 
294 } // end anonymous namespace
295 
296 namespace mlir {
297 namespace test {
registerTestVectorConversions()298 void registerTestVectorConversions() {
299   PassRegistration<TestVectorToVectorConversion> vectorToVectorPass(
300       "test-vector-to-vector-conversion",
301       "Test conversion patterns between ops in the vector dialect");
302 
303   PassRegistration<TestVectorSlicesConversion> slicesPass(
304       "test-vector-slices-conversion",
305       "Test conversion patterns that lower slices ops in the vector dialect");
306 
307   PassRegistration<TestVectorContractionConversion> contractionPass(
308       "test-vector-contraction-conversion",
309       "Test conversion patterns that lower contract ops in the vector dialect");
310 
311   PassRegistration<TestVectorUnrollingPatterns> contractionUnrollingPass(
312       "test-vector-unrolling-patterns",
313       "Test conversion patterns to unroll contract ops in the vector dialect");
314 
315   PassRegistration<TestVectorTransferUnrollingPatterns> transferOpUnrollingPass(
316       "test-vector-transfer-unrolling-patterns",
317       "Test conversion patterns to unroll transfer ops in the vector dialect");
318 
319   PassRegistration<TestVectorTransferFullPartialSplitPatterns>
320       vectorTransformFullPartialPass("test-vector-transfer-full-partial-split",
321                                      "Test conversion patterns to split "
322                                      "transfer ops via scf.if + linalg ops");
323   PassRegistration<TestVectorDistributePatterns> distributePass(
324       "test-vector-distribute-patterns",
325       "Test conversion patterns to distribute vector ops in the vector "
326       "dialect");
327   PassRegistration<TestVectorToLoopPatterns> vectorToForLoop(
328       "test-vector-to-forloop",
329       "Test conversion patterns to break up a vector op into a for loop");
330 }
331 } // namespace test
332 } // namespace mlir
333