1 //===- PassManagerTest.cpp - PassManager unit tests -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Pass/PassManager.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/BuiltinOps.h"
12 #include "mlir/Pass/Pass.h"
13 #include "gtest/gtest.h"
14
15 using namespace mlir;
16 using namespace mlir::detail;
17
18 namespace {
19 /// Analysis that operates on any operation.
20 struct GenericAnalysis {
GenericAnalysis__anon8fac1c090111::GenericAnalysis21 GenericAnalysis(Operation *op) : isFunc(isa<FuncOp>(op)) {}
22 const bool isFunc;
23 };
24
25 /// Analysis that operates on a specific operation.
26 struct OpSpecificAnalysis {
OpSpecificAnalysis__anon8fac1c090111::OpSpecificAnalysis27 OpSpecificAnalysis(FuncOp op) : isSecret(op.getName() == "secret") {}
28 const bool isSecret;
29 };
30
31 /// Simple pass to annotate a FuncOp with the results of analysis.
32 /// Note: not using FunctionPass as it skip external functions.
33 struct AnnotateFunctionPass
34 : public PassWrapper<AnnotateFunctionPass, OperationPass<FuncOp>> {
runOnOperation__anon8fac1c090111::AnnotateFunctionPass35 void runOnOperation() override {
36 FuncOp op = getOperation();
37 Builder builder(op->getParentOfType<ModuleOp>());
38
39 auto &ga = getAnalysis<GenericAnalysis>();
40 auto &sa = getAnalysis<OpSpecificAnalysis>();
41
42 op->setAttr("isFunc", builder.getBoolAttr(ga.isFunc));
43 op->setAttr("isSecret", builder.getBoolAttr(sa.isSecret));
44 }
45 };
46
TEST(PassManagerTest,OpSpecificAnalysis)47 TEST(PassManagerTest, OpSpecificAnalysis) {
48 MLIRContext context;
49 Builder builder(&context);
50
51 // Create a module with 2 functions.
52 OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
53 for (StringRef name : {"secret", "not_secret"}) {
54 FuncOp func =
55 FuncOp::create(builder.getUnknownLoc(), name,
56 builder.getFunctionType(llvm::None, llvm::None));
57 func.setPrivate();
58 module->push_back(func);
59 }
60
61 // Instantiate and run our pass.
62 PassManager pm(&context);
63 pm.addNestedPass<FuncOp>(std::make_unique<AnnotateFunctionPass>());
64 LogicalResult result = pm.run(module.get());
65 EXPECT_TRUE(succeeded(result));
66
67 // Verify that each function got annotated with expected attributes.
68 for (FuncOp func : module->getOps<FuncOp>()) {
69 ASSERT_TRUE(func->getAttr("isFunc").isa<BoolAttr>());
70 EXPECT_TRUE(func->getAttr("isFunc").cast<BoolAttr>().getValue());
71
72 bool isSecret = func.getName() == "secret";
73 ASSERT_TRUE(func->getAttr("isSecret").isa<BoolAttr>());
74 EXPECT_EQ(func->getAttr("isSecret").cast<BoolAttr>().getValue(), isSecret);
75 }
76 }
77
78 namespace {
79 struct InvalidPass : Pass {
InvalidPass__anon8fac1c090111::__anon8fac1c090211::InvalidPass80 InvalidPass() : Pass(TypeID::get<InvalidPass>(), StringRef("invalid_op")) {}
getName__anon8fac1c090111::__anon8fac1c090211::InvalidPass81 StringRef getName() const override { return "Invalid Pass"; }
runOnOperation__anon8fac1c090111::__anon8fac1c090211::InvalidPass82 void runOnOperation() override {}
83
84 /// A clone method to create a copy of this pass.
clonePass__anon8fac1c090111::__anon8fac1c090211::InvalidPass85 std::unique_ptr<Pass> clonePass() const override {
86 return std::make_unique<InvalidPass>(
87 *static_cast<const InvalidPass *>(this));
88 }
89 };
90 } // anonymous namespace
91
TEST(PassManagerTest,InvalidPass)92 TEST(PassManagerTest, InvalidPass) {
93 MLIRContext context;
94
95 // Create a module
96 OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
97
98 // Add a single "invalid_op" operation
99 OpBuilder builder(&module->getBodyRegion());
100 OperationState state(UnknownLoc::get(&context), "invalid_op");
101 builder.insert(Operation::create(state));
102
103 // Register a diagnostic handler to capture the diagnostic so that we can
104 // check it later.
105 std::unique_ptr<Diagnostic> diagnostic;
106 context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
107 diagnostic.reset(new Diagnostic(std::move(diag)));
108 });
109
110 // Instantiate and run our pass.
111 PassManager pm(&context);
112 pm.nest("invalid_op").addPass(std::make_unique<InvalidPass>());
113 LogicalResult result = pm.run(module.get());
114 EXPECT_TRUE(failed(result));
115 ASSERT_TRUE(diagnostic.get() != nullptr);
116 EXPECT_EQ(
117 diagnostic->str(),
118 "'invalid_op' op trying to schedule a pass on an unregistered operation");
119
120 // Check that adding the pass at the top-level triggers a fatal error.
121 ASSERT_DEATH(pm.addPass(std::make_unique<InvalidPass>()), "");
122 }
123
124 } // end namespace
125