1 //===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <type_traits>
10 
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
14 #include "mlir/Dialect/SCF/SCF.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/Dialect/Vector/VectorTransforms.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 
21 using namespace mlir;
22 using namespace mlir::vector;
23 namespace {
24 
25 struct TestVectorToVectorConversion
26     : public PassWrapper<TestVectorToVectorConversion, FunctionPass> {
27   TestVectorToVectorConversion() = default;
TestVectorToVectorConversion__anonafb160640111::TestVectorToVectorConversion28   TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {}
29 
getDependentDialects__anonafb160640111::TestVectorToVectorConversion30   void getDependentDialects(DialectRegistry &registry) const override {
31     registry.insert<AffineDialect>();
32   }
33 
34   Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
35                       llvm::cl::init(false)};
36 
runOnFunction__anonafb160640111::TestVectorToVectorConversion37   void runOnFunction() override {
38     OwningRewritePatternList patterns;
39     auto *ctx = &getContext();
40     if (unroll) {
41       patterns.insert<UnrollVectorPattern>(
42           ctx,
43           UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
44               filter));
45     }
46     populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
47     populateVectorToVectorTransformationPatterns(patterns, ctx);
48     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
49   }
50 
51 private:
52   // Return the target shape based on op type.
getShape__anonafb160640111::TestVectorToVectorConversion53   static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
54     if (isa<AddFOp, SelectOp, CmpFOp>(op))
55       return SmallVector<int64_t, 4>(2, 2);
56     if (isa<vector::ContractionOp>(op))
57       return SmallVector<int64_t, 4>(3, 2);
58     return llvm::None;
59   }
60 
filter__anonafb160640111::TestVectorToVectorConversion61   static LogicalResult filter(Operation *op) {
62     return success(isa<AddFOp, SelectOp, CmpFOp, ContractionOp>(op));
63   }
64 };
65 
66 struct TestVectorSlicesConversion
67     : public PassWrapper<TestVectorSlicesConversion, FunctionPass> {
runOnFunction__anonafb160640111::TestVectorSlicesConversion68   void runOnFunction() override {
69     OwningRewritePatternList patterns;
70     populateVectorSlicesLoweringPatterns(patterns, &getContext());
71     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
72   }
73 };
74 
75 struct TestVectorContractionConversion
76     : public PassWrapper<TestVectorContractionConversion, FunctionPass> {
77   TestVectorContractionConversion() = default;
TestVectorContractionConversion__anonafb160640111::TestVectorContractionConversion78   TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
79   }
80 
81   Option<bool> lowerToFlatMatrix{
82       *this, "vector-lower-matrix-intrinsics",
83       llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
84       llvm::cl::init(false)};
85   Option<bool> lowerToFlatTranspose{
86       *this, "vector-flat-transpose",
87       llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
88       llvm::cl::init(false)};
89   Option<bool> lowerToOuterProduct{
90       *this, "vector-outerproduct",
91       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
92       llvm::cl::init(false)};
93   Option<bool> lowerToFilterOuterProduct{
94       *this, "vector-filter-outerproduct",
95       llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
96                      "vectors of size 4."),
97       llvm::cl::init(false)};
98 
runOnFunction__anonafb160640111::TestVectorContractionConversion99   void runOnFunction() override {
100     OwningRewritePatternList patterns;
101 
102     // Test on one pattern in isolation.
103     if (lowerToOuterProduct) {
104       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
105       VectorTransformsOptions options{lowering};
106       patterns.insert<ContractionOpToOuterProductOpLowering>(options,
107                                                              &getContext());
108       applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
109       return;
110     }
111 
112     // Test on one pattern in isolation.
113     if (lowerToFilterOuterProduct) {
114       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
115       VectorTransformsOptions options{lowering};
116       patterns.insert<ContractionOpToOuterProductOpLowering>(
117           options, &getContext(), [](vector::ContractionOp op) {
118             // Only lowers vector.contract where the lhs as a type vector<MxNx?>
119             // where M is not 4.
120             if (op.getRhsType().getShape()[0] == 4)
121               return failure();
122             return success();
123           });
124       applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
125       return;
126     }
127 
128     // Test on all contract lowering patterns.
129     VectorContractLowering contractLowering = VectorContractLowering::Dot;
130     if (lowerToFlatMatrix)
131       contractLowering = VectorContractLowering::Matmul;
132     VectorTransposeLowering transposeLowering =
133         VectorTransposeLowering::EltWise;
134     if (lowerToFlatTranspose)
135       transposeLowering = VectorTransposeLowering::Flat;
136     VectorTransformsOptions options{contractLowering, transposeLowering};
137     populateVectorContractLoweringPatterns(patterns, &getContext(), options);
138     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
139   }
140 };
141 
142 struct TestVectorUnrollingPatterns
143     : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
144   TestVectorUnrollingPatterns() = default;
TestVectorUnrollingPatterns__anonafb160640111::TestVectorUnrollingPatterns145   TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
runOnFunction__anonafb160640111::TestVectorUnrollingPatterns146   void runOnFunction() override {
147     MLIRContext *ctx = &getContext();
148     OwningRewritePatternList patterns;
149     patterns.insert<UnrollVectorPattern>(
150         ctx, UnrollVectorOptions()
151                  .setNativeShape(ArrayRef<int64_t>{2, 2})
152                  .setFilterConstraint(
153                      [](Operation *op) { return success(isa<AddFOp>(op)); }));
154 
155     if (unrollBasedOnType) {
156       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
157           [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
158         vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
159         SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
160         if (auto floatType = contractOp.getLhsType()
161                                  .getElementType()
162                                  .dyn_cast<FloatType>()) {
163           if (floatType.getWidth() == 16) {
164             nativeShape[2] = 4;
165           }
166         }
167         return nativeShape;
168       };
169       patterns.insert<UnrollVectorPattern>(
170           ctx, UnrollVectorOptions()
171                    .setNativeShapeFn(nativeShapeFn)
172                    .setFilterConstraint([](Operation *op) {
173                      return success(isa<ContractionOp>(op));
174                    }));
175     } else {
176       patterns.insert<UnrollVectorPattern>(
177           ctx, UnrollVectorOptions()
178                    .setNativeShape(ArrayRef<int64_t>{2, 2, 2})
179                    .setFilterConstraint([](Operation *op) {
180                      return success(isa<ContractionOp>(op));
181                    }));
182     }
183     populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
184     populateVectorToVectorTransformationPatterns(patterns, ctx);
185     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
186   }
187 
188   Option<bool> unrollBasedOnType{
189       *this, "unroll-based-on-type",
190       llvm::cl::desc("Set the unroll factor based on type of the operation"),
191       llvm::cl::init(false)};
192 };
193 
194 struct TestVectorDistributePatterns
195     : public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
196   TestVectorDistributePatterns() = default;
TestVectorDistributePatterns__anonafb160640111::TestVectorDistributePatterns197   TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {}
getDependentDialects__anonafb160640111::TestVectorDistributePatterns198   void getDependentDialects(DialectRegistry &registry) const override {
199     registry.insert<VectorDialect>();
200     registry.insert<AffineDialect>();
201   }
202   ListOption<int32_t> multiplicity{
203       *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated,
204       llvm::cl::desc("Set the multiplicity used for distributing vector")};
205 
runOnFunction__anonafb160640111::TestVectorDistributePatterns206   void runOnFunction() override {
207     MLIRContext *ctx = &getContext();
208     OwningRewritePatternList patterns;
209     FuncOp func = getFunction();
210     func.walk([&](AddFOp op) {
211       OpBuilder builder(op);
212       if (auto vecType = op.getType().dyn_cast<VectorType>()) {
213         SmallVector<int64_t, 2> mul;
214         SmallVector<AffineExpr, 2> perm;
215         SmallVector<Value, 2> ids;
216         unsigned count = 0;
217         // Remove the multiplicity of 1 and calculate the affine map based on
218         // the multiplicity.
219         SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
220         for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
221           if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
222             mul.push_back(m[i]);
223             ids.push_back(func.getArgument(count++));
224             perm.push_back(getAffineDimExpr(i, ctx));
225           }
226         }
227         auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
228                                   perm, ctx);
229         Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
230             builder, op.getOperation(), ids, mul, map);
231         if (ops.hasValue()) {
232           SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
233           op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
234                                               extractOp);
235         }
236       }
237     });
238     patterns.insert<PointwiseExtractPattern>(ctx);
239     populateVectorToVectorTransformationPatterns(patterns, ctx);
240     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
241   }
242 };
243 
244 struct TestVectorToLoopPatterns
245     : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> {
246   TestVectorToLoopPatterns() = default;
TestVectorToLoopPatterns__anonafb160640111::TestVectorToLoopPatterns247   TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {}
getDependentDialects__anonafb160640111::TestVectorToLoopPatterns248   void getDependentDialects(DialectRegistry &registry) const override {
249     registry.insert<VectorDialect>();
250     registry.insert<AffineDialect>();
251   }
252   Option<int32_t> multiplicity{
253       *this, "distribution-multiplicity",
254       llvm::cl::desc("Set the multiplicity used for distributing vector"),
255       llvm::cl::init(32)};
runOnFunction__anonafb160640111::TestVectorToLoopPatterns256   void runOnFunction() override {
257     MLIRContext *ctx = &getContext();
258     OwningRewritePatternList patterns;
259     FuncOp func = getFunction();
260     func.walk([&](AddFOp op) {
261       // Check that the operation type can be broken down into a loop.
262       VectorType type = op.getType().dyn_cast<VectorType>();
263       if (!type || type.getRank() != 1 ||
264           type.getNumElements() % multiplicity != 0)
265         return mlir::WalkResult::advance();
266       auto filterAlloc = [](Operation *op) {
267         if (isa<ConstantOp, AllocOp, CallOp>(op))
268           return false;
269         return true;
270       };
271       auto dependentOps = getSlice(op, filterAlloc);
272       // Create a loop and move instructions from the Op slice into the loop.
273       OpBuilder builder(op);
274       auto zero = builder.create<ConstantOp>(
275           op.getLoc(), builder.getIndexType(),
276           builder.getIntegerAttr(builder.getIndexType(), 0));
277       auto one = builder.create<ConstantOp>(
278           op.getLoc(), builder.getIndexType(),
279           builder.getIntegerAttr(builder.getIndexType(), 1));
280       auto numIter = builder.create<ConstantOp>(
281           op.getLoc(), builder.getIndexType(),
282           builder.getIntegerAttr(builder.getIndexType(), multiplicity));
283       auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
284       for (Operation *it : dependentOps) {
285         it->moveBefore(forOp.getBody()->getTerminator());
286       }
287       auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
288       // break up the original op and let the patterns propagate.
289       Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
290           builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
291           map);
292       if (ops.hasValue()) {
293         SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
294         op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
295       }
296       return mlir::WalkResult::interrupt();
297     });
298     patterns.insert<PointwiseExtractPattern>(ctx);
299     populateVectorToVectorTransformationPatterns(patterns, ctx);
300     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
301   }
302 };
303 
304 struct TestVectorTransferUnrollingPatterns
305     : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> {
getDependentDialects__anonafb160640111::TestVectorTransferUnrollingPatterns306   void getDependentDialects(DialectRegistry &registry) const override {
307     registry.insert<AffineDialect>();
308   }
runOnFunction__anonafb160640111::TestVectorTransferUnrollingPatterns309   void runOnFunction() override {
310     MLIRContext *ctx = &getContext();
311     OwningRewritePatternList patterns;
312     patterns.insert<UnrollVectorPattern>(
313         ctx,
314         UnrollVectorOptions()
315             .setNativeShape(ArrayRef<int64_t>{2, 2})
316             .setFilterConstraint([](Operation *op) {
317               return success(
318                   isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
319             }));
320     populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
321     populateVectorToVectorTransformationPatterns(patterns, ctx);
322     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
323   }
324 };
325 
326 struct TestVectorTransferFullPartialSplitPatterns
327     : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
328                          FunctionPass> {
329   TestVectorTransferFullPartialSplitPatterns() = default;
TestVectorTransferFullPartialSplitPatterns__anonafb160640111::TestVectorTransferFullPartialSplitPatterns330   TestVectorTransferFullPartialSplitPatterns(
331       const TestVectorTransferFullPartialSplitPatterns &pass) {}
332 
getDependentDialects__anonafb160640111::TestVectorTransferFullPartialSplitPatterns333   void getDependentDialects(DialectRegistry &registry) const override {
334     registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
335   }
336 
337   Option<bool> useLinalgOps{
338       *this, "use-linalg-copy",
339       llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
340                      "linalg.copy operations."),
341       llvm::cl::init(false)};
runOnFunction__anonafb160640111::TestVectorTransferFullPartialSplitPatterns342   void runOnFunction() override {
343     MLIRContext *ctx = &getContext();
344     OwningRewritePatternList patterns;
345     VectorTransformsOptions options;
346     if (useLinalgOps)
347       options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
348     else
349       options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
350     patterns.insert<VectorTransferFullPartialRewriter>(ctx, options);
351     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
352   }
353 };
354 
355 struct TestVectorTransferOpt
356     : public PassWrapper<TestVectorTransferOpt, FunctionPass> {
runOnFunction__anonafb160640111::TestVectorTransferOpt357   void runOnFunction() override { transferOpflowOpt(getFunction()); }
358 };
359 
360 } // end anonymous namespace
361 
362 namespace mlir {
363 namespace test {
registerTestVectorConversions()364 void registerTestVectorConversions() {
365   PassRegistration<TestVectorToVectorConversion> vectorToVectorPass(
366       "test-vector-to-vector-conversion",
367       "Test conversion patterns between ops in the vector dialect");
368 
369   PassRegistration<TestVectorSlicesConversion> slicesPass(
370       "test-vector-slices-conversion",
371       "Test conversion patterns that lower slices ops in the vector dialect");
372 
373   PassRegistration<TestVectorContractionConversion> contractionPass(
374       "test-vector-contraction-conversion",
375       "Test conversion patterns that lower contract ops in the vector dialect");
376 
377   PassRegistration<TestVectorUnrollingPatterns> contractionUnrollingPass(
378       "test-vector-unrolling-patterns",
379       "Test conversion patterns to unroll contract ops in the vector dialect");
380 
381   PassRegistration<TestVectorTransferUnrollingPatterns> transferOpUnrollingPass(
382       "test-vector-transfer-unrolling-patterns",
383       "Test conversion patterns to unroll transfer ops in the vector dialect");
384 
385   PassRegistration<TestVectorTransferFullPartialSplitPatterns>
386       vectorTransformFullPartialPass("test-vector-transfer-full-partial-split",
387                                      "Test conversion patterns to split "
388                                      "transfer ops via scf.if + linalg ops");
389   PassRegistration<TestVectorDistributePatterns> distributePass(
390       "test-vector-distribute-patterns",
391       "Test conversion patterns to distribute vector ops in the vector "
392       "dialect");
393   PassRegistration<TestVectorToLoopPatterns> vectorToForLoop(
394       "test-vector-to-forloop",
395       "Test conversion patterns to break up a vector op into a for loop");
396   PassRegistration<TestVectorTransferOpt> transferOpOpt(
397       "test-vector-transferop-opt",
398       "Test optimization transformations for transfer ops");
399 }
400 } // namespace test
401 } // namespace mlir
402