1 //===- TestPassManager.cpp - Test pass manager functionality --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinOps.h"
10 #include "mlir/Pass/Pass.h"
11 #include "mlir/Pass/PassManager.h"
12 
13 using namespace mlir;
14 
15 namespace {
16 struct TestModulePass
17     : public PassWrapper<TestModulePass, OperationPass<ModuleOp>> {
runOnOperation__anond6b7e0de0111::TestModulePass18   void runOnOperation() final {}
getArgument__anond6b7e0de0111::TestModulePass19   StringRef getArgument() const final { return "test-module-pass"; }
getDescription__anond6b7e0de0111::TestModulePass20   StringRef getDescription() const final {
21     return "Test a module pass in the pass manager";
22   }
23 };
24 struct TestFunctionPass : public PassWrapper<TestFunctionPass, FunctionPass> {
runOnFunction__anond6b7e0de0111::TestFunctionPass25   void runOnFunction() final {}
getArgument__anond6b7e0de0111::TestFunctionPass26   StringRef getArgument() const final { return "test-function-pass"; }
getDescription__anond6b7e0de0111::TestFunctionPass27   StringRef getDescription() const final {
28     return "Test a function pass in the pass manager";
29   }
30 };
31 class TestOptionsPass : public PassWrapper<TestOptionsPass, FunctionPass> {
32 public:
33   struct Options : public PassPipelineOptions<Options> {
34     ListOption<int> listOption{*this, "list",
35                                llvm::cl::MiscFlags::CommaSeparated,
36                                llvm::cl::desc("Example list option")};
37     ListOption<std::string> stringListOption{
38         *this, "string-list", llvm::cl::MiscFlags::CommaSeparated,
39         llvm::cl::desc("Example string list option")};
40     Option<std::string> stringOption{*this, "string",
41                                      llvm::cl::desc("Example string option")};
42   };
43   TestOptionsPass() = default;
TestOptionsPass(const TestOptionsPass &)44   TestOptionsPass(const TestOptionsPass &) {}
TestOptionsPass(const Options & options)45   TestOptionsPass(const Options &options) {
46     listOption = options.listOption;
47     stringOption = options.stringOption;
48     stringListOption = options.stringListOption;
49   }
50 
runOnFunction()51   void runOnFunction() final {}
getArgument() const52   StringRef getArgument() const final { return "test-options-pass"; }
getDescription() const53   StringRef getDescription() const final {
54     return "Test options parsing capabilities";
55   }
56 
57   ListOption<int> listOption{*this, "list", llvm::cl::MiscFlags::CommaSeparated,
58                              llvm::cl::desc("Example list option")};
59   ListOption<std::string> stringListOption{
60       *this, "string-list", llvm::cl::MiscFlags::CommaSeparated,
61       llvm::cl::desc("Example string list option")};
62   Option<std::string> stringOption{*this, "string",
63                                    llvm::cl::desc("Example string option")};
64 };
65 
66 /// A test pass that always aborts to enable testing the crash recovery
67 /// mechanism of the pass manager.
68 class TestCrashRecoveryPass
69     : public PassWrapper<TestCrashRecoveryPass, OperationPass<>> {
runOnOperation()70   void runOnOperation() final { abort(); }
getArgument() const71   StringRef getArgument() const final { return "test-pass-crash"; }
getDescription() const72   StringRef getDescription() const final {
73     return "Test a pass in the pass manager that always crashes";
74   }
75 };
76 
77 /// A test pass that always fails to enable testing the failure recovery
78 /// mechanisms of the pass manager.
79 class TestFailurePass : public PassWrapper<TestFailurePass, OperationPass<>> {
runOnOperation()80   void runOnOperation() final { signalPassFailure(); }
getArgument() const81   StringRef getArgument() const final { return "test-pass-failure"; }
getDescription() const82   StringRef getDescription() const final {
83     return "Test a pass in the pass manager that always fails";
84   }
85 };
86 
87 /// A test pass that contains a statistic.
88 struct TestStatisticPass
89     : public PassWrapper<TestStatisticPass, OperationPass<>> {
90   TestStatisticPass() = default;
TestStatisticPass__anond6b7e0de0111::TestStatisticPass91   TestStatisticPass(const TestStatisticPass &) {}
getArgument__anond6b7e0de0111::TestStatisticPass92   StringRef getArgument() const final { return "test-stats-pass"; }
getDescription__anond6b7e0de0111::TestStatisticPass93   StringRef getDescription() const final { return "Test pass statistics"; }
94 
95   Statistic opCount{this, "num-ops", "Number of operations counted"};
96 
runOnOperation__anond6b7e0de0111::TestStatisticPass97   void runOnOperation() final {
98     getOperation()->walk([&](Operation *) { ++opCount; });
99   }
100 };
101 } // end anonymous namespace
102 
testNestedPipeline(OpPassManager & pm)103 static void testNestedPipeline(OpPassManager &pm) {
104   // Nest a module pipeline that contains:
105   /// A module pass.
106   auto &modulePM = pm.nest<ModuleOp>();
107   modulePM.addPass(std::make_unique<TestModulePass>());
108   /// A nested function pass.
109   auto &nestedFunctionPM = modulePM.nest<FuncOp>();
110   nestedFunctionPM.addPass(std::make_unique<TestFunctionPass>());
111 
112   // Nest a function pipeline that contains a single pass.
113   auto &functionPM = pm.nest<FuncOp>();
114   functionPM.addPass(std::make_unique<TestFunctionPass>());
115 }
116 
testNestedPipelineTextual(OpPassManager & pm)117 static void testNestedPipelineTextual(OpPassManager &pm) {
118   (void)parsePassPipeline("test-pm-nested-pipeline", pm);
119 }
120 
121 namespace mlir {
registerPassManagerTestPass()122 void registerPassManagerTestPass() {
123   PassRegistration<TestOptionsPass>();
124 
125   PassRegistration<TestModulePass>();
126 
127   PassRegistration<TestFunctionPass>();
128 
129   PassRegistration<TestCrashRecoveryPass>();
130   PassRegistration<TestFailurePass>();
131 
132   PassRegistration<TestStatisticPass>();
133 
134   PassPipelineRegistration<>("test-pm-nested-pipeline",
135                              "Test a nested pipeline in the pass manager",
136                              testNestedPipeline);
137   PassPipelineRegistration<>("test-textual-pm-nested-pipeline",
138                              "Test a nested pipeline in the pass manager",
139                              testNestedPipelineTextual);
140   PassPipelineRegistration<>(
141       "test-dump-pipeline",
142       "Dumps the pipeline build so far for debugging purposes",
143       [](OpPassManager &pm) {
144         pm.printAsTextualPipeline(llvm::errs());
145         llvm::errs() << "\n";
146       });
147 
148   PassPipelineRegistration<TestOptionsPass::Options>
149       registerOptionsPassPipeline(
150           "test-options-pass-pipeline",
151           "Parses options using pass pipeline registration",
152           [](OpPassManager &pm, const TestOptionsPass::Options &options) {
153             pm.addPass(std::make_unique<TestOptionsPass>(options));
154           });
155 }
156 } // namespace mlir
157