1 //===- TestLinalgCodegenStrategy.cpp - Test Linalg codegen strategy -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing the Linalg codegen strategy.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/GPU/GPUDialect.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
17 #include "mlir/Dialect/Linalg/Utils/Utils.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/Dialect/Vector/VectorOps.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Pass/Pass.h"
22 
23 #include "llvm/ADT/SetVector.h"
24 
25 using namespace mlir;
26 using namespace mlir::linalg;
27 
28 namespace {
29 struct TestLinalgCodegenStrategy
30     : public PassWrapper<TestLinalgCodegenStrategy, FunctionPass> {
31   TestLinalgCodegenStrategy() = default;
TestLinalgCodegenStrategy__anonae3f22c10111::TestLinalgCodegenStrategy32   TestLinalgCodegenStrategy(const TestLinalgCodegenStrategy &pass) {}
33 
getDependentDialects__anonae3f22c10111::TestLinalgCodegenStrategy34   void getDependentDialects(DialectRegistry &registry) const override {
35     // clang-format off
36     registry.insert<AffineDialect,
37                     gpu::GPUDialect,
38                     linalg::LinalgDialect,
39                     scf::SCFDialect,
40                     StandardOpsDialect,
41                     vector::VectorDialect>();
42     // clang-format on
43   }
44 
45   void runOnFunction() override;
46 
47   ListOption<int64_t> tileSizes{*this, "tile-sizes",
48                                 llvm::cl::MiscFlags::CommaSeparated,
49                                 llvm::cl::desc("Specifies the tile sizes.")};
50   Option<bool> promote{
51       *this, "promote",
52       llvm::cl::desc("Promote the tile into a small aligned memory buffer."),
53       llvm::cl::init(false)};
54   Option<bool> promoteFullTile{
55       *this, "promote-full-tile-pad",
56       llvm::cl::desc("Pad the small aligned memory buffer to the tile sizes."),
57       llvm::cl::init(false)};
58   ListOption<int64_t> registerTileSizes{
59       *this, "register-tile-sizes", llvm::cl::MiscFlags::CommaSeparated,
60       llvm::cl::desc(
61           "Specifies the size of the register tile that will be used "
62           " to vectorize")};
63   Option<bool> registerPromote{
64       *this, "register-promote",
65       llvm::cl::desc(
66           "Promote the register tile into a small aligned memory buffer."),
67       llvm::cl::init(false)};
68   Option<bool> registerPromoteFullTile{
69       *this, "register-promote-full-tile-pad",
70       llvm::cl::desc("Pad the small aligned memory buffer to the tile sizes."),
71       llvm::cl::init(false)};
72   Option<bool> vectorize{
73       *this, "vectorize",
74       llvm::cl::desc("Rewrite the linalg op as a vector operation."),
75       llvm::cl::init(false)};
76   Option<std::string> splitVectorTransfersTo{
77       *this, "split-transfers",
78       llvm::cl::desc(
79           "Split vector transfers between slow (masked) and fast "
80           "(unmasked) variants. Possible options are:\n"
81           "\tnone: keep unsplit vector.transfer and pay the full price\n"
82           "\tlinalg-copy: use linalg.fill + linalg.copy for the slow path\n"
83           "\tvector-transfers: use extra small unmasked vector.transfer for"
84           " the slow path\n"),
85       llvm::cl::init("none")};
86   Option<std::string> vectorizeContractionTo{
87       *this, "vectorize-contraction-to",
88       llvm::cl::desc("the type of vector op to use for linalg contractions"),
89       llvm::cl::init("outerproduct")};
90   Option<bool> unrollVectorTransfers{
91       *this, "unroll-vector-transfers",
92       llvm::cl::desc("Enable full unrolling of vector.transfer operations"),
93       llvm::cl::init(false)};
94 };
95 } // end anonymous namespace
96 
97 /// Apply transformations specified as patterns.
runOnFunction()98 void TestLinalgCodegenStrategy::runOnFunction() {
99   LinalgTilingOptions tilingOptions;
100   if (!tileSizes.empty())
101     tilingOptions = tilingOptions.setTileSizes(tileSizes);
102 
103   LinalgTilingOptions registerTilingOptions;
104   if (!registerTileSizes.empty())
105     registerTilingOptions =
106         registerTilingOptions.setTileSizes(registerTileSizes);
107 
108   vector::VectorContractLowering vectorContractLowering =
109       llvm::StringSwitch<vector::VectorContractLowering>(
110           vectorizeContractionTo.getValue())
111           .Case("matrixintrinsics", vector::VectorContractLowering::Matmul)
112           .Case("dot", vector::VectorContractLowering::Dot)
113           .Case("outerproduct", vector::VectorContractLowering::OuterProduct)
114           .Default(vector::VectorContractLowering::OuterProduct);
115   vector::VectorTransferSplit vectorTransferSplit =
116       llvm::StringSwitch<vector::VectorTransferSplit>(
117           splitVectorTransfersTo.getValue())
118           .Case("none", vector::VectorTransferSplit::None)
119           .Case("linalg-copy", vector::VectorTransferSplit::LinalgCopy)
120           .Case("vector-transfers", vector::VectorTransferSplit::VectorTransfer)
121           .Default(vector::VectorTransferSplit::None);
122 
123   CodegenStrategy strategy;
124   strategy.tileIf<MatmulOp>(!tileSizes.empty(), tilingOptions)
125       .promoteIf<MatmulOp>(promote,
126                            LinalgPromotionOptions()
127                                .setAlignment(16)
128                                .setUseFullTileBuffersByDefault(promoteFullTile))
129       .tileIf<MatmulOp>(!registerTileSizes.empty(), registerTilingOptions)
130       .promoteIf<MatmulOp>(registerPromote, LinalgPromotionOptions()
131                                                 .setAlignment(16)
132                                                 .setUseFullTileBuffersByDefault(
133                                                     registerPromoteFullTile))
134       .vectorizeIf<MatmulOp>(vectorize)
135       .setVectorTransformsOptions(
136           vector::VectorTransformsOptions()
137               .setVectorTransformsOptions(vectorContractLowering)
138               .setVectorTransferSplit(vectorTransferSplit))
139       .setVectorTransferToSCFOptions(
140           VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers));
141 
142   strategy.transform(getFunction());
143 }
144 
145 namespace mlir {
146 namespace test {
registerTestLinalgCodegenStrategy()147 void registerTestLinalgCodegenStrategy() {
148   PassRegistration<TestLinalgCodegenStrategy> testLinalgCodegenStrategyPass(
149       "test-linalg-codegen-strategy", "Test Linalg Codegen Strategy.");
150 }
151 } // namespace test
152 } // namespace mlir
153