1 //===- PassManagerTest.cpp - PassManager unit tests -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Pass/PassManager.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/Function.h"
12 #include "mlir/Pass/Pass.h"
13 #include "gtest/gtest.h"
14
15 using namespace mlir;
16 using namespace mlir::detail;
17
18 namespace {
19 /// Analysis that operates on any operation.
20 struct GenericAnalysis {
GenericAnalysis__anon097f9e560111::GenericAnalysis21 GenericAnalysis(Operation *op) : isFunc(isa<FuncOp>(op)) {}
22 const bool isFunc;
23 };
24
25 /// Analysis that operates on a specific operation.
26 struct OpSpecificAnalysis {
OpSpecificAnalysis__anon097f9e560111::OpSpecificAnalysis27 OpSpecificAnalysis(FuncOp op) : isSecret(op.getName() == "secret") {}
28 const bool isSecret;
29 };
30
31 /// Simple pass to annotate a FuncOp with the results of analysis.
32 /// Note: not using FunctionPass as it skip external functions.
33 struct AnnotateFunctionPass
34 : public PassWrapper<AnnotateFunctionPass, OperationPass<FuncOp>> {
runOnOperation__anon097f9e560111::AnnotateFunctionPass35 void runOnOperation() override {
36 FuncOp op = getOperation();
37 Builder builder(op.getParentOfType<ModuleOp>());
38
39 auto &ga = getAnalysis<GenericAnalysis>();
40 auto &sa = getAnalysis<OpSpecificAnalysis>();
41
42 op.setAttr("isFunc", builder.getBoolAttr(ga.isFunc));
43 op.setAttr("isSecret", builder.getBoolAttr(sa.isSecret));
44 }
45 };
46
TEST(PassManagerTest,OpSpecificAnalysis)47 TEST(PassManagerTest, OpSpecificAnalysis) {
48 MLIRContext context;
49 Builder builder(&context);
50
51 // Create a module with 2 functions.
52 OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
53 for (StringRef name : {"secret", "not_secret"}) {
54 FuncOp func =
55 FuncOp::create(builder.getUnknownLoc(), name,
56 builder.getFunctionType(llvm::None, llvm::None));
57 module->push_back(func);
58 }
59
60 // Instantiate and run our pass.
61 PassManager pm(&context);
62 pm.addNestedPass<FuncOp>(std::make_unique<AnnotateFunctionPass>());
63 LogicalResult result = pm.run(module.get());
64 EXPECT_TRUE(succeeded(result));
65
66 // Verify that each function got annotated with expected attributes.
67 for (FuncOp func : module->getOps<FuncOp>()) {
68 ASSERT_TRUE(func.getAttr("isFunc").isa<BoolAttr>());
69 EXPECT_TRUE(func.getAttr("isFunc").cast<BoolAttr>().getValue());
70
71 bool isSecret = func.getName() == "secret";
72 ASSERT_TRUE(func.getAttr("isSecret").isa<BoolAttr>());
73 EXPECT_EQ(func.getAttr("isSecret").cast<BoolAttr>().getValue(), isSecret);
74 }
75 }
76
77 namespace {
78 struct InvalidPass : Pass {
InvalidPass__anon097f9e560111::__anon097f9e560211::InvalidPass79 InvalidPass() : Pass(TypeID::get<InvalidPass>(), StringRef("invalid_op")) {}
getName__anon097f9e560111::__anon097f9e560211::InvalidPass80 StringRef getName() const override { return "Invalid Pass"; }
runOnOperation__anon097f9e560111::__anon097f9e560211::InvalidPass81 void runOnOperation() override {}
82
83 /// A clone method to create a copy of this pass.
clonePass__anon097f9e560111::__anon097f9e560211::InvalidPass84 std::unique_ptr<Pass> clonePass() const override {
85 return std::make_unique<InvalidPass>(
86 *static_cast<const InvalidPass *>(this));
87 }
88 };
89 } // anonymous namespace
90
TEST(PassManagerTest,InvalidPass)91 TEST(PassManagerTest, InvalidPass) {
92 MLIRContext context;
93
94 // Create a module
95 OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
96
97 // Add a single "invalid_op" operation
98 OpBuilder builder(&module->getBodyRegion());
99 OperationState state(UnknownLoc::get(&context), "invalid_op");
100 builder.insert(Operation::create(state));
101
102 // Register a diagnostic handler to capture the diagnostic so that we can
103 // check it later.
104 std::unique_ptr<Diagnostic> diagnostic;
105 context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
106 diagnostic.reset(new Diagnostic(std::move(diag)));
107 });
108
109 // Instantiate and run our pass.
110 PassManager pm(&context);
111 pm.nest("invalid_op").addPass(std::make_unique<InvalidPass>());
112 LogicalResult result = pm.run(module.get());
113 EXPECT_TRUE(failed(result));
114 ASSERT_TRUE(diagnostic.get() != nullptr);
115 EXPECT_EQ(
116 diagnostic->str(),
117 "'invalid_op' op trying to schedule a pass on an unregistered operation");
118
119 // Check that adding the pass at the top-level triggers a fatal error.
120 ASSERT_DEATH(pm.addPass(std::make_unique<InvalidPass>()), "");
121 }
122
123 } // end namespace
124