1 //===- llvm/unittest/CodeGen/PassManager.cpp - PassManager tests ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "llvm/ADT/Triple.h"
10 #include "llvm/Analysis/CGSCCPassManager.h"
11 #include "llvm/Analysis/LoopAnalysisManager.h"
12 #include "llvm/AsmParser/Parser.h"
13 #include "llvm/CodeGen/MachineModuleInfo.h"
14 #include "llvm/CodeGen/MachinePassManager.h"
15 #include "llvm/IR/LLVMContext.h"
16 #include "llvm/IR/Module.h"
17 #include "llvm/Passes/PassBuilder.h"
18 #include "llvm/Support/Host.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "llvm/Support/TargetRegistry.h"
21 #include "llvm/Support/TargetSelect.h"
22 #include "llvm/Target/TargetMachine.h"
23 #include "gtest/gtest.h"
24
25 using namespace llvm;
26
27 namespace {
28
29 class TestFunctionAnalysis : public AnalysisInfoMixin<TestFunctionAnalysis> {
30 public:
31 struct Result {
Result__anon1d198ac80111::TestFunctionAnalysis::Result32 Result(int Count) : InstructionCount(Count) {}
33 int InstructionCount;
34 };
35
36 /// Run the analysis pass over the function and return a result.
run(Function & F,FunctionAnalysisManager & AM)37 Result run(Function &F, FunctionAnalysisManager &AM) {
38 int Count = 0;
39 for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI)
40 for (BasicBlock::iterator II = BBI->begin(), IE = BBI->end(); II != IE;
41 ++II)
42 ++Count;
43 return Result(Count);
44 }
45
46 private:
47 friend AnalysisInfoMixin<TestFunctionAnalysis>;
48 static AnalysisKey Key;
49 };
50
51 AnalysisKey TestFunctionAnalysis::Key;
52
53 class TestMachineFunctionAnalysis
54 : public AnalysisInfoMixin<TestMachineFunctionAnalysis> {
55 public:
56 struct Result {
Result__anon1d198ac80111::TestMachineFunctionAnalysis::Result57 Result(int Count) : InstructionCount(Count) {}
58 int InstructionCount;
59 };
60
61 /// Run the analysis pass over the machine function and return a result.
run(MachineFunction & MF,MachineFunctionAnalysisManager::Base & AM)62 Result run(MachineFunction &MF, MachineFunctionAnalysisManager::Base &AM) {
63 auto &MFAM = static_cast<MachineFunctionAnalysisManager &>(AM);
64 // Query function analysis result.
65 TestFunctionAnalysis::Result &FAR =
66 MFAM.getResult<TestFunctionAnalysis>(MF.getFunction());
67 // + 5
68 return FAR.InstructionCount;
69 }
70
71 private:
72 friend AnalysisInfoMixin<TestMachineFunctionAnalysis>;
73 static AnalysisKey Key;
74 };
75
76 AnalysisKey TestMachineFunctionAnalysis::Key;
77
78 const std::string DoInitErrMsg = "doInitialization failed";
79 const std::string DoFinalErrMsg = "doFinalization failed";
80
81 struct TestMachineFunctionPass : public PassInfoMixin<TestMachineFunctionPass> {
TestMachineFunctionPass__anon1d198ac80111::TestMachineFunctionPass82 TestMachineFunctionPass(int &Count, std::vector<int> &BeforeInitialization,
83 std::vector<int> &BeforeFinalization,
84 std::vector<int> &MachineFunctionPassCount)
85 : Count(Count), BeforeInitialization(BeforeInitialization),
86 BeforeFinalization(BeforeFinalization),
87 MachineFunctionPassCount(MachineFunctionPassCount) {}
88
doInitialization__anon1d198ac80111::TestMachineFunctionPass89 Error doInitialization(Module &M, MachineFunctionAnalysisManager &MFAM) {
90 // Force doInitialization fail by starting with big `Count`.
91 if (Count > 10000)
92 return make_error<StringError>(DoInitErrMsg, inconvertibleErrorCode());
93
94 // + 1
95 ++Count;
96 BeforeInitialization.push_back(Count);
97 return Error::success();
98 }
doFinalization__anon1d198ac80111::TestMachineFunctionPass99 Error doFinalization(Module &M, MachineFunctionAnalysisManager &MFAM) {
100 // Force doFinalization fail by starting with big `Count`.
101 if (Count > 1000)
102 return make_error<StringError>(DoFinalErrMsg, inconvertibleErrorCode());
103
104 // + 1
105 ++Count;
106 BeforeFinalization.push_back(Count);
107 return Error::success();
108 }
109
run__anon1d198ac80111::TestMachineFunctionPass110 PreservedAnalyses run(MachineFunction &MF,
111 MachineFunctionAnalysisManager &MFAM) {
112 // Query function analysis result.
113 TestFunctionAnalysis::Result &FAR =
114 MFAM.getResult<TestFunctionAnalysis>(MF.getFunction());
115 // 3 + 1 + 1 = 5
116 Count += FAR.InstructionCount;
117
118 // Query module analysis result.
119 MachineModuleInfo &MMI =
120 MFAM.getResult<MachineModuleAnalysis>(*MF.getFunction().getParent());
121 // 1 + 1 + 1 = 3
122 Count += (MMI.getModule() == MF.getFunction().getParent());
123
124 // Query machine function analysis result.
125 TestMachineFunctionAnalysis::Result &MFAR =
126 MFAM.getResult<TestMachineFunctionAnalysis>(MF);
127 // 3 + 1 + 1 = 5
128 Count += MFAR.InstructionCount;
129
130 MachineFunctionPassCount.push_back(Count);
131
132 return PreservedAnalyses::none();
133 }
134
135 int &Count;
136 std::vector<int> &BeforeInitialization;
137 std::vector<int> &BeforeFinalization;
138 std::vector<int> &MachineFunctionPassCount;
139 };
140
141 struct TestMachineModulePass : public PassInfoMixin<TestMachineModulePass> {
TestMachineModulePass__anon1d198ac80111::TestMachineModulePass142 TestMachineModulePass(int &Count, std::vector<int> &MachineModulePassCount)
143 : Count(Count), MachineModulePassCount(MachineModulePassCount) {}
144
run__anon1d198ac80111::TestMachineModulePass145 Error run(Module &M, MachineFunctionAnalysisManager &MFAM) {
146 MachineModuleInfo &MMI = MFAM.getResult<MachineModuleAnalysis>(M);
147 // + 1
148 Count += (MMI.getModule() == &M);
149 MachineModulePassCount.push_back(Count);
150 return Error::success();
151 }
152
run__anon1d198ac80111::TestMachineModulePass153 PreservedAnalyses run(MachineFunction &MF,
154 MachineFunctionAnalysisManager &AM) {
155 llvm_unreachable(
156 "This should never be reached because this is machine module pass");
157 }
158
159 int &Count;
160 std::vector<int> &MachineModulePassCount;
161 };
162
parseIR(LLVMContext & Context,const char * IR)163 std::unique_ptr<Module> parseIR(LLVMContext &Context, const char *IR) {
164 SMDiagnostic Err;
165 return parseAssemblyString(IR, Err, Context);
166 }
167
168 class PassManagerTest : public ::testing::Test {
169 protected:
170 LLVMContext Context;
171 std::unique_ptr<Module> M;
172 std::unique_ptr<TargetMachine> TM;
173
174 public:
PassManagerTest()175 PassManagerTest()
176 : M(parseIR(Context, "define void @f() {\n"
177 "entry:\n"
178 " call void @g()\n"
179 " call void @h()\n"
180 " ret void\n"
181 "}\n"
182 "define void @g() {\n"
183 " ret void\n"
184 "}\n"
185 "define void @h() {\n"
186 " ret void\n"
187 "}\n")) {
188 // MachineModuleAnalysis needs a TargetMachine instance.
189 llvm::InitializeAllTargets();
190
191 std::string TripleName = Triple::normalize(sys::getDefaultTargetTriple());
192 std::string Error;
193 const Target *TheTarget =
194 TargetRegistry::lookupTarget(TripleName, Error);
195 if (!TheTarget)
196 return;
197
198 TargetOptions Options;
199 TM.reset(TheTarget->createTargetMachine(TripleName, "", "",
200 Options, None));
201 }
202 };
203
TEST_F(PassManagerTest,Basic)204 TEST_F(PassManagerTest, Basic) {
205 if (!TM)
206 GTEST_SKIP();
207
208 LLVMTargetMachine *LLVMTM = static_cast<LLVMTargetMachine *>(TM.get());
209 M->setDataLayout(TM->createDataLayout());
210
211 LoopAnalysisManager LAM;
212 FunctionAnalysisManager FAM;
213 CGSCCAnalysisManager CGAM;
214 ModuleAnalysisManager MAM;
215 PassBuilder PB(TM.get());
216 PB.registerModuleAnalyses(MAM);
217 PB.registerFunctionAnalyses(FAM);
218 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
219
220 FAM.registerPass([&] { return TestFunctionAnalysis(); });
221 FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
222 MAM.registerPass([&] { return MachineModuleAnalysis(LLVMTM); });
223 MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
224
225 MachineFunctionAnalysisManager MFAM;
226 {
227 // Test move assignment.
228 MachineFunctionAnalysisManager NestedMFAM(FAM, MAM);
229 NestedMFAM.registerPass([&] { return PassInstrumentationAnalysis(); });
230 NestedMFAM.registerPass([&] { return TestMachineFunctionAnalysis(); });
231 MFAM = std::move(NestedMFAM);
232 }
233
234 int Count = 0;
235 std::vector<int> BeforeInitialization[2];
236 std::vector<int> BeforeFinalization[2];
237 std::vector<int> TestMachineFunctionCount[2];
238 std::vector<int> TestMachineModuleCount[2];
239
240 MachineFunctionPassManager MFPM;
241 {
242 // Test move assignment.
243 MachineFunctionPassManager NestedMFPM;
244 NestedMFPM.addPass(TestMachineModulePass(Count, TestMachineModuleCount[0]));
245 NestedMFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[0],
246 BeforeFinalization[0],
247 TestMachineFunctionCount[0]));
248 NestedMFPM.addPass(TestMachineModulePass(Count, TestMachineModuleCount[1]));
249 NestedMFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[1],
250 BeforeFinalization[1],
251 TestMachineFunctionCount[1]));
252 MFPM = std::move(NestedMFPM);
253 }
254
255 ASSERT_FALSE(errorToBool(MFPM.run(*M, MFAM)));
256
257 // Check first machine module pass
258 EXPECT_EQ(1u, TestMachineModuleCount[0].size());
259 EXPECT_EQ(3, TestMachineModuleCount[0][0]);
260
261 // Check first machine function pass
262 EXPECT_EQ(1u, BeforeInitialization[0].size());
263 EXPECT_EQ(1, BeforeInitialization[0][0]);
264 EXPECT_EQ(3u, TestMachineFunctionCount[0].size());
265 EXPECT_EQ(10, TestMachineFunctionCount[0][0]);
266 EXPECT_EQ(13, TestMachineFunctionCount[0][1]);
267 EXPECT_EQ(16, TestMachineFunctionCount[0][2]);
268 EXPECT_EQ(1u, BeforeFinalization[0].size());
269 EXPECT_EQ(31, BeforeFinalization[0][0]);
270
271 // Check second machine module pass
272 EXPECT_EQ(1u, TestMachineModuleCount[1].size());
273 EXPECT_EQ(17, TestMachineModuleCount[1][0]);
274
275 // Check second machine function pass
276 EXPECT_EQ(1u, BeforeInitialization[1].size());
277 EXPECT_EQ(2, BeforeInitialization[1][0]);
278 EXPECT_EQ(3u, TestMachineFunctionCount[1].size());
279 EXPECT_EQ(24, TestMachineFunctionCount[1][0]);
280 EXPECT_EQ(27, TestMachineFunctionCount[1][1]);
281 EXPECT_EQ(30, TestMachineFunctionCount[1][2]);
282 EXPECT_EQ(1u, BeforeFinalization[1].size());
283 EXPECT_EQ(32, BeforeFinalization[1][0]);
284
285 EXPECT_EQ(32, Count);
286
287 // doInitialization returns error
288 Count = 10000;
289 MFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[1],
290 BeforeFinalization[1],
291 TestMachineFunctionCount[1]));
292 std::string Message;
293 llvm::handleAllErrors(MFPM.run(*M, MFAM), [&](llvm::StringError &Error) {
294 Message = Error.getMessage();
295 });
296 EXPECT_EQ(Message, DoInitErrMsg);
297
298 // doFinalization returns error
299 Count = 1000;
300 MFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[1],
301 BeforeFinalization[1],
302 TestMachineFunctionCount[1]));
303 llvm::handleAllErrors(MFPM.run(*M, MFAM), [&](llvm::StringError &Error) {
304 Message = Error.getMessage();
305 });
306 EXPECT_EQ(Message, DoFinalErrMsg);
307 }
308
309 } // namespace
310