1 //===- PassManagerTest.cpp - PassManager unit tests -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Pass/PassManager.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/BuiltinOps.h"
12 #include "mlir/Pass/Pass.h"
13 #include "gtest/gtest.h"
14
15 using namespace mlir;
16 using namespace mlir::detail;
17
18 namespace {
19 /// Analysis that operates on any operation.
20 struct GenericAnalysis {
GenericAnalysis__anon6b73f1430111::GenericAnalysis21 GenericAnalysis(Operation *op) : isFunc(isa<FuncOp>(op)) {}
22 const bool isFunc;
23 };
24
25 /// Analysis that operates on a specific operation.
26 struct OpSpecificAnalysis {
OpSpecificAnalysis__anon6b73f1430111::OpSpecificAnalysis27 OpSpecificAnalysis(FuncOp op) : isSecret(op.getName() == "secret") {}
28 const bool isSecret;
29 };
30
31 /// Simple pass to annotate a FuncOp with the results of analysis.
32 /// Note: not using FunctionPass as it skip external functions.
33 struct AnnotateFunctionPass
34 : public PassWrapper<AnnotateFunctionPass, OperationPass<FuncOp>> {
runOnOperation__anon6b73f1430111::AnnotateFunctionPass35 void runOnOperation() override {
36 FuncOp op = getOperation();
37 Builder builder(op->getParentOfType<ModuleOp>());
38
39 auto &ga = getAnalysis<GenericAnalysis>();
40 auto &sa = getAnalysis<OpSpecificAnalysis>();
41
42 op->setAttr("isFunc", builder.getBoolAttr(ga.isFunc));
43 op->setAttr("isSecret", builder.getBoolAttr(sa.isSecret));
44 }
45 };
46
TEST(PassManagerTest,OpSpecificAnalysis)47 TEST(PassManagerTest, OpSpecificAnalysis) {
48 MLIRContext context;
49 Builder builder(&context);
50
51 // Create a module with 2 functions.
52 OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
53 for (StringRef name : {"secret", "not_secret"}) {
54 FuncOp func =
55 FuncOp::create(builder.getUnknownLoc(), name,
56 builder.getFunctionType(llvm::None, llvm::None));
57 func.setPrivate();
58 module->push_back(func);
59 }
60
61 // Instantiate and run our pass.
62 PassManager pm(&context);
63 pm.addNestedPass<FuncOp>(std::make_unique<AnnotateFunctionPass>());
64 LogicalResult result = pm.run(module.get());
65 EXPECT_TRUE(succeeded(result));
66
67 // Verify that each function got annotated with expected attributes.
68 for (FuncOp func : module->getOps<FuncOp>()) {
69 ASSERT_TRUE(func->getAttr("isFunc").isa<BoolAttr>());
70 EXPECT_TRUE(func->getAttr("isFunc").cast<BoolAttr>().getValue());
71
72 bool isSecret = func.getName() == "secret";
73 ASSERT_TRUE(func->getAttr("isSecret").isa<BoolAttr>());
74 EXPECT_EQ(func->getAttr("isSecret").cast<BoolAttr>().getValue(), isSecret);
75 }
76 }
77
78 namespace {
79 struct InvalidPass : Pass {
InvalidPass__anon6b73f1430111::__anon6b73f1430211::InvalidPass80 InvalidPass() : Pass(TypeID::get<InvalidPass>(), StringRef("invalid_op")) {}
getName__anon6b73f1430111::__anon6b73f1430211::InvalidPass81 StringRef getName() const override { return "Invalid Pass"; }
runOnOperation__anon6b73f1430111::__anon6b73f1430211::InvalidPass82 void runOnOperation() override {}
83
84 /// A clone method to create a copy of this pass.
clonePass__anon6b73f1430111::__anon6b73f1430211::InvalidPass85 std::unique_ptr<Pass> clonePass() const override {
86 return std::make_unique<InvalidPass>(
87 *static_cast<const InvalidPass *>(this));
88 }
89 };
90 } // anonymous namespace
91
TEST(PassManagerTest,InvalidPass)92 TEST(PassManagerTest, InvalidPass) {
93 MLIRContext context;
94 context.allowUnregisteredDialects();
95
96 // Create a module
97 OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
98
99 // Add a single "invalid_op" operation
100 OpBuilder builder(&module->getBodyRegion());
101 OperationState state(UnknownLoc::get(&context), "invalid_op");
102 builder.insert(Operation::create(state));
103
104 // Register a diagnostic handler to capture the diagnostic so that we can
105 // check it later.
106 std::unique_ptr<Diagnostic> diagnostic;
107 context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
108 diagnostic.reset(new Diagnostic(std::move(diag)));
109 });
110
111 // Instantiate and run our pass.
112 PassManager pm(&context);
113 pm.nest("invalid_op").addPass(std::make_unique<InvalidPass>());
114 LogicalResult result = pm.run(module.get());
115 EXPECT_TRUE(failed(result));
116 ASSERT_TRUE(diagnostic.get() != nullptr);
117 EXPECT_EQ(
118 diagnostic->str(),
119 "'invalid_op' op trying to schedule a pass on an unregistered operation");
120
121 // Check that adding the pass at the top-level triggers a fatal error.
122 ASSERT_DEATH(pm.addPass(std::make_unique<InvalidPass>()), "");
123 }
124
125 } // end namespace
126