1 //===- TestLinalgCodegenStrategy.cpp - Test Linalg codegen strategy -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing the Linalg codegen strategy.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/GPU/GPUDialect.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
17 #include "mlir/Dialect/Linalg/Utils/Utils.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/Dialect/Vector/VectorOps.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Pass/Pass.h"
22 
23 #include "llvm/ADT/SetVector.h"
24 
25 using namespace mlir;
26 using namespace mlir::linalg;
27 
28 namespace {
29 struct TestLinalgCodegenStrategy
30     : public PassWrapper<TestLinalgCodegenStrategy, FunctionPass> {
getArgument__anon9131ef5b0111::TestLinalgCodegenStrategy31   StringRef getArgument() const final { return "test-linalg-codegen-strategy"; }
getDescription__anon9131ef5b0111::TestLinalgCodegenStrategy32   StringRef getDescription() const final {
33     return "Test Linalg Codegen Strategy.";
34   }
35   TestLinalgCodegenStrategy() = default;
TestLinalgCodegenStrategy__anon9131ef5b0111::TestLinalgCodegenStrategy36   TestLinalgCodegenStrategy(const TestLinalgCodegenStrategy &pass) {}
37 
getDependentDialects__anon9131ef5b0111::TestLinalgCodegenStrategy38   void getDependentDialects(DialectRegistry &registry) const override {
39     // clang-format off
40     registry.insert<AffineDialect,
41                     gpu::GPUDialect,
42                     linalg::LinalgDialect,
43                     memref::MemRefDialect,
44                     scf::SCFDialect,
45                     StandardOpsDialect,
46                     vector::VectorDialect>();
47     // clang-format on
48   }
49 
50   template <typename LinalgNamedOp>
51   void applyStrategyToNamedLinalgOp();
52 
53   void runOnFunction() override;
54 
55   void runStrategy(LinalgTilingOptions tilingOptions,
56                    LinalgTilingOptions registerTilingOptions,
57                    vector::VectorContractLowering vectorContractLowering,
58                    vector::VectorTransferSplit vectorTransferSplit);
59 
60   ListOption<int64_t> tileSizes{*this, "tile-sizes",
61                                 llvm::cl::MiscFlags::CommaSeparated,
62                                 llvm::cl::desc("Specifies the tile sizes.")};
63   ListOption<unsigned> tileInterchange{
64       *this, "tile-interchange", llvm::cl::MiscFlags::CommaSeparated,
65       llvm::cl::desc("Specifies the tile interchange.")};
66 
67   Option<bool> promote{
68       *this, "promote",
69       llvm::cl::desc("Promote the tile into a small aligned memory buffer."),
70       llvm::cl::init(false)};
71   Option<bool> promoteFullTile{
72       *this, "promote-full-tile-pad",
73       llvm::cl::desc("Pad the small aligned memory buffer to the tile sizes."),
74       llvm::cl::init(false)};
75   ListOption<int64_t> registerTileSizes{
76       *this, "register-tile-sizes", llvm::cl::MiscFlags::CommaSeparated,
77       llvm::cl::desc(
78           "Specifies the size of the register tile that will be used "
79           " to vectorize")};
80   Option<bool> registerPromote{
81       *this, "register-promote",
82       llvm::cl::desc(
83           "Promote the register tile into a small aligned memory buffer."),
84       llvm::cl::init(false)};
85   Option<bool> registerPromoteFullTile{
86       *this, "register-promote-full-tile-pad",
87       llvm::cl::desc("Pad the small aligned memory buffer to the tile sizes."),
88       llvm::cl::init(false)};
89   Option<bool> generalize{*this, "generalize",
90                           llvm::cl::desc("Generalize named operations."),
91                           llvm::cl::init(false)};
92   ListOption<int64_t> iteratorInterchange{
93       *this, "iterator-interchange", llvm::cl::MiscFlags::CommaSeparated,
94       llvm::cl::desc("Specifies the iterator interchange.")};
95   Option<bool> vectorize{
96       *this, "vectorize",
97       llvm::cl::desc("Rewrite the linalg op as a vector operation."),
98       llvm::cl::init(false)};
99   Option<std::string> splitVectorTransfersTo{
100       *this, "split-transfers",
101       llvm::cl::desc(
102           "Split vector transfers between slow (masked) and fast "
103           "(unmasked) variants. Possible options are:\n"
104           "\tnone: keep unsplit vector.transfer and pay the full price\n"
105           "\tlinalg-copy: use linalg.fill + linalg.copy for the slow path\n"
106           "\tvector-transfers: use extra small unmasked vector.transfer for"
107           " the slow path\n"),
108       llvm::cl::init("none")};
109   Option<std::string> vectorizeContractionTo{
110       *this, "vectorize-contraction-to",
111       llvm::cl::desc("the type of vector op to use for linalg contractions"),
112       llvm::cl::init("outerproduct")};
113   Option<bool> unrollVectorTransfers{
114       *this, "unroll-vector-transfers",
115       llvm::cl::desc("Enable full unrolling of vector.transfer operations"),
116       llvm::cl::init(false)};
117   Option<std::string> anchorOpName{
118       *this, "anchor-op",
119       llvm::cl::desc(
120           "Which single linalg op is the anchor for the codegen strategy to "
121           "latch on:\n"
122           "\tlinalg.matmul: anchor on linalg.matmul\n"
123           "\tlinalg.matmul_column_major: anchor on linalg.matmul_column_major\n"
124           "\tlinalg.copy: anchor on linalg.copy\n"
125           "\tlinalg.fill: anchor on linalg.fill\n"),
126       llvm::cl::init("")};
127   Option<std::string> anchorFuncOpName{
128       *this, "anchor-func",
129       llvm::cl::desc(
130           "Which single func op is the anchor for the codegen strategy to "
131           "latch on."),
132       llvm::cl::init("")};
133 };
134 
runStrategy(LinalgTilingOptions tilingOptions,LinalgTilingOptions registerTilingOptions,vector::VectorContractLowering vectorContractLowering,vector::VectorTransferSplit vectorTransferSplit)135 void TestLinalgCodegenStrategy::runStrategy(
136     LinalgTilingOptions tilingOptions,
137     LinalgTilingOptions registerTilingOptions,
138     vector::VectorContractLowering vectorContractLowering,
139     vector::VectorTransferSplit vectorTransferSplit) {
140   assert(!anchorOpName.empty());
141   CodegenStrategy strategy;
142   StringRef genericOpName = GenericOp::getOperationName();
143   strategy.tileIf(!tileSizes.empty(), anchorOpName, tilingOptions)
144       .promoteIf(promote, anchorOpName,
145                  LinalgPromotionOptions()
146                      .setAlignment(16)
147                      .setUseFullTileBuffersByDefault(promoteFullTile))
148       .tileIf(!registerTileSizes.empty(), anchorOpName, registerTilingOptions)
149       .promoteIf(registerPromote, anchorOpName,
150                  LinalgPromotionOptions()
151                      .setAlignment(16)
152                      .setUseFullTileBuffersByDefault(registerPromoteFullTile))
153       .generalizeIf(generalize, anchorOpName)
154       .interchangeIf(!iteratorInterchange.empty(), iteratorInterchange)
155       .vectorizeIf(vectorize, generalize ? genericOpName : anchorOpName)
156       .setEnableVectorTransferPartialRewrite(true)
157       .setEnableVectorContractLowering(true)
158       .setEnableVectorToSCFConversion(true)
159       .setVectorTransformsOptions(
160           vector::VectorTransformsOptions()
161               .setVectorTransformsOptions(vectorContractLowering)
162               .setVectorTransferSplit(vectorTransferSplit))
163       .setVectorTransferToSCFOptions(
164           VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers));
165   (void)strategy.transform(getFunction());
166 }
167 } // end anonymous namespace
168 
169 /// Apply transformations specified as patterns.
runOnFunction()170 void TestLinalgCodegenStrategy::runOnFunction() {
171   if (!anchorFuncOpName.empty() && anchorFuncOpName != getFunction().getName())
172     return;
173 
174   LinalgTilingOptions tilingOptions;
175   if (!tileSizes.empty())
176     tilingOptions = tilingOptions.setTileSizes(tileSizes);
177   if (!tileInterchange.empty())
178     tilingOptions = tilingOptions.setInterchange(tileInterchange);
179 
180   LinalgTilingOptions registerTilingOptions;
181   if (!registerTileSizes.empty())
182     registerTilingOptions =
183         registerTilingOptions.setTileSizes(registerTileSizes);
184 
185   vector::VectorContractLowering vectorContractLowering =
186       llvm::StringSwitch<vector::VectorContractLowering>(
187           vectorizeContractionTo.getValue())
188           .Case("matrixintrinsics", vector::VectorContractLowering::Matmul)
189           .Case("dot", vector::VectorContractLowering::Dot)
190           .Case("outerproduct", vector::VectorContractLowering::OuterProduct)
191           .Default(vector::VectorContractLowering::OuterProduct);
192   vector::VectorTransferSplit vectorTransferSplit =
193       llvm::StringSwitch<vector::VectorTransferSplit>(
194           splitVectorTransfersTo.getValue())
195           .Case("none", vector::VectorTransferSplit::None)
196           .Case("linalg-copy", vector::VectorTransferSplit::LinalgCopy)
197           .Case("vector-transfers", vector::VectorTransferSplit::VectorTransfer)
198           .Default(vector::VectorTransferSplit::None);
199 
200   runStrategy(tilingOptions, registerTilingOptions, vectorContractLowering,
201               vectorTransferSplit);
202 }
203 
204 namespace mlir {
205 namespace test {
registerTestLinalgCodegenStrategy()206 void registerTestLinalgCodegenStrategy() {
207   PassRegistration<TestLinalgCodegenStrategy>();
208 }
209 } // namespace test
210 } // namespace mlir
211