1 //===- llvm/unittest/CodeGen/PassManager.cpp - PassManager tests ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/ADT/Triple.h"
10 #include "llvm/Analysis/CGSCCPassManager.h"
11 #include "llvm/Analysis/LoopAnalysisManager.h"
12 #include "llvm/AsmParser/Parser.h"
13 #include "llvm/CodeGen/MachineModuleInfo.h"
14 #include "llvm/CodeGen/MachinePassManager.h"
15 #include "llvm/IR/LLVMContext.h"
16 #include "llvm/IR/Module.h"
17 #include "llvm/Passes/PassBuilder.h"
18 #include "llvm/Support/Host.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "llvm/Support/TargetRegistry.h"
21 #include "llvm/Support/TargetSelect.h"
22 #include "llvm/Target/TargetMachine.h"
23 #include "gtest/gtest.h"
24 
25 using namespace llvm;
26 
27 namespace {
28 
29 class TestFunctionAnalysis : public AnalysisInfoMixin<TestFunctionAnalysis> {
30 public:
31   struct Result {
Result__anon1d198ac80111::TestFunctionAnalysis::Result32     Result(int Count) : InstructionCount(Count) {}
33     int InstructionCount;
34   };
35 
36   /// Run the analysis pass over the function and return a result.
run(Function & F,FunctionAnalysisManager & AM)37   Result run(Function &F, FunctionAnalysisManager &AM) {
38     int Count = 0;
39     for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI)
40       for (BasicBlock::iterator II = BBI->begin(), IE = BBI->end(); II != IE;
41            ++II)
42         ++Count;
43     return Result(Count);
44   }
45 
46 private:
47   friend AnalysisInfoMixin<TestFunctionAnalysis>;
48   static AnalysisKey Key;
49 };
50 
51 AnalysisKey TestFunctionAnalysis::Key;
52 
53 class TestMachineFunctionAnalysis
54     : public AnalysisInfoMixin<TestMachineFunctionAnalysis> {
55 public:
56   struct Result {
Result__anon1d198ac80111::TestMachineFunctionAnalysis::Result57     Result(int Count) : InstructionCount(Count) {}
58     int InstructionCount;
59   };
60 
61   /// Run the analysis pass over the machine function and return a result.
run(MachineFunction & MF,MachineFunctionAnalysisManager::Base & AM)62   Result run(MachineFunction &MF, MachineFunctionAnalysisManager::Base &AM) {
63     auto &MFAM = static_cast<MachineFunctionAnalysisManager &>(AM);
64     // Query function analysis result.
65     TestFunctionAnalysis::Result &FAR =
66         MFAM.getResult<TestFunctionAnalysis>(MF.getFunction());
67     // + 5
68     return FAR.InstructionCount;
69   }
70 
71 private:
72   friend AnalysisInfoMixin<TestMachineFunctionAnalysis>;
73   static AnalysisKey Key;
74 };
75 
76 AnalysisKey TestMachineFunctionAnalysis::Key;
77 
78 const std::string DoInitErrMsg = "doInitialization failed";
79 const std::string DoFinalErrMsg = "doFinalization failed";
80 
81 struct TestMachineFunctionPass : public PassInfoMixin<TestMachineFunctionPass> {
TestMachineFunctionPass__anon1d198ac80111::TestMachineFunctionPass82   TestMachineFunctionPass(int &Count, std::vector<int> &BeforeInitialization,
83                           std::vector<int> &BeforeFinalization,
84                           std::vector<int> &MachineFunctionPassCount)
85       : Count(Count), BeforeInitialization(BeforeInitialization),
86         BeforeFinalization(BeforeFinalization),
87         MachineFunctionPassCount(MachineFunctionPassCount) {}
88 
doInitialization__anon1d198ac80111::TestMachineFunctionPass89   Error doInitialization(Module &M, MachineFunctionAnalysisManager &MFAM) {
90     // Force doInitialization fail by starting with big `Count`.
91     if (Count > 10000)
92       return make_error<StringError>(DoInitErrMsg, inconvertibleErrorCode());
93 
94     // + 1
95     ++Count;
96     BeforeInitialization.push_back(Count);
97     return Error::success();
98   }
doFinalization__anon1d198ac80111::TestMachineFunctionPass99   Error doFinalization(Module &M, MachineFunctionAnalysisManager &MFAM) {
100     // Force doFinalization fail by starting with big `Count`.
101     if (Count > 1000)
102       return make_error<StringError>(DoFinalErrMsg, inconvertibleErrorCode());
103 
104     // + 1
105     ++Count;
106     BeforeFinalization.push_back(Count);
107     return Error::success();
108   }
109 
run__anon1d198ac80111::TestMachineFunctionPass110   PreservedAnalyses run(MachineFunction &MF,
111                         MachineFunctionAnalysisManager &MFAM) {
112     // Query function analysis result.
113     TestFunctionAnalysis::Result &FAR =
114         MFAM.getResult<TestFunctionAnalysis>(MF.getFunction());
115     // 3 + 1 + 1 = 5
116     Count += FAR.InstructionCount;
117 
118     // Query module analysis result.
119     MachineModuleInfo &MMI =
120         MFAM.getResult<MachineModuleAnalysis>(*MF.getFunction().getParent());
121     // 1 + 1 + 1 = 3
122     Count += (MMI.getModule() == MF.getFunction().getParent());
123 
124     // Query machine function analysis result.
125     TestMachineFunctionAnalysis::Result &MFAR =
126         MFAM.getResult<TestMachineFunctionAnalysis>(MF);
127     // 3 + 1 + 1 = 5
128     Count += MFAR.InstructionCount;
129 
130     MachineFunctionPassCount.push_back(Count);
131 
132     return PreservedAnalyses::none();
133   }
134 
135   int &Count;
136   std::vector<int> &BeforeInitialization;
137   std::vector<int> &BeforeFinalization;
138   std::vector<int> &MachineFunctionPassCount;
139 };
140 
141 struct TestMachineModulePass : public PassInfoMixin<TestMachineModulePass> {
TestMachineModulePass__anon1d198ac80111::TestMachineModulePass142   TestMachineModulePass(int &Count, std::vector<int> &MachineModulePassCount)
143       : Count(Count), MachineModulePassCount(MachineModulePassCount) {}
144 
run__anon1d198ac80111::TestMachineModulePass145   Error run(Module &M, MachineFunctionAnalysisManager &MFAM) {
146     MachineModuleInfo &MMI = MFAM.getResult<MachineModuleAnalysis>(M);
147     // + 1
148     Count += (MMI.getModule() == &M);
149     MachineModulePassCount.push_back(Count);
150     return Error::success();
151   }
152 
run__anon1d198ac80111::TestMachineModulePass153   PreservedAnalyses run(MachineFunction &MF,
154                         MachineFunctionAnalysisManager &AM) {
155     llvm_unreachable(
156         "This should never be reached because this is machine module pass");
157   }
158 
159   int &Count;
160   std::vector<int> &MachineModulePassCount;
161 };
162 
parseIR(LLVMContext & Context,const char * IR)163 std::unique_ptr<Module> parseIR(LLVMContext &Context, const char *IR) {
164   SMDiagnostic Err;
165   return parseAssemblyString(IR, Err, Context);
166 }
167 
168 class PassManagerTest : public ::testing::Test {
169 protected:
170   LLVMContext Context;
171   std::unique_ptr<Module> M;
172   std::unique_ptr<TargetMachine> TM;
173 
174 public:
PassManagerTest()175   PassManagerTest()
176       : M(parseIR(Context, "define void @f() {\n"
177                            "entry:\n"
178                            "  call void @g()\n"
179                            "  call void @h()\n"
180                            "  ret void\n"
181                            "}\n"
182                            "define void @g() {\n"
183                            "  ret void\n"
184                            "}\n"
185                            "define void @h() {\n"
186                            "  ret void\n"
187                            "}\n")) {
188     // MachineModuleAnalysis needs a TargetMachine instance.
189     llvm::InitializeAllTargets();
190 
191     std::string TripleName = Triple::normalize(sys::getDefaultTargetTriple());
192     std::string Error;
193     const Target *TheTarget =
194         TargetRegistry::lookupTarget(TripleName, Error);
195     if (!TheTarget)
196       return;
197 
198     TargetOptions Options;
199     TM.reset(TheTarget->createTargetMachine(TripleName, "", "",
200                                             Options, None));
201   }
202 };
203 
TEST_F(PassManagerTest,Basic)204 TEST_F(PassManagerTest, Basic) {
205   if (!TM)
206     GTEST_SKIP();
207 
208   LLVMTargetMachine *LLVMTM = static_cast<LLVMTargetMachine *>(TM.get());
209   M->setDataLayout(TM->createDataLayout());
210 
211   LoopAnalysisManager LAM;
212   FunctionAnalysisManager FAM;
213   CGSCCAnalysisManager CGAM;
214   ModuleAnalysisManager MAM;
215   PassBuilder PB(TM.get());
216   PB.registerModuleAnalyses(MAM);
217   PB.registerFunctionAnalyses(FAM);
218   PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
219 
220   FAM.registerPass([&] { return TestFunctionAnalysis(); });
221   FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
222   MAM.registerPass([&] { return MachineModuleAnalysis(LLVMTM); });
223   MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
224 
225   MachineFunctionAnalysisManager MFAM;
226   {
227     // Test move assignment.
228     MachineFunctionAnalysisManager NestedMFAM(FAM, MAM);
229     NestedMFAM.registerPass([&] { return PassInstrumentationAnalysis(); });
230     NestedMFAM.registerPass([&] { return TestMachineFunctionAnalysis(); });
231     MFAM = std::move(NestedMFAM);
232   }
233 
234   int Count = 0;
235   std::vector<int> BeforeInitialization[2];
236   std::vector<int> BeforeFinalization[2];
237   std::vector<int> TestMachineFunctionCount[2];
238   std::vector<int> TestMachineModuleCount[2];
239 
240   MachineFunctionPassManager MFPM;
241   {
242     // Test move assignment.
243     MachineFunctionPassManager NestedMFPM;
244     NestedMFPM.addPass(TestMachineModulePass(Count, TestMachineModuleCount[0]));
245     NestedMFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[0],
246                                                BeforeFinalization[0],
247                                                TestMachineFunctionCount[0]));
248     NestedMFPM.addPass(TestMachineModulePass(Count, TestMachineModuleCount[1]));
249     NestedMFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[1],
250                                                BeforeFinalization[1],
251                                                TestMachineFunctionCount[1]));
252     MFPM = std::move(NestedMFPM);
253   }
254 
255   ASSERT_FALSE(errorToBool(MFPM.run(*M, MFAM)));
256 
257   // Check first machine module pass
258   EXPECT_EQ(1u, TestMachineModuleCount[0].size());
259   EXPECT_EQ(3, TestMachineModuleCount[0][0]);
260 
261   // Check first machine function pass
262   EXPECT_EQ(1u, BeforeInitialization[0].size());
263   EXPECT_EQ(1, BeforeInitialization[0][0]);
264   EXPECT_EQ(3u, TestMachineFunctionCount[0].size());
265   EXPECT_EQ(10, TestMachineFunctionCount[0][0]);
266   EXPECT_EQ(13, TestMachineFunctionCount[0][1]);
267   EXPECT_EQ(16, TestMachineFunctionCount[0][2]);
268   EXPECT_EQ(1u, BeforeFinalization[0].size());
269   EXPECT_EQ(31, BeforeFinalization[0][0]);
270 
271   // Check second machine module pass
272   EXPECT_EQ(1u, TestMachineModuleCount[1].size());
273   EXPECT_EQ(17, TestMachineModuleCount[1][0]);
274 
275   // Check second machine function pass
276   EXPECT_EQ(1u, BeforeInitialization[1].size());
277   EXPECT_EQ(2, BeforeInitialization[1][0]);
278   EXPECT_EQ(3u, TestMachineFunctionCount[1].size());
279   EXPECT_EQ(24, TestMachineFunctionCount[1][0]);
280   EXPECT_EQ(27, TestMachineFunctionCount[1][1]);
281   EXPECT_EQ(30, TestMachineFunctionCount[1][2]);
282   EXPECT_EQ(1u, BeforeFinalization[1].size());
283   EXPECT_EQ(32, BeforeFinalization[1][0]);
284 
285   EXPECT_EQ(32, Count);
286 
287   // doInitialization returns error
288   Count = 10000;
289   MFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[1],
290                                        BeforeFinalization[1],
291                                        TestMachineFunctionCount[1]));
292   std::string Message;
293   llvm::handleAllErrors(MFPM.run(*M, MFAM), [&](llvm::StringError &Error) {
294     Message = Error.getMessage();
295   });
296   EXPECT_EQ(Message, DoInitErrMsg);
297 
298   // doFinalization returns error
299   Count = 1000;
300   MFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[1],
301                                        BeforeFinalization[1],
302                                        TestMachineFunctionCount[1]));
303   llvm::handleAllErrors(MFPM.run(*M, MFAM), [&](llvm::StringError &Error) {
304     Message = Error.getMessage();
305   });
306   EXPECT_EQ(Message, DoFinalErrMsg);
307 }
308 
309 } // namespace
310